From 21a14088dd198e93f139f0d374321a243cd67553 Mon Sep 17 00:00:00 2001 From: Stein Magnus Jodal Date: Thu, 14 Nov 2013 21:52:36 +0100 Subject: [PATCH 0001/1539] Fix reference to 'websockets.protocols' in plural --- websockets/protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index 381ea4ac6..de9f843e4 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -1,5 +1,5 @@ """ -The :mod:`websockets.protocols` module handles WebSocket control and data +The :mod:`websockets.protocol` module handles WebSocket control and data frames as specified in `sections 4 to 8 of RFC 6455`_. .. _sections 4 to 8 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-4 From 53f61fa9c4c3ba10789a1de86aa7a69b8f623414 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 22 Nov 2013 18:48:35 +0100 Subject: [PATCH 0002/1539] Don't rely on utf-8 being the default system encoding. Fix #7. --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 07afe6cbc..4719ef520 100644 --- a/setup.py +++ b/setup.py @@ -14,10 +14,10 @@ description = "An implementation of the WebSocket Protocol (RFC 6455)" -with open(os.path.join(root, 'README')) as f: +with open(os.path.join(root, 'README'), encoding='utf-8') as f: long_description = '\n\n'.join(f.read().split('\n\n')[1:]) -with open(os.path.join(root, 'websockets', 'version.py')) as f: +with open(os.path.join(root, 'websockets', 'version.py'), encoding='utf-8') as f: exec(f.read()) py_version = sys.version_info[:2] From a13be1f1f2f1af137d1ab256997b1a75079cd59c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 29 Nov 2013 17:07:08 +0100 Subject: [PATCH 0003/1539] Support ping cancellation. Fix #8. --- websockets/protocol.py | 3 ++- websockets/test_protocol.py | 9 +++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index de9f843e4..c474a3761 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -280,7 +280,8 @@ def read_data_frame(self): ping_id = None while ping_id != frame.data: ping_id, waiter = self.pings.popitem(0) - waiter.set_result(None) + if not waiter.cancelled(): + waiter.set_result(None) # 5.6. Data Frames else: return frame diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index f44cef236..4e66cc8f1 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -185,6 +185,15 @@ def test_acknowledge_previous_pings(self): self.assertTrue(pings[1][0].done()) self.assertFalse(pings[2][0].done()) + def test_cancel_ping(self): + ping = self.protocol.ping() + ping_frame = self.loop.run_until_complete(self.sent()) + ping.cancel() + pong_frame = Frame(True, OP_PONG, ping_frame.data) + self.feed(pong_frame) + self.process_control_frames() + self.assertTrue(ping.cancelled()) + def test_duplicate_ping(self): self.protocol.ping(b'foobar') self.assertFrameSent(True, OP_PING, b'foobar') From 2659246d4e50cc61fdad57f185440609a3b13ce3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 6 Feb 2014 22:25:55 +0100 Subject: [PATCH 0004/1539] Account for minor API changes in asyncio. --- websockets/framing.py | 6 +++--- websockets/protocol.py | 5 ++--- websockets/test_framing.py | 2 +- websockets/test_protocol.py | 2 +- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/websockets/framing.py b/websockets/framing.py index fa5a21eb5..d65e220b5 100644 --- a/websockets/framing.py +++ b/websockets/framing.py @@ -92,10 +92,10 @@ def read_frame(reader, mask): @asyncio.coroutine def read_bytes(reader, n): # Undocumented utility function. - data = yield from reader(n) - if len(data) != n: + try: + return (yield from reader(n)) + except asyncio.IncompleteReadError: raise WebSocketProtocolError("Unexpected EOF") - return data def write_frame(frame, writer, mask): diff --git a/websockets/protocol.py b/websockets/protocol.py index c474a3761..f1d5d66c5 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -10,12 +10,11 @@ import codecs import collections import logging -import queue import random import struct import asyncio -from asyncio.queues import Queue +from asyncio.queues import Queue, QueueEmpty from .exceptions import InvalidState, WebSocketProtocolError from .framing import * @@ -138,7 +137,7 @@ def recv(self): # Return any available message try: return self.messages.get_nowait() - except queue.Empty: + except QueueEmpty: pass # Wait for a message until the connection is closed diff --git a/websockets/test_framing.py b/websockets/test_framing.py index e352e40ab..2a323e059 100644 --- a/websockets/test_framing.py +++ b/websockets/test_framing.py @@ -12,12 +12,12 @@ class FramingTests(unittest.TestCase): def setUp(self): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) - self.stream = asyncio.StreamReader() def tearDown(self): self.loop.close() def decode(self, message, mask=False): + self.stream = asyncio.StreamReader() self.stream.feed_data(message) self.stream.feed_eof() reader = self.stream.readexactly diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 4e66cc8f1..da13832da 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -48,7 +48,7 @@ def sent(self): stream.feed_data(data) self.transport.write.call_args_list = [] stream.feed_eof() - if stream._byte_count: + if not stream.at_eof(): return read_frame(stream.readexactly, self.protocol.is_client) @asyncio.coroutine From 181c8a392c795da19de4f9a8462320ff990a316a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 8 Feb 2014 20:32:18 +0100 Subject: [PATCH 0005/1539] AutoBahn no longer reports incorrect failures. --- compliance/README.rst | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/compliance/README.rst b/compliance/README.rst index c0c310c0c..736643771 100644 --- a/compliance/README.rst +++ b/compliance/README.rst @@ -30,6 +30,4 @@ Conformance notes ----------------- Test cases 6.4.2, 6.4.3, and 6.4.4 are actually more strict than the RFC. -Given its implementation, ``websockets`` should get a "Non-Strict", but due to -a bug in the test suite runner, it gets a "Fail". For more information see -issues 1, 3, 9, and 14 on https://github.com/tavendo/AutobahnTestSuite/issues. +Given its implementation, ``websockets`` gets a "Non-Strict". From 2bf6a14464a97ffeabad8890b11bc403bd59c91c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 6 Feb 2014 23:30:46 +0100 Subject: [PATCH 0006/1539] Take advantage of StreamReaderProtocol. This is a step towards supporting flow control. Refactor connection termination to better follow the specification: close the TCP connection properly with shutdown() then close(). --- websockets/client.py | 6 +-- websockets/protocol.py | 42 ++++++++++------- websockets/server.py | 10 ++-- websockets/test_protocol.py | 94 ++++++++++++++++++++++++------------- 4 files changed, 94 insertions(+), 58 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index 496ded4cc..a2d7c02bc 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -15,7 +15,7 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): """ - Complete WebSocket client implementation as a Tulip protocol. + Complete WebSocket client implementation as an asyncio protocol. This class inherits most of its methods from :class:`~websockets.protocol.WebSocketCommonProtocol`. @@ -41,11 +41,11 @@ def handshake(self, uri): key = build_request(set_header) request.append('\r\n') request = '\r\n'.join(request).encode() - self.transport.write(request) + self.writer.write(request) # Read handshake response. try: - status_code, headers = yield from read_response(self.stream) + status_code, headers = yield from read_response(self.reader) except Exception as exc: raise InvalidHandshake("Malformed HTTP message") from exc if status_code != 101: diff --git a/websockets/protocol.py b/websockets/protocol.py index f1d5d66c5..0ca5dcaa5 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -24,7 +24,7 @@ logger = logging.getLogger(__name__) -class WebSocketCommonProtocol(asyncio.Protocol): +class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): """ This class implements common parts of the WebSocket protocol. @@ -54,8 +54,9 @@ class WebSocketCommonProtocol(asyncio.Protocol): is_client = False state = 'OPEN' - def __init__(self, timeout=10): + def __init__(self, timeout=10, loop=None): self.timeout = timeout + super().__init__(asyncio.StreamReader(), self.client_connected, loop) self.close_code = None self.close_reason = '' @@ -288,7 +289,7 @@ def read_data_frame(self): @asyncio.coroutine def read_frame(self): is_masked = not self.is_client - frame = yield from read_frame(self.stream.readexactly, is_masked) + frame = yield from read_frame(self.reader.readexactly, is_masked) side = 'client' if self.is_client else 'server' logger.debug("%s << %s", side, frame) return frame @@ -302,7 +303,7 @@ def write_frame(self, opcode, data=b'', expected_state='OPEN'): side = 'client' if self.is_client else 'server' logger.debug("%s >> %s", side, frame) is_masked = self.is_client - write_frame(frame, self.transport.write, is_masked) + write_frame(frame, self.writer.write, is_masked) @asyncio.coroutine def close_connection(self): @@ -322,8 +323,18 @@ def close_connection(self): except (asyncio.CancelledError, asyncio.TimeoutError): pass - if self.state != 'CLOSED': - self.transport.close() + if self.state == 'CLOSED': + return + + assert self.writer.can_write_eof(), "WebSocket runs over TCP/IP!" + self.writer.write_eof() + self.writer.close() + + try: + yield from asyncio.wait_for(self.connection_closed, + timeout=self.timeout) + except (asyncio.CancelledError, asyncio.TimeoutError): + pass @asyncio.coroutine def fail_connection(self, code=1011, reason=''): @@ -340,22 +351,17 @@ def fail_connection(self, code=1011, reason=''): self.closing_handshake.set_result(False) yield from self.close_connection() - # Tulip Protocol methods - - def connection_made(self, transport): - self.transport = transport - self.stream = asyncio.StreamReader() - - def data_received(self, data): - self.stream.feed_data(data) + # asyncio StreamReaderProtocol methods - def eof_received(self): - self.stream.feed_eof() - self.transport.close() + def client_connected(self, reader, writer): + self.reader = reader + self.writer = writer def connection_lost(self, exc): # 7.1.4. The WebSocket Connection is Closed self.state = 'CLOSED' - self.connection_closed.set_result(None) + if not self.connection_closed.done(): + self.connection_closed.set_result(None) if self.close_code is None: self.close_code = 1006 + super().connection_lost(exc) diff --git a/websockets/server.py b/websockets/server.py index f998173a0..1824e9bfc 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -19,7 +19,7 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): """ - Complete WebSocket server implementation as a Tulip protocol. + Complete WebSocket server implementation as an asyncio protocol. This class inherits most of its methods from :class:`~websockets.protocol.WebSocketCommonProtocol`. @@ -45,7 +45,7 @@ def handler(self): uri = yield from self.handshake() except Exception as exc: logger.info("Exception in opening handshake: {}".format(exc)) - self.transport.close() + self.writer.close() return try: @@ -59,7 +59,7 @@ def handler(self): yield from self.close() except Exception as exc: logger.info("Exception in closing handshake: {}".format(exc)) - self.transport.close() + self.writer.close() return @asyncio.coroutine @@ -71,7 +71,7 @@ def handshake(self): """ # Read handshake request. try: - uri, headers = yield from read_request(self.stream) + uri, headers = yield from read_request(self.reader) except Exception as exc: raise InvalidHandshake("Malformed HTTP message") from exc get_header = lambda k: headers.get(k, '') @@ -85,7 +85,7 @@ def handshake(self): build_response(set_header, key) response.append('\r\n') response = '\r\n'.join(response).encode() - self.transport.write(response) + self.writer.write(response) self.state = 'OPEN' self.opening_handshake.set_result(True) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index da13832da..7970c443c 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -32,14 +32,6 @@ def feed(self, frame): mask = not self.protocol.is_client write_frame(frame, self.protocol.data_received, mask) - def feed_eof(self): - """Feed end-of-file to the protocol.""" - # The transport is mocked so this sequence will happen synchronously: - # proto.eof_received() -> transport.close() -> proto.connection_lost() - # To allow processing frames before shutting down the connection, - # delay self.feed_eof() with self.loop.call_later(). - self.protocol.eof_received() - @asyncio.coroutine def sent(self): """Read the next frame sent to the transport.""" @@ -59,11 +51,8 @@ def echo(self): @asyncio.coroutine def fast_connection_failure(self): """Ensure the connection failure terminates quickly.""" - sent = yield from self.sent() - if sent and sent.opcode == OP_CLOSE: - self.feed(sent) - if self.protocol.is_client: - self.feed_eof() + self.protocol.eof_received() + self.protocol.connection_lost(None) def process_control_frames(self): """Process control frames fed to the protocol.""" @@ -85,11 +74,11 @@ def assertConnectionClosed(self, code, message): def test_open(self): self.assertTrue(self.protocol.open) - self.feed_eof() + self.protocol.connection_lost(None) self.assertFalse(self.protocol.open) def test_connection_lost(self): - self.feed_eof() + self.protocol.connection_lost(None) self.assertConnectionClosed(1006, '') def test_recv_text(self): @@ -126,7 +115,8 @@ def read_message(): self.assertConnectionClosed(1011, '') def test_recv_on_closed_connection(self): - self.feed_eof() + self.protocol.eof_received() + self.protocol.connection_lost(None) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) def test_send_text(self): @@ -143,7 +133,8 @@ def test_send_type_error(self): self.assertNoFrameSent() def test_send_on_closed_connection(self): - self.feed_eof() + self.protocol.eof_received() + self.protocol.connection_lost(None) with self.assertRaises(InvalidState): self.protocol.send('foobar') self.assertNoFrameSent() @@ -239,9 +230,10 @@ def test_close_handshake_in_fragmented_text(self): def test_connection_close_in_fragmented_text(self): self.feed(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) - self.loop.call_later(MS, self.feed_eof) + self.loop.call_later(MS, self.protocol.eof_received) + self.loop.call_later(2 * MS, lambda: self.protocol.connection_lost(None)) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) - self.assertConnectionClosed(1006, '') + self.assertConnectionClosed(1002, '') class ServerTests(CommonTests, unittest.TestCase): @@ -262,10 +254,10 @@ def test_client_close(self): # non standard client-initiated close self.loop.call_later(MS, self.feed, frame) # The server is waiting for some data at this point, and won't get it. self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) - # After recv() returns None the connection is closed. + # After recv() returns None, the connection is closed. self.assertConnectionClosed(1000, 'because.') self.assertFrameSent(*frame) - # The server may call close() later without any effect. + # Closing the connection again is a no-op. self.loop.run_until_complete(self.protocol.close(reason='oh noes!')) self.assertConnectionClosed(1000, 'because.') self.assertNoFrameSent() @@ -287,7 +279,7 @@ def test_close_drops_frames(self): # Only one frame is emitted, and it's consumed by self.echo(). self.assertNoFrameSent() - def test_close_timeout(self): + def test_close_handshake_timeout(self): self.after = asyncio.Future() self.loop.call_later(2 * MS, self.after.cancel) self.before = asyncio.Future() @@ -299,13 +291,30 @@ def test_close_timeout(self): self.assertFalse(self.before.cancelled()) self.before.cancel() + def test_close_timeout_before_connection_lost(self): + # Prevent the connection from terminating. + self.protocol.connection_lost = unittest.mock.Mock() + + self.after = asyncio.Future() + self.loop.call_later(4 * MS, self.after.cancel) + self.before = asyncio.Future() + self.loop.call_later(8 * MS, self.before.cancel) + self.protocol.timeout = 5 * MS + self.loop.call_later(MS, asyncio.async, self.echo()) + self.loop.run_until_complete(self.protocol.close(reason='because.')) + self.assertEqual(self.protocol.state, 'CLOSING') + self.assertTrue(self.after.cancelled()) + self.assertFalse(self.before.cancelled()) + self.before.cancel() + def test_close_protocol_error(self): self.loop.call_later(MS, self.feed, Frame(True, OP_CLOSE, b'\x00')) self.loop.run_until_complete(self.protocol.close(reason='because.')) self.assertConnectionClosed(1002, '') def test_close_connection_lost(self): - self.loop.call_later(MS, self.feed_eof) + self.loop.call_later(MS, self.protocol.eof_received) + self.loop.call_later(2 * MS, lambda: self.protocol.connection_lost(None)) self.loop.run_until_complete(self.protocol.close(reason='because.')) self.assertConnectionClosed(1002, '') @@ -335,20 +344,22 @@ def setUp(self): def test_close(self): # standard server-initiated close frame = Frame(True, OP_CLOSE, serialize_close(1000, 'because.')) self.loop.call_later(MS, self.feed, frame) - self.loop.call_later(2 * MS, self.feed_eof) + self.loop.call_later(2 * MS, self.protocol.eof_received) + self.loop.call_later(3 * MS, lambda: self.protocol.connection_lost(None)) # The client is waiting for some data at this point, and won't get it. self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) - # After recv() returns None the connection is closed. + # After recv() returns None, the connection is closed. self.assertConnectionClosed(1000, 'because.') self.assertFrameSent(*frame) - # The client may call close() later without any effect. + # Closing the connection again is a no-op. self.loop.run_until_complete(self.protocol.close('oh noes!')) self.assertConnectionClosed(1000, 'because.') self.assertNoFrameSent() def test_client_close(self): # non standard client-initiated close self.loop.call_later(MS, asyncio.async, self.echo()) - self.loop.call_later(2 * MS, self.feed_eof) + self.loop.call_later(2 * MS, self.protocol.eof_received) + self.loop.call_later(3 * MS, lambda: self.protocol.connection_lost(None)) self.loop.run_until_complete(self.protocol.close(reason='because.')) self.assertConnectionClosed(1000, 'because.') # Only one frame is emitted, and it's consumed by self.echo(). @@ -362,22 +373,41 @@ def test_simultaneous_close(self): # non standard close from both sides server_close = Frame(True, OP_CLOSE, serialize_close(1000, 'server')) client_close = Frame(True, OP_CLOSE, serialize_close(1000, 'client')) self.loop.call_later(MS, self.feed, server_close) - self.loop.call_later(2 * MS, self.feed_eof) + self.loop.call_later(2 * MS, self.protocol.eof_received) + self.loop.call_later(3 * MS, lambda: self.protocol.connection_lost(None)) self.loop.run_until_complete(self.protocol.close(reason='client')) self.assertConnectionClosed(1000, 'server') self.assertFrameSent(*client_close) self.assertNoFrameSent() - def test_connection_close_timeout(self): - # If the server doesn't drop the connection quickly, the client will. + def test_close_timeout_before_eof_received(self): self.after = asyncio.Future() - self.loop.call_later(2 * MS, self.after.cancel) + self.loop.call_later(4 * MS, self.after.cancel) self.before = asyncio.Future() - self.loop.call_later(10 * MS, self.before.cancel) + self.loop.call_later(8 * MS, self.before.cancel) self.protocol.timeout = 5 * MS self.loop.call_later(MS, asyncio.async, self.echo()) self.loop.run_until_complete(self.protocol.close(reason='because.')) + # If the server doesn't drop the connection quickly, the client will. self.assertConnectionClosed(1000, 'because.') self.assertTrue(self.after.cancelled()) self.assertFalse(self.before.cancelled()) self.before.cancel() + + def test_close_timeout_before_connection_lost(self): + # Prevent the connection from terminating. + self.protocol.connection_lost = unittest.mock.Mock() + + self.after = asyncio.Future() + self.loop.call_later(9 * MS, self.after.cancel) + self.before = asyncio.Future() + self.loop.call_later(13 * MS, self.before.cancel) + self.protocol.timeout = 5 * MS + self.loop.call_later(MS, asyncio.async, self.echo()) + self.loop.call_later(2 * MS, self.protocol.eof_received) + self.loop.run_until_complete(self.protocol.close(reason='because.')) + # If the server doesn't drop the connection quickly, the client will. + self.assertEqual(self.protocol.state, 'CLOSING') + self.assertTrue(self.after.cancelled()) + self.assertFalse(self.before.cancelled()) + self.before.cancel() From b198151cc6a8e35d34fcc751d02534973855f672 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 8 Feb 2014 20:05:55 +0100 Subject: [PATCH 0007/1539] Add flow control. Fix #10. Invoke drain() after each write. --- compliance/test_client.py | 2 +- compliance/test_server.py | 2 +- docs/index.rst | 17 +++++++++++++++++ example/client.py | 2 +- example/server.py | 2 +- websockets/protocol.py | 29 +++++++++++++++++++---------- websockets/test_client_server.py | 8 ++++---- websockets/test_protocol.py | 19 ++++++++++--------- 8 files changed, 54 insertions(+), 27 deletions(-) diff --git a/compliance/test_client.py b/compliance/test_client.py index 5f6701351..2f02e3e6f 100644 --- a/compliance/test_client.py +++ b/compliance/test_client.py @@ -21,7 +21,7 @@ class EchoClientProtocol(websockets.WebSocketClientProtocol): def read_message(self): msg = yield from super(EchoClientProtocol, self).read_message() if msg is not None: - self.send(msg) + yield from self.send(msg) return msg diff --git a/compliance/test_server.py b/compliance/test_server.py index adc6295d4..b974d9b76 100644 --- a/compliance/test_server.py +++ b/compliance/test_server.py @@ -16,7 +16,7 @@ class EchoServerProtocol(websockets.WebSocketServerProtocol): def read_message(self): msg = yield from super(EchoServerProtocol, self).read_message() if msg is not None: - self.send(msg) + yield from self.send(msg) return msg diff --git a/docs/index.rst b/docs/index.rst index 054d1f782..710d827ff 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -145,6 +145,23 @@ Utilities .. automodule:: websockets.http :members: +Changelog +--------- + +2.0 +... + +* Backwards-incompatible API change: + :meth:`~websockets.protocol.WebSocketCommonProtocol.send`, + :meth:`~websockets.protocol.WebSocketCommonProtocol.ping` and + :meth:`~websockets.protocol.WebSocketCommonProtocol.pong` are coroutines. + They used to be regular functions. +* Add flow control. + +1.0 +... + +* Initial public release. Limitations ----------- diff --git a/example/client.py b/example/client.py index f5e9395c5..6bc960227 100755 --- a/example/client.py +++ b/example/client.py @@ -7,7 +7,7 @@ def hello(): websocket = yield from websockets.connect('ws://localhost:8765/') name = input("What's your name? ") - websocket.send(name) + yield from websocket.send(name) print("> {}".format(name)) greeting = yield from websocket.recv() print("< {}".format(greeting)) diff --git a/example/server.py b/example/server.py index 9df53bff3..c76f0bae0 100755 --- a/example/server.py +++ b/example/server.py @@ -9,7 +9,7 @@ def hello(websocket, uri): print("< {}".format(name)) greeting = "Hello {}!".format(name) print("> {}".format(greeting)) - websocket.send(greeting) + yield from websocket.send(greeting) start_server = websockets.serve(hello, 'localhost', 8765) diff --git a/websockets/protocol.py b/websockets/protocol.py index 0ca5dcaa5..0b3a087cc 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -109,7 +109,7 @@ def close(self, code=1000, reason=''): if self.state == 'OPEN': # 7.1.2. Start the WebSocket Closing Handshake self.close_code, self.close_reason = code, reason - self.write_frame(OP_CLOSE, serialize_close(code, reason)) + yield from self.write_frame(OP_CLOSE, serialize_close(code, reason)) # 7.1.3. The WebSocket Closing Handshake is Started self.state = 'CLOSING' @@ -149,9 +149,10 @@ def recv(self): if next_message in done: return next_message.result() + @asyncio.coroutine def send(self, data): """ - This function sends a message. + This coroutine sends a message. It sends a :class:`str` as a text frame and :class:`bytes` as a binary frame. @@ -166,11 +167,12 @@ def send(self, data): opcode = 2 else: raise TypeError("data must be bytes or str") - self.write_frame(opcode, data) + yield from self.write_frame(opcode, data) + @asyncio.coroutine def ping(self, data=None): """ - This function sends a ping. + This coroutine sends a ping. It returns a Future which will be completed when the corresponding pong is received and which you may ignore if you don't want to wait. @@ -185,16 +187,17 @@ def ping(self, data=None): data = struct.pack('!I', random.getrandbits(32)) self.pings[data] = asyncio.Future() - self.write_frame(OP_PING, data) + yield from self.write_frame(OP_PING, data) return self.pings[data] + @asyncio.coroutine def pong(self, data=b''): """ - This function sends a pong. + This coroutine sends a pong. An unsolicited pong may serve as a unidirectional heartbeat. """ - self.write_frame(OP_PONG, data) + yield from self.write_frame(OP_PONG, data) # Private methods - no guarantees. @@ -267,12 +270,12 @@ def read_data_frame(self): if self.state != 'CLOSING': # 7.1.3. The WebSocket Closing Handshake is Started self.state = 'CLOSING' - self.write_frame(OP_CLOSE, frame.data, 'CLOSING') + yield from self.write_frame(OP_CLOSE, frame.data, 'CLOSING') self.closing_handshake.set_result(True) return elif frame.opcode == OP_PING: # Answer pings. - self.pong(frame.data) + yield from self.pong(frame.data) elif frame.opcode == OP_PONG: # Do not acknowledge pings on unsolicited pongs. if frame.data in self.pings: @@ -294,6 +297,7 @@ def read_frame(self): logger.debug("%s << %s", side, frame) return frame + @asyncio.coroutine def write_frame(self, opcode, data=b'', expected_state='OPEN'): # This may happen if a user attempts to write on a closed connection. if self.state != expected_state: @@ -304,6 +308,11 @@ def write_frame(self, opcode, data=b'', expected_state='OPEN'): logger.debug("%s >> %s", side, frame) is_masked = self.is_client write_frame(frame, self.writer.write, is_masked) + # Handle flow control automatically. + try: + yield from self.writer.drain() + except ConnectionResetError: + pass @asyncio.coroutine def close_connection(self): @@ -345,7 +354,7 @@ def fail_connection(self, code=1011, reason=''): # 7.1.7. Fail the WebSocket Connection logger.info("Failing the WebSocket connection: %d %s", code, reason) if self.state == 'OPEN': - self.write_frame(OP_CLOSE, serialize_close(code, reason)) + yield from self.write_frame(OP_CLOSE, serialize_close(code, reason)) self.state = 'CLOSING' if not self.closing_handshake.done(): self.closing_handshake.set_result(False) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 77c144101..eea8b02fc 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -11,7 +11,7 @@ @asyncio.coroutine def echo(ws, uri): - ws.send((yield from ws.recv())) + yield from ws.send((yield from ws.recv())) class ClientServerTests(unittest.TestCase): @@ -42,7 +42,7 @@ def stop_server(self): def test_basic(self): self.start_client() - self.client.send("Hello!") + self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") self.stop_client() @@ -95,7 +95,7 @@ def test_server_handler_crashes(self, send): send.side_effect = ValueError("send failed") self.start_client() - self.client.send("Hello!") + self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, None) self.stop_client() @@ -108,7 +108,7 @@ def test_server_close_crashes(self, close): close.side_effect = ValueError("close failed") self.start_client() - self.client.send("Hello!") + self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") self.stop_client() diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 7970c443c..4c9a40a85 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -19,6 +19,7 @@ def setUp(self): asyncio.set_event_loop(self.loop) self.protocol = WebSocketCommonProtocol() self.transport = unittest.mock.Mock() + self.transport._conn_lost = 0 # checked by drain() self.transport.close = unittest.mock.Mock( side_effect=lambda: self.protocol.connection_lost(None)) self.protocol.connection_made(self.transport) @@ -120,23 +121,23 @@ def test_recv_on_closed_connection(self): self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) def test_send_text(self): - self.protocol.send('café') + self.loop.run_until_complete(self.protocol.send('café')) self.assertFrameSent(True, OP_TEXT, 'café'.encode('utf-8')) def test_send_binary(self): - self.protocol.send(b'tea') + self.loop.run_until_complete(self.protocol.send(b'tea')) self.assertFrameSent(True, OP_BINARY, b'tea') def test_send_type_error(self): with self.assertRaises(TypeError): - self.protocol.send(42) + self.loop.run_until_complete(self.protocol.send(42)) self.assertNoFrameSent() def test_send_on_closed_connection(self): self.protocol.eof_received() self.protocol.connection_lost(None) with self.assertRaises(InvalidState): - self.protocol.send('foobar') + self.loop.run_until_complete(self.protocol.send('foobar')) self.assertNoFrameSent() def test_answer_ping(self): @@ -150,7 +151,7 @@ def test_ignore_pong(self): self.assertNoFrameSent() def test_acknowledge_ping(self): - ping = self.protocol.ping() + ping = self.loop.run_until_complete(self.protocol.ping()) self.assertFalse(ping.done()) ping_frame = self.loop.run_until_complete(self.sent()) pong_frame = Frame(True, OP_PONG, ping_frame.data) @@ -160,7 +161,7 @@ def test_acknowledge_ping(self): def test_acknowledge_previous_pings(self): pings = [( - self.protocol.ping(), + self.loop.run_until_complete(self.protocol.ping()), self.loop.run_until_complete(self.sent()), ) for i in range(3)] # Unsolicited pong doesn't acknowledge pings @@ -177,7 +178,7 @@ def test_acknowledge_previous_pings(self): self.assertFalse(pings[2][0].done()) def test_cancel_ping(self): - ping = self.protocol.ping() + ping = self.loop.run_until_complete(self.protocol.ping()) ping_frame = self.loop.run_until_complete(self.sent()) ping.cancel() pong_frame = Frame(True, OP_PONG, ping_frame.data) @@ -186,10 +187,10 @@ def test_cancel_ping(self): self.assertTrue(ping.cancelled()) def test_duplicate_ping(self): - self.protocol.ping(b'foobar') + self.loop.run_until_complete(self.protocol.ping(b'foobar')) self.assertFrameSent(True, OP_PING, b'foobar') with self.assertRaises(ValueError): - self.protocol.ping(b'foobar') + self.loop.run_until_complete(self.protocol.ping(b'foobar')) self.assertNoFrameSent() def test_fragmented_text(self): From 024c4a5c98e5f17805bae271c76eeb63252c1ae8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 Feb 2014 11:14:34 +0100 Subject: [PATCH 0008/1539] Don't rely on `coverage` being in $PATH. --- Makefile | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index bee2d4e5e..6f014b953 100644 --- a/Makefile +++ b/Makefile @@ -2,9 +2,9 @@ test: python -m unittest coverage: - coverage erase - coverage run --branch --source=websockets -m unittest - coverage html --omit='websockets/test_*.py' + python -m coverage erase + python -m coverage run --branch --source=websockets -m unittest + python -m coverage html --omit='websockets/test_*.py' clean: find . -name '*.pyc' -delete From c9587401f3ba5a890d3d75078c367bbc0b6ac4a3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 Feb 2014 19:36:02 +0100 Subject: [PATCH 0009/1539] Fix ResourceWarnings under Python 3.4. --- websockets/client.py | 3 ++- websockets/server.py | 2 ++ websockets/test_client_server.py | 4 ++++ 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/websockets/client.py b/websockets/client.py index a2d7c02bc..1b5d90d53 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -88,7 +88,8 @@ def connect(uri, *, try: yield from protocol.handshake(uri) except Exception: - transport.close() + protocol.writer.write_eof() + protocol.writer.close() raise return protocol diff --git a/websockets/server.py b/websockets/server.py index 1824e9bfc..204dc169d 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -45,6 +45,7 @@ def handler(self): uri = yield from self.handshake() except Exception as exc: logger.info("Exception in opening handshake: {}".format(exc)) + self.writer.write_eof() self.writer.close() return @@ -59,6 +60,7 @@ def handler(self): yield from self.close() except Exception as exc: logger.info("Exception in closing handshake: {}".format(exc)) + self.writer.write_eof() self.writer.close() return diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index eea8b02fc..7a08ab45d 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -61,6 +61,10 @@ def test_client_receives_malformed_response(self, _read_response): with self.assertRaises(InvalidHandshake): self.start_client() + # Now the server believes the connection is open. Run the event loop + # once to make it notice the connection was closed. Interesting hack. + yield from tulip.sleep(0) + @patch('websockets.client.build_request') def test_client_sends_invalid_handshake_request(self, _build_request): def wrong_build_request(set_header): From d03ba5669bd386f8e873a620ff6a5a6926b340b4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 Feb 2014 19:38:18 +0100 Subject: [PATCH 0010/1539] Bump version number. --- LICENSE | 2 +- docs/conf.py | 6 +++--- websockets/version.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/LICENSE b/LICENSE index aa75fbe2b..60ab1d99f 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright (c) 2013 Aymeric Augustin. +Copyright (c) 2013-2014 Aymeric Augustin. All rights reserved. Redistribution and use in source and binary forms, with or without diff --git a/docs/conf.py b/docs/conf.py index c91f69a35..d00ce37d8 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -41,16 +41,16 @@ # General information about the project. project = u'websockets' -copyright = u'2013, Aymeric Augustin' +copyright = u'2013-2014, Aymeric Augustin' # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. -version = '1.0' +version = '2.0' # The full version, including alpha/beta/rc tags. -release = '1.0.0' +release = '2.0.0' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/websockets/version.py b/websockets/version.py index 294b9514d..fc2e06f51 100644 --- a/websockets/version.py +++ b/websockets/version.py @@ -1 +1 @@ -version = '1.0' +version = '2.0' From f714c3fd95bc407a0c59fc496909bf17dc4e9d31 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Apr 2014 23:01:18 +0200 Subject: [PATCH 0011/1539] Be more consistent in ReST markup. --- websockets/client.py | 4 ++-- websockets/http.py | 12 ++++++------ websockets/protocol.py | 2 +- websockets/server.py | 10 +++++----- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index 1b5d90d53..02194f2f3 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -63,8 +63,8 @@ def connect(uri, *, """ This coroutine connects to a WebSocket server. - It's a thin wrapper around the event loop's ``create_connection`` method. - Extra keyword arguments are passed to ``create_server``. + It's a thin wrapper around the event loop's `create_connection` method. + Extra keyword arguments are passed to `create_server`. It returns a :class:`~websockets.client.WebSocketClientProtocol` which can then be used to send and receive messages. diff --git a/websockets/http.py b/websockets/http.py index 9c9d2d368..a50be8fc3 100644 --- a/websockets/http.py +++ b/websockets/http.py @@ -32,8 +32,8 @@ def read_request(stream): """ Read an HTTP/1.1 request from `stream`. - Return `(uri, headers)` where `uri` is a :class:`str` and `headers` - is a :class:`~email.message.Message`; `uri` isn't URL-decoded. + Return `(uri, headers)` where `uri` is a :class:`str` and `headers` is a + :class:`~email.message.Message`; `uri` isn't URL-decoded. Raise an exception if the request isn't well formatted. @@ -53,8 +53,8 @@ def read_response(stream): """ Read an HTTP/1.1 response from `stream`. - Return `(status, headers)` where `status` is a :class:`int` and - `headers` is a :class:`~email.message.Message`. + Return `(status, headers)` where `status` is a :class:`int` and `headers` + is a :class:`~email.message.Message`. Raise an exception if the request isn't well formatted. @@ -72,8 +72,8 @@ def read_message(stream): """ Read an HTTP message from `stream`. - Return `(start_line, headers)` where `start_line` is :class:`bytes` - and `headers` is a :class:`~email.message.Message`. + Return `(start_line, headers)` where `start_line` is :class:`bytes` and + `headers` is a :class:`~email.message.Message`. The message is assumed not to contain a body. """ diff --git a/websockets/protocol.py b/websockets/protocol.py index 0b3a087cc..1397a796c 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -101,7 +101,7 @@ def close(self, code=1000, reason=''): It waits for the other end to complete the handshake. It doesn't do anything once the connection is closed. - It's usually safe to wrap this coroutine in ``asyncio.async()`` since + It's usually safe to wrap this coroutine in `asyncio.async()` since errors during connection termination aren't particularly useful. The `code` must be an :class:`int` and the `reason` a :class:`str`. diff --git a/websockets/server.py b/websockets/server.py index 204dc169d..cdf3abb4e 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -101,16 +101,16 @@ def serve(ws_handler, host=None, port=None, *, """ This coroutine creates a WebSocket server. - It's a thin wrapper around the event loop's ``create_server`` method. - ``host``, ``port`` as well as extra keyword arguments are passed to - ``create_server``. + It's a thin wrapper around the event loop's `create_server` method. + `host`, `port` as well as extra keyword arguments are passed to + `create_server`. - It returns a ``Server`` object with a ``close`` method to stop the server. + It returns a `Server` object with a `close` method to stop the server. `ws_handler` is the WebSocket handler. It must be a coroutine accepting two arguments: a :class:`~websockets.server.WebSocketServerProtocol` and the request URI. The `host` and `port` arguments and other keyword - arguments are passed to ``create_server``. + arguments are passed to `create_server`. Whenever a client connects, the server accepts the connection, creates a :class:`~websockets.server.WebSocketServerProtocol`, performs the opening From bfac179d417b44e3678fa4a6fb1ab67cbb3cbe1a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Apr 2014 23:33:12 +0200 Subject: [PATCH 0012/1539] Implement Origin checking. --- docs/index.rst | 7 +++++ websockets/client.py | 12 ++++++-- websockets/server.py | 27 ++++++++++++------ websockets/test_client_server.py | 47 ++++++++++++++++++++++++++++++++ 4 files changed, 81 insertions(+), 12 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index 710d827ff..aa7e07dd9 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -148,6 +148,13 @@ Utilities Changelog --------- +2.1 +... + +* Added support for providing and checking Origin_. + +.. _Origin: https://tools.ietf.org/html/rfc6455#section-10.2 + 2.0 ... diff --git a/websockets/client.py b/websockets/client.py index 02194f2f3..974c2778b 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -25,9 +25,11 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): state = 'CONNECTING' @asyncio.coroutine - def handshake(self, uri): + def handshake(self, uri, origin=None): """ Perform the client side of the opening handshake. + + If provided, ``origin`` sets the HTTP Origin header. """ # Send handshake request. Since the uri and the headers only contain # ASCII characters, we can keep this simple. @@ -37,6 +39,8 @@ def handshake(self, uri): set_header('Host', uri.host) else: set_header('Host', '{}:{}'.format(uri.host, uri.port)) + if origin is not None: + set_header('Origin', origin) set_header('User-Agent', USER_AGENT) key = build_request(set_header) request.append('\r\n') @@ -59,10 +63,12 @@ def handshake(self, uri): @asyncio.coroutine def connect(uri, *, - klass=WebSocketClientProtocol, **kwds): + klass=WebSocketClientProtocol, origin=None, **kwds): """ This coroutine connects to a WebSocket server. + It accepts an ``origin`` keyword argument to set the Origin HTTP header. + It's a thin wrapper around the event loop's `create_connection` method. Extra keyword arguments are passed to `create_server`. @@ -86,7 +92,7 @@ def connect(uri, *, klass, uri.host, uri.port, **kwds) try: - yield from protocol.handshake(uri) + yield from protocol.handshake(uri, origin=origin) except Exception: protocol.writer.write_eof() protocol.writer.close() diff --git a/websockets/server.py b/websockets/server.py index cdf3abb4e..44d14c24e 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -31,8 +31,9 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): state = 'CONNECTING' - def __init__(self, ws_handler=None, **kwargs): + def __init__(self, ws_handler=None, *, origins=None, **kwargs): self.ws_handler = ws_handler + self.origins = origins super().__init__(**kwargs) def connection_made(self, transport): @@ -42,7 +43,7 @@ def connection_made(self, transport): @asyncio.coroutine def handler(self): try: - uri = yield from self.handshake() + uri = yield from self.handshake(origins=self.origins) except Exception as exc: logger.info("Exception in opening handshake: {}".format(exc)) self.writer.write_eof() @@ -65,10 +66,13 @@ def handler(self): return @asyncio.coroutine - def handshake(self): + def handshake(self, origins=None): """ Perform the server side of the opening handshake. + If provided, ``origins`` is a list of acceptable HTTP Origin values. + Include ``''`` in the list if the lack of an origin is acceptable. + Return the URI of the request. """ # Read handshake request. @@ -79,6 +83,12 @@ def handshake(self): get_header = lambda k: headers.get(k, '') key = check_request(get_header) + # Check origin in request. + if origins is not None: + origin = get_header('Origin') + if not set(origin.split() or ('',))<= set(origins): + raise InvalidHandshake("Bad origin: {}".format(origin)) + # Send handshake response. Since the headers only contain ASCII # characters, we can keep this simple. response = ['HTTP/1.1 101 Switching Protocols'] @@ -97,7 +107,7 @@ def handshake(self): @asyncio.coroutine def serve(ws_handler, host=None, port=None, *, - klass=WebSocketServerProtocol, **kwds): + klass=WebSocketServerProtocol, origins=None, **kwds): """ This coroutine creates a WebSocket server. @@ -105,12 +115,11 @@ def serve(ws_handler, host=None, port=None, *, `host`, `port` as well as extra keyword arguments are passed to `create_server`. - It returns a `Server` object with a `close` method to stop the server. - `ws_handler` is the WebSocket handler. It must be a coroutine accepting two arguments: a :class:`~websockets.server.WebSocketServerProtocol` and - the request URI. The `host` and `port` arguments and other keyword - arguments are passed to `create_server`. + the request URI. `origin` is a list of acceptable Origin HTTP headers. + + It returns a `Server` object with a `close` method to stop the server. Whenever a client connects, the server accepts the connection, creates a :class:`~websockets.server.WebSocketServerProtocol`, performs the opening @@ -119,4 +128,4 @@ def serve(ws_handler, host=None, port=None, *, connection. """ return (yield from asyncio.get_event_loop().create_server( - lambda: klass(ws_handler), host, port, **kwds)) + lambda: klass(ws_handler, origins=origins), host, port, **kwds)) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 7a08ab45d..8faf345c3 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -119,3 +119,50 @@ def test_server_close_crashes(self, close): # Connection ends with a protocol error. self.assertEqual(self.client.close_code, 1002) + + +class ClientServerOriginTests(unittest.TestCase): + + def test_checking_origin_succeeds(self): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + server = loop.run_until_complete( + serve(echo, 'localhost', 8642, origins=['http://localhost'])) + client = loop.run_until_complete( + connect('ws://localhost:8642/', origin='http://localhost')) + + loop.run_until_complete(client.send("Hello!")) + self.assertEqual(loop.run_until_complete(client.recv()), "Hello!") + + server.close() + loop.run_until_complete(server.wait_closed()) + loop.run_until_complete(client.worker) + + def test_checking_origin_fails(self): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + server = loop.run_until_complete( + serve(echo, 'localhost', 8642, origins=['http://localhost'])) + with self.assertRaises(InvalidHandshake): + loop.run_until_complete( + connect('ws://localhost:8642/', origin='http://otherhost')) + + server.close() + loop.run_until_complete(server.wait_closed()) + + def test_checking_lack_of_origin_succeeds(self): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + server = loop.run_until_complete( + serve(echo, 'localhost', 8642, origins=[''])) + client = loop.run_until_complete(connect('ws://localhost:8642/')) + + loop.run_until_complete(client.send("Hello!")) + self.assertEqual(loop.run_until_complete(client.recv()), "Hello!") + + server.close() + loop.run_until_complete(server.wait_closed()) + loop.run_until_complete(client.worker) From 13bf3354736f219e739dc78e5f5c0cc2ab3058c7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 6 Apr 2014 09:58:49 +0200 Subject: [PATCH 0013/1539] Fix a test that wasn't running. --- websockets/test_client_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 8faf345c3..38f1751c3 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -63,7 +63,7 @@ def test_client_receives_malformed_response(self, _read_response): # Now the server believes the connection is open. Run the event loop # once to make it notice the connection was closed. Interesting hack. - yield from tulip.sleep(0) + self.loop.run_until_complete(asyncio.sleep(0)) @patch('websockets.client.build_request') def test_client_sends_invalid_handshake_request(self, _build_request): From 69103a2ca1f8b586c0db2fcbb20f3ad9d1e39eb5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 6 Apr 2014 10:00:26 +0200 Subject: [PATCH 0014/1539] Add HTTP responses on handshake errors. --- websockets/server.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/websockets/server.py b/websockets/server.py index 44d14c24e..19b8aae39 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -25,8 +25,7 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): :class:`~websockets.protocol.WebSocketCommonProtocol`. For the sake of simplicity, this protocol doesn't inherit a proper HTTP - implementation, and it doesn't send appropriate HTTP responses when - something goes wrong. + implementation. Its support for HTTP responses is very limited. """ state = 'CONNECTING' @@ -46,6 +45,12 @@ def handler(self): uri = yield from self.handshake(origins=self.origins) except Exception as exc: logger.info("Exception in opening handshake: {}".format(exc)) + if isinstance(exc, InvalidHandshake): + response = 'HTTP/1.1 400 Bad Request\r\n\r\n' + str(exc) + else: + response = ('HTTP/1.1 500 Internal Server Error\r\n\r\n' + 'See server log for more information.') + self.writer.write(response.encode()) self.writer.write_eof() self.writer.close() return @@ -80,6 +85,7 @@ def handshake(self, origins=None): uri, headers = yield from read_request(self.reader) except Exception as exc: raise InvalidHandshake("Malformed HTTP message") from exc + get_header = lambda k: headers.get(k, '') key = check_request(get_header) From 53086cf0eabb3503671609edd00985b2c275d22e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 12 Apr 2014 09:28:45 +0200 Subject: [PATCH 0015/1539] Disambiguate URI and path. Technically, handlers receive a "resource_name" or "full_path" ie. a path + a query string, but these names are too long. Refs #20. Thanks Ian Kelly for the report. --- compliance/test_server.py | 2 +- example/server.py | 2 +- websockets/http.py | 8 ++++---- websockets/server.py | 8 ++++---- websockets/test_client_server.py | 2 +- websockets/test_http.py | 4 ++-- 6 files changed, 13 insertions(+), 13 deletions(-) diff --git a/compliance/test_server.py b/compliance/test_server.py index b974d9b76..ecdce46d4 100644 --- a/compliance/test_server.py +++ b/compliance/test_server.py @@ -21,7 +21,7 @@ def read_message(self): @asyncio.coroutine -def noop(ws, uri): +def noop(ws, path): yield from ws.worker diff --git a/example/server.py b/example/server.py index c76f0bae0..958697318 100755 --- a/example/server.py +++ b/example/server.py @@ -4,7 +4,7 @@ import websockets @asyncio.coroutine -def hello(websocket, uri): +def hello(websocket, path): name = yield from websocket.recv() print("< {}".format(name)) greeting = "Hello {}!".format(name) diff --git a/websockets/http.py b/websockets/http.py index a50be8fc3..6262bf017 100644 --- a/websockets/http.py +++ b/websockets/http.py @@ -32,20 +32,20 @@ def read_request(stream): """ Read an HTTP/1.1 request from `stream`. - Return `(uri, headers)` where `uri` is a :class:`str` and `headers` is a - :class:`~email.message.Message`; `uri` isn't URL-decoded. + Return `(path, headers)` where `path` is a :class:`str` and `headers` is a + :class:`~email.message.Message`; `path` isn't URL-decoded. Raise an exception if the request isn't well formatted. The request is assumed not to contain a body. """ request_line, headers = yield from read_message(stream) - method, uri, version = request_line[:-2].decode().split(None, 2) + method, path, version = request_line[:-2].decode().split(None, 2) if method != 'GET': raise ValueError("Unsupported method") if version != 'HTTP/1.1': raise ValueError("Unsupported HTTP version") - return uri, headers + return path, headers @asyncio.coroutine diff --git a/websockets/server.py b/websockets/server.py index 19b8aae39..a5d00951b 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -42,7 +42,7 @@ def connection_made(self, transport): @asyncio.coroutine def handler(self): try: - uri = yield from self.handshake(origins=self.origins) + path = yield from self.handshake(origins=self.origins) except Exception as exc: logger.info("Exception in opening handshake: {}".format(exc)) if isinstance(exc, InvalidHandshake): @@ -56,7 +56,7 @@ def handler(self): return try: - yield from self.ws_handler(self, uri) + yield from self.ws_handler(self, path) except Exception: logger.info("Exception in connection handler", exc_info=True) yield from self.fail_connection(1011) @@ -82,7 +82,7 @@ def handshake(self, origins=None): """ # Read handshake request. try: - uri, headers = yield from read_request(self.reader) + path, headers = yield from read_request(self.reader) except Exception as exc: raise InvalidHandshake("Malformed HTTP message") from exc @@ -108,7 +108,7 @@ def handshake(self, origins=None): self.state = 'OPEN' self.opening_handshake.set_result(True) - return uri + return path @asyncio.coroutine diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 38f1751c3..287733905 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -10,7 +10,7 @@ @asyncio.coroutine -def echo(ws, uri): +def echo(ws, path): yield from ws.send((yield from ws.recv())) diff --git a/websockets/test_http.py b/websockets/test_http.py index 072246012..2caf891e2 100644 --- a/websockets/test_http.py +++ b/websockets/test_http.py @@ -31,8 +31,8 @@ def test_read_request(self): b'Sec-WebSocket-Version: 13\r\n' b'\r\n' ) - uri, hdrs = self.loop.run_until_complete(read_request(self.stream)) - self.assertEqual(uri, '/chat') + path, hdrs = self.loop.run_until_complete(read_request(self.stream)) + self.assertEqual(path, '/chat') self.assertEqual(hdrs['Upgrade'], 'websocket') def test_read_response(self): From a47fbd5c9cc24320f3d65b1a83fd74900281de56 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 12 Apr 2014 09:37:11 +0200 Subject: [PATCH 0016/1539] Disambiguate URI strings from WebSocketURI tuples. --- websockets/client.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index 974c2778b..6fd8c4c7c 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -25,20 +25,20 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): state = 'CONNECTING' @asyncio.coroutine - def handshake(self, uri, origin=None): + def handshake(self, wsuri, origin=None): """ Perform the client side of the opening handshake. If provided, ``origin`` sets the HTTP Origin header. """ - # Send handshake request. Since the uri and the headers only contain + # Send handshake request. Since the URI and the headers only contain # ASCII characters, we can keep this simple. - request = ['GET %s HTTP/1.1' % uri.resource_name] + request = ['GET %s HTTP/1.1' % wsuri.resource_name] set_header = lambda k, v: request.append('{}: {}'.format(k, v)) - if uri.port == (443 if uri.secure else 80): # pragma: no cover - set_header('Host', uri.host) + if wsuri.port == (443 if wsuri.secure else 80): # pragma: no cover + set_header('Host', wsuri.host) else: - set_header('Host', '{}:{}'.format(uri.host, uri.port)) + set_header('Host', '{}:{}'.format(wsuri.host, wsuri.port)) if origin is not None: set_header('Origin', origin) set_header('User-Agent', USER_AGENT) @@ -86,13 +86,13 @@ def connect(uri, *, Connection" in RFC 6455, except for the requirement that "there MUST be no more than one connection in a CONNECTING state." """ - uri = parse_uri(uri) - kwds.setdefault('ssl', uri.secure) + wsuri = parse_uri(uri) + kwds.setdefault('ssl', wsuri.secure) transport, protocol = yield from asyncio.get_event_loop().create_connection( - klass, uri.host, uri.port, **kwds) + klass, wsuri.host, wsuri.port, **kwds) try: - yield from protocol.handshake(uri, origin=origin) + yield from protocol.handshake(wsuri, origin=origin) except Exception: protocol.writer.write_eof() protocol.writer.close() From 429308623165d3e8772c75d38d7e8ce5b191748b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 12 Apr 2014 09:44:19 +0200 Subject: [PATCH 0017/1539] Use the new super. --- compliance/test_client.py | 2 +- compliance/test_server.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/compliance/test_client.py b/compliance/test_client.py index 2f02e3e6f..5c75006a1 100644 --- a/compliance/test_client.py +++ b/compliance/test_client.py @@ -19,7 +19,7 @@ class EchoClientProtocol(websockets.WebSocketClientProtocol): @asyncio.coroutine def read_message(self): - msg = yield from super(EchoClientProtocol, self).read_message() + msg = yield from super().read_message() if msg is not None: yield from self.send(msg) return msg diff --git a/compliance/test_server.py b/compliance/test_server.py index ecdce46d4..3df861f53 100644 --- a/compliance/test_server.py +++ b/compliance/test_server.py @@ -14,7 +14,7 @@ class EchoServerProtocol(websockets.WebSocketServerProtocol): @asyncio.coroutine def read_message(self): - msg = yield from super(EchoServerProtocol, self).read_message() + msg = yield from super().read_message() if msg is not None: yield from self.send(msg) return msg From d744d63ceb45e7de8b1e5a2e39ed9fd39d26fbe6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 15 Apr 2014 23:42:47 +0200 Subject: [PATCH 0018/1539] Add tests for SSL connections. --- websockets/test_client_server.py | 24 ++++++++++++++++++++++++ websockets/testcert.pem | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+) create mode 100644 websockets/testcert.pem diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 287733905..5f1028ace 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -1,3 +1,5 @@ +import os +import ssl import unittest from unittest.mock import patch @@ -9,6 +11,9 @@ from .server import * +testcert = os.path.join(os.path.dirname(__file__), 'testcert.pem') + + @asyncio.coroutine def echo(ws, path): yield from ws.send((yield from ws.recv())) @@ -121,6 +126,25 @@ def test_server_close_crashes(self, close): self.assertEqual(self.client.close_code, 1002) +@unittest.skipUnless(os.path.exists(testcert), "test certificate is missing") +class SSLClientServerTests(ClientServerTests): + + def start_server(self): + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ssl_context.load_cert_chain(testcert) + + server = serve(echo, 'localhost', 8642, ssl=ssl_context) + self.server = self.loop.run_until_complete(server) + + def start_client(self): + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ssl_context.load_verify_locations(testcert) + ssl_context.verify_mode = ssl.CERT_REQUIRED + + client = connect('ws://localhost:8642/', ssl=ssl_context) + self.client = self.loop.run_until_complete(client) + + class ClientServerOriginTests(unittest.TestCase): def test_checking_origin_succeeds(self): diff --git a/websockets/testcert.pem b/websockets/testcert.pem new file mode 100644 index 000000000..1ed5fd458 --- /dev/null +++ b/websockets/testcert.pem @@ -0,0 +1,32 @@ +-----BEGIN PRIVATE KEY----- +MIICdQIBADANBgkqhkiG9w0BAQEFAASCAl8wggJbAgEAAoGBANSBDRjLau8ur0s1 +WNVJdpa1x6PMdistb9VU9lBqxJzu8sgWnuzvy1Nt+1lCl6j6QtQxma99bPjbcZ9S +rXJUwtBLq067Zy01VQ/lpBfjqRZShYUVimg4We9KB5DFvWzP52L8Oj0U3sm46mek +vcddtJQz6WwbPiROOSvF80W206fNAgMBAAECgYAfSKBU9h1X+Nd1ivT48Ue0CC7L +vl3nHVlJXqikThODxumW6z2aQ/L65UYLbfJFvhH4ixTE8QIJ4MRpYBKIslG7c3DX +cX6MP6KPaUjxSbjB9RlS9VdKbovxxeecbWzfSY+Cz/alyg++J0iOwbJVGL+RlaJw +g8hQM+UWyJLN764/QQJBAP/NeBHChjU7QyA36lv2Lm/lUpkYy3Zy4ZTGPyiuBjLC +SNqF1PMxrvuHHL05NaE6R02VFXztxJf2ci1rZKDG2N8CQQDUqwdsWZFlmTA5hqTB +mEYw3feCij3t4sy0KDV1wV851WJRbVrzrbxN+rHL5MKwd3qcxs1TXCfF1A9qbPXS +phjTAkBtd/KgNwzUDu5lBUjH3gx1WkAEwHWh1PvwfP5eXErOwhIHYiqFgIePoHyO +BcOLobMN4nT1p5LwLUkjYsgHfdElAkBgbBL3izyjBeuZiXSV2gapDVq1MxyVCOmr +HTfv5fbY7+id5qkAJttjt7B5M4UaIXHUN0bM7tGRnm5G4JQsJ+bFAkAQ/pYfrC9l +2hXI29YTSYTsw4iDjgJF6RAxw2108M8KybSJdyvQ43N4U40BQx8BRQmxZwSyG5QX +s+j9Cb63orCr +-----END PRIVATE KEY----- +-----BEGIN CERTIFICATE----- +MIICijCCAfOgAwIBAgIJANEJitrPxb96MA0GCSqGSIb3DQEBBQUAMF0xCzAJBgNV +BAYTAkZSMQ8wDQYDVQQIDAZGcmFuY2UxDjAMBgNVBAcMBVBhcmlzMRkwFwYDVQQK +DBBBeW1lcmljIEF1Z3VzdGluMRIwEAYDVQQDDAlsb2NhbGhvc3QwIBcNMTQwNDE1 +MjEzMjI5WhgPMjExNDA0MTYyMTMyMjlaMF0xCzAJBgNVBAYTAkZSMQ8wDQYDVQQI +DAZGcmFuY2UxDjAMBgNVBAcMBVBhcmlzMRkwFwYDVQQKDBBBeW1lcmljIEF1Z3Vz +dGluMRIwEAYDVQQDDAlsb2NhbGhvc3QwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJ +AoGBANSBDRjLau8ur0s1WNVJdpa1x6PMdistb9VU9lBqxJzu8sgWnuzvy1Nt+1lC +l6j6QtQxma99bPjbcZ9SrXJUwtBLq067Zy01VQ/lpBfjqRZShYUVimg4We9KB5DF +vWzP52L8Oj0U3sm46mekvcddtJQz6WwbPiROOSvF80W206fNAgMBAAGjUDBOMB0G +A1UdDgQWBBRcFzeirOD3zMnjCptlc0sh9VWZJjAfBgNVHSMEGDAWgBRcFzeirOD3 +zMnjCptlc0sh9VWZJjAMBgNVHRMEBTADAQH/MA0GCSqGSIb3DQEBBQUAA4GBAFyv +MGP9hnrMbDnwRtCYX/g99nvxjc5KXJyDw91Vo3hmHjdVRXY/oJbjiUtOBf1OsgoN +rv7KsaMb9+060K+uDtQIIiwPcxF1nQOZDtv6Nyzj8hwM2XFl+XiVgUD2pg++scWF +PDfbpmeEDQnUMEqHETM7JTMLB349/s5UUQqsSBE0 +-----END CERTIFICATE----- From ad5cb55dfe97d1b107422ca8bd1e7a725ac3b331 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 15 Apr 2014 23:50:56 +0200 Subject: [PATCH 0019/1539] Incorrect assertion when closing SSL connections. Fix #22. --- websockets/client.py | 1 - websockets/protocol.py | 4 ++-- websockets/server.py | 2 -- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index 6fd8c4c7c..bdfbe9e4e 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -94,7 +94,6 @@ def connect(uri, *, try: yield from protocol.handshake(wsuri, origin=origin) except Exception: - protocol.writer.write_eof() protocol.writer.close() raise diff --git a/websockets/protocol.py b/websockets/protocol.py index 1397a796c..33bf76d5c 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -335,8 +335,8 @@ def close_connection(self): if self.state == 'CLOSED': return - assert self.writer.can_write_eof(), "WebSocket runs over TCP/IP!" - self.writer.write_eof() + if self.writer.can_write_eof(): + self.writer.write_eof() self.writer.close() try: diff --git a/websockets/server.py b/websockets/server.py index a5d00951b..7e6cda0a4 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -51,7 +51,6 @@ def handler(self): response = ('HTTP/1.1 500 Internal Server Error\r\n\r\n' 'See server log for more information.') self.writer.write(response.encode()) - self.writer.write_eof() self.writer.close() return @@ -66,7 +65,6 @@ def handler(self): yield from self.close() except Exception as exc: logger.info("Exception in closing handshake: {}".format(exc)) - self.writer.write_eof() self.writer.close() return From 18a809ca6ac546e4a3c0e5ca3540f80bdf42afb9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 16 Apr 2014 00:11:16 +0200 Subject: [PATCH 0020/1539] Fix some resource warnings in tests. --- websockets/test_client_server.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 5f1028ace..f28d9d768 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -99,6 +99,10 @@ def wrong_read_response(stream): with self.assertRaises(InvalidHandshake): self.start_client() + # Now the server believes the connection is open. Run the event loop + # once to make it notice the connection was closed. Interesting hack. + self.loop.run_until_complete(asyncio.sleep(0)) + @patch('websockets.server.WebSocketServerProtocol.send') def test_server_handler_crashes(self, send): send.side_effect = ValueError("send failed") @@ -162,6 +166,7 @@ def test_checking_origin_succeeds(self): server.close() loop.run_until_complete(server.wait_closed()) loop.run_until_complete(client.worker) + loop.close() def test_checking_origin_fails(self): loop = asyncio.new_event_loop() @@ -175,6 +180,7 @@ def test_checking_origin_fails(self): server.close() loop.run_until_complete(server.wait_closed()) + loop.close() def test_checking_lack_of_origin_succeeds(self): loop = asyncio.new_event_loop() @@ -190,3 +196,4 @@ def test_checking_lack_of_origin_succeeds(self): server.close() loop.run_until_complete(server.wait_closed()) loop.run_until_complete(client.worker) + loop.close() From 8d0ebc241c3d318b698e45c64523ee10eba61b68 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 26 Apr 2014 13:59:24 +0200 Subject: [PATCH 0021/1539] Make error handling more robust in the server. Fix #23. --- websockets/server.py | 55 ++++++++++++++++++++++++++------------------ 1 file changed, 33 insertions(+), 22 deletions(-) diff --git a/websockets/server.py b/websockets/server.py index 7e6cda0a4..b266815e3 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -41,32 +41,43 @@ def connection_made(self, transport): @asyncio.coroutine def handler(self): + # Since this method doesn't have a caller able to handle exceptions, + # it attemps to log relevant ones and close the connection properly. try: - path = yield from self.handshake(origins=self.origins) - except Exception as exc: - logger.info("Exception in opening handshake: {}".format(exc)) - if isinstance(exc, InvalidHandshake): - response = 'HTTP/1.1 400 Bad Request\r\n\r\n' + str(exc) - else: - response = ('HTTP/1.1 500 Internal Server Error\r\n\r\n' - 'See server log for more information.') - self.writer.write(response.encode()) - self.writer.close() - return - try: - yield from self.ws_handler(self, path) + try: + path = yield from self.handshake(origins=self.origins) + except Exception as exc: + logger.info("Exception in opening handshake: {}".format(exc)) + if isinstance(exc, InvalidHandshake): + response = 'HTTP/1.1 400 Bad Request\r\n\r\n' + str(exc) + else: + response = ('HTTP/1.1 500 Internal Server Error\r\n\r\n' + 'See server log for more information.') + self.writer.write(response.encode()) + raise + + try: + yield from self.ws_handler(self, path) + except Exception: + logger.info("Exception in connection handler", exc_info=True) + yield from self.fail_connection(1011) + raise + + try: + yield from self.close() + except Exception as exc: + logger.info("Exception in closing handshake: {}".format(exc)) + raise + except Exception: - logger.info("Exception in connection handler", exc_info=True) - yield from self.fail_connection(1011) - return + # Last-ditch attempt to avoid leaking connections on errors. + try: + self.writer.close() + except Exception: # pragma: no cover + pass + - try: - yield from self.close() - except Exception as exc: - logger.info("Exception in closing handshake: {}".format(exc)) - self.writer.close() - return @asyncio.coroutine def handshake(self, origins=None): From 598f9bac94494a5fd7d757fa440ebdb837949374 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 26 Apr 2014 14:42:16 +0200 Subject: [PATCH 0022/1539] Add host, port and secure attributes to protocols. Refs #20. --- websockets/client.py | 5 +-- websockets/protocol.py | 11 +++++- websockets/server.py | 5 ++- websockets/test_client_server.py | 58 +++++++++++++++++++++++++------- 4 files changed, 63 insertions(+), 16 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index bdfbe9e4e..af5b4f101 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -87,9 +87,10 @@ def connect(uri, *, more than one connection in a CONNECTING state." """ wsuri = parse_uri(uri) - kwds.setdefault('ssl', wsuri.secure) + kwds.setdefault('ssl', True) + factory = lambda: klass(host=wsuri.host, port=wsuri.port, secure=wsuri.secure) transport, protocol = yield from asyncio.get_event_loop().create_connection( - klass, wsuri.host, wsuri.port, **kwds) + factory, wsuri.host, wsuri.port, **kwds) try: yield from protocol.handshake(wsuri, origin=origin) diff --git a/websockets/protocol.py b/websockets/protocol.py index 33bf76d5c..210bd8b85 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -37,6 +37,9 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): control frames automatically. It sends outgoing data frames and performs the closing handshake. + The `host`, `port` and `secure` parameters are simply stored as attributes + for handlers that need them. + The `timeout` parameter defines the maximum wait time in seconds for completing the closing handshake and, only on the client side, for terminating the TCP connection. :meth:`close()` will complete in at most @@ -54,8 +57,14 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): is_client = False state = 'OPEN' - def __init__(self, timeout=10, loop=None): + def __init__(self, *, + host=None, port=None, secure=None, timeout=10, loop=None): + self.host = host + self.port = port + self.secure = secure + self.timeout = timeout + super().__init__(asyncio.StreamReader(), self.client_connected, loop) self.close_code = None diff --git a/websockets/server.py b/websockets/server.py index b266815e3..e121603c4 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -142,5 +142,8 @@ def serve(ws_handler, host=None, port=None, *, completes, the server performs the closing handshake and closes the connection. """ + secure = kwds.get('ssl') is not None + factory = lambda: klass(ws_handler, + host=host, port=port, secure=secure, origins=origins) return (yield from asyncio.get_event_loop().create_server( - lambda: klass(ws_handler, origins=origins), host, port, **kwds)) + factory, host, port, **kwds)) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index f28d9d768..d660d9eef 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -15,8 +15,11 @@ @asyncio.coroutine -def echo(ws, path): - yield from ws.send((yield from ws.recv())) +def handler(ws, path): + if path == '/attributes': + yield from ws.send(repr((ws.host, ws.port, ws.secure))) + else: + yield from ws.send((yield from ws.recv())) class ClientServerTests(unittest.TestCase): @@ -31,7 +34,7 @@ def tearDown(self): self.loop.close() def start_server(self): - server = serve(echo, 'localhost', 8642) + server = serve(handler, 'localhost', 8642) self.server = self.loop.run_until_complete(server) def start_client(self): @@ -52,6 +55,18 @@ def test_basic(self): self.assertEqual(reply, "Hello!") self.stop_client() + def test_protocol_attributes(self): + client = connect('ws://localhost:8642/attributes') + client = self.loop.run_until_complete(client) + try: + expected_attrs = repr(('localhost', 8642, False)) + client_attrs = repr((client.host, client.port, client.secure)) + self.assertEqual(client_attrs, expected_attrs) + server_attrs = self.loop.run_until_complete(client.recv()) + self.assertEqual(server_attrs, expected_attrs) + finally: + self.loop.run_until_complete(client.worker) + @patch('websockets.server.read_request') def test_server_receives_malformed_request(self, _read_request): _read_request.side_effect = ValueError("read_request failed") @@ -133,21 +148,40 @@ def test_server_close_crashes(self, close): @unittest.skipUnless(os.path.exists(testcert), "test certificate is missing") class SSLClientServerTests(ClientServerTests): - def start_server(self): + @property + def server_context(self): ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ssl_context.load_cert_chain(testcert) + return ssl_context - server = serve(echo, 'localhost', 8642, ssl=ssl_context) - self.server = self.loop.run_until_complete(server) - - def start_client(self): + @property + def client_context(self): ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ssl_context.load_verify_locations(testcert) ssl_context.verify_mode = ssl.CERT_REQUIRED + return ssl_context - client = connect('ws://localhost:8642/', ssl=ssl_context) + def start_server(self): + server = serve(handler, 'localhost', 8642, ssl=self.server_context) + self.server = self.loop.run_until_complete(server) + + def start_client(self): + client = connect('wss://localhost:8642/', ssl=self.client_context) self.client = self.loop.run_until_complete(client) + def test_protocol_attributes(self): + client = connect('wss://localhost:8642/attributes', + ssl=self.client_context) + client = self.loop.run_until_complete(client) + try: + expected_attrs = repr(('localhost', 8642, True)) + client_attrs = repr((client.host, client.port, client.secure)) + self.assertEqual(client_attrs, expected_attrs) + server_attrs = self.loop.run_until_complete(client.recv()) + self.assertEqual(server_attrs, expected_attrs) + finally: + self.loop.run_until_complete(client.worker) + class ClientServerOriginTests(unittest.TestCase): @@ -156,7 +190,7 @@ def test_checking_origin_succeeds(self): asyncio.set_event_loop(loop) server = loop.run_until_complete( - serve(echo, 'localhost', 8642, origins=['http://localhost'])) + serve(handler, 'localhost', 8642, origins=['http://localhost'])) client = loop.run_until_complete( connect('ws://localhost:8642/', origin='http://localhost')) @@ -173,7 +207,7 @@ def test_checking_origin_fails(self): asyncio.set_event_loop(loop) server = loop.run_until_complete( - serve(echo, 'localhost', 8642, origins=['http://localhost'])) + serve(handler, 'localhost', 8642, origins=['http://localhost'])) with self.assertRaises(InvalidHandshake): loop.run_until_complete( connect('ws://localhost:8642/', origin='http://otherhost')) @@ -187,7 +221,7 @@ def test_checking_lack_of_origin_succeeds(self): asyncio.set_event_loop(loop) server = loop.run_until_complete( - serve(echo, 'localhost', 8642, origins=[''])) + serve(handler, 'localhost', 8642, origins=[''])) client = loop.run_until_complete(connect('ws://localhost:8642/')) loop.run_until_complete(client.send("Hello!")) From dec978e0c5fbb0e947913427e88bdd92185520fb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 26 Apr 2014 14:43:57 +0200 Subject: [PATCH 0023/1539] Reject ws:// URIs with an SSL context. It seems better to ask for an explicit wss:// URI. The reverse isn't true: it's reasonable to create a default SSL context to connect to a wss:// URI when one isn't provided. --- websockets/client.py | 6 +++++- websockets/test_client_server.py | 4 ++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/websockets/client.py b/websockets/client.py index af5b4f101..aecc3b399 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -87,7 +87,11 @@ def connect(uri, *, more than one connection in a CONNECTING state." """ wsuri = parse_uri(uri) - kwds.setdefault('ssl', True) + if wsuri.secure: + kwds.setdefault('ssl', True) + elif 'ssl' in kwds: + raise ValueError("connect() received a SSL context for a ws:// URI. " + "Use a wss:// URI to enable TLS.") factory = lambda: klass(host=wsuri.host, port=wsuri.port, secure=wsuri.secure) transport, protocol = yield from asyncio.get_event_loop().create_connection( factory, wsuri.host, wsuri.port, **kwds) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index d660d9eef..a823eec6b 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -182,6 +182,10 @@ def test_protocol_attributes(self): finally: self.loop.run_until_complete(client.worker) + def test_ws_uri_is_rejected(self): + client = connect('ws://localhost:8642/', ssl=self.client_context) + with self.assertRaises(ValueError): + self.loop.run_until_complete(client) class ClientServerOriginTests(unittest.TestCase): From cc1de6051f897567a70d0e92ce7809ed458c0851 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 26 Apr 2014 14:49:13 +0200 Subject: [PATCH 0024/1539] Add new feature to changelog. --- docs/index.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/index.rst b/docs/index.rst index aa7e07dd9..48ec98382 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -151,6 +151,7 @@ Changelog 2.1 ... +* Added `host`, `port` and `secure` attributes on protocols. * Added support for providing and checking Origin_. .. _Origin: https://tools.ietf.org/html/rfc6455#section-10.2 From a292c73d0f136d650814bdae5257683b7d250055 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 26 Apr 2014 15:19:37 +0200 Subject: [PATCH 0025/1539] Add request and response headers to protocols. They're lists of 2-uples reflecting the HTTP request and response. I'm not documenting them yet because I'm not sure this is the most convenient API (dict-like access in more intuitive) but at least the raw data is available for those who need it. Refs #20. --- websockets/client.py | 13 ++++--- websockets/server.py | 13 ++++--- websockets/test_client_server.py | 59 +++++++++++++++++--------------- 3 files changed, 49 insertions(+), 36 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index aecc3b399..f8e041ad5 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -31,10 +31,8 @@ def handshake(self, wsuri, origin=None): If provided, ``origin`` sets the HTTP Origin header. """ - # Send handshake request. Since the URI and the headers only contain - # ASCII characters, we can keep this simple. - request = ['GET %s HTTP/1.1' % wsuri.resource_name] - set_header = lambda k, v: request.append('{}: {}'.format(k, v)) + headers = [] + set_header = lambda k, v: headers.append((k, v)) if wsuri.port == (443 if wsuri.secure else 80): # pragma: no cover set_header('Host', wsuri.host) else: @@ -43,6 +41,12 @@ def handshake(self, wsuri, origin=None): set_header('Origin', origin) set_header('User-Agent', USER_AGENT) key = build_request(set_header) + self.request_headers = headers + + # Send handshake request. Since the URI and the headers only contain + # ASCII characters, we can keep this simple. + request = ['GET %s HTTP/1.1' % wsuri.resource_name] + request.extend('{}: {}'.format(k, v) for k, v in headers) request.append('\r\n') request = '\r\n'.join(request).encode() self.writer.write(request) @@ -54,6 +58,7 @@ def handshake(self, wsuri, origin=None): raise InvalidHandshake("Malformed HTTP message") from exc if status_code != 101: raise InvalidHandshake("Bad status code: {}".format(status_code)) + self.response_headers = list(headers.raw_items()) get_header = lambda k: headers.get(k, '') check_response(get_header, key) diff --git a/websockets/server.py b/websockets/server.py index e121603c4..0f6771f6a 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -95,6 +95,7 @@ def handshake(self, origins=None): except Exception as exc: raise InvalidHandshake("Malformed HTTP message") from exc + self.request_headers = list(headers.raw_items()) get_header = lambda k: headers.get(k, '') key = check_request(get_header) @@ -104,12 +105,16 @@ def handshake(self, origins=None): if not set(origin.split() or ('',))<= set(origins): raise InvalidHandshake("Bad origin: {}".format(origin)) - # Send handshake response. Since the headers only contain ASCII - # characters, we can keep this simple. - response = ['HTTP/1.1 101 Switching Protocols'] - set_header = lambda k, v: response.append('{}: {}'.format(k, v)) + headers = [] + set_header = lambda k, v: headers.append((k, v)) set_header('Server', USER_AGENT) build_response(set_header, key) + self.response_headers = headers + + # Send handshake response. Since the status line and headers only + # contain ASCII characters, we can keep this simple. + response = ['HTTP/1.1 101 Switching Protocols'] + response.extend('{}: {}'.format(k, v) for k, v in headers) response.append('\r\n') response = '\r\n'.join(response).encode() self.writer.write(response) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index a823eec6b..815cb02cd 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -7,7 +7,7 @@ from .client import * from .exceptions import InvalidHandshake -from .http import read_response +from .http import read_response, USER_AGENT from .server import * @@ -18,12 +18,17 @@ def handler(ws, path): if path == '/attributes': yield from ws.send(repr((ws.host, ws.port, ws.secure))) + elif path == '/headers': + yield from ws.send(repr(ws.request_headers)) + yield from ws.send(repr(ws.response_headers)) else: yield from ws.send((yield from ws.recv())) class ClientServerTests(unittest.TestCase): + secure = False + def setUp(self): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) @@ -37,8 +42,8 @@ def start_server(self): server = serve(handler, 'localhost', 8642) self.server = self.loop.run_until_complete(server) - def start_client(self): - client = connect('ws://localhost:8642/') + def start_client(self, path=''): + client = connect('ws://localhost:8642/' + path) self.client = self.loop.run_until_complete(client) def stop_client(self): @@ -56,16 +61,25 @@ def test_basic(self): self.stop_client() def test_protocol_attributes(self): - client = connect('ws://localhost:8642/attributes') - client = self.loop.run_until_complete(client) - try: - expected_attrs = repr(('localhost', 8642, False)) - client_attrs = repr((client.host, client.port, client.secure)) - self.assertEqual(client_attrs, expected_attrs) - server_attrs = self.loop.run_until_complete(client.recv()) - self.assertEqual(server_attrs, expected_attrs) - finally: - self.loop.run_until_complete(client.worker) + self.start_client('attributes') + expected_attrs = ('localhost', 8642, self.secure) + client_attrs = (self.client.host, self.client.port, self.client.secure) + self.assertEqual(client_attrs, expected_attrs) + server_attrs = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(server_attrs, repr(expected_attrs)) + self.stop_client() + + def test_protocol_headers(self): + self.start_client('headers') + client_req = self.client.request_headers + client_resp = self.client.response_headers + self.assertEqual(dict(client_req)['User-Agent'], USER_AGENT) + self.assertEqual(dict(client_resp)['Server'], USER_AGENT) + server_req = self.loop.run_until_complete(self.client.recv()) + server_resp = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(server_req, repr(client_req)) + self.assertEqual(server_resp, repr(client_resp)) + self.stop_client() @patch('websockets.server.read_request') def test_server_receives_malformed_request(self, _read_request): @@ -148,6 +162,8 @@ def test_server_close_crashes(self, close): @unittest.skipUnless(os.path.exists(testcert), "test certificate is missing") class SSLClientServerTests(ClientServerTests): + secure = True + @property def server_context(self): ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) @@ -165,23 +181,10 @@ def start_server(self): server = serve(handler, 'localhost', 8642, ssl=self.server_context) self.server = self.loop.run_until_complete(server) - def start_client(self): - client = connect('wss://localhost:8642/', ssl=self.client_context) + def start_client(self, path=''): + client = connect('wss://localhost:8642/' + path, ssl=self.client_context) self.client = self.loop.run_until_complete(client) - def test_protocol_attributes(self): - client = connect('wss://localhost:8642/attributes', - ssl=self.client_context) - client = self.loop.run_until_complete(client) - try: - expected_attrs = repr(('localhost', 8642, True)) - client_attrs = repr((client.host, client.port, client.secure)) - self.assertEqual(client_attrs, expected_attrs) - server_attrs = self.loop.run_until_complete(client.recv()) - self.assertEqual(server_attrs, expected_attrs) - finally: - self.loop.run_until_complete(client.worker) - def test_ws_uri_is_rejected(self): client = connect('ws://localhost:8642/', ssl=self.client_context) with self.assertRaises(ValueError): From ff9e82c912c2fa8581c0acaec364b5563e945cbd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 26 Apr 2014 15:28:59 +0200 Subject: [PATCH 0026/1539] Rename attributes to raw_request/response_headers. This leaves room for a better API (with case-insensitive lookup etc.) to be implemented as request/response_headers. --- websockets/client.py | 4 ++-- websockets/server.py | 4 ++-- websockets/test_client_server.py | 14 +++++++------- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index f8e041ad5..107401549 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -41,7 +41,7 @@ def handshake(self, wsuri, origin=None): set_header('Origin', origin) set_header('User-Agent', USER_AGENT) key = build_request(set_header) - self.request_headers = headers + self.raw_request_headers = headers # Send handshake request. Since the URI and the headers only contain # ASCII characters, we can keep this simple. @@ -58,7 +58,7 @@ def handshake(self, wsuri, origin=None): raise InvalidHandshake("Malformed HTTP message") from exc if status_code != 101: raise InvalidHandshake("Bad status code: {}".format(status_code)) - self.response_headers = list(headers.raw_items()) + self.raw_response_headers = list(headers.raw_items()) get_header = lambda k: headers.get(k, '') check_response(get_header, key) diff --git a/websockets/server.py b/websockets/server.py index 0f6771f6a..8f27bd778 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -95,7 +95,7 @@ def handshake(self, origins=None): except Exception as exc: raise InvalidHandshake("Malformed HTTP message") from exc - self.request_headers = list(headers.raw_items()) + self.raw_request_headers = list(headers.raw_items()) get_header = lambda k: headers.get(k, '') key = check_request(get_header) @@ -109,7 +109,7 @@ def handshake(self, origins=None): set_header = lambda k, v: headers.append((k, v)) set_header('Server', USER_AGENT) build_response(set_header, key) - self.response_headers = headers + self.raw_response_headers = headers # Send handshake response. Since the status line and headers only # contain ASCII characters, we can keep this simple. diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 815cb02cd..43be66120 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -18,9 +18,9 @@ def handler(ws, path): if path == '/attributes': yield from ws.send(repr((ws.host, ws.port, ws.secure))) - elif path == '/headers': - yield from ws.send(repr(ws.request_headers)) - yield from ws.send(repr(ws.response_headers)) + elif path == '/raw_headers': + yield from ws.send(repr(ws.raw_request_headers)) + yield from ws.send(repr(ws.raw_response_headers)) else: yield from ws.send((yield from ws.recv())) @@ -69,10 +69,10 @@ def test_protocol_attributes(self): self.assertEqual(server_attrs, repr(expected_attrs)) self.stop_client() - def test_protocol_headers(self): - self.start_client('headers') - client_req = self.client.request_headers - client_resp = self.client.response_headers + def test_protocol_raw_headers(self): + self.start_client('raw_headers') + client_req = self.client.raw_request_headers + client_resp = self.client.raw_response_headers self.assertEqual(dict(client_req)['User-Agent'], USER_AGENT) self.assertEqual(dict(client_resp)['Server'], USER_AGENT) server_req = self.loop.run_until_complete(self.client.recv()) From bb64dc7aef10d55eebd139255263d35c7a092e76 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 26 Apr 2014 15:46:40 +0200 Subject: [PATCH 0027/1539] Death to unicode literals. --- docs/conf.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index d00ce37d8..ce299d709 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -40,8 +40,8 @@ master_doc = 'index' # General information about the project. -project = u'websockets' -copyright = u'2013-2014, Aymeric Augustin' +project = 'websockets' +copyright = '2013-2014, Aymeric Augustin' # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -184,8 +184,8 @@ # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass [howto/manual]). latex_documents = [ - ('index', 'websockets.tex', u'websockets Documentation', - u'Aymeric Augustin', 'manual'), + ('index', 'websockets.tex', 'websockets Documentation', + 'Aymeric Augustin', 'manual'), ] # The name of an image file (relative to this directory) to place at the top of @@ -214,8 +214,8 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - ('index', 'websockets', u'websockets Documentation', - [u'Aymeric Augustin'], 1) + ('index', 'websockets', 'websockets Documentation', + ['Aymeric Augustin'], 1) ] # If true, show URL addresses after external links. @@ -228,8 +228,8 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - ('index', 'websockets', u'websockets Documentation', - u'Aymeric Augustin', 'websockets', 'One line description of project.', + ('index', 'websockets', 'websockets Documentation', + 'Aymeric Augustin', 'websockets', 'One line description of project.', 'Miscellaneous'), ] From 07c13192aaa177a86f65acc6041e91d2dd03e2ce Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 26 Apr 2014 15:47:19 +0200 Subject: [PATCH 0028/1539] Bump version number. --- docs/conf.py | 4 ++-- websockets/version.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index ce299d709..2ec2a4245 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -48,9 +48,9 @@ # built documents. # # The short X.Y version. -version = '2.0' +version = '2.1' # The full version, including alpha/beta/rc tags. -release = '2.0.0' +release = '2.1.0' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/websockets/version.py b/websockets/version.py index fc2e06f51..dd4e0aee4 100644 --- a/websockets/version.py +++ b/websockets/version.py @@ -1 +1 @@ -version = '2.0' +version = '2.1' From 481986fd209c5d0667195390b5f1f56f2ee899a7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 26 Apr 2014 15:47:49 +0200 Subject: [PATCH 0029/1539] Create universal wheels. --- setup.cfg | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 setup.cfg diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 000000000..e57d130e3 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[wheel] +universal = True From 415b3d560cf3691d465189e07339157e2bff598b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 1 May 2014 22:41:03 +0200 Subject: [PATCH 0030/1539] No more wheels, fix #25. --- setup.cfg | 2 -- 1 file changed, 2 deletions(-) delete mode 100644 setup.cfg diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index e57d130e3..000000000 --- a/setup.cfg +++ /dev/null @@ -1,2 +0,0 @@ -[wheel] -universal = True From 90e0494f53c8670a0bd233b1811ea3b596843a5c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 17 May 2014 21:43:57 +0200 Subject: [PATCH 0031/1539] Close connection properly when the socket dies. This avoids flooding logs with messages output by asyncio: "socket.send() raised exception." Fix #23 (hopefully). --- websockets/protocol.py | 16 ++++++++++++---- websockets/test_client_server.py | 4 ++-- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index 210bd8b85..8d85c38ff 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -317,11 +317,13 @@ def write_frame(self, opcode, data=b'', expected_state='OPEN'): logger.debug("%s >> %s", side, frame) is_masked = self.is_client write_frame(frame, self.writer.write, is_masked) - # Handle flow control automatically. try: + # Handle flow control automatically. yield from self.writer.drain() except ConnectionResetError: - pass + # Terminate the connection if the socket died. + self.state = 'CLOSING' + yield from self.fail_connection(1006) @asyncio.coroutine def close_connection(self): @@ -344,8 +346,14 @@ def close_connection(self): if self.state == 'CLOSED': return - if self.writer.can_write_eof(): - self.writer.write_eof() + # Attempt to terminate the TCP connection properly. + # If the socket is already closed, this will crash. + try: + if self.writer.can_write_eof(): + self.writer.write_eof() + except Exception: + pass + self.writer.close() try: diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 43be66120..68163a2ae 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -155,8 +155,8 @@ def test_server_close_crashes(self, close): self.assertEqual(reply, "Hello!") self.stop_client() - # Connection ends with a protocol error. - self.assertEqual(self.client.close_code, 1002) + # Connection ends with an abnormal closure. + self.assertEqual(self.client.close_code, 1006) @unittest.skipUnless(os.path.exists(testcert), "test certificate is missing") From a01b6ca96e71063af43c0b29d3e1326a8f959f09 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 29 Jun 2014 17:50:11 +0200 Subject: [PATCH 0032/1539] Document how to show stack traces for exceptions in handlers. --- websockets/server.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/websockets/server.py b/websockets/server.py index 8f27bd778..438764eaa 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -146,6 +146,15 @@ def serve(ws_handler, host=None, port=None, *, handshake, and delegates to the WebSocket handler. Once the handler completes, the server performs the closing handshake and closes the connection. + + Since there's no useful way to propagate exceptions triggered in handlers, + they're sent to the `websockets.server` logger instead. Debugging is much + easier if you configure logging to print them:: + + import logging + logger = logging.getLogger('websockets.server') + logger.setLevel(logging.DEBUG) + logger.addHandler(logging.StreamHandler()) """ secure = kwds.get('ssl') is not None factory = lambda: klass(ws_handler, From 48b9c0226dd4d770b92459b68d0ef073008bacf2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 29 Jun 2014 17:50:36 +0200 Subject: [PATCH 0033/1539] Update signatures in docs, especially for Origin checking. --- docs/index.rst | 12 ++++++------ websockets/server.py | 11 +++++------ 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index 48ec98382..97457bcdf 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -81,9 +81,9 @@ Server .. automodule:: websockets.server - .. autofunction:: serve(ws_handler, host=None, port=None, *, klass=WebSocketServerProtocol, **kwds) + .. autofunction:: serve(ws_handler, host=None, port=None, *, klass=WebSocketServerProtocol, origins=None, **kwds) - .. autoclass:: WebSocketServerProtocol(self, ws_handler, timeout=10) + .. autoclass:: WebSocketServerProtocol(ws_handler, *, origins=None, host=None, port=None, secure=None, timeout=10, loop=None) :members: handshake Client @@ -91,9 +91,9 @@ Client .. automodule:: websockets.client - .. autofunction:: connect(uri, *, klass=WebSocketClientProtocol, **kwds) + .. autofunction:: connect(uri, *, klass=WebSocketClientProtocol, origin=None, **kwds) - .. autoclass:: WebSocketClientProtocol(self, timeout=10) + .. autoclass:: WebSocketClientProtocol(*, host=None, port=None, secure=None, timeout=10, loop=None) :members: handshake Shared @@ -101,7 +101,7 @@ Shared .. automodule:: websockets.protocol - .. autoclass:: WebSocketCommonProtocol(self, timeout=10) + .. autoclass:: WebSocketCommonProtocol(*, host=None, port=None, secure=None, timeout=10, loop=None) .. autoattribute:: open .. automethod:: close(code=1000, reason='') @@ -110,7 +110,7 @@ Shared .. automethod:: send(data) .. automethod:: ping(data=None) - .. automethod:: pong() + .. automethod:: pong(data=b'') Low-level API ------------- diff --git a/websockets/server.py b/websockets/server.py index 438764eaa..a1ae0a222 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -30,7 +30,7 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): state = 'CONNECTING' - def __init__(self, ws_handler=None, *, origins=None, **kwargs): + def __init__(self, ws_handler, *, origins=None, **kwargs): self.ws_handler = ws_handler self.origins = origins super().__init__(**kwargs) @@ -60,7 +60,7 @@ def handler(self): try: yield from self.ws_handler(self, path) except Exception: - logger.info("Exception in connection handler", exc_info=True) + logger.error("Exception in connection handler", exc_info=True) yield from self.fail_connection(1011) raise @@ -77,15 +77,13 @@ def handler(self): except Exception: # pragma: no cover pass - - @asyncio.coroutine def handshake(self, origins=None): """ Perform the server side of the opening handshake. If provided, ``origins`` is a list of acceptable HTTP Origin values. - Include ``''`` in the list if the lack of an origin is acceptable. + Include ``''`` if the lack of an origin is acceptable. Return the URI of the request. """ @@ -137,7 +135,8 @@ def serve(ws_handler, host=None, port=None, *, `ws_handler` is the WebSocket handler. It must be a coroutine accepting two arguments: a :class:`~websockets.server.WebSocketServerProtocol` and - the request URI. `origin` is a list of acceptable Origin HTTP headers. + the request URI. If provided, `origin` is a list of acceptable Origin HTTP + headers. Include ``''`` if the lack of an origin is acceptable. It returns a `Server` object with a `close` method to stop the server. From 567db4eca52008cd4401124e344ee916f8229d0e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 30 Jun 2014 21:39:15 +0200 Subject: [PATCH 0034/1539] Avoid showing stack traces during test runs. --- websockets/test_client_server.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 68163a2ae..4918644c1 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -1,3 +1,4 @@ +import logging import os import ssl import unittest @@ -11,6 +12,9 @@ from .server import * +# Avoid displaying stack traces at the ERROR logging level. +logging.basicConfig(level=logging.CRITICAL) + testcert = os.path.join(os.path.dirname(__file__), 'testcert.pem') From c19008269801741af1f32ebf0f7fe1fe0658ead8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 30 Jun 2014 21:43:21 +0200 Subject: [PATCH 0035/1539] Run the event loop a bit to avoid race conditions. --- websockets/test_client_server.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 4918644c1..a0177986d 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -51,10 +51,14 @@ def start_client(self, path=''): self.client = self.loop.run_until_complete(client) def stop_client(self): + # Run the event loop to finish processing and avoid race conditions. + self.loop.run_until_complete(asyncio.sleep(0.001)) self.loop.run_until_complete(self.client.worker) def stop_server(self): self.server.close() + # Run the event loop to finish processing and avoid race conditions. + self.loop.run_until_complete(asyncio.sleep(0.001)) self.loop.run_until_complete(self.server.wait_closed()) def test_basic(self): From fdae6c828549a148345bc32aa6de8e10136ca035 Mon Sep 17 00:00:00 2001 From: Rui Abreu Ferreira Date: Wed, 18 Jun 2014 09:05:33 +0100 Subject: [PATCH 0036/1539] Add support for limiting message length. Initial patch by Rui Abreu Ferreira, final patch by Aymeric Augustin. Fix #28, #29. --- LICENSE | 2 +- docs/index.rst | 11 +++++-- websockets/exceptions.py | 4 +++ websockets/framing.py | 10 +++++-- websockets/protocol.py | 45 +++++++++++++++++++++------- websockets/test_framing.py | 10 +++++-- websockets/test_protocol.py | 58 ++++++++++++++++++++++++++++++++++++- 7 files changed, 120 insertions(+), 20 deletions(-) diff --git a/LICENSE b/LICENSE index 60ab1d99f..fcf63659c 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright (c) 2013-2014 Aymeric Augustin. +Copyright (c) 2013-2014 Aymeric Augustin and contributors. All rights reserved. Redistribution and use in source and binary forms, with or without diff --git a/docs/index.rst b/docs/index.rst index 97457bcdf..722656b4e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -83,7 +83,7 @@ Server .. autofunction:: serve(ws_handler, host=None, port=None, *, klass=WebSocketServerProtocol, origins=None, **kwds) - .. autoclass:: WebSocketServerProtocol(ws_handler, *, origins=None, host=None, port=None, secure=None, timeout=10, loop=None) + .. autoclass:: WebSocketServerProtocol(ws_handler, *, origins=None, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, loop=None) :members: handshake Client @@ -93,7 +93,7 @@ Client .. autofunction:: connect(uri, *, klass=WebSocketClientProtocol, origin=None, **kwds) - .. autoclass:: WebSocketClientProtocol(*, host=None, port=None, secure=None, timeout=10, loop=None) + .. autoclass:: WebSocketClientProtocol(*, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, loop=None) :members: handshake Shared @@ -101,7 +101,7 @@ Shared .. automodule:: websockets.protocol - .. autoclass:: WebSocketCommonProtocol(*, host=None, port=None, secure=None, timeout=10, loop=None) + .. autoclass:: WebSocketCommonProtocol(*, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, loop=None) .. autoattribute:: open .. automethod:: close(code=1000, reason='') @@ -148,6 +148,11 @@ Utilities Changelog --------- +2.2 +... + +* Added support for limiting message size. + 2.1 ... diff --git a/websockets/exceptions.py b/websockets/exceptions.py index cb3e9a409..20cc0ba2f 100644 --- a/websockets/exceptions.py +++ b/websockets/exceptions.py @@ -13,6 +13,10 @@ class InvalidURI(Exception): """Exception raised when an URI is invalid.""" +class PayloadTooBig(Exception): + """Exception raised when the payload in a frame exceeds the maximum size.""" + + class WebSocketProtocolError(Exception): # Internal exception raised when the other end breaks the protocol. # It's private because it shouldn't leak outside of WebSocketCommonProtocol. diff --git a/websockets/framing.py b/websockets/framing.py index d65e220b5..d281306e1 100644 --- a/websockets/framing.py +++ b/websockets/framing.py @@ -15,7 +15,7 @@ import asyncio -from .exceptions import WebSocketProtocolError +from .exceptions import WebSocketProtocolError, PayloadTooBig __all__ = [ @@ -47,7 +47,7 @@ @asyncio.coroutine -def read_frame(reader, mask): +def read_frame(reader, mask, *, max_size=None): """ Read a WebSocket frame and return a :class:`Frame` object. @@ -57,6 +57,9 @@ def read_frame(reader, mask): `mask` is a :class:`bool` telling whether the frame should be masked, ie. whether the read happens on the server side. + If `max_size` is set and the payload exceeds this size in bytes, + :exc:`PayloadTooBig` is raised. + This function validates the frame before returning it and raises :exc:`WebSocketProtocolError` if it contains incorrect values. """ @@ -76,6 +79,9 @@ def read_frame(reader, mask): elif length == 127: data = yield from read_bytes(reader, 8) length, = struct.unpack('!Q', data) + if max_size is not None and length > max_size: + raise PayloadTooBig("Payload exceeds limit " + "({} > {} bytes)".format(length, max_size)) if mask: mask_bits = yield from read_bytes(reader, 4) diff --git a/websockets/protocol.py b/websockets/protocol.py index 8d85c38ff..99696ade3 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -16,7 +16,7 @@ import asyncio from asyncio.queues import Queue, QueueEmpty -from .exceptions import InvalidState, WebSocketProtocolError +from .exceptions import InvalidState, PayloadTooBig, WebSocketProtocolError from .framing import * from .handshake import * @@ -45,6 +45,11 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): terminating the TCP connection. :meth:`close()` will complete in at most this time on the server side and twice this time on the client side. + The `max_size` parameter enforces the maximum size for incoming messages + in bytes. The default value is 1MB. ``None`` disables the limit. If a + message larger than the maximum size is received, :meth:`recv()` will + return ``None`` and the connection will be closed with status code 1009. + Once the connection is closed, the status code is available in the :attr:`close_code` attribute and the reason in :attr:`close_reason`. """ @@ -58,12 +63,13 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): state = 'OPEN' def __init__(self, *, - host=None, port=None, secure=None, timeout=10, loop=None): + host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, loop=None): self.host = host self.port = port self.secure = secure self.timeout = timeout + self.max_size = max_size super().__init__(asyncio.StreamReader(), self.client_connected, loop) @@ -226,6 +232,8 @@ def run(self): yield from self.fail_connection(1002) except UnicodeDecodeError: yield from self.fail_connection(1007) + except PayloadTooBig: + yield from self.fail_connection(1009) except Exception: yield from self.fail_connection(1011) raise @@ -234,7 +242,7 @@ def run(self): @asyncio.coroutine def read_message(self): # Reassemble fragmented messages. - frame = yield from self.read_data_frame() + frame = yield from self.read_data_frame(max_size=self.max_size) if frame is None: return if frame.opcode == OP_TEXT: @@ -250,15 +258,32 @@ def read_message(self): # 5.4. Fragmentation chunks = [] + max_size = self.max_size if text: decoder = codecs.getincrementaldecoder('utf-8')(errors='strict') - append = lambda f: chunks.append(decoder.decode(f.data, f.fin)) + if max_size is None: + def append(frame): + nonlocal chunks + chunks.append(decoder.decode(frame.data, frame.fin)) + else: + def append(frame): + nonlocal chunks, max_size + chunks.append(decoder.decode(frame.data, frame.fin)) + max_size -= len(frame.data) else: - append = lambda f: chunks.append(f.data) + if max_size is None: + def append(frame): + nonlocal chunks + chunks.append(frame.data) + else: + def append(frame): + nonlocal chunks, max_size + chunks.append(frame.data) + max_size -= len(frame.data) append(frame) while not frame.fin: - frame = yield from self.read_data_frame() + frame = yield from self.read_data_frame(max_size=max_size) if frame is None: raise WebSocketProtocolError("Incomplete fragmented message") if frame.opcode != OP_CONT: @@ -268,11 +293,11 @@ def read_message(self): return ('' if text else b'').join(chunks) @asyncio.coroutine - def read_data_frame(self): + def read_data_frame(self, max_size): # Deal with control frames automatically and return next data frame. # 6.2. Receiving Data while True: - frame = yield from self.read_frame() + frame = yield from self.read_frame(max_size) # 5.5. Control Frames if frame.opcode == OP_CLOSE: self.close_code, self.close_reason = parse_close(frame.data) @@ -299,9 +324,9 @@ def read_data_frame(self): return frame @asyncio.coroutine - def read_frame(self): + def read_frame(self, max_size): is_masked = not self.is_client - frame = yield from read_frame(self.reader.readexactly, is_masked) + frame = yield from read_frame(self.reader.readexactly, is_masked, max_size=max_size) side = 'client' if self.is_client else 'server' logger.debug("%s << %s", side, frame) return frame diff --git a/websockets/test_framing.py b/websockets/test_framing.py index 2a323e059..bf3c7967b 100644 --- a/websockets/test_framing.py +++ b/websockets/test_framing.py @@ -3,7 +3,7 @@ import asyncio -from .exceptions import WebSocketProtocolError +from .exceptions import WebSocketProtocolError, PayloadTooBig from .framing import * @@ -16,12 +16,12 @@ def setUp(self): def tearDown(self): self.loop.close() - def decode(self, message, mask=False): + def decode(self, message, mask=False, max_size=None): self.stream = asyncio.StreamReader() self.stream.feed_data(message) self.stream.feed_eof() reader = self.stream.readexactly - return self.loop.run_until_complete(read_frame(reader, mask)) + return self.loop.run_until_complete(read_frame(reader, mask, max_size=max_size)) def encode(self, frame, mask=False): encoded = io.BytesIO() @@ -89,6 +89,10 @@ def test_very_long(self): b'\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x00' + 65536 * b'a', Frame(True, OP_BINARY, 65536 * b'a')) + def test_payload_too_big(self): + with self.assertRaises(PayloadTooBig): + self.decode(b'\x82\x7e\x04\x01' + 1025 * b'a', max_size=1024) + def test_bad_reserved_bits(self): with self.assertRaises(WebSocketProtocolError): self.decode(b'\xc0\x00') diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 4c9a40a85..1248323de 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -3,7 +3,7 @@ import asyncio -from .exceptions import InvalidState +from .exceptions import InvalidState, PayloadTooBig from .framing import * from .protocol import WebSocketCommonProtocol @@ -104,6 +104,32 @@ def test_recv_unicode_error(self): self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1007, '') + def test_recv_text_payload_too_big(self): + self.protocol.max_size = 1024 + self.feed(Frame(True, OP_TEXT, 'café'.encode('utf-8') * 205)) + self.loop.call_later(MS, asyncio.async, self.fast_connection_failure()) + self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) + self.assertConnectionClosed(1009, '') + + def test_recv_binary_payload_too_big(self): + self.protocol.max_size = 1024 + self.feed(Frame(True, OP_BINARY, b'tea' * 342)) + self.loop.call_later(MS, asyncio.async, self.fast_connection_failure()) + self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) + self.assertConnectionClosed(1009, '') + + def test_recv_text_no_max_size(self): + self.protocol.max_size = None # for test coverage + self.feed(Frame(True, OP_TEXT, 'café'.encode('utf-8') * 205)) + data = self.loop.run_until_complete(self.protocol.recv()) + self.assertEqual(data, 'café' * 205) + + def test_recv_binary_no_max_size(self): + self.protocol.max_size = None # for test coverage + self.feed(Frame(True, OP_BINARY, b'tea' * 342)) + data = self.loop.run_until_complete(self.protocol.recv()) + self.assertEqual(data, b'tea' * 342) + def test_recv_other_error(self): @asyncio.coroutine def read_message(): @@ -206,6 +232,36 @@ def test_fragmented_binary(self): data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, b'tea') + def test_fragmented_text_payload_too_big(self): + self.protocol.max_size = 1024 + self.feed(Frame(False, OP_TEXT, 'café'.encode('utf-8') * 100)) + self.feed(Frame(True, OP_CONT, 'café'.encode('utf-8') * 105)) + self.loop.call_later(MS, asyncio.async, self.fast_connection_failure()) + self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) + self.assertConnectionClosed(1009, '') + + def test_fragmented_binary_payload_too_big(self): + self.protocol.max_size = 1024 + self.feed(Frame(False, OP_BINARY, b'tea' * 171)) + self.feed(Frame(True, OP_CONT, b'tea' * 171)) + self.loop.call_later(MS, asyncio.async, self.fast_connection_failure()) + self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) + self.assertConnectionClosed(1009, '') + + def test_fragmented_text_no_max_size(self): + self.protocol.max_size = None # for test coverage + self.feed(Frame(False, OP_TEXT, 'café'.encode('utf-8') * 100)) + self.feed(Frame(True, OP_CONT, 'café'.encode('utf-8') * 105)) + data = self.loop.run_until_complete(self.protocol.recv()) + self.assertEqual(data, 'café' * 205) + + def test_fragmented_binary_no_max_size(self): + self.protocol.max_size = None # for test coverage + self.feed(Frame(False, OP_BINARY, b'tea' * 171)) + self.feed(Frame(True, OP_CONT, b'tea' * 171)) + data = self.loop.run_until_complete(self.protocol.recv()) + self.assertEqual(data, b'tea' * 342) + def test_control_frame_within_fragmented_text(self): self.feed(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) self.feed(Frame(True, OP_PING, b'')) From edf3a323e657ffd73229d333995e9b583af17920 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 28 Jul 2014 11:03:10 +0200 Subject: [PATCH 0037/1539] Document how to write recv and send loops. Refs #30. --- docs/index.rst | 52 +++++++++++++++++++++++++++++++++++++++-------- example/server.py | 2 +- 2 files changed, 45 insertions(+), 9 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index 722656b4e..057056591 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -37,20 +37,56 @@ message. .. literalinclude:: ../example/server.py -.. note:: - - The handler function, ``hello``, is executed once for each WebSocket - connection. The connection is automatically closed when the handler - returns. If you want to process several messages in the same connection, - you must write a loop, most likely with :attr:`websocket.open - `. - .. _client-example: Here's a corresponding client example. .. literalinclude:: ../example/client.py +.. note:: + + On the server side, the handler coroutine ``hello`` is executed once for + each WebSocket connection. The connection is automatically closed when the + handler returns. + + You will almost always want to process several messages during the + lifetime of a connection. Therefore you must write a loop. Here are the + recommended patterns to exit cleanly when the connection drops, either + because the other side closed it or for any other reason. + + For receiving messages and passing them to a ``consumer`` coroutine:: + + @asyncio.coroutine + def handler(websocket, path): + while True: + message = yield from websocket.recv() + if message is None: + break + yield from consumer(message) + + :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` returns ``None`` + when the connection is closed. In other words, ``None`` marks the end of + the message stream. The handler coroutine should check for that case and + return when it happens. + + For getting messages from a ``producer`` coroutine and sending them:: + + @asyncio.coroutine + def handler(websocket, path): + while True: + message = yield from producer() + if not websocket.open: + break + yield from websocket.send(message) + + :meth:`~websockets.protocol.WebSocketCommonProtocol.send` fails with an + exception when it's called on a closed connection. Therefore the handler + coroutine should check that the connection is still open before attempting + to write and return otherwise. + + Of course, you can combine the two patterns shown above to read and write + messages on the same connection. + Design ------ diff --git a/example/server.py b/example/server.py index 958697318..dea1fd40c 100755 --- a/example/server.py +++ b/example/server.py @@ -8,8 +8,8 @@ def hello(websocket, path): name = yield from websocket.recv() print("< {}".format(name)) greeting = "Hello {}!".format(name) - print("> {}".format(greeting)) yield from websocket.send(greeting) + print("> {}".format(greeting)) start_server = websockets.serve(hello, 'localhost', 8765) From 7e4ef005696a956def34fa735422ca4e1ef17197 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 28 Jul 2014 11:14:48 +0200 Subject: [PATCH 0038/1539] Bump version number. --- docs/conf.py | 4 ++-- websockets/version.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 2ec2a4245..157e604fa 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -48,9 +48,9 @@ # built documents. # # The short X.Y version. -version = '2.1' +version = '2.2' # The full version, including alpha/beta/rc tags. -release = '2.1.0' +release = '2.2.0' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/websockets/version.py b/websockets/version.py index dd4e0aee4..62d14aae4 100644 --- a/websockets/version.py +++ b/websockets/version.py @@ -1 +1 @@ -version = '2.1' +version = '2.2' From 1f8bd3ade280a3d7e1254cee7a7379a96a247154 Mon Sep 17 00:00:00 2001 From: housleyjk Date: Mon, 29 Sep 2014 13:11:42 -0600 Subject: [PATCH 0039/1539] Prevent incomplete task leak in recv. --- websockets/protocol.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index 99696ade3..06cb42f66 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -157,12 +157,14 @@ def recv(self): pass # Wait for a message until the connection is closed - next_message = asyncio.Task(self.messages.get()) + next_message = asyncio.async(self.messages.get()) done, pending = yield from asyncio.wait( [next_message, self.worker], return_when=asyncio.FIRST_COMPLETED) if next_message in done: return next_message.result() + else: + next_message.cancel() @asyncio.coroutine def send(self, data): From 2f9c24e9adff980855e5d4c2493a68259e295e68 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Nov 2014 19:20:46 +0100 Subject: [PATCH 0040/1539] Normalize tolerance in tests. --- websockets/test_protocol.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 1248323de..f95e3cfa7 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -338,9 +338,9 @@ def test_close_drops_frames(self): def test_close_handshake_timeout(self): self.after = asyncio.Future() - self.loop.call_later(2 * MS, self.after.cancel) + self.loop.call_later(4 * MS, self.after.cancel) self.before = asyncio.Future() - self.loop.call_later(10 * MS, self.before.cancel) + self.loop.call_later(8 * MS, self.before.cancel) self.protocol.timeout = 5 * MS self.loop.run_until_complete(self.protocol.close(reason='because.')) self.assertConnectionClosed(1000, 'because.') From 1520cc1f2d1310529fdb5164ddfced1e4ac1c080 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 3 Nov 2014 22:40:46 +0100 Subject: [PATCH 0041/1539] Connection failures are 1006 errors, not 1002. --- websockets/framing.py | 19 +++++-------------- websockets/protocol.py | 2 ++ websockets/test_framing.py | 4 ---- websockets/test_protocol.py | 4 ++-- 4 files changed, 9 insertions(+), 20 deletions(-) diff --git a/websockets/framing.py b/websockets/framing.py index d281306e1..b83b45466 100644 --- a/websockets/framing.py +++ b/websockets/framing.py @@ -64,7 +64,7 @@ def read_frame(reader, mask, *, max_size=None): :exc:`WebSocketProtocolError` if it contains incorrect values. """ # Read the header - data = yield from read_bytes(reader, 2) + data = yield from reader(2) head1, head2 = struct.unpack('!BB', data) fin = bool(head1 & 0b10000000) if head1 & 0b01110000: @@ -74,19 +74,19 @@ def read_frame(reader, mask, *, max_size=None): raise WebSocketProtocolError("Incorrect masking") length = head2 & 0b01111111 if length == 126: - data = yield from read_bytes(reader, 2) + data = yield from reader(2) length, = struct.unpack('!H', data) elif length == 127: - data = yield from read_bytes(reader, 8) + data = yield from reader(8) length, = struct.unpack('!Q', data) if max_size is not None and length > max_size: raise PayloadTooBig("Payload exceeds limit " "({} > {} bytes)".format(length, max_size)) if mask: - mask_bits = yield from read_bytes(reader, 4) + mask_bits = yield from reader(4) # Read the data - data = yield from read_bytes(reader, length) + data = yield from reader(length) if mask: data = bytes(b ^ mask_bits[i % 4] for i, b in enumerate(data)) @@ -95,15 +95,6 @@ def read_frame(reader, mask, *, max_size=None): return frame -@asyncio.coroutine -def read_bytes(reader, n): - # Undocumented utility function. - try: - return (yield from reader(n)) - except asyncio.IncompleteReadError: - raise WebSocketProtocolError("Unexpected EOF") - - def write_frame(frame, writer, mask): """ Write a WebSocket frame. diff --git a/websockets/protocol.py b/websockets/protocol.py index 06cb42f66..a0eb0b21b 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -232,6 +232,8 @@ def run(self): break except WebSocketProtocolError: yield from self.fail_connection(1002) + except asyncio.IncompleteReadError: + yield from self.fail_connection(1006) except UnicodeDecodeError: yield from self.fail_connection(1007) except PayloadTooBig: diff --git a/websockets/test_framing.py b/websockets/test_framing.py index bf3c7967b..86e2c595f 100644 --- a/websockets/test_framing.py +++ b/websockets/test_framing.py @@ -124,10 +124,6 @@ def test_fragmented_control_frame(self): with self.assertRaises(WebSocketProtocolError): self.decode(b'\x08\x00') - def test_truncated_message(self): - with self.assertRaises(WebSocketProtocolError): - self.decode(b'\x80\x01') - def test_parse_close(self): self.round_trip_close(b'\x03\xe8', 1000, '') self.round_trip_close(b'\x03\xe8OK', 1000, 'OK') diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index f95e3cfa7..461e0b9c4 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -290,7 +290,7 @@ def test_connection_close_in_fragmented_text(self): self.loop.call_later(MS, self.protocol.eof_received) self.loop.call_later(2 * MS, lambda: self.protocol.connection_lost(None)) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) - self.assertConnectionClosed(1002, '') + self.assertConnectionClosed(1006, '') class ServerTests(CommonTests, unittest.TestCase): @@ -373,7 +373,7 @@ def test_close_connection_lost(self): self.loop.call_later(MS, self.protocol.eof_received) self.loop.call_later(2 * MS, lambda: self.protocol.connection_lost(None)) self.loop.run_until_complete(self.protocol.close(reason='because.')) - self.assertConnectionClosed(1002, '') + self.assertConnectionClosed(1006, '') def test_close_during_recv(self): recv = asyncio.async(self.protocol.recv()) From 43d334f4e19e2a9755c44448b47fa8ff07e9b0fb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 3 Nov 2014 22:07:08 +0100 Subject: [PATCH 0042/1539] Improve handling of concurrent close from both sides. This fixes random failures in the test suite. --- websockets/protocol.py | 21 +++++++++++++++---- websockets/test_client_server.py | 4 ---- websockets/test_protocol.py | 36 ++++++++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 8 deletions(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index a0eb0b21b..1c73f3f93 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -79,6 +79,7 @@ def __init__(self, *, # Futures tracking steps in the connection's lifecycle. self.opening_handshake = asyncio.Future() self.closing_handshake = asyncio.Future() + self.connection_failed = asyncio.Future() self.connection_closed = asyncio.Future() # Queue of received messages. @@ -309,7 +310,8 @@ def read_data_frame(self, max_size): # 7.1.3. The WebSocket Closing Handshake is Started self.state = 'CLOSING' yield from self.write_frame(OP_CLOSE, frame.data, 'CLOSING') - self.closing_handshake.set_result(True) + if not self.closing_handshake.done(): + self.closing_handshake.set_result(True) return elif frame.opcode == OP_PING: # Answer pings. @@ -350,9 +352,11 @@ def write_frame(self, opcode, data=b'', expected_state='OPEN'): # Handle flow control automatically. yield from self.writer.drain() except ConnectionResetError: - # Terminate the connection if the socket died. - self.state = 'CLOSING' - yield from self.fail_connection(1006) + # Terminate the connection if the socket died, + # unless it's already being closed. + if expected_state != 'CLOSING': + self.state = 'CLOSING' + yield from self.fail_connection(1006) @asyncio.coroutine def close_connection(self): @@ -393,6 +397,15 @@ def close_connection(self): @asyncio.coroutine def fail_connection(self, code=1011, reason=''): + # Avoid calling fail_connection more than once to minimize + # the consequences of race conditions between the two sides. + if self.connection_failed.done(): + # Wait until the other coroutine calls connection_lost. + yield from self.connection_closed + return + else: + self.connection_failed.set_result(None) + # Losing the connection usually results in a protocol error. # Preserve the original error code in this case. if self.close_code != 1006: diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index a0177986d..4918644c1 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -51,14 +51,10 @@ def start_client(self, path=''): self.client = self.loop.run_until_complete(client) def stop_client(self): - # Run the event loop to finish processing and avoid race conditions. - self.loop.run_until_complete(asyncio.sleep(0.001)) self.loop.run_until_complete(self.client.worker) def stop_server(self): self.server.close() - # Run the event loop to finish processing and avoid race conditions. - self.loop.run_until_complete(asyncio.sleep(0.001)) self.loop.run_until_complete(self.server.wait_closed()) def test_basic(self): diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 461e0b9c4..0e57d49e3 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -364,6 +364,23 @@ def test_close_timeout_before_connection_lost(self): self.assertFalse(self.before.cancelled()) self.before.cancel() + def test_client_close_race_with_failing_connection(self): + original_write_frame = self.protocol.write_frame + @asyncio.coroutine + def delayed_write_frame(*args): + yield from original_write_frame(*args) + yield from asyncio.sleep(2 * MS) + self.protocol.write_frame = delayed_write_frame + + frame = Frame(True, OP_CLOSE, serialize_close(1000, 'client')) + # Trigger the race condition between answering the close frame from + # the client and sending another close frame from the server. + self.loop.call_later(MS, self.feed, frame) + self.loop.call_later(2 * MS, asyncio.async, self.protocol.fail_connection(1000, 'server')) + self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) + self.assertConnectionClosed(1000, 'server') + self.assertFrameSent(*frame) + def test_close_protocol_error(self): self.loop.call_later(MS, self.feed, Frame(True, OP_CLOSE, b'\x00')) self.loop.run_until_complete(self.protocol.close(reason='because.')) @@ -468,3 +485,22 @@ def test_close_timeout_before_connection_lost(self): self.assertTrue(self.after.cancelled()) self.assertFalse(self.before.cancelled()) self.before.cancel() + + def test_server_close_race_with_failing_connection(self): + original_write_frame = self.protocol.write_frame + @asyncio.coroutine + def delayed_write_frame(*args): + yield from original_write_frame(*args) + yield from asyncio.sleep(2 * MS) + self.protocol.write_frame = delayed_write_frame + + frame = Frame(True, OP_CLOSE, serialize_close(1000, 'server')) + # Trigger the race condition between answering the close frame from + # the server and sending another close frame from the client. + self.loop.call_later(MS, self.feed, frame) + self.loop.call_later(2 * MS, asyncio.async, self.protocol.fail_connection(1000, 'client')) + self.loop.call_later(3 * MS, self.protocol.eof_received) + self.loop.call_later(4 * MS, lambda: self.protocol.connection_lost(None)) + self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) + self.assertConnectionClosed(1000, 'client') + self.assertFrameSent(*frame) From 634c329958f5707a687f96d51106a9f17477d355 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 3 Nov 2014 23:00:56 +0100 Subject: [PATCH 0043/1539] Second attempt at producing wheels. Fix #35. Refs #25. --- setup.cfg | 2 ++ setup.py | 5 ++++- 2 files changed, 6 insertions(+), 1 deletion(-) create mode 100644 setup.cfg diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 000000000..87175524c --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[bdist_wheel] +python-tag = py33.py34 diff --git a/setup.py b/setup.py index 4719ef520..72b6575dd 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,9 @@ packages=[ 'websockets', ], - install_requires=['asyncio'] if py_version == (3, 3) else [], + extras_require={ + ':python_version=="3.3"': ['asyncio'], + }, classifiers=[ "Development Status :: 5 - Production/Stable", "Environment :: Web Environment", @@ -47,6 +49,7 @@ "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.3", + "Programming Language :: Python :: 3.4", ], platforms='all', license='BSD' From 0f72a4d80e8ce67febfca1d5f95cbaf952e4ebb3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 3 Nov 2014 23:04:36 +0100 Subject: [PATCH 0044/1539] Changelog for 2.3. --- docs/index.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/index.rst b/docs/index.rst index 057056591..9a8c0e5f5 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -184,6 +184,11 @@ Utilities Changelog --------- +2.3 +... + +* Improved compliance of close codes. + 2.2 ... From 67fa773d72b6b3888e86ae65f3d4f89ceb705198 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 3 Nov 2014 23:05:09 +0100 Subject: [PATCH 0045/1539] Bump version number. --- docs/conf.py | 4 ++-- websockets/version.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 157e604fa..cfff51116 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -48,9 +48,9 @@ # built documents. # # The short X.Y version. -version = '2.2' +version = '2.3' # The full version, including alpha/beta/rc tags. -release = '2.2.0' +release = '2.3.0' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/websockets/version.py b/websockets/version.py index 62d14aae4..5a1c16e2d 100644 --- a/websockets/version.py +++ b/websockets/version.py @@ -1 +1 @@ -version = '2.2' +version = '2.3' From 5b204ddec43cea65e53d184b7656fc4cae1d4071 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 31 Jan 2015 16:35:57 +0100 Subject: [PATCH 0046/1539] Fix compliance tests. --- compliance/README.rst | 7 +++++-- compliance/test_client.py | 4 ++++ compliance/test_server.py | 4 ++++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/compliance/README.rst b/compliance/README.rst index 736643771..e85b7e1ff 100644 --- a/compliance/README.rst +++ b/compliance/README.rst @@ -29,5 +29,8 @@ supports Python 3; you need two different environments. Conformance notes ----------------- -Test cases 6.4.2, 6.4.3, and 6.4.4 are actually more strict than the RFC. -Given its implementation, ``websockets`` gets a "Non-Strict". +Test cases 6.4.3, and 6.4.4 are actually more strict than the RFC. Given its +implementation, ``websockets`` gets a "Non-Strict". + +Test cases 12.* and 13.* don't run because ``websockets`` doesn't implement +compression at this time. diff --git a/compliance/test_client.py b/compliance/test_client.py index 5c75006a1..f72f7de49 100644 --- a/compliance/test_client.py +++ b/compliance/test_client.py @@ -17,6 +17,10 @@ class EchoClientProtocol(websockets.WebSocketClientProtocol): """WebSocket client protocol that echoes messages synchronously.""" + def __init__(self, *args, **kwargs): + kwargs['max_size'] = 2 ** 25 + super().__init__(*args, **kwargs) + @asyncio.coroutine def read_message(self): msg = yield from super().read_message() diff --git a/compliance/test_server.py b/compliance/test_server.py index 3df861f53..85b059011 100644 --- a/compliance/test_server.py +++ b/compliance/test_server.py @@ -12,6 +12,10 @@ class EchoServerProtocol(websockets.WebSocketServerProtocol): """WebSocket server protocol that echoes messages synchronously.""" + def __init__(self, *args, **kwargs): + kwargs['max_size'] = 2 ** 25 + super().__init__(*args, **kwargs) + @asyncio.coroutine def read_message(self): msg = yield from super().read_message() From d62a5f267116ccc575f645f5954d027dd3c41cb0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 31 Jan 2015 16:41:42 +0100 Subject: [PATCH 0047/1539] Support non-default event loop. Fix #42. --- docs/index.rst | 11 +++++++++-- websockets/client.py | 8 ++++++-- websockets/protocol.py | 30 ++++++++++++++++-------------- websockets/server.py | 11 +++++++---- 4 files changed, 38 insertions(+), 22 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index 9a8c0e5f5..51598f6fe 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -117,7 +117,7 @@ Server .. automodule:: websockets.server - .. autofunction:: serve(ws_handler, host=None, port=None, *, klass=WebSocketServerProtocol, origins=None, **kwds) + .. autofunction:: serve(ws_handler, host=None, port=None, *, loop=None, klass=WebSocketServerProtocol, origins=None, **kwds) .. autoclass:: WebSocketServerProtocol(ws_handler, *, origins=None, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, loop=None) :members: handshake @@ -127,7 +127,7 @@ Client .. automodule:: websockets.client - .. autofunction:: connect(uri, *, klass=WebSocketClientProtocol, origin=None, **kwds) + .. autofunction:: connect(uri, *, loop=None, klass=WebSocketClientProtocol, origin=None, **kwds) .. autoclass:: WebSocketClientProtocol(*, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, loop=None) :members: handshake @@ -184,6 +184,13 @@ Utilities Changelog --------- +2.4 +... + +* Supported non-default event loop. +* Added `loop` argument to :func:`~websockets.client.connect` and + :func:`~websockets.server.serve`. + 2.3 ... diff --git a/websockets/client.py b/websockets/client.py index 107401549..e0d5a6a3f 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -68,7 +68,7 @@ def handshake(self, wsuri, origin=None): @asyncio.coroutine def connect(uri, *, - klass=WebSocketClientProtocol, origin=None, **kwds): + loop=None, klass=WebSocketClientProtocol, origin=None, **kwds): """ This coroutine connects to a WebSocket server. @@ -91,6 +91,9 @@ def connect(uri, *, Connection" in RFC 6455, except for the requirement that "there MUST be no more than one connection in a CONNECTING state." """ + if loop is None: + loop = asyncio.get_event_loop() + wsuri = parse_uri(uri) if wsuri.secure: kwds.setdefault('ssl', True) @@ -98,7 +101,8 @@ def connect(uri, *, raise ValueError("connect() received a SSL context for a ws:// URI. " "Use a wss:// URI to enable TLS.") factory = lambda: klass(host=wsuri.host, port=wsuri.port, secure=wsuri.secure) - transport, protocol = yield from asyncio.get_event_loop().create_connection( + + transport, protocol = yield from loop.create_connection( factory, wsuri.host, wsuri.port, **kwds) try: diff --git a/websockets/protocol.py b/websockets/protocol.py index 1c73f3f93..92f838006 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -71,16 +71,17 @@ def __init__(self, *, self.timeout = timeout self.max_size = max_size - super().__init__(asyncio.StreamReader(), self.client_connected, loop) + stream_reader = asyncio.StreamReader(loop=loop) + super().__init__(stream_reader, self.client_connected, loop) self.close_code = None self.close_reason = '' # Futures tracking steps in the connection's lifecycle. - self.opening_handshake = asyncio.Future() - self.closing_handshake = asyncio.Future() - self.connection_failed = asyncio.Future() - self.connection_closed = asyncio.Future() + self.opening_handshake = asyncio.Future(loop=loop) + self.closing_handshake = asyncio.Future(loop=loop) + self.connection_failed = asyncio.Future(loop=loop) + self.connection_closed = asyncio.Future(loop=loop) # Queue of received messages. self.messages = Queue() @@ -89,7 +90,7 @@ def __init__(self, *, self.pings = collections.OrderedDict() # Task managing the connection. - self.worker = asyncio.async(self.run()) + self.worker = asyncio.async(self.run(), loop=loop) # In a subclass implementing the opening handshake, the state will be # CONNECTING at this point. @@ -132,7 +133,8 @@ def close(self, code=1000, reason=''): # If the connection doesn't terminate within the timeout, break out of # the worker loop. try: - yield from asyncio.wait_for(self.worker, timeout=self.timeout) + yield from asyncio.wait_for( + self.worker, self.timeout, loop=self._loop) except asyncio.TimeoutError: self.worker.cancel() @@ -158,10 +160,10 @@ def recv(self): pass # Wait for a message until the connection is closed - next_message = asyncio.async(self.messages.get()) + next_message = asyncio.async(self.messages.get(), loop=self._loop) done, pending = yield from asyncio.wait( [next_message, self.worker], - return_when=asyncio.FIRST_COMPLETED) + loop=self._loop, return_when=asyncio.FIRST_COMPLETED) if next_message in done: return next_message.result() else: @@ -204,7 +206,7 @@ def ping(self, data=None): while data is None or data in self.pings: data = struct.pack('!I', random.getrandbits(32)) - self.pings[data] = asyncio.Future() + self.pings[data] = asyncio.Future(loop=self._loop) yield from self.write_frame(OP_PING, data) return self.pings[data] @@ -371,8 +373,8 @@ def close_connection(self): if self.is_client: try: - yield from asyncio.wait_for(self.connection_closed, - timeout=self.timeout) + yield from asyncio.wait_for( + self.connection_closed, self.timeout, loop=self._loop) except (asyncio.CancelledError, asyncio.TimeoutError): pass @@ -390,8 +392,8 @@ def close_connection(self): self.writer.close() try: - yield from asyncio.wait_for(self.connection_closed, - timeout=self.timeout) + yield from asyncio.wait_for( + self.connection_closed, self.timeout, loop=self._loop) except (asyncio.CancelledError, asyncio.TimeoutError): pass diff --git a/websockets/server.py b/websockets/server.py index a1ae0a222..8358c13fe 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -37,7 +37,7 @@ def __init__(self, ws_handler, *, origins=None, **kwargs): def connection_made(self, transport): super().connection_made(transport) - asyncio.async(self.handler()) + asyncio.async(self.handler(), loop=self._loop) @asyncio.coroutine def handler(self): @@ -125,7 +125,7 @@ def handshake(self, origins=None): @asyncio.coroutine def serve(ws_handler, host=None, port=None, *, - klass=WebSocketServerProtocol, origins=None, **kwds): + loop=None, klass=WebSocketServerProtocol, origins=None, **kwds): """ This coroutine creates a WebSocket server. @@ -155,8 +155,11 @@ def serve(ws_handler, host=None, port=None, *, logger.setLevel(logging.DEBUG) logger.addHandler(logging.StreamHandler()) """ + if loop is None: + loop = asyncio.get_event_loop() + secure = kwds.get('ssl') is not None factory = lambda: klass(ws_handler, host=host, port=port, secure=secure, origins=origins) - return (yield from asyncio.get_event_loop().create_server( - factory, host, port, **kwds)) + + return (yield from loop.create_server(factory, host, port, **kwds)) From 0f1697fcfde2656b8a2b6a2f100dd99b7ec70bb1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 31 Jan 2015 19:00:24 +0100 Subject: [PATCH 0048/1539] Add support for subprotocols. Fix #38. --- docs/index.rst | 12 ++--- websockets/client.py | 23 ++++++++-- websockets/protocol.py | 5 +++ websockets/server.py | 45 ++++++++++++++----- websockets/test_client_server.py | 75 ++++++++++++++++++++++++++++---- 5 files changed, 131 insertions(+), 29 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index 51598f6fe..37bbf5ebb 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -117,17 +117,17 @@ Server .. automodule:: websockets.server - .. autofunction:: serve(ws_handler, host=None, port=None, *, loop=None, klass=WebSocketServerProtocol, origins=None, **kwds) + .. autofunction:: serve(ws_handler, host=None, port=None, *, loop=None, klass=WebSocketServerProtocol, origins=None, subprotocols=None, **kwds) .. autoclass:: WebSocketServerProtocol(ws_handler, *, origins=None, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, loop=None) - :members: handshake + :members: handshake, select_subprotocol Client ...... .. automodule:: websockets.client - .. autofunction:: connect(uri, *, loop=None, klass=WebSocketClientProtocol, origin=None, **kwds) + .. autofunction:: connect(uri, *, loop=None, klass=WebSocketClientProtocol, origin=None, subprotocols=None, **kwds) .. autoclass:: WebSocketClientProtocol(*, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, loop=None) :members: handshake @@ -187,6 +187,7 @@ Changelog 2.4 ... +* Added support for subprotocols. * Supported non-default event loop. * Added `loop` argument to :func:`~websockets.client.connect` and :func:`~websockets.server.serve`. @@ -227,10 +228,9 @@ Changelog Limitations ----------- -Subprotocols_ and Extensions_ aren't implemented. Few subprotocols and no -extensions are registered_ at the time of writing. +Extensions_ aren't implemented. No extensions are registered_ at the time of +writing. -.. _Subprotocols: http://tools.ietf.org/html/rfc6455#section-1.9 .. _Extensions: http://tools.ietf.org/html/rfc6455#section-9 .. _registered: http://www.iana.org/assignments/websocket/websocket.xml diff --git a/websockets/client.py b/websockets/client.py index e0d5a6a3f..3361dfae5 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -25,11 +25,14 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): state = 'CONNECTING' @asyncio.coroutine - def handshake(self, wsuri, origin=None): + def handshake(self, wsuri, origin=None, subprotocols=None): """ Perform the client side of the opening handshake. If provided, ``origin`` sets the HTTP Origin header. + + If provided, ``subprotocols`` is a list of supported subprotocols, in + order of decreasing preference. """ headers = [] set_header = lambda k, v: headers.append((k, v)) @@ -39,6 +42,8 @@ def handshake(self, wsuri, origin=None): set_header('Host', '{}:{}'.format(wsuri.host, wsuri.port)) if origin is not None: set_header('Origin', origin) + if subprotocols is not None: + set_header('Sec-WebSocket-Protocol', ', '.join(subprotocols)) set_header('User-Agent', USER_AGENT) key = build_request(set_header) self.raw_request_headers = headers @@ -62,17 +67,26 @@ def handshake(self, wsuri, origin=None): get_header = lambda k: headers.get(k, '') check_response(get_header, key) + self.subprotocol = headers.get('Sec-WebSocket-Protocol', None) + if (self.subprotocol is not None + and self.subprotocol not in subprotocols): + raise InvalidHandshake( + "Unknown subprotocol: {}".format(self.subprotocol)) + self.state = 'OPEN' self.opening_handshake.set_result(True) @asyncio.coroutine def connect(uri, *, - loop=None, klass=WebSocketClientProtocol, origin=None, **kwds): + loop=None, klass=WebSocketClientProtocol, origin=None, + subprotocols=None, **kwds): """ This coroutine connects to a WebSocket server. - It accepts an ``origin`` keyword argument to set the Origin HTTP header. + It accepts an ``origin`` keyword argument to set the Origin HTTP header + and a ``subprotocols`` keyword argument to provide a list of supported + subprotocols. It's a thin wrapper around the event loop's `create_connection` method. Extra keyword arguments are passed to `create_server`. @@ -106,7 +120,8 @@ def connect(uri, *, factory, wsuri.host, wsuri.port, **kwds) try: - yield from protocol.handshake(wsuri, origin=origin) + yield from protocol.handshake( + wsuri, origin=origin, subprotocols=subprotocols) except Exception: protocol.writer.close() raise diff --git a/websockets/protocol.py b/websockets/protocol.py index 92f838006..78795dc14 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -50,6 +50,9 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): message larger than the maximum size is received, :meth:`recv()` will return ``None`` and the connection will be closed with status code 1009. + Once the handshake is complete, if a subprotocol was negotiated, it's + available in the :attr:`subprotocol` attribute. + Once the connection is closed, the status code is available in the :attr:`close_code` attribute and the reason in :attr:`close_reason`. """ @@ -74,6 +77,8 @@ def __init__(self, *, stream_reader = asyncio.StreamReader(loop=loop) super().__init__(stream_reader, self.client_connected, loop) + self.subprotocol = None + self.close_code = None self.close_reason = '' diff --git a/websockets/server.py b/websockets/server.py index 8358c13fe..b971ac173 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -30,10 +30,12 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): state = 'CONNECTING' - def __init__(self, ws_handler, *, origins=None, **kwargs): + def __init__(self, ws_handler, *, + origins=None, subprotocols=None, **kwds): self.ws_handler = ws_handler self.origins = origins - super().__init__(**kwargs) + self.subprotocols = subprotocols + super().__init__(**kwds) def connection_made(self, transport): super().connection_made(transport) @@ -46,7 +48,8 @@ def handler(self): try: try: - path = yield from self.handshake(origins=self.origins) + path = yield from self.handshake( + origins=self.origins, subprotocols=self.subprotocols) except Exception as exc: logger.info("Exception in opening handshake: {}".format(exc)) if isinstance(exc, InvalidHandshake): @@ -78,7 +81,7 @@ def handler(self): pass @asyncio.coroutine - def handshake(self, origins=None): + def handshake(self, origins=None, subprotocols=None): """ Perform the server side of the opening handshake. @@ -97,15 +100,23 @@ def handshake(self, origins=None): get_header = lambda k: headers.get(k, '') key = check_request(get_header) - # Check origin in request. if origins is not None: origin = get_header('Origin') - if not set(origin.split() or ('',))<= set(origins): + if not set(origin.split() or ['']) <= set(origins): raise InvalidHandshake("Bad origin: {}".format(origin)) + if subprotocols is not None: + protocol = get_header('Sec-WebSocket-Protocol') + if protocol: + client_subprotocols = [p.strip() for p in protocol.split(',')] + self.subprotocol = self.select_subprotocol( + client_subprotocols, subprotocols) + headers = [] set_header = lambda k, v: headers.append((k, v)) set_header('Server', USER_AGENT) + if self.subprotocol: + set_header('Sec-WebSocket-Protocol', self.subprotocol) build_response(set_header, key) self.raw_response_headers = headers @@ -122,10 +133,21 @@ def handshake(self, origins=None): return path + def select_subprotocol(self, client_protos, server_protos): + """ + Pick a subprotocol among those offered by the client. + """ + common_protos = set(client_protos) & set(server_protos) + if not common_protos: + return None + priority = lambda p: client_protos.index(p) + server_protos.index(p) + return sorted(common_protos, key=priority)[0] + @asyncio.coroutine def serve(ws_handler, host=None, port=None, *, - loop=None, klass=WebSocketServerProtocol, origins=None, **kwds): + loop=None, klass=WebSocketServerProtocol, origins=None, + subprotocols=None, **kwds): """ This coroutine creates a WebSocket server. @@ -136,7 +158,8 @@ def serve(ws_handler, host=None, port=None, *, `ws_handler` is the WebSocket handler. It must be a coroutine accepting two arguments: a :class:`~websockets.server.WebSocketServerProtocol` and the request URI. If provided, `origin` is a list of acceptable Origin HTTP - headers. Include ``''`` if the lack of an origin is acceptable. + headers. Include ``''`` if the lack of an origin is acceptable. If + provided, `subprotocols` is a list of supported subprotocols. It returns a `Server` object with a `close` method to stop the server. @@ -159,7 +182,7 @@ def serve(ws_handler, host=None, port=None, *, loop = asyncio.get_event_loop() secure = kwds.get('ssl') is not None - factory = lambda: klass(ws_handler, - host=host, port=port, secure=secure, origins=origins) - + factory = lambda: klass( + ws_handler, host=host, port=port, secure=secure, + origins=origins, subprotocols=subprotocols) return (yield from loop.create_server(factory, host, port, **kwds)) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 4918644c1..a0b99c888 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -25,6 +25,8 @@ def handler(ws, path): elif path == '/raw_headers': yield from ws.send(repr(ws.raw_request_headers)) yield from ws.send(repr(ws.raw_response_headers)) + elif path == '/subprotocol': + yield from ws.send(repr(ws.subprotocol)) else: yield from ws.send((yield from ws.recv())) @@ -42,12 +44,12 @@ def tearDown(self): self.stop_server() self.loop.close() - def start_server(self): - server = serve(handler, 'localhost', 8642) + def start_server(self, **kwds): + server = serve(handler, 'localhost', 8642, **kwds) self.server = self.loop.run_until_complete(server) - def start_client(self, path=''): - client = connect('ws://localhost:8642/' + path) + def start_client(self, path='', **kwds): + client = connect('ws://localhost:8642/' + path, **kwds) self.client = self.loop.run_until_complete(client) def stop_client(self): @@ -85,6 +87,61 @@ def test_protocol_raw_headers(self): self.assertEqual(server_resp, repr(client_resp)) self.stop_client() + def test_no_subprotocol(self): + self.start_client('subprotocol') + server_subprotocol = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(server_subprotocol, repr(None)) + self.assertEqual(self.client.subprotocol, None) + self.stop_client() + + def test_subprotocol_found(self): + self.stop_server() + self.start_server(subprotocols=['superchat', 'chat']) + + self.start_client('subprotocol', subprotocols=['otherchat', 'chat']) + server_subprotocol = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(server_subprotocol, repr('chat')) + self.assertEqual(self.client.subprotocol, 'chat') + self.stop_client() + + def test_subprotocol_not_found(self): + self.stop_server() + self.start_server(subprotocols=['superchat']) + + self.start_client('subprotocol', subprotocols=['otherchat']) + server_subprotocol = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(server_subprotocol, repr(None)) + self.assertEqual(self.client.subprotocol, None) + self.stop_client() + + def test_subprotocol_not_offered(self): + self.start_client('subprotocol', subprotocols=['otherchat', 'chat']) + server_subprotocol = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(server_subprotocol, repr(None)) + self.assertEqual(self.client.subprotocol, None) + self.stop_client() + + def test_subprotocol_not_requested(self): + self.stop_server() + self.start_server(subprotocols=['superchat', 'chat']) + + self.start_client('subprotocol') + server_subprotocol = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(server_subprotocol, repr(None)) + self.assertEqual(self.client.subprotocol, None) + self.stop_client() + + @patch.object(WebSocketServerProtocol, 'select_subprotocol', autospec=True) + def test_subprotocol_error(self, _select_subprotocol): + _select_subprotocol.return_value = 'superchat' + + self.stop_server() + self.start_server(subprotocols=['superchat']) + + with self.assertRaises(InvalidHandshake): + self.start_client('subprotocol', subprotocols=['otherchat']) + print(_select_subprotocol.call_args_list) + @patch('websockets.server.read_request') def test_server_receives_malformed_request(self, _read_request): _read_request.side_effect = ValueError("read_request failed") @@ -181,12 +238,14 @@ def client_context(self): ssl_context.verify_mode = ssl.CERT_REQUIRED return ssl_context - def start_server(self): - server = serve(handler, 'localhost', 8642, ssl=self.server_context) + def start_server(self, *args, **kwds): + kwds['ssl'] = self.server_context + server = serve(handler, 'localhost', 8642, **kwds) self.server = self.loop.run_until_complete(server) - def start_client(self, path=''): - client = connect('wss://localhost:8642/' + path, ssl=self.client_context) + def start_client(self, path='', **kwds): + kwds['ssl'] = self.client_context + client = connect('wss://localhost:8642/' + path, **kwds) self.client = self.loop.run_until_complete(client) def test_ws_uri_is_rejected(self): From 2893f08c19bddc9912c58821bfbd429e9b98716a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 31 Jan 2015 21:10:38 +0100 Subject: [PATCH 0049/1539] Automatically bump copyright year in docs. --- docs/conf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index cfff51116..bb69e43fe 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -11,7 +11,7 @@ # All configuration values have a default; values that are commented out # serve to show the default. -import sys, os +import sys, os, datetime # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the @@ -41,7 +41,7 @@ # General information about the project. project = 'websockets' -copyright = '2013-2014, Aymeric Augustin' +copyright = '2013-{}, Aymeric Augustin'.format(datetime.date.today().year) # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the From 9b81b7e5da8cc37cd7644ad58ac9648070703fe8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 31 Jan 2015 21:10:50 +0100 Subject: [PATCH 0050/1539] Bump version number. --- docs/conf.py | 4 ++-- websockets/version.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index bb69e43fe..f15b50f92 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -48,9 +48,9 @@ # built documents. # # The short X.Y version. -version = '2.3' +version = '2.4' # The full version, including alpha/beta/rc tags. -release = '2.3.0' +release = '2.4' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/websockets/version.py b/websockets/version.py index 5a1c16e2d..d3300deb3 100644 --- a/websockets/version.py +++ b/websockets/version.py @@ -1 +1 @@ -version = '2.3' +version = '2.4' From 7d8191699a6d647c1b45e3e11681c5987437e5b5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 31 Jan 2015 21:33:18 +0100 Subject: [PATCH 0051/1539] Improve documentation. --- LICENSE | 2 +- docs/index.rst | 2 +- websockets/client.py | 10 ++++++---- websockets/exceptions.py | 9 +++++---- websockets/framing.py | 15 +++++++++------ websockets/handshake.py | 7 ++++--- websockets/http.py | 3 +-- websockets/protocol.py | 11 ++++++----- websockets/server.py | 20 +++++++++++--------- websockets/uri.py | 3 ++- 10 files changed, 46 insertions(+), 36 deletions(-) diff --git a/LICENSE b/LICENSE index fcf63659c..f46af27b5 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright (c) 2013-2014 Aymeric Augustin and contributors. +Copyright (c) 2013-2015 Aymeric Augustin and contributors. All rights reserved. Redistribution and use in source and binary forms, with or without diff --git a/docs/index.rst b/docs/index.rst index 37bbf5ebb..061963d03 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -119,7 +119,7 @@ Server .. autofunction:: serve(ws_handler, host=None, port=None, *, loop=None, klass=WebSocketServerProtocol, origins=None, subprotocols=None, **kwds) - .. autoclass:: WebSocketServerProtocol(ws_handler, *, origins=None, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, loop=None) + .. autoclass:: WebSocketServerProtocol(ws_handler, *, origins=None, subprotocols=None, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, loop=None) :members: handshake, select_subprotocol Client diff --git a/websockets/client.py b/websockets/client.py index 3361dfae5..b9151bc42 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -15,7 +15,7 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): """ - Complete WebSocket client implementation as an asyncio protocol. + Complete WebSocket client implementation as an :mod:`asyncio` protocol. This class inherits most of its methods from :class:`~websockets.protocol.WebSocketCommonProtocol`. @@ -88,8 +88,9 @@ def connect(uri, *, and a ``subprotocols`` keyword argument to provide a list of supported subprotocols. - It's a thin wrapper around the event loop's `create_connection` method. - Extra keyword arguments are passed to `create_server`. + It's a thin wrapper around the event loop's + :meth:`~asyncio.BaseEventLoop.create_connection` method. Extra keyword + arguments are passed to :meth:`~asyncio.BaseEventLoop.create_connection`. It returns a :class:`~websockets.client.WebSocketClientProtocol` which can then be used to send and receive messages. @@ -103,7 +104,8 @@ def connect(uri, *, :func:`connect` implements the sequence called "Establish a WebSocket Connection" in RFC 6455, except for the requirement that "there MUST be no - more than one connection in a CONNECTING state." + more than one connection in a CONNECTING state" because it cannot be + enforced at that level. """ if loop is None: loop = asyncio.get_event_loop() diff --git a/websockets/exceptions.py b/websockets/exceptions.py index 20cc0ba2f..a1e130d4e 100644 --- a/websockets/exceptions.py +++ b/websockets/exceptions.py @@ -1,4 +1,7 @@ -__all__ = ['InvalidHandshake', 'InvalidState', 'InvalidURI'] +__all__ = [ + 'InvalidHandshake', 'InvalidState', 'InvalidURI', + 'PayloadTooBig', 'WebSocketProtocolError', +] class InvalidHandshake(Exception): @@ -18,6 +21,4 @@ class PayloadTooBig(Exception): class WebSocketProtocolError(Exception): - # Internal exception raised when the other end breaks the protocol. - # It's private because it shouldn't leak outside of WebSocketCommonProtocol. - pass + """Internal exception raised when the remote side breaks the protocol.""" diff --git a/websockets/framing.py b/websockets/framing.py index b83b45466..a6d45b366 100644 --- a/websockets/framing.py +++ b/websockets/framing.py @@ -58,10 +58,11 @@ def read_frame(reader, mask, *, max_size=None): whether the read happens on the server side. If `max_size` is set and the payload exceeds this size in bytes, - :exc:`PayloadTooBig` is raised. + :exc:`~websockets.exceptions.PayloadTooBig` is raised. This function validates the frame before returning it and raises - :exc:`WebSocketProtocolError` if it contains incorrect values. + :exc:`~websockets.exceptions.WebSocketProtocolError` if it contains + incorrect values. """ # Read the header data = yield from reader(2) @@ -107,7 +108,8 @@ def write_frame(frame, writer, mask): whether the write happens on the client side. This function validates the frame before sending it and raises - :exc:`WebSocketProtocolError` if it contains incorrect values. + :exc:`~websockets.exceptions.WebSocketProtocolError` if it contains + incorrect values. """ check_frame(frame) @@ -138,7 +140,8 @@ def write_frame(frame, writer, mask): def check_frame(frame): """ - Raise :exc:`WebSocketProtocolError` if the frame contains incorrect values. + Raise :exc:`~websockets.exceptions.WebSocketProtocolError` if the frame + contains incorrect values. """ if frame.opcode in (OP_CONT, OP_TEXT, OP_BINARY): return @@ -158,8 +161,8 @@ def parse_close(data): Return `(code, reason)` when `code` is an :class:`int` and `reason` a :class:`str`. - Raise :exc:`WebSocketProtocolError` or :exc:`UnicodeDecodeError` if the - data is invalid. + Raise :exc:`~websockets.exceptions.WebSocketProtocolError` or + :exc:`UnicodeDecodeError` if the data is invalid. """ length = len(data) if length == 0: diff --git a/websockets/handshake.py b/websockets/handshake.py index 61dc08e0b..a400147f6 100644 --- a/websockets/handshake.py +++ b/websockets/handshake.py @@ -69,8 +69,8 @@ def check_request(get_header): If the handshake is valid, this function returns the `key` which must be passed to :func:`build_response`. - Otherwise, it raises an :exc:`InvalidHandshake` exception and the server - must return an error, usually 400 Bad Request. + Otherwise, it raises an :exc:`~websockets.exceptions.InvalidHandshake` + exception and the server must return an error, usually 400 Bad Request. This function doesn't verify that the request is an HTTP/1.1 or higher GET request and doesn't perform Host and Origin checks. These controls are @@ -108,7 +108,8 @@ def check_response(get_header, key): If the handshake is valid, this function returns ``None``. - Otherwise, it raises an :exc:`InvalidHandshake` exception. + Otherwise, it raises an :exc:`~websockets.exceptions.InvalidHandshake` + exception. This function doesn't verify that the response is an HTTP/1.1 or higher response with a 101 status code. These controls are the responsibility of diff --git a/websockets/http.py b/websockets/http.py index 6262bf017..93f151dc6 100644 --- a/websockets/http.py +++ b/websockets/http.py @@ -1,7 +1,6 @@ """ The :mod:`websockets.http` module provides HTTP parsing functions. They're -merely adequate for the WebSocket handshake messages. They're used by the -sample client and servers. +merely adequate for the WebSocket handshake messages. These functions cannot be imported from :mod:`websockets`; they must be imported from :mod:`websockets.http`. diff --git a/websockets/protocol.py b/websockets/protocol.py index 78795dc14..7e4a94eb9 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -123,8 +123,8 @@ def close(self, code=1000, reason=''): It waits for the other end to complete the handshake. It doesn't do anything once the connection is closed. - It's usually safe to wrap this coroutine in `asyncio.async()` since - errors during connection termination aren't particularly useful. + It's usually safe to wrap this coroutine in :func:`~asyncio.async` + since errors during connection termination aren't particularly useful. The `code` must be an :class:`int` and the `reason` a :class:`str`. """ @@ -183,7 +183,7 @@ def send(self, data): frame. It raises a :exc:`TypeError` for other inputs and - :exc:`InvalidState` once the connection is closed. + :exc:`~websockets.exceptions.InvalidState` once the connection is closed. """ if isinstance(data, str): opcode = 1 @@ -199,8 +199,9 @@ def ping(self, data=None): """ This coroutine sends a ping. - It returns a Future which will be completed when the corresponding - pong is received and which you may ignore if you don't want to wait. + It returns a :class:`~asyncio.Future` which will be completed when the + corresponding pong is received and which you may ignore if you don't + want to wait. A ping may serve as a keepalive. """ diff --git a/websockets/server.py b/websockets/server.py index b971ac173..26dec8d6e 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -19,7 +19,7 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): """ - Complete WebSocket server implementation as an asyncio protocol. + Complete WebSocket server implementation as an :mod:`asyncio` protocol. This class inherits most of its methods from :class:`~websockets.protocol.WebSocketCommonProtocol`. @@ -85,8 +85,9 @@ def handshake(self, origins=None, subprotocols=None): """ Perform the server side of the opening handshake. - If provided, ``origins`` is a list of acceptable HTTP Origin values. - Include ``''`` if the lack of an origin is acceptable. + If provided, `origins` is a list of acceptable HTTP Origin values. + Include ``''`` if the lack of an origin is acceptable. If provided, + `subprotocols` is a list of supported subprotocols. Return the URI of the request. """ @@ -151,17 +152,18 @@ def serve(ws_handler, host=None, port=None, *, """ This coroutine creates a WebSocket server. - It's a thin wrapper around the event loop's `create_server` method. - `host`, `port` as well as extra keyword arguments are passed to - `create_server`. + It's a thin wrapper around the event loop's + :meth:`~asyncio.BaseEventLoop.create_server` method. `host`, `port` as + well as extra keyword arguments are passed to + :meth:`~asyncio.BaseEventLoop.create_server`. `ws_handler` is the WebSocket handler. It must be a coroutine accepting two arguments: a :class:`~websockets.server.WebSocketServerProtocol` and - the request URI. If provided, `origin` is a list of acceptable Origin HTTP - headers. Include ``''`` if the lack of an origin is acceptable. If + the request URI. If provided, `origins` is a list of acceptable Origin + HTTP headers. Include ``''`` if the lack of an origin is acceptable. If provided, `subprotocols` is a list of supported subprotocols. - It returns a `Server` object with a `close` method to stop the server. + `serve` yields a `Server` object with a `close` method to stop the server. Whenever a client connects, the server accepts the connection, creates a :class:`~websockets.server.WebSocketServerProtocol`, performs the opening diff --git a/websockets/uri.py b/websockets/uri.py index b8d993f41..03a977424 100644 --- a/websockets/uri.py +++ b/websockets/uri.py @@ -24,7 +24,8 @@ def parse_uri(uri): If the URI is valid, it returns a namedtuple `(secure, host, port, resource_name)` - Otherwise, it raises an :exc:`InvalidURI` exception. + Otherwise, it raises an :exc:`~websockets.exceptions.InvalidURI` + exception. """ uri = urllib.parse.urlparse(uri) try: From e65f032116a1ee95f8e80b622362ee813f5158a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9my=20HUBSCHER?= Date: Fri, 3 Apr 2015 17:15:14 +0200 Subject: [PATCH 0052/1539] =?UTF-8?q?Fix=20#50=20=E2=80=94=20Make=20sure?= =?UTF-8?q?=20protocol.recv()=20can=20be=20Cancelled=20without=20loosing?= =?UTF-8?q?=20messages.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- websockets/protocol.py | 8 +++++++- websockets/test_protocol.py | 12 ++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index 7e4a94eb9..400b7bd4d 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -166,9 +166,15 @@ def recv(self): # Wait for a message until the connection is closed next_message = asyncio.async(self.messages.get(), loop=self._loop) - done, pending = yield from asyncio.wait( + try: + done, pending = yield from asyncio.wait( [next_message, self.worker], loop=self._loop, return_when=asyncio.FIRST_COMPLETED) + except asyncio.CancelledError: + # Handle the Task.cancel() + next_message.cancel() + raise + if next_message in done: return next_message.result() else: diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 0e57d49e3..c462b9948 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -146,6 +146,18 @@ def test_recv_on_closed_connection(self): self.protocol.connection_lost(None) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) + def test_recv_cancelled(self): + try: + data = self.loop.run_until_complete( + asyncio.wait_for(self.protocol.recv(), 1, loop=self.loop) + ) + except asyncio.TimeoutError: + self.feed(Frame(True, OP_TEXT, 'café'.encode('utf-8'))) + data = self.loop.run_until_complete( + asyncio.wait_for(self.protocol.recv(), 1, loop=self.loop) + ) # We use wait_for here to make sure the test fail and don't hang + self.assertEqual(data, 'café') + def test_send_text(self): self.loop.run_until_complete(self.protocol.send('café')) self.assertFrameSent(True, OP_TEXT, 'café'.encode('utf-8')) From 9469965a6552c038232ae987e473199451295730 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 14 Jul 2015 21:34:50 +0200 Subject: [PATCH 0053/1539] The closing handshake can be initiated by the client. I had misunderstood the RFC. It's the TCP closing handshake that should be initiated by the server. Fix #53. --- compliance/test_client.py | 4 ++-- example/client.py | 1 + websockets/protocol.py | 2 -- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/compliance/test_client.py b/compliance/test_client.py index f72f7de49..ed3915364 100644 --- a/compliance/test_client.py +++ b/compliance/test_client.py @@ -34,7 +34,7 @@ def get_case_count(server): uri = server + '/getCaseCount' ws = yield from websockets.connect(uri) msg = yield from ws.recv() - yield from ws.worker + yield from ws.close() return json.loads(msg) @@ -49,7 +49,7 @@ def run_case(server, case, agent): def update_reports(server, agent): uri = server + '/updateReports?agent={}'.format(agent) ws = yield from websockets.connect(uri) - yield from ws.worker + yield from ws.close() @asyncio.coroutine diff --git a/example/client.py b/example/client.py index 6bc960227..a75992975 100755 --- a/example/client.py +++ b/example/client.py @@ -11,5 +11,6 @@ def hello(): print("> {}".format(name)) greeting = yield from websocket.recv() print("< {}".format(greeting)) + yield from websocket.close() asyncio.get_event_loop().run_until_complete(hello()) diff --git a/websockets/protocol.py b/websockets/protocol.py index 7e4a94eb9..92c02a54d 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -118,8 +118,6 @@ def close(self, code=1000, reason=''): """ This coroutine performs the closing handshake. - This is the expected way to terminate a connection on the server side. - It waits for the other end to complete the handshake. It doesn't do anything once the connection is closed. From 72bfb7af489001a50c78f9e2f79ec9748d30914f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 14 Jul 2015 21:36:37 +0200 Subject: [PATCH 0054/1539] Minor fixes. --- compliance/README.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compliance/README.rst b/compliance/README.rst index e85b7e1ff..277804ad7 100644 --- a/compliance/README.rst +++ b/compliance/README.rst @@ -21,7 +21,7 @@ Run the first command in a shell. Run the second command in another shell. It should take about one minute to complete. Then kill the first one with Ctrl-C. The test client or server shouldn't display any exceptions. The results are -stored in reports/index.html. +stored in reports/clients/index.html. Note that the Autobahn software only supports Python 2, while websockets only supports Python 3; you need two different environments. @@ -29,7 +29,7 @@ supports Python 3; you need two different environments. Conformance notes ----------------- -Test cases 6.4.3, and 6.4.4 are actually more strict than the RFC. Given its +Test cases 6.4.3 and 6.4.4 are actually more strict than the RFC. Given its implementation, ``websockets`` gets a "Non-Strict". Test cases 12.* and 13.* don't run because ``websockets`` doesn't implement From 8c3e856390a2bc27e3e97c8367e036c166a84e6a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 14 Jul 2015 21:36:53 +0200 Subject: [PATCH 0055/1539] Add a cheat sheet to the documentation. Also improve markup while I'm in the area. --- docs/index.rst | 63 ++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 59 insertions(+), 4 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index 061963d03..d88748769 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -87,6 +87,53 @@ Here's a corresponding client example. Of course, you can combine the two patterns shown above to read and write messages on the same connection. +That's really all you have to know! ``websockets`` manages the connection +under the hood so you don't have to. + +Cheat sheet +----------- + +Server +...... + +* Write a coroutine that handles a single connection. It receives a websocket + protocol instance and the URI path in argument. + + * Call :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` and + :meth:`~websockets.protocol.WebSocketCommonProtocol.send` to receive and + send messages at any time. + + * You may :meth:`~websockets.protocol.WebSocketCommonProtocol.ping` or + :meth:`~websockets.protocol.WebSocketCommonProtocol.pong` if you wish + but it isn't needed in general. + +* Create a server with :func:`~websockets.server.serve` which is similar to + asyncio's :meth:`~asyncio.BaseEventLoop.create_server`. + + * The server takes care of establishing connections, then lets the handler + execute the application logic, and finally closes the connection after + the handler returns. + + * You may subclass :class:`~websockets.server.WebSocketServerProtocol` if + you have an advanced use case. + +Client +...... + +* Create a server with :func:`~websockets.client.connect` which is similar to + asyncio's :meth:`~asyncio.BaseEventLoop.create_connection`. + +* Call :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` and + :meth:`~websockets.protocol.WebSocketCommonProtocol.send` to receive and + send messages at any time. + +* You may :meth:`~websockets.protocol.WebSocketCommonProtocol.ping` or + :meth:`~websockets.protocol.WebSocketCommonProtocol.pong` if you wish but it + isn't needed in general. + +* Call :meth:`~websockets.protocol.WebSocketCommonProtocol.close` to terminate + the connection. + Design ------ @@ -95,6 +142,7 @@ the examples above. These functions are built on top of low-level APIs reflecting the two phases of the WebSocket protocol: 1. An opening handshake, in the form of an HTTP Upgrade request; + 2. Data transfer, as framed messages, ending with a closing handshake. The first phase is designed to integrate with existing HTTP software. @@ -120,7 +168,9 @@ Server .. autofunction:: serve(ws_handler, host=None, port=None, *, loop=None, klass=WebSocketServerProtocol, origins=None, subprotocols=None, **kwds) .. autoclass:: WebSocketServerProtocol(ws_handler, *, origins=None, subprotocols=None, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, loop=None) - :members: handshake, select_subprotocol + + .. automethod:: handshake(origins=None, subprotocols=None) + .. automethod:: select_subprotocol(client_protos, server_protos) Client ...... @@ -130,7 +180,8 @@ Client .. autofunction:: connect(uri, *, loop=None, klass=WebSocketClientProtocol, origin=None, subprotocols=None, **kwds) .. autoclass:: WebSocketClientProtocol(*, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, loop=None) - :members: handshake + + .. automethod:: handshake(wsuri, origin=None, subprotocols=None) Shared ...... @@ -139,7 +190,7 @@ Shared .. autoclass:: WebSocketCommonProtocol(*, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, loop=None) - .. autoattribute:: open + .. autoattribute:: open() .. automethod:: close(code=1000, reason='') .. automethod:: recv() @@ -188,7 +239,9 @@ Changelog ... * Added support for subprotocols. + * Supported non-default event loop. + * Added `loop` argument to :func:`~websockets.client.connect` and :func:`~websockets.server.serve`. @@ -206,6 +259,7 @@ Changelog ... * Added `host`, `port` and `secure` attributes on protocols. + * Added support for providing and checking Origin_. .. _Origin: https://tools.ietf.org/html/rfc6455#section-10.2 @@ -218,7 +272,8 @@ Changelog :meth:`~websockets.protocol.WebSocketCommonProtocol.ping` and :meth:`~websockets.protocol.WebSocketCommonProtocol.pong` are coroutines. They used to be regular functions. -* Add flow control. + +* Added flow control. 1.0 ... From e8a12dfc54ed23904a1926ef196523d1b07837c7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 14 Jul 2015 21:43:28 +0200 Subject: [PATCH 0056/1539] Add tox configuration file. --- .gitignore | 3 ++- tox.ini | 7 +++++++ 2 files changed, 9 insertions(+), 1 deletion(-) create mode 100644 tox.ini diff --git a/.gitignore b/.gitignore index e78549946..f453c4914 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ *.pyc .coverage .DS_Store +.tox build/ compliance/reports/ dist/ @@ -9,4 +10,4 @@ htmlcov/ MANIFEST README README.html -websockets.egg-info/ \ No newline at end of file +websockets.egg-info/ diff --git a/tox.ini b/tox.ini new file mode 100644 index 000000000..f9aee3b26 --- /dev/null +++ b/tox.ini @@ -0,0 +1,7 @@ +[tox] +envlist = py33,py34 + +[testenv] +deps = + py33: asyncio +commands = python -m unittest From 0a815f860550ac89eefd27d626d5b7d02543c48d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 14 Jul 2015 22:02:39 +0200 Subject: [PATCH 0057/1539] Document how to multiplex reads and writes. Fix #48. --- docs/index.rst | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/docs/index.rst b/docs/index.rst index d88748769..e0b4f8f57 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -85,7 +85,35 @@ Here's a corresponding client example. to write and return otherwise. Of course, you can combine the two patterns shown above to read and write - messages on the same connection. + messages on the same connection:: + + @asyncio.coroutine + def handler(websocket, path): + while True: + listener_task = asyncio.ensure_future(websocket.recv()) + producer_task = asyncio.ensure_future(producer()) + done, pending = yield from asyncio.wait( + [listener_task, producer_task], + return_when=asyncio.FIRST_COMPLETED) + + if listener_task in done: + message = listener_task.result() + if message is None: + break + yield from consumer(message) + else: + listener_task.cancel() + + if producer_task in done: + message = producer_task.result() + if not websocket.open: + break + yield from websocket.send(message) + else: + producer_task.cancel() + + (This code looks convoluted. If you know a more straightforward solution, + please let me know about it!) That's really all you have to know! ``websockets`` manages the connection under the hood so you don't have to. From 6ce37977296199aa888881b0bfd75ed0f67a2502 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 14 Jul 2015 22:39:38 +0200 Subject: [PATCH 0058/1539] Allow customizing request or response HTTP headers. Thanks @knutae for providing an initial patch. Fix #47. --- docs/index.rst | 10 +++---- websockets/client.py | 37 +++++++++++++++++------- websockets/server.py | 48 +++++++++++++++++++++++--------- websockets/test_client_server.py | 30 ++++++++++++++++++++ 4 files changed, 97 insertions(+), 28 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index e0b4f8f57..2b2e370bc 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -193,11 +193,11 @@ Server .. automodule:: websockets.server - .. autofunction:: serve(ws_handler, host=None, port=None, *, loop=None, klass=WebSocketServerProtocol, origins=None, subprotocols=None, **kwds) + .. autofunction:: serve(ws_handler, host=None, port=None, *, loop=None, klass=WebSocketServerProtocol, origins=None, subprotocols=None, extra_headers=None, **kwds) - .. autoclass:: WebSocketServerProtocol(ws_handler, *, origins=None, subprotocols=None, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, loop=None) + .. autoclass:: WebSocketServerProtocol(ws_handler, *, origins=None, subprotocols=None, extra_headers=None, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, loop=None) - .. automethod:: handshake(origins=None, subprotocols=None) + .. automethod:: handshake(origins=None, subprotocols=None, extra_headers=None) .. automethod:: select_subprotocol(client_protos, server_protos) Client @@ -205,11 +205,11 @@ Client .. automodule:: websockets.client - .. autofunction:: connect(uri, *, loop=None, klass=WebSocketClientProtocol, origin=None, subprotocols=None, **kwds) + .. autofunction:: connect(uri, *, loop=None, klass=WebSocketClientProtocol, origin=None, subprotocols=None, extra_headers=None, **kwds) .. autoclass:: WebSocketClientProtocol(*, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, loop=None) - .. automethod:: handshake(wsuri, origin=None, subprotocols=None) + .. automethod:: handshake(wsuri, origin=None, subprotocols=None, extra_headers=None) Shared ...... diff --git a/websockets/client.py b/websockets/client.py index b9151bc42..925d27031 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -5,6 +5,7 @@ __all__ = ['connect', 'WebSocketClientProtocol'] import asyncio +import collections from .exceptions import InvalidHandshake from .handshake import build_request, check_response @@ -25,14 +26,18 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): state = 'CONNECTING' @asyncio.coroutine - def handshake(self, wsuri, origin=None, subprotocols=None): + def handshake(self, wsuri, + origin=None, subprotocols=None, extra_headers=None): """ Perform the client side of the opening handshake. - If provided, ``origin`` sets the HTTP Origin header. + If provided, ``origin`` sets the Origin HTTP header. - If provided, ``subprotocols`` is a list of supported subprotocols, in + If provided, ``subprotocols`` is a list of supported subprotocols in order of decreasing preference. + + If provided, ``extra_headers`` sets additional HTTP request headers. + It must be a mapping or an iterable of (name, value) pairs. """ headers = [] set_header = lambda k, v: headers.append((k, v)) @@ -44,8 +49,14 @@ def handshake(self, wsuri, origin=None, subprotocols=None): set_header('Origin', origin) if subprotocols is not None: set_header('Sec-WebSocket-Protocol', ', '.join(subprotocols)) + if extra_headers is not None: + if isinstance(extra_headers, collections.abc.Mapping): + extra_headers = extra_headers.items() + for name, value in extra_headers: + set_header(name, value) set_header('User-Agent', USER_AGENT) key = build_request(set_header) + self.raw_request_headers = headers # Send handshake request. Since the URI and the headers only contain @@ -79,19 +90,24 @@ def handshake(self, wsuri, origin=None, subprotocols=None): @asyncio.coroutine def connect(uri, *, - loop=None, klass=WebSocketClientProtocol, origin=None, - subprotocols=None, **kwds): + loop=None, klass=WebSocketClientProtocol, + origin=None, subprotocols=None, extra_headers=None, + **kwds): """ This coroutine connects to a WebSocket server. - It accepts an ``origin`` keyword argument to set the Origin HTTP header - and a ``subprotocols`` keyword argument to provide a list of supported - subprotocols. - It's a thin wrapper around the event loop's :meth:`~asyncio.BaseEventLoop.create_connection` method. Extra keyword arguments are passed to :meth:`~asyncio.BaseEventLoop.create_connection`. + This coroutine accepts several optional arguments: + + * ``origin`` sets the Origin HTTP header + * ``subprotocols`` is a list of supported subprotocols in order of + decreasing preference + * ``extra_headers`` sets additional HTTP request headers – it can be a + mapping or an iterable of (name, value) pairs + It returns a :class:`~websockets.client.WebSocketClientProtocol` which can then be used to send and receive messages. @@ -123,7 +139,8 @@ def connect(uri, *, try: yield from protocol.handshake( - wsuri, origin=origin, subprotocols=subprotocols) + wsuri, origin=origin, subprotocols=subprotocols, + extra_headers=extra_headers) except Exception: protocol.writer.close() raise diff --git a/websockets/server.py b/websockets/server.py index 26dec8d6e..73ce8ff37 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -4,6 +4,7 @@ __all__ = ['serve', 'WebSocketServerProtocol'] +import collections import logging import asyncio @@ -31,10 +32,11 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): state = 'CONNECTING' def __init__(self, ws_handler, *, - origins=None, subprotocols=None, **kwds): + origins=None, subprotocols=None, extra_headers=None, **kwds): self.ws_handler = ws_handler self.origins = origins self.subprotocols = subprotocols + self.extra_headers = extra_headers super().__init__(**kwds) def connection_made(self, transport): @@ -49,7 +51,8 @@ def handler(self): try: path = yield from self.handshake( - origins=self.origins, subprotocols=self.subprotocols) + origins=self.origins, subprotocols=self.subprotocols, + extra_headers=self.extra_headers) except Exception as exc: logger.info("Exception in opening handshake: {}".format(exc)) if isinstance(exc, InvalidHandshake): @@ -81,13 +84,18 @@ def handler(self): pass @asyncio.coroutine - def handshake(self, origins=None, subprotocols=None): + def handshake(self, origins=None, subprotocols=None, extra_headers=None): """ Perform the server side of the opening handshake. - If provided, `origins` is a list of acceptable HTTP Origin values. - Include ``''`` if the lack of an origin is acceptable. If provided, - `subprotocols` is a list of supported subprotocols. + If provided, ``origins`` is a list of acceptable HTTP Origin values. + Include ``''`` if the lack of an origin is acceptable. + + If provided, ``subprotocols`` is a list of supported subprotocols in + order of decreasing preference. + + If provided, ``extra_headers`` sets additional HTTP response headers. + It must be a mapping or an iterable of (name, value) pairs. Return the URI of the request. """ @@ -118,6 +126,11 @@ def handshake(self, origins=None, subprotocols=None): set_header('Server', USER_AGENT) if self.subprotocol: set_header('Sec-WebSocket-Protocol', self.subprotocol) + if extra_headers is not None: + if isinstance(extra_headers, collections.abc.Mapping): + extra_headers = extra_headers.items() + for name, value in extra_headers: + set_header(name, value) build_response(set_header, key) self.raw_response_headers = headers @@ -147,8 +160,9 @@ def select_subprotocol(self, client_protos, server_protos): @asyncio.coroutine def serve(ws_handler, host=None, port=None, *, - loop=None, klass=WebSocketServerProtocol, origins=None, - subprotocols=None, **kwds): + loop=None, klass=WebSocketServerProtocol, + origins=None, subprotocols=None, extra_headers=None, + **kwds): """ This coroutine creates a WebSocket server. @@ -157,11 +171,18 @@ def serve(ws_handler, host=None, port=None, *, well as extra keyword arguments are passed to :meth:`~asyncio.BaseEventLoop.create_server`. - `ws_handler` is the WebSocket handler. It must be a coroutine accepting + ``ws_handler`` is the WebSocket handler. It must be a coroutine accepting two arguments: a :class:`~websockets.server.WebSocketServerProtocol` and - the request URI. If provided, `origins` is a list of acceptable Origin - HTTP headers. Include ``''`` if the lack of an origin is acceptable. If - provided, `subprotocols` is a list of supported subprotocols. + the request URI. + + This coroutine accepts several optional arguments: + + * ``origins`` defines acceptable Origin HTTP headers — include + ``''`` if the lack of an origin is acceptable + * ``subprotocols`` is a list of supported subprotocols in order of + decreasing preference + * ``extra_headers`` sets additional HTTP response headers – it can be a + mapping or an iterable of (name, value) pairs `serve` yields a `Server` object with a `close` method to stop the server. @@ -186,5 +207,6 @@ def serve(ws_handler, host=None, port=None, *, secure = kwds.get('ssl') is not None factory = lambda: klass( ws_handler, host=host, port=port, secure=secure, - origins=origins, subprotocols=subprotocols) + origins=origins, subprotocols=subprotocols, + extra_headers=extra_headers) return (yield from loop.create_server(factory, host, port, **kwds)) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index a0b99c888..5b40414c6 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -87,6 +87,36 @@ def test_protocol_raw_headers(self): self.assertEqual(server_resp, repr(client_resp)) self.stop_client() + def test_protocol_custom_request_headers_dict(self): + self.start_client('raw_headers', extra_headers={'X-Spam': 'Eggs'}) + req_headers = self.loop.run_until_complete(self.client.recv()) + _resp_headers = self.loop.run_until_complete(self.client.recv()) + self.assertIn("('X-Spam', 'Eggs')", req_headers) + + def test_protocol_custom_request_headers_list(self): + self.start_client('raw_headers', extra_headers=[('X-Spam', 'Eggs')]) + req_headers = self.loop.run_until_complete(self.client.recv()) + _resp_headers = self.loop.run_until_complete(self.client.recv()) + self.assertIn("('X-Spam', 'Eggs')", req_headers) + + def test_protocol_custom_response_headers_dict(self): + self.stop_server() + self.start_server(extra_headers={'X-Spam': 'Eggs'}) + + self.start_client('raw_headers') + _req_headers = self.loop.run_until_complete(self.client.recv()) + resp_headers = self.loop.run_until_complete(self.client.recv()) + self.assertIn("('X-Spam', 'Eggs')", resp_headers) + + def test_protocol_custom_response_headers_list(self): + self.stop_server() + self.start_server(extra_headers=[('X-Spam', 'Eggs')]) + + self.start_client('raw_headers') + _req_headers = self.loop.run_until_complete(self.client.recv()) + resp_headers = self.loop.run_until_complete(self.client.recv()) + self.assertIn("('X-Spam', 'Eggs')", resp_headers) + def test_no_subprotocol(self): self.start_client('subprotocol') server_subprotocol = self.loop.run_until_complete(self.client.recv()) From a621764215c5abdd78808df7ec9bd1b3c265d39c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 14 Jul 2015 23:02:19 +0200 Subject: [PATCH 0059/1539] Documented the fix for #50. --- docs/index.rst | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/index.rst b/docs/index.rst index 2b2e370bc..be5de2f15 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -263,6 +263,16 @@ Utilities Changelog --------- +2.5 +... + +* Allowed customizing handshake request and response HTTP headers. + +* Improved documentation. + +* Cancelling :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` no + longer drops the next message. + 2.4 ... From 23878bc5b0ad85d6a2cf6e7f9c80876b99916f23 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 18 Jul 2015 15:37:43 +0200 Subject: [PATCH 0060/1539] Import collections.abc for Python 3.4. Thanks @SzieberthAdam for the report and initial patch. Fix #62. --- websockets/client.py | 2 +- websockets/server.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index 925d27031..f585a6873 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -5,7 +5,7 @@ __all__ = ['connect', 'WebSocketClientProtocol'] import asyncio -import collections +import collections.abc from .exceptions import InvalidHandshake from .handshake import build_request, check_response diff --git a/websockets/server.py b/websockets/server.py index 73ce8ff37..a35734aa7 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -4,7 +4,7 @@ __all__ = ['serve', 'WebSocketServerProtocol'] -import collections +import collections.abc import logging import asyncio From 8e406194ed412b06a6c88ec9824684a06a41d3c1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Jul 2015 15:58:38 +0200 Subject: [PATCH 0061/1539] Add a missing "yield from" in a test utility. Since the result of this method was always yielded from, this bug didn't have visible effects. --- websockets/test_protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index c462b9948..b5a866882 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -42,7 +42,7 @@ def sent(self): self.transport.write.call_args_list = [] stream.feed_eof() if not stream.at_eof(): - return read_frame(stream.readexactly, self.protocol.is_client) + return (yield from read_frame(stream.readexactly, self.protocol.is_client)) @asyncio.coroutine def echo(self): From 6624dc8e0a6ca4cdaf8af2e65ba8bbc85589b96a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Jul 2015 16:00:29 +0200 Subject: [PATCH 0062/1539] Enable asyncio's debug tools in tests. --- Makefile | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Makefile b/Makefile index 6f014b953..7f96bcdd7 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,6 @@ +export PYTHONASYNCIODEBUG=1 +export PYTHONWARNINGS=default + test: python -m unittest From fb58aa61c8f77e73adb7e591e7a65acbf8470499 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Jul 2015 15:29:06 +0200 Subject: [PATCH 0063/1539] Very minor cleanup. --- websockets/test_uri.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/websockets/test_uri.py b/websockets/test_uri.py index 4e30fb433..537062023 100644 --- a/websockets/test_uri.py +++ b/websockets/test_uri.py @@ -4,26 +4,28 @@ from .uri import * -VALID_URIS = ( +VALID_URIS = [ ('ws://localhost/', (False, 'localhost', 80, '/')), ('wss://localhost/', (True, 'localhost', 443, '/')), ('ws://localhost/path?query', (False, 'localhost', 80, '/path?query')), -) +] -INVALID_URIS = ( +INVALID_URIS = [ 'http://localhost/', 'https://localhost/', 'http://localhost/path#fragment' -) +] class URITests(unittest.TestCase): def test_success(self): for uri, parsed in VALID_URIS: + # wrap in `with self.subTest():` when dropping Python 3.3 self.assertEqual(parse_uri(uri), parsed) def test_error(self): for uri in INVALID_URIS: + # wrap in `with self.subTest():` when dropping Python 3.3 with self.assertRaises(InvalidURI): parse_uri(uri) From 381a3641ba997a33e19acb73721e324c96d01635 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szieberth=20=C3=81d=C3=A1m?= Date: Thu, 16 Jul 2015 10:16:39 +0200 Subject: [PATCH 0064/1539] Pass event loop explicitly to all asyncio objects Fix #60. --- websockets/client.py | 3 +- websockets/protocol.py | 2 +- websockets/server.py | 2 +- websockets/test_client_server.py | 4 +- websockets/test_framing.py | 2 +- websockets/test_http.py | 2 +- websockets/test_protocol.py | 69 +++++++++++++++++--------------- 7 files changed, 45 insertions(+), 39 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index f585a6873..777106bed 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -132,7 +132,8 @@ def connect(uri, *, elif 'ssl' in kwds: raise ValueError("connect() received a SSL context for a ws:// URI. " "Use a wss:// URI to enable TLS.") - factory = lambda: klass(host=wsuri.host, port=wsuri.port, secure=wsuri.secure) + factory = lambda: klass(host=wsuri.host, port=wsuri.port, + secure=wsuri.secure, loop=loop) transport, protocol = yield from loop.create_connection( factory, wsuri.host, wsuri.port, **kwds) diff --git a/websockets/protocol.py b/websockets/protocol.py index efe50c543..9300cd377 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -89,7 +89,7 @@ def __init__(self, *, self.connection_closed = asyncio.Future(loop=loop) # Queue of received messages. - self.messages = Queue() + self.messages = Queue(loop=loop) # Mapping of ping IDs to waiters, in chronological order. self.pings = collections.OrderedDict() diff --git a/websockets/server.py b/websockets/server.py index a35734aa7..6984a5d17 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -208,5 +208,5 @@ def serve(ws_handler, host=None, port=None, *, factory = lambda: klass( ws_handler, host=host, port=port, secure=secure, origins=origins, subprotocols=subprotocols, - extra_headers=extra_headers) + extra_headers=extra_headers, loop=loop) return (yield from loop.create_server(factory, host, port, **kwds)) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 5b40414c6..70b5c402b 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -188,7 +188,7 @@ def test_client_receives_malformed_response(self, _read_response): # Now the server believes the connection is open. Run the event loop # once to make it notice the connection was closed. Interesting hack. - self.loop.run_until_complete(asyncio.sleep(0)) + self.loop.run_until_complete(asyncio.sleep(0, loop=self.loop)) @patch('websockets.client.build_request') def test_client_sends_invalid_handshake_request(self, _build_request): @@ -221,7 +221,7 @@ def wrong_read_response(stream): # Now the server believes the connection is open. Run the event loop # once to make it notice the connection was closed. Interesting hack. - self.loop.run_until_complete(asyncio.sleep(0)) + self.loop.run_until_complete(asyncio.sleep(0, loop=self.loop)) @patch('websockets.server.WebSocketServerProtocol.send') def test_server_handler_crashes(self, send): diff --git a/websockets/test_framing.py b/websockets/test_framing.py index 86e2c595f..674fa1373 100644 --- a/websockets/test_framing.py +++ b/websockets/test_framing.py @@ -17,7 +17,7 @@ def tearDown(self): self.loop.close() def decode(self, message, mask=False, max_size=None): - self.stream = asyncio.StreamReader() + self.stream = asyncio.StreamReader(loop=self.loop) self.stream.feed_data(message) self.stream.feed_eof() reader = self.stream.readexactly diff --git a/websockets/test_http.py b/websockets/test_http.py index 2caf891e2..0607b8196 100644 --- a/websockets/test_http.py +++ b/websockets/test_http.py @@ -12,7 +12,7 @@ def setUp(self): super().setUp() self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) - self.stream = asyncio.StreamReader() + self.stream = asyncio.StreamReader(loop=self.loop) def tearDown(self): self.loop.close() diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index b5a866882..a188197ae 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -2,6 +2,7 @@ import unittest.mock import asyncio +from functools import partial from .exceptions import InvalidState, PayloadTooBig from .framing import * @@ -24,6 +25,10 @@ def setUp(self): side_effect=lambda: self.protocol.connection_lost(None)) self.protocol.connection_made(self.transport) + @property + def async(self): + return partial(asyncio.async, loop=self.loop) + def tearDown(self): self.loop.close() super().tearDown() @@ -36,7 +41,7 @@ def feed(self, frame): @asyncio.coroutine def sent(self): """Read the next frame sent to the transport.""" - stream = asyncio.StreamReader() + stream = asyncio.StreamReader(loop=self.loop) for (data,), kw in self.transport.write.call_args_list: stream.feed_data(data) self.transport.write.call_args_list = [] @@ -94,27 +99,27 @@ def test_recv_binary(self): def test_recv_protocol_error(self): self.feed(Frame(True, OP_CONT, 'café'.encode('utf-8'))) - self.loop.call_later(MS, asyncio.async, self.fast_connection_failure()) + self.loop.call_later(MS, self.async, self.fast_connection_failure()) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1002, '') def test_recv_unicode_error(self): self.feed(Frame(True, OP_TEXT, 'café'.encode('latin-1'))) - self.loop.call_later(MS, asyncio.async, self.fast_connection_failure()) + self.loop.call_later(MS, self.async, self.fast_connection_failure()) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1007, '') def test_recv_text_payload_too_big(self): self.protocol.max_size = 1024 self.feed(Frame(True, OP_TEXT, 'café'.encode('utf-8') * 205)) - self.loop.call_later(MS, asyncio.async, self.fast_connection_failure()) + self.loop.call_later(MS, self.async, self.fast_connection_failure()) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1009, '') def test_recv_binary_payload_too_big(self): self.protocol.max_size = 1024 self.feed(Frame(True, OP_BINARY, b'tea' * 342)) - self.loop.call_later(MS, asyncio.async, self.fast_connection_failure()) + self.loop.call_later(MS, self.async, self.fast_connection_failure()) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1009, '') @@ -135,7 +140,7 @@ def test_recv_other_error(self): def read_message(): raise Exception("BOOM") self.protocol.read_message = read_message - self.loop.call_later(MS, asyncio.async, self.fast_connection_failure()) + self.loop.call_later(MS, self.async, self.fast_connection_failure()) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) with self.assertRaises(Exception): self.loop.run_until_complete(self.protocol.worker) @@ -248,7 +253,7 @@ def test_fragmented_text_payload_too_big(self): self.protocol.max_size = 1024 self.feed(Frame(False, OP_TEXT, 'café'.encode('utf-8') * 100)) self.feed(Frame(True, OP_CONT, 'café'.encode('utf-8') * 105)) - self.loop.call_later(MS, asyncio.async, self.fast_connection_failure()) + self.loop.call_later(MS, self.async, self.fast_connection_failure()) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1009, '') @@ -256,7 +261,7 @@ def test_fragmented_binary_payload_too_big(self): self.protocol.max_size = 1024 self.feed(Frame(False, OP_BINARY, b'tea' * 171)) self.feed(Frame(True, OP_CONT, b'tea' * 171)) - self.loop.call_later(MS, asyncio.async, self.fast_connection_failure()) + self.loop.call_later(MS, self.async, self.fast_connection_failure()) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1009, '') @@ -286,14 +291,14 @@ def test_unterminated_fragmented_text(self): self.feed(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) # Missing the second part of the fragmented frame. self.feed(Frame(True, OP_BINARY, b'tea')) - self.loop.call_later(MS, asyncio.async, self.fast_connection_failure()) + self.loop.call_later(MS, self.async, self.fast_connection_failure()) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1002, '') def test_close_handshake_in_fragmented_text(self): self.feed(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) self.feed(Frame(True, OP_CLOSE, b'')) - self.loop.call_later(MS, asyncio.async, self.fast_connection_failure()) + self.loop.call_later(MS, self.async, self.fast_connection_failure()) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1002, '') @@ -308,7 +313,7 @@ def test_connection_close_in_fragmented_text(self): class ServerTests(CommonTests, unittest.TestCase): def test_close(self): # standard server-initiated close - self.loop.call_later(MS, asyncio.async, self.echo()) + self.loop.call_later(MS, self.async, self.echo()) self.loop.run_until_complete(self.protocol.close(reason='because.')) self.assertConnectionClosed(1000, 'because.') # Only one frame is emitted, and it's consumed by self.echo(). @@ -342,16 +347,16 @@ def test_simultaneous_close(self): # non standard close from both sides def test_close_drops_frames(self): self.loop.call_later(MS, self.feed, Frame(True, OP_TEXT, b'')) - self.loop.call_later(2 * MS, asyncio.async, self.echo()) + self.loop.call_later(2 * MS, self.async, self.echo()) self.loop.run_until_complete(self.protocol.close(reason='because.')) self.assertConnectionClosed(1000, 'because.') # Only one frame is emitted, and it's consumed by self.echo(). self.assertNoFrameSent() def test_close_handshake_timeout(self): - self.after = asyncio.Future() + self.after = asyncio.Future(loop=self.loop) self.loop.call_later(4 * MS, self.after.cancel) - self.before = asyncio.Future() + self.before = asyncio.Future(loop=self.loop) self.loop.call_later(8 * MS, self.before.cancel) self.protocol.timeout = 5 * MS self.loop.run_until_complete(self.protocol.close(reason='because.')) @@ -364,12 +369,12 @@ def test_close_timeout_before_connection_lost(self): # Prevent the connection from terminating. self.protocol.connection_lost = unittest.mock.Mock() - self.after = asyncio.Future() + self.after = asyncio.Future(loop=self.loop) self.loop.call_later(4 * MS, self.after.cancel) - self.before = asyncio.Future() + self.before = asyncio.Future(loop=self.loop) self.loop.call_later(8 * MS, self.before.cancel) self.protocol.timeout = 5 * MS - self.loop.call_later(MS, asyncio.async, self.echo()) + self.loop.call_later(MS, self.async, self.echo()) self.loop.run_until_complete(self.protocol.close(reason='because.')) self.assertEqual(self.protocol.state, 'CLOSING') self.assertTrue(self.after.cancelled()) @@ -381,14 +386,14 @@ def test_client_close_race_with_failing_connection(self): @asyncio.coroutine def delayed_write_frame(*args): yield from original_write_frame(*args) - yield from asyncio.sleep(2 * MS) + yield from asyncio.sleep(2 * MS, loop=self.loop) self.protocol.write_frame = delayed_write_frame frame = Frame(True, OP_CLOSE, serialize_close(1000, 'client')) # Trigger the race condition between answering the close frame from # the client and sending another close frame from the server. self.loop.call_later(MS, self.feed, frame) - self.loop.call_later(2 * MS, asyncio.async, self.protocol.fail_connection(1000, 'server')) + self.loop.call_later(2 * MS, self.async, self.protocol.fail_connection(1000, 'server')) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1000, 'server') self.assertFrameSent(*frame) @@ -405,19 +410,19 @@ def test_close_connection_lost(self): self.assertConnectionClosed(1006, '') def test_close_during_recv(self): - recv = asyncio.async(self.protocol.recv()) - self.loop.call_later(MS, asyncio.async, self.echo()) + recv = self.async(self.protocol.recv()) + self.loop.call_later(MS, self.async, self.echo()) self.loop.run_until_complete(self.protocol.close(reason='because.')) self.assertIsNone(self.loop.run_until_complete(recv)) def test_close_after_cancelled_recv(self): - recv = asyncio.async(self.protocol.recv()) + recv = self.async(self.protocol.recv()) self.loop.call_later(MS, recv.cancel) with self.assertRaises(asyncio.CancelledError): self.loop.run_until_complete(recv) # Closing the connection shouldn't crash. # I can't find a way to test this on the client side. - self.loop.call_later(MS, asyncio.async, self.echo()) + self.loop.call_later(MS, self.async, self.echo()) self.loop.run_until_complete(self.protocol.close(reason='because.')) @@ -443,7 +448,7 @@ def test_close(self): # standard server-initiated close self.assertNoFrameSent() def test_client_close(self): # non standard client-initiated close - self.loop.call_later(MS, asyncio.async, self.echo()) + self.loop.call_later(MS, self.async, self.echo()) self.loop.call_later(2 * MS, self.protocol.eof_received) self.loop.call_later(3 * MS, lambda: self.protocol.connection_lost(None)) self.loop.run_until_complete(self.protocol.close(reason='because.')) @@ -467,12 +472,12 @@ def test_simultaneous_close(self): # non standard close from both sides self.assertNoFrameSent() def test_close_timeout_before_eof_received(self): - self.after = asyncio.Future() + self.after = asyncio.Future(loop=self.loop) self.loop.call_later(4 * MS, self.after.cancel) - self.before = asyncio.Future() + self.before = asyncio.Future(loop=self.loop) self.loop.call_later(8 * MS, self.before.cancel) self.protocol.timeout = 5 * MS - self.loop.call_later(MS, asyncio.async, self.echo()) + self.loop.call_later(MS, self.async, self.echo()) self.loop.run_until_complete(self.protocol.close(reason='because.')) # If the server doesn't drop the connection quickly, the client will. self.assertConnectionClosed(1000, 'because.') @@ -484,12 +489,12 @@ def test_close_timeout_before_connection_lost(self): # Prevent the connection from terminating. self.protocol.connection_lost = unittest.mock.Mock() - self.after = asyncio.Future() + self.after = asyncio.Future(loop=self.loop) self.loop.call_later(9 * MS, self.after.cancel) - self.before = asyncio.Future() + self.before = asyncio.Future(loop=self.loop) self.loop.call_later(13 * MS, self.before.cancel) self.protocol.timeout = 5 * MS - self.loop.call_later(MS, asyncio.async, self.echo()) + self.loop.call_later(MS, self.async, self.echo()) self.loop.call_later(2 * MS, self.protocol.eof_received) self.loop.run_until_complete(self.protocol.close(reason='because.')) # If the server doesn't drop the connection quickly, the client will. @@ -503,14 +508,14 @@ def test_server_close_race_with_failing_connection(self): @asyncio.coroutine def delayed_write_frame(*args): yield from original_write_frame(*args) - yield from asyncio.sleep(2 * MS) + yield from asyncio.sleep(2 * MS, loop=self.loop) self.protocol.write_frame = delayed_write_frame frame = Frame(True, OP_CLOSE, serialize_close(1000, 'server')) # Trigger the race condition between answering the close frame from # the server and sending another close frame from the client. self.loop.call_later(MS, self.feed, frame) - self.loop.call_later(2 * MS, asyncio.async, self.protocol.fail_connection(1000, 'client')) + self.loop.call_later(2 * MS, self.async, self.protocol.fail_connection(1000, 'client')) self.loop.call_later(3 * MS, self.protocol.eof_received) self.loop.call_later(4 * MS, lambda: self.protocol.connection_lost(None)) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) From 14473689ae0f92f7aee16868fd41665e3ad9aac0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Jul 2015 16:25:25 +0200 Subject: [PATCH 0065/1539] Small adjustments & changelog for previous commit. Ref #60. --- docs/index.rst | 2 ++ websockets/protocol.py | 15 +++++++++------ websockets/server.py | 2 +- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index be5de2f15..b91d82efb 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -268,6 +268,8 @@ Changelog * Allowed customizing handshake request and response HTTP headers. +* Supported running on a non-default event loop. + * Improved documentation. * Cancelling :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` no diff --git a/websockets/protocol.py b/websockets/protocol.py index 9300cd377..93cbe2873 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -73,6 +73,9 @@ def __init__(self, *, self.timeout = timeout self.max_size = max_size + # Store a reference to loop to avoid relying on self.loop, a private + # attribute of StreamReaderProtocol, inherited from FlowControlMixin. + self.loop = loop stream_reader = asyncio.StreamReader(loop=loop) super().__init__(stream_reader, self.client_connected, loop) @@ -137,7 +140,7 @@ def close(self, code=1000, reason=''): # the worker loop. try: yield from asyncio.wait_for( - self.worker, self.timeout, loop=self._loop) + self.worker, self.timeout, loop=self.loop) except asyncio.TimeoutError: self.worker.cancel() @@ -163,11 +166,11 @@ def recv(self): pass # Wait for a message until the connection is closed - next_message = asyncio.async(self.messages.get(), loop=self._loop) + next_message = asyncio.async(self.messages.get(), loop=self.loop) try: done, pending = yield from asyncio.wait( [next_message, self.worker], - loop=self._loop, return_when=asyncio.FIRST_COMPLETED) + loop=self.loop, return_when=asyncio.FIRST_COMPLETED) except asyncio.CancelledError: # Handle the Task.cancel() next_message.cancel() @@ -216,7 +219,7 @@ def ping(self, data=None): while data is None or data in self.pings: data = struct.pack('!I', random.getrandbits(32)) - self.pings[data] = asyncio.Future(loop=self._loop) + self.pings[data] = asyncio.Future(loop=self.loop) yield from self.write_frame(OP_PING, data) return self.pings[data] @@ -384,7 +387,7 @@ def close_connection(self): if self.is_client: try: yield from asyncio.wait_for( - self.connection_closed, self.timeout, loop=self._loop) + self.connection_closed, self.timeout, loop=self.loop) except (asyncio.CancelledError, asyncio.TimeoutError): pass @@ -403,7 +406,7 @@ def close_connection(self): try: yield from asyncio.wait_for( - self.connection_closed, self.timeout, loop=self._loop) + self.connection_closed, self.timeout, loop=self.loop) except (asyncio.CancelledError, asyncio.TimeoutError): pass diff --git a/websockets/server.py b/websockets/server.py index 6984a5d17..a9b8ffb8a 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -41,7 +41,7 @@ def __init__(self, ws_handler, *, def connection_made(self, transport): super().connection_made(transport) - asyncio.async(self.handler(), loop=self._loop) + asyncio.async(self.handler(), loop=self.loop) @asyncio.coroutine def handler(self): From 6a3ed1ecc6edb788a05783028af1d64dcd25f612 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Jul 2015 16:59:09 +0200 Subject: [PATCH 0066/1539] Add flake8 and fix warnings. --- setup.cfg | 3 +++ tox.ini | 7 ++++++- websockets/__init__.py | 2 +- websockets/client.py | 14 +++++++------- websockets/exceptions.py | 4 ++-- websockets/handshake.py | 10 ++++++---- websockets/protocol.py | 24 +++++++++++++++--------- websockets/server.py | 8 ++++---- websockets/test_client_server.py | 9 +++++---- websockets/test_framing.py | 28 ++++++++++++++-------------- websockets/test_protocol.py | 31 ++++++++++++++++++------------- websockets/uri.py | 4 ++-- 12 files changed, 83 insertions(+), 61 deletions(-) diff --git a/setup.cfg b/setup.cfg index 87175524c..0530ab2e0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,2 +1,5 @@ [bdist_wheel] python-tag = py33.py34 + +[flake8] +ignore = F403 diff --git a/tox.ini b/tox.ini index f9aee3b26..fe9b0af07 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,12 @@ [tox] -envlist = py33,py34 +envlist = py33,py34,flake8 [testenv] deps = py33: asyncio commands = python -m unittest + +[testenv:flake8] +commands = flake8 websockets +deps = + flake8 diff --git a/websockets/__init__.py b/websockets/__init__.py index 7f21bd6c1..60bc9c5fe 100644 --- a/websockets/__init__.py +++ b/websockets/__init__.py @@ -14,4 +14,4 @@ + uri.__all__ ) -from .version import version as __version__ +from .version import version as __version__ # noqa diff --git a/websockets/client.py b/websockets/client.py index 777106bed..5ca06271c 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -41,7 +41,7 @@ def handshake(self, wsuri, """ headers = [] set_header = lambda k, v: headers.append((k, v)) - if wsuri.port == (443 if wsuri.secure else 80): # pragma: no cover + if wsuri.port == (443 if wsuri.secure else 80): # pragma: no cover set_header('Host', wsuri.host) else: set_header('Host', '{}:{}'.format(wsuri.host, wsuri.port)) @@ -82,7 +82,7 @@ def handshake(self, wsuri, if (self.subprotocol is not None and self.subprotocol not in subprotocols): raise InvalidHandshake( - "Unknown subprotocol: {}".format(self.subprotocol)) + "Unknown subprotocol: {}".format(self.subprotocol)) self.state = 'OPEN' self.opening_handshake.set_result(True) @@ -132,16 +132,16 @@ def connect(uri, *, elif 'ssl' in kwds: raise ValueError("connect() received a SSL context for a ws:// URI. " "Use a wss:// URI to enable TLS.") - factory = lambda: klass(host=wsuri.host, port=wsuri.port, - secure=wsuri.secure, loop=loop) + factory = lambda: klass( + host=wsuri.host, port=wsuri.port, secure=wsuri.secure, loop=loop) transport, protocol = yield from loop.create_connection( - factory, wsuri.host, wsuri.port, **kwds) + factory, wsuri.host, wsuri.port, **kwds) try: yield from protocol.handshake( - wsuri, origin=origin, subprotocols=subprotocols, - extra_headers=extra_headers) + wsuri, origin=origin, subprotocols=subprotocols, + extra_headers=extra_headers) except Exception: protocol.writer.close() raise diff --git a/websockets/exceptions.py b/websockets/exceptions.py index a1e130d4e..32b95be5d 100644 --- a/websockets/exceptions.py +++ b/websockets/exceptions.py @@ -13,11 +13,11 @@ class InvalidState(Exception): class InvalidURI(Exception): - """Exception raised when an URI is invalid.""" + """Exception raised when an URI isn't a valid websocket URI.""" class PayloadTooBig(Exception): - """Exception raised when the payload in a frame exceeds the maximum size.""" + """Exception raised when a frame's payload exceeds the maximum size.""" class WebSocketProtocolError(Exception): diff --git a/websockets/handshake.py b/websockets/handshake.py index a400147f6..799c93b40 100644 --- a/websockets/handshake.py +++ b/websockets/handshake.py @@ -79,8 +79,9 @@ def check_request(get_header): """ try: assert get_header('Upgrade').lower() == 'websocket' - assert any(token.strip() == 'upgrade' - for token in get_header('Connection').lower().split(',')) + assert any( + token.strip() == 'upgrade' + for token in get_header('Connection').lower().split(',')) key = get_header('Sec-WebSocket-Key') assert len(base64.b64decode(key.encode())) == 16 assert get_header('Sec-WebSocket-Version') == '13' @@ -117,8 +118,9 @@ def check_response(get_header, key): """ try: assert get_header('Upgrade').lower() == 'websocket' - assert any(token.strip() == 'upgrade' - for token in get_header('Connection').lower().split(',')) + assert any( + token.strip() == 'upgrade' + for token in get_header('Connection').lower().split(',')) assert get_header('Sec-WebSocket-Accept') == accept(key) except (AssertionError, KeyError) as exc: raise InvalidHandshake("Invalid response") from exc diff --git a/websockets/protocol.py b/websockets/protocol.py index 93cbe2873..00c11b0a0 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -66,7 +66,8 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): state = 'OPEN' def __init__(self, *, - host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, loop=None): + host=None, port=None, secure=None, + timeout=10, max_size=2 ** 20, loop=None): self.host = host self.port = port self.secure = secure @@ -132,7 +133,8 @@ def close(self, code=1000, reason=''): if self.state == 'OPEN': # 7.1.2. Start the WebSocket Closing Handshake self.close_code, self.close_reason = code, reason - yield from self.write_frame(OP_CLOSE, serialize_close(code, reason)) + frame_data = serialize_close(code, reason) + yield from self.write_frame(OP_CLOSE, frame_data) # 7.1.3. The WebSocket Closing Handshake is Started self.state = 'CLOSING' @@ -140,7 +142,7 @@ def close(self, code=1000, reason=''): # the worker loop. try: yield from asyncio.wait_for( - self.worker, self.timeout, loop=self.loop) + self.worker, self.timeout, loop=self.loop) except asyncio.TimeoutError: self.worker.cancel() @@ -190,7 +192,8 @@ def send(self, data): frame. It raises a :exc:`TypeError` for other inputs and - :exc:`~websockets.exceptions.InvalidState` once the connection is closed. + :exc:`~websockets.exceptions.InvalidState` once the connection is + closed. """ if isinstance(data, str): opcode = 1 @@ -324,7 +327,8 @@ def read_data_frame(self, max_size): if self.state != 'CLOSING': # 7.1.3. The WebSocket Closing Handshake is Started self.state = 'CLOSING' - yield from self.write_frame(OP_CLOSE, frame.data, 'CLOSING') + yield from self.write_frame( + OP_CLOSE, frame.data, expected_state='CLOSING') if not self.closing_handshake.done(): self.closing_handshake.set_result(True) return @@ -347,7 +351,8 @@ def read_data_frame(self, max_size): @asyncio.coroutine def read_frame(self, max_size): is_masked = not self.is_client - frame = yield from read_frame(self.reader.readexactly, is_masked, max_size=max_size) + frame = yield from read_frame( + self.reader.readexactly, is_masked, max_size=max_size) side = 'client' if self.is_client else 'server' logger.debug("%s << %s", side, frame) return frame @@ -387,7 +392,7 @@ def close_connection(self): if self.is_client: try: yield from asyncio.wait_for( - self.connection_closed, self.timeout, loop=self.loop) + self.connection_closed, self.timeout, loop=self.loop) except (asyncio.CancelledError, asyncio.TimeoutError): pass @@ -406,7 +411,7 @@ def close_connection(self): try: yield from asyncio.wait_for( - self.connection_closed, self.timeout, loop=self.loop) + self.connection_closed, self.timeout, loop=self.loop) except (asyncio.CancelledError, asyncio.TimeoutError): pass @@ -428,7 +433,8 @@ def fail_connection(self, code=1011, reason=''): # 7.1.7. Fail the WebSocket Connection logger.info("Failing the WebSocket connection: %d %s", code, reason) if self.state == 'OPEN': - yield from self.write_frame(OP_CLOSE, serialize_close(code, reason)) + frame_data = serialize_close(code, reason) + yield from self.write_frame(OP_CLOSE, frame_data) self.state = 'CLOSING' if not self.closing_handshake.done(): self.closing_handshake.set_result(False) diff --git a/websockets/server.py b/websockets/server.py index a9b8ffb8a..43bfabffe 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -119,7 +119,7 @@ def handshake(self, origins=None, subprotocols=None, extra_headers=None): if protocol: client_subprotocols = [p.strip() for p in protocol.split(',')] self.subprotocol = self.select_subprotocol( - client_subprotocols, subprotocols) + client_subprotocols, subprotocols) headers = [] set_header = lambda k, v: headers.append((k, v)) @@ -206,7 +206,7 @@ def serve(ws_handler, host=None, port=None, *, secure = kwds.get('ssl') is not None factory = lambda: klass( - ws_handler, host=host, port=port, secure=secure, - origins=origins, subprotocols=subprotocols, - extra_headers=extra_headers, loop=loop) + ws_handler, host=host, port=port, secure=secure, + origins=origins, subprotocols=subprotocols, + extra_headers=extra_headers, loop=loop) return (yield from loop.create_server(factory, host, port, **kwds)) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 70b5c402b..e73ccee98 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -90,13 +90,13 @@ def test_protocol_raw_headers(self): def test_protocol_custom_request_headers_dict(self): self.start_client('raw_headers', extra_headers={'X-Spam': 'Eggs'}) req_headers = self.loop.run_until_complete(self.client.recv()) - _resp_headers = self.loop.run_until_complete(self.client.recv()) + self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", req_headers) def test_protocol_custom_request_headers_list(self): self.start_client('raw_headers', extra_headers=[('X-Spam', 'Eggs')]) req_headers = self.loop.run_until_complete(self.client.recv()) - _resp_headers = self.loop.run_until_complete(self.client.recv()) + self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", req_headers) def test_protocol_custom_response_headers_dict(self): @@ -104,7 +104,7 @@ def test_protocol_custom_response_headers_dict(self): self.start_server(extra_headers={'X-Spam': 'Eggs'}) self.start_client('raw_headers') - _req_headers = self.loop.run_until_complete(self.client.recv()) + self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) @@ -113,7 +113,7 @@ def test_protocol_custom_response_headers_list(self): self.start_server(extra_headers=[('X-Spam', 'Eggs')]) self.start_client('raw_headers') - _req_headers = self.loop.run_until_complete(self.client.recv()) + self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) @@ -283,6 +283,7 @@ def test_ws_uri_is_rejected(self): with self.assertRaises(ValueError): self.loop.run_until_complete(client) + class ClientServerOriginTests(unittest.TestCase): def test_checking_origin_succeeds(self): diff --git a/websockets/test_framing.py b/websockets/test_framing.py index 674fa1373..cf2a678eb 100644 --- a/websockets/test_framing.py +++ b/websockets/test_framing.py @@ -20,8 +20,8 @@ def decode(self, message, mask=False, max_size=None): self.stream = asyncio.StreamReader(loop=self.loop) self.stream.feed_data(message) self.stream.feed_eof() - reader = self.stream.readexactly - return self.loop.run_until_complete(read_frame(reader, mask, max_size=max_size)) + return self.loop.run_until_complete(read_frame( + self.stream.readexactly, mask, max_size=max_size)) def encode(self, frame, mask=False): encoded = io.BytesIO() @@ -49,26 +49,26 @@ def test_text(self): def test_text_masked(self): self.round_trip( - b'\x81\x84\x5b\xfb\xe1\xa8\x08\x8b\x80\xc5', - Frame(True, OP_TEXT, b'Spam'), mask=True) + b'\x81\x84\x5b\xfb\xe1\xa8\x08\x8b\x80\xc5', + Frame(True, OP_TEXT, b'Spam'), mask=True) def test_binary(self): self.round_trip(b'\x82\x04Eggs', Frame(True, OP_BINARY, b'Eggs')) def test_binary_masked(self): self.round_trip( - b'\x82\x84\x53\xcd\xe2\x89\x16\xaa\x85\xfa', - Frame(True, OP_BINARY, b'Eggs'), mask=True) + b'\x82\x84\x53\xcd\xe2\x89\x16\xaa\x85\xfa', + Frame(True, OP_BINARY, b'Eggs'), mask=True) def test_non_ascii_text(self): self.round_trip( - b'\x81\x05caf\xc3\xa9', - Frame(True, OP_TEXT, 'café'.encode('utf-8'))) + b'\x81\x05caf\xc3\xa9', + Frame(True, OP_TEXT, 'café'.encode('utf-8'))) def test_non_ascii_text_masked(self): self.round_trip( - b'\x81\x85\x64\xbe\xee\x7e\x07\xdf\x88\xbd\xcd', - Frame(True, OP_TEXT, 'café'.encode('utf-8')), mask=True) + b'\x81\x85\x64\xbe\xee\x7e\x07\xdf\x88\xbd\xcd', + Frame(True, OP_TEXT, 'café'.encode('utf-8')), mask=True) def test_close(self): self.round_trip(b'\x88\x00', Frame(True, OP_CLOSE, b'')) @@ -81,13 +81,13 @@ def test_pong(self): def test_long(self): self.round_trip( - b'\x82\x7e\x00\x7e' + 126 * b'a', - Frame(True, OP_BINARY, 126 * b'a')) + b'\x82\x7e\x00\x7e' + 126 * b'a', + Frame(True, OP_BINARY, 126 * b'a')) def test_very_long(self): self.round_trip( - b'\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x00' + 65536 * b'a', - Frame(True, OP_BINARY, 65536 * b'a')) + b'\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x00' + 65536 * b'a', + Frame(True, OP_BINARY, 65536 * b'a')) def test_payload_too_big(self): with self.assertRaises(PayloadTooBig): diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index a188197ae..309f7c7e2 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -4,7 +4,7 @@ import asyncio from functools import partial -from .exceptions import InvalidState, PayloadTooBig +from .exceptions import InvalidState from .framing import * from .protocol import WebSocketCommonProtocol @@ -22,7 +22,7 @@ def setUp(self): self.transport = unittest.mock.Mock() self.transport._conn_lost = 0 # checked by drain() self.transport.close = unittest.mock.Mock( - side_effect=lambda: self.protocol.connection_lost(None)) + side_effect=lambda: self.protocol.connection_lost(None)) self.protocol.connection_made(self.transport) @property @@ -47,7 +47,8 @@ def sent(self): self.transport.write.call_args_list = [] stream.feed_eof() if not stream.at_eof(): - return (yield from read_frame(stream.readexactly, self.protocol.is_client)) + return (yield from read_frame( + stream.readexactly, self.protocol.is_client)) @asyncio.coroutine def echo(self): @@ -305,7 +306,7 @@ def test_close_handshake_in_fragmented_text(self): def test_connection_close_in_fragmented_text(self): self.feed(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) self.loop.call_later(MS, self.protocol.eof_received) - self.loop.call_later(2 * MS, lambda: self.protocol.connection_lost(None)) + self.loop.call_later(2 * MS, self.protocol.connection_lost, None) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1006, '') @@ -383,10 +384,12 @@ def test_close_timeout_before_connection_lost(self): def test_client_close_race_with_failing_connection(self): original_write_frame = self.protocol.write_frame + @asyncio.coroutine - def delayed_write_frame(*args): - yield from original_write_frame(*args) + def delayed_write_frame(*args, **kwargs): + yield from original_write_frame(*args, **kwargs) yield from asyncio.sleep(2 * MS, loop=self.loop) + self.protocol.write_frame = delayed_write_frame frame = Frame(True, OP_CLOSE, serialize_close(1000, 'client')) @@ -405,7 +408,7 @@ def test_close_protocol_error(self): def test_close_connection_lost(self): self.loop.call_later(MS, self.protocol.eof_received) - self.loop.call_later(2 * MS, lambda: self.protocol.connection_lost(None)) + self.loop.call_later(2 * MS, self.protocol.connection_lost, None) self.loop.run_until_complete(self.protocol.close(reason='because.')) self.assertConnectionClosed(1006, '') @@ -436,7 +439,7 @@ def test_close(self): # standard server-initiated close frame = Frame(True, OP_CLOSE, serialize_close(1000, 'because.')) self.loop.call_later(MS, self.feed, frame) self.loop.call_later(2 * MS, self.protocol.eof_received) - self.loop.call_later(3 * MS, lambda: self.protocol.connection_lost(None)) + self.loop.call_later(3 * MS, self.protocol.connection_lost, None) # The client is waiting for some data at this point, and won't get it. self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) # After recv() returns None, the connection is closed. @@ -450,7 +453,7 @@ def test_close(self): # standard server-initiated close def test_client_close(self): # non standard client-initiated close self.loop.call_later(MS, self.async, self.echo()) self.loop.call_later(2 * MS, self.protocol.eof_received) - self.loop.call_later(3 * MS, lambda: self.protocol.connection_lost(None)) + self.loop.call_later(3 * MS, self.protocol.connection_lost, None) self.loop.run_until_complete(self.protocol.close(reason='because.')) self.assertConnectionClosed(1000, 'because.') # Only one frame is emitted, and it's consumed by self.echo(). @@ -465,7 +468,7 @@ def test_simultaneous_close(self): # non standard close from both sides client_close = Frame(True, OP_CLOSE, serialize_close(1000, 'client')) self.loop.call_later(MS, self.feed, server_close) self.loop.call_later(2 * MS, self.protocol.eof_received) - self.loop.call_later(3 * MS, lambda: self.protocol.connection_lost(None)) + self.loop.call_later(3 * MS, self.protocol.connection_lost, None) self.loop.run_until_complete(self.protocol.close(reason='client')) self.assertConnectionClosed(1000, 'server') self.assertFrameSent(*client_close) @@ -505,10 +508,12 @@ def test_close_timeout_before_connection_lost(self): def test_server_close_race_with_failing_connection(self): original_write_frame = self.protocol.write_frame + @asyncio.coroutine - def delayed_write_frame(*args): - yield from original_write_frame(*args) + def delayed_write_frame(*args, **kwargs): + yield from original_write_frame(*args, **kwargs) yield from asyncio.sleep(2 * MS, loop=self.loop) + self.protocol.write_frame = delayed_write_frame frame = Frame(True, OP_CLOSE, serialize_close(1000, 'server')) @@ -517,7 +522,7 @@ def delayed_write_frame(*args): self.loop.call_later(MS, self.feed, frame) self.loop.call_later(2 * MS, self.async, self.protocol.fail_connection(1000, 'client')) self.loop.call_later(3 * MS, self.protocol.eof_received) - self.loop.call_later(4 * MS, lambda: self.protocol.connection_lost(None)) + self.loop.call_later(4 * MS, self.protocol.connection_lost, None) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1000, 'client') self.assertFrameSent(*frame) diff --git a/websockets/uri.py b/websockets/uri.py index 03a977424..3b4ae7a50 100644 --- a/websockets/uri.py +++ b/websockets/uri.py @@ -13,8 +13,8 @@ from .exceptions import InvalidURI -WebSocketURI = collections.namedtuple('WebSocketURI', - ('secure', 'host', 'port', 'resource_name')) +WebSocketURI = collections.namedtuple( + 'WebSocketURI', ('secure', 'host', 'port', 'resource_name')) def parse_uri(uri): From c3a71a1405269db361b256af623daebf4202f338 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Jul 2015 21:58:49 +0200 Subject: [PATCH 0067/1539] Remove leftover debugging code. --- websockets/test_client_server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index e73ccee98..b2b461966 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -170,7 +170,6 @@ def test_subprotocol_error(self, _select_subprotocol): with self.assertRaises(InvalidHandshake): self.start_client('subprotocol', subprotocols=['otherchat']) - print(_select_subprotocol.call_args_list) @patch('websockets.server.read_request') def test_server_receives_malformed_request(self, _read_request): From 3b170ad097db8a5b86d574572d283cf981c7b36b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Jul 2015 22:06:45 +0200 Subject: [PATCH 0068/1539] Close sockets correctly in tests. This prevents spurious ResourceWarnings on Python 3.4. --- websockets/test_client_server.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index b2b461966..0b08ef802 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -92,12 +92,14 @@ def test_protocol_custom_request_headers_dict(self): req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", req_headers) + self.stop_client() def test_protocol_custom_request_headers_list(self): self.start_client('raw_headers', extra_headers=[('X-Spam', 'Eggs')]) req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", req_headers) + self.stop_client() def test_protocol_custom_response_headers_dict(self): self.stop_server() @@ -107,6 +109,7 @@ def test_protocol_custom_response_headers_dict(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) + self.stop_client() def test_protocol_custom_response_headers_list(self): self.stop_server() @@ -116,6 +119,7 @@ def test_protocol_custom_response_headers_list(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) + self.stop_client() def test_no_subprotocol(self): self.start_client('subprotocol') @@ -171,6 +175,10 @@ def test_subprotocol_error(self, _select_subprotocol): with self.assertRaises(InvalidHandshake): self.start_client('subprotocol', subprotocols=['otherchat']) + # Now the server believes the connection is open. Run the event loop + # once to make it notice the connection was closed. Interesting hack. + self.loop.run_until_complete(asyncio.sleep(0, loop=self.loop)) + @patch('websockets.server.read_request') def test_server_receives_malformed_request(self, _read_request): _read_request.side_effect = ValueError("read_request failed") From e7cdda188006cdc5946551507a252b378609c450 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Jul 2015 22:32:38 +0200 Subject: [PATCH 0069/1539] Factor out repeated code. --- websockets/test_client_server.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 0b08ef802..5c28f72c7 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -55,6 +55,11 @@ def start_client(self, path='', **kwds): def stop_client(self): self.loop.run_until_complete(self.client.worker) + def notice_connection_close(self): + # When the client closes the connection, the server still believes + # it's open until the event loop has run once. Interesting hack. + self.loop.run_until_complete(asyncio.sleep(0, loop=self.loop)) + def stop_server(self): self.server.close() self.loop.run_until_complete(self.server.wait_closed()) @@ -174,10 +179,7 @@ def test_subprotocol_error(self, _select_subprotocol): with self.assertRaises(InvalidHandshake): self.start_client('subprotocol', subprotocols=['otherchat']) - - # Now the server believes the connection is open. Run the event loop - # once to make it notice the connection was closed. Interesting hack. - self.loop.run_until_complete(asyncio.sleep(0, loop=self.loop)) + self.notice_connection_close() @patch('websockets.server.read_request') def test_server_receives_malformed_request(self, _read_request): @@ -192,10 +194,7 @@ def test_client_receives_malformed_response(self, _read_response): with self.assertRaises(InvalidHandshake): self.start_client() - - # Now the server believes the connection is open. Run the event loop - # once to make it notice the connection was closed. Interesting hack. - self.loop.run_until_complete(asyncio.sleep(0, loop=self.loop)) + self.notice_connection_close() @patch('websockets.client.build_request') def test_client_sends_invalid_handshake_request(self, _build_request): @@ -225,10 +224,7 @@ def wrong_read_response(stream): with self.assertRaises(InvalidHandshake): self.start_client() - - # Now the server believes the connection is open. Run the event loop - # once to make it notice the connection was closed. Interesting hack. - self.loop.run_until_complete(asyncio.sleep(0, loop=self.loop)) + self.notice_connection_close() @patch('websockets.server.WebSocketServerProtocol.send') def test_server_handler_crashes(self, send): From 01980b8ed6aecd86f36883293d7fc0b71c13ad8d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Jul 2015 22:37:08 +0200 Subject: [PATCH 0070/1539] Start and stop server explicitly in tests. This is a bit more verbose but avoids the awkward pattern of stopping the server and restarting it with different arguments. --- websockets/test_client_server.py | 50 +++++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 14 deletions(-) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 5c28f72c7..728ebff1d 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -38,10 +38,8 @@ class ClientServerTests(unittest.TestCase): def setUp(self): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) - self.start_server() def tearDown(self): - self.stop_server() self.loop.close() def start_server(self, **kwds): @@ -65,13 +63,16 @@ def stop_server(self): self.loop.run_until_complete(self.server.wait_closed()) def test_basic(self): + self.start_server() self.start_client() self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") self.stop_client() + self.stop_server() def test_protocol_attributes(self): + self.start_server() self.start_client('attributes') expected_attrs = ('localhost', 8642, self.secure) client_attrs = (self.client.host, self.client.port, self.client.secure) @@ -79,8 +80,10 @@ def test_protocol_attributes(self): server_attrs = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_attrs, repr(expected_attrs)) self.stop_client() + self.stop_server() def test_protocol_raw_headers(self): + self.start_server() self.start_client('raw_headers') client_req = self.client.raw_request_headers client_resp = self.client.raw_response_headers @@ -91,110 +94,117 @@ def test_protocol_raw_headers(self): self.assertEqual(server_req, repr(client_req)) self.assertEqual(server_resp, repr(client_resp)) self.stop_client() + self.stop_server() def test_protocol_custom_request_headers_dict(self): + self.start_server() self.start_client('raw_headers', extra_headers={'X-Spam': 'Eggs'}) req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", req_headers) self.stop_client() + self.stop_server() def test_protocol_custom_request_headers_list(self): + self.start_server() self.start_client('raw_headers', extra_headers=[('X-Spam', 'Eggs')]) req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", req_headers) self.stop_client() + self.stop_server() def test_protocol_custom_response_headers_dict(self): - self.stop_server() self.start_server(extra_headers={'X-Spam': 'Eggs'}) - self.start_client('raw_headers') self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) self.stop_client() + self.stop_server() def test_protocol_custom_response_headers_list(self): - self.stop_server() self.start_server(extra_headers=[('X-Spam', 'Eggs')]) - self.start_client('raw_headers') self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) self.stop_client() + self.stop_server() def test_no_subprotocol(self): + self.start_server() self.start_client('subprotocol') server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) self.stop_client() + self.stop_server() def test_subprotocol_found(self): - self.stop_server() self.start_server(subprotocols=['superchat', 'chat']) - self.start_client('subprotocol', subprotocols=['otherchat', 'chat']) server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr('chat')) self.assertEqual(self.client.subprotocol, 'chat') self.stop_client() + self.stop_server() def test_subprotocol_not_found(self): - self.stop_server() self.start_server(subprotocols=['superchat']) - self.start_client('subprotocol', subprotocols=['otherchat']) server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) self.stop_client() + self.stop_server() def test_subprotocol_not_offered(self): + self.start_server() self.start_client('subprotocol', subprotocols=['otherchat', 'chat']) server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) self.stop_client() + self.stop_server() def test_subprotocol_not_requested(self): - self.stop_server() self.start_server(subprotocols=['superchat', 'chat']) - self.start_client('subprotocol') server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) self.stop_client() + self.stop_server() @patch.object(WebSocketServerProtocol, 'select_subprotocol', autospec=True) def test_subprotocol_error(self, _select_subprotocol): _select_subprotocol.return_value = 'superchat' - self.stop_server() self.start_server(subprotocols=['superchat']) - with self.assertRaises(InvalidHandshake): self.start_client('subprotocol', subprotocols=['otherchat']) self.notice_connection_close() + self.stop_server() @patch('websockets.server.read_request') def test_server_receives_malformed_request(self, _read_request): _read_request.side_effect = ValueError("read_request failed") + self.start_server() with self.assertRaises(InvalidHandshake): self.start_client() + self.stop_server() @patch('websockets.client.read_response') def test_client_receives_malformed_response(self, _read_response): _read_response.side_effect = ValueError("read_response failed") + self.start_server() with self.assertRaises(InvalidHandshake): self.start_client() self.notice_connection_close() + self.stop_server() @patch('websockets.client.build_request') def test_client_sends_invalid_handshake_request(self, _build_request): @@ -202,8 +212,10 @@ def wrong_build_request(set_header): return '42' _build_request.side_effect = wrong_build_request + self.start_server() with self.assertRaises(InvalidHandshake): self.start_client() + self.stop_server() @patch('websockets.server.build_response') def test_server_sends_invalid_handshake_response(self, _build_response): @@ -211,8 +223,10 @@ def wrong_build_response(set_header, key): return build_response(set_header, '42') _build_response.side_effect = wrong_build_response + self.start_server() with self.assertRaises(InvalidHandshake): self.start_client() + self.stop_server() @patch('websockets.client.read_response') def test_server_does_not_switch_protocols(self, _read_response): @@ -222,19 +236,23 @@ def wrong_read_response(stream): return 400, headers _read_response.side_effect = wrong_read_response + self.start_server() with self.assertRaises(InvalidHandshake): self.start_client() self.notice_connection_close() + self.stop_server() @patch('websockets.server.WebSocketServerProtocol.send') def test_server_handler_crashes(self, send): send.side_effect = ValueError("send failed") + self.start_server() self.start_client() self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, None) self.stop_client() + self.stop_server() # Connection ends with an unexpected error. self.assertEqual(self.client.close_code, 1011) @@ -243,11 +261,13 @@ def test_server_handler_crashes(self, send): def test_server_close_crashes(self, close): close.side_effect = ValueError("close failed") + self.start_server() self.start_client() self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") self.stop_client() + self.stop_server() # Connection ends with an abnormal closure. self.assertEqual(self.client.close_code, 1006) @@ -282,9 +302,11 @@ def start_client(self, path='', **kwds): self.client = self.loop.run_until_complete(client) def test_ws_uri_is_rejected(self): + self.start_server() client = connect('ws://localhost:8642/', ssl=self.client_context) with self.assertRaises(ValueError): self.loop.run_until_complete(client) + self.stop_server() class ClientServerOriginTests(unittest.TestCase): From 1df3c9989add437fcb1ac7ffc798f7a3e75be56e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Jul 2015 22:41:28 +0200 Subject: [PATCH 0071/1539] Uniformize setup and teardown code. --- websockets/test_client_server.py | 49 ++++++++++++++------------------ 1 file changed, 22 insertions(+), 27 deletions(-) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 728ebff1d..eafc81b98 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -311,49 +311,44 @@ def test_ws_uri_is_rejected(self): class ClientServerOriginTests(unittest.TestCase): - def test_checking_origin_succeeds(self): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() - server = loop.run_until_complete( + def test_checking_origin_succeeds(self): + server = self.loop.run_until_complete( serve(handler, 'localhost', 8642, origins=['http://localhost'])) - client = loop.run_until_complete( + client = self.loop.run_until_complete( connect('ws://localhost:8642/', origin='http://localhost')) - loop.run_until_complete(client.send("Hello!")) - self.assertEqual(loop.run_until_complete(client.recv()), "Hello!") + self.loop.run_until_complete(client.send("Hello!")) + self.assertEqual(self.loop.run_until_complete(client.recv()), "Hello!") + self.loop.run_until_complete(client.close()) server.close() - loop.run_until_complete(server.wait_closed()) - loop.run_until_complete(client.worker) - loop.close() + self.loop.run_until_complete(server.wait_closed()) def test_checking_origin_fails(self): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - server = loop.run_until_complete( + server = self.loop.run_until_complete( serve(handler, 'localhost', 8642, origins=['http://localhost'])) with self.assertRaises(InvalidHandshake): - loop.run_until_complete( + self.loop.run_until_complete( connect('ws://localhost:8642/', origin='http://otherhost')) server.close() - loop.run_until_complete(server.wait_closed()) - loop.close() + self.loop.run_until_complete(server.wait_closed()) def test_checking_lack_of_origin_succeeds(self): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - server = loop.run_until_complete( + server = self.loop.run_until_complete( serve(handler, 'localhost', 8642, origins=[''])) - client = loop.run_until_complete(connect('ws://localhost:8642/')) + client = self.loop.run_until_complete(connect('ws://localhost:8642/')) - loop.run_until_complete(client.send("Hello!")) - self.assertEqual(loop.run_until_complete(client.recv()), "Hello!") + self.loop.run_until_complete(client.send("Hello!")) + self.assertEqual(self.loop.run_until_complete(client.recv()), "Hello!") + self.loop.run_until_complete(client.close()) server.close() - loop.run_until_complete(server.wait_closed()) - loop.run_until_complete(client.worker) - loop.close() + self.loop.run_until_complete(server.wait_closed()) From 4bb47d3b6f587785cba0d6210636fec126aa9be6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 20 Jul 2015 09:21:12 +0200 Subject: [PATCH 0072/1539] No one remembers what Tulip was. --- README.rst | 3 +-- docs/index.rst | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/README.rst b/README.rst index ca96d3d94..189c6b18d 100644 --- a/README.rst +++ b/README.rst @@ -11,7 +11,7 @@ concurrent applications. Installation is as simple as ``pip install websockets``. It requires Python ≥ 3.4 or Python 3.3 with the ``asyncio`` module, which is available with ``pip -install asyncio`` or in the `Tulip`_ repository. +install asyncio``. Documentation is available at http://aaugustin.github.io/websockets/. @@ -23,6 +23,5 @@ Bug reports, patches and suggestions welcome! Just open an issue_ or send a .. _RFC 6455: http://tools.ietf.org/html/rfc6455 .. _Autobahn Testsuite: https://github.com/aaugustin/websockets/blob/master/compliance/README.rst .. _PEP 3156: http://www.python.org/dev/peps/pep-3156/ -.. _Tulip: http://code.google.com/p/tulip/ .. _issue: https://github.com/aaugustin/websockets/issues/new .. _pull request: https://github.com/aaugustin/websockets/compare/ diff --git a/docs/index.rst b/docs/index.rst index b91d82efb..283de6cb2 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -13,7 +13,7 @@ concurrent applications. Installation is as simple as ``pip install websockets``. It requires Python ≥ 3.4 or Python 3.3 with the ``asyncio`` module, which is available with ``pip -install asyncio`` or in the `Tulip`_ repository. +install asyncio``. Bug reports, patches and suggestions welcome! Just open an issue_ or send a `pull request`_. @@ -23,7 +23,6 @@ Bug reports, patches and suggestions welcome! Just open an issue_ or send a .. _RFC 6455: http://tools.ietf.org/html/rfc6455 .. _Autobahn Testsuite: https://github.com/aaugustin/websockets/blob/master/compliance/README.rst .. _PEP 3156: http://www.python.org/dev/peps/pep-3156/ -.. _Tulip: http://code.google.com/p/tulip/ .. _issue: https://github.com/aaugustin/websockets/issues/new .. _pull request: https://github.com/aaugustin/websockets/compare/ From ad9f24d56dc21493dba0d88b472096b78b872ebe Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 20 Jul 2015 10:13:17 +0200 Subject: [PATCH 0073/1539] Allow customizing response headers depending on request headers. --- websockets/server.py | 8 ++++++-- websockets/test_client_server.py | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/websockets/server.py b/websockets/server.py index 43bfabffe..e8e0dcc86 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -95,7 +95,8 @@ def handshake(self, origins=None, subprotocols=None, extra_headers=None): order of decreasing preference. If provided, ``extra_headers`` sets additional HTTP response headers. - It must be a mapping or an iterable of (name, value) pairs. + It can be a mapping or an iterable of (name, value) pairs. It can also + be a callable taking the request path and headers in arguments. Return the URI of the request. """ @@ -127,6 +128,8 @@ def handshake(self, origins=None, subprotocols=None, extra_headers=None): if self.subprotocol: set_header('Sec-WebSocket-Protocol', self.subprotocol) if extra_headers is not None: + if callable(extra_headers): + extra_headers = extra_headers(path, self.raw_request_headers) if isinstance(extra_headers, collections.abc.Mapping): extra_headers = extra_headers.items() for name, value in extra_headers: @@ -182,7 +185,8 @@ def serve(ws_handler, host=None, port=None, *, * ``subprotocols`` is a list of supported subprotocols in order of decreasing preference * ``extra_headers`` sets additional HTTP response headers – it can be a - mapping or an iterable of (name, value) pairs + mapping, an iterable of (name, value) pairs, or a callable taking the + request path and headers in arguments. `serve` yields a `Server` object with a `close` method to stop the server. diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index eafc81b98..05ac9bf8e 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -114,6 +114,24 @@ def test_protocol_custom_request_headers_list(self): self.stop_client() self.stop_server() + def test_protocol_custom_response_headers_callable_dict(self): + self.start_server(extra_headers=lambda p, r: {'X-Spam': 'Eggs'}) + self.start_client('raw_headers') + self.loop.run_until_complete(self.client.recv()) + resp_headers = self.loop.run_until_complete(self.client.recv()) + self.assertIn("('X-Spam', 'Eggs')", resp_headers) + self.stop_client() + self.stop_server() + + def test_protocol_custom_response_headers_callable_list(self): + self.start_server(extra_headers=lambda p, r: [('X-Spam', 'Eggs')]) + self.start_client('raw_headers') + self.loop.run_until_complete(self.client.recv()) + resp_headers = self.loop.run_until_complete(self.client.recv()) + self.assertIn("('X-Spam', 'Eggs')", resp_headers) + self.stop_client() + self.stop_server() + def test_protocol_custom_response_headers_dict(self): self.start_server(extra_headers={'X-Spam': 'Eggs'}) self.start_client('raw_headers') From 54f32334490895d63b56ea9c872038292080365a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 20 Jul 2015 10:26:30 +0200 Subject: [PATCH 0074/1539] Return a HTTP 403 when Origin isn't allowed. --- docs/index.rst | 3 +++ websockets/exceptions.py | 6 +++++- websockets/server.py | 8 +++++--- websockets/test_client_server.py | 2 +- 4 files changed, 14 insertions(+), 5 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index 283de6cb2..424406c86 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -271,6 +271,9 @@ Changelog * Improved documentation. +* Returned a 403 error code instead of 400 when the request Origin isn't + allowed. + * Cancelling :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` no longer drops the next message. diff --git a/websockets/exceptions.py b/websockets/exceptions.py index 32b95be5d..258780eec 100644 --- a/websockets/exceptions.py +++ b/websockets/exceptions.py @@ -1,5 +1,5 @@ __all__ = [ - 'InvalidHandshake', 'InvalidState', 'InvalidURI', + 'InvalidHandshake', 'InvalidOrigin', 'InvalidState', 'InvalidURI', 'PayloadTooBig', 'WebSocketProtocolError', ] @@ -8,6 +8,10 @@ class InvalidHandshake(Exception): """Exception raised when a handshake request or response is invalid.""" +class InvalidOrigin(InvalidHandshake): + """Exception raised when the origin in a handshake request is forbidden.""" + + class InvalidState(Exception): """Exception raised when an operation is forbidden in the current state.""" diff --git a/websockets/server.py b/websockets/server.py index e8e0dcc86..67cb42f0f 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -9,7 +9,7 @@ import asyncio -from .exceptions import InvalidHandshake +from .exceptions import InvalidHandshake, InvalidOrigin from .handshake import check_request, build_response from .http import read_request, USER_AGENT from .protocol import WebSocketCommonProtocol @@ -55,7 +55,9 @@ def handler(self): extra_headers=self.extra_headers) except Exception as exc: logger.info("Exception in opening handshake: {}".format(exc)) - if isinstance(exc, InvalidHandshake): + if isinstance(exc, InvalidOrigin): + response = 'HTTP/1.1 403 Forbidden\r\n\r\n' + str(exc) + elif isinstance(exc, InvalidHandshake): response = 'HTTP/1.1 400 Bad Request\r\n\r\n' + str(exc) else: response = ('HTTP/1.1 500 Internal Server Error\r\n\r\n' @@ -113,7 +115,7 @@ def handshake(self, origins=None, subprotocols=None, extra_headers=None): if origins is not None: origin = get_header('Origin') if not set(origin.split() or ['']) <= set(origins): - raise InvalidHandshake("Bad origin: {}".format(origin)) + raise InvalidOrigin("Origin not allowed: {}".format(origin)) if subprotocols is not None: protocol = get_header('Sec-WebSocket-Protocol') diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 05ac9bf8e..aae22e453 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -352,7 +352,7 @@ def test_checking_origin_succeeds(self): def test_checking_origin_fails(self): server = self.loop.run_until_complete( serve(handler, 'localhost', 8642, origins=['http://localhost'])) - with self.assertRaises(InvalidHandshake): + with self.assertRaisesRegex(InvalidHandshake, "Bad status code: 403"): self.loop.run_until_complete( connect('ws://localhost:8642/', origin='http://otherhost')) From 1b90eb080a47e369ed01d4ff434d5c37f2a1561e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 20 Jul 2015 14:10:52 +0200 Subject: [PATCH 0075/1539] Scheme and hostname are case-insensitive. No code changes are needed because urllib.parse.urlparse already does the right thing. --- websockets/test_uri.py | 1 + 1 file changed, 1 insertion(+) diff --git a/websockets/test_uri.py b/websockets/test_uri.py index 537062023..7594b31ca 100644 --- a/websockets/test_uri.py +++ b/websockets/test_uri.py @@ -8,6 +8,7 @@ ('ws://localhost/', (False, 'localhost', 80, '/')), ('wss://localhost/', (True, 'localhost', 443, '/')), ('ws://localhost/path?query', (False, 'localhost', 80, '/path?query')), + ('WS://LOCALHOST/PATH?QUERY', (False, 'localhost', 80, '/PATH?QUERY')), ] INVALID_URIS = [ From 25aba2beceda000d89aab969fec96fc1678e6f6a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 20 Jul 2015 14:12:20 +0200 Subject: [PATCH 0076/1539] Fix a test case and add another. --- websockets/test_uri.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/websockets/test_uri.py b/websockets/test_uri.py index 7594b31ca..d1102ca65 100644 --- a/websockets/test_uri.py +++ b/websockets/test_uri.py @@ -14,7 +14,8 @@ INVALID_URIS = [ 'http://localhost/', 'https://localhost/', - 'http://localhost/path#fragment' + 'ws://localhost/path#fragment', + 'ws://user:pass@localhost/', ] From 1a1cf65b3472be3a0a061e4207edf4f4d0c5496d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 20 Jul 2015 14:21:13 +0200 Subject: [PATCH 0077/1539] Move documentation of a limitation. --- docs/index.rst | 5 +++++ websockets/client.py | 5 ----- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index 424406c86..b94f8052f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -328,6 +328,11 @@ Limitations Extensions_ aren't implemented. No extensions are registered_ at the time of writing. +The client doesn't attempt to guarantee that there is no more than one +connection to a given IP adress in a CONNECTING state. + +The client doesn't support connecting through a proxy. + .. _Extensions: http://tools.ietf.org/html/rfc6455#section-9 .. _registered: http://www.iana.org/assignments/websocket/websocket.xml diff --git a/websockets/client.py b/websockets/client.py index 5ca06271c..6c62e1ded 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -117,11 +117,6 @@ def connect(uri, *, Clients shouldn't close the WebSocket connection. Instead, they should wait until the server performs the closing handshake by yielding from the protocol's :attr:`worker` attribute. - - :func:`connect` implements the sequence called "Establish a WebSocket - Connection" in RFC 6455, except for the requirement that "there MUST be no - more than one connection in a CONNECTING state" because it cannot be - enforced at that level. """ if loop is None: loop = asyncio.get_event_loop() From 659cb30126d3cd24b4c05f84499b26efa7e891ad Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 20 Jul 2015 14:30:43 +0200 Subject: [PATCH 0078/1539] Use symbolic constants for protocol states. --- websockets/client.py | 6 ++--- websockets/protocol.py | 45 +++++++++++++++++++++---------------- websockets/server.py | 6 ++--- websockets/test_protocol.py | 8 +++---- 4 files changed, 36 insertions(+), 29 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index 6c62e1ded..8bd3124ad 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -10,7 +10,7 @@ from .exceptions import InvalidHandshake from .handshake import build_request, check_response from .http import read_response, USER_AGENT -from .protocol import WebSocketCommonProtocol +from .protocol import CONNECTING, OPEN, WebSocketCommonProtocol from .uri import parse_uri @@ -23,7 +23,7 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): """ is_client = True - state = 'CONNECTING' + state = CONNECTING @asyncio.coroutine def handshake(self, wsuri, @@ -84,7 +84,7 @@ def handshake(self, wsuri, raise InvalidHandshake( "Unknown subprotocol: {}".format(self.subprotocol)) - self.state = 'OPEN' + self.state = OPEN self.opening_handshake.set_result(True) diff --git a/websockets/protocol.py b/websockets/protocol.py index 00c11b0a0..ad432b5d9 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -24,6 +24,9 @@ logger = logging.getLogger(__name__) +CONNECTING, OPEN, CLOSING, CLOSED = range(4) + + class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): """ This class implements common parts of the WebSocket protocol. @@ -63,7 +66,7 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): # To get the client-side behavior, set is_client = True. is_client = False - state = 'OPEN' + state = OPEN def __init__(self, *, host=None, port=None, secure=None, @@ -103,9 +106,13 @@ def __init__(self, *, # In a subclass implementing the opening handshake, the state will be # CONNECTING at this point. - if self.state == 'OPEN': + if self.state == OPEN: self.opening_handshake.set_result(True) + @property + def state_name(self): + return ['CONNECTING', 'OPEN', 'CLOSING', 'CLOSED'][self.state] + # Public API @property @@ -115,7 +122,7 @@ def open(self): It may be used to handle disconnections gracefully. """ - return self.state == 'OPEN' + return self.state == OPEN @asyncio.coroutine def close(self, code=1000, reason=''): @@ -130,13 +137,13 @@ def close(self, code=1000, reason=''): The `code` must be an :class:`int` and the `reason` a :class:`str`. """ - if self.state == 'OPEN': + if self.state == OPEN: # 7.1.2. Start the WebSocket Closing Handshake self.close_code, self.close_reason = code, reason frame_data = serialize_close(code, reason) yield from self.write_frame(OP_CLOSE, frame_data) # 7.1.3. The WebSocket Closing Handshake is Started - self.state = 'CLOSING' + self.state = CLOSING # If the connection doesn't terminate within the timeout, break out of # the worker loop. @@ -324,11 +331,11 @@ def read_data_frame(self, max_size): # 5.5. Control Frames if frame.opcode == OP_CLOSE: self.close_code, self.close_reason = parse_close(frame.data) - if self.state != 'CLOSING': + if self.state != CLOSING: # 7.1.3. The WebSocket Closing Handshake is Started - self.state = 'CLOSING' + self.state = CLOSING yield from self.write_frame( - OP_CLOSE, frame.data, expected_state='CLOSING') + OP_CLOSE, frame.data, expected_state=CLOSING) if not self.closing_handshake.done(): self.closing_handshake.set_result(True) return @@ -358,11 +365,11 @@ def read_frame(self, max_size): return frame @asyncio.coroutine - def write_frame(self, opcode, data=b'', expected_state='OPEN'): + def write_frame(self, opcode, data=b'', expected_state=OPEN): # This may happen if a user attempts to write on a closed connection. if self.state != expected_state: raise InvalidState("Cannot write to a WebSocket " - "in the {} state".format(self.state)) + "in the {} state".format(self.state_name)) frame = Frame(True, opcode, data) side = 'client' if self.is_client else 'server' logger.debug("%s >> %s", side, frame) @@ -374,20 +381,20 @@ def write_frame(self, opcode, data=b'', expected_state='OPEN'): except ConnectionResetError: # Terminate the connection if the socket died, # unless it's already being closed. - if expected_state != 'CLOSING': - self.state = 'CLOSING' + if expected_state != CLOSING: + self.state = CLOSING yield from self.fail_connection(1006) @asyncio.coroutine def close_connection(self): # 7.1.1. Close the WebSocket Connection - if self.state == 'CLOSED': + if self.state == CLOSED: return # Defensive assertion for protocol compliance. - if self.state != 'CLOSING': # pragma: no cover + if self.state != CLOSING: # pragma: no cover raise InvalidState("Cannot close a WebSocket connection " - "in the {} state".format(self.state)) + "in the {} state".format(self.state_name)) if self.is_client: try: @@ -396,7 +403,7 @@ def close_connection(self): except (asyncio.CancelledError, asyncio.TimeoutError): pass - if self.state == 'CLOSED': + if self.state == CLOSED: return # Attempt to terminate the TCP connection properly. @@ -432,10 +439,10 @@ def fail_connection(self, code=1011, reason=''): self.close_code, self.close_reason = code, reason # 7.1.7. Fail the WebSocket Connection logger.info("Failing the WebSocket connection: %d %s", code, reason) - if self.state == 'OPEN': + if self.state == OPEN: frame_data = serialize_close(code, reason) yield from self.write_frame(OP_CLOSE, frame_data) - self.state = 'CLOSING' + self.state = CLOSING if not self.closing_handshake.done(): self.closing_handshake.set_result(False) yield from self.close_connection() @@ -448,7 +455,7 @@ def client_connected(self, reader, writer): def connection_lost(self, exc): # 7.1.4. The WebSocket Connection is Closed - self.state = 'CLOSED' + self.state = CLOSED if not self.connection_closed.done(): self.connection_closed.set_result(None) if self.close_code is None: diff --git a/websockets/server.py b/websockets/server.py index 67cb42f0f..06573c8d2 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -12,7 +12,7 @@ from .exceptions import InvalidHandshake, InvalidOrigin from .handshake import check_request, build_response from .http import read_request, USER_AGENT -from .protocol import WebSocketCommonProtocol +from .protocol import CONNECTING, OPEN, WebSocketCommonProtocol logger = logging.getLogger(__name__) @@ -29,7 +29,7 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): implementation. Its support for HTTP responses is very limited. """ - state = 'CONNECTING' + state = CONNECTING def __init__(self, ws_handler, *, origins=None, subprotocols=None, extra_headers=None, **kwds): @@ -147,7 +147,7 @@ def handshake(self, origins=None, subprotocols=None, extra_headers=None): response = '\r\n'.join(response).encode() self.writer.write(response) - self.state = 'OPEN' + self.state = OPEN self.opening_handshake.set_result(True) return path diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 309f7c7e2..e2d5510dd 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -6,7 +6,7 @@ from .exceptions import InvalidState from .framing import * -from .protocol import WebSocketCommonProtocol +from .protocol import CLOSED, CLOSING, WebSocketCommonProtocol MS = 0.001 # Unit for timeouts. May be increased on slow machines. @@ -75,7 +75,7 @@ def assertNoFrameSent(self): self.assertIsNone(sent) def assertConnectionClosed(self, code, message): - self.assertEqual(self.protocol.state, 'CLOSED') + self.assertEqual(self.protocol.state, CLOSED) self.assertEqual(self.protocol.close_code, code) self.assertEqual(self.protocol.close_reason, message) @@ -377,7 +377,7 @@ def test_close_timeout_before_connection_lost(self): self.protocol.timeout = 5 * MS self.loop.call_later(MS, self.async, self.echo()) self.loop.run_until_complete(self.protocol.close(reason='because.')) - self.assertEqual(self.protocol.state, 'CLOSING') + self.assertEqual(self.protocol.state, CLOSING) self.assertTrue(self.after.cancelled()) self.assertFalse(self.before.cancelled()) self.before.cancel() @@ -501,7 +501,7 @@ def test_close_timeout_before_connection_lost(self): self.loop.call_later(2 * MS, self.protocol.eof_received) self.loop.run_until_complete(self.protocol.close(reason='because.')) # If the server doesn't drop the connection quickly, the client will. - self.assertEqual(self.protocol.state, 'CLOSING') + self.assertEqual(self.protocol.state, CLOSING) self.assertTrue(self.after.cancelled()) self.assertFalse(self.before.cancelled()) self.before.cancel() From 739f9d1057cea1e0039b664bbf5f6e88138c46c6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 20 Jul 2015 14:47:38 +0200 Subject: [PATCH 0079/1539] Normalize import style with isort. --- setup.cfg | 4 ++++ setup.py | 2 ++ tox.ini | 7 ++++++- websockets/client.py | 2 +- websockets/framing.py | 5 ++--- websockets/http.py | 3 +-- websockets/protocol.py | 9 ++++----- websockets/server.py | 7 +++---- websockets/test_client_server.py | 24 ++++++++++++------------ websockets/test_framing.py | 5 ++--- websockets/test_handshake.py | 2 +- websockets/test_http.py | 5 ++--- websockets/test_protocol.py | 13 ++++++------- 13 files changed, 46 insertions(+), 42 deletions(-) diff --git a/setup.cfg b/setup.cfg index 0530ab2e0..d1368ad2f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,3 +3,7 @@ python-tag = py33.py34 [flake8] ignore = F403 + +[isort] +known_standard_library = asyncio +lines_after_imports = 2 diff --git a/setup.py b/setup.py index 72b6575dd..4f9eaa8cb 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,9 @@ import os import sys + import setuptools + # Avoid polluting the .tar.gz with ._* files under Mac OS X os.putenv('COPYFILE_DISABLE', 'true') diff --git a/tox.ini b/tox.ini index fe9b0af07..78c9dd8e4 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py33,py34,flake8 +envlist = py33,py34,flake8,isort [testenv] deps = @@ -10,3 +10,8 @@ commands = python -m unittest commands = flake8 websockets deps = flake8 + +[testenv:isort] +commands = isort --check-only --recursive websockets +deps = + isort diff --git a/websockets/client.py b/websockets/client.py index 8bd3124ad..4c37824a4 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -9,7 +9,7 @@ from .exceptions import InvalidHandshake from .handshake import build_request, check_response -from .http import read_response, USER_AGENT +from .http import USER_AGENT, read_response from .protocol import CONNECTING, OPEN, WebSocketCommonProtocol from .uri import parse_uri diff --git a/websockets/framing.py b/websockets/framing.py index a6d45b366..db9112b84 100644 --- a/websockets/framing.py +++ b/websockets/framing.py @@ -8,14 +8,13 @@ .. _section 5 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-5 """ +import asyncio import collections import io import random import struct -import asyncio - -from .exceptions import WebSocketProtocolError, PayloadTooBig +from .exceptions import PayloadTooBig, WebSocketProtocolError __all__ = [ diff --git a/websockets/http.py b/websockets/http.py index 93f151dc6..e99a6ee9e 100644 --- a/websockets/http.py +++ b/websockets/http.py @@ -8,12 +8,11 @@ __all__ = ['read_request', 'read_response', 'USER_AGENT'] +import asyncio import email.parser import io import sys -import asyncio - from .version import version as websockets_version diff --git a/websockets/protocol.py b/websockets/protocol.py index ad432b5d9..c465454fa 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -7,15 +7,14 @@ __all__ = ['WebSocketCommonProtocol'] +import asyncio +import asyncio.queues import codecs import collections import logging import random import struct -import asyncio -from asyncio.queues import Queue, QueueEmpty - from .exceptions import InvalidState, PayloadTooBig, WebSocketProtocolError from .framing import * from .handshake import * @@ -96,7 +95,7 @@ def __init__(self, *, self.connection_closed = asyncio.Future(loop=loop) # Queue of received messages. - self.messages = Queue(loop=loop) + self.messages = asyncio.queues.Queue(loop=loop) # Mapping of ping IDs to waiters, in chronological order. self.pings = collections.OrderedDict() @@ -171,7 +170,7 @@ def recv(self): # Return any available message try: return self.messages.get_nowait() - except QueueEmpty: + except asyncio.queues.QueueEmpty: pass # Wait for a message until the connection is closed diff --git a/websockets/server.py b/websockets/server.py index 06573c8d2..74c738097 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -4,14 +4,13 @@ __all__ = ['serve', 'WebSocketServerProtocol'] +import asyncio import collections.abc import logging -import asyncio - from .exceptions import InvalidHandshake, InvalidOrigin -from .handshake import check_request, build_response -from .http import read_request, USER_AGENT +from .handshake import build_response, check_request +from .http import USER_AGENT, read_request from .protocol import CONNECTING, OPEN, WebSocketCommonProtocol diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index aae22e453..d23051581 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -1,14 +1,13 @@ +import asyncio import logging import os import ssl import unittest -from unittest.mock import patch - -import asyncio +import unittest.mock from .client import * from .exceptions import InvalidHandshake -from .http import read_response, USER_AGENT +from .http import USER_AGENT, read_response from .server import * @@ -195,7 +194,8 @@ def test_subprotocol_not_requested(self): self.stop_client() self.stop_server() - @patch.object(WebSocketServerProtocol, 'select_subprotocol', autospec=True) + @unittest.mock.patch.object( + WebSocketServerProtocol, 'select_subprotocol', autospec=True) def test_subprotocol_error(self, _select_subprotocol): _select_subprotocol.return_value = 'superchat' @@ -205,7 +205,7 @@ def test_subprotocol_error(self, _select_subprotocol): self.notice_connection_close() self.stop_server() - @patch('websockets.server.read_request') + @unittest.mock.patch('websockets.server.read_request') def test_server_receives_malformed_request(self, _read_request): _read_request.side_effect = ValueError("read_request failed") @@ -214,7 +214,7 @@ def test_server_receives_malformed_request(self, _read_request): self.start_client() self.stop_server() - @patch('websockets.client.read_response') + @unittest.mock.patch('websockets.client.read_response') def test_client_receives_malformed_response(self, _read_response): _read_response.side_effect = ValueError("read_response failed") @@ -224,7 +224,7 @@ def test_client_receives_malformed_response(self, _read_response): self.notice_connection_close() self.stop_server() - @patch('websockets.client.build_request') + @unittest.mock.patch('websockets.client.build_request') def test_client_sends_invalid_handshake_request(self, _build_request): def wrong_build_request(set_header): return '42' @@ -235,7 +235,7 @@ def wrong_build_request(set_header): self.start_client() self.stop_server() - @patch('websockets.server.build_response') + @unittest.mock.patch('websockets.server.build_response') def test_server_sends_invalid_handshake_response(self, _build_response): def wrong_build_response(set_header, key): return build_response(set_header, '42') @@ -246,7 +246,7 @@ def wrong_build_response(set_header, key): self.start_client() self.stop_server() - @patch('websockets.client.read_response') + @unittest.mock.patch('websockets.client.read_response') def test_server_does_not_switch_protocols(self, _read_response): @asyncio.coroutine def wrong_read_response(stream): @@ -260,7 +260,7 @@ def wrong_read_response(stream): self.notice_connection_close() self.stop_server() - @patch('websockets.server.WebSocketServerProtocol.send') + @unittest.mock.patch('websockets.server.WebSocketServerProtocol.send') def test_server_handler_crashes(self, send): send.side_effect = ValueError("send failed") @@ -275,7 +275,7 @@ def test_server_handler_crashes(self, send): # Connection ends with an unexpected error. self.assertEqual(self.client.close_code, 1011) - @patch('websockets.server.WebSocketServerProtocol.close') + @unittest.mock.patch('websockets.server.WebSocketServerProtocol.close') def test_server_close_crashes(self, close): close.side_effect = ValueError("close failed") diff --git a/websockets/test_framing.py b/websockets/test_framing.py index cf2a678eb..7fe914395 100644 --- a/websockets/test_framing.py +++ b/websockets/test_framing.py @@ -1,9 +1,8 @@ +import asyncio import io import unittest -import asyncio - -from .exceptions import WebSocketProtocolError, PayloadTooBig +from .exceptions import PayloadTooBig, WebSocketProtocolError from .framing import * diff --git a/websockets/test_handshake.py b/websockets/test_handshake.py index 8a6254822..c859b20e8 100644 --- a/websockets/test_handshake.py +++ b/websockets/test_handshake.py @@ -2,7 +2,7 @@ from .exceptions import InvalidHandshake from .handshake import * -from .handshake import accept # private API +from .handshake import accept # private API class HandshakeTests(unittest.TestCase): diff --git a/websockets/test_http.py b/websockets/test_http.py index 0607b8196..b31bd84d0 100644 --- a/websockets/test_http.py +++ b/websockets/test_http.py @@ -1,9 +1,8 @@ -import unittest - import asyncio +import unittest from .http import * -from .http import read_message # private API +from .http import read_message # private API class HTTPTests(unittest.TestCase): diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index e2d5510dd..36153f487 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -1,9 +1,8 @@ +import asyncio +import functools import unittest import unittest.mock -import asyncio -from functools import partial - from .exceptions import InvalidState from .framing import * from .protocol import CLOSED, CLOSING, WebSocketCommonProtocol @@ -25,14 +24,14 @@ def setUp(self): side_effect=lambda: self.protocol.connection_lost(None)) self.protocol.connection_made(self.transport) - @property - def async(self): - return partial(asyncio.async, loop=self.loop) - def tearDown(self): self.loop.close() super().tearDown() + @property + def async(self): + return functools.partial(asyncio.async, loop=self.loop) + def feed(self, frame): """Feed a frame to the protocol.""" mask = not self.protocol.is_client From 9638bee3b7890c48529e7b93f167d0fdbe14eff1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 20 Jul 2015 14:48:59 +0200 Subject: [PATCH 0080/1539] Fix remaining flake8 warnings. --- docs/index.rst | 4 ++-- websockets/test_protocol.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index b94f8052f..2add66d1e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -92,8 +92,8 @@ Here's a corresponding client example. listener_task = asyncio.ensure_future(websocket.recv()) producer_task = asyncio.ensure_future(producer()) done, pending = yield from asyncio.wait( - [listener_task, producer_task], - return_when=asyncio.FIRST_COMPLETED) + [listener_task, producer_task], + return_when=asyncio.FIRST_COMPLETED) if listener_task in done: message = listener_task.result() diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 36153f487..1f0a8ab6a 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -395,7 +395,8 @@ def delayed_write_frame(*args, **kwargs): # Trigger the race condition between answering the close frame from # the client and sending another close frame from the server. self.loop.call_later(MS, self.feed, frame) - self.loop.call_later(2 * MS, self.async, self.protocol.fail_connection(1000, 'server')) + fail_connection = self.protocol.fail_connection(1000, 'server') + self.loop.call_later(2 * MS, self.async, fail_connection) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1000, 'server') self.assertFrameSent(*frame) @@ -519,7 +520,8 @@ def delayed_write_frame(*args, **kwargs): # Trigger the race condition between answering the close frame from # the server and sending another close frame from the client. self.loop.call_later(MS, self.feed, frame) - self.loop.call_later(2 * MS, self.async, self.protocol.fail_connection(1000, 'client')) + fail_connection = self.protocol.fail_connection(1000, 'client') + self.loop.call_later(2 * MS, self.async, fail_connection) self.loop.call_later(3 * MS, self.protocol.eof_received) self.loop.call_later(4 * MS, self.protocol.connection_lost, None) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) From 25d8de0ab0e4c9ffda63997066ebd91959cb1d56 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 20 Jul 2015 15:42:18 +0200 Subject: [PATCH 0081/1539] Provide access to HTTP headers as MIME messages. Document the private API to acces them as a list of (name, value) pairs. Fix #24. --- docs/index.rst | 2 ++ websockets/client.py | 7 +++++++ websockets/protocol.py | 17 +++++++++++++++-- websockets/server.py | 7 +++++++ websockets/test_client_server.py | 17 +++++++++++++++++ 5 files changed, 48 insertions(+), 2 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index 2add66d1e..cb0971ecc 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -265,6 +265,8 @@ Changelog 2.5 ... +* Provided access to handshake request and response HTTP headers. + * Allowed customizing handshake request and response HTTP headers. * Supported running on a non-default event loop. diff --git a/websockets/client.py b/websockets/client.py index 4c37824a4..26d9e5078 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -6,6 +6,7 @@ import asyncio import collections.abc +import email.message from .exceptions import InvalidHandshake from .handshake import build_request, check_response @@ -57,6 +58,9 @@ def handshake(self, wsuri, set_header('User-Agent', USER_AGENT) key = build_request(set_header) + self.request_headers = email.message.Message() + for name, value in headers: + self.request_headers[name] = value self.raw_request_headers = headers # Send handshake request. Since the URI and the headers only contain @@ -74,7 +78,10 @@ def handshake(self, wsuri, raise InvalidHandshake("Malformed HTTP message") from exc if status_code != 101: raise InvalidHandshake("Bad status code: {}".format(status_code)) + + self.response_headers = headers self.raw_response_headers = list(headers.raw_items()) + get_header = lambda k: headers.get(k, '') check_response(get_header, key) diff --git a/websockets/protocol.py b/websockets/protocol.py index c465454fa..bc9465469 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -52,8 +52,16 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): message larger than the maximum size is received, :meth:`recv()` will return ``None`` and the connection will be closed with status code 1009. - Once the handshake is complete, if a subprotocol was negotiated, it's - available in the :attr:`subprotocol` attribute. + Once the handshake is complete, request and response HTTP headers are + available: + + * as a MIME :class:`~email.message.Message` in the ``request_headers`` and + ``response_headers`` attributes + * as an iterable of (name, value) pairs in the ``raw_request_headers`` and + ``raw_response_headers`` attributes + + If a subprotocol was negotiated, it's available in the :attr:`subprotocol` + attribute. Once the connection is closed, the status code is available in the :attr:`close_code` attribute and the reason in :attr:`close_reason`. @@ -83,6 +91,11 @@ def __init__(self, *, stream_reader = asyncio.StreamReader(loop=loop) super().__init__(stream_reader, self.client_connected, loop) + self.request_headers = None + self.raw_request_headers = None + self.response_headers = None + self.raw_response_headers = None + self.subprotocol = None self.close_code = None diff --git a/websockets/server.py b/websockets/server.py index 74c738097..0f24aa878 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -6,6 +6,7 @@ import asyncio import collections.abc +import email.message import logging from .exceptions import InvalidHandshake, InvalidOrigin @@ -107,7 +108,9 @@ def handshake(self, origins=None, subprotocols=None, extra_headers=None): except Exception as exc: raise InvalidHandshake("Malformed HTTP message") from exc + self.request_headers = headers self.raw_request_headers = list(headers.raw_items()) + get_header = lambda k: headers.get(k, '') key = check_request(get_header) @@ -136,6 +139,10 @@ def handshake(self, origins=None, subprotocols=None, extra_headers=None): for name, value in extra_headers: set_header(name, value) build_response(set_header, key) + + self.response_headers = email.message.Message() + for name, value in headers: + self.response_headers[name] = value self.raw_response_headers = headers # Send handshake response. Since the status line and headers only diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index d23051581..0fdf170ef 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -21,6 +21,9 @@ def handler(ws, path): if path == '/attributes': yield from ws.send(repr((ws.host, ws.port, ws.secure))) + elif path == '/headers': + yield from ws.send(str(ws.request_headers)) + yield from ws.send(str(ws.response_headers)) elif path == '/raw_headers': yield from ws.send(repr(ws.raw_request_headers)) yield from ws.send(repr(ws.raw_response_headers)) @@ -81,6 +84,20 @@ def test_protocol_attributes(self): self.stop_client() self.stop_server() + def test_protocol_headers(self): + self.start_server() + self.start_client('headers') + client_req = self.client.request_headers + client_resp = self.client.response_headers + self.assertEqual(client_req['User-Agent'], USER_AGENT) + self.assertEqual(client_resp['Server'], USER_AGENT) + server_req = self.loop.run_until_complete(self.client.recv()) + server_resp = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(server_req, str(client_req)) + self.assertEqual(server_resp, str(client_resp)) + self.stop_client() + self.stop_server() + def test_protocol_raw_headers(self): self.start_server() self.start_client('raw_headers') From 8971978a7ccb2fc0a43739239912ba073905e536 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 20 Jul 2015 21:17:19 +0200 Subject: [PATCH 0082/1539] Improve the cheat sheet a bit. --- docs/index.rst | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index cb0971ecc..5abde0cb2 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -141,8 +141,8 @@ Server execute the application logic, and finally closes the connection after the handler returns. - * You may subclass :class:`~websockets.server.WebSocketServerProtocol` if - you have an advanced use case. + * You may subclass :class:`~websockets.server.WebSocketServerProtocol` and + pass it in the ``klass`` keyword argument for advanced customization. Client ...... @@ -150,6 +150,9 @@ Client * Create a server with :func:`~websockets.client.connect` which is similar to asyncio's :meth:`~asyncio.BaseEventLoop.create_connection`. + * You may subclass :class:`~websockets.server.WebSocketClientProtocol` and + pass it in the ``klass`` keyword argument for advanced customization. + * Call :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` and :meth:`~websockets.protocol.WebSocketCommonProtocol.send` to receive and send messages at any time. From c7ee9f2076b02796c76b7816adb334c74c258697 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 20 Jul 2015 21:17:33 +0200 Subject: [PATCH 0083/1539] Add debugging advice. Try to deflect asyncio-related problems to Python's support channels. --- docs/index.rst | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/docs/index.rst b/docs/index.rst index 5abde0cb2..7eba30df3 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -164,6 +164,29 @@ Client * Call :meth:`~websockets.protocol.WebSocketCommonProtocol.close` to terminate the connection. +Debugging +......... + +If you don't understand what ``websockets`` is doing, enable logging:: + + import logging + logger = logging.getLogger('websockets') + logger.setLevel(logging.INFO) + logger.addHandler(logging.StreamHandler()) + +The logs contains: + +* Exceptions in the connection handler at the ``ERROR`` level +* Exceptions in the opening or closing handshake at the ``INFO`` level +* All frames at the ``DEBUG`` level — this can be very verbose + +If you're new to ``asyncio``, you will certainly encounter issues that are +related to asynchronous programming in general rather than to ``websockets`` +in particular. Fortunately Python's official documentation provides advice to +`develop with asyncio`_. Check it out: it's invaluable! + +.. _develop with asyncio: https://docs.python.org/3/library/asyncio-dev.html + Design ------ From 7cb30eaa88fa0b1352d7cbcb0950c74d6cf88eac Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 20 Jul 2015 22:11:25 +0200 Subject: [PATCH 0084/1539] Proof-read and improve the documentation. * Explain how to serve or connect to wss:// endpoints. * Make WebSocketURI a public API. * Add docstrings to namedtuples. * Add cross-references. * Normalize markup. --- docs/index.rst | 14 +++++++------- websockets/client.py | 16 ++++++++++------ websockets/framing.py | 32 +++++++++++++++++++++----------- websockets/handshake.py | 23 ++++++++++++----------- websockets/http.py | 21 +++++++++++---------- websockets/protocol.py | 18 +++++++++--------- websockets/server.py | 39 ++++++++++++++++++++------------------- websockets/uri.py | 15 ++++++++++----- 8 files changed, 100 insertions(+), 78 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index 7eba30df3..ba778f9ee 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -252,15 +252,15 @@ Shared .. automethod:: ping(data=None) .. automethod:: pong(data=b'') -Low-level API -------------- - Exceptions .......... .. automodule:: websockets.exceptions :members: +Low-level API +------------- + Opening handshake ................. @@ -291,14 +291,14 @@ Changelog 2.5 ... +* Improved documentation. + * Provided access to handshake request and response HTTP headers. * Allowed customizing handshake request and response HTTP headers. * Supported running on a non-default event loop. -* Improved documentation. - * Returned a 403 error code instead of 400 when the request Origin isn't allowed. @@ -312,7 +312,7 @@ Changelog * Supported non-default event loop. -* Added `loop` argument to :func:`~websockets.client.connect` and +* Added ``loop`` argument to :func:`~websockets.client.connect` and :func:`~websockets.server.serve`. 2.3 @@ -328,7 +328,7 @@ Changelog 2.1 ... -* Added `host`, `port` and `secure` attributes on protocols. +* Added ``host``, ``port`` and ``secure`` attributes on protocols. * Added support for providing and checking Origin_. diff --git a/websockets/client.py b/websockets/client.py index 26d9e5078..3c648d787 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -17,7 +17,7 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): """ - Complete WebSocket client implementation as an :mod:`asyncio` protocol. + Complete WebSocket client implementation as an :class:`asyncio.Protocol`. This class inherits most of its methods from :class:`~websockets.protocol.WebSocketCommonProtocol`. @@ -103,11 +103,15 @@ def connect(uri, *, """ This coroutine connects to a WebSocket server. - It's a thin wrapper around the event loop's + It's a wrapper around the event loop's :meth:`~asyncio.BaseEventLoop.create_connection` method. Extra keyword arguments are passed to :meth:`~asyncio.BaseEventLoop.create_connection`. + For example, you can set the ``ssl`` keyword argument to a + :class:`~ssl.SSLContext` to enforce some TLS settings. When connecting to + a ``wss://`` URI, if this argument isn't provided explicitly, it's set to + ``True``, which means Python's default :class:`~ssl.SSLContext` is used. - This coroutine accepts several optional arguments: + :func:`connect` accepts several optional arguments: * ``origin`` sets the Origin HTTP header * ``subprotocols`` is a list of supported subprotocols in order of @@ -115,10 +119,10 @@ def connect(uri, *, * ``extra_headers`` sets additional HTTP request headers – it can be a mapping or an iterable of (name, value) pairs - It returns a :class:`~websockets.client.WebSocketClientProtocol` which can - then be used to send and receive messages. + :func:`connect` yields a :class:`WebSocketClientProtocol` which can then + be used to send and receive messages. - It raises :exc:`~websockets.uri.InvalidURI` if `uri` is invalid and + It raises :exc:`~websockets.uri.InvalidURI` if ``uri`` is invalid and :exc:`~websockets.handshake.InvalidHandshake` if the handshake fails. Clients shouldn't close the WebSocket connection. Instead, they should diff --git a/websockets/framing.py b/websockets/framing.py index db9112b84..b1c13aad2 100644 --- a/websockets/framing.py +++ b/websockets/framing.py @@ -43,6 +43,16 @@ Frame = collections.namedtuple('Frame', ('fin', 'opcode', 'data')) +Frame.__doc__ = """WebSocket frame. + +* ``fin`` is the FIN bit +* ``opcode`` is the opcode +* ``data`` is the payload data + +Only these three fields are needed by higher level code. The MASK bit, payload +length and masking-key are handled on the fly by :func:`read_frame` and +:func:`write_frame`. +""" @asyncio.coroutine @@ -50,13 +60,13 @@ def read_frame(reader, mask, *, max_size=None): """ Read a WebSocket frame and return a :class:`Frame` object. - `reader` is a coroutine taking an integer argument and reading exactly this - number of bytes, unless the end of file is reached. + ``reader`` is a coroutine taking an integer argument and reading exactly + this number of bytes, unless the end of file is reached. - `mask` is a :class:`bool` telling whether the frame should be masked, ie. - whether the read happens on the server side. + ``mask`` is a :class:`bool` telling whether the frame should be masked + i.e. whether the read happens on the server side. - If `max_size` is set and the payload exceeds this size in bytes, + If ``max_size`` is set and the payload exceeds this size in bytes, :exc:`~websockets.exceptions.PayloadTooBig` is raised. This function validates the frame before returning it and raises @@ -99,12 +109,12 @@ def write_frame(frame, writer, mask): """ Write a WebSocket frame. - `frame` is the :class:`Frame` object to write. + ``frame`` is the :class:`Frame` object to write. - `writer` is a function accepting bytes. + ``writer`` is a function accepting bytes. - `mask` is a :class:`bool` telling whether the frame should be masked, ie. - whether the write happens on the client side. + ``mask`` is a :class:`bool` telling whether the frame should be masked + i.e. whether the write happens on the client side. This function validates the frame before sending it and raises :exc:`~websockets.exceptions.WebSocketProtocolError` if it contains @@ -157,8 +167,8 @@ def parse_close(data): """ Parse the data in a close frame. - Return `(code, reason)` when `code` is an :class:`int` and `reason` a - :class:`str`. + Return ``(code, reason)`` when ``code`` is an :class:`int` and ``reason`` + a :class:`str`. Raise :exc:`~websockets.exceptions.WebSocketProtocolError` or :exc:`UnicodeDecodeError` if the data is invalid. diff --git a/websockets/handshake.py b/websockets/handshake.py index 799c93b40..74cfd966f 100644 --- a/websockets/handshake.py +++ b/websockets/handshake.py @@ -7,11 +7,11 @@ It provides functions to implement the handshake with any existing HTTP library. You must pass to these functions: -- A `set_header` function accepting a header name and a header value, -- A `get_header` function accepting a header name and returning the header +- A ``set_header`` function accepting a header name and a header value, +- A ``get_header`` function accepting a header name and returning the header value. -The inputs and outputs of `get_header` and `set_header` are :class:`str` +The inputs and outputs of ``get_header`` and ``set_header`` are :class:`str` objects containing only ASCII characters. Some checks cannot be performed because they depend too much on the @@ -22,7 +22,8 @@ - Read the request, check that the method is GET, and check the headers with :func:`check_request`, - Send a 101 response to the client with the headers created by - :func:`build_response` if the request is valid; otherwise, send a 400. + :func:`build_response` if the request is valid; otherwise, send an + appropriate HTTP error code. To open a connection, a client must: @@ -51,7 +52,7 @@ def build_request(set_header): """ Build a handshake request to send to the server. - Return the `key` which must be passed to :func:`check_response`. + Return the ``key`` which must be passed to :func:`check_response`. """ rand = bytes(random.getrandbits(8) for _ in range(16)) key = base64.b64encode(rand).decode() @@ -66,11 +67,11 @@ def check_request(get_header): """ Check a handshake request received from the client. - If the handshake is valid, this function returns the `key` which must be + If the handshake is valid, this function returns the ``key`` which must be passed to :func:`build_response`. - Otherwise, it raises an :exc:`~websockets.exceptions.InvalidHandshake` - exception and the server must return an error, usually 400 Bad Request. + Otherwise it raises an :exc:`~websockets.exceptions.InvalidHandshake` + exception and the server must return an error like 400 Bad Request. This function doesn't verify that the request is an HTTP/1.1 or higher GET request and doesn't perform Host and Origin checks. These controls are @@ -94,7 +95,7 @@ def build_response(set_header, key): """ Build a handshake response to send to the client. - `key` comes from :func:`check_request`. + ``key`` comes from :func:`check_request`. """ set_header('Upgrade', 'WebSocket') set_header('Connection', 'Upgrade') @@ -105,11 +106,11 @@ def check_response(get_header, key): """ Check a handshake response received from the server. - `key` comes from :func:`build_request`. + ``key`` comes from :func:`build_request`. If the handshake is valid, this function returns ``None``. - Otherwise, it raises an :exc:`~websockets.exceptions.InvalidHandshake` + Otherwise it raises an :exc:`~websockets.exceptions.InvalidHandshake` exception. This function doesn't verify that the response is an HTTP/1.1 or higher diff --git a/websockets/http.py b/websockets/http.py index e99a6ee9e..6e1972fa8 100644 --- a/websockets/http.py +++ b/websockets/http.py @@ -28,10 +28,11 @@ @asyncio.coroutine def read_request(stream): """ - Read an HTTP/1.1 request from `stream`. + Read an HTTP/1.1 request from ``stream``. - Return `(path, headers)` where `path` is a :class:`str` and `headers` is a - :class:`~email.message.Message`; `path` isn't URL-decoded. + Return ``(path, headers)`` where ``path`` is a :class:`str` and + ``headers`` is a :class:`~email.message.Message`. ``path`` isn't + URL-decoded. Raise an exception if the request isn't well formatted. @@ -49,10 +50,10 @@ def read_request(stream): @asyncio.coroutine def read_response(stream): """ - Read an HTTP/1.1 response from `stream`. + Read an HTTP/1.1 response from ``stream``. - Return `(status, headers)` where `status` is a :class:`int` and `headers` - is a :class:`~email.message.Message`. + Return ``(status, headers)`` where ``status`` is a :class:`int` and + ``headers`` is a :class:`~email.message.Message`. Raise an exception if the request isn't well formatted. @@ -68,10 +69,10 @@ def read_response(stream): @asyncio.coroutine def read_message(stream): """ - Read an HTTP message from `stream`. + Read an HTTP message from ``stream``. - Return `(start_line, headers)` where `start_line` is :class:`bytes` and - `headers` is a :class:`~email.message.Message`. + Return ``(start_line, headers)`` where ``start_line`` is :class:`bytes` + and ``headers`` is a :class:`~email.message.Message`. The message is assumed not to contain a body. """ @@ -92,7 +93,7 @@ def read_message(stream): @asyncio.coroutine def read_line(stream): """ - Read a single line from `stream`. + Read a single line from ``stream``. """ line = yield from stream.readline() if len(line) > MAX_LINE: diff --git a/websockets/protocol.py b/websockets/protocol.py index bc9465469..7e4d97ca8 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -39,15 +39,15 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): control frames automatically. It sends outgoing data frames and performs the closing handshake. - The `host`, `port` and `secure` parameters are simply stored as attributes - for handlers that need them. + The ``host``, ``port`` and ``secure`` parameters are simply stored as + attributes for handlers that need them. - The `timeout` parameter defines the maximum wait time in seconds for + The ``timeout`` parameter defines the maximum wait time in seconds for completing the closing handshake and, only on the client side, for terminating the TCP connection. :meth:`close()` will complete in at most this time on the server side and twice this time on the client side. - The `max_size` parameter enforces the maximum size for incoming messages + The ``max_size`` parameter enforces the maximum size for incoming messages in bytes. The default value is 1MB. ``None`` disables the limit. If a message larger than the maximum size is received, :meth:`recv()` will return ``None`` and the connection will be closed with status code 1009. @@ -55,10 +55,10 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): Once the handshake is complete, request and response HTTP headers are available: - * as a MIME :class:`~email.message.Message` in the ``request_headers`` and - ``response_headers`` attributes - * as an iterable of (name, value) pairs in the ``raw_request_headers`` and - ``raw_response_headers`` attributes + * as a MIME :class:`~email.message.Message` in the :attr:`request_headers` + and :attr:`response_headers` attributes + * as an iterable of (name, value) pairs in the :attr:`raw_request_headers` + and :attr:`raw_response_headers` attributes If a subprotocol was negotiated, it's available in the :attr:`subprotocol` attribute. @@ -147,7 +147,7 @@ def close(self, code=1000, reason=''): It's usually safe to wrap this coroutine in :func:`~asyncio.async` since errors during connection termination aren't particularly useful. - The `code` must be an :class:`int` and the `reason` a :class:`str`. + ``code`` must be an :class:`int` and ``reason`` a :class:`str`. """ if self.state == OPEN: # 7.1.2. Start the WebSocket Closing Handshake diff --git a/websockets/server.py b/websockets/server.py index 0f24aa878..997440a3a 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -20,13 +20,13 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): """ - Complete WebSocket server implementation as an :mod:`asyncio` protocol. + Complete WebSocket server implementation as an :class:`asyncio.Protocol`. This class inherits most of its methods from :class:`~websockets.protocol.WebSocketCommonProtocol`. - For the sake of simplicity, this protocol doesn't inherit a proper HTTP - implementation. Its support for HTTP responses is very limited. + For the sake of simplicity, it doesn't rely on a full HTTP implementation. + Its support for HTTP responses is very limited. """ state = CONNECTING @@ -177,40 +177,41 @@ def serve(ws_handler, host=None, port=None, *, """ This coroutine creates a WebSocket server. - It's a thin wrapper around the event loop's - :meth:`~asyncio.BaseEventLoop.create_server` method. `host`, `port` as + It's a wrapper around the event loop's + :meth:`~asyncio.BaseEventLoop.create_server` method. ``host``, ``port`` as well as extra keyword arguments are passed to - :meth:`~asyncio.BaseEventLoop.create_server`. + :meth:`~asyncio.BaseEventLoop.create_server`. For example, you can set the + ``ssl`` keyword argument to a :class:`~ssl.SSLContext` to enable TLS. ``ws_handler`` is the WebSocket handler. It must be a coroutine accepting - two arguments: a :class:`~websockets.server.WebSocketServerProtocol` and - the request URI. + two arguments: a :class:`WebSocketServerProtocol` and the request URI. - This coroutine accepts several optional arguments: + :func:`serve` accepts several optional arguments: * ``origins`` defines acceptable Origin HTTP headers — include ``''`` if the lack of an origin is acceptable * ``subprotocols`` is a list of supported subprotocols in order of - decreasing preference - * ``extra_headers`` sets additional HTTP response headers – it can be a + decreasing preference + * ``extra_headers`` sets additional HTTP response headers — it can be a mapping, an iterable of (name, value) pairs, or a callable taking the request path and headers in arguments. - `serve` yields a `Server` object with a `close` method to stop the server. + :func:`serve` yields a :class:`~asyncio.Server` which provides a + :meth:`~asyncio.Server.close` method and a + :meth:`~asyncio.Server.wait_closed` coroutine to stop serving requests. Whenever a client connects, the server accepts the connection, creates a - :class:`~websockets.server.WebSocketServerProtocol`, performs the opening - handshake, and delegates to the WebSocket handler. Once the handler - completes, the server performs the closing handshake and closes the - connection. + :class:`WebSocketServerProtocol`, performs the opening handshake, and + delegates to the WebSocket handler. Once the handler completes, the server + performs the closing handshake and closes the connection. Since there's no useful way to propagate exceptions triggered in handlers, - they're sent to the `websockets.server` logger instead. Debugging is much - easier if you configure logging to print them:: + they're sent to the ``'websockets.server'`` logger instead. Debugging is + much easier if you configure logging to print them:: import logging logger = logging.getLogger('websockets.server') - logger.setLevel(logging.DEBUG) + logger.setLevel(logging.ERROR) logger.addHandler(logging.StreamHandler()) """ if loop is None: diff --git a/websockets/uri.py b/websockets/uri.py index 3b4ae7a50..2bb474fbe 100644 --- a/websockets/uri.py +++ b/websockets/uri.py @@ -5,7 +5,7 @@ .. _section 3 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-3 """ -__all__ = ['parse_uri'] +__all__ = ['parse_uri', 'WebSocketURI'] import collections import urllib.parse @@ -15,17 +15,22 @@ WebSocketURI = collections.namedtuple( 'WebSocketURI', ('secure', 'host', 'port', 'resource_name')) +WebSocketURI.__doc__ = """WebSocket URI. + +* ``secure`` is the secure flag +* ``host`` is the lower-case host +* ``port`` if the integer port, it's always provided even if it's the default +* ``resource_name`` is the resource name, that is, the path and optional query +""" def parse_uri(uri): """ This function parses and validates a WebSocket URI. - If the URI is valid, it returns a namedtuple `(secure, host, port, - resource_name)` + If the URI is valid, it returns a :class:`WebSocketURI`. - Otherwise, it raises an :exc:`~websockets.exceptions.InvalidURI` - exception. + Otherwise it raises an :exc:`~websockets.exceptions.InvalidURI` exception. """ uri = urllib.parse.urlparse(uri) try: From 4bb82d95ae34592fcf861cdaf02faeca0c6a6249 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 20 Jul 2015 23:17:15 +0200 Subject: [PATCH 0085/1539] Remove a misunderstanding of the specification. As explained properly in the remainder of the library, clients are allowed to start the WebSocket closing handshake. It's the TCP connection that should be closed by the server (to allow reuse). --- websockets/client.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index 3c648d787..3908fe5bf 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -124,10 +124,6 @@ def connect(uri, *, It raises :exc:`~websockets.uri.InvalidURI` if ``uri`` is invalid and :exc:`~websockets.handshake.InvalidHandshake` if the handshake fails. - - Clients shouldn't close the WebSocket connection. Instead, they should - wait until the server performs the closing handshake by yielding from the - protocol's :attr:`worker` attribute. """ if loop is None: loop = asyncio.get_event_loop() From 6e9f7e4bdecdc6c4f95603a74611ec40dfad2c60 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 20 Jul 2015 23:58:57 +0200 Subject: [PATCH 0086/1539] Add an environment variable to adjust test timeouts. --- websockets/test_protocol.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 1f0a8ab6a..2dabc10bf 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -1,5 +1,6 @@ import asyncio import functools +import os import unittest import unittest.mock @@ -8,7 +9,9 @@ from .protocol import CLOSED, CLOSING, WebSocketCommonProtocol -MS = 0.001 # Unit for timeouts. May be increased on slow machines. +# Unit for timeouts. May be increased on slow machines by setting the +# WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variables. +MS = 0.001 * int(os.environ.get('WEBSOCKETS_TESTS_TIMEOUT_FACTOR', 1)) class CommonTests: From a3444c4dc39d452cb2b1cd368b5ec1b6972d2972 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 21 Jul 2015 09:53:47 +0200 Subject: [PATCH 0087/1539] Remove unneeded mock attribute. The implementation of asyncio must have changed since it was added. --- websockets/test_protocol.py | 1 - 1 file changed, 1 deletion(-) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 2dabc10bf..13a709bad 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -22,7 +22,6 @@ def setUp(self): asyncio.set_event_loop(self.loop) self.protocol = WebSocketCommonProtocol() self.transport = unittest.mock.Mock() - self.transport._conn_lost = 0 # checked by drain() self.transport.close = unittest.mock.Mock( side_effect=lambda: self.protocol.connection_lost(None)) self.protocol.connection_made(self.transport) From 0bf62b695f128e451f13bc5e6b7f655781c8393f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 21 Jul 2015 12:22:23 +0200 Subject: [PATCH 0088/1539] Initialize worker when reader/writer are available. Theoretically there was a race condition here. --- websockets/protocol.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index 7e4d97ca8..c04de13f9 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -113,8 +113,8 @@ def __init__(self, *, # Mapping of ping IDs to waiters, in chronological order. self.pings = collections.OrderedDict() - # Task managing the connection. - self.worker = asyncio.async(self.run(), loop=loop) + # Task managing the connection, initalized in self.client_connected. + self.worker = None # In a subclass implementing the opening handshake, the state will be # CONNECTING at this point. @@ -464,6 +464,8 @@ def fail_connection(self, code=1011, reason=''): def client_connected(self, reader, writer): self.reader = reader self.writer = writer + # Start the task that handles incoming messages. + self.worker = asyncio.async(self.run(), loop=self.loop) def connection_lost(self, exc): # 7.1.4. The WebSocket Connection is Closed From 2437caa994e5a6cbfa79674ec99805890ee2e7c0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 21 Jul 2015 12:26:59 +0200 Subject: [PATCH 0089/1539] Materialize the transport mock in its own class. --- websockets/test_protocol.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 13a709bad..dbb0cf3bf 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -14,6 +14,28 @@ MS = 0.001 * int(os.environ.get('WEBSOCKETS_TESTS_TIMEOUT_FACTOR', 1)) +class TransportMock(unittest.mock.Mock): + """ + Transport mock to control the protocol's inputs and outputs in tests. + + It calls the protocol's connection_made and connection_lost methods like + actual transports. + + To simulate incoming data, tests call the protocol's data_received and + eof_received methods directly. + + They could also pause_writing and resume_writing to test flow control. + """ + # This should happen in __init__ but overriding Mock.__init__ is hard. + def connect(self, loop, protocol): + self.loop = loop + self.protocol = protocol + self.protocol.connection_made(self) + + def close(self): + self.protocol.connection_lost(None) + + class CommonTests: def setUp(self): @@ -21,10 +43,8 @@ def setUp(self): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) self.protocol = WebSocketCommonProtocol() - self.transport = unittest.mock.Mock() - self.transport.close = unittest.mock.Mock( - side_effect=lambda: self.protocol.connection_lost(None)) - self.protocol.connection_made(self.transport) + self.transport = TransportMock() + self.transport.connect(self.loop, self.protocol) def tearDown(self): self.loop.close() From 56ad98fc1f0ba025edfa172f7b4caef31978d2d4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 21 Jul 2015 12:29:48 +0200 Subject: [PATCH 0090/1539] Rename feed to receive_frame. This is more explicit. --- websockets/test_protocol.py | 102 +++++++++++++++++++----------------- 1 file changed, 53 insertions(+), 49 deletions(-) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index dbb0cf3bf..fe8d0580a 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -54,10 +54,13 @@ def tearDown(self): def async(self): return functools.partial(asyncio.async, loop=self.loop) - def feed(self, frame): - """Feed a frame to the protocol.""" + def receive_frame(self, frame): + """ + Make the protocol receive a frame. + """ + writer = self.protocol.data_received mask = not self.protocol.is_client - write_frame(frame, self.protocol.data_received, mask) + write_frame(frame, writer, mask) @asyncio.coroutine def sent(self): @@ -74,7 +77,7 @@ def sent(self): @asyncio.coroutine def echo(self): """Echo to the protocol the next frame sent to the transport.""" - self.feed((yield from self.sent())) + self.receive_frame((yield from self.sent())) @asyncio.coroutine def fast_connection_failure(self): @@ -84,7 +87,7 @@ def fast_connection_failure(self): def process_control_frames(self): """Process control frames fed to the protocol.""" - self.feed(Frame(True, OP_TEXT, b'')) + self.receive_frame(Frame(True, OP_TEXT, b'')) self.loop.run_until_complete(self.protocol.recv()) def assertFrameSent(self, fin, opcode, data): @@ -110,50 +113,50 @@ def test_connection_lost(self): self.assertConnectionClosed(1006, '') def test_recv_text(self): - self.feed(Frame(True, OP_TEXT, 'café'.encode('utf-8'))) + self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8'))) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, 'café') def test_recv_binary(self): - self.feed(Frame(True, OP_BINARY, b'tea')) + self.receive_frame(Frame(True, OP_BINARY, b'tea')) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, b'tea') def test_recv_protocol_error(self): - self.feed(Frame(True, OP_CONT, 'café'.encode('utf-8'))) + self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8'))) self.loop.call_later(MS, self.async, self.fast_connection_failure()) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1002, '') def test_recv_unicode_error(self): - self.feed(Frame(True, OP_TEXT, 'café'.encode('latin-1'))) + self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('latin-1'))) self.loop.call_later(MS, self.async, self.fast_connection_failure()) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1007, '') def test_recv_text_payload_too_big(self): self.protocol.max_size = 1024 - self.feed(Frame(True, OP_TEXT, 'café'.encode('utf-8') * 205)) + self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8') * 205)) self.loop.call_later(MS, self.async, self.fast_connection_failure()) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1009, '') def test_recv_binary_payload_too_big(self): self.protocol.max_size = 1024 - self.feed(Frame(True, OP_BINARY, b'tea' * 342)) + self.receive_frame(Frame(True, OP_BINARY, b'tea' * 342)) self.loop.call_later(MS, self.async, self.fast_connection_failure()) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1009, '') def test_recv_text_no_max_size(self): self.protocol.max_size = None # for test coverage - self.feed(Frame(True, OP_TEXT, 'café'.encode('utf-8') * 205)) + self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8') * 205)) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, 'café' * 205) def test_recv_binary_no_max_size(self): self.protocol.max_size = None # for test coverage - self.feed(Frame(True, OP_BINARY, b'tea' * 342)) + self.receive_frame(Frame(True, OP_BINARY, b'tea' * 342)) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, b'tea' * 342) @@ -179,7 +182,7 @@ def test_recv_cancelled(self): asyncio.wait_for(self.protocol.recv(), 1, loop=self.loop) ) except asyncio.TimeoutError: - self.feed(Frame(True, OP_TEXT, 'café'.encode('utf-8'))) + self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8'))) data = self.loop.run_until_complete( asyncio.wait_for(self.protocol.recv(), 1, loop=self.loop) ) # We use wait_for here to make sure the test fail and don't hang @@ -206,12 +209,12 @@ def test_send_on_closed_connection(self): self.assertNoFrameSent() def test_answer_ping(self): - self.feed(Frame(True, OP_PING, b'test')) + self.receive_frame(Frame(True, OP_PING, b'test')) self.process_control_frames() self.assertFrameSent(True, OP_PONG, b'test') def test_ignore_pong(self): - self.feed(Frame(True, OP_PONG, b'test')) + self.receive_frame(Frame(True, OP_PONG, b'test')) self.process_control_frames() self.assertNoFrameSent() @@ -220,7 +223,7 @@ def test_acknowledge_ping(self): self.assertFalse(ping.done()) ping_frame = self.loop.run_until_complete(self.sent()) pong_frame = Frame(True, OP_PONG, ping_frame.data) - self.feed(pong_frame) + self.receive_frame(pong_frame) self.process_control_frames() self.assertTrue(ping.done()) @@ -230,13 +233,13 @@ def test_acknowledge_previous_pings(self): self.loop.run_until_complete(self.sent()), ) for i in range(3)] # Unsolicited pong doesn't acknowledge pings - self.feed(Frame(True, OP_PONG, b'')) + self.receive_frame(Frame(True, OP_PONG, b'')) self.process_control_frames() self.assertFalse(pings[0][0].done()) self.assertFalse(pings[1][0].done()) self.assertFalse(pings[2][0].done()) # Pong acknowledges all previous pings - self.feed(Frame(True, OP_PONG, pings[1][1].data)) + self.receive_frame(Frame(True, OP_PONG, pings[1][1].data)) self.process_control_frames() self.assertTrue(pings[0][0].done()) self.assertTrue(pings[1][0].done()) @@ -247,7 +250,7 @@ def test_cancel_ping(self): ping_frame = self.loop.run_until_complete(self.sent()) ping.cancel() pong_frame = Frame(True, OP_PONG, ping_frame.data) - self.feed(pong_frame) + self.receive_frame(pong_frame) self.process_control_frames() self.assertTrue(ping.cancelled()) @@ -259,73 +262,73 @@ def test_duplicate_ping(self): self.assertNoFrameSent() def test_fragmented_text(self): - self.feed(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) - self.feed(Frame(True, OP_CONT, 'fé'.encode('utf-8'))) + self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) + self.receive_frame(Frame(True, OP_CONT, 'fé'.encode('utf-8'))) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, 'café') def test_fragmented_binary(self): - self.feed(Frame(False, OP_BINARY, b't')) - self.feed(Frame(False, OP_CONT, b'e')) - self.feed(Frame(True, OP_CONT, b'a')) + self.receive_frame(Frame(False, OP_BINARY, b't')) + self.receive_frame(Frame(False, OP_CONT, b'e')) + self.receive_frame(Frame(True, OP_CONT, b'a')) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, b'tea') def test_fragmented_text_payload_too_big(self): self.protocol.max_size = 1024 - self.feed(Frame(False, OP_TEXT, 'café'.encode('utf-8') * 100)) - self.feed(Frame(True, OP_CONT, 'café'.encode('utf-8') * 105)) + self.receive_frame(Frame(False, OP_TEXT, 'café'.encode('utf-8') * 100)) + self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8') * 105)) self.loop.call_later(MS, self.async, self.fast_connection_failure()) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1009, '') def test_fragmented_binary_payload_too_big(self): self.protocol.max_size = 1024 - self.feed(Frame(False, OP_BINARY, b'tea' * 171)) - self.feed(Frame(True, OP_CONT, b'tea' * 171)) + self.receive_frame(Frame(False, OP_BINARY, b'tea' * 171)) + self.receive_frame(Frame(True, OP_CONT, b'tea' * 171)) self.loop.call_later(MS, self.async, self.fast_connection_failure()) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1009, '') def test_fragmented_text_no_max_size(self): self.protocol.max_size = None # for test coverage - self.feed(Frame(False, OP_TEXT, 'café'.encode('utf-8') * 100)) - self.feed(Frame(True, OP_CONT, 'café'.encode('utf-8') * 105)) + self.receive_frame(Frame(False, OP_TEXT, 'café'.encode('utf-8') * 100)) + self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8') * 105)) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, 'café' * 205) def test_fragmented_binary_no_max_size(self): self.protocol.max_size = None # for test coverage - self.feed(Frame(False, OP_BINARY, b'tea' * 171)) - self.feed(Frame(True, OP_CONT, b'tea' * 171)) + self.receive_frame(Frame(False, OP_BINARY, b'tea' * 171)) + self.receive_frame(Frame(True, OP_CONT, b'tea' * 171)) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, b'tea' * 342) def test_control_frame_within_fragmented_text(self): - self.feed(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) - self.feed(Frame(True, OP_PING, b'')) - self.feed(Frame(True, OP_CONT, 'fé'.encode('utf-8'))) + self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) + self.receive_frame(Frame(True, OP_PING, b'')) + self.receive_frame(Frame(True, OP_CONT, 'fé'.encode('utf-8'))) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, 'café') self.assertFrameSent(True, OP_PONG, b'') def test_unterminated_fragmented_text(self): - self.feed(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) + self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) # Missing the second part of the fragmented frame. - self.feed(Frame(True, OP_BINARY, b'tea')) + self.receive_frame(Frame(True, OP_BINARY, b'tea')) self.loop.call_later(MS, self.async, self.fast_connection_failure()) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1002, '') def test_close_handshake_in_fragmented_text(self): - self.feed(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) - self.feed(Frame(True, OP_CLOSE, b'')) + self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) + self.receive_frame(Frame(True, OP_CLOSE, b'')) self.loop.call_later(MS, self.async, self.fast_connection_failure()) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1002, '') def test_connection_close_in_fragmented_text(self): - self.feed(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) + self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) self.loop.call_later(MS, self.protocol.eof_received) self.loop.call_later(2 * MS, self.protocol.connection_lost, None) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) @@ -347,7 +350,7 @@ def test_close(self): # standard server-initiated close def test_client_close(self): # non standard client-initiated close frame = Frame(True, OP_CLOSE, serialize_close(1000, 'because.')) - self.loop.call_later(MS, self.feed, frame) + self.loop.call_later(MS, self.receive_frame, frame) # The server is waiting for some data at this point, and won't get it. self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) # After recv() returns None, the connection is closed. @@ -361,14 +364,14 @@ def test_client_close(self): # non standard client-initiated close def test_simultaneous_close(self): # non standard close from both sides client_close = Frame(True, OP_CLOSE, serialize_close(1000, 'client')) server_close = Frame(True, OP_CLOSE, serialize_close(1000, 'server')) - self.loop.call_later(MS, self.feed, client_close) + self.loop.call_later(MS, self.receive_frame, client_close) self.loop.run_until_complete(self.protocol.close(reason='server')) self.assertConnectionClosed(1000, 'client') self.assertFrameSent(*server_close) self.assertNoFrameSent() def test_close_drops_frames(self): - self.loop.call_later(MS, self.feed, Frame(True, OP_TEXT, b'')) + self.loop.call_later(MS, self.receive_frame, Frame(True, OP_TEXT, b'')) self.loop.call_later(2 * MS, self.async, self.echo()) self.loop.run_until_complete(self.protocol.close(reason='because.')) self.assertConnectionClosed(1000, 'because.') @@ -416,7 +419,7 @@ def delayed_write_frame(*args, **kwargs): frame = Frame(True, OP_CLOSE, serialize_close(1000, 'client')) # Trigger the race condition between answering the close frame from # the client and sending another close frame from the server. - self.loop.call_later(MS, self.feed, frame) + self.loop.call_later(MS, self.receive_frame, frame) fail_connection = self.protocol.fail_connection(1000, 'server') self.loop.call_later(2 * MS, self.async, fail_connection) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) @@ -424,7 +427,8 @@ def delayed_write_frame(*args, **kwargs): self.assertFrameSent(*frame) def test_close_protocol_error(self): - self.loop.call_later(MS, self.feed, Frame(True, OP_CLOSE, b'\x00')) + invalid_close_frame = Frame(True, OP_CLOSE, b'\x00') + self.loop.call_later(MS, self.receive_frame, invalid_close_frame) self.loop.run_until_complete(self.protocol.close(reason='because.')) self.assertConnectionClosed(1002, '') @@ -459,7 +463,7 @@ def setUp(self): def test_close(self): # standard server-initiated close frame = Frame(True, OP_CLOSE, serialize_close(1000, 'because.')) - self.loop.call_later(MS, self.feed, frame) + self.loop.call_later(MS, self.receive_frame, frame) self.loop.call_later(2 * MS, self.protocol.eof_received) self.loop.call_later(3 * MS, self.protocol.connection_lost, None) # The client is waiting for some data at this point, and won't get it. @@ -488,7 +492,7 @@ def test_client_close(self): # non standard client-initiated close def test_simultaneous_close(self): # non standard close from both sides server_close = Frame(True, OP_CLOSE, serialize_close(1000, 'server')) client_close = Frame(True, OP_CLOSE, serialize_close(1000, 'client')) - self.loop.call_later(MS, self.feed, server_close) + self.loop.call_later(MS, self.receive_frame, server_close) self.loop.call_later(2 * MS, self.protocol.eof_received) self.loop.call_later(3 * MS, self.protocol.connection_lost, None) self.loop.run_until_complete(self.protocol.close(reason='client')) @@ -541,7 +545,7 @@ def delayed_write_frame(*args, **kwargs): frame = Frame(True, OP_CLOSE, serialize_close(1000, 'server')) # Trigger the race condition between answering the close frame from # the server and sending another close frame from the client. - self.loop.call_later(MS, self.feed, frame) + self.loop.call_later(MS, self.receive_frame, frame) fail_connection = self.protocol.fail_connection(1000, 'client') self.loop.call_later(2 * MS, self.async, fail_connection) self.loop.call_later(3 * MS, self.protocol.eof_received) From b6120037b71752443bddd683eaba620ef5ea5dab Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 21 Jul 2015 12:41:56 +0200 Subject: [PATCH 0091/1539] Add receive_eof to match receive_frame. Clarify how the connection close sequence works in asyncio transports. Remove two tests for a sequence that cannot happen in practice (unless the event loop runs exceedingly slow, in which case there isn't much we can do). As explained in the docstring of receive_eof and TransportMock, transports always close themselves and call connection_lost. --- websockets/test_protocol.py | 103 +++++++++++++----------------------- 1 file changed, 37 insertions(+), 66 deletions(-) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index fe8d0580a..7b8d46e86 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -6,7 +6,7 @@ from .exceptions import InvalidState from .framing import * -from .protocol import CLOSED, CLOSING, WebSocketCommonProtocol +from .protocol import CLOSED, WebSocketCommonProtocol # Unit for timeouts. May be increased on slow machines by setting the @@ -62,6 +62,25 @@ def receive_frame(self, frame): mask = not self.protocol.is_client write_frame(frame, writer, mask) + def receive_eof(self): + """ + Make the protocol receive the end of stream. + + WebSocketCommonProtocol.eof_received returns None — it is inherited + from StreamReaderProtocol. (Returning True wouldn't work on secure + connections anyway.) As a consequence, actual transports close + themselves after calling it. + + To emulate this behavior, tests must close the transport just after + calling the protocol's eof_received. Closing the transport will have + the side-effect calling the protocol's connection_lost. + + This method is often called shortly after simulating invalid data to + ensure that the connection fails quickly. + """ + self.protocol.eof_received() + self.transport.close() + @asyncio.coroutine def sent(self): """Read the next frame sent to the transport.""" @@ -79,12 +98,6 @@ def echo(self): """Echo to the protocol the next frame sent to the transport.""" self.receive_frame((yield from self.sent())) - @asyncio.coroutine - def fast_connection_failure(self): - """Ensure the connection failure terminates quickly.""" - self.protocol.eof_received() - self.protocol.connection_lost(None) - def process_control_frames(self): """Process control frames fed to the protocol.""" self.receive_frame(Frame(True, OP_TEXT, b'')) @@ -124,27 +137,27 @@ def test_recv_binary(self): def test_recv_protocol_error(self): self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8'))) - self.loop.call_later(MS, self.async, self.fast_connection_failure()) + self.loop.call_later(MS, self.receive_eof) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1002, '') def test_recv_unicode_error(self): self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('latin-1'))) - self.loop.call_later(MS, self.async, self.fast_connection_failure()) + self.loop.call_later(MS, self.receive_eof) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1007, '') def test_recv_text_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8') * 205)) - self.loop.call_later(MS, self.async, self.fast_connection_failure()) + self.loop.call_later(MS, self.receive_eof) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1009, '') def test_recv_binary_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(True, OP_BINARY, b'tea' * 342)) - self.loop.call_later(MS, self.async, self.fast_connection_failure()) + self.loop.call_later(MS, self.receive_eof) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1009, '') @@ -165,15 +178,14 @@ def test_recv_other_error(self): def read_message(): raise Exception("BOOM") self.protocol.read_message = read_message - self.loop.call_later(MS, self.async, self.fast_connection_failure()) + self.loop.call_later(MS, self.receive_eof) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) with self.assertRaises(Exception): self.loop.run_until_complete(self.protocol.worker) self.assertConnectionClosed(1011, '') def test_recv_on_closed_connection(self): - self.protocol.eof_received() - self.protocol.connection_lost(None) + self.receive_eof() self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) def test_recv_cancelled(self): @@ -202,8 +214,7 @@ def test_send_type_error(self): self.assertNoFrameSent() def test_send_on_closed_connection(self): - self.protocol.eof_received() - self.protocol.connection_lost(None) + self.receive_eof() with self.assertRaises(InvalidState): self.loop.run_until_complete(self.protocol.send('foobar')) self.assertNoFrameSent() @@ -278,7 +289,7 @@ def test_fragmented_text_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(False, OP_TEXT, 'café'.encode('utf-8') * 100)) self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8') * 105)) - self.loop.call_later(MS, self.async, self.fast_connection_failure()) + self.loop.call_later(MS, self.receive_eof) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1009, '') @@ -286,7 +297,7 @@ def test_fragmented_binary_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(False, OP_BINARY, b'tea' * 171)) self.receive_frame(Frame(True, OP_CONT, b'tea' * 171)) - self.loop.call_later(MS, self.async, self.fast_connection_failure()) + self.loop.call_later(MS, self.receive_eof) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1009, '') @@ -316,21 +327,20 @@ def test_unterminated_fragmented_text(self): self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) # Missing the second part of the fragmented frame. self.receive_frame(Frame(True, OP_BINARY, b'tea')) - self.loop.call_later(MS, self.async, self.fast_connection_failure()) + self.loop.call_later(MS, self.receive_eof) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1002, '') def test_close_handshake_in_fragmented_text(self): self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) self.receive_frame(Frame(True, OP_CLOSE, b'')) - self.loop.call_later(MS, self.async, self.fast_connection_failure()) + self.loop.call_later(MS, self.receive_eof) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1002, '') def test_connection_close_in_fragmented_text(self): self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) - self.loop.call_later(MS, self.protocol.eof_received) - self.loop.call_later(2 * MS, self.protocol.connection_lost, None) + self.loop.call_later(MS, self.receive_eof) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1006, '') @@ -390,22 +400,6 @@ def test_close_handshake_timeout(self): self.assertFalse(self.before.cancelled()) self.before.cancel() - def test_close_timeout_before_connection_lost(self): - # Prevent the connection from terminating. - self.protocol.connection_lost = unittest.mock.Mock() - - self.after = asyncio.Future(loop=self.loop) - self.loop.call_later(4 * MS, self.after.cancel) - self.before = asyncio.Future(loop=self.loop) - self.loop.call_later(8 * MS, self.before.cancel) - self.protocol.timeout = 5 * MS - self.loop.call_later(MS, self.async, self.echo()) - self.loop.run_until_complete(self.protocol.close(reason='because.')) - self.assertEqual(self.protocol.state, CLOSING) - self.assertTrue(self.after.cancelled()) - self.assertFalse(self.before.cancelled()) - self.before.cancel() - def test_client_close_race_with_failing_connection(self): original_write_frame = self.protocol.write_frame @@ -433,8 +427,7 @@ def test_close_protocol_error(self): self.assertConnectionClosed(1002, '') def test_close_connection_lost(self): - self.loop.call_later(MS, self.protocol.eof_received) - self.loop.call_later(2 * MS, self.protocol.connection_lost, None) + self.loop.call_later(MS, self.receive_eof) self.loop.run_until_complete(self.protocol.close(reason='because.')) self.assertConnectionClosed(1006, '') @@ -464,8 +457,7 @@ def setUp(self): def test_close(self): # standard server-initiated close frame = Frame(True, OP_CLOSE, serialize_close(1000, 'because.')) self.loop.call_later(MS, self.receive_frame, frame) - self.loop.call_later(2 * MS, self.protocol.eof_received) - self.loop.call_later(3 * MS, self.protocol.connection_lost, None) + self.loop.call_later(2 * MS, self.receive_eof) # The client is waiting for some data at this point, and won't get it. self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) # After recv() returns None, the connection is closed. @@ -478,8 +470,7 @@ def test_close(self): # standard server-initiated close def test_client_close(self): # non standard client-initiated close self.loop.call_later(MS, self.async, self.echo()) - self.loop.call_later(2 * MS, self.protocol.eof_received) - self.loop.call_later(3 * MS, self.protocol.connection_lost, None) + self.loop.call_later(2 * MS, self.receive_eof) self.loop.run_until_complete(self.protocol.close(reason='because.')) self.assertConnectionClosed(1000, 'because.') # Only one frame is emitted, and it's consumed by self.echo(). @@ -493,8 +484,7 @@ def test_simultaneous_close(self): # non standard close from both sides server_close = Frame(True, OP_CLOSE, serialize_close(1000, 'server')) client_close = Frame(True, OP_CLOSE, serialize_close(1000, 'client')) self.loop.call_later(MS, self.receive_frame, server_close) - self.loop.call_later(2 * MS, self.protocol.eof_received) - self.loop.call_later(3 * MS, self.protocol.connection_lost, None) + self.loop.call_later(2 * MS, self.receive_eof) self.loop.run_until_complete(self.protocol.close(reason='client')) self.assertConnectionClosed(1000, 'server') self.assertFrameSent(*client_close) @@ -514,24 +504,6 @@ def test_close_timeout_before_eof_received(self): self.assertFalse(self.before.cancelled()) self.before.cancel() - def test_close_timeout_before_connection_lost(self): - # Prevent the connection from terminating. - self.protocol.connection_lost = unittest.mock.Mock() - - self.after = asyncio.Future(loop=self.loop) - self.loop.call_later(9 * MS, self.after.cancel) - self.before = asyncio.Future(loop=self.loop) - self.loop.call_later(13 * MS, self.before.cancel) - self.protocol.timeout = 5 * MS - self.loop.call_later(MS, self.async, self.echo()) - self.loop.call_later(2 * MS, self.protocol.eof_received) - self.loop.run_until_complete(self.protocol.close(reason='because.')) - # If the server doesn't drop the connection quickly, the client will. - self.assertEqual(self.protocol.state, CLOSING) - self.assertTrue(self.after.cancelled()) - self.assertFalse(self.before.cancelled()) - self.before.cancel() - def test_server_close_race_with_failing_connection(self): original_write_frame = self.protocol.write_frame @@ -548,8 +520,7 @@ def delayed_write_frame(*args, **kwargs): self.loop.call_later(MS, self.receive_frame, frame) fail_connection = self.protocol.fail_connection(1000, 'client') self.loop.call_later(2 * MS, self.async, fail_connection) - self.loop.call_later(3 * MS, self.protocol.eof_received) - self.loop.call_later(4 * MS, self.protocol.connection_lost, None) + self.loop.call_later(3 * MS, self.receive_eof) self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1000, 'client') self.assertFrameSent(*frame) From dc711581a48e2f9b57dee98aecc716128654154e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 21 Jul 2015 12:52:12 +0200 Subject: [PATCH 0092/1539] Client is allowed to close the connection. Remove more comments related to my misunderstanding of this point. --- websockets/test_protocol.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 7b8d46e86..7150c8371 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -347,7 +347,7 @@ def test_connection_close_in_fragmented_text(self): class ServerTests(CommonTests, unittest.TestCase): - def test_close(self): # standard server-initiated close + def test_server_close(self): self.loop.call_later(MS, self.async, self.echo()) self.loop.run_until_complete(self.protocol.close(reason='because.')) self.assertConnectionClosed(1000, 'because.') @@ -358,7 +358,7 @@ def test_close(self): # standard server-initiated close self.assertConnectionClosed(1000, 'because.') self.assertNoFrameSent() - def test_client_close(self): # non standard client-initiated close + def test_client_close(self): frame = Frame(True, OP_CLOSE, serialize_close(1000, 'because.')) self.loop.call_later(MS, self.receive_frame, frame) # The server is waiting for some data at this point, and won't get it. @@ -371,7 +371,7 @@ def test_client_close(self): # non standard client-initiated close self.assertConnectionClosed(1000, 'because.') self.assertNoFrameSent() - def test_simultaneous_close(self): # non standard close from both sides + def test_simultaneous_close(self): client_close = Frame(True, OP_CLOSE, serialize_close(1000, 'client')) server_close = Frame(True, OP_CLOSE, serialize_close(1000, 'server')) self.loop.call_later(MS, self.receive_frame, client_close) @@ -454,7 +454,19 @@ def setUp(self): super().setUp() self.protocol.is_client = True - def test_close(self): # standard server-initiated close + def test_client_close(self): + self.loop.call_later(MS, self.async, self.echo()) + self.loop.call_later(2 * MS, self.receive_eof) + self.loop.run_until_complete(self.protocol.close(reason='because.')) + self.assertConnectionClosed(1000, 'because.') + # Only one frame is emitted, and it's consumed by self.echo(). + self.assertNoFrameSent() + # Closing the connection again is a no-op. + self.loop.run_until_complete(self.protocol.close(reason='oh noes!')) + self.assertConnectionClosed(1000, 'because.') + self.assertNoFrameSent() + + def test_server_close(self): frame = Frame(True, OP_CLOSE, serialize_close(1000, 'because.')) self.loop.call_later(MS, self.receive_frame, frame) self.loop.call_later(2 * MS, self.receive_eof) @@ -468,19 +480,7 @@ def test_close(self): # standard server-initiated close self.assertConnectionClosed(1000, 'because.') self.assertNoFrameSent() - def test_client_close(self): # non standard client-initiated close - self.loop.call_later(MS, self.async, self.echo()) - self.loop.call_later(2 * MS, self.receive_eof) - self.loop.run_until_complete(self.protocol.close(reason='because.')) - self.assertConnectionClosed(1000, 'because.') - # Only one frame is emitted, and it's consumed by self.echo(). - self.assertNoFrameSent() - # Closing the connection again is a no-op. - self.loop.run_until_complete(self.protocol.close(reason='oh noes!')) - self.assertConnectionClosed(1000, 'because.') - self.assertNoFrameSent() - - def test_simultaneous_close(self): # non standard close from both sides + def test_simultaneous_close(self): server_close = Frame(True, OP_CLOSE, serialize_close(1000, 'server')) client_close = Frame(True, OP_CLOSE, serialize_close(1000, 'client')) self.loop.call_later(MS, self.receive_frame, server_close) From d3219088d393f4e7bc6d7610d21c8b0d0580dfb2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 21 Jul 2015 21:14:34 +0200 Subject: [PATCH 0093/1539] Improve implementation of test_recv_cancelled. Use an explicit cancellation after 1ms instead of a timeout after 1s to cancel recv(). This makes the entire test suite ~5 times faster as this test, which runs twice, accounted for ~80% of the 2.5s test run time. --- websockets/test_protocol.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 7150c8371..2181b03f6 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -189,16 +189,15 @@ def test_recv_on_closed_connection(self): self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) def test_recv_cancelled(self): - try: - data = self.loop.run_until_complete( - asyncio.wait_for(self.protocol.recv(), 1, loop=self.loop) - ) - except asyncio.TimeoutError: - self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8'))) - data = self.loop.run_until_complete( - asyncio.wait_for(self.protocol.recv(), 1, loop=self.loop) - ) # We use wait_for here to make sure the test fail and don't hang - self.assertEqual(data, 'café') + recv = self.async(self.protocol.recv()) + self.loop.call_later(MS, recv.cancel) + with self.assertRaises(asyncio.CancelledError): + self.loop.run_until_complete(recv) + + # The next frame doesn't disappear in a vacuum (it used to). + self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8'))) + data = self.loop.run_until_complete(self.protocol.recv()) + self.assertEqual(data, 'café') def test_send_text(self): self.loop.run_until_complete(self.protocol.send('café')) From d1be44c579c1884e0ff676c8f2a1eeb11a446760 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 21 Jul 2015 21:29:41 +0200 Subject: [PATCH 0094/1539] Test passing explicitly an event loop to connect/serve. --- websockets/test_client_server.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 0fdf170ef..4b0d8da43 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -73,6 +73,15 @@ def test_basic(self): self.stop_client() self.stop_server() + def test_explicit_event_loop(self): + self.start_server(loop=self.loop) + self.start_client(loop=self.loop) + self.loop.run_until_complete(self.client.send("Hello!")) + reply = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(reply, "Hello!") + self.stop_client() + self.stop_server() + def test_protocol_attributes(self): self.start_server() self.start_client('attributes') From 8b0ab64c9b1c2664d01721fb7b328a024e2464cc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 22 Jul 2015 10:25:15 +0200 Subject: [PATCH 0095/1539] Large cleanup in the tests for closing scenarios. Remove the echo() utility. Check sent frames and receive frames explicitly instead. Factor out the definition of close frames which are used by several tests. Separate setup from assertions. Document tests that rely on timings. Check that no unexpected frames are sent. Run the same tests on the server and the client side. Rewrite test_eof_received_timeout which I removed recently. Remove test_close_after_cancelled_recv which is superseded by the more general test_recv_cancelled. --- docs/index.rst | 2 + websockets/protocol.py | 4 +- websockets/test_protocol.py | 215 +++++++++++++++++++++++++----------- 3 files changed, 154 insertions(+), 67 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index ba778f9ee..2ae398eb4 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -305,6 +305,8 @@ Changelog * Cancelling :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` no longer drops the next message. +* Improved tests. + 2.4 ... diff --git a/websockets/protocol.py b/websockets/protocol.py index c04de13f9..2bb2de4fd 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -419,11 +419,11 @@ def close_connection(self): return # Attempt to terminate the TCP connection properly. - # If the socket is already closed, this will crash. + # If the socket is already closed, this may crash. try: if self.writer.can_write_eof(): self.writer.write_eof() - except Exception: + except Exception: # pragma: no cover pass self.writer.close() diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 2181b03f6..4a94be1ba 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -50,6 +50,11 @@ def tearDown(self): self.loop.close() super().tearDown() + # These frames are used in the ServerTests and ClientTests subclasses. + close_frame = Frame(True, OP_CLOSE, serialize_close(1000, 'because.')) + client_close = Frame(True, OP_CLOSE, serialize_close(1000, 'client')) + server_close = Frame(True, OP_CLOSE, serialize_close(1000, 'server')) + @property def async(self): return functools.partial(asyncio.async, loop=self.loop) @@ -93,11 +98,6 @@ def sent(self): return (yield from read_frame( stream.readexactly, self.protocol.is_client)) - @asyncio.coroutine - def echo(self): - """Echo to the protocol the next frame sent to the transport.""" - self.receive_frame((yield from self.sent())) - def process_control_frames(self): """Process control frames fed to the protocol.""" self.receive_frame(Frame(True, OP_TEXT, b'')) @@ -112,6 +112,7 @@ def assertNoFrameSent(self): self.assertIsNone(sent) def assertConnectionClosed(self, code, message): + # The following line guarantees that connection_lost was called. self.assertEqual(self.protocol.state, CLOSED) self.assertEqual(self.protocol.close_code, code) self.assertEqual(self.protocol.close_reason, message) @@ -344,56 +345,72 @@ def test_connection_close_in_fragmented_text(self): self.assertConnectionClosed(1006, '') -class ServerTests(CommonTests, unittest.TestCase): +class ServerCloseTests(CommonTests, unittest.TestCase): def test_server_close(self): - self.loop.call_later(MS, self.async, self.echo()) + self.loop.call_later(MS, self.receive_frame, self.close_frame) self.loop.run_until_complete(self.protocol.close(reason='because.')) + self.assertConnectionClosed(1000, 'because.') - # Only one frame is emitted, and it's consumed by self.echo(). + self.assertFrameSent(*self.close_frame) self.assertNoFrameSent() + # Closing the connection again is a no-op. self.loop.run_until_complete(self.protocol.close(reason='oh noes!')) + self.assertConnectionClosed(1000, 'because.') self.assertNoFrameSent() def test_client_close(self): - frame = Frame(True, OP_CLOSE, serialize_close(1000, 'because.')) - self.loop.call_later(MS, self.receive_frame, frame) - # The server is waiting for some data at this point, and won't get it. - self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) + self.loop.call_later(MS, self.receive_frame, self.close_frame) + # The server is waiting for some data at this point but won't get it. + next_message = self.loop.run_until_complete(self.protocol.recv()) + + self.assertIsNone(next_message) # After recv() returns None, the connection is closed. self.assertConnectionClosed(1000, 'because.') - self.assertFrameSent(*frame) + self.assertFrameSent(*self.close_frame) + self.assertNoFrameSent() + # Closing the connection again is a no-op. self.loop.run_until_complete(self.protocol.close(reason='oh noes!')) + self.assertConnectionClosed(1000, 'because.') self.assertNoFrameSent() def test_simultaneous_close(self): - client_close = Frame(True, OP_CLOSE, serialize_close(1000, 'client')) - server_close = Frame(True, OP_CLOSE, serialize_close(1000, 'server')) - self.loop.call_later(MS, self.receive_frame, client_close) + self.loop.call_later(MS, self.receive_frame, self.client_close) self.loop.run_until_complete(self.protocol.close(reason='server')) + + # The close code and reason are taken from the remote side because + # that's presumably more useful that the values from the local side. self.assertConnectionClosed(1000, 'client') - self.assertFrameSent(*server_close) + self.assertFrameSent(*self.server_close) self.assertNoFrameSent() def test_close_drops_frames(self): - self.loop.call_later(MS, self.receive_frame, Frame(True, OP_TEXT, b'')) - self.loop.call_later(2 * MS, self.async, self.echo()) + text_frame = Frame(True, OP_TEXT, b'') + self.loop.call_later(MS, self.receive_frame, text_frame) + self.loop.call_later(2 * MS, self.receive_frame, self.close_frame) self.loop.run_until_complete(self.protocol.close(reason='because.')) + self.assertConnectionClosed(1000, 'because.') - # Only one frame is emitted, and it's consumed by self.echo(). + self.assertFrameSent(*self.close_frame) self.assertNoFrameSent() def test_close_handshake_timeout(self): + # Timeout is expected in 1 + 10 = 11ms. + # Check the timing within -1/+5ms for robustness. self.after = asyncio.Future(loop=self.loop) - self.loop.call_later(4 * MS, self.after.cancel) + self.loop.call_later(10 * MS, self.after.cancel) self.before = asyncio.Future(loop=self.loop) - self.loop.call_later(8 * MS, self.before.cancel) - self.protocol.timeout = 5 * MS + self.loop.call_later(15 * MS, self.before.cancel) + self.protocol.timeout = 10 * MS + + # Unlike previous tests, no close frame will be received in response. + # The server will stop waiting for the close frame and timeout. self.loop.run_until_complete(self.protocol.close(reason='because.')) + self.assertConnectionClosed(1000, 'because.') self.assertTrue(self.after.cancelled()) self.assertFalse(self.before.cancelled()) @@ -409,95 +426,137 @@ def delayed_write_frame(*args, **kwargs): self.protocol.write_frame = delayed_write_frame - frame = Frame(True, OP_CLOSE, serialize_close(1000, 'client')) - # Trigger the race condition between answering the close frame from - # the client and sending another close frame from the server. - self.loop.call_later(MS, self.receive_frame, frame) + # Trigger the race condition by failing the connection while answering + # the closing handshake initiated by the client. + self.loop.call_later(MS, self.receive_frame, self.client_close) fail_connection = self.protocol.fail_connection(1000, 'server') self.loop.call_later(2 * MS, self.async, fail_connection) - self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) + next_message = self.loop.run_until_complete(self.protocol.recv()) + + self.assertIsNone(next_message) self.assertConnectionClosed(1000, 'server') - self.assertFrameSent(*frame) + self.assertFrameSent(*self.client_close) + self.assertNoFrameSent() def test_close_protocol_error(self): invalid_close_frame = Frame(True, OP_CLOSE, b'\x00') self.loop.call_later(MS, self.receive_frame, invalid_close_frame) self.loop.run_until_complete(self.protocol.close(reason='because.')) + self.assertConnectionClosed(1002, '') def test_close_connection_lost(self): self.loop.call_later(MS, self.receive_eof) self.loop.run_until_complete(self.protocol.close(reason='because.')) + self.assertConnectionClosed(1006, '') def test_close_during_recv(self): recv = self.async(self.protocol.recv()) - self.loop.call_later(MS, self.async, self.echo()) + self.loop.call_later(MS, self.receive_frame, self.close_frame) self.loop.run_until_complete(self.protocol.close(reason='because.')) - self.assertIsNone(self.loop.run_until_complete(recv)) - def test_close_after_cancelled_recv(self): - recv = self.async(self.protocol.recv()) - self.loop.call_later(MS, recv.cancel) - with self.assertRaises(asyncio.CancelledError): - self.loop.run_until_complete(recv) - # Closing the connection shouldn't crash. - # I can't find a way to test this on the client side. - self.loop.call_later(MS, self.async, self.echo()) - self.loop.run_until_complete(self.protocol.close(reason='because.')) + # Receiving a message shouldn't crash. + next_message = self.loop.run_until_complete(recv) + self.assertIsNone(next_message) -class ClientTests(CommonTests, unittest.TestCase): +class ClientCloseTests(CommonTests, unittest.TestCase): def setUp(self): super().setUp() self.protocol.is_client = True def test_client_close(self): - self.loop.call_later(MS, self.async, self.echo()) + self.loop.call_later(MS, self.receive_frame, self.close_frame) self.loop.call_later(2 * MS, self.receive_eof) self.loop.run_until_complete(self.protocol.close(reason='because.')) + self.assertConnectionClosed(1000, 'because.') - # Only one frame is emitted, and it's consumed by self.echo(). + self.assertFrameSent(*self.close_frame) self.assertNoFrameSent() + # Closing the connection again is a no-op. self.loop.run_until_complete(self.protocol.close(reason='oh noes!')) + self.assertConnectionClosed(1000, 'because.') self.assertNoFrameSent() def test_server_close(self): - frame = Frame(True, OP_CLOSE, serialize_close(1000, 'because.')) - self.loop.call_later(MS, self.receive_frame, frame) + self.loop.call_later(MS, self.receive_frame, self.close_frame) self.loop.call_later(2 * MS, self.receive_eof) - # The client is waiting for some data at this point, and won't get it. - self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) + # The client is waiting for some data at this point but won't get it. + next_message = self.loop.run_until_complete(self.protocol.recv()) + + self.assertIsNone(next_message) # After recv() returns None, the connection is closed. self.assertConnectionClosed(1000, 'because.') - self.assertFrameSent(*frame) + self.assertFrameSent(*self.close_frame) + self.assertNoFrameSent() + # Closing the connection again is a no-op. self.loop.run_until_complete(self.protocol.close('oh noes!')) + self.assertConnectionClosed(1000, 'because.') self.assertNoFrameSent() def test_simultaneous_close(self): - server_close = Frame(True, OP_CLOSE, serialize_close(1000, 'server')) - client_close = Frame(True, OP_CLOSE, serialize_close(1000, 'client')) - self.loop.call_later(MS, self.receive_frame, server_close) + self.loop.call_later(MS, self.receive_frame, self.server_close) self.loop.call_later(2 * MS, self.receive_eof) self.loop.run_until_complete(self.protocol.close(reason='client')) + + # The close code and reason are taken from the remote side because + # that's presumably more useful that the values from the local side. self.assertConnectionClosed(1000, 'server') - self.assertFrameSent(*client_close) + self.assertFrameSent(*self.client_close) + self.assertNoFrameSent() + + def test_close_drops_frames(self): + text_frame = Frame(True, OP_TEXT, b'') + self.loop.call_later(MS, self.receive_frame, text_frame) + self.loop.call_later(2 * MS, self.receive_frame, self.close_frame) + self.loop.call_later(3 * MS, self.receive_eof) + self.loop.run_until_complete(self.protocol.close(reason='because.')) + + self.assertConnectionClosed(1000, 'because.') + self.assertFrameSent(*self.close_frame) self.assertNoFrameSent() - def test_close_timeout_before_eof_received(self): + def test_close_handshake_timeout(self): + # Timeout is expected in 1 + 2 * 10 = 21ms. + # Check the timing within -1/+5ms for robustness. + self.after = asyncio.Future(loop=self.loop) + self.loop.call_later(20 * MS, self.after.cancel) + self.before = asyncio.Future(loop=self.loop) + self.loop.call_later(25 * MS, self.before.cancel) + self.protocol.timeout = 10 * MS + + # Unlike previous tests, no close frame will be received in response + # and the connection will not be closed. The client will stop waiting + # for the close frame and timeout, then stop waiting for the + # connection close and timeout again. + self.loop.run_until_complete(self.protocol.close(reason='because.')) + + self.assertConnectionClosed(1000, 'because.') + self.assertTrue(self.after.cancelled()) + self.assertFalse(self.before.cancelled()) + self.before.cancel() + + def test_eof_received_timeout(self): + # Timeout is expected in 1 + 10 = 11ms. + # Check the timing within -1/+5ms for robustness. self.after = asyncio.Future(loop=self.loop) - self.loop.call_later(4 * MS, self.after.cancel) + self.loop.call_later(10 * MS, self.after.cancel) self.before = asyncio.Future(loop=self.loop) - self.loop.call_later(8 * MS, self.before.cancel) - self.protocol.timeout = 5 * MS - self.loop.call_later(MS, self.async, self.echo()) + self.loop.call_later(15 * MS, self.before.cancel) + self.protocol.timeout = 10 * MS + + # Unlike previous tests, the close frame will be received in response + # but the connection will not be closed. The client will stop waiting + # for the connection close and timeout. + self.loop.call_later(MS, self.receive_frame, self.close_frame) self.loop.run_until_complete(self.protocol.close(reason='because.')) - # If the server doesn't drop the connection quickly, the client will. + self.assertConnectionClosed(1000, 'because.') self.assertTrue(self.after.cancelled()) self.assertFalse(self.before.cancelled()) @@ -513,13 +572,39 @@ def delayed_write_frame(*args, **kwargs): self.protocol.write_frame = delayed_write_frame - frame = Frame(True, OP_CLOSE, serialize_close(1000, 'server')) - # Trigger the race condition between answering the close frame from - # the server and sending another close frame from the client. - self.loop.call_later(MS, self.receive_frame, frame) + # Trigger the race condition by failing the connection while answering + # the closing handshake initiated by the server. + self.loop.call_later(MS, self.receive_frame, self.server_close) fail_connection = self.protocol.fail_connection(1000, 'client') self.loop.call_later(2 * MS, self.async, fail_connection) - self.loop.call_later(3 * MS, self.receive_eof) - self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) + self.loop.call_later(4 * MS, self.receive_eof) + next_message = self.loop.run_until_complete(self.protocol.recv()) + + self.assertIsNone(next_message) self.assertConnectionClosed(1000, 'client') - self.assertFrameSent(*frame) + self.assertFrameSent(*self.server_close) + self.assertNoFrameSent() + + def test_close_protocol_error(self): + invalid_close_frame = Frame(True, OP_CLOSE, b'\x00') + self.loop.call_later(MS, self.receive_frame, invalid_close_frame) + self.loop.call_later(2 * MS, self.receive_eof) + self.loop.run_until_complete(self.protocol.close(reason='because.')) + + self.assertConnectionClosed(1002, '') + + def test_close_connection_lost(self): + self.loop.call_later(MS, self.receive_eof) + self.loop.run_until_complete(self.protocol.close(reason='because.')) + + self.assertConnectionClosed(1006, '') + + def test_close_during_recv(self): + recv = self.async(self.protocol.recv()) + self.loop.call_later(MS, self.receive_frame, self.close_frame) + self.loop.call_later(2 * MS, self.receive_eof) + self.loop.run_until_complete(self.protocol.close(reason='because.')) + + # Receiving a message shouldn't crash. + next_message = self.loop.run_until_complete(recv) + self.assertIsNone(next_message) From d4aaf8ca6754934f765ca0e6a20ff8b38db0e32f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 22 Jul 2015 10:40:52 +0200 Subject: [PATCH 0096/1539] Fix tests for sent frames. The previous implementation silently ignored trailing data, which means assertFrameSent + assertNoFrameSent didn't behave as one might expect. --- websockets/test_protocol.py | 91 ++++++++++++++++++++----------------- 1 file changed, 49 insertions(+), 42 deletions(-) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 4a94be1ba..f6456206f 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -86,30 +86,47 @@ def receive_eof(self): self.protocol.eof_received() self.transport.close() - @asyncio.coroutine - def sent(self): - """Read the next frame sent to the transport.""" + def process_control_frames(self): + """ + Process control frames received by the protocol. + + To ensure that recv completes quickly, receive an additional dummy + frame, which recv() will drop. + """ + self.receive_frame(Frame(True, OP_TEXT, b'')) + self.loop.run_until_complete(self.protocol.recv()) + + def last_sent_frame(self): + """ + Read the last frame sent to the transport. + + This method assumes that at most one frame was sent. It raises an + AssertionError otherwise. + """ stream = asyncio.StreamReader(loop=self.loop) + for (data,), kw in self.transport.write.call_args_list: stream.feed_data(data) self.transport.write.call_args_list = [] stream.feed_eof() - if not stream.at_eof(): - return (yield from read_frame( + + if stream.at_eof(): + frame = None + else: + frame = self.loop.run_until_complete(read_frame( stream.readexactly, self.protocol.is_client)) - def process_control_frames(self): - """Process control frames fed to the protocol.""" - self.receive_frame(Frame(True, OP_TEXT, b'')) - self.loop.run_until_complete(self.protocol.recv()) + if not stream.at_eof(): + data = self.loop.run_until_complete(stream.read()) + raise AssertionError("Trailing data found: {!r}".format(data)) + + return frame - def assertFrameSent(self, fin, opcode, data): - sent = self.loop.run_until_complete(self.sent()) - self.assertEqual(sent, Frame(fin, opcode, data)) + def assertOneFrameSent(self, fin, opcode, data): + self.assertEqual(self.last_sent_frame(), Frame(fin, opcode, data)) def assertNoFrameSent(self): - sent = self.loop.run_until_complete(self.sent()) - self.assertIsNone(sent) + self.assertIsNone(self.last_sent_frame()) def assertConnectionClosed(self, code, message): # The following line guarantees that connection_lost was called. @@ -202,11 +219,11 @@ def test_recv_cancelled(self): def test_send_text(self): self.loop.run_until_complete(self.protocol.send('café')) - self.assertFrameSent(True, OP_TEXT, 'café'.encode('utf-8')) + self.assertOneFrameSent(True, OP_TEXT, 'café'.encode('utf-8')) def test_send_binary(self): self.loop.run_until_complete(self.protocol.send(b'tea')) - self.assertFrameSent(True, OP_BINARY, b'tea') + self.assertOneFrameSent(True, OP_BINARY, b'tea') def test_send_type_error(self): with self.assertRaises(TypeError): @@ -222,7 +239,7 @@ def test_send_on_closed_connection(self): def test_answer_ping(self): self.receive_frame(Frame(True, OP_PING, b'test')) self.process_control_frames() - self.assertFrameSent(True, OP_PONG, b'test') + self.assertOneFrameSent(True, OP_PONG, b'test') def test_ignore_pong(self): self.receive_frame(Frame(True, OP_PONG, b'test')) @@ -232,7 +249,7 @@ def test_ignore_pong(self): def test_acknowledge_ping(self): ping = self.loop.run_until_complete(self.protocol.ping()) self.assertFalse(ping.done()) - ping_frame = self.loop.run_until_complete(self.sent()) + ping_frame = self.last_sent_frame() pong_frame = Frame(True, OP_PONG, ping_frame.data) self.receive_frame(pong_frame) self.process_control_frames() @@ -241,7 +258,7 @@ def test_acknowledge_ping(self): def test_acknowledge_previous_pings(self): pings = [( self.loop.run_until_complete(self.protocol.ping()), - self.loop.run_until_complete(self.sent()), + self.last_sent_frame(), ) for i in range(3)] # Unsolicited pong doesn't acknowledge pings self.receive_frame(Frame(True, OP_PONG, b'')) @@ -258,7 +275,7 @@ def test_acknowledge_previous_pings(self): def test_cancel_ping(self): ping = self.loop.run_until_complete(self.protocol.ping()) - ping_frame = self.loop.run_until_complete(self.sent()) + ping_frame = self.last_sent_frame() ping.cancel() pong_frame = Frame(True, OP_PONG, ping_frame.data) self.receive_frame(pong_frame) @@ -267,7 +284,7 @@ def test_cancel_ping(self): def test_duplicate_ping(self): self.loop.run_until_complete(self.protocol.ping(b'foobar')) - self.assertFrameSent(True, OP_PING, b'foobar') + self.assertOneFrameSent(True, OP_PING, b'foobar') with self.assertRaises(ValueError): self.loop.run_until_complete(self.protocol.ping(b'foobar')) self.assertNoFrameSent() @@ -321,7 +338,7 @@ def test_control_frame_within_fragmented_text(self): self.receive_frame(Frame(True, OP_CONT, 'fé'.encode('utf-8'))) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, 'café') - self.assertFrameSent(True, OP_PONG, b'') + self.assertOneFrameSent(True, OP_PONG, b'') def test_unterminated_fragmented_text(self): self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) @@ -352,8 +369,7 @@ def test_server_close(self): self.loop.run_until_complete(self.protocol.close(reason='because.')) self.assertConnectionClosed(1000, 'because.') - self.assertFrameSent(*self.close_frame) - self.assertNoFrameSent() + self.assertOneFrameSent(*self.close_frame) # Closing the connection again is a no-op. self.loop.run_until_complete(self.protocol.close(reason='oh noes!')) @@ -369,8 +385,7 @@ def test_client_close(self): self.assertIsNone(next_message) # After recv() returns None, the connection is closed. self.assertConnectionClosed(1000, 'because.') - self.assertFrameSent(*self.close_frame) - self.assertNoFrameSent() + self.assertOneFrameSent(*self.close_frame) # Closing the connection again is a no-op. self.loop.run_until_complete(self.protocol.close(reason='oh noes!')) @@ -385,8 +400,7 @@ def test_simultaneous_close(self): # The close code and reason are taken from the remote side because # that's presumably more useful that the values from the local side. self.assertConnectionClosed(1000, 'client') - self.assertFrameSent(*self.server_close) - self.assertNoFrameSent() + self.assertOneFrameSent(*self.server_close) def test_close_drops_frames(self): text_frame = Frame(True, OP_TEXT, b'') @@ -395,8 +409,7 @@ def test_close_drops_frames(self): self.loop.run_until_complete(self.protocol.close(reason='because.')) self.assertConnectionClosed(1000, 'because.') - self.assertFrameSent(*self.close_frame) - self.assertNoFrameSent() + self.assertOneFrameSent(*self.close_frame) def test_close_handshake_timeout(self): # Timeout is expected in 1 + 10 = 11ms. @@ -435,8 +448,7 @@ def delayed_write_frame(*args, **kwargs): self.assertIsNone(next_message) self.assertConnectionClosed(1000, 'server') - self.assertFrameSent(*self.client_close) - self.assertNoFrameSent() + self.assertOneFrameSent(*self.client_close) def test_close_protocol_error(self): invalid_close_frame = Frame(True, OP_CLOSE, b'\x00') @@ -473,8 +485,7 @@ def test_client_close(self): self.loop.run_until_complete(self.protocol.close(reason='because.')) self.assertConnectionClosed(1000, 'because.') - self.assertFrameSent(*self.close_frame) - self.assertNoFrameSent() + self.assertOneFrameSent(*self.close_frame) # Closing the connection again is a no-op. self.loop.run_until_complete(self.protocol.close(reason='oh noes!')) @@ -491,8 +502,7 @@ def test_server_close(self): self.assertIsNone(next_message) # After recv() returns None, the connection is closed. self.assertConnectionClosed(1000, 'because.') - self.assertFrameSent(*self.close_frame) - self.assertNoFrameSent() + self.assertOneFrameSent(*self.close_frame) # Closing the connection again is a no-op. self.loop.run_until_complete(self.protocol.close('oh noes!')) @@ -508,8 +518,7 @@ def test_simultaneous_close(self): # The close code and reason are taken from the remote side because # that's presumably more useful that the values from the local side. self.assertConnectionClosed(1000, 'server') - self.assertFrameSent(*self.client_close) - self.assertNoFrameSent() + self.assertOneFrameSent(*self.client_close) def test_close_drops_frames(self): text_frame = Frame(True, OP_TEXT, b'') @@ -519,8 +528,7 @@ def test_close_drops_frames(self): self.loop.run_until_complete(self.protocol.close(reason='because.')) self.assertConnectionClosed(1000, 'because.') - self.assertFrameSent(*self.close_frame) - self.assertNoFrameSent() + self.assertOneFrameSent(*self.close_frame) def test_close_handshake_timeout(self): # Timeout is expected in 1 + 2 * 10 = 21ms. @@ -582,8 +590,7 @@ def delayed_write_frame(*args, **kwargs): self.assertIsNone(next_message) self.assertConnectionClosed(1000, 'client') - self.assertFrameSent(*self.server_close) - self.assertNoFrameSent() + self.assertOneFrameSent(*self.server_close) def test_close_protocol_error(self): invalid_close_frame = Frame(True, OP_CLOSE, b'\x00') From 88bb97de0fef710b5320c3a65e3ca6b6fd4195b3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 22 Jul 2015 12:25:18 +0200 Subject: [PATCH 0097/1539] Improve hack to run the event loop once. --- websockets/test_client_server.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 4b0d8da43..d1f6c3812 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -44,6 +44,13 @@ def setUp(self): def tearDown(self): self.loop.close() + def run_loop_once(self): + # Process callbacks scheduled with call_soon. This pattern works + # because stop schedules a callback to stop the event loop and + # run_forever runs the loop until it hits this callback. + self.loop.stop() + self.loop.run_forever() + def start_server(self, **kwds): server = serve(handler, 'localhost', 8642, **kwds) self.server = self.loop.run_until_complete(server) @@ -55,11 +62,6 @@ def start_client(self, path='', **kwds): def stop_client(self): self.loop.run_until_complete(self.client.worker) - def notice_connection_close(self): - # When the client closes the connection, the server still believes - # it's open until the event loop has run once. Interesting hack. - self.loop.run_until_complete(asyncio.sleep(0, loop=self.loop)) - def stop_server(self): self.server.close() self.loop.run_until_complete(self.server.wait_closed()) @@ -228,7 +230,7 @@ def test_subprotocol_error(self, _select_subprotocol): self.start_server(subprotocols=['superchat']) with self.assertRaises(InvalidHandshake): self.start_client('subprotocol', subprotocols=['otherchat']) - self.notice_connection_close() + self.run_loop_once() self.stop_server() @unittest.mock.patch('websockets.server.read_request') @@ -247,7 +249,7 @@ def test_client_receives_malformed_response(self, _read_response): self.start_server() with self.assertRaises(InvalidHandshake): self.start_client() - self.notice_connection_close() + self.run_loop_once() self.stop_server() @unittest.mock.patch('websockets.client.build_request') @@ -283,7 +285,7 @@ def wrong_read_response(stream): self.start_server() with self.assertRaises(InvalidHandshake): self.start_client() - self.notice_connection_close() + self.run_loop_once() self.stop_server() @unittest.mock.patch('websockets.server.WebSocketServerProtocol.send') From be439b5c28c81e5a43bc88c1f931a89dbf926c30 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 22 Jul 2015 13:23:30 +0200 Subject: [PATCH 0098/1539] Don't send a close frame on a broken connection. This fix is exercised by a test refactor I'm working on in parallel. --- websockets/protocol.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index 2bb2de4fd..b82e05501 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -452,8 +452,10 @@ def fail_connection(self, code=1011, reason=''): # 7.1.7. Fail the WebSocket Connection logger.info("Failing the WebSocket connection: %d %s", code, reason) if self.state == OPEN: - frame_data = serialize_close(code, reason) - yield from self.write_frame(OP_CLOSE, frame_data) + # Don't send a close frame is the connection is broken already. + if not (code == 1006 or self.connection_closed.done()): + frame_data = serialize_close(code, reason) + yield from self.write_frame(OP_CLOSE, frame_data) self.state = CLOSING if not self.closing_handshake.done(): self.closing_handshake.set_result(False) From bb8db1e5a16dbc8b77e2b1b2138eea52ed330e14 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 26 Jul 2015 23:20:14 +0200 Subject: [PATCH 0099/1539] Remove supefluous check. fail_connection() can now be called safely multiple times in parallel, solving the general problem, while this check only adressed an instance. --- websockets/protocol.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index b82e05501..821bbb5b8 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -391,11 +391,8 @@ def write_frame(self, opcode, data=b'', expected_state=OPEN): # Handle flow control automatically. yield from self.writer.drain() except ConnectionResetError: - # Terminate the connection if the socket died, - # unless it's already being closed. - if expected_state != CLOSING: - self.state = CLOSING - yield from self.fail_connection(1006) + # Terminate the connection if the socket died. + yield from self.fail_connection(1006) @asyncio.coroutine def close_connection(self): From 6c2667bde53086e9854d53d78a1608172227f409 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 26 Jul 2015 23:21:49 +0200 Subject: [PATCH 0100/1539] Always update the state before yielding control. Otherwise multiple close frames could be sent. This fix is required by changes I'm making to the test suite and will commit soon. --- websockets/protocol.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index 821bbb5b8..5081bf01d 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -151,11 +151,13 @@ def close(self, code=1000, reason=''): """ if self.state == OPEN: # 7.1.2. Start the WebSocket Closing Handshake + # 7.1.3. The WebSocket Closing Handshake is Started self.close_code, self.close_reason = code, reason frame_data = serialize_close(code, reason) - yield from self.write_frame(OP_CLOSE, frame_data) - # 7.1.3. The WebSocket Closing Handshake is Started + # Change the state before yielding control to avoid sending more + # than one close frame. self.state = CLOSING + yield from self.write_frame(OP_CLOSE, frame_data) # If the connection doesn't terminate within the timeout, break out of # the worker loop. @@ -449,11 +451,11 @@ def fail_connection(self, code=1011, reason=''): # 7.1.7. Fail the WebSocket Connection logger.info("Failing the WebSocket connection: %d %s", code, reason) if self.state == OPEN: + self.state = CLOSING # Don't send a close frame is the connection is broken already. if not (code == 1006 or self.connection_closed.done()): frame_data = serialize_close(code, reason) yield from self.write_frame(OP_CLOSE, frame_data) - self.state = CLOSING if not self.closing_handshake.done(): self.closing_handshake.set_result(False) yield from self.close_connection() From 68da42aff05a8e5640f02cfa3b2050f4bfe0d4c2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 26 Jul 2015 23:22:52 +0200 Subject: [PATCH 0101/1539] Tighten state transitions. The code now always checks the current state before updating it (and never yields control in the meantime). --- websockets/protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index 5081bf01d..411465f06 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -345,7 +345,7 @@ def read_data_frame(self, max_size): # 5.5. Control Frames if frame.opcode == OP_CLOSE: self.close_code, self.close_reason = parse_close(frame.data) - if self.state != CLOSING: + if self.state == OPEN: # 7.1.3. The WebSocket Closing Handshake is Started self.state = CLOSING yield from self.write_frame( From 65748576877f9520e3ac0394ecbbf87261d16035 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 26 Jul 2015 23:23:46 +0200 Subject: [PATCH 0102/1539] Remove the expected_state parameter. Since we now always set the state to CLOSING before sending a close frame, it isn't needed anymore. --- websockets/protocol.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index 411465f06..521a560a4 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -348,8 +348,7 @@ def read_data_frame(self, max_size): if self.state == OPEN: # 7.1.3. The WebSocket Closing Handshake is Started self.state = CLOSING - yield from self.write_frame( - OP_CLOSE, frame.data, expected_state=CLOSING) + yield from self.write_frame(OP_CLOSE, frame.data) if not self.closing_handshake.done(): self.closing_handshake.set_result(True) return @@ -379,7 +378,8 @@ def read_frame(self, max_size): return frame @asyncio.coroutine - def write_frame(self, opcode, data=b'', expected_state=OPEN): + def write_frame(self, opcode, data=b''): + expected_state = CLOSING if opcode == OP_CLOSE else OPEN # This may happen if a user attempts to write on a closed connection. if self.state != expected_state: raise InvalidState("Cannot write to a WebSocket " From 66df21b6a8e45fceeb9234a2e64bc2c4cdfc5d37 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 26 Jul 2015 23:24:42 +0200 Subject: [PATCH 0103/1539] Large-scale test refactor. Back to 100% branch coverage. * Delay invocation of protocol methods (connection_made, data_received, eof_received, connection_lost) with call_soon() to match what asyncio does. This enforces a higher level of asynchrony in tests. * Use call_soon(...) instead of call_later(x * MS, ...) to schedule series of callbacks. asyncio will invoke them in order. * Use a context manager instead of a pair of futures for timing tests. * Factor out code that makes writes slow. * Add a test for connection close during send to mirror the existing test for connection close during recv. --- websockets/protocol.py | 3 + websockets/test_protocol.py | 297 ++++++++++++++++++++---------------- 2 files changed, 167 insertions(+), 133 deletions(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index 521a560a4..3dbfe767b 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -103,8 +103,11 @@ def __init__(self, *, # Futures tracking steps in the connection's lifecycle. self.opening_handshake = asyncio.Future(loop=loop) + # Set to True when the closing handshake has completed properly and to + # False when the connection terminates abnormally. self.closing_handshake = asyncio.Future(loop=loop) self.connection_failed = asyncio.Future(loop=loop) + # Set to None when the connection state becomes CLOSED. self.connection_closed = asyncio.Future(loop=loop) # Queue of received messages. diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index f6456206f..903e12356 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -1,4 +1,5 @@ import asyncio +import contextlib import functools import os import unittest @@ -30,10 +31,10 @@ class TransportMock(unittest.mock.Mock): def connect(self, loop, protocol): self.loop = loop self.protocol = protocol - self.protocol.connection_made(self) + self.loop.call_soon(self.protocol.connection_made, self) def close(self): - self.protocol.connection_lost(None) + self.loop.call_soon(self.protocol.connection_lost, None) class CommonTests: @@ -50,8 +51,28 @@ def tearDown(self): self.loop.close() super().tearDown() + def run_loop_once(self): + # Process callbacks scheduled with call_soon. This pattern works + # because stop schedules a callback to stop the event loop and + # run_forever runs the loop until it hits this callback. + self.loop.stop() + self.loop.run_forever() + + def make_drain_slow(self): + # Process connection_made in order to initialize self.protocol.writer. + self.run_loop_once() + + original_drain = self.protocol.writer.drain + + @asyncio.coroutine + def delayed_drain(): + yield from asyncio.sleep(3 * MS, loop=self.loop) + yield from original_drain() + + self.protocol.writer.drain = delayed_drain + # These frames are used in the ServerTests and ClientTests subclasses. - close_frame = Frame(True, OP_CLOSE, serialize_close(1000, 'because.')) + close_frame = Frame(True, OP_CLOSE, serialize_close(1000, 'close')) client_close = Frame(True, OP_CLOSE, serialize_close(1000, 'client')) server_close = Frame(True, OP_CLOSE, serialize_close(1000, 'server')) @@ -65,7 +86,7 @@ def receive_frame(self, frame): """ writer = self.protocol.data_received mask = not self.protocol.is_client - write_frame(frame, writer, mask) + self.loop.call_soon(write_frame, frame, writer, mask) def receive_eof(self): """ @@ -83,8 +104,8 @@ def receive_eof(self): This method is often called shortly after simulating invalid data to ensure that the connection fails quickly. """ - self.protocol.eof_received() - self.transport.close() + self.loop.call_soon(self.protocol.eof_received) + self.loop.call_soon(self.transport.close) def process_control_frames(self): """ @@ -134,6 +155,19 @@ def assertConnectionClosed(self, code, message): self.assertEqual(self.protocol.close_code, code) self.assertEqual(self.protocol.close_reason, message) + @contextlib.contextmanager + def assertCompletesWithin(self, min_time, max_time): + min_time *= MS + max_time *= MS + t0 = self.loop.time() + yield + t1 = self.loop.time() + dt = t1 - t0 + self.assertGreaterEqual( + dt, min_time, "Too fast: {} < {}".format(dt, min_time)) + self.assertLess( + dt, max_time, "Too slow: {} >= {}".format(dt, max_time)) + def test_open(self): self.assertTrue(self.protocol.open) self.protocol.connection_lost(None) @@ -155,27 +189,27 @@ def test_recv_binary(self): def test_recv_protocol_error(self): self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8'))) - self.loop.call_later(MS, self.receive_eof) + self.receive_eof() self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1002, '') def test_recv_unicode_error(self): self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('latin-1'))) - self.loop.call_later(MS, self.receive_eof) + self.receive_eof() self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1007, '') def test_recv_text_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8') * 205)) - self.loop.call_later(MS, self.receive_eof) + self.receive_eof() self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1009, '') def test_recv_binary_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(True, OP_BINARY, b'tea' * 342)) - self.loop.call_later(MS, self.receive_eof) + self.receive_eof() self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1009, '') @@ -196,7 +230,7 @@ def test_recv_other_error(self): def read_message(): raise Exception("BOOM") self.protocol.read_message = read_message - self.loop.call_later(MS, self.receive_eof) + self.receive_eof() self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) with self.assertRaises(Exception): self.loop.run_until_complete(self.protocol.worker) @@ -208,7 +242,7 @@ def test_recv_on_closed_connection(self): def test_recv_cancelled(self): recv = self.async(self.protocol.recv()) - self.loop.call_later(MS, recv.cancel) + self.loop.call_soon(recv.cancel) with self.assertRaises(asyncio.CancelledError): self.loop.run_until_complete(recv) @@ -232,6 +266,8 @@ def test_send_type_error(self): def test_send_on_closed_connection(self): self.receive_eof() + # Ensure the protocol processes the connection termination. + self.loop.run_until_complete(self.protocol.recv()) with self.assertRaises(InvalidState): self.loop.run_until_complete(self.protocol.send('foobar')) self.assertNoFrameSent() @@ -306,7 +342,7 @@ def test_fragmented_text_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(False, OP_TEXT, 'café'.encode('utf-8') * 100)) self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8') * 105)) - self.loop.call_later(MS, self.receive_eof) + self.receive_eof() self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1009, '') @@ -314,7 +350,7 @@ def test_fragmented_binary_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(False, OP_BINARY, b'tea' * 171)) self.receive_frame(Frame(True, OP_CONT, b'tea' * 171)) - self.loop.call_later(MS, self.receive_eof) + self.receive_eof() self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1009, '') @@ -344,20 +380,20 @@ def test_unterminated_fragmented_text(self): self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) # Missing the second part of the fragmented frame. self.receive_frame(Frame(True, OP_BINARY, b'tea')) - self.loop.call_later(MS, self.receive_eof) + self.receive_eof() self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1002, '') def test_close_handshake_in_fragmented_text(self): self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) self.receive_frame(Frame(True, OP_CLOSE, b'')) - self.loop.call_later(MS, self.receive_eof) + self.receive_eof() self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1002, '') def test_connection_close_in_fragmented_text(self): self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) - self.loop.call_later(MS, self.receive_eof) + self.receive_eof() self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) self.assertConnectionClosed(1006, '') @@ -365,36 +401,36 @@ def test_connection_close_in_fragmented_text(self): class ServerCloseTests(CommonTests, unittest.TestCase): def test_server_close(self): - self.loop.call_later(MS, self.receive_frame, self.close_frame) - self.loop.run_until_complete(self.protocol.close(reason='because.')) + self.receive_frame(self.close_frame) + self.loop.run_until_complete(self.protocol.close(reason='close')) - self.assertConnectionClosed(1000, 'because.') + self.assertConnectionClosed(1000, 'close') self.assertOneFrameSent(*self.close_frame) # Closing the connection again is a no-op. self.loop.run_until_complete(self.protocol.close(reason='oh noes!')) - self.assertConnectionClosed(1000, 'because.') + self.assertConnectionClosed(1000, 'close') self.assertNoFrameSent() def test_client_close(self): - self.loop.call_later(MS, self.receive_frame, self.close_frame) + self.receive_frame(self.close_frame) # The server is waiting for some data at this point but won't get it. next_message = self.loop.run_until_complete(self.protocol.recv()) self.assertIsNone(next_message) # After recv() returns None, the connection is closed. - self.assertConnectionClosed(1000, 'because.') + self.assertConnectionClosed(1000, 'close') self.assertOneFrameSent(*self.close_frame) # Closing the connection again is a no-op. self.loop.run_until_complete(self.protocol.close(reason='oh noes!')) - self.assertConnectionClosed(1000, 'because.') + self.assertConnectionClosed(1000, 'close') self.assertNoFrameSent() def test_simultaneous_close(self): - self.loop.call_later(MS, self.receive_frame, self.client_close) + self.receive_frame(self.client_close) self.loop.run_until_complete(self.protocol.close(reason='server')) # The close code and reason are taken from the remote side because @@ -404,46 +440,31 @@ def test_simultaneous_close(self): def test_close_drops_frames(self): text_frame = Frame(True, OP_TEXT, b'') - self.loop.call_later(MS, self.receive_frame, text_frame) - self.loop.call_later(2 * MS, self.receive_frame, self.close_frame) - self.loop.run_until_complete(self.protocol.close(reason='because.')) + self.receive_frame(text_frame) + self.receive_frame(self.close_frame) + self.loop.run_until_complete(self.protocol.close(reason='close')) - self.assertConnectionClosed(1000, 'because.') + self.assertConnectionClosed(1000, 'close') self.assertOneFrameSent(*self.close_frame) def test_close_handshake_timeout(self): - # Timeout is expected in 1 + 10 = 11ms. - # Check the timing within -1/+5ms for robustness. - self.after = asyncio.Future(loop=self.loop) - self.loop.call_later(10 * MS, self.after.cancel) - self.before = asyncio.Future(loop=self.loop) - self.loop.call_later(15 * MS, self.before.cancel) + # Timeout is expected in 10ms. self.protocol.timeout = 10 * MS - - # Unlike previous tests, no close frame will be received in response. - # The server will stop waiting for the close frame and timeout. - self.loop.run_until_complete(self.protocol.close(reason='because.')) - - self.assertConnectionClosed(1000, 'because.') - self.assertTrue(self.after.cancelled()) - self.assertFalse(self.before.cancelled()) - self.before.cancel() + # Check the timing within -1/+5ms for robustness. + with self.assertCompletesWithin(9, 15): + # Unlike previous tests, no close frame will be received in + # response. The server will stop waiting for the close frame and + # timeout. + self.loop.run_until_complete(self.protocol.close(reason='close')) + self.assertConnectionClosed(1000, 'close') def test_client_close_race_with_failing_connection(self): - original_write_frame = self.protocol.write_frame - - @asyncio.coroutine - def delayed_write_frame(*args, **kwargs): - yield from original_write_frame(*args, **kwargs) - yield from asyncio.sleep(2 * MS, loop=self.loop) - - self.protocol.write_frame = delayed_write_frame + self.make_drain_slow() - # Trigger the race condition by failing the connection while answering - # the closing handshake initiated by the client. - self.loop.call_later(MS, self.receive_frame, self.client_close) + # Fail the connection while answering a close frame from the client. + self.loop.call_soon(self.receive_frame, self.client_close) fail_connection = self.protocol.fail_connection(1000, 'server') - self.loop.call_later(2 * MS, self.async, fail_connection) + self.loop.call_later(MS, self.async, fail_connection) next_message = self.loop.run_until_complete(self.protocol.recv()) self.assertIsNone(next_message) @@ -452,26 +473,43 @@ def delayed_write_frame(*args, **kwargs): def test_close_protocol_error(self): invalid_close_frame = Frame(True, OP_CLOSE, b'\x00') - self.loop.call_later(MS, self.receive_frame, invalid_close_frame) - self.loop.run_until_complete(self.protocol.close(reason='because.')) + self.receive_frame(invalid_close_frame) + self.loop.run_until_complete(self.protocol.close(reason='close')) self.assertConnectionClosed(1002, '') def test_close_connection_lost(self): - self.loop.call_later(MS, self.receive_eof) - self.loop.run_until_complete(self.protocol.close(reason='because.')) + self.receive_eof() + self.loop.run_until_complete(self.protocol.close(reason='close')) self.assertConnectionClosed(1006, '') def test_close_during_recv(self): recv = self.async(self.protocol.recv()) - self.loop.call_later(MS, self.receive_frame, self.close_frame) - self.loop.run_until_complete(self.protocol.close(reason='because.')) + self.receive_frame(self.close_frame) + self.loop.run_until_complete(self.protocol.close(reason='close')) # Receiving a message shouldn't crash. next_message = self.loop.run_until_complete(recv) self.assertIsNone(next_message) + self.assertConnectionClosed(1000, 'close') + + def test_close_during_send(self): + self.make_drain_slow() + + send = self.async(self.protocol.send('hello')) + self.receive_frame(self.close_frame) + self.receive_eof() + + # Sending a message shouldn't crash. + self.loop.run_until_complete(send) + + # Complete the connection. + self.loop.run_until_complete(self.protocol.close(reason='close')) + + self.assertConnectionClosed(1006, '') + class ClientCloseTests(CommonTests, unittest.TestCase): @@ -480,39 +518,39 @@ def setUp(self): self.protocol.is_client = True def test_client_close(self): - self.loop.call_later(MS, self.receive_frame, self.close_frame) - self.loop.call_later(2 * MS, self.receive_eof) - self.loop.run_until_complete(self.protocol.close(reason='because.')) + self.receive_frame(self.close_frame) + self.receive_eof() + self.loop.run_until_complete(self.protocol.close(reason='close')) - self.assertConnectionClosed(1000, 'because.') + self.assertConnectionClosed(1000, 'close') self.assertOneFrameSent(*self.close_frame) # Closing the connection again is a no-op. self.loop.run_until_complete(self.protocol.close(reason='oh noes!')) - self.assertConnectionClosed(1000, 'because.') + self.assertConnectionClosed(1000, 'close') self.assertNoFrameSent() def test_server_close(self): - self.loop.call_later(MS, self.receive_frame, self.close_frame) - self.loop.call_later(2 * MS, self.receive_eof) + self.receive_frame(self.close_frame) + self.receive_eof() # The client is waiting for some data at this point but won't get it. next_message = self.loop.run_until_complete(self.protocol.recv()) self.assertIsNone(next_message) # After recv() returns None, the connection is closed. - self.assertConnectionClosed(1000, 'because.') + self.assertConnectionClosed(1000, 'close') self.assertOneFrameSent(*self.close_frame) # Closing the connection again is a no-op. self.loop.run_until_complete(self.protocol.close('oh noes!')) - self.assertConnectionClosed(1000, 'because.') + self.assertConnectionClosed(1000, 'close') self.assertNoFrameSent() def test_simultaneous_close(self): - self.loop.call_later(MS, self.receive_frame, self.server_close) - self.loop.call_later(2 * MS, self.receive_eof) + self.receive_frame(self.server_close) + self.receive_eof() self.loop.run_until_complete(self.protocol.close(reason='client')) # The close code and reason are taken from the remote side because @@ -522,70 +560,46 @@ def test_simultaneous_close(self): def test_close_drops_frames(self): text_frame = Frame(True, OP_TEXT, b'') - self.loop.call_later(MS, self.receive_frame, text_frame) - self.loop.call_later(2 * MS, self.receive_frame, self.close_frame) - self.loop.call_later(3 * MS, self.receive_eof) - self.loop.run_until_complete(self.protocol.close(reason='because.')) + self.receive_frame(text_frame) + self.receive_frame(self.close_frame) + self.receive_eof() + self.loop.run_until_complete(self.protocol.close(reason='close')) - self.assertConnectionClosed(1000, 'because.') + self.assertConnectionClosed(1000, 'close') self.assertOneFrameSent(*self.close_frame) def test_close_handshake_timeout(self): - # Timeout is expected in 1 + 2 * 10 = 21ms. - # Check the timing within -1/+5ms for robustness. - self.after = asyncio.Future(loop=self.loop) - self.loop.call_later(20 * MS, self.after.cancel) - self.before = asyncio.Future(loop=self.loop) - self.loop.call_later(25 * MS, self.before.cancel) + # Timeout is expected in 2 * 10 = 20ms. self.protocol.timeout = 10 * MS - - # Unlike previous tests, no close frame will be received in response - # and the connection will not be closed. The client will stop waiting - # for the close frame and timeout, then stop waiting for the - # connection close and timeout again. - self.loop.run_until_complete(self.protocol.close(reason='because.')) - - self.assertConnectionClosed(1000, 'because.') - self.assertTrue(self.after.cancelled()) - self.assertFalse(self.before.cancelled()) - self.before.cancel() + # Check the timing within -1/+5ms for robustness. + with self.assertCompletesWithin(19, 25): + # Unlike previous tests, no close frame will be received in + # response and the connection will not be closed. The client will + # stop waiting for the close frame and timeout, then stop waiting + # for the connection close and timeout again. + self.loop.run_until_complete(self.protocol.close(reason='close')) + self.assertConnectionClosed(1000, 'close') def test_eof_received_timeout(self): - # Timeout is expected in 1 + 10 = 11ms. - # Check the timing within -1/+5ms for robustness. - self.after = asyncio.Future(loop=self.loop) - self.loop.call_later(10 * MS, self.after.cancel) - self.before = asyncio.Future(loop=self.loop) - self.loop.call_later(15 * MS, self.before.cancel) + # Timeout is expected in 10ms. self.protocol.timeout = 10 * MS - - # Unlike previous tests, the close frame will be received in response - # but the connection will not be closed. The client will stop waiting - # for the connection close and timeout. - self.loop.call_later(MS, self.receive_frame, self.close_frame) - self.loop.run_until_complete(self.protocol.close(reason='because.')) - - self.assertConnectionClosed(1000, 'because.') - self.assertTrue(self.after.cancelled()) - self.assertFalse(self.before.cancelled()) - self.before.cancel() + # Check the timing within -1/+5ms for robustness. + with self.assertCompletesWithin(9, 15): + # Unlike previous tests, the close frame will be received in + # response but the connection will not be closed. The client will + # stop waiting for the connection close and timeout. + self.receive_frame(self.close_frame) + self.loop.run_until_complete(self.protocol.close(reason='close')) + self.assertConnectionClosed(1000, 'close') def test_server_close_race_with_failing_connection(self): - original_write_frame = self.protocol.write_frame + self.make_drain_slow() - @asyncio.coroutine - def delayed_write_frame(*args, **kwargs): - yield from original_write_frame(*args, **kwargs) - yield from asyncio.sleep(2 * MS, loop=self.loop) - - self.protocol.write_frame = delayed_write_frame - - # Trigger the race condition by failing the connection while answering - # the closing handshake initiated by the server. - self.loop.call_later(MS, self.receive_frame, self.server_close) + # Fail the connection while answering a close frame from the server. + self.loop.call_soon(self.receive_frame, self.server_close) fail_connection = self.protocol.fail_connection(1000, 'client') - self.loop.call_later(2 * MS, self.async, fail_connection) - self.loop.call_later(4 * MS, self.receive_eof) + self.loop.call_later(MS, self.async, fail_connection) + self.loop.call_later(2 * MS, self.receive_eof) next_message = self.loop.run_until_complete(self.protocol.recv()) self.assertIsNone(next_message) @@ -594,24 +608,41 @@ def delayed_write_frame(*args, **kwargs): def test_close_protocol_error(self): invalid_close_frame = Frame(True, OP_CLOSE, b'\x00') - self.loop.call_later(MS, self.receive_frame, invalid_close_frame) - self.loop.call_later(2 * MS, self.receive_eof) - self.loop.run_until_complete(self.protocol.close(reason='because.')) + self.receive_frame(invalid_close_frame) + self.receive_eof() + self.loop.run_until_complete(self.protocol.close(reason='close')) self.assertConnectionClosed(1002, '') def test_close_connection_lost(self): - self.loop.call_later(MS, self.receive_eof) - self.loop.run_until_complete(self.protocol.close(reason='because.')) + self.receive_eof() + self.loop.run_until_complete(self.protocol.close(reason='close')) self.assertConnectionClosed(1006, '') def test_close_during_recv(self): recv = self.async(self.protocol.recv()) - self.loop.call_later(MS, self.receive_frame, self.close_frame) - self.loop.call_later(2 * MS, self.receive_eof) - self.loop.run_until_complete(self.protocol.close(reason='because.')) + self.receive_frame(self.close_frame) + self.receive_eof() + self.loop.run_until_complete(self.protocol.close(reason='close')) # Receiving a message shouldn't crash. next_message = self.loop.run_until_complete(recv) self.assertIsNone(next_message) + + self.assertConnectionClosed(1000, 'close') + + def test_close_during_send(self): + self.make_drain_slow() + + send = self.async(self.protocol.send('hello')) + self.receive_frame(self.close_frame) + self.receive_eof() + + # Sending a message shouldn't crash. + self.loop.run_until_complete(send) + + # Complete the connection. + self.loop.run_until_complete(self.protocol.close(reason='close')) + + self.assertConnectionClosed(1006, '') From 8be2e61b7889d6911ed342caefe22cfdb4bd692b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 26 Jul 2015 23:40:48 +0200 Subject: [PATCH 0104/1539] Close connection more efficiently on handshake errors. This avoids a warning about an unclosed SSL socket in tests (at least). --- websockets/client.py | 2 +- websockets/protocol.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index 3908fe5bf..ac83ca4ac 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -145,7 +145,7 @@ def connect(uri, *, wsuri, origin=origin, subprotocols=subprotocols, extra_headers=extra_headers) except Exception: - protocol.writer.close() + yield from protocol.close_connection(force=True) raise return protocol diff --git a/websockets/protocol.py b/websockets/protocol.py index 3dbfe767b..888cd64ea 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -400,17 +400,17 @@ def write_frame(self, opcode, data=b''): yield from self.fail_connection(1006) @asyncio.coroutine - def close_connection(self): + def close_connection(self, force=False): # 7.1.1. Close the WebSocket Connection if self.state == CLOSED: return # Defensive assertion for protocol compliance. - if self.state != CLOSING: # pragma: no cover + if self.state != CLOSING and not force: # pragma: no cover raise InvalidState("Cannot close a WebSocket connection " "in the {} state".format(self.state_name)) - if self.is_client: + if self.is_client and not force: try: yield from asyncio.wait_for( self.connection_closed, self.timeout, loop=self.loop) From f406c2a99d41dc95fb9626acfd5e48c2b49e8148 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 27 Jul 2015 11:52:48 +0200 Subject: [PATCH 0105/1539] Add tests for hanshake requests and responses. --- websockets/handshake.py | 9 ++-- websockets/test_handshake.py | 90 +++++++++++++++++++++++++++++++++--- 2 files changed, 89 insertions(+), 10 deletions(-) diff --git a/websockets/handshake.py b/websockets/handshake.py index 74cfd966f..322d5557a 100644 --- a/websockets/handshake.py +++ b/websockets/handshake.py @@ -84,11 +84,12 @@ def check_request(get_header): token.strip() == 'upgrade' for token in get_header('Connection').lower().split(',')) key = get_header('Sec-WebSocket-Key') - assert len(base64.b64decode(key.encode())) == 16 + assert len(base64.b64decode(key.encode(), validate=True)) == 16 assert get_header('Sec-WebSocket-Version') == '13' - return key - except (AssertionError, KeyError) as exc: + except Exception as exc: raise InvalidHandshake("Invalid request") from exc + else: + return key def build_response(set_header, key): @@ -123,7 +124,7 @@ def check_response(get_header, key): token.strip() == 'upgrade' for token in get_header('Connection').lower().split(',')) assert get_header('Sec-WebSocket-Accept') == accept(key) - except (AssertionError, KeyError) as exc: + except Exception as exc: raise InvalidHandshake("Invalid response") from exc diff --git a/websockets/test_handshake.py b/websockets/test_handshake.py index c859b20e8..60ed808e6 100644 --- a/websockets/test_handshake.py +++ b/websockets/test_handshake.py @@ -1,3 +1,4 @@ +import contextlib import unittest from .exceptions import InvalidHandshake @@ -22,16 +23,93 @@ def test_round_trip(self): build_response(response_headers.__setitem__, response_key) check_response(response_headers.__getitem__, request_key) - def test_bad_request(self): + @contextlib.contextmanager + def assert_invalid_request_headers(self): + """ + Provide request headers for corruption. + + Assert that the transformation made them invalid. + """ headers = {} build_request(headers.__setitem__) - del headers['Sec-WebSocket-Key'] + yield headers with self.assertRaises(InvalidHandshake): check_request(headers.__getitem__) - def test_bad_response(self): + def test_request_invalid_upgrade(self): + with self.assert_invalid_request_headers() as headers: + headers['Upgrade'] = 'socketweb' + + def test_request_missing_upgrade(self): + with self.assert_invalid_request_headers() as headers: + del headers['Upgrade'] + + def test_request_invalid_connection(self): + with self.assert_invalid_request_headers() as headers: + headers['Connection'] = 'Downgrade' + + def test_request_missing_connection(self): + with self.assert_invalid_request_headers() as headers: + del headers['Connection'] + + def test_request_invalid_key_not_base64(self): + with self.assert_invalid_request_headers() as headers: + headers['Sec-WebSocket-Key'] = "!@#$%^&*()" + + def test_request_invalid_key_not_well_padded(self): + with self.assert_invalid_request_headers() as headers: + headers['Sec-WebSocket-Key'] = "CSIRmL8dWYxeAdr/XpEHRw" + + def test_request_invalid_key_not_16_bytes_long(self): + with self.assert_invalid_request_headers() as headers: + headers['Sec-WebSocket-Key'] = "ZLpprpvK4PE=" + + def test_request_missing_key(self): + with self.assert_invalid_request_headers() as headers: + del headers['Sec-WebSocket-Key'] + + def test_request_invalid_version(self): + with self.assert_invalid_request_headers() as headers: + headers['Sec-WebSocket-Version'] = '42' + + def test_request_missing_version(self): + with self.assert_invalid_request_headers() as headers: + del headers['Sec-WebSocket-Version'] + + @contextlib.contextmanager + def assert_invalid_response_headers(self, key='CSIRmL8dWYxeAdr/XpEHRw=='): + """ + Provide response headers for corruption. + + Assert that the transformation made them invalid. + """ headers = {} - build_response(headers.__setitem__, 'blabla') - del headers['Sec-WebSocket-Accept'] + build_response(headers.__setitem__, key) + yield headers with self.assertRaises(InvalidHandshake): - check_response(headers.__getitem__, 'blabla') + check_response(headers.__getitem__, key) + + def test_response_invalid_upgrade(self): + with self.assert_invalid_response_headers() as headers: + headers['Upgrade'] = 'socketweb' + + def test_response_missing_upgrade(self): + with self.assert_invalid_response_headers() as headers: + del headers['Upgrade'] + + def test_response_invalid_connection(self): + with self.assert_invalid_response_headers() as headers: + headers['Connection'] = 'Downgrade' + + def test_response_missing_connection(self): + with self.assert_invalid_response_headers() as headers: + del headers['Connection'] + + def test_response_invalid_accept(self): + with self.assert_invalid_response_headers() as headers: + other_key = "1Eq4UDEFQYg3YspNgqxv5g==" + headers['Sec-WebSocket-Accept'] = accept(other_key) + + def test_response_missing_accept(self): + with self.assert_invalid_response_headers() as headers: + del headers['Sec-WebSocket-Accept'] From a16d30de2abf3fe3d241fcac672657fe213d8731 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 27 Jul 2015 11:55:05 +0200 Subject: [PATCH 0106/1539] Fail tests if branch coverage isn't at 100%. --- Makefile | 2 +- tox.ini | 15 ++++++++++----- websockets/test_protocol.py | 2 +- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/Makefile b/Makefile index 7f96bcdd7..e942ecd75 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ test: coverage: python -m coverage erase python -m coverage run --branch --source=websockets -m unittest - python -m coverage html --omit='websockets/test_*.py' + python -m coverage html clean: find . -name '*.pyc' -delete diff --git a/tox.ini b/tox.ini index 78c9dd8e4..56e949418 100644 --- a/tox.ini +++ b/tox.ini @@ -1,17 +1,22 @@ [tox] -envlist = py33,py34,flake8,isort +envlist = py33,py34,coverage,flake8,isort [testenv] deps = py33: asyncio commands = python -m unittest +[testenv:coverage] +commands = + python -m coverage erase + python -m coverage run --branch --source=websockets -m unittest + python -m coverage report --fail-under=100 +deps = coverage + [testenv:flake8] commands = flake8 websockets -deps = - flake8 +deps = flake8 [testenv:isort] commands = isort --check-only --recursive websockets -deps = - isort +deps = isort diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 903e12356..4df4247dd 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -137,7 +137,7 @@ def last_sent_frame(self): frame = self.loop.run_until_complete(read_frame( stream.readexactly, self.protocol.is_client)) - if not stream.at_eof(): + if not stream.at_eof(): # pragma: no cover data = self.loop.run_until_complete(stream.read()) raise AssertionError("Trailing data found: {!r}".format(data)) From e085d334b409f6e797da426a127b3680b8e21156 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 27 Jul 2015 12:20:48 +0200 Subject: [PATCH 0107/1539] Remove lock protecting fail_connection. Since we now always set the state to CLOSING before sending a close frame, it isn't needed anymore. --- websockets/protocol.py | 10 ---------- websockets/test_protocol.py | 6 ++++-- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index 888cd64ea..3bcb514e9 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -106,7 +106,6 @@ def __init__(self, *, # Set to True when the closing handshake has completed properly and to # False when the connection terminates abnormally. self.closing_handshake = asyncio.Future(loop=loop) - self.connection_failed = asyncio.Future(loop=loop) # Set to None when the connection state becomes CLOSED. self.connection_closed = asyncio.Future(loop=loop) @@ -438,15 +437,6 @@ def close_connection(self, force=False): @asyncio.coroutine def fail_connection(self, code=1011, reason=''): - # Avoid calling fail_connection more than once to minimize - # the consequences of race conditions between the two sides. - if self.connection_failed.done(): - # Wait until the other coroutine calls connection_lost. - yield from self.connection_closed - return - else: - self.connection_failed.set_result(None) - # Losing the connection usually results in a protocol error. # Preserve the original error code in this case. if self.close_code != 1006: diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 4df4247dd..644a654ec 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -468,7 +468,8 @@ def test_client_close_race_with_failing_connection(self): next_message = self.loop.run_until_complete(self.protocol.recv()) self.assertIsNone(next_message) - self.assertConnectionClosed(1000, 'server') + # The connection was closed before the close frame could be sent. + self.assertConnectionClosed(1006, '') self.assertOneFrameSent(*self.client_close) def test_close_protocol_error(self): @@ -603,7 +604,8 @@ def test_server_close_race_with_failing_connection(self): next_message = self.loop.run_until_complete(self.protocol.recv()) self.assertIsNone(next_message) - self.assertConnectionClosed(1000, 'client') + # The connection was closed before the close frame could be sent. + self.assertConnectionClosed(1006, '') self.assertOneFrameSent(*self.server_close) def test_close_protocol_error(self): From db6b7a51be61aa4db6a75997765783db309255ec Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 27 Jul 2015 13:51:43 +0200 Subject: [PATCH 0108/1539] Add a comment for consistency. --- websockets/protocol.py | 1 + 1 file changed, 1 insertion(+) diff --git a/websockets/protocol.py b/websockets/protocol.py index 3bcb514e9..d3485d8b9 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -102,6 +102,7 @@ def __init__(self, *, self.close_reason = '' # Futures tracking steps in the connection's lifecycle. + # Set to True when the opening handshake has completed properly. self.opening_handshake = asyncio.Future(loop=loop) # Set to True when the closing handshake has completed properly and to # False when the connection terminates abnormally. From 9bdfb81780f71eac09dd154a75a09e9dbe03b81e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 27 Jul 2015 14:32:02 +0200 Subject: [PATCH 0109/1539] Increase guarantees against invalid state changes. --- websockets/client.py | 1 + websockets/protocol.py | 25 ++++++++++++++++--------- websockets/server.py | 1 + 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index ac83ca4ac..d3677c321 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -91,6 +91,7 @@ def handshake(self, wsuri, raise InvalidHandshake( "Unknown subprotocol: {}".format(self.subprotocol)) + assert self.state == CONNECTING self.state = OPEN self.opening_handshake.set_result(True) diff --git a/websockets/protocol.py b/websockets/protocol.py index d3485d8b9..7bfb82c7c 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -23,8 +23,14 @@ logger = logging.getLogger(__name__) +# A WebSocket connection goes through the following four states, in order: + CONNECTING, OPEN, CLOSING, CLOSED = range(4) +# In order to ensure consistency, the code always checks the current value of +# WebSocketCommonProtocol.state before assigning a new value and never yields +# between the check and the assignment. + class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): """ @@ -157,9 +163,6 @@ def close(self, code=1000, reason=''): # 7.1.3. The WebSocket Closing Handshake is Started self.close_code, self.close_reason = code, reason frame_data = serialize_close(code, reason) - # Change the state before yielding control to avoid sending more - # than one close frame. - self.state = CLOSING yield from self.write_frame(OP_CLOSE, frame_data) # If the connection doesn't terminate within the timeout, break out of @@ -350,7 +353,6 @@ def read_data_frame(self, max_size): self.close_code, self.close_reason = parse_close(frame.data) if self.state == OPEN: # 7.1.3. The WebSocket Closing Handshake is Started - self.state = CLOSING yield from self.write_frame(OP_CLOSE, frame.data) if not self.closing_handshake.done(): self.closing_handshake.set_result(True) @@ -382,11 +384,14 @@ def read_frame(self, max_size): @asyncio.coroutine def write_frame(self, opcode, data=b''): - expected_state = CLOSING if opcode == OP_CLOSE else OPEN # This may happen if a user attempts to write on a closed connection. - if self.state != expected_state: + if self.state != OPEN: raise InvalidState("Cannot write to a WebSocket " "in the {} state".format(self.state_name)) + # Make sure no other frame will be sent after a close frame. Do this + # before yielding control to avoid sending more than one close frame. + if opcode == OP_CLOSE: + self.state = CLOSING frame = Frame(True, opcode, data) side = 'client' if self.is_client else 'server' logger.debug("%s >> %s", side, frame) @@ -445,9 +450,11 @@ def fail_connection(self, code=1011, reason=''): # 7.1.7. Fail the WebSocket Connection logger.info("Failing the WebSocket connection: %d %s", code, reason) if self.state == OPEN: - self.state = CLOSING - # Don't send a close frame is the connection is broken already. - if not (code == 1006 or self.connection_closed.done()): + if code == 1006: + # Don't send a close frame is the connection is broken. Set + # the state to CLOSING to allow close_connection to proceed. + self.state = CLOSING + else: frame_data = serialize_close(code, reason) yield from self.write_frame(OP_CLOSE, frame_data) if not self.closing_handshake.done(): diff --git a/websockets/server.py b/websockets/server.py index 997440a3a..078155850 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -153,6 +153,7 @@ def handshake(self, origins=None, subprotocols=None, extra_headers=None): response = '\r\n'.join(response).encode() self.writer.write(response) + assert self.state == CONNECTING self.state = OPEN self.opening_handshake.set_result(True) From 471cbbc72760802f334fdf68e5c1c3ef54ca98eb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 28 Jul 2015 15:08:26 +0200 Subject: [PATCH 0110/1539] Terminate each connection in tests. This prevents spurious warnings about unterminated tasks. --- websockets/test_protocol.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 644a654ec..a0003f5f0 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -48,6 +48,8 @@ def setUp(self): self.transport.connect(self.loop, self.protocol) def tearDown(self): + self.loop.run_until_complete( + self.protocol.close_connection(force=True)) self.loop.close() super().tearDown() From 88f067fdc04accd62b0664d83f744844b4f4578c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 28 Jul 2015 22:22:04 +0200 Subject: [PATCH 0111/1539] Set the close status code and reason more consistently. Set them when the closing handshake is considered complete or aborted. --- docs/index.rst | 2 ++ websockets/protocol.py | 16 ++++++++-------- websockets/test_protocol.py | 20 +++++++++----------- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index 2ae398eb4..39eafc1f0 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -305,6 +305,8 @@ Changelog * Cancelling :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` no longer drops the next message. +* Set the close status code and reason more consistently. + * Improved tests. 2.4 diff --git a/websockets/protocol.py b/websockets/protocol.py index 7bfb82c7c..16da73b3f 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -104,6 +104,7 @@ def __init__(self, *, self.subprotocol = None + # Code and reason must be set when the closing handshake completes. self.close_code = None self.close_reason = '' @@ -161,7 +162,6 @@ def close(self, code=1000, reason=''): if self.state == OPEN: # 7.1.2. Start the WebSocket Closing Handshake # 7.1.3. The WebSocket Closing Handshake is Started - self.close_code, self.close_reason = code, reason frame_data = serialize_close(code, reason) yield from self.write_frame(OP_CLOSE, frame_data) @@ -350,11 +350,13 @@ def read_data_frame(self, max_size): frame = yield from self.read_frame(max_size) # 5.5. Control Frames if frame.opcode == OP_CLOSE: - self.close_code, self.close_reason = parse_close(frame.data) + # Make sure the close frame is valid before echoing it. + code, reason = parse_close(frame.data) if self.state == OPEN: # 7.1.3. The WebSocket Closing Handshake is Started yield from self.write_frame(OP_CLOSE, frame.data) if not self.closing_handshake.done(): + self.close_code, self.close_reason = code, reason self.closing_handshake.set_result(True) return elif frame.opcode == OP_PING: @@ -443,10 +445,6 @@ def close_connection(self, force=False): @asyncio.coroutine def fail_connection(self, code=1011, reason=''): - # Losing the connection usually results in a protocol error. - # Preserve the original error code in this case. - if self.close_code != 1006: - self.close_code, self.close_reason = code, reason # 7.1.7. Fail the WebSocket Connection logger.info("Failing the WebSocket connection: %d %s", code, reason) if self.state == OPEN: @@ -458,6 +456,7 @@ def fail_connection(self, code=1011, reason=''): frame_data = serialize_close(code, reason) yield from self.write_frame(OP_CLOSE, frame_data) if not self.closing_handshake.done(): + self.close_code, self.close_reason = code, reason self.closing_handshake.set_result(False) yield from self.close_connection() @@ -472,8 +471,9 @@ def client_connected(self, reader, writer): def connection_lost(self, exc): # 7.1.4. The WebSocket Connection is Closed self.state = CLOSED + if not self.closing_handshake.done(): + self.close_code, self.close_reason = 1006, '' + self.closing_handshake.set_result(False) if not self.connection_closed.done(): self.connection_closed.set_result(None) - if self.close_code is None: - self.close_code = 1006 super().connection_lost(exc) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index a0003f5f0..8d11be325 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -391,7 +391,7 @@ def test_close_handshake_in_fragmented_text(self): self.receive_frame(Frame(True, OP_CLOSE, b'')) self.receive_eof() self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) - self.assertConnectionClosed(1002, '') + self.assertConnectionClosed(1005, '') def test_connection_close_in_fragmented_text(self): self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) @@ -458,20 +458,19 @@ def test_close_handshake_timeout(self): # response. The server will stop waiting for the close frame and # timeout. self.loop.run_until_complete(self.protocol.close(reason='close')) - self.assertConnectionClosed(1000, 'close') + self.assertConnectionClosed(1006, '') def test_client_close_race_with_failing_connection(self): self.make_drain_slow() # Fail the connection while answering a close frame from the client. self.loop.call_soon(self.receive_frame, self.client_close) - fail_connection = self.protocol.fail_connection(1000, 'server') - self.loop.call_later(MS, self.async, fail_connection) + self.loop.call_later(MS, self.async, self.protocol.fail_connection()) next_message = self.loop.run_until_complete(self.protocol.recv()) self.assertIsNone(next_message) - # The connection was closed before the close frame could be sent. - self.assertConnectionClosed(1006, '') + # The closing handshake was completed by fail_connection. + self.assertConnectionClosed(1011, '') self.assertOneFrameSent(*self.client_close) def test_close_protocol_error(self): @@ -581,7 +580,7 @@ def test_close_handshake_timeout(self): # stop waiting for the close frame and timeout, then stop waiting # for the connection close and timeout again. self.loop.run_until_complete(self.protocol.close(reason='close')) - self.assertConnectionClosed(1000, 'close') + self.assertConnectionClosed(1006, '') def test_eof_received_timeout(self): # Timeout is expected in 10ms. @@ -600,14 +599,13 @@ def test_server_close_race_with_failing_connection(self): # Fail the connection while answering a close frame from the server. self.loop.call_soon(self.receive_frame, self.server_close) - fail_connection = self.protocol.fail_connection(1000, 'client') - self.loop.call_later(MS, self.async, fail_connection) + self.loop.call_later(MS, self.async, self.protocol.fail_connection()) self.loop.call_later(2 * MS, self.receive_eof) next_message = self.loop.run_until_complete(self.protocol.recv()) self.assertIsNone(next_message) - # The connection was closed before the close frame could be sent. - self.assertConnectionClosed(1006, '') + # The closing handshake was completed by fail_connection. + self.assertConnectionClosed(1011, '') self.assertOneFrameSent(*self.server_close) def test_close_protocol_error(self): From 930b9d98f44a3a142c58d94640014d191a1e851e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 28 Jul 2015 22:33:09 +0200 Subject: [PATCH 0112/1539] Add changelog for changes spread across multiple recent commits. --- docs/index.rst | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/index.rst b/docs/index.rst index 39eafc1f0..1bfc179f7 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -305,9 +305,13 @@ Changelog * Cancelling :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` no longer drops the next message. +* Clarified that the closing handshake can be initiated by the client. + * Set the close status code and reason more consistently. -* Improved tests. +* Strengthened connection termination by simplifying the implementation. + +* Improved tests, added tox configuration, and enforced 100% branch coverage. 2.4 ... From 0c43316b07e93bddb3a688a104713f8a91dfaff5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 28 Jul 2015 22:41:01 +0200 Subject: [PATCH 0113/1539] Bump version number. --- docs/conf.py | 4 ++-- websockets/version.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index f15b50f92..6bcffe9b7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -48,9 +48,9 @@ # built documents. # # The short X.Y version. -version = '2.4' +version = '2.5' # The full version, including alpha/beta/rc tags. -release = '2.4' +release = '2.5' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/websockets/version.py b/websockets/version.py index d3300deb3..2250731ca 100644 --- a/websockets/version.py +++ b/websockets/version.py @@ -1 +1 @@ -version = '2.4' +version = '2.5' From 9fcad1999e6feec4c2f25ef5e36a195454ce5036 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 7 Aug 2015 15:15:24 +0200 Subject: [PATCH 0114/1539] Increase timeout 10 times in asyncio debug mode. --- websockets/test_protocol.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 8d11be325..d718ffede 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -14,6 +14,10 @@ # WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variables. MS = 0.001 * int(os.environ.get('WEBSOCKETS_TESTS_TIMEOUT_FACTOR', 1)) +# asyncio's debug mode has a 10x performance penalty for this test suite. +if os.environ.get('PYTHONASYNCIODEBUG'): # pragma: no cover + MS *= 10 + class TransportMock(unittest.mock.Mock): """ From 5482cccab5be5aee7a80c959f3335697862477a6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 7 Aug 2015 15:15:57 +0200 Subject: [PATCH 0115/1539] Make assertCompletesWithin more explicit. Also increase tolerance as much as possible. --- websockets/test_protocol.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index d718ffede..e4d25893b 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -163,8 +163,6 @@ def assertConnectionClosed(self, code, message): @contextlib.contextmanager def assertCompletesWithin(self, min_time, max_time): - min_time *= MS - max_time *= MS t0 = self.loop.time() yield t1 = self.loop.time() @@ -456,8 +454,8 @@ def test_close_drops_frames(self): def test_close_handshake_timeout(self): # Timeout is expected in 10ms. self.protocol.timeout = 10 * MS - # Check the timing within -1/+5ms for robustness. - with self.assertCompletesWithin(9, 15): + # Check the timing within -1/+9ms for robustness. + with self.assertCompletesWithin(9 * MS, 19 * MS): # Unlike previous tests, no close frame will be received in # response. The server will stop waiting for the close frame and # timeout. @@ -577,8 +575,8 @@ def test_close_drops_frames(self): def test_close_handshake_timeout(self): # Timeout is expected in 2 * 10 = 20ms. self.protocol.timeout = 10 * MS - # Check the timing within -1/+5ms for robustness. - with self.assertCompletesWithin(19, 25): + # Check the timing within -1/+9ms for robustness. + with self.assertCompletesWithin(19 * MS, 29 * MS): # Unlike previous tests, no close frame will be received in # response and the connection will not be closed. The client will # stop waiting for the close frame and timeout, then stop waiting @@ -589,8 +587,8 @@ def test_close_handshake_timeout(self): def test_eof_received_timeout(self): # Timeout is expected in 10ms. self.protocol.timeout = 10 * MS - # Check the timing within -1/+5ms for robustness. - with self.assertCompletesWithin(9, 15): + # Check the timing within -1/+9ms for robustness. + with self.assertCompletesWithin(9 * MS, 19 * MS): # Unlike previous tests, the close frame will be received in # response but the connection will not be closed. The client will # stop waiting for the connection close and timeout. From 8e5a3954d62d01e2201aedaf4fe366d67e6af2c3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 7 Aug 2015 15:05:14 +0200 Subject: [PATCH 0116/1539] Close open connections on server shutdown. Fix #64. --- docs/index.rst | 7 ++- websockets/protocol.py | 2 +- websockets/server.py | 85 +++++++++++++++++++++++++++++--- websockets/test_client_server.py | 5 ++ 4 files changed, 91 insertions(+), 8 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index 1bfc179f7..028d442e8 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -220,7 +220,7 @@ Server .. autofunction:: serve(ws_handler, host=None, port=None, *, loop=None, klass=WebSocketServerProtocol, origins=None, subprotocols=None, extra_headers=None, **kwds) - .. autoclass:: WebSocketServerProtocol(ws_handler, *, origins=None, subprotocols=None, extra_headers=None, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, loop=None) + .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, origins=None, subprotocols=None, extra_headers=None, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, loop=None) .. automethod:: handshake(origins=None, subprotocols=None, extra_headers=None) .. automethod:: select_subprotocol(client_protos, server_protos) @@ -288,6 +288,11 @@ Utilities Changelog --------- +2.6 +... + +* Closed open connections with code 1001 when a server shuts down. + 2.5 ... diff --git a/websockets/protocol.py b/websockets/protocol.py index 16da73b3f..1ac5bf2a3 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -90,7 +90,7 @@ def __init__(self, *, self.timeout = timeout self.max_size = max_size - # Store a reference to loop to avoid relying on self.loop, a private + # Store a reference to loop to avoid relying on self._loop, a private # attribute of StreamReaderProtocol, inherited from FlowControlMixin. self.loop = loop diff --git a/websockets/server.py b/websockets/server.py index 078155850..238e9aa59 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -31,9 +31,10 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): state = CONNECTING - def __init__(self, ws_handler, *, + def __init__(self, ws_handler, ws_server, *, origins=None, subprotocols=None, extra_headers=None, **kwds): self.ws_handler = ws_handler + self.ws_server = ws_server self.origins = origins self.subprotocols = subprotocols self.extra_headers = extra_headers @@ -169,6 +170,68 @@ def select_subprotocol(self, client_protos, server_protos): priority = lambda p: client_protos.index(p) + server_protos.index(p) return sorted(common_protos, key=priority)[0] + def client_connected(self, reader, writer): + super().client_connected(reader, writer) + self.ws_server.register(self) + + def connection_lost(self, exc): + self.ws_server.unregister(self) + super().connection_lost(exc) + + +class WebSocketServer(asyncio.AbstractServer): + """ + Wrapper for :class:`~asyncio.Server` that triggers the closing handshake. + """ + + def __init__(self, loop=None): + # Store a reference to loop to avoid relying on self.server._loop. + self.loop = loop + + self.websockets = set() + self.closing_tasks = set() + + def wrap(self, server): + """ + Attach to a given :class:`~asyncio.Server`. + + Since :meth:`~asyncio.BaseEventLoop.create_server` doesn't support + injecting a custom ``Server`` class, a simple solution that doesn't + rely on private APIs is to: + + - instantiate a :class:`WebSocketServer` + - give the protocol factory a reference to that instance + - call :meth:`~asyncio.BaseEventLoop.create_server` with the factory + - attach the resulting :class:`~asyncio.Server` with this method + """ + self.server = server + + def register(self, protocol): + self.websockets.add(protocol) + + def unregister(self, protocol): + self.websockets.remove(protocol) + + def close(self): + """ + Stop serving and trigger a closing handshake on open connections. + """ + self.closing_tasks = { + asyncio.async(websocket.fail_connection(1001), loop=self.loop) + for websocket in self.websockets + } + self.server.close() + + @asyncio.coroutine + def wait_closed(self): + """ + Wait until all connections are closed. + """ + # asyncio.wait doesn't accept an empty first argument. + if self.closing_tasks: + yield from asyncio.wait(self.closing_tasks, loop=self.loop) + yield from self.server.wait_closed() + @asyncio.coroutine def serve(ws_handler, host=None, port=None, *, @@ -197,9 +260,12 @@ def serve(ws_handler, host=None, port=None, *, mapping, an iterable of (name, value) pairs, or a callable taking the request path and headers in arguments. - :func:`serve` yields a :class:`~asyncio.Server` which provides a - :meth:`~asyncio.Server.close` method and a - :meth:`~asyncio.Server.wait_closed` coroutine to stop serving requests. + :func:`serve` yields a :class:`~asyncio.Server` which provides: + + * a :meth:`~asyncio.Server.close` method that closes open connections with + status code 1001 and stops accepting new connections + * a :meth:`~asyncio.Server.wait_closed` coroutine that waits until closing + handshakes complete and connections are closed. Whenever a client connects, the server accepts the connection, creates a :class:`WebSocketServerProtocol`, performs the opening handshake, and @@ -218,9 +284,16 @@ def serve(ws_handler, host=None, port=None, *, if loop is None: loop = asyncio.get_event_loop() + ws_server = WebSocketServer() + secure = kwds.get('ssl') is not None factory = lambda: klass( - ws_handler, host=host, port=port, secure=secure, + ws_handler, ws_server, + host=host, port=port, secure=secure, origins=origins, subprotocols=subprotocols, extra_headers=extra_headers, loop=loop) - return (yield from loop.create_server(factory, host, port, **kwds)) + server = yield from loop.create_server(factory, host, port, **kwds) + + ws_server.wrap(server) + + return ws_server diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index d1f6c3812..81235f66b 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -75,6 +75,11 @@ def test_basic(self): self.stop_client() self.stop_server() + def test_server_close_while_client_connected(self): + self.start_server() + self.start_client() + self.stop_server() + def test_explicit_event_loop(self): self.start_server(loop=self.loop) self.start_client(loop=self.loop) From a52cbbb1cf68b89f7690f1570ff6aa7721f8efe7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 7 Aug 2015 23:14:23 +0200 Subject: [PATCH 0117/1539] Wait on the handler task instead of the closing handshake. Fix #64 (again). --- websockets/server.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/websockets/server.py b/websockets/server.py index 238e9aa59..0b57fc7bb 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -42,7 +42,7 @@ def __init__(self, ws_handler, ws_server, *, def connection_made(self, transport): super().connection_made(transport) - asyncio.async(self.handler(), loop=self.loop) + self.handler_task = asyncio.async(self.handler(), loop=self.loop) @asyncio.coroutine def handler(self): @@ -189,7 +189,6 @@ def __init__(self, loop=None): self.loop = loop self.websockets = set() - self.closing_tasks = set() def wrap(self, server): """ @@ -216,10 +215,8 @@ def close(self): """ Stop serving and trigger a closing handshake on open connections. """ - self.closing_tasks = { + for websocket in self.websockets: asyncio.async(websocket.fail_connection(1001), loop=self.loop) - for websocket in self.websockets - } self.server.close() @asyncio.coroutine @@ -228,8 +225,9 @@ def wait_closed(self): Wait until all connections are closed. """ # asyncio.wait doesn't accept an empty first argument. - if self.closing_tasks: - yield from asyncio.wait(self.closing_tasks, loop=self.loop) + if self.websockets: + yield from asyncio.wait( + [ws.handler_task for ws in self.websockets], loop=self.loop) yield from self.server.wait_closed() From 93bad7d638db60d91cc3909fbc97b8923d0765e4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 8 Aug 2015 13:01:08 +0200 Subject: [PATCH 0118/1539] Fix test failures on Windows. Fix #63. --- websockets/test_protocol.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index e4d25893b..18dd25d76 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -2,6 +2,7 @@ import contextlib import functools import os +import time import unittest import unittest.mock @@ -18,6 +19,9 @@ if os.environ.get('PYTHONASYNCIODEBUG'): # pragma: no cover MS *= 10 +# Ensure that timeouts are larger than the clock's resolution (for Windows). +MS = max(MS, 2.5 * time.get_clock_info('monotonic').resolution) + class TransportMock(unittest.mock.Mock): """ From 5666434e8d250a6b0731b6328ce389eb99eced1a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Aug 2015 22:05:53 +0200 Subject: [PATCH 0119/1539] Close connections instead of failing them on shutdown. This lets the client complete the closing handshake. Previously the TCP connection was closed just after sending a close frame. Fix #64 (again). --- websockets/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/websockets/server.py b/websockets/server.py index 0b57fc7bb..94135d170 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -216,7 +216,7 @@ def close(self): Stop serving and trigger a closing handshake on open connections. """ for websocket in self.websockets: - asyncio.async(websocket.fail_connection(1001), loop=self.loop) + asyncio.async(websocket.close(1001), loop=self.loop) self.server.close() @asyncio.coroutine From 2a41dd8a4bd6105e53dea6803872c79a7db969f5 Mon Sep 17 00:00:00 2001 From: Mike Putnam Date: Sun, 16 Aug 2015 12:55:28 -0500 Subject: [PATCH 0120/1539] Fix typo "on top of" --- README.rst | 2 +- docs/index.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index 189c6b18d..ec3c42c4e 100644 --- a/README.rst +++ b/README.rst @@ -5,7 +5,7 @@ WebSockets Python. It implements `RFC 6455`_ with a focus on correctness and simplicity. It passes the `Autobahn Testsuite`_. -Built on top on Python's asynchronous I/O support introduced in `PEP 3156`_, +Built on top of Python's asynchronous I/O support introduced in `PEP 3156`_, it provides an API based on coroutines, making it easy to write highly concurrent applications. diff --git a/docs/index.rst b/docs/index.rst index 028d442e8..fa6452eee 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -7,7 +7,7 @@ WebSockets Python. It implements `RFC 6455`_ with a focus on correctness and simplicity. It passes the `Autobahn Testsuite`_. -Built on top on Python's asynchronous I/O support introduced in `PEP 3156`_, +Built on top of Python's asynchronous I/O support introduced in `PEP 3156`_, it provides an API based on coroutines, making it easy to write highly concurrent applications. From bb2269a9b2d40801bc88fb62a455cb6cd3181256 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 18 Aug 2015 08:55:44 +0200 Subject: [PATCH 0121/1539] Avoided TCP fragmentation of small frames. Fix #68. --- docs/index.rst | 2 ++ websockets/framing.py | 20 +++++++++++--------- websockets/test_framing.py | 17 ++++++++++++----- 3 files changed, 25 insertions(+), 14 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index fa6452eee..2ae0e3466 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -293,6 +293,8 @@ Changelog * Closed open connections with code 1001 when a server shuts down. +* Avoided TCP fragmentation of small frames. + 2.5 ... diff --git a/websockets/framing.py b/websockets/framing.py index b1c13aad2..7714a8671 100644 --- a/websockets/framing.py +++ b/websockets/framing.py @@ -121,30 +121,32 @@ def write_frame(frame, writer, mask): incorrect values. """ check_frame(frame) + output = io.BytesIO() - # Write the header - header = io.BytesIO() + # Prepare the header head1 = 0b10000000 if frame.fin else 0 head1 |= frame.opcode head2 = 0b10000000 if mask else 0 length = len(frame.data) if length < 0x7e: - header.write(struct.pack('!BB', head1, head2 | length)) + output.write(struct.pack('!BB', head1, head2 | length)) elif length < 0x10000: - header.write(struct.pack('!BBH', head1, head2 | 126, length)) + output.write(struct.pack('!BBH', head1, head2 | 126, length)) else: - header.write(struct.pack('!BBQ', head1, head2 | 127, length)) + output.write(struct.pack('!BBQ', head1, head2 | 127, length)) if mask: mask_bits = struct.pack('!I', random.getrandbits(32)) - header.write(mask_bits) - writer(header.getvalue()) + output.write(mask_bits) - # Write the data + # Prepare the data if mask: data = bytes(b ^ mask_bits[i % 4] for i, b in enumerate(frame.data)) else: data = frame.data - writer(data) + output.write(data) + + # Send the frame + writer(output.getvalue()) def check_frame(frame): diff --git a/websockets/test_framing.py b/websockets/test_framing.py index 7fe914395..f88ee3bcc 100644 --- a/websockets/test_framing.py +++ b/websockets/test_framing.py @@ -1,6 +1,6 @@ import asyncio -import io import unittest +import unittest.mock from .exceptions import PayloadTooBig, WebSocketProtocolError from .framing import * @@ -19,13 +19,20 @@ def decode(self, message, mask=False, max_size=None): self.stream = asyncio.StreamReader(loop=self.loop) self.stream.feed_data(message) self.stream.feed_eof() - return self.loop.run_until_complete(read_frame( + frame = self.loop.run_until_complete(read_frame( self.stream.readexactly, mask, max_size=max_size)) + # Make sure all the data was consumed. + self.assertTrue(self.stream.at_eof()) + return frame def encode(self, frame, mask=False): - encoded = io.BytesIO() - write_frame(frame, encoded.write, mask) - return encoded.getvalue() + writer = unittest.mock.Mock() + write_frame(frame, writer, mask) + # Ensure the entire frame is sent with a single call to writer(). + # Multiple calls cause TCP fragmentation and degrade performance. + self.assertEqual(writer.call_count, 1) + # The frame data is the single positional argument of that call. + return writer.call_args[0][0] def round_trip(self, message, expected, mask=False): decoded = self.decode(message, mask) From be26a070c7c37e0838410734466587a267d047c7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 18 Aug 2015 13:28:41 +0200 Subject: [PATCH 0122/1539] Add local/remote_address attributes on protocols. Fix #66. --- docs/index.rst | 7 ++++++- websockets/protocol.py | 27 +++++++++++++++++++++++++++ websockets/test_protocol.py | 20 ++++++++++++++++++++ 3 files changed, 53 insertions(+), 1 deletion(-) diff --git a/docs/index.rst b/docs/index.rst index 2ae0e3466..469f86977 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -243,7 +243,10 @@ Shared .. autoclass:: WebSocketCommonProtocol(*, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, loop=None) - .. autoattribute:: open() + .. autoattribute:: local_address + .. autoattribute:: remote_address + + .. autoattribute:: open .. automethod:: close(code=1000, reason='') .. automethod:: recv() @@ -291,6 +294,8 @@ Changelog 2.6 ... +* Added ``local_address`` and ``remote_address`` attributes on protocols. + * Closed open connections with code 1001 when a server shuts down. * Avoided TCP fragmentation of small frames. diff --git a/websockets/protocol.py b/websockets/protocol.py index 1ac5bf2a3..00dfa3797 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -97,6 +97,9 @@ def __init__(self, *, stream_reader = asyncio.StreamReader(loop=loop) super().__init__(stream_reader, self.client_connected, loop) + self.reader = None + self.writer = None + self.request_headers = None self.raw_request_headers = None self.response_headers = None @@ -137,6 +140,30 @@ def state_name(self): # Public API + @property + def local_address(self): + """ + Local address of the connection. + + The address is a ``(host, port)`` tuple or ``None`` if the connection + hasn't been established yet. + """ + if self.writer is None: + return None + return self.writer.get_extra_info('sockname') + + @property + def remote_address(self): + """ + Remote address of the connection. + + The address is a ``(host, port)`` tuple or ``None`` if the connection + hasn't been established yet. + """ + if self.writer is None: + return None + return self.writer.get_extra_info('peername') + @property def open(self): """ diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 18dd25d76..7add736e4 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -176,6 +176,26 @@ def assertCompletesWithin(self, min_time, max_time): self.assertLess( dt, max_time, "Too slow: {} >= {}".format(dt, max_time)) + def test_local_address(self): + get_extra_info = unittest.mock.Mock(return_value=('host', 4312)) + self.transport.get_extra_info = get_extra_info + # The connection isn't established yet. + self.assertEqual(self.protocol.local_address, None) + self.run_loop_once() + # The connection is established. + self.assertEqual(self.protocol.local_address, ('host', 4312)) + get_extra_info.assert_called_once_with('sockname', None) + + def test_remote_address(self): + get_extra_info = unittest.mock.Mock(return_value=('host', 4312)) + self.transport.get_extra_info = get_extra_info + # The connection isn't established yet. + self.assertEqual(self.protocol.remote_address, None) + self.run_loop_once() + # The connection is established. + self.assertEqual(self.protocol.remote_address, ('host', 4312)) + get_extra_info.assert_called_once_with('peername', None) + def test_open(self): self.assertTrue(self.protocol.open) self.protocol.connection_lost(None) From 72f9b68b87395bdf598411398d07917b043dfc41 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 18 Aug 2015 20:13:37 +0200 Subject: [PATCH 0123/1539] Bump version number. --- docs/conf.py | 4 ++-- websockets/version.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 6bcffe9b7..16e9463f6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -48,9 +48,9 @@ # built documents. # # The short X.Y version. -version = '2.5' +version = '2.6' # The full version, including alpha/beta/rc tags. -release = '2.5' +release = '2.6' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/websockets/version.py b/websockets/version.py index 2250731ca..410fd1e00 100644 --- a/websockets/version.py +++ b/websockets/version.py @@ -1 +1 @@ -version = '2.5' +version = '2.6' From ea5087c2a917fd6ec9a2e3954b28a548fd29c145 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 25 Aug 2015 23:01:19 +0200 Subject: [PATCH 0124/1539] Tie registration with handle tasks' lifecycle. Since the purpose of registering or unregistering is to keep a list of active connection handlers, it makes sense to register when starting a connection handler and unregister only when the connection handler is about to terminate. Fix #69. --- websockets/server.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/websockets/server.py b/websockets/server.py index 94135d170..5c4c89a01 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -42,6 +42,11 @@ def __init__(self, ws_handler, ws_server, *, def connection_made(self, transport): super().connection_made(transport) + # Register the connection with the server when creating the handler + # task. (Registering at the beginning of the handler coroutine would + # create a race condition between the creation of the task, which + # schedules its execution, and the moment the handler starts running.) + self.ws_server.register(self) self.handler_task = asyncio.async(self.handler(), loop=self.loop) @asyncio.coroutine @@ -86,6 +91,13 @@ def handler(self): except Exception: # pragma: no cover pass + finally: + # Unregister the connection with the server when the handler task + # terminates. Registration is tied to the lifecycle of the handler + # task because the server waits for tasks attached to registered + # connections before terminating. + self.ws_server.unregister(self) + @asyncio.coroutine def handshake(self, origins=None, subprotocols=None, extra_headers=None): """ @@ -170,14 +182,6 @@ def select_subprotocol(self, client_protos, server_protos): priority = lambda p: client_protos.index(p) + server_protos.index(p) return sorted(common_protos, key=priority)[0] - def client_connected(self, reader, writer): - super().client_connected(reader, writer) - self.ws_server.register(self) - - def connection_lost(self, exc): - self.ws_server.unregister(self) - super().connection_lost(exc) - class WebSocketServer(asyncio.AbstractServer): """ From 212fed7b5d30314bcd5b466f339fa79e81c0075e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 31 Oct 2015 13:39:09 +0100 Subject: [PATCH 0125/1539] Add compatibility with Python 3.5. --- tox.ini | 2 +- websockets/compatibility.py | 8 ++++++++ websockets/protocol.py | 8 +++++--- websockets/server.py | 6 ++++-- websockets/test_protocol.py | 3 ++- 5 files changed, 20 insertions(+), 7 deletions(-) create mode 100644 websockets/compatibility.py diff --git a/tox.ini b/tox.ini index 56e949418..4955cb38c 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py33,py34,coverage,flake8,isort +envlist = py33,py34,py35,coverage,flake8,isort [testenv] deps = diff --git a/websockets/compatibility.py b/websockets/compatibility.py new file mode 100644 index 000000000..90afea9ba --- /dev/null +++ b/websockets/compatibility.py @@ -0,0 +1,8 @@ +import asyncio + + +# Replace with BaseEventLoop.create_task when dropping Python < 3.4.2. +try: # pragma: no cover + asyncio_ensure_future = asyncio.ensure_future # Python ≥ 3.5 +except AttributeError: # pragma: no cover + asyncio_ensure_future = asyncio.async # Python < 3.5 diff --git a/websockets/protocol.py b/websockets/protocol.py index 00dfa3797..0af76b4a8 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -15,6 +15,7 @@ import random import struct +from .compatibility import asyncio_ensure_future from .exceptions import InvalidState, PayloadTooBig, WebSocketProtocolError from .framing import * from .handshake import * @@ -181,7 +182,7 @@ def close(self, code=1000, reason=''): It waits for the other end to complete the handshake. It doesn't do anything once the connection is closed. - It's usually safe to wrap this coroutine in :func:`~asyncio.async` + It's safe to wrap this coroutine in :func:`~asyncio.ensure_future` since errors during connection termination aren't particularly useful. ``code`` must be an :class:`int` and ``reason`` a :class:`str`. @@ -222,7 +223,8 @@ def recv(self): pass # Wait for a message until the connection is closed - next_message = asyncio.async(self.messages.get(), loop=self.loop) + next_message = asyncio_ensure_future( + self.messages.get(), loop=self.loop) try: done, pending = yield from asyncio.wait( [next_message, self.worker], @@ -493,7 +495,7 @@ def client_connected(self, reader, writer): self.reader = reader self.writer = writer # Start the task that handles incoming messages. - self.worker = asyncio.async(self.run(), loop=self.loop) + self.worker = asyncio_ensure_future(self.run(), loop=self.loop) def connection_lost(self, exc): # 7.1.4. The WebSocket Connection is Closed diff --git a/websockets/server.py b/websockets/server.py index 5c4c89a01..1953848a0 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -9,6 +9,7 @@ import email.message import logging +from .compatibility import asyncio_ensure_future from .exceptions import InvalidHandshake, InvalidOrigin from .handshake import build_response, check_request from .http import USER_AGENT, read_request @@ -47,7 +48,8 @@ def connection_made(self, transport): # create a race condition between the creation of the task, which # schedules its execution, and the moment the handler starts running.) self.ws_server.register(self) - self.handler_task = asyncio.async(self.handler(), loop=self.loop) + self.handler_task = asyncio_ensure_future( + self.handler(), loop=self.loop) @asyncio.coroutine def handler(self): @@ -220,7 +222,7 @@ def close(self): Stop serving and trigger a closing handshake on open connections. """ for websocket in self.websockets: - asyncio.async(websocket.close(1001), loop=self.loop) + asyncio_ensure_future(websocket.close(1001), loop=self.loop) self.server.close() @asyncio.coroutine diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 7add736e4..414e38692 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -6,6 +6,7 @@ import unittest import unittest.mock +from .compatibility import asyncio_ensure_future from .exceptions import InvalidState from .framing import * from .protocol import CLOSED, WebSocketCommonProtocol @@ -88,7 +89,7 @@ def delayed_drain(): @property def async(self): - return functools.partial(asyncio.async, loop=self.loop) + return functools.partial(asyncio_ensure_future, loop=self.loop) def receive_frame(self, frame): """ From be9852913637932351feb09729048eaa18ccf58d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 31 Oct 2015 14:06:16 +0100 Subject: [PATCH 0126/1539] Close connections correctly on Python 3.5. * Counteract change in StreamReaderProtocol.eof_received() * Ensure the transport is closed to fix resource warnings Fix #76. --- websockets/protocol.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/websockets/protocol.py b/websockets/protocol.py index 0af76b4a8..b8ca6e3ea 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -497,6 +497,20 @@ def client_connected(self, reader, writer): # Start the task that handles incoming messages. self.worker = asyncio_ensure_future(self.run(), loop=self.loop) + def eof_received(self): + super().eof_received() + # Since Python 3.5, StreamReaderProtocol.eof_received() returns True + # to leave the transport open (http://bugs.python.org/issue24539). + # This is inappropriate for websockets for at least three reasons. + # 1. The use case is to read data until EOF with self.reader.read(-1). + # Since websockets is a TLV protocol, this never happens. + # 2. It doesn't work on SSL connections. A falsy value must be + # returned to have the same behavior on SSL and plain connections. + # 3. The websockets protocol has its own closing handshake. Endpoints + # close the TCP connection after sending a Close frame. + # As a consequence we revert to the previous, more useful behavior. + return + def connection_lost(self, exc): # 7.1.4. The WebSocket Connection is Closed self.state = CLOSED @@ -505,4 +519,7 @@ def connection_lost(self, exc): self.closing_handshake.set_result(False) if not self.connection_closed.done(): self.connection_closed.set_result(None) + # Close the transport in case close_connection() wasn't executed. + if self.writer is not None: + self.writer.close() super().connection_lost(exc) From e3a7c289ffcda22ea83d82c190753ea851de7cc3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 31 Oct 2015 14:11:41 +0100 Subject: [PATCH 0127/1539] Add changelog entry for previous commit. --- docs/index.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/index.rst b/docs/index.rst index 469f86977..15916e898 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -291,6 +291,11 @@ Utilities Changelog --------- +2.7 +... + +* Added compatibility with Python 3.5. + 2.6 ... From 9af8776d8e33e31315c72021515226e1d75483d4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 31 Oct 2015 16:56:03 +0100 Subject: [PATCH 0128/1539] Change docs theme. --- docs/conf.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 16e9463f6..a1f020799 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -91,8 +91,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = 'pydoctheme' -html_theme_path = ['.'] +html_theme = 'alabaster' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the From 37e6234a5f4c9908827dbf2eb9b375cd7253ca52 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 31 Oct 2015 16:56:12 +0100 Subject: [PATCH 0129/1539] Move docs to Read the Docs. It appears to have gained support for Python 3.4 now. --- README.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index ec3c42c4e..01f357cf5 100644 --- a/README.rst +++ b/README.rst @@ -13,7 +13,7 @@ Installation is as simple as ``pip install websockets``. It requires Python ≥ 3.4 or Python 3.3 with the ``asyncio`` module, which is available with ``pip install asyncio``. -Documentation is available at http://aaugustin.github.io/websockets/. +Documentation is available on `Read the Docs`_. Bug reports, patches and suggestions welcome! Just open an issue_ or send a `pull request`_. @@ -23,5 +23,6 @@ Bug reports, patches and suggestions welcome! Just open an issue_ or send a .. _RFC 6455: http://tools.ietf.org/html/rfc6455 .. _Autobahn Testsuite: https://github.com/aaugustin/websockets/blob/master/compliance/README.rst .. _PEP 3156: http://www.python.org/dev/peps/pep-3156/ +.. _Read the Docs: https://websockets.readthedocs.org/ .. _issue: https://github.com/aaugustin/websockets/issues/new .. _pull request: https://github.com/aaugustin/websockets/compare/ From 46986ef80728c2aeef2a72795014803c9890c4ed Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 31 Oct 2015 18:31:33 +0100 Subject: [PATCH 0130/1539] Improve docs with the alabaster theme and a logo. Also split them across multiple pages. Fix #30. --- docs/_static/websockets.svg | 16 + docs/api.rst | 103 +++++++ docs/changelog.rst | 86 ++++++ docs/cheatsheet.rst | 69 +++++ docs/conf.py | 19 +- docs/index.rst | 408 +++----------------------- docs/intro.rst | 88 ++++++ docs/license.rst | 4 + docs/limitations.rst | 13 + docs/pydoctheme/static/pydoctheme.css | 170 ----------- docs/pydoctheme/theme.conf | 23 -- 11 files changed, 429 insertions(+), 570 deletions(-) create mode 100644 docs/_static/websockets.svg create mode 100644 docs/api.rst create mode 100644 docs/changelog.rst create mode 100644 docs/cheatsheet.rst create mode 100644 docs/intro.rst create mode 100644 docs/license.rst create mode 100644 docs/limitations.rst delete mode 100644 docs/pydoctheme/static/pydoctheme.css delete mode 100644 docs/pydoctheme/theme.conf diff --git a/docs/_static/websockets.svg b/docs/_static/websockets.svg new file mode 100644 index 000000000..409afb71d --- /dev/null +++ b/docs/_static/websockets.svg @@ -0,0 +1,16 @@ + + + + + + + + + + + + diff --git a/docs/api.rst b/docs/api.rst new file mode 100644 index 000000000..a36fdc255 --- /dev/null +++ b/docs/api.rst @@ -0,0 +1,103 @@ +API +=== + +Design +------ + +``websockets`` provides complete client and server implementations, as shown in +the examples above. These functions are built on top of low-level APIs +reflecting the two phases of the WebSocket protocol: + +1. An opening handshake, in the form of an HTTP Upgrade request; + +2. Data transfer, as framed messages, ending with a closing handshake. + +The first phase is designed to integrate with existing HTTP software. +``websockets`` provides functions to build and validate the request and +response headers. + +The second phase is the core of the WebSocket protocol. ``websockets`` +provides a standalone implementation on top of ``asyncio`` with a very simple +API. + +For convenience, public APIs can be imported directly from the +:mod:`websockets` package, unless noted otherwise. Anything that isn't listed +in this document is a private API. + +High-level +---------- + +Server +...... + +.. automodule:: websockets.server + + .. autofunction:: serve(ws_handler, host=None, port=None, *, loop=None, klass=WebSocketServerProtocol, origins=None, subprotocols=None, extra_headers=None, **kwds) + + .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, origins=None, subprotocols=None, extra_headers=None, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, loop=None) + + .. automethod:: handshake(origins=None, subprotocols=None, extra_headers=None) + .. automethod:: select_subprotocol(client_protos, server_protos) + +Client +...... + +.. automodule:: websockets.client + + .. autofunction:: connect(uri, *, loop=None, klass=WebSocketClientProtocol, origin=None, subprotocols=None, extra_headers=None, **kwds) + + .. autoclass:: WebSocketClientProtocol(*, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, loop=None) + + .. automethod:: handshake(wsuri, origin=None, subprotocols=None, extra_headers=None) + +Shared +...... + +.. automodule:: websockets.protocol + + .. autoclass:: WebSocketCommonProtocol(*, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, loop=None) + + .. autoattribute:: local_address + .. autoattribute:: remote_address + + .. autoattribute:: open + .. automethod:: close(code=1000, reason='') + + .. automethod:: recv() + .. automethod:: send(data) + + .. automethod:: ping(data=None) + .. automethod:: pong(data=b'') + +Exceptions +.......... + +.. automodule:: websockets.exceptions + :members: + +Low-level +--------- + +Opening handshake +................. + +.. automodule:: websockets.handshake + :members: + +Data transfer +............. + +.. automodule:: websockets.framing + :members: + +URI parser +.......... + +.. automodule:: websockets.uri + :members: + +Utilities +......... + +.. automodule:: websockets.http + :members: diff --git a/docs/changelog.rst b/docs/changelog.rst new file mode 100644 index 000000000..9561bf39a --- /dev/null +++ b/docs/changelog.rst @@ -0,0 +1,86 @@ +Changelog +--------- + +2.7 +... + +* Added compatibility with Python 3.5. + +2.6 +... + +* Added ``local_address`` and ``remote_address`` attributes on protocols. + +* Closed open connections with code 1001 when a server shuts down. + +* Avoided TCP fragmentation of small frames. + +2.5 +... + +* Improved documentation. + +* Provided access to handshake request and response HTTP headers. + +* Allowed customizing handshake request and response HTTP headers. + +* Supported running on a non-default event loop. + +* Returned a 403 error code instead of 400 when the request Origin isn't + allowed. + +* Cancelling :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` no + longer drops the next message. + +* Clarified that the closing handshake can be initiated by the client. + +* Set the close status code and reason more consistently. + +* Strengthened connection termination by simplifying the implementation. + +* Improved tests, added tox configuration, and enforced 100% branch coverage. + +2.4 +... + +* Added support for subprotocols. + +* Supported non-default event loop. + +* Added ``loop`` argument to :func:`~websockets.client.connect` and + :func:`~websockets.server.serve`. + +2.3 +... + +* Improved compliance of close codes. + +2.2 +... + +* Added support for limiting message size. + +2.1 +... + +* Added ``host``, ``port`` and ``secure`` attributes on protocols. + +* Added support for providing and checking Origin_. + +.. _Origin: https://tools.ietf.org/html/rfc6455#section-10.2 + +2.0 +... + +* Backwards-incompatible API change: + :meth:`~websockets.protocol.WebSocketCommonProtocol.send`, + :meth:`~websockets.protocol.WebSocketCommonProtocol.ping` and + :meth:`~websockets.protocol.WebSocketCommonProtocol.pong` are coroutines. + They used to be regular functions. + +* Added flow control. + +1.0 +... + +* Initial public release. diff --git a/docs/cheatsheet.rst b/docs/cheatsheet.rst new file mode 100644 index 000000000..75347fb98 --- /dev/null +++ b/docs/cheatsheet.rst @@ -0,0 +1,69 @@ +Cheat sheet +=========== + +Server +------ + +* Write a coroutine that handles a single connection. It receives a websocket + protocol instance and the URI path in argument. + + * Call :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` and + :meth:`~websockets.protocol.WebSocketCommonProtocol.send` to receive and + send messages at any time. + + * You may :meth:`~websockets.protocol.WebSocketCommonProtocol.ping` or + :meth:`~websockets.protocol.WebSocketCommonProtocol.pong` if you wish + but it isn't needed in general. + +* Create a server with :func:`~websockets.server.serve` which is similar to + asyncio's :meth:`~asyncio.BaseEventLoop.create_server`. + + * The server takes care of establishing connections, then lets the handler + execute the application logic, and finally closes the connection after + the handler returns. + + * You may subclass :class:`~websockets.server.WebSocketServerProtocol` and + pass it in the ``klass`` keyword argument for advanced customization. + +Client +------ + +* Create a server with :func:`~websockets.client.connect` which is similar to + asyncio's :meth:`~asyncio.BaseEventLoop.create_connection`. + + * You may subclass :class:`~websockets.server.WebSocketClientProtocol` and + pass it in the ``klass`` keyword argument for advanced customization. + +* Call :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` and + :meth:`~websockets.protocol.WebSocketCommonProtocol.send` to receive and + send messages at any time. + +* You may :meth:`~websockets.protocol.WebSocketCommonProtocol.ping` or + :meth:`~websockets.protocol.WebSocketCommonProtocol.pong` if you wish but it + isn't needed in general. + +* Call :meth:`~websockets.protocol.WebSocketCommonProtocol.close` to terminate + the connection. + +Debugging +--------- + +If you don't understand what ``websockets`` is doing, enable logging:: + + import logging + logger = logging.getLogger('websockets') + logger.setLevel(logging.INFO) + logger.addHandler(logging.StreamHandler()) + +The logs contains: + +* Exceptions in the connection handler at the ``ERROR`` level +* Exceptions in the opening or closing handshake at the ``INFO`` level +* All frames at the ``DEBUG`` level — this can be very verbose + +If you're new to ``asyncio``, you will certainly encounter issues that are +related to asynchronous programming in general rather than to ``websockets`` +in particular. Fortunately Python's official documentation provides advice to +`develop with asyncio`_. Check it out: it's invaluable! + +.. _develop with asyncio: https://docs.python.org/3/library/asyncio-dev.html diff --git a/docs/conf.py b/docs/conf.py index a1f020799..1a02b24eb 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -96,7 +96,13 @@ # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. -#html_theme_options = {} +html_theme_options = { + 'logo': 'websockets.svg', + 'description': 'WebSockets for Python 3', + 'github_button': True, + 'github_user': 'aaugustin', + 'github_repo': 'websockets', +} # Add any paths that contain custom themes here, relative to this directory. #html_theme_path = [] @@ -120,7 +126,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -# html_static_path = ['_static'] +html_static_path = ['_static'] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. @@ -131,7 +137,14 @@ #html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -#html_sidebars = {} +html_sidebars = { + '**': [ + 'about.html', + 'navigation.html', + 'relations.html', + 'searchbox.html', + ] +} # Additional templates that should be rendered to pages, maps page names to # template names. diff --git a/docs/index.rst b/docs/index.rst index 15916e898..c0712592a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,5 +1,3 @@ -.. module:: websockets - WebSockets ========== @@ -7,13 +5,33 @@ WebSockets Python. It implements `RFC 6455`_ with a focus on correctness and simplicity. It passes the `Autobahn Testsuite`_. -Built on top of Python's asynchronous I/O support introduced in `PEP 3156`_, -it provides an API based on coroutines, making it easy to write highly -concurrent applications. +Built on top of :mod:`asyncio`, Python's standard asynchronous I/O framework, +it provides a straightforward API based on coroutines, making it easy to write +highly concurrent applications. + +Installation +------------ + +Installation is as simple as ``pip install websockets``. + +It requires Python ≥ 3.4 or Python 3.3 with the ``asyncio`` module, which is +available with ``pip install asyncio``. + +User guide +---------- + +If you're new to ``websockets``, :doc:`intro` describes usage patterns and +provides examples. + +If you've used ``websockets`` before and just need a quick reference, have a +look at :doc:`cheatsheet`. + +If you need more details, the :doc:`api` documentation is for you. -Installation is as simple as ``pip install websockets``. It requires Python ≥ -3.4 or Python 3.3 with the ``asyncio`` module, which is available with ``pip -install asyncio``. +If you're upgrading ``websockets``, check the :doc:`changelog`. + +Contributing +------------ Bug reports, patches and suggestions welcome! Just open an issue_ or send a `pull request`_. @@ -26,370 +44,12 @@ Bug reports, patches and suggestions welcome! Just open an issue_ or send a .. _issue: https://github.com/aaugustin/websockets/issues/new .. _pull request: https://github.com/aaugustin/websockets/compare/ -Example -------- - -.. _server-example: - -Here's a WebSocket server example. It reads a name from the client and sends a -message. - -.. literalinclude:: ../example/server.py - -.. _client-example: - -Here's a corresponding client example. - -.. literalinclude:: ../example/client.py - -.. note:: - - On the server side, the handler coroutine ``hello`` is executed once for - each WebSocket connection. The connection is automatically closed when the - handler returns. - - You will almost always want to process several messages during the - lifetime of a connection. Therefore you must write a loop. Here are the - recommended patterns to exit cleanly when the connection drops, either - because the other side closed it or for any other reason. - - For receiving messages and passing them to a ``consumer`` coroutine:: - - @asyncio.coroutine - def handler(websocket, path): - while True: - message = yield from websocket.recv() - if message is None: - break - yield from consumer(message) - - :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` returns ``None`` - when the connection is closed. In other words, ``None`` marks the end of - the message stream. The handler coroutine should check for that case and - return when it happens. - - For getting messages from a ``producer`` coroutine and sending them:: - - @asyncio.coroutine - def handler(websocket, path): - while True: - message = yield from producer() - if not websocket.open: - break - yield from websocket.send(message) - - :meth:`~websockets.protocol.WebSocketCommonProtocol.send` fails with an - exception when it's called on a closed connection. Therefore the handler - coroutine should check that the connection is still open before attempting - to write and return otherwise. - - Of course, you can combine the two patterns shown above to read and write - messages on the same connection:: - - @asyncio.coroutine - def handler(websocket, path): - while True: - listener_task = asyncio.ensure_future(websocket.recv()) - producer_task = asyncio.ensure_future(producer()) - done, pending = yield from asyncio.wait( - [listener_task, producer_task], - return_when=asyncio.FIRST_COMPLETED) - - if listener_task in done: - message = listener_task.result() - if message is None: - break - yield from consumer(message) - else: - listener_task.cancel() - - if producer_task in done: - message = producer_task.result() - if not websocket.open: - break - yield from websocket.send(message) - else: - producer_task.cancel() - - (This code looks convoluted. If you know a more straightforward solution, - please let me know about it!) - -That's really all you have to know! ``websockets`` manages the connection -under the hood so you don't have to. - -Cheat sheet ------------ - -Server -...... - -* Write a coroutine that handles a single connection. It receives a websocket - protocol instance and the URI path in argument. - - * Call :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` and - :meth:`~websockets.protocol.WebSocketCommonProtocol.send` to receive and - send messages at any time. - - * You may :meth:`~websockets.protocol.WebSocketCommonProtocol.ping` or - :meth:`~websockets.protocol.WebSocketCommonProtocol.pong` if you wish - but it isn't needed in general. - -* Create a server with :func:`~websockets.server.serve` which is similar to - asyncio's :meth:`~asyncio.BaseEventLoop.create_server`. - - * The server takes care of establishing connections, then lets the handler - execute the application logic, and finally closes the connection after - the handler returns. - - * You may subclass :class:`~websockets.server.WebSocketServerProtocol` and - pass it in the ``klass`` keyword argument for advanced customization. - -Client -...... - -* Create a server with :func:`~websockets.client.connect` which is similar to - asyncio's :meth:`~asyncio.BaseEventLoop.create_connection`. - - * You may subclass :class:`~websockets.server.WebSocketClientProtocol` and - pass it in the ``klass`` keyword argument for advanced customization. - -* Call :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` and - :meth:`~websockets.protocol.WebSocketCommonProtocol.send` to receive and - send messages at any time. - -* You may :meth:`~websockets.protocol.WebSocketCommonProtocol.ping` or - :meth:`~websockets.protocol.WebSocketCommonProtocol.pong` if you wish but it - isn't needed in general. - -* Call :meth:`~websockets.protocol.WebSocketCommonProtocol.close` to terminate - the connection. - -Debugging -......... - -If you don't understand what ``websockets`` is doing, enable logging:: - - import logging - logger = logging.getLogger('websockets') - logger.setLevel(logging.INFO) - logger.addHandler(logging.StreamHandler()) - -The logs contains: - -* Exceptions in the connection handler at the ``ERROR`` level -* Exceptions in the opening or closing handshake at the ``INFO`` level -* All frames at the ``DEBUG`` level — this can be very verbose - -If you're new to ``asyncio``, you will certainly encounter issues that are -related to asynchronous programming in general rather than to ``websockets`` -in particular. Fortunately Python's official documentation provides advice to -`develop with asyncio`_. Check it out: it's invaluable! - -.. _develop with asyncio: https://docs.python.org/3/library/asyncio-dev.html - -Design ------- - -``websockets`` provides complete client and server implementations, as shown in -the examples above. These functions are built on top of low-level APIs -reflecting the two phases of the WebSocket protocol: - -1. An opening handshake, in the form of an HTTP Upgrade request; - -2. Data transfer, as framed messages, ending with a closing handshake. - -The first phase is designed to integrate with existing HTTP software. -``websockets`` provides functions to build and validate the request and -response headers. - -The second phase is the core of the WebSocket protocol. ``websockets`` -provides a standalone implementation on top of ``asyncio`` with a very simple -API. - -For convenience, public APIs can be imported directly from the -:mod:`websockets` package, unless noted otherwise. Anything that isn't listed -in this document is a private API. - -High-level API --------------- - -Server -...... - -.. automodule:: websockets.server - - .. autofunction:: serve(ws_handler, host=None, port=None, *, loop=None, klass=WebSocketServerProtocol, origins=None, subprotocols=None, extra_headers=None, **kwds) - - .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, origins=None, subprotocols=None, extra_headers=None, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, loop=None) - - .. automethod:: handshake(origins=None, subprotocols=None, extra_headers=None) - .. automethod:: select_subprotocol(client_protos, server_protos) - -Client -...... - -.. automodule:: websockets.client - - .. autofunction:: connect(uri, *, loop=None, klass=WebSocketClientProtocol, origin=None, subprotocols=None, extra_headers=None, **kwds) - - .. autoclass:: WebSocketClientProtocol(*, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, loop=None) - - .. automethod:: handshake(wsuri, origin=None, subprotocols=None, extra_headers=None) - -Shared -...... - -.. automodule:: websockets.protocol - - .. autoclass:: WebSocketCommonProtocol(*, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, loop=None) - - .. autoattribute:: local_address - .. autoattribute:: remote_address - - .. autoattribute:: open - .. automethod:: close(code=1000, reason='') - - .. automethod:: recv() - .. automethod:: send(data) - - .. automethod:: ping(data=None) - .. automethod:: pong(data=b'') - -Exceptions -.......... - -.. automodule:: websockets.exceptions - :members: - -Low-level API -------------- - -Opening handshake -................. - -.. automodule:: websockets.handshake - :members: - -Data transfer -............. - -.. automodule:: websockets.framing - :members: - -URI parser -.......... - -.. automodule:: websockets.uri - :members: - -Utilities -......... - -.. automodule:: websockets.http - :members: - -Changelog ---------- - -2.7 -... - -* Added compatibility with Python 3.5. - -2.6 -... - -* Added ``local_address`` and ``remote_address`` attributes on protocols. - -* Closed open connections with code 1001 when a server shuts down. - -* Avoided TCP fragmentation of small frames. - -2.5 -... - -* Improved documentation. - -* Provided access to handshake request and response HTTP headers. - -* Allowed customizing handshake request and response HTTP headers. - -* Supported running on a non-default event loop. - -* Returned a 403 error code instead of 400 when the request Origin isn't - allowed. - -* Cancelling :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` no - longer drops the next message. - -* Clarified that the closing handshake can be initiated by the client. - -* Set the close status code and reason more consistently. - -* Strengthened connection termination by simplifying the implementation. - -* Improved tests, added tox configuration, and enforced 100% branch coverage. - -2.4 -... - -* Added support for subprotocols. - -* Supported non-default event loop. - -* Added ``loop`` argument to :func:`~websockets.client.connect` and - :func:`~websockets.server.serve`. - -2.3 -... - -* Improved compliance of close codes. - -2.2 -... - -* Added support for limiting message size. - -2.1 -... - -* Added ``host``, ``port`` and ``secure`` attributes on protocols. - -* Added support for providing and checking Origin_. - -.. _Origin: https://tools.ietf.org/html/rfc6455#section-10.2 - -2.0 -... - -* Backwards-incompatible API change: - :meth:`~websockets.protocol.WebSocketCommonProtocol.send`, - :meth:`~websockets.protocol.WebSocketCommonProtocol.ping` and - :meth:`~websockets.protocol.WebSocketCommonProtocol.pong` are coroutines. - They used to be regular functions. - -* Added flow control. - -1.0 -... - -* Initial public release. - -Limitations ------------ - -Extensions_ aren't implemented. No extensions are registered_ at the time of -writing. - -The client doesn't attempt to guarantee that there is no more than one -connection to a given IP adress in a CONNECTING state. - -The client doesn't support connecting through a proxy. - -.. _Extensions: http://tools.ietf.org/html/rfc6455#section-9 -.. _registered: http://www.iana.org/assignments/websocket/websocket.xml - -License -------- +.. toctree:: + :hidden: -.. literalinclude:: ../LICENSE + intro + cheatsheet + api + limitations + changelog + license diff --git a/docs/intro.rst b/docs/intro.rst new file mode 100644 index 000000000..3c96a446d --- /dev/null +++ b/docs/intro.rst @@ -0,0 +1,88 @@ +Getting started +=============== + +.. _server-example: + +Here's a WebSocket server example. It reads a name from the client and sends a +message. + +.. literalinclude:: ../example/server.py + +.. _client-example: + +Here's a corresponding client example. + +.. literalinclude:: ../example/client.py + +On the server side, the handler coroutine ``hello`` is executed once for +each WebSocket connection. The connection is automatically closed when the +handler returns. + +You will almost always want to process several messages during the +lifetime of a connection. Therefore you must write a loop. Here are the +recommended patterns to exit cleanly when the connection drops, either +because the other side closed it or for any other reason. + +For receiving messages and passing them to a ``consumer`` coroutine:: + + @asyncio.coroutine + def handler(websocket, path): + while True: + message = yield from websocket.recv() + if message is None: + break + yield from consumer(message) + +:meth:`~websockets.protocol.WebSocketCommonProtocol.recv` returns ``None`` +when the connection is closed. In other words, ``None`` marks the end of +the message stream. The handler coroutine should check for that case and +return when it happens. + +For getting messages from a ``producer`` coroutine and sending them:: + + @asyncio.coroutine + def handler(websocket, path): + while True: + message = yield from producer() + if not websocket.open: + break + yield from websocket.send(message) + +:meth:`~websockets.protocol.WebSocketCommonProtocol.send` fails with an +exception when it's called on a closed connection. Therefore the handler +coroutine should check that the connection is still open before attempting +to write and return otherwise. + +Of course, you can combine the two patterns shown above to read and write +messages on the same connection:: + + @asyncio.coroutine + def handler(websocket, path): + while True: + listener_task = asyncio.ensure_future(websocket.recv()) + producer_task = asyncio.ensure_future(producer()) + done, pending = yield from asyncio.wait( + [listener_task, producer_task], + return_when=asyncio.FIRST_COMPLETED) + + if listener_task in done: + message = listener_task.result() + if message is None: + break + yield from consumer(message) + else: + listener_task.cancel() + + if producer_task in done: + message = producer_task.result() + if not websocket.open: + break + yield from websocket.send(message) + else: + producer_task.cancel() + +(This code looks convoluted. If you know a more straightforward solution, +please let me know about it!) + +That's really all you have to know! ``websockets`` manages the connection +under the hood so you don't have to. diff --git a/docs/license.rst b/docs/license.rst new file mode 100644 index 000000000..842d3b07f --- /dev/null +++ b/docs/license.rst @@ -0,0 +1,4 @@ +License +------- + +.. literalinclude:: ../LICENSE diff --git a/docs/limitations.rst b/docs/limitations.rst new file mode 100644 index 000000000..8cf5314d9 --- /dev/null +++ b/docs/limitations.rst @@ -0,0 +1,13 @@ +Limitations +----------- + +Extensions_ aren't implemented. No extensions are registered_ at the time of +writing. + +The client doesn't attempt to guarantee that there is no more than one +connection to a given IP adress in a CONNECTING state. + +The client doesn't support connecting through a proxy. + +.. _Extensions: http://tools.ietf.org/html/rfc6455#section-9 +.. _registered: http://www.iana.org/assignments/websocket/websocket.xml diff --git a/docs/pydoctheme/static/pydoctheme.css b/docs/pydoctheme/static/pydoctheme.css deleted file mode 100644 index 9942ca631..000000000 --- a/docs/pydoctheme/static/pydoctheme.css +++ /dev/null @@ -1,170 +0,0 @@ -@import url("default.css"); - -body { - background-color: white; - margin-left: 1em; - margin-right: 1em; -} - -div.related { - margin-bottom: 1.2em; - padding: 0.5em 0; - border-top: 1px solid #ccc; - margin-top: 0.5em; -} - -div.related a:hover { - color: #0095C4; -} - -div.related:first-child { - border-top: 0; - border-bottom: 1px solid #ccc; -} - -div.sphinxsidebar { - background-color: #eeeeee; - border-radius: 5px; - line-height: 130%; - font-size: smaller; -} - -div.sphinxsidebar h3, div.sphinxsidebar h4 { - margin-top: 1.5em; -} - -div.sphinxsidebarwrapper > h3:first-child { - margin-top: 0.2em; -} - -div.sphinxsidebarwrapper > ul > li > ul > li { - margin-bottom: 0.4em; -} - -div.sphinxsidebar a:hover { - color: #0095C4; -} - -div.sphinxsidebar input { - font-family: 'Lucida Grande',Arial,sans-serif; - border: 1px solid #999999; - font-size: smaller; - border-radius: 3px; -} - -div.sphinxsidebar input[type=text] { - max-width: 150px; -} - -div.body { - padding: 0 0 0 1.2em; -} - -div.body p { - line-height: 140%; -} - -div.body h1, div.body h2, div.body h3, div.body h4, div.body h5, div.body h6 { - margin: 0; - border: 0; - padding: 0.3em 0; -} - -div.body hr { - border: 0; - background-color: #ccc; - height: 1px; -} - -div.body pre { - border-radius: 3px; - border: 1px solid #ac9; -} - -div.body div.admonition, div.body div.impl-detail { - border-radius: 3px; -} - -div.body div.impl-detail > p { - margin: 0; -} - -div.body div.seealso { - border: 1px solid #dddd66; -} - -div.body a { - color: #00608f; -} - -div.body a:visited { - color: #30306f; -} - -div.body a:hover { - color: #00B0E4; -} - -tt, pre { - font-family: monospace, sans-serif; - font-size: 96.5%; -} - -div.body tt { - border-radius: 3px; -} - -div.body tt.descname { - font-size: 120%; -} - -div.body tt.xref, div.body a tt { - font-weight: normal; -} - -p.deprecated { - border-radius: 3px; -} - -table.docutils { - border: 1px solid #ddd; - min-width: 20%; - border-radius: 3px; - margin-top: 10px; - margin-bottom: 10px; -} - -table.docutils td, table.docutils th { - border: 1px solid #ddd !important; - border-radius: 3px; -} - -table p, table li { - text-align: left !important; -} - -table.docutils th { - background-color: #eee; - padding: 0.3em 0.5em; -} - -table.docutils td { - background-color: white; - padding: 0.3em 0.5em; -} - -table.footnote, table.footnote td { - border: 0 !important; -} - -div.footer { - line-height: 150%; - margin-top: -2em; - text-align: right; - width: auto; - margin-right: 10px; -} - -div.footer a:hover { - color: #0095C4; -} diff --git a/docs/pydoctheme/theme.conf b/docs/pydoctheme/theme.conf deleted file mode 100644 index 0c4388167..000000000 --- a/docs/pydoctheme/theme.conf +++ /dev/null @@ -1,23 +0,0 @@ -[theme] -inherit = default -stylesheet = pydoctheme.css -pygments_style = sphinx - -[options] -bodyfont = 'Lucida Grande', Arial, sans-serif -headfont = 'Lucida Grande', Arial, sans-serif -footerbgcolor = white -footertextcolor = #555555 -relbarbgcolor = white -relbartextcolor = #666666 -relbarlinkcolor = #444444 -sidebarbgcolor = white -sidebartextcolor = #444444 -sidebarlinkcolor = #444444 -bgcolor = white -textcolor = #222222 -linkcolor = #0090c0 -visitedlinkcolor = #00608f -headtextcolor = #1a1a1a -headbgcolor = white -headlinkcolor = #aaaaaa From 292cd17b6021e5239525303f3635af088c40e071 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 31 Oct 2015 18:36:37 +0100 Subject: [PATCH 0131/1539] Next version will be 3.0. --- docs/changelog.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 9561bf39a..7d3ee8b81 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,11 +1,13 @@ Changelog --------- -2.7 +3.0 ... * Added compatibility with Python 3.5. +* Refreshed documentation. + 2.6 ... From db9eb7033d81fba3c30711bf1c9d9635fc288a18 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 31 Oct 2015 19:33:16 +0100 Subject: [PATCH 0132/1539] Add an HTML + JS example. Fix #74. --- docs/intro.rst | 29 +++++++++++++++++++++++++++++ example/client.py | 3 +++ example/server.py | 1 + example/time.html | 20 ++++++++++++++++++++ example/time.py | 20 ++++++++++++++++++++ 5 files changed, 73 insertions(+) create mode 100644 example/time.html create mode 100644 example/time.py diff --git a/docs/intro.rst b/docs/intro.rst index 3c96a446d..adfe2694a 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -1,6 +1,9 @@ Getting started =============== +Basic example +------------- + .. _server-example: Here's a WebSocket server example. It reads a name from the client and sends a @@ -18,11 +21,31 @@ On the server side, the handler coroutine ``hello`` is executed once for each WebSocket connection. The connection is automatically closed when the handler returns. +Browser-based example +--------------------- + +Here's an example of how to run a WebSocket server and connect from a browser. + +Run this script in a console: + +.. literalinclude:: ../example/time.py + +Then open this HTML file in a browser. + +.. literalinclude:: ../example/time.html + :language: html + +Common patterns +--------------- + You will almost always want to process several messages during the lifetime of a connection. Therefore you must write a loop. Here are the recommended patterns to exit cleanly when the connection drops, either because the other side closed it or for any other reason. +Consumer +........ + For receiving messages and passing them to a ``consumer`` coroutine:: @asyncio.coroutine @@ -38,6 +61,9 @@ when the connection is closed. In other words, ``None`` marks the end of the message stream. The handler coroutine should check for that case and return when it happens. +Producer +........ + For getting messages from a ``producer`` coroutine and sending them:: @asyncio.coroutine @@ -53,6 +79,9 @@ exception when it's called on a closed connection. Therefore the handler coroutine should check that the connection is still open before attempting to write and return otherwise. +Both +.... + Of course, you can combine the two patterns shown above to read and write messages on the same connection:: diff --git a/example/client.py b/example/client.py index a75992975..66ff2b4d2 100755 --- a/example/client.py +++ b/example/client.py @@ -6,11 +6,14 @@ @asyncio.coroutine def hello(): websocket = yield from websockets.connect('ws://localhost:8765/') + name = input("What's your name? ") yield from websocket.send(name) print("> {}".format(name)) + greeting = yield from websocket.recv() print("< {}".format(greeting)) + yield from websocket.close() asyncio.get_event_loop().run_until_complete(hello()) diff --git a/example/server.py b/example/server.py index dea1fd40c..7074c9ab5 100755 --- a/example/server.py +++ b/example/server.py @@ -8,6 +8,7 @@ def hello(websocket, path): name = yield from websocket.recv() print("< {}".format(name)) greeting = "Hello {}!".format(name) + yield from websocket.send(greeting) print("> {}".format(greeting)) diff --git a/example/time.html b/example/time.html new file mode 100644 index 000000000..721f44264 --- /dev/null +++ b/example/time.html @@ -0,0 +1,20 @@ + + + + WebSocket demo + + + + + diff --git a/example/time.py b/example/time.py new file mode 100644 index 000000000..3fa7ce966 --- /dev/null +++ b/example/time.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python + +import asyncio +import datetime +import random +import websockets + +@asyncio.coroutine +def time(websocket, path): + while True: + now = datetime.datetime.utcnow().isoformat() + 'Z' + if not websocket.open: + return + yield from websocket.send(now) + yield from asyncio.sleep(random.random() * 3) + +start_server = websockets.serve(time, '127.0.0.1', 5678) + +asyncio.get_event_loop().run_until_complete(start_server) +asyncio.get_event_loop().run_forever() From a8b098d30386c7a05dff59baf04ba12370a841c3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 31 Oct 2015 19:37:16 +0100 Subject: [PATCH 0133/1539] Document the lack of deployment best-practices. Refs #74. --- docs/deployment.rst | 12 ++++++++++++ docs/index.rst | 1 + 2 files changed, 13 insertions(+) create mode 100644 docs/deployment.rst diff --git a/docs/deployment.rst b/docs/deployment.rst new file mode 100644 index 000000000..e5d952538 --- /dev/null +++ b/docs/deployment.rst @@ -0,0 +1,12 @@ +Deployment +---------- + +The author of ``websockets`` isn't aware of best practices for deploying +network services based on :mod:`asyncio`. + +He suggests running a Python script similar to the :ref:`server example +`, perhaps inside a supervisor if you deem it useful. + +If you can share knowledge on this topic, please file an issue_. Thanks! + +.. _issue: https://github.com/aaugustin/websockets/issues/new diff --git a/docs/index.rst b/docs/index.rst index c0712592a..ffcd7afe1 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -50,6 +50,7 @@ Bug reports, patches and suggestions welcome! Just open an issue_ or send a intro cheatsheet api + deployment limitations changelog license From fd33d297d836bf0acf22b793450be0bd00e9e391 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 18 Nov 2015 19:26:41 +0100 Subject: [PATCH 0134/1539] Make tests for error conditions more robust. This will prevent the tests from locking on Python 3.6, assuming that the proposed fix for http://bugs.python.org/issue25593 is implemented. https://groups.google.com/d/msg/python-tulip/r8nN53Fq_x0/8Z7BCdgsAAAJ Thanks Guido for the patch. --- websockets/test_client_server.py | 7 +++---- websockets/test_protocol.py | 7 +++---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 81235f66b..9f67282bb 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -45,10 +45,9 @@ def tearDown(self): self.loop.close() def run_loop_once(self): - # Process callbacks scheduled with call_soon. This pattern works - # because stop schedules a callback to stop the event loop and - # run_forever runs the loop until it hits this callback. - self.loop.stop() + # Process callbacks scheduled with call_soon by appending a callback + # to stop the event loop then running it until it hits that callback. + self.loop.call_soon(self.loop.stop) self.loop.run_forever() def start_server(self, **kwds): diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 414e38692..8ce81cd9c 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -63,10 +63,9 @@ def tearDown(self): super().tearDown() def run_loop_once(self): - # Process callbacks scheduled with call_soon. This pattern works - # because stop schedules a callback to stop the event loop and - # run_forever runs the loop until it hits this callback. - self.loop.stop() + # Process callbacks scheduled with call_soon by appending a callback + # to stop the event loop then running it until it hits that callback. + self.loop.call_soon(self.loop.stop) self.loop.run_forever() def make_drain_slow(self): From 2bd32e949e397f6c82d7de5d45363889bae65b9d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 18 Nov 2015 20:52:09 +0100 Subject: [PATCH 0135/1539] Bump version number. --- docs/changelog.rst | 5 +++++ docs/conf.py | 4 ++-- websockets/version.py | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 7d3ee8b81..5db2b6a44 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,11 @@ Changelog 3.0 ... +*In development* + +2.7 +... + * Added compatibility with Python 3.5. * Refreshed documentation. diff --git a/docs/conf.py b/docs/conf.py index 1a02b24eb..e8223a507 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -48,9 +48,9 @@ # built documents. # # The short X.Y version. -version = '2.6' +version = '2.7' # The full version, including alpha/beta/rc tags. -release = '2.6' +release = '2.7' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/websockets/version.py b/websockets/version.py index 410fd1e00..328c2005d 100644 --- a/websockets/version.py +++ b/websockets/version.py @@ -1 +1 @@ -version = '2.6' +version = '2.7' From 6ec274636ec89453d04decc1c51be68bac6c3af9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 12 Dec 2015 15:19:50 +0100 Subject: [PATCH 0136/1539] Document how to check that messages were received. --- websockets/protocol.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index b8ca6e3ea..513803f01 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -269,7 +269,8 @@ def ping(self, data=None): corresponding pong is received and which you may ignore if you don't want to wait. - A ping may serve as a keepalive. + A ping may serve as a keepalive or as a check that the remote endpoint + received all messages up to this point, with ``yield from ws.ping()``. """ # Protect against duplicates if a payload is explicitly set. if data in self.pings: From 46741370aeaa44fd2957d19e02e5d66b71d35f38 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 12 Dec 2015 12:15:08 +0100 Subject: [PATCH 0137/1539] Avoid busy loop during connection termination. Refs #84. Refs https://github.com/python/asyncio/pull/280. --- docs/changelog.rst | 2 ++ websockets/protocol.py | 18 ++++++++++ websockets/test_protocol.py | 70 ++++++++++++++++++++----------------- 3 files changed, 58 insertions(+), 32 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 5db2b6a44..2276971fe 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -6,6 +6,8 @@ Changelog *In development* +* Worked around an asyncio bug affecting connection termination under load. + 2.7 ... diff --git a/websockets/protocol.py b/websockets/protocol.py index 513803f01..4acf27575 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -429,6 +429,24 @@ def write_frame(self, opcode, data=b''): logger.debug("%s >> %s", side, frame) is_masked = self.is_client write_frame(frame, self.writer.write, is_masked) + + # Backport of the combined logic of: + # https://github.com/python/asyncio/pull/280 + # https://github.com/python/asyncio/pull/291 + # Remove when dropping support for Python < 3.6. + transport = self.writer._transport + if transport is not None: # pragma: no cover + # PR 291 added the is_closing method to transports shortly after + # PR 280 fixed the bug we're trying to work around in this block. + if not hasattr(transport, 'is_closing'): + # This emulates what is_closing would return if it existed. + try: + is_closing = transport._closing + except AttributeError: + is_closing = transport._closed + if is_closing: + yield + try: # Handle flow control automatically. yield from self.writer.drain() diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 8ce81cd9c..2e0301b1b 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -40,9 +40,13 @@ class TransportMock(unittest.mock.Mock): def connect(self, loop, protocol): self.loop = loop self.protocol = protocol + # Remove when dropping support for Python < 3.6. + self._closing = False self.loop.call_soon(self.protocol.connection_made, self) def close(self): + # Remove when dropping support for Python < 3.6. + self._closing = True self.loop.call_soon(self.protocol.connection_lost, None) @@ -107,16 +111,26 @@ def receive_eof(self): connections anyway.) As a consequence, actual transports close themselves after calling it. - To emulate this behavior, tests must close the transport just after - calling the protocol's eof_received. Closing the transport will have + To emulate this behavior, this function closes the transport just + after calling the protocol's eof_received. Closing the transport has the side-effect calling the protocol's connection_lost. - - This method is often called shortly after simulating invalid data to - ensure that the connection fails quickly. """ self.loop.call_soon(self.protocol.eof_received) self.loop.call_soon(self.transport.close) + def process_invalid_frames(self): + """ + Make the protocol fail quickly after simulating invalid data. + + To achieve this, this function triggers the protocol's eof_received, + which interrupts pending reads waiting for more data. It delays this + operation with call_later because the protocol must start processing + frames first. Otherwise it will see a closed connection and no data. + """ + self.loop.call_later(MS, self.receive_eof) + next_message = self.loop.run_until_complete(self.protocol.recv()) + self.assertIsNone(next_message) + def process_control_frames(self): """ Process control frames received by the protocol. @@ -125,7 +139,8 @@ def process_control_frames(self): frame, which recv() will drop. """ self.receive_frame(Frame(True, OP_TEXT, b'')) - self.loop.run_until_complete(self.protocol.recv()) + next_message = self.loop.run_until_complete(self.protocol.recv()) + self.assertEqual(next_message, '') def last_sent_frame(self): """ @@ -217,28 +232,24 @@ def test_recv_binary(self): def test_recv_protocol_error(self): self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8'))) - self.receive_eof() - self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) + self.process_invalid_frames() self.assertConnectionClosed(1002, '') def test_recv_unicode_error(self): self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('latin-1'))) - self.receive_eof() - self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) + self.process_invalid_frames() self.assertConnectionClosed(1007, '') def test_recv_text_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8') * 205)) - self.receive_eof() - self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) + self.process_invalid_frames() self.assertConnectionClosed(1009, '') def test_recv_binary_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(True, OP_BINARY, b'tea' * 342)) - self.receive_eof() - self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) + self.process_invalid_frames() self.assertConnectionClosed(1009, '') def test_recv_text_no_max_size(self): @@ -258,15 +269,13 @@ def test_recv_other_error(self): def read_message(): raise Exception("BOOM") self.protocol.read_message = read_message - self.receive_eof() - self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) + self.process_invalid_frames() with self.assertRaises(Exception): self.loop.run_until_complete(self.protocol.worker) self.assertConnectionClosed(1011, '') def test_recv_on_closed_connection(self): - self.receive_eof() - self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) + self.process_invalid_frames() def test_recv_cancelled(self): recv = self.async(self.protocol.recv()) @@ -293,9 +302,9 @@ def test_send_type_error(self): self.assertNoFrameSent() def test_send_on_closed_connection(self): - self.receive_eof() - # Ensure the protocol processes the connection termination. - self.loop.run_until_complete(self.protocol.recv()) + # This is a way to terminate the connection. + self.process_invalid_frames() + with self.assertRaises(InvalidState): self.loop.run_until_complete(self.protocol.send('foobar')) self.assertNoFrameSent() @@ -370,16 +379,14 @@ def test_fragmented_text_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(False, OP_TEXT, 'café'.encode('utf-8') * 100)) self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8') * 105)) - self.receive_eof() - self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) + self.process_invalid_frames() self.assertConnectionClosed(1009, '') def test_fragmented_binary_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(False, OP_BINARY, b'tea' * 171)) self.receive_frame(Frame(True, OP_CONT, b'tea' * 171)) - self.receive_eof() - self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) + self.process_invalid_frames() self.assertConnectionClosed(1009, '') def test_fragmented_text_no_max_size(self): @@ -408,21 +415,18 @@ def test_unterminated_fragmented_text(self): self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) # Missing the second part of the fragmented frame. self.receive_frame(Frame(True, OP_BINARY, b'tea')) - self.receive_eof() - self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) + self.process_invalid_frames() self.assertConnectionClosed(1002, '') def test_close_handshake_in_fragmented_text(self): self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) self.receive_frame(Frame(True, OP_CLOSE, b'')) - self.receive_eof() - self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) + self.process_invalid_frames() self.assertConnectionClosed(1005, '') def test_connection_close_in_fragmented_text(self): self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) - self.receive_eof() - self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) + self.process_invalid_frames() self.assertConnectionClosed(1006, '') @@ -561,7 +565,9 @@ def test_client_close(self): def test_server_close(self): self.receive_frame(self.close_frame) - self.receive_eof() + # The client expects the server to close the connection. Simulate it + # to avoid having to wait for the connection timeout. + self.loop.call_later(MS, self.receive_eof) # The client is waiting for some data at this point but won't get it. next_message = self.loop.run_until_complete(self.protocol.recv()) From 29b8bc595c3e66738c8330abbea7c265e679fbbd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 12 Dec 2015 14:14:18 +0100 Subject: [PATCH 0138/1539] Catch connection errors a bit more liberally. --- websockets/protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index 4acf27575..6316cbf35 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -450,7 +450,7 @@ def write_frame(self, opcode, data=b''): try: # Handle flow control automatically. yield from self.writer.drain() - except ConnectionResetError: + except ConnectionError: # Terminate the connection if the socket died. yield from self.fail_connection(1006) From f4959b8ad8bd2dfaf542cbc60dc0774a3d2f69af Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 12 Dec 2015 21:30:14 +0100 Subject: [PATCH 0139/1539] Document the registration pattern. --- docs/intro.rst | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/docs/intro.rst b/docs/intro.rst index adfe2694a..6e76968f1 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -83,7 +83,9 @@ Both .... Of course, you can combine the two patterns shown above to read and write -messages on the same connection:: +messages on the same connection. + +:: @asyncio.coroutine def handler(websocket, path): @@ -113,5 +115,33 @@ messages on the same connection:: (This code looks convoluted. If you know a more straightforward solution, please let me know about it!) +Registration +............ + +If you need to maintain a list of currently connected clients, you must +register clients when they connect and unregister them when they disconnect. + +:: + + connected = set() + + @asyncio.coroutine + def handler(websocket, path): + global connected + # Register. + connected.add(websocket) + try: + # Implement logic here. + yield from asyncio.wait( + [ws.send("Hello!") for ws in connected]) + yield from asyncio.sleep(10) + finally: + # Unregister. + connected.remove(websocket) + +This simplistic example keeps track of connected clients in memory. This only +works as long as you run a single process. In a practical application, the +handler may subscribe to some channels on a message broker, for example. + That's really all you have to know! ``websockets`` manages the connection under the hood so you don't have to. From 49658fc697bd67ce069d648439fa94a94cab45ef Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 12 Dec 2015 21:38:24 +0100 Subject: [PATCH 0140/1539] Clarify the conclusion of the introduction. --- docs/intro.rst | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/docs/intro.rst b/docs/intro.rst index 6e76968f1..4738c335c 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -143,5 +143,12 @@ This simplistic example keeps track of connected clients in memory. This only works as long as you run a single process. In a practical application, the handler may subscribe to some channels on a message broker, for example. -That's really all you have to know! ``websockets`` manages the connection -under the hood so you don't have to. +That's all! +----------- + +The design of the ``websockets`` API was driven by simplicity. + +You don't have to worry about performing the opening or the closing handshake, +answering pings, or any other behavior required by the specification. + +``websockets`` handles all this under the hood so you don't have to. From 96e620f275a748069950d761df8feb555d2d73a4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Dec 2015 14:01:42 +0100 Subject: [PATCH 0141/1539] Proof-read and improve documentation. --- docs/api.rst | 6 +++--- docs/changelog.rst | 27 ++++++++++++++++++++++----- docs/cheatsheet.rst | 4 ++-- websockets/client.py | 2 +- websockets/protocol.py | 12 ++++++------ 5 files changed, 34 insertions(+), 17 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index a36fdc255..61f82ba84 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -4,9 +4,9 @@ API Design ------ -``websockets`` provides complete client and server implementations, as shown in -the examples above. These functions are built on top of low-level APIs -reflecting the two phases of the WebSocket protocol: +``websockets`` provides complete client and server implementations, as shown +in the :doc:`getting started guide `. These functions are built on top +of low-level APIs reflecting the two phases of the WebSocket protocol: 1. An opening handshake, in the form of an HTTP Upgrade request; diff --git a/docs/changelog.rst b/docs/changelog.rst index 2276971fe..5212cf1b4 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -8,6 +8,8 @@ Changelog * Worked around an asyncio bug affecting connection termination under load. +* Improved documentation. + 2.7 ... @@ -81,11 +83,26 @@ Changelog 2.0 ... -* Backwards-incompatible API change: - :meth:`~websockets.protocol.WebSocketCommonProtocol.send`, - :meth:`~websockets.protocol.WebSocketCommonProtocol.ping` and - :meth:`~websockets.protocol.WebSocketCommonProtocol.pong` are coroutines. - They used to be regular functions. +.. warning:: + + **Version 2.0 introduces a backwards-incompatible change in the** + :meth:`~websockets.protocol.WebSocketCommonProtocol.send`, + :meth:`~websockets.protocol.WebSocketCommonProtocol.ping`, and + :meth:`~websockets.protocol.WebSocketCommonProtocol.pong` **APIs.** + + **If you're upgrading from 1.x or earlier, please read this carefully.** + + These APIs used to be functions. Now they're coroutines. + + Instead of:: + + websocket.send(message) + + you must now write:: + + yield from websocket.send(message) + +Also: * Added flow control. diff --git a/docs/cheatsheet.rst b/docs/cheatsheet.rst index 75347fb98..33a2c0eee 100644 --- a/docs/cheatsheet.rst +++ b/docs/cheatsheet.rst @@ -20,7 +20,7 @@ Server * The server takes care of establishing connections, then lets the handler execute the application logic, and finally closes the connection after - the handler returns. + the handler exits normally or with an exception. * You may subclass :class:`~websockets.server.WebSocketServerProtocol` and pass it in the ``klass`` keyword argument for advanced customization. @@ -28,7 +28,7 @@ Server Client ------ -* Create a server with :func:`~websockets.client.connect` which is similar to +* Create a client with :func:`~websockets.client.connect` which is similar to asyncio's :meth:`~asyncio.BaseEventLoop.create_connection`. * You may subclass :class:`~websockets.server.WebSocketClientProtocol` and diff --git a/websockets/client.py b/websockets/client.py index d3677c321..b84323914 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -116,7 +116,7 @@ def connect(uri, *, * ``origin`` sets the Origin HTTP header * ``subprotocols`` is a list of supported subprotocols in order of - decreasing preference + decreasing preference * ``extra_headers`` sets additional HTTP request headers – it can be a mapping or an iterable of (name, value) pairs diff --git a/websockets/protocol.py b/websockets/protocol.py index 6316cbf35..16d1ef9e7 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -146,8 +146,8 @@ def local_address(self): """ Local address of the connection. - The address is a ``(host, port)`` tuple or ``None`` if the connection - hasn't been established yet. + This is a ``(host, port)`` tuple or ``None`` if the connection hasn't + been established yet. """ if self.writer is None: return None @@ -158,8 +158,8 @@ def remote_address(self): """ Remote address of the connection. - The address is a ``(host, port)`` tuple or ``None`` if the connection - hasn't been established yet. + This is a ``(host, port)`` tuple or ``None`` if the connection hasn't + been established yet. """ if self.writer is None: return None @@ -180,7 +180,7 @@ def close(self, code=1000, reason=''): This coroutine performs the closing handshake. It waits for the other end to complete the handshake. It doesn't do - anything once the connection is closed. + anything once the connection is closed. Thus it's idemptotent. It's safe to wrap this coroutine in :func:`~asyncio.ensure_future` since errors during connection termination aren't particularly useful. @@ -244,7 +244,7 @@ def send(self, data): """ This coroutine sends a message. - It sends a :class:`str` as a text frame and :class:`bytes` as a binary + It sends :class:`str` as a text frame and :class:`bytes` as a binary frame. It raises a :exc:`TypeError` for other inputs and From 345d9e7dee84856db9e124368e2b213ddbe9585e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Dec 2015 14:45:17 +0100 Subject: [PATCH 0142/1539] Support passing data as str in ping and pong. --- docs/changelog.rst | 4 +++ websockets/protocol.py | 20 +++++++++++ websockets/test_protocol.py | 71 +++++++++++++++++++++++++++++++++++++ 3 files changed, 95 insertions(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index 5212cf1b4..553ff6fd2 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -6,6 +6,10 @@ Changelog *In development* +* :meth:`~websockets.protocol.WebSocketCommonProtocol.ping` and + :meth:`~websockets.protocol.WebSocketCommonProtocol.pong` supports + data passed as :class:`str` in addition to :class:`bytes`. + * Worked around an asyncio bug affecting connection termination under load. * Improved documentation. diff --git a/websockets/protocol.py b/websockets/protocol.py index 16d1ef9e7..80c5181ef 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -271,7 +271,13 @@ def ping(self, data=None): A ping may serve as a keepalive or as a check that the remote endpoint received all messages up to this point, with ``yield from ws.ping()``. + + By default, the ping contains four random bytes. The content may be + overridden with the optional ``data`` argument which must be of type + :class:`str` (which will be encoded to UTF-8) or :class:`bytes`. """ + if data is not None: + data = self.encode_data(data) # Protect against duplicates if a payload is explicitly set. if data in self.pings: raise ValueError("Already waiting for a pong with the same data") @@ -289,11 +295,25 @@ def pong(self, data=b''): This coroutine sends a pong. An unsolicited pong may serve as a unidirectional heartbeat. + + The content may be overridden with the optional ``data`` argument + which must be of type :class:`str` (which will be encoded to UTF-8) or + :class:`bytes`. """ + data = self.encode_data(data) yield from self.write_frame(OP_PONG, data) # Private methods - no guarantees. + def encode_data(self, data): + # Expect str or bytes, return bytes. + if isinstance(data, str): + return data.encode('utf-8') + elif isinstance(data, bytes): + return data + else: + raise TypeError("data must be bytes or str") + @asyncio.coroutine def run(self): # This coroutine guarantees that the connection is closed at exit. diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 2e0301b1b..b45fb2a20 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -220,6 +220,8 @@ def test_connection_lost(self): self.protocol.connection_lost(None) self.assertConnectionClosed(1006, '') + # Test the recv coroutine. + def test_recv_text(self): self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8'))) data = self.loop.run_until_complete(self.protocol.recv()) @@ -288,6 +290,8 @@ def test_recv_cancelled(self): data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, 'café') + # Test the send coroutine. + def test_send_text(self): self.loop.run_until_complete(self.protocol.send('café')) self.assertOneFrameSent(True, OP_TEXT, 'café'.encode('utf-8')) @@ -309,6 +313,67 @@ def test_send_on_closed_connection(self): self.loop.run_until_complete(self.protocol.send('foobar')) self.assertNoFrameSent() + # Test the ping coroutine. + + def test_ping_default(self): + self.loop.run_until_complete(self.protocol.ping()) + # With our testing tools, it's more convenient to extract the expected + # ping data from the library's internals than from the frame sent. + ping_data = next(iter(self.protocol.pings)) + self.assertIsInstance(ping_data, bytes) + self.assertEqual(len(ping_data), 4) + self.assertOneFrameSent(True, OP_PING, ping_data) + + def test_ping_text(self): + self.loop.run_until_complete(self.protocol.ping('café')) + self.assertOneFrameSent(True, OP_PING, 'café'.encode('utf-8')) + + def test_ping_binary(self): + self.loop.run_until_complete(self.protocol.ping(b'tea')) + self.assertOneFrameSent(True, OP_PING, b'tea') + + def test_ping_type_error(self): + with self.assertRaises(TypeError): + self.loop.run_until_complete(self.protocol.ping(42)) + self.assertNoFrameSent() + + def test_ping_on_closed_connection(self): + # This is a way to terminate the connection. + self.process_invalid_frames() + + with self.assertRaises(InvalidState): + self.loop.run_until_complete(self.protocol.ping()) + self.assertNoFrameSent() + + # Test the pong coroutine. + + def test_pong_default(self): + self.loop.run_until_complete(self.protocol.pong()) + self.assertOneFrameSent(True, OP_PONG, b'') + + def test_pong_text(self): + self.loop.run_until_complete(self.protocol.pong('café')) + self.assertOneFrameSent(True, OP_PONG, 'café'.encode('utf-8')) + + def test_pong_binary(self): + self.loop.run_until_complete(self.protocol.pong(b'tea')) + self.assertOneFrameSent(True, OP_PONG, b'tea') + + def test_pong_type_error(self): + with self.assertRaises(TypeError): + self.loop.run_until_complete(self.protocol.pong(42)) + self.assertNoFrameSent() + + def test_pong_on_closed_connection(self): + # This is a way to terminate the connection. + self.process_invalid_frames() + + with self.assertRaises(InvalidState): + self.loop.run_until_complete(self.protocol.pong()) + self.assertNoFrameSent() + + # Test the protocol's logic for acknowledging pings with pongs. + def test_answer_ping(self): self.receive_frame(Frame(True, OP_PING, b'test')) self.process_control_frames() @@ -362,6 +427,8 @@ def test_duplicate_ping(self): self.loop.run_until_complete(self.protocol.ping(b'foobar')) self.assertNoFrameSent() + # Test the protocol's logic for rebuilding fragmented messages. + def test_fragmented_text(self): self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) self.receive_frame(Frame(True, OP_CONT, 'fé'.encode('utf-8'))) @@ -432,6 +499,8 @@ def test_connection_close_in_fragmented_text(self): class ServerCloseTests(CommonTests, unittest.TestCase): + # Test the protocol logic for closing the connection on the server side. + def test_server_close(self): self.receive_frame(self.close_frame) self.loop.run_until_complete(self.protocol.close(reason='close')) @@ -549,6 +618,8 @@ def setUp(self): super().setUp() self.protocol.is_client = True + # Test the protocol logic for closing the connection on the client side. + def test_client_close(self): self.receive_frame(self.close_frame) self.receive_eof() From 188ec3276e110106dc8e2bbd6771538efbc8c877 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Dec 2015 14:53:02 +0100 Subject: [PATCH 0143/1539] Improve some old tests a bit. --- websockets/test_protocol.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index b45fb2a20..956255ee8 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -213,11 +213,16 @@ def test_remote_address(self): def test_open(self): self.assertTrue(self.protocol.open) - self.protocol.connection_lost(None) + + # This is a way to terminate the connection. + self.process_invalid_frames() + self.assertFalse(self.protocol.open) def test_connection_lost(self): + # Test calling connection_lost without going through close_connection. self.protocol.connection_lost(None) + self.assertConnectionClosed(1006, '') # Test the recv coroutine. From 34f9c3e7b2ce7f344c221baa47a98b59e604d0b4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Dec 2015 19:27:50 +0100 Subject: [PATCH 0144/1539] Make state_name a public API and test it. --- docs/api.rst | 11 +++++++---- docs/changelog.rst | 2 ++ websockets/protocol.py | 16 ++++++++++++---- websockets/test_protocol.py | 8 ++++++++ 4 files changed, 29 insertions(+), 8 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index 61f82ba84..0ead249af 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -57,10 +57,6 @@ Shared .. autoclass:: WebSocketCommonProtocol(*, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, loop=None) - .. autoattribute:: local_address - .. autoattribute:: remote_address - - .. autoattribute:: open .. automethod:: close(code=1000, reason='') .. automethod:: recv() @@ -69,6 +65,13 @@ Shared .. automethod:: ping(data=None) .. automethod:: pong(data=b'') + .. autoattribute:: local_address + .. autoattribute:: remote_address + + .. autoattribute:: open + .. autoattribute:: state_name + + Exceptions .......... diff --git a/docs/changelog.rst b/docs/changelog.rst index 553ff6fd2..767b0711f 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -12,6 +12,8 @@ Changelog * Worked around an asyncio bug affecting connection termination under load. +* Made ``state_name`` atttribute on protocols a public API. + * Improved documentation. 2.7 diff --git a/websockets/protocol.py b/websockets/protocol.py index 80c5181ef..094117a4e 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -135,10 +135,6 @@ def __init__(self, *, if self.state == OPEN: self.opening_handshake.set_result(True) - @property - def state_name(self): - return ['CONNECTING', 'OPEN', 'CLOSING', 'CLOSED'][self.state] - # Public API @property @@ -174,6 +170,18 @@ def open(self): """ return self.state == OPEN + @property + def state_name(self): + """ + Current connection state, as a string. + + Possible states are defined in the WebSocket specification: + CONNECTING, OPEN, CLOSING, or CLOSED. + + To check if the connection is open, use :attr:`open` instead. + """ + return ['CONNECTING', 'OPEN', 'CLOSING', 'CLOSED'][self.state] + @asyncio.coroutine def close(self, code=1000, reason=''): """ diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 956255ee8..e508225a2 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -219,6 +219,14 @@ def test_open(self): self.assertFalse(self.protocol.open) + def test_state_name(self): + self.assertEqual(self.protocol.state_name, 'OPEN') + + # This is a way to terminate the connection. + self.process_invalid_frames() + + self.assertEqual(self.protocol.state_name, 'CLOSED') + def test_connection_lost(self): # Test calling connection_lost without going through close_connection. self.protocol.connection_lost(None) From 3a04aa2f2b363d6fb8c4874acc64446b2495a64e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Dec 2015 14:17:00 +0100 Subject: [PATCH 0145/1539] Check the connection state in send, ping and pong. This doesn't change the behavior significantly. It raises a slightly more explicit ConnectionClosed error that carries the close code and reason instead of the technical InvalidState (which it subclasses to preserve API compatibility). --- websockets/exceptions.py | 12 +++++- websockets/protocol.py | 47 ++++++++++++++++++---- websockets/test_protocol.py | 77 +++++++++++++++++++++++++++++++------ 3 files changed, 116 insertions(+), 20 deletions(-) diff --git a/websockets/exceptions.py b/websockets/exceptions.py index 258780eec..76bb22448 100644 --- a/websockets/exceptions.py +++ b/websockets/exceptions.py @@ -1,6 +1,6 @@ __all__ = [ 'InvalidHandshake', 'InvalidOrigin', 'InvalidState', 'InvalidURI', - 'PayloadTooBig', 'WebSocketProtocolError', + 'ConnectionClosed', 'PayloadTooBig', 'WebSocketProtocolError', ] @@ -16,6 +16,16 @@ class InvalidState(Exception): """Exception raised when an operation is forbidden in the current state.""" +class ConnectionClosed(InvalidState): + """Exception raised when trying to read or write on a closed connection.""" + + def __init__(self, code, reason): + self.code = code + self.reason = reason + super().__init__('WebSocket connection is closed: ' + 'code = {}, reason = {}'.format(code, reason)) + + class InvalidURI(Exception): """Exception raised when an URI isn't a valid websocket URI.""" diff --git a/websockets/protocol.py b/websockets/protocol.py index 094117a4e..e12985960 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -16,7 +16,8 @@ import struct from .compatibility import asyncio_ensure_future -from .exceptions import InvalidState, PayloadTooBig, WebSocketProtocolError +from .exceptions import (ConnectionClosed, InvalidState, PayloadTooBig, + WebSocketProtocolError) from .framing import * from .handshake import * @@ -253,12 +254,10 @@ def send(self, data): This coroutine sends a message. It sends :class:`str` as a text frame and :class:`bytes` as a binary - frame. - - It raises a :exc:`TypeError` for other inputs and - :exc:`~websockets.exceptions.InvalidState` once the connection is - closed. + frame. It raises a :exc:`TypeError` for other inputs. """ + yield from self.ensure_open() + if isinstance(data, str): opcode = 1 data = data.encode('utf-8') @@ -266,6 +265,7 @@ def send(self, data): opcode = 2 else: raise TypeError("data must be bytes or str") + yield from self.write_frame(opcode, data) @asyncio.coroutine @@ -284,11 +284,15 @@ def ping(self, data=None): overridden with the optional ``data`` argument which must be of type :class:`str` (which will be encoded to UTF-8) or :class:`bytes`. """ + yield from self.ensure_open() + if data is not None: data = self.encode_data(data) + # Protect against duplicates if a payload is explicitly set. if data in self.pings: raise ValueError("Already waiting for a pong with the same data") + # Generate a unique random payload otherwise. while data is None or data in self.pings: data = struct.pack('!I', random.getrandbits(32)) @@ -308,7 +312,10 @@ def pong(self, data=b''): which must be of type :class:`str` (which will be encoded to UTF-8) or :class:`bytes`. """ + yield from self.ensure_open() + data = self.encode_data(data) + yield from self.write_frame(OP_PONG, data) # Private methods - no guarantees. @@ -322,6 +329,29 @@ def encode_data(self, data): else: raise TypeError("data must be bytes or str") + @asyncio.coroutine + def ensure_open(self): + # Raise a suitable exception if the connection isn't open. + # Handle cases from the most common to the least common. + + if self.state == OPEN: + return + + if self.state == CLOSED: + raise ConnectionClosed(self.close_code, self.close_reason) + + # If the closing handshake is in progress, let it complete to get the + # proper close status and code. As an safety measure, the timeout is + # longer than the worst case (2 * self.timeout) but not unlimited. + if self.state == CLOSING: + yield from asyncio.wait_for( + self.connection_closed, 3 * self.timeout, loop=self.loop) + raise ConnectionClosed(self.close_code, self.close_reason) + + # Control may only reach this point in buggy third-party subclasses. + assert self.state == CONNECTING + raise InvalidState("WebSocket connection isn't established yet.") + @asyncio.coroutine def run(self): # This coroutine guarantees that the connection is closed at exit. @@ -444,10 +474,11 @@ def read_frame(self, max_size): @asyncio.coroutine def write_frame(self, opcode, data=b''): - # This may happen if a user attempts to write on a closed connection. - if self.state != OPEN: + # Defensive assertion for protocol compliance. + if self.state != OPEN: # pragma: no cover raise InvalidState("Cannot write to a WebSocket " "in the {} state".format(self.state_name)) + # Make sure no other frame will be sent after a close frame. Do this # before yielding control to avoid sending more than one close frame. if opcode == OP_CLOSE: diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index e508225a2..26d61d2e6 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -7,9 +7,9 @@ import unittest.mock from .compatibility import asyncio_ensure_future -from .exceptions import InvalidState +from .exceptions import ConnectionClosed, InvalidState from .framing import * -from .protocol import CLOSED, WebSocketCommonProtocol +from .protocol import CLOSED, CONNECTING, WebSocketCommonProtocol # Unit for timeouts. May be increased on slow machines by setting the @@ -191,6 +191,8 @@ def assertCompletesWithin(self, min_time, max_time): self.assertLess( dt, max_time, "Too slow: {} >= {}".format(dt, max_time)) + # Test public attributes. + def test_local_address(self): get_extra_info = unittest.mock.Mock(return_value=('host', 4312)) self.transport.get_extra_info = get_extra_info @@ -227,12 +229,6 @@ def test_state_name(self): self.assertEqual(self.protocol.state_name, 'CLOSED') - def test_connection_lost(self): - # Test calling connection_lost without going through close_connection. - self.protocol.connection_lost(None) - - self.assertConnectionClosed(1006, '') - # Test the recv coroutine. def test_recv_text(self): @@ -318,11 +314,26 @@ def test_send_type_error(self): self.loop.run_until_complete(self.protocol.send(42)) self.assertNoFrameSent() + def test_send_on_closing_connection(self): + # This is a way to start a closing handshake. + self.async(self.protocol.close()) + self.run_loop_once() + self.assertOneFrameSent(True, OP_CLOSE, b'\x03\xe8') + + # Complete the closing handshake while running the send. + self.receive_frame(self.close_frame) + if self.protocol.is_client: + self.receive_eof() + + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.protocol.send('foobar')) + self.assertNoFrameSent() + def test_send_on_closed_connection(self): # This is a way to terminate the connection. self.process_invalid_frames() - with self.assertRaises(InvalidState): + with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.send('foobar')) self.assertNoFrameSent() @@ -350,11 +361,26 @@ def test_ping_type_error(self): self.loop.run_until_complete(self.protocol.ping(42)) self.assertNoFrameSent() + def test_ping_on_closing_connection(self): + # This is a way to start a closing handshake. + self.async(self.protocol.close()) + self.run_loop_once() + self.assertOneFrameSent(True, OP_CLOSE, b'\x03\xe8') + + # Complete the closing handshake while running the ping. + self.receive_frame(self.close_frame) + if self.protocol.is_client: + self.receive_eof() + + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.protocol.ping()) + self.assertNoFrameSent() + def test_ping_on_closed_connection(self): # This is a way to terminate the connection. self.process_invalid_frames() - with self.assertRaises(InvalidState): + with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.ping()) self.assertNoFrameSent() @@ -377,11 +403,26 @@ def test_pong_type_error(self): self.loop.run_until_complete(self.protocol.pong(42)) self.assertNoFrameSent() + def test_pong_on_closing_connection(self): + # This is a way to start a closing handshake. + self.async(self.protocol.close()) + self.run_loop_once() + self.assertOneFrameSent(True, OP_CLOSE, b'\x03\xe8') + + # Complete the closing handshake while running the pong. + self.receive_frame(self.close_frame) + if self.protocol.is_client: + self.receive_eof() + + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.protocol.pong()) + self.assertNoFrameSent() + def test_pong_on_closed_connection(self): # This is a way to terminate the connection. self.process_invalid_frames() - with self.assertRaises(InvalidState): + with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.pong()) self.assertNoFrameSent() @@ -509,6 +550,20 @@ def test_connection_close_in_fragmented_text(self): self.process_invalid_frames() self.assertConnectionClosed(1006, '') + # Test miscellaneous code paths to ensure full coverage. + + def test_connection_lost(self): + # Test calling connection_lost without going through close_connection. + self.protocol.connection_lost(None) + + self.assertConnectionClosed(1006, '') + + def test_ensure_connection_before_opening_handshake(self): + self.protocol.state = CONNECTING + + with self.assertRaises(InvalidState): + self.loop.run_until_complete(self.protocol.ensure_open()) + class ServerCloseTests(CommonTests, unittest.TestCase): From 52e7750a8ddf1d7674fa27ad2f46fd2a448ce0cb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Dec 2015 21:26:03 +0100 Subject: [PATCH 0146/1539] Raise an exception in recv instead of returning None. This makes the API more Pythonic. Fix #77. --- docs/changelog.rst | 33 +++++++++++++++ docs/intro.rst | 29 ++++--------- example/time.py | 2 - websockets/client.py | 5 ++- websockets/protocol.py | 36 +++++++++++++--- websockets/server.py | 4 +- websockets/test_client_server.py | 6 +-- websockets/test_protocol.py | 73 +++++++++++++++++++++++--------- 8 files changed, 134 insertions(+), 54 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 767b0711f..540772da1 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -10,6 +10,39 @@ Changelog :meth:`~websockets.protocol.WebSocketCommonProtocol.pong` supports data passed as :class:`str` in addition to :class:`bytes`. +.. warning:: + + **Version 3.0 introduces a backwards-incompatible change in the** + :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` **API.** + + **If you're upgrading from 2.x or earlier, please read this carefully.** + + :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` used to return + ``None`` when the connection was closed. This required checking the return + value of every call:: + + message = yield from websocket.recv() + if message is None: + return + + Now it raises a :exc:`~websockets.exceptions.ConnectionClosed` exception + instead. This is more Pythonic. The previous code can be simplified to:: + + message = yield from websocket.recv() + + When implementing a server, which is the more popular use case, there's no + strong reason to handle such exceptions. Let them bubble up, terminate the + handler coroutine, and the server will simply ignore them. + + In order to avoid stranding projects built upon an earlier version, the + previous behavior can be restored by passing ``legacy_recv=True`` to + :func:`~websockets.server.serve`, :func:`~websockets.client.connect`, + :class:`~websockets.server.WebSocketServerProtocol`, or + :class:`~websockets.client.WebSocketClientProtocol`. ``legacy_recv`` isn't + documented in their signatures but isn't scheduled for deprecation either. + +Also: + * Worked around an asyncio bug affecting connection termination under load. * Made ``state_name`` atttribute on protocols a public API. diff --git a/docs/intro.rst b/docs/intro.rst index 4738c335c..a48bda71d 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -38,10 +38,9 @@ Then open this HTML file in a browser. Common patterns --------------- -You will almost always want to process several messages during the -lifetime of a connection. Therefore you must write a loop. Here are the -recommended patterns to exit cleanly when the connection drops, either -because the other side closed it or for any other reason. +You will usually want to process several messages during the lifetime of a +connection. Therefore you must write a loop. Here are the basic patterns for +building a WebSocket server. Consumer ........ @@ -52,14 +51,11 @@ For receiving messages and passing them to a ``consumer`` coroutine:: def handler(websocket, path): while True: message = yield from websocket.recv() - if message is None: - break yield from consumer(message) -:meth:`~websockets.protocol.WebSocketCommonProtocol.recv` returns ``None`` -when the connection is closed. In other words, ``None`` marks the end of -the message stream. The handler coroutine should check for that case and -return when it happens. +:meth:`~websockets.protocol.WebSocketCommonProtocol.recv` raises a +:exc:`~websockets.exceptions.ConnectionClosed` exception when the client +disconnects, which breaks out of the ``while True`` loop. Producer ........ @@ -70,14 +66,11 @@ For getting messages from a ``producer`` coroutine and sending them:: def handler(websocket, path): while True: message = yield from producer() - if not websocket.open: - break yield from websocket.send(message) -:meth:`~websockets.protocol.WebSocketCommonProtocol.send` fails with an -exception when it's called on a closed connection. Therefore the handler -coroutine should check that the connection is still open before attempting -to write and return otherwise. +:meth:`~websockets.protocol.WebSocketCommonProtocol.send` raises a +:exc:`~websockets.exceptions.ConnectionClosed` exception when the client +disconnects, which breaks out of the ``while True`` loop. Both .... @@ -98,16 +91,12 @@ messages on the same connection. if listener_task in done: message = listener_task.result() - if message is None: - break yield from consumer(message) else: listener_task.cancel() if producer_task in done: message = producer_task.result() - if not websocket.open: - break yield from websocket.send(message) else: producer_task.cancel() diff --git a/example/time.py b/example/time.py index 3fa7ce966..7374bb8ad 100644 --- a/example/time.py +++ b/example/time.py @@ -9,8 +9,6 @@ def time(websocket, path): while True: now = datetime.datetime.utcnow().isoformat() + 'Z' - if not websocket.open: - return yield from websocket.send(now) yield from asyncio.sleep(random.random() * 3) diff --git a/websockets/client.py b/websockets/client.py index b84323914..90a37ed69 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -98,7 +98,7 @@ def handshake(self, wsuri, @asyncio.coroutine def connect(uri, *, - loop=None, klass=WebSocketClientProtocol, + loop=None, klass=WebSocketClientProtocol, legacy_recv=False, origin=None, subprotocols=None, extra_headers=None, **kwds): """ @@ -136,7 +136,8 @@ def connect(uri, *, raise ValueError("connect() received a SSL context for a ws:// URI. " "Use a wss:// URI to enable TLS.") factory = lambda: klass( - host=wsuri.host, port=wsuri.port, secure=wsuri.secure, loop=loop) + host=wsuri.host, port=wsuri.port, secure=wsuri.secure, + loop=loop, legacy_recv=legacy_recv) transport, protocol = yield from loop.create_connection( factory, wsuri.host, wsuri.port, **kwds) diff --git a/websockets/protocol.py b/websockets/protocol.py index e12985960..71e74861b 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -58,7 +58,8 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): The ``max_size`` parameter enforces the maximum size for incoming messages in bytes. The default value is 1MB. ``None`` disables the limit. If a message larger than the maximum size is received, :meth:`recv()` will - return ``None`` and the connection will be closed with status code 1009. + raise :exc:`~websockets.exceptions.ConnectionClosed` and the connection + will be closed with status code 1009. Once the handshake is complete, request and response HTTP headers are available: @@ -85,7 +86,8 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): def __init__(self, *, host=None, port=None, secure=None, - timeout=10, max_size=2 ** 20, loop=None): + timeout=10, max_size=2 ** 20, loop=None, + legacy_recv=False): self.host = host self.port = port self.secure = secure @@ -96,6 +98,8 @@ def __init__(self, *, # attribute of StreamReaderProtocol, inherited from FlowControlMixin. self.loop = loop + self.legacy_recv = legacy_recv + stream_reader = asyncio.StreamReader(loop=loop) super().__init__(stream_reader, self.client_connected, loop) @@ -167,7 +171,11 @@ def open(self): """ This property is ``True`` when the connection is usable. - It may be used to handle disconnections gracefully. + It may be used to detect disconnections but this is discouraged per + the EAFP_ principle. When ``open`` is ``False``, using the connection + raises a :exc:`~websockets.exceptions.ConnectionClosed` exception. + + .. _EAFP: https://docs.python.org/3/glossary.html#term-eafp """ return self.state == OPEN @@ -221,16 +229,27 @@ def recv(self): It returns a :class:`str` for a text frame and :class:`bytes` for a binary frame. - When the end of the message stream is reached, or when a protocol - error occurs, :meth:`recv` returns ``None``, indicating that the - connection is closed. + When the end of the message stream is reached, :meth:`recv` raises + :exc:`~websockets.exceptions.ConnectionClosed`. This can happen after + a normal connection closure, a protocol error or a network failure. + + .. versionchanged:: 3.0 + + :meth:`recv` used to return ``None`` instead. Refer to the + changelog for details. """ + # Don't yield from self.ensure_open() here because messages could be + # available in the queue even if the connection is closed. + # Return any available message try: return self.messages.get_nowait() except asyncio.queues.QueueEmpty: pass + # Don't yield from self.ensure_open() here because messages could be + # received before the closing frame even if the connection is closing. + # Wait for a message until the connection is closed next_message = asyncio_ensure_future( self.messages.get(), loop=self.loop) @@ -243,10 +262,15 @@ def recv(self): next_message.cancel() raise + # Now there's no need to yield from self.ensure_open(). Either a + # message was received or the connection was closed. + if next_message in done: return next_message.result() else: next_message.cancel() + if not self.legacy_recv: + raise ConnectionClosed(self.close_code, self.close_reason) @asyncio.coroutine def send(self, data): diff --git a/websockets/server.py b/websockets/server.py index 1953848a0..0fe9a729d 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -239,7 +239,7 @@ def wait_closed(self): @asyncio.coroutine def serve(ws_handler, host=None, port=None, *, - loop=None, klass=WebSocketServerProtocol, + loop=None, klass=WebSocketServerProtocol, legacy_recv=False, origins=None, subprotocols=None, extra_headers=None, **kwds): """ @@ -295,7 +295,7 @@ def serve(ws_handler, host=None, port=None, *, ws_handler, ws_server, host=host, port=port, secure=secure, origins=origins, subprotocols=subprotocols, - extra_headers=extra_headers, loop=loop) + extra_headers=extra_headers, loop=loop, legacy_recv=legacy_recv) server = yield from loop.create_server(factory, host, port, **kwds) ws_server.wrap(server) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 9f67282bb..8ce75703a 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -6,7 +6,7 @@ import unittest.mock from .client import * -from .exceptions import InvalidHandshake +from .exceptions import ConnectionClosed, InvalidHandshake from .http import USER_AGENT, read_response from .server import * @@ -299,8 +299,8 @@ def test_server_handler_crashes(self, send): self.start_server() self.start_client() self.loop.run_until_complete(self.client.send("Hello!")) - reply = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(reply, None) + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.client.recv()) self.stop_client() self.stop_server() diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 26d61d2e6..eb231dc45 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -128,8 +128,8 @@ def process_invalid_frames(self): frames first. Otherwise it will see a closed connection and no data. """ self.loop.call_later(MS, self.receive_eof) - next_message = self.loop.run_until_complete(self.protocol.recv()) - self.assertIsNone(next_message) + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.protocol.recv()) def process_control_frames(self): """ @@ -241,6 +241,27 @@ def test_recv_binary(self): data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, b'tea') + def test_recv_on_closing_connection(self): + # This is a way to start a closing handshake. + self.async(self.protocol.close()) + self.run_loop_once() + self.assertOneFrameSent(True, OP_CLOSE, b'\x03\xe8') + + # Complete the closing handshake while running the recv. + self.receive_frame(self.close_frame) + if self.protocol.is_client: + self.receive_eof() + + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.protocol.recv()) + + def test_recv_on_closed_connection(self): + # This is a way to terminate the connection. + self.process_invalid_frames() + + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.protocol.recv()) + def test_recv_protocol_error(self): self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8'))) self.process_invalid_frames() @@ -285,9 +306,6 @@ def read_message(): self.loop.run_until_complete(self.protocol.worker) self.assertConnectionClosed(1011, '') - def test_recv_on_closed_connection(self): - self.process_invalid_frames() - def test_recv_cancelled(self): recv = self.async(self.protocol.recv()) self.loop.call_soon(recv.cancel) @@ -564,6 +582,19 @@ def test_ensure_connection_before_opening_handshake(self): with self.assertRaises(InvalidState): self.loop.run_until_complete(self.protocol.ensure_open()) + def test_legacy_recv(self): + # By default legacy_recv in disabled. + self.assertEqual(self.protocol.legacy_recv, False) + + # This is a way to terminate the connection. + self.process_invalid_frames() + + # Enable legacy_recv. + self.protocol.legacy_recv = True + + # Now recv() returns None instead of raising ConnectionClosed. + self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) + class ServerCloseTests(CommonTests, unittest.TestCase): @@ -584,11 +615,12 @@ def test_server_close(self): def test_client_close(self): self.receive_frame(self.close_frame) + # The server is waiting for some data at this point but won't get it. - next_message = self.loop.run_until_complete(self.protocol.recv()) + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.protocol.recv()) - self.assertIsNone(next_message) - # After recv() returns None, the connection is closed. + # After recv() raises ConnectionClosed, the connection is closed. self.assertConnectionClosed(1000, 'close') self.assertOneFrameSent(*self.close_frame) @@ -633,9 +665,10 @@ def test_client_close_race_with_failing_connection(self): # Fail the connection while answering a close frame from the client. self.loop.call_soon(self.receive_frame, self.client_close) self.loop.call_later(MS, self.async, self.protocol.fail_connection()) - next_message = self.loop.run_until_complete(self.protocol.recv()) - self.assertIsNone(next_message) + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.protocol.recv()) + # The closing handshake was completed by fail_connection. self.assertConnectionClosed(1011, '') self.assertOneFrameSent(*self.client_close) @@ -659,8 +692,8 @@ def test_close_during_recv(self): self.loop.run_until_complete(self.protocol.close(reason='close')) # Receiving a message shouldn't crash. - next_message = self.loop.run_until_complete(recv) - self.assertIsNone(next_message) + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(recv) self.assertConnectionClosed(1000, 'close') @@ -707,11 +740,12 @@ def test_server_close(self): # The client expects the server to close the connection. Simulate it # to avoid having to wait for the connection timeout. self.loop.call_later(MS, self.receive_eof) + # The client is waiting for some data at this point but won't get it. - next_message = self.loop.run_until_complete(self.protocol.recv()) + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.protocol.recv()) - self.assertIsNone(next_message) - # After recv() returns None, the connection is closed. + # After recv() raises ConnectionClosed, the connection is closed. self.assertConnectionClosed(1000, 'close') self.assertOneFrameSent(*self.close_frame) @@ -772,9 +806,10 @@ def test_server_close_race_with_failing_connection(self): self.loop.call_soon(self.receive_frame, self.server_close) self.loop.call_later(MS, self.async, self.protocol.fail_connection()) self.loop.call_later(2 * MS, self.receive_eof) - next_message = self.loop.run_until_complete(self.protocol.recv()) - self.assertIsNone(next_message) + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.protocol.recv()) + # The closing handshake was completed by fail_connection. self.assertConnectionClosed(1011, '') self.assertOneFrameSent(*self.server_close) @@ -800,8 +835,8 @@ def test_close_during_recv(self): self.loop.run_until_complete(self.protocol.close(reason='close')) # Receiving a message shouldn't crash. - next_message = self.loop.run_until_complete(recv) - self.assertIsNone(next_message) + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(recv) self.assertConnectionClosed(1000, 'close') From 5e7849004265a68def05d1cee088d14d23e34072 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Dec 2015 21:37:40 +0100 Subject: [PATCH 0147/1539] Improve and normalize docstring style. --- compliance/test_client.py | 5 +++-- compliance/test_server.py | 5 +++-- websockets/client.py | 5 ++++- websockets/exceptions.py | 35 ++++++++++++++++++++++++++++------- websockets/framing.py | 7 +++++++ websockets/handshake.py | 5 +++++ websockets/http.py | 5 +++++ websockets/protocol.py | 12 +++++++++++- websockets/server.py | 11 +++++++++-- websockets/test_handshake.py | 2 ++ websockets/test_protocol.py | 6 ++++++ websockets/uri.py | 3 +++ 12 files changed, 86 insertions(+), 15 deletions(-) diff --git a/compliance/test_client.py b/compliance/test_client.py index ed3915364..3804e9f92 100644 --- a/compliance/test_client.py +++ b/compliance/test_client.py @@ -14,9 +14,10 @@ class EchoClientProtocol(websockets.WebSocketClientProtocol): + """ + WebSocket client protocol that echoes messages synchronously. - """WebSocket client protocol that echoes messages synchronously.""" - + """ def __init__(self, *args, **kwargs): kwargs['max_size'] = 2 ** 25 super().__init__(*args, **kwargs) diff --git a/compliance/test_server.py b/compliance/test_server.py index 85b059011..46e48128f 100644 --- a/compliance/test_server.py +++ b/compliance/test_server.py @@ -9,9 +9,10 @@ class EchoServerProtocol(websockets.WebSocketServerProtocol): + """ + WebSocket server protocol that echoes messages synchronously. - """WebSocket server protocol that echoes messages synchronously.""" - + """ def __init__(self, *args, **kwargs): kwargs['max_size'] = 2 ** 25 super().__init__(*args, **kwargs) diff --git a/websockets/client.py b/websockets/client.py index 90a37ed69..958825cde 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -1,5 +1,6 @@ """ The :mod:`websockets.client` module defines a simple WebSocket client API. + """ __all__ = ['connect', 'WebSocketClientProtocol'] @@ -21,8 +22,8 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): This class inherits most of its methods from :class:`~websockets.protocol.WebSocketCommonProtocol`. - """ + """ is_client = True state = CONNECTING @@ -39,6 +40,7 @@ def handshake(self, wsuri, If provided, ``extra_headers`` sets additional HTTP request headers. It must be a mapping or an iterable of (name, value) pairs. + """ headers = [] set_header = lambda k, v: headers.append((k, v)) @@ -125,6 +127,7 @@ def connect(uri, *, It raises :exc:`~websockets.uri.InvalidURI` if ``uri`` is invalid and :exc:`~websockets.handshake.InvalidHandshake` if the handshake fails. + """ if loop is None: loop = asyncio.get_event_loop() diff --git a/websockets/exceptions.py b/websockets/exceptions.py index 76bb22448..c06280f6a 100644 --- a/websockets/exceptions.py +++ b/websockets/exceptions.py @@ -5,19 +5,31 @@ class InvalidHandshake(Exception): - """Exception raised when a handshake request or response is invalid.""" + """ + Exception raised when a handshake request or response is invalid. + + """ class InvalidOrigin(InvalidHandshake): - """Exception raised when the origin in a handshake request is forbidden.""" + """ + Exception raised when the origin in a handshake request is forbidden. + + """ class InvalidState(Exception): - """Exception raised when an operation is forbidden in the current state.""" + """ + Exception raised when an operation is forbidden in the current state. + + """ class ConnectionClosed(InvalidState): - """Exception raised when trying to read or write on a closed connection.""" + """ + Exception raised when trying to read or write on a closed connection. + + """ def __init__(self, code, reason): self.code = code @@ -27,12 +39,21 @@ def __init__(self, code, reason): class InvalidURI(Exception): - """Exception raised when an URI isn't a valid websocket URI.""" + """ + Exception raised when an URI isn't a valid websocket URI. + + """ class PayloadTooBig(Exception): - """Exception raised when a frame's payload exceeds the maximum size.""" + """ + Exception raised when a frame's payload exceeds the maximum size. + + """ class WebSocketProtocolError(Exception): - """Internal exception raised when the remote side breaks the protocol.""" + """ + Internal exception raised when the remote side breaks the protocol. + + """ diff --git a/websockets/framing.py b/websockets/framing.py index 7714a8671..544ac27ec 100644 --- a/websockets/framing.py +++ b/websockets/framing.py @@ -6,6 +6,7 @@ of frames is implemented in :mod:`websockets.protocol`. .. _section 5 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-5 + """ import asyncio @@ -52,6 +53,7 @@ Only these three fields are needed by higher level code. The MASK bit, payload length and masking-key are handled on the fly by :func:`read_frame` and :func:`write_frame`. + """ @@ -72,6 +74,7 @@ def read_frame(reader, mask, *, max_size=None): This function validates the frame before returning it and raises :exc:`~websockets.exceptions.WebSocketProtocolError` if it contains incorrect values. + """ # Read the header data = yield from reader(2) @@ -119,6 +122,7 @@ def write_frame(frame, writer, mask): This function validates the frame before sending it and raises :exc:`~websockets.exceptions.WebSocketProtocolError` if it contains incorrect values. + """ check_frame(frame) output = io.BytesIO() @@ -153,6 +157,7 @@ def check_frame(frame): """ Raise :exc:`~websockets.exceptions.WebSocketProtocolError` if the frame contains incorrect values. + """ if frame.opcode in (OP_CONT, OP_TEXT, OP_BINARY): return @@ -174,6 +179,7 @@ def parse_close(data): Raise :exc:`~websockets.exceptions.WebSocketProtocolError` or :exc:`UnicodeDecodeError` if the data is invalid. + """ length = len(data) if length == 0: @@ -193,5 +199,6 @@ def serialize_close(code, reason): Serialize the data for a close frame. This is the reverse of :func:`parse_close`. + """ return struct.pack('!H', code) + reason.encode('utf-8') diff --git a/websockets/handshake.py b/websockets/handshake.py index 322d5557a..e2ed644d9 100644 --- a/websockets/handshake.py +++ b/websockets/handshake.py @@ -31,6 +31,7 @@ :func:`build_request`, - Read the response, check that the status code is 101, and check the headers with :func:`check_response`. + """ __all__ = [ @@ -53,6 +54,7 @@ def build_request(set_header): Build a handshake request to send to the server. Return the ``key`` which must be passed to :func:`check_response`. + """ rand = bytes(random.getrandbits(8) for _ in range(16)) key = base64.b64encode(rand).decode() @@ -77,6 +79,7 @@ def check_request(get_header): request and doesn't perform Host and Origin checks. These controls are usually performed earlier in the HTTP request handling code. They're the responsibility of the caller. + """ try: assert get_header('Upgrade').lower() == 'websocket' @@ -97,6 +100,7 @@ def build_response(set_header, key): Build a handshake response to send to the client. ``key`` comes from :func:`check_request`. + """ set_header('Upgrade', 'WebSocket') set_header('Connection', 'Upgrade') @@ -117,6 +121,7 @@ def check_response(get_header, key): This function doesn't verify that the response is an HTTP/1.1 or higher response with a 101 status code. These controls are the responsibility of the caller. + """ try: assert get_header('Upgrade').lower() == 'websocket' diff --git a/websockets/http.py b/websockets/http.py index 6e1972fa8..1452d0a36 100644 --- a/websockets/http.py +++ b/websockets/http.py @@ -4,6 +4,7 @@ These functions cannot be imported from :mod:`websockets`; they must be imported from :mod:`websockets.http`. + """ __all__ = ['read_request', 'read_response', 'USER_AGENT'] @@ -37,6 +38,7 @@ def read_request(stream): Raise an exception if the request isn't well formatted. The request is assumed not to contain a body. + """ request_line, headers = yield from read_message(stream) method, path, version = request_line[:-2].decode().split(None, 2) @@ -58,6 +60,7 @@ def read_response(stream): Raise an exception if the request isn't well formatted. The response is assumed not to contain a body. + """ status_line, headers = yield from read_message(stream) version, status, reason = status_line[:-2].decode().split(None, 2) @@ -75,6 +78,7 @@ def read_message(stream): and ``headers`` is a :class:`~email.message.Message`. The message is assumed not to contain a body. + """ start_line = yield from read_line(stream) header_lines = io.BytesIO() @@ -94,6 +98,7 @@ def read_message(stream): def read_line(stream): """ Read a single line from ``stream``. + """ line = yield from stream.readline() if len(line) > MAX_LINE: diff --git a/websockets/protocol.py b/websockets/protocol.py index 71e74861b..7008a451d 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -3,6 +3,7 @@ frames as specified in `sections 4 to 8 of RFC 6455`_. .. _sections 4 to 8 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-4 + """ __all__ = ['WebSocketCommonProtocol'] @@ -74,8 +75,8 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): Once the connection is closed, the status code is available in the :attr:`close_code` attribute and the reason in :attr:`close_reason`. - """ + """ # There are only two differences between the client-side and the server- # side behavior: masking the payload and closing the underlying TCP # connection. This class implements the server-side behavior by default. @@ -149,6 +150,7 @@ def local_address(self): This is a ``(host, port)`` tuple or ``None`` if the connection hasn't been established yet. + """ if self.writer is None: return None @@ -161,6 +163,7 @@ def remote_address(self): This is a ``(host, port)`` tuple or ``None`` if the connection hasn't been established yet. + """ if self.writer is None: return None @@ -176,6 +179,7 @@ def open(self): raises a :exc:`~websockets.exceptions.ConnectionClosed` exception. .. _EAFP: https://docs.python.org/3/glossary.html#term-eafp + """ return self.state == OPEN @@ -188,6 +192,7 @@ def state_name(self): CONNECTING, OPEN, CLOSING, or CLOSED. To check if the connection is open, use :attr:`open` instead. + """ return ['CONNECTING', 'OPEN', 'CLOSING', 'CLOSED'][self.state] @@ -203,6 +208,7 @@ def close(self, code=1000, reason=''): since errors during connection termination aren't particularly useful. ``code`` must be an :class:`int` and ``reason`` a :class:`str`. + """ if self.state == OPEN: # 7.1.2. Start the WebSocket Closing Handshake @@ -237,6 +243,7 @@ def recv(self): :meth:`recv` used to return ``None`` instead. Refer to the changelog for details. + """ # Don't yield from self.ensure_open() here because messages could be # available in the queue even if the connection is closed. @@ -279,6 +286,7 @@ def send(self, data): It sends :class:`str` as a text frame and :class:`bytes` as a binary frame. It raises a :exc:`TypeError` for other inputs. + """ yield from self.ensure_open() @@ -307,6 +315,7 @@ def ping(self, data=None): By default, the ping contains four random bytes. The content may be overridden with the optional ``data`` argument which must be of type :class:`str` (which will be encoded to UTF-8) or :class:`bytes`. + """ yield from self.ensure_open() @@ -335,6 +344,7 @@ def pong(self, data=b''): The content may be overridden with the optional ``data`` argument which must be of type :class:`str` (which will be encoded to UTF-8) or :class:`bytes`. + """ yield from self.ensure_open() diff --git a/websockets/server.py b/websockets/server.py index 0fe9a729d..763d6ff7c 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -1,5 +1,6 @@ """ The :mod:`websockets.server` module defines a simple WebSocket server API. + """ __all__ = ['serve', 'WebSocketServerProtocol'] @@ -28,8 +29,8 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): For the sake of simplicity, it doesn't rely on a full HTTP implementation. Its support for HTTP responses is very limited. - """ + """ state = CONNECTING def __init__(self, ws_handler, ws_server, *, @@ -116,6 +117,7 @@ def handshake(self, origins=None, subprotocols=None, extra_headers=None): be a callable taking the request path and headers in arguments. Return the URI of the request. + """ # Read handshake request. try: @@ -177,6 +179,7 @@ def handshake(self, origins=None, subprotocols=None, extra_headers=None): def select_subprotocol(self, client_protos, server_protos): """ Pick a subprotocol among those offered by the client. + """ common_protos = set(client_protos) & set(server_protos) if not common_protos: @@ -188,8 +191,8 @@ def select_subprotocol(self, client_protos, server_protos): class WebSocketServer(asyncio.AbstractServer): """ Wrapper for :class:`~asyncio.Server` that triggers the closing handshake. - """ + """ def __init__(self, loop=None): # Store a reference to loop to avoid relying on self.server._loop. self.loop = loop @@ -208,6 +211,7 @@ def wrap(self, server): - give the protocol factory a reference to that instance - call :meth:`~asyncio.BaseEventLoop.create_server` with the factory - attach the resulting :class:`~asyncio.Server` with this method + """ self.server = server @@ -220,6 +224,7 @@ def unregister(self, protocol): def close(self): """ Stop serving and trigger a closing handshake on open connections. + """ for websocket in self.websockets: asyncio_ensure_future(websocket.close(1001), loop=self.loop) @@ -229,6 +234,7 @@ def close(self): def wait_closed(self): """ Wait until all connections are closed. + """ # asyncio.wait doesn't accept an empty first argument. if self.websockets: @@ -284,6 +290,7 @@ def serve(ws_handler, host=None, port=None, *, logger = logging.getLogger('websockets.server') logger.setLevel(logging.ERROR) logger.addHandler(logging.StreamHandler()) + """ if loop is None: loop = asyncio.get_event_loop() diff --git a/websockets/test_handshake.py b/websockets/test_handshake.py index 60ed808e6..2642d3855 100644 --- a/websockets/test_handshake.py +++ b/websockets/test_handshake.py @@ -29,6 +29,7 @@ def assert_invalid_request_headers(self): Provide request headers for corruption. Assert that the transformation made them invalid. + """ headers = {} build_request(headers.__setitem__) @@ -82,6 +83,7 @@ def assert_invalid_response_headers(self, key='CSIRmL8dWYxeAdr/XpEHRw=='): Provide response headers for corruption. Assert that the transformation made them invalid. + """ headers = {} build_response(headers.__setitem__, key) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index eb231dc45..cda4d1cb5 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -35,6 +35,7 @@ class TransportMock(unittest.mock.Mock): eof_received methods directly. They could also pause_writing and resume_writing to test flow control. + """ # This should happen in __init__ but overriding Mock.__init__ is hard. def connect(self, loop, protocol): @@ -97,6 +98,7 @@ def async(self): def receive_frame(self, frame): """ Make the protocol receive a frame. + """ writer = self.protocol.data_received mask = not self.protocol.is_client @@ -114,6 +116,7 @@ def receive_eof(self): To emulate this behavior, this function closes the transport just after calling the protocol's eof_received. Closing the transport has the side-effect calling the protocol's connection_lost. + """ self.loop.call_soon(self.protocol.eof_received) self.loop.call_soon(self.transport.close) @@ -126,6 +129,7 @@ def process_invalid_frames(self): which interrupts pending reads waiting for more data. It delays this operation with call_later because the protocol must start processing frames first. Otherwise it will see a closed connection and no data. + """ self.loop.call_later(MS, self.receive_eof) with self.assertRaises(ConnectionClosed): @@ -137,6 +141,7 @@ def process_control_frames(self): To ensure that recv completes quickly, receive an additional dummy frame, which recv() will drop. + """ self.receive_frame(Frame(True, OP_TEXT, b'')) next_message = self.loop.run_until_complete(self.protocol.recv()) @@ -148,6 +153,7 @@ def last_sent_frame(self): This method assumes that at most one frame was sent. It raises an AssertionError otherwise. + """ stream = asyncio.StreamReader(loop=self.loop) diff --git a/websockets/uri.py b/websockets/uri.py index 2bb474fbe..6e27a7cc3 100644 --- a/websockets/uri.py +++ b/websockets/uri.py @@ -3,6 +3,7 @@ according to `section 3 of RFC 6455`_. .. _section 3 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-3 + """ __all__ = ['parse_uri', 'WebSocketURI'] @@ -21,6 +22,7 @@ * ``host`` is the lower-case host * ``port`` if the integer port, it's always provided even if it's the default * ``resource_name`` is the resource name, that is, the path and optional query + """ @@ -31,6 +33,7 @@ def parse_uri(uri): If the URI is valid, it returns a :class:`WebSocketURI`. Otherwise it raises an :exc:`~websockets.exceptions.InvalidURI` exception. + """ uri = urllib.parse.urlparse(uri) try: From 98219d6e93ddedf2bce2d227360259d569b12918 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Dec 2015 21:40:45 +0100 Subject: [PATCH 0148/1539] Improve message of ConnectionClosed. --- websockets/exceptions.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/websockets/exceptions.py b/websockets/exceptions.py index c06280f6a..cd185ab05 100644 --- a/websockets/exceptions.py +++ b/websockets/exceptions.py @@ -34,8 +34,10 @@ class ConnectionClosed(InvalidState): def __init__(self, code, reason): self.code = code self.reason = reason - super().__init__('WebSocket connection is closed: ' - 'code = {}, reason = {}'.format(code, reason)) + message = 'WebSocket connection is closed: ' + message += 'code = {}, '.format(code) if code else 'no code, ' + message += 'reason = {}.'.format(reason) if reason else 'no reason.' + super().__init__(message) class InvalidURI(Exception): From f7f49084810ad47cfb0d3b518ad6a18b7bbab49e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Dec 2015 21:42:21 +0100 Subject: [PATCH 0149/1539] Prevent unintended timeout. The library considers that its job is done when the worker task exits. Usually connection_closed will have been set by connection_lost, but this is somewhat outside of our control, making it preferrable to test consistently for the worker's termination. --- websockets/protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index 7008a451d..7222060c7 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -379,7 +379,7 @@ def ensure_open(self): # longer than the worst case (2 * self.timeout) but not unlimited. if self.state == CLOSING: yield from asyncio.wait_for( - self.connection_closed, 3 * self.timeout, loop=self.loop) + self.worker, 3 * self.timeout, loop=self.loop) raise ConnectionClosed(self.close_code, self.close_reason) # Control may only reach this point in buggy third-party subclasses. From 6bd24ac1c9daa6d9ceeaf8941e69244b3f7e5a9a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Dec 2015 21:49:41 +0100 Subject: [PATCH 0150/1539] Make attributes of ConnectionClosed a public API. --- websockets/exceptions.py | 4 +++- websockets/test_protocol.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/websockets/exceptions.py b/websockets/exceptions.py index cd185ab05..8824df03c 100644 --- a/websockets/exceptions.py +++ b/websockets/exceptions.py @@ -29,8 +29,10 @@ class ConnectionClosed(InvalidState): """ Exception raised when trying to read or write on a closed connection. - """ + Provides the connection close code and reason in its ``code`` and + ``reason`` attributes respectively. + """ def __init__(self, code, reason): self.code = code self.reason = reason diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index cda4d1cb5..3e20f0596 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -601,6 +601,18 @@ def test_legacy_recv(self): # Now recv() returns None instead of raising ConnectionClosed. self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) + def test_connection_closed_attributes(self): + self.receive_frame(self.close_frame) + self.receive_eof() + self.loop.run_until_complete(self.protocol.close(reason='close')) + + with self.assertRaises(ConnectionClosed) as context: + self.loop.run_until_complete(self.protocol.recv()) + + connection_closed = context.exception + self.assertEqual(connection_closed.code, 1000) + self.assertEqual(connection_closed.reason, 'close') + class ServerCloseTests(CommonTests, unittest.TestCase): From ff85c929992796a19629b0e34034c209ff33c954 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Dec 2015 22:20:17 +0100 Subject: [PATCH 0151/1539] Fix incorrect order due to history rewriting. --- docs/changelog.rst | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 540772da1..9d45f5e1a 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -6,10 +6,6 @@ Changelog *In development* -* :meth:`~websockets.protocol.WebSocketCommonProtocol.ping` and - :meth:`~websockets.protocol.WebSocketCommonProtocol.pong` supports - data passed as :class:`str` in addition to :class:`bytes`. - .. warning:: **Version 3.0 introduces a backwards-incompatible change in the** @@ -43,6 +39,10 @@ Changelog Also: +* :meth:`~websockets.protocol.WebSocketCommonProtocol.ping` and + :meth:`~websockets.protocol.WebSocketCommonProtocol.pong` supports + data passed as :class:`str` in addition to :class:`bytes`. + * Worked around an asyncio bug affecting connection termination under load. * Made ``state_name`` atttribute on protocols a public API. From 7e673ece009ddd666040288185741c4d34e7c55a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Dec 2015 22:20:50 +0100 Subject: [PATCH 0152/1539] Negligible clarification. --- docs/intro.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/intro.rst b/docs/intro.rst index a48bda71d..ea8c76b52 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -6,8 +6,8 @@ Basic example .. _server-example: -Here's a WebSocket server example. It reads a name from the client and sends a -message. +Here's a WebSocket server example. It reads a name from the client, sends a +greeting, and closes the connection. .. literalinclude:: ../example/server.py From f531d696dad7a0b3c8fe59e3512992f3bc9c3037 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 14 Dec 2015 09:05:48 +0100 Subject: [PATCH 0153/1539] Avoid swallowing an exception in write_frame. This raises a consistent ConnectionClosed exception when the connection is closed while attempting to send a frame. Thanks @dzen for the suggestion. --- websockets/protocol.py | 13 +++++++++---- websockets/test_protocol.py | 16 ++++------------ 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index 7222060c7..51cfd622d 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -470,6 +470,7 @@ def read_data_frame(self, max_size): # 6.2. Receiving Data while True: frame = yield from self.read_frame(max_size) + # 5.5. Control Frames if frame.opcode == OP_CLOSE: # Make sure the close frame is valid before echoing it. @@ -477,13 +478,14 @@ def read_data_frame(self, max_size): if self.state == OPEN: # 7.1.3. The WebSocket Closing Handshake is Started yield from self.write_frame(OP_CLOSE, frame.data) - if not self.closing_handshake.done(): - self.close_code, self.close_reason = code, reason - self.closing_handshake.set_result(True) + self.close_code, self.close_reason = code, reason + self.closing_handshake.set_result(True) return + elif frame.opcode == OP_PING: # Answer pings. yield from self.pong(frame.data) + elif frame.opcode == OP_PONG: # Do not acknowledge pings on unsolicited pongs. if frame.data in self.pings: @@ -493,6 +495,7 @@ def read_data_frame(self, max_size): ping_id, waiter = self.pings.popitem(0) if not waiter.cancelled(): waiter.set_result(None) + # 5.6. Data Frames else: return frame @@ -546,6 +549,8 @@ def write_frame(self, opcode, data=b''): except ConnectionError: # Terminate the connection if the socket died. yield from self.fail_connection(1006) + # And raise an exception, since the frame couldn't be sent. + raise ConnectionClosed(self.close_code, self.close_reason) @asyncio.coroutine def close_connection(self, force=False): @@ -590,7 +595,7 @@ def fail_connection(self, code=1011, reason=''): logger.info("Failing the WebSocket connection: %d %s", code, reason) if self.state == OPEN: if code == 1006: - # Don't send a close frame is the connection is broken. Set + # Don't send a close frame if the connection is broken. Set # the state to CLOSING to allow close_connection to proceed. self.state = CLOSING else: diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 3e20f0596..94c3ab3da 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -709,7 +709,6 @@ def test_close_during_recv(self): self.receive_frame(self.close_frame) self.loop.run_until_complete(self.protocol.close(reason='close')) - # Receiving a message shouldn't crash. with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(recv) @@ -722,11 +721,8 @@ def test_close_during_send(self): self.receive_frame(self.close_frame) self.receive_eof() - # Sending a message shouldn't crash. - self.loop.run_until_complete(send) - - # Complete the connection. - self.loop.run_until_complete(self.protocol.close(reason='close')) + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(send) self.assertConnectionClosed(1006, '') @@ -852,7 +848,6 @@ def test_close_during_recv(self): self.receive_eof() self.loop.run_until_complete(self.protocol.close(reason='close')) - # Receiving a message shouldn't crash. with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(recv) @@ -865,10 +860,7 @@ def test_close_during_send(self): self.receive_frame(self.close_frame) self.receive_eof() - # Sending a message shouldn't crash. - self.loop.run_until_complete(send) - - # Complete the connection. - self.loop.run_until_complete(self.protocol.close(reason='close')) + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(send) self.assertConnectionClosed(1006, '') From 85050db93f0fabaf67ef68889bfda77b1038cb07 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 14 Dec 2015 10:08:29 +0100 Subject: [PATCH 0154/1539] Explain difference between client and server tests. --- websockets/test_protocol.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 94c3ab3da..a8d835356 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -52,7 +52,12 @@ def close(self): class CommonTests: + """ + Mixin that defines most tests but doesn't inherit unittest.TestCase. + + Tests are run by the ServerTests and ClientTests subclasses. + """ def setUp(self): super().setUp() self.loop = asyncio.new_event_loop() @@ -121,6 +126,17 @@ def receive_eof(self): self.loop.call_soon(self.protocol.eof_received) self.loop.call_soon(self.transport.close) + def receive_eof_if_client(self): + """ + Like receive_eof, but only if this is the client side. + + Since the server is supposed to initiate the termination of the TCP + connection, this method helps making tests work for both sides. + + """ + if self.protocol.is_client: + self.receive_eof() + def process_invalid_frames(self): """ Make the protocol fail quickly after simulating invalid data. @@ -255,8 +271,7 @@ def test_recv_on_closing_connection(self): # Complete the closing handshake while running the recv. self.receive_frame(self.close_frame) - if self.protocol.is_client: - self.receive_eof() + self.receive_eof_if_client() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.recv()) @@ -346,8 +361,7 @@ def test_send_on_closing_connection(self): # Complete the closing handshake while running the send. self.receive_frame(self.close_frame) - if self.protocol.is_client: - self.receive_eof() + self.receive_eof_if_client() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.send('foobar')) @@ -393,8 +407,7 @@ def test_ping_on_closing_connection(self): # Complete the closing handshake while running the ping. self.receive_frame(self.close_frame) - if self.protocol.is_client: - self.receive_eof() + self.receive_eof_if_client() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.ping()) @@ -435,8 +448,7 @@ def test_pong_on_closing_connection(self): # Complete the closing handshake while running the pong. self.receive_frame(self.close_frame) - if self.protocol.is_client: - self.receive_eof() + self.receive_eof_if_client() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.pong()) From 6f38c47f58257f0882f7c917cbc52a3b486d0d74 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 14 Dec 2015 10:23:00 +0100 Subject: [PATCH 0155/1539] Factor common client and server protocol tests. Only tests with significant differences remain in the subclasses. --- websockets/test_protocol.py | 215 +++++++++++------------------------- 1 file changed, 65 insertions(+), 150 deletions(-) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index a8d835356..7ed1f76bb 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -72,6 +72,8 @@ def tearDown(self): self.loop.close() super().tearDown() + # Utilities for writing tests. + def run_loop_once(self): # Process callbacks scheduled with call_soon by appending a callback # to stop the event loop then running it until it hits that callback. @@ -91,10 +93,9 @@ def delayed_drain(): self.protocol.writer.drain = delayed_drain - # These frames are used in the ServerTests and ClientTests subclasses. close_frame = Frame(True, OP_CLOSE, serialize_close(1000, 'close')) - client_close = Frame(True, OP_CLOSE, serialize_close(1000, 'client')) - server_close = Frame(True, OP_CLOSE, serialize_close(1000, 'server')) + local_close = Frame(True, OP_CLOSE, serialize_close(1000, 'local')) + remote_close = Frame(True, OP_CLOSE, serialize_close(1000, 'remote')) @property def async(self): @@ -625,13 +626,14 @@ def test_connection_closed_attributes(self): self.assertEqual(connection_closed.code, 1000) self.assertEqual(connection_closed.reason, 'close') + # Test the protocol logic for closing the connection. -class ServerCloseTests(CommonTests, unittest.TestCase): - - # Test the protocol logic for closing the connection on the server side. - - def test_server_close(self): + def test_local_close(self): + # Emulate how the remote endpoint answers the closing handshake. self.receive_frame(self.close_frame) + self.receive_eof_if_client() + + # Run the closing handshake. self.loop.run_until_complete(self.protocol.close(reason='close')) self.assertConnectionClosed(1000, 'close') @@ -643,14 +645,16 @@ def test_server_close(self): self.assertConnectionClosed(1000, 'close') self.assertNoFrameSent() - def test_client_close(self): + def test_remote_close(self): + # Emulate how the remote endpoint initiates the closing handshake. self.receive_frame(self.close_frame) + self.receive_eof_if_client() - # The server is waiting for some data at this point but won't get it. + # Wait for some data in order to process the handshake. + # After recv() raises ConnectionClosed, the connection is closed. with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.recv()) - # After recv() raises ConnectionClosed, the connection is closed. self.assertConnectionClosed(1000, 'close') self.assertOneFrameSent(*self.close_frame) @@ -661,51 +665,29 @@ def test_client_close(self): self.assertNoFrameSent() def test_simultaneous_close(self): - self.receive_frame(self.client_close) - self.loop.run_until_complete(self.protocol.close(reason='server')) + self.receive_frame(self.remote_close) + self.receive_eof_if_client() + self.loop.run_until_complete(self.protocol.close(reason='local')) # The close code and reason are taken from the remote side because # that's presumably more useful that the values from the local side. - self.assertConnectionClosed(1000, 'client') - self.assertOneFrameSent(*self.server_close) + self.assertConnectionClosed(1000, 'remote') + self.assertOneFrameSent(*self.local_close) def test_close_drops_frames(self): text_frame = Frame(True, OP_TEXT, b'') self.receive_frame(text_frame) self.receive_frame(self.close_frame) + self.receive_eof_if_client() self.loop.run_until_complete(self.protocol.close(reason='close')) self.assertConnectionClosed(1000, 'close') self.assertOneFrameSent(*self.close_frame) - def test_close_handshake_timeout(self): - # Timeout is expected in 10ms. - self.protocol.timeout = 10 * MS - # Check the timing within -1/+9ms for robustness. - with self.assertCompletesWithin(9 * MS, 19 * MS): - # Unlike previous tests, no close frame will be received in - # response. The server will stop waiting for the close frame and - # timeout. - self.loop.run_until_complete(self.protocol.close(reason='close')) - self.assertConnectionClosed(1006, '') - - def test_client_close_race_with_failing_connection(self): - self.make_drain_slow() - - # Fail the connection while answering a close frame from the client. - self.loop.call_soon(self.receive_frame, self.client_close) - self.loop.call_later(MS, self.async, self.protocol.fail_connection()) - - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(self.protocol.recv()) - - # The closing handshake was completed by fail_connection. - self.assertConnectionClosed(1011, '') - self.assertOneFrameSent(*self.client_close) - def test_close_protocol_error(self): invalid_close_frame = Frame(True, OP_CLOSE, b'\x00') self.receive_frame(invalid_close_frame) + self.receive_eof_if_client() self.loop.run_until_complete(self.protocol.close(reason='close')) self.assertConnectionClosed(1002, '') @@ -716,9 +698,29 @@ def test_close_connection_lost(self): self.assertConnectionClosed(1006, '') - def test_close_during_recv(self): + def test_remote_close_race_with_failing_connection(self): + self.make_drain_slow() + + # Fail the connection while answering a close frame from the client. + self.loop.call_soon(self.receive_frame, self.remote_close) + self.loop.call_later(MS, self.async, self.protocol.fail_connection()) + # The client expects the server to close the connection. + # Simulate it instead of waiting for the connection timeout. + self.loop.call_later(MS, self.receive_eof_if_client) + + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.protocol.recv()) + + # The closing handshake was completed by fail_connection. + self.assertConnectionClosed(1011, '') + self.assertOneFrameSent(*self.remote_close) + + def test_local_close_during_recv(self): recv = self.async(self.protocol.recv()) + self.receive_frame(self.close_frame) + self.receive_eof_if_client() + self.loop.run_until_complete(self.protocol.close(reason='close')) with self.assertRaises(ConnectionClosed): @@ -726,10 +728,13 @@ def test_close_during_recv(self): self.assertConnectionClosed(1000, 'close') - def test_close_during_send(self): - self.make_drain_slow() + # There is no test_remote_close_during_recv because it would be identical + # to test_remote_close. + def test_remote_close_during_send(self): + self.make_drain_slow() send = self.async(self.protocol.send('hello')) + self.receive_frame(self.close_frame) self.receive_eof() @@ -738,68 +743,29 @@ def test_close_during_send(self): self.assertConnectionClosed(1006, '') + # There is no test_local_close_during_send because this cannot really + # happen, considering that writes are serialized. -class ClientCloseTests(CommonTests, unittest.TestCase): - - def setUp(self): - super().setUp() - self.protocol.is_client = True - - # Test the protocol logic for closing the connection on the client side. - - def test_client_close(self): - self.receive_frame(self.close_frame) - self.receive_eof() - self.loop.run_until_complete(self.protocol.close(reason='close')) - - self.assertConnectionClosed(1000, 'close') - self.assertOneFrameSent(*self.close_frame) - - # Closing the connection again is a no-op. - self.loop.run_until_complete(self.protocol.close(reason='oh noes!')) - self.assertConnectionClosed(1000, 'close') - self.assertNoFrameSent() +class ServerTests(CommonTests, unittest.TestCase): - def test_server_close(self): - self.receive_frame(self.close_frame) - # The client expects the server to close the connection. Simulate it - # to avoid having to wait for the connection timeout. - self.loop.call_later(MS, self.receive_eof) - - # The client is waiting for some data at this point but won't get it. - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(self.protocol.recv()) - - # After recv() raises ConnectionClosed, the connection is closed. - self.assertConnectionClosed(1000, 'close') - self.assertOneFrameSent(*self.close_frame) - - # Closing the connection again is a no-op. - self.loop.run_until_complete(self.protocol.close('oh noes!')) - - self.assertConnectionClosed(1000, 'close') - self.assertNoFrameSent() - - def test_simultaneous_close(self): - self.receive_frame(self.server_close) - self.receive_eof() - self.loop.run_until_complete(self.protocol.close(reason='client')) + def test_close_handshake_timeout(self): + # Timeout is expected in 10ms. + self.protocol.timeout = 10 * MS + # Check the timing within -1/+9ms for robustness. + with self.assertCompletesWithin(9 * MS, 19 * MS): + # Unlike previous tests, no close frame will be received in + # response. The server will stop waiting for the close frame and + # timeout. + self.loop.run_until_complete(self.protocol.close(reason='close')) + self.assertConnectionClosed(1006, '') - # The close code and reason are taken from the remote side because - # that's presumably more useful that the values from the local side. - self.assertConnectionClosed(1000, 'server') - self.assertOneFrameSent(*self.client_close) - def test_close_drops_frames(self): - text_frame = Frame(True, OP_TEXT, b'') - self.receive_frame(text_frame) - self.receive_frame(self.close_frame) - self.receive_eof() - self.loop.run_until_complete(self.protocol.close(reason='close')) +class ClientTests(CommonTests, unittest.TestCase): - self.assertConnectionClosed(1000, 'close') - self.assertOneFrameSent(*self.close_frame) + def setUp(self): + super().setUp() + self.protocol.is_client = True def test_close_handshake_timeout(self): # Timeout is expected in 2 * 10 = 20ms. @@ -823,56 +789,5 @@ def test_eof_received_timeout(self): # stop waiting for the connection close and timeout. self.receive_frame(self.close_frame) self.loop.run_until_complete(self.protocol.close(reason='close')) - self.assertConnectionClosed(1000, 'close') - - def test_server_close_race_with_failing_connection(self): - self.make_drain_slow() - - # Fail the connection while answering a close frame from the server. - self.loop.call_soon(self.receive_frame, self.server_close) - self.loop.call_later(MS, self.async, self.protocol.fail_connection()) - self.loop.call_later(2 * MS, self.receive_eof) - - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(self.protocol.recv()) - - # The closing handshake was completed by fail_connection. - self.assertConnectionClosed(1011, '') - self.assertOneFrameSent(*self.server_close) - - def test_close_protocol_error(self): - invalid_close_frame = Frame(True, OP_CLOSE, b'\x00') - self.receive_frame(invalid_close_frame) - self.receive_eof() - self.loop.run_until_complete(self.protocol.close(reason='close')) - - self.assertConnectionClosed(1002, '') - - def test_close_connection_lost(self): - self.receive_eof() - self.loop.run_until_complete(self.protocol.close(reason='close')) - - self.assertConnectionClosed(1006, '') - - def test_close_during_recv(self): - recv = self.async(self.protocol.recv()) - self.receive_frame(self.close_frame) - self.receive_eof() - self.loop.run_until_complete(self.protocol.close(reason='close')) - - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(recv) self.assertConnectionClosed(1000, 'close') - - def test_close_during_send(self): - self.make_drain_slow() - - send = self.async(self.protocol.send('hello')) - self.receive_frame(self.close_frame) - self.receive_eof() - - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(send) - - self.assertConnectionClosed(1006, '') From 1de8f529dbd12eed001b63fd715f5fddda16118b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 14 Dec 2015 17:05:17 +0100 Subject: [PATCH 0156/1539] Messages can be read after the close handshake. This test implied the opposite of the current, and desired, behavior. --- websockets/test_protocol.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 7ed1f76bb..363e0b302 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -674,9 +674,8 @@ def test_simultaneous_close(self): self.assertConnectionClosed(1000, 'remote') self.assertOneFrameSent(*self.local_close) - def test_close_drops_frames(self): - text_frame = Frame(True, OP_TEXT, b'') - self.receive_frame(text_frame) + def test_close_preserves_incoming_frames(self): + self.receive_frame(Frame(True, OP_TEXT, b'hello')) self.receive_frame(self.close_frame) self.receive_eof_if_client() self.loop.run_until_complete(self.protocol.close(reason='close')) @@ -684,6 +683,9 @@ def test_close_drops_frames(self): self.assertConnectionClosed(1000, 'close') self.assertOneFrameSent(*self.close_frame) + next_message = self.loop.run_until_complete(self.protocol.recv()) + self.assertEqual(next_message, 'hello') + def test_close_protocol_error(self): invalid_close_frame = Frame(True, OP_CLOSE, b'\x00') self.receive_frame(invalid_close_frame) From 61ee954bb45708eec8021dcb40210c446b7c1650 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 14 Dec 2015 17:11:48 +0100 Subject: [PATCH 0157/1539] Fix typo. --- websockets/test_protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 363e0b302..59a01c408 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -13,7 +13,7 @@ # Unit for timeouts. May be increased on slow machines by setting the -# WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variables. +# WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. MS = 0.001 * int(os.environ.get('WEBSOCKETS_TESTS_TIMEOUT_FACTOR', 1)) # asyncio's debug mode has a 10x performance penalty for this test suite. From abec6594f81d60c23e5ef8af650284ee44ef44e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szieberth=20=C3=81d=C3=A1m?= Date: Mon, 14 Dec 2015 18:06:35 +0100 Subject: [PATCH 0158/1539] Avoid conflict with the built-in time module. Fix #87. --- docs/intro.rst | 4 ++-- example/{time.py => sendtime.py} | 0 example/{time.html => showtime.html} | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename example/{time.py => sendtime.py} (100%) rename example/{time.html => showtime.html} (100%) diff --git a/docs/intro.rst b/docs/intro.rst index ea8c76b52..45922ab6e 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -28,11 +28,11 @@ Here's an example of how to run a WebSocket server and connect from a browser. Run this script in a console: -.. literalinclude:: ../example/time.py +.. literalinclude:: ../example/sendtime.py Then open this HTML file in a browser. -.. literalinclude:: ../example/time.html +.. literalinclude:: ../example/showtime.html :language: html Common patterns diff --git a/example/time.py b/example/sendtime.py similarity index 100% rename from example/time.py rename to example/sendtime.py diff --git a/example/time.html b/example/showtime.html similarity index 100% rename from example/time.html rename to example/showtime.html From 00d04cbf12e2c6691cdf34d22162d8204625d702 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 16 Dec 2015 18:57:31 +0100 Subject: [PATCH 0159/1539] Improve emulation of receiving EOF. This fixes tests on Python 3.5.1. Fix #92. --- websockets/test_protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 59a01c408..203d4d18e 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -125,7 +125,7 @@ def receive_eof(self): """ self.loop.call_soon(self.protocol.eof_received) - self.loop.call_soon(self.transport.close) + self.loop.call_soon(self.loop.call_soon, self.transport.close) def receive_eof_if_client(self): """ From b9a235d8f37872924e3a8fe3a44f0373f552d827 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 16 Dec 2015 21:16:24 +0100 Subject: [PATCH 0160/1539] Factor out testing on closed connections. --- websockets/test_protocol.py | 43 +++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 203d4d18e..8e77707da 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -138,6 +138,20 @@ def receive_eof_if_client(self): if self.protocol.is_client: self.receive_eof() + def close_connection(self, code=1000, reason='close'): + """ + Close the connection with a standard closing handshake. + + """ + close_frame_data = serialize_close(code, reason) + # Prepare the response to the closing handshake from the remote side. + self.receive_frame(Frame(True, OP_CLOSE, close_frame_data)) + self.receive_eof_if_client() + # Trigger the closing handshake from the local side and complete it. + self.loop.run_until_complete(self.protocol.close(code, reason)) + # Empty the outgoing data stream so we can make assertions later on. + self.assertOneFrameSent(True, OP_CLOSE, close_frame_data) + def process_invalid_frames(self): """ Make the protocol fail quickly after simulating invalid data. @@ -238,18 +252,12 @@ def test_remote_address(self): def test_open(self): self.assertTrue(self.protocol.open) - - # This is a way to terminate the connection. - self.process_invalid_frames() - + self.close_connection() self.assertFalse(self.protocol.open) def test_state_name(self): self.assertEqual(self.protocol.state_name, 'OPEN') - - # This is a way to terminate the connection. - self.process_invalid_frames() - + self.close_connection() self.assertEqual(self.protocol.state_name, 'CLOSED') # Test the recv coroutine. @@ -278,8 +286,7 @@ def test_recv_on_closing_connection(self): self.loop.run_until_complete(self.protocol.recv()) def test_recv_on_closed_connection(self): - # This is a way to terminate the connection. - self.process_invalid_frames() + self.close_connection() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.recv()) @@ -369,8 +376,7 @@ def test_send_on_closing_connection(self): self.assertNoFrameSent() def test_send_on_closed_connection(self): - # This is a way to terminate the connection. - self.process_invalid_frames() + self.close_connection() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.send('foobar')) @@ -415,8 +421,7 @@ def test_ping_on_closing_connection(self): self.assertNoFrameSent() def test_ping_on_closed_connection(self): - # This is a way to terminate the connection. - self.process_invalid_frames() + self.close_connection() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.ping()) @@ -456,8 +461,7 @@ def test_pong_on_closing_connection(self): self.assertNoFrameSent() def test_pong_on_closed_connection(self): - # This is a way to terminate the connection. - self.process_invalid_frames() + self.close_connection() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.pong()) @@ -605,8 +609,7 @@ def test_legacy_recv(self): # By default legacy_recv in disabled. self.assertEqual(self.protocol.legacy_recv, False) - # This is a way to terminate the connection. - self.process_invalid_frames() + self.close_connection() # Enable legacy_recv. self.protocol.legacy_recv = True @@ -615,9 +618,7 @@ def test_legacy_recv(self): self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) def test_connection_closed_attributes(self): - self.receive_frame(self.close_frame) - self.receive_eof() - self.loop.run_until_complete(self.protocol.close(reason='close')) + self.close_connection() with self.assertRaises(ConnectionClosed) as context: self.loop.run_until_complete(self.protocol.recv()) From 3c0f78f8279e45ec3bf83046495977a867c76cd8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 16 Dec 2015 21:21:52 +0100 Subject: [PATCH 0161/1539] Factor out testing on closing connections. --- websockets/test_protocol.py | 56 ++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 32 deletions(-) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 8e77707da..53d67223b 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -142,6 +142,8 @@ def close_connection(self, code=1000, reason='close'): """ Close the connection with a standard closing handshake. + This puts the connection in the CLOSED state. + """ close_frame_data = serialize_close(code, reason) # Prepare the response to the closing handshake from the remote side. @@ -152,6 +154,24 @@ def close_connection(self, code=1000, reason='close'): # Empty the outgoing data stream so we can make assertions later on. self.assertOneFrameSent(True, OP_CLOSE, close_frame_data) + def close_connection_partial(self, code=1000, reason='close'): + """ + Initiate a standard closing handshake but do not complete it. + + The main difference with `close_connection` is that the connection is + left in the CLOSING state until the event loop runs again. + + """ + close_frame_data = serialize_close(code, reason) + # Trigger the closing handshake from the local side. + self.async(self.protocol.close(code, reason)) + self.run_loop_once() + # Empty the outgoing data stream so we can make assertions later on. + self.assertOneFrameSent(True, OP_CLOSE, close_frame_data) + # Prepare the response to the closing handshake from the remote side. + self.receive_frame(Frame(True, OP_CLOSE, close_frame_data)) + self.receive_eof_if_client() + def process_invalid_frames(self): """ Make the protocol fail quickly after simulating invalid data. @@ -273,14 +293,7 @@ def test_recv_binary(self): self.assertEqual(data, b'tea') def test_recv_on_closing_connection(self): - # This is a way to start a closing handshake. - self.async(self.protocol.close()) - self.run_loop_once() - self.assertOneFrameSent(True, OP_CLOSE, b'\x03\xe8') - - # Complete the closing handshake while running the recv. - self.receive_frame(self.close_frame) - self.receive_eof_if_client() + self.close_connection_partial() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.recv()) @@ -362,14 +375,7 @@ def test_send_type_error(self): self.assertNoFrameSent() def test_send_on_closing_connection(self): - # This is a way to start a closing handshake. - self.async(self.protocol.close()) - self.run_loop_once() - self.assertOneFrameSent(True, OP_CLOSE, b'\x03\xe8') - - # Complete the closing handshake while running the send. - self.receive_frame(self.close_frame) - self.receive_eof_if_client() + self.close_connection_partial() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.send('foobar')) @@ -407,14 +413,7 @@ def test_ping_type_error(self): self.assertNoFrameSent() def test_ping_on_closing_connection(self): - # This is a way to start a closing handshake. - self.async(self.protocol.close()) - self.run_loop_once() - self.assertOneFrameSent(True, OP_CLOSE, b'\x03\xe8') - - # Complete the closing handshake while running the ping. - self.receive_frame(self.close_frame) - self.receive_eof_if_client() + self.close_connection_partial() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.ping()) @@ -447,14 +446,7 @@ def test_pong_type_error(self): self.assertNoFrameSent() def test_pong_on_closing_connection(self): - # This is a way to start a closing handshake. - self.async(self.protocol.close()) - self.run_loop_once() - self.assertOneFrameSent(True, OP_CLOSE, b'\x03\xe8') - - # Complete the closing handshake while running the pong. - self.receive_frame(self.close_frame) - self.receive_eof_if_client() + self.close_connection_partial() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.pong()) From 62e0873cdbf1da8330a3249b92f6a8f6a9af18a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szieberth=20=C3=81d=C3=A1m?= Date: Tue, 15 Dec 2015 17:30:13 +0100 Subject: [PATCH 0162/1539] Allow using connect() as an async context manager. This only works on Python 3.5. --- example/client35.py | 16 ++++++++++++++++ websockets/client.py | 11 +++++++++++ websockets/python35.py | 25 +++++++++++++++++++++++++ 3 files changed, 52 insertions(+) create mode 100644 example/client35.py create mode 100644 websockets/python35.py diff --git a/example/client35.py b/example/client35.py new file mode 100644 index 000000000..702cbf18e --- /dev/null +++ b/example/client35.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python + +import asyncio +import websockets + +async def hello(): + async with websockets.connect('ws://localhost:8765') as websocket: + + name = input("What's your name? ") + await websocket.send(name) + print("> {}".format(name)) + + greeting = await websocket.recv() + print("< {}".format(greeting)) + +asyncio.get_event_loop().run_until_complete(hello()) diff --git a/websockets/client.py b/websockets/client.py index 958825cde..399d6c231 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -154,3 +154,14 @@ def connect(uri, *, raise return protocol + + +try: + from .python35 import Connect +except SyntaxError: + pass +else: + Connect.__wrapped__ = connect + # Copy over docstring to support building documentation on Python 3.5. + Connect.__doc__ = connect.__doc__ + connect = Connect diff --git a/websockets/python35.py b/websockets/python35.py new file mode 100644 index 000000000..57b05c396 --- /dev/null +++ b/websockets/python35.py @@ -0,0 +1,25 @@ +class Connect: + """ + This class wraps :func:`connect` on Python 3.5 and above. + + It can be used as an asynchronous context manager. + + """ + + __wrapped__ = NotImplemented + + def __init__(self, *args, **kwargs): + connect = self.__class__.__wrapped__ + self.connect_coroutine = connect(*args, **kwargs) + + def __await__(self): + return (yield from self.connect_coroutine) + + async def __aenter__(self): + self.websocket = await self + return self.websocket + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.websocket.close() + + __iter__ = __await__ From edb4158aa511825eae3f9dc6392223919fc20609 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 17 Dec 2015 22:37:45 +0100 Subject: [PATCH 0163/1539] Add tests for connect() as a context manager. Coverage report is at 100% but it requires this patch: https://bitbucket.org/ned/coveragepy/issues/434/indexerror-in-python-35#comment-24146893 --- websockets/client.py | 2 +- websockets/py35_test_client_server.py | 33 +++++++++++++++++++++++++++ websockets/test_client_server.py | 10 ++++++++ 3 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 websockets/py35_test_client_server.py diff --git a/websockets/client.py b/websockets/client.py index 399d6c231..e051fa991 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -158,7 +158,7 @@ def connect(uri, *, try: from .python35 import Connect -except SyntaxError: +except SyntaxError: # pragma: no cover pass else: Connect.__wrapped__ = connect diff --git a/websockets/py35_test_client_server.py b/websockets/py35_test_client_server.py new file mode 100644 index 000000000..0b40df9d1 --- /dev/null +++ b/websockets/py35_test_client_server.py @@ -0,0 +1,33 @@ +# Tests containing Python 3.5+ syntax, extracted from test_client_server.py. +# To avoid test discovery, this module's name must not start with test_. + +import asyncio + +from .client import * +from .server import * +from .test_client_server import handler + + +class ClientServerContextManager: + + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + + def test_basic(self): + server = serve(handler, 'localhost', 8642) + self.server = self.loop.run_until_complete(server) + + async def basic(): + async with connect('ws://localhost:8642/') as client: + await client.send("Hello!") + reply = await client.recv() + self.assertEqual(reply, "Hello!") + + self.loop.run_until_complete(basic()) + + self.server.close() + self.loop.run_until_complete(self.server.wait_closed()) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 8ce75703a..657723cc6 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -402,3 +402,13 @@ def test_checking_lack_of_origin_succeeds(self): self.loop.run_until_complete(client.close()) server.close() self.loop.run_until_complete(server.wait_closed()) + + +try: + from .py35_test_client_server import ClientServerContextManager +except SyntaxError: # pragma: no cover + pass +else: + class ClientServerContextManagerTests(ClientServerContextManager, + unittest.TestCase): + pass From b89ac62ed3d08e0a479ada2b6fb0e853b8c1a065 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 17 Dec 2015 22:44:26 +0100 Subject: [PATCH 0164/1539] Non-significant style changes to Connect. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I haven't written that sort of code before and changed my mind after giving Ádám precise advice. --- websockets/client.py | 2 +- websockets/{python35.py => py35_client.py} | 16 ++++++---------- 2 files changed, 7 insertions(+), 11 deletions(-) rename websockets/{python35.py => py35_client.py} (51%) diff --git a/websockets/client.py b/websockets/client.py index e051fa991..d61bbd91c 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -157,7 +157,7 @@ def connect(uri, *, try: - from .python35 import Connect + from .py35_client import Connect except SyntaxError: # pragma: no cover pass else: diff --git a/websockets/python35.py b/websockets/py35_client.py similarity index 51% rename from websockets/python35.py rename to websockets/py35_client.py index 57b05c396..5ab7af034 100644 --- a/websockets/python35.py +++ b/websockets/py35_client.py @@ -1,19 +1,12 @@ class Connect: """ - This class wraps :func:`connect` on Python 3.5 and above. + This class wraps :func:`~websockets.client.connect` on Python ≥ 3.5. - It can be used as an asynchronous context manager. + This allows using it as an asynchronous context manager. """ - - __wrapped__ = NotImplemented - def __init__(self, *args, **kwargs): - connect = self.__class__.__wrapped__ - self.connect_coroutine = connect(*args, **kwargs) - - def __await__(self): - return (yield from self.connect_coroutine) + self.client = self.__class__.__wrapped__(*args, **kwargs) async def __aenter__(self): self.websocket = await self @@ -22,4 +15,7 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_value, traceback): await self.websocket.close() + def __await__(self): + return (yield from self.client) + __iter__ = __await__ From 00f0b530cd15b21a7ee6ba7193b8bb5e9f7990e8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 17 Dec 2015 23:03:32 +0100 Subject: [PATCH 0165/1539] Document using connect as an async context manager. --- docs/changelog.rst | 3 +++ docs/cheatsheet.rst | 5 ++++- docs/intro.rst | 7 +++++++ example/client.py | 17 +++++++---------- example/client35.py | 16 ---------------- example/oldclient.py | 21 +++++++++++++++++++++ websockets/client.py | 3 +++ 7 files changed, 45 insertions(+), 27 deletions(-) mode change 100755 => 100644 example/client.py delete mode 100644 example/client35.py create mode 100755 example/oldclient.py diff --git a/docs/changelog.rst b/docs/changelog.rst index 9d45f5e1a..12ec1584d 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -39,6 +39,9 @@ Changelog Also: +* :func:`~websockets.client.connect` can be used as an asynchronous context + manager on Python ≥ 3.5. + * :meth:`~websockets.protocol.WebSocketCommonProtocol.ping` and :meth:`~websockets.protocol.WebSocketCommonProtocol.pong` supports data passed as :class:`str` in addition to :class:`bytes`. diff --git a/docs/cheatsheet.rst b/docs/cheatsheet.rst index 33a2c0eee..28c082bbe 100644 --- a/docs/cheatsheet.rst +++ b/docs/cheatsheet.rst @@ -31,6 +31,8 @@ Client * Create a client with :func:`~websockets.client.connect` which is similar to asyncio's :meth:`~asyncio.BaseEventLoop.create_connection`. + * On Python ≥ 3.5, you can also use it as an asynchronous context manager. + * You may subclass :class:`~websockets.server.WebSocketClientProtocol` and pass it in the ``klass`` keyword argument for advanced customization. @@ -42,7 +44,8 @@ Client :meth:`~websockets.protocol.WebSocketCommonProtocol.pong` if you wish but it isn't needed in general. -* Call :meth:`~websockets.protocol.WebSocketCommonProtocol.close` to terminate +* If you aren't using :func:`~websockets.client.connect` as a context manager, + call :meth:`~websockets.protocol.WebSocketCommonProtocol.close` to terminate the connection. Debugging diff --git a/docs/intro.rst b/docs/intro.rst index 45922ab6e..1146745bc 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -4,6 +4,8 @@ Getting started Basic example ------------- +*This section assumes Python ≥ 3.5. For older versions, read below.* + .. _server-example: Here's a WebSocket server example. It reads a name from the client, sends a @@ -21,6 +23,11 @@ On the server side, the handler coroutine ``hello`` is executed once for each WebSocket connection. The connection is automatically closed when the handler returns. +``async`` and ``await`` aren't available in Python < 3.5. Here's how to adapt +the client example for older Python versions. + +.. literalinclude:: ../example/oldclient.py + Browser-based example --------------------- diff --git a/example/client.py b/example/client.py old mode 100755 new mode 100644 index 66ff2b4d2..702cbf18e --- a/example/client.py +++ b/example/client.py @@ -3,17 +3,14 @@ import asyncio import websockets -@asyncio.coroutine -def hello(): - websocket = yield from websockets.connect('ws://localhost:8765/') +async def hello(): + async with websockets.connect('ws://localhost:8765') as websocket: - name = input("What's your name? ") - yield from websocket.send(name) - print("> {}".format(name)) + name = input("What's your name? ") + await websocket.send(name) + print("> {}".format(name)) - greeting = yield from websocket.recv() - print("< {}".format(greeting)) - - yield from websocket.close() + greeting = await websocket.recv() + print("< {}".format(greeting)) asyncio.get_event_loop().run_until_complete(hello()) diff --git a/example/client35.py b/example/client35.py deleted file mode 100644 index 702cbf18e..000000000 --- a/example/client35.py +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env python - -import asyncio -import websockets - -async def hello(): - async with websockets.connect('ws://localhost:8765') as websocket: - - name = input("What's your name? ") - await websocket.send(name) - print("> {}".format(name)) - - greeting = await websocket.recv() - print("< {}".format(greeting)) - -asyncio.get_event_loop().run_until_complete(hello()) diff --git a/example/oldclient.py b/example/oldclient.py new file mode 100755 index 000000000..763627a4b --- /dev/null +++ b/example/oldclient.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python + +import asyncio +import websockets + +@asyncio.coroutine +def hello(): + websocket = yield from websockets.connect('ws://localhost:8765/') + + try: + name = input("What's your name? ") + yield from websocket.send(name) + print("> {}".format(name)) + + greeting = yield from websocket.recv() + print("< {}".format(greeting)) + + finally: + yield from websocket.close() + +asyncio.get_event_loop().run_until_complete(hello()) diff --git a/websockets/client.py b/websockets/client.py index d61bbd91c..ea744619c 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -128,6 +128,9 @@ def connect(uri, *, It raises :exc:`~websockets.uri.InvalidURI` if ``uri`` is invalid and :exc:`~websockets.handshake.InvalidHandshake` if the handshake fails. + On Python 3.5, it can be used as a asynchronous context manager. In that + case, the connection is closed when exiting the context. + """ if loop is None: loop = asyncio.get_event_loop() From 4f16e25625bfe67d713658b8d0b19e7b3b1d9121 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 15 Dec 2015 23:49:00 +0100 Subject: [PATCH 0166/1539] Update docs to async / await syntax. Fix #88. --- docs/changelog.rst | 8 +++--- docs/intro.rst | 63 ++++++++++++++++++++++++++++++++------------- example/sendtime.py | 7 +++-- example/server.py | 7 +++-- 4 files changed, 56 insertions(+), 29 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 12ec1584d..6144d35cf 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -17,14 +17,14 @@ Changelog ``None`` when the connection was closed. This required checking the return value of every call:: - message = yield from websocket.recv() + message = await websocket.recv() if message is None: return Now it raises a :exc:`~websockets.exceptions.ConnectionClosed` exception instead. This is more Pythonic. The previous code can be simplified to:: - message = yield from websocket.recv() + message = await websocket.recv() When implementing a server, which is the more popular use case, there's no strong reason to handle such exceptions. Let them bubble up, terminate the @@ -42,6 +42,8 @@ Also: * :func:`~websockets.client.connect` can be used as an asynchronous context manager on Python ≥ 3.5. +* Updated documentation with ``await`` and ``async`` syntax from Python 3.5. + * :meth:`~websockets.protocol.WebSocketCommonProtocol.ping` and :meth:`~websockets.protocol.WebSocketCommonProtocol.pong` supports data passed as :class:`str` in addition to :class:`bytes`. @@ -142,7 +144,7 @@ Also: you must now write:: - yield from websocket.send(message) + await websocket.send(message) Also: diff --git a/docs/intro.rst b/docs/intro.rst index 1146745bc..d2ae4b5d6 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -1,6 +1,11 @@ Getting started =============== +.. warning:: + + This documentation is written for Python ≥ 3.5. If you're using Python 3.4 + or 3.3, you will have to :ref:`adapt the code samples `. + Basic example ------------- @@ -54,11 +59,10 @@ Consumer For receiving messages and passing them to a ``consumer`` coroutine:: - @asyncio.coroutine - def handler(websocket, path): + async def handler(websocket, path): while True: - message = yield from websocket.recv() - yield from consumer(message) + message = await websocket.recv() + await consumer(message) :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` raises a :exc:`~websockets.exceptions.ConnectionClosed` exception when the client @@ -69,11 +73,10 @@ Producer For getting messages from a ``producer`` coroutine and sending them:: - @asyncio.coroutine - def handler(websocket, path): + async def handler(websocket, path): while True: - message = yield from producer() - yield from websocket.send(message) + message = await producer() + await websocket.send(message) :meth:`~websockets.protocol.WebSocketCommonProtocol.send` raises a :exc:`~websockets.exceptions.ConnectionClosed` exception when the client @@ -87,24 +90,23 @@ messages on the same connection. :: - @asyncio.coroutine - def handler(websocket, path): + async def handler(websocket, path): while True: listener_task = asyncio.ensure_future(websocket.recv()) producer_task = asyncio.ensure_future(producer()) - done, pending = yield from asyncio.wait( + done, pending = await asyncio.wait( [listener_task, producer_task], return_when=asyncio.FIRST_COMPLETED) if listener_task in done: message = listener_task.result() - yield from consumer(message) + await consumer(message) else: listener_task.cancel() if producer_task in done: message = producer_task.result() - yield from websocket.send(message) + await websocket.send(message) else: producer_task.cancel() @@ -121,16 +123,14 @@ register clients when they connect and unregister them when they disconnect. connected = set() - @asyncio.coroutine - def handler(websocket, path): + async def handler(websocket, path): global connected # Register. connected.add(websocket) try: # Implement logic here. - yield from asyncio.wait( - [ws.send("Hello!") for ws in connected]) - yield from asyncio.sleep(10) + await asyncio.wait([ws.send("Hello!") for ws in connected]) + await asyncio.sleep(10) finally: # Unregister. connected.remove(websocket) @@ -148,3 +148,30 @@ You don't have to worry about performing the opening or the closing handshake, answering pings, or any other behavior required by the specification. ``websockets`` handles all this under the hood so you don't have to. + +.. _python-lt-35: + +Python < 3.5 +------------ + +This documentation uses the ``await`` and ``async`` syntax introduced in +Python 3.5. + +If you're using Python 3.4 or 3.3, you must substitute:: + + async def ... + +with:: + + @asyncio.coroutine + def ... + +and:: + + await ... + +with:: + + yield from ... + +Otherwise you will encounter a :exc:`SyntaxError`. diff --git a/example/sendtime.py b/example/sendtime.py index 7374bb8ad..2b14827c8 100644 --- a/example/sendtime.py +++ b/example/sendtime.py @@ -5,12 +5,11 @@ import random import websockets -@asyncio.coroutine -def time(websocket, path): +async def time(websocket, path): while True: now = datetime.datetime.utcnow().isoformat() + 'Z' - yield from websocket.send(now) - yield from asyncio.sleep(random.random() * 3) + await websocket.send(now) + await asyncio.sleep(random.random() * 3) start_server = websockets.serve(time, '127.0.0.1', 5678) diff --git a/example/server.py b/example/server.py index 7074c9ab5..cda3323dc 100755 --- a/example/server.py +++ b/example/server.py @@ -3,13 +3,12 @@ import asyncio import websockets -@asyncio.coroutine -def hello(websocket, path): - name = yield from websocket.recv() +async def hello(websocket, path): + name = await websocket.recv() print("< {}".format(name)) greeting = "Hello {}!".format(name) - yield from websocket.send(greeting) + await websocket.send(greeting) print("> {}".format(greeting)) start_server = websockets.serve(hello, 'localhost', 8765) From c7048aa107b23c1038cf306889d29b72a02b34f9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 25 Dec 2015 12:08:41 +0100 Subject: [PATCH 0167/1539] Small docs improvements. --- docs/cheatsheet.rst | 2 +- docs/intro.rst | 8 ++++---- example/server.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/cheatsheet.rst b/docs/cheatsheet.rst index 28c082bbe..eb2e1dd8c 100644 --- a/docs/cheatsheet.rst +++ b/docs/cheatsheet.rst @@ -58,7 +58,7 @@ If you don't understand what ``websockets`` is doing, enable logging:: logger.setLevel(logging.INFO) logger.addHandler(logging.StreamHandler()) -The logs contains: +The logs contain: * Exceptions in the connection handler at the ``ERROR`` level * Exceptions in the opening or closing handshake at the ``INFO`` level diff --git a/docs/intro.rst b/docs/intro.rst index d2ae4b5d6..df250ba1c 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -20,14 +20,14 @@ greeting, and closes the connection. .. _client-example: +On the server side, the handler coroutine ``hello`` is executed once for each +WebSocket connection. The connection is automatically closed when the handler +returns. + Here's a corresponding client example. .. literalinclude:: ../example/client.py -On the server side, the handler coroutine ``hello`` is executed once for -each WebSocket connection. The connection is automatically closed when the -handler returns. - ``async`` and ``await`` aren't available in Python < 3.5. Here's how to adapt the client example for older Python versions. diff --git a/example/server.py b/example/server.py index cda3323dc..37744b815 100755 --- a/example/server.py +++ b/example/server.py @@ -6,8 +6,8 @@ async def hello(websocket, path): name = await websocket.recv() print("< {}".format(name)) - greeting = "Hello {}!".format(name) + greeting = "Hello {}!".format(name) await websocket.send(greeting) print("> {}".format(greeting)) From 7d7cac390637021ab9944dc7028a7b2e8f84983f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 25 Dec 2015 12:09:38 +0100 Subject: [PATCH 0168/1539] Bump version number. --- docs/changelog.rst | 2 -- docs/conf.py | 4 ++-- setup.cfg | 2 +- websockets/version.py | 2 +- 4 files changed, 4 insertions(+), 6 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 6144d35cf..68d5f5bdb 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,8 +4,6 @@ Changelog 3.0 ... -*In development* - .. warning:: **Version 3.0 introduces a backwards-incompatible change in the** diff --git a/docs/conf.py b/docs/conf.py index e8223a507..194e59bf4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -48,9 +48,9 @@ # built documents. # # The short X.Y version. -version = '2.7' +version = '3.0' # The full version, including alpha/beta/rc tags. -release = '2.7' +release = '3.0' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/setup.cfg b/setup.cfg index d1368ad2f..350c3450c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bdist_wheel] -python-tag = py33.py34 +python-tag = py33.py34.py35 [flake8] ignore = F403 diff --git a/websockets/version.py b/websockets/version.py index 328c2005d..25c840365 100644 --- a/websockets/version.py +++ b/websockets/version.py @@ -1 +1 @@ -version = '2.7' +version = '3.0' From 702f2c3aace0c36d911a23c9b9ca0e3f7c84a49d Mon Sep 17 00:00:00 2001 From: Matt Iversen Date: Tue, 5 Jan 2016 23:56:29 +1100 Subject: [PATCH 0169/1539] Add Python 3.5 classifier --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 4f9eaa8cb..f6bc74744 100644 --- a/setup.py +++ b/setup.py @@ -52,6 +52,7 @@ "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.3", "Programming Language :: Python :: 3.4", + "Programming Language :: Python :: 3.5", ], platforms='all', license='BSD' From 992fc3a2794301ca4dbb94f9ce7801a4515e997d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Cardona?= Date: Fri, 12 Feb 2016 15:16:07 +0100 Subject: [PATCH 0170/1539] Fix setup.py errors/warnings on Python < 3.5. The core of the change is to move the py35_*.py files to a subpackage called websockets.py35, which we can then ignore in setup.py if running on Python < 3.5. --- MANIFEST.in | 4 +++- setup.py | 8 +++++--- websockets/client.py | 4 ++-- websockets/py35/__init__.py | 2 ++ websockets/{py35_client.py => py35/client.py} | 0 .../{py35_test_client_server.py => py35/client_server.py} | 6 +++--- websockets/test_client_server.py | 4 ++-- 7 files changed, 17 insertions(+), 11 deletions(-) create mode 100644 websockets/py35/__init__.py rename websockets/{py35_client.py => py35/client.py} (100%) rename websockets/{py35_test_client_server.py => py35/client_server.py} (90%) diff --git a/MANIFEST.in b/MANIFEST.in index cc0d11642..09205fb4b 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1,3 @@ -include LICENSE \ No newline at end of file +include LICENSE + +graft websockets/py35 diff --git a/setup.py b/setup.py index f6bc74744..06ff24b6c 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,10 @@ if py_version < (3, 3): raise Exception("websockets requires Python >= 3.3.") +packages = ['websockets'] +if py_version >= (3, 5): + packages.append('websockets/py35') + setuptools.setup( name='websockets', version=version, @@ -36,9 +40,7 @@ description=description, long_description=long_description, download_url='https://pypi.python.org/pypi/websockets', - packages=[ - 'websockets', - ], + packages=packages, extras_require={ ':python_version=="3.3"': ['asyncio'], }, diff --git a/websockets/client.py b/websockets/client.py index ea744619c..f67d43776 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -160,8 +160,8 @@ def connect(uri, *, try: - from .py35_client import Connect -except SyntaxError: # pragma: no cover + from .py35.client import Connect +except (SyntaxError, ImportError): # pragma: no cover pass else: Connect.__wrapped__ = connect diff --git a/websockets/py35/__init__.py b/websockets/py35/__init__.py new file mode 100644 index 000000000..9612d9dd7 --- /dev/null +++ b/websockets/py35/__init__.py @@ -0,0 +1,2 @@ +# This package contains code using async / await syntax added in Python 3.5. +# It cannot be imported on Python < 3.5 because it triggers syntax errors. diff --git a/websockets/py35_client.py b/websockets/py35/client.py similarity index 100% rename from websockets/py35_client.py rename to websockets/py35/client.py diff --git a/websockets/py35_test_client_server.py b/websockets/py35/client_server.py similarity index 90% rename from websockets/py35_test_client_server.py rename to websockets/py35/client_server.py index 0b40df9d1..bfafa39c3 100644 --- a/websockets/py35_test_client_server.py +++ b/websockets/py35/client_server.py @@ -3,9 +3,9 @@ import asyncio -from .client import * -from .server import * -from .test_client_server import handler +from ..client import * +from ..server import * +from ..test_client_server import handler class ClientServerContextManager: diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 657723cc6..5e2a78331 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -405,8 +405,8 @@ def test_checking_lack_of_origin_succeeds(self): try: - from .py35_test_client_server import ClientServerContextManager -except SyntaxError: # pragma: no cover + from .py35.client_server import ClientServerContextManager +except (SyntaxError, ImportError): # pragma: no cover pass else: class ClientServerContextManagerTests(ClientServerContextManager, From 6285a989881908579a65a002258f18b8b89ed094 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Cardona?= Date: Thu, 24 Mar 2016 18:57:05 +0100 Subject: [PATCH 0171/1539] protocol: make sure the worker task always finishes, closes #102 When a client connection is closed before the initial handshake is done, the "opening_handshake" Future is not finished and thus blocks the "run" coroutine and its "worker" Task. This raises warnings later on when the object is garbage collected. --- websockets/protocol.py | 2 ++ websockets/test_client_server.py | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/websockets/protocol.py b/websockets/protocol.py index 51cfd622d..28508c30f 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -631,6 +631,8 @@ def eof_received(self): def connection_lost(self, exc): # 7.1.4. The WebSocket Connection is Closed self.state = CLOSED + if not self.opening_handshake.done(): + self.opening_handshake.set_result(False) if not self.closing_handshake.done(): self.close_code, self.close_reason = 1006, '' self.closing_handshake.set_result(False) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 5e2a78331..fe267822c 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -79,6 +79,27 @@ def test_server_close_while_client_connected(self): self.start_client() self.stop_server() + def test_server_client_quick_disconnect(self): + self.start_server() + self.server.unregister = unittest.mock.Mock() + # run a client that opens a socket, waits for the connection to be + # established, and immediately closes the connection without exchanging + # any data. + @asyncio.coroutine + def quick_disconnect(): + kwds = {'ssl': self.client_context} if self.secure else {} + _, writer = yield from asyncio.open_connection( + 'localhost', 8642, loop=self.loop, **kwds) + writer.close() + self.loop.run_until_complete(quick_disconnect()) + # yield to the loop to let the WebSocketServer call its cleanup methods + self.loop.run_until_complete(asyncio.sleep(0.1, loop=self.loop)) + ws = next(iter(self.server.websockets)) + self.assertTrue(ws.worker.done()) + # do the job of the mocked "unregister" + self.server.websockets.clear() + self.stop_server() + def test_explicit_event_loop(self): self.start_server(loop=self.loop) self.start_client(loop=self.loop) From 55933e036686a8314e6ef39b09d16ea201c39cb9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 30 Mar 2016 23:00:04 +0200 Subject: [PATCH 0172/1539] Simplify test for #102. Also add changelog entry. --- docs/changelog.rst | 7 +++++++ websockets/test_client_server.py | 33 ++++++++++++-------------------- 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 68d5f5bdb..4e3107aa4 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,6 +1,13 @@ Changelog --------- +3.1 +... + +*In development* + +* Avoided a warning when closing a connection before the opening handshake. + 3.0 ... diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index fe267822c..e11d6c4f4 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -79,27 +79,6 @@ def test_server_close_while_client_connected(self): self.start_client() self.stop_server() - def test_server_client_quick_disconnect(self): - self.start_server() - self.server.unregister = unittest.mock.Mock() - # run a client that opens a socket, waits for the connection to be - # established, and immediately closes the connection without exchanging - # any data. - @asyncio.coroutine - def quick_disconnect(): - kwds = {'ssl': self.client_context} if self.secure else {} - _, writer = yield from asyncio.open_connection( - 'localhost', 8642, loop=self.loop, **kwds) - writer.close() - self.loop.run_until_complete(quick_disconnect()) - # yield to the loop to let the WebSocketServer call its cleanup methods - self.loop.run_until_complete(asyncio.sleep(0.1, loop=self.loop)) - ws = next(iter(self.server.websockets)) - self.assertTrue(ws.worker.done()) - # do the job of the mocked "unregister" - self.server.websockets.clear() - self.stop_server() - def test_explicit_event_loop(self): self.start_server(loop=self.loop) self.start_client(loop=self.loop) @@ -343,6 +322,18 @@ def test_server_close_crashes(self, close): # Connection ends with an abnormal closure. self.assertEqual(self.client.close_code, 1006) + @unittest.mock.patch.object(WebSocketClientProtocol, 'handshake') + def test_client_closes_connection_before_handshake(self, handshake): + self.start_server() + self.start_client() + # We have mocked the handshake() method to prevent the client from + # performing the opening handshake. Force it to close the connection. + self.loop.run_until_complete(self.client.close_connection(force=True)) + self.stop_client() + # The server should stop properly anyway. It used to hang because the + # worker handling the connection was waiting for the opening handshake. + self.stop_server() + @unittest.skipUnless(os.path.exists(testcert), "test certificate is missing") class SSLClientServerTests(ClientServerTests): From b3d4e18e562c20967bac59a93549c99b771d90fa Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 30 Mar 2016 22:58:21 +0200 Subject: [PATCH 0173/1539] Make tests fail instead of hanging. --- websockets/test_client_server.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index e11d6c4f4..08d141a61 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -59,11 +59,19 @@ def start_client(self, path='', **kwds): self.client = self.loop.run_until_complete(client) def stop_client(self): - self.loop.run_until_complete(self.client.worker) + try: + self.loop.run_until_complete( + asyncio.wait_for(self.client.worker, timeout=1)) + except asyncio.TimeoutError: # pragma: no cover + self.fail("Client failed to stop") def stop_server(self): self.server.close() - self.loop.run_until_complete(self.server.wait_closed()) + try: + self.loop.run_until_complete( + asyncio.wait_for(self.server.wait_closed(), timeout=1)) + except asyncio.TimeoutError: # pragma: no cover + self.fail("Server failed to stop") def test_basic(self): self.start_server() From fbc218ea93288e56dd7b55f69186392844171c5c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 30 Mar 2016 22:28:54 +0200 Subject: [PATCH 0174/1539] Fix assorted flake8 warnings or errors. --- setup.cfg | 2 +- websockets/__init__.py | 14 +++++++------- websockets/client.py | 9 +++++---- websockets/handshake.py | 10 +++++----- websockets/http.py | 4 ++-- websockets/protocol.py | 4 ++-- websockets/server.py | 4 ++-- websockets/uri.py | 4 ++-- 8 files changed, 26 insertions(+), 25 deletions(-) diff --git a/setup.cfg b/setup.cfg index 350c3450c..8e99103f0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,7 +2,7 @@ python-tag = py33.py34.py35 [flake8] -ignore = F403 +ignore = E731,F403 [isort] known_standard_library = asyncio diff --git a/websockets/__init__.py b/websockets/__init__.py index 60bc9c5fe..b394c5692 100644 --- a/websockets/__init__.py +++ b/websockets/__init__.py @@ -5,13 +5,13 @@ from .protocol import * from .server import * from .uri import * +from .version import version as __version__ # noqa + __all__ = ( - client.__all__ - + exceptions.__all__ - + protocol.__all__ - + server.__all__ - + uri.__all__ + client.__all__ + + exceptions.__all__ + + protocol.__all__ + + server.__all__ + + uri.__all__ ) - -from .version import version as __version__ # noqa diff --git a/websockets/client.py b/websockets/client.py index f67d43776..12cfa243d 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -3,8 +3,6 @@ """ -__all__ = ['connect', 'WebSocketClientProtocol'] - import asyncio import collections.abc import email.message @@ -16,6 +14,9 @@ from .uri import parse_uri +__all__ = ['connect', 'WebSocketClientProtocol'] + + class WebSocketClientProtocol(WebSocketCommonProtocol): """ Complete WebSocket client implementation as an :class:`asyncio.Protocol`. @@ -88,8 +89,8 @@ def handshake(self, wsuri, check_response(get_header, key) self.subprotocol = headers.get('Sec-WebSocket-Protocol', None) - if (self.subprotocol is not None - and self.subprotocol not in subprotocols): + if (self.subprotocol is not None and + self.subprotocol not in subprotocols): raise InvalidHandshake( "Unknown subprotocol: {}".format(self.subprotocol)) diff --git a/websockets/handshake.py b/websockets/handshake.py index e2ed644d9..0b99242c9 100644 --- a/websockets/handshake.py +++ b/websockets/handshake.py @@ -34,11 +34,6 @@ """ -__all__ = [ - 'build_request', 'check_request', - 'build_response', 'check_response', -] - import base64 import hashlib import random @@ -46,6 +41,11 @@ from .exceptions import InvalidHandshake +__all__ = [ + 'build_request', 'check_request', + 'build_response', 'check_response', +] + GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" diff --git a/websockets/http.py b/websockets/http.py index 1452d0a36..561e79803 100644 --- a/websockets/http.py +++ b/websockets/http.py @@ -7,8 +7,6 @@ """ -__all__ = ['read_request', 'read_response', 'USER_AGENT'] - import asyncio import email.parser import io @@ -17,6 +15,8 @@ from .version import version as websockets_version +__all__ = ['read_request', 'read_response', 'USER_AGENT'] + MAX_HEADERS = 256 MAX_LINE = 4096 diff --git a/websockets/protocol.py b/websockets/protocol.py index 28508c30f..99f801bf0 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -6,8 +6,6 @@ """ -__all__ = ['WebSocketCommonProtocol'] - import asyncio import asyncio.queues import codecs @@ -23,6 +21,8 @@ from .handshake import * +__all__ = ['WebSocketCommonProtocol'] + logger = logging.getLogger(__name__) diff --git a/websockets/server.py b/websockets/server.py index 763d6ff7c..c81fa3f6b 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -3,8 +3,6 @@ """ -__all__ = ['serve', 'WebSocketServerProtocol'] - import asyncio import collections.abc import email.message @@ -17,6 +15,8 @@ from .protocol import CONNECTING, OPEN, WebSocketCommonProtocol +__all__ = ['serve', 'WebSocketServerProtocol'] + logger = logging.getLogger(__name__) diff --git a/websockets/uri.py b/websockets/uri.py index 6e27a7cc3..48c39c1a5 100644 --- a/websockets/uri.py +++ b/websockets/uri.py @@ -6,14 +6,14 @@ """ -__all__ = ['parse_uri', 'WebSocketURI'] - import collections import urllib.parse from .exceptions import InvalidURI +__all__ = ['parse_uri', 'WebSocketURI'] + WebSocketURI = collections.namedtuple( 'WebSocketURI', ('secure', 'host', 'port', 'resource_name')) WebSocketURI.__doc__ = """WebSocket URI. From fd22f008815e5e6a5dcd8cae876b172b7af77b91 Mon Sep 17 00:00:00 2001 From: "Mark E. Haase" Date: Thu, 7 Apr 2016 20:15:06 -0400 Subject: [PATCH 0175/1539] Add flow control for reading from a websocket (#105) Add a configurable size limit on the reader's message queue. When the queue is full, the reader will block. (A well behaved server will slow down transmission to avoid packet loss.) --- websockets/protocol.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index 99f801bf0..0f4f05fde 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -87,8 +87,8 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): def __init__(self, *, host=None, port=None, secure=None, - timeout=10, max_size=2 ** 20, loop=None, - legacy_recv=False): + timeout=10, max_size=2 ** 20, max_queue=2 ** 10, + loop=None, legacy_recv=False): self.host = host self.port = port self.secure = secure @@ -128,7 +128,7 @@ def __init__(self, *, self.connection_closed = asyncio.Future(loop=loop) # Queue of received messages. - self.messages = asyncio.queues.Queue(loop=loop) + self.messages = asyncio.queues.Queue(loop=loop, maxsize=max_queue) # Mapping of ping IDs to waiters, in chronological order. self.pings = collections.OrderedDict() @@ -395,7 +395,7 @@ def run(self): msg = yield from self.read_message() if msg is None: break - self.messages.put_nowait(msg) + yield from self.messages.put(msg) except asyncio.CancelledError: break except WebSocketProtocolError: From 8afa9d7e5d1709603eaeb53263a9ab95d91a25b2 Mon Sep 17 00:00:00 2001 From: "Mark E. Haase" Date: Fri, 8 Apr 2016 01:25:11 -0400 Subject: [PATCH 0176/1539] Update documentation to include max_queue parameter (#105) --- docs/changelog.rst | 2 ++ websockets/protocol.py | 7 +++++++ 2 files changed, 9 insertions(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index 4e3107aa4..bf8bb44a5 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -7,6 +7,8 @@ Changelog *In development* * Avoided a warning when closing a connection before the opening handshake. +* Add flow control when reading from a websocket. (The previous flow control + implementation only affected *writes* to a websocket.) 3.0 ... diff --git a/websockets/protocol.py b/websockets/protocol.py index 0f4f05fde..ba7cb269f 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -62,6 +62,13 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): raise :exc:`~websockets.exceptions.ConnectionClosed` and the connection will be closed with status code 1009. + The ``max_queue`` parameter sets the maximum size for the incoming message + queue. The default value is 1024. ``0`` (zero) disables the limit. When the + queue is full, no more messages will be read from the websocket. In this + full condition, the system's receive buffer will being to fill and the TCP + receive window will shrink. A well-behaved peer will slow down transmission + in order to avoid packet loss. + Once the handshake is complete, request and response HTTP headers are available: From ad401af46f29592b957ce5b25ff721c856193105 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 21 Apr 2016 22:15:32 +0200 Subject: [PATCH 0177/1539] Tweak the default setting and docs for flow control. --- docs/changelog.rst | 6 +++--- websockets/protocol.py | 25 +++++++++++++++++-------- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index bf8bb44a5..dc1bc5c8c 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -7,8 +7,8 @@ Changelog *In development* * Avoided a warning when closing a connection before the opening handshake. -* Add flow control when reading from a websocket. (The previous flow control - implementation only affected *writes* to a websocket.) + +* Added flow control for incoming data. 3.0 ... @@ -155,7 +155,7 @@ Also: Also: -* Added flow control. +* Added flow control for outgoing data. 1.0 ... diff --git a/websockets/protocol.py b/websockets/protocol.py index ba7cb269f..4b5cbcd7e 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -62,12 +62,21 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): raise :exc:`~websockets.exceptions.ConnectionClosed` and the connection will be closed with status code 1009. - The ``max_queue`` parameter sets the maximum size for the incoming message - queue. The default value is 1024. ``0`` (zero) disables the limit. When the - queue is full, no more messages will be read from the websocket. In this - full condition, the system's receive buffer will being to fill and the TCP - receive window will shrink. A well-behaved peer will slow down transmission - in order to avoid packet loss. + The ``max_queue`` parameter sets the maximum length of the queue that holds + incoming messages. The default value is 32. 0 disables the limit. Messages + are added to an in-memory queue when they're received; then :meth:`recv()` + pops from that queue. In order to prevent excessive memory consumption when + messages are received faster than they can be processed, the queue must be + bounded. If the queue fills up, the protocol stops processing incoming data + until :meth:`recv()` is called. In this situation, various receive buffers + (at least in ``asyncio`` and in the OS) will fill up, then the TCP receive + window will shrink, slowing down transmission to avoid packet loss. + + Since Python can use up to 4 bytes of memory to represent a single + character, each websocket connection may use up to ``4 * max_size * + max_queue`` bytes of memory to store incoming messages. By default, + this is 128MB. You may want to lower the limits, depending on your + application's requirements. Once the handshake is complete, request and response HTTP headers are available: @@ -94,7 +103,7 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): def __init__(self, *, host=None, port=None, secure=None, - timeout=10, max_size=2 ** 20, max_queue=2 ** 10, + timeout=10, max_size=2 ** 20, max_queue=2 ** 5, loop=None, legacy_recv=False): self.host = host self.port = port @@ -135,7 +144,7 @@ def __init__(self, *, self.connection_closed = asyncio.Future(loop=loop) # Queue of received messages. - self.messages = asyncio.queues.Queue(loop=loop, maxsize=max_queue) + self.messages = asyncio.queues.Queue(max_queue, loop=loop) # Mapping of ping IDs to waiters, in chronological order. self.pings = collections.OrderedDict() From 817a14fd3b85c0bdc4449bfe792098eaa2a45bfe Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 21 Apr 2016 22:20:10 +0200 Subject: [PATCH 0178/1539] Bump version number. --- docs/changelog.rst | 2 -- docs/conf.py | 4 ++-- websockets/version.py | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index dc1bc5c8c..90024a010 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,8 +4,6 @@ Changelog 3.1 ... -*In development* - * Avoided a warning when closing a connection before the opening handshake. * Added flow control for incoming data. diff --git a/docs/conf.py b/docs/conf.py index 194e59bf4..f5020a773 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -48,9 +48,9 @@ # built documents. # # The short X.Y version. -version = '3.0' +version = '3.1' # The full version, including alpha/beta/rc tags. -release = '3.0' +release = '3.1' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/websockets/version.py b/websockets/version.py index 25c840365..8d35e0a76 100644 --- a/websockets/version.py +++ b/websockets/version.py @@ -1 +1 @@ -version = '3.0' +version = '3.1' From 1b4f0204601da06289b5c0bb3898205752c93e68 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 28 Apr 2016 12:15:54 +0200 Subject: [PATCH 0179/1539] Update link to Read the Docs. --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 01f357cf5..5a383721b 100644 --- a/README.rst +++ b/README.rst @@ -23,6 +23,6 @@ Bug reports, patches and suggestions welcome! Just open an issue_ or send a .. _RFC 6455: http://tools.ietf.org/html/rfc6455 .. _Autobahn Testsuite: https://github.com/aaugustin/websockets/blob/master/compliance/README.rst .. _PEP 3156: http://www.python.org/dev/peps/pep-3156/ -.. _Read the Docs: https://websockets.readthedocs.org/ +.. _Read the Docs: https://websockets.readthedocs.io/ .. _issue: https://github.com/aaugustin/websockets/issues/new .. _pull request: https://github.com/aaugustin/websockets/compare/ From 9729e63d4c85c699222dab482de334a3621d76f9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 19 May 2016 22:09:53 +0200 Subject: [PATCH 0180/1539] Add kwargs to connect() and serve(). Reorganize their docstrings for clarity. Refs #112. --- docs/api.rst | 8 ++++---- docs/changelog.rst | 3 +++ websockets/client.py | 34 ++++++++++++++++++++++------------ websockets/server.py | 41 ++++++++++++++++++++++++++--------------- 4 files changed, 55 insertions(+), 31 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index 0ead249af..779ab5bf6 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -32,9 +32,9 @@ Server .. automodule:: websockets.server - .. autofunction:: serve(ws_handler, host=None, port=None, *, loop=None, klass=WebSocketServerProtocol, origins=None, subprotocols=None, extra_headers=None, **kwds) + .. autofunction:: serve(ws_handler, host=None, port=None, *, klass=WebSocketServerProtocol, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, loop=None, origins=None, subprotocols=None, extra_headers=None, **kwds) - .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, origins=None, subprotocols=None, extra_headers=None, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, loop=None) + .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, loop=None, origins=None, subprotocols=None, extra_headers=None) .. automethod:: handshake(origins=None, subprotocols=None, extra_headers=None) .. automethod:: select_subprotocol(client_protos, server_protos) @@ -44,9 +44,9 @@ Client .. automodule:: websockets.client - .. autofunction:: connect(uri, *, loop=None, klass=WebSocketClientProtocol, origin=None, subprotocols=None, extra_headers=None, **kwds) + .. autofunction:: connect(uri, *, klass=WebSocketClientProtocol, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, loop=None, origin=None, subprotocols=None, extra_headers=None, **kwds) - .. autoclass:: WebSocketClientProtocol(*, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, loop=None) + .. autoclass:: WebSocketClientProtocol(*, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, loop=None) .. automethod:: handshake(wsuri, origin=None, subprotocols=None, extra_headers=None) diff --git a/docs/changelog.rst b/docs/changelog.rst index 90024a010..44c2cee10 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,9 @@ Changelog 3.1 ... +* Added ``timeout``, ``max_size``, and ``max_queue`` arguments to + :func:`~websockets.client.connect()` and :func:`~websockets.server.serve()`. + * Avoided a warning when closing a connection before the opening handshake. * Added flow control for incoming data. diff --git a/websockets/client.py b/websockets/client.py index 12cfa243d..9044c3e49 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -101,21 +101,31 @@ def handshake(self, wsuri, @asyncio.coroutine def connect(uri, *, - loop=None, klass=WebSocketClientProtocol, legacy_recv=False, + klass=WebSocketClientProtocol, + timeout=10, max_size=2 ** 20, max_queue=2 ** 5, + loop=None, legacy_recv=False, origin=None, subprotocols=None, extra_headers=None, **kwds): """ - This coroutine connects to a WebSocket server. + This coroutine connects to a WebSocket server at a given ``uri``. - It's a wrapper around the event loop's + It yields a :class:`WebSocketClientProtocol` which can then be used to + send and receive messages. + + :func:`connect` is a wrapper around the event loop's :meth:`~asyncio.BaseEventLoop.create_connection` method. Extra keyword arguments are passed to :meth:`~asyncio.BaseEventLoop.create_connection`. + For example, you can set the ``ssl`` keyword argument to a :class:`~ssl.SSLContext` to enforce some TLS settings. When connecting to a ``wss://`` URI, if this argument isn't provided explicitly, it's set to ``True``, which means Python's default :class:`~ssl.SSLContext` is used. - :func:`connect` accepts several optional arguments: + The behavior of the ``timeout``, ``max_size``, and ``max_queue`` optional + arguments is described the documentation of + :class:`~websockets.protocol.WebSocketCommonProtocol`. + + :func:`connect` also accepts the following optional arguments: * ``origin`` sets the Origin HTTP header * ``subprotocols`` is a list of supported subprotocols in order of @@ -123,14 +133,12 @@ def connect(uri, *, * ``extra_headers`` sets additional HTTP request headers – it can be a mapping or an iterable of (name, value) pairs - :func:`connect` yields a :class:`WebSocketClientProtocol` which can then - be used to send and receive messages. - - It raises :exc:`~websockets.uri.InvalidURI` if ``uri`` is invalid and - :exc:`~websockets.handshake.InvalidHandshake` if the handshake fails. + :func:`connect` raises :exc:`~websockets.uri.InvalidURI` if ``uri`` is + invalid and :exc:`~websockets.handshake.InvalidHandshake` if the opening + handshake fails. - On Python 3.5, it can be used as a asynchronous context manager. In that - case, the connection is closed when exiting the context. + On Python 3.5, :func:`connect` can be used as a asynchronous context + manager. In that case, the connection is closed when exiting the context. """ if loop is None: @@ -144,7 +152,9 @@ def connect(uri, *, "Use a wss:// URI to enable TLS.") factory = lambda: klass( host=wsuri.host, port=wsuri.port, secure=wsuri.secure, - loop=loop, legacy_recv=legacy_recv) + timeout=timeout, max_size=max_size, max_queue=max_queue, + loop=loop, legacy_recv=legacy_recv, + ) transport, protocol = yield from loop.create_connection( factory, wsuri.host, wsuri.port, **kwds) diff --git a/websockets/server.py b/websockets/server.py index c81fa3f6b..74ab08492 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -245,22 +245,37 @@ def wait_closed(self): @asyncio.coroutine def serve(ws_handler, host=None, port=None, *, - loop=None, klass=WebSocketServerProtocol, legacy_recv=False, + klass=WebSocketServerProtocol, + timeout=10, max_size=2 ** 20, max_queue=2 ** 5, + loop=None, legacy_recv=False, origins=None, subprotocols=None, extra_headers=None, **kwds): """ This coroutine creates a WebSocket server. - It's a wrapper around the event loop's - :meth:`~asyncio.BaseEventLoop.create_server` method. ``host``, ``port`` as - well as extra keyword arguments are passed to - :meth:`~asyncio.BaseEventLoop.create_server`. For example, you can set the - ``ssl`` keyword argument to a :class:`~ssl.SSLContext` to enable TLS. + It yields a :class:`~asyncio.Server` which provides: + + * a :meth:`~asyncio.Server.close` method that closes open connections with + status code 1001 and stops accepting new connections + * a :meth:`~asyncio.Server.wait_closed` coroutine that waits until closing + handshakes complete and connections are closed. ``ws_handler`` is the WebSocket handler. It must be a coroutine accepting two arguments: a :class:`WebSocketServerProtocol` and the request URI. - :func:`serve` accepts several optional arguments: + :func:`serve` is a wrapper around the event loop's + :meth:`~asyncio.BaseEventLoop.create_server` method. ``host``, ``port`` as + well as extra keyword arguments are passed to + :meth:`~asyncio.BaseEventLoop.create_server`. + + For example, you can set the ``ssl`` keyword argument to a + :class:`~ssl.SSLContext` to enable TLS. + + The behavior of the ``timeout``, ``max_size``, and ``max_queue`` optional + arguments is described the documentation of + :class:`~websockets.protocol.WebSocketCommonProtocol`. + + :func:`serve` also accepts the following optional arguments: * ``origins`` defines acceptable Origin HTTP headers — include ``''`` if the lack of an origin is acceptable @@ -270,13 +285,6 @@ def serve(ws_handler, host=None, port=None, *, mapping, an iterable of (name, value) pairs, or a callable taking the request path and headers in arguments. - :func:`serve` yields a :class:`~asyncio.Server` which provides: - - * a :meth:`~asyncio.Server.close` method that closes open connections with - status code 1001 and stops accepting new connections - * a :meth:`~asyncio.Server.wait_closed` coroutine that waits until closing - handshakes complete and connections are closed. - Whenever a client connects, the server accepts the connection, creates a :class:`WebSocketServerProtocol`, performs the opening handshake, and delegates to the WebSocket handler. Once the handler completes, the server @@ -301,8 +309,11 @@ def serve(ws_handler, host=None, port=None, *, factory = lambda: klass( ws_handler, ws_server, host=host, port=port, secure=secure, + timeout=timeout, max_size=max_size, max_queue=max_queue, + loop=loop, legacy_recv=legacy_recv, origins=origins, subprotocols=subprotocols, - extra_headers=extra_headers, loop=loop, legacy_recv=legacy_recv) + extra_headers=extra_headers, + ) server = yield from loop.create_server(factory, host, port, **kwds) ws_server.wrap(server) From 58cb396c05efb03d5bf66592607c3e0dd4918c5d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 4 Jul 2016 11:24:48 +0200 Subject: [PATCH 0181/1539] Rectify changelog. Thanks @RemiCardona for the heads-up. --- docs/changelog.rst | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 44c2cee10..4e6e936a3 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,12 +1,17 @@ Changelog --------- -3.1 +3.2 ... +*In development* + * Added ``timeout``, ``max_size``, and ``max_queue`` arguments to :func:`~websockets.client.connect()` and :func:`~websockets.server.serve()`. +3.1 +... + * Avoided a warning when closing a connection before the opening handshake. * Added flow control for incoming data. From a05fe09ca8f40f00be03b12a378ece0c5e1d6809 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 17 Aug 2016 16:01:08 +0200 Subject: [PATCH 0182/1539] Rename worker to worker_task. This increases consistency with handler_task. --- compliance/test_client.py | 2 +- compliance/test_server.py | 2 +- websockets/protocol.py | 14 +++++++------- websockets/test_client_server.py | 2 +- websockets/test_protocol.py | 2 +- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/compliance/test_client.py b/compliance/test_client.py index 3804e9f92..48898302c 100644 --- a/compliance/test_client.py +++ b/compliance/test_client.py @@ -43,7 +43,7 @@ def get_case_count(server): def run_case(server, case, agent): uri = server + '/runCase?case={}&agent={}'.format(case, agent) ws = yield from websockets.connect(uri, klass=EchoClientProtocol) - yield from ws.worker + yield from ws.worker_task @asyncio.coroutine diff --git a/compliance/test_server.py b/compliance/test_server.py index 46e48128f..7c29f9595 100644 --- a/compliance/test_server.py +++ b/compliance/test_server.py @@ -27,7 +27,7 @@ def read_message(self): @asyncio.coroutine def noop(ws, path): - yield from ws.worker + yield from ws.worker_task start_server = websockets.serve(noop, '127.0.0.1', 8642, klass=EchoServerProtocol) diff --git a/websockets/protocol.py b/websockets/protocol.py index 4b5cbcd7e..b397f81a7 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -150,7 +150,7 @@ def __init__(self, *, self.pings = collections.OrderedDict() # Task managing the connection, initalized in self.client_connected. - self.worker = None + self.worker_task = None # In a subclass implementing the opening handshake, the state will be # CONNECTING at this point. @@ -236,12 +236,12 @@ def close(self, code=1000, reason=''): # the worker loop. try: yield from asyncio.wait_for( - self.worker, self.timeout, loop=self.loop) + self.worker_task, self.timeout, loop=self.loop) except asyncio.TimeoutError: - self.worker.cancel() + self.worker_task.cancel() # The worker should terminate quickly once it has been cancelled. - yield from self.worker + yield from self.worker_task @asyncio.coroutine def recv(self): @@ -278,7 +278,7 @@ def recv(self): self.messages.get(), loop=self.loop) try: done, pending = yield from asyncio.wait( - [next_message, self.worker], + [next_message, self.worker_task], loop=self.loop, return_when=asyncio.FIRST_COMPLETED) except asyncio.CancelledError: # Handle the Task.cancel() @@ -395,7 +395,7 @@ def ensure_open(self): # longer than the worst case (2 * self.timeout) but not unlimited. if self.state == CLOSING: yield from asyncio.wait_for( - self.worker, 3 * self.timeout, loop=self.loop) + self.worker_task, 3 * self.timeout, loop=self.loop) raise ConnectionClosed(self.close_code, self.close_reason) # Control may only reach this point in buggy third-party subclasses. @@ -628,7 +628,7 @@ def client_connected(self, reader, writer): self.reader = reader self.writer = writer # Start the task that handles incoming messages. - self.worker = asyncio_ensure_future(self.run(), loop=self.loop) + self.worker_task = asyncio_ensure_future(self.run(), loop=self.loop) def eof_received(self): super().eof_received() diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 08d141a61..3678bfefc 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -61,7 +61,7 @@ def start_client(self, path='', **kwds): def stop_client(self): try: self.loop.run_until_complete( - asyncio.wait_for(self.client.worker, timeout=1)) + asyncio.wait_for(self.client.worker_task, timeout=1)) except asyncio.TimeoutError: # pragma: no cover self.fail("Client failed to stop") diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 53d67223b..2e256c662 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -345,7 +345,7 @@ def read_message(): self.protocol.read_message = read_message self.process_invalid_frames() with self.assertRaises(Exception): - self.loop.run_until_complete(self.protocol.worker) + self.loop.run_until_complete(self.protocol.worker_task) self.assertConnectionClosed(1011, '') def test_recv_cancelled(self): From 22876eaf44d86cb218caae8681edd9eb9ae3847b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 17 Aug 2016 16:01:59 +0200 Subject: [PATCH 0183/1539] Ensure proper references to the event loop. --- websockets/protocol.py | 2 ++ websockets/server.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index b397f81a7..bfc9cd775 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -113,6 +113,8 @@ def __init__(self, *, self.max_size = max_size # Store a reference to loop to avoid relying on self._loop, a private # attribute of StreamReaderProtocol, inherited from FlowControlMixin. + if loop is None: + loop = asyncio.get_event_loop() self.loop = loop self.legacy_recv = legacy_recv diff --git a/websockets/server.py b/websockets/server.py index 74ab08492..c7b3ac3f6 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -193,7 +193,7 @@ class WebSocketServer(asyncio.AbstractServer): Wrapper for :class:`~asyncio.Server` that triggers the closing handshake. """ - def __init__(self, loop=None): + def __init__(self, loop): # Store a reference to loop to avoid relying on self.server._loop. self.loop = loop @@ -303,7 +303,7 @@ def serve(ws_handler, host=None, port=None, *, if loop is None: loop = asyncio.get_event_loop() - ws_server = WebSocketServer() + ws_server = WebSocketServer(loop) secure = kwds.get('ssl') is not None factory = lambda: klass( From 181acf3e29499ecff61b5e6892d73bae5332292e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 16 Aug 2016 15:19:10 +0200 Subject: [PATCH 0184/1539] Improve server closing sequence. Previously, the server wouldn't close correctly when handlers were waiting on something other than the websocket connection. --- docs/changelog.rst | 2 ++ websockets/client.py | 2 +- websockets/server.py | 61 ++++++++++++++++++++++++++------ websockets/test_client_server.py | 26 ++++++++++++++ 4 files changed, 79 insertions(+), 12 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 4e6e936a3..0ad2f3c2f 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -9,6 +9,8 @@ Changelog * Added ``timeout``, ``max_size``, and ``max_queue`` arguments to :func:`~websockets.client.connect()` and :func:`~websockets.server.serve()`. +* Made server shutdown more robust. + 3.1 ... diff --git a/websockets/client.py b/websockets/client.py index 9044c3e49..acebce332 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -77,7 +77,7 @@ def handshake(self, wsuri, # Read handshake response. try: status_code, headers = yield from read_response(self.reader) - except Exception as exc: + except ValueError as exc: raise InvalidHandshake("Malformed HTTP message") from exc if status_code != 101: raise InvalidHandshake("Bad status code: {}".format(status_code)) diff --git a/websockets/server.py b/websockets/server.py index c7b3ac3f6..dab14796c 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -63,12 +63,15 @@ def handler(self): origins=self.origins, subprotocols=self.subprotocols, extra_headers=self.extra_headers) except Exception as exc: - logger.info("Exception in opening handshake: {}".format(exc)) - if isinstance(exc, InvalidOrigin): + if self._is_server_shutting_down(exc): + response = ('HTTP/1.1 503 Service Unavailable\r\n\r\n' + 'Server is shutting down.') + elif isinstance(exc, InvalidOrigin): response = 'HTTP/1.1 403 Forbidden\r\n\r\n' + str(exc) elif isinstance(exc, InvalidHandshake): response = 'HTTP/1.1 400 Bad Request\r\n\r\n' + str(exc) else: + logger.warning("Error in opening handshake", exc_info=True) response = ('HTTP/1.1 500 Internal Server Error\r\n\r\n' 'See server log for more information.') self.writer.write(response.encode()) @@ -76,15 +79,21 @@ def handler(self): try: yield from self.ws_handler(self, path) - except Exception: - logger.error("Exception in connection handler", exc_info=True) - yield from self.fail_connection(1011) + except Exception as exc: + if self._is_server_shutting_down(exc): + yield from self.fail_connection(1001) + else: + logger.error("Error in connection handler", exc_info=True) + yield from self.fail_connection(1011) raise try: yield from self.close() except Exception as exc: - logger.info("Exception in closing handshake: {}".format(exc)) + if self._is_server_shutting_down(exc): + pass + else: + logger.warning("Error in closing handshake", exc_info=True) raise except Exception: @@ -101,6 +110,16 @@ def handler(self): # connections before terminating. self.ws_server.unregister(self) + def _is_server_shutting_down(self, exc): + """ + Decide whether an exception means that the server is shutting down. + + """ + return ( + isinstance(exc, asyncio.CancelledError) and + self.ws_server.closing + ) + @asyncio.coroutine def handshake(self, origins=None, subprotocols=None, extra_headers=None): """ @@ -122,7 +141,7 @@ def handshake(self, origins=None, subprotocols=None, extra_headers=None): # Read handshake request. try: path, headers = yield from read_request(self.reader) - except Exception as exc: + except ValueError as exc: raise InvalidHandshake("Malformed HTTP message") from exc self.request_headers = headers @@ -197,6 +216,7 @@ def __init__(self, loop): # Store a reference to loop to avoid relying on self.server._loop. self.loop = loop + self.closing = False self.websockets = set() def wrap(self, server): @@ -223,23 +243,42 @@ def unregister(self, protocol): def close(self): """ - Stop serving and trigger a closing handshake on open connections. + Stop accepting new connections and close open connections. """ - for websocket in self.websockets: - asyncio_ensure_future(websocket.close(1001), loop=self.loop) + # Make a note that the server is shutting down. Websocket connections + # check this attribute to decide to send a "going away" close code. + self.closing = True + + # Stop accepting new connections. self.server.close() + # Close open connections. For each connection, two tasks are running: + # 1. self.worker_task shuffles messages between the network and queues + # 2. self.handler_task runs the opening handshake, the handler provided + # by the user and the closing handshake + # In the general case, cancelling the handler task will cause the + # handler provided by the user to exit with a CancelledError, which + # will then cause the worker task to terminate. + for websocket in self.websockets: + websocket.handler_task.cancel() + @asyncio.coroutine def wait_closed(self): """ Wait until all connections are closed. + This method must be called after :meth:`close()`. + """ # asyncio.wait doesn't accept an empty first argument. if self.websockets: + # The handler or the worker task can terminate first, depending + # on how the client behaves and the server is implemented. yield from asyncio.wait( - [ws.handler_task for ws in self.websockets], loop=self.loop) + [websocket.handler_task for websocket in self.websockets] + + [websocket.worker_task for websocket in self.websockets], + loop=self.loop) yield from self.server.wait_closed() diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 3678bfefc..948c631d3 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -342,6 +342,32 @@ def test_client_closes_connection_before_handshake(self, handshake): # worker handling the connection was waiting for the opening handshake. self.stop_server() + @unittest.mock.patch('websockets.server.read_request') + def test_server_shuts_down_during_opening_handshake(self, _read_request): + _read_request.side_effect = asyncio.CancelledError + + self.start_server() + self.server.closing = True + with self.assertRaises(InvalidHandshake) as raised: + self.start_client() + self.stop_server() + + # Opening handshake fails with 503 Service Unavailable + self.assertEqual(str(raised.exception), "Bad status code: 503") + + def test_server_shuts_down_during_connection_handling(self): + self.start_server() + self.start_client() + + self.server.close() + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.client.recv()) + self.stop_client() + self.stop_server() + + # Websocket connection terminates with 1001 Going Away. + self.assertEqual(self.client.close_code, 1001) + @unittest.skipUnless(os.path.exists(testcert), "test certificate is missing") class SSLClientServerTests(ClientServerTests): From 6e728f0357db864f5d1f6495ca9ab27c47c636af Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 17 Aug 2016 23:22:45 +0200 Subject: [PATCH 0185/1539] Bump version number. --- docs/changelog.rst | 5 ++++- docs/conf.py | 4 ++-- websockets/version.py | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 0ad2f3c2f..bdb9dae81 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,11 +1,14 @@ Changelog --------- -3.2 +3.3 ... *In development* +3.2 +... + * Added ``timeout``, ``max_size``, and ``max_queue`` arguments to :func:`~websockets.client.connect()` and :func:`~websockets.server.serve()`. diff --git a/docs/conf.py b/docs/conf.py index f5020a773..aabe2a92e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -48,9 +48,9 @@ # built documents. # # The short X.Y version. -version = '3.1' +version = '3.2' # The full version, including alpha/beta/rc tags. -release = '3.1' +release = '3.2' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/websockets/version.py b/websockets/version.py index 8d35e0a76..7970055a9 100644 --- a/websockets/version.py +++ b/websockets/version.py @@ -1 +1 @@ -version = '3.1' +version = '3.2' From 3a1174a8dff212c8d1a41f1e1afa91cf7dcea877 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 3 Sep 2016 18:54:53 +0200 Subject: [PATCH 0186/1539] Simplify setup.py according to PyPA guidelines. --- setup.py | 38 ++++++++++++++------------------------ 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/setup.py b/setup.py index 06ff24b6c..fecc1f6fc 100644 --- a/setup.py +++ b/setup.py @@ -1,25 +1,16 @@ -import os +import os.path import sys import setuptools - -# Avoid polluting the .tar.gz with ._* files under Mac OS X -os.putenv('COPYFILE_DISABLE', 'true') - -root = os.path.dirname(__file__) - -# Prevent distutils from complaining that a standard file wasn't found -README = os.path.join(root, 'README') -if not os.path.exists(README): - os.symlink(README + '.rst', README) +root_dir = os.path.abspath(os.path.dirname(__file__)) description = "An implementation of the WebSocket Protocol (RFC 6455)" -with open(os.path.join(root, 'README'), encoding='utf-8') as f: - long_description = '\n\n'.join(f.read().split('\n\n')[1:]) +with open(os.path.join(root_dir, 'README.rst')) as f: + long_description = f.read() -with open(os.path.join(root, 'websockets', 'version.py'), encoding='utf-8') as f: +with open(os.path.join(root_dir, 'websockets', 'version.py')) as f: exec(f.read()) py_version = sys.version_info[:2] @@ -28,22 +19,19 @@ raise Exception("websockets requires Python >= 3.3.") packages = ['websockets'] + if py_version >= (3, 5): packages.append('websockets/py35') setuptools.setup( name='websockets', version=version, - author='Aymeric Augustin', - author_email='aymeric.augustin@m4x.org', - url='https://github.com/aaugustin/websockets', description=description, long_description=long_description, - download_url='https://pypi.python.org/pypi/websockets', - packages=packages, - extras_require={ - ':python_version=="3.3"': ['asyncio'], - }, + url='https://github.com/aaugustin/websockets', + author='Aymeric Augustin', + author_email='aymeric.augustin@m4x.org', + license='BSD', classifiers=[ "Development Status :: 5 - Production/Stable", "Environment :: Web Environment", @@ -56,6 +44,8 @@ "Programming Language :: Python :: 3.4", "Programming Language :: Python :: 3.5", ], - platforms='all', - license='BSD' + packages=packages, + extras_require={ + ':python_version=="3.3"': ['asyncio'], + }, ) From 485738f2f324a86aa7e54feb1334ca0d2201d6ae Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 3 Sep 2016 19:15:27 +0200 Subject: [PATCH 0187/1539] Uniformize quotes. --- setup.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/setup.py b/setup.py index fecc1f6fc..57ced1764 100644 --- a/setup.py +++ b/setup.py @@ -33,16 +33,16 @@ author_email='aymeric.augustin@m4x.org', license='BSD', classifiers=[ - "Development Status :: 5 - Production/Stable", - "Environment :: Web Environment", - "Intended Audience :: Developers", - "License :: OSI Approved :: BSD License", - "Operating System :: OS Independent", - "Programming Language :: Python", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.3", - "Programming Language :: Python :: 3.4", - "Programming Language :: Python :: 3.5", + 'Development Status :: 5 - Production/Stable', + 'Environment :: Web Environment', + 'Intended Audience :: Developers', + 'License :: OSI Approved :: BSD License', + 'Operating System :: OS Independent', + 'Programming Language :: Python', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.3', + 'Programming Language :: Python :: 3.4', + 'Programming Language :: Python :: 3.5', ], packages=packages, extras_require={ From 858b261d7d5365128e4f302690892d16fa93e34c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 3 Sep 2016 20:51:51 +0200 Subject: [PATCH 0188/1539] Clean up .gitignore. --- .gitignore | 3 --- 1 file changed, 3 deletions(-) diff --git a/.gitignore b/.gitignore index f453c4914..1a6a602c4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,5 @@ *.pyc .coverage -.DS_Store .tox build/ compliance/reports/ @@ -8,6 +7,4 @@ dist/ docs/_build/ htmlcov/ MANIFEST -README -README.html websockets.egg-info/ From c211a3a1fb29627ef653fff45cbecb17df0a6eb1 Mon Sep 17 00:00:00 2001 From: Lennart Grahl Date: Wed, 28 Sep 2016 18:42:00 +0200 Subject: [PATCH 0189/1539] Make select_subprotocol a static method --- websockets/server.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/websockets/server.py b/websockets/server.py index dab14796c..c9dc5b47e 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -195,7 +195,8 @@ def handshake(self, origins=None, subprotocols=None, extra_headers=None): return path - def select_subprotocol(self, client_protos, server_protos): + @staticmethod + def select_subprotocol(client_protos, server_protos): """ Pick a subprotocol among those offered by the client. From c25954538aab85c259dc37e8235c1362d9543146 Mon Sep 17 00:00:00 2001 From: Mircea Baja Date: Thu, 24 Nov 2016 08:08:26 +0000 Subject: [PATCH 0190/1539] Improve getting started for -both- case (#144) * Improve getting started for -both- case * Minor tweaks. --- docs/intro.rst | 40 +++++++++++++++++++--------------------- 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/docs/intro.rst b/docs/intro.rst index df250ba1c..fcfbeb924 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -90,28 +90,26 @@ messages on the same connection. :: - async def handler(websocket, path): + async def consumer_handler(websocket): + while True: + message = await websocket.recv() + await consumer(message) + + async def producer_handler(websocket): while True: - listener_task = asyncio.ensure_future(websocket.recv()) - producer_task = asyncio.ensure_future(producer()) - done, pending = await asyncio.wait( - [listener_task, producer_task], - return_when=asyncio.FIRST_COMPLETED) - - if listener_task in done: - message = listener_task.result() - await consumer(message) - else: - listener_task.cancel() - - if producer_task in done: - message = producer_task.result() - await websocket.send(message) - else: - producer_task.cancel() - -(This code looks convoluted. If you know a more straightforward solution, -please let me know about it!) + message = await producer() + await websocket.send(message) + + async def handler(websocket, path): + consumer_task = asyncio.ensure_future(consumer_handler(websocket)) + producer_task = asyncio.ensure_future(producer_handler(websocket)) + done, pending = await asyncio.wait( + [consumer_task, producer_task], + return_when=asyncio.FIRST_COMPLETED, + ) + + for task in pending: + task.cancel() Registration ............ From 6b1db40e93fbedd8a7c4134bb0c82d29e651fcbf Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Dec 2016 11:04:09 +0100 Subject: [PATCH 0191/1539] Fix tests on Python 3.5.2. Thanks Julien Enselme for the investigation. --- websockets/test_protocol.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 2e256c662..b4ad4d2c5 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -258,7 +258,7 @@ def test_local_address(self): self.run_loop_once() # The connection is established. self.assertEqual(self.protocol.local_address, ('host', 4312)) - get_extra_info.assert_called_once_with('sockname', None) + get_extra_info.assert_called_with('sockname', None) def test_remote_address(self): get_extra_info = unittest.mock.Mock(return_value=('host', 4312)) @@ -268,7 +268,7 @@ def test_remote_address(self): self.run_loop_once() # The connection is established. self.assertEqual(self.protocol.remote_address, ('host', 4312)) - get_extra_info.assert_called_once_with('peername', None) + get_extra_info.assert_called_with('peername', None) def test_open(self): self.assertTrue(self.protocol.open) From d9f785327d53b76103b019ca4077138294d4e95c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Dec 2016 11:04:45 +0100 Subject: [PATCH 0192/1539] Ignore star imports in linting. --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 8e99103f0..0770d01f5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,7 +2,7 @@ python-tag = py33.py34.py35 [flake8] -ignore = E731,F403 +ignore = E731,F403,F405 [isort] known_standard_library = asyncio From f1d5e1e8da1b206b300700eb14e0a801ec5aad4f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Dec 2016 11:15:04 +0100 Subject: [PATCH 0193/1539] Setup Travis CI. --- .travis.yml | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 .travis.yml diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 000000000..04677c73f --- /dev/null +++ b/.travis.yml @@ -0,0 +1,4 @@ +language: python +python: "3.5" +install: pip install tox +script: tox From 34ca55c368774d486a767ae307a817800f011722 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Dec 2016 11:21:07 +0100 Subject: [PATCH 0194/1539] Don't email me Travis. --- .travis.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.travis.yml b/.travis.yml index 04677c73f..250ef4b9f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,3 +2,5 @@ language: python python: "3.5" install: pip install tox script: tox +notifications: + email: false From fd54f619769a5a07ad85d18152d3c1bf39dd5936 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 15 Jan 2017 12:48:49 +0100 Subject: [PATCH 0195/1539] Add Python 3.6 to supported versions. --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 4955cb38c..4c1e73138 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py33,py34,py35,coverage,flake8,isort +envlist = py33,py34,py35,py36,coverage,flake8,isort [testenv] deps = From 34c4031db2355c20785d0a7d50b5aa3ef5ccb90e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 15 Jan 2017 16:56:30 +0100 Subject: [PATCH 0196/1539] Fix regression from adf6fe4b. --- websockets/test_client_server.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 948c631d3..d63415645 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -234,8 +234,7 @@ def test_subprotocol_not_requested(self): self.stop_client() self.stop_server() - @unittest.mock.patch.object( - WebSocketServerProtocol, 'select_subprotocol', autospec=True) + @unittest.mock.patch.object(WebSocketServerProtocol, 'select_subprotocol') def test_subprotocol_error(self, _select_subprotocol): _select_subprotocol.return_value = 'superchat' From 2e7fd48e19731aa1fe8e5af57825b5a7f43a972c Mon Sep 17 00:00:00 2001 From: Niklas Keller Date: Mon, 6 Feb 2017 08:06:23 +0100 Subject: [PATCH 0197/1539] Fix read_response when reason is empty --- websockets/http.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/websockets/http.py b/websockets/http.py index 561e79803..81f22a824 100644 --- a/websockets/http.py +++ b/websockets/http.py @@ -63,7 +63,7 @@ def read_response(stream): """ status_line, headers = yield from read_message(stream) - version, status, reason = status_line[:-2].decode().split(None, 2) + version, status, reason = status_line[:-2].decode().split(" ", 2) if version != 'HTTP/1.1': raise ValueError("Unsupported HTTP version") return int(status), headers From 2466e74b17494191dc6d390f2b87dcdb46072f52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Cardona?= Date: Fri, 9 Dec 2016 18:14:50 +0100 Subject: [PATCH 0198/1539] server: don't print exceptions for network errors during opening and closing handshakes Since 181acf3, full tracebacks are printed if a connection is reset during handshake. Connection resets are to be expected, even in controlled environments. Simply log them as info rather than warning, and don't print a full traceback. Closes #126. --- websockets/server.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/websockets/server.py b/websockets/server.py index c9dc5b47e..dbc3eece8 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -62,6 +62,9 @@ def handler(self): path = yield from self.handshake( origins=self.origins, subprotocols=self.subprotocols, extra_headers=self.extra_headers) + except ConnectionError as exc: + logger.info('Connection error during opening handshake', exc_info=True) + raise except Exception as exc: if self._is_server_shutting_down(exc): response = ('HTTP/1.1 503 Service Unavailable\r\n\r\n' @@ -89,6 +92,11 @@ def handler(self): try: yield from self.close() + except ConnectionError as exc: + if self._is_server_shutting_down(exc): + pass + logger.info('Connection error in closing handshake', exc_info=True) + raise except Exception as exc: if self._is_server_shutting_down(exc): pass From 198b71537917adb44002573b14cbe23dbd4c21a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Cardona?= Date: Wed, 29 Mar 2017 00:17:49 +0200 Subject: [PATCH 0199/1539] protocol: add a lock around StreamWriter.drain(), closes #16 Works around the following error: File /usr/lib/python3.4/asyncio/streams.py, line 194, in _drain_helper assert waiter is None or waiter.cancelled() AssertionError when the write buffer reaches the high watermark (and thus blocks) and 2+ tasks try to call drain(). Clearly, asyncio's current code (all versions up to 3.6.1 included) is not "thread" safe. So use a lock around the only place drain() is called. --- websockets/protocol.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index bfc9cd775..979ad4b12 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -124,6 +124,7 @@ def __init__(self, *, self.reader = None self.writer = None + self._drain_lock = asyncio.Lock(loop=loop) self.request_headers = None self.raw_request_headers = None @@ -562,8 +563,12 @@ def write_frame(self, opcode, data=b''): yield try: - # Handle flow control automatically. - yield from self.writer.drain() + # drain() cannot be called concurrently by multiple coroutines: + # http://bugs.python.org/issue29930. Remove this lock when no + # version of Python where this bugs exists is supported anymore. + with (yield from self._drain_lock): + # Handle flow control automatically. + yield from self.writer.drain() except ConnectionError: # Terminate the connection if the socket died. yield from self.fail_connection(1006) From c46ab352f6046522b4ac096418367bbd12d76ae6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 29 Mar 2017 15:47:49 +0200 Subject: [PATCH 0200/1539] Add changelog and bump version number. --- docs/changelog.rst | 9 ++++++++- docs/conf.py | 4 ++-- websockets/version.py | 2 +- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index bdb9dae81..dc0ae5f04 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,11 +1,18 @@ Changelog --------- -3.3 +3.4 ... *In development* +3.3 +... + +* Reduced noise in logs caused by connection resets. + +* Avoided crashing on concurrent writes on slow connections. + 3.2 ... diff --git a/docs/conf.py b/docs/conf.py index aabe2a92e..258b57ede 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -48,9 +48,9 @@ # built documents. # # The short X.Y version. -version = '3.2' +version = '3.3' # The full version, including alpha/beta/rc tags. -release = '3.2' +release = '3.3' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/websockets/version.py b/websockets/version.py index 7970055a9..680144bc4 100644 --- a/websockets/version.py +++ b/websockets/version.py @@ -1 +1 @@ -version = '3.2' +version = '3.3' From cdfa863472b64b2eb3b4deac1fcac933aebd70cb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 29 Mar 2017 15:48:43 +0200 Subject: [PATCH 0201/1539] Add 3.6 tag to wheel. --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 0770d01f5..e4ece51a8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bdist_wheel] -python-tag = py33.py34.py35 +python-tag = py33.py34.py35.py36 [flake8] ignore = E731,F403,F405 From e7e681fc1cae2bb13e09418c9c5ab6624bf8160a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 3 May 2017 21:29:08 +0200 Subject: [PATCH 0202/1539] Specify charset for open() calls. Otherwise they're locale dependent, which can cause issues. Fix #171. --- setup.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 57ced1764..a2fe50a81 100644 --- a/setup.py +++ b/setup.py @@ -7,10 +7,12 @@ description = "An implementation of the WebSocket Protocol (RFC 6455)" -with open(os.path.join(root_dir, 'README.rst')) as f: +readme_file = os.path.join(root_dir, 'README.rst') +with open(readme_file, encoding='utf-8') as f: long_description = f.read() -with open(os.path.join(root_dir, 'websockets', 'version.py')) as f: +version_module = os.path.join(root_dir, 'websockets', 'version.py') +with open(version_module, encoding='utf-8') as f: exec(f.read()) py_version = sys.version_info[:2] From 4a570396c7e859340ca8539ab883aa368f8b5415 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 3 May 2017 21:44:12 +0200 Subject: [PATCH 0203/1539] Improve wording a bit. --- docs/intro.rst | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/docs/intro.rst b/docs/intro.rst index fcfbeb924..38df55ce7 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -59,7 +59,7 @@ Consumer For receiving messages and passing them to a ``consumer`` coroutine:: - async def handler(websocket, path): + async def consumer_handler(websocket, path): while True: message = await websocket.recv() await consumer(message) @@ -73,7 +73,7 @@ Producer For getting messages from a ``producer`` coroutine and sending them:: - async def handler(websocket, path): + async def producer_handler(websocket, path): while True: message = await producer() await websocket.send(message) @@ -85,20 +85,8 @@ disconnects, which breaks out of the ``while True`` loop. Both .... -Of course, you can combine the two patterns shown above to read and write -messages on the same connection. - -:: - - async def consumer_handler(websocket): - while True: - message = await websocket.recv() - await consumer(message) - - async def producer_handler(websocket): - while True: - message = await producer() - await websocket.send(message) +You can read and write messages on the same connection by combining the two +patterns shown above and running the two tasks in parallel:: async def handler(websocket, path): consumer_task = asyncio.ensure_future(consumer_handler(websocket)) From a1dce912dfa0841d18d45e93f4f308e31cb1e49b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 3 May 2017 21:53:11 +0200 Subject: [PATCH 0204/1539] Support ssl=None when connecting to ws:// URIs. Fix #149. --- websockets/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/websockets/client.py b/websockets/client.py index acebce332..190afbc94 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -147,7 +147,7 @@ def connect(uri, *, wsuri = parse_uri(uri) if wsuri.secure: kwds.setdefault('ssl', True) - elif 'ssl' in kwds: + elif kwds.get('ssl') is not None: raise ValueError("connect() received a SSL context for a ws:// URI. " "Use a wss:// URI to enable TLS.") factory = lambda: klass( From 388bd4a1e621370571059d4dd4bd584db7debbbd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 3 May 2017 21:56:38 +0200 Subject: [PATCH 0205/1539] Remove Travis CI config. Travis doesn't work. --- .travis.yml | 6 ------ 1 file changed, 6 deletions(-) delete mode 100644 .travis.yml diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 250ef4b9f..000000000 --- a/.travis.yml +++ /dev/null @@ -1,6 +0,0 @@ -language: python -python: "3.5" -install: pip install tox -script: tox -notifications: - email: false From 438310501703ddb580d4bbc7628993190faf1152 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 3 May 2017 22:04:03 +0200 Subject: [PATCH 0206/1539] Document the type of the stream argument. FIx #174. --- websockets/http.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/websockets/http.py b/websockets/http.py index 81f22a824..173a35691 100644 --- a/websockets/http.py +++ b/websockets/http.py @@ -31,6 +31,8 @@ def read_request(stream): """ Read an HTTP/1.1 request from ``stream``. + ``stream`` is an :class:`~asyncio.StreamReader`. + Return ``(path, headers)`` where ``path`` is a :class:`str` and ``headers`` is a :class:`~email.message.Message`. ``path`` isn't URL-decoded. @@ -54,6 +56,8 @@ def read_response(stream): """ Read an HTTP/1.1 response from ``stream``. + ``stream`` is an :class:`~asyncio.StreamReader`. + Return ``(status, headers)`` where ``status`` is a :class:`int` and ``headers`` is a :class:`~email.message.Message`. @@ -74,6 +78,8 @@ def read_message(stream): """ Read an HTTP message from ``stream``. + ``stream`` is an :class:`~asyncio.StreamReader`. + Return ``(start_line, headers)`` where ``start_line`` is :class:`bytes` and ``headers`` is a :class:`~email.message.Message`. @@ -99,6 +105,8 @@ def read_line(stream): """ Read a single line from ``stream``. + ``stream`` is an :class:`~asyncio.StreamReader`. + """ line = yield from stream.readline() if len(line) > MAX_LINE: From 5b48c3c9a58c6d5c61b928c61915e4ee2c42aa41 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 4 May 2017 09:19:37 +0200 Subject: [PATCH 0207/1539] Fix flake8. --- compliance/test_client.py | 5 ++++- compliance/test_server.py | 7 +++++-- websockets/server.py | 6 ++++-- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/compliance/test_client.py b/compliance/test_client.py index 48898302c..6cf7d1b76 100644 --- a/compliance/test_client.py +++ b/compliance/test_client.py @@ -7,7 +7,10 @@ logging.basicConfig(level=logging.WARNING) -#logging.getLogger('websockets').setLevel(logging.DEBUG) + +# Uncomment this line to make only websockets more verbose. +# logging.getLogger('websockets').setLevel(logging.DEBUG) + SERVER = 'ws://127.0.0.1:8642' AGENT = 'websockets' diff --git a/compliance/test_server.py b/compliance/test_server.py index 7c29f9595..75a2d9d00 100644 --- a/compliance/test_server.py +++ b/compliance/test_server.py @@ -5,7 +5,9 @@ logging.basicConfig(level=logging.WARNING) -#logging.getLogger('websockets').setLevel(logging.DEBUG) + +# Uncomment this line to make only websockets more verbose. +# logging.getLogger('websockets').setLevel(logging.DEBUG) class EchoServerProtocol(websockets.WebSocketServerProtocol): @@ -30,7 +32,8 @@ def noop(ws, path): yield from ws.worker_task -start_server = websockets.serve(noop, '127.0.0.1', 8642, klass=EchoServerProtocol) +start_server = websockets.serve( + noop, '127.0.0.1', 8642, klass=EchoServerProtocol) try: asyncio.get_event_loop().run_until_complete(start_server) diff --git a/websockets/server.py b/websockets/server.py index dbc3eece8..fe2adb77b 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -63,7 +63,8 @@ def handler(self): origins=self.origins, subprotocols=self.subprotocols, extra_headers=self.extra_headers) except ConnectionError as exc: - logger.info('Connection error during opening handshake', exc_info=True) + logger.info( + "Connection error during opening handshake", exc_info=True) raise except Exception as exc: if self._is_server_shutting_down(exc): @@ -95,7 +96,8 @@ def handler(self): except ConnectionError as exc: if self._is_server_shutting_down(exc): pass - logger.info('Connection error in closing handshake', exc_info=True) + logger.info( + "Connection error in closing handshake", exc_info=True) raise except Exception as exc: if self._is_server_shutting_down(exc): From 65179090062eac8c555bf0f3f85bfbee4c957baa Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 4 May 2017 09:22:13 +0200 Subject: [PATCH 0208/1539] Configure Circle CI. --- circle.yml | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 circle.yml diff --git a/circle.yml b/circle.yml new file mode 100644 index 000000000..13c54c381 --- /dev/null +++ b/circle.yml @@ -0,0 +1,13 @@ +machine: + post: + - pyenv global 3.3.6 3.4.4 3.5.2 3.6.1 + python: + version: 3.6.1 + +dependencies: + pre: + - pip install tox + +test: + override: + - tox From b763ac48b9dc3cdd446eaff14ed5bf62a2290ed5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 4 May 2017 09:39:08 +0200 Subject: [PATCH 0209/1539] Fix deprecation warning for Python 3.7. --- websockets/test_protocol.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index b4ad4d2c5..c08c57b47 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -98,7 +98,7 @@ def delayed_drain(): remote_close = Frame(True, OP_CLOSE, serialize_close(1000, 'remote')) @property - def async(self): + def ensure_future(self): return functools.partial(asyncio_ensure_future, loop=self.loop) def receive_frame(self, frame): @@ -164,7 +164,7 @@ def close_connection_partial(self, code=1000, reason='close'): """ close_frame_data = serialize_close(code, reason) # Trigger the closing handshake from the local side. - self.async(self.protocol.close(code, reason)) + self.ensure_future(self.protocol.close(code, reason)) self.run_loop_once() # Empty the outgoing data stream so we can make assertions later on. self.assertOneFrameSent(True, OP_CLOSE, close_frame_data) @@ -349,7 +349,7 @@ def read_message(): self.assertConnectionClosed(1011, '') def test_recv_cancelled(self): - recv = self.async(self.protocol.recv()) + recv = self.ensure_future(self.protocol.recv()) self.loop.call_soon(recv.cancel) with self.assertRaises(asyncio.CancelledError): self.loop.run_until_complete(recv) @@ -698,7 +698,8 @@ def test_remote_close_race_with_failing_connection(self): # Fail the connection while answering a close frame from the client. self.loop.call_soon(self.receive_frame, self.remote_close) - self.loop.call_later(MS, self.async, self.protocol.fail_connection()) + self.loop.call_later( + MS, self.ensure_future, self.protocol.fail_connection()) # The client expects the server to close the connection. # Simulate it instead of waiting for the connection timeout. self.loop.call_later(MS, self.receive_eof_if_client) @@ -711,7 +712,7 @@ def test_remote_close_race_with_failing_connection(self): self.assertOneFrameSent(*self.remote_close) def test_local_close_during_recv(self): - recv = self.async(self.protocol.recv()) + recv = self.ensure_future(self.protocol.recv()) self.receive_frame(self.close_frame) self.receive_eof_if_client() @@ -728,7 +729,7 @@ def test_local_close_during_recv(self): def test_remote_close_during_send(self): self.make_drain_slow() - send = self.async(self.protocol.send('hello')) + send = self.ensure_future(self.protocol.send('hello')) self.receive_frame(self.close_frame) self.receive_eof() From 5e2adb251ddcbd7856de1a72dd1dfa893c20ebd5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 4 May 2017 10:58:18 +0200 Subject: [PATCH 0210/1539] Downgrade connection errors in handshakes to debug. They're very common and either obvious (connection fails to establish) or harmless (the connection was closing anyway). --- websockets/server.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/websockets/server.py b/websockets/server.py index fe2adb77b..c6ff259af 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -63,8 +63,8 @@ def handler(self): origins=self.origins, subprotocols=self.subprotocols, extra_headers=self.extra_headers) except ConnectionError as exc: - logger.info( - "Connection error during opening handshake", exc_info=True) + logger.debug( + "Connection error in opening handshake", exc_info=True) raise except Exception as exc: if self._is_server_shutting_down(exc): @@ -94,15 +94,12 @@ def handler(self): try: yield from self.close() except ConnectionError as exc: - if self._is_server_shutting_down(exc): - pass - logger.info( - "Connection error in closing handshake", exc_info=True) + if not self._is_server_shutting_down(exc): + logger.debug( + "Connection error in closing handshake", exc_info=True) raise except Exception as exc: - if self._is_server_shutting_down(exc): - pass - else: + if not self._is_server_shutting_down(exc): logger.warning("Error in closing handshake", exc_info=True) raise From 93c6f28ac0597d09f18c25d4075f22dc0528e9b6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 4 May 2017 22:23:07 +0200 Subject: [PATCH 0211/1539] Get back to 100% branch coverage. Remove a branch that doesn't make sense in the process. --- websockets/server.py | 5 ++--- websockets/test_client_server.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/websockets/server.py b/websockets/server.py index c6ff259af..e38ba6326 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -94,9 +94,8 @@ def handler(self): try: yield from self.close() except ConnectionError as exc: - if not self._is_server_shutting_down(exc): - logger.debug( - "Connection error in closing handshake", exc_info=True) + logger.debug( + "Connection error in closing handshake", exc_info=True) raise except Exception as exc: if not self._is_server_shutting_down(exc): diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index d63415645..e6346dd2e 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -367,6 +367,34 @@ def test_server_shuts_down_during_connection_handling(self): # Websocket connection terminates with 1001 Going Away. self.assertEqual(self.client.close_code, 1001) + @unittest.mock.patch('websockets.server.read_request') + def test_connection_error_during_opening_handshake(self, _read_request): + _read_request.side_effect = ConnectionError + + self.start_server() + with self.assertRaises(InvalidHandshake) as raised: + self.start_client() + self.stop_server() + + # Opening handshake doesn't complete -- since we faked a connection + # error, the server doesn't send a response to the client. + self.assertEqual(str(raised.exception), "Malformed HTTP message") + + @unittest.mock.patch('websockets.server.WebSocketServerProtocol.close') + def test_connection_error_during_closing_handshake(self, close): + close.side_effect = ConnectionError + + self.start_server() + self.start_client() + self.loop.run_until_complete(self.client.send("Hello!")) + reply = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(reply, "Hello!") + self.stop_client() + self.stop_server() + + # Connection ends with an abnormal closure. + self.assertEqual(self.client.close_code, 1006) + @unittest.skipUnless(os.path.exists(testcert), "test certificate is missing") class SSLClientServerTests(ClientServerTests): From ec8a86038f4b8b157c56c9a0e689bd4166a27246 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 15 Jan 2017 12:41:14 +0100 Subject: [PATCH 0212/1539] Refactor server-side handshake handling. The primary goal is to make it easier to add authentication through subclassing. (Subclassing as an API is a mistake, but that's hard to change at this point.) --- docs/api.rst | 1 + websockets/client.py | 4 +- websockets/exceptions.py | 12 ++- websockets/server.py | 164 ++++++++++++++++++++++++++++----------- 4 files changed, 132 insertions(+), 49 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index 779ab5bf6..d65418d15 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -38,6 +38,7 @@ Server .. automethod:: handshake(origins=None, subprotocols=None, extra_headers=None) .. automethod:: select_subprotocol(client_protos, server_protos) + .. automethod:: get_response_status() Client ...... diff --git a/websockets/client.py b/websockets/client.py index 190afbc94..2279f271b 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -7,7 +7,7 @@ import collections.abc import email.message -from .exceptions import InvalidHandshake +from .exceptions import InvalidHandshake, InvalidMessage from .handshake import build_request, check_response from .http import USER_AGENT, read_response from .protocol import CONNECTING, OPEN, WebSocketCommonProtocol @@ -78,7 +78,7 @@ def handshake(self, wsuri, try: status_code, headers = yield from read_response(self.reader) except ValueError as exc: - raise InvalidHandshake("Malformed HTTP message") from exc + raise InvalidMessage("Malformed HTTP message") from exc if status_code != 101: raise InvalidHandshake("Bad status code: {}".format(status_code)) diff --git a/websockets/exceptions.py b/websockets/exceptions.py index 8824df03c..3d3ad46f1 100644 --- a/websockets/exceptions.py +++ b/websockets/exceptions.py @@ -1,6 +1,7 @@ __all__ = [ - 'InvalidHandshake', 'InvalidOrigin', 'InvalidState', 'InvalidURI', - 'ConnectionClosed', 'PayloadTooBig', 'WebSocketProtocolError', + 'InvalidHandshake', 'InvalidMessage', 'InvalidOrigin', 'InvalidState', + 'InvalidURI', 'ConnectionClosed', 'PayloadTooBig', + 'WebSocketProtocolError', ] @@ -11,6 +12,13 @@ class InvalidHandshake(Exception): """ +class InvalidMessage(InvalidHandshake): + """ + Exception raised when the HTTP message in a handshake request is malformed. + + """ + + class InvalidOrigin(InvalidHandshake): """ Exception raised when the origin in a handshake request is forbidden. diff --git a/websockets/server.py b/websockets/server.py index e38ba6326..5ba7d6ac9 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -1,3 +1,4 @@ + """ The :mod:`websockets.server` module defines a simple WebSocket server API. @@ -6,10 +7,11 @@ import asyncio import collections.abc import email.message +import http import logging from .compatibility import asyncio_ensure_future -from .exceptions import InvalidHandshake, InvalidOrigin +from .exceptions import InvalidHandshake, InvalidMessage, InvalidOrigin from .handshake import build_response, check_request from .http import USER_AGENT, read_request from .protocol import CONNECTING, OPEN, WebSocketCommonProtocol @@ -19,6 +21,13 @@ logger = logging.getLogger(__name__) +try: + SWITCHING_PROTOCOLS = http.HTTPStatus.SWITCHING_PROTOCOLS +except AttributeError: # pragma: no cover + class SWITCHING_PROTOCOLS: + value = 101 + phrase = 'Switching protocols' + class WebSocketServerProtocol(WebSocketCommonProtocol): """ @@ -127,51 +136,139 @@ def _is_server_shutting_down(self, exc): ) @asyncio.coroutine - def handshake(self, origins=None, subprotocols=None, extra_headers=None): + def read_request_headers(self): """ - Perform the server side of the opening handshake. - - If provided, ``origins`` is a list of acceptable HTTP Origin values. - Include ``''`` if the lack of an origin is acceptable. + Read headers from the HTTP request. - If provided, ``subprotocols`` is a list of supported subprotocols in - order of decreasing preference. - - If provided, ``extra_headers`` sets additional HTTP response headers. - It can be a mapping or an iterable of (name, value) pairs. It can also - be a callable taking the request path and headers in arguments. - - Return the URI of the request. + Raise :exc:`~websockets.exceptions.InvalidMessage` if the HTTP message + is malformed or isn't a HTTP/1.1 GET request. """ - # Read handshake request. try: path, headers = yield from read_request(self.reader) except ValueError as exc: - raise InvalidHandshake("Malformed HTTP message") from exc + raise InvalidMessage("Malformed HTTP message") from exc self.request_headers = headers self.raw_request_headers = list(headers.raw_items()) - get_header = lambda k: headers.get(k, '') - key = check_request(get_header) + return path, headers + @asyncio.coroutine + def write_response_headers(self, status, headers): + """ + Write headers to the HTTP response. + + """ + self.response_headers = email.message.Message() + for name, value in headers: + self.response_headers[name] = value + self.raw_response_headers = headers + + # Since the status line and headers only contain ASCII characters, + # we can keep this simple. + response = [ + 'HTTP/1.1 {value} {phrase}'.format( + value=status.value, phrase=status.phrase)] + response.extend('{}: {}'.format(k, v) for k, v in headers) + response.append('\r\n') + response = '\r\n'.join(response).encode() + + self.writer.write(response) + + def process_origin(self, get_header, origins=None): + """ + Handle the Origin HTTP header when ``origins`` is provided. + + Raise :exc:`~websockets.exceptions.InvalidOrigin` if the origin isn't + acceptable. + + """ if origins is not None: origin = get_header('Origin') - if not set(origin.split() or ['']) <= set(origins): + if origin not in origins: raise InvalidOrigin("Origin not allowed: {}".format(origin)) + return origin + + def process_subprotocol(self, get_header, subprotocols=None): + """ + Handle the Sec-WebSocket-Protocol HTTP header when ``subprotocols`` is provided. + """ if subprotocols is not None: protocol = get_header('Sec-WebSocket-Protocol') if protocol: client_subprotocols = [p.strip() for p in protocol.split(',')] - self.subprotocol = self.select_subprotocol( + return self.select_subprotocol( client_subprotocols, subprotocols) + @staticmethod + def select_subprotocol(client_protos, server_protos): + """ + Pick a subprotocol among those offered by the client. + + """ + common_protos = set(client_protos) & set(server_protos) + if not common_protos: + return None + priority = lambda p: client_protos.index(p) + server_protos.index(p) + return sorted(common_protos, key=priority)[0] + + @asyncio.coroutine + def get_response_status(self): + """ + Return a :class:`~http.HTTPStatus` for the HTTP response. + + (:class:`~http.HTTPStatus` was added in Python 3.5. On earlier + versions, a compatible object must be returned. Check the definition + of ``SWITCHING_PROTOCOLS`` for an example.) + + This method may be overridden to check the request headers and set a + different status, for example to authenticate the request and return + ``HTTPStatus.UNAUTHORIZED`` or ``HTTPStatus.FORBIDDEN``. + + It is declared as a coroutine because such authentication checks are + likely to require network requests. + + """ + return SWITCHING_PROTOCOLS + + @asyncio.coroutine + def handshake(self, origins=None, subprotocols=None, extra_headers=None): + """ + Perform the server side of the opening handshake. + + If provided, ``origins`` is a list of acceptable HTTP Origin values. + Include ``''`` if the lack of an origin is acceptable. + + If provided, ``subprotocols`` is a list of supported subprotocols in + order of decreasing preference. + + If provided, ``extra_headers`` sets additional HTTP response headers. + It can be a mapping or an iterable of (name, value) pairs. It can also + be a callable taking the request path and headers in arguments. + + Raise :exc:`~websockets.exceptions.InvalidHandshake` or a subclass if + the handshake fails. + + Return the URI of the request. + + """ + path, headers = yield from self.read_request_headers() + get_header = lambda k: headers.get(k, '') + + key = check_request(get_header) + + self.origin = self.process_origin(get_header, origins) + self.subprotocol = self.process_subprotocol(get_header, subprotocols) + headers = [] set_header = lambda k, v: headers.append((k, v)) + + status = yield from self.get_response_status() + set_header('Server', USER_AGENT) - if self.subprotocol: + if status.value == 101 and self.subprotocol: set_header('Sec-WebSocket-Protocol', self.subprotocol) if extra_headers is not None: if callable(extra_headers): @@ -182,18 +279,7 @@ def handshake(self, origins=None, subprotocols=None, extra_headers=None): set_header(name, value) build_response(set_header, key) - self.response_headers = email.message.Message() - for name, value in headers: - self.response_headers[name] = value - self.raw_response_headers = headers - - # Send handshake response. Since the status line and headers only - # contain ASCII characters, we can keep this simple. - response = ['HTTP/1.1 101 Switching Protocols'] - response.extend('{}: {}'.format(k, v) for k, v in headers) - response.append('\r\n') - response = '\r\n'.join(response).encode() - self.writer.write(response) + yield from self.write_response_headers(status, headers) assert self.state == CONNECTING self.state = OPEN @@ -201,18 +287,6 @@ def handshake(self, origins=None, subprotocols=None, extra_headers=None): return path - @staticmethod - def select_subprotocol(client_protos, server_protos): - """ - Pick a subprotocol among those offered by the client. - - """ - common_protos = set(client_protos) & set(server_protos) - if not common_protos: - return None - priority = lambda p: client_protos.index(p) + server_protos.index(p) - return sorted(common_protos, key=priority)[0] - class WebSocketServer(asyncio.AbstractServer): """ From 3051b2ea3261fea2168675168660caa91947f28e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 15 Jan 2017 13:16:59 +0100 Subject: [PATCH 0213/1539] Refactor client-side handshake handling. --- websockets/client.py | 82 ++++++++++++++++++++++++++++++-------------- websockets/server.py | 13 +++---- 2 files changed, 64 insertions(+), 31 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index 2279f271b..29c27f6bb 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -28,6 +28,57 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): is_client = True state = CONNECTING + @asyncio.coroutine + def write_request_headers(self, path, headers): + """ + Write headers to the HTTP request. + + """ + self.request_headers = email.message.Message() + for name, value in headers: + self.request_headers[name] = value + self.raw_request_headers = headers + + # Since the path and headers only contain ASCII characters, + # we can keep this simple. + request = ['GET {path} HTTP/1.1'.format(path=path)] + request.extend('{}: {}'.format(k, v) for k, v in headers) + request.append('\r\n') + request = '\r\n'.join(request).encode() + + self.writer.write(request) + + @asyncio.coroutine + def read_response_headers(self): + """ + Read headers from the HTTP response. + + Raise :exc:`~websockets.exceptions.InvalidMessage` if the HTTP message + is malformed or isn't a HTTP/1.1 GET request. + + """ + try: + status_code, headers = yield from read_response(self.reader) + except ValueError as exc: + raise InvalidMessage("Malformed HTTP message") from exc + + self.response_headers = headers + self.raw_response_headers = list(headers.raw_items()) + + return status_code, headers + + def process_subprotocol(self, get_header, subprotocols=None): + """ + Handle the Sec-WebSocket-Protocol HTTP header. + + """ + subprotocol = get_header('Sec-WebSocket-Protocol') + if subprotocol: + if subprotocols is None or subprotocol not in subprotocols: + raise InvalidHandshake( + "Unknown subprotocol: {}".format(subprotocol)) + return subprotocol + @asyncio.coroutine def handshake(self, wsuri, origin=None, subprotocols=None, extra_headers=None): @@ -45,6 +96,7 @@ def handshake(self, wsuri, """ headers = [] set_header = lambda k, v: headers.append((k, v)) + if wsuri.port == (443 if wsuri.secure else 80): # pragma: no cover set_header('Host', wsuri.host) else: @@ -59,40 +111,20 @@ def handshake(self, wsuri, for name, value in extra_headers: set_header(name, value) set_header('User-Agent', USER_AGENT) + key = build_request(set_header) - self.request_headers = email.message.Message() - for name, value in headers: - self.request_headers[name] = value - self.raw_request_headers = headers + yield from self.write_request_headers(wsuri.resource_name, headers) - # Send handshake request. Since the URI and the headers only contain - # ASCII characters, we can keep this simple. - request = ['GET %s HTTP/1.1' % wsuri.resource_name] - request.extend('{}: {}'.format(k, v) for k, v in headers) - request.append('\r\n') - request = '\r\n'.join(request).encode() - self.writer.write(request) + status_code, headers = yield from self.read_response_headers() + get_header = lambda k: headers.get(k, '') - # Read handshake response. - try: - status_code, headers = yield from read_response(self.reader) - except ValueError as exc: - raise InvalidMessage("Malformed HTTP message") from exc if status_code != 101: raise InvalidHandshake("Bad status code: {}".format(status_code)) - self.response_headers = headers - self.raw_response_headers = list(headers.raw_items()) - - get_header = lambda k: headers.get(k, '') check_response(get_header, key) - self.subprotocol = headers.get('Sec-WebSocket-Protocol', None) - if (self.subprotocol is not None and - self.subprotocol not in subprotocols): - raise InvalidHandshake( - "Unknown subprotocol: {}".format(self.subprotocol)) + self.subprotocol = self.process_subprotocol(get_header, subprotocols) assert self.state == CONNECTING self.state = OPEN diff --git a/websockets/server.py b/websockets/server.py index 5ba7d6ac9..fbd733dce 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -178,7 +178,7 @@ def write_response_headers(self, status, headers): def process_origin(self, get_header, origins=None): """ - Handle the Origin HTTP header when ``origins`` is provided. + Handle the Origin HTTP header. Raise :exc:`~websockets.exceptions.InvalidOrigin` if the origin isn't acceptable. @@ -192,15 +192,16 @@ def process_origin(self, get_header, origins=None): def process_subprotocol(self, get_header, subprotocols=None): """ - Handle the Sec-WebSocket-Protocol HTTP header when ``subprotocols`` is provided. + Handle the Sec-WebSocket-Protocol HTTP header. """ if subprotocols is not None: - protocol = get_header('Sec-WebSocket-Protocol') - if protocol: - client_subprotocols = [p.strip() for p in protocol.split(',')] + subprotocol = get_header('Sec-WebSocket-Protocol') + if subprotocol: return self.select_subprotocol( - client_subprotocols, subprotocols) + [p.strip() for p in subprotocol.split(',')], + subprotocols, + ) @staticmethod def select_subprotocol(client_protos, server_protos): From e58c1ac1315f4dd73aea99eaca2a60509f2fcf03 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 5 May 2017 09:41:19 +0200 Subject: [PATCH 0214/1539] Simplify implementing auth during the handshake. --- docs/changelog.rst | 3 +++ websockets/client.py | 13 ++++++----- websockets/protocol.py | 6 +++-- websockets/server.py | 40 ++++++++++++++++++++++++-------- websockets/test_client_server.py | 34 +++++++++++++++++++++++++++ 5 files changed, 78 insertions(+), 18 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index dc0ae5f04..8f6567532 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -6,6 +6,9 @@ Changelog *In development* +* Added support rejecting incoming connections by customizing + :meth:`~websockets.server.WebSocketServerProtocol.get_response_status()`. + 3.3 ... diff --git a/websockets/client.py b/websockets/client.py index 29c27f6bb..ad15ba166 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -29,11 +29,12 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): state = CONNECTING @asyncio.coroutine - def write_request_headers(self, path, headers): + def write_http_request(self, path, headers): """ - Write headers to the HTTP request. + Write status line and headers to the HTTP request. """ + self.path = path self.request_headers = email.message.Message() for name, value in headers: self.request_headers[name] = value @@ -49,9 +50,9 @@ def write_request_headers(self, path, headers): self.writer.write(request) @asyncio.coroutine - def read_response_headers(self): + def read_http_response(self): """ - Read headers from the HTTP response. + Read status line and headers from the HTTP response. Raise :exc:`~websockets.exceptions.InvalidMessage` if the HTTP message is malformed or isn't a HTTP/1.1 GET request. @@ -114,9 +115,9 @@ def handshake(self, wsuri, key = build_request(set_header) - yield from self.write_request_headers(wsuri.resource_name, headers) + yield from self.write_http_request(wsuri.resource_name, headers) - status_code, headers = yield from self.read_response_headers() + status_code, headers = yield from self.read_http_response() get_header = lambda k: headers.get(k, '') if status_code != 101: diff --git a/websockets/protocol.py b/websockets/protocol.py index 979ad4b12..1b8287f67 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -78,8 +78,9 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): this is 128MB. You may want to lower the limits, depending on your application's requirements. - Once the handshake is complete, request and response HTTP headers are - available: + As soon as the HTTP request and response in the opening handshake are + processed, the request path is available in the :attr:`path` attribute, + and the request and response HTTP headers are available: * as a MIME :class:`~email.message.Message` in the :attr:`request_headers` and :attr:`response_headers` attributes @@ -126,6 +127,7 @@ def __init__(self, *, self.writer = None self._drain_lock = asyncio.Lock(loop=loop) + self.path = None self.request_headers = None self.raw_request_headers = None self.response_headers = None diff --git a/websockets/server.py b/websockets/server.py index fbd733dce..cd8503ce8 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -90,6 +90,11 @@ def handler(self): self.writer.write(response.encode()) raise + # Subclasses can customize get_response_status() or handshake() to + # reject the handshake, typically after checking authentication. + if path is None: + return + try: yield from self.ws_handler(self, path) except Exception as exc: @@ -136,9 +141,9 @@ def _is_server_shutting_down(self, exc): ) @asyncio.coroutine - def read_request_headers(self): + def read_http_request(self): """ - Read headers from the HTTP request. + Read status line and headers from the HTTP request. Raise :exc:`~websockets.exceptions.InvalidMessage` if the HTTP message is malformed or isn't a HTTP/1.1 GET request. @@ -149,15 +154,16 @@ def read_request_headers(self): except ValueError as exc: raise InvalidMessage("Malformed HTTP message") from exc + self.path = path self.request_headers = headers self.raw_request_headers = list(headers.raw_items()) return path, headers @asyncio.coroutine - def write_response_headers(self, status, headers): + def write_http_response(self, status, headers): """ - Write headers to the HTTP response. + Write status line and headers to the HTTP response. """ self.response_headers = email.message.Message() @@ -216,7 +222,7 @@ def select_subprotocol(client_protos, server_protos): return sorted(common_protos, key=priority)[0] @asyncio.coroutine - def get_response_status(self): + def get_response_status(self, set_header): """ Return a :class:`~http.HTTPStatus` for the HTTP response. @@ -231,6 +237,11 @@ def get_response_status(self): It is declared as a coroutine because such authentication checks are likely to require network requests. + The connection is closed immediately after sending the response when + the status code is not ``HTTPStatus.SWITCHING_PROTOCOLS``. + + Call ``set_header(key, value)`` to set additional response headers. + """ return SWITCHING_PROTOCOLS @@ -255,7 +266,7 @@ def handshake(self, origins=None, subprotocols=None, extra_headers=None): Return the URI of the request. """ - path, headers = yield from self.read_request_headers() + path, headers = yield from self.read_http_request() get_header = lambda k: headers.get(k, '') key = check_request(get_header) @@ -266,10 +277,19 @@ def handshake(self, origins=None, subprotocols=None, extra_headers=None): headers = [] set_header = lambda k, v: headers.append((k, v)) - status = yield from self.get_response_status() - set_header('Server', USER_AGENT) - if status.value == 101 and self.subprotocol: + + status = yield from self.get_response_status(set_header) + + # Abort the connection if the status code isn't 101. + if status.value != SWITCHING_PROTOCOLS.value: + yield from self.write_http_response(status, headers) + self.opening_handshake.set_result(False) + yield from self.close_connection(force=True) + return + + # Status code is 101, establish the connection. + if self.subprotocol: set_header('Sec-WebSocket-Protocol', self.subprotocol) if extra_headers is not None: if callable(extra_headers): @@ -280,7 +300,7 @@ def handshake(self, origins=None, subprotocols=None, extra_headers=None): set_header(name, value) build_response(set_header, key) - yield from self.write_response_headers(status, headers) + yield from self.write_http_response(status, headers) assert self.state == CONNECTING self.state = OPEN diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index e6346dd2e..60e7616c7 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -1,4 +1,5 @@ import asyncio +import http import logging import os import ssl @@ -21,6 +22,8 @@ def handler(ws, path): if path == '/attributes': yield from ws.send(repr((ws.host, ws.port, ws.secure))) + elif path == '/path': + yield from ws.send(str(ws.path)) elif path == '/headers': yield from ws.send(str(ws.request_headers)) yield from ws.send(str(ws.response_headers)) @@ -33,6 +36,21 @@ def handler(ws, path): yield from ws.send((yield from ws.recv())) +try: + FORBIDDEN = http.HTTPStatus.FORBIDDEN +except AttributeError: # pragma: no cover + class FORBIDDEN: + value = 403 + phrase = 'Forbidden' + + +class ForbiddenWebSocketServerProtocol(WebSocketServerProtocol): + + @asyncio.coroutine + def get_response_status(self, set_header): + return FORBIDDEN + + class ClientServerTests(unittest.TestCase): secure = False @@ -107,6 +125,16 @@ def test_protocol_attributes(self): self.stop_client() self.stop_server() + def test_protocol_path(self): + self.start_server() + self.start_client('path') + client_path = self.client.path + self.assertEqual(client_path, '/path') + server_path = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(server_path, '/path') + self.stop_client() + self.stop_server() + def test_protocol_headers(self): self.start_server() self.start_client('headers') @@ -189,6 +217,12 @@ def test_protocol_custom_response_headers_list(self): self.stop_client() self.stop_server() + def test_authentication(self): + self.start_server(klass=ForbiddenWebSocketServerProtocol) + with self.assertRaises(InvalidHandshake): + self.start_client() + self.stop_server() + def test_no_subprotocol(self): self.start_server() self.start_client('subprotocol') From a41c8310807af4de9b2275c685988b3cbcf01ae2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 5 May 2017 11:57:00 +0200 Subject: [PATCH 0215/1539] Document how to shut down the server. Fix #124, #103. --- docs/deployment.rst | 34 ++++++++++++++++++++++++++++++++-- example/client.py | 2 +- example/shutdown.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+), 3 deletions(-) create mode 100644 example/shutdown.py diff --git a/docs/deployment.rst b/docs/deployment.rst index e5d952538..fb52421c2 100644 --- a/docs/deployment.rst +++ b/docs/deployment.rst @@ -4,9 +4,39 @@ Deployment The author of ``websockets`` isn't aware of best practices for deploying network services based on :mod:`asyncio`. -He suggests running a Python script similar to the :ref:`server example -`, perhaps inside a supervisor if you deem it useful. +You can run a script similar to the :ref:`server example `, +inside a supervisor if you deem that useful. + +You can also add a wrapper to daemonize the process. Third-party libraries +provide solutions for that. If you can share knowledge on this topic, please file an issue_. Thanks! .. _issue: https://github.com/aaugustin/websockets/issues/new + +Graceful shutdown +----------------- + +You may want to close connections gracefully when shutting down the server, +perhaps after executing some cleanup logic. + +The proper way to do this is to call the ``close()`` method of the object +returned by :func:`~websockets.server.serve`, then wait for ``wait_closed()`` +to complete. + +Tasks that handle connections will be cancelled, in the sense that +:meth:`~websockets.protocol.WebSocketCommonProtocol.recv` raises +:exc:`~asyncio.CancelledError`. + +On Unix systems, shutdown is usually triggered by sending a signal. + +Here's a full example (Unix-only): + +.. literalinclude:: ../example/shutdown.py + + +It's more difficult to achieve the same effect on Windows. Some third-party +projects try to help with this problem. + +If your server doesn't run in the main thread, look at +:func:`~asyncio.AbstractEventLoop.call_soon_threadsafe`. diff --git a/example/client.py b/example/client.py index 702cbf18e..7c589e39b 100644 --- a/example/client.py +++ b/example/client.py @@ -5,7 +5,7 @@ async def hello(): async with websockets.connect('ws://localhost:8765') as websocket: - + while True: name = input("What's your name? ") await websocket.send(name) print("> {}".format(name)) diff --git a/example/shutdown.py b/example/shutdown.py new file mode 100644 index 000000000..8846d1add --- /dev/null +++ b/example/shutdown.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python + +import asyncio +import signal +import websockets + +async def echo(websocket, path): + while True: + try: + msg = await websocket.recv() + except websockets.ConnectionClosed: + pass + else: + await websocket.send(msg) + +loop = asyncio.get_event_loop() + +# Create the server. +start_server = websockets.serve(echo, 'localhost', 8765) +server = loop.run_until_complete(start_server) + +# Run the server until SIGTERM. +stop_server = asyncio.Future() +loop.add_signal_handler(signal.SIGTERM, stop_server.set_result, None) +loop.run_until_complete(stop_server) + +# Shut down the server. +server.close() +loop.run_until_complete(server.wait_closed()) From 3ec6f896e50820ab0048533d89155aa48cca52e6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 5 May 2017 12:40:15 +0200 Subject: [PATCH 0216/1539] Make it possible to configure buffer sizes. Document how backpressure and buffers work. Refs #170. --- docs/deployment.rst | 74 +++++++++++++++++++++++++++++++++++++++++- websockets/client.py | 8 +++-- websockets/protocol.py | 23 +++++++++++-- websockets/server.py | 8 +++-- 4 files changed, 103 insertions(+), 10 deletions(-) diff --git a/docs/deployment.rst b/docs/deployment.rst index fb52421c2..c6fd1bf19 100644 --- a/docs/deployment.rst +++ b/docs/deployment.rst @@ -1,3 +1,76 @@ +Deployment +========== + +Backpressure +------------ + +.. note:: + + This section discusses the concept of backpressure from the perspective of + a server but the concepts also apply to clients. The issue is symmetrical. + +With a naive implementation, if a server receives inputs faster than it can +process them, or if it generates outputs faster than it can send them, data +accumulates in buffers, eventually causing the server to run out of memory and +crash. + +The solution to this problem is backpressure. Any part of the server that +receives inputs faster than it can it can process them and send the outputs +must propagate that information back to the previous part in the chain. + +``websockets`` is designed to make it easy to get backpressure right. + +For incoming data, ``websockets`` builds upon :class:`~asyncio.StreamReader` +which propagates backpressure to its own buffer and to the TCP stream. Frames +are parsed from the input stream and added to a bounded queue. If the queue +fills up, parsing halts until some the application reads a frame. + +For outgoing data, ``websockets`` builds upon :class:`~asyncio.StreamWriter` +which implements flow control. If the output buffers grow too large, it waits +until they're drained. That's why all APIs that write frames are asynchronous +in websockets (since version 2.0). + +Of course, it's still possible for an application to create its own unbounded +buffers and break the backpressure. Be careful with queues. + +Buffers +------- + +An asynchronous systems works best when its buffers are almost always empty. + +For example, if a client sends frames too fast for a server, the queue of +incoming frames will be constantly full. The server will always be 32 frames +(by default) behind the client. This consumes memory and adds latency for no +good reason. + +If buffers are almost always full and that problem cannot be solved by adding +capacity (typically because the system is bottlenecked by the output and +constantly regulated by backpressure), reducing the size of buffers minimizes +negative consequences. + +By default ``websockets`` has rather high limits. You can decrease them +according to your application's characteristics. + +Bufferbloat can happen at every level in the stack where there is a buffer. +The receiving side contains these buffers: + +- OS buffers: you shouldn't need to tune them in general. +- :class:`~asyncio.StreamReader` bytes buffer: the default limit is 64kB. + You can set another limit by passing a ``read_limit`` keyword argument to + :func:`~websockets.client.connect` or :func:`~websockets.server.serve`. +- ``websockets`` frame buffer: its size depends both on the size and the + number of frames it contains. By default the maximum size is 1MB and the + maximum number is 32. You can adjust these limits by setting the + ``max_size`` and ``max_queue`` keyword arguments of + :func:`~websockets.client.connect` or :func:`~websockets.server.serve`. + +The sending side contains these buffers: + +- :class:`~asyncio.StreamWriter` bytes buffer: the default size is 64kB. + You can set another limit by passing a ``write_limit`` keyword argument to + :func:`~websockets.client.connect` or :func:`~websockets.server.serve`. +- OS buffers: you shouldn't need to tune them in general. + Deployment ---------- @@ -34,7 +107,6 @@ Here's a full example (Unix-only): .. literalinclude:: ../example/shutdown.py - It's more difficult to achieve the same effect on Windows. Some third-party projects try to help with this problem. diff --git a/websockets/client.py b/websockets/client.py index ad15ba166..2ab81d92b 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -136,6 +136,7 @@ def handshake(self, wsuri, def connect(uri, *, klass=WebSocketClientProtocol, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, + read_limit=2 ** 16, write_limit=2 ** 16, loop=None, legacy_recv=False, origin=None, subprotocols=None, extra_headers=None, **kwds): @@ -154,9 +155,9 @@ def connect(uri, *, a ``wss://`` URI, if this argument isn't provided explicitly, it's set to ``True``, which means Python's default :class:`~ssl.SSLContext` is used. - The behavior of the ``timeout``, ``max_size``, and ``max_queue`` optional - arguments is described the documentation of - :class:`~websockets.protocol.WebSocketCommonProtocol`. + The behavior of the ``timeout``, ``max_size``, and ``max_queue``, + ``read_limit``, and ``write_limit`` optional arguments is described in the + documentation of :class:`~websockets.protocol.WebSocketCommonProtocol`. :func:`connect` also accepts the following optional arguments: @@ -186,6 +187,7 @@ def connect(uri, *, factory = lambda: klass( host=wsuri.host, port=wsuri.port, secure=wsuri.secure, timeout=timeout, max_size=max_size, max_queue=max_queue, + read_limit=read_limit, write_limit=write_limit, loop=loop, legacy_recv=legacy_recv, ) diff --git a/websockets/protocol.py b/websockets/protocol.py index 1b8287f67..530f998a0 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -78,6 +78,16 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): this is 128MB. You may want to lower the limits, depending on your application's requirements. + The ``read_limit`` argument sets the high-water limit of the buffer for + incoming bytes. The low-water limit is half the high-water limit. The + default value is 64kB, half of asyncio's default (based on the current + implementation of :class:`~asyncio.StreamReader`). + + The ``write_limit`` argument sets the high-water limit of the buffer for + outgoing bytes. The low-water limit is a quarter of the high-water limit. + The default value is 64kB, equal to asyncio's default (based on the + current implementation of ``_FlowControlMixin``). + As soon as the HTTP request and response in the opening handshake are processed, the request path is available in the :attr:`path` attribute, and the request and response HTTP headers are available: @@ -105,22 +115,27 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): def __init__(self, *, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, + read_limit=2 ** 16, write_limit=2 ** 16, loop=None, legacy_recv=False): self.host = host self.port = port self.secure = secure - self.timeout = timeout self.max_size = max_size + self.max_queue = max_queue + self.read_limit = read_limit + self.write_limit = write_limit + # Store a reference to loop to avoid relying on self._loop, a private - # attribute of StreamReaderProtocol, inherited from FlowControlMixin. + # attribute of StreamReaderProtocol, inherited from _FlowControlMixin. if loop is None: loop = asyncio.get_event_loop() self.loop = loop self.legacy_recv = legacy_recv - stream_reader = asyncio.StreamReader(loop=loop) + # This limit is both the line length limit and half the buffer limit. + stream_reader = asyncio.StreamReader(limit=read_limit // 2, loop=loop) super().__init__(stream_reader, self.client_connected, loop) self.reader = None @@ -636,6 +651,8 @@ def fail_connection(self, code=1011, reason=''): def client_connected(self, reader, writer): self.reader = reader self.writer = writer + # Configure write buffer limit. + self.writer._transport.set_write_buffer_limits(self.write_limit) # Start the task that handles incoming messages. self.worker_task = asyncio_ensure_future(self.run(), loop=self.loop) diff --git a/websockets/server.py b/websockets/server.py index cd8503ce8..fd535c2fe 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -388,6 +388,7 @@ def wait_closed(self): def serve(ws_handler, host=None, port=None, *, klass=WebSocketServerProtocol, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, + read_limit=2 ** 16, write_limit=2 ** 16, loop=None, legacy_recv=False, origins=None, subprotocols=None, extra_headers=None, **kwds): @@ -412,9 +413,9 @@ def serve(ws_handler, host=None, port=None, *, For example, you can set the ``ssl`` keyword argument to a :class:`~ssl.SSLContext` to enable TLS. - The behavior of the ``timeout``, ``max_size``, and ``max_queue`` optional - arguments is described the documentation of - :class:`~websockets.protocol.WebSocketCommonProtocol`. + The behavior of the ``timeout``, ``max_size``, and ``max_queue``, + ``read_limit``, and ``write_limit`` optional arguments is described in the + documentation of :class:`~websockets.protocol.WebSocketCommonProtocol`. :func:`serve` also accepts the following optional arguments: @@ -451,6 +452,7 @@ def serve(ws_handler, host=None, port=None, *, ws_handler, ws_server, host=host, port=port, secure=secure, timeout=timeout, max_size=max_size, max_queue=max_queue, + read_limit=read_limit, write_limit=write_limit, loop=loop, legacy_recv=legacy_recv, origins=origins, subprotocols=subprotocols, extra_headers=extra_headers, From d0d434a71f3cfa0270a8b0bd1f809b9807f26f5c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 5 May 2017 13:01:43 +0200 Subject: [PATCH 0217/1539] Document that port sharing is unsupported. Fix #116. --- docs/deployment.rst | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docs/deployment.rst b/docs/deployment.rst index c6fd1bf19..8786ad3c1 100644 --- a/docs/deployment.rst +++ b/docs/deployment.rst @@ -112,3 +112,16 @@ projects try to help with this problem. If your server doesn't run in the main thread, look at :func:`~asyncio.AbstractEventLoop.call_soon_threadsafe`. + +Port sharing +------------ + +The WebSocket protocol is an extension of HTTP/1.1. It can be tempting to +serve both HTTP and WebSocket on the same port. + +The author of ``websockets`` doesn't think that's a good idea, due to the +widely different operational characteristics of HTTP and WebSocket. + +If you need to respond to requests with a protocol other than WebSocket, for +example TCP or HTTP health checks, run a server for that protocol on another +port, within the same Python process, with :func:`~asyncio.start_server`. From d8739496f120e576014539b00194c09b86059d11 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 5 May 2017 13:33:27 +0200 Subject: [PATCH 0218/1539] Update changelog. --- docs/changelog.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index 8f6567532..07f28bc5d 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -9,6 +9,8 @@ Changelog * Added support rejecting incoming connections by customizing :meth:`~websockets.server.WebSocketServerProtocol.get_response_status()`. +* Made read and write buffer sizes configurable. + 3.3 ... From 8e6baa8d118cb6f6cff683c135f7170aeb763dba Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 5 May 2017 13:56:54 +0200 Subject: [PATCH 0219/1539] Fix tests on Circle CI. --- websockets/test_client_server.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 60e7616c7..6ad6ec20a 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -406,14 +406,13 @@ def test_connection_error_during_opening_handshake(self, _read_request): _read_request.side_effect = ConnectionError self.start_server() - with self.assertRaises(InvalidHandshake) as raised: + # Exception appears to be platform-dependent: InvalidHandshake on + # macOS, ConnectionResetError on Linux. This doesn't matter; this + # test primarily aims at covering a code path on the server side. + with self.assertRaises(Exception): self.start_client() self.stop_server() - # Opening handshake doesn't complete -- since we faked a connection - # error, the server doesn't send a response to the client. - self.assertEqual(str(raised.exception), "Malformed HTTP message") - @unittest.mock.patch('websockets.server.WebSocketServerProtocol.close') def test_connection_error_during_closing_handshake(self, close): close.side_effect = ConnectionError From 13f2e408707627260f2eed954bdec735f701daa5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 5 May 2017 14:42:09 +0200 Subject: [PATCH 0220/1539] Document how to keep connections open. Fix #110. --- docs/cheatsheet.rst | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/docs/cheatsheet.rst b/docs/cheatsheet.rst index eb2e1dd8c..04349badc 100644 --- a/docs/cheatsheet.rst +++ b/docs/cheatsheet.rst @@ -70,3 +70,24 @@ in particular. Fortunately Python's official documentation provides advice to `develop with asyncio`_. Check it out: it's invaluable! .. _develop with asyncio: https://docs.python.org/3/library/asyncio-dev.html + +Keeping connections open +------------------------ + +Pinging the other side once in a while is a good way to check whether the +connection is still working, and also to keep it open in case something kills +idle connections after some time:: + + while True: + try: + msg = await asyncio.wait_for(ws.recv(), timeout=20) + except asyncio.TimeoutError: + # No data in 20 seconds, check the connection. + try: + await asyncio.wait_for(ws.ping(), timeout=10) + except asyncio.TimeoutError: + # No response to ping in 10 seconds, disconnect. + break + else: + # do something with msg + ... From 76f02cab95879648481f2dd84e4c3c6924d52c54 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 5 May 2017 13:31:35 +0200 Subject: [PATCH 0221/1539] Make serve an async context manager. Fix #86. --- docs/changelog.rst | 3 +++ docs/deployment.rst | 14 ++++++++++---- example/oldshutdown.py | 29 +++++++++++++++++++++++++++++ example/shutdown.py | 20 +++++++++----------- websockets/client.py | 2 +- websockets/py35/client_server.py | 16 +++++++++++++--- websockets/py35/server.py | 22 ++++++++++++++++++++++ websockets/server.py | 16 +++++++++++++++- 8 files changed, 102 insertions(+), 20 deletions(-) create mode 100644 example/oldshutdown.py create mode 100644 websockets/py35/server.py diff --git a/docs/changelog.rst b/docs/changelog.rst index 07f28bc5d..26093b198 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -6,6 +6,9 @@ Changelog *In development* +* :func:`~websockets.server.serve` can be used as an asynchronous context + manager on Python ≥ 3.5. + * Added support rejecting incoming connections by customizing :meth:`~websockets.server.WebSocketServerProtocol.get_response_status()`. diff --git a/docs/deployment.rst b/docs/deployment.rst index 8786ad3c1..9ce0745f6 100644 --- a/docs/deployment.rst +++ b/docs/deployment.rst @@ -91,11 +91,12 @@ Graceful shutdown ----------------- You may want to close connections gracefully when shutting down the server, -perhaps after executing some cleanup logic. +perhaps after executing some cleanup logic. There are two ways to achieve this +with the object returned by :func:`~websockets.server.serve`: -The proper way to do this is to call the ``close()`` method of the object -returned by :func:`~websockets.server.serve`, then wait for ``wait_closed()`` -to complete. +- using it as a asynchronous context manager, or +- calling its ``close()`` method, then waiting for its ``wait_closed()`` + method to complete. Tasks that handle connections will be cancelled, in the sense that :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` raises @@ -107,6 +108,11 @@ Here's a full example (Unix-only): .. literalinclude:: ../example/shutdown.py +``async``, ``await``, and asynchronous context managers aren't available in +Python < 3.5. Here's the equivalent for older Python versions: + +.. literalinclude:: ../example/oldshutdown.py + It's more difficult to achieve the same effect on Windows. Some third-party projects try to help with this problem. diff --git a/example/oldshutdown.py b/example/oldshutdown.py new file mode 100644 index 000000000..b95fa91a3 --- /dev/null +++ b/example/oldshutdown.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python + +import asyncio +import signal +import websockets + +async def echo(websocket, path): + while True: + try: + msg = await websocket.recv() + except websockets.ConnectionClosed: + pass + else: + await websocket.send(msg) + +loop = asyncio.get_event_loop() + +# Create the server. +start_server = websockets.serve(echo, 'localhost', 8765) +server = loop.run_until_complete(start_server) + +# Run the server until SIGTERM. +stop = asyncio.Future() +loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) +loop.run_until_complete(stop) + +# Shut down the server. +server.close() +loop.run_until_complete(server.wait_closed()) diff --git a/example/shutdown.py b/example/shutdown.py index 8846d1add..1f686d160 100644 --- a/example/shutdown.py +++ b/example/shutdown.py @@ -13,17 +13,15 @@ async def echo(websocket, path): else: await websocket.send(msg) -loop = asyncio.get_event_loop() +async def echo_server(stop): + async with websockets.serve(echo, 'localhost', 8765): + await stop -# Create the server. -start_server = websockets.serve(echo, 'localhost', 8765) -server = loop.run_until_complete(start_server) +loop = asyncio.get_event_loop() -# Run the server until SIGTERM. -stop_server = asyncio.Future() -loop.add_signal_handler(signal.SIGTERM, stop_server.set_result, None) -loop.run_until_complete(stop_server) +# The stop condition is set when receiving SIGTERM. +stop = asyncio.Future() +loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) -# Shut down the server. -server.close() -loop.run_until_complete(server.wait_closed()) +# Run the server until the stop condition is met. +loop.run_until_complete(echo_server(stop)) diff --git a/websockets/client.py b/websockets/client.py index 2ab81d92b..a90939e18 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -147,7 +147,7 @@ def connect(uri, *, send and receive messages. :func:`connect` is a wrapper around the event loop's - :meth:`~asyncio.BaseEventLoop.create_connection` method. Extra keyword + :meth:`~asyncio.BaseEventLoop.create_connection` method. Unknown keyword arguments are passed to :meth:`~asyncio.BaseEventLoop.create_connection`. For example, you can set the ``ssl`` keyword argument to a diff --git a/websockets/py35/client_server.py b/websockets/py35/client_server.py index bfafa39c3..624824aa1 100644 --- a/websockets/py35/client_server.py +++ b/websockets/py35/client_server.py @@ -17,17 +17,27 @@ def setUp(self): def tearDown(self): self.loop.close() - def test_basic(self): + def test_client(self): server = serve(handler, 'localhost', 8642) self.server = self.loop.run_until_complete(server) - async def basic(): + async def run_client(): async with connect('ws://localhost:8642/') as client: await client.send("Hello!") reply = await client.recv() self.assertEqual(reply, "Hello!") - self.loop.run_until_complete(basic()) + self.loop.run_until_complete(run_client()) self.server.close() self.loop.run_until_complete(self.server.wait_closed()) + + def test_server(self): + async def run_server(): + async with serve(handler, 'localhost', 8642): + client = await connect('ws://localhost:8642/') + await client.send("Hello!") + reply = await client.recv() + self.assertEqual(reply, "Hello!") + + self.loop.run_until_complete(run_server()) diff --git a/websockets/py35/server.py b/websockets/py35/server.py new file mode 100644 index 000000000..3aba1c84e --- /dev/null +++ b/websockets/py35/server.py @@ -0,0 +1,22 @@ +class Serve: + """ + This class wraps :func:`~websockets.server.serve` on Python ≥ 3.5. + + This allows using it as an asynchronous context manager. + + """ + def __init__(self, *args, **kwargs): + self.server = self.__class__.__wrapped__(*args, **kwargs) + + async def __aenter__(self): + self.server = await self + return self.server + + async def __aexit__(self, exc_type, exc_value, traceback): + self.server.close() + await self.server.wait_closed() + + def __await__(self): + return (yield from self.server) + + __iter__ = __await__ diff --git a/websockets/server.py b/websockets/server.py index fd535c2fe..119e251e5 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -407,7 +407,7 @@ def serve(ws_handler, host=None, port=None, *, :func:`serve` is a wrapper around the event loop's :meth:`~asyncio.BaseEventLoop.create_server` method. ``host``, ``port`` as - well as extra keyword arguments are passed to + well as unknown keyword arguments are passed to :meth:`~asyncio.BaseEventLoop.create_server`. For example, you can set the ``ssl`` keyword argument to a @@ -441,6 +441,9 @@ def serve(ws_handler, host=None, port=None, *, logger.setLevel(logging.ERROR) logger.addHandler(logging.StreamHandler()) + On Python 3.5, :func:`serve` can be used as a asynchronous context + manager. In that case, the server is shut down when exiting the context. + """ if loop is None: loop = asyncio.get_event_loop() @@ -462,3 +465,14 @@ def serve(ws_handler, host=None, port=None, *, ws_server.wrap(server) return ws_server + + +try: + from .py35.server import Serve +except (SyntaxError, ImportError): # pragma: no cover + pass +else: + Serve.__wrapped__ = serve + # Copy over docstring to support building documentation on Python 3.5. + Serve.__doc__ = serve.__doc__ + serve = Serve From 7dfccc446ec73eea28c5fee6ad2c76b080005fab Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 9 May 2017 22:46:23 +0200 Subject: [PATCH 0222/1539] Simplify compliance tests. --- compliance/README.rst | 14 ++++++++++---- compliance/test_client.py | 26 +++++++------------------- compliance/test_server.py | 28 ++++++++-------------------- 3 files changed, 25 insertions(+), 43 deletions(-) diff --git a/compliance/README.rst b/compliance/README.rst index 277804ad7..cfaaafca7 100644 --- a/compliance/README.rst +++ b/compliance/README.rst @@ -29,8 +29,14 @@ supports Python 3; you need two different environments. Conformance notes ----------------- -Test cases 6.4.3 and 6.4.4 are actually more strict than the RFC. Given its -implementation, ``websockets`` gets a "Non-Strict". +Some test cases are more strict than the RFC. Given the implementation of the +library and the test echo client or server, ``websockets`` gets a "Non-Strict" +in these cases. -Test cases 12.* and 13.* don't run because ``websockets`` doesn't implement -compression at this time. +In 3.2, 3.3, 4.1.3, 4.1.4, 4.2.3, 4.2.4, and 5.15 ``websockets`` notices the +protocol error and closes the connection before it has had a chance to echo +the previous frame. + +In 6.4.3 and 6.4.4, even though it uses an incremental decoder, ``websockets`` +doesn't notice the invalid utf-8 fast enough to get a "Strict" pass. These +tests are more strict than the RFC. diff --git a/compliance/test_client.py b/compliance/test_client.py index 6cf7d1b76..382d06a05 100644 --- a/compliance/test_client.py +++ b/compliance/test_client.py @@ -16,23 +16,6 @@ AGENT = 'websockets' -class EchoClientProtocol(websockets.WebSocketClientProtocol): - """ - WebSocket client protocol that echoes messages synchronously. - - """ - def __init__(self, *args, **kwargs): - kwargs['max_size'] = 2 ** 25 - super().__init__(*args, **kwargs) - - @asyncio.coroutine - def read_message(self): - msg = yield from super().read_message() - if msg is not None: - yield from self.send(msg) - return msg - - @asyncio.coroutine def get_case_count(server): uri = server + '/getCaseCount' @@ -45,8 +28,13 @@ def get_case_count(server): @asyncio.coroutine def run_case(server, case, agent): uri = server + '/runCase?case={}&agent={}'.format(case, agent) - ws = yield from websockets.connect(uri, klass=EchoClientProtocol) - yield from ws.worker_task + ws = yield from websockets.connect(uri, max_size=2 ** 25, max_queue=1) + while True: + try: + msg = yield from ws.recv() + yield from ws.send(msg) + except websockets.ConnectionClosed: + break @asyncio.coroutine diff --git a/compliance/test_server.py b/compliance/test_server.py index 75a2d9d00..75e0e3044 100644 --- a/compliance/test_server.py +++ b/compliance/test_server.py @@ -10,30 +10,18 @@ # logging.getLogger('websockets').setLevel(logging.DEBUG) -class EchoServerProtocol(websockets.WebSocketServerProtocol): - """ - WebSocket server protocol that echoes messages synchronously. - - """ - def __init__(self, *args, **kwargs): - kwargs['max_size'] = 2 ** 25 - super().__init__(*args, **kwargs) - - @asyncio.coroutine - def read_message(self): - msg = yield from super().read_message() - if msg is not None: - yield from self.send(msg) - return msg - - @asyncio.coroutine -def noop(ws, path): - yield from ws.worker_task +def echo(ws, path): + while True: + try: + msg = yield from ws.recv() + yield from ws.send(msg) + except websockets.ConnectionClosed: + break start_server = websockets.serve( - noop, '127.0.0.1', 8642, klass=EchoServerProtocol) + echo, '127.0.0.1', 8642, max_size=2 ** 25, max_queue=1) try: asyncio.get_event_loop().run_until_complete(start_server) From 5d361a361a31ce19896dd0fb2c9dc684514e2a09 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 10 May 2017 23:04:03 +0200 Subject: [PATCH 0223/1539] Revert accidental change. Fix #183. --- example/client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/example/client.py b/example/client.py index 7c589e39b..5a3a026b4 100644 --- a/example/client.py +++ b/example/client.py @@ -5,7 +5,6 @@ async def hello(): async with websockets.connect('ws://localhost:8765') as websocket: - while True: name = input("What's your name? ") await websocket.send(name) print("> {}".format(name)) From b1d09a1b1c36215a59f6a630215ed15d4582b3a8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Jul 2017 18:49:10 +0200 Subject: [PATCH 0224/1539] Fix Circle CI. --- circle.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/circle.yml b/circle.yml index 13c54c381..51fd7f60e 100644 --- a/circle.yml +++ b/circle.yml @@ -1,11 +1,11 @@ machine: post: - - pyenv global 3.3.6 3.4.4 3.5.2 3.6.1 + - pyenv global 3.6.1 3.5.3 3.4.4 3.3.6 python: version: 3.6.1 dependencies: - pre: + override: - pip install tox test: From f120792cbb52d55be1ad8388d24fbd087adfc111 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Jul 2017 18:27:45 +0200 Subject: [PATCH 0225/1539] Replace MIME parsing with custom HTTP parsing. Given that websockets makes straightforward use of HTTP, that websocket implementations can be expected not to exhibit legacy behaviors, and that RFC 7230 deprecates this behavior, parsing HTTP is doable. Thanks https://github.com/njsmith/h11 for providing some inspiration, especially for translating the RFC to regular expressions and figuring out some edge cases. I expect the new implementation to be faster, since it has a much tighter focus than the stdlib's general purpose MIME parser, and possibly more secure, since it was written from the beginning with security as a primary goal (with the caveat that it's new code, which means it's more likely to have security issues). Fix #19. --- docs/changelog.rst | 2 + websockets/client.py | 14 ++--- websockets/http.py | 133 +++++++++++++++++++++++++++++++--------- websockets/protocol.py | 4 +- websockets/server.py | 15 +++-- websockets/test_http.py | 40 +++++++++--- 6 files changed, 155 insertions(+), 53 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 26093b198..0bd7910c0 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -14,6 +14,8 @@ Changelog * Made read and write buffer sizes configurable. +* Rewrote HTTP handling for simplicity and performance. + 3.3 ... diff --git a/websockets/client.py b/websockets/client.py index a90939e18..9aea4f6b7 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -5,7 +5,7 @@ import asyncio import collections.abc -import email.message +import http.client from .exceptions import InvalidHandshake, InvalidMessage from .handshake import build_request, check_response @@ -35,9 +35,8 @@ def write_http_request(self, path, headers): """ self.path = path - self.request_headers = email.message.Message() - for name, value in headers: - self.request_headers[name] = value + self.request_headers = http.client.HTTPMessage() + self.request_headers._headers = headers # HACK self.raw_request_headers = headers # Since the path and headers only contain ASCII characters, @@ -63,10 +62,11 @@ def read_http_response(self): except ValueError as exc: raise InvalidMessage("Malformed HTTP message") from exc - self.response_headers = headers - self.raw_response_headers = list(headers.raw_items()) + self.response_headers = http.client.HTTPMessage() + self.response_headers._headers = headers # HACK + self.raw_response_headers = headers - return status_code, headers + return status_code, self.response_headers def process_subprotocol(self, get_header, subprotocols=None): """ diff --git a/websockets/http.py b/websockets/http.py index 173a35691..48ec2f5e8 100644 --- a/websockets/http.py +++ b/websockets/http.py @@ -8,8 +8,7 @@ """ import asyncio -import email.parser -import io +import re import sys from .version import version as websockets_version @@ -26,6 +25,26 @@ )) +# See https://tools.ietf.org/html/rfc7230#appendix-B. + +# Regex for validating header names. + +_token_re = re.compile(rb'^[-!#$%&\'*+.^_`|~0-9a-zA-Z]+$') + +# Regex for validating header values. + +# We don't attempt to support obsolete line folding. + +# Include HTAB (\x09), SP (\x20), VCHAR (\x21-\x7e), obs-text (\x80-\xff). + +# The ABNF is complicated because it attempts to express that optional +# whitespace is ignored. We strip whitespace and don't revalidate that. + +# See also https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189 + +_value_re = re.compile(rb'^[\x09\x20-\x7e\x80-\xff]*$') + + @asyncio.coroutine def read_request(stream): """ @@ -34,20 +53,38 @@ def read_request(stream): ``stream`` is an :class:`~asyncio.StreamReader`. Return ``(path, headers)`` where ``path`` is a :class:`str` and - ``headers`` is a :class:`~email.message.Message`. ``path`` isn't - URL-decoded. + ``headers`` is a list of ``(name, value)`` tuples. + + ``path`` isn't URL-decoded or validated in any way. + + Non-ASCII characters are represented with surrogate escapes. Raise an exception if the request isn't well formatted. The request is assumed not to contain a body. """ - request_line, headers = yield from read_message(stream) - method, path, version = request_line[:-2].decode().split(None, 2) - if method != 'GET': - raise ValueError("Unsupported method") - if version != 'HTTP/1.1': - raise ValueError("Unsupported HTTP version") + # https://tools.ietf.org/html/rfc7230#section-3.1.1 + + # Parsing is simple because fixed values are expected for method and + # version and because path isn't checked. Since WebSocket software tends + # to implement HTTP/1.1 strictly, there's little need for lenient parsing. + + # Given the implementation of read_line(), request_line ends with CRLF. + request_line = yield from read_line(stream) + + # This may raise "ValueError: not enough values to unpack" + method, path, version = request_line[:-2].split(b' ', 2) + + if method != b'GET': + raise ValueError("Unsupported HTTP method: %r" % method) + if version != b'HTTP/1.1': + raise ValueError("Unsupported HTTP version: %r" % version) + + path = path.decode('ascii', 'surrogateescape') + + headers = yield from read_headers(stream) + return path, headers @@ -59,45 +96,82 @@ def read_response(stream): ``stream`` is an :class:`~asyncio.StreamReader`. Return ``(status, headers)`` where ``status`` is a :class:`int` and - ``headers`` is a :class:`~email.message.Message`. + ``headers`` is a list of ``(name, value)`` tuples. + + Non-ASCII characters are represented with surrogate escapes. Raise an exception if the request isn't well formatted. The response is assumed not to contain a body. """ - status_line, headers = yield from read_message(stream) - version, status, reason = status_line[:-2].decode().split(" ", 2) - if version != 'HTTP/1.1': - raise ValueError("Unsupported HTTP version") - return int(status), headers + # https://tools.ietf.org/html/rfc7230#section-3.1.2 + + # As in read_request, parsing is simple because a fixed value is expected + # for version, status is a 3-digit number, and reason can be ignored. + + # Given the implementation of read_line(), status_line ends with CRLF. + status_line = yield from read_line(stream) + + # This may raise "ValueError: not enough values to unpack" + version, status, reason = status_line[:-2].split(b' ', 2) + + if version != b'HTTP/1.1': + raise ValueError("Unsupported HTTP version: %r" % version) + # This may raise "ValueError: invalid literal for int() with base 10" + status = int(status) + if not 100 <= status < 1000: + raise ValueError("Unsupported HTTP status code: %d" % status) + if not _value_re.match(reason): + raise ValueError("Invalid HTTP reason phrase: %r" % reason) + + headers = yield from read_headers(stream) + + return status, headers @asyncio.coroutine -def read_message(stream): +def read_headers(stream): """ Read an HTTP message from ``stream``. ``stream`` is an :class:`~asyncio.StreamReader`. Return ``(start_line, headers)`` where ``start_line`` is :class:`bytes` - and ``headers`` is a :class:`~email.message.Message`. + and ``headers`` is a list of ``(name, value)`` tuples. + + Non-ASCII characters are represented with surrogate escapes. The message is assumed not to contain a body. """ - start_line = yield from read_line(stream) - header_lines = io.BytesIO() - for num in range(MAX_HEADERS): - header_line = yield from read_line(stream) - header_lines.write(header_line) - if header_line == b'\r\n': + # https://tools.ietf.org/html/rfc7230#section-3.2 + + # We don't attempt to support obsolete line folding. + + headers = [] + for _ in range(MAX_HEADERS): + line = yield from read_line(stream) + if line == b'\r\n': break + + # This may raise "ValueError: not enough values to unpack" + name, value = line[:-2].split(b':', 1) + if not _token_re.match(name): + raise ValueError("Invalid HTTP header name: %r" % name) + value = value.strip(b' \t') + if not _value_re.match(value): + raise ValueError("Invalid HTTP header value: %r" % value) + + headers.append(( + name.decode('ascii'), # guaranteed to be ASCII at this point + value.decode('ascii', 'surrogateescape'), + )) + else: - raise ValueError("Too many headers") - header_lines.seek(0) - headers = email.parser.BytesHeaderParser().parse(header_lines) - return start_line, headers + raise ValueError("Too many HTTP headers") + + return headers @asyncio.coroutine @@ -108,9 +182,12 @@ def read_line(stream): ``stream`` is an :class:`~asyncio.StreamReader`. """ + # Security: this is bounded by the StreamReader's limit (default = 32kB). line = yield from stream.readline() + # Security: this guarantees header values are small (hardcoded = 4kB) if len(line) > MAX_LINE: raise ValueError("Line too long") + # Not mandatory but safe - https://tools.ietf.org/html/rfc7230#section-3.5 if not line.endswith(b'\r\n'): raise ValueError("Line without CRLF") return line diff --git a/websockets/protocol.py b/websockets/protocol.py index 530f998a0..b0fb7c893 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -92,11 +92,13 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): processed, the request path is available in the :attr:`path` attribute, and the request and response HTTP headers are available: - * as a MIME :class:`~email.message.Message` in the :attr:`request_headers` + * as a :class:`~http.client.HTTPMessage` in the :attr:`request_headers` and :attr:`response_headers` attributes * as an iterable of (name, value) pairs in the :attr:`raw_request_headers` and :attr:`raw_response_headers` attributes + These attributes must be treated as immutable. + If a subprotocol was negotiated, it's available in the :attr:`subprotocol` attribute. diff --git a/websockets/server.py b/websockets/server.py index 119e251e5..f4da5c59f 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -6,8 +6,7 @@ import asyncio import collections.abc -import email.message -import http +import http.client import logging from .compatibility import asyncio_ensure_future @@ -155,10 +154,11 @@ def read_http_request(self): raise InvalidMessage("Malformed HTTP message") from exc self.path = path - self.request_headers = headers - self.raw_request_headers = list(headers.raw_items()) + self.request_headers = http.client.HTTPMessage() + self.request_headers._headers = headers # HACK + self.raw_request_headers = headers - return path, headers + return path, self.request_headers @asyncio.coroutine def write_http_response(self, status, headers): @@ -166,9 +166,8 @@ def write_http_response(self, status, headers): Write status line and headers to the HTTP response. """ - self.response_headers = email.message.Message() - for name, value in headers: - self.response_headers[name] = value + self.response_headers = http.client.HTTPMessage() + self.response_headers._headers = headers # HACK self.raw_response_headers = headers # Since the status line and headers only contain ASCII characters, diff --git a/websockets/test_http.py b/websockets/test_http.py index b31bd84d0..8035451c7 100644 --- a/websockets/test_http.py +++ b/websockets/test_http.py @@ -2,7 +2,7 @@ import unittest from .http import * -from .http import read_message # private API +from .http import read_headers # private API class HTTPTests(unittest.TestCase): @@ -32,7 +32,7 @@ def test_read_request(self): ) path, hdrs = self.loop.run_until_complete(read_request(self.stream)) self.assertEqual(path, '/chat') - self.assertEqual(hdrs['Upgrade'], 'websocket') + self.assertEqual(dict(hdrs)['Upgrade'], 'websocket') def test_read_response(self): # Example from the protocol overview in RFC 6455 @@ -46,32 +46,54 @@ def test_read_response(self): ) status, hdrs = self.loop.run_until_complete(read_response(self.stream)) self.assertEqual(status, 101) - self.assertEqual(hdrs['Upgrade'], 'websocket') + self.assertEqual(dict(hdrs)['Upgrade'], 'websocket') - def test_method(self): + def test_request_method(self): self.stream.feed_data(b'OPTIONS * HTTP/1.1\r\n\r\n') with self.assertRaises(ValueError): self.loop.run_until_complete(read_request(self.stream)) - def test_version(self): + def test_request_version(self): self.stream.feed_data(b'GET /chat HTTP/1.0\r\n\r\n') with self.assertRaises(ValueError): self.loop.run_until_complete(read_request(self.stream)) + + def test_response_version(self): self.stream.feed_data(b'HTTP/1.0 400 Bad Request\r\n\r\n') with self.assertRaises(ValueError): self.loop.run_until_complete(read_response(self.stream)) + def test_response_status(self): + self.stream.feed_data(b'HTTP/1.1 007 My name is Bond\r\n\r\n') + with self.assertRaises(ValueError): + self.loop.run_until_complete(read_response(self.stream)) + + def test_response_reason(self): + self.stream.feed_data(b'HTTP/1.1 200 \x7f\r\n\r\n') + with self.assertRaises(ValueError): + self.loop.run_until_complete(read_response(self.stream)) + + def test_header_name(self): + self.stream.feed_data(b'foo bar: baz qux\r\n\r\n') + with self.assertRaises(ValueError): + self.loop.run_until_complete(read_headers(self.stream)) + + def test_header_value(self): + self.stream.feed_data(b'foo: \x00\x00\x0f\r\n\r\n') + with self.assertRaises(ValueError): + self.loop.run_until_complete(read_headers(self.stream)) + def test_headers_limit(self): self.stream.feed_data(b'foo: bar\r\n' * 500 + b'\r\n') with self.assertRaises(ValueError): - self.loop.run_until_complete(read_message(self.stream)) + self.loop.run_until_complete(read_headers(self.stream)) def test_line_limit(self): self.stream.feed_data(b'a' * 5000 + b'\r\n\r\n') with self.assertRaises(ValueError): - self.loop.run_until_complete(read_message(self.stream)) + self.loop.run_until_complete(read_headers(self.stream)) def test_line_ending(self): - self.stream.feed_data(b'GET / HTTP/1.1\n\n') + self.stream.feed_data(b'foo: bar\n\n') with self.assertRaises(ValueError): - self.loop.run_until_complete(read_message(self.stream)) + self.loop.run_until_complete(read_headers(self.stream)) From 00efcc605fd0f88eb0bef9e5b7bda6fdb7d954ad Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 17 Jul 2017 10:21:20 +0200 Subject: [PATCH 0226/1539] Encapsulate creation of HTTP headers. Since this part is hacky and likely to change in the future (#210), wrap it into a single function and add tests for the public API we really care about. --- websockets/client.py | 9 +++------ websockets/http.py | 13 +++++++++++++ websockets/server.py | 10 ++++------ websockets/test_http.py | 32 ++++++++++++++++++++++++++++++-- 4 files changed, 50 insertions(+), 14 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index 9aea4f6b7..411cf37f5 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -5,11 +5,10 @@ import asyncio import collections.abc -import http.client from .exceptions import InvalidHandshake, InvalidMessage from .handshake import build_request, check_response -from .http import USER_AGENT, read_response +from .http import USER_AGENT, build_headers, read_response from .protocol import CONNECTING, OPEN, WebSocketCommonProtocol from .uri import parse_uri @@ -35,8 +34,7 @@ def write_http_request(self, path, headers): """ self.path = path - self.request_headers = http.client.HTTPMessage() - self.request_headers._headers = headers # HACK + self.request_headers = build_headers(headers) self.raw_request_headers = headers # Since the path and headers only contain ASCII characters, @@ -62,8 +60,7 @@ def read_http_response(self): except ValueError as exc: raise InvalidMessage("Malformed HTTP message") from exc - self.response_headers = http.client.HTTPMessage() - self.response_headers._headers = headers # HACK + self.response_headers = build_headers(headers) self.raw_response_headers = headers return status_code, self.response_headers diff --git a/websockets/http.py b/websockets/http.py index 48ec2f5e8..e71e8c78d 100644 --- a/websockets/http.py +++ b/websockets/http.py @@ -8,6 +8,7 @@ """ import asyncio +import http.client import re import sys @@ -191,3 +192,15 @@ def read_line(stream): if not line.endswith(b'\r\n'): raise ValueError("Line without CRLF") return line + + +def build_headers(raw_headers): + """ + Build a date structure for HTTP headers from a list of name - value pairs. + + See also https://github.com/aaugustin/websockets/issues/210. + + """ + headers = http.client.HTTPMessage() + headers._headers = raw_headers # HACK + return headers diff --git a/websockets/server.py b/websockets/server.py index f4da5c59f..43a9c682b 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -6,13 +6,13 @@ import asyncio import collections.abc -import http.client +import http import logging from .compatibility import asyncio_ensure_future from .exceptions import InvalidHandshake, InvalidMessage, InvalidOrigin from .handshake import build_response, check_request -from .http import USER_AGENT, read_request +from .http import USER_AGENT, build_headers, read_request from .protocol import CONNECTING, OPEN, WebSocketCommonProtocol @@ -154,8 +154,7 @@ def read_http_request(self): raise InvalidMessage("Malformed HTTP message") from exc self.path = path - self.request_headers = http.client.HTTPMessage() - self.request_headers._headers = headers # HACK + self.request_headers = build_headers(headers) self.raw_request_headers = headers return path, self.request_headers @@ -166,8 +165,7 @@ def write_http_response(self, status, headers): Write status line and headers to the HTTP response. """ - self.response_headers = http.client.HTTPMessage() - self.response_headers._headers = headers # HACK + self.response_headers = build_headers(headers) self.raw_response_headers = headers # Since the status line and headers only contain ASCII characters, diff --git a/websockets/test_http.py b/websockets/test_http.py index 8035451c7..28ad4a25e 100644 --- a/websockets/test_http.py +++ b/websockets/test_http.py @@ -2,10 +2,10 @@ import unittest from .http import * -from .http import read_headers # private API +from .http import build_headers, read_headers -class HTTPTests(unittest.TestCase): +class HTTPAsyncTests(unittest.TestCase): def setUp(self): super().setUp() @@ -97,3 +97,31 @@ def test_line_ending(self): self.stream.feed_data(b'foo: bar\n\n') with self.assertRaises(ValueError): self.loop.run_until_complete(read_headers(self.stream)) + + +class HTTPSyncTests(unittest.TestCase): + + def test_build_headers(self): + headers = build_headers([ + ('X-Foo', 'Bar'), + ('X-Baz', 'Quux Quux'), + ]) + + self.assertEqual(headers['X-Foo'], 'Bar') + self.assertEqual(headers['X-Bar'], None) + + self.assertEqual(headers.get('X-Bar', ''), '') + self.assertEqual(headers.get('X-Baz', ''), 'Quux Quux') + + def test_build_headers_multi_value(self): + headers = build_headers([ + ('X-Foo', 'Bar'), + ('X-Foo', 'Baz'), + ]) + + # Getting a single value is non-deterministic. + self.assertIn(headers['X-Foo'], ['Bar', 'Baz']) + self.assertIn(headers.get('X-Foo'), ['Bar', 'Baz']) + + # Ordering is deterministic when getting all values. + self.assertEqual(headers.get_all('X-Foo'), ['Bar', 'Baz']) From ce61a9dac266cac87fa1f1ae20ed37ad7bf60770 Mon Sep 17 00:00:00 2001 From: Chris Jerdonek Date: Mon, 10 Jul 2017 19:06:52 -0700 Subject: [PATCH 0227/1539] Address issue #203: make WebSocketServer not inherit. This commit changes WebSocketServer not to inherit from asyncio.AbstractServer. It also improves the documentation to explain more clearly the relationship between the WebSocketServer and the underlying asyncio.Server object. --- docs/api.rst | 5 +++ docs/cheatsheet.rst | 2 +- websockets/server.py | 72 ++++++++++++++++++++++++++++++-------------- 3 files changed, 55 insertions(+), 24 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index d65418d15..62b580129 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -34,6 +34,11 @@ Server .. autofunction:: serve(ws_handler, host=None, port=None, *, klass=WebSocketServerProtocol, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, loop=None, origins=None, subprotocols=None, extra_headers=None, **kwds) + .. autoclass:: WebSocketServer + + .. automethod:: close() + .. automethod:: wait_closed() + .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, loop=None, origins=None, subprotocols=None, extra_headers=None) .. automethod:: handshake(origins=None, subprotocols=None, extra_headers=None) diff --git a/docs/cheatsheet.rst b/docs/cheatsheet.rst index 04349badc..cf6897257 100644 --- a/docs/cheatsheet.rst +++ b/docs/cheatsheet.rst @@ -16,7 +16,7 @@ Server but it isn't needed in general. * Create a server with :func:`~websockets.server.serve` which is similar to - asyncio's :meth:`~asyncio.BaseEventLoop.create_server`. + asyncio's :meth:`~asyncio.AbstractEventLoop.create_server`. * The server takes care of establishing connections, then lets the handler execute the application logic, and finally closes the connection after diff --git a/websockets/server.py b/websockets/server.py index 43a9c682b..16427c670 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -306,9 +306,24 @@ def handshake(self, origins=None, subprotocols=None, extra_headers=None): return path -class WebSocketServer(asyncio.AbstractServer): +class WebSocketServer: """ - Wrapper for :class:`~asyncio.Server` that triggers the closing handshake. + Wraps an underlying :class:`~asyncio.Server` object. + + This class provides the return type of :func:`~websockets.server.serve`. + This class shouldn't be instantiated directly. + + Objects of this class store a reference to an underlying + :class:`~asyncio.Server` object returned by + :meth:`~asyncio.AbstractEventLoop.create_server`. The class stores a + reference rather than inheriting from :class:`~asyncio.Server` in part + because :meth:`~asyncio.AbstractEventLoop.create_server` doesn't support + passing a custom :class:`~asyncio.Server` class. + + :class:`WebSocketServer` supports cleaning up the underlying + :class:`~asyncio.Server` object and other resources by implementing the + interface of ``asyncio.events.AbstractServer``, namely its ``close()`` + and ``wait_closed()`` methods. """ def __init__(self, loop): @@ -322,13 +337,13 @@ def wrap(self, server): """ Attach to a given :class:`~asyncio.Server`. - Since :meth:`~asyncio.BaseEventLoop.create_server` doesn't support + Since :meth:`~asyncio.AbstractEventLoop.create_server` doesn't support injecting a custom ``Server`` class, a simple solution that doesn't rely on private APIs is to: - instantiate a :class:`WebSocketServer` - give the protocol factory a reference to that instance - - call :meth:`~asyncio.BaseEventLoop.create_server` with the factory + - call :meth:`~asyncio.AbstractEventLoop.create_server` with the factory - attach the resulting :class:`~asyncio.Server` with this method """ @@ -342,7 +357,11 @@ def unregister(self, protocol): def close(self): """ - Stop accepting new connections and close open connections. + Close the underlying server, and clean up connections. + + This calls :meth:`~asyncio.Server.close` on the underlying + :class:`~asyncio.Server` object, closes open connections with + status code 1001, and stops accepting new connections. """ # Make a note that the server is shutting down. Websocket connections @@ -365,7 +384,11 @@ def close(self): @asyncio.coroutine def wait_closed(self): """ - Wait until all connections are closed. + Wait until the underlying server and all connections are closed. + + This calls :meth:`~asyncio.Server.wait_closed` on the underlying + :class:`~asyncio.Server` object and waits until closing handshakes + are complete and all connections are closed. This method must be called after :meth:`close()`. @@ -390,25 +413,31 @@ def serve(ws_handler, host=None, port=None, *, origins=None, subprotocols=None, extra_headers=None, **kwds): """ - This coroutine creates a WebSocket server. + Create, start, and return a :class:`WebSocketServer` object. - It yields a :class:`~asyncio.Server` which provides: + :func:`serve` is a wrapper around the event loop's + :meth:`~asyncio.AbstractEventLoop.create_server` method. + Internally, the function creates and starts a :class:`~asyncio.Server` + object by calling :meth:`~asyncio.AbstractEventLoop.create_server`. The + :class:`WebSocketServer` keeps a reference to this object. - * a :meth:`~asyncio.Server.close` method that closes open connections with - status code 1001 and stops accepting new connections - * a :meth:`~asyncio.Server.wait_closed` coroutine that waits until closing - handshakes complete and connections are closed. + The returned :class:`WebSocketServer` and its resources can be cleaned + up by calling its :meth:`~websockets.server.WebSocketServer.close` and + :meth:`~websockets.server.WebSocketServer.wait_closed` methods. - ``ws_handler`` is the WebSocket handler. It must be a coroutine accepting - two arguments: a :class:`WebSocketServerProtocol` and the request URI. + On Python 3.5 and greater, :func:`serve` can also be used as an + asynchronous context manager. In this case, the server is shut down + when exiting the context. - :func:`serve` is a wrapper around the event loop's - :meth:`~asyncio.BaseEventLoop.create_server` method. ``host``, ``port`` as - well as unknown keyword arguments are passed to - :meth:`~asyncio.BaseEventLoop.create_server`. + The ``ws_handler`` argument is the WebSocket handler. It must be a + coroutine accepting two arguments: a :class:`WebSocketServerProtocol` + and the request URI. - For example, you can set the ``ssl`` keyword argument to a - :class:`~ssl.SSLContext` to enable TLS. + The ``host`` and ``port`` arguments, as well as unrecognized keyword + arguments, are passed along to + :meth:`~asyncio.AbstractEventLoop.create_server`. For example, you can + set the ``ssl`` keyword argument to a :class:`~ssl.SSLContext` to enable + TLS. The behavior of the ``timeout``, ``max_size``, and ``max_queue``, ``read_limit``, and ``write_limit`` optional arguments is described in the @@ -438,9 +467,6 @@ def serve(ws_handler, host=None, port=None, *, logger.setLevel(logging.ERROR) logger.addHandler(logging.StreamHandler()) - On Python 3.5, :func:`serve` can be used as a asynchronous context - manager. In that case, the server is shut down when exiting the context. - """ if loop is None: loop = asyncio.get_event_loop() From 79949f1859b0ab7a742e234f46b2b1ceb0cd038e Mon Sep 17 00:00:00 2001 From: Chris Jerdonek Date: Tue, 18 Jul 2017 16:24:02 -0700 Subject: [PATCH 0228/1539] Fix CircleCI flake8 error. --- websockets/server.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/websockets/server.py b/websockets/server.py index 16427c670..33e2f774f 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -343,7 +343,8 @@ def wrap(self, server): - instantiate a :class:`WebSocketServer` - give the protocol factory a reference to that instance - - call :meth:`~asyncio.AbstractEventLoop.create_server` with the factory + - call :meth:`~asyncio.AbstractEventLoop.create_server` with the + factory - attach the resulting :class:`~asyncio.Server` with this method """ From e0915479ce582b0b69b551cc62ab2eb4e6cd39f9 Mon Sep 17 00:00:00 2001 From: Chris Jerdonek Date: Fri, 14 Jul 2017 15:15:16 -0700 Subject: [PATCH 0229/1539] Fix issue #207: always set self.origin. --- websockets/server.py | 4 ++-- websockets/test_client_server.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/websockets/server.py b/websockets/server.py index 33e2f774f..279d814df 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -187,11 +187,11 @@ def process_origin(self, get_header, origins=None): acceptable. """ + origin = get_header('Origin') if origins is not None: - origin = get_header('Origin') if origin not in origins: raise InvalidOrigin("Origin not allowed: {}".format(origin)) - return origin + return origin def process_subprotocol(self, get_header, subprotocols=None): """ diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 6ad6ec20a..5b47c6c50 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -1,5 +1,6 @@ import asyncio import http +import http.client import logging import os import ssl @@ -217,6 +218,37 @@ def test_protocol_custom_response_headers_list(self): self.stop_client() self.stop_server() + def test_get_response_status_attributes_available(self): + # Save the attribute values to a dict instead of asserting inside + # get_response_status() because assertion errors there do not + # currently bubble up for easy viewing. + attrs = {} + + class SaveAttributesProtocol(WebSocketServerProtocol): + @asyncio.coroutine + def get_response_status(self, set_header): + attrs['origin'] = self.origin + attrs['path'] = self.path + attrs['raw_request_headers'] = self.raw_request_headers.copy() + attrs['request_headers'] = self.request_headers + status = yield from super().get_response_status(set_header) + return status + + self.start_server(klass=SaveAttributesProtocol) + try: + self.start_client(path='foo/bar', origin='http://otherhost') + self.assertEqual(attrs['origin'], 'http://otherhost') + self.assertEqual(attrs['path'], '/foo/bar') + # To reduce test brittleness, only check one nontrivial aspect + # of the request headers. + self.assertIn(('Origin', 'http://otherhost'), + attrs['raw_request_headers']) + request_headers = attrs['request_headers'] + self.assertIsInstance(request_headers, http.client.HTTPMessage) + self.assertEqual(request_headers.get('origin'), 'http://otherhost') + finally: + self.stop_server() + def test_authentication(self): self.start_server(klass=ForbiddenWebSocketServerProtocol) with self.assertRaises(InvalidHandshake): From aa711aa5554814eb625924e5e543191dc3ad5b0a Mon Sep 17 00:00:00 2001 From: Chris Jerdonek Date: Thu, 20 Jul 2017 14:14:04 -0700 Subject: [PATCH 0230/1539] Expose invalid status code during connection. (#209) Fix #198. Thanks @cjerdonek! * Address issue #198: expose invalid status code during connection. * Address review comments. * Fix tests, code coverage, and flake8. * Ignore the response reason per the spec. * Use double quotes for consistency. --- docs/changelog.rst | 4 ++++ websockets/client.py | 4 ++-- websockets/exceptions.py | 15 ++++++++++++++- websockets/test_client_server.py | 20 +++++++++++++++----- websockets/test_http.py | 5 +++-- 5 files changed, 38 insertions(+), 10 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 0bd7910c0..e4929149f 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -16,6 +16,10 @@ Changelog * Rewrote HTTP handling for simplicity and performance. +* An invalid response status code during :func:`~websockets.client.connect` + now raises :class:`~websockets.exceptions.InvalidStatus` with a ``code`` + attribute. + 3.3 ... diff --git a/websockets/client.py b/websockets/client.py index 411cf37f5..143ec37a0 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -6,7 +6,7 @@ import asyncio import collections.abc -from .exceptions import InvalidHandshake, InvalidMessage +from .exceptions import InvalidHandshake, InvalidMessage, InvalidStatus from .handshake import build_request, check_response from .http import USER_AGENT, build_headers, read_response from .protocol import CONNECTING, OPEN, WebSocketCommonProtocol @@ -118,7 +118,7 @@ def handshake(self, wsuri, get_header = lambda k: headers.get(k, '') if status_code != 101: - raise InvalidHandshake("Bad status code: {}".format(status_code)) + raise InvalidStatus(status_code) check_response(get_header, key) diff --git a/websockets/exceptions.py b/websockets/exceptions.py index 3d3ad46f1..b3917c704 100644 --- a/websockets/exceptions.py +++ b/websockets/exceptions.py @@ -1,6 +1,6 @@ __all__ = [ 'InvalidHandshake', 'InvalidMessage', 'InvalidOrigin', 'InvalidState', - 'InvalidURI', 'ConnectionClosed', 'PayloadTooBig', + 'InvalidStatus', 'InvalidURI', 'ConnectionClosed', 'PayloadTooBig', 'WebSocketProtocolError', ] @@ -26,6 +26,19 @@ class InvalidOrigin(InvalidHandshake): """ +class InvalidStatus(InvalidHandshake): + """ + Exception raised when a handshake response status code is invalid. + + Provides the integer status code in its ``code`` attribute. + + """ + def __init__(self, code): + self.code = code + message = 'Status code not 101: {}'.format(code) + super().__init__(message) + + class InvalidState(Exception): """ Exception raised when an operation is forbidden in the current state. diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 5b47c6c50..04c4548d2 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -8,7 +8,7 @@ import unittest.mock from .client import * -from .exceptions import ConnectionClosed, InvalidHandshake +from .exceptions import ConnectionClosed, InvalidHandshake, InvalidStatus from .http import USER_AGENT, read_response from .server import * @@ -251,7 +251,7 @@ def get_response_status(self, set_header): def test_authentication(self): self.start_server(klass=ForbiddenWebSocketServerProtocol) - with self.assertRaises(InvalidHandshake): + with self.assertRaises(InvalidStatus): self.start_client() self.stop_server() @@ -360,7 +360,7 @@ def wrong_read_response(stream): _read_response.side_effect = wrong_read_response self.start_server() - with self.assertRaises(InvalidHandshake): + with self.assertRaises(InvalidStatus): self.start_client() self.run_loop_once() self.stop_server() @@ -418,7 +418,7 @@ def test_server_shuts_down_during_opening_handshake(self, _read_request): self.stop_server() # Opening handshake fails with 503 Service Unavailable - self.assertEqual(str(raised.exception), "Bad status code: 503") + self.assertEqual(str(raised.exception), "Status code not 101: 503") def test_server_shuts_down_during_connection_handling(self): self.start_server() @@ -433,6 +433,15 @@ def test_server_shuts_down_during_connection_handling(self): # Websocket connection terminates with 1001 Going Away. self.assertEqual(self.client.close_code, 1001) + def test_invalid_status_error_during_client_connect(self): + self.start_server(klass=ForbiddenWebSocketServerProtocol) + with self.assertRaises(InvalidStatus) as raised: + self.start_client() + exception = raised.exception + self.assertEqual(str(exception), "Status code not 101: 403") + self.assertEqual(exception.code, 403) + self.stop_server() + @unittest.mock.patch('websockets.server.read_request') def test_connection_error_during_opening_handshake(self, _read_request): _read_request.side_effect = ConnectionError @@ -522,7 +531,8 @@ def test_checking_origin_succeeds(self): def test_checking_origin_fails(self): server = self.loop.run_until_complete( serve(handler, 'localhost', 8642, origins=['http://localhost'])) - with self.assertRaisesRegex(InvalidHandshake, "Bad status code: 403"): + with self.assertRaisesRegex(InvalidHandshake, + "Status code not 101: 403"): self.loop.run_until_complete( connect('ws://localhost:8642/', origin='http://otherhost')) diff --git a/websockets/test_http.py b/websockets/test_http.py index 28ad4a25e..0e13c8f5c 100644 --- a/websockets/test_http.py +++ b/websockets/test_http.py @@ -44,9 +44,10 @@ def test_read_response(self): b'Sec-WebSocket-Protocol: chat\r\n' b'\r\n' ) - status, hdrs = self.loop.run_until_complete(read_response(self.stream)) + status, headers = self.loop.run_until_complete( + read_response(self.stream)) self.assertEqual(status, 101) - self.assertEqual(dict(hdrs)['Upgrade'], 'websocket') + self.assertEqual(dict(headers)['Upgrade'], 'websocket') def test_request_method(self): self.stream.feed_data(b'OPTIONS * HTTP/1.1\r\n\r\n') From 93f258fb455d68ec931a53dde3d138e27ed5ef72 Mon Sep 17 00:00:00 2001 From: Chris Jerdonek Date: Fri, 21 Jul 2017 19:16:54 -0700 Subject: [PATCH 0231/1539] Address issue #212: have coverage report missing line numbers. --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 4c1e73138..a92a0b8d6 100644 --- a/tox.ini +++ b/tox.ini @@ -10,7 +10,7 @@ commands = python -m unittest commands = python -m coverage erase python -m coverage run --branch --source=websockets -m unittest - python -m coverage report --fail-under=100 + python -m coverage report --show-missing --fail-under=100 deps = coverage [testenv:flake8] From c0737a73f32632e0d954e7b701e892a20ec1d372 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 22 Jul 2017 16:37:20 +0200 Subject: [PATCH 0232/1539] Speed up frame (un)masking with a C extensions. Thanks Messense Lv (@messense) and Matthieu Darbois (@mayeut) for proposing previous iterations of this patch in #175 and #211. --- .gitignore | 1 + Makefile | 2 +- setup.py | 10 +++++ tox.ini | 27 ++++++++++++- websockets/framing.py | 10 ++++- websockets/speedups.c | 80 +++++++++++++++++++++++++++++++++++++ websockets/test_speedups.py | 0 websockets/test_utils.py | 49 +++++++++++++++++++++++ websockets/utils.py | 14 +++++++ 9 files changed, 188 insertions(+), 5 deletions(-) create mode 100644 websockets/speedups.c create mode 100644 websockets/test_speedups.py create mode 100644 websockets/test_utils.py create mode 100644 websockets/utils.py diff --git a/.gitignore b/.gitignore index 1a6a602c4..4dc1216b7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ *.pyc +*.so .coverage .tox build/ diff --git a/Makefile b/Makefile index e942ecd75..fd24aa7ff 100644 --- a/Makefile +++ b/Makefile @@ -10,6 +10,6 @@ coverage: python -m coverage html clean: - find . -name '*.pyc' -delete + find . -name '*.pyc' -o -name '*.so' -delete find . -name __pycache__ -delete rm -rf .coverage build compliance/reports dist docs/_build htmlcov MANIFEST README websockets.egg-info diff --git a/setup.py b/setup.py index a2fe50a81..6363c6f23 100644 --- a/setup.py +++ b/setup.py @@ -25,6 +25,14 @@ if py_version >= (3, 5): packages.append('websockets/py35') +ext_modules = [ + setuptools.Extension( + 'websockets.speedups', + sources=['websockets/speedups.c'], + optional=True, + ) +] + setuptools.setup( name='websockets', version=version, @@ -45,8 +53,10 @@ 'Programming Language :: Python :: 3.3', 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', ], packages=packages, + ext_modules=ext_modules, extras_require={ ':python_version=="3.3"': ['asyncio'], }, diff --git a/tox.ini b/tox.ini index a92a0b8d6..972301336 100644 --- a/tox.ini +++ b/tox.ini @@ -1,16 +1,39 @@ [tox] -envlist = py33,py34,py35,py36,coverage,flake8,isort +envlist = {py33,py34,py35,py36}{,-speedups},coverage,flake8,isort [testenv] +commands = + ; Unfortunately tox has no support for building C extensions. + ; Do it manually in the git checkout - that's where tests run. + + ; Remove any existing compiled extension. + sh -c 'rm -f websockets/*.so' + + ; Before testing with speedups, compile the extension. + speedups: python setup.py --quiet build_ext --inplace + + python -m unittest + + ; After testing with speedups, remove the extension. + speedups: sh -c 'rm websockets/*.so' + deps = py33: asyncio -commands = python -m unittest +whitelist_externals = + sh [testenv:coverage] commands = + ; Handle speedups as above. + sh -c 'rm -f websockets/*.so' + python setup.py --quiet build_ext --inplace + python -m coverage erase python -m coverage run --branch --source=websockets -m unittest python -m coverage report --show-missing --fail-under=100 + + speedups: sh -c 'rm websockets/*.so' + deps = coverage [testenv:flake8] diff --git a/websockets/framing.py b/websockets/framing.py index 544ac27ec..135a139e5 100644 --- a/websockets/framing.py +++ b/websockets/framing.py @@ -18,6 +18,12 @@ from .exceptions import PayloadTooBig, WebSocketProtocolError +try: + from .speedups import apply_mask +except ImportError: # pragma: no cover + from .utils import apply_mask + + __all__ = [ 'OP_CONT', 'OP_TEXT', 'OP_BINARY', 'OP_CLOSE', 'OP_PING', 'OP_PONG', 'Frame', 'read_frame', 'write_frame', 'parse_close', 'serialize_close' @@ -101,7 +107,7 @@ def read_frame(reader, mask, *, max_size=None): # Read the data data = yield from reader(length) if mask: - data = bytes(b ^ mask_bits[i % 4] for i, b in enumerate(data)) + data = apply_mask(data, mask_bits) frame = Frame(fin, opcode, data) check_frame(frame) @@ -144,7 +150,7 @@ def write_frame(frame, writer, mask): # Prepare the data if mask: - data = bytes(b ^ mask_bits[i % 4] for i, b in enumerate(frame.data)) + data = apply_mask(frame.data, mask_bits) else: data = frame.data output.write(data) diff --git a/websockets/speedups.c b/websockets/speedups.c new file mode 100644 index 000000000..7d894973c --- /dev/null +++ b/websockets/speedups.c @@ -0,0 +1,80 @@ +/* C implementation of performance sensitive functions. */ + +#define PY_SSIZE_T_CLEAN +#include + +const Py_ssize_t MASK_LEN = 4; + +static PyObject * +apply_mask(PyObject *self, PyObject *args, PyObject *kwds) +{ + + // Inputs are treated as immutable, which causes an extra memory copy. + + static char *kwlist[] = {"data", "mask", NULL}; + const char *input; + Py_ssize_t input_len; + const char *mask; + Py_ssize_t mask_len; + + // Initialize a PyBytesObject then get a pointer to the underlying char * + // in order to avoid an extra memory copy in PyBytes_FromStringAndSize. + + PyObject *result; + char *output; + Py_ssize_t i; + + if (!PyArg_ParseTupleAndKeywords( + args, kwds, "s#s#", kwlist, &input, &input_len, &mask, &mask_len)) + { + return NULL; + } + + if (mask_len != MASK_LEN) + { + PyErr_SetString(PyExc_ValueError, "mask must contain 4 bytes"); + return NULL; + } + + result = PyBytes_FromStringAndSize(NULL, input_len); + if (result == NULL) + { + return NULL; + } + + // Since we juste created result, we don't need error checks. + output = PyBytes_AS_STRING(result); + + for (i = 0; i < input_len; i++) + { + output[i] = input[i] ^ mask[i % MASK_LEN]; + } + + return result; + +} + +static PyMethodDef speedups_methods[] = { + { + "apply_mask", + (PyCFunction)apply_mask, + METH_VARARGS | METH_KEYWORDS, + "Apply masking to websocket message.", + }, + {NULL, NULL, 0, NULL}, /* Sentinel */ +}; + +static struct PyModuleDef speedups_module = { + PyModuleDef_HEAD_INIT, + "websocket.speedups", /* m_name */ + "C implementation of performance sensitive functions.", + /* m_doc */ + -1, /* m_size */ + speedups_methods, /* m_methods */ +}; + +PyMODINIT_FUNC +PyInit_speedups(void) +{ + return PyModule_Create(&speedups_module); +} diff --git a/websockets/test_speedups.py b/websockets/test_speedups.py new file mode 100644 index 000000000..e69de29bb diff --git a/websockets/test_utils.py b/websockets/test_utils.py new file mode 100644 index 000000000..7b20284aa --- /dev/null +++ b/websockets/test_utils.py @@ -0,0 +1,49 @@ +import unittest + +from .utils import apply_mask as py_apply_mask + + +class UtilsTests(unittest.TestCase): + + @staticmethod + def apply_mask(*args, **kwargs): + return py_apply_mask(*args, **kwargs) + + def test_apply_mask(self): + for data_in, mask, data_out in [ + (b'', b'1234', b''), + (b'aBcDe', b'\x00\x00\x00\x00', b'aBcDe'), + (b'abcdABCD', b'1234', b'PPPPpppp'), + ]: + self.assertEqual(self.apply_mask(data_in, mask), data_out) + + def test_apply_mask_check_input_types(self): + for data_in, mask in [ + (None, None), + (b'abcd', None), + (None, b'abcd'), + ]: + with self.assertRaises(TypeError): + self.apply_mask(data_in, mask) + + def test_apply_mask_check_mask_length(self): + for data_in, mask in [ + (b'', b''), + (b'abcd', b'123'), + (b'', b'aBcDe'), + (b'12345678', b'12345678'), + ]: + with self.assertRaises(ValueError): + self.apply_mask(data_in, mask) + + +try: + from .speedups import apply_mask as c_apply_mask +except ImportError: # pragma: no cover + pass +else: + class SpeedupsTests(UtilsTests): + + @staticmethod + def apply_mask(*args, **kwargs): + return c_apply_mask(*args, **kwargs) diff --git a/websockets/utils.py b/websockets/utils.py new file mode 100644 index 000000000..b4083dff4 --- /dev/null +++ b/websockets/utils.py @@ -0,0 +1,14 @@ +import itertools + + +__all__ = ['apply_mask'] + + +def apply_mask(data, mask): + """ + Apply masking to websocket message. + + """ + if len(mask) != 4: + raise ValueError("mask must contain 4 bytes") + return bytes(b ^ m for b, m in zip(data, itertools.cycle(mask))) From c53221933e8558c349c52b9b5d066520f370b6d1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 Jul 2017 13:06:04 +0200 Subject: [PATCH 0233/1539] Optimize performance with AVX. This is mostly for fun. It won't have an effect unless the extension is compiled with CFLAGS='-march=native' or similar. --- websockets/speedups.c | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/websockets/speedups.c b/websockets/speedups.c index 7d894973c..f634c4024 100644 --- a/websockets/speedups.c +++ b/websockets/speedups.c @@ -3,6 +3,10 @@ #define PY_SSIZE_T_CLEAN #include +#if __AVX__ +#include +#endif + const Py_ssize_t MASK_LEN = 4; static PyObject * @@ -22,7 +26,7 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds) PyObject *result; char *output; - Py_ssize_t i; + Py_ssize_t i = 0; if (!PyArg_ParseTupleAndKeywords( args, kwds, "s#s#", kwlist, &input, &input_len, &mask, &mask_len)) @@ -45,7 +49,27 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds) // Since we juste created result, we don't need error checks. output = PyBytes_AS_STRING(result); - for (i = 0; i < input_len; i++) + // Apparently GCC cannot figure out the following optimizations by itself. + +#if __AVX__ + + // With AVX support, XOR by blocks of 32 bytes = 256 bits. + + Py_ssize_t input_len_256 = input_len & ~31; + __m256 mask_256 = _mm256_set1_epi32(*(int *)mask); + + for (; i < input_len_256; i += 32) + { + __m256i in_256 = _mm256_loadu_si256((__m256i *)(input + i)); + __m256i out_256 = _mm256_xor_si256(in_256, mask_256); + _mm256_storeu_si256((__m256i *)(output + i), out_256); + } + +#endif + + // XOR the remainder of the input byte by byte. + + for (; i < input_len; i++) { output[i] = input[i] ^ mask[i % MASK_LEN]; } From 7a84580b57bd7c5a4ec7ced92e75f2ab4f2c9ccc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 Jul 2017 13:55:10 +0200 Subject: [PATCH 0234/1539] Optimize performance with SSE2. Unlike AVX, support for SSE2 is enabled by default in modern compilers. --- websockets/speedups.c | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/websockets/speedups.c b/websockets/speedups.c index f634c4024..823edd55f 100644 --- a/websockets/speedups.c +++ b/websockets/speedups.c @@ -3,7 +3,7 @@ #define PY_SSIZE_T_CLEAN #include -#if __AVX__ +#if __AVX__ || __SSE2__ #include #endif @@ -65,6 +65,23 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds) _mm256_storeu_si256((__m256i *)(output + i), out_256); } +#elif __SSE2__ + + // With SSE2 support, XOR by blocks of 16 bytes = 128 bits. + + // Since we cannot control the 16-bytes alignment of input and output + // buffers, we rely on loadu/storeu rather than load/store. + + Py_ssize_t input_len_128 = input_len & ~15; + __m128i mask_128 = _mm_set1_epi32(*(uint32_t *)mask); + + for (; i < input_len_128; i += 16) + { + __m128i in_128 = _mm_loadu_si128((__m128i *)(input + i)); + __m128i out_128 = _mm_xor_si128(in_128, mask_128); + _mm_storeu_si128((__m128i *)(output + i), out_128); + } + #endif // XOR the remainder of the input byte by byte. From 6d7744c5e3d5880043d18021cb345b45d6340914 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 Jul 2017 14:16:07 +0200 Subject: [PATCH 0235/1539] Remove support for AVX. It doesn't bring significant performance improvements over SSE2. --- websockets/speedups.c | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/websockets/speedups.c b/websockets/speedups.c index 823edd55f..91a8801cb 100644 --- a/websockets/speedups.c +++ b/websockets/speedups.c @@ -3,8 +3,8 @@ #define PY_SSIZE_T_CLEAN #include -#if __AVX__ || __SSE2__ -#include +#if __SSE2__ +#include #endif const Py_ssize_t MASK_LEN = 4; @@ -51,21 +51,7 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds) // Apparently GCC cannot figure out the following optimizations by itself. -#if __AVX__ - - // With AVX support, XOR by blocks of 32 bytes = 256 bits. - - Py_ssize_t input_len_256 = input_len & ~31; - __m256 mask_256 = _mm256_set1_epi32(*(int *)mask); - - for (; i < input_len_256; i += 32) - { - __m256i in_256 = _mm256_loadu_si256((__m256i *)(input + i)); - __m256i out_256 = _mm256_xor_si256(in_256, mask_256); - _mm256_storeu_si256((__m256i *)(output + i), out_256); - } - -#elif __SSE2__ +#if __SSE2__ // With SSE2 support, XOR by blocks of 16 bytes = 128 bits. From 457cd0e2b453a50cd956abd91620697038036e77 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 Jul 2017 14:24:50 +0200 Subject: [PATCH 0236/1539] Add changelog entry. --- docs/changelog.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index e4929149f..6b0241239 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -9,13 +9,15 @@ Changelog * :func:`~websockets.server.serve` can be used as an asynchronous context manager on Python ≥ 3.5. -* Added support rejecting incoming connections by customizing +* Added support for rejecting incoming connections by customizing :meth:`~websockets.server.WebSocketServerProtocol.get_response_status()`. * Made read and write buffer sizes configurable. * Rewrote HTTP handling for simplicity and performance. +* Added an optional C extension to speed up low level operations. + * An invalid response status code during :func:`~websockets.client.connect` now raises :class:`~websockets.exceptions.InvalidStatus` with a ``code`` attribute. From 45bd2762fd0f0cf8dce82425e99bed02cc384c65 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 Jul 2017 15:00:14 +0200 Subject: [PATCH 0237/1539] Optimize for 64-bit, non SSE2 CPUs (e.g. ARM). --- websockets/speedups.c | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/websockets/speedups.c b/websockets/speedups.c index 91a8801cb..502fd9156 100644 --- a/websockets/speedups.c +++ b/websockets/speedups.c @@ -68,6 +68,20 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds) _mm_storeu_si128((__m128i *)(output + i), out_128); } +#else + + // Without SSE2 support, XOR by blocks of 8 bytes = 64 bits. + + // We assume the memory allocator aligns everything on 8 bytes boundaries. + + Py_ssize_t input_len_64 = input_len & ~7; + uint64_t mask_64 = (*(uint64_t *)mask << 32) | *(uint64_t *)mask; + + for (; i < input_len_64; i += 8) + { + *(uint64_t *)(output + i) = *(uint64_t *)(input + i) ^ mask_64; + } + #endif // XOR the remainder of the input byte by byte. From b8eda04abb8ab39f89e3d93457e53efdbd71beea Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 Jul 2017 15:31:28 +0200 Subject: [PATCH 0238/1539] Add a test for apply_mask with more data. This forces the code to go through the 16 bytes blocks path. --- websockets/test_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/websockets/test_utils.py b/websockets/test_utils.py index 7b20284aa..8259b7490 100644 --- a/websockets/test_utils.py +++ b/websockets/test_utils.py @@ -14,6 +14,7 @@ def test_apply_mask(self): (b'', b'1234', b''), (b'aBcDe', b'\x00\x00\x00\x00', b'aBcDe'), (b'abcdABCD', b'1234', b'PPPPpppp'), + (b'abcdABCD' * 10, b'1234', b'PPPPpppp' * 10), ]: self.assertEqual(self.apply_mask(data_in, mask), data_out) From 1f7604b7e2272567fbb3e00c7814782018113a94 Mon Sep 17 00:00:00 2001 From: 38elements <38elements@users.noreply.github.com> Date: Sun, 2 Jul 2017 13:37:30 +0900 Subject: [PATCH 0239/1539] Fix api.rst This adds max_queue to WebSocketCommonProtocol in api.rst. --- docs/api.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/api.rst b/docs/api.rst index 62b580129..3fe2763fe 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -61,7 +61,7 @@ Shared .. automodule:: websockets.protocol - .. autoclass:: WebSocketCommonProtocol(*, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, loop=None) + .. autoclass:: WebSocketCommonProtocol(*, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, loop=None) .. automethod:: close(code=1000, reason='') From 0058c0954fc102ba74434c66ee810a078e40789b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 Jul 2017 15:42:32 +0200 Subject: [PATCH 0240/1539] Add read/write_limit parameters to API docs. --- docs/api.rst | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index 3fe2763fe..9dd5f8f88 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -32,14 +32,14 @@ Server .. automodule:: websockets.server - .. autofunction:: serve(ws_handler, host=None, port=None, *, klass=WebSocketServerProtocol, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, loop=None, origins=None, subprotocols=None, extra_headers=None, **kwds) + .. autofunction:: serve(ws_handler, host=None, port=None, *, klass=WebSocketServerProtocol, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, subprotocols=None, extra_headers=None, **kwds) .. autoclass:: WebSocketServer .. automethod:: close() .. automethod:: wait_closed() - .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, loop=None, origins=None, subprotocols=None, extra_headers=None) + .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, subprotocols=None, extra_headers=None) .. automethod:: handshake(origins=None, subprotocols=None, extra_headers=None) .. automethod:: select_subprotocol(client_protos, server_protos) @@ -50,9 +50,9 @@ Client .. automodule:: websockets.client - .. autofunction:: connect(uri, *, klass=WebSocketClientProtocol, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, loop=None, origin=None, subprotocols=None, extra_headers=None, **kwds) + .. autofunction:: connect(uri, *, klass=WebSocketClientProtocol, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, subprotocols=None, extra_headers=None, **kwds) - .. autoclass:: WebSocketClientProtocol(*, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, loop=None) + .. autoclass:: WebSocketClientProtocol(*, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None) .. automethod:: handshake(wsuri, origin=None, subprotocols=None, extra_headers=None) @@ -61,7 +61,7 @@ Shared .. automodule:: websockets.protocol - .. autoclass:: WebSocketCommonProtocol(*, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, loop=None) + .. autoclass:: WebSocketCommonProtocol(*, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None) .. automethod:: close(code=1000, reason='') From 6173a0d6ad12967224d7a2d588ce8ab93f22b11c Mon Sep 17 00:00:00 2001 From: mayeut Date: Sun, 23 Jul 2017 17:53:54 +0200 Subject: [PATCH 0241/1539] Correctly read mask in non SSE2 cases in speedups.c Fixes aaugustin/websockets#223 --- websockets/speedups.c | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/websockets/speedups.c b/websockets/speedups.c index 502fd9156..071512541 100644 --- a/websockets/speedups.c +++ b/websockets/speedups.c @@ -75,7 +75,8 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds) // We assume the memory allocator aligns everything on 8 bytes boundaries. Py_ssize_t input_len_64 = input_len & ~7; - uint64_t mask_64 = (*(uint64_t *)mask << 32) | *(uint64_t *)mask; + uint32_t mask_32 = *(uint32_t *)mask; + uint64_t mask_64 = ((uint64_t)mask_32 << 32) | (uint64_t)mask_32; for (; i < input_len_64; i += 8) { From cb225d7eb13ac6be800d170962b65bd580b00cf2 Mon Sep 17 00:00:00 2001 From: mayeut Date: Sun, 23 Jul 2017 18:00:33 +0200 Subject: [PATCH 0242/1539] Don't rely on Python.h to include stdint.h in speedups.c MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The code in speedups.c relies on the assumption that Python.h includes stdint.h This might change in the future so don’t rely on this assumption. --- websockets/speedups.c | 1 + 1 file changed, 1 insertion(+) diff --git a/websockets/speedups.c b/websockets/speedups.c index 071512541..70806e43c 100644 --- a/websockets/speedups.c +++ b/websockets/speedups.c @@ -2,6 +2,7 @@ #define PY_SSIZE_T_CLEAN #include +#include /* uint32_t, uint64_t */ #if __SSE2__ #include From 16c8c5386b2d1b5e52b8d2c182bcd80a493a3219 Mon Sep 17 00:00:00 2001 From: Chris Jerdonek Date: Sat, 22 Jul 2017 17:45:41 -0700 Subject: [PATCH 0243/1539] Add ClientServerTests.temp_server() context manager. --- websockets/test_client_server.py | 446 +++++++++++++++---------------- 1 file changed, 215 insertions(+), 231 deletions(-) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 04c4548d2..75b71c793 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -6,6 +6,7 @@ import ssl import unittest import unittest.mock +from contextlib import contextmanager from .client import * from .exceptions import ConnectionClosed, InvalidHandshake, InvalidStatus @@ -92,131 +93,128 @@ def stop_server(self): except asyncio.TimeoutError: # pragma: no cover self.fail("Server failed to stop") + @contextmanager + def temp_server(self, **kwds): + self.start_server(**kwds) + try: + yield + finally: + self.stop_server() + def test_basic(self): - self.start_server() - self.start_client() - self.loop.run_until_complete(self.client.send("Hello!")) - reply = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(reply, "Hello!") - self.stop_client() - self.stop_server() + with self.temp_server(): + self.start_client() + self.loop.run_until_complete(self.client.send("Hello!")) + reply = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(reply, "Hello!") + self.stop_client() def test_server_close_while_client_connected(self): - self.start_server() - self.start_client() - self.stop_server() + with self.temp_server(): + self.start_client() def test_explicit_event_loop(self): - self.start_server(loop=self.loop) - self.start_client(loop=self.loop) - self.loop.run_until_complete(self.client.send("Hello!")) - reply = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(reply, "Hello!") - self.stop_client() - self.stop_server() + with self.temp_server(loop=self.loop): + self.start_client(loop=self.loop) + self.loop.run_until_complete(self.client.send("Hello!")) + reply = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(reply, "Hello!") + self.stop_client() def test_protocol_attributes(self): - self.start_server() - self.start_client('attributes') - expected_attrs = ('localhost', 8642, self.secure) - client_attrs = (self.client.host, self.client.port, self.client.secure) - self.assertEqual(client_attrs, expected_attrs) - server_attrs = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(server_attrs, repr(expected_attrs)) - self.stop_client() - self.stop_server() + with self.temp_server(): + self.start_client('attributes') + expected_attrs = ('localhost', 8642, self.secure) + client_attrs = (self.client.host, self.client.port, + self.client.secure) + self.assertEqual(client_attrs, expected_attrs) + server_attrs = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(server_attrs, repr(expected_attrs)) + self.stop_client() def test_protocol_path(self): - self.start_server() - self.start_client('path') - client_path = self.client.path - self.assertEqual(client_path, '/path') - server_path = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(server_path, '/path') - self.stop_client() - self.stop_server() + with self.temp_server(): + self.start_client('path') + client_path = self.client.path + self.assertEqual(client_path, '/path') + server_path = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(server_path, '/path') + self.stop_client() def test_protocol_headers(self): - self.start_server() - self.start_client('headers') - client_req = self.client.request_headers - client_resp = self.client.response_headers - self.assertEqual(client_req['User-Agent'], USER_AGENT) - self.assertEqual(client_resp['Server'], USER_AGENT) - server_req = self.loop.run_until_complete(self.client.recv()) - server_resp = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(server_req, str(client_req)) - self.assertEqual(server_resp, str(client_resp)) - self.stop_client() - self.stop_server() + with self.temp_server(): + self.start_client('headers') + client_req = self.client.request_headers + client_resp = self.client.response_headers + self.assertEqual(client_req['User-Agent'], USER_AGENT) + self.assertEqual(client_resp['Server'], USER_AGENT) + server_req = self.loop.run_until_complete(self.client.recv()) + server_resp = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(server_req, str(client_req)) + self.assertEqual(server_resp, str(client_resp)) + self.stop_client() def test_protocol_raw_headers(self): - self.start_server() - self.start_client('raw_headers') - client_req = self.client.raw_request_headers - client_resp = self.client.raw_response_headers - self.assertEqual(dict(client_req)['User-Agent'], USER_AGENT) - self.assertEqual(dict(client_resp)['Server'], USER_AGENT) - server_req = self.loop.run_until_complete(self.client.recv()) - server_resp = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(server_req, repr(client_req)) - self.assertEqual(server_resp, repr(client_resp)) - self.stop_client() - self.stop_server() + with self.temp_server(): + self.start_client('raw_headers') + client_req = self.client.raw_request_headers + client_resp = self.client.raw_response_headers + self.assertEqual(dict(client_req)['User-Agent'], USER_AGENT) + self.assertEqual(dict(client_resp)['Server'], USER_AGENT) + server_req = self.loop.run_until_complete(self.client.recv()) + server_resp = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(server_req, repr(client_req)) + self.assertEqual(server_resp, repr(client_resp)) + self.stop_client() def test_protocol_custom_request_headers_dict(self): - self.start_server() - self.start_client('raw_headers', extra_headers={'X-Spam': 'Eggs'}) - req_headers = self.loop.run_until_complete(self.client.recv()) - self.loop.run_until_complete(self.client.recv()) - self.assertIn("('X-Spam', 'Eggs')", req_headers) - self.stop_client() - self.stop_server() + with self.temp_server(): + self.start_client('raw_headers', extra_headers={'X-Spam': 'Eggs'}) + req_headers = self.loop.run_until_complete(self.client.recv()) + self.loop.run_until_complete(self.client.recv()) + self.assertIn("('X-Spam', 'Eggs')", req_headers) + self.stop_client() def test_protocol_custom_request_headers_list(self): - self.start_server() - self.start_client('raw_headers', extra_headers=[('X-Spam', 'Eggs')]) - req_headers = self.loop.run_until_complete(self.client.recv()) - self.loop.run_until_complete(self.client.recv()) - self.assertIn("('X-Spam', 'Eggs')", req_headers) - self.stop_client() - self.stop_server() + with self.temp_server(): + self.start_client('raw_headers', + extra_headers=[('X-Spam', 'Eggs')]) + req_headers = self.loop.run_until_complete(self.client.recv()) + self.loop.run_until_complete(self.client.recv()) + self.assertIn("('X-Spam', 'Eggs')", req_headers) + self.stop_client() def test_protocol_custom_response_headers_callable_dict(self): - self.start_server(extra_headers=lambda p, r: {'X-Spam': 'Eggs'}) - self.start_client('raw_headers') - self.loop.run_until_complete(self.client.recv()) - resp_headers = self.loop.run_until_complete(self.client.recv()) - self.assertIn("('X-Spam', 'Eggs')", resp_headers) - self.stop_client() - self.stop_server() + with self.temp_server(extra_headers=lambda p, r: {'X-Spam': 'Eggs'}): + self.start_client('raw_headers') + self.loop.run_until_complete(self.client.recv()) + resp_headers = self.loop.run_until_complete(self.client.recv()) + self.assertIn("('X-Spam', 'Eggs')", resp_headers) + self.stop_client() def test_protocol_custom_response_headers_callable_list(self): - self.start_server(extra_headers=lambda p, r: [('X-Spam', 'Eggs')]) - self.start_client('raw_headers') - self.loop.run_until_complete(self.client.recv()) - resp_headers = self.loop.run_until_complete(self.client.recv()) - self.assertIn("('X-Spam', 'Eggs')", resp_headers) - self.stop_client() - self.stop_server() + with self.temp_server(extra_headers=lambda p, r: [('X-Spam', 'Eggs')]): + self.start_client('raw_headers') + self.loop.run_until_complete(self.client.recv()) + resp_headers = self.loop.run_until_complete(self.client.recv()) + self.assertIn("('X-Spam', 'Eggs')", resp_headers) + self.stop_client() def test_protocol_custom_response_headers_dict(self): - self.start_server(extra_headers={'X-Spam': 'Eggs'}) - self.start_client('raw_headers') - self.loop.run_until_complete(self.client.recv()) - resp_headers = self.loop.run_until_complete(self.client.recv()) - self.assertIn("('X-Spam', 'Eggs')", resp_headers) - self.stop_client() - self.stop_server() + with self.temp_server(extra_headers={'X-Spam': 'Eggs'}): + self.start_client('raw_headers') + self.loop.run_until_complete(self.client.recv()) + resp_headers = self.loop.run_until_complete(self.client.recv()) + self.assertIn("('X-Spam', 'Eggs')", resp_headers) + self.stop_client() def test_protocol_custom_response_headers_list(self): - self.start_server(extra_headers=[('X-Spam', 'Eggs')]) - self.start_client('raw_headers') - self.loop.run_until_complete(self.client.recv()) - resp_headers = self.loop.run_until_complete(self.client.recv()) - self.assertIn("('X-Spam', 'Eggs')", resp_headers) - self.stop_client() - self.stop_server() + with self.temp_server(extra_headers=[('X-Spam', 'Eggs')]): + self.start_client('raw_headers') + self.loop.run_until_complete(self.client.recv()) + resp_headers = self.loop.run_until_complete(self.client.recv()) + self.assertIn("('X-Spam', 'Eggs')", resp_headers) + self.stop_client() def test_get_response_status_attributes_available(self): # Save the attribute values to a dict instead of asserting inside @@ -234,8 +232,7 @@ def get_response_status(self, set_header): status = yield from super().get_response_status(set_header) return status - self.start_server(klass=SaveAttributesProtocol) - try: + with self.temp_server(klass=SaveAttributesProtocol): self.start_client(path='foo/bar', origin='http://otherhost') self.assertEqual(attrs['origin'], 'http://otherhost') self.assertEqual(attrs['path'], '/foo/bar') @@ -246,88 +243,84 @@ def get_response_status(self, set_header): request_headers = attrs['request_headers'] self.assertIsInstance(request_headers, http.client.HTTPMessage) self.assertEqual(request_headers.get('origin'), 'http://otherhost') - finally: - self.stop_server() def test_authentication(self): - self.start_server(klass=ForbiddenWebSocketServerProtocol) - with self.assertRaises(InvalidStatus): - self.start_client() - self.stop_server() + with self.temp_server(klass=ForbiddenWebSocketServerProtocol): + with self.assertRaises(InvalidStatus): + self.start_client() def test_no_subprotocol(self): - self.start_server() - self.start_client('subprotocol') - server_subprotocol = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(server_subprotocol, repr(None)) - self.assertEqual(self.client.subprotocol, None) - self.stop_client() - self.stop_server() + with self.temp_server(): + self.start_client('subprotocol') + server_subprotocol = self.loop.run_until_complete( + self.client.recv()) + self.assertEqual(server_subprotocol, repr(None)) + self.assertEqual(self.client.subprotocol, None) + self.stop_client() def test_subprotocol_found(self): - self.start_server(subprotocols=['superchat', 'chat']) - self.start_client('subprotocol', subprotocols=['otherchat', 'chat']) - server_subprotocol = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(server_subprotocol, repr('chat')) - self.assertEqual(self.client.subprotocol, 'chat') - self.stop_client() - self.stop_server() + with self.temp_server(subprotocols=['superchat', 'chat']): + self.start_client('subprotocol', + subprotocols=['otherchat', 'chat']) + server_subprotocol = self.loop.run_until_complete( + self.client.recv()) + self.assertEqual(server_subprotocol, repr('chat')) + self.assertEqual(self.client.subprotocol, 'chat') + self.stop_client() def test_subprotocol_not_found(self): - self.start_server(subprotocols=['superchat']) - self.start_client('subprotocol', subprotocols=['otherchat']) - server_subprotocol = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(server_subprotocol, repr(None)) - self.assertEqual(self.client.subprotocol, None) - self.stop_client() - self.stop_server() + with self.temp_server(subprotocols=['superchat']): + self.start_client('subprotocol', subprotocols=['otherchat']) + server_subprotocol = self.loop.run_until_complete( + self.client.recv()) + self.assertEqual(server_subprotocol, repr(None)) + self.assertEqual(self.client.subprotocol, None) + self.stop_client() def test_subprotocol_not_offered(self): - self.start_server() - self.start_client('subprotocol', subprotocols=['otherchat', 'chat']) - server_subprotocol = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(server_subprotocol, repr(None)) - self.assertEqual(self.client.subprotocol, None) - self.stop_client() - self.stop_server() + with self.temp_server(): + self.start_client('subprotocol', + subprotocols=['otherchat', 'chat']) + server_subprotocol = self.loop.run_until_complete( + self.client.recv()) + self.assertEqual(server_subprotocol, repr(None)) + self.assertEqual(self.client.subprotocol, None) + self.stop_client() def test_subprotocol_not_requested(self): - self.start_server(subprotocols=['superchat', 'chat']) - self.start_client('subprotocol') - server_subprotocol = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(server_subprotocol, repr(None)) - self.assertEqual(self.client.subprotocol, None) - self.stop_client() - self.stop_server() + with self.temp_server(subprotocols=['superchat', 'chat']): + self.start_client('subprotocol') + server_subprotocol = self.loop.run_until_complete( + self.client.recv()) + self.assertEqual(server_subprotocol, repr(None)) + self.assertEqual(self.client.subprotocol, None) + self.stop_client() @unittest.mock.patch.object(WebSocketServerProtocol, 'select_subprotocol') def test_subprotocol_error(self, _select_subprotocol): _select_subprotocol.return_value = 'superchat' - self.start_server(subprotocols=['superchat']) - with self.assertRaises(InvalidHandshake): - self.start_client('subprotocol', subprotocols=['otherchat']) - self.run_loop_once() - self.stop_server() + with self.temp_server(subprotocols=['superchat']): + with self.assertRaises(InvalidHandshake): + self.start_client('subprotocol', subprotocols=['otherchat']) + self.run_loop_once() @unittest.mock.patch('websockets.server.read_request') def test_server_receives_malformed_request(self, _read_request): _read_request.side_effect = ValueError("read_request failed") - self.start_server() - with self.assertRaises(InvalidHandshake): - self.start_client() - self.stop_server() + with self.temp_server(): + with self.assertRaises(InvalidHandshake): + self.start_client() @unittest.mock.patch('websockets.client.read_response') def test_client_receives_malformed_response(self, _read_response): _read_response.side_effect = ValueError("read_response failed") - self.start_server() - with self.assertRaises(InvalidHandshake): - self.start_client() - self.run_loop_once() - self.stop_server() + with self.temp_server(): + with self.assertRaises(InvalidHandshake): + self.start_client() + self.run_loop_once() @unittest.mock.patch('websockets.client.build_request') def test_client_sends_invalid_handshake_request(self, _build_request): @@ -335,10 +328,9 @@ def wrong_build_request(set_header): return '42' _build_request.side_effect = wrong_build_request - self.start_server() - with self.assertRaises(InvalidHandshake): - self.start_client() - self.stop_server() + with self.temp_server(): + with self.assertRaises(InvalidHandshake): + self.start_client() @unittest.mock.patch('websockets.server.build_response') def test_server_sends_invalid_handshake_response(self, _build_response): @@ -346,10 +338,9 @@ def wrong_build_response(set_header, key): return build_response(set_header, '42') _build_response.side_effect = wrong_build_response - self.start_server() - with self.assertRaises(InvalidHandshake): - self.start_client() - self.stop_server() + with self.temp_server(): + with self.assertRaises(InvalidHandshake): + self.start_client() @unittest.mock.patch('websockets.client.read_response') def test_server_does_not_switch_protocols(self, _read_response): @@ -359,23 +350,21 @@ def wrong_read_response(stream): return 400, headers _read_response.side_effect = wrong_read_response - self.start_server() - with self.assertRaises(InvalidStatus): - self.start_client() - self.run_loop_once() - self.stop_server() + with self.temp_server(): + with self.assertRaises(InvalidStatus): + self.start_client() + self.run_loop_once() @unittest.mock.patch('websockets.server.WebSocketServerProtocol.send') def test_server_handler_crashes(self, send): send.side_effect = ValueError("send failed") - self.start_server() - self.start_client() - self.loop.run_until_complete(self.client.send("Hello!")) - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(self.client.recv()) - self.stop_client() - self.stop_server() + with self.temp_server(): + self.start_client() + self.loop.run_until_complete(self.client.send("Hello!")) + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.client.recv()) + self.stop_client() # Connection ends with an unexpected error. self.assertEqual(self.client.close_code, 1011) @@ -384,87 +373,83 @@ def test_server_handler_crashes(self, send): def test_server_close_crashes(self, close): close.side_effect = ValueError("close failed") - self.start_server() - self.start_client() - self.loop.run_until_complete(self.client.send("Hello!")) - reply = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(reply, "Hello!") - self.stop_client() - self.stop_server() + with self.temp_server(): + self.start_client() + self.loop.run_until_complete(self.client.send("Hello!")) + reply = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(reply, "Hello!") + self.stop_client() # Connection ends with an abnormal closure. self.assertEqual(self.client.close_code, 1006) @unittest.mock.patch.object(WebSocketClientProtocol, 'handshake') def test_client_closes_connection_before_handshake(self, handshake): - self.start_server() - self.start_client() - # We have mocked the handshake() method to prevent the client from - # performing the opening handshake. Force it to close the connection. - self.loop.run_until_complete(self.client.close_connection(force=True)) - self.stop_client() - # The server should stop properly anyway. It used to hang because the - # worker handling the connection was waiting for the opening handshake. - self.stop_server() + with self.temp_server(): + self.start_client() + # We have mocked the handshake() method to prevent the client + # from performing the opening handshake. Force it to close the + # connection. + self.loop.run_until_complete( + self.client.close_connection(force=True)) + self.stop_client() + # The server should stop properly anyway. It used to hang because + # the worker handling the connection was waiting for the opening + # handshake. @unittest.mock.patch('websockets.server.read_request') def test_server_shuts_down_during_opening_handshake(self, _read_request): _read_request.side_effect = asyncio.CancelledError - self.start_server() - self.server.closing = True - with self.assertRaises(InvalidHandshake) as raised: - self.start_client() - self.stop_server() + with self.temp_server(): + self.server.closing = True + with self.assertRaises(InvalidHandshake) as raised: + self.start_client() # Opening handshake fails with 503 Service Unavailable self.assertEqual(str(raised.exception), "Status code not 101: 503") def test_server_shuts_down_during_connection_handling(self): - self.start_server() - self.start_client() + with self.temp_server(): + self.start_client() - self.server.close() - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(self.client.recv()) - self.stop_client() - self.stop_server() + self.server.close() + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.client.recv()) + self.stop_client() # Websocket connection terminates with 1001 Going Away. self.assertEqual(self.client.close_code, 1001) def test_invalid_status_error_during_client_connect(self): - self.start_server(klass=ForbiddenWebSocketServerProtocol) - with self.assertRaises(InvalidStatus) as raised: - self.start_client() - exception = raised.exception - self.assertEqual(str(exception), "Status code not 101: 403") - self.assertEqual(exception.code, 403) - self.stop_server() + with self.temp_server(klass=ForbiddenWebSocketServerProtocol): + with self.assertRaises(InvalidStatus) as raised: + self.start_client() + exception = raised.exception + self.assertEqual(str(exception), "Status code not 101: 403") + self.assertEqual(exception.code, 403) @unittest.mock.patch('websockets.server.read_request') def test_connection_error_during_opening_handshake(self, _read_request): _read_request.side_effect = ConnectionError - self.start_server() - # Exception appears to be platform-dependent: InvalidHandshake on - # macOS, ConnectionResetError on Linux. This doesn't matter; this - # test primarily aims at covering a code path on the server side. - with self.assertRaises(Exception): - self.start_client() - self.stop_server() + with self.temp_server(): + # Exception appears to be platform-dependent: InvalidHandshake on + # macOS, ConnectionResetError on Linux. This doesn't matter; this + # test primarily aims at covering a code path on the server side. + with self.assertRaises(Exception): + self.start_client() @unittest.mock.patch('websockets.server.WebSocketServerProtocol.close') def test_connection_error_during_closing_handshake(self, close): close.side_effect = ConnectionError - self.start_server() - self.start_client() - self.loop.run_until_complete(self.client.send("Hello!")) - reply = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(reply, "Hello!") - self.stop_client() - self.stop_server() + with self.temp_server(): + self.start_client() + self.loop.run_until_complete(self.client.send("Hello!")) + reply = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(reply, "Hello!") + self.stop_client() # Connection ends with an abnormal closure. self.assertEqual(self.client.close_code, 1006) @@ -499,11 +484,10 @@ def start_client(self, path='', **kwds): self.client = self.loop.run_until_complete(client) def test_ws_uri_is_rejected(self): - self.start_server() - client = connect('ws://localhost:8642/', ssl=self.client_context) - with self.assertRaises(ValueError): - self.loop.run_until_complete(client) - self.stop_server() + with self.temp_server(): + client = connect('ws://localhost:8642/', ssl=self.client_context) + with self.assertRaises(ValueError): + self.loop.run_until_complete(client) class ClientServerOriginTests(unittest.TestCase): From 571caeb914193ce625f7263e055da2226aa9ea6b Mon Sep 17 00:00:00 2001 From: Chris Jerdonek Date: Sat, 22 Jul 2017 18:09:04 -0700 Subject: [PATCH 0244/1539] Add with_server() test method decorator. --- websockets/test_client_server.py | 446 ++++++++++++++++--------------- 1 file changed, 233 insertions(+), 213 deletions(-) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 75b71c793..db06c9eee 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -1,4 +1,5 @@ import asyncio +import functools import http import http.client import logging @@ -46,6 +47,30 @@ class FORBIDDEN: phrase = 'Forbidden' +@contextmanager +def temp_test_server(test, **kwds): + test.start_server(**kwds) + try: + yield + finally: + test.stop_server() + + +def with_server(**kwds): + """ + Return a decorator for TestCase methods that starts and stops a server. + """ + def decorate(test_func): + @functools.wraps(test_func) + def _decorate(self, *args, **kwargs): + with temp_test_server(self, **kwds): + return test_func(self, *args, **kwargs) + + return _decorate + + return decorate + + class ForbiddenWebSocketServerProtocol(WebSocketServerProtocol): @asyncio.coroutine @@ -95,23 +120,20 @@ def stop_server(self): @contextmanager def temp_server(self, **kwds): - self.start_server(**kwds) - try: + with temp_test_server(self, **kwds): yield - finally: - self.stop_server() + @with_server() def test_basic(self): - with self.temp_server(): - self.start_client() - self.loop.run_until_complete(self.client.send("Hello!")) - reply = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(reply, "Hello!") - self.stop_client() + self.start_client() + self.loop.run_until_complete(self.client.send("Hello!")) + reply = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(reply, "Hello!") + self.stop_client() + @with_server() def test_server_close_while_client_connected(self): - with self.temp_server(): - self.start_client() + self.start_client() def test_explicit_event_loop(self): with self.temp_server(loop=self.loop): @@ -121,100 +143,99 @@ def test_explicit_event_loop(self): self.assertEqual(reply, "Hello!") self.stop_client() + @with_server() def test_protocol_attributes(self): - with self.temp_server(): - self.start_client('attributes') - expected_attrs = ('localhost', 8642, self.secure) - client_attrs = (self.client.host, self.client.port, - self.client.secure) - self.assertEqual(client_attrs, expected_attrs) - server_attrs = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(server_attrs, repr(expected_attrs)) - self.stop_client() - + self.start_client('attributes') + expected_attrs = ('localhost', 8642, self.secure) + client_attrs = (self.client.host, self.client.port, + self.client.secure) + self.assertEqual(client_attrs, expected_attrs) + server_attrs = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(server_attrs, repr(expected_attrs)) + self.stop_client() + + @with_server() def test_protocol_path(self): - with self.temp_server(): - self.start_client('path') - client_path = self.client.path - self.assertEqual(client_path, '/path') - server_path = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(server_path, '/path') - self.stop_client() - + self.start_client('path') + client_path = self.client.path + self.assertEqual(client_path, '/path') + server_path = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(server_path, '/path') + self.stop_client() + + @with_server() def test_protocol_headers(self): - with self.temp_server(): - self.start_client('headers') - client_req = self.client.request_headers - client_resp = self.client.response_headers - self.assertEqual(client_req['User-Agent'], USER_AGENT) - self.assertEqual(client_resp['Server'], USER_AGENT) - server_req = self.loop.run_until_complete(self.client.recv()) - server_resp = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(server_req, str(client_req)) - self.assertEqual(server_resp, str(client_resp)) - self.stop_client() - + self.start_client('headers') + client_req = self.client.request_headers + client_resp = self.client.response_headers + self.assertEqual(client_req['User-Agent'], USER_AGENT) + self.assertEqual(client_resp['Server'], USER_AGENT) + server_req = self.loop.run_until_complete(self.client.recv()) + server_resp = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(server_req, str(client_req)) + self.assertEqual(server_resp, str(client_resp)) + self.stop_client() + + @with_server() def test_protocol_raw_headers(self): - with self.temp_server(): - self.start_client('raw_headers') - client_req = self.client.raw_request_headers - client_resp = self.client.raw_response_headers - self.assertEqual(dict(client_req)['User-Agent'], USER_AGENT) - self.assertEqual(dict(client_resp)['Server'], USER_AGENT) - server_req = self.loop.run_until_complete(self.client.recv()) - server_resp = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(server_req, repr(client_req)) - self.assertEqual(server_resp, repr(client_resp)) - self.stop_client() - + self.start_client('raw_headers') + client_req = self.client.raw_request_headers + client_resp = self.client.raw_response_headers + self.assertEqual(dict(client_req)['User-Agent'], USER_AGENT) + self.assertEqual(dict(client_resp)['Server'], USER_AGENT) + server_req = self.loop.run_until_complete(self.client.recv()) + server_resp = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(server_req, repr(client_req)) + self.assertEqual(server_resp, repr(client_resp)) + self.stop_client() + + @with_server() def test_protocol_custom_request_headers_dict(self): - with self.temp_server(): - self.start_client('raw_headers', extra_headers={'X-Spam': 'Eggs'}) - req_headers = self.loop.run_until_complete(self.client.recv()) - self.loop.run_until_complete(self.client.recv()) - self.assertIn("('X-Spam', 'Eggs')", req_headers) - self.stop_client() + self.start_client('raw_headers', extra_headers={'X-Spam': 'Eggs'}) + req_headers = self.loop.run_until_complete(self.client.recv()) + self.loop.run_until_complete(self.client.recv()) + self.assertIn("('X-Spam', 'Eggs')", req_headers) + self.stop_client() + @with_server() def test_protocol_custom_request_headers_list(self): - with self.temp_server(): - self.start_client('raw_headers', - extra_headers=[('X-Spam', 'Eggs')]) - req_headers = self.loop.run_until_complete(self.client.recv()) - self.loop.run_until_complete(self.client.recv()) - self.assertIn("('X-Spam', 'Eggs')", req_headers) - self.stop_client() + self.start_client('raw_headers', extra_headers=[('X-Spam', 'Eggs')]) + req_headers = self.loop.run_until_complete(self.client.recv()) + self.loop.run_until_complete(self.client.recv()) + self.assertIn("('X-Spam', 'Eggs')", req_headers) + self.stop_client() + @with_server(extra_headers=lambda p, r: {'X-Spam': 'Eggs'}) def test_protocol_custom_response_headers_callable_dict(self): - with self.temp_server(extra_headers=lambda p, r: {'X-Spam': 'Eggs'}): - self.start_client('raw_headers') - self.loop.run_until_complete(self.client.recv()) - resp_headers = self.loop.run_until_complete(self.client.recv()) - self.assertIn("('X-Spam', 'Eggs')", resp_headers) - self.stop_client() + self.start_client('raw_headers') + self.loop.run_until_complete(self.client.recv()) + resp_headers = self.loop.run_until_complete(self.client.recv()) + self.assertIn("('X-Spam', 'Eggs')", resp_headers) + self.stop_client() + @with_server(extra_headers=lambda p, r: [('X-Spam', 'Eggs')]) def test_protocol_custom_response_headers_callable_list(self): - with self.temp_server(extra_headers=lambda p, r: [('X-Spam', 'Eggs')]): - self.start_client('raw_headers') - self.loop.run_until_complete(self.client.recv()) - resp_headers = self.loop.run_until_complete(self.client.recv()) - self.assertIn("('X-Spam', 'Eggs')", resp_headers) - self.stop_client() + self.start_client('raw_headers') + self.loop.run_until_complete(self.client.recv()) + resp_headers = self.loop.run_until_complete(self.client.recv()) + self.assertIn("('X-Spam', 'Eggs')", resp_headers) + self.stop_client() + @with_server(extra_headers={'X-Spam': 'Eggs'}) def test_protocol_custom_response_headers_dict(self): - with self.temp_server(extra_headers={'X-Spam': 'Eggs'}): - self.start_client('raw_headers') - self.loop.run_until_complete(self.client.recv()) - resp_headers = self.loop.run_until_complete(self.client.recv()) - self.assertIn("('X-Spam', 'Eggs')", resp_headers) - self.stop_client() + self.start_client('raw_headers') + self.loop.run_until_complete(self.client.recv()) + resp_headers = self.loop.run_until_complete(self.client.recv()) + self.assertIn("('X-Spam', 'Eggs')", resp_headers) + self.stop_client() + @with_server(extra_headers={'X-Spam': 'Eggs'}) def test_protocol_custom_response_headers_list(self): - with self.temp_server(extra_headers=[('X-Spam', 'Eggs')]): - self.start_client('raw_headers') - self.loop.run_until_complete(self.client.recv()) - resp_headers = self.loop.run_until_complete(self.client.recv()) - self.assertIn("('X-Spam', 'Eggs')", resp_headers) - self.stop_client() + self.start_client('raw_headers') + self.loop.run_until_complete(self.client.recv()) + resp_headers = self.loop.run_until_complete(self.client.recv()) + self.assertIn("('X-Spam', 'Eggs')", resp_headers) + self.stop_client() def test_get_response_status_attributes_available(self): # Save the attribute values to a dict instead of asserting inside @@ -244,104 +265,104 @@ def get_response_status(self, set_header): self.assertIsInstance(request_headers, http.client.HTTPMessage) self.assertEqual(request_headers.get('origin'), 'http://otherhost') + @with_server(klass=ForbiddenWebSocketServerProtocol) def test_authentication(self): - with self.temp_server(klass=ForbiddenWebSocketServerProtocol): - with self.assertRaises(InvalidStatus): - self.start_client() + with self.assertRaises(InvalidStatus): + self.start_client() + @with_server() def test_no_subprotocol(self): - with self.temp_server(): - self.start_client('subprotocol') - server_subprotocol = self.loop.run_until_complete( - self.client.recv()) - self.assertEqual(server_subprotocol, repr(None)) - self.assertEqual(self.client.subprotocol, None) - self.stop_client() - + self.start_client('subprotocol') + server_subprotocol = self.loop.run_until_complete( + self.client.recv()) + self.assertEqual(server_subprotocol, repr(None)) + self.assertEqual(self.client.subprotocol, None) + self.stop_client() + + @with_server(subprotocols=['superchat', 'chat']) def test_subprotocol_found(self): - with self.temp_server(subprotocols=['superchat', 'chat']): - self.start_client('subprotocol', - subprotocols=['otherchat', 'chat']) - server_subprotocol = self.loop.run_until_complete( - self.client.recv()) - self.assertEqual(server_subprotocol, repr('chat')) - self.assertEqual(self.client.subprotocol, 'chat') - self.stop_client() - + self.start_client('subprotocol', + subprotocols=['otherchat', 'chat']) + server_subprotocol = self.loop.run_until_complete( + self.client.recv()) + self.assertEqual(server_subprotocol, repr('chat')) + self.assertEqual(self.client.subprotocol, 'chat') + self.stop_client() + + @with_server(subprotocols=['superchat']) def test_subprotocol_not_found(self): - with self.temp_server(subprotocols=['superchat']): - self.start_client('subprotocol', subprotocols=['otherchat']) - server_subprotocol = self.loop.run_until_complete( - self.client.recv()) - self.assertEqual(server_subprotocol, repr(None)) - self.assertEqual(self.client.subprotocol, None) - self.stop_client() - + self.start_client('subprotocol', subprotocols=['otherchat']) + server_subprotocol = self.loop.run_until_complete( + self.client.recv()) + self.assertEqual(server_subprotocol, repr(None)) + self.assertEqual(self.client.subprotocol, None) + self.stop_client() + + @with_server() def test_subprotocol_not_offered(self): - with self.temp_server(): - self.start_client('subprotocol', - subprotocols=['otherchat', 'chat']) - server_subprotocol = self.loop.run_until_complete( - self.client.recv()) - self.assertEqual(server_subprotocol, repr(None)) - self.assertEqual(self.client.subprotocol, None) - self.stop_client() - + self.start_client('subprotocol', subprotocols=['otherchat', 'chat']) + server_subprotocol = self.loop.run_until_complete( + self.client.recv()) + self.assertEqual(server_subprotocol, repr(None)) + self.assertEqual(self.client.subprotocol, None) + self.stop_client() + + @with_server(subprotocols=['superchat', 'chat']) def test_subprotocol_not_requested(self): - with self.temp_server(subprotocols=['superchat', 'chat']): - self.start_client('subprotocol') - server_subprotocol = self.loop.run_until_complete( - self.client.recv()) - self.assertEqual(server_subprotocol, repr(None)) - self.assertEqual(self.client.subprotocol, None) - self.stop_client() - + self.start_client('subprotocol') + server_subprotocol = self.loop.run_until_complete( + self.client.recv()) + self.assertEqual(server_subprotocol, repr(None)) + self.assertEqual(self.client.subprotocol, None) + self.stop_client() + + @with_server(subprotocols=['superchat']) @unittest.mock.patch.object(WebSocketServerProtocol, 'select_subprotocol') def test_subprotocol_error(self, _select_subprotocol): _select_subprotocol.return_value = 'superchat' - with self.temp_server(subprotocols=['superchat']): - with self.assertRaises(InvalidHandshake): - self.start_client('subprotocol', subprotocols=['otherchat']) - self.run_loop_once() + with self.assertRaises(InvalidHandshake): + self.start_client('subprotocol', subprotocols=['otherchat']) + self.run_loop_once() + @with_server() @unittest.mock.patch('websockets.server.read_request') def test_server_receives_malformed_request(self, _read_request): _read_request.side_effect = ValueError("read_request failed") - with self.temp_server(): - with self.assertRaises(InvalidHandshake): - self.start_client() + with self.assertRaises(InvalidHandshake): + self.start_client() + @with_server() @unittest.mock.patch('websockets.client.read_response') def test_client_receives_malformed_response(self, _read_response): _read_response.side_effect = ValueError("read_response failed") - with self.temp_server(): - with self.assertRaises(InvalidHandshake): - self.start_client() - self.run_loop_once() + with self.assertRaises(InvalidHandshake): + self.start_client() + self.run_loop_once() + @with_server() @unittest.mock.patch('websockets.client.build_request') def test_client_sends_invalid_handshake_request(self, _build_request): def wrong_build_request(set_header): return '42' _build_request.side_effect = wrong_build_request - with self.temp_server(): - with self.assertRaises(InvalidHandshake): - self.start_client() + with self.assertRaises(InvalidHandshake): + self.start_client() + @with_server() @unittest.mock.patch('websockets.server.build_response') def test_server_sends_invalid_handshake_response(self, _build_response): def wrong_build_response(set_header, key): return build_response(set_header, '42') _build_response.side_effect = wrong_build_response - with self.temp_server(): - with self.assertRaises(InvalidHandshake): - self.start_client() + with self.assertRaises(InvalidHandshake): + self.start_client() + @with_server() @unittest.mock.patch('websockets.client.read_response') def test_server_does_not_switch_protocols(self, _read_response): @asyncio.coroutine @@ -350,106 +371,105 @@ def wrong_read_response(stream): return 400, headers _read_response.side_effect = wrong_read_response - with self.temp_server(): - with self.assertRaises(InvalidStatus): - self.start_client() - self.run_loop_once() + with self.assertRaises(InvalidStatus): + self.start_client() + self.run_loop_once() + @with_server() @unittest.mock.patch('websockets.server.WebSocketServerProtocol.send') def test_server_handler_crashes(self, send): send.side_effect = ValueError("send failed") - with self.temp_server(): - self.start_client() - self.loop.run_until_complete(self.client.send("Hello!")) - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(self.client.recv()) - self.stop_client() + self.start_client() + self.loop.run_until_complete(self.client.send("Hello!")) + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.client.recv()) + self.stop_client() # Connection ends with an unexpected error. self.assertEqual(self.client.close_code, 1011) + @with_server() @unittest.mock.patch('websockets.server.WebSocketServerProtocol.close') def test_server_close_crashes(self, close): close.side_effect = ValueError("close failed") - with self.temp_server(): - self.start_client() - self.loop.run_until_complete(self.client.send("Hello!")) - reply = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(reply, "Hello!") - self.stop_client() + self.start_client() + self.loop.run_until_complete(self.client.send("Hello!")) + reply = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(reply, "Hello!") + self.stop_client() # Connection ends with an abnormal closure. self.assertEqual(self.client.close_code, 1006) + @with_server() @unittest.mock.patch.object(WebSocketClientProtocol, 'handshake') def test_client_closes_connection_before_handshake(self, handshake): - with self.temp_server(): - self.start_client() - # We have mocked the handshake() method to prevent the client - # from performing the opening handshake. Force it to close the - # connection. - self.loop.run_until_complete( - self.client.close_connection(force=True)) - self.stop_client() - # The server should stop properly anyway. It used to hang because - # the worker handling the connection was waiting for the opening - # handshake. - + self.start_client() + # We have mocked the handshake() method to prevent the client + # from performing the opening handshake. Force it to close the + # connection. + self.loop.run_until_complete( + self.client.close_connection(force=True)) + self.stop_client() + # The server should stop properly anyway. It used to hang because + # the worker handling the connection was waiting for the opening + # handshake. + + @with_server() @unittest.mock.patch('websockets.server.read_request') def test_server_shuts_down_during_opening_handshake(self, _read_request): _read_request.side_effect = asyncio.CancelledError - with self.temp_server(): - self.server.closing = True - with self.assertRaises(InvalidHandshake) as raised: - self.start_client() + self.server.closing = True + with self.assertRaises(InvalidHandshake) as raised: + self.start_client() # Opening handshake fails with 503 Service Unavailable self.assertEqual(str(raised.exception), "Status code not 101: 503") + @with_server() def test_server_shuts_down_during_connection_handling(self): - with self.temp_server(): - self.start_client() + self.start_client() - self.server.close() - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(self.client.recv()) - self.stop_client() + self.server.close() + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.client.recv()) + self.stop_client() # Websocket connection terminates with 1001 Going Away. self.assertEqual(self.client.close_code, 1001) + @with_server(klass=ForbiddenWebSocketServerProtocol) def test_invalid_status_error_during_client_connect(self): - with self.temp_server(klass=ForbiddenWebSocketServerProtocol): - with self.assertRaises(InvalidStatus) as raised: - self.start_client() - exception = raised.exception - self.assertEqual(str(exception), "Status code not 101: 403") - self.assertEqual(exception.code, 403) + with self.assertRaises(InvalidStatus) as raised: + self.start_client() + exception = raised.exception + self.assertEqual(str(exception), "Status code not 101: 403") + self.assertEqual(exception.code, 403) + @with_server() @unittest.mock.patch('websockets.server.read_request') def test_connection_error_during_opening_handshake(self, _read_request): _read_request.side_effect = ConnectionError - with self.temp_server(): - # Exception appears to be platform-dependent: InvalidHandshake on - # macOS, ConnectionResetError on Linux. This doesn't matter; this - # test primarily aims at covering a code path on the server side. - with self.assertRaises(Exception): - self.start_client() + # Exception appears to be platform-dependent: InvalidHandshake on + # macOS, ConnectionResetError on Linux. This doesn't matter; this + # test primarily aims at covering a code path on the server side. + with self.assertRaises(Exception): + self.start_client() + @with_server() @unittest.mock.patch('websockets.server.WebSocketServerProtocol.close') def test_connection_error_during_closing_handshake(self, close): close.side_effect = ConnectionError - with self.temp_server(): - self.start_client() - self.loop.run_until_complete(self.client.send("Hello!")) - reply = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(reply, "Hello!") - self.stop_client() + self.start_client() + self.loop.run_until_complete(self.client.send("Hello!")) + reply = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(reply, "Hello!") + self.stop_client() # Connection ends with an abnormal closure. self.assertEqual(self.client.close_code, 1006) @@ -483,11 +503,11 @@ def start_client(self, path='', **kwds): client = connect('wss://localhost:8642/' + path, **kwds) self.client = self.loop.run_until_complete(client) + @with_server() def test_ws_uri_is_rejected(self): - with self.temp_server(): - client = connect('ws://localhost:8642/', ssl=self.client_context) - with self.assertRaises(ValueError): - self.loop.run_until_complete(client) + client = connect('ws://localhost:8642/', ssl=self.client_context) + with self.assertRaises(ValueError): + self.loop.run_until_complete(client) class ClientServerOriginTests(unittest.TestCase): From 4703884a5c82452c25ed39224160913162cf448e Mon Sep 17 00:00:00 2001 From: Chris Jerdonek Date: Sat, 22 Jul 2017 18:13:59 -0700 Subject: [PATCH 0245/1539] Restore lines to their original lengths. --- websockets/test_client_server.py | 31 +++++++++++-------------------- 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index db06c9eee..9f3bbbc8c 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -147,8 +147,7 @@ def test_explicit_event_loop(self): def test_protocol_attributes(self): self.start_client('attributes') expected_attrs = ('localhost', 8642, self.secure) - client_attrs = (self.client.host, self.client.port, - self.client.secure) + client_attrs = (self.client.host, self.client.port, self.client.secure) self.assertEqual(client_attrs, expected_attrs) server_attrs = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_attrs, repr(expected_attrs)) @@ -273,18 +272,15 @@ def test_authentication(self): @with_server() def test_no_subprotocol(self): self.start_client('subprotocol') - server_subprotocol = self.loop.run_until_complete( - self.client.recv()) + server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) self.stop_client() @with_server(subprotocols=['superchat', 'chat']) def test_subprotocol_found(self): - self.start_client('subprotocol', - subprotocols=['otherchat', 'chat']) - server_subprotocol = self.loop.run_until_complete( - self.client.recv()) + self.start_client('subprotocol', subprotocols=['otherchat', 'chat']) + server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr('chat')) self.assertEqual(self.client.subprotocol, 'chat') self.stop_client() @@ -292,8 +288,7 @@ def test_subprotocol_found(self): @with_server(subprotocols=['superchat']) def test_subprotocol_not_found(self): self.start_client('subprotocol', subprotocols=['otherchat']) - server_subprotocol = self.loop.run_until_complete( - self.client.recv()) + server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) self.stop_client() @@ -301,8 +296,7 @@ def test_subprotocol_not_found(self): @with_server() def test_subprotocol_not_offered(self): self.start_client('subprotocol', subprotocols=['otherchat', 'chat']) - server_subprotocol = self.loop.run_until_complete( - self.client.recv()) + server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) self.stop_client() @@ -310,8 +304,7 @@ def test_subprotocol_not_offered(self): @with_server(subprotocols=['superchat', 'chat']) def test_subprotocol_not_requested(self): self.start_client('subprotocol') - server_subprotocol = self.loop.run_until_complete( - self.client.recv()) + server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) self.stop_client() @@ -407,15 +400,13 @@ def test_server_close_crashes(self, close): @unittest.mock.patch.object(WebSocketClientProtocol, 'handshake') def test_client_closes_connection_before_handshake(self, handshake): self.start_client() - # We have mocked the handshake() method to prevent the client - # from performing the opening handshake. Force it to close the - # connection. + # We have mocked the handshake() method to prevent the client from + # performing the opening handshake. Force it to close the connection. self.loop.run_until_complete( self.client.close_connection(force=True)) self.stop_client() - # The server should stop properly anyway. It used to hang because - # the worker handling the connection was waiting for the opening - # handshake. + # The server should stop properly anyway. It used to hang because the + # worker handling the connection was waiting for the opening handshake. @with_server() @unittest.mock.patch('websockets.server.read_request') From ccc5d34f068e1e3e6b51c49aa0c69983d9cd104d Mon Sep 17 00:00:00 2001 From: Chris Jerdonek Date: Sun, 23 Jul 2017 13:46:45 -0700 Subject: [PATCH 0246/1539] Fix argument mis-copy in PR #219. --- websockets/test_client_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 9f3bbbc8c..1af7bcb3e 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -228,7 +228,7 @@ def test_protocol_custom_response_headers_dict(self): self.assertIn("('X-Spam', 'Eggs')", resp_headers) self.stop_client() - @with_server(extra_headers={'X-Spam': 'Eggs'}) + @with_server(extra_headers=[('X-Spam', 'Eggs')]) def test_protocol_custom_response_headers_list(self): self.start_client('raw_headers') self.loop.run_until_complete(self.client.recv()) From 55dfa10ff0a4c767c4aafc19480b9c4f1a4e30d8 Mon Sep 17 00:00:00 2001 From: mayeut Date: Sun, 23 Jul 2017 18:13:47 +0200 Subject: [PATCH 0247/1539] Replace modulo on signed integers by masking Compiler is not smart enough to know that both operands are positive & optimize this with masking. We do what it should have been doing. --- websockets/speedups.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/websockets/speedups.c b/websockets/speedups.c index 70806e43c..2cc1636b9 100644 --- a/websockets/speedups.c +++ b/websockets/speedups.c @@ -8,7 +8,7 @@ #include #endif -const Py_ssize_t MASK_LEN = 4; +static const Py_ssize_t MASK_LEN = 4; static PyObject * apply_mask(PyObject *self, PyObject *args, PyObject *kwds) @@ -90,7 +90,7 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds) for (; i < input_len; i++) { - output[i] = input[i] ^ mask[i % MASK_LEN]; + output[i] = input[i] ^ mask[i & (MASK_LEN - 1)]; } return result; From 739af7ef5f9bd974f389e50e367d4f9f07f71790 Mon Sep 17 00:00:00 2001 From: mayeut Date: Sun, 23 Jul 2017 18:09:41 +0200 Subject: [PATCH 0248/1539] Don't accept Unicode objects in speedups.apply_mask Only non-mutable bytes-like objects shall be supported --- websockets/speedups.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/websockets/speedups.c b/websockets/speedups.c index 2cc1636b9..7a18d6107 100644 --- a/websockets/speedups.c +++ b/websockets/speedups.c @@ -30,7 +30,7 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds) Py_ssize_t i = 0; if (!PyArg_ParseTupleAndKeywords( - args, kwds, "s#s#", kwlist, &input, &input_len, &mask, &mask_len)) + args, kwds, "y#y#", kwlist, &input, &input_len, &mask, &mask_len)) { return NULL; } From 5f096507984b9a5d226652ca12ae411ea8f03235 Mon Sep 17 00:00:00 2001 From: Chris Jerdonek Date: Sun, 23 Jul 2017 17:17:59 -0700 Subject: [PATCH 0249/1539] Add with_client() test method decorator. --- websockets/test_client_server.py | 105 +++++++++++++++++-------------- 1 file changed, 57 insertions(+), 48 deletions(-) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 1af7bcb3e..dde6a962e 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -56,21 +56,44 @@ def temp_test_server(test, **kwds): test.stop_server() -def with_server(**kwds): +@contextmanager +def temp_test_client(test, *args, **kwds): + test.start_client(*args, **kwds) + try: + yield + finally: + test.stop_client() + + +def with_manager(manager, *args, **kwds): """ - Return a decorator for TestCase methods that starts and stops a server. + Return a decorator that wraps a function with a context manager. """ - def decorate(test_func): - @functools.wraps(test_func) - def _decorate(self, *args, **kwargs): - with temp_test_server(self, **kwds): - return test_func(self, *args, **kwargs) + def decorate(func): + @functools.wraps(func) + def _decorate(self, *_args, **_kwds): + with manager(self, *args, **kwds): + return func(self, *_args, **_kwds) return _decorate return decorate +def with_server(**kwds): + """ + Return a decorator for TestCase methods that starts and stops a server. + """ + return with_manager(temp_test_server, **kwds) + + +def with_client(*args, **kwds): + """ + Return a decorator for TestCase methods that starts and stops a client. + """ + return with_manager(temp_test_client, *args, **kwds) + + class ForbiddenWebSocketServerProtocol(WebSocketServerProtocol): @asyncio.coroutine @@ -123,13 +146,17 @@ def temp_server(self, **kwds): with temp_test_server(self, **kwds): yield + @contextmanager + def temp_client(self, *args, **kwds): + with temp_test_client(self, *args, **kwds): + yield + @with_server() + @with_client() def test_basic(self): - self.start_client() self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") - self.stop_client() @with_server() def test_server_close_while_client_connected(self): @@ -137,34 +164,31 @@ def test_server_close_while_client_connected(self): def test_explicit_event_loop(self): with self.temp_server(loop=self.loop): - self.start_client(loop=self.loop) - self.loop.run_until_complete(self.client.send("Hello!")) - reply = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(reply, "Hello!") - self.stop_client() + with self.temp_client(loop=self.loop): + self.loop.run_until_complete(self.client.send("Hello!")) + reply = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(reply, "Hello!") @with_server() + @with_client('attributes') def test_protocol_attributes(self): - self.start_client('attributes') expected_attrs = ('localhost', 8642, self.secure) client_attrs = (self.client.host, self.client.port, self.client.secure) self.assertEqual(client_attrs, expected_attrs) server_attrs = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_attrs, repr(expected_attrs)) - self.stop_client() @with_server() + @with_client('path') def test_protocol_path(self): - self.start_client('path') client_path = self.client.path self.assertEqual(client_path, '/path') server_path = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_path, '/path') - self.stop_client() @with_server() + @with_client('headers') def test_protocol_headers(self): - self.start_client('headers') client_req = self.client.request_headers client_resp = self.client.response_headers self.assertEqual(client_req['User-Agent'], USER_AGENT) @@ -173,11 +197,10 @@ def test_protocol_headers(self): server_resp = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_req, str(client_req)) self.assertEqual(server_resp, str(client_resp)) - self.stop_client() @with_server() + @with_client('raw_headers') def test_protocol_raw_headers(self): - self.start_client('raw_headers') client_req = self.client.raw_request_headers client_resp = self.client.raw_response_headers self.assertEqual(dict(client_req)['User-Agent'], USER_AGENT) @@ -186,55 +209,48 @@ def test_protocol_raw_headers(self): server_resp = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_req, repr(client_req)) self.assertEqual(server_resp, repr(client_resp)) - self.stop_client() @with_server() + @with_client('raw_headers', extra_headers={'X-Spam': 'Eggs'}) def test_protocol_custom_request_headers_dict(self): - self.start_client('raw_headers', extra_headers={'X-Spam': 'Eggs'}) req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", req_headers) - self.stop_client() @with_server() + @with_client('raw_headers', extra_headers=[('X-Spam', 'Eggs')]) def test_protocol_custom_request_headers_list(self): - self.start_client('raw_headers', extra_headers=[('X-Spam', 'Eggs')]) req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", req_headers) - self.stop_client() @with_server(extra_headers=lambda p, r: {'X-Spam': 'Eggs'}) + @with_client('raw_headers') def test_protocol_custom_response_headers_callable_dict(self): - self.start_client('raw_headers') self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) - self.stop_client() @with_server(extra_headers=lambda p, r: [('X-Spam', 'Eggs')]) + @with_client('raw_headers') def test_protocol_custom_response_headers_callable_list(self): - self.start_client('raw_headers') self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) - self.stop_client() @with_server(extra_headers={'X-Spam': 'Eggs'}) + @with_client('raw_headers') def test_protocol_custom_response_headers_dict(self): - self.start_client('raw_headers') self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) - self.stop_client() @with_server(extra_headers=[('X-Spam', 'Eggs')]) + @with_client('raw_headers') def test_protocol_custom_response_headers_list(self): - self.start_client('raw_headers') self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) - self.stop_client() def test_get_response_status_attributes_available(self): # Save the attribute values to a dict instead of asserting inside @@ -270,44 +286,39 @@ def test_authentication(self): self.start_client() @with_server() + @with_client('subprotocol') def test_no_subprotocol(self): - self.start_client('subprotocol') server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) - self.stop_client() @with_server(subprotocols=['superchat', 'chat']) + @with_client('subprotocol', subprotocols=['otherchat', 'chat']) def test_subprotocol_found(self): - self.start_client('subprotocol', subprotocols=['otherchat', 'chat']) server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr('chat')) self.assertEqual(self.client.subprotocol, 'chat') - self.stop_client() @with_server(subprotocols=['superchat']) + @with_client('subprotocol', subprotocols=['otherchat']) def test_subprotocol_not_found(self): - self.start_client('subprotocol', subprotocols=['otherchat']) server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) - self.stop_client() @with_server() + @with_client('subprotocol', subprotocols=['otherchat', 'chat']) def test_subprotocol_not_offered(self): - self.start_client('subprotocol', subprotocols=['otherchat', 'chat']) server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) - self.stop_client() @with_server(subprotocols=['superchat', 'chat']) + @with_client('subprotocol') def test_subprotocol_not_requested(self): - self.start_client('subprotocol') server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) - self.stop_client() @with_server(subprotocols=['superchat']) @unittest.mock.patch.object(WebSocketServerProtocol, 'select_subprotocol') @@ -397,14 +408,12 @@ def test_server_close_crashes(self, close): self.assertEqual(self.client.close_code, 1006) @with_server() + @with_client() @unittest.mock.patch.object(WebSocketClientProtocol, 'handshake') def test_client_closes_connection_before_handshake(self, handshake): - self.start_client() # We have mocked the handshake() method to prevent the client from # performing the opening handshake. Force it to close the connection. - self.loop.run_until_complete( - self.client.close_connection(force=True)) - self.stop_client() + self.loop.run_until_complete(self.client.close_connection(force=True)) # The server should stop properly anyway. It used to hang because the # worker handling the connection was waiting for the opening handshake. From c07ff418fa64398d7aab6cc9f72cc44f32ed773e Mon Sep 17 00:00:00 2001 From: Chris Jerdonek Date: Sun, 23 Jul 2017 23:56:39 -0700 Subject: [PATCH 0250/1539] Use temp_client() where we can't use the with_client() decorator. --- websockets/test_client_server.py | 37 ++++++++++++++------------------ 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index dde6a962e..980356ee9 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -384,11 +384,10 @@ def wrong_read_response(stream): def test_server_handler_crashes(self, send): send.side_effect = ValueError("send failed") - self.start_client() - self.loop.run_until_complete(self.client.send("Hello!")) - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(self.client.recv()) - self.stop_client() + with self.temp_client(): + self.loop.run_until_complete(self.client.send("Hello!")) + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.client.recv()) # Connection ends with an unexpected error. self.assertEqual(self.client.close_code, 1011) @@ -398,11 +397,10 @@ def test_server_handler_crashes(self, send): def test_server_close_crashes(self, close): close.side_effect = ValueError("close failed") - self.start_client() - self.loop.run_until_complete(self.client.send("Hello!")) - reply = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(reply, "Hello!") - self.stop_client() + with self.temp_client(): + self.loop.run_until_complete(self.client.send("Hello!")) + reply = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(reply, "Hello!") # Connection ends with an abnormal closure. self.assertEqual(self.client.close_code, 1006) @@ -431,12 +429,10 @@ def test_server_shuts_down_during_opening_handshake(self, _read_request): @with_server() def test_server_shuts_down_during_connection_handling(self): - self.start_client() - - self.server.close() - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(self.client.recv()) - self.stop_client() + with self.temp_client(): + self.server.close() + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.client.recv()) # Websocket connection terminates with 1001 Going Away. self.assertEqual(self.client.close_code, 1001) @@ -465,11 +461,10 @@ def test_connection_error_during_opening_handshake(self, _read_request): def test_connection_error_during_closing_handshake(self, close): close.side_effect = ConnectionError - self.start_client() - self.loop.run_until_complete(self.client.send("Hello!")) - reply = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(reply, "Hello!") - self.stop_client() + with self.temp_client(): + self.loop.run_until_complete(self.client.send("Hello!")) + reply = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(reply, "Hello!") # Connection ends with an abnormal closure. self.assertEqual(self.client.close_code, 1006) From f859e2f017903abf195d3663ba486d6d1ef2cb23 Mon Sep 17 00:00:00 2001 From: Chris Jerdonek Date: Sat, 29 Jul 2017 04:05:50 -0700 Subject: [PATCH 0251/1539] Update intersphinx_mapping setting. --- docs/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index 258b57ede..48006a483 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -256,4 +256,4 @@ # Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = {'http://docs.python.org/3/': None} +intersphinx_mapping = {'https://docs.python.org/3/': None} From 725675eb20eb78e4cc3a9230daea3cdaf3dc7ffd Mon Sep 17 00:00:00 2001 From: Chris Jerdonek Date: Thu, 27 Jul 2017 11:54:20 -0700 Subject: [PATCH 0252/1539] Address issue #216: rename the klass argument to create_protocol. --- docs/api.rst | 4 +- docs/changelog.rst | 5 +++ docs/cheatsheet.rst | 6 ++- websockets/client.py | 14 ++++-- websockets/server.py | 13 +++++- websockets/test_client_server.py | 73 +++++++++++++++++++++++++++++--- 6 files changed, 100 insertions(+), 15 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index 9dd5f8f88..26fdc25bc 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -32,7 +32,7 @@ Server .. automodule:: websockets.server - .. autofunction:: serve(ws_handler, host=None, port=None, *, klass=WebSocketServerProtocol, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, subprotocols=None, extra_headers=None, **kwds) + .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, subprotocols=None, extra_headers=None, **kwds) .. autoclass:: WebSocketServer @@ -50,7 +50,7 @@ Client .. automodule:: websockets.client - .. autofunction:: connect(uri, *, klass=WebSocketClientProtocol, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, subprotocols=None, extra_headers=None, **kwds) + .. autofunction:: connect(uri, *, create_protocol=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, subprotocols=None, extra_headers=None, **kwds) .. autoclass:: WebSocketClientProtocol(*, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None) diff --git a/docs/changelog.rst b/docs/changelog.rst index 6b0241239..747948ee8 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -6,6 +6,11 @@ Changelog *In development* +* Renamed :func:`~websockets.server.serve()` and + :func:`~websockets.client.connect()`'s ``klass`` argument to + ``create_protocol`` to reflect that it can also be a callable. + For backwards compatibility, ``klass`` is still supported. + * :func:`~websockets.server.serve` can be used as an asynchronous context manager on Python ≥ 3.5. diff --git a/docs/cheatsheet.rst b/docs/cheatsheet.rst index cf6897257..5ee2c221f 100644 --- a/docs/cheatsheet.rst +++ b/docs/cheatsheet.rst @@ -23,7 +23,8 @@ Server the handler exits normally or with an exception. * You may subclass :class:`~websockets.server.WebSocketServerProtocol` and - pass it in the ``klass`` keyword argument for advanced customization. + pass it or a factory function as the ``create_protocol`` argument for + advanced customization. Client ------ @@ -34,7 +35,8 @@ Client * On Python ≥ 3.5, you can also use it as an asynchronous context manager. * You may subclass :class:`~websockets.server.WebSocketClientProtocol` and - pass it in the ``klass`` keyword argument for advanced customization. + pass it or a factory function as the ``create_protocol`` argument for + advanced customization. * Call :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` and :meth:`~websockets.protocol.WebSocketCommonProtocol.send` to receive and diff --git a/websockets/client.py b/websockets/client.py index 143ec37a0..4053c2863 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -130,8 +130,7 @@ def handshake(self, wsuri, @asyncio.coroutine -def connect(uri, *, - klass=WebSocketClientProtocol, +def connect(uri, *, create_protocol=None, klass=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, legacy_recv=False, @@ -156,6 +155,13 @@ def connect(uri, *, ``read_limit``, and ``write_limit`` optional arguments is described in the documentation of :class:`~websockets.protocol.WebSocketCommonProtocol`. + The ``create_protocol`` parameter allows customizing the + :class:`WebSocketClientProtocol` class used. The argument should be a + callable or class accepting the same arguments as + :class:`WebSocketClientProtocol` and that returns a + :class:`WebSocketClientProtocol` instance. It defaults to + :class:`WebSocketClientProtocol`. + :func:`connect` also accepts the following optional arguments: * ``origin`` sets the Origin HTTP header @@ -175,13 +181,15 @@ def connect(uri, *, if loop is None: loop = asyncio.get_event_loop() + create_protocol = create_protocol or klass or WebSocketClientProtocol + wsuri = parse_uri(uri) if wsuri.secure: kwds.setdefault('ssl', True) elif kwds.get('ssl') is not None: raise ValueError("connect() received a SSL context for a ws:// URI. " "Use a wss:// URI to enable TLS.") - factory = lambda: klass( + factory = lambda: create_protocol( host=wsuri.host, port=wsuri.port, secure=wsuri.secure, timeout=timeout, max_size=max_size, max_queue=max_queue, read_limit=read_limit, write_limit=write_limit, diff --git a/websockets/server.py b/websockets/server.py index 279d814df..5c938aa25 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -407,7 +407,7 @@ def wait_closed(self): @asyncio.coroutine def serve(ws_handler, host=None, port=None, *, - klass=WebSocketServerProtocol, + create_protocol=None, klass=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, legacy_recv=False, @@ -440,6 +440,13 @@ def serve(ws_handler, host=None, port=None, *, set the ``ssl`` keyword argument to a :class:`~ssl.SSLContext` to enable TLS. + The ``create_protocol`` parameter allows customizing the + :class:`WebSocketServerProtocol` class used. The argument should be a + callable or class accepting the same arguments as + :class:`WebSocketServerProtocol` and that returns a + :class:`WebSocketServerProtocol` instance. It defaults to + :class:`WebSocketServerProtocol`. + The behavior of the ``timeout``, ``max_size``, and ``max_queue``, ``read_limit``, and ``write_limit`` optional arguments is described in the documentation of :class:`~websockets.protocol.WebSocketCommonProtocol`. @@ -472,10 +479,12 @@ def serve(ws_handler, host=None, port=None, *, if loop is None: loop = asyncio.get_event_loop() + create_protocol = create_protocol or klass or WebSocketServerProtocol + ws_server = WebSocketServer(loop) secure = kwds.get('ssl') is not None - factory = lambda: klass( + factory = lambda: create_protocol( ws_handler, ws_server, host=host, port=port, secure=secure, timeout=timeout, max_size=max_size, max_queue=max_queue, diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 980356ee9..0edc1408c 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -40,8 +40,14 @@ def handler(ws, path): try: + # Order by status code. + UNAUTHORIZED = http.HTTPStatus.UNAUTHORIZED FORBIDDEN = http.HTTPStatus.FORBIDDEN except AttributeError: # pragma: no cover + class UNAUTHORIZED: + value = 401 + phrase = 'Unauthorized' + class FORBIDDEN: value = 403 phrase = 'Forbidden' @@ -94,13 +100,28 @@ def with_client(*args, **kwds): return with_manager(temp_test_client, *args, **kwds) -class ForbiddenWebSocketServerProtocol(WebSocketServerProtocol): +class UnauthorizedServerProtocol(WebSocketServerProtocol): + + @asyncio.coroutine + def get_response_status(self, set_header): + return UNAUTHORIZED + + +class ForbiddenServerProtocol(WebSocketServerProtocol): @asyncio.coroutine def get_response_status(self, set_header): return FORBIDDEN +class FooClientProtocol(WebSocketClientProtocol): + pass + + +class BarClientProtocol(WebSocketClientProtocol): + pass + + class ClientServerTests(unittest.TestCase): secure = False @@ -268,7 +289,7 @@ def get_response_status(self, set_header): status = yield from super().get_response_status(set_header) return status - with self.temp_server(klass=SaveAttributesProtocol): + with self.temp_server(create_protocol=SaveAttributesProtocol): self.start_client(path='foo/bar', origin='http://otherhost') self.assertEqual(attrs['origin'], 'http://otherhost') self.assertEqual(attrs['path'], '/foo/bar') @@ -280,10 +301,50 @@ def get_response_status(self, set_header): self.assertIsInstance(request_headers, http.client.HTTPMessage) self.assertEqual(request_headers.get('origin'), 'http://otherhost') - @with_server(klass=ForbiddenWebSocketServerProtocol) - def test_authentication(self): - with self.assertRaises(InvalidStatus): + def assert_client_raises_code(self, code): + with self.assertRaises(InvalidStatus) as raised: self.start_client() + self.assertEqual(raised.exception.code, code) + + @with_server(create_protocol=UnauthorizedServerProtocol) + def test_server_create_protocol(self): + self.assert_client_raises_code(401) + + @with_server(create_protocol=(lambda *args, **kwargs: + UnauthorizedServerProtocol(*args, **kwargs))) + def test_server_create_protocol_function(self): + self.assert_client_raises_code(401) + + @with_server(klass=UnauthorizedServerProtocol) + def test_server_klass(self): + self.assert_client_raises_code(401) + + @with_server(create_protocol=ForbiddenServerProtocol, + klass=UnauthorizedServerProtocol) + def test_server_create_protocol_over_klass(self): + self.assert_client_raises_code(403) + + @with_server() + @with_client('path', create_protocol=FooClientProtocol) + def test_client_create_protocol(self): + self.assertIsInstance(self.client, FooClientProtocol) + + @with_server() + @with_client('path', create_protocol=( + lambda *args, **kwargs: FooClientProtocol(*args, **kwargs))) + def test_client_create_protocol_function(self): + self.assertIsInstance(self.client, FooClientProtocol) + + @with_server() + @with_client('path', klass=FooClientProtocol) + def test_client_klass(self): + self.assertIsInstance(self.client, FooClientProtocol) + + @with_server() + @with_client('path', create_protocol=BarClientProtocol, + klass=FooClientProtocol) + def test_client_create_protocol_over_klass(self): + self.assertIsInstance(self.client, BarClientProtocol) @with_server() @with_client('subprotocol') @@ -437,7 +498,7 @@ def test_server_shuts_down_during_connection_handling(self): # Websocket connection terminates with 1001 Going Away. self.assertEqual(self.client.close_code, 1001) - @with_server(klass=ForbiddenWebSocketServerProtocol) + @with_server(create_protocol=ForbiddenServerProtocol) def test_invalid_status_error_during_client_connect(self): with self.assertRaises(InvalidStatus) as raised: self.start_client() From 39cec143769a31c92fce282b6270bbf52ceb1206 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 29 Jul 2017 14:55:47 +0200 Subject: [PATCH 0253/1539] Review fix for issue #216. * Clarify documentation wording a bit. * Make the backwards-compatibility logic more explicit (and removable). * Move klass with legacy_recv, the other backwards-compatibility shim. --- docs/cheatsheet.rst | 12 ++++++------ websockets/client.py | 20 +++++++++++++------- websockets/server.py | 19 ++++++++++++------- 3 files changed, 31 insertions(+), 20 deletions(-) diff --git a/docs/cheatsheet.rst b/docs/cheatsheet.rst index 5ee2c221f..21509acae 100644 --- a/docs/cheatsheet.rst +++ b/docs/cheatsheet.rst @@ -22,9 +22,9 @@ Server execute the application logic, and finally closes the connection after the handler exits normally or with an exception. - * You may subclass :class:`~websockets.server.WebSocketServerProtocol` and - pass it or a factory function as the ``create_protocol`` argument for - advanced customization. + * For advanced customization, you may subclass + :class:`~websockets.server.WebSocketServerProtocol` and pass either this + subclass or a factory function as the ``create_protocol`` argument. Client ------ @@ -34,9 +34,9 @@ Client * On Python ≥ 3.5, you can also use it as an asynchronous context manager. - * You may subclass :class:`~websockets.server.WebSocketClientProtocol` and - pass it or a factory function as the ``create_protocol`` argument for - advanced customization. + * For advanced customization, you may subclass + :class:`~websockets.server.WebSocketClientProtocol` and pass either this + subclass or a factory function as the ``create_protocol`` argument. * Call :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` and :meth:`~websockets.protocol.WebSocketCommonProtocol.send` to receive and diff --git a/websockets/client.py b/websockets/client.py index 4053c2863..4afe4c35e 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -130,10 +130,11 @@ def handshake(self, wsuri, @asyncio.coroutine -def connect(uri, *, create_protocol=None, klass=None, +def connect(uri, *, + create_protocol=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, - loop=None, legacy_recv=False, + loop=None, legacy_recv=False, klass=None, origin=None, subprotocols=None, extra_headers=None, **kwds): """ @@ -155,10 +156,9 @@ def connect(uri, *, create_protocol=None, klass=None, ``read_limit``, and ``write_limit`` optional arguments is described in the documentation of :class:`~websockets.protocol.WebSocketCommonProtocol`. - The ``create_protocol`` parameter allows customizing the - :class:`WebSocketClientProtocol` class used. The argument should be a - callable or class accepting the same arguments as - :class:`WebSocketClientProtocol` and that returns a + The ``create_protocol`` parameter allows customizing the asyncio protocol + that manages the connection. It should be a callable or class accepting + the same arguments as :class:`WebSocketClientProtocol` and returning a :class:`WebSocketClientProtocol` instance. It defaults to :class:`WebSocketClientProtocol`. @@ -181,7 +181,13 @@ def connect(uri, *, create_protocol=None, klass=None, if loop is None: loop = asyncio.get_event_loop() - create_protocol = create_protocol or klass or WebSocketClientProtocol + # Backwards-compatibility: create_protocol used to be called klass. + # In the unlikely event that both are specified, klass is ignored. + if create_protocol is None: + create_protocol = klass + + if create_protocol is None: + create_protocol = WebSocketClientProtocol wsuri = parse_uri(uri) if wsuri.secure: diff --git a/websockets/server.py b/websockets/server.py index 5c938aa25..cf04d6eaf 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -407,10 +407,10 @@ def wait_closed(self): @asyncio.coroutine def serve(ws_handler, host=None, port=None, *, - create_protocol=None, klass=None, + create_protocol=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, - loop=None, legacy_recv=False, + loop=None, legacy_recv=False, klass=None, origins=None, subprotocols=None, extra_headers=None, **kwds): """ @@ -440,10 +440,9 @@ def serve(ws_handler, host=None, port=None, *, set the ``ssl`` keyword argument to a :class:`~ssl.SSLContext` to enable TLS. - The ``create_protocol`` parameter allows customizing the - :class:`WebSocketServerProtocol` class used. The argument should be a - callable or class accepting the same arguments as - :class:`WebSocketServerProtocol` and that returns a + The ``create_protocol`` parameter allows customizing the asyncio protocol + that manages the connection. It should be a callable or class accepting + the same arguments as :class:`WebSocketServerProtocol` and returning a :class:`WebSocketServerProtocol` instance. It defaults to :class:`WebSocketServerProtocol`. @@ -479,7 +478,13 @@ def serve(ws_handler, host=None, port=None, *, if loop is None: loop = asyncio.get_event_loop() - create_protocol = create_protocol or klass or WebSocketServerProtocol + # Backwards-compatibility: create_protocol used to be called klass. + # In the unlikely event that both are specified, klass is ignored. + if create_protocol is None: + create_protocol = klass + + if create_protocol is None: + create_protocol = WebSocketServerProtocol ws_server = WebSocketServer(loop) From 68e2af5cd9d823f8d9f77bc18bd2556694b5d84e Mon Sep 17 00:00:00 2001 From: mayeut Date: Tue, 1 Aug 2017 22:35:53 +0200 Subject: [PATCH 0254/1539] Fix build of speedups.c on MSVC 2010 --- websockets/speedups.c | 47 +++++++++++++++++++++++++------------------ 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/websockets/speedups.c b/websockets/speedups.c index 7a18d6107..4d7622231 100644 --- a/websockets/speedups.c +++ b/websockets/speedups.c @@ -52,39 +52,42 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds) // Apparently GCC cannot figure out the following optimizations by itself. + // We need a new scope for MSVC 2010 (non C99 friendly) + { #if __SSE2__ - // With SSE2 support, XOR by blocks of 16 bytes = 128 bits. + // With SSE2 support, XOR by blocks of 16 bytes = 128 bits. - // Since we cannot control the 16-bytes alignment of input and output - // buffers, we rely on loadu/storeu rather than load/store. + // Since we cannot control the 16-bytes alignment of input and output + // buffers, we rely on loadu/storeu rather than load/store. - Py_ssize_t input_len_128 = input_len & ~15; - __m128i mask_128 = _mm_set1_epi32(*(uint32_t *)mask); + Py_ssize_t input_len_128 = input_len & ~15; + __m128i mask_128 = _mm_set1_epi32(*(uint32_t *)mask); - for (; i < input_len_128; i += 16) - { - __m128i in_128 = _mm_loadu_si128((__m128i *)(input + i)); - __m128i out_128 = _mm_xor_si128(in_128, mask_128); - _mm_storeu_si128((__m128i *)(output + i), out_128); - } + for (; i < input_len_128; i += 16) + { + __m128i in_128 = _mm_loadu_si128((__m128i *)(input + i)); + __m128i out_128 = _mm_xor_si128(in_128, mask_128); + _mm_storeu_si128((__m128i *)(output + i), out_128); + } #else - // Without SSE2 support, XOR by blocks of 8 bytes = 64 bits. + // Without SSE2 support, XOR by blocks of 8 bytes = 64 bits. - // We assume the memory allocator aligns everything on 8 bytes boundaries. + // We assume the memory allocator aligns everything on 8 bytes boundaries. - Py_ssize_t input_len_64 = input_len & ~7; - uint32_t mask_32 = *(uint32_t *)mask; - uint64_t mask_64 = ((uint64_t)mask_32 << 32) | (uint64_t)mask_32; + Py_ssize_t input_len_64 = input_len & ~7; + uint32_t mask_32 = *(uint32_t *)mask; + uint64_t mask_64 = ((uint64_t)mask_32 << 32) | (uint64_t)mask_32; - for (; i < input_len_64; i += 8) - { - *(uint64_t *)(output + i) = *(uint64_t *)(input + i) ^ mask_64; - } + for (; i < input_len_64; i += 8) + { + *(uint64_t *)(output + i) = *(uint64_t *)(input + i) ^ mask_64; + } #endif + } // XOR the remainder of the input byte by byte. @@ -114,6 +117,10 @@ static struct PyModuleDef speedups_module = { /* m_doc */ -1, /* m_size */ speedups_methods, /* m_methods */ + NULL, + NULL, + NULL, + NULL }; PyMODINIT_FUNC From 054bc9045c4a324d375e50e402568e7fb886d120 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 8 Aug 2017 09:14:30 +0200 Subject: [PATCH 0255/1539] Change isort multi-line rule. --- setup.cfg | 1 + websockets/protocol.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index e4ece51a8..08644c32e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,3 +7,4 @@ ignore = E731,F403,F405 [isort] known_standard_library = asyncio lines_after_imports = 2 +multi_line_output = 5 diff --git a/websockets/protocol.py b/websockets/protocol.py index b0fb7c893..bfe7e2313 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -15,8 +15,9 @@ import struct from .compatibility import asyncio_ensure_future -from .exceptions import (ConnectionClosed, InvalidState, PayloadTooBig, - WebSocketProtocolError) +from .exceptions import ( + ConnectionClosed, InvalidState, PayloadTooBig, WebSocketProtocolError +) from .framing import * from .handshake import * From 27e06981ff12421da3c2c0e05b6b4abbadcf9937 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 8 Aug 2017 18:04:39 +0200 Subject: [PATCH 0256/1539] Move compatibility code in the adequate module. --- websockets/compatibility.py | 21 +++++++++++++++++++++ websockets/server.py | 10 +--------- websockets/test_client_server.py | 15 +-------------- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/websockets/compatibility.py b/websockets/compatibility.py index 90afea9ba..ac9aaf7ad 100644 --- a/websockets/compatibility.py +++ b/websockets/compatibility.py @@ -1,4 +1,5 @@ import asyncio +import http # Replace with BaseEventLoop.create_task when dropping Python < 3.4.2. @@ -6,3 +7,23 @@ asyncio_ensure_future = asyncio.ensure_future # Python ≥ 3.5 except AttributeError: # pragma: no cover asyncio_ensure_future = asyncio.async # Python < 3.5 + +try: # pragma: no cover + # Python ≥ 3.5 + SWITCHING_PROTOCOLS = http.HTTPStatus.SWITCHING_PROTOCOLS + # Used only in tests. + UNAUTHORIZED = http.HTTPStatus.UNAUTHORIZED + FORBIDDEN = http.HTTPStatus.FORBIDDEN +except AttributeError: # pragma: no cover + # Python < 3.5 + class SWITCHING_PROTOCOLS: + value = 101 + phrase = "Switching Protocols" + + class UNAUTHORIZED: + value = 401 + phrase = "Unauthorized" + + class FORBIDDEN: + value = 403 + phrase = "Forbidden" diff --git a/websockets/server.py b/websockets/server.py index cf04d6eaf..24fb246e9 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -6,10 +6,9 @@ import asyncio import collections.abc -import http import logging -from .compatibility import asyncio_ensure_future +from .compatibility import SWITCHING_PROTOCOLS, asyncio_ensure_future from .exceptions import InvalidHandshake, InvalidMessage, InvalidOrigin from .handshake import build_response, check_request from .http import USER_AGENT, build_headers, read_request @@ -20,13 +19,6 @@ logger = logging.getLogger(__name__) -try: - SWITCHING_PROTOCOLS = http.HTTPStatus.SWITCHING_PROTOCOLS -except AttributeError: # pragma: no cover - class SWITCHING_PROTOCOLS: - value = 101 - phrase = 'Switching protocols' - class WebSocketServerProtocol(WebSocketCommonProtocol): """ diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 0edc1408c..1fe201029 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -10,6 +10,7 @@ from contextlib import contextmanager from .client import * +from .compatibility import FORBIDDEN, UNAUTHORIZED from .exceptions import ConnectionClosed, InvalidHandshake, InvalidStatus from .http import USER_AGENT, read_response from .server import * @@ -39,20 +40,6 @@ def handler(ws, path): yield from ws.send((yield from ws.recv())) -try: - # Order by status code. - UNAUTHORIZED = http.HTTPStatus.UNAUTHORIZED - FORBIDDEN = http.HTTPStatus.FORBIDDEN -except AttributeError: # pragma: no cover - class UNAUTHORIZED: - value = 401 - phrase = 'Unauthorized' - - class FORBIDDEN: - value = 403 - phrase = 'Forbidden' - - @contextmanager def temp_test_server(test, **kwds): test.start_server(**kwds) From 411383fd3c5ccbabd3aab0a5a91fa092529a766a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Aug 2017 11:04:50 +0200 Subject: [PATCH 0257/1539] Rename exception. Providing both InvalidStatus and InvalidState was quite confusing. --- docs/changelog.rst | 2 +- websockets/client.py | 4 ++-- websockets/exceptions.py | 4 ++-- websockets/test_client_server.py | 8 ++++---- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 747948ee8..dd91449c2 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -24,7 +24,7 @@ Changelog * Added an optional C extension to speed up low level operations. * An invalid response status code during :func:`~websockets.client.connect` - now raises :class:`~websockets.exceptions.InvalidStatus` with a ``code`` + now raises :class:`~websockets.exceptions.InvalidStatusCode` with a ``code`` attribute. 3.3 diff --git a/websockets/client.py b/websockets/client.py index 4afe4c35e..793738972 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -6,7 +6,7 @@ import asyncio import collections.abc -from .exceptions import InvalidHandshake, InvalidMessage, InvalidStatus +from .exceptions import InvalidHandshake, InvalidMessage, InvalidStatusCode from .handshake import build_request, check_response from .http import USER_AGENT, build_headers, read_response from .protocol import CONNECTING, OPEN, WebSocketCommonProtocol @@ -118,7 +118,7 @@ def handshake(self, wsuri, get_header = lambda k: headers.get(k, '') if status_code != 101: - raise InvalidStatus(status_code) + raise InvalidStatusCode(status_code) check_response(get_header, key) diff --git a/websockets/exceptions.py b/websockets/exceptions.py index b3917c704..0681fb33e 100644 --- a/websockets/exceptions.py +++ b/websockets/exceptions.py @@ -1,6 +1,6 @@ __all__ = [ 'InvalidHandshake', 'InvalidMessage', 'InvalidOrigin', 'InvalidState', - 'InvalidStatus', 'InvalidURI', 'ConnectionClosed', 'PayloadTooBig', + 'InvalidStatusCode', 'InvalidURI', 'ConnectionClosed', 'PayloadTooBig', 'WebSocketProtocolError', ] @@ -26,7 +26,7 @@ class InvalidOrigin(InvalidHandshake): """ -class InvalidStatus(InvalidHandshake): +class InvalidStatusCode(InvalidHandshake): """ Exception raised when a handshake response status code is invalid. diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 1fe201029..e6ac85a13 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -11,7 +11,7 @@ from .client import * from .compatibility import FORBIDDEN, UNAUTHORIZED -from .exceptions import ConnectionClosed, InvalidHandshake, InvalidStatus +from .exceptions import ConnectionClosed, InvalidHandshake, InvalidStatusCode from .http import USER_AGENT, read_response from .server import * @@ -289,7 +289,7 @@ def get_response_status(self, set_header): self.assertEqual(request_headers.get('origin'), 'http://otherhost') def assert_client_raises_code(self, code): - with self.assertRaises(InvalidStatus) as raised: + with self.assertRaises(InvalidStatusCode) as raised: self.start_client() self.assertEqual(raised.exception.code, code) @@ -423,7 +423,7 @@ def wrong_read_response(stream): return 400, headers _read_response.side_effect = wrong_read_response - with self.assertRaises(InvalidStatus): + with self.assertRaises(InvalidStatusCode): self.start_client() self.run_loop_once() @@ -487,7 +487,7 @@ def test_server_shuts_down_during_connection_handling(self): @with_server(create_protocol=ForbiddenServerProtocol) def test_invalid_status_error_during_client_connect(self): - with self.assertRaises(InvalidStatus) as raised: + with self.assertRaises(InvalidStatusCode) as raised: self.start_client() exception = raised.exception self.assertEqual(str(exception), "Status code not 101: 403") From 3242b20e6ef7fe6bf916d7d7c5026c47e8f99398 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Aug 2017 11:19:16 +0200 Subject: [PATCH 0258/1539] Use status vs. status_code consistently. * status is a HTTPStatus instance. * status_code is a numerical status code, like status.value. --- docs/changelog.rst | 4 ++-- websockets/exceptions.py | 8 ++++---- websockets/http.py | 16 ++++++++-------- websockets/server.py | 2 +- websockets/test_client_server.py | 8 ++++---- websockets/test_http.py | 4 ++-- 6 files changed, 21 insertions(+), 21 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index dd91449c2..83a4b15eb 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -127,7 +127,7 @@ Also: * Supported running on a non-default event loop. -* Returned a 403 error code instead of 400 when the request Origin isn't +* Returned a 403 status code instead of 400 when the request Origin isn't allowed. * Cancelling :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` no @@ -135,7 +135,7 @@ Also: * Clarified that the closing handshake can be initiated by the client. -* Set the close status code and reason more consistently. +* Set the close code and reason more consistently. * Strengthened connection termination by simplifying the implementation. diff --git a/websockets/exceptions.py b/websockets/exceptions.py index 0681fb33e..bb47b2d14 100644 --- a/websockets/exceptions.py +++ b/websockets/exceptions.py @@ -30,12 +30,12 @@ class InvalidStatusCode(InvalidHandshake): """ Exception raised when a handshake response status code is invalid. - Provides the integer status code in its ``code`` attribute. + Provides the integer status code in its ``status_code`` attribute. """ - def __init__(self, code): - self.code = code - message = 'Status code not 101: {}'.format(code) + def __init__(self, status_code): + self.status_code = status_code + message = 'Status code not 101: {}'.format(status_code) super().__init__(message) diff --git a/websockets/http.py b/websockets/http.py index e71e8c78d..4b95fbcbf 100644 --- a/websockets/http.py +++ b/websockets/http.py @@ -96,8 +96,8 @@ def read_response(stream): ``stream`` is an :class:`~asyncio.StreamReader`. - Return ``(status, headers)`` where ``status`` is a :class:`int` and - ``headers`` is a list of ``(name, value)`` tuples. + Return ``(status_code, headers)`` where ``status_code`` is a :class:`int` + and ``headers`` is a list of ``(name, value)`` tuples. Non-ASCII characters are represented with surrogate escapes. @@ -109,26 +109,26 @@ def read_response(stream): # https://tools.ietf.org/html/rfc7230#section-3.1.2 # As in read_request, parsing is simple because a fixed value is expected - # for version, status is a 3-digit number, and reason can be ignored. + # for version, status_code is a 3-digit number, and reason can be ignored. # Given the implementation of read_line(), status_line ends with CRLF. status_line = yield from read_line(stream) # This may raise "ValueError: not enough values to unpack" - version, status, reason = status_line[:-2].split(b' ', 2) + version, status_code, reason = status_line[:-2].split(b' ', 2) if version != b'HTTP/1.1': raise ValueError("Unsupported HTTP version: %r" % version) # This may raise "ValueError: invalid literal for int() with base 10" - status = int(status) - if not 100 <= status < 1000: - raise ValueError("Unsupported HTTP status code: %d" % status) + status_code = int(status_code) + if not 100 <= status_code < 1000: + raise ValueError("Unsupported HTTP status_code code: %d" % status_code) if not _value_re.match(reason): raise ValueError("Invalid HTTP reason phrase: %r" % reason) headers = yield from read_headers(stream) - return status, headers + return status_code, headers @asyncio.coroutine diff --git a/websockets/server.py b/websockets/server.py index 24fb246e9..15bb1e622 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -227,7 +227,7 @@ def get_response_status(self, set_header): likely to require network requests. The connection is closed immediately after sending the response when - the status code is not ``HTTPStatus.SWITCHING_PROTOCOLS``. + the status is not ``HTTPStatus.SWITCHING_PROTOCOLS``. Call ``set_header(key, value)`` to set additional response headers. diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index e6ac85a13..b41aeeb61 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -288,10 +288,10 @@ def get_response_status(self, set_header): self.assertIsInstance(request_headers, http.client.HTTPMessage) self.assertEqual(request_headers.get('origin'), 'http://otherhost') - def assert_client_raises_code(self, code): + def assert_client_raises_code(self, status_code): with self.assertRaises(InvalidStatusCode) as raised: self.start_client() - self.assertEqual(raised.exception.code, code) + self.assertEqual(raised.exception.status_code, status_code) @with_server(create_protocol=UnauthorizedServerProtocol) def test_server_create_protocol(self): @@ -419,7 +419,7 @@ def wrong_build_response(set_header, key): def test_server_does_not_switch_protocols(self, _read_response): @asyncio.coroutine def wrong_read_response(stream): - code, headers = yield from read_response(stream) + status_code, headers = yield from read_response(stream) return 400, headers _read_response.side_effect = wrong_read_response @@ -491,7 +491,7 @@ def test_invalid_status_error_during_client_connect(self): self.start_client() exception = raised.exception self.assertEqual(str(exception), "Status code not 101: 403") - self.assertEqual(exception.code, 403) + self.assertEqual(exception.status_code, 403) @with_server() @unittest.mock.patch('websockets.server.read_request') diff --git a/websockets/test_http.py b/websockets/test_http.py index 0e13c8f5c..a891ad5ea 100644 --- a/websockets/test_http.py +++ b/websockets/test_http.py @@ -44,9 +44,9 @@ def test_read_response(self): b'Sec-WebSocket-Protocol: chat\r\n' b'\r\n' ) - status, headers = self.loop.run_until_complete( + status_code, headers = self.loop.run_until_complete( read_response(self.stream)) - self.assertEqual(status, 101) + self.assertEqual(status_code, 101) self.assertEqual(dict(headers)['Upgrade'], 'websocket') def test_request_method(self): From 50fd62e333256cd5db43927bc98bebe7bba6e564 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 12 Aug 2017 10:18:43 +0200 Subject: [PATCH 0259/1539] Provide a better way to override request handling. This replaces the get_response_status() API which never made it into a release (so there's no backwards incompatibility). Remove a test that depends on get_response_status() being called after check_request(). The extension point must be before check_request() so it can handle regular HTTP requests. Fix #116. Supersedes #202 #154, #137. --- docs/api.rst | 2 +- docs/changelog.rst | 4 +-- websockets/server.py | 61 +++++++++++++++++--------------- websockets/test_client_server.py | 38 +++----------------- 4 files changed, 40 insertions(+), 65 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index 26fdc25bc..fbcec2de8 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -42,8 +42,8 @@ Server .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, subprotocols=None, extra_headers=None) .. automethod:: handshake(origins=None, subprotocols=None, extra_headers=None) + .. automethod:: process_request(path, request_headers) .. automethod:: select_subprotocol(client_protos, server_protos) - .. automethod:: get_response_status() Client ...... diff --git a/docs/changelog.rst b/docs/changelog.rst index 83a4b15eb..772d3bee7 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -14,8 +14,8 @@ Changelog * :func:`~websockets.server.serve` can be used as an asynchronous context manager on Python ≥ 3.5. -* Added support for rejecting incoming connections by customizing - :meth:`~websockets.server.WebSocketServerProtocol.get_response_status()`. +* Added support for customizing handling of incoming connections with + :meth:`~websockets.server.WebSocketServerProtocol.process_request()`. * Made read and write buffer sizes configurable. diff --git a/websockets/server.py b/websockets/server.py index 15bb1e622..943ea9711 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -81,8 +81,8 @@ def handler(self): self.writer.write(response.encode()) raise - # Subclasses can customize get_response_status() or handshake() to - # reject the handshake, typically after checking authentication. + # Subclasses can customize process_request() to reject the + # handshake, typically after checking authentication. if path is None: return @@ -211,13 +211,22 @@ def select_subprotocol(client_protos, server_protos): return sorted(common_protos, key=priority)[0] @asyncio.coroutine - def get_response_status(self, set_header): + def process_request(self, path, request_headers): """ - Return a :class:`~http.HTTPStatus` for the HTTP response. + Intercept the HTTP request and return a HTTP response if needed. - (:class:`~http.HTTPStatus` was added in Python 3.5. On earlier - versions, a compatible object must be returned. Check the definition - of ``SWITCHING_PROTOCOLS`` for an example.) + ``request_headers`` are a :class:`~http.client.HTTPMessage`. + + If this coroutine returns ``None``, the WebSocket handshake continues. + If it returns a HTTP status code and HTTP headers, that HTTP response + is sent and the connection is closed immediately. + + The HTTP status must be a :class:`~http.HTTPStatus` and HTTP headers + must be an iterable of ``(name, value)`` pairs. + + (:class:`~http.HTTPStatus` was added in Python 3.5. Use a compatible + object on earlier versions. Look at ``SWITCHING_PROTOCOLS`` in + ``websockets.compatibility`` for an example.) This method may be overridden to check the request headers and set a different status, for example to authenticate the request and return @@ -226,13 +235,7 @@ def get_response_status(self, set_header): It is declared as a coroutine because such authentication checks are likely to require network requests. - The connection is closed immediately after sending the response when - the status is not ``HTTPStatus.SWITCHING_PROTOCOLS``. - - Call ``set_header(key, value)`` to set additional response headers. - """ - return SWITCHING_PROTOCOLS @asyncio.coroutine def handshake(self, origins=None, subprotocols=None, extra_headers=None): @@ -255,29 +258,30 @@ def handshake(self, origins=None, subprotocols=None, extra_headers=None): Return the URI of the request. """ - path, headers = yield from self.read_http_request() - get_header = lambda k: headers.get(k, '') + path, request_headers = yield from self.read_http_request() + + # Hook for customizing request handling, for example checking + # authentication or treating some paths as plain HTTP endpoints. + + early_response = yield from self.process_request(path, request_headers) + if early_response is not None: + yield from self.write_http_response(*early_response) + self.opening_handshake.set_result(False) + yield from self.close_connection(force=True) + return + + get_header = lambda k: request_headers.get(k, '') key = check_request(get_header) self.origin = self.process_origin(get_header, origins) self.subprotocol = self.process_subprotocol(get_header, subprotocols) - headers = [] - set_header = lambda k, v: headers.append((k, v)) + response_headers = [] + set_header = lambda k, v: response_headers.append((k, v)) set_header('Server', USER_AGENT) - status = yield from self.get_response_status(set_header) - - # Abort the connection if the status code isn't 101. - if status.value != SWITCHING_PROTOCOLS.value: - yield from self.write_http_response(status, headers) - self.opening_handshake.set_result(False) - yield from self.close_connection(force=True) - return - - # Status code is 101, establish the connection. if self.subprotocol: set_header('Sec-WebSocket-Protocol', self.subprotocol) if extra_headers is not None: @@ -289,7 +293,8 @@ def handshake(self, origins=None, subprotocols=None, extra_headers=None): set_header(name, value) build_response(set_header, key) - yield from self.write_http_response(status, headers) + yield from self.write_http_response( + SWITCHING_PROTOCOLS, response_headers) assert self.state == CONNECTING self.state = OPEN diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index b41aeeb61..794702b01 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -1,7 +1,5 @@ import asyncio import functools -import http -import http.client import logging import os import ssl @@ -90,15 +88,15 @@ def with_client(*args, **kwds): class UnauthorizedServerProtocol(WebSocketServerProtocol): @asyncio.coroutine - def get_response_status(self, set_header): - return UNAUTHORIZED + def process_request(self, path, request_headers): + return UNAUTHORIZED, [] class ForbiddenServerProtocol(WebSocketServerProtocol): @asyncio.coroutine - def get_response_status(self, set_header): - return FORBIDDEN + def process_request(self, path, request_headers): + return FORBIDDEN, [] class FooClientProtocol(WebSocketClientProtocol): @@ -260,34 +258,6 @@ def test_protocol_custom_response_headers_list(self): resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) - def test_get_response_status_attributes_available(self): - # Save the attribute values to a dict instead of asserting inside - # get_response_status() because assertion errors there do not - # currently bubble up for easy viewing. - attrs = {} - - class SaveAttributesProtocol(WebSocketServerProtocol): - @asyncio.coroutine - def get_response_status(self, set_header): - attrs['origin'] = self.origin - attrs['path'] = self.path - attrs['raw_request_headers'] = self.raw_request_headers.copy() - attrs['request_headers'] = self.request_headers - status = yield from super().get_response_status(set_header) - return status - - with self.temp_server(create_protocol=SaveAttributesProtocol): - self.start_client(path='foo/bar', origin='http://otherhost') - self.assertEqual(attrs['origin'], 'http://otherhost') - self.assertEqual(attrs['path'], '/foo/bar') - # To reduce test brittleness, only check one nontrivial aspect - # of the request headers. - self.assertIn(('Origin', 'http://otherhost'), - attrs['raw_request_headers']) - request_headers = attrs['request_headers'] - self.assertIsInstance(request_headers, http.client.HTTPMessage) - self.assertEqual(request_headers.get('origin'), 'http://otherhost') - def assert_client_raises_code(self, status_code): with self.assertRaises(InvalidStatusCode) as raised: self.start_client() From 48372040742db9cc4f3ec436e4ec0f3df035e26b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 12 Aug 2017 15:11:13 +0200 Subject: [PATCH 0260/1539] Support returning a HTTP body in custom response. Also add a relevant test for process_request(). --- websockets/compatibility.py | 5 +++ websockets/server.py | 18 ++++++++--- websockets/test_client_server.py | 55 ++++++++++++++++++++++++++++---- 3 files changed, 67 insertions(+), 11 deletions(-) diff --git a/websockets/compatibility.py b/websockets/compatibility.py index ac9aaf7ad..1c3290cf4 100644 --- a/websockets/compatibility.py +++ b/websockets/compatibility.py @@ -12,6 +12,7 @@ # Python ≥ 3.5 SWITCHING_PROTOCOLS = http.HTTPStatus.SWITCHING_PROTOCOLS # Used only in tests. + OK = http.HTTPStatus.OK UNAUTHORIZED = http.HTTPStatus.UNAUTHORIZED FORBIDDEN = http.HTTPStatus.FORBIDDEN except AttributeError: # pragma: no cover @@ -20,6 +21,10 @@ class SWITCHING_PROTOCOLS: value = 101 phrase = "Switching Protocols" + class OK: + value = 200 + phrase = "OK" + class UNAUTHORIZED: value = 401 phrase = "Unauthorized" diff --git a/websockets/server.py b/websockets/server.py index 943ea9711..37bfef0a4 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -139,6 +139,8 @@ def read_http_request(self): Raise :exc:`~websockets.exceptions.InvalidMessage` if the HTTP message is malformed or isn't a HTTP/1.1 GET request. + This coroutine assumes that there is no request body. + """ try: path, headers = yield from read_request(self.reader) @@ -152,10 +154,12 @@ def read_http_request(self): return path, self.request_headers @asyncio.coroutine - def write_http_response(self, status, headers): + def write_http_response(self, status, headers, body=None): """ Write status line and headers to the HTTP response. + This coroutine is also able to write a response body. + """ self.response_headers = build_headers(headers) self.raw_response_headers = headers @@ -171,6 +175,9 @@ def write_http_response(self, status, headers): self.writer.write(response) + if body is not None: + self.writer.write(body) + def process_origin(self, get_header, origins=None): """ Handle the Origin HTTP header. @@ -218,11 +225,12 @@ def process_request(self, path, request_headers): ``request_headers`` are a :class:`~http.client.HTTPMessage`. If this coroutine returns ``None``, the WebSocket handshake continues. - If it returns a HTTP status code and HTTP headers, that HTTP response - is sent and the connection is closed immediately. + If it returns a status code, headers and a optionally a response body, + that HTTP response is sent and the connection is closed. - The HTTP status must be a :class:`~http.HTTPStatus` and HTTP headers - must be an iterable of ``(name, value)`` pairs. + The HTTP status must be a :class:`~http.HTTPStatus`. HTTP headers must + be an iterable of ``(name, value)`` pairs. If provided, the HTTP + response body must be :class:`bytes`. (:class:`~http.HTTPStatus` was added in Python 3.5. Use a compatible object on earlier versions. Look at ``SWITCHING_PROTOCOLS`` in diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 794702b01..6609ef2b6 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -1,14 +1,16 @@ import asyncio +import contextlib import functools import logging import os import ssl +import sys import unittest import unittest.mock -from contextlib import contextmanager +import urllib.request from .client import * -from .compatibility import FORBIDDEN, UNAUTHORIZED +from .compatibility import FORBIDDEN, OK, UNAUTHORIZED from .exceptions import ConnectionClosed, InvalidHandshake, InvalidStatusCode from .http import USER_AGENT, read_response from .server import * @@ -38,7 +40,7 @@ def handler(ws, path): yield from ws.send((yield from ws.recv())) -@contextmanager +@contextlib.contextmanager def temp_test_server(test, **kwds): test.start_server(**kwds) try: @@ -47,7 +49,7 @@ def temp_test_server(test, **kwds): test.stop_server() -@contextmanager +@contextlib.contextmanager def temp_test_client(test, *args, **kwds): test.start_client(*args, **kwds) try: @@ -99,6 +101,15 @@ def process_request(self, path, request_headers): return FORBIDDEN, [] +class HealthCheckServerProtocol(WebSocketServerProtocol): + + @asyncio.coroutine + def process_request(self, path, request_headers): + if path == '/__health__/': + body = b'status = green\n' + return OK, [('Content-Length', str(len(body)))], body + + class FooClientProtocol(WebSocketClientProtocol): pass @@ -147,12 +158,12 @@ def stop_server(self): except asyncio.TimeoutError: # pragma: no cover self.fail("Server failed to stop") - @contextmanager + @contextlib.contextmanager def temp_server(self, **kwds): with temp_test_server(self, **kwds): yield - @contextmanager + @contextlib.contextmanager def temp_client(self, *args, **kwds): with temp_test_client(self, *args, **kwds): yield @@ -258,6 +269,38 @@ def test_protocol_custom_response_headers_list(self): resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) + @with_server(create_protocol=HealthCheckServerProtocol) + @with_client() + def test_custom_protocol_http_request(self): + # One URL returns a HTTP response. + + if self.secure: + url = 'https://localhost:8642/__health__/' + if sys.version_info[:2] < (3, 4): # pragma: no cover + # Python 3.3 didn't check SSL certificates. + open_health_check = functools.partial( + urllib.request.urlopen, url) + else: # pragma: no cover + open_health_check = functools.partial( + urllib.request.urlopen, url, context=self.client_context) + else: + url = 'http://localhost:8642/__health__/' + open_health_check = functools.partial( + urllib.request.urlopen, url) + + response = self.loop.run_until_complete( + self.loop.run_in_executor(None, open_health_check)) + + with contextlib.closing(response): + self.assertEqual(response.code, 200) + self.assertEqual(response.read(), b'status = green\n') + + # Other URLs create a WebSocket connection. + + self.loop.run_until_complete(self.client.send("Hello!")) + reply = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(reply, "Hello!") + def assert_client_raises_code(self, status_code): with self.assertRaises(InvalidStatusCode) as raised: self.start_client() From 7389cc40c5d4ec2b0f50edb275a99bd07a29ba96 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Aug 2017 09:41:10 +0200 Subject: [PATCH 0261/1539] Clarify handling of HTTP request/response bodies. websockets doesn't do anything with them -- since it doesn't expect any -- but they're still there in the data stream. --- websockets/client.py | 8 ++++++-- websockets/http.py | 16 +++++++++------- websockets/server.py | 10 ++++++---- websockets/test_client_server.py | 2 +- 4 files changed, 22 insertions(+), 14 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index 793738972..544223e43 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -30,7 +30,7 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): @asyncio.coroutine def write_http_request(self, path, headers): """ - Write status line and headers to the HTTP request. + Write request line and headers to the HTTP request. """ self.path = path @@ -52,7 +52,11 @@ def read_http_response(self): Read status line and headers from the HTTP response. Raise :exc:`~websockets.exceptions.InvalidMessage` if the HTTP message - is malformed or isn't a HTTP/1.1 GET request. + is malformed or isn't an HTTP/1.1 GET request. + + Don't attempt to read the response body because WebSocket handshake + responses don't have one. If the response contains a body, it may be + read from ``self.reader`` after this coroutine returns. """ try: diff --git a/websockets/http.py b/websockets/http.py index 4b95fbcbf..464e942a7 100644 --- a/websockets/http.py +++ b/websockets/http.py @@ -49,7 +49,7 @@ @asyncio.coroutine def read_request(stream): """ - Read an HTTP/1.1 request from ``stream``. + Read an HTTP/1.1 GET request from ``stream``. ``stream`` is an :class:`~asyncio.StreamReader`. @@ -62,7 +62,9 @@ def read_request(stream): Raise an exception if the request isn't well formatted. - The request is assumed not to contain a body. + Don't attempt to read the request body because WebSocket handshake + requests don't have one. If the request contains a body, it may be + read from ``stream`` after this coroutine returns. """ # https://tools.ietf.org/html/rfc7230#section-3.1.1 @@ -101,9 +103,11 @@ def read_response(stream): Non-ASCII characters are represented with surrogate escapes. - Raise an exception if the request isn't well formatted. + Raise an exception if the response isn't well formatted. - The response is assumed not to contain a body. + Don't attempt to read the response body, because WebSocket handshake + responses don't have one. If the response contains a body, it may be + read from ``stream`` after this coroutine returns. """ # https://tools.ietf.org/html/rfc7230#section-3.1.2 @@ -134,7 +138,7 @@ def read_response(stream): @asyncio.coroutine def read_headers(stream): """ - Read an HTTP message from ``stream``. + Read HTTP headers from ``stream``. ``stream`` is an :class:`~asyncio.StreamReader`. @@ -143,8 +147,6 @@ def read_headers(stream): Non-ASCII characters are represented with surrogate escapes. - The message is assumed not to contain a body. - """ # https://tools.ietf.org/html/rfc7230#section-3.2 diff --git a/websockets/server.py b/websockets/server.py index 37bfef0a4..b0c411f1c 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -134,12 +134,14 @@ def _is_server_shutting_down(self, exc): @asyncio.coroutine def read_http_request(self): """ - Read status line and headers from the HTTP request. + Read request line and headers from the HTTP request. Raise :exc:`~websockets.exceptions.InvalidMessage` if the HTTP message - is malformed or isn't a HTTP/1.1 GET request. + is malformed or isn't an HTTP/1.1 GET request. - This coroutine assumes that there is no request body. + Don't attempt to read the request body because WebSocket handshake + requests don't have one. If the request contains a body, it may be + read from ``self.reader`` after this coroutine returns. """ try: @@ -220,7 +222,7 @@ def select_subprotocol(client_protos, server_protos): @asyncio.coroutine def process_request(self, path, request_headers): """ - Intercept the HTTP request and return a HTTP response if needed. + Intercept the HTTP request and return an HTTP response if needed. ``request_headers`` are a :class:`~http.client.HTTPMessage`. diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 6609ef2b6..ac020636c 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -272,7 +272,7 @@ def test_protocol_custom_response_headers_list(self): @with_server(create_protocol=HealthCheckServerProtocol) @with_client() def test_custom_protocol_http_request(self): - # One URL returns a HTTP response. + # One URL returns an HTTP response. if self.secure: url = 'https://localhost:8642/__health__/' From e42f8c8f37ccdfee32030e549708e6464844905d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Aug 2017 10:08:00 +0200 Subject: [PATCH 0262/1539] Reorder methods more logically. --- websockets/server.py | 56 ++++++++++++++++++++++---------------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/websockets/server.py b/websockets/server.py index b0c411f1c..b2ba137a2 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -180,6 +180,34 @@ def write_http_response(self, status, headers, body=None): if body is not None: self.writer.write(body) + @asyncio.coroutine + def process_request(self, path, request_headers): + """ + Intercept the HTTP request and return an HTTP response if needed. + + ``request_headers`` are a :class:`~http.client.HTTPMessage`. + + If this coroutine returns ``None``, the WebSocket handshake continues. + If it returns a status code, headers and a optionally a response body, + that HTTP response is sent and the connection is closed. + + The HTTP status must be a :class:`~http.HTTPStatus`. HTTP headers must + be an iterable of ``(name, value)`` pairs. If provided, the HTTP + response body must be :class:`bytes`. + + (:class:`~http.HTTPStatus` was added in Python 3.5. Use a compatible + object on earlier versions. Look at ``SWITCHING_PROTOCOLS`` in + ``websockets.compatibility`` for an example.) + + This method may be overridden to check the request headers and set a + different status, for example to authenticate the request and return + ``HTTPStatus.UNAUTHORIZED`` or ``HTTPStatus.FORBIDDEN``. + + It is declared as a coroutine because such authentication checks are + likely to require network requests. + + """ + def process_origin(self, get_header, origins=None): """ Handle the Origin HTTP header. @@ -219,34 +247,6 @@ def select_subprotocol(client_protos, server_protos): priority = lambda p: client_protos.index(p) + server_protos.index(p) return sorted(common_protos, key=priority)[0] - @asyncio.coroutine - def process_request(self, path, request_headers): - """ - Intercept the HTTP request and return an HTTP response if needed. - - ``request_headers`` are a :class:`~http.client.HTTPMessage`. - - If this coroutine returns ``None``, the WebSocket handshake continues. - If it returns a status code, headers and a optionally a response body, - that HTTP response is sent and the connection is closed. - - The HTTP status must be a :class:`~http.HTTPStatus`. HTTP headers must - be an iterable of ``(name, value)`` pairs. If provided, the HTTP - response body must be :class:`bytes`. - - (:class:`~http.HTTPStatus` was added in Python 3.5. Use a compatible - object on earlier versions. Look at ``SWITCHING_PROTOCOLS`` in - ``websockets.compatibility`` for an example.) - - This method may be overridden to check the request headers and set a - different status, for example to authenticate the request and return - ``HTTPStatus.UNAUTHORIZED`` or ``HTTPStatus.FORBIDDEN``. - - It is declared as a coroutine because such authentication checks are - likely to require network requests. - - """ - @asyncio.coroutine def handshake(self, origins=None, subprotocols=None, extra_headers=None): """ From d45b29062c48c6678d889a63f3d5745994be85d3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Aug 2017 13:10:11 +0200 Subject: [PATCH 0263/1539] Refactor writing of HTTP responses by the server. After this change: 1. either handshake() successfully writes a handshake response 2. or it raises an exception, which is caught and translated to an HTTP response in handler() Custom HTTP responses created by process_request() are handled by 2. --- websockets/compatibility.py | 16 ++++++++++- websockets/exceptions.py | 17 +++++++++-- websockets/server.py | 57 ++++++++++++++++++++++++++----------- 3 files changed, 69 insertions(+), 21 deletions(-) diff --git a/websockets/compatibility.py b/websockets/compatibility.py index 1c3290cf4..c8c301421 100644 --- a/websockets/compatibility.py +++ b/websockets/compatibility.py @@ -11,10 +11,12 @@ try: # pragma: no cover # Python ≥ 3.5 SWITCHING_PROTOCOLS = http.HTTPStatus.SWITCHING_PROTOCOLS - # Used only in tests. OK = http.HTTPStatus.OK + BAD_REQUEST = http.HTTPStatus.BAD_REQUEST UNAUTHORIZED = http.HTTPStatus.UNAUTHORIZED FORBIDDEN = http.HTTPStatus.FORBIDDEN + INTERNAL_SERVER_ERROR = http.HTTPStatus.INTERNAL_SERVER_ERROR + SERVICE_UNAVAILABLE = http.HTTPStatus.SERVICE_UNAVAILABLE except AttributeError: # pragma: no cover # Python < 3.5 class SWITCHING_PROTOCOLS: @@ -25,6 +27,10 @@ class OK: value = 200 phrase = "OK" + class BAD_REQUEST: + value = 400 + phrase = "Bad Request" + class UNAUTHORIZED: value = 401 phrase = "Unauthorized" @@ -32,3 +38,11 @@ class UNAUTHORIZED: class FORBIDDEN: value = 403 phrase = "Forbidden" + + class INTERNAL_SERVER_ERROR: + value = 500 + phrase = "Internal Server Error" + + class SERVICE_UNAVAILABLE: + value = 503 + phrase = "Service Unavailable" diff --git a/websockets/exceptions.py b/websockets/exceptions.py index bb47b2d14..3db569564 100644 --- a/websockets/exceptions.py +++ b/websockets/exceptions.py @@ -1,7 +1,7 @@ __all__ = [ - 'InvalidHandshake', 'InvalidMessage', 'InvalidOrigin', 'InvalidState', - 'InvalidStatusCode', 'InvalidURI', 'ConnectionClosed', 'PayloadTooBig', - 'WebSocketProtocolError', + 'AbortHandshake', 'InvalidHandshake', 'InvalidMessage', 'InvalidOrigin', + 'InvalidState', 'InvalidStatusCode', 'InvalidURI', 'ConnectionClosed', + 'PayloadTooBig', 'WebSocketProtocolError', ] @@ -12,6 +12,17 @@ class InvalidHandshake(Exception): """ +class AbortHandshake(InvalidHandshake): + """ + Exception raised to abort a handshake and return a HTTP response. + + """ + def __init__(self, status, headers, body=None): + self.status = status + self.headers = headers + self.body = body + + class InvalidMessage(InvalidHandshake): """ Exception raised when the HTTP message in a handshake request is malformed. diff --git a/websockets/server.py b/websockets/server.py index b2ba137a2..770609c76 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -8,8 +8,13 @@ import collections.abc import logging -from .compatibility import SWITCHING_PROTOCOLS, asyncio_ensure_future -from .exceptions import InvalidHandshake, InvalidMessage, InvalidOrigin +from .compatibility import ( + BAD_REQUEST, FORBIDDEN, INTERNAL_SERVER_ERROR, SERVICE_UNAVAILABLE, + SWITCHING_PROTOCOLS, asyncio_ensure_future +) +from .exceptions import ( + AbortHandshake, InvalidHandshake, InvalidMessage, InvalidOrigin +) from .handshake import build_response, check_request from .http import USER_AGENT, build_headers, read_request from .protocol import CONNECTING, OPEN, WebSocketCommonProtocol @@ -68,22 +73,43 @@ def handler(self): raise except Exception as exc: if self._is_server_shutting_down(exc): - response = ('HTTP/1.1 503 Service Unavailable\r\n\r\n' - 'Server is shutting down.') + early_response = ( + SERVICE_UNAVAILABLE, + [], + b"Server is shutting down.", + ) + elif isinstance(exc, AbortHandshake): + early_response = ( + exc.status, + exc.headers, + exc.body, + ) elif isinstance(exc, InvalidOrigin): - response = 'HTTP/1.1 403 Forbidden\r\n\r\n' + str(exc) + logger.warning("Invalid origin", exc_info=True) + early_response = ( + FORBIDDEN, + [], + str(exc).encode(), + ) elif isinstance(exc, InvalidHandshake): - response = 'HTTP/1.1 400 Bad Request\r\n\r\n' + str(exc) + logger.warning("Invalid handshake", exc_info=True) + early_response = ( + BAD_REQUEST, + [], + str(exc).encode(), + ) else: logger.warning("Error in opening handshake", exc_info=True) - response = ('HTTP/1.1 500 Internal Server Error\r\n\r\n' - 'See server log for more information.') - self.writer.write(response.encode()) - raise + early_response = ( + INTERNAL_SERVER_ERROR, + [], + b"See server log for more information.", + ) + + yield from self.write_http_response(*early_response) + self.opening_handshake.set_result(False) + yield from self.close_connection(force=True) - # Subclasses can customize process_request() to reject the - # handshake, typically after checking authentication. - if path is None: return try: @@ -275,10 +301,7 @@ def handshake(self, origins=None, subprotocols=None, extra_headers=None): early_response = yield from self.process_request(path, request_headers) if early_response is not None: - yield from self.write_http_response(*early_response) - self.opening_handshake.set_result(False) - yield from self.close_connection(force=True) - return + raise AbortHandshake(*early_response) get_header = lambda k: request_headers.get(k, '') From 22f8cef96063361cd14ab372b0c7608c4a05c0ad Mon Sep 17 00:00:00 2001 From: Matthieu Darbois Date: Sun, 20 Aug 2017 13:45:05 +0200 Subject: [PATCH 0264/1539] Build wheels on various platforms (#239) * C extension mandatory if '.cibuildwheel' file is present * build wheels on travis-ci * build wheels on appveyor Fix #220. --- .travis.yml | 35 +++++++++++++++++++++++++++++++++++ appveyor.yml | 19 +++++++++++++++++++ setup.py | 2 +- 3 files changed, 55 insertions(+), 1 deletion(-) create mode 100644 .travis.yml create mode 100644 appveyor.yml diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 000000000..ddc0fc604 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,35 @@ +env: + global: +# Don't attempt to build Python 2.7 wheels; websockets only works on Python 3. + - CIBW_SKIP=cp27* +# Commented out because tests don't pass reliably on macOS, see #241. +# - CIBW_TEST_COMMAND="python3 -m unittest discover websockets" + +matrix: + include: + - dist: trusty + sudo: required + language: python + python: "3.3" + services: + - docker + - os: osx + osx_image: xcode8.3 + +install: +# Python 3 is needed to run cibuildwheel for websockets. + - if [ "${TRAVIS_OS_NAME:-}" == "osx" ]; then + brew install python3; + fi +# Install cibuildwheel using pip3 to make sure Python 3 is used. + - pip3 install cibuildwheel==0.4.0 +# Create file '.cibuildwheel' so that extension build is not optional (c.f. setup.py). + - touch .cibuildwheel + +script: + - cibuildwheel --output-dir wheelhouse +# Upload to PyPI on tags + - if [ "${TRAVIS_TAG:-}" != "" ]; then + python -m pip install twine && + python -m twine upload --skip-existing wheelhouse/*; + fi diff --git a/appveyor.yml b/appveyor.yml new file mode 100644 index 000000000..5f9a079bc --- /dev/null +++ b/appveyor.yml @@ -0,0 +1,19 @@ +environment: +# Don't attempt to build Python 2.7 wheels; websockets only works on Python 3. + CIBW_SKIP: cp27* +# Commented out because tests don't pass reliably on Windows, see #240. +# CIBW_TEST_COMMAND: python -m unittest discover websockets + +# Since Python 2 is still the default, invoke Python 3 explicitly. +install: + - cmd: C:\Python33-x64\python.exe -m pip install cibuildwheel==0.4.0 +# Create file '.cibuildwheel' so that extension build is not optional (c.f. setup.py). + - cmd: touch .cibuildwheel +build_script: + - cmd: C:\Python33-x64\python.exe -m cibuildwheel --output-dir wheelhouse +# Upload to PyPI on tags + - ps: >- + if ($env:APPVEYOR_REPO_TAG -eq "true") { + Invoke-Expression "python -m pip install twine" + Invoke-Expression "python -m twine upload --skip-existing wheelhouse/*.whl" + } diff --git a/setup.py b/setup.py index 6363c6f23..4ec80bb64 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ setuptools.Extension( 'websockets.speedups', sources=['websockets/speedups.c'], - optional=True, + optional=not os.path.exists(os.path.join(root_dir, '.cibuildwheel')), ) ] From e812f02d30d930eb802e0439a397431880bad5f1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Aug 2017 13:58:20 +0200 Subject: [PATCH 0265/1539] Bump version number. --- docs/changelog.rst | 5 ++++- docs/conf.py | 4 ++-- websockets/version.py | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 772d3bee7..cfb8d42fb 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,11 +1,14 @@ Changelog --------- -3.4 +3.5 ... *In development* +3.4 +... + * Renamed :func:`~websockets.server.serve()` and :func:`~websockets.client.connect()`'s ``klass`` argument to ``create_protocol`` to reflect that it can also be a callable. diff --git a/docs/conf.py b/docs/conf.py index 48006a483..885b81690 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -48,9 +48,9 @@ # built documents. # # The short X.Y version. -version = '3.3' +version = '3.4' # The full version, including alpha/beta/rc tags. -release = '3.3' +release = '3.4' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/websockets/version.py b/websockets/version.py index 680144bc4..a0e73377d 100644 --- a/websockets/version.py +++ b/websockets/version.py @@ -1 +1 @@ -version = '3.3' +version = '3.4' From a786635863f591106b258716be2d2998692e2505 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Aug 2017 15:09:18 +0200 Subject: [PATCH 0266/1539] Change Upgrade header to lowercase value. This value is technically case-insensitive but sloppy servers may expect it to be lower case. Fix #250. --- websockets/handshake.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/websockets/handshake.py b/websockets/handshake.py index 0b99242c9..cb6d742a6 100644 --- a/websockets/handshake.py +++ b/websockets/handshake.py @@ -58,7 +58,7 @@ def build_request(set_header): """ rand = bytes(random.getrandbits(8) for _ in range(16)) key = base64.b64encode(rand).decode() - set_header('Upgrade', 'WebSocket') + set_header('Upgrade', 'websocket') set_header('Connection', 'Upgrade') set_header('Sec-WebSocket-Key', key) set_header('Sec-WebSocket-Version', '13') @@ -102,7 +102,7 @@ def build_response(set_header, key): ``key`` comes from :func:`check_request`. """ - set_header('Upgrade', 'WebSocket') + set_header('Upgrade', 'websocket') set_header('Connection', 'Upgrade') set_header('Sec-WebSocket-Accept', accept(key)) From 50c0dae8152d19ae0d6f58e1297967cf883090d1 Mon Sep 17 00:00:00 2001 From: Matthieu Darbois Date: Sun, 20 Aug 2017 20:45:53 +0200 Subject: [PATCH 0267/1539] Add PyPI and CircleCI badges to README This reflects at first glance that this package is available on PyPI and that CI is in place. --- README.rst | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index 5a383721b..ef1920af9 100644 --- a/README.rst +++ b/README.rst @@ -1,5 +1,5 @@ -WebSockets -========== +WebSockets |pypi| |circleci| +============================ ``websockets`` is a library for developing WebSocket servers_ and clients_ in Python. It implements `RFC 6455`_ with a focus on correctness and simplicity. @@ -26,3 +26,8 @@ Bug reports, patches and suggestions welcome! Just open an issue_ or send a .. _Read the Docs: https://websockets.readthedocs.io/ .. _issue: https://github.com/aaugustin/websockets/issues/new .. _pull request: https://github.com/aaugustin/websockets/compare/ + +.. |pypi| image:: https://img.shields.io/pypi/v/websockets.svg + :target: https://pypi.python.org/pypi/websockets +.. |circleci| image:: https://circleci.com/gh/aaugustin/websockets/tree/master.svg?style=svg + :target: https://circleci.com/gh/aaugustin/websockets/tree/master From 544d704e79834cd536bb81d82e9195e96a3b39fd Mon Sep 17 00:00:00 2001 From: mayeut Date: Mon, 21 Aug 2017 19:26:54 +0200 Subject: [PATCH 0268/1539] Add codecov to circleci.yml --- circle.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/circle.yml b/circle.yml index 51fd7f60e..665e46fd5 100644 --- a/circle.yml +++ b/circle.yml @@ -6,8 +6,9 @@ machine: dependencies: override: - - pip install tox + - pip install tox codecov test: override: - tox + - codecov From 683ee51859c6d33ebf7a087beb8e4a76cfe0d2d9 Mon Sep 17 00:00:00 2001 From: Matthieu Darbois Date: Mon, 21 Aug 2017 20:52:13 +0200 Subject: [PATCH 0269/1539] Add codecov badge to README Also use "shield" style for circleci one to be consistent. --- README.rst | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/README.rst b/README.rst index ef1920af9..3c0b013f1 100644 --- a/README.rst +++ b/README.rst @@ -1,5 +1,5 @@ -WebSockets |pypi| |circleci| -============================ +WebSockets |pypi| |circleci| |codecov| +====================================== ``websockets`` is a library for developing WebSocket servers_ and clients_ in Python. It implements `RFC 6455`_ with a focus on correctness and simplicity. @@ -29,5 +29,7 @@ Bug reports, patches and suggestions welcome! Just open an issue_ or send a .. |pypi| image:: https://img.shields.io/pypi/v/websockets.svg :target: https://pypi.python.org/pypi/websockets -.. |circleci| image:: https://circleci.com/gh/aaugustin/websockets/tree/master.svg?style=svg +.. |circleci| image:: https://circleci.com/gh/aaugustin/websockets/tree/master.svg?style=shield :target: https://circleci.com/gh/aaugustin/websockets/tree/master +.. |codecov| image:: https://codecov.io/gh/aaugustin/websockets/branch/master/graph/badge.svg + :target: https://codecov.io/gh/aaugustin/websockets From c827cc679ef19e5672e08fa3ffe62e1410103b6c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 21 Aug 2017 22:06:27 +0200 Subject: [PATCH 0270/1539] Restore missing import. Add test for a code path that was exercised by accident to preserve full test coverage. --- websockets/test_client_server.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index ac020636c..ecf575c9d 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -12,6 +12,7 @@ from .client import * from .compatibility import FORBIDDEN, OK, UNAUTHORIZED from .exceptions import ConnectionClosed, InvalidHandshake, InvalidStatusCode +from .handshake import build_response from .http import USER_AGENT, read_response from .server import * @@ -440,6 +441,15 @@ def wrong_read_response(stream): self.start_client() self.run_loop_once() + @with_server() + @unittest.mock.patch( + 'websockets.server.WebSocketServerProtocol.process_request') + def test_server_error_in_handshake(self, _process_request): + _process_request.side_effect = Exception("process_request crashed") + + with self.assertRaises(InvalidHandshake): + self.start_client() + @with_server() @unittest.mock.patch('websockets.server.WebSocketServerProtocol.send') def test_server_handler_crashes(self, send): From 87638839d54e702fa2c78735c2ba219b3fcfce57 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 15 May 2016 10:54:18 +0200 Subject: [PATCH 0271/1539] Add reserved bits to frames. --- websockets/framing.py | 51 +++++++++++++++++++++++++++---------- websockets/test_protocol.py | 4 +-- 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/websockets/framing.py b/websockets/framing.py index 135a139e5..65863b020 100644 --- a/websockets/framing.py +++ b/websockets/framing.py @@ -49,18 +49,30 @@ } -Frame = collections.namedtuple('Frame', ('fin', 'opcode', 'data')) -Frame.__doc__ = """WebSocket frame. +FrameData = collections.namedtuple( + 'FrameData', + ['fin', 'opcode', 'data', 'rsv1', 'rsv2', 'rsv3'], +) -* ``fin`` is the FIN bit -* ``opcode`` is the opcode -* ``data`` is the payload data -Only these three fields are needed by higher level code. The MASK bit, payload -length and masking-key are handled on the fly by :func:`read_frame` and -:func:`write_frame`. +class Frame(FrameData): + """ + WebSocket frame. -""" + * ``fin`` is the FIN bit + * ``rsv1`` is the RSV1 bit + * ``rsv2`` is the RSV2 bit + * ``rsv3`` is the RSV3 bit + * ``opcode`` is the opcode + * ``data`` is the payload data + + Only these fields are needed by higher level code. The MASK bit, payload + length and masking-key are handled on the fly by :func:`read_frame` and + :func:`write_frame`. + + """ + def __new__(cls, fin, opcode, data, rsv1=False, rsv2=False, rsv3=False): + return FrameData.__new__(cls, fin, opcode, data, rsv1, rsv2, rsv3) @asyncio.coroutine @@ -86,8 +98,9 @@ def read_frame(reader, mask, *, max_size=None): data = yield from reader(2) head1, head2 = struct.unpack('!BB', data) fin = bool(head1 & 0b10000000) - if head1 & 0b01110000: - raise WebSocketProtocolError("Reserved bits must be 0") + rsv1 = bool(head1 & 0b01000000) + rsv2 = bool(head1 & 0b00100000) + rsv3 = bool(head1 & 0b00010000) opcode = head1 & 0b00001111 if bool(head2 & 0b10000000) != mask: raise WebSocketProtocolError("Incorrect masking") @@ -109,7 +122,8 @@ def read_frame(reader, mask, *, max_size=None): if mask: data = apply_mask(data, mask_bits) - frame = Frame(fin, opcode, data) + frame = Frame(fin, opcode, data, rsv1, rsv2, rsv3) + check_frame(frame) return frame @@ -134,8 +148,14 @@ def write_frame(frame, writer, mask): output = io.BytesIO() # Prepare the header - head1 = 0b10000000 if frame.fin else 0 - head1 |= frame.opcode + head1 = ( + (0b10000000 if frame.fin else 0) | + (0b01000000 if frame.rsv1 else 0) | + (0b00100000 if frame.rsv2 else 0) | + (0b00010000 if frame.rsv3 else 0) | + frame.opcode + ) + head2 = 0b10000000 if mask else 0 length = len(frame.data) if length < 0x7e: @@ -165,6 +185,9 @@ def check_frame(frame): contains incorrect values. """ + if frame.rsv1 or frame.rsv2 or frame.rsv3: + raise WebSocketProtocolError("Reserved bits must be 0") + if frame.opcode in (OP_CONT, OP_TEXT, OP_BINARY): return elif frame.opcode in (OP_CLOSE, OP_PING, OP_PONG): diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index c08c57b47..9c8d9d189 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -225,8 +225,8 @@ def last_sent_frame(self): return frame - def assertOneFrameSent(self, fin, opcode, data): - self.assertEqual(self.last_sent_frame(), Frame(fin, opcode, data)) + def assertOneFrameSent(self, *args): + self.assertEqual(self.last_sent_frame(), Frame(*args)) def assertNoFrameSent(self): self.assertIsNone(self.last_sent_frame()) From c4f786103a978601c39e1d1b64f592711929c83a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 15 May 2016 11:21:41 +0200 Subject: [PATCH 0272/1539] Miscellaneous small style cleanups. --- websockets/framing.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/websockets/framing.py b/websockets/framing.py index 65863b020..efbb0114c 100644 --- a/websockets/framing.py +++ b/websockets/framing.py @@ -97,13 +97,16 @@ def read_frame(reader, mask, *, max_size=None): # Read the header data = yield from reader(2) head1, head2 = struct.unpack('!BB', data) + fin = bool(head1 & 0b10000000) rsv1 = bool(head1 & 0b01000000) rsv2 = bool(head1 & 0b00100000) rsv3 = bool(head1 & 0b00010000) opcode = head1 & 0b00001111 + if bool(head2 & 0b10000000) != mask: raise WebSocketProtocolError("Incorrect masking") + length = head2 & 0b01111111 if length == 126: data = yield from reader(2) @@ -145,6 +148,7 @@ def write_frame(frame, writer, mask): """ check_frame(frame) + output = io.BytesIO() # Prepare the header @@ -157,13 +161,15 @@ def write_frame(frame, writer, mask): ) head2 = 0b10000000 if mask else 0 + length = len(frame.data) - if length < 0x7e: + if length < 126: output.write(struct.pack('!BB', head1, head2 | length)) - elif length < 0x10000: + elif length < 65536: output.write(struct.pack('!BBH', head1, head2 | 126, length)) else: output.write(struct.pack('!BBQ', head1, head2 | 127, length)) + if mask: mask_bits = struct.pack('!I', random.getrandbits(32)) output.write(mask_bits) @@ -188,9 +194,9 @@ def check_frame(frame): if frame.rsv1 or frame.rsv2 or frame.rsv3: raise WebSocketProtocolError("Reserved bits must be 0") - if frame.opcode in (OP_CONT, OP_TEXT, OP_BINARY): + if frame.opcode in [OP_CONT, OP_TEXT, OP_BINARY]: return - elif frame.opcode in (OP_CLOSE, OP_PING, OP_PONG): + elif frame.opcode in [OP_CLOSE, OP_PING, OP_PONG]: if len(frame.data) > 125: raise WebSocketProtocolError("Control frame too long") if not frame.fin: From 6f7a462b8af63930a5ca3e0643968d0da7ef150d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 15 May 2016 11:29:41 +0200 Subject: [PATCH 0273/1539] Add support for processing frames. This is where I'm planning to hook extensions. --- websockets/framing.py | 17 +++++++++++++++-- websockets/test_framing.py | 34 ++++++++++++++++++++++++++-------- 2 files changed, 41 insertions(+), 10 deletions(-) diff --git a/websockets/framing.py b/websockets/framing.py index efbb0114c..52103d353 100644 --- a/websockets/framing.py +++ b/websockets/framing.py @@ -76,7 +76,7 @@ def __new__(cls, fin, opcode, data, rsv1=False, rsv2=False, rsv3=False): @asyncio.coroutine -def read_frame(reader, mask, *, max_size=None): +def read_frame(reader, mask, *, max_size=None, extensions=()): """ Read a WebSocket frame and return a :class:`Frame` object. @@ -89,6 +89,9 @@ def read_frame(reader, mask, *, max_size=None): If ``max_size`` is set and the payload exceeds this size in bytes, :exc:`~websockets.exceptions.PayloadTooBig` is raised. + If ``extensions`` is provided, it's a list of functions that transform the + frame and return it. They are applied in reverse order. + This function validates the frame before returning it and raises :exc:`~websockets.exceptions.WebSocketProtocolError` if it contains incorrect values. @@ -128,10 +131,14 @@ def read_frame(reader, mask, *, max_size=None): frame = Frame(fin, opcode, data, rsv1, rsv2, rsv3) check_frame(frame) + + for extension in reversed(extensions): + frame = extension(frame) + return frame -def write_frame(frame, writer, mask): +def write_frame(frame, writer, mask, *, extensions=()): """ Write a WebSocket frame. @@ -142,11 +149,17 @@ def write_frame(frame, writer, mask): ``mask`` is a :class:`bool` telling whether the frame should be masked i.e. whether the write happens on the client side. + If ``extensions`` is provided, it's a list of functions that transform the + frame and return it. They are applied in order. + This function validates the frame before sending it and raises :exc:`~websockets.exceptions.WebSocketProtocolError` if it contains incorrect values. """ + for extension in extensions: + frame = extension(frame) + check_frame(frame) output = io.BytesIO() diff --git a/websockets/test_framing.py b/websockets/test_framing.py index f88ee3bcc..cbe454e9e 100644 --- a/websockets/test_framing.py +++ b/websockets/test_framing.py @@ -1,4 +1,6 @@ import asyncio +import codecs +import sys import unittest import unittest.mock @@ -15,31 +17,32 @@ def setUp(self): def tearDown(self): self.loop.close() - def decode(self, message, mask=False, max_size=None): + def decode(self, message, mask=False, max_size=None, extensions=()): self.stream = asyncio.StreamReader(loop=self.loop) self.stream.feed_data(message) self.stream.feed_eof() frame = self.loop.run_until_complete(read_frame( - self.stream.readexactly, mask, max_size=max_size)) + self.stream.readexactly, mask, + max_size=max_size, extensions=extensions)) # Make sure all the data was consumed. self.assertTrue(self.stream.at_eof()) return frame - def encode(self, frame, mask=False): + def encode(self, frame, mask=False, extensions=()): writer = unittest.mock.Mock() - write_frame(frame, writer, mask) + write_frame(frame, writer, mask, extensions=extensions) # Ensure the entire frame is sent with a single call to writer(). # Multiple calls cause TCP fragmentation and degrade performance. self.assertEqual(writer.call_count, 1) # The frame data is the single positional argument of that call. return writer.call_args[0][0] - def round_trip(self, message, expected, mask=False): - decoded = self.decode(message, mask) + def round_trip(self, message, expected, mask=False, extensions=()): + decoded = self.decode(message, mask, extensions=extensions) self.assertEqual(decoded, expected) - encoded = self.encode(decoded, mask) + encoded = self.encode(decoded, mask, extensions=extensions) if mask: # non-deterministic encoding - decoded = self.decode(encoded, mask) + decoded = self.decode(encoded, mask, extensions=extensions) self.assertEqual(decoded, expected) else: # deterministic encoding self.assertEqual(encoded, message) @@ -144,3 +147,18 @@ def test_parse_close_errors(self): parse_close(b'\x03\xe7') with self.assertRaises(UnicodeDecodeError): parse_close(b'\x03\xe8\xff\xff') + + @unittest.skipUnless(sys.version_info[:2] >= (3, 4), "rot13 is new in 3.4") + def test_extensions(self): + + # This extensions is symmetrical. + def rot13(frame): + assert frame.opcode == OP_TEXT + text = frame.data.decode() + data = codecs.encode(text, 'rot13').encode() + return frame._replace(data=data) + + self.round_trip( + b'\x81\x05uryyb', + Frame(True, OP_TEXT, b'hello'), + extensions=[rot13]) From 2e5e851c9695b01a10870dced6136c9b50c052e5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 15 May 2016 18:15:41 +0200 Subject: [PATCH 0274/1539] Negligible performance optimization. --- websockets/framing.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/websockets/framing.py b/websockets/framing.py index 52103d353..e427178d2 100644 --- a/websockets/framing.py +++ b/websockets/framing.py @@ -101,13 +101,14 @@ def read_frame(reader, mask, *, max_size=None, extensions=()): data = yield from reader(2) head1, head2 = struct.unpack('!BB', data) - fin = bool(head1 & 0b10000000) - rsv1 = bool(head1 & 0b01000000) - rsv2 = bool(head1 & 0b00100000) - rsv3 = bool(head1 & 0b00010000) + # While not very Pythonic, this is marginally faster than calling bool(). + fin = True if head1 & 0b10000000 else False + rsv1 = True if head1 & 0b01000000 else False + rsv2 = True if head1 & 0b00100000 else False + rsv3 = True if head1 & 0b00010000 else False opcode = head1 & 0b00001111 - if bool(head2 & 0b10000000) != mask: + if (True if head2 & 0b10000000 else False) != mask: raise WebSocketProtocolError("Incorrect masking") length = head2 & 0b01111111 From 8f1ff2825c5fc541830b0e530efc9a0a9ee49c1e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 15 May 2016 18:52:09 +0200 Subject: [PATCH 0275/1539] Check close codes of outgoing frames. This can avoid protocol violations. --- websockets/framing.py | 19 ++++++++++++++----- websockets/test_framing.py | 6 +++++- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/websockets/framing.py b/websockets/framing.py index e427178d2..c0a611e51 100644 --- a/websockets/framing.py +++ b/websockets/framing.py @@ -233,14 +233,13 @@ def parse_close(data): length = len(data) if length == 0: return 1005, '' - elif length == 1: - raise WebSocketProtocolError("Close frame too short") - else: + elif length >= 2: code, = struct.unpack('!H', data[:2]) - if not (code in CLOSE_CODES or 3000 <= code < 5000): - raise WebSocketProtocolError("Invalid status code") + check_close(code) reason = data[2:].decode('utf-8') return code, reason + else: + raise WebSocketProtocolError("Close frame too short") def serialize_close(code, reason): @@ -250,4 +249,14 @@ def serialize_close(code, reason): This is the reverse of :func:`parse_close`. """ + check_close(code) return struct.pack('!H', code) + reason.encode('utf-8') + + +def check_close(code): + """ + Check the close code for a close frame. + + """ + if not (code in CLOSE_CODES or 3000 <= code < 5000): + raise WebSocketProtocolError("Invalid status code") diff --git a/websockets/test_framing.py b/websockets/test_framing.py index cbe454e9e..0da5bb620 100644 --- a/websockets/test_framing.py +++ b/websockets/test_framing.py @@ -133,7 +133,7 @@ def test_fragmented_control_frame(self): with self.assertRaises(WebSocketProtocolError): self.decode(b'\x08\x00') - def test_parse_close(self): + def test_parse_close_and_serialize_close(self): self.round_trip_close(b'\x03\xe8', 1000, '') self.round_trip_close(b'\x03\xe8OK', 1000, 'OK') @@ -148,6 +148,10 @@ def test_parse_close_errors(self): with self.assertRaises(UnicodeDecodeError): parse_close(b'\x03\xe8\xff\xff') + def test_serialize_close_errors(self): + with self.assertRaises(WebSocketProtocolError): + serialize_close(999, '') + @unittest.skipUnless(sys.version_info[:2] >= (3, 4), "rot13 is new in 3.4") def test_extensions(self): From 7e119fbe82fbaa883a09e288ae9c496f2977ea4c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 15 May 2016 18:55:25 +0200 Subject: [PATCH 0276/1539] Refactor read/write/check_frame as Frame methods. --- websockets/framing.py | 268 ++++++++++++++++++------------------ websockets/protocol.py | 27 ++-- websockets/test_framing.py | 6 +- websockets/test_protocol.py | 6 +- 4 files changed, 161 insertions(+), 146 deletions(-) diff --git a/websockets/framing.py b/websockets/framing.py index c0a611e51..7aa4bfba5 100644 --- a/websockets/framing.py +++ b/websockets/framing.py @@ -26,7 +26,7 @@ __all__ = [ 'OP_CONT', 'OP_TEXT', 'OP_BINARY', 'OP_CLOSE', 'OP_PING', 'OP_PONG', - 'Frame', 'read_frame', 'write_frame', 'parse_close', 'serialize_close' + 'Frame', 'parse_close', 'serialize_close' ] OP_CONT, OP_TEXT, OP_BINARY = range(0x00, 0x03) @@ -67,156 +67,162 @@ class Frame(FrameData): * ``data`` is the payload data Only these fields are needed by higher level code. The MASK bit, payload - length and masking-key are handled on the fly by :func:`read_frame` and - :func:`write_frame`. + length and masking-key are handled on the fly by :meth:`read` and + :meth:`write`. """ def __new__(cls, fin, opcode, data, rsv1=False, rsv2=False, rsv3=False): return FrameData.__new__(cls, fin, opcode, data, rsv1, rsv2, rsv3) + @classmethod + @asyncio.coroutine + def read(cls, reader, *, mask, max_size=None, extensions=()): + """ + Read a WebSocket frame and return a :class:`Frame` object. -@asyncio.coroutine -def read_frame(reader, mask, *, max_size=None, extensions=()): - """ - Read a WebSocket frame and return a :class:`Frame` object. - - ``reader`` is a coroutine taking an integer argument and reading exactly - this number of bytes, unless the end of file is reached. + ``reader`` is a coroutine taking an integer argument and reading + exactly this number of bytes, unless the end of file is reached. - ``mask`` is a :class:`bool` telling whether the frame should be masked - i.e. whether the read happens on the server side. + ``mask`` is a :class:`bool` telling whether the frame should be masked + i.e. whether the read happens on the server side. - If ``max_size`` is set and the payload exceeds this size in bytes, - :exc:`~websockets.exceptions.PayloadTooBig` is raised. + If ``max_size`` is set and the payload exceeds this size in bytes, + :exc:`~websockets.exceptions.PayloadTooBig` is raised. - If ``extensions`` is provided, it's a list of functions that transform the - frame and return it. They are applied in reverse order. + If ``extensions`` is provided, it's a list of functions that transform + the frame and return it. They are applied in reverse order. - This function validates the frame before returning it and raises - :exc:`~websockets.exceptions.WebSocketProtocolError` if it contains - incorrect values. + This function validates the frame before returning it and raises + :exc:`~websockets.exceptions.WebSocketProtocolError` if it contains + incorrect values. - """ - # Read the header - data = yield from reader(2) - head1, head2 = struct.unpack('!BB', data) - - # While not very Pythonic, this is marginally faster than calling bool(). - fin = True if head1 & 0b10000000 else False - rsv1 = True if head1 & 0b01000000 else False - rsv2 = True if head1 & 0b00100000 else False - rsv3 = True if head1 & 0b00010000 else False - opcode = head1 & 0b00001111 - - if (True if head2 & 0b10000000 else False) != mask: - raise WebSocketProtocolError("Incorrect masking") - - length = head2 & 0b01111111 - if length == 126: + """ + # Read the header data = yield from reader(2) - length, = struct.unpack('!H', data) - elif length == 127: - data = yield from reader(8) - length, = struct.unpack('!Q', data) - if max_size is not None and length > max_size: - raise PayloadTooBig("Payload exceeds limit " - "({} > {} bytes)".format(length, max_size)) - if mask: - mask_bits = yield from reader(4) - - # Read the data - data = yield from reader(length) - if mask: - data = apply_mask(data, mask_bits) - - frame = Frame(fin, opcode, data, rsv1, rsv2, rsv3) - - check_frame(frame) - - for extension in reversed(extensions): - frame = extension(frame) - - return frame - - -def write_frame(frame, writer, mask, *, extensions=()): - """ - Write a WebSocket frame. - - ``frame`` is the :class:`Frame` object to write. - - ``writer`` is a function accepting bytes. + head1, head2 = struct.unpack('!BB', data) - ``mask`` is a :class:`bool` telling whether the frame should be masked - i.e. whether the write happens on the client side. + # While not Pythonic, this is marginally faster than calling bool(). + fin = True if head1 & 0b10000000 else False + rsv1 = True if head1 & 0b01000000 else False + rsv2 = True if head1 & 0b00100000 else False + rsv3 = True if head1 & 0b00010000 else False + opcode = head1 & 0b00001111 - If ``extensions`` is provided, it's a list of functions that transform the - frame and return it. They are applied in order. + if (True if head2 & 0b10000000 else False) != mask: + raise WebSocketProtocolError("Incorrect masking") - This function validates the frame before sending it and raises - :exc:`~websockets.exceptions.WebSocketProtocolError` if it contains - incorrect values. + length = head2 & 0b01111111 + if length == 126: + data = yield from reader(2) + length, = struct.unpack('!H', data) + elif length == 127: + data = yield from reader(8) + length, = struct.unpack('!Q', data) + if max_size is not None and length > max_size: + raise PayloadTooBig("Payload exceeds limit " + "({} > {} bytes)".format(length, max_size)) + if mask: + mask_bits = yield from reader(4) + + # Read the data + data = yield from reader(length) + if mask: + data = apply_mask(data, mask_bits) + + frame = cls(fin, opcode, data, rsv1, rsv2, rsv3) + + frame.check() + + for extension in reversed(extensions): + frame = extension(frame) + + return frame + + def write(frame, writer, *, mask, extensions=()): + """ + Write a WebSocket frame. + + ``frame`` is the :class:`Frame` object to write. + + ``writer`` is a function accepting bytes. + + ``mask`` is a :class:`bool` telling whether the frame should be masked + i.e. whether the write happens on the client side. + + If ``extensions`` is provided, it's a list of functions that transform + the frame and return it. They are applied in order. + + This function validates the frame before sending it and raises + :exc:`~websockets.exceptions.WebSocketProtocolError` if it contains + incorrect values. + + """ + + # The first parameter is called `frame` rather than `self`, + # but it's the instance of class to which this method is bound. + + for extension in extensions: + frame = extension(frame) + + frame.check() + + output = io.BytesIO() + + # Prepare the header + head1 = ( + (0b10000000 if frame.fin else 0) | + (0b01000000 if frame.rsv1 else 0) | + (0b00100000 if frame.rsv2 else 0) | + (0b00010000 if frame.rsv3 else 0) | + frame.opcode + ) + + head2 = 0b10000000 if mask else 0 + + length = len(frame.data) + if length < 126: + output.write(struct.pack('!BB', head1, head2 | length)) + elif length < 65536: + output.write(struct.pack('!BBH', head1, head2 | 126, length)) + else: + output.write(struct.pack('!BBQ', head1, head2 | 127, length)) + + if mask: + mask_bits = struct.pack('!I', random.getrandbits(32)) + output.write(mask_bits) + + # Prepare the data + if mask: + data = apply_mask(frame.data, mask_bits) + else: + data = frame.data + output.write(data) + + # Send the frame + writer(output.getvalue()) - """ - for extension in extensions: - frame = extension(frame) - - check_frame(frame) + def check(frame): + """ + Raise :exc:`~websockets.exceptions.WebSocketProtocolError` if the frame + contains incorrect values. - output = io.BytesIO() + """ - # Prepare the header - head1 = ( - (0b10000000 if frame.fin else 0) | - (0b01000000 if frame.rsv1 else 0) | - (0b00100000 if frame.rsv2 else 0) | - (0b00010000 if frame.rsv3 else 0) | - frame.opcode - ) + # The first parameter is called `frame` rather than `self`, + # but it's the instance of class to which this method is bound. - head2 = 0b10000000 if mask else 0 - - length = len(frame.data) - if length < 126: - output.write(struct.pack('!BB', head1, head2 | length)) - elif length < 65536: - output.write(struct.pack('!BBH', head1, head2 | 126, length)) - else: - output.write(struct.pack('!BBQ', head1, head2 | 127, length)) - - if mask: - mask_bits = struct.pack('!I', random.getrandbits(32)) - output.write(mask_bits) - - # Prepare the data - if mask: - data = apply_mask(frame.data, mask_bits) - else: - data = frame.data - output.write(data) + if frame.rsv1 or frame.rsv2 or frame.rsv3: + raise WebSocketProtocolError("Reserved bits must be 0") - # Send the frame - writer(output.getvalue()) - - -def check_frame(frame): - """ - Raise :exc:`~websockets.exceptions.WebSocketProtocolError` if the frame - contains incorrect values. - - """ - if frame.rsv1 or frame.rsv2 or frame.rsv3: - raise WebSocketProtocolError("Reserved bits must be 0") - - if frame.opcode in [OP_CONT, OP_TEXT, OP_BINARY]: - return - elif frame.opcode in [OP_CLOSE, OP_PING, OP_PONG]: - if len(frame.data) > 125: - raise WebSocketProtocolError("Control frame too long") - if not frame.fin: - raise WebSocketProtocolError("Fragmented control frame") - else: - raise WebSocketProtocolError("Invalid opcode") + if frame.opcode in [OP_CONT, OP_TEXT, OP_BINARY]: + return + elif frame.opcode in [OP_CLOSE, OP_PING, OP_PONG]: + if len(frame.data) > 125: + raise WebSocketProtocolError("Control frame too long") + if not frame.fin: + raise WebSocketProtocolError("Fragmented control frame") + else: + raise WebSocketProtocolError("Invalid opcode") def parse_close(data): diff --git a/websockets/protocol.py b/websockets/protocol.py index bfe7e2313..11a1a302c 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -541,11 +541,15 @@ def read_data_frame(self, max_size): @asyncio.coroutine def read_frame(self, max_size): - is_masked = not self.is_client - frame = yield from read_frame( - self.reader.readexactly, is_masked, max_size=max_size) - side = 'client' if self.is_client else 'server' - logger.debug("%s << %s", side, frame) + frame = yield from Frame.read( + self.reader.readexactly, + mask=not self.is_client, + max_size=max_size, + ) + logger.debug( + "%s << %s", + 'client' if self.is_client else 'server', frame, + ) return frame @asyncio.coroutine @@ -559,11 +563,16 @@ def write_frame(self, opcode, data=b''): # before yielding control to avoid sending more than one close frame. if opcode == OP_CLOSE: self.state = CLOSING + frame = Frame(True, opcode, data) - side = 'client' if self.is_client else 'server' - logger.debug("%s >> %s", side, frame) - is_masked = self.is_client - write_frame(frame, self.writer.write, is_masked) + logger.debug( + "%s >> %s", + 'client' if self.is_client else 'server', frame, + ) + frame.write( + self.writer.write, + mask=self.is_client, + ) # Backport of the combined logic of: # https://github.com/python/asyncio/pull/280 diff --git a/websockets/test_framing.py b/websockets/test_framing.py index 0da5bb620..a77332180 100644 --- a/websockets/test_framing.py +++ b/websockets/test_framing.py @@ -21,8 +21,8 @@ def decode(self, message, mask=False, max_size=None, extensions=()): self.stream = asyncio.StreamReader(loop=self.loop) self.stream.feed_data(message) self.stream.feed_eof() - frame = self.loop.run_until_complete(read_frame( - self.stream.readexactly, mask, + frame = self.loop.run_until_complete(Frame.read( + self.stream.readexactly, mask=mask, max_size=max_size, extensions=extensions)) # Make sure all the data was consumed. self.assertTrue(self.stream.at_eof()) @@ -30,7 +30,7 @@ def decode(self, message, mask=False, max_size=None, extensions=()): def encode(self, frame, mask=False, extensions=()): writer = unittest.mock.Mock() - write_frame(frame, writer, mask, extensions=extensions) + frame.write(writer, mask=mask, extensions=extensions) # Ensure the entire frame is sent with a single call to writer(). # Multiple calls cause TCP fragmentation and degrade performance. self.assertEqual(writer.call_count, 1) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 9c8d9d189..850f80e12 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -108,7 +108,7 @@ def receive_frame(self, frame): """ writer = self.protocol.data_received mask = not self.protocol.is_client - self.loop.call_soon(write_frame, frame, writer, mask) + self.loop.call_soon(functools.partial(frame.write, writer, mask=mask)) def receive_eof(self): """ @@ -216,8 +216,8 @@ def last_sent_frame(self): if stream.at_eof(): frame = None else: - frame = self.loop.run_until_complete(read_frame( - stream.readexactly, self.protocol.is_client)) + frame = self.loop.run_until_complete(Frame.read( + stream.readexactly, mask=self.protocol.is_client)) if not stream.at_eof(): # pragma: no cover data = self.loop.run_until_complete(stream.read()) From c098d4463496e5b4ff9ac841aed0325ab3754a1c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 4 May 2017 09:15:06 +0200 Subject: [PATCH 0277/1539] Negligible optimization. --- websockets/framing.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/websockets/framing.py b/websockets/framing.py index 7aa4bfba5..69a00f0e1 100644 --- a/websockets/framing.py +++ b/websockets/framing.py @@ -29,8 +29,8 @@ 'Frame', 'parse_close', 'serialize_close' ] -OP_CONT, OP_TEXT, OP_BINARY = range(0x00, 0x03) -OP_CLOSE, OP_PING, OP_PONG = range(0x08, 0x0b) +DATA_OPCODES = OP_CONT, OP_TEXT, OP_BINARY = 0x00, 0x01, 0x02 +CTRL_OPCODES = OP_CLOSE, OP_PING, OP_PONG = 0x08, 0x09, 0x0a CLOSE_CODES = { 1000: "OK", @@ -214,9 +214,9 @@ def check(frame): if frame.rsv1 or frame.rsv2 or frame.rsv3: raise WebSocketProtocolError("Reserved bits must be 0") - if frame.opcode in [OP_CONT, OP_TEXT, OP_BINARY]: + if frame.opcode in DATA_OPCODES: return - elif frame.opcode in [OP_CLOSE, OP_PING, OP_PONG]: + elif frame.opcode in CTRL_OPCODES: if len(frame.data) > 125: raise WebSocketProtocolError("Control frame too long") if not frame.fin: From 53c0890a2bb58a93a965fd2b76a3ee2c40a3af3d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 15 May 2016 22:27:17 +0200 Subject: [PATCH 0278/1539] Add framework for extensions in handshake. --- websockets/client.py | 81 +++++++++++++++++++++++++++++------- websockets/server.py | 98 +++++++++++++++++++++++++++++++++----------- 2 files changed, 140 insertions(+), 39 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index 544223e43..94d8eb903 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -27,6 +27,15 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): is_client = True state = CONNECTING + def __init__(self, *, + origin=None, extensions=None, subprotocols=None, + extra_headers=None, **kwds): + self.origin = origin + self.available_extensions = extensions + self.available_subprotocols = subprotocols + self.extra_headers = extra_headers + super().__init__(**kwds) + @asyncio.coroutine def write_http_request(self, path, headers): """ @@ -69,28 +78,55 @@ def read_http_response(self): return status_code, self.response_headers - def process_subprotocol(self, get_header, subprotocols=None): + def process_extensions(self, get_header, available_extensions=None): + """ + Handle the Sec-WebSocket-Extensions HTTP response header. + + """ + extensions = get_header('Sec-WebSocket-Extensions') + if extensions: + extensions = [e.strip() for e in extensions.split(',')] + if available_extensions is None: + raise InvalidHandshake("No extensions supported.") + unsupported_extensions = [ + extension + for extension in extensions + if extension not in available_extensions + ] + if unsupported_extensions: + raise InvalidHandshake( + "Unsupported extensions: {}" + .format(', '.join(unsupported_extensions))) + return extensions + + def process_subprotocol(self, get_header, available_subprotocols=None): """ - Handle the Sec-WebSocket-Protocol HTTP header. + Handle the Sec-WebSocket-Protocol HTTP response header. """ subprotocol = get_header('Sec-WebSocket-Protocol') if subprotocol: - if subprotocols is None or subprotocol not in subprotocols: + if available_subprotocols is None: + raise InvalidHandshake("No subprotocols supported.") + if subprotocol not in available_subprotocols: raise InvalidHandshake( - "Unknown subprotocol: {}".format(subprotocol)) + "Unsupported subprotocol: {}".format(subprotocol)) return subprotocol @asyncio.coroutine - def handshake(self, wsuri, - origin=None, subprotocols=None, extra_headers=None): + def handshake(self, wsuri, origin=None, + available_extensions=None, available_subprotocols=None, + extra_headers=None): """ Perform the client side of the opening handshake. If provided, ``origin`` sets the Origin HTTP header. - If provided, ``subprotocols`` is a list of supported subprotocols in - order of decreasing preference. + If provided, ``available_extensions`` is a list of supported + extensions in the order in which they should be used. + + If provided, ``available_subprotocols`` is a list of supported + subprotocols in order of decreasing preference. If provided, ``extra_headers`` sets additional HTTP request headers. It must be a mapping or an iterable of (name, value) pairs. @@ -105,8 +141,12 @@ def handshake(self, wsuri, set_header('Host', '{}:{}'.format(wsuri.host, wsuri.port)) if origin is not None: set_header('Origin', origin) - if subprotocols is not None: - set_header('Sec-WebSocket-Protocol', ', '.join(subprotocols)) + if available_extensions is not None: + set_header( + 'Sec-WebSocket-Extensions', ', '.join(available_extensions)) + if available_subprotocols is not None: + set_header( + 'Sec-WebSocket-Protocol', ', '.join(available_subprotocols)) if extra_headers is not None: if isinstance(extra_headers, collections.abc.Mapping): extra_headers = extra_headers.items() @@ -126,7 +166,11 @@ def handshake(self, wsuri, check_response(get_header, key) - self.subprotocol = self.process_subprotocol(get_header, subprotocols) + self.extensions = self.process_extensions( + get_header, available_extensions) + + self.subprotocol = self.process_subprotocol( + get_header, available_subprotocols) assert self.state == CONNECTING self.state = OPEN @@ -139,8 +183,8 @@ def connect(uri, *, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, legacy_recv=False, klass=None, - origin=None, subprotocols=None, extra_headers=None, - **kwds): + origin=None, extensions=None, subprotocols=None, + extra_headers=None, **kwds): """ This coroutine connects to a WebSocket server at a given ``uri``. @@ -169,6 +213,8 @@ def connect(uri, *, :func:`connect` also accepts the following optional arguments: * ``origin`` sets the Origin HTTP header + * ``extensions`` is a list of supported extensions in order of decreasing + preference * ``subprotocols`` is a list of supported subprotocols in order of decreasing preference * ``extra_headers`` sets additional HTTP request headers – it can be a @@ -204,6 +250,8 @@ def connect(uri, *, timeout=timeout, max_size=max_size, max_queue=max_queue, read_limit=read_limit, write_limit=write_limit, loop=loop, legacy_recv=legacy_recv, + origin=origin, extensions=extensions, subprotocols=subprotocols, + extra_headers=extra_headers, ) transport, protocol = yield from loop.create_connection( @@ -211,8 +259,11 @@ def connect(uri, *, try: yield from protocol.handshake( - wsuri, origin=origin, subprotocols=subprotocols, - extra_headers=extra_headers) + wsuri, origin=origin, + available_extensions=extensions, + available_subprotocols=subprotocols, + extra_headers=extra_headers, + ) except Exception: yield from protocol.close_connection(force=True) raise diff --git a/websockets/server.py b/websockets/server.py index 770609c76..86dc884b9 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -39,11 +39,13 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): state = CONNECTING def __init__(self, ws_handler, ws_server, *, - origins=None, subprotocols=None, extra_headers=None, **kwds): + origins=None, extensions=None, subprotocols=None, + extra_headers=None, **kwds): self.ws_handler = ws_handler self.ws_server = ws_server self.origins = origins - self.subprotocols = subprotocols + self.available_extensions = extensions + self.available_subprotocols = subprotocols self.extra_headers = extra_headers super().__init__(**kwds) @@ -65,8 +67,11 @@ def handler(self): try: path = yield from self.handshake( - origins=self.origins, subprotocols=self.subprotocols, - extra_headers=self.extra_headers) + origins=self.origins, + available_extensions=self.available_extensions, + available_subprotocols=self.available_subprotocols, + extra_headers=self.extra_headers, + ) except ConnectionError as exc: logger.debug( "Connection error in opening handshake", exc_info=True) @@ -236,7 +241,7 @@ def process_request(self, path, request_headers): def process_origin(self, get_header, origins=None): """ - Handle the Origin HTTP header. + Handle the Origin HTTP request header. Raise :exc:`~websockets.exceptions.InvalidOrigin` if the origin isn't acceptable. @@ -248,41 +253,77 @@ def process_origin(self, get_header, origins=None): raise InvalidOrigin("Origin not allowed: {}".format(origin)) return origin - def process_subprotocol(self, get_header, subprotocols=None): + def process_extensions(self, get_header, available_extensions=None): + """ + Handle the Sec-WebSocket-Extensions HTTP request header. + + """ + if available_extensions is not None: + extensions = get_header('Sec-WebSocket-Extensions') + if extensions: + return self.select_extensions( + [ + extension.strip() + for extension in extensions.split(',') + ], + available_extensions, + ) + + @staticmethod + def select_extensions(client_extensions, server_extensions): """ - Handle the Sec-WebSocket-Protocol HTTP header. + Pick a subprotocol among those offered by the client. """ - if subprotocols is not None: - subprotocol = get_header('Sec-WebSocket-Protocol') - if subprotocol: + return [ + extension + for extension in client_extensions + if extension in server_extensions + ] + + def process_subprotocol(self, get_header, available_subprotocols=None): + """ + Handle the Sec-WebSocket-Protocol HTTP request header. + + """ + if available_subprotocols is not None: + subprotocols = get_header('Sec-WebSocket-Protocol') + if subprotocols: return self.select_subprotocol( - [p.strip() for p in subprotocol.split(',')], - subprotocols, + [ + subprotocol.strip() + for subprotocol in subprotocols.split(',') + ], + available_subprotocols, ) @staticmethod - def select_subprotocol(client_protos, server_protos): + def select_subprotocol(client_subprotocols, server_subprotocols): """ Pick a subprotocol among those offered by the client. """ - common_protos = set(client_protos) & set(server_protos) - if not common_protos: + subprotocols = set(client_subprotocols) & set(server_subprotocols) + if not subprotocols: return None - priority = lambda p: client_protos.index(p) + server_protos.index(p) - return sorted(common_protos, key=priority)[0] + priority = lambda p: ( + client_subprotocols.index(p) + server_subprotocols.index(p)) + return sorted(subprotocols, key=priority)[0] @asyncio.coroutine - def handshake(self, origins=None, subprotocols=None, extra_headers=None): + def handshake(self, origins=None, available_extensions=None, + available_subprotocols=None, extra_headers=None): """ Perform the server side of the opening handshake. If provided, ``origins`` is a list of acceptable HTTP Origin values. Include ``''`` if the lack of an origin is acceptable. - If provided, ``subprotocols`` is a list of supported subprotocols in - order of decreasing preference. + If provided, ``available_extensions`` is a list of supported + extensions in the order in which they should be used. + + If provided, ``available_subprotocols`` is a list of supported + subprotocols in order of decreasing preference. If provided, ``extra_headers`` sets additional HTTP response headers. It can be a mapping or an iterable of (name, value) pairs. It can also @@ -308,13 +349,20 @@ def handshake(self, origins=None, subprotocols=None, extra_headers=None): key = check_request(get_header) self.origin = self.process_origin(get_header, origins) - self.subprotocol = self.process_subprotocol(get_header, subprotocols) + + self.extensions = self.process_extensions( + get_header, available_extensions) + + self.subprotocol = self.process_subprotocol( + get_header, available_subprotocols) response_headers = [] set_header = lambda k, v: response_headers.append((k, v)) set_header('Server', USER_AGENT) + if self.extensions: + set_header('Sec-WebSocket-Extensions', ', '.join(self.extensions)) if self.subprotocol: set_header('Sec-WebSocket-Protocol', self.subprotocol) if extra_headers is not None: @@ -441,8 +489,8 @@ def serve(ws_handler, host=None, port=None, *, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, legacy_recv=False, klass=None, - origins=None, subprotocols=None, extra_headers=None, - **kwds): + origins=None, extensions=None, subprotocols=None, + extra_headers=None, **kwds): """ Create, start, and return a :class:`WebSocketServer` object. @@ -484,6 +532,8 @@ def serve(ws_handler, host=None, port=None, *, * ``origins`` defines acceptable Origin HTTP headers — include ``''`` if the lack of an origin is acceptable + * ``extensions`` is a list of supported extensions in order of decreasing + preference * ``subprotocols`` is a list of supported subprotocols in order of decreasing preference * ``extra_headers`` sets additional HTTP response headers — it can be a @@ -525,7 +575,7 @@ def serve(ws_handler, host=None, port=None, *, timeout=timeout, max_size=max_size, max_queue=max_queue, read_limit=read_limit, write_limit=write_limit, loop=loop, legacy_recv=legacy_recv, - origins=origins, subprotocols=subprotocols, + origins=origins, extensions=extensions, subprotocols=subprotocols, extra_headers=extra_headers, ) server = yield from loop.create_server(factory, host, port, **kwds) From 2cf56e9af53d8a22ca320d817c657e9bc83dd27d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 9 May 2017 15:58:07 +0200 Subject: [PATCH 0279/1539] Refactor processing of frames. --- websockets/framing.py | 15 +++++++++------ websockets/test_framing.py | 19 ++++++++++++------- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/websockets/framing.py b/websockets/framing.py index 69a00f0e1..24a11c5c5 100644 --- a/websockets/framing.py +++ b/websockets/framing.py @@ -25,6 +25,7 @@ __all__ = [ + 'DATA_OPCODES', 'CTRL_OPCODES', 'OP_CONT', 'OP_TEXT', 'OP_BINARY', 'OP_CLOSE', 'OP_PING', 'OP_PONG', 'Frame', 'parse_close', 'serialize_close' ] @@ -89,8 +90,9 @@ def read(cls, reader, *, mask, max_size=None, extensions=()): If ``max_size`` is set and the payload exceeds this size in bytes, :exc:`~websockets.exceptions.PayloadTooBig` is raised. - If ``extensions`` is provided, it's a list of functions that transform - the frame and return it. They are applied in reverse order. + If ``extensions`` is provided, it's a list of classes with an + ``decode()`` method that transform the frame and return a new frame. + They are applied in order. This function validates the frame before returning it and raises :exc:`~websockets.exceptions.WebSocketProtocolError` if it contains @@ -134,7 +136,7 @@ def read(cls, reader, *, mask, max_size=None, extensions=()): frame.check() for extension in reversed(extensions): - frame = extension(frame) + frame = extension.decode(frame) return frame @@ -149,8 +151,9 @@ def write(frame, writer, *, mask, extensions=()): ``mask`` is a :class:`bool` telling whether the frame should be masked i.e. whether the write happens on the client side. - If ``extensions`` is provided, it's a list of functions that transform - the frame and return it. They are applied in order. + If ``extensions`` is provided, it's a list of classes with an + ``encode()`` method that transform the frame and return a new frame. + They are applied in order. This function validates the frame before sending it and raises :exc:`~websockets.exceptions.WebSocketProtocolError` if it contains @@ -162,7 +165,7 @@ def write(frame, writer, *, mask, extensions=()): # but it's the instance of class to which this method is bound. for extension in extensions: - frame = extension(frame) + frame = extension.encode(frame) frame.check() diff --git a/websockets/test_framing.py b/websockets/test_framing.py index a77332180..9aec1ea17 100644 --- a/websockets/test_framing.py +++ b/websockets/test_framing.py @@ -155,14 +155,19 @@ def test_serialize_close_errors(self): @unittest.skipUnless(sys.version_info[:2] >= (3, 4), "rot13 is new in 3.4") def test_extensions(self): - # This extensions is symmetrical. - def rot13(frame): - assert frame.opcode == OP_TEXT - text = frame.data.decode() - data = codecs.encode(text, 'rot13').encode() - return frame._replace(data=data) + class Rot13: + + @staticmethod + def encode(frame): + assert frame.opcode == OP_TEXT + text = frame.data.decode() + data = codecs.encode(text, 'rot13').encode() + return frame._replace(data=data) + + # This extensions is symmetrical. + decode = encode self.round_trip( b'\x81\x05uryyb', Frame(True, OP_TEXT, b'hello'), - extensions=[rot13]) + extensions=[Rot13()]) From cb6f56e70a5e1a27a85e3ea4529a6e6fed223223 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 9 May 2017 18:06:12 +0200 Subject: [PATCH 0280/1539] Stub per-message deflate extension. --- compliance/fuzzingclient.json | 2 +- compliance/fuzzingserver.json | 2 +- websockets/extensions.py | 102 ++++++++++++++++++++++++++++++++++ websockets/framing.py | 16 ++++-- websockets/protocol.py | 3 + websockets/server.py | 47 +++++++--------- websockets/test_extensions.py | 0 7 files changed, 137 insertions(+), 35 deletions(-) create mode 100644 websockets/extensions.py create mode 100644 websockets/test_extensions.py diff --git a/compliance/fuzzingclient.json b/compliance/fuzzingclient.json index 3aa408a7d..bd754de41 100644 --- a/compliance/fuzzingclient.json +++ b/compliance/fuzzingclient.json @@ -6,6 +6,6 @@ "servers": [{"agent": "websockets", "url": "ws://localhost:8642", "options": {"version": 18}}], "cases": ["*"], - "exclude-cases": [], + "exclude-cases": ["12.*", "13.*"], "exclude-agent-cases": {} } diff --git a/compliance/fuzzingserver.json b/compliance/fuzzingserver.json index 1bdb42723..6d39d86a9 100644 --- a/compliance/fuzzingserver.json +++ b/compliance/fuzzingserver.json @@ -7,6 +7,6 @@ "webport": 8080, "cases": ["*"], - "exclude-cases": [], + "exclude-cases": ["12.*", "13.*"], "exclude-agent-cases": {} } diff --git a/websockets/extensions.py b/websockets/extensions.py new file mode 100644 index 000000000..cea58109e --- /dev/null +++ b/websockets/extensions.py @@ -0,0 +1,102 @@ +import zlib + +from .framing import CTRL_OPCODES, OP_CONT + + +__all__ = ['PerMessageDeflate'] + + +_EMPTY_UNCOMPRESSED_BLOCK = b'\x00\x00\xff\xff' + + +class PerMessageDeflate: + """ + Compression Extensions for WebSocket (`RFC 7692`_). + + .. _RFC 7692: http://tools.ietf.org/html/rfc7692 + + """ + + # This class implements the server-side behavior by default. + # To get the client-side behavior, set is_client = True. + is_client = False + + def __init__(self): + # Currently there's no way to customize these parameters. + self.server_no_context_takeover = False + self.client_no_context_takeover = False + self.server_max_window_bits = 15 + self.client_max_window_bits = 15 + # Internal state. + self.decoder = zlib.decompressobj( + wbits=-( + self.server_max_window_bits + if self.is_client else + self.client_max_window_bits + ), + ) + self.encoder = zlib.compressobj( + wbits=-( + self.client_max_window_bits + if self.is_client else + self.server_max_window_bits + ), + ) + self.decode_cont_data = False + self.encode_cont_data = False + + def name(self): + return 'permessage-deflate' + + def decode(self, frame): + """ + Decode an incoming frame. + + """ + # Skip control frames. + if frame.opcode in CTRL_OPCODES: + return frame + # Handle continuation data frames: + # - skip if the initial data frame wasn't encoded + # - reset "decode continuation data" flag if it's a final frame + elif frame.opcode == OP_CONT: + if not self.decode_cont_data: + return frame + if frame.fin: + self.decode_cont_data = False + # Handle text and binary data frames: + # - skip if the frame isn't encoded + # - set "decode continuation data" flag if it's a non-final frame + else: + if not frame.rsv1: + return frame + if not frame.fin: # frame.rsv1 is True at this point + self.decode_cont_data = True + + # Uncompress compressed frames. + data = frame.data + if frame.fin: + data += _EMPTY_UNCOMPRESSED_BLOCK + data = self.decoder.decompress(data) + + return frame._replace(data=data, rsv1=False) + + def encode(self, frame): + """ + Encode an outgoing frame. + + """ + # Skip control frames. + if frame.opcode in CTRL_OPCODES: + return frame + + # Compress data frames. + # Since we don't do fragmentation, this is easy. + data = ( + self.encoder.compress(frame.data) + + self.encoder.flush(zlib.Z_SYNC_FLUSH) + ) + if data.endswith(_EMPTY_UNCOMPRESSED_BLOCK): + data = data[:-4] + + return frame._replace(data=data, rsv1=frame.opcode != OP_CONT) diff --git a/websockets/framing.py b/websockets/framing.py index 24a11c5c5..5c007d6f5 100644 --- a/websockets/framing.py +++ b/websockets/framing.py @@ -77,7 +77,7 @@ def __new__(cls, fin, opcode, data, rsv1=False, rsv2=False, rsv3=False): @classmethod @asyncio.coroutine - def read(cls, reader, *, mask, max_size=None, extensions=()): + def read(cls, reader, *, mask, max_size=None, extensions=None): """ Read a WebSocket frame and return a :class:`Frame` object. @@ -133,14 +133,16 @@ def read(cls, reader, *, mask, max_size=None, extensions=()): frame = cls(fin, opcode, data, rsv1, rsv2, rsv3) - frame.check() - + if extensions is None: + extensions = [] for extension in reversed(extensions): frame = extension.decode(frame) + frame.check() + return frame - def write(frame, writer, *, mask, extensions=()): + def write(frame, writer, *, mask, extensions=None): """ Write a WebSocket frame. @@ -161,14 +163,16 @@ def write(frame, writer, *, mask, extensions=()): """ + frame.check() + # The first parameter is called `frame` rather than `self`, # but it's the instance of class to which this method is bound. + if extensions is None: + extensions = [] for extension in extensions: frame = extension.encode(frame) - frame.check() - output = io.BytesIO() # Prepare the header diff --git a/websockets/protocol.py b/websockets/protocol.py index 11a1a302c..998f99f2b 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -151,6 +151,7 @@ def __init__(self, *, self.response_headers = None self.raw_response_headers = None + self.extensions = [] self.subprotocol = None # Code and reason must be set when the closing handshake completes. @@ -545,6 +546,7 @@ def read_frame(self, max_size): self.reader.readexactly, mask=not self.is_client, max_size=max_size, + extensions=self.extensions, ) logger.debug( "%s << %s", @@ -572,6 +574,7 @@ def write_frame(self, opcode, data=b''): frame.write( self.writer.write, mask=self.is_client, + extensions=self.extensions, ) # Backport of the combined logic of: diff --git a/websockets/server.py b/websockets/server.py index 86dc884b9..aa37f0f2a 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -15,6 +15,7 @@ from .exceptions import ( AbortHandshake, InvalidHandshake, InvalidMessage, InvalidOrigin ) +from .extensions import PerMessageDeflate from .handshake import build_response, check_request from .http import USER_AGENT, build_headers, read_request from .protocol import CONNECTING, OPEN, WebSocketCommonProtocol @@ -258,28 +259,18 @@ def process_extensions(self, get_header, available_extensions=None): Handle the Sec-WebSocket-Extensions HTTP request header. """ - if available_extensions is not None: - extensions = get_header('Sec-WebSocket-Extensions') - if extensions: - return self.select_extensions( - [ - extension.strip() - for extension in extensions.split(',') - ], - available_extensions, - ) - - @staticmethod - def select_extensions(client_extensions, server_extensions): - """ - Pick a subprotocol among those offered by the client. - - """ - return [ - extension - for extension in client_extensions - if extension in server_extensions - ] + # TODO this doesn't allow configuring available extensions. + extensions = get_header('Sec-WebSocket-Extensions') + if extensions: + extensions = [ + extension.strip() + for extension in extensions.split(',') + ] + for extension in extensions: + extension, params = extension.split(';', 1) + if extension == 'permessage-deflate': + return [PerMessageDeflate()] + return [] def process_subprotocol(self, get_header, available_subprotocols=None): """ @@ -289,11 +280,12 @@ def process_subprotocol(self, get_header, available_subprotocols=None): if available_subprotocols is not None: subprotocols = get_header('Sec-WebSocket-Protocol') if subprotocols: + subprotocols = [ + subprotocol.strip() + for subprotocol in subprotocols.split(',') + ] return self.select_subprotocol( - [ - subprotocol.strip() - for subprotocol in subprotocols.split(',') - ], + subprotocols, available_subprotocols, ) @@ -362,7 +354,8 @@ def handshake(self, origins=None, available_extensions=None, set_header('Server', USER_AGENT) if self.extensions: - set_header('Sec-WebSocket-Extensions', ', '.join(self.extensions)) + set_header('Sec-WebSocket-Extensions', ', '.join( + extension.name() for extension in self.extensions)) if self.subprotocol: set_header('Sec-WebSocket-Protocol', self.subprotocol) if extra_headers is not None: diff --git a/websockets/test_extensions.py b/websockets/test_extensions.py new file mode 100644 index 000000000..e69de29bb From 3e377f3d125a00c9dcbfc351f9e8096281046644 Mon Sep 17 00:00:00 2001 From: Jonathan Martin Date: Sun, 21 May 2017 17:32:23 +0200 Subject: [PATCH 0281/1539] Support `permessage-deflate` parameters --- websockets/extensions.py | 102 +++++++++++++++++++++++++++++++-------- websockets/server.py | 4 +- 2 files changed, 84 insertions(+), 22 deletions(-) diff --git a/websockets/extensions.py b/websockets/extensions.py index cea58109e..60ddc7c5c 100644 --- a/websockets/extensions.py +++ b/websockets/extensions.py @@ -21,27 +21,57 @@ class PerMessageDeflate: # To get the client-side behavior, set is_client = True. is_client = False - def __init__(self): - # Currently there's no way to customize these parameters. - self.server_no_context_takeover = False - self.client_no_context_takeover = False - self.server_max_window_bits = 15 - self.client_max_window_bits = 15 + def __init__(self, parameter_string, *, + server_no_context_takeover=False, + client_no_context_takeover=False, + server_max_window_bits=15, + client_max_window_bits=15): + self.server_no_context_takeover = server_no_context_takeover + self.client_no_context_takeover = client_no_context_takeover + self.server_max_window_bits = server_max_window_bits + self.client_max_window_bits = client_max_window_bits + + for param in [p.strip() for p in parameter_string.split(';')]: + if param == 'server_no_context_takeover': + self.server_no_context_takeover = True + elif param == 'client_no_context_takeover': + self.client_no_context_takeover = True + elif param.startswith('client_max_window_bits'): + if '=' in param: + window_bits = int(param.split('=')[1]) + assert 8 <= window_bits <= 15 + self.server_max_window_bits = min( + window_bits, self.server_max_window_bits) + elif param.startswith('server_max_window_bits'): + assert '=' in param + window_bits = int(param.split('=')[1]) + assert 8 <= window_bits <= 15 + self.server_max_window_bits = min( + window_bits, self.server_max_window_bits) + else: + raise ValueError('invalid parameter') + # Internal state. - self.decoder = zlib.decompressobj( - wbits=-( - self.server_max_window_bits - if self.is_client else - self.client_max_window_bits - ), - ) - self.encoder = zlib.compressobj( - wbits=-( - self.client_max_window_bits - if self.is_client else - self.server_max_window_bits - ), - ) + if self.server_no_context_takeover: + self.decoder = None + else: + self.decoder = zlib.decompressobj( + wbits=-( + self.server_max_window_bits + if self.is_client else + self.client_max_window_bits + ), + ) + if self.client_no_context_takeover: + self.encoder = None + else: + self.encoder = zlib.compressobj( + wbits=-( + self.client_max_window_bits + if self.is_client else + self.server_max_window_bits + ), + ) self.decode_cont_data = False self.encode_cont_data = False @@ -73,6 +103,15 @@ def decode(self, frame): if not frame.fin: # frame.rsv1 is True at this point self.decode_cont_data = True + if self.server_no_context_takeover: + self.decoder = zlib.decompressobj( + wbits=-( + self.server_max_window_bits + if self.is_client else + self.client_max_window_bits + ), + ) + # Uncompress compressed frames. data = frame.data if frame.fin: @@ -90,6 +129,15 @@ def encode(self, frame): if frame.opcode in CTRL_OPCODES: return frame + if self.client_no_context_takeover: + self.encoder = zlib.compressobj( + wbits=-( + self.client_max_window_bits + if self.is_client else + self.server_max_window_bits + ), + ) + # Compress data frames. # Since we don't do fragmentation, this is easy. data = ( @@ -100,3 +148,17 @@ def encode(self, frame): data = data[:-4] return frame._replace(data=data, rsv1=frame.opcode != OP_CONT) + + def response(self): + response = self.name() + if self.server_no_context_takeover: + response += '; server_no_context_takeover' + if self.client_no_context_takeover: + response += '; client_no_context_takeover' + if self.client_max_window_bits < 15: + response += '; client_max_window_bits={}'.format( + self.client_max_window_bits) + if self.server_max_window_bits < 15: + response += '; server_max_window_bits={}'.format( + self.server_max_window_bits) + return response diff --git a/websockets/server.py b/websockets/server.py index aa37f0f2a..72e5d2a1f 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -269,7 +269,7 @@ def process_extensions(self, get_header, available_extensions=None): for extension in extensions: extension, params = extension.split(';', 1) if extension == 'permessage-deflate': - return [PerMessageDeflate()] + return [PerMessageDeflate(params)] return [] def process_subprotocol(self, get_header, available_subprotocols=None): @@ -355,7 +355,7 @@ def handshake(self, origins=None, available_extensions=None, if self.extensions: set_header('Sec-WebSocket-Extensions', ', '.join( - extension.name() for extension in self.extensions)) + extension.response() for extension in self.extensions)) if self.subprotocol: set_header('Sec-WebSocket-Protocol', self.subprotocol) if extra_headers is not None: From f79999fbc906ab5037fcc53df019f104689763a2 Mon Sep 17 00:00:00 2001 From: Jonathan Martin Date: Sun, 21 May 2017 23:45:06 +0200 Subject: [PATCH 0282/1539] Add parse_extensions API and make PerMessageDeflate more generic (client, server) --- compliance/fuzzingclient.json | 2 +- websockets/extensions.py | 80 ++++++++++++++++++++++++++--------- websockets/server.py | 11 ++--- websockets/test_extensions.py | 71 +++++++++++++++++++++++++++++++ 4 files changed, 135 insertions(+), 29 deletions(-) diff --git a/compliance/fuzzingclient.json b/compliance/fuzzingclient.json index bd754de41..9c4dd6342 100644 --- a/compliance/fuzzingclient.json +++ b/compliance/fuzzingclient.json @@ -6,6 +6,6 @@ "servers": [{"agent": "websockets", "url": "ws://localhost:8642", "options": {"version": 18}}], "cases": ["*"], - "exclude-cases": ["12.*", "13.*"], + "exclude-cases": ["12.5.*"], "exclude-agent-cases": {} } diff --git a/websockets/extensions.py b/websockets/extensions.py index 60ddc7c5c..72f114ccc 100644 --- a/websockets/extensions.py +++ b/websockets/extensions.py @@ -3,12 +3,43 @@ from .framing import CTRL_OPCODES, OP_CONT -__all__ = ['PerMessageDeflate'] - +__all__ = ['PerMessageDeflate', 'parse_extensions'] _EMPTY_UNCOMPRESSED_BLOCK = b'\x00\x00\xff\xff' +def unquote_value(value): + if value and value[0] == value[-1] == '"': + return value[1:-1] + return value + + +def parse_extensions(header): + """ + Parse an extension header and return a list of extension/parameters + :param header: str + :return: [('extension name', {parameters dict}), ...] + """ + extensions = [] + header = header.replace('\n', ',') + for ext_string in header.split(','): + ext_name, *params_list = ext_string.strip().split(';') + ext_name = ext_name.strip() + if not ext_name: + # Can happen with an initial carriage return + continue + parameters = {} + for param in params_list: + if '=' in param: + param, param_value = param.split('=', 1) + param_value = unquote_value(param_value.strip()) + else: + param_value = None + parameters[param.strip()] = param_value + extensions.append((ext_name, parameters)) + return extensions + + class PerMessageDeflate: """ Compression Extensions for WebSocket (`RFC 7692`_). @@ -17,42 +48,48 @@ class PerMessageDeflate: """ - # This class implements the server-side behavior by default. - # To get the client-side behavior, set is_client = True. - is_client = False - - def __init__(self, parameter_string, *, + def __init__(self, is_client, parameters, *, server_no_context_takeover=False, client_no_context_takeover=False, server_max_window_bits=15, client_max_window_bits=15): + self.is_client = is_client self.server_no_context_takeover = server_no_context_takeover self.client_no_context_takeover = client_no_context_takeover self.server_max_window_bits = server_max_window_bits self.client_max_window_bits = client_max_window_bits - for param in [p.strip() for p in parameter_string.split(';')]: + for param, value in parameters.items(): if param == 'server_no_context_takeover': + assert value is None self.server_no_context_takeover = True elif param == 'client_no_context_takeover': + assert value is None self.client_no_context_takeover = True elif param.startswith('client_max_window_bits'): - if '=' in param: - window_bits = int(param.split('=')[1]) + if value: + window_bits = int(value) assert 8 <= window_bits <= 15 - self.server_max_window_bits = min( - window_bits, self.server_max_window_bits) + window_bits = min(window_bits, self.client_max_window_bits) + self.client_max_window_bits = window_bits elif param.startswith('server_max_window_bits'): - assert '=' in param - window_bits = int(param.split('=')[1]) + assert value is not None + window_bits = int(value) assert 8 <= window_bits <= 15 - self.server_max_window_bits = min( - window_bits, self.server_max_window_bits) + window_bits = min(window_bits, self.server_max_window_bits) + self.server_max_window_bits = window_bits else: raise ValueError('invalid parameter') # Internal state. - if self.server_no_context_takeover: + if self.is_client: + self.transient_encoder = self.client_no_context_takeover + self.transient_decoder = self.server_no_context_takeover + else: + self.transient_encoder = self.server_no_context_takeover + self.transient_decoder = self.client_no_context_takeover + + if self.transient_decoder: self.decoder = None else: self.decoder = zlib.decompressobj( @@ -62,7 +99,8 @@ def __init__(self, parameter_string, *, self.client_max_window_bits ), ) - if self.client_no_context_takeover: + + if self.transient_encoder: self.encoder = None else: self.encoder = zlib.compressobj( @@ -100,10 +138,10 @@ def decode(self, frame): else: if not frame.rsv1: return frame - if not frame.fin: # frame.rsv1 is True at this point + if not frame.fin: # frame.rsv1 is True at this point self.decode_cont_data = True - if self.server_no_context_takeover: + if self.transient_decoder: self.decoder = zlib.decompressobj( wbits=-( self.server_max_window_bits @@ -129,7 +167,7 @@ def encode(self, frame): if frame.opcode in CTRL_OPCODES: return frame - if self.client_no_context_takeover: + if self.transient_encoder: self.encoder = zlib.compressobj( wbits=-( self.client_max_window_bits diff --git a/websockets/server.py b/websockets/server.py index 72e5d2a1f..7b63e2168 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -15,7 +15,7 @@ from .exceptions import ( AbortHandshake, InvalidHandshake, InvalidMessage, InvalidOrigin ) -from .extensions import PerMessageDeflate +from .extensions import PerMessageDeflate, parse_extensions from .handshake import build_response, check_request from .http import USER_AGENT, build_headers, read_request from .protocol import CONNECTING, OPEN, WebSocketCommonProtocol @@ -262,14 +262,11 @@ def process_extensions(self, get_header, available_extensions=None): # TODO this doesn't allow configuring available extensions. extensions = get_header('Sec-WebSocket-Extensions') if extensions: - extensions = [ - extension.strip() - for extension in extensions.split(',') - ] + extensions = parse_extensions(extensions) for extension in extensions: - extension, params = extension.split(';', 1) + extension, params = extension if extension == 'permessage-deflate': - return [PerMessageDeflate(params)] + return [PerMessageDeflate(False, params)] return [] def process_subprotocol(self, get_header, available_subprotocols=None): diff --git a/websockets/test_extensions.py b/websockets/test_extensions.py index e69de29bb..15a0c7d5b 100644 --- a/websockets/test_extensions.py +++ b/websockets/test_extensions.py @@ -0,0 +1,71 @@ +import unittest + +from websockets.extensions import parse_extensions + + +class ExtensionParsingTests(unittest.TestCase): + def test_simple(self): + self.assert_parse_extensions('permessage-deflate', [ + ('permessage-deflate', {}) + ]) + + def test_one_extension_no_value(self): + self.assert_parse_extensions( + 'permessage-deflate; client_max_window_bits', [ + ('permessage-deflate', {'client_max_window_bits': None}) + ]) + + def test_one_extension_value(self): + self.assert_parse_extensions( + 'permessage-deflate; server_max_window_bits=10', [ + ('permessage-deflate', {'server_max_window_bits': '10'}) + ]) + + def test_one_extension_quoted_value(self): + self.assert_parse_extensions( + 'permessage-deflate; server_max_window_bits="10"', [ + ('permessage-deflate', {'server_max_window_bits': '10'}) + ]) + + def test_one_extension_multiple_params(self): + self.assert_parse_extensions( + 'permessage-deflate; option_a;option_b="10";option_c=foo', + [ + ('permessage-deflate', { + 'option_a': None, + 'option_b': '10', + 'option_c': 'foo' + }) + ]) + + def test_multi_extensions(self): + self.assert_parse_extensions( + 'ext_one; option_a;option_b="10", ext_two, ext_three; foo; bar=42', + [ + ('ext_one', { + 'option_a': None, + 'option_b': '10' + }), + ('ext_two', {}), + ('ext_three', { + 'foo': None, + 'bar': '42' + }) + ]) + + def test_multi_line(self): + self.assert_parse_extensions( + '\next_one, \next_two, \n\next_three; foo; bar=42', + [ + ('ext_one', {}), + ('ext_two', {}), + ('ext_three', { + 'foo': None, + 'bar': '42' + }) + ]) + + @staticmethod + def assert_parse_extensions(header, expected): + result = parse_extensions(header) + assert result == expected From 9ee35e1e01b2c5d3c2897609d52aeecac7505dde Mon Sep 17 00:00:00 2001 From: Jonathan Martin Date: Sun, 21 May 2017 23:45:59 +0200 Subject: [PATCH 0283/1539] Add support for client side compression --- compliance/fuzzingserver.json | 4 ++-- websockets/client.py | 33 ++++++++++++++++----------------- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/compliance/fuzzingserver.json b/compliance/fuzzingserver.json index 6d39d86a9..b83bd5968 100644 --- a/compliance/fuzzingserver.json +++ b/compliance/fuzzingserver.json @@ -3,10 +3,10 @@ "url": "ws://localhost:8642", "options": {"failByDrop": false}, - "outdir": "./reports/clients", + "outdir": "./reports/servers", "webport": 8080, "cases": ["*"], - "exclude-cases": ["12.*", "13.*"], + "exclude-cases": ["12.5.*"], "exclude-agent-cases": {} } diff --git a/websockets/client.py b/websockets/client.py index 94d8eb903..652e10634 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -7,6 +7,7 @@ import collections.abc from .exceptions import InvalidHandshake, InvalidMessage, InvalidStatusCode +from .extensions import PerMessageDeflate, parse_extensions from .handshake import build_request, check_response from .http import USER_AGENT, build_headers, read_response from .protocol import CONNECTING, OPEN, WebSocketCommonProtocol @@ -31,7 +32,12 @@ def __init__(self, *, origin=None, extensions=None, subprotocols=None, extra_headers=None, **kwds): self.origin = origin - self.available_extensions = extensions + self.available_extensions = [ + 'permessage-deflate' + '; client_no_context_takeover; client_max_window_bits' + ] + if extensions: + self.available_extensions.append(extensions) self.available_subprotocols = subprotocols self.extra_headers = extra_headers super().__init__(**kwds) @@ -85,19 +91,12 @@ def process_extensions(self, get_header, available_extensions=None): """ extensions = get_header('Sec-WebSocket-Extensions') if extensions: - extensions = [e.strip() for e in extensions.split(',')] - if available_extensions is None: - raise InvalidHandshake("No extensions supported.") - unsupported_extensions = [ - extension - for extension in extensions - if extension not in available_extensions - ] - if unsupported_extensions: - raise InvalidHandshake( - "Unsupported extensions: {}" - .format(', '.join(unsupported_extensions))) - return extensions + extensions = parse_extensions(extensions) + for extension in extensions: + extension, params = extension + if extension == 'permessage-deflate': + return [PerMessageDeflate(True, params)] + return [] def process_subprotocol(self, get_header, available_subprotocols=None): """ @@ -260,9 +259,9 @@ def connect(uri, *, try: yield from protocol.handshake( wsuri, origin=origin, - available_extensions=extensions, - available_subprotocols=subprotocols, - extra_headers=extra_headers, + available_extensions=protocol.available_extensions, + available_subprotocols=protocol.available_subprotocols, + extra_headers=protocol.extra_headers, ) except Exception: yield from protocol.close_connection(force=True) From 58bc7cd2af7e90ac31f6b347a688098e2f90d81e Mon Sep 17 00:00:00 2001 From: Jonathan Martin Date: Fri, 26 May 2017 21:11:59 +0200 Subject: [PATCH 0284/1539] Add parameter to force client not to use compression --- websockets/client.py | 17 ++++++++++------- websockets/test_client_server.py | 7 +++++++ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index 652e10634..8fb678ba7 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -30,12 +30,14 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): def __init__(self, *, origin=None, extensions=None, subprotocols=None, - extra_headers=None, **kwds): + extra_headers=None, use_compression=True, **kwds): self.origin = origin - self.available_extensions = [ - 'permessage-deflate' - '; client_no_context_takeover; client_max_window_bits' - ] + self.available_extensions = [] + if use_compression: + self.available_extensions.append( + 'permessage-deflate' + '; client_no_context_takeover; client_max_window_bits' + ) if extensions: self.available_extensions.append(extensions) self.available_subprotocols = subprotocols @@ -183,7 +185,7 @@ def connect(uri, *, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, legacy_recv=False, klass=None, origin=None, extensions=None, subprotocols=None, - extra_headers=None, **kwds): + extra_headers=None, use_compression=True, **kwds): """ This coroutine connects to a WebSocket server at a given ``uri``. @@ -218,6 +220,7 @@ def connect(uri, *, decreasing preference * ``extra_headers`` sets additional HTTP request headers – it can be a mapping or an iterable of (name, value) pairs + * ``use_compression`` allow client to force compression to be disabled :func:`connect` raises :exc:`~websockets.uri.InvalidURI` if ``uri`` is invalid and :exc:`~websockets.handshake.InvalidHandshake` if the opening @@ -250,7 +253,7 @@ def connect(uri, *, read_limit=read_limit, write_limit=write_limit, loop=loop, legacy_recv=legacy_recv, origin=origin, extensions=extensions, subprotocols=subprotocols, - extra_headers=extra_headers, + extra_headers=extra_headers, use_compression=use_compression ) transport, protocol = yield from loop.create_connection( diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index ecf575c9d..8ab3742cb 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -176,6 +176,13 @@ def test_basic(self): reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") + @with_server() + @with_client(use_compression=False) + def test_basic_no_compression(self): + self.loop.run_until_complete(self.client.send("Hello!")) + reply = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(reply, "Hello!") + @with_server() def test_server_close_while_client_connected(self): self.start_client() From bec5517cd6169da1b387c31280e1527a3b840a04 Mon Sep 17 00:00:00 2001 From: Jonathan Martin Date: Fri, 26 May 2017 21:20:23 +0200 Subject: [PATCH 0285/1539] More deflate extension testing --- websockets/extensions.py | 2 +- websockets/test_extensions.py | 154 +++++++++++++++++++++++++++++++++- 2 files changed, 154 insertions(+), 2 deletions(-) diff --git a/websockets/extensions.py b/websockets/extensions.py index 72f114ccc..efae2f5d9 100644 --- a/websockets/extensions.py +++ b/websockets/extensions.py @@ -182,7 +182,7 @@ def encode(self, frame): self.encoder.compress(frame.data) + self.encoder.flush(zlib.Z_SYNC_FLUSH) ) - if data.endswith(_EMPTY_UNCOMPRESSED_BLOCK): + if data.endswith(_EMPTY_UNCOMPRESSED_BLOCK): # pragma: no cover data = data[:-4] return frame._replace(data=data, rsv1=frame.opcode != OP_CONT) diff --git a/websockets/test_extensions.py b/websockets/test_extensions.py index 15a0c7d5b..f178cdead 100644 --- a/websockets/test_extensions.py +++ b/websockets/test_extensions.py @@ -1,6 +1,7 @@ import unittest -from websockets.extensions import parse_extensions +from .extensions import PerMessageDeflate, parse_extensions +from .framing import OP_CONT, OP_PING, OP_TEXT, Frame class ExtensionParsingTests(unittest.TestCase): @@ -69,3 +70,154 @@ def test_multi_line(self): def assert_parse_extensions(header, expected): result = parse_extensions(header) assert result == expected + + +class PerMessageDeflateTests(unittest.TestCase): + def test_deflate_default(self): + server_deflate = PerMessageDeflate(False, {}) + data = "Hello world".encode('utf-8') + + frame = Frame(True, OP_TEXT, data) + frame = server_deflate.encode(frame) + self.assertTrue(frame.rsv1) + self.assertNotEqual(frame.data, data) + + frame = server_deflate.decode(frame) + self.assertFalse(frame.rsv1) + self.assertEqual(frame.data, data) + + def test_deflate_control(self): + server_deflate = PerMessageDeflate(False, {}) + + frame = Frame(True, OP_PING, b'foo') + encoded = server_deflate.encode(frame) + self.assertEqual(frame, encoded) + + decoded = server_deflate.decode(encoded) + self.assertEqual(frame, decoded) + + def test_deflate_decode_uncompressed(self): + server_deflate = PerMessageDeflate(False, {}) + data = "Hello world".encode('utf-8') + + frame = Frame(True, OP_TEXT, data) + frame = server_deflate.decode(frame) + self.assertEqual(frame.data, data) + + def test_deflate_decode_uncompressed_fragments(self): + server_deflate = PerMessageDeflate(False, {}) + data = "Hello world".encode('utf-8') + + frame = Frame(True, OP_TEXT, data) + frag1 = server_deflate.decode( + frame._replace(fin=False, data=frame.data[:5]) + ) + frag2 = server_deflate.decode( + frame._replace(opcode=OP_CONT, data=frame.data[5:]) + ) + result = frag1.data + frag2.data + self.assertEqual(result, data) + + def test_deflate_fragment(self): + server_deflate = PerMessageDeflate(False, {}) + data = "I love websockets, especially RFC 7692".encode('utf-8') + + frame = server_deflate.encode(Frame(True, OP_TEXT, data)) + frag1 = server_deflate.decode( + frame._replace(fin=False, data=frame.data[:5]) + ) + frag2 = server_deflate.decode( + frame._replace(fin=False, rsv1=False, opcode=OP_CONT, + data=frame.data[5:10]) + ) + frag3 = server_deflate.decode( + frame._replace(rsv1=False, opcode=OP_CONT, data=frame.data[10:]) + ) + result = frag1.data + frag2.data + frag3.data + self.assertEqual(result, data) + + # Manually configured items + + def test_deflate_response_server_no_context_takeover(self): + deflate = PerMessageDeflate(False, {}, server_no_context_takeover=True) + self.assertIn('server_no_context_takeover', deflate.response()) + + def test_deflate_response_client_no_context_takeover(self): + deflate = PerMessageDeflate(False, {}, client_no_context_takeover=True) + self.assertIn('client_no_context_takeover', deflate.response()) + + def test_deflate_response_client_max_window_bits(self): + deflate = PerMessageDeflate(False, {}, client_max_window_bits=10) + self.assertIn('client_max_window_bits=10', deflate.response()) + + def test_deflate_response_server_max_window_bits(self): + deflate = PerMessageDeflate(False, {}, server_max_window_bits=8) + self.assertIn('server_max_window_bits=8', deflate.response()) + + # Taking requested params into account + + def test_deflate_server_max_window_bits_same(self): + deflate = PerMessageDeflate(False, { + 'server_max_window_bits': 10 + }, server_max_window_bits=10) + self.assertIn('server_max_window_bits=10', deflate.response()) + + def test_deflate_server_max_window_bits_higher(self): + deflate = PerMessageDeflate(False, { + 'server_max_window_bits': 12 + }, server_max_window_bits=10) + self.assertIn('server_max_window_bits=10', deflate.response()) + + def test_deflate_server_max_window_bits_lower(self): + deflate = PerMessageDeflate(False, { + 'server_max_window_bits': 8 + }, server_max_window_bits=10) + self.assertIn('server_max_window_bits=8', deflate.response()) + + def test_deflate_client_max_window_bits_same(self): + deflate = PerMessageDeflate(False, { + 'client_max_window_bits': 10 + }, client_max_window_bits=10) + self.assertIn('client_max_window_bits=10', deflate.response()) + + def test_deflate_client_max_window_bits_higher(self): + deflate = PerMessageDeflate(False, { + 'client_max_window_bits': 12 + }, client_max_window_bits=10) + self.assertIn('client_max_window_bits=10', deflate.response()) + + def test_deflate_client_max_window_bits_lower(self): + deflate = PerMessageDeflate(False, { + 'client_max_window_bits': 8 + }, client_max_window_bits=10) + self.assertIn('client_max_window_bits=8', deflate.response()) + + def test_deflate_server_no_context_takeover(self): + deflate = PerMessageDeflate(False, { + 'server_no_context_takeover': None + }) + self.assertIn('server_no_context_takeover', deflate.response()) + + def test_deflate_server_no_context_takeover_invalid(self): + with self.assertRaises(Exception): + PerMessageDeflate(False, { + 'server_no_context_takeover': 42 + }) + + def test_deflate_client_no_context_takeover(self): + deflate = PerMessageDeflate(False, { + 'client_no_context_takeover': None + }) + self.assertIn('client_no_context_takeover', deflate.response()) + + def test_deflate_client_no_context_takeover_invalid(self): + with self.assertRaises(Exception): + PerMessageDeflate(False, { + 'client_no_context_takeover': 42 + }) + + def test_deflate_invalid_parameter(self): + with self.assertRaises(Exception): + PerMessageDeflate(False, { + 'websockets_are_great': 42 + }) From 9fd534a53563903bace696c1407301310951a8b8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 17 Jun 2017 17:23:12 +0200 Subject: [PATCH 0286/1539] Turn extensions module into a package. --- setup.py | 2 +- websockets/client.py | 3 +- websockets/extensions/__init__.py | 0 .../permessage_deflate.py} | 50 +++---------- .../test_permessage_deflate.py} | 73 +------------------ websockets/extensions/test_utils.py | 72 ++++++++++++++++++ websockets/extensions/utils.py | 33 +++++++++ websockets/server.py | 3 +- 8 files changed, 123 insertions(+), 113 deletions(-) create mode 100644 websockets/extensions/__init__.py rename websockets/{extensions.py => extensions/permessage_deflate.py} (82%) rename websockets/{test_extensions.py => extensions/test_permessage_deflate.py} (71%) create mode 100644 websockets/extensions/test_utils.py create mode 100644 websockets/extensions/utils.py diff --git a/setup.py b/setup.py index 4ec80bb64..14742973a 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ if py_version < (3, 3): raise Exception("websockets requires Python >= 3.3.") -packages = ['websockets'] +packages = ['websockets', 'websockets/extensions'] if py_version >= (3, 5): packages.append('websockets/py35') diff --git a/websockets/client.py b/websockets/client.py index 8fb678ba7..f01134b10 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -7,7 +7,8 @@ import collections.abc from .exceptions import InvalidHandshake, InvalidMessage, InvalidStatusCode -from .extensions import PerMessageDeflate, parse_extensions +from .extensions.permessage_deflate import PerMessageDeflate +from .extensions.utils import parse_extensions from .handshake import build_request, check_response from .http import USER_AGENT, build_headers, read_response from .protocol import CONNECTING, OPEN, WebSocketCommonProtocol diff --git a/websockets/extensions/__init__.py b/websockets/extensions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/websockets/extensions.py b/websockets/extensions/permessage_deflate.py similarity index 82% rename from websockets/extensions.py rename to websockets/extensions/permessage_deflate.py index efae2f5d9..ac572c93c 100644 --- a/websockets/extensions.py +++ b/websockets/extensions/permessage_deflate.py @@ -1,52 +1,22 @@ -import zlib - -from .framing import CTRL_OPCODES, OP_CONT +""" +The :mod:`websockets.extensions.permessage_deflate` module implements the +Compression Extensions for WebSocket as specified in `RFC 7692`_. +.. _RFC 7692: http://tools.ietf.org/html/rfc7692 -__all__ = ['PerMessageDeflate', 'parse_extensions'] +""" -_EMPTY_UNCOMPRESSED_BLOCK = b'\x00\x00\xff\xff' +import zlib +from ..framing import CTRL_OPCODES, OP_CONT -def unquote_value(value): - if value and value[0] == value[-1] == '"': - return value[1:-1] - return value - - -def parse_extensions(header): - """ - Parse an extension header and return a list of extension/parameters - :param header: str - :return: [('extension name', {parameters dict}), ...] - """ - extensions = [] - header = header.replace('\n', ',') - for ext_string in header.split(','): - ext_name, *params_list = ext_string.strip().split(';') - ext_name = ext_name.strip() - if not ext_name: - # Can happen with an initial carriage return - continue - parameters = {} - for param in params_list: - if '=' in param: - param, param_value = param.split('=', 1) - param_value = unquote_value(param_value.strip()) - else: - param_value = None - parameters[param.strip()] = param_value - extensions.append((ext_name, parameters)) - return extensions +__all__ = ['PerMessageDeflate'] -class PerMessageDeflate: - """ - Compression Extensions for WebSocket (`RFC 7692`_). +_EMPTY_UNCOMPRESSED_BLOCK = b'\x00\x00\xff\xff' - .. _RFC 7692: http://tools.ietf.org/html/rfc7692 - """ +class PerMessageDeflate: def __init__(self, is_client, parameters, *, server_no_context_takeover=False, diff --git a/websockets/test_extensions.py b/websockets/extensions/test_permessage_deflate.py similarity index 71% rename from websockets/test_extensions.py rename to websockets/extensions/test_permessage_deflate.py index f178cdead..154f0278e 100644 --- a/websockets/test_extensions.py +++ b/websockets/extensions/test_permessage_deflate.py @@ -1,78 +1,11 @@ import unittest -from .extensions import PerMessageDeflate, parse_extensions -from .framing import OP_CONT, OP_PING, OP_TEXT, Frame - - -class ExtensionParsingTests(unittest.TestCase): - def test_simple(self): - self.assert_parse_extensions('permessage-deflate', [ - ('permessage-deflate', {}) - ]) - - def test_one_extension_no_value(self): - self.assert_parse_extensions( - 'permessage-deflate; client_max_window_bits', [ - ('permessage-deflate', {'client_max_window_bits': None}) - ]) - - def test_one_extension_value(self): - self.assert_parse_extensions( - 'permessage-deflate; server_max_window_bits=10', [ - ('permessage-deflate', {'server_max_window_bits': '10'}) - ]) - - def test_one_extension_quoted_value(self): - self.assert_parse_extensions( - 'permessage-deflate; server_max_window_bits="10"', [ - ('permessage-deflate', {'server_max_window_bits': '10'}) - ]) - - def test_one_extension_multiple_params(self): - self.assert_parse_extensions( - 'permessage-deflate; option_a;option_b="10";option_c=foo', - [ - ('permessage-deflate', { - 'option_a': None, - 'option_b': '10', - 'option_c': 'foo' - }) - ]) - - def test_multi_extensions(self): - self.assert_parse_extensions( - 'ext_one; option_a;option_b="10", ext_two, ext_three; foo; bar=42', - [ - ('ext_one', { - 'option_a': None, - 'option_b': '10' - }), - ('ext_two', {}), - ('ext_three', { - 'foo': None, - 'bar': '42' - }) - ]) - - def test_multi_line(self): - self.assert_parse_extensions( - '\next_one, \next_two, \n\next_three; foo; bar=42', - [ - ('ext_one', {}), - ('ext_two', {}), - ('ext_three', { - 'foo': None, - 'bar': '42' - }) - ]) - - @staticmethod - def assert_parse_extensions(header, expected): - result = parse_extensions(header) - assert result == expected +from ..framing import OP_CONT, OP_PING, OP_TEXT, Frame +from .permessage_deflate import * class PerMessageDeflateTests(unittest.TestCase): + def test_deflate_default(self): server_deflate = PerMessageDeflate(False, {}) data = "Hello world".encode('utf-8') diff --git a/websockets/extensions/test_utils.py b/websockets/extensions/test_utils.py new file mode 100644 index 000000000..54a54cdb6 --- /dev/null +++ b/websockets/extensions/test_utils.py @@ -0,0 +1,72 @@ +import unittest + +from .utils import * + + +class ExtensionParsingTests(unittest.TestCase): + + def test_simple(self): + self.assert_parse_extensions('permessage-deflate', [ + ('permessage-deflate', {}) + ]) + + def test_one_extension_no_value(self): + self.assert_parse_extensions( + 'permessage-deflate; client_max_window_bits', [ + ('permessage-deflate', {'client_max_window_bits': None}) + ]) + + def test_one_extension_value(self): + self.assert_parse_extensions( + 'permessage-deflate; server_max_window_bits=10', [ + ('permessage-deflate', {'server_max_window_bits': '10'}) + ]) + + def test_one_extension_quoted_value(self): + self.assert_parse_extensions( + 'permessage-deflate; server_max_window_bits="10"', [ + ('permessage-deflate', {'server_max_window_bits': '10'}) + ]) + + def test_one_extension_multiple_params(self): + self.assert_parse_extensions( + 'permessage-deflate; option_a;option_b="10";option_c=foo', + [ + ('permessage-deflate', { + 'option_a': None, + 'option_b': '10', + 'option_c': 'foo' + }) + ]) + + def test_multi_extensions(self): + self.assert_parse_extensions( + 'ext_one; option_a;option_b="10", ext_two, ext_three; foo; bar=42', + [ + ('ext_one', { + 'option_a': None, + 'option_b': '10' + }), + ('ext_two', {}), + ('ext_three', { + 'foo': None, + 'bar': '42' + }) + ]) + + def test_multi_line(self): + self.assert_parse_extensions( + '\next_one, \next_two, \n\next_three; foo; bar=42', + [ + ('ext_one', {}), + ('ext_two', {}), + ('ext_three', { + 'foo': None, + 'bar': '42' + }) + ]) + + @staticmethod + def assert_parse_extensions(header, expected): + result = parse_extensions(header) + assert result == expected diff --git a/websockets/extensions/utils.py b/websockets/extensions/utils.py new file mode 100644 index 000000000..6753c540d --- /dev/null +++ b/websockets/extensions/utils.py @@ -0,0 +1,33 @@ +__all__ = ['parse_extensions'] + + +def unquote_value(value): + if value and value[0] == value[-1] == '"': + return value[1:-1] + return value + + +def parse_extensions(header): + """ + Parse an extension header and return a list of extension/parameters + :param header: str + :return: [('extension name', {parameters dict}), ...] + """ + extensions = [] + header = header.replace('\n', ',') + for ext_string in header.split(','): + ext_name, *params_list = ext_string.strip().split(';') + ext_name = ext_name.strip() + if not ext_name: + # Can happen with an initial carriage return + continue + parameters = {} + for param in params_list: + if '=' in param: + param, param_value = param.split('=', 1) + param_value = unquote_value(param_value.strip()) + else: + param_value = None + parameters[param.strip()] = param_value + extensions.append((ext_name, parameters)) + return extensions diff --git a/websockets/server.py b/websockets/server.py index 7b63e2168..cf83f9255 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -15,7 +15,8 @@ from .exceptions import ( AbortHandshake, InvalidHandshake, InvalidMessage, InvalidOrigin ) -from .extensions import PerMessageDeflate, parse_extensions +from .extensions.permessage_deflate import PerMessageDeflate +from .extensions.utils import parse_extensions from .handshake import build_response, check_request from .http import USER_AGENT, build_headers, read_request from .protocol import CONNECTING, OPEN, WebSocketCommonProtocol From 842ef415fab8c15b1136117f7a3d2d1a84336f2b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 Jul 2017 22:44:56 +0200 Subject: [PATCH 0287/1539] Implement correct parsing of Sec-WebSocket-Extensions. This requires a lot more code that one might expect. The previous version failed when quoted strings contained delimiters. --- websockets/client.py | 6 +- websockets/exceptions.py | 18 +++- websockets/extensions/test_utils.py | 124 +++++++++++----------- websockets/extensions/utils.py | 159 +++++++++++++++++++++++----- websockets/server.py | 7 +- 5 files changed, 215 insertions(+), 99 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index f01134b10..5e2eac92b 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -8,7 +8,7 @@ from .exceptions import InvalidHandshake, InvalidMessage, InvalidStatusCode from .extensions.permessage_deflate import PerMessageDeflate -from .extensions.utils import parse_extensions +from .extensions.utils import parse_extension_list from .handshake import build_request, check_response from .http import USER_AGENT, build_headers, read_response from .protocol import CONNECTING, OPEN, WebSocketCommonProtocol @@ -94,11 +94,11 @@ def process_extensions(self, get_header, available_extensions=None): """ extensions = get_header('Sec-WebSocket-Extensions') if extensions: - extensions = parse_extensions(extensions) + extensions = parse_extension_list(extensions) for extension in extensions: extension, params = extension if extension == 'permessage-deflate': - return [PerMessageDeflate(True, params)] + return [PerMessageDeflate(True, dict(params))] return [] def process_subprotocol(self, get_header, available_subprotocols=None): diff --git a/websockets/exceptions.py b/websockets/exceptions.py index 3db569564..c465c0fe7 100644 --- a/websockets/exceptions.py +++ b/websockets/exceptions.py @@ -1,7 +1,7 @@ __all__ = [ - 'AbortHandshake', 'InvalidHandshake', 'InvalidMessage', 'InvalidOrigin', - 'InvalidState', 'InvalidStatusCode', 'InvalidURI', 'ConnectionClosed', - 'PayloadTooBig', 'WebSocketProtocolError', + 'AbortHandshake', 'InvalidHandshake', 'InvalidHeader', 'InvalidMessage', + 'InvalidOrigin', 'InvalidState', 'InvalidStatusCode', 'InvalidURI', + 'ConnectionClosed', 'PayloadTooBig', 'WebSocketProtocolError', ] @@ -30,6 +30,18 @@ class InvalidMessage(InvalidHandshake): """ +class InvalidHeader(InvalidHandshake): + """ + Exception raised when a HTTP header doesn't have the expected format. + + """ + def __init__(self, message, string, pos): + self.string = string + self.pos = pos + message = '{} at {} in {}'.format(message, pos, string) + super().__init__(message) + + class InvalidOrigin(InvalidHandshake): """ Exception raised when the origin in a handshake request is forbidden. diff --git a/websockets/extensions/test_utils.py b/websockets/extensions/test_utils.py index 54a54cdb6..b135554d0 100644 --- a/websockets/extensions/test_utils.py +++ b/websockets/extensions/test_utils.py @@ -1,72 +1,68 @@ import unittest +from ..exceptions import InvalidHeader from .utils import * -class ExtensionParsingTests(unittest.TestCase): +class UtilsTests(unittest.TestCase): - def test_simple(self): - self.assert_parse_extensions('permessage-deflate', [ - ('permessage-deflate', {}) - ]) + def test_parse_extension_list(self): + for header, parsed in [ + # Synthetic examples + ( + 'foo', + [('foo', [])], + ), + ( + 'foo, bar', + [('foo', []), ('bar', [])], + ), + ( + 'foo; name; token=token; quoted-string="quoted string", ' + 'bar; quux; quuux', + [ + ('foo', [('name', None), ('token', 'token'), + ('quoted-string', 'quoted string')]), + ('bar', [('quux', None), ('quuux', None)]), + ], + ), + # Pathological examples + ( + 'a; b="q,s;1\\"2\'3\\\\4="; c="q;s,6=7\\\\8\'9\\\""', + [('a', [('b', 'q,s;1"2\'3\\4='), ('c', 'q;s,6=7\\8\'9"')])] + ), + ( + ',\t, , ,foo ;bar = 42,, baz,,', + [('foo', [('bar', '42')]), ('baz', [])], + ), + # Realistic use cases for permessage-deflate + ( + 'permessage-deflate', + [('permessage-deflate', [])], + ), + ( + 'permessage-deflate; client_max_window_bits', + [('permessage-deflate', [('client_max_window_bits', None)])], + ), + ( + 'permessage-deflate; server_max_window_bits=10', + [('permessage-deflate', [('server_max_window_bits', '10')])], + ), + ]: + self.assertEqual(parse_extension_list(header), parsed) - def test_one_extension_no_value(self): - self.assert_parse_extensions( - 'permessage-deflate; client_max_window_bits', [ - ('permessage-deflate', {'client_max_window_bits': None}) - ]) + def test_parse_extension_list_invalid_header(self): + for header in [ + # Truncated examples + '', + ',\t,' + 'foo;', + 'foo; bar;', + 'foo; bar=', + 'foo; bar="baz', + # Wrong delimiter + 'foo, bar, baz=quux; quuux', - def test_one_extension_value(self): - self.assert_parse_extensions( - 'permessage-deflate; server_max_window_bits=10', [ - ('permessage-deflate', {'server_max_window_bits': '10'}) - ]) - - def test_one_extension_quoted_value(self): - self.assert_parse_extensions( - 'permessage-deflate; server_max_window_bits="10"', [ - ('permessage-deflate', {'server_max_window_bits': '10'}) - ]) - - def test_one_extension_multiple_params(self): - self.assert_parse_extensions( - 'permessage-deflate; option_a;option_b="10";option_c=foo', - [ - ('permessage-deflate', { - 'option_a': None, - 'option_b': '10', - 'option_c': 'foo' - }) - ]) - - def test_multi_extensions(self): - self.assert_parse_extensions( - 'ext_one; option_a;option_b="10", ext_two, ext_three; foo; bar=42', - [ - ('ext_one', { - 'option_a': None, - 'option_b': '10' - }), - ('ext_two', {}), - ('ext_three', { - 'foo': None, - 'bar': '42' - }) - ]) - - def test_multi_line(self): - self.assert_parse_extensions( - '\next_one, \next_two, \n\next_three; foo; bar=42', - [ - ('ext_one', {}), - ('ext_two', {}), - ('ext_three', { - 'foo': None, - 'bar': '42' - }) - ]) - - @staticmethod - def assert_parse_extensions(header, expected): - result = parse_extensions(header) - assert result == expected + ]: + with self.assertRaises(InvalidHeader): + parse_extension_list(header) diff --git a/websockets/extensions/utils.py b/websockets/extensions/utils.py index 6753c540d..b848261be 100644 --- a/websockets/extensions/utils.py +++ b/websockets/extensions/utils.py @@ -1,33 +1,142 @@ -__all__ = ['parse_extensions'] +import re +from ..exceptions import InvalidHeader -def unquote_value(value): - if value and value[0] == value[-1] == '"': - return value[1:-1] - return value +__all__ = ['parse_extension_list'] -def parse_extensions(header): + +# To avoid a dependency on a parsing library, we implement manually the ABNF +# described in https://tools.ietf.org/html/rfc6455#section-9.1 with the +# definitions from https://tools.ietf.org/html/rfc7230#appendix-B. + +def peek_ahead(string, pos): + # We never peek more than one character ahead. + return None if pos == len(string) else string[pos] + + +_OWS_re = re.compile(r'[\t ]*') + + +def parse_OWS(string, pos): + # There's always a match, possibly empty, whose content doesn't matter. + match = _OWS_re.match(string, pos) + return match.end() + + +_token_re = re.compile(r'[-!#$%&\'*+.^_`|~0-9a-zA-Z]+') + + +def parse_token(string, pos): + match = _token_re.match(string, pos) + if match is None: + raise InvalidHeader("expected token", string=string, pos=pos) + return match.group(), match.end() + + +_quoted_string_re = re.compile( + r'"(?:[\x09\x20-\x21\x23-\x5b\x5d-\x7e]|\\[\x09\x20-\x7e\x80-\xff])*"') + + +_unquote_re = re.compile(r'\\([\x09\x20-\x7e\x80-\xff])') + + +def parse_quoted_string(string, pos): + match = _quoted_string_re.match(string, pos) + if match is None: + raise InvalidHeader("expected quoted string", string=string, pos=pos) + return _unquote_re.sub(r'\1', match.group()[1:-1]), match.end() + + +def parse_extension_param(string, pos): + # Extract parameter name. + name, pos = parse_token(string, pos) + pos = parse_OWS(string, pos) + # Extract parameter string, if there is one. + if peek_ahead(string, pos) == '=': + pos = parse_OWS(string, pos + 1) + if peek_ahead(string, pos) == '"': + value, pos = parse_quoted_string(string, pos) + else: + value, pos = parse_token(string, pos) + pos = parse_OWS(string, pos) + else: + value = None + + return (name, value), pos + + +def parse_extension(string, pos): + # Extract extension name. + name, pos = parse_token(string, pos) + pos = parse_OWS(string, pos) + # Extract all parameters. + parameters = [] + while peek_ahead(string, pos) == ';': + pos = parse_OWS(string, pos + 1) + parameter, pos = parse_extension_param(string, pos) + parameters.append(parameter) + return (name, parameters), pos + + +def parse_extension_list(string, pos=0): """ - Parse an extension header and return a list of extension/parameters - :param header: str - :return: [('extension name', {parameters dict}), ...] + Parse a Sec-WebSocket-Extensions header. + + The string is assumed not to start or end with whitespace. + + The return value has the following format:: + + [ + ( + 'extension name', + [ + ('parameter name', 'parameter value'), + .... + ] + ), + ... + ] + + Parameter values are ``None`` when no value is provided. + + Raise InvalidHeader if the header cannot be parsed. + """ + # Per https://tools.ietf.org/html/rfc7230#section-7, "a recipient MUST + # parse and ignore a reasonable number of empty list elements"; hence + # while loops that remove extra delimiters. + + # Remove extra delimiters before the first extension. + while peek_ahead(string, pos) == ',': + pos = parse_OWS(string, pos + 1) + extensions = [] - header = header.replace('\n', ',') - for ext_string in header.split(','): - ext_name, *params_list = ext_string.strip().split(';') - ext_name = ext_name.strip() - if not ext_name: - # Can happen with an initial carriage return - continue - parameters = {} - for param in params_list: - if '=' in param: - param, param_value = param.split('=', 1) - param_value = unquote_value(param_value.strip()) - else: - param_value = None - parameters[param.strip()] = param_value - extensions.append((ext_name, parameters)) + while True: + # Loop invariant: an extension starts at pos in string. + extension, pos = parse_extension(string, pos) + extensions.append(extension) + + # We may have reached the end of the string. + if pos == len(string): + break + + # There must be a delimiter after each element except the last one. + if peek_ahead(string, pos) == ',': + pos = parse_OWS(string, pos + 1) + else: + raise InvalidHeader("expected comma", string=string, pos=pos) + + # Remove extra delimiters before the next extension. + while peek_ahead(string, pos) == ',': + pos = parse_OWS(string, pos + 1) + + # We may have reached the end of the string. + if pos == len(string): + break + + # Since we only advance in the string by one character with peek_ahead() + # or with the end position of a regex match, we can't overshoot the end. + assert pos == len(string) + return extensions diff --git a/websockets/server.py b/websockets/server.py index cf83f9255..f42bce009 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -1,4 +1,3 @@ - """ The :mod:`websockets.server` module defines a simple WebSocket server API. @@ -16,7 +15,7 @@ AbortHandshake, InvalidHandshake, InvalidMessage, InvalidOrigin ) from .extensions.permessage_deflate import PerMessageDeflate -from .extensions.utils import parse_extensions +from .extensions.utils import parse_extension_list from .handshake import build_response, check_request from .http import USER_AGENT, build_headers, read_request from .protocol import CONNECTING, OPEN, WebSocketCommonProtocol @@ -263,11 +262,11 @@ def process_extensions(self, get_header, available_extensions=None): # TODO this doesn't allow configuring available extensions. extensions = get_header('Sec-WebSocket-Extensions') if extensions: - extensions = parse_extensions(extensions) + extensions = parse_extension_list(extensions) for extension in extensions: extension, params = extension if extension == 'permessage-deflate': - return [PerMessageDeflate(False, params)] + return [PerMessageDeflate(False, dict(params))] return [] def process_subprotocol(self, get_header, available_subprotocols=None): From e7773be7c69b4e52f0c20de1846ef7b7a4e3a374 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 29 Jul 2017 15:13:10 +0200 Subject: [PATCH 0288/1539] Implement serialization of Sec-WebSocket-Extensions. --- websockets/extensions/test_utils.py | 3 +++ websockets/extensions/utils.py | 38 ++++++++++++++++++++++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/websockets/extensions/test_utils.py b/websockets/extensions/test_utils.py index b135554d0..2ff34734d 100644 --- a/websockets/extensions/test_utils.py +++ b/websockets/extensions/test_utils.py @@ -50,6 +50,9 @@ def test_parse_extension_list(self): ), ]: self.assertEqual(parse_extension_list(header), parsed) + # Also ensure that build_extension_list round-trips cleanly. + unparsed = build_extension_list(parsed) + self.assertEqual(parse_extension_list(unparsed), parsed) def test_parse_extension_list_invalid_header(self): for header in [ diff --git a/websockets/extensions/utils.py b/websockets/extensions/utils.py index b848261be..a6e46d9dd 100644 --- a/websockets/extensions/utils.py +++ b/websockets/extensions/utils.py @@ -3,7 +3,7 @@ from ..exceptions import InvalidHeader -__all__ = ['parse_extension_list'] +__all__ = ['build_extension_list', 'parse_extension_list'] # To avoid a dependency on a parsing library, we implement manually the ABNF @@ -140,3 +140,39 @@ def parse_extension_list(string, pos=0): assert pos == len(string) return extensions + + +_quote_re = re.compile(r'([\x22\x5c])') + + +# Workaround for the lack of re.fullmatch in older Pythons +_exact_token_re = re.compile(r'^[-!#$%&\'*+.^_`|~0-9a-zA-Z]+$') + + +def build_extension_param(name, value): + if value is None: + return name + elif _exact_token_re.match(value): + return '{}={}'.format(name, value) + else: + return '{}="{}"'.format(name, _quote_re.sub(r'\\\1', value)) + + +def build_extension(name, parameters): + return '; '.join([name] + [ + build_extension_param(name, value) + for name, value in parameters + ]) + + +def build_extension_list(extensions): + """ + Parse a Sec-WebSocket-Extensions header. + + This is the reverse of parse_extension_list. + + """ + return ', '.join( + build_extension(name, parameters) + for name, parameters in extensions + ) From 7e878e6d0e479b7df251aaa65c58922fb71e8a74 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 26 Jul 2017 23:36:59 +0200 Subject: [PATCH 0289/1539] Implement permessage-deflate negotiation. WIP - needs tests --- websockets/client.py | 118 ++++- websockets/exceptions.py | 58 ++- websockets/extensions/base.py | 92 ++++ websockets/extensions/permessage_deflate.py | 475 ++++++++++++++---- .../extensions/test_permessage_deflate.py | 280 ++++++----- websockets/extensions/utils.py | 2 +- websockets/server.py | 108 +++- websockets/test_client_server.py | 9 +- 8 files changed, 869 insertions(+), 273 deletions(-) create mode 100644 websockets/extensions/base.py diff --git a/websockets/client.py b/websockets/client.py index 5e2eac92b..4387bcce8 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -6,9 +6,11 @@ import asyncio import collections.abc -from .exceptions import InvalidHandshake, InvalidMessage, InvalidStatusCode -from .extensions.permessage_deflate import PerMessageDeflate -from .extensions.utils import parse_extension_list +from .exceptions import ( + InvalidHandshake, InvalidMessage, InvalidStatusCode, NegotiationError +) +from .extensions.permessage_deflate import ClientPerMessageDeflateFactory +from .extensions.utils import build_extension_list, parse_extension_list from .handshake import build_request, check_response from .http import USER_AGENT, build_headers, read_response from .protocol import CONNECTING, OPEN, WebSocketCommonProtocol @@ -31,16 +33,9 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): def __init__(self, *, origin=None, extensions=None, subprotocols=None, - extra_headers=None, use_compression=True, **kwds): + extra_headers=None, **kwds): self.origin = origin - self.available_extensions = [] - if use_compression: - self.available_extensions.append( - 'permessage-deflate' - '; client_no_context_takeover; client_max_window_bits' - ) - if extensions: - self.available_extensions.append(extensions) + self.available_extensions = extensions self.available_subprotocols = subprotocols self.extra_headers = extra_headers super().__init__(**kwds) @@ -93,12 +88,57 @@ def process_extensions(self, get_header, available_extensions=None): """ extensions = get_header('Sec-WebSocket-Extensions') + if extensions: - extensions = parse_extension_list(extensions) - for extension in extensions: - extension, params = extension - if extension == 'permessage-deflate': - return [PerMessageDeflate(True, dict(params))] + + if available_extensions is None: + raise InvalidHandshake("No extensions supported.") + + # For each extension selected in the server response, check that + # it matches an extension in our list of available extensions. + + # RFC 6455 leaves the exact process up to the specification of + # each extension. To provide this flexibility, we tell each + # extension which extensions were accepted up to this point. + + # Such flexibility prevents us from providing any guarantees + # against reordered or duplicated extensions in the response. + # Extensions must implement ther own requirements, based on the + # list of previously accepted extensions. + + accepted_extensions = [] + + for name, response_params in parse_extension_list(extensions): + + for extension_factory in available_extensions: + + # Skip non-matching extensions based on their name. + if extension_factory.name != name: + continue + + # This is allowed to raise NegotiationError. + extension = extension_factory.process_response_params( + response_params, accepted_extensions) + + # Skip non-matching extensions based on their params. + if extension is None: + continue + + # Add matching extension to the final list. + accepted_extensions.append(extension) + + # Break out of the loop once we have a match. + break + + # If we didn't break from the loop, no extension in our list + # matched what the server sent. Fail the connection. + else: + raise NegotiationError( + "Unsupported extension: name={}, params={}".format( + name, response_params)) + + return accepted_extensions + return [] def process_subprotocol(self, get_header, available_subprotocols=None): @@ -107,14 +147,20 @@ def process_subprotocol(self, get_header, available_subprotocols=None): """ subprotocol = get_header('Sec-WebSocket-Protocol') + if subprotocol: + if available_subprotocols is None: raise InvalidHandshake("No subprotocols supported.") + if subprotocol not in available_subprotocols: - raise InvalidHandshake( + raise NegotiationError( "Unsupported subprotocol: {}".format(subprotocol)) + return subprotocol + return None + @asyncio.coroutine def handshake(self, wsuri, origin=None, available_extensions=None, available_subprotocols=None, @@ -141,19 +187,30 @@ def handshake(self, wsuri, origin=None, set_header('Host', wsuri.host) else: set_header('Host', '{}:{}'.format(wsuri.host, wsuri.port)) + if origin is not None: set_header('Origin', origin) + if available_extensions is not None: - set_header( - 'Sec-WebSocket-Extensions', ', '.join(available_extensions)) + extensions_header = build_extension_list([ + ( + extension_factory.name, + extension_factory.get_request_params(), + ) + for extension_factory in available_extensions + ]) + set_header('Sec-WebSocket-Extensions', extensions_header) + if available_subprotocols is not None: - set_header( - 'Sec-WebSocket-Protocol', ', '.join(available_subprotocols)) + protocol_header = ', '.join(available_subprotocols) + set_header('Sec-WebSocket-Protocol', protocol_header) + if extra_headers is not None: if isinstance(extra_headers, collections.abc.Mapping): extra_headers = extra_headers.items() for name, value in extra_headers: set_header(name, value) + set_header('User-Agent', USER_AGENT) key = build_request(set_header) @@ -221,7 +278,8 @@ def connect(uri, *, decreasing preference * ``extra_headers`` sets additional HTTP request headers – it can be a mapping or an iterable of (name, value) pairs - * ``use_compression`` allow client to force compression to be disabled + * ``use_compression`` is a shortcut to enable compression of messages with + the "permessage-deflate" extension; it is enabled by default :func:`connect` raises :exc:`~websockets.uri.InvalidURI` if ``uri`` is invalid and :exc:`~websockets.handshake.InvalidHandshake` if the opening @@ -248,13 +306,25 @@ def connect(uri, *, elif kwds.get('ssl') is not None: raise ValueError("connect() received a SSL context for a ws:// URI. " "Use a wss:// URI to enable TLS.") + + if use_compression: + if extensions is None: + extensions = [] + if not any( + extension_factory.name == ClientPerMessageDeflateFactory.name + for extension_factory in extensions + ): + extensions.append(ClientPerMessageDeflateFactory( + client_max_window_bits=True, + )) + factory = lambda: create_protocol( host=wsuri.host, port=wsuri.port, secure=wsuri.secure, timeout=timeout, max_size=max_size, max_queue=max_queue, read_limit=read_limit, write_limit=write_limit, loop=loop, legacy_recv=legacy_recv, origin=origin, extensions=extensions, subprotocols=subprotocols, - extra_headers=extra_headers, use_compression=use_compression + extra_headers=extra_headers, ) transport, protocol = yield from loop.create_connection( diff --git a/websockets/exceptions.py b/websockets/exceptions.py index c465c0fe7..543f2579b 100644 --- a/websockets/exceptions.py +++ b/websockets/exceptions.py @@ -1,7 +1,10 @@ __all__ = [ + 'AbortHandshake', 'InvalidHandshake', 'InvalidHeader', 'InvalidMessage', - 'InvalidOrigin', 'InvalidState', 'InvalidStatusCode', 'InvalidURI', - 'ConnectionClosed', 'PayloadTooBig', 'WebSocketProtocolError', + 'InvalidOrigin', 'InvalidState', 'InvalidStatusCode', 'NegotiationError', + 'InvalidParameterName', 'InvalidParameterValue', 'DuplicateParameter', + 'InvalidURI', 'ConnectionClosed', 'PayloadTooBig', + 'WebSocketProtocolError', ] @@ -38,7 +41,7 @@ class InvalidHeader(InvalidHandshake): def __init__(self, message, string, pos): self.string = string self.pos = pos - message = '{} at {} in {}'.format(message, pos, string) + message = "{} at {} in {}".format(message, pos, string) super().__init__(message) @@ -58,7 +61,48 @@ class InvalidStatusCode(InvalidHandshake): """ def __init__(self, status_code): self.status_code = status_code - message = 'Status code not 101: {}'.format(status_code) + message = "Status code not 101: {}".format(status_code) + super().__init__(message) + + +class NegotiationError(InvalidHandshake): + """ + Exception raised when negociating an extension fails. + + """ + + +class InvalidParameterName(NegotiationError): + """ + Exception raised when a parameter name in an extension header is invalid. + + """ + def __init__(self, name): + self.name = name + message = "Invalid parameter name: {}".format(name) + super().__init__(message) + + +class InvalidParameterValue(NegotiationError): + """ + Exception raised when a parameter value in an extension header is invalid. + + """ + def __init__(self, name, value): + self.name = name + self.value = value + message = "Invalid value for parameter {}: {}".format(name, value) + super().__init__(message) + + +class DuplicateParameter(NegotiationError): + """ + Exception raised when a parameter name is repeated in an extension header. + + """ + def __init__(self, name): + self.name = name + message = "Duplicate parameter: {}".format(name) super().__init__(message) @@ -80,9 +124,9 @@ class ConnectionClosed(InvalidState): def __init__(self, code, reason): self.code = code self.reason = reason - message = 'WebSocket connection is closed: ' - message += 'code = {}, '.format(code) if code else 'no code, ' - message += 'reason = {}.'.format(reason) if reason else 'no reason.' + message = "WebSocket connection is closed: " + message += "code = {}, ".format(code) if code else "no code, " + message += "reason = {}.".format(reason) if reason else "no reason." super().__init__(message) diff --git a/websockets/extensions/base.py b/websockets/extensions/base.py new file mode 100644 index 000000000..679646ec5 --- /dev/null +++ b/websockets/extensions/base.py @@ -0,0 +1,92 @@ +""" +The :mod:`websockets.extensions.base` defines abstract classes for extensions. + +See https://tools.ietf.org/html/rfc6455#section-9. + +""" + + +class ClientExtensionFactory: + """ + Abstract class for client-side extension factories. + + Extension factories handle configuration and negotiation. + + """ + name = ... + + def get_request_params(self): + """ + Build request parameters. + + Return a list of (name, value) pairs. + + """ + + def process_response_params(self, params, accepted_extensions): + """" + Process response parameters. + + ``params`` are a list of (name, value) pairs. + + ``accepted_extensions`` is a list of previously accepted extensions, + represented by extension instances. + + Return an extension instance (an instance of a subclass of + :class:`Extension`) to accept this response or ``None`` to reject it. + + Raise :exc:`~websockets.exceptions.NegotiationError` to abort the + handshake and fail the WebSocket connection. + + """ + + +class ServerExtensionFactory: + """ + Abstract class for server-side extension factories. + + Extension factories handle configuration and negotiation. + + """ + name = ... + + def process_request_params(self, params, accepted_extensions): + """" + Process request parameters. + + ``accepted_extensions`` is a list of previously accepted extensions, + represented by extension instances. + + Return response params and an extension instance to accept this + extension or ``None, None`` to reject it. + + Return response params (a list of (name, value) pairs) and an + extension instance (an instance of a subclass of :class:`Extension`) + to accept this response or ``None, None`` to reject it. + + Raise :exc:`~websockets.exceptions.NegotiationError` to abort the + handshake and fail the websocket connection. + + """ + + +class Extension: + """ + Abstract class for extensions. + + """ + name = ... + + def decode(self, frame): + """ + Decode an incoming frame. + + """ + return frame + + def encode(self, frame): + """ + Encode an outgoing frame. + + """ + return frame diff --git a/websockets/extensions/permessage_deflate.py b/websockets/extensions/permessage_deflate.py index ac572c93c..38ae098df 100644 --- a/websockets/extensions/permessage_deflate.py +++ b/websockets/extensions/permessage_deflate.py @@ -8,6 +8,10 @@ import zlib +from ..exceptions import ( + DuplicateParameter, InvalidParameterName, InvalidParameterValue, + NegotiationError +) from ..framing import CTRL_OPCODES, OP_CONT @@ -15,76 +19,390 @@ _EMPTY_UNCOMPRESSED_BLOCK = b'\x00\x00\xff\xff' +_MAX_WINDOW_BITS_VALUES = [str(bits) for bits in range(8, 16)] -class PerMessageDeflate: - def __init__(self, is_client, parameters, *, - server_no_context_takeover=False, - client_no_context_takeover=False, - server_max_window_bits=15, - client_max_window_bits=15): - self.is_client = is_client +def _build_parameters( + server_no_context_takeover, + client_no_context_takeover, + server_max_window_bits, + client_max_window_bits, +): + params = [] + if server_no_context_takeover: + params.append(('server_no_context_takeover', None)) + if client_no_context_takeover: + params.append(('client_no_context_takeover', None)) + if server_max_window_bits: + params.append(('server_max_window_bits', str(server_max_window_bits))) + if client_max_window_bits is True: # only in handshake requests + params.append(('client_max_window_bits', None)) + elif client_max_window_bits: + params.append(('client_max_window_bits', str(client_max_window_bits))) + return params + + +def _extract_parameters(params, *, is_server): + server_no_context_takeover = False + client_no_context_takeover = False + server_max_window_bits = None + client_max_window_bits = None + + for name, value in params: + + if name == 'server_no_context_takeover': + if server_no_context_takeover: + raise DuplicateParameter(name) + if value is None: + server_no_context_takeover = True + else: + raise InvalidParameterValue(name, value) + + elif name == 'client_no_context_takeover': + if client_no_context_takeover: + raise DuplicateParameter(name) + if value is None: + client_no_context_takeover = True + else: + raise InvalidParameterValue(name, value) + + elif name == 'server_max_window_bits': + if server_max_window_bits is not None: + raise DuplicateParameter(name) + if value in _MAX_WINDOW_BITS_VALUES: + server_max_window_bits = int(value) + else: + raise InvalidParameterValue(name, value) + + elif name == 'client_max_window_bits': + if client_max_window_bits is not None: + raise DuplicateParameter(name) + if is_server and value is None: # only in handshake responses + client_max_window_bits = True + elif value in _MAX_WINDOW_BITS_VALUES: + client_max_window_bits = int(value) + else: + raise InvalidParameterValue(name, value) + + else: + raise InvalidParameterName(name) + + return ( + server_no_context_takeover, + client_no_context_takeover, + server_max_window_bits, + client_max_window_bits, + ) + + +class ClientPerMessageDeflateFactory: + """ + Client-side extension factory for permessage-deflate extension. + + """ + name = 'permessage-deflate' + + def __init__( + self, + server_no_context_takeover=False, + client_no_context_takeover=False, + server_max_window_bits=None, + client_max_window_bits=None, + ): + """ + Configure permessage-deflate extension factory. + + See https://tools.ietf.org/html/rfc7692#section-7.1. + + """ + if not (server_max_window_bits is None or + 8 <= server_max_window_bits <= 15): + raise ValueError("server_max_window_bits must be between 8 and 15") + if not (client_max_window_bits is None or + client_max_window_bits is True or + 8 <= client_max_window_bits <= 15): + raise ValueError("client_max_window_bits must be between 8 and 15") + self.server_no_context_takeover = server_no_context_takeover self.client_no_context_takeover = client_no_context_takeover self.server_max_window_bits = server_max_window_bits self.client_max_window_bits = client_max_window_bits - for param, value in parameters.items(): - if param == 'server_no_context_takeover': - assert value is None - self.server_no_context_takeover = True - elif param == 'client_no_context_takeover': - assert value is None - self.client_no_context_takeover = True - elif param.startswith('client_max_window_bits'): - if value: - window_bits = int(value) - assert 8 <= window_bits <= 15 - window_bits = min(window_bits, self.client_max_window_bits) - self.client_max_window_bits = window_bits - elif param.startswith('server_max_window_bits'): - assert value is not None - window_bits = int(value) - assert 8 <= window_bits <= 15 - window_bits = min(window_bits, self.server_max_window_bits) - self.server_max_window_bits = window_bits - else: - raise ValueError('invalid parameter') + def get_request_params(self): + """ + Build request parameters. + + """ + return _build_parameters( + self.server_no_context_takeover, self.client_no_context_takeover, + self.server_max_window_bits, self.client_max_window_bits, + ) + + def process_response_params(self, params, accepted_extensions): + """" + Process response parameters. + + Return an extension instance. + + """ + # Request parameters are available in instance variables. + + # Load response parameters in local variables. + ( + server_no_context_takeover, + client_no_context_takeover, + server_max_window_bits, + client_max_window_bits, + ) = _extract_parameters(params, is_server=False) + + # After comparing the request and the response, the final + # configuration must be available in the local variables. + + # server_no_context_takeover + # + # Req. Resp. Result + # ------ ------ -------------------------------------------------- + # False False False + # False True True + # True False Error! + # True True True + + if self.server_no_context_takeover: + if not server_no_context_takeover: + raise NegotiationError("Expected server_no_context_takeover") + + # client_no_context_takeover + # + # Req. Resp. Result + # ------ ------ -------------------------------------------------- + # False False False + # False True True + # True False True - must change value + # True True True + + if self.client_no_context_takeover: + if not client_no_context_takeover: + client_no_context_takeover = True + + # server_max_window_bits + + # Req. Resp. Result + # ------ ------ -------------------------------------------------- + # None None None + # None 8≤M≤15 M + # 8≤N≤15 None Error! + # 8≤N≤15 8≤M≤N M + # 8≤N≤15 N self.server_max_window_bits: + raise NegotiationError("Unsupported server_max_window_bits") + + # client_max_window_bits + + # Req. Resp. Result + # ------ ------ -------------------------------------------------- + # None None None + # None 8≤M≤15 Error! + # True None None + # True 8≤M≤15 M + # 8≤N≤15 None N - must change value + # 8≤N≤15 8≤M≤N M + # 8≤N≤15 N"M≤15 Error! + + if self.client_max_window_bits is None: + if client_max_window_bits is not None: + raise NegotiationError("Unexpected client_max_window_bits") + + elif self.client_max_window_bits is True: + pass - if self.transient_decoder: - self.decoder = None else: - self.decoder = zlib.decompressobj( - wbits=-( - self.server_max_window_bits - if self.is_client else - self.client_max_window_bits - ), - ) + if client_max_window_bits is None: + client_max_window_bits = self.client_max_window_bits + elif client_max_window_bits > self.client_max_window_bits: + raise NegotiationError("Unsupported client_max_window_bits") + + return PerMessageDeflate( + server_no_context_takeover, # remote_no_context_takeover + client_no_context_takeover, # local_no_context_takeover + server_max_window_bits or 15, # remote_max_window_bits + client_max_window_bits or 15, # local_max_window_bits + ) + + +class ServerPerMessageDeflateFactory: + """ + Server-side extension factory for permessage-deflate extension. + + """ + name = 'permessage-deflate' + + def __init__( + self, + server_no_context_takeover=False, + client_no_context_takeover=False, + server_max_window_bits=None, + client_max_window_bits=None, + ): + """ + Configure permessage-deflate extension factory. + + See https://tools.ietf.org/html/rfc7692#section-7.1. + + """ + if not (server_max_window_bits is None or + 8 <= server_max_window_bits <= 15): + raise ValueError("server_max_window_bits must be between 8 and 15") + if not (client_max_window_bits is None or + 8 <= client_max_window_bits <= 15): + raise ValueError("client_max_window_bits must be between 8 and 15") + + self.server_no_context_takeover = server_no_context_takeover + self.client_no_context_takeover = client_no_context_takeover + self.server_max_window_bits = server_max_window_bits + self.client_max_window_bits = client_max_window_bits + + def process_request_params(self, params, accepted_extensions): + """" + Process request parameters. + + Return response params and an extension instance. + + """ + # Load request parameters in local variables. + ( + server_no_context_takeover, + client_no_context_takeover, + server_max_window_bits, + client_max_window_bits, + ) = _extract_parameters(params, is_server=True) + + # Configuration parameters are available in instance variables. + + # After comparing the request and the configuration, the response must + # be available in the local variables. + + # server_no_context_takeover + # + # Config Req. Resp. + # ------ ------ -------------------------------------------------- + # False False False + # False True True + # True False True - must change value to True + # True True True + + if self.server_no_context_takeover: + if not server_no_context_takeover: + server_no_context_takeover = True + + # client_no_context_takeover + # + # Config Req. Resp. + # ------ ------ -------------------------------------------------- + # False False False + # False True True (or False) + # True False True - must change value to True + # True True True (or False) + + if self.client_no_context_takeover: + if not client_no_context_takeover: + client_no_context_takeover = True + + # server_max_window_bits + + # Config Req. Resp. + # ------ ------ -------------------------------------------------- + # None None None + # None 8≤M≤15 M + # 8≤N≤15 None N - must change value + # 8≤N≤15 8≤M≤N M + # 8≤N≤15 N self.server_max_window_bits: + server_max_window_bits = self.server_max_window_bits + + # client_max_window_bits + + # Config Req. Resp. + # ------ ------ -------------------------------------------------- + # None None None + # None True None - must change value + # None 8≤M≤15 M (or None) + # 8≤N≤15 None Error! + # 8≤N≤15 True N - must change value + # 8≤N≤15 8≤M≤N M (or None) + # 8≤N≤15 N Date: Tue, 8 Aug 2017 11:16:49 +0200 Subject: [PATCH 0290/1539] Avoid sending empty `Sec-WebSocket-Extensions` when no extension was negotiated --- websockets/extensions/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/websockets/extensions/utils.py b/websockets/extensions/utils.py index 989877389..3ba9f7148 100644 --- a/websockets/extensions/utils.py +++ b/websockets/extensions/utils.py @@ -175,4 +175,4 @@ def build_extension_list(extensions): return ', '.join( build_extension(name, parameters) for name, parameters in extensions - ) + ) if extensions else None From 79abe9961c136ae04356337bb57f609c9a41ab47 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 8 Aug 2017 18:14:27 +0200 Subject: [PATCH 0291/1539] Add test coverage for websockets.extensions.base. --- websockets/extensions/base.py | 6 ++++-- websockets/extensions/test_base.py | 4 ++++ 2 files changed, 8 insertions(+), 2 deletions(-) create mode 100644 websockets/extensions/test_base.py diff --git a/websockets/extensions/base.py b/websockets/extensions/base.py index 679646ec5..453ab18ef 100644 --- a/websockets/extensions/base.py +++ b/websockets/extensions/base.py @@ -81,12 +81,14 @@ def decode(self, frame): """ Decode an incoming frame. + Return a frame. + """ - return frame def encode(self, frame): """ Encode an outgoing frame. + Return a frame. + """ - return frame diff --git a/websockets/extensions/test_base.py b/websockets/extensions/test_base.py new file mode 100644 index 000000000..9dd15c857 --- /dev/null +++ b/websockets/extensions/test_base.py @@ -0,0 +1,4 @@ +from .base import * # noqa + + +# Abstract classes don't provide any behavior to test. From 09d443971b6b126eab55d6b961880c54dabdbcdd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 9 Aug 2017 15:38:21 +0200 Subject: [PATCH 0292/1539] Add full test coverage for permessage-deflate. --- websockets/extensions/permessage_deflate.py | 27 +- .../extensions/test_permessage_deflate.py | 900 +++++++++++++++--- websockets/extensions/test_utils.py | 13 + websockets/test_client_server.py | 1 - 4 files changed, 796 insertions(+), 145 deletions(-) diff --git a/websockets/extensions/permessage_deflate.py b/websockets/extensions/permessage_deflate.py index 38ae098df..91df654bf 100644 --- a/websockets/extensions/permessage_deflate.py +++ b/websockets/extensions/permessage_deflate.py @@ -15,7 +15,11 @@ from ..framing import CTRL_OPCODES, OP_CONT -__all__ = ['PerMessageDeflate'] +__all__ = [ + 'ClientPerMessageDeflateFactory', + 'ServerPerMessageDeflateFactory', + 'PerMessageDeflate', +] _EMPTY_UNCOMPRESSED_BLOCK = b'\x00\x00\xff\xff' @@ -213,7 +217,7 @@ def process_response_params(self, params, accepted_extensions): # True 8≤M≤15 M # 8≤N≤15 None N - must change value # 8≤N≤15 8≤M≤N M - # 8≤N≤15 N"M≤15 Error! + # 8≤N≤15 N 15 + (True, True, 15, 16), # client_max_window_bits > 15 + (False, False, True, None), # server_max_window_bits + ]: + with self.assertRaises(ValueError): + ClientPerMessageDeflateFactory(*config) + + def test_get_request_params(self): + for config, result in [ + # Test without any parameter + ( + (False, False, None, None), + [], + ), + # Test server_no_context_takeover + ( + (True, False, None, None), + [('server_no_context_takeover', None)], + ), + # Test client_no_context_takeover + ( + (False, True, None, None), + [('client_no_context_takeover', None)], + ), + # Test server_max_window_bits + ( + (False, False, 10, None), + [('server_max_window_bits', '10')], + ), + # Test client_max_window_bits + ( + (False, False, None, 10), + [('client_max_window_bits', '10')], + ), + ( + (False, False, None, True), + [('client_max_window_bits', None)], + ), + # Test all parameters together + ( + (True, True, 12, 12), + [ + ('server_no_context_takeover', None), + ('client_no_context_takeover', None), + ('server_max_window_bits', '12'), + ('client_max_window_bits', '12'), + ], + ), + ]: + factory = ClientPerMessageDeflateFactory(*config) + self.assertEqual(factory.get_request_params(), result) + + def test_process_response_params(self): + for config, response_params, result in [ + # Test without any parameter + ( + (False, False, None, None), + [], + (False, False, 15, 15), + ), + ( + (False, False, None, None), + [('unknown', None)], + InvalidParameterName, + ), + # Test server_no_context_takeover + ( + (False, False, None, None), + [('server_no_context_takeover', None)], + (True, False, 15, 15), + ), + ( + (True, False, None, None), + [], + NegotiationError, + ), + ( + (True, False, None, None), + [('server_no_context_takeover', None)], + (True, False, 15, 15), + ), + ( + (True, False, None, None), + [('server_no_context_takeover', None)] * 2, + DuplicateParameter, + ), + ( + (True, False, None, None), + [('server_no_context_takeover', '42')], + InvalidParameterValue, + ), + # Test client_no_context_takeover + ( + (False, False, None, None), + [('client_no_context_takeover', None)], + (False, True, 15, 15), + ), + ( + (False, True, None, None), + [], + (False, True, 15, 15), + ), + ( + (False, True, None, None), + [('client_no_context_takeover', None)], + (False, True, 15, 15), + ), + ( + (False, True, None, None), + [('client_no_context_takeover', None)] * 2, + DuplicateParameter, + ), + ( + (False, True, None, None), + [('client_no_context_takeover', '42')], + InvalidParameterValue, + ), + # Test server_max_window_bits + ( + (False, False, None, None), + [('server_max_window_bits', '7')], + NegotiationError, + ), + ( + (False, False, None, None), + [('server_max_window_bits', '10')], + (False, False, 10, 15), + ), + ( + (False, False, None, None), + [('server_max_window_bits', '16')], + NegotiationError, + ), + ( + (False, False, 12, None), + [], + NegotiationError, + ), + ( + (False, False, 12, None), + [('server_max_window_bits', '10')], + (False, False, 10, 15), + ), + ( + (False, False, 12, None), + [('server_max_window_bits', '12')], + (False, False, 12, 15), + ), + ( + (False, False, 12, None), + [('server_max_window_bits', '13')], + NegotiationError, + ), + ( + (False, False, 12, None), + [('server_max_window_bits', '12')] * 2, + DuplicateParameter, + ), + ( + (False, False, 12, None), + [('server_max_window_bits', '42')], + InvalidParameterValue, + ), + # Test client_max_window_bits + ( + (False, False, None, None), + [('client_max_window_bits', '10')], + NegotiationError, + ), + ( + (False, False, None, True), + [], + (False, False, 15, 15), + ), + ( + (False, False, None, True), + [('client_max_window_bits', '7')], + NegotiationError, + ), + ( + (False, False, None, True), + [('client_max_window_bits', '10')], + (False, False, 15, 10), + ), + ( + (False, False, None, True), + [('client_max_window_bits', '16')], + NegotiationError, + ), + ( + (False, False, None, 12), + [], + (False, False, 15, 12), + ), + ( + (False, False, None, 12), + [('client_max_window_bits', '10')], + (False, False, 15, 10), + ), + ( + (False, False, None, 12), + [('client_max_window_bits', '12')], + (False, False, 15, 12), + ), + ( + (False, False, None, 12), + [('client_max_window_bits', '13')], + NegotiationError, + ), + ( + (False, False, None, 12), + [('client_max_window_bits', '12')] * 2, + DuplicateParameter, + ), + ( + (False, False, None, 12), + [('client_max_window_bits', '42')], + InvalidParameterValue, + ), + # Test all parameters together + ( + (True, True, 12, 12), + [ + ('server_no_context_takeover', None), + ('client_no_context_takeover', None), + ('server_max_window_bits', '10'), + ('client_max_window_bits', '10'), + ], + (True, True, 10, 10), + ), + ( + (False, False, None, True), + [ + ('server_no_context_takeover', None), + ('client_no_context_takeover', None), + ('server_max_window_bits', '10'), + ('client_max_window_bits', '10'), + ], + (True, True, 10, 10), + ), + ( + (True, True, 12, 12), + [ + ('server_no_context_takeover', None), + ('server_max_window_bits', '12'), + ], + (True, True, 12, 12), + ), + ]: + factory = ClientPerMessageDeflateFactory(*config) + if isinstance(result, type) and issubclass(result, Exception): + with self.assertRaises(result): + factory.process_response_params(response_params, []) + else: + extension = factory.process_response_params( + response_params, []) + expected = PerMessageDeflate(*result) + self.assertExtensionEqual(extension, expected) + + +class ServerPerMessageDeflateFactoryTests(unittest.TestCase, + ExtensionTestsMixin): - def test_deflate_encode_decode_text_frame(self): - deflate = PerMessageDeflate(False, False, 15, 15) - data = "Hello world".encode('utf-8') - frame = Frame(True, OP_TEXT, data) + def test_name(self): + assert ServerPerMessageDeflateFactory.name == 'permessage-deflate' + + def test_init(self): + for config in [ + (False, False, 8, None), # server_max_window_bits ≥ 8 + (False, True, 15, None), # server_max_window_bits ≤ 15 + (True, False, None, 8), # client_max_window_bits ≥ 8 + (True, True, None, 15), # client_max_window_bits ≤ 15 + ]: + # This does not raise an exception. + ServerPerMessageDeflateFactory(*config) + + def test_init_error(self): + for config in [ + (False, False, 7, 8), # server_max_window_bits < 8 + (False, True, 8, 7), # client_max_window_bits < 8 + (True, False, 16, 15), # server_max_window_bits > 15 + (True, True, 15, 16), # client_max_window_bits > 15 + (False, False, None, True), # client_max_window_bits + (False, False, True, None), # server_max_window_bits + ]: + with self.assertRaises(ValueError): + ServerPerMessageDeflateFactory(*config) + + def test_process_request_params(self): + # Parameters in result appear swapped vs. config because the order is + # (remote, local) vs. (server, client). + for config, request_params, response_params, result in [ + # Test without any parameter + ( + (False, False, None, None), + [], + [], + (False, False, 15, 15), + ), + ( + (False, False, None, None), + [('unknown', None)], + None, + InvalidParameterName, + ), + # Test server_no_context_takeover + ( + (False, False, None, None), + [('server_no_context_takeover', None)], + [('server_no_context_takeover', None)], + (False, True, 15, 15), + ), + ( + (True, False, None, None), + [], + [('server_no_context_takeover', None)], + (False, True, 15, 15), + ), + ( + (True, False, None, None), + [('server_no_context_takeover', None)], + [('server_no_context_takeover', None)], + (False, True, 15, 15), + ), + ( + (True, False, None, None), + [('server_no_context_takeover', None)] * 2, + None, + DuplicateParameter, + ), + ( + (True, False, None, None), + [('server_no_context_takeover', '42')], + None, + InvalidParameterValue, + ), + # Test client_no_context_takeover + ( + (False, False, None, None), + [('client_no_context_takeover', None)], + [('client_no_context_takeover', None)], # doesn't matter + (True, False, 15, 15), + ), + ( + (False, True, None, None), + [], + [('client_no_context_takeover', None)], + (True, False, 15, 15), + ), + ( + (False, True, None, None), + [('client_no_context_takeover', None)], + [('client_no_context_takeover', None)], # doesn't matter + (True, False, 15, 15), + ), + ( + (False, True, None, None), + [('client_no_context_takeover', None)] * 2, + None, + DuplicateParameter, + ), + ( + (False, True, None, None), + [('client_no_context_takeover', '42')], + None, + InvalidParameterValue, + ), + # Test server_max_window_bits + ( + (False, False, None, None), + [('server_max_window_bits', '7')], + None, + NegotiationError, + ), + ( + (False, False, None, None), + [('server_max_window_bits', '10')], + [('server_max_window_bits', '10')], + (False, False, 15, 10), + ), + ( + (False, False, None, None), + [('server_max_window_bits', '16')], + None, + NegotiationError, + ), + ( + (False, False, 12, None), + [], + [('server_max_window_bits', '12')], + (False, False, 15, 12), + ), + ( + (False, False, 12, None), + [('server_max_window_bits', '10')], + [('server_max_window_bits', '10')], + (False, False, 15, 10), + ), + ( + (False, False, 12, None), + [('server_max_window_bits', '12')], + [('server_max_window_bits', '12')], + (False, False, 15, 12), + ), + ( + (False, False, 12, None), + [('server_max_window_bits', '13')], + [('server_max_window_bits', '12')], + (False, False, 15, 12), + ), + ( + (False, False, 12, None), + [('server_max_window_bits', '12')] * 2, + None, + DuplicateParameter, + ), + ( + (False, False, 12, None), + [('server_max_window_bits', '42')], + None, + InvalidParameterValue, + ), + # Test client_max_window_bits + ( + (False, False, None, None), + [('client_max_window_bits', None)], + [], + (False, False, 15, 15), + ), + ( + (False, False, None, None), + [('client_max_window_bits', '7')], + None, + InvalidParameterValue, + ), + ( + (False, False, None, None), + [('client_max_window_bits', '10')], + [('client_max_window_bits', '10')], # doesn't matter + (False, False, 10, 15), + ), + ( + (False, False, None, None), + [('client_max_window_bits', '16')], + None, + InvalidParameterValue, + ), + ( + (False, False, None, 12), + [], + None, + NegotiationError, + ), + ( + (False, False, None, 12), + [('client_max_window_bits', None)], + [('client_max_window_bits', '12')], + (False, False, 12, 15), + ), + ( + (False, False, None, 12), + [('client_max_window_bits', '10')], + [('client_max_window_bits', '10')], + (False, False, 10, 15), + ), + ( + (False, False, None, 12), + [('client_max_window_bits', '12')], + [('client_max_window_bits', '12')], # doesn't matter + (False, False, 12, 15), + ), + ( + (False, False, None, 12), + [('client_max_window_bits', '13')], + [('client_max_window_bits', '12')], # doesn't matter + (False, False, 12, 15), + ), + ( + (False, False, None, 12), + [('client_max_window_bits', '12')] * 2, + None, + DuplicateParameter, + ), + ( + (False, False, None, 12), + [('client_max_window_bits', '42')], + None, + InvalidParameterValue, + ), + # # Test all parameters together + ( + (True, True, 12, 12), + [ + ('server_no_context_takeover', None), + ('client_no_context_takeover', None), + ('server_max_window_bits', '10'), + ('client_max_window_bits', '10'), + ], + [ + ('server_no_context_takeover', None), + ('client_no_context_takeover', None), + ('server_max_window_bits', '10'), + ('client_max_window_bits', '10'), + ], + (True, True, 10, 10), + ), + ( + (False, False, None, None), + [ + ('server_no_context_takeover', None), + ('client_no_context_takeover', None), + ('server_max_window_bits', '10'), + ('client_max_window_bits', '10'), + ], + [ + ('server_no_context_takeover', None), + ('client_no_context_takeover', None), + ('server_max_window_bits', '10'), + ('client_max_window_bits', '10'), + ], + (True, True, 10, 10), + ), + ( + (True, True, 12, 12), + [ + ('client_max_window_bits', None), + ], + [ + ('server_no_context_takeover', None), + ('client_no_context_takeover', None), + ('server_max_window_bits', '12'), + ('client_max_window_bits', '12'), + ], + (True, True, 12, 12), + ), + ]: + factory = ServerPerMessageDeflateFactory(*config) + if isinstance(result, type) and issubclass(result, Exception): + with self.assertRaises(result): + factory.process_request_params(request_params, []) + else: + params, extension = factory.process_request_params( + request_params, []) + self.assertEqual(params, response_params) + expected = PerMessageDeflate(*result) + self.assertExtensionEqual(extension, expected) - enc_frame = deflate.encode(frame) - self.assertTrue(enc_frame.rsv1) - self.assertNotEqual(enc_frame.data, data) +class PerMessageDeflateTests(unittest.TestCase): - dec_frame = deflate.decode(enc_frame) + def setUp(self): + # Set up an instance of the permessage-deflate extension with the most + # common settings. Since the extension is symmetrical, this instance + # may be used for testing both encoding and decoding. + self.extension = PerMessageDeflate(False, False, 15, 15) - self.assertFalse(dec_frame.rsv1) - self.assertEqual(dec_frame.data, data) + def test_name(self): + assert self.extension.name == 'permessage-deflate' + + # Control frames aren't encoded or decoded. - def test_deflate_no_encode_decode_control_frame(self): - deflate = PerMessageDeflate(False, False, 15, 15) + def test_no_encode_decode_ping_frame(self): frame = Frame(True, OP_PING, b'') - enc_frame = deflate.encode(frame) - self.assertEqual(enc_frame, frame) + self.assertEqual(self.extension.encode(frame), frame) + + self.assertEqual(self.extension.decode(frame), frame) + + def test_no_encode_decode_pong_frame(self): + frame = Frame(True, OP_PONG, b'') + + self.assertEqual(self.extension.encode(frame), frame) + + self.assertEqual(self.extension.decode(frame), frame) + + def test_no_encode_decode_close_frame(self): + frame = Frame(True, OP_CLOSE, serialize_close(1000, '')) + + self.assertEqual(self.extension.encode(frame), frame) + + self.assertEqual(self.extension.decode(frame), frame) + + # Data frames are encoded and decoded. + + def test_encode_decode_text_frame(self): + frame = Frame(True, OP_TEXT, 'café'.encode('utf-8')) + + enc_frame = self.extension.encode(frame) + + self.assertEqual(enc_frame, frame._replace( + rsv1=True, + data=b'JNL;\xbc\x12\x00', + )) + + dec_frame = self.extension.decode(enc_frame) - dec_frame = deflate.decode(frame) self.assertEqual(dec_frame, frame) - def test_deflate_no_decode_uncompressed_text_frame(self): - deflate = PerMessageDeflate(False, False, 15, 15) - data = "Hello world".encode('utf-8') - frame = Frame(True, OP_TEXT, data) + def test_encode_decode_binary_frame(self): + frame = Frame(True, OP_BINARY, b'tea') - dec_frame = deflate.decode(frame) + enc_frame = self.extension.encode(frame) + + self.assertEqual(enc_frame, frame._replace( + rsv1=True, + data=b'*IM\x04\x00', + )) + + dec_frame = self.extension.decode(enc_frame) self.assertEqual(dec_frame, frame) - # def test_deflate_decode_uncompressed_fragments(self): - # deflate = PerMessageDeflate(False, False, 15, 15) - # data = "Hello world".encode('utf-8') - - # frame = Frame(True, OP_TEXT, data) - # frag1 = deflate.decode( - # frame._replace(fin=False, data=frame.data[:5]) - # ) - # frag2 = deflate.decode( - # frame._replace(opcode=OP_CONT, data=frame.data[5:]) - # ) - # result = frag1.data + frag2.data - # self.assertEqual(result, data) - - # def test_deflate_fragment(self): - # deflate = PerMessageDeflate(False, False, 15, 15) - # data = "I love websockets, especially RFC 7692".encode('utf-8') - - # frame = deflate.encode(Frame(True, OP_TEXT, data)) - # frag1 = deflate.decode( - # frame._replace(fin=False, data=frame.data[:5]) - # ) - # frag2 = deflate.decode( - # frame._replace(fin=False, rsv1=False, opcode=OP_CONT, - # data=frame.data[5:10]) - # ) - # frag3 = deflate.decode( - # frame._replace(rsv1=False, opcode=OP_CONT, data=frame.data[10:]) - # ) - # result = frag1.data + frag2.data + frag3.data - # self.assertEqual(result, data) - - # # Manually configured items - - # def test_deflate_response_server_no_context_takeover(self): - # deflate = PerMessageDeflate(False, False, None, None, server_no_context_takeover=True) - # self.assertIn('server_no_context_takeover', deflate.response()) - - # def test_deflate_response_client_no_context_takeover(self): - # deflate = PerMessageDeflate(False, False, None, None, client_no_context_takeover=True) - # self.assertIn('client_no_context_takeover', deflate.response()) - - # def test_deflate_response_client_max_window_bits(self): - # deflate = PerMessageDeflate(False, False, None, None, client_max_window_bits=10) - # self.assertIn('client_max_window_bits=10', deflate.response()) - - # def test_deflate_response_server_max_window_bits(self): - # deflate = PerMessageDeflate(False, False, None, None, server_max_window_bits=8) - # self.assertIn('server_max_window_bits=8', deflate.response()) - - # # Taking requested params into account - - # def test_deflate_server_max_window_bits_same(self): - # deflate = PerMessageDeflate(False, { - # 'server_max_window_bits': 10 - # }, server_max_window_bits=10) - # self.assertIn('server_max_window_bits=10', deflate.response()) - - # def test_deflate_server_max_window_bits_higher(self): - # deflate = PerMessageDeflate(False, { - # 'server_max_window_bits': 12 - # }, server_max_window_bits=10) - # self.assertIn('server_max_window_bits=10', deflate.response()) - - # def test_deflate_server_max_window_bits_lower(self): - # deflate = PerMessageDeflate(False, { - # 'server_max_window_bits': 8 - # }, server_max_window_bits=10) - # self.assertIn('server_max_window_bits=8', deflate.response()) - - # def test_deflate_client_max_window_bits_same(self): - # deflate = PerMessageDeflate(False, { - # 'client_max_window_bits': 10 - # }, client_max_window_bits=10) - # self.assertIn('client_max_window_bits=10', deflate.response()) - - # def test_deflate_client_max_window_bits_higher(self): - # deflate = PerMessageDeflate(False, { - # 'client_max_window_bits': 12 - # }, client_max_window_bits=10) - # self.assertIn('client_max_window_bits=10', deflate.response()) - - # def test_deflate_client_max_window_bits_lower(self): - # deflate = PerMessageDeflate(False, { - # 'client_max_window_bits': 8 - # }, client_max_window_bits=10) - # self.assertIn('client_max_window_bits=8', deflate.response()) - - # def test_deflate_server_no_context_takeover(self): - # deflate = PerMessageDeflate(False, { - # 'server_no_context_takeover': None - # }) - # self.assertIn('server_no_context_takeover', deflate.response()) - - # def test_deflate_server_no_context_takeover_invalid(self): - # with self.assertRaises(Exception): - # PerMessageDeflate(False, { - # 'server_no_context_takeover': 42 - # }) - - # def test_deflate_client_no_context_takeover(self): - # deflate = PerMessageDeflate(False, { - # 'client_no_context_takeover': None - # }) - # self.assertIn('client_no_context_takeover', deflate.response()) - - # def test_deflate_client_no_context_takeover_invalid(self): - # with self.assertRaises(Exception): - # PerMessageDeflate(False, { - # 'client_no_context_takeover': 42 - # }) - - # def test_deflate_invalid_parameter(self): - # with self.assertRaises(Exception): - # PerMessageDeflate(False, { - # 'websockets_are_great': 42 - # }) + def test_encode_decode_fragmented_text_frame(self): + frame1 = Frame(False, OP_TEXT, 'café'.encode('utf-8')) + frame2 = Frame(False, OP_CONT, ' & '.encode('utf-8')) + frame3 = Frame(True, OP_CONT, 'croissants'.encode('utf-8')) + + enc_frame1 = self.extension.encode(frame1) + enc_frame2 = self.extension.encode(frame2) + enc_frame3 = self.extension.encode(frame3) + + self.assertEqual(enc_frame1, frame1._replace( + rsv1=True, + data=b'JNL;\xbc\x12\x00\x00\x00\xff\xff', + )) + self.assertEqual(enc_frame2, frame2._replace( + rsv1=True, + data=b'RPS\x00\x00\x00\x00\xff\xff', + )) + self.assertEqual(enc_frame3, frame3._replace( + rsv1=True, + data=b'J.\xca\xcf,.N\xcc+)\x06\x00', + )) + + dec_frame1 = self.extension.decode(enc_frame1) + dec_frame2 = self.extension.decode(enc_frame2) + dec_frame3 = self.extension.decode(enc_frame3) + + self.assertEqual(dec_frame1, frame1) + self.assertEqual(dec_frame2, frame2) + self.assertEqual(dec_frame3, frame3) + + def test_encode_decode_fragmented_binary_frame(self): + frame1 = Frame(False, OP_TEXT, b'tea ') + frame2 = Frame(True, OP_CONT, b'time') + + enc_frame1 = self.extension.encode(frame1) + enc_frame2 = self.extension.encode(frame2) + + self.assertEqual(enc_frame1, frame1._replace( + rsv1=True, + data=b'*IMT\x00\x00\x00\x00\xff\xff', + )) + self.assertEqual(enc_frame2, frame2._replace( + rsv1=True, + data=b'*\xc9\xccM\x05\x00', + )) + + dec_frame1 = self.extension.decode(enc_frame1) + dec_frame2 = self.extension.decode(enc_frame2) + + self.assertEqual(dec_frame1, frame1) + self.assertEqual(dec_frame2, frame2) + + def test_no_decode_text_frame(self): + frame = Frame(True, OP_TEXT, 'café'.encode('utf-8')) + + # Try decoding a frame that wasn't encoded. + self.assertEqual(self.extension.decode(frame), frame) + + def test_no_decode_binary_frame(self): + frame = Frame(True, OP_TEXT, b'tea') + + # Try decoding a frame that wasn't encoded. + self.assertEqual(self.extension.decode(frame), frame) + + def test_no_decode_fragmented_text_frame(self): + frame1 = Frame(False, OP_TEXT, 'café'.encode('utf-8')) + frame2 = Frame(False, OP_CONT, ' & '.encode('utf-8')) + frame3 = Frame(True, OP_CONT, 'croissants'.encode('utf-8')) + + dec_frame1 = self.extension.decode(frame1) + dec_frame2 = self.extension.decode(frame2) + dec_frame3 = self.extension.decode(frame3) + + self.assertEqual(dec_frame1, frame1) + self.assertEqual(dec_frame2, frame2) + self.assertEqual(dec_frame3, frame3) + + def test_no_decode_fragmented_binary_frame(self): + frame1 = Frame(False, OP_TEXT, b'tea ') + frame2 = Frame(True, OP_CONT, b'time') + + dec_frame1 = self.extension.decode(frame1) + dec_frame2 = self.extension.decode(frame2) + + self.assertEqual(dec_frame1, frame1) + self.assertEqual(dec_frame2, frame2) + + def test_context_takeover(self): + frame = Frame(True, OP_TEXT, 'café'.encode('utf-8')) + + enc_frame1 = self.extension.encode(frame) + enc_frame2 = self.extension.encode(frame) + + self.assertEqual(enc_frame1.data, b'JNL;\xbc\x12\x00') + self.assertEqual(enc_frame2.data, b'J\x06\x11\x00\x00') + + def test_remote_no_context_takeover(self): + # No context takeover when decoding messages. + self.extension = PerMessageDeflate(True, False, 15, 15) + + frame = Frame(True, OP_TEXT, 'café'.encode('utf-8')) + + enc_frame1 = self.extension.encode(frame) + enc_frame2 = self.extension.encode(frame) + + self.assertEqual(enc_frame1.data, b'JNL;\xbc\x12\x00') + self.assertEqual(enc_frame2.data, b'J\x06\x11\x00\x00') + + dec_frame1 = self.extension.decode(enc_frame1) + self.assertEqual(dec_frame1, frame) + + with self.assertRaises(zlib.error) as exc: + self.extension.decode(enc_frame2) + self.assertIn("invalid distance too far back", str(exc.exception)) + + def test_local_no_context_takeover(self): + # No context takeover when encoding and decoding messages. + self.extension = PerMessageDeflate(True, True, 15, 15) + + frame = Frame(True, OP_TEXT, 'café'.encode('utf-8')) + + enc_frame1 = self.extension.encode(frame) + enc_frame2 = self.extension.encode(frame) + + self.assertEqual(enc_frame1.data, b'JNL;\xbc\x12\x00') + self.assertEqual(enc_frame2.data, b'JNL;\xbc\x12\x00') + + dec_frame1 = self.extension.decode(enc_frame1) + dec_frame2 = self.extension.decode(enc_frame2) + + self.assertEqual(dec_frame1, frame) + self.assertEqual(dec_frame2, frame) diff --git a/websockets/extensions/test_utils.py b/websockets/extensions/test_utils.py index 2ff34734d..571752f40 100644 --- a/websockets/extensions/test_utils.py +++ b/websockets/extensions/test_utils.py @@ -69,3 +69,16 @@ def test_parse_extension_list_invalid_header(self): ]: with self.assertRaises(InvalidHeader): parse_extension_list(header) + + +class ExtensionTestsMixin: + + def assertExtensionEqual(self, extension1, extension2): + self.assertEqual(extension1.remote_no_context_takeover, + extension2.remote_no_context_takeover) + self.assertEqual(extension1.local_no_context_takeover, + extension2.local_no_context_takeover) + self.assertEqual(extension1.remote_max_window_bits, + extension2.remote_max_window_bits) + self.assertEqual(extension1.local_max_window_bits, + extension2.local_max_window_bits) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index a835ba692..ce842dd6d 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -11,7 +11,6 @@ from .client import * from .compatibility import FORBIDDEN, OK, UNAUTHORIZED -from .exceptions import ConnectionClosed, InvalidHandshake, InvalidStatusCode from .exceptions import ( ConnectionClosed, InvalidHandshake, InvalidStatusCode, NegotiationError ) From 1e81db7d07412bf521d4be99200b41169b5f4236 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 21 Aug 2017 22:02:57 +0200 Subject: [PATCH 0293/1539] Add tests for extension negotiation. --- websockets/client.py | 4 +- websockets/extensions/permessage_deflate.py | 12 ++ websockets/server.py | 6 +- websockets/test_client_server.py | 218 ++++++++++++++++++-- 4 files changed, 223 insertions(+), 17 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index 4387bcce8..b38e3fb2b 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -121,7 +121,9 @@ def process_extensions(self, get_header, available_extensions=None): response_params, accepted_extensions) # Skip non-matching extensions based on their params. - if extension is None: + # There are no tests because the only extension currently + # built in, permessage-deflate, doesn't need this feature. + if extension is None: # pragma: no cover continue # Add matching extension to the final list. diff --git a/websockets/extensions/permessage_deflate.py b/websockets/extensions/permessage_deflate.py index 91df654bf..3f7f36ee2 100644 --- a/websockets/extensions/permessage_deflate.py +++ b/websockets/extensions/permessage_deflate.py @@ -417,6 +417,18 @@ def __init__( # There's no need for self.encode_cont_data because we always encode # outgoing frames, so it would always be True. + def __repr__(self): + return 'PerMessageDeflate({})'.format(', '.join([ + 'remote_no_context_takeover={}'.format( + self.remote_no_context_takeover), + 'local_no_context_takeover={}'.format( + self.local_no_context_takeover), + 'remote_max_window_bits={}'.format( + self.remote_max_window_bits), + 'local_max_window_bits={}'.format( + self.local_max_window_bits), + ])) + def decode(self, frame): """ Decode an incoming frame. diff --git a/websockets/server.py b/websockets/server.py index 8a814bc33..d21f5e345 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -297,7 +297,9 @@ def process_extensions(self, get_header, available_extensions=None): assert (response_params is None) == (extension is None) # Skip non-matching extensions based on their params. - if extension is None: + # There are no tests because the only extension currently + # built in, permessage-deflate, doesn't need this feature. + if extension is None: # pragma: no cover continue # Add matching extension to the final list. @@ -314,7 +316,7 @@ def process_extensions(self, get_header, available_extensions=None): return extensions_header, accepted_extensions - return None, None + return None, [] def process_subprotocol(self, get_header, available_subprotocols=None): """ diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index ce842dd6d..6a6b60e06 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -14,6 +14,10 @@ from .exceptions import ( ConnectionClosed, InvalidHandshake, InvalidStatusCode, NegotiationError ) +from .extensions.permessage_deflate import ( + ClientPerMessageDeflateFactory, PerMessageDeflate, + ServerPerMessageDeflateFactory +) from .handshake import build_response from .http import USER_AGENT, read_response from .server import * @@ -37,6 +41,8 @@ def handler(ws, path): elif path == '/raw_headers': yield from ws.send(repr(ws.raw_request_headers)) yield from ws.send(repr(ws.raw_response_headers)) + elif path == '/extensions': + yield from ws.send(repr(ws.extensions)) elif path == '/subprotocol': yield from ws.send(repr(ws.subprotocol)) else: @@ -121,6 +127,50 @@ class BarClientProtocol(WebSocketClientProtocol): pass +class ClientNoOpExtensionFactory: + name = 'x-no-op' + + def __init__(self, params=None): + if params is None: + params = [] + self.params = params + + def get_request_params(self): + return self.params + + def process_response_params(self, params, accepted_extensions): + if params: + raise NegotiationError() + return NoOpExtension() + + +class ServerNoOpExtensionFactory: + name = 'x-no-op' + + def __init__(self, params=None): + if params is None: + params = [] + self.params = params + + def process_request_params(self, params, accepted_extensions): + if params: + raise NegotiationError() + return self.params, NoOpExtension() + + +class NoOpExtension: + name = 'x-no-op' + + def __repr__(self): + return 'NoOpExtension()' + + def decode(self, frame): + return frame + + def encode(self, frame): + return frame + + class ClientServerTests(unittest.TestCase): secure = False @@ -182,13 +232,6 @@ def test_basic(self): reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") - @with_server() - @with_client(use_compression=False) - def test_basic_no_compression(self): - self.loop.run_until_complete(self.client.send("Hello!")) - reply = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(reply, "Hello!") - @with_server() def test_server_close_while_client_connected(self): self.start_client() @@ -360,6 +403,140 @@ def test_client_klass(self): def test_client_create_protocol_over_klass(self): self.assertIsInstance(self.client, BarClientProtocol) + @with_server() + @with_client('extensions') + def test_no_extension(self): + server_extensions = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(server_extensions, repr([])) + self.assertEqual(repr(self.client.extensions), repr([])) + + @with_server(extensions=[ServerNoOpExtensionFactory()]) + @with_client('extensions', extensions=[ClientNoOpExtensionFactory()]) + def test_extension(self): + server_extensions = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(server_extensions, repr([NoOpExtension()])) + self.assertEqual(repr(self.client.extensions), repr([NoOpExtension()])) + + @with_server() + @with_client('extensions', extensions=[ClientNoOpExtensionFactory()]) + def test_extension_not_accepted(self): + server_extensions = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(server_extensions, repr([])) + self.assertEqual(repr(self.client.extensions), repr([])) + + @with_server(extensions=[ServerNoOpExtensionFactory()]) + @with_client('extensions') + def test_extension_not_requested(self): + server_extensions = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(server_extensions, repr([])) + self.assertEqual(repr(self.client.extensions), repr([])) + + @with_server(extensions=[ServerNoOpExtensionFactory()]) + def test_extension_server_rejection(self): + with self.assertRaises(InvalidStatusCode): + self.start_client( + 'extensions', + extensions=[ClientNoOpExtensionFactory([('foo', None)])], + ) + + @with_server(extensions=[ServerNoOpExtensionFactory([('foo', None)])]) + def test_extension_client_rejection(self): + with self.assertRaises(NegotiationError): + self.start_client( + 'extensions', + extensions=[ClientNoOpExtensionFactory()], + ) + + @with_server(extensions=[ServerPerMessageDeflateFactory()]) + @with_client('extensions', extensions=[ClientNoOpExtensionFactory()]) + def test_extension_mismatch(self): + server_extensions = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(server_extensions, repr([])) + self.assertEqual(repr(self.client.extensions), repr([])) + + @with_server( + extensions=[ + ServerNoOpExtensionFactory(), + ServerPerMessageDeflateFactory(), + ], + ) + @with_client( + 'extensions', + extensions=[ + ClientPerMessageDeflateFactory(), + ClientNoOpExtensionFactory(), + ], + ) + def test_extension_order(self): + # The order requested by the client has priority. + server_extensions = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(server_extensions, repr([ + PerMessageDeflate(False, False, 15, 15), + NoOpExtension(), + ])) + self.assertEqual(repr(self.client.extensions), repr([ + PerMessageDeflate(False, False, 15, 15), + NoOpExtension(), + ])) + + @with_server(extensions=[ServerNoOpExtensionFactory()]) + @unittest.mock.patch.object(WebSocketServerProtocol, 'process_extensions') + def test_extensions_error(self, _process_extensions): + _process_extensions.return_value = 'x-no-op', [NoOpExtension()] + + with self.assertRaises(NegotiationError): + self.start_client( + 'extensions', + extensions=[ClientPerMessageDeflateFactory()], + ) + + @with_server(extensions=[ServerNoOpExtensionFactory()]) + @unittest.mock.patch.object(WebSocketServerProtocol, 'process_extensions') + def test_extensions_error_no_extensions(self, _process_extensions): + _process_extensions.return_value = 'x-no-op', [NoOpExtension()] + + with self.assertRaises(InvalidHandshake): + self.start_client('extensions') + + @with_server(use_compression=True) + @with_client('extensions', use_compression=True) + def test_use_compression(self): + server_extensions = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(server_extensions, repr([ + PerMessageDeflate(False, False, 15, 15), + ])) + self.assertEqual(repr(self.client.extensions), repr([ + PerMessageDeflate(False, False, 15, 15), + ])) + + @with_server( + extensions=[ + ServerPerMessageDeflateFactory( + client_no_context_takeover=True, + server_max_window_bits=10, + ), + ], + use_compression=True, # overridden by explicit config + ) + @with_client( + 'extensions', + extensions=[ + ClientPerMessageDeflateFactory( + server_no_context_takeover=True, + client_max_window_bits=12, + ), + ], + use_compression=True, # overridden by explicit config + ) + def test_use_compression_explicit_config(self): + server_extensions = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(server_extensions, repr([ + PerMessageDeflate(True, True, 12, 10), + ])) + self.assertEqual(repr(self.client.extensions), repr([ + PerMessageDeflate(True, True, 10, 12), + ])) + @with_server() @with_client('subprotocol') def test_no_subprotocol(self): @@ -369,14 +546,14 @@ def test_no_subprotocol(self): @with_server(subprotocols=['superchat', 'chat']) @with_client('subprotocol', subprotocols=['otherchat', 'chat']) - def test_subprotocol_found(self): + def test_subprotocol(self): server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr('chat')) self.assertEqual(self.client.subprotocol, 'chat') @with_server(subprotocols=['superchat']) @with_client('subprotocol', subprotocols=['otherchat']) - def test_subprotocol_not_found(self): + def test_subprotocol_not_accepted(self): server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) @@ -396,14 +573,23 @@ def test_subprotocol_not_requested(self): self.assertEqual(self.client.subprotocol, None) @with_server(subprotocols=['superchat']) - @unittest.mock.patch.object(WebSocketServerProtocol, 'select_subprotocol') - def test_subprotocol_error(self, _select_subprotocol): - _select_subprotocol.return_value = 'superchat' + @unittest.mock.patch.object(WebSocketServerProtocol, 'process_subprotocol') + def test_subprotocol_error(self, _process_subprotocol): + _process_subprotocol.return_value = 'superchat' with self.assertRaises(NegotiationError): self.start_client('subprotocol', subprotocols=['otherchat']) self.run_loop_once() + @with_server(subprotocols=['superchat']) + @unittest.mock.patch.object(WebSocketServerProtocol, 'process_subprotocol') + def test_subprotocol_error_no_subprotocols(self, _process_subprotocol): + _process_subprotocol.return_value = 'superchat' + + with self.assertRaises(InvalidHandshake): + self.start_client('subprotocol') + self.run_loop_once() + @with_server() @unittest.mock.patch('websockets.server.read_request') def test_server_receives_malformed_request(self, _read_request): @@ -573,12 +759,16 @@ def client_context(self): return ssl_context def start_server(self, *args, **kwds): - kwds['ssl'] = self.server_context + kwds.setdefault('ssl', self.server_context) + # Don't enable compression by default in tests. + kwds.setdefault('use_compression', False) server = serve(handler, 'localhost', 8642, **kwds) self.server = self.loop.run_until_complete(server) def start_client(self, path='', **kwds): - kwds['ssl'] = self.client_context + kwds.setdefault('ssl', self.client_context) + # Don't enable compression by default in tests. + kwds.setdefault('use_compression', False) client = connect('wss://localhost:8642/' + path, **kwds) self.client = self.loop.run_until_complete(client) From cb669973247f5a6df4d07128759607043f071f92 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 26 Aug 2017 16:05:52 +0200 Subject: [PATCH 0294/1539] Require parameter values to match the token ABNF. --- websockets/extensions/test_utils.py | 11 ++++------- websockets/extensions/utils.py | 28 +++++++++++----------------- 2 files changed, 15 insertions(+), 24 deletions(-) diff --git a/websockets/extensions/test_utils.py b/websockets/extensions/test_utils.py index 571752f40..811a387cc 100644 --- a/websockets/extensions/test_utils.py +++ b/websockets/extensions/test_utils.py @@ -18,19 +18,15 @@ def test_parse_extension_list(self): [('foo', []), ('bar', [])], ), ( - 'foo; name; token=token; quoted-string="quoted string", ' + 'foo; name; token=token; quoted-string="quoted-string", ' 'bar; quux; quuux', [ ('foo', [('name', None), ('token', 'token'), - ('quoted-string', 'quoted string')]), + ('quoted-string', 'quoted-string')]), ('bar', [('quux', None), ('quuux', None)]), ], ), # Pathological examples - ( - 'a; b="q,s;1\\"2\'3\\\\4="; c="q;s,6=7\\\\8\'9\\\""', - [('a', [('b', 'q,s;1"2\'3\\4='), ('c', 'q;s,6=7\\8\'9"')])] - ), ( ',\t, , ,foo ;bar = 42,, baz,,', [('foo', [('bar', '42')]), ('baz', [])], @@ -65,7 +61,8 @@ def test_parse_extension_list_invalid_header(self): 'foo; bar="baz', # Wrong delimiter 'foo, bar, baz=quux; quuux', - + # Value in quoted string parameter that isn't a token + 'foo; bar=" "', ]: with self.assertRaises(InvalidHeader): parse_extension_list(header) diff --git a/websockets/extensions/utils.py b/websockets/extensions/utils.py index 3ba9f7148..021b7fa65 100644 --- a/websockets/extensions/utils.py +++ b/websockets/extensions/utils.py @@ -26,6 +26,9 @@ def parse_OWS(string, pos): _token_re = re.compile(r'[-!#$%&\'*+.^_`|~0-9a-zA-Z]+') +# Workaround for the lack of re.fullmatch in older Pythons +_exact_token_re = re.compile(r'^[-!#$%&\'*+.^_`|~0-9a-zA-Z]+$') + def parse_token(string, pos): match = _token_re.match(string, pos) @@ -56,7 +59,13 @@ def parse_extension_param(string, pos): if peek_ahead(string, pos) == '=': pos = parse_OWS(string, pos + 1) if peek_ahead(string, pos) == '"': + pos_before = pos # for proper error reporting below value, pos = parse_quoted_string(string, pos) + # https://tools.ietf.org/html/rfc6455#section-9.1 says: the value + # after quoted-string unescaping MUST conform to the 'token' ABNF. + if _exact_token_re.match(value) is None: + raise InvalidHeader("invalid quoted string content", + string=string, pos=pos_before) else: value, pos = parse_token(string, pos) pos = parse_OWS(string, pos) @@ -142,25 +151,10 @@ def parse_extension_list(string, pos=0): return extensions -_quote_re = re.compile(r'([\x22\x5c])') - - -# Workaround for the lack of re.fullmatch in older Pythons -_exact_token_re = re.compile(r'^[-!#$%&\'*+.^_`|~0-9a-zA-Z]+$') - - -def build_extension_param(name, value): - if value is None: - return name - elif _exact_token_re.match(value): - return '{}={}'.format(name, value) - else: - return '{}="{}"'.format(name, _quote_re.sub(r'\\\1', value)) - - def build_extension(name, parameters): return '; '.join([name] + [ - build_extension_param(name, value) + # Quoted strings aren't necessary because values are always tokens. + name if value is None else '{}={}'.format(name, value) for name, value in parameters ]) From d1c4981a82b6c255cbee75fcd11f3f2dc4c70d5f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 26 Aug 2017 16:07:47 +0200 Subject: [PATCH 0295/1539] Stop special casing [] in build_extension_list. This makes build_extension_list and parse_extension_list more exactly the opposite of each other. Also it's more natural to handle this in process_extensions(), which looks better with a single return statement. --- websockets/extensions/utils.py | 2 +- websockets/server.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/websockets/extensions/utils.py b/websockets/extensions/utils.py index 021b7fa65..843de3ee4 100644 --- a/websockets/extensions/utils.py +++ b/websockets/extensions/utils.py @@ -169,4 +169,4 @@ def build_extension_list(extensions): return ', '.join( build_extension(name, parameters) for name, parameters in extensions - ) if extensions else None + ) diff --git a/websockets/server.py b/websockets/server.py index d21f5e345..51feac9b4 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -261,6 +261,9 @@ def process_extensions(self, get_header, available_extensions=None): """ extensions = get_header('Sec-WebSocket-Extensions') + extensions_header = [] + accepted_extensions = [] + if extensions and available_extensions is not None: # For each extension proposed in the client request, check if it @@ -277,9 +280,6 @@ def process_extensions(self, get_header, available_extensions=None): # The current implementation doesn't allow reordering extensions. - extensions_header = [] - accepted_extensions = [] - for name, request_params in parse_extension_list(extensions): for extension_factory in available_extensions: @@ -312,11 +312,13 @@ def process_extensions(self, get_header, available_extensions=None): # If we didn't break from the loop, no extension in our list # matched what the client sent. Ignore that extension. + # Serialize extension header. + if extensions_header: extensions_header = build_extension_list(extensions_header) + else: + extensions_header = None - return extensions_header, accepted_extensions - - return None, [] + return extensions_header, accepted_extensions def process_subprotocol(self, get_header, available_subprotocols=None): """ From 2907f80c8832a2d60261abc3c8e7161a06af130f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 27 Aug 2017 11:42:13 +0200 Subject: [PATCH 0296/1539] Improve the simple configuration for compression. Make it a bit more future-proof. --- websockets/client.py | 11 +++++++---- websockets/server.py | 11 +++++++---- websockets/test_client_server.py | 26 ++++++++++++++++---------- 3 files changed, 30 insertions(+), 18 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index b38e3fb2b..1a3368afc 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -245,7 +245,7 @@ def connect(uri, *, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, legacy_recv=False, klass=None, origin=None, extensions=None, subprotocols=None, - extra_headers=None, use_compression=True, **kwds): + extra_headers=None, compression='deflate', **kwds): """ This coroutine connects to a WebSocket server at a given ``uri``. @@ -280,8 +280,9 @@ def connect(uri, *, decreasing preference * ``extra_headers`` sets additional HTTP request headers – it can be a mapping or an iterable of (name, value) pairs - * ``use_compression`` is a shortcut to enable compression of messages with - the "permessage-deflate" extension; it is enabled by default + * ``compression`` is a shortcut to configure compression extensions; + by default it enables the "permessage-deflate" extension; set it to + ``None`` to disable compression :func:`connect` raises :exc:`~websockets.uri.InvalidURI` if ``uri`` is invalid and :exc:`~websockets.handshake.InvalidHandshake` if the opening @@ -309,7 +310,7 @@ def connect(uri, *, raise ValueError("connect() received a SSL context for a ws:// URI. " "Use a wss:// URI to enable TLS.") - if use_compression: + if compression == 'deflate': if extensions is None: extensions = [] if not any( @@ -319,6 +320,8 @@ def connect(uri, *, extensions.append(ClientPerMessageDeflateFactory( client_max_window_bits=True, )) + elif compression is not None: + raise ValueError("Unsupported compression: {}".format(compression)) factory = lambda: create_protocol( host=wsuri.host, port=wsuri.port, secure=wsuri.secure, diff --git a/websockets/server.py b/websockets/server.py index 51feac9b4..8a600601d 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -548,7 +548,7 @@ def serve(ws_handler, host=None, port=None, *, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, legacy_recv=False, klass=None, origins=None, extensions=None, subprotocols=None, - extra_headers=None, use_compression=True, **kwds): + extra_headers=None, compression='deflate', **kwds): """ Create, start, and return a :class:`WebSocketServer` object. @@ -597,8 +597,9 @@ def serve(ws_handler, host=None, port=None, *, * ``extra_headers`` sets additional HTTP response headers — it can be a mapping, an iterable of (name, value) pairs, or a callable taking the request path and headers in arguments. - * ``use_compression`` is a shortcut to enable compression of messages with - the "permessage-deflate" extension; it is enabled by default + * ``compression`` is a shortcut to configure compression extensions; + by default it enables the "permessage-deflate" extension; set it to + ``None`` to disable compression Whenever a client connects, the server accepts the connection, creates a :class:`WebSocketServerProtocol`, performs the opening handshake, and @@ -630,7 +631,7 @@ def serve(ws_handler, host=None, port=None, *, secure = kwds.get('ssl') is not None - if use_compression: + if compression == 'deflate': if extensions is None: extensions = [] if not any( @@ -638,6 +639,8 @@ def serve(ws_handler, host=None, port=None, *, for extension_factory in extensions ): extensions.append(ServerPerMessageDeflateFactory()) + elif compression is not None: + raise ValueError("Unsupported compression: {}".format(compression)) factory = lambda: create_protocol( ws_handler, ws_server, diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 6a6b60e06..c68d00ebd 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -190,13 +190,13 @@ def run_loop_once(self): def start_server(self, **kwds): # Don't enable compression by default in tests. - kwds.setdefault('use_compression', False) + kwds.setdefault('compression', None) server = serve(handler, 'localhost', 8642, **kwds) self.server = self.loop.run_until_complete(server) def start_client(self, path='', **kwds): # Don't enable compression by default in tests. - kwds.setdefault('use_compression', False) + kwds.setdefault('compression', None) client = connect('ws://localhost:8642/' + path, **kwds) self.client = self.loop.run_until_complete(client) @@ -498,9 +498,9 @@ def test_extensions_error_no_extensions(self, _process_extensions): with self.assertRaises(InvalidHandshake): self.start_client('extensions') - @with_server(use_compression=True) - @with_client('extensions', use_compression=True) - def test_use_compression(self): + @with_server(compression='deflate') + @with_client('extensions', compression='deflate') + def test_compression_deflate(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([ PerMessageDeflate(False, False, 15, 15), @@ -516,7 +516,7 @@ def test_use_compression(self): server_max_window_bits=10, ), ], - use_compression=True, # overridden by explicit config + compression='deflate', # overridden by explicit config ) @with_client( 'extensions', @@ -526,9 +526,9 @@ def test_use_compression(self): client_max_window_bits=12, ), ], - use_compression=True, # overridden by explicit config + compression='deflate', # overridden by explicit config ) - def test_use_compression_explicit_config(self): + def test_compression_deflate_and_explicit_config(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([ PerMessageDeflate(True, True, 12, 10), @@ -537,6 +537,12 @@ def test_use_compression_explicit_config(self): PerMessageDeflate(True, True, 10, 12), ])) + def test_compression_unsupported(self): + with self.assertRaises(ValueError): + self.loop.run_until_complete(self.start_server(compression='xz')) + with self.assertRaises(ValueError): + self.loop.run_until_complete(self.start_client(compression='xz')) + @with_server() @with_client('subprotocol') def test_no_subprotocol(self): @@ -761,14 +767,14 @@ def client_context(self): def start_server(self, *args, **kwds): kwds.setdefault('ssl', self.server_context) # Don't enable compression by default in tests. - kwds.setdefault('use_compression', False) + kwds.setdefault('compression', None) server = serve(handler, 'localhost', 8642, **kwds) self.server = self.loop.run_until_complete(server) def start_client(self, path='', **kwds): kwds.setdefault('ssl', self.client_context) # Don't enable compression by default in tests. - kwds.setdefault('use_compression', False) + kwds.setdefault('compression', None) client = connect('wss://localhost:8642/' + path, **kwds) self.client = self.loop.run_until_complete(client) From 89725778f6f6ff43dac49aadd5c600fc683d35c3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 27 Aug 2017 12:17:32 +0200 Subject: [PATCH 0297/1539] Make it possible to tweak compression settings. --- websockets/extensions/permessage_deflate.py | 24 +++++++++++++++++-- .../extensions/test_permessage_deflate.py | 19 +++++++++++++++ 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/websockets/extensions/permessage_deflate.py b/websockets/extensions/permessage_deflate.py index 3f7f36ee2..ce8bfb528 100644 --- a/websockets/extensions/permessage_deflate.py +++ b/websockets/extensions/permessage_deflate.py @@ -112,6 +112,7 @@ def __init__( client_no_context_takeover=False, server_max_window_bits=None, client_max_window_bits=None, + compress_settings=None, ): """ Configure permessage-deflate extension factory. @@ -126,11 +127,15 @@ def __init__( client_max_window_bits is True or 8 <= client_max_window_bits <= 15): raise ValueError("client_max_window_bits must be between 8 and 15") + if compress_settings is not None and 'wbits' in compress_settings: + raise ValueError("compress_settings must not include wbits, " + "set client_max_window_bits instead") self.server_no_context_takeover = server_no_context_takeover self.client_no_context_takeover = client_no_context_takeover self.server_max_window_bits = server_max_window_bits self.client_max_window_bits = client_max_window_bits + self.compress_settings = compress_settings def get_request_params(self): """ @@ -237,6 +242,7 @@ def process_response_params(self, params, accepted_extensions): client_no_context_takeover, # local_no_context_takeover server_max_window_bits or 15, # remote_max_window_bits client_max_window_bits or 15, # local_max_window_bits + self.compress_settings, ) @@ -253,6 +259,7 @@ def __init__( client_no_context_takeover=False, server_max_window_bits=None, client_max_window_bits=None, + compress_settings=None, ): """ Configure permessage-deflate extension factory. @@ -266,11 +273,15 @@ def __init__( if not (client_max_window_bits is None or 8 <= client_max_window_bits <= 15): raise ValueError("client_max_window_bits must be between 8 and 15") + if compress_settings is not None and 'wbits' in compress_settings: + raise ValueError("compress_settings must not include wbits, " + "set server_max_window_bits instead") self.server_no_context_takeover = server_no_context_takeover self.client_no_context_takeover = client_no_context_takeover self.server_max_window_bits = server_max_window_bits self.client_max_window_bits = client_max_window_bits + self.compress_settings = compress_settings def process_request_params(self, params, accepted_extensions): """" @@ -371,6 +382,7 @@ def process_request_params(self, params, accepted_extensions): server_no_context_takeover, # local_no_context_takeover client_max_window_bits or 15, # remote_max_window_bits server_max_window_bits or 15, # local_max_window_bits + self.compress_settings, ) ) @@ -388,20 +400,26 @@ def __init__( local_no_context_takeover, remote_max_window_bits, local_max_window_bits, + compress_settings=None, ): """ Configure permessage-deflate extension. """ + if compress_settings is None: + compress_settings = {} + assert remote_no_context_takeover in [False, True] assert local_no_context_takeover in [False, True] assert 8 <= remote_max_window_bits <= 15 assert 8 <= local_max_window_bits <= 15 + assert 'wbits' not in compress_settings self.remote_no_context_takeover = remote_no_context_takeover self.local_no_context_takeover = local_no_context_takeover self.remote_max_window_bits = remote_max_window_bits self.local_max_window_bits = local_max_window_bits + self.compress_settings = compress_settings if not self.remote_no_context_takeover: self.decoder = zlib.decompressobj( @@ -409,7 +427,8 @@ def __init__( if not self.local_no_context_takeover: self.encoder = zlib.compressobj( - wbits=-self.local_max_window_bits) + wbits=-self.local_max_window_bits, + **self.compress_settings) # To handle continuation frames properly, we must keep track of # whether that initial frame was encoded. @@ -489,7 +508,8 @@ def encode(self, frame): # Re-initialize per-message decoder. if self.local_no_context_takeover: self.encoder = zlib.compressobj( - wbits=-self.local_max_window_bits) + wbits=-self.local_max_window_bits, + **self.compress_settings) # Compress data frames. data = ( diff --git a/websockets/extensions/test_permessage_deflate.py b/websockets/extensions/test_permessage_deflate.py index eb5c51ea6..bad291fed 100644 --- a/websockets/extensions/test_permessage_deflate.py +++ b/websockets/extensions/test_permessage_deflate.py @@ -26,6 +26,7 @@ def test_init(self): (True, False, None, 8), # client_max_window_bits ≥ 8 (True, True, None, 15), # client_max_window_bits ≤ 15 (False, False, None, True), # client_max_window_bits + (False, False, None, None, {'memLevel': 4}), ]: # This does not raise an exception. ClientPerMessageDeflateFactory(*config) @@ -37,6 +38,7 @@ def test_init_error(self): (True, False, 16, 15), # server_max_window_bits > 15 (True, True, 15, 16), # client_max_window_bits > 15 (False, False, True, None), # server_max_window_bits + (False, False, None, None, {'wbits': 11}), ]: with self.assertRaises(ValueError): ClientPerMessageDeflateFactory(*config) @@ -306,6 +308,7 @@ def test_init(self): (False, True, 15, None), # server_max_window_bits ≤ 15 (True, False, None, 8), # client_max_window_bits ≥ 8 (True, True, None, 15), # client_max_window_bits ≤ 15 + (False, False, None, None, {'memLevel': 4}), ]: # This does not raise an exception. ServerPerMessageDeflateFactory(*config) @@ -318,6 +321,7 @@ def test_init_error(self): (True, True, 15, 16), # client_max_window_bits > 15 (False, False, None, True), # client_max_window_bits (False, False, True, None), # server_max_window_bits + (False, False, None, None, {'wbits': 11}), ]: with self.assertRaises(ValueError): ServerPerMessageDeflateFactory(*config) @@ -778,3 +782,18 @@ def test_local_no_context_takeover(self): self.assertEqual(dec_frame1, frame) self.assertEqual(dec_frame2, frame) + + # Compression settings can be customized. + + def test_compress_settings(self): + # Configure an extension so that no compression actually occurs. + extension = PerMessageDeflate(False, False, 15, 15, {'level': 0}) + + frame = Frame(True, OP_TEXT, 'café'.encode('utf-8')) + + enc_frame = extension.encode(frame) + + self.assertEqual(enc_frame, frame._replace( + rsv1=True, + data=b'\x00\x05\x00\xfa\xffcaf\xc3\xa9\x00', # not compressed + )) From d8f5df1f8929af0ca62cfd12d3d91c8beac37b2f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 27 Aug 2017 14:45:06 +0200 Subject: [PATCH 0298/1539] Specify more closely extension negotiation. Remove the unclear distinction between "not accepting an exception" and "raising a negotiation error" on the server side. A negotiation error means the extension isn't accepted, that's all. --- websockets/client.py | 60 +++++++++-------- websockets/extensions/base.py | 13 ++-- websockets/server.py | 107 +++++++++++++++++-------------- websockets/test_client_server.py | 44 +++++++------ 4 files changed, 121 insertions(+), 103 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index 1a3368afc..2e93bc657 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -82,32 +82,41 @@ def read_http_response(self): return status_code, self.response_headers - def process_extensions(self, get_header, available_extensions=None): + def process_extensions(self, get_header, available_extensions): """ Handle the Sec-WebSocket-Extensions HTTP response header. + Check that each extension is supported, as well as its parameters. + + Return the list of accepted extensions. + + Raise :exc:`~websockets.exceptions.InvalidHandshake` to abort the + connection. + + RFC 6455 leaves the rules up to the specification of each extension. + + To provide this level of flexibility, for each extension accepted by + the server, we check for a match with each extension available in the + client configuration. If no match is found, an exception is raised. + + If several variants of the same extension are accepted by the server, + it may be configured severel times, which won't make sense in general. + Extensions must implement their own requirements. For this purpose, + the list of previously accepted extensions is provided. + + Other requirements, for example related to mandatory extensions or the + order of extensions, may be implemented by overriding this method. + """ extensions = get_header('Sec-WebSocket-Extensions') + accepted_extensions = [] + if extensions: if available_extensions is None: raise InvalidHandshake("No extensions supported.") - # For each extension selected in the server response, check that - # it matches an extension in our list of available extensions. - - # RFC 6455 leaves the exact process up to the specification of - # each extension. To provide this flexibility, we tell each - # extension which extensions were accepted up to this point. - - # Such flexibility prevents us from providing any guarantees - # against reordered or duplicated extensions in the response. - # Extensions must implement ther own requirements, based on the - # list of previously accepted extensions. - - accepted_extensions = [] - for name, response_params in parse_extension_list(extensions): for extension_factory in available_extensions: @@ -116,14 +125,11 @@ def process_extensions(self, get_header, available_extensions=None): if extension_factory.name != name: continue - # This is allowed to raise NegotiationError. - extension = extension_factory.process_response_params( - response_params, accepted_extensions) - # Skip non-matching extensions based on their params. - # There are no tests because the only extension currently - # built in, permessage-deflate, doesn't need this feature. - if extension is None: # pragma: no cover + try: + extension = extension_factory.process_response_params( + response_params, accepted_extensions) + except NegotiationError: continue # Add matching extension to the final list. @@ -139,14 +145,16 @@ def process_extensions(self, get_header, available_extensions=None): "Unsupported extension: name={}, params={}".format( name, response_params)) - return accepted_extensions - - return [] + return accepted_extensions - def process_subprotocol(self, get_header, available_subprotocols=None): + def process_subprotocol(self, get_header, available_subprotocols): """ Handle the Sec-WebSocket-Protocol HTTP response header. + Check that it contains a supported subprotocol. + + Return the selected subprotocol. + """ subprotocol = get_header('Sec-WebSocket-Protocol') diff --git a/websockets/extensions/base.py b/websockets/extensions/base.py index 453ab18ef..3ec7c4321 100644 --- a/websockets/extensions/base.py +++ b/websockets/extensions/base.py @@ -33,10 +33,9 @@ def process_response_params(self, params, accepted_extensions): represented by extension instances. Return an extension instance (an instance of a subclass of - :class:`Extension`) to accept this response or ``None`` to reject it. + :class:`Extension`) if these parameters are acceptable. - Raise :exc:`~websockets.exceptions.NegotiationError` to abort the - handshake and fail the WebSocket connection. + Raise :exc:`~websockets.exceptions.NegotiationError` if they aren't. """ @@ -57,15 +56,11 @@ def process_request_params(self, params, accepted_extensions): ``accepted_extensions`` is a list of previously accepted extensions, represented by extension instances. - Return response params and an extension instance to accept this - extension or ``None, None`` to reject it. - Return response params (a list of (name, value) pairs) and an extension instance (an instance of a subclass of :class:`Extension`) - to accept this response or ``None, None`` to reject it. + to accept this extension. - Raise :exc:`~websockets.exceptions.NegotiationError` to abort the - handshake and fail the websocket connection. + Raise :exc:`~websockets.exceptions.NegotiationError` to reject it. """ diff --git a/websockets/server.py b/websockets/server.py index 8a600601d..ff18312d9 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -12,7 +12,8 @@ SWITCHING_PROTOCOLS, asyncio_ensure_future ) from .exceptions import ( - AbortHandshake, InvalidHandshake, InvalidMessage, InvalidOrigin + AbortHandshake, InvalidHandshake, InvalidMessage, InvalidOrigin, + NegotiationError ) from .extensions.permessage_deflate import ServerPerMessageDeflateFactory from .extensions.utils import build_extension_list, parse_extension_list @@ -254,31 +255,44 @@ def process_origin(self, get_header, origins=None): raise InvalidOrigin("Origin not allowed: {}".format(origin)) return origin - def process_extensions(self, get_header, available_extensions=None): + def process_extensions(self, get_header, available_extensions): """ Handle the Sec-WebSocket-Extensions HTTP request header. - """ - extensions = get_header('Sec-WebSocket-Extensions') + Accept or reject each extension proposed in the client request. + Negotiate parameters for accepted extensions. - extensions_header = [] - accepted_extensions = [] + Return the Sec-WebSocket-Extensions HTTP response header and the list + of accepted extensions. - if extensions and available_extensions is not None: + Raise :exc:`~websockets.exceptions.InvalidHandshake` to abort the + handshake with an HTTP 400 error code. (The default implementation + never does this.) + + RFC 6455 leaves the rules up to the specification of each extension. - # For each extension proposed in the client request, check if it - # matches an extension in our list of available extensions. + To provide this level of flexibility, for each extension proposed by + the client, we check for a match with each extension available in the + server configuration. If no match is found, the extension is ignored. - # RFC 6455 leaves the exact process up to the specification of - # each extension. To provide this flexibility, we tell each - # extension which extensions were accepted up to this point. + If several variants of the same extension are proposed by the client, + it may be accepted severel times, which won't make sense in general. + Extensions must implement their own requirements. For this purpose, + the list of previously accepted extensions is provided. - # Such flexibility prevents us from providing any guarantees - # against duplicated extensions in the request. Extensions must - # implement ther own requirements, based on the list of previously - # accepted extensions. + This process doesn't allow the server to reorder extensions. It can + only select a subset of the extensions proposed by the client. - # The current implementation doesn't allow reordering extensions. + Other requirements, for example related to mandatory extensions or the + order of extensions, may be implemented by overriding this method. + + """ + extensions = get_header('Sec-WebSocket-Extensions') + + response_header = [] + accepted_extensions = [] + + if extensions and available_extensions is not None: for name, request_params in parse_extension_list(extensions): @@ -288,56 +302,51 @@ def process_extensions(self, get_header, available_extensions=None): if extension_factory.name != name: continue - # This is allowed to raise NegotiationError. - response_params, extension = ( - extension_factory.process_request_params( - request_params, accepted_extensions) - ) - - assert (response_params is None) == (extension is None) - # Skip non-matching extensions based on their params. - # There are no tests because the only extension currently - # built in, permessage-deflate, doesn't need this feature. - if extension is None: # pragma: no cover + try: + response_params, extension = ( + extension_factory.process_request_params( + request_params, accepted_extensions)) + except NegotiationError: continue # Add matching extension to the final list. - extensions_header.append((name, response_params)) + response_header.append((name, response_params)) accepted_extensions.append(extension) # Break out of the loop once we have a match. break # If we didn't break from the loop, no extension in our list - # matched what the client sent. Ignore that extension. + # matched what the client sent. The extension is declined. # Serialize extension header. - if extensions_header: - extensions_header = build_extension_list(extensions_header) + if response_header: + response_header = build_extension_list(response_header) else: - extensions_header = None + response_header = None - return extensions_header, accepted_extensions + return response_header, accepted_extensions - def process_subprotocol(self, get_header, available_subprotocols=None): + def process_subprotocol(self, get_header, available_subprotocols): """ Handle the Sec-WebSocket-Protocol HTTP request header. + Return Sec-WebSocket-Protocol HTTP response header, which is the same + as the selected subprotocol. + """ - if available_subprotocols is not None: - - subprotocols = get_header('Sec-WebSocket-Protocol') - - if subprotocols: - subprotocols = [ - subprotocol.strip() - for subprotocol in subprotocols.split(',') - ] - return self.select_subprotocol( - subprotocols, - available_subprotocols, - ) + subprotocols = get_header('Sec-WebSocket-Protocol') + + if subprotocols and available_subprotocols is not None: + subprotocols = [ + subprotocol.strip() + for subprotocol in subprotocols.split(',') + ] + return self.select_subprotocol( + subprotocols, + available_subprotocols, + ) return None @@ -348,7 +357,7 @@ def select_subprotocol(client_subprotocols, server_subprotocols): If several subprotocols are supported by the client and the server, the default implementation selects the preferred subprotocols by - giving equal valueto the priorities of the client and the server. + giving equal value to the priorities of the client and the server. If no subprotocols are supported by the client and the server, it proceeds without a subprotocol. diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index c68d00ebd..28dcbed33 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -130,13 +130,8 @@ class BarClientProtocol(WebSocketClientProtocol): class ClientNoOpExtensionFactory: name = 'x-no-op' - def __init__(self, params=None): - if params is None: - params = [] - self.params = params - def get_request_params(self): - return self.params + return [] def process_response_params(self, params, accepted_extensions): if params: @@ -148,13 +143,9 @@ class ServerNoOpExtensionFactory: name = 'x-no-op' def __init__(self, params=None): - if params is None: - params = [] - self.params = params + self.params = params or [] def process_request_params(self, params, accepted_extensions): - if params: - raise NegotiationError() return self.params, NoOpExtension() @@ -431,14 +422,6 @@ def test_extension_not_requested(self): self.assertEqual(server_extensions, repr([])) self.assertEqual(repr(self.client.extensions), repr([])) - @with_server(extensions=[ServerNoOpExtensionFactory()]) - def test_extension_server_rejection(self): - with self.assertRaises(InvalidStatusCode): - self.start_client( - 'extensions', - extensions=[ClientNoOpExtensionFactory([('foo', None)])], - ) - @with_server(extensions=[ServerNoOpExtensionFactory([('foo', None)])]) def test_extension_client_rejection(self): with self.assertRaises(NegotiationError): @@ -447,6 +430,29 @@ def test_extension_client_rejection(self): extensions=[ClientNoOpExtensionFactory()], ) + @with_server( + extensions=[ + # No match because the client doesn't send client_max_window_bits. + ServerPerMessageDeflateFactory(client_max_window_bits=10), + ServerPerMessageDeflateFactory(), + ], + ) + @with_client( + 'extensions', + extensions=[ + ClientPerMessageDeflateFactory(), + ], + ) + def test_extension_no_match_then_match(self): + # The order requested by the client has priority. + server_extensions = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(server_extensions, repr([ + PerMessageDeflate(False, False, 15, 15), + ])) + self.assertEqual(repr(self.client.extensions), repr([ + PerMessageDeflate(False, False, 15, 15), + ])) + @with_server(extensions=[ServerPerMessageDeflateFactory()]) @with_client('extensions', extensions=[ClientNoOpExtensionFactory()]) def test_extension_mismatch(self): From cf4db037a15cc9e21e74e3399d5268e8f98fd2bb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 27 Aug 2017 14:46:09 +0200 Subject: [PATCH 0299/1539] Add changelog for extensions. --- README.rst | 5 +++-- docs/changelog.rst | 13 ++++++++++++- docs/index.rst | 5 +++-- setup.py | 2 +- 4 files changed, 19 insertions(+), 6 deletions(-) diff --git a/README.rst b/README.rst index 3c0b013f1..8183778c9 100644 --- a/README.rst +++ b/README.rst @@ -2,8 +2,8 @@ WebSockets |pypi| |circleci| |codecov| ====================================== ``websockets`` is a library for developing WebSocket servers_ and clients_ in -Python. It implements `RFC 6455`_ with a focus on correctness and simplicity. -It passes the `Autobahn Testsuite`_. +Python. It implements `RFC 6455`_ and `RFC 7692`_ with a focus on correctness +and simplicity. It passes the `Autobahn Testsuite`_. Built on top of Python's asynchronous I/O support introduced in `PEP 3156`_, it provides an API based on coroutines, making it easy to write highly @@ -21,6 +21,7 @@ Bug reports, patches and suggestions welcome! Just open an issue_ or send a .. _servers: https://github.com/aaugustin/websockets/blob/master/example/server.py .. _clients: https://github.com/aaugustin/websockets/blob/master/example/client.py .. _RFC 6455: http://tools.ietf.org/html/rfc6455 +.. _RFC 7692: http://tools.ietf.org/html/rfc7692 .. _Autobahn Testsuite: https://github.com/aaugustin/websockets/blob/master/compliance/README.rst .. _PEP 3156: http://www.python.org/dev/peps/pep-3156/ .. _Read the Docs: https://websockets.readthedocs.io/ diff --git a/docs/changelog.rst b/docs/changelog.rst index cfb8d42fb..39d5ceaf7 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,11 +1,22 @@ Changelog --------- -3.5 +4.0 ... *In development* +.. warning:: + + **Version 4.0 enables compression with the permessage-deflate extension.** + + In August 2017, Firefox and Chrome support it, but not Safari and IE. + + Compression should improve performance but it increases RAM and CPU use. + + If you want to disable compression, add ``compression=None`` when calling + :func:`~websockets.server.serve` or :func:`~websockets.client.connect`. + 3.4 ... diff --git a/docs/index.rst b/docs/index.rst index ffcd7afe1..2a875792e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -2,8 +2,8 @@ WebSockets ========== ``websockets`` is a library for developing WebSocket servers_ and clients_ in -Python. It implements `RFC 6455`_ with a focus on correctness and simplicity. -It passes the `Autobahn Testsuite`_. +Python. It implements `RFC 6455`_ and `RFC 7692`_ with a focus on correctness +and simplicity. It passes the `Autobahn Testsuite`_. Built on top of :mod:`asyncio`, Python's standard asynchronous I/O framework, it provides a straightforward API based on coroutines, making it easy to write @@ -39,6 +39,7 @@ Bug reports, patches and suggestions welcome! Just open an issue_ or send a .. _servers: https://github.com/aaugustin/websockets/blob/master/example/server.py .. _clients: https://github.com/aaugustin/websockets/blob/master/example/client.py .. _RFC 6455: http://tools.ietf.org/html/rfc6455 +.. _RFC 7692: http://tools.ietf.org/html/rfc7692 .. _Autobahn Testsuite: https://github.com/aaugustin/websockets/blob/master/compliance/README.rst .. _PEP 3156: http://www.python.org/dev/peps/pep-3156/ .. _issue: https://github.com/aaugustin/websockets/issues/new diff --git a/setup.py b/setup.py index 14742973a..10da2faa7 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ root_dir = os.path.abspath(os.path.dirname(__file__)) -description = "An implementation of the WebSocket Protocol (RFC 6455)" +description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" readme_file = os.path.join(root_dir, 'README.rst') with open(readme_file, encoding='utf-8') as f: From d2b61f7055a4eae7c8187a818731f30e3bd2c7fb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 27 Aug 2017 15:10:45 +0200 Subject: [PATCH 0300/1539] Reject multiple instances of permessage-deflate. --- websockets/extensions/permessage_deflate.py | 6 ++++++ websockets/extensions/test_permessage_deflate.py | 12 ++++++++++++ 2 files changed, 18 insertions(+) diff --git a/websockets/extensions/permessage_deflate.py b/websockets/extensions/permessage_deflate.py index ce8bfb528..7d46a7887 100644 --- a/websockets/extensions/permessage_deflate.py +++ b/websockets/extensions/permessage_deflate.py @@ -154,6 +154,9 @@ def process_response_params(self, params, accepted_extensions): Return an extension instance. """ + if any(other.name == self.name for other in accepted_extensions): + raise NegotiationError("Received duplicate {}".format(self.name)) + # Request parameters are available in instance variables. # Load response parameters in local variables. @@ -290,6 +293,9 @@ def process_request_params(self, params, accepted_extensions): Return response params and an extension instance. """ + if any(other.name == self.name for other in accepted_extensions): + raise NegotiationError("Skipped duplicate {}".format(self.name)) + # Load request parameters in local variables. ( server_no_context_takeover, diff --git a/websockets/extensions/test_permessage_deflate.py b/websockets/extensions/test_permessage_deflate.py index bad291fed..864edc357 100644 --- a/websockets/extensions/test_permessage_deflate.py +++ b/websockets/extensions/test_permessage_deflate.py @@ -295,6 +295,12 @@ def test_process_response_params(self): expected = PerMessageDeflate(*result) self.assertExtensionEqual(extension, expected) + def test_process_response_params_deduplication(self): + factory = ClientPerMessageDeflateFactory(False, False, None, None) + with self.assertRaises(NegotiationError): + factory.process_response_params( + [], [PerMessageDeflate(False, False, 15, 15)]) + class ServerPerMessageDeflateFactoryTests(unittest.TestCase, ExtensionTestsMixin): @@ -585,6 +591,12 @@ def test_process_request_params(self): expected = PerMessageDeflate(*result) self.assertExtensionEqual(extension, expected) + def test_process_response_params_deduplication(self): + factory = ServerPerMessageDeflateFactory(False, False, None, None) + with self.assertRaises(NegotiationError): + factory.process_request_params( + [], [PerMessageDeflate(False, False, 15, 15)]) + class PerMessageDeflateTests(unittest.TestCase): From efd5d086c8fa5b593fcf7df31ba55dac9f012033 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 27 Aug 2017 15:13:56 +0200 Subject: [PATCH 0301/1539] Update compliance notes. --- compliance/README.rst | 19 +++++++++++++++---- compliance/fuzzingclient.json | 4 ++-- compliance/fuzzingserver.json | 4 ++-- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/compliance/README.rst b/compliance/README.rst index cfaaafca7..c9605f675 100644 --- a/compliance/README.rst +++ b/compliance/README.rst @@ -4,21 +4,28 @@ Autobahn Testsuite General information and installation instructions are available at http://autobahn.ws/testsuite. +To improve performance, you should compile the C extension first:: + + $ python setup.py build_ext --inplace + Running the test suite ---------------------- +All commands below must be run from the directory containing this file. + To test the server:: - $ python test_server.py + $ PYTHONPATH=.. python test_server.py $ wstest -m fuzzingclient To test the client:: $ wstest -m fuzzingserver - $ python test_client.py + $ PYTHONPATH=.. python test_client.py -Run the first command in a shell. Run the second command in another shell. It -should take about one minute to complete. Then kill the first one with Ctrl-C. +Run the first command in a shell. Run the second command in another shell. +It should take about ten minutes to complete — wstest is the bottleneck. +Then kill the first one with Ctrl-C. The test client or server shouldn't display any exceptions. The results are stored in reports/clients/index.html. @@ -40,3 +47,7 @@ the previous frame. In 6.4.3 and 6.4.4, even though it uses an incremental decoder, ``websockets`` doesn't notice the invalid utf-8 fast enough to get a "Strict" pass. These tests are more strict than the RFC. + +12.4.* are skipped: https://github.com/crossbario/autobahn-testsuite/issues/77 + +12.5.* are skipped: https://github.com/crossbario/autobahn-testsuite/issues/77 diff --git a/compliance/fuzzingclient.json b/compliance/fuzzingclient.json index 9c4dd6342..c572d02e8 100644 --- a/compliance/fuzzingclient.json +++ b/compliance/fuzzingclient.json @@ -1,11 +1,11 @@ { "options": {"failByDrop": false}, - "outdir": "./reports/clients", + "outdir": "./reports/servers", "servers": [{"agent": "websockets", "url": "ws://localhost:8642", "options": {"version": 18}}], "cases": ["*"], - "exclude-cases": ["12.5.*"], + "exclude-cases": ["12.4.*", "12.5.*"], "exclude-agent-cases": {} } diff --git a/compliance/fuzzingserver.json b/compliance/fuzzingserver.json index b83bd5968..d7abd94c1 100644 --- a/compliance/fuzzingserver.json +++ b/compliance/fuzzingserver.json @@ -3,10 +3,10 @@ "url": "ws://localhost:8642", "options": {"failByDrop": false}, - "outdir": "./reports/servers", + "outdir": "./reports/clients", "webport": 8080, "cases": ["*"], - "exclude-cases": ["12.5.*"], + "exclude-cases": ["12.4.*", "12.5.*"], "exclude-agent-cases": {} } From 3de517876de82f8cac827eb8dc890d3769ab995f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 27 Aug 2017 15:44:25 +0200 Subject: [PATCH 0302/1539] Handle multiple HTTP headers with the same name. --- websockets/client.py | 47 +++++++++++++++++++++++++++----------------- websockets/server.py | 44 +++++++++++++++++++++++++---------------- 2 files changed, 56 insertions(+), 35 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index 2e93bc657..8002ac0d8 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -82,7 +82,8 @@ def read_http_response(self): return status_code, self.response_headers - def process_extensions(self, get_header, available_extensions): + @staticmethod + def process_extensions(headers, available_extensions): """ Handle the Sec-WebSocket-Extensions HTTP response header. @@ -108,16 +109,21 @@ def process_extensions(self, get_header, available_extensions): order of extensions, may be implemented by overriding this method. """ - extensions = get_header('Sec-WebSocket-Extensions') - accepted_extensions = [] - if extensions: + header_values = headers.get_all('Sec-WebSocket-Extensions') + + if header_values is not None: if available_extensions is None: raise InvalidHandshake("No extensions supported.") - for name, response_params in parse_extension_list(extensions): + parsed_header_values = sum([ + parse_extension_list(header_value) + for header_value in header_values + ], []) + + for name, response_params in parsed_header_values: for extension_factory in available_extensions: @@ -147,7 +153,8 @@ def process_extensions(self, get_header, available_extensions): return accepted_extensions - def process_subprotocol(self, get_header, available_subprotocols): + @staticmethod + def process_subprotocol(headers, available_subprotocols): """ Handle the Sec-WebSocket-Protocol HTTP response header. @@ -156,20 +163,23 @@ def process_subprotocol(self, get_header, available_subprotocols): Return the selected subprotocol. """ - subprotocol = get_header('Sec-WebSocket-Protocol') + subprotocol = None - if subprotocol: + header_values = headers.get_all('Sec-WebSocket-Protocol') + + if header_values is not None: if available_subprotocols is None: raise InvalidHandshake("No subprotocols supported.") + # TODO - handle the case when len(header_values) != 1 + subprotocol = header_values[0] + if subprotocol not in available_subprotocols: raise NegotiationError( "Unsupported subprotocol: {}".format(subprotocol)) - return subprotocol - - return None + return subprotocol @asyncio.coroutine def handshake(self, wsuri, origin=None, @@ -190,8 +200,8 @@ def handshake(self, wsuri, origin=None, It must be a mapping or an iterable of (name, value) pairs. """ - headers = [] - set_header = lambda k, v: headers.append((k, v)) + request_headers = [] + set_header = lambda k, v: request_headers.append((k, v)) if wsuri.port == (443 if wsuri.secure else 80): # pragma: no cover set_header('Host', wsuri.host) @@ -225,10 +235,11 @@ def handshake(self, wsuri, origin=None, key = build_request(set_header) - yield from self.write_http_request(wsuri.resource_name, headers) + yield from self.write_http_request( + wsuri.resource_name, request_headers) - status_code, headers = yield from self.read_http_response() - get_header = lambda k: headers.get(k, '') + status_code, response_headers = yield from self.read_http_response() + get_header = lambda k: response_headers.get(k, '') if status_code != 101: raise InvalidStatusCode(status_code) @@ -236,10 +247,10 @@ def handshake(self, wsuri, origin=None, check_response(get_header, key) self.extensions = self.process_extensions( - get_header, available_extensions) + response_headers, available_extensions) self.subprotocol = self.process_subprotocol( - get_header, available_subprotocols) + response_headers, available_subprotocols) assert self.state == CONNECTING self.state = OPEN diff --git a/websockets/server.py b/websockets/server.py index ff18312d9..c52bb3837 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -255,7 +255,8 @@ def process_origin(self, get_header, origins=None): raise InvalidOrigin("Origin not allowed: {}".format(origin)) return origin - def process_extensions(self, get_header, available_extensions): + @staticmethod + def process_extensions(headers, available_extensions): """ Handle the Sec-WebSocket-Extensions HTTP request header. @@ -287,14 +288,19 @@ def process_extensions(self, get_header, available_extensions): order of extensions, may be implemented by overriding this method. """ - extensions = get_header('Sec-WebSocket-Extensions') - response_header = [] accepted_extensions = [] - if extensions and available_extensions is not None: + header_values = headers.get_all('Sec-WebSocket-Extensions') + + if header_values is not None and available_extensions is not None: - for name, request_params in parse_extension_list(extensions): + parsed_header_values = sum([ + parse_extension_list(header_value) + for header_value in header_values + ], []) + + for name, request_params in parsed_header_values: for extension_factory in available_extensions: @@ -328,7 +334,8 @@ def process_extensions(self, get_header, available_extensions): return response_header, accepted_extensions - def process_subprotocol(self, get_header, available_subprotocols): + # Not @staticmethod because it calls self.select_subprotocol() + def process_subprotocol(self, headers, available_subprotocols): """ Handle the Sec-WebSocket-Protocol HTTP request header. @@ -336,19 +343,22 @@ def process_subprotocol(self, get_header, available_subprotocols): as the selected subprotocol. """ - subprotocols = get_header('Sec-WebSocket-Protocol') + subprotocols = None + + header_values = headers.get_all('Sec-WebSocket-Protocol') - if subprotocols and available_subprotocols is not None: - subprotocols = [ + if header_values is not None and available_subprotocols is not None: + parsed_header_values = [ subprotocol.strip() - for subprotocol in subprotocols.split(',') + for header_value in header_values + for subprotocol in header_value.split(',') ] - return self.select_subprotocol( - subprotocols, + subprotocols = self.select_subprotocol( + parsed_header_values, available_subprotocols, ) - return None + return subprotocols @staticmethod def select_subprotocol(client_subprotocols, server_subprotocols): @@ -415,10 +425,10 @@ def handshake(self, origins=None, available_extensions=None, self.origin = self.process_origin(get_header, origins) extensions_header, self.extensions = self.process_extensions( - get_header, available_extensions) + request_headers, available_extensions) - self.subprotocol = self.process_subprotocol( - get_header, available_subprotocols) + protocol_header = self.subprotocol = self.process_subprotocol( + request_headers, available_subprotocols) response_headers = [] set_header = lambda k, v: response_headers.append((k, v)) @@ -429,7 +439,7 @@ def handshake(self, origins=None, available_extensions=None, set_header('Sec-WebSocket-Extensions', extensions_header) if self.subprotocol is not None: - set_header('Sec-WebSocket-Protocol', self.subprotocol) + set_header('Sec-WebSocket-Protocol', protocol_header) if extra_headers is not None: if callable(extra_headers): From 22bff9c9d563d04679b81de023aab5ae424016aa Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 2 Sep 2017 10:14:02 +0200 Subject: [PATCH 0303/1539] Drop support for Python 3.3. Python 3.3 is reaching EOL at the end of the month. --- .travis.yml | 10 +++++----- README.rst | 6 +++--- appveyor.yml | 8 ++++---- circle.yml | 2 +- docs/index.rst | 3 +-- docs/intro.rst | 6 +++--- setup.cfg | 2 +- setup.py | 8 ++------ tox.ini | 4 +--- websockets/test_client_server.py | 10 ++-------- websockets/test_framing.py | 2 -- 11 files changed, 23 insertions(+), 38 deletions(-) diff --git a/.travis.yml b/.travis.yml index ddc0fc604..4c0abc0b5 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,16 +1,16 @@ env: global: -# Don't attempt to build Python 2.7 wheels; websockets only works on Python 3. - - CIBW_SKIP=cp27* -# Commented out because tests don't pass reliably on macOS, see #241. -# - CIBW_TEST_COMMAND="python3 -m unittest discover websockets" + # websockets only works on Python >= 3.4. + - CIBW_SKIP="cp27-* cp33-*" + # Commented out because tests don't pass reliably on macOS, see #241. + # - CIBW_TEST_COMMAND="python3 -m unittest discover websockets" matrix: include: - dist: trusty sudo: required language: python - python: "3.3" + python: "3.6" services: - docker - os: osx diff --git a/README.rst b/README.rst index 8183778c9..30ce838e7 100644 --- a/README.rst +++ b/README.rst @@ -9,9 +9,9 @@ Built on top of Python's asynchronous I/O support introduced in `PEP 3156`_, it provides an API based on coroutines, making it easy to write highly concurrent applications. -Installation is as simple as ``pip install websockets``. It requires Python ≥ -3.4 or Python 3.3 with the ``asyncio`` module, which is available with ``pip -install asyncio``. +Installation is as simple as ``pip install websockets``. + +It requires Python ≥ 3.4. Documentation is available on `Read the Docs`_. diff --git a/appveyor.yml b/appveyor.yml index 5f9a079bc..73ffb93ee 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -1,16 +1,16 @@ environment: -# Don't attempt to build Python 2.7 wheels; websockets only works on Python 3. - CIBW_SKIP: cp27* +# websockets only works on Python >= 3.4. + CIBW_SKIP: cp27-* cp33-* # Commented out because tests don't pass reliably on Windows, see #240. # CIBW_TEST_COMMAND: python -m unittest discover websockets # Since Python 2 is still the default, invoke Python 3 explicitly. install: - - cmd: C:\Python33-x64\python.exe -m pip install cibuildwheel==0.4.0 + - cmd: C:\Python36-x64\python.exe -m pip install cibuildwheel==0.4.0 # Create file '.cibuildwheel' so that extension build is not optional (c.f. setup.py). - cmd: touch .cibuildwheel build_script: - - cmd: C:\Python33-x64\python.exe -m cibuildwheel --output-dir wheelhouse + - cmd: C:\Python36-x64\python.exe -m cibuildwheel --output-dir wheelhouse # Upload to PyPI on tags - ps: >- if ($env:APPVEYOR_REPO_TAG -eq "true") { diff --git a/circle.yml b/circle.yml index 665e46fd5..1726ea432 100644 --- a/circle.yml +++ b/circle.yml @@ -1,6 +1,6 @@ machine: post: - - pyenv global 3.6.1 3.5.3 3.4.4 3.3.6 + - pyenv global 3.6.1 3.5.3 3.4.4 python: version: 3.6.1 diff --git a/docs/index.rst b/docs/index.rst index 2a875792e..30f4878b6 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -14,8 +14,7 @@ Installation Installation is as simple as ``pip install websockets``. -It requires Python ≥ 3.4 or Python 3.3 with the ``asyncio`` module, which is -available with ``pip install asyncio``. +It requires Python ≥ 3.4. User guide ---------- diff --git a/docs/intro.rst b/docs/intro.rst index 38df55ce7..be27aa3e8 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -3,8 +3,8 @@ Getting started .. warning:: - This documentation is written for Python ≥ 3.5. If you're using Python 3.4 - or 3.3, you will have to :ref:`adapt the code samples `. + This documentation is written for Python ≥ 3.5. If you're using Python + 3.4, you will have to :ref:`adapt the code samples `. Basic example ------------- @@ -143,7 +143,7 @@ Python < 3.5 This documentation uses the ``await`` and ``async`` syntax introduced in Python 3.5. -If you're using Python 3.4 or 3.3, you must substitute:: +If you're using Python 3.4, you must substitute:: async def ... diff --git a/setup.cfg b/setup.cfg index 08644c32e..dc625d46c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bdist_wheel] -python-tag = py33.py34.py35.py36 +python-tag = py34.py35.py36 [flake8] ignore = E731,F403,F405 diff --git a/setup.py b/setup.py index 10da2faa7..0911d0203 100644 --- a/setup.py +++ b/setup.py @@ -17,8 +17,8 @@ py_version = sys.version_info[:2] -if py_version < (3, 3): - raise Exception("websockets requires Python >= 3.3.") +if py_version < (3, 4): + raise Exception("websockets requires Python >= 3.4.") packages = ['websockets', 'websockets/extensions'] @@ -50,14 +50,10 @@ 'Operating System :: OS Independent', 'Programming Language :: Python', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.3', 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', ], packages=packages, ext_modules=ext_modules, - extras_require={ - ':python_version=="3.3"': ['asyncio'], - }, ) diff --git a/tox.ini b/tox.ini index 972301336..d23a2c217 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = {py33,py34,py35,py36}{,-speedups},coverage,flake8,isort +envlist = {py34,py35,py36}{,-speedups},coverage,flake8,isort [testenv] commands = @@ -17,8 +17,6 @@ commands = ; After testing with speedups, remove the extension. speedups: sh -c 'rm websockets/*.so' -deps = - py33: asyncio whitelist_externals = sh diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 28dcbed33..8affc8dd4 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -4,7 +4,6 @@ import logging import os import ssl -import sys import unittest import unittest.mock import urllib.request @@ -324,13 +323,8 @@ def test_custom_protocol_http_request(self): if self.secure: url = 'https://localhost:8642/__health__/' - if sys.version_info[:2] < (3, 4): # pragma: no cover - # Python 3.3 didn't check SSL certificates. - open_health_check = functools.partial( - urllib.request.urlopen, url) - else: # pragma: no cover - open_health_check = functools.partial( - urllib.request.urlopen, url, context=self.client_context) + open_health_check = functools.partial( + urllib.request.urlopen, url, context=self.client_context) else: url = 'http://localhost:8642/__health__/' open_health_check = functools.partial( diff --git a/websockets/test_framing.py b/websockets/test_framing.py index 9aec1ea17..04f8acda3 100644 --- a/websockets/test_framing.py +++ b/websockets/test_framing.py @@ -1,6 +1,5 @@ import asyncio import codecs -import sys import unittest import unittest.mock @@ -152,7 +151,6 @@ def test_serialize_close_errors(self): with self.assertRaises(WebSocketProtocolError): serialize_close(999, '') - @unittest.skipUnless(sys.version_info[:2] >= (3, 4), "rot13 is new in 3.4") def test_extensions(self): class Rot13: From fb71c7d8f81a5a36a0054ebbf4912922b0ed219f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 2 Sep 2017 10:14:30 +0200 Subject: [PATCH 0304/1539] Take advantage of TestCase.subTest. --- .../extensions/test_permessage_deflate.py | 74 +++++++++++-------- websockets/extensions/test_utils.py | 13 ++-- websockets/test_uri.py | 10 +-- websockets/test_utils.py | 13 ++-- 4 files changed, 65 insertions(+), 45 deletions(-) diff --git a/websockets/extensions/test_permessage_deflate.py b/websockets/extensions/test_permessage_deflate.py index 864edc357..a72beec42 100644 --- a/websockets/extensions/test_permessage_deflate.py +++ b/websockets/extensions/test_permessage_deflate.py @@ -28,8 +28,9 @@ def test_init(self): (False, False, None, True), # client_max_window_bits (False, False, None, None, {'memLevel': 4}), ]: - # This does not raise an exception. - ClientPerMessageDeflateFactory(*config) + with self.subTest(config=config): + # This does not raise an exception. + ClientPerMessageDeflateFactory(*config) def test_init_error(self): for config in [ @@ -40,8 +41,9 @@ def test_init_error(self): (False, False, True, None), # server_max_window_bits (False, False, None, None, {'wbits': 11}), ]: - with self.assertRaises(ValueError): - ClientPerMessageDeflateFactory(*config) + with self.subTest(config=config): + with self.assertRaises(ValueError): + ClientPerMessageDeflateFactory(*config) def test_get_request_params(self): for config, result in [ @@ -85,8 +87,9 @@ def test_get_request_params(self): ], ), ]: - factory = ClientPerMessageDeflateFactory(*config) - self.assertEqual(factory.get_request_params(), result) + with self.subTest(config=config, result=result): + factory = ClientPerMessageDeflateFactory(*config) + self.assertEqual(factory.get_request_params(), result) def test_process_response_params(self): for config, response_params, result in [ @@ -285,15 +288,20 @@ def test_process_response_params(self): (True, True, 12, 12), ), ]: - factory = ClientPerMessageDeflateFactory(*config) - if isinstance(result, type) and issubclass(result, Exception): - with self.assertRaises(result): - factory.process_response_params(response_params, []) - else: - extension = factory.process_response_params( - response_params, []) - expected = PerMessageDeflate(*result) - self.assertExtensionEqual(extension, expected) + with self.subTest( + config=config, + response_params=response_params, + result=result, + ): + factory = ClientPerMessageDeflateFactory(*config) + if isinstance(result, type) and issubclass(result, Exception): + with self.assertRaises(result): + factory.process_response_params(response_params, []) + else: + extension = factory.process_response_params( + response_params, []) + expected = PerMessageDeflate(*result) + self.assertExtensionEqual(extension, expected) def test_process_response_params_deduplication(self): factory = ClientPerMessageDeflateFactory(False, False, None, None) @@ -316,8 +324,9 @@ def test_init(self): (True, True, None, 15), # client_max_window_bits ≤ 15 (False, False, None, None, {'memLevel': 4}), ]: - # This does not raise an exception. - ServerPerMessageDeflateFactory(*config) + with self.subTest(config=config): + # This does not raise an exception. + ServerPerMessageDeflateFactory(*config) def test_init_error(self): for config in [ @@ -329,8 +338,9 @@ def test_init_error(self): (False, False, True, None), # server_max_window_bits (False, False, None, None, {'wbits': 11}), ]: - with self.assertRaises(ValueError): - ServerPerMessageDeflateFactory(*config) + with self.subTest(config=config): + with self.assertRaises(ValueError): + ServerPerMessageDeflateFactory(*config) def test_process_request_params(self): # Parameters in result appear swapped vs. config because the order is @@ -580,16 +590,22 @@ def test_process_request_params(self): (True, True, 12, 12), ), ]: - factory = ServerPerMessageDeflateFactory(*config) - if isinstance(result, type) and issubclass(result, Exception): - with self.assertRaises(result): - factory.process_request_params(request_params, []) - else: - params, extension = factory.process_request_params( - request_params, []) - self.assertEqual(params, response_params) - expected = PerMessageDeflate(*result) - self.assertExtensionEqual(extension, expected) + with self.subTest( + config=config, + request_params=request_params, + response_params=response_params, + result=result, + ): + factory = ServerPerMessageDeflateFactory(*config) + if isinstance(result, type) and issubclass(result, Exception): + with self.assertRaises(result): + factory.process_request_params(request_params, []) + else: + params, extension = factory.process_request_params( + request_params, []) + self.assertEqual(params, response_params) + expected = PerMessageDeflate(*result) + self.assertExtensionEqual(extension, expected) def test_process_response_params_deduplication(self): factory = ServerPerMessageDeflateFactory(False, False, None, None) diff --git a/websockets/extensions/test_utils.py b/websockets/extensions/test_utils.py index 811a387cc..8e1f888be 100644 --- a/websockets/extensions/test_utils.py +++ b/websockets/extensions/test_utils.py @@ -45,10 +45,10 @@ def test_parse_extension_list(self): [('permessage-deflate', [('server_max_window_bits', '10')])], ), ]: - self.assertEqual(parse_extension_list(header), parsed) - # Also ensure that build_extension_list round-trips cleanly. - unparsed = build_extension_list(parsed) - self.assertEqual(parse_extension_list(unparsed), parsed) + with self.subTest(header=header, parsed=parsed): + self.assertEqual(parse_extension_list(header), parsed) + unparsed = build_extension_list(parsed) + self.assertEqual(parse_extension_list(unparsed), parsed) def test_parse_extension_list_invalid_header(self): for header in [ @@ -64,8 +64,9 @@ def test_parse_extension_list_invalid_header(self): # Value in quoted string parameter that isn't a token 'foo; bar=" "', ]: - with self.assertRaises(InvalidHeader): - parse_extension_list(header) + with self.subTest(header=header): + with self.assertRaises(InvalidHeader): + parse_extension_list(header) class ExtensionTestsMixin: diff --git a/websockets/test_uri.py b/websockets/test_uri.py index d1102ca65..d15df3b63 100644 --- a/websockets/test_uri.py +++ b/websockets/test_uri.py @@ -23,11 +23,11 @@ class URITests(unittest.TestCase): def test_success(self): for uri, parsed in VALID_URIS: - # wrap in `with self.subTest():` when dropping Python 3.3 - self.assertEqual(parse_uri(uri), parsed) + with self.subTest(uri=uri, parsed=parsed): + self.assertEqual(parse_uri(uri), parsed) def test_error(self): for uri in INVALID_URIS: - # wrap in `with self.subTest():` when dropping Python 3.3 - with self.assertRaises(InvalidURI): - parse_uri(uri) + with self.subTest(uri=uri): + with self.assertRaises(InvalidURI): + parse_uri(uri) diff --git a/websockets/test_utils.py b/websockets/test_utils.py index 8259b7490..7772dce72 100644 --- a/websockets/test_utils.py +++ b/websockets/test_utils.py @@ -16,7 +16,8 @@ def test_apply_mask(self): (b'abcdABCD', b'1234', b'PPPPpppp'), (b'abcdABCD' * 10, b'1234', b'PPPPpppp' * 10), ]: - self.assertEqual(self.apply_mask(data_in, mask), data_out) + with self.subTest(data_in=data_in, mask=mask, data_out=data_out): + self.assertEqual(self.apply_mask(data_in, mask), data_out) def test_apply_mask_check_input_types(self): for data_in, mask in [ @@ -24,8 +25,9 @@ def test_apply_mask_check_input_types(self): (b'abcd', None), (None, b'abcd'), ]: - with self.assertRaises(TypeError): - self.apply_mask(data_in, mask) + with self.subTest(data_in=data_in, mask=mask): + with self.assertRaises(TypeError): + self.apply_mask(data_in, mask) def test_apply_mask_check_mask_length(self): for data_in, mask in [ @@ -34,8 +36,9 @@ def test_apply_mask_check_mask_length(self): (b'', b'aBcDe'), (b'12345678', b'12345678'), ]: - with self.assertRaises(ValueError): - self.apply_mask(data_in, mask) + with self.subTest(data_in=data_in, mask=mask): + with self.assertRaises(ValueError): + self.apply_mask(data_in, mask) try: From cf5b5e74521c84527162d7b8932dbbaaa6f2d4c0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 2 Sep 2017 10:29:06 +0200 Subject: [PATCH 0305/1539] Take advantage of re.fullmatch. --- websockets/extensions/utils.py | 5 +---- websockets/http.py | 10 +++++----- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/websockets/extensions/utils.py b/websockets/extensions/utils.py index 843de3ee4..d0c64b746 100644 --- a/websockets/extensions/utils.py +++ b/websockets/extensions/utils.py @@ -26,9 +26,6 @@ def parse_OWS(string, pos): _token_re = re.compile(r'[-!#$%&\'*+.^_`|~0-9a-zA-Z]+') -# Workaround for the lack of re.fullmatch in older Pythons -_exact_token_re = re.compile(r'^[-!#$%&\'*+.^_`|~0-9a-zA-Z]+$') - def parse_token(string, pos): match = _token_re.match(string, pos) @@ -63,7 +60,7 @@ def parse_extension_param(string, pos): value, pos = parse_quoted_string(string, pos) # https://tools.ietf.org/html/rfc6455#section-9.1 says: the value # after quoted-string unescaping MUST conform to the 'token' ABNF. - if _exact_token_re.match(value) is None: + if _token_re.fullmatch(value) is None: raise InvalidHeader("invalid quoted string content", string=string, pos=pos_before) else: diff --git a/websockets/http.py b/websockets/http.py index 464e942a7..4b082c8f3 100644 --- a/websockets/http.py +++ b/websockets/http.py @@ -30,7 +30,7 @@ # Regex for validating header names. -_token_re = re.compile(rb'^[-!#$%&\'*+.^_`|~0-9a-zA-Z]+$') +_token_re = re.compile(rb'[-!#$%&\'*+.^_`|~0-9a-zA-Z]+') # Regex for validating header values. @@ -43,7 +43,7 @@ # See also https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189 -_value_re = re.compile(rb'^[\x09\x20-\x7e\x80-\xff]*$') +_value_re = re.compile(rb'[\x09\x20-\x7e\x80-\xff]*') @asyncio.coroutine @@ -127,7 +127,7 @@ def read_response(stream): status_code = int(status_code) if not 100 <= status_code < 1000: raise ValueError("Unsupported HTTP status_code code: %d" % status_code) - if not _value_re.match(reason): + if not _value_re.fullmatch(reason): raise ValueError("Invalid HTTP reason phrase: %r" % reason) headers = yield from read_headers(stream) @@ -160,10 +160,10 @@ def read_headers(stream): # This may raise "ValueError: not enough values to unpack" name, value = line[:-2].split(b':', 1) - if not _token_re.match(name): + if not _token_re.fullmatch(name): raise ValueError("Invalid HTTP header name: %r" % name) value = value.strip(b' \t') - if not _value_re.match(value): + if not _value_re.fullmatch(value): raise ValueError("Invalid HTTP header value: %r" % value) headers.append(( From 96bcc64e9098ce6c1c3597deca05c8aaf7f03f2b Mon Sep 17 00:00:00 2001 From: Edward Betts Date: Fri, 1 Sep 2017 21:56:24 +0100 Subject: [PATCH 0306/1539] correct spelling mistake --- docs/limitations.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/limitations.rst b/docs/limitations.rst index 8cf5314d9..c7f6e688b 100644 --- a/docs/limitations.rst +++ b/docs/limitations.rst @@ -5,7 +5,7 @@ Extensions_ aren't implemented. No extensions are registered_ at the time of writing. The client doesn't attempt to guarantee that there is no more than one -connection to a given IP adress in a CONNECTING state. +connection to a given IP address in a CONNECTING state. The client doesn't support connecting through a proxy. From 9f8539f3060c8c5c19e1ddeda146716b49382d34 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 3 Sep 2017 18:13:10 +0200 Subject: [PATCH 0307/1539] Extensions are implemented now. --- docs/limitations.rst | 6 ------ 1 file changed, 6 deletions(-) diff --git a/docs/limitations.rst b/docs/limitations.rst index c7f6e688b..d0b9743fc 100644 --- a/docs/limitations.rst +++ b/docs/limitations.rst @@ -1,13 +1,7 @@ Limitations ----------- -Extensions_ aren't implemented. No extensions are registered_ at the time of -writing. - The client doesn't attempt to guarantee that there is no more than one connection to a given IP address in a CONNECTING state. The client doesn't support connecting through a proxy. - -.. _Extensions: http://tools.ietf.org/html/rfc6455#section-9 -.. _registered: http://www.iana.org/assignments/websocket/websocket.xml From 37650ae60cc92f2073f271e57af2bfc27cdc9636 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 27 Aug 2017 16:27:06 +0200 Subject: [PATCH 0308/1539] Parse Sec-WebSocket-Protocol robustly. Fix #257. --- docs/api.rst | 3 + websockets/client.py | 20 ++++- .../extensions/test_permessage_deflate.py | 14 +++- .../{extensions/utils.py => headers.py} | 84 ++++++++++++++++++- websockets/http.py | 8 +- websockets/server.py | 19 +++-- websockets/test_client_server.py | 10 +++ .../test_utils.py => test_headers.py} | 51 +++++++---- 8 files changed, 175 insertions(+), 34 deletions(-) rename websockets/{extensions/utils.py => headers.py} (67%) rename websockets/{extensions/test_utils.py => test_headers.py} (62%) diff --git a/docs/api.rst b/docs/api.rst index fbcec2de8..302fe9829 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -108,5 +108,8 @@ URI parser Utilities ......... +.. automodule:: websockets.headers + :members: + .. automodule:: websockets.http :members: diff --git a/websockets/client.py b/websockets/client.py index 8002ac0d8..5e2eeb848 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -10,8 +10,11 @@ InvalidHandshake, InvalidMessage, InvalidStatusCode, NegotiationError ) from .extensions.permessage_deflate import ClientPerMessageDeflateFactory -from .extensions.utils import build_extension_list, parse_extension_list from .handshake import build_request, check_response +from .headers import ( + build_extension_list, build_protocol_list, parse_extension_list, + parse_protocol_list +) from .http import USER_AGENT, build_headers, read_response from .protocol import CONNECTING, OPEN, WebSocketCommonProtocol from .uri import parse_uri @@ -172,8 +175,17 @@ def process_subprotocol(headers, available_subprotocols): if available_subprotocols is None: raise InvalidHandshake("No subprotocols supported.") - # TODO - handle the case when len(header_values) != 1 - subprotocol = header_values[0] + parsed_header_values = sum([ + parse_protocol_list(header_value) + for header_value in header_values + ], []) + + if len(parsed_header_values) > 1: + raise InvalidHandshake( + "Multiple subprotocols: {}".format( + ', '.join(parsed_header_values))) + + subprotocol = parsed_header_values[0] if subprotocol not in available_subprotocols: raise NegotiationError( @@ -222,7 +234,7 @@ def handshake(self, wsuri, origin=None, set_header('Sec-WebSocket-Extensions', extensions_header) if available_subprotocols is not None: - protocol_header = ', '.join(available_subprotocols) + protocol_header = build_protocol_list(available_subprotocols) set_header('Sec-WebSocket-Protocol', protocol_header) if extra_headers is not None: diff --git a/websockets/extensions/test_permessage_deflate.py b/websockets/extensions/test_permessage_deflate.py index a72beec42..b034a5089 100644 --- a/websockets/extensions/test_permessage_deflate.py +++ b/websockets/extensions/test_permessage_deflate.py @@ -10,7 +10,19 @@ serialize_close ) from .permessage_deflate import * -from .test_utils import ExtensionTestsMixin + + +class ExtensionTestsMixin: + + def assertExtensionEqual(self, extension1, extension2): + self.assertEqual(extension1.remote_no_context_takeover, + extension2.remote_no_context_takeover) + self.assertEqual(extension1.local_no_context_takeover, + extension2.local_no_context_takeover) + self.assertEqual(extension1.remote_max_window_bits, + extension2.remote_max_window_bits) + self.assertEqual(extension1.local_max_window_bits, + extension2.local_max_window_bits) class ClientPerMessageDeflateFactoryTests(unittest.TestCase, diff --git a/websockets/extensions/utils.py b/websockets/headers.py similarity index 67% rename from websockets/extensions/utils.py rename to websockets/headers.py index d0c64b746..b1459b2b8 100644 --- a/websockets/extensions/utils.py +++ b/websockets/headers.py @@ -1,9 +1,21 @@ +""" +The :mod:`websockets.headers` module provides parsers and serializers for HTTP +headers used in WebSocket handshake messages. + +Its functions cannot be imported from :mod:`websockets`. They must be imported +from :mod:`websockets.headers`. + +""" + import re -from ..exceptions import InvalidHeader +from .exceptions import InvalidHeader -__all__ = ['build_extension_list', 'parse_extension_list'] +__all__ = [ + 'parse_extension_list', 'build_extension_list', + 'parse_protocol_list', 'build_protocol_list', +] # To avoid a dependency on a parsing library, we implement manually the ABNF @@ -91,7 +103,7 @@ def parse_extension_list(string, pos=0): The string is assumed not to start or end with whitespace. - The return value has the following format:: + Return a value with the following format:: [ ( @@ -167,3 +179,69 @@ def build_extension_list(extensions): build_extension(name, parameters) for name, parameters in extensions ) + + +def parse_protocol(string, pos): + name, pos = parse_token(string, pos) + pos = parse_OWS(string, pos) + return name, pos + + +def parse_protocol_list(string, pos=0): + """ + Parse a Sec-WebSocket-Protocol header. + + The string is assumed not to start or end with whitespace. + + Return a list of protocols. + + Raise InvalidHeader if the header cannot be parsed. + + """ + # Per https://tools.ietf.org/html/rfc7230#section-7, "a recipient MUST + # parse and ignore a reasonable number of empty list elements"; hence + # while loops that remove extra delimiters. + + # Remove extra delimiters before the first extension. + while peek_ahead(string, pos) == ',': + pos = parse_OWS(string, pos + 1) + + protocols = [] + while True: + # Loop invariant: a protocol starts at pos in string. + protocol, pos = parse_protocol(string, pos) + protocols.append(protocol) + + # We may have reached the end of the string. + if pos == len(string): + break + + # There must be a delimiter after each element except the last one. + if peek_ahead(string, pos) == ',': + pos = parse_OWS(string, pos + 1) + else: + raise InvalidHeader("expected comma", string=string, pos=pos) + + # Remove extra delimiters before the next protocol. + while peek_ahead(string, pos) == ',': + pos = parse_OWS(string, pos + 1) + + # We may have reached the end of the string. + if pos == len(string): + break + + # Since we only advance in the string by one character with peek_ahead() + # or with the end position of a regex match, we can't overshoot the end. + assert pos == len(string) + + return protocols + + +def build_protocol_list(protocols): + """ + Unparse a Sec-WebSocket-Protocol header. + + This is the reverse of parse_protocol_list. + + """ + return ', '.join(protocols) diff --git a/websockets/http.py b/websockets/http.py index 4b082c8f3..9d2316b70 100644 --- a/websockets/http.py +++ b/websockets/http.py @@ -1,9 +1,9 @@ """ -The :mod:`websockets.http` module provides HTTP parsing functions. They're -merely adequate for the WebSocket handshake messages. +The :mod:`websockets.http` module provides basic HTTP parsing and +serialization. It is merely adequate for WebSocket handshake messages. -These functions cannot be imported from :mod:`websockets`; they must be -imported from :mod:`websockets.http`. +Its functions cannot be imported from :mod:`websockets`. They must be imported +from :mod:`websockets.http`. """ diff --git a/websockets/server.py b/websockets/server.py index c52bb3837..6674da4ca 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -16,8 +16,10 @@ NegotiationError ) from .extensions.permessage_deflate import ServerPerMessageDeflateFactory -from .extensions.utils import build_extension_list, parse_extension_list from .handshake import build_response, check_request +from .headers import ( + build_extension_list, parse_extension_list, parse_protocol_list +) from .http import USER_AGENT, build_headers, read_request from .protocol import CONNECTING, OPEN, WebSocketCommonProtocol @@ -343,22 +345,23 @@ def process_subprotocol(self, headers, available_subprotocols): as the selected subprotocol. """ - subprotocols = None + subprotocol = None header_values = headers.get_all('Sec-WebSocket-Protocol') if header_values is not None and available_subprotocols is not None: - parsed_header_values = [ - subprotocol.strip() + + parsed_header_values = sum([ + parse_protocol_list(header_value) for header_value in header_values - for subprotocol in header_value.split(',') - ] - subprotocols = self.select_subprotocol( + ], []) + + subprotocol = self.select_subprotocol( parsed_header_values, available_subprotocols, ) - return subprotocols + return subprotocol @staticmethod def select_subprotocol(client_subprotocols, server_subprotocols): diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 8affc8dd4..b2816af6f 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -596,6 +596,16 @@ def test_subprotocol_error_no_subprotocols(self, _process_subprotocol): self.start_client('subprotocol') self.run_loop_once() + @with_server(subprotocols=['superchat', 'chat']) + @unittest.mock.patch.object(WebSocketServerProtocol, 'process_subprotocol') + def test_subprotocol_error_two_subprotocols(self, _process_subprotocol): + _process_subprotocol.return_value = 'superchat, chat' + + with self.assertRaises(InvalidHandshake): + self.start_client( + 'subprotocol', subprotocols=['superchat', 'chat']) + self.run_loop_once() + @with_server() @unittest.mock.patch('websockets.server.read_request') def test_server_receives_malformed_request(self, _read_request): diff --git a/websockets/extensions/test_utils.py b/websockets/test_headers.py similarity index 62% rename from websockets/extensions/test_utils.py rename to websockets/test_headers.py index 8e1f888be..9311a4a68 100644 --- a/websockets/extensions/test_utils.py +++ b/websockets/test_headers.py @@ -1,10 +1,10 @@ import unittest -from ..exceptions import InvalidHeader -from .utils import * +from .exceptions import InvalidHeader +from .headers import * -class UtilsTests(unittest.TestCase): +class HeadersTests(unittest.TestCase): def test_parse_extension_list(self): for header, parsed in [ @@ -47,6 +47,7 @@ def test_parse_extension_list(self): ]: with self.subTest(header=header, parsed=parsed): self.assertEqual(parse_extension_list(header), parsed) + # Also ensure that build_extension_list round-trips cleanly. unparsed = build_extension_list(parsed) self.assertEqual(parse_extension_list(unparsed), parsed) @@ -68,15 +69,37 @@ def test_parse_extension_list_invalid_header(self): with self.assertRaises(InvalidHeader): parse_extension_list(header) + def test_parse_protocol_list(self): + for header, parsed in [ + # Synthetic examples + ( + 'foo', + ['foo'], + ), + ( + 'foo, bar', + ['foo', 'bar'], + ), + # Pathological examples + ( + ',\t, , ,foo ,, bar,baz,,', + ['foo', 'bar', 'baz'], + ), + ]: + with self.subTest(header=header, parsed=parsed): + self.assertEqual(parse_protocol_list(header), parsed) + # Also ensure that build_protocol_list round-trips cleanly. + unparsed = build_protocol_list(parsed) + self.assertEqual(parse_protocol_list(unparsed), parsed) -class ExtensionTestsMixin: - - def assertExtensionEqual(self, extension1, extension2): - self.assertEqual(extension1.remote_no_context_takeover, - extension2.remote_no_context_takeover) - self.assertEqual(extension1.local_no_context_takeover, - extension2.local_no_context_takeover) - self.assertEqual(extension1.remote_max_window_bits, - extension2.remote_max_window_bits) - self.assertEqual(extension1.local_max_window_bits, - extension2.local_max_window_bits) + def test_parse_protocol_list_invalid_header(self): + for header in [ + # Truncated examples + '', + ',\t,' + # Wrong delimiter + 'foo; bar', + ]: + with self.subTest(header=header): + with self.assertRaises(InvalidHeader): + parse_protocol_list(header) From 0ae899ced8e38f1222aaa8dcd2da19a193c7f5c6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 3 Sep 2017 20:45:39 +0200 Subject: [PATCH 0309/1539] Avoid leaking tasks when close() is cancelled. Fix #142. Thanks @cjerdonek for diagnosing this tricky issue. Also add a test for a code path that was exercised by accident before this change. --- docs/changelog.rst | 5 +++++ websockets/protocol.py | 2 ++ websockets/test_client_server.py | 14 ++++++++++++++ websockets/test_protocol.py | 22 ++++++++++++++++++++++ 4 files changed, 43 insertions(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index 39d5ceaf7..016a81f6c 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -17,6 +17,11 @@ Changelog If you want to disable compression, add ``compression=None`` when calling :func:`~websockets.server.serve` or :func:`~websockets.client.connect`. +Also: + +* Stopped leaking pending tasks when :meth:`~asyncio.Task.cancel` is called on + a connection while it's being closed. + 3.4 ... diff --git a/websockets/protocol.py b/websockets/protocol.py index 998f99f2b..854081ff0 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -261,6 +261,8 @@ def close(self, code=1000, reason=''): try: yield from asyncio.wait_for( self.worker_task, self.timeout, loop=self.loop) + except asyncio.CancelledError: + pass except asyncio.TimeoutError: self.worker_task.cancel() diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index b2816af6f..29b656347 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -723,6 +723,20 @@ def test_server_shuts_down_during_connection_handling(self): # Websocket connection terminates with 1001 Going Away. self.assertEqual(self.client.close_code, 1001) + @with_server() + @unittest.mock.patch('websockets.server.WebSocketServerProtocol.close') + def test_server_shuts_down_during_connection_close(self, _close): + _close.side_effect = asyncio.CancelledError + + self.server.closing = True + with self.temp_client(): + self.loop.run_until_complete(self.client.send("Hello!")) + reply = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(reply, "Hello!") + + # Websocket connection terminates abnormally. + self.assertEqual(self.client.close_code, 1006) + @with_server(create_protocol=ForbiddenServerProtocol) def test_invalid_status_error_during_client_connect(self): with self.assertRaises(InvalidStatusCode) as raised: diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 850f80e12..7faaddc37 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -742,6 +742,28 @@ def test_remote_close_during_send(self): # There is no test_local_close_during_send because this cannot really # happen, considering that writes are serialized. + def test_cancelled_close_waits_for_worker(self): + # Regression test for #142. + + # Start the closing handshake. + close_task = self.ensure_future(self.protocol.close(reason='close')) + self.run_loop_once() + self.assertOneFrameSent(*self.close_frame) + + # Now close_task is waiting for worker_task which is waiting for the + # closing handshake to complete. + + # Cancelling close_task throws a CancelledError into worker_task, + # which catches that exception and waits for close_connection(). + self.loop.call_later(MS, close_task.cancel) + # close_task resumes waiting for worker_task. Drop the connection so + # that close_connection(), worker_task and close_task terminate. + self.loop.call_later(2 * MS, self.receive_eof) + + # Make sure the worker task terminated before close(). + self.loop.run_until_complete(close_task) + self.assertTrue(self.protocol.worker_task.done()) + class ServerTests(CommonTests, unittest.TestCase): From 570661453b7c69d133c0d97f45ebc09809d0ee87 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 3 Sep 2017 21:59:47 +0200 Subject: [PATCH 0310/1539] Reduce verbosity of fail connection log. "Failing the WebSocket connection" is an essential concept in the RFC, but it's only relevant for developers of the library, not end users. Fix #167. --- docs/changelog.rst | 2 ++ websockets/protocol.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 016a81f6c..88dd14e54 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -22,6 +22,8 @@ Also: * Stopped leaking pending tasks when :meth:`~asyncio.Task.cancel` is called on a connection while it's being closed. +* Reduced verbosity of "Failing the WebSocket connection" logs. + 3.4 ... diff --git a/websockets/protocol.py b/websockets/protocol.py index 854081ff0..89862e884 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -649,7 +649,7 @@ def close_connection(self, force=False): @asyncio.coroutine def fail_connection(self, code=1011, reason=''): # 7.1.7. Fail the WebSocket Connection - logger.info("Failing the WebSocket connection: %d %s", code, reason) + logger.debug("Failing the WebSocket connection: %d %s", code, reason) if self.state == OPEN: if code == 1006: # Don't send a close frame if the connection is broken. Set From 32081781e99d9a8244c262da838f5c2b0bbf6b54 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 3 Sep 2017 22:57:53 +0200 Subject: [PATCH 0311/1539] Support asynchronous iteration of protocol instances. Kudos to @cjerdonek for the idea and proof of concept. Fix #192. Ref #195. --- MANIFEST.in | 1 + docs/changelog.rst | 3 + docs/intro.rst | 9 +++ setup.py | 3 + websockets/client.py | 6 +- websockets/protocol.py | 19 +++++ ...lient_server.py => _test_client_server.py} | 4 +- websockets/py36/__init__.py | 2 + websockets/py36/_test_client_server.py | 81 +++++++++++++++++++ websockets/py36/protocol.py | 12 +++ websockets/server.py | 5 +- websockets/test_client_server.py | 12 +-- 12 files changed, 144 insertions(+), 13 deletions(-) rename websockets/py35/{client_server.py => _test_client_server.py} (92%) create mode 100644 websockets/py36/__init__.py create mode 100644 websockets/py36/_test_client_server.py create mode 100644 websockets/py36/protocol.py diff --git a/MANIFEST.in b/MANIFEST.in index 09205fb4b..9f4f1787e 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,4 @@ include LICENSE graft websockets/py35 +graft websockets/py36 diff --git a/docs/changelog.rst b/docs/changelog.rst index 88dd14e54..dfa9c3346 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -19,6 +19,9 @@ Changelog Also: +* :class:`~websockets.protocol.WebSocketCommonProtocol` instances can be used + as asynchronous iterators on Python ≥ 3.6. They yield incoming messages. + * Stopped leaking pending tasks when :meth:`~asyncio.Task.cancel` is called on a connection while it's being closed. diff --git a/docs/intro.rst b/docs/intro.rst index be27aa3e8..0b71f571c 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -59,6 +59,15 @@ Consumer For receiving messages and passing them to a ``consumer`` coroutine:: + async def consumer_handler(websocket, path): + async for message in websocket: + await consumer(message) + +Iteration terminates when the client disconnects. + +Asynchronous iteration isn't available in Python < 3.6; here's the same code +for earlier Python versions:: + async def consumer_handler(websocket, path): while True: message = await websocket.recv() diff --git a/setup.py b/setup.py index 0911d0203..bdb8f85cd 100644 --- a/setup.py +++ b/setup.py @@ -25,6 +25,9 @@ if py_version >= (3, 5): packages.append('websockets/py35') +if py_version >= (3, 6): + packages.append('websockets/py36') + ext_modules = [ setuptools.Extension( 'websockets.speedups', diff --git a/websockets/client.py b/websockets/client.py index 5e2eeb848..521a3c66e 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -283,6 +283,9 @@ def connect(uri, *, It yields a :class:`WebSocketClientProtocol` which can then be used to send and receive messages. + On Python ≥ 3.5, :func:`connect` can be used as a asynchronous context + manager. In that case, the connection is closed when exiting the context. + :func:`connect` is a wrapper around the event loop's :meth:`~asyncio.BaseEventLoop.create_connection` method. Unknown keyword arguments are passed to :meth:`~asyncio.BaseEventLoop.create_connection`. @@ -319,9 +322,6 @@ def connect(uri, *, invalid and :exc:`~websockets.handshake.InvalidHandshake` if the opening handshake fails. - On Python 3.5, :func:`connect` can be used as a asynchronous context - manager. In that case, the connection is closed when exiting the context. - """ if loop is None: loop = asyncio.get_event_loop() diff --git a/websockets/protocol.py b/websockets/protocol.py index 89862e884..71a14b14d 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -49,6 +49,17 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): control frames automatically. It sends outgoing data frames and performs the closing handshake. + On Python ≥ 3.6, :class:`WebSocketCommonProtocol` instances support + asynchronous iteration:: + + async for message in websocket: + await process(message) + + The iterator yields incoming messages. It exits normally when the + connection is closed with the status code 1000 OK. It raises a + :exc:`~websockets.exceptions.ConnectionClosed` exception when the + connection is closed with any other status code. + The ``host``, ``port`` and ``secure`` parameters are simply stored as attributes for handlers that need them. @@ -701,3 +712,11 @@ def connection_lost(self, exc): if self.writer is not None: self.writer.close() super().connection_lost(exc) + + +try: + from .py36.protocol import __aiter__ +except (SyntaxError, ImportError): # pragma: no cover + pass +else: + WebSocketCommonProtocol.__aiter__ = __aiter__ diff --git a/websockets/py35/client_server.py b/websockets/py35/_test_client_server.py similarity index 92% rename from websockets/py35/client_server.py rename to websockets/py35/_test_client_server.py index 624824aa1..1e69a8675 100644 --- a/websockets/py35/client_server.py +++ b/websockets/py35/_test_client_server.py @@ -1,14 +1,14 @@ # Tests containing Python 3.5+ syntax, extracted from test_client_server.py. -# To avoid test discovery, this module's name must not start with test_. import asyncio +import unittest from ..client import * from ..server import * from ..test_client_server import handler -class ClientServerContextManager: +class ContextManagerTests(unittest.TestCase): def setUp(self): self.loop = asyncio.new_event_loop() diff --git a/websockets/py36/__init__.py b/websockets/py36/__init__.py new file mode 100644 index 000000000..396f34968 --- /dev/null +++ b/websockets/py36/__init__.py @@ -0,0 +1,2 @@ +# This package contains code using async iteratino added in Python 3.6. +# It cannot be imported on Python < 3.6 because it triggers syntax errors. diff --git a/websockets/py36/_test_client_server.py b/websockets/py36/_test_client_server.py new file mode 100644 index 000000000..cfa2760fe --- /dev/null +++ b/websockets/py36/_test_client_server.py @@ -0,0 +1,81 @@ +# Tests containing Python 3.6+ syntax, extracted from test_client_server.py. + +import asyncio +import sys +import unittest + +from ..client import * +from ..exceptions import ConnectionClosed +from ..server import * + + +# Fail at import time, not just at run time, to prevent test +# discovery. +if sys.version_info[:2] < (3, 6): # pragma: no cover + raise ImportError("Python 3.6+ only") + + +MESSAGES = ['3', '2', '1', 'Fire!'] + + +class AsyncIteratorTests(unittest.TestCase): + + # This is a protocol-level feature, but since it's a high-level API, it is + # much easier to exercise at the client or server level. + + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + + def test_iterate_on_messages(self): + + async def handler(ws, path): + for message in MESSAGES: + await ws.send(message) + + server = serve(handler, 'localhost', 8642) + self.server = self.loop.run_until_complete(server) + + messages = [] + + async def run_client(): + nonlocal messages + async with connect('ws://localhost:8642/') as ws: + async for message in ws: + messages.append(message) + + self.loop.run_until_complete(run_client()) + + self.assertEqual(messages, MESSAGES) + + self.server.close() + self.loop.run_until_complete(self.server.wait_closed()) + + def test_iterate_on_messages_exit_not_ok(self): + + async def handler(ws, path): + for message in MESSAGES: + await ws.send(message) + await ws.close(1001) + + server = serve(handler, 'localhost', 8642) + self.server = self.loop.run_until_complete(server) + + messages = [] + + async def run_client(): + nonlocal messages + async with connect('ws://localhost:8642/') as ws: + async for message in ws: + messages.append(message) + + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(run_client()) + + self.assertEqual(messages, MESSAGES) + + self.server.close() + self.loop.run_until_complete(self.server.wait_closed()) diff --git a/websockets/py36/protocol.py b/websockets/py36/protocol.py new file mode 100644 index 000000000..37b7b3477 --- /dev/null +++ b/websockets/py36/protocol.py @@ -0,0 +1,12 @@ +from ..exceptions import ConnectionClosed + + +async def __aiter__(self): + try: + while True: + yield await self.recv() + except ConnectionClosed as exc: + if exc.code == 1000: + return + else: + raise diff --git a/websockets/server.py b/websockets/server.py index 6674da4ca..499d863f0 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -584,9 +584,8 @@ def serve(ws_handler, host=None, port=None, *, up by calling its :meth:`~websockets.server.WebSocketServer.close` and :meth:`~websockets.server.WebSocketServer.wait_closed` methods. - On Python 3.5 and greater, :func:`serve` can also be used as an - asynchronous context manager. In this case, the server is shut down - when exiting the context. + On Python ≥ 3.5, :func:`serve` can also be used as an asynchronous context + manager. In this case, the server is shut down when exiting the context. The ``ws_handler`` argument is the WebSocket handler. It must be a coroutine accepting two arguments: a :class:`WebSocketServerProtocol` diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 29b656347..97a700aec 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -856,10 +856,12 @@ def test_checking_lack_of_origin_succeeds(self): try: - from .py35.client_server import ClientServerContextManager + from .py35._test_client_server import ContextManagerTests # noqa +except (SyntaxError, ImportError): # pragma: no cover + pass + + +try: + from .py36._test_client_server import AsyncIteratorTests # noqa except (SyntaxError, ImportError): # pragma: no cover pass -else: - class ClientServerContextManagerTests(ClientServerContextManager, - unittest.TestCase): - pass From 3ecd5475ee3b39c9698e46e9b53418c74b1c720d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 9 Sep 2017 13:58:09 +0200 Subject: [PATCH 0312/1539] Refactor asyncio.StreamReaderProtocol methods. * Move client_connected() next to __init__(): that's where it's * referenced. * Set write transport limits in connection_made(): that's more obvious. * Add comments and docstrings. --- websockets/protocol.py | 75 +++++++++++++++++++++++++++++++----------- 1 file changed, 55 insertions(+), 20 deletions(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index 71a14b14d..d92f3f795 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -148,7 +148,10 @@ def __init__(self, *, self.legacy_recv = legacy_recv - # This limit is both the line length limit and half the buffer limit. + # Configure read buffer limits. The high-water limit is defined by + # ``self.read_limit``. The ``limit`` argument controls the line length + # limit and half the buffer limit of :class:`~asyncio.StreamReader`. + # That's why it must be set to half of ``self.read_limit``. stream_reader = asyncio.StreamReader(limit=read_limit // 2, loop=loop) super().__init__(stream_reader, self.client_connected, loop) @@ -192,6 +195,19 @@ def __init__(self, *, if self.state == OPEN: self.opening_handshake.set_result(True) + def client_connected(self, reader, writer): + """ + Callback for :class:`~asyncio.StreamReaderProtocol`. + + Record references to the stream reader and the stream writer to avoid + using private APIs``self._stream_reader`` and ``self._stream_writer``. + + """ + self.reader = reader + self.writer = writer + # Start the task that handles incoming messages. + self.worker_task = asyncio_ensure_future(self.run(), loop=self.loop) + # Public API @property @@ -674,32 +690,51 @@ def fail_connection(self, code=1011, reason=''): self.closing_handshake.set_result(False) yield from self.close_connection() - # asyncio StreamReaderProtocol methods + # asyncio.StreamReaderProtocol methods - def client_connected(self, reader, writer): - self.reader = reader - self.writer = writer - # Configure write buffer limit. - self.writer._transport.set_write_buffer_limits(self.write_limit) - # Start the task that handles incoming messages. - self.worker_task = asyncio_ensure_future(self.run(), loop=self.loop) + def connection_made(self, transport): + """ + Configure write buffer limits. + + The high-water limit is defined by ``self.write_limit``. + + The low-water limit currently defaults to ``self.write_limit // 4`` in + :meth:`~asyncio.WriteTransport.set_write_buffer_limits`, which should + be all right for reasonable use cases of this library. + + """ + transport.set_write_buffer_limits(self.write_limit) + super().connection_made(transport) def eof_received(self): + """ + Ensure the transport is closed after receiving EOF. + + Since Python 3.5, `:meth:~StreamReaderProtocol.eof_received` returns + ``True``. See http://bugs.python.org/issue24539 for details. + + This is inappropriate for websockets for at least three reasons: + + 1. The use case is to read data until EOF with self.reader.read(-1). + Since websockets is a TLV protocol, this never happens. + + 2. It doesn't work on SSL connections. A falsy value must be + returned to have the same behavior on SSL and plain connections. + + 3. The websockets protocol has its own closing handshake. Endpoints + close the TCP connection after sending a Close frame. + + As a consequence we revert to the previous, more useful behavior. + + """ super().eof_received() - # Since Python 3.5, StreamReaderProtocol.eof_received() returns True - # to leave the transport open (http://bugs.python.org/issue24539). - # This is inappropriate for websockets for at least three reasons. - # 1. The use case is to read data until EOF with self.reader.read(-1). - # Since websockets is a TLV protocol, this never happens. - # 2. It doesn't work on SSL connections. A falsy value must be - # returned to have the same behavior on SSL and plain connections. - # 3. The websockets protocol has its own closing handshake. Endpoints - # close the TCP connection after sending a Close frame. - # As a consequence we revert to the previous, more useful behavior. return def connection_lost(self, exc): - # 7.1.4. The WebSocket Connection is Closed + """ + Implement section 7.1.4. The WebSocket Connection is Closed. + + """ self.state = CLOSED if not self.opening_handshake.done(): self.opening_handshake.set_result(False) From 4ad3b1a7f9860e4dd7916b0f8a4385b6435225ac Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 9 Sep 2017 17:45:43 +0200 Subject: [PATCH 0313/1539] Only set close code when receiving close frame. The RFC is clear about this: > _The WebSocket Connection Close Code_ is defined as the status code > (Section 7.4) contained in the first Close control frame received by > the application implementing this protocol. Also: * Differentiate between closing and failing the connection in tests. * Remove a test for setting the close code based on a local close code. --- websockets/protocol.py | 6 ++-- websockets/test_protocol.py | 64 +++++++++++++++++++------------------ 2 files changed, 36 insertions(+), 34 deletions(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index d92f3f795..078d9b994 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -544,10 +544,10 @@ def read_data_frame(self, max_size): if frame.opcode == OP_CLOSE: # Make sure the close frame is valid before echoing it. code, reason = parse_close(frame.data) + self.close_code, self.close_reason = code, reason if self.state == OPEN: # 7.1.3. The WebSocket Closing Handshake is Started yield from self.write_frame(OP_CLOSE, frame.data) - self.close_code, self.close_reason = code, reason self.closing_handshake.set_result(True) return @@ -686,7 +686,6 @@ def fail_connection(self, code=1011, reason=''): frame_data = serialize_close(code, reason) yield from self.write_frame(OP_CLOSE, frame_data) if not self.closing_handshake.done(): - self.close_code, self.close_reason = code, reason self.closing_handshake.set_result(False) yield from self.close_connection() @@ -738,8 +737,9 @@ def connection_lost(self, exc): self.state = CLOSED if not self.opening_handshake.done(): self.opening_handshake.set_result(False) + if self.close_code is None: + self.close_code = 1006 if not self.closing_handshake.done(): - self.close_code, self.close_reason = 1006, '' self.closing_handshake.set_result(False) if not self.connection_closed.done(): self.connection_closed.set_result(None) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 7faaddc37..3c1cff7e4 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -234,9 +234,23 @@ def assertNoFrameSent(self): def assertConnectionClosed(self, code, message): # The following line guarantees that connection_lost was called. self.assertEqual(self.protocol.state, CLOSED) + # A close frame was received. self.assertEqual(self.protocol.close_code, code) self.assertEqual(self.protocol.close_reason, message) + def assertConnectionFailed(self, code, message): + # The following line guarantees that connection_lost was called. + self.assertEqual(self.protocol.state, CLOSED) + # No close frame was received. + self.assertEqual(self.protocol.close_code, 1006) + self.assertEqual(self.protocol.close_reason, '') + # A close frame was sent -- unless the connection was already lost. + if code == 1006: + self.assertNoFrameSent() + else: + self.assertOneFrameSent( + True, OP_CLOSE, serialize_close(code, message)) + @contextlib.contextmanager def assertCompletesWithin(self, min_time, max_time): t0 = self.loop.time() @@ -307,24 +321,24 @@ def test_recv_on_closed_connection(self): def test_recv_protocol_error(self): self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8'))) self.process_invalid_frames() - self.assertConnectionClosed(1002, '') + self.assertConnectionFailed(1002, '') def test_recv_unicode_error(self): self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('latin-1'))) self.process_invalid_frames() - self.assertConnectionClosed(1007, '') + self.assertConnectionFailed(1007, '') def test_recv_text_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8') * 205)) self.process_invalid_frames() - self.assertConnectionClosed(1009, '') + self.assertConnectionFailed(1009, '') def test_recv_binary_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(True, OP_BINARY, b'tea' * 342)) self.process_invalid_frames() - self.assertConnectionClosed(1009, '') + self.assertConnectionFailed(1009, '') def test_recv_text_no_max_size(self): self.protocol.max_size = None # for test coverage @@ -346,7 +360,7 @@ def read_message(): self.process_invalid_frames() with self.assertRaises(Exception): self.loop.run_until_complete(self.protocol.worker_task) - self.assertConnectionClosed(1011, '') + self.assertConnectionFailed(1011, '') def test_recv_cancelled(self): recv = self.ensure_future(self.protocol.recv()) @@ -534,14 +548,14 @@ def test_fragmented_text_payload_too_big(self): self.receive_frame(Frame(False, OP_TEXT, 'café'.encode('utf-8') * 100)) self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8') * 105)) self.process_invalid_frames() - self.assertConnectionClosed(1009, '') + self.assertConnectionFailed(1009, '') def test_fragmented_binary_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(False, OP_BINARY, b'tea' * 171)) self.receive_frame(Frame(True, OP_CONT, b'tea' * 171)) self.process_invalid_frames() - self.assertConnectionClosed(1009, '') + self.assertConnectionFailed(1009, '') def test_fragmented_text_no_max_size(self): self.protocol.max_size = None # for test coverage @@ -570,18 +584,22 @@ def test_unterminated_fragmented_text(self): # Missing the second part of the fragmented frame. self.receive_frame(Frame(True, OP_BINARY, b'tea')) self.process_invalid_frames() - self.assertConnectionClosed(1002, '') + self.assertConnectionFailed(1002, '') def test_close_handshake_in_fragmented_text(self): self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) self.receive_frame(Frame(True, OP_CLOSE, b'')) self.process_invalid_frames() + # The RFC may have overlooked this case: it says that control frames + # can be interjected in the middle of a fragmented message and that a + # close frame must be echoed. Even though there's an unterminated + # message, technically, the closing handshake was successful. self.assertConnectionClosed(1005, '') def test_connection_close_in_fragmented_text(self): self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) self.process_invalid_frames() - self.assertConnectionClosed(1006, '') + self.assertConnectionFailed(1006, '') # Test miscellaneous code paths to ensure full coverage. @@ -589,7 +607,7 @@ def test_connection_lost(self): # Test calling connection_lost without going through close_connection. self.protocol.connection_lost(None) - self.assertConnectionClosed(1006, '') + self.assertConnectionFailed(1006, '') def test_ensure_connection_before_opening_handshake(self): self.protocol.state = CONNECTING @@ -683,33 +701,17 @@ def test_close_protocol_error(self): invalid_close_frame = Frame(True, OP_CLOSE, b'\x00') self.receive_frame(invalid_close_frame) self.receive_eof_if_client() + self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason='close')) - self.assertConnectionClosed(1002, '') + self.assertConnectionFailed(1002, '') def test_close_connection_lost(self): self.receive_eof() + self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason='close')) - self.assertConnectionClosed(1006, '') - - def test_remote_close_race_with_failing_connection(self): - self.make_drain_slow() - - # Fail the connection while answering a close frame from the client. - self.loop.call_soon(self.receive_frame, self.remote_close) - self.loop.call_later( - MS, self.ensure_future, self.protocol.fail_connection()) - # The client expects the server to close the connection. - # Simulate it instead of waiting for the connection timeout. - self.loop.call_later(MS, self.receive_eof_if_client) - - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(self.protocol.recv()) - - # The closing handshake was completed by fail_connection. - self.assertConnectionClosed(1011, '') - self.assertOneFrameSent(*self.remote_close) + self.assertConnectionFailed(1006, '') def test_local_close_during_recv(self): recv = self.ensure_future(self.protocol.recv()) @@ -737,7 +739,7 @@ def test_remote_close_during_send(self): with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(send) - self.assertConnectionClosed(1006, '') + self.assertConnectionClosed(1000, 'close') # There is no test_local_close_during_send because this cannot really # happen, considering that writes are serialized. From 5592994e8a8f5b418a1b6b2522b71ca7778fe102 Mon Sep 17 00:00:00 2001 From: Michael Sverdlik Date: Thu, 14 Sep 2017 06:30:50 +1000 Subject: [PATCH 0314/1539] Allow customizing User-Agent and Server headers (#266) Fix #262. --- docs/changelog.rst | 2 ++ websockets/client.py | 4 +++- websockets/server.py | 6 ++++-- websockets/test_client_server.py | 16 ++++++++++++++++ 4 files changed, 25 insertions(+), 3 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index dfa9c3346..d41898abe 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -27,6 +27,8 @@ Also: * Reduced verbosity of "Failing the WebSocket connection" logs. +* Allowed ``extra_headers`` to override ``Server`` and ``User-Agent`` headers. + 3.4 ... diff --git a/websockets/client.py b/websockets/client.py index 521a3c66e..c801c3b74 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -214,6 +214,7 @@ def handshake(self, wsuri, origin=None, """ request_headers = [] set_header = lambda k, v: request_headers.append((k, v)) + is_header_set = lambda k: k in dict(request_headers).keys() if wsuri.port == (443 if wsuri.secure else 80): # pragma: no cover set_header('Host', wsuri.host) @@ -243,7 +244,8 @@ def handshake(self, wsuri, origin=None, for name, value in extra_headers: set_header(name, value) - set_header('User-Agent', USER_AGENT) + if not is_header_set('User-Agent'): + set_header('User-Agent', USER_AGENT) key = build_request(set_header) diff --git a/websockets/server.py b/websockets/server.py index 499d863f0..ac087cfd3 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -435,8 +435,7 @@ def handshake(self, origins=None, available_extensions=None, response_headers = [] set_header = lambda k, v: response_headers.append((k, v)) - - set_header('Server', USER_AGENT) + is_header_set = lambda k: k in dict(response_headers).keys() if extensions_header is not None: set_header('Sec-WebSocket-Extensions', extensions_header) @@ -452,6 +451,9 @@ def handshake(self, origins=None, available_extensions=None, for name, value in extra_headers: set_header(name, value) + if not is_header_set('Server'): + set_header('Server', USER_AGENT) + build_response(set_header, key) yield from self.write_http_response( diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 97a700aec..06c689947 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -288,6 +288,14 @@ def test_protocol_custom_request_headers_list(self): self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", req_headers) + @with_server() + @with_client('raw_headers', extra_headers=[('User-Agent', 'Eggs')]) + def test_protocol_custom_request_user_agent(self): + req_headers = self.loop.run_until_complete(self.client.recv()) + self.loop.run_until_complete(self.client.recv()) + self.assertEqual(req_headers.count("User-Agent"), 1) + self.assertIn("('User-Agent', 'Eggs')", req_headers) + @with_server(extra_headers=lambda p, r: {'X-Spam': 'Eggs'}) @with_client('raw_headers') def test_protocol_custom_response_headers_callable_dict(self): @@ -316,6 +324,14 @@ def test_protocol_custom_response_headers_list(self): resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) + @with_server(extra_headers=[('Server', 'Eggs')]) + @with_client('raw_headers') + def test_protocol_custom_response_user_agent(self): + self.loop.run_until_complete(self.client.recv()) + resp_headers = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(resp_headers.count("Server"), 1) + self.assertIn("('Server', 'Eggs')", resp_headers) + @with_server(create_protocol=HealthCheckServerProtocol) @with_client() def test_custom_protocol_http_request(self): From fa59b1648f80f950a15f6a5d9d75f1cd67910ac0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Sep 2017 13:51:54 +0200 Subject: [PATCH 0315/1539] Add debug logs on asyncio Protocol callbacks. data_received isn't logged because that would duplicate the logs of incoming frames, which are more useful. Also: * Increase symmetry between the client and server side, * Improve some docstrings. --- websockets/client.py | 1 + websockets/protocol.py | 45 +++++++++++++++++++++---------------- websockets/server.py | 2 ++ websockets/test_protocol.py | 6 +++++ 4 files changed, 35 insertions(+), 19 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index c801c3b74..1b0b07e70 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -32,6 +32,7 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): """ is_client = True + side = 'client' state = CONNECTING def __init__(self, *, diff --git a/websockets/protocol.py b/websockets/protocol.py index 078d9b994..08743abbd 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -120,10 +120,9 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): """ # There are only two differences between the client-side and the server- # side behavior: masking the payload and closing the underlying TCP - # connection. This class implements the server-side behavior by default. - # To get the client-side behavior, set is_client = True. - - is_client = False + # connection. Set is_client and side to pick a side. + is_client = None + side = 'undefined' state = OPEN def __init__(self, *, @@ -577,10 +576,7 @@ def read_frame(self, max_size): max_size=max_size, extensions=self.extensions, ) - logger.debug( - "%s << %s", - 'client' if self.is_client else 'server', frame, - ) + logger.debug("%s < %s", self.side, frame) return frame @asyncio.coroutine @@ -596,10 +592,7 @@ def write_frame(self, opcode, data=b''): self.state = CLOSING frame = Frame(True, opcode, data) - logger.debug( - "%s >> %s", - 'client' if self.is_client else 'server', frame, - ) + logger.debug("%s > %s", self.side, frame) frame.write( self.writer.write, mask=self.is_client, @@ -675,8 +668,14 @@ def close_connection(self, force=False): @asyncio.coroutine def fail_connection(self, code=1011, reason=''): - # 7.1.7. Fail the WebSocket Connection - logger.debug("Failing the WebSocket connection: %d %s", code, reason) + """ + 7.1.7. Fail the WebSocket Connection + + """ + logger.debug( + "%s ! failing WebSocket connection: %d %s", + self.side, code, reason, + ) if self.state == OPEN: if code == 1006: # Don't send a close frame if the connection is broken. Set @@ -701,24 +700,30 @@ def connection_made(self, transport): :meth:`~asyncio.WriteTransport.set_write_buffer_limits`, which should be all right for reasonable use cases of this library. + This is the earliest point where we can get hold of the transport, + which means it's the best point for configuring it. + """ + logger.debug("%s - connection_made(%s)", self.side, transport) transport.set_write_buffer_limits(self.write_limit) super().connection_made(transport) def eof_received(self): """ - Ensure the transport is closed after receiving EOF. + Close the transport after receiving EOF. Since Python 3.5, `:meth:~StreamReaderProtocol.eof_received` returns - ``True``. See http://bugs.python.org/issue24539 for details. + ``True`` on non-TLS connections. + + See http://bugs.python.org/issue24539 for more information. This is inappropriate for websockets for at least three reasons: 1. The use case is to read data until EOF with self.reader.read(-1). Since websockets is a TLV protocol, this never happens. - 2. It doesn't work on SSL connections. A falsy value must be - returned to have the same behavior on SSL and plain connections. + 2. It doesn't work on TLS connections. A falsy value must be + returned to have the same behavior on TLS and plain connections. 3. The websockets protocol has its own closing handshake. Endpoints close the TCP connection after sending a Close frame. @@ -726,14 +731,16 @@ def eof_received(self): As a consequence we revert to the previous, more useful behavior. """ + logger.debug("%s - eof_received()", self.side) super().eof_received() return def connection_lost(self, exc): """ - Implement section 7.1.4. The WebSocket Connection is Closed. + 7.1.4. The WebSocket Connection is Closed. """ + logger.debug("%s - connection_lost(%s)", self.side, exc) self.state = CLOSED if not self.opening_handshake.done(): self.opening_handshake.set_result(False) diff --git a/websockets/server.py b/websockets/server.py index ac087cfd3..059847f2b 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -40,6 +40,8 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): Its support for HTTP responses is very limited. """ + is_client = False + side = 'server' state = CONNECTING def __init__(self, ws_handler, ws_server, *, diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 3c1cff7e4..20bc8fa89 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -769,6 +769,11 @@ def test_cancelled_close_waits_for_worker(self): class ServerTests(CommonTests, unittest.TestCase): + def setUp(self): + super().setUp() + self.protocol.is_client = False + self.protocol.side = 'server' + def test_close_handshake_timeout(self): # Timeout is expected in 10ms. self.protocol.timeout = 10 * MS @@ -786,6 +791,7 @@ class ClientTests(CommonTests, unittest.TestCase): def setUp(self): super().setUp() self.protocol.is_client = True + self.protocol.side = 'client' def test_close_handshake_timeout(self): # Timeout is expected in 2 * 10 = 20ms. From 303b35590b080cf27e026cd45c42ee5d6c12d7cb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Sep 2017 14:08:23 +0200 Subject: [PATCH 0316/1539] Factor out test setup. --- websockets/test_client_server.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 06c689947..64d168ef6 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -187,7 +187,8 @@ def start_server(self, **kwds): def start_client(self, path='', **kwds): # Don't enable compression by default in tests. kwds.setdefault('compression', None) - client = connect('ws://localhost:8642/' + path, **kwds) + proto = 'ws' if kwds.get('ssl') is None else 'wss' + client = connect(proto + '://localhost:8642/' + path, **kwds) self.client = self.loop.run_until_complete(client) def stop_client(self): @@ -804,19 +805,13 @@ def client_context(self): ssl_context.verify_mode = ssl.CERT_REQUIRED return ssl_context - def start_server(self, *args, **kwds): + def start_server(self, **kwds): kwds.setdefault('ssl', self.server_context) - # Don't enable compression by default in tests. - kwds.setdefault('compression', None) - server = serve(handler, 'localhost', 8642, **kwds) - self.server = self.loop.run_until_complete(server) + super().start_server(**kwds) def start_client(self, path='', **kwds): kwds.setdefault('ssl', self.client_context) - # Don't enable compression by default in tests. - kwds.setdefault('compression', None) - client = connect('wss://localhost:8642/' + path, **kwds) - self.client = self.loop.run_until_complete(client) + super().start_client(path, **kwds) @with_server() def test_ws_uri_is_rejected(self): From f0981a589ab94971274a2315b4588fff5a0a5e12 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 19 Sep 2017 22:00:18 +0200 Subject: [PATCH 0317/1539] Review coding style in framing module. Also in related tests. --- websockets/framing.py | 42 ++++++++------- websockets/test_framing.py | 103 ++++++++++++++++++++++++++----------- 2 files changed, 98 insertions(+), 47 deletions(-) diff --git a/websockets/framing.py b/websockets/framing.py index 5c007d6f5..9a72fcd56 100644 --- a/websockets/framing.py +++ b/websockets/framing.py @@ -92,14 +92,14 @@ def read(cls, reader, *, mask, max_size=None, extensions=None): If ``extensions`` is provided, it's a list of classes with an ``decode()`` method that transform the frame and return a new frame. - They are applied in order. + They are applied in reverse order. This function validates the frame before returning it and raises :exc:`~websockets.exceptions.WebSocketProtocolError` if it contains incorrect values. """ - # Read the header + # Read the header. data = yield from reader(2) head1, head2 = struct.unpack('!BB', data) @@ -121,12 +121,13 @@ def read(cls, reader, *, mask, max_size=None, extensions=None): data = yield from reader(8) length, = struct.unpack('!Q', data) if max_size is not None and length > max_size: - raise PayloadTooBig("Payload exceeds limit " - "({} > {} bytes)".format(length, max_size)) + raise PayloadTooBig( + "Payload length exceeds limit ({} > {} bytes)" + .format(length, max_size)) if mask: mask_bits = yield from reader(4) - # Read the data + # Read the data. data = yield from reader(length) if mask: data = apply_mask(data, mask_bits) @@ -162,12 +163,11 @@ def write(frame, writer, *, mask, extensions=None): incorrect values. """ - - frame.check() - # The first parameter is called `frame` rather than `self`, # but it's the instance of class to which this method is bound. + frame.check() + if extensions is None: extensions = [] for extension in extensions: @@ -175,7 +175,7 @@ def write(frame, writer, *, mask, extensions=None): output = io.BytesIO() - # Prepare the header + # Prepare the header. head1 = ( (0b10000000 if frame.fin else 0) | (0b01000000 if frame.rsv1 else 0) | @@ -198,23 +198,27 @@ def write(frame, writer, *, mask, extensions=None): mask_bits = struct.pack('!I', random.getrandbits(32)) output.write(mask_bits) - # Prepare the data + # Prepare the data. if mask: data = apply_mask(frame.data, mask_bits) else: data = frame.data output.write(data) - # Send the frame + # Send the frame. + + # The frame is written in a single call to writer in order to prevent + # TCP fragmentation. See #68 for details. writer(output.getvalue()) def check(frame): """ - Raise :exc:`~websockets.exceptions.WebSocketProtocolError` if the frame - contains incorrect values. + Check that this frame contains acceptable values. - """ + Raise :exc:`~websockets.exceptions.WebSocketProtocolError` if this + frame contains incorrect values. + """ # The first parameter is called `frame` rather than `self`, # but it's the instance of class to which this method is bound. @@ -229,7 +233,8 @@ def check(frame): if not frame.fin: raise WebSocketProtocolError("Fragmented control frame") else: - raise WebSocketProtocolError("Invalid opcode") + raise WebSocketProtocolError( + "Invalid opcode ({})".format(frame.opcode)) def parse_close(data): @@ -244,14 +249,15 @@ def parse_close(data): """ length = len(data) - if length == 0: - return 1005, '' - elif length >= 2: + if length >= 2: code, = struct.unpack('!H', data[:2]) check_close(code) reason = data[2:].decode('utf-8') return code, reason + elif length == 0: + return 1005, '' else: + assert length == 1 raise WebSocketProtocolError("Close frame too short") diff --git a/websockets/test_framing.py b/websockets/test_framing.py index 04f8acda3..4271b5a18 100644 --- a/websockets/test_framing.py +++ b/websockets/test_framing.py @@ -16,27 +16,30 @@ def setUp(self): def tearDown(self): self.loop.close() - def decode(self, message, mask=False, max_size=None, extensions=()): + def decode(self, message, mask=False, max_size=None, extensions=None): self.stream = asyncio.StreamReader(loop=self.loop) self.stream.feed_data(message) self.stream.feed_eof() frame = self.loop.run_until_complete(Frame.read( self.stream.readexactly, mask=mask, - max_size=max_size, extensions=extensions)) + max_size=max_size, extensions=extensions, + )) # Make sure all the data was consumed. self.assertTrue(self.stream.at_eof()) return frame - def encode(self, frame, mask=False, extensions=()): + def encode(self, frame, mask=False, extensions=None): writer = unittest.mock.Mock() frame.write(writer, mask=mask, extensions=extensions) # Ensure the entire frame is sent with a single call to writer(). # Multiple calls cause TCP fragmentation and degrade performance. self.assertEqual(writer.call_count, 1) # The frame data is the single positional argument of that call. + self.assertEqual(len(writer.call_args[0]), 1) + self.assertEqual(len(writer.call_args[1]), 0) return writer.call_args[0][0] - def round_trip(self, message, expected, mask=False, extensions=()): + def round_trip(self, message, expected, mask=False, extensions=None): decoded = self.decode(message, mask, extensions=extensions) self.assertEqual(decoded, expected) encoded = self.encode(decoded, mask, extensions=extensions) @@ -53,82 +56,123 @@ def round_trip_close(self, data, code, reason): self.assertEqual(serialized, data) def test_text(self): - self.round_trip(b'\x81\x04Spam', Frame(True, OP_TEXT, b'Spam')) + self.round_trip( + b'\x81\x04Spam', + Frame(True, OP_TEXT, b'Spam'), + ) def test_text_masked(self): self.round_trip( b'\x81\x84\x5b\xfb\xe1\xa8\x08\x8b\x80\xc5', - Frame(True, OP_TEXT, b'Spam'), mask=True) + Frame(True, OP_TEXT, b'Spam'), + mask=True, + ) def test_binary(self): - self.round_trip(b'\x82\x04Eggs', Frame(True, OP_BINARY, b'Eggs')) + self.round_trip( + b'\x82\x04Eggs', + Frame(True, OP_BINARY, b'Eggs'), + ) def test_binary_masked(self): self.round_trip( b'\x82\x84\x53\xcd\xe2\x89\x16\xaa\x85\xfa', - Frame(True, OP_BINARY, b'Eggs'), mask=True) + Frame(True, OP_BINARY, b'Eggs'), + mask=True, + ) def test_non_ascii_text(self): self.round_trip( b'\x81\x05caf\xc3\xa9', - Frame(True, OP_TEXT, 'café'.encode('utf-8'))) + Frame(True, OP_TEXT, 'café'.encode('utf-8')), + ) def test_non_ascii_text_masked(self): self.round_trip( b'\x81\x85\x64\xbe\xee\x7e\x07\xdf\x88\xbd\xcd', - Frame(True, OP_TEXT, 'café'.encode('utf-8')), mask=True) + Frame(True, OP_TEXT, 'café'.encode('utf-8')), + mask=True, + ) def test_close(self): - self.round_trip(b'\x88\x00', Frame(True, OP_CLOSE, b'')) + self.round_trip( + b'\x88\x00', + Frame(True, OP_CLOSE, b''), + ) def test_ping(self): - self.round_trip(b'\x89\x04ping', Frame(True, OP_PING, b'ping')) + self.round_trip( + b'\x89\x04ping', + Frame(True, OP_PING, b'ping'), + ) def test_pong(self): - self.round_trip(b'\x8a\x04pong', Frame(True, OP_PONG, b'pong')) + self.round_trip( + b'\x8a\x04pong', + Frame(True, OP_PONG, b'pong'), + ) def test_long(self): self.round_trip( b'\x82\x7e\x00\x7e' + 126 * b'a', - Frame(True, OP_BINARY, 126 * b'a')) + Frame(True, OP_BINARY, 126 * b'a'), + ) def test_very_long(self): self.round_trip( b'\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x00' + 65536 * b'a', - Frame(True, OP_BINARY, 65536 * b'a')) + Frame(True, OP_BINARY, 65536 * b'a'), + ) def test_payload_too_big(self): with self.assertRaises(PayloadTooBig): - self.decode(b'\x82\x7e\x04\x01' + 1025 * b'a', max_size=1024) + self.decode( + b'\x82\x7e\x04\x01' + 1025 * b'a', + max_size=1024, + ) def test_bad_reserved_bits(self): - with self.assertRaises(WebSocketProtocolError): - self.decode(b'\xc0\x00') - with self.assertRaises(WebSocketProtocolError): - self.decode(b'\xa0\x00') - with self.assertRaises(WebSocketProtocolError): - self.decode(b'\x90\x00') + for encoded in [b'\xc0\x00', b'\xa0\x00', b'\x90\x00']: + with self.subTest(encoded=encoded): + with self.assertRaises(WebSocketProtocolError): + self.decode(encoded) - def test_bad_opcode(self): + def test_good_opcode(self): for opcode in list(range(0x00, 0x03)) + list(range(0x08, 0x0b)): - self.decode(bytes([0x80 | opcode, 0])) + encoded = bytes([0x80 | opcode, 0]) + with self.subTest(encoded=encoded): + self.decode(encoded) # does not raise an exception + + def test_bad_opcode(self): for opcode in list(range(0x03, 0x08)) + list(range(0x0b, 0x10)): - with self.assertRaises(WebSocketProtocolError): - self.decode(bytes([0x80 | opcode, 0])) + encoded = bytes([0x80 | opcode, 0]) + with self.subTest(encoded=encoded): + with self.assertRaises(WebSocketProtocolError): + self.decode(encoded) - def test_bad_mask_flag(self): + def test_mask_flag(self): + # Mask flag correctly set. self.decode(b'\x80\x80\x00\x00\x00\x00', mask=True) + # Mask flag incorrectly unset. with self.assertRaises(WebSocketProtocolError): self.decode(b'\x80\x80\x00\x00\x00\x00') + # Mask flag correctly unset. self.decode(b'\x80\x00') + # Mask flag incorrectly set. with self.assertRaises(WebSocketProtocolError): self.decode(b'\x80\x00', mask=True) - def test_control_frame_too_long(self): + def test_control_frame_max_length(self): + # At maximum allowed length. + self.decode(b'\x88\x7e\x00\x7d' + 125 * b'a') + # Above maximum allowed length. with self.assertRaises(WebSocketProtocolError): self.decode(b'\x88\x7e\x00\x7e' + 126 * b'a') def test_fragmented_control_frame(self): + # Fin bit correctly set. + self.decode(b'\x88\x00') + # Fin bit incorrectly unset. with self.assertRaises(WebSocketProtocolError): self.decode(b'\x08\x00') @@ -168,4 +212,5 @@ def encode(frame): self.round_trip( b'\x81\x05uryyb', Frame(True, OP_TEXT, b'hello'), - extensions=[Rot13()]) + extensions=[Rot13()], + ) From 73a11ed80cab88ea24ec44b5f272590b8c4719e7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 19 Sep 2017 22:18:05 +0200 Subject: [PATCH 0318/1539] Review http module and tests. --- websockets/http.py | 2 +- websockets/test_http.py | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/websockets/http.py b/websockets/http.py index 9d2316b70..99eef8482 100644 --- a/websockets/http.py +++ b/websockets/http.py @@ -153,7 +153,7 @@ def read_headers(stream): # We don't attempt to support obsolete line folding. headers = [] - for _ in range(MAX_HEADERS): + for _ in range(MAX_HEADERS + 1): line = yield from read_line(stream) if line == b'\r\n': break diff --git a/websockets/test_http.py b/websockets/test_http.py index a891ad5ea..a6d61299b 100644 --- a/websockets/test_http.py +++ b/websockets/test_http.py @@ -30,9 +30,10 @@ def test_read_request(self): b'Sec-WebSocket-Version: 13\r\n' b'\r\n' ) - path, hdrs = self.loop.run_until_complete(read_request(self.stream)) + path, headers = self.loop.run_until_complete( + read_request(self.stream)) self.assertEqual(path, '/chat') - self.assertEqual(dict(hdrs)['Upgrade'], 'websocket') + self.assertEqual(dict(headers)['Upgrade'], 'websocket') def test_read_response(self): # Example from the protocol overview in RFC 6455 @@ -85,12 +86,13 @@ def test_header_value(self): self.loop.run_until_complete(read_headers(self.stream)) def test_headers_limit(self): - self.stream.feed_data(b'foo: bar\r\n' * 500 + b'\r\n') + self.stream.feed_data(b'foo: bar\r\n' * 257 + b'\r\n') with self.assertRaises(ValueError): self.loop.run_until_complete(read_headers(self.stream)) def test_line_limit(self): - self.stream.feed_data(b'a' * 5000 + b'\r\n\r\n') + # Header line contains 5 + 4090 + 2 = 4097 bytes. + self.stream.feed_data(b'foo: ' + b'a' * 4090 + b'\r\n\r\n') with self.assertRaises(ValueError): self.loop.run_until_complete(read_headers(self.stream)) From b893f89a530acf5f70d7e77499eedfb0285ab8b4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 19 Sep 2017 22:22:54 +0200 Subject: [PATCH 0319/1539] Review handshake module and tests. --- websockets/handshake.py | 2 +- websockets/test_handshake.py | 36 ++++++++++++++++++------------------ 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/websockets/handshake.py b/websockets/handshake.py index cb6d742a6..b84325001 100644 --- a/websockets/handshake.py +++ b/websockets/handshake.py @@ -46,7 +46,7 @@ 'build_response', 'check_response', ] -GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" +GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' def build_request(set_header): diff --git a/websockets/test_handshake.py b/websockets/test_handshake.py index 2642d3855..fee35bc9d 100644 --- a/websockets/test_handshake.py +++ b/websockets/test_handshake.py @@ -24,7 +24,7 @@ def test_round_trip(self): check_response(response_headers.__getitem__, request_key) @contextlib.contextmanager - def assert_invalid_request_headers(self): + def assertInvalidRequestHeaders(self): """ Provide request headers for corruption. @@ -38,47 +38,47 @@ def assert_invalid_request_headers(self): check_request(headers.__getitem__) def test_request_invalid_upgrade(self): - with self.assert_invalid_request_headers() as headers: + with self.assertInvalidRequestHeaders() as headers: headers['Upgrade'] = 'socketweb' def test_request_missing_upgrade(self): - with self.assert_invalid_request_headers() as headers: + with self.assertInvalidRequestHeaders() as headers: del headers['Upgrade'] def test_request_invalid_connection(self): - with self.assert_invalid_request_headers() as headers: + with self.assertInvalidRequestHeaders() as headers: headers['Connection'] = 'Downgrade' def test_request_missing_connection(self): - with self.assert_invalid_request_headers() as headers: + with self.assertInvalidRequestHeaders() as headers: del headers['Connection'] def test_request_invalid_key_not_base64(self): - with self.assert_invalid_request_headers() as headers: + with self.assertInvalidRequestHeaders() as headers: headers['Sec-WebSocket-Key'] = "!@#$%^&*()" def test_request_invalid_key_not_well_padded(self): - with self.assert_invalid_request_headers() as headers: + with self.assertInvalidRequestHeaders() as headers: headers['Sec-WebSocket-Key'] = "CSIRmL8dWYxeAdr/XpEHRw" def test_request_invalid_key_not_16_bytes_long(self): - with self.assert_invalid_request_headers() as headers: + with self.assertInvalidRequestHeaders() as headers: headers['Sec-WebSocket-Key'] = "ZLpprpvK4PE=" def test_request_missing_key(self): - with self.assert_invalid_request_headers() as headers: + with self.assertInvalidRequestHeaders() as headers: del headers['Sec-WebSocket-Key'] def test_request_invalid_version(self): - with self.assert_invalid_request_headers() as headers: + with self.assertInvalidRequestHeaders() as headers: headers['Sec-WebSocket-Version'] = '42' def test_request_missing_version(self): - with self.assert_invalid_request_headers() as headers: + with self.assertInvalidRequestHeaders() as headers: del headers['Sec-WebSocket-Version'] @contextlib.contextmanager - def assert_invalid_response_headers(self, key='CSIRmL8dWYxeAdr/XpEHRw=='): + def assertInvalidResponseHeaders(self, key='CSIRmL8dWYxeAdr/XpEHRw=='): """ Provide response headers for corruption. @@ -92,26 +92,26 @@ def assert_invalid_response_headers(self, key='CSIRmL8dWYxeAdr/XpEHRw=='): check_response(headers.__getitem__, key) def test_response_invalid_upgrade(self): - with self.assert_invalid_response_headers() as headers: + with self.assertInvalidResponseHeaders() as headers: headers['Upgrade'] = 'socketweb' def test_response_missing_upgrade(self): - with self.assert_invalid_response_headers() as headers: + with self.assertInvalidResponseHeaders() as headers: del headers['Upgrade'] def test_response_invalid_connection(self): - with self.assert_invalid_response_headers() as headers: + with self.assertInvalidResponseHeaders() as headers: headers['Connection'] = 'Downgrade' def test_response_missing_connection(self): - with self.assert_invalid_response_headers() as headers: + with self.assertInvalidResponseHeaders() as headers: del headers['Connection'] def test_response_invalid_accept(self): - with self.assert_invalid_response_headers() as headers: + with self.assertInvalidResponseHeaders() as headers: other_key = "1Eq4UDEFQYg3YspNgqxv5g==" headers['Sec-WebSocket-Accept'] = accept(other_key) def test_response_missing_accept(self): - with self.assert_invalid_response_headers() as headers: + with self.assertInvalidResponseHeaders() as headers: del headers['Sec-WebSocket-Accept'] From d37433902cc959b050925581d6e9bd361262e9be Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 19 Sep 2017 22:24:19 +0200 Subject: [PATCH 0320/1539] Review uri module and tests. --- websockets/uri.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/websockets/uri.py b/websockets/uri.py index 48c39c1a5..6d9f9124c 100644 --- a/websockets/uri.py +++ b/websockets/uri.py @@ -15,7 +15,7 @@ __all__ = ['parse_uri', 'WebSocketURI'] WebSocketURI = collections.namedtuple( - 'WebSocketURI', ('secure', 'host', 'port', 'resource_name')) + 'WebSocketURI', ['secure', 'host', 'port', 'resource_name']) WebSocketURI.__doc__ = """WebSocket URI. * ``secure`` is the secure flag @@ -37,7 +37,7 @@ def parse_uri(uri): """ uri = urllib.parse.urlparse(uri) try: - assert uri.scheme in ('ws', 'wss') + assert uri.scheme in ['ws', 'wss'] assert uri.params == '' assert uri.fragment == '' assert uri.username is None From 7d7e2630188f965999908085da9b85293517174f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 23 Sep 2017 22:00:29 +0200 Subject: [PATCH 0321/1539] Log unexpected exceptions at the warning level. This is consistent with how errors in the opening handshake are handled. --- websockets/protocol.py | 1 + 1 file changed, 1 insertion(+) diff --git a/websockets/protocol.py b/websockets/protocol.py index 08743abbd..98adf5ebb 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -475,6 +475,7 @@ def run(self): except PayloadTooBig: yield from self.fail_connection(1009) except Exception: + logger.warning("Error in data transfer", exc_info=True) yield from self.fail_connection(1011) raise yield from self.close_connection() From 1b8ab39be670b3e8f46839caafb6824133cbd59b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 23 Sep 2017 22:07:45 +0200 Subject: [PATCH 0322/1539] Handle cancellation of pong waiters more elegantly. Fix #269. --- websockets/protocol.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index 98adf5ebb..600da9b45 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -399,8 +399,10 @@ def ping(self, data=None): data = struct.pack('!I', random.getrandbits(32)) self.pings[data] = asyncio.Future(loop=self.loop) + yield from self.write_frame(OP_PING, data) - return self.pings[data] + + return asyncio.shield(self.pings[data]) @asyncio.coroutine def pong(self, data=b''): @@ -561,9 +563,8 @@ def read_data_frame(self, max_size): # Acknowledge all pings up to the one matching this pong. ping_id = None while ping_id != frame.data: - ping_id, waiter = self.pings.popitem(0) - if not waiter.cancelled(): - waiter.set_result(None) + ping_id, pong_waiter = self.pings.popitem(0) + pong_waiter.set_result(None) # 5.6. Data Frames else: From 16b09f801a010b17759cff5a2a99828a2c7f7cae Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 23 Sep 2017 22:08:08 +0200 Subject: [PATCH 0323/1539] Fix documentation of how to wait for a pong. --- docs/cheatsheet.rst | 3 ++- websockets/protocol.py | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/cheatsheet.rst b/docs/cheatsheet.rst index 21509acae..f73812a58 100644 --- a/docs/cheatsheet.rst +++ b/docs/cheatsheet.rst @@ -86,7 +86,8 @@ idle connections after some time:: except asyncio.TimeoutError: # No data in 20 seconds, check the connection. try: - await asyncio.wait_for(ws.ping(), timeout=10) + pong_waiter = await ws.ping() + await asyncio.wait_for(pong_waiter, timeout=10) except asyncio.TimeoutError: # No response to ping in 10 seconds, disconnect. break diff --git a/websockets/protocol.py b/websockets/protocol.py index 600da9b45..0c3f5a228 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -378,7 +378,10 @@ def ping(self, data=None): want to wait. A ping may serve as a keepalive or as a check that the remote endpoint - received all messages up to this point, with ``yield from ws.ping()``. + received all messages up to this point:: + + pong_waiter = await ws.ping() + await pong_waiter # only if you want to wait for the pong By default, the ping contains four random bytes. The content may be overridden with the optional ``data`` argument which must be of type From 9c67144b41ab90903eb6a0dc11658d6180455f4c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 24 Sep 2017 09:23:29 +0200 Subject: [PATCH 0324/1539] Update class and method signatures. They were out of sync after extensions were implemented. --- docs/api.rst | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index 302fe9829..742268f3f 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -32,29 +32,29 @@ Server .. automodule:: websockets.server - .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, subprotocols=None, extra_headers=None, **kwds) + .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, compression='deflate', **kwds) .. autoclass:: WebSocketServer .. automethod:: close() .. automethod:: wait_closed() - .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, subprotocols=None, extra_headers=None) + .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, extensions=None, subprotocols=None, extra_headers=None) - .. automethod:: handshake(origins=None, subprotocols=None, extra_headers=None) + .. automethod:: handshake(origins=None, available_extensions=None, available_subprotocols=None, extra_headers=None) .. automethod:: process_request(path, request_headers) - .. automethod:: select_subprotocol(client_protos, server_protos) + .. automethod:: select_subprotocol(client_subprotocols, server_subprotocols) Client ...... .. automodule:: websockets.client - .. autofunction:: connect(uri, *, create_protocol=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, subprotocols=None, extra_headers=None, **kwds) + .. autofunction:: connect(uri, *, create_protocol=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, compression='deflate', **kwds) - .. autoclass:: WebSocketClientProtocol(*, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None) + .. autoclass:: WebSocketClientProtocol(*, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, extensions=None, subprotocols=None, extra_headers=None) - .. automethod:: handshake(wsuri, origin=None, subprotocols=None, extra_headers=None) + .. automethod:: handshake(wsuri, origin=None, available_extensions=None, available_subprotocols=None, extra_headers=None) Shared ...... From 37b596dd17f0e4c58dfb15ec214dba557f781dd1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 6 Sep 2017 23:00:26 +0200 Subject: [PATCH 0325/1539] Get rid of self.closing_handshake. --- websockets/protocol.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index 0c3f5a228..98ef68098 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -167,16 +167,14 @@ def __init__(self, *, self.extensions = [] self.subprotocol = None - # Code and reason must be set when the closing handshake completes. + # The close code and reason are set when receiving a close frame or + # losing the TCP connection. self.close_code = None self.close_reason = '' # Futures tracking steps in the connection's lifecycle. # Set to True when the opening handshake has completed properly. self.opening_handshake = asyncio.Future(loop=loop) - # Set to True when the closing handshake has completed properly and to - # False when the connection terminates abnormally. - self.closing_handshake = asyncio.Future(loop=loop) # Set to None when the connection state becomes CLOSED. self.connection_closed = asyncio.Future(loop=loop) @@ -463,7 +461,7 @@ def ensure_open(self): def run(self): # This coroutine guarantees that the connection is closed at exit. yield from self.opening_handshake - while not self.closing_handshake.done(): + while True: try: msg = yield from self.read_message() if msg is None: @@ -473,12 +471,16 @@ def run(self): break except WebSocketProtocolError: yield from self.fail_connection(1002) + break except asyncio.IncompleteReadError: yield from self.fail_connection(1006) + break except UnicodeDecodeError: yield from self.fail_connection(1007) + break except PayloadTooBig: yield from self.fail_connection(1009) + break except Exception: logger.warning("Error in data transfer", exc_info=True) yield from self.fail_connection(1011) @@ -553,7 +555,6 @@ def read_data_frame(self, max_size): if self.state == OPEN: # 7.1.3. The WebSocket Closing Handshake is Started yield from self.write_frame(OP_CLOSE, frame.data) - self.closing_handshake.set_result(True) return elif frame.opcode == OP_PING: @@ -689,8 +690,6 @@ def fail_connection(self, code=1011, reason=''): else: frame_data = serialize_close(code, reason) yield from self.write_frame(OP_CLOSE, frame_data) - if not self.closing_handshake.done(): - self.closing_handshake.set_result(False) yield from self.close_connection() # asyncio.StreamReaderProtocol methods @@ -751,8 +750,6 @@ def connection_lost(self, exc): self.opening_handshake.set_result(False) if self.close_code is None: self.close_code = 1006 - if not self.closing_handshake.done(): - self.closing_handshake.set_result(False) if not self.connection_closed.done(): self.connection_closed.set_result(None) # Close the transport in case close_connection() wasn't executed. From dd16c0e56289d0cd7593ce7c83399613e95f740e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 9 Sep 2017 11:12:02 +0200 Subject: [PATCH 0326/1539] Get rid of self.opening_handshake. --- websockets/client.py | 7 ++----- websockets/protocol.py | 30 ++++++++++++++---------------- websockets/server.py | 8 ++------ websockets/test_protocol.py | 14 ++++++++++++++ 4 files changed, 32 insertions(+), 27 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index 1b0b07e70..b400f94eb 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -16,7 +16,7 @@ parse_protocol_list ) from .http import USER_AGENT, build_headers, read_response -from .protocol import CONNECTING, OPEN, WebSocketCommonProtocol +from .protocol import WebSocketCommonProtocol from .uri import parse_uri @@ -33,7 +33,6 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): """ is_client = True side = 'client' - state = CONNECTING def __init__(self, *, origin=None, extensions=None, subprotocols=None, @@ -267,9 +266,7 @@ def handshake(self, wsuri, origin=None, self.subprotocol = self.process_subprotocol( response_headers, available_subprotocols) - assert self.state == CONNECTING - self.state = OPEN - self.opening_handshake.set_result(True) + self.connection_open() @asyncio.coroutine diff --git a/websockets/protocol.py b/websockets/protocol.py index 98ef68098..7605827c1 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -123,7 +123,6 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): # connection. Set is_client and side to pick a side. is_client = None side = 'undefined' - state = OPEN def __init__(self, *, host=None, port=None, secure=None, @@ -158,6 +157,12 @@ def __init__(self, *, self.writer = None self._drain_lock = asyncio.Lock(loop=loop) + # This class implements the data transfer and closing handshake, which + # are shared between the client-side and the server-side. Subclasses + # implement the opening handshake and execute connection_open() to + # change the state to OPEN. + self.state = CONNECTING + self.path = None self.request_headers = None self.raw_request_headers = None @@ -172,9 +177,6 @@ def __init__(self, *, self.close_code = None self.close_reason = '' - # Futures tracking steps in the connection's lifecycle. - # Set to True when the opening handshake has completed properly. - self.opening_handshake = asyncio.Future(loop=loop) # Set to None when the connection state becomes CLOSED. self.connection_closed = asyncio.Future(loop=loop) @@ -184,26 +186,19 @@ def __init__(self, *, # Mapping of ping IDs to waiters, in chronological order. self.pings = collections.OrderedDict() - # Task managing the connection, initalized in self.client_connected. + # Task managing the connection, initialized in connection_open. self.worker_task = None - # In a subclass implementing the opening handshake, the state will be - # CONNECTING at this point. - if self.state == OPEN: - self.opening_handshake.set_result(True) - def client_connected(self, reader, writer): """ Callback for :class:`~asyncio.StreamReaderProtocol`. Record references to the stream reader and the stream writer to avoid - using private APIs``self._stream_reader`` and ``self._stream_writer``. + using private attributes ``_stream_reader`` and ``_stream_writer``. """ self.reader = reader self.writer = writer - # Start the task that handles incoming messages. - self.worker_task = asyncio_ensure_future(self.run(), loop=self.loop) # Public API @@ -434,6 +429,12 @@ def encode_data(self, data): else: raise TypeError("data must be bytes or str") + def connection_open(self): + assert self.state == CONNECTING + self.state = OPEN + # Start the task that handles incoming messages. + self.worker_task = asyncio_ensure_future(self.run(), loop=self.loop) + @asyncio.coroutine def ensure_open(self): # Raise a suitable exception if the connection isn't open. @@ -460,7 +461,6 @@ def ensure_open(self): @asyncio.coroutine def run(self): # This coroutine guarantees that the connection is closed at exit. - yield from self.opening_handshake while True: try: msg = yield from self.read_message() @@ -746,8 +746,6 @@ def connection_lost(self, exc): """ logger.debug("%s - connection_lost(%s)", self.side, exc) self.state = CLOSED - if not self.opening_handshake.done(): - self.opening_handshake.set_result(False) if self.close_code is None: self.close_code = 1006 if not self.connection_closed.done(): diff --git a/websockets/server.py b/websockets/server.py index 059847f2b..515ac508c 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -21,7 +21,7 @@ build_extension_list, parse_extension_list, parse_protocol_list ) from .http import USER_AGENT, build_headers, read_request -from .protocol import CONNECTING, OPEN, WebSocketCommonProtocol +from .protocol import WebSocketCommonProtocol __all__ = ['serve', 'WebSocketServerProtocol'] @@ -42,7 +42,6 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): """ is_client = False side = 'server' - state = CONNECTING def __init__(self, ws_handler, ws_server, *, origins=None, extensions=None, subprotocols=None, @@ -118,7 +117,6 @@ def handler(self): ) yield from self.write_http_response(*early_response) - self.opening_handshake.set_result(False) yield from self.close_connection(force=True) return @@ -461,9 +459,7 @@ def handshake(self, origins=None, available_extensions=None, yield from self.write_http_response( SWITCHING_PROTOCOLS, response_headers) - assert self.state == CONNECTING - self.state = OPEN - self.opening_handshake.set_result(True) + self.connection_open() return path diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 20bc8fa89..db9637256 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -31,6 +31,9 @@ class TransportMock(unittest.mock.Mock): It calls the protocol's connection_made and connection_lost methods like actual transports. + It also calls the protocol's connection_open method to bypass the + WebSocket handshake. + To simulate incoming data, tests call the protocol's data_received and eof_received methods directly. @@ -43,7 +46,10 @@ def connect(self, loop, protocol): self.protocol = protocol # Remove when dropping support for Python < 3.6. self._closing = False + # Simulate a successful TCP handshake. self.loop.call_soon(self.protocol.connection_made, self) + # Simulate a successful WebSocket handshake. + self.loop.call_soon(self.protocol.connection_open) def close(self): # Remove when dropping support for Python < 3.6. @@ -285,11 +291,15 @@ def test_remote_address(self): get_extra_info.assert_called_with('peername', None) def test_open(self): + self.assertFalse(self.protocol.open) + self.run_loop_once() self.assertTrue(self.protocol.open) self.close_connection() self.assertFalse(self.protocol.open) def test_state_name(self): + self.assertEqual(self.protocol.state_name, 'CONNECTING') + self.run_loop_once() self.assertEqual(self.protocol.state_name, 'OPEN') self.close_connection() self.assertEqual(self.protocol.state_name, 'CLOSED') @@ -610,6 +620,10 @@ def test_connection_lost(self): self.assertConnectionFailed(1006, '') def test_ensure_connection_before_opening_handshake(self): + # Finalize the connection opening sequence. + self.run_loop_once() + + # Simulate a bug by forcibly reverting the protocol state. self.protocol.state = CONNECTING with self.assertRaises(InvalidState): From 0827eeee5bfe8b776dfae70642a336d324435027 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 9 Sep 2017 16:44:16 +0200 Subject: [PATCH 0327/1539] Rework connection close handling. * Split the task that runs the data transfer part from the task that closes the TCP connection. * Explain why self.connection_closed is needed. * Fix usage of wait_for to avoid accidental cancellations. * Simplify connection_lost() because it's called exactly once. * Don't call close() immediately after write_eof() -- the goal is to wait until the other end closes the connection too and then close() in connection_lost(). * Refactor tests to reduce asynchrony. * Make the transport mock in tests more realistic. --- websockets/client.py | 2 +- websockets/protocol.py | 243 +++++++++++++++++++------------ websockets/server.py | 14 +- websockets/test_client_server.py | 14 +- websockets/test_protocol.py | 214 ++++++++++++++------------- 5 files changed, 278 insertions(+), 209 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index b400f94eb..70e74ecd2 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -374,7 +374,7 @@ def connect(uri, *, extra_headers=protocol.extra_headers, ) except Exception: - yield from protocol.close_connection(force=True) + yield from protocol.close_connection(after_handshake=False) raise return protocol diff --git a/websockets/protocol.py b/websockets/protocol.py index 7605827c1..cd8f3c9a5 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -66,7 +66,7 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): The ``timeout`` parameter defines the maximum wait time in seconds for completing the closing handshake and, only on the client side, for terminating the TCP connection. :meth:`close()` will complete in at most - this time on the server side and twice this time on the client side. + ``2 * timeout`` on the server side and ``3 * timeout`` on the client side. The ``max_size`` parameter enforces the maximum size for incoming messages in bytes. The default value is 1MB. ``None`` disables the limit. If a @@ -158,17 +158,19 @@ def __init__(self, *, self._drain_lock = asyncio.Lock(loop=loop) # This class implements the data transfer and closing handshake, which - # are shared between the client-side and the server-side. Subclasses - # implement the opening handshake and execute connection_open() to - # change the state to OPEN. + # are shared between the client-side and the server-side. + # Subclasses implement the opening handshake and, on success, execute + # :meth:`connection_open()` to change the state to OPEN. self.state = CONNECTING + # HTTP protocol parameters. self.path = None self.request_headers = None self.raw_request_headers = None self.response_headers = None self.raw_response_headers = None + # WebSocket protocol parameters. self.extensions = [] self.subprotocol = None @@ -177,8 +179,11 @@ def __init__(self, *, self.close_code = None self.close_reason = '' - # Set to None when the connection state becomes CLOSED. - self.connection_closed = asyncio.Future(loop=loop) + # Completed when the connection state becomes CLOSED. Translates the + # :meth:`connection_lost()` callback to a :class:`~asyncio.Future` + # that can be awaited. (Other :class:`~asyncio.Protocol` callbacks are + # translated by ``self.stream_reader``). + self.connection_lost_waiter = asyncio.Future(loop=loop) # Queue of received messages. self.messages = asyncio.queues.Queue(max_queue, loop=loop) @@ -186,8 +191,11 @@ def __init__(self, *, # Mapping of ping IDs to waiters, in chronological order. self.pings = collections.OrderedDict() - # Task managing the connection, initialized in connection_open. - self.worker_task = None + # Task running the data transfer. + self.transfer_data_task = None + + # Task closing the TCP connection. + self.close_connection_task = None def client_connected(self, reader, writer): """ @@ -275,18 +283,19 @@ def close(self, code=1000, reason=''): frame_data = serialize_close(code, reason) yield from self.write_frame(OP_CLOSE, frame_data) - # If the connection doesn't terminate within the timeout, break out of - # the worker loop. + # If no close frame is received within the timeout, cancel the data + # transfer task in order to exit the infinite loop. try: yield from asyncio.wait_for( - self.worker_task, self.timeout, loop=self.loop) - except asyncio.CancelledError: + asyncio.shield(self.transfer_data_task), + self.timeout, loop=self.loop) + except (asyncio.CancelledError, asyncio.TimeoutError): pass - except asyncio.TimeoutError: - self.worker_task.cancel() + if self.close_code is None: + self.transfer_data_task.cancel() - # The worker should terminate quickly once it has been cancelled. - yield from self.worker_task + # Wait for the close connection task to close the TCP connection. + yield from asyncio.shield(self.close_connection_task) @asyncio.coroutine def recv(self): @@ -323,10 +332,10 @@ def recv(self): self.messages.get(), loop=self.loop) try: done, pending = yield from asyncio.wait( - [next_message, self.worker_task], + [next_message, self.transfer_data_task], loop=self.loop, return_when=asyncio.FIRST_COMPLETED) except asyncio.CancelledError: - # Handle the Task.cancel() + # Propagate cancellation to avoid leaking the next_message Task. next_message.cancel() raise @@ -432,8 +441,12 @@ def encode_data(self, data): def connection_open(self): assert self.state == CONNECTING self.state = OPEN - # Start the task that handles incoming messages. - self.worker_task = asyncio_ensure_future(self.run(), loop=self.loop) + # Start the task that receives incoming WebSocket messages. + self.transfer_data_task = asyncio_ensure_future( + self.transfer_data(), loop=self.loop) + # Start the task that eventually closes the TCP connection. + self.close_connection_task = asyncio_ensure_future( + self.close_connection(), loop=self.loop) @asyncio.coroutine def ensure_open(self): @@ -447,11 +460,12 @@ def ensure_open(self): raise ConnectionClosed(self.close_code, self.close_reason) # If the closing handshake is in progress, let it complete to get the - # proper close status and code. As an safety measure, the timeout is - # longer than the worst case (2 * self.timeout) but not unlimited. + # proper close status and code. As a safety measure, the timeout is + # longer than the worst case (3 * self.timeout) but not unlimited. if self.state == CLOSING: yield from asyncio.wait_for( - self.worker_task, 3 * self.timeout, loop=self.loop) + asyncio.shield(self.transfer_data_task), + 4 * self.timeout, loop=self.loop) raise ConnectionClosed(self.close_code, self.close_reason) # Control may only reach this point in buggy third-party subclasses. @@ -459,40 +473,38 @@ def ensure_open(self): raise InvalidState("WebSocket connection isn't established yet.") @asyncio.coroutine - def run(self): - # This coroutine guarantees that the connection is closed at exit. - while True: - try: + def transfer_data(self): + try: + while True: msg = yield from self.read_message() + # Exit the loop when receiving a close frame. if msg is None: break yield from self.messages.put(msg) - except asyncio.CancelledError: - break - except WebSocketProtocolError: - yield from self.fail_connection(1002) - break - except asyncio.IncompleteReadError: - yield from self.fail_connection(1006) - break - except UnicodeDecodeError: - yield from self.fail_connection(1007) - break - except PayloadTooBig: - yield from self.fail_connection(1009) - break - except Exception: - logger.warning("Error in data transfer", exc_info=True) - yield from self.fail_connection(1011) - raise - yield from self.close_connection() + except asyncio.CancelledError: + # This happens if self.close() cancels self.transfer_data_task. + pass + except WebSocketProtocolError: + yield from self.fail_connection(1002) + except asyncio.IncompleteReadError: + yield from self.fail_connection(1006) + except UnicodeDecodeError: + yield from self.fail_connection(1007) + except PayloadTooBig: + yield from self.fail_connection(1009) + except Exception: + logger.warning("Error in data transfer", exc_info=True) + yield from self.fail_connection(1011) @asyncio.coroutine def read_message(self): # Reassemble fragmented messages. frame = yield from self.read_data_frame(max_size=self.max_size) + + # A close frame was received. if frame is None: return + if frame.opcode == OP_TEXT: text = True elif frame.opcode == OP_BINARY: @@ -551,6 +563,8 @@ def read_data_frame(self, max_size): if frame.opcode == OP_CLOSE: # Make sure the close frame is valid before echoing it. code, reason = parse_close(frame.data) + # 7.1.5. The WebSocket Connection Close Code + # 7.1.6. The WebSocket Connection Close Reason self.close_code, self.close_reason = code, reason if self.state == OPEN: # 7.1.3. The WebSocket Closing Handshake is Started @@ -605,21 +619,10 @@ def write_frame(self, opcode, data=b''): extensions=self.extensions, ) - # Backport of the combined logic of: - # https://github.com/python/asyncio/pull/280 - # https://github.com/python/asyncio/pull/291 + # Backport of https://github.com/python/asyncio/pull/280. # Remove when dropping support for Python < 3.6. - transport = self.writer._transport - if transport is not None: # pragma: no cover - # PR 291 added the is_closing method to transports shortly after - # PR 280 fixed the bug we're trying to work around in this block. - if not hasattr(transport, 'is_closing'): - # This emulates what is_closing would return if it existed. - try: - is_closing = transport._closing - except AttributeError: - is_closing = transport._closed - if is_closing: + if self.writer.transport is not None: # pragma: no cover + if self.writer_is_closing(): yield try: @@ -635,42 +638,94 @@ def write_frame(self, opcode, data=b''): # And raise an exception, since the frame couldn't be sent. raise ConnectionClosed(self.close_code, self.close_reason) + def writer_is_closing(self): + """ + Backport of https://github.com/python/asyncio/pull/291. + + Replace with ``self.writer.transport.is_closing()`` when dropping + support for Python < 3.6 and with ``self.writer.is_closing()`` when + https://bugs.python.org/issue31491 is fixed. + + """ + transport = self.writer.transport + try: + return transport.is_closing() + except AttributeError: # pragma: no cover + # This emulates what is_closing would return if it existed. + try: + return transport._closing + except AttributeError: + return transport._closed + @asyncio.coroutine - def close_connection(self, force=False): - # 7.1.1. Close the WebSocket Connection - if self.state == CLOSED: - return + def wait_for_connection_lost(self): + """ + Wait until the TCP connection is closed or ``self.timeout`` elapses. - # Defensive assertion for protocol compliance. - if self.state != CLOSING and not force: # pragma: no cover - raise InvalidState("Cannot close a WebSocket connection " - "in the {} state".format(self.state_name)) + Return ``True`` if the connection is closed and ``False`` otherwise. - if self.is_client and not force: + """ + if not self.connection_lost_waiter.done(): try: yield from asyncio.wait_for( - self.connection_closed, self.timeout, loop=self.loop) - except (asyncio.CancelledError, asyncio.TimeoutError): + asyncio.shield(self.connection_lost_waiter), + self.timeout, loop=self.loop) + except asyncio.TimeoutError: pass + # Re-check self.connection_lost_waiter.done() synchronously because + # connection_lost() could run between the moment the timeout occurs + # and the moment this coroutine resumes running. + return self.connection_lost_waiter.done() - if self.state == CLOSED: - return + @asyncio.coroutine + def close_connection(self, after_handshake=True): + """ + 7.1.1. Close the WebSocket Connection - # Attempt to terminate the TCP connection properly. - # If the socket is already closed, this may crash. + When the opening handshake succeeds, :meth:`connection_open` starts + this coroutine in a task. It waits for the data transfer phase to + complete then it closes the TCP connection cleanly. + + When the opening handshake fails, the client or the server runs this + coroutine with ``after_handshake=False`` to close the TCP connection. + + """ try: + # Wait for the data transfer phase to complete. + if after_handshake: + yield from self.transfer_data_task + + # A client should wait for a TCP Close from the server. + if self.is_client and after_handshake: + if (yield from self.wait_for_connection_lost()): + return + logger.debug( + "%s ! timed out waiting for TCP close", self.side) + + # Half-close the TCP connection if possible (when there's no TLS). if self.writer.can_write_eof(): + logger.debug( + "%s x half-closing TCP connection", self.side) self.writer.write_eof() - except Exception: # pragma: no cover - pass - self.writer.close() + if (yield from self.wait_for_connection_lost()): + return + logger.debug( + "%s ! timed out waiting for TCP close", self.side) - try: - yield from asyncio.wait_for( - self.connection_closed, self.timeout, loop=self.loop) - except (asyncio.CancelledError, asyncio.TimeoutError): - pass + finally: + # The try/finally ensures that the transport never remains open, + # even if this coroutine is cancelled (for example). + + # Closing a transport is idempotent. If the transport was already + # closed, for example from eof_received(), it's fine. + + # Close the TCP connection. + logger.debug( + "%s x closing TCP connection", self.side) + self.writer.close() + # There's little need to await self.wait_for_connection_lost() + # here. Closing the transport triggers self.connection_lost(). @asyncio.coroutine def fail_connection(self, code=1011, reason=''): @@ -682,15 +737,10 @@ def fail_connection(self, code=1011, reason=''): "%s ! failing WebSocket connection: %d %s", self.side, code, reason, ) - if self.state == OPEN: - if code == 1006: - # Don't send a close frame if the connection is broken. Set - # the state to CLOSING to allow close_connection to proceed. - self.state = CLOSING - else: - frame_data = serialize_close(code, reason) - yield from self.write_frame(OP_CLOSE, frame_data) - yield from self.close_connection() + # Don't send a close frame if the connection is broken. + if self.state == OPEN and code != 1006: + frame_data = serialize_close(code, reason) + yield from self.write_frame(OP_CLOSE, frame_data) # asyncio.StreamReaderProtocol methods @@ -748,11 +798,10 @@ def connection_lost(self, exc): self.state = CLOSED if self.close_code is None: self.close_code = 1006 - if not self.connection_closed.done(): - self.connection_closed.set_result(None) - # Close the transport in case close_connection() wasn't executed. - if self.writer is not None: - self.writer.close() + # If self.connection_lost_waiter isn't pending, that's a bug, because: + # - it's set only here in connection_lost() which is called only once; + # - it must never be cancelled. + self.connection_lost_waiter.set_result(None) super().connection_lost(exc) diff --git a/websockets/server.py b/websockets/server.py index 515ac508c..a6368e728 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -117,7 +117,7 @@ def handler(self): ) yield from self.write_http_response(*early_response) - yield from self.close_connection(force=True) + yield from self.close_connection(after_handshake=False) return @@ -531,12 +531,12 @@ def close(self): self.server.close() # Close open connections. For each connection, two tasks are running: - # 1. self.worker_task shuffles messages between the network and queues + # 1. self.transfer_data_task receives incoming WebSocket messages # 2. self.handler_task runs the opening handshake, the handler provided # by the user and the closing handshake # In the general case, cancelling the handler task will cause the # handler provided by the user to exit with a CancelledError, which - # will then cause the worker task to terminate. + # will then cause the transfer data task to terminate. for websocket in self.websockets: websocket.handler_task.cancel() @@ -554,11 +554,13 @@ def wait_closed(self): """ # asyncio.wait doesn't accept an empty first argument. if self.websockets: - # The handler or the worker task can terminate first, depending - # on how the client behaves and the server is implemented. + # Either the handler or the connection can terminate first, + # depending on how the client behaves and the server is + # implemented. yield from asyncio.wait( [websocket.handler_task for websocket in self.websockets] + - [websocket.worker_task for websocket in self.websockets], + [websocket.close_connection_task + for websocket in self.websockets], loop=self.loop) yield from self.server.wait_closed() diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 64d168ef6..2395b6cc8 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -194,7 +194,7 @@ def start_client(self, path='', **kwds): def stop_client(self): try: self.loop.run_until_complete( - asyncio.wait_for(self.client.worker_task, timeout=1)) + asyncio.wait_for(self.client.close_connection_task, timeout=1)) except asyncio.TimeoutError: # pragma: no cover self.fail("Client failed to stop") @@ -223,9 +223,13 @@ def test_basic(self): reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") - @with_server() def test_server_close_while_client_connected(self): - self.start_client() + with self.temp_server(loop=self.loop): + self.start_client() + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.client.recv()) + # Connection ends with 1001 going away. + self.assertEqual(self.client.close_code, 1001) def test_explicit_event_loop(self): with self.temp_server(loop=self.loop): @@ -714,9 +718,9 @@ def test_server_close_crashes(self, close): def test_client_closes_connection_before_handshake(self, handshake): # We have mocked the handshake() method to prevent the client from # performing the opening handshake. Force it to close the connection. - self.loop.run_until_complete(self.client.close_connection(force=True)) + self.client.writer.close() # The server should stop properly anyway. It used to hang because the - # worker handling the connection was waiting for the opening handshake. + # task handling the connection was waiting for the opening handshake. @with_server() @unittest.mock.patch('websockets.server.read_request') diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index db9637256..c0476084a 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -41,20 +41,34 @@ class TransportMock(unittest.mock.Mock): """ # This should happen in __init__ but overriding Mock.__init__ is hard. - def connect(self, loop, protocol): + def setup_mock(self, loop, protocol): self.loop = loop self.protocol = protocol - # Remove when dropping support for Python < 3.6. + self._eof = False self._closing = False # Simulate a successful TCP handshake. - self.loop.call_soon(self.protocol.connection_made, self) + self.protocol.connection_made(self) # Simulate a successful WebSocket handshake. - self.loop.call_soon(self.protocol.connection_open) + self.protocol.connection_open() + + def can_write_eof(self): + return True + + def write_eof(self): + # When the protocol half-closes the TCP connection, it expects the + # other end to close it. Simulate that. + if not self._eof: + self.loop.call_soon(self.close) + self._eof = True + + def is_closing(self): + return self._closing def close(self): - # Remove when dropping support for Python < 3.6. + # Simulate how actual transports drop the connection. + if not self._closing: + self.loop.call_soon(self.protocol.connection_lost, None) self._closing = True - self.loop.call_soon(self.protocol.connection_lost, None) class CommonTests: @@ -70,11 +84,11 @@ def setUp(self): asyncio.set_event_loop(self.loop) self.protocol = WebSocketCommonProtocol() self.transport = TransportMock() - self.transport.connect(self.loop, self.protocol) + self.transport.setup_mock(self.loop, self.protocol) def tearDown(self): - self.loop.run_until_complete( - self.protocol.close_connection(force=True)) + self.transport.close() + self.loop.run_until_complete(self.protocol.close()) self.loop.close() super().tearDown() @@ -114,24 +128,19 @@ def receive_frame(self, frame): """ writer = self.protocol.data_received mask = not self.protocol.is_client - self.loop.call_soon(functools.partial(frame.write, writer, mask=mask)) + frame.write(writer, mask=mask) def receive_eof(self): """ - Make the protocol receive the end of stream. - - WebSocketCommonProtocol.eof_received returns None — it is inherited - from StreamReaderProtocol. (Returning True wouldn't work on secure - connections anyway.) As a consequence, actual transports close - themselves after calling it. + Make the protocol receive the end of the data stream. - To emulate this behavior, this function closes the transport just - after calling the protocol's eof_received. Closing the transport has - the side-effect calling the protocol's connection_lost. + Since ``WebSocketCommonProtocol.eof_received`` returns ``None``, an + actual transport would close itself after calling it. This function + emulates that behavior. """ - self.loop.call_soon(self.protocol.eof_received) - self.loop.call_soon(self.loop.call_soon, self.transport.close) + self.protocol.eof_received() + self.loop.call_soon(self.transport.close) def receive_eof_if_client(self): """ @@ -183,26 +192,12 @@ def process_invalid_frames(self): Make the protocol fail quickly after simulating invalid data. To achieve this, this function triggers the protocol's eof_received, - which interrupts pending reads waiting for more data. It delays this - operation with call_later because the protocol must start processing - frames first. Otherwise it will see a closed connection and no data. - - """ - self.loop.call_later(MS, self.receive_eof) - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(self.protocol.recv()) + which interrupts pending reads waiting for more data. - def process_control_frames(self): """ - Process control frames received by the protocol. - - To ensure that recv completes quickly, receive an additional dummy - frame, which recv() will drop. - - """ - self.receive_frame(Frame(True, OP_TEXT, b'')) - next_message = self.loop.run_until_complete(self.protocol.recv()) - self.assertEqual(next_message, '') + self.run_loop_once() + self.receive_eof() + self.loop.run_until_complete(self.protocol.close_connection_task) def last_sent_frame(self): """ @@ -273,33 +268,41 @@ def assertCompletesWithin(self, min_time, max_time): def test_local_address(self): get_extra_info = unittest.mock.Mock(return_value=('host', 4312)) self.transport.get_extra_info = get_extra_info - # The connection isn't established yet. - self.assertEqual(self.protocol.local_address, None) - self.run_loop_once() - # The connection is established. + self.assertEqual(self.protocol.local_address, ('host', 4312)) get_extra_info.assert_called_with('sockname', None) + def test_local_address_before_connection(self): + # Emulate the situation before connection_open() runs. + self.protocol.writer, _writer = None, self.protocol.writer + + try: + self.assertEqual(self.protocol.local_address, None) + finally: + self.protocol.writer = _writer + def test_remote_address(self): get_extra_info = unittest.mock.Mock(return_value=('host', 4312)) self.transport.get_extra_info = get_extra_info - # The connection isn't established yet. - self.assertEqual(self.protocol.remote_address, None) - self.run_loop_once() - # The connection is established. + self.assertEqual(self.protocol.remote_address, ('host', 4312)) get_extra_info.assert_called_with('peername', None) + def test_remote_address_before_connection(self): + # Emulate the situation before connection_open() runs. + self.protocol.writer, _writer = None, self.protocol.writer + + try: + self.assertEqual(self.protocol.remote_address, None) + finally: + self.protocol.writer = _writer + def test_open(self): - self.assertFalse(self.protocol.open) - self.run_loop_once() self.assertTrue(self.protocol.open) self.close_connection() self.assertFalse(self.protocol.open) def test_state_name(self): - self.assertEqual(self.protocol.state_name, 'CONNECTING') - self.run_loop_once() self.assertEqual(self.protocol.state_name, 'OPEN') self.close_connection() self.assertEqual(self.protocol.state_name, 'CLOSED') @@ -368,8 +371,6 @@ def read_message(): raise Exception("BOOM") self.protocol.read_message = read_message self.process_invalid_frames() - with self.assertRaises(Exception): - self.loop.run_until_complete(self.protocol.worker_task) self.assertConnectionFailed(1011, '') def test_recv_cancelled(self): @@ -487,12 +488,12 @@ def test_pong_on_closed_connection(self): def test_answer_ping(self): self.receive_frame(Frame(True, OP_PING, b'test')) - self.process_control_frames() + self.run_loop_once() self.assertOneFrameSent(True, OP_PONG, b'test') def test_ignore_pong(self): self.receive_frame(Frame(True, OP_PONG, b'test')) - self.process_control_frames() + self.run_loop_once() self.assertNoFrameSent() def test_acknowledge_ping(self): @@ -501,7 +502,8 @@ def test_acknowledge_ping(self): ping_frame = self.last_sent_frame() pong_frame = Frame(True, OP_PONG, ping_frame.data) self.receive_frame(pong_frame) - self.process_control_frames() + self.run_loop_once() + self.run_loop_once() self.assertTrue(ping.done()) def test_acknowledge_previous_pings(self): @@ -511,13 +513,15 @@ def test_acknowledge_previous_pings(self): ) for i in range(3)] # Unsolicited pong doesn't acknowledge pings self.receive_frame(Frame(True, OP_PONG, b'')) - self.process_control_frames() + self.run_loop_once() + self.run_loop_once() self.assertFalse(pings[0][0].done()) self.assertFalse(pings[1][0].done()) self.assertFalse(pings[2][0].done()) # Pong acknowledges all previous pings self.receive_frame(Frame(True, OP_PONG, pings[1][1].data)) - self.process_control_frames() + self.run_loop_once() + self.run_loop_once() self.assertTrue(pings[0][0].done()) self.assertTrue(pings[1][0].done()) self.assertFalse(pings[2][0].done()) @@ -528,7 +532,8 @@ def test_cancel_ping(self): ping.cancel() pong_frame = Frame(True, OP_PONG, ping_frame.data) self.receive_frame(pong_frame) - self.process_control_frames() + self.run_loop_once() + self.run_loop_once() self.assertTrue(ping.cancelled()) def test_duplicate_ping(self): @@ -620,9 +625,6 @@ def test_connection_lost(self): self.assertConnectionFailed(1006, '') def test_ensure_connection_before_opening_handshake(self): - # Finalize the connection opening sequence. - self.run_loop_once() - # Simulate a bug by forcibly reverting the protocol state. self.protocol.state = CONNECTING @@ -647,16 +649,16 @@ def test_connection_closed_attributes(self): with self.assertRaises(ConnectionClosed) as context: self.loop.run_until_complete(self.protocol.recv()) - connection_closed = context.exception - self.assertEqual(connection_closed.code, 1000) - self.assertEqual(connection_closed.reason, 'close') + connection_closed_exc = context.exception + self.assertEqual(connection_closed_exc.code, 1000) + self.assertEqual(connection_closed_exc.reason, 'close') # Test the protocol logic for closing the connection. def test_local_close(self): # Emulate how the remote endpoint answers the closing handshake. - self.receive_frame(self.close_frame) - self.receive_eof_if_client() + self.loop.call_soon(self.receive_frame, self.close_frame) + self.loop.call_soon(self.receive_eof_if_client) # Run the closing handshake. self.loop.run_until_complete(self.protocol.close(reason='close')) @@ -672,8 +674,8 @@ def test_local_close(self): def test_remote_close(self): # Emulate how the remote endpoint initiates the closing handshake. - self.receive_frame(self.close_frame) - self.receive_eof_if_client() + self.loop.call_soon(self.receive_frame, self.close_frame) + self.loop.call_soon(self.receive_eof_if_client) # Wait for some data in order to process the handshake. # After recv() raises ConnectionClosed, the connection is closed. @@ -690,8 +692,10 @@ def test_remote_close(self): self.assertNoFrameSent() def test_simultaneous_close(self): - self.receive_frame(self.remote_close) - self.receive_eof_if_client() + # Delay the incoming close frame until after we send the outgoing one. + self.loop.call_soon(self.receive_frame, self.remote_close) + self.loop.call_soon(self.receive_eof_if_client) + self.loop.run_until_complete(self.protocol.close(reason='local')) # The close code and reason are taken from the remote side because @@ -701,8 +705,9 @@ def test_simultaneous_close(self): def test_close_preserves_incoming_frames(self): self.receive_frame(Frame(True, OP_TEXT, b'hello')) - self.receive_frame(self.close_frame) - self.receive_eof_if_client() + + self.loop.call_soon(self.receive_frame, self.close_frame) + self.loop.call_soon(self.receive_eof_if_client) self.loop.run_until_complete(self.protocol.close(reason='close')) self.assertConnectionClosed(1000, 'close') @@ -730,8 +735,8 @@ def test_close_connection_lost(self): def test_local_close_during_recv(self): recv = self.ensure_future(self.protocol.recv()) - self.receive_frame(self.close_frame) - self.receive_eof_if_client() + self.loop.call_soon(self.receive_frame, self.close_frame) + self.loop.call_soon(self.receive_eof_if_client) self.loop.run_until_complete(self.protocol.close(reason='close')) @@ -758,7 +763,7 @@ def test_remote_close_during_send(self): # There is no test_local_close_during_send because this cannot really # happen, considering that writes are serialized. - def test_cancelled_close_waits_for_worker(self): + def test_cancelled_close_waits_for_transfer_data_task(self): # Regression test for #142. # Start the closing handshake. @@ -766,19 +771,20 @@ def test_cancelled_close_waits_for_worker(self): self.run_loop_once() self.assertOneFrameSent(*self.close_frame) - # Now close_task is waiting for worker_task which is waiting for the - # closing handshake to complete. + # Now close_task is waiting for transfer_data_task which is waiting + # for the closing handshake to complete. - # Cancelling close_task throws a CancelledError into worker_task, + # Cancelling close_task throws a CancelledError in transfer_data_task, # which catches that exception and waits for close_connection(). self.loop.call_later(MS, close_task.cancel) - # close_task resumes waiting for worker_task. Drop the connection so - # that close_connection(), worker_task and close_task terminate. + # close_task resumes waiting for transfer_data_task. Drop the + # connection so that close_connection(), transfer_data_task and + # close_connection_task terminate. self.loop.call_later(2 * MS, self.receive_eof) # Make sure the worker task terminated before close(). self.loop.run_until_complete(close_task) - self.assertTrue(self.protocol.worker_task.done()) + self.assertTrue(self.protocol.transfer_data_task.done()) class ServerTests(CommonTests, unittest.TestCase): @@ -788,17 +794,26 @@ def setUp(self): self.protocol.is_client = False self.protocol.side = 'server' - def test_close_handshake_timeout(self): - # Timeout is expected in 10ms. + def test_local_close_timeout(self): self.protocol.timeout = 10 * MS + # If the client doesn't send a close frame, time out in 10ms. # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(9 * MS, 19 * MS): - # Unlike previous tests, no close frame will be received in - # response. The server will stop waiting for the close frame and - # timeout. self.loop.run_until_complete(self.protocol.close(reason='close')) self.assertConnectionClosed(1006, '') + def test_local_close_connection_lost_timeout(self): + self.protocol.timeout = 10 * MS + # If the client doesn't close its side of the TCP connection after we + # half-close our side with write_eof(), time out in 10ms. + # Check the timing within -1/+9ms for robustness. + with self.assertCompletesWithin(9 * MS, 19 * MS): + # HACK: disable write_eof => other end drops connection emulation. + self.transport._eof = True + self.receive_frame(self.close_frame) + self.loop.run_until_complete(self.protocol.close(reason='close')) + self.assertConnectionClosed(1000, 'close') + class ClientTests(CommonTests, unittest.TestCase): @@ -807,27 +822,26 @@ def setUp(self): self.protocol.is_client = True self.protocol.side = 'client' - def test_close_handshake_timeout(self): - # Timeout is expected in 2 * 10 = 20ms. + def test_local_close_timeout(self): self.protocol.timeout = 10 * MS + # If the server doesn't send a close frame, time out in 30ms: + # - 10ms waiting for a close frame + # - 10ms waiting for a half-close # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(19 * MS, 29 * MS): - # Unlike previous tests, no close frame will be received in - # response and the connection will not be closed. The client will - # stop waiting for the close frame and timeout, then stop waiting - # for the connection close and timeout again. self.loop.run_until_complete(self.protocol.close(reason='close')) self.assertConnectionClosed(1006, '') - def test_eof_received_timeout(self): - # Timeout is expected in 10ms. + def test_local_close_connection_lost_timeout(self): self.protocol.timeout = 10 * MS + # If the server doesn't half-close its side of the TCP connection + # after we send a close frame, time out in 20ms: + # - 10ms waiting for a half-close + # - 10ms waiting for a close # Check the timing within -1/+9ms for robustness. - with self.assertCompletesWithin(9 * MS, 19 * MS): - # Unlike previous tests, the close frame will be received in - # response but the connection will not be closed. The client will - # stop waiting for the connection close and timeout. + with self.assertCompletesWithin(19 * MS, 29 * MS): + # HACK: disable write_eof => other end drops connection emulation. + self.transport._eof = True self.receive_frame(self.close_frame) self.loop.run_until_complete(self.protocol.close(reason='close')) - self.assertConnectionClosed(1000, 'close') From fcc73bcdc0cb69b67e0acea3c74a871214e1aba4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Sep 2017 23:03:13 +0200 Subject: [PATCH 0328/1539] Stop swallowing cancellations in close(). Fix #264. Ref #142. Remove a test that no longer makes sense now that close() no longer swallows CancelledError. --- websockets/protocol.py | 17 +++++++++++------ websockets/test_protocol.py | 23 ----------------------- 2 files changed, 11 insertions(+), 29 deletions(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index cd8f3c9a5..b93b0dfc1 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -284,15 +284,20 @@ def close(self, code=1000, reason=''): yield from self.write_frame(OP_CLOSE, frame_data) # If no close frame is received within the timeout, cancel the data - # transfer task in order to exit the infinite loop. + # transfer task in order to exit the infinite loop. transfer_data() + # will catch CancelledError and exit without an exception. However + # wait_for() will raise CancelledError anyway. As a consequence, if + # close() is called several times concurrently and one of these calls + # is cancelled, other calls will see that the data transfer task has + # completed. This is why there's no need to catch CancelledError here. try: + # If close() is cancelled during the wait, self.transfer_data_task + # is cancelled before the timeout elapses (on Python ≥ 3.4.3). + # This helps closing connections when shutting down a server. yield from asyncio.wait_for( - asyncio.shield(self.transfer_data_task), - self.timeout, loop=self.loop) - except (asyncio.CancelledError, asyncio.TimeoutError): + self.transfer_data_task, self.timeout, loop=self.loop) + except asyncio.TimeoutError: pass - if self.close_code is None: - self.transfer_data_task.cancel() # Wait for the close connection task to close the TCP connection. yield from asyncio.shield(self.close_connection_task) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index c0476084a..8eebf659c 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -763,29 +763,6 @@ def test_remote_close_during_send(self): # There is no test_local_close_during_send because this cannot really # happen, considering that writes are serialized. - def test_cancelled_close_waits_for_transfer_data_task(self): - # Regression test for #142. - - # Start the closing handshake. - close_task = self.ensure_future(self.protocol.close(reason='close')) - self.run_loop_once() - self.assertOneFrameSent(*self.close_frame) - - # Now close_task is waiting for transfer_data_task which is waiting - # for the closing handshake to complete. - - # Cancelling close_task throws a CancelledError in transfer_data_task, - # which catches that exception and waits for close_connection(). - self.loop.call_later(MS, close_task.cancel) - # close_task resumes waiting for transfer_data_task. Drop the - # connection so that close_connection(), transfer_data_task and - # close_connection_task terminate. - self.loop.call_later(2 * MS, self.receive_eof) - - # Make sure the worker task terminated before close(). - self.loop.run_until_complete(close_task) - self.assertTrue(self.protocol.transfer_data_task.done()) - class ServerTests(CommonTests, unittest.TestCase): From 115b466964404dfc83558ec54db1a05f66ce4d8e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 4 Sep 2017 22:27:49 +0200 Subject: [PATCH 0329/1539] Add some security considerations to the docs. --- docs/index.rst | 1 + docs/security.rst | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+) create mode 100644 docs/security.rst diff --git a/docs/index.rst b/docs/index.rst index 30f4878b6..adfb3d168 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -51,6 +51,7 @@ Bug reports, patches and suggestions welcome! Just open an issue_ or send a cheatsheet api deployment + security limitations changelog license diff --git a/docs/security.rst b/docs/security.rst new file mode 100644 index 000000000..aae067f34 --- /dev/null +++ b/docs/security.rst @@ -0,0 +1,44 @@ +Security +======== + +Memory use +---------- +.. warning:: + + An attacker who can open an arbitrary number of connections will be able + to perform a denial of service by memory exhaustion. If you're concerned + by denial of service attacks, you must reject suspicious connections + before they reach ``websockets``, typically in a reverse proxy. + +The baseline memory use for a connection is about 20kB. + +The incoming bytes buffer, incoming messages queue and outgoing bytes buffer +contribute to the memory use of a connection. By default, each bytes buffer +takes up to 64kB and the messages queue up to 128MB, which is very large. + +Most applications use small messages. Setting ``max_size`` according to the +application's requirements is strongly recommended. See :ref:`buffers` for +details about tuning buffers. + +When compression is enabled, additional memory may be allocated for carrying +the compression context across messages, depending on the context takeover and +window size parameters. With the default configuration, this adds 320kB to the +memory use for a connection. + +You can reduce this amount by configuring the ``PerMessageDeflate`` extension +with lower ``server_max_window_bits`` and ``client_max_window_bits`` values. +These parameters default is 15. Lowering them to 11 is a good choice. + +Finally, memory consumed by your application code also counts towards the +memory use of a connection. + +Other limits +------------ + +``websockets`` implements additional limits on the amount of data it accepts +in order to mimimize exposure to security vulnerabilities. + +In the opening handshake, ``websockets`` limits the number of HTTP headers to +256 and the size of an individual header to 4096 bytes. These limits are 10 to +20 times larger than what's expected in standard use cases. They're hardcoded. +If you need to change them, monkey-patch the constants in ``websockets.http``. From fe6761ef0daab82fc370f1dc1771faea87a59e9b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 24 Sep 2017 12:16:00 +0200 Subject: [PATCH 0330/1539] Add docstrings for all classes and methods. --- websockets/client.py | 8 ++- websockets/compatibility.py | 6 ++ websockets/extensions/permessage_deflate.py | 13 +++- websockets/headers.py | 80 ++++++++++++++++++--- websockets/protocol.py | 52 ++++++++++++-- websockets/py36/protocol.py | 8 +++ websockets/server.py | 33 ++++++--- 7 files changed, 172 insertions(+), 28 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index 70e74ecd2..096c7cf94 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -194,9 +194,8 @@ def process_subprotocol(headers, available_subprotocols): return subprotocol @asyncio.coroutine - def handshake(self, wsuri, origin=None, - available_extensions=None, available_subprotocols=None, - extra_headers=None): + def handshake(self, wsuri, origin=None, available_extensions=None, + available_subprotocols=None, extra_headers=None): """ Perform the client side of the opening handshake. @@ -211,6 +210,9 @@ def handshake(self, wsuri, origin=None, If provided, ``extra_headers`` sets additional HTTP request headers. It must be a mapping or an iterable of (name, value) pairs. + Raise :exc:`~websockets.exceptions.InvalidHandshake` if the handshake + fails. + """ request_headers = [] set_header = lambda k, v: request_headers.append((k, v)) diff --git a/websockets/compatibility.py b/websockets/compatibility.py index c8c301421..c7e08126f 100644 --- a/websockets/compatibility.py +++ b/websockets/compatibility.py @@ -1,3 +1,9 @@ +""" +The :mod:`websockets.compatibility` module provides helpers for bridging +compatibility issues across Python versions. + +""" + import asyncio import http diff --git a/websockets/extensions/permessage_deflate.py b/websockets/extensions/permessage_deflate.py index 7d46a7887..f7fb4f021 100644 --- a/websockets/extensions/permessage_deflate.py +++ b/websockets/extensions/permessage_deflate.py @@ -32,6 +32,10 @@ def _build_parameters( server_max_window_bits, client_max_window_bits, ): + """ + Build a list of ``(name, value)`` pairs for some compression parameters. + + """ params = [] if server_no_context_takeover: params.append(('server_no_context_takeover', None)) @@ -47,6 +51,13 @@ def _build_parameters( def _extract_parameters(params, *, is_server): + """ + Extract compression parameters from a list of ``(name, value)`` pairs. + + If ``is_server`` is ``True``, ``client_max_window_bits`` may be provided + without a value. This is only allow in handshake requests. + + """ server_no_context_takeover = False client_no_context_takeover = False server_max_window_bits = None @@ -81,7 +92,7 @@ def _extract_parameters(params, *, is_server): elif name == 'client_max_window_bits': if client_max_window_bits is not None: raise DuplicateParameter(name) - if is_server and value is None: # only in handshake responses + if is_server and value is None: # only in handshake requests client_max_window_bits = True elif value in _MAX_WINDOW_BITS_VALUES: client_max_window_bits = int(value) diff --git a/websockets/headers.py b/websockets/headers.py index b1459b2b8..276ec850a 100644 --- a/websockets/headers.py +++ b/websockets/headers.py @@ -23,7 +23,14 @@ # definitions from https://tools.ietf.org/html/rfc7230#appendix-B. def peek_ahead(string, pos): - # We never peek more than one character ahead. + """ + Return the next character from ``string`` at the given position. + + Return ``None`` at the end of ``string``. + + We never need to peek more than one character ahead. + + """ return None if pos == len(string) else string[pos] @@ -31,6 +38,14 @@ def peek_ahead(string, pos): def parse_OWS(string, pos): + """ + Parse optional whitespace from ``string`` at the given position. + + Return the new position. + + The whitespace itself isn't returned because it isn't significant. + + """ # There's always a match, possibly empty, whose content doesn't matter. match = _OWS_re.match(string, pos) return match.end() @@ -40,6 +55,14 @@ def parse_OWS(string, pos): def parse_token(string, pos): + """ + Parse a token from ``string`` at the given position. + + Return the token value and the new position. + + Raise :exc:`~websockets.exceptions.InvalidHeader` on invalid inputs. + + """ match = _token_re.match(string, pos) if match is None: raise InvalidHeader("expected token", string=string, pos=pos) @@ -54,6 +77,14 @@ def parse_token(string, pos): def parse_quoted_string(string, pos): + """ + Parse a quoted string from ``string`` at the given position. + + Return the unquoted value and the new position. + + Raise :exc:`~websockets.exceptions.InvalidHeader` on invalid inputs. + + """ match = _quoted_string_re.match(string, pos) if match is None: raise InvalidHeader("expected quoted string", string=string, pos=pos) @@ -61,6 +92,14 @@ def parse_quoted_string(string, pos): def parse_extension_param(string, pos): + """ + Parse a single extension parameter from ``string`` at the given position. + + Return a ``(name, value)`` pair and the new position. + + Raise :exc:`~websockets.exceptions.InvalidHeader` on invalid inputs. + + """ # Extract parameter name. name, pos = parse_token(string, pos) pos = parse_OWS(string, pos) @@ -85,6 +124,15 @@ def parse_extension_param(string, pos): def parse_extension(string, pos): + """ + Parse an extension definition from ``string`` at the given position. + + Return an ``(extension name, parameters)`` pair, where ``parameters`` is a + list of ``(name, value)`` pairs, and the new position. + + Raise :exc:`~websockets.exceptions.InvalidHeader` on invalid inputs. + + """ # Extract extension name. name, pos = parse_token(string, pos) pos = parse_OWS(string, pos) @@ -99,7 +147,7 @@ def parse_extension(string, pos): def parse_extension_list(string, pos=0): """ - Parse a Sec-WebSocket-Extensions header. + Parse a ``Sec-WebSocket-Extensions`` header. The string is assumed not to start or end with whitespace. @@ -118,7 +166,7 @@ def parse_extension_list(string, pos=0): Parameter values are ``None`` when no value is provided. - Raise InvalidHeader if the header cannot be parsed. + Raise :exc:`~websockets.exceptions.InvalidHeader` on invalid inputs. """ # Per https://tools.ietf.org/html/rfc7230#section-7, "a recipient MUST @@ -161,6 +209,12 @@ def parse_extension_list(string, pos=0): def build_extension(name, parameters): + """ + Build an extension definition. + + This is the reverse of :func:`parse_extension`. + + """ return '; '.join([name] + [ # Quoted strings aren't necessary because values are always tokens. name if value is None else '{}={}'.format(name, value) @@ -170,9 +224,9 @@ def build_extension(name, parameters): def build_extension_list(extensions): """ - Unparse a Sec-WebSocket-Extensions header. + Unparse a ``Sec-WebSocket-Extensions`` header. - This is the reverse of parse_extension_list. + This is the reverse of :func:`parse_extension_list`. """ return ', '.join( @@ -182,6 +236,14 @@ def build_extension_list(extensions): def parse_protocol(string, pos): + """ + Parse a protocol definition from ``string`` at the given position. + + Return the protocol and the new position. + + Raise :exc:`~websockets.exceptions.InvalidHeader` on invalid inputs. + + """ name, pos = parse_token(string, pos) pos = parse_OWS(string, pos) return name, pos @@ -189,13 +251,13 @@ def parse_protocol(string, pos): def parse_protocol_list(string, pos=0): """ - Parse a Sec-WebSocket-Protocol header. + Parse a ``Sec-WebSocket-Protocol`` header. The string is assumed not to start or end with whitespace. Return a list of protocols. - Raise InvalidHeader if the header cannot be parsed. + Raise :exc:`~websockets.exceptions.InvalidHeader` on invalid inputs. """ # Per https://tools.ietf.org/html/rfc7230#section-7, "a recipient MUST @@ -239,9 +301,9 @@ def parse_protocol_list(string, pos=0): def build_protocol_list(protocols): """ - Unparse a Sec-WebSocket-Protocol header. + Unparse a ``Sec-WebSocket-Protocol`` header. - This is the reverse of parse_protocol_list. + This is the reverse of :func:`parse_protocol_list`. """ return ', '.join(protocols) diff --git a/websockets/protocol.py b/websockets/protocol.py index b93b0dfc1..242a168db 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -98,7 +98,7 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): The ``write_limit`` argument sets the high-water limit of the buffer for outgoing bytes. The low-water limit is a quarter of the high-water limit. The default value is 64kB, equal to asyncio's default (based on the - current implementation of ``_FlowControlMixin``). + current implementation of ``FlowControlMixin``). As soon as the HTTP request and response in the opening handshake are processed, the request path is available in the :attr:`path` attribute, @@ -139,7 +139,7 @@ def __init__(self, *, self.write_limit = write_limit # Store a reference to loop to avoid relying on self._loop, a private - # attribute of StreamReaderProtocol, inherited from _FlowControlMixin. + # attribute of StreamReaderProtocol, inherited from FlowControlMixin. if loop is None: loop = asyncio.get_event_loop() self.loop = loop @@ -256,7 +256,7 @@ def state_name(self): Current connection state, as a string. Possible states are defined in the WebSocket specification: - CONNECTING, OPEN, CLOSING, or CLOSED. + ``CONNECTING``, ``OPEN``, ``CLOSING``, or ``CLOSED``. To check if the connection is open, use :attr:`open` instead. @@ -435,6 +435,12 @@ def pong(self, data=b''): # Private methods - no guarantees. def encode_data(self, data): + """ + Expect :class:`str` or :class:`bytes`. Return :class:`bytes`. + + :class:`str` are encoded with UTF-8. + + """ # Expect str or bytes, return bytes. if isinstance(data, str): return data.encode('utf-8') @@ -444,6 +450,10 @@ def encode_data(self, data): raise TypeError("data must be bytes or str") def connection_open(self): + """ + Callback when the opening handshake completes. + + """ assert self.state == CONNECTING self.state = OPEN # Start the task that receives incoming WebSocket messages. @@ -455,9 +465,13 @@ def connection_open(self): @asyncio.coroutine def ensure_open(self): - # Raise a suitable exception if the connection isn't open. - # Handle cases from the most common to the least common. + """ + Check that the WebSocket connection is open. + + Raise :exc:`~websockets.exceptions.ConnectionClosed` if it isn't. + """ + # Handle cases from the most common to the least common. if self.state == OPEN: return @@ -479,6 +493,12 @@ def ensure_open(self): @asyncio.coroutine def transfer_data(self): + """ + Read incoming messages and put them in a queue. + + This coroutine runs in a task until the closing handshake is started. + + """ try: while True: msg = yield from self.read_message() @@ -503,7 +523,14 @@ def transfer_data(self): @asyncio.coroutine def read_message(self): - # Reassemble fragmented messages. + """ + Read a single message from the connection. + + Re-assemble data frames if the message is fragmented. + + Return ``None`` when the closing handshake is started. + + """ frame = yield from self.read_data_frame(max_size=self.max_size) # A close frame was received. @@ -559,7 +586,14 @@ def append(frame): @asyncio.coroutine def read_data_frame(self, max_size): - # Deal with control frames automatically and return next data frame. + """ + Read a single data frame from the connection. + + Process control frames received before the next data frame. + + Return ``None`` if a close frame is encountered before any data frame. + + """ # 6.2. Receiving Data while True: frame = yield from self.read_frame(max_size) @@ -595,6 +629,10 @@ def read_data_frame(self, max_size): @asyncio.coroutine def read_frame(self, max_size): + """ + Read a single frame from the connection. + + """ frame = yield from Frame.read( self.reader.readexactly, mask=not self.is_client, diff --git a/websockets/py36/protocol.py b/websockets/py36/protocol.py index 37b7b3477..919f9a038 100644 --- a/websockets/py36/protocol.py +++ b/websockets/py36/protocol.py @@ -2,6 +2,14 @@ async def __aiter__(self): + """ + Iterate on received messages. + + Exit normally when the connection is closed with code 1000. + + Raise an exception in other cases. + + """ try: while True: yield await self.recv() diff --git a/websockets/server.py b/websockets/server.py index a6368e728..d7c115829 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -55,19 +55,28 @@ def __init__(self, ws_handler, ws_server, *, super().__init__(**kwds) def connection_made(self, transport): + """ + Register connection and initialize a task to handle it. + + """ super().connection_made(transport) - # Register the connection with the server when creating the handler - # task. (Registering at the beginning of the handler coroutine would + # Register the connection with the server before creating the handler + # task. Registering at the beginning of the handler coroutine would # create a race condition between the creation of the task, which - # schedules its execution, and the moment the handler starts running.) + # schedules its execution, and the moment the handler starts running. self.ws_server.register(self) self.handler_task = asyncio_ensure_future( self.handler(), loop=self.loop) @asyncio.coroutine def handler(self): - # Since this method doesn't have a caller able to handle exceptions, - # it attemps to log relevant ones and close the connection properly. + """ + Handle the lifecycle of a WebSocket connection. + + Since this method doesn't have a caller able to handle exceptions, it + attemps to log relevant ones and close the connection properly. + + """ try: try: @@ -406,10 +415,10 @@ def handshake(self, origins=None, available_extensions=None, It can be a mapping or an iterable of (name, value) pairs. It can also be a callable taking the request path and headers in arguments. - Raise :exc:`~websockets.exceptions.InvalidHandshake` or a subclass if - the handshake fails. + Raise :exc:`~websockets.exceptions.InvalidHandshake` if the handshake + fails. - Return the URI of the request. + Return the path of the URI of the request. """ path, request_headers = yield from self.read_http_request() @@ -509,9 +518,17 @@ def wrap(self, server): self.server = server def register(self, protocol): + """ + Register a connection with this server. + + """ self.websockets.add(protocol) def unregister(self, protocol): + """ + Unregister a connection with this server. + + """ self.websockets.remove(protocol) def close(self): From ec450f53e7f0bb6fda9793430186507662025788 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 4 Sep 2017 23:26:50 +0200 Subject: [PATCH 0331/1539] Document the design. --- docs/deployment.rst | 98 ++-------- docs/design.rst | 411 +++++++++++++++++++++++++++++++++++++++++ docs/index.rst | 1 + docs/lifecycle.graffle | Bin 0 -> 3073 bytes docs/lifecycle.svg | 3 + docs/limitations.rst | 3 + docs/protocol.graffle | Bin 0 -> 4656 bytes docs/protocol.svg | 3 + 8 files changed, 440 insertions(+), 79 deletions(-) create mode 100644 docs/design.rst create mode 100644 docs/lifecycle.graffle create mode 100644 docs/lifecycle.svg create mode 100644 docs/protocol.graffle create mode 100644 docs/protocol.svg diff --git a/docs/deployment.rst b/docs/deployment.rst index 9ce0745f6..8cdc09152 100644 --- a/docs/deployment.rst +++ b/docs/deployment.rst @@ -1,81 +1,11 @@ Deployment ========== -Backpressure ------------- - -.. note:: - - This section discusses the concept of backpressure from the perspective of - a server but the concepts also apply to clients. The issue is symmetrical. - -With a naive implementation, if a server receives inputs faster than it can -process them, or if it generates outputs faster than it can send them, data -accumulates in buffers, eventually causing the server to run out of memory and -crash. - -The solution to this problem is backpressure. Any part of the server that -receives inputs faster than it can it can process them and send the outputs -must propagate that information back to the previous part in the chain. - -``websockets`` is designed to make it easy to get backpressure right. - -For incoming data, ``websockets`` builds upon :class:`~asyncio.StreamReader` -which propagates backpressure to its own buffer and to the TCP stream. Frames -are parsed from the input stream and added to a bounded queue. If the queue -fills up, parsing halts until some the application reads a frame. - -For outgoing data, ``websockets`` builds upon :class:`~asyncio.StreamWriter` -which implements flow control. If the output buffers grow too large, it waits -until they're drained. That's why all APIs that write frames are asynchronous -in websockets (since version 2.0). - -Of course, it's still possible for an application to create its own unbounded -buffers and break the backpressure. Be careful with queues. - -Buffers -------- - -An asynchronous systems works best when its buffers are almost always empty. - -For example, if a client sends frames too fast for a server, the queue of -incoming frames will be constantly full. The server will always be 32 frames -(by default) behind the client. This consumes memory and adds latency for no -good reason. - -If buffers are almost always full and that problem cannot be solved by adding -capacity (typically because the system is bottlenecked by the output and -constantly regulated by backpressure), reducing the size of buffers minimizes -negative consequences. - -By default ``websockets`` has rather high limits. You can decrease them -according to your application's characteristics. - -Bufferbloat can happen at every level in the stack where there is a buffer. -The receiving side contains these buffers: - -- OS buffers: you shouldn't need to tune them in general. -- :class:`~asyncio.StreamReader` bytes buffer: the default limit is 64kB. - You can set another limit by passing a ``read_limit`` keyword argument to - :func:`~websockets.client.connect` or :func:`~websockets.server.serve`. -- ``websockets`` frame buffer: its size depends both on the size and the - number of frames it contains. By default the maximum size is 1MB and the - maximum number is 32. You can adjust these limits by setting the - ``max_size`` and ``max_queue`` keyword arguments of - :func:`~websockets.client.connect` or :func:`~websockets.server.serve`. - -The sending side contains these buffers: - -- :class:`~asyncio.StreamWriter` bytes buffer: the default size is 64kB. - You can set another limit by passing a ``write_limit`` keyword argument to - :func:`~websockets.client.connect` or :func:`~websockets.server.serve`. -- OS buffers: you shouldn't need to tune them in general. - -Deployment ----------- +Application server +------------------ The author of ``websockets`` isn't aware of best practices for deploying -network services based on :mod:`asyncio`. +network services based on :mod:`asyncio`, let alone application servers. You can run a script similar to the :ref:`server example `, inside a supervisor if you deem that useful. @@ -98,9 +28,9 @@ with the object returned by :func:`~websockets.server.serve`: - calling its ``close()`` method, then waiting for its ``wait_closed()`` method to complete. -Tasks that handle connections will be cancelled, in the sense that -:meth:`~websockets.protocol.WebSocketCommonProtocol.recv` raises -:exc:`~asyncio.CancelledError`. +Tasks that handle connections will be cancelled. For example, if the handler +is awaiting :meth:`~websockets.protocol.WebSocketCommonProtocol.recv`, that +call will raise :exc:`~asyncio.CancelledError`. On Unix systems, shutdown is usually triggered by sending a signal. @@ -119,6 +49,16 @@ projects try to help with this problem. If your server doesn't run in the main thread, look at :func:`~asyncio.AbstractEventLoop.call_soon_threadsafe`. +Memory use +---------- + +In order to avoid excessive memory use caused by buffer bloat, it is strongly +recommended to :ref:`tune buffer sizes `. + +Most importantly ``max_size`` should be lowered according to the expected size +of messages. It is also suggested to lower ``max_queue``, ``read_limit`` and +``write_limit`` if memory use is a concern. + Port sharing ------------ @@ -128,6 +68,6 @@ serve both HTTP and WebSocket on the same port. The author of ``websockets`` doesn't think that's a good idea, due to the widely different operational characteristics of HTTP and WebSocket. -If you need to respond to requests with a protocol other than WebSocket, for -example TCP or HTTP health checks, run a server for that protocol on another -port, within the same Python process, with :func:`~asyncio.start_server`. +``websockets`` provide minimal support for responding to HTTP requests with +the :meth:`~server.WebSocketServerProtocol.process_request()` hook. Typical +use cases include health checks. diff --git a/docs/design.rst b/docs/design.rst new file mode 100644 index 000000000..b114c66db --- /dev/null +++ b/docs/design.rst @@ -0,0 +1,411 @@ +Design +====== + +.. currentmodule:: websockets + +This document describes the design of ``websockets``. It assumes familiarity +with the specification of the WebSocket protocol in :rfc:`6455`. + +It's primarily intended at maintainers. It may also be useful for users who +wish to understand what happens under the hood. + +.. warning: + + Internals described in this document may change at any time. + + Backwards compatibility is only guaranteed for `public APIs `_. + + +Lifecycle +--------- + +State +..... + +WebSocket connections go through a trivial state machine: + +- ``CONNECTING``: initial state, +- ``OPEN``: when the opening handshake is complete, +- ``CLOSING``: when the closing handshake is started, +- ``CLOSED``: when the TCP connection is closed. + +Transitions happen in the following places: + +- ``CONNECTING -> OPEN``: in + :meth:`~protocol.WebSocketCommonProtocol.connection_open()`, which runs + when the :ref:`opening handshake ` completes and the + WebSocket connection is established — not to be confused with + :meth:`~asyncio.Protocol.connection_made` which runs earlier, when the TCP + connection is established; +- ``OPEN -> CLOSING``: in + :meth:`~protocol.WebSocketCommonProtocol.write_frame()` immediately before + sending a close frame; since receiving a close frame triggers sending a + close frame, this does the right thing regardless of which side started the + :ref:`closing handshake `; +- ``* -> CLOSED``: in + :meth:`~protocol.WebSocketCommonProtocol.connection_lost()` which is always + called exactly once when the TCP connection is closed. + +Coroutines +.......... + +The following diagram shows which coroutines are running at each stage of the +connection lifecycle on the client side. + +.. image:: lifecycle.svg + :target: _images/lifecycle.svg + +The lifecycle is identical on the server side, except inversion of control +makes the equivalent of :meth:`~client.connect()` implicit. + +Coroutines shown in green are called by the application. Multiple coroutines +may interact with the WebSocket connection concurrently. + +Coroutines shown in gray manage the connection. When the opening handshake +succeeds, :meth:`~protocol.WebSocketCommonProtocol.connection_open()` starts +two tasks: + +- :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` runs + :meth:`~protocol.WebSocketCommonProtocol.transfer_data()` which handles + incoming data and lets :meth:`~protocol.WebSocketCommonProtocol.recv()` + consume it. It never exits with an exception but it may be cancelled. + See :ref:`data transfer ` below. +- :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` runs + :meth:`~protocol.WebSocketCommonProtocol.close_connection()` which waits for + the data transfer to terminate, then takes care of closing the TCP + connection. It never exits with an exception and it is never cancelled. See + :ref:`connection termination ` below. + + +.. _opening-handshake: + +Opening handshake +----------------- + +``websockets`` performs the opening handshake when establishing a WebSocket +connection. On the client side, :meth:`~client.connect()` executes it before +returning the protocol to the caller. On the server side, it's executed before +passing the protocol to the ``ws_handler`` coroutine handling the connection. + +While the opening handshake is asymmetrical — the client sends an HTTP Upgrade +request and the server replies with an HTTP Switching Protocols response — +``websockets`` aims at keepping the implementation of both sides consistent +with one another. + +On the client side, :meth:`~client.WebSocketClientProtocol.handshake()`: + +- builds a HTTP request based on the ``uri`` and parameters passed to + :meth:`~client.connect()`; +- writes the HTTP request to the network; +- reads a HTTP response from the network; +- checks the HTTP response, validates ``extensions`` and ``subprotocol``, and + configures the protocol accordingly; +- moves to the ``OPEN`` state. + +On the server side, :meth:`~server.WebSocketServerProtocol.handshake()`: + +- reads a HTTP request from the network; +- calls :meth:`~server.WebSocketServerProtocol.process_request()` which may + abort the WebSocket handshake and return a HTTP response instead; this + hook only makes sense on the server side; +- checks the HTTP request, negociates ``extensions`` and ``subprotocol``, and + configures the protocol accordingly; +- builds a HTTP response based on the above and parameters passed to + :meth:`~server.serve()`; +- writes the HTTP response to the network; +- moves to the ``OPEN`` state; +- returns the ``path`` part of the ``uri``. + +The most significant assymetry between the two sides of the opening handshake +lies in the negociation of extensions and, to a lesser extent, of the +subprotocol. The server knows everything about both sides and decides what the +parameters should be for the connection. The client merely applies them. + +If anything goes wrong during the opening handshake, ``websockets`` closes the +TCP connection. This is the proper way to fail the WebSocket connection before +it's established. + + +.. _data-transfer: + +Data transfer +------------- + +Symmetry +........ + +Once the opening handshake has completed, the WebSocket protocol enters the +data transfer phase. This part is almost symmetrical. There are only two +differences between a server and a client: + +- `client-to-server masking`_: the client masks outgoing frames; the server + unmasks incoming frames; +- `closing the TCP connection`_: the server closes the connection immediately; + the client waits for the server to do it. + +.. _client-to-server masking: https://tools.ietf.org/html/rfc6455#section-5.3 +.. _closing the TCP connection: https://tools.ietf.org/html/rfc6455#section-5.5.1 + +These differences are so minor that all the logic for `data framing`_, for +`sending and receiving data`_ and for `closing the connection`_ is implemented +in the same class, :class:`~protocol.WebSocketCommonProtocol`. + +.. _data framing: https://tools.ietf.org/html/rfc6455#section-5 +.. _sending and receiving data: https://tools.ietf.org/html/rfc6455#section-6 +.. _closing the connection: https://tools.ietf.org/html/rfc6455#section-7 + +The :attr:`~protocol.WebSocketCommonProtocol.is_client` attribute tells which +side a protocol instance is managing. This attribute is defined on the +:attr:`~server.WebSocketServerProtocol` and +:attr:`~client.WebSocketClientProtocol` classes. + +Data flow +......... + +The following diagram shows how data flows between an application built on top +of ``websockets`` and a remote endpoint. It applies regardless of which side +is the server or the client. + +.. image:: protocol.svg + :target: _images/protocol.svg + +Public methods are shown in green, private methods in yellow, and buffers in +orange. Methods related to connection termination are omitted; connection +termination is discussed in another section below. + +Receiving data +.............. + +The left side of the diagram shows how ``websockets`` receives data. + +Incoming data is written to a :class:`~asyncio.StreamReader` in order to +implement flow control and provide backpressure on the TCP connection. + +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`, which is started +when the WebSocket connection is established, processes this data. + +When it receives data frames, it reassembles fragments and puts the resulting +messages in the :attr:`~protocol.WebSocketCommonProtocol.messages` queue. + +When it encounters a control frame: + +- if it's a close frame, it starts the closing handshake; +- if it's a ping frame, it anwsers with a pong frame; +- if it's a pong frame, it acknowledges the corresponding ping (unless it's an + unsolicited pong). + +Running this process in a task guarantees that control frames are processed +promptly. Without such a task, ``websockets`` would depend on the application +to drive the connection by having exactly one coroutine awaiting +:meth:`~protocol.WebSocketCommonProtocol.recv()` at any time. While this +happens naturally in many use cases, it cannot be relied upon. + +Then :meth:`~protocol.WebSocketCommonProtocol.recv()` fetches the next message +from the :attr:`~protocol.WebSocketCommonProtocol.messages` queue, with some +complexity added for handling termination correctly. + +Sending data +............ + +The right side of the diagram shows how ``websockets`` sends data. + +:meth:`~protocol.WebSocketCommonProtocol.send()` writes a single data frame +containing the message. Fragmentation isn't supported at this time. + +:meth:`~protocol.WebSocketCommonProtocol.ping()` writes a ping frame and +returns a :class:`~asyncio.Future` which will be completed when a matching +pong frame is received. + +:meth:`~protocol.WebSocketCommonProtocol.pong()` writes a pong frame. + +:meth:`~protocol.WebSocketCommonProtocol.close()` writes a close frame and +waits for the TCP connection to terminate. + +Outgoing data is written to a :class:`~asyncio.StreamWriter` in order to +implement flow control and provide backpressure from the TCP connection. + +.. _closing-handshake: + +Closing handshake +................. + +When the other side of the connection initiates the closing handshake, +:meth:`~protocol.WebSocketCommonProtocol.read_message()` receives a close +frame while in the ``OPEN`` state. It moves to the ``CLOSING`` state, sends a +close frame, and returns ``None``, causing +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to terminate. + +When this side of the connection initiates the closing handshake with +:meth:`~protocol.WebSocketCommonProtocol.close()`, it moves to the ``CLOSING`` +state and sends a close frame. When the other side sends a close frame, +:meth:`~protocol.WebSocketCommonProtocol.read_message()` receives it in the +``CLOSING`` state and returns ``None``, also causing +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to terminate. + +If the other side doesn't send a close frame within the connection's timeout, +:meth:`~protocol.WebSocketCommonProtocol.close()` cancels +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`, which has the +same effect. + +Then ``websockets`` terminates the TCP connection. + + +.. _connection-termination: + +Connection termination +---------------------- + +:attr:`~protocol.WebSocketCommonProtocol.close_connection_task`, which is +started when the WebSocket connection is established, is responsible for +eventually closing the TCP connection. + +First :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` waits +for :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to terminate, +which may happen as a result of: + +- a successful closing handshake: as explained above, this exits the infinite + loop in :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`; +- a timeout while waiting for the closing handshake to complete: this cancels + :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`; +- a protocol error, including connection errors: depending on the exception, + :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` fails the + WebSocket connection with a suitable code and exits. + +:attr:`~protocol.WebSocketCommonProtocol.close_connection_task` is separate +from :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to make it +easier to implement the timeout on the closing handshake. Cancelling +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` creates no risk +of cancelling :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` +and failing to close the TCP connection, thus leaking resources. + + +.. _cancellation: + +Cancellation +------------ + +Most :doc:`public APIs ` of ``websockets`` are coroutines. They may be +cancelled. ``websockets`` must handle this situation. + +Cancellation during the opening handshake is handled like any other exception: +the TCP connection is closed and the exception is re-raised or logged. + +Once the WebSocket connection is established, +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` and +:attr:`~protocol.WebSocketCommonProtocol.close_connection_task` mustn't get +accidentally cancelled if a coroutine that awaits them is cancelled. They must +be shielded from cancellation. + +:meth:`~protocol.WebSocketCommonProtocol.recv()` waits for the next message in +the queue or for :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` +to terminate, whichever comes first. It relies on :func:`~asyncio.wait()` for +waiting on two tasks in parallel. As a consequence, even though it's waiting +on the transfer data task, it doesn't propagate cancellation to that task. + +:meth:`~protocol.WebSocketCommonProtocol.ensure_open()` is called by +:meth:`~protocol.WebSocketCommonProtocol.send()`, +:meth:`~protocol.WebSocketCommonProtocol.ping()`, and +:meth:`~protocol.WebSocketCommonProtocol.pong()`. When the connection state is +``CLOSING``, it waits for +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` but shields it to +prevent cancellation. + +:meth:`~protocol.WebSocketCommonProtocol.close()` waits for the data transfer +task to terminate with :func:`~asyncio.wait_for`. If it's cancelled or if the +timout elapses, :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` +is cancelled. :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` is +expected to catch the cancellation and terminate properly. This is the only +point where it may be cancelled. + +:meth:`~protocol.WebSocketCommonProtocol.close()` then waits for +:attr:`~protocol.WebSocketCommonProtocol.close_connection_task` but shields it +to prevent cancellation. + +:attr:`~protocol.WebSocketCommonProtocol.close_connnection_task` starts by +waiting for :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`. +Since :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` handles +:exc:`~asyncio.CancelledError`, cancellation doesn't propagate to +:attr:`~protocol.WebSocketCommonProtocol.close_connnection_task`. + + +.. _backpressure: + +Backpressure +------------ + +.. note:: + + This section discusses backpressure from the perspective of a server but + the concept applies to clients symmetrically. + +With a naive implementation, if a server receives inputs faster than it can +process them, or if it generates outputs faster than it can send them, data +accumulates in buffers, eventually causing the server to run out of memory and +crash. + +The solution to this problem is backpressure. Any part of the server that +receives inputs faster than it can it can process them and send the outputs +must propagate that information back to the previous part in the chain. + +``websockets`` is designed to make it easy to get backpressure right. + +For incoming data, ``websockets`` builds upon :class:`~asyncio.StreamReader` +which propagates backpressure to its own buffer and to the TCP stream. Frames +are parsed from the input stream and added to a bounded queue. If the queue +fills up, parsing halts until some the application reads a frame. + +For outgoing data, ``websockets`` builds upon :class:`~asyncio.StreamWriter` +which implements flow control. If the output buffers grow too large, it waits +until they're drained. That's why all APIs that write frames are asynchronous. + +Of course, it's still possible for an application to create its own unbounded +buffers and break the backpressure. Be careful with queues. + + +.. _buffers: + +Buffers +------- + +.. note:: + + This section discusses buffers from the perspective of a server but it + applies to clients as well. + +An asynchronous systems works best when its buffers are almost always empty. + +For example, if a client sends data too fast for a server, the queue of +incoming messages will be constantly full. The server will always be 32 +messages (by default) behind the client. This consumes memory and increases +latency for no good reason. The problem is called bufferbloat. + +If buffers are almost always full and that problem cannot be solved by adding +capacity — typically because the system is bottlenecked by the output and +constantly regulated by backpressure — reducing the size of buffers minimizes +negative consequences. + +By default ``websockets`` has rather high limits. You can decrease them +according to your application's characteristics. + +Bufferbloat can happen at every level in the stack where there is a buffer. +For each connection, the receiving side contains these buffers: + +- OS buffers: tuning them is an advanced optimization. +- :class:`~asyncio.StreamReader` bytes buffer: the default limit is 64kB. + You can set another limit by passing a ``read_limit`` keyword argument to + :func:`~client.connect()` or :func:`~server.serve()`. +- Incoming messages :class:`~asyncio.queues.Queue`: its size depends both on + the size and the number of messages it contains. By default the maximum + UTF-8 encoded size is 1MB and the maximum number is 32. In the worst case, + after UTF-8 decoding, a single message could take up to 4MB of memory and + the overall memory consumption could reach 128MB. You should adjust these + limits by setting the ``max_size`` and ``max_queue`` keyword arguments of + :func:`~client.connect()` or :func:`~server.serve()` according to your + application's requirements. + +For each connection, the sending side contains these buffers: + +- :class:`~asyncio.StreamWriter` bytes buffer: the default size is 64kB. + You can set another limit by passing a ``write_limit`` keyword argument to + :func:`~client.connect()` or :func:`~server.serve()`. +- OS buffers: tuning them is an advanced optimization. diff --git a/docs/index.rst b/docs/index.rst index adfb3d168..5f22282ce 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -52,6 +52,7 @@ Bug reports, patches and suggestions welcome! Just open an issue_ or send a api deployment security + design limitations changelog license diff --git a/docs/lifecycle.graffle b/docs/lifecycle.graffle new file mode 100644 index 0000000000000000000000000000000000000000..b1df4afd045d331a7492f39c3855171566d6a700 GIT binary patch literal 3073 zcmV+c4F2;UiwFP!000030PS5{SKGK2elEYlhllfU1{_JYY^!CM6#`S3LkUAtdfJt> zW(iwKJUDi0I|ZhT|9-Zdi?4x%LR(t%P&?L^v?c5NwzMVh?)-K=@{~^^jNHKgbrY*- zQxU!$IIchVb@Ofa_ibzQx804Mf9I3B-jG_J0$R4%1dI~Ym9sBszwV-dzT$57ZdWKf;h*@SAQ zrRUUvr;cmKyBjZdF2&7m(~jLw;+Wrv@W^+>^-e=Rg+$kn#Xy9+Xs0p1oo9i%U7N>H zfjK(x63iAU|w9k}UGnD=GU zmDx9$=IFTph?0ET<)uu4HyQa&&mH(BrEpoPM7lHPHWaat4+ZyjHL9}YkLZjNYB9Fc zND}5*;fEaE`#dh`M-Vo*5!*((g7Hfoztjxn7X*JLZDk=rSj^wN8Hv!fmF8psqgVD< zKHI(kfS?+43=xs<3F<%h!iBw}EvT$wXiWrmx4 zmb`M=y4kzWf}(`iv_Du505i#U7!QL%;B#+}hmxH}FoNt%?m0bG$)5&pwlNDR=$fQM zO*aVAO==nzAOo0D)l8F7jCG1lf{fd8_Z3Qy+&J@M=PlpH*_O++M33yNTD;P%QiQ6Q zfi5P))Xh*BLhX0qD~m@Q$p&XA9Vvy|H^NyTCnq;k(V@4KG&M5FwhBIqGy zm&TvCj+ggbJ~bURy)f?M9`_?R`Ntj)FrlPp2X??^5)jt(M>M&B--c-|ZH(CNk3)Ijhr#w58Imb?(6Cl?FJi>4(a8cRfgN#zB!M$_E#V&Fu?0Mn6_ z>xMp$Cn5KGA8q*{b_c^)_;-LXB<+fbuJr6aQXC$0B@Ur?`XcOY{F{%)uN2z@ z6XOq*A!rUm=Th9>H`QDn$3=6aXS^Pl zYpZov6=q>(J(JBfCY9EgEkW+04*hF61PCT71%4zNqozf4Q?9*>3!Mi$1f<(lBLsLD z*q9=uk%xH*Ftb!WnU@(%gN2tV{dt8)%lrey)|dpW2F9udrh#A!0ia+qSo;TlM*o0? z*Fmzjwlm=;BS!r6#Hf8MVl?QJ5u?rwM9ptbjD~i<+qEXfniwC07}xW%VoO)$h&D`J zLm2BOk&C>d5vMTtiKS|~p;O3inusw2Ki(qm!NlAOCZQ|B6LEFgdDC5)ozc46NEWFe z$2WZ7i+gzlK_@Tc2}s*snurT~vDvika^kxDao>8lC6`n=aAYK=A{W`@1^wzM|Ccp| z^PGc)8dV)921z$5qJ+v#wDPmtFJ|Oc5c(oK zQ-2>pyFAAvbg9d}ED(6e_v zTx6UNp|C$G4J8u31D6fi0T)OaOMgT|83*z*&7&lH0z!F0f8ugeQ}QSN^ri;are2ZZ+RwM2BGvRUhO+yo2WR3L++bDG@sZ5#n38+|!==Od%jmn9TayQcnE^BgXvCTY3qhvSA9jNj zJ$PUAl*ch3wsKn(<8Qw&`tu-%JQO*U64j&>5JMQ!bd4J8Ag3yKjhq@e&j>mAp~%rO zRuRz*%nXWQpqq3Z_pQjPkX`WmUv2dPIVuOAVrrKuSHuAy3gHeW;ac|f&&dT{cX zN0*9F)4yk&GE(ET#_97gsqK@KgI0GHt6B&y`8sr+tiL87Gg77&q6*G;e^o1w#&oef zJKOMg5LRYiEZ|W^FvwwW70FzsP?PTuI;9Ac9iNY*Zm@!UKHB5!iR*~yIIu5;Q;Zy# z8_yZ(b2oA?%25gn)g-=SpiGb8Qi=rpCD|uI?>%*#HM^ zgGs!6h&O$AlnH&gM3F?)c0)UOnNoav6Hc zV4EF@T0=ep^mK+IbB<~{@E53f6F6>v7Q*+gAg9<=Fng(!ml~=Layl(YtR3a_B#6c8 zEoy^p^5a||-@zf%IJ7GlE~#k`Uejf?f~-XIa@?GB_JTMLMsIkSiG-VVpU;c#djXI0 zsGi;HIFI0zv8FW~2BG_p{9YaOo&{~Xt0EgLy;E7yUNip3IsC(W{|+PZzdi56(P7Zk z|J?iYsB7ygTTN~YrmC7Gy`qxO z2F9}p#7H+ZLJez6F&VkNr&2lk&V-ESEh-@K9A5vKoU}dV`|k$hrzk$p-^q(;`$AI3 zl6l9O;7;53Ztmxlv*2pgkk%kt9nadx$FEMfzw5?j`kThx&P`ONIYuy{PSlW*pr_Nj zX=L~nW_L+uIRW6lE!NEY)^|TnM7q|zqxEkiap1>pcvrWdS5GvxQep30AC`{Gsk}7_ zL+FX5C(HBnifQ`6Yz6e!w2L$IxYYK}HhaSyJ*6JIErlir#fb@YK(IzCg3im;KyEKN z4VSV=Kbv@|MJ@@z>$%|fCFeH)Q)7!NS`Luwc>Ll}NNgA8K$=#VGR@5LnHc1lD?OaA zQlP@*Q!R^ZDaC|qYx18l P>~8!Y*@LYtBVqsmQ?K-A literal 0 HcmV?d00001 diff --git a/docs/lifecycle.svg b/docs/lifecycle.svg new file mode 100644 index 000000000..d783421f9 --- /dev/null +++ b/docs/lifecycle.svg @@ -0,0 +1,3 @@ + + + Produced by OmniGraffle 6.6.2 2017-09-17 19:42:30 +0000Canvas 1Layer 1CONNECTINGOPENCLOSINGCLOSEDtransfer_dataclose_connectionconnectrecv / send / ping / pong / close opening handshakeconnectionterminationdata transfer& closing handshake diff --git a/docs/limitations.rst b/docs/limitations.rst index d0b9743fc..bd6d32b2f 100644 --- a/docs/limitations.rst +++ b/docs/limitations.rst @@ -5,3 +5,6 @@ The client doesn't attempt to guarantee that there is no more than one connection to a given IP address in a CONNECTING state. The client doesn't support connecting through a proxy. + +There is no way to fragment outgoing messages. A message is always sent in a +single frame. diff --git a/docs/protocol.graffle b/docs/protocol.graffle new file mode 100644 index 0000000000000000000000000000000000000000..98b4cdb581d83dbd8aa3041d77ba003c72a71a96 GIT binary patch literal 4656 zcmV-063^`)iwFP!000030PS7rbK5ww|2+9C^z!j-Oyeez&1B25olRzw*b_O)o0O}j zXbG0Np-2r$#qp%_zrO}`@c=DK@h!+z*1^U>5dG@|jmEQoeH(_IuPRA{IQsQ5cF^OF zih6NBhz7qt{&4hi%X|E<=Z~KK&)&iA(I1EVolzL1S?BP>i#M-#JCC=vw|7RPP;GDT z9qn}v-@NV~b)b&z?frL;JCDz@Z1i+{`~3Xe@impxi-(#p-9Aj>kxH_QH&EFY6ma@k z|1mT(slC)4__ZJOvgeO}dG zBX<4htDhbu!5~=7h1p={V_AHZixS=7Z&jl&jCrk89FB+4P8bZLTCZ@uS2=Ze;uzOY4eG?7x%_P6=hX!o)0aljqP-Try6h!++Uut6*epI8ROL9j8eoH{O6K-Pdd!Q zj^{~NAn9TxBxRZtJ?Da0(=&ghOpH1m z#5l@cN2hT!pJ@sMo9)KqBv8pab$-}=dsM1;%G`xXuYb$u9Y*sBSY*u_YAha#5GC9& z0Y|ONr$HDNn_gTtLq`*K6ni)Yr8B=DpI=km6~Oz6x+mnHf_`>ZBA+}1hZi0~=_ynm z$dUS%HIH%$pNfwRm7myz`q1y!W`kGiN4>=b2OTHbDL(e2G|2z;MgvSZIqt>1*w>dh zVv{fF_yoRt6g~QWd|2X@`1%m$=ll)D1`i&5lzj;W% zQJRxbnUgSW^c zi=7f6h_ujQo0{L~NA z=~SH&0jm~m#1edy3VctI*$=|ideR_%kmI(NKhS3LlPU7@qW9b=eag_Ypsp$x8wh^H z0DkRoqsHaSxo=UL5$*t$$6SfIV3OgD;@vJh)i2Ylb=J{EOCF#`?QI37zgZ4m!%89^ zbtK~iCWRz~PzqCmO5u1~8erFyxlp}dYwKBCZ$oUoGzpc>dT>e*x+x3Lxa3h$dxW@Q3lOHegRzH^=Q8LKbYpp^n)N_N>Z>@sv2>d? zplWGr!`-qCw}Mz%!-OMzViIGHvA?C_e5Xulld4=Zp<)xv^4xxfzmr2ic<)*ITY`786#P#B8^D>et zo`GVR(ikIxJ;Xc&ayyUjcaei!&!rLzK|rLFd$5eO^&$@Y?bnjZnA8f5O)J80{XwLb zA{heUYQhS*w$*qMS9S>$uWi*Yt_mW{7KwJXcfZ6T1&RsDN%8VrHDPG-64Uq6>esa7 zz)%t!=n`WK-lSrT<>Uh`ul^Px?QSh#UGs|6?W`V6neXl{A^p{BS}Quyt3_@K!rg@2D(}}O@`+K>wH9ivf93)EXHpgQR|eb%x8%>)UGisAI7krbNkW+< zjDpcD~UZSPEK174fMKWuIOsVyI`#6O18RcEm0Sq6G= zJNq6S1JSh^NVtxR1*gPCF2Rh7l^DoSy6QIbtOT+QbkAeORtpg!YXp|DwYO0iBb(ZD zu%UF_!(rp1tVVdF0*$S`o1yB7bJ@}feMw<=IWJjuxQ{V^{7L}b%ExtvbDh9z=B^WX zU6`LMY8Qh#3eO`~0=sW$7l*lM(mIV0H}$r(A* zdyk=whZyNwRE~!#P5psNk8kQogQjlbyUjl`{C<0p_=^_I&FYpoKqX46WvQhAlDUWSC!ZI``T zoNZuL^aqfy(k^@b9Fy~PnEXnYwyW_Q%?UG&r7K4p2xgi5{xdnXOn!w-j@fnl7pZ%j zNyX~Ft?OQN9YP9AcM^MRxf@#^LD-aJteVto`5hL=XS&QU2olmoSVC#Z9pZBCxyS>P z8SbjFv>`j1K+*v-!^501uB*ltVK>p!Ntwmj)x=gi=)A#vjMXr}8gnx&ZZj%~0_wQS z5vSpn2ts-6e`~bEm^y&-fOi-pk4kSPq%~@zn{$|g#mdVxE^=EDNv1R~^(>(_BB=?D znT<%gD-lU{wZ_ID{qz#cc@kvm&-|dSE6l25+GSyIH`_Dj4G2r^8T0nsL3_r0%d-ai ze&*lML4&Koox$;L;m!be*utHKI}3L$;qER1UHyCy3v>pcLl)>P&{?2s33MkH*)8In zY2RY?Xkk}aH(?Xl5t|m&x(ThD@F8te^!tBK!?CZ>PiVy1dQhvy*(RK|8^kTnTAY1g zIIDe{SH2`;g4${vp4DiZk(OGdwMc7`_F*CIO{{whmpK|b1qXy>pw3E-vc#iK8Bu}K z1Bk{<0cv<%GUh!alr(QLKferzlp zM?qM?;weMk7f)}SGo3(*I**et8-L0OAvYE4P6>CU#|Q^4%@ZJpc#DlJ*B$F#jS!Q3 zl~{MvExBTn+l2_iJeb>di}QxYo3rbUH{Y}kljOQLOcH!kNNCE(cI!xJ>eMnUyF4G8 z-KF@DkZ=vt)U*|>YJ_@u=*}W_A0b${QfE%9{wCxGh3;T!eSd9ze|lRu!Z1GosKAo( zaQmcXElMV9l-y8bcQ;NVmAv=+!8omh%*3G&sxp;evw0kg`e{WE+IxAw)2FpHv5OAR zEP>%XRmo&|Df81W)5K<-TTibxkNP@-UNw0E7PICs^m}TkO6r_BW^WBSS&iFg0JhX4 zj?m(dVi*h$)G;k5P5Z?SN69=$Ir-{k7#K3e{juMz_uDK}-FqlgAA07{O<90zbu!+4 zPZatP70GuXCiL!#U;ffoEVGGUA7e@pG^1oJk_NZ zHz$j)W~s-{69*Kj^b|5FZ5y_c0LI?_b*hkXnF2#u82l``C($*A7PB>>Vg@f(steW7^w55)izCV=xsdCny2_q6b$`L)h1BIb)>KMf}|I} z4Whju%_{6rJjH53+f~56We=9Yqn6W$c{lf?K@ipZ(vP!RXCmcivEF3xMn!`HG|w8C z`P;zgLi)fhR%tLVzd{GI^LP}rqAxQWbn#0Q_tcqfx@5BgeG1&uJyV5sp2@_YbH&@Z zADm8Y-o3(;o>fW^cDIn+0zk#`DV0wN?fgTJT&=CVq5`i@FYjCoRTA_%JLADP&4Q>} z?j7KB`#U;F^rrCA7@N0^%&1|;w60u3(+hY^U(yz|nwy*X=3VzioMrLwt)C2n!gMGs za!ujt+z)(zk{eU@9GWpz}D&;8+x&m{oQ}{TlmK( zjMV?V2tU7m74NX$U;O^MyMMNGcJKjleDdkHus7I+@BM=h`19GxtB>IhT;JX2M@Qth zzdnC@hjw0n+TZE^1vD?eh3C8fzP$Ty)M@rJ!=pjvvic9foJdL-ZK!_oW~@K*w?dU= zuuNh z{Ey<%byQ!*@BBd=L`p+PgWjQ6l@6ltVmYR~kIze~axYPLfV84IP9zOoXqr=%5wQO# z$m;O7>qlSxw4|s^N(@1#&by&q3={Qkr>B>%T4;B9VXXl8QBSQ{_Cpl>HC77?`;xo= zkgEMC3zEx*{bO52^HUIk5!npoyW=E*k;q4~uBRK+^!=$x^yg%V3wqo__D($(NeP}7 z(WeYhfbe{?Aod7+!P+0^jPx@Hvi$RST@rk|z1+gXx+Y m;lL33VqlP847HnlFGVn_0mtHf6auF_fAs$g=`ROJ@Bjcs1u7i? literal 0 HcmV?d00001 diff --git a/docs/protocol.svg b/docs/protocol.svg new file mode 100644 index 000000000..7108927b8 --- /dev/null +++ b/docs/protocol.svg @@ -0,0 +1,3 @@ + + + Produced by OmniGraffle 6.6.2 2017-09-24 19:39:13 +0000Canvas 1Layer 1remote endpointwebsocketsWebSocketCommonProtocolapplication logicreaderStreamReaderwriterStreamWriterpingsdicttransfer_data_taskTasknetworkread_frameread_data_frameread_messagebytesframesdataframeswrite_framemessagesQueuerecvsendpingpongclosecontrolframesbytesframes From 4f2a6ac7a2840d305812ff642f32b20bdcba9640 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 30 Sep 2017 13:38:57 +0200 Subject: [PATCH 0332/1539] Run tests on macOS and Windows. --- .travis.yml | 3 +-- appveyor.yml | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/.travis.yml b/.travis.yml index 4c0abc0b5..ea53a1f54 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,8 +2,7 @@ env: global: # websockets only works on Python >= 3.4. - CIBW_SKIP="cp27-* cp33-*" - # Commented out because tests don't pass reliably on macOS, see #241. - # - CIBW_TEST_COMMAND="python3 -m unittest discover websockets" + - CIBW_TEST_COMMAND="python3 -m unittest websockets" matrix: include: diff --git a/appveyor.yml b/appveyor.yml index 73ffb93ee..0114320f9 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -1,8 +1,7 @@ environment: # websockets only works on Python >= 3.4. CIBW_SKIP: cp27-* cp33-* -# Commented out because tests don't pass reliably on Windows, see #240. -# CIBW_TEST_COMMAND: python -m unittest discover websockets + CIBW_TEST_COMMAND: python -m unittest websockets # Since Python 2 is still the default, invoke Python 3 explicitly. install: From bbf9563a6387ef6eaaab0becd4eed273e3f41915 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 30 Sep 2017 13:45:02 +0200 Subject: [PATCH 0333/1539] Add test certificate so SSL tests run on CI. --- MANIFEST.in | 1 + setup.py | 1 + websockets/test_client_server.py | 2 +- websockets/test_localhost.pem | 32 ++++++++++++++++++++++++++++++++ 4 files changed, 35 insertions(+), 1 deletion(-) create mode 100644 websockets/test_localhost.pem diff --git a/MANIFEST.in b/MANIFEST.in index 9f4f1787e..7e96fd0d7 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,5 @@ include LICENSE +include websockets/test_localhost.pem graft websockets/py35 graft websockets/py36 diff --git a/setup.py b/setup.py index bdb8f85cd..df1b516dd 100644 --- a/setup.py +++ b/setup.py @@ -59,4 +59,5 @@ ], packages=packages, ext_modules=ext_modules, + include_package_data=True, ) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 2395b6cc8..37cf25398 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -25,7 +25,7 @@ # Avoid displaying stack traces at the ERROR logging level. logging.basicConfig(level=logging.CRITICAL) -testcert = os.path.join(os.path.dirname(__file__), 'testcert.pem') +testcert = os.path.join(os.path.dirname(__file__), 'test_localhost.pem') @asyncio.coroutine diff --git a/websockets/test_localhost.pem b/websockets/test_localhost.pem new file mode 100644 index 000000000..1ed5fd458 --- /dev/null +++ b/websockets/test_localhost.pem @@ -0,0 +1,32 @@ +-----BEGIN PRIVATE KEY----- +MIICdQIBADANBgkqhkiG9w0BAQEFAASCAl8wggJbAgEAAoGBANSBDRjLau8ur0s1 +WNVJdpa1x6PMdistb9VU9lBqxJzu8sgWnuzvy1Nt+1lCl6j6QtQxma99bPjbcZ9S +rXJUwtBLq067Zy01VQ/lpBfjqRZShYUVimg4We9KB5DFvWzP52L8Oj0U3sm46mek +vcddtJQz6WwbPiROOSvF80W206fNAgMBAAECgYAfSKBU9h1X+Nd1ivT48Ue0CC7L +vl3nHVlJXqikThODxumW6z2aQ/L65UYLbfJFvhH4ixTE8QIJ4MRpYBKIslG7c3DX +cX6MP6KPaUjxSbjB9RlS9VdKbovxxeecbWzfSY+Cz/alyg++J0iOwbJVGL+RlaJw +g8hQM+UWyJLN764/QQJBAP/NeBHChjU7QyA36lv2Lm/lUpkYy3Zy4ZTGPyiuBjLC +SNqF1PMxrvuHHL05NaE6R02VFXztxJf2ci1rZKDG2N8CQQDUqwdsWZFlmTA5hqTB +mEYw3feCij3t4sy0KDV1wV851WJRbVrzrbxN+rHL5MKwd3qcxs1TXCfF1A9qbPXS +phjTAkBtd/KgNwzUDu5lBUjH3gx1WkAEwHWh1PvwfP5eXErOwhIHYiqFgIePoHyO +BcOLobMN4nT1p5LwLUkjYsgHfdElAkBgbBL3izyjBeuZiXSV2gapDVq1MxyVCOmr +HTfv5fbY7+id5qkAJttjt7B5M4UaIXHUN0bM7tGRnm5G4JQsJ+bFAkAQ/pYfrC9l +2hXI29YTSYTsw4iDjgJF6RAxw2108M8KybSJdyvQ43N4U40BQx8BRQmxZwSyG5QX +s+j9Cb63orCr +-----END PRIVATE KEY----- +-----BEGIN CERTIFICATE----- +MIICijCCAfOgAwIBAgIJANEJitrPxb96MA0GCSqGSIb3DQEBBQUAMF0xCzAJBgNV +BAYTAkZSMQ8wDQYDVQQIDAZGcmFuY2UxDjAMBgNVBAcMBVBhcmlzMRkwFwYDVQQK +DBBBeW1lcmljIEF1Z3VzdGluMRIwEAYDVQQDDAlsb2NhbGhvc3QwIBcNMTQwNDE1 +MjEzMjI5WhgPMjExNDA0MTYyMTMyMjlaMF0xCzAJBgNVBAYTAkZSMQ8wDQYDVQQI +DAZGcmFuY2UxDjAMBgNVBAcMBVBhcmlzMRkwFwYDVQQKDBBBeW1lcmljIEF1Z3Vz +dGluMRIwEAYDVQQDDAlsb2NhbGhvc3QwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJ +AoGBANSBDRjLau8ur0s1WNVJdpa1x6PMdistb9VU9lBqxJzu8sgWnuzvy1Nt+1lC +l6j6QtQxma99bPjbcZ9SrXJUwtBLq067Zy01VQ/lpBfjqRZShYUVimg4We9KB5DF +vWzP52L8Oj0U3sm46mekvcddtJQz6WwbPiROOSvF80W206fNAgMBAAGjUDBOMB0G +A1UdDgQWBBRcFzeirOD3zMnjCptlc0sh9VWZJjAfBgNVHSMEGDAWgBRcFzeirOD3 +zMnjCptlc0sh9VWZJjAMBgNVHRMEBTADAQH/MA0GCSqGSIb3DQEBBQUAA4GBAFyv +MGP9hnrMbDnwRtCYX/g99nvxjc5KXJyDw91Vo3hmHjdVRXY/oJbjiUtOBf1OsgoN +rv7KsaMb9+060K+uDtQIIiwPcxF1nQOZDtv6Nyzj8hwM2XFl+XiVgUD2pg++scWF +PDfbpmeEDQnUMEqHETM7JTMLB349/s5UUQqsSBE0 +-----END CERTIFICATE----- From fcedaf59514f7447f84e69bf226393496480ab48 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 30 Sep 2017 13:53:52 +0200 Subject: [PATCH 0334/1539] Increase time unit for timeout tests. This should make test more reliable. --- .travis.yml | 1 + appveyor.yml | 1 + 2 files changed, 2 insertions(+) diff --git a/.travis.yml b/.travis.yml index ea53a1f54..db6596e76 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,6 +3,7 @@ env: # websockets only works on Python >= 3.4. - CIBW_SKIP="cp27-* cp33-*" - CIBW_TEST_COMMAND="python3 -m unittest websockets" + - WEBSOCKETS_TESTS_TIMEOUT_FACTOR=100 matrix: include: diff --git a/appveyor.yml b/appveyor.yml index 0114320f9..75b32b118 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -2,6 +2,7 @@ environment: # websockets only works on Python >= 3.4. CIBW_SKIP: cp27-* cp33-* CIBW_TEST_COMMAND: python -m unittest websockets + WEBSOCKETS_TESTS_TIMEOUT_FACTOR: 100 # Since Python 2 is still the default, invoke Python 3 explicitly. install: From 43784eac2932fee3b0e87d579608000f2be30cfb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 30 Sep 2017 17:50:22 +0200 Subject: [PATCH 0335/1539] Add a timeout when writing a close frame. Fix #112. --- websockets/protocol.py | 28 +++++++++++------ websockets/test_protocol.py | 61 +++++++++++++++++++++++++------------ 2 files changed, 60 insertions(+), 29 deletions(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index 242a168db..b177674e0 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -66,7 +66,7 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): The ``timeout`` parameter defines the maximum wait time in seconds for completing the closing handshake and, only on the client side, for terminating the TCP connection. :meth:`close()` will complete in at most - ``2 * timeout`` on the server side and ``3 * timeout`` on the client side. + ``3 * timeout`` on the server side and ``4 * timeout`` on the client side. The ``max_size`` parameter enforces the maximum size for incoming messages in bytes. The default value is 1MB. ``None`` disables the limit. If a @@ -281,15 +281,25 @@ def close(self, code=1000, reason=''): # 7.1.2. Start the WebSocket Closing Handshake # 7.1.3. The WebSocket Closing Handshake is Started frame_data = serialize_close(code, reason) - yield from self.write_frame(OP_CLOSE, frame_data) + try: + yield from asyncio.wait_for( + self.write_frame(OP_CLOSE, frame_data), + self.timeout, loop=self.loop) + except asyncio.TimeoutError: + # If the close frame cannot be sent because the send buffers + # are full, the closing handshake won't complete anyway. + # Cancel the data transfer task to shut down faster. + # Cancelling a task is idempotent. + self.transfer_data_task.cancel() + + # If no close frame is received within the timeout, wait_for() cancels + # the data transfer task and raises TimeoutError. Then transfer_data() + # catches CancelledError and exits without an exception. + + # If close() is called multiple times concurrently and one of these + # calls hits the timeout, other calls will resume executing without an + # exception, so there's no need to catch CancelledError here. - # If no close frame is received within the timeout, cancel the data - # transfer task in order to exit the infinite loop. transfer_data() - # will catch CancelledError and exit without an exception. However - # wait_for() will raise CancelledError anyway. As a consequence, if - # close() is called several times concurrently and one of these calls - # is cancelled, other calls will see that the data transfer task has - # completed. This is why there's no need to catch CancelledError here. try: # If close() is cancelled during the wait, self.transfer_data_task # is cancelled before the timeout elapses (on Python ≥ 3.4.3). diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 8eebf659c..b5723a71f 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -100,7 +100,7 @@ def run_loop_once(self): self.loop.call_soon(self.loop.stop) self.loop.run_forever() - def make_drain_slow(self): + def make_drain_slow(self, delay=3 * MS): # Process connection_made in order to initialize self.protocol.writer. self.run_loop_once() @@ -108,7 +108,7 @@ def make_drain_slow(self): @asyncio.coroutine def delayed_drain(): - yield from asyncio.sleep(3 * MS, loop=self.loop) + yield from asyncio.sleep(delay, loop=self.loop) yield from original_drain() self.protocol.writer.drain = delayed_drain @@ -180,7 +180,8 @@ def close_connection_partial(self, code=1000, reason='close'): close_frame_data = serialize_close(code, reason) # Trigger the closing handshake from the local side. self.ensure_future(self.protocol.close(code, reason)) - self.run_loop_once() + self.run_loop_once() # wait_for executes + self.run_loop_once() # write_frame executes # Empty the outgoing data stream so we can make assertions later on. self.assertOneFrameSent(True, OP_CLOSE, close_frame_data) # Prepare the response to the closing handshake from the remote side. @@ -657,8 +658,8 @@ def test_connection_closed_attributes(self): def test_local_close(self): # Emulate how the remote endpoint answers the closing handshake. - self.loop.call_soon(self.receive_frame, self.close_frame) - self.loop.call_soon(self.receive_eof_if_client) + self.loop.call_later(MS, self.receive_frame, self.close_frame) + self.loop.call_later(MS, self.receive_eof_if_client) # Run the closing handshake. self.loop.run_until_complete(self.protocol.close(reason='close')) @@ -674,8 +675,8 @@ def test_local_close(self): def test_remote_close(self): # Emulate how the remote endpoint initiates the closing handshake. - self.loop.call_soon(self.receive_frame, self.close_frame) - self.loop.call_soon(self.receive_eof_if_client) + self.loop.call_later(MS, self.receive_frame, self.close_frame) + self.loop.call_later(MS, self.receive_eof_if_client) # Wait for some data in order to process the handshake. # After recv() raises ConnectionClosed, the connection is closed. @@ -693,8 +694,8 @@ def test_remote_close(self): def test_simultaneous_close(self): # Delay the incoming close frame until after we send the outgoing one. - self.loop.call_soon(self.receive_frame, self.remote_close) - self.loop.call_soon(self.receive_eof_if_client) + self.loop.call_later(MS, self.receive_frame, self.remote_close) + self.loop.call_later(MS, self.receive_eof_if_client) self.loop.run_until_complete(self.protocol.close(reason='local')) @@ -706,8 +707,8 @@ def test_simultaneous_close(self): def test_close_preserves_incoming_frames(self): self.receive_frame(Frame(True, OP_TEXT, b'hello')) - self.loop.call_soon(self.receive_frame, self.close_frame) - self.loop.call_soon(self.receive_eof_if_client) + self.loop.call_later(MS, self.receive_frame, self.close_frame) + self.loop.call_later(MS, self.receive_eof_if_client) self.loop.run_until_complete(self.protocol.close(reason='close')) self.assertConnectionClosed(1000, 'close') @@ -735,8 +736,8 @@ def test_close_connection_lost(self): def test_local_close_during_recv(self): recv = self.ensure_future(self.protocol.recv()) - self.loop.call_soon(self.receive_frame, self.close_frame) - self.loop.call_soon(self.receive_eof_if_client) + self.loop.call_later(MS, self.receive_frame, self.close_frame) + self.loop.call_later(MS, self.receive_eof_if_client) self.loop.run_until_complete(self.protocol.close(reason='close')) @@ -771,7 +772,16 @@ def setUp(self): self.protocol.is_client = False self.protocol.side = 'server' - def test_local_close_timeout(self): + def test_local_close_send_close_frame_timeout(self): + self.protocol.timeout = 10 * MS + self.make_drain_slow(50 * MS) + # If we can't send a close frame, time out in 10ms. + # Check the timing within -1/+9ms for robustness. + with self.assertCompletesWithin(9 * MS, 19 * MS): + self.loop.run_until_complete(self.protocol.close(reason='close')) + self.assertConnectionClosed(1006, '') + + def test_local_close_receive_close_frame_timeout(self): self.protocol.timeout = 10 * MS # If the client doesn't send a close frame, time out in 10ms. # Check the timing within -1/+9ms for robustness. @@ -799,11 +809,22 @@ def setUp(self): self.protocol.is_client = True self.protocol.side = 'client' - def test_local_close_timeout(self): + def test_local_close_send_close_frame_timeout(self): + self.protocol.timeout = 10 * MS + self.make_drain_slow(50 * MS) + # If we can't send a close frame, time out in 20ms. + # - 10ms waiting for sending a close frame + # - 10ms waiting for receiving a half-close + # Check the timing within -1/+9ms for robustness. + with self.assertCompletesWithin(19 * MS, 29 * MS): + self.loop.run_until_complete(self.protocol.close(reason='close')) + self.assertConnectionClosed(1006, '') + + def test_local_close_receive_close_frame_timeout(self): self.protocol.timeout = 10 * MS - # If the server doesn't send a close frame, time out in 30ms: - # - 10ms waiting for a close frame - # - 10ms waiting for a half-close + # If the server doesn't send a close frame, time out in 20ms: + # - 10ms waiting for receiving a close frame + # - 10ms waiting for receiving a half-close # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(19 * MS, 29 * MS): self.loop.run_until_complete(self.protocol.close(reason='close')) @@ -813,8 +834,8 @@ def test_local_close_connection_lost_timeout(self): self.protocol.timeout = 10 * MS # If the server doesn't half-close its side of the TCP connection # after we send a close frame, time out in 20ms: - # - 10ms waiting for a half-close - # - 10ms waiting for a close + # - 10ms waiting for receiving a half-close + # - 10ms waiting for receiving a close # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(19 * MS, 29 * MS): # HACK: disable write_eof => other end drops connection emulation. From f9721dec3700987cff03410110ddd7b6c8c69315 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 30 Sep 2017 18:10:58 +0200 Subject: [PATCH 0336/1539] Abort TCP connection if it doesn't close fast enough. Fix #112. --- docs/design.rst | 11 ++++++++++ websockets/protocol.py | 19 ++++++++++++---- websockets/test_protocol.py | 44 ++++++++++++++++++++++++++++++++++--- 3 files changed, 67 insertions(+), 7 deletions(-) diff --git a/docs/design.rst b/docs/design.rst index b114c66db..f8132c7a9 100644 --- a/docs/design.rst +++ b/docs/design.rst @@ -247,6 +247,9 @@ If the other side doesn't send a close frame within the connection's timeout, :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`, which has the same effect. +The closing handshake can take up to ``2 * timeout``: one ``timeout`` to write +a close frame and one ``timeout`` to receive a close frame. + Then ``websockets`` terminates the TCP connection. @@ -278,6 +281,14 @@ easier to implement the timeout on the closing handshake. Cancelling of cancelling :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` and failing to close the TCP connection, thus leaking resources. +Terminating the TCP connection can take up to ``2 * timeout`` on the server +side and ``3 * timeout`` on the client side. Clients start by waiting for the +server to close the connection, hence the extra ``timeout``. Then both sides +go through the following steps until the TCP connection is lost: half-closing +the connection (only for non-TLS connections), closing the connection, +aborting the connection. At this point the connection drops regardless of what +happens on the network. + .. _cancellation: diff --git a/websockets/protocol.py b/websockets/protocol.py index b177674e0..9edd1fbc4 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -66,7 +66,7 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): The ``timeout`` parameter defines the maximum wait time in seconds for completing the closing handshake and, only on the client side, for terminating the TCP connection. :meth:`close()` will complete in at most - ``3 * timeout`` on the server side and ``4 * timeout`` on the client side. + ``4 * timeout`` on the server side and ``5 * timeout`` on the client side. The ``max_size`` parameter enforces the maximum size for incoming messages in bytes. The default value is 1MB. ``None`` disables the limit. If a @@ -773,12 +773,23 @@ def close_connection(self, after_handshake=True): # Closing a transport is idempotent. If the transport was already # closed, for example from eof_received(), it's fine. - # Close the TCP connection. + # Close the TCP connection. Buffers are flushed asynchronously. logger.debug( "%s x closing TCP connection", self.side) self.writer.close() - # There's little need to await self.wait_for_connection_lost() - # here. Closing the transport triggers self.connection_lost(). + + if (yield from self.wait_for_connection_lost()): + return + logger.debug( + "%s ! timed out waiting for TCP close", self.side) + + # Abort the TCP connection. Buffers are discarded. + logger.debug( + "%s x aborting TCP connection", self.side) + self.writer.transport.abort() + + # connection_lost() is called quickly after aborting. + yield from self.wait_for_connection_lost() @asyncio.coroutine def fail_connection(self, code=1011, reason=''): diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index b5723a71f..67cf3d114 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -70,6 +70,11 @@ def close(self): self.loop.call_soon(self.protocol.connection_lost, None) self._closing = True + def abort(self): + # Change this to an `if` if tests call abort() multiple times. + assert self.protocol.state != CLOSED + self.loop.call_soon(self.protocol.connection_lost, None) + class CommonTests: """ @@ -789,7 +794,7 @@ def test_local_close_receive_close_frame_timeout(self): self.loop.run_until_complete(self.protocol.close(reason='close')) self.assertConnectionClosed(1006, '') - def test_local_close_connection_lost_timeout(self): + def test_local_close_connection_lost_timeout_after_write_eof(self): self.protocol.timeout = 10 * MS # If the client doesn't close its side of the TCP connection after we # half-close our side with write_eof(), time out in 10ms. @@ -801,6 +806,21 @@ def test_local_close_connection_lost_timeout(self): self.loop.run_until_complete(self.protocol.close(reason='close')) self.assertConnectionClosed(1000, 'close') + def test_local_close_connection_lost_timeout_after_close(self): + self.protocol.timeout = 10 * MS + # If the client doesn't close its side of the TCP connection after we + # half-close our side with write_eof() and close it with close(), time + # out in 20ms. + # Check the timing within -1/+9ms for robustness. + with self.assertCompletesWithin(19 * MS, 29 * MS): + # HACK: disable write_eof => other end drops connection emulation. + self.transport._eof = True + # HACK: disable close => other end drops connection emulation. + self.transport._closing = True + self.receive_frame(self.close_frame) + self.loop.run_until_complete(self.protocol.close(reason='close')) + self.assertConnectionClosed(1000, 'close') + class ClientTests(CommonTests, unittest.TestCase): @@ -830,12 +850,12 @@ def test_local_close_receive_close_frame_timeout(self): self.loop.run_until_complete(self.protocol.close(reason='close')) self.assertConnectionClosed(1006, '') - def test_local_close_connection_lost_timeout(self): + def test_local_close_connection_lost_timeout_after_write_eof(self): self.protocol.timeout = 10 * MS # If the server doesn't half-close its side of the TCP connection # after we send a close frame, time out in 20ms: # - 10ms waiting for receiving a half-close - # - 10ms waiting for receiving a close + # - 10ms waiting for receiving a close after write_eof # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(19 * MS, 29 * MS): # HACK: disable write_eof => other end drops connection emulation. @@ -843,3 +863,21 @@ def test_local_close_connection_lost_timeout(self): self.receive_frame(self.close_frame) self.loop.run_until_complete(self.protocol.close(reason='close')) self.assertConnectionClosed(1000, 'close') + + def test_local_close_connection_lost_timeout_after_close(self): + self.protocol.timeout = 10 * MS + # If the client doesn't close its side of the TCP connection after we + # half-close our side with write_eof() and close it with close(), time + # out in 20ms. + # - 10ms waiting for receiving a half-close + # - 10ms waiting for receiving a close after write_eof + # - 10ms waiting for receiving a close after close + # Check the timing within -1/+9ms for robustness. + with self.assertCompletesWithin(29 * MS, 39 * MS): + # HACK: disable write_eof => other end drops connection emulation. + self.transport._eof = True + # HACK: disable close => other end drops connection emulation. + self.transport._closing = True + self.receive_frame(self.close_frame) + self.loop.run_until_complete(self.protocol.close(reason='close')) + self.assertConnectionClosed(1000, 'close') From 712e48d34c4991011f776272e60d90dcac19aa16 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 30 Sep 2017 18:34:52 +0200 Subject: [PATCH 0337/1539] Add changelog for recent changes. --- docs/changelog.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index d41898abe..ad4ae02f9 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -22,6 +22,10 @@ Also: * :class:`~websockets.protocol.WebSocketCommonProtocol` instances can be used as asynchronous iterators on Python ≥ 3.6. They yield incoming messages. +* Aborted connections if they don't close within the configured ``timeout``. + +* Rewrote connection termination to increase robustness in edge cases. + * Stopped leaking pending tasks when :meth:`~asyncio.Task.cancel` is called on a connection while it's being closed. From 124e2762ec165a15809eda24d2db1793fb958fab Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 30 Sep 2017 20:34:22 +0200 Subject: [PATCH 0338/1539] Reorganize docs by use case. Fix #275. --- LICENSE | 2 +- docs/changelog.rst | 2 ++ docs/contributing.rst | 11 ++++++ docs/index.rst | 84 ++++++++++++++++++++++++++----------------- docs/intro.rst | 7 ++++ example/echo.py | 12 +++++++ example/hello.py | 11 ++++++ 7 files changed, 95 insertions(+), 34 deletions(-) create mode 100644 docs/contributing.rst create mode 100644 example/echo.py create mode 100644 example/hello.py diff --git a/LICENSE b/LICENSE index f46af27b5..7101662c8 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright (c) 2013-2015 Aymeric Augustin and contributors. +Copyright (c) 2013-2017 Aymeric Augustin and contributors. All rights reserved. Redistribution and use in source and binary forms, with or without diff --git a/docs/changelog.rst b/docs/changelog.rst index ad4ae02f9..f8d6380b4 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -22,6 +22,8 @@ Also: * :class:`~websockets.protocol.WebSocketCommonProtocol` instances can be used as asynchronous iterators on Python ≥ 3.6. They yield incoming messages. +* Reorganized and extended documentation. + * Aborted connections if they don't close within the configured ``timeout``. * Rewrote connection termination to increase robustness in edge cases. diff --git a/docs/contributing.rst b/docs/contributing.rst new file mode 100644 index 000000000..4b869dcac --- /dev/null +++ b/docs/contributing.rst @@ -0,0 +1,11 @@ +Contributing +============ + +Bug reports, patches and suggestions are welcome! Please open an issue_ or +send a `pull request`_. + +Feedback about this documentation is especially valuable — the authors of +``websockets`` feel more confident about writing code than writing docs :-) + +.. _issue: https://github.com/aaugustin/websockets/issues/new +.. _pull request: https://github.com/aaugustin/websockets/compare/ diff --git a/docs/index.rst b/docs/index.rst index 5f22282ce..9833e09d7 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,58 +1,76 @@ WebSockets ========== -``websockets`` is a library for developing WebSocket servers_ and clients_ in -Python. It implements `RFC 6455`_ and `RFC 7692`_ with a focus on correctness -and simplicity. It passes the `Autobahn Testsuite`_. +``websockets`` is a library for building WebSocket servers_ and clients_ in +Python with a focus on correctness and simplicity. + +.. _servers: https://github.com/aaugustin/websockets/blob/master/example/server.py +.. _clients: https://github.com/aaugustin/websockets/blob/master/example/client.py Built on top of :mod:`asyncio`, Python's standard asynchronous I/O framework, -it provides a straightforward API based on coroutines, making it easy to write -highly concurrent applications. +it provides an elegant coroutine-based API. -Installation ------------- +Here's a client that says "Hello world!": -Installation is as simple as ``pip install websockets``. +.. literalinclude:: ../example/hello.py -It requires Python ≥ 3.4. +And here's an echo server: -User guide ----------- +.. literalinclude:: ../example/echo.py -If you're new to ``websockets``, :doc:`intro` describes usage patterns and -provides examples. +Do you like it? Let's dive in! -If you've used ``websockets`` before and just need a quick reference, have a -look at :doc:`cheatsheet`. +Tutorials +--------- -If you need more details, the :doc:`api` documentation is for you. +If you're new to ``websockets``, this is the place to start. -If you're upgrading ``websockets``, check the :doc:`changelog`. +.. toctree:: + :maxdepth: 2 -Contributing ------------- + intro -Bug reports, patches and suggestions welcome! Just open an issue_ or send a -`pull request`_. +How-to guides +------------- -.. _servers: https://github.com/aaugustin/websockets/blob/master/example/server.py -.. _clients: https://github.com/aaugustin/websockets/blob/master/example/client.py -.. _RFC 6455: http://tools.ietf.org/html/rfc6455 -.. _RFC 7692: http://tools.ietf.org/html/rfc7692 -.. _Autobahn Testsuite: https://github.com/aaugustin/websockets/blob/master/compliance/README.rst -.. _PEP 3156: http://www.python.org/dev/peps/pep-3156/ -.. _issue: https://github.com/aaugustin/websockets/issues/new -.. _pull request: https://github.com/aaugustin/websockets/compare/ +These guides will help you build and deploy a ``websockets`` application. .. toctree:: - :hidden: + :maxdepth: 2 - intro cheatsheet - api deployment - security + +Reference +--------- + +Find all the details you could ask for, and then some. + +.. toctree:: + :maxdepth: 2 + + api + +Discussions +----------- + +Get a deeper understanding of how ``websockets`` is built and why. + +.. toctree:: + :maxdepth: 2 + design limitations + security + +Project +------- + +This is about websockets-the-project rather than websockets-the-software. + +.. toctree:: + :maxdepth: 2 + + contributing changelog license diff --git a/docs/intro.rst b/docs/intro.rst index 0b71f571c..881bc610e 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -6,6 +6,13 @@ Getting started This documentation is written for Python ≥ 3.5. If you're using Python 3.4, you will have to :ref:`adapt the code samples `. +Installation +------------ + +``websockets`` requires Python ≥ 3.4. Install it with:: + + pip install websockets + Basic example ------------- diff --git a/example/echo.py b/example/echo.py new file mode 100644 index 000000000..8fa307dd7 --- /dev/null +++ b/example/echo.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python + +import asyncio +import websockets + +async def echo(websocket, path): + async for message in websocket: + await websocket.send(message) + +asyncio.get_event_loop().run_until_complete( + websockets.serve(echo, 'localhost', 8765)) +asyncio.get_event_loop().run_forever() diff --git a/example/hello.py b/example/hello.py new file mode 100644 index 000000000..bbb3d9a0e --- /dev/null +++ b/example/hello.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python + +import asyncio +import websockets + +async def hello(uri): + async with websockets.connect(uri) as websocket: + await websocket.send("Hello world!") + +asyncio.get_event_loop().run_until_complete( + hello('ws://localhost:8765')) From c4cf94f8c776841b1794eecdf614a03942972b76 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 30 Sep 2017 21:06:01 +0200 Subject: [PATCH 0339/1539] Update headline in docs. --- docs/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index 885b81690..fd886a842 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -98,7 +98,7 @@ # documentation. html_theme_options = { 'logo': 'websockets.svg', - 'description': 'WebSockets for Python 3', + 'description': 'A library for building WebSocket servers and clients in Python with a focus on correctness and simplicity.', 'github_button': True, 'github_user': 'aaugustin', 'github_repo': 'websockets', From 037d56a67f241843fdb175cdf7a5cef34cd202ed Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 30 Sep 2017 21:06:22 +0200 Subject: [PATCH 0340/1539] More badges! --- README.rst | 35 ++++++++++++++++++++++++++--------- docs/index.rst | 20 ++++++++++++++++++++ 2 files changed, 46 insertions(+), 9 deletions(-) diff --git a/README.rst b/README.rst index 30ce838e7..d852b1018 100644 --- a/README.rst +++ b/README.rst @@ -1,5 +1,29 @@ -WebSockets |pypi| |circleci| |codecov| -====================================== +WebSockets +========== + +|rtd| |pypi-v| |pypi-pyversions| |pypi-l| |pypi-wheel| |circleci| |codecov| + +.. |rtd| image:: https://readthedocs.org/projects/websockets/badge/?version=latest + :target: https://websockets.readthedocs.io/ + +.. |pypi-v| image:: https://img.shields.io/pypi/v/websockets.svg + :target: https://pypi.python.org/pypi/websockets + +.. |pypi-pyversions| image:: https://img.shields.io/pypi/pyversions/websockets.svg + :target: https://pypi.python.org/pypi/websockets + +.. |pypi-l| image:: https://img.shields.io/pypi/l/websockets.svg + :target: https://pypi.python.org/pypi/websockets + +.. |pypi-wheel| image:: https://img.shields.io/pypi/wheel/websockets.svg + :target: https://pypi.python.org/pypi/websockets + +.. |circleci| image:: https://img.shields.io/circleci/project/github/aaugustin/websockets.svg + :target: https://circleci.com/gh/aaugustin/websockets + +.. |codecov| image:: https://codecov.io/gh/aaugustin/websockets/branch/master/graph/badge.svg + :target: https://codecov.io/gh/aaugustin/websockets + ``websockets`` is a library for developing WebSocket servers_ and clients_ in Python. It implements `RFC 6455`_ and `RFC 7692`_ with a focus on correctness @@ -27,10 +51,3 @@ Bug reports, patches and suggestions welcome! Just open an issue_ or send a .. _Read the Docs: https://websockets.readthedocs.io/ .. _issue: https://github.com/aaugustin/websockets/issues/new .. _pull request: https://github.com/aaugustin/websockets/compare/ - -.. |pypi| image:: https://img.shields.io/pypi/v/websockets.svg - :target: https://pypi.python.org/pypi/websockets -.. |circleci| image:: https://circleci.com/gh/aaugustin/websockets/tree/master.svg?style=shield - :target: https://circleci.com/gh/aaugustin/websockets/tree/master -.. |codecov| image:: https://codecov.io/gh/aaugustin/websockets/branch/master/graph/badge.svg - :target: https://codecov.io/gh/aaugustin/websockets diff --git a/docs/index.rst b/docs/index.rst index 9833e09d7..7ccd9463e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,6 +1,26 @@ WebSockets ========== +|pypi-v| |pypi-pyversions| |pypi-l| |pypi-wheel| |circleci| |codecov| + +.. |pypi-v| image:: https://img.shields.io/pypi/v/websockets.svg + :target: https://pypi.python.org/pypi/websockets + +.. |pypi-pyversions| image:: https://img.shields.io/pypi/pyversions/websockets.svg + :target: https://pypi.python.org/pypi/websockets + +.. |pypi-l| image:: https://img.shields.io/pypi/l/websockets.svg + :target: https://pypi.python.org/pypi/websockets + +.. |pypi-wheel| image:: https://img.shields.io/pypi/wheel/websockets.svg + :target: https://pypi.python.org/pypi/websockets + +.. |circleci| image:: https://img.shields.io/circleci/project/github/aaugustin/websockets.svg + :target: https://circleci.com/gh/aaugustin/websockets + +.. |codecov| image:: https://codecov.io/gh/aaugustin/websockets/branch/master/graph/badge.svg + :target: https://codecov.io/gh/aaugustin/websockets + ``websockets`` is a library for building WebSocket servers_ and clients_ in Python with a focus on correctness and simplicity. From 77fc08303eca8f8cbc43ea99f00191fa1fadadbf Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 30 Sep 2017 22:27:08 +0200 Subject: [PATCH 0341/1539] Add marketing copy to the README. Fix #245. Thanks @cjerdonek for your feedback! --- README.rst | 116 +++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 100 insertions(+), 16 deletions(-) diff --git a/README.rst b/README.rst index d852b1018..d05a5bb14 100644 --- a/README.rst +++ b/README.rst @@ -24,30 +24,114 @@ WebSockets .. |codecov| image:: https://codecov.io/gh/aaugustin/websockets/branch/master/graph/badge.svg :target: https://codecov.io/gh/aaugustin/websockets +What is ``websockets``? +----------------------- -``websockets`` is a library for developing WebSocket servers_ and clients_ in -Python. It implements `RFC 6455`_ and `RFC 7692`_ with a focus on correctness -and simplicity. It passes the `Autobahn Testsuite`_. +``websockets`` is a library for building WebSocket servers_ and clients_ in +Python with a focus on correctness and simplicity. -Built on top of Python's asynchronous I/O support introduced in `PEP 3156`_, -it provides an API based on coroutines, making it easy to write highly -concurrent applications. +.. _servers: https://github.com/aaugustin/websockets/blob/master/example/server.py +.. _clients: https://github.com/aaugustin/websockets/blob/master/example/client.py + +Built on top of ``asyncio``, Python's standard asynchronous I/O framework, it +provides an elegant coroutine-based API. + +Here's a client that says "Hello world!": + +.. copy-pasted because GitHub doesn't support the include directive + +.. code:: python + + #!/usr/bin/env python + + import asyncio + import websockets + + async def hello(uri): + async with websockets.connect(uri) as websocket: + await websocket.send("Hello world!") + + asyncio.get_event_loop().run_until_complete( + hello('ws://localhost:8765')) + +And here's an echo server: + +.. code:: python + + #!/usr/bin/env python + + import asyncio + import websockets + + async def echo(websocket, path): + async for message in websocket: + await websocket.send(message) + + asyncio.get_event_loop().run_until_complete( + websockets.serve(echo, 'localhost', 8765)) + asyncio.get_event_loop().run_forever() + +Does that look good? `Start here`_. + +.. _Start here: https://websockets.readthedocs.io/en/stable/intro.html -Installation is as simple as ``pip install websockets``. +Why should I use ``websockets``? +-------------------------------- -It requires Python ≥ 3.4. +The development of ``websockets`` is shaped by four principles: -Documentation is available on `Read the Docs`_. +1. **Simplicity**: all you need to understand is ``msg = await ws.recv()`` and + ``await ws.send(msg)``; ``websockets`` takes care of managing connections + so you can focus on your application. + +2. **Robustness**: ``websockets`` is built for production; for example it was + the only library to `handle backpressure correctly`_ before the issue + became widely known in the Python community. + +3. **Quality**: ``websockets`` is heavily tested. Continuous integration fails + under 100% branch coverage. Also it passes the industry-standard `Autobahn + Testsuite`_. + +4. **Performance**: memory use is configurable. An extension written in C + accelerates expensive operations. It's pre-compiled for Linux, macOS and + Windows and packaged in the wheel format for each system and Python version. + +Documentation is a first class concern in the project. Head over to `Read the +Docs`_ and see for yourself. + +Professional support is available if you — or your company — are so inclined. +`Get in touch`_. + +(If you contribute to ``websockets`` and would like to become an official +support provider, let me know.) + +.. _Read the Docs: https://websockets.readthedocs.io/ +.. _handle backpressure correctly: https://vorpus.org/blog/some-thoughts-on-asynchronous-api-design-in-a-post-asyncawait-world/#websocket-servers +.. _Autobahn Testsuite: https://github.com/aaugustin/websockets/blob/master/compliance/README.rst +.. _Get in touch: https://fractalideas.com/ + +Why shouldn't I use ``websockets``? +----------------------------------- + +* If you prefer callbacks over coroutines: ``websockets`` was created to + provide the best corountine-based API to manage WebSocket connections in + Python. Pick another library for a callback-based API. +* If you're looking for a mixed HTTP / WebSocket library: ``websockets`` aims + at being an excellent implementation of :rfc:`6455`: The WebSocket Protocol + and :rfc:`7692`: Compression Extensions for WebSocket. Its support for HTTP + is minimal — just enough for a HTTP health check. +* If you want to use Python 2: ``websockets`` builds upon ``asyncio`` which + only works on Python 3. ``websockets`` requires Python ≥ 3.4. + +What else? +---------- Bug reports, patches and suggestions welcome! Just open an issue_ or send a `pull request`_. -.. _servers: https://github.com/aaugustin/websockets/blob/master/example/server.py -.. _clients: https://github.com/aaugustin/websockets/blob/master/example/client.py -.. _RFC 6455: http://tools.ietf.org/html/rfc6455 -.. _RFC 7692: http://tools.ietf.org/html/rfc7692 -.. _Autobahn Testsuite: https://github.com/aaugustin/websockets/blob/master/compliance/README.rst -.. _PEP 3156: http://www.python.org/dev/peps/pep-3156/ -.. _Read the Docs: https://websockets.readthedocs.io/ .. _issue: https://github.com/aaugustin/websockets/issues/new .. _pull request: https://github.com/aaugustin/websockets/compare/ + +``websockets`` is released under the `BSD license`_. + +.. _BSD license: https://websockets.readthedocs.io/en/stable/license.html From 51b26f8ff65912a611a2f169b352c671f0afae01 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 1 Oct 2017 14:44:32 +0200 Subject: [PATCH 0342/1539] Add support for sock argument in client. Fix #221. Adapted from PR #229. --- docs/changelog.rst | 3 +++ websockets/client.py | 8 +++++- websockets/test_client_server.py | 45 ++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 1 deletion(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index f8d6380b4..aa0cf00e0 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -59,6 +59,9 @@ Also: now raises :class:`~websockets.exceptions.InvalidStatusCode` with a ``code`` attribute. +* Providing a ``sock`` argument to :func:`~websockets.client.connect()` no + longer crashes. + 3.3 ... diff --git a/websockets/client.py b/websockets/client.py index 096c7cf94..cc4d1cf49 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -365,8 +365,14 @@ def connect(uri, *, extra_headers=extra_headers, ) + if kwds.get('sock') is None: + host, port = wsuri.host, wsuri.port + else: + # If sock is given, host and port mustn't be specified. + host, port = None, None + transport, protocol = yield from loop.create_connection( - factory, wsuri.host, wsuri.port, **kwds) + factory, host, port, **kwds) try: yield from protocol.handshake( diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 37cf25398..3d4ab0071 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -3,7 +3,9 @@ import functools import logging import os +import socket import ssl +import sys import unittest import unittest.mock import urllib.request @@ -69,6 +71,7 @@ def temp_test_client(test, *args, **kwds): def with_manager(manager, *args, **kwds): """ Return a decorator that wraps a function with a context manager. + """ def decorate(func): @functools.wraps(func) @@ -84,6 +87,7 @@ def _decorate(self, *_args, **_kwds): def with_server(**kwds): """ Return a decorator for TestCase methods that starts and stops a server. + """ return with_manager(temp_test_server, **kwds) @@ -91,6 +95,7 @@ def with_server(**kwds): def with_client(*args, **kwds): """ Return a decorator for TestCase methods that starts and stops a client. + """ return with_manager(temp_test_client, *args, **kwds) @@ -238,6 +243,46 @@ def test_explicit_event_loop(self): reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") + # The way the legacy SSL implementation wraps sockets makes it extremely + # hard to write a test for Python 3.4. + @unittest.skipIf( + sys.version_info[:2] <= (3, 4), 'this test requires Python 3.5+') + @with_server() + def test_explicit_socket(self): + + class TrackedSocket(socket.socket): + def __init__(self, *args, **kwargs): + self.used_for_read = False + self.used_for_write = False + super().__init__(*args, **kwargs) + + def recv(self, *args, **kwargs): + self.used_for_read = True + return super().recv(*args, **kwargs) + + def send(self, *args, **kwargs): + self.used_for_write = True + return super().send(*args, **kwargs) + + sock = TrackedSocket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect(('localhost', 8642)) + server_hostname = 'localhost' if self.secure else None + + try: + self.assertFalse(sock.used_for_read) + self.assertFalse(sock.used_for_write) + + with self.temp_client(sock=sock, server_hostname=server_hostname): + self.loop.run_until_complete(self.client.send("Hello!")) + reply = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(reply, "Hello!") + + self.assertTrue(sock.used_for_read) + self.assertTrue(sock.used_for_write) + + finally: + sock.close() + @with_server() @with_client('attributes') def test_protocol_attributes(self): From a1405fef94a59b90b7434dd6853e8a1244b257d9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 1 Oct 2017 21:08:56 +0200 Subject: [PATCH 0343/1539] Make markup easier to read with currentmodule. --- docs/changelog.rst | 74 ++++++++++++++++++++++----------------------- docs/cheatsheet.rst | 51 ++++++++++++++++--------------- docs/deployment.rst | 8 +++-- docs/intro.rst | 14 +++++---- 4 files changed, 76 insertions(+), 71 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index aa0cf00e0..8fb174964 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,6 +1,8 @@ Changelog --------- +.. currentmodule:: websockets + 4.0 ... @@ -15,12 +17,12 @@ Changelog Compression should improve performance but it increases RAM and CPU use. If you want to disable compression, add ``compression=None`` when calling - :func:`~websockets.server.serve` or :func:`~websockets.client.connect`. + :func:`~server.serve()` or :func:`~client.connect()`. Also: -* :class:`~websockets.protocol.WebSocketCommonProtocol` instances can be used - as asynchronous iterators on Python ≥ 3.6. They yield incoming messages. +* :class:`~protocol.WebSocketCommonProtocol` instances can be used as + asynchronous iterators on Python ≥ 3.6. They yield incoming messages. * Reorganized and extended documentation. @@ -38,16 +40,15 @@ Also: 3.4 ... -* Renamed :func:`~websockets.server.serve()` and - :func:`~websockets.client.connect()`'s ``klass`` argument to - ``create_protocol`` to reflect that it can also be a callable. +* Renamed :func:`~server.serve()` and :func:`~client.connect()`'s ``klass`` + argument to ``create_protocol`` to reflect that it can also be a callable. For backwards compatibility, ``klass`` is still supported. -* :func:`~websockets.server.serve` can be used as an asynchronous context - manager on Python ≥ 3.5. +* :func:`~server.serve` can be used as an asynchronous context manager on + Python ≥ 3.5. * Added support for customizing handling of incoming connections with - :meth:`~websockets.server.WebSocketServerProtocol.process_request()`. + :meth:`~server.WebSocketServerProtocol.process_request()`. * Made read and write buffer sizes configurable. @@ -55,12 +56,11 @@ Also: * Added an optional C extension to speed up low level operations. -* An invalid response status code during :func:`~websockets.client.connect` - now raises :class:`~websockets.exceptions.InvalidStatusCode` with a ``code`` - attribute. +* An invalid response status code during :func:`~client.connect()` now raises + :class:`~exceptions.InvalidStatusCode` with a ``code`` attribute. -* Providing a ``sock`` argument to :func:`~websockets.client.connect()` no - longer crashes. +* Providing a ``sock`` argument to :func:`~client.connect()` no longer + crashes. 3.3 ... @@ -73,7 +73,7 @@ Also: ... * Added ``timeout``, ``max_size``, and ``max_queue`` arguments to - :func:`~websockets.client.connect()` and :func:`~websockets.server.serve()`. + :func:`~client.connect()` and :func:`~server.serve()`. * Made server shutdown more robust. @@ -90,20 +90,20 @@ Also: .. warning:: **Version 3.0 introduces a backwards-incompatible change in the** - :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` **API.** + :meth:`~protocol.WebSocketCommonProtocol.recv` **API.** **If you're upgrading from 2.x or earlier, please read this carefully.** - :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` used to return - ``None`` when the connection was closed. This required checking the return - value of every call:: + :meth:`~protocol.WebSocketCommonProtocol.recv` used to return ``None`` + when the connection was closed. This required checking the return value of + every call:: message = await websocket.recv() if message is None: return - Now it raises a :exc:`~websockets.exceptions.ConnectionClosed` exception - instead. This is more Pythonic. The previous code can be simplified to:: + Now it raises a :exc:`~exceptions.ConnectionClosed` exception instead. + This is more Pythonic. The previous code can be simplified to:: message = await websocket.recv() @@ -113,21 +113,21 @@ Also: In order to avoid stranding projects built upon an earlier version, the previous behavior can be restored by passing ``legacy_recv=True`` to - :func:`~websockets.server.serve`, :func:`~websockets.client.connect`, - :class:`~websockets.server.WebSocketServerProtocol`, or - :class:`~websockets.client.WebSocketClientProtocol`. ``legacy_recv`` isn't - documented in their signatures but isn't scheduled for deprecation either. + :func:`~server.serve`, :func:`~client.connect`, + :class:`~server.WebSocketServerProtocol`, or + :class:`~client.WebSocketClientProtocol`. ``legacy_recv`` isn't documented + in their signatures but isn't scheduled for deprecation either. Also: -* :func:`~websockets.client.connect` can be used as an asynchronous context - manager on Python ≥ 3.5. +* :func:`~client.connect` can be used as an asynchronous context manager on + Python ≥ 3.5. * Updated documentation with ``await`` and ``async`` syntax from Python 3.5. -* :meth:`~websockets.protocol.WebSocketCommonProtocol.ping` and - :meth:`~websockets.protocol.WebSocketCommonProtocol.pong` supports - data passed as :class:`str` in addition to :class:`bytes`. +* :meth:`~protocol.WebSocketCommonProtocol.ping` and + :meth:`~protocol.WebSocketCommonProtocol.pong` support data passed as + :class:`str` in addition to :class:`bytes`. * Worked around an asyncio bug affecting connection termination under load. @@ -165,8 +165,8 @@ Also: * Returned a 403 status code instead of 400 when the request Origin isn't allowed. -* Cancelling :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` no - longer drops the next message. +* Cancelling :meth:`~protocol.WebSocketCommonProtocol.recv` no longer drops + the next message. * Clarified that the closing handshake can be initiated by the client. @@ -183,8 +183,8 @@ Also: * Supported non-default event loop. -* Added ``loop`` argument to :func:`~websockets.client.connect` and - :func:`~websockets.server.serve`. +* Added ``loop`` argument to :func:`~client.connect` and + :func:`~server.serve`. 2.3 ... @@ -211,9 +211,9 @@ Also: .. warning:: **Version 2.0 introduces a backwards-incompatible change in the** - :meth:`~websockets.protocol.WebSocketCommonProtocol.send`, - :meth:`~websockets.protocol.WebSocketCommonProtocol.ping`, and - :meth:`~websockets.protocol.WebSocketCommonProtocol.pong` **APIs.** + :meth:`~protocol.WebSocketCommonProtocol.send`, + :meth:`~protocol.WebSocketCommonProtocol.ping`, and + :meth:`~protocol.WebSocketCommonProtocol.pong` **APIs.** **If you're upgrading from 1.x or earlier, please read this carefully.** diff --git a/docs/cheatsheet.rst b/docs/cheatsheet.rst index f73812a58..610f27973 100644 --- a/docs/cheatsheet.rst +++ b/docs/cheatsheet.rst @@ -1,54 +1,55 @@ Cheat sheet =========== +.. currentmodule:: websockets + Server ------ * Write a coroutine that handles a single connection. It receives a websocket protocol instance and the URI path in argument. - * Call :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` and - :meth:`~websockets.protocol.WebSocketCommonProtocol.send` to receive and - send messages at any time. + * Call :meth:`~protocol.WebSocketCommonProtocol.recv` and + :meth:`~protocol.WebSocketCommonProtocol.send` to receive and send + messages at any time. - * You may :meth:`~websockets.protocol.WebSocketCommonProtocol.ping` or - :meth:`~websockets.protocol.WebSocketCommonProtocol.pong` if you wish - but it isn't needed in general. + * You may :meth:`~protocol.WebSocketCommonProtocol.ping` or + :meth:`~protocol.WebSocketCommonProtocol.pong` if you wish but it isn't + needed in general. -* Create a server with :func:`~websockets.server.serve` which is similar to - asyncio's :meth:`~asyncio.AbstractEventLoop.create_server`. +* Create a server with :func:`~server.serve` which is similar to asyncio's + :meth:`~asyncio.AbstractEventLoop.create_server`. * The server takes care of establishing connections, then lets the handler - execute the application logic, and finally closes the connection after - the handler exits normally or with an exception. + execute the application logic, and finally closes the connection after the + handler exits normally or with an exception. * For advanced customization, you may subclass - :class:`~websockets.server.WebSocketServerProtocol` and pass either this - subclass or a factory function as the ``create_protocol`` argument. + :class:`~server.WebSocketServerProtocol` and pass either this subclass or + a factory function as the ``create_protocol`` argument. Client ------ -* Create a client with :func:`~websockets.client.connect` which is similar to - asyncio's :meth:`~asyncio.BaseEventLoop.create_connection`. +* Create a client with :func:`~client.connect` which is similar to asyncio's + :meth:`~asyncio.BaseEventLoop.create_connection`. * On Python ≥ 3.5, you can also use it as an asynchronous context manager. * For advanced customization, you may subclass - :class:`~websockets.server.WebSocketClientProtocol` and pass either this - subclass or a factory function as the ``create_protocol`` argument. + :class:`~server.WebSocketClientProtocol` and pass either this subclass or + a factory function as the ``create_protocol`` argument. -* Call :meth:`~websockets.protocol.WebSocketCommonProtocol.recv` and - :meth:`~websockets.protocol.WebSocketCommonProtocol.send` to receive and - send messages at any time. +* Call :meth:`~protocol.WebSocketCommonProtocol.recv` and + :meth:`~protocol.WebSocketCommonProtocol.send` to receive and send messages + at any time. -* You may :meth:`~websockets.protocol.WebSocketCommonProtocol.ping` or - :meth:`~websockets.protocol.WebSocketCommonProtocol.pong` if you wish but it - isn't needed in general. +* You may :meth:`~protocol.WebSocketCommonProtocol.ping` or + :meth:`~protocol.WebSocketCommonProtocol.pong` if you wish but it isn't + needed in general. -* If you aren't using :func:`~websockets.client.connect` as a context manager, - call :meth:`~websockets.protocol.WebSocketCommonProtocol.close` to terminate - the connection. +* If you aren't using :func:`~client.connect` as a context manager, call + :meth:`~protocol.WebSocketCommonProtocol.close` to terminate the connection. Debugging --------- diff --git a/docs/deployment.rst b/docs/deployment.rst index 8cdc09152..c5e3dab28 100644 --- a/docs/deployment.rst +++ b/docs/deployment.rst @@ -1,6 +1,8 @@ Deployment ========== +.. currentmodule:: websockets + Application server ------------------ @@ -22,15 +24,15 @@ Graceful shutdown You may want to close connections gracefully when shutting down the server, perhaps after executing some cleanup logic. There are two ways to achieve this -with the object returned by :func:`~websockets.server.serve`: +with the object returned by :func:`~server.serve`: - using it as a asynchronous context manager, or - calling its ``close()`` method, then waiting for its ``wait_closed()`` method to complete. Tasks that handle connections will be cancelled. For example, if the handler -is awaiting :meth:`~websockets.protocol.WebSocketCommonProtocol.recv`, that -call will raise :exc:`~asyncio.CancelledError`. +is awaiting :meth:`~protocol.WebSocketCommonProtocol.recv`, that call will +raise :exc:`~asyncio.CancelledError`. On Unix systems, shutdown is usually triggered by sending a signal. diff --git a/docs/intro.rst b/docs/intro.rst index 881bc610e..ab77f9c50 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -1,6 +1,8 @@ Getting started =============== +.. currentmodule:: websockets + .. warning:: This documentation is written for Python ≥ 3.5. If you're using Python @@ -80,9 +82,9 @@ for earlier Python versions:: message = await websocket.recv() await consumer(message) -:meth:`~websockets.protocol.WebSocketCommonProtocol.recv` raises a -:exc:`~websockets.exceptions.ConnectionClosed` exception when the client -disconnects, which breaks out of the ``while True`` loop. +:meth:`~protocol.WebSocketCommonProtocol.recv` raises a +:exc:`~exceptions.ConnectionClosed` exception when the client disconnects, +which breaks out of the ``while True`` loop. Producer ........ @@ -94,9 +96,9 @@ For getting messages from a ``producer`` coroutine and sending them:: message = await producer() await websocket.send(message) -:meth:`~websockets.protocol.WebSocketCommonProtocol.send` raises a -:exc:`~websockets.exceptions.ConnectionClosed` exception when the client -disconnects, which breaks out of the ``while True`` loop. +:meth:`~protocol.WebSocketCommonProtocol.send` raises a +:exc:`~exceptions.ConnectionClosed` exception when the client disconnects, +which breaks out of the ``while True`` loop. Both .... From 97dea5258bc0cdf11ce0bfa6bae82e85eee6e41c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 1 Oct 2017 21:10:18 +0200 Subject: [PATCH 0344/1539] Link to RFCs consistently. --- websockets/client.py | 3 ++- websockets/extensions/permessage_deflate.py | 4 +--- websockets/server.py | 3 ++- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index cc4d1cf49..49b98bcc4 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -97,7 +97,8 @@ def process_extensions(headers, available_extensions): Raise :exc:`~websockets.exceptions.InvalidHandshake` to abort the connection. - RFC 6455 leaves the rules up to the specification of each extension. + :rfc:`6455` leaves the rules up to the specification of each + :extension. To provide this level of flexibility, for each extension accepted by the server, we check for a match with each extension available in the diff --git a/websockets/extensions/permessage_deflate.py b/websockets/extensions/permessage_deflate.py index f7fb4f021..a5a0cb857 100644 --- a/websockets/extensions/permessage_deflate.py +++ b/websockets/extensions/permessage_deflate.py @@ -1,8 +1,6 @@ """ The :mod:`websockets.extensions.permessage_deflate` module implements the -Compression Extensions for WebSocket as specified in `RFC 7692`_. - -.. _RFC 7692: http://tools.ietf.org/html/rfc7692 +Compression Extensions for WebSocket as specified in :rfc:`7692`. """ diff --git a/websockets/server.py b/websockets/server.py index d7c115829..0735cf7db 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -281,7 +281,8 @@ def process_extensions(headers, available_extensions): handshake with an HTTP 400 error code. (The default implementation never does this.) - RFC 6455 leaves the rules up to the specification of each extension. + :rfc:`6455` leaves the rules up to the specification of each + :extension. To provide this level of flexibility, for each extension proposed by the client, we check for a match with each extension available in the From 91be27e9c4d23b3902ac77cdfe648e480ec63985 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 1 Oct 2017 21:23:15 +0200 Subject: [PATCH 0345/1539] Move encode_data utility outside of protocol. --- websockets/framing.py | 18 +++++++++++++++++- websockets/protocol.py | 19 ++----------------- websockets/test_framing.py | 10 ++++++++++ 3 files changed, 29 insertions(+), 18 deletions(-) diff --git a/websockets/framing.py b/websockets/framing.py index 9a72fcd56..c76778315 100644 --- a/websockets/framing.py +++ b/websockets/framing.py @@ -27,7 +27,7 @@ __all__ = [ 'DATA_OPCODES', 'CTRL_OPCODES', 'OP_CONT', 'OP_TEXT', 'OP_BINARY', 'OP_CLOSE', 'OP_PING', 'OP_PONG', - 'Frame', 'parse_close', 'serialize_close' + 'Frame', 'encode_data', 'parse_close', 'serialize_close' ] DATA_OPCODES = OP_CONT, OP_TEXT, OP_BINARY = 0x00, 0x01, 0x02 @@ -237,6 +237,22 @@ def check(frame): "Invalid opcode ({})".format(frame.opcode)) +def encode_data(data): + """ + Helper that converts :class:`str` or :class:`bytes` to :class:`bytes`. + + :class:`str` are encoded with UTF-8. + + """ + # Expect str or bytes, return bytes. + if isinstance(data, str): + return data.encode('utf-8') + elif isinstance(data, bytes): + return data + else: + raise TypeError("data must be bytes or str") + + def parse_close(data): """ Parse the data in a close frame. diff --git a/websockets/protocol.py b/websockets/protocol.py index 9edd1fbc4..73b19679e 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -408,7 +408,7 @@ def ping(self, data=None): yield from self.ensure_open() if data is not None: - data = self.encode_data(data) + data = encode_data(data) # Protect against duplicates if a payload is explicitly set. if data in self.pings: @@ -438,27 +438,12 @@ def pong(self, data=b''): """ yield from self.ensure_open() - data = self.encode_data(data) + data = encode_data(data) yield from self.write_frame(OP_PONG, data) # Private methods - no guarantees. - def encode_data(self, data): - """ - Expect :class:`str` or :class:`bytes`. Return :class:`bytes`. - - :class:`str` are encoded with UTF-8. - - """ - # Expect str or bytes, return bytes. - if isinstance(data, str): - return data.encode('utf-8') - elif isinstance(data, bytes): - return data - else: - raise TypeError("data must be bytes or str") - def connection_open(self): """ Callback when the opening handshake completes. diff --git a/websockets/test_framing.py b/websockets/test_framing.py index 4271b5a18..ba14603b1 100644 --- a/websockets/test_framing.py +++ b/websockets/test_framing.py @@ -169,6 +169,16 @@ def test_control_frame_max_length(self): with self.assertRaises(WebSocketProtocolError): self.decode(b'\x88\x7e\x00\x7e' + 126 * b'a') + def test_encode_data_str(self): + self.assertEqual(encode_data('café'), b'caf\xc3\xa9') + + def test_encode_data_bytes(self): + self.assertEqual(encode_data(b'tea'), b'tea') + + def test_encode_data_other(self): + with self.assertRaises(TypeError): + encode_data(None) + def test_fragmented_control_frame(self): # Fin bit correctly set. self.decode(b'\x88\x00') From 5769cbf1a06be91a272eff54f7c140bea85d0c3f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 1 Oct 2017 21:47:13 +0200 Subject: [PATCH 0346/1539] Optimize ensure_open when the connection is closing. * If the remote side started the closing handsake, the close code and reason are already known. * If they aren't, there's no need to add a timeout because we know that self.transfer_data_task will complete quickly anyway. Extend protocol tests to cover the two cases of CLOSING connections: initiated by the local or the remote side. --- websockets/protocol.py | 15 ++++----- websockets/test_protocol.py | 63 ++++++++++++++++++++++++++++++------- 2 files changed, 60 insertions(+), 18 deletions(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index 73b19679e..1cc64217a 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -466,20 +466,21 @@ def ensure_open(self): Raise :exc:`~websockets.exceptions.ConnectionClosed` if it isn't. """ - # Handle cases from the most common to the least common. + # Handle cases from most common to least common for performance. if self.state == OPEN: return if self.state == CLOSED: raise ConnectionClosed(self.close_code, self.close_reason) - # If the closing handshake is in progress, let it complete to get the - # proper close status and code. As a safety measure, the timeout is - # longer than the worst case (3 * self.timeout) but not unlimited. if self.state == CLOSING: - yield from asyncio.wait_for( - asyncio.shield(self.transfer_data_task), - 4 * self.timeout, loop=self.loop) + # If we started the closing handshake, wait for its completion to + # get the proper close code and status. self.transfer_data_task + # will complete within 2 * timeout after calling close(). + # If we moved to the CLOSING state because we're failing the + # connection, self.transfer_data_task will complete immediately. + if self.close_code is None: + yield from asyncio.shield(self.transfer_data_task) raise ConnectionClosed(self.close_code, self.close_reason) # Control may only reach this point in buggy third-party subclasses. diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 67cf3d114..82d078290 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -160,7 +160,7 @@ def receive_eof_if_client(self): def close_connection(self, code=1000, reason='close'): """ - Close the connection with a standard closing handshake. + Execute a closing handshake. This puts the connection in the CLOSED state. @@ -174,9 +174,9 @@ def close_connection(self, code=1000, reason='close'): # Empty the outgoing data stream so we can make assertions later on. self.assertOneFrameSent(True, OP_CLOSE, close_frame_data) - def close_connection_partial(self, code=1000, reason='close'): + def half_close_connection_local(self, code=1000, reason='close'): """ - Initiate a standard closing handshake but do not complete it. + Start a closing handshake but do not complete it. The main difference with `close_connection` is that the connection is left in the CLOSING state until the event loop runs again. @@ -190,6 +190,20 @@ def close_connection_partial(self, code=1000, reason='close'): # Empty the outgoing data stream so we can make assertions later on. self.assertOneFrameSent(True, OP_CLOSE, close_frame_data) # Prepare the response to the closing handshake from the remote side. + self.loop.call_soon( + self.receive_frame, Frame(True, OP_CLOSE, close_frame_data)) + self.loop.call_soon(self.receive_eof_if_client) + + def half_close_connection_remote(self, code=1000, reason='close'): + """ + Receive a closing handshake. + + The main difference with `close_connection` is that the connection is + left in the CLOSING state until the event loop runs again. + + """ + close_frame_data = serialize_close(code, reason) + # Trigger the closing handshake from the remote side. self.receive_frame(Frame(True, OP_CLOSE, close_frame_data)) self.receive_eof_if_client() @@ -325,8 +339,14 @@ def test_recv_binary(self): data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, b'tea') - def test_recv_on_closing_connection(self): - self.close_connection_partial() + def test_recv_on_closing_connection_local(self): + self.half_close_connection_local() + + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.protocol.recv()) + + def test_recv_on_closing_connection_remote(self): + self.half_close_connection_remote() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.recv()) @@ -405,13 +425,20 @@ def test_send_type_error(self): self.loop.run_until_complete(self.protocol.send(42)) self.assertNoFrameSent() - def test_send_on_closing_connection(self): - self.close_connection_partial() + def test_send_on_closing_connection_local(self): + self.half_close_connection_local() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.send('foobar')) self.assertNoFrameSent() + def test_send_on_closing_connection_remote(self): + self.half_close_connection_remote() + + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.protocol.send('foobar')) + self.assertOneFrameSent(True, OP_CLOSE, serialize_close(1000, 'close')) + def test_send_on_closed_connection(self): self.close_connection() @@ -443,13 +470,20 @@ def test_ping_type_error(self): self.loop.run_until_complete(self.protocol.ping(42)) self.assertNoFrameSent() - def test_ping_on_closing_connection(self): - self.close_connection_partial() + def test_ping_on_closing_connection_local(self): + self.half_close_connection_local() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.ping()) self.assertNoFrameSent() + def test_ping_on_closing_connection_remote(self): + self.half_close_connection_remote() + + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.protocol.ping()) + self.assertOneFrameSent(True, OP_CLOSE, serialize_close(1000, 'close')) + def test_ping_on_closed_connection(self): self.close_connection() @@ -476,13 +510,20 @@ def test_pong_type_error(self): self.loop.run_until_complete(self.protocol.pong(42)) self.assertNoFrameSent() - def test_pong_on_closing_connection(self): - self.close_connection_partial() + def test_pong_on_closing_connection_local(self): + self.half_close_connection_local() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.pong()) self.assertNoFrameSent() + def test_pong_on_closing_connection_remote(self): + self.half_close_connection_remote() + + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.protocol.pong()) + self.assertOneFrameSent(True, OP_CLOSE, serialize_close(1000, 'close')) + def test_pong_on_closed_connection(self): self.close_connection() From 68553b77af7aa8f880e18471d1b93f33392bc769 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 1 Oct 2017 19:12:38 +0200 Subject: [PATCH 0347/1539] Add unix_server to create a Unix server. Based on #273 by @PrettyWood. --- docs/api.rst | 3 +++ docs/changelog.rst | 2 ++ websockets/server.py | 31 +++++++++++++++++++++++++------ websockets/test_client_server.py | 29 ++++++++++++++++++++++++++++- 4 files changed, 58 insertions(+), 7 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index 742268f3f..83426959e 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -34,6 +34,9 @@ Server .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, compression='deflate', **kwds) + .. autofunction:: unix_serve(ws_handler, path, *, create_protocol=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, compression='deflate', **kwds) + + .. autoclass:: WebSocketServer .. automethod:: close() diff --git a/docs/changelog.rst b/docs/changelog.rst index 8fb174964..557780195 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -24,6 +24,8 @@ Also: * :class:`~protocol.WebSocketCommonProtocol` instances can be used as asynchronous iterators on Python ≥ 3.6. They yield incoming messages. +* Added :func:`~websockets.server.unix_serve` for listening on Unix sockets. + * Reorganized and extended documentation. * Aborted connections if they don't close within the configured ``timeout``. diff --git a/websockets/server.py b/websockets/server.py index 0735cf7db..1aa823037 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -24,7 +24,7 @@ from .protocol import WebSocketCommonProtocol -__all__ = ['serve', 'WebSocketServerProtocol'] +__all__ = ['serve', 'unix_serve', 'WebSocketServerProtocol'] logger = logging.getLogger(__name__) @@ -585,7 +585,7 @@ def wait_closed(self): @asyncio.coroutine def serve(ws_handler, host=None, port=None, *, - create_protocol=None, + path=None, create_protocol=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, legacy_recv=False, klass=None, @@ -657,9 +657,6 @@ def serve(ws_handler, host=None, port=None, *, logger.addHandler(logging.StreamHandler()) """ - if loop is None: - loop = asyncio.get_event_loop() - # Backwards-compatibility: create_protocol used to be called klass. # In the unlikely event that both are specified, klass is ignored. if create_protocol is None: @@ -668,6 +665,9 @@ def serve(ws_handler, host=None, port=None, *, if create_protocol is None: create_protocol = WebSocketServerProtocol + if loop is None: + loop = asyncio.get_event_loop() + ws_server = WebSocketServer(loop) secure = kwds.get('ssl') is not None @@ -692,13 +692,32 @@ def serve(ws_handler, host=None, port=None, *, origins=origins, extensions=extensions, subprotocols=subprotocols, extra_headers=extra_headers, ) - server = yield from loop.create_server(factory, host, port, **kwds) + if path is None: + server = yield from loop.create_server(factory, host, port, **kwds) + else: + server = yield from loop.create_unix_server(factory, path, **kwds) ws_server.wrap(server) return ws_server +@asyncio.coroutine +def unix_serve(ws_handler, path, **kwargs): + """ + Similar to :func:`serve()`, but for listening on Unix sockets. + + This function calls the event loop's + :meth:`~asyncio.AbstractEventLoop.create_unix_server` method. + + It is only available on Unix. + + It's useful for deploying a server behind a reverse proxy such as nginx. + + """ + return serve(ws_handler, path=path, **kwargs) + + try: from .py35.server import Serve except (SyntaxError, ImportError): # pragma: no cover diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 3d4ab0071..33b246de4 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -2,10 +2,11 @@ import contextlib import functools import logging -import os +import os.path import socket import ssl import sys +import tempfile import unittest import unittest.mock import urllib.request @@ -283,6 +284,29 @@ def send(self, *args, **kwargs): finally: sock.close() + @unittest.skipUnless( + hasattr(socket, 'AF_UNIX'), 'this test requires Unix sockets') + def test_unix_socket(self): + with tempfile.TemporaryDirectory() as temp_dir: + path = os.path.join(temp_dir, 'websockets') + + # Like self.start_server() but with unix_serve(). + unix_server = unix_serve(handler, path) + self.server = self.loop.run_until_complete(unix_server) + + sock = socket.socket(socket.AF_UNIX) + sock.connect(path) + + try: + with self.temp_client(sock=sock): + self.loop.run_until_complete(self.client.send("Hello!")) + reply = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(reply, "Hello!") + + finally: + sock.close() + self.stop_server() + @with_server() @with_client('attributes') def test_protocol_attributes(self): @@ -862,6 +886,9 @@ def start_client(self, path='', **kwds): kwds.setdefault('ssl', self.client_context) super().start_client(path, **kwds) + # TLS over Unix sockets doesn't make sense. + test_unix_socket = None + @with_server() def test_ws_uri_is_rejected(self): client = connect('ws://localhost:8642/', ssl=self.client_context) From 8cdf09b8f122a339188905a58182148eed05778d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 2 Oct 2017 21:54:57 +0200 Subject: [PATCH 0348/1539] Document concurrency-safety of public APIs. Fix #238. --- docs/design.rst | 24 ++++++++++++++++++++++++ websockets/framing.py | 3 ++- websockets/protocol.py | 7 +++++-- 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/docs/design.rst b/docs/design.rst index f8132c7a9..f885b0047 100644 --- a/docs/design.rst +++ b/docs/design.rst @@ -420,3 +420,27 @@ For each connection, the sending side contains these buffers: You can set another limit by passing a ``write_limit`` keyword argument to :func:`~client.connect()` or :func:`~server.serve()`. - OS buffers: tuning them is an advanced optimization. + +Concurrency +----------- + +Calling any combination of :meth:`~protocol.WebSocketCommonProtocol.recv()`, +:meth:`~protocol.WebSocketCommonProtocol.send()`, +:meth:`~protocol.WebSocketCommonProtocol.close()` +:meth:`~protocol.WebSocketCommonProtocol.ping()`, or +:meth:`~protocol.WebSocketCommonProtocol.pong()` concurrently is safe, +including multiple calls to the same method. + +As shown above, receiving frames is independent from sending frames. That +isolates :meth:`~protocol.WebSocketCommonProtocol.recv()`, which receives +frames, from the other methods, which send frames. + +While :meth:`~protocol.WebSocketCommonProtocol.recv()` supports being called +multiple times concurrently, this is unlikely to be useful: when multiple +callers are waiting for the next message, exactly one of them will get it, but +there is no guarantee about which one. + +Methods that send frames also support concurrent calls. While the connection +is open, each frame is sent with a single write. Combined with the concurrency +model of :mod:`asyncio`, this enforces serialization. After the connection is +closed, sending a frame raises :exc:`~websockets.exceptions.ConnectionClosed`. diff --git a/websockets/framing.py b/websockets/framing.py index c76778315..a45c78838 100644 --- a/websockets/framing.py +++ b/websockets/framing.py @@ -208,7 +208,8 @@ def write(frame, writer, *, mask, extensions=None): # Send the frame. # The frame is written in a single call to writer in order to prevent - # TCP fragmentation. See #68 for details. + # TCP fragmentation. See #68 for details. This also makes it safe to + # send frames concurrently from multiple coroutines. writer(output.getvalue()) def check(frame): diff --git a/websockets/protocol.py b/websockets/protocol.py index 1cc64217a..c638fd705 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -268,8 +268,11 @@ def close(self, code=1000, reason=''): """ This coroutine performs the closing handshake. - It waits for the other end to complete the handshake. It doesn't do - anything once the connection is closed. Thus it's idemptotent. + It waits for the other end to complete the handshake and for the TCP + connection to terminate. + + It doesn't do anything once the connection is closed. In other words + it's idemptotent. It's safe to wrap this coroutine in :func:`~asyncio.ensure_future` since errors during connection termination aren't particularly useful. From 050a6aeb90225fe54b05f8c8cef9ed05596f3553 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 2 Oct 2017 22:13:34 +0200 Subject: [PATCH 0349/1539] Documented passing extra args to the handler. Fix #271. --- docs/cheatsheet.rst | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/docs/cheatsheet.rst b/docs/cheatsheet.rst index 610f27973..eb29ad23d 100644 --- a/docs/cheatsheet.rst +++ b/docs/cheatsheet.rst @@ -95,3 +95,26 @@ idle connections after some time:: else: # do something with msg ... + +Passing additional arguments to the connection handler +------------------------------------------------------ + +When writing a server, if you need to pass additional arguments to the +connection handler, you can bind them with :func:`functools.partial`:: + + import asyncio + import functools + import websockets + + async def handler(websocket, path, extra_argument): + ... + + bound_handler = functools.partial(handler, extra_argument='spam') + start_server = websockets.serve(bound_handler, '127.0.0.1', 8765) + + asyncio.get_event_loop().run_until_complete(start_server) + asyncio.get_event_loop().run_forever() + +Another way to achieve this result is to define the ``handler`` corountine in +a scope where the ``extra_argument`` variable exists instead of injecting it +through an argument. From 5a9d2e0038acee5ffca60c047b7a28cd72b4324e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 2 Oct 2017 23:03:59 +0200 Subject: [PATCH 0350/1539] Remove superfluous args from subTest calls. --- websockets/extensions/test_permessage_deflate.py | 4 +--- websockets/test_headers.py | 4 ++-- websockets/test_uri.py | 2 +- websockets/test_utils.py | 2 +- 4 files changed, 5 insertions(+), 7 deletions(-) diff --git a/websockets/extensions/test_permessage_deflate.py b/websockets/extensions/test_permessage_deflate.py index b034a5089..8fbcdee58 100644 --- a/websockets/extensions/test_permessage_deflate.py +++ b/websockets/extensions/test_permessage_deflate.py @@ -99,7 +99,7 @@ def test_get_request_params(self): ], ), ]: - with self.subTest(config=config, result=result): + with self.subTest(config=config): factory = ClientPerMessageDeflateFactory(*config) self.assertEqual(factory.get_request_params(), result) @@ -303,7 +303,6 @@ def test_process_response_params(self): with self.subTest( config=config, response_params=response_params, - result=result, ): factory = ClientPerMessageDeflateFactory(*config) if isinstance(result, type) and issubclass(result, Exception): @@ -606,7 +605,6 @@ def test_process_request_params(self): config=config, request_params=request_params, response_params=response_params, - result=result, ): factory = ServerPerMessageDeflateFactory(*config) if isinstance(result, type) and issubclass(result, Exception): diff --git a/websockets/test_headers.py b/websockets/test_headers.py index 9311a4a68..230aadfac 100644 --- a/websockets/test_headers.py +++ b/websockets/test_headers.py @@ -45,7 +45,7 @@ def test_parse_extension_list(self): [('permessage-deflate', [('server_max_window_bits', '10')])], ), ]: - with self.subTest(header=header, parsed=parsed): + with self.subTest(header=header): self.assertEqual(parse_extension_list(header), parsed) # Also ensure that build_extension_list round-trips cleanly. unparsed = build_extension_list(parsed) @@ -86,7 +86,7 @@ def test_parse_protocol_list(self): ['foo', 'bar', 'baz'], ), ]: - with self.subTest(header=header, parsed=parsed): + with self.subTest(header=header): self.assertEqual(parse_protocol_list(header), parsed) # Also ensure that build_protocol_list round-trips cleanly. unparsed = build_protocol_list(parsed) diff --git a/websockets/test_uri.py b/websockets/test_uri.py index d15df3b63..1b9928007 100644 --- a/websockets/test_uri.py +++ b/websockets/test_uri.py @@ -23,7 +23,7 @@ class URITests(unittest.TestCase): def test_success(self): for uri, parsed in VALID_URIS: - with self.subTest(uri=uri, parsed=parsed): + with self.subTest(uri=uri): self.assertEqual(parse_uri(uri), parsed) def test_error(self): diff --git a/websockets/test_utils.py b/websockets/test_utils.py index 7772dce72..9611ee777 100644 --- a/websockets/test_utils.py +++ b/websockets/test_utils.py @@ -16,7 +16,7 @@ def test_apply_mask(self): (b'abcdABCD', b'1234', b'PPPPpppp'), (b'abcdABCD' * 10, b'1234', b'PPPPpppp' * 10), ]: - with self.subTest(data_in=data_in, mask=mask, data_out=data_out): + with self.subTest(data_in=data_in, mask=mask): self.assertEqual(self.apply_mask(data_in, mask), data_out) def test_apply_mask_check_input_types(self): From 4a8415da441ebab15dfebd798838c293f07d38d7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 2 Oct 2017 23:04:44 +0200 Subject: [PATCH 0351/1539] Add explanation of status codes in error messages. Fix #249. --- websockets/exceptions.py | 34 ++++++++++-- websockets/framing.py | 31 +++++------ websockets/test_exceptions.py | 98 +++++++++++++++++++++++++++++++++++ websockets/uri.py | 2 +- 4 files changed, 144 insertions(+), 21 deletions(-) create mode 100644 websockets/test_exceptions.py diff --git a/websockets/exceptions.py b/websockets/exceptions.py index 543f2579b..51525eb93 100644 --- a/websockets/exceptions.py +++ b/websockets/exceptions.py @@ -1,5 +1,4 @@ __all__ = [ - 'AbortHandshake', 'InvalidHandshake', 'InvalidHeader', 'InvalidMessage', 'InvalidOrigin', 'InvalidState', 'InvalidStatusCode', 'NegotiationError', 'InvalidParameterName', 'InvalidParameterValue', 'DuplicateParameter', @@ -24,6 +23,9 @@ def __init__(self, status, headers, body=None): self.status = status self.headers = headers self.body = body + message = "HTTP {}, {} headers, {} bytes".format( + status, len(headers), 0 if body is None else len(body)) + super().__init__(message) class InvalidMessage(InvalidHandshake): @@ -113,6 +115,23 @@ class InvalidState(Exception): """ +CLOSE_CODES = { + 1000: "OK", + 1001: "going away", + 1002: "protocol error", + 1003: "unsupported type", + # 1004 is reserved + 1005: "no status code [internal]", + 1006: "connection closed abnormally [internal]", + 1007: "invalid data", + 1008: "policy violation", + 1009: "message too big", + 1010: "extension required", + 1011: "unexpected error", + 1015: "TLS failure [internal]", +} + + class ConnectionClosed(InvalidState): """ Exception raised when trying to read or write on a closed connection. @@ -125,8 +144,17 @@ def __init__(self, code, reason): self.code = code self.reason = reason message = "WebSocket connection is closed: " - message += "code = {}, ".format(code) if code else "no code, " - message += "reason = {}.".format(reason) if reason else "no reason." + if 3000 <= code < 4000: + explanation = "registered" + elif 4000 <= code < 5000: + explanation = "private use" + else: + explanation = CLOSE_CODES.get(code, "unknown") + message += "code = {} ({}), ".format(code, explanation) + if reason: + message += "reason = {}.".format(reason) + else: + message += "no reason." super().__init__(message) diff --git a/websockets/framing.py b/websockets/framing.py index a45c78838..e9f5006b1 100644 --- a/websockets/framing.py +++ b/websockets/framing.py @@ -33,22 +33,19 @@ DATA_OPCODES = OP_CONT, OP_TEXT, OP_BINARY = 0x00, 0x01, 0x02 CTRL_OPCODES = OP_CLOSE, OP_PING, OP_PONG = 0x08, 0x09, 0x0a -CLOSE_CODES = { - 1000: "OK", - 1001: "going away", - 1002: "protocol error", - 1003: "unsupported type", - # 1004: - (reserved) - # 1005: no status code (internal) - # 1006: connection closed abnormally (internal) - 1007: "invalid data", - 1008: "policy violation", - 1009: "message too big", - 1010: "extension required", - 1011: "unexpected error", - # 1015: TLS failure (internal) -} - +# Close code that are allowed in a close frame. +# Using a list optimizes `code in EXTERNAL_CLOSE_CODES`. +EXTERNAL_CLOSE_CODES = [ + 1000, + 1001, + 1002, + 1003, + 1007, + 1008, + 1009, + 1010, + 1011, +] FrameData = collections.namedtuple( 'FrameData', @@ -294,5 +291,5 @@ def check_close(code): Check the close code for a close frame. """ - if not (code in CLOSE_CODES or 3000 <= code < 5000): + if not (code in EXTERNAL_CLOSE_CODES or 3000 <= code < 5000): raise WebSocketProtocolError("Invalid status code") diff --git a/websockets/test_exceptions.py b/websockets/test_exceptions.py new file mode 100644 index 000000000..b064e9a8f --- /dev/null +++ b/websockets/test_exceptions.py @@ -0,0 +1,98 @@ +import unittest + +from .exceptions import * + + +class ExceptionsTests(unittest.TestCase): + + def test_str(self): + for exception, exception_str in [ + ( + InvalidHandshake("Invalid request"), + "Invalid request", + ), + ( + AbortHandshake(200, [], b'OK\n'), + "HTTP 200, 0 headers, 3 bytes", + ), + ( + InvalidMessage("Malformed HTTP message"), + "Malformed HTTP message", + ), + ( + InvalidHeader("expected token", "a=|", 3), + "expected token at 3 in a=|", + ), + ( + InvalidOrigin("Origin not allowed: ''"), + "Origin not allowed: ''", + ), + ( + InvalidStatusCode(403), + "Status code not 101: 403", + ), + ( + NegotiationError("Unsupported subprotocol: spam"), + "Unsupported subprotocol: spam", + ), + ( + InvalidParameterName('|'), + "Invalid parameter name: |", + ), + ( + InvalidParameterValue('a', '|'), + "Invalid value for parameter a: |", + ), + ( + DuplicateParameter('a'), + "Duplicate parameter: a", + ), + ( + InvalidState("WebSocket connection isn't established yet."), + "WebSocket connection isn't established yet.", + ), + ( + ConnectionClosed(1000, ''), + "WebSocket connection is closed: code = 1000 " + "(OK), no reason.", + ), + ( + ConnectionClosed(1001, 'bye'), + "WebSocket connection is closed: code = 1001 " + "(going away), reason = bye.", + ), + ( + ConnectionClosed(1006, None), + "WebSocket connection is closed: code = 1006 " + "(connection closed abnormally [internal]), no reason." + ), + ( + ConnectionClosed(1016, None), + "WebSocket connection is closed: code = 1016 " + "(unknown), no reason." + ), + ( + ConnectionClosed(3000, None), + "WebSocket connection is closed: code = 3000 " + "(registered), no reason." + ), + ( + ConnectionClosed(4000, None), + "WebSocket connection is closed: code = 4000 " + "(private use), no reason." + ), + ( + InvalidURI("| isn't a valid URI"), + "| isn't a valid URI", + ), + ( + PayloadTooBig("Payload length exceeds limit (2 > 1 bytes)"), + "Payload length exceeds limit (2 > 1 bytes)", + ), + ( + WebSocketProtocolError("Invalid opcode (7)"), + "Invalid opcode (7)", + ), + ]: + with self.subTest(exception=exception): + self.assertEqual(str(exception), exception_str) diff --git a/websockets/uri.py b/websockets/uri.py index 6d9f9124c..84c3f3b87 100644 --- a/websockets/uri.py +++ b/websockets/uri.py @@ -44,7 +44,7 @@ def parse_uri(uri): assert uri.password is None assert uri.hostname is not None except AssertionError as exc: - raise InvalidURI() from exc + raise InvalidURI("{} isn't a valid URI".format(uri)) from exc secure = uri.scheme == 'wss' host = uri.hostname From 5ce75688f59b017c2f0d4efebff8e9022f0e3d67 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 1 Oct 2017 09:37:19 +0200 Subject: [PATCH 0352/1539] Test ConnectionError handling. Fix #215. Thanks @cjerdonek for reporting and investigating this issue. --- websockets/test_client_server.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 33b246de4..519621779 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -836,16 +836,25 @@ def test_invalid_status_error_during_client_connect(self): self.assertEqual(exception.status_code, 403) @with_server() - @unittest.mock.patch('websockets.server.read_request') - def test_connection_error_during_opening_handshake(self, _read_request): - _read_request.side_effect = ConnectionError - - # Exception appears to be platform-dependent: InvalidHandshake on - # macOS, ConnectionResetError on Linux. This doesn't matter; this - # test primarily aims at covering a code path on the server side. + @unittest.mock.patch( + 'websockets.server.WebSocketServerProtocol.write_http_response') + @unittest.mock.patch( + 'websockets.server.WebSocketServerProtocol.read_http_request') + def test_connection_error_during_opening_handshake( + self, _read_http_request, _write_http_response): + _read_http_request.side_effect = ConnectionError + + # This exception is currently platform-dependent. It was observed to + # be ConnectionResetError on Linux in the non-SSL case, and + # InvalidMessage otherwise (including both Linux and macOS). This + # doesn't matter though since this test is primarily for testing a + # code path on the server side. with self.assertRaises(Exception): self.start_client() + # No response must not be written if the network connection is broken. + _write_http_response.assert_not_called() + @with_server() @unittest.mock.patch('websockets.server.WebSocketServerProtocol.close') def test_connection_error_during_closing_handshake(self, close): From f83362ef91a6ac2ea89fd8bda4d18ff22ffb7a40 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 7 Oct 2017 20:44:00 +0200 Subject: [PATCH 0353/1539] Normalize exception message style. * Initial capital * No period * message: value rather than message (value) --- websockets/client.py | 10 +++++----- websockets/exceptions.py | 4 ++-- websockets/framing.py | 4 ++-- websockets/headers.py | 10 +++++----- websockets/protocol.py | 2 +- websockets/test_exceptions.py | 28 ++++++++++++++-------------- 6 files changed, 29 insertions(+), 29 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index 49b98bcc4..24fbdd9d7 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -120,7 +120,7 @@ def process_extensions(headers, available_extensions): if header_values is not None: if available_extensions is None: - raise InvalidHandshake("No extensions supported.") + raise InvalidHandshake("No extensions supported") parsed_header_values = sum([ parse_extension_list(header_value) @@ -152,7 +152,7 @@ def process_extensions(headers, available_extensions): # matched what the server sent. Fail the connection. else: raise NegotiationError( - "Unsupported extension: name={}, params={}".format( + "Unsupported extension: name = {}, params = {}".format( name, response_params)) return accepted_extensions @@ -174,7 +174,7 @@ def process_subprotocol(headers, available_subprotocols): if header_values is not None: if available_subprotocols is None: - raise InvalidHandshake("No subprotocols supported.") + raise InvalidHandshake("No subprotocols supported") parsed_header_values = sum([ parse_protocol_list(header_value) @@ -341,8 +341,8 @@ def connect(uri, *, if wsuri.secure: kwds.setdefault('ssl', True) elif kwds.get('ssl') is not None: - raise ValueError("connect() received a SSL context for a ws:// URI. " - "Use a wss:// URI to enable TLS.") + raise ValueError("connect() received a SSL context for a ws:// URI, " + "use a wss:// URI to enable TLS") if compression == 'deflate': if extensions is None: diff --git a/websockets/exceptions.py b/websockets/exceptions.py index 51525eb93..210aa40e0 100644 --- a/websockets/exceptions.py +++ b/websockets/exceptions.py @@ -152,9 +152,9 @@ def __init__(self, code, reason): explanation = CLOSE_CODES.get(code, "unknown") message += "code = {} ({}), ".format(code, explanation) if reason: - message += "reason = {}.".format(reason) + message += "reason = {}".format(reason) else: - message += "no reason." + message += "no reason" super().__init__(message) diff --git a/websockets/framing.py b/websockets/framing.py index e9f5006b1..8236017da 100644 --- a/websockets/framing.py +++ b/websockets/framing.py @@ -119,7 +119,7 @@ def read(cls, reader, *, mask, max_size=None, extensions=None): length, = struct.unpack('!Q', data) if max_size is not None and length > max_size: raise PayloadTooBig( - "Payload length exceeds limit ({} > {} bytes)" + "Payload length exceeds limit: {} > {} bytes" .format(length, max_size)) if mask: mask_bits = yield from reader(4) @@ -232,7 +232,7 @@ def check(frame): raise WebSocketProtocolError("Fragmented control frame") else: raise WebSocketProtocolError( - "Invalid opcode ({})".format(frame.opcode)) + "Invalid opcode: {}".format(frame.opcode)) def encode_data(data): diff --git a/websockets/headers.py b/websockets/headers.py index 276ec850a..c5b228cb9 100644 --- a/websockets/headers.py +++ b/websockets/headers.py @@ -65,7 +65,7 @@ def parse_token(string, pos): """ match = _token_re.match(string, pos) if match is None: - raise InvalidHeader("expected token", string=string, pos=pos) + raise InvalidHeader("Expected token", string=string, pos=pos) return match.group(), match.end() @@ -87,7 +87,7 @@ def parse_quoted_string(string, pos): """ match = _quoted_string_re.match(string, pos) if match is None: - raise InvalidHeader("expected quoted string", string=string, pos=pos) + raise InvalidHeader("Expected quoted string", string=string, pos=pos) return _unquote_re.sub(r'\1', match.group()[1:-1]), match.end() @@ -112,7 +112,7 @@ def parse_extension_param(string, pos): # https://tools.ietf.org/html/rfc6455#section-9.1 says: the value # after quoted-string unescaping MUST conform to the 'token' ABNF. if _token_re.fullmatch(value) is None: - raise InvalidHeader("invalid quoted string content", + raise InvalidHeader("Invalid quoted string content", string=string, pos=pos_before) else: value, pos = parse_token(string, pos) @@ -191,7 +191,7 @@ def parse_extension_list(string, pos=0): if peek_ahead(string, pos) == ',': pos = parse_OWS(string, pos + 1) else: - raise InvalidHeader("expected comma", string=string, pos=pos) + raise InvalidHeader("Expected comma", string=string, pos=pos) # Remove extra delimiters before the next extension. while peek_ahead(string, pos) == ',': @@ -282,7 +282,7 @@ def parse_protocol_list(string, pos=0): if peek_ahead(string, pos) == ',': pos = parse_OWS(string, pos + 1) else: - raise InvalidHeader("expected comma", string=string, pos=pos) + raise InvalidHeader("Expected comma", string=string, pos=pos) # Remove extra delimiters before the next protocol. while peek_ahead(string, pos) == ',': diff --git a/websockets/protocol.py b/websockets/protocol.py index c638fd705..5323dd329 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -488,7 +488,7 @@ def ensure_open(self): # Control may only reach this point in buggy third-party subclasses. assert self.state == CONNECTING - raise InvalidState("WebSocket connection isn't established yet.") + raise InvalidState("WebSocket connection isn't established yet") @asyncio.coroutine def transfer_data(self): diff --git a/websockets/test_exceptions.py b/websockets/test_exceptions.py index b064e9a8f..031490f9e 100644 --- a/websockets/test_exceptions.py +++ b/websockets/test_exceptions.py @@ -20,8 +20,8 @@ def test_str(self): "Malformed HTTP message", ), ( - InvalidHeader("expected token", "a=|", 3), - "expected token at 3 in a=|", + InvalidHeader("Expected token", "a=|", 3), + "Expected token at 3 in a=|", ), ( InvalidOrigin("Origin not allowed: ''"), @@ -48,50 +48,50 @@ def test_str(self): "Duplicate parameter: a", ), ( - InvalidState("WebSocket connection isn't established yet."), - "WebSocket connection isn't established yet.", + InvalidState("WebSocket connection isn't established yet"), + "WebSocket connection isn't established yet", ), ( ConnectionClosed(1000, ''), "WebSocket connection is closed: code = 1000 " - "(OK), no reason.", + "(OK), no reason", ), ( ConnectionClosed(1001, 'bye'), "WebSocket connection is closed: code = 1001 " - "(going away), reason = bye.", + "(going away), reason = bye", ), ( ConnectionClosed(1006, None), "WebSocket connection is closed: code = 1006 " - "(connection closed abnormally [internal]), no reason." + "(connection closed abnormally [internal]), no reason" ), ( ConnectionClosed(1016, None), "WebSocket connection is closed: code = 1016 " - "(unknown), no reason." + "(unknown), no reason" ), ( ConnectionClosed(3000, None), "WebSocket connection is closed: code = 3000 " - "(registered), no reason." + "(registered), no reason" ), ( ConnectionClosed(4000, None), "WebSocket connection is closed: code = 4000 " - "(private use), no reason." + "(private use), no reason" ), ( InvalidURI("| isn't a valid URI"), "| isn't a valid URI", ), ( - PayloadTooBig("Payload length exceeds limit (2 > 1 bytes)"), - "Payload length exceeds limit (2 > 1 bytes)", + PayloadTooBig("Payload length exceeds limit: 2 > 1 bytes"), + "Payload length exceeds limit: 2 > 1 bytes", ), ( - WebSocketProtocolError("Invalid opcode (7)"), - "Invalid opcode (7)", + WebSocketProtocolError("Invalid opcode: 7"), + "Invalid opcode: 7", ), ]: with self.subTest(exception=exception): From 27549c4b390443b7504e937d4d974bd0855b4c7f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 7 Oct 2017 20:56:57 +0200 Subject: [PATCH 0354/1539] Await close_connection_task in ensure_open. This fixes the following edge case: - this side starts the closing handshake - the connection hangs and the other end never completes the handshake - meanwhile, another coroutine attempts to send a message Before this change, that coroutine would receive a CloseConnection exception with code = None. After this change, it gets code = 1006, set in connection_lost. --- websockets/protocol.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index 5323dd329..c57be4024 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -478,12 +478,12 @@ def ensure_open(self): if self.state == CLOSING: # If we started the closing handshake, wait for its completion to - # get the proper close code and status. self.transfer_data_task - # will complete within 2 * timeout after calling close(). - # If we moved to the CLOSING state because we're failing the - # connection, self.transfer_data_task will complete immediately. + # get the proper close code and status. self.close_connection_task + # will complete within 4 or 5 * timeout after calling close(). + # The CLOSING state also occurs when failing the connection. In + # that case self.close_connection_task will complete even faster. if self.close_code is None: - yield from asyncio.shield(self.transfer_data_task) + yield from asyncio.shield(self.close_connection_task) raise ConnectionClosed(self.close_code, self.close_reason) # Control may only reach this point in buggy third-party subclasses. From b82e5703b5261759807d7b4310c4798688fe9a07 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 7 Oct 2017 21:05:03 +0200 Subject: [PATCH 0355/1539] Reorder protocol methods more consistently. --- websockets/protocol.py | 177 +++++++++++++++++++++-------------------- 1 file changed, 89 insertions(+), 88 deletions(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index c57be4024..c58c907a2 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -199,15 +199,30 @@ def __init__(self, *, def client_connected(self, reader, writer): """ - Callback for :class:`~asyncio.StreamReaderProtocol`. + Callback when the TCP connection is established. Record references to the stream reader and the stream writer to avoid - using private attributes ``_stream_reader`` and ``_stream_writer``. + using private attributes ``_stream_reader`` and ``_stream_writer`` of + :class:`~asyncio.StreamReaderProtocol`. """ self.reader = reader self.writer = writer + def connection_open(self): + """ + Callback when the WebSocket opening handshake completes. + + """ + assert self.state == CONNECTING + self.state = OPEN + # Start the task that receives incoming WebSocket messages. + self.transfer_data_task = asyncio_ensure_future( + self.transfer_data(), loop=self.loop) + # Start the task that eventually closes the TCP connection. + self.close_connection_task = asyncio_ensure_future( + self.close_connection(), loop=self.loop) + # Public API @property @@ -263,58 +278,6 @@ def state_name(self): """ return ['CONNECTING', 'OPEN', 'CLOSING', 'CLOSED'][self.state] - @asyncio.coroutine - def close(self, code=1000, reason=''): - """ - This coroutine performs the closing handshake. - - It waits for the other end to complete the handshake and for the TCP - connection to terminate. - - It doesn't do anything once the connection is closed. In other words - it's idemptotent. - - It's safe to wrap this coroutine in :func:`~asyncio.ensure_future` - since errors during connection termination aren't particularly useful. - - ``code`` must be an :class:`int` and ``reason`` a :class:`str`. - - """ - if self.state == OPEN: - # 7.1.2. Start the WebSocket Closing Handshake - # 7.1.3. The WebSocket Closing Handshake is Started - frame_data = serialize_close(code, reason) - try: - yield from asyncio.wait_for( - self.write_frame(OP_CLOSE, frame_data), - self.timeout, loop=self.loop) - except asyncio.TimeoutError: - # If the close frame cannot be sent because the send buffers - # are full, the closing handshake won't complete anyway. - # Cancel the data transfer task to shut down faster. - # Cancelling a task is idempotent. - self.transfer_data_task.cancel() - - # If no close frame is received within the timeout, wait_for() cancels - # the data transfer task and raises TimeoutError. Then transfer_data() - # catches CancelledError and exits without an exception. - - # If close() is called multiple times concurrently and one of these - # calls hits the timeout, other calls will resume executing without an - # exception, so there's no need to catch CancelledError here. - - try: - # If close() is cancelled during the wait, self.transfer_data_task - # is cancelled before the timeout elapses (on Python ≥ 3.4.3). - # This helps closing connections when shutting down a server. - yield from asyncio.wait_for( - self.transfer_data_task, self.timeout, loop=self.loop) - except asyncio.TimeoutError: - pass - - # Wait for the close connection task to close the TCP connection. - yield from asyncio.shield(self.close_connection_task) - @asyncio.coroutine def recv(self): """ @@ -388,6 +351,58 @@ def send(self, data): yield from self.write_frame(opcode, data) + @asyncio.coroutine + def close(self, code=1000, reason=''): + """ + This coroutine performs the closing handshake. + + It waits for the other end to complete the handshake and for the TCP + connection to terminate. + + It doesn't do anything once the connection is closed. In other words + it's idemptotent. + + It's safe to wrap this coroutine in :func:`~asyncio.ensure_future` + since errors during connection termination aren't particularly useful. + + ``code`` must be an :class:`int` and ``reason`` a :class:`str`. + + """ + if self.state == OPEN: + # 7.1.2. Start the WebSocket Closing Handshake + # 7.1.3. The WebSocket Closing Handshake is Started + frame_data = serialize_close(code, reason) + try: + yield from asyncio.wait_for( + self.write_frame(OP_CLOSE, frame_data), + self.timeout, loop=self.loop) + except asyncio.TimeoutError: + # If the close frame cannot be sent because the send buffers + # are full, the closing handshake won't complete anyway. + # Cancel the data transfer task to shut down faster. + # Cancelling a task is idempotent. + self.transfer_data_task.cancel() + + # If no close frame is received within the timeout, wait_for() cancels + # the data transfer task and raises TimeoutError. Then transfer_data() + # catches CancelledError and exits without an exception. + + # If close() is called multiple times concurrently and one of these + # calls hits the timeout, other calls will resume executing without an + # exception, so there's no need to catch CancelledError here. + + try: + # If close() is cancelled during the wait, self.transfer_data_task + # is cancelled before the timeout elapses (on Python ≥ 3.4.3). + # This helps closing connections when shutting down a server. + yield from asyncio.wait_for( + self.transfer_data_task, self.timeout, loop=self.loop) + except asyncio.TimeoutError: + pass + + # Wait for the close connection task to close the TCP connection. + yield from asyncio.shield(self.close_connection_task) + @asyncio.coroutine def ping(self, data=None): """ @@ -447,20 +462,6 @@ def pong(self, data=b''): # Private methods - no guarantees. - def connection_open(self): - """ - Callback when the opening handshake completes. - - """ - assert self.state == CONNECTING - self.state = OPEN - # Start the task that receives incoming WebSocket messages. - self.transfer_data_task = asyncio_ensure_future( - self.transfer_data(), loop=self.loop) - # Start the task that eventually closes the TCP connection. - self.close_connection_task = asyncio_ensure_future( - self.close_connection(), loop=self.loop) - @asyncio.coroutine def ensure_open(self): """ @@ -699,26 +700,6 @@ def writer_is_closing(self): except AttributeError: return transport._closed - @asyncio.coroutine - def wait_for_connection_lost(self): - """ - Wait until the TCP connection is closed or ``self.timeout`` elapses. - - Return ``True`` if the connection is closed and ``False`` otherwise. - - """ - if not self.connection_lost_waiter.done(): - try: - yield from asyncio.wait_for( - asyncio.shield(self.connection_lost_waiter), - self.timeout, loop=self.loop) - except asyncio.TimeoutError: - pass - # Re-check self.connection_lost_waiter.done() synchronously because - # connection_lost() could run between the moment the timeout occurs - # and the moment this coroutine resumes running. - return self.connection_lost_waiter.done() - @asyncio.coroutine def close_connection(self, after_handshake=True): """ @@ -780,6 +761,26 @@ def close_connection(self, after_handshake=True): # connection_lost() is called quickly after aborting. yield from self.wait_for_connection_lost() + @asyncio.coroutine + def wait_for_connection_lost(self): + """ + Wait until the TCP connection is closed or ``self.timeout`` elapses. + + Return ``True`` if the connection is closed and ``False`` otherwise. + + """ + if not self.connection_lost_waiter.done(): + try: + yield from asyncio.wait_for( + asyncio.shield(self.connection_lost_waiter), + self.timeout, loop=self.loop) + except asyncio.TimeoutError: + pass + # Re-check self.connection_lost_waiter.done() synchronously because + # connection_lost() could run between the moment the timeout occurs + # and the moment this coroutine resumes running. + return self.connection_lost_waiter.done() + @asyncio.coroutine def fail_connection(self, code=1011, reason=''): """ From 69cd26dcf3aea3ebfa9750ada894283a7c53d2d4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 15 Oct 2017 11:09:00 +0200 Subject: [PATCH 0356/1539] Improve some docstrings. --- websockets/protocol.py | 3 +++ websockets/server.py | 27 +++++++++++++-------------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index c58c907a2..b89399de3 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -213,7 +213,10 @@ def connection_open(self): """ Callback when the WebSocket opening handshake completes. + Enter the OPEN state and start the data transfer phase. + """ + # 4.1. The WebSocket Connection is Established. assert self.state == CONNECTING self.state = OPEN # Start the task that receives incoming WebSocket messages. diff --git a/websockets/server.py b/websockets/server.py index 1aa823037..c18a9704d 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -476,22 +476,21 @@ def handshake(self, origins=None, available_extensions=None, class WebSocketServer: """ - Wraps an underlying :class:`~asyncio.Server` object. + Wrapper for :class:`~asyncio.Server` that closes connections on exit. This class provides the return type of :func:`~websockets.server.serve`. - This class shouldn't be instantiated directly. - Objects of this class store a reference to an underlying - :class:`~asyncio.Server` object returned by - :meth:`~asyncio.AbstractEventLoop.create_server`. The class stores a - reference rather than inheriting from :class:`~asyncio.Server` in part - because :meth:`~asyncio.AbstractEventLoop.create_server` doesn't support - passing a custom :class:`~asyncio.Server` class. + It mimics the interface of :class:`~asyncio.AbstractServer`, namely its + :meth:`~asyncio.AbstractServer.close()` and + :meth:`~asyncio.AbstractServer.wait_closed()` methods, to close WebSocket + connections properly on exit, in addition to closing the underlying + :class:`~asyncio.Server`. - :class:`WebSocketServer` supports cleaning up the underlying - :class:`~asyncio.Server` object and other resources by implementing the - interface of ``asyncio.events.AbstractServer``, namely its ``close()`` - and ``wait_closed()`` methods. + Instances of this class store a reference to the :class:`~asyncio.Server` + object returned by :meth:`~asyncio.AbstractEventLoop.create_server` rather + than inherit from :class:`~asyncio.Server` in part because + :meth:`~asyncio.AbstractEventLoop.create_server` doesn't support passing a + custom :class:`~asyncio.Server` class. """ def __init__(self, loop): @@ -506,8 +505,8 @@ def wrap(self, server): Attach to a given :class:`~asyncio.Server`. Since :meth:`~asyncio.AbstractEventLoop.create_server` doesn't support - injecting a custom ``Server`` class, a simple solution that doesn't - rely on private APIs is to: + injecting a custom ``Server`` class, the easiest solution that doesn't + rely on private :mod:`asyncio` APIs is to: - instantiate a :class:`WebSocketServer` - give the protocol factory a reference to that instance From 4cbb18c21a66fb406e7cbd36eca3329312367a54 Mon Sep 17 00:00:00 2001 From: Chris Jerdonek Date: Wed, 1 Nov 2017 10:25:38 -0700 Subject: [PATCH 0357/1539] Eliminate the need to wrap serve(). (#294) * Eliminate the need to wrap serve(). See issue #197 for related discussion. * Fix Python 3.4. * Choose better names per @aaugustin's comment. * Update client.py. * Fix flake8. * Fix test for Python 3.4. --- websockets/client.py | 263 ++++++++++++++++-------------- websockets/py35/client.py | 21 --- websockets/py35/server.py | 22 --- websockets/server.py | 271 +++++++++++++++++-------------- websockets/test_client_server.py | 7 +- 5 files changed, 297 insertions(+), 287 deletions(-) delete mode 100644 websockets/py35/client.py delete mode 100644 websockets/py35/server.py diff --git a/websockets/client.py b/websockets/client.py index 24fbdd9d7..66e984c53 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -5,6 +5,7 @@ import asyncio import collections.abc +import sys from .exceptions import ( InvalidHandshake, InvalidMessage, InvalidStatusCode, NegotiationError @@ -272,129 +273,153 @@ def handshake(self, wsuri, origin=None, available_extensions=None, self.connection_open() -@asyncio.coroutine -def connect(uri, *, - create_protocol=None, - timeout=10, max_size=2 ** 20, max_queue=2 ** 5, - read_limit=2 ** 16, write_limit=2 ** 16, - loop=None, legacy_recv=False, klass=None, - origin=None, extensions=None, subprotocols=None, - extra_headers=None, compression='deflate', **kwds): - """ - This coroutine connects to a WebSocket server at a given ``uri``. - - It yields a :class:`WebSocketClientProtocol` which can then be used to - send and receive messages. - - On Python ≥ 3.5, :func:`connect` can be used as a asynchronous context - manager. In that case, the connection is closed when exiting the context. - - :func:`connect` is a wrapper around the event loop's - :meth:`~asyncio.BaseEventLoop.create_connection` method. Unknown keyword - arguments are passed to :meth:`~asyncio.BaseEventLoop.create_connection`. - - For example, you can set the ``ssl`` keyword argument to a - :class:`~ssl.SSLContext` to enforce some TLS settings. When connecting to - a ``wss://`` URI, if this argument isn't provided explicitly, it's set to - ``True``, which means Python's default :class:`~ssl.SSLContext` is used. - - The behavior of the ``timeout``, ``max_size``, and ``max_queue``, - ``read_limit``, and ``write_limit`` optional arguments is described in the - documentation of :class:`~websockets.protocol.WebSocketCommonProtocol`. - - The ``create_protocol`` parameter allows customizing the asyncio protocol - that manages the connection. It should be a callable or class accepting - the same arguments as :class:`WebSocketClientProtocol` and returning a - :class:`WebSocketClientProtocol` instance. It defaults to - :class:`WebSocketClientProtocol`. - - :func:`connect` also accepts the following optional arguments: - - * ``origin`` sets the Origin HTTP header - * ``extensions`` is a list of supported extensions in order of decreasing - preference - * ``subprotocols`` is a list of supported subprotocols in order of - decreasing preference - * ``extra_headers`` sets additional HTTP request headers – it can be a - mapping or an iterable of (name, value) pairs - * ``compression`` is a shortcut to configure compression extensions; - by default it enables the "permessage-deflate" extension; set it to - ``None`` to disable compression - - :func:`connect` raises :exc:`~websockets.uri.InvalidURI` if ``uri`` is - invalid and :exc:`~websockets.handshake.InvalidHandshake` if the opening - handshake fails. +class Connect: - """ - if loop is None: - loop = asyncio.get_event_loop() - - # Backwards-compatibility: create_protocol used to be called klass. - # In the unlikely event that both are specified, klass is ignored. - if create_protocol is None: - create_protocol = klass - - if create_protocol is None: - create_protocol = WebSocketClientProtocol - - wsuri = parse_uri(uri) - if wsuri.secure: - kwds.setdefault('ssl', True) - elif kwds.get('ssl') is not None: - raise ValueError("connect() received a SSL context for a ws:// URI, " - "use a wss:// URI to enable TLS") - - if compression == 'deflate': - if extensions is None: - extensions = [] - if not any( - extension_factory.name == ClientPerMessageDeflateFactory.name - for extension_factory in extensions - ): - extensions.append(ClientPerMessageDeflateFactory( - client_max_window_bits=True, - )) - elif compression is not None: - raise ValueError("Unsupported compression: {}".format(compression)) - - factory = lambda: create_protocol( - host=wsuri.host, port=wsuri.port, secure=wsuri.secure, - timeout=timeout, max_size=max_size, max_queue=max_queue, - read_limit=read_limit, write_limit=write_limit, - loop=loop, legacy_recv=legacy_recv, - origin=origin, extensions=extensions, subprotocols=subprotocols, - extra_headers=extra_headers, - ) - - if kwds.get('sock') is None: - host, port = wsuri.host, wsuri.port - else: - # If sock is given, host and port mustn't be specified. - host, port = None, None - - transport, protocol = yield from loop.create_connection( - factory, host, port, **kwds) - - try: - yield from protocol.handshake( - wsuri, origin=origin, - available_extensions=protocol.available_extensions, - available_subprotocols=protocol.available_subprotocols, - extra_headers=protocol.extra_headers, + def __init__(self, uri, *, + create_protocol=None, + timeout=10, max_size=2 ** 20, max_queue=2 ** 5, + read_limit=2 ** 16, write_limit=2 ** 16, + loop=None, legacy_recv=False, klass=None, + origin=None, extensions=None, subprotocols=None, + extra_headers=None, compression='deflate', **kwds): + """ + This coroutine connects to a WebSocket server at a given ``uri``. + + It yields a :class:`WebSocketClientProtocol` which can then be used to + send and receive messages. + + On Python ≥ 3.5, :func:`connect` can be used as a asynchronous + context manager. In that case, the connection is closed when exiting + the context. + + :func:`connect` is a wrapper around the event loop's + :meth:`~asyncio.BaseEventLoop.create_connection` method. Unknown + keyword arguments are passed to + :meth:`~asyncio.BaseEventLoop.create_connection`. + + For example, you can set the ``ssl`` keyword argument to a + :class:`~ssl.SSLContext` to enforce some TLS settings. When + connecting to a ``wss://`` URI, if this argument isn't provided + explicitly, it's set to ``True``, which means Python's default + :class:`~ssl.SSLContext` is used. + + The behavior of the ``timeout``, ``max_size``, and ``max_queue``, + ``read_limit``, and ``write_limit`` optional arguments is described + in the documentation of + :class:`~websockets.protocol.WebSocketCommonProtocol`. + + The ``create_protocol`` parameter allows customizing the asyncio + protocol that manages the connection. It should be a callable or + class accepting the same arguments as + :class:`WebSocketClientProtocol` and returning a + :class:`WebSocketClientProtocol` instance. It defaults to + :class:`WebSocketClientProtocol`. + + :func:`connect` also accepts the following optional arguments: + + * ``origin`` sets the Origin HTTP header + * ``extensions`` is a list of supported extensions in order of + decreasing preference + * ``subprotocols`` is a list of supported subprotocols in order of + decreasing preference + * ``extra_headers`` sets additional HTTP request headers – it can be a + mapping or an iterable of (name, value) pairs + * ``compression`` is a shortcut to configure compression extensions; + by default it enables the "permessage-deflate" extension; set it to + ``None`` to disable compression + + :func:`connect` raises :exc:`~websockets.uri.InvalidURI` if ``uri`` + is invalid and :exc:`~websockets.handshake.InvalidHandshake` if the + opening handshake fails. + + """ + if loop is None: + loop = asyncio.get_event_loop() + + # Backwards-compatibility: create_protocol used to be called klass. + # In the unlikely event that both are specified, klass is ignored. + if create_protocol is None: + create_protocol = klass + + if create_protocol is None: + create_protocol = WebSocketClientProtocol + + wsuri = parse_uri(uri) + if wsuri.secure: + kwds.setdefault('ssl', True) + elif kwds.get('ssl') is not None: + raise ValueError("connect() received a SSL context for a ws:// " + "URI, use a wss:// URI to enable TLS") + + if compression == 'deflate': + if extensions is None: + extensions = [] + if not any( + extension_factory.name == ClientPerMessageDeflateFactory.name + for extension_factory in extensions + ): + extensions.append(ClientPerMessageDeflateFactory( + client_max_window_bits=True, + )) + elif compression is not None: + raise ValueError("Unsupported compression: {}".format(compression)) + + factory = lambda: create_protocol( + host=wsuri.host, port=wsuri.port, secure=wsuri.secure, + timeout=timeout, max_size=max_size, max_queue=max_queue, + read_limit=read_limit, write_limit=write_limit, + loop=loop, legacy_recv=legacy_recv, + origin=origin, extensions=extensions, subprotocols=subprotocols, + extra_headers=extra_headers, ) - except Exception: - yield from protocol.close_connection(after_handshake=False) - raise - return protocol + if kwds.get('sock') is None: + host, port = wsuri.host, wsuri.port + else: + # If sock is given, host and port mustn't be specified. + host, port = None, None + + self._wsuri = wsuri + self._origin = origin + + # This is a coroutine object. + self._creating_connection = loop.create_connection( + factory, host, port, **kwds) + + @asyncio.coroutine + def __aenter__(self): + self.websocket = yield from self + return self.websocket + + @asyncio.coroutine + def __aexit__(self, exc_type, exc_value, traceback): + yield from self.websocket.close() + + def __await__(self): + transport, protocol = yield from self._creating_connection + + try: + yield from protocol.handshake( + self._wsuri, origin=self._origin, + available_extensions=protocol.available_extensions, + available_subprotocols=protocol.available_subprotocols, + extra_headers=protocol.extra_headers, + ) + except Exception: + yield from protocol.close_connection(after_handshake=False) + raise + + return protocol + + __iter__ = __await__ -try: - from .py35.client import Connect -except (SyntaxError, ImportError): # pragma: no cover - pass +if sys.version_info[:2] <= (3, 4): # pragma: no cover + import functools + + @asyncio.coroutine + @functools.wraps(Connect.__init__) + def connect(*args, **kwds): + return Connect(*args, **kwds).__await__() else: - Connect.__wrapped__ = connect - # Copy over docstring to support building documentation on Python 3.5. - Connect.__doc__ = connect.__doc__ connect = Connect diff --git a/websockets/py35/client.py b/websockets/py35/client.py deleted file mode 100644 index 5ab7af034..000000000 --- a/websockets/py35/client.py +++ /dev/null @@ -1,21 +0,0 @@ -class Connect: - """ - This class wraps :func:`~websockets.client.connect` on Python ≥ 3.5. - - This allows using it as an asynchronous context manager. - - """ - def __init__(self, *args, **kwargs): - self.client = self.__class__.__wrapped__(*args, **kwargs) - - async def __aenter__(self): - self.websocket = await self - return self.websocket - - async def __aexit__(self, exc_type, exc_value, traceback): - await self.websocket.close() - - def __await__(self): - return (yield from self.client) - - __iter__ = __await__ diff --git a/websockets/py35/server.py b/websockets/py35/server.py deleted file mode 100644 index 3aba1c84e..000000000 --- a/websockets/py35/server.py +++ /dev/null @@ -1,22 +0,0 @@ -class Serve: - """ - This class wraps :func:`~websockets.server.serve` on Python ≥ 3.5. - - This allows using it as an asynchronous context manager. - - """ - def __init__(self, *args, **kwargs): - self.server = self.__class__.__wrapped__(*args, **kwargs) - - async def __aenter__(self): - self.server = await self - return self.server - - async def __aexit__(self, exc_type, exc_value, traceback): - self.server.close() - await self.server.wait_closed() - - def __await__(self): - return (yield from self.server) - - __iter__ = __await__ diff --git a/websockets/server.py b/websockets/server.py index c18a9704d..277218d99 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -6,6 +6,7 @@ import asyncio import collections.abc import logging +import sys from .compatibility import ( BAD_REQUEST, FORBIDDEN, INTERNAL_SERVER_ERROR, SERVICE_UNAVAILABLE, @@ -582,123 +583,147 @@ def wait_closed(self): yield from self.server.wait_closed() -@asyncio.coroutine -def serve(ws_handler, host=None, port=None, *, - path=None, create_protocol=None, - timeout=10, max_size=2 ** 20, max_queue=2 ** 5, - read_limit=2 ** 16, write_limit=2 ** 16, - loop=None, legacy_recv=False, klass=None, - origins=None, extensions=None, subprotocols=None, - extra_headers=None, compression='deflate', **kwds): - """ - Create, start, and return a :class:`WebSocketServer` object. - - :func:`serve` is a wrapper around the event loop's - :meth:`~asyncio.AbstractEventLoop.create_server` method. - Internally, the function creates and starts a :class:`~asyncio.Server` - object by calling :meth:`~asyncio.AbstractEventLoop.create_server`. The - :class:`WebSocketServer` keeps a reference to this object. - - The returned :class:`WebSocketServer` and its resources can be cleaned - up by calling its :meth:`~websockets.server.WebSocketServer.close` and - :meth:`~websockets.server.WebSocketServer.wait_closed` methods. - - On Python ≥ 3.5, :func:`serve` can also be used as an asynchronous context - manager. In this case, the server is shut down when exiting the context. - - The ``ws_handler`` argument is the WebSocket handler. It must be a - coroutine accepting two arguments: a :class:`WebSocketServerProtocol` - and the request URI. - - The ``host`` and ``port`` arguments, as well as unrecognized keyword - arguments, are passed along to - :meth:`~asyncio.AbstractEventLoop.create_server`. For example, you can - set the ``ssl`` keyword argument to a :class:`~ssl.SSLContext` to enable - TLS. - - The ``create_protocol`` parameter allows customizing the asyncio protocol - that manages the connection. It should be a callable or class accepting - the same arguments as :class:`WebSocketServerProtocol` and returning a - :class:`WebSocketServerProtocol` instance. It defaults to - :class:`WebSocketServerProtocol`. - - The behavior of the ``timeout``, ``max_size``, and ``max_queue``, - ``read_limit``, and ``write_limit`` optional arguments is described in the - documentation of :class:`~websockets.protocol.WebSocketCommonProtocol`. - - :func:`serve` also accepts the following optional arguments: - - * ``origins`` defines acceptable Origin HTTP headers — include - ``''`` if the lack of an origin is acceptable - * ``extensions`` is a list of supported extensions in order of decreasing - preference - * ``subprotocols`` is a list of supported subprotocols in order of - decreasing preference - * ``extra_headers`` sets additional HTTP response headers — it can be a - mapping, an iterable of (name, value) pairs, or a callable taking the - request path and headers in arguments. - * ``compression`` is a shortcut to configure compression extensions; - by default it enables the "permessage-deflate" extension; set it to - ``None`` to disable compression - - Whenever a client connects, the server accepts the connection, creates a - :class:`WebSocketServerProtocol`, performs the opening handshake, and - delegates to the WebSocket handler. Once the handler completes, the server - performs the closing handshake and closes the connection. - - Since there's no useful way to propagate exceptions triggered in handlers, - they're sent to the ``'websockets.server'`` logger instead. Debugging is - much easier if you configure logging to print them:: - - import logging - logger = logging.getLogger('websockets.server') - logger.setLevel(logging.ERROR) - logger.addHandler(logging.StreamHandler()) +class Serve: - """ - # Backwards-compatibility: create_protocol used to be called klass. - # In the unlikely event that both are specified, klass is ignored. - if create_protocol is None: - create_protocol = klass - - if create_protocol is None: - create_protocol = WebSocketServerProtocol - - if loop is None: - loop = asyncio.get_event_loop() - - ws_server = WebSocketServer(loop) - - secure = kwds.get('ssl') is not None - - if compression == 'deflate': - if extensions is None: - extensions = [] - if not any( - extension_factory.name == ServerPerMessageDeflateFactory.name - for extension_factory in extensions - ): - extensions.append(ServerPerMessageDeflateFactory()) - elif compression is not None: - raise ValueError("Unsupported compression: {}".format(compression)) - - factory = lambda: create_protocol( - ws_handler, ws_server, - host=host, port=port, secure=secure, - timeout=timeout, max_size=max_size, max_queue=max_queue, - read_limit=read_limit, write_limit=write_limit, - loop=loop, legacy_recv=legacy_recv, - origins=origins, extensions=extensions, subprotocols=subprotocols, - extra_headers=extra_headers, - ) - if path is None: - server = yield from loop.create_server(factory, host, port, **kwds) - else: - server = yield from loop.create_unix_server(factory, path, **kwds) - - ws_server.wrap(server) - - return ws_server + def __init__(self, ws_handler, host=None, port=None, *, + path=None, create_protocol=None, + timeout=10, max_size=2 ** 20, max_queue=2 ** 5, + read_limit=2 ** 16, write_limit=2 ** 16, + loop=None, legacy_recv=False, klass=None, + origins=None, extensions=None, subprotocols=None, + extra_headers=None, compression='deflate', **kwds): + """ + Create, start, and return a :class:`WebSocketServer` object. + + :func:`serve` is a wrapper around the event loop's + :meth:`~asyncio.AbstractEventLoop.create_server` method. Internally, + the function creates and starts a :class:`~asyncio.Server` object by + calling :meth:`~asyncio.AbstractEventLoop.create_server`. The + :class:`WebSocketServer` keeps a reference to this object. + + The returned :class:`WebSocketServer` and its resources can be + cleaned up by calling its + :meth:`~websockets.server.WebSocketServer.close` and + :meth:`~websockets.server.WebSocketServer.wait_closed` methods. + + On Python ≥ 3.5, :func:`serve` can also be used as an asynchronous + context manager. In this case, the server is shut down when exiting + the context. + + The ``ws_handler`` argument is the WebSocket handler. It must be a + coroutine accepting two arguments: a :class:`WebSocketServerProtocol` + and the request URI. + + The ``host`` and ``port`` arguments, as well as unrecognized keyword + arguments, are passed along to + :meth:`~asyncio.AbstractEventLoop.create_server`. For example, you + can set the ``ssl`` keyword argument to a :class:`~ssl.SSLContext` to + enable TLS. + + The ``create_protocol`` parameter allows customizing the asyncio + protocol that manages the connection. It should be a callable or + class accepting the same arguments as + :class:`WebSocketServerProtocol` and returning a + :class:`WebSocketServerProtocol` instance. It defaults to + :class:`WebSocketServerProtocol`. + + The behavior of the ``timeout``, ``max_size``, and ``max_queue``, + ``read_limit``, and ``write_limit`` optional arguments is described + in the documentation of + :class:`~websockets.protocol.WebSocketCommonProtocol`. + + :func:`serve` also accepts the following optional arguments: + + * ``origins`` defines acceptable Origin HTTP headers — include + ``''`` if the lack of an origin is acceptable + * ``extensions`` is a list of supported extensions in order of + decreasing preference + * ``subprotocols`` is a list of supported subprotocols in order of + decreasing preference + * ``extra_headers`` sets additional HTTP response headers — it can be a + mapping, an iterable of (name, value) pairs, or a callable taking the + request path and headers in arguments. + * ``compression`` is a shortcut to configure compression extensions; + by default it enables the "permessage-deflate" extension; set it to + ``None`` to disable compression + + Whenever a client connects, the server accepts the connection, + creates a :class:`WebSocketServerProtocol`, performs the opening + handshake, and delegates to the WebSocket handler. Once the handler + completes, the server performs the closing handshake and closes the + connection. + + Since there's no useful way to propagate exceptions triggered in + handlers, they're sent to the ``'websockets.server'`` logger instead. + Debugging is much easier if you configure logging to print them:: + + import logging + logger = logging.getLogger('websockets.server') + logger.setLevel(logging.ERROR) + logger.addHandler(logging.StreamHandler()) + + """ + # Backwards-compatibility: create_protocol used to be called klass. + # In the unlikely event that both are specified, klass is ignored. + if create_protocol is None: + create_protocol = klass + + if create_protocol is None: + create_protocol = WebSocketServerProtocol + + if loop is None: + loop = asyncio.get_event_loop() + + ws_server = WebSocketServer(loop) + + secure = kwds.get('ssl') is not None + + if compression == 'deflate': + if extensions is None: + extensions = [] + if not any( + extension_factory.name == ServerPerMessageDeflateFactory.name + for extension_factory in extensions + ): + extensions.append(ServerPerMessageDeflateFactory()) + elif compression is not None: + raise ValueError("Unsupported compression: {}".format(compression)) + + factory = lambda: create_protocol( + ws_handler, ws_server, + host=host, port=port, secure=secure, + timeout=timeout, max_size=max_size, max_queue=max_queue, + read_limit=read_limit, write_limit=write_limit, + loop=loop, legacy_recv=legacy_recv, + origins=origins, extensions=extensions, subprotocols=subprotocols, + extra_headers=extra_headers, + ) + + if path is None: + creating_server = loop.create_server(factory, host, port, **kwds) + else: + creating_server = loop.create_unix_server(factory, path, **kwds) + + # This is a coroutine object. + self._creating_server = creating_server + self.ws_server = ws_server + + @asyncio.coroutine + def __aenter__(self): + return (yield from self) + + @asyncio.coroutine + def __aexit__(self, exc_type, exc_value, traceback): + self.ws_server.close() + yield from self.ws_server.wait_closed() + + def __await__(self): + server = yield from self._creating_server + self.ws_server.wrap(server) + + return self.ws_server + + __iter__ = __await__ @asyncio.coroutine @@ -717,12 +742,12 @@ def unix_serve(ws_handler, path, **kwargs): return serve(ws_handler, path=path, **kwargs) -try: - from .py35.server import Serve -except (SyntaxError, ImportError): # pragma: no cover - pass +if sys.version_info[:2] <= (3, 4): # pragma: no cover + import functools + + @asyncio.coroutine + @functools.wraps(Serve.__init__) + def serve(*args, **kwds): + return Serve(*args, **kwds).__await__() else: - Serve.__wrapped__ = serve - # Copy over docstring to support building documentation on Python 3.5. - Serve.__doc__ = serve.__doc__ serve = Serve diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 519621779..6499fd0ae 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -900,9 +900,12 @@ def start_client(self, path='', **kwds): @with_server() def test_ws_uri_is_rejected(self): - client = connect('ws://localhost:8642/', ssl=self.client_context) with self.assertRaises(ValueError): - self.loop.run_until_complete(client) + client = connect('ws://localhost:8642/', ssl=self.client_context) + # With Python ≥ 3.5, the exception is raised by connect() even + # before awaiting. However, with Python 3.4 the exception is + # raised only when awaiting. + self.loop.run_until_complete(client) # pragma: no cover class ClientServerOriginTests(unittest.TestCase): From fa96c583d178ff79a1536ce597b1a4ceef4d3d2e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 1 Nov 2017 21:01:34 +0100 Subject: [PATCH 0358/1539] Manage connection states with an enum. Fix #284. --- docs/api.rst | 2 -- docs/changelog.rst | 6 +++++ websockets/protocol.py | 47 ++++++++++++++----------------------- websockets/test_protocol.py | 15 ++++-------- 4 files changed, 29 insertions(+), 41 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index 83426959e..59463f032 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -78,8 +78,6 @@ Shared .. autoattribute:: remote_address .. autoattribute:: open - .. autoattribute:: state_name - Exceptions .......... diff --git a/docs/changelog.rst b/docs/changelog.rst index 557780195..e27a5fa7f 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -19,6 +19,12 @@ Changelog If you want to disable compression, add ``compression=None`` when calling :func:`~server.serve()` or :func:`~client.connect()`. +.. warning:: + + **Version 4.0 removes the ``state_name`` attribute of protocols.** + + Use ``protocol.state.name`` instead of ``protocol.state_name``. + Also: * :class:`~protocol.WebSocketCommonProtocol` instances can be used as diff --git a/websockets/protocol.py b/websockets/protocol.py index b89399de3..fb8ad0d75 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -10,6 +10,7 @@ import asyncio.queues import codecs import collections +import enum import logging import random import struct @@ -29,7 +30,8 @@ # A WebSocket connection goes through the following four states, in order: -CONNECTING, OPEN, CLOSING, CLOSED = range(4) +class State(enum.IntEnum): + CONNECTING, OPEN, CLOSING, CLOSED = range(4) # In order to ensure consistency, the code always checks the current value of # WebSocketCommonProtocol.state before assigning a new value and never yields @@ -161,7 +163,7 @@ def __init__(self, *, # are shared between the client-side and the server-side. # Subclasses implement the opening handshake and, on success, execute # :meth:`connection_open()` to change the state to OPEN. - self.state = CONNECTING + self.state = State.CONNECTING # HTTP protocol parameters. self.path = None @@ -217,8 +219,8 @@ def connection_open(self): """ # 4.1. The WebSocket Connection is Established. - assert self.state == CONNECTING - self.state = OPEN + assert self.state == State.CONNECTING + self.state = State.OPEN # Start the task that receives incoming WebSocket messages. self.transfer_data_task = asyncio_ensure_future( self.transfer_data(), loop=self.loop) @@ -266,20 +268,7 @@ def open(self): .. _EAFP: https://docs.python.org/3/glossary.html#term-eafp """ - return self.state == OPEN - - @property - def state_name(self): - """ - Current connection state, as a string. - - Possible states are defined in the WebSocket specification: - ``CONNECTING``, ``OPEN``, ``CLOSING``, or ``CLOSED``. - - To check if the connection is open, use :attr:`open` instead. - - """ - return ['CONNECTING', 'OPEN', 'CLOSING', 'CLOSED'][self.state] + return self.state == State.OPEN @asyncio.coroutine def recv(self): @@ -371,7 +360,7 @@ def close(self, code=1000, reason=''): ``code`` must be an :class:`int` and ``reason`` a :class:`str`. """ - if self.state == OPEN: + if self.state == State.OPEN: # 7.1.2. Start the WebSocket Closing Handshake # 7.1.3. The WebSocket Closing Handshake is Started frame_data = serialize_close(code, reason) @@ -474,13 +463,13 @@ def ensure_open(self): """ # Handle cases from most common to least common for performance. - if self.state == OPEN: + if self.state == State.OPEN: return - if self.state == CLOSED: + if self.state == State.CLOSED: raise ConnectionClosed(self.close_code, self.close_reason) - if self.state == CLOSING: + if self.state == State.CLOSING: # If we started the closing handshake, wait for its completion to # get the proper close code and status. self.close_connection_task # will complete within 4 or 5 * timeout after calling close(). @@ -491,7 +480,7 @@ def ensure_open(self): raise ConnectionClosed(self.close_code, self.close_reason) # Control may only reach this point in buggy third-party subclasses. - assert self.state == CONNECTING + assert self.state == State.CONNECTING raise InvalidState("WebSocket connection isn't established yet") @asyncio.coroutine @@ -608,7 +597,7 @@ def read_data_frame(self, max_size): # 7.1.5. The WebSocket Connection Close Code # 7.1.6. The WebSocket Connection Close Reason self.close_code, self.close_reason = code, reason - if self.state == OPEN: + if self.state == State.OPEN: # 7.1.3. The WebSocket Closing Handshake is Started yield from self.write_frame(OP_CLOSE, frame.data) return @@ -648,14 +637,14 @@ def read_frame(self, max_size): @asyncio.coroutine def write_frame(self, opcode, data=b''): # Defensive assertion for protocol compliance. - if self.state != OPEN: # pragma: no cover + if self.state != State.OPEN: # pragma: no cover raise InvalidState("Cannot write to a WebSocket " - "in the {} state".format(self.state_name)) + "in the {} state".format(self.state.name)) # Make sure no other frame will be sent after a close frame. Do this # before yielding control to avoid sending more than one close frame. if opcode == OP_CLOSE: - self.state = CLOSING + self.state = State.CLOSING frame = Frame(True, opcode, data) logger.debug("%s > %s", self.side, frame) @@ -795,7 +784,7 @@ def fail_connection(self, code=1011, reason=''): self.side, code, reason, ) # Don't send a close frame if the connection is broken. - if self.state == OPEN and code != 1006: + if self.state == State.OPEN and code != 1006: frame_data = serialize_close(code, reason) yield from self.write_frame(OP_CLOSE, frame_data) @@ -852,7 +841,7 @@ def connection_lost(self, exc): """ logger.debug("%s - connection_lost(%s)", self.side, exc) - self.state = CLOSED + self.state = State.CLOSED if self.close_code is None: self.close_code = 1006 # If self.connection_lost_waiter isn't pending, that's a bug, because: diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 82d078290..544a38158 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -9,7 +9,7 @@ from .compatibility import asyncio_ensure_future from .exceptions import ConnectionClosed, InvalidState from .framing import * -from .protocol import CLOSED, CONNECTING, WebSocketCommonProtocol +from .protocol import State, WebSocketCommonProtocol # Unit for timeouts. May be increased on slow machines by setting the @@ -72,7 +72,7 @@ def close(self): def abort(self): # Change this to an `if` if tests call abort() multiple times. - assert self.protocol.state != CLOSED + assert self.protocol.state != State.CLOSED self.loop.call_soon(self.protocol.connection_lost, None) @@ -254,14 +254,14 @@ def assertNoFrameSent(self): def assertConnectionClosed(self, code, message): # The following line guarantees that connection_lost was called. - self.assertEqual(self.protocol.state, CLOSED) + self.assertEqual(self.protocol.state, State.CLOSED) # A close frame was received. self.assertEqual(self.protocol.close_code, code) self.assertEqual(self.protocol.close_reason, message) def assertConnectionFailed(self, code, message): # The following line guarantees that connection_lost was called. - self.assertEqual(self.protocol.state, CLOSED) + self.assertEqual(self.protocol.state, State.CLOSED) # No close frame was received. self.assertEqual(self.protocol.close_code, 1006) self.assertEqual(self.protocol.close_reason, '') @@ -322,11 +322,6 @@ def test_open(self): self.close_connection() self.assertFalse(self.protocol.open) - def test_state_name(self): - self.assertEqual(self.protocol.state_name, 'OPEN') - self.close_connection() - self.assertEqual(self.protocol.state_name, 'CLOSED') - # Test the recv coroutine. def test_recv_text(self): @@ -673,7 +668,7 @@ def test_connection_lost(self): def test_ensure_connection_before_opening_handshake(self): # Simulate a bug by forcibly reverting the protocol state. - self.protocol.state = CONNECTING + self.protocol.state = State.CONNECTING with self.assertRaises(InvalidState): self.loop.run_until_complete(self.protocol.ensure_open()) From 2c0b6618a71453cf1ff39a3bf769081a25475aaf Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 1 Nov 2017 21:21:10 +0100 Subject: [PATCH 0359/1539] Fix doc generation on Python 3.5+. --- websockets/client.py | 102 +++++++++++++++---------------- websockets/server.py | 141 ++++++++++++++++++++----------------------- 2 files changed, 115 insertions(+), 128 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index 66e984c53..a04fae8f4 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -274,6 +274,52 @@ def handshake(self, wsuri, origin=None, available_extensions=None, class Connect: + """ + This coroutine connects to a WebSocket server at a given ``uri``. + + It yields a :class:`WebSocketClientProtocol` which can then be used to + send and receive messages. + + On Python ≥ 3.5, :func:`connect` can be used as a asynchronous context + manager. In that case, the connection is closed when exiting the context. + + :func:`connect` is a wrapper around the event loop's + :meth:`~asyncio.BaseEventLoop.create_connection` method. Unknown keyword + arguments are passed to :meth:`~asyncio.BaseEventLoop.create_connection`. + + For example, you can set the ``ssl`` keyword argument to a + :class:`~ssl.SSLContext` to enforce some TLS settings. When connecting to + a ``wss://`` URI, if this argument isn't provided explicitly, it's set to + ``True``, which means Python's default :class:`~ssl.SSLContext` is used. + + The behavior of the ``timeout``, ``max_size``, and ``max_queue``, + ``read_limit``, and ``write_limit`` optional arguments is described in the + documentation of :class:`~websockets.protocol.WebSocketCommonProtocol`. + + The ``create_protocol`` parameter allows customizing the asyncio protocol + that manages the connection. It should be a callable or class accepting + the same arguments as :class:`WebSocketClientProtocol` and returning a + :class:`WebSocketClientProtocol` instance. It defaults to + :class:`WebSocketClientProtocol`. + + :func:`connect` also accepts the following optional arguments: + + * ``origin`` sets the Origin HTTP header + * ``extensions`` is a list of supported extensions in order of + decreasing preference + * ``subprotocols`` is a list of supported subprotocols in order of + decreasing preference + * ``extra_headers`` sets additional HTTP request headers – it can be a + mapping or an iterable of (name, value) pairs + * ``compression`` is a shortcut to configure compression extensions; + by default it enables the "permessage-deflate" extension; set it to + ``None`` to disable compression + + :func:`connect` raises :exc:`~websockets.uri.InvalidURI` if ``uri`` is + invalid and :exc:`~websockets.handshake.InvalidHandshake` if the opening + handshake fails. + + """ def __init__(self, uri, *, create_protocol=None, @@ -282,57 +328,6 @@ def __init__(self, uri, *, loop=None, legacy_recv=False, klass=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, compression='deflate', **kwds): - """ - This coroutine connects to a WebSocket server at a given ``uri``. - - It yields a :class:`WebSocketClientProtocol` which can then be used to - send and receive messages. - - On Python ≥ 3.5, :func:`connect` can be used as a asynchronous - context manager. In that case, the connection is closed when exiting - the context. - - :func:`connect` is a wrapper around the event loop's - :meth:`~asyncio.BaseEventLoop.create_connection` method. Unknown - keyword arguments are passed to - :meth:`~asyncio.BaseEventLoop.create_connection`. - - For example, you can set the ``ssl`` keyword argument to a - :class:`~ssl.SSLContext` to enforce some TLS settings. When - connecting to a ``wss://`` URI, if this argument isn't provided - explicitly, it's set to ``True``, which means Python's default - :class:`~ssl.SSLContext` is used. - - The behavior of the ``timeout``, ``max_size``, and ``max_queue``, - ``read_limit``, and ``write_limit`` optional arguments is described - in the documentation of - :class:`~websockets.protocol.WebSocketCommonProtocol`. - - The ``create_protocol`` parameter allows customizing the asyncio - protocol that manages the connection. It should be a callable or - class accepting the same arguments as - :class:`WebSocketClientProtocol` and returning a - :class:`WebSocketClientProtocol` instance. It defaults to - :class:`WebSocketClientProtocol`. - - :func:`connect` also accepts the following optional arguments: - - * ``origin`` sets the Origin HTTP header - * ``extensions`` is a list of supported extensions in order of - decreasing preference - * ``subprotocols`` is a list of supported subprotocols in order of - decreasing preference - * ``extra_headers`` sets additional HTTP request headers – it can be a - mapping or an iterable of (name, value) pairs - * ``compression`` is a shortcut to configure compression extensions; - by default it enables the "permessage-deflate" extension; set it to - ``None`` to disable compression - - :func:`connect` raises :exc:`~websockets.uri.InvalidURI` if ``uri`` - is invalid and :exc:`~websockets.handshake.InvalidHandshake` if the - opening handshake fails. - - """ if loop is None: loop = asyncio.get_event_loop() @@ -415,11 +410,10 @@ def __await__(self): if sys.version_info[:2] <= (3, 4): # pragma: no cover - import functools - @asyncio.coroutine - @functools.wraps(Connect.__init__) def connect(*args, **kwds): return Connect(*args, **kwds).__await__() + connect.__doc__ = Connect.__doc__ + else: connect = Connect diff --git a/websockets/server.py b/websockets/server.py index 277218d99..135bd5ce5 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -584,6 +584,71 @@ def wait_closed(self): class Serve: + """ + This coroutine creates, starts, and return a :class:`WebSocketServer`. + + :func:`serve` is a wrapper around the event loop's + :meth:`~asyncio.AbstractEventLoop.create_server` method. Internally, the + function creates and starts a :class:`~asyncio.Server` object by calling + :meth:`~asyncio.AbstractEventLoop.create_server`. The + :class:`WebSocketServer` keeps a reference to this object. + + The returned :class:`WebSocketServer` and its resources can be cleaned up + by calling its :meth:`~websockets.server.WebSocketServer.close` and + :meth:`~websockets.server.WebSocketServer.wait_closed` methods. + + On Python ≥ 3.5, :func:`serve` can also be used as an asynchronous context + manager. In this case, the server is shut down when exiting the context. + + The ``ws_handler`` argument is the WebSocket handler. It must be a + coroutine accepting two arguments: a :class:`WebSocketServerProtocol` and + the request URI. + + The ``host`` and ``port`` arguments, as well as unrecognized keyword + arguments, are passed along to + :meth:`~asyncio.AbstractEventLoop.create_server`. For example, you can set + the ``ssl`` keyword argument to a :class:`~ssl.SSLContext` to enable TLS. + + The ``create_protocol`` parameter allows customizing the asyncio protocol + that manages the connection. It should be a callable or class accepting + the same arguments as :class:`WebSocketServerProtocol` and returning a + :class:`WebSocketServerProtocol` instance. It defaults to + :class:`WebSocketServerProtocol`. + + The behavior of the ``timeout``, ``max_size``, and ``max_queue``, + ``read_limit``, and ``write_limit`` optional arguments is described in the + documentation of :class:`~websockets.protocol.WebSocketCommonProtocol`. + + :func:`serve` also accepts the following optional arguments: + + * ``origins`` defines acceptable Origin HTTP headers — include ``''`` if + the lack of an origin is acceptable + * ``extensions`` is a list of supported extensions in order of + decreasing preference + * ``subprotocols`` is a list of supported subprotocols in order of + decreasing preference + * ``extra_headers`` sets additional HTTP response headers — it can be a + mapping, an iterable of (name, value) pairs, or a callable taking the + request path and headers in arguments. + * ``compression`` is a shortcut to configure compression extensions; + by default it enables the "permessage-deflate" extension; set it to + ``None`` to disable compression + + Whenever a client connects, the server accepts the connection, creates a + :class:`WebSocketServerProtocol`, performs the opening handshake, and + delegates to the WebSocket handler. Once the handler completes, the server + performs the closing handshake and closes the connection. + + Since there's no useful way to propagate exceptions triggered in handlers, + they're sent to the ``'websockets.server'`` logger instead. Debugging is + much easier if you configure logging to print them:: + + import logging + logger = logging.getLogger('websockets.server') + logger.setLevel(logging.ERROR) + logger.addHandler(logging.StreamHandler()) + + """ def __init__(self, ws_handler, host=None, port=None, *, path=None, create_protocol=None, @@ -592,77 +657,6 @@ def __init__(self, ws_handler, host=None, port=None, *, loop=None, legacy_recv=False, klass=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, compression='deflate', **kwds): - """ - Create, start, and return a :class:`WebSocketServer` object. - - :func:`serve` is a wrapper around the event loop's - :meth:`~asyncio.AbstractEventLoop.create_server` method. Internally, - the function creates and starts a :class:`~asyncio.Server` object by - calling :meth:`~asyncio.AbstractEventLoop.create_server`. The - :class:`WebSocketServer` keeps a reference to this object. - - The returned :class:`WebSocketServer` and its resources can be - cleaned up by calling its - :meth:`~websockets.server.WebSocketServer.close` and - :meth:`~websockets.server.WebSocketServer.wait_closed` methods. - - On Python ≥ 3.5, :func:`serve` can also be used as an asynchronous - context manager. In this case, the server is shut down when exiting - the context. - - The ``ws_handler`` argument is the WebSocket handler. It must be a - coroutine accepting two arguments: a :class:`WebSocketServerProtocol` - and the request URI. - - The ``host`` and ``port`` arguments, as well as unrecognized keyword - arguments, are passed along to - :meth:`~asyncio.AbstractEventLoop.create_server`. For example, you - can set the ``ssl`` keyword argument to a :class:`~ssl.SSLContext` to - enable TLS. - - The ``create_protocol`` parameter allows customizing the asyncio - protocol that manages the connection. It should be a callable or - class accepting the same arguments as - :class:`WebSocketServerProtocol` and returning a - :class:`WebSocketServerProtocol` instance. It defaults to - :class:`WebSocketServerProtocol`. - - The behavior of the ``timeout``, ``max_size``, and ``max_queue``, - ``read_limit``, and ``write_limit`` optional arguments is described - in the documentation of - :class:`~websockets.protocol.WebSocketCommonProtocol`. - - :func:`serve` also accepts the following optional arguments: - - * ``origins`` defines acceptable Origin HTTP headers — include - ``''`` if the lack of an origin is acceptable - * ``extensions`` is a list of supported extensions in order of - decreasing preference - * ``subprotocols`` is a list of supported subprotocols in order of - decreasing preference - * ``extra_headers`` sets additional HTTP response headers — it can be a - mapping, an iterable of (name, value) pairs, or a callable taking the - request path and headers in arguments. - * ``compression`` is a shortcut to configure compression extensions; - by default it enables the "permessage-deflate" extension; set it to - ``None`` to disable compression - - Whenever a client connects, the server accepts the connection, - creates a :class:`WebSocketServerProtocol`, performs the opening - handshake, and delegates to the WebSocket handler. Once the handler - completes, the server performs the closing handshake and closes the - connection. - - Since there's no useful way to propagate exceptions triggered in - handlers, they're sent to the ``'websockets.server'`` logger instead. - Debugging is much easier if you configure logging to print them:: - - import logging - logger = logging.getLogger('websockets.server') - logger.setLevel(logging.ERROR) - logger.addHandler(logging.StreamHandler()) - - """ # Backwards-compatibility: create_protocol used to be called klass. # In the unlikely event that both are specified, klass is ignored. if create_protocol is None: @@ -743,11 +737,10 @@ def unix_serve(ws_handler, path, **kwargs): if sys.version_info[:2] <= (3, 4): # pragma: no cover - import functools - @asyncio.coroutine - @functools.wraps(Serve.__init__) def serve(*args, **kwds): return Serve(*args, **kwds).__await__() + serve.__doc__ = Serve.__doc__ + else: serve = Serve From c8057350a08d697a669fe875f744e398f8f54d74 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 1 Nov 2017 21:25:39 +0100 Subject: [PATCH 0360/1539] Increase symmetry between connect and serve. --- websockets/client.py | 6 +++--- websockets/server.py | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index a04fae8f4..4290e007b 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -383,12 +383,11 @@ def __init__(self, uri, *, @asyncio.coroutine def __aenter__(self): - self.websocket = yield from self - return self.websocket + return (yield from self) @asyncio.coroutine def __aexit__(self, exc_type, exc_value, traceback): - yield from self.websocket.close() + yield from self.ws_client.close() def __await__(self): transport, protocol = yield from self._creating_connection @@ -404,6 +403,7 @@ def __await__(self): yield from protocol.close_connection(after_handshake=False) raise + self.ws_client = protocol return protocol __iter__ = __await__ diff --git a/websockets/server.py b/websockets/server.py index 135bd5ce5..3ca6d2066 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -714,7 +714,6 @@ def __aexit__(self, exc_type, exc_value, traceback): def __await__(self): server = yield from self._creating_server self.ws_server.wrap(server) - return self.ws_server __iter__ = __await__ From 2599b8a3df918a3718422e7ccebb3f316ef602b0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 1 Nov 2017 21:41:36 +0100 Subject: [PATCH 0361/1539] Workaround for https://github.com/travis-ci/travis-ci/issues/8552. --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index db6596e76..389349a4e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -19,6 +19,7 @@ matrix: install: # Python 3 is needed to run cibuildwheel for websockets. - if [ "${TRAVIS_OS_NAME:-}" == "osx" ]; then + brew update; brew install python3; fi # Install cibuildwheel using pip3 to make sure Python 3 is used. From 11a16d9799cfeb899c362660ed4a98233e3985b5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 1 Nov 2017 19:12:10 +0100 Subject: [PATCH 0362/1539] Add WebSocketServer.sockets. Also refactor tests to bind to an arbitrary available port instead of hardcoding port 8642. --- docs/api.rst | 1 + docs/changelog.rst | 2 + websockets/py35/_test_client_server.py | 16 +-- websockets/py36/_test_client_server.py | 21 +-- websockets/server.py | 10 ++ websockets/test_client_server.py | 186 ++++++++++++++++--------- 6 files changed, 149 insertions(+), 87 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index 59463f032..df68764c3 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -41,6 +41,7 @@ Server .. automethod:: close() .. automethod:: wait_closed() + .. autoattribute:: sockets .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, extensions=None, subprotocols=None, extra_headers=None) diff --git a/docs/changelog.rst b/docs/changelog.rst index e27a5fa7f..a494043b1 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -32,6 +32,8 @@ Also: * Added :func:`~websockets.server.unix_serve` for listening on Unix sockets. +* Added the :attr:`~websockets.server.WebSocketServer.sockets` attribute. + * Reorganized and extended documentation. * Aborted connections if they don't close within the configured ``timeout``. diff --git a/websockets/py35/_test_client_server.py b/websockets/py35/_test_client_server.py index 1e69a8675..e9d7a6221 100644 --- a/websockets/py35/_test_client_server.py +++ b/websockets/py35/_test_client_server.py @@ -5,7 +5,7 @@ from ..client import * from ..server import * -from ..test_client_server import handler +from ..test_client_server import get_server_uri, handler class ContextManagerTests(unittest.TestCase): @@ -18,24 +18,24 @@ def tearDown(self): self.loop.close() def test_client(self): - server = serve(handler, 'localhost', 8642) - self.server = self.loop.run_until_complete(server) + start_server = serve(handler, 'localhost', 0) + server = self.loop.run_until_complete(start_server) async def run_client(): - async with connect('ws://localhost:8642/') as client: + async with connect(get_server_uri(server)) as client: await client.send("Hello!") reply = await client.recv() self.assertEqual(reply, "Hello!") self.loop.run_until_complete(run_client()) - self.server.close() - self.loop.run_until_complete(self.server.wait_closed()) + server.close() + self.loop.run_until_complete(server.wait_closed()) def test_server(self): async def run_server(): - async with serve(handler, 'localhost', 8642): - client = await connect('ws://localhost:8642/') + async with serve(handler, 'localhost', 0) as server: + client = await connect(get_server_uri(server)) await client.send("Hello!") reply = await client.recv() self.assertEqual(reply, "Hello!") diff --git a/websockets/py36/_test_client_server.py b/websockets/py36/_test_client_server.py index cfa2760fe..0bd0d8938 100644 --- a/websockets/py36/_test_client_server.py +++ b/websockets/py36/_test_client_server.py @@ -7,6 +7,7 @@ from ..client import * from ..exceptions import ConnectionClosed from ..server import * +from ..test_client_server import get_server_uri # Fail at import time, not just at run time, to prevent test @@ -36,14 +37,14 @@ async def handler(ws, path): for message in MESSAGES: await ws.send(message) - server = serve(handler, 'localhost', 8642) - self.server = self.loop.run_until_complete(server) + start_server = serve(handler, 'localhost', 0) + server = self.loop.run_until_complete(start_server) messages = [] async def run_client(): nonlocal messages - async with connect('ws://localhost:8642/') as ws: + async with connect(get_server_uri(server)) as ws: async for message in ws: messages.append(message) @@ -51,8 +52,8 @@ async def run_client(): self.assertEqual(messages, MESSAGES) - self.server.close() - self.loop.run_until_complete(self.server.wait_closed()) + server.close() + self.loop.run_until_complete(server.wait_closed()) def test_iterate_on_messages_exit_not_ok(self): @@ -61,14 +62,14 @@ async def handler(ws, path): await ws.send(message) await ws.close(1001) - server = serve(handler, 'localhost', 8642) - self.server = self.loop.run_until_complete(server) + start_server = serve(handler, 'localhost', 0) + server = self.loop.run_until_complete(start_server) messages = [] async def run_client(): nonlocal messages - async with connect('ws://localhost:8642/') as ws: + async with connect(get_server_uri(server)) as ws: async for message in ws: messages.append(message) @@ -77,5 +78,5 @@ async def run_client(): self.assertEqual(messages, MESSAGES) - self.server.close() - self.loop.run_until_complete(self.server.wait_closed()) + server.close() + self.loop.run_until_complete(server.wait_closed()) diff --git a/websockets/server.py b/websockets/server.py index 3ca6d2066..c0aa9e9f8 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -582,6 +582,16 @@ def wait_closed(self): loop=self.loop) yield from self.server.wait_closed() + @property + def sockets(self): + """ + List of :class:`~socket.socket` objects the server is listening to. + + ``None`` if the server is closed. + + """ + return self.server.sockets + class Serve: """ diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 6499fd0ae..291827714 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -3,6 +3,7 @@ import functools import logging import os.path +import random import socket import ssl import sys @@ -101,6 +102,33 @@ def with_client(*args, **kwds): return with_manager(temp_test_client, *args, **kwds) +def get_server_uri(server, secure=False, resource_name='/'): + """ + Return a WebSocket URI for connecting to the given server. + + """ + proto = 'wss' if secure else 'ws' + + # Pick a random socket in order to test both IPv4 and IPv6 on systems + # where both are available. Randomizing tests is usually a bad idea. If + # needed, either use the first socket, or test separately IPv4 and IPv6. + server_socket = random.choice(server.sockets) + + # That case + if server_socket.family == socket.AF_INET6: # pragma: no cover + host, port = server_socket.getsockname()[:2] + host = '[{}]'.format(host) + elif server_socket.family == socket.AF_INET: + host, port = server_socket.getsockname() + elif server_socket.family == socket.AF_UNIX: + # The host and port are ignored when connecting to a Unix socket. + host, port = 'localhost', 0 + else: # pragma: no cover + raise ValueError("Expected an IPv6, IPv4, or Unix socket") + + return '{}://{}:{}{}'.format(proto, host, port, resource_name) + + class UnauthorizedServerProtocol(WebSocketServerProtocol): @asyncio.coroutine @@ -187,15 +215,16 @@ def run_loop_once(self): def start_server(self, **kwds): # Don't enable compression by default in tests. kwds.setdefault('compression', None) - server = serve(handler, 'localhost', 8642, **kwds) - self.server = self.loop.run_until_complete(server) + start_server = serve(handler, 'localhost', 0, **kwds) + self.server = self.loop.run_until_complete(start_server) - def start_client(self, path='', **kwds): + def start_client(self, resource_name='/', **kwds): # Don't enable compression by default in tests. kwds.setdefault('compression', None) - proto = 'ws' if kwds.get('ssl') is None else 'wss' - client = connect(proto + '://localhost:8642/' + path, **kwds) - self.client = self.loop.run_until_complete(client) + secure = kwds.get('ssl') is not None + server_uri = get_server_uri(self.server, secure, resource_name) + start_client = connect(server_uri, **kwds) + self.client = self.loop.run_until_complete(start_client) def stop_client(self): try: @@ -265,24 +294,28 @@ def send(self, *args, **kwargs): self.used_for_write = True return super().send(*args, **kwargs) - sock = TrackedSocket(socket.AF_INET, socket.SOCK_STREAM) - sock.connect(('localhost', 8642)) - server_hostname = 'localhost' if self.secure else None + server_socket = [ + s for s in self.server.sockets if s.family == socket.AF_INET][0] + client_socket = TrackedSocket(socket.AF_INET, socket.SOCK_STREAM) + client_socket.connect(server_socket.getsockname()) try: - self.assertFalse(sock.used_for_read) - self.assertFalse(sock.used_for_write) + self.assertFalse(client_socket.used_for_read) + self.assertFalse(client_socket.used_for_write) - with self.temp_client(sock=sock, server_hostname=server_hostname): + with self.temp_client( + sock=client_socket, + server_hostname='localhost' if self.secure else None, + ): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") - self.assertTrue(sock.used_for_read) - self.assertTrue(sock.used_for_write) + self.assertTrue(client_socket.used_for_read) + self.assertTrue(client_socket.used_for_write) finally: - sock.close() + client_socket.close() @unittest.skipUnless( hasattr(socket, 'AF_UNIX'), 'this test requires Unix sockets') @@ -294,30 +327,36 @@ def test_unix_socket(self): unix_server = unix_serve(handler, path) self.server = self.loop.run_until_complete(unix_server) - sock = socket.socket(socket.AF_UNIX) - sock.connect(path) + client_socket = socket.socket(socket.AF_UNIX) + client_socket.connect(path) try: - with self.temp_client(sock=sock): + with self.temp_client(sock=client_socket): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") finally: - sock.close() + client_socket.close() self.stop_server() @with_server() - @with_client('attributes') + @with_client('/attributes') def test_protocol_attributes(self): - expected_attrs = ('localhost', 8642, self.secure) + # The test could be connecting with IPv6 or IPv4. + expected_client_attrs = [ + server_socket.getsockname()[:2] + (self.secure,) + for server_socket in self.server.sockets + ] client_attrs = (self.client.host, self.client.port, self.client.secure) - self.assertEqual(client_attrs, expected_attrs) + self.assertIn(client_attrs, expected_client_attrs) + + expected_server_attrs = ('localhost', 0, self.secure) server_attrs = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(server_attrs, repr(expected_attrs)) + self.assertEqual(server_attrs, repr(expected_server_attrs)) @with_server() - @with_client('path') + @with_client('/path') def test_protocol_path(self): client_path = self.client.path self.assertEqual(client_path, '/path') @@ -325,7 +364,7 @@ def test_protocol_path(self): self.assertEqual(server_path, '/path') @with_server() - @with_client('headers') + @with_client('/headers') def test_protocol_headers(self): client_req = self.client.request_headers client_resp = self.client.response_headers @@ -337,7 +376,7 @@ def test_protocol_headers(self): self.assertEqual(server_resp, str(client_resp)) @with_server() - @with_client('raw_headers') + @with_client('/raw_headers') def test_protocol_raw_headers(self): client_req = self.client.raw_request_headers client_resp = self.client.raw_response_headers @@ -349,21 +388,21 @@ def test_protocol_raw_headers(self): self.assertEqual(server_resp, repr(client_resp)) @with_server() - @with_client('raw_headers', extra_headers={'X-Spam': 'Eggs'}) + @with_client('/raw_headers', extra_headers={'X-Spam': 'Eggs'}) def test_protocol_custom_request_headers_dict(self): req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", req_headers) @with_server() - @with_client('raw_headers', extra_headers=[('X-Spam', 'Eggs')]) + @with_client('/raw_headers', extra_headers=[('X-Spam', 'Eggs')]) def test_protocol_custom_request_headers_list(self): req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", req_headers) @with_server() - @with_client('raw_headers', extra_headers=[('User-Agent', 'Eggs')]) + @with_client('/raw_headers', extra_headers=[('User-Agent', 'Eggs')]) def test_protocol_custom_request_user_agent(self): req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) @@ -371,35 +410,35 @@ def test_protocol_custom_request_user_agent(self): self.assertIn("('User-Agent', 'Eggs')", req_headers) @with_server(extra_headers=lambda p, r: {'X-Spam': 'Eggs'}) - @with_client('raw_headers') + @with_client('/raw_headers') def test_protocol_custom_response_headers_callable_dict(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) @with_server(extra_headers=lambda p, r: [('X-Spam', 'Eggs')]) - @with_client('raw_headers') + @with_client('/raw_headers') def test_protocol_custom_response_headers_callable_list(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) @with_server(extra_headers={'X-Spam': 'Eggs'}) - @with_client('raw_headers') + @with_client('/raw_headers') def test_protocol_custom_response_headers_dict(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) @with_server(extra_headers=[('X-Spam', 'Eggs')]) - @with_client('raw_headers') + @with_client('/raw_headers') def test_protocol_custom_response_headers_list(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) @with_server(extra_headers=[('Server', 'Eggs')]) - @with_client('raw_headers') + @with_client('/raw_headers') def test_protocol_custom_response_user_agent(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) @@ -411,12 +450,15 @@ def test_protocol_custom_response_user_agent(self): def test_custom_protocol_http_request(self): # One URL returns an HTTP response. + # Set url to 'https?://:/__health__/'. + url = get_server_uri( + self.server, resource_name='/__health__/', secure=self.secure) + url = url.replace('ws', 'http') + if self.secure: - url = 'https://localhost:8642/__health__/' open_health_check = functools.partial( urllib.request.urlopen, url, context=self.client_context) else: - url = 'http://localhost:8642/__health__/' open_health_check = functools.partial( urllib.request.urlopen, url) @@ -457,50 +499,50 @@ def test_server_create_protocol_over_klass(self): self.assert_client_raises_code(403) @with_server() - @with_client('path', create_protocol=FooClientProtocol) + @with_client('/path', create_protocol=FooClientProtocol) def test_client_create_protocol(self): self.assertIsInstance(self.client, FooClientProtocol) @with_server() - @with_client('path', create_protocol=( + @with_client('/path', create_protocol=( lambda *args, **kwargs: FooClientProtocol(*args, **kwargs))) def test_client_create_protocol_function(self): self.assertIsInstance(self.client, FooClientProtocol) @with_server() - @with_client('path', klass=FooClientProtocol) + @with_client('/path', klass=FooClientProtocol) def test_client_klass(self): self.assertIsInstance(self.client, FooClientProtocol) @with_server() - @with_client('path', create_protocol=BarClientProtocol, + @with_client('/path', create_protocol=BarClientProtocol, klass=FooClientProtocol) def test_client_create_protocol_over_klass(self): self.assertIsInstance(self.client, BarClientProtocol) @with_server() - @with_client('extensions') + @with_client('/extensions') def test_no_extension(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([])) self.assertEqual(repr(self.client.extensions), repr([])) @with_server(extensions=[ServerNoOpExtensionFactory()]) - @with_client('extensions', extensions=[ClientNoOpExtensionFactory()]) + @with_client('/extensions', extensions=[ClientNoOpExtensionFactory()]) def test_extension(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([NoOpExtension()])) self.assertEqual(repr(self.client.extensions), repr([NoOpExtension()])) @with_server() - @with_client('extensions', extensions=[ClientNoOpExtensionFactory()]) + @with_client('/extensions', extensions=[ClientNoOpExtensionFactory()]) def test_extension_not_accepted(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([])) self.assertEqual(repr(self.client.extensions), repr([])) @with_server(extensions=[ServerNoOpExtensionFactory()]) - @with_client('extensions') + @with_client('/extensions') def test_extension_not_requested(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([])) @@ -510,7 +552,7 @@ def test_extension_not_requested(self): def test_extension_client_rejection(self): with self.assertRaises(NegotiationError): self.start_client( - 'extensions', + '/extensions', extensions=[ClientNoOpExtensionFactory()], ) @@ -522,7 +564,7 @@ def test_extension_client_rejection(self): ], ) @with_client( - 'extensions', + '/extensions', extensions=[ ClientPerMessageDeflateFactory(), ], @@ -538,7 +580,7 @@ def test_extension_no_match_then_match(self): ])) @with_server(extensions=[ServerPerMessageDeflateFactory()]) - @with_client('extensions', extensions=[ClientNoOpExtensionFactory()]) + @with_client('/extensions', extensions=[ClientNoOpExtensionFactory()]) def test_extension_mismatch(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([])) @@ -551,7 +593,7 @@ def test_extension_mismatch(self): ], ) @with_client( - 'extensions', + '/extensions', extensions=[ ClientPerMessageDeflateFactory(), ClientNoOpExtensionFactory(), @@ -576,7 +618,7 @@ def test_extensions_error(self, _process_extensions): with self.assertRaises(NegotiationError): self.start_client( - 'extensions', + '/extensions', extensions=[ClientPerMessageDeflateFactory()], ) @@ -586,10 +628,10 @@ def test_extensions_error_no_extensions(self, _process_extensions): _process_extensions.return_value = 'x-no-op', [NoOpExtension()] with self.assertRaises(InvalidHandshake): - self.start_client('extensions') + self.start_client('/extensions') @with_server(compression='deflate') - @with_client('extensions', compression='deflate') + @with_client('/extensions', compression='deflate') def test_compression_deflate(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([ @@ -609,7 +651,7 @@ def test_compression_deflate(self): compression='deflate', # overridden by explicit config ) @with_client( - 'extensions', + '/extensions', extensions=[ ClientPerMessageDeflateFactory( server_no_context_takeover=True, @@ -627,42 +669,45 @@ def test_compression_deflate_and_explicit_config(self): PerMessageDeflate(True, True, 10, 12), ])) - def test_compression_unsupported(self): + def test_compression_unsupported_server(self): with self.assertRaises(ValueError): self.loop.run_until_complete(self.start_server(compression='xz')) + + @with_server() + def test_compression_unsupported_client(self): with self.assertRaises(ValueError): self.loop.run_until_complete(self.start_client(compression='xz')) @with_server() - @with_client('subprotocol') + @with_client('/subprotocol') def test_no_subprotocol(self): server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) @with_server(subprotocols=['superchat', 'chat']) - @with_client('subprotocol', subprotocols=['otherchat', 'chat']) + @with_client('/subprotocol', subprotocols=['otherchat', 'chat']) def test_subprotocol(self): server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr('chat')) self.assertEqual(self.client.subprotocol, 'chat') @with_server(subprotocols=['superchat']) - @with_client('subprotocol', subprotocols=['otherchat']) + @with_client('/subprotocol', subprotocols=['otherchat']) def test_subprotocol_not_accepted(self): server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) @with_server() - @with_client('subprotocol', subprotocols=['otherchat', 'chat']) + @with_client('/subprotocol', subprotocols=['otherchat', 'chat']) def test_subprotocol_not_offered(self): server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) @with_server(subprotocols=['superchat', 'chat']) - @with_client('subprotocol') + @with_client('/subprotocol') def test_subprotocol_not_requested(self): server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) @@ -674,7 +719,7 @@ def test_subprotocol_error(self, _process_subprotocol): _process_subprotocol.return_value = 'superchat' with self.assertRaises(NegotiationError): - self.start_client('subprotocol', subprotocols=['otherchat']) + self.start_client('/subprotocol', subprotocols=['otherchat']) self.run_loop_once() @with_server(subprotocols=['superchat']) @@ -683,7 +728,7 @@ def test_subprotocol_error_no_subprotocols(self, _process_subprotocol): _process_subprotocol.return_value = 'superchat' with self.assertRaises(InvalidHandshake): - self.start_client('subprotocol') + self.start_client('/subprotocol') self.run_loop_once() @with_server(subprotocols=['superchat', 'chat']) @@ -693,7 +738,7 @@ def test_subprotocol_error_two_subprotocols(self, _process_subprotocol): with self.assertRaises(InvalidHandshake): self.start_client( - 'subprotocol', subprotocols=['superchat', 'chat']) + '/subprotocol', subprotocols=['superchat', 'chat']) self.run_loop_once() @with_server() @@ -891,7 +936,7 @@ def start_server(self, **kwds): kwds.setdefault('ssl', self.server_context) super().start_server(**kwds) - def start_client(self, path='', **kwds): + def start_client(self, path='/', **kwds): kwds.setdefault('ssl', self.client_context) super().start_client(path, **kwds) @@ -901,7 +946,10 @@ def start_client(self, path='', **kwds): @with_server() def test_ws_uri_is_rejected(self): with self.assertRaises(ValueError): - client = connect('ws://localhost:8642/', ssl=self.client_context) + client = connect( + get_server_uri(self.server, secure=False), + ssl=self.client_context, + ) # With Python ≥ 3.5, the exception is raised by connect() even # before awaiting. However, with Python 3.4 the exception is # raised only when awaiting. @@ -919,9 +967,9 @@ def tearDown(self): def test_checking_origin_succeeds(self): server = self.loop.run_until_complete( - serve(handler, 'localhost', 8642, origins=['http://localhost'])) + serve(handler, 'localhost', 0, origins=['http://localhost'])) client = self.loop.run_until_complete( - connect('ws://localhost:8642/', origin='http://localhost')) + connect(get_server_uri(server), origin='http://localhost')) self.loop.run_until_complete(client.send("Hello!")) self.assertEqual(self.loop.run_until_complete(client.recv()), "Hello!") @@ -932,19 +980,19 @@ def test_checking_origin_succeeds(self): def test_checking_origin_fails(self): server = self.loop.run_until_complete( - serve(handler, 'localhost', 8642, origins=['http://localhost'])) + serve(handler, 'localhost', 0, origins=['http://localhost'])) with self.assertRaisesRegex(InvalidHandshake, "Status code not 101: 403"): self.loop.run_until_complete( - connect('ws://localhost:8642/', origin='http://otherhost')) + connect(get_server_uri(server), origin='http://otherhost')) server.close() self.loop.run_until_complete(server.wait_closed()) def test_checking_lack_of_origin_succeeds(self): server = self.loop.run_until_complete( - serve(handler, 'localhost', 8642, origins=[''])) - client = self.loop.run_until_complete(connect('ws://localhost:8642/')) + serve(handler, 'localhost', 0, origins=[''])) + client = self.loop.run_until_complete(connect(get_server_uri(server))) self.loop.run_until_complete(client.send("Hello!")) self.assertEqual(self.loop.run_until_complete(client.recv()), "Hello!") From 93c7bc1446f1f9e99b0919dceacf84fd92832811 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 2 Nov 2017 07:50:44 +0100 Subject: [PATCH 0363/1539] Bump version number. --- docs/changelog.rst | 5 ++++- docs/conf.py | 4 ++-- websockets/version.py | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index a494043b1..de9911d11 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -3,11 +3,14 @@ Changelog .. currentmodule:: websockets -4.0 +4.1 ... *In development* +4.0 +... + .. warning:: **Version 4.0 enables compression with the permessage-deflate extension.** diff --git a/docs/conf.py b/docs/conf.py index fd886a842..f68a0de28 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -48,9 +48,9 @@ # built documents. # # The short X.Y version. -version = '3.4' +version = '4.0' # The full version, including alpha/beta/rc tags. -release = '3.4' +release = '4.0' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/websockets/version.py b/websockets/version.py index a0e73377d..14058a263 100644 --- a/websockets/version.py +++ b/websockets/version.py @@ -1 +1 @@ -version = '3.4' +version = '4.0' From 4e67f829cfe4e9c23046415d6cd38cd2825cd6b7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 2 Nov 2017 21:05:12 +0100 Subject: [PATCH 0364/1539] Remove unnecessary lines in MANIFEST.in. `packages` takes care of this in setup.py. --- MANIFEST.in | 3 --- 1 file changed, 3 deletions(-) diff --git a/MANIFEST.in b/MANIFEST.in index 7e96fd0d7..d0cff2af7 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,2 @@ include LICENSE include websockets/test_localhost.pem - -graft websockets/py35 -graft websockets/py36 From cb9b268c60a96f5bdcc8988587e0460ca4edbcdc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 2 Nov 2017 21:07:59 +0100 Subject: [PATCH 0365/1539] Attempt to fix wheel upload on Travis CI. --- .travis.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index 389349a4e..3978574ab 100644 --- a/.travis.yml +++ b/.travis.yml @@ -31,6 +31,6 @@ script: - cibuildwheel --output-dir wheelhouse # Upload to PyPI on tags - if [ "${TRAVIS_TAG:-}" != "" ]; then - python -m pip install twine && - python -m twine upload --skip-existing wheelhouse/*; + pip3 install twine; + twine upload --skip-existing wheelhouse/*; fi From f730fe091fdb9aab8b608e710d260efdf4e8a077 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 2 Nov 2017 21:11:32 +0100 Subject: [PATCH 0366/1539] Test enum identity instead of equality. The docs say this is recommended. --- websockets/protocol.py | 20 ++++++++++---------- websockets/test_protocol.py | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index fb8ad0d75..3d769c547 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -219,7 +219,7 @@ def connection_open(self): """ # 4.1. The WebSocket Connection is Established. - assert self.state == State.CONNECTING + assert self.state is State.CONNECTING self.state = State.OPEN # Start the task that receives incoming WebSocket messages. self.transfer_data_task = asyncio_ensure_future( @@ -268,7 +268,7 @@ def open(self): .. _EAFP: https://docs.python.org/3/glossary.html#term-eafp """ - return self.state == State.OPEN + return self.state is State.OPEN @asyncio.coroutine def recv(self): @@ -360,7 +360,7 @@ def close(self, code=1000, reason=''): ``code`` must be an :class:`int` and ``reason`` a :class:`str`. """ - if self.state == State.OPEN: + if self.state is State.OPEN: # 7.1.2. Start the WebSocket Closing Handshake # 7.1.3. The WebSocket Closing Handshake is Started frame_data = serialize_close(code, reason) @@ -463,13 +463,13 @@ def ensure_open(self): """ # Handle cases from most common to least common for performance. - if self.state == State.OPEN: + if self.state is State.OPEN: return - if self.state == State.CLOSED: + if self.state is State.CLOSED: raise ConnectionClosed(self.close_code, self.close_reason) - if self.state == State.CLOSING: + if self.state is State.CLOSING: # If we started the closing handshake, wait for its completion to # get the proper close code and status. self.close_connection_task # will complete within 4 or 5 * timeout after calling close(). @@ -480,7 +480,7 @@ def ensure_open(self): raise ConnectionClosed(self.close_code, self.close_reason) # Control may only reach this point in buggy third-party subclasses. - assert self.state == State.CONNECTING + assert self.state is State.CONNECTING raise InvalidState("WebSocket connection isn't established yet") @asyncio.coroutine @@ -597,7 +597,7 @@ def read_data_frame(self, max_size): # 7.1.5. The WebSocket Connection Close Code # 7.1.6. The WebSocket Connection Close Reason self.close_code, self.close_reason = code, reason - if self.state == State.OPEN: + if self.state is State.OPEN: # 7.1.3. The WebSocket Closing Handshake is Started yield from self.write_frame(OP_CLOSE, frame.data) return @@ -637,7 +637,7 @@ def read_frame(self, max_size): @asyncio.coroutine def write_frame(self, opcode, data=b''): # Defensive assertion for protocol compliance. - if self.state != State.OPEN: # pragma: no cover + if self.state is not State.OPEN: # pragma: no cover raise InvalidState("Cannot write to a WebSocket " "in the {} state".format(self.state.name)) @@ -784,7 +784,7 @@ def fail_connection(self, code=1011, reason=''): self.side, code, reason, ) # Don't send a close frame if the connection is broken. - if self.state == State.OPEN and code != 1006: + if self.state is State.OPEN and code != 1006: frame_data = serialize_close(code, reason) yield from self.write_frame(OP_CLOSE, frame_data) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 544a38158..bf64b9d11 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -72,7 +72,7 @@ def close(self): def abort(self): # Change this to an `if` if tests call abort() multiple times. - assert self.protocol.state != State.CLOSED + assert self.protocol.state is not State.CLOSED self.loop.call_soon(self.protocol.connection_lost, None) From a6211d48c57a0886ee1eefa87651c99bd08d4fd2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 2 Nov 2017 21:14:59 +0100 Subject: [PATCH 0367/1539] Bump version number. --- websockets/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/websockets/version.py b/websockets/version.py index 14058a263..dc4e88832 100644 --- a/websockets/version.py +++ b/websockets/version.py @@ -1 +1 @@ -version = '4.0' +version = '4.0.1' From 05907cd7e390617d6c80f2b2a0bf222e622dd641 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 2 Nov 2017 22:58:58 +0100 Subject: [PATCH 0368/1539] Attempt #2 to fix wheel upload on Travis CI. --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 3978574ab..d964a1343 100644 --- a/.travis.yml +++ b/.travis.yml @@ -32,5 +32,5 @@ script: # Upload to PyPI on tags - if [ "${TRAVIS_TAG:-}" != "" ]; then pip3 install twine; - twine upload --skip-existing wheelhouse/*; + python3 -m twine upload --skip-existing wheelhouse/*; fi From 136d9bc9a917f7f5fdde4001e801b591ee1b4a0b Mon Sep 17 00:00:00 2001 From: Anton Lakotka Date: Tue, 7 Nov 2017 11:06:33 +0100 Subject: [PATCH 0369/1539] Fixed infinite loop in "Graceful Shutdown" Examples when client closed the connection. Replaced "pass" statement by "break" inside echo handler loop. If client closes connection it causes infinite loop and took 100% CPU. "websocket.recv()" always return throws "websockets.ConnectionClosed" so pass statement just continue loop over and over again. --- example/oldshutdown.py | 2 +- example/shutdown.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/example/oldshutdown.py b/example/oldshutdown.py index b95fa91a3..6618aade0 100644 --- a/example/oldshutdown.py +++ b/example/oldshutdown.py @@ -9,7 +9,7 @@ async def echo(websocket, path): try: msg = await websocket.recv() except websockets.ConnectionClosed: - pass + break else: await websocket.send(msg) diff --git a/example/shutdown.py b/example/shutdown.py index 1f686d160..663dbd58a 100644 --- a/example/shutdown.py +++ b/example/shutdown.py @@ -9,7 +9,7 @@ async def echo(websocket, path): try: msg = await websocket.recv() except websockets.ConnectionClosed: - pass + break else: await websocket.send(msg) From 9f566c5d12eae1258aa6f1991fe8cf0a7d390f97 Mon Sep 17 00:00:00 2001 From: MysterialPy Date: Thu, 14 Dec 2017 23:41:48 +1000 Subject: [PATCH 0370/1539] Fix Syntax error thrown in 3.7 Maintains backward compat. --- websockets/compatibility.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/websockets/compatibility.py b/websockets/compatibility.py index c7e08126f..c3bb61321 100644 --- a/websockets/compatibility.py +++ b/websockets/compatibility.py @@ -9,10 +9,10 @@ # Replace with BaseEventLoop.create_task when dropping Python < 3.4.2. -try: # pragma: no cover - asyncio_ensure_future = asyncio.ensure_future # Python ≥ 3.5 -except AttributeError: # pragma: no cover - asyncio_ensure_future = asyncio.async # Python < 3.5 +try: # pragma: no cover + asyncio_ensure_future = asyncio.ensure_future # Python ≥ 3.5 +except AttributeError: # pragma: no cover + asyncio_ensure_future = getattr(asyncio, 'async') # Python < 3.5 try: # pragma: no cover # Python ≥ 3.5 From d4a6a24fcf9786b1917668bfaee946e85bc22478 Mon Sep 17 00:00:00 2001 From: pv2b Date: Wed, 3 Jan 2018 16:06:00 +0100 Subject: [PATCH 0371/1539] Fix typo in README.rst courountine -> coroutine --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index d05a5bb14..5d38481d8 100644 --- a/README.rst +++ b/README.rst @@ -114,7 +114,7 @@ Why shouldn't I use ``websockets``? ----------------------------------- * If you prefer callbacks over coroutines: ``websockets`` was created to - provide the best corountine-based API to manage WebSocket connections in + provide the best coroutine-based API to manage WebSocket connections in Python. Pick another library for a callback-based API. * If you're looking for a mixed HTTP / WebSocket library: ``websockets`` aims at being an excellent implementation of :rfc:`6455`: The WebSocket Protocol From f7331b04c524f2ebfd2371a1a59852d3d3439779 Mon Sep 17 00:00:00 2001 From: Ashley Sommer Date: Mon, 12 Feb 2018 02:45:49 +1000 Subject: [PATCH 0372/1539] Add zip_safe=True to setup.py (#342) Add `zip_safe=True` to setup.py to tell easy_install and setuptools that the code should not be analyzed before installing. This avoids the problem caused by evaluating code in the websockets/py36/ directory when running under Python 3.5. Fixes #341. --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index df1b516dd..e483dbecf 100644 --- a/setup.py +++ b/setup.py @@ -60,4 +60,5 @@ packages=packages, ext_modules=ext_modules, include_package_data=True, + zip_safe=True, ) From 13eb4a09815b7edf3ac6270cbdf131b89b4ccc9d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 12 Feb 2018 21:46:17 +0100 Subject: [PATCH 0373/1539] =?UTF-8?q?Mention=20that=20README=20requires=20?= =?UTF-8?q?Python=20=E2=89=A5=203.6.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The Python < 3.6 version is provided in the documentation. Fix #340. --- README.rst | 2 +- websockets/py36/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index 5d38481d8..adc3a9210 100644 --- a/README.rst +++ b/README.rst @@ -54,7 +54,7 @@ Here's a client that says "Hello world!": asyncio.get_event_loop().run_until_complete( hello('ws://localhost:8765')) -And here's an echo server: +And here's an echo server (for Python ≥ 3.6): .. code:: python diff --git a/websockets/py36/__init__.py b/websockets/py36/__init__.py index 396f34968..b9211bf87 100644 --- a/websockets/py36/__init__.py +++ b/websockets/py36/__init__.py @@ -1,2 +1,2 @@ -# This package contains code using async iteratino added in Python 3.6. +# This package contains code using async iteration added in Python 3.6. # It cannot be imported on Python < 3.6 because it triggers syntax errors. From 99e6a4d9dbbab01f514bd84d21723ba98d6f8634 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 28 Apr 2018 17:49:23 +0200 Subject: [PATCH 0374/1539] Use the latest version of cibuildwheel. It's more likely to fix issues than introduce them. --- .travis.yml | 2 +- appveyor.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index d964a1343..ca3c8dff6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -23,7 +23,7 @@ install: brew install python3; fi # Install cibuildwheel using pip3 to make sure Python 3 is used. - - pip3 install cibuildwheel==0.4.0 + - pip3 install --upgrade cibuildwheel # Create file '.cibuildwheel' so that extension build is not optional (c.f. setup.py). - touch .cibuildwheel diff --git a/appveyor.yml b/appveyor.yml index 75b32b118..42856df38 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -6,7 +6,7 @@ environment: # Since Python 2 is still the default, invoke Python 3 explicitly. install: - - cmd: C:\Python36-x64\python.exe -m pip install cibuildwheel==0.4.0 + - cmd: C:\Python36-x64\python.exe -m pip install --upgrade cibuildwheel # Create file '.cibuildwheel' so that extension build is not optional (c.f. setup.py). - cmd: touch .cibuildwheel build_script: From ad13e489a305b28d78e30460e694fa735a180890 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 28 Apr 2018 18:05:04 +0200 Subject: [PATCH 0375/1539] Adapt to changes for Python 3 in Homebrew. Python 3 is now the default (yay!) --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index ca3c8dff6..da1ab8b01 100644 --- a/.travis.yml +++ b/.travis.yml @@ -20,7 +20,7 @@ install: # Python 3 is needed to run cibuildwheel for websockets. - if [ "${TRAVIS_OS_NAME:-}" == "osx" ]; then brew update; - brew install python3; + brew upgrade python; fi # Install cibuildwheel using pip3 to make sure Python 3 is used. - pip3 install --upgrade cibuildwheel From 61badeb0a87db175157429647eae5d47fd596393 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 28 Apr 2018 18:05:52 +0200 Subject: [PATCH 0376/1539] Fix AppVeyor builds. --- appveyor.yml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/appveyor.yml b/appveyor.yml index 42856df38..31e48fe2d 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -4,13 +4,15 @@ environment: CIBW_TEST_COMMAND: python -m unittest websockets WEBSOCKETS_TESTS_TIMEOUT_FACTOR: 100 -# Since Python 2 is still the default, invoke Python 3 explicitly. install: - - cmd: C:\Python36-x64\python.exe -m pip install --upgrade cibuildwheel +# Ensure python is Python 3. + - set PATH=C:\Python34;%PATH% + - cmd: python -m pip install --upgrade cibuildwheel # Create file '.cibuildwheel' so that extension build is not optional (c.f. setup.py). - cmd: touch .cibuildwheel + build_script: - - cmd: C:\Python36-x64\python.exe -m cibuildwheel --output-dir wheelhouse + - cmd: python -m cibuildwheel --output-dir wheelhouse # Upload to PyPI on tags - ps: >- if ($env:APPVEYOR_REPO_TAG -eq "true") { From 4579c2b93eff053c624eebf44cba09f45c140f1a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 28 Apr 2018 18:23:23 +0200 Subject: [PATCH 0377/1539] Change AppVeyor config to a dotfile. This is consistent with Travis config. --- appveyor.yml => .appveyor.yml | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename appveyor.yml => .appveyor.yml (100%) diff --git a/appveyor.yml b/.appveyor.yml similarity index 100% rename from appveyor.yml rename to .appveyor.yml From 38145985e37474fe5c1145731ec7022836f12a41 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 28 Apr 2018 20:42:08 +0200 Subject: [PATCH 0378/1539] Make unix_serve an asynchronous context manager. Fix #355. --- docs/changelog.rst | 7 +++-- websockets/py35/_test_client_server.py | 39 +++++++++++++++++++++----- websockets/server.py | 1 - 3 files changed, 37 insertions(+), 10 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index de9911d11..cde0c462c 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -8,6 +8,9 @@ Changelog *In development* +* :func:`~server.unix_serve` can be used as an asynchronous context manager on + Python ≥ 3.5. + 4.0 ... @@ -33,9 +36,9 @@ Also: * :class:`~protocol.WebSocketCommonProtocol` instances can be used as asynchronous iterators on Python ≥ 3.6. They yield incoming messages. -* Added :func:`~websockets.server.unix_serve` for listening on Unix sockets. +* Added :func:`~server.unix_serve` for listening on Unix sockets. -* Added the :attr:`~websockets.server.WebSocketServer.sockets` attribute. +* Added the :attr:`~server.WebSocketServer.sockets` attribute. * Reorganized and extended documentation. diff --git a/websockets/py35/_test_client_server.py b/websockets/py35/_test_client_server.py index e9d7a6221..ad9d83d03 100644 --- a/websockets/py35/_test_client_server.py +++ b/websockets/py35/_test_client_server.py @@ -1,9 +1,14 @@ # Tests containing Python 3.5+ syntax, extracted from test_client_server.py. import asyncio +import os +import socket +import sys +import tempfile import unittest from ..client import * +from ..protocol import State from ..server import * from ..test_client_server import get_server_uri, handler @@ -22,10 +27,12 @@ def test_client(self): server = self.loop.run_until_complete(start_server) async def run_client(): + # Use connect as an asynchronous context manager. async with connect(get_server_uri(server)) as client: - await client.send("Hello!") - reply = await client.recv() - self.assertEqual(reply, "Hello!") + self.assertEqual(client.state, State.OPEN) + + # Check that exiting the context manager closed the connection. + self.assertEqual(client.state, State.CLOSED) self.loop.run_until_complete(run_client()) @@ -34,10 +41,28 @@ async def run_client(): def test_server(self): async def run_server(): + # Use serve as an asynchronous context manager. async with serve(handler, 'localhost', 0) as server: - client = await connect(get_server_uri(server)) - await client.send("Hello!") - reply = await client.recv() - self.assertEqual(reply, "Hello!") + self.assertTrue(server.sockets) + + # Check that exiting the context manager closed the server. + self.assertFalse(server.sockets) self.loop.run_until_complete(run_server()) + + # Asynchronous context managers are only enabled on Python ≥ 3.5.1. + @unittest.skipIf( + sys.version_info[:3] <= (3, 5, 0), 'this test requires Python 3.5.1+') + @unittest.skipUnless( + hasattr(socket, 'AF_UNIX'), 'this test requires Unix sockets') + def test_unix_server(self): + async def run_server(path): + async with unix_serve(handler, path) as server: + self.assertTrue(server.sockets) + + # Check that exiting the context manager closed the server. + self.assertFalse(server.sockets) + + with tempfile.TemporaryDirectory() as temp_dir: + path = os.path.join(temp_dir, 'websockets') + self.loop.run_until_complete(run_server(path)) diff --git a/websockets/server.py b/websockets/server.py index c0aa9e9f8..7243b9c94 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -729,7 +729,6 @@ def __await__(self): __iter__ = __await__ -@asyncio.coroutine def unix_serve(ws_handler, path, **kwargs): """ Similar to :func:`serve()`, but for listening on Unix sockets. From 33341afda0c1917fb4a37e3459e3c39d66668d11 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 28 Apr 2018 21:03:55 +0200 Subject: [PATCH 0379/1539] Restore full test coverage. wait_closed() was accidentally covered by ContextManagerTests. --- websockets/test_client_server.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 291827714..e397558ae 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -24,6 +24,7 @@ from .handshake import build_response from .http import USER_AGENT, read_response from .server import * +from .test_protocol import MS # Avoid displaying stack traces at the ERROR logging level. @@ -48,6 +49,12 @@ def handler(ws, path): yield from ws.send(repr(ws.extensions)) elif path == '/subprotocol': yield from ws.send(repr(ws.subprotocol)) + elif path == '/slow_stop': + try: + yield from asyncio.sleep(1000 * MS) + except asyncio.CancelledError: + yield from asyncio.sleep(MS) + raise else: yield from ws.send((yield from ws.recv())) @@ -260,7 +267,9 @@ def test_basic(self): def test_server_close_while_client_connected(self): with self.temp_server(loop=self.loop): - self.start_client() + # This endpoint waits just a bit when the connection is cancelled + # in order to test that wait_closed() really waits for completion. + self.start_client('/slow_stop') with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.client.recv()) # Connection ends with 1001 going away. From 4199fc83e5ddb717b67681e4bb5ad6c6c2df4428 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 28 Apr 2018 22:02:56 +0200 Subject: [PATCH 0380/1539] Allow running a single test with tox. Fix #297. --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index d23a2c217..c4123c20a 100644 --- a/tox.ini +++ b/tox.ini @@ -12,7 +12,7 @@ commands = ; Before testing with speedups, compile the extension. speedups: python setup.py --quiet build_ext --inplace - python -m unittest + python -m unittest {posargs} ; After testing with speedups, remove the extension. speedups: sh -c 'rm websockets/*.so' From f3f3fd0c19843f5e4359738ce298932e93ce2758 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 28 Apr 2018 19:03:34 +0200 Subject: [PATCH 0381/1539] Disable async context manager functionality on Python 3.5. Clarify documentation about unsupported features in older Pythons. Fix #318. --- docs/changelog.rst | 6 ++-- docs/cheatsheet.rst | 4 ++- docs/deployment.rst | 5 +-- docs/intro.rst | 42 +++++++++++++++++--------- websockets/client.py | 7 +++-- websockets/py35/_test_client_server.py | 6 ++++ websockets/server.py | 5 ++- 7 files changed, 51 insertions(+), 24 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index cde0c462c..a961f317d 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -9,7 +9,7 @@ Changelog *In development* * :func:`~server.unix_serve` can be used as an asynchronous context manager on - Python ≥ 3.5. + Python ≥ 3.5.1. 4.0 ... @@ -61,7 +61,7 @@ Also: For backwards compatibility, ``klass`` is still supported. * :func:`~server.serve` can be used as an asynchronous context manager on - Python ≥ 3.5. + Python ≥ 3.5.1. * Added support for customizing handling of incoming connections with :meth:`~server.WebSocketServerProtocol.process_request()`. @@ -137,7 +137,7 @@ Also: Also: * :func:`~client.connect` can be used as an asynchronous context manager on - Python ≥ 3.5. + Python ≥ 3.5.1. * Updated documentation with ``await`` and ``async`` syntax from Python 3.5. diff --git a/docs/cheatsheet.rst b/docs/cheatsheet.rst index eb29ad23d..259f85d50 100644 --- a/docs/cheatsheet.rst +++ b/docs/cheatsheet.rst @@ -20,6 +20,8 @@ Server * Create a server with :func:`~server.serve` which is similar to asyncio's :meth:`~asyncio.AbstractEventLoop.create_server`. + * On Python ≥ 3.5.1, you can also use it as an asynchronous context manager. + * The server takes care of establishing connections, then lets the handler execute the application logic, and finally closes the connection after the handler exits normally or with an exception. @@ -34,7 +36,7 @@ Client * Create a client with :func:`~client.connect` which is similar to asyncio's :meth:`~asyncio.BaseEventLoop.create_connection`. - * On Python ≥ 3.5, you can also use it as an asynchronous context manager. + * On Python ≥ 3.5.1, you can also use it as an asynchronous context manager. * For advanced customization, you may subclass :class:`~server.WebSocketClientProtocol` and pass either this subclass or diff --git a/docs/deployment.rst b/docs/deployment.rst index c5e3dab28..beb8c7474 100644 --- a/docs/deployment.rst +++ b/docs/deployment.rst @@ -40,8 +40,9 @@ Here's a full example (Unix-only): .. literalinclude:: ../example/shutdown.py -``async``, ``await``, and asynchronous context managers aren't available in -Python < 3.5. Here's the equivalent for older Python versions: +``async`` and ``await`` were introduced in Python 3.5. websockets supports +asynchronous context managers on Python ≥ 3.5.1. Here's the equivalent for +older Python versions: .. literalinclude:: ../example/oldshutdown.py diff --git a/docs/intro.rst b/docs/intro.rst index ab77f9c50..4c5551ba5 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -3,22 +3,27 @@ Getting started .. currentmodule:: websockets -.. warning:: - - This documentation is written for Python ≥ 3.5. If you're using Python - 3.4, you will have to :ref:`adapt the code samples `. - Installation ------------ -``websockets`` requires Python ≥ 3.4. Install it with:: +Install ``websockets`` with:: pip install websockets +``websockets`` requires Python ≥ 3.4. We recommend using the latest version. + +If you're using an older version, be aware that for each minor version (3.x), +only the latest bugfix release (3.x.y) is officially supported. + +.. warning:: + + This documentation is written for Python ≥ 3.5.1. If you're using an older + Python version, you need to :ref:`adapt the code samples `. + Basic example ------------- -*This section assumes Python ≥ 3.5. For older versions, read below.* +*This section assumes Python ≥ 3.5.1. For older versions, read below.* .. _server-example: @@ -37,8 +42,9 @@ Here's a corresponding client example. .. literalinclude:: ../example/client.py -``async`` and ``await`` aren't available in Python < 3.5. Here's how to adapt -the client example for older Python versions. +``async`` and ``await`` were introduced in Python 3.5. websockets supports ++asynchronous context managers on Python ≥ 3.5.1. Here's how to adapt the +client example for older Python versions. .. literalinclude:: ../example/oldclient.py @@ -74,8 +80,8 @@ For receiving messages and passing them to a ``consumer`` coroutine:: Iteration terminates when the client disconnects. -Asynchronous iteration isn't available in Python < 3.6; here's the same code -for earlier Python versions:: +Asynchronous iteration was introduced in Python 3.6; here's the same code for +earlier Python versions:: async def consumer_handler(websocket, path): while True: @@ -153,15 +159,15 @@ answering pings, or any other behavior required by the specification. ``websockets`` handles all this under the hood so you don't have to. -.. _python-lt-35: +.. _python-lt-351: -Python < 3.5 ------------- +Python < 3.5.1 +-------------- This documentation uses the ``await`` and ``async`` syntax introduced in Python 3.5. -If you're using Python 3.4, you must substitute:: +If you're using Python < 3.5, you must substitute:: async def ... @@ -179,3 +185,9 @@ with:: yield from ... Otherwise you will encounter a :exc:`SyntaxError`. + +websockets supports asynchronous context managers only on Python ≥ 3.5.1 +because :func:`~asyncio.ensure_future` was changed to accept arbitrary +awaitables in that version. + +If you're using Python ≤ 3.5, you can't use this feature. diff --git a/websockets/client.py b/websockets/client.py index 4290e007b..2f69fcdcb 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -280,7 +280,7 @@ class Connect: It yields a :class:`WebSocketClientProtocol` which can then be used to send and receive messages. - On Python ≥ 3.5, :func:`connect` can be used as a asynchronous context + On Python ≥ 3.5.1, :func:`connect` can be used as a asynchronous context manager. In that case, the connection is closed when exiting the context. :func:`connect` is a wrapper around the event loop's @@ -409,7 +409,10 @@ def __await__(self): __iter__ = __await__ -if sys.version_info[:2] <= (3, 4): # pragma: no cover +# Disable asynchronous context manager functionality only on Python < 3.5.1 +# because it doesn't exist on Python < 3.5 and asyncio.ensure_future didn't +# accept arbitrary awaitables in Python 3.5; that was fixed in Python 3.5.1. +if sys.version_info[:3] <= (3, 5, 0): # pragma: no cover @asyncio.coroutine def connect(*args, **kwds): return Connect(*args, **kwds).__await__() diff --git a/websockets/py35/_test_client_server.py b/websockets/py35/_test_client_server.py index ad9d83d03..d0c0cfa13 100644 --- a/websockets/py35/_test_client_server.py +++ b/websockets/py35/_test_client_server.py @@ -22,6 +22,9 @@ def setUp(self): def tearDown(self): self.loop.close() + # Asynchronous context managers are only enabled on Python ≥ 3.5.1. + @unittest.skipIf( + sys.version_info[:3] <= (3, 5, 0), 'this test requires Python 3.5.1+') def test_client(self): start_server = serve(handler, 'localhost', 0) server = self.loop.run_until_complete(start_server) @@ -39,6 +42,9 @@ async def run_client(): server.close() self.loop.run_until_complete(server.wait_closed()) + # Asynchronous context managers are only enabled on Python ≥ 3.5.1. + @unittest.skipIf( + sys.version_info[:3] <= (3, 5, 0), 'this test requires Python 3.5.1+') def test_server(self): async def run_server(): # Use serve as an asynchronous context manager. diff --git a/websockets/server.py b/websockets/server.py index 7243b9c94..410087bad 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -744,7 +744,10 @@ def unix_serve(ws_handler, path, **kwargs): return serve(ws_handler, path=path, **kwargs) -if sys.version_info[:2] <= (3, 4): # pragma: no cover +# Disable asynchronous context manager functionality only on Python < 3.5.1 +# because it doesn't exist on Python < 3.5 and asyncio.ensure_future didn't +# accept arbitrary awaitables in Python 3.5; that was fixed in Python 3.5.1. +if sys.version_info[:3] <= (3, 5, 0): # pragma: no cover @asyncio.coroutine def serve(*args, **kwds): return Serve(*args, **kwds).__await__() From 35d6754ce32f36dab013e0273470be1529578317 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 28 Apr 2018 17:42:00 +0200 Subject: [PATCH 0382/1539] Add support for HTTP Basic Auth in the client side. Fix #373. --- docs/changelog.rst | 15 ++++++++++++++- websockets/client.py | 5 ++++- websockets/http.py | 13 +++++++++++++ websockets/test_client_server.py | 19 +++++++++++++++---- websockets/test_http.py | 9 ++++++++- websockets/test_uri.py | 25 ++++++++++++++++++++----- websockets/uri.py | 13 +++++++++---- 7 files changed, 83 insertions(+), 16 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index a961f317d..745174f21 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -3,11 +3,24 @@ Changelog .. currentmodule:: websockets -4.1 +5.0 ... *In development* +.. warning:: + + **Version 5.0 adds a ``user_info`` field to the return value of + :func:`~websockets.parse_uri`, :class:`~websockets.WebSocketURI`.** + + If you're unpacking :class:`~websockets.WebSocketURI` into four variables, + adjust your code to account for that fifth field. + +Also: + +* :func:`~client.connect()` performs HTTP Basic Auth when the URI contains + credentials. + * :func:`~server.unix_serve` can be used as an asynchronous context manager on Python ≥ 3.5.1. diff --git a/websockets/client.py b/websockets/client.py index 2f69fcdcb..6922a70c2 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -16,7 +16,7 @@ build_extension_list, build_protocol_list, parse_extension_list, parse_protocol_list ) -from .http import USER_AGENT, build_headers, read_response +from .http import USER_AGENT, basic_auth_header, build_headers, read_response from .protocol import WebSocketCommonProtocol from .uri import parse_uri @@ -225,6 +225,9 @@ def handshake(self, wsuri, origin=None, available_extensions=None, else: set_header('Host', '{}:{}'.format(wsuri.host, wsuri.port)) + if wsuri.user_info: + set_header(*basic_auth_header(*wsuri.user_info)) + if origin is not None: set_header('Origin', origin) diff --git a/websockets/http.py b/websockets/http.py index 99eef8482..3fef6d34b 100644 --- a/websockets/http.py +++ b/websockets/http.py @@ -8,6 +8,7 @@ """ import asyncio +import base64 import http.client import re import sys @@ -206,3 +207,15 @@ def build_headers(raw_headers): headers = http.client.HTTPMessage() headers._headers = raw_headers # HACK return headers + + +def basic_auth_header(username, password): + """ + Build an Authorization header for HTTP Basic Auth. + + """ + # https://tools.ietf.org/html/rfc7617#section-2 + assert ':' not in username + user_pass = '{}:{}'.format(username, password) + basic_credentials = base64.b64encode(user_pass.encode()).decode() + return ('Authorization', 'Basic ' + basic_credentials) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index e397558ae..7ee8b78d9 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -109,13 +109,15 @@ def with_client(*args, **kwds): return with_manager(temp_test_client, *args, **kwds) -def get_server_uri(server, secure=False, resource_name='/'): +def get_server_uri(server, secure=False, resource_name='/', user_info=None): """ Return a WebSocket URI for connecting to the given server. """ proto = 'wss' if secure else 'ws' + user_info = ':'.join(user_info) + '@' if user_info else '' + # Pick a random socket in order to test both IPv4 and IPv6 on systems # where both are available. Randomizing tests is usually a bad idea. If # needed, either use the first socket, or test separately IPv4 and IPv6. @@ -133,7 +135,7 @@ def get_server_uri(server, secure=False, resource_name='/'): else: # pragma: no cover raise ValueError("Expected an IPv6, IPv4, or Unix socket") - return '{}://{}:{}{}'.format(proto, host, port, resource_name) + return '{}://{}{}:{}{}'.format(proto, user_info, host, port, resource_name) class UnauthorizedServerProtocol(WebSocketServerProtocol): @@ -225,11 +227,12 @@ def start_server(self, **kwds): start_server = serve(handler, 'localhost', 0, **kwds) self.server = self.loop.run_until_complete(start_server) - def start_client(self, resource_name='/', **kwds): + def start_client(self, resource_name='/', user_info=None, **kwds): # Don't enable compression by default in tests. kwds.setdefault('compression', None) secure = kwds.get('ssl') is not None - server_uri = get_server_uri(self.server, secure, resource_name) + server_uri = get_server_uri( + self.server, secure, resource_name, user_info) start_client = connect(server_uri, **kwds) self.client = self.loop.run_until_complete(start_client) @@ -372,6 +375,14 @@ def test_protocol_path(self): server_path = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_path, '/path') + @with_server() + @with_client('/headers', user_info=('user', 'pass')) + def test_protocol_basic_auth(self): + self.assertEqual( + self.client.request_headers['Authorization'], + 'Basic dXNlcjpwYXNz', + ) + @with_server() @with_client('/headers') def test_protocol_headers(self): diff --git a/websockets/test_http.py b/websockets/test_http.py index a6d61299b..38f6363da 100644 --- a/websockets/test_http.py +++ b/websockets/test_http.py @@ -2,7 +2,7 @@ import unittest from .http import * -from .http import build_headers, read_headers +from .http import basic_auth_header, build_headers, read_headers class HTTPAsyncTests(unittest.TestCase): @@ -128,3 +128,10 @@ def test_build_headers_multi_value(self): # Ordering is deterministic when getting all values. self.assertEqual(headers.get_all('X-Foo'), ['Bar', 'Baz']) + + def test_basic_auth_header(self): + # Test vector from RFC 7617. + self.assertEqual( + basic_auth_header("Aladdin", "open sesame"), + ('Authorization', 'Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=='), + ) diff --git a/websockets/test_uri.py b/websockets/test_uri.py index 1b9928007..86e305ae2 100644 --- a/websockets/test_uri.py +++ b/websockets/test_uri.py @@ -5,17 +5,32 @@ VALID_URIS = [ - ('ws://localhost/', (False, 'localhost', 80, '/')), - ('wss://localhost/', (True, 'localhost', 443, '/')), - ('ws://localhost/path?query', (False, 'localhost', 80, '/path?query')), - ('WS://LOCALHOST/PATH?QUERY', (False, 'localhost', 80, '/PATH?QUERY')), + ( + 'ws://localhost/', + (False, 'localhost', 80, '/', None), + ), + ( + 'wss://localhost/', + (True, 'localhost', 443, '/', None), + ), + ( + 'ws://localhost/path?query', + (False, 'localhost', 80, '/path?query', None), + ), + ( + 'WS://LOCALHOST/PATH?QUERY', + (False, 'localhost', 80, '/PATH?QUERY', None), + ), + ( + 'ws://user:pass@localhost/', + (False, 'localhost', 80, '/', ('user', 'pass')), + ), ] INVALID_URIS = [ 'http://localhost/', 'https://localhost/', 'ws://localhost/path#fragment', - 'ws://user:pass@localhost/', ] diff --git a/websockets/uri.py b/websockets/uri.py index 84c3f3b87..21f757f8a 100644 --- a/websockets/uri.py +++ b/websockets/uri.py @@ -15,13 +15,17 @@ __all__ = ['parse_uri', 'WebSocketURI'] WebSocketURI = collections.namedtuple( - 'WebSocketURI', ['secure', 'host', 'port', 'resource_name']) + 'WebSocketURI', ['secure', 'host', 'port', 'resource_name', 'user_info']) WebSocketURI.__doc__ = """WebSocket URI. * ``secure`` is the secure flag * ``host`` is the lower-case host * ``port`` if the integer port, it's always provided even if it's the default * ``resource_name`` is the resource name, that is, the path and optional query +* ``user_info`` is an ``(username, password)`` tuple when the URI contains + `User Information`_, else ``None``. + +.. _User Information: https://tools.ietf.org/html/rfc3986#section-3.2.1 """ @@ -40,8 +44,6 @@ def parse_uri(uri): assert uri.scheme in ['ws', 'wss'] assert uri.params == '' assert uri.fragment == '' - assert uri.username is None - assert uri.password is None assert uri.hostname is not None except AssertionError as exc: raise InvalidURI("{} isn't a valid URI".format(uri)) from exc @@ -52,4 +54,7 @@ def parse_uri(uri): resource_name = uri.path or '/' if uri.query: resource_name += '?' + uri.query - return WebSocketURI(secure, host, port, resource_name) + user_info = None + if uri.username or uri.password: + user_info = (uri.username, uri.password) + return WebSocketURI(secure, host, port, resource_name, user_info) From a3e1680a7a9bc5ec9f68f339a69afe1d24aca89c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 1 May 2018 18:14:34 +0200 Subject: [PATCH 0383/1539] Add debug logs in ping / pong handling. Fix #315. --- websockets/protocol.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index 3d769c547..1f566938a 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -8,6 +8,7 @@ import asyncio import asyncio.queues +import binascii import codecs import collections import enum @@ -604,16 +605,41 @@ def read_data_frame(self, max_size): elif frame.opcode == OP_PING: # Answer pings. + # Replace by frame.data.hex() when dropping Python < 3.5. + ping_hex = binascii.hexlify(frame.data).decode() or '[empty]' + logger.debug("%s - received ping, sending pong: %s", + self.side, ping_hex) yield from self.pong(frame.data) elif frame.opcode == OP_PONG: - # Do not acknowledge pings on unsolicited pongs. + # Acknowledge pings on solicited pongs. if frame.data in self.pings: # Acknowledge all pings up to the one matching this pong. ping_id = None + ping_ids = [] while ping_id != frame.data: ping_id, pong_waiter = self.pings.popitem(0) + ping_ids.append(ping_id) pong_waiter.set_result(None) + pong_hex = ( + binascii.hexlify(frame.data).decode() or '[empty]') + logger.debug("%s - received solicited pong: %s", + self.side, pong_hex) + ping_ids = ping_ids[:-1] + if ping_ids: + pings_hex = ', '.join( + binascii.hexlify(ping_id).decode() or '[empty]' + for ping_id in ping_ids + ) + plural = 's' if len(ping_ids) > 1 else '' + logger.debug( + "%s - acknowledged previous ping%s: %s", + self.side, plural, pings_hex) + else: + pong_hex = ( + binascii.hexlify(frame.data).decode() or '[empty]') + logger.debug("%s - received unsolicited pong: %s", + self.side, pong_hex) # 5.6. Data Frames else: From 481d7fb10f6a521a28414b190f9ae0cd7b4219b5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 1 May 2018 09:18:07 +0200 Subject: [PATCH 0384/1539] Cancel unacknowledged pings. --- docs/changelog.rst | 3 +++ websockets/protocol.py | 13 +++++++++++++ websockets/test_protocol.py | 10 +++++++++- 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 745174f21..cea1ef499 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -24,6 +24,9 @@ Also: * :func:`~server.unix_serve` can be used as an asynchronous context manager on Python ≥ 3.5.1. +* If a :meth:`~protocol.WebSocketCommonProtocol.ping` doesn't receive a pong, + it's cancelled when the connection is closed. + 4.0 ... diff --git a/websockets/protocol.py b/websockets/protocol.py index 1f566938a..caf42ebbf 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -736,6 +736,19 @@ def close_connection(self, after_handshake=True): if after_handshake: yield from self.transfer_data_task + # Cancel all pending pings because they'll never receive a pong. + for ping in self.pings.values(): + ping.cancel() + if self.pings: + pings_hex = ', '.join( + binascii.hexlify(ping_id).decode() or '[empty]' + for ping_id in self.pings + ) + plural = 's' if len(self.pings) > 1 else '' + logger.debug( + "%s - cancelled pending ping%s: %s", + self.side, plural, pings_hex) + # A client should wait for a TCP Close from the server. if self.is_client and after_handshake: if (yield from self.wait_for_connection_lost()): diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index bf64b9d11..70348fb30 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -548,6 +548,14 @@ def test_acknowledge_ping(self): self.run_loop_once() self.assertTrue(ping.done()) + def test_cancel_ping(self): + ping = self.loop.run_until_complete(self.protocol.ping()) + # Remove the frame from the buffer, else close_connection() complains. + self.last_sent_frame() + self.assertFalse(ping.cancelled()) + self.close_connection() + self.assertTrue(ping.cancelled()) + def test_acknowledge_previous_pings(self): pings = [( self.loop.run_until_complete(self.protocol.ping()), @@ -568,7 +576,7 @@ def test_acknowledge_previous_pings(self): self.assertTrue(pings[1][0].done()) self.assertFalse(pings[2][0].done()) - def test_cancel_ping(self): + def test_cancelled_ping(self): ping = self.loop.run_until_complete(self.protocol.ping()) ping_frame = self.last_sent_frame() ping.cancel() From f2432937bd13c64df23bf14d336d9336d1251964 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 1 May 2018 15:07:29 +0200 Subject: [PATCH 0385/1539] Log connection state changes for debugging. --- websockets/protocol.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index caf42ebbf..9c3c98719 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -165,6 +165,7 @@ def __init__(self, *, # Subclasses implement the opening handshake and, on success, execute # :meth:`connection_open()` to change the state to OPEN. self.state = State.CONNECTING + logger.debug("%s - state = CONNECTING", self.side) # HTTP protocol parameters. self.path = None @@ -222,6 +223,7 @@ def connection_open(self): # 4.1. The WebSocket Connection is Established. assert self.state is State.CONNECTING self.state = State.OPEN + logger.debug("%s - state = OPEN", self.side) # Start the task that receives incoming WebSocket messages. self.transfer_data_task = asyncio_ensure_future( self.transfer_data(), loop=self.loop) @@ -671,6 +673,7 @@ def write_frame(self, opcode, data=b''): # before yielding control to avoid sending more than one close frame. if opcode == OP_CLOSE: self.state = State.CLOSING + logger.debug("%s - state = CLOSING", self.side) frame = Frame(True, opcode, data) logger.debug("%s > %s", self.side, frame) @@ -843,7 +846,7 @@ def connection_made(self, transport): which means it's the best point for configuring it. """ - logger.debug("%s - connection_made(%s)", self.side, transport) + logger.debug("%s - event = connection_made(%s)", self.side, transport) transport.set_write_buffer_limits(self.write_limit) super().connection_made(transport) @@ -870,7 +873,7 @@ def eof_received(self): As a consequence we revert to the previous, more useful behavior. """ - logger.debug("%s - eof_received()", self.side) + logger.debug("%s - event = eof_received()", self.side) super().eof_received() return @@ -879,8 +882,9 @@ def connection_lost(self, exc): 7.1.4. The WebSocket Connection is Closed. """ - logger.debug("%s - connection_lost(%s)", self.side, exc) + logger.debug("%s - event = connection_lost(%s)", self.side, exc) self.state = State.CLOSED + logger.debug("%s - state = CLOSED", self.side) if self.close_code is None: self.close_code = 1006 # If self.connection_lost_waiter isn't pending, that's a bug, because: From 0a69dac1cf49a802ea36a10b5d1d1a58f1a504fc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 1 May 2018 15:10:50 +0200 Subject: [PATCH 0386/1539] Don't close TCP connection if it's already closed. This had no effect because closing a TCP connection is idemptotent. However it emitted a log line which could be confusing. --- websockets/protocol.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index 9c3c98719..f8121a111 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -774,8 +774,11 @@ def close_connection(self, after_handshake=True): # The try/finally ensures that the transport never remains open, # even if this coroutine is cancelled (for example). - # Closing a transport is idempotent. If the transport was already - # closed, for example from eof_received(), it's fine. + # If connection_lost() was called, the TCP connection is closed. + # However, if TLS is enabled, the transport still needs closing. + # Else asyncio complains: ResourceWarning: unclosed transport. + if self.connection_lost_waiter.done() and not self.secure: + return # Close the TCP connection. Buffers are flushed asynchronously. logger.debug( From 402059e4a46a764632eba8a669f5b012f173ee7b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 1 May 2018 17:05:05 +0200 Subject: [PATCH 0387/1539] Fix behavior of recv() in the CLOSING state. The behavior wasn't tested correctly: in some test cases, the connection had already moved to the CLOSED state, where the close code and reason are already known. Refactor half_close_connection_{local,remote} to allow multiple runs of the event loop while remaining in the CLOSING state. Refactor affected tests accordingly. I verified that all tests in the CLOSING state were behaving is intended by inserting debug statements in recv/send/ping/pong and running: $ PYTHONASYNCIODEBUG=1 python -m unittest -v websockets.test_protocol.{Client,Server}Tests.test_{recv,send,ping,pong}_on_closing_connection_{local,remote} Fix #317, #327, #350, #357. --- websockets/protocol.py | 10 ++--- websockets/test_protocol.py | 78 +++++++++++++++++++++++++++++-------- 2 files changed, 66 insertions(+), 22 deletions(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index f8121a111..7583fe9c7 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -303,7 +303,7 @@ def recv(self): # Don't yield from self.ensure_open() here because messages could be # received before the closing frame even if the connection is closing. - # Wait for a message until the connection is closed + # Wait for a message until the connection is closed. next_message = asyncio_ensure_future( self.messages.get(), loop=self.loop) try: @@ -315,15 +315,15 @@ def recv(self): next_message.cancel() raise - # Now there's no need to yield from self.ensure_open(). Either a - # message was received or the connection was closed. - if next_message in done: return next_message.result() else: next_message.cancel() if not self.legacy_recv: - raise ConnectionClosed(self.close_code, self.close_reason) + assert self.state in [State.CLOSING, State.CLOSED] + # Wait until the connection is closed to raise + # ConnectionClosed with the correct code and reason. + yield from self.ensure_open() @asyncio.coroutine def send(self, data): diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 70348fb30..bfd4e3b0f 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -105,7 +105,7 @@ def run_loop_once(self): self.loop.call_soon(self.loop.stop) self.loop.run_forever() - def make_drain_slow(self, delay=3 * MS): + def make_drain_slow(self, delay=MS): # Process connection_made in order to initialize self.protocol.writer. self.run_loop_once() @@ -174,6 +174,8 @@ def close_connection(self, code=1000, reason='close'): # Empty the outgoing data stream so we can make assertions later on. self.assertOneFrameSent(True, OP_CLOSE, close_frame_data) + assert self.protocol.state is State.CLOSED + def half_close_connection_local(self, code=1000, reason='close'): """ Start a closing handshake but do not complete it. @@ -181,31 +183,56 @@ def half_close_connection_local(self, code=1000, reason='close'): The main difference with `close_connection` is that the connection is left in the CLOSING state until the event loop runs again. + The current implementation returns a task that must be awaited or + cancelled, else asyncio complains about destroying a pending task. + """ close_frame_data = serialize_close(code, reason) - # Trigger the closing handshake from the local side. - self.ensure_future(self.protocol.close(code, reason)) + # Trigger the closing handshake from the local endpoint. + close_task = self.ensure_future(self.protocol.close(code, reason)) self.run_loop_once() # wait_for executes self.run_loop_once() # write_frame executes # Empty the outgoing data stream so we can make assertions later on. self.assertOneFrameSent(True, OP_CLOSE, close_frame_data) - # Prepare the response to the closing handshake from the remote side. - self.loop.call_soon( - self.receive_frame, Frame(True, OP_CLOSE, close_frame_data)) - self.loop.call_soon(self.receive_eof_if_client) + + assert self.protocol.state is State.CLOSING + + # Complete the closing sequence at 1ms intervals so the test can run + # at each point even it goes back to the event loop several times. + self.loop.call_later( + MS, self.receive_frame, Frame(True, OP_CLOSE, close_frame_data)) + self.loop.call_later(2 * MS, self.receive_eof_if_client) + + # This task must be awaited or cancelled by the caller. + return close_task def half_close_connection_remote(self, code=1000, reason='close'): """ - Receive a closing handshake. + Receive a closing handshake but do not complete it. The main difference with `close_connection` is that the connection is left in the CLOSING state until the event loop runs again. """ + # On the server side, websockets completes the closing handshake and + # closes the TCP connection immediately. Yield to the event loop after + # sending the close frame to run the test while the connection is in + # the CLOSING state. + if not self.protocol.is_client: + self.make_drain_slow() + close_frame_data = serialize_close(code, reason) - # Trigger the closing handshake from the remote side. + # Trigger the closing handshake from the remote endpoint. self.receive_frame(Frame(True, OP_CLOSE, close_frame_data)) - self.receive_eof_if_client() + self.run_loop_once() # read_frame executes + # Empty the outgoing data stream so we can make assertions later on. + self.assertOneFrameSent(True, OP_CLOSE, close_frame_data) + + assert self.protocol.state is State.CLOSING + + # Complete the closing sequence at 1ms intervals so the test can run + # at each point even it goes back to the event loop several times. + self.loop.call_later(2 * MS, self.receive_eof_if_client) def process_invalid_frames(self): """ @@ -335,11 +362,13 @@ def test_recv_binary(self): self.assertEqual(data, b'tea') def test_recv_on_closing_connection_local(self): - self.half_close_connection_local() + close_task = self.half_close_connection_local() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.recv()) + self.loop.run_until_complete(close_task) # cleanup + def test_recv_on_closing_connection_remote(self): self.half_close_connection_remote() @@ -421,24 +450,29 @@ def test_send_type_error(self): self.assertNoFrameSent() def test_send_on_closing_connection_local(self): - self.half_close_connection_local() + close_task = self.half_close_connection_local() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.send('foobar')) + self.assertNoFrameSent() + self.loop.run_until_complete(close_task) # cleanup + def test_send_on_closing_connection_remote(self): self.half_close_connection_remote() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.send('foobar')) - self.assertOneFrameSent(True, OP_CLOSE, serialize_close(1000, 'close')) + + self.assertNoFrameSent() def test_send_on_closed_connection(self): self.close_connection() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.send('foobar')) + self.assertNoFrameSent() # Test the ping coroutine. @@ -466,24 +500,29 @@ def test_ping_type_error(self): self.assertNoFrameSent() def test_ping_on_closing_connection_local(self): - self.half_close_connection_local() + close_task = self.half_close_connection_local() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.ping()) + self.assertNoFrameSent() + self.loop.run_until_complete(close_task) # cleanup + def test_ping_on_closing_connection_remote(self): self.half_close_connection_remote() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.ping()) - self.assertOneFrameSent(True, OP_CLOSE, serialize_close(1000, 'close')) + + self.assertNoFrameSent() def test_ping_on_closed_connection(self): self.close_connection() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.ping()) + self.assertNoFrameSent() # Test the pong coroutine. @@ -506,24 +545,29 @@ def test_pong_type_error(self): self.assertNoFrameSent() def test_pong_on_closing_connection_local(self): - self.half_close_connection_local() + close_task = self.half_close_connection_local() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.pong()) + self.assertNoFrameSent() + self.loop.run_until_complete(close_task) # cleanup + def test_pong_on_closing_connection_remote(self): self.half_close_connection_remote() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.pong()) - self.assertOneFrameSent(True, OP_CLOSE, serialize_close(1000, 'close')) + + self.assertNoFrameSent() def test_pong_on_closed_connection(self): self.close_connection() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.pong()) + self.assertNoFrameSent() # Test the protocol's logic for acknowledging pings with pongs. From faa1a8aec785ea62e3c595ecf72b6802691a1d6e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 1 May 2018 19:38:24 +0200 Subject: [PATCH 0388/1539] Improve description of connect and serve. Fix #300. --- websockets/client.py | 7 ++++--- websockets/server.py | 22 ++++++++++++---------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index 6922a70c2..f873903cb 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -278,10 +278,11 @@ def handshake(self, wsuri, origin=None, available_extensions=None, class Connect: """ - This coroutine connects to a WebSocket server at a given ``uri``. + Connect to the WebSocket server at the given ``uri``. - It yields a :class:`WebSocketClientProtocol` which can then be used to - send and receive messages. + :func:`connect` returns an awaitable. Awaiting it yields an instance of + :class:`WebSocketClientProtocol` which can then be used to send and + receive messages. On Python ≥ 3.5.1, :func:`connect` can be used as a asynchronous context manager. In that case, the connection is closed when exiting the context. diff --git a/websockets/server.py b/websockets/server.py index 410087bad..e58c0acf5 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -595,21 +595,23 @@ def sockets(self): class Serve: """ - This coroutine creates, starts, and return a :class:`WebSocketServer`. + Create, start, and return a :class:`WebSocketServer`. - :func:`serve` is a wrapper around the event loop's - :meth:`~asyncio.AbstractEventLoop.create_server` method. Internally, the - function creates and starts a :class:`~asyncio.Server` object by calling - :meth:`~asyncio.AbstractEventLoop.create_server`. The - :class:`WebSocketServer` keeps a reference to this object. - - The returned :class:`WebSocketServer` and its resources can be cleaned up - by calling its :meth:`~websockets.server.WebSocketServer.close` and - :meth:`~websockets.server.WebSocketServer.wait_closed` methods. + :func:`serve` returns an awaitable. Awaiting it yields an instance of + :class:`WebSocketServer` which provides + :meth:`~websockets.server.WebSocketServer.close` and + :meth:`~websockets.server.WebSocketServer.wait_closed` methods for + terminating the server and cleaning up its resources. On Python ≥ 3.5, :func:`serve` can also be used as an asynchronous context manager. In this case, the server is shut down when exiting the context. + :func:`serve` is a wrapper around the event loop's + :meth:`~asyncio.AbstractEventLoop.create_server` method. Internally, it + creates and starts a :class:`~asyncio.Server` object by calling + :meth:`~asyncio.AbstractEventLoop.create_server`. The + :class:`WebSocketServer` it returns keeps a reference to this object. + The ``ws_handler`` argument is the WebSocket handler. It must be a coroutine accepting two arguments: a :class:`WebSocketServerProtocol` and the request URI. From 08d5e1892f27cac39d050d60e7b2404530f3ba54 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 May 2018 19:47:53 +0200 Subject: [PATCH 0389/1539] Improve advice about Python version. Group advice about older Python versions at the bottom of the page. --- docs/intro.rst | 42 ++++++++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/docs/intro.rst b/docs/intro.rst index 4c5551ba5..bab7c003b 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -3,32 +3,38 @@ Getting started .. currentmodule:: websockets -Installation +Requirements ------------ -Install ``websockets`` with:: +``websockets`` requires Python ≥ 3.4. - pip install websockets +You should use the latest version of Python if possible. If you're using an +older version, be aware that for each minor version (3.x), only the latest +bugfix release (3.x.y) is officially supported. -``websockets`` requires Python ≥ 3.4. We recommend using the latest version. - -If you're using an older version, be aware that for each minor version (3.x), -only the latest bugfix release (3.x.y) is officially supported. +For the best experience, you should start with Python ≥ 3.6. :mod:`asyncio` +received interesting improvements between Python 3.4 and 3.6. .. warning:: This documentation is written for Python ≥ 3.5.1. If you're using an older Python version, you need to :ref:`adapt the code samples `. +Installation +------------ + +Install ``websockets`` with:: + + pip install websockets + Basic example ------------- -*This section assumes Python ≥ 3.5.1. For older versions, read below.* - .. _server-example: -Here's a WebSocket server example. It reads a name from the client, sends a -greeting, and closes the connection. +Here's a WebSocket server example. + +It reads a name from the client, sends a greeting, and closes the connection. .. literalinclude:: ../example/server.py @@ -36,17 +42,14 @@ greeting, and closes the connection. On the server side, the handler coroutine ``hello`` is executed once for each WebSocket connection. The connection is automatically closed when the handler -returns. +coroutine returns. Here's a corresponding client example. .. literalinclude:: ../example/client.py -``async`` and ``await`` were introduced in Python 3.5. websockets supports -+asynchronous context managers on Python ≥ 3.5.1. Here's how to adapt the -client example for older Python versions. - -.. literalinclude:: ../example/oldclient.py +Using :func:`connect` as an asynchronous context manager ensures the +connection is closed before exiting the ``hello`` coroutine. Browser-based example --------------------- @@ -190,4 +193,7 @@ websockets supports asynchronous context managers only on Python ≥ 3.5.1 because :func:`~asyncio.ensure_future` was changed to accept arbitrary awaitables in that version. -If you're using Python ≤ 3.5, you can't use this feature. +If you're using Python < 3.5.1, you can't use this feature. Here's how to +adapt the basic client example. + +.. literalinclude:: ../example/oldclient.py From d851f57f56cbf4fceb6743d3afafb12780fe6132 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 May 2018 21:14:15 +0200 Subject: [PATCH 0390/1539] Simplify shutdown example on Python 3.6+. --- docs/deployment.rst | 6 ++++-- example/oldshutdown.py | 2 +- example/shutdown.py | 9 ++------- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/docs/deployment.rst b/docs/deployment.rst index beb8c7474..0203cfcf9 100644 --- a/docs/deployment.rst +++ b/docs/deployment.rst @@ -39,12 +39,14 @@ On Unix systems, shutdown is usually triggered by sending a signal. Here's a full example (Unix-only): .. literalinclude:: ../example/shutdown.py + :emphasize-lines: 13,17-19 ``async`` and ``await`` were introduced in Python 3.5. websockets supports -asynchronous context managers on Python ≥ 3.5.1. Here's the equivalent for -older Python versions: +asynchronous context managers on Python ≥ 3.5.1. ``async for`` was introduced +in Python 3.6. Here's the equivalent for older Python versions: .. literalinclude:: ../example/oldshutdown.py + :emphasize-lines: 22-25 It's more difficult to achieve the same effect on Windows. Some third-party projects try to help with this problem. diff --git a/example/oldshutdown.py b/example/oldshutdown.py index 6618aade0..180da9059 100644 --- a/example/oldshutdown.py +++ b/example/oldshutdown.py @@ -19,7 +19,7 @@ async def echo(websocket, path): start_server = websockets.serve(echo, 'localhost', 8765) server = loop.run_until_complete(start_server) -# Run the server until SIGTERM. +# Run the server until receiving SIGTERM. stop = asyncio.Future() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) loop.run_until_complete(stop) diff --git a/example/shutdown.py b/example/shutdown.py index 663dbd58a..dd3e8f6a4 100644 --- a/example/shutdown.py +++ b/example/shutdown.py @@ -5,13 +5,8 @@ import websockets async def echo(websocket, path): - while True: - try: - msg = await websocket.recv() - except websockets.ConnectionClosed: - break - else: - await websocket.send(msg) + async for message in websocket: + await websocket.send(message) async def echo_server(stop): async with websockets.serve(echo, 'localhost', 8765): From 9b52a8df0daf0ed7e226674fa4bae49a572bf65c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 6 May 2018 09:51:40 +0200 Subject: [PATCH 0391/1539] Add closed property on protocols. Fix #286. --- docs/api.rst | 1 + docs/changelog.rst | 2 ++ websockets/protocol.py | 11 +++++++++++ websockets/test_protocol.py | 5 +++++ 4 files changed, 19 insertions(+) diff --git a/docs/api.rst b/docs/api.rst index df68764c3..d41075ad8 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -79,6 +79,7 @@ Shared .. autoattribute:: remote_address .. autoattribute:: open + .. autoattribute:: closed Exceptions .......... diff --git a/docs/changelog.rst b/docs/changelog.rst index cea1ef499..01bb00b40 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -24,6 +24,8 @@ Also: * :func:`~server.unix_serve` can be used as an asynchronous context manager on Python ≥ 3.5.1. +* Added :meth:`~protocol.WebSocketCommonProtocol.closed` property. + * If a :meth:`~protocol.WebSocketCommonProtocol.ping` doesn't receive a pong, it's cancelled when the connection is closed. diff --git a/websockets/protocol.py b/websockets/protocol.py index 7583fe9c7..541aa75de 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -273,6 +273,17 @@ def open(self): """ return self.state is State.OPEN + @property + def closed(self): + """ + This property is ``True`` once the connection is closed. + + Be aware that :attr:`open` and :attr`closed` are ``False`` when the + connection is in the OPENING or CLOSING state. + + """ + return self.state is State.CLOSED + @asyncio.coroutine def recv(self): """ diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index bfd4e3b0f..aabc65ed3 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -349,6 +349,11 @@ def test_open(self): self.close_connection() self.assertFalse(self.protocol.open) + def test_closed(self): + self.assertFalse(self.protocol.closed) + self.close_connection() + self.assertTrue(self.protocol.closed) + # Test the recv coroutine. def test_recv_text(self): From 63f2ff7bcd86f66bdb0d59a708bff94ea69d26a1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 6 May 2018 10:00:48 +0200 Subject: [PATCH 0392/1539] Fix a couple formatting issues. --- docs/changelog.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 01bb00b40..5ecc0cbbd 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -10,8 +10,8 @@ Changelog .. warning:: - **Version 5.0 adds a ``user_info`` field to the return value of - :func:`~websockets.parse_uri`, :class:`~websockets.WebSocketURI`.** + **Version 5.0 adds a** ``user_info`` **field to the return value of** + :func:`~uri.parse_uri` **and** :class:`~uri.WebSocketURI` **.** If you're unpacking :class:`~websockets.WebSocketURI` into four variables, adjust your code to account for that fifth field. @@ -45,7 +45,7 @@ Also: .. warning:: - **Version 4.0 removes the ``state_name`` attribute of protocols.** + **Version 4.0 removes the** ``state_name`` **attribute of protocols.** Use ``protocol.state.name`` instead of ``protocol.state_name``. From ac1323e11a9c1d64058088c8ffafd270b131113a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 6 May 2018 18:39:30 +0200 Subject: [PATCH 0393/1539] Add missing entry in changelog. --- docs/changelog.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index 5ecc0cbbd..0128a5c60 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -29,6 +29,8 @@ Also: * If a :meth:`~protocol.WebSocketCommonProtocol.ping` doesn't receive a pong, it's cancelled when the connection is closed. +* Fixed missing close code, which caused :exc:`TypeError` on connection close. + 4.0 ... From 30db1d1c0043e9eb1c6c032dcc31ffefa0bc23e3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 May 2018 17:37:49 +0200 Subject: [PATCH 0394/1539] Remove obsolete file. It should have been removed in bbf9563a6387ef6eaaab0becd4eed273e3f41915. --- websockets/testcert.pem | 32 -------------------------------- 1 file changed, 32 deletions(-) delete mode 100644 websockets/testcert.pem diff --git a/websockets/testcert.pem b/websockets/testcert.pem deleted file mode 100644 index 1ed5fd458..000000000 --- a/websockets/testcert.pem +++ /dev/null @@ -1,32 +0,0 @@ ------BEGIN PRIVATE KEY----- -MIICdQIBADANBgkqhkiG9w0BAQEFAASCAl8wggJbAgEAAoGBANSBDRjLau8ur0s1 -WNVJdpa1x6PMdistb9VU9lBqxJzu8sgWnuzvy1Nt+1lCl6j6QtQxma99bPjbcZ9S -rXJUwtBLq067Zy01VQ/lpBfjqRZShYUVimg4We9KB5DFvWzP52L8Oj0U3sm46mek -vcddtJQz6WwbPiROOSvF80W206fNAgMBAAECgYAfSKBU9h1X+Nd1ivT48Ue0CC7L -vl3nHVlJXqikThODxumW6z2aQ/L65UYLbfJFvhH4ixTE8QIJ4MRpYBKIslG7c3DX -cX6MP6KPaUjxSbjB9RlS9VdKbovxxeecbWzfSY+Cz/alyg++J0iOwbJVGL+RlaJw -g8hQM+UWyJLN764/QQJBAP/NeBHChjU7QyA36lv2Lm/lUpkYy3Zy4ZTGPyiuBjLC -SNqF1PMxrvuHHL05NaE6R02VFXztxJf2ci1rZKDG2N8CQQDUqwdsWZFlmTA5hqTB -mEYw3feCij3t4sy0KDV1wV851WJRbVrzrbxN+rHL5MKwd3qcxs1TXCfF1A9qbPXS -phjTAkBtd/KgNwzUDu5lBUjH3gx1WkAEwHWh1PvwfP5eXErOwhIHYiqFgIePoHyO -BcOLobMN4nT1p5LwLUkjYsgHfdElAkBgbBL3izyjBeuZiXSV2gapDVq1MxyVCOmr -HTfv5fbY7+id5qkAJttjt7B5M4UaIXHUN0bM7tGRnm5G4JQsJ+bFAkAQ/pYfrC9l -2hXI29YTSYTsw4iDjgJF6RAxw2108M8KybSJdyvQ43N4U40BQx8BRQmxZwSyG5QX -s+j9Cb63orCr ------END PRIVATE KEY----- ------BEGIN CERTIFICATE----- -MIICijCCAfOgAwIBAgIJANEJitrPxb96MA0GCSqGSIb3DQEBBQUAMF0xCzAJBgNV -BAYTAkZSMQ8wDQYDVQQIDAZGcmFuY2UxDjAMBgNVBAcMBVBhcmlzMRkwFwYDVQQK -DBBBeW1lcmljIEF1Z3VzdGluMRIwEAYDVQQDDAlsb2NhbGhvc3QwIBcNMTQwNDE1 -MjEzMjI5WhgPMjExNDA0MTYyMTMyMjlaMF0xCzAJBgNVBAYTAkZSMQ8wDQYDVQQI -DAZGcmFuY2UxDjAMBgNVBAcMBVBhcmlzMRkwFwYDVQQKDBBBeW1lcmljIEF1Z3Vz -dGluMRIwEAYDVQQDDAlsb2NhbGhvc3QwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJ -AoGBANSBDRjLau8ur0s1WNVJdpa1x6PMdistb9VU9lBqxJzu8sgWnuzvy1Nt+1lC -l6j6QtQxma99bPjbcZ9SrXJUwtBLq067Zy01VQ/lpBfjqRZShYUVimg4We9KB5DF -vWzP52L8Oj0U3sm46mekvcddtJQz6WwbPiROOSvF80W206fNAgMBAAGjUDBOMB0G -A1UdDgQWBBRcFzeirOD3zMnjCptlc0sh9VWZJjAfBgNVHSMEGDAWgBRcFzeirOD3 -zMnjCptlc0sh9VWZJjAMBgNVHRMEBTADAQH/MA0GCSqGSIb3DQEBBQUAA4GBAFyv -MGP9hnrMbDnwRtCYX/g99nvxjc5KXJyDw91Vo3hmHjdVRXY/oJbjiUtOBf1OsgoN -rv7KsaMb9+060K+uDtQIIiwPcxF1nQOZDtv6Nyzj8hwM2XFl+XiVgUD2pg++scWF -PDfbpmeEDQnUMEqHETM7JTMLB349/s5UUQqsSBE0 ------END CERTIFICATE----- From f0b997a5960786d37a999ea808e2c71f8c24e2ce Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 May 2018 18:56:37 +0200 Subject: [PATCH 0395/1539] Enable TLS hostname checking in tests. This required generating a better self-signed certificate. --- websockets/test_client_server.py | 15 ++++++- websockets/test_localhost.cnf | 26 ++++++++++++ websockets/test_localhost.pem | 72 +++++++++++++++++++------------- 3 files changed, 83 insertions(+), 30 deletions(-) create mode 100644 websockets/test_localhost.cnf diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 7ee8b78d9..478bf4639 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -30,6 +30,13 @@ # Avoid displaying stack traces at the ERROR logging level. logging.basicConfig(level=logging.CRITICAL) + +# Generate TLS certificate with: +# $ openssl req -x509 -config test_localhost.cnf -days 15340 -newkey rsa:2048 \ +# -out test_localhost.crt -keyout test_localhost.key +# $ cat test_localhost.key test_localhost.crt > test_localhost.pem +# $ rm test_localhost.key test_localhost.crt + testcert = os.path.join(os.path.dirname(__file__), 'test_localhost.pem') @@ -123,9 +130,8 @@ def get_server_uri(server, secure=False, resource_name='/', user_info=None): # needed, either use the first socket, or test separately IPv4 and IPv6. server_socket = random.choice(server.sockets) - # That case if server_socket.family == socket.AF_INET6: # pragma: no cover - host, port = server_socket.getsockname()[:2] + host, port = server_socket.getsockname()[:2] # (no IPv6 on CI) host = '[{}]'.format(host) elif server_socket.family == socket.AF_INET: host, port = server_socket.getsockname() @@ -317,6 +323,7 @@ def send(self, *args, **kwargs): with self.temp_client( sock=client_socket, + # "You must set server_hostname when using ssl without a host" server_hostname='localhost' if self.secure else None, ): self.loop.run_until_complete(self.client.send("Hello!")) @@ -941,15 +948,19 @@ class SSLClientServerTests(ClientServerTests): @property def server_context(self): + # Change to ssl.PROTOCOL_TLS_SERVER when dropping Python < 3.6. ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ssl_context.load_cert_chain(testcert) return ssl_context @property def client_context(self): + # Change to ssl.PROTOCOL_TLS_CLIENT when dropping Python < 3.6. + # Then remove verify_mode and check_hostname below. ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ssl_context.load_verify_locations(testcert) ssl_context.verify_mode = ssl.CERT_REQUIRED + ssl_context.check_hostname = True return ssl_context def start_server(self, **kwds): diff --git a/websockets/test_localhost.cnf b/websockets/test_localhost.cnf new file mode 100644 index 000000000..6dc331ac6 --- /dev/null +++ b/websockets/test_localhost.cnf @@ -0,0 +1,26 @@ +[ req ] + +default_md = sha256 +encrypt_key = no + +prompt = no + +distinguished_name = dn +x509_extensions = ext + +[ dn ] + +C = "FR" +L = "Paris" +O = "Aymeric Augustin" +CN = "localhost" + +[ ext ] + +subjectAltName = @san + +[ san ] + +DNS.1 = localhost +IP.2 = 127.0.0.1 +IP.3 = ::1 diff --git a/websockets/test_localhost.pem b/websockets/test_localhost.pem index 1ed5fd458..b8a9ea9ab 100644 --- a/websockets/test_localhost.pem +++ b/websockets/test_localhost.pem @@ -1,32 +1,48 @@ -----BEGIN PRIVATE KEY----- -MIICdQIBADANBgkqhkiG9w0BAQEFAASCAl8wggJbAgEAAoGBANSBDRjLau8ur0s1 -WNVJdpa1x6PMdistb9VU9lBqxJzu8sgWnuzvy1Nt+1lCl6j6QtQxma99bPjbcZ9S -rXJUwtBLq067Zy01VQ/lpBfjqRZShYUVimg4We9KB5DFvWzP52L8Oj0U3sm46mek -vcddtJQz6WwbPiROOSvF80W206fNAgMBAAECgYAfSKBU9h1X+Nd1ivT48Ue0CC7L -vl3nHVlJXqikThODxumW6z2aQ/L65UYLbfJFvhH4ixTE8QIJ4MRpYBKIslG7c3DX -cX6MP6KPaUjxSbjB9RlS9VdKbovxxeecbWzfSY+Cz/alyg++J0iOwbJVGL+RlaJw -g8hQM+UWyJLN764/QQJBAP/NeBHChjU7QyA36lv2Lm/lUpkYy3Zy4ZTGPyiuBjLC -SNqF1PMxrvuHHL05NaE6R02VFXztxJf2ci1rZKDG2N8CQQDUqwdsWZFlmTA5hqTB -mEYw3feCij3t4sy0KDV1wV851WJRbVrzrbxN+rHL5MKwd3qcxs1TXCfF1A9qbPXS -phjTAkBtd/KgNwzUDu5lBUjH3gx1WkAEwHWh1PvwfP5eXErOwhIHYiqFgIePoHyO -BcOLobMN4nT1p5LwLUkjYsgHfdElAkBgbBL3izyjBeuZiXSV2gapDVq1MxyVCOmr -HTfv5fbY7+id5qkAJttjt7B5M4UaIXHUN0bM7tGRnm5G4JQsJ+bFAkAQ/pYfrC9l -2hXI29YTSYTsw4iDjgJF6RAxw2108M8KybSJdyvQ43N4U40BQx8BRQmxZwSyG5QX -s+j9Cb63orCr +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCUgrQVkNbAWRlo +zZUj14Ufz7YEp2MXmvmhdlfOGLwjy+xPO98aJRv5/nYF2eWM3llcmLe8FbBSK+QF +To4su7ZVnc6qITOHqcSDUw06WarQUMs94bhHUvQp1u8+b2hNiMeGw6+QiBI6OJRO +iGpLRbkN6Uj3AKwi8SYVoLyMiztuwbNyGf8fF3DDpHZtBitGtMSBCMsQsfB465pl +2UoyBrWa2lsbLt3VvBZZvHqfEuPjpjjKN5USIXnaf0NizaR6ps3EyfftWy4i7zIQ +N5uTExvaPDyPn9nH3q/dkT99mSMSU1AvTTpX8PN7DlqE6wZMbQsBPRGW7GElQ+Ox +IKdKOLk5AgMBAAECggEAd3kqzQqnaTiEs4ZoC9yPUUc1pErQ8iWP27Ar9TZ67MVa +B2ggFJV0C0sFwbFI9WnPNCn77gj4vzJmD0riH+SnS/tXThDFtscBu7BtvNp0C4Bj +8RWMvXxjxuENuQnBPFbkRWtZ6wk8uK/Zx9AAyyt9M07Qjz1wPfAIdm/IH7zHBFMA +gsqjnkLh1r0FvjNEbLiuGqYU/GVxaZYd+xy+JU52IxjHUUL9yD0BPWb+Szar6AM2 +gUpmTX6+BcCZwwZ//DzCoWYZ9JbP8akn6edBeZyuMPqYgLzZkPyQ+hRW46VPPw89 +yg4LR9nzgQiBHlac0laB4NrWa+d9QRRLitl1O3gVAQKBgQDDkptxXu7w9Lpc+HeE +N/pJfpCzUuF7ZC4vatdoDzvfB5Ky6W88Poq+I7bB9m7StXdFAbDyUBxvisjTBMVA +OtYqpAk/rhX8MjSAtjoFe2nH+eEiQriuZmtA5CdKEXS4hNbc/HhEPWhk7Zh8OV5v +y7l4r6l4UHqaN9QyE0vlFdmcmQKBgQDCZZR/trJ2/g2OquaS+Zd2h/3NXw0NBq4z +4OBEWqNa/R35jdK6WlWJH7+tKOacr+xtswLpPeZHGwMdk64/erbYWBuJWAjpH72J +DM9+1H5fFHANWpWTNn94enQxwfzZRvdkxq4IWzGhesptYnHIzoAmaqC3lbn/e3u0 +Flng32hFoQKBgQCF3D4K3hib0lYQtnxPgmUMktWF+A+fflViXTWs4uhu4mcVkFNz +n7clJ5q6reryzAQjtmGfqRedfRex340HRn46V2aBMK2Znd9zzcZu5CbmGnFvGs3/ +iNiWZNNDjike9sV+IkxLIODoW/vH4xhxWrbLFSjg0ezoy5ew4qZK2abF2QKBgQC5 +M5efeQpbjTyTUERtf/aKCZOGZmkDoPq0GCjxVjzNQdqd1z0NJ2TYR/QP36idXIlu +FZ7PYZaS5aw5MGpQtfOe94n8dm++0et7t0WzunRO1yTNxCA+aSxWNquegAcJZa/q +RdKlyWPmSRqzzZdDzWCPuQQ3AyF5wkYfUy/7qjwoIQKBgB2v96BV7+lICviIKzzb +1o3A3VzAX5MGd98uLGjlK4qsBC+s7mk2eQztiNZgbA0W6fhQ5Dz3HcXJ5ppy8Okc +jeAktrNRzz15hvi/XkWdO+VMqiHW4l+sWYukjhCyod1oO1KGHq0LYYvv076syxGw +vRKLq7IJ4WIp1VtfaBlrIogq -----END PRIVATE KEY----- -----BEGIN CERTIFICATE----- -MIICijCCAfOgAwIBAgIJANEJitrPxb96MA0GCSqGSIb3DQEBBQUAMF0xCzAJBgNV -BAYTAkZSMQ8wDQYDVQQIDAZGcmFuY2UxDjAMBgNVBAcMBVBhcmlzMRkwFwYDVQQK -DBBBeW1lcmljIEF1Z3VzdGluMRIwEAYDVQQDDAlsb2NhbGhvc3QwIBcNMTQwNDE1 -MjEzMjI5WhgPMjExNDA0MTYyMTMyMjlaMF0xCzAJBgNVBAYTAkZSMQ8wDQYDVQQI -DAZGcmFuY2UxDjAMBgNVBAcMBVBhcmlzMRkwFwYDVQQKDBBBeW1lcmljIEF1Z3Vz -dGluMRIwEAYDVQQDDAlsb2NhbGhvc3QwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJ -AoGBANSBDRjLau8ur0s1WNVJdpa1x6PMdistb9VU9lBqxJzu8sgWnuzvy1Nt+1lC -l6j6QtQxma99bPjbcZ9SrXJUwtBLq067Zy01VQ/lpBfjqRZShYUVimg4We9KB5DF -vWzP52L8Oj0U3sm46mekvcddtJQz6WwbPiROOSvF80W206fNAgMBAAGjUDBOMB0G -A1UdDgQWBBRcFzeirOD3zMnjCptlc0sh9VWZJjAfBgNVHSMEGDAWgBRcFzeirOD3 -zMnjCptlc0sh9VWZJjAMBgNVHRMEBTADAQH/MA0GCSqGSIb3DQEBBQUAA4GBAFyv -MGP9hnrMbDnwRtCYX/g99nvxjc5KXJyDw91Vo3hmHjdVRXY/oJbjiUtOBf1OsgoN -rv7KsaMb9+060K+uDtQIIiwPcxF1nQOZDtv6Nyzj8hwM2XFl+XiVgUD2pg++scWF -PDfbpmeEDQnUMEqHETM7JTMLB349/s5UUQqsSBE0 +MIIDTTCCAjWgAwIBAgIJAJ6VG2cQlsepMA0GCSqGSIb3DQEBCwUAMEwxCzAJBgNV +BAYTAkZSMQ4wDAYDVQQHDAVQYXJpczEZMBcGA1UECgwQQXltZXJpYyBBdWd1c3Rp +bjESMBAGA1UEAwwJbG9jYWxob3N0MCAXDTE4MDUwNTE2NTc1NloYDzIwNjAwNTA0 +MTY1NzU2WjBMMQswCQYDVQQGEwJGUjEOMAwGA1UEBwwFUGFyaXMxGTAXBgNVBAoM +EEF5bWVyaWMgQXVndXN0aW4xEjAQBgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZI +hvcNAQEBBQADggEPADCCAQoCggEBAJSCtBWQ1sBZGWjNlSPXhR/PtgSnYxea+aF2 +V84YvCPL7E873xolG/n+dgXZ5YzeWVyYt7wVsFIr5AVOjiy7tlWdzqohM4epxINT +DTpZqtBQyz3huEdS9CnW7z5vaE2Ix4bDr5CIEjo4lE6IaktFuQ3pSPcArCLxJhWg +vIyLO27Bs3IZ/x8XcMOkdm0GK0a0xIEIyxCx8HjrmmXZSjIGtZraWxsu3dW8Flm8 +ep8S4+OmOMo3lRIhedp/Q2LNpHqmzcTJ9+1bLiLvMhA3m5MTG9o8PI+f2cfer92R +P32ZIxJTUC9NOlfw83sOWoTrBkxtCwE9EZbsYSVD47Egp0o4uTkCAwEAAaMwMC4w +LAYDVR0RBCUwI4IJbG9jYWxob3N0hwR/AAABhxAAAAAAAAAAAAAAAAAAAAABMA0G +CSqGSIb3DQEBCwUAA4IBAQA0imKp/rflfbDCCx78NdsR5rt0jKem2t3YPGT6tbeU ++FQz62SEdeD2OHWxpvfPf+6h3iTXJbkakr2R4lP3z7GHUe61lt3So9VHAvgbtPTH +aB1gOdThA83o0fzQtnIv67jCvE9gwPQInViZLEcm2iQEZLj6AuSvBKmluTR7vNRj +8/f2R4LsDfCWGrzk2W+deGRvSow7irS88NQ8BW8S8otgMiBx4D2UlOmQwqr6X+/r +jYIDuMb6GDKRXtBUGDokfE94hjj9u2mrNRwt8y4tqu8ZNa//yLEQ0Ow2kP3QJPLY +941VZpwRi2v/+JvI7OBYlvbOTFwM8nAk79k+Dgviygd9 -----END CERTIFICATE----- From 8bbff9dd1a515288063ef47ce7e49cf527ea0760 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 May 2018 19:28:57 +0200 Subject: [PATCH 0396/1539] Explain how to secure websockets with TLS. Fix #363. --- docs/intro.rst | 32 +++++++++++++++++++++ example/client.py | 3 +- example/localhost.pem | 48 ++++++++++++++++++++++++++++++++ example/oldclient.py | 3 +- example/secure_client.py | 24 ++++++++++++++++ example/secure_server.py | 26 +++++++++++++++++ websockets/test_client_server.py | 5 +++- 7 files changed, 138 insertions(+), 3 deletions(-) create mode 100644 example/localhost.pem create mode 100644 example/secure_client.py create mode 100755 example/secure_server.py diff --git a/docs/intro.rst b/docs/intro.rst index bab7c003b..abca2dca6 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -37,6 +37,7 @@ Here's a WebSocket server example. It reads a name from the client, sends a greeting, and closes the connection. .. literalinclude:: ../example/server.py + :emphasize-lines: 6,14 .. _client-example: @@ -47,10 +48,40 @@ coroutine returns. Here's a corresponding client example. .. literalinclude:: ../example/client.py + :emphasize-lines: 7-8 Using :func:`connect` as an asynchronous context manager ensures the connection is closed before exiting the ``hello`` coroutine. +Secure example +-------------- + +Secure WebSocket connections improve confidentiality and also reliability +because they reduce the risk of interference by bad proxies. + +The WSS protocol is to WS what HTTPS is to HTTP: the connection is encrypted +with TLS. WSS requires TLS certificates like HTTPS. + +Here's how to adapt the server example to provide secure connections, using +APIs available in Python ≥ 3.6. + +Refer to the documentation of the :mod:`ssl` module for configuring the +context securely or adapting the code to older Python versions. + +.. literalinclude:: ../example/secure_server.py + :emphasize-lines: 18,22-23 + +Here's how to adapt the client, also on Python ≥ 3.6. + +.. literalinclude:: ../example/secure_client.py + :emphasize-lines: 10,15-16 + +This client needs a context because the server uses a self-signed certificate. + +A client connecting to a secure WebSocket server with a valid certificate +(i.e. signed by a CA that your Python installation trusts) can simply pass +``ssl=True`` to :func:`connect`` instead of building a context. + Browser-based example --------------------- @@ -197,3 +228,4 @@ If you're using Python < 3.5.1, you can't use this feature. Here's how to adapt the basic client example. .. literalinclude:: ../example/oldclient.py + :emphasize-lines: 8-9 diff --git a/example/client.py b/example/client.py index 5a3a026b4..3ce8273d3 100644 --- a/example/client.py +++ b/example/client.py @@ -4,7 +4,8 @@ import websockets async def hello(): - async with websockets.connect('ws://localhost:8765') as websocket: + async with websockets.connect( + 'ws://localhost:8765') as websocket: name = input("What's your name? ") await websocket.send(name) print("> {}".format(name)) diff --git a/example/localhost.pem b/example/localhost.pem new file mode 100644 index 000000000..f9a30ba8f --- /dev/null +++ b/example/localhost.pem @@ -0,0 +1,48 @@ +-----BEGIN PRIVATE KEY----- +MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDG8iDak4UBpurI +TWjSfqJ0YVG/S56nhswehupCaIzu0xQ8wqPSs36h5t1jMexJPZfvwyvFjcV+hYpj +LMM0wMJPx9oBQEe0bsmlC66e8aF0UpSQw1aVfYoxA9BejgEyrFNE7cRbQNYFEb/5 +3HfqZKdEQA2fgQSlZ0RTRmLrD+l72iO5o2xl5bttXpqYZB2XOkyO79j/xWdu9zFE +sgZJ5ysWbqoRAGgnxjdYYr9DARd8bIE/hN3SW7mDt5v4LqCIhGn1VmrwtT3d5AuG +QPz4YEbm0t6GOlmFjIMYH5Y7pALRVfoJKRj6DGNIR1JicL+wqLV66kcVnj8WKbla +20i7fR7NAgMBAAECggEAG5yvgqbG5xvLqlFUIyMAWTbIqcxNEONcoUAIc38fUGZr +gKNjKXNQOBha0dG0AdZSqCxmftzWdGEEfA9SaJf4YCpUz6ekTB60Tfv5GIZg6kwr +4ou6ELWD4Jmu6fC7qdTRGdgGUMQG8F0uT/eRjS67KHXbbi/x/SMAEK7MO+PRfCbj ++JGzS9Ym9mUweINPotgjHdDGwwd039VWYS+9A+QuNK27p3zq4hrWRb4wshSC8fKy +oLoe4OQt81aowpX9k6mAU6N8vOmP8/EcQHYC+yFIIDZB2EmDP07R1LUEH3KJnzo7 +plCK1/kYPhX0a05cEdTpXdKa74AlvSRkS11sGqfUAQKBgQDj1SRv0AUGsHSA0LWx +a0NT1ZLEXCG0uqgdgh0sTqIeirQsPROw3ky4lH5MbjkfReArFkhHu3M6KoywEPxE +wanSRh/t1qcNjNNZUvFoUzAKVpb33RLkJppOTVEWPt+wtyDlfz1ZAXzMV66tACrx +H2a3v0ZWUz6J+x/dESH5TTNL4QKBgQDfirmknp408pwBE+bulngKy0QvU09En8H0 +uvqr8q4jCXqJ1tXon4wsHg2yF4Fa37SCpSmvONIDwJvVWkkYLyBHKOns/fWCkW3n +hIcYx0q2jgcoOLU0uoaM9ArRXhIxoWqV/KGkQzN+3xXC1/MxZ5OhyxBxfPCPIYIN +YN3M1t/QbQKBgDImhsC+D30rdlmsl3IYZFed2ZKznQ/FTqBANd+8517FtWdPgnga +VtUCitKUKKrDnNafLwXrMzAIkbNn6b/QyWrp2Lln2JnY9+TfpxgJx7de3BhvZ2sl +PC4kQsccy+yAQxOBcKWY+Dmay251bP5qpRepWPhDlq6UwqzMyqev4KzBAoGAWDMi +IEO9ZGK9DufNXCHeZ1PgKVQTmJ34JxmHQkTUVFqvEKfFaq1Y3ydUfAouLa7KSCnm +ko42vuhGFB41bOdbMvh/o9RoBAZheNGfhDVN002ioUoOpSlbYU4A3q7hOtfXeCpf +lLI3JT3cFi6ic8HMTDAU4tJLEA5GhATOPr4hPNkCgYB8jTYGcLvoeFaLEveg0kS2 +cz6ZXGLJx5m1AOQy5g9FwGaW+10lr8TF2k3AldwoiwX0R6sHAf/945aGU83ms5v9 +PB9/x66AYtSRUos9MwB4y1ur4g6FiXZUBgTJUqzz2nehPCyGjYhh49WucjszqcjX +chS1bKZOY+1knWq8xj5Qyg== +-----END PRIVATE KEY----- +-----BEGIN CERTIFICATE----- +MIIDTTCCAjWgAwIBAgIJAOjte6l+03jvMA0GCSqGSIb3DQEBCwUAMEwxCzAJBgNV +BAYTAkZSMQ4wDAYDVQQHDAVQYXJpczEZMBcGA1UECgwQQXltZXJpYyBBdWd1c3Rp +bjESMBAGA1UEAwwJbG9jYWxob3N0MCAXDTE4MDUwNTE2NTkyOVoYDzIwNjAwNTA0 +MTY1OTI5WjBMMQswCQYDVQQGEwJGUjEOMAwGA1UEBwwFUGFyaXMxGTAXBgNVBAoM +EEF5bWVyaWMgQXVndXN0aW4xEjAQBgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZI +hvcNAQEBBQADggEPADCCAQoCggEBAMbyINqThQGm6shNaNJ+onRhUb9LnqeGzB6G +6kJojO7TFDzCo9KzfqHm3WMx7Ek9l+/DK8WNxX6FimMswzTAwk/H2gFAR7RuyaUL +rp7xoXRSlJDDVpV9ijED0F6OATKsU0TtxFtA1gURv/ncd+pkp0RADZ+BBKVnRFNG +YusP6XvaI7mjbGXlu21emphkHZc6TI7v2P/FZ273MUSyBknnKxZuqhEAaCfGN1hi +v0MBF3xsgT+E3dJbuYO3m/guoIiEafVWavC1Pd3kC4ZA/PhgRubS3oY6WYWMgxgf +ljukAtFV+gkpGPoMY0hHUmJwv7CotXrqRxWePxYpuVrbSLt9Hs0CAwEAAaMwMC4w +LAYDVR0RBCUwI4IJbG9jYWxob3N0hwR/AAABhxAAAAAAAAAAAAAAAAAAAAABMA0G +CSqGSIb3DQEBCwUAA4IBAQC9TsTxTEvqHPUS6sfvF77eG0D6HLOONVN91J+L7LiX +v3bFeS1xbUS6/wIxZi5EnAt/te5vaHk/5Q1UvznQP4j2gNoM6lH/DRkSARvRitVc +H0qN4Xp2Yk1R9VEx4ZgArcyMpI+GhE4vJRx1LE/hsuAzw7BAdsTt9zicscNg2fxO +3ao/eBcdaC6n9aFYdE6CADMpB1lCX2oWNVdj6IavQLu7VMc+WJ3RKncwC9th+5OP +ISPvkVZWf25rR2STmvvb0qEm3CZjk4Xd7N+gxbKKUvzEgPjrLSWzKKJAWHjCLugI +/kQqhpjWVlTbtKzWz5bViqCjSbrIPpU2MgG9AUV9y3iV +-----END CERTIFICATE----- diff --git a/example/oldclient.py b/example/oldclient.py index 763627a4b..71b11ece9 100755 --- a/example/oldclient.py +++ b/example/oldclient.py @@ -5,7 +5,8 @@ @asyncio.coroutine def hello(): - websocket = yield from websockets.connect('ws://localhost:8765/') + websocket = yield from websockets.connect( + 'ws://localhost:8765/') try: name = input("What's your name? ") diff --git a/example/secure_client.py b/example/secure_client.py new file mode 100644 index 000000000..63619bafb --- /dev/null +++ b/example/secure_client.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python + +# WSS (WS over TLS) client example, with a self-signed certificate + +import asyncio +import pathlib +import ssl +import websockets + +ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) +ssl_context.load_verify_locations( + pathlib.Path(__file__).with_name('localhost.pem')) + +async def hello(): + async with websockets.connect( + 'wss://localhost:8765', ssl=ssl_context) as websocket: + name = input("What's your name? ") + await websocket.send(name) + print("> {}".format(name)) + + greeting = await websocket.recv() + print("< {}".format(greeting)) + +asyncio.get_event_loop().run_until_complete(hello()) diff --git a/example/secure_server.py b/example/secure_server.py new file mode 100755 index 000000000..fabbd58b0 --- /dev/null +++ b/example/secure_server.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python + +# WSS (WS over TLS) server example, with a self-signed certificate + +import asyncio +import pathlib +import ssl +import websockets + +async def hello(websocket, path): + name = await websocket.recv() + print("< {}".format(name)) + + greeting = "Hello {}!".format(name) + await websocket.send(greeting) + print("> {}".format(greeting)) + +ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) +ssl_context.load_cert_chain( + pathlib.Path(__file__).with_name('localhost.pem')) + +start_server = websockets.serve( + hello, 'localhost', 8765, ssl=ssl_context) + +asyncio.get_event_loop().run_until_complete(start_server) +asyncio.get_event_loop().run_forever() diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 478bf4639..8d36452e8 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -960,7 +960,10 @@ def client_context(self): ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ssl_context.load_verify_locations(testcert) ssl_context.verify_mode = ssl.CERT_REQUIRED - ssl_context.check_hostname = True + # ssl.match_hostname can't match IP addresses on Python < 3.5. + # We're using IP addresses to enforce testing of IPv4 and IPv6. + if sys.version_info[:2] >= (3, 5): # pragma: no cover + ssl_context.check_hostname = True return ssl_context def start_server(self, **kwds): From d8c5887afbe3e1c21efe5276810f248ab6f141e4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 6 May 2018 20:35:06 +0200 Subject: [PATCH 0397/1539] Update examples for Python 3.6. Improve documentation about supporting older versions. --- docs/changelog.rst | 2 + docs/deployment.rst | 2 +- docs/intro.rst | 75 ++++++++++++++------- example/client.py | 7 +- example/echo.py | 0 example/hello.py | 0 example/{oldclient.py => old_client.py} | 3 + example/old_server.py | 21 ++++++ example/{oldshutdown.py => old_shutdown.py} | 0 example/secure_client.py | 5 +- example/secure_server.py | 7 +- example/{sendtime.py => send_time.py} | 2 + example/server.py | 9 ++- example/{showtime.html => show_time.html} | 0 example/shutdown.py | 0 15 files changed, 97 insertions(+), 36 deletions(-) mode change 100644 => 100755 example/client.py mode change 100644 => 100755 example/echo.py mode change 100644 => 100755 example/hello.py rename example/{oldclient.py => old_client.py} (91%) create mode 100755 example/old_server.py rename example/{oldshutdown.py => old_shutdown.py} (100%) mode change 100644 => 100755 mode change 100644 => 100755 example/secure_client.py rename example/{sendtime.py => send_time.py} (89%) mode change 100644 => 100755 rename example/{showtime.html => show_time.html} (100%) mode change 100644 => 100755 example/shutdown.py diff --git a/docs/changelog.rst b/docs/changelog.rst index 0128a5c60..1ca5c4ca3 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -29,6 +29,8 @@ Also: * If a :meth:`~protocol.WebSocketCommonProtocol.ping` doesn't receive a pong, it's cancelled when the connection is closed. +* Updated documentation with new features from Python 3.6. + * Fixed missing close code, which caused :exc:`TypeError` on connection close. 4.0 diff --git a/docs/deployment.rst b/docs/deployment.rst index 0203cfcf9..8cec48df0 100644 --- a/docs/deployment.rst +++ b/docs/deployment.rst @@ -45,7 +45,7 @@ Here's a full example (Unix-only): asynchronous context managers on Python ≥ 3.5.1. ``async for`` was introduced in Python 3.6. Here's the equivalent for older Python versions: -.. literalinclude:: ../example/oldshutdown.py +.. literalinclude:: ../example/old_shutdown.py :emphasize-lines: 22-25 It's more difficult to achieve the same effect on Windows. Some third-party diff --git a/docs/intro.rst b/docs/intro.rst index abca2dca6..126f9ddbc 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -17,8 +17,8 @@ received interesting improvements between Python 3.4 and 3.6. .. warning:: - This documentation is written for Python ≥ 3.5.1. If you're using an older - Python version, you need to :ref:`adapt the code samples `. + This documentation is written for Python ≥ 3.6. If you're using an older + Python version, you need to :ref:`adapt the code samples `. Installation ------------ @@ -37,18 +37,18 @@ Here's a WebSocket server example. It reads a name from the client, sends a greeting, and closes the connection. .. literalinclude:: ../example/server.py - :emphasize-lines: 6,14 + :emphasize-lines: 8,17 .. _client-example: -On the server side, the handler coroutine ``hello`` is executed once for each -WebSocket connection. The connection is automatically closed when the handler +On the server side, ``websockets`` executes the handler coroutine ``hello`` +once for each WebSocket connection. It closes the connection when the handler coroutine returns. -Here's a corresponding client example. +Here's a corresponding WebSocket client example. .. literalinclude:: ../example/client.py - :emphasize-lines: 7-8 + :emphasize-lines: 8-10 Using :func:`connect` as an asynchronous context manager ensures the connection is closed before exiting the ``hello`` coroutine. @@ -69,7 +69,7 @@ Refer to the documentation of the :mod:`ssl` module for configuring the context securely or adapting the code to older Python versions. .. literalinclude:: ../example/secure_server.py - :emphasize-lines: 18,22-23 + :emphasize-lines: 19,23-24 Here's how to adapt the client, also on Python ≥ 3.6. @@ -89,11 +89,11 @@ Here's an example of how to run a WebSocket server and connect from a browser. Run this script in a console: -.. literalinclude:: ../example/sendtime.py +.. literalinclude:: ../example/send_time.py Then open this HTML file in a browser. -.. literalinclude:: ../example/showtime.html +.. literalinclude:: ../example/show_time.html :language: html Common patterns @@ -112,6 +112,9 @@ For receiving messages and passing them to a ``consumer`` coroutine:: async for message in websocket: await consumer(message) +In this example, ``consumer`` represents your business logic for processing +messages received on the WebSocket connection. + Iteration terminates when the client disconnects. Asynchronous iteration was introduced in Python 3.6; here's the same code for @@ -136,6 +139,9 @@ For getting messages from a ``producer`` coroutine and sending them:: message = await producer() await websocket.send(message) +In this example, ``producer`` represents your business logic for generating +messages to send on the WebSocket connection. + :meth:`~protocol.WebSocketCommonProtocol.send` raises a :exc:`~exceptions.ConnectionClosed` exception when the client disconnects, which breaks out of the ``while True`` loop. @@ -147,13 +153,14 @@ You can read and write messages on the same connection by combining the two patterns shown above and running the two tasks in parallel:: async def handler(websocket, path): - consumer_task = asyncio.ensure_future(consumer_handler(websocket)) - producer_task = asyncio.ensure_future(producer_handler(websocket)) + consumer_task = asyncio.ensure_future( + consumer_handler(websocket, path)) + producer_task = asyncio.ensure_future( + producer_handler(websocket, path)) done, pending = await asyncio.wait( [consumer_task, producer_task], return_when=asyncio.FIRST_COMPLETED, ) - for task in pending: task.cancel() @@ -193,13 +200,31 @@ answering pings, or any other behavior required by the specification. ``websockets`` handles all this under the hood so you don't have to. -.. _python-lt-351: +.. _python-lt-36: -Python < 3.5.1 --------------- +Python < 3.6 +------------ + +This documentation takes advantage of several features that aren't available +in Python < 3.6: + +- ``await`` and ``async`` were added in Python 3.5; +- Asynchronous context managers didn't work well until Python 3.5.1; +- f-strings were introduced in Python 3.6 (unrelated to :mod:`asyncio` + :mod:`websockets`). + +Here's how to adapt the basic server example. + +.. literalinclude:: ../example/old_server.py + :emphasize-lines: 8-9,18 + +And here's the basic client example. + +.. literalinclude:: ../example/old_client.py + :emphasize-lines: 8-11,13,22-23 -This documentation uses the ``await`` and ``async`` syntax introduced in -Python 3.5. +``await`` and ``async`` +....................... If you're using Python < 3.5, you must substitute:: @@ -220,12 +245,12 @@ with:: Otherwise you will encounter a :exc:`SyntaxError`. -websockets supports asynchronous context managers only on Python ≥ 3.5.1 -because :func:`~asyncio.ensure_future` was changed to accept arbitrary -awaitables in that version. +Asynchronous context managers +............................. -If you're using Python < 3.5.1, you can't use this feature. Here's how to -adapt the basic client example. +Asynchronous context managers were added in Python 3.5. However, +``websockets`` only supports them on Python ≥ 3.5.1, where +:func:`~asyncio.ensure_future` accepts any awaitable. -.. literalinclude:: ../example/oldclient.py - :emphasize-lines: 8-9 +If you're using Python < 3.5.1, you must rely on ``try: ... finally: ...`` +instead. diff --git a/example/client.py b/example/client.py old mode 100644 new mode 100755 index 3ce8273d3..e71595ff5 --- a/example/client.py +++ b/example/client.py @@ -1,5 +1,7 @@ #!/usr/bin/env python +# WS client example + import asyncio import websockets @@ -7,10 +9,11 @@ async def hello(): async with websockets.connect( 'ws://localhost:8765') as websocket: name = input("What's your name? ") + await websocket.send(name) - print("> {}".format(name)) + print(f"> {name}") greeting = await websocket.recv() - print("< {}".format(greeting)) + print(f"< {greeting}") asyncio.get_event_loop().run_until_complete(hello()) diff --git a/example/echo.py b/example/echo.py old mode 100644 new mode 100755 diff --git a/example/hello.py b/example/hello.py old mode 100644 new mode 100755 diff --git a/example/oldclient.py b/example/old_client.py similarity index 91% rename from example/oldclient.py rename to example/old_client.py index 71b11ece9..c44d6edff 100755 --- a/example/oldclient.py +++ b/example/old_client.py @@ -1,5 +1,7 @@ #!/usr/bin/env python +# WS client example for old Python versions + import asyncio import websockets @@ -10,6 +12,7 @@ def hello(): try: name = input("What's your name? ") + yield from websocket.send(name) print("> {}".format(name)) diff --git a/example/old_server.py b/example/old_server.py new file mode 100755 index 000000000..bb19bdabc --- /dev/null +++ b/example/old_server.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python + +# WS server example for old Python versions + +import asyncio +import websockets + +@asyncio.coroutine +def hello(websocket, path): + name = yield from websocket.recv() + print("< {}".format(name)) + + greeting = "Hello {}!".format(name) + + yield from websocket.send(greeting) + print("> {}".format(greeting)) + +start_server = websockets.serve(hello, 'localhost', 8765) + +asyncio.get_event_loop().run_until_complete(start_server) +asyncio.get_event_loop().run_forever() diff --git a/example/oldshutdown.py b/example/old_shutdown.py old mode 100644 new mode 100755 similarity index 100% rename from example/oldshutdown.py rename to example/old_shutdown.py diff --git a/example/secure_client.py b/example/secure_client.py old mode 100644 new mode 100755 index 63619bafb..8e7f57ff9 --- a/example/secure_client.py +++ b/example/secure_client.py @@ -15,10 +15,11 @@ async def hello(): async with websockets.connect( 'wss://localhost:8765', ssl=ssl_context) as websocket: name = input("What's your name? ") + await websocket.send(name) - print("> {}".format(name)) + print(f"> {name}") greeting = await websocket.recv() - print("< {}".format(greeting)) + print(f"< {greeting}") asyncio.get_event_loop().run_until_complete(hello()) diff --git a/example/secure_server.py b/example/secure_server.py index fabbd58b0..5cbed46c0 100755 --- a/example/secure_server.py +++ b/example/secure_server.py @@ -9,11 +9,12 @@ async def hello(websocket, path): name = await websocket.recv() - print("< {}".format(name)) + print(f"< {name}") + + greeting = f"Hello {name}!" - greeting = "Hello {}!".format(name) await websocket.send(greeting) - print("> {}".format(greeting)) + print(f"> {greeting}") ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) ssl_context.load_cert_chain( diff --git a/example/sendtime.py b/example/send_time.py old mode 100644 new mode 100755 similarity index 89% rename from example/sendtime.py rename to example/send_time.py index 2b14827c8..6d196deb3 --- a/example/sendtime.py +++ b/example/send_time.py @@ -1,5 +1,7 @@ #!/usr/bin/env python +# WS server that sends messages at random intervals + import asyncio import datetime import random diff --git a/example/server.py b/example/server.py index 37744b815..cc5c8fea8 100755 --- a/example/server.py +++ b/example/server.py @@ -1,15 +1,18 @@ #!/usr/bin/env python +# WS server example + import asyncio import websockets async def hello(websocket, path): name = await websocket.recv() - print("< {}".format(name)) + print(f"< {name}") + + greeting = f"Hello {name}!" - greeting = "Hello {}!".format(name) await websocket.send(greeting) - print("> {}".format(greeting)) + print(f"> {greeting}") start_server = websockets.serve(hello, 'localhost', 8765) diff --git a/example/showtime.html b/example/show_time.html similarity index 100% rename from example/showtime.html rename to example/show_time.html diff --git a/example/shutdown.py b/example/shutdown.py old mode 100644 new mode 100755 From f192a4ad37ce72f9047ce4527721eb6fdcc792fd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 6 May 2018 20:48:22 +0200 Subject: [PATCH 0398/1539] Use pathlib instead of os.path. --- setup.py | 15 ++++++--------- websockets/py35/_test_client_server.py | 4 ++-- websockets/test_client_server.py | 7 +++---- 3 files changed, 11 insertions(+), 15 deletions(-) diff --git a/setup.py b/setup.py index e483dbecf..7ed5de214 100644 --- a/setup.py +++ b/setup.py @@ -1,19 +1,16 @@ -import os.path +import pathlib import sys import setuptools -root_dir = os.path.abspath(os.path.dirname(__file__)) + +root_dir = pathlib.Path(__file__).parent.resolve() description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" -readme_file = os.path.join(root_dir, 'README.rst') -with open(readme_file, encoding='utf-8') as f: - long_description = f.read() +long_description = (root_dir / 'README.rst').read_text(encoding='utf-8') -version_module = os.path.join(root_dir, 'websockets', 'version.py') -with open(version_module, encoding='utf-8') as f: - exec(f.read()) +exec((root_dir / 'websockets' / 'version.py').read_text(encoding='utf-8')) py_version = sys.version_info[:2] @@ -32,7 +29,7 @@ setuptools.Extension( 'websockets.speedups', sources=['websockets/speedups.c'], - optional=not os.path.exists(os.path.join(root_dir, '.cibuildwheel')), + optional=not (root_dir / '.cibuildwheel').exists(), ) ] diff --git a/websockets/py35/_test_client_server.py b/websockets/py35/_test_client_server.py index d0c0cfa13..437524885 100644 --- a/websockets/py35/_test_client_server.py +++ b/websockets/py35/_test_client_server.py @@ -1,7 +1,7 @@ # Tests containing Python 3.5+ syntax, extracted from test_client_server.py. import asyncio -import os +import pathlib import socket import sys import tempfile @@ -70,5 +70,5 @@ async def run_server(path): self.assertFalse(server.sockets) with tempfile.TemporaryDirectory() as temp_dir: - path = os.path.join(temp_dir, 'websockets') + path = bytes(pathlib.Path(temp_dir) / 'websockets') self.loop.run_until_complete(run_server(path)) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 8d36452e8..ae30afffb 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -2,7 +2,7 @@ import contextlib import functools import logging -import os.path +import pathlib import random import socket import ssl @@ -37,7 +37,7 @@ # $ cat test_localhost.key test_localhost.crt > test_localhost.pem # $ rm test_localhost.key test_localhost.crt -testcert = os.path.join(os.path.dirname(__file__), 'test_localhost.pem') +testcert = bytes(pathlib.Path(__file__).with_name('test_localhost.pem')) @asyncio.coroutine @@ -340,7 +340,7 @@ def send(self, *args, **kwargs): hasattr(socket, 'AF_UNIX'), 'this test requires Unix sockets') def test_unix_socket(self): with tempfile.TemporaryDirectory() as temp_dir: - path = os.path.join(temp_dir, 'websockets') + path = bytes(pathlib.Path(temp_dir) / 'websockets') # Like self.start_server() but with unix_serve(). unix_server = unix_serve(handler, path) @@ -941,7 +941,6 @@ def test_connection_error_during_closing_handshake(self, close): self.assertEqual(self.client.close_code, 1006) -@unittest.skipUnless(os.path.exists(testcert), "test certificate is missing") class SSLClientServerTests(ClientServerTests): secure = True From 0e8d2d5cffb2b9bdcc73d340618962a6e8a35192 Mon Sep 17 00:00:00 2001 From: mayeut Date: Sun, 6 May 2018 09:25:52 +0200 Subject: [PATCH 0399/1539] Add python_requires to setup.py --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 7ed5de214..81f70b2fc 100644 --- a/setup.py +++ b/setup.py @@ -58,4 +58,5 @@ ext_modules=ext_modules, include_package_data=True, zip_safe=True, + python_requires='>=3.4', ) From 9b0929582874cab3b3d18e283a841c43fd198eaf Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 6 May 2018 21:05:41 +0200 Subject: [PATCH 0400/1539] Adjust f192a4ad for compatibility with Python 3.4. --- setup.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 81f70b2fc..df7aa24f5 100644 --- a/setup.py +++ b/setup.py @@ -4,13 +4,19 @@ import setuptools -root_dir = pathlib.Path(__file__).parent.resolve() +root_dir = pathlib.Path(__file__).parent description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" -long_description = (root_dir / 'README.rst').read_text(encoding='utf-8') +# When dropping Python < 3.5, change to: +# long_description = (root_dir / 'README.rst').read_text(encoding='utf-8') +with (root_dir / 'README.rst').open(encoding='utf-8') as f: + long_description = f.read() -exec((root_dir / 'websockets' / 'version.py').read_text(encoding='utf-8')) +# When dropping Python < 3.5, change to: +# exec((root_dir / 'websockets' / 'version.py').read_text(encoding='utf-8')) +with (root_dir / 'websockets' / 'version.py').open(encoding='utf-8') as f: + exec(f.read()) py_version = sys.version_info[:2] From efad8562cb9fcd51199ba0504a90a3cdbec350e9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 8 May 2018 08:26:46 +0200 Subject: [PATCH 0401/1539] Log connection close codes and reasons. For debugging. --- websockets/protocol.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/websockets/protocol.py b/websockets/protocol.py index 541aa75de..4ad43dbb1 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -901,6 +901,8 @@ def connection_lost(self, exc): logger.debug("%s - state = CLOSED", self.side) if self.close_code is None: self.close_code = 1006 + logger.debug("%s x code = %d, reason = %s", self.side, + self.close_code, self.close_reason or '[empty]') # If self.connection_lost_waiter isn't pending, that's a bug, because: # - it's set only here in connection_lost() which is called only once; # - it must never be cancelled. From 45067908efb81169bfb6437a35887df5a5c52764 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 8 May 2018 23:21:51 +0200 Subject: [PATCH 0402/1539] Prevent spurious logs when running the protocol tests. This already exists in test_client_server.py. There's the same need in test_protocol.py. --- websockets/test_protocol.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index aabc65ed3..9048c6418 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -1,6 +1,7 @@ import asyncio import contextlib import functools +import logging import os import time import unittest @@ -12,6 +13,10 @@ from .protocol import State, WebSocketCommonProtocol +# Avoid displaying stack traces at the ERROR logging level. +logging.basicConfig(level=logging.CRITICAL) + + # Unit for timeouts. May be increased on slow machines by setting the # WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. MS = 0.001 * int(os.environ.get('WEBSOCKETS_TESTS_TIMEOUT_FACTOR', 1)) From cd304ff803377255dbd587068dc5d46a9964c6c2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 May 2018 21:09:01 +0200 Subject: [PATCH 0403/1539] Add an example of health check. Fix #354. --- docs/deployment.rst | 5 ++++- example/health_check_server.py | 23 +++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) create mode 100644 example/health_check_server.py diff --git a/docs/deployment.rst b/docs/deployment.rst index 8cec48df0..ed4453cd6 100644 --- a/docs/deployment.rst +++ b/docs/deployment.rst @@ -75,4 +75,7 @@ widely different operational characteristics of HTTP and WebSocket. ``websockets`` provide minimal support for responding to HTTP requests with the :meth:`~server.WebSocketServerProtocol.process_request()` hook. Typical -use cases include health checks. +use cases include health checks. Here's an example: + +.. literalinclude:: ../example/health_check_server.py + :emphasize-lines: 9-13,19-20 diff --git a/example/health_check_server.py b/example/health_check_server.py new file mode 100644 index 000000000..89fd1e2ff --- /dev/null +++ b/example/health_check_server.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python + +# WS echo server with HTTP endpoint at /health/ + +import asyncio +import http +import websockets + +class ServerProtocol(websockets.WebSocketServerProtocol): + + async def process_request(self, path, request_headers): + if path == '/health/': + return http.HTTPStatus.OK, [], b'OK\n' + +async def echo(websocket, path): + async for message in websocket: + await websocket.send(message) + +start_server = websockets.serve( + echo, 'localhost', 8765, create_protocol=ServerProtocol) + +asyncio.get_event_loop().run_until_complete(start_server) +asyncio.get_event_loop().run_forever() From 4f5e66fd02f102d9d5faea11b8fda884e3bb8d67 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 6 May 2018 21:56:26 +0200 Subject: [PATCH 0404/1539] Better handle hard TCP connection termination. There's no need to print a full stack trace when the network connection dropped earlier than we expected. A 1006 close code is likely sufficient to identify the cause if this is an issue in actual deployments. Fix #348, #349. --- docs/changelog.rst | 2 ++ websockets/protocol.py | 7 ++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 1ca5c4ca3..8886441bb 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -33,6 +33,8 @@ Also: * Fixed missing close code, which caused :exc:`TypeError` on connection close. +* Stopped logging stack traces when the TCP connection dies prematurely. + 4.0 ... diff --git a/websockets/protocol.py b/websockets/protocol.py index 4ad43dbb1..46ce54767 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -517,7 +517,12 @@ def transfer_data(self): pass except WebSocketProtocolError: yield from self.fail_connection(1002) - except asyncio.IncompleteReadError: + except (ConnectionError, EOFError): + # Reading data with self.reader.readexactly may raise: + # - most subclasses of ConnectionError if the TCP connection + # breaks, is reset, or is aborted; + # - IncompleteReadError, a subclass of EOFError, if fewer + # bytes are available than requested. yield from self.fail_connection(1006) except UnicodeDecodeError: yield from self.fail_connection(1007) From 9eefbc25831837aac3fa7370ba5faff9b8c17bb1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 12 May 2018 19:00:13 +0200 Subject: [PATCH 0405/1539] Prevent TypeError in WebSocketServer.wait_closed(). I think there's room for further improvement. Until then this will help. Fix #309. --- websockets/server.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/websockets/server.py b/websockets/server.py index e58c0acf5..b265bddd5 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -578,7 +578,8 @@ def wait_closed(self): yield from asyncio.wait( [websocket.handler_task for websocket in self.websockets] + [websocket.close_connection_task - for websocket in self.websockets], + for websocket in self.websockets + if websocket.close_connection_task], loop=self.loop) yield from self.server.wait_closed() From 4c2c93a30cf9a843959dcfbb9b7cc94ebca50825 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 12 May 2018 16:16:46 +0200 Subject: [PATCH 0406/1539] Downgrade log severity in handshake. Logging InvalidHandshake errors at the debug level avoids dumping stack traces into server logs whenever a plain HTTP client hits a WS endpoint. Fix #369. --- websockets/server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/websockets/server.py b/websockets/server.py index b265bddd5..1a1406fb8 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -105,14 +105,14 @@ def handler(self): exc.body, ) elif isinstance(exc, InvalidOrigin): - logger.warning("Invalid origin", exc_info=True) + logger.debug("Invalid origin", exc_info=True) early_response = ( FORBIDDEN, [], str(exc).encode(), ) elif isinstance(exc, InvalidHandshake): - logger.warning("Invalid handshake", exc_info=True) + logger.debug("Invalid handshake", exc_info=True) early_response = ( BAD_REQUEST, [], From 43f4a6db0408908a8b60b9701b15f184510642ff Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 12 May 2018 16:37:43 +0200 Subject: [PATCH 0407/1539] Simplify validation of Connection header. The ABNF for the value of the Connection header is: Connection = 1#connection-option connection-option = token --- websockets/handshake.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/websockets/handshake.py b/websockets/handshake.py index b84325001..1a9bb8564 100644 --- a/websockets/handshake.py +++ b/websockets/handshake.py @@ -83,9 +83,7 @@ def check_request(get_header): """ try: assert get_header('Upgrade').lower() == 'websocket' - assert any( - token.strip() == 'upgrade' - for token in get_header('Connection').lower().split(',')) + assert get_header('Connection').lower() == 'upgrade' key = get_header('Sec-WebSocket-Key') assert len(base64.b64decode(key.encode(), validate=True)) == 16 assert get_header('Sec-WebSocket-Version') == '13' @@ -125,9 +123,7 @@ def check_response(get_header, key): """ try: assert get_header('Upgrade').lower() == 'websocket' - assert any( - token.strip() == 'upgrade' - for token in get_header('Connection').lower().split(',')) + assert get_header('Connection').lower() == 'upgrade' assert get_header('Sec-WebSocket-Accept') == accept(key) except Exception as exc: raise InvalidHandshake("Invalid response") from exc From 9e8090ea943dc8431a0d7d3fa7fe77a65236f5c3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 12 May 2018 17:44:47 +0200 Subject: [PATCH 0408/1539] Include header name in parsing exception messages. --- websockets/exceptions.py | 37 +++++++++++++------- websockets/headers.py | 64 +++++++++++++++++++---------------- websockets/server.py | 2 +- websockets/test_exceptions.py | 18 +++++++--- websockets/test_headers.py | 6 ++-- 5 files changed, 78 insertions(+), 49 deletions(-) diff --git a/websockets/exceptions.py b/websockets/exceptions.py index 210aa40e0..2dfdac14e 100644 --- a/websockets/exceptions.py +++ b/websockets/exceptions.py @@ -1,9 +1,9 @@ __all__ = [ - 'AbortHandshake', 'InvalidHandshake', 'InvalidHeader', 'InvalidMessage', - 'InvalidOrigin', 'InvalidState', 'InvalidStatusCode', 'NegotiationError', - 'InvalidParameterName', 'InvalidParameterValue', 'DuplicateParameter', - 'InvalidURI', 'ConnectionClosed', 'PayloadTooBig', - 'WebSocketProtocolError', + 'AbortHandshake', 'InvalidHandshake', 'InvalidHeader', + 'InvalidHeaderFormat', 'InvalidMessage', 'InvalidOrigin', 'InvalidState', + 'InvalidStatusCode', 'NegotiationError', 'InvalidParameterName', + 'InvalidParameterValue', 'DuplicateParameter', 'InvalidURI', + 'ConnectionClosed', 'PayloadTooBig', 'WebSocketProtocolError', ] @@ -37,21 +37,34 @@ class InvalidMessage(InvalidHandshake): class InvalidHeader(InvalidHandshake): """ - Exception raised when a HTTP header doesn't have the expected format. + Exception raised when a HTTP header doesn't have a valid format or value. """ - def __init__(self, message, string, pos): - self.string = string - self.pos = pos - message = "{} at {} in {}".format(message, pos, string) + def __init__(self, name, value): + if value: + message = "Invalid {} header: {}".format(name, value) + else: + message = "Missing or empty {} header".format(name) super().__init__(message) -class InvalidOrigin(InvalidHandshake): +class InvalidHeaderFormat(InvalidHeader): + """ + Exception raised when a Sec-WebSocket-* HTTP header cannot be parsed. + + """ + def __init__(self, name, error, string, pos): + error = "{} at {} in {}".format(error, pos, string) + super().__init__(name, error) + + +class InvalidOrigin(InvalidHeader): """ - Exception raised when the origin in a handshake request is forbidden. + Exception raised when the Origin header in a request isn't allowed. """ + def __init__(self, origin): + super().__init__('Origin', origin) class InvalidStatusCode(InvalidHandshake): diff --git a/websockets/headers.py b/websockets/headers.py index c5b228cb9..873984f0f 100644 --- a/websockets/headers.py +++ b/websockets/headers.py @@ -9,7 +9,7 @@ import re -from .exceptions import InvalidHeader +from .exceptions import InvalidHeaderFormat __all__ = [ @@ -54,18 +54,19 @@ def parse_OWS(string, pos): _token_re = re.compile(r'[-!#$%&\'*+.^_`|~0-9a-zA-Z]+') -def parse_token(string, pos): +def parse_token(string, pos, header_name): """ Parse a token from ``string`` at the given position. Return the token value and the new position. - Raise :exc:`~websockets.exceptions.InvalidHeader` on invalid inputs. + Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. """ match = _token_re.match(string, pos) if match is None: - raise InvalidHeader("Expected token", string=string, pos=pos) + raise InvalidHeaderFormat( + header_name, "expected token", string=string, pos=pos) return match.group(), match.end() @@ -76,46 +77,48 @@ def parse_token(string, pos): _unquote_re = re.compile(r'\\([\x09\x20-\x7e\x80-\xff])') -def parse_quoted_string(string, pos): +def parse_quoted_string(string, pos, header_name): """ Parse a quoted string from ``string`` at the given position. Return the unquoted value and the new position. - Raise :exc:`~websockets.exceptions.InvalidHeader` on invalid inputs. + Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. """ match = _quoted_string_re.match(string, pos) if match is None: - raise InvalidHeader("Expected quoted string", string=string, pos=pos) + raise InvalidHeaderFormat( + header_name, "expected quoted string", string=string, pos=pos) return _unquote_re.sub(r'\1', match.group()[1:-1]), match.end() -def parse_extension_param(string, pos): +def parse_extension_param(string, pos, header_name): """ Parse a single extension parameter from ``string`` at the given position. Return a ``(name, value)`` pair and the new position. - Raise :exc:`~websockets.exceptions.InvalidHeader` on invalid inputs. + Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. """ # Extract parameter name. - name, pos = parse_token(string, pos) + name, pos = parse_token(string, pos, header_name) pos = parse_OWS(string, pos) # Extract parameter string, if there is one. if peek_ahead(string, pos) == '=': pos = parse_OWS(string, pos + 1) if peek_ahead(string, pos) == '"': pos_before = pos # for proper error reporting below - value, pos = parse_quoted_string(string, pos) + value, pos = parse_quoted_string(string, pos, header_name) # https://tools.ietf.org/html/rfc6455#section-9.1 says: the value # after quoted-string unescaping MUST conform to the 'token' ABNF. if _token_re.fullmatch(value) is None: - raise InvalidHeader("Invalid quoted string content", - string=string, pos=pos_before) + raise InvalidHeaderFormat( + header_name, "invalid quoted string content", + string=string, pos=pos_before) else: - value, pos = parse_token(string, pos) + value, pos = parse_token(string, pos, header_name) pos = parse_OWS(string, pos) else: value = None @@ -123,29 +126,30 @@ def parse_extension_param(string, pos): return (name, value), pos -def parse_extension(string, pos): +def parse_extension(string, pos, header_name): """ Parse an extension definition from ``string`` at the given position. Return an ``(extension name, parameters)`` pair, where ``parameters`` is a list of ``(name, value)`` pairs, and the new position. - Raise :exc:`~websockets.exceptions.InvalidHeader` on invalid inputs. + Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. """ # Extract extension name. - name, pos = parse_token(string, pos) + name, pos = parse_token(string, pos, header_name) pos = parse_OWS(string, pos) # Extract all parameters. parameters = [] while peek_ahead(string, pos) == ';': pos = parse_OWS(string, pos + 1) - parameter, pos = parse_extension_param(string, pos) + parameter, pos = parse_extension_param(string, pos, header_name) parameters.append(parameter) return (name, parameters), pos -def parse_extension_list(string, pos=0): +def parse_extension_list( + string, pos=0, header_name='Sec-WebSocket-Extensions'): """ Parse a ``Sec-WebSocket-Extensions`` header. @@ -166,7 +170,7 @@ def parse_extension_list(string, pos=0): Parameter values are ``None`` when no value is provided. - Raise :exc:`~websockets.exceptions.InvalidHeader` on invalid inputs. + Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. """ # Per https://tools.ietf.org/html/rfc7230#section-7, "a recipient MUST @@ -180,7 +184,7 @@ def parse_extension_list(string, pos=0): extensions = [] while True: # Loop invariant: an extension starts at pos in string. - extension, pos = parse_extension(string, pos) + extension, pos = parse_extension(string, pos, header_name) extensions.append(extension) # We may have reached the end of the string. @@ -191,7 +195,8 @@ def parse_extension_list(string, pos=0): if peek_ahead(string, pos) == ',': pos = parse_OWS(string, pos + 1) else: - raise InvalidHeader("Expected comma", string=string, pos=pos) + raise InvalidHeaderFormat( + header_name, "expected comma", string=string, pos=pos) # Remove extra delimiters before the next extension. while peek_ahead(string, pos) == ',': @@ -235,21 +240,21 @@ def build_extension_list(extensions): ) -def parse_protocol(string, pos): +def parse_protocol(string, pos, header_name): """ Parse a protocol definition from ``string`` at the given position. Return the protocol and the new position. - Raise :exc:`~websockets.exceptions.InvalidHeader` on invalid inputs. + Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. """ - name, pos = parse_token(string, pos) + name, pos = parse_token(string, pos, header_name) pos = parse_OWS(string, pos) return name, pos -def parse_protocol_list(string, pos=0): +def parse_protocol_list(string, pos=0, header_name='Sec-WebSocket-Protocol'): """ Parse a ``Sec-WebSocket-Protocol`` header. @@ -257,7 +262,7 @@ def parse_protocol_list(string, pos=0): Return a list of protocols. - Raise :exc:`~websockets.exceptions.InvalidHeader` on invalid inputs. + Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. """ # Per https://tools.ietf.org/html/rfc7230#section-7, "a recipient MUST @@ -271,7 +276,7 @@ def parse_protocol_list(string, pos=0): protocols = [] while True: # Loop invariant: a protocol starts at pos in string. - protocol, pos = parse_protocol(string, pos) + protocol, pos = parse_protocol(string, pos, header_name) protocols.append(protocol) # We may have reached the end of the string. @@ -282,7 +287,8 @@ def parse_protocol_list(string, pos=0): if peek_ahead(string, pos) == ',': pos = parse_OWS(string, pos + 1) else: - raise InvalidHeader("Expected comma", string=string, pos=pos) + raise InvalidHeaderFormat( + header_name, "expected comma", string=string, pos=pos) # Remove extra delimiters before the next protocol. while peek_ahead(string, pos) == ',': diff --git a/websockets/server.py b/websockets/server.py index 1a1406fb8..9f1e4477b 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -264,7 +264,7 @@ def process_origin(self, get_header, origins=None): origin = get_header('Origin') if origins is not None: if origin not in origins: - raise InvalidOrigin("Origin not allowed: {}".format(origin)) + raise InvalidOrigin(origin) return origin @staticmethod diff --git a/websockets/test_exceptions.py b/websockets/test_exceptions.py index 031490f9e..b7aafa86a 100644 --- a/websockets/test_exceptions.py +++ b/websockets/test_exceptions.py @@ -20,12 +20,22 @@ def test_str(self): "Malformed HTTP message", ), ( - InvalidHeader("Expected token", "a=|", 3), - "Expected token at 3 in a=|", + InvalidHeader('Upgrade', ''), + "Missing or empty Upgrade header", ), ( - InvalidOrigin("Origin not allowed: ''"), - "Origin not allowed: ''", + InvalidHeader('Connection', 'websocket'), + "Invalid Connection header: websocket", + ), + ( + InvalidHeaderFormat( + 'Sec-WebSocket-Protocol', "expected token", 'a=|', 3), + "Invalid Sec-WebSocket-Protocol header: " + "expected token at 3 in a=|", + ), + ( + InvalidOrigin('http://bad.origin'), + 'Invalid Origin header: http://bad.origin', ), ( InvalidStatusCode(403), diff --git a/websockets/test_headers.py b/websockets/test_headers.py index 230aadfac..efb41ae29 100644 --- a/websockets/test_headers.py +++ b/websockets/test_headers.py @@ -1,6 +1,6 @@ import unittest -from .exceptions import InvalidHeader +from .exceptions import InvalidHeaderFormat from .headers import * @@ -66,7 +66,7 @@ def test_parse_extension_list_invalid_header(self): 'foo; bar=" "', ]: with self.subTest(header=header): - with self.assertRaises(InvalidHeader): + with self.assertRaises(InvalidHeaderFormat): parse_extension_list(header) def test_parse_protocol_list(self): @@ -101,5 +101,5 @@ def test_parse_protocol_list_invalid_header(self): 'foo; bar', ]: with self.subTest(header=header): - with self.assertRaises(InvalidHeader): + with self.assertRaises(InvalidHeaderFormat): parse_protocol_list(header) From b91088418cbb061b3a033931e4975d21b0dd249e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 12 May 2018 18:07:21 +0200 Subject: [PATCH 0409/1539] Raise more specific exceptions on hanshake errors. --- websockets/exceptions.py | 25 ++++++++++++++---- websockets/handshake.py | 49 ++++++++++++++++++++++------------- websockets/test_exceptions.py | 21 ++++++++++++--- websockets/test_handshake.py | 5 ++-- 4 files changed, 71 insertions(+), 29 deletions(-) diff --git a/websockets/exceptions.py b/websockets/exceptions.py index 2dfdac14e..b4f7dfc2e 100644 --- a/websockets/exceptions.py +++ b/websockets/exceptions.py @@ -1,9 +1,10 @@ __all__ = [ - 'AbortHandshake', 'InvalidHandshake', 'InvalidHeader', - 'InvalidHeaderFormat', 'InvalidMessage', 'InvalidOrigin', 'InvalidState', - 'InvalidStatusCode', 'NegotiationError', 'InvalidParameterName', - 'InvalidParameterValue', 'DuplicateParameter', 'InvalidURI', - 'ConnectionClosed', 'PayloadTooBig', 'WebSocketProtocolError', + 'AbortHandshake', 'ConnectionClosed', 'DuplicateParameter', + 'InvalidHandshake', 'InvalidHeader', 'InvalidHeaderFormat', + 'InvalidHeaderValue', 'InvalidMessage', 'InvalidOrigin', + 'InvalidParameterName', 'InvalidParameterValue', 'InvalidState', + 'InvalidStatusCode', 'InvalidUpgrade', 'InvalidURI', 'NegotiationError', + 'PayloadTooBig', 'WebSocketProtocolError', ] @@ -58,6 +59,20 @@ def __init__(self, name, error, string, pos): super().__init__(name, error) +class InvalidHeaderValue(InvalidHeader): + """ + Exception raised when a Sec-WebSocket-* HTTP header has a wrong value. + + """ + + +class InvalidUpgrade(InvalidHeader): + """ + Exception raised when a Upgrade or Connection header isn't correct. + + """ + + class InvalidOrigin(InvalidHeader): """ Exception raised when the Origin header in a request isn't allowed. diff --git a/websockets/handshake.py b/websockets/handshake.py index 1a9bb8564..142fef61c 100644 --- a/websockets/handshake.py +++ b/websockets/handshake.py @@ -35,10 +35,11 @@ """ import base64 +import binascii import hashlib import random -from .exceptions import InvalidHandshake +from .exceptions import InvalidHeaderValue, InvalidUpgrade __all__ = [ @@ -56,8 +57,8 @@ def build_request(set_header): Return the ``key`` which must be passed to :func:`check_response`. """ - rand = bytes(random.getrandbits(8) for _ in range(16)) - key = base64.b64encode(rand).decode() + raw_key = bytes(random.getrandbits(8) for _ in range(16)) + key = base64.b64encode(raw_key).decode() set_header('Upgrade', 'websocket') set_header('Connection', 'Upgrade') set_header('Sec-WebSocket-Key', key) @@ -81,16 +82,25 @@ def check_request(get_header): responsibility of the caller. """ + if get_header('Upgrade').lower() != 'websocket': + raise InvalidUpgrade('Upgrade', get_header('Upgrade')) + + if get_header('Connection').lower() != 'upgrade': + raise InvalidUpgrade('Connection', get_header('Connection')) + + key = get_header('Sec-WebSocket-Key') try: - assert get_header('Upgrade').lower() == 'websocket' - assert get_header('Connection').lower() == 'upgrade' - key = get_header('Sec-WebSocket-Key') - assert len(base64.b64decode(key.encode(), validate=True)) == 16 - assert get_header('Sec-WebSocket-Version') == '13' - except Exception as exc: - raise InvalidHandshake("Invalid request") from exc - else: - return key + raw_key = base64.b64decode(key.encode(), validate=True) + except binascii.Error: + raise InvalidHeaderValue('Sec-WebSocket-Key', key) + if len(raw_key) != 16: + raise InvalidHeaderValue('Sec-WebSocket-Key', key) + + version = get_header('Sec-WebSocket-Version') + if version != '13': + raise InvalidHeaderValue('Sec-WebSocket-Version', version) + + return key def build_response(set_header, key): @@ -121,12 +131,15 @@ def check_response(get_header, key): the caller. """ - try: - assert get_header('Upgrade').lower() == 'websocket' - assert get_header('Connection').lower() == 'upgrade' - assert get_header('Sec-WebSocket-Accept') == accept(key) - except Exception as exc: - raise InvalidHandshake("Invalid response") from exc + if get_header('Upgrade').lower() != 'websocket': + raise InvalidUpgrade('Upgrade', get_header('Upgrade')) + + if get_header('Connection').lower() != 'upgrade': + raise InvalidUpgrade('Connection', get_header('Connection')) + + if get_header('Sec-WebSocket-Accept') != accept(key): + raise InvalidHeaderValue( + 'Sec-WebSocket-Accept', get_header('Sec-WebSocket-Accept')) def accept(key): diff --git a/websockets/test_exceptions.py b/websockets/test_exceptions.py index b7aafa86a..da87bed5d 100644 --- a/websockets/test_exceptions.py +++ b/websockets/test_exceptions.py @@ -20,12 +20,12 @@ def test_str(self): "Malformed HTTP message", ), ( - InvalidHeader('Upgrade', ''), - "Missing or empty Upgrade header", + InvalidHeader('Name', ''), + "Missing or empty Name header", ), ( - InvalidHeader('Connection', 'websocket'), - "Invalid Connection header: websocket", + InvalidHeader('Name', 'Value'), + "Invalid Name header: Value", ), ( InvalidHeaderFormat( @@ -33,6 +33,19 @@ def test_str(self): "Invalid Sec-WebSocket-Protocol header: " "expected token at 3 in a=|", ), + ( + InvalidHeaderValue('Sec-WebSocket-Version', '42'), + "Invalid Sec-WebSocket-Version header: 42", + ), + + ( + InvalidUpgrade('Upgrade', ''), + "Missing or empty Upgrade header", + ), + ( + InvalidUpgrade('Connection', 'websocket'), + "Invalid Connection header: websocket", + ), ( InvalidOrigin('http://bad.origin'), 'Invalid Origin header: http://bad.origin', diff --git a/websockets/test_handshake.py b/websockets/test_handshake.py index fee35bc9d..62b4ffc0f 100644 --- a/websockets/test_handshake.py +++ b/websockets/test_handshake.py @@ -1,3 +1,4 @@ +import collections import contextlib import unittest @@ -31,7 +32,7 @@ def assertInvalidRequestHeaders(self): Assert that the transformation made them invalid. """ - headers = {} + headers = collections.defaultdict(lambda: '') build_request(headers.__setitem__) yield headers with self.assertRaises(InvalidHandshake): @@ -85,7 +86,7 @@ def assertInvalidResponseHeaders(self, key='CSIRmL8dWYxeAdr/XpEHRw=='): Assert that the transformation made them invalid. """ - headers = {} + headers = collections.defaultdict(lambda: '') build_response(headers.__setitem__, key) yield headers with self.assertRaises(InvalidHandshake): From b07b88950eefe83bfc7fea6c01ec87969f3ae1d5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 12 May 2018 18:59:45 +0200 Subject: [PATCH 0410/1539] Respond with HTTP 426 to HTTP requests. --- websockets/compatibility.py | 5 ++++ websockets/server.py | 11 ++++++-- websockets/test_client_server.py | 43 ++++++++++++++++++++++++-------- 3 files changed, 46 insertions(+), 13 deletions(-) diff --git a/websockets/compatibility.py b/websockets/compatibility.py index c3bb61321..21bc586a4 100644 --- a/websockets/compatibility.py +++ b/websockets/compatibility.py @@ -21,6 +21,7 @@ BAD_REQUEST = http.HTTPStatus.BAD_REQUEST UNAUTHORIZED = http.HTTPStatus.UNAUTHORIZED FORBIDDEN = http.HTTPStatus.FORBIDDEN + UPGRADE_REQUIRED = http.HTTPStatus.UPGRADE_REQUIRED INTERNAL_SERVER_ERROR = http.HTTPStatus.INTERNAL_SERVER_ERROR SERVICE_UNAVAILABLE = http.HTTPStatus.SERVICE_UNAVAILABLE except AttributeError: # pragma: no cover @@ -45,6 +46,10 @@ class FORBIDDEN: value = 403 phrase = "Forbidden" + class UPGRADE_REQUIRED: + value = 426 + phrase = "Upgrade Required" + class INTERNAL_SERVER_ERROR: value = 500 phrase = "Internal Server Error" diff --git a/websockets/server.py b/websockets/server.py index 9f1e4477b..ec5517e5d 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -10,11 +10,11 @@ from .compatibility import ( BAD_REQUEST, FORBIDDEN, INTERNAL_SERVER_ERROR, SERVICE_UNAVAILABLE, - SWITCHING_PROTOCOLS, asyncio_ensure_future + SWITCHING_PROTOCOLS, UPGRADE_REQUIRED, asyncio_ensure_future ) from .exceptions import ( AbortHandshake, InvalidHandshake, InvalidMessage, InvalidOrigin, - NegotiationError + InvalidUpgrade, NegotiationError ) from .extensions.permessage_deflate import ServerPerMessageDeflateFactory from .handshake import build_response, check_request @@ -111,6 +111,13 @@ def handler(self): [], str(exc).encode(), ) + elif isinstance(exc, InvalidUpgrade): + logger.debug("Invalid upgrade", exc_info=True) + early_response = ( + UPGRADE_REQUIRED, + [('Upgrade', 'websocket')], + str(exc).encode(), + ) elif isinstance(exc, InvalidHandshake): logger.debug("Invalid handshake", exc_info=True) early_response = ( diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index ae30afffb..67424dfb3 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -10,6 +10,7 @@ import tempfile import unittest import unittest.mock +import urllib.error import urllib.request from .client import * @@ -472,14 +473,10 @@ def test_protocol_custom_response_user_agent(self): self.assertEqual(resp_headers.count("Server"), 1) self.assertIn("('Server', 'Eggs')", resp_headers) - @with_server(create_protocol=HealthCheckServerProtocol) - @with_client() - def test_custom_protocol_http_request(self): - # One URL returns an HTTP response. - - # Set url to 'https?://:/__health__/'. + def make_http_request(self, path='/'): + # Set url to 'https?://:'. url = get_server_uri( - self.server, resource_name='/__health__/', secure=self.secure) + self.server, resource_name=path, secure=self.secure) url = url.replace('ws', 'http') if self.secure: @@ -489,18 +486,42 @@ def test_custom_protocol_http_request(self): open_health_check = functools.partial( urllib.request.urlopen, url) + return self.loop.run_in_executor(None, open_health_check) + + @with_server(create_protocol=HealthCheckServerProtocol) + def test_http_request_http_endpoint(self): + # Making a HTTP request to a HTTP endpoint succeeds. response = self.loop.run_until_complete( - self.loop.run_in_executor(None, open_health_check)) + self.make_http_request('/__health__/')) with contextlib.closing(response): self.assertEqual(response.code, 200) self.assertEqual(response.read(), b'status = green\n') - # Other URLs create a WebSocket connection. + @with_server(create_protocol=HealthCheckServerProtocol) + def test_http_request_ws_endpoint(self): + # Making a HTTP request to a WS endpoint fails. + with self.assertRaises(urllib.error.HTTPError) as raised: + self.loop.run_until_complete(self.make_http_request()) + + self.assertEqual(raised.exception.code, 426) + self.assertEqual(raised.exception.headers['Upgrade'], 'websocket') + @with_server(create_protocol=HealthCheckServerProtocol) + def test_ws_connection_http_endpoint(self): + # Making a WS connection to a HTTP endpoint fails. + with self.assertRaises(InvalidStatusCode) as raised: + self.start_client('/__health__/') + + self.assertEqual(raised.exception.status_code, 200) + + @with_server(create_protocol=HealthCheckServerProtocol) + def test_ws_connection_ws_endpoint(self): + # Making a WS connection to a WS endpoint succeeds. + self.start_client() self.loop.run_until_complete(self.client.send("Hello!")) - reply = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(reply, "Hello!") + self.loop.run_until_complete(self.client.recv()) + self.stop_client() def assert_client_raises_code(self, status_code): with self.assertRaises(InvalidStatusCode) as raised: From 4b1527d380ee0fe6526b07107788c77026e12da9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 12 May 2018 19:30:32 +0200 Subject: [PATCH 0411/1539] Add trailing newline to HTTP response bodies. --- websockets/server.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/websockets/server.py b/websockets/server.py index ec5517e5d..3a810410e 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -96,7 +96,7 @@ def handler(self): early_response = ( SERVICE_UNAVAILABLE, [], - b"Server is shutting down.", + b"Server is shutting down.\n", ) elif isinstance(exc, AbortHandshake): early_response = ( @@ -109,28 +109,28 @@ def handler(self): early_response = ( FORBIDDEN, [], - str(exc).encode(), + (str(exc) + "\n").encode(), ) elif isinstance(exc, InvalidUpgrade): logger.debug("Invalid upgrade", exc_info=True) early_response = ( UPGRADE_REQUIRED, [('Upgrade', 'websocket')], - str(exc).encode(), + (str(exc) + "\n").encode(), ) elif isinstance(exc, InvalidHandshake): logger.debug("Invalid handshake", exc_info=True) early_response = ( BAD_REQUEST, [], - str(exc).encode(), + (str(exc) + "\n").encode(), ) else: logger.warning("Error in opening handshake", exc_info=True) early_response = ( INTERNAL_SERVER_ERROR, [], - b"See server log for more information.", + b"See server log for more information.\n", ) yield from self.write_http_response(*early_response) From 7578987add60c72a8e2fdecfc407d7d3af47c154 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 12 May 2018 15:57:00 +0200 Subject: [PATCH 0412/1539] Document that handlers can be cancelled. Fix #337. --- websockets/server.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/websockets/server.py b/websockets/server.py index 3a810410e..8836b3e93 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -659,6 +659,12 @@ class Serve: delegates to the WebSocket handler. Once the handler completes, the server performs the closing handshake and closes the connection. + When a server is closed with + :meth:`~websockets.server.WebSocketServer.close`, all running WebSocket + handlers are cancelled. They may intercept :exc:`~asyncio.CancelledError` + and perform cleanup actions before re-raising that exception. If a handler + started new tasks, it should cancel them as well in that case. + Since there's no useful way to propagate exceptions triggered in handlers, they're sent to the ``'websockets.server'`` logger instead. Debugging is much easier if you configure logging to print them:: From 2098046e41d139b54f7884b7f13e93bc79894a78 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 15 May 2018 14:00:32 +0200 Subject: [PATCH 0413/1539] Don't treat close code 1001 as an error. Quite often, browsers close WebSocket connections when the user navigates to another page. In most cases this isn't an error condition and thus not worth reporting with an exception. --- docs/changelog.rst | 3 +++ websockets/protocol.py | 6 +++--- websockets/py36/_test_client_server.py | 27 +++++++++++++++++++++++++- websockets/py36/protocol.py | 2 +- 4 files changed, 33 insertions(+), 5 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 8886441bb..cf48a298a 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -21,6 +21,9 @@ Also: * :func:`~client.connect()` performs HTTP Basic Auth when the URI contains credentials. +* Iterating on incoming messages no longer raises an exception when the + connection terminates with code 1001 (going away). + * :func:`~server.unix_serve` can be used as an asynchronous context manager on Python ≥ 3.5.1. diff --git a/websockets/protocol.py b/websockets/protocol.py index 46ce54767..563594bb4 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -59,9 +59,9 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): await process(message) The iterator yields incoming messages. It exits normally when the - connection is closed with the status code 1000 OK. It raises a - :exc:`~websockets.exceptions.ConnectionClosed` exception when the - connection is closed with any other status code. + connection is closed with the status code 1000 (OK) or 1001 (going away). + It raises a :exc:`~websockets.exceptions.ConnectionClosed` exception when + the connection is closed with any other status code. The ``host``, ``port`` and ``secure`` parameters are simply stored as attributes for handlers that need them. diff --git a/websockets/py36/_test_client_server.py b/websockets/py36/_test_client_server.py index 0bd0d8938..e81fbd600 100644 --- a/websockets/py36/_test_client_server.py +++ b/websockets/py36/_test_client_server.py @@ -55,7 +55,7 @@ async def run_client(): server.close() self.loop.run_until_complete(server.wait_closed()) - def test_iterate_on_messages_exit_not_ok(self): + def test_iterate_on_messages_going_away_exit_ok(self): async def handler(ws, path): for message in MESSAGES: @@ -67,6 +67,31 @@ async def handler(ws, path): messages = [] + async def run_client(): + nonlocal messages + async with connect(get_server_uri(server)) as ws: + async for message in ws: + messages.append(message) + + self.loop.run_until_complete(run_client()) + + self.assertEqual(messages, MESSAGES) + + server.close() + self.loop.run_until_complete(server.wait_closed()) + + def test_iterate_on_messages_internal_error_exit_not_ok(self): + + async def handler(ws, path): + for message in MESSAGES: + await ws.send(message) + await ws.close(1011) + + start_server = serve(handler, 'localhost', 0) + server = self.loop.run_until_complete(start_server) + + messages = [] + async def run_client(): nonlocal messages async with connect(get_server_uri(server)) as ws: diff --git a/websockets/py36/protocol.py b/websockets/py36/protocol.py index 919f9a038..f0784de05 100644 --- a/websockets/py36/protocol.py +++ b/websockets/py36/protocol.py @@ -14,7 +14,7 @@ async def __aiter__(self): while True: yield await self.recv() except ConnectionClosed as exc: - if exc.code == 1000: + if exc.code == 1000 or exc.code == 1001: return else: raise From 97c81dfaaa522861c7849a76caff2c790a9b1d02 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 16 May 2018 22:52:53 +0200 Subject: [PATCH 0414/1539] Add an example of synchronization between clients. Fix #395. --- docs/intro.rst | 26 ++++++++++++-- example/counter.html | 80 ++++++++++++++++++++++++++++++++++++++++++++ example/counter.py | 61 +++++++++++++++++++++++++++++++++ 3 files changed, 165 insertions(+), 2 deletions(-) create mode 100644 example/counter.html create mode 100755 example/counter.py diff --git a/docs/intro.rst b/docs/intro.rst index 126f9ddbc..d9ba2cbcc 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -96,6 +96,27 @@ Then open this HTML file in a browser. .. literalinclude:: ../example/show_time.html :language: html +Synchronization example +----------------------- + +A WebSocket server can receive events from clients, process them to update the +application state, and synchronize the resulting state across clients. + +Here's an example where any client can increment or decrement a counter. +Updates are propagated to all connected clients. + +The concurrency model of :mod:`asyncio` guarantees that updates are +serialized. + +Run this script in a console: + +.. literalinclude:: ../example/counter.py + +Then open this HTML file in several browsers. + +.. literalinclude:: ../example/counter.html + :language: html + Common patterns --------------- @@ -167,8 +188,9 @@ patterns shown above and running the two tasks in parallel:: Registration ............ -If you need to maintain a list of currently connected clients, you must -register clients when they connect and unregister them when they disconnect. +As shown in the synchronization example above, if you need to maintain a list +of currently connected clients, you must register them when they connect and +unregister them when they disconnect. :: diff --git a/example/counter.html b/example/counter.html new file mode 100644 index 000000000..6310c4a16 --- /dev/null +++ b/example/counter.html @@ -0,0 +1,80 @@ + + + + WebSocket demo + + + +
+
-
+
?
+
+
+
+
+ ? online +
+ + + diff --git a/example/counter.py b/example/counter.py new file mode 100755 index 000000000..9cce009fd --- /dev/null +++ b/example/counter.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python + +# WS server example that synchronizes state across clients + +import asyncio +import json +import logging +import websockets + +logging.basicConfig() + +STATE = {'value': 0} + +USERS = set() + +def state_event(): + return json.dumps({'type': 'state', **STATE}) + +def users_event(): + return json.dumps({'type': 'users', 'count': len(USERS)}) + +async def notify_state(): + if USERS: # asyncio.wait doesn't accept an empty list + message = state_event() + await asyncio.wait([user.send(message) for user in USERS]) + +async def notify_users(): + if USERS: # asyncio.wait doesn't accept an empty list + message = users_event() + await asyncio.wait([user.send(message) for user in USERS]) + +async def register(websocket): + USERS.add(websocket) + await notify_users() + +async def unregister(websocket): + USERS.remove(websocket) + await notify_users() + +async def counter(websocket, path): + # register(websocket) sends user_event() to websocket + await register(websocket) + try: + await websocket.send(state_event()) + async for message in websocket: + data = json.loads(message) + if data['action'] == 'minus': + STATE['value'] -= 1 + await notify_state() + elif data['action'] == 'plus': + STATE['value'] += 1 + await notify_state() + else: + logging.error( + "unsupported event: {}", data) + finally: + await unregister(websocket) + +asyncio.get_event_loop().run_until_complete( + websockets.serve(counter, 'localhost', 6789)) +asyncio.get_event_loop().run_forever() From 6204d2e90eae9ccc7b8825b84f90b2d09f07233c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 16 May 2018 22:53:41 +0200 Subject: [PATCH 0415/1539] Fix a few inconsistencies. --- docs/intro.rst | 3 +-- example/health_check_server.py | 0 example/{send_time.py => show_time.py} | 0 3 files changed, 1 insertion(+), 2 deletions(-) mode change 100644 => 100755 example/health_check_server.py rename example/{send_time.py => show_time.py} (100%) diff --git a/docs/intro.rst b/docs/intro.rst index d9ba2cbcc..a39023bb5 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -89,7 +89,7 @@ Here's an example of how to run a WebSocket server and connect from a browser. Run this script in a console: -.. literalinclude:: ../example/send_time.py +.. literalinclude:: ../example/show_time.py Then open this HTML file in a browser. @@ -197,7 +197,6 @@ unregister them when they disconnect. connected = set() async def handler(websocket, path): - global connected # Register. connected.add(websocket) try: diff --git a/example/health_check_server.py b/example/health_check_server.py old mode 100644 new mode 100755 diff --git a/example/send_time.py b/example/show_time.py similarity index 100% rename from example/send_time.py rename to example/show_time.py From 1f6a09591923488355147c8eba8a0de9553a96b0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 16 May 2018 23:01:50 +0200 Subject: [PATCH 0416/1539] Extend description of Python < 3.6 idioms. --- docs/intro.rst | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/docs/intro.rst b/docs/intro.rst index a39023bb5..3eb3505db 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -231,8 +231,9 @@ in Python < 3.6: - ``await`` and ``async`` were added in Python 3.5; - Asynchronous context managers didn't work well until Python 3.5.1; -- f-strings were introduced in Python 3.6 (unrelated to :mod:`asyncio` - :mod:`websockets`). +- Asynchronous iterators were added in Python 3.6; +- f-strings were introduced in Python 3.6 (this is unrelated to :mod:`asyncio` + and :mod:`websockets`). Here's how to adapt the basic server example. @@ -273,5 +274,33 @@ Asynchronous context managers were added in Python 3.5. However, ``websockets`` only supports them on Python ≥ 3.5.1, where :func:`~asyncio.ensure_future` accepts any awaitable. -If you're using Python < 3.5.1, you must rely on ``try: ... finally: ...`` -instead. +If you're using Python < 3.5.1, instead of:: + + with websockets.connect(...) as client: + ... + +you must write:: + + client = yield from websockets.connect(...) + try: + ... + finally: + yield from client.close() + +Asynchronous iterators +...................... + +If you're using Python < 3.6, you must replace:: + + async for message in websocket: + ... + +with:: + + while True: + message = yield from websocket.recv() + ... + +The latter will always raise a :exc:`~exceptions.ConnectionClosed` exception +when the connection is closed, while the former will only raise that exception +if the connection terminates with an error. From cb5943808e36122cfc05d2ca3165860fd5f249a5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 15 May 2018 22:36:12 +0200 Subject: [PATCH 0417/1539] Factor out duplicated logic. --- websockets/headers.py | 164 +++++++++++++++++------------------------- 1 file changed, 65 insertions(+), 99 deletions(-) diff --git a/websockets/headers.py b/websockets/headers.py index 873984f0f..9140ff110 100644 --- a/websockets/headers.py +++ b/websockets/headers.py @@ -93,6 +93,67 @@ def parse_quoted_string(string, pos, header_name): return _unquote_re.sub(r'\1', match.group()[1:-1]), match.end() +def parse_list(parse_item, string, pos, header_name): + """ + Parse a comma-separated list from ``string`` at the given position. + + This is appropriate for parsing values with the following grammar: + + 1#item + + ``parse_item`` parses one item. + + ``string`` is assumed not to start or end with whitespace. + + (This function is designed for parsing an entire header value and + :func:`~websockets.http.read_headers` strips whitespace from values.) + + Return a list of items. + + Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. + + """ + # Per https://tools.ietf.org/html/rfc7230#section-7, "a recipient MUST + # parse and ignore a reasonable number of empty list elements"; hence + # while loops that remove extra delimiters. + + # Remove extra delimiters before the first item. + while peek_ahead(string, pos) == ',': + pos = parse_OWS(string, pos + 1) + + items = [] + while True: + # Loop invariant: a item starts at pos in string. + item, pos = parse_item(string, pos, header_name) + items.append(item) + pos = parse_OWS(string, pos) + + # We may have reached the end of the string. + if pos == len(string): + break + + # There must be a delimiter after each element except the last one. + if peek_ahead(string, pos) == ',': + pos = parse_OWS(string, pos + 1) + else: + raise InvalidHeaderFormat( + header_name, "expected comma", string=string, pos=pos) + + # Remove extra delimiters before the next item. + while peek_ahead(string, pos) == ',': + pos = parse_OWS(string, pos + 1) + + # We may have reached the end of the string. + if pos == len(string): + break + + # Since we only advance in the string by one character with peek_ahead() + # or with the end position of a regex match, we can't overshoot the end. + assert pos == len(string) + + return items + + def parse_extension_param(string, pos, header_name): """ Parse a single extension parameter from ``string`` at the given position. @@ -148,13 +209,10 @@ def parse_extension(string, pos, header_name): return (name, parameters), pos -def parse_extension_list( - string, pos=0, header_name='Sec-WebSocket-Extensions'): +def parse_extension_list(string): """ Parse a ``Sec-WebSocket-Extensions`` header. - The string is assumed not to start or end with whitespace. - Return a value with the following format:: [ @@ -173,44 +231,7 @@ def parse_extension_list( Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. """ - # Per https://tools.ietf.org/html/rfc7230#section-7, "a recipient MUST - # parse and ignore a reasonable number of empty list elements"; hence - # while loops that remove extra delimiters. - - # Remove extra delimiters before the first extension. - while peek_ahead(string, pos) == ',': - pos = parse_OWS(string, pos + 1) - - extensions = [] - while True: - # Loop invariant: an extension starts at pos in string. - extension, pos = parse_extension(string, pos, header_name) - extensions.append(extension) - - # We may have reached the end of the string. - if pos == len(string): - break - - # There must be a delimiter after each element except the last one. - if peek_ahead(string, pos) == ',': - pos = parse_OWS(string, pos + 1) - else: - raise InvalidHeaderFormat( - header_name, "expected comma", string=string, pos=pos) - - # Remove extra delimiters before the next extension. - while peek_ahead(string, pos) == ',': - pos = parse_OWS(string, pos + 1) - - # We may have reached the end of the string. - if pos == len(string): - break - - # Since we only advance in the string by one character with peek_ahead() - # or with the end position of a regex match, we can't overshoot the end. - assert pos == len(string) - - return extensions + return parse_list(parse_extension, string, 0, 'Sec-WebSocket-Extensions') def build_extension(name, parameters): @@ -240,69 +261,14 @@ def build_extension_list(extensions): ) -def parse_protocol(string, pos, header_name): - """ - Parse a protocol definition from ``string`` at the given position. - - Return the protocol and the new position. - - Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. - - """ - name, pos = parse_token(string, pos, header_name) - pos = parse_OWS(string, pos) - return name, pos - - -def parse_protocol_list(string, pos=0, header_name='Sec-WebSocket-Protocol'): +def parse_protocol_list(string): """ Parse a ``Sec-WebSocket-Protocol`` header. - The string is assumed not to start or end with whitespace. - - Return a list of protocols. - Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. """ - # Per https://tools.ietf.org/html/rfc7230#section-7, "a recipient MUST - # parse and ignore a reasonable number of empty list elements"; hence - # while loops that remove extra delimiters. - - # Remove extra delimiters before the first extension. - while peek_ahead(string, pos) == ',': - pos = parse_OWS(string, pos + 1) - - protocols = [] - while True: - # Loop invariant: a protocol starts at pos in string. - protocol, pos = parse_protocol(string, pos, header_name) - protocols.append(protocol) - - # We may have reached the end of the string. - if pos == len(string): - break - - # There must be a delimiter after each element except the last one. - if peek_ahead(string, pos) == ',': - pos = parse_OWS(string, pos + 1) - else: - raise InvalidHeaderFormat( - header_name, "expected comma", string=string, pos=pos) - - # Remove extra delimiters before the next protocol. - while peek_ahead(string, pos) == ',': - pos = parse_OWS(string, pos + 1) - - # We may have reached the end of the string. - if pos == len(string): - break - - # Since we only advance in the string by one character with peek_ahead() - # or with the end position of a regex match, we can't overshoot the end. - assert pos == len(string) - - return protocols + return parse_list(parse_token, string, 0, 'Sec-WebSocket-Protocol') def build_protocol_list(protocols): From cffbed8f33995e70d80226d73d7b9c3edc8bcfe2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 16 May 2018 22:08:19 +0200 Subject: [PATCH 0418/1539] Improve parsing of Connection and Upgrade headers. Revert 43f4a6db which was based on a misunderstanding of the ABNF. This change uniformizes header parsing logic and provides better error messages on invalid headers. --- websockets/client.py | 10 ++--- websockets/handshake.py | 21 ++++++---- websockets/headers.py | 53 ++++++++++++++++++++++++-- websockets/http.py | 2 +- websockets/server.py | 4 +- websockets/test_headers.py | 78 +++++++++++++++++++++++++++++++++----- 6 files changed, 139 insertions(+), 29 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index f873903cb..f8f7c7709 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -13,8 +13,8 @@ from .extensions.permessage_deflate import ClientPerMessageDeflateFactory from .handshake import build_request, check_response from .headers import ( - build_extension_list, build_protocol_list, parse_extension_list, - parse_protocol_list + build_extension_list, build_subprotocol_list, parse_extension_list, + parse_subprotocol_list ) from .http import USER_AGENT, basic_auth_header, build_headers, read_response from .protocol import WebSocketCommonProtocol @@ -163,7 +163,7 @@ def process_subprotocol(headers, available_subprotocols): """ Handle the Sec-WebSocket-Protocol HTTP response header. - Check that it contains a supported subprotocol. + Check that it contains exactly one supported subprotocol. Return the selected subprotocol. @@ -178,7 +178,7 @@ def process_subprotocol(headers, available_subprotocols): raise InvalidHandshake("No subprotocols supported") parsed_header_values = sum([ - parse_protocol_list(header_value) + parse_subprotocol_list(header_value) for header_value in header_values ], []) @@ -242,7 +242,7 @@ def handshake(self, wsuri, origin=None, available_extensions=None, set_header('Sec-WebSocket-Extensions', extensions_header) if available_subprotocols is not None: - protocol_header = build_protocol_list(available_subprotocols) + protocol_header = build_subprotocol_list(available_subprotocols) set_header('Sec-WebSocket-Protocol', protocol_header) if extra_headers is not None: diff --git a/websockets/handshake.py b/websockets/handshake.py index 142fef61c..a23428a11 100644 --- a/websockets/handshake.py +++ b/websockets/handshake.py @@ -40,6 +40,7 @@ import random from .exceptions import InvalidHeaderValue, InvalidUpgrade +from .headers import parse_connection, parse_upgrade __all__ = [ @@ -82,12 +83,14 @@ def check_request(get_header): responsibility of the caller. """ - if get_header('Upgrade').lower() != 'websocket': - raise InvalidUpgrade('Upgrade', get_header('Upgrade')) - - if get_header('Connection').lower() != 'upgrade': + connection = parse_connection(get_header('Connection')) + if not any(value.lower() == 'upgrade' for value in connection): raise InvalidUpgrade('Connection', get_header('Connection')) + upgrade = parse_upgrade(get_header('Upgrade')) + if not (len(upgrade) == 1 and upgrade[0] == 'websocket'): + raise InvalidUpgrade('Upgrade', get_header('Upgrade')) + key = get_header('Sec-WebSocket-Key') try: raw_key = base64.b64decode(key.encode(), validate=True) @@ -131,12 +134,14 @@ def check_response(get_header, key): the caller. """ - if get_header('Upgrade').lower() != 'websocket': - raise InvalidUpgrade('Upgrade', get_header('Upgrade')) - - if get_header('Connection').lower() != 'upgrade': + connection = parse_connection(get_header('Connection')) + if not any(value.lower() == 'upgrade' for value in connection): raise InvalidUpgrade('Connection', get_header('Connection')) + upgrade = parse_upgrade(get_header('Upgrade')) + if not (len(upgrade) == 1 and upgrade[0] == 'websocket'): + raise InvalidUpgrade('Upgrade', get_header('Upgrade')) + if get_header('Sec-WebSocket-Accept') != accept(key): raise InvalidHeaderValue( 'Sec-WebSocket-Accept', get_header('Sec-WebSocket-Accept')) diff --git a/websockets/headers.py b/websockets/headers.py index 9140ff110..a88c975a9 100644 --- a/websockets/headers.py +++ b/websockets/headers.py @@ -13,8 +13,9 @@ __all__ = [ + 'parse_connection', 'parse_upgrade', 'parse_extension_list', 'build_extension_list', - 'parse_protocol_list', 'build_protocol_list', + 'parse_subprotocol_list', 'build_subprotocol_list', ] @@ -154,6 +155,50 @@ def parse_list(parse_item, string, pos, header_name): return items +def parse_connection(string): + """ + Parse a ``Connection`` header. + + Return a list of connection options. + + Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. + + """ + return parse_list(parse_token, string, 0, 'Connection') + + +_protocol_re = re.compile( + r'[-!#$%&\'*+.^_`|~0-9a-zA-Z]+(?:/[-!#$%&\'*+.^_`|~0-9a-zA-Z]+)?') + + +def parse_protocol(string, pos, header_name): + """ + Parse a protocol from ``string`` at the given position. + + Return the protocol value and the new position. + + Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. + + """ + match = _protocol_re.match(string, pos) + if match is None: + raise InvalidHeaderFormat( + header_name, "expected protocol", string=string, pos=pos) + return match.group(), match.end() + + +def parse_upgrade(string): + """ + Parse an ``Upgrade`` header. + + Return a list of connection options. + + Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. + + """ + return parse_list(parse_protocol, string, 0, 'Upgrade') + + def parse_extension_param(string, pos, header_name): """ Parse a single extension parameter from ``string`` at the given position. @@ -261,7 +306,7 @@ def build_extension_list(extensions): ) -def parse_protocol_list(string): +def parse_subprotocol_list(string): """ Parse a ``Sec-WebSocket-Protocol`` header. @@ -271,11 +316,11 @@ def parse_protocol_list(string): return parse_list(parse_token, string, 0, 'Sec-WebSocket-Protocol') -def build_protocol_list(protocols): +def build_subprotocol_list(protocols): """ Unparse a ``Sec-WebSocket-Protocol`` header. - This is the reverse of :func:`parse_protocol_list`. + This is the reverse of :func:`parse_subprotocol_list`. """ return ', '.join(protocols) diff --git a/websockets/http.py b/websockets/http.py index 3fef6d34b..25f32c34e 100644 --- a/websockets/http.py +++ b/websockets/http.py @@ -127,7 +127,7 @@ def read_response(stream): # This may raise "ValueError: invalid literal for int() with base 10" status_code = int(status_code) if not 100 <= status_code < 1000: - raise ValueError("Unsupported HTTP status_code code: %d" % status_code) + raise ValueError("Unsupported HTTP status code: %d" % status_code) if not _value_re.fullmatch(reason): raise ValueError("Invalid HTTP reason phrase: %r" % reason) diff --git a/websockets/server.py b/websockets/server.py index 8836b3e93..85dea02e0 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -19,7 +19,7 @@ from .extensions.permessage_deflate import ServerPerMessageDeflateFactory from .handshake import build_response, check_request from .headers import ( - build_extension_list, parse_extension_list, parse_protocol_list + build_extension_list, parse_extension_list, parse_subprotocol_list ) from .http import USER_AGENT, build_headers, read_request from .protocol import WebSocketCommonProtocol @@ -370,7 +370,7 @@ def process_subprotocol(self, headers, available_subprotocols): if header_values is not None and available_subprotocols is not None: parsed_header_values = sum([ - parse_protocol_list(header_value) + parse_subprotocol_list(header_value) for header_value in header_values ], []) diff --git a/websockets/test_headers.py b/websockets/test_headers.py index efb41ae29..ef5f28eda 100644 --- a/websockets/test_headers.py +++ b/websockets/test_headers.py @@ -6,6 +6,66 @@ class HeadersTests(unittest.TestCase): + def test_parse_connection(self): + for header, parsed in [ + # Realistic use cases + ( + 'Upgrade', # Safari, Chrome + ['Upgrade'], + ), + ( + 'keep-alive, Upgrade', # Firefox + ['keep-alive', 'Upgrade'], + ), + # Pathological example + ( + ',,\t, , ,Upgrade ,,', + ['Upgrade'], + ), + ]: + with self.subTest(header=header): + self.assertEqual(parse_connection(header), parsed) + + def test_parse_connection_invalid_header(self): + for header in [ + '???', + 'keep-alive; Upgrade', + ]: + with self.subTest(header=header): + with self.assertRaises(InvalidHeaderFormat): + parse_connection(header) + + def test_parse_upgrade(self): + for header, parsed in [ + # Realistic use case + ( + 'websocket', + ['websocket'], + ), + # Synthetic example + ( + 'http/3.0, websocket', + ['http/3.0', 'websocket'] + ), + # Pathological example + ( + ',, WebSocket, \t,,', + ['WebSocket'], + ), + ]: + with self.subTest(header=header): + self.assertEqual(parse_upgrade(header), parsed) + + def test_parse_upgrade_invalid_header(self): + for header in [ + '???', + 'websocket 2', + 'http/3.0; websocket', + ]: + with self.subTest(header=header): + with self.assertRaises(InvalidHeaderFormat): + parse_upgrade(header) + def test_parse_extension_list(self): for header, parsed in [ # Synthetic examples @@ -26,7 +86,7 @@ def test_parse_extension_list(self): ('bar', [('quux', None), ('quuux', None)]), ], ), - # Pathological examples + # Pathological example ( ',\t, , ,foo ;bar = 42,, baz,,', [('foo', [('bar', '42')]), ('baz', [])], @@ -69,7 +129,7 @@ def test_parse_extension_list_invalid_header(self): with self.assertRaises(InvalidHeaderFormat): parse_extension_list(header) - def test_parse_protocol_list(self): + def test_parse_subprotocol_list(self): for header, parsed in [ # Synthetic examples ( @@ -80,19 +140,19 @@ def test_parse_protocol_list(self): 'foo, bar', ['foo', 'bar'], ), - # Pathological examples + # Pathological example ( ',\t, , ,foo ,, bar,baz,,', ['foo', 'bar', 'baz'], ), ]: with self.subTest(header=header): - self.assertEqual(parse_protocol_list(header), parsed) - # Also ensure that build_protocol_list round-trips cleanly. - unparsed = build_protocol_list(parsed) - self.assertEqual(parse_protocol_list(unparsed), parsed) + self.assertEqual(parse_subprotocol_list(header), parsed) + # Also ensure that build_subprotocol_list round-trips cleanly. + unparsed = build_subprotocol_list(parsed) + self.assertEqual(parse_subprotocol_list(unparsed), parsed) - def test_parse_protocol_list_invalid_header(self): + def test_parse_subprotocol_list_invalid_header(self): for header in [ # Truncated examples '', @@ -102,4 +162,4 @@ def test_parse_protocol_list_invalid_header(self): ]: with self.subTest(header=header): with self.assertRaises(InvalidHeaderFormat): - parse_protocol_list(header) + parse_subprotocol_list(header) From 23d208797c43f28bc278160de4309707c184f280 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 17 May 2018 22:41:17 +0200 Subject: [PATCH 0419/1539] Fix a race condition in the closing handshake. Fix #339. --- docs/changelog.rst | 3 ++ websockets/protocol.py | 72 ++++++++++++++++++++----------------- websockets/test_protocol.py | 15 ++++---- 3 files changed, 52 insertions(+), 38 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index cf48a298a..ab01df5e3 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -36,6 +36,9 @@ Also: * Fixed missing close code, which caused :exc:`TypeError` on connection close. +* Fixed a race condition in the closing handshake that raised + :exc:`~exceptions.InvalidState`. + * Stopped logging stack traces when the TCP connection dies prematurely. 4.0 diff --git a/websockets/protocol.py b/websockets/protocol.py index 563594bb4..bee5eaf0b 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -374,20 +374,16 @@ def close(self, code=1000, reason=''): ``code`` must be an :class:`int` and ``reason`` a :class:`str`. """ - if self.state is State.OPEN: - # 7.1.2. Start the WebSocket Closing Handshake - # 7.1.3. The WebSocket Closing Handshake is Started - frame_data = serialize_close(code, reason) - try: - yield from asyncio.wait_for( - self.write_frame(OP_CLOSE, frame_data), - self.timeout, loop=self.loop) - except asyncio.TimeoutError: - # If the close frame cannot be sent because the send buffers - # are full, the closing handshake won't complete anyway. - # Cancel the data transfer task to shut down faster. - # Cancelling a task is idempotent. - self.transfer_data_task.cancel() + try: + yield from asyncio.wait_for( + self.write_close_frame(serialize_close(code, reason)), + self.timeout, loop=self.loop) + except asyncio.TimeoutError: + # If the close frame cannot be sent because the send buffers + # are full, the closing handshake won't complete anyway. + # Cancel the data transfer task to shut down faster. + # Cancelling a task is idempotent. + self.transfer_data_task.cancel() # If no close frame is received within the timeout, wait_for() cancels # the data transfer task and raises TimeoutError. Then transfer_data() @@ -402,7 +398,8 @@ def close(self, code=1000, reason=''): # is cancelled before the timeout elapses (on Python ≥ 3.4.3). # This helps closing connections when shutting down a server. yield from asyncio.wait_for( - self.transfer_data_task, self.timeout, loop=self.loop) + self.transfer_data_task, + self.timeout, loop=self.loop) except asyncio.TimeoutError: pass @@ -611,14 +608,13 @@ def read_data_frame(self, max_size): # 5.5. Control Frames if frame.opcode == OP_CLOSE: - # Make sure the close frame is valid before echoing it. - code, reason = parse_close(frame.data) # 7.1.5. The WebSocket Connection Close Code # 7.1.6. The WebSocket Connection Close Reason - self.close_code, self.close_reason = code, reason - if self.state is State.OPEN: - # 7.1.3. The WebSocket Closing Handshake is Started - yield from self.write_frame(OP_CLOSE, frame.data) + self.close_code, self.close_reason = parse_close(frame.data) + # Echo the original data instead of re-serializing it with + # serialize_close() because that fails when the close frame is + # empty and parse_close() synthetizes a 1005 close code. + yield from self.write_close_frame(frame.data) return elif frame.opcode == OP_PING: @@ -679,18 +675,12 @@ def read_frame(self, max_size): return frame @asyncio.coroutine - def write_frame(self, opcode, data=b''): + def write_frame(self, opcode, data=b'', _expected_state=State.OPEN): # Defensive assertion for protocol compliance. - if self.state is not State.OPEN: # pragma: no cover + if self.state is not _expected_state: # pragma: no cover raise InvalidState("Cannot write to a WebSocket " "in the {} state".format(self.state.name)) - # Make sure no other frame will be sent after a close frame. Do this - # before yielding control to avoid sending more than one close frame. - if opcode == OP_CLOSE: - self.state = State.CLOSING - logger.debug("%s - state = CLOSING", self.side) - frame = Frame(True, opcode, data) logger.debug("%s > %s", self.side, frame) frame.write( @@ -737,6 +727,25 @@ def writer_is_closing(self): except AttributeError: return transport._closed + @asyncio.coroutine + def write_close_frame(self, data=b''): + """ + Write a close frame if and only if the connection state is OPEN. + + This dedicated coroutine must be used for writing close frames to + ensure that at most one close frame is sent on a given connection. + + """ + # Test and set the connection state before sending the close frame to + # avoid sending two frames in case of concurrent calls. + if self.state is State.OPEN: + # 7.1.3. The WebSocket Closing Handshake is Started + self.state = State.CLOSING + logger.debug("%s - state = CLOSING", self.side) + + # 7.1.2. Start the WebSocket Closing Handshake + yield from self.write_frame(OP_CLOSE, data, State.CLOSING) + @asyncio.coroutine def close_connection(self, after_handshake=True): """ @@ -845,9 +854,8 @@ def fail_connection(self, code=1011, reason=''): self.side, code, reason, ) # Don't send a close frame if the connection is broken. - if self.state is State.OPEN and code != 1006: - frame_data = serialize_close(code, reason) - yield from self.write_frame(OP_CLOSE, frame_data) + if code != 1006: + yield from self.write_close_frame(serialize_close(code, reason)) # asyncio.StreamReaderProtocol methods diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 9048c6418..5303cc9e2 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -796,16 +796,19 @@ def test_remote_close(self): self.assertNoFrameSent() def test_simultaneous_close(self): - # Delay the incoming close frame until after we send the outgoing one. - self.loop.call_later(MS, self.receive_frame, self.remote_close) - self.loop.call_later(MS, self.receive_eof_if_client) + # Receive the incoming close frame right after self.protocol.close() + # starts executing. This reproduces the error described in: + # https://github.com/aaugustin/websockets/issues/339 + self.loop.call_soon(self.receive_frame, self.remote_close) + self.loop.call_soon(self.receive_eof_if_client) self.loop.run_until_complete(self.protocol.close(reason='local')) - # The close code and reason are taken from the remote side because - # that's presumably more useful that the values from the local side. self.assertConnectionClosed(1000, 'remote') - self.assertOneFrameSent(*self.local_close) + # The current implementation sends a close frame in response to the + # close frame received from the remote end. It skips the close frame + # that should be sent as a result of calling close(). + self.assertOneFrameSent(*self.remote_close) def test_close_preserves_incoming_frames(self): self.receive_frame(Frame(True, OP_TEXT, b'hello')) From d375a075c16b47166597d24ec88f8f792b3986dd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 19 May 2018 09:00:25 +0200 Subject: [PATCH 0420/1539] Improve check for the connection to be usable. Fix #320. --- docs/changelog.rst | 2 ++ websockets/protocol.py | 15 +++++++++++---- websockets/test_protocol.py | 15 ++++++++++++++- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index ab01df5e3..e61cc35eb 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -41,6 +41,8 @@ Also: * Stopped logging stack traces when the TCP connection dies prematurely. +* Prevented writing to a closing TCP connection during unclean shutdowns. + 4.0 ... diff --git a/websockets/protocol.py b/websockets/protocol.py index bee5eaf0b..7eee2c03a 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -271,15 +271,15 @@ def open(self): .. _EAFP: https://docs.python.org/3/glossary.html#term-eafp """ - return self.state is State.OPEN + return self.state is State.OPEN and not self.transfer_data_task.done() @property def closed(self): """ This property is ``True`` once the connection is closed. - Be aware that :attr:`open` and :attr`closed` are ``False`` when the - connection is in the OPENING or CLOSING state. + Be aware that both :attr:`open` and :attr`closed` are ``False`` during + the opening and closing sequences. """ return self.state is State.CLOSED @@ -475,7 +475,14 @@ def ensure_open(self): """ # Handle cases from most common to least common for performance. if self.state is State.OPEN: - return + # If self.transfer_data_task exited without a closing handshake, + # self.close_connection_task may be closing it, going straight + # from OPEN to CLOSED. + if self.transfer_data_task.done(): + yield from asyncio.shield(self.close_connection_task) + raise ConnectionClosed(self.close_code, self.close_reason) + else: + return if self.state is State.CLOSED: raise ConnectionClosed(self.close_code, self.close_reason) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 5303cc9e2..a9230df5b 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -728,13 +728,26 @@ def test_connection_lost(self): self.assertConnectionFailed(1006, '') - def test_ensure_connection_before_opening_handshake(self): + def test_ensure_open_before_opening_handshake(self): # Simulate a bug by forcibly reverting the protocol state. self.protocol.state = State.CONNECTING with self.assertRaises(InvalidState): self.loop.run_until_complete(self.protocol.ensure_open()) + def test_ensure_open_during_unclean_close(self): + # Process connection_made in order to start transfer_data_task. + self.run_loop_once() + + # Ensure the test terminates quickly. + self.loop.call_later(MS, self.receive_eof_if_client) + + # Simulate the situation where sending a close frame times out. + self.protocol.transfer_data_task.cancel() + + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.protocol.ensure_open()) + def test_legacy_recv(self): # By default legacy_recv in disabled. self.assertEqual(self.protocol.legacy_recv, False) From a11fa66afa3c60759a06de0cce9a8513f8bad689 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 8 May 2018 23:38:04 +0200 Subject: [PATCH 0421/1539] Improve failing the WebSocket connection. * Make it more compliant with the specification by preventing further processing of incoming frames. * Make it synchronous, which will make it easier to control error handling. * Allow transfer_data_task to be cancelled. * Extend design documentation. --- docs/changelog.rst | 4 ++ docs/design.rst | 72 +++++++++++++++++++------- websockets/protocol.py | 100 ++++++++++++++++++++++++++++-------- websockets/server.py | 4 +- websockets/test_protocol.py | 4 +- 5 files changed, 141 insertions(+), 43 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index e61cc35eb..8fa29373c 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -43,6 +43,10 @@ Also: * Prevented writing to a closing TCP connection during unclean shutdowns. +* Made connection termination more robust to network congestion. + +* Prevented processing of incoming frames after failing the connection. + 4.0 ... diff --git a/docs/design.rst b/docs/design.rst index f885b0047..e9ccde912 100644 --- a/docs/design.rst +++ b/docs/design.rst @@ -32,16 +32,18 @@ WebSocket connections go through a trivial state machine: Transitions happen in the following places: - ``CONNECTING -> OPEN``: in - :meth:`~protocol.WebSocketCommonProtocol.connection_open()`, which runs - when the :ref:`opening handshake ` completes and the - WebSocket connection is established — not to be confused with - :meth:`~asyncio.Protocol.connection_made` which runs earlier, when the TCP - connection is established; + :meth:`~protocol.WebSocketCommonProtocol.connection_open()` which runs when + the :ref:`opening handshake ` completes and the WebSocket + connection is established — not to be confused with + :meth:`~asyncio.Protocol.connection_made` which runs when the TCP connection + is established; - ``OPEN -> CLOSING``: in :meth:`~protocol.WebSocketCommonProtocol.write_frame()` immediately before sending a close frame; since receiving a close frame triggers sending a close frame, this does the right thing regardless of which side started the - :ref:`closing handshake `; + :ref:`closing handshake `; also in + :meth:`~protocol.WebSocketCommonProtocol.fail_connection()` which duplicates + a few lines of code from `write_close_frame()` and `write_frame()`; - ``* -> CLOSED``: in :meth:`~protocol.WebSocketCommonProtocol.connection_lost()` which is always called exactly once when the TCP connection is closed. @@ -68,14 +70,32 @@ two tasks: - :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` runs :meth:`~protocol.WebSocketCommonProtocol.transfer_data()` which handles incoming data and lets :meth:`~protocol.WebSocketCommonProtocol.recv()` - consume it. It never exits with an exception but it may be cancelled. - See :ref:`data transfer ` below. + consume it. It may be cancelled to terminate the connection. It never exits + with an exception other than :exc:`~asyncio.CancelledError`. See :ref:`data + transfer ` below. + - :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` runs :meth:`~protocol.WebSocketCommonProtocol.close_connection()` which waits for the data transfer to terminate, then takes care of closing the TCP - connection. It never exits with an exception and it is never cancelled. See + connection. It must not be cancelled. It never exits with an exception. See :ref:`connection termination ` below. +Splitting the responsibilities between two tasks makes it easier to guarantee +that ``websockets`` can terminate connections: + +- within a fixed timeout, +- without leaking pending tasks, +- without leaking open TCP connections, + +regardless of whether the connection terminates normally or abnormally. + +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` completes when no +more data will be received on the connection. Under normal circumstances, it +exits after exchanging close frames. + +:attr:`~protocol.WebSocketCommonProtocol.close_connection_task` completes when +the TCP connection is closed. + .. _opening-handshake: @@ -121,9 +141,8 @@ lies in the negociation of extensions and, to a lesser extent, of the subprotocol. The server knows everything about both sides and decides what the parameters should be for the connection. The client merely applies them. -If anything goes wrong during the opening handshake, ``websockets`` closes the -TCP connection. This is the proper way to fail the WebSocket connection before -it's established. +If anything goes wrong during the opening handshake, ``websockets`` +:ref:`fails the connection `. .. _data-transfer: @@ -213,8 +232,8 @@ The right side of the diagram shows how ``websockets`` sends data. containing the message. Fragmentation isn't supported at this time. :meth:`~protocol.WebSocketCommonProtocol.ping()` writes a ping frame and -returns a :class:`~asyncio.Future` which will be completed when a matching -pong frame is received. +yields a :class:`~asyncio.Future` which will be completed when a matching pong +frame is received. :meth:`~protocol.WebSocketCommonProtocol.pong()` writes a pong frame. @@ -243,9 +262,7 @@ state and sends a close frame. When the other side sends a close frame, :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to terminate. If the other side doesn't send a close frame within the connection's timeout, -:meth:`~protocol.WebSocketCommonProtocol.close()` cancels -:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`, which has the -same effect. +``websockets`` :ref:`fails the connection `. The closing handshake can take up to ``2 * timeout``: one ``timeout`` to write a close frame and one ``timeout`` to receive a close frame. @@ -271,8 +288,8 @@ which may happen as a result of: - a timeout while waiting for the closing handshake to complete: this cancels :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`; - a protocol error, including connection errors: depending on the exception, - :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` fails the - WebSocket connection with a suitable code and exits. + :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` :ref:`fails the + connection `_ with a suitable code and exits. :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` is separate from :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to make it @@ -290,6 +307,23 @@ aborting the connection. At this point the connection drops regardless of what happens on the network. +.. _connection-failure: + +Connection failure +------------------ + +If the opening handshake doesn't complete successfully, ``websockets`` fails +the connection by closing the TCP connection. + +Once the opening handshake has completed, ``websockets`` fails the connection +by cancelling :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` and +sending a close frame if appropriate. + +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` exits, unblocking +:attr:`~protocol.WebSocketCommonProtocol.close_connection_task`, which closes +the TCP connection. + + .. _cancellation: Cancellation diff --git a/websockets/protocol.py b/websockets/protocol.py index 7eee2c03a..a073188e5 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -317,6 +317,7 @@ def recv(self): # Wait for a message until the connection is closed. next_message = asyncio_ensure_future( self.messages.get(), loop=self.loop) + # See https://bugs.python.org/issue23859 for cancellation handling. try: done, pending = yield from asyncio.wait( [next_message, self.transfer_data_task], @@ -381,9 +382,8 @@ def close(self, code=1000, reason=''): except asyncio.TimeoutError: # If the close frame cannot be sent because the send buffers # are full, the closing handshake won't complete anyway. - # Cancel the data transfer task to shut down faster. - # Cancelling a task is idempotent. - self.transfer_data_task.cancel() + # Fail the connection to shut down faster. + self.fail_connection() # If no close frame is received within the timeout, wait_for() cancels # the data transfer task and raises TimeoutError. Then transfer_data() @@ -400,7 +400,7 @@ def close(self, code=1000, reason=''): yield from asyncio.wait_for( self.transfer_data_task, self.timeout, loop=self.loop) - except asyncio.TimeoutError: + except (asyncio.TimeoutError, asyncio.CancelledError): pass # Wait for the close connection task to close the TCP connection. @@ -516,25 +516,35 @@ def transfer_data(self): if msg is None: break yield from self.messages.put(msg) + except asyncio.CancelledError: - # This happens if self.close() cancels self.transfer_data_task. - pass + # If fail_connection() cancels this task, avoid logging the error + # twice and failing the connection again. + raise + except WebSocketProtocolError: - yield from self.fail_connection(1002) + self.fail_connection(1002) + except (ConnectionError, EOFError): # Reading data with self.reader.readexactly may raise: # - most subclasses of ConnectionError if the TCP connection # breaks, is reset, or is aborted; # - IncompleteReadError, a subclass of EOFError, if fewer # bytes are available than requested. - yield from self.fail_connection(1006) + self.fail_connection(1006) + except UnicodeDecodeError: - yield from self.fail_connection(1007) + self.fail_connection(1007) + except PayloadTooBig: - yield from self.fail_connection(1009) + self.fail_connection(1009) + except Exception: - logger.warning("Error in data transfer", exc_info=True) - yield from self.fail_connection(1011) + # This shouldn't happen often because exceptions expected under + # regular circumstances are handled above. If it does, consider + # catching and handling more exceptions. + logger.error("Error in data transfer", exc_info=True) + self.fail_connection(1011) @asyncio.coroutine def read_message(self): @@ -711,7 +721,7 @@ def write_frame(self, opcode, data=b'', _expected_state=State.OPEN): yield from self.writer.drain() except ConnectionError: # Terminate the connection if the socket died. - yield from self.fail_connection(1006) + self.fail_connection() # And raise an exception, since the frame couldn't be sent. raise ConnectionClosed(self.close_code, self.close_reason) @@ -769,7 +779,10 @@ def close_connection(self, after_handshake=True): try: # Wait for the data transfer phase to complete. if after_handshake: - yield from self.transfer_data_task + try: + yield from self.transfer_data_task + except asyncio.CancelledError: + pass # Cancel all pending pings because they'll never receive a pong. for ping in self.pings.values(): @@ -784,7 +797,7 @@ def close_connection(self, after_handshake=True): "%s - cancelled pending ping%s: %s", self.side, plural, pings_hex) - # A client should wait for a TCP Close from the server. + # A client should wait for a TCP close from the server. if self.is_client and after_handshake: if (yield from self.wait_for_connection_lost()): return @@ -850,19 +863,66 @@ def wait_for_connection_lost(self): # and the moment this coroutine resumes running. return self.connection_lost_waiter.done() - @asyncio.coroutine - def fail_connection(self, code=1011, reason=''): + def fail_connection(self, code=1006, reason=''): """ 7.1.7. Fail the WebSocket Connection + This requires: + + 1. Stopping all processing of incoming data, which means cancelling + :attr:`transfer_data_task`. The close code will be 1006 unless a + close frame was received earlier. + + 2. Sending a close frame with an appropriate code if the opening + handshake succeeded and the other side is likely to process it. + + 3. Closing the connection. :meth:`close_connection` takes care of + this once :attr:`transfer_data_task` exits after being cancelled. + + (The specification describes these steps in the opposite order.) + """ + # fail_connection() only supports the case when the opening handshake + # succeeded. Before that, use close_connection(after_handshake=False). + assert self.state is not State.CONNECTING, ( + "fail_connection() doesn't support the CONNECTING state") + logger.debug( "%s ! failing WebSocket connection: %d %s", self.side, code, reason, ) + + # transfer_data_task was started when the opening handshake succeeded. + # cancel() is idempotent and ignored if the task is done already. + self.transfer_data_task.cancel() + + # Send a close frame when the state is OPEN (a close frame was already + # sent if it's CLOSING), except when failing the connection because of + # an error reading from or writing to the network. # Don't send a close frame if the connection is broken. - if code != 1006: - yield from self.write_close_frame(serialize_close(code, reason)) + if code != 1006 and self.state is State.OPEN: + + frame_data = serialize_close(code, reason) + + # Write the close frame without draining the write buffer. + + # Keeping fail_connection() synchronous guarantees it can't + # get stuck and simplifies the implementation of the callers. + # Not drainig the write buffer is acceptable in this context. + + # This duplicates a few lines of code from write_close_frame() + # and write_frame(). + + self.state = State.CLOSING + logger.debug("%s - state = CLOSING", self.side) + + frame = Frame(True, OP_CLOSE, frame_data) + logger.debug("%s > %s", self.side, frame) + frame.write( + self.writer.write, + mask=self.is_client, + extensions=self.extensions, + ) # asyncio.StreamReaderProtocol methods @@ -902,7 +962,7 @@ def eof_received(self): returned to have the same behavior on TLS and plain connections. 3. The websockets protocol has its own closing handshake. Endpoints - close the TCP connection after sending a Close frame. + close the TCP connection after sending a close frame. As a consequence we revert to the previous, more useful behavior. diff --git a/websockets/server.py b/websockets/server.py index 85dea02e0..38f879e35 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -142,10 +142,10 @@ def handler(self): yield from self.ws_handler(self, path) except Exception as exc: if self._is_server_shutting_down(exc): - yield from self.fail_connection(1001) + self.fail_connection(1001) else: logger.error("Error in connection handler", exc_info=True) - yield from self.fail_connection(1011) + self.fail_connection(1011) raise try: diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index a9230df5b..19ec7bca6 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -742,8 +742,8 @@ def test_ensure_open_during_unclean_close(self): # Ensure the test terminates quickly. self.loop.call_later(MS, self.receive_eof_if_client) - # Simulate the situation where sending a close frame times out. - self.protocol.transfer_data_task.cancel() + # Simulate the case when close() times out sending a close frame. + self.protocol.fail_connection() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.ensure_open()) From 0a1b76a01e5ea6c2ea9f9a99e141dc85db9bb4cf Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 12 May 2018 11:11:47 +0200 Subject: [PATCH 0422/1539] Fail the connection before the opening handshake. Revert 9eefbc25: now close_connection_task is always started, either by connection_open or by fail_connection. This doesn't change the behavior but makes the implementtation more maintainable by better mirroring the specification. --- docs/design.rst | 4 ++++ websockets/client.py | 2 +- websockets/protocol.py | 30 ++++++++++++++++++------------ websockets/server.py | 5 ++--- 4 files changed, 25 insertions(+), 16 deletions(-) diff --git a/docs/design.rst b/docs/design.rst index e9ccde912..9974c1cb8 100644 --- a/docs/design.rst +++ b/docs/design.rst @@ -80,6 +80,10 @@ two tasks: connection. It must not be cancelled. It never exits with an exception. See :ref:`connection termination ` below. +Besides, :meth:`~protocol.WebSocketCommonProtocol.fail_connection()` starts +the same :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` when +the opening handshake fails, in order to close the TCP connection. + Splitting the responsibilities between two tasks makes it easier to guarantee that ``websockets`` can terminate connections: diff --git a/websockets/client.py b/websockets/client.py index f8f7c7709..92f29e9f5 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -404,7 +404,7 @@ def __await__(self): extra_headers=protocol.extra_headers, ) except Exception: - yield from protocol.close_connection(after_handshake=False) + yield from protocol.fail_connection() raise self.ws_client = protocol diff --git a/websockets/protocol.py b/websockets/protocol.py index a073188e5..869cc6b9e 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -764,7 +764,7 @@ def write_close_frame(self, data=b''): yield from self.write_frame(OP_CLOSE, data, State.CLOSING) @asyncio.coroutine - def close_connection(self, after_handshake=True): + def close_connection(self): """ 7.1.1. Close the WebSocket Connection @@ -772,13 +772,13 @@ def close_connection(self, after_handshake=True): this coroutine in a task. It waits for the data transfer phase to complete then it closes the TCP connection cleanly. - When the opening handshake fails, the client or the server runs this - coroutine with ``after_handshake=False`` to close the TCP connection. + When the opening handshake fails, :meth:`fail_connection` does the + same. There's no data transfer phase in that case. """ try: # Wait for the data transfer phase to complete. - if after_handshake: + if self.transfer_data_task is not None: try: yield from self.transfer_data_task except asyncio.CancelledError: @@ -798,7 +798,7 @@ def close_connection(self, after_handshake=True): self.side, plural, pings_hex) # A client should wait for a TCP close from the server. - if self.is_client and after_handshake: + if self.is_client and self.transfer_data_task is not None: if (yield from self.wait_for_connection_lost()): return logger.debug( @@ -881,20 +881,19 @@ def fail_connection(self, code=1006, reason=''): (The specification describes these steps in the opposite order.) - """ - # fail_connection() only supports the case when the opening handshake - # succeeded. Before that, use close_connection(after_handshake=False). - assert self.state is not State.CONNECTING, ( - "fail_connection() doesn't support the CONNECTING state") + Return a :class:`~asyncio.Task` that completes when the TCP connection + is closed. + """ logger.debug( "%s ! failing WebSocket connection: %d %s", self.side, code, reason, ) - # transfer_data_task was started when the opening handshake succeeded. + # Cancel transfer_data_task if the opening handshake succeeded. # cancel() is idempotent and ignored if the task is done already. - self.transfer_data_task.cancel() + if self.transfer_data_task is not None: + self.transfer_data_task.cancel() # Send a close frame when the state is OPEN (a close frame was already # sent if it's CLOSING), except when failing the connection because of @@ -924,6 +923,13 @@ def fail_connection(self, code=1006, reason=''): extensions=self.extensions, ) + # Start close_connection_task if the opening handshake didn't succeed. + if self.close_connection_task is None: + self.close_connection_task = asyncio_ensure_future( + self.close_connection(), loop=self.loop) + + return self.close_connection_task + # asyncio.StreamReaderProtocol methods def connection_made(self, transport): diff --git a/websockets/server.py b/websockets/server.py index 38f879e35..6faa9197a 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -134,7 +134,7 @@ def handler(self): ) yield from self.write_http_response(*early_response) - yield from self.close_connection(after_handshake=False) + yield from self.fail_connection() return @@ -585,8 +585,7 @@ def wait_closed(self): yield from asyncio.wait( [websocket.handler_task for websocket in self.websockets] + [websocket.close_connection_task - for websocket in self.websockets - if websocket.close_connection_task], + for websocket in self.websockets], loop=self.loop) yield from self.server.wait_closed() From a28d14763adcd3fc97985753a32ada831d7f3f2e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 12 May 2018 13:52:08 +0200 Subject: [PATCH 0423/1539] Don't call fail_connection on closed connections. This change prevents spurious debug logs. fail_connection didn't do anything in the CLOSED state anyway. --- websockets/server.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/websockets/server.py b/websockets/server.py index 6faa9197a..8db048282 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -142,10 +142,12 @@ def handler(self): yield from self.ws_handler(self, path) except Exception as exc: if self._is_server_shutting_down(exc): - self.fail_connection(1001) + if not self.closed: + self.fail_connection(1001) else: logger.error("Error in connection handler", exc_info=True) - self.fail_connection(1011) + if not self.closed: + self.fail_connection(1011) raise try: From b32a44a998708d6817b8ea5a484bfcf6698fef66 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 12 May 2018 13:50:37 +0200 Subject: [PATCH 0424/1539] Report cause of ConnectionClosed exception. This makes it much easier to figure out why the connection drops, all the more since the close code is always 1006 when failing the WebSocket connection. Fix #368. --- docs/changelog.rst | 2 ++ websockets/protocol.py | 33 +++++++++++++++++++++++---------- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 8fa29373c..c4fc2898c 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -32,6 +32,8 @@ Also: * If a :meth:`~protocol.WebSocketCommonProtocol.ping` doesn't receive a pong, it's cancelled when the connection is closed. +* Reported the cause of :exc:`~exceptions.ConnectionClosed` exceptions. + * Updated documentation with new features from Python 3.6. * Fixed missing close code, which caused :exc:`TypeError` on connection close. diff --git a/websockets/protocol.py b/websockets/protocol.py index 869cc6b9e..e6a0bef4c 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -198,6 +198,9 @@ def __init__(self, *, # Task running the data transfer. self.transfer_data_task = None + # Exception that occurred during data transfer, if any. + self.transfer_data_exc = None + # Task closing the TCP connection. self.close_connection_task = None @@ -485,7 +488,8 @@ def ensure_open(self): return if self.state is State.CLOSED: - raise ConnectionClosed(self.close_code, self.close_reason) + raise ConnectionClosed( + self.close_code, self.close_reason) from self.transfer_data_exc if self.state is State.CLOSING: # If we started the closing handshake, wait for its completion to @@ -495,7 +499,8 @@ def ensure_open(self): # that case self.close_connection_task will complete even faster. if self.close_code is None: yield from asyncio.shield(self.close_connection_task) - raise ConnectionClosed(self.close_code, self.close_reason) + raise ConnectionClosed( + self.close_code, self.close_reason) from self.transfer_data_exc # Control may only reach this point in buggy third-party subclasses. assert self.state is State.CONNECTING @@ -517,33 +522,40 @@ def transfer_data(self): break yield from self.messages.put(msg) - except asyncio.CancelledError: + except asyncio.CancelledError as exc: + self.transfer_data_exc = exc # If fail_connection() cancels this task, avoid logging the error # twice and failing the connection again. raise - except WebSocketProtocolError: + except WebSocketProtocolError as exc: + self.transfer_data_exc = exc self.fail_connection(1002) - except (ConnectionError, EOFError): + except (ConnectionError, EOFError) as exc: # Reading data with self.reader.readexactly may raise: # - most subclasses of ConnectionError if the TCP connection # breaks, is reset, or is aborted; # - IncompleteReadError, a subclass of EOFError, if fewer # bytes are available than requested. + self.transfer_data_exc = exc self.fail_connection(1006) - except UnicodeDecodeError: + except UnicodeDecodeError as exc: + self.transfer_data_exc = exc self.fail_connection(1007) - except PayloadTooBig: + except PayloadTooBig as exc: + self.transfer_data_exc = exc self.fail_connection(1009) - except Exception: + except Exception as exc: # This shouldn't happen often because exceptions expected under # regular circumstances are handled above. If it does, consider # catching and handling more exceptions. logger.error("Error in data transfer", exc_info=True) + + self.transfer_data_exc = exc self.fail_connection(1011) @asyncio.coroutine @@ -722,8 +734,9 @@ def write_frame(self, opcode, data=b'', _expected_state=State.OPEN): except ConnectionError: # Terminate the connection if the socket died. self.fail_connection() - # And raise an exception, since the frame couldn't be sent. - raise ConnectionClosed(self.close_code, self.close_reason) + # Wait until the connection is closed to raise ConnectionClosed + # with the correct code and reason. + yield from self.ensure_open() def writer_is_closing(self): """ From 595e3f052e2004313bfda4443e7e61d836e24f34 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 19 May 2018 11:58:28 +0200 Subject: [PATCH 0425/1539] Add missing entries in the changelog. --- docs/changelog.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index c4fc2898c..908faba1c 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -24,6 +24,9 @@ Also: * Iterating on incoming messages no longer raises an exception when the connection terminates with code 1001 (going away). +* A plain HTTP request now receives a 426 Upgrade Required response and + doesn't log a stack trace. + * :func:`~server.unix_serve` can be used as an asynchronous context manager on Python ≥ 3.5.1. @@ -34,8 +37,12 @@ Also: * Reported the cause of :exc:`~exceptions.ConnectionClosed` exceptions. +* Added new examples in the documentation. + * Updated documentation with new features from Python 3.6. +* Improved several other sections of the documentation. + * Fixed missing close code, which caused :exc:`TypeError` on connection close. * Fixed a race condition in the closing handshake that raised From 7da5f40a65fc6003b3d2457d042c07443a8256d1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 19 May 2018 11:43:02 +0200 Subject: [PATCH 0426/1539] Declare __await__ as coroutine. Python 3.7 requires this. --- docs/changelog.rst | 4 ++++ websockets/client.py | 1 + websockets/server.py | 1 + 3 files changed, 6 insertions(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index 908faba1c..1d614129e 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -18,6 +18,8 @@ Changelog Also: +* Added compatibility with Python 3.7. + * :func:`~client.connect()` performs HTTP Basic Auth when the URI contains credentials. @@ -126,6 +128,8 @@ Also: 3.3 ... +* Ensured compatibility with Python 3.6. + * Reduced noise in logs caused by connection resets. * Avoided crashing on concurrent writes on slow connections. diff --git a/websockets/client.py b/websockets/client.py index 92f29e9f5..3a810173c 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -393,6 +393,7 @@ def __aenter__(self): def __aexit__(self, exc_type, exc_value, traceback): yield from self.ws_client.close() + @asyncio.coroutine def __await__(self): transport, protocol = yield from self._creating_connection diff --git a/websockets/server.py b/websockets/server.py index 8db048282..12c6514e1 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -738,6 +738,7 @@ def __aexit__(self, exc_type, exc_value, traceback): self.ws_server.close() yield from self.ws_server.wait_closed() + @asyncio.coroutine def __await__(self): server = yield from self._creating_server self.ws_server.wrap(server) From 2b89213dc3a34c98d19f87cd4504046a656fe05c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 19 May 2018 11:41:33 +0200 Subject: [PATCH 0427/1539] Ignore a warning added in Python 3.7. Also set explicitly the warnings filter to default when running tests to prevent unittest from stomping on our ignore filter. --- .appveyor.yml | 2 +- .travis.yml | 2 +- Makefile | 5 ++--- tox.ini | 4 ++-- websockets/protocol.py | 11 +++++++++++ 5 files changed, 17 insertions(+), 7 deletions(-) diff --git a/.appveyor.yml b/.appveyor.yml index 31e48fe2d..41ea07d99 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -1,7 +1,7 @@ environment: # websockets only works on Python >= 3.4. CIBW_SKIP: cp27-* cp33-* - CIBW_TEST_COMMAND: python -m unittest websockets + CIBW_TEST_COMMAND: python -W default -m unittest websockets WEBSOCKETS_TESTS_TIMEOUT_FACTOR: 100 install: diff --git a/.travis.yml b/.travis.yml index da1ab8b01..d9460952e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,7 +2,7 @@ env: global: # websockets only works on Python >= 3.4. - CIBW_SKIP="cp27-* cp33-*" - - CIBW_TEST_COMMAND="python3 -m unittest websockets" + - CIBW_TEST_COMMAND="python3 -W default -m unittest websockets" - WEBSOCKETS_TESTS_TIMEOUT_FACTOR=100 matrix: diff --git a/Makefile b/Makefile index fd24aa7ff..db31c68d5 100644 --- a/Makefile +++ b/Makefile @@ -1,12 +1,11 @@ export PYTHONASYNCIODEBUG=1 -export PYTHONWARNINGS=default test: - python -m unittest + python -W default -m unittest coverage: python -m coverage erase - python -m coverage run --branch --source=websockets -m unittest + python -W default -m coverage run --branch --source=websockets -m unittest python -m coverage html clean: diff --git a/tox.ini b/tox.ini index c4123c20a..2eaf4fa48 100644 --- a/tox.ini +++ b/tox.ini @@ -12,7 +12,7 @@ commands = ; Before testing with speedups, compile the extension. speedups: python setup.py --quiet build_ext --inplace - python -m unittest {posargs} + python -W default -m unittest {posargs} ; After testing with speedups, remove the extension. speedups: sh -c 'rm websockets/*.so' @@ -27,7 +27,7 @@ commands = python setup.py --quiet build_ext --inplace python -m coverage erase - python -m coverage run --branch --source=websockets -m unittest + python -W default -m coverage run --branch --source=websockets -m unittest python -m coverage report --show-missing --fail-under=100 speedups: sh -c 'rm websockets/*.so' diff --git a/websockets/protocol.py b/websockets/protocol.py index e6a0bef4c..dbc99518e 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -15,6 +15,7 @@ import logging import random import struct +import warnings from .compatibility import asyncio_ensure_future from .exceptions import ( @@ -29,6 +30,16 @@ logger = logging.getLogger(__name__) +# On Python ≥ 3.7, silence a deprecation warning that we can't address before +# dropping support for Python < 3.5. +warnings.filterwarnings( + action='ignore', + message=r"'with \(yield from lock\)' is deprecated " + r"use 'async with lock' instead", + category=DeprecationWarning, +) + + # A WebSocket connection goes through the following four states, in order: class State(enum.IntEnum): From b6a25ceb3555d0ba69e5961b8d7616e4c1aecb2b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 8 May 2018 11:57:12 +0200 Subject: [PATCH 0428/1539] Security fix: defend against zip bombs. --- docs/changelog.rst | 8 ++++++++ websockets/extensions/base.py | 2 +- websockets/extensions/permessage_deflate.py | 15 +++++++++++---- websockets/extensions/test_permessage_deflate.py | 14 +++++++++++++- websockets/framing.py | 4 ++-- websockets/test_client_server.py | 2 +- websockets/test_framing.py | 4 +++- 7 files changed, 39 insertions(+), 10 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 1d614129e..af721068f 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -8,6 +8,14 @@ Changelog *In development* +.. note:: + + **Version 5.0 fixes a security issue introduced in version 4.0.** + + websockets 4.0 was vulnerable to denial of service by memory exhaustion + because it didn't enforce ``max_size`` when decompressing compressed + messages. + .. warning:: **Version 5.0 adds a** ``user_info`` **field to the return value of** diff --git a/websockets/extensions/base.py b/websockets/extensions/base.py index 3ec7c4321..1888f52fc 100644 --- a/websockets/extensions/base.py +++ b/websockets/extensions/base.py @@ -72,7 +72,7 @@ class Extension: """ name = ... - def decode(self, frame): + def decode(self, frame, *, max_size=None): """ Decode an incoming frame. diff --git a/websockets/extensions/permessage_deflate.py b/websockets/extensions/permessage_deflate.py index a5a0cb857..7ca911a2d 100644 --- a/websockets/extensions/permessage_deflate.py +++ b/websockets/extensions/permessage_deflate.py @@ -8,7 +8,7 @@ from ..exceptions import ( DuplicateParameter, InvalidParameterName, InvalidParameterValue, - NegotiationError + NegotiationError, PayloadTooBig ) from ..framing import CTRL_OPCODES, OP_CONT @@ -463,7 +463,7 @@ def __repr__(self): self.local_max_window_bits), ])) - def decode(self, frame): + def decode(self, frame, *, max_size=None): """ Decode an incoming frame. @@ -495,11 +495,18 @@ def decode(self, frame): self.decoder = zlib.decompressobj( wbits=-self.remote_max_window_bits) - # Uncompress compressed frames. + # Uncompress compressed frames. Protect against zip bombs by + # preventing zlib from decompressing more than max_length bytes + # (except when the limit is disabled with max_size = None). data = frame.data if frame.fin: data += _EMPTY_UNCOMPRESSED_BLOCK - data = self.decoder.decompress(data) + max_length = 0 if max_size is None else max_size + data = self.decoder.decompress(data, max_length) + if self.decoder.unconsumed_tail: + raise PayloadTooBig( + "Uncompressed payload length exceeds size limit (? > {} bytes)" + .format(max_size)) # Allow garbage collection of the decoder if it won't be reused. if frame.fin and self.remote_no_context_takeover: diff --git a/websockets/extensions/test_permessage_deflate.py b/websockets/extensions/test_permessage_deflate.py index 8fbcdee58..e4afcec39 100644 --- a/websockets/extensions/test_permessage_deflate.py +++ b/websockets/extensions/test_permessage_deflate.py @@ -3,7 +3,7 @@ from ..exceptions import ( DuplicateParameter, InvalidParameterName, InvalidParameterValue, - NegotiationError + NegotiationError, PayloadTooBig ) from ..framing import ( OP_BINARY, OP_CLOSE, OP_CONT, OP_PING, OP_PONG, OP_TEXT, Frame, @@ -835,3 +835,15 @@ def test_compress_settings(self): rsv1=True, data=b'\x00\x05\x00\xfa\xffcaf\xc3\xa9\x00', # not compressed )) + + # Frames aren't decoded beyond max_length. + + def test_decompress_max_size(self): + frame = Frame(True, OP_TEXT, ('a' * 20).encode('utf-8')) + + enc_frame = self.extension.encode(frame) + + self.assertEqual(enc_frame.data, b'JL\xc4\x04\x00\x00') + + with self.assertRaises(PayloadTooBig): + self.extension.decode(enc_frame, max_size=10) diff --git a/websockets/framing.py b/websockets/framing.py index 8236017da..b1b655b28 100644 --- a/websockets/framing.py +++ b/websockets/framing.py @@ -119,7 +119,7 @@ def read(cls, reader, *, mask, max_size=None, extensions=None): length, = struct.unpack('!Q', data) if max_size is not None and length > max_size: raise PayloadTooBig( - "Payload length exceeds limit: {} > {} bytes" + "Payload length exceeds size limit ({} > {} bytes)" .format(length, max_size)) if mask: mask_bits = yield from reader(4) @@ -134,7 +134,7 @@ def read(cls, reader, *, mask, max_size=None, extensions=None): if extensions is None: extensions = [] for extension in reversed(extensions): - frame = extension.decode(frame) + frame = extension.decode(frame, max_size=max_size) frame.check() diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 67424dfb3..8476913fe 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -204,7 +204,7 @@ class NoOpExtension: def __repr__(self): return 'NoOpExtension()' - def decode(self, frame): + def decode(self, frame, *, max_size=None): return frame def encode(self, frame): diff --git a/websockets/test_framing.py b/websockets/test_framing.py index ba14603b1..d550f7268 100644 --- a/websockets/test_framing.py +++ b/websockets/test_framing.py @@ -217,7 +217,9 @@ def encode(frame): return frame._replace(data=data) # This extensions is symmetrical. - decode = encode + @staticmethod + def decode(frame, *, max_size=None): + return Rot13.encode(frame) self.round_trip( b'\x81\x05uryyb', From 414a51414d37bfa51d9e533bde39f545c7fe9041 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 21 May 2018 12:48:00 +0200 Subject: [PATCH 0429/1539] Fix regression from b07b8895. --- websockets/handshake.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/websockets/handshake.py b/websockets/handshake.py index a23428a11..fedc6c502 100644 --- a/websockets/handshake.py +++ b/websockets/handshake.py @@ -88,7 +88,9 @@ def check_request(get_header): raise InvalidUpgrade('Connection', get_header('Connection')) upgrade = parse_upgrade(get_header('Upgrade')) - if not (len(upgrade) == 1 and upgrade[0] == 'websocket'): + # For compatibility with non-strict implementations, ignore case when + # checking the Upgrade header. It's supposed to be 'WebSocket'. + if not (len(upgrade) == 1 and upgrade[0].lower() == 'websocket'): raise InvalidUpgrade('Upgrade', get_header('Upgrade')) key = get_header('Sec-WebSocket-Key') @@ -139,7 +141,9 @@ def check_response(get_header, key): raise InvalidUpgrade('Connection', get_header('Connection')) upgrade = parse_upgrade(get_header('Upgrade')) - if not (len(upgrade) == 1 and upgrade[0] == 'websocket'): + # For compatibility with non-strict implementations, ignore case when + # checking the Upgrade header. It's supposed to be 'WebSocket'. + if not (len(upgrade) == 1 and upgrade[0].lower() == 'websocket'): raise InvalidUpgrade('Upgrade', get_header('Upgrade')) if get_header('Sec-WebSocket-Accept') != accept(key): From c08a3186b750be3fdfbe40c7f92b45e026e563e6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 22 May 2018 07:19:55 +0200 Subject: [PATCH 0430/1539] Bump version number. --- docs/changelog.rst | 5 ++++- docs/conf.py | 4 ++-- websockets/version.py | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index af721068f..76073f809 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -3,11 +3,14 @@ Changelog .. currentmodule:: websockets -5.0 +5.1 ... *In development* +5.0 +... + .. note:: **Version 5.0 fixes a security issue introduced in version 4.0.** diff --git a/docs/conf.py b/docs/conf.py index f68a0de28..6a7f02b90 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -48,9 +48,9 @@ # built documents. # # The short X.Y version. -version = '4.0' +version = '5.0' # The full version, including alpha/beta/rc tags. -release = '4.0' +release = '5.0' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/websockets/version.py b/websockets/version.py index dc4e88832..d4bf36377 100644 --- a/websockets/version.py +++ b/websockets/version.py @@ -1 +1 @@ -version = '4.0.1' +version = '5.0' From 5b991fb441fc7626d783a813609022e670397132 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 24 May 2018 12:39:39 +0200 Subject: [PATCH 0431/1539] Revert "Declare __await__ as coroutine." This reverts commit 7da5f40a65fc6003b3d2457d042c07443a8256d1. Fix #411. --- docs/changelog.rst | 4 ---- websockets/client.py | 1 - websockets/server.py | 1 - 3 files changed, 6 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 76073f809..5399d4bc9 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -29,8 +29,6 @@ Changelog Also: -* Added compatibility with Python 3.7. - * :func:`~client.connect()` performs HTTP Basic Auth when the URI contains credentials. @@ -139,8 +137,6 @@ Also: 3.3 ... -* Ensured compatibility with Python 3.6. - * Reduced noise in logs caused by connection resets. * Avoided crashing on concurrent writes on slow connections. diff --git a/websockets/client.py b/websockets/client.py index 3a810173c..92f29e9f5 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -393,7 +393,6 @@ def __aenter__(self): def __aexit__(self, exc_type, exc_value, traceback): yield from self.ws_client.close() - @asyncio.coroutine def __await__(self): transport, protocol = yield from self._creating_connection diff --git a/websockets/server.py b/websockets/server.py index 12c6514e1..8db048282 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -738,7 +738,6 @@ def __aexit__(self, exc_type, exc_value, traceback): self.ws_server.close() yield from self.ws_server.wait_closed() - @asyncio.coroutine def __await__(self): server = yield from self._creating_server self.ws_server.wrap(server) From 458169b1c8ce692633d9bf90e8474c249936d49d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 24 May 2018 13:16:36 +0200 Subject: [PATCH 0432/1539] Bump version number. --- websockets/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/websockets/version.py b/websockets/version.py index d4bf36377..cd2999a7e 100644 --- a/websockets/version.py +++ b/websockets/version.py @@ -1 +1 @@ -version = '5.0' +version = '5.0.1' From 68371dd26ed08cc7839987ca5c77e9e03248ca58 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 21 May 2018 12:49:04 +0200 Subject: [PATCH 0433/1539] Autobahn testsuite issue 77 was fixed. --- compliance/README.rst | 5 +---- compliance/fuzzingclient.json | 2 +- compliance/fuzzingserver.json | 2 +- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/compliance/README.rst b/compliance/README.rst index c9605f675..cbb4ca2c7 100644 --- a/compliance/README.rst +++ b/compliance/README.rst @@ -2,7 +2,7 @@ Autobahn Testsuite ================== General information and installation instructions are available at -http://autobahn.ws/testsuite. +https://github.com/crossbario/autobahn-testsuite. To improve performance, you should compile the C extension first:: @@ -48,6 +48,3 @@ In 6.4.3 and 6.4.4, even though it uses an incremental decoder, ``websockets`` doesn't notice the invalid utf-8 fast enough to get a "Strict" pass. These tests are more strict than the RFC. -12.4.* are skipped: https://github.com/crossbario/autobahn-testsuite/issues/77 - -12.5.* are skipped: https://github.com/crossbario/autobahn-testsuite/issues/77 diff --git a/compliance/fuzzingclient.json b/compliance/fuzzingclient.json index c572d02e8..202ff49a0 100644 --- a/compliance/fuzzingclient.json +++ b/compliance/fuzzingclient.json @@ -6,6 +6,6 @@ "servers": [{"agent": "websockets", "url": "ws://localhost:8642", "options": {"version": 18}}], "cases": ["*"], - "exclude-cases": ["12.4.*", "12.5.*"], + "exclude-cases": [], "exclude-agent-cases": {} } diff --git a/compliance/fuzzingserver.json b/compliance/fuzzingserver.json index d7abd94c1..1bdb42723 100644 --- a/compliance/fuzzingserver.json +++ b/compliance/fuzzingserver.json @@ -7,6 +7,6 @@ "webport": 8080, "cases": ["*"], - "exclude-cases": ["12.4.*", "12.5.*"], + "exclude-cases": [], "exclude-agent-cases": {} } From 6f8f1c877744623f0a5df5917a85b97807bfb7e5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 24 May 2018 22:29:12 +0200 Subject: [PATCH 0434/1539] Add support for Python 3.7. Hopefully for real this time. This is annoyingly complicated. Fix #405. --- docs/changelog.rst | 4 +++ websockets/client.py | 24 +++++++---------- websockets/py35/_test_client_server.py | 37 ++++++++++++++++++++++++++ websockets/py35/client.py | 33 +++++++++++++++++++++++ websockets/py35/server.py | 22 +++++++++++++++ websockets/server.py | 25 +++++++---------- websockets/test_client_server.py | 1 + 7 files changed, 115 insertions(+), 31 deletions(-) create mode 100644 websockets/py35/client.py create mode 100644 websockets/py35/server.py diff --git a/docs/changelog.rst b/docs/changelog.rst index 5399d4bc9..7712e38f3 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -8,6 +8,8 @@ Changelog *In development* +* Added compatibility with Python 3.7. + 5.0 ... @@ -137,6 +139,8 @@ Also: 3.3 ... +* Ensured compatibility with Python 3.6. + * Reduced noise in logs caused by connection resets. * Avoided crashing on concurrent writes on slow connections. diff --git a/websockets/client.py b/websockets/client.py index 92f29e9f5..a86b90f81 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -385,15 +385,7 @@ def __init__(self, uri, *, self._creating_connection = loop.create_connection( factory, host, port, **kwds) - @asyncio.coroutine - def __aenter__(self): - return (yield from self) - - @asyncio.coroutine - def __aexit__(self, exc_type, exc_value, traceback): - yield from self.ws_client.close() - - def __await__(self): + def __iter__(self): # pragma: no cover transport, protocol = yield from self._creating_connection try: @@ -410,17 +402,19 @@ def __await__(self): self.ws_client = protocol return protocol - __iter__ = __await__ - -# Disable asynchronous context manager functionality only on Python < 3.5.1 -# because it doesn't exist on Python < 3.5 and asyncio.ensure_future didn't -# accept arbitrary awaitables in Python 3.5; that was fixed in Python 3.5.1. +# We can't define __await__ on Python < 3.5.1 because asyncio.ensure_future +# didn't accept arbitrary awaitables until Python 3.5.1. We don't define +# __aenter__ and __aexit__ either on Python < 3.5.1 to keep things simple. if sys.version_info[:3] <= (3, 5, 0): # pragma: no cover @asyncio.coroutine def connect(*args, **kwds): - return Connect(*args, **kwds).__await__() + return Connect(*args, **kwds).__iter__() connect.__doc__ = Connect.__doc__ else: + from .py35.client import __aenter__, __aexit__, __await__ + Connect.__aenter__ = __aenter__ + Connect.__aexit__ = __aexit__ + Connect.__await__ = __await__ connect = Connect diff --git a/websockets/py35/_test_client_server.py b/websockets/py35/_test_client_server.py index 437524885..5360d8d01 100644 --- a/websockets/py35/_test_client_server.py +++ b/websockets/py35/_test_client_server.py @@ -13,6 +13,43 @@ from ..test_client_server import get_server_uri, handler +class AsyncAwaitTests(unittest.TestCase): + + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + + def test_client(self): + start_server = serve(handler, 'localhost', 0) + server = self.loop.run_until_complete(start_server) + + async def run_client(): + # Await connect. + client = await connect(get_server_uri(server)) + self.assertEqual(client.state, State.OPEN) + await client.close() + self.assertEqual(client.state, State.CLOSED) + + self.loop.run_until_complete(run_client()) + + server.close() + self.loop.run_until_complete(server.wait_closed()) + + def test_server(self): + async def run_server(): + # Await serve. + server = await serve(handler, 'localhost', 0) + self.assertTrue(server.sockets) + server.close() + await server.wait_closed() + self.assertFalse(server.sockets) + + self.loop.run_until_complete(run_server()) + + class ContextManagerTests(unittest.TestCase): def setUp(self): diff --git a/websockets/py35/client.py b/websockets/py35/client.py new file mode 100644 index 000000000..7673ea3ad --- /dev/null +++ b/websockets/py35/client.py @@ -0,0 +1,33 @@ +async def __aenter__(self): + return await self + + +async def __aexit__(self, exc_type, exc_value, traceback): + await self.ws_client.close() + + +async def __await_impl__(self): + # Duplicated with __iter__ because Python 3.7 requires an async function + # (as explained in __await__ below) which Python 3.4 doesn't support. + transport, protocol = await self._creating_connection + + try: + await protocol.handshake( + self._wsuri, origin=self._origin, + available_extensions=protocol.available_extensions, + available_subprotocols=protocol.available_subprotocols, + extra_headers=protocol.extra_headers, + ) + except Exception: + await protocol.fail_connection() + raise + + self.ws_client = protocol + return protocol + + +def __await__(self): + # __await__() must return a type that I don't know how to obtain except + # by calling __await__() on the return value of an async function. + # I'm not finding a better way to take advantage of PEP 492. + return __await_impl__(self).__await__() diff --git a/websockets/py35/server.py b/websockets/py35/server.py new file mode 100644 index 000000000..41a3675e3 --- /dev/null +++ b/websockets/py35/server.py @@ -0,0 +1,22 @@ +async def __aenter__(self): + return await self + + +async def __aexit__(self, exc_type, exc_value, traceback): + self.ws_server.close() + await self.ws_server.wait_closed() + + +async def __await_impl__(self): + # Duplicated with __iter__ because Python 3.7 requires an async function + # (as explained in __await__ below) which Python 3.4 doesn't support. + server = await self._creating_server + self.ws_server.wrap(server) + return self.ws_server + + +def __await__(self): + # __await__() must return a type that I don't know how to obtain except + # by calling __await__() on the return value of an async function. + # I'm not finding a better way to take advantage of PEP 492. + return __await_impl__(self).__await__() diff --git a/websockets/server.py b/websockets/server.py index 8db048282..46c80dc6d 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -729,22 +729,11 @@ def __init__(self, ws_handler, host=None, port=None, *, self._creating_server = creating_server self.ws_server = ws_server - @asyncio.coroutine - def __aenter__(self): - return (yield from self) - - @asyncio.coroutine - def __aexit__(self, exc_type, exc_value, traceback): - self.ws_server.close() - yield from self.ws_server.wait_closed() - - def __await__(self): + def __iter__(self): # pragma: no cover server = yield from self._creating_server self.ws_server.wrap(server) return self.ws_server - __iter__ = __await__ - def unix_serve(ws_handler, path, **kwargs): """ @@ -761,14 +750,18 @@ def unix_serve(ws_handler, path, **kwargs): return serve(ws_handler, path=path, **kwargs) -# Disable asynchronous context manager functionality only on Python < 3.5.1 -# because it doesn't exist on Python < 3.5 and asyncio.ensure_future didn't -# accept arbitrary awaitables in Python 3.5; that was fixed in Python 3.5.1. +# We can't define __await__ on Python < 3.5.1 because asyncio.ensure_future +# didn't accept arbitrary awaitables until Python 3.5.1. We don't define +# __aenter__ and __aexit__ either on Python < 3.5.1 to keep things simple. if sys.version_info[:3] <= (3, 5, 0): # pragma: no cover @asyncio.coroutine def serve(*args, **kwds): - return Serve(*args, **kwds).__await__() + return Serve(*args, **kwds).__iter__() serve.__doc__ = Serve.__doc__ else: + from .py35.server import __aenter__, __aexit__, __await__ + Serve.__aenter__ = __aenter__ + Serve.__aexit__ = __aexit__ + Serve.__await__ = __await__ serve = Serve diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 8476913fe..27a2a719b 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -1057,6 +1057,7 @@ def test_checking_lack_of_origin_succeeds(self): try: + from .py35._test_client_server import AsyncAwaitTests # noqa from .py35._test_client_server import ContextManagerTests # noqa except (SyntaxError, ImportError): # pragma: no cover pass From ada2987ddf2eccbb36a6ead0a5936ba0ed397032 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 5 Jun 2018 22:08:55 +0200 Subject: [PATCH 0435/1539] Replace conditional errors with version checks. This avoids silently ignoring tests instead of failing them in case of mistakes. Fix #415. --- websockets/protocol.py | 6 ++---- websockets/test_client_server.py | 8 ++------ 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index dbc99518e..66939aa83 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -15,6 +15,7 @@ import logging import random import struct +import sys import warnings from .compatibility import asyncio_ensure_future @@ -1020,9 +1021,6 @@ def connection_lost(self, exc): super().connection_lost(exc) -try: +if sys.version_info[:2] >= (3, 6): # pragma: no cover from .py36.protocol import __aiter__ -except (SyntaxError, ImportError): # pragma: no cover - pass -else: WebSocketCommonProtocol.__aiter__ = __aiter__ diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 27a2a719b..a3e1e9244 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -1056,14 +1056,10 @@ def test_checking_lack_of_origin_succeeds(self): self.loop.run_until_complete(server.wait_closed()) -try: +if sys.version_info[:2] >= (3, 5): # pragma: no cover from .py35._test_client_server import AsyncAwaitTests # noqa from .py35._test_client_server import ContextManagerTests # noqa -except (SyntaxError, ImportError): # pragma: no cover - pass -try: +if sys.version_info[:2] >= (3, 6): # pragma: no cover from .py36._test_client_server import AsyncIteratorTests # noqa -except (SyntaxError, ImportError): # pragma: no cover - pass From 18235028464343e3584e009bc11709a83ca1a801 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 2 Jun 2018 20:50:43 +0200 Subject: [PATCH 0436/1539] Remove get_header. --- websockets/client.py | 30 ++++++++++----------- websockets/handshake.py | 46 ++++++++++++++++---------------- websockets/http.py | 2 +- websockets/server.py | 28 +++++++++---------- websockets/test_client_server.py | 6 ++--- websockets/test_handshake.py | 21 +++++++-------- websockets/test_http.py | 2 +- 7 files changed, 64 insertions(+), 71 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index a86b90f81..c17452313 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -216,20 +216,19 @@ def handshake(self, wsuri, origin=None, available_extensions=None, fails. """ - request_headers = [] - set_header = lambda k, v: request_headers.append((k, v)) - is_header_set = lambda k: k in dict(request_headers).keys() + request_headers = {} if wsuri.port == (443 if wsuri.secure else 80): # pragma: no cover - set_header('Host', wsuri.host) + request_headers['Host'] = wsuri.host else: - set_header('Host', '{}:{}'.format(wsuri.host, wsuri.port)) + request_headers['Host'] = '{}:{}'.format(wsuri.host, wsuri.port) if wsuri.user_info: - set_header(*basic_auth_header(*wsuri.user_info)) + request_headers['Authorization'] = basic_auth_header( + *wsuri.user_info) if origin is not None: - set_header('Origin', origin) + request_headers['Origin'] = origin if available_extensions is not None: extensions_header = build_extension_list([ @@ -239,33 +238,32 @@ def handshake(self, wsuri, origin=None, available_extensions=None, ) for extension_factory in available_extensions ]) - set_header('Sec-WebSocket-Extensions', extensions_header) + request_headers['Sec-WebSocket-Extensions'] = extensions_header if available_subprotocols is not None: protocol_header = build_subprotocol_list(available_subprotocols) - set_header('Sec-WebSocket-Protocol', protocol_header) + request_headers['Sec-WebSocket-Protocol'] = protocol_header if extra_headers is not None: if isinstance(extra_headers, collections.abc.Mapping): extra_headers = extra_headers.items() for name, value in extra_headers: - set_header(name, value) + request_headers[name] = value - if not is_header_set('User-Agent'): - set_header('User-Agent', USER_AGENT) + if 'User-Agent' not in request_headers: + request_headers['User-Agent'] = USER_AGENT - key = build_request(set_header) + key = build_request(request_headers) yield from self.write_http_request( - wsuri.resource_name, request_headers) + wsuri.resource_name, list(request_headers.items())) status_code, response_headers = yield from self.read_http_response() - get_header = lambda k: response_headers.get(k, '') if status_code != 101: raise InvalidStatusCode(status_code) - check_response(get_header, key) + check_response(response_headers, key) self.extensions = self.process_extensions( response_headers, available_extensions) diff --git a/websockets/handshake.py b/websockets/handshake.py index fedc6c502..e1046b99f 100644 --- a/websockets/handshake.py +++ b/websockets/handshake.py @@ -51,7 +51,7 @@ GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' -def build_request(set_header): +def build_request(headers): """ Build a handshake request to send to the server. @@ -60,14 +60,14 @@ def build_request(set_header): """ raw_key = bytes(random.getrandbits(8) for _ in range(16)) key = base64.b64encode(raw_key).decode() - set_header('Upgrade', 'websocket') - set_header('Connection', 'Upgrade') - set_header('Sec-WebSocket-Key', key) - set_header('Sec-WebSocket-Version', '13') + headers['Upgrade'] = 'websocket' + headers['Connection'] = 'Upgrade' + headers['Sec-WebSocket-Key'] = key + headers['Sec-WebSocket-Version'] = '13' return key -def check_request(get_header): +def check_request(headers): """ Check a handshake request received from the client. @@ -83,17 +83,17 @@ def check_request(get_header): responsibility of the caller. """ - connection = parse_connection(get_header('Connection')) + connection = parse_connection(headers.get('Connection', '')) if not any(value.lower() == 'upgrade' for value in connection): - raise InvalidUpgrade('Connection', get_header('Connection')) + raise InvalidUpgrade('Connection', headers.get('Connection', '')) - upgrade = parse_upgrade(get_header('Upgrade')) + upgrade = parse_upgrade(headers.get('Upgrade', '')) # For compatibility with non-strict implementations, ignore case when # checking the Upgrade header. It's supposed to be 'WebSocket'. if not (len(upgrade) == 1 and upgrade[0].lower() == 'websocket'): - raise InvalidUpgrade('Upgrade', get_header('Upgrade')) + raise InvalidUpgrade('Upgrade', headers.get('Upgrade', '')) - key = get_header('Sec-WebSocket-Key') + key = headers.get('Sec-WebSocket-Key', '') try: raw_key = base64.b64decode(key.encode(), validate=True) except binascii.Error: @@ -101,26 +101,26 @@ def check_request(get_header): if len(raw_key) != 16: raise InvalidHeaderValue('Sec-WebSocket-Key', key) - version = get_header('Sec-WebSocket-Version') + version = headers.get('Sec-WebSocket-Version', '') if version != '13': raise InvalidHeaderValue('Sec-WebSocket-Version', version) return key -def build_response(set_header, key): +def build_response(headers, key): """ Build a handshake response to send to the client. ``key`` comes from :func:`check_request`. """ - set_header('Upgrade', 'websocket') - set_header('Connection', 'Upgrade') - set_header('Sec-WebSocket-Accept', accept(key)) + headers['Upgrade'] = 'websocket' + headers['Connection'] = 'Upgrade' + headers['Sec-WebSocket-Accept'] = accept(key) -def check_response(get_header, key): +def check_response(headers, key): """ Check a handshake response received from the server. @@ -136,19 +136,19 @@ def check_response(get_header, key): the caller. """ - connection = parse_connection(get_header('Connection')) + connection = parse_connection(headers.get('Connection', '')) if not any(value.lower() == 'upgrade' for value in connection): - raise InvalidUpgrade('Connection', get_header('Connection')) + raise InvalidUpgrade('Connection', headers.get('Connection', '')) - upgrade = parse_upgrade(get_header('Upgrade')) + upgrade = parse_upgrade(headers.get('Upgrade', '')) # For compatibility with non-strict implementations, ignore case when # checking the Upgrade header. It's supposed to be 'WebSocket'. if not (len(upgrade) == 1 and upgrade[0].lower() == 'websocket'): - raise InvalidUpgrade('Upgrade', get_header('Upgrade')) + raise InvalidUpgrade('Upgrade', headers.get('Upgrade', '')) - if get_header('Sec-WebSocket-Accept') != accept(key): + if headers.get('Sec-WebSocket-Accept', '') != accept(key): raise InvalidHeaderValue( - 'Sec-WebSocket-Accept', get_header('Sec-WebSocket-Accept')) + 'Sec-WebSocket-Accept', headers.get('Sec-WebSocket-Accept', '')) def accept(key): diff --git a/websockets/http.py b/websockets/http.py index 25f32c34e..1689e6637 100644 --- a/websockets/http.py +++ b/websockets/http.py @@ -218,4 +218,4 @@ def basic_auth_header(username, password): assert ':' not in username user_pass = '{}:{}'.format(username, password) basic_credentials = base64.b64encode(user_pass.encode()).decode() - return ('Authorization', 'Basic ' + basic_credentials) + return 'Basic ' + basic_credentials diff --git a/websockets/server.py b/websockets/server.py index 46c80dc6d..a5e14a02a 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -262,7 +262,7 @@ def process_request(self, path, request_headers): """ - def process_origin(self, get_header, origins=None): + def process_origin(self, origin, origins=None): """ Handle the Origin HTTP request header. @@ -270,7 +270,6 @@ def process_origin(self, get_header, origins=None): acceptable. """ - origin = get_header('Origin') if origins is not None: if origin not in origins: raise InvalidOrigin(origin) @@ -441,11 +440,10 @@ def handshake(self, origins=None, available_extensions=None, if early_response is not None: raise AbortHandshake(*early_response) - get_header = lambda k: request_headers.get(k, '') + key = check_request(request_headers) - key = check_request(get_header) - - self.origin = self.process_origin(get_header, origins) + origin = request_headers.get('Origin', '') + self.origin = self.process_origin(origin, origins) extensions_header, self.extensions = self.process_extensions( request_headers, available_extensions) @@ -453,15 +451,13 @@ def handshake(self, origins=None, available_extensions=None, protocol_header = self.subprotocol = self.process_subprotocol( request_headers, available_subprotocols) - response_headers = [] - set_header = lambda k, v: response_headers.append((k, v)) - is_header_set = lambda k: k in dict(response_headers).keys() + response_headers = {} if extensions_header is not None: - set_header('Sec-WebSocket-Extensions', extensions_header) + response_headers['Sec-WebSocket-Extensions'] = extensions_header if self.subprotocol is not None: - set_header('Sec-WebSocket-Protocol', protocol_header) + response_headers['Sec-WebSocket-Protocol'] = protocol_header if extra_headers is not None: if callable(extra_headers): @@ -469,15 +465,15 @@ def handshake(self, origins=None, available_extensions=None, if isinstance(extra_headers, collections.abc.Mapping): extra_headers = extra_headers.items() for name, value in extra_headers: - set_header(name, value) + response_headers[name] = value - if not is_header_set('Server'): - set_header('Server', USER_AGENT) + if 'Server' not in response_headers: + response_headers['Server'] = USER_AGENT - build_response(set_header, key) + build_response(response_headers, key) yield from self.write_http_response( - SWITCHING_PROTOCOLS, response_headers) + SWITCHING_PROTOCOLS, list(response_headers.items())) self.connection_open() diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index a3e1e9244..321ecbec6 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -809,7 +809,7 @@ def test_client_receives_malformed_response(self, _read_response): @with_server() @unittest.mock.patch('websockets.client.build_request') def test_client_sends_invalid_handshake_request(self, _build_request): - def wrong_build_request(set_header): + def wrong_build_request(headers): return '42' _build_request.side_effect = wrong_build_request @@ -819,8 +819,8 @@ def wrong_build_request(set_header): @with_server() @unittest.mock.patch('websockets.server.build_response') def test_server_sends_invalid_handshake_response(self, _build_response): - def wrong_build_response(set_header, key): - return build_response(set_header, '42') + def wrong_build_response(headers, key): + return build_response(headers, '42') _build_response.side_effect = wrong_build_response with self.assertRaises(InvalidHandshake): diff --git a/websockets/test_handshake.py b/websockets/test_handshake.py index 62b4ffc0f..5083d1ee2 100644 --- a/websockets/test_handshake.py +++ b/websockets/test_handshake.py @@ -1,4 +1,3 @@ -import collections import contextlib import unittest @@ -17,12 +16,12 @@ def test_accept(self): def test_round_trip(self): request_headers = {} - request_key = build_request(request_headers.__setitem__) - response_key = check_request(request_headers.__getitem__) + request_key = build_request(request_headers) + response_key = check_request(request_headers) self.assertEqual(request_key, response_key) response_headers = {} - build_response(response_headers.__setitem__, response_key) - check_response(response_headers.__getitem__, request_key) + build_response(response_headers, response_key) + check_response(response_headers, request_key) @contextlib.contextmanager def assertInvalidRequestHeaders(self): @@ -32,11 +31,11 @@ def assertInvalidRequestHeaders(self): Assert that the transformation made them invalid. """ - headers = collections.defaultdict(lambda: '') - build_request(headers.__setitem__) + headers = {} + build_request(headers) yield headers with self.assertRaises(InvalidHandshake): - check_request(headers.__getitem__) + check_request(headers) def test_request_invalid_upgrade(self): with self.assertInvalidRequestHeaders() as headers: @@ -86,11 +85,11 @@ def assertInvalidResponseHeaders(self, key='CSIRmL8dWYxeAdr/XpEHRw=='): Assert that the transformation made them invalid. """ - headers = collections.defaultdict(lambda: '') - build_response(headers.__setitem__, key) + headers = {} + build_response(headers, key) yield headers with self.assertRaises(InvalidHandshake): - check_response(headers.__getitem__, key) + check_response(headers, key) def test_response_invalid_upgrade(self): with self.assertInvalidResponseHeaders() as headers: diff --git a/websockets/test_http.py b/websockets/test_http.py index 38f6363da..7069e0eda 100644 --- a/websockets/test_http.py +++ b/websockets/test_http.py @@ -133,5 +133,5 @@ def test_basic_auth_header(self): # Test vector from RFC 7617. self.assertEqual( basic_auth_header("Aladdin", "open sesame"), - ('Authorization', 'Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=='), + 'Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==', ) From b4e53f702c261d91c73f46712fd0cac213a31510 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 2 Jun 2018 21:38:47 +0200 Subject: [PATCH 0437/1539] Update documentation. --- docs/changelog.rst | 10 +++++++++- websockets/handshake.py | 13 +++++-------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 7712e38f3..044c090d1 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -3,11 +3,19 @@ Changelog .. currentmodule:: websockets -5.1 +6.0 ... *In development* +.. warning:: + + **Version 6.0 changes public APIs in the** :mod:`~websockets.handshake` + **module. If you're calling these APIs, you must update your code. This + affects mostly libraries that use low-level APIs.** + +Also: + * Added compatibility with Python 3.7. 5.0 diff --git a/websockets/handshake.py b/websockets/handshake.py index e1046b99f..00fdd18aa 100644 --- a/websockets/handshake.py +++ b/websockets/handshake.py @@ -4,15 +4,12 @@ .. _section 4 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-4 -It provides functions to implement the handshake with any existing HTTP -library. You must pass to these functions: +Functions defined in this module manipulate HTTP headers. The ``headers`` +argument must implement ``get`` and ``__setitem__`` and ``get`` — a small +subset of the :class:`~collections.abc.MutableMapping` abstract base class. -- A ``set_header`` function accepting a header name and a header value, -- A ``get_header`` function accepting a header name and returning the header - value. - -The inputs and outputs of ``get_header`` and ``set_header`` are :class:`str` -objects containing only ASCII characters. +Headers names and values are :class:`str` objects containing only ASCII +characters. Some checks cannot be performed because they depend too much on the context; instead, they're documented below. From 2ad9b4822cf3c956bcae1b069a93e0a19ef1b62a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 2 Jun 2018 21:46:23 +0200 Subject: [PATCH 0438/1539] Move basic auth to the headers module. --- websockets/client.py | 8 ++++---- websockets/headers.py | 13 +++++++++++++ websockets/http.py | 13 ------------- websockets/test_headers.py | 8 ++++++++ websockets/test_http.py | 9 +-------- 5 files changed, 26 insertions(+), 25 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index c17452313..33a9650c9 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -13,10 +13,10 @@ from .extensions.permessage_deflate import ClientPerMessageDeflateFactory from .handshake import build_request, check_response from .headers import ( - build_extension_list, build_subprotocol_list, parse_extension_list, - parse_subprotocol_list + build_basic_auth, build_extension_list, build_subprotocol_list, + parse_extension_list, parse_subprotocol_list ) -from .http import USER_AGENT, basic_auth_header, build_headers, read_response +from .http import USER_AGENT, build_headers, read_response from .protocol import WebSocketCommonProtocol from .uri import parse_uri @@ -224,7 +224,7 @@ def handshake(self, wsuri, origin=None, available_extensions=None, request_headers['Host'] = '{}:{}'.format(wsuri.host, wsuri.port) if wsuri.user_info: - request_headers['Authorization'] = basic_auth_header( + request_headers['Authorization'] = build_basic_auth( *wsuri.user_info) if origin is not None: diff --git a/websockets/headers.py b/websockets/headers.py index a88c975a9..6da5ec7f0 100644 --- a/websockets/headers.py +++ b/websockets/headers.py @@ -7,6 +7,7 @@ """ +import base64 import re from .exceptions import InvalidHeaderFormat @@ -324,3 +325,15 @@ def build_subprotocol_list(protocols): """ return ', '.join(protocols) + + +def build_basic_auth(username, password): + """ + Build an Authorization header for HTTP Basic Auth. + + """ + # https://tools.ietf.org/html/rfc7617#section-2 + assert ':' not in username + user_pass = '{}:{}'.format(username, password) + basic_credentials = base64.b64encode(user_pass.encode()).decode() + return 'Basic ' + basic_credentials diff --git a/websockets/http.py b/websockets/http.py index 1689e6637..5a91a4b7c 100644 --- a/websockets/http.py +++ b/websockets/http.py @@ -8,7 +8,6 @@ """ import asyncio -import base64 import http.client import re import sys @@ -207,15 +206,3 @@ def build_headers(raw_headers): headers = http.client.HTTPMessage() headers._headers = raw_headers # HACK return headers - - -def basic_auth_header(username, password): - """ - Build an Authorization header for HTTP Basic Auth. - - """ - # https://tools.ietf.org/html/rfc7617#section-2 - assert ':' not in username - user_pass = '{}:{}'.format(username, password) - basic_credentials = base64.b64encode(user_pass.encode()).decode() - return 'Basic ' + basic_credentials diff --git a/websockets/test_headers.py b/websockets/test_headers.py index ef5f28eda..10c2a7fd8 100644 --- a/websockets/test_headers.py +++ b/websockets/test_headers.py @@ -2,6 +2,7 @@ from .exceptions import InvalidHeaderFormat from .headers import * +from .headers import build_basic_auth class HeadersTests(unittest.TestCase): @@ -163,3 +164,10 @@ def test_parse_subprotocol_list_invalid_header(self): with self.subTest(header=header): with self.assertRaises(InvalidHeaderFormat): parse_subprotocol_list(header) + + def test_build_basic_auth(self): + # Test vector from RFC 7617. + self.assertEqual( + build_basic_auth("Aladdin", "open sesame"), + 'Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==', + ) diff --git a/websockets/test_http.py b/websockets/test_http.py index 7069e0eda..a6d61299b 100644 --- a/websockets/test_http.py +++ b/websockets/test_http.py @@ -2,7 +2,7 @@ import unittest from .http import * -from .http import basic_auth_header, build_headers, read_headers +from .http import build_headers, read_headers class HTTPAsyncTests(unittest.TestCase): @@ -128,10 +128,3 @@ def test_build_headers_multi_value(self): # Ordering is deterministic when getting all values. self.assertEqual(headers.get_all('X-Foo'), ['Bar', 'Baz']) - - def test_basic_auth_header(self): - # Test vector from RFC 7617. - self.assertEqual( - basic_auth_header("Aladdin", "open sesame"), - 'Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==', - ) From 43809a0688591dd7c24daf16b6057d3be7cf1ca8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 2 Jun 2018 23:28:43 +0200 Subject: [PATCH 0439/1539] Add a class for managing HTTP headers. --- websockets/http.py | 136 +++++++++++++++++++++++++++++++++++++++- websockets/test_http.py | 126 +++++++++++++++++++++++++++++++++++++ 2 files changed, 261 insertions(+), 1 deletion(-) diff --git a/websockets/http.py b/websockets/http.py index 5a91a4b7c..f4d78a51f 100644 --- a/websockets/http.py +++ b/websockets/http.py @@ -8,6 +8,7 @@ """ import asyncio +import collections.abc import http.client import re import sys @@ -15,7 +16,11 @@ from .version import version as websockets_version -__all__ = ['read_request', 'read_response', 'USER_AGENT'] +__all__ = [ + 'Headers', 'MultipleValuesError', + 'read_request', 'read_response', + 'USER_AGENT', +] MAX_HEADERS = 256 MAX_LINE = 4096 @@ -196,6 +201,135 @@ def read_line(stream): return line +class MultipleValuesError(LookupError): + """ + Exception raised when :class:`Headers` has more than one value for a key. + + """ + + def __str__(self): + # Implement the same logic as KeyError_str in Objects/exceptions.c. + if len(self.args) == 1: + return repr(self.args[0]) + return super().__str__() + + +class Headers(collections.abc.MutableMapping): + """ + Data structure for working with HTTP headers efficiently. + + A :class:`list` of ``(name, values)`` is inefficient for lookups. + + A :class:`dict` doesn't suffice because header names are case-insensitive + and multiple occurrences of headers with the same name are possible. + + :class:`Headers` stores HTTP headers in a hybrid data structure to provide + efficient insertions and lookups while preserving the original data. + + In order to account for multiple values with minimal hassle, + :class:`Headers` follows this logic: + + - When getting a header with ``headers[name]``: + - if there's no value, :exc:`KeyError` is raised; + - if there's exactly one value, it's returned; + - if there's more than one value, :exc:`MultipleValuesError` is raised. + + - When setting a header with ``headers[name] = value``, the value is + appended to the list of values for that header. + + - When deleting a header with ``del headers[name]``, all values for that + header are removed (this is slow). + + Other methods for manipulating headers are consistent with this logic. + + As long as no header occurs multiple times, :class:`Headers` behaves like + :class:`dict`, except keys are lower-cased to provide case-insensitivity. + + :meth:`get_all()` returns a list of all values for a header and + :meth:`raw_items()` returns an iterator of ``(name, values)`` pairs, + similar to :meth:`http.client.HTTPMessage`. + + """ + + __slots__ = ['_dict', '_list'] + + def __init__(self, *args, **kwargs): + self._dict = {} + self._list = [] + # MutableMapping.update calls __setitem__ for each (name, value) pair. + self.update(*args, **kwargs) + + def __str__(self): + return ''.join( + '{}: {}\r\n'.format(key, value) + for key, value in self._list + ) + '\r\n' + + def __repr__(self): + return '{}({})'.format(self.__class__.__name__, repr(self._list)) + + def copy(self): + copy = self.__class__() + copy._dict = self._dict.copy() + copy._list = self._list.copy() + return copy + + # Collection methods + + def __contains__(self, key): + return key.lower() in self._dict + + def __iter__(self): + return iter(self._dict) + + def __len__(self): + return len(self._dict) + + # MutableMapping methods + + def __getitem__(self, key): + value = self._dict[key.lower()] + if len(value) == 1: + return value[0] + else: + raise MultipleValuesError(key) + + def __setitem__(self, key, value): + self._dict.setdefault(key.lower(), []).append(value) + self._list.append((key, value)) + + def __delitem__(self, key): + key_lower = key.lower() + self._dict.__delitem__(key_lower) + # This is inefficent. Fortunately deleting HTTP headers is uncommon. + self._list = [(k, v) for k, v in self._list if k.lower() != key_lower] + + def __eq__(self, other): + if not isinstance(other, Headers): + return NotImplemented + return self._list == other._list + + def clear(self): + self._dict = {} + self._list = [] + + # Methods for handling multiple values + + def get_all(self, key): + """ + Return the (possibly empty) list of all values for a header. + + """ + return self._dict.get(key.lower(), []) + + def raw_items(self): + """ + Return an iterator of (header name, header value). + + """ + return iter(self._list) + + def build_headers(raw_headers): """ Build a date structure for HTTP headers from a list of name - value pairs. diff --git a/websockets/test_http.py b/websockets/test_http.py index a6d61299b..51df901f1 100644 --- a/websockets/test_http.py +++ b/websockets/test_http.py @@ -128,3 +128,129 @@ def test_build_headers_multi_value(self): # Ordering is deterministic when getting all values. self.assertEqual(headers.get_all('X-Foo'), ['Bar', 'Baz']) + + +class HeadersTests(unittest.TestCase): + + def setUp(self): + self.headers = Headers([ + ('Connection', 'Upgrade'), + ('Server', USER_AGENT), + ]) + + def test_str(self): + self.assertEqual( + str(self.headers), + "Connection: Upgrade\r\nServer: {}\r\n\r\n".format(USER_AGENT), + ) + + def test_repr(self): + self.assertEqual( + repr(self.headers), + "Headers([('Connection', 'Upgrade'), " + "('Server', '{}')])".format(USER_AGENT), + ) + + def test_multiple_values_error_str(self): + self.assertEqual( + str(MultipleValuesError('Connection')), + "'Connection'", + ) + self.assertEqual( + str(MultipleValuesError()), + "", + ) + + def test_contains(self): + self.assertIn('Server', self.headers) + + def test_contains_case_insensitive(self): + self.assertIn('server', self.headers) + + def test_contains_not_found(self): + self.assertNotIn('Date', self.headers) + + def test_iter(self): + self.assertEqual(set(iter(self.headers)), {'connection', 'server'}) + + def test_len(self): + self.assertEqual(len(self.headers), 2) + + def test_getitem(self): + self.assertEqual(self.headers['Server'], USER_AGENT) + + def test_getitem_case_insensitive(self): + self.assertEqual(self.headers['server'], USER_AGENT) + + def test_getitem_key_error(self): + with self.assertRaises(KeyError): + self.headers['Upgrade'] + + def test_getitem_multiple_values_error(self): + self.headers['Server'] = '2' + with self.assertRaises(MultipleValuesError): + self.headers['Server'] + + def test_setitem(self): + self.headers['Upgrade'] = 'websocket' + self.assertEqual(self.headers['Upgrade'], 'websocket') + + def test_setitem_case_insensitive(self): + self.headers['upgrade'] = 'websocket' + self.assertEqual(self.headers['Upgrade'], 'websocket') + + def test_setitem_multiple_values(self): + self.headers['Connection'] = 'close' + with self.assertRaises(MultipleValuesError): + self.headers['Connection'] + + def test_delitem(self): + del self.headers['Connection'] + with self.assertRaises(KeyError): + self.headers['Connection'] + + def test_delitem_case_insensitive(self): + del self.headers['connection'] + with self.assertRaises(KeyError): + self.headers['Connection'] + + def test_delitem_multiple_values(self): + self.headers['Connection'] = 'close' + del self.headers['Connection'] + with self.assertRaises(KeyError): + self.headers['Connection'] + + def test_eq(self): + other_headers = self.headers.copy() + self.assertEqual(self.headers, other_headers) + + def test_eq_not_equal(self): + self.assertNotEqual(self.headers, []) + + def test_clear(self): + self.headers.clear() + self.assertFalse(self.headers) + self.assertEqual(self.headers, Headers()) + + def test_get_all(self): + self.assertEqual(self.headers.get_all('Connection'), ['Upgrade']) + + def test_get_all_case_insensitive(self): + self.assertEqual(self.headers.get_all('connection'), ['Upgrade']) + + def test_get_all_no_values(self): + self.assertEqual(self.headers.get_all('Upgrade'), []) + + def test_get_all_multiple_values(self): + self.headers['Connection'] = 'close' + self.assertEqual( + self.headers.get_all('Connection'), ['Upgrade', 'close']) + + def test_raw_items(self): + self.assertEqual( + list(self.headers.raw_items()), + [ + ('Connection', 'Upgrade'), + ('Server', USER_AGENT), + ], + ) From 2037226b0f1dc2c8e5ff255ec57a1fd18b0d1da9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 3 Jun 2018 10:12:07 +0200 Subject: [PATCH 0440/1539] Use Headers throughout the library. Document changes. Fix #210, #256. --- docs/changelog.rst | 30 +++++++++-- websockets/client.py | 36 +++++++------- websockets/http.py | 33 +++++-------- websockets/protocol.py | 13 ++--- websockets/server.py | 85 +++++++++++++++++--------------- websockets/test_client_server.py | 64 +++++++++++++----------- websockets/test_exceptions.py | 3 +- websockets/test_http.py | 34 ++----------- 8 files changed, 149 insertions(+), 149 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 044c090d1..5bc47c99d 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -10,9 +10,33 @@ Changelog .. warning:: - **Version 6.0 changes public APIs in the** :mod:`~websockets.handshake` - **module. If you're calling these APIs, you must update your code. This - affects mostly libraries that use low-level APIs.** + **Version 6.0 introduces the** :class:`~http.Headers` **class for managing + HTTP headers and changes several public APIs:** + + * :meth:`~server.WebSocketServerProtocol.process_request` now receives a + :class:`~http.Headers` instead of a :class:`~http.client.HTTPMessage` in + the ``request_headers`` argument. + + * The :attr:`~protocol.WebSocketCommonProtocol.request_headers` and + :attr:`~protocol.WebSocketCommonProtocol.response_headers` attributes of + :class:`~protocol.WebSocketCommonProtocol` are :class:`~http.Headers` + instead of :class:`~http.client.HTTPMessage`. + + * The :attr:`~protocol.WebSocketCommonProtocol.raw_request_headers` and + :attr:`~protocol.WebSocketCommonProtocol.raw_response_headers` + attributes of :class:`~protocol.WebSocketCommonProtocol` are removed. + Use :meth:`~http.Headers.raw_items` instead. + + * Functions defined in the :mod:`~handshake` module now receive + :class:`~http.Headers` in argument instead of ``get_header`` or + ``set_header`` fucntions. This affects libraries that rely on + low-level APIs. + + * Functions defined in the :mod:`~http` module now return HTTP headers as + :class:`~http.Headers` instead of lists of ``(name, value)`` pairs. + + Note that :class:`~http.Headers` and :class:`~http.client.HTTPMessage` + provide similar APIs. Also: diff --git a/websockets/client.py b/websockets/client.py index 33a9650c9..e4733d7e2 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -16,7 +16,7 @@ build_basic_auth, build_extension_list, build_subprotocol_list, parse_extension_list, parse_subprotocol_list ) -from .http import USER_AGENT, build_headers, read_response +from .http import USER_AGENT, Headers, read_response from .protocol import WebSocketCommonProtocol from .uri import parse_uri @@ -51,17 +51,14 @@ def write_http_request(self, path, headers): """ self.path = path - self.request_headers = build_headers(headers) - self.raw_request_headers = headers + self.request_headers = headers # Since the path and headers only contain ASCII characters, # we can keep this simple. - request = ['GET {path} HTTP/1.1'.format(path=path)] - request.extend('{}: {}'.format(k, v) for k, v in headers) - request.append('\r\n') - request = '\r\n'.join(request).encode() + request = 'GET {path} HTTP/1.1\r\n'.format(path=path) + request += str(headers) - self.writer.write(request) + self.writer.write(request.encode()) @asyncio.coroutine def read_http_response(self): @@ -81,8 +78,7 @@ def read_http_response(self): except ValueError as exc: raise InvalidMessage("Malformed HTTP message") from exc - self.response_headers = build_headers(headers) - self.raw_response_headers = headers + self.response_headers = headers return status_code, self.response_headers @@ -118,7 +114,7 @@ def process_extensions(headers, available_extensions): header_values = headers.get_all('Sec-WebSocket-Extensions') - if header_values is not None: + if header_values: if available_extensions is None: raise InvalidHandshake("No extensions supported") @@ -172,7 +168,7 @@ def process_subprotocol(headers, available_subprotocols): header_values = headers.get_all('Sec-WebSocket-Protocol') - if header_values is not None: + if header_values: if available_subprotocols is None: raise InvalidHandshake("No subprotocols supported") @@ -210,13 +206,15 @@ def handshake(self, wsuri, origin=None, available_extensions=None, subprotocols in order of decreasing preference. If provided, ``extra_headers`` sets additional HTTP request headers. - It must be a mapping or an iterable of (name, value) pairs. + It must be a :class:`~websockets.http.Headers` instance, a + :class:`~collections.abc.Mapping`, or an iterable of ``(name, value)`` + pairs. Raise :exc:`~websockets.exceptions.InvalidHandshake` if the handshake fails. """ - request_headers = {} + request_headers = Headers() if wsuri.port == (443 if wsuri.secure else 80): # pragma: no cover request_headers['Host'] = wsuri.host @@ -245,7 +243,9 @@ def handshake(self, wsuri, origin=None, available_extensions=None, request_headers['Sec-WebSocket-Protocol'] = protocol_header if extra_headers is not None: - if isinstance(extra_headers, collections.abc.Mapping): + if isinstance(extra_headers, Headers): + extra_headers = extra_headers.raw_items() + elif isinstance(extra_headers, collections.abc.Mapping): extra_headers = extra_headers.items() for name, value in extra_headers: request_headers[name] = value @@ -256,7 +256,7 @@ def handshake(self, wsuri, origin=None, available_extensions=None, key = build_request(request_headers) yield from self.write_http_request( - wsuri.resource_name, list(request_headers.items())) + wsuri.resource_name, request_headers) status_code, response_headers = yield from self.read_http_response() @@ -312,7 +312,9 @@ class Connect: * ``subprotocols`` is a list of supported subprotocols in order of decreasing preference * ``extra_headers`` sets additional HTTP request headers – it can be a - mapping or an iterable of (name, value) pairs + :class:`~websockets.http.Headers` instance, a + :class:`~collections.abc.Mapping`, or an iterable of ``(name, value)`` + pairs * ``compression`` is a shortcut to configure compression extensions; by default it enables the "permessage-deflate" extension; set it to ``None`` to disable compression diff --git a/websockets/http.py b/websockets/http.py index f4d78a51f..5ec339055 100644 --- a/websockets/http.py +++ b/websockets/http.py @@ -9,7 +9,6 @@ import asyncio import collections.abc -import http.client import re import sys @@ -59,7 +58,7 @@ def read_request(stream): ``stream`` is an :class:`~asyncio.StreamReader`. Return ``(path, headers)`` where ``path`` is a :class:`str` and - ``headers`` is a list of ``(name, value)`` tuples. + ``headers`` is a :class:`Headers` instance. ``path`` isn't URL-decoded or validated in any way. @@ -104,7 +103,7 @@ def read_response(stream): ``stream`` is an :class:`~asyncio.StreamReader`. Return ``(status_code, headers)`` where ``status_code`` is a :class:`int` - and ``headers`` is a list of ``(name, value)`` tuples. + and ``headers`` is a :class:`Headers` instance. Non-ASCII characters are represented with surrogate escapes. @@ -147,8 +146,7 @@ def read_headers(stream): ``stream`` is an :class:`~asyncio.StreamReader`. - Return ``(start_line, headers)`` where ``start_line`` is :class:`bytes` - and ``headers`` is a list of ``(name, value)`` tuples. + Return a :class:`Headers` instance Non-ASCII characters are represented with surrogate escapes. @@ -157,7 +155,7 @@ def read_headers(stream): # We don't attempt to support obsolete line folding. - headers = [] + headers = Headers() for _ in range(MAX_HEADERS + 1): line = yield from read_line(stream) if line == b'\r\n': @@ -171,10 +169,9 @@ def read_headers(stream): if not _value_re.fullmatch(value): raise ValueError("Invalid HTTP header value: %r" % value) - headers.append(( - name.decode('ascii'), # guaranteed to be ASCII at this point - value.decode('ascii', 'surrogateescape'), - )) + name = name.decode('ascii') # guaranteed to be ASCII at this point + value = value.decode('ascii', 'surrogateescape') + headers[name] = value else: raise ValueError("Too many HTTP headers") @@ -310,6 +307,10 @@ def __eq__(self, other): return self._list == other._list def clear(self): + """ + Remove all headers. + + """ self._dict = {} self._list = [] @@ -328,15 +329,3 @@ def raw_items(self): """ return iter(self._list) - - -def build_headers(raw_headers): - """ - Build a date structure for HTTP headers from a list of name - value pairs. - - See also https://github.com/aaugustin/websockets/issues/210. - - """ - headers = http.client.HTTPMessage() - headers._headers = raw_headers # HACK - return headers diff --git a/websockets/protocol.py b/websockets/protocol.py index 66939aa83..adadf6328 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -116,13 +116,12 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): current implementation of ``FlowControlMixin``). As soon as the HTTP request and response in the opening handshake are - processed, the request path is available in the :attr:`path` attribute, - and the request and response HTTP headers are available: + processed: - * as a :class:`~http.client.HTTPMessage` in the :attr:`request_headers` - and :attr:`response_headers` attributes - * as an iterable of (name, value) pairs in the :attr:`raw_request_headers` - and :attr:`raw_response_headers` attributes + * the request path is available in the :attr:`path` attribute; + * the request and response HTTP headers are available in the + :attr:`request_headers` and :attr:`response_headers` attributes, + which are :class:`~websockets.http.Headers` instances. These attributes must be treated as immutable. @@ -182,9 +181,7 @@ def __init__(self, *, # HTTP protocol parameters. self.path = None self.request_headers = None - self.raw_request_headers = None self.response_headers = None - self.raw_response_headers = None # WebSocket protocol parameters. self.extensions = [] diff --git a/websockets/server.py b/websockets/server.py index a5e14a02a..1ac131a6a 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -21,7 +21,7 @@ from .headers import ( build_extension_list, parse_extension_list, parse_subprotocol_list ) -from .http import USER_AGENT, build_headers, read_request +from .http import USER_AGENT, Headers, read_request from .protocol import WebSocketCommonProtocol @@ -93,47 +93,49 @@ def handler(self): raise except Exception as exc: if self._is_server_shutting_down(exc): - early_response = ( + status, headers, body = ( SERVICE_UNAVAILABLE, [], b"Server is shutting down.\n", ) elif isinstance(exc, AbortHandshake): - early_response = ( + status, headers, body = ( exc.status, exc.headers, exc.body, ) elif isinstance(exc, InvalidOrigin): logger.debug("Invalid origin", exc_info=True) - early_response = ( + status, headers, body = ( FORBIDDEN, [], (str(exc) + "\n").encode(), ) elif isinstance(exc, InvalidUpgrade): logger.debug("Invalid upgrade", exc_info=True) - early_response = ( + status, headers, body = ( UPGRADE_REQUIRED, [('Upgrade', 'websocket')], (str(exc) + "\n").encode(), ) elif isinstance(exc, InvalidHandshake): logger.debug("Invalid handshake", exc_info=True) - early_response = ( + status, headers, body = ( BAD_REQUEST, [], (str(exc) + "\n").encode(), ) else: logger.warning("Error in opening handshake", exc_info=True) - early_response = ( + status, headers, body = ( INTERNAL_SERVER_ERROR, [], b"See server log for more information.\n", ) - yield from self.write_http_response(*early_response) + if not isinstance(headers, Headers): + headers = Headers(headers) + yield from self.write_http_response(status, headers, body) yield from self.fail_connection() return @@ -204,8 +206,7 @@ def read_http_request(self): raise InvalidMessage("Malformed HTTP message") from exc self.path = path - self.request_headers = build_headers(headers) - self.raw_request_headers = headers + self.request_headers = headers return path, self.request_headers @@ -217,19 +218,15 @@ def write_http_response(self, status, headers, body=None): This coroutine is also able to write a response body. """ - self.response_headers = build_headers(headers) - self.raw_response_headers = headers + self.response_headers = headers # Since the status line and headers only contain ASCII characters, # we can keep this simple. - response = [ - 'HTTP/1.1 {value} {phrase}'.format( - value=status.value, phrase=status.phrase)] - response.extend('{}: {}'.format(k, v) for k, v in headers) - response.append('\r\n') - response = '\r\n'.join(response).encode() + response = 'HTTP/1.1 {status.value} {status.phrase}\r\n'.format( + status=status) + response += str(headers) - self.writer.write(response) + self.writer.write(response.encode()) if body is not None: self.writer.write(body) @@ -239,20 +236,23 @@ def process_request(self, path, request_headers): """ Intercept the HTTP request and return an HTTP response if needed. - ``request_headers`` are a :class:`~http.client.HTTPMessage`. + ``request_headers`` is a :class:`~websockets.http.Headers` instance. If this coroutine returns ``None``, the WebSocket handshake continues. - If it returns a status code, headers and a optionally a response body, - that HTTP response is sent and the connection is closed. - - The HTTP status must be a :class:`~http.HTTPStatus`. HTTP headers must - be an iterable of ``(name, value)`` pairs. If provided, the HTTP - response body must be :class:`bytes`. + If it returns a status code, headers and a response body, that HTTP + response is sent and the connection is closed. + The HTTP status must be a :class:`~http.HTTPStatus`. (:class:`~http.HTTPStatus` was added in Python 3.5. Use a compatible object on earlier versions. Look at ``SWITCHING_PROTOCOLS`` in ``websockets.compatibility`` for an example.) + HTTP headers must be a :class:`~websockets.http.Headers` instance, a + :class:`~collections.abc.Mapping`, or an iterable of ``(name, value)`` + pairs. + + The HTTP response body must be :class:`bytes`. It may be empty. + This method may be overridden to check the request headers and set a different status, for example to authenticate the request and return ``HTTPStatus.UNAUTHORIZED`` or ``HTTPStatus.FORBIDDEN``. @@ -262,7 +262,8 @@ def process_request(self, path, request_headers): """ - def process_origin(self, origin, origins=None): + @staticmethod + def process_origin(headers, origins=None): """ Handle the Origin HTTP request header. @@ -270,6 +271,7 @@ def process_origin(self, origin, origins=None): acceptable. """ + origin = headers.get('Origin', '') if origins is not None: if origin not in origins: raise InvalidOrigin(origin) @@ -314,7 +316,7 @@ def process_extensions(headers, available_extensions): header_values = headers.get_all('Sec-WebSocket-Extensions') - if header_values is not None and available_extensions is not None: + if header_values and available_extensions: parsed_header_values = sum([ parse_extension_list(header_value) @@ -368,7 +370,7 @@ def process_subprotocol(self, headers, available_subprotocols): header_values = headers.get_all('Sec-WebSocket-Protocol') - if header_values is not None and available_subprotocols is not None: + if header_values and available_subprotocols: parsed_header_values = sum([ parse_subprotocol_list(header_value) @@ -422,8 +424,10 @@ def handshake(self, origins=None, available_extensions=None, subprotocols in order of decreasing preference. If provided, ``extra_headers`` sets additional HTTP response headers. - It can be a mapping or an iterable of (name, value) pairs. It can also - be a callable taking the request path and headers in arguments. + It can be a :class:`~websockets.http.Headers` instance, a + :class:`~collections.abc.Mapping`, an iterable of ``(name, value)`` + pairs, or a callable taking the request path and headers in arguments + and returning one of the above. Raise :exc:`~websockets.exceptions.InvalidHandshake` if the handshake fails. @@ -442,8 +446,7 @@ def handshake(self, origins=None, available_extensions=None, key = check_request(request_headers) - origin = request_headers.get('Origin', '') - self.origin = self.process_origin(origin, origins) + self.origin = self.process_origin(request_headers, origins) extensions_header, self.extensions = self.process_extensions( request_headers, available_extensions) @@ -451,7 +454,7 @@ def handshake(self, origins=None, available_extensions=None, protocol_header = self.subprotocol = self.process_subprotocol( request_headers, available_subprotocols) - response_headers = {} + response_headers = Headers() if extensions_header is not None: response_headers['Sec-WebSocket-Extensions'] = extensions_header @@ -461,8 +464,10 @@ def handshake(self, origins=None, available_extensions=None, if extra_headers is not None: if callable(extra_headers): - extra_headers = extra_headers(path, self.raw_request_headers) - if isinstance(extra_headers, collections.abc.Mapping): + extra_headers = extra_headers(path, self.request_headers) + if isinstance(extra_headers, Headers): + extra_headers = extra_headers.raw_items() + elif isinstance(extra_headers, collections.abc.Mapping): extra_headers = extra_headers.items() for name, value in extra_headers: response_headers[name] = value @@ -473,7 +478,7 @@ def handshake(self, origins=None, available_extensions=None, build_response(response_headers, key) yield from self.write_http_response( - SWITCHING_PROTOCOLS, list(response_headers.items())) + SWITCHING_PROTOCOLS, response_headers) self.connection_open() @@ -645,8 +650,10 @@ class Serve: * ``subprotocols`` is a list of supported subprotocols in order of decreasing preference * ``extra_headers`` sets additional HTTP response headers — it can be a - mapping, an iterable of (name, value) pairs, or a callable taking the - request path and headers in arguments. + :class:`~websockets.http.Headers` instance, a + :class:`~collections.abc.Mapping`, an iterable of ``(name, value)`` + pairs, or a callable taking the request path and headers in arguments + and returning one of the above. * ``compression`` is a shortcut to configure compression extensions; by default it enables the "permessage-deflate" extension; set it to ``None`` to disable compression diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 321ecbec6..5c1e4a6f9 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -23,7 +23,7 @@ ServerPerMessageDeflateFactory ) from .handshake import build_response -from .http import USER_AGENT, read_response +from .http import USER_AGENT, Headers, read_response from .server import * from .test_protocol import MS @@ -48,11 +48,8 @@ def handler(ws, path): elif path == '/path': yield from ws.send(str(ws.path)) elif path == '/headers': - yield from ws.send(str(ws.request_headers)) - yield from ws.send(str(ws.response_headers)) - elif path == '/raw_headers': - yield from ws.send(repr(ws.raw_request_headers)) - yield from ws.send(repr(ws.raw_response_headers)) + yield from ws.send(repr(ws.request_headers)) + yield from ws.send(repr(ws.response_headers)) elif path == '/extensions': yield from ws.send(repr(ws.extensions)) elif path == '/subprotocol': @@ -149,14 +146,16 @@ class UnauthorizedServerProtocol(WebSocketServerProtocol): @asyncio.coroutine def process_request(self, path, request_headers): - return UNAUTHORIZED, [] + # Use [...] here rather than Headers(...) to ensure that both work. + return UNAUTHORIZED, [('X-Access', 'denied')] class ForbiddenServerProtocol(WebSocketServerProtocol): @asyncio.coroutine def process_request(self, path, request_headers): - return FORBIDDEN, [] + # Use Headers(...) here rather than [...] to ensure that both work. + return FORBIDDEN, Headers({'X-Access': 'Denied'}) class HealthCheckServerProtocol(WebSocketServerProtocol): @@ -400,73 +399,82 @@ def test_protocol_headers(self): self.assertEqual(client_resp['Server'], USER_AGENT) server_req = self.loop.run_until_complete(self.client.recv()) server_resp = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(server_req, str(client_req)) - self.assertEqual(server_resp, str(client_resp)) - - @with_server() - @with_client('/raw_headers') - def test_protocol_raw_headers(self): - client_req = self.client.raw_request_headers - client_resp = self.client.raw_response_headers - self.assertEqual(dict(client_req)['User-Agent'], USER_AGENT) - self.assertEqual(dict(client_resp)['Server'], USER_AGENT) - server_req = self.loop.run_until_complete(self.client.recv()) - server_resp = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_req, repr(client_req)) self.assertEqual(server_resp, repr(client_resp)) @with_server() - @with_client('/raw_headers', extra_headers={'X-Spam': 'Eggs'}) + @with_client('/headers', extra_headers=Headers({'X-Spam': 'Eggs'})) + def test_protocol_custom_request_headers(self): + req_headers = self.loop.run_until_complete(self.client.recv()) + self.loop.run_until_complete(self.client.recv()) + self.assertIn("('X-Spam', 'Eggs')", req_headers) + + @with_server() + @with_client('/headers', extra_headers={'X-Spam': 'Eggs'}) def test_protocol_custom_request_headers_dict(self): req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", req_headers) @with_server() - @with_client('/raw_headers', extra_headers=[('X-Spam', 'Eggs')]) + @with_client('/headers', extra_headers=[('X-Spam', 'Eggs')]) def test_protocol_custom_request_headers_list(self): req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", req_headers) @with_server() - @with_client('/raw_headers', extra_headers=[('User-Agent', 'Eggs')]) + @with_client('/headers', extra_headers=[('User-Agent', 'Eggs')]) def test_protocol_custom_request_user_agent(self): req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertEqual(req_headers.count("User-Agent"), 1) self.assertIn("('User-Agent', 'Eggs')", req_headers) + @with_server(extra_headers=lambda p, r: Headers({'X-Spam': 'Eggs'})) + @with_client('/headers') + def test_protocol_custom_response_headers_callable(self): + self.loop.run_until_complete(self.client.recv()) + resp_headers = self.loop.run_until_complete(self.client.recv()) + self.assertIn("('X-Spam', 'Eggs')", resp_headers) + @with_server(extra_headers=lambda p, r: {'X-Spam': 'Eggs'}) - @with_client('/raw_headers') + @with_client('/headers') def test_protocol_custom_response_headers_callable_dict(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) @with_server(extra_headers=lambda p, r: [('X-Spam', 'Eggs')]) - @with_client('/raw_headers') + @with_client('/headers') def test_protocol_custom_response_headers_callable_list(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) + @with_server(extra_headers=Headers({'X-Spam': 'Eggs'})) + @with_client('/headers') + def test_protocol_custom_response_headers(self): + self.loop.run_until_complete(self.client.recv()) + resp_headers = self.loop.run_until_complete(self.client.recv()) + self.assertIn("('X-Spam', 'Eggs')", resp_headers) + @with_server(extra_headers={'X-Spam': 'Eggs'}) - @with_client('/raw_headers') + @with_client('/headers') def test_protocol_custom_response_headers_dict(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) @with_server(extra_headers=[('X-Spam', 'Eggs')]) - @with_client('/raw_headers') + @with_client('/headers') def test_protocol_custom_response_headers_list(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) @with_server(extra_headers=[('Server', 'Eggs')]) - @with_client('/raw_headers') + @with_client('/headers') def test_protocol_custom_response_user_agent(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) diff --git a/websockets/test_exceptions.py b/websockets/test_exceptions.py index da87bed5d..0985e5766 100644 --- a/websockets/test_exceptions.py +++ b/websockets/test_exceptions.py @@ -1,6 +1,7 @@ import unittest from .exceptions import * +from .http import Headers class ExceptionsTests(unittest.TestCase): @@ -12,7 +13,7 @@ def test_str(self): "Invalid request", ), ( - AbortHandshake(200, [], b'OK\n'), + AbortHandshake(200, Headers(), b'OK\n'), "HTTP 200, 0 headers, 3 bytes", ), ( diff --git a/websockets/test_http.py b/websockets/test_http.py index 51df901f1..01ae6de71 100644 --- a/websockets/test_http.py +++ b/websockets/test_http.py @@ -2,7 +2,7 @@ import unittest from .http import * -from .http import build_headers, read_headers +from .http import read_headers class HTTPAsyncTests(unittest.TestCase): @@ -33,7 +33,7 @@ def test_read_request(self): path, headers = self.loop.run_until_complete( read_request(self.stream)) self.assertEqual(path, '/chat') - self.assertEqual(dict(headers)['Upgrade'], 'websocket') + self.assertEqual(headers['Upgrade'], 'websocket') def test_read_response(self): # Example from the protocol overview in RFC 6455 @@ -48,7 +48,7 @@ def test_read_response(self): status_code, headers = self.loop.run_until_complete( read_response(self.stream)) self.assertEqual(status_code, 101) - self.assertEqual(dict(headers)['Upgrade'], 'websocket') + self.assertEqual(headers['Upgrade'], 'websocket') def test_request_method(self): self.stream.feed_data(b'OPTIONS * HTTP/1.1\r\n\r\n') @@ -102,34 +102,6 @@ def test_line_ending(self): self.loop.run_until_complete(read_headers(self.stream)) -class HTTPSyncTests(unittest.TestCase): - - def test_build_headers(self): - headers = build_headers([ - ('X-Foo', 'Bar'), - ('X-Baz', 'Quux Quux'), - ]) - - self.assertEqual(headers['X-Foo'], 'Bar') - self.assertEqual(headers['X-Bar'], None) - - self.assertEqual(headers.get('X-Bar', ''), '') - self.assertEqual(headers.get('X-Baz', ''), 'Quux Quux') - - def test_build_headers_multi_value(self): - headers = build_headers([ - ('X-Foo', 'Bar'), - ('X-Foo', 'Baz'), - ]) - - # Getting a single value is non-deterministic. - self.assertIn(headers['X-Foo'], ['Bar', 'Baz']) - self.assertIn(headers.get('X-Foo'), ['Bar', 'Baz']) - - # Ordering is deterministic when getting all values. - self.assertEqual(headers.get_all('X-Foo'), ['Bar', 'Baz']) - - class HeadersTests(unittest.TestCase): def setUp(self): From 25630f3c26b57ec035c8e2f151abafb4104b8edc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Jun 2018 16:14:36 +0200 Subject: [PATCH 0441/1539] Send mandatory HTTP headers, including Date. This improves HTTP compliance. Fix #386. --- websockets/client.py | 7 +++---- websockets/exceptions.py | 4 ++-- websockets/server.py | 16 ++++++++++++---- websockets/test_client_server.py | 12 ++++++------ 4 files changed, 23 insertions(+), 16 deletions(-) diff --git a/websockets/client.py b/websockets/client.py index e4733d7e2..b8f33b2cb 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -228,6 +228,8 @@ def handshake(self, wsuri, origin=None, available_extensions=None, if origin is not None: request_headers['Origin'] = origin + key = build_request(request_headers) + if available_extensions is not None: extensions_header = build_extension_list([ ( @@ -250,10 +252,7 @@ def handshake(self, wsuri, origin=None, available_extensions=None, for name, value in extra_headers: request_headers[name] = value - if 'User-Agent' not in request_headers: - request_headers['User-Agent'] = USER_AGENT - - key = build_request(request_headers) + request_headers.setdefault('User-Agent', USER_AGENT) yield from self.write_http_request( wsuri.resource_name, request_headers) diff --git a/websockets/exceptions.py b/websockets/exceptions.py index b4f7dfc2e..74619cabd 100644 --- a/websockets/exceptions.py +++ b/websockets/exceptions.py @@ -20,12 +20,12 @@ class AbortHandshake(InvalidHandshake): Exception raised to abort a handshake and return a HTTP response. """ - def __init__(self, status, headers, body=None): + def __init__(self, status, headers, body=b''): self.status = status self.headers = headers self.body = body message = "HTTP {}, {} headers, {} bytes".format( - status, len(headers), 0 if body is None else len(body)) + status, len(headers), len(body)) super().__init__(message) diff --git a/websockets/server.py b/websockets/server.py index 1ac131a6a..e1a420c79 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -5,6 +5,7 @@ import asyncio import collections.abc +import email.utils import logging import sys @@ -135,6 +136,13 @@ def handler(self): if not isinstance(headers, Headers): headers = Headers(headers) + + headers.setdefault('Date', email.utils.formatdate(usegmt=True)) + headers.setdefault('Server', USER_AGENT) + headers.setdefault('Content-Length', str(len(body))) + headers.setdefault('Content-Type', 'text/plain') + headers.setdefault('Connection', 'close') + yield from self.write_http_response(status, headers, body) yield from self.fail_connection() @@ -455,6 +463,9 @@ def handshake(self, origins=None, available_extensions=None, request_headers, available_subprotocols) response_headers = Headers() + response_headers['Date'] = email.utils.formatdate(usegmt=True) + + build_response(response_headers, key) if extensions_header is not None: response_headers['Sec-WebSocket-Extensions'] = extensions_header @@ -472,10 +483,7 @@ def handshake(self, origins=None, available_extensions=None, for name, value in extra_headers: response_headers[name] = value - if 'Server' not in response_headers: - response_headers['Server'] = USER_AGENT - - build_response(response_headers, key) + response_headers.setdefault('Server', USER_AGENT) yield from self.write_http_response( SWITCHING_PROTOCOLS, response_headers) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 5c1e4a6f9..599629acd 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -146,25 +146,25 @@ class UnauthorizedServerProtocol(WebSocketServerProtocol): @asyncio.coroutine def process_request(self, path, request_headers): - # Use [...] here rather than Headers(...) to ensure that both work. - return UNAUTHORIZED, [('X-Access', 'denied')] + # Test returning headers as a Headers instance (1/3) + return UNAUTHORIZED, Headers([('X-Access', 'denied')]), b'' class ForbiddenServerProtocol(WebSocketServerProtocol): @asyncio.coroutine def process_request(self, path, request_headers): - # Use Headers(...) here rather than [...] to ensure that both work. - return FORBIDDEN, Headers({'X-Access': 'Denied'}) + # Test returning headers as a dict (2/3) + return FORBIDDEN, {'X-Access': 'denied'}, b'' class HealthCheckServerProtocol(WebSocketServerProtocol): @asyncio.coroutine def process_request(self, path, request_headers): + # Test returning headers as a list of pairs (3/3) if path == '/__health__/': - body = b'status = green\n' - return OK, [('Content-Length', str(len(body)))], body + return OK, [('X-Access', 'OK')], b'status = green\n' class FooClientProtocol(WebSocketClientProtocol): From e45fa38282099b94db966dc04d1334336ddbb2f2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Jun 2018 16:26:02 +0200 Subject: [PATCH 0442/1539] Add code of conduct. --- CODE_OF_CONDUCT.md | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 CODE_OF_CONDUCT.md diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 000000000..80f80d51b --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,46 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at aymeric DOT augustin AT fractalideas DOT com. The project team will review and investigate all complaints, and will respond in a way that it deems appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at [http://contributor-covenant.org/version/1/4][version] + +[homepage]: http://contributor-covenant.org +[version]: http://contributor-covenant.org/version/1/4/ From b7a2bfe4ed5a14208ee3658eaad9be9f689cb6fb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Jun 2018 16:56:46 +0200 Subject: [PATCH 0443/1539] Expand contribution guidelines. Hopefully this will reduce the number of "question issues", which can be hard to answer patiently. --- docs/contributing.rst | 56 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 53 insertions(+), 3 deletions(-) diff --git a/docs/contributing.rst b/docs/contributing.rst index 4b869dcac..21e2152c1 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -1,11 +1,61 @@ Contributing ============ -Bug reports, patches and suggestions are welcome! Please open an issue_ or -send a `pull request`_. +Thanks for taking the time to contribute to websockets! -Feedback about this documentation is especially valuable — the authors of +Code of Conduct +--------------- + +This project and everyone participating in it is governed by the `Code of +Conduct`_. By participating, you are expected to uphold this code. Please +report inappropriate behavior to aymeric DOT augustin AT fractalideas DOT com. + +.. _Code of Conduct: https://github.com/aaugustin/websockets/blob/master/CODE_OF_CONDUCT.md + +*(If I'm the person with the inappropriate behavior, please accept my +apologies. I know I can mess up. I can't expect you to tell me, but if you +chose to do so, I'll do my best to handle criticism constructively. +-- Aymeric)* + +Contributions +------------- + +Bug reports, patches and suggestions are welcome! + +Please open an issue_ or send a `pull request`_. + +Feedback about the documentation is especially valuable — the authors of ``websockets`` feel more confident about writing code than writing docs :-) +If you're wondering why things are done in a certain way, the :doc:`design +document ` provides lots of details about the internals of websockets. + .. _issue: https://github.com/aaugustin/websockets/issues/new .. _pull request: https://github.com/aaugustin/websockets/compare/ + +Questions +--------- + +GitHub issues aren't a good medium for handling questions. There are better +places to ask questions, for example Stack Overflow. + +If you want to ask a question anyway, please make sure that: + +- it's a question about ``websockets`` and not about :mod:`asyncio`; +- it isn't answered by the documentation; +- it wasn't asked already. + +A good question can be written as a suggestion to improve the documentation. + +Bitcoin users +------------- + +websockets appears to be quite popular for interfacing with Bitcoin or other +cryptocurrency trackers. I'm strongly opposed to Bitcoin's carbon footprint. + +Please stop heating the planet where my children are supposed to live, thanks. + +Since websockets is released under an open-source license, you can use it for +any purpose you like. However, I won't spend any of my time to help. + +I will summarily close issues related to Bitcoin or cryptocurrency in any way. From 39b050fc8a053f21fe490c09b7d294b988241280 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 26 Jun 2018 21:50:43 +0200 Subject: [PATCH 0444/1539] Add CVE reference. --- docs/changelog.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 5bc47c99d..375ebbc8f 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -51,7 +51,7 @@ Also: websockets 4.0 was vulnerable to denial of service by memory exhaustion because it didn't enforce ``max_size`` when decompressing compressed - messages. + messages (CVE-2018-1000518). .. warning:: From 02371af16a6311fe9d07e5f0f0bbd6fd996becd5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 15 Jul 2018 11:26:37 +0200 Subject: [PATCH 0445/1539] Set up Circle 2. Fix #437. --- .circleci/config.yml | 70 ++++++++++++++++++++++++++++++++++++++++++++ circle.yml | 14 --------- tox.ini | 2 +- 3 files changed, 71 insertions(+), 15 deletions(-) create mode 100644 .circleci/config.yml delete mode 100644 circle.yml diff --git a/.circleci/config.yml b/.circleci/config.yml new file mode 100644 index 000000000..c4fdb2fc0 --- /dev/null +++ b/.circleci/config.yml @@ -0,0 +1,70 @@ +version: 2 + +jobs: + quality: + docker: + - image: circleci/python:3.7 + steps: + - checkout + - run: sudo pip install tox + - run: tox -e flake8,isort + py34: + docker: + - image: circleci/python:3.4 + steps: + - checkout + - run: sudo pip install tox + - run: tox -e py34,py34-speedups + py35: + docker: + - image: circleci/python:3.5 + steps: + - checkout + - run: sudo pip install tox + - run: tox -e py35,py35-speedups + py36: + docker: + - image: circleci/python:3.6 + steps: + - checkout + - run: sudo pip install tox + - run: tox -e py36,py36-speedups + py37: + docker: + - image: circleci/python:3.7 + steps: + - checkout + - run: sudo pip install tox + - run: tox -e py37,py37-speedups + coverage: + docker: + - image: circleci/python:3.7 + steps: + - checkout + - run: sudo pip install tox codecov + - run: tox -e coverage + - run: codecov + +workflows: + version: 2 + build: + jobs: + - quality + - py34: + requires: + - quality + - py35: + requires: + - quality + - py36: + requires: + - quality + - py37: + requires: + - quality + - coverage: + requires: + - py34 + - py35 + - py36 + - py37 diff --git a/circle.yml b/circle.yml deleted file mode 100644 index 1726ea432..000000000 --- a/circle.yml +++ /dev/null @@ -1,14 +0,0 @@ -machine: - post: - - pyenv global 3.6.1 3.5.3 3.4.4 - python: - version: 3.6.1 - -dependencies: - override: - - pip install tox codecov - -test: - override: - - tox - - codecov diff --git a/tox.ini b/tox.ini index 2eaf4fa48..e30ca5288 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = {py34,py35,py36}{,-speedups},coverage,flake8,isort +envlist = {py34,py35,py36,py37}{,-speedups},coverage,flake8,isort [testenv] commands = From e4e21e35219d04e7bca97fe690191b4da29c0ab1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 15 Jul 2018 12:47:31 +0200 Subject: [PATCH 0446/1539] Remove IPv6 host entries. It's impossible to connect to ::1 in a Circle CI container because IPv6 networking isn't set up: OSError: [Errno 99] error while attempting to bind on address ('::1', 0, 0, 0): cannot assign requested address --- .circleci/config.yml | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/.circleci/config.yml b/.circleci/config.yml index c4fdb2fc0..1bce9d635 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -5,6 +5,8 @@ jobs: docker: - image: circleci/python:3.7 steps: + # Remove IPv6 entry for localhost in Circle CI containers because it doesn't work anyway. + - run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc - checkout - run: sudo pip install tox - run: tox -e flake8,isort @@ -12,6 +14,8 @@ jobs: docker: - image: circleci/python:3.4 steps: + # Remove IPv6 entry for localhost in Circle CI containers because it doesn't work anyway. + - run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc - checkout - run: sudo pip install tox - run: tox -e py34,py34-speedups @@ -19,6 +23,8 @@ jobs: docker: - image: circleci/python:3.5 steps: + # Remove IPv6 entry for localhost in Circle CI containers because it doesn't work anyway. + - run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc - checkout - run: sudo pip install tox - run: tox -e py35,py35-speedups @@ -26,6 +32,8 @@ jobs: docker: - image: circleci/python:3.6 steps: + # Remove IPv6 entry for localhost in Circle CI containers because it doesn't work anyway. + - run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc - checkout - run: sudo pip install tox - run: tox -e py36,py36-speedups @@ -33,6 +41,8 @@ jobs: docker: - image: circleci/python:3.7 steps: + # Remove IPv6 entry for localhost in Circle CI containers because it doesn't work anyway. + - run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc - checkout - run: sudo pip install tox - run: tox -e py37,py37-speedups @@ -40,6 +50,8 @@ jobs: docker: - image: circleci/python:3.7 steps: + # Remove IPv6 entry for localhost in Circle CI containers because it doesn't work anyway. + - run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc - checkout - run: sudo pip install tox codecov - run: tox -e coverage From 91a376685b1ab7103d3d861ff8b02a1c00f142b1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 15 Jul 2018 11:07:47 +0200 Subject: [PATCH 0447/1539] Support yield from connect/serve on Python 3.7. Fix #435. --- websockets/client.py | 1 + websockets/py35/_test_client_server.py | 3 ++ websockets/server.py | 1 + websockets/test_client_server.py | 41 ++++++++++++++++++++++++++ 4 files changed, 46 insertions(+) diff --git a/websockets/client.py b/websockets/client.py index b8f33b2cb..aa45180bb 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -384,6 +384,7 @@ def __init__(self, uri, *, self._creating_connection = loop.create_connection( factory, host, port, **kwds) + @asyncio.coroutine def __iter__(self): # pragma: no cover transport, protocol = yield from self._creating_connection diff --git a/websockets/py35/_test_client_server.py b/websockets/py35/_test_client_server.py index 5360d8d01..c656dd38a 100644 --- a/websockets/py35/_test_client_server.py +++ b/websockets/py35/_test_client_server.py @@ -39,6 +39,7 @@ async def run_client(): self.loop.run_until_complete(server.wait_closed()) def test_server(self): + async def run_server(): # Await serve. server = await serve(handler, 'localhost', 0) @@ -83,6 +84,7 @@ async def run_client(): @unittest.skipIf( sys.version_info[:3] <= (3, 5, 0), 'this test requires Python 3.5.1+') def test_server(self): + async def run_server(): # Use serve as an asynchronous context manager. async with serve(handler, 'localhost', 0) as server: @@ -99,6 +101,7 @@ async def run_server(): @unittest.skipUnless( hasattr(socket, 'AF_UNIX'), 'this test requires Unix sockets') def test_unix_server(self): + async def run_server(path): async with unix_serve(handler, path) as server: self.assertTrue(server.sockets) diff --git a/websockets/server.py b/websockets/server.py index e1a420c79..1204eabf0 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -740,6 +740,7 @@ def __init__(self, ws_handler, host=None, port=None, *, self._creating_server = creating_server self.ws_server = ws_server + @asyncio.coroutine def __iter__(self): # pragma: no cover server = yield from self._creating_server self.ws_server.wrap(server) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 599629acd..7111f044a 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -24,6 +24,7 @@ ) from .handshake import build_response from .http import USER_AGENT, Headers, read_response +from .protocol import State from .server import * from .test_protocol import MS @@ -1064,6 +1065,46 @@ def test_checking_lack_of_origin_succeeds(self): self.loop.run_until_complete(server.wait_closed()) +class YieldFromTests(unittest.TestCase): + + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + + def test_client(self): + start_server = serve(handler, 'localhost', 0) + server = self.loop.run_until_complete(start_server) + + @asyncio.coroutine + def run_client(): + # Yield from connect. + client = yield from connect(get_server_uri(server)) + self.assertEqual(client.state, State.OPEN) + yield from client.close() + self.assertEqual(client.state, State.CLOSED) + + self.loop.run_until_complete(run_client()) + + server.close() + self.loop.run_until_complete(server.wait_closed()) + + def test_server(self): + + @asyncio.coroutine + def run_server(): + # Yield from serve. + server = yield from serve(handler, 'localhost', 0) + self.assertTrue(server.sockets) + server.close() + yield from server.wait_closed() + self.assertFalse(server.sockets) + + self.loop.run_until_complete(run_server()) + + if sys.version_info[:2] >= (3, 5): # pragma: no cover from .py35._test_client_server import AsyncAwaitTests # noqa from .py35._test_client_server import ContextManagerTests # noqa From bfdcb571d245ae7f7cbec12170d22fe59bd2c43e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 15 Jul 2018 15:13:12 +0200 Subject: [PATCH 0448/1539] Optimize Circle CI workflow. --- .circleci/config.yml | 33 +++++++++------------------------ 1 file changed, 9 insertions(+), 24 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 1bce9d635..1632a17f5 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,15 +1,16 @@ version: 2 jobs: - quality: + main: docker: - image: circleci/python:3.7 steps: # Remove IPv6 entry for localhost in Circle CI containers because it doesn't work anyway. - run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc - checkout - - run: sudo pip install tox - - run: tox -e flake8,isort + - run: sudo pip install tox codecov + - run: tox -e coverage,flake8,isort + - run: codecov py34: docker: - image: circleci/python:3.4 @@ -46,37 +47,21 @@ jobs: - checkout - run: sudo pip install tox - run: tox -e py37,py37-speedups - coverage: - docker: - - image: circleci/python:3.7 - steps: - # Remove IPv6 entry for localhost in Circle CI containers because it doesn't work anyway. - - run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc - - checkout - - run: sudo pip install tox codecov - - run: tox -e coverage - - run: codecov workflows: version: 2 build: jobs: - - quality + - main - py34: requires: - - quality + - main - py35: requires: - - quality + - main - py36: requires: - - quality + - main - py37: requires: - - quality - - coverage: - requires: - - py34 - - py35 - - py36 - - py37 + - main From bc401f179ffd29715417d5e6946fc4c061f7f71b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 15 Jul 2018 15:27:27 +0200 Subject: [PATCH 0449/1539] Prevent duplicate AppVeyor builds. This was configured in the UI, however the UI is ignored when there's a YAML config file. --- .appveyor.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.appveyor.yml b/.appveyor.yml index 41ea07d99..77e07e9b7 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -1,3 +1,9 @@ +branches: + only: + - master + +skip_branch_with_pr: true + environment: # websockets only works on Python >= 3.4. CIBW_SKIP: cp27-* cp33-* From 67ba5335f828ff446ac0e7cab7a9b45ed2513724 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 16 Jul 2018 21:50:15 +0200 Subject: [PATCH 0450/1539] Bump version number. --- docs/changelog.rst | 5 ++++- docs/conf.py | 4 ++-- websockets/version.py | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 375ebbc8f..25ed5191a 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -3,11 +3,14 @@ Changelog .. currentmodule:: websockets -6.0 +6.1 ... *In development* +6.0 +... + .. warning:: **Version 6.0 introduces the** :class:`~http.Headers` **class for managing diff --git a/docs/conf.py b/docs/conf.py index 6a7f02b90..3bdeb3616 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -48,9 +48,9 @@ # built documents. # # The short X.Y version. -version = '5.0' +version = '6.0' # The full version, including alpha/beta/rc tags. -release = '5.0' +release = '6.0' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/websockets/version.py b/websockets/version.py index cd2999a7e..9d929a970 100644 --- a/websockets/version.py +++ b/websockets/version.py @@ -1 +1 @@ -version = '5.0.1' +version = '6.0' From 2c68e9cb0a917de48af99e29d71e1d85ec8c0a62 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 16 Jul 2018 21:59:22 +0200 Subject: [PATCH 0451/1539] Advertise support for Python 3.7. --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index df7aa24f5..7abafd399 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,7 @@ 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', ], packages=packages, ext_modules=ext_modules, From 75b5874a235c82990d5b97da4bbb454b00b978eb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 29 Jul 2018 18:21:26 +0200 Subject: [PATCH 0452/1539] Fix typos in docs. --- docs/changelog.rst | 2 +- docs/design.rst | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 25ed5191a..63d08ba98 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -276,7 +276,7 @@ Also: * Returned a 403 status code instead of 400 when the request Origin isn't allowed. -* Cancelling :meth:`~protocol.WebSocketCommonProtocol.recv` no longer drops +* Canceling :meth:`~protocol.WebSocketCommonProtocol.recv` no longer drops the next message. * Clarified that the closing handshake can be initiated by the client. diff --git a/docs/design.rst b/docs/design.rst index 9974c1cb8..db2924b0c 100644 --- a/docs/design.rst +++ b/docs/design.rst @@ -293,13 +293,13 @@ which may happen as a result of: :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`; - a protocol error, including connection errors: depending on the exception, :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` :ref:`fails the - connection `_ with a suitable code and exits. + connection ` with a suitable code and exits. :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` is separate from :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to make it -easier to implement the timeout on the closing handshake. Cancelling +easier to implement the timeout on the closing handshake. Canceling :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` creates no risk -of cancelling :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` +of canceling :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` and failing to close the TCP connection, thus leaking resources. Terminating the TCP connection can take up to ``2 * timeout`` on the server @@ -320,7 +320,7 @@ If the opening handshake doesn't complete successfully, ``websockets`` fails the connection by closing the TCP connection. Once the opening handshake has completed, ``websockets`` fails the connection -by cancelling :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` and +by canceling :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` and sending a close frame if appropriate. :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` exits, unblocking From 1c9d314dbe7c371e1be39e0738492aaf9f7e0869 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 29 Jul 2018 18:21:37 +0200 Subject: [PATCH 0453/1539] Document that cancelling recv is safe. (It wasn't until version 2.5.) Fix #333. --- websockets/protocol.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/websockets/protocol.py b/websockets/protocol.py index adadf6328..fe73c0971 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -313,6 +313,11 @@ def recv(self): :meth:`recv` used to return ``None`` instead. Refer to the changelog for details. + Canceling :meth:`recv` is safe. There's no risk of losing the next + message. The next invocation of :meth:`recv` will return it. This + makes it possible to enforce a timeout by wrapping :meth:`recv` in + :func:`~asyncio.wait_for`. + """ # Don't yield from self.ensure_open() here because messages could be # available in the queue even if the connection is closed. From 0523ee82ea1f136eaf4e6dcf828ce21461dd34aa Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 29 Jul 2018 18:36:44 +0200 Subject: [PATCH 0454/1539] Spell-check the documentation. --- docs/Makefile | 7 +++++++ docs/changelog.rst | 6 +++--- docs/cheatsheet.rst | 4 ++-- docs/conf.py | 5 +++++ docs/deployment.rst | 2 +- docs/design.rst | 26 +++++++++++++------------- docs/security.rst | 4 ++-- docs/spelling_wordlist.txt | 32 ++++++++++++++++++++++++++++++++ websockets/exceptions.py | 2 +- websockets/http.py | 2 +- websockets/protocol.py | 14 +++++++------- websockets/server.py | 2 +- websockets/test_client_server.py | 2 +- websockets/test_protocol.py | 8 ++++---- 14 files changed, 80 insertions(+), 36 deletions(-) create mode 100644 docs/spelling_wordlist.txt diff --git a/docs/Makefile b/docs/Makefile index d875339dd..bb25aa49d 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -37,6 +37,7 @@ help: @echo " changes to make an overview of all changed/added/deprecated items" @echo " linkcheck to check all external links for integrity" @echo " doctest to run all doctests embedded in the documentation (if enabled)" + @echo " spelling to check for typos in documentation" clean: -rm -rf $(BUILDDIR)/* @@ -151,3 +152,9 @@ doctest: $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest @echo "Testing of doctests in the sources finished, look at the " \ "results in $(BUILDDIR)/doctest/output.txt." + +spelling: + $(SPHINXBUILD) -b spelling $(ALLSPHINXOPTS) $(BUILDDIR)/spelling + @echo + @echo "Check finished. Wrong words can be found in " \ + "$(BUILDDIR)/spelling/output.txt." diff --git a/docs/changelog.rst b/docs/changelog.rst index 63d08ba98..f8bc08227 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -32,7 +32,7 @@ Changelog * Functions defined in the :mod:`~handshake` module now receive :class:`~http.Headers` in argument instead of ``get_header`` or - ``set_header`` fucntions. This affects libraries that rely on + ``set_header`` functions. This affects libraries that rely on low-level APIs. * Functions defined in the :mod:`~http` module now return HTTP headers as @@ -81,7 +81,7 @@ Also: * Added :meth:`~protocol.WebSocketCommonProtocol.closed` property. * If a :meth:`~protocol.WebSocketCommonProtocol.ping` doesn't receive a pong, - it's cancelled when the connection is closed. + it's canceled when the connection is closed. * Reported the cause of :exc:`~exceptions.ConnectionClosed` exceptions. @@ -242,7 +242,7 @@ Also: * Worked around an asyncio bug affecting connection termination under load. -* Made ``state_name`` atttribute on protocols a public API. +* Made ``state_name`` attribute on protocols a public API. * Improved documentation. diff --git a/docs/cheatsheet.rst b/docs/cheatsheet.rst index 259f85d50..8857152c4 100644 --- a/docs/cheatsheet.rst +++ b/docs/cheatsheet.rst @@ -6,7 +6,7 @@ Cheat sheet Server ------ -* Write a coroutine that handles a single connection. It receives a websocket +* Write a coroutine that handles a single connection. It receives a WebSocket protocol instance and the URI path in argument. * Call :meth:`~protocol.WebSocketCommonProtocol.recv` and @@ -117,6 +117,6 @@ connection handler, you can bind them with :func:`functools.partial`:: asyncio.get_event_loop().run_until_complete(start_server) asyncio.get_event_loop().run_forever() -Another way to achieve this result is to define the ``handler`` corountine in +Another way to achieve this result is to define the ``handler`` coroutine in a scope where the ``extra_argument`` variable exists instead of injecting it through an argument. diff --git a/docs/conf.py b/docs/conf.py index 3bdeb3616..04db46ca7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -27,6 +27,11 @@ # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. extensions = ['sphinx.ext.autodoc', 'sphinx.ext.intersphinx', 'sphinx.ext.viewcode'] +# Spelling check needs an additional module that is not installed by default. +# Add it only if spelling check is requested so docs can be generated without it. +if 'spelling' in sys.argv: + extensions.append('sphinxcontrib.spelling') + # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] diff --git a/docs/deployment.rst b/docs/deployment.rst index ed4453cd6..15f722eea 100644 --- a/docs/deployment.rst +++ b/docs/deployment.rst @@ -30,7 +30,7 @@ with the object returned by :func:`~server.serve`: - calling its ``close()`` method, then waiting for its ``wait_closed()`` method to complete. -Tasks that handle connections will be cancelled. For example, if the handler +Tasks that handle connections will be canceled. For example, if the handler is awaiting :meth:`~protocol.WebSocketCommonProtocol.recv`, that call will raise :exc:`~asyncio.CancelledError`. diff --git a/docs/design.rst b/docs/design.rst index db2924b0c..6cd095369 100644 --- a/docs/design.rst +++ b/docs/design.rst @@ -70,14 +70,14 @@ two tasks: - :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` runs :meth:`~protocol.WebSocketCommonProtocol.transfer_data()` which handles incoming data and lets :meth:`~protocol.WebSocketCommonProtocol.recv()` - consume it. It may be cancelled to terminate the connection. It never exits + consume it. It may be canceled to terminate the connection. It never exits with an exception other than :exc:`~asyncio.CancelledError`. See :ref:`data transfer ` below. - :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` runs :meth:`~protocol.WebSocketCommonProtocol.close_connection()` which waits for the data transfer to terminate, then takes care of closing the TCP - connection. It must not be cancelled. It never exits with an exception. See + connection. It must not be canceled. It never exits with an exception. See :ref:`connection termination ` below. Besides, :meth:`~protocol.WebSocketCommonProtocol.fail_connection()` starts @@ -113,7 +113,7 @@ passing the protocol to the ``ws_handler`` coroutine handling the connection. While the opening handshake is asymmetrical — the client sends an HTTP Upgrade request and the server replies with an HTTP Switching Protocols response — -``websockets`` aims at keepping the implementation of both sides consistent +``websockets`` aims at keeping the implementation of both sides consistent with one another. On the client side, :meth:`~client.WebSocketClientProtocol.handshake()`: @@ -132,7 +132,7 @@ On the server side, :meth:`~server.WebSocketServerProtocol.handshake()`: - calls :meth:`~server.WebSocketServerProtocol.process_request()` which may abort the WebSocket handshake and return a HTTP response instead; this hook only makes sense on the server side; -- checks the HTTP request, negociates ``extensions`` and ``subprotocol``, and +- checks the HTTP request, negotiates ``extensions`` and ``subprotocol``, and configures the protocol accordingly; - builds a HTTP response based on the above and parameters passed to :meth:`~server.serve()`; @@ -140,8 +140,8 @@ On the server side, :meth:`~server.WebSocketServerProtocol.handshake()`: - moves to the ``OPEN`` state; - returns the ``path`` part of the ``uri``. -The most significant assymetry between the two sides of the opening handshake -lies in the negociation of extensions and, to a lesser extent, of the +The most significant asymmetry between the two sides of the opening handshake +lies in the negotiation of extensions and, to a lesser extent, of the subprotocol. The server knows everything about both sides and decides what the parameters should be for the connection. The client merely applies them. @@ -213,7 +213,7 @@ messages in the :attr:`~protocol.WebSocketCommonProtocol.messages` queue. When it encounters a control frame: - if it's a close frame, it starts the closing handshake; -- if it's a ping frame, it anwsers with a pong frame; +- if it's a ping frame, it answers with a pong frame; - if it's a pong frame, it acknowledges the corresponding ping (unless it's an unsolicited pong). @@ -334,7 +334,7 @@ Cancellation ------------ Most :doc:`public APIs ` of ``websockets`` are coroutines. They may be -cancelled. ``websockets`` must handle this situation. +canceled. ``websockets`` must handle this situation. Cancellation during the opening handshake is handled like any other exception: the TCP connection is closed and the exception is re-raised or logged. @@ -342,7 +342,7 @@ the TCP connection is closed and the exception is re-raised or logged. Once the WebSocket connection is established, :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` and :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` mustn't get -accidentally cancelled if a coroutine that awaits them is cancelled. They must +accidentally canceled if a coroutine that awaits them is canceled. They must be shielded from cancellation. :meth:`~protocol.WebSocketCommonProtocol.recv()` waits for the next message in @@ -360,11 +360,11 @@ on the transfer data task, it doesn't propagate cancellation to that task. prevent cancellation. :meth:`~protocol.WebSocketCommonProtocol.close()` waits for the data transfer -task to terminate with :func:`~asyncio.wait_for`. If it's cancelled or if the -timout elapses, :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` -is cancelled. :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` is +task to terminate with :func:`~asyncio.wait_for`. If it's canceled or if the +timeout elapses, :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` +is canceled. :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` is expected to catch the cancellation and terminate properly. This is the only -point where it may be cancelled. +point where it may be canceled. :meth:`~protocol.WebSocketCommonProtocol.close()` then waits for :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` but shields it diff --git a/docs/security.rst b/docs/security.rst index aae067f34..f0d1deee3 100644 --- a/docs/security.rst +++ b/docs/security.rst @@ -36,9 +36,9 @@ Other limits ------------ ``websockets`` implements additional limits on the amount of data it accepts -in order to mimimize exposure to security vulnerabilities. +in order to minimize exposure to security vulnerabilities. In the opening handshake, ``websockets`` limits the number of HTTP headers to 256 and the size of an individual header to 4096 bytes. These limits are 10 to -20 times larger than what's expected in standard use cases. They're hardcoded. +20 times larger than what's expected in standard use cases. They're hard-coded. If you need to change them, monkey-patch the constants in ``websockets.http``. diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt new file mode 100644 index 000000000..63be00f1f --- /dev/null +++ b/docs/spelling_wordlist.txt @@ -0,0 +1,32 @@ +attr +augustin +Auth +awaitable +aymeric +backpressure +Backpressure +Bitcoin +bufferbloat +Bufferbloat +bugfix +changelog +cryptocurrency +daemonize +fractalideas +iterable +kB +keepalive +lifecycle +Lifecycle +nginx +permessage +pong +Pythonic +serializers +subprotocol +subprotocols +TLS +Unparse +websocket +WebSocket +websockets diff --git a/websockets/exceptions.py b/websockets/exceptions.py index 74619cabd..91ae52549 100644 --- a/websockets/exceptions.py +++ b/websockets/exceptions.py @@ -97,7 +97,7 @@ def __init__(self, status_code): class NegotiationError(InvalidHandshake): """ - Exception raised when negociating an extension fails. + Exception raised when negotiating an extension fails. """ diff --git a/websockets/http.py b/websockets/http.py index 5ec339055..e0bd17609 100644 --- a/websockets/http.py +++ b/websockets/http.py @@ -189,7 +189,7 @@ def read_line(stream): """ # Security: this is bounded by the StreamReader's limit (default = 32kB). line = yield from stream.readline() - # Security: this guarantees header values are small (hardcoded = 4kB) + # Security: this guarantees header values are small (hard-coded = 4kB) if len(line) > MAX_LINE: raise ValueError("Line too long") # Not mandatory but safe - https://tools.ietf.org/html/rfc7230#section-3.5 diff --git a/websockets/protocol.py b/websockets/protocol.py index fe73c0971..f14c9665e 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -384,7 +384,7 @@ def close(self, code=1000, reason=''): connection to terminate. It doesn't do anything once the connection is closed. In other words - it's idemptotent. + it's idempotent. It's safe to wrap this coroutine in :func:`~asyncio.ensure_future` since errors during connection termination aren't particularly useful. @@ -411,8 +411,8 @@ def close(self, code=1000, reason=''): # exception, so there's no need to catch CancelledError here. try: - # If close() is cancelled during the wait, self.transfer_data_task - # is cancelled before the timeout elapses (on Python ≥ 3.4.3). + # If close() is canceled during the wait, self.transfer_data_task + # is canceled before the timeout elapses (on Python ≥ 3.4.3). # This helps closing connections when shutting down a server. yield from asyncio.wait_for( self.transfer_data_task, @@ -821,7 +821,7 @@ def close_connection(self): ) plural = 's' if len(self.pings) > 1 else '' logger.debug( - "%s - cancelled pending ping%s: %s", + "%s - canceled pending ping%s: %s", self.side, plural, pings_hex) # A client should wait for a TCP close from the server. @@ -844,7 +844,7 @@ def close_connection(self): finally: # The try/finally ensures that the transport never remains open, - # even if this coroutine is cancelled (for example). + # even if this coroutine is canceled (for example). # If connection_lost() was called, the TCP connection is closed. # However, if TLS is enabled, the transport still needs closing. @@ -904,7 +904,7 @@ def fail_connection(self, code=1006, reason=''): handshake succeeded and the other side is likely to process it. 3. Closing the connection. :meth:`close_connection` takes care of - this once :attr:`transfer_data_task` exits after being cancelled. + this once :attr:`transfer_data_task` exits after being canceled. (The specification describes these steps in the opposite order.) @@ -1018,7 +1018,7 @@ def connection_lost(self, exc): self.close_code, self.close_reason or '[empty]') # If self.connection_lost_waiter isn't pending, that's a bug, because: # - it's set only here in connection_lost() which is called only once; - # - it must never be cancelled. + # - it must never be canceled. self.connection_lost_waiter.set_result(None) super().connection_lost(exc) diff --git a/websockets/server.py b/websockets/server.py index 1204eabf0..af7089983 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -673,7 +673,7 @@ class Serve: When a server is closed with :meth:`~websockets.server.WebSocketServer.close`, all running WebSocket - handlers are cancelled. They may intercept :exc:`~asyncio.CancelledError` + handlers are canceled. They may intercept :exc:`~asyncio.CancelledError` and perform cleanup actions before re-raising that exception. If a handler started new tasks, it should cancel them as well in that case. diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 7111f044a..fcf768a8e 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -277,7 +277,7 @@ def test_basic(self): def test_server_close_while_client_connected(self): with self.temp_server(loop=self.loop): - # This endpoint waits just a bit when the connection is cancelled + # This endpoint waits just a bit when the connection is canceled # in order to test that wait_closed() really waits for completion. self.start_client('/slow_stop') with self.assertRaises(ConnectionClosed): diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 19ec7bca6..e9e227104 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -189,7 +189,7 @@ def half_close_connection_local(self, code=1000, reason='close'): left in the CLOSING state until the event loop runs again. The current implementation returns a task that must be awaited or - cancelled, else asyncio complains about destroying a pending task. + canceled, else asyncio complains about destroying a pending task. """ close_frame_data = serialize_close(code, reason) @@ -208,7 +208,7 @@ def half_close_connection_local(self, code=1000, reason='close'): MS, self.receive_frame, Frame(True, OP_CLOSE, close_frame_data)) self.loop.call_later(2 * MS, self.receive_eof_if_client) - # This task must be awaited or cancelled by the caller. + # This task must be awaited or canceled by the caller. return close_task def half_close_connection_remote(self, code=1000, reason='close'): @@ -433,7 +433,7 @@ def read_message(): self.process_invalid_frames() self.assertConnectionFailed(1011, '') - def test_recv_cancelled(self): + def test_recv_canceled(self): recv = self.ensure_future(self.protocol.recv()) self.loop.call_soon(recv.cancel) with self.assertRaises(asyncio.CancelledError): @@ -630,7 +630,7 @@ def test_acknowledge_previous_pings(self): self.assertTrue(pings[1][0].done()) self.assertFalse(pings[2][0].done()) - def test_cancelled_ping(self): + def test_canceled_ping(self): ping = self.loop.run_until_complete(self.protocol.ping()) ping_frame = self.last_sent_frame() ping.cancel() From 3525ec391cb031d62015e71dcb4710b6f06248aa Mon Sep 17 00:00:00 2001 From: cclauss Date: Mon, 30 Jul 2018 07:54:05 +0200 Subject: [PATCH 0455/1539] Travis CI upgrade to Python 3.7 Upgrade to Python 3.7 in alignment with travis-ci/travis-ci#9069 --- .travis.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.travis.yml b/.travis.yml index d9460952e..1bf868253 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,3 +1,4 @@ +language: python env: global: # websockets only works on Python >= 3.4. @@ -7,10 +8,9 @@ env: matrix: include: - - dist: trusty + - dist: xenial # required for Python 3.7 (travis-ci/travis-ci#9069) sudo: required - language: python - python: "3.6" + python: "3.7" services: - docker - os: osx From ef33ebd2a913d548cc261ef054d89d04f20d9dd1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 30 Jul 2018 14:59:41 +0200 Subject: [PATCH 0456/1539] Clarify why timeout only applies to close. Fix #428. --- websockets/protocol.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/websockets/protocol.py b/websockets/protocol.py index f14c9665e..462904b23 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -78,10 +78,18 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): The ``host``, ``port`` and ``secure`` parameters are simply stored as attributes for handlers that need them. - The ``timeout`` parameter defines the maximum wait time in seconds for - completing the closing handshake and, only on the client side, for - terminating the TCP connection. :meth:`close()` will complete in at most - ``4 * timeout`` on the server side and ``5 * timeout`` on the client side. + The ``timeout`` parameter defines a maximum wait time in seconds for + completing the closing handshake and terminating the TCP connection. + :meth:`close()` completes in at most ``4 * timeout`` on the server side + and ``5 * timeout`` on the client side. + + ``timeout`` is a parameter of the protocol because websockets usually + calls :meth:`close()` implicitly: + + - on the server side, when the connection handler terminates, + - on the client side, when exiting the context manager for the connection. + + To apply a timeout to any other API, wrap it in :func:`~asyncio.wait_for`. The ``max_size`` parameter enforces the maximum size for incoming messages in bytes. The default value is 1MB. ``None`` disables the limit. If a From e2c4303129801c87a1f0f1c36ec7134f306e3194 Mon Sep 17 00:00:00 2001 From: cclauss Date: Mon, 30 Jul 2018 15:07:44 +0200 Subject: [PATCH 0457/1539] Move language: python --- .travis.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index 1bf868253..b66c0f5b7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,4 +1,3 @@ -language: python env: global: # websockets only works on Python >= 3.4. @@ -8,7 +7,8 @@ env: matrix: include: - - dist: xenial # required for Python 3.7 (travis-ci/travis-ci#9069) + - language: python + dist: xenial # required for Python 3.7 (travis-ci/travis-ci#9069) sudo: required python: "3.7" services: From ab590c620dd195a508e2884cfb9763e50d99addd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 16 Jul 2018 21:33:51 +0200 Subject: [PATCH 0458/1539] Add an interactive client. Fix #361. --- Makefile | 2 +- docs/changelog.rst | 2 + docs/intro.rst | 7 ++ tox.ini | 2 +- websockets/__main__.py | 143 +++++++++++++++++++++++++++++++++++++++ websockets/exceptions.py | 34 +++++++--- 6 files changed, 177 insertions(+), 13 deletions(-) create mode 100644 websockets/__main__.py diff --git a/Makefile b/Makefile index db31c68d5..4992263b3 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ test: coverage: python -m coverage erase - python -W default -m coverage run --branch --source=websockets -m unittest + python -W default -m coverage run --branch --omit=websockets/__main__.py --source=websockets -m unittest python -m coverage html clean: diff --git a/docs/changelog.rst b/docs/changelog.rst index f8bc08227..9ff32f62a 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -8,6 +8,8 @@ Changelog *In development* +* Added an interactive client: `python -m websockets ` + 6.0 ... diff --git a/docs/intro.rst b/docs/intro.rst index 3eb3505db..154e1d8ea 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -221,6 +221,13 @@ answering pings, or any other behavior required by the specification. ``websockets`` handles all this under the hood so you don't have to. +One more thing... +----------------- + +``websockets`` provides an interactive client:: + + $ python -m websockets wss://echo.websocket.org/ + .. _python-lt-36: Python < 3.6 diff --git a/tox.ini b/tox.ini index e30ca5288..9a3516af9 100644 --- a/tox.ini +++ b/tox.ini @@ -27,7 +27,7 @@ commands = python setup.py --quiet build_ext --inplace python -m coverage erase - python -W default -m coverage run --branch --source=websockets -m unittest + python -W default -m coverage run --branch --omit=websockets/__main__.py --source=websockets -m unittest python -m coverage report --show-missing --fail-under=100 speedups: sh -c 'rm websockets/*.so' diff --git a/websockets/__main__.py b/websockets/__main__.py new file mode 100644 index 000000000..582594474 --- /dev/null +++ b/websockets/__main__.py @@ -0,0 +1,143 @@ +import argparse +import asyncio +import os +import signal +import sys +import threading + +import websockets +from websockets.compatibility import asyncio_ensure_future +from websockets.exceptions import format_close + + +def exit_from_event_loop_thread(loop, stop): + loop.stop() + if not stop.done(): + # When exiting the thread that runs the event loop, raise + # KeyboardInterrupt in the main thead to exit the program. + try: + ctrl_c = signal.CTRL_C_EVENT # Windows + except AttributeError: + ctrl_c = signal.SIGINT # POSIX + os.kill(os.getpid(), ctrl_c) + + +def print_during_input(string): + sys.stdout.write( + '\N{ESC}7' # Save cursor position + '\N{LINE FEED}' # Add a new line + '\N{ESC}[A' # Move cursor up + '\N{ESC}[L' # Insert blank line, scroll last line down + '{string}\N{LINE FEED}' # Print string in the inserted blank line + '\N{ESC}8' # Restore cursor position + '\N{ESC}[B' # Move cursor down + .format(string=string) + ) + sys.stdout.flush() + + +def print_over_input(string): + sys.stdout.write( + '\N{CARRIAGE RETURN}' # Move cursor to beginning of line + '\N{ESC}[K' # Delete current line + '{string}\N{LINE FEED}' + .format(string=string) + ) + sys.stdout.flush() + + +@asyncio.coroutine +def run_client(uri, loop, inputs, stop): + try: + websocket = yield from websockets.connect(uri) + except Exception as exc: + print_over_input("Failed to connect to {}: {}.".format(uri, exc)) + exit_from_event_loop_thread(loop, stop) + return + else: + print_during_input("Connected to {}.".format(uri)) + + try: + while True: + incoming = asyncio_ensure_future(websocket.recv()) + outgoing = asyncio_ensure_future(inputs.get()) + done, pending = yield from asyncio.wait( + [incoming, outgoing, stop], + return_when=asyncio.FIRST_COMPLETED, + ) + + # Cancel pending tasks to avoid leaking them. + if incoming in pending: + incoming.cancel() + if outgoing in pending: + outgoing.cancel() + + if incoming in done: + try: + message = incoming.result() + except websockets.ConnectionClosed: + break + else: + print_during_input('< ' + message) + + if outgoing in done: + message = outgoing.result() + yield from websocket.send(message) + + if stop in done: + break + + finally: + yield from websocket.close() + close_status = format_close( + websocket.close_code, websocket.close_reason) + + print_over_input( + "Connection closed: {close_status}." + .format(close_status=close_status) + ) + + exit_from_event_loop_thread(loop, stop) + + +def main(): + # Parse command line arguments. + parser = argparse.ArgumentParser( + prog="python -m websockets", + description="Interactive WebSocket client.", + add_help=False, + ) + parser.add_argument('uri', metavar='') + args = parser.parse_args() + + # Create an event loop that will run in a background thread. + loop = asyncio.new_event_loop() + + # Create a queue of user inputs. There's no need to limit its size. + inputs = asyncio.Queue(loop=loop) + + # Create a stop condition when receiving SIGINT or SIGTERM. + stop = asyncio.Future(loop=loop) + + # Schedule the task that will manage the connection. + asyncio_ensure_future(run_client(args.uri, loop, inputs, stop), loop=loop) + + # Start the event loop in a background thread. + thread = threading.Thread(target=loop.run_forever) + thread.start() + + # Read from stdin in the main thread in order to receive signals. + try: + while True: + # Since there's no size limit, put_nowait is identical to put. + message = input('> ') + loop.call_soon_threadsafe(inputs.put_nowait, message) + except (KeyboardInterrupt, EOFError): # ^C, ^D + loop.call_soon_threadsafe(stop.set_result, None) + + # Wait for the event loop to terminate. + thread.join() + + +if __name__ == '__main__': + main() diff --git a/websockets/exceptions.py b/websockets/exceptions.py index 91ae52549..1b758c648 100644 --- a/websockets/exceptions.py +++ b/websockets/exceptions.py @@ -160,6 +160,28 @@ class InvalidState(Exception): } +def format_close(code, reason): + """ + Display a human-readable version of the close code and reason. + + + """ + if 3000 <= code < 4000: + explanation = "registered" + elif 4000 <= code < 5000: + explanation = "private use" + else: + explanation = CLOSE_CODES.get(code, "unknown") + result = "code = {} ({}), ".format(code, explanation) + + if reason: + result += "reason = {}".format(reason) + else: + result += "no reason" + + return result + + class ConnectionClosed(InvalidState): """ Exception raised when trying to read or write on a closed connection. @@ -172,17 +194,7 @@ def __init__(self, code, reason): self.code = code self.reason = reason message = "WebSocket connection is closed: " - if 3000 <= code < 4000: - explanation = "registered" - elif 4000 <= code < 5000: - explanation = "private use" - else: - explanation = CLOSE_CODES.get(code, "unknown") - message += "code = {} ({}), ".format(code, explanation) - if reason: - message += "reason = {}".format(reason) - else: - message += "no reason" + message += format_close(code, reason) super().__init__(message) From 2d6b4290bd4e59badbcfe1592f9312c3e28fa48d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joel=20Ho=CC=88ner?= Date: Fri, 3 Aug 2018 00:20:30 +0200 Subject: [PATCH 0459/1539] Enable VT100 support on Windows --- websockets/__main__.py | 46 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/websockets/__main__.py b/websockets/__main__.py index 582594474..9e0f0a297 100644 --- a/websockets/__main__.py +++ b/websockets/__main__.py @@ -10,6 +10,38 @@ from websockets.exceptions import format_close +def win_enable_vt100(): + """ + Enable VT-100 for console output on Windows. + + See also https://bugs.python.org/issue29059. + + """ + import ctypes + + STD_OUTPUT_HANDLE = ctypes.c_uint(-11) + INVALID_HANDLE_VALUE = ctypes.c_uint(-1) + ENABLE_VIRTUAL_TERMINAL_PROCESSING = 0x004 + + handle = ctypes.windll.kernel32.GetStdHandle(STD_OUTPUT_HANDLE) + if handle == INVALID_HANDLE_VALUE: + raise RuntimeError("Unable to obtain stdout handle") + + cur_mode = ctypes.c_uint() + if ctypes.windll.kernel32.GetConsoleMode( + handle, ctypes.byref(cur_mode) + ) == 0: + raise RuntimeError("Unable to query current console mode") + + # ctypes ints lack support for the required bit-OR operation. + # Temporarily convert to Py int, do the OR and convert back. + py_int_mode = int.from_bytes(cur_mode, sys.byteorder) + new_mode = ctypes.c_uint(py_int_mode | ENABLE_VIRTUAL_TERMINAL_PROCESSING) + + if ctypes.windll.kernel32.SetConsoleMode(handle, new_mode) == 0: + raise RuntimeError("Unable to set console mode") + + def exit_from_event_loop_thread(loop, stop): loop.stop() if not stop.done(): @@ -101,6 +133,20 @@ def run_client(uri, loop, inputs, stop): def main(): + # If we're on Windows, enable VT100 terminal support. + if os.name == 'nt': + try: + win_enable_vt100() + except RuntimeError as exc: + sys.stderr.write( + "Unable to set terminal to VT100 mode. This is only " + "supported since Win10 anniversary update. Expect " + "weird symbols on the terminal. Error: {exc!s}" + "\N{LINE FEED}" + .format(exc=exc) + ) + sys.stderr.flush() + # Parse command line arguments. parser = argparse.ArgumentParser( prog="python -m websockets", From 4216b35384c177981c4d18d763248c712b8e21d4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 11 Aug 2018 12:16:12 +0200 Subject: [PATCH 0460/1539] Fix typo. Fix #454. --- docs/design.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/design.rst b/docs/design.rst index 6cd095369..33e835d82 100644 --- a/docs/design.rst +++ b/docs/design.rst @@ -393,7 +393,7 @@ accumulates in buffers, eventually causing the server to run out of memory and crash. The solution to this problem is backpressure. Any part of the server that -receives inputs faster than it can it can process them and send the outputs +receives inputs faster than it can process them and send the outputs must propagate that information back to the previous part in the chain. ``websockets`` is designed to make it easy to get backpressure right. From adbe4e5bc70090a3a8aa514e0d613a08e3ec960e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Aug 2018 13:51:55 +0200 Subject: [PATCH 0461/1539] Improve handling of missing headers. Specifically this change makes websockets return a HTTP 426 error with a proper message instead of a HTTP 400 with a cryptic message when the Connection header is missing. Fix #456. --- docs/changelog.rst | 2 + websockets/exceptions.py | 10 +++-- websockets/handshake.py | 72 +++++++++++++++++++++++++---------- websockets/test_exceptions.py | 14 +++++-- websockets/test_handshake.py | 47 +++++++++++++---------- 5 files changed, 97 insertions(+), 48 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 9ff32f62a..7961301cb 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -10,6 +10,8 @@ Changelog * Added an interactive client: `python -m websockets ` +* Improved error messages when a required HTTP header is missing. + 6.0 ... diff --git a/websockets/exceptions.py b/websockets/exceptions.py index 1b758c648..e256f218a 100644 --- a/websockets/exceptions.py +++ b/websockets/exceptions.py @@ -41,11 +41,13 @@ class InvalidHeader(InvalidHandshake): Exception raised when a HTTP header doesn't have a valid format or value. """ - def __init__(self, name, value): - if value: - message = "Invalid {} header: {}".format(name, value) + def __init__(self, name, value=None): + if value is None: + message = "Missing {} header".format(name) + elif value == '': + message = "Empty {} header".format(name) else: - message = "Missing or empty {} header".format(name) + message = "Invalid {} header: {}".format(name, value) super().__init__(message) diff --git a/websockets/handshake.py b/websockets/handshake.py index 00fdd18aa..d8f79d371 100644 --- a/websockets/handshake.py +++ b/websockets/handshake.py @@ -36,7 +36,7 @@ import hashlib import random -from .exceptions import InvalidHeaderValue, InvalidUpgrade +from .exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade from .headers import parse_connection, parse_upgrade @@ -80,29 +80,47 @@ def check_request(headers): responsibility of the caller. """ - connection = parse_connection(headers.get('Connection', '')) + try: + connection = headers['Connection'] + except KeyError: + raise InvalidUpgrade('Connection') + + connection = parse_connection(connection) if not any(value.lower() == 'upgrade' for value in connection): - raise InvalidUpgrade('Connection', headers.get('Connection', '')) + raise InvalidUpgrade('Connection', connection) + + try: + upgrade = headers['Upgrade'] + except KeyError: + raise InvalidUpgrade('Upgrade') - upgrade = parse_upgrade(headers.get('Upgrade', '')) + upgrade = parse_upgrade(upgrade) # For compatibility with non-strict implementations, ignore case when # checking the Upgrade header. It's supposed to be 'WebSocket'. if not (len(upgrade) == 1 and upgrade[0].lower() == 'websocket'): - raise InvalidUpgrade('Upgrade', headers.get('Upgrade', '')) + raise InvalidUpgrade('Upgrade', upgrade) - key = headers.get('Sec-WebSocket-Key', '') try: - raw_key = base64.b64decode(key.encode(), validate=True) + s_w_key = headers['Sec-WebSocket-Key'] + except KeyError: + raise InvalidHeader('Sec-WebSocket-Key') + + try: + raw_key = base64.b64decode(s_w_key.encode(), validate=True) except binascii.Error: - raise InvalidHeaderValue('Sec-WebSocket-Key', key) + raise InvalidHeaderValue('Sec-WebSocket-Key', s_w_key) if len(raw_key) != 16: - raise InvalidHeaderValue('Sec-WebSocket-Key', key) + raise InvalidHeaderValue('Sec-WebSocket-Key', s_w_key) - version = headers.get('Sec-WebSocket-Version', '') - if version != '13': - raise InvalidHeaderValue('Sec-WebSocket-Version', version) + try: + s_w_version = headers['Sec-WebSocket-Version'] + except KeyError: + raise InvalidHeader('Sec-WebSocket-Version') - return key + if s_w_version != '13': + raise InvalidHeaderValue('Sec-WebSocket-Version', s_w_version) + + return s_w_key def build_response(headers, key): @@ -133,19 +151,33 @@ def check_response(headers, key): the caller. """ - connection = parse_connection(headers.get('Connection', '')) + try: + connection = headers['Connection'] + except KeyError: + raise InvalidUpgrade('Connection') + + connection = parse_connection(connection) if not any(value.lower() == 'upgrade' for value in connection): - raise InvalidUpgrade('Connection', headers.get('Connection', '')) + raise InvalidUpgrade('Connection', connection) - upgrade = parse_upgrade(headers.get('Upgrade', '')) + try: + upgrade = headers['Upgrade'] + except KeyError: + raise InvalidUpgrade('Upgrade') + + upgrade = parse_upgrade(upgrade) # For compatibility with non-strict implementations, ignore case when # checking the Upgrade header. It's supposed to be 'WebSocket'. if not (len(upgrade) == 1 and upgrade[0].lower() == 'websocket'): - raise InvalidUpgrade('Upgrade', headers.get('Upgrade', '')) + raise InvalidUpgrade('Upgrade', upgrade) + + try: + s_w_accept = headers['Sec-WebSocket-Accept'] + except KeyError: + raise InvalidHeader('Sec-WebSocket-Accept') - if headers.get('Sec-WebSocket-Accept', '') != accept(key): - raise InvalidHeaderValue( - 'Sec-WebSocket-Accept', headers.get('Sec-WebSocket-Accept', '')) + if s_w_accept != accept(key): + raise InvalidHeaderValue('Sec-WebSocket-Accept', s_w_accept) def accept(key): diff --git a/websockets/test_exceptions.py b/websockets/test_exceptions.py index 0985e5766..8092b6d11 100644 --- a/websockets/test_exceptions.py +++ b/websockets/test_exceptions.py @@ -20,9 +20,17 @@ def test_str(self): InvalidMessage("Malformed HTTP message"), "Malformed HTTP message", ), + ( + InvalidHeader('Name'), + "Missing Name header", + ), + ( + InvalidHeader('Name', None), + "Missing Name header", + ), ( InvalidHeader('Name', ''), - "Missing or empty Name header", + "Empty Name header", ), ( InvalidHeader('Name', 'Value'), @@ -40,8 +48,8 @@ def test_str(self): ), ( - InvalidUpgrade('Upgrade', ''), - "Missing or empty Upgrade header", + InvalidUpgrade('Upgrade'), + "Missing Upgrade header", ), ( InvalidUpgrade('Connection', 'websocket'), diff --git a/websockets/test_handshake.py b/websockets/test_handshake.py index 5083d1ee2..53cedb55e 100644 --- a/websockets/test_handshake.py +++ b/websockets/test_handshake.py @@ -1,7 +1,9 @@ import contextlib import unittest -from .exceptions import InvalidHandshake +from .exceptions import ( + InvalidHandshake, InvalidHeader, InvalidHeaderValue, InvalidUpgrade +) from .handshake import * from .handshake import accept # private API @@ -24,94 +26,97 @@ def test_round_trip(self): check_response(response_headers, request_key) @contextlib.contextmanager - def assertInvalidRequestHeaders(self): + def assertInvalidRequestHeaders(self, exc_type=InvalidHandshake): """ Provide request headers for corruption. Assert that the transformation made them invalid. """ + assert issubclass(exc_type, InvalidHandshake) headers = {} build_request(headers) yield headers - with self.assertRaises(InvalidHandshake): + with self.assertRaises(exc_type): check_request(headers) def test_request_invalid_upgrade(self): - with self.assertInvalidRequestHeaders() as headers: + with self.assertInvalidRequestHeaders(InvalidUpgrade) as headers: headers['Upgrade'] = 'socketweb' def test_request_missing_upgrade(self): - with self.assertInvalidRequestHeaders() as headers: + with self.assertInvalidRequestHeaders(InvalidUpgrade) as headers: del headers['Upgrade'] def test_request_invalid_connection(self): - with self.assertInvalidRequestHeaders() as headers: + with self.assertInvalidRequestHeaders(InvalidUpgrade) as headers: headers['Connection'] = 'Downgrade' def test_request_missing_connection(self): - with self.assertInvalidRequestHeaders() as headers: + with self.assertInvalidRequestHeaders(InvalidUpgrade) as headers: del headers['Connection'] def test_request_invalid_key_not_base64(self): - with self.assertInvalidRequestHeaders() as headers: + with self.assertInvalidRequestHeaders(InvalidHeaderValue) as headers: headers['Sec-WebSocket-Key'] = "!@#$%^&*()" def test_request_invalid_key_not_well_padded(self): - with self.assertInvalidRequestHeaders() as headers: + with self.assertInvalidRequestHeaders(InvalidHeaderValue) as headers: headers['Sec-WebSocket-Key'] = "CSIRmL8dWYxeAdr/XpEHRw" def test_request_invalid_key_not_16_bytes_long(self): - with self.assertInvalidRequestHeaders() as headers: + with self.assertInvalidRequestHeaders(InvalidHeaderValue) as headers: headers['Sec-WebSocket-Key'] = "ZLpprpvK4PE=" def test_request_missing_key(self): - with self.assertInvalidRequestHeaders() as headers: + with self.assertInvalidRequestHeaders(InvalidHeader) as headers: del headers['Sec-WebSocket-Key'] def test_request_invalid_version(self): - with self.assertInvalidRequestHeaders() as headers: + with self.assertInvalidRequestHeaders(InvalidHeaderValue) as headers: headers['Sec-WebSocket-Version'] = '42' def test_request_missing_version(self): - with self.assertInvalidRequestHeaders() as headers: + with self.assertInvalidRequestHeaders(InvalidHeader) as headers: del headers['Sec-WebSocket-Version'] @contextlib.contextmanager - def assertInvalidResponseHeaders(self, key='CSIRmL8dWYxeAdr/XpEHRw=='): + def assertInvalidResponseHeaders( + self, exc_type=InvalidHandshake, key='CSIRmL8dWYxeAdr/XpEHRw=='): """ Provide response headers for corruption. Assert that the transformation made them invalid. """ + assert issubclass(exc_type, InvalidHandshake) headers = {} build_response(headers, key) yield headers - with self.assertRaises(InvalidHandshake): + with self.assertRaises(exc_type): check_response(headers, key) def test_response_invalid_upgrade(self): - with self.assertInvalidResponseHeaders() as headers: + with self.assertInvalidResponseHeaders(InvalidUpgrade) as headers: headers['Upgrade'] = 'socketweb' def test_response_missing_upgrade(self): - with self.assertInvalidResponseHeaders() as headers: + with self.assertInvalidResponseHeaders(InvalidUpgrade) as headers: del headers['Upgrade'] def test_response_invalid_connection(self): - with self.assertInvalidResponseHeaders() as headers: + with self.assertInvalidResponseHeaders(InvalidUpgrade) as headers: headers['Connection'] = 'Downgrade' def test_response_missing_connection(self): - with self.assertInvalidResponseHeaders() as headers: + with self.assertInvalidResponseHeaders(InvalidUpgrade) as headers: del headers['Connection'] def test_response_invalid_accept(self): - with self.assertInvalidResponseHeaders() as headers: + with self.assertInvalidResponseHeaders(InvalidHeaderValue) as headers: other_key = "1Eq4UDEFQYg3YspNgqxv5g==" headers['Sec-WebSocket-Accept'] = accept(other_key) def test_response_missing_accept(self): - with self.assertInvalidResponseHeaders() as headers: + with self.assertInvalidResponseHeaders(InvalidHeader) as headers: del headers['Sec-WebSocket-Accept'] From b99f1bd1a108ac8e8e3b1624696f03d8fe5c3537 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Aug 2018 14:39:00 +0200 Subject: [PATCH 0462/1539] Improve handling of the HTTP Origin header. * Represent the lack of a header by None rather than ''. (The empty string suggests that the header is present with an empty value.) Backwards-compatibility is preserved with a deprecation warning. * Provide a better error message if there's more than one Origin header in a HTTP request. (Clients aren't supposed to do that.) --- docs/changelog.rst | 3 +++ websockets/server.py | 23 +++++++++++++++------ websockets/test_client_server.py | 34 +++++++++++++++++++++++++++++++- 3 files changed, 53 insertions(+), 7 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 7961301cb..c98bb73cf 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -10,6 +10,9 @@ Changelog * Added an interactive client: `python -m websockets ` +* Changed the ``origins`` argument to represent the lack of an origin with + ``None`` rather than ``''``. + * Improved error messages when a required HTTP header is missing. 6.0 diff --git a/websockets/server.py b/websockets/server.py index af7089983..199309f50 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -8,21 +8,22 @@ import email.utils import logging import sys +import warnings from .compatibility import ( BAD_REQUEST, FORBIDDEN, INTERNAL_SERVER_ERROR, SERVICE_UNAVAILABLE, SWITCHING_PROTOCOLS, UPGRADE_REQUIRED, asyncio_ensure_future ) from .exceptions import ( - AbortHandshake, InvalidHandshake, InvalidMessage, InvalidOrigin, - InvalidUpgrade, NegotiationError + AbortHandshake, InvalidHandshake, InvalidHeader, InvalidMessage, + InvalidOrigin, InvalidUpgrade, NegotiationError ) from .extensions.permessage_deflate import ServerPerMessageDeflateFactory from .handshake import build_response, check_request from .headers import ( build_extension_list, parse_extension_list, parse_subprotocol_list ) -from .http import USER_AGENT, Headers, read_request +from .http import USER_AGENT, Headers, MultipleValuesError, read_request from .protocol import WebSocketCommonProtocol @@ -48,6 +49,11 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): def __init__(self, ws_handler, ws_server, *, origins=None, extensions=None, subprotocols=None, extra_headers=None, **kwds): + # For backwards-compatibility with 6.0 or earlier. + if origins is not None and '' in origins: + warnings.warn( + "use None instead of '' in origins", DeprecationWarning) + origins = [None if origin == '' else origin for origin in origins] self.ws_handler = ws_handler self.ws_server = ws_server self.origins = origins @@ -279,7 +285,12 @@ def process_origin(headers, origins=None): acceptable. """ - origin = headers.get('Origin', '') + # "The user agent MUST NOT include more than one Origin header field" + # per https://tools.ietf.org/html/rfc6454#section-7.3. + try: + origin = headers.get('Origin') + except MultipleValuesError: + raise InvalidHeader('Origin', "more than one Origin header found") if origins is not None: if origin not in origins: raise InvalidOrigin(origin) @@ -423,7 +434,7 @@ def handshake(self, origins=None, available_extensions=None, Perform the server side of the opening handshake. If provided, ``origins`` is a list of acceptable HTTP Origin values. - Include ``''`` if the lack of an origin is acceptable. + Include ``None`` if the lack of an origin is acceptable. If provided, ``available_extensions`` is a list of supported extensions in the order in which they should be used. @@ -651,7 +662,7 @@ class Serve: :func:`serve` also accepts the following optional arguments: - * ``origins`` defines acceptable Origin HTTP headers — include ``''`` if + * ``origins`` defines acceptable Origin HTTP headers — include ``None`` if the lack of an origin is acceptable * ``extensions`` is a list of supported extensions in order of decreasing preference diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index fcf768a8e..8173ec482 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -12,6 +12,7 @@ import unittest.mock import urllib.error import urllib.request +import warnings from .client import * from .compatibility import FORBIDDEN, OK, UNAUTHORIZED @@ -1052,9 +1053,21 @@ def test_checking_origin_fails(self): server.close() self.loop.run_until_complete(server.wait_closed()) + def test_checking_origins_fails_with_multiple_headers(self): + server = self.loop.run_until_complete( + serve(handler, 'localhost', 0, origins=['http://localhost'])) + with self.assertRaisesRegex(InvalidHandshake, + "Status code not 101: 400"): + self.loop.run_until_complete( + connect(get_server_uri(server), origin='http://localhost', + extra_headers=[('Origin', 'http://otherhost')])) + + server.close() + self.loop.run_until_complete(server.wait_closed()) + def test_checking_lack_of_origin_succeeds(self): server = self.loop.run_until_complete( - serve(handler, 'localhost', 0, origins=[''])) + serve(handler, 'localhost', 0, origins=[None])) client = self.loop.run_until_complete(connect(get_server_uri(server))) self.loop.run_until_complete(client.send("Hello!")) @@ -1064,6 +1077,25 @@ def test_checking_lack_of_origin_succeeds(self): server.close() self.loop.run_until_complete(server.wait_closed()) + def test_checking_lack_of_origin_succeeds_backwards_compatibility(self): + with warnings.catch_warnings(record=True) as recorded_warnings: + server = self.loop.run_until_complete( + serve(handler, 'localhost', 0, origins=[''])) + client = self.loop.run_until_complete( + connect(get_server_uri(server))) + + self.assertEqual(len(recorded_warnings), 1) + warning = recorded_warnings[0].message + self.assertEqual(str(warning), "use None instead of '' in origins") + self.assertEqual(type(warning), DeprecationWarning) + + self.loop.run_until_complete(client.send("Hello!")) + self.assertEqual(self.loop.run_until_complete(client.recv()), "Hello!") + + self.loop.run_until_complete(client.close()) + server.close() + self.loop.run_until_complete(server.wait_closed()) + class YieldFromTests(unittest.TestCase): From c2dd6b6eff0b24aaf99a48dbcd543ea6918b3241 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Aug 2018 15:30:04 +0200 Subject: [PATCH 0463/1539] Handle multiple HTTP headers with the same name. Fix #424. --- docs/changelog.rst | 2 + websockets/handshake.py | 49 +++++++++------- websockets/test_handshake.py | 111 ++++++++++++++++++++++++++++------- 3 files changed, 120 insertions(+), 42 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index c98bb73cf..b73f897ab 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -13,6 +13,8 @@ Changelog * Changed the ``origins`` argument to represent the lack of an origin with ``None`` rather than ``''``. +* Improved handling of multiple HTTP headers with the same name. + * Improved error messages when a required HTTP header is missing. 6.0 diff --git a/websockets/handshake.py b/websockets/handshake.py index d8f79d371..aef467034 100644 --- a/websockets/handshake.py +++ b/websockets/handshake.py @@ -38,6 +38,7 @@ from .exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade from .headers import parse_connection, parse_upgrade +from .http import MultipleValuesError __all__ = [ @@ -80,21 +81,19 @@ def check_request(headers): responsibility of the caller. """ - try: - connection = headers['Connection'] - except KeyError: - raise InvalidUpgrade('Connection') + connection = sum([ + parse_connection(value) + for value in headers.get_all('Connection') + ], []) - connection = parse_connection(connection) if not any(value.lower() == 'upgrade' for value in connection): raise InvalidUpgrade('Connection', connection) - try: - upgrade = headers['Upgrade'] - except KeyError: - raise InvalidUpgrade('Upgrade') + upgrade = sum([ + parse_upgrade(value) + for value in headers.get_all('Upgrade') + ], []) - upgrade = parse_upgrade(upgrade) # For compatibility with non-strict implementations, ignore case when # checking the Upgrade header. It's supposed to be 'WebSocket'. if not (len(upgrade) == 1 and upgrade[0].lower() == 'websocket'): @@ -104,6 +103,10 @@ def check_request(headers): s_w_key = headers['Sec-WebSocket-Key'] except KeyError: raise InvalidHeader('Sec-WebSocket-Key') + except MultipleValuesError: + raise InvalidHeader( + 'Sec-WebSocket-Key', + "more than one Sec-WebSocket-Key header found") try: raw_key = base64.b64decode(s_w_key.encode(), validate=True) @@ -116,6 +119,10 @@ def check_request(headers): s_w_version = headers['Sec-WebSocket-Version'] except KeyError: raise InvalidHeader('Sec-WebSocket-Version') + except MultipleValuesError: + raise InvalidHeader( + 'Sec-WebSocket-Version', + "more than one Sec-WebSocket-Version header found") if s_w_version != '13': raise InvalidHeaderValue('Sec-WebSocket-Version', s_w_version) @@ -151,21 +158,19 @@ def check_response(headers, key): the caller. """ - try: - connection = headers['Connection'] - except KeyError: - raise InvalidUpgrade('Connection') + connection = sum([ + parse_connection(value) + for value in headers.get_all('Connection') + ], []) - connection = parse_connection(connection) if not any(value.lower() == 'upgrade' for value in connection): raise InvalidUpgrade('Connection', connection) - try: - upgrade = headers['Upgrade'] - except KeyError: - raise InvalidUpgrade('Upgrade') + upgrade = sum([ + parse_upgrade(value) + for value in headers.get_all('Upgrade') + ], []) - upgrade = parse_upgrade(upgrade) # For compatibility with non-strict implementations, ignore case when # checking the Upgrade header. It's supposed to be 'WebSocket'. if not (len(upgrade) == 1 and upgrade[0].lower() == 'websocket'): @@ -175,6 +180,10 @@ def check_response(headers, key): s_w_accept = headers['Sec-WebSocket-Accept'] except KeyError: raise InvalidHeader('Sec-WebSocket-Accept') + except MultipleValuesError: + raise InvalidHeader( + 'Sec-WebSocket-Accept', + "more than one Sec-WebSocket-Accept header found") if s_w_accept != accept(key): raise InvalidHeaderValue('Sec-WebSocket-Accept', s_w_accept) diff --git a/websockets/test_handshake.py b/websockets/test_handshake.py index 53cedb55e..ebdb75b62 100644 --- a/websockets/test_handshake.py +++ b/websockets/test_handshake.py @@ -6,6 +6,7 @@ ) from .handshake import * from .handshake import accept # private API +from .http import Headers class HandshakeTests(unittest.TestCase): @@ -17,106 +18,172 @@ def test_accept(self): self.assertEqual(accept(key), acc) def test_round_trip(self): - request_headers = {} + request_headers = Headers() request_key = build_request(request_headers) response_key = check_request(request_headers) self.assertEqual(request_key, response_key) - response_headers = {} + response_headers = Headers() build_response(response_headers, response_key) check_response(response_headers, request_key) @contextlib.contextmanager - def assertInvalidRequestHeaders(self, exc_type=InvalidHandshake): + def assertValidRequestHeaders(self): """ - Provide request headers for corruption. + Provide request headers for modification. + + Assert that the transformation kept them valid. + + """ + headers = Headers() + build_request(headers) + yield headers + check_request(headers) + + @contextlib.contextmanager + def assertInvalidRequestHeaders(self, exc_type): + """ + Provide request headers for modification. Assert that the transformation made them invalid. """ - assert issubclass(exc_type, InvalidHandshake) - headers = {} + headers = Headers() build_request(headers) yield headers + assert issubclass(exc_type, InvalidHandshake) with self.assertRaises(exc_type): check_request(headers) + def test_request_invalid_connection(self): + with self.assertInvalidRequestHeaders(InvalidUpgrade) as headers: + del headers['Connection'] + headers['Connection'] = 'Downgrade' + + def test_request_missing_connection(self): + with self.assertInvalidRequestHeaders(InvalidUpgrade) as headers: + del headers['Connection'] + + def test_request_additional_connection(self): + with self.assertValidRequestHeaders() as headers: + headers['Connection'] = 'close' + def test_request_invalid_upgrade(self): with self.assertInvalidRequestHeaders(InvalidUpgrade) as headers: + del headers['Upgrade'] headers['Upgrade'] = 'socketweb' def test_request_missing_upgrade(self): with self.assertInvalidRequestHeaders(InvalidUpgrade) as headers: del headers['Upgrade'] - def test_request_invalid_connection(self): - with self.assertInvalidRequestHeaders(InvalidUpgrade) as headers: - headers['Connection'] = 'Downgrade' - - def test_request_missing_connection(self): + def test_request_additional_upgrade(self): with self.assertInvalidRequestHeaders(InvalidUpgrade) as headers: - del headers['Connection'] + headers['Upgrade'] = 'socketweb' def test_request_invalid_key_not_base64(self): with self.assertInvalidRequestHeaders(InvalidHeaderValue) as headers: + del headers['Sec-WebSocket-Key'] headers['Sec-WebSocket-Key'] = "!@#$%^&*()" def test_request_invalid_key_not_well_padded(self): with self.assertInvalidRequestHeaders(InvalidHeaderValue) as headers: + del headers['Sec-WebSocket-Key'] headers['Sec-WebSocket-Key'] = "CSIRmL8dWYxeAdr/XpEHRw" def test_request_invalid_key_not_16_bytes_long(self): with self.assertInvalidRequestHeaders(InvalidHeaderValue) as headers: + del headers['Sec-WebSocket-Key'] headers['Sec-WebSocket-Key'] = "ZLpprpvK4PE=" def test_request_missing_key(self): with self.assertInvalidRequestHeaders(InvalidHeader) as headers: del headers['Sec-WebSocket-Key'] + def test_request_additional_key(self): + with self.assertInvalidRequestHeaders(InvalidHeader) as headers: + # This duplicates the Sec-WebSocket-Key header. + headers['Sec-WebSocket-Key'] = headers['Sec-WebSocket-Key'] + def test_request_invalid_version(self): with self.assertInvalidRequestHeaders(InvalidHeaderValue) as headers: + del headers['Sec-WebSocket-Version'] headers['Sec-WebSocket-Version'] = '42' def test_request_missing_version(self): with self.assertInvalidRequestHeaders(InvalidHeader) as headers: del headers['Sec-WebSocket-Version'] + def test_request_additional_version(self): + with self.assertInvalidRequestHeaders(InvalidHeader) as headers: + # This duplicates the Sec-WebSocket-Version header. + headers['Sec-WebSocket-Version'] = headers['Sec-WebSocket-Version'] + + @contextlib.contextmanager + def assertValidResponseHeaders(self, key='CSIRmL8dWYxeAdr/XpEHRw=='): + """ + Provide response headers for modification. + + Assert that the transformation kept them valid. + + """ + headers = Headers() + build_response(headers, key) + yield headers + check_response(headers, key) + @contextlib.contextmanager def assertInvalidResponseHeaders( - self, exc_type=InvalidHandshake, key='CSIRmL8dWYxeAdr/XpEHRw=='): + self, exc_type, key='CSIRmL8dWYxeAdr/XpEHRw=='): """ - Provide response headers for corruption. + Provide response headers for modification. Assert that the transformation made them invalid. """ - assert issubclass(exc_type, InvalidHandshake) - headers = {} + headers = Headers() build_response(headers, key) yield headers + assert issubclass(exc_type, InvalidHandshake) with self.assertRaises(exc_type): check_response(headers, key) + def test_response_invalid_connection(self): + with self.assertInvalidResponseHeaders(InvalidUpgrade) as headers: + del headers['Connection'] + headers['Connection'] = 'Downgrade' + + def test_response_missing_connection(self): + with self.assertInvalidResponseHeaders(InvalidUpgrade) as headers: + del headers['Connection'] + + def test_response_additional_connection(self): + with self.assertValidResponseHeaders() as headers: + headers['Connection'] = 'close' + def test_response_invalid_upgrade(self): with self.assertInvalidResponseHeaders(InvalidUpgrade) as headers: + del headers['Upgrade'] headers['Upgrade'] = 'socketweb' def test_response_missing_upgrade(self): with self.assertInvalidResponseHeaders(InvalidUpgrade) as headers: del headers['Upgrade'] - def test_response_invalid_connection(self): - with self.assertInvalidResponseHeaders(InvalidUpgrade) as headers: - headers['Connection'] = 'Downgrade' - - def test_response_missing_connection(self): + def test_response_additional_upgrade(self): with self.assertInvalidResponseHeaders(InvalidUpgrade) as headers: - del headers['Connection'] + headers['Upgrade'] = 'socketweb' def test_response_invalid_accept(self): with self.assertInvalidResponseHeaders(InvalidHeaderValue) as headers: + del headers['Sec-WebSocket-Accept'] other_key = "1Eq4UDEFQYg3YspNgqxv5g==" headers['Sec-WebSocket-Accept'] = accept(other_key) def test_response_missing_accept(self): with self.assertInvalidResponseHeaders(InvalidHeader) as headers: del headers['Sec-WebSocket-Accept'] + + def test_response_additional_accept(self): + with self.assertInvalidResponseHeaders(InvalidHeader) as headers: + # This duplicates the Sec-WebSocket-Accept header. + headers['Sec-WebSocket-Accept'] = headers['Sec-WebSocket-Accept'] From d2b9537e0ef904fdec92f45e9f0b6a4e0e671a08 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Aug 2018 16:16:13 +0200 Subject: [PATCH 0464/1539] Improve test coverage. --- websockets/extensions/test_permessage_deflate.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/websockets/extensions/test_permessage_deflate.py b/websockets/extensions/test_permessage_deflate.py index e4afcec39..a681355b1 100644 --- a/websockets/extensions/test_permessage_deflate.py +++ b/websockets/extensions/test_permessage_deflate.py @@ -624,7 +624,7 @@ def test_process_response_params_deduplication(self): [], [PerMessageDeflate(False, False, 15, 15)]) -class PerMessageDeflateTests(unittest.TestCase): +class PerMessageDeflateTests(unittest.TestCase, ExtensionTestsMixin): def setUp(self): # Set up an instance of the permessage-deflate extension with the most @@ -635,6 +635,9 @@ def setUp(self): def test_name(self): assert self.extension.name == 'permessage-deflate' + def test_repr(self): + self.assertExtensionEqual(eval(repr(self.extension)), self.extension) + # Control frames aren't encoded or decoded. def test_no_encode_decode_ping_frame(self): From 0197e82c7cd5b825a8d9b966e5586d77fa01f454 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 29 Jul 2018 17:56:40 +0200 Subject: [PATCH 0465/1539] Ping at regular intervals. Fix #383. --- docs/api.rst | 12 ++-- docs/changelog.rst | 5 +- docs/cheatsheet.rst | 22 ------- docs/design.rst | 11 ++++ docs/lifecycle.graffle | Bin 3073 -> 3134 bytes docs/lifecycle.svg | 2 +- websockets/client.py | 15 +++-- websockets/protocol.py | 76 ++++++++++++++++++++- websockets/server.py | 15 +++-- websockets/test_client_server.py | 8 ++- websockets/test_protocol.py | 110 ++++++++++++++++++++++++++++++- 11 files changed, 230 insertions(+), 46 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index d41075ad8..67f3756a4 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -32,9 +32,9 @@ Server .. automodule:: websockets.server - .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, compression='deflate', **kwds) + .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, ping_interval=20, ping_timeout=20, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, compression='deflate', **kwds) - .. autofunction:: unix_serve(ws_handler, path, *, create_protocol=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, compression='deflate', **kwds) + .. autofunction:: unix_serve(ws_handler, path, *, create_protocol=None, ping_interval=20, ping_timeout=20, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, compression='deflate', **kwds) .. autoclass:: WebSocketServer @@ -43,7 +43,7 @@ Server .. automethod:: wait_closed() .. autoattribute:: sockets - .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, extensions=None, subprotocols=None, extra_headers=None) + .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, host=None, port=None, secure=None, ping_interval=20, ping_timeout=20, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, extensions=None, subprotocols=None, extra_headers=None) .. automethod:: handshake(origins=None, available_extensions=None, available_subprotocols=None, extra_headers=None) .. automethod:: process_request(path, request_headers) @@ -54,9 +54,9 @@ Client .. automodule:: websockets.client - .. autofunction:: connect(uri, *, create_protocol=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, compression='deflate', **kwds) + .. autofunction:: connect(uri, *, create_protocol=None, ping_interval=20, ping_timeout=20, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, compression='deflate', **kwds) - .. autoclass:: WebSocketClientProtocol(*, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, extensions=None, subprotocols=None, extra_headers=None) + .. autoclass:: WebSocketClientProtocol(*, host=None, port=None, secure=None, ping_interval=20, ping_timeout=20, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, extensions=None, subprotocols=None, extra_headers=None) .. automethod:: handshake(wsuri, origin=None, available_extensions=None, available_subprotocols=None, extra_headers=None) @@ -65,7 +65,7 @@ Shared .. automodule:: websockets.protocol - .. autoclass:: WebSocketCommonProtocol(*, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None) + .. autoclass:: WebSocketCommonProtocol(*, host=None, port=None, secure=None, ping_interval=20, ping_timeout=20, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None) .. automethod:: close(code=1000, reason='') diff --git a/docs/changelog.rst b/docs/changelog.rst index b73f897ab..c895baac9 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -8,7 +8,10 @@ Changelog *In development* -* Added an interactive client: `python -m websockets ` +* websockets sends Ping frames at regular intervals and closes the connection + if it doesn't receive a matching Pong frame. See :class:`~protocol.WebSocketCommonProtocol` for details. + +* Added an interactive client: ``python -m websockets ``. * Changed the ``origins`` argument to represent the lack of an origin with ``None`` rather than ``''``. diff --git a/docs/cheatsheet.rst b/docs/cheatsheet.rst index 8857152c4..2f6a47e9c 100644 --- a/docs/cheatsheet.rst +++ b/docs/cheatsheet.rst @@ -76,28 +76,6 @@ in particular. Fortunately Python's official documentation provides advice to .. _develop with asyncio: https://docs.python.org/3/library/asyncio-dev.html -Keeping connections open ------------------------- - -Pinging the other side once in a while is a good way to check whether the -connection is still working, and also to keep it open in case something kills -idle connections after some time:: - - while True: - try: - msg = await asyncio.wait_for(ws.recv(), timeout=20) - except asyncio.TimeoutError: - # No data in 20 seconds, check the connection. - try: - pong_waiter = await ws.ping() - await asyncio.wait_for(pong_waiter, timeout=10) - except asyncio.TimeoutError: - # No response to ping in 10 seconds, disconnect. - break - else: - # do something with msg - ... - Passing additional arguments to the connection handler ------------------------------------------------------ diff --git a/docs/design.rst b/docs/design.rst index 33e835d82..63afbb8d4 100644 --- a/docs/design.rst +++ b/docs/design.rst @@ -74,6 +74,12 @@ two tasks: with an exception other than :exc:`~asyncio.CancelledError`. See :ref:`data transfer ` below. +- :attr:`~protocol.WebSocketCommonProtocol.keepalive_ping_task` runs + :meth:`~protocol.WebSocketCommonProtocol.keepalive_ping()` which sends Ping + frames at regular intervals and ensures that corresponding Pong frames are + received. It is cancelled when the connection terminates. It never exits + with an exception other than :exc:`~asyncio.CancelledError`. + - :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` runs :meth:`~protocol.WebSocketCommonProtocol.close_connection()` which waits for the data transfer to terminate, then takes care of closing the TCP @@ -302,6 +308,11 @@ easier to implement the timeout on the closing handshake. Canceling of canceling :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` and failing to close the TCP connection, thus leaking resources. +Then :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` cancels +:attr:`~protocol.WebSocketCommonProtocol.keepalive_ping`. This task has no +protocol compliance responsibilities. Terminating it to avoid leaking it is +the only concern. + Terminating the TCP connection can take up to ``2 * timeout`` on the server side and ``3 * timeout`` on the client side. Clients start by waiting for the server to close the connection, hence the extra ``timeout``. Then both sides diff --git a/docs/lifecycle.graffle b/docs/lifecycle.graffle index b1df4afd045d331a7492f39c3855171566d6a700..a8ab7ff09f52a41e3e8a8abc2e8a5503e481c44a 100644 GIT binary patch literal 3134 zcmV-E48ijsiwFP!000030PS5{SK~Mmeja{>mxq1W1xJ!4+3I1~6E4#*O9MmF^iH3g z)5q9K@Zi{K+i9BV!++nBbMZA0Zqv}%53yraN#&Bhs?x>s#(%yJeeH{kVlNDSS;ab9 z)nwp=t``h`S$*IAb=_M1&*sX;KX#6`yMG+-Y9rr^6Yco@*1`U^wz}SGwMQdgwpu&g z9qssFzthzqk5+5HsoZ2L|0u+*K%M zl6yfMxa)dOvbplp#)Z7vY&(hfMIP8!GTIMZ`E{eE?n0s$Byu35O|;RPUCvmbY|pV1 z$ntcO^PI_!B3oVjv=Ju}Jn|+;sE5PA8${uFq#wc6+sN+seaSakQ<5At=0Yb1h)+z; zImwZ1v}RdPL>MPw+x>GKC-b6{Xe?W^^ex-DP(=nV6gN8eoH^O_*{5WRBA%tq#rXCO zt!e0bP5UMYU(-x1uCEK?D*A2Tj;pMl(D$-3ODy|48U0MBB(L5HStq&jWkxxZc#;TN zh8J0CE_sr^6WN!USeamlM|p8TpL0dbuCj}X24nd#42Mg!hiyCfV#i03H}JBSFjHl! z$$Uj-93Ag3S)re6yuuUu<6+SDy+Kfs30GxGsXHUvfh^|qk+l6ygXtnY2o;PlONfnD znlNL9KdNo-*hxh^g0Q)c#5$rH#&0NoV{q-?2#$)hDw8lO^S7^tGV&a)Jsv>YRq86T z?RfFXx34;m?N`bM@WP=vL+uxBeqGMsSI?K-t5Jp72EdeQ(`LF8I^&@Xk}{I#swEZO zPMdeEj7<&li|w@tles(xULb2=GJS|JNcMw%xVV)oike2gC;cxn0ozo_bBd~5Yg6~q zDpyZ8t@}JCCLDXxb{C6~19J|}!YJT*Mq7B8jG7jS%L3$!$V%WPCPi@^!mYfT=hcz3 zYFJo0B@7G3g(0}G=$d9=tXmi(f-OWXB$#$2W(R`XW%-HxJy0GVNjW_2-5&6hJ_J&*(2xhnAWZ9h&O z(wwMkd;Y@wQC+qlJU2NnP)}N|&T`GR&}^z^{gnd}i~C@VJr>4HYfU4J=@?rWStb>R zB{;ha#wy&T!Ps-hSShZ|1(xL+<`C@=Vd~s4O<^z!Hk2}2IZ6iyJ;dJ zIKIBFqIChnJLmRDx{r_X?Qk45^6_pT#;zgK) zwDpCFxD>tS)7Gnrcg14g^ZlF90#R%hLzXEfBxm{Lscw;OF)9#2^P|v~r<{X(b4K^z zEZBtFbZUJV!Pfrt{RyM zo*6;_aFbXJ^Ph*`AE2E&-ch2o7mI>1G_GEo4^xb#epRa+* zQD6YBpL>yX!7KI#k(7bDIP=F6&MDQudh+YY4qWNZ>EV8FWJhi<`N}wgUyMn5;+8rQ zLerCl)X6YX5<{dUl%^yMr6j8ArC?`|PM2i8nY;WYD)&>q9Fm z>M_rH{g_a#=k$?wA?3*Sy)W|92vm4|Tfh7V3gGJPI;yejM^SPf4#L3px9mub_6(s< z?&P>3Fgu9#*CQNfpV_fHg02yM8 zVj)cNctLKh0kTAm0eWEM_5ehDSG3(H@4E#YPZmf7o}Y^%PG#iUf&KC$hb(Sz?ZnoS z2(;6eQE%m+_HgvG=J?=_f`N7pl0(zEkavzJ7Q>igtRrG@a8C@If-<%U#@Nsei&=zY z@L5JWviSGI*lJ+h!1(#Y*a?FGI;Ido_Er)Z4ZR?pZhA!Af>X6P#}pe_5R*a>`K)nk zHn?qY`}}bmj%4sKPy5QIUQF@SbDUJ`7&%#UA1U#QUhU3O`>9%DjHw; ze3?a>`Lbo_%P1-yWXbt5y!7nY%_TImW6yJT?EB7o^`MNKDSx!7u}C^!w3Zgjz=hL_ zQotF~)HE6V9fj&N;nC88FR49Zui0X{3U#<(mVt?|7+vIQ%9NJz+FcE9T+Llw4S#Z1 z!y8xAxSH3vz+yVW`iQ=vt{YOGW@o1I4F*LzW+pW)!$`Z`uiZD8qVWyQvc1MzG~VL5 zdJEN&_>|4cEpi)~;!mV+tG6cGP?2@sV)pi+8h$OkIrvB`#=*%5P=uFYuzjPNhTlDowEdhJ=3&zcRzX4lb+k(MEmM2XYueC07Q|Q`N)?fd) zZ|9+VPmCVhNdkziofpOU`%jF%9^{aFBge3GwNA{Wh;RrzxwsQ@YPcKZG{|{D$ieqT zPHEw#Wm$JYP91lHoCY~B2RYPw%)JzrM@}7ggPb23IXAq>0#+1pA}m1+DhN&Eev0+9 z?@qX(1|qDQnSCMn(+;84`x?|Ni2nf=#J}ueisc3-whxXvyE~6K(OP+Tc{!jy`4XI3 zpz`L$(FUk52dMjpZyyn=_Rmi8^4&Hd_D5sg__dAhRU&oHWG zpLgF;&2H1@;bpA2DtlS~vD>L^VciMrQQQrepf7GG{dVlRGCl~M3+a~kCN8w+9QCOe zduP>skW0m+BO%B3K*p85i>q4~?+<|x8YA3|me@u5zfz96VR1L+`GCn}L`7|!9^bkP zvU>2Ug)HyMtQhGlXb2k`xK~)wv zOe(?zYEN%95PKklL2meJR={j23wTdNz`xQ`fhqqT1WYc&k+&3m5o{ph#Uk`n!8TnG zbqD$w(9=1Wxm8q?hCfHeyU_Le(@z8}*hbx2N9H;a8fI?@dt=c0ZPX_j$vf|(ei$Zl z`6oI8-xMST9p7L>CL7%^ow+hiTTo3G@e;HO&5Ln!*x3q`Bpkl8quf^e`N8v9)_p&; zlOn38*Ay2KoLQ_%343AW{iR-42f3#^liXC01D@WkEoiHq{N?Wb=Kt{lBl*8A|Kt8% z*rva4{l4GXJ#U{My+;h6efZ6H2HWu8J$jEno}cZV`fW(x-euh``R((^4~M9||6#Y? z`3y9_e)TW6-`*_#f$S#_QI9ezH|2i}F>O?^(KXpmPPO?5_PQ^V)cf38Oko11u9!5v zrjTzOOWCI+f&wmHyk}h_C@=hR>V;t zmFZXzcbrJuJqrA*+il88c)9FHGk~=Iiy6H5g^bkpcfF)){w5#xC4FT+zk>hx=HP)^?_zJvkzVXUB|_y ze0v;4povtI)i}K-Oy8aMfIdxBoQvZ!-#eY^jSA)|zbbtlr|ZNig=@<+YrbludIN=- zp{g*jgRP*EjxsGq2?Ma30QVe7mrI2armQUoMV7c;2 z;8f@N!%UUJv8fyR!(7v-DT?=SjmGjH$*G&goD-ee(4~!Q?MBV)cMXyj&XT5=v@kKB Y*{zC=$ zW(iwKJUDi0I|ZhT|9-Zdi?4x%LR(t%P&?L^v?c5NwzMVh?)-K=@{~^^jNHKgbrY*- zQxU!$IIchVb@Ofa_ibzQx804Mf9I3B-jG_J0$R4%1dI~Ym9sBszwV-dzT$57ZdWKf;h*@SAQ zrRUUvr;cmKyBjZdF2&7m(~jLw;+Wrv@W^+>^-e=Rg+$kn#Xy9+Xs0p1oo9i%U7N>H zfjK(x63iAU|w9k}UGnD=GU zmDx9$=IFTph?0ET<)uu4HyQa&&mH(BrEpoPM7lHPHWaat4+ZyjHL9}YkLZjNYB9Fc zND}5*;fEaE`#dh`M-Vo*5!*((g7Hfoztjxn7X*JLZDk=rSj^wN8Hv!fmF8psqgVD< zKHI(kfS?+43=xs<3F<%h!iBw}EvT$wXiWrmx4 zmb`M=y4kzWf}(`iv_Du505i#U7!QL%;B#+}hmxH}FoNt%?m0bG$)5&pwlNDR=$fQM zO*aVAO==nzAOo0D)l8F7jCG1lf{fd8_Z3Qy+&J@M=PlpH*_O++M33yNTD;P%QiQ6Q zfi5P))Xh*BLhX0qD~m@Q$p&XA9Vvy|H^NyTCnq;k(V@4KG&M5FwhBIqGy zm&TvCj+ggbJ~bURy)f?M9`_?R`Ntj)FrlPp2X??^5)jt(M>M&B--c-|ZH(CNk3)Ijhr#w58Imb?(6Cl?FJi>4(a8cRfgN#zB!M$_E#V&Fu?0Mn6_ z>xMp$Cn5KGA8q*{b_c^)_;-LXB<+fbuJr6aQXC$0B@Ur?`XcOY{F{%)uN2z@ z6XOq*A!rUm=Th9>H`QDn$3=6aXS^Pl zYpZov6=q>(J(JBfCY9EgEkW+04*hF61PCT71%4zNqozf4Q?9*>3!Mi$1f<(lBLsLD z*q9=uk%xH*Ftb!WnU@(%gN2tV{dt8)%lrey)|dpW2F9udrh#A!0ia+qSo;TlM*o0? z*Fmzjwlm=;BS!r6#Hf8MVl?QJ5u?rwM9ptbjD~i<+qEXfniwC07}xW%VoO)$h&D`J zLm2BOk&C>d5vMTtiKS|~p;O3inusw2Ki(qm!NlAOCZQ|B6LEFgdDC5)ozc46NEWFe z$2WZ7i+gzlK_@Tc2}s*snurT~vDvika^kxDao>8lC6`n=aAYK=A{W`@1^wzM|Ccp| z^PGc)8dV)921z$5qJ+v#wDPmtFJ|Oc5c(oK zQ-2>pyFAAvbg9d}ED(6e_v zTx6UNp|C$G4J8u31D6fi0T)OaOMgT|83*z*&7&lH0z!F0f8ugeQ}QSN^ri;are2ZZ+RwM2BGvRUhO+yo2WR3L++bDG@sZ5#n38+|!==Od%jmn9TayQcnE^BgXvCTY3qhvSA9jNj zJ$PUAl*ch3wsKn(<8Qw&`tu-%JQO*U64j&>5JMQ!bd4J8Ag3yKjhq@e&j>mAp~%rO zRuRz*%nXWQpqq3Z_pQjPkX`WmUv2dPIVuOAVrrKuSHuAy3gHeW;ac|f&&dT{cX zN0*9F)4yk&GE(ET#_97gsqK@KgI0GHt6B&y`8sr+tiL87Gg77&q6*G;e^o1w#&oef zJKOMg5LRYiEZ|W^FvwwW70FzsP?PTuI;9Ac9iNY*Zm@!UKHB5!iR*~yIIu5;Q;Zy# z8_yZ(b2oA?%25gn)g-=SpiGb8Qi=rpCD|uI?>%*#HM^ zgGs!6h&O$AlnH&gM3F?)c0)UOnNoav6Hc zV4EF@T0=ep^mK+IbB<~{@E53f6F6>v7Q*+gAg9<=Fng(!ml~=Layl(YtR3a_B#6c8 zEoy^p^5a||-@zf%IJ7GlE~#k`Uejf?f~-XIa@?GB_JTMLMsIkSiG-VVpU;c#djXI0 zsGi;HIFI0zv8FW~2BG_p{9YaOo&{~Xt0EgLy;E7yUNip3IsC(W{|+PZzdi56(P7Zk z|J?iYsB7ygTTN~YrmC7Gy`qxO z2F9}p#7H+ZLJez6F&VkNr&2lk&V-ESEh-@K9A5vKoU}dV`|k$hrzk$p-^q(;`$AI3 zl6l9O;7;53Ztmxlv*2pgkk%kt9nadx$FEMfzw5?j`kThx&P`ONIYuy{PSlW*pr_Nj zX=L~nW_L+uIRW6lE!NEY)^|TnM7q|zqxEkiap1>pcvrWdS5GvxQep30AC`{Gsk}7_ zL+FX5C(HBnifQ`6Yz6e!w2L$IxYYK}HhaSyJ*6JIErlir#fb@YK(IzCg3im;KyEKN z4VSV=Kbv@|MJ@@z>$%|fCFeH)Q)7!NS`Luwc>Ll}NNgA8K$=#VGR@5LnHc1lD?OaA zQlP@*Q!R^ZDaC|qYx18l P>~8!Y*@LYtBVqsmQ?K-A diff --git a/docs/lifecycle.svg b/docs/lifecycle.svg index d783421f9..0a9818d29 100644 --- a/docs/lifecycle.svg +++ b/docs/lifecycle.svg @@ -1,3 +1,3 @@ - Produced by OmniGraffle 6.6.2 2017-09-17 19:42:30 +0000Canvas 1Layer 1CONNECTINGOPENCLOSINGCLOSEDtransfer_dataclose_connectionconnectrecv / send / ping / pong / close opening handshakeconnectionterminationdata transfer& closing handshake + Produced by OmniGraffle 6.6.2 2018-07-29 15:25:34 +0000Canvas 1Layer 1CONNECTINGOPENCLOSINGCLOSEDtransfer_dataclose_connectionconnectrecv / send / ping / pong / close opening handshakeconnectionterminationdata transfer& closing handshakekeepalive_ping diff --git a/websockets/client.py b/websockets/client.py index aa45180bb..d01ef7395 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -293,9 +293,10 @@ class Connect: a ``wss://`` URI, if this argument isn't provided explicitly, it's set to ``True``, which means Python's default :class:`~ssl.SSLContext` is used. - The behavior of the ``timeout``, ``max_size``, and ``max_queue``, - ``read_limit``, and ``write_limit`` optional arguments is described in the - documentation of :class:`~websockets.protocol.WebSocketCommonProtocol`. + The behavior of the ``ping_interval``, ``ping_timeout``, ``timeout``, + ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit`` optional + arguments is described in the documentation of + :class:`~websockets.protocol.WebSocketCommonProtocol`. The ``create_protocol`` parameter allows customizing the asyncio protocol that manages the connection. It should be a callable or class accepting @@ -326,7 +327,9 @@ class Connect: def __init__(self, uri, *, create_protocol=None, - timeout=10, max_size=2 ** 20, max_queue=2 ** 5, + ping_interval=20, ping_timeout=20, + timeout=10, + max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, legacy_recv=False, klass=None, origin=None, extensions=None, subprotocols=None, @@ -364,7 +367,9 @@ def __init__(self, uri, *, factory = lambda: create_protocol( host=wsuri.host, port=wsuri.port, secure=wsuri.secure, - timeout=timeout, max_size=max_size, max_queue=max_queue, + ping_interval=ping_interval, ping_timeout=ping_timeout, + timeout=timeout, + max_size=max_size, max_queue=max_queue, read_limit=read_limit, write_limit=write_limit, loop=loop, legacy_recv=legacy_recv, origin=origin, extensions=extensions, subprotocols=subprotocols, diff --git a/websockets/protocol.py b/websockets/protocol.py index 462904b23..9c72c5f54 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -78,13 +78,27 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): The ``host``, ``port`` and ``secure`` parameters are simply stored as attributes for handlers that need them. + Once the connection is open, a `Ping frame`_ is sent every + ``ping_interval`` seconds. This serves as a keepalive. It helps keeping + the connection open, especially in the presence of proxies with short + timeouts. Set ``ping_interval`` to ``None`` to disable this behavior. + + .. _Ping frame: https://tools.ietf.org/html/rfc6455#section-5.5.2 + + If the corresponding `Pong frame`_ isn't received within ``ping_timeout`` + seconds, the connection is considered unusable and is closed with status + code 1011. This ensures that the remote endpoint remains responsive. Set + ``ping_timeout`` to ``None`` to disable this behavior. + + .. _Pong frame: https://tools.ietf.org/html/rfc6455#section-5.5.3 + The ``timeout`` parameter defines a maximum wait time in seconds for completing the closing handshake and terminating the TCP connection. :meth:`close()` completes in at most ``4 * timeout`` on the server side and ``5 * timeout`` on the client side. - ``timeout`` is a parameter of the protocol because websockets usually - calls :meth:`close()` implicitly: + ``timeout`` needs to be a parameter of the protocol because websockets + usually calls :meth:`close()` implicitly: - on the server side, when the connection handler terminates, - on the client side, when exiting the context manager for the connection. @@ -148,12 +162,16 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): def __init__(self, *, host=None, port=None, secure=None, - timeout=10, max_size=2 ** 20, max_queue=2 ** 5, + ping_interval=20, ping_timeout=20, + timeout=10, + max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, legacy_recv=False): self.host = host self.port = port self.secure = secure + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout self.timeout = timeout self.max_size = max_size self.max_queue = max_queue @@ -218,6 +236,9 @@ def __init__(self, *, # Exception that occurred during data transfer, if any. self.transfer_data_exc = None + # Task sending keepalive pings. + self.keepalive_ping_task = None + # Task closing the TCP connection. self.close_connection_task = None @@ -247,6 +268,9 @@ def connection_open(self): # Start the task that receives incoming WebSocket messages. self.transfer_data_task = asyncio_ensure_future( self.transfer_data(), loop=self.loop) + # Start the task that sends pings at regular intervals. + self.keepalive_ping_task = asyncio_ensure_future( + self.keepalive_ping(), loop=self.loop) # Start the task that eventually closes the TCP connection. self.close_connection_task = asyncio_ensure_future( self.close_connection(), loop=self.loop) @@ -798,6 +822,48 @@ def write_close_frame(self, data=b''): # 7.1.2. Start the WebSocket Closing Handshake yield from self.write_frame(OP_CLOSE, data, State.CLOSING) + @asyncio.coroutine + def keepalive_ping(self): + """ + Send a Ping frame and wait for a Pong frame at regular intervals. + + This coroutine exits when the connection terminates and one of the + following happens: + - :meth:`ping` raises :exc:`ConnectionClosed`, or + - :meth:`close_connection` cancels :attr:`keepalive_ping_task`. + + """ + if self.ping_interval is None: + return + + try: + while True: + yield from asyncio.sleep(self.ping_interval, loop=self.loop) + + # ping() cannot raise ConnectionClosed, only CancelledError: + # - If the connection is CLOSING, keepalive_ping_task will be + # cancelled by close_connection() before ping() returns. + # - If the connection is CLOSED, keepalive_ping_task must be + # cancelled already. + ping_waiter = yield from self.ping() + + if self.ping_timeout is not None: + try: + yield from asyncio.wait_for( + ping_waiter, self.ping_timeout, loop=self.loop) + except asyncio.TimeoutError: + logger.debug( + "%s ! timed out waiting for pong", self.side) + self.fail_connection(1011) + break + + except asyncio.CancelledError: + raise + + except Exception as exc: + logger.warning( + "Unexpected exception in keepalive ping task", exc_info=True) + @asyncio.coroutine def close_connection(self): """ @@ -819,6 +885,10 @@ def close_connection(self): except asyncio.CancelledError: pass + # Cancel the keepalive ping task. + if self.keepalive_ping_task is not None: + self.keepalive_ping_task.cancel() + # Cancel all pending pings because they'll never receive a pong. for ping in self.pings.values(): ping.cancel() diff --git a/websockets/server.py b/websockets/server.py index 199309f50..27cca4b92 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -656,9 +656,10 @@ class Serve: :class:`WebSocketServerProtocol` instance. It defaults to :class:`WebSocketServerProtocol`. - The behavior of the ``timeout``, ``max_size``, and ``max_queue``, - ``read_limit``, and ``write_limit`` optional arguments is described in the - documentation of :class:`~websockets.protocol.WebSocketCommonProtocol`. + The behavior of the ``ping_interval``, ``ping_timeout``, ``timeout``, + ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit`` optional + arguments is described in the documentation of + :class:`~websockets.protocol.WebSocketCommonProtocol`. :func:`serve` also accepts the following optional arguments: @@ -701,7 +702,9 @@ class Serve: def __init__(self, ws_handler, host=None, port=None, *, path=None, create_protocol=None, - timeout=10, max_size=2 ** 20, max_queue=2 ** 5, + ping_interval=20, ping_timeout=20, + timeout=10, + max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, legacy_recv=False, klass=None, origins=None, extensions=None, subprotocols=None, @@ -735,7 +738,9 @@ def __init__(self, ws_handler, host=None, port=None, *, factory = lambda: create_protocol( ws_handler, ws_server, host=host, port=port, secure=secure, - timeout=timeout, max_size=max_size, max_queue=max_queue, + ping_interval=ping_interval, ping_timeout=ping_timeout, + timeout=timeout, + max_size=max_size, max_queue=max_queue, read_limit=read_limit, write_limit=write_limit, loop=loop, legacy_recv=legacy_recv, origins=origins, extensions=extensions, subprotocols=subprotocols, diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index 8173ec482..b3b15bb30 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -230,14 +230,18 @@ def run_loop_once(self): self.loop.run_forever() def start_server(self, **kwds): - # Don't enable compression by default in tests. + # Disable compression by default in tests. kwds.setdefault('compression', None) + # Disable pings by default in tests. + kwds.setdefault('ping_interval', None) start_server = serve(handler, 'localhost', 0, **kwds) self.server = self.loop.run_until_complete(start_server) def start_client(self, resource_name='/', user_info=None, **kwds): - # Don't enable compression by default in tests. + # Disable compression by default in tests. kwds.setdefault('compression', None) + # Disable pings by default in tests. + kwds.setdefault('ping_interval', None) secure = kwds.get('ssl') is not None server_uri = get_server_uri( self.server, secure, resource_name, user_info) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index e9e227104..3c934a41d 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -92,7 +92,8 @@ def setUp(self): super().setUp() self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) - self.protocol = WebSocketCommonProtocol() + # Disable pings to make it easier to test what frames are sent exactly. + self.protocol = WebSocketCommonProtocol(ping_interval=None) self.transport = TransportMock() self.transport.setup_mock(self.loop, self.protocol) @@ -770,6 +771,113 @@ def test_connection_closed_attributes(self): self.assertEqual(connection_closed_exc.code, 1000) self.assertEqual(connection_closed_exc.reason, 'close') + # Test the protocol logic for sending keepalive pings. + + def restart_protocol_with_keepalive_ping( + self, ping_interval=3 * MS, ping_timeout=3 * MS): + initial_protocol = self.protocol + # copied from tearDown + self.transport.close() + self.loop.run_until_complete(self.protocol.close()) + # copied from setUp, but enables keepalive pings + self.protocol = WebSocketCommonProtocol( + ping_interval=ping_interval, ping_timeout=ping_timeout) + self.transport = TransportMock() + self.transport.setup_mock(self.loop, self.protocol) + self.protocol.is_client = initial_protocol.is_client + self.protocol.side = initial_protocol.side + + def test_keepalive_ping(self): + self.restart_protocol_with_keepalive_ping() + + # Ping is sent at 3ms and acknowledged at 4ms. + self.loop.run_until_complete(asyncio.sleep(4 * MS)) + ping_1, = tuple(self.protocol.pings) + self.assertOneFrameSent(True, OP_PING, ping_1) + self.receive_frame(Frame(True, OP_PONG, ping_1)) + + # Next ping is sent at 7ms. + self.loop.run_until_complete(asyncio.sleep(4 * MS)) + ping_2, = tuple(self.protocol.pings) + self.assertOneFrameSent(True, OP_PING, ping_2) + + # The keepalive ping task goes on. + self.assertFalse(self.protocol.keepalive_ping_task.done()) + + def test_keepalive_ping_not_acknowledged_closes_connection(self): + self.restart_protocol_with_keepalive_ping() + + # Ping is sent at 3ms and not acknowleged. + self.loop.run_until_complete(asyncio.sleep(4 * MS)) + ping_1, = tuple(self.protocol.pings) + self.assertOneFrameSent(True, OP_PING, ping_1) + + # Connection is closed at 6ms. + self.loop.run_until_complete(asyncio.sleep(4 * MS)) + self.assertOneFrameSent(True, OP_CLOSE, serialize_close(1011, '')) + + # The keepalive ping task is complete. + self.assertEqual(self.protocol.keepalive_ping_task.result(), None) + + def test_keepalive_ping_stops_when_connection_closing(self): + self.restart_protocol_with_keepalive_ping() + close_task = self.half_close_connection_local() + + # No ping sent at 3ms because the closing handshake is in progress. + self.loop.run_until_complete(asyncio.sleep(4 * MS)) + self.assertNoFrameSent() + + # The keepalive ping task terminated. + self.assertTrue(self.protocol.keepalive_ping_task.cancelled()) + + self.loop.run_until_complete(close_task) # cleanup + + def test_keepalive_ping_stops_when_connection_closed(self): + self.restart_protocol_with_keepalive_ping() + self.close_connection() + + # The keepalive ping task terminated. + self.assertTrue(self.protocol.keepalive_ping_task.cancelled()) + + def test_keepalive_ping_with_no_ping_interval(self): + self.restart_protocol_with_keepalive_ping(ping_interval=None) + + # No ping is sent at 3ms. + self.loop.run_until_complete(asyncio.sleep(4 * MS)) + self.assertNoFrameSent() + + def test_keepalive_ping_with_no_ping_timeout(self): + self.restart_protocol_with_keepalive_ping(ping_timeout=None) + + # Ping is sent at 3ms and not acknowleged. + self.loop.run_until_complete(asyncio.sleep(4 * MS)) + ping_1, = tuple(self.protocol.pings) + self.assertOneFrameSent(True, OP_PING, ping_1) + + # Next ping is sent at 7ms anyway. + self.loop.run_until_complete(asyncio.sleep(4 * MS)) + ping_1_again, ping_2 = tuple(self.protocol.pings) + self.assertEqual(ping_1, ping_1_again) + self.assertOneFrameSent(True, OP_PING, ping_2) + + # The keepalive ping task goes on. + self.assertFalse(self.protocol.keepalive_ping_task.done()) + + def test_keepalive_ping_unexpected_error(self): + self.restart_protocol_with_keepalive_ping() + + @asyncio.coroutine + def ping(): + raise Exception("BOOM") + self.protocol.ping = ping + + # The keepalive ping task fails when sending a ping at 3ms. + self.loop.run_until_complete(asyncio.sleep(4 * MS)) + + # The keepalive ping task is complete. + # It logs and swallows the exception. + self.assertEqual(self.protocol.keepalive_ping_task.result(), None) + # Test the protocol logic for closing the connection. def test_local_close(self): From 1928ad71c0d97be6f7aeac84fc4c6605f3a7bb32 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 22 Aug 2018 15:52:31 +0200 Subject: [PATCH 0466/1539] Reformat code with black. --- .circleci/config.yml | 2 +- Makefile | 3 + setup.cfg | 9 +- tox.ini | 6 +- websockets/__init__.py | 12 +- websockets/__main__.py | 65 +++-- websockets/client.py | 171 ++++++++---- websockets/compatibility.py | 16 +- websockets/exceptions.py | 36 ++- websockets/extensions/base.py | 3 + websockets/extensions/permessage_deflate.py | 120 ++++---- .../extensions/test_permessage_deflate.py | 264 +++++++----------- websockets/framing.py | 52 ++-- websockets/handshake.py | 41 +-- websockets/headers.py | 53 ++-- websockets/http.py | 21 +- websockets/protocol.py | 175 +++++++----- websockets/py35/_test_client_server.py | 17 +- websockets/py35/client.py | 3 +- websockets/py36/_test_client_server.py | 5 +- websockets/server.py | 201 +++++++------ websockets/test_client_server.py | 246 ++++++++-------- websockets/test_exceptions.py | 6 +- websockets/test_framing.py | 61 ++-- websockets/test_handshake.py | 9 +- websockets/test_headers.py | 84 ++---- websockets/test_http.py | 31 +- websockets/test_protocol.py | 65 +++-- websockets/test_uri.py | 26 +- websockets/test_utils.py | 11 +- websockets/uri.py | 3 +- 31 files changed, 916 insertions(+), 901 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 1632a17f5..fbcc172d1 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -9,7 +9,7 @@ jobs: - run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc - checkout - run: sudo pip install tox codecov - - run: tox -e coverage,flake8,isort + - run: tox -e coverage,black,flake8,isort - run: codecov py34: docker: diff --git a/Makefile b/Makefile index 4992263b3..3282fd1e1 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,8 @@ export PYTHONASYNCIODEBUG=1 +style: + black --skip-string-normalization websockets + test: python -W default -m unittest diff --git a/setup.cfg b/setup.cfg index dc625d46c..2fa9d1394 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,9 +2,14 @@ python-tag = py34.py35.py36 [flake8] -ignore = E731,F403,F405 +ignore = E731,F403,F405,W503 +max-line-length = 88 [isort] +combine_as_imports = True +force_grid_wrap = 0 +include_trailing_comma = True known_standard_library = asyncio +line_length = 88 lines_after_imports = 2 -multi_line_output = 5 +multi_line_output = 3 diff --git a/tox.ini b/tox.ini index 9a3516af9..b6c8eeab2 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = {py34,py35,py36,py37}{,-speedups},coverage,flake8,isort +envlist = {py34,py35,py36,py37}{,-speedups},coverage,black,flake8,isort [testenv] commands = @@ -34,6 +34,10 @@ commands = deps = coverage +[testenv:black] +commands = black --check --skip-string-normalization websockets +deps = black + [testenv:flake8] commands = flake8 websockets deps = flake8 diff --git a/websockets/__init__.py b/websockets/__init__.py index b394c5692..5fbff0d41 100644 --- a/websockets/__init__.py +++ b/websockets/__init__.py @@ -5,13 +5,13 @@ from .protocol import * from .server import * from .uri import * -from .version import version as __version__ # noqa +from .version import version as __version__ # noqa __all__ = ( - client.__all__ + - exceptions.__all__ + - protocol.__all__ + - server.__all__ + - uri.__all__ + client.__all__ + + exceptions.__all__ + + protocol.__all__ + + server.__all__ + + uri.__all__ ) diff --git a/websockets/__main__.py b/websockets/__main__.py index 9e0f0a297..af9286637 100644 --- a/websockets/__main__.py +++ b/websockets/__main__.py @@ -28,9 +28,7 @@ def win_enable_vt100(): raise RuntimeError("Unable to obtain stdout handle") cur_mode = ctypes.c_uint() - if ctypes.windll.kernel32.GetConsoleMode( - handle, ctypes.byref(cur_mode) - ) == 0: + if ctypes.windll.kernel32.GetConsoleMode(handle, ctypes.byref(cur_mode)) == 0: raise RuntimeError("Unable to query current console mode") # ctypes ints lack support for the required bit-OR operation. @@ -48,32 +46,44 @@ def exit_from_event_loop_thread(loop, stop): # When exiting the thread that runs the event loop, raise # KeyboardInterrupt in the main thead to exit the program. try: - ctrl_c = signal.CTRL_C_EVENT # Windows + ctrl_c = signal.CTRL_C_EVENT # Windows except AttributeError: - ctrl_c = signal.SIGINT # POSIX + ctrl_c = signal.SIGINT # POSIX os.kill(os.getpid(), ctrl_c) def print_during_input(string): sys.stdout.write( - '\N{ESC}7' # Save cursor position - '\N{LINE FEED}' # Add a new line - '\N{ESC}[A' # Move cursor up - '\N{ESC}[L' # Insert blank line, scroll last line down - '{string}\N{LINE FEED}' # Print string in the inserted blank line - '\N{ESC}8' # Restore cursor position - '\N{ESC}[B' # Move cursor down - .format(string=string) + ( + # Save cursor position + '\N{ESC}7' + # Add a new line + '\N{LINE FEED}' + # Move cursor up + '\N{ESC}[A' + # Insert blank line, scroll last line down + '\N{ESC}[L' + # Print string in the inserted blank line + '{string}\N{LINE FEED}' + # Restore cursor position + '\N{ESC}8' + # Move cursor down + '\N{ESC}[B' + ).format(string=string) ) sys.stdout.flush() def print_over_input(string): sys.stdout.write( - '\N{CARRIAGE RETURN}' # Move cursor to beginning of line - '\N{ESC}[K' # Delete current line - '{string}\N{LINE FEED}' - .format(string=string) + ( + # Move cursor to beginning of line + '\N{CARRIAGE RETURN}' + # Delete current line + '\N{ESC}[K' + # Print string + '{string}\N{LINE FEED}' + ).format(string=string) ) sys.stdout.flush() @@ -94,8 +104,7 @@ def run_client(uri, loop, inputs, stop): incoming = asyncio_ensure_future(websocket.recv()) outgoing = asyncio_ensure_future(inputs.get()) done, pending = yield from asyncio.wait( - [incoming, outgoing, stop], - return_when=asyncio.FIRST_COMPLETED, + [incoming, outgoing, stop], return_when=asyncio.FIRST_COMPLETED ) # Cancel pending tasks to avoid leaking them. @@ -121,12 +130,10 @@ def run_client(uri, loop, inputs, stop): finally: yield from websocket.close() - close_status = format_close( - websocket.close_code, websocket.close_reason) + close_status = format_close(websocket.close_code, websocket.close_reason) print_over_input( - "Connection closed: {close_status}." - .format(close_status=close_status) + "Connection closed: {close_status}.".format(close_status=close_status) ) exit_from_event_loop_thread(loop, stop) @@ -139,11 +146,11 @@ def main(): win_enable_vt100() except RuntimeError as exc: sys.stderr.write( - "Unable to set terminal to VT100 mode. This is only " - "supported since Win10 anniversary update. Expect " - "weird symbols on the terminal. Error: {exc!s}" - "\N{LINE FEED}" - .format(exc=exc) + ( + "Unable to set terminal to VT100 mode. This is only " + "supported since Win10 anniversary update. Expect " + "weird symbols on the terminal.\nError: {exc!s}\n" + ).format(exc=exc) ) sys.stderr.flush() @@ -178,7 +185,7 @@ def main(): # Since there's no size limit, put_nowait is identical to put. message = input('> ') loop.call_soon_threadsafe(inputs.put_nowait, message) - except (KeyboardInterrupt, EOFError): # ^C, ^D + except (KeyboardInterrupt, EOFError): # ^C, ^D loop.call_soon_threadsafe(stop.set_result, None) # Wait for the event loop to terminate. diff --git a/websockets/client.py b/websockets/client.py index d01ef7395..94c761745 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -8,13 +8,19 @@ import sys from .exceptions import ( - InvalidHandshake, InvalidMessage, InvalidStatusCode, NegotiationError + InvalidHandshake, + InvalidMessage, + InvalidStatusCode, + NegotiationError, ) from .extensions.permessage_deflate import ClientPerMessageDeflateFactory from .handshake import build_request, check_response from .headers import ( - build_basic_auth, build_extension_list, build_subprotocol_list, - parse_extension_list, parse_subprotocol_list + build_basic_auth, + build_extension_list, + build_subprotocol_list, + parse_extension_list, + parse_subprotocol_list, ) from .http import USER_AGENT, Headers, read_response from .protocol import WebSocketCommonProtocol @@ -32,12 +38,19 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): :class:`~websockets.protocol.WebSocketCommonProtocol`. """ + is_client = True side = 'client' - def __init__(self, *, - origin=None, extensions=None, subprotocols=None, - extra_headers=None, **kwds): + def __init__( + self, + *, + origin=None, + extensions=None, + subprotocols=None, + extra_headers=None, + **kwds + ): self.origin = origin self.available_extensions = extensions self.available_subprotocols = subprotocols @@ -119,10 +132,10 @@ def process_extensions(headers, available_extensions): if available_extensions is None: raise InvalidHandshake("No extensions supported") - parsed_header_values = sum([ - parse_extension_list(header_value) - for header_value in header_values - ], []) + parsed_header_values = sum( + [parse_extension_list(header_value) for header_value in header_values], + [], + ) for name, response_params in parsed_header_values: @@ -135,7 +148,8 @@ def process_extensions(headers, available_extensions): # Skip non-matching extensions based on their params. try: extension = extension_factory.process_response_params( - response_params, accepted_extensions) + response_params, accepted_extensions + ) except NegotiationError: continue @@ -150,7 +164,9 @@ def process_extensions(headers, available_extensions): else: raise NegotiationError( "Unsupported extension: name = {}, params = {}".format( - name, response_params)) + name, response_params + ) + ) return accepted_extensions @@ -173,27 +189,37 @@ def process_subprotocol(headers, available_subprotocols): if available_subprotocols is None: raise InvalidHandshake("No subprotocols supported") - parsed_header_values = sum([ - parse_subprotocol_list(header_value) - for header_value in header_values - ], []) + parsed_header_values = sum( + [ + parse_subprotocol_list(header_value) + for header_value in header_values + ], + [], + ) if len(parsed_header_values) > 1: raise InvalidHandshake( - "Multiple subprotocols: {}".format( - ', '.join(parsed_header_values))) + "Multiple subprotocols: {}".format(', '.join(parsed_header_values)) + ) subprotocol = parsed_header_values[0] if subprotocol not in available_subprotocols: raise NegotiationError( - "Unsupported subprotocol: {}".format(subprotocol)) + "Unsupported subprotocol: {}".format(subprotocol) + ) return subprotocol @asyncio.coroutine - def handshake(self, wsuri, origin=None, available_extensions=None, - available_subprotocols=None, extra_headers=None): + def handshake( + self, + wsuri, + origin=None, + available_extensions=None, + available_subprotocols=None, + extra_headers=None, + ): """ Perform the client side of the opening handshake. @@ -216,14 +242,13 @@ def handshake(self, wsuri, origin=None, available_extensions=None, """ request_headers = Headers() - if wsuri.port == (443 if wsuri.secure else 80): # pragma: no cover + if wsuri.port == (443 if wsuri.secure else 80): # pragma: no cover request_headers['Host'] = wsuri.host else: request_headers['Host'] = '{}:{}'.format(wsuri.host, wsuri.port) if wsuri.user_info: - request_headers['Authorization'] = build_basic_auth( - *wsuri.user_info) + request_headers['Authorization'] = build_basic_auth(*wsuri.user_info) if origin is not None: request_headers['Origin'] = origin @@ -231,13 +256,12 @@ def handshake(self, wsuri, origin=None, available_extensions=None, key = build_request(request_headers) if available_extensions is not None: - extensions_header = build_extension_list([ - ( - extension_factory.name, - extension_factory.get_request_params(), - ) - for extension_factory in available_extensions - ]) + extensions_header = build_extension_list( + [ + (extension_factory.name, extension_factory.get_request_params()) + for extension_factory in available_extensions + ] + ) request_headers['Sec-WebSocket-Extensions'] = extensions_header if available_subprotocols is not None: @@ -254,8 +278,7 @@ def handshake(self, wsuri, origin=None, available_extensions=None, request_headers.setdefault('User-Agent', USER_AGENT) - yield from self.write_http_request( - wsuri.resource_name, request_headers) + yield from self.write_http_request(wsuri.resource_name, request_headers) status_code, response_headers = yield from self.read_http_response() @@ -265,10 +288,12 @@ def handshake(self, wsuri, origin=None, available_extensions=None, check_response(response_headers, key) self.extensions = self.process_extensions( - response_headers, available_extensions) + response_headers, available_extensions + ) self.subprotocol = self.process_subprotocol( - response_headers, available_subprotocols) + response_headers, available_subprotocols + ) self.connection_open() @@ -325,15 +350,28 @@ class Connect: """ - def __init__(self, uri, *, - create_protocol=None, - ping_interval=20, ping_timeout=20, - timeout=10, - max_size=2 ** 20, max_queue=2 ** 5, - read_limit=2 ** 16, write_limit=2 ** 16, - loop=None, legacy_recv=False, klass=None, - origin=None, extensions=None, subprotocols=None, - extra_headers=None, compression='deflate', **kwds): + def __init__( + self, + uri, + *, + create_protocol=None, + ping_interval=20, + ping_timeout=20, + timeout=10, + max_size=2 ** 20, + max_queue=2 ** 5, + read_limit=2 ** 16, + write_limit=2 ** 16, + loop=None, + legacy_recv=False, + klass=None, + origin=None, + extensions=None, + subprotocols=None, + extra_headers=None, + compression='deflate', + **kwds + ): if loop is None: loop = asyncio.get_event_loop() @@ -349,8 +387,10 @@ def __init__(self, uri, *, if wsuri.secure: kwds.setdefault('ssl', True) elif kwds.get('ssl') is not None: - raise ValueError("connect() received a SSL context for a ws:// " - "URI, use a wss:// URI to enable TLS") + raise ValueError( + "connect() received a SSL context for a ws:// URI, " + "use a wss:// URI to enable TLS" + ) if compression == 'deflate': if extensions is None: @@ -359,20 +399,28 @@ def __init__(self, uri, *, extension_factory.name == ClientPerMessageDeflateFactory.name for extension_factory in extensions ): - extensions.append(ClientPerMessageDeflateFactory( - client_max_window_bits=True, - )) + extensions.append( + ClientPerMessageDeflateFactory(client_max_window_bits=True) + ) elif compression is not None: raise ValueError("Unsupported compression: {}".format(compression)) factory = lambda: create_protocol( - host=wsuri.host, port=wsuri.port, secure=wsuri.secure, - ping_interval=ping_interval, ping_timeout=ping_timeout, + host=wsuri.host, + port=wsuri.port, + secure=wsuri.secure, + ping_interval=ping_interval, + ping_timeout=ping_timeout, timeout=timeout, - max_size=max_size, max_queue=max_queue, - read_limit=read_limit, write_limit=write_limit, - loop=loop, legacy_recv=legacy_recv, - origin=origin, extensions=extensions, subprotocols=subprotocols, + max_size=max_size, + max_queue=max_queue, + read_limit=read_limit, + write_limit=write_limit, + loop=loop, + legacy_recv=legacy_recv, + origin=origin, + extensions=extensions, + subprotocols=subprotocols, extra_headers=extra_headers, ) @@ -386,16 +434,16 @@ def __init__(self, uri, *, self._origin = origin # This is a coroutine object. - self._creating_connection = loop.create_connection( - factory, host, port, **kwds) + self._creating_connection = loop.create_connection(factory, host, port, **kwds) @asyncio.coroutine - def __iter__(self): # pragma: no cover + def __iter__(self): # pragma: no cover transport, protocol = yield from self._creating_connection try: yield from protocol.handshake( - self._wsuri, origin=self._origin, + self._wsuri, + origin=self._origin, available_extensions=protocol.available_extensions, available_subprotocols=protocol.available_subprotocols, extra_headers=protocol.extra_headers, @@ -411,14 +459,17 @@ def __iter__(self): # pragma: no cover # We can't define __await__ on Python < 3.5.1 because asyncio.ensure_future # didn't accept arbitrary awaitables until Python 3.5.1. We don't define # __aenter__ and __aexit__ either on Python < 3.5.1 to keep things simple. -if sys.version_info[:3] <= (3, 5, 0): # pragma: no cover +if sys.version_info[:3] <= (3, 5, 0): # pragma: no cover + @asyncio.coroutine def connect(*args, **kwds): return Connect(*args, **kwds).__iter__() + connect.__doc__ = Connect.__doc__ else: from .py35.client import __aenter__, __aexit__, __await__ + Connect.__aenter__ = __aenter__ Connect.__aexit__ = __aexit__ Connect.__await__ = __await__ diff --git a/websockets/compatibility.py b/websockets/compatibility.py index 21bc586a4..b6506b70c 100644 --- a/websockets/compatibility.py +++ b/websockets/compatibility.py @@ -9,13 +9,13 @@ # Replace with BaseEventLoop.create_task when dropping Python < 3.4.2. -try: # pragma: no cover - asyncio_ensure_future = asyncio.ensure_future # Python ≥ 3.5 -except AttributeError: # pragma: no cover - asyncio_ensure_future = getattr(asyncio, 'async') # Python < 3.5 +try: # pragma: no cover + asyncio_ensure_future = asyncio.ensure_future # Python ≥ 3.5 +except AttributeError: # pragma: no cover + asyncio_ensure_future = getattr(asyncio, 'async') # Python < 3.5 -try: # pragma: no cover - # Python ≥ 3.5 +try: # pragma: no cover + # Python ≥ 3.5 SWITCHING_PROTOCOLS = http.HTTPStatus.SWITCHING_PROTOCOLS OK = http.HTTPStatus.OK BAD_REQUEST = http.HTTPStatus.BAD_REQUEST @@ -24,8 +24,8 @@ UPGRADE_REQUIRED = http.HTTPStatus.UPGRADE_REQUIRED INTERNAL_SERVER_ERROR = http.HTTPStatus.INTERNAL_SERVER_ERROR SERVICE_UNAVAILABLE = http.HTTPStatus.SERVICE_UNAVAILABLE -except AttributeError: # pragma: no cover - # Python < 3.5 +except AttributeError: # pragma: no cover + # Python < 3.5 class SWITCHING_PROTOCOLS: value = 101 phrase = "Switching Protocols" diff --git a/websockets/exceptions.py b/websockets/exceptions.py index e256f218a..b34a2c0dc 100644 --- a/websockets/exceptions.py +++ b/websockets/exceptions.py @@ -1,10 +1,22 @@ __all__ = [ - 'AbortHandshake', 'ConnectionClosed', 'DuplicateParameter', - 'InvalidHandshake', 'InvalidHeader', 'InvalidHeaderFormat', - 'InvalidHeaderValue', 'InvalidMessage', 'InvalidOrigin', - 'InvalidParameterName', 'InvalidParameterValue', 'InvalidState', - 'InvalidStatusCode', 'InvalidUpgrade', 'InvalidURI', 'NegotiationError', - 'PayloadTooBig', 'WebSocketProtocolError', + 'AbortHandshake', + 'ConnectionClosed', + 'DuplicateParameter', + 'InvalidHandshake', + 'InvalidHeader', + 'InvalidHeaderFormat', + 'InvalidHeaderValue', + 'InvalidMessage', + 'InvalidOrigin', + 'InvalidParameterName', + 'InvalidParameterValue', + 'InvalidState', + 'InvalidStatusCode', + 'InvalidUpgrade', + 'InvalidURI', + 'NegotiationError', + 'PayloadTooBig', + 'WebSocketProtocolError', ] @@ -20,12 +32,14 @@ class AbortHandshake(InvalidHandshake): Exception raised to abort a handshake and return a HTTP response. """ + def __init__(self, status, headers, body=b''): self.status = status self.headers = headers self.body = body message = "HTTP {}, {} headers, {} bytes".format( - status, len(headers), len(body)) + status, len(headers), len(body) + ) super().__init__(message) @@ -41,6 +55,7 @@ class InvalidHeader(InvalidHandshake): Exception raised when a HTTP header doesn't have a valid format or value. """ + def __init__(self, name, value=None): if value is None: message = "Missing {} header".format(name) @@ -56,6 +71,7 @@ class InvalidHeaderFormat(InvalidHeader): Exception raised when a Sec-WebSocket-* HTTP header cannot be parsed. """ + def __init__(self, name, error, string, pos): error = "{} at {} in {}".format(error, pos, string) super().__init__(name, error) @@ -80,6 +96,7 @@ class InvalidOrigin(InvalidHeader): Exception raised when the Origin header in a request isn't allowed. """ + def __init__(self, origin): super().__init__('Origin', origin) @@ -91,6 +108,7 @@ class InvalidStatusCode(InvalidHandshake): Provides the integer status code in its ``status_code`` attribute. """ + def __init__(self, status_code): self.status_code = status_code message = "Status code not 101: {}".format(status_code) @@ -109,6 +127,7 @@ class InvalidParameterName(NegotiationError): Exception raised when a parameter name in an extension header is invalid. """ + def __init__(self, name): self.name = name message = "Invalid parameter name: {}".format(name) @@ -120,6 +139,7 @@ class InvalidParameterValue(NegotiationError): Exception raised when a parameter value in an extension header is invalid. """ + def __init__(self, name, value): self.name = name self.value = value @@ -132,6 +152,7 @@ class DuplicateParameter(NegotiationError): Exception raised when a parameter name is repeated in an extension header. """ + def __init__(self, name): self.name = name message = "Duplicate parameter: {}".format(name) @@ -192,6 +213,7 @@ class ConnectionClosed(InvalidState): ``reason`` attributes respectively. """ + def __init__(self, code, reason): self.code = code self.reason = reason diff --git a/websockets/extensions/base.py b/websockets/extensions/base.py index 1888f52fc..69b55b3f8 100644 --- a/websockets/extensions/base.py +++ b/websockets/extensions/base.py @@ -13,6 +13,7 @@ class ClientExtensionFactory: Extension factories handle configuration and negotiation. """ + name = ... def get_request_params(self): @@ -47,6 +48,7 @@ class ServerExtensionFactory: Extension factories handle configuration and negotiation. """ + name = ... def process_request_params(self, params, accepted_extensions): @@ -70,6 +72,7 @@ class Extension: Abstract class for extensions. """ + name = ... def decode(self, frame, *, max_size=None): diff --git a/websockets/extensions/permessage_deflate.py b/websockets/extensions/permessage_deflate.py index 7ca911a2d..19f340734 100644 --- a/websockets/extensions/permessage_deflate.py +++ b/websockets/extensions/permessage_deflate.py @@ -7,8 +7,11 @@ import zlib from ..exceptions import ( - DuplicateParameter, InvalidParameterName, InvalidParameterValue, - NegotiationError, PayloadTooBig + DuplicateParameter, + InvalidParameterName, + InvalidParameterValue, + NegotiationError, + PayloadTooBig, ) from ..framing import CTRL_OPCODES, OP_CONT @@ -41,7 +44,7 @@ def _build_parameters( params.append(('client_no_context_takeover', None)) if server_max_window_bits: params.append(('server_max_window_bits', str(server_max_window_bits))) - if client_max_window_bits is True: # only in handshake requests + if client_max_window_bits is True: # only in handshake requests params.append(('client_max_window_bits', None)) elif client_max_window_bits: params.append(('client_max_window_bits', str(client_max_window_bits))) @@ -90,7 +93,7 @@ def _extract_parameters(params, *, is_server): elif name == 'client_max_window_bits': if client_max_window_bits is not None: raise DuplicateParameter(name) - if is_server and value is None: # only in handshake requests + if is_server and value is None: # only in handshake requests client_max_window_bits = True elif value in _MAX_WINDOW_BITS_VALUES: client_max_window_bits = int(value) @@ -113,6 +116,7 @@ class ClientPerMessageDeflateFactory: Client-side extension factory for permessage-deflate extension. """ + name = 'permessage-deflate' def __init__( @@ -129,16 +133,19 @@ def __init__( See https://tools.ietf.org/html/rfc7692#section-7.1. """ - if not (server_max_window_bits is None or - 8 <= server_max_window_bits <= 15): + if not (server_max_window_bits is None or 8 <= server_max_window_bits <= 15): raise ValueError("server_max_window_bits must be between 8 and 15") - if not (client_max_window_bits is None or - client_max_window_bits is True or - 8 <= client_max_window_bits <= 15): + if not ( + client_max_window_bits is None + or client_max_window_bits is True + or 8 <= client_max_window_bits <= 15 + ): raise ValueError("client_max_window_bits must be between 8 and 15") if compress_settings is not None and 'wbits' in compress_settings: - raise ValueError("compress_settings must not include wbits, " - "set client_max_window_bits instead") + raise ValueError( + "compress_settings must not include wbits, " + "set client_max_window_bits instead" + ) self.server_no_context_takeover = server_no_context_takeover self.client_no_context_takeover = client_no_context_takeover @@ -152,8 +159,10 @@ def get_request_params(self): """ return _build_parameters( - self.server_no_context_takeover, self.client_no_context_takeover, - self.server_max_window_bits, self.client_max_window_bits, + self.server_no_context_takeover, + self.client_no_context_takeover, + self.server_max_window_bits, + self.client_max_window_bits, ) def process_response_params(self, params, accepted_extensions): @@ -250,10 +259,10 @@ def process_response_params(self, params, accepted_extensions): raise NegotiationError("Unsupported client_max_window_bits") return PerMessageDeflate( - server_no_context_takeover, # remote_no_context_takeover - client_no_context_takeover, # local_no_context_takeover - server_max_window_bits or 15, # remote_max_window_bits - client_max_window_bits or 15, # local_max_window_bits + server_no_context_takeover, # remote_no_context_takeover + client_no_context_takeover, # local_no_context_takeover + server_max_window_bits or 15, # remote_max_window_bits + client_max_window_bits or 15, # local_max_window_bits self.compress_settings, ) @@ -263,6 +272,7 @@ class ServerPerMessageDeflateFactory: Server-side extension factory for permessage-deflate extension. """ + name = 'permessage-deflate' def __init__( @@ -279,15 +289,15 @@ def __init__( See https://tools.ietf.org/html/rfc7692#section-7.1. """ - if not (server_max_window_bits is None or - 8 <= server_max_window_bits <= 15): + if not (server_max_window_bits is None or 8 <= server_max_window_bits <= 15): raise ValueError("server_max_window_bits must be between 8 and 15") - if not (client_max_window_bits is None or - 8 <= client_max_window_bits <= 15): + if not (client_max_window_bits is None or 8 <= client_max_window_bits <= 15): raise ValueError("client_max_window_bits must be between 8 and 15") if compress_settings is not None and 'wbits' in compress_settings: - raise ValueError("compress_settings must not include wbits, " - "set server_max_window_bits instead") + raise ValueError( + "compress_settings must not include wbits, " + "set server_max_window_bits instead" + ) self.server_no_context_takeover = server_no_context_takeover self.client_no_context_takeover = client_no_context_takeover @@ -389,16 +399,18 @@ def process_request_params(self, params, accepted_extensions): return ( _build_parameters( - server_no_context_takeover, client_no_context_takeover, - server_max_window_bits, client_max_window_bits, + server_no_context_takeover, + client_no_context_takeover, + server_max_window_bits, + client_max_window_bits, ), PerMessageDeflate( - client_no_context_takeover, # remote_no_context_takeover - server_no_context_takeover, # local_no_context_takeover - client_max_window_bits or 15, # remote_max_window_bits - server_max_window_bits or 15, # local_max_window_bits + client_no_context_takeover, # remote_no_context_takeover + server_no_context_takeover, # local_no_context_takeover + client_max_window_bits or 15, # remote_max_window_bits + server_max_window_bits or 15, # local_max_window_bits self.compress_settings, - ) + ), ) @@ -407,6 +419,7 @@ class PerMessageDeflate: permessage-deflate extension. """ + name = 'permessage-deflate' def __init__( @@ -437,13 +450,12 @@ def __init__( self.compress_settings = compress_settings if not self.remote_no_context_takeover: - self.decoder = zlib.decompressobj( - wbits=-self.remote_max_window_bits) + self.decoder = zlib.decompressobj(wbits=-self.remote_max_window_bits) if not self.local_no_context_takeover: self.encoder = zlib.compressobj( - wbits=-self.local_max_window_bits, - **self.compress_settings) + wbits=-self.local_max_window_bits, **self.compress_settings + ) # To handle continuation frames properly, we must keep track of # whether that initial frame was encoded. @@ -452,16 +464,18 @@ def __init__( # outgoing frames, so it would always be True. def __repr__(self): - return 'PerMessageDeflate({})'.format(', '.join([ - 'remote_no_context_takeover={}'.format( - self.remote_no_context_takeover), - 'local_no_context_takeover={}'.format( - self.local_no_context_takeover), - 'remote_max_window_bits={}'.format( - self.remote_max_window_bits), - 'local_max_window_bits={}'.format( - self.local_max_window_bits), - ])) + return ( + 'PerMessageDeflate(' + 'remote_no_context_takeover={}, ' + 'local_no_context_takeover={}, ' + 'remote_max_window_bits={}, ' + 'local_max_window_bits={})' + ).format( + self.remote_no_context_takeover, + self.local_no_context_takeover, + self.remote_max_window_bits, + self.local_max_window_bits, + ) def decode(self, frame, *, max_size=None): """ @@ -492,8 +506,7 @@ def decode(self, frame, *, max_size=None): # Re-initialize per-message decoder. if self.remote_no_context_takeover: - self.decoder = zlib.decompressobj( - wbits=-self.remote_max_window_bits) + self.decoder = zlib.decompressobj(wbits=-self.remote_max_window_bits) # Uncompress compressed frames. Protect against zip bombs by # preventing zlib from decompressing more than max_length bytes @@ -505,8 +518,10 @@ def decode(self, frame, *, max_size=None): data = self.decoder.decompress(data, max_length) if self.decoder.unconsumed_tail: raise PayloadTooBig( - "Uncompressed payload length exceeds size limit (? > {} bytes)" - .format(max_size)) + "Uncompressed payload length exceeds size limit (? > {} bytes)".format( + max_size + ) + ) # Allow garbage collection of the decoder if it won't be reused. if frame.fin and self.remote_no_context_takeover: @@ -530,14 +545,11 @@ def encode(self, frame): # Re-initialize per-message decoder. if self.local_no_context_takeover: self.encoder = zlib.compressobj( - wbits=-self.local_max_window_bits, - **self.compress_settings) + wbits=-self.local_max_window_bits, **self.compress_settings + ) # Compress data frames. - data = ( - self.encoder.compress(frame.data) + - self.encoder.flush(zlib.Z_SYNC_FLUSH) - ) + data = self.encoder.compress(frame.data) + self.encoder.flush(zlib.Z_SYNC_FLUSH) if frame.fin and data.endswith(_EMPTY_UNCOMPRESSED_BLOCK): data = data[:-4] diff --git a/websockets/extensions/test_permessage_deflate.py b/websockets/extensions/test_permessage_deflate.py index a681355b1..67dec5af2 100644 --- a/websockets/extensions/test_permessage_deflate.py +++ b/websockets/extensions/test_permessage_deflate.py @@ -2,42 +2,52 @@ import zlib from ..exceptions import ( - DuplicateParameter, InvalidParameterName, InvalidParameterValue, - NegotiationError, PayloadTooBig + DuplicateParameter, + InvalidParameterName, + InvalidParameterValue, + NegotiationError, + PayloadTooBig, ) from ..framing import ( - OP_BINARY, OP_CLOSE, OP_CONT, OP_PING, OP_PONG, OP_TEXT, Frame, - serialize_close + OP_BINARY, + OP_CLOSE, + OP_CONT, + OP_PING, + OP_PONG, + OP_TEXT, + Frame, + serialize_close, ) from .permessage_deflate import * class ExtensionTestsMixin: - def assertExtensionEqual(self, extension1, extension2): - self.assertEqual(extension1.remote_no_context_takeover, - extension2.remote_no_context_takeover) - self.assertEqual(extension1.local_no_context_takeover, - extension2.local_no_context_takeover) - self.assertEqual(extension1.remote_max_window_bits, - extension2.remote_max_window_bits) - self.assertEqual(extension1.local_max_window_bits, - extension2.local_max_window_bits) - - -class ClientPerMessageDeflateFactoryTests(unittest.TestCase, - ExtensionTestsMixin): - + self.assertEqual( + extension1.remote_no_context_takeover, extension2.remote_no_context_takeover + ) + self.assertEqual( + extension1.local_no_context_takeover, extension2.local_no_context_takeover + ) + self.assertEqual( + extension1.remote_max_window_bits, extension2.remote_max_window_bits + ) + self.assertEqual( + extension1.local_max_window_bits, extension2.local_max_window_bits + ) + + +class ClientPerMessageDeflateFactoryTests(unittest.TestCase, ExtensionTestsMixin): def test_name(self): assert ClientPerMessageDeflateFactory.name == 'permessage-deflate' def test_init(self): for config in [ - (False, False, 8, None), # server_max_window_bits ≥ 8 - (False, True, 15, None), # server_max_window_bits ≤ 15 - (True, False, None, 8), # client_max_window_bits ≥ 8 - (True, True, None, 15), # client_max_window_bits ≤ 15 - (False, False, None, True), # client_max_window_bits + (False, False, 8, None), # server_max_window_bits ≥ 8 + (False, True, 15, None), # server_max_window_bits ≤ 15 + (True, False, None, 8), # client_max_window_bits ≥ 8 + (True, True, None, 15), # client_max_window_bits ≤ 15 + (False, False, None, True), # client_max_window_bits (False, False, None, None, {'memLevel': 4}), ]: with self.subTest(config=config): @@ -46,11 +56,11 @@ def test_init(self): def test_init_error(self): for config in [ - (False, False, 7, 8), # server_max_window_bits < 8 - (False, True, 8, 7), # client_max_window_bits < 8 - (True, False, 16, 15), # server_max_window_bits > 15 - (True, True, 15, 16), # client_max_window_bits > 15 - (False, False, True, None), # server_max_window_bits + (False, False, 7, 8), # server_max_window_bits < 8 + (False, True, 8, 7), # client_max_window_bits < 8 + (True, False, 16, 15), # server_max_window_bits > 15 + (True, True, 15, 16), # client_max_window_bits > 15 + (False, False, True, None), # server_max_window_bits (False, False, None, None, {'wbits': 11}), ]: with self.subTest(config=config): @@ -60,34 +70,16 @@ def test_init_error(self): def test_get_request_params(self): for config, result in [ # Test without any parameter - ( - (False, False, None, None), - [], - ), + ((False, False, None, None), []), # Test server_no_context_takeover - ( - (True, False, None, None), - [('server_no_context_takeover', None)], - ), + ((True, False, None, None), [('server_no_context_takeover', None)]), # Test client_no_context_takeover - ( - (False, True, None, None), - [('client_no_context_takeover', None)], - ), + ((False, True, None, None), [('client_no_context_takeover', None)]), # Test server_max_window_bits - ( - (False, False, 10, None), - [('server_max_window_bits', '10')], - ), + ((False, False, 10, None), [('server_max_window_bits', '10')]), # Test client_max_window_bits - ( - (False, False, None, 10), - [('client_max_window_bits', '10')], - ), - ( - (False, False, None, True), - [('client_max_window_bits', None)], - ), + ((False, False, None, 10), [('client_max_window_bits', '10')]), + ((False, False, None, True), [('client_max_window_bits', None)]), # Test all parameters together ( (True, True, 12, 12), @@ -106,27 +98,15 @@ def test_get_request_params(self): def test_process_response_params(self): for config, response_params, result in [ # Test without any parameter - ( - (False, False, None, None), - [], - (False, False, 15, 15), - ), - ( - (False, False, None, None), - [('unknown', None)], - InvalidParameterName, - ), + ((False, False, None, None), [], (False, False, 15, 15)), + ((False, False, None, None), [('unknown', None)], InvalidParameterName), # Test server_no_context_takeover ( (False, False, None, None), [('server_no_context_takeover', None)], (True, False, 15, 15), ), - ( - (True, False, None, None), - [], - NegotiationError, - ), + ((True, False, None, None), [], NegotiationError), ( (True, False, None, None), [('server_no_context_takeover', None)], @@ -148,11 +128,7 @@ def test_process_response_params(self): [('client_no_context_takeover', None)], (False, True, 15, 15), ), - ( - (False, True, None, None), - [], - (False, True, 15, 15), - ), + ((False, True, None, None), [], (False, True, 15, 15)), ( (False, True, None, None), [('client_no_context_takeover', None)], @@ -184,11 +160,7 @@ def test_process_response_params(self): [('server_max_window_bits', '16')], NegotiationError, ), - ( - (False, False, 12, None), - [], - NegotiationError, - ), + ((False, False, 12, None), [], NegotiationError), ( (False, False, 12, None), [('server_max_window_bits', '10')], @@ -220,11 +192,7 @@ def test_process_response_params(self): [('client_max_window_bits', '10')], NegotiationError, ), - ( - (False, False, None, True), - [], - (False, False, 15, 15), - ), + ((False, False, None, True), [], (False, False, 15, 15)), ( (False, False, None, True), [('client_max_window_bits', '7')], @@ -240,11 +208,7 @@ def test_process_response_params(self): [('client_max_window_bits', '16')], NegotiationError, ), - ( - (False, False, None, 12), - [], - (False, False, 15, 12), - ), + ((False, False, None, 12), [], (False, False, 15, 12)), ( (False, False, None, 12), [('client_max_window_bits', '10')], @@ -300,17 +264,13 @@ def test_process_response_params(self): (True, True, 12, 12), ), ]: - with self.subTest( - config=config, - response_params=response_params, - ): + with self.subTest(config=config, response_params=response_params): factory = ClientPerMessageDeflateFactory(*config) if isinstance(result, type) and issubclass(result, Exception): with self.assertRaises(result): factory.process_response_params(response_params, []) else: - extension = factory.process_response_params( - response_params, []) + extension = factory.process_response_params(response_params, []) expected = PerMessageDeflate(*result) self.assertExtensionEqual(extension, expected) @@ -318,21 +278,20 @@ def test_process_response_params_deduplication(self): factory = ClientPerMessageDeflateFactory(False, False, None, None) with self.assertRaises(NegotiationError): factory.process_response_params( - [], [PerMessageDeflate(False, False, 15, 15)]) - + [], [PerMessageDeflate(False, False, 15, 15)] + ) -class ServerPerMessageDeflateFactoryTests(unittest.TestCase, - ExtensionTestsMixin): +class ServerPerMessageDeflateFactoryTests(unittest.TestCase, ExtensionTestsMixin): def test_name(self): assert ServerPerMessageDeflateFactory.name == 'permessage-deflate' def test_init(self): for config in [ - (False, False, 8, None), # server_max_window_bits ≥ 8 - (False, True, 15, None), # server_max_window_bits ≤ 15 - (True, False, None, 8), # client_max_window_bits ≥ 8 - (True, True, None, 15), # client_max_window_bits ≤ 15 + (False, False, 8, None), # server_max_window_bits ≥ 8 + (False, True, 15, None), # server_max_window_bits ≤ 15 + (True, False, None, 8), # client_max_window_bits ≥ 8 + (True, True, None, 15), # client_max_window_bits ≤ 15 (False, False, None, None, {'memLevel': 4}), ]: with self.subTest(config=config): @@ -341,12 +300,12 @@ def test_init(self): def test_init_error(self): for config in [ - (False, False, 7, 8), # server_max_window_bits < 8 - (False, True, 8, 7), # client_max_window_bits < 8 - (True, False, 16, 15), # server_max_window_bits > 15 - (True, True, 15, 16), # client_max_window_bits > 15 - (False, False, None, True), # client_max_window_bits - (False, False, True, None), # server_max_window_bits + (False, False, 7, 8), # server_max_window_bits < 8 + (False, True, 8, 7), # client_max_window_bits < 8 + (True, False, 16, 15), # server_max_window_bits > 15 + (True, True, 15, 16), # client_max_window_bits > 15 + (False, False, None, True), # client_max_window_bits + (False, False, True, None), # server_max_window_bits (False, False, None, None, {'wbits': 11}), ]: with self.subTest(config=config): @@ -358,12 +317,7 @@ def test_process_request_params(self): # (remote, local) vs. (server, client). for config, request_params, response_params, result in [ # Test without any parameter - ( - (False, False, None, None), - [], - [], - (False, False, 15, 15), - ), + ((False, False, None, None), [], [], (False, False, 15, 15)), ( (False, False, None, None), [('unknown', None)], @@ -405,7 +359,7 @@ def test_process_request_params(self): ( (False, False, None, None), [('client_no_context_takeover', None)], - [('client_no_context_takeover', None)], # doesn't matter + [('client_no_context_takeover', None)], # doesn't matter (True, False, 15, 15), ), ( @@ -417,7 +371,7 @@ def test_process_request_params(self): ( (False, True, None, None), [('client_no_context_takeover', None)], - [('client_no_context_takeover', None)], # doesn't matter + [('client_no_context_takeover', None)], # doesn't matter (True, False, 15, 15), ), ( @@ -503,7 +457,7 @@ def test_process_request_params(self): ( (False, False, None, None), [('client_max_window_bits', '10')], - [('client_max_window_bits', '10')], # doesn't matter + [('client_max_window_bits', '10')], # doesn't matter (False, False, 10, 15), ), ( @@ -512,12 +466,7 @@ def test_process_request_params(self): None, InvalidParameterValue, ), - ( - (False, False, None, 12), - [], - None, - NegotiationError, - ), + ((False, False, None, 12), [], None, NegotiationError), ( (False, False, None, 12), [('client_max_window_bits', None)], @@ -533,13 +482,13 @@ def test_process_request_params(self): ( (False, False, None, 12), [('client_max_window_bits', '12')], - [('client_max_window_bits', '12')], # doesn't matter + [('client_max_window_bits', '12')], # doesn't matter (False, False, 12, 15), ), ( (False, False, None, 12), [('client_max_window_bits', '13')], - [('client_max_window_bits', '12')], # doesn't matter + [('client_max_window_bits', '12')], # doesn't matter (False, False, 12, 15), ), ( @@ -589,9 +538,7 @@ def test_process_request_params(self): ), ( (True, True, 12, 12), - [ - ('client_max_window_bits', None), - ], + [('client_max_window_bits', None)], [ ('server_no_context_takeover', None), ('client_no_context_takeover', None), @@ -612,7 +559,8 @@ def test_process_request_params(self): factory.process_request_params(request_params, []) else: params, extension = factory.process_request_params( - request_params, []) + request_params, [] + ) self.assertEqual(params, response_params) expected = PerMessageDeflate(*result) self.assertExtensionEqual(extension, expected) @@ -621,11 +569,11 @@ def test_process_response_params_deduplication(self): factory = ServerPerMessageDeflateFactory(False, False, None, None) with self.assertRaises(NegotiationError): factory.process_request_params( - [], [PerMessageDeflate(False, False, 15, 15)]) + [], [PerMessageDeflate(False, False, 15, 15)] + ) class PerMessageDeflateTests(unittest.TestCase, ExtensionTestsMixin): - def setUp(self): # Set up an instance of the permessage-deflate extension with the most # common settings. Since the extension is symmetrical, this instance @@ -668,10 +616,7 @@ def test_encode_decode_text_frame(self): enc_frame = self.extension.encode(frame) - self.assertEqual(enc_frame, frame._replace( - rsv1=True, - data=b'JNL;\xbc\x12\x00', - )) + self.assertEqual(enc_frame, frame._replace(rsv1=True, data=b'JNL;\xbc\x12\x00')) dec_frame = self.extension.decode(enc_frame) @@ -682,10 +627,7 @@ def test_encode_decode_binary_frame(self): enc_frame = self.extension.encode(frame) - self.assertEqual(enc_frame, frame._replace( - rsv1=True, - data=b'*IM\x04\x00', - )) + self.assertEqual(enc_frame, frame._replace(rsv1=True, data=b'*IM\x04\x00')) dec_frame = self.extension.decode(enc_frame) @@ -700,18 +642,16 @@ def test_encode_decode_fragmented_text_frame(self): enc_frame2 = self.extension.encode(frame2) enc_frame3 = self.extension.encode(frame3) - self.assertEqual(enc_frame1, frame1._replace( - rsv1=True, - data=b'JNL;\xbc\x12\x00\x00\x00\xff\xff', - )) - self.assertEqual(enc_frame2, frame2._replace( - rsv1=True, - data=b'RPS\x00\x00\x00\x00\xff\xff', - )) - self.assertEqual(enc_frame3, frame3._replace( - rsv1=True, - data=b'J.\xca\xcf,.N\xcc+)\x06\x00', - )) + self.assertEqual( + enc_frame1, + frame1._replace(rsv1=True, data=b'JNL;\xbc\x12\x00\x00\x00\xff\xff'), + ) + self.assertEqual( + enc_frame2, frame2._replace(rsv1=True, data=b'RPS\x00\x00\x00\x00\xff\xff') + ) + self.assertEqual( + enc_frame3, frame3._replace(rsv1=True, data=b'J.\xca\xcf,.N\xcc+)\x06\x00') + ) dec_frame1 = self.extension.decode(enc_frame1) dec_frame2 = self.extension.decode(enc_frame2) @@ -728,14 +668,12 @@ def test_encode_decode_fragmented_binary_frame(self): enc_frame1 = self.extension.encode(frame1) enc_frame2 = self.extension.encode(frame2) - self.assertEqual(enc_frame1, frame1._replace( - rsv1=True, - data=b'*IMT\x00\x00\x00\x00\xff\xff', - )) - self.assertEqual(enc_frame2, frame2._replace( - rsv1=True, - data=b'*\xc9\xccM\x05\x00', - )) + self.assertEqual( + enc_frame1, frame1._replace(rsv1=True, data=b'*IMT\x00\x00\x00\x00\xff\xff') + ) + self.assertEqual( + enc_frame2, frame2._replace(rsv1=True, data=b'*\xc9\xccM\x05\x00') + ) dec_frame1 = self.extension.decode(enc_frame1) dec_frame2 = self.extension.decode(enc_frame2) @@ -834,10 +772,12 @@ def test_compress_settings(self): enc_frame = extension.encode(frame) - self.assertEqual(enc_frame, frame._replace( - rsv1=True, - data=b'\x00\x05\x00\xfa\xffcaf\xc3\xa9\x00', # not compressed - )) + self.assertEqual( + enc_frame, + frame._replace( + rsv1=True, data=b'\x00\x05\x00\xfa\xffcaf\xc3\xa9\x00' # not compressed + ), + ) # Frames aren't decoded beyond max_length. diff --git a/websockets/framing.py b/websockets/framing.py index b1b655b28..afbc664c6 100644 --- a/websockets/framing.py +++ b/websockets/framing.py @@ -20,14 +20,23 @@ try: from .speedups import apply_mask -except ImportError: # pragma: no cover +except ImportError: # pragma: no cover from .utils import apply_mask __all__ = [ - 'DATA_OPCODES', 'CTRL_OPCODES', - 'OP_CONT', 'OP_TEXT', 'OP_BINARY', 'OP_CLOSE', 'OP_PING', 'OP_PONG', - 'Frame', 'encode_data', 'parse_close', 'serialize_close' + 'DATA_OPCODES', + 'CTRL_OPCODES', + 'OP_CONT', + 'OP_TEXT', + 'OP_BINARY', + 'OP_CLOSE', + 'OP_PING', + 'OP_PONG', + 'Frame', + 'encode_data', + 'parse_close', + 'serialize_close', ] DATA_OPCODES = OP_CONT, OP_TEXT, OP_BINARY = 0x00, 0x01, 0x02 @@ -35,21 +44,10 @@ # Close code that are allowed in a close frame. # Using a list optimizes `code in EXTERNAL_CLOSE_CODES`. -EXTERNAL_CLOSE_CODES = [ - 1000, - 1001, - 1002, - 1003, - 1007, - 1008, - 1009, - 1010, - 1011, -] +EXTERNAL_CLOSE_CODES = [1000, 1001, 1002, 1003, 1007, 1008, 1009, 1010, 1011] FrameData = collections.namedtuple( - 'FrameData', - ['fin', 'opcode', 'data', 'rsv1', 'rsv2', 'rsv3'], + 'FrameData', ['fin', 'opcode', 'data', 'rsv1', 'rsv2', 'rsv3'] ) @@ -69,6 +67,7 @@ class Frame(FrameData): :meth:`write`. """ + def __new__(cls, fin, opcode, data, rsv1=False, rsv2=False, rsv3=False): return FrameData.__new__(cls, fin, opcode, data, rsv1, rsv2, rsv3) @@ -119,8 +118,10 @@ def read(cls, reader, *, mask, max_size=None, extensions=None): length, = struct.unpack('!Q', data) if max_size is not None and length > max_size: raise PayloadTooBig( - "Payload length exceeds size limit ({} > {} bytes)" - .format(length, max_size)) + "Payload length exceeds size limit ({} > {} bytes)".format( + length, max_size + ) + ) if mask: mask_bits = yield from reader(4) @@ -174,11 +175,11 @@ def write(frame, writer, *, mask, extensions=None): # Prepare the header. head1 = ( - (0b10000000 if frame.fin else 0) | - (0b01000000 if frame.rsv1 else 0) | - (0b00100000 if frame.rsv2 else 0) | - (0b00010000 if frame.rsv3 else 0) | - frame.opcode + (0b10000000 if frame.fin else 0) + | (0b01000000 if frame.rsv1 else 0) + | (0b00100000 if frame.rsv2 else 0) + | (0b00010000 if frame.rsv3 else 0) + | frame.opcode ) head2 = 0b10000000 if mask else 0 @@ -231,8 +232,7 @@ def check(frame): if not frame.fin: raise WebSocketProtocolError("Fragmented control frame") else: - raise WebSocketProtocolError( - "Invalid opcode: {}".format(frame.opcode)) + raise WebSocketProtocolError("Invalid opcode: {}".format(frame.opcode)) def encode_data(data): diff --git a/websockets/handshake.py b/websockets/handshake.py index aef467034..cc4248974 100644 --- a/websockets/handshake.py +++ b/websockets/handshake.py @@ -41,10 +41,7 @@ from .http import MultipleValuesError -__all__ = [ - 'build_request', 'check_request', - 'build_response', 'check_response', -] +__all__ = ['build_request', 'check_request', 'build_response', 'check_response'] GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' @@ -81,18 +78,14 @@ def check_request(headers): responsibility of the caller. """ - connection = sum([ - parse_connection(value) - for value in headers.get_all('Connection') - ], []) + connection = sum( + [parse_connection(value) for value in headers.get_all('Connection')], [] + ) if not any(value.lower() == 'upgrade' for value in connection): raise InvalidUpgrade('Connection', connection) - upgrade = sum([ - parse_upgrade(value) - for value in headers.get_all('Upgrade') - ], []) + upgrade = sum([parse_upgrade(value) for value in headers.get_all('Upgrade')], []) # For compatibility with non-strict implementations, ignore case when # checking the Upgrade header. It's supposed to be 'WebSocket'. @@ -105,8 +98,8 @@ def check_request(headers): raise InvalidHeader('Sec-WebSocket-Key') except MultipleValuesError: raise InvalidHeader( - 'Sec-WebSocket-Key', - "more than one Sec-WebSocket-Key header found") + 'Sec-WebSocket-Key', "more than one Sec-WebSocket-Key header found" + ) try: raw_key = base64.b64decode(s_w_key.encode(), validate=True) @@ -121,8 +114,8 @@ def check_request(headers): raise InvalidHeader('Sec-WebSocket-Version') except MultipleValuesError: raise InvalidHeader( - 'Sec-WebSocket-Version', - "more than one Sec-WebSocket-Version header found") + 'Sec-WebSocket-Version', "more than one Sec-WebSocket-Version header found" + ) if s_w_version != '13': raise InvalidHeaderValue('Sec-WebSocket-Version', s_w_version) @@ -158,18 +151,14 @@ def check_response(headers, key): the caller. """ - connection = sum([ - parse_connection(value) - for value in headers.get_all('Connection') - ], []) + connection = sum( + [parse_connection(value) for value in headers.get_all('Connection')], [] + ) if not any(value.lower() == 'upgrade' for value in connection): raise InvalidUpgrade('Connection', connection) - upgrade = sum([ - parse_upgrade(value) - for value in headers.get_all('Upgrade') - ], []) + upgrade = sum([parse_upgrade(value) for value in headers.get_all('Upgrade')], []) # For compatibility with non-strict implementations, ignore case when # checking the Upgrade header. It's supposed to be 'WebSocket'. @@ -182,8 +171,8 @@ def check_response(headers, key): raise InvalidHeader('Sec-WebSocket-Accept') except MultipleValuesError: raise InvalidHeader( - 'Sec-WebSocket-Accept', - "more than one Sec-WebSocket-Accept header found") + 'Sec-WebSocket-Accept', "more than one Sec-WebSocket-Accept header found" + ) if s_w_accept != accept(key): raise InvalidHeaderValue('Sec-WebSocket-Accept', s_w_accept) diff --git a/websockets/headers.py b/websockets/headers.py index 6da5ec7f0..937962376 100644 --- a/websockets/headers.py +++ b/websockets/headers.py @@ -14,9 +14,12 @@ __all__ = [ - 'parse_connection', 'parse_upgrade', - 'parse_extension_list', 'build_extension_list', - 'parse_subprotocol_list', 'build_subprotocol_list', + 'parse_connection', + 'parse_upgrade', + 'parse_extension_list', + 'build_extension_list', + 'parse_subprotocol_list', + 'build_subprotocol_list', ] @@ -24,6 +27,7 @@ # described in https://tools.ietf.org/html/rfc6455#section-9.1 with the # definitions from https://tools.ietf.org/html/rfc7230#appendix-B. + def peek_ahead(string, pos): """ Return the next character from ``string`` at the given position. @@ -67,13 +71,13 @@ def parse_token(string, pos, header_name): """ match = _token_re.match(string, pos) if match is None: - raise InvalidHeaderFormat( - header_name, "expected token", string=string, pos=pos) + raise InvalidHeaderFormat(header_name, "expected token", string=string, pos=pos) return match.group(), match.end() _quoted_string_re = re.compile( - r'"(?:[\x09\x20-\x21\x23-\x5b\x5d-\x7e]|\\[\x09\x20-\x7e\x80-\xff])*"') + r'"(?:[\x09\x20-\x21\x23-\x5b\x5d-\x7e]|\\[\x09\x20-\x7e\x80-\xff])*"' +) _unquote_re = re.compile(r'\\([\x09\x20-\x7e\x80-\xff])') @@ -91,7 +95,8 @@ def parse_quoted_string(string, pos, header_name): match = _quoted_string_re.match(string, pos) if match is None: raise InvalidHeaderFormat( - header_name, "expected quoted string", string=string, pos=pos) + header_name, "expected quoted string", string=string, pos=pos + ) return _unquote_re.sub(r'\1', match.group()[1:-1]), match.end() @@ -139,7 +144,8 @@ def parse_list(parse_item, string, pos, header_name): pos = parse_OWS(string, pos + 1) else: raise InvalidHeaderFormat( - header_name, "expected comma", string=string, pos=pos) + header_name, "expected comma", string=string, pos=pos + ) # Remove extra delimiters before the next item. while peek_ahead(string, pos) == ',': @@ -169,7 +175,8 @@ def parse_connection(string): _protocol_re = re.compile( - r'[-!#$%&\'*+.^_`|~0-9a-zA-Z]+(?:/[-!#$%&\'*+.^_`|~0-9a-zA-Z]+)?') + r'[-!#$%&\'*+.^_`|~0-9a-zA-Z]+(?:/[-!#$%&\'*+.^_`|~0-9a-zA-Z]+)?' +) def parse_protocol(string, pos, header_name): @@ -184,7 +191,8 @@ def parse_protocol(string, pos, header_name): match = _protocol_re.match(string, pos) if match is None: raise InvalidHeaderFormat( - header_name, "expected protocol", string=string, pos=pos) + header_name, "expected protocol", string=string, pos=pos + ) return match.group(), match.end() @@ -216,14 +224,17 @@ def parse_extension_param(string, pos, header_name): if peek_ahead(string, pos) == '=': pos = parse_OWS(string, pos + 1) if peek_ahead(string, pos) == '"': - pos_before = pos # for proper error reporting below + pos_before = pos # for proper error reporting below value, pos = parse_quoted_string(string, pos, header_name) # https://tools.ietf.org/html/rfc6455#section-9.1 says: the value # after quoted-string unescaping MUST conform to the 'token' ABNF. if _token_re.fullmatch(value) is None: raise InvalidHeaderFormat( - header_name, "invalid quoted string content", - string=string, pos=pos_before) + header_name, + "invalid quoted string content", + string=string, + pos=pos_before, + ) else: value, pos = parse_token(string, pos, header_name) pos = parse_OWS(string, pos) @@ -287,11 +298,14 @@ def build_extension(name, parameters): This is the reverse of :func:`parse_extension`. """ - return '; '.join([name] + [ - # Quoted strings aren't necessary because values are always tokens. - name if value is None else '{}={}'.format(name, value) - for name, value in parameters - ]) + return '; '.join( + [name] + + [ + # Quoted strings aren't necessary because values are always tokens. + name if value is None else '{}={}'.format(name, value) + for name, value in parameters + ] + ) def build_extension_list(extensions): @@ -302,8 +316,7 @@ def build_extension_list(extensions): """ return ', '.join( - build_extension(name, parameters) - for name, parameters in extensions + build_extension(name, parameters) for name, parameters in extensions ) diff --git a/websockets/http.py b/websockets/http.py index e0bd17609..e56a4a2c5 100644 --- a/websockets/http.py +++ b/websockets/http.py @@ -16,18 +16,17 @@ __all__ = [ - 'Headers', 'MultipleValuesError', - 'read_request', 'read_response', + 'Headers', + 'MultipleValuesError', + 'read_request', + 'read_response', 'USER_AGENT', ] MAX_HEADERS = 256 MAX_LINE = 4096 -USER_AGENT = ' '.join(( - 'Python/{}'.format(sys.version[:3]), - 'websockets/{}'.format(websockets_version), -)) +USER_AGENT = 'Python/{} websockets/{}'.format(sys.version[:3], websockets_version) # See https://tools.ietf.org/html/rfc7230#appendix-B. @@ -169,7 +168,7 @@ def read_headers(stream): if not _value_re.fullmatch(value): raise ValueError("Invalid HTTP header value: %r" % value) - name = name.decode('ascii') # guaranteed to be ASCII at this point + name = name.decode('ascii') # guaranteed to be ASCII at this point value = value.decode('ascii', 'surrogateescape') headers[name] = value @@ -257,10 +256,10 @@ def __init__(self, *args, **kwargs): self.update(*args, **kwargs) def __str__(self): - return ''.join( - '{}: {}\r\n'.format(key, value) - for key, value in self._list - ) + '\r\n' + return ( + ''.join('{}: {}\r\n'.format(key, value) for key, value in self._list) + + '\r\n' + ) def __repr__(self): return '{}({})'.format(self.__class__.__name__, repr(self._list)) diff --git a/websockets/protocol.py b/websockets/protocol.py index 9c72c5f54..b9bbda37a 100644 --- a/websockets/protocol.py +++ b/websockets/protocol.py @@ -20,7 +20,10 @@ from .compatibility import asyncio_ensure_future from .exceptions import ( - ConnectionClosed, InvalidState, PayloadTooBig, WebSocketProtocolError + ConnectionClosed, + InvalidState, + PayloadTooBig, + WebSocketProtocolError, ) from .framing import * from .handshake import * @@ -35,17 +38,18 @@ # dropping support for Python < 3.5. warnings.filterwarnings( action='ignore', - message=r"'with \(yield from lock\)' is deprecated " - r"use 'async with lock' instead", + message=r"'with \(yield from lock\)' is deprecated use 'async with lock' instead", category=DeprecationWarning, ) # A WebSocket connection goes through the following four states, in order: + class State(enum.IntEnum): CONNECTING, OPEN, CLOSING, CLOSED = range(4) + # In order to ensure consistency, the code always checks the current value of # WebSocketCommonProtocol.state before assigning a new value and never yields # between the check and the assignment. @@ -154,19 +158,29 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): :attr:`close_code` attribute and the reason in :attr:`close_reason`. """ + # There are only two differences between the client-side and the server- # side behavior: masking the payload and closing the underlying TCP # connection. Set is_client and side to pick a side. is_client = None side = 'undefined' - def __init__(self, *, - host=None, port=None, secure=None, - ping_interval=20, ping_timeout=20, - timeout=10, - max_size=2 ** 20, max_queue=2 ** 5, - read_limit=2 ** 16, write_limit=2 ** 16, - loop=None, legacy_recv=False): + def __init__( + self, + *, + host=None, + port=None, + secure=None, + ping_interval=20, + ping_timeout=20, + timeout=10, + max_size=2 ** 20, + max_queue=2 ** 5, + read_limit=2 ** 16, + write_limit=2 ** 16, + loop=None, + legacy_recv=False + ): self.host = host self.port = port self.secure = secure @@ -267,13 +281,16 @@ def connection_open(self): logger.debug("%s - state = OPEN", self.side) # Start the task that receives incoming WebSocket messages. self.transfer_data_task = asyncio_ensure_future( - self.transfer_data(), loop=self.loop) + self.transfer_data(), loop=self.loop + ) # Start the task that sends pings at regular intervals. self.keepalive_ping_task = asyncio_ensure_future( - self.keepalive_ping(), loop=self.loop) + self.keepalive_ping(), loop=self.loop + ) # Start the task that eventually closes the TCP connection. self.close_connection_task = asyncio_ensure_future( - self.close_connection(), loop=self.loop) + self.close_connection(), loop=self.loop + ) # Public API @@ -364,13 +381,14 @@ def recv(self): # received before the closing frame even if the connection is closing. # Wait for a message until the connection is closed. - next_message = asyncio_ensure_future( - self.messages.get(), loop=self.loop) + next_message = asyncio_ensure_future(self.messages.get(), loop=self.loop) # See https://bugs.python.org/issue23859 for cancellation handling. try: done, pending = yield from asyncio.wait( [next_message, self.transfer_data_task], - loop=self.loop, return_when=asyncio.FIRST_COMPLETED) + loop=self.loop, + return_when=asyncio.FIRST_COMPLETED, + ) except asyncio.CancelledError: # Propagate cancellation to avoid leaking the next_message Task. next_message.cancel() @@ -427,7 +445,9 @@ def close(self, code=1000, reason=''): try: yield from asyncio.wait_for( self.write_close_frame(serialize_close(code, reason)), - self.timeout, loop=self.loop) + self.timeout, + loop=self.loop, + ) except asyncio.TimeoutError: # If the close frame cannot be sent because the send buffers # are full, the closing handshake won't complete anyway. @@ -447,8 +467,8 @@ def close(self, code=1000, reason=''): # is canceled before the timeout elapses (on Python ≥ 3.4.3). # This helps closing connections when shutting down a server. yield from asyncio.wait_for( - self.transfer_data_task, - self.timeout, loop=self.loop) + self.transfer_data_task, self.timeout, loop=self.loop + ) except (asyncio.TimeoutError, asyncio.CancelledError): pass @@ -535,7 +555,8 @@ def ensure_open(self): if self.state is State.CLOSED: raise ConnectionClosed( - self.close_code, self.close_reason) from self.transfer_data_exc + self.close_code, self.close_reason + ) from self.transfer_data_exc if self.state is State.CLOSING: # If we started the closing handshake, wait for its completion to @@ -546,7 +567,8 @@ def ensure_open(self): if self.close_code is None: yield from asyncio.shield(self.close_connection_task) raise ConnectionClosed( - self.close_code, self.close_reason) from self.transfer_data_exc + self.close_code, self.close_reason + ) from self.transfer_data_exc # Control may only reach this point in buggy third-party subclasses. assert self.state is State.CONNECTING @@ -624,7 +646,7 @@ def read_message(self): text = True elif frame.opcode == OP_BINARY: text = False - else: # frame.opcode == OP_CONT + else: # frame.opcode == OP_CONT raise WebSocketProtocolError("Unexpected opcode") # Shortcut for the common case - no fragmentation @@ -637,24 +659,32 @@ def read_message(self): if text: decoder = codecs.getincrementaldecoder('utf-8')(errors='strict') if max_size is None: + def append(frame): nonlocal chunks chunks.append(decoder.decode(frame.data, frame.fin)) + else: + def append(frame): nonlocal chunks, max_size chunks.append(decoder.decode(frame.data, frame.fin)) max_size -= len(frame.data) + else: if max_size is None: + def append(frame): nonlocal chunks chunks.append(frame.data) + else: + def append(frame): nonlocal chunks, max_size chunks.append(frame.data) max_size -= len(frame.data) + append(frame) while not frame.fin: @@ -696,8 +726,9 @@ def read_data_frame(self, max_size): # Answer pings. # Replace by frame.data.hex() when dropping Python < 3.5. ping_hex = binascii.hexlify(frame.data).decode() or '[empty]' - logger.debug("%s - received ping, sending pong: %s", - self.side, ping_hex) + logger.debug( + "%s - received ping, sending pong: %s", self.side, ping_hex + ) yield from self.pong(frame.data) elif frame.opcode == OP_PONG: @@ -710,10 +741,10 @@ def read_data_frame(self, max_size): ping_id, pong_waiter = self.pings.popitem(0) ping_ids.append(ping_id) pong_waiter.set_result(None) - pong_hex = ( - binascii.hexlify(frame.data).decode() or '[empty]') - logger.debug("%s - received solicited pong: %s", - self.side, pong_hex) + pong_hex = binascii.hexlify(frame.data).decode() or '[empty]' + logger.debug( + "%s - received solicited pong: %s", self.side, pong_hex + ) ping_ids = ping_ids[:-1] if ping_ids: pings_hex = ', '.join( @@ -723,12 +754,15 @@ def read_data_frame(self, max_size): plural = 's' if len(ping_ids) > 1 else '' logger.debug( "%s - acknowledged previous ping%s: %s", - self.side, plural, pings_hex) + self.side, + plural, + pings_hex, + ) else: - pong_hex = ( - binascii.hexlify(frame.data).decode() or '[empty]') - logger.debug("%s - received unsolicited pong: %s", - self.side, pong_hex) + pong_hex = binascii.hexlify(frame.data).decode() or '[empty]' + logger.debug( + "%s - received unsolicited pong: %s", self.side, pong_hex + ) # 5.6. Data Frames else: @@ -752,23 +786,20 @@ def read_frame(self, max_size): @asyncio.coroutine def write_frame(self, opcode, data=b'', _expected_state=State.OPEN): # Defensive assertion for protocol compliance. - if self.state is not _expected_state: # pragma: no cover - raise InvalidState("Cannot write to a WebSocket " - "in the {} state".format(self.state.name)) + if self.state is not _expected_state: # pragma: no cover + raise InvalidState( + "Cannot write to a WebSocket " "in the {} state".format(self.state.name) + ) frame = Frame(True, opcode, data) logger.debug("%s > %s", self.side, frame) - frame.write( - self.writer.write, - mask=self.is_client, - extensions=self.extensions, - ) + frame.write(self.writer.write, mask=self.is_client, extensions=self.extensions) # Backport of https://github.com/python/asyncio/pull/280. # Remove when dropping support for Python < 3.6. - if self.writer.transport is not None: # pragma: no cover + if self.writer.transport is not None: # pragma: no cover if self.writer_is_closing(): - yield + yield try: # drain() cannot be called concurrently by multiple coroutines: @@ -796,7 +827,7 @@ def writer_is_closing(self): transport = self.writer.transport try: return transport.is_closing() - except AttributeError: # pragma: no cover + except AttributeError: # pragma: no cover # This emulates what is_closing would return if it existed. try: return transport._closing @@ -850,10 +881,10 @@ def keepalive_ping(self): if self.ping_timeout is not None: try: yield from asyncio.wait_for( - ping_waiter, self.ping_timeout, loop=self.loop) + ping_waiter, self.ping_timeout, loop=self.loop + ) except asyncio.TimeoutError: - logger.debug( - "%s ! timed out waiting for pong", self.side) + logger.debug("%s ! timed out waiting for pong", self.side) self.fail_connection(1011) break @@ -861,8 +892,7 @@ def keepalive_ping(self): raise except Exception as exc: - logger.warning( - "Unexpected exception in keepalive ping task", exc_info=True) + logger.warning("Unexpected exception in keepalive ping task", exc_info=True) @asyncio.coroutine def close_connection(self): @@ -899,26 +929,23 @@ def close_connection(self): ) plural = 's' if len(self.pings) > 1 else '' logger.debug( - "%s - canceled pending ping%s: %s", - self.side, plural, pings_hex) + "%s - canceled pending ping%s: %s", self.side, plural, pings_hex + ) # A client should wait for a TCP close from the server. if self.is_client and self.transfer_data_task is not None: if (yield from self.wait_for_connection_lost()): return - logger.debug( - "%s ! timed out waiting for TCP close", self.side) + logger.debug("%s ! timed out waiting for TCP close", self.side) # Half-close the TCP connection if possible (when there's no TLS). if self.writer.can_write_eof(): - logger.debug( - "%s x half-closing TCP connection", self.side) + logger.debug("%s x half-closing TCP connection", self.side) self.writer.write_eof() if (yield from self.wait_for_connection_lost()): return - logger.debug( - "%s ! timed out waiting for TCP close", self.side) + logger.debug("%s ! timed out waiting for TCP close", self.side) finally: # The try/finally ensures that the transport never remains open, @@ -931,18 +958,15 @@ def close_connection(self): return # Close the TCP connection. Buffers are flushed asynchronously. - logger.debug( - "%s x closing TCP connection", self.side) + logger.debug("%s x closing TCP connection", self.side) self.writer.close() if (yield from self.wait_for_connection_lost()): return - logger.debug( - "%s ! timed out waiting for TCP close", self.side) + logger.debug("%s ! timed out waiting for TCP close", self.side) # Abort the TCP connection. Buffers are discarded. - logger.debug( - "%s x aborting TCP connection", self.side) + logger.debug("%s x aborting TCP connection", self.side) self.writer.transport.abort() # connection_lost() is called quickly after aborting. @@ -960,7 +984,9 @@ def wait_for_connection_lost(self): try: yield from asyncio.wait_for( asyncio.shield(self.connection_lost_waiter), - self.timeout, loop=self.loop) + self.timeout, + loop=self.loop, + ) except asyncio.TimeoutError: pass # Re-check self.connection_lost_waiter.done() synchronously because @@ -991,8 +1017,7 @@ def fail_connection(self, code=1006, reason=''): """ logger.debug( - "%s ! failing WebSocket connection: %d %s", - self.side, code, reason, + "%s ! failing WebSocket connection: %d %s", self.side, code, reason ) # Cancel transfer_data_task if the opening handshake succeeded. @@ -1023,15 +1048,14 @@ def fail_connection(self, code=1006, reason=''): frame = Frame(True, OP_CLOSE, frame_data) logger.debug("%s > %s", self.side, frame) frame.write( - self.writer.write, - mask=self.is_client, - extensions=self.extensions, + self.writer.write, mask=self.is_client, extensions=self.extensions ) # Start close_connection_task if the opening handshake didn't succeed. if self.close_connection_task is None: self.close_connection_task = asyncio_ensure_future( - self.close_connection(), loop=self.loop) + self.close_connection(), loop=self.loop + ) return self.close_connection_task @@ -1092,8 +1116,12 @@ def connection_lost(self, exc): logger.debug("%s - state = CLOSED", self.side) if self.close_code is None: self.close_code = 1006 - logger.debug("%s x code = %d, reason = %s", self.side, - self.close_code, self.close_reason or '[empty]') + logger.debug( + "%s x code = %d, reason = %s", + self.side, + self.close_code, + self.close_reason or '[empty]', + ) # If self.connection_lost_waiter isn't pending, that's a bug, because: # - it's set only here in connection_lost() which is called only once; # - it must never be canceled. @@ -1101,6 +1129,7 @@ def connection_lost(self, exc): super().connection_lost(exc) -if sys.version_info[:2] >= (3, 6): # pragma: no cover +if sys.version_info[:2] >= (3, 6): # pragma: no cover from .py36.protocol import __aiter__ + WebSocketCommonProtocol.__aiter__ = __aiter__ diff --git a/websockets/py35/_test_client_server.py b/websockets/py35/_test_client_server.py index c656dd38a..7e7218247 100644 --- a/websockets/py35/_test_client_server.py +++ b/websockets/py35/_test_client_server.py @@ -14,7 +14,6 @@ class AsyncAwaitTests(unittest.TestCase): - def setUp(self): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) @@ -39,7 +38,6 @@ async def run_client(): self.loop.run_until_complete(server.wait_closed()) def test_server(self): - async def run_server(): # Await serve. server = await serve(handler, 'localhost', 0) @@ -52,7 +50,6 @@ async def run_server(): class ContextManagerTests(unittest.TestCase): - def setUp(self): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) @@ -62,7 +59,8 @@ def tearDown(self): # Asynchronous context managers are only enabled on Python ≥ 3.5.1. @unittest.skipIf( - sys.version_info[:3] <= (3, 5, 0), 'this test requires Python 3.5.1+') + sys.version_info[:3] <= (3, 5, 0), 'this test requires Python 3.5.1+' + ) def test_client(self): start_server = serve(handler, 'localhost', 0) server = self.loop.run_until_complete(start_server) @@ -82,9 +80,9 @@ async def run_client(): # Asynchronous context managers are only enabled on Python ≥ 3.5.1. @unittest.skipIf( - sys.version_info[:3] <= (3, 5, 0), 'this test requires Python 3.5.1+') + sys.version_info[:3] <= (3, 5, 0), 'this test requires Python 3.5.1+' + ) def test_server(self): - async def run_server(): # Use serve as an asynchronous context manager. async with serve(handler, 'localhost', 0) as server: @@ -97,11 +95,10 @@ async def run_server(): # Asynchronous context managers are only enabled on Python ≥ 3.5.1. @unittest.skipIf( - sys.version_info[:3] <= (3, 5, 0), 'this test requires Python 3.5.1+') - @unittest.skipUnless( - hasattr(socket, 'AF_UNIX'), 'this test requires Unix sockets') + sys.version_info[:3] <= (3, 5, 0), 'this test requires Python 3.5.1+' + ) + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'this test requires Unix sockets') def test_unix_server(self): - async def run_server(path): async with unix_serve(handler, path) as server: self.assertTrue(server.sockets) diff --git a/websockets/py35/client.py b/websockets/py35/client.py index 7673ea3ad..f62e7d69e 100644 --- a/websockets/py35/client.py +++ b/websockets/py35/client.py @@ -13,7 +13,8 @@ async def __await_impl__(self): try: await protocol.handshake( - self._wsuri, origin=self._origin, + self._wsuri, + origin=self._origin, available_extensions=protocol.available_extensions, available_subprotocols=protocol.available_subprotocols, extra_headers=protocol.extra_headers, diff --git a/websockets/py36/_test_client_server.py b/websockets/py36/_test_client_server.py index e81fbd600..693242f13 100644 --- a/websockets/py36/_test_client_server.py +++ b/websockets/py36/_test_client_server.py @@ -12,7 +12,7 @@ # Fail at import time, not just at run time, to prevent test # discovery. -if sys.version_info[:2] < (3, 6): # pragma: no cover +if sys.version_info[:2] < (3, 6): # pragma: no cover raise ImportError("Python 3.6+ only") @@ -32,7 +32,6 @@ def tearDown(self): self.loop.close() def test_iterate_on_messages(self): - async def handler(ws, path): for message in MESSAGES: await ws.send(message) @@ -56,7 +55,6 @@ async def run_client(): self.loop.run_until_complete(server.wait_closed()) def test_iterate_on_messages_going_away_exit_ok(self): - async def handler(ws, path): for message in MESSAGES: await ws.send(message) @@ -81,7 +79,6 @@ async def run_client(): self.loop.run_until_complete(server.wait_closed()) def test_iterate_on_messages_internal_error_exit_not_ok(self): - async def handler(ws, path): for message in MESSAGES: await ws.send(message) diff --git a/websockets/server.py b/websockets/server.py index 27cca4b92..4284d09ad 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -11,18 +11,26 @@ import warnings from .compatibility import ( - BAD_REQUEST, FORBIDDEN, INTERNAL_SERVER_ERROR, SERVICE_UNAVAILABLE, - SWITCHING_PROTOCOLS, UPGRADE_REQUIRED, asyncio_ensure_future + BAD_REQUEST, + FORBIDDEN, + INTERNAL_SERVER_ERROR, + SERVICE_UNAVAILABLE, + SWITCHING_PROTOCOLS, + UPGRADE_REQUIRED, + asyncio_ensure_future, ) from .exceptions import ( - AbortHandshake, InvalidHandshake, InvalidHeader, InvalidMessage, - InvalidOrigin, InvalidUpgrade, NegotiationError + AbortHandshake, + InvalidHandshake, + InvalidHeader, + InvalidMessage, + InvalidOrigin, + InvalidUpgrade, + NegotiationError, ) from .extensions.permessage_deflate import ServerPerMessageDeflateFactory from .handshake import build_response, check_request -from .headers import ( - build_extension_list, parse_extension_list, parse_subprotocol_list -) +from .headers import build_extension_list, parse_extension_list, parse_subprotocol_list from .http import USER_AGENT, Headers, MultipleValuesError, read_request from .protocol import WebSocketCommonProtocol @@ -43,16 +51,24 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): Its support for HTTP responses is very limited. """ + is_client = False side = 'server' - def __init__(self, ws_handler, ws_server, *, - origins=None, extensions=None, subprotocols=None, - extra_headers=None, **kwds): + def __init__( + self, + ws_handler, + ws_server, + *, + origins=None, + extensions=None, + subprotocols=None, + extra_headers=None, + **kwds + ): # For backwards-compatibility with 6.0 or earlier. if origins is not None and '' in origins: - warnings.warn( - "use None instead of '' in origins", DeprecationWarning) + warnings.warn("use None instead of '' in origins", DeprecationWarning) origins = [None if origin == '' else origin for origin in origins] self.ws_handler = ws_handler self.ws_server = ws_server @@ -73,8 +89,7 @@ def connection_made(self, transport): # create a race condition between the creation of the task, which # schedules its execution, and the moment the handler starts running. self.ws_server.register(self) - self.handler_task = asyncio_ensure_future( - self.handler(), loop=self.loop) + self.handler_task = asyncio_ensure_future(self.handler(), loop=self.loop) @asyncio.coroutine def handler(self): @@ -95,8 +110,7 @@ def handler(self): extra_headers=self.extra_headers, ) except ConnectionError as exc: - logger.debug( - "Connection error in opening handshake", exc_info=True) + logger.debug("Connection error in opening handshake", exc_info=True) raise except Exception as exc: if self._is_server_shutting_down(exc): @@ -106,18 +120,10 @@ def handler(self): b"Server is shutting down.\n", ) elif isinstance(exc, AbortHandshake): - status, headers, body = ( - exc.status, - exc.headers, - exc.body, - ) + status, headers, body = exc.status, exc.headers, exc.body elif isinstance(exc, InvalidOrigin): logger.debug("Invalid origin", exc_info=True) - status, headers, body = ( - FORBIDDEN, - [], - (str(exc) + "\n").encode(), - ) + status, headers, body = FORBIDDEN, [], (str(exc) + "\n").encode() elif isinstance(exc, InvalidUpgrade): logger.debug("Invalid upgrade", exc_info=True) status, headers, body = ( @@ -169,8 +175,7 @@ def handler(self): try: yield from self.close() except ConnectionError as exc: - logger.debug( - "Connection error in closing handshake", exc_info=True) + logger.debug("Connection error in closing handshake", exc_info=True) raise except Exception as exc: if not self._is_server_shutting_down(exc): @@ -181,7 +186,7 @@ def handler(self): # Last-ditch attempt to avoid leaking connections on errors. try: self.writer.close() - except Exception: # pragma: no cover + except Exception: # pragma: no cover pass finally: @@ -196,10 +201,7 @@ def _is_server_shutting_down(self, exc): Decide whether an exception means that the server is shutting down. """ - return ( - isinstance(exc, asyncio.CancelledError) and - self.ws_server.closing - ) + return isinstance(exc, asyncio.CancelledError) and self.ws_server.closing @asyncio.coroutine def read_http_request(self): @@ -236,8 +238,7 @@ def write_http_response(self, status, headers, body=None): # Since the status line and headers only contain ASCII characters, # we can keep this simple. - response = 'HTTP/1.1 {status.value} {status.phrase}\r\n'.format( - status=status) + response = 'HTTP/1.1 {status.value} {status.phrase}\r\n'.format(status=status) response += str(headers) self.writer.write(response.encode()) @@ -337,24 +338,24 @@ def process_extensions(headers, available_extensions): if header_values and available_extensions: - parsed_header_values = sum([ - parse_extension_list(header_value) - for header_value in header_values - ], []) + parsed_header_values = sum( + [parse_extension_list(header_value) for header_value in header_values], + [], + ) for name, request_params in parsed_header_values: - for extension_factory in available_extensions: + for ext_factory in available_extensions: # Skip non-matching extensions based on their name. - if extension_factory.name != name: + if ext_factory.name != name: continue # Skip non-matching extensions based on their params. try: - response_params, extension = ( - extension_factory.process_request_params( - request_params, accepted_extensions)) + response_params, extension = ext_factory.process_request_params( + request_params, accepted_extensions + ) except NegotiationError: continue @@ -391,14 +392,16 @@ def process_subprotocol(self, headers, available_subprotocols): if header_values and available_subprotocols: - parsed_header_values = sum([ - parse_subprotocol_list(header_value) - for header_value in header_values - ], []) + parsed_header_values = sum( + [ + parse_subprotocol_list(header_value) + for header_value in header_values + ], + [], + ) subprotocol = self.select_subprotocol( - parsed_header_values, - available_subprotocols, + parsed_header_values, available_subprotocols ) return subprotocol @@ -424,12 +427,18 @@ def select_subprotocol(client_subprotocols, server_subprotocols): if not subprotocols: return None priority = lambda p: ( - client_subprotocols.index(p) + server_subprotocols.index(p)) + client_subprotocols.index(p) + server_subprotocols.index(p) + ) return sorted(subprotocols, key=priority)[0] @asyncio.coroutine - def handshake(self, origins=None, available_extensions=None, - available_subprotocols=None, extra_headers=None): + def handshake( + self, + origins=None, + available_extensions=None, + available_subprotocols=None, + extra_headers=None, + ): """ Perform the server side of the opening handshake. @@ -468,10 +477,12 @@ def handshake(self, origins=None, available_extensions=None, self.origin = self.process_origin(request_headers, origins) extensions_header, self.extensions = self.process_extensions( - request_headers, available_extensions) + request_headers, available_extensions + ) protocol_header = self.subprotocol = self.process_subprotocol( - request_headers, available_subprotocols) + request_headers, available_subprotocols + ) response_headers = Headers() response_headers['Date'] = email.utils.formatdate(usegmt=True) @@ -496,8 +507,7 @@ def handshake(self, origins=None, available_extensions=None, response_headers.setdefault('Server', USER_AGENT) - yield from self.write_http_response( - SWITCHING_PROTOCOLS, response_headers) + yield from self.write_http_response(SWITCHING_PROTOCOLS, response_headers) self.connection_open() @@ -523,6 +533,7 @@ class WebSocketServer: custom :class:`~asyncio.Server` class. """ + def __init__(self, loop): # Store a reference to loop to avoid relying on self.server._loop. self.loop = loop @@ -605,10 +616,10 @@ def wait_closed(self): # depending on how the client behaves and the server is # implemented. yield from asyncio.wait( - [websocket.handler_task for websocket in self.websockets] + - [websocket.close_connection_task - for websocket in self.websockets], - loop=self.loop) + [websocket.handler_task for websocket in self.websockets] + + [websocket.close_connection_task for websocket in self.websockets], + loop=self.loop, + ) yield from self.server.wait_closed() @property @@ -700,15 +711,31 @@ class Serve: """ - def __init__(self, ws_handler, host=None, port=None, *, - path=None, create_protocol=None, - ping_interval=20, ping_timeout=20, - timeout=10, - max_size=2 ** 20, max_queue=2 ** 5, - read_limit=2 ** 16, write_limit=2 ** 16, - loop=None, legacy_recv=False, klass=None, - origins=None, extensions=None, subprotocols=None, - extra_headers=None, compression='deflate', **kwds): + def __init__( + self, + ws_handler, + host=None, + port=None, + *, + path=None, + create_protocol=None, + ping_interval=20, + ping_timeout=20, + timeout=10, + max_size=2 ** 20, + max_queue=2 ** 5, + read_limit=2 ** 16, + write_limit=2 ** 16, + loop=None, + legacy_recv=False, + klass=None, + origins=None, + extensions=None, + subprotocols=None, + extra_headers=None, + compression='deflate', + **kwds + ): # Backwards-compatibility: create_protocol used to be called klass. # In the unlikely event that both are specified, klass is ignored. if create_protocol is None: @@ -728,22 +755,31 @@ def __init__(self, ws_handler, host=None, port=None, *, if extensions is None: extensions = [] if not any( - extension_factory.name == ServerPerMessageDeflateFactory.name - for extension_factory in extensions + ext_factory.name == ServerPerMessageDeflateFactory.name + for ext_factory in extensions ): extensions.append(ServerPerMessageDeflateFactory()) elif compression is not None: raise ValueError("Unsupported compression: {}".format(compression)) factory = lambda: create_protocol( - ws_handler, ws_server, - host=host, port=port, secure=secure, - ping_interval=ping_interval, ping_timeout=ping_timeout, + ws_handler, + ws_server, + host=host, + port=port, + secure=secure, + ping_interval=ping_interval, + ping_timeout=ping_timeout, timeout=timeout, - max_size=max_size, max_queue=max_queue, - read_limit=read_limit, write_limit=write_limit, - loop=loop, legacy_recv=legacy_recv, - origins=origins, extensions=extensions, subprotocols=subprotocols, + max_size=max_size, + max_queue=max_queue, + read_limit=read_limit, + write_limit=write_limit, + loop=loop, + legacy_recv=legacy_recv, + origins=origins, + extensions=extensions, + subprotocols=subprotocols, extra_headers=extra_headers, ) @@ -757,7 +793,7 @@ def __init__(self, ws_handler, host=None, port=None, *, self.ws_server = ws_server @asyncio.coroutine - def __iter__(self): # pragma: no cover + def __iter__(self): # pragma: no cover server = yield from self._creating_server self.ws_server.wrap(server) return self.ws_server @@ -781,14 +817,17 @@ def unix_serve(ws_handler, path, **kwargs): # We can't define __await__ on Python < 3.5.1 because asyncio.ensure_future # didn't accept arbitrary awaitables until Python 3.5.1. We don't define # __aenter__ and __aexit__ either on Python < 3.5.1 to keep things simple. -if sys.version_info[:3] <= (3, 5, 0): # pragma: no cover +if sys.version_info[:3] <= (3, 5, 0): # pragma: no cover + @asyncio.coroutine def serve(*args, **kwds): return Serve(*args, **kwds).__iter__() + serve.__doc__ = Serve.__doc__ else: from .py35.server import __aenter__, __aexit__, __await__ + Serve.__aenter__ = __aenter__ Serve.__aexit__ = __aexit__ Serve.__await__ = __await__ diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index b3b15bb30..7f66ea036 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -17,11 +17,15 @@ from .client import * from .compatibility import FORBIDDEN, OK, UNAUTHORIZED from .exceptions import ( - ConnectionClosed, InvalidHandshake, InvalidStatusCode, NegotiationError + ConnectionClosed, + InvalidHandshake, + InvalidStatusCode, + NegotiationError, ) from .extensions.permessage_deflate import ( - ClientPerMessageDeflateFactory, PerMessageDeflate, - ServerPerMessageDeflateFactory + ClientPerMessageDeflateFactory, + PerMessageDeflate, + ServerPerMessageDeflateFactory, ) from .handshake import build_response from .http import USER_AGENT, Headers, read_response @@ -89,6 +93,7 @@ def with_manager(manager, *args, **kwds): Return a decorator that wraps a function with a context manager. """ + def decorate(func): @functools.wraps(func) def _decorate(self, *_args, **_kwds): @@ -130,22 +135,21 @@ def get_server_uri(server, secure=False, resource_name='/', user_info=None): # needed, either use the first socket, or test separately IPv4 and IPv6. server_socket = random.choice(server.sockets) - if server_socket.family == socket.AF_INET6: # pragma: no cover - host, port = server_socket.getsockname()[:2] # (no IPv6 on CI) + if server_socket.family == socket.AF_INET6: # pragma: no cover + host, port = server_socket.getsockname()[:2] # (no IPv6 on CI) host = '[{}]'.format(host) elif server_socket.family == socket.AF_INET: host, port = server_socket.getsockname() elif server_socket.family == socket.AF_UNIX: # The host and port are ignored when connecting to a Unix socket. host, port = 'localhost', 0 - else: # pragma: no cover + else: # pragma: no cover raise ValueError("Expected an IPv6, IPv4, or Unix socket") return '{}://{}{}:{}{}'.format(proto, user_info, host, port, resource_name) class UnauthorizedServerProtocol(WebSocketServerProtocol): - @asyncio.coroutine def process_request(self, path, request_headers): # Test returning headers as a Headers instance (1/3) @@ -153,7 +157,6 @@ def process_request(self, path, request_headers): class ForbiddenServerProtocol(WebSocketServerProtocol): - @asyncio.coroutine def process_request(self, path, request_headers): # Test returning headers as a dict (2/3) @@ -161,7 +164,6 @@ def process_request(self, path, request_headers): class HealthCheckServerProtocol(WebSocketServerProtocol): - @asyncio.coroutine def process_request(self, path, request_headers): # Test returning headers as a list of pairs (3/3) @@ -243,24 +245,25 @@ def start_client(self, resource_name='/', user_info=None, **kwds): # Disable pings by default in tests. kwds.setdefault('ping_interval', None) secure = kwds.get('ssl') is not None - server_uri = get_server_uri( - self.server, secure, resource_name, user_info) + server_uri = get_server_uri(self.server, secure, resource_name, user_info) start_client = connect(server_uri, **kwds) self.client = self.loop.run_until_complete(start_client) def stop_client(self): try: self.loop.run_until_complete( - asyncio.wait_for(self.client.close_connection_task, timeout=1)) - except asyncio.TimeoutError: # pragma: no cover + asyncio.wait_for(self.client.close_connection_task, timeout=1) + ) + except asyncio.TimeoutError: # pragma: no cover self.fail("Client failed to stop") def stop_server(self): self.server.close() try: self.loop.run_until_complete( - asyncio.wait_for(self.server.wait_closed(), timeout=1)) - except asyncio.TimeoutError: # pragma: no cover + asyncio.wait_for(self.server.wait_closed(), timeout=1) + ) + except asyncio.TimeoutError: # pragma: no cover self.fail("Server failed to stop") @contextlib.contextmanager @@ -299,11 +302,9 @@ def test_explicit_event_loop(self): # The way the legacy SSL implementation wraps sockets makes it extremely # hard to write a test for Python 3.4. - @unittest.skipIf( - sys.version_info[:2] <= (3, 4), 'this test requires Python 3.5+') + @unittest.skipIf(sys.version_info[:2] <= (3, 4), 'this test requires Python 3.5+') @with_server() def test_explicit_socket(self): - class TrackedSocket(socket.socket): def __init__(self, *args, **kwargs): self.used_for_read = False @@ -319,7 +320,8 @@ def send(self, *args, **kwargs): return super().send(*args, **kwargs) server_socket = [ - s for s in self.server.sockets if s.family == socket.AF_INET][0] + sock for sock in self.server.sockets if sock.family == socket.AF_INET + ][0] client_socket = TrackedSocket(socket.AF_INET, socket.SOCK_STREAM) client_socket.connect(server_socket.getsockname()) @@ -342,8 +344,7 @@ def send(self, *args, **kwargs): finally: client_socket.close() - @unittest.skipUnless( - hasattr(socket, 'AF_UNIX'), 'this test requires Unix sockets') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'this test requires Unix sockets') def test_unix_socket(self): with tempfile.TemporaryDirectory() as temp_dir: path = bytes(pathlib.Path(temp_dir) / 'websockets') @@ -392,8 +393,7 @@ def test_protocol_path(self): @with_client('/headers', user_info=('user', 'pass')) def test_protocol_basic_auth(self): self.assertEqual( - self.client.request_headers['Authorization'], - 'Basic dXNlcjpwYXNz', + self.client.request_headers['Authorization'], 'Basic dXNlcjpwYXNz' ) @with_server() @@ -489,24 +489,22 @@ def test_protocol_custom_response_user_agent(self): def make_http_request(self, path='/'): # Set url to 'https?://:'. - url = get_server_uri( - self.server, resource_name=path, secure=self.secure) + url = get_server_uri(self.server, resource_name=path, secure=self.secure) url = url.replace('ws', 'http') if self.secure: open_health_check = functools.partial( - urllib.request.urlopen, url, context=self.client_context) + urllib.request.urlopen, url, context=self.client_context + ) else: - open_health_check = functools.partial( - urllib.request.urlopen, url) + open_health_check = functools.partial(urllib.request.urlopen, url) return self.loop.run_in_executor(None, open_health_check) @with_server(create_protocol=HealthCheckServerProtocol) def test_http_request_http_endpoint(self): # Making a HTTP request to a HTTP endpoint succeeds. - response = self.loop.run_until_complete( - self.make_http_request('/__health__/')) + response = self.loop.run_until_complete(self.make_http_request('/__health__/')) with contextlib.closing(response): self.assertEqual(response.code, 200) @@ -546,8 +544,11 @@ def assert_client_raises_code(self, status_code): def test_server_create_protocol(self): self.assert_client_raises_code(401) - @with_server(create_protocol=(lambda *args, **kwargs: - UnauthorizedServerProtocol(*args, **kwargs))) + @with_server( + create_protocol=( + lambda *args, **kwargs: UnauthorizedServerProtocol(*args, **kwargs) + ) + ) def test_server_create_protocol_function(self): self.assert_client_raises_code(401) @@ -555,8 +556,9 @@ def test_server_create_protocol_function(self): def test_server_klass(self): self.assert_client_raises_code(401) - @with_server(create_protocol=ForbiddenServerProtocol, - klass=UnauthorizedServerProtocol) + @with_server( + create_protocol=ForbiddenServerProtocol, klass=UnauthorizedServerProtocol + ) def test_server_create_protocol_over_klass(self): self.assert_client_raises_code(403) @@ -566,8 +568,10 @@ def test_client_create_protocol(self): self.assertIsInstance(self.client, FooClientProtocol) @with_server() - @with_client('/path', create_protocol=( - lambda *args, **kwargs: FooClientProtocol(*args, **kwargs))) + @with_client( + '/path', + create_protocol=(lambda *args, **kwargs: FooClientProtocol(*args, **kwargs)), + ) def test_client_create_protocol_function(self): self.assertIsInstance(self.client, FooClientProtocol) @@ -577,8 +581,7 @@ def test_client_klass(self): self.assertIsInstance(self.client, FooClientProtocol) @with_server() - @with_client('/path', create_protocol=BarClientProtocol, - klass=FooClientProtocol) + @with_client('/path', create_protocol=BarClientProtocol, klass=FooClientProtocol) def test_client_create_protocol_over_klass(self): self.assertIsInstance(self.client, BarClientProtocol) @@ -613,33 +616,26 @@ def test_extension_not_requested(self): @with_server(extensions=[ServerNoOpExtensionFactory([('foo', None)])]) def test_extension_client_rejection(self): with self.assertRaises(NegotiationError): - self.start_client( - '/extensions', - extensions=[ClientNoOpExtensionFactory()], - ) + self.start_client('/extensions', extensions=[ClientNoOpExtensionFactory()]) @with_server( extensions=[ # No match because the client doesn't send client_max_window_bits. ServerPerMessageDeflateFactory(client_max_window_bits=10), ServerPerMessageDeflateFactory(), - ], - ) - @with_client( - '/extensions', - extensions=[ - ClientPerMessageDeflateFactory(), - ], + ] ) + @with_client('/extensions', extensions=[ClientPerMessageDeflateFactory()]) def test_extension_no_match_then_match(self): # The order requested by the client has priority. server_extensions = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(server_extensions, repr([ - PerMessageDeflate(False, False, 15, 15), - ])) - self.assertEqual(repr(self.client.extensions), repr([ - PerMessageDeflate(False, False, 15, 15), - ])) + self.assertEqual( + server_extensions, repr([PerMessageDeflate(False, False, 15, 15)]) + ) + self.assertEqual( + repr(self.client.extensions), + repr([PerMessageDeflate(False, False, 15, 15)]), + ) @with_server(extensions=[ServerPerMessageDeflateFactory()]) @with_client('/extensions', extensions=[ClientNoOpExtensionFactory()]) @@ -649,29 +645,23 @@ def test_extension_mismatch(self): self.assertEqual(repr(self.client.extensions), repr([])) @with_server( - extensions=[ - ServerNoOpExtensionFactory(), - ServerPerMessageDeflateFactory(), - ], + extensions=[ServerNoOpExtensionFactory(), ServerPerMessageDeflateFactory()] ) @with_client( '/extensions', - extensions=[ - ClientPerMessageDeflateFactory(), - ClientNoOpExtensionFactory(), - ], + extensions=[ClientPerMessageDeflateFactory(), ClientNoOpExtensionFactory()], ) def test_extension_order(self): # The order requested by the client has priority. server_extensions = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(server_extensions, repr([ - PerMessageDeflate(False, False, 15, 15), - NoOpExtension(), - ])) - self.assertEqual(repr(self.client.extensions), repr([ - PerMessageDeflate(False, False, 15, 15), - NoOpExtension(), - ])) + self.assertEqual( + server_extensions, + repr([PerMessageDeflate(False, False, 15, 15), NoOpExtension()]), + ) + self.assertEqual( + repr(self.client.extensions), + repr([PerMessageDeflate(False, False, 15, 15), NoOpExtension()]), + ) @with_server(extensions=[ServerNoOpExtensionFactory()]) @unittest.mock.patch.object(WebSocketServerProtocol, 'process_extensions') @@ -680,8 +670,7 @@ def test_extensions_error(self, _process_extensions): with self.assertRaises(NegotiationError): self.start_client( - '/extensions', - extensions=[ClientPerMessageDeflateFactory()], + '/extensions', extensions=[ClientPerMessageDeflateFactory()] ) @with_server(extensions=[ServerNoOpExtensionFactory()]) @@ -696,19 +685,19 @@ def test_extensions_error_no_extensions(self, _process_extensions): @with_client('/extensions', compression='deflate') def test_compression_deflate(self): server_extensions = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(server_extensions, repr([ - PerMessageDeflate(False, False, 15, 15), - ])) - self.assertEqual(repr(self.client.extensions), repr([ - PerMessageDeflate(False, False, 15, 15), - ])) + self.assertEqual( + server_extensions, repr([PerMessageDeflate(False, False, 15, 15)]) + ) + self.assertEqual( + repr(self.client.extensions), + repr([PerMessageDeflate(False, False, 15, 15)]), + ) @with_server( extensions=[ ServerPerMessageDeflateFactory( - client_no_context_takeover=True, - server_max_window_bits=10, - ), + client_no_context_takeover=True, server_max_window_bits=10 + ) ], compression='deflate', # overridden by explicit config ) @@ -716,20 +705,19 @@ def test_compression_deflate(self): '/extensions', extensions=[ ClientPerMessageDeflateFactory( - server_no_context_takeover=True, - client_max_window_bits=12, - ), + server_no_context_takeover=True, client_max_window_bits=12 + ) ], compression='deflate', # overridden by explicit config ) def test_compression_deflate_and_explicit_config(self): server_extensions = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(server_extensions, repr([ - PerMessageDeflate(True, True, 12, 10), - ])) - self.assertEqual(repr(self.client.extensions), repr([ - PerMessageDeflate(True, True, 10, 12), - ])) + self.assertEqual( + server_extensions, repr([PerMessageDeflate(True, True, 12, 10)]) + ) + self.assertEqual( + repr(self.client.extensions), repr([PerMessageDeflate(True, True, 10, 12)]) + ) def test_compression_unsupported_server(self): with self.assertRaises(ValueError): @@ -799,8 +787,7 @@ def test_subprotocol_error_two_subprotocols(self, _process_subprotocol): _process_subprotocol.return_value = 'superchat, chat' with self.assertRaises(InvalidHandshake): - self.start_client( - '/subprotocol', subprotocols=['superchat', 'chat']) + self.start_client('/subprotocol', subprotocols=['superchat', 'chat']) self.run_loop_once() @with_server() @@ -825,6 +812,7 @@ def test_client_receives_malformed_response(self, _read_response): def test_client_sends_invalid_handshake_request(self, _build_request): def wrong_build_request(headers): return '42' + _build_request.side_effect = wrong_build_request with self.assertRaises(InvalidHandshake): @@ -835,6 +823,7 @@ def wrong_build_request(headers): def test_server_sends_invalid_handshake_response(self, _build_response): def wrong_build_response(headers, key): return build_response(headers, '42') + _build_response.side_effect = wrong_build_response with self.assertRaises(InvalidHandshake): @@ -847,6 +836,7 @@ def test_server_does_not_switch_protocols(self, _read_response): def wrong_read_response(stream): status_code, headers = yield from read_response(stream) return 400, headers + _read_response.side_effect = wrong_read_response with self.assertRaises(InvalidStatusCode): @@ -854,8 +844,7 @@ def wrong_read_response(stream): self.run_loop_once() @with_server() - @unittest.mock.patch( - 'websockets.server.WebSocketServerProtocol.process_request') + @unittest.mock.patch('websockets.server.WebSocketServerProtocol.process_request') def test_server_error_in_handshake(self, _process_request): _process_request.side_effect = Exception("process_request crashed") @@ -944,11 +933,12 @@ def test_invalid_status_error_during_client_connect(self): @with_server() @unittest.mock.patch( - 'websockets.server.WebSocketServerProtocol.write_http_response') - @unittest.mock.patch( - 'websockets.server.WebSocketServerProtocol.read_http_request') + 'websockets.server.WebSocketServerProtocol.write_http_response' + ) + @unittest.mock.patch('websockets.server.WebSocketServerProtocol.read_http_request') def test_connection_error_during_opening_handshake( - self, _read_http_request, _write_http_response): + self, _read_http_request, _write_http_response + ): _read_http_request.side_effect = ConnectionError # This exception is currently platform-dependent. It was observed to @@ -996,7 +986,7 @@ def client_context(self): ssl_context.verify_mode = ssl.CERT_REQUIRED # ssl.match_hostname can't match IP addresses on Python < 3.5. # We're using IP addresses to enforce testing of IPv4 and IPv6. - if sys.version_info[:2] >= (3, 5): # pragma: no cover + if sys.version_info[:2] >= (3, 5): # pragma: no cover ssl_context.check_hostname = True return ssl_context @@ -1015,17 +1005,15 @@ def start_client(self, path='/', **kwds): def test_ws_uri_is_rejected(self): with self.assertRaises(ValueError): client = connect( - get_server_uri(self.server, secure=False), - ssl=self.client_context, + get_server_uri(self.server, secure=False), ssl=self.client_context ) # With Python ≥ 3.5, the exception is raised by connect() even # before awaiting. However, with Python 3.4 the exception is # raised only when awaiting. - self.loop.run_until_complete(client) # pragma: no cover + self.loop.run_until_complete(client) # pragma: no cover class ClientServerOriginTests(unittest.TestCase): - def setUp(self): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) @@ -1035,9 +1023,11 @@ def tearDown(self): def test_checking_origin_succeeds(self): server = self.loop.run_until_complete( - serve(handler, 'localhost', 0, origins=['http://localhost'])) + serve(handler, 'localhost', 0, origins=['http://localhost']) + ) client = self.loop.run_until_complete( - connect(get_server_uri(server), origin='http://localhost')) + connect(get_server_uri(server), origin='http://localhost') + ) self.loop.run_until_complete(client.send("Hello!")) self.assertEqual(self.loop.run_until_complete(client.recv()), "Hello!") @@ -1048,30 +1038,36 @@ def test_checking_origin_succeeds(self): def test_checking_origin_fails(self): server = self.loop.run_until_complete( - serve(handler, 'localhost', 0, origins=['http://localhost'])) - with self.assertRaisesRegex(InvalidHandshake, - "Status code not 101: 403"): + serve(handler, 'localhost', 0, origins=['http://localhost']) + ) + with self.assertRaisesRegex(InvalidHandshake, "Status code not 101: 403"): self.loop.run_until_complete( - connect(get_server_uri(server), origin='http://otherhost')) + connect(get_server_uri(server), origin='http://otherhost') + ) server.close() self.loop.run_until_complete(server.wait_closed()) def test_checking_origins_fails_with_multiple_headers(self): server = self.loop.run_until_complete( - serve(handler, 'localhost', 0, origins=['http://localhost'])) - with self.assertRaisesRegex(InvalidHandshake, - "Status code not 101: 400"): + serve(handler, 'localhost', 0, origins=['http://localhost']) + ) + with self.assertRaisesRegex(InvalidHandshake, "Status code not 101: 400"): self.loop.run_until_complete( - connect(get_server_uri(server), origin='http://localhost', - extra_headers=[('Origin', 'http://otherhost')])) + connect( + get_server_uri(server), + origin='http://localhost', + extra_headers=[('Origin', 'http://otherhost')], + ) + ) server.close() self.loop.run_until_complete(server.wait_closed()) def test_checking_lack_of_origin_succeeds(self): server = self.loop.run_until_complete( - serve(handler, 'localhost', 0, origins=[None])) + serve(handler, 'localhost', 0, origins=[None]) + ) client = self.loop.run_until_complete(connect(get_server_uri(server))) self.loop.run_until_complete(client.send("Hello!")) @@ -1084,9 +1080,9 @@ def test_checking_lack_of_origin_succeeds(self): def test_checking_lack_of_origin_succeeds_backwards_compatibility(self): with warnings.catch_warnings(record=True) as recorded_warnings: server = self.loop.run_until_complete( - serve(handler, 'localhost', 0, origins=[''])) - client = self.loop.run_until_complete( - connect(get_server_uri(server))) + serve(handler, 'localhost', 0, origins=['']) + ) + client = self.loop.run_until_complete(connect(get_server_uri(server))) self.assertEqual(len(recorded_warnings), 1) warning = recorded_warnings[0].message @@ -1102,7 +1098,6 @@ def test_checking_lack_of_origin_succeeds_backwards_compatibility(self): class YieldFromTests(unittest.TestCase): - def setUp(self): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) @@ -1128,7 +1123,6 @@ def run_client(): self.loop.run_until_complete(server.wait_closed()) def test_server(self): - @asyncio.coroutine def run_server(): # Yield from serve. @@ -1141,10 +1135,10 @@ def run_server(): self.loop.run_until_complete(run_server()) -if sys.version_info[:2] >= (3, 5): # pragma: no cover - from .py35._test_client_server import AsyncAwaitTests # noqa - from .py35._test_client_server import ContextManagerTests # noqa +if sys.version_info[:2] >= (3, 5): # pragma: no cover + from .py35._test_client_server import AsyncAwaitTests # noqa + from .py35._test_client_server import ContextManagerTests # noqa -if sys.version_info[:2] >= (3, 6): # pragma: no cover - from .py36._test_client_server import AsyncIteratorTests # noqa +if sys.version_info[:2] >= (3, 6): # pragma: no cover + from .py36._test_client_server import AsyncIteratorTests # noqa diff --git a/websockets/test_exceptions.py b/websockets/test_exceptions.py index 8092b6d11..49042105c 100644 --- a/websockets/test_exceptions.py +++ b/websockets/test_exceptions.py @@ -5,9 +5,9 @@ class ExceptionsTests(unittest.TestCase): - def test_str(self): for exception, exception_str in [ + # fmt: off ( InvalidHandshake("Invalid request"), "Invalid request", @@ -38,7 +38,8 @@ def test_str(self): ), ( InvalidHeaderFormat( - 'Sec-WebSocket-Protocol', "expected token", 'a=|', 3), + 'Sec-WebSocket-Protocol', "expected token", 'a=|', 3 + ), "Invalid Sec-WebSocket-Protocol header: " "expected token at 3 in a=|", ), @@ -125,6 +126,7 @@ def test_str(self): WebSocketProtocolError("Invalid opcode: 7"), "Invalid opcode: 7", ), + # fmt: on ]: with self.subTest(exception=exception): self.assertEqual(str(exception), exception_str) diff --git a/websockets/test_framing.py b/websockets/test_framing.py index d550f7268..0ea0a2851 100644 --- a/websockets/test_framing.py +++ b/websockets/test_framing.py @@ -8,7 +8,6 @@ class FramingTests(unittest.TestCase): - def setUp(self): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) @@ -20,10 +19,14 @@ def decode(self, message, mask=False, max_size=None, extensions=None): self.stream = asyncio.StreamReader(loop=self.loop) self.stream.feed_data(message) self.stream.feed_eof() - frame = self.loop.run_until_complete(Frame.read( - self.stream.readexactly, mask=mask, - max_size=max_size, extensions=extensions, - )) + frame = self.loop.run_until_complete( + Frame.read( + self.stream.readexactly, + mask=mask, + max_size=max_size, + extensions=extensions, + ) + ) # Make sure all the data was consumed. self.assertTrue(self.stream.at_eof()) return frame @@ -43,10 +46,10 @@ def round_trip(self, message, expected, mask=False, extensions=None): decoded = self.decode(message, mask, extensions=extensions) self.assertEqual(decoded, expected) encoded = self.encode(decoded, mask, extensions=extensions) - if mask: # non-deterministic encoding + if mask: # non-deterministic encoding decoded = self.decode(encoded, mask, extensions=extensions) self.assertEqual(decoded, expected) - else: # deterministic encoding + else: # deterministic encoding self.assertEqual(encoded, message) def round_trip_close(self, data, code, reason): @@ -56,10 +59,7 @@ def round_trip_close(self, data, code, reason): self.assertEqual(serialized, data) def test_text(self): - self.round_trip( - b'\x81\x04Spam', - Frame(True, OP_TEXT, b'Spam'), - ) + self.round_trip(b'\x81\x04Spam', Frame(True, OP_TEXT, b'Spam')) def test_text_masked(self): self.round_trip( @@ -69,10 +69,7 @@ def test_text_masked(self): ) def test_binary(self): - self.round_trip( - b'\x82\x04Eggs', - Frame(True, OP_BINARY, b'Eggs'), - ) + self.round_trip(b'\x82\x04Eggs', Frame(True, OP_BINARY, b'Eggs')) def test_binary_masked(self): self.round_trip( @@ -83,8 +80,7 @@ def test_binary_masked(self): def test_non_ascii_text(self): self.round_trip( - b'\x81\x05caf\xc3\xa9', - Frame(True, OP_TEXT, 'café'.encode('utf-8')), + b'\x81\x05caf\xc3\xa9', Frame(True, OP_TEXT, 'café'.encode('utf-8')) ) def test_non_ascii_text_masked(self): @@ -95,27 +91,17 @@ def test_non_ascii_text_masked(self): ) def test_close(self): - self.round_trip( - b'\x88\x00', - Frame(True, OP_CLOSE, b''), - ) + self.round_trip(b'\x88\x00', Frame(True, OP_CLOSE, b'')) def test_ping(self): - self.round_trip( - b'\x89\x04ping', - Frame(True, OP_PING, b'ping'), - ) + self.round_trip(b'\x89\x04ping', Frame(True, OP_PING, b'ping')) def test_pong(self): - self.round_trip( - b'\x8a\x04pong', - Frame(True, OP_PONG, b'pong'), - ) + self.round_trip(b'\x8a\x04pong', Frame(True, OP_PONG, b'pong')) def test_long(self): self.round_trip( - b'\x82\x7e\x00\x7e' + 126 * b'a', - Frame(True, OP_BINARY, 126 * b'a'), + b'\x82\x7e\x00\x7e' + 126 * b'a', Frame(True, OP_BINARY, 126 * b'a') ) def test_very_long(self): @@ -126,10 +112,7 @@ def test_very_long(self): def test_payload_too_big(self): with self.assertRaises(PayloadTooBig): - self.decode( - b'\x82\x7e\x04\x01' + 1025 * b'a', - max_size=1024, - ) + self.decode(b'\x82\x7e\x04\x01' + 1025 * b'a', max_size=1024) def test_bad_reserved_bits(self): for encoded in [b'\xc0\x00', b'\xa0\x00', b'\x90\x00']: @@ -141,7 +124,7 @@ def test_good_opcode(self): for opcode in list(range(0x00, 0x03)) + list(range(0x08, 0x0b)): encoded = bytes([0x80 | opcode, 0]) with self.subTest(encoded=encoded): - self.decode(encoded) # does not raise an exception + self.decode(encoded) # does not raise an exception def test_bad_opcode(self): for opcode in list(range(0x03, 0x08)) + list(range(0x0b, 0x10)): @@ -206,9 +189,7 @@ def test_serialize_close_errors(self): serialize_close(999, '') def test_extensions(self): - class Rot13: - @staticmethod def encode(frame): assert frame.opcode == OP_TEXT @@ -222,7 +203,5 @@ def decode(frame, *, max_size=None): return Rot13.encode(frame) self.round_trip( - b'\x81\x05uryyb', - Frame(True, OP_TEXT, b'hello'), - extensions=[Rot13()], + b'\x81\x05uryyb', Frame(True, OP_TEXT, b'hello'), extensions=[Rot13()] ) diff --git a/websockets/test_handshake.py b/websockets/test_handshake.py index ebdb75b62..bf695a472 100644 --- a/websockets/test_handshake.py +++ b/websockets/test_handshake.py @@ -2,7 +2,10 @@ import unittest from .exceptions import ( - InvalidHandshake, InvalidHeader, InvalidHeaderValue, InvalidUpgrade + InvalidHandshake, + InvalidHeader, + InvalidHeaderValue, + InvalidUpgrade, ) from .handshake import * from .handshake import accept # private API @@ -10,7 +13,6 @@ class HandshakeTests(unittest.TestCase): - def test_accept(self): # Test vector from RFC 6455 key = "dGhlIHNhbXBsZSBub25jZQ==" @@ -132,8 +134,7 @@ def assertValidResponseHeaders(self, key='CSIRmL8dWYxeAdr/XpEHRw=='): check_response(headers, key) @contextlib.contextmanager - def assertInvalidResponseHeaders( - self, exc_type, key='CSIRmL8dWYxeAdr/XpEHRw=='): + def assertInvalidResponseHeaders(self, exc_type, key='CSIRmL8dWYxeAdr/XpEHRw=='): """ Provide response headers for modification. diff --git a/websockets/test_headers.py b/websockets/test_headers.py index 10c2a7fd8..f85c9b044 100644 --- a/websockets/test_headers.py +++ b/websockets/test_headers.py @@ -6,32 +6,19 @@ class HeadersTests(unittest.TestCase): - def test_parse_connection(self): for header, parsed in [ # Realistic use cases - ( - 'Upgrade', # Safari, Chrome - ['Upgrade'], - ), - ( - 'keep-alive, Upgrade', # Firefox - ['keep-alive', 'Upgrade'], - ), + ('Upgrade', ['Upgrade']), # Safari, Chrome + ('keep-alive, Upgrade', ['keep-alive', 'Upgrade']), # Firefox # Pathological example - ( - ',,\t, , ,Upgrade ,,', - ['Upgrade'], - ), + (',,\t, , ,Upgrade ,,', ['Upgrade']), ]: with self.subTest(header=header): self.assertEqual(parse_connection(header), parsed) def test_parse_connection_invalid_header(self): - for header in [ - '???', - 'keep-alive; Upgrade', - ]: + for header in ['???', 'keep-alive; Upgrade']: with self.subTest(header=header): with self.assertRaises(InvalidHeaderFormat): parse_connection(header) @@ -39,30 +26,17 @@ def test_parse_connection_invalid_header(self): def test_parse_upgrade(self): for header, parsed in [ # Realistic use case - ( - 'websocket', - ['websocket'], - ), + ('websocket', ['websocket']), # Synthetic example - ( - 'http/3.0, websocket', - ['http/3.0', 'websocket'] - ), + ('http/3.0, websocket', ['http/3.0', 'websocket']), # Pathological example - ( - ',, WebSocket, \t,,', - ['WebSocket'], - ), + (',, WebSocket, \t,,', ['WebSocket']), ]: with self.subTest(header=header): self.assertEqual(parse_upgrade(header), parsed) def test_parse_upgrade_invalid_header(self): - for header in [ - '???', - 'websocket 2', - 'http/3.0; websocket', - ]: + for header in ['???', 'websocket 2', 'http/3.0; websocket']: with self.subTest(header=header): with self.assertRaises(InvalidHeaderFormat): parse_upgrade(header) @@ -70,20 +44,20 @@ def test_parse_upgrade_invalid_header(self): def test_parse_extension_list(self): for header, parsed in [ # Synthetic examples - ( - 'foo', - [('foo', [])], - ), - ( - 'foo, bar', - [('foo', []), ('bar', [])], - ), + ('foo', [('foo', [])]), + ('foo, bar', [('foo', []), ('bar', [])]), ( 'foo; name; token=token; quoted-string="quoted-string", ' 'bar; quux; quuux', [ - ('foo', [('name', None), ('token', 'token'), - ('quoted-string', 'quoted-string')]), + ( + 'foo', + [ + ('name', None), + ('token', 'token'), + ('quoted-string', 'quoted-string'), + ], + ), ('bar', [('quux', None), ('quuux', None)]), ], ), @@ -93,10 +67,7 @@ def test_parse_extension_list(self): [('foo', [('bar', '42')]), ('baz', [])], ), # Realistic use cases for permessage-deflate - ( - 'permessage-deflate', - [('permessage-deflate', [])], - ), + ('permessage-deflate', [('permessage-deflate', [])]), ( 'permessage-deflate; client_max_window_bits', [('permessage-deflate', [('client_max_window_bits', None)])], @@ -116,7 +87,7 @@ def test_parse_extension_list_invalid_header(self): for header in [ # Truncated examples '', - ',\t,' + ',\t,', 'foo;', 'foo; bar;', 'foo; bar=', @@ -133,19 +104,10 @@ def test_parse_extension_list_invalid_header(self): def test_parse_subprotocol_list(self): for header, parsed in [ # Synthetic examples - ( - 'foo', - ['foo'], - ), - ( - 'foo, bar', - ['foo', 'bar'], - ), + ('foo', ['foo']), + ('foo, bar', ['foo', 'bar']), # Pathological example - ( - ',\t, , ,foo ,, bar,baz,,', - ['foo', 'bar', 'baz'], - ), + (',\t, , ,foo ,, bar,baz,,', ['foo', 'bar', 'baz']), ]: with self.subTest(header=header): self.assertEqual(parse_subprotocol_list(header), parsed) diff --git a/websockets/test_http.py b/websockets/test_http.py index 01ae6de71..371bafc16 100644 --- a/websockets/test_http.py +++ b/websockets/test_http.py @@ -6,7 +6,6 @@ class HTTPAsyncTests(unittest.TestCase): - def setUp(self): super().setUp() self.loop = asyncio.new_event_loop() @@ -30,8 +29,7 @@ def test_read_request(self): b'Sec-WebSocket-Version: 13\r\n' b'\r\n' ) - path, headers = self.loop.run_until_complete( - read_request(self.stream)) + path, headers = self.loop.run_until_complete(read_request(self.stream)) self.assertEqual(path, '/chat') self.assertEqual(headers['Upgrade'], 'websocket') @@ -45,8 +43,7 @@ def test_read_response(self): b'Sec-WebSocket-Protocol: chat\r\n' b'\r\n' ) - status_code, headers = self.loop.run_until_complete( - read_response(self.stream)) + status_code, headers = self.loop.run_until_complete(read_response(self.stream)) self.assertEqual(status_code, 101) self.assertEqual(headers['Upgrade'], 'websocket') @@ -103,12 +100,8 @@ def test_line_ending(self): class HeadersTests(unittest.TestCase): - def setUp(self): - self.headers = Headers([ - ('Connection', 'Upgrade'), - ('Server', USER_AGENT), - ]) + self.headers = Headers([('Connection', 'Upgrade'), ('Server', USER_AGENT)]) def test_str(self): self.assertEqual( @@ -124,14 +117,8 @@ def test_repr(self): ) def test_multiple_values_error_str(self): - self.assertEqual( - str(MultipleValuesError('Connection')), - "'Connection'", - ) - self.assertEqual( - str(MultipleValuesError()), - "", - ) + self.assertEqual(str(MultipleValuesError('Connection')), "'Connection'") + self.assertEqual(str(MultipleValuesError()), "") def test_contains(self): self.assertIn('Server', self.headers) @@ -215,14 +202,10 @@ def test_get_all_no_values(self): def test_get_all_multiple_values(self): self.headers['Connection'] = 'close' - self.assertEqual( - self.headers.get_all('Connection'), ['Upgrade', 'close']) + self.assertEqual(self.headers.get_all('Connection'), ['Upgrade', 'close']) def test_raw_items(self): self.assertEqual( list(self.headers.raw_items()), - [ - ('Connection', 'Upgrade'), - ('Server', USER_AGENT), - ], + [('Connection', 'Upgrade'), ('Server', USER_AGENT)], ) diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 3c934a41d..e1ce1e3e0 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -22,7 +22,7 @@ MS = 0.001 * int(os.environ.get('WEBSOCKETS_TESTS_TIMEOUT_FACTOR', 1)) # asyncio's debug mode has a 10x performance penalty for this test suite. -if os.environ.get('PYTHONASYNCIODEBUG'): # pragma: no cover +if os.environ.get('PYTHONASYNCIODEBUG'): # pragma: no cover MS *= 10 # Ensure that timeouts are larger than the clock's resolution (for Windows). @@ -45,6 +45,7 @@ class TransportMock(unittest.mock.Mock): They could also pause_writing and resume_writing to test flow control. """ + # This should happen in __init__ but overriding Mock.__init__ is hard. def setup_mock(self, loop, protocol): self.loop = loop @@ -88,6 +89,7 @@ class CommonTests: Tests are run by the ServerTests and ClientTests subclasses. """ + def setUp(self): super().setUp() self.loop = asyncio.new_event_loop() @@ -196,8 +198,8 @@ def half_close_connection_local(self, code=1000, reason='close'): close_frame_data = serialize_close(code, reason) # Trigger the closing handshake from the local endpoint. close_task = self.ensure_future(self.protocol.close(code, reason)) - self.run_loop_once() # wait_for executes - self.run_loop_once() # write_frame executes + self.run_loop_once() # wait_for executes + self.run_loop_once() # write_frame executes # Empty the outgoing data stream so we can make assertions later on. self.assertOneFrameSent(True, OP_CLOSE, close_frame_data) @@ -206,7 +208,8 @@ def half_close_connection_local(self, code=1000, reason='close'): # Complete the closing sequence at 1ms intervals so the test can run # at each point even it goes back to the event loop several times. self.loop.call_later( - MS, self.receive_frame, Frame(True, OP_CLOSE, close_frame_data)) + MS, self.receive_frame, Frame(True, OP_CLOSE, close_frame_data) + ) self.loop.call_later(2 * MS, self.receive_eof_if_client) # This task must be awaited or canceled by the caller. @@ -230,7 +233,7 @@ def half_close_connection_remote(self, code=1000, reason='close'): close_frame_data = serialize_close(code, reason) # Trigger the closing handshake from the remote endpoint. self.receive_frame(Frame(True, OP_CLOSE, close_frame_data)) - self.run_loop_once() # read_frame executes + self.run_loop_once() # read_frame executes # Empty the outgoing data stream so we can make assertions later on. self.assertOneFrameSent(True, OP_CLOSE, close_frame_data) @@ -270,10 +273,11 @@ def last_sent_frame(self): if stream.at_eof(): frame = None else: - frame = self.loop.run_until_complete(Frame.read( - stream.readexactly, mask=self.protocol.is_client)) + frame = self.loop.run_until_complete( + Frame.read(stream.readexactly, mask=self.protocol.is_client) + ) - if not stream.at_eof(): # pragma: no cover + if not stream.at_eof(): # pragma: no cover data = self.loop.run_until_complete(stream.read()) raise AssertionError("Trailing data found: {!r}".format(data)) @@ -302,8 +306,7 @@ def assertConnectionFailed(self, code, message): if code == 1006: self.assertNoFrameSent() else: - self.assertOneFrameSent( - True, OP_CLOSE, serialize_close(code, message)) + self.assertOneFrameSent(True, OP_CLOSE, serialize_close(code, message)) @contextlib.contextmanager def assertCompletesWithin(self, min_time, max_time): @@ -311,10 +314,8 @@ def assertCompletesWithin(self, min_time, max_time): yield t1 = self.loop.time() dt = t1 - t0 - self.assertGreaterEqual( - dt, min_time, "Too fast: {} < {}".format(dt, min_time)) - self.assertLess( - dt, max_time, "Too slow: {} >= {}".format(dt, max_time)) + self.assertGreaterEqual(dt, min_time, "Too fast: {} < {}".format(dt, min_time)) + self.assertLess(dt, max_time, "Too slow: {} >= {}".format(dt, max_time)) # Test public attributes. @@ -378,7 +379,7 @@ def test_recv_on_closing_connection_local(self): with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.recv()) - self.loop.run_until_complete(close_task) # cleanup + self.loop.run_until_complete(close_task) # cleanup def test_recv_on_closing_connection_remote(self): self.half_close_connection_remote() @@ -415,13 +416,13 @@ def test_recv_binary_payload_too_big(self): self.assertConnectionFailed(1009, '') def test_recv_text_no_max_size(self): - self.protocol.max_size = None # for test coverage + self.protocol.max_size = None # for test coverage self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8') * 205)) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, 'café' * 205) def test_recv_binary_no_max_size(self): - self.protocol.max_size = None # for test coverage + self.protocol.max_size = None # for test coverage self.receive_frame(Frame(True, OP_BINARY, b'tea' * 342)) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, b'tea' * 342) @@ -430,6 +431,7 @@ def test_recv_other_error(self): @asyncio.coroutine def read_message(): raise Exception("BOOM") + self.protocol.read_message = read_message self.process_invalid_frames() self.assertConnectionFailed(1011, '') @@ -468,7 +470,7 @@ def test_send_on_closing_connection_local(self): self.assertNoFrameSent() - self.loop.run_until_complete(close_task) # cleanup + self.loop.run_until_complete(close_task) # cleanup def test_send_on_closing_connection_remote(self): self.half_close_connection_remote() @@ -518,7 +520,7 @@ def test_ping_on_closing_connection_local(self): self.assertNoFrameSent() - self.loop.run_until_complete(close_task) # cleanup + self.loop.run_until_complete(close_task) # cleanup def test_ping_on_closing_connection_remote(self): self.half_close_connection_remote() @@ -563,7 +565,7 @@ def test_pong_on_closing_connection_local(self): self.assertNoFrameSent() - self.loop.run_until_complete(close_task) # cleanup + self.loop.run_until_complete(close_task) # cleanup def test_pong_on_closing_connection_remote(self): self.half_close_connection_remote() @@ -612,10 +614,10 @@ def test_cancel_ping(self): self.assertTrue(ping.cancelled()) def test_acknowledge_previous_pings(self): - pings = [( - self.loop.run_until_complete(self.protocol.ping()), - self.last_sent_frame(), - ) for i in range(3)] + pings = [ + (self.loop.run_until_complete(self.protocol.ping()), self.last_sent_frame()) + for i in range(3) + ] # Unsolicited pong doesn't acknowledge pings self.receive_frame(Frame(True, OP_PONG, b'')) self.run_loop_once() @@ -678,14 +680,14 @@ def test_fragmented_binary_payload_too_big(self): self.assertConnectionFailed(1009, '') def test_fragmented_text_no_max_size(self): - self.protocol.max_size = None # for test coverage + self.protocol.max_size = None # for test coverage self.receive_frame(Frame(False, OP_TEXT, 'café'.encode('utf-8') * 100)) self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8') * 105)) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, 'café' * 205) def test_fragmented_binary_no_max_size(self): - self.protocol.max_size = None # for test coverage + self.protocol.max_size = None # for test coverage self.receive_frame(Frame(False, OP_BINARY, b'tea' * 171)) self.receive_frame(Frame(True, OP_CONT, b'tea' * 171)) data = self.loop.run_until_complete(self.protocol.recv()) @@ -774,14 +776,16 @@ def test_connection_closed_attributes(self): # Test the protocol logic for sending keepalive pings. def restart_protocol_with_keepalive_ping( - self, ping_interval=3 * MS, ping_timeout=3 * MS): + self, ping_interval=3 * MS, ping_timeout=3 * MS + ): initial_protocol = self.protocol # copied from tearDown self.transport.close() self.loop.run_until_complete(self.protocol.close()) # copied from setUp, but enables keepalive pings self.protocol = WebSocketCommonProtocol( - ping_interval=ping_interval, ping_timeout=ping_timeout) + ping_interval=ping_interval, ping_timeout=ping_timeout + ) self.transport = TransportMock() self.transport.setup_mock(self.loop, self.protocol) self.protocol.is_client = initial_protocol.is_client @@ -830,7 +834,7 @@ def test_keepalive_ping_stops_when_connection_closing(self): # The keepalive ping task terminated. self.assertTrue(self.protocol.keepalive_ping_task.cancelled()) - self.loop.run_until_complete(close_task) # cleanup + self.loop.run_until_complete(close_task) # cleanup def test_keepalive_ping_stops_when_connection_closed(self): self.restart_protocol_with_keepalive_ping() @@ -869,6 +873,7 @@ def test_keepalive_ping_unexpected_error(self): @asyncio.coroutine def ping(): raise Exception("BOOM") + self.protocol.ping = ping # The keepalive ping task fails when sending a ping at 3ms. @@ -993,7 +998,6 @@ def test_remote_close_during_send(self): class ServerTests(CommonTests, unittest.TestCase): - def setUp(self): super().setUp() self.protocol.is_client = False @@ -1045,7 +1049,6 @@ def test_local_close_connection_lost_timeout_after_close(self): class ClientTests(CommonTests, unittest.TestCase): - def setUp(self): super().setUp() self.protocol.is_client = True diff --git a/websockets/test_uri.py b/websockets/test_uri.py index 86e305ae2..f8f9e042f 100644 --- a/websockets/test_uri.py +++ b/websockets/test_uri.py @@ -5,26 +5,11 @@ VALID_URIS = [ - ( - 'ws://localhost/', - (False, 'localhost', 80, '/', None), - ), - ( - 'wss://localhost/', - (True, 'localhost', 443, '/', None), - ), - ( - 'ws://localhost/path?query', - (False, 'localhost', 80, '/path?query', None), - ), - ( - 'WS://LOCALHOST/PATH?QUERY', - (False, 'localhost', 80, '/PATH?QUERY', None), - ), - ( - 'ws://user:pass@localhost/', - (False, 'localhost', 80, '/', ('user', 'pass')), - ), + ('ws://localhost/', (False, 'localhost', 80, '/', None)), + ('wss://localhost/', (True, 'localhost', 443, '/', None)), + ('ws://localhost/path?query', (False, 'localhost', 80, '/path?query', None)), + ('WS://LOCALHOST/PATH?QUERY', (False, 'localhost', 80, '/PATH?QUERY', None)), + ('ws://user:pass@localhost/', (False, 'localhost', 80, '/', ('user', 'pass'))), ] INVALID_URIS = [ @@ -35,7 +20,6 @@ class URITests(unittest.TestCase): - def test_success(self): for uri, parsed in VALID_URIS: with self.subTest(uri=uri): diff --git a/websockets/test_utils.py b/websockets/test_utils.py index 9611ee777..13ce9ff99 100644 --- a/websockets/test_utils.py +++ b/websockets/test_utils.py @@ -4,7 +4,6 @@ class UtilsTests(unittest.TestCase): - @staticmethod def apply_mask(*args, **kwargs): return py_apply_mask(*args, **kwargs) @@ -20,11 +19,7 @@ def test_apply_mask(self): self.assertEqual(self.apply_mask(data_in, mask), data_out) def test_apply_mask_check_input_types(self): - for data_in, mask in [ - (None, None), - (b'abcd', None), - (None, b'abcd'), - ]: + for data_in, mask in [(None, None), (b'abcd', None), (None, b'abcd')]: with self.subTest(data_in=data_in, mask=mask): with self.assertRaises(TypeError): self.apply_mask(data_in, mask) @@ -43,11 +38,11 @@ def test_apply_mask_check_mask_length(self): try: from .speedups import apply_mask as c_apply_mask -except ImportError: # pragma: no cover +except ImportError: # pragma: no cover pass else: - class SpeedupsTests(UtilsTests): + class SpeedupsTests(UtilsTests): @staticmethod def apply_mask(*args, **kwargs): return c_apply_mask(*args, **kwargs) diff --git a/websockets/uri.py b/websockets/uri.py index 21f757f8a..d793fc6aa 100644 --- a/websockets/uri.py +++ b/websockets/uri.py @@ -15,7 +15,8 @@ __all__ = ['parse_uri', 'WebSocketURI'] WebSocketURI = collections.namedtuple( - 'WebSocketURI', ['secure', 'host', 'port', 'resource_name', 'user_info']) + 'WebSocketURI', ['secure', 'host', 'port', 'resource_name', 'user_info'] +) WebSocketURI.__doc__ = """WebSocket URI. * ``secure`` is the secure flag From 013583837fed38ac787fb5c4bc08899383b74cd9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 23 Aug 2018 12:43:34 +0200 Subject: [PATCH 0467/1539] Wheels work on Python 3.7 too. --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 2fa9d1394..2e14eee38 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bdist_wheel] -python-tag = py34.py35.py36 +python-tag = py34.py35.py36.py37 [flake8] ignore = E731,F403,F405,W503 From e5332e93d31827f71baae976aaad244097cb64a7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 23 Aug 2018 14:56:00 +0200 Subject: [PATCH 0468/1539] Add LICENSE to wheels. --- setup.cfg | 3 +++ 1 file changed, 3 insertions(+) diff --git a/setup.cfg b/setup.cfg index 2e14eee38..b029a684c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,9 @@ [bdist_wheel] python-tag = py34.py35.py36.py37 +[metadata] +license_file = LICENSE + [flake8] ignore = E731,F403,F405,W503 max-line-length = 88 From 638df3ac16669ecc711cb585c0e50bab55d8f89f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 23 Aug 2018 12:42:38 +0200 Subject: [PATCH 0469/1539] Improve project layout. This allows tox to run tests against an installed version rather than the local checkout. Also stop running tests with and without the C extension. The tests in test_utils.py are sufficient. --- .appveyor.yml | 2 +- .circleci/config.yml | 8 ++--- .travis.yml | 2 +- MANIFEST.in | 1 - Makefile | 6 ++-- setup.cfg | 12 +++++++ setup.py | 7 ++-- {websockets => src/websockets}/__init__.py | 0 {websockets => src/websockets}/__main__.py | 0 {websockets => src/websockets}/client.py | 0 .../websockets}/compatibility.py | 0 {websockets => src/websockets}/exceptions.py | 0 .../websockets}/extensions/__init__.py | 0 .../websockets}/extensions/base.py | 0 .../extensions/permessage_deflate.py | 0 {websockets => src/websockets}/framing.py | 0 {websockets => src/websockets}/handshake.py | 0 {websockets => src/websockets}/headers.py | 0 {websockets => src/websockets}/http.py | 0 {websockets => src/websockets}/protocol.py | 0 .../websockets}/py35/__init__.py | 0 {websockets => src/websockets}/py35/client.py | 0 {websockets => src/websockets}/py35/server.py | 0 .../websockets}/py36/__init__.py | 0 .../websockets}/py36/protocol.py | 0 {websockets => src/websockets}/server.py | 0 {websockets => src/websockets}/speedups.c | 0 {websockets => src/websockets}/uri.py | 0 {websockets => src/websockets}/utils.py | 0 {websockets => src/websockets}/version.py | 0 .../test_speedups.py => tests/__init__.py | 0 tests/extensions/__init__.py | 0 {websockets => tests}/extensions/test_base.py | 2 +- .../extensions/test_permessage_deflate.py | 6 ++-- tests/py35/__init__.py | 0 .../py35/_test_client_server.py | 7 ++-- tests/py36/__init__.py | 0 .../py36/_test_client_server.py | 7 ++-- {websockets => tests}/test_client_server.py | 17 ++++----- {websockets => tests}/test_exceptions.py | 4 +-- {websockets => tests}/test_framing.py | 4 +-- {websockets => tests}/test_handshake.py | 8 ++--- {websockets => tests}/test_headers.py | 6 ++-- {websockets => tests}/test_http.py | 4 +-- {websockets => tests}/test_localhost.cnf | 0 {websockets => tests}/test_localhost.pem | 0 {websockets => tests}/test_protocol.py | 8 ++--- tests/test_speedups.py | 0 {websockets => tests}/test_uri.py | 4 +-- {websockets => tests}/test_utils.py | 4 +-- tox.ini | 35 ++++--------------- 51 files changed, 74 insertions(+), 80 deletions(-) rename {websockets => src/websockets}/__init__.py (100%) rename {websockets => src/websockets}/__main__.py (100%) rename {websockets => src/websockets}/client.py (100%) rename {websockets => src/websockets}/compatibility.py (100%) rename {websockets => src/websockets}/exceptions.py (100%) rename {websockets => src/websockets}/extensions/__init__.py (100%) rename {websockets => src/websockets}/extensions/base.py (100%) rename {websockets => src/websockets}/extensions/permessage_deflate.py (100%) rename {websockets => src/websockets}/framing.py (100%) rename {websockets => src/websockets}/handshake.py (100%) rename {websockets => src/websockets}/headers.py (100%) rename {websockets => src/websockets}/http.py (100%) rename {websockets => src/websockets}/protocol.py (100%) rename {websockets => src/websockets}/py35/__init__.py (100%) rename {websockets => src/websockets}/py35/client.py (100%) rename {websockets => src/websockets}/py35/server.py (100%) rename {websockets => src/websockets}/py36/__init__.py (100%) rename {websockets => src/websockets}/py36/protocol.py (100%) rename {websockets => src/websockets}/server.py (100%) rename {websockets => src/websockets}/speedups.c (100%) rename {websockets => src/websockets}/uri.py (100%) rename {websockets => src/websockets}/utils.py (100%) rename {websockets => src/websockets}/version.py (100%) rename websockets/test_speedups.py => tests/__init__.py (100%) create mode 100644 tests/extensions/__init__.py rename {websockets => tests}/extensions/test_base.py (53%) rename {websockets => tests}/extensions/test_permessage_deflate.py (99%) create mode 100644 tests/py35/__init__.py rename {websockets => tests}/py35/_test_client_server.py (97%) create mode 100644 tests/py36/__init__.py rename {websockets => tests}/py36/_test_client_server.py (96%) rename {websockets => tests}/test_client_server.py (99%) rename {websockets => tests}/test_exceptions.py (98%) rename {websockets => tests}/test_framing.py (98%) rename {websockets => tests}/test_handshake.py (97%) rename {websockets => tests}/test_headers.py (97%) rename {websockets => tests}/test_http.py (99%) rename {websockets => tests}/test_localhost.cnf (100%) rename {websockets => tests}/test_localhost.pem (100%) rename {websockets => tests}/test_protocol.py (99%) create mode 100644 tests/test_speedups.py rename {websockets => tests}/test_uri.py (92%) rename {websockets => tests}/test_utils.py (92%) diff --git a/.appveyor.yml b/.appveyor.yml index 77e07e9b7..461ff5ced 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -7,7 +7,7 @@ skip_branch_with_pr: true environment: # websockets only works on Python >= 3.4. CIBW_SKIP: cp27-* cp33-* - CIBW_TEST_COMMAND: python -W default -m unittest websockets + CIBW_TEST_COMMAND: python -W default -m unittest WEBSOCKETS_TESTS_TIMEOUT_FACTOR: 100 install: diff --git a/.circleci/config.yml b/.circleci/config.yml index fbcc172d1..f0ca45b21 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -19,7 +19,7 @@ jobs: - run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc - checkout - run: sudo pip install tox - - run: tox -e py34,py34-speedups + - run: tox -e py34 py35: docker: - image: circleci/python:3.5 @@ -28,7 +28,7 @@ jobs: - run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc - checkout - run: sudo pip install tox - - run: tox -e py35,py35-speedups + - run: tox -e py35 py36: docker: - image: circleci/python:3.6 @@ -37,7 +37,7 @@ jobs: - run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc - checkout - run: sudo pip install tox - - run: tox -e py36,py36-speedups + - run: tox -e py36 py37: docker: - image: circleci/python:3.7 @@ -46,7 +46,7 @@ jobs: - run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc - checkout - run: sudo pip install tox - - run: tox -e py37,py37-speedups + - run: tox -e py37 workflows: version: 2 diff --git a/.travis.yml b/.travis.yml index b66c0f5b7..3d6dd2089 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,7 +2,7 @@ env: global: # websockets only works on Python >= 3.4. - CIBW_SKIP="cp27-* cp33-*" - - CIBW_TEST_COMMAND="python3 -W default -m unittest websockets" + - CIBW_TEST_COMMAND="python3 -W default -m unittest" - WEBSOCKETS_TESTS_TIMEOUT_FACTOR=100 matrix: diff --git a/MANIFEST.in b/MANIFEST.in index d0cff2af7..1aba38f67 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1 @@ include LICENSE -include websockets/test_localhost.pem diff --git a/Makefile b/Makefile index 3282fd1e1..5d729181b 100644 --- a/Makefile +++ b/Makefile @@ -1,14 +1,16 @@ export PYTHONASYNCIODEBUG=1 +export PYTHONPATH=src style: - black --skip-string-normalization websockets + isort --recursive src tests + black --skip-string-normalization src tests test: python -W default -m unittest coverage: python -m coverage erase - python -W default -m coverage run --branch --omit=websockets/__main__.py --source=websockets -m unittest + python -W default -m coverage run -m unittest python -m coverage html clean: diff --git a/setup.cfg b/setup.cfg index b029a684c..ad3af102f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,3 +16,15 @@ known_standard_library = asyncio line_length = 88 lines_after_imports = 2 multi_line_output = 3 + +[coverage:run] +branch = True +omit = */__main__.py +source = + websockets + tests + +[coverage:paths] +source = + src/websockets + .tox/*/lib/python*/site-packages/websockets diff --git a/setup.py b/setup.py index 7abafd399..63401d49d 100644 --- a/setup.py +++ b/setup.py @@ -14,8 +14,8 @@ long_description = f.read() # When dropping Python < 3.5, change to: -# exec((root_dir / 'websockets' / 'version.py').read_text(encoding='utf-8')) -with (root_dir / 'websockets' / 'version.py').open(encoding='utf-8') as f: +# exec((root_dir / 'src' / 'websockets' / 'version.py').read_text(encoding='utf-8')) +with (root_dir / 'src' / 'websockets' / 'version.py').open(encoding='utf-8') as f: exec(f.read()) py_version = sys.version_info[:2] @@ -34,7 +34,7 @@ ext_modules = [ setuptools.Extension( 'websockets.speedups', - sources=['websockets/speedups.c'], + sources=['src/websockets/speedups.c'], optional=not (root_dir / '.cibuildwheel').exists(), ) ] @@ -61,6 +61,7 @@ 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', ], + package_dir = {'': 'src'}, packages=packages, ext_modules=ext_modules, include_package_data=True, diff --git a/websockets/__init__.py b/src/websockets/__init__.py similarity index 100% rename from websockets/__init__.py rename to src/websockets/__init__.py diff --git a/websockets/__main__.py b/src/websockets/__main__.py similarity index 100% rename from websockets/__main__.py rename to src/websockets/__main__.py diff --git a/websockets/client.py b/src/websockets/client.py similarity index 100% rename from websockets/client.py rename to src/websockets/client.py diff --git a/websockets/compatibility.py b/src/websockets/compatibility.py similarity index 100% rename from websockets/compatibility.py rename to src/websockets/compatibility.py diff --git a/websockets/exceptions.py b/src/websockets/exceptions.py similarity index 100% rename from websockets/exceptions.py rename to src/websockets/exceptions.py diff --git a/websockets/extensions/__init__.py b/src/websockets/extensions/__init__.py similarity index 100% rename from websockets/extensions/__init__.py rename to src/websockets/extensions/__init__.py diff --git a/websockets/extensions/base.py b/src/websockets/extensions/base.py similarity index 100% rename from websockets/extensions/base.py rename to src/websockets/extensions/base.py diff --git a/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py similarity index 100% rename from websockets/extensions/permessage_deflate.py rename to src/websockets/extensions/permessage_deflate.py diff --git a/websockets/framing.py b/src/websockets/framing.py similarity index 100% rename from websockets/framing.py rename to src/websockets/framing.py diff --git a/websockets/handshake.py b/src/websockets/handshake.py similarity index 100% rename from websockets/handshake.py rename to src/websockets/handshake.py diff --git a/websockets/headers.py b/src/websockets/headers.py similarity index 100% rename from websockets/headers.py rename to src/websockets/headers.py diff --git a/websockets/http.py b/src/websockets/http.py similarity index 100% rename from websockets/http.py rename to src/websockets/http.py diff --git a/websockets/protocol.py b/src/websockets/protocol.py similarity index 100% rename from websockets/protocol.py rename to src/websockets/protocol.py diff --git a/websockets/py35/__init__.py b/src/websockets/py35/__init__.py similarity index 100% rename from websockets/py35/__init__.py rename to src/websockets/py35/__init__.py diff --git a/websockets/py35/client.py b/src/websockets/py35/client.py similarity index 100% rename from websockets/py35/client.py rename to src/websockets/py35/client.py diff --git a/websockets/py35/server.py b/src/websockets/py35/server.py similarity index 100% rename from websockets/py35/server.py rename to src/websockets/py35/server.py diff --git a/websockets/py36/__init__.py b/src/websockets/py36/__init__.py similarity index 100% rename from websockets/py36/__init__.py rename to src/websockets/py36/__init__.py diff --git a/websockets/py36/protocol.py b/src/websockets/py36/protocol.py similarity index 100% rename from websockets/py36/protocol.py rename to src/websockets/py36/protocol.py diff --git a/websockets/server.py b/src/websockets/server.py similarity index 100% rename from websockets/server.py rename to src/websockets/server.py diff --git a/websockets/speedups.c b/src/websockets/speedups.c similarity index 100% rename from websockets/speedups.c rename to src/websockets/speedups.c diff --git a/websockets/uri.py b/src/websockets/uri.py similarity index 100% rename from websockets/uri.py rename to src/websockets/uri.py diff --git a/websockets/utils.py b/src/websockets/utils.py similarity index 100% rename from websockets/utils.py rename to src/websockets/utils.py diff --git a/websockets/version.py b/src/websockets/version.py similarity index 100% rename from websockets/version.py rename to src/websockets/version.py diff --git a/websockets/test_speedups.py b/tests/__init__.py similarity index 100% rename from websockets/test_speedups.py rename to tests/__init__.py diff --git a/tests/extensions/__init__.py b/tests/extensions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/websockets/extensions/test_base.py b/tests/extensions/test_base.py similarity index 53% rename from websockets/extensions/test_base.py rename to tests/extensions/test_base.py index 9dd15c857..ba8657b65 100644 --- a/websockets/extensions/test_base.py +++ b/tests/extensions/test_base.py @@ -1,4 +1,4 @@ -from .base import * # noqa +from websockets.extensions.base import * # noqa # Abstract classes don't provide any behavior to test. diff --git a/websockets/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py similarity index 99% rename from websockets/extensions/test_permessage_deflate.py rename to tests/extensions/test_permessage_deflate.py index 67dec5af2..0b7b78eae 100644 --- a/websockets/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -1,14 +1,15 @@ import unittest import zlib -from ..exceptions import ( +from websockets.exceptions import ( DuplicateParameter, InvalidParameterName, InvalidParameterValue, NegotiationError, PayloadTooBig, ) -from ..framing import ( +from websockets.extensions.permessage_deflate import * +from websockets.framing import ( OP_BINARY, OP_CLOSE, OP_CONT, @@ -18,7 +19,6 @@ Frame, serialize_close, ) -from .permessage_deflate import * class ExtensionTestsMixin: diff --git a/tests/py35/__init__.py b/tests/py35/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/websockets/py35/_test_client_server.py b/tests/py35/_test_client_server.py similarity index 97% rename from websockets/py35/_test_client_server.py rename to tests/py35/_test_client_server.py index 7e7218247..46e9111a5 100644 --- a/websockets/py35/_test_client_server.py +++ b/tests/py35/_test_client_server.py @@ -7,9 +7,10 @@ import tempfile import unittest -from ..client import * -from ..protocol import State -from ..server import * +from websockets.client import * +from websockets.protocol import State +from websockets.server import * + from ..test_client_server import get_server_uri, handler diff --git a/tests/py36/__init__.py b/tests/py36/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/websockets/py36/_test_client_server.py b/tests/py36/_test_client_server.py similarity index 96% rename from websockets/py36/_test_client_server.py rename to tests/py36/_test_client_server.py index 693242f13..f38fbe6f6 100644 --- a/websockets/py36/_test_client_server.py +++ b/tests/py36/_test_client_server.py @@ -4,9 +4,10 @@ import sys import unittest -from ..client import * -from ..exceptions import ConnectionClosed -from ..server import * +from websockets.client import * +from websockets.exceptions import ConnectionClosed +from websockets.server import * + from ..test_client_server import get_server_uri diff --git a/websockets/test_client_server.py b/tests/test_client_server.py similarity index 99% rename from websockets/test_client_server.py rename to tests/test_client_server.py index 7f66ea036..83bd6ad9d 100644 --- a/websockets/test_client_server.py +++ b/tests/test_client_server.py @@ -14,23 +14,24 @@ import urllib.request import warnings -from .client import * -from .compatibility import FORBIDDEN, OK, UNAUTHORIZED -from .exceptions import ( +from websockets.client import * +from websockets.compatibility import FORBIDDEN, OK, UNAUTHORIZED +from websockets.exceptions import ( ConnectionClosed, InvalidHandshake, InvalidStatusCode, NegotiationError, ) -from .extensions.permessage_deflate import ( +from websockets.extensions.permessage_deflate import ( ClientPerMessageDeflateFactory, PerMessageDeflate, ServerPerMessageDeflateFactory, ) -from .handshake import build_response -from .http import USER_AGENT, Headers, read_response -from .protocol import State -from .server import * +from websockets.handshake import build_response +from websockets.http import USER_AGENT, Headers, read_response +from websockets.protocol import State +from websockets.server import * + from .test_protocol import MS diff --git a/websockets/test_exceptions.py b/tests/test_exceptions.py similarity index 98% rename from websockets/test_exceptions.py rename to tests/test_exceptions.py index 49042105c..7b935491b 100644 --- a/websockets/test_exceptions.py +++ b/tests/test_exceptions.py @@ -1,7 +1,7 @@ import unittest -from .exceptions import * -from .http import Headers +from websockets.exceptions import * +from websockets.http import Headers class ExceptionsTests(unittest.TestCase): diff --git a/websockets/test_framing.py b/tests/test_framing.py similarity index 98% rename from websockets/test_framing.py rename to tests/test_framing.py index 0ea0a2851..ee8515762 100644 --- a/websockets/test_framing.py +++ b/tests/test_framing.py @@ -3,8 +3,8 @@ import unittest import unittest.mock -from .exceptions import PayloadTooBig, WebSocketProtocolError -from .framing import * +from websockets.exceptions import PayloadTooBig, WebSocketProtocolError +from websockets.framing import * class FramingTests(unittest.TestCase): diff --git a/websockets/test_handshake.py b/tests/test_handshake.py similarity index 97% rename from websockets/test_handshake.py rename to tests/test_handshake.py index bf695a472..a0cb55a9e 100644 --- a/websockets/test_handshake.py +++ b/tests/test_handshake.py @@ -1,15 +1,15 @@ import contextlib import unittest -from .exceptions import ( +from websockets.exceptions import ( InvalidHandshake, InvalidHeader, InvalidHeaderValue, InvalidUpgrade, ) -from .handshake import * -from .handshake import accept # private API -from .http import Headers +from websockets.handshake import * +from websockets.handshake import accept # private API +from websockets.http import Headers class HandshakeTests(unittest.TestCase): diff --git a/websockets/test_headers.py b/tests/test_headers.py similarity index 97% rename from websockets/test_headers.py rename to tests/test_headers.py index f85c9b044..7d52b9f74 100644 --- a/websockets/test_headers.py +++ b/tests/test_headers.py @@ -1,8 +1,8 @@ import unittest -from .exceptions import InvalidHeaderFormat -from .headers import * -from .headers import build_basic_auth +from websockets.exceptions import InvalidHeaderFormat +from websockets.headers import * +from websockets.headers import build_basic_auth class HeadersTests(unittest.TestCase): diff --git a/websockets/test_http.py b/tests/test_http.py similarity index 99% rename from websockets/test_http.py rename to tests/test_http.py index 371bafc16..b18e24a26 100644 --- a/websockets/test_http.py +++ b/tests/test_http.py @@ -1,8 +1,8 @@ import asyncio import unittest -from .http import * -from .http import read_headers +from websockets.http import * +from websockets.http import read_headers class HTTPAsyncTests(unittest.TestCase): diff --git a/websockets/test_localhost.cnf b/tests/test_localhost.cnf similarity index 100% rename from websockets/test_localhost.cnf rename to tests/test_localhost.cnf diff --git a/websockets/test_localhost.pem b/tests/test_localhost.pem similarity index 100% rename from websockets/test_localhost.pem rename to tests/test_localhost.pem diff --git a/websockets/test_protocol.py b/tests/test_protocol.py similarity index 99% rename from websockets/test_protocol.py rename to tests/test_protocol.py index e1ce1e3e0..133a015ed 100644 --- a/websockets/test_protocol.py +++ b/tests/test_protocol.py @@ -7,10 +7,10 @@ import unittest import unittest.mock -from .compatibility import asyncio_ensure_future -from .exceptions import ConnectionClosed, InvalidState -from .framing import * -from .protocol import State, WebSocketCommonProtocol +from websockets.compatibility import asyncio_ensure_future +from websockets.exceptions import ConnectionClosed, InvalidState +from websockets.framing import * +from websockets.protocol import State, WebSocketCommonProtocol # Avoid displaying stack traces at the ERROR logging level. diff --git a/tests/test_speedups.py b/tests/test_speedups.py new file mode 100644 index 000000000..e69de29bb diff --git a/websockets/test_uri.py b/tests/test_uri.py similarity index 92% rename from websockets/test_uri.py rename to tests/test_uri.py index f8f9e042f..ad4ec4013 100644 --- a/websockets/test_uri.py +++ b/tests/test_uri.py @@ -1,7 +1,7 @@ import unittest -from .exceptions import InvalidURI -from .uri import * +from websockets.exceptions import InvalidURI +from websockets.uri import * VALID_URIS = [ diff --git a/websockets/test_utils.py b/tests/test_utils.py similarity index 92% rename from websockets/test_utils.py rename to tests/test_utils.py index 13ce9ff99..c7699232e 100644 --- a/websockets/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,6 @@ import unittest -from .utils import apply_mask as py_apply_mask +from websockets.utils import apply_mask as py_apply_mask class UtilsTests(unittest.TestCase): @@ -37,7 +37,7 @@ def test_apply_mask_check_mask_length(self): try: - from .speedups import apply_mask as c_apply_mask + from websockets.speedups import apply_mask as c_apply_mask except ImportError: # pragma: no cover pass else: diff --git a/tox.ini b/tox.ini index b6c8eeab2..6cff294e5 100644 --- a/tox.ini +++ b/tox.ini @@ -1,47 +1,24 @@ [tox] -envlist = {py34,py35,py36,py37}{,-speedups},coverage,black,flake8,isort +envlist = py34,py35,py36,py37,coverage,black,flake8,isort [testenv] -commands = - ; Unfortunately tox has no support for building C extensions. - ; Do it manually in the git checkout - that's where tests run. - - ; Remove any existing compiled extension. - sh -c 'rm -f websockets/*.so' - - ; Before testing with speedups, compile the extension. - speedups: python setup.py --quiet build_ext --inplace - - python -W default -m unittest {posargs} - - ; After testing with speedups, remove the extension. - speedups: sh -c 'rm websockets/*.so' - -whitelist_externals = - sh +commands = python -W default -m unittest {posargs} [testenv:coverage] commands = - ; Handle speedups as above. - sh -c 'rm -f websockets/*.so' - python setup.py --quiet build_ext --inplace - python -m coverage erase - python -W default -m coverage run --branch --omit=websockets/__main__.py --source=websockets -m unittest + python -W default -m coverage run -m unittest {posargs} python -m coverage report --show-missing --fail-under=100 - - speedups: sh -c 'rm websockets/*.so' - deps = coverage [testenv:black] -commands = black --check --skip-string-normalization websockets +commands = black --check --skip-string-normalization src tests deps = black [testenv:flake8] -commands = flake8 websockets +commands = flake8 src tests deps = flake8 [testenv:isort] -commands = isort --check-only --recursive websockets +commands = isort --check-only --recursive src tests deps = isort From 983f4fd8ec3f24362e0bdf2d821ea8282eb2d008 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Sep 2018 19:04:03 +0200 Subject: [PATCH 0470/1539] Abort pending pings with ConnectionClosed. Fix #464. --- docs/changelog.rst | 12 +++++++++++- src/websockets/protocol.py | 38 +++++++++++++++++++++++++------------- tests/test_protocol.py | 7 ++++--- 3 files changed, 40 insertions(+), 17 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index c895baac9..630ee8786 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -3,11 +3,21 @@ Changelog .. currentmodule:: websockets -6.1 +7.0 ... *In development* +.. warning:: + + **Version 7.0 changes how a :meth:`~protocol.WebSocketCommonProtocol.ping` + that hasn't received a pong yet behaves when the connection is closed.** + + The ping — as in ``ping = await websocket.ping()`` — used to be canceled + when the connection is closed, so that ``await ping`` raised + :exc:`~concurrent.futures.CancelledError`. Now ``await ping`` raises + :exc:`~exceptions.ConnectionClosed` like other public APIs. + * websockets sends Ping frames at regular intervals and closes the connection if it doesn't receive a matching Pong frame. See :class:`~protocol.WebSocketCommonProtocol` for details. diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index b9bbda37a..34fa48834 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -919,19 +919,6 @@ def close_connection(self): if self.keepalive_ping_task is not None: self.keepalive_ping_task.cancel() - # Cancel all pending pings because they'll never receive a pong. - for ping in self.pings.values(): - ping.cancel() - if self.pings: - pings_hex = ', '.join( - binascii.hexlify(ping_id).decode() or '[empty]' - for ping_id in self.pings - ) - plural = 's' if len(self.pings) > 1 else '' - logger.debug( - "%s - canceled pending ping%s: %s", self.side, plural, pings_hex - ) - # A client should wait for a TCP close from the server. if self.is_client and self.transfer_data_task is not None: if (yield from self.wait_for_connection_lost()): @@ -1059,6 +1046,30 @@ def fail_connection(self, code=1006, reason=''): return self.close_connection_task + def abort_keepalive_pings(self): + """ + Raise ConnectionClosed in pending keepalive pings. + + They'll never receive a pong once the connection is closed. + + """ + assert self.state is State.CLOSED + exc = ConnectionClosed(self.close_code, self.close_reason) + exc.__cause__ = self.transfer_data_exc # emulate raise ... from ... + + for ping in self.pings.values(): + ping.set_exception(exc) + + if self.pings: + pings_hex = ', '.join( + binascii.hexlify(ping_id).decode() or '[empty]' + for ping_id in self.pings + ) + plural = 's' if len(self.pings) > 1 else '' + logger.debug( + "%s - aborted pending ping%s: %s", self.side, plural, pings_hex + ) + # asyncio.StreamReaderProtocol methods def connection_made(self, transport): @@ -1122,6 +1133,7 @@ def connection_lost(self, exc): self.close_code, self.close_reason or '[empty]', ) + self.abort_keepalive_pings() # If self.connection_lost_waiter isn't pending, that's a bug, because: # - it's set only here in connection_lost() which is called only once; # - it must never be canceled. diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 133a015ed..063960a19 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -605,13 +605,14 @@ def test_acknowledge_ping(self): self.run_loop_once() self.assertTrue(ping.done()) - def test_cancel_ping(self): + def test_abort_ping(self): ping = self.loop.run_until_complete(self.protocol.ping()) # Remove the frame from the buffer, else close_connection() complains. self.last_sent_frame() - self.assertFalse(ping.cancelled()) + self.assertFalse(ping.done()) self.close_connection() - self.assertTrue(ping.cancelled()) + self.assertTrue(ping.done()) + self.assertIsInstance(ping.exception(), ConnectionClosed) def test_acknowledge_previous_pings(self): pings = [ From a8d54373901c56832fafa33b7c713a555f5bce02 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 Sep 2018 20:58:31 +0200 Subject: [PATCH 0471/1539] Fix typo. Fix #471. --- docs/design.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/design.rst b/docs/design.rst index 63afbb8d4..7ce40db5a 100644 --- a/docs/design.rst +++ b/docs/design.rst @@ -412,7 +412,7 @@ must propagate that information back to the previous part in the chain. For incoming data, ``websockets`` builds upon :class:`~asyncio.StreamReader` which propagates backpressure to its own buffer and to the TCP stream. Frames are parsed from the input stream and added to a bounded queue. If the queue -fills up, parsing halts until some the application reads a frame. +fills up, parsing halts until the application reads a frame. For outgoing data, ``websockets`` builds upon :class:`~asyncio.StreamWriter` which implements flow control. If the output buffers grow too large, it waits From ff9fc363f70068a52c8d75f9d3048b5a6aa223fd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 Sep 2018 16:25:29 +0200 Subject: [PATCH 0472/1539] Add wait_closed method to protocols. Fix #469. --- docs/api.rst | 1 + docs/changelog.rst | 3 +++ src/websockets/protocol.py | 19 ++++++++++++++++--- tests/test_protocol.py | 5 +++++ 4 files changed, 25 insertions(+), 3 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index 67f3756a4..cefa89bc1 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -68,6 +68,7 @@ Shared .. autoclass:: WebSocketCommonProtocol(*, host=None, port=None, secure=None, ping_interval=20, ping_timeout=20, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None) .. automethod:: close(code=1000, reason='') + .. automethod:: wait_closed(code=1000, reason='') .. automethod:: recv() .. automethod:: send(data) diff --git a/docs/changelog.rst b/docs/changelog.rst index 630ee8786..a3d7448d5 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -21,6 +21,9 @@ Changelog * websockets sends Ping frames at regular intervals and closes the connection if it doesn't receive a matching Pong frame. See :class:`~protocol.WebSocketCommonProtocol` for details. +* Added the :meth:`~protocol.WebSocketCommonProtocol.wait_closed` method to + protocols. + * Added an interactive client: ``python -m websockets ``. * Changed the ``origins`` argument to represent the lack of an origin with diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 34fa48834..f59ddd707 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -345,6 +345,18 @@ def closed(self): """ return self.state is State.CLOSED + def wait_closed(self): + """ + Return a :class:`asyncio.Future` that completes when the connection is closed. + + This is identical to :attr:`closed`, except it can be awaited. + + This can make it easier to handle connection termination, regardless + of its cause, in tasks that interact with the WebSocket connection. + + """ + return asyncio.shield(self.connection_lost_waiter) + @asyncio.coroutine def recv(self): """ @@ -431,10 +443,11 @@ def close(self, code=1000, reason=''): This coroutine performs the closing handshake. It waits for the other end to complete the handshake and for the TCP - connection to terminate. + connection to terminate. As a consequence, there's no need to await + :meth:`wait_closed`; :meth:`close` already does it. - It doesn't do anything once the connection is closed. In other words - it's idempotent. + :meth:`close` is idempotent: it doesn't do anything once the + connection is closed. It's safe to wrap this coroutine in :func:`~asyncio.ensure_future` since errors during connection termination aren't particularly useful. diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 063960a19..6c0a4325f 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -361,6 +361,11 @@ def test_closed(self): self.close_connection() self.assertTrue(self.protocol.closed) + def test_wait_closed(self): + self.assertFalse(self.protocol.wait_closed().done()) + self.close_connection() + self.assertTrue(self.protocol.wait_closed().done()) + # Test the recv coroutine. def test_recv_text(self): From 4e7a82eeec4621f7a6d99d43aa0a995b70382992 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 Sep 2018 16:09:21 +0200 Subject: [PATCH 0473/1539] Rename timeout to close_timeout. Fix #460. --- docs/api.rst | 12 ++++++------ docs/changelog.rst | 14 ++++++++++++-- docs/conf.py | 2 +- docs/design.rst | 23 ++++++++++++----------- src/websockets/client.py | 19 +++++++++++-------- src/websockets/protocol.py | 36 +++++++++++++++++++++--------------- src/websockets/server.py | 19 +++++++++++-------- tests/test_client_server.py | 37 ++++++++++++++++++++++++++++++++++++- tests/test_protocol.py | 16 ++++++++-------- 9 files changed, 118 insertions(+), 60 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index cefa89bc1..39b5922f9 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -32,9 +32,9 @@ Server .. automodule:: websockets.server - .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, ping_interval=20, ping_timeout=20, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, compression='deflate', **kwds) + .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, compression='deflate', **kwds) - .. autofunction:: unix_serve(ws_handler, path, *, create_protocol=None, ping_interval=20, ping_timeout=20, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, compression='deflate', **kwds) + .. autofunction:: unix_serve(ws_handler, path, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, compression='deflate', **kwds) .. autoclass:: WebSocketServer @@ -43,7 +43,7 @@ Server .. automethod:: wait_closed() .. autoattribute:: sockets - .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, host=None, port=None, secure=None, ping_interval=20, ping_timeout=20, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, extensions=None, subprotocols=None, extra_headers=None) + .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, host=None, port=None, secure=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, extensions=None, subprotocols=None, extra_headers=None) .. automethod:: handshake(origins=None, available_extensions=None, available_subprotocols=None, extra_headers=None) .. automethod:: process_request(path, request_headers) @@ -54,9 +54,9 @@ Client .. automodule:: websockets.client - .. autofunction:: connect(uri, *, create_protocol=None, ping_interval=20, ping_timeout=20, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, compression='deflate', **kwds) + .. autofunction:: connect(uri, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, compression='deflate', **kwds) - .. autoclass:: WebSocketClientProtocol(*, host=None, port=None, secure=None, ping_interval=20, ping_timeout=20, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, extensions=None, subprotocols=None, extra_headers=None) + .. autoclass:: WebSocketClientProtocol(*, host=None, port=None, secure=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, extensions=None, subprotocols=None, extra_headers=None) .. automethod:: handshake(wsuri, origin=None, available_extensions=None, available_subprotocols=None, extra_headers=None) @@ -65,7 +65,7 @@ Shared .. automodule:: websockets.protocol - .. autoclass:: WebSocketCommonProtocol(*, host=None, port=None, secure=None, ping_interval=20, ping_timeout=20, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None) + .. autoclass:: WebSocketCommonProtocol(*, host=None, port=None, secure=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None) .. automethod:: close(code=1000, reason='') .. automethod:: wait_closed(code=1000, reason='') diff --git a/docs/changelog.rst b/docs/changelog.rst index a3d7448d5..f04054b24 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -10,8 +10,18 @@ Changelog .. warning:: - **Version 7.0 changes how a :meth:`~protocol.WebSocketCommonProtocol.ping` - that hasn't received a pong yet behaves when the connection is closed.** + **Version 7.0 renames the** ``timeout`` **argument of** + :func:`~server.serve()` **and** :func:`~client.connect()` **to** + ``close_timeout`` **.** + + This prevents confusion with ``ping_timeout``. + + For backwards compatibility, ``timeout`` is still supported. + +.. warning:: + + **Version 7.0 changes how a** :meth:`~protocol.WebSocketCommonProtocol.ping` + **that hasn't received a pong yet behaves when the connection is closed.** The ping — as in ``ping = await websocket.ping()`` — used to be canceled when the connection is closed, so that ``await ping`` raised diff --git a/docs/conf.py b/docs/conf.py index 04db46ca7..8fdf12b4d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -16,7 +16,7 @@ # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.join(os.path.abspath('..'), 'src')) # -- General configuration ----------------------------------------------------- diff --git a/docs/design.rst b/docs/design.rst index 7ce40db5a..93869732a 100644 --- a/docs/design.rst +++ b/docs/design.rst @@ -271,11 +271,12 @@ state and sends a close frame. When the other side sends a close frame, ``CLOSING`` state and returns ``None``, also causing :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to terminate. -If the other side doesn't send a close frame within the connection's timeout, -``websockets`` :ref:`fails the connection `. +If the other side doesn't send a close frame within the connection's close +timeout, ``websockets`` :ref:`fails the connection `. -The closing handshake can take up to ``2 * timeout``: one ``timeout`` to write -a close frame and one ``timeout`` to receive a close frame. +The closing handshake can take up to ``2 * close_timeout``: one +``close_timeout`` to write a close frame and one ``close_timeout`` to receive +a close frame. Then ``websockets`` terminates the TCP connection. @@ -313,13 +314,13 @@ Then :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` cancels protocol compliance responsibilities. Terminating it to avoid leaking it is the only concern. -Terminating the TCP connection can take up to ``2 * timeout`` on the server -side and ``3 * timeout`` on the client side. Clients start by waiting for the -server to close the connection, hence the extra ``timeout``. Then both sides -go through the following steps until the TCP connection is lost: half-closing -the connection (only for non-TLS connections), closing the connection, -aborting the connection. At this point the connection drops regardless of what -happens on the network. +Terminating the TCP connection can take up to ``2 * close_timeout`` on the +server side and ``3 * close_timeout`` on the client side. Clients start by +waiting for the server to close the connection, hence the extra +``close_timeout``. Then both sides go through the following steps until the +TCP connection is lost: half-closing the connection (only for non-TLS +connections), closing the connection, aborting the connection. At this point +the connection drops regardless of what happens on the network. .. _connection-failure: diff --git a/src/websockets/client.py b/src/websockets/client.py index 94c761745..b1b14bc62 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -318,7 +318,7 @@ class Connect: a ``wss://`` URI, if this argument isn't provided explicitly, it's set to ``True``, which means Python's default :class:`~ssl.SSLContext` is used. - The behavior of the ``ping_interval``, ``ping_timeout``, ``timeout``, + The behavior of the ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit`` optional arguments is described in the documentation of :class:`~websockets.protocol.WebSocketCommonProtocol`. @@ -357,14 +357,15 @@ def __init__( create_protocol=None, ping_interval=20, ping_timeout=20, - timeout=10, + close_timeout=None, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, legacy_recv=False, - klass=None, + klass=WebSocketClientProtocol, + timeout=10, origin=None, extensions=None, subprotocols=None, @@ -375,14 +376,16 @@ def __init__( if loop is None: loop = asyncio.get_event_loop() + # Backwards-compatibility: close_timeout used to be called timeout. + # If both are specified, timeout is ignored. + if close_timeout is None: + close_timeout = timeout + # Backwards-compatibility: create_protocol used to be called klass. - # In the unlikely event that both are specified, klass is ignored. + # If both are specified, klass is ignored. if create_protocol is None: create_protocol = klass - if create_protocol is None: - create_protocol = WebSocketClientProtocol - wsuri = parse_uri(uri) if wsuri.secure: kwds.setdefault('ssl', True) @@ -411,7 +414,7 @@ def __init__( secure=wsuri.secure, ping_interval=ping_interval, ping_timeout=ping_timeout, - timeout=timeout, + close_timeout=close_timeout, max_size=max_size, max_queue=max_queue, read_limit=read_limit, diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index f59ddd707..d95c26b38 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -96,13 +96,13 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): .. _Pong frame: https://tools.ietf.org/html/rfc6455#section-5.5.3 - The ``timeout`` parameter defines a maximum wait time in seconds for + The ``close_timeout`` parameter defines a maximum wait time in seconds for completing the closing handshake and terminating the TCP connection. - :meth:`close()` completes in at most ``4 * timeout`` on the server side - and ``5 * timeout`` on the client side. + :meth:`close()` completes in at most ``4 * close_timeout`` on the server + side and ``5 * close_timeout`` on the client side. - ``timeout`` needs to be a parameter of the protocol because websockets - usually calls :meth:`close()` implicitly: + ``close_timeout`` needs to be a parameter of the protocol because + websockets usually calls :meth:`close()` implicitly: - on the server side, when the connection handler terminates, - on the client side, when exiting the context manager for the connection. @@ -173,20 +173,26 @@ def __init__( secure=None, ping_interval=20, ping_timeout=20, - timeout=10, + close_timeout=None, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, - legacy_recv=False + legacy_recv=False, + timeout=10 ): + # Backwards-compatibility: close_timeout used to be called timeout. + # If both are specified, timeout is ignored. + if close_timeout is None: + close_timeout = timeout + self.host = host self.port = port self.secure = secure self.ping_interval = ping_interval self.ping_timeout = ping_timeout - self.timeout = timeout + self.close_timeout = close_timeout self.max_size = max_size self.max_queue = max_queue self.read_limit = read_limit @@ -458,7 +464,7 @@ def close(self, code=1000, reason=''): try: yield from asyncio.wait_for( self.write_close_frame(serialize_close(code, reason)), - self.timeout, + self.close_timeout, loop=self.loop, ) except asyncio.TimeoutError: @@ -480,7 +486,7 @@ def close(self, code=1000, reason=''): # is canceled before the timeout elapses (on Python ≥ 3.4.3). # This helps closing connections when shutting down a server. yield from asyncio.wait_for( - self.transfer_data_task, self.timeout, loop=self.loop + self.transfer_data_task, self.close_timeout, loop=self.loop ) except (asyncio.TimeoutError, asyncio.CancelledError): pass @@ -574,9 +580,9 @@ def ensure_open(self): if self.state is State.CLOSING: # If we started the closing handshake, wait for its completion to # get the proper close code and status. self.close_connection_task - # will complete within 4 or 5 * timeout after calling close(). - # The CLOSING state also occurs when failing the connection. In - # that case self.close_connection_task will complete even faster. + # will complete within 4 or 5 * close_timeout after close(). The + # CLOSING state also occurs when failing the connection. In that + # case self.close_connection_task will complete even faster. if self.close_code is None: yield from asyncio.shield(self.close_connection_task) raise ConnectionClosed( @@ -975,7 +981,7 @@ def close_connection(self): @asyncio.coroutine def wait_for_connection_lost(self): """ - Wait until the TCP connection is closed or ``self.timeout`` elapses. + Wait until the TCP connection is closed or ``self.close_timeout`` elapses. Return ``True`` if the connection is closed and ``False`` otherwise. @@ -984,7 +990,7 @@ def wait_for_connection_lost(self): try: yield from asyncio.wait_for( asyncio.shield(self.connection_lost_waiter), - self.timeout, + self.close_timeout, loop=self.loop, ) except asyncio.TimeoutError: diff --git a/src/websockets/server.py b/src/websockets/server.py index 4284d09ad..a349fea9c 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -667,7 +667,7 @@ class Serve: :class:`WebSocketServerProtocol` instance. It defaults to :class:`WebSocketServerProtocol`. - The behavior of the ``ping_interval``, ``ping_timeout``, ``timeout``, + The behavior of the ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit`` optional arguments is described in the documentation of :class:`~websockets.protocol.WebSocketCommonProtocol`. @@ -721,14 +721,15 @@ def __init__( create_protocol=None, ping_interval=20, ping_timeout=20, - timeout=10, + close_timeout=None, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, legacy_recv=False, - klass=None, + klass=WebSocketServerProtocol, + timeout=10, origins=None, extensions=None, subprotocols=None, @@ -736,14 +737,16 @@ def __init__( compression='deflate', **kwds ): + # Backwards-compatibility: close_timeout used to be called timeout. + # If both are specified, timeout is ignored. + if close_timeout is None: + close_timeout = timeout + # Backwards-compatibility: create_protocol used to be called klass. - # In the unlikely event that both are specified, klass is ignored. + # If both are specified, klass is ignored. if create_protocol is None: create_protocol = klass - if create_protocol is None: - create_protocol = WebSocketServerProtocol - if loop is None: loop = asyncio.get_event_loop() @@ -770,7 +773,7 @@ def __init__( secure=secure, ping_interval=ping_interval, ping_timeout=ping_timeout, - timeout=timeout, + close_timeout=close_timeout, max_size=max_size, max_queue=max_queue, read_limit=read_limit, diff --git a/tests/test_client_server.py b/tests/test_client_server.py index 83bd6ad9d..55fee7340 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -52,6 +52,8 @@ def handler(ws, path): if path == '/attributes': yield from ws.send(repr((ws.host, ws.port, ws.secure))) + elif path == '/close_timeout': + yield from ws.send(repr(ws.close_timeout)) elif path == '/path': yield from ws.send(str(ws.path)) elif path == '/headers': @@ -554,7 +556,7 @@ def test_server_create_protocol_function(self): self.assert_client_raises_code(401) @with_server(klass=UnauthorizedServerProtocol) - def test_server_klass(self): + def test_server_klass_backwards_compatibility(self): self.assert_client_raises_code(401) @with_server( @@ -586,6 +588,39 @@ def test_client_klass(self): def test_client_create_protocol_over_klass(self): self.assertIsInstance(self.client, BarClientProtocol) + @with_server(close_timeout=7) + @with_client('/close_timeout') + def test_server_close_timeout(self): + close_timeout = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(eval(close_timeout), 7) + + @with_server(timeout=6) + @with_client('/close_timeout') + def test_server_timeout_backwards_compatibility(self): + close_timeout = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(eval(close_timeout), 6) + + @with_server(close_timeout=7, timeout=6) + @with_client('/close_timeout') + def test_server_close_timeout_over_timeout(self): + close_timeout = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(eval(close_timeout), 7) + + @with_server() + @with_client('/close_timeout', close_timeout=7) + def test_client_close_timeout(self): + self.assertEqual(self.client.close_timeout, 7) + + @with_server() + @with_client('/close_timeout', timeout=6) + def test_client_timeout_backwards_compatibility(self): + self.assertEqual(self.client.close_timeout, 6) + + @with_server() + @with_client('/close_timeout', close_timeout=7, timeout=6) + def test_client_close_timeout_over_timeout(self): + self.assertEqual(self.client.close_timeout, 7) + @with_server() @with_client('/extensions') def test_no_extension(self): diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 6c0a4325f..27a5c1133 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -1010,7 +1010,7 @@ def setUp(self): self.protocol.side = 'server' def test_local_close_send_close_frame_timeout(self): - self.protocol.timeout = 10 * MS + self.protocol.close_timeout = 10 * MS self.make_drain_slow(50 * MS) # If we can't send a close frame, time out in 10ms. # Check the timing within -1/+9ms for robustness. @@ -1019,7 +1019,7 @@ def test_local_close_send_close_frame_timeout(self): self.assertConnectionClosed(1006, '') def test_local_close_receive_close_frame_timeout(self): - self.protocol.timeout = 10 * MS + self.protocol.close_timeout = 10 * MS # If the client doesn't send a close frame, time out in 10ms. # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(9 * MS, 19 * MS): @@ -1027,7 +1027,7 @@ def test_local_close_receive_close_frame_timeout(self): self.assertConnectionClosed(1006, '') def test_local_close_connection_lost_timeout_after_write_eof(self): - self.protocol.timeout = 10 * MS + self.protocol.close_timeout = 10 * MS # If the client doesn't close its side of the TCP connection after we # half-close our side with write_eof(), time out in 10ms. # Check the timing within -1/+9ms for robustness. @@ -1039,7 +1039,7 @@ def test_local_close_connection_lost_timeout_after_write_eof(self): self.assertConnectionClosed(1000, 'close') def test_local_close_connection_lost_timeout_after_close(self): - self.protocol.timeout = 10 * MS + self.protocol.close_timeout = 10 * MS # If the client doesn't close its side of the TCP connection after we # half-close our side with write_eof() and close it with close(), time # out in 20ms. @@ -1061,7 +1061,7 @@ def setUp(self): self.protocol.side = 'client' def test_local_close_send_close_frame_timeout(self): - self.protocol.timeout = 10 * MS + self.protocol.close_timeout = 10 * MS self.make_drain_slow(50 * MS) # If we can't send a close frame, time out in 20ms. # - 10ms waiting for sending a close frame @@ -1072,7 +1072,7 @@ def test_local_close_send_close_frame_timeout(self): self.assertConnectionClosed(1006, '') def test_local_close_receive_close_frame_timeout(self): - self.protocol.timeout = 10 * MS + self.protocol.close_timeout = 10 * MS # If the server doesn't send a close frame, time out in 20ms: # - 10ms waiting for receiving a close frame # - 10ms waiting for receiving a half-close @@ -1082,7 +1082,7 @@ def test_local_close_receive_close_frame_timeout(self): self.assertConnectionClosed(1006, '') def test_local_close_connection_lost_timeout_after_write_eof(self): - self.protocol.timeout = 10 * MS + self.protocol.close_timeout = 10 * MS # If the server doesn't half-close its side of the TCP connection # after we send a close frame, time out in 20ms: # - 10ms waiting for receiving a half-close @@ -1096,7 +1096,7 @@ def test_local_close_connection_lost_timeout_after_write_eof(self): self.assertConnectionClosed(1000, 'close') def test_local_close_connection_lost_timeout_after_close(self): - self.protocol.timeout = 10 * MS + self.protocol.close_timeout = 10 * MS # If the client doesn't close its side of the TCP connection after we # half-close our side with write_eof() and close it with close(), time # out in 20ms. From 84ce48cb148fa815fb40882bf4aa4806f562886d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Sep 2018 21:04:45 +0200 Subject: [PATCH 0474/1539] Make it possible to fragment outgoing messages. Fix #258. --- docs/changelog.rst | 2 ++ src/websockets/protocol.py | 73 +++++++++++++++++++++++++++++++------- tests/test_protocol.py | 68 +++++++++++++++++++++++++++-------- 3 files changed, 117 insertions(+), 26 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index f04054b24..f2f9412d2 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -31,6 +31,8 @@ Changelog * websockets sends Ping frames at regular intervals and closes the connection if it doesn't receive a matching Pong frame. See :class:`~protocol.WebSocketCommonProtocol` for details. +* Added support for sending fragmented messages. + * Added the :meth:`~protocol.WebSocketCommonProtocol.wait_closed` method to protocols. diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index d95c26b38..6e158534a 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -11,6 +11,7 @@ import binascii import codecs import collections +import collections.abc import enum import logging import random @@ -428,20 +429,66 @@ def send(self, data): This coroutine sends a message. It sends :class:`str` as a text frame and :class:`bytes` as a binary - frame. It raises a :exc:`TypeError` for other inputs. + frame. + + It also accepts an iterable of :class:`str` or :class:`bytes`. Each + item is treated as a message fragment and sent in its own frame. All + items must be of the same type, or else :meth:`send` will raise a + :exc:`TypeError` and the connection will be closed. + + It raises a :exc:`TypeError` for other inputs. """ yield from self.ensure_open() + # Unfragmented message (first because str and bytes are iterable). + if isinstance(data, str): - opcode = 1 - data = data.encode('utf-8') + yield from self.write_frame(True, OP_TEXT, data.encode('utf-8')) + elif isinstance(data, bytes): - opcode = 2 - else: - raise TypeError("data must be bytes or str") + yield from self.write_frame(True, OP_BINARY, data) - yield from self.write_frame(opcode, data) + # Fragmented message -- regular iterator. + + elif isinstance(data, collections.abc.Iterable): + iter_data = iter(data) + + # First fragment. + try: + data = next(iter_data) + except StopIteration: + return + data_type = type(data) + if isinstance(data, str): + yield from self.write_frame(False, OP_TEXT, data.encode('utf-8')) + encode_data = True + elif isinstance(data, bytes): + yield from self.write_frame(False, OP_BINARY, data) + encode_data = False + else: + raise TypeError("data must be an iterable of bytes or str") + + # Other fragments. + for data in iter_data: + if type(data) != data_type: + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + self.fail_connection(1011) + raise TypeError("data contains inconsistent types") + if encode_data: + data = data.encode('utf-8') + yield from self.write_frame(False, OP_CONT, data) + + # Final fragment. + yield from self.write_frame(True, OP_CONT, type(data)()) + + # Fragmented message -- asynchronous iterator + + # To be implemented after dropping support for Python 3.4. + + else: + raise TypeError("data must be bytes, str, or iterable") @asyncio.coroutine def close(self, code=1000, reason=''): @@ -529,7 +576,7 @@ def ping(self, data=None): self.pings[data] = asyncio.Future(loop=self.loop) - yield from self.write_frame(OP_PING, data) + yield from self.write_frame(True, OP_PING, data) return asyncio.shield(self.pings[data]) @@ -549,7 +596,7 @@ def pong(self, data=b''): data = encode_data(data) - yield from self.write_frame(OP_PONG, data) + yield from self.write_frame(True, OP_PONG, data) # Private methods - no guarantees. @@ -803,14 +850,14 @@ def read_frame(self, max_size): return frame @asyncio.coroutine - def write_frame(self, opcode, data=b'', _expected_state=State.OPEN): + def write_frame(self, fin, opcode, data, *, _expected_state=State.OPEN): # Defensive assertion for protocol compliance. if self.state is not _expected_state: # pragma: no cover raise InvalidState( "Cannot write to a WebSocket " "in the {} state".format(self.state.name) ) - frame = Frame(True, opcode, data) + frame = Frame(fin, opcode, data) logger.debug("%s > %s", self.side, frame) frame.write(self.writer.write, mask=self.is_client, extensions=self.extensions) @@ -870,7 +917,9 @@ def write_close_frame(self, data=b''): logger.debug("%s - state = CLOSING", self.side) # 7.1.2. Start the WebSocket Closing Handshake - yield from self.write_frame(OP_CLOSE, data, State.CLOSING) + yield from self.write_frame( + True, OP_CLOSE, data, _expected_state=State.CLOSING + ) @asyncio.coroutine def keepalive_ping(self): diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 27a5c1133..d9aea196b 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -255,12 +255,9 @@ def process_invalid_frames(self): self.receive_eof() self.loop.run_until_complete(self.protocol.close_connection_task) - def last_sent_frame(self): + def sent_frames(self): """ - Read the last frame sent to the transport. - - This method assumes that at most one frame was sent. It raises an - AssertionError otherwise. + Read all frames sent to the transport. """ stream = asyncio.StreamReader(loop=self.loop) @@ -270,18 +267,30 @@ def last_sent_frame(self): self.transport.write.call_args_list = [] stream.feed_eof() - if stream.at_eof(): - frame = None - else: - frame = self.loop.run_until_complete( - Frame.read(stream.readexactly, mask=self.protocol.is_client) + frames = [] + while not stream.at_eof(): + frames.append( + self.loop.run_until_complete( + Frame.read(stream.readexactly, mask=self.protocol.is_client) + ) ) + return frames + + def last_sent_frame(self): + """ + Read the last frame sent to the transport. + + This method assumes that at most one frame was sent. It raises an + AssertionError otherwise. - if not stream.at_eof(): # pragma: no cover - data = self.loop.run_until_complete(stream.read()) - raise AssertionError("Trailing data found: {!r}".format(data)) + """ + frames = self.sent_frames() + if frames: + assert len(frames) == 1 + return frames[0] - return frame + def assertFramesSent(self, *frames): + self.assertEqual(self.sent_frames(), [Frame(*args) for args in frames]) def assertOneFrameSent(self, *args): self.assertEqual(self.last_sent_frame(), Frame(*args)) @@ -467,6 +476,37 @@ def test_send_type_error(self): self.loop.run_until_complete(self.protocol.send(42)) self.assertNoFrameSent() + def test_send_iterable_text(self): + self.loop.run_until_complete(self.protocol.send(['ca', 'fé'])) + self.assertFramesSent( + (False, OP_TEXT, 'ca'.encode('utf-8')), + (False, OP_CONT, 'fé'.encode('utf-8')), + (True, OP_CONT, ''.encode('utf-8')), + ) + + def test_send_iterable_binary(self): + self.loop.run_until_complete(self.protocol.send([b'te', b'a'])) + self.assertFramesSent( + (False, OP_BINARY, b'te'), (False, OP_CONT, b'a'), (True, OP_CONT, b'') + ) + + def test_send_empty_iterable(self): + self.loop.run_until_complete(self.protocol.send([])) + self.assertNoFrameSent() + + def test_send_iterable_type_error(self): + with self.assertRaises(TypeError): + self.loop.run_until_complete(self.protocol.send([42])) + self.assertNoFrameSent() + + def test_send_iterable_mixed_type_error(self): + with self.assertRaises(TypeError): + self.loop.run_until_complete(self.protocol.send(['café', b'tea'])) + self.assertFramesSent( + (False, OP_TEXT, 'café'.encode('utf-8')), + (True, OP_CLOSE, serialize_close(1011, '')), + ) + def test_send_on_closing_connection_local(self): close_task = self.half_close_connection_local() From 94e4ee4dbe33d8da65e7362141a83b5ac05306d3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 28 Oct 2018 13:09:18 +0100 Subject: [PATCH 0475/1539] Don't prevent overriding the Date header. This is mostly for consistency. --- src/websockets/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/websockets/server.py b/src/websockets/server.py index a349fea9c..78f8c80b4 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -485,7 +485,6 @@ def handshake( ) response_headers = Headers() - response_headers['Date'] = email.utils.formatdate(usegmt=True) build_response(response_headers, key) @@ -505,6 +504,7 @@ def handshake( for name, value in extra_headers: response_headers[name] = value + response_headers.setdefault('Date', email.utils.formatdate(usegmt=True)) response_headers.setdefault('Server', USER_AGENT) yield from self.write_http_response(SWITCHING_PROTOCOLS, response_headers) From 363b5e600f23373296e56b10f43f81dfce8b261a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 28 Oct 2018 13:09:41 +0100 Subject: [PATCH 0476/1539] Negligible cleanup. --- src/websockets/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/websockets/server.py b/src/websockets/server.py index 78f8c80b4..625c21209 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -491,7 +491,7 @@ def handshake( if extensions_header is not None: response_headers['Sec-WebSocket-Extensions'] = extensions_header - if self.subprotocol is not None: + if protocol_header is not None: response_headers['Sec-WebSocket-Protocol'] = protocol_header if extra_headers is not None: From aa6ba992b346d56b80e8aabc074ae83c40113213 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 28 Oct 2018 14:33:28 +0100 Subject: [PATCH 0477/1539] Add flake8 to `make style`. --- Makefile | 1 + 1 file changed, 1 insertion(+) diff --git a/Makefile b/Makefile index 5d729181b..7de30b002 100644 --- a/Makefile +++ b/Makefile @@ -4,6 +4,7 @@ export PYTHONPATH=src style: isort --recursive src tests black --skip-string-normalization src tests + flake8 src tests test: python -W default -m unittest From db8820790f0f066493db5de7c964c26771f73c75 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 28 Oct 2018 14:30:58 +0100 Subject: [PATCH 0478/1539] Black now normalizes hex representations of integers. --- src/websockets/framing.py | 2 +- tests/test_framing.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/websockets/framing.py b/src/websockets/framing.py index afbc664c6..00a24d807 100644 --- a/src/websockets/framing.py +++ b/src/websockets/framing.py @@ -40,7 +40,7 @@ ] DATA_OPCODES = OP_CONT, OP_TEXT, OP_BINARY = 0x00, 0x01, 0x02 -CTRL_OPCODES = OP_CLOSE, OP_PING, OP_PONG = 0x08, 0x09, 0x0a +CTRL_OPCODES = OP_CLOSE, OP_PING, OP_PONG = 0x08, 0x09, 0x0A # Close code that are allowed in a close frame. # Using a list optimizes `code in EXTERNAL_CLOSE_CODES`. diff --git a/tests/test_framing.py b/tests/test_framing.py index ee8515762..9da64f14c 100644 --- a/tests/test_framing.py +++ b/tests/test_framing.py @@ -121,13 +121,13 @@ def test_bad_reserved_bits(self): self.decode(encoded) def test_good_opcode(self): - for opcode in list(range(0x00, 0x03)) + list(range(0x08, 0x0b)): + for opcode in list(range(0x00, 0x03)) + list(range(0x08, 0x0B)): encoded = bytes([0x80 | opcode, 0]) with self.subTest(encoded=encoded): self.decode(encoded) # does not raise an exception def test_bad_opcode(self): - for opcode in list(range(0x03, 0x08)) + list(range(0x0b, 0x10)): + for opcode in list(range(0x03, 0x08)) + list(range(0x0B, 0x10)): encoded = bytes([0x80 | opcode, 0]) with self.subTest(encoded=encoded): with self.assertRaises(WebSocketProtocolError): From 2f5ec328679e77eec03a5410a11bdccaa51e8c91 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 27 Oct 2018 17:30:09 +0200 Subject: [PATCH 0479/1539] Fix indentation. --- docs/changelog.rst | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index f2f9412d2..b46b9aff7 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -10,26 +10,27 @@ Changelog .. warning:: - **Version 7.0 renames the** ``timeout`` **argument of** - :func:`~server.serve()` **and** :func:`~client.connect()` **to** - ``close_timeout`` **.** + **Version 7.0 renames the** ``timeout`` **argument of** + :func:`~server.serve()` **and** :func:`~client.connect()` **to** + ``close_timeout`` **.** - This prevents confusion with ``ping_timeout``. + This prevents confusion with ``ping_timeout``. - For backwards compatibility, ``timeout`` is still supported. + For backwards compatibility, ``timeout`` is still supported. .. warning:: - **Version 7.0 changes how a** :meth:`~protocol.WebSocketCommonProtocol.ping` - **that hasn't received a pong yet behaves when the connection is closed.** + **Version 7.0 changes how a** :meth:`~protocol.WebSocketCommonProtocol.ping` + **that hasn't received a pong yet behaves when the connection is closed.** - The ping — as in ``ping = await websocket.ping()`` — used to be canceled - when the connection is closed, so that ``await ping`` raised - :exc:`~concurrent.futures.CancelledError`. Now ``await ping`` raises - :exc:`~exceptions.ConnectionClosed` like other public APIs. + The ping — as in ``ping = await websocket.ping()`` — used to be canceled + when the connection is closed, so that ``await ping`` raised + :exc:`~asyncio.CancelledError`. Now ``await ping`` raises + :exc:`~exceptions.ConnectionClosed` like other public APIs. * websockets sends Ping frames at regular intervals and closes the connection - if it doesn't receive a matching Pong frame. See :class:`~protocol.WebSocketCommonProtocol` for details. + if it doesn't receive a matching Pong frame. See + :class:`~protocol.WebSocketCommonProtocol` for details. * Added support for sending fragmented messages. From 8dfe6e3e860efc1cb753518ef652adc38b8da4c9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 27 Oct 2018 17:31:55 +0200 Subject: [PATCH 0480/1539] Normalize the spelling of "canceled". --- docs/design.rst | 2 +- src/websockets/protocol.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/design.rst b/docs/design.rst index 93869732a..2688a27dc 100644 --- a/docs/design.rst +++ b/docs/design.rst @@ -77,7 +77,7 @@ two tasks: - :attr:`~protocol.WebSocketCommonProtocol.keepalive_ping_task` runs :meth:`~protocol.WebSocketCommonProtocol.keepalive_ping()` which sends Ping frames at regular intervals and ensures that corresponding Pong frames are - received. It is cancelled when the connection terminates. It never exits + received. It is canceled when the connection terminates. It never exits with an exception other than :exc:`~asyncio.CancelledError`. - :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` runs diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 6e158534a..53f19a57d 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -941,9 +941,9 @@ def keepalive_ping(self): # ping() cannot raise ConnectionClosed, only CancelledError: # - If the connection is CLOSING, keepalive_ping_task will be - # cancelled by close_connection() before ping() returns. + # canceled by close_connection() before ping() returns. # - If the connection is CLOSED, keepalive_ping_task must be - # cancelled already. + # canceled already. ping_waiter = yield from self.ping() if self.ping_timeout is not None: From af9f318973f71bfdb79d00af0ff443b195ef8fc7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 27 Oct 2018 17:34:08 +0200 Subject: [PATCH 0481/1539] Call WebSocket codes "close codes" consistently. --- docs/changelog.rst | 2 +- src/websockets/protocol.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index b46b9aff7..f4df0ff03 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -108,7 +108,7 @@ Also: credentials. * Iterating on incoming messages no longer raises an exception when the - connection terminates with code 1001 (going away). + connection terminates with close code 1001 (going away). * A plain HTTP request now receives a 426 Upgrade Required response and doesn't log a stack trace. diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 53f19a57d..b98c37c60 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -76,7 +76,7 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): await process(message) The iterator yields incoming messages. It exits normally when the - connection is closed with the status code 1000 (OK) or 1001 (going away). + connection is closed with the close code 1000 (OK) or 1001 (going away). It raises a :exc:`~websockets.exceptions.ConnectionClosed` exception when the connection is closed with any other status code. From 9ebae9c41daa0bada6a3f1ae734ae92c55ced6a7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 28 Oct 2018 13:23:59 +0100 Subject: [PATCH 0482/1539] Improve debug logs when closing or failing connection. --- src/websockets/protocol.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index b98c37c60..2bbffaf7b 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -1072,7 +1072,11 @@ def fail_connection(self, code=1006, reason=''): """ logger.debug( - "%s ! failing WebSocket connection: %d %s", self.side, code, reason + "%s ! failing WebSocket connection in the %s state: %d %s", + self.side, + self.state.name, + code, + reason or '[no reason]', ) # Cancel transfer_data_task if the opening handshake succeeded. @@ -1199,7 +1203,7 @@ def connection_lost(self, exc): "%s x code = %d, reason = %s", self.side, self.close_code, - self.close_reason or '[empty]', + self.close_reason or '[no reason]', ) self.abort_keepalive_pings() # If self.connection_lost_waiter isn't pending, that's a bug, because: From c3a81ff690acadaa3f57da2bbdaec69f873d4c0d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 27 Oct 2018 17:37:36 +0200 Subject: [PATCH 0483/1539] Convert wait_closed() to a regular coroutine. Refs #469. --- src/websockets/protocol.py | 5 +++-- tests/test_protocol.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 2bbffaf7b..5a609a495 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -352,9 +352,10 @@ def closed(self): """ return self.state is State.CLOSED + @asyncio.coroutine def wait_closed(self): """ - Return a :class:`asyncio.Future` that completes when the connection is closed. + Wait until the connection is closed. This is identical to :attr:`closed`, except it can be awaited. @@ -362,7 +363,7 @@ def wait_closed(self): of its cause, in tasks that interact with the WebSocket connection. """ - return asyncio.shield(self.connection_lost_waiter) + yield from asyncio.shield(self.connection_lost_waiter) @asyncio.coroutine def recv(self): diff --git a/tests/test_protocol.py b/tests/test_protocol.py index d9aea196b..cceb63e85 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -371,9 +371,10 @@ def test_closed(self): self.assertTrue(self.protocol.closed) def test_wait_closed(self): - self.assertFalse(self.protocol.wait_closed().done()) + wait_closed = asyncio_ensure_future(self.protocol.wait_closed()) + self.assertFalse(wait_closed.done()) self.close_connection() - self.assertTrue(self.protocol.wait_closed().done()) + self.assertTrue(wait_closed.done()) # Test the recv coroutine. From 71c4db9c5947dc7a69eb05513bfc46a1a9771c1a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 29 Sep 2018 18:27:06 +0200 Subject: [PATCH 0484/1539] Simplify connection termination on server shutdown. Previously, there were two scenarios where a connection handler on the server side would receive an unexpected connection termination: 1. When the remote endpoint closed the connection or it dropped for any reason: this was signalled with a ConnectionClosed exception when interacting with the connection, usually thrown by recv() since most connection handlers spend the majority of their time awaiting recv() 2. When the server shut down: this caused the connection handler to be cancelled, meaning that a CancelledError exception was injected at an arbitrary location. Now, when the server shuts down, scenario 1 also applies. This removes the need to prepare for CancelledError in connection handlers. A good way to understand this change is to consider that "shutting down the server" means "shutting down all connections". Given that WebSocket connections are long-running, I find it sensible to require connection handlers to detect when the connection they're managing drops and to terminate. I expect this to make it easier for users to write connection handlers that behave correctly when the server shuts down, for two reasons: - There's only one scenario to implement, which reduces the amount of exception handling required. In addition, this avoids the specially messy situation where the clean-up process for ConnectionClosed gets interrupted by CancelledError. - Users are much more likely to check what happens when the connection drops than when the server shuts down. Making the latter behave like the former gives them proper behavior for free. Another argument for this change is that cancellation wasn't properly used there. Cancellation should only be used for terminating the execution of a task that one started when one no longer cares about its result. Since websockets runs a connection handler provided by the user, it can't make the decision that whatever code comes next doesn't matter, especially finalization code. The counter-argument against this change is that shutdown will be much slower for a server that reads events from an external source and sends them over the WebSocket connection. Previously, such a server would close immediately. Now, unless the connection handler checks for termination explicitly, it won't notice that the server is closing until it attempts to send the next event, which could take some time. Fix #338 and #394. Revert #337 and #392. --- docs/changelog.rst | 17 ++++ docs/cheatsheet.rst | 9 +++ docs/design.rst | 34 +++++++- src/websockets/exceptions.py | 8 ++ src/websockets/protocol.py | 2 +- src/websockets/server.py | 145 +++++++++++++++++++---------------- tests/test_client_server.py | 53 +++++++------ 7 files changed, 170 insertions(+), 98 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index f4df0ff03..f1bfd0b8f 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -18,6 +18,23 @@ Changelog For backwards compatibility, ``timeout`` is still supported. +.. warning:: + + **Version 7.0 changes how a server terminates connections when it's + closed with :meth:`~websockets.server.WebSocketServer.close`.** + + Previously, connections handlers were canceled. Now, connections are + closed with close code 1001 (going away). From the perspective of the + connection handler, this is the same as if the remote endpoint was + disconnecting. This removes the need to prepare for + :exc:`~asyncio.CancelledError` in connection handlers. + + You can restore the previous behavior by adding the following line at the + beginning of connection handlers:: + + def handler(ws, path): + ws.wait_closed().add_done_callback(asyncio.current_task().cancel) + .. warning:: **Version 7.0 changes how a** :meth:`~protocol.WebSocketCommonProtocol.ping` diff --git a/docs/cheatsheet.rst b/docs/cheatsheet.rst index 2f6a47e9c..3b8993a8c 100644 --- a/docs/cheatsheet.rst +++ b/docs/cheatsheet.rst @@ -13,6 +13,15 @@ Server :meth:`~protocol.WebSocketCommonProtocol.send` to receive and send messages at any time. + * When :meth:`~protocol.WebSocketCommonProtocol.recv` or + :meth:`~protocol.WebSocketCommonProtocol.send` raises + :exc:`~exceptions.ConnectionClosed`, clean up and exit. If you started + other :class:`asyncio.Task`, terminate them before exiting. + + * If you aren't awaiting :meth:`~protocol.WebSocketCommonProtocol.recv`, + consider awaiting :meth:`~protocol.WebSocketCommonProtocol.wait_closed` + to detect quickly when the connection is closed. + * You may :meth:`~protocol.WebSocketCommonProtocol.ping` or :meth:`~protocol.WebSocketCommonProtocol.pong` if you wish but it isn't needed in general. diff --git a/docs/design.rst b/docs/design.rst index 2688a27dc..244d233e1 100644 --- a/docs/design.rst +++ b/docs/design.rst @@ -340,22 +340,44 @@ sending a close frame if appropriate. the TCP connection. +.. _server-shutdown: + +Server shutdown +--------------- + +:class:`~websockets.server.WebSocketServer` closes asynchronously like +:class:`asyncio.Server`. The shutdown happen in two steps: + +1. Stop listening and accepting new connections; +2. Close established connections with close code 1001 (going away) or, if + the opening handshake is still in progress, with HTTP status code 503 + (Service Unavailable). + +The first call to :class:`~websockets.server.WebSocketServer.close` starts a +task that performs this sequence. Further calls are ignored. This is the +easiest way to make :class:`~websockets.server.WebSocketServer.close` and +:class:`~websockets.server.WebSocketServer.wait_closed` idempotent. + + .. _cancellation: Cancellation ------------ Most :doc:`public APIs ` of ``websockets`` are coroutines. They may be -canceled. ``websockets`` must handle this situation. +canceled, for example if the user starts a task that calls these coroutines +and cancels the task later. ``websockets`` must handle this situation. Cancellation during the opening handshake is handled like any other exception: -the TCP connection is closed and the exception is re-raised or logged. +the TCP connection is closed and the exception is re-raised. This can only +happen on the client side. On the server side, the opening handshake is +managed by ``websockets`` and nothing results in a cancellation. Once the WebSocket connection is established, :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` and :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` mustn't get -accidentally canceled if a coroutine that awaits them is canceled. They must -be shielded from cancellation. +accidentally canceled if a coroutine that awaits them is canceled. In other +words, they must be shielded from cancellation. :meth:`~protocol.WebSocketCommonProtocol.recv()` waits for the next message in the queue or for :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` @@ -388,6 +410,10 @@ Since :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` handles :exc:`~asyncio.CancelledError`, cancellation doesn't propagate to :attr:`~protocol.WebSocketCommonProtocol.close_connnection_task`. +Conversely, ``websockets`` never injects :exc:`~asyncio.CancelledError` into +user code. It doesn't cancel connection handler coroutines. Instead it expects +them to detect when the connection is closed and to exit. + .. _backpressure: diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index b34a2c0dc..b1618fa73 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -1,5 +1,6 @@ __all__ = [ 'AbortHandshake', + 'CancelHandshake', 'ConnectionClosed', 'DuplicateParameter', 'InvalidHandshake', @@ -43,6 +44,13 @@ def __init__(self, status, headers, body=b''): super().__init__(message) +class CancelHandshake(InvalidHandshake): + """ + Exception raised to cancel a handshake when the connection is closed. + + """ + + class InvalidMessage(InvalidHandshake): """ Exception raised when the HTTP message in a handshake request is malformed. diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 5a609a495..6147eff1e 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -960,7 +960,7 @@ def keepalive_ping(self): except asyncio.CancelledError: raise - except Exception as exc: + except Exception: logger.warning("Unexpected exception in keepalive ping task", exc_info=True) @asyncio.coroutine diff --git a/src/websockets/server.py b/src/websockets/server.py index 625c21209..857dde5de 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -14,13 +14,13 @@ BAD_REQUEST, FORBIDDEN, INTERNAL_SERVER_ERROR, - SERVICE_UNAVAILABLE, SWITCHING_PROTOCOLS, UPGRADE_REQUIRED, asyncio_ensure_future, ) from .exceptions import ( AbortHandshake, + CancelHandshake, InvalidHandshake, InvalidHeader, InvalidMessage, @@ -32,7 +32,7 @@ from .handshake import build_response, check_request from .headers import build_extension_list, parse_extension_list, parse_subprotocol_list from .http import USER_AGENT, Headers, MultipleValuesError, read_request -from .protocol import WebSocketCommonProtocol +from .protocol import State, WebSocketCommonProtocol __all__ = ['serve', 'unix_serve', 'WebSocketServerProtocol'] @@ -97,7 +97,8 @@ def handler(self): Handle the lifecycle of a WebSocket connection. Since this method doesn't have a caller able to handle exceptions, it - attemps to log relevant ones and close the connection properly. + attemps to log relevant ones and guarantees that the TCP connection is + closed before exiting. """ try: @@ -109,17 +110,14 @@ def handler(self): available_subprotocols=self.available_subprotocols, extra_headers=self.extra_headers, ) - except ConnectionError as exc: + except ConnectionError: logger.debug("Connection error in opening handshake", exc_info=True) raise + except CancelHandshake: + yield from self.fail_connection() + return except Exception as exc: - if self._is_server_shutting_down(exc): - status, headers, body = ( - SERVICE_UNAVAILABLE, - [], - b"Server is shutting down.\n", - ) - elif isinstance(exc, AbortHandshake): + if isinstance(exc, AbortHandshake): status, headers, body = exc.status, exc.headers, exc.body elif isinstance(exc, InvalidOrigin): logger.debug("Invalid origin", exc_info=True) @@ -157,29 +155,23 @@ def handler(self): yield from self.write_http_response(status, headers, body) yield from self.fail_connection() - return try: yield from self.ws_handler(self, path) - except Exception as exc: - if self._is_server_shutting_down(exc): - if not self.closed: - self.fail_connection(1001) - else: - logger.error("Error in connection handler", exc_info=True) - if not self.closed: - self.fail_connection(1011) + except Exception: + logger.error("Error in connection handler", exc_info=True) + if not self.closed: + self.fail_connection(1011) raise try: yield from self.close() - except ConnectionError as exc: + except ConnectionError: logger.debug("Connection error in closing handshake", exc_info=True) raise - except Exception as exc: - if not self._is_server_shutting_down(exc): - logger.warning("Error in closing handshake", exc_info=True) + except Exception: + logger.warning("Error in closing handshake", exc_info=True) raise except Exception: @@ -196,13 +188,6 @@ def handler(self): # connections before terminating. self.ws_server.unregister(self) - def _is_server_shutting_down(self, exc): - """ - Decide whether an exception means that the server is shutting down. - - """ - return isinstance(exc, asyncio.CancelledError) and self.ws_server.closing - @asyncio.coroutine def read_http_request(self): """ @@ -467,8 +452,14 @@ def handshake( # Hook for customizing request handling, for example checking # authentication or treating some paths as plain HTTP endpoints. - early_response = yield from self.process_request(path, request_headers) + + # Give up immediately and don't attempt to write a HTTP response if + # the TCP connection was closed while process_request() was running. + # This happens if the server shuts down and calls fail_connection(). + if self.state != State.CONNECTING: + raise CancelHandshake() + if early_response is not None: raise AbortHandshake(*early_response) @@ -538,9 +529,15 @@ def __init__(self, loop): # Store a reference to loop to avoid relying on self.server._loop. self.loop = loop - self.closing = False + # Keep track of active connections. self.websockets = set() + # Task responsible for closing the server and terminating connections. + self.close_task = None + + # Completed when the server is closed and connections are terminated. + self.closed_waiter = asyncio.Future(loop=loop) + def wrap(self, server): """ Attach to a given :class:`~asyncio.Server`. @@ -574,53 +571,65 @@ def unregister(self, protocol): def close(self): """ - Close the underlying server, and clean up connections. + Close the server and terminate connections with close code 1001. - This calls :meth:`~asyncio.Server.close` on the underlying - :class:`~asyncio.Server` object, closes open connections with - status code 1001, and stops accepting new connections. + This method is idempotent. """ - # Make a note that the server is shutting down. Websocket connections - # check this attribute to decide to send a "going away" close code. - self.closing = True + if self.close_task is None: + self.close_task = asyncio_ensure_future(self._close(), loop=self.loop) - # Stop accepting new connections. - self.server.close() + @asyncio.coroutine + def _close(self): + """ + Implementation of :meth:`close`. - # Close open connections. For each connection, two tasks are running: - # 1. self.transfer_data_task receives incoming WebSocket messages - # 2. self.handler_task runs the opening handshake, the handler provided - # by the user and the closing handshake - # In the general case, cancelling the handler task will cause the - # handler provided by the user to exit with a CancelledError, which - # will then cause the transfer data task to terminate. - for websocket in self.websockets: - websocket.handler_task.cancel() + This calls :meth:`~asyncio.Server.close` on the underlying + :class:`~asyncio.Server` object to stop accepting new connections and + then closes open connections with close code 1001. - @asyncio.coroutine - def wait_closed(self): """ - Wait until the underlying server and all connections are closed. + # Stop accepting new connections. + self.server.close() - This calls :meth:`~asyncio.Server.wait_closed` on the underlying - :class:`~asyncio.Server` object and waits until closing handshakes - are complete and all connections are closed. + # Wait until self.server.close() completes. + yield from self.server.wait_closed() - This method must be called after :meth:`close()`. + # Wait until all accepted connections reach connection_made() and call + # register(). See https://bugs.python.org/issue34852 for details. + yield from asyncio.sleep(0) + + # Close open connections. fail_connection() will cancel the transfer + # data task, which is expected to cause the handler task to terminate. + for websocket in self.websockets: + websocket.fail_connection(1001) - """ # asyncio.wait doesn't accept an empty first argument. if self.websockets: - # Either the handler or the connection can terminate first, - # depending on how the client behaves and the server is - # implemented. + # The connection handler can terminate before or after the + # connection closes. Wait until both are done to avoid leaking + # running tasks. + # TODO: it would be nicer to wait only for the connection handler + # and let the handler wait for the connection to close. yield from asyncio.wait( [websocket.handler_task for websocket in self.websockets] + [websocket.close_connection_task for websocket in self.websockets], loop=self.loop, ) - yield from self.server.wait_closed() + + # Tell wait_closed() to return. + self.closed_waiter.set_result(None) + + @asyncio.coroutine + def wait_closed(self): + """ + Wait until the server is closed and all connections are terminated. + + When :meth:`wait_closed()` returns, all TCP connections are closed and + there are no pending tasks left. + + """ + yield from asyncio.shield(self.closed_waiter) @property def sockets(self): @@ -694,11 +703,11 @@ class Serve: delegates to the WebSocket handler. Once the handler completes, the server performs the closing handshake and closes the connection. - When a server is closed with - :meth:`~websockets.server.WebSocketServer.close`, all running WebSocket - handlers are canceled. They may intercept :exc:`~asyncio.CancelledError` - and perform cleanup actions before re-raising that exception. If a handler - started new tasks, it should cancel them as well in that case. + When a server is closed with :meth:`~WebSocketServer.close`, it closes all + connections with close code 1001 (going away). WebSocket handlers — which + are running the coroutine passed in the ``ws_handler`` — will receive a + :exc:`~websockets.exceptions.ConnectionClosed` exception on their current + or next interaction with the WebSocket connection. Since there's no useful way to propagate exceptions triggered in handlers, they're sent to the ``'websockets.server'`` logger instead. Debugging is diff --git a/tests/test_client_server.py b/tests/test_client_server.py index 55fee7340..20a78374a 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -64,11 +64,8 @@ def handler(ws, path): elif path == '/subprotocol': yield from ws.send(repr(ws.subprotocol)) elif path == '/slow_stop': - try: - yield from asyncio.sleep(1000 * MS) - except asyncio.CancelledError: - yield from asyncio.sleep(MS) - raise + yield from ws.wait_closed() + yield from asyncio.sleep(2 * MS) else: yield from ws.send((yield from ws.recv())) @@ -174,6 +171,12 @@ def process_request(self, path, request_headers): return OK, [('X-Access', 'OK')], b'status = green\n' +class SlowServerProtocol(WebSocketServerProtocol): + @asyncio.coroutine + def process_request(self, path, request_headers): + yield from asyncio.sleep(10 * MS) + + class FooClientProtocol(WebSocketClientProtocol): pass @@ -286,16 +289,6 @@ def test_basic(self): reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") - def test_server_close_while_client_connected(self): - with self.temp_server(loop=self.loop): - # This endpoint waits just a bit when the connection is canceled - # in order to test that wait_closed() really waits for completion. - self.start_client('/slow_stop') - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(self.client.recv()) - # Connection ends with 1001 going away. - self.assertEqual(self.client.close_code, 1001) - def test_explicit_event_loop(self): with self.temp_server(loop=self.loop): with self.temp_client(loop=self.loop): @@ -923,18 +916,12 @@ def test_client_closes_connection_before_handshake(self, handshake): # The server should stop properly anyway. It used to hang because the # task handling the connection was waiting for the opening handshake. - @with_server() - @unittest.mock.patch('websockets.server.read_request') - def test_server_shuts_down_during_opening_handshake(self, _read_request): - _read_request.side_effect = asyncio.CancelledError - - self.server.closing = True - with self.assertRaises(InvalidHandshake) as raised: + @with_server(create_protocol=SlowServerProtocol) + def test_server_shuts_down_during_opening_handshake(self): + self.loop.call_later(5 * MS, self.server.close) + with self.assertRaises(InvalidHandshake): self.start_client() - # Opening handshake fails with 503 Service Unavailable - self.assertEqual(str(raised.exception), "Status code not 101: 503") - @with_server() def test_server_shuts_down_during_connection_handling(self): with self.temp_client(): @@ -959,6 +946,22 @@ def test_server_shuts_down_during_connection_close(self, _close): # Websocket connection terminates abnormally. self.assertEqual(self.client.close_code, 1006) + @with_server() + def test_server_shuts_down_waits_until_handlers_terminate(self): + # This handler waits a bit after the connection is closed in order + # to test that wait_closed() really waits for handlers to complete. + self.start_client('/slow_stop') + server_ws = next(iter(self.server.websockets)) + + # Test that the handler task keeps running after close(). + self.server.close() + self.loop.run_until_complete(asyncio.sleep(MS)) + self.assertFalse(server_ws.handler_task.done()) + + # Test that the handler task terminates before wait_closed() returns. + self.loop.run_until_complete(self.server.wait_closed()) + self.assertTrue(server_ws.handler_task.done()) + @with_server(create_protocol=ForbiddenServerProtocol) def test_invalid_status_error_during_client_connect(self): with self.assertRaises(InvalidStatusCode) as raised: From 64575c87c82b34a77a358ccd25f90a537ef570a0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 28 Oct 2018 15:01:34 +0100 Subject: [PATCH 0485/1539] Add support for python setup.py test. Tell setuptools to use the default unittest loader, unittest.TestLoader, rather than its own version, setuptools.command.test.ScanningLoader, as documented here: https://setuptools.readthedocs.io/en/latest/setuptools.html#test-loader Fix #415. --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 63401d49d..3a86887aa 100644 --- a/setup.py +++ b/setup.py @@ -67,4 +67,5 @@ include_package_data=True, zip_safe=True, python_requires='>=3.4', + test_loader='unittest:TestLoader', ) From c1b74b61dc07d5a1359307a1ff50510e6026765f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 28 Oct 2018 15:47:22 +0100 Subject: [PATCH 0486/1539] Fix backwards-compatibility recipe. Refs #338. --- docs/changelog.rst | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index f1bfd0b8f..b86faec29 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -32,8 +32,9 @@ Changelog You can restore the previous behavior by adding the following line at the beginning of connection handlers:: - def handler(ws, path): - ws.wait_closed().add_done_callback(asyncio.current_task().cancel) + def handler(websocket, path): + closed = asyncio.ensure_future(websocket.wait_closed()) + closed.add_done_callback(lambda task: task.cancel()) .. warning:: From b64fee8eb849d292ddc0e0246c20bd4196cda030 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 28 Oct 2018 20:39:34 +0100 Subject: [PATCH 0487/1539] Support common customizations without subclassing. This reduces the amount of boilerplate required for defining a custom process_request. Fix #495. --- docs/api.rst | 8 +++---- docs/changelog.rst | 6 ++++++ docs/deployment.rst | 2 +- example/health_check_server.py | 10 ++++----- src/websockets/client.py | 8 +++---- src/websockets/server.py | 39 ++++++++++++++++++++++++++++------ tests/test_client_server.py | 16 ++++++++++++++ 7 files changed, 68 insertions(+), 21 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index 39b5922f9..3971ff8b4 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -32,9 +32,9 @@ Server .. automodule:: websockets.server - .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, compression='deflate', **kwds) + .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, **kwds) - .. autofunction:: unix_serve(ws_handler, path, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, compression='deflate', **kwds) + .. autofunction:: unix_serve(ws_handler, path, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, **kwds) .. autoclass:: WebSocketServer @@ -43,7 +43,7 @@ Server .. automethod:: wait_closed() .. autoattribute:: sockets - .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, host=None, port=None, secure=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, extensions=None, subprotocols=None, extra_headers=None) + .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, host=None, port=None, secure=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None) .. automethod:: handshake(origins=None, available_extensions=None, available_subprotocols=None, extra_headers=None) .. automethod:: process_request(path, request_headers) @@ -54,7 +54,7 @@ Client .. automodule:: websockets.client - .. autofunction:: connect(uri, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, compression='deflate', **kwds) + .. autofunction:: connect(uri, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, **kwds) .. autoclass:: WebSocketClientProtocol(*, host=None, port=None, secure=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, extensions=None, subprotocols=None, extra_headers=None) diff --git a/docs/changelog.rst b/docs/changelog.rst index b86faec29..2fe4a4dfc 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -50,6 +50,12 @@ Changelog if it doesn't receive a matching Pong frame. See :class:`~protocol.WebSocketCommonProtocol` for details. +* Added ``process_request`` and ``select_subprotocol`` arguments to + :func:`~server.serve()` and :class:`~server.WebSocketServerProtocol` to + customize :meth:`~server.WebSocketServerProtocol.process_request` and + :meth:`~server.WebSocketServerProtocol.select_subprotocol` without + subclassing :class:`~server.WebSocketServerProtocol` + * Added support for sending fragmented messages. * Added the :meth:`~protocol.WebSocketCommonProtocol.wait_closed` method to diff --git a/docs/deployment.rst b/docs/deployment.rst index 15f722eea..0f571520d 100644 --- a/docs/deployment.rst +++ b/docs/deployment.rst @@ -78,4 +78,4 @@ the :meth:`~server.WebSocketServerProtocol.process_request()` hook. Typical use cases include health checks. Here's an example: .. literalinclude:: ../example/health_check_server.py - :emphasize-lines: 9-13,19-20 + :emphasize-lines: 9-11,17-18 diff --git a/example/health_check_server.py b/example/health_check_server.py index 89fd1e2ff..8e70890b5 100755 --- a/example/health_check_server.py +++ b/example/health_check_server.py @@ -6,18 +6,16 @@ import http import websockets -class ServerProtocol(websockets.WebSocketServerProtocol): - - async def process_request(self, path, request_headers): - if path == '/health/': - return http.HTTPStatus.OK, [], b'OK\n' +def health_check(path, request_headers): + if path == '/health/': + return http.HTTPStatus.OK, [], b'OK\n' async def echo(websocket, path): async for message in websocket: await websocket.send(message) start_server = websockets.serve( - echo, 'localhost', 8765, create_protocol=ServerProtocol) + echo, 'localhost', 8765, process_request=health_check) asyncio.get_event_loop().run_until_complete(start_server) asyncio.get_event_loop().run_forever() diff --git a/src/websockets/client.py b/src/websockets/client.py index b1b14bc62..bd2544862 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -331,6 +331,9 @@ class Connect: :func:`connect` also accepts the following optional arguments: + * ``compression`` is a shortcut to configure compression extensions; + by default it enables the "permessage-deflate" extension; set it to + ``None`` to disable compression * ``origin`` sets the Origin HTTP header * ``extensions`` is a list of supported extensions in order of decreasing preference @@ -340,9 +343,6 @@ class Connect: :class:`~websockets.http.Headers` instance, a :class:`~collections.abc.Mapping`, or an iterable of ``(name, value)`` pairs - * ``compression`` is a shortcut to configure compression extensions; - by default it enables the "permessage-deflate" extension; set it to - ``None`` to disable compression :func:`connect` raises :exc:`~websockets.uri.InvalidURI` if ``uri`` is invalid and :exc:`~websockets.handshake.InvalidHandshake` if the opening @@ -366,11 +366,11 @@ def __init__( legacy_recv=False, klass=WebSocketClientProtocol, timeout=10, + compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, - compression='deflate', **kwds ): if loop is None: diff --git a/src/websockets/server.py b/src/websockets/server.py index 857dde5de..2b6ab0529 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -64,6 +64,8 @@ def __init__( extensions=None, subprotocols=None, extra_headers=None, + process_request=None, + select_subprotocol=None, **kwds ): # For backwards-compatibility with 6.0 or earlier. @@ -76,6 +78,10 @@ def __init__( self.available_extensions = extensions self.available_subprotocols = subprotocols self.extra_headers = extra_headers + if process_request is not None: + self.process_request = process_request + if select_subprotocol is not None: + self.select_subprotocol = select_subprotocol super().__init__(**kwds) def connection_made(self, transport): @@ -260,6 +266,10 @@ def process_request(self, path, request_headers): It is declared as a coroutine because such authentication checks are likely to require network requests. + This coroutine may be overridden by passing a ``process_request`` + argument to the :class:`WebSocketServerProtocol` contstructor or the + :func:`serve` function. + """ @staticmethod @@ -407,6 +417,10 @@ def select_subprotocol(client_subprotocols, server_subprotocols): many servers providing a subprotocol will require that the client uses that subprotocol. Such rules can be implemented in a subclass. + This method may be overridden by passing a ``select_subprotocol`` + argument to the :class:`WebSocketServerProtocol` contstructor or the + :func:`serve` function. + """ subprotocols = set(client_subprotocols) & set(server_subprotocols) if not subprotocols: @@ -452,7 +466,10 @@ def handshake( # Hook for customizing request handling, for example checking # authentication or treating some paths as plain HTTP endpoints. - early_response = yield from self.process_request(path, request_headers) + if asyncio.iscoroutinefunction(self.process_request): + early_response = yield from self.process_request(path, request_headers) + else: + early_response = self.process_request(path, request_headers) # Give up immediately and don't attempt to write a HTTP response if # the TCP connection was closed while process_request() was running. @@ -683,6 +700,9 @@ class Serve: :func:`serve` also accepts the following optional arguments: + * ``compression`` is a shortcut to configure compression extensions; + by default it enables the "permessage-deflate" extension; set it to + ``None`` to disable compression * ``origins`` defines acceptable Origin HTTP headers — include ``None`` if the lack of an origin is acceptable * ``extensions`` is a list of supported extensions in order of @@ -693,10 +713,13 @@ class Serve: :class:`~websockets.http.Headers` instance, a :class:`~collections.abc.Mapping`, an iterable of ``(name, value)`` pairs, or a callable taking the request path and headers in arguments - and returning one of the above. - * ``compression`` is a shortcut to configure compression extensions; - by default it enables the "permessage-deflate" extension; set it to - ``None`` to disable compression + and returning one of the above + * ``process_request`` is a callable or a coroutine taking the request path + and headers in argument, see + :meth:`~WebSocketServerProtocol.process_request` for details + * ``select_subprotocol`` is a callable taking the subprotocols offered by + the client and available on the server in argument, see + :meth:`~WebSocketServerProtocol.select_subprotocol` for details Whenever a client connects, the server accepts the connection, creates a :class:`WebSocketServerProtocol`, performs the opening handshake, and @@ -739,11 +762,13 @@ def __init__( legacy_recv=False, klass=WebSocketServerProtocol, timeout=10, + compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, - compression='deflate', + process_request=None, + select_subprotocol=None, **kwds ): # Backwards-compatibility: close_timeout used to be called timeout. @@ -793,6 +818,8 @@ def __init__( extensions=extensions, subprotocols=subprotocols, extra_headers=extra_headers, + process_request=process_request, + select_subprotocol=select_subprotocol, ) if path is None: diff --git a/tests/test_client_server.py b/tests/test_client_server.py index 20a78374a..dee44a662 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -362,6 +362,22 @@ def test_unix_socket(self): client_socket.close() self.stop_server() + @with_server(process_request=lambda p, rh: (OK, [], b'OK\n')) + def test_process_request_argument(self): + response = self.loop.run_until_complete(self.make_http_request('/')) + + with contextlib.closing(response): + self.assertEqual(response.code, 200) + + @with_server( + subprotocols=['superchat', 'chat'], select_subprotocol=lambda cs, ss: 'chat' + ) + @with_client('/subprotocol', subprotocols=['superchat', 'chat']) + def test_select_subprotocol_argument(self): + server_subprotocol = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(server_subprotocol, repr('chat')) + self.assertEqual(self.client.subprotocol, 'chat') + @with_server() @with_client('/attributes') def test_protocol_attributes(self): From 324d436648c816b06b42cf68702c8647b7660c57 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 28 Oct 2018 20:40:45 +0100 Subject: [PATCH 0488/1539] Show coverage info when running make coverage. --- Makefile | 1 + 1 file changed, 1 insertion(+) diff --git a/Makefile b/Makefile index 7de30b002..2d77dcfc7 100644 --- a/Makefile +++ b/Makefile @@ -13,6 +13,7 @@ coverage: python -m coverage erase python -W default -m coverage run -m unittest python -m coverage html + python -m coverage report --show-missing --fail-under=100 clean: find . -name '*.pyc' -o -name '*.so' -delete From 21d3bed6bcdd4a0374e28dfee0887584ec64e12c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 28 Oct 2018 20:48:19 +0100 Subject: [PATCH 0489/1539] Simplify fail_connection. Wait for the connection to terminate separately. Avoid the weird "yield from / await function that returns a Future" pattern. Fix #498. --- src/websockets/client.py | 3 ++- src/websockets/protocol.py | 5 ----- src/websockets/py35/client.py | 3 ++- src/websockets/server.py | 6 ++++-- 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index bd2544862..9f92f18e8 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -452,7 +452,8 @@ def __iter__(self): # pragma: no cover extra_headers=protocol.extra_headers, ) except Exception: - yield from protocol.fail_connection() + protocol.fail_connection() + yield from protocol.wait_closed() raise self.ws_client = protocol diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 6147eff1e..5f7f3b0cd 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -1068,9 +1068,6 @@ def fail_connection(self, code=1006, reason=''): (The specification describes these steps in the opposite order.) - Return a :class:`~asyncio.Task` that completes when the TCP connection - is closed. - """ logger.debug( "%s ! failing WebSocket connection in the %s state: %d %s", @@ -1117,8 +1114,6 @@ def fail_connection(self, code=1006, reason=''): self.close_connection(), loop=self.loop ) - return self.close_connection_task - def abort_keepalive_pings(self): """ Raise ConnectionClosed in pending keepalive pings. diff --git a/src/websockets/py35/client.py b/src/websockets/py35/client.py index f62e7d69e..a016ba437 100644 --- a/src/websockets/py35/client.py +++ b/src/websockets/py35/client.py @@ -20,7 +20,8 @@ async def __await_impl__(self): extra_headers=protocol.extra_headers, ) except Exception: - await protocol.fail_connection() + protocol.fail_connection() + await protocol.wait_closed() raise self.ws_client = protocol diff --git a/src/websockets/server.py b/src/websockets/server.py index 2b6ab0529..be9abaad6 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -120,7 +120,8 @@ def handler(self): logger.debug("Connection error in opening handshake", exc_info=True) raise except CancelHandshake: - yield from self.fail_connection() + self.fail_connection() + yield from self.wait_closed() return except Exception as exc: if isinstance(exc, AbortHandshake): @@ -160,7 +161,8 @@ def handler(self): headers.setdefault('Connection', 'close') yield from self.write_http_response(status, headers, body) - yield from self.fail_connection() + self.fail_connection() + yield from self.wait_closed() return try: From b870638044c8b4e9c5944b05142014258654f79c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 30 Oct 2018 13:05:14 +0100 Subject: [PATCH 0490/1539] Rewrite recv() to avoid a data loss bug. A race condition between receiving a message and cancelling recv() could result in losing the message. Fix #486 and #470. --- docs/changelog.rst | 11 ++++ docs/design.rst | 7 +-- docs/protocol.graffle | Bin 4656 -> 4664 bytes docs/protocol.svg | 2 +- src/websockets/protocol.py | 104 ++++++++++++++++++++++++------------- tests/test_protocol.py | 63 ++++++++++++++++++++++ 6 files changed, 144 insertions(+), 43 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 2fe4a4dfc..18f5f0764 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -46,6 +46,14 @@ Changelog :exc:`~asyncio.CancelledError`. Now ``await ping`` raises :exc:`~exceptions.ConnectionClosed` like other public APIs. +.. warning:: + + **Version 7.0 raises a** :exc:`RuntimeError` **exception if two coroutines + call** :meth:`~protocol.WebSocketCommonProtocol.recv` **concurrently.** + + Concurrent calls lead to non-deterministic behavior because there are no + guarantees about which coroutine will receive which message. + * websockets sends Ping frames at regular intervals and closes the connection if it doesn't receive a matching Pong frame. See :class:`~protocol.WebSocketCommonProtocol` for details. @@ -66,6 +74,9 @@ Changelog * Changed the ``origins`` argument to represent the lack of an origin with ``None`` rather than ``''``. +* Fixed a data loss bug in :meth:`~protocol.WebSocketCommonProtocol.recv`: + cancelling it at the wrong time could result in messages being dropped. + * Improved handling of multiple HTTP headers with the same name. * Improved error messages when a required HTTP header is missing. diff --git a/docs/design.rst b/docs/design.rst index 244d233e1..349c63b31 100644 --- a/docs/design.rst +++ b/docs/design.rst @@ -481,7 +481,7 @@ For each connection, the receiving side contains these buffers: - :class:`~asyncio.StreamReader` bytes buffer: the default limit is 64kB. You can set another limit by passing a ``read_limit`` keyword argument to :func:`~client.connect()` or :func:`~server.serve()`. -- Incoming messages :class:`~asyncio.queues.Queue`: its size depends both on +- Incoming messages :class:`~collections.deque`: its size depends both on the size and the number of messages it contains. By default the maximum UTF-8 encoded size is 1MB and the maximum number is 32. In the worst case, after UTF-8 decoding, a single message could take up to 4MB of memory and @@ -511,11 +511,6 @@ As shown above, receiving frames is independent from sending frames. That isolates :meth:`~protocol.WebSocketCommonProtocol.recv()`, which receives frames, from the other methods, which send frames. -While :meth:`~protocol.WebSocketCommonProtocol.recv()` supports being called -multiple times concurrently, this is unlikely to be useful: when multiple -callers are waiting for the next message, exactly one of them will get it, but -there is no guarantee about which one. - Methods that send frames also support concurrent calls. While the connection is open, each frame is sent with a single write. Combined with the concurrency model of :mod:`asyncio`, this enforces serialization. After the connection is diff --git a/docs/protocol.graffle b/docs/protocol.graffle index 98b4cdb581d83dbd8aa3041d77ba003c72a71a96..13fdb307ef5907b16265b4fcb669180e509c71cc 100644 GIT binary patch literal 4664 zcmV-8636WyiwFP!000030PS7tbJIE&|9tr?vV7d7uHLp-E<2@#g}Xqzq`+RLGk4rX zO?`3fU^|pvhW~w!oR=SRvoAN=ya~Bal2=h*eS5O4zlDoIl&OJAo}wq)vyZbT zaPOe!XVB$`Nzd~}evq@uq(F^@ zbO?pv6ITd93M5aqXI)Q*FwWwg{-5JCo8LN1#%g3$N^I?4djvjT;k4K&Fm}2&*1H+O(ev}5RTkM8@T5sHq!(egON|RT63;NltlHvqU z$hz5is1}s7F;6Z+m+|Mt)xzb;^}WPDTMTP4u($ig;lPZ}r(*VL@nJHDsrnSh!xk%I z*N?vX>3$Lng2h^xO=doq#aFp#(GC7mHTuGw*ILElco^-3!62%Q3fD)KQ+G#x54xDs zC&~|>Qtrt7hcL+r_oRHXonM$Wg+KKA-t)8C013kG7Lr?tbufO+@M9`E|3L6l8?Ejn zPO9y9&W9=qdYzr|0M@UbuyV7#ARUGNdAH|>wR-~t!%wF-?EKo9f39}$A_&#t`KZ>| z2d1gg<_mQ%?u~~k%BtKvpITCz+vx&NHQ*Syzr12AY*yei#;L;@rHC>3pG)dJ>M##G zo+n*_q>GV|lxa@%xPz#qLUNB{B0Y~j>JZPNOfW7uaa|;(z-2BqvFUUUgl&{2||;3GdugZ#hVXn+YPN4>Zg`}z_` zZ1P7sK8C+NiXMDFI)Rm&9fwCJ=;&k^_f--do%BxqBvl#el-KV+e*k?i4&wwG{PQSL zeZu+CAWS?nll+(P^e;kK{>wx9 z7o|A~l{pDxISGs(OJyh zcY-=6i9b~74+B2}>SdtLgr*n4j9gjx^fN%mspoKy2$vxum?w#_KxYCvW`WKEodvp< zKzDqeT{+G*AkFI_QkRc2FID(e>5;dF7_u_NPy@v5C)sH{h$BCI<|o?V9D+&EE3}?j zoF@8%I4C~Md>K=Q9L!zjdXy55LCbyAseZ?)0o*ISr$JWG)Mrn=g^N#~Go2ns@m0-r z^x~id2qG$hqxQA}(_b$KuVEz- zk2;cZ0*gWtLMVkLL8Wj!Ee$aBTnT@9+DdYT!zmb4L=fg-A(7a0D+wF8SYy(v@F~p7 zdrjk*F9pyI5b&>%6m$lTB&q52@8T@?rd9+_letiRUTf=FTW>>by)+B~_+b**%64B5 zoAH7H5V|f4&=~Tls69elvjqst-ND$y$a5Kt2!^qIP|bQEBlRc_Z*1Kr z22^crZMYk@;Z_hUYnX6^Z@g8;<+@f`J*)sqG}v0&nZpqho&X)0H*q5u!jS}&ZNxdl z64dbJtt~g+9oFJx74HbOeCS;kAL^_Yq7$5exu}aUC5#EW3Lh?OZTavnxhb(_#FZG) zy-b1FKD=RM#3ijQBi`G1!A3EnUKy)SCuJ9V^UTaGuFL#3u`?44tWE5BF5=nEGCiL1rc<}i{e zpMhdn(ikIxJ;Xc&ay!rOcaei!&!rLzK|rLFI}k?NdKQQM_OYZgCbdFiQ$+ZUKZw*) zBtrmPO;`cfwi++u$}WN8wXOQaRY7D~kZ4zX_e&g7pqP*x7cb9M6NV-~V*0xj{hGEM z7+PWjU1Ge0H|ZE-Ir%`#(cdDZ-L3%E6-T6QCVDhwzS|o@`qFEP6&>qn5t~Mf!Z=kw zmTd@HA9A%GP!B5HPROnDer+Ki8wIVyOrmzM0&UWYw#IpvnrX0xbih_RSm|(8(qR>2 z{}zgeyD1qKa$yulgDVsZS`)|8b6B1(D0DiPbg2~!_dzTmD;BI+uwvnUiiO~6#e&}O z&#>cCLOdd<$C>o3f2Ij)nXpi6q1O6m?y-L+RZ)Lsz`b`Nf4*+WpGn~$L8K=MWs)!o za)S#;l*43=K`le;YW#rNjY)OjK=`J2Cba>t&EX%mHviPZ$1CxVp>@?6?0J@f?%mG5 zd&fX@Z3Ysq<6^-nagj?fqhcinGPJI`%{(iCECb!~T(Q+sM93P6Wo+$j6voV^_8e?z zUH5RLdg5HRbU|Ox*j>&`mL2Y5&L6)NK)3R7o#|XB@tV2oBwiQR z=L#F2iNlNaUUI7tNyin`bv=Lq#A$6K#Z_lsSE?;a!eLNEIJ%}32b)#GKesXaxmv3c zljV8a7S}c;{duMRUHB5ypR21f-%VhnYpo%X+vu8&uG#3?J&vv=s`qu1qHAmEp7_dr zGvwNIPa<7MtC^B=h9qZP?+5}>r#Jn9f~Q@q+~v2R=njm}p%l4Da0vruZsWeaYHR&S z%QRQ3du~YmoEzJEhnWjX0p>2iw!kiTi>jw|1ryxMbx+0}UB8iP><8e-xsqCQ6Kzld zM|KCdRm=)IRsm6>Dj?=>r~cA*&%*$8gQWhx?wUrsOjB*vak15C8*4_gORE_vT)d8= zdjy@{$JE}UHat{m>JLrZqYLJF}uiM_S_i!IL_Y^o|&t>U%(35)YHjq3}7gme*>&{}ebxSV?~^1w!h zr)q3%D2FDHbimB;FsF>`uCYbfO>}KiW^r~kvDFSbudx7QHSe#++zgAGOar2TI_`4B zX?P@pP@e4H8tpKq4&XfC9mdF`(pw2>joRquoStB@@*<6k+*AsZsqRa?ET~OEYC>aX zQ;=>;3X%=a*o32>4zZjiL8ku5Pv*MBGAgER6$W>+JyhO+u+$zZZx0%@hsw7+T(IwF z{xzK}xEkCU9PbwH3~+}n+*!D@aMu#DD>cdzk2+;U1xgPf z8Z!yBLO&j(w*X~7sHE>l{?#EXHb#INc4de0xPR$evR}aU#ymS}%(H`e?guK$R>#Qs zsab-#h|7w<(+G}jnk1_YV{pE4Ci$in9C6-oa(dKYTy0XSU8S? zuzxxy<{R|pLAO)v&=qw9JQ z28=J6Ht&g`4Dn($>q&}-V5Vx;jBJuI$80rgTg}>5v%ZxwYmDyGgu14 zkNN0qjq5eemjDrrz;+}pMY1b5UL+GMlC4O#BH4=MTaXiadqwiQd~{YMw-(8jT?-y_ zjXN1~S8lngCB3cH3a&gGmtTIBf?Irs=f!uhX|gs=_BQ!I{n01{nAf1%3FAS~yTqLf zCakR2aB0FyLsj4nz)Hqe2TKbZ*9;r=c}6Zhj9z~#nb0cH4geJeXT{)yrGc|3)CNiG&o)WUc+tjM5Y6M z!04RNh@lgx9ReY!1MqE9BgZxJ|0p#&jJrL*Fp7De`6S2OKUc}(&U62AMKtOMU*;Q` zP`>qt20wFtn+~p%A0);N5ylj*F4!F^Z8wY$fjV2-ZPNmW(Z=Jfd5P~t!O+iCZ2?t$ zMfz$lNP6*`AleJktile(W2+XlT?Oo0c3Bx*XgPhD5A%662%`E}`f*kpOr-oQ*1HN` zt7uSw=GhHq4mPl|kUn9HRc@HKUtxgRSv(3_(U+MGhWNROd+N+KjnS+?p91%EPgNnX zGg+egT>n-@>!qe_uTO*Xktund#A>a#8;WU``|@j5c&X z`DN@s^0z{jW#x!r!);2Igy}ZZc_|Wzu2@I(0q_?P8dycRM{DuWF&)<&Cuh;72{RV%zso z@MWwP7WM^C|2|dEqbx`+n)VNE70r)91V&^tl<$s{1ZE5dG@|jmEQoeH(_IuPRA{IQsQ5cF^OF zih6NBhz7qt{&4hi%X|E<=Z~KK&)&iA(I1EVolzL1S?BP>i#M-#JCC=vw|7RPP;GDT z9qn}v-@NV~b)b&z?frL;JCDz@Z1i+{`~3Xe@impxi-(#p-9Aj>kxH_QH&EFY6ma@k z|1mT(slC)4__ZJOvgeO}dG zBX<4htDhbu!5~=7h1p={V_AHZixS=7Z&jl&jCrk89FB+4P8bZLTCZ@uS2=Ze;uzOY4eG?7x%_P6=hX!o)0aljqP-Try6h!++Uut6*epI8ROL9j8eoH{O6K-Pdd!Q zj^{~NAn9TxBxRZtJ?Da0(=&ghOpH1m z#5l@cN2hT!pJ@sMo9)KqBv8pab$-}=dsM1;%G`xXuYb$u9Y*sBSY*u_YAha#5GC9& z0Y|ONr$HDNn_gTtLq`*K6ni)Yr8B=DpI=km6~Oz6x+mnHf_`>ZBA+}1hZi0~=_ynm z$dUS%HIH%$pNfwRm7myz`q1y!W`kGiN4>=b2OTHbDL(e2G|2z;MgvSZIqt>1*w>dh zVv{fF_yoRt6g~QWd|2X@`1%m$=ll)D1`i&5lzj;W% zQJRxbnUgSW^c zi=7f6h_ujQo0{L~NA z=~SH&0jm~m#1edy3VctI*$=|ideR_%kmI(NKhS3LlPU7@qW9b=eag_Ypsp$x8wh^H z0DkRoqsHaSxo=UL5$*t$$6SfIV3OgD;@vJh)i2Ylb=J{EOCF#`?QI37zgZ4m!%89^ zbtK~iCWRz~PzqCmO5u1~8erFyxlp}dYwKBCZ$oUoGzpc>dT>e*x+x3Lxa3h$dxW@Q3lOHegRzH^=Q8LKbYpp^n)N_N>Z>@sv2>d? zplWGr!`-qCw}Mz%!-OMzViIGHvA?C_e5Xulld4=Zp<)xv^4xxfzmr2ic<)*ITY`786#P#B8^D>et zo`GVR(ikIxJ;Xc&ayyUjcaei!&!rLzK|rLFd$5eO^&$@Y?bnjZnA8f5O)J80{XwLb zA{heUYQhS*w$*qMS9S>$uWi*Yt_mW{7KwJXcfZ6T1&RsDN%8VrHDPG-64Uq6>esa7 zz)%t!=n`WK-lSrT<>Uh`ul^Px?QSh#UGs|6?W`V6neXl{A^p{BS}Quyt3_@K!rg@2D(}}O@`+K>wH9ivf93)EXHpgQR|eb%x8%>)UGisAI7krbNkW+< zjDpcD~UZSPEK174fMKWuIOsVyI`#6O18RcEm0Sq6G= zJNq6S1JSh^NVtxR1*gPCF2Rh7l^DoSy6QIbtOT+QbkAeORtpg!YXp|DwYO0iBb(ZD zu%UF_!(rp1tVVdF0*$S`o1yB7bJ@}feMw<=IWJjuxQ{V^{7L}b%ExtvbDh9z=B^WX zU6`LMY8Qh#3eO`~0=sW$7l*lM(mIV0H}$r(A* zdyk=whZyNwRE~!#P5psNk8kQogQjlbyUjl`{C<0p_=^_I&FYpoKqX46WvQhAlDUWSC!ZI``T zoNZuL^aqfy(k^@b9Fy~PnEXnYwyW_Q%?UG&r7K4p2xgi5{xdnXOn!w-j@fnl7pZ%j zNyX~Ft?OQN9YP9AcM^MRxf@#^LD-aJteVto`5hL=XS&QU2olmoSVC#Z9pZBCxyS>P z8SbjFv>`j1K+*v-!^501uB*ltVK>p!Ntwmj)x=gi=)A#vjMXr}8gnx&ZZj%~0_wQS z5vSpn2ts-6e`~bEm^y&-fOi-pk4kSPq%~@zn{$|g#mdVxE^=EDNv1R~^(>(_BB=?D znT<%gD-lU{wZ_ID{qz#cc@kvm&-|dSE6l25+GSyIH`_Dj4G2r^8T0nsL3_r0%d-ai ze&*lML4&Koox$;L;m!be*utHKI}3L$;qER1UHyCy3v>pcLl)>P&{?2s33MkH*)8In zY2RY?Xkk}aH(?Xl5t|m&x(ThD@F8te^!tBK!?CZ>PiVy1dQhvy*(RK|8^kTnTAY1g zIIDe{SH2`;g4${vp4DiZk(OGdwMc7`_F*CIO{{whmpK|b1qXy>pw3E-vc#iK8Bu}K z1Bk{<0cv<%GUh!alr(QLKferzlp zM?qM?;weMk7f)}SGo3(*I**et8-L0OAvYE4P6>CU#|Q^4%@ZJpc#DlJ*B$F#jS!Q3 zl~{MvExBTn+l2_iJeb>di}QxYo3rbUH{Y}kljOQLOcH!kNNCE(cI!xJ>eMnUyF4G8 z-KF@DkZ=vt)U*|>YJ_@u=*}W_A0b${QfE%9{wCxGh3;T!eSd9ze|lRu!Z1GosKAo( zaQmcXElMV9l-y8bcQ;NVmAv=+!8omh%*3G&sxp;evw0kg`e{WE+IxAw)2FpHv5OAR zEP>%XRmo&|Df81W)5K<-TTibxkNP@-UNw0E7PICs^m}TkO6r_BW^WBSS&iFg0JhX4 zj?m(dVi*h$)G;k5P5Z?SN69=$Ir-{k7#K3e{juMz_uDK}-FqlgAA07{O<90zbu!+4 zPZatP70GuXCiL!#U;ffoEVGGUA7e@pG^1oJk_NZ zHz$j)W~s-{69*Kj^b|5FZ5y_c0LI?_b*hkXnF2#u82l``C($*A7PB>>Vg@f(steW7^w55)izCV=xsdCny2_q6b$`L)h1BIb)>KMf}|I} z4Whju%_{6rJjH53+f~56We=9Yqn6W$c{lf?K@ipZ(vP!RXCmcivEF3xMn!`HG|w8C z`P;zgLi)fhR%tLVzd{GI^LP}rqAxQWbn#0Q_tcqfx@5BgeG1&uJyV5sp2@_YbH&@Z zADm8Y-o3(;o>fW^cDIn+0zk#`DV0wN?fgTJT&=CVq5`i@FYjCoRTA_%JLADP&4Q>} z?j7KB`#U;F^rrCA7@N0^%&1|;w60u3(+hY^U(yz|nwy*X=3VzioMrLwt)C2n!gMGs za!ujt+z)(zk{eU@9GWpz}D&;8+x&m{oQ}{TlmK( zjMV?V2tU7m74NX$U;O^MyMMNGcJKjleDdkHus7I+@BM=h`19GxtB>IhT;JX2M@Qth zzdnC@hjw0n+TZE^1vD?eh3C8fzP$Ty)M@rJ!=pjvvic9foJdL-ZK!_oW~@K*w?dU= zuuNh z{Ey<%byQ!*@BBd=L`p+PgWjQ6l@6ltVmYR~kIze~axYPLfV84IP9zOoXqr=%5wQO# z$m;O7>qlSxw4|s^N(@1#&by&q3={Qkr>B>%T4;B9VXXl8QBSQ{_Cpl>HC77?`;xo= zkgEMC3zEx*{bO52^HUIk5!npoyW=E*k;q4~uBRK+^!=$x^yg%V3wqo__D($(NeP}7 z(WeYhfbe{?Aod7+!P+0^jPx@Hvi$RST@rk|z1+gXx+Y m;lL33VqlP847HnlFGVn_0mtHf6auF_fAs$g=`ROJ@Bjcs1u7i? diff --git a/docs/protocol.svg b/docs/protocol.svg index 7108927b8..301bb1b4c 100644 --- a/docs/protocol.svg +++ b/docs/protocol.svg @@ -1,3 +1,3 @@ - Produced by OmniGraffle 6.6.2 2017-09-24 19:39:13 +0000Canvas 1Layer 1remote endpointwebsocketsWebSocketCommonProtocolapplication logicreaderStreamReaderwriterStreamWriterpingsdicttransfer_data_taskTasknetworkread_frameread_data_frameread_messagebytesframesdataframeswrite_framemessagesQueuerecvsendpingpongclosecontrolframesbytesframes + Produced by OmniGraffle 6.6.2 2017-09-24 19:39:13 +0000Canvas 1Layer 1remote endpointwebsocketsWebSocketCommonProtocolapplication logicreaderStreamReaderwriterStreamWriterpingsdicttransfer_data_taskTasknetworkread_frameread_data_frameread_messagebytesframesdataframeswrite_framemessagesdequerecvsendpingpongclosecontrolframesbytesframes diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 5f7f3b0cd..96d9b1d03 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -7,7 +7,6 @@ """ import asyncio -import asyncio.queues import binascii import codecs import collections @@ -246,7 +245,9 @@ def __init__( self.connection_lost_waiter = asyncio.Future(loop=loop) # Queue of received messages. - self.messages = asyncio.queues.Queue(max_queue, loop=loop) + self.messages = collections.deque() + self._pop_message_waiter = None + self._put_message_waiter = None # Mapping of ping IDs to waiters, in chronological order. self.pings = collections.OrderedDict() @@ -387,42 +388,57 @@ def recv(self): makes it possible to enforce a timeout by wrapping :meth:`recv` in :func:`~asyncio.wait_for`. + .. versionchanged:: 7.0 + + Calling :meth:`recv` concurrently raises :exc:`RuntimeError`. + """ - # Don't yield from self.ensure_open() here because messages could be - # available in the queue even if the connection is closed. + if self._pop_message_waiter is not None: + raise RuntimeError( + "cannot call recv() while another coroutine " + "is already waiting for the next message" + ) - # Return any available message - try: - return self.messages.get_nowait() - except asyncio.queues.QueueEmpty: - pass + # Don't yield from self.ensure_open() here: + # - messages could be available in the queue even if the connection + # is closed; + # - messages could be received before the closing frame even if the + # connection is closing. + + # Wait until there's a message in the queue (if necessary) or the + # connection is closed. + while len(self.messages) <= 0: + pop_message_waiter = asyncio.Future(loop=self.loop) + self._pop_message_waiter = pop_message_waiter + try: + # If asyncio.wait() is canceled, it doesn't cancel + # pop_message_waiter and self.transfer_data_task. + yield from asyncio.wait( + [pop_message_waiter, self.transfer_data_task], + loop=self.loop, + return_when=asyncio.FIRST_COMPLETED, + ) + if pop_message_waiter.done(): + pass + elif self.legacy_recv: + return + else: + assert self.state in [State.CLOSING, State.CLOSED] + # Wait until the connection is closed to raise + # ConnectionClosed with the correct code and reason. + yield from self.ensure_open() + finally: + self._pop_message_waiter = None - # Don't yield from self.ensure_open() here because messages could be - # received before the closing frame even if the connection is closing. + # Pop a message from the queue. + message = self.messages.popleft() - # Wait for a message until the connection is closed. - next_message = asyncio_ensure_future(self.messages.get(), loop=self.loop) - # See https://bugs.python.org/issue23859 for cancellation handling. - try: - done, pending = yield from asyncio.wait( - [next_message, self.transfer_data_task], - loop=self.loop, - return_when=asyncio.FIRST_COMPLETED, - ) - except asyncio.CancelledError: - # Propagate cancellation to avoid leaking the next_message Task. - next_message.cancel() - raise + # Notify transfer_data(). + if self._put_message_waiter is not None: + self._put_message_waiter.set_result(None) + self._put_message_waiter = None - if next_message in done: - return next_message.result() - else: - next_message.cancel() - if not self.legacy_recv: - assert self.state in [State.CLOSING, State.CLOSED] - # Wait until the connection is closed to raise - # ConnectionClosed with the correct code and reason. - yield from self.ensure_open() + return message @asyncio.coroutine def send(self, data): @@ -651,11 +667,27 @@ def transfer_data(self): """ try: while True: - msg = yield from self.read_message() + message = yield from self.read_message() + # Exit the loop when receiving a close frame. - if msg is None: + if message is None: break - yield from self.messages.put(msg) + + # Wait until there's room in the queue (if necessary). + while len(self.messages) >= self.max_queue: + self._put_message_waiter = asyncio.Future(loop=self.loop) + try: + yield from self._put_message_waiter + finally: + self._put_message_waiter = None + + # Put the message in the queue. + self.messages.append(message) + + # Notify recv(). + if self._pop_message_waiter is not None: + self._pop_message_waiter.set_result(None) + self._pop_message_waiter = None except asyncio.CancelledError as exc: self.transfer_data_exc = exc diff --git a/tests/test_protocol.py b/tests/test_protocol.py index cceb63e85..aee3289ea 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -442,6 +442,43 @@ def test_recv_binary_no_max_size(self): data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, b'tea' * 342) + def test_recv_queue_empty(self): + recv = self.ensure_future(self.protocol.recv()) + with self.assertRaises(asyncio.TimeoutError): + self.loop.run_until_complete( + asyncio.wait_for(asyncio.shield(recv), timeout=MS) + ) + + self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8'))) + data = self.loop.run_until_complete(recv) + self.assertEqual(data, 'café') + + def test_recv_queue_full(self): + self.protocol.max_queue = 2 + # Test internals because it's hard to verify buffers from the outside. + self.assertEqual(list(self.protocol.messages), []) + + self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8'))) + self.run_loop_once() + self.assertEqual(list(self.protocol.messages), ['café']) + + self.receive_frame(Frame(True, OP_BINARY, b'tea')) + self.run_loop_once() + self.assertEqual(list(self.protocol.messages), ['café', b'tea']) + + self.receive_frame(Frame(True, OP_BINARY, b'milk')) + self.run_loop_once() + self.assertEqual(list(self.protocol.messages), ['café', b'tea']) + + self.loop.run_until_complete(self.protocol.recv()) + self.assertEqual(list(self.protocol.messages), [b'tea', b'milk']) + + self.loop.run_until_complete(self.protocol.recv()) + self.assertEqual(list(self.protocol.messages), [b'milk']) + + self.loop.run_until_complete(self.protocol.recv()) + self.assertEqual(list(self.protocol.messages), []) + def test_recv_other_error(self): @asyncio.coroutine def read_message(): @@ -454,6 +491,7 @@ def read_message(): def test_recv_canceled(self): recv = self.ensure_future(self.protocol.recv()) self.loop.call_soon(recv.cancel) + with self.assertRaises(asyncio.CancelledError): self.loop.run_until_complete(recv) @@ -462,6 +500,31 @@ def test_recv_canceled(self): data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, 'café') + def test_recv_canceled_race_condition(self): + recv = self.ensure_future( + asyncio.wait_for(self.protocol.recv(), timeout=0.000001) + ) + self.loop.call_soon( + self.receive_frame, Frame(True, OP_TEXT, 'café'.encode('utf-8')) + ) + + with self.assertRaises(asyncio.TimeoutError): + self.loop.run_until_complete(recv) + + # The previous frame doesn't disappear in a vacuum (it used to). + self.receive_frame(Frame(True, OP_TEXT, 'tea'.encode('utf-8'))) + data = self.loop.run_until_complete(self.protocol.recv()) + # If we're getting "tea" there, it means "café" was swallowed (ha, ha). + self.assertEqual(data, 'café') + + def test_recv_prevents_concurrent_calls(self): + recv = self.ensure_future(self.protocol.recv()) + + with self.assertRaises(RuntimeError): + self.loop.run_until_complete(self.protocol.recv()) + + recv.cancel() + # Test the send coroutine. def test_send_text(self): From 2552297b052a50b49b4e74eb1a17322a72860beb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 31 Oct 2018 21:16:05 +0100 Subject: [PATCH 0491/1539] Raise ConnectionClose consistently. --- src/websockets/protocol.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 96d9b1d03..b5145e27b 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -632,7 +632,9 @@ def ensure_open(self): # from OPEN to CLOSED. if self.transfer_data_task.done(): yield from asyncio.shield(self.close_connection_task) - raise ConnectionClosed(self.close_code, self.close_reason) + raise ConnectionClosed( + self.close_code, self.close_reason + ) from self.transfer_data_exc else: return From 001b51df4a27051b8d226f170c96daeac555f3b6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 31 Oct 2018 21:17:18 +0100 Subject: [PATCH 0492/1539] Remove self._pop_messsage_waiter as soon as possible. I don't think this changes the behavior, however the new version is easier to reason about and mirrors transfer_task more clearly. --- src/websockets/protocol.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index b5145e27b..3e4ced21d 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -418,17 +418,20 @@ def recv(self): loop=self.loop, return_when=asyncio.FIRST_COMPLETED, ) - if pop_message_waiter.done(): - pass - elif self.legacy_recv: + finally: + self._pop_message_waiter = None + + # If asyncio.wait(...) exited because self.transfer_data_task + # completed before receiving a new message, raise a suitable + # exception (or return None if legacy_recv is enabled). + if not pop_message_waiter.done(): + if self.legacy_recv: return else: assert self.state in [State.CLOSING, State.CLOSED] # Wait until the connection is closed to raise # ConnectionClosed with the correct code and reason. yield from self.ensure_open() - finally: - self._pop_message_waiter = None # Pop a message from the queue. message = self.messages.popleft() From dc2722bcf5e275f44d6fbd3f36da529642c697c0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 31 Oct 2018 22:33:15 +0100 Subject: [PATCH 0493/1539] Document why websockets never cancels handlers. --- docs/design.rst | 42 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/docs/design.rst b/docs/design.rst index 349c63b31..83c8ce07a 100644 --- a/docs/design.rst +++ b/docs/design.rst @@ -364,6 +364,44 @@ easiest way to make :class:`~websockets.server.WebSocketServer.close` and Cancellation ------------ +User code +......... + +``websockets`` provides a WebSocket application server. It manages connections +and passes them to user-provided connection handlers. This is an *inversion of +control* scenario: library code calls user code. + +If a connection drops, the corresponding handler should terminate. If the +server shuts down, all connection handlers must terminate. Canceling +connection handlers would terminate them. + +However, using cancellation for this purpose would require all connection +handlers to handle it properly. For example, if a connection handler starts +some tasks, it should catch :exc:`~asyncio.CancelledError`, terminate or +cancel these tasks, and then re-raise the exception. + +Cancellation is tricky in :mod:`asyncio` applications, especially when it +interacts with finalization logic. In the example above, what if a handler +gets interrupted with :exc:`~asyncio.CancelledError` while it's finalizing +the tasks it started, after detecting that the connection dropped? + +``websockets`` considers that cancellation may only be triggered by the caller +of a coroutine when it doesn't care about the results of that coroutine +anymore. (Source: `Guido van Rossum `_). Since connection handlers run +arbitrary user code, ``websockets`` has no way of deciding whether that code +is still doing something worth caring about. + +For these reasons, ``websockets`` never cancels connection handlers. Instead +it expects them to detect when the connection is closed, execute finalization +logic if needed, and exit. + +Conversely, cancellation isn't a concern for WebSocket clients because they +don't involve inversion of control. + +Library +....... + Most :doc:`public APIs ` of ``websockets`` are coroutines. They may be canceled, for example if the user starts a task that calls these coroutines and cancels the task later. ``websockets`` must handle this situation. @@ -410,10 +448,6 @@ Since :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` handles :exc:`~asyncio.CancelledError`, cancellation doesn't propagate to :attr:`~protocol.WebSocketCommonProtocol.close_connnection_task`. -Conversely, ``websockets`` never injects :exc:`~asyncio.CancelledError` into -user code. It doesn't cancel connection handler coroutines. Instead it expects -them to detect when the connection is closed and to exit. - .. _backpressure: From 62795d1a370858bdd3c43e092d273ed8953c2b59 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 31 Oct 2018 22:34:21 +0100 Subject: [PATCH 0494/1539] Update discussion of transfer_data_task cancellation. Fix #509. --- docs/design.rst | 21 +++++++++++---------- src/websockets/protocol.py | 7 +++---- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/docs/design.rst b/docs/design.rst index 83c8ce07a..03f1ec163 100644 --- a/docs/design.rst +++ b/docs/design.rst @@ -411,7 +411,7 @@ the TCP connection is closed and the exception is re-raised. This can only happen on the client side. On the server side, the opening handshake is managed by ``websockets`` and nothing results in a cancellation. -Once the WebSocket connection is established, +Once the WebSocket connection is established, internal tasks :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` and :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` mustn't get accidentally canceled if a coroutine that awaits them is canceled. In other @@ -434,20 +434,21 @@ prevent cancellation. :meth:`~protocol.WebSocketCommonProtocol.close()` waits for the data transfer task to terminate with :func:`~asyncio.wait_for`. If it's canceled or if the timeout elapses, :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` -is canceled. :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` is -expected to catch the cancellation and terminate properly. This is the only -point where it may be canceled. - +is canceled, which is correct at this point. :meth:`~protocol.WebSocketCommonProtocol.close()` then waits for :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` but shields it to prevent cancellation. -:attr:`~protocol.WebSocketCommonProtocol.close_connnection_task` starts by -waiting for :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`. -Since :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` handles -:exc:`~asyncio.CancelledError`, cancellation doesn't propagate to -:attr:`~protocol.WebSocketCommonProtocol.close_connnection_task`. +:meth:`~protocol.WebSocketCommonProtocol.close()` and +:func:`~protocol.WebSocketCommonProtocol.fail_connection()` are the only +places where :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` may +be canceled. +:attr:`~protocol.WebSocketCommonProtocol.close_connnection_task` starts by +waiting for :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`. It +catches :exc:`~asyncio.CancelledError` to prevent a cancellation of +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` from propagating +to :attr:`~protocol.WebSocketCommonProtocol.close_connnection_task`. .. _backpressure: diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 3e4ced21d..ae87c450b 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -541,12 +541,11 @@ def close(self, code=1000, reason=''): self.fail_connection() # If no close frame is received within the timeout, wait_for() cancels - # the data transfer task and raises TimeoutError. Then transfer_data() - # catches CancelledError and exits without an exception. + # the data transfer task and raises TimeoutError. # If close() is called multiple times concurrently and one of these - # calls hits the timeout, other calls will resume executing without an - # exception, so there's no need to catch CancelledError here. + # calls hits the timeout, the data transfer task will be cancelled. + # Other calls will receive a CancelledError here. try: # If close() is canceled during the wait, self.transfer_data_task From 6e7666e4203e245b7a0097e7d578d871a71d024e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 1 Nov 2018 15:54:34 +0100 Subject: [PATCH 0495/1539] Fix typos in the docs. --- docs/changelog.rst | 2 +- docs/spelling_wordlist.txt | 1 + src/websockets/server.py | 4 ++-- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 18f5f0764..e79041e6e 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -75,7 +75,7 @@ Changelog ``None`` rather than ``''``. * Fixed a data loss bug in :meth:`~protocol.WebSocketCommonProtocol.recv`: - cancelling it at the wrong time could result in messages being dropped. + canceling it at the wrong time could result in messages being dropped. * Improved handling of multiple HTTP headers with the same name. diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 63be00f1f..ba30efd99 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -23,6 +23,7 @@ permessage pong Pythonic serializers +subclassing subprotocol subprotocols TLS diff --git a/src/websockets/server.py b/src/websockets/server.py index be9abaad6..556c270d4 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -269,7 +269,7 @@ def process_request(self, path, request_headers): likely to require network requests. This coroutine may be overridden by passing a ``process_request`` - argument to the :class:`WebSocketServerProtocol` contstructor or the + argument to the :class:`WebSocketServerProtocol` constructor or the :func:`serve` function. """ @@ -420,7 +420,7 @@ def select_subprotocol(client_subprotocols, server_subprotocols): that subprotocol. Such rules can be implemented in a subclass. This method may be overridden by passing a ``select_subprotocol`` - argument to the :class:`WebSocketServerProtocol` contstructor or the + argument to the :class:`WebSocketServerProtocol` constructor or the :func:`serve` function. """ From 1138974fc7dd003d6626faf0b5f3a13f3a44c13d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 1 Nov 2018 15:57:01 +0100 Subject: [PATCH 0496/1539] Bump version number. --- docs/changelog.rst | 7 ++++++- docs/conf.py | 4 ++-- src/websockets/version.py | 2 +- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index e79041e6e..eea0693e0 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -3,11 +3,14 @@ Changelog .. currentmodule:: websockets -7.0 +7.1 ... *In development* +7.0 +... + .. warning:: **Version 7.0 renames the** ``timeout`` **argument of** @@ -54,6 +57,8 @@ Changelog Concurrent calls lead to non-deterministic behavior because there are no guarantees about which coroutine will receive which message. +Also: + * websockets sends Ping frames at regular intervals and closes the connection if it doesn't receive a matching Pong frame. See :class:`~protocol.WebSocketCommonProtocol` for details. diff --git a/docs/conf.py b/docs/conf.py index 8fdf12b4d..1a5448f7b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -53,9 +53,9 @@ # built documents. # # The short X.Y version. -version = '6.0' +version = '7.0' # The full version, including alpha/beta/rc tags. -release = '6.0' +release = '7.0' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/src/websockets/version.py b/src/websockets/version.py index 9d929a970..fe9ed183b 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -1 +1 @@ -version = '6.0' +version = '7.0' From 9d6fe04da62b35368de4ec047ab940e3a4e06b77 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 1 Nov 2018 16:28:28 +0100 Subject: [PATCH 0497/1539] Make the client and server examples compatible. Fix #484. --- README.rst | 5 +++-- example/hello.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index adc3a9210..863829c33 100644 --- a/README.rst +++ b/README.rst @@ -36,7 +36,7 @@ Python with a focus on correctness and simplicity. Built on top of ``asyncio``, Python's standard asynchronous I/O framework, it provides an elegant coroutine-based API. -Here's a client that says "Hello world!": +Here's how a client sends and receives messages (Python ≥ 3.6): .. copy-pasted because GitHub doesn't support the include directive @@ -50,11 +50,12 @@ Here's a client that says "Hello world!": async def hello(uri): async with websockets.connect(uri) as websocket: await websocket.send("Hello world!") + await websocket.recv() asyncio.get_event_loop().run_until_complete( hello('ws://localhost:8765')) -And here's an echo server (for Python ≥ 3.6): +And here's an echo server (Python ≥ 3.6): .. code:: python diff --git a/example/hello.py b/example/hello.py index bbb3d9a0e..f90c0de55 100755 --- a/example/hello.py +++ b/example/hello.py @@ -6,6 +6,7 @@ async def hello(uri): async with websockets.connect(uri) as websocket: await websocket.send("Hello world!") + await websocket.recv() asyncio.get_event_loop().run_until_complete( hello('ws://localhost:8765')) From 82b575bdcc98e9f9702c1f53d0b7414297383bca Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 1 Nov 2018 16:34:16 +0100 Subject: [PATCH 0498/1539] Make link to docs more prominent. Ref #484. --- README.rst | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index 863829c33..8c6fe1f03 100644 --- a/README.rst +++ b/README.rst @@ -72,9 +72,11 @@ And here's an echo server (Python ≥ 3.6): websockets.serve(echo, 'localhost', 8765)) asyncio.get_event_loop().run_forever() -Does that look good? `Start here`_. +Does that look good? -.. _Start here: https://websockets.readthedocs.io/en/stable/intro.html +`Start here!`_ + +.. _Start here!: https://websockets.readthedocs.io/en/stable/intro.html Why should I use ``websockets``? -------------------------------- From 59d4c2c7648c6a143923703f97a8af6c41e39e1f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 1 Nov 2018 16:46:44 +0100 Subject: [PATCH 0499/1539] Point to the CoC in the README. --- README.rst | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index 8c6fe1f03..b64a32abd 100644 --- a/README.rst +++ b/README.rst @@ -129,12 +129,17 @@ Why shouldn't I use ``websockets``? What else? ---------- -Bug reports, patches and suggestions welcome! Just open an issue_ or send a -`pull request`_. +Bug reports, patches and suggestions are welcome! + +Please open an issue_ or send a `pull request`_. .. _issue: https://github.com/aaugustin/websockets/issues/new .. _pull request: https://github.com/aaugustin/websockets/compare/ +Participants must uphold the `Contributor Covenant code of conduct`_. + +.. _Contributor Covenant code of conduct: https://github.com/aaugustin/websockets/blob/master/CODE_OF_CONDUCT.md + ``websockets`` is released under the `BSD license`_. .. _BSD license: https://websockets.readthedocs.io/en/stable/license.html From 2f357dbeaa6513ced67d70142d947caef294a62f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 1 Nov 2018 16:47:51 +0100 Subject: [PATCH 0500/1539] Link to the LICENSE on GitHub. --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index b64a32abd..b57317d19 100644 --- a/README.rst +++ b/README.rst @@ -142,4 +142,4 @@ Participants must uphold the `Contributor Covenant code of conduct`_. ``websockets`` is released under the `BSD license`_. -.. _BSD license: https://websockets.readthedocs.io/en/stable/license.html +.. _BSD license: https://github.com/aaugustin/websockets/blob/master/LICENSE From 391aa13091869cb1073e967be4295e83bd4649cc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 1 Nov 2018 16:49:21 +0100 Subject: [PATCH 0501/1539] Better two months early than ten months late. --- LICENSE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LICENSE b/LICENSE index 7101662c8..b2962adba 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright (c) 2013-2017 Aymeric Augustin and contributors. +Copyright (c) 2013-2019 Aymeric Augustin and contributors. All rights reserved. Redistribution and use in source and binary forms, with or without From 9668b5bb93a7ffb738125ae8f6c1e9002bc57c13 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 1 Nov 2018 16:50:56 +0100 Subject: [PATCH 0502/1539] Fix typo. --- docs/contributing.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/contributing.rst b/docs/contributing.rst index 21e2152c1..00a529243 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -14,7 +14,7 @@ report inappropriate behavior to aymeric DOT augustin AT fractalideas DOT com. *(If I'm the person with the inappropriate behavior, please accept my apologies. I know I can mess up. I can't expect you to tell me, but if you -chose to do so, I'll do my best to handle criticism constructively. +choose to do so, I'll do my best to handle criticism constructively. -- Aymeric)* Contributions From 5a92a1124f47b5d439ac38a607c5a47a2115d6d7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 1 Nov 2018 17:12:23 +0100 Subject: [PATCH 0503/1539] Factor out CRLF stripping. --- src/websockets/http.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/websockets/http.py b/src/websockets/http.py index e56a4a2c5..507be0555 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -76,11 +76,10 @@ def read_request(stream): # version and because path isn't checked. Since WebSocket software tends # to implement HTTP/1.1 strictly, there's little need for lenient parsing. - # Given the implementation of read_line(), request_line ends with CRLF. request_line = yield from read_line(stream) # This may raise "ValueError: not enough values to unpack" - method, path, version = request_line[:-2].split(b' ', 2) + method, path, version = request_line.split(b' ', 2) if method != b'GET': raise ValueError("Unsupported HTTP method: %r" % method) @@ -118,11 +117,10 @@ def read_response(stream): # As in read_request, parsing is simple because a fixed value is expected # for version, status_code is a 3-digit number, and reason can be ignored. - # Given the implementation of read_line(), status_line ends with CRLF. status_line = yield from read_line(stream) # This may raise "ValueError: not enough values to unpack" - version, status_code, reason = status_line[:-2].split(b' ', 2) + version, status_code, reason = status_line.split(b' ', 2) if version != b'HTTP/1.1': raise ValueError("Unsupported HTTP version: %r" % version) @@ -157,11 +155,11 @@ def read_headers(stream): headers = Headers() for _ in range(MAX_HEADERS + 1): line = yield from read_line(stream) - if line == b'\r\n': + if line == b'': break # This may raise "ValueError: not enough values to unpack" - name, value = line[:-2].split(b':', 1) + name, value = line.split(b':', 1) if not _token_re.fullmatch(name): raise ValueError("Invalid HTTP header name: %r" % name) value = value.strip(b' \t') @@ -185,6 +183,8 @@ def read_line(stream): ``stream`` is an :class:`~asyncio.StreamReader`. + Return :class:`bytes` without CRLF. + """ # Security: this is bounded by the StreamReader's limit (default = 32kB). line = yield from stream.readline() @@ -194,7 +194,7 @@ def read_line(stream): # Not mandatory but safe - https://tools.ietf.org/html/rfc7230#section-3.5 if not line.endswith(b'\r\n'): raise ValueError("Line without CRLF") - return line + return line[:-2] class MultipleValuesError(LookupError): From 82baae15dba99dc4b6d7a476be501a643c1ae1bf Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 1 Nov 2018 18:04:07 +0100 Subject: [PATCH 0504/1539] Add debug logs of HTTP requests and responses. Fix #493. --- docs/changelog.rst | 7 ++++++- src/websockets/client.py | 11 ++++++++++- src/websockets/http.py | 9 +++++---- src/websockets/protocol.py | 6 +++--- src/websockets/server.py | 9 ++++++++- tests/test_client_server.py | 4 ++-- tests/test_http.py | 5 ++++- 7 files changed, 38 insertions(+), 13 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index eea0693e0..4b2521d05 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -3,11 +3,16 @@ Changelog .. currentmodule:: websockets -7.1 +8.0 ... *In development* +.. warning:: + + **Version 8.0 adds the reason phrase to the return type of the low-level + API** :func:`~http.read_response` **.** + 7.0 ... diff --git a/src/websockets/client.py b/src/websockets/client.py index 9f92f18e8..2ee654ec0 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -5,6 +5,7 @@ import asyncio import collections.abc +import logging import sys from .exceptions import ( @@ -29,6 +30,8 @@ __all__ = ['connect', 'WebSocketClientProtocol'] +logger = logging.getLogger(__name__) + class WebSocketClientProtocol(WebSocketCommonProtocol): """ @@ -66,6 +69,9 @@ def write_http_request(self, path, headers): self.path = path self.request_headers = headers + logger.debug("%s > GET %s HTTP/1.1", self.side, path) + logger.debug("%s > %r", self.side, headers) + # Since the path and headers only contain ASCII characters, # we can keep this simple. request = 'GET {path} HTTP/1.1\r\n'.format(path=path) @@ -87,10 +93,13 @@ def read_http_response(self): """ try: - status_code, headers = yield from read_response(self.reader) + status_code, reason, headers = yield from read_response(self.reader) except ValueError as exc: raise InvalidMessage("Malformed HTTP message") from exc + logger.debug("%s < HTTP/1.1 %d %s", self.side, status_code, reason) + logger.debug("%s < %r", self.side, headers) + self.response_headers = headers return status_code, self.response_headers diff --git a/src/websockets/http.py b/src/websockets/http.py index 507be0555..5062c03d7 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -85,7 +85,6 @@ def read_request(stream): raise ValueError("Unsupported HTTP method: %r" % method) if version != b'HTTP/1.1': raise ValueError("Unsupported HTTP version: %r" % version) - path = path.decode('ascii', 'surrogateescape') headers = yield from read_headers(stream) @@ -100,8 +99,9 @@ def read_response(stream): ``stream`` is an :class:`~asyncio.StreamReader`. - Return ``(status_code, headers)`` where ``status_code`` is a :class:`int` - and ``headers`` is a :class:`Headers` instance. + Return ``(status_code, reason, headers)`` where ``status_code`` is an + :class:`int`, ``reason`` is a :class:`str`, and ``headers`` is a + :class:`Headers` instance. Non-ASCII characters are represented with surrogate escapes. @@ -130,10 +130,11 @@ def read_response(stream): raise ValueError("Unsupported HTTP status code: %d" % status_code) if not _value_re.fullmatch(reason): raise ValueError("Invalid HTTP reason phrase: %r" % reason) + reason = reason.decode() headers = yield from read_headers(stream) - return status_code, headers + return status_code, reason, headers @asyncio.coroutine diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index ae87c450b..ebbf95530 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -883,7 +883,7 @@ def read_frame(self, max_size): max_size=max_size, extensions=self.extensions, ) - logger.debug("%s < %s", self.side, frame) + logger.debug("%s < %r", self.side, frame) return frame @asyncio.coroutine @@ -895,7 +895,7 @@ def write_frame(self, fin, opcode, data, *, _expected_state=State.OPEN): ) frame = Frame(fin, opcode, data) - logger.debug("%s > %s", self.side, frame) + logger.debug("%s > %r", self.side, frame) frame.write(self.writer.write, mask=self.is_client, extensions=self.extensions) # Backport of https://github.com/python/asyncio/pull/280. @@ -1139,7 +1139,7 @@ def fail_connection(self, code=1006, reason=''): logger.debug("%s - state = CLOSING", self.side) frame = Frame(True, OP_CLOSE, frame_data) - logger.debug("%s > %s", self.side, frame) + logger.debug("%s > %r", self.side, frame) frame.write( self.writer.write, mask=self.is_client, extensions=self.extensions ) diff --git a/src/websockets/server.py b/src/websockets/server.py index 556c270d4..5465ccd7e 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -214,10 +214,13 @@ def read_http_request(self): except ValueError as exc: raise InvalidMessage("Malformed HTTP message") from exc + logger.debug("%s < GET %s HTTP/1.1", self.side, path) + logger.debug("%s < %r", self.side, headers) + self.path = path self.request_headers = headers - return path, self.request_headers + return path, headers @asyncio.coroutine def write_http_response(self, status, headers, body=None): @@ -229,6 +232,9 @@ def write_http_response(self, status, headers, body=None): """ self.response_headers = headers + logger.debug("%s > HTTP/1.1 %d %s", self.side, status.value, status.phrase) + logger.debug("%s > %r", self.side, headers) + # Since the status line and headers only contain ASCII characters, # we can keep this simple. response = 'HTTP/1.1 {status.value} {status.phrase}\r\n'.format(status=status) @@ -237,6 +243,7 @@ def write_http_response(self, status, headers, body=None): self.writer.write(response.encode()) if body is not None: + logger.debug("%s > Body (%d bytes)", self.side, len(body)) self.writer.write(body) @asyncio.coroutine diff --git a/tests/test_client_server.py b/tests/test_client_server.py index dee44a662..0d6ee144d 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -879,8 +879,8 @@ def wrong_build_response(headers, key): def test_server_does_not_switch_protocols(self, _read_response): @asyncio.coroutine def wrong_read_response(stream): - status_code, headers = yield from read_response(stream) - return 400, headers + status_code, reason, headers = yield from read_response(stream) + return 400, 'Bad Request', headers _read_response.side_effect = wrong_read_response diff --git a/tests/test_http.py b/tests/test_http.py index b18e24a26..c222b370f 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -43,8 +43,11 @@ def test_read_response(self): b'Sec-WebSocket-Protocol: chat\r\n' b'\r\n' ) - status_code, headers = self.loop.run_until_complete(read_response(self.stream)) + status_code, reason, headers = self.loop.run_until_complete( + read_response(self.stream) + ) self.assertEqual(status_code, 101) + self.assertEqual(reason, 'Switching Protocols') self.assertEqual(headers['Upgrade'], 'websocket') def test_request_method(self): From 10b16ab82cccb7651a22597e3b8be2c61705889b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 1 Nov 2018 19:22:08 +0100 Subject: [PATCH 0505/1539] Shorten debug logs a bit. --- src/websockets/protocol.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index ebbf95530..eb34c9174 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -1106,11 +1106,10 @@ def fail_connection(self, code=1006, reason=''): """ logger.debug( - "%s ! failing WebSocket connection in the %s state: %d %s", + "%s ! failing %s WebSocket connection with code %d", self.side, self.state.name, code, - reason or '[no reason]', ) # Cancel transfer_data_task if the opening handshake succeeded. From 6e315128b575de9240a4e3603ba3c9c095e0edd0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 1 Nov 2018 20:42:33 +0100 Subject: [PATCH 0506/1539] Make write_http_request/response synchronous. They make small writes early in the lifetime of the connection so they're extremely unlikely to require draining the write buffer. --- src/websockets/client.py | 3 +-- src/websockets/server.py | 5 ++--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 2ee654ec0..2de160e9c 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -60,7 +60,6 @@ def __init__( self.extra_headers = extra_headers super().__init__(**kwds) - @asyncio.coroutine def write_http_request(self, path, headers): """ Write request line and headers to the HTTP request. @@ -287,7 +286,7 @@ def handshake( request_headers.setdefault('User-Agent', USER_AGENT) - yield from self.write_http_request(wsuri.resource_name, request_headers) + self.write_http_request(wsuri.resource_name, request_headers) status_code, response_headers = yield from self.read_http_response() diff --git a/src/websockets/server.py b/src/websockets/server.py index 5465ccd7e..fd3ecf30e 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -160,7 +160,7 @@ def handler(self): headers.setdefault('Content-Type', 'text/plain') headers.setdefault('Connection', 'close') - yield from self.write_http_response(status, headers, body) + self.write_http_response(status, headers, body) self.fail_connection() yield from self.wait_closed() return @@ -222,7 +222,6 @@ def read_http_request(self): return path, headers - @asyncio.coroutine def write_http_response(self, status, headers, body=None): """ Write status line and headers to the HTTP response. @@ -524,7 +523,7 @@ def handshake( response_headers.setdefault('Date', email.utils.formatdate(usegmt=True)) response_headers.setdefault('Server', USER_AGENT) - yield from self.write_http_response(SWITCHING_PROTOCOLS, response_headers) + self.write_http_response(SWITCHING_PROTOCOLS, response_headers) self.connection_open() From 4f1a14c341df27338460db97ea6376571dc3ada7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 1 Nov 2018 21:33:26 +0100 Subject: [PATCH 0507/1539] Declare process_request as function by default. This keeps things simple. --- src/websockets/server.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/websockets/server.py b/src/websockets/server.py index fd3ecf30e..b42068764 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -245,14 +245,13 @@ def write_http_response(self, status, headers, body=None): logger.debug("%s > Body (%d bytes)", self.side, len(body)) self.writer.write(body) - @asyncio.coroutine def process_request(self, path, request_headers): """ Intercept the HTTP request and return an HTTP response if needed. ``request_headers`` is a :class:`~websockets.http.Headers` instance. - If this coroutine returns ``None``, the WebSocket handshake continues. + If this method returns ``None``, the WebSocket handshake continues. If it returns a status code, headers and a response body, that HTTP response is sent and the connection is closed. @@ -271,12 +270,12 @@ def process_request(self, path, request_headers): different status, for example to authenticate the request and return ``HTTPStatus.UNAUTHORIZED`` or ``HTTPStatus.FORBIDDEN``. - It is declared as a coroutine because such authentication checks are - likely to require network requests. + It can be declared as a function or as a coroutine because such + authentication checks are likely to require network requests. - This coroutine may be overridden by passing a ``process_request`` - argument to the :class:`WebSocketServerProtocol` constructor or the - :func:`serve` function. + It may also be overridden by passing a ``process_request`` argument to + the :class:`WebSocketServerProtocol` constructor or the :func:`serve` + function. """ From 771a5f2d1f1c873ea09a7a1191529a93f3f21846 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 1 Nov 2018 21:38:33 +0100 Subject: [PATCH 0508/1539] Return 503 on server shutdown during handshake. That was a regression in 71c4db9c. Fix #499. Ref #483. --- src/websockets/exceptions.py | 8 -------- src/websockets/server.py | 33 +++++++++++++++++++++------------ tests/test_client_server.py | 5 ++++- 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index b1618fa73..b34a2c0dc 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -1,6 +1,5 @@ __all__ = [ 'AbortHandshake', - 'CancelHandshake', 'ConnectionClosed', 'DuplicateParameter', 'InvalidHandshake', @@ -44,13 +43,6 @@ def __init__(self, status, headers, body=b''): super().__init__(message) -class CancelHandshake(InvalidHandshake): - """ - Exception raised to cancel a handshake when the connection is closed. - - """ - - class InvalidMessage(InvalidHandshake): """ Exception raised when the HTTP message in a handshake request is malformed. diff --git a/src/websockets/server.py b/src/websockets/server.py index b42068764..1d88e73a1 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -14,13 +14,13 @@ BAD_REQUEST, FORBIDDEN, INTERNAL_SERVER_ERROR, + SERVICE_UNAVAILABLE, SWITCHING_PROTOCOLS, UPGRADE_REQUIRED, asyncio_ensure_future, ) from .exceptions import ( AbortHandshake, - CancelHandshake, InvalidHandshake, InvalidHeader, InvalidMessage, @@ -119,10 +119,6 @@ def handler(self): except ConnectionError: logger.debug("Connection error in opening handshake", exc_info=True) raise - except CancelHandshake: - self.fail_connection() - yield from self.wait_closed() - return except Exception as exc: if isinstance(exc, AbortHandshake): status, headers, body = exc.status, exc.headers, exc.body @@ -478,11 +474,9 @@ def handshake( else: early_response = self.process_request(path, request_headers) - # Give up immediately and don't attempt to write a HTTP response if - # the TCP connection was closed while process_request() was running. - # This happens if the server shuts down and calls fail_connection(). - if self.state != State.CONNECTING: - raise CancelHandshake() + # Change the response to a 503 error if the server is shutting down. + if not self.ws_server.is_serving(): + early_response = SERVICE_UNAVAILABLE, [], b"Server is shutting down.\n" if early_response is not None: raise AbortHandshake(*early_response) @@ -593,6 +587,16 @@ def unregister(self, protocol): """ self.websockets.remove(protocol) + def is_serving(self): + """ + Tell whether the server is accepting new connections or shutting down. + + """ + try: + return self.server.is_serving() # Python ≥ 3.7 + except AttributeError: # pragma: no cover + return self.server.sockets is not None # Python < 3.7 + def close(self): """ Close the server and terminate connections with close code 1001. @@ -626,7 +630,8 @@ def _close(self): # Close open connections. fail_connection() will cancel the transfer # data task, which is expected to cause the handler task to terminate. for websocket in self.websockets: - websocket.fail_connection(1001) + if websocket.state is State.OPEN: + websocket.fail_connection(1001) # asyncio.wait doesn't accept an empty first argument. if self.websockets: @@ -637,7 +642,11 @@ def _close(self): # and let the handler wait for the connection to close. yield from asyncio.wait( [websocket.handler_task for websocket in self.websockets] - + [websocket.close_connection_task for websocket in self.websockets], + + [ + websocket.close_connection_task + for websocket in self.websockets + if websocket.state is State.OPEN + ], loop=self.loop, ) diff --git a/tests/test_client_server.py b/tests/test_client_server.py index 0d6ee144d..73866ff63 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -935,8 +935,11 @@ def test_client_closes_connection_before_handshake(self, handshake): @with_server(create_protocol=SlowServerProtocol) def test_server_shuts_down_during_opening_handshake(self): self.loop.call_later(5 * MS, self.server.close) - with self.assertRaises(InvalidHandshake): + with self.assertRaises(InvalidStatusCode) as raised: self.start_client() + exception = raised.exception + self.assertEqual(str(exception), "Status code not 101: 503") + self.assertEqual(exception.status_code, 503) @with_server() def test_server_shuts_down_during_connection_handling(self): From 9329ef30f4af2c6720ca17aa3980fe59ac52efce Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 4 Nov 2018 21:25:39 +0100 Subject: [PATCH 0509/1539] Fix formatting in changelog. --- docs/changelog.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 4b2521d05..393abf1f8 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -29,7 +29,7 @@ Changelog .. warning:: **Version 7.0 changes how a server terminates connections when it's - closed with :meth:`~websockets.server.WebSocketServer.close`.** + closed with** :meth:`~websockets.server.WebSocketServer.close` **.** Previously, connections handlers were canceled. Now, connections are closed with close code 1001 (going away). From the perspective of the From 0d3c7411f62af0d407426fa028011ad65a845e8d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 4 Nov 2018 10:38:38 +0100 Subject: [PATCH 0510/1539] Handle bytearray like bytes. Ref #478. --- src/websockets/framing.py | 16 ++++++---- src/websockets/protocol.py | 4 +-- src/websockets/speedups.c | 61 +++++++++++++++++++++++++++++++++++--- src/websockets/utils.py | 7 ++++- tests/test_framing.py | 9 +++++- tests/test_protocol.py | 20 +++++++++++++ tests/test_utils.py | 19 +++++++----- 7 files changed, 114 insertions(+), 22 deletions(-) diff --git a/src/websockets/framing.py b/src/websockets/framing.py index 00a24d807..850e7e7e2 100644 --- a/src/websockets/framing.py +++ b/src/websockets/framing.py @@ -237,18 +237,22 @@ def check(frame): def encode_data(data): """ - Helper that converts :class:`str` or :class:`bytes` to :class:`bytes`. + Convert a string or byte-like object to bytes. - :class:`str` are encoded with UTF-8. + If ``data`` is a :class:`str`, return a :class:`bytes` object encoding + ``data`` in UTF-8. + + If ``data`` is a bytes-like object, return a :class:`bytes` object. + + Raise :exc:`TypeError` for other inputs. """ - # Expect str or bytes, return bytes. if isinstance(data, str): return data.encode('utf-8') - elif isinstance(data, bytes): - return data + elif isinstance(data, collections.abc.ByteString): + return bytes(data) else: - raise TypeError("data must be bytes or str") + raise TypeError("data must be bytes-like or str") def parse_close(data): diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index eb34c9174..7af86133f 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -466,7 +466,7 @@ def send(self, data): if isinstance(data, str): yield from self.write_frame(True, OP_TEXT, data.encode('utf-8')) - elif isinstance(data, bytes): + elif isinstance(data, collections.abc.ByteString): yield from self.write_frame(True, OP_BINARY, data) # Fragmented message -- regular iterator. @@ -483,7 +483,7 @@ def send(self, data): if isinstance(data, str): yield from self.write_frame(False, OP_TEXT, data.encode('utf-8')) encode_data = True - elif isinstance(data, bytes): + elif isinstance(data, collections.abc.ByteString): yield from self.write_frame(False, OP_BINARY, data) encode_data = False else: diff --git a/src/websockets/speedups.c b/src/websockets/speedups.c index 4d7622231..bb9c7053f 100644 --- a/src/websockets/speedups.c +++ b/src/websockets/speedups.c @@ -10,16 +10,50 @@ static const Py_ssize_t MASK_LEN = 4; +/* Similar to PyBytes_AsStringAndSize, but accepts more types */ + +static int +_PyBytesLike_AsStringAndSize(PyObject *obj, char **buffer, Py_ssize_t *length) +{ + if (PyBytes_Check(obj)) + { + *buffer = PyBytes_AS_STRING(obj); + *length = PyBytes_GET_SIZE(obj); + } + else if (PyByteArray_Check(obj)) + { + *buffer = PyByteArray_AS_STRING(obj); + *length = PyByteArray_GET_SIZE(obj); + } + else + { + PyErr_Format( + PyExc_TypeError, + "expected a bytes-like object, %.200s found", + Py_TYPE(obj)->tp_name); + return -1; + } + + return 0; +} + +/* C implementation of websockets.utils.apply_mask */ + static PyObject * apply_mask(PyObject *self, PyObject *args, PyObject *kwds) { - // Inputs are treated as immutable, which causes an extra memory copy. + // In order to support bytes and bytearray, accept any Python object. static char *kwlist[] = {"data", "mask", NULL}; - const char *input; + PyObject *input_obj; + PyObject *mask_obj; + + // A pointer to the underlying char * will be extracted from these inputs. + + char *input; Py_ssize_t input_len; - const char *mask; + char *mask; Py_ssize_t mask_len; // Initialize a PyBytesObject then get a pointer to the underlying char * @@ -27,10 +61,25 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds) PyObject *result; char *output; + + // Other variables. + Py_ssize_t i = 0; + // Parse inputs. + if (!PyArg_ParseTupleAndKeywords( - args, kwds, "y#y#", kwlist, &input, &input_len, &mask, &mask_len)) + args, kwds, "OO", kwlist, &input_obj, &mask_obj)) + { + return NULL; + } + + if (_PyBytesLike_AsStringAndSize(input_obj, &input, &input_len) == -1) + { + return NULL; + } + + if (_PyBytesLike_AsStringAndSize(mask_obj, &mask, &mask_len) == -1) { return NULL; } @@ -41,6 +90,8 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds) return NULL; } + // Create output. + result = PyBytes_FromStringAndSize(NULL, input_len); if (result == NULL) { @@ -50,6 +101,8 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds) // Since we juste created result, we don't need error checks. output = PyBytes_AS_STRING(result); + // Perform the masking operation. + // Apparently GCC cannot figure out the following optimizations by itself. // We need a new scope for MSVC 2010 (non C99 friendly) diff --git a/src/websockets/utils.py b/src/websockets/utils.py index b4083dff4..def997841 100644 --- a/src/websockets/utils.py +++ b/src/websockets/utils.py @@ -6,9 +6,14 @@ def apply_mask(data, mask): """ - Apply masking to websocket message. + Apply masking to the data of a WebSocket message. + + ``data`` and ``mask`` are bytes-like objects. + + Return :class:`bytes`. """ if len(mask) != 4: raise ValueError("mask must contain 4 bytes") + return bytes(b ^ m for b, m in zip(data, itertools.cycle(mask))) diff --git a/tests/test_framing.py b/tests/test_framing.py index 9da64f14c..ae5acc1a6 100644 --- a/tests/test_framing.py +++ b/tests/test_framing.py @@ -158,7 +158,14 @@ def test_encode_data_str(self): def test_encode_data_bytes(self): self.assertEqual(encode_data(b'tea'), b'tea') - def test_encode_data_other(self): + def test_encode_data_bytearray(self): + self.assertEqual(encode_data(bytearray(b'tea')), b'tea') + + def test_encode_data_list(self): + with self.assertRaises(TypeError): + encode_data([]) + + def test_encode_data_none(self): with self.assertRaises(TypeError): encode_data(None) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index aee3289ea..c546e4e48 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -535,6 +535,10 @@ def test_send_binary(self): self.loop.run_until_complete(self.protocol.send(b'tea')) self.assertOneFrameSent(True, OP_BINARY, b'tea') + def test_send_binary_from_bytearray(self): + self.loop.run_until_complete(self.protocol.send(bytearray(b'tea'))) + self.assertOneFrameSent(True, OP_BINARY, b'tea') + def test_send_type_error(self): with self.assertRaises(TypeError): self.loop.run_until_complete(self.protocol.send(42)) @@ -554,6 +558,14 @@ def test_send_iterable_binary(self): (False, OP_BINARY, b'te'), (False, OP_CONT, b'a'), (True, OP_CONT, b'') ) + def test_send_iterable_binary_from_bytearray(self): + self.loop.run_until_complete( + self.protocol.send([bytearray(b'te'), bytearray(b'a')]) + ) + self.assertFramesSent( + (False, OP_BINARY, b'te'), (False, OP_CONT, b'a'), (True, OP_CONT, b'') + ) + def test_send_empty_iterable(self): self.loop.run_until_complete(self.protocol.send([])) self.assertNoFrameSent() @@ -616,6 +628,10 @@ def test_ping_binary(self): self.loop.run_until_complete(self.protocol.ping(b'tea')) self.assertOneFrameSent(True, OP_PING, b'tea') + def test_ping_binary_from_bytearray(self): + self.loop.run_until_complete(self.protocol.ping(bytearray(b'tea'))) + self.assertOneFrameSent(True, OP_PING, b'tea') + def test_ping_type_error(self): with self.assertRaises(TypeError): self.loop.run_until_complete(self.protocol.ping(42)) @@ -661,6 +677,10 @@ def test_pong_binary(self): self.loop.run_until_complete(self.protocol.pong(b'tea')) self.assertOneFrameSent(True, OP_PONG, b'tea') + def test_pong_binary_from_bytearray(self): + self.loop.run_until_complete(self.protocol.pong(bytearray(b'tea'))) + self.assertOneFrameSent(True, OP_PONG, b'tea') + def test_pong_type_error(self): with self.assertRaises(TypeError): self.loop.run_until_complete(self.protocol.pong(42)) diff --git a/tests/test_utils.py b/tests/test_utils.py index c7699232e..d2573e235 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,3 +1,4 @@ +import itertools import unittest from websockets.utils import apply_mask as py_apply_mask @@ -9,14 +10,16 @@ def apply_mask(*args, **kwargs): return py_apply_mask(*args, **kwargs) def test_apply_mask(self): - for data_in, mask, data_out in [ - (b'', b'1234', b''), - (b'aBcDe', b'\x00\x00\x00\x00', b'aBcDe'), - (b'abcdABCD', b'1234', b'PPPPpppp'), - (b'abcdABCD' * 10, b'1234', b'PPPPpppp' * 10), - ]: - with self.subTest(data_in=data_in, mask=mask): - self.assertEqual(self.apply_mask(data_in, mask), data_out) + for data_type, mask_type in itertools.product([bytes, bytearray], repeat=2): + for data_in, mask, data_out in [ + (b'', b'1234', b''), + (b'aBcDe', b'\x00\x00\x00\x00', b'aBcDe'), + (b'abcdABCD', b'1234', b'PPPPpppp'), + (b'abcdABCD' * 10, b'1234', b'PPPPpppp' * 10), + ]: + data_in, mask = data_type(data_in), mask_type(mask) + with self.subTest(data_in=data_in, mask=mask): + self.assertEqual(self.apply_mask(data_in, mask), data_out) def test_apply_mask_check_input_types(self): for data_in, mask in [(None, None), (b'abcd', None), (None, b'abcd')]: From 5897ee913650efdcdcf2c2c98b9f74c9b605e83a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 4 Nov 2018 11:35:34 +0100 Subject: [PATCH 0511/1539] Factor out logic for encoding data. This is slightly different for data frames and control frames. --- src/websockets/framing.py | 26 ++++++++++++++++++++++++++ src/websockets/protocol.py | 36 ++++++++++++++++-------------------- tests/test_framing.py | 19 +++++++++++++++++++ 3 files changed, 61 insertions(+), 20 deletions(-) diff --git a/src/websockets/framing.py b/src/websockets/framing.py index 850e7e7e2..3e3f9386d 100644 --- a/src/websockets/framing.py +++ b/src/websockets/framing.py @@ -34,6 +34,7 @@ 'OP_PING', 'OP_PONG', 'Frame', + 'prepare_data', 'encode_data', 'parse_close', 'serialize_close', @@ -235,10 +236,35 @@ def check(frame): raise WebSocketProtocolError("Invalid opcode: {}".format(frame.opcode)) +def prepare_data(data): + """ + Convert a string or byte-like object to an opcode and a bytes-like object. + + This function is designed for data frames. + + If ``data`` is a :class:`str`, return ``OP_TEXT`` and a :class:`bytes` + object encoding ``data`` in UTF-8. + + If ``data`` is a bytes-like object, return ``OP_BINARY`` and a bytes-like + object. + + Raise :exc:`TypeError` for other inputs. + + """ + if isinstance(data, str): + return OP_TEXT, data.encode('utf-8') + elif isinstance(data, collections.abc.ByteString): + return OP_BINARY, data + else: + raise TypeError("data must be bytes-like or str") + + def encode_data(data): """ Convert a string or byte-like object to bytes. + This function is designed for ping and pong frames. + If ``data`` is a :class:`str`, return a :class:`bytes` object encoding ``data`` in UTF-8. diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 7af86133f..13a370aca 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -461,17 +461,21 @@ def send(self, data): """ yield from self.ensure_open() - # Unfragmented message (first because str and bytes are iterable). + # Unfragmented message -- this case must be handled first because + # strings and bytes-like objects are iterable. - if isinstance(data, str): - yield from self.write_frame(True, OP_TEXT, data.encode('utf-8')) - - elif isinstance(data, collections.abc.ByteString): - yield from self.write_frame(True, OP_BINARY, data) + try: + opcode, data = prepare_data(data) + except TypeError: + # Perhaps data is an iterator, see below. + pass + else: + yield from self.write_frame(True, opcode, data) + return # Fragmented message -- regular iterator. - elif isinstance(data, collections.abc.Iterable): + if isinstance(data, collections.abc.Iterable): iter_data = iter(data) # First fragment. @@ -479,29 +483,21 @@ def send(self, data): data = next(iter_data) except StopIteration: return - data_type = type(data) - if isinstance(data, str): - yield from self.write_frame(False, OP_TEXT, data.encode('utf-8')) - encode_data = True - elif isinstance(data, collections.abc.ByteString): - yield from self.write_frame(False, OP_BINARY, data) - encode_data = False - else: - raise TypeError("data must be an iterable of bytes or str") + opcode, data = prepare_data(data) + yield from self.write_frame(False, opcode, data) # Other fragments. for data in iter_data: - if type(data) != data_type: + confirm_opcode, data = prepare_data(data) + if confirm_opcode != opcode: # We're half-way through a fragmented message and we can't # complete it. This makes the connection unusable. self.fail_connection(1011) raise TypeError("data contains inconsistent types") - if encode_data: - data = data.encode('utf-8') yield from self.write_frame(False, OP_CONT, data) # Final fragment. - yield from self.write_frame(True, OP_CONT, type(data)()) + yield from self.write_frame(True, OP_CONT, b'') # Fragmented message -- asynchronous iterator diff --git a/tests/test_framing.py b/tests/test_framing.py index ae5acc1a6..570fe3bdf 100644 --- a/tests/test_framing.py +++ b/tests/test_framing.py @@ -152,6 +152,25 @@ def test_control_frame_max_length(self): with self.assertRaises(WebSocketProtocolError): self.decode(b'\x88\x7e\x00\x7e' + 126 * b'a') + def test_prepare_data_str(self): + self.assertEqual(prepare_data('café'), (OP_TEXT, b'caf\xc3\xa9')) + + def test_prepare_data_bytes(self): + self.assertEqual(prepare_data(b'tea'), (OP_BINARY, b'tea')) + + def test_prepare_data_bytearray(self): + self.assertEqual( + prepare_data(bytearray(b'tea')), (OP_BINARY, bytearray(b'tea')) + ) + + def test_prepare_data_list(self): + with self.assertRaises(TypeError): + prepare_data([]) + + def test_prepare_data_none(self): + with self.assertRaises(TypeError): + prepare_data(None) + def test_encode_data_str(self): self.assertEqual(encode_data('café'), b'caf\xc3\xa9') From 6a8c8332838ef9814b773a4752aa80f7bca42d96 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 4 Nov 2018 17:45:47 +0100 Subject: [PATCH 0512/1539] Support memoryview objects like bytes. Minimize memory copies when they're C-contiguous. Fix #478. --- src/websockets/framing.py | 7 +++++ src/websockets/speedups.c | 27 +++++++++++++++++-- tests/test_framing.py | 14 ++++++++++ tests/test_protocol.py | 40 +++++++++++++++++++++++++++ tests/test_utils.py | 57 +++++++++++++++++++++++++++++++++------ 5 files changed, 135 insertions(+), 10 deletions(-) diff --git a/src/websockets/framing.py b/src/websockets/framing.py index 3e3f9386d..feebd3983 100644 --- a/src/websockets/framing.py +++ b/src/websockets/framing.py @@ -255,6 +255,11 @@ def prepare_data(data): return OP_TEXT, data.encode('utf-8') elif isinstance(data, collections.abc.ByteString): return OP_BINARY, data + elif isinstance(data, memoryview): + if data.c_contiguous: + return OP_BINARY, data + else: + return OP_BINARY, data.tobytes() else: raise TypeError("data must be bytes-like or str") @@ -277,6 +282,8 @@ def encode_data(data): return data.encode('utf-8') elif isinstance(data, collections.abc.ByteString): return bytes(data) + elif isinstance(data, memoryview): + return data.tobytes() else: raise TypeError("data must be bytes-like or str") diff --git a/src/websockets/speedups.c b/src/websockets/speedups.c index bb9c7053f..d1c2b37e6 100644 --- a/src/websockets/speedups.c +++ b/src/websockets/speedups.c @@ -15,6 +15,11 @@ static const Py_ssize_t MASK_LEN = 4; static int _PyBytesLike_AsStringAndSize(PyObject *obj, char **buffer, Py_ssize_t *length) { + // This supports bytes, bytearrays, and C-contiguous memoryview objects, + // which are the most useful data structures for handling byte streams. + // websockets.framing.prepare_data() returns only values of these types. + // Any object implementing the buffer protocol could be supported, however + // that would require allocation or copying memory, which is expensive. if (PyBytes_Check(obj)) { *buffer = PyBytes_AS_STRING(obj); @@ -25,6 +30,23 @@ _PyBytesLike_AsStringAndSize(PyObject *obj, char **buffer, Py_ssize_t *length) *buffer = PyByteArray_AS_STRING(obj); *length = PyByteArray_GET_SIZE(obj); } + else if (PyMemoryView_Check(obj)) + { + Py_buffer *mv_buf; + mv_buf = PyMemoryView_GET_BUFFER(obj); + if (PyBuffer_IsContiguous(mv_buf, 'C')) + { + *buffer = mv_buf->buf; + *length = mv_buf->len; + } + else + { + PyErr_Format( + PyExc_TypeError, + "expected a contiguous memoryview"); + return -1; + } + } else { PyErr_Format( @@ -43,13 +65,14 @@ static PyObject * apply_mask(PyObject *self, PyObject *args, PyObject *kwds) { - // In order to support bytes and bytearray, accept any Python object. + // In order to support various bytes-like types, accept any Python object. static char *kwlist[] = {"data", "mask", NULL}; PyObject *input_obj; PyObject *mask_obj; - // A pointer to the underlying char * will be extracted from these inputs. + // A pointer to a char * + length will be extracted from the data and mask + // arguments, possibly via a Py_buffer. char *input; Py_ssize_t input_len; diff --git a/tests/test_framing.py b/tests/test_framing.py index 570fe3bdf..ab11f6bdc 100644 --- a/tests/test_framing.py +++ b/tests/test_framing.py @@ -163,6 +163,14 @@ def test_prepare_data_bytearray(self): prepare_data(bytearray(b'tea')), (OP_BINARY, bytearray(b'tea')) ) + def test_prepare_data_memoryview(self): + self.assertEqual( + prepare_data(memoryview(b'tea')), (OP_BINARY, memoryview(b'tea')) + ) + + def test_prepare_data_non_contiguous_memoryview(self): + self.assertEqual(prepare_data(memoryview(b'tteeaa')[::2]), (OP_BINARY, b'tea')) + def test_prepare_data_list(self): with self.assertRaises(TypeError): prepare_data([]) @@ -180,6 +188,12 @@ def test_encode_data_bytes(self): def test_encode_data_bytearray(self): self.assertEqual(encode_data(bytearray(b'tea')), b'tea') + def test_encode_data_memoryview(self): + self.assertEqual(encode_data(memoryview(b'tea')), b'tea') + + def test_encode_data_non_contiguous_memoryview(self): + self.assertEqual(encode_data(memoryview(b'tteeaa')[::2]), b'tea') + def test_encode_data_list(self): with self.assertRaises(TypeError): encode_data([]) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index c546e4e48..a5eb251c9 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -539,6 +539,14 @@ def test_send_binary_from_bytearray(self): self.loop.run_until_complete(self.protocol.send(bytearray(b'tea'))) self.assertOneFrameSent(True, OP_BINARY, b'tea') + def test_send_binary_from_memoryview(self): + self.loop.run_until_complete(self.protocol.send(memoryview(b'tea'))) + self.assertOneFrameSent(True, OP_BINARY, b'tea') + + def test_send_binary_from_non_contiguous_memoryview(self): + self.loop.run_until_complete(self.protocol.send(memoryview(b'tteeaa')[::2])) + self.assertOneFrameSent(True, OP_BINARY, b'tea') + def test_send_type_error(self): with self.assertRaises(TypeError): self.loop.run_until_complete(self.protocol.send(42)) @@ -566,6 +574,22 @@ def test_send_iterable_binary_from_bytearray(self): (False, OP_BINARY, b'te'), (False, OP_CONT, b'a'), (True, OP_CONT, b'') ) + def test_send_iterable_binary_from_memoryview(self): + self.loop.run_until_complete( + self.protocol.send([memoryview(b'te'), memoryview(b'a')]) + ) + self.assertFramesSent( + (False, OP_BINARY, b'te'), (False, OP_CONT, b'a'), (True, OP_CONT, b'') + ) + + def test_send_iterable_binary_from_non_contiguous_memoryview(self): + self.loop.run_until_complete( + self.protocol.send([memoryview(b'ttee')[::2], memoryview(b'aa')[::2]]) + ) + self.assertFramesSent( + (False, OP_BINARY, b'te'), (False, OP_CONT, b'a'), (True, OP_CONT, b'') + ) + def test_send_empty_iterable(self): self.loop.run_until_complete(self.protocol.send([])) self.assertNoFrameSent() @@ -632,6 +656,14 @@ def test_ping_binary_from_bytearray(self): self.loop.run_until_complete(self.protocol.ping(bytearray(b'tea'))) self.assertOneFrameSent(True, OP_PING, b'tea') + def test_ping_binary_from_memoryview(self): + self.loop.run_until_complete(self.protocol.ping(memoryview(b'tea'))) + self.assertOneFrameSent(True, OP_PING, b'tea') + + def test_ping_binary_from_non_contiguous_memoryview(self): + self.loop.run_until_complete(self.protocol.ping(memoryview(b'tteeaa')[::2])) + self.assertOneFrameSent(True, OP_PING, b'tea') + def test_ping_type_error(self): with self.assertRaises(TypeError): self.loop.run_until_complete(self.protocol.ping(42)) @@ -681,6 +713,14 @@ def test_pong_binary_from_bytearray(self): self.loop.run_until_complete(self.protocol.pong(bytearray(b'tea'))) self.assertOneFrameSent(True, OP_PONG, b'tea') + def test_pong_binary_from_memoryview(self): + self.loop.run_until_complete(self.protocol.pong(memoryview(b'tea'))) + self.assertOneFrameSent(True, OP_PONG, b'tea') + + def test_pong_binary_from_non_contiguous_memoryview(self): + self.loop.run_until_complete(self.protocol.pong(memoryview(b'tteeaa')[::2])) + self.assertOneFrameSent(True, OP_PONG, b'tea') + def test_pong_type_error(self): with self.assertRaises(TypeError): self.loop.run_until_complete(self.protocol.pong(42)) diff --git a/tests/test_utils.py b/tests/test_utils.py index d2573e235..1b913fe7f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -9,17 +9,45 @@ class UtilsTests(unittest.TestCase): def apply_mask(*args, **kwargs): return py_apply_mask(*args, **kwargs) + apply_mask_type_combos = list(itertools.product([bytes, bytearray], repeat=2)) + + apply_mask_test_values = [ + (b'', b'1234', b''), + (b'aBcDe', b'\x00\x00\x00\x00', b'aBcDe'), + (b'abcdABCD', b'1234', b'PPPPpppp'), + (b'abcdABCD' * 10, b'1234', b'PPPPpppp' * 10), + ] + def test_apply_mask(self): - for data_type, mask_type in itertools.product([bytes, bytearray], repeat=2): - for data_in, mask, data_out in [ - (b'', b'1234', b''), - (b'aBcDe', b'\x00\x00\x00\x00', b'aBcDe'), - (b'abcdABCD', b'1234', b'PPPPpppp'), - (b'abcdABCD' * 10, b'1234', b'PPPPpppp' * 10), - ]: + for data_type, mask_type in self.apply_mask_type_combos: + for data_in, mask, data_out in self.apply_mask_test_values: data_in, mask = data_type(data_in), mask_type(mask) + + with self.subTest(data_in=data_in, mask=mask): + result = self.apply_mask(data_in, mask) + self.assertEqual(result, data_out) + + def test_apply_mask_memoryview(self): + for data_type, mask_type in self.apply_mask_type_combos: + for data_in, mask, data_out in self.apply_mask_test_values: + data_in, mask = data_type(data_in), mask_type(mask) + data_in, mask = memoryview(data_in), memoryview(mask) + with self.subTest(data_in=data_in, mask=mask): - self.assertEqual(self.apply_mask(data_in, mask), data_out) + result = self.apply_mask(data_in, mask) + self.assertEqual(result, data_out) + + def test_apply_mask_non_contiguous_memoryview(self): + for data_type, mask_type in self.apply_mask_type_combos: + for data_in, mask, data_out in self.apply_mask_test_values: + data_in, mask = data_type(data_in), mask_type(mask) + data_in, mask = memoryview(data_in), memoryview(mask) + data_in, mask = data_in[::-1], mask[::-1] + data_out = data_out[::-1] + + with self.subTest(data_in=data_in, mask=mask): + result = self.apply_mask(data_in, mask) + self.assertEqual(result, data_out) def test_apply_mask_check_input_types(self): for data_in, mask in [(None, None), (b'abcd', None), (None, b'abcd')]: @@ -49,3 +77,16 @@ class SpeedupsTests(UtilsTests): @staticmethod def apply_mask(*args, **kwargs): return c_apply_mask(*args, **kwargs) + + def test_apply_mask_non_contiguous_memoryview(self): + for data_type, mask_type in self.apply_mask_type_combos: + for data_in, mask, data_out in self.apply_mask_test_values: + data_in, mask = data_type(data_in), mask_type(mask) + data_in, mask = memoryview(data_in), memoryview(mask) + data_in, mask = data_in[::-1], mask[::-1] + data_out = data_out[::-1] + + with self.subTest(data_in=data_in, mask=mask): + # The C extension only supports contiguous memoryviews. + with self.assertRaises(TypeError): + self.apply_mask(data_in, mask) From dff6cfc1a285c0a36ad440290fe81b46269bbba8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 4 Nov 2018 18:12:57 +0100 Subject: [PATCH 0513/1539] Add documentation. Ref #478. --- docs/changelog.rst | 7 +++++++ src/websockets/protocol.py | 15 ++++++++------- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 393abf1f8..3ec35445d 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -13,6 +13,13 @@ Changelog **Version 8.0 adds the reason phrase to the return type of the low-level API** :func:`~http.read_response` **.** +Also: + +* :meth:`~protocol.WebSocketCommonProtocol.send`, + :meth:`~protocol.WebSocketCommonProtocol.ping`, and + :meth:`~protocol.WebSocketCommonProtocol.pong` support bytes-like types + :class:`bytearray` and :class:`memoryview` in addition to :class:`bytes`. + 7.0 ... diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 13a370aca..e154a62cf 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -448,10 +448,11 @@ def send(self, data): """ This coroutine sends a message. - It sends :class:`str` as a text frame and :class:`bytes` as a binary - frame. + It sends a string (:class:`str`) as a text frame and a bytes-like + object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) + as a binary frame. - It also accepts an iterable of :class:`str` or :class:`bytes`. Each + It also accepts an iterable of strings or bytes-like objects. Each item is treated as a message fragment and sent in its own frame. All items must be of the same type, or else :meth:`send` will raise a :exc:`TypeError` and the connection will be closed. @@ -572,8 +573,8 @@ def ping(self, data=None): await pong_waiter # only if you want to wait for the pong By default, the ping contains four random bytes. The content may be - overridden with the optional ``data`` argument which must be of type - :class:`str` (which will be encoded to UTF-8) or :class:`bytes`. + overridden with the optional ``data`` argument which must be a string + (which will be encoded to UTF-8) or a bytes-like object. """ yield from self.ensure_open() @@ -603,8 +604,8 @@ def pong(self, data=b''): An unsolicited pong may serve as a unidirectional heartbeat. The content may be overridden with the optional ``data`` argument - which must be of type :class:`str` (which will be encoded to UTF-8) or - :class:`bytes`. + which must be a string (which will be encoded to UTF-8) or a + bytes-like object. """ yield from self.ensure_open() From a4dbe6ccb22fd9a591e7557de8ed9b6aa7202741 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 6 Nov 2018 22:50:49 +0100 Subject: [PATCH 0514/1539] Fix wait_closed signature in docs. Fix #512. --- docs/api.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/api.rst b/docs/api.rst index 3971ff8b4..80d64e254 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -68,7 +68,7 @@ Shared .. autoclass:: WebSocketCommonProtocol(*, host=None, port=None, secure=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None) .. automethod:: close(code=1000, reason='') - .. automethod:: wait_closed(code=1000, reason='') + .. automethod:: wait_closed() .. automethod:: recv() .. automethod:: send(data) From b6fc5c06d91fbbd76c5db5293adb9b3269116557 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 11 Nov 2018 10:22:18 +0100 Subject: [PATCH 0515/1539] Fix side effect of automatic code formatting.` --- src/websockets/protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index e154a62cf..52e39a2af 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -888,7 +888,7 @@ def write_frame(self, fin, opcode, data, *, _expected_state=State.OPEN): # Defensive assertion for protocol compliance. if self.state is not _expected_state: # pragma: no cover raise InvalidState( - "Cannot write to a WebSocket " "in the {} state".format(self.state.name) + "Cannot write to a WebSocket in the {} state".format(self.state.name) ) frame = Frame(fin, opcode, data) From 00458f2749bbaeb36280c3129af74f00dab26b3d Mon Sep 17 00:00:00 2001 From: Cory Johns Date: Wed, 12 Dec 2018 17:47:32 -0500 Subject: [PATCH 0516/1539] Handle redirects in client when connecting Per https://tools.ietf.org/html/rfc6455.html#section-4.2.2 the server may redirect the client during the handshake. This allows the client to handle redirects properly instead of raising an InvalidStatusCode error. --- src/websockets/client.py | 119 ++++++++++++++++++++++------------ src/websockets/exceptions.py | 10 +++ src/websockets/py35/client.py | 41 ++++++++---- tests/test_client_server.py | 87 ++++++++++++++++++++++++- 4 files changed, 202 insertions(+), 55 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 2de160e9c..7b0421a44 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -13,6 +13,7 @@ InvalidMessage, InvalidStatusCode, NegotiationError, + RedirectHandshake, ) from .extensions.permessage_deflate import ClientPerMessageDeflateFactory from .handshake import build_request, check_response @@ -289,8 +290,11 @@ def handshake( self.write_http_request(wsuri.resource_name, request_headers) status_code, response_headers = yield from self.read_http_response() - - if status_code != 101: + if status_code in (301, 302, 303, 307, 308): + if 'Location' not in response_headers: + raise InvalidMessage('Redirect response missing Location') + raise RedirectHandshake(parse_uri(response_headers['Location'])) + elif status_code != 101: raise InvalidStatusCode(status_code) check_response(response_headers, key) @@ -358,6 +362,8 @@ class Connect: """ + MAX_REDIRECTS_ALLOWED = 10 + def __init__( self, uri, @@ -394,8 +400,8 @@ def __init__( if create_protocol is None: create_protocol = klass - wsuri = parse_uri(uri) - if wsuri.secure: + self._wsuri = parse_uri(uri) + if self._wsuri.secure: kwds.setdefault('ssl', True) elif kwds.get('ssl') is not None: raise ValueError( @@ -416,53 +422,86 @@ def __init__( elif compression is not None: raise ValueError("Unsupported compression: {}".format(compression)) - factory = lambda: create_protocol( - host=wsuri.host, - port=wsuri.port, - secure=wsuri.secure, - ping_interval=ping_interval, - ping_timeout=ping_timeout, - close_timeout=close_timeout, - max_size=max_size, - max_queue=max_queue, - read_limit=read_limit, - write_limit=write_limit, - loop=loop, - legacy_recv=legacy_recv, - origin=origin, - extensions=extensions, - subprotocols=subprotocols, - extra_headers=extra_headers, + self._create_protocol = create_protocol + self._ping_interval = ping_interval + self._ping_timeout = ping_timeout + self._close_timeout = close_timeout + self._max_size = max_size + self._max_queue = max_queue + self._read_limit = read_limit + self._write_limit = write_limit + self._loop = loop + self._legacy_recv = legacy_recv + self._klass = klass + self._timeout = timeout + self._compression = compression + self._origin = origin + self._extensions = extensions + self._subprotocols = subprotocols + self._extra_headers = extra_headers + self._kwds = kwds + + def _creating_connection(self): + if self._wsuri.secure: + self._kwds.setdefault('ssl', True) + + factory = lambda: self._create_protocol( + host=self._wsuri.host, + port=self._wsuri.port, + secure=self._wsuri.secure, + ping_interval=self._ping_interval, + ping_timeout=self._ping_timeout, + close_timeout=self._close_timeout, + max_size=self._max_size, + max_queue=self._max_queue, + read_limit=self._read_limit, + write_limit=self._write_limit, + loop=self._loop, + legacy_recv=self._legacy_recv, + origin=self._origin, + extensions=self._extensions, + subprotocols=self._subprotocols, + extra_headers=self._extra_headers, ) - if kwds.get('sock') is None: - host, port = wsuri.host, wsuri.port + if self._kwds.get('sock') is None: + host, port = self._wsuri.host, self._wsuri.port else: # If sock is given, host and port mustn't be specified. host, port = None, None - self._wsuri = wsuri - self._origin = origin + self._wsuri = self._wsuri + self._origin = self._origin # This is a coroutine object. - self._creating_connection = loop.create_connection(factory, host, port, **kwds) + return self._loop.create_connection(factory, host, port, **self._kwds) @asyncio.coroutine def __iter__(self): # pragma: no cover - transport, protocol = yield from self._creating_connection - - try: - yield from protocol.handshake( - self._wsuri, - origin=self._origin, - available_extensions=protocol.available_extensions, - available_subprotocols=protocol.available_subprotocols, - extra_headers=protocol.extra_headers, - ) - except Exception: - protocol.fail_connection() - yield from protocol.wait_closed() - raise + for redirects in range(self.MAX_REDIRECTS_ALLOWED): + transport, protocol = yield from self._creating_connection() + + try: + try: + yield from protocol.handshake( + self._wsuri, + origin=self._origin, + available_extensions=protocol.available_extensions, + available_subprotocols=protocol.available_subprotocols, + extra_headers=protocol.extra_headers, + ) + break # redirection chain ended + except Exception: + protocol.fail_connection() + yield from protocol.wait_closed() + raise + except RedirectHandshake as e: + if self._wsuri.secure and not e.wsuri.secure: + raise InvalidHandshake('Redirect dropped TLS') + self._wsuri = e.wsuri + continue # redirection chain continues + else: + raise InvalidHandshake('Maximum redirects exceeded') self.ws_client = protocol return protocol diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index b34a2c0dc..39fa093ee 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -43,6 +43,16 @@ def __init__(self, status, headers, body=b''): super().__init__(message) +class RedirectHandshake(InvalidHandshake): + """ + Exception raised when a handshake gets redirected. + + """ + + def __init__(self, wsuri): + self.wsuri = wsuri + + class InvalidMessage(InvalidHandshake): """ Exception raised when the HTTP message in a handshake request is malformed. diff --git a/src/websockets/py35/client.py b/src/websockets/py35/client.py index a016ba437..bd902841a 100644 --- a/src/websockets/py35/client.py +++ b/src/websockets/py35/client.py @@ -1,3 +1,6 @@ +from ..exceptions import InvalidHandshake, RedirectHandshake + + async def __aenter__(self): return await self @@ -9,20 +12,30 @@ async def __aexit__(self, exc_type, exc_value, traceback): async def __await_impl__(self): # Duplicated with __iter__ because Python 3.7 requires an async function # (as explained in __await__ below) which Python 3.4 doesn't support. - transport, protocol = await self._creating_connection - - try: - await protocol.handshake( - self._wsuri, - origin=self._origin, - available_extensions=protocol.available_extensions, - available_subprotocols=protocol.available_subprotocols, - extra_headers=protocol.extra_headers, - ) - except Exception: - protocol.fail_connection() - await protocol.wait_closed() - raise + for redirects in range(self.MAX_REDIRECTS_ALLOWED): + transport, protocol = await self._creating_connection() + + try: + try: + await protocol.handshake( + self._wsuri, + origin=self._origin, + available_extensions=protocol.available_extensions, + available_subprotocols=protocol.available_subprotocols, + extra_headers=protocol.extra_headers, + ) + break # redirection chain ended + except Exception: + protocol.fail_connection() + await protocol.wait_closed() + raise + except RedirectHandshake as e: + if self._wsuri.secure and not e.wsuri.secure: + raise InvalidHandshake('Redirect dropped TLS') + self._wsuri = e.wsuri + continue # redirection chain continues + else: + raise InvalidHandshake('Maximum redirects exceeded') self.ws_client = protocol return protocol diff --git a/tests/test_client_server.py b/tests/test_client_server.py index 73866ff63..394d090a7 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -1,6 +1,7 @@ import asyncio import contextlib import functools +import http import logging import pathlib import random @@ -19,6 +20,7 @@ from websockets.exceptions import ( ConnectionClosed, InvalidHandshake, + InvalidMessage, InvalidStatusCode, NegotiationError, ) @@ -79,6 +81,16 @@ def temp_test_server(test, **kwds): test.stop_server() +@contextlib.contextmanager +def temp_test_redirecting_server(test, status, + include_location=True, force_insecure=False): + test.start_redirecting_server(status, include_location, force_insecure) + try: + yield + finally: + test.stop_redirecting_server() + + @contextlib.contextmanager def temp_test_client(test, *args, **kwds): test.start_client(*args, **kwds) @@ -227,6 +239,8 @@ class ClientServerTests(unittest.TestCase): def setUp(self): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) + self.server = None + self.redirecting_server = None def tearDown(self): self.loop.close() @@ -237,6 +251,10 @@ def run_loop_once(self): self.loop.call_soon(self.loop.stop) self.loop.run_forever() + @property + def server_context(self): + return None + def start_server(self, **kwds): # Disable compression by default in tests. kwds.setdefault('compression', None) @@ -245,13 +263,30 @@ def start_server(self, **kwds): start_server = serve(handler, 'localhost', 0, **kwds) self.server = self.loop.run_until_complete(start_server) + def start_redirecting_server(self, status, + include_location=True, force_insecure=False): + def _process_request(path, headers): + server_uri = get_server_uri(self.server, self.secure, path) + if force_insecure: + server_uri = server_uri.replace('wss:', 'ws:') + headers = {'Location': server_uri} if include_location else [] + return status, headers, b"" + + start_server = serve(handler, 'localhost', 0, + compression=None, + ping_interval=None, + process_request=_process_request, + ssl=self.server_context) + self.redirecting_server = self.loop.run_until_complete(start_server) + def start_client(self, resource_name='/', user_info=None, **kwds): # Disable compression by default in tests. kwds.setdefault('compression', None) # Disable pings by default in tests. kwds.setdefault('ping_interval', None) secure = kwds.get('ssl') is not None - server_uri = get_server_uri(self.server, secure, resource_name, user_info) + server = self.redirecting_server if self.redirecting_server else self.server + server_uri = get_server_uri(server, secure, resource_name, user_info) start_client = connect(server_uri, **kwds) self.client = self.loop.run_until_complete(start_client) @@ -272,6 +307,17 @@ def stop_server(self): except asyncio.TimeoutError: # pragma: no cover self.fail("Server failed to stop") + def stop_redirecting_server(self): + self.redirecting_server.close() + try: + self.loop.run_until_complete( + asyncio.wait_for(self.redirecting_server.wait_closed(), timeout=1) + ) + except asyncio.TimeoutError: # pragma: no cover + self.fail("Redirecting server failed to stop") + finally: + self.redirecting_server = None + @contextlib.contextmanager def temp_server(self, **kwds): with temp_test_server(self, **kwds): @@ -289,6 +335,37 @@ def test_basic(self): reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") + @with_server() + def test_redirect(self): + redirect_statuses = [ + http.HTTPStatus.MOVED_PERMANENTLY, + http.HTTPStatus.FOUND, + http.HTTPStatus.SEE_OTHER, + http.HTTPStatus.TEMPORARY_REDIRECT, + http.HTTPStatus.PERMANENT_REDIRECT, + ] + for status in redirect_statuses: + with temp_test_redirecting_server(self, status): + with temp_test_client(self): + self.loop.run_until_complete(self.client.send("Hello!")) + reply = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(reply, "Hello!") + + def test_infinite_redirect(self): + with temp_test_redirecting_server(self, http.HTTPStatus.FOUND): + self.server = self.redirecting_server + with self.assertRaises(InvalidHandshake): + with temp_test_client(self): + self.fail('Did not raise') # pragma: no cover + + @with_server() + def test_redirect_missing_location(self): + with temp_test_redirecting_server(self, http.HTTPStatus.FOUND, + include_location=False): + with self.assertRaises(InvalidMessage): + with temp_test_client(self): + self.fail('Did not raise') # pragma: no cover + def test_explicit_event_loop(self): with self.temp_server(loop=self.loop): with self.temp_client(loop=self.loop): @@ -1070,6 +1147,14 @@ def test_ws_uri_is_rejected(self): # raised only when awaiting. self.loop.run_until_complete(client) # pragma: no cover + @with_server() + def test_redirect_insecure(self): + with temp_test_redirecting_server(self, http.HTTPStatus.FOUND, + force_insecure=True): + with self.assertRaises(InvalidHandshake): + with temp_test_client(self): + self.fail('Did not raise') # pragma: no cover + class ClientServerOriginTests(unittest.TestCase): def setUp(self): From ee92bc490bd762ef575cd2eee8883561d4066f15 Mon Sep 17 00:00:00 2001 From: Cory Johns Date: Wed, 12 Dec 2018 19:57:19 -0500 Subject: [PATCH 0517/1539] Run black to normalize formatting --- tests/test_client_server.py | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/tests/test_client_server.py b/tests/test_client_server.py index 394d090a7..86f3ff277 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -82,8 +82,9 @@ def temp_test_server(test, **kwds): @contextlib.contextmanager -def temp_test_redirecting_server(test, status, - include_location=True, force_insecure=False): +def temp_test_redirecting_server( + test, status, include_location=True, force_insecure=False +): test.start_redirecting_server(status, include_location, force_insecure) try: yield @@ -263,8 +264,9 @@ def start_server(self, **kwds): start_server = serve(handler, 'localhost', 0, **kwds) self.server = self.loop.run_until_complete(start_server) - def start_redirecting_server(self, status, - include_location=True, force_insecure=False): + def start_redirecting_server( + self, status, include_location=True, force_insecure=False + ): def _process_request(path, headers): server_uri = get_server_uri(self.server, self.secure, path) if force_insecure: @@ -272,11 +274,15 @@ def _process_request(path, headers): headers = {'Location': server_uri} if include_location else [] return status, headers, b"" - start_server = serve(handler, 'localhost', 0, - compression=None, - ping_interval=None, - process_request=_process_request, - ssl=self.server_context) + start_server = serve( + handler, + 'localhost', + 0, + compression=None, + ping_interval=None, + process_request=_process_request, + ssl=self.server_context, + ) self.redirecting_server = self.loop.run_until_complete(start_server) def start_client(self, resource_name='/', user_info=None, **kwds): @@ -360,8 +366,9 @@ def test_infinite_redirect(self): @with_server() def test_redirect_missing_location(self): - with temp_test_redirecting_server(self, http.HTTPStatus.FOUND, - include_location=False): + with temp_test_redirecting_server( + self, http.HTTPStatus.FOUND, include_location=False + ): with self.assertRaises(InvalidMessage): with temp_test_client(self): self.fail('Did not raise') # pragma: no cover @@ -1149,8 +1156,9 @@ def test_ws_uri_is_rejected(self): @with_server() def test_redirect_insecure(self): - with temp_test_redirecting_server(self, http.HTTPStatus.FOUND, - force_insecure=True): + with temp_test_redirecting_server( + self, http.HTTPStatus.FOUND, force_insecure=True + ): with self.assertRaises(InvalidHandshake): with temp_test_client(self): self.fail('Did not raise') # pragma: no cover From e4b49877e8ffdc618360f5ed85fc1212056a6ded Mon Sep 17 00:00:00 2001 From: Cory Johns Date: Wed, 12 Dec 2018 20:13:01 -0500 Subject: [PATCH 0518/1539] Fix references to HTTPStatus --- src/websockets/compatibility.py | 25 +++++++++++++++++++++++++ tests/test_client_server.py | 32 ++++++++++++++++++-------------- 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/src/websockets/compatibility.py b/src/websockets/compatibility.py index b6506b70c..369c63e32 100644 --- a/src/websockets/compatibility.py +++ b/src/websockets/compatibility.py @@ -24,6 +24,11 @@ UPGRADE_REQUIRED = http.HTTPStatus.UPGRADE_REQUIRED INTERNAL_SERVER_ERROR = http.HTTPStatus.INTERNAL_SERVER_ERROR SERVICE_UNAVAILABLE = http.HTTPStatus.SERVICE_UNAVAILABLE + MOVED_PERMANENTLY = http.HTTPStatus.MOVED_PERMANENTLY + FOUND = http.HTTPStatus.FOUND + SEE_OTHER = http.HTTPStatus.SEE_OTHER + TEMPORARY_REDIRECT = http.HTTPStatus.TEMPORARY_REDIRECT + PERMANENT_REDIRECT = http.HTTPStatus.PERMANENT_REDIRECT except AttributeError: # pragma: no cover # Python < 3.5 class SWITCHING_PROTOCOLS: @@ -57,3 +62,23 @@ class INTERNAL_SERVER_ERROR: class SERVICE_UNAVAILABLE: value = 503 phrase = "Service Unavailable" + + class MOVED_PERMANENTLY: + value = 301 + phrase = "Moved Permanently" + + class FOUND: + value = 302 + phrase = "Found" + + class SEE_OTHER: + value = 303 + phrase = "See Other" + + class TEMPORARY_REDIRECT: + value = 307 + phrase = "Temporary Redirect" + + class PERMANENT_REDIRECT: + value = 308 + phrase = "Permanent Redirect" diff --git a/tests/test_client_server.py b/tests/test_client_server.py index 86f3ff277..eade7e066 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -1,7 +1,6 @@ import asyncio import contextlib import functools -import http import logging import pathlib import random @@ -16,7 +15,16 @@ import warnings from websockets.client import * -from websockets.compatibility import FORBIDDEN, OK, UNAUTHORIZED +from websockets.compatibility import ( + FORBIDDEN, + FOUND, + MOVED_PERMANENTLY, + OK, + PERMANENT_REDIRECT, + SEE_OTHER, + TEMPORARY_REDIRECT, + UNAUTHORIZED, +) from websockets.exceptions import ( ConnectionClosed, InvalidHandshake, @@ -344,11 +352,11 @@ def test_basic(self): @with_server() def test_redirect(self): redirect_statuses = [ - http.HTTPStatus.MOVED_PERMANENTLY, - http.HTTPStatus.FOUND, - http.HTTPStatus.SEE_OTHER, - http.HTTPStatus.TEMPORARY_REDIRECT, - http.HTTPStatus.PERMANENT_REDIRECT, + MOVED_PERMANENTLY, + FOUND, + SEE_OTHER, + TEMPORARY_REDIRECT, + PERMANENT_REDIRECT, ] for status in redirect_statuses: with temp_test_redirecting_server(self, status): @@ -358,7 +366,7 @@ def test_redirect(self): self.assertEqual(reply, "Hello!") def test_infinite_redirect(self): - with temp_test_redirecting_server(self, http.HTTPStatus.FOUND): + with temp_test_redirecting_server(self, FOUND): self.server = self.redirecting_server with self.assertRaises(InvalidHandshake): with temp_test_client(self): @@ -366,9 +374,7 @@ def test_infinite_redirect(self): @with_server() def test_redirect_missing_location(self): - with temp_test_redirecting_server( - self, http.HTTPStatus.FOUND, include_location=False - ): + with temp_test_redirecting_server(self, FOUND, include_location=False): with self.assertRaises(InvalidMessage): with temp_test_client(self): self.fail('Did not raise') # pragma: no cover @@ -1156,9 +1162,7 @@ def test_ws_uri_is_rejected(self): @with_server() def test_redirect_insecure(self): - with temp_test_redirecting_server( - self, http.HTTPStatus.FOUND, force_insecure=True - ): + with temp_test_redirecting_server(self, FOUND, force_insecure=True): with self.assertRaises(InvalidHandshake): with temp_test_client(self): self.fail('Did not raise') # pragma: no cover From 170088dcf478b661097556e97ad60d0dba410408 Mon Sep 17 00:00:00 2001 From: Cory Johns Date: Fri, 21 Dec 2018 15:23:05 -0500 Subject: [PATCH 0519/1539] Add line to changelog for redirect handling --- docs/changelog.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index 3ec35445d..a76e1212e 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -20,6 +20,9 @@ Also: :meth:`~protocol.WebSocketCommonProtocol.pong` support bytes-like types :class:`bytearray` and :class:`memoryview` in addition to :class:`bytes`. +* :func:`~client.connect()` handles redirects from the server during the + handshake. + 7.0 ... From 5931865413bfe7afa2be6e6e947870668b729e15 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 29 Dec 2018 14:33:07 +0100 Subject: [PATCH 0520/1539] Normalize string style with black. --- Makefile | 2 +- src/websockets/__main__.py | 30 +- src/websockets/client.py | 48 +-- src/websockets/compatibility.py | 2 +- src/websockets/exceptions.py | 42 +- .../extensions/permessage_deflate.py | 48 +-- src/websockets/framing.py | 54 +-- src/websockets/handshake.py | 70 ++-- src/websockets/headers.py | 54 +-- src/websockets/http.py | 48 +-- src/websockets/protocol.py | 50 +-- src/websockets/py35/client.py | 4 +- src/websockets/server.py | 44 +-- src/websockets/uri.py | 16 +- src/websockets/utils.py | 2 +- src/websockets/version.py | 2 +- tests/extensions/test_permessage_deflate.py | 308 +++++++-------- tests/py35/_test_client_server.py | 18 +- tests/py36/_test_client_server.py | 8 +- tests/test_client_server.py | 328 ++++++++-------- tests/test_framing.py | 88 ++--- tests/test_handshake.py | 68 ++-- tests/test_headers.py | 72 ++-- tests/test_http.py | 118 +++--- tests/test_protocol.py | 358 +++++++++--------- tests/test_uri.py | 16 +- tests/test_utils.py | 18 +- tox.ini | 2 +- 28 files changed, 959 insertions(+), 959 deletions(-) diff --git a/Makefile b/Makefile index 2d77dcfc7..0863f8578 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ export PYTHONPATH=src style: isort --recursive src tests - black --skip-string-normalization src tests + black src tests flake8 src tests test: diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index af9286637..4c880c24c 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -56,19 +56,19 @@ def print_during_input(string): sys.stdout.write( ( # Save cursor position - '\N{ESC}7' + "\N{ESC}7" # Add a new line - '\N{LINE FEED}' + "\N{LINE FEED}" # Move cursor up - '\N{ESC}[A' + "\N{ESC}[A" # Insert blank line, scroll last line down - '\N{ESC}[L' + "\N{ESC}[L" # Print string in the inserted blank line - '{string}\N{LINE FEED}' + "{string}\N{LINE FEED}" # Restore cursor position - '\N{ESC}8' + "\N{ESC}8" # Move cursor down - '\N{ESC}[B' + "\N{ESC}[B" ).format(string=string) ) sys.stdout.flush() @@ -78,11 +78,11 @@ def print_over_input(string): sys.stdout.write( ( # Move cursor to beginning of line - '\N{CARRIAGE RETURN}' + "\N{CARRIAGE RETURN}" # Delete current line - '\N{ESC}[K' + "\N{ESC}[K" # Print string - '{string}\N{LINE FEED}' + "{string}\N{LINE FEED}" ).format(string=string) ) sys.stdout.flush() @@ -119,7 +119,7 @@ def run_client(uri, loop, inputs, stop): except websockets.ConnectionClosed: break else: - print_during_input('< ' + message) + print_during_input("< " + message) if outgoing in done: message = outgoing.result() @@ -141,7 +141,7 @@ def run_client(uri, loop, inputs, stop): def main(): # If we're on Windows, enable VT100 terminal support. - if os.name == 'nt': + if os.name == "nt": try: win_enable_vt100() except RuntimeError as exc: @@ -160,7 +160,7 @@ def main(): description="Interactive WebSocket client.", add_help=False, ) - parser.add_argument('uri', metavar='') + parser.add_argument("uri", metavar="") args = parser.parse_args() # Create an event loop that will run in a background thread. @@ -183,7 +183,7 @@ def main(): try: while True: # Since there's no size limit, put_nowait is identical to put. - message = input('> ') + message = input("> ") loop.call_soon_threadsafe(inputs.put_nowait, message) except (KeyboardInterrupt, EOFError): # ^C, ^D loop.call_soon_threadsafe(stop.set_result, None) @@ -192,5 +192,5 @@ def main(): thread.join() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/src/websockets/client.py b/src/websockets/client.py index 7b0421a44..66034ce25 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -29,7 +29,7 @@ from .uri import parse_uri -__all__ = ['connect', 'WebSocketClientProtocol'] +__all__ = ["connect", "WebSocketClientProtocol"] logger = logging.getLogger(__name__) @@ -44,7 +44,7 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): """ is_client = True - side = 'client' + side = "client" def __init__( self, @@ -74,7 +74,7 @@ def write_http_request(self, path, headers): # Since the path and headers only contain ASCII characters, # we can keep this simple. - request = 'GET {path} HTTP/1.1\r\n'.format(path=path) + request = "GET {path} HTTP/1.1\r\n".format(path=path) request += str(headers) self.writer.write(request.encode()) @@ -134,7 +134,7 @@ def process_extensions(headers, available_extensions): """ accepted_extensions = [] - header_values = headers.get_all('Sec-WebSocket-Extensions') + header_values = headers.get_all("Sec-WebSocket-Extensions") if header_values: @@ -191,7 +191,7 @@ def process_subprotocol(headers, available_subprotocols): """ subprotocol = None - header_values = headers.get_all('Sec-WebSocket-Protocol') + header_values = headers.get_all("Sec-WebSocket-Protocol") if header_values: @@ -208,7 +208,7 @@ def process_subprotocol(headers, available_subprotocols): if len(parsed_header_values) > 1: raise InvalidHandshake( - "Multiple subprotocols: {}".format(', '.join(parsed_header_values)) + "Multiple subprotocols: {}".format(", ".join(parsed_header_values)) ) subprotocol = parsed_header_values[0] @@ -252,15 +252,15 @@ def handshake( request_headers = Headers() if wsuri.port == (443 if wsuri.secure else 80): # pragma: no cover - request_headers['Host'] = wsuri.host + request_headers["Host"] = wsuri.host else: - request_headers['Host'] = '{}:{}'.format(wsuri.host, wsuri.port) + request_headers["Host"] = "{}:{}".format(wsuri.host, wsuri.port) if wsuri.user_info: - request_headers['Authorization'] = build_basic_auth(*wsuri.user_info) + request_headers["Authorization"] = build_basic_auth(*wsuri.user_info) if origin is not None: - request_headers['Origin'] = origin + request_headers["Origin"] = origin key = build_request(request_headers) @@ -271,11 +271,11 @@ def handshake( for extension_factory in available_extensions ] ) - request_headers['Sec-WebSocket-Extensions'] = extensions_header + request_headers["Sec-WebSocket-Extensions"] = extensions_header if available_subprotocols is not None: protocol_header = build_subprotocol_list(available_subprotocols) - request_headers['Sec-WebSocket-Protocol'] = protocol_header + request_headers["Sec-WebSocket-Protocol"] = protocol_header if extra_headers is not None: if isinstance(extra_headers, Headers): @@ -285,15 +285,15 @@ def handshake( for name, value in extra_headers: request_headers[name] = value - request_headers.setdefault('User-Agent', USER_AGENT) + request_headers.setdefault("User-Agent", USER_AGENT) self.write_http_request(wsuri.resource_name, request_headers) status_code, response_headers = yield from self.read_http_response() if status_code in (301, 302, 303, 307, 308): - if 'Location' not in response_headers: - raise InvalidMessage('Redirect response missing Location') - raise RedirectHandshake(parse_uri(response_headers['Location'])) + if "Location" not in response_headers: + raise InvalidMessage("Redirect response missing Location") + raise RedirectHandshake(parse_uri(response_headers["Location"])) elif status_code != 101: raise InvalidStatusCode(status_code) @@ -380,7 +380,7 @@ def __init__( legacy_recv=False, klass=WebSocketClientProtocol, timeout=10, - compression='deflate', + compression="deflate", origin=None, extensions=None, subprotocols=None, @@ -402,14 +402,14 @@ def __init__( self._wsuri = parse_uri(uri) if self._wsuri.secure: - kwds.setdefault('ssl', True) - elif kwds.get('ssl') is not None: + kwds.setdefault("ssl", True) + elif kwds.get("ssl") is not None: raise ValueError( "connect() received a SSL context for a ws:// URI, " "use a wss:// URI to enable TLS" ) - if compression == 'deflate': + if compression == "deflate": if extensions is None: extensions = [] if not any( @@ -443,7 +443,7 @@ def __init__( def _creating_connection(self): if self._wsuri.secure: - self._kwds.setdefault('ssl', True) + self._kwds.setdefault("ssl", True) factory = lambda: self._create_protocol( host=self._wsuri.host, @@ -464,7 +464,7 @@ def _creating_connection(self): extra_headers=self._extra_headers, ) - if self._kwds.get('sock') is None: + if self._kwds.get("sock") is None: host, port = self._wsuri.host, self._wsuri.port else: # If sock is given, host and port mustn't be specified. @@ -497,11 +497,11 @@ def __iter__(self): # pragma: no cover raise except RedirectHandshake as e: if self._wsuri.secure and not e.wsuri.secure: - raise InvalidHandshake('Redirect dropped TLS') + raise InvalidHandshake("Redirect dropped TLS") self._wsuri = e.wsuri continue # redirection chain continues else: - raise InvalidHandshake('Maximum redirects exceeded') + raise InvalidHandshake("Maximum redirects exceeded") self.ws_client = protocol return protocol diff --git a/src/websockets/compatibility.py b/src/websockets/compatibility.py index 369c63e32..8b7a21a5c 100644 --- a/src/websockets/compatibility.py +++ b/src/websockets/compatibility.py @@ -12,7 +12,7 @@ try: # pragma: no cover asyncio_ensure_future = asyncio.ensure_future # Python ≥ 3.5 except AttributeError: # pragma: no cover - asyncio_ensure_future = getattr(asyncio, 'async') # Python < 3.5 + asyncio_ensure_future = getattr(asyncio, "async") # Python < 3.5 try: # pragma: no cover # Python ≥ 3.5 diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 39fa093ee..611e68188 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -1,22 +1,22 @@ __all__ = [ - 'AbortHandshake', - 'ConnectionClosed', - 'DuplicateParameter', - 'InvalidHandshake', - 'InvalidHeader', - 'InvalidHeaderFormat', - 'InvalidHeaderValue', - 'InvalidMessage', - 'InvalidOrigin', - 'InvalidParameterName', - 'InvalidParameterValue', - 'InvalidState', - 'InvalidStatusCode', - 'InvalidUpgrade', - 'InvalidURI', - 'NegotiationError', - 'PayloadTooBig', - 'WebSocketProtocolError', + "AbortHandshake", + "ConnectionClosed", + "DuplicateParameter", + "InvalidHandshake", + "InvalidHeader", + "InvalidHeaderFormat", + "InvalidHeaderValue", + "InvalidMessage", + "InvalidOrigin", + "InvalidParameterName", + "InvalidParameterValue", + "InvalidState", + "InvalidStatusCode", + "InvalidUpgrade", + "InvalidURI", + "NegotiationError", + "PayloadTooBig", + "WebSocketProtocolError", ] @@ -33,7 +33,7 @@ class AbortHandshake(InvalidHandshake): """ - def __init__(self, status, headers, body=b''): + def __init__(self, status, headers, body=b""): self.status = status self.headers = headers self.body = body @@ -69,7 +69,7 @@ class InvalidHeader(InvalidHandshake): def __init__(self, name, value=None): if value is None: message = "Missing {} header".format(name) - elif value == '': + elif value == "": message = "Empty {} header".format(name) else: message = "Invalid {} header: {}".format(name, value) @@ -108,7 +108,7 @@ class InvalidOrigin(InvalidHeader): """ def __init__(self, origin): - super().__init__('Origin', origin) + super().__init__("Origin", origin) class InvalidStatusCode(InvalidHandshake): diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 19f340734..dad6f1ec1 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -17,12 +17,12 @@ __all__ = [ - 'ClientPerMessageDeflateFactory', - 'ServerPerMessageDeflateFactory', - 'PerMessageDeflate', + "ClientPerMessageDeflateFactory", + "ServerPerMessageDeflateFactory", + "PerMessageDeflate", ] -_EMPTY_UNCOMPRESSED_BLOCK = b'\x00\x00\xff\xff' +_EMPTY_UNCOMPRESSED_BLOCK = b"\x00\x00\xff\xff" _MAX_WINDOW_BITS_VALUES = [str(bits) for bits in range(8, 16)] @@ -39,15 +39,15 @@ def _build_parameters( """ params = [] if server_no_context_takeover: - params.append(('server_no_context_takeover', None)) + params.append(("server_no_context_takeover", None)) if client_no_context_takeover: - params.append(('client_no_context_takeover', None)) + params.append(("client_no_context_takeover", None)) if server_max_window_bits: - params.append(('server_max_window_bits', str(server_max_window_bits))) + params.append(("server_max_window_bits", str(server_max_window_bits))) if client_max_window_bits is True: # only in handshake requests - params.append(('client_max_window_bits', None)) + params.append(("client_max_window_bits", None)) elif client_max_window_bits: - params.append(('client_max_window_bits', str(client_max_window_bits))) + params.append(("client_max_window_bits", str(client_max_window_bits))) return params @@ -66,7 +66,7 @@ def _extract_parameters(params, *, is_server): for name, value in params: - if name == 'server_no_context_takeover': + if name == "server_no_context_takeover": if server_no_context_takeover: raise DuplicateParameter(name) if value is None: @@ -74,7 +74,7 @@ def _extract_parameters(params, *, is_server): else: raise InvalidParameterValue(name, value) - elif name == 'client_no_context_takeover': + elif name == "client_no_context_takeover": if client_no_context_takeover: raise DuplicateParameter(name) if value is None: @@ -82,7 +82,7 @@ def _extract_parameters(params, *, is_server): else: raise InvalidParameterValue(name, value) - elif name == 'server_max_window_bits': + elif name == "server_max_window_bits": if server_max_window_bits is not None: raise DuplicateParameter(name) if value in _MAX_WINDOW_BITS_VALUES: @@ -90,7 +90,7 @@ def _extract_parameters(params, *, is_server): else: raise InvalidParameterValue(name, value) - elif name == 'client_max_window_bits': + elif name == "client_max_window_bits": if client_max_window_bits is not None: raise DuplicateParameter(name) if is_server and value is None: # only in handshake requests @@ -117,7 +117,7 @@ class ClientPerMessageDeflateFactory: """ - name = 'permessage-deflate' + name = "permessage-deflate" def __init__( self, @@ -141,7 +141,7 @@ def __init__( or 8 <= client_max_window_bits <= 15 ): raise ValueError("client_max_window_bits must be between 8 and 15") - if compress_settings is not None and 'wbits' in compress_settings: + if compress_settings is not None and "wbits" in compress_settings: raise ValueError( "compress_settings must not include wbits, " "set client_max_window_bits instead" @@ -273,7 +273,7 @@ class ServerPerMessageDeflateFactory: """ - name = 'permessage-deflate' + name = "permessage-deflate" def __init__( self, @@ -293,7 +293,7 @@ def __init__( raise ValueError("server_max_window_bits must be between 8 and 15") if not (client_max_window_bits is None or 8 <= client_max_window_bits <= 15): raise ValueError("client_max_window_bits must be between 8 and 15") - if compress_settings is not None and 'wbits' in compress_settings: + if compress_settings is not None and "wbits" in compress_settings: raise ValueError( "compress_settings must not include wbits, " "set server_max_window_bits instead" @@ -420,7 +420,7 @@ class PerMessageDeflate: """ - name = 'permessage-deflate' + name = "permessage-deflate" def __init__( self, @@ -441,7 +441,7 @@ def __init__( assert local_no_context_takeover in [False, True] assert 8 <= remote_max_window_bits <= 15 assert 8 <= local_max_window_bits <= 15 - assert 'wbits' not in compress_settings + assert "wbits" not in compress_settings self.remote_no_context_takeover = remote_no_context_takeover self.local_no_context_takeover = local_no_context_takeover @@ -465,11 +465,11 @@ def __init__( def __repr__(self): return ( - 'PerMessageDeflate(' - 'remote_no_context_takeover={}, ' - 'local_no_context_takeover={}, ' - 'remote_max_window_bits={}, ' - 'local_max_window_bits={})' + "PerMessageDeflate(" + "remote_no_context_takeover={}, " + "local_no_context_takeover={}, " + "remote_max_window_bits={}, " + "local_max_window_bits={})" ).format( self.remote_no_context_takeover, self.local_no_context_takeover, diff --git a/src/websockets/framing.py b/src/websockets/framing.py index feebd3983..8b0242715 100644 --- a/src/websockets/framing.py +++ b/src/websockets/framing.py @@ -25,19 +25,19 @@ __all__ = [ - 'DATA_OPCODES', - 'CTRL_OPCODES', - 'OP_CONT', - 'OP_TEXT', - 'OP_BINARY', - 'OP_CLOSE', - 'OP_PING', - 'OP_PONG', - 'Frame', - 'prepare_data', - 'encode_data', - 'parse_close', - 'serialize_close', + "DATA_OPCODES", + "CTRL_OPCODES", + "OP_CONT", + "OP_TEXT", + "OP_BINARY", + "OP_CLOSE", + "OP_PING", + "OP_PONG", + "Frame", + "prepare_data", + "encode_data", + "parse_close", + "serialize_close", ] DATA_OPCODES = OP_CONT, OP_TEXT, OP_BINARY = 0x00, 0x01, 0x02 @@ -48,7 +48,7 @@ EXTERNAL_CLOSE_CODES = [1000, 1001, 1002, 1003, 1007, 1008, 1009, 1010, 1011] FrameData = collections.namedtuple( - 'FrameData', ['fin', 'opcode', 'data', 'rsv1', 'rsv2', 'rsv3'] + "FrameData", ["fin", "opcode", "data", "rsv1", "rsv2", "rsv3"] ) @@ -98,7 +98,7 @@ def read(cls, reader, *, mask, max_size=None, extensions=None): """ # Read the header. data = yield from reader(2) - head1, head2 = struct.unpack('!BB', data) + head1, head2 = struct.unpack("!BB", data) # While not Pythonic, this is marginally faster than calling bool(). fin = True if head1 & 0b10000000 else False @@ -113,10 +113,10 @@ def read(cls, reader, *, mask, max_size=None, extensions=None): length = head2 & 0b01111111 if length == 126: data = yield from reader(2) - length, = struct.unpack('!H', data) + length, = struct.unpack("!H", data) elif length == 127: data = yield from reader(8) - length, = struct.unpack('!Q', data) + length, = struct.unpack("!Q", data) if max_size is not None and length > max_size: raise PayloadTooBig( "Payload length exceeds size limit ({} > {} bytes)".format( @@ -187,14 +187,14 @@ def write(frame, writer, *, mask, extensions=None): length = len(frame.data) if length < 126: - output.write(struct.pack('!BB', head1, head2 | length)) + output.write(struct.pack("!BB", head1, head2 | length)) elif length < 65536: - output.write(struct.pack('!BBH', head1, head2 | 126, length)) + output.write(struct.pack("!BBH", head1, head2 | 126, length)) else: - output.write(struct.pack('!BBQ', head1, head2 | 127, length)) + output.write(struct.pack("!BBQ", head1, head2 | 127, length)) if mask: - mask_bits = struct.pack('!I', random.getrandbits(32)) + mask_bits = struct.pack("!I", random.getrandbits(32)) output.write(mask_bits) # Prepare the data. @@ -252,7 +252,7 @@ def prepare_data(data): """ if isinstance(data, str): - return OP_TEXT, data.encode('utf-8') + return OP_TEXT, data.encode("utf-8") elif isinstance(data, collections.abc.ByteString): return OP_BINARY, data elif isinstance(data, memoryview): @@ -279,7 +279,7 @@ def encode_data(data): """ if isinstance(data, str): - return data.encode('utf-8') + return data.encode("utf-8") elif isinstance(data, collections.abc.ByteString): return bytes(data) elif isinstance(data, memoryview): @@ -301,12 +301,12 @@ def parse_close(data): """ length = len(data) if length >= 2: - code, = struct.unpack('!H', data[:2]) + code, = struct.unpack("!H", data[:2]) check_close(code) - reason = data[2:].decode('utf-8') + reason = data[2:].decode("utf-8") return code, reason elif length == 0: - return 1005, '' + return 1005, "" else: assert length == 1 raise WebSocketProtocolError("Close frame too short") @@ -320,7 +320,7 @@ def serialize_close(code, reason): """ check_close(code) - return struct.pack('!H', code) + reason.encode('utf-8') + return struct.pack("!H", code) + reason.encode("utf-8") def check_close(code): diff --git a/src/websockets/handshake.py b/src/websockets/handshake.py index cc4248974..e6bd61fab 100644 --- a/src/websockets/handshake.py +++ b/src/websockets/handshake.py @@ -41,9 +41,9 @@ from .http import MultipleValuesError -__all__ = ['build_request', 'check_request', 'build_response', 'check_response'] +__all__ = ["build_request", "check_request", "build_response", "check_response"] -GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' +GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" def build_request(headers): @@ -55,10 +55,10 @@ def build_request(headers): """ raw_key = bytes(random.getrandbits(8) for _ in range(16)) key = base64.b64encode(raw_key).decode() - headers['Upgrade'] = 'websocket' - headers['Connection'] = 'Upgrade' - headers['Sec-WebSocket-Key'] = key - headers['Sec-WebSocket-Version'] = '13' + headers["Upgrade"] = "websocket" + headers["Connection"] = "Upgrade" + headers["Sec-WebSocket-Key"] = key + headers["Sec-WebSocket-Version"] = "13" return key @@ -79,46 +79,46 @@ def check_request(headers): """ connection = sum( - [parse_connection(value) for value in headers.get_all('Connection')], [] + [parse_connection(value) for value in headers.get_all("Connection")], [] ) - if not any(value.lower() == 'upgrade' for value in connection): - raise InvalidUpgrade('Connection', connection) + if not any(value.lower() == "upgrade" for value in connection): + raise InvalidUpgrade("Connection", connection) - upgrade = sum([parse_upgrade(value) for value in headers.get_all('Upgrade')], []) + upgrade = sum([parse_upgrade(value) for value in headers.get_all("Upgrade")], []) # For compatibility with non-strict implementations, ignore case when # checking the Upgrade header. It's supposed to be 'WebSocket'. - if not (len(upgrade) == 1 and upgrade[0].lower() == 'websocket'): - raise InvalidUpgrade('Upgrade', upgrade) + if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"): + raise InvalidUpgrade("Upgrade", upgrade) try: - s_w_key = headers['Sec-WebSocket-Key'] + s_w_key = headers["Sec-WebSocket-Key"] except KeyError: - raise InvalidHeader('Sec-WebSocket-Key') + raise InvalidHeader("Sec-WebSocket-Key") except MultipleValuesError: raise InvalidHeader( - 'Sec-WebSocket-Key', "more than one Sec-WebSocket-Key header found" + "Sec-WebSocket-Key", "more than one Sec-WebSocket-Key header found" ) try: raw_key = base64.b64decode(s_w_key.encode(), validate=True) except binascii.Error: - raise InvalidHeaderValue('Sec-WebSocket-Key', s_w_key) + raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key) if len(raw_key) != 16: - raise InvalidHeaderValue('Sec-WebSocket-Key', s_w_key) + raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key) try: - s_w_version = headers['Sec-WebSocket-Version'] + s_w_version = headers["Sec-WebSocket-Version"] except KeyError: - raise InvalidHeader('Sec-WebSocket-Version') + raise InvalidHeader("Sec-WebSocket-Version") except MultipleValuesError: raise InvalidHeader( - 'Sec-WebSocket-Version', "more than one Sec-WebSocket-Version header found" + "Sec-WebSocket-Version", "more than one Sec-WebSocket-Version header found" ) - if s_w_version != '13': - raise InvalidHeaderValue('Sec-WebSocket-Version', s_w_version) + if s_w_version != "13": + raise InvalidHeaderValue("Sec-WebSocket-Version", s_w_version) return s_w_key @@ -130,9 +130,9 @@ def build_response(headers, key): ``key`` comes from :func:`check_request`. """ - headers['Upgrade'] = 'websocket' - headers['Connection'] = 'Upgrade' - headers['Sec-WebSocket-Accept'] = accept(key) + headers["Upgrade"] = "websocket" + headers["Connection"] = "Upgrade" + headers["Sec-WebSocket-Accept"] = accept(key) def check_response(headers, key): @@ -152,30 +152,30 @@ def check_response(headers, key): """ connection = sum( - [parse_connection(value) for value in headers.get_all('Connection')], [] + [parse_connection(value) for value in headers.get_all("Connection")], [] ) - if not any(value.lower() == 'upgrade' for value in connection): - raise InvalidUpgrade('Connection', connection) + if not any(value.lower() == "upgrade" for value in connection): + raise InvalidUpgrade("Connection", connection) - upgrade = sum([parse_upgrade(value) for value in headers.get_all('Upgrade')], []) + upgrade = sum([parse_upgrade(value) for value in headers.get_all("Upgrade")], []) # For compatibility with non-strict implementations, ignore case when # checking the Upgrade header. It's supposed to be 'WebSocket'. - if not (len(upgrade) == 1 and upgrade[0].lower() == 'websocket'): - raise InvalidUpgrade('Upgrade', upgrade) + if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"): + raise InvalidUpgrade("Upgrade", upgrade) try: - s_w_accept = headers['Sec-WebSocket-Accept'] + s_w_accept = headers["Sec-WebSocket-Accept"] except KeyError: - raise InvalidHeader('Sec-WebSocket-Accept') + raise InvalidHeader("Sec-WebSocket-Accept") except MultipleValuesError: raise InvalidHeader( - 'Sec-WebSocket-Accept', "more than one Sec-WebSocket-Accept header found" + "Sec-WebSocket-Accept", "more than one Sec-WebSocket-Accept header found" ) if s_w_accept != accept(key): - raise InvalidHeaderValue('Sec-WebSocket-Accept', s_w_accept) + raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept) def accept(key): diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 937962376..6151b16db 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -14,12 +14,12 @@ __all__ = [ - 'parse_connection', - 'parse_upgrade', - 'parse_extension_list', - 'build_extension_list', - 'parse_subprotocol_list', - 'build_subprotocol_list', + "parse_connection", + "parse_upgrade", + "parse_extension_list", + "build_extension_list", + "parse_subprotocol_list", + "build_subprotocol_list", ] @@ -40,7 +40,7 @@ def peek_ahead(string, pos): return None if pos == len(string) else string[pos] -_OWS_re = re.compile(r'[\t ]*') +_OWS_re = re.compile(r"[\t ]*") def parse_OWS(string, pos): @@ -57,7 +57,7 @@ def parse_OWS(string, pos): return match.end() -_token_re = re.compile(r'[-!#$%&\'*+.^_`|~0-9a-zA-Z]+') +_token_re = re.compile(r"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+") def parse_token(string, pos, header_name): @@ -80,7 +80,7 @@ def parse_token(string, pos, header_name): ) -_unquote_re = re.compile(r'\\([\x09\x20-\x7e\x80-\xff])') +_unquote_re = re.compile(r"\\([\x09\x20-\x7e\x80-\xff])") def parse_quoted_string(string, pos, header_name): @@ -97,7 +97,7 @@ def parse_quoted_string(string, pos, header_name): raise InvalidHeaderFormat( header_name, "expected quoted string", string=string, pos=pos ) - return _unquote_re.sub(r'\1', match.group()[1:-1]), match.end() + return _unquote_re.sub(r"\1", match.group()[1:-1]), match.end() def parse_list(parse_item, string, pos, header_name): @@ -125,7 +125,7 @@ def parse_list(parse_item, string, pos, header_name): # while loops that remove extra delimiters. # Remove extra delimiters before the first item. - while peek_ahead(string, pos) == ',': + while peek_ahead(string, pos) == ",": pos = parse_OWS(string, pos + 1) items = [] @@ -140,7 +140,7 @@ def parse_list(parse_item, string, pos, header_name): break # There must be a delimiter after each element except the last one. - if peek_ahead(string, pos) == ',': + if peek_ahead(string, pos) == ",": pos = parse_OWS(string, pos + 1) else: raise InvalidHeaderFormat( @@ -148,7 +148,7 @@ def parse_list(parse_item, string, pos, header_name): ) # Remove extra delimiters before the next item. - while peek_ahead(string, pos) == ',': + while peek_ahead(string, pos) == ",": pos = parse_OWS(string, pos + 1) # We may have reached the end of the string. @@ -171,11 +171,11 @@ def parse_connection(string): Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. """ - return parse_list(parse_token, string, 0, 'Connection') + return parse_list(parse_token, string, 0, "Connection") _protocol_re = re.compile( - r'[-!#$%&\'*+.^_`|~0-9a-zA-Z]+(?:/[-!#$%&\'*+.^_`|~0-9a-zA-Z]+)?' + r"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+(?:/[-!#$%&\'*+.^_`|~0-9a-zA-Z]+)?" ) @@ -205,7 +205,7 @@ def parse_upgrade(string): Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. """ - return parse_list(parse_protocol, string, 0, 'Upgrade') + return parse_list(parse_protocol, string, 0, "Upgrade") def parse_extension_param(string, pos, header_name): @@ -221,7 +221,7 @@ def parse_extension_param(string, pos, header_name): name, pos = parse_token(string, pos, header_name) pos = parse_OWS(string, pos) # Extract parameter string, if there is one. - if peek_ahead(string, pos) == '=': + if peek_ahead(string, pos) == "=": pos = parse_OWS(string, pos + 1) if peek_ahead(string, pos) == '"': pos_before = pos # for proper error reporting below @@ -259,7 +259,7 @@ def parse_extension(string, pos, header_name): pos = parse_OWS(string, pos) # Extract all parameters. parameters = [] - while peek_ahead(string, pos) == ';': + while peek_ahead(string, pos) == ";": pos = parse_OWS(string, pos + 1) parameter, pos = parse_extension_param(string, pos, header_name) parameters.append(parameter) @@ -288,7 +288,7 @@ def parse_extension_list(string): Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. """ - return parse_list(parse_extension, string, 0, 'Sec-WebSocket-Extensions') + return parse_list(parse_extension, string, 0, "Sec-WebSocket-Extensions") def build_extension(name, parameters): @@ -298,11 +298,11 @@ def build_extension(name, parameters): This is the reverse of :func:`parse_extension`. """ - return '; '.join( + return "; ".join( [name] + [ # Quoted strings aren't necessary because values are always tokens. - name if value is None else '{}={}'.format(name, value) + name if value is None else "{}={}".format(name, value) for name, value in parameters ] ) @@ -315,7 +315,7 @@ def build_extension_list(extensions): This is the reverse of :func:`parse_extension_list`. """ - return ', '.join( + return ", ".join( build_extension(name, parameters) for name, parameters in extensions ) @@ -327,7 +327,7 @@ def parse_subprotocol_list(string): Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. """ - return parse_list(parse_token, string, 0, 'Sec-WebSocket-Protocol') + return parse_list(parse_token, string, 0, "Sec-WebSocket-Protocol") def build_subprotocol_list(protocols): @@ -337,7 +337,7 @@ def build_subprotocol_list(protocols): This is the reverse of :func:`parse_subprotocol_list`. """ - return ', '.join(protocols) + return ", ".join(protocols) def build_basic_auth(username, password): @@ -346,7 +346,7 @@ def build_basic_auth(username, password): """ # https://tools.ietf.org/html/rfc7617#section-2 - assert ':' not in username - user_pass = '{}:{}'.format(username, password) + assert ":" not in username + user_pass = "{}:{}".format(username, password) basic_credentials = base64.b64encode(user_pass.encode()).decode() - return 'Basic ' + basic_credentials + return "Basic " + basic_credentials diff --git a/src/websockets/http.py b/src/websockets/http.py index 5062c03d7..ea17e0a2e 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -16,24 +16,24 @@ __all__ = [ - 'Headers', - 'MultipleValuesError', - 'read_request', - 'read_response', - 'USER_AGENT', + "Headers", + "MultipleValuesError", + "read_request", + "read_response", + "USER_AGENT", ] MAX_HEADERS = 256 MAX_LINE = 4096 -USER_AGENT = 'Python/{} websockets/{}'.format(sys.version[:3], websockets_version) +USER_AGENT = "Python/{} websockets/{}".format(sys.version[:3], websockets_version) # See https://tools.ietf.org/html/rfc7230#appendix-B. # Regex for validating header names. -_token_re = re.compile(rb'[-!#$%&\'*+.^_`|~0-9a-zA-Z]+') +_token_re = re.compile(rb"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+") # Regex for validating header values. @@ -46,7 +46,7 @@ # See also https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189 -_value_re = re.compile(rb'[\x09\x20-\x7e\x80-\xff]*') +_value_re = re.compile(rb"[\x09\x20-\x7e\x80-\xff]*") @asyncio.coroutine @@ -79,13 +79,13 @@ def read_request(stream): request_line = yield from read_line(stream) # This may raise "ValueError: not enough values to unpack" - method, path, version = request_line.split(b' ', 2) + method, path, version = request_line.split(b" ", 2) - if method != b'GET': + if method != b"GET": raise ValueError("Unsupported HTTP method: %r" % method) - if version != b'HTTP/1.1': + if version != b"HTTP/1.1": raise ValueError("Unsupported HTTP version: %r" % version) - path = path.decode('ascii', 'surrogateescape') + path = path.decode("ascii", "surrogateescape") headers = yield from read_headers(stream) @@ -120,9 +120,9 @@ def read_response(stream): status_line = yield from read_line(stream) # This may raise "ValueError: not enough values to unpack" - version, status_code, reason = status_line.split(b' ', 2) + version, status_code, reason = status_line.split(b" ", 2) - if version != b'HTTP/1.1': + if version != b"HTTP/1.1": raise ValueError("Unsupported HTTP version: %r" % version) # This may raise "ValueError: invalid literal for int() with base 10" status_code = int(status_code) @@ -156,19 +156,19 @@ def read_headers(stream): headers = Headers() for _ in range(MAX_HEADERS + 1): line = yield from read_line(stream) - if line == b'': + if line == b"": break # This may raise "ValueError: not enough values to unpack" - name, value = line.split(b':', 1) + name, value = line.split(b":", 1) if not _token_re.fullmatch(name): raise ValueError("Invalid HTTP header name: %r" % name) - value = value.strip(b' \t') + value = value.strip(b" \t") if not _value_re.fullmatch(value): raise ValueError("Invalid HTTP header value: %r" % value) - name = name.decode('ascii') # guaranteed to be ASCII at this point - value = value.decode('ascii', 'surrogateescape') + name = name.decode("ascii") # guaranteed to be ASCII at this point + value = value.decode("ascii", "surrogateescape") headers[name] = value else: @@ -193,7 +193,7 @@ def read_line(stream): if len(line) > MAX_LINE: raise ValueError("Line too long") # Not mandatory but safe - https://tools.ietf.org/html/rfc7230#section-3.5 - if not line.endswith(b'\r\n'): + if not line.endswith(b"\r\n"): raise ValueError("Line without CRLF") return line[:-2] @@ -248,7 +248,7 @@ class Headers(collections.abc.MutableMapping): """ - __slots__ = ['_dict', '_list'] + __slots__ = ["_dict", "_list"] def __init__(self, *args, **kwargs): self._dict = {} @@ -258,12 +258,12 @@ def __init__(self, *args, **kwargs): def __str__(self): return ( - ''.join('{}: {}\r\n'.format(key, value) for key, value in self._list) - + '\r\n' + "".join("{}: {}\r\n".format(key, value) for key, value in self._list) + + "\r\n" ) def __repr__(self): - return '{}({})'.format(self.__class__.__name__, repr(self._list)) + return "{}({})".format(self.__class__.__name__, repr(self._list)) def copy(self): copy = self.__class__() diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 52e39a2af..d7d7282a1 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -29,7 +29,7 @@ from .handshake import * -__all__ = ['WebSocketCommonProtocol'] +__all__ = ["WebSocketCommonProtocol"] logger = logging.getLogger(__name__) @@ -37,7 +37,7 @@ # On Python ≥ 3.7, silence a deprecation warning that we can't address before # dropping support for Python < 3.5. warnings.filterwarnings( - action='ignore', + action="ignore", message=r"'with \(yield from lock\)' is deprecated use 'async with lock' instead", category=DeprecationWarning, ) @@ -163,7 +163,7 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): # side behavior: masking the payload and closing the underlying TCP # connection. Set is_client and side to pick a side. is_client = None - side = 'undefined' + side = "undefined" def __init__( self, @@ -236,7 +236,7 @@ def __init__( # The close code and reason are set when receiving a close frame or # losing the TCP connection. self.close_code = None - self.close_reason = '' + self.close_reason = "" # Completed when the connection state becomes CLOSED. Translates the # :meth:`connection_lost()` callback to a :class:`~asyncio.Future` @@ -313,7 +313,7 @@ def local_address(self): """ if self.writer is None: return None - return self.writer.get_extra_info('sockname') + return self.writer.get_extra_info("sockname") @property def remote_address(self): @@ -326,7 +326,7 @@ def remote_address(self): """ if self.writer is None: return None - return self.writer.get_extra_info('peername') + return self.writer.get_extra_info("peername") @property def open(self): @@ -498,7 +498,7 @@ def send(self, data): yield from self.write_frame(False, OP_CONT, data) # Final fragment. - yield from self.write_frame(True, OP_CONT, b'') + yield from self.write_frame(True, OP_CONT, b"") # Fragmented message -- asynchronous iterator @@ -508,7 +508,7 @@ def send(self, data): raise TypeError("data must be bytes, str, or iterable") @asyncio.coroutine - def close(self, code=1000, reason=''): + def close(self, code=1000, reason=""): """ This coroutine performs the closing handshake. @@ -588,7 +588,7 @@ def ping(self, data=None): # Generate a unique random payload otherwise. while data is None or data in self.pings: - data = struct.pack('!I', random.getrandbits(32)) + data = struct.pack("!I", random.getrandbits(32)) self.pings[data] = asyncio.Future(loop=self.loop) @@ -597,7 +597,7 @@ def ping(self, data=None): return asyncio.shield(self.pings[data]) @asyncio.coroutine - def pong(self, data=b''): + def pong(self, data=b""): """ This coroutine sends a pong. @@ -751,13 +751,13 @@ def read_message(self): # Shortcut for the common case - no fragmentation if frame.fin: - return frame.data.decode('utf-8') if text else frame.data + return frame.data.decode("utf-8") if text else frame.data # 5.4. Fragmentation chunks = [] max_size = self.max_size if text: - decoder = codecs.getincrementaldecoder('utf-8')(errors='strict') + decoder = codecs.getincrementaldecoder("utf-8")(errors="strict") if max_size is None: def append(frame): @@ -795,7 +795,7 @@ def append(frame): raise WebSocketProtocolError("Unexpected opcode") append(frame) - return ('' if text else b'').join(chunks) + return ("" if text else b"").join(chunks) @asyncio.coroutine def read_data_frame(self, max_size): @@ -825,7 +825,7 @@ def read_data_frame(self, max_size): elif frame.opcode == OP_PING: # Answer pings. # Replace by frame.data.hex() when dropping Python < 3.5. - ping_hex = binascii.hexlify(frame.data).decode() or '[empty]' + ping_hex = binascii.hexlify(frame.data).decode() or "[empty]" logger.debug( "%s - received ping, sending pong: %s", self.side, ping_hex ) @@ -841,17 +841,17 @@ def read_data_frame(self, max_size): ping_id, pong_waiter = self.pings.popitem(0) ping_ids.append(ping_id) pong_waiter.set_result(None) - pong_hex = binascii.hexlify(frame.data).decode() or '[empty]' + pong_hex = binascii.hexlify(frame.data).decode() or "[empty]" logger.debug( "%s - received solicited pong: %s", self.side, pong_hex ) ping_ids = ping_ids[:-1] if ping_ids: - pings_hex = ', '.join( - binascii.hexlify(ping_id).decode() or '[empty]' + pings_hex = ", ".join( + binascii.hexlify(ping_id).decode() or "[empty]" for ping_id in ping_ids ) - plural = 's' if len(ping_ids) > 1 else '' + plural = "s" if len(ping_ids) > 1 else "" logger.debug( "%s - acknowledged previous ping%s: %s", self.side, @@ -859,7 +859,7 @@ def read_data_frame(self, max_size): pings_hex, ) else: - pong_hex = binascii.hexlify(frame.data).decode() or '[empty]' + pong_hex = binascii.hexlify(frame.data).decode() or "[empty]" logger.debug( "%s - received unsolicited pong: %s", self.side, pong_hex ) @@ -935,7 +935,7 @@ def writer_is_closing(self): return transport._closed @asyncio.coroutine - def write_close_frame(self, data=b''): + def write_close_frame(self, data=b""): """ Write a close frame if and only if the connection state is OPEN. @@ -1083,7 +1083,7 @@ def wait_for_connection_lost(self): # and the moment this coroutine resumes running. return self.connection_lost_waiter.done() - def fail_connection(self, code=1006, reason=''): + def fail_connection(self, code=1006, reason=""): """ 7.1.7. Fail the WebSocket Connection @@ -1161,11 +1161,11 @@ def abort_keepalive_pings(self): ping.set_exception(exc) if self.pings: - pings_hex = ', '.join( - binascii.hexlify(ping_id).decode() or '[empty]' + pings_hex = ", ".join( + binascii.hexlify(ping_id).decode() or "[empty]" for ping_id in self.pings ) - plural = 's' if len(self.pings) > 1 else '' + plural = "s" if len(self.pings) > 1 else "" logger.debug( "%s - aborted pending ping%s: %s", self.side, plural, pings_hex ) @@ -1231,7 +1231,7 @@ def connection_lost(self, exc): "%s x code = %d, reason = %s", self.side, self.close_code, - self.close_reason or '[no reason]', + self.close_reason or "[no reason]", ) self.abort_keepalive_pings() # If self.connection_lost_waiter isn't pending, that's a bug, because: diff --git a/src/websockets/py35/client.py b/src/websockets/py35/client.py index bd902841a..ccb098483 100644 --- a/src/websockets/py35/client.py +++ b/src/websockets/py35/client.py @@ -31,11 +31,11 @@ async def __await_impl__(self): raise except RedirectHandshake as e: if self._wsuri.secure and not e.wsuri.secure: - raise InvalidHandshake('Redirect dropped TLS') + raise InvalidHandshake("Redirect dropped TLS") self._wsuri = e.wsuri continue # redirection chain continues else: - raise InvalidHandshake('Maximum redirects exceeded') + raise InvalidHandshake("Maximum redirects exceeded") self.ws_client = protocol return protocol diff --git a/src/websockets/server.py b/src/websockets/server.py index 1d88e73a1..e207db2bc 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -35,7 +35,7 @@ from .protocol import State, WebSocketCommonProtocol -__all__ = ['serve', 'unix_serve', 'WebSocketServerProtocol'] +__all__ = ["serve", "unix_serve", "WebSocketServerProtocol"] logger = logging.getLogger(__name__) @@ -53,7 +53,7 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): """ is_client = False - side = 'server' + side = "server" def __init__( self, @@ -69,9 +69,9 @@ def __init__( **kwds ): # For backwards-compatibility with 6.0 or earlier. - if origins is not None and '' in origins: + if origins is not None and "" in origins: warnings.warn("use None instead of '' in origins", DeprecationWarning) - origins = [None if origin == '' else origin for origin in origins] + origins = [None if origin == "" else origin for origin in origins] self.ws_handler = ws_handler self.ws_server = ws_server self.origins = origins @@ -129,7 +129,7 @@ def handler(self): logger.debug("Invalid upgrade", exc_info=True) status, headers, body = ( UPGRADE_REQUIRED, - [('Upgrade', 'websocket')], + [("Upgrade", "websocket")], (str(exc) + "\n").encode(), ) elif isinstance(exc, InvalidHandshake): @@ -150,11 +150,11 @@ def handler(self): if not isinstance(headers, Headers): headers = Headers(headers) - headers.setdefault('Date', email.utils.formatdate(usegmt=True)) - headers.setdefault('Server', USER_AGENT) - headers.setdefault('Content-Length', str(len(body))) - headers.setdefault('Content-Type', 'text/plain') - headers.setdefault('Connection', 'close') + headers.setdefault("Date", email.utils.formatdate(usegmt=True)) + headers.setdefault("Server", USER_AGENT) + headers.setdefault("Content-Length", str(len(body))) + headers.setdefault("Content-Type", "text/plain") + headers.setdefault("Connection", "close") self.write_http_response(status, headers, body) self.fail_connection() @@ -232,7 +232,7 @@ def write_http_response(self, status, headers, body=None): # Since the status line and headers only contain ASCII characters, # we can keep this simple. - response = 'HTTP/1.1 {status.value} {status.phrase}\r\n'.format(status=status) + response = "HTTP/1.1 {status.value} {status.phrase}\r\n".format(status=status) response += str(headers) self.writer.write(response.encode()) @@ -287,9 +287,9 @@ def process_origin(headers, origins=None): # "The user agent MUST NOT include more than one Origin header field" # per https://tools.ietf.org/html/rfc6454#section-7.3. try: - origin = headers.get('Origin') + origin = headers.get("Origin") except MultipleValuesError: - raise InvalidHeader('Origin', "more than one Origin header found") + raise InvalidHeader("Origin", "more than one Origin header found") if origins is not None: if origin not in origins: raise InvalidOrigin(origin) @@ -332,7 +332,7 @@ def process_extensions(headers, available_extensions): response_header = [] accepted_extensions = [] - header_values = headers.get_all('Sec-WebSocket-Extensions') + header_values = headers.get_all("Sec-WebSocket-Extensions") if header_values and available_extensions: @@ -386,7 +386,7 @@ def process_subprotocol(self, headers, available_subprotocols): """ subprotocol = None - header_values = headers.get_all('Sec-WebSocket-Protocol') + header_values = headers.get_all("Sec-WebSocket-Protocol") if header_values and available_subprotocols: @@ -498,10 +498,10 @@ def handshake( build_response(response_headers, key) if extensions_header is not None: - response_headers['Sec-WebSocket-Extensions'] = extensions_header + response_headers["Sec-WebSocket-Extensions"] = extensions_header if protocol_header is not None: - response_headers['Sec-WebSocket-Protocol'] = protocol_header + response_headers["Sec-WebSocket-Protocol"] = protocol_header if extra_headers is not None: if callable(extra_headers): @@ -513,8 +513,8 @@ def handshake( for name, value in extra_headers: response_headers[name] = value - response_headers.setdefault('Date', email.utils.formatdate(usegmt=True)) - response_headers.setdefault('Server', USER_AGENT) + response_headers.setdefault("Date", email.utils.formatdate(usegmt=True)) + response_headers.setdefault("Server", USER_AGENT) self.write_http_response(SWITCHING_PROTOCOLS, response_headers) @@ -778,7 +778,7 @@ def __init__( legacy_recv=False, klass=WebSocketServerProtocol, timeout=10, - compression='deflate', + compression="deflate", origins=None, extensions=None, subprotocols=None, @@ -802,9 +802,9 @@ def __init__( ws_server = WebSocketServer(loop) - secure = kwds.get('ssl') is not None + secure = kwds.get("ssl") is not None - if compression == 'deflate': + if compression == "deflate": if extensions is None: extensions = [] if not any( diff --git a/src/websockets/uri.py b/src/websockets/uri.py index d793fc6aa..b6e1ad0ce 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -12,10 +12,10 @@ from .exceptions import InvalidURI -__all__ = ['parse_uri', 'WebSocketURI'] +__all__ = ["parse_uri", "WebSocketURI"] WebSocketURI = collections.namedtuple( - 'WebSocketURI', ['secure', 'host', 'port', 'resource_name', 'user_info'] + "WebSocketURI", ["secure", "host", "port", "resource_name", "user_info"] ) WebSocketURI.__doc__ = """WebSocket URI. @@ -42,19 +42,19 @@ def parse_uri(uri): """ uri = urllib.parse.urlparse(uri) try: - assert uri.scheme in ['ws', 'wss'] - assert uri.params == '' - assert uri.fragment == '' + assert uri.scheme in ["ws", "wss"] + assert uri.params == "" + assert uri.fragment == "" assert uri.hostname is not None except AssertionError as exc: raise InvalidURI("{} isn't a valid URI".format(uri)) from exc - secure = uri.scheme == 'wss' + secure = uri.scheme == "wss" host = uri.hostname port = uri.port or (443 if secure else 80) - resource_name = uri.path or '/' + resource_name = uri.path or "/" if uri.query: - resource_name += '?' + uri.query + resource_name += "?" + uri.query user_info = None if uri.username or uri.password: user_info = (uri.username, uri.password) diff --git a/src/websockets/utils.py b/src/websockets/utils.py index def997841..193f8fc32 100644 --- a/src/websockets/utils.py +++ b/src/websockets/utils.py @@ -1,7 +1,7 @@ import itertools -__all__ = ['apply_mask'] +__all__ = ["apply_mask"] def apply_mask(data, mask): diff --git a/src/websockets/version.py b/src/websockets/version.py index fe9ed183b..96b948d8a 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -1 +1 @@ -version = '7.0' +version = "7.0" diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index 0b7b78eae..80003ca2d 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -39,7 +39,7 @@ def assertExtensionEqual(self, extension1, extension2): class ClientPerMessageDeflateFactoryTests(unittest.TestCase, ExtensionTestsMixin): def test_name(self): - assert ClientPerMessageDeflateFactory.name == 'permessage-deflate' + assert ClientPerMessageDeflateFactory.name == "permessage-deflate" def test_init(self): for config in [ @@ -48,7 +48,7 @@ def test_init(self): (True, False, None, 8), # client_max_window_bits ≥ 8 (True, True, None, 15), # client_max_window_bits ≤ 15 (False, False, None, True), # client_max_window_bits - (False, False, None, None, {'memLevel': 4}), + (False, False, None, None, {"memLevel": 4}), ]: with self.subTest(config=config): # This does not raise an exception. @@ -61,7 +61,7 @@ def test_init_error(self): (True, False, 16, 15), # server_max_window_bits > 15 (True, True, 15, 16), # client_max_window_bits > 15 (False, False, True, None), # server_max_window_bits - (False, False, None, None, {'wbits': 11}), + (False, False, None, None, {"wbits": 11}), ]: with self.subTest(config=config): with self.assertRaises(ValueError): @@ -72,22 +72,22 @@ def test_get_request_params(self): # Test without any parameter ((False, False, None, None), []), # Test server_no_context_takeover - ((True, False, None, None), [('server_no_context_takeover', None)]), + ((True, False, None, None), [("server_no_context_takeover", None)]), # Test client_no_context_takeover - ((False, True, None, None), [('client_no_context_takeover', None)]), + ((False, True, None, None), [("client_no_context_takeover", None)]), # Test server_max_window_bits - ((False, False, 10, None), [('server_max_window_bits', '10')]), + ((False, False, 10, None), [("server_max_window_bits", "10")]), # Test client_max_window_bits - ((False, False, None, 10), [('client_max_window_bits', '10')]), - ((False, False, None, True), [('client_max_window_bits', None)]), + ((False, False, None, 10), [("client_max_window_bits", "10")]), + ((False, False, None, True), [("client_max_window_bits", None)]), # Test all parameters together ( (True, True, 12, 12), [ - ('server_no_context_takeover', None), - ('client_no_context_takeover', None), - ('server_max_window_bits', '12'), - ('client_max_window_bits', '12'), + ("server_no_context_takeover", None), + ("client_no_context_takeover", None), + ("server_max_window_bits", "12"), + ("client_max_window_bits", "12"), ], ), ]: @@ -99,167 +99,167 @@ def test_process_response_params(self): for config, response_params, result in [ # Test without any parameter ((False, False, None, None), [], (False, False, 15, 15)), - ((False, False, None, None), [('unknown', None)], InvalidParameterName), + ((False, False, None, None), [("unknown", None)], InvalidParameterName), # Test server_no_context_takeover ( (False, False, None, None), - [('server_no_context_takeover', None)], + [("server_no_context_takeover", None)], (True, False, 15, 15), ), ((True, False, None, None), [], NegotiationError), ( (True, False, None, None), - [('server_no_context_takeover', None)], + [("server_no_context_takeover", None)], (True, False, 15, 15), ), ( (True, False, None, None), - [('server_no_context_takeover', None)] * 2, + [("server_no_context_takeover", None)] * 2, DuplicateParameter, ), ( (True, False, None, None), - [('server_no_context_takeover', '42')], + [("server_no_context_takeover", "42")], InvalidParameterValue, ), # Test client_no_context_takeover ( (False, False, None, None), - [('client_no_context_takeover', None)], + [("client_no_context_takeover", None)], (False, True, 15, 15), ), ((False, True, None, None), [], (False, True, 15, 15)), ( (False, True, None, None), - [('client_no_context_takeover', None)], + [("client_no_context_takeover", None)], (False, True, 15, 15), ), ( (False, True, None, None), - [('client_no_context_takeover', None)] * 2, + [("client_no_context_takeover", None)] * 2, DuplicateParameter, ), ( (False, True, None, None), - [('client_no_context_takeover', '42')], + [("client_no_context_takeover", "42")], InvalidParameterValue, ), # Test server_max_window_bits ( (False, False, None, None), - [('server_max_window_bits', '7')], + [("server_max_window_bits", "7")], NegotiationError, ), ( (False, False, None, None), - [('server_max_window_bits', '10')], + [("server_max_window_bits", "10")], (False, False, 10, 15), ), ( (False, False, None, None), - [('server_max_window_bits', '16')], + [("server_max_window_bits", "16")], NegotiationError, ), ((False, False, 12, None), [], NegotiationError), ( (False, False, 12, None), - [('server_max_window_bits', '10')], + [("server_max_window_bits", "10")], (False, False, 10, 15), ), ( (False, False, 12, None), - [('server_max_window_bits', '12')], + [("server_max_window_bits", "12")], (False, False, 12, 15), ), ( (False, False, 12, None), - [('server_max_window_bits', '13')], + [("server_max_window_bits", "13")], NegotiationError, ), ( (False, False, 12, None), - [('server_max_window_bits', '12')] * 2, + [("server_max_window_bits", "12")] * 2, DuplicateParameter, ), ( (False, False, 12, None), - [('server_max_window_bits', '42')], + [("server_max_window_bits", "42")], InvalidParameterValue, ), # Test client_max_window_bits ( (False, False, None, None), - [('client_max_window_bits', '10')], + [("client_max_window_bits", "10")], NegotiationError, ), ((False, False, None, True), [], (False, False, 15, 15)), ( (False, False, None, True), - [('client_max_window_bits', '7')], + [("client_max_window_bits", "7")], NegotiationError, ), ( (False, False, None, True), - [('client_max_window_bits', '10')], + [("client_max_window_bits", "10")], (False, False, 15, 10), ), ( (False, False, None, True), - [('client_max_window_bits', '16')], + [("client_max_window_bits", "16")], NegotiationError, ), ((False, False, None, 12), [], (False, False, 15, 12)), ( (False, False, None, 12), - [('client_max_window_bits', '10')], + [("client_max_window_bits", "10")], (False, False, 15, 10), ), ( (False, False, None, 12), - [('client_max_window_bits', '12')], + [("client_max_window_bits", "12")], (False, False, 15, 12), ), ( (False, False, None, 12), - [('client_max_window_bits', '13')], + [("client_max_window_bits", "13")], NegotiationError, ), ( (False, False, None, 12), - [('client_max_window_bits', '12')] * 2, + [("client_max_window_bits", "12")] * 2, DuplicateParameter, ), ( (False, False, None, 12), - [('client_max_window_bits', '42')], + [("client_max_window_bits", "42")], InvalidParameterValue, ), # Test all parameters together ( (True, True, 12, 12), [ - ('server_no_context_takeover', None), - ('client_no_context_takeover', None), - ('server_max_window_bits', '10'), - ('client_max_window_bits', '10'), + ("server_no_context_takeover", None), + ("client_no_context_takeover", None), + ("server_max_window_bits", "10"), + ("client_max_window_bits", "10"), ], (True, True, 10, 10), ), ( (False, False, None, True), [ - ('server_no_context_takeover', None), - ('client_no_context_takeover', None), - ('server_max_window_bits', '10'), - ('client_max_window_bits', '10'), + ("server_no_context_takeover", None), + ("client_no_context_takeover", None), + ("server_max_window_bits", "10"), + ("client_max_window_bits", "10"), ], (True, True, 10, 10), ), ( (True, True, 12, 12), [ - ('server_no_context_takeover', None), - ('server_max_window_bits', '12'), + ("server_no_context_takeover", None), + ("server_max_window_bits", "12"), ], (True, True, 12, 12), ), @@ -284,7 +284,7 @@ def test_process_response_params_deduplication(self): class ServerPerMessageDeflateFactoryTests(unittest.TestCase, ExtensionTestsMixin): def test_name(self): - assert ServerPerMessageDeflateFactory.name == 'permessage-deflate' + assert ServerPerMessageDeflateFactory.name == "permessage-deflate" def test_init(self): for config in [ @@ -292,7 +292,7 @@ def test_init(self): (False, True, 15, None), # server_max_window_bits ≤ 15 (True, False, None, 8), # client_max_window_bits ≥ 8 (True, True, None, 15), # client_max_window_bits ≤ 15 - (False, False, None, None, {'memLevel': 4}), + (False, False, None, None, {"memLevel": 4}), ]: with self.subTest(config=config): # This does not raise an exception. @@ -306,7 +306,7 @@ def test_init_error(self): (True, True, 15, 16), # client_max_window_bits > 15 (False, False, None, True), # client_max_window_bits (False, False, True, None), # server_max_window_bits - (False, False, None, None, {'wbits': 11}), + (False, False, None, None, {"wbits": 11}), ]: with self.subTest(config=config): with self.assertRaises(ValueError): @@ -320,186 +320,186 @@ def test_process_request_params(self): ((False, False, None, None), [], [], (False, False, 15, 15)), ( (False, False, None, None), - [('unknown', None)], + [("unknown", None)], None, InvalidParameterName, ), # Test server_no_context_takeover ( (False, False, None, None), - [('server_no_context_takeover', None)], - [('server_no_context_takeover', None)], + [("server_no_context_takeover", None)], + [("server_no_context_takeover", None)], (False, True, 15, 15), ), ( (True, False, None, None), [], - [('server_no_context_takeover', None)], + [("server_no_context_takeover", None)], (False, True, 15, 15), ), ( (True, False, None, None), - [('server_no_context_takeover', None)], - [('server_no_context_takeover', None)], + [("server_no_context_takeover", None)], + [("server_no_context_takeover", None)], (False, True, 15, 15), ), ( (True, False, None, None), - [('server_no_context_takeover', None)] * 2, + [("server_no_context_takeover", None)] * 2, None, DuplicateParameter, ), ( (True, False, None, None), - [('server_no_context_takeover', '42')], + [("server_no_context_takeover", "42")], None, InvalidParameterValue, ), # Test client_no_context_takeover ( (False, False, None, None), - [('client_no_context_takeover', None)], - [('client_no_context_takeover', None)], # doesn't matter + [("client_no_context_takeover", None)], + [("client_no_context_takeover", None)], # doesn't matter (True, False, 15, 15), ), ( (False, True, None, None), [], - [('client_no_context_takeover', None)], + [("client_no_context_takeover", None)], (True, False, 15, 15), ), ( (False, True, None, None), - [('client_no_context_takeover', None)], - [('client_no_context_takeover', None)], # doesn't matter + [("client_no_context_takeover", None)], + [("client_no_context_takeover", None)], # doesn't matter (True, False, 15, 15), ), ( (False, True, None, None), - [('client_no_context_takeover', None)] * 2, + [("client_no_context_takeover", None)] * 2, None, DuplicateParameter, ), ( (False, True, None, None), - [('client_no_context_takeover', '42')], + [("client_no_context_takeover", "42")], None, InvalidParameterValue, ), # Test server_max_window_bits ( (False, False, None, None), - [('server_max_window_bits', '7')], + [("server_max_window_bits", "7")], None, NegotiationError, ), ( (False, False, None, None), - [('server_max_window_bits', '10')], - [('server_max_window_bits', '10')], + [("server_max_window_bits", "10")], + [("server_max_window_bits", "10")], (False, False, 15, 10), ), ( (False, False, None, None), - [('server_max_window_bits', '16')], + [("server_max_window_bits", "16")], None, NegotiationError, ), ( (False, False, 12, None), [], - [('server_max_window_bits', '12')], + [("server_max_window_bits", "12")], (False, False, 15, 12), ), ( (False, False, 12, None), - [('server_max_window_bits', '10')], - [('server_max_window_bits', '10')], + [("server_max_window_bits", "10")], + [("server_max_window_bits", "10")], (False, False, 15, 10), ), ( (False, False, 12, None), - [('server_max_window_bits', '12')], - [('server_max_window_bits', '12')], + [("server_max_window_bits", "12")], + [("server_max_window_bits", "12")], (False, False, 15, 12), ), ( (False, False, 12, None), - [('server_max_window_bits', '13')], - [('server_max_window_bits', '12')], + [("server_max_window_bits", "13")], + [("server_max_window_bits", "12")], (False, False, 15, 12), ), ( (False, False, 12, None), - [('server_max_window_bits', '12')] * 2, + [("server_max_window_bits", "12")] * 2, None, DuplicateParameter, ), ( (False, False, 12, None), - [('server_max_window_bits', '42')], + [("server_max_window_bits", "42")], None, InvalidParameterValue, ), # Test client_max_window_bits ( (False, False, None, None), - [('client_max_window_bits', None)], + [("client_max_window_bits", None)], [], (False, False, 15, 15), ), ( (False, False, None, None), - [('client_max_window_bits', '7')], + [("client_max_window_bits", "7")], None, InvalidParameterValue, ), ( (False, False, None, None), - [('client_max_window_bits', '10')], - [('client_max_window_bits', '10')], # doesn't matter + [("client_max_window_bits", "10")], + [("client_max_window_bits", "10")], # doesn't matter (False, False, 10, 15), ), ( (False, False, None, None), - [('client_max_window_bits', '16')], + [("client_max_window_bits", "16")], None, InvalidParameterValue, ), ((False, False, None, 12), [], None, NegotiationError), ( (False, False, None, 12), - [('client_max_window_bits', None)], - [('client_max_window_bits', '12')], + [("client_max_window_bits", None)], + [("client_max_window_bits", "12")], (False, False, 12, 15), ), ( (False, False, None, 12), - [('client_max_window_bits', '10')], - [('client_max_window_bits', '10')], + [("client_max_window_bits", "10")], + [("client_max_window_bits", "10")], (False, False, 10, 15), ), ( (False, False, None, 12), - [('client_max_window_bits', '12')], - [('client_max_window_bits', '12')], # doesn't matter + [("client_max_window_bits", "12")], + [("client_max_window_bits", "12")], # doesn't matter (False, False, 12, 15), ), ( (False, False, None, 12), - [('client_max_window_bits', '13')], - [('client_max_window_bits', '12')], # doesn't matter + [("client_max_window_bits", "13")], + [("client_max_window_bits", "12")], # doesn't matter (False, False, 12, 15), ), ( (False, False, None, 12), - [('client_max_window_bits', '12')] * 2, + [("client_max_window_bits", "12")] * 2, None, DuplicateParameter, ), ( (False, False, None, 12), - [('client_max_window_bits', '42')], + [("client_max_window_bits", "42")], None, InvalidParameterValue, ), @@ -507,43 +507,43 @@ def test_process_request_params(self): ( (True, True, 12, 12), [ - ('server_no_context_takeover', None), - ('client_no_context_takeover', None), - ('server_max_window_bits', '10'), - ('client_max_window_bits', '10'), + ("server_no_context_takeover", None), + ("client_no_context_takeover", None), + ("server_max_window_bits", "10"), + ("client_max_window_bits", "10"), ], [ - ('server_no_context_takeover', None), - ('client_no_context_takeover', None), - ('server_max_window_bits', '10'), - ('client_max_window_bits', '10'), + ("server_no_context_takeover", None), + ("client_no_context_takeover", None), + ("server_max_window_bits", "10"), + ("client_max_window_bits", "10"), ], (True, True, 10, 10), ), ( (False, False, None, None), [ - ('server_no_context_takeover', None), - ('client_no_context_takeover', None), - ('server_max_window_bits', '10'), - ('client_max_window_bits', '10'), + ("server_no_context_takeover", None), + ("client_no_context_takeover", None), + ("server_max_window_bits", "10"), + ("client_max_window_bits", "10"), ], [ - ('server_no_context_takeover', None), - ('client_no_context_takeover', None), - ('server_max_window_bits', '10'), - ('client_max_window_bits', '10'), + ("server_no_context_takeover", None), + ("client_no_context_takeover", None), + ("server_max_window_bits", "10"), + ("client_max_window_bits", "10"), ], (True, True, 10, 10), ), ( (True, True, 12, 12), - [('client_max_window_bits', None)], + [("client_max_window_bits", None)], [ - ('server_no_context_takeover', None), - ('client_no_context_takeover', None), - ('server_max_window_bits', '12'), - ('client_max_window_bits', '12'), + ("server_no_context_takeover", None), + ("client_no_context_takeover", None), + ("server_max_window_bits", "12"), + ("client_max_window_bits", "12"), ], (True, True, 12, 12), ), @@ -581,7 +581,7 @@ def setUp(self): self.extension = PerMessageDeflate(False, False, 15, 15) def test_name(self): - assert self.extension.name == 'permessage-deflate' + assert self.extension.name == "permessage-deflate" def test_repr(self): self.assertExtensionEqual(eval(repr(self.extension)), self.extension) @@ -589,21 +589,21 @@ def test_repr(self): # Control frames aren't encoded or decoded. def test_no_encode_decode_ping_frame(self): - frame = Frame(True, OP_PING, b'') + frame = Frame(True, OP_PING, b"") self.assertEqual(self.extension.encode(frame), frame) self.assertEqual(self.extension.decode(frame), frame) def test_no_encode_decode_pong_frame(self): - frame = Frame(True, OP_PONG, b'') + frame = Frame(True, OP_PONG, b"") self.assertEqual(self.extension.encode(frame), frame) self.assertEqual(self.extension.decode(frame), frame) def test_no_encode_decode_close_frame(self): - frame = Frame(True, OP_CLOSE, serialize_close(1000, '')) + frame = Frame(True, OP_CLOSE, serialize_close(1000, "")) self.assertEqual(self.extension.encode(frame), frame) @@ -612,31 +612,31 @@ def test_no_encode_decode_close_frame(self): # Data frames are encoded and decoded. def test_encode_decode_text_frame(self): - frame = Frame(True, OP_TEXT, 'café'.encode('utf-8')) + frame = Frame(True, OP_TEXT, "café".encode("utf-8")) enc_frame = self.extension.encode(frame) - self.assertEqual(enc_frame, frame._replace(rsv1=True, data=b'JNL;\xbc\x12\x00')) + self.assertEqual(enc_frame, frame._replace(rsv1=True, data=b"JNL;\xbc\x12\x00")) dec_frame = self.extension.decode(enc_frame) self.assertEqual(dec_frame, frame) def test_encode_decode_binary_frame(self): - frame = Frame(True, OP_BINARY, b'tea') + frame = Frame(True, OP_BINARY, b"tea") enc_frame = self.extension.encode(frame) - self.assertEqual(enc_frame, frame._replace(rsv1=True, data=b'*IM\x04\x00')) + self.assertEqual(enc_frame, frame._replace(rsv1=True, data=b"*IM\x04\x00")) dec_frame = self.extension.decode(enc_frame) self.assertEqual(dec_frame, frame) def test_encode_decode_fragmented_text_frame(self): - frame1 = Frame(False, OP_TEXT, 'café'.encode('utf-8')) - frame2 = Frame(False, OP_CONT, ' & '.encode('utf-8')) - frame3 = Frame(True, OP_CONT, 'croissants'.encode('utf-8')) + frame1 = Frame(False, OP_TEXT, "café".encode("utf-8")) + frame2 = Frame(False, OP_CONT, " & ".encode("utf-8")) + frame3 = Frame(True, OP_CONT, "croissants".encode("utf-8")) enc_frame1 = self.extension.encode(frame1) enc_frame2 = self.extension.encode(frame2) @@ -644,13 +644,13 @@ def test_encode_decode_fragmented_text_frame(self): self.assertEqual( enc_frame1, - frame1._replace(rsv1=True, data=b'JNL;\xbc\x12\x00\x00\x00\xff\xff'), + frame1._replace(rsv1=True, data=b"JNL;\xbc\x12\x00\x00\x00\xff\xff"), ) self.assertEqual( - enc_frame2, frame2._replace(rsv1=True, data=b'RPS\x00\x00\x00\x00\xff\xff') + enc_frame2, frame2._replace(rsv1=True, data=b"RPS\x00\x00\x00\x00\xff\xff") ) self.assertEqual( - enc_frame3, frame3._replace(rsv1=True, data=b'J.\xca\xcf,.N\xcc+)\x06\x00') + enc_frame3, frame3._replace(rsv1=True, data=b"J.\xca\xcf,.N\xcc+)\x06\x00") ) dec_frame1 = self.extension.decode(enc_frame1) @@ -662,17 +662,17 @@ def test_encode_decode_fragmented_text_frame(self): self.assertEqual(dec_frame3, frame3) def test_encode_decode_fragmented_binary_frame(self): - frame1 = Frame(False, OP_TEXT, b'tea ') - frame2 = Frame(True, OP_CONT, b'time') + frame1 = Frame(False, OP_TEXT, b"tea ") + frame2 = Frame(True, OP_CONT, b"time") enc_frame1 = self.extension.encode(frame1) enc_frame2 = self.extension.encode(frame2) self.assertEqual( - enc_frame1, frame1._replace(rsv1=True, data=b'*IMT\x00\x00\x00\x00\xff\xff') + enc_frame1, frame1._replace(rsv1=True, data=b"*IMT\x00\x00\x00\x00\xff\xff") ) self.assertEqual( - enc_frame2, frame2._replace(rsv1=True, data=b'*\xc9\xccM\x05\x00') + enc_frame2, frame2._replace(rsv1=True, data=b"*\xc9\xccM\x05\x00") ) dec_frame1 = self.extension.decode(enc_frame1) @@ -682,21 +682,21 @@ def test_encode_decode_fragmented_binary_frame(self): self.assertEqual(dec_frame2, frame2) def test_no_decode_text_frame(self): - frame = Frame(True, OP_TEXT, 'café'.encode('utf-8')) + frame = Frame(True, OP_TEXT, "café".encode("utf-8")) # Try decoding a frame that wasn't encoded. self.assertEqual(self.extension.decode(frame), frame) def test_no_decode_binary_frame(self): - frame = Frame(True, OP_TEXT, b'tea') + frame = Frame(True, OP_TEXT, b"tea") # Try decoding a frame that wasn't encoded. self.assertEqual(self.extension.decode(frame), frame) def test_no_decode_fragmented_text_frame(self): - frame1 = Frame(False, OP_TEXT, 'café'.encode('utf-8')) - frame2 = Frame(False, OP_CONT, ' & '.encode('utf-8')) - frame3 = Frame(True, OP_CONT, 'croissants'.encode('utf-8')) + frame1 = Frame(False, OP_TEXT, "café".encode("utf-8")) + frame2 = Frame(False, OP_CONT, " & ".encode("utf-8")) + frame3 = Frame(True, OP_CONT, "croissants".encode("utf-8")) dec_frame1 = self.extension.decode(frame1) dec_frame2 = self.extension.decode(frame2) @@ -707,8 +707,8 @@ def test_no_decode_fragmented_text_frame(self): self.assertEqual(dec_frame3, frame3) def test_no_decode_fragmented_binary_frame(self): - frame1 = Frame(False, OP_TEXT, b'tea ') - frame2 = Frame(True, OP_CONT, b'time') + frame1 = Frame(False, OP_TEXT, b"tea ") + frame2 = Frame(True, OP_CONT, b"time") dec_frame1 = self.extension.decode(frame1) dec_frame2 = self.extension.decode(frame2) @@ -717,25 +717,25 @@ def test_no_decode_fragmented_binary_frame(self): self.assertEqual(dec_frame2, frame2) def test_context_takeover(self): - frame = Frame(True, OP_TEXT, 'café'.encode('utf-8')) + frame = Frame(True, OP_TEXT, "café".encode("utf-8")) enc_frame1 = self.extension.encode(frame) enc_frame2 = self.extension.encode(frame) - self.assertEqual(enc_frame1.data, b'JNL;\xbc\x12\x00') - self.assertEqual(enc_frame2.data, b'J\x06\x11\x00\x00') + self.assertEqual(enc_frame1.data, b"JNL;\xbc\x12\x00") + self.assertEqual(enc_frame2.data, b"J\x06\x11\x00\x00") def test_remote_no_context_takeover(self): # No context takeover when decoding messages. self.extension = PerMessageDeflate(True, False, 15, 15) - frame = Frame(True, OP_TEXT, 'café'.encode('utf-8')) + frame = Frame(True, OP_TEXT, "café".encode("utf-8")) enc_frame1 = self.extension.encode(frame) enc_frame2 = self.extension.encode(frame) - self.assertEqual(enc_frame1.data, b'JNL;\xbc\x12\x00') - self.assertEqual(enc_frame2.data, b'J\x06\x11\x00\x00') + self.assertEqual(enc_frame1.data, b"JNL;\xbc\x12\x00") + self.assertEqual(enc_frame2.data, b"J\x06\x11\x00\x00") dec_frame1 = self.extension.decode(enc_frame1) self.assertEqual(dec_frame1, frame) @@ -748,13 +748,13 @@ def test_local_no_context_takeover(self): # No context takeover when encoding and decoding messages. self.extension = PerMessageDeflate(True, True, 15, 15) - frame = Frame(True, OP_TEXT, 'café'.encode('utf-8')) + frame = Frame(True, OP_TEXT, "café".encode("utf-8")) enc_frame1 = self.extension.encode(frame) enc_frame2 = self.extension.encode(frame) - self.assertEqual(enc_frame1.data, b'JNL;\xbc\x12\x00') - self.assertEqual(enc_frame2.data, b'JNL;\xbc\x12\x00') + self.assertEqual(enc_frame1.data, b"JNL;\xbc\x12\x00") + self.assertEqual(enc_frame2.data, b"JNL;\xbc\x12\x00") dec_frame1 = self.extension.decode(enc_frame1) dec_frame2 = self.extension.decode(enc_frame2) @@ -766,27 +766,27 @@ def test_local_no_context_takeover(self): def test_compress_settings(self): # Configure an extension so that no compression actually occurs. - extension = PerMessageDeflate(False, False, 15, 15, {'level': 0}) + extension = PerMessageDeflate(False, False, 15, 15, {"level": 0}) - frame = Frame(True, OP_TEXT, 'café'.encode('utf-8')) + frame = Frame(True, OP_TEXT, "café".encode("utf-8")) enc_frame = extension.encode(frame) self.assertEqual( enc_frame, frame._replace( - rsv1=True, data=b'\x00\x05\x00\xfa\xffcaf\xc3\xa9\x00' # not compressed + rsv1=True, data=b"\x00\x05\x00\xfa\xffcaf\xc3\xa9\x00" # not compressed ), ) # Frames aren't decoded beyond max_length. def test_decompress_max_size(self): - frame = Frame(True, OP_TEXT, ('a' * 20).encode('utf-8')) + frame = Frame(True, OP_TEXT, ("a" * 20).encode("utf-8")) enc_frame = self.extension.encode(frame) - self.assertEqual(enc_frame.data, b'JL\xc4\x04\x00\x00') + self.assertEqual(enc_frame.data, b"JL\xc4\x04\x00\x00") with self.assertRaises(PayloadTooBig): self.extension.decode(enc_frame, max_size=10) diff --git a/tests/py35/_test_client_server.py b/tests/py35/_test_client_server.py index 46e9111a5..869c379b8 100644 --- a/tests/py35/_test_client_server.py +++ b/tests/py35/_test_client_server.py @@ -23,7 +23,7 @@ def tearDown(self): self.loop.close() def test_client(self): - start_server = serve(handler, 'localhost', 0) + start_server = serve(handler, "localhost", 0) server = self.loop.run_until_complete(start_server) async def run_client(): @@ -41,7 +41,7 @@ async def run_client(): def test_server(self): async def run_server(): # Await serve. - server = await serve(handler, 'localhost', 0) + server = await serve(handler, "localhost", 0) self.assertTrue(server.sockets) server.close() await server.wait_closed() @@ -60,10 +60,10 @@ def tearDown(self): # Asynchronous context managers are only enabled on Python ≥ 3.5.1. @unittest.skipIf( - sys.version_info[:3] <= (3, 5, 0), 'this test requires Python 3.5.1+' + sys.version_info[:3] <= (3, 5, 0), "this test requires Python 3.5.1+" ) def test_client(self): - start_server = serve(handler, 'localhost', 0) + start_server = serve(handler, "localhost", 0) server = self.loop.run_until_complete(start_server) async def run_client(): @@ -81,12 +81,12 @@ async def run_client(): # Asynchronous context managers are only enabled on Python ≥ 3.5.1. @unittest.skipIf( - sys.version_info[:3] <= (3, 5, 0), 'this test requires Python 3.5.1+' + sys.version_info[:3] <= (3, 5, 0), "this test requires Python 3.5.1+" ) def test_server(self): async def run_server(): # Use serve as an asynchronous context manager. - async with serve(handler, 'localhost', 0) as server: + async with serve(handler, "localhost", 0) as server: self.assertTrue(server.sockets) # Check that exiting the context manager closed the server. @@ -96,9 +96,9 @@ async def run_server(): # Asynchronous context managers are only enabled on Python ≥ 3.5.1. @unittest.skipIf( - sys.version_info[:3] <= (3, 5, 0), 'this test requires Python 3.5.1+' + sys.version_info[:3] <= (3, 5, 0), "this test requires Python 3.5.1+" ) - @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'this test requires Unix sockets') + @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") def test_unix_server(self): async def run_server(path): async with unix_serve(handler, path) as server: @@ -108,5 +108,5 @@ async def run_server(path): self.assertFalse(server.sockets) with tempfile.TemporaryDirectory() as temp_dir: - path = bytes(pathlib.Path(temp_dir) / 'websockets') + path = bytes(pathlib.Path(temp_dir) / "websockets") self.loop.run_until_complete(run_server(path)) diff --git a/tests/py36/_test_client_server.py b/tests/py36/_test_client_server.py index f38fbe6f6..10b135cc9 100644 --- a/tests/py36/_test_client_server.py +++ b/tests/py36/_test_client_server.py @@ -17,7 +17,7 @@ raise ImportError("Python 3.6+ only") -MESSAGES = ['3', '2', '1', 'Fire!'] +MESSAGES = ["3", "2", "1", "Fire!"] class AsyncIteratorTests(unittest.TestCase): @@ -37,7 +37,7 @@ async def handler(ws, path): for message in MESSAGES: await ws.send(message) - start_server = serve(handler, 'localhost', 0) + start_server = serve(handler, "localhost", 0) server = self.loop.run_until_complete(start_server) messages = [] @@ -61,7 +61,7 @@ async def handler(ws, path): await ws.send(message) await ws.close(1001) - start_server = serve(handler, 'localhost', 0) + start_server = serve(handler, "localhost", 0) server = self.loop.run_until_complete(start_server) messages = [] @@ -85,7 +85,7 @@ async def handler(ws, path): await ws.send(message) await ws.close(1011) - start_server = serve(handler, 'localhost', 0) + start_server = serve(handler, "localhost", 0) server = self.loop.run_until_complete(start_server) messages = [] diff --git a/tests/test_client_server.py b/tests/test_client_server.py index eade7e066..9ba2725d9 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -55,25 +55,25 @@ # $ cat test_localhost.key test_localhost.crt > test_localhost.pem # $ rm test_localhost.key test_localhost.crt -testcert = bytes(pathlib.Path(__file__).with_name('test_localhost.pem')) +testcert = bytes(pathlib.Path(__file__).with_name("test_localhost.pem")) @asyncio.coroutine def handler(ws, path): - if path == '/attributes': + if path == "/attributes": yield from ws.send(repr((ws.host, ws.port, ws.secure))) - elif path == '/close_timeout': + elif path == "/close_timeout": yield from ws.send(repr(ws.close_timeout)) - elif path == '/path': + elif path == "/path": yield from ws.send(str(ws.path)) - elif path == '/headers': + elif path == "/headers": yield from ws.send(repr(ws.request_headers)) yield from ws.send(repr(ws.response_headers)) - elif path == '/extensions': + elif path == "/extensions": yield from ws.send(repr(ws.extensions)) - elif path == '/subprotocol': + elif path == "/subprotocol": yield from ws.send(repr(ws.subprotocol)) - elif path == '/slow_stop': + elif path == "/slow_stop": yield from ws.wait_closed() yield from asyncio.sleep(2 * MS) else: @@ -142,14 +142,14 @@ def with_client(*args, **kwds): return with_manager(temp_test_client, *args, **kwds) -def get_server_uri(server, secure=False, resource_name='/', user_info=None): +def get_server_uri(server, secure=False, resource_name="/", user_info=None): """ Return a WebSocket URI for connecting to the given server. """ - proto = 'wss' if secure else 'ws' + proto = "wss" if secure else "ws" - user_info = ':'.join(user_info) + '@' if user_info else '' + user_info = ":".join(user_info) + "@" if user_info else "" # Pick a random socket in order to test both IPv4 and IPv6 on systems # where both are available. Randomizing tests is usually a bad idea. If @@ -158,38 +158,38 @@ def get_server_uri(server, secure=False, resource_name='/', user_info=None): if server_socket.family == socket.AF_INET6: # pragma: no cover host, port = server_socket.getsockname()[:2] # (no IPv6 on CI) - host = '[{}]'.format(host) + host = "[{}]".format(host) elif server_socket.family == socket.AF_INET: host, port = server_socket.getsockname() elif server_socket.family == socket.AF_UNIX: # The host and port are ignored when connecting to a Unix socket. - host, port = 'localhost', 0 + host, port = "localhost", 0 else: # pragma: no cover raise ValueError("Expected an IPv6, IPv4, or Unix socket") - return '{}://{}{}:{}{}'.format(proto, user_info, host, port, resource_name) + return "{}://{}{}:{}{}".format(proto, user_info, host, port, resource_name) class UnauthorizedServerProtocol(WebSocketServerProtocol): @asyncio.coroutine def process_request(self, path, request_headers): # Test returning headers as a Headers instance (1/3) - return UNAUTHORIZED, Headers([('X-Access', 'denied')]), b'' + return UNAUTHORIZED, Headers([("X-Access", "denied")]), b"" class ForbiddenServerProtocol(WebSocketServerProtocol): @asyncio.coroutine def process_request(self, path, request_headers): # Test returning headers as a dict (2/3) - return FORBIDDEN, {'X-Access': 'denied'}, b'' + return FORBIDDEN, {"X-Access": "denied"}, b"" class HealthCheckServerProtocol(WebSocketServerProtocol): @asyncio.coroutine def process_request(self, path, request_headers): # Test returning headers as a list of pairs (3/3) - if path == '/__health__/': - return OK, [('X-Access', 'OK')], b'status = green\n' + if path == "/__health__/": + return OK, [("X-Access", "OK")], b"status = green\n" class SlowServerProtocol(WebSocketServerProtocol): @@ -207,7 +207,7 @@ class BarClientProtocol(WebSocketClientProtocol): class ClientNoOpExtensionFactory: - name = 'x-no-op' + name = "x-no-op" def get_request_params(self): return [] @@ -219,7 +219,7 @@ def process_response_params(self, params, accepted_extensions): class ServerNoOpExtensionFactory: - name = 'x-no-op' + name = "x-no-op" def __init__(self, params=None): self.params = params or [] @@ -229,10 +229,10 @@ def process_request_params(self, params, accepted_extensions): class NoOpExtension: - name = 'x-no-op' + name = "x-no-op" def __repr__(self): - return 'NoOpExtension()' + return "NoOpExtension()" def decode(self, frame, *, max_size=None): return frame @@ -266,10 +266,10 @@ def server_context(self): def start_server(self, **kwds): # Disable compression by default in tests. - kwds.setdefault('compression', None) + kwds.setdefault("compression", None) # Disable pings by default in tests. - kwds.setdefault('ping_interval', None) - start_server = serve(handler, 'localhost', 0, **kwds) + kwds.setdefault("ping_interval", None) + start_server = serve(handler, "localhost", 0, **kwds) self.server = self.loop.run_until_complete(start_server) def start_redirecting_server( @@ -278,13 +278,13 @@ def start_redirecting_server( def _process_request(path, headers): server_uri = get_server_uri(self.server, self.secure, path) if force_insecure: - server_uri = server_uri.replace('wss:', 'ws:') - headers = {'Location': server_uri} if include_location else [] + server_uri = server_uri.replace("wss:", "ws:") + headers = {"Location": server_uri} if include_location else [] return status, headers, b"" start_server = serve( handler, - 'localhost', + "localhost", 0, compression=None, ping_interval=None, @@ -293,12 +293,12 @@ def _process_request(path, headers): ) self.redirecting_server = self.loop.run_until_complete(start_server) - def start_client(self, resource_name='/', user_info=None, **kwds): + def start_client(self, resource_name="/", user_info=None, **kwds): # Disable compression by default in tests. - kwds.setdefault('compression', None) + kwds.setdefault("compression", None) # Disable pings by default in tests. - kwds.setdefault('ping_interval', None) - secure = kwds.get('ssl') is not None + kwds.setdefault("ping_interval", None) + secure = kwds.get("ssl") is not None server = self.redirecting_server if self.redirecting_server else self.server server_uri = get_server_uri(server, secure, resource_name, user_info) start_client = connect(server_uri, **kwds) @@ -370,14 +370,14 @@ def test_infinite_redirect(self): self.server = self.redirecting_server with self.assertRaises(InvalidHandshake): with temp_test_client(self): - self.fail('Did not raise') # pragma: no cover + self.fail("Did not raise") # pragma: no cover @with_server() def test_redirect_missing_location(self): with temp_test_redirecting_server(self, FOUND, include_location=False): with self.assertRaises(InvalidMessage): with temp_test_client(self): - self.fail('Did not raise') # pragma: no cover + self.fail("Did not raise") # pragma: no cover def test_explicit_event_loop(self): with self.temp_server(loop=self.loop): @@ -388,7 +388,7 @@ def test_explicit_event_loop(self): # The way the legacy SSL implementation wraps sockets makes it extremely # hard to write a test for Python 3.4. - @unittest.skipIf(sys.version_info[:2] <= (3, 4), 'this test requires Python 3.5+') + @unittest.skipIf(sys.version_info[:2] <= (3, 4), "this test requires Python 3.5+") @with_server() def test_explicit_socket(self): class TrackedSocket(socket.socket): @@ -418,7 +418,7 @@ def send(self, *args, **kwargs): with self.temp_client( sock=client_socket, # "You must set server_hostname when using ssl without a host" - server_hostname='localhost' if self.secure else None, + server_hostname="localhost" if self.secure else None, ): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) @@ -430,10 +430,10 @@ def send(self, *args, **kwargs): finally: client_socket.close() - @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'this test requires Unix sockets') + @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") def test_unix_socket(self): with tempfile.TemporaryDirectory() as temp_dir: - path = bytes(pathlib.Path(temp_dir) / 'websockets') + path = bytes(pathlib.Path(temp_dir) / "websockets") # Like self.start_server() but with unix_serve(). unix_server = unix_serve(handler, path) @@ -452,24 +452,24 @@ def test_unix_socket(self): client_socket.close() self.stop_server() - @with_server(process_request=lambda p, rh: (OK, [], b'OK\n')) + @with_server(process_request=lambda p, rh: (OK, [], b"OK\n")) def test_process_request_argument(self): - response = self.loop.run_until_complete(self.make_http_request('/')) + response = self.loop.run_until_complete(self.make_http_request("/")) with contextlib.closing(response): self.assertEqual(response.code, 200) @with_server( - subprotocols=['superchat', 'chat'], select_subprotocol=lambda cs, ss: 'chat' + subprotocols=["superchat", "chat"], select_subprotocol=lambda cs, ss: "chat" ) - @with_client('/subprotocol', subprotocols=['superchat', 'chat']) + @with_client("/subprotocol", subprotocols=["superchat", "chat"]) def test_select_subprotocol_argument(self): server_subprotocol = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(server_subprotocol, repr('chat')) - self.assertEqual(self.client.subprotocol, 'chat') + self.assertEqual(server_subprotocol, repr("chat")) + self.assertEqual(self.client.subprotocol, "chat") @with_server() - @with_client('/attributes') + @with_client("/attributes") def test_protocol_attributes(self): # The test could be connecting with IPv6 or IPv4. expected_client_attrs = [ @@ -479,120 +479,120 @@ def test_protocol_attributes(self): client_attrs = (self.client.host, self.client.port, self.client.secure) self.assertIn(client_attrs, expected_client_attrs) - expected_server_attrs = ('localhost', 0, self.secure) + expected_server_attrs = ("localhost", 0, self.secure) server_attrs = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_attrs, repr(expected_server_attrs)) @with_server() - @with_client('/path') + @with_client("/path") def test_protocol_path(self): client_path = self.client.path - self.assertEqual(client_path, '/path') + self.assertEqual(client_path, "/path") server_path = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(server_path, '/path') + self.assertEqual(server_path, "/path") @with_server() - @with_client('/headers', user_info=('user', 'pass')) + @with_client("/headers", user_info=("user", "pass")) def test_protocol_basic_auth(self): self.assertEqual( - self.client.request_headers['Authorization'], 'Basic dXNlcjpwYXNz' + self.client.request_headers["Authorization"], "Basic dXNlcjpwYXNz" ) @with_server() - @with_client('/headers') + @with_client("/headers") def test_protocol_headers(self): client_req = self.client.request_headers client_resp = self.client.response_headers - self.assertEqual(client_req['User-Agent'], USER_AGENT) - self.assertEqual(client_resp['Server'], USER_AGENT) + self.assertEqual(client_req["User-Agent"], USER_AGENT) + self.assertEqual(client_resp["Server"], USER_AGENT) server_req = self.loop.run_until_complete(self.client.recv()) server_resp = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_req, repr(client_req)) self.assertEqual(server_resp, repr(client_resp)) @with_server() - @with_client('/headers', extra_headers=Headers({'X-Spam': 'Eggs'})) + @with_client("/headers", extra_headers=Headers({"X-Spam": "Eggs"})) def test_protocol_custom_request_headers(self): req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", req_headers) @with_server() - @with_client('/headers', extra_headers={'X-Spam': 'Eggs'}) + @with_client("/headers", extra_headers={"X-Spam": "Eggs"}) def test_protocol_custom_request_headers_dict(self): req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", req_headers) @with_server() - @with_client('/headers', extra_headers=[('X-Spam', 'Eggs')]) + @with_client("/headers", extra_headers=[("X-Spam", "Eggs")]) def test_protocol_custom_request_headers_list(self): req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", req_headers) @with_server() - @with_client('/headers', extra_headers=[('User-Agent', 'Eggs')]) + @with_client("/headers", extra_headers=[("User-Agent", "Eggs")]) def test_protocol_custom_request_user_agent(self): req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertEqual(req_headers.count("User-Agent"), 1) self.assertIn("('User-Agent', 'Eggs')", req_headers) - @with_server(extra_headers=lambda p, r: Headers({'X-Spam': 'Eggs'})) - @with_client('/headers') + @with_server(extra_headers=lambda p, r: Headers({"X-Spam": "Eggs"})) + @with_client("/headers") def test_protocol_custom_response_headers_callable(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) - @with_server(extra_headers=lambda p, r: {'X-Spam': 'Eggs'}) - @with_client('/headers') + @with_server(extra_headers=lambda p, r: {"X-Spam": "Eggs"}) + @with_client("/headers") def test_protocol_custom_response_headers_callable_dict(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) - @with_server(extra_headers=lambda p, r: [('X-Spam', 'Eggs')]) - @with_client('/headers') + @with_server(extra_headers=lambda p, r: [("X-Spam", "Eggs")]) + @with_client("/headers") def test_protocol_custom_response_headers_callable_list(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) - @with_server(extra_headers=Headers({'X-Spam': 'Eggs'})) - @with_client('/headers') + @with_server(extra_headers=Headers({"X-Spam": "Eggs"})) + @with_client("/headers") def test_protocol_custom_response_headers(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) - @with_server(extra_headers={'X-Spam': 'Eggs'}) - @with_client('/headers') + @with_server(extra_headers={"X-Spam": "Eggs"}) + @with_client("/headers") def test_protocol_custom_response_headers_dict(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) - @with_server(extra_headers=[('X-Spam', 'Eggs')]) - @with_client('/headers') + @with_server(extra_headers=[("X-Spam", "Eggs")]) + @with_client("/headers") def test_protocol_custom_response_headers_list(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) - @with_server(extra_headers=[('Server', 'Eggs')]) - @with_client('/headers') + @with_server(extra_headers=[("Server", "Eggs")]) + @with_client("/headers") def test_protocol_custom_response_user_agent(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertEqual(resp_headers.count("Server"), 1) self.assertIn("('Server', 'Eggs')", resp_headers) - def make_http_request(self, path='/'): + def make_http_request(self, path="/"): # Set url to 'https?://:'. url = get_server_uri(self.server, resource_name=path, secure=self.secure) - url = url.replace('ws', 'http') + url = url.replace("ws", "http") if self.secure: open_health_check = functools.partial( @@ -606,11 +606,11 @@ def make_http_request(self, path='/'): @with_server(create_protocol=HealthCheckServerProtocol) def test_http_request_http_endpoint(self): # Making a HTTP request to a HTTP endpoint succeeds. - response = self.loop.run_until_complete(self.make_http_request('/__health__/')) + response = self.loop.run_until_complete(self.make_http_request("/__health__/")) with contextlib.closing(response): self.assertEqual(response.code, 200) - self.assertEqual(response.read(), b'status = green\n') + self.assertEqual(response.read(), b"status = green\n") @with_server(create_protocol=HealthCheckServerProtocol) def test_http_request_ws_endpoint(self): @@ -619,13 +619,13 @@ def test_http_request_ws_endpoint(self): self.loop.run_until_complete(self.make_http_request()) self.assertEqual(raised.exception.code, 426) - self.assertEqual(raised.exception.headers['Upgrade'], 'websocket') + self.assertEqual(raised.exception.headers["Upgrade"], "websocket") @with_server(create_protocol=HealthCheckServerProtocol) def test_ws_connection_http_endpoint(self): # Making a WS connection to a HTTP endpoint fails. with self.assertRaises(InvalidStatusCode) as raised: - self.start_client('/__health__/') + self.start_client("/__health__/") self.assertEqual(raised.exception.status_code, 200) @@ -665,93 +665,93 @@ def test_server_create_protocol_over_klass(self): self.assert_client_raises_code(403) @with_server() - @with_client('/path', create_protocol=FooClientProtocol) + @with_client("/path", create_protocol=FooClientProtocol) def test_client_create_protocol(self): self.assertIsInstance(self.client, FooClientProtocol) @with_server() @with_client( - '/path', + "/path", create_protocol=(lambda *args, **kwargs: FooClientProtocol(*args, **kwargs)), ) def test_client_create_protocol_function(self): self.assertIsInstance(self.client, FooClientProtocol) @with_server() - @with_client('/path', klass=FooClientProtocol) + @with_client("/path", klass=FooClientProtocol) def test_client_klass(self): self.assertIsInstance(self.client, FooClientProtocol) @with_server() - @with_client('/path', create_protocol=BarClientProtocol, klass=FooClientProtocol) + @with_client("/path", create_protocol=BarClientProtocol, klass=FooClientProtocol) def test_client_create_protocol_over_klass(self): self.assertIsInstance(self.client, BarClientProtocol) @with_server(close_timeout=7) - @with_client('/close_timeout') + @with_client("/close_timeout") def test_server_close_timeout(self): close_timeout = self.loop.run_until_complete(self.client.recv()) self.assertEqual(eval(close_timeout), 7) @with_server(timeout=6) - @with_client('/close_timeout') + @with_client("/close_timeout") def test_server_timeout_backwards_compatibility(self): close_timeout = self.loop.run_until_complete(self.client.recv()) self.assertEqual(eval(close_timeout), 6) @with_server(close_timeout=7, timeout=6) - @with_client('/close_timeout') + @with_client("/close_timeout") def test_server_close_timeout_over_timeout(self): close_timeout = self.loop.run_until_complete(self.client.recv()) self.assertEqual(eval(close_timeout), 7) @with_server() - @with_client('/close_timeout', close_timeout=7) + @with_client("/close_timeout", close_timeout=7) def test_client_close_timeout(self): self.assertEqual(self.client.close_timeout, 7) @with_server() - @with_client('/close_timeout', timeout=6) + @with_client("/close_timeout", timeout=6) def test_client_timeout_backwards_compatibility(self): self.assertEqual(self.client.close_timeout, 6) @with_server() - @with_client('/close_timeout', close_timeout=7, timeout=6) + @with_client("/close_timeout", close_timeout=7, timeout=6) def test_client_close_timeout_over_timeout(self): self.assertEqual(self.client.close_timeout, 7) @with_server() - @with_client('/extensions') + @with_client("/extensions") def test_no_extension(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([])) self.assertEqual(repr(self.client.extensions), repr([])) @with_server(extensions=[ServerNoOpExtensionFactory()]) - @with_client('/extensions', extensions=[ClientNoOpExtensionFactory()]) + @with_client("/extensions", extensions=[ClientNoOpExtensionFactory()]) def test_extension(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([NoOpExtension()])) self.assertEqual(repr(self.client.extensions), repr([NoOpExtension()])) @with_server() - @with_client('/extensions', extensions=[ClientNoOpExtensionFactory()]) + @with_client("/extensions", extensions=[ClientNoOpExtensionFactory()]) def test_extension_not_accepted(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([])) self.assertEqual(repr(self.client.extensions), repr([])) @with_server(extensions=[ServerNoOpExtensionFactory()]) - @with_client('/extensions') + @with_client("/extensions") def test_extension_not_requested(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([])) self.assertEqual(repr(self.client.extensions), repr([])) - @with_server(extensions=[ServerNoOpExtensionFactory([('foo', None)])]) + @with_server(extensions=[ServerNoOpExtensionFactory([("foo", None)])]) def test_extension_client_rejection(self): with self.assertRaises(NegotiationError): - self.start_client('/extensions', extensions=[ClientNoOpExtensionFactory()]) + self.start_client("/extensions", extensions=[ClientNoOpExtensionFactory()]) @with_server( extensions=[ @@ -760,7 +760,7 @@ def test_extension_client_rejection(self): ServerPerMessageDeflateFactory(), ] ) - @with_client('/extensions', extensions=[ClientPerMessageDeflateFactory()]) + @with_client("/extensions", extensions=[ClientPerMessageDeflateFactory()]) def test_extension_no_match_then_match(self): # The order requested by the client has priority. server_extensions = self.loop.run_until_complete(self.client.recv()) @@ -773,7 +773,7 @@ def test_extension_no_match_then_match(self): ) @with_server(extensions=[ServerPerMessageDeflateFactory()]) - @with_client('/extensions', extensions=[ClientNoOpExtensionFactory()]) + @with_client("/extensions", extensions=[ClientNoOpExtensionFactory()]) def test_extension_mismatch(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_extensions, repr([])) @@ -783,7 +783,7 @@ def test_extension_mismatch(self): extensions=[ServerNoOpExtensionFactory(), ServerPerMessageDeflateFactory()] ) @with_client( - '/extensions', + "/extensions", extensions=[ClientPerMessageDeflateFactory(), ClientNoOpExtensionFactory()], ) def test_extension_order(self): @@ -799,25 +799,25 @@ def test_extension_order(self): ) @with_server(extensions=[ServerNoOpExtensionFactory()]) - @unittest.mock.patch.object(WebSocketServerProtocol, 'process_extensions') + @unittest.mock.patch.object(WebSocketServerProtocol, "process_extensions") def test_extensions_error(self, _process_extensions): - _process_extensions.return_value = 'x-no-op', [NoOpExtension()] + _process_extensions.return_value = "x-no-op", [NoOpExtension()] with self.assertRaises(NegotiationError): self.start_client( - '/extensions', extensions=[ClientPerMessageDeflateFactory()] + "/extensions", extensions=[ClientPerMessageDeflateFactory()] ) @with_server(extensions=[ServerNoOpExtensionFactory()]) - @unittest.mock.patch.object(WebSocketServerProtocol, 'process_extensions') + @unittest.mock.patch.object(WebSocketServerProtocol, "process_extensions") def test_extensions_error_no_extensions(self, _process_extensions): - _process_extensions.return_value = 'x-no-op', [NoOpExtension()] + _process_extensions.return_value = "x-no-op", [NoOpExtension()] with self.assertRaises(InvalidHandshake): - self.start_client('/extensions') + self.start_client("/extensions") - @with_server(compression='deflate') - @with_client('/extensions', compression='deflate') + @with_server(compression="deflate") + @with_client("/extensions", compression="deflate") def test_compression_deflate(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual( @@ -834,16 +834,16 @@ def test_compression_deflate(self): client_no_context_takeover=True, server_max_window_bits=10 ) ], - compression='deflate', # overridden by explicit config + compression="deflate", # overridden by explicit config ) @with_client( - '/extensions', + "/extensions", extensions=[ ClientPerMessageDeflateFactory( server_no_context_takeover=True, client_max_window_bits=12 ) ], - compression='deflate', # overridden by explicit config + compression="deflate", # overridden by explicit config ) def test_compression_deflate_and_explicit_config(self): server_extensions = self.loop.run_until_complete(self.client.recv()) @@ -856,77 +856,77 @@ def test_compression_deflate_and_explicit_config(self): def test_compression_unsupported_server(self): with self.assertRaises(ValueError): - self.loop.run_until_complete(self.start_server(compression='xz')) + self.loop.run_until_complete(self.start_server(compression="xz")) @with_server() def test_compression_unsupported_client(self): with self.assertRaises(ValueError): - self.loop.run_until_complete(self.start_client(compression='xz')) + self.loop.run_until_complete(self.start_client(compression="xz")) @with_server() - @with_client('/subprotocol') + @with_client("/subprotocol") def test_no_subprotocol(self): server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) - @with_server(subprotocols=['superchat', 'chat']) - @with_client('/subprotocol', subprotocols=['otherchat', 'chat']) + @with_server(subprotocols=["superchat", "chat"]) + @with_client("/subprotocol", subprotocols=["otherchat", "chat"]) def test_subprotocol(self): server_subprotocol = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(server_subprotocol, repr('chat')) - self.assertEqual(self.client.subprotocol, 'chat') + self.assertEqual(server_subprotocol, repr("chat")) + self.assertEqual(self.client.subprotocol, "chat") - @with_server(subprotocols=['superchat']) - @with_client('/subprotocol', subprotocols=['otherchat']) + @with_server(subprotocols=["superchat"]) + @with_client("/subprotocol", subprotocols=["otherchat"]) def test_subprotocol_not_accepted(self): server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) @with_server() - @with_client('/subprotocol', subprotocols=['otherchat', 'chat']) + @with_client("/subprotocol", subprotocols=["otherchat", "chat"]) def test_subprotocol_not_offered(self): server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) - @with_server(subprotocols=['superchat', 'chat']) - @with_client('/subprotocol') + @with_server(subprotocols=["superchat", "chat"]) + @with_client("/subprotocol") def test_subprotocol_not_requested(self): server_subprotocol = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_subprotocol, repr(None)) self.assertEqual(self.client.subprotocol, None) - @with_server(subprotocols=['superchat']) - @unittest.mock.patch.object(WebSocketServerProtocol, 'process_subprotocol') + @with_server(subprotocols=["superchat"]) + @unittest.mock.patch.object(WebSocketServerProtocol, "process_subprotocol") def test_subprotocol_error(self, _process_subprotocol): - _process_subprotocol.return_value = 'superchat' + _process_subprotocol.return_value = "superchat" with self.assertRaises(NegotiationError): - self.start_client('/subprotocol', subprotocols=['otherchat']) + self.start_client("/subprotocol", subprotocols=["otherchat"]) self.run_loop_once() - @with_server(subprotocols=['superchat']) - @unittest.mock.patch.object(WebSocketServerProtocol, 'process_subprotocol') + @with_server(subprotocols=["superchat"]) + @unittest.mock.patch.object(WebSocketServerProtocol, "process_subprotocol") def test_subprotocol_error_no_subprotocols(self, _process_subprotocol): - _process_subprotocol.return_value = 'superchat' + _process_subprotocol.return_value = "superchat" with self.assertRaises(InvalidHandshake): - self.start_client('/subprotocol') + self.start_client("/subprotocol") self.run_loop_once() - @with_server(subprotocols=['superchat', 'chat']) - @unittest.mock.patch.object(WebSocketServerProtocol, 'process_subprotocol') + @with_server(subprotocols=["superchat", "chat"]) + @unittest.mock.patch.object(WebSocketServerProtocol, "process_subprotocol") def test_subprotocol_error_two_subprotocols(self, _process_subprotocol): - _process_subprotocol.return_value = 'superchat, chat' + _process_subprotocol.return_value = "superchat, chat" with self.assertRaises(InvalidHandshake): - self.start_client('/subprotocol', subprotocols=['superchat', 'chat']) + self.start_client("/subprotocol", subprotocols=["superchat", "chat"]) self.run_loop_once() @with_server() - @unittest.mock.patch('websockets.server.read_request') + @unittest.mock.patch("websockets.server.read_request") def test_server_receives_malformed_request(self, _read_request): _read_request.side_effect = ValueError("read_request failed") @@ -934,7 +934,7 @@ def test_server_receives_malformed_request(self, _read_request): self.start_client() @with_server() - @unittest.mock.patch('websockets.client.read_response') + @unittest.mock.patch("websockets.client.read_response") def test_client_receives_malformed_response(self, _read_response): _read_response.side_effect = ValueError("read_response failed") @@ -943,10 +943,10 @@ def test_client_receives_malformed_response(self, _read_response): self.run_loop_once() @with_server() - @unittest.mock.patch('websockets.client.build_request') + @unittest.mock.patch("websockets.client.build_request") def test_client_sends_invalid_handshake_request(self, _build_request): def wrong_build_request(headers): - return '42' + return "42" _build_request.side_effect = wrong_build_request @@ -954,10 +954,10 @@ def wrong_build_request(headers): self.start_client() @with_server() - @unittest.mock.patch('websockets.server.build_response') + @unittest.mock.patch("websockets.server.build_response") def test_server_sends_invalid_handshake_response(self, _build_response): def wrong_build_response(headers, key): - return build_response(headers, '42') + return build_response(headers, "42") _build_response.side_effect = wrong_build_response @@ -965,12 +965,12 @@ def wrong_build_response(headers, key): self.start_client() @with_server() - @unittest.mock.patch('websockets.client.read_response') + @unittest.mock.patch("websockets.client.read_response") def test_server_does_not_switch_protocols(self, _read_response): @asyncio.coroutine def wrong_read_response(stream): status_code, reason, headers = yield from read_response(stream) - return 400, 'Bad Request', headers + return 400, "Bad Request", headers _read_response.side_effect = wrong_read_response @@ -979,7 +979,7 @@ def wrong_read_response(stream): self.run_loop_once() @with_server() - @unittest.mock.patch('websockets.server.WebSocketServerProtocol.process_request') + @unittest.mock.patch("websockets.server.WebSocketServerProtocol.process_request") def test_server_error_in_handshake(self, _process_request): _process_request.side_effect = Exception("process_request crashed") @@ -987,7 +987,7 @@ def test_server_error_in_handshake(self, _process_request): self.start_client() @with_server() - @unittest.mock.patch('websockets.server.WebSocketServerProtocol.send') + @unittest.mock.patch("websockets.server.WebSocketServerProtocol.send") def test_server_handler_crashes(self, send): send.side_effect = ValueError("send failed") @@ -1000,7 +1000,7 @@ def test_server_handler_crashes(self, send): self.assertEqual(self.client.close_code, 1011) @with_server() - @unittest.mock.patch('websockets.server.WebSocketServerProtocol.close') + @unittest.mock.patch("websockets.server.WebSocketServerProtocol.close") def test_server_close_crashes(self, close): close.side_effect = ValueError("close failed") @@ -1014,7 +1014,7 @@ def test_server_close_crashes(self, close): @with_server() @with_client() - @unittest.mock.patch.object(WebSocketClientProtocol, 'handshake') + @unittest.mock.patch.object(WebSocketClientProtocol, "handshake") def test_client_closes_connection_before_handshake(self, handshake): # We have mocked the handshake() method to prevent the client from # performing the opening handshake. Force it to close the connection. @@ -1042,7 +1042,7 @@ def test_server_shuts_down_during_connection_handling(self): self.assertEqual(self.client.close_code, 1001) @with_server() - @unittest.mock.patch('websockets.server.WebSocketServerProtocol.close') + @unittest.mock.patch("websockets.server.WebSocketServerProtocol.close") def test_server_shuts_down_during_connection_close(self, _close): _close.side_effect = asyncio.CancelledError @@ -1059,7 +1059,7 @@ def test_server_shuts_down_during_connection_close(self, _close): def test_server_shuts_down_waits_until_handlers_terminate(self): # This handler waits a bit after the connection is closed in order # to test that wait_closed() really waits for handlers to complete. - self.start_client('/slow_stop') + self.start_client("/slow_stop") server_ws = next(iter(self.server.websockets)) # Test that the handler task keeps running after close(). @@ -1081,9 +1081,9 @@ def test_invalid_status_error_during_client_connect(self): @with_server() @unittest.mock.patch( - 'websockets.server.WebSocketServerProtocol.write_http_response' + "websockets.server.WebSocketServerProtocol.write_http_response" ) - @unittest.mock.patch('websockets.server.WebSocketServerProtocol.read_http_request') + @unittest.mock.patch("websockets.server.WebSocketServerProtocol.read_http_request") def test_connection_error_during_opening_handshake( self, _read_http_request, _write_http_response ): @@ -1101,7 +1101,7 @@ def test_connection_error_during_opening_handshake( _write_http_response.assert_not_called() @with_server() - @unittest.mock.patch('websockets.server.WebSocketServerProtocol.close') + @unittest.mock.patch("websockets.server.WebSocketServerProtocol.close") def test_connection_error_during_closing_handshake(self, close): close.side_effect = ConnectionError @@ -1139,11 +1139,11 @@ def client_context(self): return ssl_context def start_server(self, **kwds): - kwds.setdefault('ssl', self.server_context) + kwds.setdefault("ssl", self.server_context) super().start_server(**kwds) - def start_client(self, path='/', **kwds): - kwds.setdefault('ssl', self.client_context) + def start_client(self, path="/", **kwds): + kwds.setdefault("ssl", self.client_context) super().start_client(path, **kwds) # TLS over Unix sockets doesn't make sense. @@ -1165,7 +1165,7 @@ def test_redirect_insecure(self): with temp_test_redirecting_server(self, FOUND, force_insecure=True): with self.assertRaises(InvalidHandshake): with temp_test_client(self): - self.fail('Did not raise') # pragma: no cover + self.fail("Did not raise") # pragma: no cover class ClientServerOriginTests(unittest.TestCase): @@ -1178,10 +1178,10 @@ def tearDown(self): def test_checking_origin_succeeds(self): server = self.loop.run_until_complete( - serve(handler, 'localhost', 0, origins=['http://localhost']) + serve(handler, "localhost", 0, origins=["http://localhost"]) ) client = self.loop.run_until_complete( - connect(get_server_uri(server), origin='http://localhost') + connect(get_server_uri(server), origin="http://localhost") ) self.loop.run_until_complete(client.send("Hello!")) @@ -1193,11 +1193,11 @@ def test_checking_origin_succeeds(self): def test_checking_origin_fails(self): server = self.loop.run_until_complete( - serve(handler, 'localhost', 0, origins=['http://localhost']) + serve(handler, "localhost", 0, origins=["http://localhost"]) ) with self.assertRaisesRegex(InvalidHandshake, "Status code not 101: 403"): self.loop.run_until_complete( - connect(get_server_uri(server), origin='http://otherhost') + connect(get_server_uri(server), origin="http://otherhost") ) server.close() @@ -1205,14 +1205,14 @@ def test_checking_origin_fails(self): def test_checking_origins_fails_with_multiple_headers(self): server = self.loop.run_until_complete( - serve(handler, 'localhost', 0, origins=['http://localhost']) + serve(handler, "localhost", 0, origins=["http://localhost"]) ) with self.assertRaisesRegex(InvalidHandshake, "Status code not 101: 400"): self.loop.run_until_complete( connect( get_server_uri(server), - origin='http://localhost', - extra_headers=[('Origin', 'http://otherhost')], + origin="http://localhost", + extra_headers=[("Origin", "http://otherhost")], ) ) @@ -1221,7 +1221,7 @@ def test_checking_origins_fails_with_multiple_headers(self): def test_checking_lack_of_origin_succeeds(self): server = self.loop.run_until_complete( - serve(handler, 'localhost', 0, origins=[None]) + serve(handler, "localhost", 0, origins=[None]) ) client = self.loop.run_until_complete(connect(get_server_uri(server))) @@ -1235,7 +1235,7 @@ def test_checking_lack_of_origin_succeeds(self): def test_checking_lack_of_origin_succeeds_backwards_compatibility(self): with warnings.catch_warnings(record=True) as recorded_warnings: server = self.loop.run_until_complete( - serve(handler, 'localhost', 0, origins=['']) + serve(handler, "localhost", 0, origins=[""]) ) client = self.loop.run_until_complete(connect(get_server_uri(server))) @@ -1261,7 +1261,7 @@ def tearDown(self): self.loop.close() def test_client(self): - start_server = serve(handler, 'localhost', 0) + start_server = serve(handler, "localhost", 0) server = self.loop.run_until_complete(start_server) @asyncio.coroutine @@ -1281,7 +1281,7 @@ def test_server(self): @asyncio.coroutine def run_server(): # Yield from serve. - server = yield from serve(handler, 'localhost', 0) + server = yield from serve(handler, "localhost", 0) self.assertTrue(server.sockets) server.close() yield from server.wait_closed() diff --git a/tests/test_framing.py b/tests/test_framing.py index ab11f6bdc..83d0a251a 100644 --- a/tests/test_framing.py +++ b/tests/test_framing.py @@ -59,63 +59,63 @@ def round_trip_close(self, data, code, reason): self.assertEqual(serialized, data) def test_text(self): - self.round_trip(b'\x81\x04Spam', Frame(True, OP_TEXT, b'Spam')) + self.round_trip(b"\x81\x04Spam", Frame(True, OP_TEXT, b"Spam")) def test_text_masked(self): self.round_trip( - b'\x81\x84\x5b\xfb\xe1\xa8\x08\x8b\x80\xc5', - Frame(True, OP_TEXT, b'Spam'), + b"\x81\x84\x5b\xfb\xe1\xa8\x08\x8b\x80\xc5", + Frame(True, OP_TEXT, b"Spam"), mask=True, ) def test_binary(self): - self.round_trip(b'\x82\x04Eggs', Frame(True, OP_BINARY, b'Eggs')) + self.round_trip(b"\x82\x04Eggs", Frame(True, OP_BINARY, b"Eggs")) def test_binary_masked(self): self.round_trip( - b'\x82\x84\x53\xcd\xe2\x89\x16\xaa\x85\xfa', - Frame(True, OP_BINARY, b'Eggs'), + b"\x82\x84\x53\xcd\xe2\x89\x16\xaa\x85\xfa", + Frame(True, OP_BINARY, b"Eggs"), mask=True, ) def test_non_ascii_text(self): self.round_trip( - b'\x81\x05caf\xc3\xa9', Frame(True, OP_TEXT, 'café'.encode('utf-8')) + b"\x81\x05caf\xc3\xa9", Frame(True, OP_TEXT, "café".encode("utf-8")) ) def test_non_ascii_text_masked(self): self.round_trip( - b'\x81\x85\x64\xbe\xee\x7e\x07\xdf\x88\xbd\xcd', - Frame(True, OP_TEXT, 'café'.encode('utf-8')), + b"\x81\x85\x64\xbe\xee\x7e\x07\xdf\x88\xbd\xcd", + Frame(True, OP_TEXT, "café".encode("utf-8")), mask=True, ) def test_close(self): - self.round_trip(b'\x88\x00', Frame(True, OP_CLOSE, b'')) + self.round_trip(b"\x88\x00", Frame(True, OP_CLOSE, b"")) def test_ping(self): - self.round_trip(b'\x89\x04ping', Frame(True, OP_PING, b'ping')) + self.round_trip(b"\x89\x04ping", Frame(True, OP_PING, b"ping")) def test_pong(self): - self.round_trip(b'\x8a\x04pong', Frame(True, OP_PONG, b'pong')) + self.round_trip(b"\x8a\x04pong", Frame(True, OP_PONG, b"pong")) def test_long(self): self.round_trip( - b'\x82\x7e\x00\x7e' + 126 * b'a', Frame(True, OP_BINARY, 126 * b'a') + b"\x82\x7e\x00\x7e" + 126 * b"a", Frame(True, OP_BINARY, 126 * b"a") ) def test_very_long(self): self.round_trip( - b'\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x00' + 65536 * b'a', - Frame(True, OP_BINARY, 65536 * b'a'), + b"\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x00" + 65536 * b"a", + Frame(True, OP_BINARY, 65536 * b"a"), ) def test_payload_too_big(self): with self.assertRaises(PayloadTooBig): - self.decode(b'\x82\x7e\x04\x01' + 1025 * b'a', max_size=1024) + self.decode(b"\x82\x7e\x04\x01" + 1025 * b"a", max_size=1024) def test_bad_reserved_bits(self): - for encoded in [b'\xc0\x00', b'\xa0\x00', b'\x90\x00']: + for encoded in [b"\xc0\x00", b"\xa0\x00", b"\x90\x00"]: with self.subTest(encoded=encoded): with self.assertRaises(WebSocketProtocolError): self.decode(encoded) @@ -135,41 +135,41 @@ def test_bad_opcode(self): def test_mask_flag(self): # Mask flag correctly set. - self.decode(b'\x80\x80\x00\x00\x00\x00', mask=True) + self.decode(b"\x80\x80\x00\x00\x00\x00", mask=True) # Mask flag incorrectly unset. with self.assertRaises(WebSocketProtocolError): - self.decode(b'\x80\x80\x00\x00\x00\x00') + self.decode(b"\x80\x80\x00\x00\x00\x00") # Mask flag correctly unset. - self.decode(b'\x80\x00') + self.decode(b"\x80\x00") # Mask flag incorrectly set. with self.assertRaises(WebSocketProtocolError): - self.decode(b'\x80\x00', mask=True) + self.decode(b"\x80\x00", mask=True) def test_control_frame_max_length(self): # At maximum allowed length. - self.decode(b'\x88\x7e\x00\x7d' + 125 * b'a') + self.decode(b"\x88\x7e\x00\x7d" + 125 * b"a") # Above maximum allowed length. with self.assertRaises(WebSocketProtocolError): - self.decode(b'\x88\x7e\x00\x7e' + 126 * b'a') + self.decode(b"\x88\x7e\x00\x7e" + 126 * b"a") def test_prepare_data_str(self): - self.assertEqual(prepare_data('café'), (OP_TEXT, b'caf\xc3\xa9')) + self.assertEqual(prepare_data("café"), (OP_TEXT, b"caf\xc3\xa9")) def test_prepare_data_bytes(self): - self.assertEqual(prepare_data(b'tea'), (OP_BINARY, b'tea')) + self.assertEqual(prepare_data(b"tea"), (OP_BINARY, b"tea")) def test_prepare_data_bytearray(self): self.assertEqual( - prepare_data(bytearray(b'tea')), (OP_BINARY, bytearray(b'tea')) + prepare_data(bytearray(b"tea")), (OP_BINARY, bytearray(b"tea")) ) def test_prepare_data_memoryview(self): self.assertEqual( - prepare_data(memoryview(b'tea')), (OP_BINARY, memoryview(b'tea')) + prepare_data(memoryview(b"tea")), (OP_BINARY, memoryview(b"tea")) ) def test_prepare_data_non_contiguous_memoryview(self): - self.assertEqual(prepare_data(memoryview(b'tteeaa')[::2]), (OP_BINARY, b'tea')) + self.assertEqual(prepare_data(memoryview(b"tteeaa")[::2]), (OP_BINARY, b"tea")) def test_prepare_data_list(self): with self.assertRaises(TypeError): @@ -180,19 +180,19 @@ def test_prepare_data_none(self): prepare_data(None) def test_encode_data_str(self): - self.assertEqual(encode_data('café'), b'caf\xc3\xa9') + self.assertEqual(encode_data("café"), b"caf\xc3\xa9") def test_encode_data_bytes(self): - self.assertEqual(encode_data(b'tea'), b'tea') + self.assertEqual(encode_data(b"tea"), b"tea") def test_encode_data_bytearray(self): - self.assertEqual(encode_data(bytearray(b'tea')), b'tea') + self.assertEqual(encode_data(bytearray(b"tea")), b"tea") def test_encode_data_memoryview(self): - self.assertEqual(encode_data(memoryview(b'tea')), b'tea') + self.assertEqual(encode_data(memoryview(b"tea")), b"tea") def test_encode_data_non_contiguous_memoryview(self): - self.assertEqual(encode_data(memoryview(b'tteeaa')[::2]), b'tea') + self.assertEqual(encode_data(memoryview(b"tteeaa")[::2]), b"tea") def test_encode_data_list(self): with self.assertRaises(TypeError): @@ -204,29 +204,29 @@ def test_encode_data_none(self): def test_fragmented_control_frame(self): # Fin bit correctly set. - self.decode(b'\x88\x00') + self.decode(b"\x88\x00") # Fin bit incorrectly unset. with self.assertRaises(WebSocketProtocolError): - self.decode(b'\x08\x00') + self.decode(b"\x08\x00") def test_parse_close_and_serialize_close(self): - self.round_trip_close(b'\x03\xe8', 1000, '') - self.round_trip_close(b'\x03\xe8OK', 1000, 'OK') + self.round_trip_close(b"\x03\xe8", 1000, "") + self.round_trip_close(b"\x03\xe8OK", 1000, "OK") def test_parse_close_empty(self): - self.assertEqual(parse_close(b''), (1005, '')) + self.assertEqual(parse_close(b""), (1005, "")) def test_parse_close_errors(self): with self.assertRaises(WebSocketProtocolError): - parse_close(b'\x03') + parse_close(b"\x03") with self.assertRaises(WebSocketProtocolError): - parse_close(b'\x03\xe7') + parse_close(b"\x03\xe7") with self.assertRaises(UnicodeDecodeError): - parse_close(b'\x03\xe8\xff\xff') + parse_close(b"\x03\xe8\xff\xff") def test_serialize_close_errors(self): with self.assertRaises(WebSocketProtocolError): - serialize_close(999, '') + serialize_close(999, "") def test_extensions(self): class Rot13: @@ -234,7 +234,7 @@ class Rot13: def encode(frame): assert frame.opcode == OP_TEXT text = frame.data.decode() - data = codecs.encode(text, 'rot13').encode() + data = codecs.encode(text, "rot13").encode() return frame._replace(data=data) # This extensions is symmetrical. @@ -243,5 +243,5 @@ def decode(frame, *, max_size=None): return Rot13.encode(frame) self.round_trip( - b'\x81\x05uryyb', Frame(True, OP_TEXT, b'hello'), extensions=[Rot13()] + b"\x81\x05uryyb", Frame(True, OP_TEXT, b"hello"), extensions=[Rot13()] ) diff --git a/tests/test_handshake.py b/tests/test_handshake.py index a0cb55a9e..7d0477715 100644 --- a/tests/test_handshake.py +++ b/tests/test_handshake.py @@ -58,70 +58,70 @@ def assertInvalidRequestHeaders(self, exc_type): def test_request_invalid_connection(self): with self.assertInvalidRequestHeaders(InvalidUpgrade) as headers: - del headers['Connection'] - headers['Connection'] = 'Downgrade' + del headers["Connection"] + headers["Connection"] = "Downgrade" def test_request_missing_connection(self): with self.assertInvalidRequestHeaders(InvalidUpgrade) as headers: - del headers['Connection'] + del headers["Connection"] def test_request_additional_connection(self): with self.assertValidRequestHeaders() as headers: - headers['Connection'] = 'close' + headers["Connection"] = "close" def test_request_invalid_upgrade(self): with self.assertInvalidRequestHeaders(InvalidUpgrade) as headers: - del headers['Upgrade'] - headers['Upgrade'] = 'socketweb' + del headers["Upgrade"] + headers["Upgrade"] = "socketweb" def test_request_missing_upgrade(self): with self.assertInvalidRequestHeaders(InvalidUpgrade) as headers: - del headers['Upgrade'] + del headers["Upgrade"] def test_request_additional_upgrade(self): with self.assertInvalidRequestHeaders(InvalidUpgrade) as headers: - headers['Upgrade'] = 'socketweb' + headers["Upgrade"] = "socketweb" def test_request_invalid_key_not_base64(self): with self.assertInvalidRequestHeaders(InvalidHeaderValue) as headers: - del headers['Sec-WebSocket-Key'] - headers['Sec-WebSocket-Key'] = "!@#$%^&*()" + del headers["Sec-WebSocket-Key"] + headers["Sec-WebSocket-Key"] = "!@#$%^&*()" def test_request_invalid_key_not_well_padded(self): with self.assertInvalidRequestHeaders(InvalidHeaderValue) as headers: - del headers['Sec-WebSocket-Key'] - headers['Sec-WebSocket-Key'] = "CSIRmL8dWYxeAdr/XpEHRw" + del headers["Sec-WebSocket-Key"] + headers["Sec-WebSocket-Key"] = "CSIRmL8dWYxeAdr/XpEHRw" def test_request_invalid_key_not_16_bytes_long(self): with self.assertInvalidRequestHeaders(InvalidHeaderValue) as headers: - del headers['Sec-WebSocket-Key'] - headers['Sec-WebSocket-Key'] = "ZLpprpvK4PE=" + del headers["Sec-WebSocket-Key"] + headers["Sec-WebSocket-Key"] = "ZLpprpvK4PE=" def test_request_missing_key(self): with self.assertInvalidRequestHeaders(InvalidHeader) as headers: - del headers['Sec-WebSocket-Key'] + del headers["Sec-WebSocket-Key"] def test_request_additional_key(self): with self.assertInvalidRequestHeaders(InvalidHeader) as headers: # This duplicates the Sec-WebSocket-Key header. - headers['Sec-WebSocket-Key'] = headers['Sec-WebSocket-Key'] + headers["Sec-WebSocket-Key"] = headers["Sec-WebSocket-Key"] def test_request_invalid_version(self): with self.assertInvalidRequestHeaders(InvalidHeaderValue) as headers: - del headers['Sec-WebSocket-Version'] - headers['Sec-WebSocket-Version'] = '42' + del headers["Sec-WebSocket-Version"] + headers["Sec-WebSocket-Version"] = "42" def test_request_missing_version(self): with self.assertInvalidRequestHeaders(InvalidHeader) as headers: - del headers['Sec-WebSocket-Version'] + del headers["Sec-WebSocket-Version"] def test_request_additional_version(self): with self.assertInvalidRequestHeaders(InvalidHeader) as headers: # This duplicates the Sec-WebSocket-Version header. - headers['Sec-WebSocket-Version'] = headers['Sec-WebSocket-Version'] + headers["Sec-WebSocket-Version"] = headers["Sec-WebSocket-Version"] @contextlib.contextmanager - def assertValidResponseHeaders(self, key='CSIRmL8dWYxeAdr/XpEHRw=='): + def assertValidResponseHeaders(self, key="CSIRmL8dWYxeAdr/XpEHRw=="): """ Provide response headers for modification. @@ -134,7 +134,7 @@ def assertValidResponseHeaders(self, key='CSIRmL8dWYxeAdr/XpEHRw=='): check_response(headers, key) @contextlib.contextmanager - def assertInvalidResponseHeaders(self, exc_type, key='CSIRmL8dWYxeAdr/XpEHRw=='): + def assertInvalidResponseHeaders(self, exc_type, key="CSIRmL8dWYxeAdr/XpEHRw=="): """ Provide response headers for modification. @@ -150,41 +150,41 @@ def assertInvalidResponseHeaders(self, exc_type, key='CSIRmL8dWYxeAdr/XpEHRw==') def test_response_invalid_connection(self): with self.assertInvalidResponseHeaders(InvalidUpgrade) as headers: - del headers['Connection'] - headers['Connection'] = 'Downgrade' + del headers["Connection"] + headers["Connection"] = "Downgrade" def test_response_missing_connection(self): with self.assertInvalidResponseHeaders(InvalidUpgrade) as headers: - del headers['Connection'] + del headers["Connection"] def test_response_additional_connection(self): with self.assertValidResponseHeaders() as headers: - headers['Connection'] = 'close' + headers["Connection"] = "close" def test_response_invalid_upgrade(self): with self.assertInvalidResponseHeaders(InvalidUpgrade) as headers: - del headers['Upgrade'] - headers['Upgrade'] = 'socketweb' + del headers["Upgrade"] + headers["Upgrade"] = "socketweb" def test_response_missing_upgrade(self): with self.assertInvalidResponseHeaders(InvalidUpgrade) as headers: - del headers['Upgrade'] + del headers["Upgrade"] def test_response_additional_upgrade(self): with self.assertInvalidResponseHeaders(InvalidUpgrade) as headers: - headers['Upgrade'] = 'socketweb' + headers["Upgrade"] = "socketweb" def test_response_invalid_accept(self): with self.assertInvalidResponseHeaders(InvalidHeaderValue) as headers: - del headers['Sec-WebSocket-Accept'] + del headers["Sec-WebSocket-Accept"] other_key = "1Eq4UDEFQYg3YspNgqxv5g==" - headers['Sec-WebSocket-Accept'] = accept(other_key) + headers["Sec-WebSocket-Accept"] = accept(other_key) def test_response_missing_accept(self): with self.assertInvalidResponseHeaders(InvalidHeader) as headers: - del headers['Sec-WebSocket-Accept'] + del headers["Sec-WebSocket-Accept"] def test_response_additional_accept(self): with self.assertInvalidResponseHeaders(InvalidHeader) as headers: # This duplicates the Sec-WebSocket-Accept header. - headers['Sec-WebSocket-Accept'] = headers['Sec-WebSocket-Accept'] + headers["Sec-WebSocket-Accept"] = headers["Sec-WebSocket-Accept"] diff --git a/tests/test_headers.py b/tests/test_headers.py index 7d52b9f74..f03dc83cf 100644 --- a/tests/test_headers.py +++ b/tests/test_headers.py @@ -9,16 +9,16 @@ class HeadersTests(unittest.TestCase): def test_parse_connection(self): for header, parsed in [ # Realistic use cases - ('Upgrade', ['Upgrade']), # Safari, Chrome - ('keep-alive, Upgrade', ['keep-alive', 'Upgrade']), # Firefox + ("Upgrade", ["Upgrade"]), # Safari, Chrome + ("keep-alive, Upgrade", ["keep-alive", "Upgrade"]), # Firefox # Pathological example - (',,\t, , ,Upgrade ,,', ['Upgrade']), + (",,\t, , ,Upgrade ,,", ["Upgrade"]), ]: with self.subTest(header=header): self.assertEqual(parse_connection(header), parsed) def test_parse_connection_invalid_header(self): - for header in ['???', 'keep-alive; Upgrade']: + for header in ["???", "keep-alive; Upgrade"]: with self.subTest(header=header): with self.assertRaises(InvalidHeaderFormat): parse_connection(header) @@ -26,17 +26,17 @@ def test_parse_connection_invalid_header(self): def test_parse_upgrade(self): for header, parsed in [ # Realistic use case - ('websocket', ['websocket']), + ("websocket", ["websocket"]), # Synthetic example - ('http/3.0, websocket', ['http/3.0', 'websocket']), + ("http/3.0, websocket", ["http/3.0", "websocket"]), # Pathological example - (',, WebSocket, \t,,', ['WebSocket']), + (",, WebSocket, \t,,", ["WebSocket"]), ]: with self.subTest(header=header): self.assertEqual(parse_upgrade(header), parsed) def test_parse_upgrade_invalid_header(self): - for header in ['???', 'websocket 2', 'http/3.0; websocket']: + for header in ["???", "websocket 2", "http/3.0; websocket"]: with self.subTest(header=header): with self.assertRaises(InvalidHeaderFormat): parse_upgrade(header) @@ -44,37 +44,37 @@ def test_parse_upgrade_invalid_header(self): def test_parse_extension_list(self): for header, parsed in [ # Synthetic examples - ('foo', [('foo', [])]), - ('foo, bar', [('foo', []), ('bar', [])]), + ("foo", [("foo", [])]), + ("foo, bar", [("foo", []), ("bar", [])]), ( 'foo; name; token=token; quoted-string="quoted-string", ' - 'bar; quux; quuux', + "bar; quux; quuux", [ ( - 'foo', + "foo", [ - ('name', None), - ('token', 'token'), - ('quoted-string', 'quoted-string'), + ("name", None), + ("token", "token"), + ("quoted-string", "quoted-string"), ], ), - ('bar', [('quux', None), ('quuux', None)]), + ("bar", [("quux", None), ("quuux", None)]), ], ), # Pathological example ( - ',\t, , ,foo ;bar = 42,, baz,,', - [('foo', [('bar', '42')]), ('baz', [])], + ",\t, , ,foo ;bar = 42,, baz,,", + [("foo", [("bar", "42")]), ("baz", [])], ), # Realistic use cases for permessage-deflate - ('permessage-deflate', [('permessage-deflate', [])]), + ("permessage-deflate", [("permessage-deflate", [])]), ( - 'permessage-deflate; client_max_window_bits', - [('permessage-deflate', [('client_max_window_bits', None)])], + "permessage-deflate; client_max_window_bits", + [("permessage-deflate", [("client_max_window_bits", None)])], ), ( - 'permessage-deflate; server_max_window_bits=10', - [('permessage-deflate', [('server_max_window_bits', '10')])], + "permessage-deflate; server_max_window_bits=10", + [("permessage-deflate", [("server_max_window_bits", "10")])], ), ]: with self.subTest(header=header): @@ -86,14 +86,14 @@ def test_parse_extension_list(self): def test_parse_extension_list_invalid_header(self): for header in [ # Truncated examples - '', - ',\t,', - 'foo;', - 'foo; bar;', - 'foo; bar=', + "", + ",\t,", + "foo;", + "foo; bar;", + "foo; bar=", 'foo; bar="baz', # Wrong delimiter - 'foo, bar, baz=quux; quuux', + "foo, bar, baz=quux; quuux", # Value in quoted string parameter that isn't a token 'foo; bar=" "', ]: @@ -104,10 +104,10 @@ def test_parse_extension_list_invalid_header(self): def test_parse_subprotocol_list(self): for header, parsed in [ # Synthetic examples - ('foo', ['foo']), - ('foo, bar', ['foo', 'bar']), + ("foo", ["foo"]), + ("foo, bar", ["foo", "bar"]), # Pathological example - (',\t, , ,foo ,, bar,baz,,', ['foo', 'bar', 'baz']), + (",\t, , ,foo ,, bar,baz,,", ["foo", "bar", "baz"]), ]: with self.subTest(header=header): self.assertEqual(parse_subprotocol_list(header), parsed) @@ -118,10 +118,10 @@ def test_parse_subprotocol_list(self): def test_parse_subprotocol_list_invalid_header(self): for header in [ # Truncated examples - '', - ',\t,' + "", + ",\t," # Wrong delimiter - 'foo; bar', + "foo; bar", ]: with self.subTest(header=header): with self.assertRaises(InvalidHeaderFormat): @@ -131,5 +131,5 @@ def test_build_basic_auth(self): # Test vector from RFC 7617. self.assertEqual( build_basic_auth("Aladdin", "open sesame"), - 'Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==', + "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==", ) diff --git a/tests/test_http.py b/tests/test_http.py index c222b370f..b28bed6ce 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -19,92 +19,92 @@ def tearDown(self): def test_read_request(self): # Example from the protocol overview in RFC 6455 self.stream.feed_data( - b'GET /chat HTTP/1.1\r\n' - b'Host: server.example.com\r\n' - b'Upgrade: websocket\r\n' - b'Connection: Upgrade\r\n' - b'Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n' - b'Origin: http://example.com\r\n' - b'Sec-WebSocket-Protocol: chat, superchat\r\n' - b'Sec-WebSocket-Version: 13\r\n' - b'\r\n' + b"GET /chat HTTP/1.1\r\n" + b"Host: server.example.com\r\n" + b"Upgrade: websocket\r\n" + b"Connection: Upgrade\r\n" + b"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + b"Origin: http://example.com\r\n" + b"Sec-WebSocket-Protocol: chat, superchat\r\n" + b"Sec-WebSocket-Version: 13\r\n" + b"\r\n" ) path, headers = self.loop.run_until_complete(read_request(self.stream)) - self.assertEqual(path, '/chat') - self.assertEqual(headers['Upgrade'], 'websocket') + self.assertEqual(path, "/chat") + self.assertEqual(headers["Upgrade"], "websocket") def test_read_response(self): # Example from the protocol overview in RFC 6455 self.stream.feed_data( - b'HTTP/1.1 101 Switching Protocols\r\n' - b'Upgrade: websocket\r\n' - b'Connection: Upgrade\r\n' - b'Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n' - b'Sec-WebSocket-Protocol: chat\r\n' - b'\r\n' + b"HTTP/1.1 101 Switching Protocols\r\n" + b"Upgrade: websocket\r\n" + b"Connection: Upgrade\r\n" + b"Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n" + b"Sec-WebSocket-Protocol: chat\r\n" + b"\r\n" ) status_code, reason, headers = self.loop.run_until_complete( read_response(self.stream) ) self.assertEqual(status_code, 101) - self.assertEqual(reason, 'Switching Protocols') - self.assertEqual(headers['Upgrade'], 'websocket') + self.assertEqual(reason, "Switching Protocols") + self.assertEqual(headers["Upgrade"], "websocket") def test_request_method(self): - self.stream.feed_data(b'OPTIONS * HTTP/1.1\r\n\r\n') + self.stream.feed_data(b"OPTIONS * HTTP/1.1\r\n\r\n") with self.assertRaises(ValueError): self.loop.run_until_complete(read_request(self.stream)) def test_request_version(self): - self.stream.feed_data(b'GET /chat HTTP/1.0\r\n\r\n') + self.stream.feed_data(b"GET /chat HTTP/1.0\r\n\r\n") with self.assertRaises(ValueError): self.loop.run_until_complete(read_request(self.stream)) def test_response_version(self): - self.stream.feed_data(b'HTTP/1.0 400 Bad Request\r\n\r\n') + self.stream.feed_data(b"HTTP/1.0 400 Bad Request\r\n\r\n") with self.assertRaises(ValueError): self.loop.run_until_complete(read_response(self.stream)) def test_response_status(self): - self.stream.feed_data(b'HTTP/1.1 007 My name is Bond\r\n\r\n') + self.stream.feed_data(b"HTTP/1.1 007 My name is Bond\r\n\r\n") with self.assertRaises(ValueError): self.loop.run_until_complete(read_response(self.stream)) def test_response_reason(self): - self.stream.feed_data(b'HTTP/1.1 200 \x7f\r\n\r\n') + self.stream.feed_data(b"HTTP/1.1 200 \x7f\r\n\r\n") with self.assertRaises(ValueError): self.loop.run_until_complete(read_response(self.stream)) def test_header_name(self): - self.stream.feed_data(b'foo bar: baz qux\r\n\r\n') + self.stream.feed_data(b"foo bar: baz qux\r\n\r\n") with self.assertRaises(ValueError): self.loop.run_until_complete(read_headers(self.stream)) def test_header_value(self): - self.stream.feed_data(b'foo: \x00\x00\x0f\r\n\r\n') + self.stream.feed_data(b"foo: \x00\x00\x0f\r\n\r\n") with self.assertRaises(ValueError): self.loop.run_until_complete(read_headers(self.stream)) def test_headers_limit(self): - self.stream.feed_data(b'foo: bar\r\n' * 257 + b'\r\n') + self.stream.feed_data(b"foo: bar\r\n" * 257 + b"\r\n") with self.assertRaises(ValueError): self.loop.run_until_complete(read_headers(self.stream)) def test_line_limit(self): # Header line contains 5 + 4090 + 2 = 4097 bytes. - self.stream.feed_data(b'foo: ' + b'a' * 4090 + b'\r\n\r\n') + self.stream.feed_data(b"foo: " + b"a" * 4090 + b"\r\n\r\n") with self.assertRaises(ValueError): self.loop.run_until_complete(read_headers(self.stream)) def test_line_ending(self): - self.stream.feed_data(b'foo: bar\n\n') + self.stream.feed_data(b"foo: bar\n\n") with self.assertRaises(ValueError): self.loop.run_until_complete(read_headers(self.stream)) class HeadersTests(unittest.TestCase): def setUp(self): - self.headers = Headers([('Connection', 'Upgrade'), ('Server', USER_AGENT)]) + self.headers = Headers([("Connection", "Upgrade"), ("Server", USER_AGENT)]) def test_str(self): self.assertEqual( @@ -120,67 +120,67 @@ def test_repr(self): ) def test_multiple_values_error_str(self): - self.assertEqual(str(MultipleValuesError('Connection')), "'Connection'") + self.assertEqual(str(MultipleValuesError("Connection")), "'Connection'") self.assertEqual(str(MultipleValuesError()), "") def test_contains(self): - self.assertIn('Server', self.headers) + self.assertIn("Server", self.headers) def test_contains_case_insensitive(self): - self.assertIn('server', self.headers) + self.assertIn("server", self.headers) def test_contains_not_found(self): - self.assertNotIn('Date', self.headers) + self.assertNotIn("Date", self.headers) def test_iter(self): - self.assertEqual(set(iter(self.headers)), {'connection', 'server'}) + self.assertEqual(set(iter(self.headers)), {"connection", "server"}) def test_len(self): self.assertEqual(len(self.headers), 2) def test_getitem(self): - self.assertEqual(self.headers['Server'], USER_AGENT) + self.assertEqual(self.headers["Server"], USER_AGENT) def test_getitem_case_insensitive(self): - self.assertEqual(self.headers['server'], USER_AGENT) + self.assertEqual(self.headers["server"], USER_AGENT) def test_getitem_key_error(self): with self.assertRaises(KeyError): - self.headers['Upgrade'] + self.headers["Upgrade"] def test_getitem_multiple_values_error(self): - self.headers['Server'] = '2' + self.headers["Server"] = "2" with self.assertRaises(MultipleValuesError): - self.headers['Server'] + self.headers["Server"] def test_setitem(self): - self.headers['Upgrade'] = 'websocket' - self.assertEqual(self.headers['Upgrade'], 'websocket') + self.headers["Upgrade"] = "websocket" + self.assertEqual(self.headers["Upgrade"], "websocket") def test_setitem_case_insensitive(self): - self.headers['upgrade'] = 'websocket' - self.assertEqual(self.headers['Upgrade'], 'websocket') + self.headers["upgrade"] = "websocket" + self.assertEqual(self.headers["Upgrade"], "websocket") def test_setitem_multiple_values(self): - self.headers['Connection'] = 'close' + self.headers["Connection"] = "close" with self.assertRaises(MultipleValuesError): - self.headers['Connection'] + self.headers["Connection"] def test_delitem(self): - del self.headers['Connection'] + del self.headers["Connection"] with self.assertRaises(KeyError): - self.headers['Connection'] + self.headers["Connection"] def test_delitem_case_insensitive(self): - del self.headers['connection'] + del self.headers["connection"] with self.assertRaises(KeyError): - self.headers['Connection'] + self.headers["Connection"] def test_delitem_multiple_values(self): - self.headers['Connection'] = 'close' - del self.headers['Connection'] + self.headers["Connection"] = "close" + del self.headers["Connection"] with self.assertRaises(KeyError): - self.headers['Connection'] + self.headers["Connection"] def test_eq(self): other_headers = self.headers.copy() @@ -195,20 +195,20 @@ def test_clear(self): self.assertEqual(self.headers, Headers()) def test_get_all(self): - self.assertEqual(self.headers.get_all('Connection'), ['Upgrade']) + self.assertEqual(self.headers.get_all("Connection"), ["Upgrade"]) def test_get_all_case_insensitive(self): - self.assertEqual(self.headers.get_all('connection'), ['Upgrade']) + self.assertEqual(self.headers.get_all("connection"), ["Upgrade"]) def test_get_all_no_values(self): - self.assertEqual(self.headers.get_all('Upgrade'), []) + self.assertEqual(self.headers.get_all("Upgrade"), []) def test_get_all_multiple_values(self): - self.headers['Connection'] = 'close' - self.assertEqual(self.headers.get_all('Connection'), ['Upgrade', 'close']) + self.headers["Connection"] = "close" + self.assertEqual(self.headers.get_all("Connection"), ["Upgrade", "close"]) def test_raw_items(self): self.assertEqual( list(self.headers.raw_items()), - [('Connection', 'Upgrade'), ('Server', USER_AGENT)], + [("Connection", "Upgrade"), ("Server", USER_AGENT)], ) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index a5eb251c9..cb562e647 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -19,14 +19,14 @@ # Unit for timeouts. May be increased on slow machines by setting the # WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. -MS = 0.001 * int(os.environ.get('WEBSOCKETS_TESTS_TIMEOUT_FACTOR', 1)) +MS = 0.001 * int(os.environ.get("WEBSOCKETS_TESTS_TIMEOUT_FACTOR", 1)) # asyncio's debug mode has a 10x performance penalty for this test suite. -if os.environ.get('PYTHONASYNCIODEBUG'): # pragma: no cover +if os.environ.get("PYTHONASYNCIODEBUG"): # pragma: no cover MS *= 10 # Ensure that timeouts are larger than the clock's resolution (for Windows). -MS = max(MS, 2.5 * time.get_clock_info('monotonic').resolution) +MS = max(MS, 2.5 * time.get_clock_info("monotonic").resolution) class TransportMock(unittest.mock.Mock): @@ -126,9 +126,9 @@ def delayed_drain(): self.protocol.writer.drain = delayed_drain - close_frame = Frame(True, OP_CLOSE, serialize_close(1000, 'close')) - local_close = Frame(True, OP_CLOSE, serialize_close(1000, 'local')) - remote_close = Frame(True, OP_CLOSE, serialize_close(1000, 'remote')) + close_frame = Frame(True, OP_CLOSE, serialize_close(1000, "close")) + local_close = Frame(True, OP_CLOSE, serialize_close(1000, "local")) + remote_close = Frame(True, OP_CLOSE, serialize_close(1000, "remote")) @property def ensure_future(self): @@ -166,7 +166,7 @@ def receive_eof_if_client(self): if self.protocol.is_client: self.receive_eof() - def close_connection(self, code=1000, reason='close'): + def close_connection(self, code=1000, reason="close"): """ Execute a closing handshake. @@ -184,7 +184,7 @@ def close_connection(self, code=1000, reason='close'): assert self.protocol.state is State.CLOSED - def half_close_connection_local(self, code=1000, reason='close'): + def half_close_connection_local(self, code=1000, reason="close"): """ Start a closing handshake but do not complete it. @@ -215,7 +215,7 @@ def half_close_connection_local(self, code=1000, reason='close'): # This task must be awaited or canceled by the caller. return close_task - def half_close_connection_remote(self, code=1000, reason='close'): + def half_close_connection_remote(self, code=1000, reason="close"): """ Receive a closing handshake but do not complete it. @@ -310,7 +310,7 @@ def assertConnectionFailed(self, code, message): self.assertEqual(self.protocol.state, State.CLOSED) # No close frame was received. self.assertEqual(self.protocol.close_code, 1006) - self.assertEqual(self.protocol.close_reason, '') + self.assertEqual(self.protocol.close_reason, "") # A close frame was sent -- unless the connection was already lost. if code == 1006: self.assertNoFrameSent() @@ -329,11 +329,11 @@ def assertCompletesWithin(self, min_time, max_time): # Test public attributes. def test_local_address(self): - get_extra_info = unittest.mock.Mock(return_value=('host', 4312)) + get_extra_info = unittest.mock.Mock(return_value=("host", 4312)) self.transport.get_extra_info = get_extra_info - self.assertEqual(self.protocol.local_address, ('host', 4312)) - get_extra_info.assert_called_with('sockname', None) + self.assertEqual(self.protocol.local_address, ("host", 4312)) + get_extra_info.assert_called_with("sockname", None) def test_local_address_before_connection(self): # Emulate the situation before connection_open() runs. @@ -345,11 +345,11 @@ def test_local_address_before_connection(self): self.protocol.writer = _writer def test_remote_address(self): - get_extra_info = unittest.mock.Mock(return_value=('host', 4312)) + get_extra_info = unittest.mock.Mock(return_value=("host", 4312)) self.transport.get_extra_info = get_extra_info - self.assertEqual(self.protocol.remote_address, ('host', 4312)) - get_extra_info.assert_called_with('peername', None) + self.assertEqual(self.protocol.remote_address, ("host", 4312)) + get_extra_info.assert_called_with("peername", None) def test_remote_address_before_connection(self): # Emulate the situation before connection_open() runs. @@ -379,14 +379,14 @@ def test_wait_closed(self): # Test the recv coroutine. def test_recv_text(self): - self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8'))) + self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) data = self.loop.run_until_complete(self.protocol.recv()) - self.assertEqual(data, 'café') + self.assertEqual(data, "café") def test_recv_binary(self): - self.receive_frame(Frame(True, OP_BINARY, b'tea')) + self.receive_frame(Frame(True, OP_BINARY, b"tea")) data = self.loop.run_until_complete(self.protocol.recv()) - self.assertEqual(data, b'tea') + self.assertEqual(data, b"tea") def test_recv_on_closing_connection_local(self): close_task = self.half_close_connection_local() @@ -409,38 +409,38 @@ def test_recv_on_closed_connection(self): self.loop.run_until_complete(self.protocol.recv()) def test_recv_protocol_error(self): - self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8'))) + self.receive_frame(Frame(True, OP_CONT, "café".encode("utf-8"))) self.process_invalid_frames() - self.assertConnectionFailed(1002, '') + self.assertConnectionFailed(1002, "") def test_recv_unicode_error(self): - self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('latin-1'))) + self.receive_frame(Frame(True, OP_TEXT, "café".encode("latin-1"))) self.process_invalid_frames() - self.assertConnectionFailed(1007, '') + self.assertConnectionFailed(1007, "") def test_recv_text_payload_too_big(self): self.protocol.max_size = 1024 - self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8') * 205)) + self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8") * 205)) self.process_invalid_frames() - self.assertConnectionFailed(1009, '') + self.assertConnectionFailed(1009, "") def test_recv_binary_payload_too_big(self): self.protocol.max_size = 1024 - self.receive_frame(Frame(True, OP_BINARY, b'tea' * 342)) + self.receive_frame(Frame(True, OP_BINARY, b"tea" * 342)) self.process_invalid_frames() - self.assertConnectionFailed(1009, '') + self.assertConnectionFailed(1009, "") def test_recv_text_no_max_size(self): self.protocol.max_size = None # for test coverage - self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8') * 205)) + self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8") * 205)) data = self.loop.run_until_complete(self.protocol.recv()) - self.assertEqual(data, 'café' * 205) + self.assertEqual(data, "café" * 205) def test_recv_binary_no_max_size(self): self.protocol.max_size = None # for test coverage - self.receive_frame(Frame(True, OP_BINARY, b'tea' * 342)) + self.receive_frame(Frame(True, OP_BINARY, b"tea" * 342)) data = self.loop.run_until_complete(self.protocol.recv()) - self.assertEqual(data, b'tea' * 342) + self.assertEqual(data, b"tea" * 342) def test_recv_queue_empty(self): recv = self.ensure_future(self.protocol.recv()) @@ -449,32 +449,32 @@ def test_recv_queue_empty(self): asyncio.wait_for(asyncio.shield(recv), timeout=MS) ) - self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8'))) + self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) data = self.loop.run_until_complete(recv) - self.assertEqual(data, 'café') + self.assertEqual(data, "café") def test_recv_queue_full(self): self.protocol.max_queue = 2 # Test internals because it's hard to verify buffers from the outside. self.assertEqual(list(self.protocol.messages), []) - self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8'))) + self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) self.run_loop_once() - self.assertEqual(list(self.protocol.messages), ['café']) + self.assertEqual(list(self.protocol.messages), ["café"]) - self.receive_frame(Frame(True, OP_BINARY, b'tea')) + self.receive_frame(Frame(True, OP_BINARY, b"tea")) self.run_loop_once() - self.assertEqual(list(self.protocol.messages), ['café', b'tea']) + self.assertEqual(list(self.protocol.messages), ["café", b"tea"]) - self.receive_frame(Frame(True, OP_BINARY, b'milk')) + self.receive_frame(Frame(True, OP_BINARY, b"milk")) self.run_loop_once() - self.assertEqual(list(self.protocol.messages), ['café', b'tea']) + self.assertEqual(list(self.protocol.messages), ["café", b"tea"]) self.loop.run_until_complete(self.protocol.recv()) - self.assertEqual(list(self.protocol.messages), [b'tea', b'milk']) + self.assertEqual(list(self.protocol.messages), [b"tea", b"milk"]) self.loop.run_until_complete(self.protocol.recv()) - self.assertEqual(list(self.protocol.messages), [b'milk']) + self.assertEqual(list(self.protocol.messages), [b"milk"]) self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(list(self.protocol.messages), []) @@ -486,7 +486,7 @@ def read_message(): self.protocol.read_message = read_message self.process_invalid_frames() - self.assertConnectionFailed(1011, '') + self.assertConnectionFailed(1011, "") def test_recv_canceled(self): recv = self.ensure_future(self.protocol.recv()) @@ -496,26 +496,26 @@ def test_recv_canceled(self): self.loop.run_until_complete(recv) # The next frame doesn't disappear in a vacuum (it used to). - self.receive_frame(Frame(True, OP_TEXT, 'café'.encode('utf-8'))) + self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) data = self.loop.run_until_complete(self.protocol.recv()) - self.assertEqual(data, 'café') + self.assertEqual(data, "café") def test_recv_canceled_race_condition(self): recv = self.ensure_future( asyncio.wait_for(self.protocol.recv(), timeout=0.000001) ) self.loop.call_soon( - self.receive_frame, Frame(True, OP_TEXT, 'café'.encode('utf-8')) + self.receive_frame, Frame(True, OP_TEXT, "café".encode("utf-8")) ) with self.assertRaises(asyncio.TimeoutError): self.loop.run_until_complete(recv) # The previous frame doesn't disappear in a vacuum (it used to). - self.receive_frame(Frame(True, OP_TEXT, 'tea'.encode('utf-8'))) + self.receive_frame(Frame(True, OP_TEXT, "tea".encode("utf-8"))) data = self.loop.run_until_complete(self.protocol.recv()) # If we're getting "tea" there, it means "café" was swallowed (ha, ha). - self.assertEqual(data, 'café') + self.assertEqual(data, "café") def test_recv_prevents_concurrent_calls(self): recv = self.ensure_future(self.protocol.recv()) @@ -528,24 +528,24 @@ def test_recv_prevents_concurrent_calls(self): # Test the send coroutine. def test_send_text(self): - self.loop.run_until_complete(self.protocol.send('café')) - self.assertOneFrameSent(True, OP_TEXT, 'café'.encode('utf-8')) + self.loop.run_until_complete(self.protocol.send("café")) + self.assertOneFrameSent(True, OP_TEXT, "café".encode("utf-8")) def test_send_binary(self): - self.loop.run_until_complete(self.protocol.send(b'tea')) - self.assertOneFrameSent(True, OP_BINARY, b'tea') + self.loop.run_until_complete(self.protocol.send(b"tea")) + self.assertOneFrameSent(True, OP_BINARY, b"tea") def test_send_binary_from_bytearray(self): - self.loop.run_until_complete(self.protocol.send(bytearray(b'tea'))) - self.assertOneFrameSent(True, OP_BINARY, b'tea') + self.loop.run_until_complete(self.protocol.send(bytearray(b"tea"))) + self.assertOneFrameSent(True, OP_BINARY, b"tea") def test_send_binary_from_memoryview(self): - self.loop.run_until_complete(self.protocol.send(memoryview(b'tea'))) - self.assertOneFrameSent(True, OP_BINARY, b'tea') + self.loop.run_until_complete(self.protocol.send(memoryview(b"tea"))) + self.assertOneFrameSent(True, OP_BINARY, b"tea") def test_send_binary_from_non_contiguous_memoryview(self): - self.loop.run_until_complete(self.protocol.send(memoryview(b'tteeaa')[::2])) - self.assertOneFrameSent(True, OP_BINARY, b'tea') + self.loop.run_until_complete(self.protocol.send(memoryview(b"tteeaa")[::2])) + self.assertOneFrameSent(True, OP_BINARY, b"tea") def test_send_type_error(self): with self.assertRaises(TypeError): @@ -553,41 +553,41 @@ def test_send_type_error(self): self.assertNoFrameSent() def test_send_iterable_text(self): - self.loop.run_until_complete(self.protocol.send(['ca', 'fé'])) + self.loop.run_until_complete(self.protocol.send(["ca", "fé"])) self.assertFramesSent( - (False, OP_TEXT, 'ca'.encode('utf-8')), - (False, OP_CONT, 'fé'.encode('utf-8')), - (True, OP_CONT, ''.encode('utf-8')), + (False, OP_TEXT, "ca".encode("utf-8")), + (False, OP_CONT, "fé".encode("utf-8")), + (True, OP_CONT, "".encode("utf-8")), ) def test_send_iterable_binary(self): - self.loop.run_until_complete(self.protocol.send([b'te', b'a'])) + self.loop.run_until_complete(self.protocol.send([b"te", b"a"])) self.assertFramesSent( - (False, OP_BINARY, b'te'), (False, OP_CONT, b'a'), (True, OP_CONT, b'') + (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") ) def test_send_iterable_binary_from_bytearray(self): self.loop.run_until_complete( - self.protocol.send([bytearray(b'te'), bytearray(b'a')]) + self.protocol.send([bytearray(b"te"), bytearray(b"a")]) ) self.assertFramesSent( - (False, OP_BINARY, b'te'), (False, OP_CONT, b'a'), (True, OP_CONT, b'') + (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") ) def test_send_iterable_binary_from_memoryview(self): self.loop.run_until_complete( - self.protocol.send([memoryview(b'te'), memoryview(b'a')]) + self.protocol.send([memoryview(b"te"), memoryview(b"a")]) ) self.assertFramesSent( - (False, OP_BINARY, b'te'), (False, OP_CONT, b'a'), (True, OP_CONT, b'') + (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") ) def test_send_iterable_binary_from_non_contiguous_memoryview(self): self.loop.run_until_complete( - self.protocol.send([memoryview(b'ttee')[::2], memoryview(b'aa')[::2]]) + self.protocol.send([memoryview(b"ttee")[::2], memoryview(b"aa")[::2]]) ) self.assertFramesSent( - (False, OP_BINARY, b'te'), (False, OP_CONT, b'a'), (True, OP_CONT, b'') + (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") ) def test_send_empty_iterable(self): @@ -601,17 +601,17 @@ def test_send_iterable_type_error(self): def test_send_iterable_mixed_type_error(self): with self.assertRaises(TypeError): - self.loop.run_until_complete(self.protocol.send(['café', b'tea'])) + self.loop.run_until_complete(self.protocol.send(["café", b"tea"])) self.assertFramesSent( - (False, OP_TEXT, 'café'.encode('utf-8')), - (True, OP_CLOSE, serialize_close(1011, '')), + (False, OP_TEXT, "café".encode("utf-8")), + (True, OP_CLOSE, serialize_close(1011, "")), ) def test_send_on_closing_connection_local(self): close_task = self.half_close_connection_local() with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(self.protocol.send('foobar')) + self.loop.run_until_complete(self.protocol.send("foobar")) self.assertNoFrameSent() @@ -621,7 +621,7 @@ def test_send_on_closing_connection_remote(self): self.half_close_connection_remote() with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(self.protocol.send('foobar')) + self.loop.run_until_complete(self.protocol.send("foobar")) self.assertNoFrameSent() @@ -629,7 +629,7 @@ def test_send_on_closed_connection(self): self.close_connection() with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(self.protocol.send('foobar')) + self.loop.run_until_complete(self.protocol.send("foobar")) self.assertNoFrameSent() @@ -645,24 +645,24 @@ def test_ping_default(self): self.assertOneFrameSent(True, OP_PING, ping_data) def test_ping_text(self): - self.loop.run_until_complete(self.protocol.ping('café')) - self.assertOneFrameSent(True, OP_PING, 'café'.encode('utf-8')) + self.loop.run_until_complete(self.protocol.ping("café")) + self.assertOneFrameSent(True, OP_PING, "café".encode("utf-8")) def test_ping_binary(self): - self.loop.run_until_complete(self.protocol.ping(b'tea')) - self.assertOneFrameSent(True, OP_PING, b'tea') + self.loop.run_until_complete(self.protocol.ping(b"tea")) + self.assertOneFrameSent(True, OP_PING, b"tea") def test_ping_binary_from_bytearray(self): - self.loop.run_until_complete(self.protocol.ping(bytearray(b'tea'))) - self.assertOneFrameSent(True, OP_PING, b'tea') + self.loop.run_until_complete(self.protocol.ping(bytearray(b"tea"))) + self.assertOneFrameSent(True, OP_PING, b"tea") def test_ping_binary_from_memoryview(self): - self.loop.run_until_complete(self.protocol.ping(memoryview(b'tea'))) - self.assertOneFrameSent(True, OP_PING, b'tea') + self.loop.run_until_complete(self.protocol.ping(memoryview(b"tea"))) + self.assertOneFrameSent(True, OP_PING, b"tea") def test_ping_binary_from_non_contiguous_memoryview(self): - self.loop.run_until_complete(self.protocol.ping(memoryview(b'tteeaa')[::2])) - self.assertOneFrameSent(True, OP_PING, b'tea') + self.loop.run_until_complete(self.protocol.ping(memoryview(b"tteeaa")[::2])) + self.assertOneFrameSent(True, OP_PING, b"tea") def test_ping_type_error(self): with self.assertRaises(TypeError): @@ -699,27 +699,27 @@ def test_ping_on_closed_connection(self): def test_pong_default(self): self.loop.run_until_complete(self.protocol.pong()) - self.assertOneFrameSent(True, OP_PONG, b'') + self.assertOneFrameSent(True, OP_PONG, b"") def test_pong_text(self): - self.loop.run_until_complete(self.protocol.pong('café')) - self.assertOneFrameSent(True, OP_PONG, 'café'.encode('utf-8')) + self.loop.run_until_complete(self.protocol.pong("café")) + self.assertOneFrameSent(True, OP_PONG, "café".encode("utf-8")) def test_pong_binary(self): - self.loop.run_until_complete(self.protocol.pong(b'tea')) - self.assertOneFrameSent(True, OP_PONG, b'tea') + self.loop.run_until_complete(self.protocol.pong(b"tea")) + self.assertOneFrameSent(True, OP_PONG, b"tea") def test_pong_binary_from_bytearray(self): - self.loop.run_until_complete(self.protocol.pong(bytearray(b'tea'))) - self.assertOneFrameSent(True, OP_PONG, b'tea') + self.loop.run_until_complete(self.protocol.pong(bytearray(b"tea"))) + self.assertOneFrameSent(True, OP_PONG, b"tea") def test_pong_binary_from_memoryview(self): - self.loop.run_until_complete(self.protocol.pong(memoryview(b'tea'))) - self.assertOneFrameSent(True, OP_PONG, b'tea') + self.loop.run_until_complete(self.protocol.pong(memoryview(b"tea"))) + self.assertOneFrameSent(True, OP_PONG, b"tea") def test_pong_binary_from_non_contiguous_memoryview(self): - self.loop.run_until_complete(self.protocol.pong(memoryview(b'tteeaa')[::2])) - self.assertOneFrameSent(True, OP_PONG, b'tea') + self.loop.run_until_complete(self.protocol.pong(memoryview(b"tteeaa")[::2])) + self.assertOneFrameSent(True, OP_PONG, b"tea") def test_pong_type_error(self): with self.assertRaises(TypeError): @@ -755,12 +755,12 @@ def test_pong_on_closed_connection(self): # Test the protocol's logic for acknowledging pings with pongs. def test_answer_ping(self): - self.receive_frame(Frame(True, OP_PING, b'test')) + self.receive_frame(Frame(True, OP_PING, b"test")) self.run_loop_once() - self.assertOneFrameSent(True, OP_PONG, b'test') + self.assertOneFrameSent(True, OP_PONG, b"test") def test_ignore_pong(self): - self.receive_frame(Frame(True, OP_PONG, b'test')) + self.receive_frame(Frame(True, OP_PONG, b"test")) self.run_loop_once() self.assertNoFrameSent() @@ -789,7 +789,7 @@ def test_acknowledge_previous_pings(self): for i in range(3) ] # Unsolicited pong doesn't acknowledge pings - self.receive_frame(Frame(True, OP_PONG, b'')) + self.receive_frame(Frame(True, OP_PONG, b"")) self.run_loop_once() self.run_loop_once() self.assertFalse(pings[0][0].done()) @@ -814,84 +814,84 @@ def test_canceled_ping(self): self.assertTrue(ping.cancelled()) def test_duplicate_ping(self): - self.loop.run_until_complete(self.protocol.ping(b'foobar')) - self.assertOneFrameSent(True, OP_PING, b'foobar') + self.loop.run_until_complete(self.protocol.ping(b"foobar")) + self.assertOneFrameSent(True, OP_PING, b"foobar") with self.assertRaises(ValueError): - self.loop.run_until_complete(self.protocol.ping(b'foobar')) + self.loop.run_until_complete(self.protocol.ping(b"foobar")) self.assertNoFrameSent() # Test the protocol's logic for rebuilding fragmented messages. def test_fragmented_text(self): - self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) - self.receive_frame(Frame(True, OP_CONT, 'fé'.encode('utf-8'))) + self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) + self.receive_frame(Frame(True, OP_CONT, "fé".encode("utf-8"))) data = self.loop.run_until_complete(self.protocol.recv()) - self.assertEqual(data, 'café') + self.assertEqual(data, "café") def test_fragmented_binary(self): - self.receive_frame(Frame(False, OP_BINARY, b't')) - self.receive_frame(Frame(False, OP_CONT, b'e')) - self.receive_frame(Frame(True, OP_CONT, b'a')) + self.receive_frame(Frame(False, OP_BINARY, b"t")) + self.receive_frame(Frame(False, OP_CONT, b"e")) + self.receive_frame(Frame(True, OP_CONT, b"a")) data = self.loop.run_until_complete(self.protocol.recv()) - self.assertEqual(data, b'tea') + self.assertEqual(data, b"tea") def test_fragmented_text_payload_too_big(self): self.protocol.max_size = 1024 - self.receive_frame(Frame(False, OP_TEXT, 'café'.encode('utf-8') * 100)) - self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8') * 105)) + self.receive_frame(Frame(False, OP_TEXT, "café".encode("utf-8") * 100)) + self.receive_frame(Frame(True, OP_CONT, "café".encode("utf-8") * 105)) self.process_invalid_frames() - self.assertConnectionFailed(1009, '') + self.assertConnectionFailed(1009, "") def test_fragmented_binary_payload_too_big(self): self.protocol.max_size = 1024 - self.receive_frame(Frame(False, OP_BINARY, b'tea' * 171)) - self.receive_frame(Frame(True, OP_CONT, b'tea' * 171)) + self.receive_frame(Frame(False, OP_BINARY, b"tea" * 171)) + self.receive_frame(Frame(True, OP_CONT, b"tea" * 171)) self.process_invalid_frames() - self.assertConnectionFailed(1009, '') + self.assertConnectionFailed(1009, "") def test_fragmented_text_no_max_size(self): self.protocol.max_size = None # for test coverage - self.receive_frame(Frame(False, OP_TEXT, 'café'.encode('utf-8') * 100)) - self.receive_frame(Frame(True, OP_CONT, 'café'.encode('utf-8') * 105)) + self.receive_frame(Frame(False, OP_TEXT, "café".encode("utf-8") * 100)) + self.receive_frame(Frame(True, OP_CONT, "café".encode("utf-8") * 105)) data = self.loop.run_until_complete(self.protocol.recv()) - self.assertEqual(data, 'café' * 205) + self.assertEqual(data, "café" * 205) def test_fragmented_binary_no_max_size(self): self.protocol.max_size = None # for test coverage - self.receive_frame(Frame(False, OP_BINARY, b'tea' * 171)) - self.receive_frame(Frame(True, OP_CONT, b'tea' * 171)) + self.receive_frame(Frame(False, OP_BINARY, b"tea" * 171)) + self.receive_frame(Frame(True, OP_CONT, b"tea" * 171)) data = self.loop.run_until_complete(self.protocol.recv()) - self.assertEqual(data, b'tea' * 342) + self.assertEqual(data, b"tea" * 342) def test_control_frame_within_fragmented_text(self): - self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) - self.receive_frame(Frame(True, OP_PING, b'')) - self.receive_frame(Frame(True, OP_CONT, 'fé'.encode('utf-8'))) + self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) + self.receive_frame(Frame(True, OP_PING, b"")) + self.receive_frame(Frame(True, OP_CONT, "fé".encode("utf-8"))) data = self.loop.run_until_complete(self.protocol.recv()) - self.assertEqual(data, 'café') - self.assertOneFrameSent(True, OP_PONG, b'') + self.assertEqual(data, "café") + self.assertOneFrameSent(True, OP_PONG, b"") def test_unterminated_fragmented_text(self): - self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) + self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) # Missing the second part of the fragmented frame. - self.receive_frame(Frame(True, OP_BINARY, b'tea')) + self.receive_frame(Frame(True, OP_BINARY, b"tea")) self.process_invalid_frames() - self.assertConnectionFailed(1002, '') + self.assertConnectionFailed(1002, "") def test_close_handshake_in_fragmented_text(self): - self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) - self.receive_frame(Frame(True, OP_CLOSE, b'')) + self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) + self.receive_frame(Frame(True, OP_CLOSE, b"")) self.process_invalid_frames() # The RFC may have overlooked this case: it says that control frames # can be interjected in the middle of a fragmented message and that a # close frame must be echoed. Even though there's an unterminated # message, technically, the closing handshake was successful. - self.assertConnectionClosed(1005, '') + self.assertConnectionClosed(1005, "") def test_connection_close_in_fragmented_text(self): - self.receive_frame(Frame(False, OP_TEXT, 'ca'.encode('utf-8'))) + self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) self.process_invalid_frames() - self.assertConnectionFailed(1006, '') + self.assertConnectionFailed(1006, "") # Test miscellaneous code paths to ensure full coverage. @@ -899,7 +899,7 @@ def test_connection_lost(self): # Test calling connection_lost without going through close_connection. self.protocol.connection_lost(None) - self.assertConnectionFailed(1006, '') + self.assertConnectionFailed(1006, "") def test_ensure_open_before_opening_handshake(self): # Simulate a bug by forcibly reverting the protocol state. @@ -941,7 +941,7 @@ def test_connection_closed_attributes(self): connection_closed_exc = context.exception self.assertEqual(connection_closed_exc.code, 1000) - self.assertEqual(connection_closed_exc.reason, 'close') + self.assertEqual(connection_closed_exc.reason, "close") # Test the protocol logic for sending keepalive pings. @@ -988,7 +988,7 @@ def test_keepalive_ping_not_acknowledged_closes_connection(self): # Connection is closed at 6ms. self.loop.run_until_complete(asyncio.sleep(4 * MS)) - self.assertOneFrameSent(True, OP_CLOSE, serialize_close(1011, '')) + self.assertOneFrameSent(True, OP_CLOSE, serialize_close(1011, "")) # The keepalive ping task is complete. self.assertEqual(self.protocol.keepalive_ping_task.result(), None) @@ -1061,15 +1061,15 @@ def test_local_close(self): self.loop.call_later(MS, self.receive_eof_if_client) # Run the closing handshake. - self.loop.run_until_complete(self.protocol.close(reason='close')) + self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionClosed(1000, 'close') + self.assertConnectionClosed(1000, "close") self.assertOneFrameSent(*self.close_frame) # Closing the connection again is a no-op. - self.loop.run_until_complete(self.protocol.close(reason='oh noes!')) + self.loop.run_until_complete(self.protocol.close(reason="oh noes!")) - self.assertConnectionClosed(1000, 'close') + self.assertConnectionClosed(1000, "close") self.assertNoFrameSent() def test_remote_close(self): @@ -1082,13 +1082,13 @@ def test_remote_close(self): with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.recv()) - self.assertConnectionClosed(1000, 'close') + self.assertConnectionClosed(1000, "close") self.assertOneFrameSent(*self.close_frame) # Closing the connection again is a no-op. - self.loop.run_until_complete(self.protocol.close(reason='oh noes!')) + self.loop.run_until_complete(self.protocol.close(reason="oh noes!")) - self.assertConnectionClosed(1000, 'close') + self.assertConnectionClosed(1000, "close") self.assertNoFrameSent() def test_simultaneous_close(self): @@ -1098,42 +1098,42 @@ def test_simultaneous_close(self): self.loop.call_soon(self.receive_frame, self.remote_close) self.loop.call_soon(self.receive_eof_if_client) - self.loop.run_until_complete(self.protocol.close(reason='local')) + self.loop.run_until_complete(self.protocol.close(reason="local")) - self.assertConnectionClosed(1000, 'remote') + self.assertConnectionClosed(1000, "remote") # The current implementation sends a close frame in response to the # close frame received from the remote end. It skips the close frame # that should be sent as a result of calling close(). self.assertOneFrameSent(*self.remote_close) def test_close_preserves_incoming_frames(self): - self.receive_frame(Frame(True, OP_TEXT, b'hello')) + self.receive_frame(Frame(True, OP_TEXT, b"hello")) self.loop.call_later(MS, self.receive_frame, self.close_frame) self.loop.call_later(MS, self.receive_eof_if_client) - self.loop.run_until_complete(self.protocol.close(reason='close')) + self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionClosed(1000, 'close') + self.assertConnectionClosed(1000, "close") self.assertOneFrameSent(*self.close_frame) next_message = self.loop.run_until_complete(self.protocol.recv()) - self.assertEqual(next_message, 'hello') + self.assertEqual(next_message, "hello") def test_close_protocol_error(self): - invalid_close_frame = Frame(True, OP_CLOSE, b'\x00') + invalid_close_frame = Frame(True, OP_CLOSE, b"\x00") self.receive_frame(invalid_close_frame) self.receive_eof_if_client() self.run_loop_once() - self.loop.run_until_complete(self.protocol.close(reason='close')) + self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionFailed(1002, '') + self.assertConnectionFailed(1002, "") def test_close_connection_lost(self): self.receive_eof() self.run_loop_once() - self.loop.run_until_complete(self.protocol.close(reason='close')) + self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionFailed(1006, '') + self.assertConnectionFailed(1006, "") def test_local_close_during_recv(self): recv = self.ensure_future(self.protocol.recv()) @@ -1141,19 +1141,19 @@ def test_local_close_during_recv(self): self.loop.call_later(MS, self.receive_frame, self.close_frame) self.loop.call_later(MS, self.receive_eof_if_client) - self.loop.run_until_complete(self.protocol.close(reason='close')) + self.loop.run_until_complete(self.protocol.close(reason="close")) with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(recv) - self.assertConnectionClosed(1000, 'close') + self.assertConnectionClosed(1000, "close") # There is no test_remote_close_during_recv because it would be identical # to test_remote_close. def test_remote_close_during_send(self): self.make_drain_slow() - send = self.ensure_future(self.protocol.send('hello')) + send = self.ensure_future(self.protocol.send("hello")) self.receive_frame(self.close_frame) self.receive_eof() @@ -1161,7 +1161,7 @@ def test_remote_close_during_send(self): with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(send) - self.assertConnectionClosed(1000, 'close') + self.assertConnectionClosed(1000, "close") # There is no test_local_close_during_send because this cannot really # happen, considering that writes are serialized. @@ -1171,7 +1171,7 @@ class ServerTests(CommonTests, unittest.TestCase): def setUp(self): super().setUp() self.protocol.is_client = False - self.protocol.side = 'server' + self.protocol.side = "server" def test_local_close_send_close_frame_timeout(self): self.protocol.close_timeout = 10 * MS @@ -1179,16 +1179,16 @@ def test_local_close_send_close_frame_timeout(self): # If we can't send a close frame, time out in 10ms. # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(9 * MS, 19 * MS): - self.loop.run_until_complete(self.protocol.close(reason='close')) - self.assertConnectionClosed(1006, '') + self.loop.run_until_complete(self.protocol.close(reason="close")) + self.assertConnectionClosed(1006, "") def test_local_close_receive_close_frame_timeout(self): self.protocol.close_timeout = 10 * MS # If the client doesn't send a close frame, time out in 10ms. # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(9 * MS, 19 * MS): - self.loop.run_until_complete(self.protocol.close(reason='close')) - self.assertConnectionClosed(1006, '') + self.loop.run_until_complete(self.protocol.close(reason="close")) + self.assertConnectionClosed(1006, "") def test_local_close_connection_lost_timeout_after_write_eof(self): self.protocol.close_timeout = 10 * MS @@ -1199,8 +1199,8 @@ def test_local_close_connection_lost_timeout_after_write_eof(self): # HACK: disable write_eof => other end drops connection emulation. self.transport._eof = True self.receive_frame(self.close_frame) - self.loop.run_until_complete(self.protocol.close(reason='close')) - self.assertConnectionClosed(1000, 'close') + self.loop.run_until_complete(self.protocol.close(reason="close")) + self.assertConnectionClosed(1000, "close") def test_local_close_connection_lost_timeout_after_close(self): self.protocol.close_timeout = 10 * MS @@ -1214,15 +1214,15 @@ def test_local_close_connection_lost_timeout_after_close(self): # HACK: disable close => other end drops connection emulation. self.transport._closing = True self.receive_frame(self.close_frame) - self.loop.run_until_complete(self.protocol.close(reason='close')) - self.assertConnectionClosed(1000, 'close') + self.loop.run_until_complete(self.protocol.close(reason="close")) + self.assertConnectionClosed(1000, "close") class ClientTests(CommonTests, unittest.TestCase): def setUp(self): super().setUp() self.protocol.is_client = True - self.protocol.side = 'client' + self.protocol.side = "client" def test_local_close_send_close_frame_timeout(self): self.protocol.close_timeout = 10 * MS @@ -1232,8 +1232,8 @@ def test_local_close_send_close_frame_timeout(self): # - 10ms waiting for receiving a half-close # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(19 * MS, 29 * MS): - self.loop.run_until_complete(self.protocol.close(reason='close')) - self.assertConnectionClosed(1006, '') + self.loop.run_until_complete(self.protocol.close(reason="close")) + self.assertConnectionClosed(1006, "") def test_local_close_receive_close_frame_timeout(self): self.protocol.close_timeout = 10 * MS @@ -1242,8 +1242,8 @@ def test_local_close_receive_close_frame_timeout(self): # - 10ms waiting for receiving a half-close # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(19 * MS, 29 * MS): - self.loop.run_until_complete(self.protocol.close(reason='close')) - self.assertConnectionClosed(1006, '') + self.loop.run_until_complete(self.protocol.close(reason="close")) + self.assertConnectionClosed(1006, "") def test_local_close_connection_lost_timeout_after_write_eof(self): self.protocol.close_timeout = 10 * MS @@ -1256,8 +1256,8 @@ def test_local_close_connection_lost_timeout_after_write_eof(self): # HACK: disable write_eof => other end drops connection emulation. self.transport._eof = True self.receive_frame(self.close_frame) - self.loop.run_until_complete(self.protocol.close(reason='close')) - self.assertConnectionClosed(1000, 'close') + self.loop.run_until_complete(self.protocol.close(reason="close")) + self.assertConnectionClosed(1000, "close") def test_local_close_connection_lost_timeout_after_close(self): self.protocol.close_timeout = 10 * MS @@ -1274,5 +1274,5 @@ def test_local_close_connection_lost_timeout_after_close(self): # HACK: disable close => other end drops connection emulation. self.transport._closing = True self.receive_frame(self.close_frame) - self.loop.run_until_complete(self.protocol.close(reason='close')) - self.assertConnectionClosed(1000, 'close') + self.loop.run_until_complete(self.protocol.close(reason="close")) + self.assertConnectionClosed(1000, "close") diff --git a/tests/test_uri.py b/tests/test_uri.py index ad4ec4013..b7b69c3c1 100644 --- a/tests/test_uri.py +++ b/tests/test_uri.py @@ -5,17 +5,17 @@ VALID_URIS = [ - ('ws://localhost/', (False, 'localhost', 80, '/', None)), - ('wss://localhost/', (True, 'localhost', 443, '/', None)), - ('ws://localhost/path?query', (False, 'localhost', 80, '/path?query', None)), - ('WS://LOCALHOST/PATH?QUERY', (False, 'localhost', 80, '/PATH?QUERY', None)), - ('ws://user:pass@localhost/', (False, 'localhost', 80, '/', ('user', 'pass'))), + ("ws://localhost/", (False, "localhost", 80, "/", None)), + ("wss://localhost/", (True, "localhost", 443, "/", None)), + ("ws://localhost/path?query", (False, "localhost", 80, "/path?query", None)), + ("WS://LOCALHOST/PATH?QUERY", (False, "localhost", 80, "/PATH?QUERY", None)), + ("ws://user:pass@localhost/", (False, "localhost", 80, "/", ("user", "pass"))), ] INVALID_URIS = [ - 'http://localhost/', - 'https://localhost/', - 'ws://localhost/path#fragment', + "http://localhost/", + "https://localhost/", + "ws://localhost/path#fragment", ] diff --git a/tests/test_utils.py b/tests/test_utils.py index 1b913fe7f..e5570f098 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -12,10 +12,10 @@ def apply_mask(*args, **kwargs): apply_mask_type_combos = list(itertools.product([bytes, bytearray], repeat=2)) apply_mask_test_values = [ - (b'', b'1234', b''), - (b'aBcDe', b'\x00\x00\x00\x00', b'aBcDe'), - (b'abcdABCD', b'1234', b'PPPPpppp'), - (b'abcdABCD' * 10, b'1234', b'PPPPpppp' * 10), + (b"", b"1234", b""), + (b"aBcDe", b"\x00\x00\x00\x00", b"aBcDe"), + (b"abcdABCD", b"1234", b"PPPPpppp"), + (b"abcdABCD" * 10, b"1234", b"PPPPpppp" * 10), ] def test_apply_mask(self): @@ -50,17 +50,17 @@ def test_apply_mask_non_contiguous_memoryview(self): self.assertEqual(result, data_out) def test_apply_mask_check_input_types(self): - for data_in, mask in [(None, None), (b'abcd', None), (None, b'abcd')]: + for data_in, mask in [(None, None), (b"abcd", None), (None, b"abcd")]: with self.subTest(data_in=data_in, mask=mask): with self.assertRaises(TypeError): self.apply_mask(data_in, mask) def test_apply_mask_check_mask_length(self): for data_in, mask in [ - (b'', b''), - (b'abcd', b'123'), - (b'', b'aBcDe'), - (b'12345678', b'12345678'), + (b"", b""), + (b"abcd", b"123"), + (b"", b"aBcDe"), + (b"12345678", b"12345678"), ]: with self.subTest(data_in=data_in, mask=mask): with self.assertRaises(ValueError): diff --git a/tox.ini b/tox.ini index 6cff294e5..e9623ec7d 100644 --- a/tox.ini +++ b/tox.ini @@ -12,7 +12,7 @@ commands = deps = coverage [testenv:black] -commands = black --check --skip-string-normalization src tests +commands = black --check src tests deps = black [testenv:flake8] From c6bf1ee284ca7ac42d7e0c556d6e2fb5ff97bd05 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 29 Dec 2018 14:46:41 +0100 Subject: [PATCH 0521/1539] Drop compatibility with Python 3.4. It's EOL in three months. I'm not putting effort into supporting obsolete versions for free :-) --- .appveyor.yml | 6 +++--- .circleci/config.yml | 12 ------------ .travis.yml | 4 ++-- README.rst | 2 +- docs/changelog.rst | 5 +++++ docs/intro.rst | 5 +---- setup.cfg | 2 +- setup.py | 7 +++---- tests/test_client_server.py | 8 +------- tox.ini | 2 +- 10 files changed, 18 insertions(+), 35 deletions(-) diff --git a/.appveyor.yml b/.appveyor.yml index 461ff5ced..5109200b4 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -5,14 +5,14 @@ branches: skip_branch_with_pr: true environment: -# websockets only works on Python >= 3.4. - CIBW_SKIP: cp27-* cp33-* +# websockets only works on Python >= 3.5. + CIBW_SKIP: cp27-* cp33-* cp34-* CIBW_TEST_COMMAND: python -W default -m unittest WEBSOCKETS_TESTS_TIMEOUT_FACTOR: 100 install: # Ensure python is Python 3. - - set PATH=C:\Python34;%PATH% + - set PATH=C:\Python37;%PATH% - cmd: python -m pip install --upgrade cibuildwheel # Create file '.cibuildwheel' so that extension build is not optional (c.f. setup.py). - cmd: touch .cibuildwheel diff --git a/.circleci/config.yml b/.circleci/config.yml index f0ca45b21..5ec5b5103 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -11,15 +11,6 @@ jobs: - run: sudo pip install tox codecov - run: tox -e coverage,black,flake8,isort - run: codecov - py34: - docker: - - image: circleci/python:3.4 - steps: - # Remove IPv6 entry for localhost in Circle CI containers because it doesn't work anyway. - - run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc - - checkout - - run: sudo pip install tox - - run: tox -e py34 py35: docker: - image: circleci/python:3.5 @@ -53,9 +44,6 @@ workflows: build: jobs: - main - - py34: - requires: - - main - py35: requires: - main diff --git a/.travis.yml b/.travis.yml index 3d6dd2089..c0f11357e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,7 @@ env: global: - # websockets only works on Python >= 3.4. - - CIBW_SKIP="cp27-* cp33-*" + # websockets only works on Python >= 3.5. + - CIBW_SKIP="cp27-* cp33-* cp34-*" - CIBW_TEST_COMMAND="python3 -W default -m unittest" - WEBSOCKETS_TESTS_TIMEOUT_FACTOR=100 diff --git a/README.rst b/README.rst index b57317d19..572647a15 100644 --- a/README.rst +++ b/README.rst @@ -124,7 +124,7 @@ Why shouldn't I use ``websockets``? and :rfc:`7692`: Compression Extensions for WebSocket. Its support for HTTP is minimal — just enough for a HTTP health check. * If you want to use Python 2: ``websockets`` builds upon ``asyncio`` which - only works on Python 3. ``websockets`` requires Python ≥ 3.4. + only works on Python 3. ``websockets`` requires Python ≥ 3.5. What else? ---------- diff --git a/docs/changelog.rst b/docs/changelog.rst index a76e1212e..87e2e0ac8 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -8,11 +8,16 @@ Changelog *In development* +.. warning:: + + **Version 8.0 drops compatibility with Python 3.4.** + .. warning:: **Version 8.0 adds the reason phrase to the return type of the low-level API** :func:`~http.read_response` **.** + Also: * :meth:`~protocol.WebSocketCommonProtocol.send`, diff --git a/docs/intro.rst b/docs/intro.rst index 154e1d8ea..b153d2f5d 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -6,15 +6,12 @@ Getting started Requirements ------------ -``websockets`` requires Python ≥ 3.4. +``websockets`` requires Python ≥ 3.5. You should use the latest version of Python if possible. If you're using an older version, be aware that for each minor version (3.x), only the latest bugfix release (3.x.y) is officially supported. -For the best experience, you should start with Python ≥ 3.6. :mod:`asyncio` -received interesting improvements between Python 3.4 and 3.6. - .. warning:: This documentation is written for Python ≥ 3.6. If you're using an older diff --git a/setup.cfg b/setup.cfg index ad3af102f..88b9b1a33 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bdist_wheel] -python-tag = py34.py35.py36.py37 +python-tag = py35.py36.py37 [metadata] license_file = LICENSE diff --git a/setup.py b/setup.py index 3a86887aa..b9e121af7 100644 --- a/setup.py +++ b/setup.py @@ -20,8 +20,8 @@ py_version = sys.version_info[:2] -if py_version < (3, 4): - raise Exception("websockets requires Python >= 3.4.") +if py_version < (3, 5): + raise Exception("websockets requires Python >= 3.5.") packages = ['websockets', 'websockets/extensions'] @@ -56,7 +56,6 @@ 'Operating System :: OS Independent', 'Programming Language :: Python', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', @@ -66,6 +65,6 @@ ext_modules=ext_modules, include_package_data=True, zip_safe=True, - python_requires='>=3.4', + python_requires='>=3.5', test_loader='unittest:TestLoader', ) diff --git a/tests/test_client_server.py b/tests/test_client_server.py index 9ba2725d9..6a06bdaf9 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -386,9 +386,6 @@ def test_explicit_event_loop(self): reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") - # The way the legacy SSL implementation wraps sockets makes it extremely - # hard to write a test for Python 3.4. - @unittest.skipIf(sys.version_info[:2] <= (3, 4), "this test requires Python 3.5+") @with_server() def test_explicit_socket(self): class TrackedSocket(socket.socket): @@ -1132,10 +1129,7 @@ def client_context(self): ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ssl_context.load_verify_locations(testcert) ssl_context.verify_mode = ssl.CERT_REQUIRED - # ssl.match_hostname can't match IP addresses on Python < 3.5. - # We're using IP addresses to enforce testing of IPv4 and IPv6. - if sys.version_info[:2] >= (3, 5): # pragma: no cover - ssl_context.check_hostname = True + ssl_context.check_hostname = True return ssl_context def start_server(self, **kwds): diff --git a/tox.ini b/tox.ini index e9623ec7d..de0f285d0 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py34,py35,py36,py37,coverage,black,flake8,isort +envlist = py35,py36,py37,coverage,black,flake8,isort [testenv] commands = python -W default -m unittest {posargs} From db25f49496343bb6aacbe31994da83a5470dc67c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 29 Dec 2018 15:03:39 +0100 Subject: [PATCH 0522/1539] Remove asyncio_ensure_future compatibility function. --- src/websockets/__main__.py | 7 +++---- src/websockets/compatibility.py | 7 ------- src/websockets/protocol.py | 21 ++++++--------------- src/websockets/server.py | 5 ++--- tests/test_protocol.py | 22 ++++++++-------------- 5 files changed, 19 insertions(+), 43 deletions(-) diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index 4c880c24c..b0fdaa6fe 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -6,7 +6,6 @@ import threading import websockets -from websockets.compatibility import asyncio_ensure_future from websockets.exceptions import format_close @@ -101,8 +100,8 @@ def run_client(uri, loop, inputs, stop): try: while True: - incoming = asyncio_ensure_future(websocket.recv()) - outgoing = asyncio_ensure_future(inputs.get()) + incoming = asyncio.ensure_future(websocket.recv()) + outgoing = asyncio.ensure_future(inputs.get()) done, pending = yield from asyncio.wait( [incoming, outgoing, stop], return_when=asyncio.FIRST_COMPLETED ) @@ -173,7 +172,7 @@ def main(): stop = asyncio.Future(loop=loop) # Schedule the task that will manage the connection. - asyncio_ensure_future(run_client(args.uri, loop, inputs, stop), loop=loop) + asyncio.ensure_future(run_client(args.uri, loop, inputs, stop), loop=loop) # Start the event loop in a background thread. thread = threading.Thread(target=loop.run_forever) diff --git a/src/websockets/compatibility.py b/src/websockets/compatibility.py index 8b7a21a5c..2e9fcef2b 100644 --- a/src/websockets/compatibility.py +++ b/src/websockets/compatibility.py @@ -4,16 +4,9 @@ """ -import asyncio import http -# Replace with BaseEventLoop.create_task when dropping Python < 3.4.2. -try: # pragma: no cover - asyncio_ensure_future = asyncio.ensure_future # Python ≥ 3.5 -except AttributeError: # pragma: no cover - asyncio_ensure_future = getattr(asyncio, "async") # Python < 3.5 - try: # pragma: no cover # Python ≥ 3.5 SWITCHING_PROTOCOLS = http.HTTPStatus.SWITCHING_PROTOCOLS diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index d7d7282a1..62845e0a8 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -18,7 +18,6 @@ import sys import warnings -from .compatibility import asyncio_ensure_future from .exceptions import ( ConnectionClosed, InvalidState, @@ -288,17 +287,11 @@ def connection_open(self): self.state = State.OPEN logger.debug("%s - state = OPEN", self.side) # Start the task that receives incoming WebSocket messages. - self.transfer_data_task = asyncio_ensure_future( - self.transfer_data(), loop=self.loop - ) + self.transfer_data_task = self.loop.create_task(self.transfer_data()) # Start the task that sends pings at regular intervals. - self.keepalive_ping_task = asyncio_ensure_future( - self.keepalive_ping(), loop=self.loop - ) + self.keepalive_ping_task = self.loop.create_task(self.keepalive_ping()) # Start the task that eventually closes the TCP connection. - self.close_connection_task = asyncio_ensure_future( - self.close_connection(), loop=self.loop - ) + self.close_connection_task = self.loop.create_task(self.close_connection()) # Public API @@ -519,8 +512,8 @@ def close(self, code=1000, reason=""): :meth:`close` is idempotent: it doesn't do anything once the connection is closed. - It's safe to wrap this coroutine in :func:`~asyncio.ensure_future` - since errors during connection termination aren't particularly useful. + It's safe to wrap this coroutine in :func:`~asyncio.create_task` since + errors during connection termination aren't particularly useful. ``code`` must be an :class:`int` and ``reason`` a :class:`str`. @@ -1142,9 +1135,7 @@ def fail_connection(self, code=1006, reason=""): # Start close_connection_task if the opening handshake didn't succeed. if self.close_connection_task is None: - self.close_connection_task = asyncio_ensure_future( - self.close_connection(), loop=self.loop - ) + self.close_connection_task = self.loop.create_task(self.close_connection()) def abort_keepalive_pings(self): """ diff --git a/src/websockets/server.py b/src/websockets/server.py index e207db2bc..c9e2cc23a 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -17,7 +17,6 @@ SERVICE_UNAVAILABLE, SWITCHING_PROTOCOLS, UPGRADE_REQUIRED, - asyncio_ensure_future, ) from .exceptions import ( AbortHandshake, @@ -95,7 +94,7 @@ def connection_made(self, transport): # create a race condition between the creation of the task, which # schedules its execution, and the moment the handler starts running. self.ws_server.register(self) - self.handler_task = asyncio_ensure_future(self.handler(), loop=self.loop) + self.handler_task = self.loop.create_task(self.handler()) @asyncio.coroutine def handler(self): @@ -605,7 +604,7 @@ def close(self): """ if self.close_task is None: - self.close_task = asyncio_ensure_future(self._close(), loop=self.loop) + self.close_task = self.loop.create_task(self._close()) @asyncio.coroutine def _close(self): diff --git a/tests/test_protocol.py b/tests/test_protocol.py index cb562e647..896c0fe4b 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -1,13 +1,11 @@ import asyncio import contextlib -import functools import logging import os import time import unittest import unittest.mock -from websockets.compatibility import asyncio_ensure_future from websockets.exceptions import ConnectionClosed, InvalidState from websockets.framing import * from websockets.protocol import State, WebSocketCommonProtocol @@ -130,10 +128,6 @@ def delayed_drain(): local_close = Frame(True, OP_CLOSE, serialize_close(1000, "local")) remote_close = Frame(True, OP_CLOSE, serialize_close(1000, "remote")) - @property - def ensure_future(self): - return functools.partial(asyncio_ensure_future, loop=self.loop) - def receive_frame(self, frame): """ Make the protocol receive a frame. @@ -197,7 +191,7 @@ def half_close_connection_local(self, code=1000, reason="close"): """ close_frame_data = serialize_close(code, reason) # Trigger the closing handshake from the local endpoint. - close_task = self.ensure_future(self.protocol.close(code, reason)) + close_task = self.loop.create_task(self.protocol.close(code, reason)) self.run_loop_once() # wait_for executes self.run_loop_once() # write_frame executes # Empty the outgoing data stream so we can make assertions later on. @@ -371,7 +365,7 @@ def test_closed(self): self.assertTrue(self.protocol.closed) def test_wait_closed(self): - wait_closed = asyncio_ensure_future(self.protocol.wait_closed()) + wait_closed = self.loop.create_task(self.protocol.wait_closed()) self.assertFalse(wait_closed.done()) self.close_connection() self.assertTrue(wait_closed.done()) @@ -443,7 +437,7 @@ def test_recv_binary_no_max_size(self): self.assertEqual(data, b"tea" * 342) def test_recv_queue_empty(self): - recv = self.ensure_future(self.protocol.recv()) + recv = self.loop.create_task(self.protocol.recv()) with self.assertRaises(asyncio.TimeoutError): self.loop.run_until_complete( asyncio.wait_for(asyncio.shield(recv), timeout=MS) @@ -489,7 +483,7 @@ def read_message(): self.assertConnectionFailed(1011, "") def test_recv_canceled(self): - recv = self.ensure_future(self.protocol.recv()) + recv = self.loop.create_task(self.protocol.recv()) self.loop.call_soon(recv.cancel) with self.assertRaises(asyncio.CancelledError): @@ -501,7 +495,7 @@ def test_recv_canceled(self): self.assertEqual(data, "café") def test_recv_canceled_race_condition(self): - recv = self.ensure_future( + recv = self.loop.create_task( asyncio.wait_for(self.protocol.recv(), timeout=0.000001) ) self.loop.call_soon( @@ -518,7 +512,7 @@ def test_recv_canceled_race_condition(self): self.assertEqual(data, "café") def test_recv_prevents_concurrent_calls(self): - recv = self.ensure_future(self.protocol.recv()) + recv = self.loop.create_task(self.protocol.recv()) with self.assertRaises(RuntimeError): self.loop.run_until_complete(self.protocol.recv()) @@ -1136,7 +1130,7 @@ def test_close_connection_lost(self): self.assertConnectionFailed(1006, "") def test_local_close_during_recv(self): - recv = self.ensure_future(self.protocol.recv()) + recv = self.loop.create_task(self.protocol.recv()) self.loop.call_later(MS, self.receive_frame, self.close_frame) self.loop.call_later(MS, self.receive_eof_if_client) @@ -1153,7 +1147,7 @@ def test_local_close_during_recv(self): def test_remote_close_during_send(self): self.make_drain_slow() - send = self.ensure_future(self.protocol.send("hello")) + send = self.loop.create_task(self.protocol.send("hello")) self.receive_frame(self.close_frame) self.receive_eof() From 54b1c370f74712a2131e4eca5415c76a3df5f4e5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 29 Dec 2018 16:05:16 +0100 Subject: [PATCH 0523/1539] Remove http.HTTPStatus compatibility definitions. --- src/websockets/compatibility.py | 77 --------------------------------- src/websockets/server.py | 32 +++++++------- tests/test_client_server.py | 39 ++++++++--------- 3 files changed, 32 insertions(+), 116 deletions(-) delete mode 100644 src/websockets/compatibility.py diff --git a/src/websockets/compatibility.py b/src/websockets/compatibility.py deleted file mode 100644 index 2e9fcef2b..000000000 --- a/src/websockets/compatibility.py +++ /dev/null @@ -1,77 +0,0 @@ -""" -The :mod:`websockets.compatibility` module provides helpers for bridging -compatibility issues across Python versions. - -""" - -import http - - -try: # pragma: no cover - # Python ≥ 3.5 - SWITCHING_PROTOCOLS = http.HTTPStatus.SWITCHING_PROTOCOLS - OK = http.HTTPStatus.OK - BAD_REQUEST = http.HTTPStatus.BAD_REQUEST - UNAUTHORIZED = http.HTTPStatus.UNAUTHORIZED - FORBIDDEN = http.HTTPStatus.FORBIDDEN - UPGRADE_REQUIRED = http.HTTPStatus.UPGRADE_REQUIRED - INTERNAL_SERVER_ERROR = http.HTTPStatus.INTERNAL_SERVER_ERROR - SERVICE_UNAVAILABLE = http.HTTPStatus.SERVICE_UNAVAILABLE - MOVED_PERMANENTLY = http.HTTPStatus.MOVED_PERMANENTLY - FOUND = http.HTTPStatus.FOUND - SEE_OTHER = http.HTTPStatus.SEE_OTHER - TEMPORARY_REDIRECT = http.HTTPStatus.TEMPORARY_REDIRECT - PERMANENT_REDIRECT = http.HTTPStatus.PERMANENT_REDIRECT -except AttributeError: # pragma: no cover - # Python < 3.5 - class SWITCHING_PROTOCOLS: - value = 101 - phrase = "Switching Protocols" - - class OK: - value = 200 - phrase = "OK" - - class BAD_REQUEST: - value = 400 - phrase = "Bad Request" - - class UNAUTHORIZED: - value = 401 - phrase = "Unauthorized" - - class FORBIDDEN: - value = 403 - phrase = "Forbidden" - - class UPGRADE_REQUIRED: - value = 426 - phrase = "Upgrade Required" - - class INTERNAL_SERVER_ERROR: - value = 500 - phrase = "Internal Server Error" - - class SERVICE_UNAVAILABLE: - value = 503 - phrase = "Service Unavailable" - - class MOVED_PERMANENTLY: - value = 301 - phrase = "Moved Permanently" - - class FOUND: - value = 302 - phrase = "Found" - - class SEE_OTHER: - value = 303 - phrase = "See Other" - - class TEMPORARY_REDIRECT: - value = 307 - phrase = "Temporary Redirect" - - class PERMANENT_REDIRECT: - value = 308 - phrase = "Permanent Redirect" diff --git a/src/websockets/server.py b/src/websockets/server.py index c9e2cc23a..453acec4d 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -6,18 +6,11 @@ import asyncio import collections.abc import email.utils +import http import logging import sys import warnings -from .compatibility import ( - BAD_REQUEST, - FORBIDDEN, - INTERNAL_SERVER_ERROR, - SERVICE_UNAVAILABLE, - SWITCHING_PROTOCOLS, - UPGRADE_REQUIRED, -) from .exceptions import ( AbortHandshake, InvalidHandshake, @@ -123,25 +116,29 @@ def handler(self): status, headers, body = exc.status, exc.headers, exc.body elif isinstance(exc, InvalidOrigin): logger.debug("Invalid origin", exc_info=True) - status, headers, body = FORBIDDEN, [], (str(exc) + "\n").encode() + status, headers, body = ( + http.HTTPStatus.FORBIDDEN, + [], + (str(exc) + "\n").encode(), + ) elif isinstance(exc, InvalidUpgrade): logger.debug("Invalid upgrade", exc_info=True) status, headers, body = ( - UPGRADE_REQUIRED, + http.HTTPStatus.UPGRADE_REQUIRED, [("Upgrade", "websocket")], (str(exc) + "\n").encode(), ) elif isinstance(exc, InvalidHandshake): logger.debug("Invalid handshake", exc_info=True) status, headers, body = ( - BAD_REQUEST, + http.HTTPStatus.BAD_REQUEST, [], (str(exc) + "\n").encode(), ) else: logger.warning("Error in opening handshake", exc_info=True) status, headers, body = ( - INTERNAL_SERVER_ERROR, + http.HTTPStatus.INTERNAL_SERVER_ERROR, [], b"See server log for more information.\n", ) @@ -251,9 +248,6 @@ def process_request(self, path, request_headers): response is sent and the connection is closed. The HTTP status must be a :class:`~http.HTTPStatus`. - (:class:`~http.HTTPStatus` was added in Python 3.5. Use a compatible - object on earlier versions. Look at ``SWITCHING_PROTOCOLS`` in - ``websockets.compatibility`` for an example.) HTTP headers must be a :class:`~websockets.http.Headers` instance, a :class:`~collections.abc.Mapping`, or an iterable of ``(name, value)`` @@ -475,7 +469,11 @@ def handshake( # Change the response to a 503 error if the server is shutting down. if not self.ws_server.is_serving(): - early_response = SERVICE_UNAVAILABLE, [], b"Server is shutting down.\n" + early_response = ( + http.HTTPStatus.SERVICE_UNAVAILABLE, + [], + b"Server is shutting down.\n", + ) if early_response is not None: raise AbortHandshake(*early_response) @@ -515,7 +513,7 @@ def handshake( response_headers.setdefault("Date", email.utils.formatdate(usegmt=True)) response_headers.setdefault("Server", USER_AGENT) - self.write_http_response(SWITCHING_PROTOCOLS, response_headers) + self.write_http_response(http.HTTPStatus.SWITCHING_PROTOCOLS, response_headers) self.connection_open() diff --git a/tests/test_client_server.py b/tests/test_client_server.py index 6a06bdaf9..214b1a627 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -1,6 +1,7 @@ import asyncio import contextlib import functools +import http import logging import pathlib import random @@ -15,16 +16,6 @@ import warnings from websockets.client import * -from websockets.compatibility import ( - FORBIDDEN, - FOUND, - MOVED_PERMANENTLY, - OK, - PERMANENT_REDIRECT, - SEE_OTHER, - TEMPORARY_REDIRECT, - UNAUTHORIZED, -) from websockets.exceptions import ( ConnectionClosed, InvalidHandshake, @@ -174,14 +165,14 @@ class UnauthorizedServerProtocol(WebSocketServerProtocol): @asyncio.coroutine def process_request(self, path, request_headers): # Test returning headers as a Headers instance (1/3) - return UNAUTHORIZED, Headers([("X-Access", "denied")]), b"" + return http.HTTPStatus.UNAUTHORIZED, Headers([("X-Access", "denied")]), b"" class ForbiddenServerProtocol(WebSocketServerProtocol): @asyncio.coroutine def process_request(self, path, request_headers): # Test returning headers as a dict (2/3) - return FORBIDDEN, {"X-Access": "denied"}, b"" + return http.HTTPStatus.FORBIDDEN, {"X-Access": "denied"}, b"" class HealthCheckServerProtocol(WebSocketServerProtocol): @@ -189,7 +180,7 @@ class HealthCheckServerProtocol(WebSocketServerProtocol): def process_request(self, path, request_headers): # Test returning headers as a list of pairs (3/3) if path == "/__health__/": - return OK, [("X-Access", "OK")], b"status = green\n" + return http.HTTPStatus.OK, [("X-Access", "OK")], b"status = green\n" class SlowServerProtocol(WebSocketServerProtocol): @@ -352,11 +343,11 @@ def test_basic(self): @with_server() def test_redirect(self): redirect_statuses = [ - MOVED_PERMANENTLY, - FOUND, - SEE_OTHER, - TEMPORARY_REDIRECT, - PERMANENT_REDIRECT, + http.HTTPStatus.MOVED_PERMANENTLY, + http.HTTPStatus.FOUND, + http.HTTPStatus.SEE_OTHER, + http.HTTPStatus.TEMPORARY_REDIRECT, + http.HTTPStatus.PERMANENT_REDIRECT, ] for status in redirect_statuses: with temp_test_redirecting_server(self, status): @@ -366,7 +357,7 @@ def test_redirect(self): self.assertEqual(reply, "Hello!") def test_infinite_redirect(self): - with temp_test_redirecting_server(self, FOUND): + with temp_test_redirecting_server(self, http.HTTPStatus.FOUND): self.server = self.redirecting_server with self.assertRaises(InvalidHandshake): with temp_test_client(self): @@ -374,7 +365,9 @@ def test_infinite_redirect(self): @with_server() def test_redirect_missing_location(self): - with temp_test_redirecting_server(self, FOUND, include_location=False): + with temp_test_redirecting_server( + self, http.HTTPStatus.FOUND, include_location=False + ): with self.assertRaises(InvalidMessage): with temp_test_client(self): self.fail("Did not raise") # pragma: no cover @@ -449,7 +442,7 @@ def test_unix_socket(self): client_socket.close() self.stop_server() - @with_server(process_request=lambda p, rh: (OK, [], b"OK\n")) + @with_server(process_request=lambda p, rh: (http.HTTPStatus.OK, [], b"OK\n")) def test_process_request_argument(self): response = self.loop.run_until_complete(self.make_http_request("/")) @@ -1156,7 +1149,9 @@ def test_ws_uri_is_rejected(self): @with_server() def test_redirect_insecure(self): - with temp_test_redirecting_server(self, FOUND, force_insecure=True): + with temp_test_redirecting_server( + self, http.HTTPStatus.FOUND, force_insecure=True + ): with self.assertRaises(InvalidHandshake): with temp_test_client(self): self.fail("Did not raise") # pragma: no cover From a1541526172de88020d925cd61703a7a89b8595c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 29 Dec 2018 16:20:12 +0100 Subject: [PATCH 0524/1539] Merge Python 3.5+ packages. --- setup.py | 3 - src/websockets/client.py | 53 ++++++++++++-- src/websockets/py35/__init__.py | 2 - src/websockets/py35/client.py | 48 ------------- src/websockets/py35/server.py | 22 ------ src/websockets/server.py | 30 ++++++-- tests/py35/__init__.py | 0 tests/py35/_test_client_server.py | 112 ------------------------------ tests/test_client_server.py | 99 +++++++++++++++++++++++++- 9 files changed, 169 insertions(+), 200 deletions(-) delete mode 100644 src/websockets/py35/__init__.py delete mode 100644 src/websockets/py35/client.py delete mode 100644 src/websockets/py35/server.py delete mode 100644 tests/py35/__init__.py delete mode 100644 tests/py35/_test_client_server.py diff --git a/setup.py b/setup.py index b9e121af7..78d6f7af4 100644 --- a/setup.py +++ b/setup.py @@ -25,9 +25,6 @@ packages = ['websockets', 'websockets/extensions'] -if py_version >= (3, 5): - packages.append('websockets/py35') - if py_version >= (3, 6): packages.append('websockets/py36') diff --git a/src/websockets/client.py b/src/websockets/client.py index 66034ce25..cb2e3ff7f 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -506,11 +506,58 @@ def __iter__(self): # pragma: no cover self.ws_client = protocol return protocol + async def __aenter__(self): + return await self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.ws_client.close() + + async def __await_impl__(self): + # Duplicated with __iter__ because Python 3.7 requires an async function + # (as explained in __await__ below) which Python 3.4 doesn't support. + for redirects in range(self.MAX_REDIRECTS_ALLOWED): + transport, protocol = await self._creating_connection() + + try: + try: + await protocol.handshake( + self._wsuri, + origin=self._origin, + available_extensions=protocol.available_extensions, + available_subprotocols=protocol.available_subprotocols, + extra_headers=protocol.extra_headers, + ) + break # redirection chain ended + except Exception: + protocol.fail_connection() + await protocol.wait_closed() + raise + except RedirectHandshake as e: + if self._wsuri.secure and not e.wsuri.secure: + raise InvalidHandshake("Redirect dropped TLS") + self._wsuri = e.wsuri + continue # redirection chain continues + else: + raise InvalidHandshake("Maximum redirects exceeded") + + self.ws_client = protocol + return protocol + + def __await__(self): + # __await__() must return a type that I don't know how to obtain except + # by calling __await__() on the return value of an async function. + # I'm not finding a better way to take advantage of PEP 492. + return self.__await_impl__().__await__() + # We can't define __await__ on Python < 3.5.1 because asyncio.ensure_future # didn't accept arbitrary awaitables until Python 3.5.1. We don't define # __aenter__ and __aexit__ either on Python < 3.5.1 to keep things simple. -if sys.version_info[:3] <= (3, 5, 0): # pragma: no cover +if sys.version_info[:3] < (3, 5, 1): # pragma: no cover + + del Connect.__aenter__ + del Connect.__aexit__ + del Connect.__await__ @asyncio.coroutine def connect(*args, **kwds): @@ -519,9 +566,5 @@ def connect(*args, **kwds): connect.__doc__ = Connect.__doc__ else: - from .py35.client import __aenter__, __aexit__, __await__ - Connect.__aenter__ = __aenter__ - Connect.__aexit__ = __aexit__ - Connect.__await__ = __await__ connect = Connect diff --git a/src/websockets/py35/__init__.py b/src/websockets/py35/__init__.py deleted file mode 100644 index 9612d9dd7..000000000 --- a/src/websockets/py35/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# This package contains code using async / await syntax added in Python 3.5. -# It cannot be imported on Python < 3.5 because it triggers syntax errors. diff --git a/src/websockets/py35/client.py b/src/websockets/py35/client.py deleted file mode 100644 index ccb098483..000000000 --- a/src/websockets/py35/client.py +++ /dev/null @@ -1,48 +0,0 @@ -from ..exceptions import InvalidHandshake, RedirectHandshake - - -async def __aenter__(self): - return await self - - -async def __aexit__(self, exc_type, exc_value, traceback): - await self.ws_client.close() - - -async def __await_impl__(self): - # Duplicated with __iter__ because Python 3.7 requires an async function - # (as explained in __await__ below) which Python 3.4 doesn't support. - for redirects in range(self.MAX_REDIRECTS_ALLOWED): - transport, protocol = await self._creating_connection() - - try: - try: - await protocol.handshake( - self._wsuri, - origin=self._origin, - available_extensions=protocol.available_extensions, - available_subprotocols=protocol.available_subprotocols, - extra_headers=protocol.extra_headers, - ) - break # redirection chain ended - except Exception: - protocol.fail_connection() - await protocol.wait_closed() - raise - except RedirectHandshake as e: - if self._wsuri.secure and not e.wsuri.secure: - raise InvalidHandshake("Redirect dropped TLS") - self._wsuri = e.wsuri - continue # redirection chain continues - else: - raise InvalidHandshake("Maximum redirects exceeded") - - self.ws_client = protocol - return protocol - - -def __await__(self): - # __await__() must return a type that I don't know how to obtain except - # by calling __await__() on the return value of an async function. - # I'm not finding a better way to take advantage of PEP 492. - return __await_impl__(self).__await__() diff --git a/src/websockets/py35/server.py b/src/websockets/py35/server.py deleted file mode 100644 index 41a3675e3..000000000 --- a/src/websockets/py35/server.py +++ /dev/null @@ -1,22 +0,0 @@ -async def __aenter__(self): - return await self - - -async def __aexit__(self, exc_type, exc_value, traceback): - self.ws_server.close() - await self.ws_server.wait_closed() - - -async def __await_impl__(self): - # Duplicated with __iter__ because Python 3.7 requires an async function - # (as explained in __await__ below) which Python 3.4 doesn't support. - server = await self._creating_server - self.ws_server.wrap(server) - return self.ws_server - - -def __await__(self): - # __await__() must return a type that I don't know how to obtain except - # by calling __await__() on the return value of an async function. - # I'm not finding a better way to take advantage of PEP 492. - return __await_impl__(self).__await__() diff --git a/src/websockets/server.py b/src/websockets/server.py index 453acec4d..424d08922 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -850,6 +850,26 @@ def __iter__(self): # pragma: no cover self.ws_server.wrap(server) return self.ws_server + async def __aenter__(self): + return await self + + async def __aexit__(self, exc_type, exc_value, traceback): + self.ws_server.close() + await self.ws_server.wait_closed() + + async def __await_impl__(self): + # Duplicated with __iter__ because Python 3.7 requires an async function + # (as explained in __await__ below) which Python 3.4 doesn't support. + server = await self._creating_server + self.ws_server.wrap(server) + return self.ws_server + + def __await__(self): + # __await__() must return a type that I don't know how to obtain except + # by calling __await__() on the return value of an async function. + # I'm not finding a better way to take advantage of PEP 492. + return self.__await_impl__().__await__() + def unix_serve(ws_handler, path, **kwargs): """ @@ -869,7 +889,11 @@ def unix_serve(ws_handler, path, **kwargs): # We can't define __await__ on Python < 3.5.1 because asyncio.ensure_future # didn't accept arbitrary awaitables until Python 3.5.1. We don't define # __aenter__ and __aexit__ either on Python < 3.5.1 to keep things simple. -if sys.version_info[:3] <= (3, 5, 0): # pragma: no cover +if sys.version_info[:3] < (3, 5, 1): # pragma: no cover + + del Serve.__aenter__ + del Serve.__aexit__ + del Serve.__await__ @asyncio.coroutine def serve(*args, **kwds): @@ -878,9 +902,5 @@ def serve(*args, **kwds): serve.__doc__ = Serve.__doc__ else: - from .py35.server import __aenter__, __aexit__, __await__ - Serve.__aenter__ = __aenter__ - Serve.__aexit__ = __aexit__ - Serve.__await__ = __await__ serve = Serve diff --git a/tests/py35/__init__.py b/tests/py35/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/py35/_test_client_server.py b/tests/py35/_test_client_server.py deleted file mode 100644 index 869c379b8..000000000 --- a/tests/py35/_test_client_server.py +++ /dev/null @@ -1,112 +0,0 @@ -# Tests containing Python 3.5+ syntax, extracted from test_client_server.py. - -import asyncio -import pathlib -import socket -import sys -import tempfile -import unittest - -from websockets.client import * -from websockets.protocol import State -from websockets.server import * - -from ..test_client_server import get_server_uri, handler - - -class AsyncAwaitTests(unittest.TestCase): - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) - - def tearDown(self): - self.loop.close() - - def test_client(self): - start_server = serve(handler, "localhost", 0) - server = self.loop.run_until_complete(start_server) - - async def run_client(): - # Await connect. - client = await connect(get_server_uri(server)) - self.assertEqual(client.state, State.OPEN) - await client.close() - self.assertEqual(client.state, State.CLOSED) - - self.loop.run_until_complete(run_client()) - - server.close() - self.loop.run_until_complete(server.wait_closed()) - - def test_server(self): - async def run_server(): - # Await serve. - server = await serve(handler, "localhost", 0) - self.assertTrue(server.sockets) - server.close() - await server.wait_closed() - self.assertFalse(server.sockets) - - self.loop.run_until_complete(run_server()) - - -class ContextManagerTests(unittest.TestCase): - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) - - def tearDown(self): - self.loop.close() - - # Asynchronous context managers are only enabled on Python ≥ 3.5.1. - @unittest.skipIf( - sys.version_info[:3] <= (3, 5, 0), "this test requires Python 3.5.1+" - ) - def test_client(self): - start_server = serve(handler, "localhost", 0) - server = self.loop.run_until_complete(start_server) - - async def run_client(): - # Use connect as an asynchronous context manager. - async with connect(get_server_uri(server)) as client: - self.assertEqual(client.state, State.OPEN) - - # Check that exiting the context manager closed the connection. - self.assertEqual(client.state, State.CLOSED) - - self.loop.run_until_complete(run_client()) - - server.close() - self.loop.run_until_complete(server.wait_closed()) - - # Asynchronous context managers are only enabled on Python ≥ 3.5.1. - @unittest.skipIf( - sys.version_info[:3] <= (3, 5, 0), "this test requires Python 3.5.1+" - ) - def test_server(self): - async def run_server(): - # Use serve as an asynchronous context manager. - async with serve(handler, "localhost", 0) as server: - self.assertTrue(server.sockets) - - # Check that exiting the context manager closed the server. - self.assertFalse(server.sockets) - - self.loop.run_until_complete(run_server()) - - # Asynchronous context managers are only enabled on Python ≥ 3.5.1. - @unittest.skipIf( - sys.version_info[:3] <= (3, 5, 0), "this test requires Python 3.5.1+" - ) - @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") - def test_unix_server(self): - async def run_server(path): - async with unix_serve(handler, path) as server: - self.assertTrue(server.sockets) - - # Check that exiting the context manager closed the server. - self.assertFalse(server.sockets) - - with tempfile.TemporaryDirectory() as temp_dir: - path = bytes(pathlib.Path(temp_dir) / "websockets") - self.loop.run_until_complete(run_server(path)) diff --git a/tests/test_client_server.py b/tests/test_client_server.py index 214b1a627..633e097bc 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -1279,9 +1279,102 @@ def run_server(): self.loop.run_until_complete(run_server()) -if sys.version_info[:2] >= (3, 5): # pragma: no cover - from .py35._test_client_server import AsyncAwaitTests # noqa - from .py35._test_client_server import ContextManagerTests # noqa +class AsyncAwaitTests(unittest.TestCase): + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + + def test_client(self): + start_server = serve(handler, "localhost", 0) + server = self.loop.run_until_complete(start_server) + + async def run_client(): + # Await connect. + client = await connect(get_server_uri(server)) + self.assertEqual(client.state, State.OPEN) + await client.close() + self.assertEqual(client.state, State.CLOSED) + + self.loop.run_until_complete(run_client()) + + server.close() + self.loop.run_until_complete(server.wait_closed()) + + def test_server(self): + async def run_server(): + # Await serve. + server = await serve(handler, "localhost", 0) + self.assertTrue(server.sockets) + server.close() + await server.wait_closed() + self.assertFalse(server.sockets) + + self.loop.run_until_complete(run_server()) + + +class ContextManagerTests(unittest.TestCase): + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + + # Asynchronous context managers are only enabled on Python ≥ 3.5.1. + @unittest.skipIf( + sys.version_info[:3] <= (3, 5, 0), "this test requires Python 3.5.1+" + ) + def test_client(self): + start_server = serve(handler, "localhost", 0) + server = self.loop.run_until_complete(start_server) + + async def run_client(): + # Use connect as an asynchronous context manager. + async with connect(get_server_uri(server)) as client: + self.assertEqual(client.state, State.OPEN) + + # Check that exiting the context manager closed the connection. + self.assertEqual(client.state, State.CLOSED) + + self.loop.run_until_complete(run_client()) + + server.close() + self.loop.run_until_complete(server.wait_closed()) + + # Asynchronous context managers are only enabled on Python ≥ 3.5.1. + @unittest.skipIf( + sys.version_info[:3] <= (3, 5, 0), "this test requires Python 3.5.1+" + ) + def test_server(self): + async def run_server(): + # Use serve as an asynchronous context manager. + async with serve(handler, "localhost", 0) as server: + self.assertTrue(server.sockets) + + # Check that exiting the context manager closed the server. + self.assertFalse(server.sockets) + + self.loop.run_until_complete(run_server()) + + # Asynchronous context managers are only enabled on Python ≥ 3.5.1. + @unittest.skipIf( + sys.version_info[:3] <= (3, 5, 0), "this test requires Python 3.5.1+" + ) + @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") + def test_unix_server(self): + async def run_server(path): + async with unix_serve(handler, path) as server: + self.assertTrue(server.sockets) + + # Check that exiting the context manager closed the server. + self.assertFalse(server.sockets) + + with tempfile.TemporaryDirectory() as temp_dir: + path = bytes(pathlib.Path(temp_dir) / "websockets") + self.loop.run_until_complete(run_server(path)) if sys.version_info[:2] >= (3, 6): # pragma: no cover From b5c40d597ed664f54382d3555b4f1d4cbd8c13d7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 29 Dec 2018 17:20:23 +0100 Subject: [PATCH 0525/1539] Switch to async / await syntax. --- compliance/test_client.py | 34 ++++----- compliance/test_server.py | 7 +- src/websockets/__main__.py | 11 ++- src/websockets/client.py | 45 ++---------- src/websockets/framing.py | 14 ++-- src/websockets/http.py | 25 +++---- src/websockets/protocol.py | 140 +++++++++++++++--------------------- src/websockets/server.py | 48 +++++-------- tests/test_client_server.py | 42 +++++------ tests/test_protocol.py | 13 ++-- 10 files changed, 144 insertions(+), 235 deletions(-) diff --git a/compliance/test_client.py b/compliance/test_client.py index 382d06a05..1c1d4416a 100644 --- a/compliance/test_client.py +++ b/compliance/test_client.py @@ -16,42 +16,38 @@ AGENT = 'websockets' -@asyncio.coroutine -def get_case_count(server): +async def get_case_count(server): uri = server + '/getCaseCount' - ws = yield from websockets.connect(uri) - msg = yield from ws.recv() - yield from ws.close() + ws = await websockets.connect(uri) + msg = await ws.recv() + await ws.close() return json.loads(msg) -@asyncio.coroutine -def run_case(server, case, agent): +async def run_case(server, case, agent): uri = server + '/runCase?case={}&agent={}'.format(case, agent) - ws = yield from websockets.connect(uri, max_size=2 ** 25, max_queue=1) + ws = await websockets.connect(uri, max_size=2 ** 25, max_queue=1) while True: try: - msg = yield from ws.recv() - yield from ws.send(msg) + msg = await ws.recv() + await ws.send(msg) except websockets.ConnectionClosed: break -@asyncio.coroutine -def update_reports(server, agent): +async def update_reports(server, agent): uri = server + '/updateReports?agent={}'.format(agent) - ws = yield from websockets.connect(uri) - yield from ws.close() + ws = await websockets.connect(uri) + await ws.close() -@asyncio.coroutine -def run_tests(server, agent): - cases = yield from get_case_count(server) +async def run_tests(server, agent): + cases = await get_case_count(server) for case in range(1, cases + 1): print("Running test case {} out of {}".format(case, cases), end="\r") - yield from run_case(server, case, agent) + await run_case(server, case, agent) print("Ran {} test cases ".format(cases)) - yield from update_reports(server, agent) + await update_reports(server, agent) main = run_tests(SERVER, urllib.parse.quote(AGENT)) diff --git a/compliance/test_server.py b/compliance/test_server.py index 75e0e3044..ac5990d16 100644 --- a/compliance/test_server.py +++ b/compliance/test_server.py @@ -10,12 +10,11 @@ # logging.getLogger('websockets').setLevel(logging.DEBUG) -@asyncio.coroutine -def echo(ws, path): +async def echo(ws, path): while True: try: - msg = yield from ws.recv() - yield from ws.send(msg) + msg = await ws.recv() + await ws.send(msg) except websockets.ConnectionClosed: break diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index b0fdaa6fe..350fc06e8 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -87,10 +87,9 @@ def print_over_input(string): sys.stdout.flush() -@asyncio.coroutine -def run_client(uri, loop, inputs, stop): +async def run_client(uri, loop, inputs, stop): try: - websocket = yield from websockets.connect(uri) + websocket = await websockets.connect(uri) except Exception as exc: print_over_input("Failed to connect to {}: {}.".format(uri, exc)) exit_from_event_loop_thread(loop, stop) @@ -102,7 +101,7 @@ def run_client(uri, loop, inputs, stop): while True: incoming = asyncio.ensure_future(websocket.recv()) outgoing = asyncio.ensure_future(inputs.get()) - done, pending = yield from asyncio.wait( + done, pending = await asyncio.wait( [incoming, outgoing, stop], return_when=asyncio.FIRST_COMPLETED ) @@ -122,13 +121,13 @@ def run_client(uri, loop, inputs, stop): if outgoing in done: message = outgoing.result() - yield from websocket.send(message) + await websocket.send(message) if stop in done: break finally: - yield from websocket.close() + await websocket.close() close_status = format_close(websocket.close_code, websocket.close_reason) print_over_input( diff --git a/src/websockets/client.py b/src/websockets/client.py index cb2e3ff7f..46dd1b447 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -79,8 +79,7 @@ def write_http_request(self, path, headers): self.writer.write(request.encode()) - @asyncio.coroutine - def read_http_response(self): + async def read_http_response(self): """ Read status line and headers from the HTTP response. @@ -93,7 +92,7 @@ def read_http_response(self): """ try: - status_code, reason, headers = yield from read_response(self.reader) + status_code, reason, headers = await read_response(self.reader) except ValueError as exc: raise InvalidMessage("Malformed HTTP message") from exc @@ -220,8 +219,7 @@ def process_subprotocol(headers, available_subprotocols): return subprotocol - @asyncio.coroutine - def handshake( + async def handshake( self, wsuri, origin=None, @@ -289,7 +287,7 @@ def handshake( self.write_http_request(wsuri.resource_name, request_headers) - status_code, response_headers = yield from self.read_http_response() + status_code, response_headers = await self.read_http_response() if status_code in (301, 302, 303, 307, 308): if "Location" not in response_headers: raise InvalidMessage("Redirect response missing Location") @@ -477,34 +475,8 @@ def _creating_connection(self): return self._loop.create_connection(factory, host, port, **self._kwds) @asyncio.coroutine - def __iter__(self): # pragma: no cover - for redirects in range(self.MAX_REDIRECTS_ALLOWED): - transport, protocol = yield from self._creating_connection() - - try: - try: - yield from protocol.handshake( - self._wsuri, - origin=self._origin, - available_extensions=protocol.available_extensions, - available_subprotocols=protocol.available_subprotocols, - extra_headers=protocol.extra_headers, - ) - break # redirection chain ended - except Exception: - protocol.fail_connection() - yield from protocol.wait_closed() - raise - except RedirectHandshake as e: - if self._wsuri.secure and not e.wsuri.secure: - raise InvalidHandshake("Redirect dropped TLS") - self._wsuri = e.wsuri - continue # redirection chain continues - else: - raise InvalidHandshake("Maximum redirects exceeded") - - self.ws_client = protocol - return protocol + def __iter__(self): + return self.__await_impl__() async def __aenter__(self): return await self @@ -513,8 +485,6 @@ async def __aexit__(self, exc_type, exc_value, traceback): await self.ws_client.close() async def __await_impl__(self): - # Duplicated with __iter__ because Python 3.7 requires an async function - # (as explained in __await__ below) which Python 3.4 doesn't support. for redirects in range(self.MAX_REDIRECTS_ALLOWED): transport, protocol = await self._creating_connection() @@ -559,8 +529,7 @@ def __await__(self): del Connect.__aexit__ del Connect.__await__ - @asyncio.coroutine - def connect(*args, **kwds): + async def connect(*args, **kwds): return Connect(*args, **kwds).__iter__() connect.__doc__ = Connect.__doc__ diff --git a/src/websockets/framing.py b/src/websockets/framing.py index 8b0242715..c6b5564f5 100644 --- a/src/websockets/framing.py +++ b/src/websockets/framing.py @@ -9,7 +9,6 @@ """ -import asyncio import collections import io import random @@ -73,8 +72,7 @@ def __new__(cls, fin, opcode, data, rsv1=False, rsv2=False, rsv3=False): return FrameData.__new__(cls, fin, opcode, data, rsv1, rsv2, rsv3) @classmethod - @asyncio.coroutine - def read(cls, reader, *, mask, max_size=None, extensions=None): + async def read(cls, reader, *, mask, max_size=None, extensions=None): """ Read a WebSocket frame and return a :class:`Frame` object. @@ -97,7 +95,7 @@ def read(cls, reader, *, mask, max_size=None, extensions=None): """ # Read the header. - data = yield from reader(2) + data = await reader(2) head1, head2 = struct.unpack("!BB", data) # While not Pythonic, this is marginally faster than calling bool(). @@ -112,10 +110,10 @@ def read(cls, reader, *, mask, max_size=None, extensions=None): length = head2 & 0b01111111 if length == 126: - data = yield from reader(2) + data = await reader(2) length, = struct.unpack("!H", data) elif length == 127: - data = yield from reader(8) + data = await reader(8) length, = struct.unpack("!Q", data) if max_size is not None and length > max_size: raise PayloadTooBig( @@ -124,10 +122,10 @@ def read(cls, reader, *, mask, max_size=None, extensions=None): ) ) if mask: - mask_bits = yield from reader(4) + mask_bits = await reader(4) # Read the data. - data = yield from reader(length) + data = await reader(length) if mask: data = apply_mask(data, mask_bits) diff --git a/src/websockets/http.py b/src/websockets/http.py index ea17e0a2e..5e04e53bd 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -7,7 +7,6 @@ """ -import asyncio import collections.abc import re import sys @@ -49,8 +48,7 @@ _value_re = re.compile(rb"[\x09\x20-\x7e\x80-\xff]*") -@asyncio.coroutine -def read_request(stream): +async def read_request(stream): """ Read an HTTP/1.1 GET request from ``stream``. @@ -76,7 +74,7 @@ def read_request(stream): # version and because path isn't checked. Since WebSocket software tends # to implement HTTP/1.1 strictly, there's little need for lenient parsing. - request_line = yield from read_line(stream) + request_line = await read_line(stream) # This may raise "ValueError: not enough values to unpack" method, path, version = request_line.split(b" ", 2) @@ -87,13 +85,12 @@ def read_request(stream): raise ValueError("Unsupported HTTP version: %r" % version) path = path.decode("ascii", "surrogateescape") - headers = yield from read_headers(stream) + headers = await read_headers(stream) return path, headers -@asyncio.coroutine -def read_response(stream): +async def read_response(stream): """ Read an HTTP/1.1 response from ``stream``. @@ -117,7 +114,7 @@ def read_response(stream): # As in read_request, parsing is simple because a fixed value is expected # for version, status_code is a 3-digit number, and reason can be ignored. - status_line = yield from read_line(stream) + status_line = await read_line(stream) # This may raise "ValueError: not enough values to unpack" version, status_code, reason = status_line.split(b" ", 2) @@ -132,13 +129,12 @@ def read_response(stream): raise ValueError("Invalid HTTP reason phrase: %r" % reason) reason = reason.decode() - headers = yield from read_headers(stream) + headers = await read_headers(stream) return status_code, reason, headers -@asyncio.coroutine -def read_headers(stream): +async def read_headers(stream): """ Read HTTP headers from ``stream``. @@ -155,7 +151,7 @@ def read_headers(stream): headers = Headers() for _ in range(MAX_HEADERS + 1): - line = yield from read_line(stream) + line = await read_line(stream) if line == b"": break @@ -177,8 +173,7 @@ def read_headers(stream): return headers -@asyncio.coroutine -def read_line(stream): +async def read_line(stream): """ Read a single line from ``stream``. @@ -188,7 +183,7 @@ def read_line(stream): """ # Security: this is bounded by the StreamReader's limit (default = 32kB). - line = yield from stream.readline() + line = await stream.readline() # Security: this guarantees header values are small (hard-coded = 4kB) if len(line) > MAX_LINE: raise ValueError("Line too long") diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 62845e0a8..7f20bed62 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -16,7 +16,6 @@ import random import struct import sys -import warnings from .exceptions import ( ConnectionClosed, @@ -33,15 +32,6 @@ logger = logging.getLogger(__name__) -# On Python ≥ 3.7, silence a deprecation warning that we can't address before -# dropping support for Python < 3.5. -warnings.filterwarnings( - action="ignore", - message=r"'with \(yield from lock\)' is deprecated use 'async with lock' instead", - category=DeprecationWarning, -) - - # A WebSocket connection goes through the following four states, in order: @@ -346,8 +336,7 @@ def closed(self): """ return self.state is State.CLOSED - @asyncio.coroutine - def wait_closed(self): + async def wait_closed(self): """ Wait until the connection is closed. @@ -357,10 +346,9 @@ def wait_closed(self): of its cause, in tasks that interact with the WebSocket connection. """ - yield from asyncio.shield(self.connection_lost_waiter) + await asyncio.shield(self.connection_lost_waiter) - @asyncio.coroutine - def recv(self): + async def recv(self): """ This coroutine receives the next message. @@ -392,7 +380,7 @@ def recv(self): "is already waiting for the next message" ) - # Don't yield from self.ensure_open() here: + # Don't await self.ensure_open() here: # - messages could be available in the queue even if the connection # is closed; # - messages could be received before the closing frame even if the @@ -406,7 +394,7 @@ def recv(self): try: # If asyncio.wait() is canceled, it doesn't cancel # pop_message_waiter and self.transfer_data_task. - yield from asyncio.wait( + await asyncio.wait( [pop_message_waiter, self.transfer_data_task], loop=self.loop, return_when=asyncio.FIRST_COMPLETED, @@ -424,7 +412,7 @@ def recv(self): assert self.state in [State.CLOSING, State.CLOSED] # Wait until the connection is closed to raise # ConnectionClosed with the correct code and reason. - yield from self.ensure_open() + await self.ensure_open() # Pop a message from the queue. message = self.messages.popleft() @@ -436,8 +424,7 @@ def recv(self): return message - @asyncio.coroutine - def send(self, data): + async def send(self, data): """ This coroutine sends a message. @@ -453,7 +440,7 @@ def send(self, data): It raises a :exc:`TypeError` for other inputs. """ - yield from self.ensure_open() + await self.ensure_open() # Unfragmented message -- this case must be handled first because # strings and bytes-like objects are iterable. @@ -464,7 +451,7 @@ def send(self, data): # Perhaps data is an iterator, see below. pass else: - yield from self.write_frame(True, opcode, data) + await self.write_frame(True, opcode, data) return # Fragmented message -- regular iterator. @@ -478,7 +465,7 @@ def send(self, data): except StopIteration: return opcode, data = prepare_data(data) - yield from self.write_frame(False, opcode, data) + await self.write_frame(False, opcode, data) # Other fragments. for data in iter_data: @@ -488,10 +475,10 @@ def send(self, data): # complete it. This makes the connection unusable. self.fail_connection(1011) raise TypeError("data contains inconsistent types") - yield from self.write_frame(False, OP_CONT, data) + await self.write_frame(False, OP_CONT, data) # Final fragment. - yield from self.write_frame(True, OP_CONT, b"") + await self.write_frame(True, OP_CONT, b"") # Fragmented message -- asynchronous iterator @@ -500,8 +487,7 @@ def send(self, data): else: raise TypeError("data must be bytes, str, or iterable") - @asyncio.coroutine - def close(self, code=1000, reason=""): + async def close(self, code=1000, reason=""): """ This coroutine performs the closing handshake. @@ -519,7 +505,7 @@ def close(self, code=1000, reason=""): """ try: - yield from asyncio.wait_for( + await asyncio.wait_for( self.write_close_frame(serialize_close(code, reason)), self.close_timeout, loop=self.loop, @@ -541,17 +527,16 @@ def close(self, code=1000, reason=""): # If close() is canceled during the wait, self.transfer_data_task # is canceled before the timeout elapses (on Python ≥ 3.4.3). # This helps closing connections when shutting down a server. - yield from asyncio.wait_for( + await asyncio.wait_for( self.transfer_data_task, self.close_timeout, loop=self.loop ) except (asyncio.TimeoutError, asyncio.CancelledError): pass # Wait for the close connection task to close the TCP connection. - yield from asyncio.shield(self.close_connection_task) + await asyncio.shield(self.close_connection_task) - @asyncio.coroutine - def ping(self, data=None): + async def ping(self, data=None): """ This coroutine sends a ping. @@ -570,7 +555,7 @@ def ping(self, data=None): (which will be encoded to UTF-8) or a bytes-like object. """ - yield from self.ensure_open() + await self.ensure_open() if data is not None: data = encode_data(data) @@ -585,12 +570,11 @@ def ping(self, data=None): self.pings[data] = asyncio.Future(loop=self.loop) - yield from self.write_frame(True, OP_PING, data) + await self.write_frame(True, OP_PING, data) return asyncio.shield(self.pings[data]) - @asyncio.coroutine - def pong(self, data=b""): + async def pong(self, data=b""): """ This coroutine sends a pong. @@ -601,16 +585,15 @@ def pong(self, data=b""): bytes-like object. """ - yield from self.ensure_open() + await self.ensure_open() data = encode_data(data) - yield from self.write_frame(True, OP_PONG, data) + await self.write_frame(True, OP_PONG, data) # Private methods - no guarantees. - @asyncio.coroutine - def ensure_open(self): + async def ensure_open(self): """ Check that the WebSocket connection is open. @@ -623,7 +606,7 @@ def ensure_open(self): # self.close_connection_task may be closing it, going straight # from OPEN to CLOSED. if self.transfer_data_task.done(): - yield from asyncio.shield(self.close_connection_task) + await asyncio.shield(self.close_connection_task) raise ConnectionClosed( self.close_code, self.close_reason ) from self.transfer_data_exc @@ -642,7 +625,7 @@ def ensure_open(self): # CLOSING state also occurs when failing the connection. In that # case self.close_connection_task will complete even faster. if self.close_code is None: - yield from asyncio.shield(self.close_connection_task) + await asyncio.shield(self.close_connection_task) raise ConnectionClosed( self.close_code, self.close_reason ) from self.transfer_data_exc @@ -651,8 +634,7 @@ def ensure_open(self): assert self.state is State.CONNECTING raise InvalidState("WebSocket connection isn't established yet") - @asyncio.coroutine - def transfer_data(self): + async def transfer_data(self): """ Read incoming messages and put them in a queue. @@ -661,7 +643,7 @@ def transfer_data(self): """ try: while True: - message = yield from self.read_message() + message = await self.read_message() # Exit the loop when receiving a close frame. if message is None: @@ -671,7 +653,7 @@ def transfer_data(self): while len(self.messages) >= self.max_queue: self._put_message_waiter = asyncio.Future(loop=self.loop) try: - yield from self._put_message_waiter + await self._put_message_waiter finally: self._put_message_waiter = None @@ -719,8 +701,7 @@ def transfer_data(self): self.transfer_data_exc = exc self.fail_connection(1011) - @asyncio.coroutine - def read_message(self): + async def read_message(self): """ Read a single message from the connection. @@ -729,7 +710,7 @@ def read_message(self): Return ``None`` when the closing handshake is started. """ - frame = yield from self.read_data_frame(max_size=self.max_size) + frame = await self.read_data_frame(max_size=self.max_size) # A close frame was received. if frame is None: @@ -781,7 +762,7 @@ def append(frame): append(frame) while not frame.fin: - frame = yield from self.read_data_frame(max_size=max_size) + frame = await self.read_data_frame(max_size=max_size) if frame is None: raise WebSocketProtocolError("Incomplete fragmented message") if frame.opcode != OP_CONT: @@ -790,8 +771,7 @@ def append(frame): return ("" if text else b"").join(chunks) - @asyncio.coroutine - def read_data_frame(self, max_size): + async def read_data_frame(self, max_size): """ Read a single data frame from the connection. @@ -802,7 +782,7 @@ def read_data_frame(self, max_size): """ # 6.2. Receiving Data while True: - frame = yield from self.read_frame(max_size) + frame = await self.read_frame(max_size) # 5.5. Control Frames if frame.opcode == OP_CLOSE: @@ -812,7 +792,7 @@ def read_data_frame(self, max_size): # Echo the original data instead of re-serializing it with # serialize_close() because that fails when the close frame is # empty and parse_close() synthetizes a 1005 close code. - yield from self.write_close_frame(frame.data) + await self.write_close_frame(frame.data) return elif frame.opcode == OP_PING: @@ -822,7 +802,7 @@ def read_data_frame(self, max_size): logger.debug( "%s - received ping, sending pong: %s", self.side, ping_hex ) - yield from self.pong(frame.data) + await self.pong(frame.data) elif frame.opcode == OP_PONG: # Acknowledge pings on solicited pongs. @@ -861,13 +841,12 @@ def read_data_frame(self, max_size): else: return frame - @asyncio.coroutine - def read_frame(self, max_size): + async def read_frame(self, max_size): """ Read a single frame from the connection. """ - frame = yield from Frame.read( + frame = await Frame.read( self.reader.readexactly, mask=not self.is_client, max_size=max_size, @@ -876,8 +855,7 @@ def read_frame(self, max_size): logger.debug("%s < %r", self.side, frame) return frame - @asyncio.coroutine - def write_frame(self, fin, opcode, data, *, _expected_state=State.OPEN): + async def write_frame(self, fin, opcode, data, *, _expected_state=State.OPEN): # Defensive assertion for protocol compliance. if self.state is not _expected_state: # pragma: no cover raise InvalidState( @@ -892,21 +870,21 @@ def write_frame(self, fin, opcode, data, *, _expected_state=State.OPEN): # Remove when dropping support for Python < 3.6. if self.writer.transport is not None: # pragma: no cover if self.writer_is_closing(): - yield + await asyncio.sleep(0) try: # drain() cannot be called concurrently by multiple coroutines: # http://bugs.python.org/issue29930. Remove this lock when no # version of Python where this bugs exists is supported anymore. - with (yield from self._drain_lock): + async with self._drain_lock: # Handle flow control automatically. - yield from self.writer.drain() + await self.writer.drain() except ConnectionError: # Terminate the connection if the socket died. self.fail_connection() # Wait until the connection is closed to raise ConnectionClosed # with the correct code and reason. - yield from self.ensure_open() + await self.ensure_open() def writer_is_closing(self): """ @@ -927,8 +905,7 @@ def writer_is_closing(self): except AttributeError: return transport._closed - @asyncio.coroutine - def write_close_frame(self, data=b""): + async def write_close_frame(self, data=b""): """ Write a close frame if and only if the connection state is OPEN. @@ -944,12 +921,9 @@ def write_close_frame(self, data=b""): logger.debug("%s - state = CLOSING", self.side) # 7.1.2. Start the WebSocket Closing Handshake - yield from self.write_frame( - True, OP_CLOSE, data, _expected_state=State.CLOSING - ) + await self.write_frame(True, OP_CLOSE, data, _expected_state=State.CLOSING) - @asyncio.coroutine - def keepalive_ping(self): + async def keepalive_ping(self): """ Send a Ping frame and wait for a Pong frame at regular intervals. @@ -964,18 +938,18 @@ def keepalive_ping(self): try: while True: - yield from asyncio.sleep(self.ping_interval, loop=self.loop) + await asyncio.sleep(self.ping_interval, loop=self.loop) # ping() cannot raise ConnectionClosed, only CancelledError: # - If the connection is CLOSING, keepalive_ping_task will be # canceled by close_connection() before ping() returns. # - If the connection is CLOSED, keepalive_ping_task must be # canceled already. - ping_waiter = yield from self.ping() + ping_waiter = await self.ping() if self.ping_timeout is not None: try: - yield from asyncio.wait_for( + await asyncio.wait_for( ping_waiter, self.ping_timeout, loop=self.loop ) except asyncio.TimeoutError: @@ -989,8 +963,7 @@ def keepalive_ping(self): except Exception: logger.warning("Unexpected exception in keepalive ping task", exc_info=True) - @asyncio.coroutine - def close_connection(self): + async def close_connection(self): """ 7.1.1. Close the WebSocket Connection @@ -1006,7 +979,7 @@ def close_connection(self): # Wait for the data transfer phase to complete. if self.transfer_data_task is not None: try: - yield from self.transfer_data_task + await self.transfer_data_task except asyncio.CancelledError: pass @@ -1016,7 +989,7 @@ def close_connection(self): # A client should wait for a TCP close from the server. if self.is_client and self.transfer_data_task is not None: - if (yield from self.wait_for_connection_lost()): + if await self.wait_for_connection_lost(): return logger.debug("%s ! timed out waiting for TCP close", self.side) @@ -1025,7 +998,7 @@ def close_connection(self): logger.debug("%s x half-closing TCP connection", self.side) self.writer.write_eof() - if (yield from self.wait_for_connection_lost()): + if await self.wait_for_connection_lost(): return logger.debug("%s ! timed out waiting for TCP close", self.side) @@ -1043,7 +1016,7 @@ def close_connection(self): logger.debug("%s x closing TCP connection", self.side) self.writer.close() - if (yield from self.wait_for_connection_lost()): + if await self.wait_for_connection_lost(): return logger.debug("%s ! timed out waiting for TCP close", self.side) @@ -1052,10 +1025,9 @@ def close_connection(self): self.writer.transport.abort() # connection_lost() is called quickly after aborting. - yield from self.wait_for_connection_lost() + await self.wait_for_connection_lost() - @asyncio.coroutine - def wait_for_connection_lost(self): + async def wait_for_connection_lost(self): """ Wait until the TCP connection is closed or ``self.close_timeout`` elapses. @@ -1064,7 +1036,7 @@ def wait_for_connection_lost(self): """ if not self.connection_lost_waiter.done(): try: - yield from asyncio.wait_for( + await asyncio.wait_for( asyncio.shield(self.connection_lost_waiter), self.close_timeout, loop=self.loop, diff --git a/src/websockets/server.py b/src/websockets/server.py index 424d08922..752170edf 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -89,8 +89,7 @@ def connection_made(self, transport): self.ws_server.register(self) self.handler_task = self.loop.create_task(self.handler()) - @asyncio.coroutine - def handler(self): + async def handler(self): """ Handle the lifecycle of a WebSocket connection. @@ -102,7 +101,7 @@ def handler(self): try: try: - path = yield from self.handshake( + path = await self.handshake( origins=self.origins, available_extensions=self.available_extensions, available_subprotocols=self.available_subprotocols, @@ -154,11 +153,11 @@ def handler(self): self.write_http_response(status, headers, body) self.fail_connection() - yield from self.wait_closed() + await self.wait_closed() return try: - yield from self.ws_handler(self, path) + await self.ws_handler(self, path) except Exception: logger.error("Error in connection handler", exc_info=True) if not self.closed: @@ -166,7 +165,7 @@ def handler(self): raise try: - yield from self.close() + await self.close() except ConnectionError: logger.debug("Connection error in closing handshake", exc_info=True) raise @@ -188,8 +187,7 @@ def handler(self): # connections before terminating. self.ws_server.unregister(self) - @asyncio.coroutine - def read_http_request(self): + async def read_http_request(self): """ Read request line and headers from the HTTP request. @@ -202,7 +200,7 @@ def read_http_request(self): """ try: - path, headers = yield from read_request(self.reader) + path, headers = await read_request(self.reader) except ValueError as exc: raise InvalidMessage("Malformed HTTP message") from exc @@ -426,8 +424,7 @@ def select_subprotocol(client_subprotocols, server_subprotocols): ) return sorted(subprotocols, key=priority)[0] - @asyncio.coroutine - def handshake( + async def handshake( self, origins=None, available_extensions=None, @@ -458,12 +455,12 @@ def handshake( Return the path of the URI of the request. """ - path, request_headers = yield from self.read_http_request() + path, request_headers = await self.read_http_request() # Hook for customizing request handling, for example checking # authentication or treating some paths as plain HTTP endpoints. if asyncio.iscoroutinefunction(self.process_request): - early_response = yield from self.process_request(path, request_headers) + early_response = await self.process_request(path, request_headers) else: early_response = self.process_request(path, request_headers) @@ -604,8 +601,7 @@ def close(self): if self.close_task is None: self.close_task = self.loop.create_task(self._close()) - @asyncio.coroutine - def _close(self): + async def _close(self): """ Implementation of :meth:`close`. @@ -618,11 +614,11 @@ def _close(self): self.server.close() # Wait until self.server.close() completes. - yield from self.server.wait_closed() + await self.server.wait_closed() # Wait until all accepted connections reach connection_made() and call # register(). See https://bugs.python.org/issue34852 for details. - yield from asyncio.sleep(0) + await asyncio.sleep(0) # Close open connections. fail_connection() will cancel the transfer # data task, which is expected to cause the handler task to terminate. @@ -637,7 +633,7 @@ def _close(self): # running tasks. # TODO: it would be nicer to wait only for the connection handler # and let the handler wait for the connection to close. - yield from asyncio.wait( + await asyncio.wait( [websocket.handler_task for websocket in self.websockets] + [ websocket.close_connection_task @@ -650,8 +646,7 @@ def _close(self): # Tell wait_closed() to return. self.closed_waiter.set_result(None) - @asyncio.coroutine - def wait_closed(self): + async def wait_closed(self): """ Wait until the server is closed and all connections are terminated. @@ -659,7 +654,7 @@ def wait_closed(self): there are no pending tasks left. """ - yield from asyncio.shield(self.closed_waiter) + await asyncio.shield(self.closed_waiter) @property def sockets(self): @@ -845,10 +840,8 @@ def __init__( self.ws_server = ws_server @asyncio.coroutine - def __iter__(self): # pragma: no cover - server = yield from self._creating_server - self.ws_server.wrap(server) - return self.ws_server + def __iter__(self): + return self.__await_impl__() async def __aenter__(self): return await self @@ -858,8 +851,6 @@ async def __aexit__(self, exc_type, exc_value, traceback): await self.ws_server.wait_closed() async def __await_impl__(self): - # Duplicated with __iter__ because Python 3.7 requires an async function - # (as explained in __await__ below) which Python 3.4 doesn't support. server = await self._creating_server self.ws_server.wrap(server) return self.ws_server @@ -895,8 +886,7 @@ def unix_serve(ws_handler, path, **kwargs): del Serve.__aexit__ del Serve.__await__ - @asyncio.coroutine - def serve(*args, **kwds): + async def serve(*args, **kwds): return Serve(*args, **kwds).__iter__() serve.__doc__ = Serve.__doc__ diff --git a/tests/test_client_server.py b/tests/test_client_server.py index 633e097bc..d155f7fae 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -49,26 +49,25 @@ testcert = bytes(pathlib.Path(__file__).with_name("test_localhost.pem")) -@asyncio.coroutine -def handler(ws, path): +async def handler(ws, path): if path == "/attributes": - yield from ws.send(repr((ws.host, ws.port, ws.secure))) + await ws.send(repr((ws.host, ws.port, ws.secure))) elif path == "/close_timeout": - yield from ws.send(repr(ws.close_timeout)) + await ws.send(repr(ws.close_timeout)) elif path == "/path": - yield from ws.send(str(ws.path)) + await ws.send(str(ws.path)) elif path == "/headers": - yield from ws.send(repr(ws.request_headers)) - yield from ws.send(repr(ws.response_headers)) + await ws.send(repr(ws.request_headers)) + await ws.send(repr(ws.response_headers)) elif path == "/extensions": - yield from ws.send(repr(ws.extensions)) + await ws.send(repr(ws.extensions)) elif path == "/subprotocol": - yield from ws.send(repr(ws.subprotocol)) + await ws.send(repr(ws.subprotocol)) elif path == "/slow_stop": - yield from ws.wait_closed() - yield from asyncio.sleep(2 * MS) + await ws.wait_closed() + await asyncio.sleep(2 * MS) else: - yield from ws.send((yield from ws.recv())) + await ws.send((await ws.recv())) @contextlib.contextmanager @@ -162,31 +161,27 @@ def get_server_uri(server, secure=False, resource_name="/", user_info=None): class UnauthorizedServerProtocol(WebSocketServerProtocol): - @asyncio.coroutine - def process_request(self, path, request_headers): + async def process_request(self, path, request_headers): # Test returning headers as a Headers instance (1/3) return http.HTTPStatus.UNAUTHORIZED, Headers([("X-Access", "denied")]), b"" class ForbiddenServerProtocol(WebSocketServerProtocol): - @asyncio.coroutine - def process_request(self, path, request_headers): + async def process_request(self, path, request_headers): # Test returning headers as a dict (2/3) return http.HTTPStatus.FORBIDDEN, {"X-Access": "denied"}, b"" class HealthCheckServerProtocol(WebSocketServerProtocol): - @asyncio.coroutine - def process_request(self, path, request_headers): + async def process_request(self, path, request_headers): # Test returning headers as a list of pairs (3/3) if path == "/__health__/": return http.HTTPStatus.OK, [("X-Access", "OK")], b"status = green\n" class SlowServerProtocol(WebSocketServerProtocol): - @asyncio.coroutine - def process_request(self, path, request_headers): - yield from asyncio.sleep(10 * MS) + async def process_request(self, path, request_headers): + await asyncio.sleep(10 * MS) class FooClientProtocol(WebSocketClientProtocol): @@ -957,9 +952,8 @@ def wrong_build_response(headers, key): @with_server() @unittest.mock.patch("websockets.client.read_response") def test_server_does_not_switch_protocols(self, _read_response): - @asyncio.coroutine - def wrong_read_response(stream): - status_code, reason, headers = yield from read_response(stream) + async def wrong_read_response(stream): + status_code, reason, headers = await read_response(stream) return 400, "Bad Request", headers _read_response.side_effect = wrong_read_response diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 896c0fe4b..70c2be0bd 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -117,10 +117,9 @@ def make_drain_slow(self, delay=MS): original_drain = self.protocol.writer.drain - @asyncio.coroutine - def delayed_drain(): - yield from asyncio.sleep(delay, loop=self.loop) - yield from original_drain() + async def delayed_drain(): + await asyncio.sleep(delay, loop=self.loop) + await original_drain() self.protocol.writer.drain = delayed_drain @@ -474,8 +473,7 @@ def test_recv_queue_full(self): self.assertEqual(list(self.protocol.messages), []) def test_recv_other_error(self): - @asyncio.coroutine - def read_message(): + async def read_message(): raise Exception("BOOM") self.protocol.read_message = read_message @@ -1034,8 +1032,7 @@ def test_keepalive_ping_with_no_ping_timeout(self): def test_keepalive_ping_unexpected_error(self): self.restart_protocol_with_keepalive_ping() - @asyncio.coroutine - def ping(): + async def ping(): raise Exception("BOOM") self.protocol.ping = ping From f77ab68e23e61658548e2a624683ec07fe816b91 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 29 Dec 2018 17:29:14 +0100 Subject: [PATCH 0526/1539] Miscellaneous cleanups. --- setup.py | 10 ++-------- src/websockets/protocol.py | 5 ++--- src/websockets/server.py | 2 +- tests/test_client_server.py | 8 +------- 4 files changed, 6 insertions(+), 19 deletions(-) diff --git a/setup.py b/setup.py index 78d6f7af4..2956058a4 100644 --- a/setup.py +++ b/setup.py @@ -8,15 +8,9 @@ description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" -# When dropping Python < 3.5, change to: -# long_description = (root_dir / 'README.rst').read_text(encoding='utf-8') -with (root_dir / 'README.rst').open(encoding='utf-8') as f: - long_description = f.read() +long_description = (root_dir / 'README.rst').read_text(encoding='utf-8') -# When dropping Python < 3.5, change to: -# exec((root_dir / 'src' / 'websockets' / 'version.py').read_text(encoding='utf-8')) -with (root_dir / 'src' / 'websockets' / 'version.py').open(encoding='utf-8') as f: - exec(f.read()) +exec((root_dir / 'src' / 'websockets' / 'version.py').read_text(encoding='utf-8')) py_version = sys.version_info[:2] diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 7f20bed62..5c60348aa 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -525,7 +525,7 @@ async def close(self, code=1000, reason=""): try: # If close() is canceled during the wait, self.transfer_data_task - # is canceled before the timeout elapses (on Python ≥ 3.4.3). + # is canceled before the timeout elapses. # This helps closing connections when shutting down a server. await asyncio.wait_for( self.transfer_data_task, self.close_timeout, loop=self.loop @@ -797,8 +797,7 @@ async def read_data_frame(self, max_size): elif frame.opcode == OP_PING: # Answer pings. - # Replace by frame.data.hex() when dropping Python < 3.5. - ping_hex = binascii.hexlify(frame.data).decode() or "[empty]" + ping_hex = frame.data.hex() or "[empty]" logger.debug( "%s - received ping, sending pong: %s", self.side, ping_hex ) diff --git a/src/websockets/server.py b/src/websockets/server.py index 752170edf..839b3c861 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -677,7 +677,7 @@ class Serve: :meth:`~websockets.server.WebSocketServer.wait_closed` methods for terminating the server and cleaning up its resources. - On Python ≥ 3.5, :func:`serve` can also be used as an asynchronous context + On Python ≥ 3.5.1, :func:`serve` can also be used as an asynchronous context manager. In this case, the server is shut down when exiting the context. :func:`serve` is a wrapper around the event loop's diff --git a/tests/test_client_server.py b/tests/test_client_server.py index d155f7fae..6b80c7f6e 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -1133,13 +1133,7 @@ def start_client(self, path="/", **kwds): @with_server() def test_ws_uri_is_rejected(self): with self.assertRaises(ValueError): - client = connect( - get_server_uri(self.server, secure=False), ssl=self.client_context - ) - # With Python ≥ 3.5, the exception is raised by connect() even - # before awaiting. However, with Python 3.4 the exception is - # raised only when awaiting. - self.loop.run_until_complete(client) # pragma: no cover + connect(get_server_uri(self.server, secure=False), ssl=self.client_context) @with_server() def test_redirect_insecure(self): From 22a4604cfdaedc27141fef8048c55cdf2c899185 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 29 Dec 2018 17:34:31 +0100 Subject: [PATCH 0527/1539] Remove documentation for Python 3.4. --- docs/intro.rst | 25 ------------------------- example/old_client.py | 11 +++++------ example/old_server.py | 7 +++---- 3 files changed, 8 insertions(+), 35 deletions(-) diff --git a/docs/intro.rst b/docs/intro.rst index b153d2f5d..376b7d9ca 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -233,7 +233,6 @@ Python < 3.6 This documentation takes advantage of several features that aren't available in Python < 3.6: -- ``await`` and ``async`` were added in Python 3.5; - Asynchronous context managers didn't work well until Python 3.5.1; - Asynchronous iterators were added in Python 3.6; - f-strings were introduced in Python 3.6 (this is unrelated to :mod:`asyncio` @@ -242,34 +241,10 @@ in Python < 3.6: Here's how to adapt the basic server example. .. literalinclude:: ../example/old_server.py - :emphasize-lines: 8-9,18 And here's the basic client example. .. literalinclude:: ../example/old_client.py - :emphasize-lines: 8-11,13,22-23 - -``await`` and ``async`` -....................... - -If you're using Python < 3.5, you must substitute:: - - async def ... - -with:: - - @asyncio.coroutine - def ... - -and:: - - await ... - -with:: - - yield from ... - -Otherwise you will encounter a :exc:`SyntaxError`. Asynchronous context managers ............................. diff --git a/example/old_client.py b/example/old_client.py index c44d6edff..be34f14be 100755 --- a/example/old_client.py +++ b/example/old_client.py @@ -5,21 +5,20 @@ import asyncio import websockets -@asyncio.coroutine -def hello(): - websocket = yield from websockets.connect( +async def hello(): + websocket = await websockets.connect( 'ws://localhost:8765/') try: name = input("What's your name? ") - yield from websocket.send(name) + await websocket.send(name) print("> {}".format(name)) - greeting = yield from websocket.recv() + greeting = await websocket.recv() print("< {}".format(greeting)) finally: - yield from websocket.close() + await websocket.close() asyncio.get_event_loop().run_until_complete(hello()) diff --git a/example/old_server.py b/example/old_server.py index bb19bdabc..8c63e33e6 100755 --- a/example/old_server.py +++ b/example/old_server.py @@ -5,14 +5,13 @@ import asyncio import websockets -@asyncio.coroutine -def hello(websocket, path): - name = yield from websocket.recv() +async def hello(websocket, path): + name = await websocket.recv() print("< {}".format(name)) greeting = "Hello {}!".format(name) - yield from websocket.send(greeting) + await websocket.send(greeting) print("> {}".format(greeting)) start_server = websockets.serve(hello, 'localhost', 8765) From 40adef93ae4cc74fef34d8ad4e72648a361799e8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 29 Dec 2018 22:33:24 +0100 Subject: [PATCH 0528/1539] Add documentation for extensions. Fix #255. --- docs/api.rst | 9 ++ docs/changelog.rst | 2 + docs/extensions.rst | 85 +++++++++++++++++++ docs/index.rst | 1 + src/websockets/extensions/base.py | 72 ++++++++++------ .../extensions/permessage_deflate.py | 50 ++++++++--- 6 files changed, 180 insertions(+), 39 deletions(-) create mode 100644 docs/extensions.rst diff --git a/docs/api.rst b/docs/api.rst index 80d64e254..e480604bb 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -82,6 +82,15 @@ Shared .. autoattribute:: open .. autoattribute:: closed +Per-Message Deflate Extension +............................. + +.. automodule:: websockets.extensions.permessage_deflate + + .. autoclass:: ServerPerMessageDeflateFactory + + .. autoclass:: ClientPerMessageDeflateFactory + Exceptions .......... diff --git a/docs/changelog.rst b/docs/changelog.rst index 87e2e0ac8..e4fd55fb4 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -28,6 +28,8 @@ Also: * :func:`~client.connect()` handles redirects from the server during the handshake. +* Added documentation for extensions. + 7.0 ... diff --git a/docs/extensions.rst b/docs/extensions.rst new file mode 100644 index 000000000..3a5885009 --- /dev/null +++ b/docs/extensions.rst @@ -0,0 +1,85 @@ +Extensions +========== + +.. currentmodule:: websockets + +The WebSocket protocol supports extensions_. + +At the time of writing, there's only one `registered extension`_, WebSocket +Per-Message Deflate, specified in :rfc:`7692`. + +.. _extensions: https://tools.ietf.org/html/rfc6455#section-9 +.. _registered extension: https://www.iana.org/assignments/websocket/websocket.xhtml#extension-name + +Per-Message Deflate +------------------- + +:func:`~server.serve()` and :func:`~client.connect()` enable the Per-Message +Deflate extension by default. You can disable this with ``compression=None``. + +You can also configure the Per-Message Deflate extension explicitly if you +want to customize its parameters. + +Here's an example on the server side:: + + import websockets + from websockets.extensions import permessage_deflate + + websockets.serve( + ..., + extensions=[ + permessage_deflate.ServerPerMessageDeflateFactory( + server_max_window_bits=11, + client_max_window_bits=11, + compress_settings={'memLevel': 4}, + ), + ], + ) + +Here's an example on the client side:: + + import websockets + from websockets.extensions import permessage_deflate + + websockets.connect( + ..., + extensions=[ + permessage_deflate.ClientPerMessageDeflateFactory( + server_max_window_bits=11, + client_max_window_bits=11, + compress_settings={'memLevel': 4}, + ), + ], + ) + +Refer to the API documentation of +:class:`~extensions.permessage_deflate.ServerPerMessageDeflateFactory` and +:class:`~extensions.permessage_deflate.ClientPerMessageDeflateFactory` for +details. + +Writing an extension +-------------------- + +During the opening handshake, WebSocket clients and servers negotiate which +extensions will be used with which parameters. Then each frame is processed by +extensions before it's sent and after it's received. + +As a consequence writing an extension requires implementing several classes: + +1. Extension Factory: it negotiates parameters and instanciates the extension. + Clients and servers require separate extension factories with distict APIs. + +2. Extension: it decodes incoming frames and encodes outgoing frames. If the + extension is symmetrical, clients and servers can use the same class. + +``websockets`` provides abstract base classes for extension factories and +extensions. + +.. autoclass:: websockets.extensions.base.ServerExtensionFactory + :members: + +.. autoclass:: websockets.extensions.base.ClientExtensionFactory + :members: + +.. autoclass:: websockets.extensions.base.Extension + :members: diff --git a/docs/index.rst b/docs/index.rst index 7ccd9463e..040d41598 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -60,6 +60,7 @@ These guides will help you build and deploy a ``websockets`` application. cheatsheet deployment + extensions Reference --------- diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index 69b55b3f8..cf3f9a2ec 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -1,7 +1,8 @@ """ -The :mod:`websockets.extensions.base` defines abstract classes for extensions. +The :mod:`websockets.extensions.base` module defines abstract classes for +implementing extensions as specified in `section 9 of RFC 6455`_. -See https://tools.ietf.org/html/rfc6455#section-9. +.. _section 9 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-9 """ @@ -10,11 +11,14 @@ class ClientExtensionFactory: """ Abstract class for client-side extension factories. - Extension factories handle configuration and negotiation. - """ - name = ... + @property + def name(self): + """ + Extension identifier. + + """ def get_request_params(self): """ @@ -25,18 +29,17 @@ def get_request_params(self): """ def process_response_params(self, params, accepted_extensions): - """" - Process response parameters. + """ + Process response parameters received from the server. - ``params`` are a list of (name, value) pairs. + ``params`` is a list of (name, value) pairs. - ``accepted_extensions`` is a list of previously accepted extensions, - represented by extension instances. + ``accepted_extensions`` is a list of previously accepted extensions. - Return an extension instance (an instance of a subclass of - :class:`Extension`) if these parameters are acceptable. + If parameters are acceptable, return an extension: an instance of a + subclass of :class:`Extension`. - Raise :exc:`~websockets.exceptions.NegotiationError` if they aren't. + If they aren't, raise :exc:`~websockets.exceptions.NegotiationError`. """ @@ -45,24 +48,30 @@ class ServerExtensionFactory: """ Abstract class for server-side extension factories. - Extension factories handle configuration and negotiation. - """ - name = ... + @property + def name(self): + """ + Extension identifier. + + """ def process_request_params(self, params, accepted_extensions): - """" - Process request parameters. + """ + Process request parameters received from the client. + + ``params`` is a list of (name, value) pairs. - ``accepted_extensions`` is a list of previously accepted extensions, - represented by extension instances. + ``accepted_extensions`` is a list of previously accepted extensions. - Return response params (a list of (name, value) pairs) and an - extension instance (an instance of a subclass of :class:`Extension`) - to accept this extension. + To accept the offer, return a 2-uple containing: - Raise :exc:`~websockets.exceptions.NegotiationError` to reject it. + - response parameters: a list of (name, value) pairs + - an extension: an instance of a subclass of :class:`Extension` + + To reject the offer, raise + :exc:`~websockets.exceptions.NegotiationError`. """ @@ -73,13 +82,21 @@ class Extension: """ - name = ... + @property + def name(self): + """ + Extension identifier. + + """ def decode(self, frame, *, max_size=None): """ Decode an incoming frame. - Return a frame. + The ``frame`` parameter and the return value are + :class:`~websockets.framing.Frame` instances. + + """ @@ -87,6 +104,7 @@ def encode(self, frame): """ Encode an outgoing frame. - Return a frame. + The ``frame`` parameter and the return value are + :class:`~websockets.framing.Frame` instances. """ diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index dad6f1ec1..167746021 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -113,7 +113,22 @@ def _extract_parameters(params, *, is_server): class ClientPerMessageDeflateFactory: """ - Client-side extension factory for permessage-deflate extension. + Client-side extension factory for Per-Message Deflate extension. + + These parameters behave as described in `section 7.1 of RFC 7692`_: + + - ``server_no_context_takeover`` + - ``client_no_context_takeover`` + - ``server_max_window_bits`` + - ``client_max_window_bits`` + + Set them to ``True`` to include them in the negotiation offer without a + value or to an integer value to include them with this value. + + .. _section 7.1 of RFC 7692: https://tools.ietf.org/html/rfc7692#section-7.1 + + ``compress_settings`` is an optional :class:`dict` of keyword arguments + for :func:`zlib.compressobj`, excluding ``wbits``. """ @@ -128,9 +143,7 @@ def __init__( compress_settings=None, ): """ - Configure permessage-deflate extension factory. - - See https://tools.ietf.org/html/rfc7692#section-7.1. + Configure the Per-Message Deflate extension factory. """ if not (server_max_window_bits is None or 8 <= server_max_window_bits <= 15): @@ -166,7 +179,7 @@ def get_request_params(self): ) def process_response_params(self, params, accepted_extensions): - """" + """ Process response parameters. Return an extension instance. @@ -269,7 +282,22 @@ def process_response_params(self, params, accepted_extensions): class ServerPerMessageDeflateFactory: """ - Server-side extension factory for permessage-deflate extension. + Server-side extension factory for the Per-Message Deflate extension. + + These parameters behave as described in `section 7.1 of RFC 7692`_: + + - ``server_no_context_takeover`` + - ``client_no_context_takeover`` + - ``server_max_window_bits`` + - ``client_max_window_bits`` + + Set them to ``True`` to include them in the negotiation offer without a + value or to an integer value to include them with this value. + + .. _section 7.1 of RFC 7692: https://tools.ietf.org/html/rfc7692#section-7.1 + + ``compress_settings`` is an optional :class:`dict` of keyword arguments + for :func:`zlib.compressobj`, excluding ``wbits``. """ @@ -284,9 +312,7 @@ def __init__( compress_settings=None, ): """ - Configure permessage-deflate extension factory. - - See https://tools.ietf.org/html/rfc7692#section-7.1. + Configure the Per-Message Deflate extension factory. """ if not (server_max_window_bits is None or 8 <= server_max_window_bits <= 15): @@ -306,7 +332,7 @@ def __init__( self.compress_settings = compress_settings def process_request_params(self, params, accepted_extensions): - """" + """ Process request parameters. Return response params and an extension instance. @@ -416,7 +442,7 @@ def process_request_params(self, params, accepted_extensions): class PerMessageDeflate: """ - permessage-deflate extension. + Per-Message Deflate extension. """ @@ -431,7 +457,7 @@ def __init__( compress_settings=None, ): """ - Configure permessage-deflate extension. + Configure the Per-Message Deflate extension. """ if compress_settings is None: From 8bcfd9aacd4b93b9df687a6d1e171b03ea3727c9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 30 Dec 2018 11:57:38 +0100 Subject: [PATCH 0529/1539] Remove obsolete description. That behavior changed in 7.0. --- docs/deployment.rst | 4 ---- 1 file changed, 4 deletions(-) diff --git a/docs/deployment.rst b/docs/deployment.rst index 0f571520d..7eb350606 100644 --- a/docs/deployment.rst +++ b/docs/deployment.rst @@ -30,10 +30,6 @@ with the object returned by :func:`~server.serve`: - calling its ``close()`` method, then waiting for its ``wait_closed()`` method to complete. -Tasks that handle connections will be canceled. For example, if the handler -is awaiting :meth:`~protocol.WebSocketCommonProtocol.recv`, that call will -raise :exc:`~asyncio.CancelledError`. - On Unix systems, shutdown is usually triggered by sending a signal. Here's a full example (Unix-only): From 4034bc768f1adec08274ef28f62ea6e401d4e88e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 30 Dec 2018 11:59:56 +0100 Subject: [PATCH 0530/1539] Remove documentation for Python 3.4. Missed from 22a4604c. --- docs/deployment.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/deployment.rst b/docs/deployment.rst index 7eb350606..6758e6afd 100644 --- a/docs/deployment.rst +++ b/docs/deployment.rst @@ -37,9 +37,9 @@ Here's a full example (Unix-only): .. literalinclude:: ../example/shutdown.py :emphasize-lines: 13,17-19 -``async`` and ``await`` were introduced in Python 3.5. websockets supports -asynchronous context managers on Python ≥ 3.5.1. ``async for`` was introduced -in Python 3.6. Here's the equivalent for older Python versions: +websockets supports asynchronous context managers on Python ≥ 3.5.1. ``async +for`` was introduced in Python 3.6. Here's the equivalent for older Python +versions: .. literalinclude:: ../example/old_shutdown.py :emphasize-lines: 22-25 From 3da06faadec19d50cc068f62e10e2ca456396f53 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 31 Dec 2018 10:16:15 +0100 Subject: [PATCH 0531/1539] Document how to optimize memory usage. Include benchmarking scripts. Also improve neighboring docs. Fix #272. --- docs/changelog.rst | 2 + docs/deployment.rst | 101 ++++++++++++++++++++++++++++++++++--- docs/design.rst | 10 ++-- docs/extensions.rst | 6 ++- docs/intro.rst | 2 + docs/security.rst | 33 +++++------- docs/spelling_wordlist.txt | 4 +- performance/mem_client.py | 54 ++++++++++++++++++++ performance/mem_server.py | 63 +++++++++++++++++++++++ src/websockets/http.py | 4 +- src/websockets/protocol.py | 8 +-- 11 files changed, 247 insertions(+), 40 deletions(-) create mode 100644 performance/mem_client.py create mode 100644 performance/mem_server.py diff --git a/docs/changelog.rst b/docs/changelog.rst index e4fd55fb4..320300f64 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -30,6 +30,8 @@ Also: * Added documentation for extensions. +* Documented how to optimize memory usage. + 7.0 ... diff --git a/docs/deployment.rst b/docs/deployment.rst index 6758e6afd..f8bc7f94b 100644 --- a/docs/deployment.rst +++ b/docs/deployment.rst @@ -50,15 +50,102 @@ projects try to help with this problem. If your server doesn't run in the main thread, look at :func:`~asyncio.AbstractEventLoop.call_soon_threadsafe`. -Memory use ----------- +Memory usage +------------ + +.. _memory-usage: + +In most cases, memory usage of a WebSocket server is proportional to the +number of open connections. When a server handles thousands of connections, +memory usage can become a bottleneck. + +Memory usage of a single connection is the sum of: + +1. the baseline amount of memory ``websockets`` requires for each connection, +2. the amount of data held in buffers before the application processes it, +3. any additional memory allocated by the application itself. + +Baseline +........ + +Compression settings are the main factor affecting the baseline amount of +memory used by each connection. + +By default ``websockets`` maximizes compression rate at the expense of memory +usage. If memory usage is an issue, lowering compression settings can help: + +- Context Takeover is necessary to get good performance for almost all + applications. It should remain enabled. +- Window Bits is a trade-off between memory usage and compression rate. + It defaults to 15 and can be lowered. The default value isn't optimal + for small, repetitive messages which are typical of WebSocket servers. +- Memory Level is a trade-off between memory usage and compression speed. + It defaults to 8 and can be lowered. A lower memory level can actually + increase speed thanks to memory locality, even if the CPU does more work! + +See this :ref:`example ` for how to +configure compression settings. + +Here's how various compression settings affect memory usage of a single +connection on a 64-bit system, as well a benchmark_ of compressed size and +compression time for a corpus of small JSON documents. + ++-------------+-------------+--------------+--------------+------------------+------------------+ +| Compression | Window Bits | Memory Level | Memory usage | Size vs. default | Time vs. default | ++=============+=============+==============+==============+==================+==================+ +| *default* | 15 | 8 | 325 KiB | +0% | +0% + ++-------------+-------------+--------------+--------------+------------------+------------------+ +| | 14 | 7 | 181 KiB | +1.5% | -5.3% | ++-------------+-------------+--------------+--------------+------------------+------------------+ +| | 13 | 6 | 110 KiB | +2.8% | -7.5% | ++-------------+-------------+--------------+--------------+------------------+------------------+ +| | 12 | 5 | 73 KiB | +4.4% | -18.9% | ++-------------+-------------+--------------+--------------+------------------+------------------+ +| | 11 | 4 | 55 KiB | +8.5% | -18.8% | ++-------------+-------------+--------------+--------------+------------------+------------------+ +| *disabled* | N/A | N/A | 22 KiB | N/A | N/A | ++-------------+-------------+--------------+--------------+------------------+------------------+ + +*Don't assume this example is representative! Compressed size and compression +time depend heavily on the kind of messages exchanged by the application!* + +You can run the same benchmark for your application by creating a list of +typical messages and passing it to the ``_benchmark`` function_. + +.. _benchmark: https://gist.github.com/aaugustin/fbea09ce8b5b30c4e56458eb081fe599 +.. _function: https://gist.github.com/aaugustin/fbea09ce8b5b30c4e56458eb081fe599#file-compression-py-L48-L144 + +This `blog post by Ilya Grigorik`_ provides more details about how compression +settings affect memory usage and how to optimize them. + +.. _blog post by Ilya Grigorik: https://www.igvita.com/2013/11/27/configuring-and-optimizing-websocket-compression/ + +This `experiment by Peter Thorson`_ suggests Window Bits = 11, Memory Level = +4 as a sweet spot for optimizing memory usage. + +.. _experiment by Peter Thorson: https://www.ietf.org/mail-archive/web/hybi/current/msg10222.html + +Buffers +....... + +Under normal circumstances, buffers are almost always empty. + +Under high load, if a server receives more messages than it can process, +bufferbloat can result in excessive memory use. + +By default ``websockets`` has generous limits. It is strongly recommended to +adapt them to your application. When you call :func:`~server.serve()`: + +- Set ``max_size`` (default: 1 MiB, UTF-8 encoded) to the maximum size of + messages your application generates. +- Set ``max_queue`` (default: 32) to the maximum number of messages your + application expects to receive faster than it can process them. The queue + provides burst tolerance without slowing down the TCP connection. -In order to avoid excessive memory use caused by buffer bloat, it is strongly -recommended to :ref:`tune buffer sizes `. +Furthermore, you can lower ``read_limit`` and ``write_limit`` (default: +64 KiB) to reduce the size of buffers for incoming and outgoing data. -Most importantly ``max_size`` should be lowered according to the expected size -of messages. It is also suggested to lower ``max_queue``, ``read_limit`` and -``write_limit`` if memory use is a concern. +The design document provides :ref:`more details about buffers`. Port sharing ------------ diff --git a/docs/design.rst b/docs/design.rst index 03f1ec163..c6097f724 100644 --- a/docs/design.rst +++ b/docs/design.rst @@ -513,21 +513,21 @@ Bufferbloat can happen at every level in the stack where there is a buffer. For each connection, the receiving side contains these buffers: - OS buffers: tuning them is an advanced optimization. -- :class:`~asyncio.StreamReader` bytes buffer: the default limit is 64kB. +- :class:`~asyncio.StreamReader` bytes buffer: the default limit is 64 KiB. You can set another limit by passing a ``read_limit`` keyword argument to :func:`~client.connect()` or :func:`~server.serve()`. - Incoming messages :class:`~collections.deque`: its size depends both on the size and the number of messages it contains. By default the maximum - UTF-8 encoded size is 1MB and the maximum number is 32. In the worst case, - after UTF-8 decoding, a single message could take up to 4MB of memory and - the overall memory consumption could reach 128MB. You should adjust these + UTF-8 encoded size is 1 MiB and the maximum number is 32. In the worst case, + after UTF-8 decoding, a single message could take up to 4 MiB of memory and + the overall memory consumption could reach 128 MiB. You should adjust these limits by setting the ``max_size`` and ``max_queue`` keyword arguments of :func:`~client.connect()` or :func:`~server.serve()` according to your application's requirements. For each connection, the sending side contains these buffers: -- :class:`~asyncio.StreamWriter` bytes buffer: the default size is 64kB. +- :class:`~asyncio.StreamWriter` bytes buffer: the default size is 64 KiB. You can set another limit by passing a ``write_limit`` keyword argument to :func:`~client.connect()` or :func:`~server.serve()`. - OS buffers: tuning them is an advanced optimization. diff --git a/docs/extensions.rst b/docs/extensions.rst index 3a5885009..7c282ffd0 100644 --- a/docs/extensions.rst +++ b/docs/extensions.rst @@ -20,6 +20,8 @@ Deflate extension by default. You can disable this with ``compression=None``. You can also configure the Per-Message Deflate extension explicitly if you want to customize its parameters. +.. _per-message-deflate-configuration-example: + Here's an example on the server side:: import websockets @@ -66,8 +68,8 @@ extensions before it's sent and after it's received. As a consequence writing an extension requires implementing several classes: -1. Extension Factory: it negotiates parameters and instanciates the extension. - Clients and servers require separate extension factories with distict APIs. +1. Extension Factory: it negotiates parameters and instantiates the extension. + Clients and servers require separate extension factories with distinct APIs. 2. Extension: it decodes incoming frames and encodes outgoing frames. If the extension is symmetrical, clients and servers can use the same class. diff --git a/docs/intro.rst b/docs/intro.rst index 376b7d9ca..dea152ab1 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -50,6 +50,8 @@ Here's a corresponding WebSocket client example. Using :func:`connect` as an asynchronous context manager ensures the connection is closed before exiting the ``hello`` coroutine. +.. _secure-server-example: + Secure example -------------- diff --git a/docs/security.rst b/docs/security.rst index f0d1deee3..e9acf0629 100644 --- a/docs/security.rst +++ b/docs/security.rst @@ -1,8 +1,17 @@ Security ======== +Encryption +---------- + +For production use, a server should require encrypted connections. + +See this example of :ref:`encrypting connections with TLS +`. + Memory use ---------- + .. warning:: An attacker who can open an arbitrary number of connections will be able @@ -10,27 +19,13 @@ Memory use by denial of service attacks, you must reject suspicious connections before they reach ``websockets``, typically in a reverse proxy. -The baseline memory use for a connection is about 20kB. - -The incoming bytes buffer, incoming messages queue and outgoing bytes buffer -contribute to the memory use of a connection. By default, each bytes buffer -takes up to 64kB and the messages queue up to 128MB, which is very large. - -Most applications use small messages. Setting ``max_size`` according to the -application's requirements is strongly recommended. See :ref:`buffers` for -details about tuning buffers. - -When compression is enabled, additional memory may be allocated for carrying -the compression context across messages, depending on the context takeover and -window size parameters. With the default configuration, this adds 320kB to the -memory use for a connection. +With the default settings, opening a connection uses 325 KiB of memory. -You can reduce this amount by configuring the ``PerMessageDeflate`` extension -with lower ``server_max_window_bits`` and ``client_max_window_bits`` values. -These parameters default is 15. Lowering them to 11 is a good choice. +Sending some highly compressed messages could use up to 128 MiB of memory +with an amplification factor of 1000 between network traffic and memory use. -Finally, memory consumed by your application code also counts towards the -memory use of a connection. +Configuring a server to :ref:`optimize memory usage ` will +improve security in addition to improving performance. Other limits ------------ diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index ba30efd99..c2988ead5 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -14,10 +14,11 @@ cryptocurrency daemonize fractalideas iterable -kB keepalive +KiB lifecycle Lifecycle +MiB nginx permessage pong @@ -28,6 +29,7 @@ subprotocol subprotocols TLS Unparse +uple websocket WebSocket websockets diff --git a/performance/mem_client.py b/performance/mem_client.py new file mode 100644 index 000000000..890216edf --- /dev/null +++ b/performance/mem_client.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python + +import asyncio +import statistics +import tracemalloc + +import websockets +from websockets.extensions import permessage_deflate + + +CLIENTS = 10 +INTERVAL = 1 / 10 # seconds + +MEM_SIZE = [] + + +async def mem_client(client): + # Space out connections to make them sequential. + await asyncio.sleep(client * INTERVAL) + + tracemalloc.start() + + async with websockets.connect( + "ws://localhost:8765", + extensions=[ + permessage_deflate.ClientPerMessageDeflateFactory( + server_max_window_bits=10, + client_max_window_bits=10, + compress_settings={"memLevel": 3}, + ) + ], + ) as ws: + await ws.send("hello") + await ws.recv() + + await ws.send(b"hello") + await ws.recv() + + MEM_SIZE.append(tracemalloc.get_traced_memory()[0]) + tracemalloc.stop() + + # Hold connection open until the end of the test. + await asyncio.sleep(CLIENTS * INTERVAL) + + +asyncio.get_event_loop().run_until_complete( + asyncio.gather(*[mem_client(client) for client in range(CLIENTS + 1)]) +) + +# First connection incurs non-representative setup costs. +del MEM_SIZE[0] + +print(f"µ = {statistics.mean(MEM_SIZE) / 1024:.1f} KiB") +print(f"σ = {statistics.stdev(MEM_SIZE) / 1024:.1f} KiB") diff --git a/performance/mem_server.py b/performance/mem_server.py new file mode 100644 index 000000000..6c8cef2ec --- /dev/null +++ b/performance/mem_server.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python + +import asyncio +import signal +import statistics +import tracemalloc + +import websockets +from websockets.extensions import permessage_deflate + + +CLIENTS = 10 +INTERVAL = 1 / 10 # seconds + +MEM_SIZE = [] + + +async def handler(ws, path): + msg = await ws.recv() + await ws.send(msg) + + msg = await ws.recv() + await ws.send(msg) + + MEM_SIZE.append(tracemalloc.get_traced_memory()[0]) + tracemalloc.stop() + + tracemalloc.start() + + # Hold connection open until the end of the test. + await asyncio.sleep(CLIENTS * INTERVAL) + + +async def mem_server(stop): + async with websockets.serve( + handler, + "localhost", + 8765, + extensions=[ + permessage_deflate.ServerPerMessageDeflateFactory( + server_max_window_bits=10, + client_max_window_bits=10, + compress_settings={"memLevel": 3}, + ) + ], + ): + await stop + + +loop = asyncio.get_event_loop() + +stop = asyncio.Future() +loop.add_signal_handler(signal.SIGINT, stop.set_result, None) + +tracemalloc.start() + +loop.run_until_complete(mem_server(stop)) + +# First connection incurs non-representative setup costs. +del MEM_SIZE[0] + +print(f"µ = {statistics.mean(MEM_SIZE) / 1024:.1f} KiB") +print(f"σ = {statistics.stdev(MEM_SIZE) / 1024:.1f} KiB") diff --git a/src/websockets/http.py b/src/websockets/http.py index 5e04e53bd..e28acac9f 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -182,9 +182,9 @@ async def read_line(stream): Return :class:`bytes` without CRLF. """ - # Security: this is bounded by the StreamReader's limit (default = 32kB). + # Security: this is bounded by the StreamReader's limit (default = 32 KiB). line = await stream.readline() - # Security: this guarantees header values are small (hard-coded = 4kB) + # Security: this guarantees header values are small (hard-coded = 4 KiB) if len(line) > MAX_LINE: raise ValueError("Line too long") # Not mandatory but safe - https://tools.ietf.org/html/rfc7230#section-3.5 diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 5c60348aa..981c0975c 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -99,7 +99,7 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): To apply a timeout to any other API, wrap it in :func:`~asyncio.wait_for`. The ``max_size`` parameter enforces the maximum size for incoming messages - in bytes. The default value is 1MB. ``None`` disables the limit. If a + in bytes. The default value is 1 MiB. ``None`` disables the limit. If a message larger than the maximum size is received, :meth:`recv()` will raise :exc:`~websockets.exceptions.ConnectionClosed` and the connection will be closed with status code 1009. @@ -117,17 +117,17 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): Since Python can use up to 4 bytes of memory to represent a single character, each websocket connection may use up to ``4 * max_size * max_queue`` bytes of memory to store incoming messages. By default, - this is 128MB. You may want to lower the limits, depending on your + this is 128 MiB. You may want to lower the limits, depending on your application's requirements. The ``read_limit`` argument sets the high-water limit of the buffer for incoming bytes. The low-water limit is half the high-water limit. The - default value is 64kB, half of asyncio's default (based on the current + default value is 64 KiB, half of asyncio's default (based on the current implementation of :class:`~asyncio.StreamReader`). The ``write_limit`` argument sets the high-water limit of the buffer for outgoing bytes. The low-water limit is a quarter of the high-water limit. - The default value is 64kB, equal to asyncio's default (based on the + The default value is 64 KiB, equal to asyncio's default (based on the current implementation of ``FlowControlMixin``). As soon as the HTTP request and response in the opening handshake are From 7d72dabd100b65bb05580f3e0163e3b7ce3dc787 Mon Sep 17 00:00:00 2001 From: q Date: Tue, 1 Jan 2019 09:55:26 +0800 Subject: [PATCH 0532/1539] Enable GNU Readline for interactive client --- src/websockets/__main__.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index 350fc06e8..078733912 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -152,6 +152,11 @@ def main(): ) sys.stderr.flush() + try: + import readline # noqa + except ImportError: # Windows has no `readline` normally + pass + # Parse command line arguments. parser = argparse.ArgumentParser( prog="python -m websockets", From 04336426f894374012da0933ee370d5c20abeeda Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 1 Jan 2019 17:06:27 +0100 Subject: [PATCH 0533/1539] Indicate which functions are coroutines. Also fix indentation issues in the API docs. Thanks @cjerdonek for the idea and @njsmith for sphinxcontrib-trio. Fix #295 (assuming RTD builds properly). --- .readthedocs.yml | 7 +++++++ docs/api.rst | 34 +++++++++++++++++++--------------- docs/conf.py | 7 ++++++- src/websockets/server.py | 3 +-- 4 files changed, 33 insertions(+), 18 deletions(-) create mode 100644 .readthedocs.yml diff --git a/.readthedocs.yml b/.readthedocs.yml new file mode 100644 index 000000000..e5e224afd --- /dev/null +++ b/.readthedocs.yml @@ -0,0 +1,7 @@ +build: + image: latest + +python: + version: 3.6 + +requirements_file: docs/requirements.txt diff --git a/docs/api.rst b/docs/api.rst index e480604bb..ce6529d1d 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -32,21 +32,24 @@ Server .. automodule:: websockets.server - .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, **kwds) + .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, **kwds) + :async: - .. autofunction:: unix_serve(ws_handler, path, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, **kwds) + .. autofunction:: unix_serve(ws_handler, path, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, **kwds) + :async: - .. autoclass:: WebSocketServer + .. autoclass:: WebSocketServer .. automethod:: close() .. automethod:: wait_closed() .. autoattribute:: sockets - .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, host=None, port=None, secure=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None) + .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, host=None, port=None, secure=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None) .. automethod:: handshake(origins=None, available_extensions=None, available_subprotocols=None, extra_headers=None) .. automethod:: process_request(path, request_headers) + :async: .. automethod:: select_subprotocol(client_subprotocols, server_subprotocols) Client @@ -54,9 +57,10 @@ Client .. automodule:: websockets.client - .. autofunction:: connect(uri, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, **kwds) + .. autofunction:: connect(uri, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, **kwds) + :async: - .. autoclass:: WebSocketClientProtocol(*, host=None, port=None, secure=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, extensions=None, subprotocols=None, extra_headers=None) + .. autoclass:: WebSocketClientProtocol(*, host=None, port=None, secure=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, extensions=None, subprotocols=None, extra_headers=None) .. automethod:: handshake(wsuri, origin=None, available_extensions=None, available_subprotocols=None, extra_headers=None) @@ -65,7 +69,7 @@ Shared .. automodule:: websockets.protocol - .. autoclass:: WebSocketCommonProtocol(*, host=None, port=None, secure=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None) + .. autoclass:: WebSocketCommonProtocol(*, host=None, port=None, secure=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None) .. automethod:: close(code=1000, reason='') .. automethod:: wait_closed() @@ -87,15 +91,15 @@ Per-Message Deflate Extension .. automodule:: websockets.extensions.permessage_deflate - .. autoclass:: ServerPerMessageDeflateFactory + .. autoclass:: ServerPerMessageDeflateFactory - .. autoclass:: ClientPerMessageDeflateFactory + .. autoclass:: ClientPerMessageDeflateFactory Exceptions .......... .. automodule:: websockets.exceptions - :members: + :members: Low-level --------- @@ -104,25 +108,25 @@ Opening handshake ................. .. automodule:: websockets.handshake - :members: + :members: Data transfer ............. .. automodule:: websockets.framing - :members: + :members: URI parser .......... .. automodule:: websockets.uri - :members: + :members: Utilities ......... .. automodule:: websockets.headers - :members: + :members: .. automodule:: websockets.http - :members: + :members: diff --git a/docs/conf.py b/docs/conf.py index 1a5448f7b..4ad4ad4b7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -25,7 +25,12 @@ # Add any Sphinx extension module names here, as strings. They can be extensions # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. -extensions = ['sphinx.ext.autodoc', 'sphinx.ext.intersphinx', 'sphinx.ext.viewcode'] +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.intersphinx', + 'sphinx.ext.viewcode', + 'sphinxcontrib_trio', + ] # Spelling check needs an additional module that is not installed by default. # Add it only if spelling check is requested so docs can be generated without it. diff --git a/src/websockets/server.py b/src/websockets/server.py index 839b3c861..c0dc29dc3 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -395,8 +395,7 @@ def process_subprotocol(self, headers, available_subprotocols): return subprotocol - @staticmethod - def select_subprotocol(client_subprotocols, server_subprotocols): + def select_subprotocol(self, client_subprotocols, server_subprotocols): """ Pick a subprotocol among those offered by the client. From 76d739dfcf85b3739181c11dcbaab4d4b542e354 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 1 Jan 2019 17:13:52 +0100 Subject: [PATCH 0534/1539] Add RTD requirements. (forgotten in previous commits) --- docs/requirements.txt | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 docs/requirements.txt diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 000000000..954e8c755 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,3 @@ +sphinx +sphinxcontrib-spelling +sphinxcontrib-trio From 8fc78fee48d52bb3c690e925bad0825613319296 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 1 Jan 2019 19:45:03 +0100 Subject: [PATCH 0535/1539] Send fragmented messages from async iterators. Fix #477. --- src/websockets/protocol.py | 35 ++++++++++++++--- tests/test_protocol.py | 80 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 110 insertions(+), 5 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 981c0975c..8dacbf4ce 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -432,10 +432,11 @@ async def send(self, data): object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) as a binary frame. - It also accepts an iterable of strings or bytes-like objects. Each - item is treated as a message fragment and sent in its own frame. All - items must be of the same type, or else :meth:`send` will raise a - :exc:`TypeError` and the connection will be closed. + It also accepts an iterable or an asynchronous iterator of strings or + bytes-like objects. Each item is treated as a message fragment and + sent in its own frame. All items must be of the same type, or else + :meth:`send` will raise a :exc:`TypeError` and the connection will be + closed. It raises a :exc:`TypeError` for other inputs. @@ -482,7 +483,31 @@ async def send(self, data): # Fragmented message -- asynchronous iterator - # To be implemented after dropping support for Python 3.4. + elif isinstance(data, collections.abc.AsyncIterable): + # aiter_data = aiter(data) without aiter + aiter_data = type(data).__aiter__(data) + + # First fragment. + try: + # data = anext(aiter_data) without anext + data = await type(aiter_data).__anext__(aiter_data) + except StopAsyncIteration: + return + opcode, data = prepare_data(data) + await self.write_frame(False, opcode, data) + + # Other fragments. + async for data in aiter_data: + confirm_opcode, data = prepare_data(data) + if confirm_opcode != opcode: + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + self.fail_connection(1011) + raise TypeError("data contains inconsistent types") + await self.write_frame(False, OP_CONT, data) + + # Final fragment. + await self.write_frame(True, OP_CONT, b"") else: raise TypeError("data must be bytes, str, or iterable") diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 70c2be0bd..7a8b0a69a 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -27,6 +27,27 @@ MS = max(MS, 2.5 * time.get_clock_info("monotonic").resolution) +class async_iterable: + + # In Python ≥ 3.6, this can be simplified to: + + # async def async_iterable(iterable): + # for item in iterable: + # yield item + + def __init__(self, iterable): + self.iterator = iter(iterable) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self.iterator) + except StopIteration: + raise StopAsyncIteration + + class TransportMock(unittest.mock.Mock): """ Transport mock to control the protocol's inputs and outputs in tests. @@ -599,6 +620,65 @@ def test_send_iterable_mixed_type_error(self): (True, OP_CLOSE, serialize_close(1011, "")), ) + def test_send_async_iterable_text(self): + self.loop.run_until_complete(self.protocol.send(async_iterable(["ca", "fé"]))) + self.assertFramesSent( + (False, OP_TEXT, "ca".encode("utf-8")), + (False, OP_CONT, "fé".encode("utf-8")), + (True, OP_CONT, "".encode("utf-8")), + ) + + def test_send_async_iterable_binary(self): + self.loop.run_until_complete(self.protocol.send(async_iterable([b"te", b"a"]))) + self.assertFramesSent( + (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") + ) + + def test_send_async_iterable_binary_from_bytearray(self): + self.loop.run_until_complete( + self.protocol.send(async_iterable([bytearray(b"te"), bytearray(b"a")])) + ) + self.assertFramesSent( + (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") + ) + + def test_send_async_iterable_binary_from_memoryview(self): + self.loop.run_until_complete( + self.protocol.send(async_iterable([memoryview(b"te"), memoryview(b"a")])) + ) + self.assertFramesSent( + (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") + ) + + def test_send_async_iterable_binary_from_non_contiguous_memoryview(self): + self.loop.run_until_complete( + self.protocol.send( + async_iterable([memoryview(b"ttee")[::2], memoryview(b"aa")[::2]]) + ) + ) + self.assertFramesSent( + (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") + ) + + def test_send_empty_async_iterable(self): + self.loop.run_until_complete(self.protocol.send(async_iterable([]))) + self.assertNoFrameSent() + + def test_send_async_iterable_type_error(self): + with self.assertRaises(TypeError): + self.loop.run_until_complete(self.protocol.send(async_iterable([42]))) + self.assertNoFrameSent() + + def test_send_async_iterable_mixed_type_error(self): + with self.assertRaises(TypeError): + self.loop.run_until_complete( + self.protocol.send(async_iterable(["café", b"tea"])) + ) + self.assertFramesSent( + (False, OP_TEXT, "café".encode("utf-8")), + (True, OP_CLOSE, serialize_close(1011, "")), + ) + def test_send_on_closing_connection_local(self): close_task = self.half_close_connection_local() From 207518d813347b42d5f2fb9f50b09bd101016a24 Mon Sep 17 00:00:00 2001 From: Thierry Parmentelat Date: Fri, 18 Jan 2019 13:07:06 +0100 Subject: [PATCH 0536/1539] typo in documentation --- src/websockets/protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 8dacbf4ce..457e37b80 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -330,7 +330,7 @@ def closed(self): """ This property is ``True`` once the connection is closed. - Be aware that both :attr:`open` and :attr`closed` are ``False`` during + Be aware that both :attr:`open` and :attr:`closed` are ``False`` during the opening and closing sequences. """ From ec2e589b22146c394c2adfd47f5995db76ca1184 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 26 Jan 2019 09:58:53 +0100 Subject: [PATCH 0537/1539] Update `make clean` after introducing the src dir. --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 0863f8578..9fa5c2422 100644 --- a/Makefile +++ b/Makefile @@ -18,4 +18,4 @@ coverage: clean: find . -name '*.pyc' -o -name '*.so' -delete find . -name __pycache__ -delete - rm -rf .coverage build compliance/reports dist docs/_build htmlcov MANIFEST README websockets.egg-info + rm -rf .coverage build compliance/reports dist docs/_build htmlcov MANIFEST README src/websockets.egg-info From ed8d800304b0b2f0959060a3a83086131e208ed0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 25 Jan 2019 22:25:30 +0100 Subject: [PATCH 0538/1539] Drop support for Python 3.5. --- .appveyor.yml | 4 ++-- .circleci/config.yml | 12 ------------ .travis.yml | 4 ++-- README.rst | 2 +- docs/changelog.rst | 2 +- setup.cfg | 2 +- setup.py | 7 +++---- tox.ini | 2 +- 8 files changed, 11 insertions(+), 24 deletions(-) diff --git a/.appveyor.yml b/.appveyor.yml index 5109200b4..7954ee4be 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -5,8 +5,8 @@ branches: skip_branch_with_pr: true environment: -# websockets only works on Python >= 3.5. - CIBW_SKIP: cp27-* cp33-* cp34-* +# websockets only works on Python >= 3.6. + CIBW_SKIP: cp27-* cp33-* cp34-* cp35-* CIBW_TEST_COMMAND: python -W default -m unittest WEBSOCKETS_TESTS_TIMEOUT_FACTOR: 100 diff --git a/.circleci/config.yml b/.circleci/config.yml index 5ec5b5103..8a7df9ac6 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -11,15 +11,6 @@ jobs: - run: sudo pip install tox codecov - run: tox -e coverage,black,flake8,isort - run: codecov - py35: - docker: - - image: circleci/python:3.5 - steps: - # Remove IPv6 entry for localhost in Circle CI containers because it doesn't work anyway. - - run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc - - checkout - - run: sudo pip install tox - - run: tox -e py35 py36: docker: - image: circleci/python:3.6 @@ -44,9 +35,6 @@ workflows: build: jobs: - main - - py35: - requires: - - main - py36: requires: - main diff --git a/.travis.yml b/.travis.yml index c0f11357e..030693759 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,7 @@ env: global: - # websockets only works on Python >= 3.5. - - CIBW_SKIP="cp27-* cp33-* cp34-*" + # websockets only works on Python >= 3.6. + - CIBW_SKIP="cp27-* cp33-* cp34-* cp35-*" - CIBW_TEST_COMMAND="python3 -W default -m unittest" - WEBSOCKETS_TESTS_TIMEOUT_FACTOR=100 diff --git a/README.rst b/README.rst index 572647a15..0da52524e 100644 --- a/README.rst +++ b/README.rst @@ -124,7 +124,7 @@ Why shouldn't I use ``websockets``? and :rfc:`7692`: Compression Extensions for WebSocket. Its support for HTTP is minimal — just enough for a HTTP health check. * If you want to use Python 2: ``websockets`` builds upon ``asyncio`` which - only works on Python 3. ``websockets`` requires Python ≥ 3.5. + only works on Python 3. ``websockets`` requires Python ≥ 3.6. What else? ---------- diff --git a/docs/changelog.rst b/docs/changelog.rst index 320300f64..b53080501 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -10,7 +10,7 @@ Changelog .. warning:: - **Version 8.0 drops compatibility with Python 3.4.** + **Version 8.0 drops compatibility with Python 3.4 and 3.5.** .. warning:: diff --git a/setup.cfg b/setup.cfg index 88b9b1a33..c306b2d4f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bdist_wheel] -python-tag = py35.py36.py37 +python-tag = py36.py37 [metadata] license_file = LICENSE diff --git a/setup.py b/setup.py index 2956058a4..1fe71a4f0 100644 --- a/setup.py +++ b/setup.py @@ -14,8 +14,8 @@ py_version = sys.version_info[:2] -if py_version < (3, 5): - raise Exception("websockets requires Python >= 3.5.") +if py_version < (3, 6): + raise Exception("websockets requires Python >= 3.6.") packages = ['websockets', 'websockets/extensions'] @@ -47,7 +47,6 @@ 'Operating System :: OS Independent', 'Programming Language :: Python', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', ], @@ -56,6 +55,6 @@ ext_modules=ext_modules, include_package_data=True, zip_safe=True, - python_requires='>=3.5', + python_requires='>=3.6', test_loader='unittest:TestLoader', ) diff --git a/tox.ini b/tox.ini index de0f285d0..238fcd649 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py35,py36,py37,coverage,black,flake8,isort +envlist = py36,py37,coverage,black,flake8,isort [testenv] commands = python -W default -m unittest {posargs} From 67434cf4d996c259d82b21bb3bcfd1ce0d19c74e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 26 Jan 2019 09:29:46 +0100 Subject: [PATCH 0539/1539] =?UTF-8?q?Update=20documentation=20for=20Python?= =?UTF-8?q?=20=E2=89=A5=203.6.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.rst | 4 +- docs/cheatsheet.rst | 10 ++--- docs/deployment.rst | 7 --- docs/intro.rst | 87 ++------------------------------------ example/old_client.py | 24 ----------- example/old_server.py | 20 --------- example/old_shutdown.py | 29 ------------- src/websockets/client.py | 4 +- src/websockets/protocol.py | 3 +- src/websockets/server.py | 4 +- 10 files changed, 15 insertions(+), 177 deletions(-) delete mode 100755 example/old_client.py delete mode 100755 example/old_server.py delete mode 100755 example/old_shutdown.py diff --git a/README.rst b/README.rst index 0da52524e..ae47c7a48 100644 --- a/README.rst +++ b/README.rst @@ -36,7 +36,7 @@ Python with a focus on correctness and simplicity. Built on top of ``asyncio``, Python's standard asynchronous I/O framework, it provides an elegant coroutine-based API. -Here's how a client sends and receives messages (Python ≥ 3.6): +Here's how a client sends and receives messages: .. copy-pasted because GitHub doesn't support the include directive @@ -55,7 +55,7 @@ Here's how a client sends and receives messages (Python ≥ 3.6): asyncio.get_event_loop().run_until_complete( hello('ws://localhost:8765')) -And here's an echo server (Python ≥ 3.6): +And here's an echo server: .. code:: python diff --git a/docs/cheatsheet.rst b/docs/cheatsheet.rst index 3b8993a8c..15a731084 100644 --- a/docs/cheatsheet.rst +++ b/docs/cheatsheet.rst @@ -27,9 +27,8 @@ Server needed in general. * Create a server with :func:`~server.serve` which is similar to asyncio's - :meth:`~asyncio.AbstractEventLoop.create_server`. - - * On Python ≥ 3.5.1, you can also use it as an asynchronous context manager. + :meth:`~asyncio.AbstractEventLoop.create_server`. You can also use it as an + asynchronous context manager. * The server takes care of establishing connections, then lets the handler execute the application logic, and finally closes the connection after the @@ -43,9 +42,8 @@ Client ------ * Create a client with :func:`~client.connect` which is similar to asyncio's - :meth:`~asyncio.BaseEventLoop.create_connection`. - - * On Python ≥ 3.5.1, you can also use it as an asynchronous context manager. + :meth:`~asyncio.BaseEventLoop.create_connection`. You can also use it as an + asynchronous context manager. * For advanced customization, you may subclass :class:`~server.WebSocketClientProtocol` and pass either this subclass or diff --git a/docs/deployment.rst b/docs/deployment.rst index f8bc7f94b..b0c05dd73 100644 --- a/docs/deployment.rst +++ b/docs/deployment.rst @@ -37,13 +37,6 @@ Here's a full example (Unix-only): .. literalinclude:: ../example/shutdown.py :emphasize-lines: 13,17-19 -websockets supports asynchronous context managers on Python ≥ 3.5.1. ``async -for`` was introduced in Python 3.6. Here's the equivalent for older Python -versions: - -.. literalinclude:: ../example/old_shutdown.py - :emphasize-lines: 22-25 - It's more difficult to achieve the same effect on Windows. Some third-party projects try to help with this problem. diff --git a/docs/intro.rst b/docs/intro.rst index dea152ab1..389896ef4 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -6,17 +6,12 @@ Getting started Requirements ------------ -``websockets`` requires Python ≥ 3.5. +``websockets`` requires Python ≥ 3.6. You should use the latest version of Python if possible. If you're using an older version, be aware that for each minor version (3.x), only the latest bugfix release (3.x.y) is officially supported. -.. warning:: - - This documentation is written for Python ≥ 3.6. If you're using an older - Python version, you need to :ref:`adapt the code samples `. - Installation ------------ @@ -61,16 +56,13 @@ because they reduce the risk of interference by bad proxies. The WSS protocol is to WS what HTTPS is to HTTP: the connection is encrypted with TLS. WSS requires TLS certificates like HTTPS. -Here's how to adapt the server example to provide secure connections, using -APIs available in Python ≥ 3.6. - -Refer to the documentation of the :mod:`ssl` module for configuring the -context securely or adapting the code to older Python versions. +Here's how to adapt the server example to provide secure connections. See the +documentation of the :mod:`ssl` module for configuring the context securely. .. literalinclude:: ../example/secure_server.py :emphasize-lines: 19,23-24 -Here's how to adapt the client, also on Python ≥ 3.6. +Here's how to adapt the client. .. literalinclude:: ../example/secure_client.py :emphasize-lines: 10,15-16 @@ -137,18 +129,6 @@ messages received on the WebSocket connection. Iteration terminates when the client disconnects. -Asynchronous iteration was introduced in Python 3.6; here's the same code for -earlier Python versions:: - - async def consumer_handler(websocket, path): - while True: - message = await websocket.recv() - await consumer(message) - -:meth:`~protocol.WebSocketCommonProtocol.recv` raises a -:exc:`~exceptions.ConnectionClosed` exception when the client disconnects, -which breaks out of the ``while True`` loop. - Producer ........ @@ -226,62 +206,3 @@ One more thing... ``websockets`` provides an interactive client:: $ python -m websockets wss://echo.websocket.org/ - -.. _python-lt-36: - -Python < 3.6 ------------- - -This documentation takes advantage of several features that aren't available -in Python < 3.6: - -- Asynchronous context managers didn't work well until Python 3.5.1; -- Asynchronous iterators were added in Python 3.6; -- f-strings were introduced in Python 3.6 (this is unrelated to :mod:`asyncio` - and :mod:`websockets`). - -Here's how to adapt the basic server example. - -.. literalinclude:: ../example/old_server.py - -And here's the basic client example. - -.. literalinclude:: ../example/old_client.py - -Asynchronous context managers -............................. - -Asynchronous context managers were added in Python 3.5. However, -``websockets`` only supports them on Python ≥ 3.5.1, where -:func:`~asyncio.ensure_future` accepts any awaitable. - -If you're using Python < 3.5.1, instead of:: - - with websockets.connect(...) as client: - ... - -you must write:: - - client = yield from websockets.connect(...) - try: - ... - finally: - yield from client.close() - -Asynchronous iterators -...................... - -If you're using Python < 3.6, you must replace:: - - async for message in websocket: - ... - -with:: - - while True: - message = yield from websocket.recv() - ... - -The latter will always raise a :exc:`~exceptions.ConnectionClosed` exception -when the connection is closed, while the former will only raise that exception -if the connection terminates with an error. diff --git a/example/old_client.py b/example/old_client.py deleted file mode 100755 index be34f14be..000000000 --- a/example/old_client.py +++ /dev/null @@ -1,24 +0,0 @@ -#!/usr/bin/env python - -# WS client example for old Python versions - -import asyncio -import websockets - -async def hello(): - websocket = await websockets.connect( - 'ws://localhost:8765/') - - try: - name = input("What's your name? ") - - await websocket.send(name) - print("> {}".format(name)) - - greeting = await websocket.recv() - print("< {}".format(greeting)) - - finally: - await websocket.close() - -asyncio.get_event_loop().run_until_complete(hello()) diff --git a/example/old_server.py b/example/old_server.py deleted file mode 100755 index 8c63e33e6..000000000 --- a/example/old_server.py +++ /dev/null @@ -1,20 +0,0 @@ -#!/usr/bin/env python - -# WS server example for old Python versions - -import asyncio -import websockets - -async def hello(websocket, path): - name = await websocket.recv() - print("< {}".format(name)) - - greeting = "Hello {}!".format(name) - - await websocket.send(greeting) - print("> {}".format(greeting)) - -start_server = websockets.serve(hello, 'localhost', 8765) - -asyncio.get_event_loop().run_until_complete(start_server) -asyncio.get_event_loop().run_forever() diff --git a/example/old_shutdown.py b/example/old_shutdown.py deleted file mode 100755 index 180da9059..000000000 --- a/example/old_shutdown.py +++ /dev/null @@ -1,29 +0,0 @@ -#!/usr/bin/env python - -import asyncio -import signal -import websockets - -async def echo(websocket, path): - while True: - try: - msg = await websocket.recv() - except websockets.ConnectionClosed: - break - else: - await websocket.send(msg) - -loop = asyncio.get_event_loop() - -# Create the server. -start_server = websockets.serve(echo, 'localhost', 8765) -server = loop.run_until_complete(start_server) - -# Run the server until receiving SIGTERM. -stop = asyncio.Future() -loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) -loop.run_until_complete(stop) - -# Shut down the server. -server.close() -loop.run_until_complete(server.wait_closed()) diff --git a/src/websockets/client.py b/src/websockets/client.py index 46dd1b447..5e504969b 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -316,8 +316,8 @@ class Connect: :class:`WebSocketClientProtocol` which can then be used to send and receive messages. - On Python ≥ 3.5.1, :func:`connect` can be used as a asynchronous context - manager. In that case, the connection is closed when exiting the context. + :func:`connect` can also be used as a asynchronous context manager. In + that case, the connection is closed when exiting the context. :func:`connect` is a wrapper around the event loop's :meth:`~asyncio.BaseEventLoop.create_connection` method. Unknown keyword diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 457e37b80..1e0814dcf 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -57,8 +57,7 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): control frames automatically. It sends outgoing data frames and performs the closing handshake. - On Python ≥ 3.6, :class:`WebSocketCommonProtocol` instances support - asynchronous iteration:: + :class:`WebSocketCommonProtocol` supports asynchronous iteration:: async for message in websocket: await process(message) diff --git a/src/websockets/server.py b/src/websockets/server.py index c0dc29dc3..17b13aec2 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -676,8 +676,8 @@ class Serve: :meth:`~websockets.server.WebSocketServer.wait_closed` methods for terminating the server and cleaning up its resources. - On Python ≥ 3.5.1, :func:`serve` can also be used as an asynchronous context - manager. In this case, the server is shut down when exiting the context. + :func:`serve` can also be used as an asynchronous context manager. In + this case, the server is shut down when exiting the context. :func:`serve` is a wrapper around the event loop's :meth:`~asyncio.AbstractEventLoop.create_server` method. Internally, it From 17cb6949f40f84acc505c5b13f10837f5cb327e4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 26 Jan 2019 09:55:52 +0100 Subject: [PATCH 0540/1539] =?UTF-8?q?Update=20tests=20for=20Python=20?= =?UTF-8?q?=E2=89=A5=203.6.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_client_server.py | 9 ++------- tests/test_protocol.py | 22 +++------------------- 2 files changed, 5 insertions(+), 26 deletions(-) diff --git a/tests/test_client_server.py b/tests/test_client_server.py index 6b80c7f6e..20cef5925 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -1104,19 +1104,14 @@ class SSLClientServerTests(ClientServerTests): @property def server_context(self): - # Change to ssl.PROTOCOL_TLS_SERVER when dropping Python < 3.6. - ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) ssl_context.load_cert_chain(testcert) return ssl_context @property def client_context(self): - # Change to ssl.PROTOCOL_TLS_CLIENT when dropping Python < 3.6. - # Then remove verify_mode and check_hostname below. - ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context.load_verify_locations(testcert) - ssl_context.verify_mode = ssl.CERT_REQUIRED - ssl_context.check_hostname = True return ssl_context def start_server(self, **kwds): diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 7a8b0a69a..9e9d40393 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -27,25 +27,9 @@ MS = max(MS, 2.5 * time.get_clock_info("monotonic").resolution) -class async_iterable: - - # In Python ≥ 3.6, this can be simplified to: - - # async def async_iterable(iterable): - # for item in iterable: - # yield item - - def __init__(self, iterable): - self.iterator = iter(iterable) - - def __aiter__(self): - return self - - async def __anext__(self): - try: - return next(self.iterator) - except StopIteration: - raise StopAsyncIteration +async def async_iterable(iterable): + for item in iterable: + yield item class TransportMock(unittest.mock.Mock): From 31106eb42c846e50e1956043e9f2564e398ccd6c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 25 Jan 2019 22:30:44 +0100 Subject: [PATCH 0541/1539] Remove conditional code for Python < 3.6. --- setup.py | 3 - src/websockets/client.py | 19 +----- src/websockets/protocol.py | 25 +++++-- src/websockets/py36/__init__.py | 2 - src/websockets/py36/protocol.py | 20 ------ src/websockets/server.py | 23 +------ tests/py36/__init__.py | 0 tests/py36/_test_client_server.py | 105 ------------------------------ tests/test_client_server.py | 100 +++++++++++++++++++++++----- 9 files changed, 107 insertions(+), 190 deletions(-) delete mode 100644 src/websockets/py36/__init__.py delete mode 100644 src/websockets/py36/protocol.py delete mode 100644 tests/py36/__init__.py delete mode 100644 tests/py36/_test_client_server.py diff --git a/setup.py b/setup.py index 1fe71a4f0..d4fadb240 100644 --- a/setup.py +++ b/setup.py @@ -19,9 +19,6 @@ packages = ['websockets', 'websockets/extensions'] -if py_version >= (3, 6): - packages.append('websockets/py36') - ext_modules = [ setuptools.Extension( 'websockets.speedups', diff --git a/src/websockets/client.py b/src/websockets/client.py index 5e504969b..9babbb412 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -6,7 +6,6 @@ import asyncio import collections.abc import logging -import sys from .exceptions import ( InvalidHandshake, @@ -520,20 +519,4 @@ def __await__(self): return self.__await_impl__().__await__() -# We can't define __await__ on Python < 3.5.1 because asyncio.ensure_future -# didn't accept arbitrary awaitables until Python 3.5.1. We don't define -# __aenter__ and __aexit__ either on Python < 3.5.1 to keep things simple. -if sys.version_info[:3] < (3, 5, 1): # pragma: no cover - - del Connect.__aenter__ - del Connect.__aexit__ - del Connect.__await__ - - async def connect(*args, **kwds): - return Connect(*args, **kwds).__iter__() - - connect.__doc__ = Connect.__doc__ - -else: - - connect = Connect +connect = Connect diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 1e0814dcf..f87f40086 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -15,7 +15,6 @@ import logging import random import struct -import sys from .exceptions import ( ConnectionClosed, @@ -347,6 +346,24 @@ async def wait_closed(self): """ await asyncio.shield(self.connection_lost_waiter) + async def __aiter__(self): + """ + Iterate on received messages. + + Exit normally when the connection is closed with code 1000 or 1001. + + Raise an exception in other cases. + + """ + try: + while True: + yield await self.recv() + except ConnectionClosed as exc: + if exc.code == 1000 or exc.code == 1001: + return + else: + raise + async def recv(self): """ This coroutine receives the next message. @@ -1225,9 +1242,3 @@ def connection_lost(self, exc): # - it must never be canceled. self.connection_lost_waiter.set_result(None) super().connection_lost(exc) - - -if sys.version_info[:2] >= (3, 6): # pragma: no cover - from .py36.protocol import __aiter__ - - WebSocketCommonProtocol.__aiter__ = __aiter__ diff --git a/src/websockets/py36/__init__.py b/src/websockets/py36/__init__.py deleted file mode 100644 index b9211bf87..000000000 --- a/src/websockets/py36/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# This package contains code using async iteration added in Python 3.6. -# It cannot be imported on Python < 3.6 because it triggers syntax errors. diff --git a/src/websockets/py36/protocol.py b/src/websockets/py36/protocol.py deleted file mode 100644 index f0784de05..000000000 --- a/src/websockets/py36/protocol.py +++ /dev/null @@ -1,20 +0,0 @@ -from ..exceptions import ConnectionClosed - - -async def __aiter__(self): - """ - Iterate on received messages. - - Exit normally when the connection is closed with code 1000. - - Raise an exception in other cases. - - """ - try: - while True: - yield await self.recv() - except ConnectionClosed as exc: - if exc.code == 1000 or exc.code == 1001: - return - else: - raise diff --git a/src/websockets/server.py b/src/websockets/server.py index 17b13aec2..979fbcd1b 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -8,7 +8,6 @@ import email.utils import http import logging -import sys import warnings from .exceptions import ( @@ -861,6 +860,9 @@ def __await__(self): return self.__await_impl__().__await__() +serve = Serve + + def unix_serve(ws_handler, path, **kwargs): """ Similar to :func:`serve()`, but for listening on Unix sockets. @@ -874,22 +876,3 @@ def unix_serve(ws_handler, path, **kwargs): """ return serve(ws_handler, path=path, **kwargs) - - -# We can't define __await__ on Python < 3.5.1 because asyncio.ensure_future -# didn't accept arbitrary awaitables until Python 3.5.1. We don't define -# __aenter__ and __aexit__ either on Python < 3.5.1 to keep things simple. -if sys.version_info[:3] < (3, 5, 1): # pragma: no cover - - del Serve.__aenter__ - del Serve.__aexit__ - del Serve.__await__ - - async def serve(*args, **kwds): - return Serve(*args, **kwds).__iter__() - - serve.__doc__ = Serve.__doc__ - -else: - - serve = Serve diff --git a/tests/py36/__init__.py b/tests/py36/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/py36/_test_client_server.py b/tests/py36/_test_client_server.py deleted file mode 100644 index 10b135cc9..000000000 --- a/tests/py36/_test_client_server.py +++ /dev/null @@ -1,105 +0,0 @@ -# Tests containing Python 3.6+ syntax, extracted from test_client_server.py. - -import asyncio -import sys -import unittest - -from websockets.client import * -from websockets.exceptions import ConnectionClosed -from websockets.server import * - -from ..test_client_server import get_server_uri - - -# Fail at import time, not just at run time, to prevent test -# discovery. -if sys.version_info[:2] < (3, 6): # pragma: no cover - raise ImportError("Python 3.6+ only") - - -MESSAGES = ["3", "2", "1", "Fire!"] - - -class AsyncIteratorTests(unittest.TestCase): - - # This is a protocol-level feature, but since it's a high-level API, it is - # much easier to exercise at the client or server level. - - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) - - def tearDown(self): - self.loop.close() - - def test_iterate_on_messages(self): - async def handler(ws, path): - for message in MESSAGES: - await ws.send(message) - - start_server = serve(handler, "localhost", 0) - server = self.loop.run_until_complete(start_server) - - messages = [] - - async def run_client(): - nonlocal messages - async with connect(get_server_uri(server)) as ws: - async for message in ws: - messages.append(message) - - self.loop.run_until_complete(run_client()) - - self.assertEqual(messages, MESSAGES) - - server.close() - self.loop.run_until_complete(server.wait_closed()) - - def test_iterate_on_messages_going_away_exit_ok(self): - async def handler(ws, path): - for message in MESSAGES: - await ws.send(message) - await ws.close(1001) - - start_server = serve(handler, "localhost", 0) - server = self.loop.run_until_complete(start_server) - - messages = [] - - async def run_client(): - nonlocal messages - async with connect(get_server_uri(server)) as ws: - async for message in ws: - messages.append(message) - - self.loop.run_until_complete(run_client()) - - self.assertEqual(messages, MESSAGES) - - server.close() - self.loop.run_until_complete(server.wait_closed()) - - def test_iterate_on_messages_internal_error_exit_not_ok(self): - async def handler(ws, path): - for message in MESSAGES: - await ws.send(message) - await ws.close(1011) - - start_server = serve(handler, "localhost", 0) - server = self.loop.run_until_complete(start_server) - - messages = [] - - async def run_client(): - nonlocal messages - async with connect(get_server_uri(server)) as ws: - async for message in ws: - messages.append(message) - - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(run_client()) - - self.assertEqual(messages, MESSAGES) - - server.close() - self.loop.run_until_complete(server.wait_closed()) diff --git a/tests/test_client_server.py b/tests/test_client_server.py index 20cef5925..cbac7a24c 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -7,7 +7,6 @@ import random import socket import ssl -import sys import tempfile import unittest import unittest.mock @@ -1306,10 +1305,6 @@ def setUp(self): def tearDown(self): self.loop.close() - # Asynchronous context managers are only enabled on Python ≥ 3.5.1. - @unittest.skipIf( - sys.version_info[:3] <= (3, 5, 0), "this test requires Python 3.5.1+" - ) def test_client(self): start_server = serve(handler, "localhost", 0) server = self.loop.run_until_complete(start_server) @@ -1327,10 +1322,6 @@ async def run_client(): server.close() self.loop.run_until_complete(server.wait_closed()) - # Asynchronous context managers are only enabled on Python ≥ 3.5.1. - @unittest.skipIf( - sys.version_info[:3] <= (3, 5, 0), "this test requires Python 3.5.1+" - ) def test_server(self): async def run_server(): # Use serve as an asynchronous context manager. @@ -1342,10 +1333,6 @@ async def run_server(): self.loop.run_until_complete(run_server()) - # Asynchronous context managers are only enabled on Python ≥ 3.5.1. - @unittest.skipIf( - sys.version_info[:3] <= (3, 5, 0), "this test requires Python 3.5.1+" - ) @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") def test_unix_server(self): async def run_server(path): @@ -1360,5 +1347,88 @@ async def run_server(path): self.loop.run_until_complete(run_server(path)) -if sys.version_info[:2] >= (3, 6): # pragma: no cover - from .py36._test_client_server import AsyncIteratorTests # noqa +class AsyncIteratorTests(unittest.TestCase): + + # This is a protocol-level feature, but since it's a high-level API, it is + # much easier to exercise at the client or server level. + + MESSAGES = ["3", "2", "1", "Fire!"] + + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + + def test_iterate_on_messages(self): + async def handler(ws, path): + for message in self.MESSAGES: + await ws.send(message) + + start_server = serve(handler, "localhost", 0) + server = self.loop.run_until_complete(start_server) + + messages = [] + + async def run_client(): + nonlocal messages + async with connect(get_server_uri(server)) as ws: + async for message in ws: + messages.append(message) + + self.loop.run_until_complete(run_client()) + + self.assertEqual(messages, self.MESSAGES) + + server.close() + self.loop.run_until_complete(server.wait_closed()) + + def test_iterate_on_messages_going_away_exit_ok(self): + async def handler(ws, path): + for message in self.MESSAGES: + await ws.send(message) + await ws.close(1001) + + start_server = serve(handler, "localhost", 0) + server = self.loop.run_until_complete(start_server) + + messages = [] + + async def run_client(): + nonlocal messages + async with connect(get_server_uri(server)) as ws: + async for message in ws: + messages.append(message) + + self.loop.run_until_complete(run_client()) + + self.assertEqual(messages, self.MESSAGES) + + server.close() + self.loop.run_until_complete(server.wait_closed()) + + def test_iterate_on_messages_internal_error_exit_not_ok(self): + async def handler(ws, path): + for message in self.MESSAGES: + await ws.send(message) + await ws.close(1011) + + start_server = serve(handler, "localhost", 0) + server = self.loop.run_until_complete(start_server) + + messages = [] + + async def run_client(): + nonlocal messages + async with connect(get_server_uri(server)) as ws: + async for message in ws: + messages.append(message) + + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(run_client()) + + self.assertEqual(messages, self.MESSAGES) + + server.close() + self.loop.run_until_complete(server.wait_closed()) From 5c08626717e29e55b3e0180050c1292833ceef44 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 26 Jan 2019 10:01:22 +0100 Subject: [PATCH 0542/1539] Remove workarounds for bugs fixed in Python 3.6. --- src/websockets/protocol.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index f87f40086..ec80ecbd9 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -906,12 +906,6 @@ async def write_frame(self, fin, opcode, data, *, _expected_state=State.OPEN): logger.debug("%s > %r", self.side, frame) frame.write(self.writer.write, mask=self.is_client, extensions=self.extensions) - # Backport of https://github.com/python/asyncio/pull/280. - # Remove when dropping support for Python < 3.6. - if self.writer.transport is not None: # pragma: no cover - if self.writer_is_closing(): - await asyncio.sleep(0) - try: # drain() cannot be called concurrently by multiple coroutines: # http://bugs.python.org/issue29930. Remove this lock when no @@ -926,25 +920,6 @@ async def write_frame(self, fin, opcode, data, *, _expected_state=State.OPEN): # with the correct code and reason. await self.ensure_open() - def writer_is_closing(self): - """ - Backport of https://github.com/python/asyncio/pull/291. - - Replace with ``self.writer.transport.is_closing()`` when dropping - support for Python < 3.6 and with ``self.writer.is_closing()`` when - https://bugs.python.org/issue31491 is fixed. - - """ - transport = self.writer.transport - try: - return transport.is_closing() - except AttributeError: # pragma: no cover - # This emulates what is_closing would return if it existed. - try: - return transport._closing - except AttributeError: - return transport._closed - async def write_close_frame(self, data=b""): """ Write a close frame if and only if the connection state is OPEN. From d836f8b107f040ce3877c21d19a029c6a534343b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 25 Jan 2019 22:37:51 +0100 Subject: [PATCH 0543/1539] Take advantage of loop.create_future(). It's the best practice for creating futures in asyncio since Python 3.5.2. Fix #504. --- example/shutdown.py | 2 +- performance/mem_server.py | 2 +- src/websockets/__main__.py | 2 +- src/websockets/protocol.py | 8 ++++---- src/websockets/server.py | 2 +- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/example/shutdown.py b/example/shutdown.py index dd3e8f6a4..6d75af192 100755 --- a/example/shutdown.py +++ b/example/shutdown.py @@ -15,7 +15,7 @@ async def echo_server(stop): loop = asyncio.get_event_loop() # The stop condition is set when receiving SIGTERM. -stop = asyncio.Future() +stop = loop.create_future() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) # Run the server until the stop condition is met. diff --git a/performance/mem_server.py b/performance/mem_server.py index 6c8cef2ec..0a4a29f76 100644 --- a/performance/mem_server.py +++ b/performance/mem_server.py @@ -49,7 +49,7 @@ async def mem_server(stop): loop = asyncio.get_event_loop() -stop = asyncio.Future() +stop = loop.create_future() loop.add_signal_handler(signal.SIGINT, stop.set_result, None) tracemalloc.start() diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index 078733912..4303ce22f 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -173,7 +173,7 @@ def main(): inputs = asyncio.Queue(loop=loop) # Create a stop condition when receiving SIGINT or SIGTERM. - stop = asyncio.Future(loop=loop) + stop = loop.create_future() # Schedule the task that will manage the connection. asyncio.ensure_future(run_client(args.uri, loop, inputs, stop), loop=loop) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index ec80ecbd9..3e02f8465 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -229,7 +229,7 @@ def __init__( # :meth:`connection_lost()` callback to a :class:`~asyncio.Future` # that can be awaited. (Other :class:`~asyncio.Protocol` callbacks are # translated by ``self.stream_reader``). - self.connection_lost_waiter = asyncio.Future(loop=loop) + self.connection_lost_waiter = loop.create_future() # Queue of received messages. self.messages = collections.deque() @@ -405,7 +405,7 @@ async def recv(self): # Wait until there's a message in the queue (if necessary) or the # connection is closed. while len(self.messages) <= 0: - pop_message_waiter = asyncio.Future(loop=self.loop) + pop_message_waiter = self.loop.create_future() self._pop_message_waiter = pop_message_waiter try: # If asyncio.wait() is canceled, it doesn't cancel @@ -609,7 +609,7 @@ async def ping(self, data=None): while data is None or data in self.pings: data = struct.pack("!I", random.getrandbits(32)) - self.pings[data] = asyncio.Future(loop=self.loop) + self.pings[data] = self.loop.create_future() await self.write_frame(True, OP_PING, data) @@ -692,7 +692,7 @@ async def transfer_data(self): # Wait until there's room in the queue (if necessary). while len(self.messages) >= self.max_queue: - self._put_message_waiter = asyncio.Future(loop=self.loop) + self._put_message_waiter = self.loop.create_future() try: await self._put_message_waiter finally: diff --git a/src/websockets/server.py b/src/websockets/server.py index 979fbcd1b..a59107b24 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -546,7 +546,7 @@ def __init__(self, loop): self.close_task = None # Completed when the server is closed and connections are terminated. - self.closed_waiter = asyncio.Future(loop=loop) + self.closed_waiter = loop.create_future() def wrap(self, server): """ From 123d471bc94998651808343db69391930900a1f5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 26 Jan 2019 13:03:20 +0100 Subject: [PATCH 0544/1539] Update compliance testing scripts for Python 3.6. --- compliance/test_client.py | 31 +++++++++++++------------------ compliance/test_server.py | 14 ++++++-------- 2 files changed, 19 insertions(+), 26 deletions(-) diff --git a/compliance/test_client.py b/compliance/test_client.py index 1c1d4416a..5fd0f4b4f 100644 --- a/compliance/test_client.py +++ b/compliance/test_client.py @@ -12,41 +12,36 @@ # logging.getLogger('websockets').setLevel(logging.DEBUG) -SERVER = 'ws://127.0.0.1:8642' -AGENT = 'websockets' +SERVER = "ws://127.0.0.1:8642" +AGENT = "websockets" async def get_case_count(server): - uri = server + '/getCaseCount' - ws = await websockets.connect(uri) - msg = await ws.recv() - await ws.close() + uri = f"{server}/getCaseCount" + async with websockets.connect(uri) as ws: + msg = ws.recv() return json.loads(msg) async def run_case(server, case, agent): - uri = server + '/runCase?case={}&agent={}'.format(case, agent) - ws = await websockets.connect(uri, max_size=2 ** 25, max_queue=1) - while True: - try: - msg = await ws.recv() + uri = f"{server}/runCase?case={case}&agent={agent}" + async with websockets.connect(uri, max_size=2 ** 25, max_queue=1) as ws: + async for msg in ws: await ws.send(msg) - except websockets.ConnectionClosed: - break async def update_reports(server, agent): - uri = server + '/updateReports?agent={}'.format(agent) - ws = await websockets.connect(uri) - await ws.close() + uri = f"{server}/updateReports?agent={agent}" + async with websockets.connect(uri): + pass async def run_tests(server, agent): cases = await get_case_count(server) for case in range(1, cases + 1): - print("Running test case {} out of {}".format(case, cases), end="\r") + print(f"Running test case {case} out of {cases}", end="\r") await run_case(server, case, agent) - print("Ran {} test cases ".format(cases)) + print(f"Ran {cases} test cases ") await update_reports(server, agent) diff --git a/compliance/test_server.py b/compliance/test_server.py index ac5990d16..8020f68d3 100644 --- a/compliance/test_server.py +++ b/compliance/test_server.py @@ -10,17 +10,15 @@ # logging.getLogger('websockets').setLevel(logging.DEBUG) +HOST, PORT = "127.0.0.1", 8642 + + async def echo(ws, path): - while True: - try: - msg = await ws.recv() - await ws.send(msg) - except websockets.ConnectionClosed: - break + async for msg in ws: + await ws.send(msg) -start_server = websockets.serve( - echo, '127.0.0.1', 8642, max_size=2 ** 25, max_queue=1) +start_server = websockets.serve(echo, HOST, PORT, max_size=2 ** 25, max_queue=1) try: asyncio.get_event_loop().run_until_complete(start_server) From 8c86ca1d1e9b8b947c5512f71a78afd5d36bd40c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 27 Jan 2019 20:57:11 +0100 Subject: [PATCH 0545/1539] Switch to f-strings for string formatting. --- docs/conf.py | 2 +- src/websockets/__main__.py | 60 ++++++++----------- src/websockets/client.py | 24 ++++---- src/websockets/exceptions.py | 24 ++++---- .../extensions/permessage_deflate.py | 23 +++---- src/websockets/framing.py | 6 +- src/websockets/headers.py | 4 +- src/websockets/http.py | 9 +-- src/websockets/protocol.py | 4 +- src/websockets/server.py | 8 +-- src/websockets/uri.py | 2 +- tests/test_client_server.py | 4 +- tests/test_http.py | 6 +- tests/test_protocol.py | 6 +- 14 files changed, 77 insertions(+), 105 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 4ad4ad4b7..504656afc 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -51,7 +51,7 @@ # General information about the project. project = 'websockets' -copyright = '2013-{}, Aymeric Augustin'.format(datetime.date.today().year) +copyright = f'2013-{datetime.date.today().year}, Aymeric Augustin and contributors' # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index 4303ce22f..f438750c9 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -53,36 +53,32 @@ def exit_from_event_loop_thread(loop, stop): def print_during_input(string): sys.stdout.write( - ( - # Save cursor position - "\N{ESC}7" - # Add a new line - "\N{LINE FEED}" - # Move cursor up - "\N{ESC}[A" - # Insert blank line, scroll last line down - "\N{ESC}[L" - # Print string in the inserted blank line - "{string}\N{LINE FEED}" - # Restore cursor position - "\N{ESC}8" - # Move cursor down - "\N{ESC}[B" - ).format(string=string) + # Save cursor position + "\N{ESC}7" + # Add a new line + "\N{LINE FEED}" + # Move cursor up + "\N{ESC}[A" + # Insert blank line, scroll last line down + "\N{ESC}[L" + # Print string in the inserted blank line + f"{string}\N{LINE FEED}" + # Restore cursor position + "\N{ESC}8" + # Move cursor down + "\N{ESC}[B" ) sys.stdout.flush() def print_over_input(string): sys.stdout.write( - ( - # Move cursor to beginning of line - "\N{CARRIAGE RETURN}" - # Delete current line - "\N{ESC}[K" - # Print string - "{string}\N{LINE FEED}" - ).format(string=string) + # Move cursor to beginning of line + "\N{CARRIAGE RETURN}" + # Delete current line + "\N{ESC}[K" + # Print string + f"{string}\N{LINE FEED}" ) sys.stdout.flush() @@ -91,11 +87,11 @@ async def run_client(uri, loop, inputs, stop): try: websocket = await websockets.connect(uri) except Exception as exc: - print_over_input("Failed to connect to {}: {}.".format(uri, exc)) + print_over_input(f"Failed to connect to {uri}: {exc}.") exit_from_event_loop_thread(loop, stop) return else: - print_during_input("Connected to {}.".format(uri)) + print_during_input(f"Connected to {uri}.") try: while True: @@ -130,9 +126,7 @@ async def run_client(uri, loop, inputs, stop): await websocket.close() close_status = format_close(websocket.close_code, websocket.close_reason) - print_over_input( - "Connection closed: {close_status}.".format(close_status=close_status) - ) + print_over_input(f"Connection closed: {close_status}.") exit_from_event_loop_thread(loop, stop) @@ -144,11 +138,9 @@ def main(): win_enable_vt100() except RuntimeError as exc: sys.stderr.write( - ( - "Unable to set terminal to VT100 mode. This is only " - "supported since Win10 anniversary update. Expect " - "weird symbols on the terminal.\nError: {exc!s}\n" - ).format(exc=exc) + f"Unable to set terminal to VT100 mode. This is only " + f"supported since Win10 anniversary update. Expect " + f"weird symbols on the terminal.\nError: {exc}\n" ) sys.stderr.flush() diff --git a/src/websockets/client.py b/src/websockets/client.py index 9babbb412..57fd33b25 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -52,7 +52,7 @@ def __init__( extensions=None, subprotocols=None, extra_headers=None, - **kwds + **kwds, ): self.origin = origin self.available_extensions = extensions @@ -73,7 +73,7 @@ def write_http_request(self, path, headers): # Since the path and headers only contain ASCII characters, # we can keep this simple. - request = "GET {path} HTTP/1.1\r\n".format(path=path) + request = f"GET {path} HTTP/1.1\r\n" request += str(headers) self.writer.write(request.encode()) @@ -170,9 +170,8 @@ def process_extensions(headers, available_extensions): # matched what the server sent. Fail the connection. else: raise NegotiationError( - "Unsupported extension: name = {}, params = {}".format( - name, response_params - ) + f"Unsupported extension: " + f"name = {name}, params = {response_params}" ) return accepted_extensions @@ -205,16 +204,13 @@ def process_subprotocol(headers, available_subprotocols): ) if len(parsed_header_values) > 1: - raise InvalidHandshake( - "Multiple subprotocols: {}".format(", ".join(parsed_header_values)) - ) + subprotocols = ", ".join(parsed_header_values) + raise InvalidHandshake(f"Multiple subprotocols: {subprotocols}") subprotocol = parsed_header_values[0] if subprotocol not in available_subprotocols: - raise NegotiationError( - "Unsupported subprotocol: {}".format(subprotocol) - ) + raise NegotiationError(f"Unsupported subprotocol: {subprotocol}") return subprotocol @@ -251,7 +247,7 @@ async def handshake( if wsuri.port == (443 if wsuri.secure else 80): # pragma: no cover request_headers["Host"] = wsuri.host else: - request_headers["Host"] = "{}:{}".format(wsuri.host, wsuri.port) + request_headers["Host"] = f"{wsuri.host}:{wsuri.port}" if wsuri.user_info: request_headers["Authorization"] = build_basic_auth(*wsuri.user_info) @@ -382,7 +378,7 @@ def __init__( extensions=None, subprotocols=None, extra_headers=None, - **kwds + **kwds, ): if loop is None: loop = asyncio.get_event_loop() @@ -417,7 +413,7 @@ def __init__( ClientPerMessageDeflateFactory(client_max_window_bits=True) ) elif compression is not None: - raise ValueError("Unsupported compression: {}".format(compression)) + raise ValueError(f"Unsupported compression: {compression}") self._create_protocol = create_protocol self._ping_interval = ping_interval diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 611e68188..50f3ab373 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -37,9 +37,7 @@ def __init__(self, status, headers, body=b""): self.status = status self.headers = headers self.body = body - message = "HTTP {}, {} headers, {} bytes".format( - status, len(headers), len(body) - ) + message = f"HTTP {status}, {len(headers)} headers, {len(body)} bytes" super().__init__(message) @@ -68,11 +66,11 @@ class InvalidHeader(InvalidHandshake): def __init__(self, name, value=None): if value is None: - message = "Missing {} header".format(name) + message = f"Missing {name} header" elif value == "": - message = "Empty {} header".format(name) + message = f"Empty {name} header" else: - message = "Invalid {} header: {}".format(name, value) + message = f"Invalid {name} header: {value}" super().__init__(message) @@ -83,7 +81,7 @@ class InvalidHeaderFormat(InvalidHeader): """ def __init__(self, name, error, string, pos): - error = "{} at {} in {}".format(error, pos, string) + error = f"{error} at {pos} in {string}" super().__init__(name, error) @@ -121,7 +119,7 @@ class InvalidStatusCode(InvalidHandshake): def __init__(self, status_code): self.status_code = status_code - message = "Status code not 101: {}".format(status_code) + message = f"Status code not 101: {status_code}" super().__init__(message) @@ -140,7 +138,7 @@ class InvalidParameterName(NegotiationError): def __init__(self, name): self.name = name - message = "Invalid parameter name: {}".format(name) + message = f"Invalid parameter name: {name}" super().__init__(message) @@ -153,7 +151,7 @@ class InvalidParameterValue(NegotiationError): def __init__(self, name, value): self.name = name self.value = value - message = "Invalid value for parameter {}: {}".format(name, value) + message = f"Invalid value for parameter {name}: {value}" super().__init__(message) @@ -165,7 +163,7 @@ class DuplicateParameter(NegotiationError): def __init__(self, name): self.name = name - message = "Duplicate parameter: {}".format(name) + message = f"Duplicate parameter: {name}" super().__init__(message) @@ -205,10 +203,10 @@ def format_close(code, reason): explanation = "private use" else: explanation = CLOSE_CODES.get(code, "unknown") - result = "code = {} ({}), ".format(code, explanation) + result = f"code = {code} ({explanation}), " if reason: - result += "reason = {}".format(reason) + result += f"reason = {reason}" else: result += "no reason" diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 167746021..2c2be49bd 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -186,7 +186,7 @@ def process_response_params(self, params, accepted_extensions): """ if any(other.name == self.name for other in accepted_extensions): - raise NegotiationError("Received duplicate {}".format(self.name)) + raise NegotiationError(f"Received duplicate {self.name}") # Request parameters are available in instance variables. @@ -339,7 +339,7 @@ def process_request_params(self, params, accepted_extensions): """ if any(other.name == self.name for other in accepted_extensions): - raise NegotiationError("Skipped duplicate {}".format(self.name)) + raise NegotiationError(f"Skipped duplicate {self.name}") # Load request parameters in local variables. ( @@ -491,16 +491,11 @@ def __init__( def __repr__(self): return ( - "PerMessageDeflate(" - "remote_no_context_takeover={}, " - "local_no_context_takeover={}, " - "remote_max_window_bits={}, " - "local_max_window_bits={})" - ).format( - self.remote_no_context_takeover, - self.local_no_context_takeover, - self.remote_max_window_bits, - self.local_max_window_bits, + f"PerMessageDeflate(" + f"remote_no_context_takeover={self.remote_no_context_takeover}, " + f"local_no_context_takeover={self.local_no_context_takeover}, " + f"remote_max_window_bits={self.remote_max_window_bits}, " + f"local_max_window_bits={self.local_max_window_bits})" ) def decode(self, frame, *, max_size=None): @@ -544,9 +539,7 @@ def decode(self, frame, *, max_size=None): data = self.decoder.decompress(data, max_length) if self.decoder.unconsumed_tail: raise PayloadTooBig( - "Uncompressed payload length exceeds size limit (? > {} bytes)".format( - max_size - ) + f"Uncompressed payload length exceeds size limit (? > {max_size} bytes)" ) # Allow garbage collection of the decoder if it won't be reused. diff --git a/src/websockets/framing.py b/src/websockets/framing.py index c6b5564f5..0abe8f8db 100644 --- a/src/websockets/framing.py +++ b/src/websockets/framing.py @@ -117,9 +117,7 @@ async def read(cls, reader, *, mask, max_size=None, extensions=None): length, = struct.unpack("!Q", data) if max_size is not None and length > max_size: raise PayloadTooBig( - "Payload length exceeds size limit ({} > {} bytes)".format( - length, max_size - ) + f"Payload length exceeds size limit ({length} > {max_size} bytes)" ) if mask: mask_bits = await reader(4) @@ -231,7 +229,7 @@ def check(frame): if not frame.fin: raise WebSocketProtocolError("Fragmented control frame") else: - raise WebSocketProtocolError("Invalid opcode: {}".format(frame.opcode)) + raise WebSocketProtocolError(f"Invalid opcode: {frame.opcode}") def prepare_data(data): diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 6151b16db..73f11edce 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -302,7 +302,7 @@ def build_extension(name, parameters): [name] + [ # Quoted strings aren't necessary because values are always tokens. - name if value is None else "{}={}".format(name, value) + name if value is None else f"{name}={value}" for name, value in parameters ] ) @@ -347,6 +347,6 @@ def build_basic_auth(username, password): """ # https://tools.ietf.org/html/rfc7617#section-2 assert ":" not in username - user_pass = "{}:{}".format(username, password) + user_pass = f"{username}:{password}" basic_credentials = base64.b64encode(user_pass.encode()).decode() return "Basic " + basic_credentials diff --git a/src/websockets/http.py b/src/websockets/http.py index e28acac9f..ab74614af 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -25,7 +25,7 @@ MAX_HEADERS = 256 MAX_LINE = 4096 -USER_AGENT = "Python/{} websockets/{}".format(sys.version[:3], websockets_version) +USER_AGENT = f"Python/{sys.version[:3]} websockets/{websockets_version}" # See https://tools.ietf.org/html/rfc7230#appendix-B. @@ -252,13 +252,10 @@ def __init__(self, *args, **kwargs): self.update(*args, **kwargs) def __str__(self): - return ( - "".join("{}: {}\r\n".format(key, value) for key, value in self._list) - + "\r\n" - ) + return "".join(f"{key}: {value}\r\n" for key, value in self._list) + "\r\n" def __repr__(self): - return "{}({})".format(self.__class__.__name__, repr(self._list)) + return f"{self.__class__.__name__}({self._list!r})" def copy(self): copy = self.__class__() diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 3e02f8465..f0b126934 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -167,7 +167,7 @@ def __init__( write_limit=2 ** 16, loop=None, legacy_recv=False, - timeout=10 + timeout=10, ): # Backwards-compatibility: close_timeout used to be called timeout. # If both are specified, timeout is ignored. @@ -899,7 +899,7 @@ async def write_frame(self, fin, opcode, data, *, _expected_state=State.OPEN): # Defensive assertion for protocol compliance. if self.state is not _expected_state: # pragma: no cover raise InvalidState( - "Cannot write to a WebSocket in the {} state".format(self.state.name) + f"Cannot write to a WebSocket in the {self.state.name} state" ) frame = Frame(fin, opcode, data) diff --git a/src/websockets/server.py b/src/websockets/server.py index a59107b24..007ebd725 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -57,7 +57,7 @@ def __init__( extra_headers=None, process_request=None, select_subprotocol=None, - **kwds + **kwds, ): # For backwards-compatibility with 6.0 or earlier. if origins is not None and "" in origins: @@ -225,7 +225,7 @@ def write_http_response(self, status, headers, body=None): # Since the status line and headers only contain ASCII characters, # we can keep this simple. - response = "HTTP/1.1 {status.value} {status.phrase}\r\n".format(status=status) + response = f"HTTP/1.1 {status.value} {status.phrase}\r\n" response += str(headers) self.writer.write(response.encode()) @@ -775,7 +775,7 @@ def __init__( extra_headers=None, process_request=None, select_subprotocol=None, - **kwds + **kwds, ): # Backwards-compatibility: close_timeout used to be called timeout. # If both are specified, timeout is ignored. @@ -803,7 +803,7 @@ def __init__( ): extensions.append(ServerPerMessageDeflateFactory()) elif compression is not None: - raise ValueError("Unsupported compression: {}".format(compression)) + raise ValueError(f"Unsupported compression: {compression}") factory = lambda: create_protocol( ws_handler, diff --git a/src/websockets/uri.py b/src/websockets/uri.py index b6e1ad0ce..730adf54e 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -47,7 +47,7 @@ def parse_uri(uri): assert uri.fragment == "" assert uri.hostname is not None except AssertionError as exc: - raise InvalidURI("{} isn't a valid URI".format(uri)) from exc + raise InvalidURI(f"{uri} isn't a valid URI") from exc secure = uri.scheme == "wss" host = uri.hostname diff --git a/tests/test_client_server.py b/tests/test_client_server.py index cbac7a24c..fc88b3139 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -147,7 +147,7 @@ def get_server_uri(server, secure=False, resource_name="/", user_info=None): if server_socket.family == socket.AF_INET6: # pragma: no cover host, port = server_socket.getsockname()[:2] # (no IPv6 on CI) - host = "[{}]".format(host) + host = f"[{host}]" elif server_socket.family == socket.AF_INET: host, port = server_socket.getsockname() elif server_socket.family == socket.AF_UNIX: @@ -156,7 +156,7 @@ def get_server_uri(server, secure=False, resource_name="/", user_info=None): else: # pragma: no cover raise ValueError("Expected an IPv6, IPv4, or Unix socket") - return "{}://{}{}:{}{}".format(proto, user_info, host, port, resource_name) + return f"{proto}://{user_info}{host}:{port}{resource_name}" class UnauthorizedServerProtocol(WebSocketServerProtocol): diff --git a/tests/test_http.py b/tests/test_http.py index b28bed6ce..a3a8cd403 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -108,15 +108,13 @@ def setUp(self): def test_str(self): self.assertEqual( - str(self.headers), - "Connection: Upgrade\r\nServer: {}\r\n\r\n".format(USER_AGENT), + str(self.headers), f"Connection: Upgrade\r\nServer: {USER_AGENT}\r\n\r\n" ) def test_repr(self): self.assertEqual( repr(self.headers), - "Headers([('Connection', 'Upgrade'), " - "('Server', '{}')])".format(USER_AGENT), + f"Headers([('Connection', 'Upgrade'), " f"('Server', '{USER_AGENT}')])", ) def test_multiple_values_error_str(self): diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 9e9d40393..154948e43 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -321,8 +321,8 @@ def assertCompletesWithin(self, min_time, max_time): yield t1 = self.loop.time() dt = t1 - t0 - self.assertGreaterEqual(dt, min_time, "Too fast: {} < {}".format(dt, min_time)) - self.assertLess(dt, max_time, "Too slow: {} >= {}".format(dt, max_time)) + self.assertGreaterEqual(dt, min_time, f"Too fast: {dt} < {min_time}") + self.assertLess(dt, max_time, f"Too slow: {dt} >= {max_time}") # Test public attributes. @@ -499,7 +499,7 @@ def test_recv_canceled(self): def test_recv_canceled_race_condition(self): recv = self.loop.create_task( - asyncio.wait_for(self.protocol.recv(), timeout=0.000001) + asyncio.wait_for(self.protocol.recv(), timeout=0.000_001) ) self.loop.call_soon( self.receive_frame, Frame(True, OP_TEXT, "café".encode("utf-8")) From 7bff03bbd74ab10a08237efbfb61cc831fa67de8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 18 Jan 2019 22:47:33 +0100 Subject: [PATCH 0546/1539] Get a clean, non-strict mypy run. --- .circleci/config.yml | 2 +- .gitignore | 1 + Makefile | 1 + src/websockets/protocol.py | 10 +++++----- src/websockets/server.py | 11 +++++++---- src/websockets/speedups.pyi | 1 + tox.ini | 6 +++++- 7 files changed, 21 insertions(+), 11 deletions(-) create mode 100644 src/websockets/speedups.pyi diff --git a/.circleci/config.yml b/.circleci/config.yml index 8a7df9ac6..a6c85d237 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -9,7 +9,7 @@ jobs: - run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc - checkout - run: sudo pip install tox codecov - - run: tox -e coverage,black,flake8,isort + - run: tox -e coverage,black,flake8,isort,mypy - run: codecov py36: docker: diff --git a/.gitignore b/.gitignore index 4dc1216b7..ef0d16520 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ *.pyc *.so .coverage +.mypy_cache .tox build/ compliance/reports/ diff --git a/Makefile b/Makefile index 9fa5c2422..f94a9103c 100644 --- a/Makefile +++ b/Makefile @@ -5,6 +5,7 @@ style: isort --recursive src tests black src tests flake8 src tests + mypy src test: python -W default -m unittest diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index f0b126934..5cc1bcc90 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -146,11 +146,11 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): """ - # There are only two differences between the client-side and the server- - # side behavior: masking the payload and closing the underlying TCP - # connection. Set is_client and side to pick a side. - is_client = None - side = "undefined" + # There are only two differences between the client-side and server-side + # behavior: masking the payload and closing the underlying TCP connection. + # Set is_client = True/False and side = "client"/"server" to pick a side. + is_client: bool + side: str = "undefined" def __init__( self, diff --git a/src/websockets/server.py b/src/websockets/server.py index 007ebd725..7fd32ba1e 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -69,10 +69,8 @@ def __init__( self.available_extensions = extensions self.available_subprotocols = subprotocols self.extra_headers = extra_headers - if process_request is not None: - self.process_request = process_request - if select_subprotocol is not None: - self.select_subprotocol = select_subprotocol + self._process_request = process_request + self._select_subprotocol = select_subprotocol super().__init__(**kwds) def connection_made(self, transport): @@ -264,6 +262,8 @@ def process_request(self, path, request_headers): function. """ + if self._process_request is not None: + return self._process_request(path, request_headers) @staticmethod def process_origin(headers, origins=None): @@ -414,6 +414,9 @@ def select_subprotocol(self, client_subprotocols, server_subprotocols): :func:`serve` function. """ + if self._select_subprotocol is not None: + return self._select_subprotocol(client_subprotocols, server_subprotocols) + subprotocols = set(client_subprotocols) & set(server_subprotocols) if not subprotocols: return None diff --git a/src/websockets/speedups.pyi b/src/websockets/speedups.pyi new file mode 100644 index 000000000..821438a06 --- /dev/null +++ b/src/websockets/speedups.pyi @@ -0,0 +1 @@ +def apply_mask(data: bytes, mask: bytes) -> bytes: ... diff --git a/tox.ini b/tox.ini index 238fcd649..4d085f56c 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py36,py37,coverage,black,flake8,isort +envlist = py36,py37,coverage,black,flake8,isort,mypy [testenv] commands = python -W default -m unittest {posargs} @@ -22,3 +22,7 @@ deps = flake8 [testenv:isort] commands = isort --check-only --recursive src tests deps = isort + +[testenv:mypy] +commands = mypy src +deps = mypy From 94945fec3280c1c6d04aa4c9dc41a794f1ef1a42 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 19 Jan 2019 21:00:06 +0100 Subject: [PATCH 0547/1539] Annotate source code with type hints. Fix a few minor bugs revealed by static typing. This required adjustments where the code wasn't statically typable, usally because a variable was reused with a different type. --- Makefile | 2 +- docs/changelog.rst | 2 + src/websockets/__main__.py | 79 ++-- src/websockets/client.py | 121 +++-- src/websockets/exceptions.py | 45 +- src/websockets/extensions/base.py | 88 ++-- .../extensions/permessage_deflate.py | 347 +++++++------- src/websockets/framing.py | 82 +++- src/websockets/handshake.py | 20 +- src/websockets/headers.py | 55 ++- src/websockets/http.py | 94 ++-- src/websockets/protocol.py | 237 +++++----- src/websockets/server.py | 243 ++++++---- src/websockets/uri.py | 50 +- src/websockets/utils.py | 2 +- tests/extensions/test_permessage_deflate.py | 438 +++++++++--------- tests/test_exceptions.py | 2 +- tests/test_http.py | 3 + tox.ini | 2 +- 19 files changed, 1088 insertions(+), 824 deletions(-) diff --git a/Makefile b/Makefile index f94a9103c..30dbfd9c1 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ style: isort --recursive src tests black src tests flake8 src tests - mypy src + mypy --strict src test: python -W default -m unittest diff --git a/docs/changelog.rst b/docs/changelog.rst index b53080501..1c4b1bc96 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -28,6 +28,8 @@ Also: * :func:`~client.connect()` handles redirects from the server during the handshake. +* Added type hints (:pep:`484`). + * Added documentation for extensions. * Documented how to optimize memory usage. diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index f438750c9..604caa5e4 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -4,42 +4,47 @@ import signal import sys import threading +from typing import Any, Set import websockets from websockets.exceptions import format_close -def win_enable_vt100(): - """ - Enable VT-100 for console output on Windows. +if sys.platform == "win32": - See also https://bugs.python.org/issue29059. + def win_enable_vt100() -> None: + """ + Enable VT-100 for console output on Windows. - """ - import ctypes + See also https://bugs.python.org/issue29059. - STD_OUTPUT_HANDLE = ctypes.c_uint(-11) - INVALID_HANDLE_VALUE = ctypes.c_uint(-1) - ENABLE_VIRTUAL_TERMINAL_PROCESSING = 0x004 + """ + import ctypes - handle = ctypes.windll.kernel32.GetStdHandle(STD_OUTPUT_HANDLE) - if handle == INVALID_HANDLE_VALUE: - raise RuntimeError("Unable to obtain stdout handle") + STD_OUTPUT_HANDLE = ctypes.c_uint(-11) + INVALID_HANDLE_VALUE = ctypes.c_uint(-1) + ENABLE_VIRTUAL_TERMINAL_PROCESSING = 0x004 - cur_mode = ctypes.c_uint() - if ctypes.windll.kernel32.GetConsoleMode(handle, ctypes.byref(cur_mode)) == 0: - raise RuntimeError("Unable to query current console mode") + handle = ctypes.windll.kernel32.GetStdHandle(STD_OUTPUT_HANDLE) + if handle == INVALID_HANDLE_VALUE: + raise RuntimeError("Unable to obtain stdout handle") - # ctypes ints lack support for the required bit-OR operation. - # Temporarily convert to Py int, do the OR and convert back. - py_int_mode = int.from_bytes(cur_mode, sys.byteorder) - new_mode = ctypes.c_uint(py_int_mode | ENABLE_VIRTUAL_TERMINAL_PROCESSING) + cur_mode = ctypes.c_uint() + if ctypes.windll.kernel32.GetConsoleMode(handle, ctypes.byref(cur_mode)) == 0: + raise RuntimeError("Unable to query current console mode") - if ctypes.windll.kernel32.SetConsoleMode(handle, new_mode) == 0: - raise RuntimeError("Unable to set console mode") + # ctypes ints lack support for the required bit-OR operation. + # Temporarily convert to Py int, do the OR and convert back. + py_int_mode = int.from_bytes(cur_mode, sys.byteorder) + new_mode = ctypes.c_uint(py_int_mode | ENABLE_VIRTUAL_TERMINAL_PROCESSING) + if ctypes.windll.kernel32.SetConsoleMode(handle, new_mode) == 0: + raise RuntimeError("Unable to set console mode") -def exit_from_event_loop_thread(loop, stop): + +def exit_from_event_loop_thread( + loop: asyncio.AbstractEventLoop, stop: asyncio.Future[None] +) -> None: loop.stop() if not stop.done(): # When exiting the thread that runs the event loop, raise @@ -51,7 +56,7 @@ def exit_from_event_loop_thread(loop, stop): os.kill(os.getpid(), ctrl_c) -def print_during_input(string): +def print_during_input(string: str) -> None: sys.stdout.write( # Save cursor position "\N{ESC}7" @@ -71,7 +76,7 @@ def print_during_input(string): sys.stdout.flush() -def print_over_input(string): +def print_over_input(string: str) -> None: sys.stdout.write( # Move cursor to beginning of line "\N{CARRIAGE RETURN}" @@ -83,7 +88,12 @@ def print_over_input(string): sys.stdout.flush() -async def run_client(uri, loop, inputs, stop): +async def run_client( + uri: str, + loop: asyncio.AbstractEventLoop, + inputs: asyncio.Queue[str], + stop: asyncio.Future[None], +) -> None: try: websocket = await websockets.connect(uri) except Exception as exc: @@ -95,8 +105,10 @@ async def run_client(uri, loop, inputs, stop): try: while True: - incoming = asyncio.ensure_future(websocket.recv()) - outgoing = asyncio.ensure_future(inputs.get()) + incoming: asyncio.Future[Any] = asyncio.ensure_future(websocket.recv()) + outgoing: asyncio.Future[Any] = asyncio.ensure_future(inputs.get()) + done: Set[asyncio.Future[Any]] + pending: Set[asyncio.Future[Any]] done, pending = await asyncio.wait( [incoming, outgoing, stop], return_when=asyncio.FIRST_COMPLETED ) @@ -113,7 +125,10 @@ async def run_client(uri, loop, inputs, stop): except websockets.ConnectionClosed: break else: - print_during_input("< " + message) + if isinstance(message, str): + print_during_input("< " + message) + else: + print_during_input("< (binary) " + message.hex()) if outgoing in done: message = outgoing.result() @@ -131,9 +146,9 @@ async def run_client(uri, loop, inputs, stop): exit_from_event_loop_thread(loop, stop) -def main(): +def main() -> None: # If we're on Windows, enable VT100 terminal support. - if os.name == "nt": + if sys.platform == "win32": try: win_enable_vt100() except RuntimeError as exc: @@ -162,10 +177,10 @@ def main(): loop = asyncio.new_event_loop() # Create a queue of user inputs. There's no need to limit its size. - inputs = asyncio.Queue(loop=loop) + inputs: asyncio.Queue[str] = asyncio.Queue(loop=loop) # Create a stop condition when receiving SIGINT or SIGTERM. - stop = loop.create_future() + stop: asyncio.Future[None] = loop.create_future() # Schedule the task that will manage the connection. asyncio.ensure_future(run_client(args.uri, loop, inputs, stop), loop=loop) diff --git a/src/websockets/client.py b/src/websockets/client.py index 57fd33b25..40c5b0073 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -6,6 +6,8 @@ import asyncio import collections.abc import logging +from types import TracebackType +from typing import Any, Generator, List, Optional, Tuple, Type, cast from .exceptions import ( InvalidHandshake, @@ -14,18 +16,20 @@ NegotiationError, RedirectHandshake, ) +from .extensions.base import ClientExtensionFactory, Extension from .extensions.permessage_deflate import ClientPerMessageDeflateFactory from .handshake import build_request, check_response from .headers import ( + ExtensionHeader, build_basic_auth, build_extension_list, build_subprotocol_list, parse_extension_list, parse_subprotocol_list, ) -from .http import USER_AGENT, Headers, read_response +from .http import USER_AGENT, Headers, HeadersLike, read_response from .protocol import WebSocketCommonProtocol -from .uri import parse_uri +from .uri import WebSocketURI, parse_uri __all__ = ["connect", "WebSocketClientProtocol"] @@ -48,19 +52,19 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): def __init__( self, *, - origin=None, - extensions=None, - subprotocols=None, - extra_headers=None, - **kwds, - ): + origin: Optional[str] = None, + extensions: Optional[List[ClientExtensionFactory]] = None, + subprotocols: Optional[List[str]] = None, + extra_headers: Optional[HeadersLike] = None, + **kwds: Any, + ) -> None: self.origin = origin self.available_extensions = extensions self.available_subprotocols = subprotocols self.extra_headers = extra_headers super().__init__(**kwds) - def write_http_request(self, path, headers): + def write_http_request(self, path: str, headers: Headers) -> None: """ Write request line and headers to the HTTP request. @@ -78,7 +82,7 @@ def write_http_request(self, path, headers): self.writer.write(request.encode()) - async def read_http_response(self): + async def read_http_response(self) -> Tuple[int, Headers]: """ Read status line and headers from the HTTP response. @@ -103,7 +107,9 @@ async def read_http_response(self): return status_code, self.response_headers @staticmethod - def process_extensions(headers, available_extensions): + def process_extensions( + headers: Headers, available_extensions: Optional[List[ClientExtensionFactory]] + ) -> List[Extension]: """ Handle the Sec-WebSocket-Extensions HTTP response header. @@ -130,7 +136,7 @@ def process_extensions(headers, available_extensions): order of extensions, may be implemented by overriding this method. """ - accepted_extensions = [] + accepted_extensions: List[Extension] = [] header_values = headers.get_all("Sec-WebSocket-Extensions") @@ -139,7 +145,7 @@ def process_extensions(headers, available_extensions): if available_extensions is None: raise InvalidHandshake("No extensions supported") - parsed_header_values = sum( + parsed_header_values: List[ExtensionHeader] = sum( [parse_extension_list(header_value) for header_value in header_values], [], ) @@ -177,7 +183,9 @@ def process_extensions(headers, available_extensions): return accepted_extensions @staticmethod - def process_subprotocol(headers, available_subprotocols): + def process_subprotocol( + headers: Headers, available_subprotocols: Optional[List[str]] + ) -> Optional[str]: """ Handle the Sec-WebSocket-Protocol HTTP response header. @@ -186,7 +194,7 @@ def process_subprotocol(headers, available_subprotocols): Return the selected subprotocol. """ - subprotocol = None + subprotocol: Optional[str] = None header_values = headers.get_all("Sec-WebSocket-Protocol") @@ -195,7 +203,7 @@ def process_subprotocol(headers, available_subprotocols): if available_subprotocols is None: raise InvalidHandshake("No subprotocols supported") - parsed_header_values = sum( + parsed_header_values: List[str] = sum( [ parse_subprotocol_list(header_value) for header_value in header_values @@ -216,12 +224,12 @@ def process_subprotocol(headers, available_subprotocols): async def handshake( self, - wsuri, - origin=None, - available_extensions=None, - available_subprotocols=None, - extra_headers=None, - ): + wsuri: WebSocketURI, + origin: Optional[str] = None, + available_extensions: Optional[List[ClientExtensionFactory]] = None, + available_subprotocols: Optional[List[str]] = None, + extra_headers: Optional[HeadersLike] = None, + ) -> None: """ Perform the client side of the opening handshake. @@ -359,26 +367,26 @@ class Connect: def __init__( self, - uri, + uri: str, *, - create_protocol=None, - ping_interval=20, - ping_timeout=20, - close_timeout=None, - max_size=2 ** 20, - max_queue=2 ** 5, - read_limit=2 ** 16, - write_limit=2 ** 16, - loop=None, - legacy_recv=False, - klass=WebSocketClientProtocol, - timeout=10, - compression="deflate", - origin=None, - extensions=None, - subprotocols=None, - extra_headers=None, - **kwds, + create_protocol: Optional[Type[WebSocketClientProtocol]] = None, + ping_interval: float = 20, + ping_timeout: float = 20, + close_timeout: Optional[float] = None, + max_size: int = 2 ** 20, + max_queue: int = 2 ** 5, + read_limit: int = 2 ** 16, + write_limit: int = 2 ** 16, + loop: Optional[asyncio.AbstractEventLoop] = None, + legacy_recv: bool = False, + klass: Type[WebSocketClientProtocol] = WebSocketClientProtocol, + timeout: float = 10, + compression: Optional[str] = "deflate", + origin: Optional[str] = None, + extensions: Optional[List[ClientExtensionFactory]] = None, + subprotocols: Optional[List[str]] = None, + extra_headers: Optional[HeadersLike] = None, + **kwds: Any, ): if loop is None: loop = asyncio.get_event_loop() @@ -434,7 +442,9 @@ def __init__( self._extra_headers = extra_headers self._kwds = kwds - def _creating_connection(self): + async def _creating_connection( + self + ) -> Tuple[asyncio.Transport, WebSocketClientProtocol]: if self._wsuri.secure: self._kwds.setdefault("ssl", True) @@ -457,6 +467,8 @@ def _creating_connection(self): extra_headers=self._extra_headers, ) + host: Optional[str] + port: Optional[int] if self._kwds.get("sock") is None: host, port = self._wsuri.host, self._wsuri.port else: @@ -467,19 +479,30 @@ def _creating_connection(self): self._origin = self._origin # This is a coroutine object. - return self._loop.create_connection(factory, host, port, **self._kwds) + # https://github.com/python/typeshed/pull/2756 + transport, protocol = await self._loop.create_connection( # type: ignore + factory, host, port, **self._kwds + ) + transport = cast(asyncio.Transport, transport) + protocol = cast(WebSocketClientProtocol, protocol) + return transport, protocol @asyncio.coroutine - def __iter__(self): - return self.__await_impl__() + def __iter__(self) -> Generator[Any, None, WebSocketClientProtocol]: + return (yield from self.__await__()) - async def __aenter__(self): + async def __aenter__(self) -> WebSocketClientProtocol: return await self - async def __aexit__(self, exc_type, exc_value, traceback): + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: await self.ws_client.close() - async def __await_impl__(self): + async def __await_impl__(self) -> WebSocketClientProtocol: for redirects in range(self.MAX_REDIRECTS_ALLOWED): transport, protocol = await self._creating_connection() @@ -508,7 +531,7 @@ async def __await_impl__(self): self.ws_client = protocol return protocol - def __await__(self): + def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]: # __await__() must return a type that I don't know how to obtain except # by calling __await__() on the return value of an async function. # I'm not finding a better way to take advantage of PEP 492. diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 50f3ab373..9999527ef 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -1,3 +1,15 @@ +import http +from typing import TYPE_CHECKING, Any, Optional + +from .http import Headers, HeadersLike + + +if TYPE_CHECKING: # pragma: no cover + from .uri import WebSocketURI +else: + WebSocketURI = Any + + __all__ = [ "AbortHandshake", "ConnectionClosed", @@ -33,11 +45,13 @@ class AbortHandshake(InvalidHandshake): """ - def __init__(self, status, headers, body=b""): + def __init__( + self, status: http.HTTPStatus, headers: HeadersLike, body: bytes = b"" + ) -> None: self.status = status - self.headers = headers + self.headers = Headers(headers) self.body = body - message = f"HTTP {status}, {len(headers)} headers, {len(body)} bytes" + message = f"HTTP {status}, {len(self.headers)} headers, {len(body)} bytes" super().__init__(message) @@ -47,7 +61,7 @@ class RedirectHandshake(InvalidHandshake): """ - def __init__(self, wsuri): + def __init__(self, wsuri: WebSocketURI) -> None: self.wsuri = wsuri @@ -64,7 +78,7 @@ class InvalidHeader(InvalidHandshake): """ - def __init__(self, name, value=None): + def __init__(self, name: str, value: Optional[str] = None) -> None: if value is None: message = f"Missing {name} header" elif value == "": @@ -80,7 +94,7 @@ class InvalidHeaderFormat(InvalidHeader): """ - def __init__(self, name, error, string, pos): + def __init__(self, name: str, error: str, string: str, pos: int) -> None: error = f"{error} at {pos} in {string}" super().__init__(name, error) @@ -105,7 +119,7 @@ class InvalidOrigin(InvalidHeader): """ - def __init__(self, origin): + def __init__(self, origin: Optional[str]) -> None: super().__init__("Origin", origin) @@ -117,7 +131,7 @@ class InvalidStatusCode(InvalidHandshake): """ - def __init__(self, status_code): + def __init__(self, status_code: int) -> None: self.status_code = status_code message = f"Status code not 101: {status_code}" super().__init__(message) @@ -136,7 +150,7 @@ class InvalidParameterName(NegotiationError): """ - def __init__(self, name): + def __init__(self, name: str) -> None: self.name = name message = f"Invalid parameter name: {name}" super().__init__(message) @@ -148,7 +162,7 @@ class InvalidParameterValue(NegotiationError): """ - def __init__(self, name, value): + def __init__(self, name: str, value: Optional[str]) -> None: self.name = name self.value = value message = f"Invalid value for parameter {name}: {value}" @@ -161,7 +175,7 @@ class DuplicateParameter(NegotiationError): """ - def __init__(self, name): + def __init__(self, name: str) -> None: self.name = name message = f"Duplicate parameter: {name}" super().__init__(message) @@ -191,7 +205,7 @@ class InvalidState(Exception): } -def format_close(code, reason): +def format_close(code: int, reason: str) -> str: """ Display a human-readable version of the close code and reason. @@ -222,7 +236,7 @@ class ConnectionClosed(InvalidState): """ - def __init__(self, code, reason): + def __init__(self, code: int, reason: str) -> None: self.code = code self.reason = reason message = "WebSocket connection is closed: " @@ -236,6 +250,11 @@ class InvalidURI(Exception): """ + def __init__(self, uri: str) -> None: + self.uri = uri + message = "{} isn't a valid URI".format(uri) + super().__init__(message) + class PayloadTooBig(Exception): """ diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index cf3f9a2ec..707e9317a 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -6,6 +6,46 @@ """ +from typing import List, Optional, Tuple + +from ..framing import Frame +from ..headers import ExtensionParameters + + +__all__ = ["Extension", "ClientExtensionFactory", "ServerExtensionFactory"] + + +class Extension: + """ + Abstract class for extensions. + + """ + + @property + def name(self) -> str: + """ + Extension identifier. + + """ + + def decode(self, frame: Frame, *, max_size: Optional[int] = None) -> Frame: + """ + Decode an incoming frame. + + The ``frame`` parameter and the return value are + :class:`~websockets.framing.Frame` instances. + + """ + + def encode(self, frame: Frame) -> Frame: + """ + Encode an outgoing frame. + + The ``frame`` parameter and the return value are + :class:`~websockets.framing.Frame` instances. + + """ + class ClientExtensionFactory: """ @@ -14,13 +54,13 @@ class ClientExtensionFactory: """ @property - def name(self): + def name(self) -> str: """ Extension identifier. """ - def get_request_params(self): + def get_request_params(self) -> ExtensionParameters: """ Build request parameters. @@ -28,7 +68,9 @@ def get_request_params(self): """ - def process_response_params(self, params, accepted_extensions): + def process_response_params( + self, params: ExtensionParameters, accepted_extensions: List[Extension] + ) -> Extension: """ Process response parameters received from the server. @@ -51,13 +93,15 @@ class ServerExtensionFactory: """ @property - def name(self): + def name(self) -> str: """ Extension identifier. """ - def process_request_params(self, params, accepted_extensions): + def process_request_params( + self, params: ExtensionParameters, accepted_extensions: List[Extension] + ) -> Tuple[ExtensionParameters, Extension]: """ Process request parameters received from the client. @@ -74,37 +118,3 @@ def process_request_params(self, params, accepted_extensions): :exc:`~websockets.exceptions.NegotiationError`. """ - - -class Extension: - """ - Abstract class for extensions. - - """ - - @property - def name(self): - """ - Extension identifier. - - """ - - def decode(self, frame, *, max_size=None): - """ - Decode an incoming frame. - - The ``frame`` parameter and the return value are - :class:`~websockets.framing.Frame` instances. - - - - """ - - def encode(self, frame): - """ - Encode an outgoing frame. - - The ``frame`` parameter and the return value are - :class:`~websockets.framing.Frame` instances. - - """ diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 2c2be49bd..93698a363 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -5,6 +5,7 @@ """ import zlib +from typing import Any, Dict, List, Optional, Tuple, Union from ..exceptions import ( DuplicateParameter, @@ -13,13 +14,15 @@ NegotiationError, PayloadTooBig, ) -from ..framing import CTRL_OPCODES, OP_CONT +from ..framing import CTRL_OPCODES, OP_CONT, Frame +from ..headers import ExtensionParameters +from .base import ClientExtensionFactory, Extension, ServerExtensionFactory __all__ = [ + "PerMessageDeflate", "ClientPerMessageDeflateFactory", "ServerPerMessageDeflateFactory", - "PerMessageDeflate", ] _EMPTY_UNCOMPRESSED_BLOCK = b"\x00\x00\xff\xff" @@ -27,17 +30,156 @@ _MAX_WINDOW_BITS_VALUES = [str(bits) for bits in range(8, 16)] +class PerMessageDeflate(Extension): + """ + Per-Message Deflate extension. + + """ + + name = "permessage-deflate" + + def __init__( + self, + remote_no_context_takeover: bool, + local_no_context_takeover: bool, + remote_max_window_bits: int, + local_max_window_bits: int, + compress_settings: Optional[Dict[Any, Any]] = None, + ): + """ + Configure the Per-Message Deflate extension. + + """ + if compress_settings is None: + compress_settings = {} + + assert remote_no_context_takeover in [False, True] + assert local_no_context_takeover in [False, True] + assert 8 <= remote_max_window_bits <= 15 + assert 8 <= local_max_window_bits <= 15 + assert "wbits" not in compress_settings + + self.remote_no_context_takeover = remote_no_context_takeover + self.local_no_context_takeover = local_no_context_takeover + self.remote_max_window_bits = remote_max_window_bits + self.local_max_window_bits = local_max_window_bits + self.compress_settings = compress_settings + + if not self.remote_no_context_takeover: + self.decoder = zlib.decompressobj(wbits=-self.remote_max_window_bits) + + if not self.local_no_context_takeover: + self.encoder = zlib.compressobj( + wbits=-self.local_max_window_bits, **self.compress_settings + ) + + # To handle continuation frames properly, we must keep track of + # whether that initial frame was encoded. + self.decode_cont_data = False + # There's no need for self.encode_cont_data because we always encode + # outgoing frames, so it would always be True. + + def __repr__(self) -> str: + return ( + f"PerMessageDeflate(" + f"remote_no_context_takeover={self.remote_no_context_takeover}, " + f"local_no_context_takeover={self.local_no_context_takeover}, " + f"remote_max_window_bits={self.remote_max_window_bits}, " + f"local_max_window_bits={self.local_max_window_bits})" + ) + + def decode(self, frame: Frame, *, max_size: Optional[int] = None) -> Frame: + """ + Decode an incoming frame. + + """ + # Skip control frames. + if frame.opcode in CTRL_OPCODES: + return frame + + # Handle continuation data frames: + # - skip if the initial data frame wasn't encoded + # - reset "decode continuation data" flag if it's a final frame + if frame.opcode == OP_CONT: + if not self.decode_cont_data: + return frame + if frame.fin: + self.decode_cont_data = False + + # Handle text and binary data frames: + # - skip if the frame isn't encoded + # - set "decode continuation data" flag if it's a non-final frame + else: + if not frame.rsv1: + return frame + if not frame.fin: # frame.rsv1 is True at this point + self.decode_cont_data = True + + # Re-initialize per-message decoder. + if self.remote_no_context_takeover: + self.decoder = zlib.decompressobj(wbits=-self.remote_max_window_bits) + + # Uncompress compressed frames. Protect against zip bombs by + # preventing zlib from decompressing more than max_length bytes + # (except when the limit is disabled with max_size = None). + data = frame.data + if frame.fin: + data += _EMPTY_UNCOMPRESSED_BLOCK + max_length = 0 if max_size is None else max_size + data = self.decoder.decompress(data, max_length) + if self.decoder.unconsumed_tail: + raise PayloadTooBig( + f"Uncompressed payload length exceeds size limit (? > {max_size} bytes)" + ) + + # Allow garbage collection of the decoder if it won't be reused. + if frame.fin and self.remote_no_context_takeover: + del self.decoder + + return frame._replace(data=data, rsv1=False) + + def encode(self, frame: Frame) -> Frame: + """ + Encode an outgoing frame. + + """ + # Skip control frames. + if frame.opcode in CTRL_OPCODES: + return frame + + # Since we always encode and never fragment messages, there's no logic + # similar to decode() here at this time. + + if frame.opcode != OP_CONT: + # Re-initialize per-message decoder. + if self.local_no_context_takeover: + self.encoder = zlib.compressobj( + wbits=-self.local_max_window_bits, **self.compress_settings + ) + + # Compress data frames. + data = self.encoder.compress(frame.data) + self.encoder.flush(zlib.Z_SYNC_FLUSH) + if frame.fin and data.endswith(_EMPTY_UNCOMPRESSED_BLOCK): + data = data[:-4] + + # Allow garbage collection of the encoder if it won't be reused. + if frame.fin and self.local_no_context_takeover: + del self.encoder + + return frame._replace(data=data, rsv1=True) + + def _build_parameters( - server_no_context_takeover, - client_no_context_takeover, - server_max_window_bits, - client_max_window_bits, -): + server_no_context_takeover: bool, + client_no_context_takeover: bool, + server_max_window_bits: Optional[int], + client_max_window_bits: Optional[Union[int, bool]], +) -> ExtensionParameters: """ Build a list of ``(name, value)`` pairs for some compression parameters. """ - params = [] + params: ExtensionParameters = [] if server_no_context_takeover: params.append(("server_no_context_takeover", None)) if client_no_context_takeover: @@ -51,7 +193,9 @@ def _build_parameters( return params -def _extract_parameters(params, *, is_server): +def _extract_parameters( + params: ExtensionParameters, *, is_server: bool +) -> Tuple[bool, bool, Optional[int], Optional[Union[int, bool]]]: """ Extract compression parameters from a list of ``(name, value)`` pairs. @@ -59,10 +203,10 @@ def _extract_parameters(params, *, is_server): without a value. This is only allow in handshake requests. """ - server_no_context_takeover = False - client_no_context_takeover = False - server_max_window_bits = None - client_max_window_bits = None + server_no_context_takeover: bool = False + client_no_context_takeover: bool = False + server_max_window_bits: Optional[int] = None + client_max_window_bits: Optional[Union[int, bool]] = None for name, value in params: @@ -111,7 +255,7 @@ def _extract_parameters(params, *, is_server): ) -class ClientPerMessageDeflateFactory: +class ClientPerMessageDeflateFactory(ClientExtensionFactory): """ Client-side extension factory for Per-Message Deflate extension. @@ -136,11 +280,11 @@ class ClientPerMessageDeflateFactory: def __init__( self, - server_no_context_takeover=False, - client_no_context_takeover=False, - server_max_window_bits=None, - client_max_window_bits=None, - compress_settings=None, + server_no_context_takeover: bool = False, + client_no_context_takeover: bool = False, + server_max_window_bits: Optional[int] = None, + client_max_window_bits: Optional[Union[int, bool]] = None, + compress_settings: Optional[Dict[Any, Any]] = None, ): """ Configure the Per-Message Deflate extension factory. @@ -166,7 +310,7 @@ def __init__( self.client_max_window_bits = client_max_window_bits self.compress_settings = compress_settings - def get_request_params(self): + def get_request_params(self) -> ExtensionParameters: """ Build request parameters. @@ -178,7 +322,11 @@ def get_request_params(self): self.client_max_window_bits, ) - def process_response_params(self, params, accepted_extensions): + def process_response_params( + self, + params: List[Tuple[str, Optional[str]]], + accepted_extensions: List["Extension"], + ) -> PerMessageDeflate: """ Process response parameters. @@ -280,7 +428,7 @@ def process_response_params(self, params, accepted_extensions): ) -class ServerPerMessageDeflateFactory: +class ServerPerMessageDeflateFactory(ServerExtensionFactory): """ Server-side extension factory for the Per-Message Deflate extension. @@ -305,11 +453,11 @@ class ServerPerMessageDeflateFactory: def __init__( self, - server_no_context_takeover=False, - client_no_context_takeover=False, - server_max_window_bits=None, - client_max_window_bits=None, - compress_settings=None, + server_no_context_takeover: bool = False, + client_no_context_takeover: bool = False, + server_max_window_bits: Optional[int] = None, + client_max_window_bits: Optional[int] = None, + compress_settings: Optional[Dict[Any, Any]] = None, ): """ Configure the Per-Message Deflate extension factory. @@ -331,7 +479,11 @@ def __init__( self.client_max_window_bits = client_max_window_bits self.compress_settings = compress_settings - def process_request_params(self, params, accepted_extensions): + def process_request_params( + self, + params: List[Tuple[str, Optional[str]]], + accepted_extensions: List["Extension"], + ) -> Tuple[ExtensionParameters, PerMessageDeflate]: """ Process request parameters. @@ -438,142 +590,3 @@ def process_request_params(self, params, accepted_extensions): self.compress_settings, ), ) - - -class PerMessageDeflate: - """ - Per-Message Deflate extension. - - """ - - name = "permessage-deflate" - - def __init__( - self, - remote_no_context_takeover, - local_no_context_takeover, - remote_max_window_bits, - local_max_window_bits, - compress_settings=None, - ): - """ - Configure the Per-Message Deflate extension. - - """ - if compress_settings is None: - compress_settings = {} - - assert remote_no_context_takeover in [False, True] - assert local_no_context_takeover in [False, True] - assert 8 <= remote_max_window_bits <= 15 - assert 8 <= local_max_window_bits <= 15 - assert "wbits" not in compress_settings - - self.remote_no_context_takeover = remote_no_context_takeover - self.local_no_context_takeover = local_no_context_takeover - self.remote_max_window_bits = remote_max_window_bits - self.local_max_window_bits = local_max_window_bits - self.compress_settings = compress_settings - - if not self.remote_no_context_takeover: - self.decoder = zlib.decompressobj(wbits=-self.remote_max_window_bits) - - if not self.local_no_context_takeover: - self.encoder = zlib.compressobj( - wbits=-self.local_max_window_bits, **self.compress_settings - ) - - # To handle continuation frames properly, we must keep track of - # whether that initial frame was encoded. - self.decode_cont_data = False - # There's no need for self.encode_cont_data because we always encode - # outgoing frames, so it would always be True. - - def __repr__(self): - return ( - f"PerMessageDeflate(" - f"remote_no_context_takeover={self.remote_no_context_takeover}, " - f"local_no_context_takeover={self.local_no_context_takeover}, " - f"remote_max_window_bits={self.remote_max_window_bits}, " - f"local_max_window_bits={self.local_max_window_bits})" - ) - - def decode(self, frame, *, max_size=None): - """ - Decode an incoming frame. - - """ - # Skip control frames. - if frame.opcode in CTRL_OPCODES: - return frame - - # Handle continuation data frames: - # - skip if the initial data frame wasn't encoded - # - reset "decode continuation data" flag if it's a final frame - if frame.opcode == OP_CONT: - if not self.decode_cont_data: - return frame - if frame.fin: - self.decode_cont_data = False - - # Handle text and binary data frames: - # - skip if the frame isn't encoded - # - set "decode continuation data" flag if it's a non-final frame - else: - if not frame.rsv1: - return frame - if not frame.fin: # frame.rsv1 is True at this point - self.decode_cont_data = True - - # Re-initialize per-message decoder. - if self.remote_no_context_takeover: - self.decoder = zlib.decompressobj(wbits=-self.remote_max_window_bits) - - # Uncompress compressed frames. Protect against zip bombs by - # preventing zlib from decompressing more than max_length bytes - # (except when the limit is disabled with max_size = None). - data = frame.data - if frame.fin: - data += _EMPTY_UNCOMPRESSED_BLOCK - max_length = 0 if max_size is None else max_size - data = self.decoder.decompress(data, max_length) - if self.decoder.unconsumed_tail: - raise PayloadTooBig( - f"Uncompressed payload length exceeds size limit (? > {max_size} bytes)" - ) - - # Allow garbage collection of the decoder if it won't be reused. - if frame.fin and self.remote_no_context_takeover: - self.decoder = None - - return frame._replace(data=data, rsv1=False) - - def encode(self, frame): - """ - Encode an outgoing frame. - - """ - # Skip control frames. - if frame.opcode in CTRL_OPCODES: - return frame - - # Since we always encode and never fragment messages, there's no logic - # similar to decode() here at this time. - - if frame.opcode != OP_CONT: - # Re-initialize per-message decoder. - if self.local_no_context_takeover: - self.encoder = zlib.compressobj( - wbits=-self.local_max_window_bits, **self.compress_settings - ) - - # Compress data frames. - data = self.encoder.compress(frame.data) + self.encoder.flush(zlib.Z_SYNC_FLUSH) - if frame.fin and data.endswith(_EMPTY_UNCOMPRESSED_BLOCK): - data = data[:-4] - - # Allow garbage collection of the encoder if it won't be reused. - if frame.fin and self.local_no_context_takeover: - self.encoder = None - - return frame._replace(data=data, rsv1=True) diff --git a/src/websockets/framing.py b/src/websockets/framing.py index 0abe8f8db..8eb1a79bd 100644 --- a/src/websockets/framing.py +++ b/src/websockets/framing.py @@ -9,14 +9,29 @@ """ -import collections import io import random import struct +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + List, + NamedTuple, + Optional, + Tuple, + Union, +) from .exceptions import PayloadTooBig, WebSocketProtocolError +if TYPE_CHECKING: # pragma: no cover + from .extensions.base import Extension +else: + Extension = Any + try: from .speedups import apply_mask except ImportError: # pragma: no cover @@ -46,8 +61,24 @@ # Using a list optimizes `code in EXTERNAL_CLOSE_CODES`. EXTERNAL_CLOSE_CODES = [1000, 1001, 1002, 1003, 1007, 1008, 1009, 1010, 1011] -FrameData = collections.namedtuple( - "FrameData", ["fin", "opcode", "data", "rsv1", "rsv2", "rsv3"] + +Data = Union[str, bytes] + + +# Switch to class-based syntax when dropping support for Python < 3.6. + +# Convert to a dataclass when dropping support for Python < 3.7. + +FrameData = NamedTuple( + "FrameData", + [ + ("fin", bool), + ("opcode", int), + ("data", bytes), + ("rsv1", bool), + ("rsv2", bool), + ("rsv3", bool), + ], ) @@ -68,11 +99,26 @@ class Frame(FrameData): """ - def __new__(cls, fin, opcode, data, rsv1=False, rsv2=False, rsv3=False): + def __new__( + cls, + fin: bool, + opcode: int, + data: bytes, + rsv1: bool = False, + rsv2: bool = False, + rsv3: bool = False, + ) -> "Frame": return FrameData.__new__(cls, fin, opcode, data, rsv1, rsv2, rsv3) @classmethod - async def read(cls, reader, *, mask, max_size=None, extensions=None): + async def read( + cls, + reader: Callable[[int], Awaitable[bytes]], + *, + mask: bool, + max_size: Optional[int] = None, + extensions: Optional[List[Extension]] = None, + ) -> "Frame": """ Read a WebSocket frame and return a :class:`Frame` object. @@ -138,7 +184,13 @@ async def read(cls, reader, *, mask, max_size=None, extensions=None): return frame - def write(frame, writer, *, mask, extensions=None): + def write( + frame, + writer: Callable[[bytes], Any], + *, + mask: bool, + extensions: Optional[List[Extension]] = None, + ) -> None: """ Write a WebSocket frame. @@ -207,7 +259,7 @@ def write(frame, writer, *, mask, extensions=None): # send frames concurrently from multiple coroutines. writer(output.getvalue()) - def check(frame): + def check(frame) -> None: """ Check that this frame contains acceptable values. @@ -232,7 +284,7 @@ def check(frame): raise WebSocketProtocolError(f"Invalid opcode: {frame.opcode}") -def prepare_data(data): +def prepare_data(data: Data) -> Tuple[int, bytes]: """ Convert a string or byte-like object to an opcode and a bytes-like object. @@ -249,7 +301,7 @@ def prepare_data(data): """ if isinstance(data, str): return OP_TEXT, data.encode("utf-8") - elif isinstance(data, collections.abc.ByteString): + elif isinstance(data, (bytes, bytearray)): return OP_BINARY, data elif isinstance(data, memoryview): if data.c_contiguous: @@ -260,11 +312,11 @@ def prepare_data(data): raise TypeError("data must be bytes-like or str") -def encode_data(data): +def encode_data(data: Data) -> bytes: """ Convert a string or byte-like object to bytes. - This function is designed for ping and pong frames. + This function is designed for ping and pon g frames. If ``data`` is a :class:`str`, return a :class:`bytes` object encoding ``data`` in UTF-8. @@ -276,7 +328,7 @@ def encode_data(data): """ if isinstance(data, str): return data.encode("utf-8") - elif isinstance(data, collections.abc.ByteString): + elif isinstance(data, (bytes, bytearray)): return bytes(data) elif isinstance(data, memoryview): return data.tobytes() @@ -284,7 +336,7 @@ def encode_data(data): raise TypeError("data must be bytes-like or str") -def parse_close(data): +def parse_close(data: bytes) -> Tuple[int, str]: """ Parse the data in a close frame. @@ -308,7 +360,7 @@ def parse_close(data): raise WebSocketProtocolError("Close frame too short") -def serialize_close(code, reason): +def serialize_close(code: int, reason: str) -> bytes: """ Serialize the data for a close frame. @@ -319,7 +371,7 @@ def serialize_close(code, reason): return struct.pack("!H", code) + reason.encode("utf-8") -def check_close(code): +def check_close(code: int) -> None: """ Check the close code for a close frame. diff --git a/src/websockets/handshake.py b/src/websockets/handshake.py index e6bd61fab..f04d81d59 100644 --- a/src/websockets/handshake.py +++ b/src/websockets/handshake.py @@ -38,7 +38,7 @@ from .exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade from .headers import parse_connection, parse_upgrade -from .http import MultipleValuesError +from .http import Headers, MultipleValuesError __all__ = ["build_request", "check_request", "build_response", "check_response"] @@ -46,7 +46,7 @@ GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" -def build_request(headers): +def build_request(headers: Headers) -> str: """ Build a handshake request to send to the server. @@ -62,7 +62,7 @@ def build_request(headers): return key -def check_request(headers): +def check_request(headers: Headers) -> str: """ Check a handshake request received from the client. @@ -83,14 +83,14 @@ def check_request(headers): ) if not any(value.lower() == "upgrade" for value in connection): - raise InvalidUpgrade("Connection", connection) + raise InvalidUpgrade("Connection", ", ".join(connection)) upgrade = sum([parse_upgrade(value) for value in headers.get_all("Upgrade")], []) # For compatibility with non-strict implementations, ignore case when # checking the Upgrade header. It's supposed to be 'WebSocket'. if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"): - raise InvalidUpgrade("Upgrade", upgrade) + raise InvalidUpgrade("Upgrade", ", ".join(upgrade)) try: s_w_key = headers["Sec-WebSocket-Key"] @@ -123,7 +123,7 @@ def check_request(headers): return s_w_key -def build_response(headers, key): +def build_response(headers: Headers, key: str) -> None: """ Build a handshake response to send to the client. @@ -135,7 +135,7 @@ def build_response(headers, key): headers["Sec-WebSocket-Accept"] = accept(key) -def check_response(headers, key): +def check_response(headers: Headers, key: str) -> None: """ Check a handshake response received from the server. @@ -156,14 +156,14 @@ def check_response(headers, key): ) if not any(value.lower() == "upgrade" for value in connection): - raise InvalidUpgrade("Connection", connection) + raise InvalidUpgrade("Connection", " ".join(connection)) upgrade = sum([parse_upgrade(value) for value in headers.get_all("Upgrade")], []) # For compatibility with non-strict implementations, ignore case when # checking the Upgrade header. It's supposed to be 'WebSocket'. if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"): - raise InvalidUpgrade("Upgrade", upgrade) + raise InvalidUpgrade("Upgrade", ", ".join(upgrade)) try: s_w_accept = headers["Sec-WebSocket-Accept"] @@ -178,6 +178,6 @@ def check_response(headers, key): raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept) -def accept(key): +def accept(key: str) -> str: sha1 = hashlib.sha1((key + GUID).encode()).digest() return base64.b64encode(sha1).decode() diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 73f11edce..e2addf4c5 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -9,6 +9,7 @@ import base64 import re +from typing import Callable, List, Optional, Tuple, TypeVar from .exceptions import InvalidHeaderFormat @@ -23,12 +24,19 @@ ] +T = TypeVar("T") + +ExtensionParameter = Tuple[str, Optional[str]] +ExtensionParameters = List[ExtensionParameter] +ExtensionHeader = Tuple[str, ExtensionParameters] +SubprotocolHeader = str + # To avoid a dependency on a parsing library, we implement manually the ABNF # described in https://tools.ietf.org/html/rfc6455#section-9.1 with the # definitions from https://tools.ietf.org/html/rfc7230#appendix-B. -def peek_ahead(string, pos): +def peek_ahead(string: str, pos: int) -> Optional[str]: """ Return the next character from ``string`` at the given position. @@ -43,7 +51,7 @@ def peek_ahead(string, pos): _OWS_re = re.compile(r"[\t ]*") -def parse_OWS(string, pos): +def parse_OWS(string: str, pos: int) -> int: """ Parse optional whitespace from ``string`` at the given position. @@ -54,13 +62,14 @@ def parse_OWS(string, pos): """ # There's always a match, possibly empty, whose content doesn't matter. match = _OWS_re.match(string, pos) + assert match is not None return match.end() _token_re = re.compile(r"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+") -def parse_token(string, pos, header_name): +def parse_token(string: str, pos: int, header_name: str) -> Tuple[str, int]: """ Parse a token from ``string`` at the given position. @@ -83,7 +92,7 @@ def parse_token(string, pos, header_name): _unquote_re = re.compile(r"\\([\x09\x20-\x7e\x80-\xff])") -def parse_quoted_string(string, pos, header_name): +def parse_quoted_string(string: str, pos: int, header_name: str) -> Tuple[str, int]: """ Parse a quoted string from ``string`` at the given position. @@ -100,7 +109,12 @@ def parse_quoted_string(string, pos, header_name): return _unquote_re.sub(r"\1", match.group()[1:-1]), match.end() -def parse_list(parse_item, string, pos, header_name): +def parse_list( + parse_item: Callable[[str, int, str], Tuple[T, int]], + string: str, + pos: int, + header_name: str, +) -> List[T]: """ Parse a comma-separated list from ``string`` at the given position. @@ -162,7 +176,7 @@ def parse_list(parse_item, string, pos, header_name): return items -def parse_connection(string): +def parse_connection(string: str) -> List[str]: """ Parse a ``Connection`` header. @@ -179,7 +193,7 @@ def parse_connection(string): ) -def parse_protocol(string, pos, header_name): +def parse_protocol(string: str, pos: int, header_name: str) -> Tuple[str, int]: """ Parse a protocol from ``string`` at the given position. @@ -196,7 +210,7 @@ def parse_protocol(string, pos, header_name): return match.group(), match.end() -def parse_upgrade(string): +def parse_upgrade(string: str) -> List[str]: """ Parse an ``Upgrade`` header. @@ -208,7 +222,9 @@ def parse_upgrade(string): return parse_list(parse_protocol, string, 0, "Upgrade") -def parse_extension_param(string, pos, header_name): +def parse_extension_param( + string: str, pos: int, header_name: str +) -> Tuple[ExtensionParameter, int]: """ Parse a single extension parameter from ``string`` at the given position. @@ -220,7 +236,8 @@ def parse_extension_param(string, pos, header_name): # Extract parameter name. name, pos = parse_token(string, pos, header_name) pos = parse_OWS(string, pos) - # Extract parameter string, if there is one. + # Extract parameter value, if there is one. + value: Optional[str] = None if peek_ahead(string, pos) == "=": pos = parse_OWS(string, pos + 1) if peek_ahead(string, pos) == '"': @@ -238,13 +255,13 @@ def parse_extension_param(string, pos, header_name): else: value, pos = parse_token(string, pos, header_name) pos = parse_OWS(string, pos) - else: - value = None return (name, value), pos -def parse_extension(string, pos, header_name): +def parse_extension( + string: str, pos: int, header_name: str +) -> Tuple[ExtensionHeader, int]: """ Parse an extension definition from ``string`` at the given position. @@ -266,7 +283,7 @@ def parse_extension(string, pos, header_name): return (name, parameters), pos -def parse_extension_list(string): +def parse_extension_list(string: str) -> List[ExtensionHeader]: """ Parse a ``Sec-WebSocket-Extensions`` header. @@ -291,7 +308,7 @@ def parse_extension_list(string): return parse_list(parse_extension, string, 0, "Sec-WebSocket-Extensions") -def build_extension(name, parameters): +def build_extension(name: str, parameters: ExtensionParameters) -> str: """ Build an extension definition. @@ -308,7 +325,7 @@ def build_extension(name, parameters): ) -def build_extension_list(extensions): +def build_extension_list(extensions: List[ExtensionHeader]) -> str: """ Unparse a ``Sec-WebSocket-Extensions`` header. @@ -320,7 +337,7 @@ def build_extension_list(extensions): ) -def parse_subprotocol_list(string): +def parse_subprotocol_list(string: str) -> List[SubprotocolHeader]: """ Parse a ``Sec-WebSocket-Protocol`` header. @@ -330,7 +347,7 @@ def parse_subprotocol_list(string): return parse_list(parse_token, string, 0, "Sec-WebSocket-Protocol") -def build_subprotocol_list(protocols): +def build_subprotocol_list(protocols: List[SubprotocolHeader]) -> str: """ Unparse a ``Sec-WebSocket-Protocol`` header. @@ -340,7 +357,7 @@ def build_subprotocol_list(protocols): return ", ".join(protocols) -def build_basic_auth(username, password): +def build_basic_auth(username: str, password: str) -> str: """ Build an Authorization header for HTTP Basic Auth. diff --git a/src/websockets/http.py b/src/websockets/http.py index ab74614af..f0c58061d 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -7,9 +7,20 @@ """ -import collections.abc +import asyncio import re import sys +from typing import ( + Any, + Dict, + Iterable, + Iterator, + List, + Mapping, + MutableMapping, + Tuple, + Union, +) from .version import version as websockets_version @@ -48,7 +59,7 @@ _value_re = re.compile(rb"[\x09\x20-\x7e\x80-\xff]*") -async def read_request(stream): +async def read_request(stream: asyncio.StreamReader) -> Tuple[str, "Headers"]: """ Read an HTTP/1.1 GET request from ``stream``. @@ -77,20 +88,20 @@ async def read_request(stream): request_line = await read_line(stream) # This may raise "ValueError: not enough values to unpack" - method, path, version = request_line.split(b" ", 2) + method, raw_path, version = request_line.split(b" ", 2) if method != b"GET": raise ValueError("Unsupported HTTP method: %r" % method) if version != b"HTTP/1.1": raise ValueError("Unsupported HTTP version: %r" % version) - path = path.decode("ascii", "surrogateescape") + path = raw_path.decode("ascii", "surrogateescape") headers = await read_headers(stream) return path, headers -async def read_response(stream): +async def read_response(stream: asyncio.StreamReader) -> Tuple[int, str, "Headers"]: """ Read an HTTP/1.1 response from ``stream``. @@ -117,24 +128,24 @@ async def read_response(stream): status_line = await read_line(stream) # This may raise "ValueError: not enough values to unpack" - version, status_code, reason = status_line.split(b" ", 2) + version, raw_status_code, raw_reason = status_line.split(b" ", 2) if version != b"HTTP/1.1": raise ValueError("Unsupported HTTP version: %r" % version) # This may raise "ValueError: invalid literal for int() with base 10" - status_code = int(status_code) + status_code = int(raw_status_code) if not 100 <= status_code < 1000: raise ValueError("Unsupported HTTP status code: %d" % status_code) - if not _value_re.fullmatch(reason): - raise ValueError("Invalid HTTP reason phrase: %r" % reason) - reason = reason.decode() + if not _value_re.fullmatch(raw_reason): + raise ValueError("Invalid HTTP reason phrase: %r" % raw_reason) + reason = raw_reason.decode() headers = await read_headers(stream) return status_code, reason, headers -async def read_headers(stream): +async def read_headers(stream: asyncio.StreamReader) -> "Headers": """ Read HTTP headers from ``stream``. @@ -156,15 +167,15 @@ async def read_headers(stream): break # This may raise "ValueError: not enough values to unpack" - name, value = line.split(b":", 1) - if not _token_re.fullmatch(name): - raise ValueError("Invalid HTTP header name: %r" % name) - value = value.strip(b" \t") - if not _value_re.fullmatch(value): - raise ValueError("Invalid HTTP header value: %r" % value) - - name = name.decode("ascii") # guaranteed to be ASCII at this point - value = value.decode("ascii", "surrogateescape") + raw_name, raw_value = line.split(b":", 1) + if not _token_re.fullmatch(raw_name): + raise ValueError("Invalid HTTP header name: %r" % raw_name) + raw_value = raw_value.strip(b" \t") + if not _value_re.fullmatch(raw_value): + raise ValueError("Invalid HTTP header value: %r" % raw_value) + + name = raw_name.decode("ascii") # guaranteed to be ASCII at this point + value = raw_value.decode("ascii", "surrogateescape") headers[name] = value else: @@ -173,7 +184,7 @@ async def read_headers(stream): return headers -async def read_line(stream): +async def read_line(stream: asyncio.StreamReader) -> bytes: """ Read a single line from ``stream``. @@ -199,14 +210,14 @@ class MultipleValuesError(LookupError): """ - def __str__(self): + def __str__(self) -> str: # Implement the same logic as KeyError_str in Objects/exceptions.c. if len(self.args) == 1: return repr(self.args[0]) return super().__str__() -class Headers(collections.abc.MutableMapping): +class Headers(MutableMapping[str, str]): """ Data structure for working with HTTP headers efficiently. @@ -245,19 +256,19 @@ class Headers(collections.abc.MutableMapping): __slots__ = ["_dict", "_list"] - def __init__(self, *args, **kwargs): - self._dict = {} - self._list = [] + def __init__(self, *args: Any, **kwargs: str) -> None: + self._dict: Dict[str, List[str]] = {} + self._list: List[Tuple[str, str]] = [] # MutableMapping.update calls __setitem__ for each (name, value) pair. self.update(*args, **kwargs) - def __str__(self): + def __str__(self) -> str: return "".join(f"{key}: {value}\r\n" for key, value in self._list) + "\r\n" - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__}({self._list!r})" - def copy(self): + def copy(self) -> "Headers": copy = self.__class__() copy._dict = self._dict.copy() copy._list = self._list.copy() @@ -265,40 +276,40 @@ def copy(self): # Collection methods - def __contains__(self, key): - return key.lower() in self._dict + def __contains__(self, key: object) -> bool: + return isinstance(key, str) and key.lower() in self._dict - def __iter__(self): + def __iter__(self) -> Iterator[str]: return iter(self._dict) - def __len__(self): + def __len__(self) -> int: return len(self._dict) # MutableMapping methods - def __getitem__(self, key): + def __getitem__(self, key: str) -> str: value = self._dict[key.lower()] if len(value) == 1: return value[0] else: raise MultipleValuesError(key) - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: str) -> None: self._dict.setdefault(key.lower(), []).append(value) self._list.append((key, value)) - def __delitem__(self, key): + def __delitem__(self, key: str) -> None: key_lower = key.lower() self._dict.__delitem__(key_lower) # This is inefficent. Fortunately deleting HTTP headers is uncommon. self._list = [(k, v) for k, v in self._list if k.lower() != key_lower] - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if not isinstance(other, Headers): return NotImplemented return self._list == other._list - def clear(self): + def clear(self) -> None: """ Remove all headers. @@ -308,16 +319,19 @@ def clear(self): # Methods for handling multiple values - def get_all(self, key): + def get_all(self, key: str) -> List[str]: """ Return the (possibly empty) list of all values for a header. """ return self._dict.get(key.lower(), []) - def raw_items(self): + def raw_items(self) -> Iterator[Tuple[str, str]]: """ Return an iterator of (header name, header value). """ return iter(self._list) + + +HeadersLike = Union[Headers, Mapping[str, str], Iterable[Tuple[str, str]]] diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 5cc1bcc90..b28dcef72 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -10,11 +10,22 @@ import binascii import codecs import collections -import collections.abc import enum import logging import random import struct +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Awaitable, + Deque, + Iterable, + List, + Optional, + Union, + cast, +) from .exceptions import ( ConnectionClosed, @@ -22,8 +33,11 @@ PayloadTooBig, WebSocketProtocolError, ) +from .extensions.base import Extension from .framing import * +from .framing import Data from .handshake import * +from .http import Headers __all__ = ["WebSocketCommonProtocol"] @@ -155,20 +169,20 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): def __init__( self, *, - host=None, - port=None, - secure=None, - ping_interval=20, - ping_timeout=20, - close_timeout=None, - max_size=2 ** 20, - max_queue=2 ** 5, - read_limit=2 ** 16, - write_limit=2 ** 16, - loop=None, - legacy_recv=False, - timeout=10, - ): + host: Optional[str] = None, + port: Optional[int] = None, + secure: Optional[bool] = None, + ping_interval: float = 20, + ping_timeout: float = 20, + close_timeout: Optional[float] = None, + max_size: int = 2 ** 20, + max_queue: int = 2 ** 5, + read_limit: int = 2 ** 16, + write_limit: int = 2 ** 16, + loop: Optional[asyncio.AbstractEventLoop] = None, + legacy_recv: bool = False, + timeout: float = 10, + ) -> None: # Backwards-compatibility: close_timeout used to be called timeout. # If both are specified, timeout is ignored. if close_timeout is None: @@ -200,8 +214,8 @@ def __init__( stream_reader = asyncio.StreamReader(limit=read_limit // 2, loop=loop) super().__init__(stream_reader, self.client_connected, loop) - self.reader = None - self.writer = None + self.reader: asyncio.StreamReader + self.writer: asyncio.StreamWriter self._drain_lock = asyncio.Lock(loop=loop) # This class implements the data transfer and closing handshake, which @@ -212,46 +226,50 @@ def __init__( logger.debug("%s - state = CONNECTING", self.side) # HTTP protocol parameters. - self.path = None - self.request_headers = None - self.response_headers = None + self.path: str + self.request_headers: Headers + self.response_headers: Headers # WebSocket protocol parameters. - self.extensions = [] - self.subprotocol = None + self.extensions: List[Extension] = [] + self.subprotocol: Optional[str] = None # The close code and reason are set when receiving a close frame or # losing the TCP connection. - self.close_code = None - self.close_reason = "" + self.close_code: int + self.close_reason: str # Completed when the connection state becomes CLOSED. Translates the # :meth:`connection_lost()` callback to a :class:`~asyncio.Future` # that can be awaited. (Other :class:`~asyncio.Protocol` callbacks are # translated by ``self.stream_reader``). - self.connection_lost_waiter = loop.create_future() + self.connection_lost_waiter: asyncio.Future[None] = loop.create_future() # Queue of received messages. - self.messages = collections.deque() - self._pop_message_waiter = None - self._put_message_waiter = None + self.messages: Deque[Data] = collections.deque() + self._pop_message_waiter: Optional[asyncio.Future[None]] = None + self._put_message_waiter: Optional[asyncio.Future[None]] = None # Mapping of ping IDs to waiters, in chronological order. - self.pings = collections.OrderedDict() + self.pings: collections.OrderedDict[ + bytes, asyncio.Future[None] + ] = collections.OrderedDict() # Task running the data transfer. - self.transfer_data_task = None + self.transfer_data_task: asyncio.Task[None] # Exception that occurred during data transfer, if any. - self.transfer_data_exc = None + self.transfer_data_exc: Optional[BaseException] = None # Task sending keepalive pings. - self.keepalive_ping_task = None + self.keepalive_ping_task: asyncio.Task[None] # Task closing the TCP connection. - self.close_connection_task = None + self.close_connection_task: asyncio.Task[None] - def client_connected(self, reader, writer): + def client_connected( + self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter + ) -> None: """ Callback when the TCP connection is established. @@ -263,7 +281,7 @@ def client_connected(self, reader, writer): self.reader = reader self.writer = writer - def connection_open(self): + def connection_open(self) -> None: """ Callback when the WebSocket opening handshake completes. @@ -284,7 +302,7 @@ def connection_open(self): # Public API @property - def local_address(self): + def local_address(self) -> Any: """ Local address of the connection. @@ -297,7 +315,7 @@ def local_address(self): return self.writer.get_extra_info("sockname") @property - def remote_address(self): + def remote_address(self) -> Any: """ Remote address of the connection. @@ -310,7 +328,7 @@ def remote_address(self): return self.writer.get_extra_info("peername") @property - def open(self): + def open(self) -> bool: """ This property is ``True`` when the connection is usable. @@ -324,7 +342,7 @@ def open(self): return self.state is State.OPEN and not self.transfer_data_task.done() @property - def closed(self): + def closed(self) -> bool: """ This property is ``True`` once the connection is closed. @@ -334,7 +352,7 @@ def closed(self): """ return self.state is State.CLOSED - async def wait_closed(self): + async def wait_closed(self) -> None: """ Wait until the connection is closed. @@ -346,7 +364,7 @@ async def wait_closed(self): """ await asyncio.shield(self.connection_lost_waiter) - async def __aiter__(self): + async def __aiter__(self) -> AsyncIterator[Data]: """ Iterate on received messages. @@ -364,7 +382,7 @@ async def __aiter__(self): else: raise - async def recv(self): + async def recv(self) -> Data: """ This coroutine receives the next message. @@ -405,7 +423,7 @@ async def recv(self): # Wait until there's a message in the queue (if necessary) or the # connection is closed. while len(self.messages) <= 0: - pop_message_waiter = self.loop.create_future() + pop_message_waiter: asyncio.Future[None] = self.loop.create_future() self._pop_message_waiter = pop_message_waiter try: # If asyncio.wait() is canceled, it doesn't cancel @@ -423,7 +441,7 @@ async def recv(self): # exception (or return None if legacy_recv is enabled). if not pop_message_waiter.done(): if self.legacy_recv: - return + return None # type: ignore else: assert self.state in [State.CLOSING, State.CLOSED] # Wait until the connection is closed to raise @@ -440,7 +458,9 @@ async def recv(self): return message - async def send(self, data): + async def send( + self, message: Union[Data, Iterable[Data], AsyncIterable[Data]] + ) -> None: """ This coroutine sends a message. @@ -462,31 +482,30 @@ async def send(self, data): # Unfragmented message -- this case must be handled first because # strings and bytes-like objects are iterable. - try: - opcode, data = prepare_data(data) - except TypeError: - # Perhaps data is an iterator, see below. - pass - else: + if isinstance(message, (str, bytes, bytearray, memoryview)): + opcode, data = prepare_data(message) await self.write_frame(True, opcode, data) - return # Fragmented message -- regular iterator. - if isinstance(data, collections.abc.Iterable): - iter_data = iter(data) + elif isinstance(message, Iterable): + + # Work around https://github.com/python/mypy/issues/6227 + message = cast(Iterable[Data], message) + + iter_message = iter(message) # First fragment. try: - data = next(iter_data) + message_chunk = next(iter_message) except StopIteration: return - opcode, data = prepare_data(data) + opcode, data = prepare_data(message_chunk) await self.write_frame(False, opcode, data) # Other fragments. - for data in iter_data: - confirm_opcode, data = prepare_data(data) + for message_chunk in iter_message: + confirm_opcode, data = prepare_data(message_chunk) if confirm_opcode != opcode: # We're half-way through a fragmented message and we can't # complete it. This makes the connection unusable. @@ -499,22 +518,22 @@ async def send(self, data): # Fragmented message -- asynchronous iterator - elif isinstance(data, collections.abc.AsyncIterable): - # aiter_data = aiter(data) without aiter - aiter_data = type(data).__aiter__(data) + elif isinstance(message, AsyncIterable): + # aiter_message = aiter(message) without aiter + aiter_message = type(message).__aiter__(message) # First fragment. try: - # data = anext(aiter_data) without anext - data = await type(aiter_data).__anext__(aiter_data) + # message_chunk = anext(aiter_message) without anext + message_chunk = await type(aiter_message).__anext__(aiter_message) except StopAsyncIteration: return - opcode, data = prepare_data(data) + opcode, data = prepare_data(message_chunk) await self.write_frame(False, opcode, data) # Other fragments. - async for data in aiter_data: - confirm_opcode, data = prepare_data(data) + async for message_chunk in aiter_message: + confirm_opcode, data = prepare_data(message_chunk) if confirm_opcode != opcode: # We're half-way through a fragmented message and we can't # complete it. This makes the connection unusable. @@ -528,7 +547,7 @@ async def send(self, data): else: raise TypeError("data must be bytes, str, or iterable") - async def close(self, code=1000, reason=""): + async def close(self, code: int = 1000, reason: str = "") -> None: """ This coroutine performs the closing handshake. @@ -577,7 +596,7 @@ async def close(self, code=1000, reason=""): # Wait for the close connection task to close the TCP connection. await asyncio.shield(self.close_connection_task) - async def ping(self, data=None): + async def ping(self, data: Optional[bytes] = None) -> Awaitable[None]: """ This coroutine sends a ping. @@ -615,7 +634,7 @@ async def ping(self, data=None): return asyncio.shield(self.pings[data]) - async def pong(self, data=b""): + async def pong(self, data: bytes = b"") -> None: """ This coroutine sends a pong. @@ -634,7 +653,7 @@ async def pong(self, data=b""): # Private methods - no guarantees. - async def ensure_open(self): + async def ensure_open(self) -> None: """ Check that the WebSocket connection is open. @@ -665,8 +684,7 @@ async def ensure_open(self): # will complete within 4 or 5 * close_timeout after close(). The # CLOSING state also occurs when failing the connection. In that # case self.close_connection_task will complete even faster. - if self.close_code is None: - await asyncio.shield(self.close_connection_task) + await asyncio.shield(self.close_connection_task) raise ConnectionClosed( self.close_code, self.close_reason ) from self.transfer_data_exc @@ -675,7 +693,7 @@ async def ensure_open(self): assert self.state is State.CONNECTING raise InvalidState("WebSocket connection isn't established yet") - async def transfer_data(self): + async def transfer_data(self) -> None: """ Read incoming messages and put them in a queue. @@ -742,7 +760,7 @@ async def transfer_data(self): self.transfer_data_exc = exc self.fail_connection(1011) - async def read_message(self): + async def read_message(self) -> Optional[Data]: """ Read a single message from the connection. @@ -755,7 +773,7 @@ async def read_message(self): # A close frame was received. if frame is None: - return + return None if frame.opcode == OP_TEXT: text = True @@ -769,19 +787,21 @@ async def read_message(self): return frame.data.decode("utf-8") if text else frame.data # 5.4. Fragmentation - chunks = [] + chunks: List[Data] = [] max_size = self.max_size if text: - decoder = codecs.getincrementaldecoder("utf-8")(errors="strict") + decoder_factory = codecs.getincrementaldecoder("utf-8") + # https://github.com/python/typeshed/pull/2752 + decoder = decoder_factory(errors="strict") # type: ignore if max_size is None: - def append(frame): + def append(frame: Frame) -> None: nonlocal chunks chunks.append(decoder.decode(frame.data, frame.fin)) else: - def append(frame): + def append(frame: Frame) -> None: nonlocal chunks, max_size chunks.append(decoder.decode(frame.data, frame.fin)) max_size -= len(frame.data) @@ -789,13 +809,13 @@ def append(frame): else: if max_size is None: - def append(frame): + def append(frame: Frame) -> None: nonlocal chunks chunks.append(frame.data) else: - def append(frame): + def append(frame: Frame) -> None: nonlocal chunks, max_size chunks.append(frame.data) max_size -= len(frame.data) @@ -810,9 +830,10 @@ def append(frame): raise WebSocketProtocolError("Unexpected opcode") append(frame) - return ("" if text else b"").join(chunks) + # mypy cannot figure out that chunks have the proper type. + return ("" if text else b"").join(chunks) # type: ignore - async def read_data_frame(self, max_size): + async def read_data_frame(self, max_size: int) -> Optional[Frame]: """ Read a single data frame from the connection. @@ -834,7 +855,7 @@ async def read_data_frame(self, max_size): # serialize_close() because that fails when the close frame is # empty and parse_close() synthetizes a 1005 close code. await self.write_close_frame(frame.data) - return + return None elif frame.opcode == OP_PING: # Answer pings. @@ -851,7 +872,7 @@ async def read_data_frame(self, max_size): ping_id = None ping_ids = [] while ping_id != frame.data: - ping_id, pong_waiter = self.pings.popitem(0) + ping_id, pong_waiter = self.pings.popitem(last=False) ping_ids.append(ping_id) pong_waiter.set_result(None) pong_hex = binascii.hexlify(frame.data).decode() or "[empty]" @@ -881,7 +902,7 @@ async def read_data_frame(self, max_size): else: return frame - async def read_frame(self, max_size): + async def read_frame(self, max_size: int) -> Frame: """ Read a single frame from the connection. @@ -895,7 +916,9 @@ async def read_frame(self, max_size): logger.debug("%s < %r", self.side, frame) return frame - async def write_frame(self, fin, opcode, data, *, _expected_state=State.OPEN): + async def write_frame( + self, fin: bool, opcode: int, data: bytes, *, _expected_state: int = State.OPEN + ) -> None: # Defensive assertion for protocol compliance. if self.state is not _expected_state: # pragma: no cover raise InvalidState( @@ -920,7 +943,7 @@ async def write_frame(self, fin, opcode, data, *, _expected_state=State.OPEN): # with the correct code and reason. await self.ensure_open() - async def write_close_frame(self, data=b""): + async def write_close_frame(self, data: bytes = b"") -> None: """ Write a close frame if and only if the connection state is OPEN. @@ -938,7 +961,7 @@ async def write_close_frame(self, data=b""): # 7.1.2. Start the WebSocket Closing Handshake await self.write_frame(True, OP_CLOSE, data, _expected_state=State.CLOSING) - async def keepalive_ping(self): + async def keepalive_ping(self) -> None: """ Send a Ping frame and wait for a Pong frame at regular intervals. @@ -978,7 +1001,7 @@ async def keepalive_ping(self): except Exception: logger.warning("Unexpected exception in keepalive ping task", exc_info=True) - async def close_connection(self): + async def close_connection(self) -> None: """ 7.1.1. Close the WebSocket Connection @@ -992,18 +1015,18 @@ async def close_connection(self): """ try: # Wait for the data transfer phase to complete. - if self.transfer_data_task is not None: + if hasattr(self, "transfer_data_task"): try: await self.transfer_data_task except asyncio.CancelledError: pass # Cancel the keepalive ping task. - if self.keepalive_ping_task is not None: + if hasattr(self, "keepalive_ping_task"): self.keepalive_ping_task.cancel() # A client should wait for a TCP close from the server. - if self.is_client and self.transfer_data_task is not None: + if self.is_client and hasattr(self, "transfer_data_task"): if await self.wait_for_connection_lost(): return logger.debug("%s ! timed out waiting for TCP close", self.side) @@ -1037,12 +1060,13 @@ async def close_connection(self): # Abort the TCP connection. Buffers are discarded. logger.debug("%s x aborting TCP connection", self.side) - self.writer.transport.abort() + # mypy thinks self.writer.transport is a BaseTransport, not a Transport. + self.writer.transport.abort() # type: ignore # connection_lost() is called quickly after aborting. await self.wait_for_connection_lost() - async def wait_for_connection_lost(self): + async def wait_for_connection_lost(self) -> bool: """ Wait until the TCP connection is closed or ``self.close_timeout`` elapses. @@ -1063,7 +1087,7 @@ async def wait_for_connection_lost(self): # and the moment this coroutine resumes running. return self.connection_lost_waiter.done() - def fail_connection(self, code=1006, reason=""): + def fail_connection(self, code: int = 1006, reason: str = "") -> None: """ 7.1.7. Fail the WebSocket Connection @@ -1091,7 +1115,7 @@ def fail_connection(self, code=1006, reason=""): # Cancel transfer_data_task if the opening handshake succeeded. # cancel() is idempotent and ignored if the task is done already. - if self.transfer_data_task is not None: + if hasattr(self, "transfer_data_task"): self.transfer_data_task.cancel() # Send a close frame when the state is OPEN (a close frame was already @@ -1121,10 +1145,10 @@ def fail_connection(self, code=1006, reason=""): ) # Start close_connection_task if the opening handshake didn't succeed. - if self.close_connection_task is None: + if not hasattr(self, "close_connection_task"): self.close_connection_task = self.loop.create_task(self.close_connection()) - def abort_keepalive_pings(self): + def abort_keepalive_pings(self) -> None: """ Raise ConnectionClosed in pending keepalive pings. @@ -1150,7 +1174,7 @@ def abort_keepalive_pings(self): # asyncio.StreamReaderProtocol methods - def connection_made(self, transport): + def connection_made(self, transport: asyncio.BaseTransport) -> None: """ Configure write buffer limits. @@ -1165,10 +1189,11 @@ def connection_made(self, transport): """ logger.debug("%s - event = connection_made(%s)", self.side, transport) - transport.set_write_buffer_limits(self.write_limit) + # mypy thinks transport is a BaseTransport, not a Transport. + transport.set_write_buffer_limits(self.write_limit) # type: ignore super().connection_made(transport) - def eof_received(self): + def eof_received(self) -> bool: """ Close the transport after receiving EOF. @@ -1193,9 +1218,9 @@ def eof_received(self): """ logger.debug("%s - event = eof_received()", self.side) super().eof_received() - return + return False - def connection_lost(self, exc): + def connection_lost(self, exc: Optional[Exception]) -> None: """ 7.1.4. The WebSocket Connection is Closed. @@ -1203,8 +1228,10 @@ def connection_lost(self, exc): logger.debug("%s - event = connection_lost(%s)", self.side, exc) self.state = State.CLOSED logger.debug("%s - state = CLOSED", self.side) - if self.close_code is None: + if not hasattr(self, "close_code"): self.close_code = 1006 + if not hasattr(self, "close_reason"): + self.close_reason = "" logger.debug( "%s x code = %d, reason = %s", self.side, diff --git a/src/websockets/server.py b/src/websockets/server.py index 7fd32ba1e..efb3ebee3 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -8,7 +8,23 @@ import email.utils import http import logging +import socket import warnings +from types import TracebackType +from typing import ( + Any, + Awaitable, + Callable, + Generator, + List, + Optional, + Sequence, + Set, + Tuple, + Type, + Union, + cast, +) from .exceptions import ( AbortHandshake, @@ -19,10 +35,16 @@ InvalidUpgrade, NegotiationError, ) +from .extensions.base import Extension, ServerExtensionFactory from .extensions.permessage_deflate import ServerPerMessageDeflateFactory from .handshake import build_response, check_request -from .headers import build_extension_list, parse_extension_list, parse_subprotocol_list -from .http import USER_AGENT, Headers, MultipleValuesError, read_request +from .headers import ( + ExtensionHeader, + build_extension_list, + parse_extension_list, + parse_subprotocol_list, +) +from .http import USER_AGENT, Headers, HeadersLike, MultipleValuesError, read_request from .protocol import State, WebSocketCommonProtocol @@ -31,6 +53,11 @@ logger = logging.getLogger(__name__) +HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]] + +HTTPResponse = Tuple[http.HTTPStatus, HeadersLike, bytes] + + class WebSocketServerProtocol(WebSocketCommonProtocol): """ Complete WebSocket server implementation as an :class:`asyncio.Protocol`. @@ -48,17 +75,22 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): def __init__( self, - ws_handler, - ws_server, + ws_handler: Callable[["WebSocketServerProtocol", str], Awaitable[Any]], + ws_server: "WebSocketServer", *, - origins=None, - extensions=None, - subprotocols=None, - extra_headers=None, - process_request=None, - select_subprotocol=None, - **kwds, - ): + origins: Optional[List[Optional[str]]] = None, + extensions: Optional[List[ServerExtensionFactory]] = None, + subprotocols: Optional[List[str]] = None, + extra_headers: Optional[HeadersLikeOrCallable] = None, + process_request: Optional[ + Callable[ + [str, Headers], + Union[Optional[HTTPResponse], Awaitable[Optional[HTTPResponse]]], + ] + ] = None, + select_subprotocol: Optional[Callable[[List[str], List[str]], str]] = None, + **kwds: Any, + ) -> None: # For backwards-compatibility with 6.0 or earlier. if origins is not None and "" in origins: warnings.warn("use None instead of '' in origins", DeprecationWarning) @@ -73,7 +105,7 @@ def __init__( self._select_subprotocol = select_subprotocol super().__init__(**kwds) - def connection_made(self, transport): + def connection_made(self, transport: asyncio.BaseTransport) -> None: """ Register connection and initialize a task to handle it. @@ -86,7 +118,7 @@ def connection_made(self, transport): self.ws_server.register(self) self.handler_task = self.loop.create_task(self.handler()) - async def handler(self): + async def handler(self) -> None: """ Handle the lifecycle of a WebSocket connection. @@ -114,34 +146,31 @@ async def handler(self): logger.debug("Invalid origin", exc_info=True) status, headers, body = ( http.HTTPStatus.FORBIDDEN, - [], + Headers(), (str(exc) + "\n").encode(), ) elif isinstance(exc, InvalidUpgrade): logger.debug("Invalid upgrade", exc_info=True) status, headers, body = ( http.HTTPStatus.UPGRADE_REQUIRED, - [("Upgrade", "websocket")], + Headers([("Upgrade", "websocket")]), (str(exc) + "\n").encode(), ) elif isinstance(exc, InvalidHandshake): logger.debug("Invalid handshake", exc_info=True) status, headers, body = ( http.HTTPStatus.BAD_REQUEST, - [], + Headers(), (str(exc) + "\n").encode(), ) else: logger.warning("Error in opening handshake", exc_info=True) status, headers, body = ( http.HTTPStatus.INTERNAL_SERVER_ERROR, - [], + Headers(), b"See server log for more information.\n", ) - if not isinstance(headers, Headers): - headers = Headers(headers) - headers.setdefault("Date", email.utils.formatdate(usegmt=True)) headers.setdefault("Server", USER_AGENT) headers.setdefault("Content-Length", str(len(body))) @@ -184,7 +213,7 @@ async def handler(self): # connections before terminating. self.ws_server.unregister(self) - async def read_http_request(self): + async def read_http_request(self) -> Tuple[str, Headers]: """ Read request line and headers from the HTTP request. @@ -209,7 +238,9 @@ async def read_http_request(self): return path, headers - def write_http_response(self, status, headers, body=None): + def write_http_response( + self, status: http.HTTPStatus, headers: Headers, body: Optional[bytes] = None + ) -> None: """ Write status line and headers to the HTTP response. @@ -232,7 +263,9 @@ def write_http_response(self, status, headers, body=None): logger.debug("%s > Body (%d bytes)", self.side, len(body)) self.writer.write(body) - def process_request(self, path, request_headers): + def process_request( + self, path: str, request_headers: Headers + ) -> Union[Optional[HTTPResponse], Awaitable[Optional[HTTPResponse]]]: """ Intercept the HTTP request and return an HTTP response if needed. @@ -264,9 +297,12 @@ def process_request(self, path, request_headers): """ if self._process_request is not None: return self._process_request(path, request_headers) + return None @staticmethod - def process_origin(headers, origins=None): + def process_origin( + headers: Headers, origins: Optional[List[Optional[str]]] = None + ) -> Optional[str]: """ Handle the Origin HTTP request header. @@ -286,7 +322,9 @@ def process_origin(headers, origins=None): return origin @staticmethod - def process_extensions(headers, available_extensions): + def process_extensions( + headers: Headers, available_extensions: Optional[List[ServerExtensionFactory]] + ) -> Tuple[Optional[str], List[Extension]]: """ Handle the Sec-WebSocket-Extensions HTTP request header. @@ -319,14 +357,16 @@ def process_extensions(headers, available_extensions): order of extensions, may be implemented by overriding this method. """ - response_header = [] - accepted_extensions = [] + response_header_value: Optional[str] = None + + extension_headers: List[ExtensionHeader] = [] + accepted_extensions: List[Extension] = [] header_values = headers.get_all("Sec-WebSocket-Extensions") if header_values and available_extensions: - parsed_header_values = sum( + parsed_header_values: List[ExtensionHeader] = sum( [parse_extension_list(header_value) for header_value in header_values], [], ) @@ -348,7 +388,7 @@ def process_extensions(headers, available_extensions): continue # Add matching extension to the final list. - response_header.append((name, response_params)) + extension_headers.append((name, response_params)) accepted_extensions.append(extension) # Break out of the loop once we have a match. @@ -358,15 +398,15 @@ def process_extensions(headers, available_extensions): # matched what the client sent. The extension is declined. # Serialize extension header. - if response_header: - response_header = build_extension_list(response_header) - else: - response_header = None + if extension_headers: + response_header_value = build_extension_list(extension_headers) - return response_header, accepted_extensions + return response_header_value, accepted_extensions # Not @staticmethod because it calls self.select_subprotocol() - def process_subprotocol(self, headers, available_subprotocols): + def process_subprotocol( + self, headers: Headers, available_subprotocols: Optional[List[str]] + ) -> Optional[str]: """ Handle the Sec-WebSocket-Protocol HTTP request header. @@ -374,13 +414,13 @@ def process_subprotocol(self, headers, available_subprotocols): as the selected subprotocol. """ - subprotocol = None + subprotocol: Optional[str] = None header_values = headers.get_all("Sec-WebSocket-Protocol") if header_values and available_subprotocols: - parsed_header_values = sum( + parsed_header_values: List[str] = sum( [ parse_subprotocol_list(header_value) for header_value in header_values @@ -394,7 +434,9 @@ def process_subprotocol(self, headers, available_subprotocols): return subprotocol - def select_subprotocol(self, client_subprotocols, server_subprotocols): + def select_subprotocol( + self, client_subprotocols: List[str], server_subprotocols: List[str] + ) -> Optional[str]: """ Pick a subprotocol among those offered by the client. @@ -427,11 +469,11 @@ def select_subprotocol(self, client_subprotocols, server_subprotocols): async def handshake( self, - origins=None, - available_extensions=None, - available_subprotocols=None, - extra_headers=None, - ): + origins: Optional[List[Optional[str]]] = None, + available_extensions: Optional[List[ServerExtensionFactory]] = None, + available_subprotocols: Optional[List[str]] = None, + extra_headers: Optional[HeadersLikeOrCallable] = None, + ) -> str: """ Perform the server side of the opening handshake. @@ -460,10 +502,9 @@ async def handshake( # Hook for customizing request handling, for example checking # authentication or treating some paths as plain HTTP endpoints. - if asyncio.iscoroutinefunction(self.process_request): - early_response = await self.process_request(path, request_headers) - else: - early_response = self.process_request(path, request_headers) + early_response = self.process_request(path, request_headers) + if isinstance(early_response, Awaitable): + early_response = await early_response # Change the response to a 503 error if the server is shutting down. if not self.ws_server.is_serving(): @@ -538,20 +579,20 @@ class WebSocketServer: """ - def __init__(self, loop): + def __init__(self, loop: asyncio.AbstractEventLoop): # Store a reference to loop to avoid relying on self.server._loop. self.loop = loop # Keep track of active connections. - self.websockets = set() + self.websockets: Set[WebSocketServerProtocol] = set() # Task responsible for closing the server and terminating connections. - self.close_task = None + self.close_task: Optional[asyncio.Task[None]] = None # Completed when the server is closed and connections are terminated. - self.closed_waiter = loop.create_future() + self.closed_waiter: asyncio.Future[None] = loop.create_future() - def wrap(self, server): + def wrap(self, server: asyncio.AbstractServer) -> None: """ Attach to a given :class:`~asyncio.Server`. @@ -568,31 +609,33 @@ def wrap(self, server): """ self.server = server - def register(self, protocol): + def register(self, protocol: WebSocketServerProtocol) -> None: """ Register a connection with this server. """ self.websockets.add(protocol) - def unregister(self, protocol): + def unregister(self, protocol: WebSocketServerProtocol) -> None: """ Unregister a connection with this server. """ self.websockets.remove(protocol) - def is_serving(self): + def is_serving(self) -> bool: """ Tell whether the server is accepting new connections or shutting down. """ try: - return self.server.is_serving() # Python ≥ 3.7 + # Python ≥ 3.7 + return self.server.is_serving() # type: ignore except AttributeError: # pragma: no cover - return self.server.sockets is not None # Python < 3.7 + # Python < 3.7 + return self.server.sockets is not None - def close(self): + def close(self) -> None: """ Close the server and terminate connections with close code 1001. @@ -602,7 +645,7 @@ def close(self): if self.close_task is None: self.close_task = self.loop.create_task(self._close()) - async def _close(self): + async def _close(self) -> None: """ Implementation of :meth:`close`. @@ -647,7 +690,7 @@ async def _close(self): # Tell wait_closed() to return. self.closed_waiter.set_result(None) - async def wait_closed(self): + async def wait_closed(self) -> None: """ Wait until the server is closed and all connections are terminated. @@ -658,7 +701,7 @@ async def wait_closed(self): await asyncio.shield(self.closed_waiter) @property - def sockets(self): + def sockets(self) -> Optional[List[socket.socket]]: """ List of :class:`~socket.socket` objects the server is listening to. @@ -754,31 +797,33 @@ class Serve: def __init__( self, - ws_handler, - host=None, - port=None, + ws_handler: Callable[[WebSocketServerProtocol, str], Awaitable[Any]], + host: Optional[Union[str, Sequence[str]]] = None, + port: Optional[int] = None, *, - path=None, - create_protocol=None, - ping_interval=20, - ping_timeout=20, - close_timeout=None, - max_size=2 ** 20, - max_queue=2 ** 5, - read_limit=2 ** 16, - write_limit=2 ** 16, - loop=None, - legacy_recv=False, - klass=WebSocketServerProtocol, - timeout=10, - compression="deflate", - origins=None, - extensions=None, - subprotocols=None, - extra_headers=None, - process_request=None, - select_subprotocol=None, - **kwds, + path: Optional[str] = None, + create_protocol: Optional[Type[WebSocketServerProtocol]] = None, + ping_interval: float = 20, + ping_timeout: float = 20, + close_timeout: Optional[float] = None, + max_size: int = 2 ** 20, + max_queue: int = 2 ** 5, + read_limit: int = 2 ** 16, + write_limit: int = 2 ** 16, + loop: Optional[asyncio.AbstractEventLoop] = None, + legacy_recv: bool = False, + klass: Type[WebSocketServerProtocol] = WebSocketServerProtocol, + timeout: float = 10, + compression: Optional[str] = "deflate", + origins: Optional[List[Optional[str]]] = None, + extensions: Optional[List[ServerExtensionFactory]] = None, + subprotocols: Optional[List[str]] = None, + extra_headers: Optional[HeadersLikeOrCallable] = None, + process_request: Optional[ + Callable[[str, Headers], Optional[HTTPResponse]] + ] = None, + select_subprotocol: Optional[Callable[[List[str], List[str]], str]] = None, + **kwds: Any, ): # Backwards-compatibility: close_timeout used to be called timeout. # If both are specified, timeout is ignored. @@ -832,6 +877,9 @@ def __init__( ) if path is None: + # https://github.com/python/typeshed/pull/2763 + host = cast(str, host) + port = cast(int, port) creating_server = loop.create_server(factory, host, port, **kwds) else: creating_server = loop.create_unix_server(factory, path, **kwds) @@ -841,22 +889,27 @@ def __init__( self.ws_server = ws_server @asyncio.coroutine - def __iter__(self): - return self.__await_impl__() + def __iter__(self) -> Generator[Any, None, WebSocketServer]: + return (yield from self.__await__()) - async def __aenter__(self): + async def __aenter__(self) -> WebSocketServer: return await self - async def __aexit__(self, exc_type, exc_value, traceback): + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: self.ws_server.close() await self.ws_server.wait_closed() - async def __await_impl__(self): + async def __await_impl__(self) -> WebSocketServer: server = await self._creating_server self.ws_server.wrap(server) return self.ws_server - def __await__(self): + def __await__(self) -> Generator[Any, None, WebSocketServer]: # __await__() must return a type that I don't know how to obtain except # by calling __await__() on the return value of an async function. # I'm not finding a better way to take advantage of PEP 492. @@ -866,7 +919,11 @@ def __await__(self): serve = Serve -def unix_serve(ws_handler, path, **kwargs): +def unix_serve( + ws_handler: Callable[[WebSocketServerProtocol, str], Awaitable[Any]], + path: str, + **kwargs: Any, +) -> Serve: """ Similar to :func:`serve()`, but for listening on Unix sockets. diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 730adf54e..cf6b798ee 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -6,17 +6,29 @@ """ -import collections import urllib.parse +from typing import NamedTuple, Optional, Tuple from .exceptions import InvalidURI __all__ = ["parse_uri", "WebSocketURI"] -WebSocketURI = collections.namedtuple( - "WebSocketURI", ["secure", "host", "port", "resource_name", "user_info"] +# Switch to class-based syntax when dropping support for Python < 3.6. + +# Convert to a dataclass when dropping support for Python < 3.7. + +WebSocketURI = NamedTuple( + "WebSocketURI", + [ + ("secure", bool), + ("host", str), + ("port", int), + ("resource_name", str), + ("user_info", Optional[Tuple[str, str]]), + ], ) + WebSocketURI.__doc__ = """WebSocket URI. * ``secure`` is the secure flag @@ -31,7 +43,7 @@ """ -def parse_uri(uri): +def parse_uri(uri: str) -> WebSocketURI: """ This function parses and validates a WebSocket URI. @@ -40,22 +52,22 @@ def parse_uri(uri): Otherwise it raises an :exc:`~websockets.exceptions.InvalidURI` exception. """ - uri = urllib.parse.urlparse(uri) + parsed = urllib.parse.urlparse(uri) try: - assert uri.scheme in ["ws", "wss"] - assert uri.params == "" - assert uri.fragment == "" - assert uri.hostname is not None + assert parsed.scheme in ["ws", "wss"] + assert parsed.params == "" + assert parsed.fragment == "" + assert parsed.hostname is not None except AssertionError as exc: - raise InvalidURI(f"{uri} isn't a valid URI") from exc - - secure = uri.scheme == "wss" - host = uri.hostname - port = uri.port or (443 if secure else 80) - resource_name = uri.path or "/" - if uri.query: - resource_name += "?" + uri.query + raise InvalidURI(uri) from exc + + secure = parsed.scheme == "wss" + host = parsed.hostname + port = parsed.port or (443 if secure else 80) + resource_name = parsed.path or "/" + if parsed.query: + resource_name += "?" + parsed.query user_info = None - if uri.username or uri.password: - user_info = (uri.username, uri.password) + if parsed.username or parsed.password: + user_info = (parsed.username, parsed.password) return WebSocketURI(secure, host, port, resource_name, user_info) diff --git a/src/websockets/utils.py b/src/websockets/utils.py index 193f8fc32..e289e6980 100644 --- a/src/websockets/utils.py +++ b/src/websockets/utils.py @@ -4,7 +4,7 @@ __all__ = ["apply_mask"] -def apply_mask(data, mask): +def apply_mask(data: bytes, mask: bytes) -> bytes: """ Apply masking to the data of a WebSocket message. diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index 80003ca2d..0ec49c6c0 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -37,6 +37,225 @@ def assertExtensionEqual(self, extension1, extension2): ) +class PerMessageDeflateTests(unittest.TestCase, ExtensionTestsMixin): + def setUp(self): + # Set up an instance of the permessage-deflate extension with the most + # common settings. Since the extension is symmetrical, this instance + # may be used for testing both encoding and decoding. + self.extension = PerMessageDeflate(False, False, 15, 15) + + def test_name(self): + assert self.extension.name == "permessage-deflate" + + def test_repr(self): + self.assertExtensionEqual(eval(repr(self.extension)), self.extension) + + # Control frames aren't encoded or decoded. + + def test_no_encode_decode_ping_frame(self): + frame = Frame(True, OP_PING, b"") + + self.assertEqual(self.extension.encode(frame), frame) + + self.assertEqual(self.extension.decode(frame), frame) + + def test_no_encode_decode_pong_frame(self): + frame = Frame(True, OP_PONG, b"") + + self.assertEqual(self.extension.encode(frame), frame) + + self.assertEqual(self.extension.decode(frame), frame) + + def test_no_encode_decode_close_frame(self): + frame = Frame(True, OP_CLOSE, serialize_close(1000, "")) + + self.assertEqual(self.extension.encode(frame), frame) + + self.assertEqual(self.extension.decode(frame), frame) + + # Data frames are encoded and decoded. + + def test_encode_decode_text_frame(self): + frame = Frame(True, OP_TEXT, "café".encode("utf-8")) + + enc_frame = self.extension.encode(frame) + + self.assertEqual(enc_frame, frame._replace(rsv1=True, data=b"JNL;\xbc\x12\x00")) + + dec_frame = self.extension.decode(enc_frame) + + self.assertEqual(dec_frame, frame) + + def test_encode_decode_binary_frame(self): + frame = Frame(True, OP_BINARY, b"tea") + + enc_frame = self.extension.encode(frame) + + self.assertEqual(enc_frame, frame._replace(rsv1=True, data=b"*IM\x04\x00")) + + dec_frame = self.extension.decode(enc_frame) + + self.assertEqual(dec_frame, frame) + + def test_encode_decode_fragmented_text_frame(self): + frame1 = Frame(False, OP_TEXT, "café".encode("utf-8")) + frame2 = Frame(False, OP_CONT, " & ".encode("utf-8")) + frame3 = Frame(True, OP_CONT, "croissants".encode("utf-8")) + + enc_frame1 = self.extension.encode(frame1) + enc_frame2 = self.extension.encode(frame2) + enc_frame3 = self.extension.encode(frame3) + + self.assertEqual( + enc_frame1, + frame1._replace(rsv1=True, data=b"JNL;\xbc\x12\x00\x00\x00\xff\xff"), + ) + self.assertEqual( + enc_frame2, frame2._replace(rsv1=True, data=b"RPS\x00\x00\x00\x00\xff\xff") + ) + self.assertEqual( + enc_frame3, frame3._replace(rsv1=True, data=b"J.\xca\xcf,.N\xcc+)\x06\x00") + ) + + dec_frame1 = self.extension.decode(enc_frame1) + dec_frame2 = self.extension.decode(enc_frame2) + dec_frame3 = self.extension.decode(enc_frame3) + + self.assertEqual(dec_frame1, frame1) + self.assertEqual(dec_frame2, frame2) + self.assertEqual(dec_frame3, frame3) + + def test_encode_decode_fragmented_binary_frame(self): + frame1 = Frame(False, OP_TEXT, b"tea ") + frame2 = Frame(True, OP_CONT, b"time") + + enc_frame1 = self.extension.encode(frame1) + enc_frame2 = self.extension.encode(frame2) + + self.assertEqual( + enc_frame1, frame1._replace(rsv1=True, data=b"*IMT\x00\x00\x00\x00\xff\xff") + ) + self.assertEqual( + enc_frame2, frame2._replace(rsv1=True, data=b"*\xc9\xccM\x05\x00") + ) + + dec_frame1 = self.extension.decode(enc_frame1) + dec_frame2 = self.extension.decode(enc_frame2) + + self.assertEqual(dec_frame1, frame1) + self.assertEqual(dec_frame2, frame2) + + def test_no_decode_text_frame(self): + frame = Frame(True, OP_TEXT, "café".encode("utf-8")) + + # Try decoding a frame that wasn't encoded. + self.assertEqual(self.extension.decode(frame), frame) + + def test_no_decode_binary_frame(self): + frame = Frame(True, OP_TEXT, b"tea") + + # Try decoding a frame that wasn't encoded. + self.assertEqual(self.extension.decode(frame), frame) + + def test_no_decode_fragmented_text_frame(self): + frame1 = Frame(False, OP_TEXT, "café".encode("utf-8")) + frame2 = Frame(False, OP_CONT, " & ".encode("utf-8")) + frame3 = Frame(True, OP_CONT, "croissants".encode("utf-8")) + + dec_frame1 = self.extension.decode(frame1) + dec_frame2 = self.extension.decode(frame2) + dec_frame3 = self.extension.decode(frame3) + + self.assertEqual(dec_frame1, frame1) + self.assertEqual(dec_frame2, frame2) + self.assertEqual(dec_frame3, frame3) + + def test_no_decode_fragmented_binary_frame(self): + frame1 = Frame(False, OP_TEXT, b"tea ") + frame2 = Frame(True, OP_CONT, b"time") + + dec_frame1 = self.extension.decode(frame1) + dec_frame2 = self.extension.decode(frame2) + + self.assertEqual(dec_frame1, frame1) + self.assertEqual(dec_frame2, frame2) + + def test_context_takeover(self): + frame = Frame(True, OP_TEXT, "café".encode("utf-8")) + + enc_frame1 = self.extension.encode(frame) + enc_frame2 = self.extension.encode(frame) + + self.assertEqual(enc_frame1.data, b"JNL;\xbc\x12\x00") + self.assertEqual(enc_frame2.data, b"J\x06\x11\x00\x00") + + def test_remote_no_context_takeover(self): + # No context takeover when decoding messages. + self.extension = PerMessageDeflate(True, False, 15, 15) + + frame = Frame(True, OP_TEXT, "café".encode("utf-8")) + + enc_frame1 = self.extension.encode(frame) + enc_frame2 = self.extension.encode(frame) + + self.assertEqual(enc_frame1.data, b"JNL;\xbc\x12\x00") + self.assertEqual(enc_frame2.data, b"J\x06\x11\x00\x00") + + dec_frame1 = self.extension.decode(enc_frame1) + self.assertEqual(dec_frame1, frame) + + with self.assertRaises(zlib.error) as exc: + self.extension.decode(enc_frame2) + self.assertIn("invalid distance too far back", str(exc.exception)) + + def test_local_no_context_takeover(self): + # No context takeover when encoding and decoding messages. + self.extension = PerMessageDeflate(True, True, 15, 15) + + frame = Frame(True, OP_TEXT, "café".encode("utf-8")) + + enc_frame1 = self.extension.encode(frame) + enc_frame2 = self.extension.encode(frame) + + self.assertEqual(enc_frame1.data, b"JNL;\xbc\x12\x00") + self.assertEqual(enc_frame2.data, b"JNL;\xbc\x12\x00") + + dec_frame1 = self.extension.decode(enc_frame1) + dec_frame2 = self.extension.decode(enc_frame2) + + self.assertEqual(dec_frame1, frame) + self.assertEqual(dec_frame2, frame) + + # Compression settings can be customized. + + def test_compress_settings(self): + # Configure an extension so that no compression actually occurs. + extension = PerMessageDeflate(False, False, 15, 15, {"level": 0}) + + frame = Frame(True, OP_TEXT, "café".encode("utf-8")) + + enc_frame = extension.encode(frame) + + self.assertEqual( + enc_frame, + frame._replace( + rsv1=True, data=b"\x00\x05\x00\xfa\xffcaf\xc3\xa9\x00" # not compressed + ), + ) + + # Frames aren't decoded beyond max_length. + + def test_decompress_max_size(self): + frame = Frame(True, OP_TEXT, ("a" * 20).encode("utf-8")) + + enc_frame = self.extension.encode(frame) + + self.assertEqual(enc_frame.data, b"JL\xc4\x04\x00\x00") + + with self.assertRaises(PayloadTooBig): + self.extension.decode(enc_frame, max_size=10) + + class ClientPerMessageDeflateFactoryTests(unittest.TestCase, ExtensionTestsMixin): def test_name(self): assert ClientPerMessageDeflateFactory.name == "permessage-deflate" @@ -571,222 +790,3 @@ def test_process_response_params_deduplication(self): factory.process_request_params( [], [PerMessageDeflate(False, False, 15, 15)] ) - - -class PerMessageDeflateTests(unittest.TestCase, ExtensionTestsMixin): - def setUp(self): - # Set up an instance of the permessage-deflate extension with the most - # common settings. Since the extension is symmetrical, this instance - # may be used for testing both encoding and decoding. - self.extension = PerMessageDeflate(False, False, 15, 15) - - def test_name(self): - assert self.extension.name == "permessage-deflate" - - def test_repr(self): - self.assertExtensionEqual(eval(repr(self.extension)), self.extension) - - # Control frames aren't encoded or decoded. - - def test_no_encode_decode_ping_frame(self): - frame = Frame(True, OP_PING, b"") - - self.assertEqual(self.extension.encode(frame), frame) - - self.assertEqual(self.extension.decode(frame), frame) - - def test_no_encode_decode_pong_frame(self): - frame = Frame(True, OP_PONG, b"") - - self.assertEqual(self.extension.encode(frame), frame) - - self.assertEqual(self.extension.decode(frame), frame) - - def test_no_encode_decode_close_frame(self): - frame = Frame(True, OP_CLOSE, serialize_close(1000, "")) - - self.assertEqual(self.extension.encode(frame), frame) - - self.assertEqual(self.extension.decode(frame), frame) - - # Data frames are encoded and decoded. - - def test_encode_decode_text_frame(self): - frame = Frame(True, OP_TEXT, "café".encode("utf-8")) - - enc_frame = self.extension.encode(frame) - - self.assertEqual(enc_frame, frame._replace(rsv1=True, data=b"JNL;\xbc\x12\x00")) - - dec_frame = self.extension.decode(enc_frame) - - self.assertEqual(dec_frame, frame) - - def test_encode_decode_binary_frame(self): - frame = Frame(True, OP_BINARY, b"tea") - - enc_frame = self.extension.encode(frame) - - self.assertEqual(enc_frame, frame._replace(rsv1=True, data=b"*IM\x04\x00")) - - dec_frame = self.extension.decode(enc_frame) - - self.assertEqual(dec_frame, frame) - - def test_encode_decode_fragmented_text_frame(self): - frame1 = Frame(False, OP_TEXT, "café".encode("utf-8")) - frame2 = Frame(False, OP_CONT, " & ".encode("utf-8")) - frame3 = Frame(True, OP_CONT, "croissants".encode("utf-8")) - - enc_frame1 = self.extension.encode(frame1) - enc_frame2 = self.extension.encode(frame2) - enc_frame3 = self.extension.encode(frame3) - - self.assertEqual( - enc_frame1, - frame1._replace(rsv1=True, data=b"JNL;\xbc\x12\x00\x00\x00\xff\xff"), - ) - self.assertEqual( - enc_frame2, frame2._replace(rsv1=True, data=b"RPS\x00\x00\x00\x00\xff\xff") - ) - self.assertEqual( - enc_frame3, frame3._replace(rsv1=True, data=b"J.\xca\xcf,.N\xcc+)\x06\x00") - ) - - dec_frame1 = self.extension.decode(enc_frame1) - dec_frame2 = self.extension.decode(enc_frame2) - dec_frame3 = self.extension.decode(enc_frame3) - - self.assertEqual(dec_frame1, frame1) - self.assertEqual(dec_frame2, frame2) - self.assertEqual(dec_frame3, frame3) - - def test_encode_decode_fragmented_binary_frame(self): - frame1 = Frame(False, OP_TEXT, b"tea ") - frame2 = Frame(True, OP_CONT, b"time") - - enc_frame1 = self.extension.encode(frame1) - enc_frame2 = self.extension.encode(frame2) - - self.assertEqual( - enc_frame1, frame1._replace(rsv1=True, data=b"*IMT\x00\x00\x00\x00\xff\xff") - ) - self.assertEqual( - enc_frame2, frame2._replace(rsv1=True, data=b"*\xc9\xccM\x05\x00") - ) - - dec_frame1 = self.extension.decode(enc_frame1) - dec_frame2 = self.extension.decode(enc_frame2) - - self.assertEqual(dec_frame1, frame1) - self.assertEqual(dec_frame2, frame2) - - def test_no_decode_text_frame(self): - frame = Frame(True, OP_TEXT, "café".encode("utf-8")) - - # Try decoding a frame that wasn't encoded. - self.assertEqual(self.extension.decode(frame), frame) - - def test_no_decode_binary_frame(self): - frame = Frame(True, OP_TEXT, b"tea") - - # Try decoding a frame that wasn't encoded. - self.assertEqual(self.extension.decode(frame), frame) - - def test_no_decode_fragmented_text_frame(self): - frame1 = Frame(False, OP_TEXT, "café".encode("utf-8")) - frame2 = Frame(False, OP_CONT, " & ".encode("utf-8")) - frame3 = Frame(True, OP_CONT, "croissants".encode("utf-8")) - - dec_frame1 = self.extension.decode(frame1) - dec_frame2 = self.extension.decode(frame2) - dec_frame3 = self.extension.decode(frame3) - - self.assertEqual(dec_frame1, frame1) - self.assertEqual(dec_frame2, frame2) - self.assertEqual(dec_frame3, frame3) - - def test_no_decode_fragmented_binary_frame(self): - frame1 = Frame(False, OP_TEXT, b"tea ") - frame2 = Frame(True, OP_CONT, b"time") - - dec_frame1 = self.extension.decode(frame1) - dec_frame2 = self.extension.decode(frame2) - - self.assertEqual(dec_frame1, frame1) - self.assertEqual(dec_frame2, frame2) - - def test_context_takeover(self): - frame = Frame(True, OP_TEXT, "café".encode("utf-8")) - - enc_frame1 = self.extension.encode(frame) - enc_frame2 = self.extension.encode(frame) - - self.assertEqual(enc_frame1.data, b"JNL;\xbc\x12\x00") - self.assertEqual(enc_frame2.data, b"J\x06\x11\x00\x00") - - def test_remote_no_context_takeover(self): - # No context takeover when decoding messages. - self.extension = PerMessageDeflate(True, False, 15, 15) - - frame = Frame(True, OP_TEXT, "café".encode("utf-8")) - - enc_frame1 = self.extension.encode(frame) - enc_frame2 = self.extension.encode(frame) - - self.assertEqual(enc_frame1.data, b"JNL;\xbc\x12\x00") - self.assertEqual(enc_frame2.data, b"J\x06\x11\x00\x00") - - dec_frame1 = self.extension.decode(enc_frame1) - self.assertEqual(dec_frame1, frame) - - with self.assertRaises(zlib.error) as exc: - self.extension.decode(enc_frame2) - self.assertIn("invalid distance too far back", str(exc.exception)) - - def test_local_no_context_takeover(self): - # No context takeover when encoding and decoding messages. - self.extension = PerMessageDeflate(True, True, 15, 15) - - frame = Frame(True, OP_TEXT, "café".encode("utf-8")) - - enc_frame1 = self.extension.encode(frame) - enc_frame2 = self.extension.encode(frame) - - self.assertEqual(enc_frame1.data, b"JNL;\xbc\x12\x00") - self.assertEqual(enc_frame2.data, b"JNL;\xbc\x12\x00") - - dec_frame1 = self.extension.decode(enc_frame1) - dec_frame2 = self.extension.decode(enc_frame2) - - self.assertEqual(dec_frame1, frame) - self.assertEqual(dec_frame2, frame) - - # Compression settings can be customized. - - def test_compress_settings(self): - # Configure an extension so that no compression actually occurs. - extension = PerMessageDeflate(False, False, 15, 15, {"level": 0}) - - frame = Frame(True, OP_TEXT, "café".encode("utf-8")) - - enc_frame = extension.encode(frame) - - self.assertEqual( - enc_frame, - frame._replace( - rsv1=True, data=b"\x00\x05\x00\xfa\xffcaf\xc3\xa9\x00" # not compressed - ), - ) - - # Frames aren't decoded beyond max_length. - - def test_decompress_max_size(self): - frame = Frame(True, OP_TEXT, ("a" * 20).encode("utf-8")) - - enc_frame = self.extension.encode(frame) - - self.assertEqual(enc_frame.data, b"JL\xc4\x04\x00\x00") - - with self.assertRaises(PayloadTooBig): - self.extension.decode(enc_frame, max_size=10) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 7b935491b..3ccdadb82 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -115,7 +115,7 @@ def test_str(self): "(private use), no reason" ), ( - InvalidURI("| isn't a valid URI"), + InvalidURI("|"), "| isn't a valid URI", ), ( diff --git a/tests/test_http.py b/tests/test_http.py index a3a8cd403..39961d641 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -130,6 +130,9 @@ def test_contains_case_insensitive(self): def test_contains_not_found(self): self.assertNotIn("Date", self.headers) + def test_contains_non_string_key(self): + self.assertNotIn(42, self.headers) + def test_iter(self): self.assertEqual(set(iter(self.headers)), {"connection", "server"}) diff --git a/tox.ini b/tox.ini index 4d085f56c..7397c90ae 100644 --- a/tox.ini +++ b/tox.ini @@ -24,5 +24,5 @@ commands = isort --check-only --recursive src tests deps = isort [testenv:mypy] -commands = mypy src +commands = mypy --strict src deps = mypy From 03c1fb657e406c3f707b5605b44fa63af188f1f8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 3 Feb 2019 10:27:52 +0100 Subject: [PATCH 0548/1539] Simplify NamedTuple declarations with class syntax. --- src/websockets/framing.py | 23 ++++++++++------------- src/websockets/uri.py | 32 ++++++++++++++++---------------- 2 files changed, 26 insertions(+), 29 deletions(-) diff --git a/src/websockets/framing.py b/src/websockets/framing.py index 8eb1a79bd..15b76eb93 100644 --- a/src/websockets/framing.py +++ b/src/websockets/framing.py @@ -64,22 +64,19 @@ Data = Union[str, bytes] +# Remove FrameData when dropping support for Python < 3.6.1 — the first +# version where NamedTuple supports default values, methods, and docstrings. -# Switch to class-based syntax when dropping support for Python < 3.6. +# Consider converting to a dataclass when dropping support for Python < 3.7. -# Convert to a dataclass when dropping support for Python < 3.7. -FrameData = NamedTuple( - "FrameData", - [ - ("fin", bool), - ("opcode", int), - ("data", bytes), - ("rsv1", bool), - ("rsv2", bool), - ("rsv3", bool), - ], -) +class FrameData(NamedTuple): + fin: bool + opcode: int + data: bytes + rsv1: bool + rsv2: bool + rsv3: bool class Frame(FrameData): diff --git a/src/websockets/uri.py b/src/websockets/uri.py index cf6b798ee..16d3d6761 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -14,22 +14,22 @@ __all__ = ["parse_uri", "WebSocketURI"] -# Switch to class-based syntax when dropping support for Python < 3.6. - -# Convert to a dataclass when dropping support for Python < 3.7. - -WebSocketURI = NamedTuple( - "WebSocketURI", - [ - ("secure", bool), - ("host", str), - ("port", int), - ("resource_name", str), - ("user_info", Optional[Tuple[str, str]]), - ], -) - -WebSocketURI.__doc__ = """WebSocket URI. + +# Consider converting to a dataclass when dropping support for Python < 3.7. + + +class WebSocketURI(NamedTuple): + secure: bool + host: str + port: int + resource_name: str + user_info: Optional[Tuple[str, str]] + + +# Declare the docstring normally when dropping support for Python < 3.6.1. + +WebSocketURI.__doc__ = """ +WebSocket URI. * ``secure`` is the secure flag * ``host`` is the lower-case host From 98b5e854d89686bbb33079cefbdf2cd83f3ca1c4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 3 Feb 2019 10:53:07 +0100 Subject: [PATCH 0549/1539] Move shared types to a typing module. --- docs/api.rst | 8 ++++++++ src/websockets/__init__.py | 2 ++ src/websockets/framing.py | 4 +--- src/websockets/protocol.py | 2 +- src/websockets/typing.py | 19 +++++++++++++++++++ 5 files changed, 31 insertions(+), 4 deletions(-) create mode 100644 src/websockets/typing.py diff --git a/docs/api.rst b/docs/api.rst index ce6529d1d..acdc69dab 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -86,6 +86,14 @@ Shared .. autoattribute:: open .. autoattribute:: closed +Types +..... + +.. automodule:: websockets.typing + + .. autodata:: Data + + Per-Message Deflate Extension ............................. diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 5fbff0d41..9bfbdabfe 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -4,6 +4,7 @@ from .exceptions import * from .protocol import * from .server import * +from .typing import * from .uri import * from .version import version as __version__ # noqa @@ -13,5 +14,6 @@ + exceptions.__all__ + protocol.__all__ + server.__all__ + + typing.__all__ + uri.__all__ ) diff --git a/src/websockets/framing.py b/src/websockets/framing.py index 15b76eb93..0a778ed53 100644 --- a/src/websockets/framing.py +++ b/src/websockets/framing.py @@ -21,10 +21,10 @@ NamedTuple, Optional, Tuple, - Union, ) from .exceptions import PayloadTooBig, WebSocketProtocolError +from .typing import Data if TYPE_CHECKING: # pragma: no cover @@ -62,8 +62,6 @@ EXTERNAL_CLOSE_CODES = [1000, 1001, 1002, 1003, 1007, 1008, 1009, 1010, 1011] -Data = Union[str, bytes] - # Remove FrameData when dropping support for Python < 3.6.1 — the first # version where NamedTuple supports default values, methods, and docstrings. diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index b28dcef72..f4dbbb279 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -35,9 +35,9 @@ ) from .extensions.base import Extension from .framing import * -from .framing import Data from .handshake import * from .http import Headers +from .typing import Data __all__ = ["WebSocketCommonProtocol"] diff --git a/src/websockets/typing.py b/src/websockets/typing.py new file mode 100644 index 000000000..2f0c50c59 --- /dev/null +++ b/src/websockets/typing.py @@ -0,0 +1,19 @@ +from typing import Union + + +__all__ = ["Data"] + +Data = Union[str, bytes] + +Data__doc__ = """ +Types supported in a WebSocket message: + +- :class:`str` for text messages +- :class:`bytes` for binary messages + +""" + +try: + Data.__doc__ = Data__doc__ # type: ignore +except AttributeError: # pragma: no cover + pass From 82b71b782ef776e6643c7e0208cba87538baf389 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 3 Feb 2019 19:01:45 +0100 Subject: [PATCH 0550/1539] Improve typing declarations. * Create new types for the values of various HTTP headers. * Prefer Sequence (generic type) to List in parameter types -- this has the nice side effect of preventing modification of mutable parameters. And refactor a bit the headers module for readability. --- src/websockets/client.py | 57 +++-- src/websockets/exceptions.py | 4 +- src/websockets/extensions/base.py | 16 +- .../extensions/permessage_deflate.py | 22 +- src/websockets/framing.py | 6 +- src/websockets/headers.py | 221 ++++++++++-------- src/websockets/server.py | 69 +++--- src/websockets/typing.py | 13 +- tests/test_headers.py | 28 +-- 9 files changed, 242 insertions(+), 194 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 40c5b0073..9cefaedb8 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -7,7 +7,7 @@ import collections.abc import logging from types import TracebackType -from typing import Any, Generator, List, Optional, Tuple, Type, cast +from typing import Any, Generator, List, Optional, Sequence, Tuple, Type, cast from .exceptions import ( InvalidHandshake, @@ -22,13 +22,14 @@ from .headers import ( ExtensionHeader, build_basic_auth, - build_extension_list, - build_subprotocol_list, - parse_extension_list, - parse_subprotocol_list, + build_extension, + build_subprotocol, + parse_extension, + parse_subprotocol, ) from .http import USER_AGENT, Headers, HeadersLike, read_response from .protocol import WebSocketCommonProtocol +from .typing import Origin, Subprotocol from .uri import WebSocketURI, parse_uri @@ -52,9 +53,9 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): def __init__( self, *, - origin: Optional[str] = None, - extensions: Optional[List[ClientExtensionFactory]] = None, - subprotocols: Optional[List[str]] = None, + origin: Optional[Origin] = None, + extensions: Optional[Sequence[ClientExtensionFactory]] = None, + subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLike] = None, **kwds: Any, ) -> None: @@ -108,7 +109,8 @@ async def read_http_response(self) -> Tuple[int, Headers]: @staticmethod def process_extensions( - headers: Headers, available_extensions: Optional[List[ClientExtensionFactory]] + headers: Headers, + available_extensions: Optional[Sequence[ClientExtensionFactory]], ) -> List[Extension]: """ Handle the Sec-WebSocket-Extensions HTTP response header. @@ -146,8 +148,7 @@ def process_extensions( raise InvalidHandshake("No extensions supported") parsed_header_values: List[ExtensionHeader] = sum( - [parse_extension_list(header_value) for header_value in header_values], - [], + [parse_extension(header_value) for header_value in header_values], [] ) for name, response_params in parsed_header_values: @@ -184,8 +185,8 @@ def process_extensions( @staticmethod def process_subprotocol( - headers: Headers, available_subprotocols: Optional[List[str]] - ) -> Optional[str]: + headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]] + ) -> Optional[Subprotocol]: """ Handle the Sec-WebSocket-Protocol HTTP response header. @@ -194,7 +195,7 @@ def process_subprotocol( Return the selected subprotocol. """ - subprotocol: Optional[str] = None + subprotocol: Optional[Subprotocol] = None header_values = headers.get_all("Sec-WebSocket-Protocol") @@ -203,12 +204,8 @@ def process_subprotocol( if available_subprotocols is None: raise InvalidHandshake("No subprotocols supported") - parsed_header_values: List[str] = sum( - [ - parse_subprotocol_list(header_value) - for header_value in header_values - ], - [], + parsed_header_values: Sequence[Subprotocol] = sum( + [parse_subprotocol(header_value) for header_value in header_values], [] ) if len(parsed_header_values) > 1: @@ -225,9 +222,9 @@ def process_subprotocol( async def handshake( self, wsuri: WebSocketURI, - origin: Optional[str] = None, - available_extensions: Optional[List[ClientExtensionFactory]] = None, - available_subprotocols: Optional[List[str]] = None, + origin: Optional[Origin] = None, + available_extensions: Optional[Sequence[ClientExtensionFactory]] = None, + available_subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLike] = None, ) -> None: """ @@ -266,7 +263,7 @@ async def handshake( key = build_request(request_headers) if available_extensions is not None: - extensions_header = build_extension_list( + extensions_header = build_extension( [ (extension_factory.name, extension_factory.get_request_params()) for extension_factory in available_extensions @@ -275,7 +272,7 @@ async def handshake( request_headers["Sec-WebSocket-Extensions"] = extensions_header if available_subprotocols is not None: - protocol_header = build_subprotocol_list(available_subprotocols) + protocol_header = build_subprotocol(available_subprotocols) request_headers["Sec-WebSocket-Protocol"] = protocol_header if extra_headers is not None: @@ -382,9 +379,9 @@ def __init__( klass: Type[WebSocketClientProtocol] = WebSocketClientProtocol, timeout: float = 10, compression: Optional[str] = "deflate", - origin: Optional[str] = None, - extensions: Optional[List[ClientExtensionFactory]] = None, - subprotocols: Optional[List[str]] = None, + origin: Optional[Origin] = None, + extensions: Optional[Sequence[ClientExtensionFactory]] = None, + subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLike] = None, **kwds: Any, ): @@ -417,9 +414,9 @@ def __init__( extension_factory.name == ClientPerMessageDeflateFactory.name for extension_factory in extensions ): - extensions.append( + extensions = list(extensions) + [ ClientPerMessageDeflateFactory(client_max_window_bits=True) - ) + ] elif compression is not None: raise ValueError(f"Unsupported compression: {compression}") diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 9999527ef..436c594a9 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -94,8 +94,8 @@ class InvalidHeaderFormat(InvalidHeader): """ - def __init__(self, name: str, error: str, string: str, pos: int) -> None: - error = f"{error} at {pos} in {string}" + def __init__(self, name: str, error: str, header: str, pos: int) -> None: + error = f"{error} at {pos} in {header}" super().__init__(name, error) diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index 707e9317a..ed847c6bc 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -6,10 +6,10 @@ """ -from typing import List, Optional, Tuple +from typing import List, Optional, Sequence, Tuple from ..framing import Frame -from ..headers import ExtensionParameters +from ..typing import ExtensionParameter __all__ = ["Extension", "ClientExtensionFactory", "ServerExtensionFactory"] @@ -60,7 +60,7 @@ def name(self) -> str: """ - def get_request_params(self) -> ExtensionParameters: + def get_request_params(self) -> List[ExtensionParameter]: """ Build request parameters. @@ -69,7 +69,9 @@ def get_request_params(self) -> ExtensionParameters: """ def process_response_params( - self, params: ExtensionParameters, accepted_extensions: List[Extension] + self, + params: Sequence[ExtensionParameter], + accepted_extensions: Sequence[Extension], ) -> Extension: """ Process response parameters received from the server. @@ -100,8 +102,10 @@ def name(self) -> str: """ def process_request_params( - self, params: ExtensionParameters, accepted_extensions: List[Extension] - ) -> Tuple[ExtensionParameters, Extension]: + self, + params: Sequence[ExtensionParameter], + accepted_extensions: Sequence[Extension], + ) -> Tuple[List[ExtensionParameter], Extension]: """ Process request parameters received from the client. diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 93698a363..145cb2bbe 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -5,7 +5,7 @@ """ import zlib -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union from ..exceptions import ( DuplicateParameter, @@ -15,7 +15,7 @@ PayloadTooBig, ) from ..framing import CTRL_OPCODES, OP_CONT, Frame -from ..headers import ExtensionParameters +from ..typing import ExtensionParameter from .base import ClientExtensionFactory, Extension, ServerExtensionFactory @@ -174,12 +174,12 @@ def _build_parameters( client_no_context_takeover: bool, server_max_window_bits: Optional[int], client_max_window_bits: Optional[Union[int, bool]], -) -> ExtensionParameters: +) -> List[ExtensionParameter]: """ Build a list of ``(name, value)`` pairs for some compression parameters. """ - params: ExtensionParameters = [] + params: List[ExtensionParameter] = [] if server_no_context_takeover: params.append(("server_no_context_takeover", None)) if client_no_context_takeover: @@ -194,7 +194,7 @@ def _build_parameters( def _extract_parameters( - params: ExtensionParameters, *, is_server: bool + params: Sequence[ExtensionParameter], *, is_server: bool ) -> Tuple[bool, bool, Optional[int], Optional[Union[int, bool]]]: """ Extract compression parameters from a list of ``(name, value)`` pairs. @@ -310,7 +310,7 @@ def __init__( self.client_max_window_bits = client_max_window_bits self.compress_settings = compress_settings - def get_request_params(self) -> ExtensionParameters: + def get_request_params(self) -> List[ExtensionParameter]: """ Build request parameters. @@ -324,8 +324,8 @@ def get_request_params(self) -> ExtensionParameters: def process_response_params( self, - params: List[Tuple[str, Optional[str]]], - accepted_extensions: List["Extension"], + params: Sequence[ExtensionParameter], + accepted_extensions: Sequence["Extension"], ) -> PerMessageDeflate: """ Process response parameters. @@ -481,9 +481,9 @@ def __init__( def process_request_params( self, - params: List[Tuple[str, Optional[str]]], - accepted_extensions: List["Extension"], - ) -> Tuple[ExtensionParameters, PerMessageDeflate]: + params: Sequence[ExtensionParameter], + accepted_extensions: Sequence["Extension"], + ) -> Tuple[List[ExtensionParameter], PerMessageDeflate]: """ Process request parameters. diff --git a/src/websockets/framing.py b/src/websockets/framing.py index 0a778ed53..1409c7d69 100644 --- a/src/websockets/framing.py +++ b/src/websockets/framing.py @@ -17,9 +17,9 @@ Any, Awaitable, Callable, - List, NamedTuple, Optional, + Sequence, Tuple, ) @@ -112,7 +112,7 @@ async def read( *, mask: bool, max_size: Optional[int] = None, - extensions: Optional[List[Extension]] = None, + extensions: Optional[Sequence[Extension]] = None, ) -> "Frame": """ Read a WebSocket frame and return a :class:`Frame` object. @@ -184,7 +184,7 @@ def write( writer: Callable[[bytes], Any], *, mask: bool, - extensions: Optional[List[Extension]] = None, + extensions: Optional[Sequence[Extension]] = None, ) -> None: """ Write a WebSocket frame. diff --git a/src/websockets/headers.py b/src/websockets/headers.py index e2addf4c5..663e71d60 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -9,51 +9,51 @@ import base64 import re -from typing import Callable, List, Optional, Tuple, TypeVar +from typing import Callable, List, NewType, Optional, Sequence, Tuple, TypeVar, cast from .exceptions import InvalidHeaderFormat +from .typing import ExtensionHeader, ExtensionParameter, Subprotocol __all__ = [ "parse_connection", "parse_upgrade", - "parse_extension_list", - "build_extension_list", - "parse_subprotocol_list", - "build_subprotocol_list", + "parse_extension", + "build_extension", + "parse_subprotocol", + "build_subprotocol", ] T = TypeVar("T") -ExtensionParameter = Tuple[str, Optional[str]] -ExtensionParameters = List[ExtensionParameter] -ExtensionHeader = Tuple[str, ExtensionParameters] -SubprotocolHeader = str +ConnectionOption = NewType("ConnectionOption", str) +UpgradeProtocol = NewType("UpgradeProtocol", str) + # To avoid a dependency on a parsing library, we implement manually the ABNF # described in https://tools.ietf.org/html/rfc6455#section-9.1 with the # definitions from https://tools.ietf.org/html/rfc7230#appendix-B. -def peek_ahead(string: str, pos: int) -> Optional[str]: +def peek_ahead(header: str, pos: int) -> Optional[str]: """ - Return the next character from ``string`` at the given position. + Return the next character from ``header`` at the given position. - Return ``None`` at the end of ``string``. + Return ``None`` at the end of ``header``. We never need to peek more than one character ahead. """ - return None if pos == len(string) else string[pos] + return None if pos == len(header) else header[pos] _OWS_re = re.compile(r"[\t ]*") -def parse_OWS(string: str, pos: int) -> int: +def parse_OWS(header: str, pos: int) -> int: """ - Parse optional whitespace from ``string`` at the given position. + Parse optional whitespace from ``header`` at the given position. Return the new position. @@ -61,7 +61,7 @@ def parse_OWS(string: str, pos: int) -> int: """ # There's always a match, possibly empty, whose content doesn't matter. - match = _OWS_re.match(string, pos) + match = _OWS_re.match(header, pos) assert match is not None return match.end() @@ -69,18 +69,18 @@ def parse_OWS(string: str, pos: int) -> int: _token_re = re.compile(r"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+") -def parse_token(string: str, pos: int, header_name: str) -> Tuple[str, int]: +def parse_token(header: str, pos: int, header_name: str) -> Tuple[str, int]: """ - Parse a token from ``string`` at the given position. + Parse a token from ``header`` at the given position. Return the token value and the new position. Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. """ - match = _token_re.match(string, pos) + match = _token_re.match(header, pos) if match is None: - raise InvalidHeaderFormat(header_name, "expected token", string=string, pos=pos) + raise InvalidHeaderFormat(header_name, "expected token", header, pos) return match.group(), match.end() @@ -92,31 +92,29 @@ def parse_token(string: str, pos: int, header_name: str) -> Tuple[str, int]: _unquote_re = re.compile(r"\\([\x09\x20-\x7e\x80-\xff])") -def parse_quoted_string(string: str, pos: int, header_name: str) -> Tuple[str, int]: +def parse_quoted_string(header: str, pos: int, header_name: str) -> Tuple[str, int]: """ - Parse a quoted string from ``string`` at the given position. + Parse a quoted string from ``header`` at the given position. Return the unquoted value and the new position. Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. """ - match = _quoted_string_re.match(string, pos) + match = _quoted_string_re.match(header, pos) if match is None: - raise InvalidHeaderFormat( - header_name, "expected quoted string", string=string, pos=pos - ) + raise InvalidHeaderFormat(header_name, "expected quoted string", header, pos) return _unquote_re.sub(r"\1", match.group()[1:-1]), match.end() def parse_list( parse_item: Callable[[str, int, str], Tuple[T, int]], - string: str, + header: str, pos: int, header_name: str, ) -> List[T]: """ - Parse a comma-separated list from ``string`` at the given position. + Parse a comma-separated list from ``header`` at the given position. This is appropriate for parsing values with the following grammar: @@ -124,7 +122,7 @@ def parse_list( ``parse_item`` parses one item. - ``string`` is assumed not to start or end with whitespace. + ``header`` is assumed not to start or end with whitespace. (This function is designed for parsing an entire header value and :func:`~websockets.http.read_headers` strips whitespace from values.) @@ -139,44 +137,57 @@ def parse_list( # while loops that remove extra delimiters. # Remove extra delimiters before the first item. - while peek_ahead(string, pos) == ",": - pos = parse_OWS(string, pos + 1) + while peek_ahead(header, pos) == ",": + pos = parse_OWS(header, pos + 1) items = [] while True: - # Loop invariant: a item starts at pos in string. - item, pos = parse_item(string, pos, header_name) + # Loop invariant: a item starts at pos in header. + item, pos = parse_item(header, pos, header_name) items.append(item) - pos = parse_OWS(string, pos) + pos = parse_OWS(header, pos) - # We may have reached the end of the string. - if pos == len(string): + # We may have reached the end of the header. + if pos == len(header): break # There must be a delimiter after each element except the last one. - if peek_ahead(string, pos) == ",": - pos = parse_OWS(string, pos + 1) + if peek_ahead(header, pos) == ",": + pos = parse_OWS(header, pos + 1) else: - raise InvalidHeaderFormat( - header_name, "expected comma", string=string, pos=pos - ) + raise InvalidHeaderFormat(header_name, "expected comma", header, pos) # Remove extra delimiters before the next item. - while peek_ahead(string, pos) == ",": - pos = parse_OWS(string, pos + 1) + while peek_ahead(header, pos) == ",": + pos = parse_OWS(header, pos + 1) - # We may have reached the end of the string. - if pos == len(string): + # We may have reached the end of the header. + if pos == len(header): break - # Since we only advance in the string by one character with peek_ahead() + # Since we only advance in the header by one character with peek_ahead() # or with the end position of a regex match, we can't overshoot the end. - assert pos == len(string) + assert pos == len(header) return items -def parse_connection(string: str) -> List[str]: +def parse_connection_option( + header: str, pos: int, header_name: str +) -> Tuple[ConnectionOption, int]: + """ + Parse a Connection option from ``header`` at the given position. + + Return the protocol value and the new position. + + Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. + + """ + item, pos = parse_token(header, pos, header_name) + return cast(ConnectionOption, item), pos + + +def parse_connection(header: str) -> List[ConnectionOption]: """ Parse a ``Connection`` header. @@ -185,7 +196,7 @@ def parse_connection(string: str) -> List[str]: Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. """ - return parse_list(parse_token, string, 0, "Connection") + return parse_list(parse_connection_option, header, 0, "Connection") _protocol_re = re.compile( @@ -193,40 +204,40 @@ def parse_connection(string: str) -> List[str]: ) -def parse_protocol(string: str, pos: int, header_name: str) -> Tuple[str, int]: +def parse_upgrade_protocol( + header: str, pos: int, header_name: str +) -> Tuple[UpgradeProtocol, int]: """ - Parse a protocol from ``string`` at the given position. + Parse an Upgrade protocol from ``header`` at the given position. Return the protocol value and the new position. Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. """ - match = _protocol_re.match(string, pos) + match = _protocol_re.match(header, pos) if match is None: - raise InvalidHeaderFormat( - header_name, "expected protocol", string=string, pos=pos - ) - return match.group(), match.end() + raise InvalidHeaderFormat(header_name, "expected protocol", header, pos) + return cast(UpgradeProtocol, match.group()), match.end() -def parse_upgrade(string: str) -> List[str]: +def parse_upgrade(header: str) -> List[UpgradeProtocol]: """ Parse an ``Upgrade`` header. - Return a list of connection options. + Return a list of protocols. Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. """ - return parse_list(parse_protocol, string, 0, "Upgrade") + return parse_list(parse_upgrade_protocol, header, 0, "Upgrade") -def parse_extension_param( - string: str, pos: int, header_name: str +def parse_extension_item_param( + header: str, pos: int, header_name: str ) -> Tuple[ExtensionParameter, int]: """ - Parse a single extension parameter from ``string`` at the given position. + Parse a single extension parameter from ``header`` at the given position. Return a ``(name, value)`` pair and the new position. @@ -234,36 +245,33 @@ def parse_extension_param( """ # Extract parameter name. - name, pos = parse_token(string, pos, header_name) - pos = parse_OWS(string, pos) + name, pos = parse_token(header, pos, header_name) + pos = parse_OWS(header, pos) # Extract parameter value, if there is one. value: Optional[str] = None - if peek_ahead(string, pos) == "=": - pos = parse_OWS(string, pos + 1) - if peek_ahead(string, pos) == '"': + if peek_ahead(header, pos) == "=": + pos = parse_OWS(header, pos + 1) + if peek_ahead(header, pos) == '"': pos_before = pos # for proper error reporting below - value, pos = parse_quoted_string(string, pos, header_name) + value, pos = parse_quoted_string(header, pos, header_name) # https://tools.ietf.org/html/rfc6455#section-9.1 says: the value # after quoted-string unescaping MUST conform to the 'token' ABNF. if _token_re.fullmatch(value) is None: raise InvalidHeaderFormat( - header_name, - "invalid quoted string content", - string=string, - pos=pos_before, + header_name, "invalid quoted header content", header, pos_before ) else: - value, pos = parse_token(string, pos, header_name) - pos = parse_OWS(string, pos) + value, pos = parse_token(header, pos, header_name) + pos = parse_OWS(header, pos) return (name, value), pos -def parse_extension( - string: str, pos: int, header_name: str +def parse_extension_item( + header: str, pos: int, header_name: str ) -> Tuple[ExtensionHeader, int]: """ - Parse an extension definition from ``string`` at the given position. + Parse an extension definition from ``header`` at the given position. Return an ``(extension name, parameters)`` pair, where ``parameters`` is a list of ``(name, value)`` pairs, and the new position. @@ -272,18 +280,18 @@ def parse_extension( """ # Extract extension name. - name, pos = parse_token(string, pos, header_name) - pos = parse_OWS(string, pos) + name, pos = parse_token(header, pos, header_name) + pos = parse_OWS(header, pos) # Extract all parameters. parameters = [] - while peek_ahead(string, pos) == ";": - pos = parse_OWS(string, pos + 1) - parameter, pos = parse_extension_param(string, pos, header_name) + while peek_ahead(header, pos) == ";": + pos = parse_OWS(header, pos + 1) + parameter, pos = parse_extension_item_param(header, pos, header_name) parameters.append(parameter) return (name, parameters), pos -def parse_extension_list(string: str) -> List[ExtensionHeader]: +def parse_extension(header: str) -> List[ExtensionHeader]: """ Parse a ``Sec-WebSocket-Extensions`` header. @@ -305,14 +313,17 @@ def parse_extension_list(string: str) -> List[ExtensionHeader]: Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. """ - return parse_list(parse_extension, string, 0, "Sec-WebSocket-Extensions") + return parse_list(parse_extension_item, header, 0, "Sec-WebSocket-Extensions") -def build_extension(name: str, parameters: ExtensionParameters) -> str: +parse_extension_list = parse_extension # alias for backwards-compatibility + + +def build_extension_item(name: str, parameters: List[ExtensionParameter]) -> str: """ Build an extension definition. - This is the reverse of :func:`parse_extension`. + This is the reverse of :func:`parse_extension_item`. """ return "; ".join( @@ -325,38 +336,62 @@ def build_extension(name: str, parameters: ExtensionParameters) -> str: ) -def build_extension_list(extensions: List[ExtensionHeader]) -> str: +def build_extension(extensions: Sequence[ExtensionHeader]) -> str: """ Unparse a ``Sec-WebSocket-Extensions`` header. - This is the reverse of :func:`parse_extension_list`. + This is the reverse of :func:`parse_extension`. """ return ", ".join( - build_extension(name, parameters) for name, parameters in extensions + build_extension_item(name, parameters) for name, parameters in extensions ) -def parse_subprotocol_list(string: str) -> List[SubprotocolHeader]: +build_extension_list = build_extension # alias for backwards-compatibility + + +def parse_subprotocol_item( + header: str, pos: int, header_name: str +) -> Tuple[Subprotocol, int]: + """ + Parse a subprotocol from ``header`` at the given position. + + Return the subprotocol value and the new position. + + Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. + + """ + item, pos = parse_token(header, pos, header_name) + return cast(Subprotocol, item), pos + + +def parse_subprotocol(header: str) -> List[Subprotocol]: """ Parse a ``Sec-WebSocket-Protocol`` header. Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. """ - return parse_list(parse_token, string, 0, "Sec-WebSocket-Protocol") + return parse_list(parse_subprotocol_item, header, 0, "Sec-WebSocket-Protocol") + +parse_subprotocol_list = parse_subprotocol # alias for backwards-compatibility -def build_subprotocol_list(protocols: List[SubprotocolHeader]) -> str: + +def build_subprotocol(protocols: Sequence[Subprotocol]) -> str: """ Unparse a ``Sec-WebSocket-Protocol`` header. - This is the reverse of :func:`parse_subprotocol_list`. + This is the reverse of :func:`parse_subprotocol`. """ return ", ".join(protocols) +build_subprotocol_list = build_subprotocol # alias for backwards-compatibility + + def build_basic_auth(username: str, password: str) -> str: """ Build an Authorization header for HTTP Basic Auth. diff --git a/src/websockets/server.py b/src/websockets/server.py index efb3ebee3..b20f4b80d 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -40,12 +40,13 @@ from .handshake import build_response, check_request from .headers import ( ExtensionHeader, - build_extension_list, - parse_extension_list, - parse_subprotocol_list, + build_extension, + parse_extension, + parse_subprotocol, ) from .http import USER_AGENT, Headers, HeadersLike, MultipleValuesError, read_request from .protocol import State, WebSocketCommonProtocol +from .typing import Origin, Subprotocol __all__ = ["serve", "unix_serve", "WebSocketServerProtocol"] @@ -78,9 +79,9 @@ def __init__( ws_handler: Callable[["WebSocketServerProtocol", str], Awaitable[Any]], ws_server: "WebSocketServer", *, - origins: Optional[List[Optional[str]]] = None, - extensions: Optional[List[ServerExtensionFactory]] = None, - subprotocols: Optional[List[str]] = None, + origins: Optional[Sequence[Optional[Origin]]] = None, + extensions: Optional[Sequence[ServerExtensionFactory]] = None, + subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLikeOrCallable] = None, process_request: Optional[ Callable[ @@ -88,7 +89,9 @@ def __init__( Union[Optional[HTTPResponse], Awaitable[Optional[HTTPResponse]]], ] ] = None, - select_subprotocol: Optional[Callable[[List[str], List[str]], str]] = None, + select_subprotocol: Optional[ + Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] + ] = None, **kwds: Any, ) -> None: # For backwards-compatibility with 6.0 or earlier. @@ -301,8 +304,8 @@ def process_request( @staticmethod def process_origin( - headers: Headers, origins: Optional[List[Optional[str]]] = None - ) -> Optional[str]: + headers: Headers, origins: Optional[Sequence[Optional[Origin]]] = None + ) -> Optional[Origin]: """ Handle the Origin HTTP request header. @@ -313,7 +316,7 @@ def process_origin( # "The user agent MUST NOT include more than one Origin header field" # per https://tools.ietf.org/html/rfc6454#section-7.3. try: - origin = headers.get("Origin") + origin = cast(Origin, headers.get("Origin")) except MultipleValuesError: raise InvalidHeader("Origin", "more than one Origin header found") if origins is not None: @@ -323,7 +326,8 @@ def process_origin( @staticmethod def process_extensions( - headers: Headers, available_extensions: Optional[List[ServerExtensionFactory]] + headers: Headers, + available_extensions: Optional[Sequence[ServerExtensionFactory]], ) -> Tuple[Optional[str], List[Extension]]: """ Handle the Sec-WebSocket-Extensions HTTP request header. @@ -367,8 +371,7 @@ def process_extensions( if header_values and available_extensions: parsed_header_values: List[ExtensionHeader] = sum( - [parse_extension_list(header_value) for header_value in header_values], - [], + [parse_extension(header_value) for header_value in header_values], [] ) for name, request_params in parsed_header_values: @@ -399,14 +402,14 @@ def process_extensions( # Serialize extension header. if extension_headers: - response_header_value = build_extension_list(extension_headers) + response_header_value = build_extension(extension_headers) return response_header_value, accepted_extensions # Not @staticmethod because it calls self.select_subprotocol() def process_subprotocol( - self, headers: Headers, available_subprotocols: Optional[List[str]] - ) -> Optional[str]: + self, headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]] + ) -> Optional[Subprotocol]: """ Handle the Sec-WebSocket-Protocol HTTP request header. @@ -414,18 +417,14 @@ def process_subprotocol( as the selected subprotocol. """ - subprotocol: Optional[str] = None + subprotocol: Optional[Subprotocol] = None header_values = headers.get_all("Sec-WebSocket-Protocol") if header_values and available_subprotocols: - parsed_header_values: List[str] = sum( - [ - parse_subprotocol_list(header_value) - for header_value in header_values - ], - [], + parsed_header_values: List[Subprotocol] = sum( + [parse_subprotocol(header_value) for header_value in header_values], [] ) subprotocol = self.select_subprotocol( @@ -435,8 +434,10 @@ def process_subprotocol( return subprotocol def select_subprotocol( - self, client_subprotocols: List[str], server_subprotocols: List[str] - ) -> Optional[str]: + self, + client_subprotocols: Sequence[Subprotocol], + server_subprotocols: Sequence[Subprotocol], + ) -> Optional[Subprotocol]: """ Pick a subprotocol among those offered by the client. @@ -469,9 +470,9 @@ def select_subprotocol( async def handshake( self, - origins: Optional[List[Optional[str]]] = None, - available_extensions: Optional[List[ServerExtensionFactory]] = None, - available_subprotocols: Optional[List[str]] = None, + origins: Optional[Sequence[Optional[Origin]]] = None, + available_extensions: Optional[Sequence[ServerExtensionFactory]] = None, + available_subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLikeOrCallable] = None, ) -> str: """ @@ -815,14 +816,16 @@ def __init__( klass: Type[WebSocketServerProtocol] = WebSocketServerProtocol, timeout: float = 10, compression: Optional[str] = "deflate", - origins: Optional[List[Optional[str]]] = None, - extensions: Optional[List[ServerExtensionFactory]] = None, - subprotocols: Optional[List[str]] = None, + origins: Optional[Sequence[Optional[Origin]]] = None, + extensions: Optional[Sequence[ServerExtensionFactory]] = None, + subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLikeOrCallable] = None, process_request: Optional[ Callable[[str, Headers], Optional[HTTPResponse]] ] = None, - select_subprotocol: Optional[Callable[[List[str], List[str]], str]] = None, + select_subprotocol: Optional[ + Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] + ] = None, **kwds: Any, ): # Backwards-compatibility: close_timeout used to be called timeout. @@ -849,7 +852,7 @@ def __init__( ext_factory.name == ServerPerMessageDeflateFactory.name for ext_factory in extensions ): - extensions.append(ServerPerMessageDeflateFactory()) + extensions = list(extensions) + [ServerPerMessageDeflateFactory()] elif compression is not None: raise ValueError(f"Unsupported compression: {compression}") diff --git a/src/websockets/typing.py b/src/websockets/typing.py index 2f0c50c59..651b40bbe 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -1,7 +1,7 @@ -from typing import Union +from typing import List, NewType, Optional, Tuple, Union -__all__ = ["Data"] +__all__ = ["Data", "Origin", "ExtensionHeader", "ExtensionParameter", "Subprotocol"] Data = Union[str, bytes] @@ -17,3 +17,12 @@ Data.__doc__ = Data__doc__ # type: ignore except AttributeError: # pragma: no cover pass + + +Origin = NewType("Origin", str) + +ExtensionParameter = Tuple[str, Optional[str]] + +ExtensionHeader = Tuple[str, List[ExtensionParameter]] + +Subprotocol = NewType("Subprotocol", str) diff --git a/tests/test_headers.py b/tests/test_headers.py index f03dc83cf..51a0f33af 100644 --- a/tests/test_headers.py +++ b/tests/test_headers.py @@ -41,7 +41,7 @@ def test_parse_upgrade_invalid_header(self): with self.assertRaises(InvalidHeaderFormat): parse_upgrade(header) - def test_parse_extension_list(self): + def test_parse_extension(self): for header, parsed in [ # Synthetic examples ("foo", [("foo", [])]), @@ -78,12 +78,12 @@ def test_parse_extension_list(self): ), ]: with self.subTest(header=header): - self.assertEqual(parse_extension_list(header), parsed) - # Also ensure that build_extension_list round-trips cleanly. - unparsed = build_extension_list(parsed) - self.assertEqual(parse_extension_list(unparsed), parsed) + self.assertEqual(parse_extension(header), parsed) + # Also ensure that build_extension round-trips cleanly. + unparsed = build_extension(parsed) + self.assertEqual(parse_extension(unparsed), parsed) - def test_parse_extension_list_invalid_header(self): + def test_parse_extension_invalid_header(self): for header in [ # Truncated examples "", @@ -99,9 +99,9 @@ def test_parse_extension_list_invalid_header(self): ]: with self.subTest(header=header): with self.assertRaises(InvalidHeaderFormat): - parse_extension_list(header) + parse_extension(header) - def test_parse_subprotocol_list(self): + def test_parse_subprotocol(self): for header, parsed in [ # Synthetic examples ("foo", ["foo"]), @@ -110,12 +110,12 @@ def test_parse_subprotocol_list(self): (",\t, , ,foo ,, bar,baz,,", ["foo", "bar", "baz"]), ]: with self.subTest(header=header): - self.assertEqual(parse_subprotocol_list(header), parsed) - # Also ensure that build_subprotocol_list round-trips cleanly. - unparsed = build_subprotocol_list(parsed) - self.assertEqual(parse_subprotocol_list(unparsed), parsed) + self.assertEqual(parse_subprotocol(header), parsed) + # Also ensure that build_subprotocol round-trips cleanly. + unparsed = build_subprotocol(parsed) + self.assertEqual(parse_subprotocol(unparsed), parsed) - def test_parse_subprotocol_list_invalid_header(self): + def test_parse_subprotocol_invalid_header(self): for header in [ # Truncated examples "", @@ -125,7 +125,7 @@ def test_parse_subprotocol_list_invalid_header(self): ]: with self.subTest(header=header): with self.assertRaises(InvalidHeaderFormat): - parse_subprotocol_list(header) + parse_subprotocol(header) def test_build_basic_auth(self): # Test vector from RFC 7617. From 15018fdc1413aa7b720009a1b2ecddb56509fdf8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 3 Feb 2019 19:21:57 +0100 Subject: [PATCH 0551/1539] Normalize return value declaration on __init__. -> None is optional, but I included in on most constructors, so I'm adding the missing ones. --- src/websockets/client.py | 2 +- src/websockets/extensions/permessage_deflate.py | 6 +++--- src/websockets/server.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 9cefaedb8..d9ad668ed 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -384,7 +384,7 @@ def __init__( subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLike] = None, **kwds: Any, - ): + ) -> None: if loop is None: loop = asyncio.get_event_loop() diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 145cb2bbe..2de27260f 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -45,7 +45,7 @@ def __init__( remote_max_window_bits: int, local_max_window_bits: int, compress_settings: Optional[Dict[Any, Any]] = None, - ): + ) -> None: """ Configure the Per-Message Deflate extension. @@ -285,7 +285,7 @@ def __init__( server_max_window_bits: Optional[int] = None, client_max_window_bits: Optional[Union[int, bool]] = None, compress_settings: Optional[Dict[Any, Any]] = None, - ): + ) -> None: """ Configure the Per-Message Deflate extension factory. @@ -458,7 +458,7 @@ def __init__( server_max_window_bits: Optional[int] = None, client_max_window_bits: Optional[int] = None, compress_settings: Optional[Dict[Any, Any]] = None, - ): + ) -> None: """ Configure the Per-Message Deflate extension factory. diff --git a/src/websockets/server.py b/src/websockets/server.py index b20f4b80d..d99308156 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -580,7 +580,7 @@ class WebSocketServer: """ - def __init__(self, loop: asyncio.AbstractEventLoop): + def __init__(self, loop: asyncio.AbstractEventLoop) -> None: # Store a reference to loop to avoid relying on self.server._loop. self.loop = loop @@ -827,7 +827,7 @@ def __init__( Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] ] = None, **kwds: Any, - ): + ) -> None: # Backwards-compatibility: close_timeout used to be called timeout. # If both are specified, timeout is ignored. if close_timeout is None: From 37ef1172ff09073a083f18e6373e7ec9f5e03063 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 3 Feb 2019 20:52:21 +0100 Subject: [PATCH 0552/1539] Improve display of type hints in docs. This requires further work to add :param: declarations in docstrings. --- docs/conf.py | 1 + docs/requirements.txt | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/conf.py b/docs/conf.py index 504656afc..f4e81db35 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -29,6 +29,7 @@ 'sphinx.ext.autodoc', 'sphinx.ext.intersphinx', 'sphinx.ext.viewcode', + 'sphinx_autodoc_typehints', 'sphinxcontrib_trio', ] diff --git a/docs/requirements.txt b/docs/requirements.txt index 954e8c755..0eaf94fbe 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,3 +1,4 @@ sphinx +sphinx-autodoc-typehints sphinxcontrib-spelling sphinxcontrib-trio From b1b3917a5e5bce88bc9f87d211928a2eabd437ad Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 3 Feb 2019 21:01:23 +0100 Subject: [PATCH 0553/1539] Removed circular dependency in exceptions and uri. --- src/websockets/client.py | 7 ++++--- src/websockets/exceptions.py | 12 +++--------- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index d9ad668ed..8e2bcf36e 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -291,7 +291,7 @@ async def handshake( if status_code in (301, 302, 303, 307, 308): if "Location" not in response_headers: raise InvalidMessage("Redirect response missing Location") - raise RedirectHandshake(parse_uri(response_headers["Location"])) + raise RedirectHandshake(response_headers["Location"]) elif status_code != 101: raise InvalidStatusCode(status_code) @@ -518,9 +518,10 @@ async def __await_impl__(self) -> WebSocketClientProtocol: await protocol.wait_closed() raise except RedirectHandshake as e: - if self._wsuri.secure and not e.wsuri.secure: + wsuri = parse_uri(e.uri) + if self._wsuri.secure and not wsuri.secure: raise InvalidHandshake("Redirect dropped TLS") - self._wsuri = e.wsuri + self._wsuri = wsuri continue # redirection chain continues else: raise InvalidHandshake("Maximum redirects exceeded") diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 436c594a9..73eb8bb79 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -1,15 +1,9 @@ import http -from typing import TYPE_CHECKING, Any, Optional +from typing import Optional from .http import Headers, HeadersLike -if TYPE_CHECKING: # pragma: no cover - from .uri import WebSocketURI -else: - WebSocketURI = Any - - __all__ = [ "AbortHandshake", "ConnectionClosed", @@ -61,8 +55,8 @@ class RedirectHandshake(InvalidHandshake): """ - def __init__(self, wsuri: WebSocketURI) -> None: - self.wsuri = wsuri + def __init__(self, uri: str) -> None: + self.uri = uri class InvalidMessage(InvalidHandshake): From 6e1766d61983f0069365de3922243b399876b794 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 3 Feb 2019 21:14:43 +0100 Subject: [PATCH 0554/1539] Handle a circular import less inelegantly. This preserves the correct type annotation even when TYPE_CHECKING is False, which seems better (e.g. for doc generation). --- src/websockets/framing.py | 24 +++++++----------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/src/websockets/framing.py b/src/websockets/framing.py index 1409c7d69..5b694fd40 100644 --- a/src/websockets/framing.py +++ b/src/websockets/framing.py @@ -12,26 +12,12 @@ import io import random import struct -from typing import ( - TYPE_CHECKING, - Any, - Awaitable, - Callable, - NamedTuple, - Optional, - Sequence, - Tuple, -) +from typing import Any, Awaitable, Callable, NamedTuple, Optional, Sequence, Tuple from .exceptions import PayloadTooBig, WebSocketProtocolError from .typing import Data -if TYPE_CHECKING: # pragma: no cover - from .extensions.base import Extension -else: - Extension = Any - try: from .speedups import apply_mask except ImportError: # pragma: no cover @@ -112,7 +98,7 @@ async def read( *, mask: bool, max_size: Optional[int] = None, - extensions: Optional[Sequence[Extension]] = None, + extensions: Optional[Sequence["websockets.extensions.base.Extension"]] = None, ) -> "Frame": """ Read a WebSocket frame and return a :class:`Frame` object. @@ -184,7 +170,7 @@ def write( writer: Callable[[bytes], Any], *, mask: bool, - extensions: Optional[Sequence[Extension]] = None, + extensions: Optional[Sequence["websockets.extensions.base.Extension"]] = None, ) -> None: """ Write a WebSocket frame. @@ -373,3 +359,7 @@ def check_close(code: int) -> None: """ if not (code in EXTERNAL_CLOSE_CODES or 3000 <= code < 5000): raise WebSocketProtocolError("Invalid status code") + + +# at the bottom to allow circular import, because Extension depends on Frame +import websockets.extensions.base # isort:skip # noqa From dcca6efd750bd42062fe9cb3ecb822a7cae75d19 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 3 Feb 2019 21:40:05 +0100 Subject: [PATCH 0555/1539] Simplify mock. The code that was using this no longer exists. --- tests/test_protocol.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 154948e43..1f35e65a2 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -70,9 +70,6 @@ def write_eof(self): self.loop.call_soon(self.close) self._eof = True - def is_closing(self): - return self._closing - def close(self): # Simulate how actual transports drop the connection. if not self._closing: From e4cec94ceacbb039ced993dd0d4fd00c4761cd90 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 10 Feb 2019 14:14:33 +0100 Subject: [PATCH 0556/1539] Document bugfix releases in the changelog. Mostly so users don't wonder why these releases exist. Fix #572. --- docs/changelog.rst | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index 1c4b1bc96..c59e569d1 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -149,6 +149,12 @@ Also: * Added compatibility with Python 3.7. +5.0.1 +..... + +* Fixed a regression in the 5.0 release that broke some invocations of + :func:`~server.serve()` and :func:`~client.connect()`. + 5.0 ... @@ -208,6 +214,11 @@ Also: * Prevented processing of incoming frames after failing the connection. +4.0.1 +..... + +* Fixed issues with the packaging of the 4.0 release. + 4.0 ... From ca86a14837d4adbcd4256688f0b0385b65fc304e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 10 Feb 2019 14:30:03 +0100 Subject: [PATCH 0557/1539] Improve changelog. Reduce the "wall of red" by downgrading less important backwards-incompatible changes from warning to note. --- docs/changelog.rst | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index c59e569d1..169cb829f 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -12,7 +12,7 @@ Changelog **Version 8.0 drops compatibility with Python 3.4 and 3.5.** -.. warning:: +.. note:: **Version 8.0 adds the reason phrase to the return type of the low-level API** :func:`~http.read_response` **.** @@ -65,7 +65,7 @@ Also: closed = asyncio.ensure_future(websocket.wait_closed()) closed.add_done_callback(lambda task: task.cancel()) -.. warning:: +.. note:: **Version 7.0 changes how a** :meth:`~protocol.WebSocketCommonProtocol.ping` **that hasn't received a pong yet behaves when the connection is closed.** @@ -75,7 +75,7 @@ Also: :exc:`~asyncio.CancelledError`. Now ``await ping`` raises :exc:`~exceptions.ConnectionClosed` like other public APIs. -.. warning:: +.. note:: **Version 7.0 raises a** :exc:`RuntimeError` **exception if two coroutines call** :meth:`~protocol.WebSocketCommonProtocol.recv` **concurrently.** @@ -93,7 +93,7 @@ Also: :func:`~server.serve()` and :class:`~server.WebSocketServerProtocol` to customize :meth:`~server.WebSocketServerProtocol.process_request` and :meth:`~server.WebSocketServerProtocol.select_subprotocol` without - subclassing :class:`~server.WebSocketServerProtocol` + subclassing :class:`~server.WebSocketServerProtocol`. * Added support for sending fragmented messages. @@ -142,8 +142,10 @@ Also: * Functions defined in the :mod:`~http` module now return HTTP headers as :class:`~http.Headers` instead of lists of ``(name, value)`` pairs. - Note that :class:`~http.Headers` and :class:`~http.client.HTTPMessage` - provide similar APIs. + Since :class:`~http.Headers` and :class:`~http.client.HTTPMessage` provide + similar APIs, this change won't affect most of the code dealing with HTTP + headers. + Also: @@ -164,9 +166,11 @@ Also: websockets 4.0 was vulnerable to denial of service by memory exhaustion because it didn't enforce ``max_size`` when decompressing compressed - messages (CVE-2018-1000518). + messages (`CVE-2018-1000518`_). -.. warning:: + .. _CVE-2018-1000518: https://nvd.nist.gov/vuln/detail/CVE-2018-1000518 + +.. note:: **Version 5.0 adds a** ``user_info`` **field to the return value of** :func:`~uri.parse_uri` **and** :class:`~uri.WebSocketURI` **.** @@ -188,7 +192,8 @@ Also: * :func:`~server.unix_serve` can be used as an asynchronous context manager on Python ≥ 3.5.1. -* Added :meth:`~protocol.WebSocketCommonProtocol.closed` property. +* Added the :attr:`~protocol.WebSocketCommonProtocol.closed` property to + protocols. * If a :meth:`~protocol.WebSocketCommonProtocol.ping` doesn't receive a pong, it's canceled when the connection is closed. @@ -235,6 +240,10 @@ Also: .. warning:: + **Version 4.0 drops compatibility with Python 3.3.** + +.. note:: + **Version 4.0 removes the** ``state_name`` **attribute of protocols.** Use ``protocol.state.name`` instead of ``protocol.state_name``. @@ -246,7 +255,8 @@ Also: * Added :func:`~server.unix_serve` for listening on Unix sockets. -* Added the :attr:`~server.WebSocketServer.sockets` attribute. +* Added the :attr:`~server.WebSocketServer.sockets` attribute to the return + value of :func:`~server.serve`. * Reorganized and extended documentation. @@ -278,7 +288,7 @@ Also: * Rewrote HTTP handling for simplicity and performance. -* Added an optional C extension to speed up low level operations. +* Added an optional C extension to speed up low-level operations. * An invalid response status code during :func:`~client.connect()` now raises :class:`~exceptions.InvalidStatusCode` with a ``code`` attribute. From 569042a682eaa0b5e5a953bea9f3d22832b8dea1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 10 Feb 2019 18:52:14 +0100 Subject: [PATCH 0558/1539] Simplify implementation of __iter__. --- src/websockets/client.py | 18 ++++++++++-------- src/websockets/server.py | 18 ++++++++++-------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 8e2bcf36e..6adb5ca23 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -484,9 +484,7 @@ async def _creating_connection( protocol = cast(WebSocketClientProtocol, protocol) return transport, protocol - @asyncio.coroutine - def __iter__(self) -> Generator[Any, None, WebSocketClientProtocol]: - return (yield from self.__await__()) + # async with connect(...) async def __aenter__(self) -> WebSocketClientProtocol: return await self @@ -499,6 +497,12 @@ async def __aexit__( ) -> None: await self.ws_client.close() + # await connect(...) + + def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]: + # Create a suitable iterator by calling __await__ on a coroutine. + return self.__await_impl__().__await__() + async def __await_impl__(self) -> WebSocketClientProtocol: for redirects in range(self.MAX_REDIRECTS_ALLOWED): transport, protocol = await self._creating_connection() @@ -529,11 +533,9 @@ async def __await_impl__(self) -> WebSocketClientProtocol: self.ws_client = protocol return protocol - def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]: - # __await__() must return a type that I don't know how to obtain except - # by calling __await__() on the return value of an async function. - # I'm not finding a better way to take advantage of PEP 492. - return self.__await_impl__().__await__() + # yield from connect(...) + + __iter__ = __await__ connect = Connect diff --git a/src/websockets/server.py b/src/websockets/server.py index d99308156..7137148a0 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -891,9 +891,7 @@ def __init__( self._creating_server = creating_server self.ws_server = ws_server - @asyncio.coroutine - def __iter__(self) -> Generator[Any, None, WebSocketServer]: - return (yield from self.__await__()) + # async with serve(...) async def __aenter__(self) -> WebSocketServer: return await self @@ -907,16 +905,20 @@ async def __aexit__( self.ws_server.close() await self.ws_server.wait_closed() + # await serve(...) + + def __await__(self) -> Generator[Any, None, WebSocketServer]: + # Create a suitable iterator by calling __await__ on a coroutine. + return self.__await_impl__().__await__() + async def __await_impl__(self) -> WebSocketServer: server = await self._creating_server self.ws_server.wrap(server) return self.ws_server - def __await__(self) -> Generator[Any, None, WebSocketServer]: - # __await__() must return a type that I don't know how to obtain except - # by calling __await__() on the return value of an async function. - # I'm not finding a better way to take advantage of PEP 492. - return self.__await_impl__().__await__() + # yield from serve(...) + + __iter__ = __await__ serve = Serve From 217477fa1119c26cccf30bbfb07573444839bb27 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Feb 2019 08:58:02 +0100 Subject: [PATCH 0559/1539] Restore support for unbounded incoming message queues. The API change is backwards incompatible. However, None is a better value than 0 to mean "no limit" and users already hit the backwards incompatibility when they upgraded to 7.0, which broke the feature. For this reason I didn't include a backwards compatibility shim. Thanks @petr-fedorov for the report. Fix #576. --- docs/changelog.rst | 8 +++++++- src/websockets/protocol.py | 32 +++++++++++++++++--------------- tests/test_protocol.py | 15 +++++++++++++++ 3 files changed, 39 insertions(+), 16 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 169cb829f..30f542b54 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -12,12 +12,18 @@ Changelog **Version 8.0 drops compatibility with Python 3.4 and 3.5.** +.. note:: + + **Version 8.0 changes the behavior of the ``max_queue`` parameter.** + + If you were setting ``max_queue=0`` to make the queue of incoming messages + unbounded, change it to ``max_queue=None``. + .. note:: **Version 8.0 adds the reason phrase to the return type of the low-level API** :func:`~http.read_response` **.** - Also: * :meth:`~protocol.WebSocketCommonProtocol.send`, diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index f4dbbb279..a663d2ab2 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -116,15 +116,16 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): raise :exc:`~websockets.exceptions.ConnectionClosed` and the connection will be closed with status code 1009. - The ``max_queue`` parameter sets the maximum length of the queue that holds - incoming messages. The default value is 32. 0 disables the limit. Messages - are added to an in-memory queue when they're received; then :meth:`recv()` - pops from that queue. In order to prevent excessive memory consumption when - messages are received faster than they can be processed, the queue must be - bounded. If the queue fills up, the protocol stops processing incoming data - until :meth:`recv()` is called. In this situation, various receive buffers - (at least in ``asyncio`` and in the OS) will fill up, then the TCP receive - window will shrink, slowing down transmission to avoid packet loss. + The ``max_queue`` parameter sets the maximum length of the queue that + holds incoming messages. The default value is ``32``. ``None`` disables + the limit. Messages are added to an in-memory queue when they're received; + then :meth:`recv()` pops from that queue. In order to prevent excessive + memory consumption when messages are received faster than they can be + processed, the queue must be bounded. If the queue fills up, the protocol + stops processing incoming data until :meth:`recv()` is called. In this + situation, various receive buffers (at least in ``asyncio`` and in the OS) + will fill up, then the TCP receive window will shrink, slowing down + transmission to avoid packet loss. Since Python can use up to 4 bytes of memory to represent a single character, each websocket connection may use up to ``4 * max_size * @@ -709,12 +710,13 @@ async def transfer_data(self) -> None: break # Wait until there's room in the queue (if necessary). - while len(self.messages) >= self.max_queue: - self._put_message_waiter = self.loop.create_future() - try: - await self._put_message_waiter - finally: - self._put_message_waiter = None + if self.max_queue is not None: + while len(self.messages) >= self.max_queue: + self._put_message_waiter = self.loop.create_future() + try: + await self._put_message_waiter + finally: + self._put_message_waiter = None # Put the message in the queue. self.messages.append(message) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 1f35e65a2..0113e4a71 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -474,6 +474,21 @@ def test_recv_queue_full(self): self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(list(self.protocol.messages), []) + def test_recv_queue_no_limit(self): + self.protocol.max_queue = None + + for _ in range(100): + self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) + self.run_loop_once() + + # Incoming message queue can contain at least 100 messages. + self.assertEqual(list(self.protocol.messages), ["café"] * 100) + + for _ in range(100): + self.loop.run_until_complete(self.protocol.recv()) + + self.assertEqual(list(self.protocol.messages), []) + def test_recv_other_error(self): async def read_message(): raise Exception("BOOM") From 68e7c04a068827d61f18adfbbb979d80e19e0221 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Feb 2019 10:41:44 +0100 Subject: [PATCH 0560/1539] Add warnings for backwards compatibility shims. This will make it easier to remove them eventually. --- src/websockets/client.py | 23 ++++++++---- src/websockets/protocol.py | 7 +++- src/websockets/server.py | 12 +++++- tests/test_client_server.py | 73 ++++++++++++++++++++++++++++++------- tests/test_protocol.py | 14 +++++++ 5 files changed, 105 insertions(+), 24 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 6adb5ca23..3d057a2e3 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -6,6 +6,7 @@ import asyncio import collections.abc import logging +import warnings from types import TracebackType from typing import Any, Generator, List, Optional, Sequence, Tuple, Type, cast @@ -376,8 +377,8 @@ def __init__( write_limit: int = 2 ** 16, loop: Optional[asyncio.AbstractEventLoop] = None, legacy_recv: bool = False, - klass: Type[WebSocketClientProtocol] = WebSocketClientProtocol, - timeout: float = 10, + klass: Optional[Type[WebSocketClientProtocol]] = None, + timeout: Optional[float] = None, compression: Optional[str] = "deflate", origin: Optional[Origin] = None, extensions: Optional[Sequence[ClientExtensionFactory]] = None, @@ -385,19 +386,27 @@ def __init__( extra_headers: Optional[HeadersLike] = None, **kwds: Any, ) -> None: - if loop is None: - loop = asyncio.get_event_loop() - - # Backwards-compatibility: close_timeout used to be called timeout. + # Backwards compatibility: close_timeout used to be called timeout. + if timeout is None: + timeout = 10 + else: + warnings.warn("rename timeout to close_timeout", DeprecationWarning) # If both are specified, timeout is ignored. if close_timeout is None: close_timeout = timeout - # Backwards-compatibility: create_protocol used to be called klass. + # Backwards compatibility: create_protocol used to be called klass. + if klass is None: + klass = WebSocketClientProtocol + else: + warnings.warn("rename klass to create_protocol", DeprecationWarning) # If both are specified, klass is ignored. if create_protocol is None: create_protocol = klass + if loop is None: + loop = asyncio.get_event_loop() + self._wsuri = parse_uri(uri) if self._wsuri.secure: kwds.setdefault("ssl", True) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index a663d2ab2..b0fff8fad 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -14,6 +14,7 @@ import logging import random import struct +import warnings from typing import ( Any, AsyncIterable, @@ -182,9 +183,13 @@ def __init__( write_limit: int = 2 ** 16, loop: Optional[asyncio.AbstractEventLoop] = None, legacy_recv: bool = False, - timeout: float = 10, + timeout: Optional[float] = None, ) -> None: # Backwards-compatibility: close_timeout used to be called timeout. + if timeout is None: + timeout = 10 + else: + warnings.warn("rename timeout to close_timeout", DeprecationWarning) # If both are specified, timeout is ignored. if close_timeout is None: close_timeout = timeout diff --git a/src/websockets/server.py b/src/websockets/server.py index 7137148a0..fca6c2caf 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -813,8 +813,8 @@ def __init__( write_limit: int = 2 ** 16, loop: Optional[asyncio.AbstractEventLoop] = None, legacy_recv: bool = False, - klass: Type[WebSocketServerProtocol] = WebSocketServerProtocol, - timeout: float = 10, + klass: Optional[Type[WebSocketServerProtocol]] = None, + timeout: Optional[float] = None, compression: Optional[str] = "deflate", origins: Optional[Sequence[Optional[Origin]]] = None, extensions: Optional[Sequence[ServerExtensionFactory]] = None, @@ -829,11 +829,19 @@ def __init__( **kwds: Any, ) -> None: # Backwards-compatibility: close_timeout used to be called timeout. + if timeout is None: + timeout = 10 + else: + warnings.warn("rename timeout to close_timeout", DeprecationWarning) # If both are specified, timeout is ignored. if close_timeout is None: close_timeout = timeout # Backwards-compatibility: create_protocol used to be called klass. + if klass is None: + klass = WebSocketServerProtocol + else: + warnings.warn("rename klass to create_protocol", DeprecationWarning) # If both are specified, klass is ignored. if create_protocol is None: create_protocol = klass diff --git a/tests/test_client_server.py b/tests/test_client_server.py index fc88b3139..83b1e0fd9 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -249,13 +249,23 @@ def run_loop_once(self): def server_context(self): return None - def start_server(self, **kwds): + def start_server(self, expected_warning=None, **kwds): # Disable compression by default in tests. kwds.setdefault("compression", None) # Disable pings by default in tests. kwds.setdefault("ping_interval", None) - start_server = serve(handler, "localhost", 0, **kwds) - self.server = self.loop.run_until_complete(start_server) + + with warnings.catch_warnings(record=True) as recorded_warnings: + start_server = serve(handler, "localhost", 0, **kwds) + self.server = self.loop.run_until_complete(start_server) + + if expected_warning is None: + self.assertEqual(len(recorded_warnings), 0) + else: + self.assertEqual(len(recorded_warnings), 1) + actual_warning = recorded_warnings[0].message + self.assertEqual(str(actual_warning), expected_warning) + self.assertEqual(type(actual_warning), DeprecationWarning) def start_redirecting_server( self, status, include_location=True, force_insecure=False @@ -278,7 +288,9 @@ def _process_request(path, headers): ) self.redirecting_server = self.loop.run_until_complete(start_server) - def start_client(self, resource_name="/", user_info=None, **kwds): + def start_client( + self, resource_name="/", user_info=None, expected_warning=None, **kwds + ): # Disable compression by default in tests. kwds.setdefault("compression", None) # Disable pings by default in tests. @@ -286,8 +298,18 @@ def start_client(self, resource_name="/", user_info=None, **kwds): secure = kwds.get("ssl") is not None server = self.redirecting_server if self.redirecting_server else self.server server_uri = get_server_uri(server, secure, resource_name, user_info) - start_client = connect(server_uri, **kwds) - self.client = self.loop.run_until_complete(start_client) + + with warnings.catch_warnings(record=True) as recorded_warnings: + start_client = connect(server_uri, **kwds) + self.client = self.loop.run_until_complete(start_client) + + if expected_warning is None: + self.assertEqual(len(recorded_warnings), 0) + else: + self.assertEqual(len(recorded_warnings), 1) + actual_warning = recorded_warnings[0].message + self.assertEqual(str(actual_warning), expected_warning) + self.assertEqual(type(actual_warning), DeprecationWarning) def stop_client(self): try: @@ -638,12 +660,17 @@ def test_server_create_protocol(self): def test_server_create_protocol_function(self): self.assert_client_raises_code(401) - @with_server(klass=UnauthorizedServerProtocol) + @with_server( + klass=UnauthorizedServerProtocol, + expected_warning="rename klass to create_protocol", + ) def test_server_klass_backwards_compatibility(self): self.assert_client_raises_code(401) @with_server( - create_protocol=ForbiddenServerProtocol, klass=UnauthorizedServerProtocol + create_protocol=ForbiddenServerProtocol, + klass=UnauthorizedServerProtocol, + expected_warning="rename klass to create_protocol", ) def test_server_create_protocol_over_klass(self): self.assert_client_raises_code(403) @@ -662,12 +689,21 @@ def test_client_create_protocol_function(self): self.assertIsInstance(self.client, FooClientProtocol) @with_server() - @with_client("/path", klass=FooClientProtocol) + @with_client( + "/path", + klass=FooClientProtocol, + expected_warning="rename klass to create_protocol", + ) def test_client_klass(self): self.assertIsInstance(self.client, FooClientProtocol) @with_server() - @with_client("/path", create_protocol=BarClientProtocol, klass=FooClientProtocol) + @with_client( + "/path", + create_protocol=BarClientProtocol, + klass=FooClientProtocol, + expected_warning="rename klass to create_protocol", + ) def test_client_create_protocol_over_klass(self): self.assertIsInstance(self.client, BarClientProtocol) @@ -677,13 +713,15 @@ def test_server_close_timeout(self): close_timeout = self.loop.run_until_complete(self.client.recv()) self.assertEqual(eval(close_timeout), 7) - @with_server(timeout=6) + @with_server(timeout=6, expected_warning="rename timeout to close_timeout") @with_client("/close_timeout") def test_server_timeout_backwards_compatibility(self): close_timeout = self.loop.run_until_complete(self.client.recv()) self.assertEqual(eval(close_timeout), 6) - @with_server(close_timeout=7, timeout=6) + @with_server( + close_timeout=7, timeout=6, expected_warning="rename timeout to close_timeout" + ) @with_client("/close_timeout") def test_server_close_timeout_over_timeout(self): close_timeout = self.loop.run_until_complete(self.client.recv()) @@ -695,12 +733,19 @@ def test_client_close_timeout(self): self.assertEqual(self.client.close_timeout, 7) @with_server() - @with_client("/close_timeout", timeout=6) + @with_client( + "/close_timeout", timeout=6, expected_warning="rename timeout to close_timeout" + ) def test_client_timeout_backwards_compatibility(self): self.assertEqual(self.client.close_timeout, 6) @with_server() - @with_client("/close_timeout", close_timeout=7, timeout=6) + @with_client( + "/close_timeout", + close_timeout=7, + timeout=6, + expected_warning="rename timeout to close_timeout", + ) def test_client_close_timeout_over_timeout(self): self.assertEqual(self.client.close_timeout, 7) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 0113e4a71..976cc7e9b 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -5,6 +5,7 @@ import time import unittest import unittest.mock +import warnings from websockets.exceptions import ConnectionClosed, InvalidState from websockets.framing import * @@ -321,6 +322,19 @@ def assertCompletesWithin(self, min_time, max_time): self.assertGreaterEqual(dt, min_time, f"Too fast: {dt} < {min_time}") self.assertLess(dt, max_time, f"Too slow: {dt} >= {max_time}") + # Test constructor. + + def test_timeout_backwards_compatibility(self): + with warnings.catch_warnings(record=True) as recorded_warnings: + protocol = WebSocketCommonProtocol(timeout=5) + + self.assertEqual(protocol.close_timeout, 5) + + self.assertEqual(len(recorded_warnings), 1) + warning = recorded_warnings[0].message + self.assertEqual(str(warning), "rename timeout to close_timeout") + self.assertEqual(type(warning), DeprecationWarning) + # Test public attributes. def test_local_address(self): From 17b3f47549b6f752a1be07fa1ba3037cb59c7d56 Mon Sep 17 00:00:00 2001 From: Pablo Marti Date: Tue, 9 Apr 2019 13:12:18 +0200 Subject: [PATCH 0561/1539] remove extra backtick --- docs/intro.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/intro.rst b/docs/intro.rst index 389896ef4..118167b73 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -71,7 +71,7 @@ This client needs a context because the server uses a self-signed certificate. A client connecting to a secure WebSocket server with a valid certificate (i.e. signed by a CA that your Python installation trusts) can simply pass -``ssl=True`` to :func:`connect`` instead of building a context. +``ssl=True`` to :func:`connect` instead of building a context. Browser-based example --------------------- From acb7a939ae2db4fb977cb92d8e9101c3ee82c0d3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 8 May 2019 13:59:58 +0200 Subject: [PATCH 0562/1539] Small cleanup. --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 30dbfd9c1..d389623a7 100644 --- a/Makefile +++ b/Makefile @@ -19,4 +19,4 @@ coverage: clean: find . -name '*.pyc' -o -name '*.so' -delete find . -name __pycache__ -delete - rm -rf .coverage build compliance/reports dist docs/_build htmlcov MANIFEST README src/websockets.egg-info + rm -rf .coverage build compliance/reports dist docs/_build htmlcov MANIFEST src/websockets.egg-info From 52872a5485651d606900cddd720b4a4aa658d690 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 8 May 2019 14:03:50 +0200 Subject: [PATCH 0563/1539] Add changelog entry for 7d72dabd. --- docs/changelog.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index 30f542b54..9618c1d4b 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -34,6 +34,8 @@ Also: * :func:`~client.connect()` handles redirects from the server during the handshake. +* Enabled readline in the interactive client. + * Added type hints (:pep:`484`). * Added documentation for extensions. From b5690affb4698d18574221ae68024b2fe995a583 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 3 Feb 2019 22:09:34 +0100 Subject: [PATCH 0564/1539] Add ConnectionClosed subclass for normal closure. Thanks @cjerdonek for the suggestion. Fix #285. --- docs/changelog.rst | 5 ++++ src/websockets/exceptions.py | 29 ++++++++++++++++++++- src/websockets/protocol.py | 49 ++++++++++++++++++++---------------- src/websockets/server.py | 4 +-- tests/test_exceptions.py | 4 +-- 5 files changed, 64 insertions(+), 27 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 9618c1d4b..ee407d13e 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -31,6 +31,11 @@ Also: :meth:`~protocol.WebSocketCommonProtocol.pong` support bytes-like types :class:`bytearray` and :class:`memoryview` in addition to :class:`bytes`. +* Added :exc:`~exceptions.ConnectionClosedOK` and + :exc:`~exceptions.ConnectionClosedError` subclasses of + :exc:`~exceptions.ConnectionClosed` to tell apart normal connection + termination from errors. + * :func:`~client.connect()` handles redirects from the server during the handshake. diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 73eb8bb79..7fdc97185 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -7,6 +7,8 @@ __all__ = [ "AbortHandshake", "ConnectionClosed", + "ConnectionClosedError", + "ConnectionClosedOK", "DuplicateParameter", "InvalidHandshake", "InvalidHeader", @@ -203,7 +205,6 @@ def format_close(code: int, reason: str) -> str: """ Display a human-readable version of the close code and reason. - """ if 3000 <= code < 4000: explanation = "registered" @@ -238,6 +239,32 @@ def __init__(self, code: int, reason: str) -> None: super().__init__(message) +class ConnectionClosedError(ConnectionClosed): + """ + Like :exc:`ConnectionClosed`, when the connection terminated with an error. + + This means the close code is different from 1000 (OK) and 1001 (going away). + + """ + + def __init__(self, code: int, reason: str) -> None: + assert code != 1000 and code != 1001 + super().__init__(code, reason) + + +class ConnectionClosedOK(ConnectionClosed): + """ + Like :exc:`ConnectionClosed`, when the connection terminated properly. + + This means the close code is 1000 (OK) or 1001 (going away). + + """ + + def __init__(self, code: int, reason: str) -> None: + assert code == 1000 or code == 1001 + super().__init__(code, reason) + + class InvalidURI(Exception): """ Exception raised when an URI isn't a valid websocket URI. diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index b0fff8fad..c07aef99f 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -30,6 +30,8 @@ from .exceptions import ( ConnectionClosed, + ConnectionClosedError, + ConnectionClosedOK, InvalidState, PayloadTooBig, WebSocketProtocolError, @@ -78,8 +80,8 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): The iterator yields incoming messages. It exits normally when the connection is closed with the close code 1000 (OK) or 1001 (going away). - It raises a :exc:`~websockets.exceptions.ConnectionClosed` exception when - the connection is closed with any other status code. + It raises a :exc:`~websockets.exceptions.ConnectionClosedError` exception + when the connection is closed with any other status code. The ``host``, ``port`` and ``secure`` parameters are simply stored as attributes for handlers that need them. @@ -114,8 +116,8 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): The ``max_size`` parameter enforces the maximum size for incoming messages in bytes. The default value is 1 MiB. ``None`` disables the limit. If a message larger than the maximum size is received, :meth:`recv()` will - raise :exc:`~websockets.exceptions.ConnectionClosed` and the connection - will be closed with status code 1009. + raise :exc:`~websockets.exceptions.ConnectionClosedError` and the + connection will be closed with status code 1009. The ``max_queue`` parameter sets the maximum length of the queue that holds incoming messages. The default value is ``32``. ``None`` disables @@ -382,11 +384,8 @@ async def __aiter__(self) -> AsyncIterator[Data]: try: while True: yield await self.recv() - except ConnectionClosed as exc: - if exc.code == 1000 or exc.code == 1001: - return - else: - raise + except ConnectionClosedOK: + return async def recv(self) -> Data: """ @@ -396,8 +395,11 @@ async def recv(self) -> Data: binary frame. When the end of the message stream is reached, :meth:`recv` raises - :exc:`~websockets.exceptions.ConnectionClosed`. This can happen after - a normal connection closure, a protocol error or a network failure. + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it + raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal + connection closure and + :exc:`~websockets.exceptions.ConnectionClosedError`after a protocol + error or a network failure. .. versionchanged:: 3.0 @@ -659,6 +661,16 @@ async def pong(self, data: bytes = b"") -> None: # Private methods - no guarantees. + def connection_closed_exc(self) -> ConnectionClosed: + exception: ConnectionClosed + if self.close_code == 1000 or self.close_code == 1001: + exception = ConnectionClosedOK(self.close_code, self.close_reason) + else: + exception = ConnectionClosedError(self.close_code, self.close_reason) + # Chain to the exception that terminated data transfer, if any. + exception.__cause__ = self.transfer_data_exc + return exception + async def ensure_open(self) -> None: """ Check that the WebSocket connection is open. @@ -673,16 +685,12 @@ async def ensure_open(self) -> None: # from OPEN to CLOSED. if self.transfer_data_task.done(): await asyncio.shield(self.close_connection_task) - raise ConnectionClosed( - self.close_code, self.close_reason - ) from self.transfer_data_exc + raise self.connection_closed_exc() else: return if self.state is State.CLOSED: - raise ConnectionClosed( - self.close_code, self.close_reason - ) from self.transfer_data_exc + raise self.connection_closed_exc() if self.state is State.CLOSING: # If we started the closing handshake, wait for its completion to @@ -691,9 +699,7 @@ async def ensure_open(self) -> None: # CLOSING state also occurs when failing the connection. In that # case self.close_connection_task will complete even faster. await asyncio.shield(self.close_connection_task) - raise ConnectionClosed( - self.close_code, self.close_reason - ) from self.transfer_data_exc + raise self.connection_closed_exc() # Control may only reach this point in buggy third-party subclasses. assert self.state is State.CONNECTING @@ -1163,8 +1169,7 @@ def abort_keepalive_pings(self) -> None: """ assert self.state is State.CLOSED - exc = ConnectionClosed(self.close_code, self.close_reason) - exc.__cause__ = self.transfer_data_exc # emulate raise ... from ... + exc = self.connection_closed_exc() for ping in self.pings.values(): ping.set_exception(exc) diff --git a/src/websockets/server.py b/src/websockets/server.py index fca6c2caf..e202ea25b 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -782,8 +782,8 @@ class Serve: When a server is closed with :meth:`~WebSocketServer.close`, it closes all connections with close code 1001 (going away). WebSocket handlers — which are running the coroutine passed in the ``ws_handler`` — will receive a - :exc:`~websockets.exceptions.ConnectionClosed` exception on their current - or next interaction with the WebSocket connection. + :exc:`~websockets.exceptions.ConnectionClosedOK` exception on their + current or next interaction with the WebSocket connection. Since there's no useful way to propagate exceptions triggered in handlers, they're sent to the ``'websockets.server'`` logger instead. Debugging is diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 3ccdadb82..6dfbeb7e6 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -90,7 +90,7 @@ def test_str(self): "(OK), no reason", ), ( - ConnectionClosed(1001, 'bye'), + ConnectionClosedOK(1001, 'bye'), "WebSocket connection is closed: code = 1001 " "(going away), reason = bye", ), @@ -100,7 +100,7 @@ def test_str(self): "(connection closed abnormally [internal]), no reason" ), ( - ConnectionClosed(1016, None), + ConnectionClosedError(1016, None), "WebSocket connection is closed: code = 1016 " "(unknown), no reason" ), From f3e40cbfc56d8770e57b65b3cd6b35377e6028c8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 8 May 2019 14:26:35 +0200 Subject: [PATCH 0565/1539] Normalize quotes in a # fmt: off section. --- tests/test_exceptions.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 6dfbeb7e6..4b9830345 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -13,7 +13,7 @@ def test_str(self): "Invalid request", ), ( - AbortHandshake(200, Headers(), b'OK\n'), + AbortHandshake(200, Headers(), b"OK\n"), "HTTP 200, 0 headers, 3 bytes", ), ( @@ -21,44 +21,44 @@ def test_str(self): "Malformed HTTP message", ), ( - InvalidHeader('Name'), + InvalidHeader("Name"), "Missing Name header", ), ( - InvalidHeader('Name', None), + InvalidHeader("Name", None), "Missing Name header", ), ( - InvalidHeader('Name', ''), + InvalidHeader("Name", ""), "Empty Name header", ), ( - InvalidHeader('Name', 'Value'), + InvalidHeader("Name", "Value"), "Invalid Name header: Value", ), ( InvalidHeaderFormat( - 'Sec-WebSocket-Protocol', "expected token", 'a=|', 3 + "Sec-WebSocket-Protocol", "expected token", "a=|", 3 ), "Invalid Sec-WebSocket-Protocol header: " "expected token at 3 in a=|", ), ( - InvalidHeaderValue('Sec-WebSocket-Version', '42'), + InvalidHeaderValue("Sec-WebSocket-Version", "42"), "Invalid Sec-WebSocket-Version header: 42", ), ( - InvalidUpgrade('Upgrade'), + InvalidUpgrade("Upgrade"), "Missing Upgrade header", ), ( - InvalidUpgrade('Connection', 'websocket'), + InvalidUpgrade("Connection", "websocket"), "Invalid Connection header: websocket", ), ( - InvalidOrigin('http://bad.origin'), - 'Invalid Origin header: http://bad.origin', + InvalidOrigin("http://bad.origin"), + "Invalid Origin header: http://bad.origin", ), ( InvalidStatusCode(403), @@ -69,15 +69,15 @@ def test_str(self): "Unsupported subprotocol: spam", ), ( - InvalidParameterName('|'), + InvalidParameterName("|"), "Invalid parameter name: |", ), ( - InvalidParameterValue('a', '|'), + InvalidParameterValue("a", "|"), "Invalid value for parameter a: |", ), ( - DuplicateParameter('a'), + DuplicateParameter("a"), "Duplicate parameter: a", ), ( @@ -85,12 +85,12 @@ def test_str(self): "WebSocket connection isn't established yet", ), ( - ConnectionClosed(1000, ''), + ConnectionClosed(1000, ""), "WebSocket connection is closed: code = 1000 " "(OK), no reason", ), ( - ConnectionClosedOK(1001, 'bye'), + ConnectionClosedOK(1001, "bye"), "WebSocket connection is closed: code = 1001 " "(going away), reason = bye", ), From 423e175cce5dc05f6fab8457fe00492e1f11a34a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 12 May 2019 09:29:36 +0200 Subject: [PATCH 0566/1539] Lock mypy version. Work around https://github.com/python/mypy/issues/6802. --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 7397c90ae..801d4d5d1 100644 --- a/tox.ini +++ b/tox.ini @@ -25,4 +25,4 @@ deps = isort [testenv:mypy] commands = mypy --strict src -deps = mypy +deps = mypy==0.670 From c2649b14037c02b12c6e5756ff8d983b86659b68 Mon Sep 17 00:00:00 2001 From: reallinfo <36298335+reallinfo@users.noreply.github.com> Date: Sun, 12 May 2019 04:26:28 +0300 Subject: [PATCH 0567/1539] Add files via upload --- logo/horizontal.svg | 156 +++++++++++++++++++++++++++++++++++++++++++ logo/icon.svg | 43 ++++++++++++ logo/vertical.svg | 157 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 356 insertions(+) create mode 100644 logo/horizontal.svg create mode 100644 logo/icon.svg create mode 100644 logo/vertical.svg diff --git a/logo/horizontal.svg b/logo/horizontal.svg new file mode 100644 index 000000000..766c706f5 --- /dev/null +++ b/logo/horizontal.svg @@ -0,0 +1,156 @@ + + + + + + + + + + +]> + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/logo/icon.svg b/logo/icon.svg new file mode 100644 index 000000000..69592fea4 --- /dev/null +++ b/logo/icon.svg @@ -0,0 +1,43 @@ + + + + + + + + + + +]> + + + + + + + + + + + + + diff --git a/logo/vertical.svg b/logo/vertical.svg new file mode 100644 index 000000000..e83e1fefe --- /dev/null +++ b/logo/vertical.svg @@ -0,0 +1,157 @@ + + + + + + + + + + +]> + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + From d2292d04ef21ede8a5eb838f73dbb99245eadc4d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 12 May 2019 11:16:13 +0200 Subject: [PATCH 0568/1539] Fine-tune logo. * Add margins * Use round pixel dimensions * Restore gradient on symbol (I like it!) * Insert in README and docs. --- README.rst | 5 +- docs/_static/websockets.svg | 17 +--- logo/horizontal.svg | 187 ++++++----------------------------- logo/icon.svg | 58 +++-------- logo/old.svg | 14 +++ logo/vertical.svg | 188 ++++++------------------------------ 6 files changed, 95 insertions(+), 374 deletions(-) mode change 100644 => 120000 docs/_static/websockets.svg create mode 100644 logo/old.svg diff --git a/README.rst b/README.rst index ae47c7a48..ecfc2e534 100644 --- a/README.rst +++ b/README.rst @@ -1,5 +1,6 @@ -WebSockets -========== +.. image:: logo/horizontal.svg + :width: 480px + :alt: websockets |rtd| |pypi-v| |pypi-pyversions| |pypi-l| |pypi-wheel| |circleci| |codecov| diff --git a/docs/_static/websockets.svg b/docs/_static/websockets.svg deleted file mode 100644 index 409afb71d..000000000 --- a/docs/_static/websockets.svg +++ /dev/null @@ -1,16 +0,0 @@ - - - - - - - - - - - - diff --git a/docs/_static/websockets.svg b/docs/_static/websockets.svg new file mode 120000 index 000000000..84c316758 --- /dev/null +++ b/docs/_static/websockets.svg @@ -0,0 +1 @@ +../../logo/vertical.svg \ No newline at end of file diff --git a/logo/horizontal.svg b/logo/horizontal.svg index 766c706f5..ee872dc47 100644 --- a/logo/horizontal.svg +++ b/logo/horizontal.svg @@ -1,156 +1,31 @@ - - - - - - - - - - -]> - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/logo/icon.svg b/logo/icon.svg index 69592fea4..cb760940a 100644 --- a/logo/icon.svg +++ b/logo/icon.svg @@ -1,43 +1,15 @@ - - - - - - - - - - -]> - - - - - - - - - - - - - + + + + + + + + + + + + + + + diff --git a/logo/old.svg b/logo/old.svg new file mode 100644 index 000000000..a073139e3 --- /dev/null +++ b/logo/old.svg @@ -0,0 +1,14 @@ + + + + + + + + + + + + diff --git a/logo/vertical.svg b/logo/vertical.svg index e83e1fefe..b07fb2238 100644 --- a/logo/vertical.svg +++ b/logo/vertical.svg @@ -1,157 +1,31 @@ - - - - - - - - - - -]> - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + From a2d0cfd0e418ad75d0de04337047fa88c7101a57 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 12 May 2019 19:30:23 +0200 Subject: [PATCH 0569/1539] The official name is lowercase. --- docs/index.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/index.rst b/docs/index.rst index 040d41598..7679f2e38 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,4 +1,4 @@ -WebSockets +websockets ========== |pypi-v| |pypi-pyversions| |pypi-l| |pypi-wheel| |circleci| |codecov| From 8d51ce2da0cbfa971bc2d74c54283671e0e544b8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 23 May 2019 22:06:56 +0200 Subject: [PATCH 0570/1539] Add Tidelift marketing & security. --- README.rst | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/README.rst b/README.rst index ecfc2e534..8cbe55260 100644 --- a/README.rst +++ b/README.rst @@ -103,16 +103,9 @@ The development of ``websockets`` is shaped by four principles: Documentation is a first class concern in the project. Head over to `Read the Docs`_ and see for yourself. -Professional support is available if you — or your company — are so inclined. -`Get in touch`_. - -(If you contribute to ``websockets`` and would like to become an official -support provider, let me know.) - .. _Read the Docs: https://websockets.readthedocs.io/ .. _handle backpressure correctly: https://vorpus.org/blog/some-thoughts-on-asynchronous-api-design-in-a-post-asyncawait-world/#websocket-servers .. _Autobahn Testsuite: https://github.com/aaugustin/websockets/blob/master/compliance/README.rst -.. _Get in touch: https://fractalideas.com/ Why shouldn't I use ``websockets``? ----------------------------------- @@ -127,12 +120,31 @@ Why shouldn't I use ``websockets``? * If you want to use Python 2: ``websockets`` builds upon ``asyncio`` which only works on Python 3. ``websockets`` requires Python ≥ 3.6. + +*Professionally supported websockets is now available* +------------------------------------------------------ + +*Tidelift gives software development teams a single source for purchasing and +maintaining their software, with professional grade assurances from the +experts who know it best, while seamlessly integrating with existing tools.* + +`Get supported websockets with the Tidelift subscription +`_ + +(If you contribute to ``websockets`` and would like to become an official +support provider, `let me know `_.) + What else? ---------- Bug reports, patches and suggestions are welcome! -Please open an issue_ or send a `pull request`_. +To report a security vulnerability, please use the `Tidelift security +contact`_. Tidelift will coordinate the fix and disclosure. + +.. _Tidelift security contact: https://tidelift.com/security + +For anything else, please open an issue_ or send a `pull request`_. .. _issue: https://github.com/aaugustin/websockets/issues/new .. _pull request: https://github.com/aaugustin/websockets/compare/ From 8cd4449977fc821725edb91df8247785dbb8a4f3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 31 May 2019 08:04:53 +0200 Subject: [PATCH 0571/1539] Add Tidelift as sponsoring method --- .github/FUNDING.yml | 1 + 1 file changed, 1 insertion(+) create mode 100644 .github/FUNDING.yml diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 000000000..7ae223b3d --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1 @@ +tidelift: "pypi/websockets" From e262874b3787ea968dc52bdf1f4869bdc272cb17 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 Jun 2019 16:37:06 +0200 Subject: [PATCH 0572/1539] Improve Tidelift marketing. * Add Tidelift logo in README. * Improve UX by separating Tidelift advertising clearly in README. * Move message earlier in README. * Add link to the sidebar in docs. --- README.rst | 30 ++++++++++++++---------------- docs/conf.py | 4 +++- logo/tidelift.png | Bin 0 -> 4069 bytes 3 files changed, 17 insertions(+), 17 deletions(-) create mode 100644 logo/tidelift.png diff --git a/README.rst b/README.rst index 8cbe55260..6bdafb2ed 100644 --- a/README.rst +++ b/README.rst @@ -75,9 +75,21 @@ And here's an echo server: Does that look good? -`Start here!`_ +`Get started with the tutorial!`_ -.. _Start here!: https://websockets.readthedocs.io/en/stable/intro.html +.. _Get started with the tutorial!: https://websockets.readthedocs.io/en/stable/intro.html + +.. raw:: html + +
+ +

Professionally supported websockets is now available

+

Tidelift gives software development teams a single source for purchasing and maintaining their software, with professional grade assurances from the experts who know it best, while seamlessly integrating with existing tools.

+

Get supported websockets with the Tidelift Subscription

+
+ +(If you contribute to ``websockets`` and would like to become an official +support provider, `let me know `_.) Why should I use ``websockets``? -------------------------------- @@ -120,20 +132,6 @@ Why shouldn't I use ``websockets``? * If you want to use Python 2: ``websockets`` builds upon ``asyncio`` which only works on Python 3. ``websockets`` requires Python ≥ 3.6. - -*Professionally supported websockets is now available* ------------------------------------------------------- - -*Tidelift gives software development teams a single source for purchasing and -maintaining their software, with professional grade assurances from the -experts who know it best, while seamlessly integrating with existing tools.* - -`Get supported websockets with the Tidelift subscription -`_ - -(If you contribute to ``websockets`` and would like to become an official -support provider, `let me know `_.) - What else? ---------- diff --git a/docs/conf.py b/docs/conf.py index f4e81db35..e5e6ab15f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -113,6 +113,7 @@ 'github_button': True, 'github_user': 'aaugustin', 'github_repo': 'websockets', + 'tidelift_url': 'https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=docs', } # Add any paths that contain custom themes here, relative to this directory. @@ -151,9 +152,10 @@ html_sidebars = { '**': [ 'about.html', + 'searchbox.html', 'navigation.html', 'relations.html', - 'searchbox.html', + 'donate.html', ] } diff --git a/logo/tidelift.png b/logo/tidelift.png new file mode 100644 index 0000000000000000000000000000000000000000..317dc4d9852df72ba34e10a6f61d1838cbbd969e GIT binary patch literal 4069 zcmeHKX;4$y626285oA+{C5WQL2QzE}vKV%kz!e2?31JbUA|g025fUUiY>rrh;Fz#z z06|dH3QI(hEkM{rKoG$YWc4YEhKLZ18;b7)bWFXc<5)HRd#Oq#_ttm5?$cj)pFYPo zyErPysmnnSqyRgu+X6udg7}M)0dK1PPxphLxxYGjghG&^w)iV`@PZu$f|SzWbq?DO z_C6ah4%Qe^kGo^BFn9}Wweq=KL%w~OQAe>k@xrzRt$Hh~5Uu+OjC)WnQn0o_>#pJ@ z&CP1g1H0nvX+fOgO{XT6ZcF{HwLbQa-z~z()!qhEDHB0ypF#d!iJC(uW`< z8j9v1AZ6g3_%UVR3D76&|7-5cpieG-^y!}gecJ3t$0dA{Fed?$ z3>$F#+aQUE0LQ<=^t~VnmB3O-pCp(hKw@Qoq?BLJoG+^+ekwPKb6_F&q)Jo=qm z0S5!yi4(~rvjy()rT6^M+0%JBUrT zucJ9Rt1jXo)%!c%FI%ikP(8O7X}?Sc2VSd2Y9lxNTQeMqZXD0o%T|@u*}KqYH`pb| zf37}|a6f^EQYNfol=j9TBA)01>CA7HyM5M{<(si|{m0_{%Ph3ry%P!YpeOsOABOcY z;*8ODG5J6p3bHFzCa@gZjiEUV4Jlv+iAE&Y3|mWYrkzpIV(He_p`zFHTlPkC>Ozh{ z>1O5#B>F;`IU*Y>?P_)+E@SD|hnX;R%8ln{*N)LMG^7y;o!;tF9rw%l+?si@oSelX zrNV_GC6nm9*JiMzH(0Q`*dNBigH5>_(F5h1gyzjrK(jYtCZ6hZU7wv%8E-gMveCrJ z#Bz1z6U@_YU$}FsE^2Vo0b;!UGvmEhdqs8&VYJAIs>KQ*?>FEtrP}U!od{%FvTvJ^@`+cF%cb~Oj|pHE=8;3hGd7z>}} zK=~EraA(2pJQB#Z1K33(+Tj;lfFioJQ&0Wwrn}%;^L(F#nN(~P{v5mrfshZlG=(Vw z(CqUGfSm_Z;oz#@LBTMF&AA4(qw|fbt5fbd3oFy+D--M&ykqcJPUQ@T8N&r> zs>%e8VO6FJZe&G9>D%~@iVec$?R$W^p5iHt9oM)MD+;ZwAQ{R7Tx$Qno~{z9%+olL z-xy}X&#Ratr^SlanIvHsk^91&$5qae&km22?LndRiS^xjA%?Fe(u=ZiL(ZZwD(H5H zOZ!&+`lwBDdC*zDcdD8d@Ugn}IoZ1{@xf9(%3!(-APbp@#4gGl>xy6&p$Z)Gd!&%) zakVizTlW|7!cpv!N8nhxuL8*ZaTAwrdY;F=8*@_xOtwrly%37BN9$WRMZ5rXxYx(B z`mPK8WzUgQhA&#C%|Yq+916dTA3Tw`OZ$|WYv=J$k7!P86`9Uwr#PpVa77x?iXs~I>qL|k%~*K2Ez2aVRT+uy*fN@8!NXnQ-emAj zi7NL2FM1nW(g$$eK()*p@tPpsLF^J?p7A*|Cvj9`IEeTXpNP;ss)_OW9$<4Zb5x(2 z!%Z#eH*np{@dA#|L4Rb2dQCUeD}dr>K+nj^*l~+?*hD2j=Q+{g&Dho=wsCyUnxKrp z#`a9HImRPH?nPu1&dtxSZEI$0g|PxKX-bIbO1kOkqLb@0V?xk>NH4^5p=TQG)0n#$ zq?cEzPlV~pcs#C~JOkLjgV&PY?buCiO!BJP%v!o+$&!s=G&sUOGJE)c>^0)XJA00d ztKsXXiDrUV>R~oK=FSftdeR`78(3Ixg{iBU=jxPwM)X_w`bZuVG@R`7W&AuU;tgr7 z(Zr>s{Lmsaka_+CzIC&H^JJVa-+|@|C|Z~@0NPA4a}i!nO54KrP?!oGDF#Wi${=6S zy?)xdd2~(9(OiuD?8_B}R2WS@pDfU7_64gunQ|cI(n#bp!?%;k&g0Iu`El>Bc*TJb zUDw1P0|;#MvDCy{npedrX&0qAnr01L1z8*kbfH@2TkK4|kd|=5nSbOpSXg1?J zNZc1G7JV+wSQx2ITf0B<}GPi%pbdzBm51E+E^}z}=V6 z{;8JQ*v=6@wE%NnGQ03(&!Kcj@+mHRWzVO^Sy5 z$YsT=b{2_vODB=8?2MOI@k~-admqz+xif0a8x1m{l4N$$apkoEH`>t=an)U)N;7&E z{>M5Z`2Xjm`~O)}@^l1E?$gL_{`fj;G}GdV=hz1I{*yKDj3=8kLB(jc)-68=^D4)i zWM`*`p4bQn9yN$ka7I&$mCZ#LiqZ|*6hU5%B)>4cbxR~@ECI(;3qGn?;Db30Fwt!l zKi^jmiG26yWcj_4Knm!;-LiOk*3+=8;ZA|4g_S+g;Sl$BDt42p-9AdzF#Gaz&S_1S zft|Ir^=X7^AjtMSYSY6A=?BCNzYf;11csP2B Date: Sun, 16 Jun 2019 16:49:58 +0200 Subject: [PATCH 0573/1539] Link more prominently to RTD. --- README.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.rst b/README.rst index 6bdafb2ed..7395d803a 100644 --- a/README.rst +++ b/README.rst @@ -37,6 +37,8 @@ Python with a focus on correctness and simplicity. Built on top of ``asyncio``, Python's standard asynchronous I/O framework, it provides an elegant coroutine-based API. +`Documentation is available on Read the Docs. `_ + Here's how a client sends and receives messages: .. copy-pasted because GitHub doesn't support the include directive @@ -75,9 +77,7 @@ And here's an echo server: Does that look good? -`Get started with the tutorial!`_ - -.. _Get started with the tutorial!: https://websockets.readthedocs.io/en/stable/intro.html +`Get started with the tutorial! `_ .. raw:: html From 60c61e0e39c25582d63559f6be905310a8d98bad Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 Jun 2019 16:50:08 +0200 Subject: [PATCH 0574/1539] Sync docs with README. --- docs/index.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/index.rst b/docs/index.rst index 7679f2e38..6001d5075 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -30,7 +30,7 @@ Python with a focus on correctness and simplicity. Built on top of :mod:`asyncio`, Python's standard asynchronous I/O framework, it provides an elegant coroutine-based API. -Here's a client that says "Hello world!": +Here's how a client sends and receives messages: .. literalinclude:: ../example/hello.py From 05ccc5ee64d5f24ed77809985ae5176d71c6caaf Mon Sep 17 00:00:00 2001 From: Tobin Yehle Date: Sun, 16 Jun 2019 08:53:02 -0700 Subject: [PATCH 0575/1539] Mark package as typed (#590) --- MANIFEST.in | 1 + setup.py | 3 ++- src/websockets/py.typed | 0 3 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 src/websockets/py.typed diff --git a/MANIFEST.in b/MANIFEST.in index 1aba38f67..1c660b95b 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1,2 @@ include LICENSE +include src/websockets/py.typed diff --git a/setup.py b/setup.py index d4fadb240..3c87b2339 100644 --- a/setup.py +++ b/setup.py @@ -48,10 +48,11 @@ 'Programming Language :: Python :: 3.7', ], package_dir = {'': 'src'}, + package_data = {'websockets': ['py.typed']}, packages=packages, ext_modules=ext_modules, include_package_data=True, - zip_safe=True, + zip_safe=False, python_requires='>=3.6', test_loader='unittest:TestLoader', ) diff --git a/src/websockets/py.typed b/src/websockets/py.typed new file mode 100644 index 000000000..e69de29bb From b3d60d75fc2973b67ac39d34f08409207c557a97 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 Jun 2019 18:02:11 +0200 Subject: [PATCH 0576/1539] Don't crash if a extra_headers callable returns None. Fix #619. --- docs/changelog.rst | 2 ++ src/websockets/server.py | 4 ++-- tests/test_client_server.py | 6 ++++++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index ee407d13e..f4cd8a4b6 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -39,6 +39,8 @@ Also: * :func:`~client.connect()` handles redirects from the server during the handshake. +* Avoided a crash of a ``extra_headers`` callable returns ``None``. + * Enabled readline in the interactive client. * Added type hints (:pep:`484`). diff --git a/src/websockets/server.py b/src/websockets/server.py index e202ea25b..73c07cf11 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -540,9 +540,9 @@ async def handshake( if protocol_header is not None: response_headers["Sec-WebSocket-Protocol"] = protocol_header + if callable(extra_headers): + extra_headers = extra_headers(path, self.request_headers) if extra_headers is not None: - if callable(extra_headers): - extra_headers = extra_headers(path, self.request_headers) if isinstance(extra_headers, Headers): extra_headers = extra_headers.raw_items() elif isinstance(extra_headers, collections.abc.Mapping): diff --git a/tests/test_client_server.py b/tests/test_client_server.py index 83b1e0fd9..5c441561f 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -566,6 +566,12 @@ def test_protocol_custom_response_headers_callable_list(self): resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) + @with_server(extra_headers=lambda p, r: None) + @with_client("/headers") + def test_protocol_custom_response_headers_callable(self): + self.loop.run_until_complete(self.client.recv()) # doesn't crash + self.loop.run_until_complete(self.client.recv()) # nothing to check + @with_server(extra_headers=Headers({"X-Spam": "Eggs"})) @with_client("/headers") def test_protocol_custom_response_headers(self): From 595978c75b42a768e8c85accd02884a8ffa9a503 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 Jun 2019 18:11:01 +0200 Subject: [PATCH 0577/1539] Add missing changelog entry for 8fc78fee. --- docs/changelog.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index f4cd8a4b6..f02280855 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -39,6 +39,9 @@ Also: * :func:`~client.connect()` handles redirects from the server during the handshake. +* Improved support for sending fragmented messages by accepting asynchronous + iterators in :meth:`~protocol.WebSocketCommonProtocol.send`. + * Avoided a crash of a ``extra_headers`` callable returns ``None``. * Enabled readline in the interactive client. From f255722c158b531415916ae29be16f295987d5d7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 17 Jun 2019 13:17:17 +0200 Subject: [PATCH 0578/1539] Fix copy-paste mistake in b3d60d75. --- tests/test_client_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_client_server.py b/tests/test_client_server.py index 5c441561f..21de5486f 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -568,7 +568,7 @@ def test_protocol_custom_response_headers_callable_list(self): @with_server(extra_headers=lambda p, r: None) @with_client("/headers") - def test_protocol_custom_response_headers_callable(self): + def test_protocol_custom_response_headers_callable_none(self): self.loop.run_until_complete(self.client.recv()) # doesn't crash self.loop.run_until_complete(self.client.recv()) # nothing to check From 218f0a9866740773349e8d50f76b4af1d9873d39 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 19 Jun 2019 21:04:29 +0200 Subject: [PATCH 0579/1539] Change process_request to be a coroutine (again). b64fee8e made it possible to use either a function or a coroutine. It was part of the 7.0 release. 4f1a14c3 documented this possibility but wasn't released. However, users may have read the "latest" docs and taken advantage of this. For this reason, include proper deprecation warnings and preserve backwards-compatibility (for the foreseeable future). The deprecation warnings need to be in two locations to account for passing a process_request argument and for overriding the process_request method. Issue #597 shows that `isinstance(..., Awaitable)` is more robust than `asyncio.iscoroutinefunction(...)` because it also supports functions returning awaitables. --- docs/changelog.rst | 16 ++++++- example/health_check_server.py | 2 +- src/websockets/server.py | 49 +++++++++++--------- tests/test_client_server.py | 81 ++++++++++++++++++++++++++++++---- 4 files changed, 116 insertions(+), 32 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index f02280855..53e5a1267 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -14,7 +14,21 @@ Changelog .. note:: - **Version 8.0 changes the behavior of the ``max_queue`` parameter.** + **Version 8.0 expects** ``process_request`` **to be a coroutine.** + + Previously, it could be a function or a coroutine. + + If you're passing a ``process_request`` argument to :func:`~server.serve` + or :class:`~server.WebSocketServerProtocol`, or if you're overriding + :meth:`~protocol.WebSocketServerProtocol.process_request` in a subclass, + define it with ``async def`` instead of ``def``. + + For backwards compatibility, functions are still supported. However, in + some inheritance scenarios, mixing functions and coroutines won't work. + +.. note:: + + **Version 8.0 changes the behavior of the** ``max_queue`` **parameter.** If you were setting ``max_queue=0`` to make the queue of incoming messages unbounded, change it to ``max_queue=None``. diff --git a/example/health_check_server.py b/example/health_check_server.py index 8e70890b5..feb04bccd 100755 --- a/example/health_check_server.py +++ b/example/health_check_server.py @@ -6,7 +6,7 @@ import http import websockets -def health_check(path, request_headers): +async def health_check(path, request_headers): if path == '/health/': return http.HTTPStatus.OK, [], b'OK\n' diff --git a/src/websockets/server.py b/src/websockets/server.py index 73c07cf11..870e4ec7a 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -84,10 +84,7 @@ def __init__( subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLikeOrCallable] = None, process_request: Optional[ - Callable[ - [str, Headers], - Union[Optional[HTTPResponse], Awaitable[Optional[HTTPResponse]]], - ] + Callable[[str, Headers], Awaitable[Optional[HTTPResponse]]] ] = None, select_subprotocol: Optional[ Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] @@ -266,15 +263,15 @@ def write_http_response( logger.debug("%s > Body (%d bytes)", self.side, len(body)) self.writer.write(body) - def process_request( + async def process_request( self, path: str, request_headers: Headers - ) -> Union[Optional[HTTPResponse], Awaitable[Optional[HTTPResponse]]]: + ) -> Optional[HTTPResponse]: """ Intercept the HTTP request and return an HTTP response if needed. ``request_headers`` is a :class:`~websockets.http.Headers` instance. - If this method returns ``None``, the WebSocket handshake continues. + If this coroutine returns ``None``, the WebSocket handshake continues. If it returns a status code, headers and a response body, that HTTP response is sent and the connection is closed. @@ -286,12 +283,10 @@ def process_request( The HTTP response body must be :class:`bytes`. It may be empty. - This method may be overridden to check the request headers and set a - different status, for example to authenticate the request and return - ``HTTPStatus.UNAUTHORIZED`` or ``HTTPStatus.FORBIDDEN``. - - It can be declared as a function or as a coroutine because such - authentication checks are likely to require network requests. + This coroutine may be overridden to check the request headers and set + a different status, for example to authenticate the request and return + :attr:`http.HTTPStatus.UNAUTHORIZED` or + :attr:`http.HTTPStatus.FORBIDDEN`. It may also be overridden by passing a ``process_request`` argument to the :class:`WebSocketServerProtocol` constructor or the :func:`serve` @@ -299,7 +294,15 @@ def process_request( """ if self._process_request is not None: - return self._process_request(path, request_headers) + response = self._process_request(path, request_headers) + if isinstance(response, Awaitable): + return await response + else: + # For backwards-compatibility with 7.0. + warnings.warn( + "declare process_request as a coroutine", DeprecationWarning + ) + return response # type: ignore return None @staticmethod @@ -503,9 +506,13 @@ async def handshake( # Hook for customizing request handling, for example checking # authentication or treating some paths as plain HTTP endpoints. - early_response = self.process_request(path, request_headers) - if isinstance(early_response, Awaitable): - early_response = await early_response + early_response_awaitable = self.process_request(path, request_headers) + if isinstance(early_response_awaitable, Awaitable): + early_response = await early_response_awaitable + else: + # For backwards-compatibility with 7.0. + warnings.warn("declare process_request as a coroutine", DeprecationWarning) + early_response = early_response_awaitable # type: ignore # Change the response to a 503 error if the server is shutting down. if not self.ws_server.is_serving(): @@ -767,9 +774,9 @@ class Serve: :class:`~collections.abc.Mapping`, an iterable of ``(name, value)`` pairs, or a callable taking the request path and headers in arguments and returning one of the above - * ``process_request`` is a callable or a coroutine taking the request path - and headers in argument, see - :meth:`~WebSocketServerProtocol.process_request` for details + * ``process_request`` is a coroutine taking the request path and headers + in argument, see :meth:`~WebSocketServerProtocol.process_request` for + details * ``select_subprotocol`` is a callable taking the subprotocols offered by the client and available on the server in argument, see :meth:`~WebSocketServerProtocol.select_subprotocol` for details @@ -821,7 +828,7 @@ def __init__( subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLikeOrCallable] = None, process_request: Optional[ - Callable[[str, Headers], Optional[HTTPResponse]] + Callable[[str, Headers], Awaitable[Optional[HTTPResponse]]] ] = None, select_subprotocol: Optional[ Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] diff --git a/tests/test_client_server.py b/tests/test_client_server.py index 21de5486f..a540c373c 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -270,7 +270,7 @@ def start_server(self, expected_warning=None, **kwds): def start_redirecting_server( self, status, include_location=True, force_insecure=False ): - def _process_request(path, headers): + async def process_request(path, headers): server_uri = get_server_uri(self.server, self.secure, path) if force_insecure: server_uri = server_uri.replace("wss:", "ws:") @@ -283,7 +283,7 @@ def _process_request(path, headers): 0, compression=None, ping_interval=None, - process_request=_process_request, + process_request=process_request, ssl=self.server_context, ) self.redirecting_server = self.loop.run_until_complete(start_server) @@ -458,15 +458,65 @@ def test_unix_socket(self): client_socket.close() self.stop_server() - @with_server(process_request=lambda p, rh: (http.HTTPStatus.OK, [], b"OK\n")) + async def process_request_OK(path, request_headers): + return http.HTTPStatus.OK, [], b"OK\n" + + @with_server(process_request=process_request_OK) def test_process_request_argument(self): response = self.loop.run_until_complete(self.make_http_request("/")) with contextlib.closing(response): self.assertEqual(response.code, 200) + def legacy_process_request_OK(path, request_headers): + return http.HTTPStatus.OK, [], b"OK\n" + + @with_server(process_request=legacy_process_request_OK) + def test_process_request_argument_backwards_compatibility(self): + with warnings.catch_warnings(record=True) as recorded_warnings: + response = self.loop.run_until_complete(self.make_http_request("/")) + + with contextlib.closing(response): + self.assertEqual(response.code, 200) + + self.assertEqual(len(recorded_warnings), 1) + warning = recorded_warnings[0].message + self.assertEqual(str(warning), "declare process_request as a coroutine") + self.assertEqual(type(warning), DeprecationWarning) + + class ProcessRequestOKServerProtocol(WebSocketServerProtocol): + async def process_request(self, path, request_headers): + return http.HTTPStatus.OK, [], b"OK\n" + + @with_server(create_protocol=ProcessRequestOKServerProtocol) + def test_process_request_override(self): + response = self.loop.run_until_complete(self.make_http_request("/")) + + with contextlib.closing(response): + self.assertEqual(response.code, 200) + + class LegacyProcessRequestOKServerProtocol(WebSocketServerProtocol): + def process_request(self, path, request_headers): + return http.HTTPStatus.OK, [], b"OK\n" + + @with_server(create_protocol=LegacyProcessRequestOKServerProtocol) + def test_process_request_override_backwards_compatibility(self): + with warnings.catch_warnings(record=True) as recorded_warnings: + response = self.loop.run_until_complete(self.make_http_request("/")) + + with contextlib.closing(response): + self.assertEqual(response.code, 200) + + self.assertEqual(len(recorded_warnings), 1) + warning = recorded_warnings[0].message + self.assertEqual(str(warning), "declare process_request as a coroutine") + self.assertEqual(type(warning), DeprecationWarning) + + def select_subprotocol_chat(client_subprotocols, server_subprotocols): + return "chat" + @with_server( - subprotocols=["superchat", "chat"], select_subprotocol=lambda cs, ss: "chat" + subprotocols=["superchat", "chat"], select_subprotocol=select_subprotocol_chat ) @with_client("/subprotocol", subprotocols=["superchat", "chat"]) def test_select_subprotocol_argument(self): @@ -474,6 +524,20 @@ def test_select_subprotocol_argument(self): self.assertEqual(server_subprotocol, repr("chat")) self.assertEqual(self.client.subprotocol, "chat") + class SelectSubprotocolChatServerProtocol(WebSocketServerProtocol): + def select_subprotocol(self, client_subprotocols, server_subprotocols): + return "chat" + + @with_server( + subprotocols=["superchat", "chat"], + create_protocol=SelectSubprotocolChatServerProtocol, + ) + @with_client("/subprotocol", subprotocols=["superchat", "chat"]) + def test_select_subprotocol_override(self): + server_subprotocol = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(server_subprotocol, repr("chat")) + self.assertEqual(self.client.subprotocol, "chat") + @with_server() @with_client("/attributes") def test_protocol_attributes(self): @@ -658,11 +722,10 @@ def assert_client_raises_code(self, status_code): def test_server_create_protocol(self): self.assert_client_raises_code(401) - @with_server( - create_protocol=( - lambda *args, **kwargs: UnauthorizedServerProtocol(*args, **kwargs) - ) - ) + def create_unauthorized_server_protocol(*args, **kwargs): + return UnauthorizedServerProtocol(*args, **kwargs) + + @with_server(create_protocol=create_unauthorized_server_protocol) def test_server_create_protocol_function(self): self.assert_client_raises_code(401) From b288d651e07e9f1614b10421d55833984b448132 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 21 Jun 2019 21:37:59 +0200 Subject: [PATCH 0580/1539] Factor out test utilities. --- tests/test_client_server.py | 122 +++++++++++++++--------------------- tests/test_framing.py | 9 +-- tests/test_http.py | 10 +-- tests/test_protocol.py | 29 ++------- tests/test_speedups.py | 0 tests/utils.py | 38 +++++++++++ 6 files changed, 96 insertions(+), 112 deletions(-) delete mode 100644 tests/test_speedups.py create mode 100644 tests/utils.py diff --git a/tests/test_client_server.py b/tests/test_client_server.py index a540c373c..8a1177a7e 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -33,6 +33,7 @@ from websockets.server import * from .test_protocol import MS +from .utils import AsyncioTestCase # Avoid displaying stack traces at the ERROR logging level. @@ -226,25 +227,15 @@ def encode(self, frame): return frame -class ClientServerTests(unittest.TestCase): +class ClientServerTestsMixin: secure = False def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) + super().setUp() self.server = None self.redirecting_server = None - def tearDown(self): - self.loop.close() - - def run_loop_once(self): - # Process callbacks scheduled with call_soon by appending a callback - # to stop the event loop then running it until it hits that callback. - self.loop.call_soon(self.loop.stop) - self.loop.run_forever() - @property def server_context(self): return None @@ -349,6 +340,40 @@ def temp_client(self, *args, **kwds): with temp_test_client(self, *args, **kwds): yield + +class SecureClientServerTestsMixin(ClientServerTestsMixin): + + secure = True + + @property + def server_context(self): + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + ssl_context.load_cert_chain(testcert) + return ssl_context + + @property + def client_context(self): + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ssl_context.load_verify_locations(testcert) + return ssl_context + + def start_server(self, **kwds): + kwds.setdefault("ssl", self.server_context) + super().start_server(**kwds) + + def start_client(self, path="/", **kwds): + kwds.setdefault("ssl", self.client_context) + super().start_client(path, **kwds) + + +class CommonClientServerTests: + """ + Mixin that defines most tests but doesn't inherit unittest.TestCase. + + Tests are run by the ClientServerTests and SecureClientServerTests subclasses. + + """ + @with_server() @with_client() def test_basic(self): @@ -1211,29 +1236,15 @@ def test_connection_error_during_closing_handshake(self, close): self.assertEqual(self.client.close_code, 1006) -class SSLClientServerTests(ClientServerTests): - - secure = True - - @property - def server_context(self): - ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - ssl_context.load_cert_chain(testcert) - return ssl_context - - @property - def client_context(self): - ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - ssl_context.load_verify_locations(testcert) - return ssl_context +class ClientServerTests( + CommonClientServerTests, ClientServerTestsMixin, AsyncioTestCase +): + pass - def start_server(self, **kwds): - kwds.setdefault("ssl", self.server_context) - super().start_server(**kwds) - def start_client(self, path="/", **kwds): - kwds.setdefault("ssl", self.client_context) - super().start_client(path, **kwds) +class SecureClientServerTests( + CommonClientServerTests, SecureClientServerTestsMixin, AsyncioTestCase +): # TLS over Unix sockets doesn't make sense. test_unix_socket = None @@ -1253,14 +1264,7 @@ def test_redirect_insecure(self): self.fail("Did not raise") # pragma: no cover -class ClientServerOriginTests(unittest.TestCase): - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) - - def tearDown(self): - self.loop.close() - +class ClientServerOriginTests(AsyncioTestCase): def test_checking_origin_succeeds(self): server = self.loop.run_until_complete( serve(handler, "localhost", 0, origins=["http://localhost"]) @@ -1337,14 +1341,7 @@ def test_checking_lack_of_origin_succeeds_backwards_compatibility(self): self.loop.run_until_complete(server.wait_closed()) -class YieldFromTests(unittest.TestCase): - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) - - def tearDown(self): - self.loop.close() - +class YieldFromTests(AsyncioTestCase): def test_client(self): start_server = serve(handler, "localhost", 0) server = self.loop.run_until_complete(start_server) @@ -1375,14 +1372,7 @@ def run_server(): self.loop.run_until_complete(run_server()) -class AsyncAwaitTests(unittest.TestCase): - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) - - def tearDown(self): - self.loop.close() - +class AsyncAwaitTests(AsyncioTestCase): def test_client(self): start_server = serve(handler, "localhost", 0) server = self.loop.run_until_complete(start_server) @@ -1411,14 +1401,7 @@ async def run_server(): self.loop.run_until_complete(run_server()) -class ContextManagerTests(unittest.TestCase): - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) - - def tearDown(self): - self.loop.close() - +class ContextManagerTests(AsyncioTestCase): def test_client(self): start_server = serve(handler, "localhost", 0) server = self.loop.run_until_complete(start_server) @@ -1461,20 +1444,13 @@ async def run_server(path): self.loop.run_until_complete(run_server(path)) -class AsyncIteratorTests(unittest.TestCase): +class AsyncIteratorTests(AsyncioTestCase): # This is a protocol-level feature, but since it's a high-level API, it is # much easier to exercise at the client or server level. MESSAGES = ["3", "2", "1", "Fire!"] - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) - - def tearDown(self): - self.loop.close() - def test_iterate_on_messages(self): async def handler(ws, path): for message in self.MESSAGES: diff --git a/tests/test_framing.py b/tests/test_framing.py index 83d0a251a..430faf6e1 100644 --- a/tests/test_framing.py +++ b/tests/test_framing.py @@ -6,15 +6,10 @@ from websockets.exceptions import PayloadTooBig, WebSocketProtocolError from websockets.framing import * +from .utils import AsyncioTestCase -class FramingTests(unittest.TestCase): - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) - - def tearDown(self): - self.loop.close() +class FramingTests(AsyncioTestCase): def decode(self, message, mask=False, max_size=None, extensions=None): self.stream = asyncio.StreamReader(loop=self.loop) self.stream.feed_data(message) diff --git a/tests/test_http.py b/tests/test_http.py index 39961d641..60cdb9a25 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -4,18 +4,14 @@ from websockets.http import * from websockets.http import read_headers +from .utils import AsyncioTestCase -class HTTPAsyncTests(unittest.TestCase): + +class HTTPAsyncTests(AsyncioTestCase): def setUp(self): super().setUp() - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) self.stream = asyncio.StreamReader(loop=self.loop) - def tearDown(self): - self.loop.close() - super().tearDown() - def test_read_request(self): # Example from the protocol overview in RFC 6455 self.stream.feed_data( diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 976cc7e9b..938e54d8d 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -1,8 +1,6 @@ import asyncio import contextlib import logging -import os -import time import unittest import unittest.mock import warnings @@ -11,23 +9,13 @@ from websockets.framing import * from websockets.protocol import State, WebSocketCommonProtocol +from .utils import MS, AsyncioTestCase + # Avoid displaying stack traces at the ERROR logging level. logging.basicConfig(level=logging.CRITICAL) -# Unit for timeouts. May be increased on slow machines by setting the -# WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. -MS = 0.001 * int(os.environ.get("WEBSOCKETS_TESTS_TIMEOUT_FACTOR", 1)) - -# asyncio's debug mode has a 10x performance penalty for this test suite. -if os.environ.get("PYTHONASYNCIODEBUG"): # pragma: no cover - MS *= 10 - -# Ensure that timeouts are larger than the clock's resolution (for Windows). -MS = max(MS, 2.5 * time.get_clock_info("monotonic").resolution) - - async def async_iterable(iterable): for item in iterable: yield item @@ -93,8 +81,6 @@ class CommonTests: def setUp(self): super().setUp() - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) # Disable pings to make it easier to test what frames are sent exactly. self.protocol = WebSocketCommonProtocol(ping_interval=None) self.transport = TransportMock() @@ -103,17 +89,10 @@ def setUp(self): def tearDown(self): self.transport.close() self.loop.run_until_complete(self.protocol.close()) - self.loop.close() super().tearDown() # Utilities for writing tests. - def run_loop_once(self): - # Process callbacks scheduled with call_soon by appending a callback - # to stop the event loop then running it until it hits that callback. - self.loop.call_soon(self.loop.stop) - self.loop.run_forever() - def make_drain_slow(self, delay=MS): # Process connection_made in order to initialize self.protocol.writer. self.run_loop_once() @@ -1248,7 +1227,7 @@ def test_remote_close_during_send(self): # happen, considering that writes are serialized. -class ServerTests(CommonTests, unittest.TestCase): +class ServerTests(CommonTests, AsyncioTestCase): def setUp(self): super().setUp() self.protocol.is_client = False @@ -1299,7 +1278,7 @@ def test_local_close_connection_lost_timeout_after_close(self): self.assertConnectionClosed(1000, "close") -class ClientTests(CommonTests, unittest.TestCase): +class ClientTests(CommonTests, AsyncioTestCase): def setUp(self): super().setUp() self.protocol.is_client = True diff --git a/tests/test_speedups.py b/tests/test_speedups.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 000000000..0a9f14ce1 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,38 @@ +import asyncio +import os +import time +import unittest + + +class AsyncioTestCase(unittest.TestCase): + """ + Base class for tests that sets up an isolated event loop for each test. + + """ + + def setUp(self): + super().setUp() + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + super().tearDown() + + def run_loop_once(self): + # Process callbacks scheduled with call_soon by appending a callback + # to stop the event loop then running it until it hits that callback. + self.loop.call_soon(self.loop.stop) + self.loop.run_forever() + + +# Unit for timeouts. May be increased on slow machines by setting the +# WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. +MS = 0.001 * int(os.environ.get("WEBSOCKETS_TESTS_TIMEOUT_FACTOR", 1)) + +# asyncio's debug mode has a 10x performance penalty for this test suite. +if os.environ.get("PYTHONASYNCIODEBUG"): # pragma: no cover + MS *= 10 + +# Ensure that timeouts are larger than the clock's resolution (for Windows). +MS = max(MS, 2.5 * time.get_clock_info("monotonic").resolution) From 250c0e05694bc57094a9511655af5dc364470ce3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 21 Jun 2019 21:42:35 +0200 Subject: [PATCH 0581/1539] Add string representation for RedirectHandshake. --- src/websockets/exceptions.py | 4 ++++ tests/test_exceptions.py | 5 ++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 7fdc97185..22978ec6f 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -24,6 +24,7 @@ "InvalidURI", "NegotiationError", "PayloadTooBig", + "RedirectHandshake", "WebSocketProtocolError", ] @@ -60,6 +61,9 @@ class RedirectHandshake(InvalidHandshake): def __init__(self, uri: str) -> None: self.uri = uri + def __str__(self) -> str: + return f"Redirect to {self.uri}" + class InvalidMessage(InvalidHandshake): """ diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 4b9830345..27e1b53ca 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -16,6 +16,10 @@ def test_str(self): AbortHandshake(200, Headers(), b"OK\n"), "HTTP 200, 0 headers, 3 bytes", ), + ( + RedirectHandshake("wss://example.com"), + "Redirect to wss://example.com", + ), ( InvalidMessage("Malformed HTTP message"), "Malformed HTTP message", @@ -47,7 +51,6 @@ def test_str(self): InvalidHeaderValue("Sec-WebSocket-Version", "42"), "Invalid Sec-WebSocket-Version header: 42", ), - ( InvalidUpgrade("Upgrade"), "Missing Upgrade header", From 3a7e4a3810675a015f783a499305864d0efb8705 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 22 Jun 2019 09:31:57 +0200 Subject: [PATCH 0582/1539] Remove override made unnecessary by 218f0a98. --- docs/api.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/api.rst b/docs/api.rst index acdc69dab..ef02c9a83 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -49,7 +49,6 @@ Server .. automethod:: handshake(origins=None, available_extensions=None, available_subprotocols=None, extra_headers=None) .. automethod:: process_request(path, request_headers) - :async: .. automethod:: select_subprotocol(client_subprotocols, server_subprotocols) Client From 3278eddb7bcbf51637c8ac64680bd7176db57b6d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 22 Jun 2019 09:47:45 +0200 Subject: [PATCH 0583/1539] Remove explicit argument lists. Except those that Sphinx cannot build automatically because of backwards-compatibility hacks. --- docs/api.rst | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index ef02c9a83..9870c5dff 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -41,15 +41,15 @@ Server .. autoclass:: WebSocketServer - .. automethod:: close() - .. automethod:: wait_closed() + .. automethod:: close + .. automethod:: wait_closed .. autoattribute:: sockets .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, host=None, port=None, secure=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None) - .. automethod:: handshake(origins=None, available_extensions=None, available_subprotocols=None, extra_headers=None) - .. automethod:: process_request(path, request_headers) - .. automethod:: select_subprotocol(client_subprotocols, server_subprotocols) + .. automethod:: handshake + .. automethod:: process_request + .. automethod:: select_subprotocol Client ...... @@ -61,7 +61,7 @@ Client .. autoclass:: WebSocketClientProtocol(*, host=None, port=None, secure=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, extensions=None, subprotocols=None, extra_headers=None) - .. automethod:: handshake(wsuri, origin=None, available_extensions=None, available_subprotocols=None, extra_headers=None) + .. automethod:: handshake Shared ...... @@ -70,14 +70,14 @@ Shared .. autoclass:: WebSocketCommonProtocol(*, host=None, port=None, secure=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None) - .. automethod:: close(code=1000, reason='') - .. automethod:: wait_closed() + .. automethod:: close + .. automethod:: wait_closed - .. automethod:: recv() - .. automethod:: send(data) + .. automethod:: recv + .. automethod:: send - .. automethod:: ping(data=None) - .. automethod:: pong(data=b'') + .. automethod:: ping + .. automethod:: pong .. autoattribute:: local_address .. autoattribute:: remote_address From c2b7e1bb5a221fe7fea6efb53ab33d8efc8783ad Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 22 Jun 2019 09:48:20 +0200 Subject: [PATCH 0584/1539] Remove parentheses in :func: and :meth: references. --- docs/changelog.rst | 22 +++++------ docs/deployment.rst | 4 +- docs/design.rst | 90 ++++++++++++++++++++++----------------------- docs/extensions.rst | 2 +- 4 files changed, 59 insertions(+), 59 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 53e5a1267..17ecd5523 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -50,7 +50,7 @@ Also: :exc:`~exceptions.ConnectionClosed` to tell apart normal connection termination from errors. -* :func:`~client.connect()` handles redirects from the server during the +* :func:`~client.connect` handles redirects from the server during the handshake. * Improved support for sending fragmented messages by accepting asynchronous @@ -72,7 +72,7 @@ Also: .. warning:: **Version 7.0 renames the** ``timeout`` **argument of** - :func:`~server.serve()` **and** :func:`~client.connect()` **to** + :func:`~server.serve()` **and** :func:`~client.connect` **to** ``close_timeout`` **.** This prevents confusion with ``ping_timeout``. @@ -122,7 +122,7 @@ Also: :class:`~protocol.WebSocketCommonProtocol` for details. * Added ``process_request`` and ``select_subprotocol`` arguments to - :func:`~server.serve()` and :class:`~server.WebSocketServerProtocol` to + :func:`~server.serve` and :class:`~server.WebSocketServerProtocol` to customize :meth:`~server.WebSocketServerProtocol.process_request` and :meth:`~server.WebSocketServerProtocol.select_subprotocol` without subclassing :class:`~server.WebSocketServerProtocol`. @@ -187,7 +187,7 @@ Also: ..... * Fixed a regression in the 5.0 release that broke some invocations of - :func:`~server.serve()` and :func:`~client.connect()`. + :func:`~server.serve()` and :func:`~client.connect`. 5.0 ... @@ -212,7 +212,7 @@ Also: Also: -* :func:`~client.connect()` performs HTTP Basic Auth when the URI contains +* :func:`~client.connect` performs HTTP Basic Auth when the URI contains credentials. * Iterating on incoming messages no longer raises an exception when the @@ -268,7 +268,7 @@ Also: Compression should improve performance but it increases RAM and CPU use. If you want to disable compression, add ``compression=None`` when calling - :func:`~server.serve()` or :func:`~client.connect()`. + :func:`~server.serve()` or :func:`~client.connect`. .. warning:: @@ -306,7 +306,7 @@ Also: 3.4 ... -* Renamed :func:`~server.serve()` and :func:`~client.connect()`'s ``klass`` +* Renamed :func:`~server.serve()` and :func:`~client.connect`'s ``klass`` argument to ``create_protocol`` to reflect that it can also be a callable. For backwards compatibility, ``klass`` is still supported. @@ -314,7 +314,7 @@ Also: Python ≥ 3.5.1. * Added support for customizing handling of incoming connections with - :meth:`~server.WebSocketServerProtocol.process_request()`. + :meth:`~server.WebSocketServerProtocol.process_request`. * Made read and write buffer sizes configurable. @@ -322,10 +322,10 @@ Also: * Added an optional C extension to speed up low-level operations. -* An invalid response status code during :func:`~client.connect()` now raises +* An invalid response status code during :func:`~client.connect` now raises :class:`~exceptions.InvalidStatusCode` with a ``code`` attribute. -* Providing a ``sock`` argument to :func:`~client.connect()` no longer +* Providing a ``sock`` argument to :func:`~client.connect` no longer crashes. 3.3 @@ -341,7 +341,7 @@ Also: ... * Added ``timeout``, ``max_size``, and ``max_queue`` arguments to - :func:`~client.connect()` and :func:`~server.serve()`. + :func:`~client.connect()` and :func:`~server.serve`. * Made server shutdown more robust. diff --git a/docs/deployment.rst b/docs/deployment.rst index b0c05dd73..9aa2d3744 100644 --- a/docs/deployment.rst +++ b/docs/deployment.rst @@ -127,7 +127,7 @@ Under high load, if a server receives more messages than it can process, bufferbloat can result in excessive memory use. By default ``websockets`` has generous limits. It is strongly recommended to -adapt them to your application. When you call :func:`~server.serve()`: +adapt them to your application. When you call :func:`~server.serve`: - Set ``max_size`` (default: 1 MiB, UTF-8 encoded) to the maximum size of messages your application generates. @@ -150,7 +150,7 @@ The author of ``websockets`` doesn't think that's a good idea, due to the widely different operational characteristics of HTTP and WebSocket. ``websockets`` provide minimal support for responding to HTTP requests with -the :meth:`~server.WebSocketServerProtocol.process_request()` hook. Typical +the :meth:`~server.WebSocketServerProtocol.process_request` hook. Typical use cases include health checks. Here's an example: .. literalinclude:: ../example/health_check_server.py diff --git a/docs/design.rst b/docs/design.rst index c6097f724..19cda16bb 100644 --- a/docs/design.rst +++ b/docs/design.rst @@ -32,20 +32,20 @@ WebSocket connections go through a trivial state machine: Transitions happen in the following places: - ``CONNECTING -> OPEN``: in - :meth:`~protocol.WebSocketCommonProtocol.connection_open()` which runs when + :meth:`~protocol.WebSocketCommonProtocol.connection_open` which runs when the :ref:`opening handshake ` completes and the WebSocket connection is established — not to be confused with :meth:`~asyncio.Protocol.connection_made` which runs when the TCP connection is established; - ``OPEN -> CLOSING``: in - :meth:`~protocol.WebSocketCommonProtocol.write_frame()` immediately before + :meth:`~protocol.WebSocketCommonProtocol.write_frame` immediately before sending a close frame; since receiving a close frame triggers sending a close frame, this does the right thing regardless of which side started the :ref:`closing handshake `; also in - :meth:`~protocol.WebSocketCommonProtocol.fail_connection()` which duplicates + :meth:`~protocol.WebSocketCommonProtocol.fail_connection` which duplicates a few lines of code from `write_close_frame()` and `write_frame()`; - ``* -> CLOSED``: in - :meth:`~protocol.WebSocketCommonProtocol.connection_lost()` which is always + :meth:`~protocol.WebSocketCommonProtocol.connection_lost` which is always called exactly once when the TCP connection is closed. Coroutines @@ -58,35 +58,35 @@ connection lifecycle on the client side. :target: _images/lifecycle.svg The lifecycle is identical on the server side, except inversion of control -makes the equivalent of :meth:`~client.connect()` implicit. +makes the equivalent of :meth:`~client.connect` implicit. Coroutines shown in green are called by the application. Multiple coroutines may interact with the WebSocket connection concurrently. Coroutines shown in gray manage the connection. When the opening handshake -succeeds, :meth:`~protocol.WebSocketCommonProtocol.connection_open()` starts +succeeds, :meth:`~protocol.WebSocketCommonProtocol.connection_open` starts two tasks: - :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` runs - :meth:`~protocol.WebSocketCommonProtocol.transfer_data()` which handles - incoming data and lets :meth:`~protocol.WebSocketCommonProtocol.recv()` + :meth:`~protocol.WebSocketCommonProtocol.transfer_data` which handles + incoming data and lets :meth:`~protocol.WebSocketCommonProtocol.recv` consume it. It may be canceled to terminate the connection. It never exits with an exception other than :exc:`~asyncio.CancelledError`. See :ref:`data transfer ` below. - :attr:`~protocol.WebSocketCommonProtocol.keepalive_ping_task` runs - :meth:`~protocol.WebSocketCommonProtocol.keepalive_ping()` which sends Ping + :meth:`~protocol.WebSocketCommonProtocol.keepalive_ping` which sends Ping frames at regular intervals and ensures that corresponding Pong frames are received. It is canceled when the connection terminates. It never exits with an exception other than :exc:`~asyncio.CancelledError`. - :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` runs - :meth:`~protocol.WebSocketCommonProtocol.close_connection()` which waits for + :meth:`~protocol.WebSocketCommonProtocol.close_connection` which waits for the data transfer to terminate, then takes care of closing the TCP connection. It must not be canceled. It never exits with an exception. See :ref:`connection termination ` below. -Besides, :meth:`~protocol.WebSocketCommonProtocol.fail_connection()` starts +Besides, :meth:`~protocol.WebSocketCommonProtocol.fail_connection` starts the same :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` when the opening handshake fails, in order to close the TCP connection. @@ -113,7 +113,7 @@ Opening handshake ----------------- ``websockets`` performs the opening handshake when establishing a WebSocket -connection. On the client side, :meth:`~client.connect()` executes it before +connection. On the client side, :meth:`~client.connect` executes it before returning the protocol to the caller. On the server side, it's executed before passing the protocol to the ``ws_handler`` coroutine handling the connection. @@ -122,26 +122,26 @@ request and the server replies with an HTTP Switching Protocols response — ``websockets`` aims at keeping the implementation of both sides consistent with one another. -On the client side, :meth:`~client.WebSocketClientProtocol.handshake()`: +On the client side, :meth:`~client.WebSocketClientProtocol.handshake`: - builds a HTTP request based on the ``uri`` and parameters passed to - :meth:`~client.connect()`; + :meth:`~client.connect`; - writes the HTTP request to the network; - reads a HTTP response from the network; - checks the HTTP response, validates ``extensions`` and ``subprotocol``, and configures the protocol accordingly; - moves to the ``OPEN`` state. -On the server side, :meth:`~server.WebSocketServerProtocol.handshake()`: +On the server side, :meth:`~server.WebSocketServerProtocol.handshake`: - reads a HTTP request from the network; -- calls :meth:`~server.WebSocketServerProtocol.process_request()` which may +- calls :meth:`~server.WebSocketServerProtocol.process_request` which may abort the WebSocket handshake and return a HTTP response instead; this hook only makes sense on the server side; - checks the HTTP request, negotiates ``extensions`` and ``subprotocol``, and configures the protocol accordingly; - builds a HTTP response based on the above and parameters passed to - :meth:`~server.serve()`; + :meth:`~server.serve`; - writes the HTTP response to the network; - moves to the ``OPEN`` state; - returns the ``path`` part of the ``uri``. @@ -226,10 +226,10 @@ When it encounters a control frame: Running this process in a task guarantees that control frames are processed promptly. Without such a task, ``websockets`` would depend on the application to drive the connection by having exactly one coroutine awaiting -:meth:`~protocol.WebSocketCommonProtocol.recv()` at any time. While this +:meth:`~protocol.WebSocketCommonProtocol.recv` at any time. While this happens naturally in many use cases, it cannot be relied upon. -Then :meth:`~protocol.WebSocketCommonProtocol.recv()` fetches the next message +Then :meth:`~protocol.WebSocketCommonProtocol.recv` fetches the next message from the :attr:`~protocol.WebSocketCommonProtocol.messages` queue, with some complexity added for handling termination correctly. @@ -238,16 +238,16 @@ Sending data The right side of the diagram shows how ``websockets`` sends data. -:meth:`~protocol.WebSocketCommonProtocol.send()` writes a single data frame +:meth:`~protocol.WebSocketCommonProtocol.send` writes a single data frame containing the message. Fragmentation isn't supported at this time. -:meth:`~protocol.WebSocketCommonProtocol.ping()` writes a ping frame and +:meth:`~protocol.WebSocketCommonProtocol.ping` writes a ping frame and yields a :class:`~asyncio.Future` which will be completed when a matching pong frame is received. -:meth:`~protocol.WebSocketCommonProtocol.pong()` writes a pong frame. +:meth:`~protocol.WebSocketCommonProtocol.pong` writes a pong frame. -:meth:`~protocol.WebSocketCommonProtocol.close()` writes a close frame and +:meth:`~protocol.WebSocketCommonProtocol.close` writes a close frame and waits for the TCP connection to terminate. Outgoing data is written to a :class:`~asyncio.StreamWriter` in order to @@ -259,15 +259,15 @@ Closing handshake ................. When the other side of the connection initiates the closing handshake, -:meth:`~protocol.WebSocketCommonProtocol.read_message()` receives a close +:meth:`~protocol.WebSocketCommonProtocol.read_message` receives a close frame while in the ``OPEN`` state. It moves to the ``CLOSING`` state, sends a close frame, and returns ``None``, causing :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to terminate. When this side of the connection initiates the closing handshake with -:meth:`~protocol.WebSocketCommonProtocol.close()`, it moves to the ``CLOSING`` +:meth:`~protocol.WebSocketCommonProtocol.close`, it moves to the ``CLOSING`` state and sends a close frame. When the other side sends a close frame, -:meth:`~protocol.WebSocketCommonProtocol.read_message()` receives it in the +:meth:`~protocol.WebSocketCommonProtocol.read_message` receives it in the ``CLOSING`` state and returns ``None``, also causing :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to terminate. @@ -417,30 +417,30 @@ Once the WebSocket connection is established, internal tasks accidentally canceled if a coroutine that awaits them is canceled. In other words, they must be shielded from cancellation. -:meth:`~protocol.WebSocketCommonProtocol.recv()` waits for the next message in +:meth:`~protocol.WebSocketCommonProtocol.recv` waits for the next message in the queue or for :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` -to terminate, whichever comes first. It relies on :func:`~asyncio.wait()` for +to terminate, whichever comes first. It relies on :func:`~asyncio.wait` for waiting on two tasks in parallel. As a consequence, even though it's waiting on the transfer data task, it doesn't propagate cancellation to that task. -:meth:`~protocol.WebSocketCommonProtocol.ensure_open()` is called by -:meth:`~protocol.WebSocketCommonProtocol.send()`, -:meth:`~protocol.WebSocketCommonProtocol.ping()`, and -:meth:`~protocol.WebSocketCommonProtocol.pong()`. When the connection state is +:meth:`~protocol.WebSocketCommonProtocol.ensure_open` is called by +:meth:`~protocol.WebSocketCommonProtocol.send`, +:meth:`~protocol.WebSocketCommonProtocol.ping`, and +:meth:`~protocol.WebSocketCommonProtocol.pong`. When the connection state is ``CLOSING``, it waits for :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` but shields it to prevent cancellation. -:meth:`~protocol.WebSocketCommonProtocol.close()` waits for the data transfer +:meth:`~protocol.WebSocketCommonProtocol.close` waits for the data transfer task to terminate with :func:`~asyncio.wait_for`. If it's canceled or if the timeout elapses, :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` is canceled, which is correct at this point. -:meth:`~protocol.WebSocketCommonProtocol.close()` then waits for +:meth:`~protocol.WebSocketCommonProtocol.close` then waits for :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` but shields it to prevent cancellation. -:meth:`~protocol.WebSocketCommonProtocol.close()` and -:func:`~protocol.WebSocketCommonProtocol.fail_connection()` are the only +:meth:`~protocol.WebSocketCommonProtocol.close` and +:func:`~protocol.WebSocketCommonProtocol.fail_connection` are the only places where :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` may be canceled. @@ -515,35 +515,35 @@ For each connection, the receiving side contains these buffers: - OS buffers: tuning them is an advanced optimization. - :class:`~asyncio.StreamReader` bytes buffer: the default limit is 64 KiB. You can set another limit by passing a ``read_limit`` keyword argument to - :func:`~client.connect()` or :func:`~server.serve()`. + :func:`~client.connect()` or :func:`~server.serve`. - Incoming messages :class:`~collections.deque`: its size depends both on the size and the number of messages it contains. By default the maximum UTF-8 encoded size is 1 MiB and the maximum number is 32. In the worst case, after UTF-8 decoding, a single message could take up to 4 MiB of memory and the overall memory consumption could reach 128 MiB. You should adjust these limits by setting the ``max_size`` and ``max_queue`` keyword arguments of - :func:`~client.connect()` or :func:`~server.serve()` according to your + :func:`~client.connect()` or :func:`~server.serve` according to your application's requirements. For each connection, the sending side contains these buffers: - :class:`~asyncio.StreamWriter` bytes buffer: the default size is 64 KiB. You can set another limit by passing a ``write_limit`` keyword argument to - :func:`~client.connect()` or :func:`~server.serve()`. + :func:`~client.connect()` or :func:`~server.serve`. - OS buffers: tuning them is an advanced optimization. Concurrency ----------- -Calling any combination of :meth:`~protocol.WebSocketCommonProtocol.recv()`, -:meth:`~protocol.WebSocketCommonProtocol.send()`, -:meth:`~protocol.WebSocketCommonProtocol.close()` -:meth:`~protocol.WebSocketCommonProtocol.ping()`, or -:meth:`~protocol.WebSocketCommonProtocol.pong()` concurrently is safe, +Calling any combination of :meth:`~protocol.WebSocketCommonProtocol.recv`, +:meth:`~protocol.WebSocketCommonProtocol.send`, +:meth:`~protocol.WebSocketCommonProtocol.close` +:meth:`~protocol.WebSocketCommonProtocol.ping`, or +:meth:`~protocol.WebSocketCommonProtocol.pong` concurrently is safe, including multiple calls to the same method. As shown above, receiving frames is independent from sending frames. That -isolates :meth:`~protocol.WebSocketCommonProtocol.recv()`, which receives +isolates :meth:`~protocol.WebSocketCommonProtocol.recv`, which receives frames, from the other methods, which send frames. Methods that send frames also support concurrent calls. While the connection diff --git a/docs/extensions.rst b/docs/extensions.rst index 7c282ffd0..400034090 100644 --- a/docs/extensions.rst +++ b/docs/extensions.rst @@ -14,7 +14,7 @@ Per-Message Deflate, specified in :rfc:`7692`. Per-Message Deflate ------------------- -:func:`~server.serve()` and :func:`~client.connect()` enable the Per-Message +:func:`~server.serve()` and :func:`~client.connect` enable the Per-Message Deflate extension by default. You can disable this with ``compression=None``. You can also configure the Per-Message Deflate extension explicitly if you From 56cd365310a341eef59f11659631d7ee73b9f1da Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 17 Jun 2019 13:18:18 +0200 Subject: [PATCH 0585/1539] Support HTTP Basic Auth on the server side. Fix #492. --- docs/api.rst | 11 +++ docs/changelog.rst | 3 + src/websockets/__init__.py | 4 +- src/websockets/auth.py | 151 ++++++++++++++++++++++++++++++++++++ src/websockets/client.py | 6 +- src/websockets/headers.py | 102 +++++++++++++++++++++++- tests/test_auth.py | 136 ++++++++++++++++++++++++++++++++ tests/test_client_server.py | 45 ++++++----- tests/test_headers.py | 66 ++++++++++++++-- 9 files changed, 488 insertions(+), 36 deletions(-) create mode 100644 src/websockets/auth.py create mode 100644 tests/test_auth.py diff --git a/docs/api.rst b/docs/api.rst index 9870c5dff..ef567ed5b 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -102,6 +102,17 @@ Per-Message Deflate Extension .. autoclass:: ClientPerMessageDeflateFactory +HTTP Basic Auth +............... + +.. automodule:: websockets.auth + + .. autofunction:: basic_auth_protocol_factory + + .. autoclass:: BasicAuthWebSocketServerProtocol + + .. automethod:: process_request + Exceptions .......... diff --git a/docs/changelog.rst b/docs/changelog.rst index 17ecd5523..77e9da0de 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -50,6 +50,9 @@ Also: :exc:`~exceptions.ConnectionClosed` to tell apart normal connection termination from errors. +* Added :func:`~auth.basic_auth_protocol_factory` to provide HTTP Basic Auth + on the server side. + * :func:`~client.connect` handles redirects from the server during the handshake. diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 9bfbdabfe..e7ba31ce5 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -1,5 +1,6 @@ # This relies on each of the submodules having an __all__ variable. +from .auth import * from .client import * from .exceptions import * from .protocol import * @@ -10,7 +11,8 @@ __all__ = ( - client.__all__ + auth.__all__ + + client.__all__ + exceptions.__all__ + protocol.__all__ + server.__all__ diff --git a/src/websockets/auth.py b/src/websockets/auth.py new file mode 100644 index 000000000..91d3d7420 --- /dev/null +++ b/src/websockets/auth.py @@ -0,0 +1,151 @@ +""" +The :mod:`websockets.auth` module implements HTTP Basic Authentication as +specified in :rfc:`7235` and :rfc:`7617`. + +""" + + +import functools +import http +from typing import Any, Awaitable, Callable, Iterable, Optional, Tuple, Type, Union + +from .exceptions import InvalidHeader +from .headers import build_www_authenticate_basic, parse_authorization_basic +from .http import Headers +from .server import HTTPResponse, WebSocketServerProtocol + + +__all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"] + +Credentials = Tuple[str, str] + + +def is_credentials(value: Any) -> bool: + try: + username, password = value + except (TypeError, ValueError): + return False + else: + return isinstance(username, str) and isinstance(password, str) + + +class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol): + """ + WebSocket server protocol that enforces HTTP Basic Auth. + + """ + + def __init__( + self, + *args: Any, + realm: str, + check_credentials: Callable[[str, str], Awaitable[bool]], + **kwargs: Any, + ) -> None: + self.realm = realm + self.check_credentials = check_credentials + super().__init__(*args, **kwargs) + + async def process_request( + self, path: str, request_headers: Headers + ) -> Optional[HTTPResponse]: + """ + Check HTTP Basic Auth and return a HTTP 401 or 403 response if needed. + + If authentication succeeds, the username of the authenticated user is + stored in the ``username`` attribute. + + """ + try: + authorization = request_headers["Authorization"] + except KeyError: + return ( + http.HTTPStatus.UNAUTHORIZED, + [("WWW-Authenticate", build_www_authenticate_basic(self.realm))], + b"Missing credentials\n", + ) + + try: + username, password = parse_authorization_basic(authorization) + except InvalidHeader: + return ( + http.HTTPStatus.UNAUTHORIZED, + [("WWW-Authenticate", build_www_authenticate_basic(self.realm))], + b"Unsupported credentials\n", + ) + + if not await self.check_credentials(username, password): + return (http.HTTPStatus.FORBIDDEN, [], b"Invalid credentials\n") + + self.username = username + + return await super().process_request(path, request_headers) + + +def basic_auth_protocol_factory( + realm: str, + credentials: Optional[Union[Credentials, Iterable[Credentials]]] = None, + check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None, + create_protocol: Type[ + BasicAuthWebSocketServerProtocol + ] = BasicAuthWebSocketServerProtocol, +) -> Callable[[Any], BasicAuthWebSocketServerProtocol]: + """ + Protocol factory that enforces HTTP Basic Auth. + + ``basic_auth_protocol_factory`` is designed to integrate with + :func:`~websockets.server.serve` like this:: + + websockets.serve( + ..., + create_protocol=websockets.basic_auth_protocol_factory( + realm="my dev server", + credentials=("hello", "iloveyou"), + ) + ) + + ``realm`` indicates the scope of protection. It should be an ASCII-only + :class:`str` because the encoding of non-ASCII characters is undefined. + Refer to section 2.2 of :rfc:`7235` for details. + + One of ``credentials`` or ``check_credentials`` must be provided but not + both. + + ``credentials`` defines hardcoded authorized credentials. It can be a + ``(username, password)`` pair or a list of such pairs. + + ``check_credentials`` defines a coroutine that checks whether credentials + are authorized. This coroutine receives ``username`` and ``password`` + arguments and returns a :class:`bool`. + + By default, ``basic_auth_protocol_factory`` creates instances of + :class:`BasicAuthWebSocketServerProtocol`. You can override this with the + ``create_protocol`` parameter. + + """ + if (credentials is None) == (check_credentials is None): + raise ValueError("Provide either credentials or check_credentials") + + if credentials is not None: + if is_credentials(credentials): + + async def check_credentials(username: str, password: str) -> bool: + return (username, password) == credentials + + elif isinstance(credentials, Iterable): + credentials_list = list(credentials) + if all(is_credentials(item) for item in credentials_list): + credentials_dict = dict(credentials_list) + + async def check_credentials(username: str, password: str) -> bool: + return credentials_dict.get(username) == password + + else: + raise ValueError(f"Invalid credentials argument: {credentials}") + + else: + raise ValueError(f"Invalid credentials argument: {credentials}") + + return functools.partial( + create_protocol, realm=realm, check_credentials=check_credentials + ) diff --git a/src/websockets/client.py b/src/websockets/client.py index 3d057a2e3..e6131ed7a 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -22,7 +22,7 @@ from .handshake import build_request, check_response from .headers import ( ExtensionHeader, - build_basic_auth, + build_authorization_basic, build_extension, build_subprotocol, parse_extension, @@ -256,7 +256,9 @@ async def handshake( request_headers["Host"] = f"{wsuri.host}:{wsuri.port}" if wsuri.user_info: - request_headers["Authorization"] = build_basic_auth(*wsuri.user_info) + request_headers["Authorization"] = build_authorization_basic( + *wsuri.user_info + ) if origin is not None: request_headers["Origin"] = origin diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 663e71d60..536cab592 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -7,11 +7,13 @@ """ + import base64 +import binascii import re from typing import Callable, List, NewType, Optional, Sequence, Tuple, TypeVar, cast -from .exceptions import InvalidHeaderFormat +from .exceptions import InvalidHeaderFormat, InvalidHeaderValue from .typing import ExtensionHeader, ExtensionParameter, Subprotocol @@ -22,6 +24,9 @@ "build_extension", "parse_subprotocol", "build_subprotocol", + "build_www_authenticate_basic", + "parse_authorization_basic", + "build_authorization_basic", ] @@ -107,6 +112,25 @@ def parse_quoted_string(header: str, pos: int, header_name: str) -> Tuple[str, i return _unquote_re.sub(r"\1", match.group()[1:-1]), match.end() +_quotable_re = re.compile(r"[\x09\x20-\x7e\x80-\xff]*") + + +_quote_re = re.compile(r"([\x22\x5c])") + + +def build_quoted_string(value: str) -> str: + """ + Format ``value`` as a quoted string. + + This is the reverse of :func:`parse_quoted_string`. + + """ + match = _quotable_re.fullmatch(value) + if match is None: + raise ValueError("invalid characters for quoted-string encoding") + return '"' + _quote_re.sub(r"\\\1", value) + '"' + + def parse_list( parse_item: Callable[[str, int, str], Tuple[T, int]], header: str, @@ -392,7 +416,18 @@ def build_subprotocol(protocols: Sequence[Subprotocol]) -> str: build_subprotocol_list = build_subprotocol # alias for backwards-compatibility -def build_basic_auth(username: str, password: str) -> str: +def build_www_authenticate_basic(realm: str) -> str: + """ + Build an WWW-Authenticate header for HTTP Basic Auth. + + """ + # https://tools.ietf.org/html/rfc7617#section-2 + realm = build_quoted_string(realm) + charset = build_quoted_string("UTF-8") + return f"Basic realm={realm}, charset={charset}" + + +def build_authorization_basic(username: str, password: str) -> str: """ Build an Authorization header for HTTP Basic Auth. @@ -402,3 +437,66 @@ def build_basic_auth(username: str, password: str) -> str: user_pass = f"{username}:{password}" basic_credentials = base64.b64encode(user_pass.encode()).decode() return "Basic " + basic_credentials + + +_token68_re = re.compile(r"[A-Za-z0-9-._~+/]+=*") + + +def parse_token68(header: str, pos: int, header_name: str) -> Tuple[str, int]: + """ + Parse a token68 from ``header`` at the given position. + + Return the token value and the new position. + + Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. + + """ + match = _token68_re.match(header, pos) + if match is None: + raise InvalidHeaderFormat(header_name, "expected token68", header, pos) + return match.group(), match.end() + + +def parse_end(header: str, pos: int, header_name: str) -> None: + """ + Check that parsing reached the end of header. + + """ + if pos < len(header): + raise InvalidHeaderFormat(header_name, "trailing data", header, pos) + + +def parse_authorization_basic(header: str) -> Tuple[str, str]: + """ + Parse an Authorization header for HTTP Basic Auth. + + Return a ``(username, password)`` tuple. + + """ + # https://tools.ietf.org/html/rfc7235#section-2.1 + # https://tools.ietf.org/html/rfc7617#section-2 + scheme, pos = parse_token(header, 0, "Authorization") + if scheme.lower() != "basic": + raise InvalidHeaderValue("Authorization", f"unsupported scheme: {scheme}") + if peek_ahead(header, pos) != " ": + raise InvalidHeaderFormat( + "Authorization", "expected space after scheme", header, pos + ) + pos += 1 + basic_credentials, pos = parse_token68(header, pos, "Authorization") + parse_end(header, pos, "Authorization") + + try: + user_pass = base64.b64decode(basic_credentials.encode()).decode() + except binascii.Error: + raise InvalidHeaderValue( + "Authorization", "expected base64-encoded credentials" + ) from None + try: + username, password = user_pass.split(":", 1) + except ValueError: + raise InvalidHeaderValue( + "Authorization", "expected username:password credentials" + ) from None + + return username, password diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 000000000..f6aa5c424 --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,136 @@ +import unittest +import urllib.error + +from websockets.auth import * +from websockets.auth import is_credentials +from websockets.exceptions import InvalidStatusCode +from websockets.headers import build_authorization_basic + +from .test_client_server import ClientServerTestsMixin, with_client, with_server +from .utils import AsyncioTestCase + + +class AuthTests(unittest.TestCase): + def test_is_credentials(self): + self.assertTrue(is_credentials(("username", "password"))) + + def test_is_not_credentials(self): + self.assertFalse(is_credentials(None)) + self.assertFalse(is_credentials("username")) + + +class AuthClientServerTests(ClientServerTestsMixin, AsyncioTestCase): + + create_protocol = basic_auth_protocol_factory( + realm="auth-tests", credentials=("hello", "iloveyou") + ) + + @with_server(create_protocol=create_protocol) + @with_client(user_info=("hello", "iloveyou")) + def test_basic_auth(self): + req_headers = self.client.request_headers + resp_headers = self.client.response_headers + self.assertEqual(req_headers["Authorization"], "Basic aGVsbG86aWxvdmV5b3U=") + self.assertNotIn("WWW-Authenticate", resp_headers) + + self.loop.run_until_complete(self.client.send("Hello!")) + self.loop.run_until_complete(self.client.recv()) + + def test_basic_auth_server_no_credentials(self): + with self.assertRaises(ValueError) as raised: + basic_auth_protocol_factory(realm="auth-tests", credentials=None) + self.assertEqual( + str(raised.exception), "Provide either credentials or check_credentials" + ) + + def test_basic_auth_server_bad_credentials(self): + with self.assertRaises(ValueError) as raised: + basic_auth_protocol_factory(realm="auth-tests", credentials=42) + self.assertEqual(str(raised.exception), "Invalid credentials argument: 42") + + create_protocol_multiple_credentials = basic_auth_protocol_factory( + realm="auth-tests", + credentials=[("hello", "iloveyou"), ("goodbye", "stillloveu")], + ) + + @with_server(create_protocol=create_protocol_multiple_credentials) + @with_client(user_info=("hello", "iloveyou")) + def test_basic_auth_server_multiple_credentials(self): + self.loop.run_until_complete(self.client.send("Hello!")) + self.loop.run_until_complete(self.client.recv()) + + def test_basic_auth_bad_multiple_credentials(self): + with self.assertRaises(ValueError) as raised: + basic_auth_protocol_factory( + realm="auth-tests", credentials=[("hello", "iloveyou"), 42] + ) + self.assertEqual( + str(raised.exception), + "Invalid credentials argument: [('hello', 'iloveyou'), 42]", + ) + + async def check_credentials(username, password): + return password == "iloveyou" + + create_protocol_check_credentials = basic_auth_protocol_factory( + realm="auth-tests", check_credentials=check_credentials + ) + + @with_server(create_protocol=create_protocol_check_credentials) + @with_client(user_info=("hello", "iloveyou")) + def test_basic_auth_check_credentials(self): + self.loop.run_until_complete(self.client.send("Hello!")) + self.loop.run_until_complete(self.client.recv()) + + @with_server(create_protocol=create_protocol) + def test_basic_auth_missing_credentials(self): + with self.assertRaises(InvalidStatusCode) as raised: + self.start_client() + self.assertEqual(raised.exception.status_code, 401) + + @with_server(create_protocol=create_protocol) + def test_basic_auth_missing_credentials_details(self): + with self.assertRaises(urllib.error.HTTPError) as raised: + self.loop.run_until_complete(self.make_http_request()) + self.assertEqual(raised.exception.code, 401) + self.assertEqual( + raised.exception.headers["WWW-Authenticate"], + 'Basic realm="auth-tests", charset="UTF-8"', + ) + self.assertEqual(raised.exception.read().decode(), "Missing credentials\n") + + @with_server(create_protocol=create_protocol) + def test_basic_auth_unsupported_credentials(self): + with self.assertRaises(InvalidStatusCode) as raised: + self.start_client(extra_headers={"Authorization": "Digest ..."}) + self.assertEqual(raised.exception.status_code, 401) + + @with_server(create_protocol=create_protocol) + def test_basic_auth_unsupported_credentials_details(self): + with self.assertRaises(urllib.error.HTTPError) as raised: + self.loop.run_until_complete( + self.make_http_request(headers={"Authorization": "Digest ..."}) + ) + self.assertEqual(raised.exception.code, 401) + self.assertEqual( + raised.exception.headers["WWW-Authenticate"], + 'Basic realm="auth-tests", charset="UTF-8"', + ) + self.assertEqual(raised.exception.read().decode(), "Unsupported credentials\n") + + @with_server(create_protocol=create_protocol) + def test_basic_auth_invalid_credentials(self): + with self.assertRaises(InvalidStatusCode) as raised: + self.start_client(user_info=("hello", "ihateyou")) + self.assertEqual(raised.exception.status_code, 403) + + @with_server(create_protocol=create_protocol) + def test_basic_auth_invalid_credentials_details(self): + with self.assertRaises(urllib.error.HTTPError) as raised: + authorization = build_authorization_basic("hello", "ihateyou") + self.loop.run_until_complete( + self.make_http_request(headers={"Authorization": authorization}) + ) + self.assertEqual(raised.exception.code, 403) + self.assertNotIn("WWW-Authenticate", raised.exception.headers) + self.assertEqual(raised.exception.read().decode(), "Invalid credentials\n") diff --git a/tests/test_client_server.py b/tests/test_client_server.py index 8a1177a7e..d82aa6d40 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -340,6 +340,26 @@ def temp_client(self, *args, **kwds): with temp_test_client(self, *args, **kwds): yield + def make_http_request(self, path="/", headers=None): + if headers is None: + headers = {} + + # Set url to 'https?://:'. + url = get_server_uri( + self.server, resource_name=path, secure=self.secure + ).replace("ws", "http") + + request = urllib.request.Request(url=url, headers=headers) + + if self.secure: + open_health_check = functools.partial( + urllib.request.urlopen, request, context=self.client_context + ) + else: + open_health_check = functools.partial(urllib.request.urlopen, request) + + return self.loop.run_in_executor(None, open_health_check) + class SecureClientServerTestsMixin(ClientServerTestsMixin): @@ -586,13 +606,6 @@ def test_protocol_path(self): server_path = self.loop.run_until_complete(self.client.recv()) self.assertEqual(server_path, "/path") - @with_server() - @with_client("/headers", user_info=("user", "pass")) - def test_protocol_basic_auth(self): - self.assertEqual( - self.client.request_headers["Authorization"], "Basic dXNlcjpwYXNz" - ) - @with_server() @with_client("/headers") def test_protocol_headers(self): @@ -690,20 +703,6 @@ def test_protocol_custom_response_user_agent(self): self.assertEqual(resp_headers.count("Server"), 1) self.assertIn("('Server', 'Eggs')", resp_headers) - def make_http_request(self, path="/"): - # Set url to 'https?://:'. - url = get_server_uri(self.server, resource_name=path, secure=self.secure) - url = url.replace("ws", "http") - - if self.secure: - open_health_check = functools.partial( - urllib.request.urlopen, url, context=self.client_context - ) - else: - open_health_check = functools.partial(urllib.request.urlopen, url) - - return self.loop.run_in_executor(None, open_health_check) - @with_server(create_protocol=HealthCheckServerProtocol) def test_http_request_http_endpoint(self): # Making a HTTP request to a HTTP endpoint succeeds. @@ -979,12 +978,12 @@ def test_compression_deflate_and_explicit_config(self): def test_compression_unsupported_server(self): with self.assertRaises(ValueError): - self.loop.run_until_complete(self.start_server(compression="xz")) + self.start_server(compression="xz") @with_server() def test_compression_unsupported_client(self): with self.assertRaises(ValueError): - self.loop.run_until_complete(self.start_client(compression="xz")) + self.start_client(compression="xz") @with_server() @with_client("/subprotocol") diff --git a/tests/test_headers.py b/tests/test_headers.py index 51a0f33af..26d85fa5e 100644 --- a/tests/test_headers.py +++ b/tests/test_headers.py @@ -1,8 +1,7 @@ import unittest -from websockets.exceptions import InvalidHeaderFormat +from websockets.exceptions import InvalidHeaderFormat, InvalidHeaderValue from websockets.headers import * -from websockets.headers import build_basic_auth class HeadersTests(unittest.TestCase): @@ -17,7 +16,7 @@ def test_parse_connection(self): with self.subTest(header=header): self.assertEqual(parse_connection(header), parsed) - def test_parse_connection_invalid_header(self): + def test_parse_connection_invalid_header_format(self): for header in ["???", "keep-alive; Upgrade"]: with self.subTest(header=header): with self.assertRaises(InvalidHeaderFormat): @@ -35,7 +34,7 @@ def test_parse_upgrade(self): with self.subTest(header=header): self.assertEqual(parse_upgrade(header), parsed) - def test_parse_upgrade_invalid_header(self): + def test_parse_upgrade_invalid_header_format(self): for header in ["???", "websocket 2", "http/3.0; websocket"]: with self.subTest(header=header): with self.assertRaises(InvalidHeaderFormat): @@ -83,7 +82,7 @@ def test_parse_extension(self): unparsed = build_extension(parsed) self.assertEqual(parse_extension(unparsed), parsed) - def test_parse_extension_invalid_header(self): + def test_parse_extension_invalid_header_format(self): for header in [ # Truncated examples "", @@ -127,9 +126,60 @@ def test_parse_subprotocol_invalid_header(self): with self.assertRaises(InvalidHeaderFormat): parse_subprotocol(header) - def test_build_basic_auth(self): - # Test vector from RFC 7617. + def test_build_www_authenticate_basic(self): + # Test vector from RFC 7617 self.assertEqual( - build_basic_auth("Aladdin", "open sesame"), + build_www_authenticate_basic("foo"), 'Basic realm="foo", charset="UTF-8"' + ) + + def test_build_www_authenticate_basic_invalid_realm(self): + # Realm contains a control character forbidden in quoted-string encoding + with self.assertRaises(ValueError): + build_www_authenticate_basic("\u0007") + + def test_build_authorization_basic(self): + # Test vector from RFC 7617 + self.assertEqual( + build_authorization_basic("Aladdin", "open sesame"), "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==", ) + + def test_build_authorization_basic_utf8(self): + # Test vector from RFC 7617 + self.assertEqual( + build_authorization_basic("test", "123£"), "Basic dGVzdDoxMjPCow==" + ) + + def test_parse_authorization_basic(self): + for header, parsed in [ + ("Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==", ("Aladdin", "open sesame")), + # Password contains non-ASCII character + ("Basic dGVzdDoxMjPCow==", ("test", "123£")), + # Password contains a colon + ("Basic YWxhZGRpbjpvcGVuOnNlc2FtZQ==", ("aladdin", "open:sesame")), + # Scheme name must be case insensitive + ("basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==", ("Aladdin", "open sesame")), + ]: + with self.subTest(header=header): + self.assertEqual(parse_authorization_basic(header), parsed) + + def test_parse_authorization_basic_invalid_header_format(self): + for header in [ + "// Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==", + "Basic\tQWxhZGRpbjpvcGVuIHNlc2FtZQ==", + "Basic ****************************", + "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ== //", + ]: + with self.subTest(header=header): + with self.assertRaises(InvalidHeaderFormat): + parse_authorization_basic(header) + + def test_parse_authorization_basic_invalid_header_value(self): + for header in [ + "Digest ...", + "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ", + "Basic QWxhZGNlc2FtZQ==", + ]: + with self.subTest(header=header): + with self.assertRaises(InvalidHeaderValue): + parse_authorization_basic(header) From 918d83f6abcd468998a1a6a51387ae5c42a90297 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 22 Jun 2019 13:04:42 +0200 Subject: [PATCH 0586/1539] Add basic auth examples. --- example/basic_auth_client.py | 14 ++++++++++++++ example/basic_auth_server.py | 20 ++++++++++++++++++++ 2 files changed, 34 insertions(+) create mode 100755 example/basic_auth_client.py create mode 100755 example/basic_auth_server.py diff --git a/example/basic_auth_client.py b/example/basic_auth_client.py new file mode 100755 index 000000000..cc94dbe4b --- /dev/null +++ b/example/basic_auth_client.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python + +# WS client example with HTTP Basic Authentication + +import asyncio +import websockets + +async def hello(): + uri = "ws://mary:p@ssw0rd@localhost:8765" + async with websockets.connect(uri) as websocket: + greeting = await websocket.recv() + print(greeting) + +asyncio.get_event_loop().run_until_complete(hello()) diff --git a/example/basic_auth_server.py b/example/basic_auth_server.py new file mode 100755 index 000000000..6740d5798 --- /dev/null +++ b/example/basic_auth_server.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python + +# Server example with HTTP Basic Authentication over TLS + +import asyncio +import websockets + +async def hello(websocket, path): + greeting = f"Hello {websocket.username}!" + await websocket.send(greeting) + +start_server = websockets.serve( + hello, "localhost", 8765, + create_protocol=websockets.basic_auth_protocol_factory( + realm="example", credentials=("mary", "p@ssw0rd") + ), +) + +asyncio.get_event_loop().run_until_complete(start_server) +asyncio.get_event_loop().run_forever() From c0c31b89c1eb382ca0604a6edb762e0b2c919ed2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 22 Jun 2019 13:18:41 +0200 Subject: [PATCH 0587/1539] Avoid crash caused by type annotations. --- src/websockets/__main__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index 604caa5e4..14bf655b1 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -43,7 +43,7 @@ def win_enable_vt100() -> None: def exit_from_event_loop_thread( - loop: asyncio.AbstractEventLoop, stop: asyncio.Future[None] + loop: asyncio.AbstractEventLoop, stop: "asyncio.Future[None]" ) -> None: loop.stop() if not stop.done(): @@ -91,8 +91,8 @@ def print_over_input(string: str) -> None: async def run_client( uri: str, loop: asyncio.AbstractEventLoop, - inputs: asyncio.Queue[str], - stop: asyncio.Future[None], + inputs: "asyncio.Queue[str]", + stop: "asyncio.Future[None]", ) -> None: try: websocket = await websockets.connect(uri) From 71d476a5141be67daaab82dab729278940085a86 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 22 Jun 2019 13:43:17 +0200 Subject: [PATCH 0588/1539] Handle ConnectionClosed exception in keepalive_ping. Fix #551. Thanks @Harmon758 for reporting this bug and identifying the root cause. --- docs/changelog.rst | 3 +++ src/websockets/protocol.py | 15 ++++++++++----- tests/test_protocol.py | 25 +++++++++++++++++++++++++ 3 files changed, 38 insertions(+), 5 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 77e9da0de..5f22a06eb 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -59,6 +59,9 @@ Also: * Improved support for sending fragmented messages by accepting asynchronous iterators in :meth:`~protocol.WebSocketCommonProtocol.send`. +* Prevented spurious log messages about :exc:`~exceptions.ConnectionClosed` + exceptions in keepalive ping task. + * Avoided a crash of a ``extra_headers`` callable returns ``None``. * Enabled readline in the interactive client. diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index c07aef99f..c46faaf94 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -980,6 +980,7 @@ async def keepalive_ping(self) -> None: This coroutine exits when the connection terminates and one of the following happens: + - :meth:`ping` raises :exc:`ConnectionClosed`, or - :meth:`close_connection` cancels :attr:`keepalive_ping_task`. @@ -991,11 +992,12 @@ async def keepalive_ping(self) -> None: while True: await asyncio.sleep(self.ping_interval, loop=self.loop) - # ping() cannot raise ConnectionClosed, only CancelledError: - # - If the connection is CLOSING, keepalive_ping_task will be - # canceled by close_connection() before ping() returns. - # - If the connection is CLOSED, keepalive_ping_task must be - # canceled already. + # ping() raises CancelledError if the connection is closed, + # when close_connection() cancels self.keepalive_ping_task. + + # ping() raises ConnectionClosed if the connection is lost, + # when connection_lost() calls abort_keepalive_pings(). + ping_waiter = await self.ping() if self.ping_timeout is not None: @@ -1011,6 +1013,9 @@ async def keepalive_ping(self) -> None: except asyncio.CancelledError: raise + except ConnectionClosed: + pass + except Exception: logger.warning("Unexpected exception in keepalive ping task", exc_info=True) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 938e54d8d..57c0c0e6e 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -1074,6 +1074,31 @@ def test_keepalive_ping_stops_when_connection_closed(self): # The keepalive ping task terminated. self.assertTrue(self.protocol.keepalive_ping_task.cancelled()) + def test_keepalive_ping_does_not_crash_when_connection_lost(self): + self.restart_protocol_with_keepalive_ping() + # Clog incoming queue. This lets connection_lost() abort pending pings + # with a ConnectionClosed exception before transfer_data_task + # terminates and close_connection cancels keepalive_ping_task. + self.protocol.max_queue = 1 + self.receive_frame(Frame(True, OP_TEXT, b"1")) + self.receive_frame(Frame(True, OP_TEXT, b"2")) + # Ping is sent at 3ms. + self.loop.run_until_complete(asyncio.sleep(4 * MS)) + ping_waiter, = tuple(self.protocol.pings.values()) + # Connection drops. + self.receive_eof() + self.loop.run_until_complete(self.protocol.wait_closed()) + + # The ping waiter receives a ConnectionClosed exception. + with self.assertRaises(ConnectionClosed): + ping_waiter.result() + # The keepalive ping task terminated properly. + self.assertIsNone(self.protocol.keepalive_ping_task.result()) + + # Unclog incoming queue to terminate the test quickly. + self.loop.run_until_complete(self.protocol.recv()) + self.loop.run_until_complete(self.protocol.recv()) + def test_keepalive_ping_with_no_ping_interval(self): self.restart_protocol_with_keepalive_ping(ping_interval=None) From f8d8a61d8e2c7dcd6eb807d952cfa2b5179b29cb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 22 Jun 2019 17:46:22 +0200 Subject: [PATCH 0589/1539] Handle aborted pings when receiving a pong. Fix #551. Thanks @Harmon758 for reporting this bug and identifying the root cause. --- src/websockets/protocol.py | 3 ++- tests/test_protocol.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index c46faaf94..d6462cc16 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -887,7 +887,8 @@ async def read_data_frame(self, max_size: int) -> Optional[Frame]: while ping_id != frame.data: ping_id, pong_waiter = self.pings.popitem(last=False) ping_ids.append(ping_id) - pong_waiter.set_result(None) + if not pong_waiter.done(): + pong_waiter.set_result(None) pong_hex = binascii.hexlify(frame.data).decode() or "[empty]" logger.debug( "%s - received solicited pong: %s", self.side, pong_hex diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 57c0c0e6e..57cef89e0 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -864,6 +864,35 @@ def test_acknowledge_previous_pings(self): self.assertTrue(pings[1][0].done()) self.assertFalse(pings[2][0].done()) + def test_acknowledge_aborted_ping(self): + ping = self.loop.run_until_complete(self.protocol.ping()) + ping_frame = self.last_sent_frame() + # Clog incoming queue. This lets connection_lost() abort pending pings + # with a ConnectionClosed exception before transfer_data_task + # terminates and close_connection cancels keepalive_ping_task. + self.protocol.max_queue = 1 + self.receive_frame(Frame(True, OP_TEXT, b"1")) + self.receive_frame(Frame(True, OP_TEXT, b"2")) + # Add pong frame to the queue. + pong_frame = Frame(True, OP_PONG, ping_frame.data) + self.receive_frame(pong_frame) + # Connection drops. + self.receive_eof() + self.loop.run_until_complete(self.protocol.wait_closed()) + # Ping receives a ConnectionClosed exception. + with self.assertRaises(ConnectionClosed): + ping.result() + + with self.assertLogs("websockets", level=logging.ERROR) as logs: + # We want to test that no error log is emitted. + # Unfortunately assertLogs expects at least one log message. + logging.getLogger("websockets").error("dummy") + # Unclog incoming queue. + self.loop.run_until_complete(self.protocol.recv()) + self.loop.run_until_complete(self.protocol.recv()) + # transfer_data doesn't crash, which would be logged. + self.assertEqual(logs.output[1:], []) + def test_canceled_ping(self): ping = self.loop.run_until_complete(self.protocol.ping()) ping_frame = self.last_sent_frame() From 34aaf6bcbbac62d8c605d5ba768709346ef87c6e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 22 Jun 2019 18:24:12 +0200 Subject: [PATCH 0590/1539] Update code style for example. Start from what black produces, then wrap at 66 chars and don't skip more than one line. --- docs/deployment.rst | 2 +- docs/intro.rst | 6 +++--- example/client.py | 4 ++-- example/counter.py | 34 +++++++++++++++++++++------------- example/echo.py | 5 +++-- example/health_check_server.py | 7 ++++--- example/hello.py | 6 +++--- example/secure_client.py | 8 +++++--- example/secure_server.py | 7 ++++--- example/server.py | 2 +- example/show_time.py | 4 ++-- example/shutdown.py | 2 +- 12 files changed, 50 insertions(+), 37 deletions(-) diff --git a/docs/deployment.rst b/docs/deployment.rst index 9aa2d3744..797284f3d 100644 --- a/docs/deployment.rst +++ b/docs/deployment.rst @@ -154,4 +154,4 @@ the :meth:`~server.WebSocketServerProtocol.process_request` hook. Typical use cases include health checks. Here's an example: .. literalinclude:: ../example/health_check_server.py - :emphasize-lines: 9-11,17-18 + :emphasize-lines: 9-11,17-19 diff --git a/docs/intro.rst b/docs/intro.rst index 118167b73..8decd462d 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -40,7 +40,7 @@ coroutine returns. Here's a corresponding WebSocket client example. .. literalinclude:: ../example/client.py - :emphasize-lines: 8-10 + :emphasize-lines: 8,10 Using :func:`connect` as an asynchronous context manager ensures the connection is closed before exiting the ``hello`` coroutine. @@ -60,12 +60,12 @@ Here's how to adapt the server example to provide secure connections. See the documentation of the :mod:`ssl` module for configuring the context securely. .. literalinclude:: ../example/secure_server.py - :emphasize-lines: 19,23-24 + :emphasize-lines: 19,23-25 Here's how to adapt the client. .. literalinclude:: ../example/secure_client.py - :emphasize-lines: 10,15-16 + :emphasize-lines: 10,15-18 This client needs a context because the server uses a self-signed certificate. diff --git a/example/client.py b/example/client.py index e71595ff5..4f969c478 100755 --- a/example/client.py +++ b/example/client.py @@ -6,8 +6,8 @@ import websockets async def hello(): - async with websockets.connect( - 'ws://localhost:8765') as websocket: + uri = "ws://localhost:8765" + async with websockets.connect(uri) as websocket: name = input("What's your name? ") await websocket.send(name) diff --git a/example/counter.py b/example/counter.py index 9cce009fd..dbbbe5935 100755 --- a/example/counter.py +++ b/example/counter.py @@ -9,34 +9,41 @@ logging.basicConfig() -STATE = {'value': 0} +STATE = {"value": 0} USERS = set() + def state_event(): - return json.dumps({'type': 'state', **STATE}) + return json.dumps({"type": "state", **STATE}) + def users_event(): - return json.dumps({'type': 'users', 'count': len(USERS)}) + return json.dumps({"type": "users", "count": len(USERS)}) + async def notify_state(): - if USERS: # asyncio.wait doesn't accept an empty list + if USERS: # asyncio.wait doesn't accept an empty list message = state_event() await asyncio.wait([user.send(message) for user in USERS]) + async def notify_users(): - if USERS: # asyncio.wait doesn't accept an empty list + if USERS: # asyncio.wait doesn't accept an empty list message = users_event() await asyncio.wait([user.send(message) for user in USERS]) + async def register(websocket): USERS.add(websocket) await notify_users() + async def unregister(websocket): USERS.remove(websocket) await notify_users() + async def counter(websocket, path): # register(websocket) sends user_event() to websocket await register(websocket) @@ -44,18 +51,19 @@ async def counter(websocket, path): await websocket.send(state_event()) async for message in websocket: data = json.loads(message) - if data['action'] == 'minus': - STATE['value'] -= 1 + if data["action"] == "minus": + STATE["value"] -= 1 await notify_state() - elif data['action'] == 'plus': - STATE['value'] += 1 + elif data["action"] == "plus": + STATE["value"] += 1 await notify_state() else: - logging.error( - "unsupported event: {}", data) + logging.error("unsupported event: {}", data) finally: await unregister(websocket) -asyncio.get_event_loop().run_until_complete( - websockets.serve(counter, 'localhost', 6789)) + +start_server = websockets.serve(counter, "localhost", 6789) + +asyncio.get_event_loop().run_until_complete(start_server) asyncio.get_event_loop().run_forever() diff --git a/example/echo.py b/example/echo.py index 8fa307dd7..b7ca38d32 100755 --- a/example/echo.py +++ b/example/echo.py @@ -7,6 +7,7 @@ async def echo(websocket, path): async for message in websocket: await websocket.send(message) -asyncio.get_event_loop().run_until_complete( - websockets.serve(echo, 'localhost', 8765)) +start_server = websockets.serve(echo, "localhost", 8765) + +asyncio.get_event_loop().run_until_complete(start_server) asyncio.get_event_loop().run_forever() diff --git a/example/health_check_server.py b/example/health_check_server.py index feb04bccd..417063fce 100755 --- a/example/health_check_server.py +++ b/example/health_check_server.py @@ -7,15 +7,16 @@ import websockets async def health_check(path, request_headers): - if path == '/health/': - return http.HTTPStatus.OK, [], b'OK\n' + if path == "/health/": + return http.HTTPStatus.OK, [], b"OK\n" async def echo(websocket, path): async for message in websocket: await websocket.send(message) start_server = websockets.serve( - echo, 'localhost', 8765, process_request=health_check) + echo, "localhost", 8765, process_request=health_check +) asyncio.get_event_loop().run_until_complete(start_server) asyncio.get_event_loop().run_forever() diff --git a/example/hello.py b/example/hello.py index f90c0de55..6c9c839d8 100755 --- a/example/hello.py +++ b/example/hello.py @@ -3,10 +3,10 @@ import asyncio import websockets -async def hello(uri): +async def hello(): + uri = "ws://localhost:8765" async with websockets.connect(uri) as websocket: await websocket.send("Hello world!") await websocket.recv() -asyncio.get_event_loop().run_until_complete( - hello('ws://localhost:8765')) +asyncio.get_event_loop().run_until_complete(hello()) diff --git a/example/secure_client.py b/example/secure_client.py index 8e7f57ff9..54971b984 100755 --- a/example/secure_client.py +++ b/example/secure_client.py @@ -8,12 +8,14 @@ import websockets ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) -ssl_context.load_verify_locations( - pathlib.Path(__file__).with_name('localhost.pem')) +localhost_pem = pathlib.Path(__file__).with_name("localhost.pem") +ssl_context.load_verify_locations(localhost_pem) async def hello(): + uri = "wss://localhost:8765" async with websockets.connect( - 'wss://localhost:8765', ssl=ssl_context) as websocket: + uri, ssl=ssl_context + ) as websocket: name = input("What's your name? ") await websocket.send(name) diff --git a/example/secure_server.py b/example/secure_server.py index 5cbed46c0..2a00bdb50 100755 --- a/example/secure_server.py +++ b/example/secure_server.py @@ -17,11 +17,12 @@ async def hello(websocket, path): print(f"> {greeting}") ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) -ssl_context.load_cert_chain( - pathlib.Path(__file__).with_name('localhost.pem')) +localhost_pem = pathlib.Path(__file__).with_name("localhost.pem") +ssl_context.load_cert_chain(localhost_pem) start_server = websockets.serve( - hello, 'localhost', 8765, ssl=ssl_context) + hello, "localhost", 8765, ssl=ssl_context +) asyncio.get_event_loop().run_until_complete(start_server) asyncio.get_event_loop().run_forever() diff --git a/example/server.py b/example/server.py index cc5c8fea8..c8ab69971 100755 --- a/example/server.py +++ b/example/server.py @@ -14,7 +14,7 @@ async def hello(websocket, path): await websocket.send(greeting) print(f"> {greeting}") -start_server = websockets.serve(hello, 'localhost', 8765) +start_server = websockets.serve(hello, "localhost", 8765) asyncio.get_event_loop().run_until_complete(start_server) asyncio.get_event_loop().run_forever() diff --git a/example/show_time.py b/example/show_time.py index 6d196deb3..e5d6ac9aa 100755 --- a/example/show_time.py +++ b/example/show_time.py @@ -9,11 +9,11 @@ async def time(websocket, path): while True: - now = datetime.datetime.utcnow().isoformat() + 'Z' + now = datetime.datetime.utcnow().isoformat() + "Z" await websocket.send(now) await asyncio.sleep(random.random() * 3) -start_server = websockets.serve(time, '127.0.0.1', 5678) +start_server = websockets.serve(time, "127.0.0.1", 5678) asyncio.get_event_loop().run_until_complete(start_server) asyncio.get_event_loop().run_forever() diff --git a/example/shutdown.py b/example/shutdown.py index 6d75af192..86846abe7 100755 --- a/example/shutdown.py +++ b/example/shutdown.py @@ -9,7 +9,7 @@ async def echo(websocket, path): await websocket.send(message) async def echo_server(stop): - async with websockets.serve(echo, 'localhost', 8765): + async with websockets.serve(echo, "localhost", 8765): await stop loop = asyncio.get_event_loop() From bf6db5ddeda3f2da7a48f69ec9fa6c024fdcbfa8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 22 Jun 2019 18:27:29 +0200 Subject: [PATCH 0591/1539] Encourage users to remove workarounds. Refs #551. --- docs/changelog.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index 5f22a06eb..56d4b9398 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -62,6 +62,8 @@ Also: * Prevented spurious log messages about :exc:`~exceptions.ConnectionClosed` exceptions in keepalive ping task. + If you were using ``ping_timeout=None`` as a workaround, you can remove it. + * Avoided a crash of a ``extra_headers`` callable returns ``None``. * Enabled readline in the interactive client. From 9b89de93b1d00fd404439675ecc1f3f385287cc4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 22 Jun 2019 19:04:07 +0200 Subject: [PATCH 0592/1539] Handle ConnectionClosed when echoing a close frame. Fix #606. Thanks @lgrahl for the bug report. --- src/websockets/protocol.py | 13 +++++++++---- tests/test_protocol.py | 21 +++++++++++++++------ tests/utils.py | 17 +++++++++++++++++ 3 files changed, 41 insertions(+), 10 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index d6462cc16..d888a9729 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -864,10 +864,15 @@ async def read_data_frame(self, max_size: int) -> Optional[Frame]: # 7.1.5. The WebSocket Connection Close Code # 7.1.6. The WebSocket Connection Close Reason self.close_code, self.close_reason = parse_close(frame.data) - # Echo the original data instead of re-serializing it with - # serialize_close() because that fails when the close frame is - # empty and parse_close() synthetizes a 1005 close code. - await self.write_close_frame(frame.data) + try: + # Echo the original data instead of re-serializing it with + # serialize_close() because that fails when the close frame + # is empty and parse_close() synthetizes a 1005 close code. + await self.write_close_frame(frame.data) + except ConnectionClosed: + # It doesn't really matter if the connection was closed + # before we could send back a close frame. + pass return None elif frame.opcode == OP_PING: diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 57cef89e0..0d3185d42 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -883,15 +883,11 @@ def test_acknowledge_aborted_ping(self): with self.assertRaises(ConnectionClosed): ping.result() - with self.assertLogs("websockets", level=logging.ERROR) as logs: - # We want to test that no error log is emitted. - # Unfortunately assertLogs expects at least one log message. - logging.getLogger("websockets").error("dummy") + # transfer_data doesn't crash, which would be logged. + with self.assertNoLogs(): # Unclog incoming queue. self.loop.run_until_complete(self.protocol.recv()) self.loop.run_until_complete(self.protocol.recv()) - # transfer_data doesn't crash, which would be logged. - self.assertEqual(logs.output[1:], []) def test_canceled_ping(self): ping = self.loop.run_until_complete(self.protocol.ping()) @@ -1205,6 +1201,19 @@ def test_remote_close(self): self.assertConnectionClosed(1000, "close") self.assertNoFrameSent() + def test_remote_close_and_connection_lost(self): + self.make_drain_slow() + # Drop the connection right after receiving a close frame, + # which prevents echoing the close frame properly. + self.receive_frame(self.close_frame) + self.receive_eof() + + with self.assertNoLogs(): + self.loop.run_until_complete(self.protocol.close(reason="oh noes!")) + + self.assertConnectionClosed(1000, "close") + self.assertOneFrameSent(*self.close_frame) + def test_simultaneous_close(self): # Receive the incoming close frame right after self.protocol.close() # starts executing. This reproduces the error described in: diff --git a/tests/utils.py b/tests/utils.py index 0a9f14ce1..059efba20 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,4 +1,6 @@ import asyncio +import contextlib +import logging import os import time import unittest @@ -25,6 +27,21 @@ def run_loop_once(self): self.loop.call_soon(self.loop.stop) self.loop.run_forever() + @contextlib.contextmanager + def assertNoLogs(self, logger="websockets", level=logging.ERROR): + """ + No message is logged on the given logger with at least the given level. + + """ + with self.assertLogs(logger, level) as logs: + # We want to test that no log message is emitted + # but assertLogs expects at least one log message. + logging.getLogger(logger).log(level, "dummy") + yield + + level_name = logging.getLevelName(level) + self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) + # Unit for timeouts. May be increased on slow machines by setting the # WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. From 7d429b56b62a263320fc693b6862da757ffb763f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 Jun 2019 18:56:27 +0200 Subject: [PATCH 0593/1539] Rewrite documentation for process_request. Fix #496. --- src/websockets/server.py | 46 ++++++++++++++++++++++++---------------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/src/websockets/server.py b/src/websockets/server.py index 870e4ec7a..9882eabef 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -267,30 +267,40 @@ async def process_request( self, path: str, request_headers: Headers ) -> Optional[HTTPResponse]: """ - Intercept the HTTP request and return an HTTP response if needed. + Intercept the HTTP request and return an HTTP response if appropriate. - ``request_headers`` is a :class:`~websockets.http.Headers` instance. + ``path`` is a :class:`str` and ``request_headers`` is a + :class:`~websockets.http.Headers` instance. - If this coroutine returns ``None``, the WebSocket handshake continues. - If it returns a status code, headers and a response body, that HTTP - response is sent and the connection is closed. + If ``process_request`` returns ``None``, the WebSocket handshake + continues. If it returns a status code, headers and a response body, + that HTTP response is sent and the connection is closed. In that case: - The HTTP status must be a :class:`~http.HTTPStatus`. + * The HTTP status must be a :class:`~http.HTTPStatus`. + * HTTP headers must be a :class:`~websockets.http.Headers` instance, a + :class:`~collections.abc.Mapping`, or an iterable of ``(name, + value)`` pairs. + * The HTTP response body must be :class:`bytes`. It may be empty. - HTTP headers must be a :class:`~websockets.http.Headers` instance, a - :class:`~collections.abc.Mapping`, or an iterable of ``(name, value)`` - pairs. + This coroutine may be overridden in a :class:`WebSocketServerProtocol` + subclass, for example: - The HTTP response body must be :class:`bytes`. It may be empty. + * to return a HTTP 200 :attr:`~http.HTTPStatus.OK` response on a given + path; then a load balancer can use this path for a health check; + * to authenticate the request and return a HTTP 401 + :attr:`~http.HTTPStatus.UNAUTHORIZED` or a HTTP 403 + :attr:`~http.HTTPStatus.FORBIDDEN` when authentication fails. - This coroutine may be overridden to check the request headers and set - a different status, for example to authenticate the request and return - :attr:`http.HTTPStatus.UNAUTHORIZED` or - :attr:`http.HTTPStatus.FORBIDDEN`. - - It may also be overridden by passing a ``process_request`` argument to - the :class:`WebSocketServerProtocol` constructor or the :func:`serve` - function. + Instead of subclassing, it is possible to pass a ``process_request`` + argument to the :class:`WebSocketServerProtocol` constructor or the + :func:`serve` function. This is equivalent, except the + ``process_request`` corountine doesn't have access to the protocol + instance, so it can't store information for later use. + + ``process_request`` is expected to complete quickly. If it may run for + a long time, then it should await :meth:`wait_closed` and exit if + :meth:`wait_closed` completes, or else it could prevent the server + from shutting down. """ if self._process_request is not None: From cac72a7bfdf744fb5d4604317f2ec68caf941751 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 Jun 2019 18:42:13 +0200 Subject: [PATCH 0594/1539] Add a FAQ. Fix #621. --- docs/changelog.rst | 2 + docs/faq.rst | 211 +++++++++++++++++++++++++++++++++++++++++++++ docs/index.rst | 1 + 3 files changed, 214 insertions(+) create mode 100644 docs/faq.rst diff --git a/docs/changelog.rst b/docs/changelog.rst index 56d4b9398..c2719560b 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -70,6 +70,8 @@ Also: * Added type hints (:pep:`484`). +* Added a FAQ to the documentation. + * Added documentation for extensions. * Documented how to optimize memory usage. diff --git a/docs/faq.rst b/docs/faq.rst new file mode 100644 index 000000000..6c5352668 --- /dev/null +++ b/docs/faq.rst @@ -0,0 +1,211 @@ +FAQ +=== + +.. currentmodule:: websockets + +.. note:: + + Many questions asked in :mod:`websockets`' issue tracker are actually + about :mod:`asyncio`. Python's documentation about `developing with + asyncio`_ is a good complement. + + .. _developing with asyncio: https://docs.python.org/3/library/asyncio-dev.html + +Server side +----------- + +Why does the server close the connection after processing one message? +...................................................................... + +Your connection handler exits after processing one message. Write a loop to +process multiple messages. + +For example, if your handler looks like this:: + + async def handler(websocket, path): + print(websocket.recv()) + +change it like this:: + + async def handler(websocket, path): + async for message in websocket: + print(message) + +*Don't feel bad if this happens to you — it's the most common question in +websockets' issue tracker :-)* + +Why can only one client connect at a time? +.......................................... + +Your connection handler blocks the event loop. Look for blocking calls. +Any call that may take some time must be asynchronous. + +For example, if you have:: + + async def handler(websocket, path): + time.sleep(1) + +change it to:: + + async def handler(websocket, path): + await asyncio.sleep(1) + +This is part of learning asyncio. It isn't specific to websockets. + +See also Python's documentation about `running blocking code`_. + +.. _running blocking code: https://docs.python.org/3/library/asyncio-dev.html#running-blocking-code + +How do I get access HTTP headers, for example cookies? +...................................................... + +To access HTTP headers during the WebSocket handshake, you can override +:attr:`~server.WebSocketServerProtocol.process_request`:: + + async def process_request(self, path, request_headers): + cookies = request_header["Cookie"] + +See + +Once the connection is established, they're available in +:attr:`~protocol.WebSocketServerProtocol.request_headers`:: + + async def handler(websocket, path): + cookies = websocket.request_headers["Cookie"] + +How do I get the IP address of the client connecting to my server? +.................................................................. + +It's available in :attr:`~protocol.WebSocketCommonProtocol.remote_address`:: + + async def handler(websocket, path): + remote_ip = websocket.remote_address[0] + +How do I set which IP addresses my server listens to? +..................................................... + +Look at the ``host`` argument of :meth:`~asyncio.loop.create_server`. + +:func:`serve` accepts the same arguments as +:meth:`~asyncio.loop.create_server`. + +How do I close a connection properly? +..................................... + +websockets takes care of closing the connection when the handler exits. + +How do I run a HTTP server and WebSocket server on the same port? +................................................................. + +This isn't supported. + +Providing a HTTP server is out of scope for websockets. It only aims at +providing a WebSocket server. + +There's limited support for returning HTTP responses with the +:attr:`~server.WebSocketServerProtocol.process_request` hook. +If you need more, pick a HTTP server and run it separately. + +Client side +----------- + +How do I close a connection properly? +..................................... + +The easiest is to use :func:`connect` as a context manager:: + + async with connect(...) as websocket: + ... + +How do I reconnect automatically when the connection drops? +........................................................... + +See `issue 414`_. + +.. _issue 414: https://github.com/aaugustin/websockets/issues/414 + +How do I disable SSL certificate verification? +.............................................. + +Look at the ``ssl`` argument of :meth:`~asyncio.loop.create_connection`. + +:func:`connect` accepts the same arguments as +:meth:`~asyncio.loop.create_connection`. + +Architecture +------------ + +How do I do two things in parallel? How do I integrate with another coroutine? +.............................................................................. + +You must start two tasks, which the event loop will run concurrently. You can +achieve this with :func:`asyncio.gather` or :func:`asyncio.wait`. + +This is also part of learning asyncio and not specific to websockets. + +Keep track of the tasks and make sure they terminate or you cancel them when +the connection terminates. + +How do I create channels or topics? +................................... + +websockets doesn't have built-in publish / subscribe for these use cases. + +Depending on the scale of your service, a simple in-memory implementation may +do the job or you may need an external publish / subscribe component. + +Are there ``onopen``, ``onmessage``, ``onerror``, and ``onclose`` callbacks? +............................................................................ + +No, there aren't. + +websockets provides high-level, coroutine-based APIs. Compared to callbacks, +coroutines make it easier to manage control flow in concurrent code. + +If you prefer callback-based APIs, you should use another library. + +Can I use ``websockets`` synchronously, without ``async`` / ``await``? +...................................................................... + +You can convert every asynchronous call to a synchronous call by wrapping it +in ``asyncio.get_event_loop().run_until_complete(...)``. + +If this turns out to be impractical, you should use another library. + +Miscellaneous +------------- + +How do I set a timeout on ``recv()``? +..................................... + +Use :func:`~asyncio.wait_for`:: + + await asyncio.wait_for(websocket.recv(), timeout=10) + +This technique works for most APIs, except for asynchronous context managers. +See `issue 574`_. + +.. _issue 574: https://github.com/aaugustin/websockets/issues/574 + +How do I keep idle connections open? +.................................... + +websockets sends pings at 20 seconds intervals to keep the connection open. + +In closes the connection if it doesn't get a pong within 20 seconds. + +You can adjust this behavior with ``ping_interval`` and ``ping_timeout``. + +How do I respond to pings? +.......................... + +websockets takes care of responding to pings with pongs. + +Is there a Python 2 version? +............................ + +No, there isn't. + +websockets builds upon asyncio which requires Python 3. + + diff --git a/docs/index.rst b/docs/index.rst index 6001d5075..c18af96e4 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -49,6 +49,7 @@ If you're new to ``websockets``, this is the place to start. :maxdepth: 2 intro + faq How-to guides ------------- From aa2a2bb52621626c5661f8be5de4985e18e87acf Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 24 Jun 2019 22:08:44 +0200 Subject: [PATCH 0595/1539] Improve HTTP parsing error messages. Fix #494. --- docs/changelog.rst | 2 ++ src/websockets/client.py | 4 +-- src/websockets/http.py | 74 ++++++++++++++++++++++++++++------------ src/websockets/server.py | 4 +-- tests/test_http.py | 69 +++++++++++++++++++++++++++++-------- 5 files changed, 113 insertions(+), 40 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index c2719560b..92cbce58f 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -66,6 +66,8 @@ Also: * Avoided a crash of a ``extra_headers`` callable returns ``None``. +* Improved error messages when HTTP parsing fails. + * Enabled readline in the interactive client. * Added type hints (:pep:`484`). diff --git a/src/websockets/client.py b/src/websockets/client.py index e6131ed7a..79b03d9e7 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -98,8 +98,8 @@ async def read_http_response(self) -> Tuple[int, Headers]: """ try: status_code, reason, headers = await read_response(self.reader) - except ValueError as exc: - raise InvalidMessage("Malformed HTTP message") from exc + except Exception as exc: + raise InvalidMessage("did not receive a valid HTTP response") from exc logger.debug("%s < HTTP/1.1 %d %s", self.side, status_code, reason) logger.debug("%s < %r", self.side, headers) diff --git a/src/websockets/http.py b/src/websockets/http.py index f0c58061d..6fbe5eb31 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -39,6 +39,21 @@ USER_AGENT = f"Python/{sys.version[:3]} websockets/{websockets_version}" +class SecurityError(ValueError): + """ + HTTP request or response exceeds security limits. + + """ + + +def d(value: bytes) -> str: + """ + Decode a bytestring for interpolating into an error message. + + """ + return value.decode(errors="backslashreplace") + + # See https://tools.ietf.org/html/rfc7230#appendix-B. # Regex for validating header names. @@ -85,15 +100,20 @@ async def read_request(stream: asyncio.StreamReader) -> Tuple[str, "Headers"]: # version and because path isn't checked. Since WebSocket software tends # to implement HTTP/1.1 strictly, there's little need for lenient parsing. - request_line = await read_line(stream) + try: + request_line = await read_line(stream) + except EOFError as exc: + raise EOFError("connection closed while reading HTTP request line") from exc - # This may raise "ValueError: not enough values to unpack" - method, raw_path, version = request_line.split(b" ", 2) + try: + method, raw_path, version = request_line.split(b" ", 2) + except ValueError: # not enough values to unpack (expected 3, got 1-2) + raise ValueError(f"invalid HTTP request line: {d(request_line)}") from None if method != b"GET": - raise ValueError("Unsupported HTTP method: %r" % method) + raise ValueError(f"unsupported HTTP method: {d(method)}") if version != b"HTTP/1.1": - raise ValueError("Unsupported HTTP version: %r" % version) + raise ValueError(f"unsupported HTTP version: {d(version)}") path = raw_path.decode("ascii", "surrogateescape") headers = await read_headers(stream) @@ -125,19 +145,26 @@ async def read_response(stream: asyncio.StreamReader) -> Tuple[int, str, "Header # As in read_request, parsing is simple because a fixed value is expected # for version, status_code is a 3-digit number, and reason can be ignored. - status_line = await read_line(stream) + try: + status_line = await read_line(stream) + except EOFError as exc: + raise EOFError("connection closed while reading HTTP status line") from exc - # This may raise "ValueError: not enough values to unpack" - version, raw_status_code, raw_reason = status_line.split(b" ", 2) + try: + version, raw_status_code, raw_reason = status_line.split(b" ", 2) + except ValueError: # not enough values to unpack (expected 3, got 1-2) + raise ValueError(f"invalid HTTP status line: {d(status_line)}") from None if version != b"HTTP/1.1": - raise ValueError("Unsupported HTTP version: %r" % version) - # This may raise "ValueError: invalid literal for int() with base 10" - status_code = int(raw_status_code) + raise ValueError(f"unsupported HTTP version: {d(version)}") + try: + status_code = int(raw_status_code) + except ValueError: # invalid literal for int() with base 10 + raise ValueError(f"invalid HTTP status code: {d(raw_status_code)}") from None if not 100 <= status_code < 1000: - raise ValueError("Unsupported HTTP status code: %d" % status_code) + raise ValueError(f"unsupported HTTP status code: {d(raw_status_code)}") if not _value_re.fullmatch(raw_reason): - raise ValueError("Invalid HTTP reason phrase: %r" % raw_reason) + raise ValueError(f"invalid HTTP reason phrase: {d(raw_reason)}") reason = raw_reason.decode() headers = await read_headers(stream) @@ -162,24 +189,29 @@ async def read_headers(stream: asyncio.StreamReader) -> "Headers": headers = Headers() for _ in range(MAX_HEADERS + 1): - line = await read_line(stream) + try: + line = await read_line(stream) + except EOFError as exc: + raise EOFError("connection closed while reading HTTP headers") from exc if line == b"": break - # This may raise "ValueError: not enough values to unpack" - raw_name, raw_value = line.split(b":", 1) + try: + raw_name, raw_value = line.split(b":", 1) + except ValueError: # not enough values to unpack (expected 2, got 1) + raise ValueError(f"invalid HTTP header line: {d(line)}") from None if not _token_re.fullmatch(raw_name): - raise ValueError("Invalid HTTP header name: %r" % raw_name) + raise ValueError(f"invalid HTTP header name: {d(raw_name)}") raw_value = raw_value.strip(b" \t") if not _value_re.fullmatch(raw_value): - raise ValueError("Invalid HTTP header value: %r" % raw_value) + raise ValueError(f"invalid HTTP header value: {d(raw_value)}") name = raw_name.decode("ascii") # guaranteed to be ASCII at this point value = raw_value.decode("ascii", "surrogateescape") headers[name] = value else: - raise ValueError("Too many HTTP headers") + raise SecurityError("too many HTTP headers") return headers @@ -197,10 +229,10 @@ async def read_line(stream: asyncio.StreamReader) -> bytes: line = await stream.readline() # Security: this guarantees header values are small (hard-coded = 4 KiB) if len(line) > MAX_LINE: - raise ValueError("Line too long") + raise SecurityError("line too long") # Not mandatory but safe - https://tools.ietf.org/html/rfc7230#section-3.5 if not line.endswith(b"\r\n"): - raise ValueError("Line without CRLF") + raise EOFError("line without CRLF") return line[:-2] diff --git a/src/websockets/server.py b/src/websockets/server.py index 9882eabef..c8eb46351 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -227,8 +227,8 @@ async def read_http_request(self) -> Tuple[str, Headers]: """ try: path, headers = await read_request(self.reader) - except ValueError as exc: - raise InvalidMessage("Malformed HTTP message") from exc + except Exception as exc: + raise InvalidMessage("did not receive a valid HTTP request") from exc logger.debug("%s < GET %s HTTP/1.1", self.side, path) logger.debug("%s < %r", self.side, headers) diff --git a/tests/test_http.py b/tests/test_http.py index 60cdb9a25..8ba1d190f 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -29,6 +29,33 @@ def test_read_request(self): self.assertEqual(path, "/chat") self.assertEqual(headers["Upgrade"], "websocket") + def test_read_request_empty(self): + self.stream.feed_eof() + with self.assertRaisesRegex( + EOFError, "connection closed while reading HTTP request line" + ): + self.loop.run_until_complete(read_request(self.stream)) + + def test_read_request_invalid_request_line(self): + self.stream.feed_data(b"GET /\r\n\r\n") + with self.assertRaisesRegex(ValueError, "invalid HTTP request line: GET /"): + self.loop.run_until_complete(read_request(self.stream)) + + def test_read_request_unsupported_method(self): + self.stream.feed_data(b"OPTIONS * HTTP/1.1\r\n\r\n") + with self.assertRaisesRegex(ValueError, "unsupported HTTP method: OPTIONS"): + self.loop.run_until_complete(read_request(self.stream)) + + def test_read_request_unsupported_version(self): + self.stream.feed_data(b"GET /chat HTTP/1.0\r\n\r\n") + with self.assertRaisesRegex(ValueError, "unsupported HTTP version: HTTP/1.0"): + self.loop.run_until_complete(read_request(self.stream)) + + def test_read_request_invalid_header(self): + self.stream.feed_data(b"GET /chat HTTP/1.1\r\nOops\r\n") + with self.assertRaisesRegex(ValueError, "invalid HTTP header line: Oops"): + self.loop.run_until_complete(read_request(self.stream)) + def test_read_response(self): # Example from the protocol overview in RFC 6455 self.stream.feed_data( @@ -46,29 +73,41 @@ def test_read_response(self): self.assertEqual(reason, "Switching Protocols") self.assertEqual(headers["Upgrade"], "websocket") - def test_request_method(self): - self.stream.feed_data(b"OPTIONS * HTTP/1.1\r\n\r\n") - with self.assertRaises(ValueError): - self.loop.run_until_complete(read_request(self.stream)) + def test_read_response_empty(self): + self.stream.feed_eof() + with self.assertRaisesRegex( + EOFError, "connection closed while reading HTTP status line" + ): + self.loop.run_until_complete(read_response(self.stream)) - def test_request_version(self): - self.stream.feed_data(b"GET /chat HTTP/1.0\r\n\r\n") - with self.assertRaises(ValueError): - self.loop.run_until_complete(read_request(self.stream)) + def test_read_request_invalid_status_line(self): + self.stream.feed_data(b"Hello!\r\n") + with self.assertRaisesRegex(ValueError, "invalid HTTP status line: Hello!"): + self.loop.run_until_complete(read_response(self.stream)) - def test_response_version(self): + def test_read_response_unsupported_version(self): self.stream.feed_data(b"HTTP/1.0 400 Bad Request\r\n\r\n") - with self.assertRaises(ValueError): + with self.assertRaisesRegex(ValueError, "unsupported HTTP version: HTTP/1.0"): self.loop.run_until_complete(read_response(self.stream)) - def test_response_status(self): + def test_read_response_invalid_status(self): + self.stream.feed_data(b"HTTP/1.1 OMG WTF\r\n\r\n") + with self.assertRaisesRegex(ValueError, "invalid HTTP status code: OMG"): + self.loop.run_until_complete(read_response(self.stream)) + + def test_read_response_unsupported_status(self): self.stream.feed_data(b"HTTP/1.1 007 My name is Bond\r\n\r\n") - with self.assertRaises(ValueError): + with self.assertRaisesRegex(ValueError, "unsupported HTTP status code: 007"): self.loop.run_until_complete(read_response(self.stream)) - def test_response_reason(self): + def test_read_response_invalid_reason(self): self.stream.feed_data(b"HTTP/1.1 200 \x7f\r\n\r\n") - with self.assertRaises(ValueError): + with self.assertRaisesRegex(ValueError, "invalid HTTP reason phrase: \\x7f"): + self.loop.run_until_complete(read_response(self.stream)) + + def test_read_response_invalid_header(self): + self.stream.feed_data(b"HTTP/1.1 500 Internal Server Error\r\nOops\r\n") + with self.assertRaisesRegex(ValueError, "invalid HTTP header line: Oops"): self.loop.run_until_complete(read_response(self.stream)) def test_header_name(self): @@ -94,7 +133,7 @@ def test_line_limit(self): def test_line_ending(self): self.stream.feed_data(b"foo: bar\n\n") - with self.assertRaises(ValueError): + with self.assertRaises(EOFError): self.loop.run_until_complete(read_headers(self.stream)) From c854564d2d871beba7b1150ef97d7cf2b6e4f872 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 24 Jun 2019 23:09:08 +0200 Subject: [PATCH 0596/1539] Convert tests for HTTP parsing to async style. Refs #403. --- tests/test_http.py | 78 ++++++++++++++++++++++------------------------ tests/utils.py | 27 ++++++++++++++++ 2 files changed, 65 insertions(+), 40 deletions(-) diff --git a/tests/test_http.py b/tests/test_http.py index 8ba1d190f..cff97fc2f 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -12,7 +12,7 @@ def setUp(self): super().setUp() self.stream = asyncio.StreamReader(loop=self.loop) - def test_read_request(self): + async def test_read_request(self): # Example from the protocol overview in RFC 6455 self.stream.feed_data( b"GET /chat HTTP/1.1\r\n" @@ -25,38 +25,38 @@ def test_read_request(self): b"Sec-WebSocket-Version: 13\r\n" b"\r\n" ) - path, headers = self.loop.run_until_complete(read_request(self.stream)) + path, headers = await read_request(self.stream) self.assertEqual(path, "/chat") self.assertEqual(headers["Upgrade"], "websocket") - def test_read_request_empty(self): + async def test_read_request_empty(self): self.stream.feed_eof() with self.assertRaisesRegex( EOFError, "connection closed while reading HTTP request line" ): - self.loop.run_until_complete(read_request(self.stream)) + await read_request(self.stream) - def test_read_request_invalid_request_line(self): + async def test_read_request_invalid_request_line(self): self.stream.feed_data(b"GET /\r\n\r\n") with self.assertRaisesRegex(ValueError, "invalid HTTP request line: GET /"): - self.loop.run_until_complete(read_request(self.stream)) + await read_request(self.stream) - def test_read_request_unsupported_method(self): + async def test_read_request_unsupported_method(self): self.stream.feed_data(b"OPTIONS * HTTP/1.1\r\n\r\n") with self.assertRaisesRegex(ValueError, "unsupported HTTP method: OPTIONS"): - self.loop.run_until_complete(read_request(self.stream)) + await read_request(self.stream) - def test_read_request_unsupported_version(self): + async def test_read_request_unsupported_version(self): self.stream.feed_data(b"GET /chat HTTP/1.0\r\n\r\n") with self.assertRaisesRegex(ValueError, "unsupported HTTP version: HTTP/1.0"): - self.loop.run_until_complete(read_request(self.stream)) + await read_request(self.stream) - def test_read_request_invalid_header(self): + async def test_read_request_invalid_header(self): self.stream.feed_data(b"GET /chat HTTP/1.1\r\nOops\r\n") with self.assertRaisesRegex(ValueError, "invalid HTTP header line: Oops"): - self.loop.run_until_complete(read_request(self.stream)) + await read_request(self.stream) - def test_read_response(self): + async def test_read_response(self): # Example from the protocol overview in RFC 6455 self.stream.feed_data( b"HTTP/1.1 101 Switching Protocols\r\n" @@ -66,75 +66,73 @@ def test_read_response(self): b"Sec-WebSocket-Protocol: chat\r\n" b"\r\n" ) - status_code, reason, headers = self.loop.run_until_complete( - read_response(self.stream) - ) + status_code, reason, headers = await read_response(self.stream) self.assertEqual(status_code, 101) self.assertEqual(reason, "Switching Protocols") self.assertEqual(headers["Upgrade"], "websocket") - def test_read_response_empty(self): + async def test_read_response_empty(self): self.stream.feed_eof() with self.assertRaisesRegex( EOFError, "connection closed while reading HTTP status line" ): - self.loop.run_until_complete(read_response(self.stream)) + await read_response(self.stream) - def test_read_request_invalid_status_line(self): + async def test_read_request_invalid_status_line(self): self.stream.feed_data(b"Hello!\r\n") with self.assertRaisesRegex(ValueError, "invalid HTTP status line: Hello!"): - self.loop.run_until_complete(read_response(self.stream)) + await read_response(self.stream) - def test_read_response_unsupported_version(self): + async def test_read_response_unsupported_version(self): self.stream.feed_data(b"HTTP/1.0 400 Bad Request\r\n\r\n") with self.assertRaisesRegex(ValueError, "unsupported HTTP version: HTTP/1.0"): - self.loop.run_until_complete(read_response(self.stream)) + await read_response(self.stream) - def test_read_response_invalid_status(self): + async def test_read_response_invalid_status(self): self.stream.feed_data(b"HTTP/1.1 OMG WTF\r\n\r\n") with self.assertRaisesRegex(ValueError, "invalid HTTP status code: OMG"): - self.loop.run_until_complete(read_response(self.stream)) + await read_response(self.stream) - def test_read_response_unsupported_status(self): + async def test_read_response_unsupported_status(self): self.stream.feed_data(b"HTTP/1.1 007 My name is Bond\r\n\r\n") with self.assertRaisesRegex(ValueError, "unsupported HTTP status code: 007"): - self.loop.run_until_complete(read_response(self.stream)) + await read_response(self.stream) - def test_read_response_invalid_reason(self): + async def test_read_response_invalid_reason(self): self.stream.feed_data(b"HTTP/1.1 200 \x7f\r\n\r\n") with self.assertRaisesRegex(ValueError, "invalid HTTP reason phrase: \\x7f"): - self.loop.run_until_complete(read_response(self.stream)) + await read_response(self.stream) - def test_read_response_invalid_header(self): + async def test_read_response_invalid_header(self): self.stream.feed_data(b"HTTP/1.1 500 Internal Server Error\r\nOops\r\n") with self.assertRaisesRegex(ValueError, "invalid HTTP header line: Oops"): - self.loop.run_until_complete(read_response(self.stream)) + await read_response(self.stream) - def test_header_name(self): + async def test_header_name(self): self.stream.feed_data(b"foo bar: baz qux\r\n\r\n") with self.assertRaises(ValueError): - self.loop.run_until_complete(read_headers(self.stream)) + await read_headers(self.stream) - def test_header_value(self): + async def test_header_value(self): self.stream.feed_data(b"foo: \x00\x00\x0f\r\n\r\n") with self.assertRaises(ValueError): - self.loop.run_until_complete(read_headers(self.stream)) + await read_headers(self.stream) - def test_headers_limit(self): + async def test_headers_limit(self): self.stream.feed_data(b"foo: bar\r\n" * 257 + b"\r\n") with self.assertRaises(ValueError): - self.loop.run_until_complete(read_headers(self.stream)) + await read_headers(self.stream) - def test_line_limit(self): + async def test_line_limit(self): # Header line contains 5 + 4090 + 2 = 4097 bytes. self.stream.feed_data(b"foo: " + b"a" * 4090 + b"\r\n\r\n") with self.assertRaises(ValueError): - self.loop.run_until_complete(read_headers(self.stream)) + await read_headers(self.stream) - def test_line_ending(self): + async def test_line_ending(self): self.stream.feed_data(b"foo: bar\n\n") with self.assertRaises(EOFError): - self.loop.run_until_complete(read_headers(self.stream)) + await read_headers(self.stream) class HeadersTests(unittest.TestCase): diff --git a/tests/utils.py b/tests/utils.py index 059efba20..24cdcfa51 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,5 +1,6 @@ import asyncio import contextlib +import functools import logging import os import time @@ -12,6 +13,32 @@ class AsyncioTestCase(unittest.TestCase): """ + def __init_subclass__(cls, **kwargs): + """ + Convert test coroutines to test functions. + + This supports asychronous tests transparently. + + """ + super().__init_subclass__(**kwargs) + for name in unittest.defaultTestLoader.getTestCaseNames(cls): + test = getattr(cls, name) + if asyncio.iscoroutinefunction(test): + setattr(cls, name, cls.convert_async_to_sync(test)) + + @staticmethod + def convert_async_to_sync(test): + """ + Convert a test coroutine to a test function. + + """ + + @functools.wraps(test) + def test_func(self, *args, **kwds): + return self.loop.run_until_complete(test(self, *args, **kwds)) + + return test_func + def setUp(self): super().setUp() self.loop = asyncio.new_event_loop() From 9f9da4478bb2c3f84020abadce118dfad6d53391 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 25 Jun 2019 22:11:36 +0200 Subject: [PATCH 0597/1539] Clarify that extra_headers only applies on success. Refs #611. --- src/websockets/server.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/websockets/server.py b/src/websockets/server.py index c8eb46351..d7d294c29 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -500,8 +500,9 @@ async def handshake( If provided, ``available_subprotocols`` is a list of supported subprotocols in order of decreasing preference. - If provided, ``extra_headers`` sets additional HTTP response headers. - It can be a :class:`~websockets.http.Headers` instance, a + If provided, ``extra_headers`` sets additional HTTP response headers + when the handshake succeeds. It can be a + :class:`~websockets.http.Headers` instance, a :class:`~collections.abc.Mapping`, an iterable of ``(name, value)`` pairs, or a callable taking the request path and headers in arguments and returning one of the above. @@ -779,11 +780,11 @@ class Serve: decreasing preference * ``subprotocols`` is a list of supported subprotocols in order of decreasing preference - * ``extra_headers`` sets additional HTTP response headers — it can be a - :class:`~websockets.http.Headers` instance, a - :class:`~collections.abc.Mapping`, an iterable of ``(name, value)`` - pairs, or a callable taking the request path and headers in arguments - and returning one of the above + * ``extra_headers`` sets additional HTTP response headers when the + handshake succeeds — it can be a :class:`~websockets.http.Headers` + instance, a :class:`~collections.abc.Mapping`, an iterable of ``(name, + value)`` pairs, or a callable taking the request path and headers in + arguments and returning one of the above * ``process_request`` is a coroutine taking the request path and headers in argument, see :meth:`~WebSocketServerProtocol.process_request` for details From d5a670019dd7baf966e5822c37f45c0da4971fdd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 26 Jun 2019 08:27:07 +0200 Subject: [PATCH 0598/1539] Improve error messages for hanshake failures. Fix #611. --- src/websockets/__main__.py | 6 +-- src/websockets/auth.py | 6 +-- src/websockets/client.py | 17 +++---- src/websockets/exceptions.py | 16 +++---- .../extensions/permessage_deflate.py | 16 +++---- src/websockets/framing.py | 16 +++---- src/websockets/protocol.py | 8 ++-- src/websockets/server.py | 18 +++++-- tests/test_auth.py | 6 +-- tests/test_client_server.py | 22 ++++++--- tests/test_exceptions.py | 48 +++++++++---------- 11 files changed, 98 insertions(+), 81 deletions(-) diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index 14bf655b1..57d2a823b 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -27,11 +27,11 @@ def win_enable_vt100() -> None: handle = ctypes.windll.kernel32.GetStdHandle(STD_OUTPUT_HANDLE) if handle == INVALID_HANDLE_VALUE: - raise RuntimeError("Unable to obtain stdout handle") + raise RuntimeError("unable to obtain stdout handle") cur_mode = ctypes.c_uint() if ctypes.windll.kernel32.GetConsoleMode(handle, ctypes.byref(cur_mode)) == 0: - raise RuntimeError("Unable to query current console mode") + raise RuntimeError("unable to query current console mode") # ctypes ints lack support for the required bit-OR operation. # Temporarily convert to Py int, do the OR and convert back. @@ -39,7 +39,7 @@ def win_enable_vt100() -> None: new_mode = ctypes.c_uint(py_int_mode | ENABLE_VIRTUAL_TERMINAL_PROCESSING) if ctypes.windll.kernel32.SetConsoleMode(handle, new_mode) == 0: - raise RuntimeError("Unable to set console mode") + raise RuntimeError("unable to set console mode") def exit_from_event_loop_thread( diff --git a/src/websockets/auth.py b/src/websockets/auth.py index 91d3d7420..60f63e9aa 100644 --- a/src/websockets/auth.py +++ b/src/websockets/auth.py @@ -124,7 +124,7 @@ def basic_auth_protocol_factory( """ if (credentials is None) == (check_credentials is None): - raise ValueError("Provide either credentials or check_credentials") + raise ValueError("provide either credentials or check_credentials") if credentials is not None: if is_credentials(credentials): @@ -141,10 +141,10 @@ async def check_credentials(username: str, password: str) -> bool: return credentials_dict.get(username) == password else: - raise ValueError(f"Invalid credentials argument: {credentials}") + raise ValueError(f"invalid credentials argument: {credentials}") else: - raise ValueError(f"Invalid credentials argument: {credentials}") + raise ValueError(f"invalid credentials argument: {credentials}") return functools.partial( create_protocol, realm=realm, check_credentials=check_credentials diff --git a/src/websockets/client.py b/src/websockets/client.py index 79b03d9e7..9c34d5c23 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -12,6 +12,7 @@ from .exceptions import ( InvalidHandshake, + InvalidHeader, InvalidMessage, InvalidStatusCode, NegotiationError, @@ -146,7 +147,7 @@ def process_extensions( if header_values: if available_extensions is None: - raise InvalidHandshake("No extensions supported") + raise InvalidHandshake("no extensions supported") parsed_header_values: List[ExtensionHeader] = sum( [parse_extension(header_value) for header_value in header_values], [] @@ -203,7 +204,7 @@ def process_subprotocol( if header_values: if available_subprotocols is None: - raise InvalidHandshake("No subprotocols supported") + raise InvalidHandshake("no subprotocols supported") parsed_header_values: Sequence[Subprotocol] = sum( [parse_subprotocol(header_value) for header_value in header_values], [] @@ -211,12 +212,12 @@ def process_subprotocol( if len(parsed_header_values) > 1: subprotocols = ", ".join(parsed_header_values) - raise InvalidHandshake(f"Multiple subprotocols: {subprotocols}") + raise InvalidHandshake(f"multiple subprotocols: {subprotocols}") subprotocol = parsed_header_values[0] if subprotocol not in available_subprotocols: - raise NegotiationError(f"Unsupported subprotocol: {subprotocol}") + raise NegotiationError(f"unsupported subprotocol: {subprotocol}") return subprotocol @@ -293,7 +294,7 @@ async def handshake( status_code, response_headers = await self.read_http_response() if status_code in (301, 302, 303, 307, 308): if "Location" not in response_headers: - raise InvalidMessage("Redirect response missing Location") + raise InvalidHeader("Location") raise RedirectHandshake(response_headers["Location"]) elif status_code != 101: raise InvalidStatusCode(status_code) @@ -429,7 +430,7 @@ def __init__( ClientPerMessageDeflateFactory(client_max_window_bits=True) ] elif compression is not None: - raise ValueError(f"Unsupported compression: {compression}") + raise ValueError(f"unsupported compression: {compression}") self._create_protocol = create_protocol self._ping_interval = ping_interval @@ -535,11 +536,11 @@ async def __await_impl__(self) -> WebSocketClientProtocol: except RedirectHandshake as e: wsuri = parse_uri(e.uri) if self._wsuri.secure and not wsuri.secure: - raise InvalidHandshake("Redirect dropped TLS") + raise InvalidHandshake("redirect dropped TLS") self._wsuri = wsuri continue # redirection chain continues else: - raise InvalidHandshake("Maximum redirects exceeded") + raise InvalidHandshake("maximum redirects exceeded") self.ws_client = protocol return protocol diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 22978ec6f..36a8ed4a8 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -62,7 +62,7 @@ def __init__(self, uri: str) -> None: self.uri = uri def __str__(self) -> str: - return f"Redirect to {self.uri}" + return f"redirect to {self.uri}" class InvalidMessage(InvalidHandshake): @@ -80,11 +80,11 @@ class InvalidHeader(InvalidHandshake): def __init__(self, name: str, value: Optional[str] = None) -> None: if value is None: - message = f"Missing {name} header" + message = f"missing {name} header" elif value == "": - message = f"Empty {name} header" + message = f"empty {name} header" else: - message = f"Invalid {name} header: {value}" + message = f"invalid {name} header: {value}" super().__init__(message) @@ -133,7 +133,7 @@ class InvalidStatusCode(InvalidHandshake): def __init__(self, status_code: int) -> None: self.status_code = status_code - message = f"Status code not 101: {status_code}" + message = f"server rejected WebSocket connection: HTTP {status_code}" super().__init__(message) @@ -152,7 +152,7 @@ class InvalidParameterName(NegotiationError): def __init__(self, name: str) -> None: self.name = name - message = f"Invalid parameter name: {name}" + message = f"invalid parameter name: {name}" super().__init__(message) @@ -165,7 +165,7 @@ class InvalidParameterValue(NegotiationError): def __init__(self, name: str, value: Optional[str]) -> None: self.name = name self.value = value - message = f"Invalid value for parameter {name}: {value}" + message = f"invalid value for parameter {name}: {value}" super().__init__(message) @@ -177,7 +177,7 @@ class DuplicateParameter(NegotiationError): def __init__(self, name: str) -> None: self.name = name - message = f"Duplicate parameter: {name}" + message = f"duplicate parameter: {name}" super().__init__(message) diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 2de27260f..bd4b3fa53 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -334,7 +334,7 @@ def process_response_params( """ if any(other.name == self.name for other in accepted_extensions): - raise NegotiationError(f"Received duplicate {self.name}") + raise NegotiationError(f"received duplicate {self.name}") # Request parameters are available in instance variables. @@ -360,7 +360,7 @@ def process_response_params( if self.server_no_context_takeover: if not server_no_context_takeover: - raise NegotiationError("Expected server_no_context_takeover") + raise NegotiationError("expected server_no_context_takeover") # client_no_context_takeover # @@ -390,9 +390,9 @@ def process_response_params( else: if server_max_window_bits is None: - raise NegotiationError("Expected server_max_window_bits") + raise NegotiationError("expected server_max_window_bits") elif server_max_window_bits > self.server_max_window_bits: - raise NegotiationError("Unsupported server_max_window_bits") + raise NegotiationError("unsupported server_max_window_bits") # client_max_window_bits @@ -408,7 +408,7 @@ def process_response_params( if self.client_max_window_bits is None: if client_max_window_bits is not None: - raise NegotiationError("Unexpected client_max_window_bits") + raise NegotiationError("unexpected client_max_window_bits") elif self.client_max_window_bits is True: pass @@ -417,7 +417,7 @@ def process_response_params( if client_max_window_bits is None: client_max_window_bits = self.client_max_window_bits elif client_max_window_bits > self.client_max_window_bits: - raise NegotiationError("Unsupported client_max_window_bits") + raise NegotiationError("unsupported client_max_window_bits") return PerMessageDeflate( server_no_context_takeover, # remote_no_context_takeover @@ -491,7 +491,7 @@ def process_request_params( """ if any(other.name == self.name for other in accepted_extensions): - raise NegotiationError(f"Skipped duplicate {self.name}") + raise NegotiationError(f"skipped duplicate {self.name}") # Load request parameters in local variables. ( @@ -569,7 +569,7 @@ def process_request_params( else: if client_max_window_bits is None: - raise NegotiationError("Required client_max_window_bits") + raise NegotiationError("required client_max_window_bits") elif client_max_window_bits is True: client_max_window_bits = self.client_max_window_bits elif self.client_max_window_bits < client_max_window_bits: diff --git a/src/websockets/framing.py b/src/websockets/framing.py index 5b694fd40..d668e0c52 100644 --- a/src/websockets/framing.py +++ b/src/websockets/framing.py @@ -133,7 +133,7 @@ async def read( opcode = head1 & 0b00001111 if (True if head2 & 0b10000000 else False) != mask: - raise WebSocketProtocolError("Incorrect masking") + raise WebSocketProtocolError("incorrect masking") length = head2 & 0b01111111 if length == 126: @@ -144,7 +144,7 @@ async def read( length, = struct.unpack("!Q", data) if max_size is not None and length > max_size: raise PayloadTooBig( - f"Payload length exceeds size limit ({length} > {max_size} bytes)" + f"payload length exceeds size limit ({length} > {max_size} bytes)" ) if mask: mask_bits = await reader(4) @@ -252,17 +252,17 @@ def check(frame) -> None: # but it's the instance of class to which this method is bound. if frame.rsv1 or frame.rsv2 or frame.rsv3: - raise WebSocketProtocolError("Reserved bits must be 0") + raise WebSocketProtocolError("reserved bits must be 0") if frame.opcode in DATA_OPCODES: return elif frame.opcode in CTRL_OPCODES: if len(frame.data) > 125: - raise WebSocketProtocolError("Control frame too long") + raise WebSocketProtocolError("control frame too long") if not frame.fin: - raise WebSocketProtocolError("Fragmented control frame") + raise WebSocketProtocolError("fragmented control frame") else: - raise WebSocketProtocolError(f"Invalid opcode: {frame.opcode}") + raise WebSocketProtocolError(f"invalid opcode: {frame.opcode}") def prepare_data(data: Data) -> Tuple[int, bytes]: @@ -338,7 +338,7 @@ def parse_close(data: bytes) -> Tuple[int, str]: return 1005, "" else: assert length == 1 - raise WebSocketProtocolError("Close frame too short") + raise WebSocketProtocolError("close frame too short") def serialize_close(code: int, reason: str) -> bytes: @@ -358,7 +358,7 @@ def check_close(code: int) -> None: """ if not (code in EXTERNAL_CLOSE_CODES or 3000 <= code < 5000): - raise WebSocketProtocolError("Invalid status code") + raise WebSocketProtocolError("invalid status code") # at the bottom to allow circular import, because Extension depends on Frame diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index d888a9729..43dcbd4ff 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -630,7 +630,7 @@ async def ping(self, data: Optional[bytes] = None) -> Awaitable[None]: # Protect against duplicates if a payload is explicitly set. if data in self.pings: - raise ValueError("Already waiting for a pong with the same data") + raise ValueError("already waiting for a pong with the same data") # Generate a unique random payload otherwise. while data is None or data in self.pings: @@ -793,7 +793,7 @@ async def read_message(self) -> Optional[Data]: elif frame.opcode == OP_BINARY: text = False else: # frame.opcode == OP_CONT - raise WebSocketProtocolError("Unexpected opcode") + raise WebSocketProtocolError("unexpected opcode") # Shortcut for the common case - no fragmentation if frame.fin: @@ -838,9 +838,9 @@ def append(frame: Frame) -> None: while not frame.fin: frame = await self.read_data_frame(max_size=max_size) if frame is None: - raise WebSocketProtocolError("Incomplete fragmented message") + raise WebSocketProtocolError("incomplete fragmented message") if frame.opcode != OP_CONT: - raise WebSocketProtocolError("Unexpected opcode") + raise WebSocketProtocolError("unexpected opcode") append(frame) # mypy cannot figure out that chunks have the proper type. diff --git a/src/websockets/server.py b/src/websockets/server.py index d7d294c29..c37aec93f 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -147,28 +147,36 @@ async def handler(self) -> None: status, headers, body = ( http.HTTPStatus.FORBIDDEN, Headers(), - (str(exc) + "\n").encode(), + f"Failed to open a WebSocket connection: {exc}.\n".encode(), ) elif isinstance(exc, InvalidUpgrade): logger.debug("Invalid upgrade", exc_info=True) status, headers, body = ( http.HTTPStatus.UPGRADE_REQUIRED, Headers([("Upgrade", "websocket")]), - (str(exc) + "\n").encode(), + ( + f"Failed to open a WebSocket connection: {exc}.\n" + f"\n" + f"You cannot access a WebSocket server directly " + f"with a browser. You need a WebSocket client.\n" + ).encode(), ) elif isinstance(exc, InvalidHandshake): logger.debug("Invalid handshake", exc_info=True) status, headers, body = ( http.HTTPStatus.BAD_REQUEST, Headers(), - (str(exc) + "\n").encode(), + f"Failed to open a WebSocket connection: {exc}.\n".encode(), ) else: logger.warning("Error in opening handshake", exc_info=True) status, headers, body = ( http.HTTPStatus.INTERNAL_SERVER_ERROR, Headers(), - b"See server log for more information.\n", + ( + b"Failed to open a WebSocket connection.\n" + b"See server log for more information.\n" + ), ) headers.setdefault("Date", email.utils.formatdate(usegmt=True)) @@ -880,7 +888,7 @@ def __init__( ): extensions = list(extensions) + [ServerPerMessageDeflateFactory()] elif compression is not None: - raise ValueError(f"Unsupported compression: {compression}") + raise ValueError(f"unsupported compression: {compression}") factory = lambda: create_protocol( ws_handler, diff --git a/tests/test_auth.py b/tests/test_auth.py index f6aa5c424..bcd340844 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -40,13 +40,13 @@ def test_basic_auth_server_no_credentials(self): with self.assertRaises(ValueError) as raised: basic_auth_protocol_factory(realm="auth-tests", credentials=None) self.assertEqual( - str(raised.exception), "Provide either credentials or check_credentials" + str(raised.exception), "provide either credentials or check_credentials" ) def test_basic_auth_server_bad_credentials(self): with self.assertRaises(ValueError) as raised: basic_auth_protocol_factory(realm="auth-tests", credentials=42) - self.assertEqual(str(raised.exception), "Invalid credentials argument: 42") + self.assertEqual(str(raised.exception), "invalid credentials argument: 42") create_protocol_multiple_credentials = basic_auth_protocol_factory( realm="auth-tests", @@ -66,7 +66,7 @@ def test_basic_auth_bad_multiple_credentials(self): ) self.assertEqual( str(raised.exception), - "Invalid credentials argument: [('hello', 'iloveyou'), 42]", + "invalid credentials argument: [('hello', 'iloveyou'), 42]", ) async def check_credentials(username, password): diff --git a/tests/test_client_server.py b/tests/test_client_server.py index d82aa6d40..35b662eb9 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -18,7 +18,7 @@ from websockets.exceptions import ( ConnectionClosed, InvalidHandshake, - InvalidMessage, + InvalidHeader, InvalidStatusCode, NegotiationError, ) @@ -155,7 +155,7 @@ def get_server_uri(server, secure=False, resource_name="/", user_info=None): # The host and port are ignored when connecting to a Unix socket. host, port = "localhost", 0 else: # pragma: no cover - raise ValueError("Expected an IPv6, IPv4, or Unix socket") + raise ValueError("expected an IPv6, IPv4, or Unix socket") return f"{proto}://{user_info}{host}:{port}{resource_name}" @@ -429,7 +429,7 @@ def test_redirect_missing_location(self): with temp_test_redirecting_server( self, http.HTTPStatus.FOUND, include_location=False ): - with self.assertRaises(InvalidMessage): + with self.assertRaises(InvalidHeader): with temp_test_client(self): self.fail("Did not raise") # pragma: no cover @@ -1149,7 +1149,9 @@ def test_server_shuts_down_during_opening_handshake(self): with self.assertRaises(InvalidStatusCode) as raised: self.start_client() exception = raised.exception - self.assertEqual(str(exception), "Status code not 101: 503") + self.assertEqual( + str(exception), "server rejected WebSocket connection: HTTP 503" + ) self.assertEqual(exception.status_code, 503) @with_server() @@ -1197,7 +1199,9 @@ def test_invalid_status_error_during_client_connect(self): with self.assertRaises(InvalidStatusCode) as raised: self.start_client() exception = raised.exception - self.assertEqual(str(exception), "Status code not 101: 403") + self.assertEqual( + str(exception), "server rejected WebSocket connection: HTTP 403" + ) self.assertEqual(exception.status_code, 403) @with_server() @@ -1283,7 +1287,9 @@ def test_checking_origin_fails(self): server = self.loop.run_until_complete( serve(handler, "localhost", 0, origins=["http://localhost"]) ) - with self.assertRaisesRegex(InvalidHandshake, "Status code not 101: 403"): + with self.assertRaisesRegex( + InvalidHandshake, "server rejected WebSocket connection: HTTP 403" + ): self.loop.run_until_complete( connect(get_server_uri(server), origin="http://otherhost") ) @@ -1295,7 +1301,9 @@ def test_checking_origins_fails_with_multiple_headers(self): server = self.loop.run_until_complete( serve(handler, "localhost", 0, origins=["http://localhost"]) ) - with self.assertRaisesRegex(InvalidHandshake, "Status code not 101: 400"): + with self.assertRaisesRegex( + InvalidHandshake, "server rejected WebSocket connection: HTTP 400" + ): self.loop.run_until_complete( connect( get_server_uri(server), diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 27e1b53ca..fbc06e576 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -9,8 +9,8 @@ def test_str(self): for exception, exception_str in [ # fmt: off ( - InvalidHandshake("Invalid request"), - "Invalid request", + InvalidHandshake("invalid request"), + "invalid request", ), ( AbortHandshake(200, Headers(), b"OK\n"), @@ -18,70 +18,70 @@ def test_str(self): ), ( RedirectHandshake("wss://example.com"), - "Redirect to wss://example.com", + "redirect to wss://example.com", ), ( - InvalidMessage("Malformed HTTP message"), - "Malformed HTTP message", + InvalidMessage("malformed HTTP message"), + "malformed HTTP message", ), ( InvalidHeader("Name"), - "Missing Name header", + "missing Name header", ), ( InvalidHeader("Name", None), - "Missing Name header", + "missing Name header", ), ( InvalidHeader("Name", ""), - "Empty Name header", + "empty Name header", ), ( InvalidHeader("Name", "Value"), - "Invalid Name header: Value", + "invalid Name header: Value", ), ( InvalidHeaderFormat( "Sec-WebSocket-Protocol", "expected token", "a=|", 3 ), - "Invalid Sec-WebSocket-Protocol header: " + "invalid Sec-WebSocket-Protocol header: " "expected token at 3 in a=|", ), ( InvalidHeaderValue("Sec-WebSocket-Version", "42"), - "Invalid Sec-WebSocket-Version header: 42", + "invalid Sec-WebSocket-Version header: 42", ), ( InvalidUpgrade("Upgrade"), - "Missing Upgrade header", + "missing Upgrade header", ), ( InvalidUpgrade("Connection", "websocket"), - "Invalid Connection header: websocket", + "invalid Connection header: websocket", ), ( InvalidOrigin("http://bad.origin"), - "Invalid Origin header: http://bad.origin", + "invalid Origin header: http://bad.origin", ), ( InvalidStatusCode(403), - "Status code not 101: 403", + "server rejected WebSocket connection: HTTP 403", ), ( - NegotiationError("Unsupported subprotocol: spam"), - "Unsupported subprotocol: spam", + NegotiationError("unsupported subprotocol: spam"), + "unsupported subprotocol: spam", ), ( InvalidParameterName("|"), - "Invalid parameter name: |", + "invalid parameter name: |", ), ( InvalidParameterValue("a", "|"), - "Invalid value for parameter a: |", + "invalid value for parameter a: |", ), ( DuplicateParameter("a"), - "Duplicate parameter: a", + "duplicate parameter: a", ), ( InvalidState("WebSocket connection isn't established yet"), @@ -122,12 +122,12 @@ def test_str(self): "| isn't a valid URI", ), ( - PayloadTooBig("Payload length exceeds limit: 2 > 1 bytes"), - "Payload length exceeds limit: 2 > 1 bytes", + PayloadTooBig("payload length exceeds limit: 2 > 1 bytes"), + "payload length exceeds limit: 2 > 1 bytes", ), ( - WebSocketProtocolError("Invalid opcode: 7"), - "Invalid opcode: 7", + WebSocketProtocolError("invalid opcode: 7"), + "invalid opcode: 7", ), # fmt: on ]: From b55ccf8d44911d3d62d55146a72cd40d96138f58 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 26 Jun 2019 20:50:18 +0200 Subject: [PATCH 0599/1539] Standardize to **kwargs. There was a mix of **kwargs and **kwds, perhaps due to Guido using **kwds and websockets mirroring asyncio APIs before it was called asyncio. --- src/websockets/client.py | 18 +++++------ src/websockets/server.py | 12 +++---- tests/test_client_server.py | 62 ++++++++++++++++++------------------- tests/utils.py | 4 +-- 4 files changed, 48 insertions(+), 48 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 9c34d5c23..8dd8a0dd1 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -59,13 +59,13 @@ def __init__( extensions: Optional[Sequence[ClientExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLike] = None, - **kwds: Any, + **kwargs: Any, ) -> None: self.origin = origin self.available_extensions = extensions self.available_subprotocols = subprotocols self.extra_headers = extra_headers - super().__init__(**kwds) + super().__init__(**kwargs) def write_http_request(self, path: str, headers: Headers) -> None: """ @@ -387,7 +387,7 @@ def __init__( extensions: Optional[Sequence[ClientExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLike] = None, - **kwds: Any, + **kwargs: Any, ) -> None: # Backwards compatibility: close_timeout used to be called timeout. if timeout is None: @@ -412,8 +412,8 @@ def __init__( self._wsuri = parse_uri(uri) if self._wsuri.secure: - kwds.setdefault("ssl", True) - elif kwds.get("ssl") is not None: + kwargs.setdefault("ssl", True) + elif kwargs.get("ssl") is not None: raise ValueError( "connect() received a SSL context for a ws:// URI, " "use a wss:// URI to enable TLS" @@ -449,13 +449,13 @@ def __init__( self._extensions = extensions self._subprotocols = subprotocols self._extra_headers = extra_headers - self._kwds = kwds + self._kwargs = kwargs async def _creating_connection( self ) -> Tuple[asyncio.Transport, WebSocketClientProtocol]: if self._wsuri.secure: - self._kwds.setdefault("ssl", True) + self._kwargs.setdefault("ssl", True) factory = lambda: self._create_protocol( host=self._wsuri.host, @@ -478,7 +478,7 @@ async def _creating_connection( host: Optional[str] port: Optional[int] - if self._kwds.get("sock") is None: + if self._kwargs.get("sock") is None: host, port = self._wsuri.host, self._wsuri.port else: # If sock is given, host and port mustn't be specified. @@ -490,7 +490,7 @@ async def _creating_connection( # This is a coroutine object. # https://github.com/python/typeshed/pull/2756 transport, protocol = await self._loop.create_connection( # type: ignore - factory, host, port, **self._kwds + factory, host, port, **self._kwargs ) transport = cast(asyncio.Transport, transport) protocol = cast(WebSocketClientProtocol, protocol) diff --git a/src/websockets/server.py b/src/websockets/server.py index c37aec93f..547656e0c 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -89,7 +89,7 @@ def __init__( select_subprotocol: Optional[ Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] ] = None, - **kwds: Any, + **kwargs: Any, ) -> None: # For backwards-compatibility with 6.0 or earlier. if origins is not None and "" in origins: @@ -103,7 +103,7 @@ def __init__( self.extra_headers = extra_headers self._process_request = process_request self._select_subprotocol = select_subprotocol - super().__init__(**kwds) + super().__init__(**kwargs) def connection_made(self, transport: asyncio.BaseTransport) -> None: """ @@ -852,7 +852,7 @@ def __init__( select_subprotocol: Optional[ Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] ] = None, - **kwds: Any, + **kwargs: Any, ) -> None: # Backwards-compatibility: close_timeout used to be called timeout. if timeout is None: @@ -877,7 +877,7 @@ def __init__( ws_server = WebSocketServer(loop) - secure = kwds.get("ssl") is not None + secure = kwargs.get("ssl") is not None if compression == "deflate": if extensions is None: @@ -917,9 +917,9 @@ def __init__( # https://github.com/python/typeshed/pull/2763 host = cast(str, host) port = cast(int, port) - creating_server = loop.create_server(factory, host, port, **kwds) + creating_server = loop.create_server(factory, host, port, **kwargs) else: - creating_server = loop.create_unix_server(factory, path, **kwds) + creating_server = loop.create_unix_server(factory, path, **kwargs) # This is a coroutine object. self._creating_server = creating_server diff --git a/tests/test_client_server.py b/tests/test_client_server.py index 35b662eb9..613143dbb 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -71,8 +71,8 @@ async def handler(ws, path): @contextlib.contextmanager -def temp_test_server(test, **kwds): - test.start_server(**kwds) +def temp_test_server(test, **kwargs): + test.start_server(**kwargs) try: yield finally: @@ -91,15 +91,15 @@ def temp_test_redirecting_server( @contextlib.contextmanager -def temp_test_client(test, *args, **kwds): - test.start_client(*args, **kwds) +def temp_test_client(test, *args, **kwargs): + test.start_client(*args, **kwargs) try: yield finally: test.stop_client() -def with_manager(manager, *args, **kwds): +def with_manager(manager, *args, **kwargs): """ Return a decorator that wraps a function with a context manager. @@ -107,29 +107,29 @@ def with_manager(manager, *args, **kwds): def decorate(func): @functools.wraps(func) - def _decorate(self, *_args, **_kwds): - with manager(self, *args, **kwds): - return func(self, *_args, **_kwds) + def _decorate(self, *_args, **_kwargs): + with manager(self, *args, **kwargs): + return func(self, *_args, **_kwargs) return _decorate return decorate -def with_server(**kwds): +def with_server(**kwargs): """ Return a decorator for TestCase methods that starts and stops a server. """ - return with_manager(temp_test_server, **kwds) + return with_manager(temp_test_server, **kwargs) -def with_client(*args, **kwds): +def with_client(*args, **kwargs): """ Return a decorator for TestCase methods that starts and stops a client. """ - return with_manager(temp_test_client, *args, **kwds) + return with_manager(temp_test_client, *args, **kwargs) def get_server_uri(server, secure=False, resource_name="/", user_info=None): @@ -240,14 +240,14 @@ def setUp(self): def server_context(self): return None - def start_server(self, expected_warning=None, **kwds): + def start_server(self, expected_warning=None, **kwargs): # Disable compression by default in tests. - kwds.setdefault("compression", None) + kwargs.setdefault("compression", None) # Disable pings by default in tests. - kwds.setdefault("ping_interval", None) + kwargs.setdefault("ping_interval", None) with warnings.catch_warnings(record=True) as recorded_warnings: - start_server = serve(handler, "localhost", 0, **kwds) + start_server = serve(handler, "localhost", 0, **kwargs) self.server = self.loop.run_until_complete(start_server) if expected_warning is None: @@ -280,18 +280,18 @@ async def process_request(path, headers): self.redirecting_server = self.loop.run_until_complete(start_server) def start_client( - self, resource_name="/", user_info=None, expected_warning=None, **kwds + self, resource_name="/", user_info=None, expected_warning=None, **kwargs ): # Disable compression by default in tests. - kwds.setdefault("compression", None) + kwargs.setdefault("compression", None) # Disable pings by default in tests. - kwds.setdefault("ping_interval", None) - secure = kwds.get("ssl") is not None + kwargs.setdefault("ping_interval", None) + secure = kwargs.get("ssl") is not None server = self.redirecting_server if self.redirecting_server else self.server server_uri = get_server_uri(server, secure, resource_name, user_info) with warnings.catch_warnings(record=True) as recorded_warnings: - start_client = connect(server_uri, **kwds) + start_client = connect(server_uri, **kwargs) self.client = self.loop.run_until_complete(start_client) if expected_warning is None: @@ -331,13 +331,13 @@ def stop_redirecting_server(self): self.redirecting_server = None @contextlib.contextmanager - def temp_server(self, **kwds): - with temp_test_server(self, **kwds): + def temp_server(self, **kwargs): + with temp_test_server(self, **kwargs): yield @contextlib.contextmanager - def temp_client(self, *args, **kwds): - with temp_test_client(self, *args, **kwds): + def temp_client(self, *args, **kwargs): + with temp_test_client(self, *args, **kwargs): yield def make_http_request(self, path="/", headers=None): @@ -377,13 +377,13 @@ def client_context(self): ssl_context.load_verify_locations(testcert) return ssl_context - def start_server(self, **kwds): - kwds.setdefault("ssl", self.server_context) - super().start_server(**kwds) + def start_server(self, **kwargs): + kwargs.setdefault("ssl", self.server_context) + super().start_server(**kwargs) - def start_client(self, path="/", **kwds): - kwds.setdefault("ssl", self.client_context) - super().start_client(path, **kwds) + def start_client(self, path="/", **kwargs): + kwargs.setdefault("ssl", self.client_context) + super().start_client(path, **kwargs) class CommonClientServerTests: diff --git a/tests/utils.py b/tests/utils.py index 24cdcfa51..2c067f8e6 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -34,8 +34,8 @@ def convert_async_to_sync(test): """ @functools.wraps(test) - def test_func(self, *args, **kwds): - return self.loop.run_until_complete(test(self, *args, **kwds)) + def test_func(self, *args, **kwargs): + return self.loop.run_until_complete(test(self, *args, **kwargs)) return test_func From c1dd59331749a859bc79201c8da62ea3a71811a9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 29 Jun 2019 10:11:24 +0200 Subject: [PATCH 0600/1539] Refer to TLS consistently. And clarify the relationship with SSL. --- docs/faq.rst | 4 ++-- docs/intro.rst | 3 ++- src/websockets/client.py | 2 +- tests/test_client_server.py | 2 +- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/docs/faq.rst b/docs/faq.rst index 6c5352668..3dfdb5bcd 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -124,8 +124,8 @@ See `issue 414`_. .. _issue 414: https://github.com/aaugustin/websockets/issues/414 -How do I disable SSL certificate verification? -.............................................. +How do I disable TLS/SSL certificate verification? +.................................................. Look at the ``ssl`` argument of :meth:`~asyncio.loop.create_connection`. diff --git a/docs/intro.rst b/docs/intro.rst index 8decd462d..14ba1b38a 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -54,7 +54,8 @@ Secure WebSocket connections improve confidentiality and also reliability because they reduce the risk of interference by bad proxies. The WSS protocol is to WS what HTTPS is to HTTP: the connection is encrypted -with TLS. WSS requires TLS certificates like HTTPS. +with Transport Layer Security (TLS) — which is often referred to as Secure +Sockets Layer (SSL). WSS requires TLS certificates like HTTPS. Here's how to adapt the server example to provide secure connections. See the documentation of the :mod:`ssl` module for configuring the context securely. diff --git a/src/websockets/client.py b/src/websockets/client.py index 8dd8a0dd1..110e61f69 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -415,7 +415,7 @@ def __init__( kwargs.setdefault("ssl", True) elif kwargs.get("ssl") is not None: raise ValueError( - "connect() received a SSL context for a ws:// URI, " + "connect() received a ssl argument for a ws:// URI, " "use a wss:// URI to enable TLS" ) diff --git a/tests/test_client_server.py b/tests/test_client_server.py index 613143dbb..a88002364 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -1215,7 +1215,7 @@ def test_connection_error_during_opening_handshake( _read_http_request.side_effect = ConnectionError # This exception is currently platform-dependent. It was observed to - # be ConnectionResetError on Linux in the non-SSL case, and + # be ConnectionResetError on Linux in the non-TLS case, and # InvalidMessage otherwise (including both Linux and macOS). This # doesn't matter though since this test is primarily for testing a # code path on the server side. From e146ace7caf42462af79de4fe3d0e0c4f1e2e8dc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 29 Jun 2019 10:15:28 +0200 Subject: [PATCH 0601/1539] Add consistency checks on serve() arguments. It's only possible to hit these assertions by not respecting the documented signatures of serve() and unix_serve(). --- src/websockets/server.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/websockets/server.py b/src/websockets/server.py index 547656e0c..7c268c257 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -914,11 +914,16 @@ def __init__( ) if path is None: + # serve(..., host, port) must specify host and port parameters. + # host can be None to listen on all interfaces; port cannot be None. + assert port is not None # https://github.com/python/typeshed/pull/2763 - host = cast(str, host) - port = cast(int, port) - creating_server = loop.create_server(factory, host, port, **kwargs) + creating_server = loop.create_server( # type: ignore + factory, host, port, **kwargs + ) else: + # unix_serve(path) must not specify host and port parameters. + assert host is None and port is None creating_server = loop.create_unix_server(factory, path, **kwargs) # This is a coroutine object. @@ -966,6 +971,8 @@ def unix_serve( """ Similar to :func:`serve()`, but for listening on Unix sockets. + ``path`` is the path to the Unix socket. + This function calls the event loop's :meth:`~asyncio.AbstractEventLoop.create_unix_server` method. From 87a9ec06ce119ad50bb54a250514fc426e8ad370 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 27 Jun 2019 07:54:17 +0200 Subject: [PATCH 0602/1539] Move SecurityError to exceptions module. --- src/websockets/exceptions.py | 8 ++++++++ src/websockets/http.py | 11 ++++------- tests/test_exceptions.py | 5 +++++ tests/test_http.py | 5 +++-- 4 files changed, 20 insertions(+), 9 deletions(-) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 36a8ed4a8..ce2c1e64b 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -25,6 +25,7 @@ "NegotiationError", "PayloadTooBig", "RedirectHandshake", + "SecurityError", "WebSocketProtocolError", ] @@ -52,6 +53,13 @@ def __init__( super().__init__(message) +class SecurityError(InvalidHandshake): + """ + Exception raised when a HTTP request or response breaks security rules. + + """ + + class RedirectHandshake(InvalidHandshake): """ Exception raised when a handshake gets redirected. diff --git a/src/websockets/http.py b/src/websockets/http.py index 6fbe5eb31..04424c6c5 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -39,13 +39,6 @@ USER_AGENT = f"Python/{sys.version[:3]} websockets/{websockets_version}" -class SecurityError(ValueError): - """ - HTTP request or response exceeds security limits. - - """ - - def d(value: bytes) -> str: """ Decode a bytestring for interpolating into an error message. @@ -211,6 +204,8 @@ async def read_headers(stream: asyncio.StreamReader) -> "Headers": headers[name] = value else: + from .exceptions import SecurityError # avoid circular import + raise SecurityError("too many HTTP headers") return headers @@ -229,6 +224,8 @@ async def read_line(stream: asyncio.StreamReader) -> bytes: line = await stream.readline() # Security: this guarantees header values are small (hard-coded = 4 KiB) if len(line) > MAX_LINE: + from .exceptions import SecurityError # avoid circular import + raise SecurityError("line too long") # Not mandatory but safe - https://tools.ietf.org/html/rfc7230#section-3.5 if not line.endswith(b"\r\n"): diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index fbc06e576..2cbd78671 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -16,6 +16,11 @@ def test_str(self): AbortHandshake(200, Headers(), b"OK\n"), "HTTP 200, 0 headers, 3 bytes", ), + ( + SecurityError("redirect from WSS to WS"), + "redirect from WSS to WS", + + ), ( RedirectHandshake("wss://example.com"), "redirect to wss://example.com", diff --git a/tests/test_http.py b/tests/test_http.py index cff97fc2f..41b522c3d 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -1,6 +1,7 @@ import asyncio import unittest +from websockets.exceptions import SecurityError from websockets.http import * from websockets.http import read_headers @@ -120,13 +121,13 @@ async def test_header_value(self): async def test_headers_limit(self): self.stream.feed_data(b"foo: bar\r\n" * 257 + b"\r\n") - with self.assertRaises(ValueError): + with self.assertRaises(SecurityError): await read_headers(self.stream) async def test_line_limit(self): # Header line contains 5 + 4090 + 2 = 4097 bytes. self.stream.feed_data(b"foo: " + b"a" * 4090 + b"\r\n\r\n") - with self.assertRaises(ValueError): + with self.assertRaises(SecurityError): await read_headers(self.stream) async def test_line_ending(self): From 626544bc58565b19dc11f74ebe9b8fe25ff411b6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 27 Jun 2019 21:44:40 +0200 Subject: [PATCH 0603/1539] Refactor redirect handling in connect(). This reverts parts of 00458f27 and uses a less usual but less verbose approach. _redirect() make look like a hack but it uses public APIs. This approach minimizes divergence between the client and server implementations. Also it will make it easier to implement new features in connect(). --- src/websockets/client.py | 132 +++++++++++++++++++-------------------- src/websockets/server.py | 19 +++--- 2 files changed, 75 insertions(+), 76 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 110e61f69..943c4bbe7 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -5,6 +5,7 @@ import asyncio import collections.abc +import functools import logging import warnings from types import TracebackType @@ -17,6 +18,7 @@ InvalidStatusCode, NegotiationError, RedirectHandshake, + SecurityError, ) from .extensions.base import ClientExtensionFactory, Extension from .extensions.permessage_deflate import ClientPerMessageDeflateFactory @@ -410,8 +412,8 @@ def __init__( if loop is None: loop = asyncio.get_event_loop() - self._wsuri = parse_uri(uri) - if self._wsuri.secure: + wsuri = parse_uri(uri) + if wsuri.secure: kwargs.setdefault("ssl", True) elif kwargs.get("ssl") is not None: raise ValueError( @@ -432,69 +434,65 @@ def __init__( elif compression is not None: raise ValueError(f"unsupported compression: {compression}") - self._create_protocol = create_protocol - self._ping_interval = ping_interval - self._ping_timeout = ping_timeout - self._close_timeout = close_timeout - self._max_size = max_size - self._max_queue = max_queue - self._read_limit = read_limit - self._write_limit = write_limit - self._loop = loop - self._legacy_recv = legacy_recv - self._klass = klass - self._timeout = timeout - self._compression = compression - self._origin = origin - self._extensions = extensions - self._subprotocols = subprotocols - self._extra_headers = extra_headers - self._kwargs = kwargs - - async def _creating_connection( - self - ) -> Tuple[asyncio.Transport, WebSocketClientProtocol]: - if self._wsuri.secure: - self._kwargs.setdefault("ssl", True) - - factory = lambda: self._create_protocol( - host=self._wsuri.host, - port=self._wsuri.port, - secure=self._wsuri.secure, - ping_interval=self._ping_interval, - ping_timeout=self._ping_timeout, - close_timeout=self._close_timeout, - max_size=self._max_size, - max_queue=self._max_queue, - read_limit=self._read_limit, - write_limit=self._write_limit, - loop=self._loop, - legacy_recv=self._legacy_recv, - origin=self._origin, - extensions=self._extensions, - subprotocols=self._subprotocols, - extra_headers=self._extra_headers, + factory = functools.partial( + create_protocol, + host=wsuri.host, + port=wsuri.port, + secure=wsuri.secure, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + close_timeout=close_timeout, + max_size=max_size, + max_queue=max_queue, + read_limit=read_limit, + write_limit=write_limit, + loop=loop, + legacy_recv=legacy_recv, + origin=origin, + extensions=extensions, + subprotocols=subprotocols, + extra_headers=extra_headers, ) host: Optional[str] port: Optional[int] - if self._kwargs.get("sock") is None: - host, port = self._wsuri.host, self._wsuri.port + if kwargs.get("sock") is None: + host, port = wsuri.host, wsuri.port else: - # If sock is given, host and port mustn't be specified. + # If sock is given, host and port shouldn't be specified. host, port = None, None - self._wsuri = self._wsuri - self._origin = self._origin + # This is a coroutine function. + self._create_connection = functools.partial( + loop.create_connection, factory, host, port, **kwargs + ) + + self._wsuri = wsuri + self._origin = origin - # This is a coroutine object. - # https://github.com/python/typeshed/pull/2756 - transport, protocol = await self._loop.create_connection( # type: ignore - factory, host, port, **self._kwargs + def _redirect(self, uri: str) -> None: + old_wsuri = self._wsuri + factory, old_host, old_port = self._create_connection.args + + new_wsuri = parse_uri(uri) + new_host, new_port = new_wsuri.host, new_wsuri.port + if old_wsuri.secure and not new_wsuri.secure: + raise SecurityError("redirect from WSS to WS") + + # Replace the host and port argument passed to the protocol factory. + factory = self._create_connection.args[0] + factory_keywords = dict(factory.keywords, host=new_host, port=new_port) + factory = functools.partial(factory.func, *factory.args, **factory_keywords) + + # Replace the host and port argument passed to create_connection. + create_connection_args = (factory, new_host, new_port) + self._create_connection = functools.partial( + self._create_connection.func, + *create_connection_args, + **self._create_connection.keywords, ) - transport = cast(asyncio.Transport, transport) - protocol = cast(WebSocketClientProtocol, protocol) - return transport, protocol + + self._wsuri = new_wsuri # async with connect(...) @@ -517,7 +515,10 @@ def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]: async def __await_impl__(self) -> WebSocketClientProtocol: for redirects in range(self.MAX_REDIRECTS_ALLOWED): - transport, protocol = await self._creating_connection() + transport, protocol = await self._create_connection() + # https://github.com/python/typeshed/pull/2756 + transport = cast(asyncio.Transport, transport) + protocol = cast(WebSocketClientProtocol, protocol) try: try: @@ -528,22 +529,17 @@ async def __await_impl__(self) -> WebSocketClientProtocol: available_subprotocols=protocol.available_subprotocols, extra_headers=protocol.extra_headers, ) - break # redirection chain ended except Exception: protocol.fail_connection() await protocol.wait_closed() raise - except RedirectHandshake as e: - wsuri = parse_uri(e.uri) - if self._wsuri.secure and not wsuri.secure: - raise InvalidHandshake("redirect dropped TLS") - self._wsuri = wsuri - continue # redirection chain continues + else: + self.ws_client = protocol + return protocol + except RedirectHandshake as exc: + self._redirect(exc.uri) else: - raise InvalidHandshake("maximum redirects exceeded") - - self.ws_client = protocol - return protocol + raise SecurityError("too many redirects") # yield from connect(...) diff --git a/src/websockets/server.py b/src/websockets/server.py index 7c268c257..c02b67e03 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -6,6 +6,7 @@ import asyncio import collections.abc import email.utils +import functools import http import logging import socket @@ -890,7 +891,8 @@ def __init__( elif compression is not None: raise ValueError(f"unsupported compression: {compression}") - factory = lambda: create_protocol( + factory = functools.partial( + create_protocol, ws_handler, ws_server, host=host, @@ -917,17 +919,18 @@ def __init__( # serve(..., host, port) must specify host and port parameters. # host can be None to listen on all interfaces; port cannot be None. assert port is not None - # https://github.com/python/typeshed/pull/2763 - creating_server = loop.create_server( # type: ignore - factory, host, port, **kwargs + create_server = functools.partial( + loop.create_server, factory, host, port, **kwargs ) else: # unix_serve(path) must not specify host and port parameters. assert host is None and port is None - creating_server = loop.create_unix_server(factory, path, **kwargs) + create_server = functools.partial( + loop.create_unix_server, factory, path, **kwargs + ) - # This is a coroutine object. - self._creating_server = creating_server + # This is a coroutine function. + self._create_server = create_server self.ws_server = ws_server # async with serve(...) @@ -951,7 +954,7 @@ def __await__(self) -> Generator[Any, None, WebSocketServer]: return self.__await_impl__().__await__() async def __await_impl__(self) -> WebSocketServer: - server = await self._creating_server + server = await self._create_server() self.ws_server.wrap(server) return self.ws_server From 721ef99dab6efebbf1aad29ba127387f5e129855 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 29 Jun 2019 10:31:06 +0200 Subject: [PATCH 0604/1539] Add unix_connect to connect to a Unix socket.. Fix #539. --- docs/api.rst | 3 +++ docs/changelog.rst | 2 ++ example/unix_client.py | 19 ++++++++++++++++ example/unix_server.py | 22 ++++++++++++++++++ src/websockets/client.py | 45 ++++++++++++++++++++++++++++--------- tests/test_client_server.py | 16 +++++-------- 6 files changed, 86 insertions(+), 21 deletions(-) create mode 100755 example/unix_client.py create mode 100755 example/unix_server.py diff --git a/docs/api.rst b/docs/api.rst index ef567ed5b..56372eb11 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -59,6 +59,9 @@ Client .. autofunction:: connect(uri, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, **kwds) :async: + .. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, **kwds) + :async: + .. autoclass:: WebSocketClientProtocol(*, host=None, port=None, secure=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, extensions=None, subprotocols=None, extra_headers=None) .. automethod:: handshake diff --git a/docs/changelog.rst b/docs/changelog.rst index 92cbce58f..761c8b8fc 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -56,6 +56,8 @@ Also: * :func:`~client.connect` handles redirects from the server during the handshake. +* Added :func:`~client.unix_connect` for connecting to Unix sockets. + * Improved support for sending fragmented messages by accepting asynchronous iterators in :meth:`~protocol.WebSocketCommonProtocol.send`. diff --git a/example/unix_client.py b/example/unix_client.py new file mode 100755 index 000000000..577135b3d --- /dev/null +++ b/example/unix_client.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python + +# WS client example connecting to a Unix socket + +import asyncio +import os.path +import websockets + +async def hello(): + socket_path = os.path.join(os.path.dirname(__file__), "socket") + async with websockets.unix_connect(socket_path) as websocket: + name = input("What's your name? ") + await websocket.send(name) + print(f"> {name}") + + greeting = await websocket.recv() + print(f"< {greeting}") + +asyncio.get_event_loop().run_until_complete(hello()) diff --git a/example/unix_server.py b/example/unix_server.py new file mode 100755 index 000000000..a6ec0168a --- /dev/null +++ b/example/unix_server.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python + +# WS server example listening on a Unix socket + +import asyncio +import os.path +import websockets + +async def hello(websocket, path): + name = await websocket.recv() + print(f"< {name}") + + greeting = f"Hello {name}!" + + await websocket.send(greeting) + print(f"> {greeting}") + +socket_path = os.path.join(os.path.dirname(__file__), "socket") +start_server = websockets.unix_serve(hello, socket_path) + +asyncio.get_event_loop().run_until_complete(start_server) +asyncio.get_event_loop().run_forever() diff --git a/src/websockets/client.py b/src/websockets/client.py index 943c4bbe7..4da8c3b50 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -37,7 +37,7 @@ from .uri import WebSocketURI, parse_uri -__all__ = ["connect", "WebSocketClientProtocol"] +__all__ = ["connect", "unix_connect", "WebSocketClientProtocol"] logger = logging.getLogger(__name__) @@ -372,6 +372,7 @@ def __init__( self, uri: str, *, + path: Optional[str] = None, create_protocol: Optional[Type[WebSocketClientProtocol]] = None, ping_interval: float = 20, ping_timeout: float = 20, @@ -454,19 +455,24 @@ def __init__( extra_headers=extra_headers, ) - host: Optional[str] - port: Optional[int] - if kwargs.get("sock") is None: - host, port = wsuri.host, wsuri.port + if path is None: + host: Optional[str] + port: Optional[int] + if kwargs.get("sock") is None: + host, port = wsuri.host, wsuri.port + else: + # If sock is given, host and port shouldn't be specified. + host, port = None, None + create_connection = functools.partial( + loop.create_connection, factory, host, port, **kwargs + ) else: - # If sock is given, host and port shouldn't be specified. - host, port = None, None + create_connection = functools.partial( + loop.create_unix_connection, factory, path, **kwargs + ) # This is a coroutine function. - self._create_connection = functools.partial( - loop.create_connection, factory, host, port, **kwargs - ) - + self._create_connection = create_connection self._wsuri = wsuri self._origin = origin @@ -547,3 +553,20 @@ async def __await_impl__(self) -> WebSocketClientProtocol: connect = Connect + + +def unix_connect(path: str, uri: str = "ws://localhost/", **kwargs: Any) -> Connect: + """ + Similar to :func:`connect`, but for connecting to a Unix socket. + + ``path`` is the path to the Unix socket. ``uri`` is the WebSocket URI. + + This function calls the event loop's + :meth:`~asyncio.AbstractEventLoop.create_unix_connection` method. + + It is only available on Unix. + + It's mainly useful for debugging servers listening on Unix sockets. + + """ + return connect(uri=uri, path=path, **kwargs) diff --git a/tests/test_client_server.py b/tests/test_client_server.py index a88002364..738d92ff0 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -151,9 +151,6 @@ def get_server_uri(server, secure=False, resource_name="/", user_info=None): host = f"[{host}]" elif server_socket.family == socket.AF_INET: host, port = server_socket.getsockname() - elif server_socket.family == socket.AF_UNIX: - # The host and port are ignored when connecting to a Unix socket. - host, port = "localhost", 0 else: # pragma: no cover raise ValueError("expected an IPv6, IPv4, or Unix socket") @@ -489,18 +486,17 @@ def test_unix_socket(self): # Like self.start_server() but with unix_serve(). unix_server = unix_serve(handler, path) self.server = self.loop.run_until_complete(unix_server) - - client_socket = socket.socket(socket.AF_UNIX) - client_socket.connect(path) - try: - with self.temp_client(sock=client_socket): + # Like self.start_client() but with unix_connect() + unix_client = unix_connect(path) + self.client = self.loop.run_until_complete(unix_client) + try: self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") - + finally: + self.stop_client() finally: - client_socket.close() self.stop_server() async def process_request_OK(path, request_headers): From 752f4145cd06d303f7ac2ddc8c1fdffe9c492bff Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 29 Jun 2019 12:52:11 +0200 Subject: [PATCH 0605/1539] Support overriding host and port in connect(). Fix #540. Thanks @Kirill888 for the report and initial patch. --- docs/changelog.rst | 2 ++ src/websockets/client.py | 41 +++++++++++++++++++++++-------------- tests/test_client_server.py | 23 +++++++++++++++++++-- 3 files changed, 49 insertions(+), 17 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 761c8b8fc..7a02ec0e7 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -56,6 +56,8 @@ Also: * :func:`~client.connect` handles redirects from the server during the handshake. +* :func:`~client.connect` supports overriding ``host`` and ``port``. + * Added :func:`~client.unix_connect` for connecting to Unix sockets. * Improved support for sending fragmented messages by accepting asynchronous diff --git a/src/websockets/client.py b/src/websockets/client.py index 4da8c3b50..abcf9dc62 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -334,6 +334,11 @@ class Connect: a ``wss://`` URI, if this argument isn't provided explicitly, it's set to ``True``, which means Python's default :class:`~ssl.SSLContext` is used. + You can connect to a different host and port from those found in ``uri`` + by setting ``host`` and ``port`` keyword arguments. This only changes the + destination of the TCP connection; the hostname from ``uri`` is still used + in the TLS handshake for secure connections and in the ``Host`` header. + The behavior of the ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit`` optional arguments is described in the documentation of @@ -463,6 +468,9 @@ def __init__( else: # If sock is given, host and port shouldn't be specified. host, port = None, None + # If host and port are given, override values from the URI. + host = kwargs.pop("host", host) + port = kwargs.pop("port", port) create_connection = functools.partial( loop.create_connection, factory, host, port, **kwargs ) @@ -478,25 +486,28 @@ def __init__( def _redirect(self, uri: str) -> None: old_wsuri = self._wsuri - factory, old_host, old_port = self._create_connection.args - new_wsuri = parse_uri(uri) - new_host, new_port = new_wsuri.host, new_wsuri.port + if old_wsuri.secure and not new_wsuri.secure: raise SecurityError("redirect from WSS to WS") - # Replace the host and port argument passed to the protocol factory. - factory = self._create_connection.args[0] - factory_keywords = dict(factory.keywords, host=new_host, port=new_port) - factory = functools.partial(factory.func, *factory.args, **factory_keywords) - - # Replace the host and port argument passed to create_connection. - create_connection_args = (factory, new_host, new_port) - self._create_connection = functools.partial( - self._create_connection.func, - *create_connection_args, - **self._create_connection.keywords, - ) + # Only rewrite the host and port arguments is they change in the URI. + # This preserves connection overrides with the host, port, or sock + # arguments if the redirect points to the same host and port. + if old_wsuri.host != new_wsuri.host or old_wsuri.port != new_wsuri.port: + # Replace the host and port argument passed to the protocol factory. + factory = self._create_connection.args[0] + factory = functools.partial( + factory.func, + *factory.args, + **dict(factory.keywords, host=new_wsuri.host, port=new_wsuri.port), + ) + # Replace the host and port argument passed to create_connection. + self._create_connection = functools.partial( + self._create_connection.func, + *(factory, new_wsuri.host, new_wsuri.port), + **self._create_connection.keywords, + ) self._wsuri = new_wsuri diff --git a/tests/test_client_server.py b/tests/test_client_server.py index 738d92ff0..7281ec6bd 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -31,6 +31,7 @@ from websockets.http import USER_AGENT, Headers, read_response from websockets.protocol import State from websockets.server import * +from websockets.uri import parse_uri from .test_protocol import MS from .utils import AsyncioTestCase @@ -284,8 +285,11 @@ def start_client( # Disable pings by default in tests. kwargs.setdefault("ping_interval", None) secure = kwargs.get("ssl") is not None - server = self.redirecting_server if self.redirecting_server else self.server - server_uri = get_server_uri(server, secure, resource_name, user_info) + try: + server_uri = kwargs.pop("uri") + except KeyError: + server = self.redirecting_server if self.redirecting_server else self.server + server_uri = get_server_uri(server, secure, resource_name, user_info) with warnings.catch_warnings(record=True) as recorded_warnings: start_client = connect(server_uri, **kwargs) @@ -437,6 +441,21 @@ def test_explicit_event_loop(self): reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") + @with_server() + def test_explicit_host_port(self): + uri = get_server_uri(self.server, self.secure) + wsuri = parse_uri(uri) + + # Change host and port to invalid values. + changed_uri = uri.replace(wsuri.host, "example.com").replace( + str(wsuri.port), str(65535 - wsuri.port) + ) + + with self.temp_client(uri=changed_uri, host=wsuri.host, port=wsuri.port): + self.loop.run_until_complete(self.client.send("Hello!")) + reply = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(reply, "Hello!") + @with_server() def test_explicit_socket(self): class TrackedSocket(socket.socket): From f967833ee3c8215e49edd4033d1efb3985a895ad Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 29 Jun 2019 17:47:20 +0200 Subject: [PATCH 0606/1539] Minor code changes for readability. --- src/websockets/client.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index abcf9dc62..10435c1ff 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -484,17 +484,23 @@ def __init__( self._wsuri = wsuri self._origin = origin - def _redirect(self, uri: str) -> None: + def handle_redirect(self, uri: str) -> None: + # Update the state of this instance to connect to a new URI. old_wsuri = self._wsuri new_wsuri = parse_uri(uri) + # Forbid TLS downgrade. if old_wsuri.secure and not new_wsuri.secure: raise SecurityError("redirect from WSS to WS") - # Only rewrite the host and port arguments is they change in the URI. - # This preserves connection overrides with the host, port, or sock + same_origin = ( + old_wsuri.host == new_wsuri.host and old_wsuri.port == new_wsuri.port + ) + + # Rewrite the host and port arguments for cross-origin redirects. + # This preserves connection overrides with the host and port # arguments if the redirect points to the same host and port. - if old_wsuri.host != new_wsuri.host or old_wsuri.port != new_wsuri.port: + if not same_origin: # Replace the host and port argument passed to the protocol factory. factory = self._create_connection.args[0] factory = functools.partial( @@ -509,6 +515,7 @@ def _redirect(self, uri: str) -> None: **self._create_connection.keywords, ) + # Set the new WebSocket URI. This suffices for same-origin redirects. self._wsuri = new_wsuri # async with connect(...) @@ -554,7 +561,7 @@ async def __await_impl__(self) -> WebSocketClientProtocol: self.ws_client = protocol return protocol except RedirectHandshake as exc: - self._redirect(exc.uri) + self.handle_redirect(exc.uri) else: raise SecurityError("too many redirects") From dd653dbe551a88dec4491fd5d83c9eefa236213a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 30 Jun 2019 14:32:22 +0200 Subject: [PATCH 0607/1539] Add WebSocketServer to server.__all__. Fix #562. Thanks @lgrahl for the suggestion. --- src/websockets/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/websockets/server.py b/src/websockets/server.py index c02b67e03..8e1db9b7c 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -50,7 +50,7 @@ from .typing import Origin, Subprotocol -__all__ = ["serve", "unix_serve", "WebSocketServerProtocol"] +__all__ = ["serve", "unix_serve", "WebSocketServerProtocol", "WebSocketServer"] logger = logging.getLogger(__name__) From 6386867594a685c026099fb307ae0efd36ca6095 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 30 Jun 2019 21:48:44 +0200 Subject: [PATCH 0608/1539] Handle import loops consistently. --- src/websockets/http.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/websockets/http.py b/src/websockets/http.py index 04424c6c5..46b09c2e6 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -204,9 +204,7 @@ async def read_headers(stream: asyncio.StreamReader) -> "Headers": headers[name] = value else: - from .exceptions import SecurityError # avoid circular import - - raise SecurityError("too many HTTP headers") + raise websockets.exceptions.SecurityError("too many HTTP headers") return headers @@ -224,9 +222,7 @@ async def read_line(stream: asyncio.StreamReader) -> bytes: line = await stream.readline() # Security: this guarantees header values are small (hard-coded = 4 KiB) if len(line) > MAX_LINE: - from .exceptions import SecurityError # avoid circular import - - raise SecurityError("line too long") + raise websockets.exceptions.SecurityError("line too long") # Not mandatory but safe - https://tools.ietf.org/html/rfc7230#section-3.5 if not line.endswith(b"\r\n"): raise EOFError("line without CRLF") @@ -364,3 +360,7 @@ def raw_items(self) -> Iterator[Tuple[str, str]]: HeadersLike = Union[Headers, Mapping[str, str], Iterable[Tuple[str, str]]] + + +# at the bottom to allow circular import, because AbortHandshake depends on HeadersLike +import websockets.exceptions # isort:skip # noqa From e832c565b6ac85b5c1a80c7e6eab15eaead31440 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 1 Jul 2019 21:37:09 +0200 Subject: [PATCH 0609/1539] Fix InvalidStateError when failing the connection. This exception occurred when: - the incoming queue was full - the connection terminated with an error - recv() was called at the wrong time Fix #634. --- src/websockets/protocol.py | 2 +- tests/test_protocol.py | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 43dcbd4ff..5161017b2 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -725,7 +725,7 @@ async def transfer_data(self) -> None: while len(self.messages) >= self.max_queue: self._put_message_waiter = self.loop.create_future() try: - await self._put_message_waiter + await asyncio.shield(self._put_message_waiter) finally: self._put_message_waiter = None diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 0d3185d42..321d20f63 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -459,12 +459,15 @@ def test_recv_queue_full(self): self.assertEqual(list(self.protocol.messages), ["café", b"tea"]) self.loop.run_until_complete(self.protocol.recv()) + self.run_loop_once() self.assertEqual(list(self.protocol.messages), [b"tea", b"milk"]) self.loop.run_until_complete(self.protocol.recv()) + self.run_loop_once() self.assertEqual(list(self.protocol.messages), [b"milk"]) self.loop.run_until_complete(self.protocol.recv()) + self.run_loop_once() self.assertEqual(list(self.protocol.messages), []) def test_recv_queue_no_limit(self): @@ -519,6 +522,27 @@ def test_recv_canceled_race_condition(self): # If we're getting "tea" there, it means "café" was swallowed (ha, ha). self.assertEqual(data, "café") + def test_recv_when_transfer_data_cancelled(self): + # Clog incoming queue. + self.protocol.max_queue = 1 + self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) + self.receive_frame(Frame(True, OP_BINARY, b"tea")) + self.run_loop_once() + + # Flow control kicks in (check with an implementation detail). + self.assertFalse(self.protocol._put_message_waiter.done()) + + # Schedule recv(). + recv = self.loop.create_task(self.protocol.recv()) + + # Cancel transfer_data_task (again, implementation detail). + self.protocol.fail_connection() + self.run_loop_once() + self.assertTrue(self.protocol.transfer_data_task.cancelled()) + + # recv() completes properly. + self.assertEqual(self.loop.run_until_complete(recv), "café") + def test_recv_prevents_concurrent_calls(self): recv = self.loop.create_task(self.protocol.recv()) From d601f68b7edfed92fbb7566511bea927f324b3c2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 1 Jul 2019 21:54:46 +0200 Subject: [PATCH 0610/1539] Serialize sending fragmented messages. While sending a fragmented message, no other data frame can be sent. Fix #542. --- src/websockets/protocol.py | 23 ++++++++++++++++++++++- tests/test_protocol.py | 32 +++++++++++++++++++++++++++++++- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 5161017b2..8eab48651 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -258,6 +258,9 @@ def __init__( self._pop_message_waiter: Optional[asyncio.Future[None]] = None self._put_message_waiter: Optional[asyncio.Future[None]] = None + # Flag that protects sending fragmented messages. + self.sending_fragmented_message = False + # Mapping of ping IDs to waiters, in chronological order. self.pings: collections.OrderedDict[ bytes, asyncio.Future[None] @@ -418,7 +421,7 @@ async def recv(self) -> Data: """ if self._pop_message_waiter is not None: raise RuntimeError( - "cannot call recv() while another coroutine " + "cannot call recv while another coroutine " "is already waiting for the next message" ) @@ -487,6 +490,13 @@ async def send( """ await self.ensure_open() + # Prevent sending other messages until all fragments are sent. + if self.sending_fragmented_message: + raise RuntimeError( + "cannot call send while another coroutine " + "is sending a fragmented message" + ) + # Unfragmented message -- this case must be handled first because # strings and bytes-like objects are iterable. @@ -503,6 +513,8 @@ async def send( iter_message = iter(message) + self.sending_fragmented_message = True + # First fragment. try: message_chunk = next(iter_message) @@ -521,6 +533,9 @@ async def send( raise TypeError("data contains inconsistent types") await self.write_frame(False, OP_CONT, data) + # write_frame() will write to the buffer before yielding control. + self.sending_fragmented_message = False + # Final fragment. await self.write_frame(True, OP_CONT, b"") @@ -530,6 +545,9 @@ async def send( # aiter_message = aiter(message) without aiter aiter_message = type(message).__aiter__(message) + # Prevent sending other messages until all fragments are sent. + self.sending_fragmented_message = True + # First fragment. try: # message_chunk = anext(aiter_message) without anext @@ -549,6 +567,9 @@ async def send( raise TypeError("data contains inconsistent types") await self.write_frame(False, OP_CONT, data) + # write_frame() will write to the buffer before yielding control. + self.sending_fragmented_message = False + # Final fragment. await self.write_frame(True, OP_CONT, b"") diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 321d20f63..342b3255e 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -546,7 +546,11 @@ def test_recv_when_transfer_data_cancelled(self): def test_recv_prevents_concurrent_calls(self): recv = self.loop.create_task(self.protocol.recv()) - with self.assertRaises(RuntimeError): + with self.assertRaisesRegex( + RuntimeError, + "cannot call recv while another coroutine " + "is already waiting for the next message", + ): self.loop.run_until_complete(self.protocol.recv()) recv.cancel() @@ -633,6 +637,19 @@ def test_send_iterable_mixed_type_error(self): (True, OP_CLOSE, serialize_close(1011, "")), ) + def test_send_iterable_prevents_concurrent_send(self): + self.make_drain_slow() + send = self.loop.create_task(self.protocol.send(["ca", "fé"])) + + with self.assertRaisesRegex( + RuntimeError, + "cannot call send while another coroutine " + "is sending a fragmented message", + ): + self.loop.run_until_complete(self.protocol.send("tea")) + + send.cancel() + def test_send_async_iterable_text(self): self.loop.run_until_complete(self.protocol.send(async_iterable(["ca", "fé"]))) self.assertFramesSent( @@ -692,6 +709,19 @@ def test_send_async_iterable_mixed_type_error(self): (True, OP_CLOSE, serialize_close(1011, "")), ) + def test_send_async_iterable_prevents_concurrent_send(self): + self.make_drain_slow() + send = self.loop.create_task(self.protocol.send(async_iterable(["ca", "fé"]))) + + with self.assertRaisesRegex( + RuntimeError, + "cannot call send while another coroutine " + "is sending a fragmented message", + ): + self.loop.run_until_complete(self.protocol.send("tea")) + + send.cancel() + def test_send_on_closing_connection_local(self): close_task = self.half_close_connection_local() From 7ef66541192be89373a904c27908a411400f5d68 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 1 Jul 2019 22:06:32 +0200 Subject: [PATCH 0611/1539] Serialize sending fragmented messages. When sending a fragmented message, wait until it's finished to send other messages. Fix #542 (more elegantly). --- src/websockets/protocol.py | 101 +++++++++++++++++++------------------ tests/test_protocol.py | 43 ++++++++-------- 2 files changed, 74 insertions(+), 70 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 8eab48651..6f2399283 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -258,8 +258,8 @@ def __init__( self._pop_message_waiter: Optional[asyncio.Future[None]] = None self._put_message_waiter: Optional[asyncio.Future[None]] = None - # Flag that protects sending fragmented messages. - self.sending_fragmented_message = False + # Protect sending fragmented messages. + self._fragmented_message_waiter: Optional[asyncio.Future[None]] = None # Mapping of ping IDs to waiters, in chronological order. self.pings: collections.OrderedDict[ @@ -490,12 +490,10 @@ async def send( """ await self.ensure_open() - # Prevent sending other messages until all fragments are sent. - if self.sending_fragmented_message: - raise RuntimeError( - "cannot call send while another coroutine " - "is sending a fragmented message" - ) + # While sending a fragmented message, prevent sending other messages + # until all fragments are sent. + while self._fragmented_message_waiter is not None: + await asyncio.shield(self._fragmented_message_waiter) # Unfragmented message -- this case must be handled first because # strings and bytes-like objects are iterable. @@ -512,66 +510,73 @@ async def send( message = cast(Iterable[Data], message) iter_message = iter(message) - - self.sending_fragmented_message = True - - # First fragment. try: message_chunk = next(iter_message) except StopIteration: return opcode, data = prepare_data(message_chunk) - await self.write_frame(False, opcode, data) - # Other fragments. - for message_chunk in iter_message: - confirm_opcode, data = prepare_data(message_chunk) - if confirm_opcode != opcode: - # We're half-way through a fragmented message and we can't - # complete it. This makes the connection unusable. - self.fail_connection(1011) - raise TypeError("data contains inconsistent types") - await self.write_frame(False, OP_CONT, data) + self._fragmented_message_waiter = asyncio.Future() + try: + # First fragment. + await self.write_frame(False, opcode, data) - # write_frame() will write to the buffer before yielding control. - self.sending_fragmented_message = False + # Other fragments. + for message_chunk in iter_message: + confirm_opcode, data = prepare_data(message_chunk) + if confirm_opcode != opcode: + raise TypeError("data contains inconsistent types") + await self.write_frame(False, OP_CONT, data) - # Final fragment. - await self.write_frame(True, OP_CONT, b"") + # Final fragment. + await self.write_frame(True, OP_CONT, b"") + + except Exception: + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + self.fail_connection(1011) + raise + + finally: + self._fragmented_message_waiter.set_result(None) + self._fragmented_message_waiter = None # Fragmented message -- asynchronous iterator elif isinstance(message, AsyncIterable): # aiter_message = aiter(message) without aiter aiter_message = type(message).__aiter__(message) - - # Prevent sending other messages until all fragments are sent. - self.sending_fragmented_message = True - - # First fragment. try: # message_chunk = anext(aiter_message) without anext message_chunk = await type(aiter_message).__anext__(aiter_message) except StopAsyncIteration: return opcode, data = prepare_data(message_chunk) - await self.write_frame(False, opcode, data) - - # Other fragments. - async for message_chunk in aiter_message: - confirm_opcode, data = prepare_data(message_chunk) - if confirm_opcode != opcode: - # We're half-way through a fragmented message and we can't - # complete it. This makes the connection unusable. - self.fail_connection(1011) - raise TypeError("data contains inconsistent types") - await self.write_frame(False, OP_CONT, data) - - # write_frame() will write to the buffer before yielding control. - self.sending_fragmented_message = False - - # Final fragment. - await self.write_frame(True, OP_CONT, b"") + + self._fragmented_message_waiter = asyncio.Future() + try: + # First fragment. + await self.write_frame(False, opcode, data) + + # Other fragments. + async for message_chunk in aiter_message: + confirm_opcode, data = prepare_data(message_chunk) + if confirm_opcode != opcode: + raise TypeError("data contains inconsistent types") + await self.write_frame(False, OP_CONT, data) + + # Final fragment. + await self.write_frame(True, OP_CONT, b"") + + except Exception: + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + self.fail_connection(1011) + raise + + finally: + self._fragmented_message_waiter.set_result(None) + self._fragmented_message_waiter = None else: raise TypeError("data must be bytes, str, or iterable") diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 342b3255e..d0156fd74 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -638,17 +638,15 @@ def test_send_iterable_mixed_type_error(self): ) def test_send_iterable_prevents_concurrent_send(self): - self.make_drain_slow() - send = self.loop.create_task(self.protocol.send(["ca", "fé"])) - - with self.assertRaisesRegex( - RuntimeError, - "cannot call send while another coroutine " - "is sending a fragmented message", - ): - self.loop.run_until_complete(self.protocol.send("tea")) - - send.cancel() + self.loop.run_until_complete( + asyncio.gather(self.protocol.send(["ca", "fé"]), self.protocol.send(b"tea")) + ) + self.assertFramesSent( + (False, OP_TEXT, "ca".encode("utf-8")), + (False, OP_CONT, "fé".encode("utf-8")), + (True, OP_CONT, "".encode("utf-8")), + (True, OP_BINARY, b"tea"), + ) def test_send_async_iterable_text(self): self.loop.run_until_complete(self.protocol.send(async_iterable(["ca", "fé"]))) @@ -710,17 +708,18 @@ def test_send_async_iterable_mixed_type_error(self): ) def test_send_async_iterable_prevents_concurrent_send(self): - self.make_drain_slow() - send = self.loop.create_task(self.protocol.send(async_iterable(["ca", "fé"]))) - - with self.assertRaisesRegex( - RuntimeError, - "cannot call send while another coroutine " - "is sending a fragmented message", - ): - self.loop.run_until_complete(self.protocol.send("tea")) - - send.cancel() + self.loop.run_until_complete( + asyncio.gather( + self.protocol.send(async_iterable(["ca", "fé"])), + self.protocol.send(b"tea"), + ) + ) + self.assertFramesSent( + (False, OP_TEXT, "ca".encode("utf-8")), + (False, OP_CONT, "fé".encode("utf-8")), + (True, OP_CONT, "".encode("utf-8")), + (True, OP_BINARY, b"tea"), + ) def test_send_on_closing_connection_local(self): close_task = self.half_close_connection_local() From e3452230eaf67ec5c4c253682eecb26aebee2223 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 30 Jun 2019 14:22:58 +0200 Subject: [PATCH 0612/1539] Discourage cancellation of APIs that write frames. If writing is stuck (and closing the connection counts as a write), then cancelling won't achieve anything. There's only one way out: closing the connection, waiting until all timeouts elapse, and eventually websockets gives up and aborts the TCP connection. Fix #278. --- src/websockets/protocol.py | 41 ++++++++++++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 6f2399283..d7e16dc4a 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -479,14 +479,25 @@ async def send( object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) as a binary frame. - It also accepts an iterable or an asynchronous iterator of strings or - bytes-like objects. Each item is treated as a message fragment and - sent in its own frame. All items must be of the same type, or else - :meth:`send` will raise a :exc:`TypeError` and the connection will be - closed. + It also accepts an iterable or an asynchronous iterable of strings or + bytes-like objects. In that case the message is fragmented. Each item + is treated as a message fragment and sent in its own frame. All items + must be of the same type, or else :meth:`send` will raise a + :exc:`TypeError` and the connection will be closed. It raises a :exc:`TypeError` for other inputs. + Canceling :meth:`send` is discouraged. Instead, you should close the + connection with :meth:`close`. Indeed, there only two situations where + :meth:`send` yields control to the event loop: + + 1. The write buffer is full. If you don't want to wait until enough + data is sent, your only alternative is to close the connection. + :meth:`close` will likely time out then abort the TCP connection. + 2. ``message`` is an asynchronous iterator. Stopping in the middle of + a fragmented message will cause a protocol error. Closing the + connection has the same effect. + """ await self.ensure_open() @@ -589,13 +600,17 @@ async def close(self, code: int = 1000, reason: str = "") -> None: connection to terminate. As a consequence, there's no need to await :meth:`wait_closed`; :meth:`close` already does it. + ``code`` must be an :class:`int` and ``reason`` a :class:`str`. + :meth:`close` is idempotent: it doesn't do anything once the connection is closed. - It's safe to wrap this coroutine in :func:`~asyncio.create_task` since - errors during connection termination aren't particularly useful. + Wrapping :func:`close` in :func:`~asyncio.create_task` is safe, given + that errors during connection termination aren't particularly useful. - ``code`` must be an :class:`int` and ``reason`` a :class:`str`. + Canceling :meth:`close` is discouraged. If it takes too long, you can + set a shorter ``close_timeout``. If you don't want to wait, let the + Python process exit, then the OS will close the TCP connection. """ try: @@ -648,6 +663,13 @@ async def ping(self, data: Optional[bytes] = None) -> Awaitable[None]: overridden with the optional ``data`` argument which must be a string (which will be encoded to UTF-8) or a bytes-like object. + Canceling :meth:`ping` is discouraged. If :meth:`ping` doesn't return + immediately, it means the write buffer is full. If you don't want to + wait, you should close the connection. + + Canceling the :class:`~asyncio.Future` returned by :meth:`ping` has no + effect. + """ await self.ensure_open() @@ -678,6 +700,9 @@ async def pong(self, data: bytes = b"") -> None: which must be a string (which will be encoded to UTF-8) or a bytes-like object. + Canceling :meth:`pong` is discouraged for the same reason as + :meth:`ping`. + """ await self.ensure_open() From 5a1b0bb890cb3d6c0ba2dcf05996f7fed8d3b751 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 3 Jul 2019 20:22:57 +0200 Subject: [PATCH 0613/1539] Change status code for invalid credentials to 401. 403 means the credentials are valid but don't provide permissions. --- src/websockets/auth.py | 6 +++++- tests/test_auth.py | 9 ++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/websockets/auth.py b/src/websockets/auth.py index 60f63e9aa..9cb673132 100644 --- a/src/websockets/auth.py +++ b/src/websockets/auth.py @@ -75,7 +75,11 @@ async def process_request( ) if not await self.check_credentials(username, password): - return (http.HTTPStatus.FORBIDDEN, [], b"Invalid credentials\n") + return ( + http.HTTPStatus.UNAUTHORIZED, + [("WWW-Authenticate", build_www_authenticate_basic(self.realm))], + b"Invalid credentials\n", + ) self.username = username diff --git a/tests/test_auth.py b/tests/test_auth.py index bcd340844..07341df56 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -122,7 +122,7 @@ def test_basic_auth_unsupported_credentials_details(self): def test_basic_auth_invalid_credentials(self): with self.assertRaises(InvalidStatusCode) as raised: self.start_client(user_info=("hello", "ihateyou")) - self.assertEqual(raised.exception.status_code, 403) + self.assertEqual(raised.exception.status_code, 401) @with_server(create_protocol=create_protocol) def test_basic_auth_invalid_credentials_details(self): @@ -131,6 +131,9 @@ def test_basic_auth_invalid_credentials_details(self): self.loop.run_until_complete( self.make_http_request(headers={"Authorization": authorization}) ) - self.assertEqual(raised.exception.code, 403) - self.assertNotIn("WWW-Authenticate", raised.exception.headers) + self.assertEqual(raised.exception.code, 401) + self.assertEqual( + raised.exception.headers["WWW-Authenticate"], + 'Basic realm="auth-tests", charset="UTF-8"', + ) self.assertEqual(raised.exception.read().decode(), "Invalid credentials\n") From be04e2fe397ba0dd4b7f7a7f33b84cd4c2c2efd2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 3 Jul 2019 20:39:58 +0200 Subject: [PATCH 0614/1539] Try to make tests less flaky. Fix #639. --- tests/test_protocol.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index d0156fd74..7cb593702 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -638,9 +638,16 @@ def test_send_iterable_mixed_type_error(self): ) def test_send_iterable_prevents_concurrent_send(self): - self.loop.run_until_complete( - asyncio.gather(self.protocol.send(["ca", "fé"]), self.protocol.send(b"tea")) - ) + self.make_drain_slow(2 * MS) + + async def send_iterable(): + await self.protocol.send(["ca", "fé"]) + + async def send_concurrent(): + await asyncio.sleep(MS) + await self.protocol.send(b"tea") + + self.loop.run_until_complete(asyncio.gather(send_iterable(), send_concurrent())) self.assertFramesSent( (False, OP_TEXT, "ca".encode("utf-8")), (False, OP_CONT, "fé".encode("utf-8")), @@ -708,11 +715,17 @@ def test_send_async_iterable_mixed_type_error(self): ) def test_send_async_iterable_prevents_concurrent_send(self): + self.make_drain_slow(2 * MS) + + async def send_async_iterable(): + await self.protocol.send(async_iterable(["ca", "fé"])) + + async def send_concurrent(): + await asyncio.sleep(MS) + await self.protocol.send(b"tea") + self.loop.run_until_complete( - asyncio.gather( - self.protocol.send(async_iterable(["ca", "fé"])), - self.protocol.send(b"tea"), - ) + asyncio.gather(send_async_iterable(), send_concurrent()) ) self.assertFramesSent( (False, OP_TEXT, "ca".encode("utf-8")), From 3718311049eb32a46a0e6b40c1132eeef85fd369 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 2 Jul 2019 22:44:15 +0200 Subject: [PATCH 0615/1539] Avoid logging ping exceptions that aren't retreived. This is a bit of a hack: it relies on the implementation of asyncio. Fix #637. --- src/websockets/protocol.py | 11 ++++++++--- tests/test_protocol.py | 10 ++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index d7e16dc4a..fdadb9398 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -1053,7 +1053,7 @@ async def keepalive_ping(self) -> None: # when close_connection() cancels self.keepalive_ping_task. # ping() raises ConnectionClosed if the connection is lost, - # when connection_lost() calls abort_keepalive_pings(). + # when connection_lost() calls abort_pings(). ping_waiter = await self.ping() @@ -1223,7 +1223,7 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> None: if not hasattr(self, "close_connection_task"): self.close_connection_task = self.loop.create_task(self.close_connection()) - def abort_keepalive_pings(self) -> None: + def abort_pings(self) -> None: """ Raise ConnectionClosed in pending keepalive pings. @@ -1235,6 +1235,11 @@ def abort_keepalive_pings(self) -> None: for ping in self.pings.values(): ping.set_exception(exc) + # If the exception is never retrieved, it will be logged when ping + # is garbage-collected. This is confusing for users. + # Given that ping is done (with an exception), canceling it does + # nothing, but it prevents logging the exception. + ping.cancel() if self.pings: pings_hex = ", ".join( @@ -1312,7 +1317,7 @@ def connection_lost(self, exc: Optional[Exception]) -> None: self.close_code, self.close_reason or "[no reason]", ) - self.abort_keepalive_pings() + self.abort_pings() # If self.connection_lost_waiter isn't pending, that's a bug, because: # - it's set only here in connection_lost() which is called only once; # - it must never be canceled. diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 7cb593702..a6c420181 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -910,6 +910,16 @@ def test_abort_ping(self): self.assertTrue(ping.done()) self.assertIsInstance(ping.exception(), ConnectionClosed) + def test_abort_ping_does_not_log_exception_if_not_retreived(self): + self.loop.run_until_complete(self.protocol.ping()) + # Get the internal Future, which isn't directly returned by ping(). + ping, = self.protocol.pings.values() + # Remove the frame from the buffer, else close_connection() complains. + self.last_sent_frame() + self.close_connection() + # Check a private attribute, for lack of a better solution. + self.assertFalse(ping._log_traceback) + def test_acknowledge_previous_pings(self): pings = [ (self.loop.run_until_complete(self.protocol.ping()), self.last_sent_frame()) From 45e9a86e5dfdb772ce40a64863083b2664d2e44b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 3 Jul 2019 23:34:11 +0200 Subject: [PATCH 0616/1539] Fix references in changelog. --- docs/changelog.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 7a02ec0e7..2556d70cc 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -98,7 +98,7 @@ Also: .. warning:: **Version 7.0 changes how a server terminates connections when it's - closed with** :meth:`~websockets.server.WebSocketServer.close` **.** + closed with** :meth:`~server.WebSocketServer.close` **.** Previously, connections handlers were canceled. Now, connections are closed with close code 1001 (going away). From the perspective of the @@ -223,7 +223,7 @@ Also: **Version 5.0 adds a** ``user_info`` **field to the return value of** :func:`~uri.parse_uri` **and** :class:`~uri.WebSocketURI` **.** - If you're unpacking :class:`~websockets.WebSocketURI` into four variables, + If you're unpacking :class:`~exceptions.WebSocketURI` into four variables, adjust your code to account for that fifth field. Also: From b0f6a3ec70fdfe87be8fbb877ca9bc46dc877c39 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 3 Jul 2019 23:33:22 +0200 Subject: [PATCH 0617/1539] Close connections properly in WebSocketServer.close. Thanks @lburg for the first iteration of this patch. Fix #541. --- docs/changelog.rst | 3 +++ src/websockets/protocol.py | 1 - src/websockets/server.py | 30 +++++++++++++----------------- tests/test_client_server.py | 4 +++- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 2556d70cc..b99d3d058 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -68,6 +68,9 @@ Also: If you were using ``ping_timeout=None`` as a workaround, you can remove it. +* Changed :meth:`~server.WebSocketServer.close` to perform a proper closing + handshake instead of failing the connection. + * Avoided a crash of a ``extra_headers`` callable returns ``None``. * Improved error messages when HTTP parsing fails. diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index fdadb9398..acc45e87b 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -635,7 +635,6 @@ async def close(self, code: int = 1000, reason: str = "") -> None: try: # If close() is canceled during the wait, self.transfer_data_task # is canceled before the timeout elapses. - # This helps closing connections when shutting down a server. await asyncio.wait_for( self.transfer_data_task, self.close_timeout, loop=self.loop ) diff --git a/src/websockets/server.py b/src/websockets/server.py index 8e1db9b7c..42487480a 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -46,7 +46,7 @@ parse_subprotocol, ) from .http import USER_AGENT, Headers, HeadersLike, MultipleValuesError, read_request -from .protocol import State, WebSocketCommonProtocol +from .protocol import WebSocketCommonProtocol from .typing import Origin, Subprotocol @@ -692,26 +692,22 @@ async def _close(self) -> None: # register(). See https://bugs.python.org/issue34852 for details. await asyncio.sleep(0) - # Close open connections. fail_connection() will cancel the transfer - # data task, which is expected to cause the handler task to terminate. - for websocket in self.websockets: - if websocket.state is State.OPEN: - websocket.fail_connection(1001) + # Close OPEN connections with status code 1001. Since the server was + # closed, handshake() closes OPENING conections with a HTTP 503 error. + # Wait until all connections are closed. + + # asyncio.wait doesn't accept an empty first argument + if self.websockets: + await asyncio.wait( + [websocket.close(1001) for websocket in self.websockets], loop=self.loop + ) + + # Wait until all connection handlers are complete. # asyncio.wait doesn't accept an empty first argument. if self.websockets: - # The connection handler can terminate before or after the - # connection closes. Wait until both are done to avoid leaking - # running tasks. - # TODO: it would be nicer to wait only for the connection handler - # and let the handler wait for the connection to close. await asyncio.wait( - [websocket.handler_task for websocket in self.websockets] - + [ - websocket.close_connection_task - for websocket in self.websockets - if websocket.state is State.OPEN - ], + [websocket.handler_task for websocket in self.websockets], loop=self.loop, ) diff --git a/tests/test_client_server.py b/tests/test_client_server.py index 7281ec6bd..aa4bebdc2 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -1172,12 +1172,14 @@ def test_server_shuts_down_during_opening_handshake(self): @with_server() def test_server_shuts_down_during_connection_handling(self): with self.temp_client(): + server_ws = next(iter(self.server.websockets)) self.server.close() with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.client.recv()) - # Websocket connection terminates with 1001 Going Away. + # Websocket connection closes properly with 1001 Going Away. self.assertEqual(self.client.close_code, 1001) + self.assertEqual(server_ws.close_code, 1001) @with_server() @unittest.mock.patch("websockets.server.WebSocketServerProtocol.close") From fc245f269e7108392a4437b1d1b02ed5b99dd9fa Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 30 Jun 2019 22:30:59 +0200 Subject: [PATCH 0618/1539] Improve API docs. * Add info fields for parameters and exceptions. * Rewrite significant parts for clarity. * Make minor consistency fixes. Refs #567. --- docs/api.rst | 7 +- docs/changelog.rst | 2 + docs/spelling_wordlist.txt | 4 + src/websockets/auth.py | 33 ++- src/websockets/client.py | 88 +++--- src/websockets/extensions/base.py | 39 ++- .../extensions/permessage_deflate.py | 50 ++-- src/websockets/framing.py | 105 ++++--- src/websockets/handshake.py | 51 ++-- src/websockets/headers.py | 91 ++++--- src/websockets/http.py | 74 +++-- src/websockets/protocol.py | 156 ++++++----- src/websockets/server.py | 256 ++++++++++-------- src/websockets/typing.py | 19 +- src/websockets/uri.py | 30 +- src/websockets/utils.py | 5 +- tests/test_auth.py | 6 +- 17 files changed, 532 insertions(+), 484 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index 56372eb11..28f41cc40 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -13,12 +13,11 @@ of low-level APIs reflecting the two phases of the WebSocket protocol: 2. Data transfer, as framed messages, ending with a closing handshake. The first phase is designed to integrate with existing HTTP software. -``websockets`` provides functions to build and validate the request and -response headers. +``websockets`` provides a minimal implementation to build, parse and validate +HTTP requests and responses. The second phase is the core of the WebSocket protocol. ``websockets`` -provides a standalone implementation on top of ``asyncio`` with a very simple -API. +provides a complete implementation on top of ``asyncio`` with a simple API. For convenience, public APIs can be imported directly from the :mod:`websockets` package, unless noted otherwise. Anything that isn't listed diff --git a/docs/changelog.rst b/docs/changelog.rst index b99d3d058..aa4a76259 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -85,6 +85,8 @@ Also: * Documented how to optimize memory usage. +* Improved API documentation. + 7.0 ... diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index c2988ead5..1eacc491d 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -9,6 +9,8 @@ Bitcoin bufferbloat Bufferbloat bugfix +bytestring +bytestrings changelog cryptocurrency daemonize @@ -22,6 +24,7 @@ MiB nginx permessage pong +pongs Pythonic serializers subclassing @@ -30,6 +33,7 @@ subprotocols TLS Unparse uple +username websocket WebSocket websockets diff --git a/src/websockets/auth.py b/src/websockets/auth.py index 9cb673132..ae204b8d9 100644 --- a/src/websockets/auth.py +++ b/src/websockets/auth.py @@ -1,6 +1,6 @@ """ -The :mod:`websockets.auth` module implements HTTP Basic Authentication as -specified in :rfc:`7235` and :rfc:`7617`. +:mod:`websockets.auth` provides HTTP Basic Authentication according to +:rfc:`7235` and :rfc:`7617`. """ @@ -108,27 +108,32 @@ def basic_auth_protocol_factory( ) ) - ``realm`` indicates the scope of protection. It should be an ASCII-only - :class:`str` because the encoding of non-ASCII characters is undefined. + ``realm`` indicates the scope of protection. It should contain only ASCII + characters because the encoding of non-ASCII characters is undefined. Refer to section 2.2 of :rfc:`7235` for details. - One of ``credentials`` or ``check_credentials`` must be provided but not - both. - - ``credentials`` defines hardcoded authorized credentials. It can be a + ``credentials`` defines hard coded authorized credentials. It can be a ``(username, password)`` pair or a list of such pairs. ``check_credentials`` defines a coroutine that checks whether credentials are authorized. This coroutine receives ``username`` and ``password`` arguments and returns a :class:`bool`. - By default, ``basic_auth_protocol_factory`` creates instances of - :class:`BasicAuthWebSocketServerProtocol`. You can override this with the - ``create_protocol`` parameter. + One of ``credentials`` or ``check_credentials`` must be provided but not + both. + + By default, ``basic_auth_protocol_factory`` creates a factory for building + :class:`BasicAuthWebSocketServerProtocol` instances. You can override this + with the ``create_protocol`` parameter. + + :param realm: scope of protection + :param credentials: hard coded credentials + :param check_credentials: coroutine that verifies credentials + :raises TypeError: if the credentials argument has the wrong type """ if (credentials is None) == (check_credentials is None): - raise ValueError("provide either credentials or check_credentials") + raise TypeError("provide either credentials or check_credentials") if credentials is not None: if is_credentials(credentials): @@ -145,10 +150,10 @@ async def check_credentials(username: str, password: str) -> bool: return credentials_dict.get(username) == password else: - raise ValueError(f"invalid credentials argument: {credentials}") + raise TypeError(f"invalid credentials argument: {credentials}") else: - raise ValueError(f"invalid credentials argument: {credentials}") + raise TypeError(f"invalid credentials argument: {credentials}") return functools.partial( create_protocol, realm=realm, check_credentials=check_credentials diff --git a/src/websockets/client.py b/src/websockets/client.py index 10435c1ff..89a624511 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -1,5 +1,5 @@ """ -The :mod:`websockets.client` module defines a simple WebSocket client API. +:mod:`websockets.client` defines the WebSocket client APIs. """ @@ -44,7 +44,7 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): """ - Complete WebSocket client implementation as an :class:`asyncio.Protocol`. + :class:`~asyncio.Protocol` subclass implementing a WebSocket client. This class inherits most of its methods from :class:`~websockets.protocol.WebSocketCommonProtocol`. @@ -91,12 +91,11 @@ async def read_http_response(self) -> Tuple[int, Headers]: """ Read status line and headers from the HTTP response. - Raise :exc:`~websockets.exceptions.InvalidMessage` if the HTTP message - is malformed or isn't an HTTP/1.1 GET request. + If the response contains a body, it may be read from ``self.reader`` + after this coroutine returns. - Don't attempt to read the response body because WebSocket handshake - responses don't have one. If the response contains a body, it may be - read from ``self.reader`` after this coroutine returns. + :raises ~websockets.exceptions.InvalidMessage: if the HTTP message is + malformed or isn't an HTTP/1.1 GET response """ try: @@ -234,21 +233,17 @@ async def handshake( """ Perform the client side of the opening handshake. - If provided, ``origin`` sets the Origin HTTP header. - - If provided, ``available_extensions`` is a list of supported - extensions in the order in which they should be used. - - If provided, ``available_subprotocols`` is a list of supported - subprotocols in order of decreasing preference. - - If provided, ``extra_headers`` sets additional HTTP request headers. - It must be a :class:`~websockets.http.Headers` instance, a - :class:`~collections.abc.Mapping`, or an iterable of ``(name, value)`` - pairs. - - Raise :exc:`~websockets.exceptions.InvalidHandshake` if the handshake - fails. + :param origin: sets the Origin HTTP header + :param available_extensions: list of supported extensions in the order + in which they should be used + :param available_subprotocols: list of supported subprotocols in order + of decreasing preference + :param extra_headers: sets additional HTTP request headers; it must be + a :class:`~websockets.http.Headers` instance, a + :class:`~collections.abc.Mapping`, or an iterable of ``(name, + value)`` pairs + :raises ~websockets.exceptions.InvalidHandshake: if the handshake + fails """ request_headers = Headers() @@ -318,16 +313,15 @@ class Connect: """ Connect to the WebSocket server at the given ``uri``. - :func:`connect` returns an awaitable. Awaiting it yields an instance of - :class:`WebSocketClientProtocol` which can then be used to send and - receive messages. + Awaiting :func:`connect` yields a :class:`WebSocketClientProtocol` which + can then be used to send and receive messages. :func:`connect` can also be used as a asynchronous context manager. In that case, the connection is closed when exiting the context. :func:`connect` is a wrapper around the event loop's - :meth:`~asyncio.BaseEventLoop.create_connection` method. Unknown keyword - arguments are passed to :meth:`~asyncio.BaseEventLoop.create_connection`. + :meth:`~asyncio.loop.create_connection` method. Unknown keyword arguments + are passed to :meth:`~asyncio.loop.create_connection`. For example, you can set the ``ssl`` keyword argument to a :class:`~ssl.SSLContext` to enforce some TLS settings. When connecting to @@ -336,20 +330,21 @@ class Connect: You can connect to a different host and port from those found in ``uri`` by setting ``host`` and ``port`` keyword arguments. This only changes the - destination of the TCP connection; the hostname from ``uri`` is still used - in the TLS handshake for secure connections and in the ``Host`` header. - - The behavior of the ``ping_interval``, ``ping_timeout``, ``close_timeout``, - ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit`` optional - arguments is described in the documentation of - :class:`~websockets.protocol.WebSocketCommonProtocol`. - - The ``create_protocol`` parameter allows customizing the asyncio protocol - that manages the connection. It should be a callable or class accepting - the same arguments as :class:`WebSocketClientProtocol` and returning a - :class:`WebSocketClientProtocol` instance. It defaults to + destination of the TCP connection. The host name from ``uri`` is still + used in the TLS handshake for secure connections and in the ``Host`` HTTP + header. + + The ``create_protocol`` parameter allows customizing the + :class:`~asyncio.Protocol` that manages the connection. It should be a + callable or class accepting the same arguments as + :class:`WebSocketClientProtocol` and returning an instance of + :class:`WebSocketClientProtocol` or a subclass. It defaults to :class:`WebSocketClientProtocol`. + The behavior of ``ping_interval``, ``ping_timeout``, ``close_timeout``, + ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit`` is + described in :class:`~websockets.protocol.WebSocketCommonProtocol`. + :func:`connect` also accepts the following optional arguments: * ``compression`` is a shortcut to configure compression extensions; @@ -360,14 +355,14 @@ class Connect: decreasing preference * ``subprotocols`` is a list of supported subprotocols in order of decreasing preference - * ``extra_headers`` sets additional HTTP request headers – it can be a + * ``extra_headers`` sets additional HTTP request headers; it can be a :class:`~websockets.http.Headers` instance, a :class:`~collections.abc.Mapping`, or an iterable of ``(name, value)`` pairs - :func:`connect` raises :exc:`~websockets.uri.InvalidURI` if ``uri`` is - invalid and :exc:`~websockets.handshake.InvalidHandshake` if the opening - handshake fails. + :raises ~websockets.uri.InvalidURI: if ``uri`` is invalid + :raises ~websockets.handshake.InvalidHandshake: if the opening handshake + fails """ @@ -577,14 +572,15 @@ def unix_connect(path: str, uri: str = "ws://localhost/", **kwargs: Any) -> Conn """ Similar to :func:`connect`, but for connecting to a Unix socket. - ``path`` is the path to the Unix socket. ``uri`` is the WebSocket URI. - This function calls the event loop's - :meth:`~asyncio.AbstractEventLoop.create_unix_connection` method. + :meth:`~asyncio.loop.create_unix_connection` method. It is only available on Unix. It's mainly useful for debugging servers listening on Unix sockets. + :param path: file system path to the Unix socket + :param uri: WebSocket URI + """ return connect(uri=uri, path=path, **kwargs) diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index ed847c6bc..7d46687c6 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -1,6 +1,8 @@ """ -The :mod:`websockets.extensions.base` module defines abstract classes for -implementing extensions as specified in `section 9 of RFC 6455`_. +:mod:`websockets.extensions.base` defines abstract classes for implementing +extensions. + +See `section 9 of RFC 6455`_. .. _section 9 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-9 @@ -32,8 +34,8 @@ def decode(self, frame: Frame, *, max_size: Optional[int] = None) -> Frame: """ Decode an incoming frame. - The ``frame`` parameter and the return value are - :class:`~websockets.framing.Frame` instances. + :param frame: incoming frame + :param max_size: maximum payload size in bytes """ @@ -41,8 +43,7 @@ def encode(self, frame: Frame) -> Frame: """ Encode an outgoing frame. - The ``frame`` parameter and the return value are - :class:`~websockets.framing.Frame` instances. + :param frame: outgoing frame """ @@ -64,7 +65,7 @@ def get_request_params(self) -> List[ExtensionParameter]: """ Build request parameters. - Return a list of (name, value) pairs. + Return a list of ``(name, value)`` pairs. """ @@ -76,14 +77,10 @@ def process_response_params( """ Process response parameters received from the server. - ``params`` is a list of (name, value) pairs. - - ``accepted_extensions`` is a list of previously accepted extensions. - - If parameters are acceptable, return an extension: an instance of a - subclass of :class:`Extension`. - - If they aren't, raise :exc:`~websockets.exceptions.NegotiationError`. + :param params: list of ``(name, value)`` pairs. + :param accepted_extensions: list of previously accepted extensions. + :raises ~websockets.exceptions.NegotiationError: if parameters aren't + acceptable """ @@ -109,16 +106,14 @@ def process_request_params( """ Process request parameters received from the client. - ``params`` is a list of (name, value) pairs. - - ``accepted_extensions`` is a list of previously accepted extensions. - To accept the offer, return a 2-uple containing: - - response parameters: a list of (name, value) pairs + - response parameters: a list of ``(name, value)`` pairs - an extension: an instance of a subclass of :class:`Extension` - To reject the offer, raise - :exc:`~websockets.exceptions.NegotiationError`. + :param params: list of ``(name, value)`` pairs. + :param accepted_extensions: list of previously accepted extensions. + :raises ~websockets.exceptions.NegotiationError: to reject the offer, + if parameters aren't acceptable """ diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index bd4b3fa53..a41fd56ca 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -1,6 +1,6 @@ """ -The :mod:`websockets.extensions.permessage_deflate` module implements the -Compression Extensions for WebSocket as specified in :rfc:`7692`. +:mod:`websockets.extensions.permessage_deflate` implements the Compression +Extensions for WebSocket as specified in :rfc:`7692`. """ @@ -257,22 +257,20 @@ def _extract_parameters( class ClientPerMessageDeflateFactory(ClientExtensionFactory): """ - Client-side extension factory for Per-Message Deflate extension. + Client-side extension factory for the Per-Message Deflate extension. - These parameters behave as described in `section 7.1 of RFC 7692`_: - - - ``server_no_context_takeover`` - - ``client_no_context_takeover`` - - ``server_max_window_bits`` - - ``client_max_window_bits`` - - Set them to ``True`` to include them in the negotiation offer without a - value or to an integer value to include them with this value. + Parameters behave as described in `section 7.1 of RFC 7692`_. Set them to + ``True`` to include them in the negotiation offer without a value or to an + integer value to include them with this value. .. _section 7.1 of RFC 7692: https://tools.ietf.org/html/rfc7692#section-7.1 - ``compress_settings`` is an optional :class:`dict` of keyword arguments - for :func:`zlib.compressobj`, excluding ``wbits``. + :param server_no_context_takeover: defaults to ``False`` + :param client_no_context_takeover: defaults to ``False`` + :param server_max_window_bits: optional, defaults to ``None`` + :param client_max_window_bits: optional, defaults to ``None`` + :param compress_settings: optional, keyword arguments for + :func:`zlib.compressobj`, excluding ``wbits`` """ @@ -284,7 +282,7 @@ def __init__( client_no_context_takeover: bool = False, server_max_window_bits: Optional[int] = None, client_max_window_bits: Optional[Union[int, bool]] = None, - compress_settings: Optional[Dict[Any, Any]] = None, + compress_settings: Optional[Dict[str, Any]] = None, ) -> None: """ Configure the Per-Message Deflate extension factory. @@ -432,20 +430,18 @@ class ServerPerMessageDeflateFactory(ServerExtensionFactory): """ Server-side extension factory for the Per-Message Deflate extension. - These parameters behave as described in `section 7.1 of RFC 7692`_: - - - ``server_no_context_takeover`` - - ``client_no_context_takeover`` - - ``server_max_window_bits`` - - ``client_max_window_bits`` - - Set them to ``True`` to include them in the negotiation offer without a - value or to an integer value to include them with this value. + Parameters behave as described in `section 7.1 of RFC 7692`_. Set them to + ``True`` to include them in the negotiation offer without a value or to an + integer value to include them with this value. .. _section 7.1 of RFC 7692: https://tools.ietf.org/html/rfc7692#section-7.1 - ``compress_settings`` is an optional :class:`dict` of keyword arguments - for :func:`zlib.compressobj`, excluding ``wbits``. + :param server_no_context_takeover: defaults to ``False`` + :param client_no_context_takeover: defaults to ``False`` + :param server_max_window_bits: optional, defaults to ``None`` + :param client_max_window_bits: optional, defaults to ``None`` + :param compress_settings: optional, keyword arguments for + :func:`zlib.compressobj`, excluding ``wbits`` """ @@ -457,7 +453,7 @@ def __init__( client_no_context_takeover: bool = False, server_max_window_bits: Optional[int] = None, client_max_window_bits: Optional[int] = None, - compress_settings: Optional[Dict[Any, Any]] = None, + compress_settings: Optional[Dict[str, Any]] = None, ) -> None: """ Configure the Per-Message Deflate extension factory. diff --git a/src/websockets/framing.py b/src/websockets/framing.py index d668e0c52..ec87665ef 100644 --- a/src/websockets/framing.py +++ b/src/websockets/framing.py @@ -1,10 +1,11 @@ """ -The :mod:`websockets.framing` module implements data framing as specified in -`section 5 of RFC 6455`_. +:mod:`websockets.framing` reads and writes WebSocket frames. It deals with a single frame at a time. Anything that depends on the sequence of frames is implemented in :mod:`websockets.protocol`. +See `section 5 of RFC 6455`_. + .. _section 5 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-5 """ @@ -67,16 +68,15 @@ class Frame(FrameData): """ WebSocket frame. - * ``fin`` is the FIN bit - * ``rsv1`` is the RSV1 bit - * ``rsv2`` is the RSV2 bit - * ``rsv3`` is the RSV3 bit - * ``opcode`` is the opcode - * ``data`` is the payload data + :param bool fin: FIN bit + :param bool rsv1: RSV1 bit + :param bool rsv2: RSV2 bit + :param bool rsv3: RSV3 bit + :param int opcode: opcode + :param bytes data: payload data - Only these fields are needed by higher level code. The MASK bit, payload - length and masking-key are handled on the fly by :meth:`read` and - :meth:`write`. + Only these fields are needed. The MASK bit, payload length and masking-key + are handled on the fly by :meth:`read` and :meth:`write`. """ @@ -101,24 +101,20 @@ async def read( extensions: Optional[Sequence["websockets.extensions.base.Extension"]] = None, ) -> "Frame": """ - Read a WebSocket frame and return a :class:`Frame` object. - - ``reader`` is a coroutine taking an integer argument and reading - exactly this number of bytes, unless the end of file is reached. - - ``mask`` is a :class:`bool` telling whether the frame should be masked - i.e. whether the read happens on the server side. - - If ``max_size`` is set and the payload exceeds this size in bytes, - :exc:`~websockets.exceptions.PayloadTooBig` is raised. - - If ``extensions`` is provided, it's a list of classes with an - ``decode()`` method that transform the frame and return a new frame. - They are applied in reverse order. - - This function validates the frame before returning it and raises - :exc:`~websockets.exceptions.WebSocketProtocolError` if it contains - incorrect values. + Read a WebSocket frame. + + :param reader: coroutine that reads exactly the requested number of + bytes, unless the end of file is reached + :param mask: whether the frame should be masked i.e. whether the read + happens on the server side + :param max_size: maximum payload size in bytes + :param extensions: list of classes with a ``decode()`` method that + transforms the frame and return a new frame; extensions are applied + in reverse order + :raises ~websockets.exceptions.PayloadTooBig: if the frame exceeds + ``max_size`` + :raises ~websockets.exceptions.WebSocketProtocolError: if the frame + contains incorrect values """ # Read the header. @@ -175,20 +171,15 @@ def write( """ Write a WebSocket frame. - ``frame`` is the :class:`Frame` object to write. - - ``writer`` is a function accepting bytes. - - ``mask`` is a :class:`bool` telling whether the frame should be masked - i.e. whether the write happens on the client side. - - If ``extensions`` is provided, it's a list of classes with an - ``encode()`` method that transform the frame and return a new frame. - They are applied in order. - - This function validates the frame before sending it and raises - :exc:`~websockets.exceptions.WebSocketProtocolError` if it contains - incorrect values. + :param frame: frame to write + :param writer: function that writes bytes + :param mask: whether the frame should be masked i.e. whether the write + happens on the client side + :param extensions: list of classes with an ``encode()`` method that + transform the frame and return a new frame; extensions are applied + in order + :raises ~websockets.exceptions.WebSocketProtocolError: if the frame + contains incorrect values """ # The first parameter is called `frame` rather than `self`, @@ -242,10 +233,10 @@ def write( def check(frame) -> None: """ - Check that this frame contains acceptable values. + Check that reserved bits and opcode have acceptable values. - Raise :exc:`~websockets.exceptions.WebSocketProtocolError` if this - frame contains incorrect values. + :raises ~websockets.exceptions.WebSocketProtocolError: if a reserved + bit or the opcode is invalid """ # The first parameter is called `frame` rather than `self`, @@ -277,7 +268,7 @@ def prepare_data(data: Data) -> Tuple[int, bytes]: If ``data`` is a bytes-like object, return ``OP_BINARY`` and a bytes-like object. - Raise :exc:`TypeError` for other inputs. + :raises TypeError: if ``data`` doesn't have a supported type """ if isinstance(data, str): @@ -297,14 +288,14 @@ def encode_data(data: Data) -> bytes: """ Convert a string or byte-like object to bytes. - This function is designed for ping and pon g frames. + This function is designed for ping and pong frames. If ``data`` is a :class:`str`, return a :class:`bytes` object encoding ``data`` in UTF-8. If ``data`` is a bytes-like object, return a :class:`bytes` object. - Raise :exc:`TypeError` for other inputs. + :raises TypeError: if ``data`` doesn't have a supported type """ if isinstance(data, str): @@ -319,13 +310,12 @@ def encode_data(data: Data) -> bytes: def parse_close(data: bytes) -> Tuple[int, str]: """ - Parse the data in a close frame. + Parse the payload from a close frame. - Return ``(code, reason)`` when ``code`` is an :class:`int` and ``reason`` - a :class:`str`. + Return ``(code, reason)``. - Raise :exc:`~websockets.exceptions.WebSocketProtocolError` or - :exc:`UnicodeDecodeError` if the data is invalid. + :raises ~websockets.exceptions.WebSocketProtocolError: if data is ill-formed + :raises UnicodeDecodeError: if the reason isn't valid UTF-8 """ length = len(data) @@ -343,7 +333,7 @@ def parse_close(data: bytes) -> Tuple[int, str]: def serialize_close(code: int, reason: str) -> bytes: """ - Serialize the data for a close frame. + Serialize the payload for a close frame. This is the reverse of :func:`parse_close`. @@ -354,7 +344,10 @@ def serialize_close(code: int, reason: str) -> bytes: def check_close(code: int) -> None: """ - Check the close code for a close frame. + Check that the close code has an acceptable value for a close frame. + + :raises ~websockets.exceptions.WebSocketProtocolError: if the close code + is invalid """ if not (code in EXTERNAL_CLOSE_CODES or 3000 <= code < 5000): diff --git a/src/websockets/handshake.py b/src/websockets/handshake.py index f04d81d59..17332d155 100644 --- a/src/websockets/handshake.py +++ b/src/websockets/handshake.py @@ -1,15 +1,9 @@ """ -The :mod:`websockets.handshake` module deals with the WebSocket opening -handshake according to `section 4 of RFC 6455`_. +:mod:`websockets.handshake` provides helpers for the WebSocket handshake. -.. _section 4 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-4 - -Functions defined in this module manipulate HTTP headers. The ``headers`` -argument must implement ``get`` and ``__setitem__`` and ``get`` — a small -subset of the :class:`~collections.abc.MutableMapping` abstract base class. +See `section 4 of RFC 6455`_. -Headers names and values are :class:`str` objects containing only ASCII -characters. +.. _section 4 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-4 Some checks cannot be performed because they depend too much on the context; instead, they're documented below. @@ -50,7 +44,10 @@ def build_request(headers: Headers) -> str: """ Build a handshake request to send to the server. - Return the ``key`` which must be passed to :func:`check_response`. + Update request headers passed in argument. + + :param headers: request headers + :returns: ``key`` which must be passed to :func:`check_response` """ raw_key = bytes(random.getrandbits(8) for _ in range(16)) @@ -66,16 +63,15 @@ def check_request(headers: Headers) -> str: """ Check a handshake request received from the client. - If the handshake is valid, this function returns the ``key`` which must be - passed to :func:`build_response`. - - Otherwise it raises an :exc:`~websockets.exceptions.InvalidHandshake` - exception and the server must return an error like 400 Bad Request. - This function doesn't verify that the request is an HTTP/1.1 or higher GET - request and doesn't perform Host and Origin checks. These controls are - usually performed earlier in the HTTP request handling code. They're the - responsibility of the caller. + request and doesn't perform ``Host`` and ``Origin`` checks. These controls + are usually performed earlier in the HTTP request handling code. They're + the responsibility of the caller. + + :param headers: request headers + :returns: ``key`` which must be passed to :func:`build_response` + :raises ~websockets.exceptions.InvalidHandshake: if the handshake request + is invalid; then the server must return 400 Bad Request error """ connection = sum( @@ -127,7 +123,10 @@ def build_response(headers: Headers, key: str) -> None: """ Build a handshake response to send to the client. - ``key`` comes from :func:`check_request`. + Update response headers passed in argument. + + :param headers: response headers + :param key: comes from :func:`check_request` """ headers["Upgrade"] = "websocket" @@ -139,17 +138,15 @@ def check_response(headers: Headers, key: str) -> None: """ Check a handshake response received from the server. - ``key`` comes from :func:`build_request`. - - If the handshake is valid, this function returns ``None``. - - Otherwise it raises an :exc:`~websockets.exceptions.InvalidHandshake` - exception. - This function doesn't verify that the response is an HTTP/1.1 or higher response with a 101 status code. These controls are the responsibility of the caller. + :param headers: response headers + :param key: comes from :func:`build_request` + :raises ~websockets.exceptions.InvalidHandshake: if the handshake response + is invalid + """ connection = sum( [parse_connection(value) for value in headers.get_all("Connection")], [] diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 536cab592..ac850654e 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -1,13 +1,12 @@ """ -The :mod:`websockets.headers` module provides parsers and serializers for HTTP -headers used in WebSocket handshake messages. +:mod:`websockets.headers` provides parsers and serializers for HTTP headers +used in WebSocket handshake messages. -Its functions cannot be imported from :mod:`websockets`. They must be imported +These APIs cannot be imported from :mod:`websockets`. They must be imported from :mod:`websockets.headers`. """ - import base64 import binascii import re @@ -80,7 +79,7 @@ def parse_token(header: str, pos: int, header_name: str) -> Tuple[str, int]: Return the token value and the new position. - Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. + :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. """ match = _token_re.match(header, pos) @@ -103,7 +102,7 @@ def parse_quoted_string(header: str, pos: int, header_name: str) -> Tuple[str, i Return the unquoted value and the new position. - Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. + :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. """ match = _quoted_string_re.match(header, pos) @@ -153,7 +152,7 @@ def parse_list( Return a list of items. - Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. + :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. """ # Per https://tools.ietf.org/html/rfc7230#section-7, "a recipient MUST @@ -204,7 +203,7 @@ def parse_connection_option( Return the protocol value and the new position. - Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. + :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. """ item, pos = parse_token(header, pos, header_name) @@ -215,9 +214,10 @@ def parse_connection(header: str) -> List[ConnectionOption]: """ Parse a ``Connection`` header. - Return a list of connection options. + Return a list of HTTP connection options. - Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. + :param header: value of the ``Connection`` header + :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. """ return parse_list(parse_connection_option, header, 0, "Connection") @@ -236,7 +236,7 @@ def parse_upgrade_protocol( Return the protocol value and the new position. - Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. + :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. """ match = _protocol_re.match(header, pos) @@ -249,9 +249,10 @@ def parse_upgrade(header: str) -> List[UpgradeProtocol]: """ Parse an ``Upgrade`` header. - Return a list of protocols. + Return a list of HTTP protocols. - Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. + :param header: value of the ``Upgrade`` header + :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. """ return parse_list(parse_upgrade_protocol, header, 0, "Upgrade") @@ -265,7 +266,7 @@ def parse_extension_item_param( Return a ``(name, value)`` pair and the new position. - Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. + :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. """ # Extract parameter name. @@ -300,7 +301,7 @@ def parse_extension_item( Return an ``(extension name, parameters)`` pair, where ``parameters`` is a list of ``(name, value)`` pairs, and the new position. - Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. + :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. """ # Extract extension name. @@ -319,7 +320,7 @@ def parse_extension(header: str) -> List[ExtensionHeader]: """ Parse a ``Sec-WebSocket-Extensions`` header. - Return a value with the following format:: + Return a list of WebSocket extensions and their parameters in this format:: [ ( @@ -334,13 +335,13 @@ def parse_extension(header: str) -> List[ExtensionHeader]: Parameter values are ``None`` when no value is provided. - Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. + :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. """ return parse_list(parse_extension_item, header, 0, "Sec-WebSocket-Extensions") -parse_extension_list = parse_extension # alias for backwards-compatibility +parse_extension_list = parse_extension # alias for backwards compatibility def build_extension_item(name: str, parameters: List[ExtensionParameter]) -> str: @@ -362,7 +363,7 @@ def build_extension_item(name: str, parameters: List[ExtensionParameter]) -> str def build_extension(extensions: Sequence[ExtensionHeader]) -> str: """ - Unparse a ``Sec-WebSocket-Extensions`` header. + Build a ``Sec-WebSocket-Extensions`` header. This is the reverse of :func:`parse_extension`. @@ -372,7 +373,7 @@ def build_extension(extensions: Sequence[ExtensionHeader]) -> str: ) -build_extension_list = build_extension # alias for backwards-compatibility +build_extension_list = build_extension # alias for backwards compatibility def parse_subprotocol_item( @@ -383,7 +384,7 @@ def parse_subprotocol_item( Return the subprotocol value and the new position. - Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. + :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. """ item, pos = parse_token(header, pos, header_name) @@ -394,18 +395,20 @@ def parse_subprotocol(header: str) -> List[Subprotocol]: """ Parse a ``Sec-WebSocket-Protocol`` header. - Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. + Return a list of WebSocket subprotocols. + + :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. """ return parse_list(parse_subprotocol_item, header, 0, "Sec-WebSocket-Protocol") -parse_subprotocol_list = parse_subprotocol # alias for backwards-compatibility +parse_subprotocol_list = parse_subprotocol # alias for backwards compatibility def build_subprotocol(protocols: Sequence[Subprotocol]) -> str: """ - Unparse a ``Sec-WebSocket-Protocol`` header. + Build a ``Sec-WebSocket-Protocol`` header. This is the reverse of :func:`parse_subprotocol`. @@ -413,12 +416,14 @@ def build_subprotocol(protocols: Sequence[Subprotocol]) -> str: return ", ".join(protocols) -build_subprotocol_list = build_subprotocol # alias for backwards-compatibility +build_subprotocol_list = build_subprotocol # alias for backwards compatibility def build_www_authenticate_basic(realm: str) -> str: """ - Build an WWW-Authenticate header for HTTP Basic Auth. + Build a ``WWW-Authenticate`` header for HTTP Basic Auth. + + :param realm: authentication realm """ # https://tools.ietf.org/html/rfc7617#section-2 @@ -427,18 +432,6 @@ def build_www_authenticate_basic(realm: str) -> str: return f"Basic realm={realm}, charset={charset}" -def build_authorization_basic(username: str, password: str) -> str: - """ - Build an Authorization header for HTTP Basic Auth. - - """ - # https://tools.ietf.org/html/rfc7617#section-2 - assert ":" not in username - user_pass = f"{username}:{password}" - basic_credentials = base64.b64encode(user_pass.encode()).decode() - return "Basic " + basic_credentials - - _token68_re = re.compile(r"[A-Za-z0-9-._~+/]+=*") @@ -448,7 +441,7 @@ def parse_token68(header: str, pos: int, header_name: str) -> Tuple[str, int]: Return the token value and the new position. - Raise :exc:`~websockets.exceptions.InvalidHeaderFormat` on invalid inputs. + :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. """ match = _token68_re.match(header, pos) @@ -468,10 +461,14 @@ def parse_end(header: str, pos: int, header_name: str) -> None: def parse_authorization_basic(header: str) -> Tuple[str, str]: """ - Parse an Authorization header for HTTP Basic Auth. + Parse an ``Authorization`` header for HTTP Basic Auth. Return a ``(username, password)`` tuple. + :param header: value of the ``Authorization`` header + :raises InvalidHeaderFormat: on invalid inputs + :raises InvalidHeaderValue: on unsupported inputs + """ # https://tools.ietf.org/html/rfc7235#section-2.1 # https://tools.ietf.org/html/rfc7617#section-2 @@ -500,3 +497,17 @@ def parse_authorization_basic(header: str) -> Tuple[str, str]: ) from None return username, password + + +def build_authorization_basic(username: str, password: str) -> str: + """ + Build an ``Authorization`` header for HTTP Basic Auth. + + This is the reverse of :func:`parse_authorization_basic`. + + """ + # https://tools.ietf.org/html/rfc7617#section-2 + assert ":" not in username + user_pass = f"{username}:{password}" + basic_credentials = base64.b64encode(user_pass.encode()).decode() + return "Basic " + basic_credentials diff --git a/src/websockets/http.py b/src/websockets/http.py index 46b09c2e6..e78a149ed 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -1,8 +1,8 @@ """ -The :mod:`websockets.http` module provides basic HTTP parsing and -serialization. It is merely adequate for WebSocket handshake messages. +:mod:`websockets.http` module provides basic HTTP/1.1 support. It is merely +:adequate for WebSocket handshake messages. -Its functions cannot be imported from :mod:`websockets`. They must be imported +These APIs cannot be imported from :mod:`websockets`. They must be imported from :mod:`websockets.http`. """ @@ -26,10 +26,10 @@ __all__ = [ - "Headers", - "MultipleValuesError", "read_request", "read_response", + "Headers", + "MultipleValuesError", "USER_AGENT", ] @@ -69,22 +69,21 @@ def d(value: bytes) -> str: async def read_request(stream: asyncio.StreamReader) -> Tuple[str, "Headers"]: """ - Read an HTTP/1.1 GET request from ``stream``. - - ``stream`` is an :class:`~asyncio.StreamReader`. - - Return ``(path, headers)`` where ``path`` is a :class:`str` and - ``headers`` is a :class:`Headers` instance. + Read an HTTP/1.1 GET request and returns ``(path, headers)``. ``path`` isn't URL-decoded or validated in any way. - Non-ASCII characters are represented with surrogate escapes. + ``path`` and ``headers`` are expected to contain only ASCII characters. + Other characters are represented with surrogate escapes. - Raise an exception if the request isn't well formatted. + :func:`read_request` doesn't attempt to read the request body because + WebSocket handshake requests don't have one. If the request contains a + body, it may be read from ``stream`` after this coroutine returns. - Don't attempt to read the request body because WebSocket handshake - requests don't have one. If the request contains a body, it may be - read from ``stream`` after this coroutine returns. + :param stream: input to read the request from + :raises EOFError: if the connection is closed without a full HTTP request + :raises SecurityError: if the request exceeds a security limit + :raises ValueError: if the request isn't well formatted """ # https://tools.ietf.org/html/rfc7230#section-3.1.1 @@ -116,21 +115,19 @@ async def read_request(stream: asyncio.StreamReader) -> Tuple[str, "Headers"]: async def read_response(stream: asyncio.StreamReader) -> Tuple[int, str, "Headers"]: """ - Read an HTTP/1.1 response from ``stream``. + Read an HTTP/1.1 response and returns ``(status_code, reason, headers)``. - ``stream`` is an :class:`~asyncio.StreamReader`. + ``reason`` and ``headers`` are expected to contain only ASCII characters. + Other characters are represented with surrogate escapes. - Return ``(status_code, reason, headers)`` where ``status_code`` is an - :class:`int`, ``reason`` is a :class:`str`, and ``headers`` is a - :class:`Headers` instance. + :func:`read_request` doesn't attempt to read the response body because + WebSocket handshake responses don't have one. If the response contains a + body, it may be read from ``stream`` after this coroutine returns. - Non-ASCII characters are represented with surrogate escapes. - - Raise an exception if the response isn't well formatted. - - Don't attempt to read the response body, because WebSocket handshake - responses don't have one. If the response contains a body, it may be - read from ``stream`` after this coroutine returns. + :param stream: input to read the response from + :raises EOFError: if the connection is closed without a full HTTP response + :raises SecurityError: if the response exceeds a security limit + :raises ValueError: if the response isn't well formatted """ # https://tools.ietf.org/html/rfc7230#section-3.1.2 @@ -169,10 +166,6 @@ async def read_headers(stream: asyncio.StreamReader) -> "Headers": """ Read HTTP headers from ``stream``. - ``stream`` is an :class:`~asyncio.StreamReader`. - - Return a :class:`Headers` instance - Non-ASCII characters are represented with surrogate escapes. """ @@ -213,9 +206,7 @@ async def read_line(stream: asyncio.StreamReader) -> bytes: """ Read a single line from ``stream``. - ``stream`` is an :class:`~asyncio.StreamReader`. - - Return :class:`bytes` without CRLF. + CRLF is stripped from the return value. """ # Security: this is bounded by the StreamReader's limit (default = 32 KiB). @@ -244,7 +235,7 @@ def __str__(self) -> str: class Headers(MutableMapping[str, str]): """ - Data structure for working with HTTP headers efficiently. + Efficient data structure for manipulating HTTP headers. A :class:`list` of ``(name, values)`` is inefficient for lookups. @@ -273,9 +264,10 @@ class Headers(MutableMapping[str, str]): As long as no header occurs multiple times, :class:`Headers` behaves like :class:`dict`, except keys are lower-cased to provide case-insensitivity. - :meth:`get_all()` returns a list of all values for a header and - :meth:`raw_items()` returns an iterator of ``(name, values)`` pairs, - similar to :meth:`http.client.HTTPMessage`. + Two methods support support manipulating multiple values explicitly: + + - :meth:`get_all` returns a list of all values for a header; + - :meth:`raw_items` returns an iterator of ``(name, values)`` pairs. """ @@ -348,12 +340,14 @@ def get_all(self, key: str) -> List[str]: """ Return the (possibly empty) list of all values for a header. + :param key: header name + """ return self._dict.get(key.lower(), []) def raw_items(self) -> Iterator[Tuple[str, str]]: """ - Return an iterator of (header name, header value). + Return an iterator of all values as ``(name, value)`` pairs. """ return iter(self._list) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index acc45e87b..fa369450b 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -1,6 +1,7 @@ """ -The :mod:`websockets.protocol` module handles WebSocket control and data -frames as specified in `sections 4 to 8 of RFC 6455`_. +:mod:`websockets.protocol` handles WebSocket control and data frames. + +See `sections 4 to 8 of RFC 6455`_. .. _sections 4 to 8 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-4 @@ -62,16 +63,24 @@ class State(enum.IntEnum): class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): """ - This class implements common parts of the WebSocket protocol. + :class:`~asyncio.Protocol` subclass implementing the data transfer phase. + + Once the WebSocket connection is established, during the data transfer + phase, the protocol is almost symmetrical between the server side and the + client side. :class:`WebSocketCommonProtocol` implements logic that's + shared between servers and clients.. + + Subclasses such as :class:`~websockets.server.WebSocketServerProtocol` and + :class:`~websockets.client.WebSocketClientProtocol` implement the opening + handshake, which is different between servers and clients. - It assumes that the WebSocket connection is established. The handshake is - managed in subclasses such as - :class:`~websockets.server.WebSocketServerProtocol` and - :class:`~websockets.client.WebSocketClientProtocol`. + :class:`WebSocketCommonProtocol` performs four functions: - It runs a task that stores incoming data frames in a queue and deals with - control frames automatically. It sends outgoing data frames and performs - the closing handshake. + * It runs a task that stores incoming data frames in a queue and makes + them available with the :meth:`recv` coroutine. + * It sends outgoing data frames with the :meth:`send` coroutine. + * It deals with control frames automatically. + * It performs the closing handshake. :class:`WebSocketCommonProtocol` supports asynchronous iteration:: @@ -81,20 +90,23 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): The iterator yields incoming messages. It exits normally when the connection is closed with the close code 1000 (OK) or 1001 (going away). It raises a :exc:`~websockets.exceptions.ConnectionClosedError` exception - when the connection is closed with any other status code. + when the connection is closed with any other code. - The ``host``, ``port`` and ``secure`` parameters are simply stored as - attributes for handlers that need them. + When initializing a :class:`WebSocketCommonProtocol`, the ``host``, + ``port``, and ``secure`` parameters are stored as attributes for backwards + compatibility. Consider using :attr:`local_address` on the server side and + :attr:`remote_address` on the client side instead. Once the connection is open, a `Ping frame`_ is sent every ``ping_interval`` seconds. This serves as a keepalive. It helps keeping the connection open, especially in the presence of proxies with short - timeouts. Set ``ping_interval`` to ``None`` to disable this behavior. + timeouts on inactive connections. Set ``ping_interval`` to ``None`` to + disable this behavior. .. _Ping frame: https://tools.ietf.org/html/rfc6455#section-5.5.2 If the corresponding `Pong frame`_ isn't received within ``ping_timeout`` - seconds, the connection is considered unusable and is closed with status + seconds, the connection is considered unusable and is closed with code 1011. This ensures that the remote endpoint remains responsive. Set ``ping_timeout`` to ``None`` to disable this behavior. @@ -102,11 +114,11 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): The ``close_timeout`` parameter defines a maximum wait time in seconds for completing the closing handshake and terminating the TCP connection. - :meth:`close()` completes in at most ``4 * close_timeout`` on the server + :meth:`close` completes in at most ``4 * close_timeout`` on the server side and ``5 * close_timeout`` on the client side. ``close_timeout`` needs to be a parameter of the protocol because - websockets usually calls :meth:`close()` implicitly: + websockets usually calls :meth:`close` implicitly: - on the server side, when the connection handler terminates, - on the client side, when exiting the context manager for the connection. @@ -115,26 +127,26 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): The ``max_size`` parameter enforces the maximum size for incoming messages in bytes. The default value is 1 MiB. ``None`` disables the limit. If a - message larger than the maximum size is received, :meth:`recv()` will + message larger than the maximum size is received, :meth:`recv` will raise :exc:`~websockets.exceptions.ConnectionClosedError` and the - connection will be closed with status code 1009. + connection will be closed with code 1009. The ``max_queue`` parameter sets the maximum length of the queue that holds incoming messages. The default value is ``32``. ``None`` disables the limit. Messages are added to an in-memory queue when they're received; - then :meth:`recv()` pops from that queue. In order to prevent excessive + then :meth:`recv` pops from that queue. In order to prevent excessive memory consumption when messages are received faster than they can be processed, the queue must be bounded. If the queue fills up, the protocol - stops processing incoming data until :meth:`recv()` is called. In this + stops processing incoming data until :meth:`recv` is called. In this situation, various receive buffers (at least in ``asyncio`` and in the OS) will fill up, then the TCP receive window will shrink, slowing down transmission to avoid packet loss. Since Python can use up to 4 bytes of memory to represent a single - character, each websocket connection may use up to ``4 * max_size * - max_queue`` bytes of memory to store incoming messages. By default, - this is 128 MiB. You may want to lower the limits, depending on your - application's requirements. + character, each connection may use up to ``4 * max_size * max_queue`` + bytes of memory to store incoming messages. By default, this is 128 MiB. + You may want to lower the limits, depending on your application's + requirements. The ``read_limit`` argument sets the high-water limit of the buffer for incoming bytes. The low-water limit is half the high-water limit. The @@ -154,14 +166,14 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): :attr:`request_headers` and :attr:`response_headers` attributes, which are :class:`~websockets.http.Headers` instances. - These attributes must be treated as immutable. - If a subprotocol was negotiated, it's available in the :attr:`subprotocol` attribute. - Once the connection is closed, the status code is available in the + Once the connection is closed, the code is available in the :attr:`close_code` attribute and the reason in :attr:`close_reason`. + All these attributes must be treated as read-only. + """ # There are only two differences between the client-side and server-side @@ -187,7 +199,7 @@ def __init__( legacy_recv: bool = False, timeout: Optional[float] = None, ) -> None: - # Backwards-compatibility: close_timeout used to be called timeout. + # Backwards compatibility: close_timeout used to be called timeout. if timeout is None: timeout = 10 else: @@ -229,7 +241,7 @@ def __init__( # This class implements the data transfer and closing handshake, which # are shared between the client-side and the server-side. # Subclasses implement the opening handshake and, on success, execute - # :meth:`connection_open()` to change the state to OPEN. + # :meth:`connection_open` to change the state to OPEN. self.state = State.CONNECTING logger.debug("%s - state = CONNECTING", self.side) @@ -248,7 +260,7 @@ def __init__( self.close_reason: str # Completed when the connection state becomes CLOSED. Translates the - # :meth:`connection_lost()` callback to a :class:`~asyncio.Future` + # :meth:`connection_lost` callback to a :class:`~asyncio.Future` # that can be awaited. (Other :class:`~asyncio.Protocol` callbacks are # translated by ``self.stream_reader``). self.connection_lost_waiter: asyncio.Future[None] = loop.create_future() @@ -341,11 +353,13 @@ def remote_address(self) -> Any: @property def open(self) -> bool: """ - This property is ``True`` when the connection is usable. + ``True`` when the connection is usable. - It may be used to detect disconnections but this is discouraged per - the EAFP_ principle. When ``open`` is ``False``, using the connection - raises a :exc:`~websockets.exceptions.ConnectionClosed` exception. + It may be used to detect disconnections. However, this approach is + discouraged per the EAFP_ principle. + + When ``open`` is ``False``, using the connection raises a + :exc:`~websockets.exceptions.ConnectionClosed` exception. .. _EAFP: https://docs.python.org/3/glossary.html#term-eafp @@ -355,7 +369,7 @@ def open(self) -> bool: @property def closed(self) -> bool: """ - This property is ``True`` once the connection is closed. + ``True`` once the connection is closed. Be aware that both :attr:`open` and :attr:`closed` are ``False`` during the opening and closing sequences. @@ -392,16 +406,16 @@ async def __aiter__(self) -> AsyncIterator[Data]: async def recv(self) -> Data: """ - This coroutine receives the next message. + Receive the next message. - It returns a :class:`str` for a text frame and :class:`bytes` for a - binary frame. + Return a :class:`str` for a text frame and :class:`bytes` for a binary + frame. When the end of the message stream is reached, :meth:`recv` raises :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal connection closure and - :exc:`~websockets.exceptions.ConnectionClosedError`after a protocol + :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol error or a network failure. .. versionchanged:: 3.0 @@ -414,9 +428,9 @@ async def recv(self) -> Data: makes it possible to enforce a timeout by wrapping :meth:`recv` in :func:`~asyncio.wait_for`. - .. versionchanged:: 7.0 - - Calling :meth:`recv` concurrently raises :exc:`RuntimeError`. + :raises ~websockets.exceptions.ConnectionClosed: when the + connection is closed + :raises RuntimeError: if two coroutines call :meth:`recv` concurrently """ if self._pop_message_waiter is not None: @@ -473,19 +487,21 @@ async def send( self, message: Union[Data, Iterable[Data], AsyncIterable[Data]] ) -> None: """ - This coroutine sends a message. + Send a message. - It sends a string (:class:`str`) as a text frame and a bytes-like - object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) - as a binary frame. + A string (:class:`str`) is sent as a `Text frame`_. A bytestring or + bytes-like object (:class:`bytes`, :class:`bytearray`, or + :class:`memoryview`) is sent as a `Binary frame`_. - It also accepts an iterable or an asynchronous iterable of strings or - bytes-like objects. In that case the message is fragmented. Each item - is treated as a message fragment and sent in its own frame. All items - must be of the same type, or else :meth:`send` will raise a - :exc:`TypeError` and the connection will be closed. + .. _Text frame: https://tools.ietf.org/html/rfc6455#section-5.6 + .. _Binary frame: https://tools.ietf.org/html/rfc6455#section-5.6 - It raises a :exc:`TypeError` for other inputs. + :meth:`send` also accepts an iterable or an asynchronous iterable of + strings, bytestrings, or bytes-like objects. In that case the message + is fragmented. Each item is treated as a message fragment and sent in + its own frame. All items must be of the same type, or else + :meth:`send` will raise a :exc:`TypeError` and the connection will be + closed. Canceling :meth:`send` is discouraged. Instead, you should close the connection with :meth:`close`. Indeed, there only two situations where @@ -498,6 +514,8 @@ async def send( a fragmented message will cause a protocol error. Closing the connection has the same effect. + :raises TypeError: for unsupported inputs + """ await self.ensure_open() @@ -594,13 +612,11 @@ async def send( async def close(self, code: int = 1000, reason: str = "") -> None: """ - This coroutine performs the closing handshake. + Perform the closing handshake. - It waits for the other end to complete the handshake and for the TCP - connection to terminate. As a consequence, there's no need to await - :meth:`wait_closed`; :meth:`close` already does it. - - ``code`` must be an :class:`int` and ``reason`` a :class:`str`. + :meth:`close` waits for the other end to complete the handshake and + for the TCP connection to terminate. As a consequence, there's no need + to await :meth:`wait_closed`; :meth:`close` already does it. :meth:`close` is idempotent: it doesn't do anything once the connection is closed. @@ -612,6 +628,9 @@ async def close(self, code: int = 1000, reason: str = "") -> None: set a shorter ``close_timeout``. If you don't want to wait, let the Python process exit, then the OS will close the TCP connection. + :param code: WebSocket close code + :param reason: WebSocket close reason + """ try: await asyncio.wait_for( @@ -644,11 +663,11 @@ async def close(self, code: int = 1000, reason: str = "") -> None: # Wait for the close connection task to close the TCP connection. await asyncio.shield(self.close_connection_task) - async def ping(self, data: Optional[bytes] = None) -> Awaitable[None]: + async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: """ - This coroutine sends a ping. + Send a ping. - It returns a :class:`~asyncio.Future` which will be completed when the + Return a :class:`~asyncio.Future` which will be completed when the corresponding pong is received and which you may ignore if you don't want to wait. @@ -658,7 +677,7 @@ async def ping(self, data: Optional[bytes] = None) -> Awaitable[None]: pong_waiter = await ws.ping() await pong_waiter # only if you want to wait for the pong - By default, the ping contains four random bytes. The content may be + By default, the ping contains four random bytes. This payload may be overridden with the optional ``data`` argument which must be a string (which will be encoded to UTF-8) or a bytes-like object. @@ -689,15 +708,14 @@ async def ping(self, data: Optional[bytes] = None) -> Awaitable[None]: return asyncio.shield(self.pings[data]) - async def pong(self, data: bytes = b"") -> None: + async def pong(self, data: Data = b"") -> None: """ - This coroutine sends a pong. + Send a pong. An unsolicited pong may serve as a unidirectional heartbeat. - The content may be overridden with the optional ``data`` argument - which must be a string (which will be encoded to UTF-8) or a - bytes-like object. + The payload may be set with the optional ``data`` argument which must + be a string (which will be encoded to UTF-8) or a bytes-like object. Canceling :meth:`pong` is discouraged for the same reason as :meth:`ping`. @@ -744,7 +762,7 @@ async def ensure_open(self) -> None: if self.state is State.CLOSING: # If we started the closing handshake, wait for its completion to - # get the proper close code and status. self.close_connection_task + # get the proper close code and reason. self.close_connection_task # will complete within 4 or 5 * close_timeout after close(). The # CLOSING state also occurs when failing the connection. In that # case self.close_connection_task will complete even faster. diff --git a/src/websockets/server.py b/src/websockets/server.py index 42487480a..446f1db7f 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -1,5 +1,5 @@ """ -The :mod:`websockets.server` module defines a simple WebSocket server API. +:mod:`websockets.server` defines the WebSocket server APIs. """ @@ -62,7 +62,7 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): """ - Complete WebSocket server implementation as an :class:`asyncio.Protocol`. + :class:`~asyncio.Protocol` subclass implementing a WebSocket server. This class inherits most of its methods from :class:`~websockets.protocol.WebSocketCommonProtocol`. @@ -92,7 +92,7 @@ def __init__( ] = None, **kwargs: Any, ) -> None: - # For backwards-compatibility with 6.0 or earlier. + # For backwards compatibility with 6.0 or earlier. if origins is not None and "" in origins: warnings.warn("use None instead of '' in origins", DeprecationWarning) origins = [None if origin == "" else origin for origin in origins] @@ -226,12 +226,11 @@ async def read_http_request(self) -> Tuple[str, Headers]: """ Read request line and headers from the HTTP request. - Raise :exc:`~websockets.exceptions.InvalidMessage` if the HTTP message - is malformed or isn't an HTTP/1.1 GET request. + If the request contains a body, it may be read from ``self.reader`` + after this coroutine returns. - Don't attempt to read the request body because WebSocket handshake - requests don't have one. If the request contains a body, it may be - read from ``self.reader`` after this coroutine returns. + :raises ~websockets.exceptions.InvalidMessage: if the HTTP message is + malformed or isn't an HTTP/1.1 GET request """ try: @@ -269,7 +268,7 @@ def write_http_response( self.writer.write(response.encode()) if body is not None: - logger.debug("%s > Body (%d bytes)", self.side, len(body)) + logger.debug("%s > body (%d bytes)", self.side, len(body)) self.writer.write(body) async def process_request( @@ -278,12 +277,10 @@ async def process_request( """ Intercept the HTTP request and return an HTTP response if appropriate. - ``path`` is a :class:`str` and ``request_headers`` is a - :class:`~websockets.http.Headers` instance. - If ``process_request`` returns ``None``, the WebSocket handshake - continues. If it returns a status code, headers and a response body, - that HTTP response is sent and the connection is closed. In that case: + continues. If it returns 3-uple containing a status code, response + headers and a response body, that HTTP response is sent and the + connection is closed. In that case: * The HTTP status must be a :class:`~http.HTTPStatus`. * HTTP headers must be a :class:`~websockets.http.Headers` instance, a @@ -294,30 +291,32 @@ async def process_request( This coroutine may be overridden in a :class:`WebSocketServerProtocol` subclass, for example: - * to return a HTTP 200 :attr:`~http.HTTPStatus.OK` response on a given - path; then a load balancer can use this path for a health check; - * to authenticate the request and return a HTTP 401 - :attr:`~http.HTTPStatus.UNAUTHORIZED` or a HTTP 403 - :attr:`~http.HTTPStatus.FORBIDDEN` when authentication fails. + * to return a HTTP 200 OK response on a given path; then a load + balancer can use this path for a health check; + * to authenticate the request and return a HTTP 401 Unauthorized or a + HTTP 403 Forbidden when authentication fails. - Instead of subclassing, it is possible to pass a ``process_request`` - argument to the :class:`WebSocketServerProtocol` constructor or the - :func:`serve` function. This is equivalent, except the - ``process_request`` corountine doesn't have access to the protocol - instance, so it can't store information for later use. + Instead of subclassing, it is possible to override this method by + passing a ``process_request`` argument to the :func:`serve` function + or the :class:`WebSocketServerProtocol` constructor. This is + equivalent, except ``process_request`` won't have access to the + protocol instance, so it can't store information for later use. ``process_request`` is expected to complete quickly. If it may run for a long time, then it should await :meth:`wait_closed` and exit if :meth:`wait_closed` completes, or else it could prevent the server from shutting down. + :param path: request path, including optional query string + :param request_headers: request headers + """ if self._process_request is not None: response = self._process_request(path, request_headers) if isinstance(response, Awaitable): return await response else: - # For backwards-compatibility with 7.0. + # For backwards compatibility with 7.0. warnings.warn( "declare process_request as a coroutine", DeprecationWarning ) @@ -331,8 +330,10 @@ def process_origin( """ Handle the Origin HTTP request header. - Raise :exc:`~websockets.exceptions.InvalidOrigin` if the origin isn't - acceptable. + :param headers: request headers + :param origins: optional list of acceptable origins + :raises ~websockets.exceptions.InvalidOrigin: if the origin isn't + acceptable """ # "The user agent MUST NOT include more than one Origin header field" @@ -360,10 +361,6 @@ def process_extensions( Return the Sec-WebSocket-Extensions HTTP response header and the list of accepted extensions. - Raise :exc:`~websockets.exceptions.InvalidHandshake` to abort the - handshake with an HTTP 400 error code. (The default implementation - never does this.) - :rfc:`6455` leaves the rules up to the specification of each :extension. @@ -382,6 +379,11 @@ def process_extensions( Other requirements, for example related to mandatory extensions or the order of extensions, may be implemented by overriding this method. + :param headers: request headers + :param extensions: optional list of supported extensions + :raises ~websockets.exceptions.InvalidHandshake: to abort the + handshake with an HTTP 400 error code + """ response_header_value: Optional[str] = None @@ -438,6 +440,11 @@ def process_subprotocol( Return Sec-WebSocket-Protocol HTTP response header, which is the same as the selected subprotocol. + :param headers: request headers + :param available_subprotocols: optional list of supported subprotocols + :raises ~websockets.exceptions.InvalidHandshake: to abort the + handshake with an HTTP 400 error code + """ subprotocol: Optional[Subprotocol] = None @@ -467,16 +474,19 @@ def select_subprotocol( the default implementation selects the preferred subprotocols by giving equal value to the priorities of the client and the server. - If no subprotocols are supported by the client and the server, it + If no subprotocol is supported by the client and the server, it proceeds without a subprotocol. This is unlikely to be the most useful implementation in practice, as many servers providing a subprotocol will require that the client uses that subprotocol. Such rules can be implemented in a subclass. - This method may be overridden by passing a ``select_subprotocol`` - argument to the :class:`WebSocketServerProtocol` constructor or the - :func:`serve` function. + Instead of subclassing, it is possible to override this method by + passing a ``select_subprotocol`` argument to the :func:`serve` + function or the :class:`WebSocketServerProtocol` constructor + + :param client_subprotocols: list of subprotocols offered by the client + :param server_subprotocols: list of subprotocols available on the server """ if self._select_subprotocol is not None: @@ -500,27 +510,22 @@ async def handshake( """ Perform the server side of the opening handshake. - If provided, ``origins`` is a list of acceptable HTTP Origin values. - Include ``None`` if the lack of an origin is acceptable. - - If provided, ``available_extensions`` is a list of supported - extensions in the order in which they should be used. - - If provided, ``available_subprotocols`` is a list of supported - subprotocols in order of decreasing preference. - - If provided, ``extra_headers`` sets additional HTTP response headers - when the handshake succeeds. It can be a - :class:`~websockets.http.Headers` instance, a - :class:`~collections.abc.Mapping`, an iterable of ``(name, value)`` - pairs, or a callable taking the request path and headers in arguments - and returning one of the above. - - Raise :exc:`~websockets.exceptions.InvalidHandshake` if the handshake - fails. - Return the path of the URI of the request. + :param origins: list of acceptable values of the Origin HTTP header; + include ``None`` if the lack of an origin is acceptable + :param available_extensions: list of supported extensions in the order + in which they should be used + :param available_subprotocols: list of supported subprotocols in order + of decreasing preference + :param extra_headers: sets additional HTTP response headers when the + handshake succeeds; it can be a :class:`~websockets.http.Headers` + instance, a :class:`~collections.abc.Mapping`, an iterable of + ``(name, value)`` pairs, or a callable taking the request path and + headers in arguments and returning one of the above. + :raises ~websockets.exceptions.InvalidHandshake: if the handshake + fails + """ path, request_headers = await self.read_http_request() @@ -530,7 +535,7 @@ async def handshake( if isinstance(early_response_awaitable, Awaitable): early_response = await early_response_awaitable else: - # For backwards-compatibility with 7.0. + # For backwards compatibility with 7.0. warnings.warn("declare process_request as a coroutine", DeprecationWarning) early_response = early_response_awaitable # type: ignore @@ -589,21 +594,21 @@ async def handshake( class WebSocketServer: """ - Wrapper for :class:`~asyncio.Server` that closes connections on exit. + WebSocket server returned by :func:`~websockets.server.serve`. - This class provides the return type of :func:`~websockets.server.serve`. + This class provides the same interface as + :class:`~asyncio.AbstractServer`, namely the + :meth:`~asyncio.AbstractServer.close` and + :meth:`~asyncio.AbstractServer.wait_closed` methods. - It mimics the interface of :class:`~asyncio.AbstractServer`, namely its - :meth:`~asyncio.AbstractServer.close()` and - :meth:`~asyncio.AbstractServer.wait_closed()` methods, to close WebSocket - connections properly on exit, in addition to closing the underlying - :class:`~asyncio.Server`. + It keeps track of WebSocket connections in order to close them properly + when shutting down. Instances of this class store a reference to the :class:`~asyncio.Server` - object returned by :meth:`~asyncio.AbstractEventLoop.create_server` rather - than inherit from :class:`~asyncio.Server` in part because - :meth:`~asyncio.AbstractEventLoop.create_server` doesn't support passing a - custom :class:`~asyncio.Server` class. + object returned by :meth:`~asyncio.loop.create_server` rather than inherit + from :class:`~asyncio.Server` in part because + :meth:`~asyncio.loop.create_server` doesn't support passing a custom + :class:`~asyncio.Server` class. """ @@ -624,14 +629,13 @@ def wrap(self, server: asyncio.AbstractServer) -> None: """ Attach to a given :class:`~asyncio.Server`. - Since :meth:`~asyncio.AbstractEventLoop.create_server` doesn't support - injecting a custom ``Server`` class, the easiest solution that doesn't - rely on private :mod:`asyncio` APIs is to: + Since :meth:`~asyncio.loop.create_server` doesn't support injecting a + custom ``Server`` class, the easiest solution that doesn't rely on + private :mod:`asyncio` APIs is to: - instantiate a :class:`WebSocketServer` - give the protocol factory a reference to that instance - - call :meth:`~asyncio.AbstractEventLoop.create_server` with the - factory + - call :meth:`~asyncio.loop.create_server` with the factory - attach the resulting :class:`~asyncio.Server` with this method """ @@ -665,9 +669,18 @@ def is_serving(self) -> bool: def close(self) -> None: """ - Close the server and terminate connections with close code 1001. + Close the server. + + This method: + + * closes the underlying :class:`~asyncio.Server`; + * rejects new WebSocket connections with an HTTP 503 (service + unavailable) error; this happens when the server accepted the TCP + connection but didn't complete the WebSocket opening handshake prior + to closing; + * closes open WebSocket connections with close code 1001 (going away). - This method is idempotent. + :meth:`close` is idempotent. """ if self.close_task is None: @@ -716,10 +729,10 @@ async def _close(self) -> None: async def wait_closed(self) -> None: """ - Wait until the server is closed and all connections are terminated. + Wait until the server is closed. - When :meth:`wait_closed()` returns, all TCP connections are closed and - there are no pending tasks left. + When :meth:`wait_closed` returns, all TCP connections are closed and + all connection handlers have returned. """ await asyncio.shield(self.closed_waiter) @@ -737,77 +750,80 @@ def sockets(self) -> Optional[List[socket.socket]]: class Serve: """ - Create, start, and return a :class:`WebSocketServer`. - :func:`serve` returns an awaitable. Awaiting it yields an instance of - :class:`WebSocketServer` which provides - :meth:`~websockets.server.WebSocketServer.close` and + Create, start, and return a WebSocket server on ``host`` and ``port``. + + Whenever a client connects, the server accepts the connection, creates a + :class:`WebSocketServerProtocol`, performs the opening handshake, and + delegates to the connection handler defined by ``ws_handler``. Once the + handler completes, either normally or with an exception, the server + performs the closing handshake and closes the connection. + + Awaiting :func:`serve` yields a :class:`WebSocketServer`. This instance + provides :meth:`~websockets.server.WebSocketServer.close` and :meth:`~websockets.server.WebSocketServer.wait_closed` methods for terminating the server and cleaning up its resources. + When a server is closed with :meth:`~WebSocketServer.close`, it closes all + connections with close code 1001 (going away). Connections handlers, which + are running the ``ws_handler`` coroutine, will receive a + :exc:`~websockets.exceptions.ConnectionClosedOK` exception on their + current or next interaction with the WebSocket connection. + :func:`serve` can also be used as an asynchronous context manager. In this case, the server is shut down when exiting the context. :func:`serve` is a wrapper around the event loop's - :meth:`~asyncio.AbstractEventLoop.create_server` method. Internally, it - creates and starts a :class:`~asyncio.Server` object by calling - :meth:`~asyncio.AbstractEventLoop.create_server`. The - :class:`WebSocketServer` it returns keeps a reference to this object. + :meth:`~asyncio.loop.create_server` method. It creates and starts a + :class:`~asyncio.Server` with :meth:`~asyncio.loop.create_server`. Then it + wraps the :class:`~asyncio.Server` in a :class:`WebSocketServer` and + returns the :class:`WebSocketServer`. The ``ws_handler`` argument is the WebSocket handler. It must be a coroutine accepting two arguments: a :class:`WebSocketServerProtocol` and the request URI. The ``host`` and ``port`` arguments, as well as unrecognized keyword - arguments, are passed along to - :meth:`~asyncio.AbstractEventLoop.create_server`. For example, you can set - the ``ssl`` keyword argument to a :class:`~ssl.SSLContext` to enable TLS. - - The ``create_protocol`` parameter allows customizing the asyncio protocol - that manages the connection. It should be a callable or class accepting - the same arguments as :class:`WebSocketServerProtocol` and returning a - :class:`WebSocketServerProtocol` instance. It defaults to + arguments, are passed along to :meth:`~asyncio.loop.create_server`. + + For example, you can set the ``ssl`` keyword argument to a + :class:`~ssl.SSLContext` to enable TLS. + + The ``create_protocol`` parameter allows customizing the + :class:`~asyncio.Protocol` that manages the connection. It should be a + callable or class accepting the same arguments as + :class:`WebSocketServerProtocol` and returning an instance of + :class:`WebSocketServerProtocol` or a subclass. It defaults to :class:`WebSocketServerProtocol`. - The behavior of the ``ping_interval``, ``ping_timeout``, ``close_timeout``, - ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit`` optional - arguments is described in the documentation of - :class:`~websockets.protocol.WebSocketCommonProtocol`. + The behavior of ``ping_interval``, ``ping_timeout``, ``close_timeout``, + ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit`` is + described in :class:`~websockets.protocol.WebSocketCommonProtocol`. :func:`serve` also accepts the following optional arguments: * ``compression`` is a shortcut to configure compression extensions; by default it enables the "permessage-deflate" extension; set it to ``None`` to disable compression - * ``origins`` defines acceptable Origin HTTP headers — include ``None`` if + * ``origins`` defines acceptable Origin HTTP headers; include ``None`` if the lack of an origin is acceptable * ``extensions`` is a list of supported extensions in order of decreasing preference * ``subprotocols`` is a list of supported subprotocols in order of decreasing preference * ``extra_headers`` sets additional HTTP response headers when the - handshake succeeds — it can be a :class:`~websockets.http.Headers` + handshake succeeds; it can be a :class:`~websockets.http.Headers` instance, a :class:`~collections.abc.Mapping`, an iterable of ``(name, value)`` pairs, or a callable taking the request path and headers in arguments and returning one of the above - * ``process_request`` is a coroutine taking the request path and headers - in argument, see :meth:`~WebSocketServerProtocol.process_request` for - details - * ``select_subprotocol`` is a callable taking the subprotocols offered by - the client and available on the server in argument, see + * ``process_request`` allows intercepting the HTTP request; it must be a + coroutine taking the request path and headers in argument; see + :meth:`~WebSocketServerProtocol.process_request` for details + * ``select_subprotocol`` allows customizing the logic for selecting a + subprotocol; it must be a callable taking the subprotocols offered by + the client and available on the server in argument; see :meth:`~WebSocketServerProtocol.select_subprotocol` for details - Whenever a client connects, the server accepts the connection, creates a - :class:`WebSocketServerProtocol`, performs the opening handshake, and - delegates to the WebSocket handler. Once the handler completes, the server - performs the closing handshake and closes the connection. - - When a server is closed with :meth:`~WebSocketServer.close`, it closes all - connections with close code 1001 (going away). WebSocket handlers — which - are running the coroutine passed in the ``ws_handler`` — will receive a - :exc:`~websockets.exceptions.ConnectionClosedOK` exception on their - current or next interaction with the WebSocket connection. - Since there's no useful way to propagate exceptions triggered in handlers, they're sent to the ``'websockets.server'`` logger instead. Debugging is much easier if you configure logging to print them:: @@ -851,7 +867,7 @@ def __init__( ] = None, **kwargs: Any, ) -> None: - # Backwards-compatibility: close_timeout used to be called timeout. + # Backwards compatibility: close_timeout used to be called timeout. if timeout is None: timeout = 10 else: @@ -860,7 +876,7 @@ def __init__( if close_timeout is None: close_timeout = timeout - # Backwards-compatibility: create_protocol used to be called klass. + # Backwards compatibility: create_protocol used to be called klass. if klass is None: klass = WebSocketServerProtocol else: @@ -968,16 +984,16 @@ def unix_serve( **kwargs: Any, ) -> Serve: """ - Similar to :func:`serve()`, but for listening on Unix sockets. - - ``path`` is the path to the Unix socket. + Similar to :func:`serve`, but for listening on Unix sockets. This function calls the event loop's - :meth:`~asyncio.AbstractEventLoop.create_unix_server` method. + :meth:`~asyncio.loop.create_unix_server` method. It is only available on Unix. It's useful for deploying a server behind a reverse proxy such as nginx. + :param path: file system path to the Unix socket + """ return serve(ws_handler, path=path, **kwargs) diff --git a/src/websockets/typing.py b/src/websockets/typing.py index 651b40bbe..3847701b2 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -12,7 +12,7 @@ - :class:`bytes` for binary messages """ - +# Remove try / except when dropping support for Python < 3.7 try: Data.__doc__ = Data__doc__ # type: ignore except AttributeError: # pragma: no cover @@ -20,9 +20,26 @@ Origin = NewType("Origin", str) +Origin.__doc__ = """Value of a Origin header""" + ExtensionParameter = Tuple[str, Optional[str]] +ExtensionParameter__doc__ = """Parameter of a WebSocket extension""" +try: + ExtensionParameter.__doc__ = ExtensionParameter__doc__ # type: ignore +except AttributeError: # pragma: no cover + pass + + ExtensionHeader = Tuple[str, List[ExtensionParameter]] +ExtensionHeader__doc__ = """Item parsed in a Sec-WebSocket-Extensions header""" +try: + ExtensionHeader.__doc__ = ExtensionHeader__doc__ # type: ignore +except AttributeError: # pragma: no cover + pass + + Subprotocol = NewType("Subprotocol", str) +Subprotocol.__doc__ = """Items parsed in a Sec-WebSocket-Protocol header""" diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 16d3d6761..cbb56524b 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -1,6 +1,7 @@ """ -The :mod:`websockets.uri` module implements parsing of WebSocket URIs -according to `section 3 of RFC 6455`_. +:mod:`websockets.uri` parses WebSocket URIs. + +See `section 3 of RFC 6455`_. .. _section 3 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-3 @@ -31,25 +32,30 @@ class WebSocketURI(NamedTuple): WebSocketURI.__doc__ = """ WebSocket URI. -* ``secure`` is the secure flag -* ``host`` is the lower-case host -* ``port`` if the integer port, it's always provided even if it's the default -* ``resource_name`` is the resource name, that is, the path and optional query -* ``user_info`` is an ``(username, password)`` tuple when the URI contains +:param bool secure: secure flag +:param str host: lower-case host +:param int port: port, always set even if it's the default +:param str resource_name: path and optional query +:param str user_info: ``(username, password)`` tuple when the URI contains `User Information`_, else ``None``. .. _User Information: https://tools.ietf.org/html/rfc3986#section-3.2.1 - """ +# Work around https://bugs.python.org/issue19931 + +WebSocketURI.secure.__doc__ = "" +WebSocketURI.host.__doc__ = "" +WebSocketURI.port.__doc__ = "" +WebSocketURI.resource_name.__doc__ = "" +WebSocketURI.user_info.__doc__ = "" + def parse_uri(uri: str) -> WebSocketURI: """ - This function parses and validates a WebSocket URI. - - If the URI is valid, it returns a :class:`WebSocketURI`. + Parse and validate a WebSocket URI. - Otherwise it raises an :exc:`~websockets.exceptions.InvalidURI` exception. + :raises ValueError: if ``uri`` isn't a valid WebSocket URI. """ parsed = urllib.parse.urlparse(uri) diff --git a/src/websockets/utils.py b/src/websockets/utils.py index e289e6980..40ac8559f 100644 --- a/src/websockets/utils.py +++ b/src/websockets/utils.py @@ -8,9 +8,8 @@ def apply_mask(data: bytes, mask: bytes) -> bytes: """ Apply masking to the data of a WebSocket message. - ``data`` and ``mask`` are bytes-like objects. - - Return :class:`bytes`. + :param data: Data to mask + :param mask: 4-bytes mask """ if len(mask) != 4: diff --git a/tests/test_auth.py b/tests/test_auth.py index 07341df56..97a4485a0 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -37,14 +37,14 @@ def test_basic_auth(self): self.loop.run_until_complete(self.client.recv()) def test_basic_auth_server_no_credentials(self): - with self.assertRaises(ValueError) as raised: + with self.assertRaises(TypeError) as raised: basic_auth_protocol_factory(realm="auth-tests", credentials=None) self.assertEqual( str(raised.exception), "provide either credentials or check_credentials" ) def test_basic_auth_server_bad_credentials(self): - with self.assertRaises(ValueError) as raised: + with self.assertRaises(TypeError) as raised: basic_auth_protocol_factory(realm="auth-tests", credentials=42) self.assertEqual(str(raised.exception), "invalid credentials argument: 42") @@ -60,7 +60,7 @@ def test_basic_auth_server_multiple_credentials(self): self.loop.run_until_complete(self.client.recv()) def test_basic_auth_bad_multiple_credentials(self): - with self.assertRaises(ValueError) as raised: + with self.assertRaises(TypeError) as raised: basic_auth_protocol_factory( realm="auth-tests", credentials=[("hello", "iloveyou"), 42] ) From 2f10791b875746ba9a8f59ea6e1f3129ffa37740 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 7 Jul 2019 10:31:19 +0200 Subject: [PATCH 0619/1539] Use monospace font consistently for the project name. --- compliance/README.rst | 4 ++-- docs/changelog.rst | 6 +++--- docs/contributing.rst | 4 ++-- src/websockets/protocol.py | 8 ++++---- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/compliance/README.rst b/compliance/README.rst index cbb4ca2c7..8570f9176 100644 --- a/compliance/README.rst +++ b/compliance/README.rst @@ -30,8 +30,8 @@ Then kill the first one with Ctrl-C. The test client or server shouldn't display any exceptions. The results are stored in reports/clients/index.html. -Note that the Autobahn software only supports Python 2, while websockets only -supports Python 3; you need two different environments. +Note that the Autobahn software only supports Python 2, while ``websockets`` +only supports Python 3; you need two different environments. Conformance notes ----------------- diff --git a/docs/changelog.rst b/docs/changelog.rst index aa4a76259..c79f0f0dd 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -138,8 +138,8 @@ Also: Also: -* websockets sends Ping frames at regular intervals and closes the connection - if it doesn't receive a matching Pong frame. See +* ``websockets`` sends Ping frames at regular intervals and closes the + connection if it doesn't receive a matching Pong frame. See :class:`~protocol.WebSocketCommonProtocol` for details. * Added ``process_request`` and ``select_subprotocol`` arguments to @@ -217,7 +217,7 @@ Also: **Version 5.0 fixes a security issue introduced in version 4.0.** - websockets 4.0 was vulnerable to denial of service by memory exhaustion + Version 4.0 was vulnerable to denial of service by memory exhaustion because it didn't enforce ``max_size`` when decompressing compressed messages (`CVE-2018-1000518`_). diff --git a/docs/contributing.rst b/docs/contributing.rst index 00a529243..40f1dbb54 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -55,7 +55,7 @@ cryptocurrency trackers. I'm strongly opposed to Bitcoin's carbon footprint. Please stop heating the planet where my children are supposed to live, thanks. -Since websockets is released under an open-source license, you can use it for -any purpose you like. However, I won't spend any of my time to help. +Since ``websockets`` is released under an open-source license, you can use it +for any purpose you like. However, I won't spend any of my time to help. I will summarily close issues related to Bitcoin or cryptocurrency in any way. diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index fa369450b..7d1560927 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -118,7 +118,7 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): side and ``5 * close_timeout`` on the client side. ``close_timeout`` needs to be a parameter of the protocol because - websockets usually calls :meth:`close` implicitly: + ``websockets`` usually calls :meth:`close` implicitly: - on the server side, when the connection handler terminates, - on the client side, when exiting the context manager for the connection. @@ -1298,15 +1298,15 @@ def eof_received(self) -> bool: See http://bugs.python.org/issue24539 for more information. - This is inappropriate for websockets for at least three reasons: + This is inappropriate for ``websockets`` for at least three reasons: 1. The use case is to read data until EOF with self.reader.read(-1). - Since websockets is a TLV protocol, this never happens. + Since WebSocket is a TLV protocol, this never happens. 2. It doesn't work on TLS connections. A falsy value must be returned to have the same behavior on TLS and plain connections. - 3. The websockets protocol has its own closing handshake. Endpoints + 3. The WebSocket protocol has its own closing handshake. Endpoints close the TCP connection after sending a close frame. As a consequence we revert to the previous, more useful behavior. From c7d795ee91804fac5f869b31b41001695f265a36 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 7 Jul 2019 10:53:22 +0200 Subject: [PATCH 0620/1539] Improve pings dict in protocol diagram. It's more like a queue than like a coroutine. --- docs/protocol.graffle | Bin 4664 -> 4740 bytes docs/protocol.svg | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/protocol.graffle b/docs/protocol.graffle index 13fdb307ef5907b16265b4fcb669180e509c71cc..df76f49607e5743b809e6f41787262932b6b0743 100644 GIT binary patch literal 4740 zcmV-~5_|0*iwFP!000030PS6CbK5r7{@nZuz4^FX)3_7prrB|vbdzqHcw;Boc0ALe zC0OQ$B2|)#<7V>T-vdgPcmXX+m9K)^iEQxz1VNnVTygN^pI=5%@3T&`Fo}PBj6L+Y zr{jJy2;bcdN8dg9q%WTC_4DwvejQxs^i@32U!LrmZ{cDX=X$8qr|8M< z<*%1b;ND?B$f3&*)1K#zf;0`xFW)`M@)REVDU8rdMsYYylgZe73%~s<4NgxZEuQSo zNJ`XLD34GWKJkSRq(t&$_p2XMRbS=gVG z&7OTmW+<}DYx7}z{ffTtp>N;!zG8~m_dOVv{PwLfVAEnh3bJP7!z2n9cda#f^?X5p zIjf{R!Beus{36l|%F8iNFG82er^VIM<>~e3X>h(6)?#39-j;_0GdiD&%ik8iOy@Au zza`12!;07s;?F_$HVud2Vl7-wW-*uLSGj0$82+gn^rbm(v`V7MDBg?0VcZxMZjP#; zJ{$*q=weQv>L7Ybxu=RB!W1W5D)nTyxNzAN{xIwNd5|{-ND%gSklI14hw%>#|3F3W z?+AVxqcxo*X}$g4#Ym@NzqdCT!umB6R&Dk?%*Ii0ao7){#=U`o;iuOh^?vNlf3J7& zQyA%^i*ciI08CS-Ef(tYq(2$yIInZ_d}>K!Zf6TTGk{~^{_2XYGRk~Bufted7-jYl zNu0lmPm}g_ZYV6v{H2aQ>m2Z@(&vn?O6$#hSG;obbhCF~=86WxF6Kk{6_%Y+9whKp z`6=hKQ34jqkcYX?d`T(cSaJG&k4oYR-^YSt!jNQKeX9-qvNa>z!%9fzE6fE`3@>lJ z08T5mG-fqB3kJz~d6E_cc6@0AnDSiOk*nGFOC1dM(zG&w1+v0h{WQB;zR#j*g=b1{ z*yE*FJfnZDhe7|MLVOxTSs_E_L`OjwwT=hbhVd~Rthkn%ASiZUb0B=-r%h`^h8?b^1o19~}O2)R~LX%3LfgT)c!kPQxg=Cf>MJzU$og zm5aQ_XVx`7<176oZ=dCgfG@u?RDEMJ;0*?i)gUY5Xx?$bLC0x+ijRXh3yXjH;~^%T z9QTuc5|~RIvFRV#1a zgNetLWP~d2OTv8IX4ch>LSMimH4MdWtqGP*n^58hYZF>+uui6(;kB}tq=LVshQGIA z4KD84{$|R|xEMXRijOLGEU4?+!4{$y*`z;*VeBD71y+J%UxB*1ia%#rYv<2-(Ec1E zsRuHRz@kuu5K4`UuO{+@G-i%3JWjzPB7y+G3x&kCn>pCR#RijBg->DEz8fmTd?|o% zN5H>AQ4snZDbi5N-z9kwPpAo;HhsVTyv|s2#@dz`YgrTl@WUjql?}H(Jb#)*$B%Br zRrTPMAaqkEoHb@qQ+tHCVG|CPyN9vFNcs#$1jAT9sCGS&k!Bc%x3=yQ#HqJ-Hry@S za4U$FHB30dx85q_YF(?W85aO08f-1&jo}ChDL_Z&ZQRI(@Du@M8*$FC0yTVjYsZcE zhqX9a#XCYBAIeSfp~>qaCdCh!i~0yt!kD0|@Zqx7jt}pXn-V)lT!|6=O$x;B@d+y< zE@|x;@xjImwu%wW%2;(eDZAR6XJ>A4Q|7mgota=@ZDJ{X#ElV3{2MmEZ3uS zF^NEK7du>h%icW4#o~5CB&jR=~BL z){D5dPo{irr+INz5Lp%^+Be?)3Wt;^rX(lj%X8I)q3Ms9eO97hvz9|kOKhP_t#^>q zjxko#UuZS@TZXi|6~MaYh}7*wk7mqwcSA^j^xI-ZCnj3NX3?T3$@I5nTjJJ-T%!jx zg9>*Oa_hX`SjZ<~c8`Lu4pw>aH^UpkB|4gRi!ODR9;6na<-H<<1!b5^6sR(6?FbZ;m3r|$TWUWCh zOY3_4fVquHP2fQIws$6V0k7@hAGS9CG{VO#@sFi--5D&UW1t7Ov+uz%5M7&rgzx!S za7ujS6U?YsiGeJw>uxjYB#>jEd!8${T8ao+Be9IFy^X?}+0>nbEv=g#4wo0@G{PGd zXl(7>40TVO&z3IeD;oRDdC9TEea!je8v%4DAJ>`AbrP?czfR(HX??D+@tJtMT<@i@ z3X$@BL498W3?NP$8!4_k^SV|pQ4tP*C9M}jLDF!NjY?bTbGM_Q)2I^AgU|r z);r96Pzo^j0k#G9xnEX2l`oj!vd}#l_e}jpYOo)G9~Vk$!%ehB1svHu+*WZ}*s%_X z3e^FzcstFPwto?Ypc|yk_jTVi+GUz*yN-*UM%!95l5MPJr10@NitaIV`Vdol%i8cr zXIU`R+3`&sLeMrWe0)-_XgSb-O!TUCEJ1=?fxKq*ywg!&PJP@1Dw_z~K^{TXAAETr zq5)cqNGiH*!9wCU!0;mospg%l9O}#vyDHO3mQuCJFqfqQb=NqTXQn9 zi*s$1$@w}=zR_5AJz=B0LZ-EK?JNVq9FyOFCZ~?cH^}6eUAIG#`nOp{tohx#{zTs+ zq!g=DSgz$??09HkTUD`o6|dz_Se~C*Twf3*l#j52){1+?=UnQa!}(3rUtq`Q)Wqg|{Fqb*I9+g9hEH@Es2q90YlALnjNa26q<6 zyMsFm++hcI4(=S>b%eXS2z1Q@Jsi+kfDSpJb3o^St|QQ$T;#WibGD6&^|OS1;oO96 zU`JdXQ0FFeZo-GOFEJSWK1;{mpr6o+v(2PdhqG-s>vn`YoOL++z;M?1G&dfSu|aJ; z56@||?MO=<(mJGdNc*sm_9o)q!e^d=PQd|T8K|?$pe&Kpt0F2;N`Pq06wnIuc#Pfw zl>MO7K@bPmhpgBd0czQm9VL^&#h?5H)*4(53f>NsB=Bj=}P3FafNYW_|u zICg20uF~YDAanjMp~DZ_Fruy1SL`ud)L~#gN1SkK&SGnP_ibz8I1Zx{7S9;+zIb}O z9O48@)OnJA+WKQf2)U_TcS^XYBqJQOG$}w1@eUhVt~=Mgo*}0AD!J~qNAktAm@*w*( zD{R)g_3~=-tgj;&Rht)Jv1$&Ypsz=|qR!cK_STS-^}KzCn_HJy;lj1;7q^@p^C0Eq z%b%jqQYr4&7l+Mp+hwYM4`u4TDHRIB1aT?YEE2bZ5T=BhSOw#MrBvulMvZ|@t#0Wo z9iH5hOL}u|<4))4-HGy|WpN4b^i)7h9S)sIN;o+Wvh!%`T1EJo)lb+2(s28dt;7BJ~3l6|%HBAGam>_oB?$xbBS zfy~ysE0W(8qjMs;vq-M(y^zee?#Cz`($(6W^sd%ixb~c4zWG`UcladItMA~FeO(cu;_NuF9cc|E z_me)!lKv+!v5me9%I-&fa#(z^pNvLHe2^wN%wcrPNd+}r^6lxIT~s)Xu$0>XOZ3jD zaQNCi2u_7_D%^cS-Bgc~T=#T57@KVNr3&1_?>O5$84k*zGk-DjtB2#D5B<&G&dlca zWYv~kEQ6V|^?ylfZ9ezW){uO<+>`Unv|5ncvzeyUaWKw~k`DBXr!Ssr5)O3s+K8rk zLn|AWs6T{Rc(MSzmfIBRSO?~q)H$IQL#I$X0zyy+;M=r8jvM6v)_Q!D9QK3KDCXJn z(@go`LZ^#6&;83a(P$ognQvrD`6d`y{LIB|dbmk`kXpBOSQF6>!S2vmw_$t;_4(3n z+ZH&CHkst@OMEX5M?tO|3#jHRGFSUy+E0E7UBgk~~jFzXa(pEG>spRBP7XX_N$cm8)mJGhF511!7I_@G?om zKh3T#7PeXz9#)*C(B%l{m_{5ksdBS!ka&!XR6y-fDlf1ds4)!~b?y|cIP z5yvMVe~$XYefWFu_C5aX?BwN#Xb-OMzu-qlJX_#geeyrnqjo$yb4~r-~oM z1SMQZFt}88!<^-&X7#<(pFyB$3rpOG*Z zrs~~Z-^8mrXm>?pqW}bPU$5BqeH{KV(F+Uvil={{=@)SxrdLh-hpvj|$1nyXvK`9z zCus^ZQOsmhPq(P)FJ>mu@6#zR>2V9$JM&ni6?j^tn=(KFz=9&=Ckgz4rL2X+<`zN4 zDYDpDJiq>K5-p^1g8*L9NwDm4)x2E@Tx5BAV7tR8IC-fJf7kRNN^5GCdlOg#Kfa#i9;lPsXVquW)EWO)&FJ&-l0LS8U S908|1ef0mR>Y&W>^8f(A2xIa9 literal 4664 zcmV-8636WyiwFP!000030PS7tbJIE&|9tr?vV7d7uHLp-E<2@#g}Xqzq`+RLGk4rX zO?`3fU^|pvhW~w!oR=SRvoAN=ya~Bal2=h*eS5O4zlDoIl&OJAo}wq)vyZbT zaPOe!XVB$`Nzd~}evq@uq(F^@ zbO?pv6ITd93M5aqXI)Q*FwWwg{-5JCo8LN1#%g3$N^I?4djvjT;k4K&Fm}2&*1H+O(ev}5RTkM8@T5sHq!(egON|RT63;NltlHvqU z$hz5is1}s7F;6Z+m+|Mt)xzb;^}WPDTMTP4u($ig;lPZ}r(*VL@nJHDsrnSh!xk%I z*N?vX>3$Lng2h^xO=doq#aFp#(GC7mHTuGw*ILElco^-3!62%Q3fD)KQ+G#x54xDs zC&~|>Qtrt7hcL+r_oRHXonM$Wg+KKA-t)8C013kG7Lr?tbufO+@M9`E|3L6l8?Ejn zPO9y9&W9=qdYzr|0M@UbuyV7#ARUGNdAH|>wR-~t!%wF-?EKo9f39}$A_&#t`KZ>| z2d1gg<_mQ%?u~~k%BtKvpITCz+vx&NHQ*Syzr12AY*yei#;L;@rHC>3pG)dJ>M##G zo+n*_q>GV|lxa@%xPz#qLUNB{B0Y~j>JZPNOfW7uaa|;(z-2BqvFUUUgl&{2||;3GdugZ#hVXn+YPN4>Zg`}z_` zZ1P7sK8C+NiXMDFI)Rm&9fwCJ=;&k^_f--do%BxqBvl#el-KV+e*k?i4&wwG{PQSL zeZu+CAWS?nll+(P^e;kK{>wx9 z7o|A~l{pDxISGs(OJyh zcY-=6i9b~74+B2}>SdtLgr*n4j9gjx^fN%mspoKy2$vxum?w#_KxYCvW`WKEodvp< zKzDqeT{+G*AkFI_QkRc2FID(e>5;dF7_u_NPy@v5C)sH{h$BCI<|o?V9D+&EE3}?j zoF@8%I4C~Md>K=Q9L!zjdXy55LCbyAseZ?)0o*ISr$JWG)Mrn=g^N#~Go2ns@m0-r z^x~id2qG$hqxQA}(_b$KuVEz- zk2;cZ0*gWtLMVkLL8Wj!Ee$aBTnT@9+DdYT!zmb4L=fg-A(7a0D+wF8SYy(v@F~p7 zdrjk*F9pyI5b&>%6m$lTB&q52@8T@?rd9+_letiRUTf=FTW>>by)+B~_+b**%64B5 zoAH7H5V|f4&=~Tls69elvjqst-ND$y$a5Kt2!^qIP|bQEBlRc_Z*1Kr z22^crZMYk@;Z_hUYnX6^Z@g8;<+@f`J*)sqG}v0&nZpqho&X)0H*q5u!jS}&ZNxdl z64dbJtt~g+9oFJx74HbOeCS;kAL^_Yq7$5exu}aUC5#EW3Lh?OZTavnxhb(_#FZG) zy-b1FKD=RM#3ijQBi`G1!A3EnUKy)SCuJ9V^UTaGuFL#3u`?44tWE5BF5=nEGCiL1rc<}i{e zpMhdn(ikIxJ;Xc&ay!rOcaei!&!rLzK|rLFI}k?NdKQQM_OYZgCbdFiQ$+ZUKZw*) zBtrmPO;`cfwi++u$}WN8wXOQaRY7D~kZ4zX_e&g7pqP*x7cb9M6NV-~V*0xj{hGEM z7+PWjU1Ge0H|ZE-Ir%`#(cdDZ-L3%E6-T6QCVDhwzS|o@`qFEP6&>qn5t~Mf!Z=kw zmTd@HA9A%GP!B5HPROnDer+Ki8wIVyOrmzM0&UWYw#IpvnrX0xbih_RSm|(8(qR>2 z{}zgeyD1qKa$yulgDVsZS`)|8b6B1(D0DiPbg2~!_dzTmD;BI+uwvnUiiO~6#e&}O z&#>cCLOdd<$C>o3f2Ij)nXpi6q1O6m?y-L+RZ)Lsz`b`Nf4*+WpGn~$L8K=MWs)!o za)S#;l*43=K`le;YW#rNjY)OjK=`J2Cba>t&EX%mHviPZ$1CxVp>@?6?0J@f?%mG5 zd&fX@Z3Ysq<6^-nagj?fqhcinGPJI`%{(iCECb!~T(Q+sM93P6Wo+$j6voV^_8e?z zUH5RLdg5HRbU|Ox*j>&`mL2Y5&L6)NK)3R7o#|XB@tV2oBwiQR z=L#F2iNlNaUUI7tNyin`bv=Lq#A$6K#Z_lsSE?;a!eLNEIJ%}32b)#GKesXaxmv3c zljV8a7S}c;{duMRUHB5ypR21f-%VhnYpo%X+vu8&uG#3?J&vv=s`qu1qHAmEp7_dr zGvwNIPa<7MtC^B=h9qZP?+5}>r#Jn9f~Q@q+~v2R=njm}p%l4Da0vruZsWeaYHR&S z%QRQ3du~YmoEzJEhnWjX0p>2iw!kiTi>jw|1ryxMbx+0}UB8iP><8e-xsqCQ6Kzld zM|KCdRm=)IRsm6>Dj?=>r~cA*&%*$8gQWhx?wUrsOjB*vak15C8*4_gORE_vT)d8= zdjy@{$JE}UHat{m>JLrZqYLJF}uiM_S_i!IL_Y^o|&t>U%(35)YHjq3}7gme*>&{}ebxSV?~^1w!h zr)q3%D2FDHbimB;FsF>`uCYbfO>}KiW^r~kvDFSbudx7QHSe#++zgAGOar2TI_`4B zX?P@pP@e4H8tpKq4&XfC9mdF`(pw2>joRquoStB@@*<6k+*AsZsqRa?ET~OEYC>aX zQ;=>;3X%=a*o32>4zZjiL8ku5Pv*MBGAgER6$W>+JyhO+u+$zZZx0%@hsw7+T(IwF z{xzK}xEkCU9PbwH3~+}n+*!D@aMu#DD>cdzk2+;U1xgPf z8Z!yBLO&j(w*X~7sHE>l{?#EXHb#INc4de0xPR$evR}aU#ymS}%(H`e?guK$R>#Qs zsab-#h|7w<(+G}jnk1_YV{pE4Ci$in9C6-oa(dKYTy0XSU8S? zuzxxy<{R|pLAO)v&=qw9JQ z28=J6Ht&g`4Dn($>q&}-V5Vx;jBJuI$80rgTg}>5v%ZxwYmDyGgu14 zkNN0qjq5eemjDrrz;+}pMY1b5UL+GMlC4O#BH4=MTaXiadqwiQd~{YMw-(8jT?-y_ zjXN1~S8lngCB3cH3a&gGmtTIBf?Irs=f!uhX|gs=_BQ!I{n01{nAf1%3FAS~yTqLf zCakR2aB0FyLsj4nz)Hqe2TKbZ*9;r=c}6Zhj9z~#nb0cH4geJeXT{)yrGc|3)CNiG&o)WUc+tjM5Y6M z!04RNh@lgx9ReY!1MqE9BgZxJ|0p#&jJrL*Fp7De`6S2OKUc}(&U62AMKtOMU*;Q` zP`>qt20wFtn+~p%A0);N5ylj*F4!F^Z8wY$fjV2-ZPNmW(Z=Jfd5P~t!O+iCZ2?t$ zMfz$lNP6*`AleJktile(W2+XlT?Oo0c3Bx*XgPhD5A%662%`E}`f*kpOr-oQ*1HN` zt7uSw=GhHq4mPl|kUn9HRc@HKUtxgRSv(3_(U+MGhWNROd+N+KjnS+?p91%EPgNnX zGg+egT>n-@>!qe_uTO*Xktund#A>a#8;WU``|@j5c&X z`DN@s^0z{jW#x!r!);2Igy}ZZc_|Wzu2@I(0q_?P8dycRM{DuWF&)<&Cuh;72{RV%zso z@MWwP7WM^C|2|dEqbx`+n)VNE70r)91V&^tl<$s{1ZE - Produced by OmniGraffle 6.6.2 2017-09-24 19:39:13 +0000Canvas 1Layer 1remote endpointwebsocketsWebSocketCommonProtocolapplication logicreaderStreamReaderwriterStreamWriterpingsdicttransfer_data_taskTasknetworkread_frameread_data_frameread_messagebytesframesdataframeswrite_framemessagesdequerecvsendpingpongclosecontrolframesbytesframes + Produced by OmniGraffle 6.6.2 2019-07-07 08:38:24 +0000Canvas 1Layer 1remote endpointwebsocketsWebSocketCommonProtocolapplication logicreaderStreamReaderwriterStreamWriterpingsdicttransfer_data_taskTasknetworkread_frameread_data_frameread_messagebytesframesdataframeswrite_framemessagesdequerecvsendpingpongclosecontrolframesbytesframes From 8afffd60f4fa8993f6d29965767dcedec4bfceb9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 7 Jul 2019 11:38:03 +0200 Subject: [PATCH 0621/1539] Update design document. It wasn't updated when fragmentation was implemented. Fix #642. --- docs/design.rst | 57 ++++++++++++++++++++++++++++++++----------------- 1 file changed, 38 insertions(+), 19 deletions(-) diff --git a/docs/design.rst b/docs/design.rst index 19cda16bb..75887d453 100644 --- a/docs/design.rst +++ b/docs/design.rst @@ -9,7 +9,7 @@ with the specification of the WebSocket protocol in :rfc:`6455`. It's primarily intended at maintainers. It may also be useful for users who wish to understand what happens under the hood. -.. warning: +.. warning:: Internals described in this document may change at any time. @@ -43,7 +43,7 @@ Transitions happen in the following places: close frame, this does the right thing regardless of which side started the :ref:`closing handshake `; also in :meth:`~protocol.WebSocketCommonProtocol.fail_connection` which duplicates - a few lines of code from `write_close_frame()` and `write_frame()`; + a few lines of code from ``write_close_frame()`` and ``write_frame()``; - ``* -> CLOSED``: in :meth:`~protocol.WebSocketCommonProtocol.connection_lost` which is always called exactly once when the TCP connection is closed. @@ -231,15 +231,17 @@ happens naturally in many use cases, it cannot be relied upon. Then :meth:`~protocol.WebSocketCommonProtocol.recv` fetches the next message from the :attr:`~protocol.WebSocketCommonProtocol.messages` queue, with some -complexity added for handling termination correctly. +complexity added for handling backpressure and termination correctly. Sending data ............ The right side of the diagram shows how ``websockets`` sends data. -:meth:`~protocol.WebSocketCommonProtocol.send` writes a single data frame -containing the message. Fragmentation isn't supported at this time. +:meth:`~protocol.WebSocketCommonProtocol.send` writes one or several data +frames containing the message. While sending a fragmented message, concurrent +calls to :meth:`~protocol.WebSocketCommonProtocol.send` are put on hold until +all fragments are sent. This makes concurrent calls safe. :meth:`~protocol.WebSocketCommonProtocol.ping` writes a ping frame and yields a :class:`~asyncio.Future` which will be completed when a matching pong @@ -420,8 +422,10 @@ words, they must be shielded from cancellation. :meth:`~protocol.WebSocketCommonProtocol.recv` waits for the next message in the queue or for :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to terminate, whichever comes first. It relies on :func:`~asyncio.wait` for -waiting on two tasks in parallel. As a consequence, even though it's waiting -on the transfer data task, it doesn't propagate cancellation to that task. +waiting on two futures in parallel. As a consequence, even though it's waiting +on a :class:`~asyncio.Future` signalling the next message and on +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`, it doesn't +propagate cancellation to them. :meth:`~protocol.WebSocketCommonProtocol.ensure_open` is called by :meth:`~protocol.WebSocketCommonProtocol.send`, @@ -535,18 +539,33 @@ For each connection, the sending side contains these buffers: Concurrency ----------- -Calling any combination of :meth:`~protocol.WebSocketCommonProtocol.recv`, +Awaiting any combination of :meth:`~protocol.WebSocketCommonProtocol.recv`, :meth:`~protocol.WebSocketCommonProtocol.send`, :meth:`~protocol.WebSocketCommonProtocol.close` :meth:`~protocol.WebSocketCommonProtocol.ping`, or -:meth:`~protocol.WebSocketCommonProtocol.pong` concurrently is safe, -including multiple calls to the same method. - -As shown above, receiving frames is independent from sending frames. That -isolates :meth:`~protocol.WebSocketCommonProtocol.recv`, which receives -frames, from the other methods, which send frames. - -Methods that send frames also support concurrent calls. While the connection -is open, each frame is sent with a single write. Combined with the concurrency -model of :mod:`asyncio`, this enforces serialization. After the connection is -closed, sending a frame raises :exc:`~websockets.exceptions.ConnectionClosed`. +:meth:`~protocol.WebSocketCommonProtocol.pong` concurrently is safe, including +multiple calls to the same method, with one exception and one limitation. + +* **Only one coroutine can receive messages at a time.** This constraint + avoids non-deterministic behavior (and simplifies the implementation). If a + coroutine is awaiting :meth:`~protocol.WebSocketCommonProtocol.recv`, + awaiting it again in another coroutine raises :exc:`RuntimeError`. + +* **Sending a fragmented message forces serialization.** Indeed, the WebSocket + protocol doesn't support multiplexing messages. If a coroutine is awaiting + :meth:`~protocol.WebSocketCommonProtocol.send` to send a fragmented message, + awaiting it again in another coroutine waits until the first call completes. + This will be transparent in many cases. It may be a concern if the + fragmented message is generated slowly by an asynchronous iterator. + +Receiving frames is independent from sending frames. This isolates +:meth:`~protocol.WebSocketCommonProtocol.recv`, which receives frames, from +the other methods, which send frames. + +While the connection is open, each frame is sent with a single write. Combined +with the concurrency model of :mod:`asyncio`, this enforces serialization. The +only other requirement is to prevent interleaving other data frames in the +middle of a fragmented message. + +After the connection is closed, sending a frame raises +:exc:`~websockets.exceptions.ConnectionClosed`, which is safe. From 1585da2aa7da6f7984ac1c55a05784619cf974c4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 7 Jul 2019 14:45:20 +0200 Subject: [PATCH 0622/1539] Improve exception hierarchy. * Group and sort exceptions from most common to least common. * Add a base WebSocketException. * Rename WebSocketProtocolError. * Document exception tree. Fix #270. --- docs/changelog.rst | 7 + docs/design.rst | 2 +- src/websockets/exceptions.py | 314 +++++++++++++++++++++-------------- src/websockets/framing.py | 26 +-- src/websockets/protocol.py | 10 +- tests/test_exceptions.py | 95 ++++++----- tests/test_framing.py | 20 +-- 7 files changed, 278 insertions(+), 196 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index c79f0f0dd..12fc57749 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -33,6 +33,13 @@ Changelog If you were setting ``max_queue=0`` to make the queue of incoming messages unbounded, change it to ``max_queue=None``. +.. note:: + + **Version 8.0 renames the** ``WebSocketProtocolError`` **exception** + :exc:`ProtocolError` **.** + + For backwards compatibility, a ``WebSocketProtocolError`` is provided. + .. note:: **Version 8.0 adds the reason phrase to the return type of the low-level diff --git a/docs/design.rst b/docs/design.rst index 75887d453..74279b87f 100644 --- a/docs/design.rst +++ b/docs/design.rst @@ -423,7 +423,7 @@ words, they must be shielded from cancellation. the queue or for :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to terminate, whichever comes first. It relies on :func:`~asyncio.wait` for waiting on two futures in parallel. As a consequence, even though it's waiting -on a :class:`~asyncio.Future` signalling the next message and on +on a :class:`~asyncio.Future` signaling the next message and on :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`, it doesn't propagate cancellation to them. diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index ce2c1e64b..f03ab72f2 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -1,3 +1,32 @@ +""" +:mod:`websockets.exceptions` defines the following exception hierarchy: + +* :exc:`WebSocketException` + * :exc:`ConnectionClosed` + * :exc:`ConnectionClosedError` + * :exc:`ConnectionClosedOK` + * :exc:`InvalidHandshake` + * :exc:`SecurityError` + * :exc:`InvalidMessage` + * :exc:`InvalidHeader` + * :exc:`InvalidHeaderFormat` + * :exc:`InvalidHeaderValue` + * :exc:`InvalidOrigin` + * :exc:`InvalidUpgrade` + * :exc:`InvalidStatusCode` + * :exc:`NegotiationError` + * :exc:`DuplicateParameter` + * :exc:`InvalidParameterName` + * :exc:`InvalidParameterValue` + * :exc:`AbortHandshake` + * :exc:`RedirectHandshake` + * :exc:`InvalidState` + * :exc:`InvalidURI` + * :exc:`PayloadTooBig` + * :exc:`ProtocolError` + +""" + import http from typing import Optional @@ -5,88 +34,152 @@ __all__ = [ - "AbortHandshake", + "WebSocketException", "ConnectionClosed", "ConnectionClosedError", "ConnectionClosedOK", - "DuplicateParameter", "InvalidHandshake", + "SecurityError", + "InvalidMessage", "InvalidHeader", "InvalidHeaderFormat", "InvalidHeaderValue", - "InvalidMessage", "InvalidOrigin", + "InvalidUpgrade", + "InvalidStatusCode", + "NegotiationError", + "DuplicateParameter", "InvalidParameterName", "InvalidParameterValue", + "AbortHandshake", + "RedirectHandshake", "InvalidState", - "InvalidStatusCode", - "InvalidUpgrade", "InvalidURI", - "NegotiationError", "PayloadTooBig", - "RedirectHandshake", - "SecurityError", - "WebSocketProtocolError", + "ProtocolError", ] -class InvalidHandshake(Exception): +class WebSocketException(Exception): """ - Exception raised when a handshake request or response is invalid. + Base class for all exceptions defined by :mod:`websockets`. """ -class AbortHandshake(InvalidHandshake): +CLOSE_CODES = { + 1000: "OK", + 1001: "going away", + 1002: "protocol error", + 1003: "unsupported type", + # 1004 is reserved + 1005: "no status code [internal]", + 1006: "connection closed abnormally [internal]", + 1007: "invalid data", + 1008: "policy violation", + 1009: "message too big", + 1010: "extension required", + 1011: "unexpected error", + 1015: "TLS failure [internal]", +} + + +def format_close(code: int, reason: str) -> str: """ - Exception raised to abort a handshake and return a HTTP response. + Display a human-readable version of the close code and reason. """ + if 3000 <= code < 4000: + explanation = "registered" + elif 4000 <= code < 5000: + explanation = "private use" + else: + explanation = CLOSE_CODES.get(code, "unknown") + result = f"code = {code} ({explanation}), " - def __init__( - self, status: http.HTTPStatus, headers: HeadersLike, body: bytes = b"" - ) -> None: - self.status = status - self.headers = Headers(headers) - self.body = body - message = f"HTTP {status}, {len(self.headers)} headers, {len(body)} bytes" + if reason: + result += f"reason = {reason}" + else: + result += "no reason" + + return result + + +class ConnectionClosed(WebSocketException): + """ + Raised when trying to interact with a closed connection. + + Provides the connection close code and reason in its ``code`` and + ``reason`` attributes respectively. + + """ + + def __init__(self, code: int, reason: str) -> None: + self.code = code + self.reason = reason + message = "WebSocket connection is closed: " + message += format_close(code, reason) super().__init__(message) -class SecurityError(InvalidHandshake): +class ConnectionClosedError(ConnectionClosed): """ - Exception raised when a HTTP request or response breaks security rules. + Like :exc:`ConnectionClosed`, when the connection terminated with an error. + + This means the close code is different from 1000 (OK) and 1001 (going away). """ + def __init__(self, code: int, reason: str) -> None: + assert code != 1000 and code != 1001 + super().__init__(code, reason) + -class RedirectHandshake(InvalidHandshake): +class ConnectionClosedOK(ConnectionClosed): """ - Exception raised when a handshake gets redirected. + Like :exc:`ConnectionClosed`, when the connection terminated properly. + + This means the close code is 1000 (OK) or 1001 (going away). """ - def __init__(self, uri: str) -> None: - self.uri = uri + def __init__(self, code: int, reason: str) -> None: + assert code == 1000 or code == 1001 + super().__init__(code, reason) - def __str__(self) -> str: - return f"redirect to {self.uri}" + +class InvalidHandshake(WebSocketException): + """ + Raised during the handshake when the WebSocket connection fails. + + """ + + +class SecurityError(InvalidHandshake): + """ + Raised when a handshake request or response breaks a security rule. + + Security limits are hard coded. + + """ class InvalidMessage(InvalidHandshake): """ - Exception raised when the HTTP message in a handshake request is malformed. + Raised when a handshake request or response is malformed. """ class InvalidHeader(InvalidHandshake): """ - Exception raised when a HTTP header doesn't have a valid format or value. + Raised when a HTTP header doesn't have a valid format or value. """ def __init__(self, name: str, value: Optional[str] = None) -> None: + self.name = name + self.value = value if value is None: message = f"missing {name} header" elif value == "": @@ -98,32 +191,30 @@ def __init__(self, name: str, value: Optional[str] = None) -> None: class InvalidHeaderFormat(InvalidHeader): """ - Exception raised when a Sec-WebSocket-* HTTP header cannot be parsed. + Raised when a HTTP header cannot be parsed. + + The format of the header doesn't match the grammar for that header. """ def __init__(self, name: str, error: str, header: str, pos: int) -> None: + self.name = name error = f"{error} at {pos} in {header}" super().__init__(name, error) class InvalidHeaderValue(InvalidHeader): """ - Exception raised when a Sec-WebSocket-* HTTP header has a wrong value. - - """ + Raised when a HTTP header has a wrong value. - -class InvalidUpgrade(InvalidHeader): - """ - Exception raised when a Upgrade or Connection header isn't correct. + The format of the header is correct but a value isn't acceptable. """ class InvalidOrigin(InvalidHeader): """ - Exception raised when the Origin header in a request isn't allowed. + Raised when the Origin header in a request isn't allowed. """ @@ -131,11 +222,18 @@ def __init__(self, origin: Optional[str]) -> None: super().__init__("Origin", origin) +class InvalidUpgrade(InvalidHeader): + """ + Raised when the Upgrade or Connection header isn't correct. + + """ + + class InvalidStatusCode(InvalidHandshake): """ - Exception raised when a handshake response status code is invalid. + Raised when a handshake response status code is invalid. - Provides the integer status code in its ``status_code`` attribute. + The integer status code is available in the ``status_code`` attribute. """ @@ -147,139 +245,102 @@ def __init__(self, status_code: int) -> None: class NegotiationError(InvalidHandshake): """ - Exception raised when negotiating an extension fails. + Raised when negotiating an extension fails. """ -class InvalidParameterName(NegotiationError): +class DuplicateParameter(NegotiationError): """ - Exception raised when a parameter name in an extension header is invalid. + Raised when a parameter name is repeated in an extension header. """ def __init__(self, name: str) -> None: self.name = name - message = f"invalid parameter name: {name}" + message = f"duplicate parameter: {name}" super().__init__(message) -class InvalidParameterValue(NegotiationError): +class InvalidParameterName(NegotiationError): """ - Exception raised when a parameter value in an extension header is invalid. + Raised when a parameter name in an extension header is invalid. """ - def __init__(self, name: str, value: Optional[str]) -> None: + def __init__(self, name: str) -> None: self.name = name - self.value = value - message = f"invalid value for parameter {name}: {value}" + message = f"invalid parameter name: {name}" super().__init__(message) -class DuplicateParameter(NegotiationError): +class InvalidParameterValue(NegotiationError): """ - Exception raised when a parameter name is repeated in an extension header. + Raised when a parameter value in an extension header is invalid. """ - def __init__(self, name: str) -> None: + def __init__(self, name: str, value: Optional[str]) -> None: self.name = name - message = f"duplicate parameter: {name}" + self.value = value + if value is None: + message = f"missing value for parameter {name}" + elif value == "": + message = f"empty value for parameter {name}" + else: + message = f"invalid value for parameter {name}: {value}" super().__init__(message) -class InvalidState(Exception): +class AbortHandshake(InvalidHandshake): """ - Exception raised when an operation is forbidden in the current state. + Raised to abort the handshake on purpose and return a HTTP response. - """ + This exception is an implementation detail. + The public API is :meth:`~server.WebSocketServerProtocol.process_request`. -CLOSE_CODES = { - 1000: "OK", - 1001: "going away", - 1002: "protocol error", - 1003: "unsupported type", - # 1004 is reserved - 1005: "no status code [internal]", - 1006: "connection closed abnormally [internal]", - 1007: "invalid data", - 1008: "policy violation", - 1009: "message too big", - 1010: "extension required", - 1011: "unexpected error", - 1015: "TLS failure [internal]", -} - - -def format_close(code: int, reason: str) -> str: """ - Display a human-readable version of the close code and reason. - """ - if 3000 <= code < 4000: - explanation = "registered" - elif 4000 <= code < 5000: - explanation = "private use" - else: - explanation = CLOSE_CODES.get(code, "unknown") - result = f"code = {code} ({explanation}), " - - if reason: - result += f"reason = {reason}" - else: - result += "no reason" - - return result + def __init__( + self, status: http.HTTPStatus, headers: HeadersLike, body: bytes = b"" + ) -> None: + self.status = status + self.headers = Headers(headers) + self.body = body + message = f"HTTP {status}, {len(self.headers)} headers, {len(body)} bytes" + super().__init__(message) -class ConnectionClosed(InvalidState): +class RedirectHandshake(InvalidHandshake): """ - Exception raised when trying to read or write on a closed connection. + Raised when a handshake gets redirected. - Provides the connection close code and reason in its ``code`` and - ``reason`` attributes respectively. + This exception is an implementation detail. """ - def __init__(self, code: int, reason: str) -> None: - self.code = code - self.reason = reason - message = "WebSocket connection is closed: " - message += format_close(code, reason) - super().__init__(message) - + def __init__(self, uri: str) -> None: + self.uri = uri -class ConnectionClosedError(ConnectionClosed): - """ - Like :exc:`ConnectionClosed`, when the connection terminated with an error. + def __str__(self) -> str: + return f"redirect to {self.uri}" - This means the close code is different from 1000 (OK) and 1001 (going away). +class InvalidState(WebSocketException, AssertionError): """ + Raised when an operation is forbidden in the current state. - def __init__(self, code: int, reason: str) -> None: - assert code != 1000 and code != 1001 - super().__init__(code, reason) - + This exception is an implementation detail. -class ConnectionClosedOK(ConnectionClosed): - """ - Like :exc:`ConnectionClosed`, when the connection terminated properly. - - This means the close code is 1000 (OK) or 1001 (going away). + It should never be raised in normal circumstances. """ - def __init__(self, code: int, reason: str) -> None: - assert code == 1000 or code == 1001 - super().__init__(code, reason) - -class InvalidURI(Exception): +class InvalidURI(WebSocketException): """ - Exception raised when an URI isn't a valid websocket URI. + Raised when connecting to an URI that isn't a valid WebSocket URI. """ @@ -289,15 +350,18 @@ def __init__(self, uri: str) -> None: super().__init__(message) -class PayloadTooBig(Exception): +class PayloadTooBig(WebSocketException): """ - Exception raised when a frame's payload exceeds the maximum size. + Raised when receiving a frame with a payload exceeding the maximum size. """ -class WebSocketProtocolError(Exception): +class ProtocolError(WebSocketException): """ - Internal exception raised when the remote side breaks the protocol. + Raised when the other side breaks the protocol. """ + + +WebSocketProtocolError = ProtocolError # for backwards compatibility diff --git a/src/websockets/framing.py b/src/websockets/framing.py index ec87665ef..478a7b05a 100644 --- a/src/websockets/framing.py +++ b/src/websockets/framing.py @@ -15,7 +15,7 @@ import struct from typing import Any, Awaitable, Callable, NamedTuple, Optional, Sequence, Tuple -from .exceptions import PayloadTooBig, WebSocketProtocolError +from .exceptions import PayloadTooBig, ProtocolError from .typing import Data @@ -113,7 +113,7 @@ async def read( in reverse order :raises ~websockets.exceptions.PayloadTooBig: if the frame exceeds ``max_size`` - :raises ~websockets.exceptions.WebSocketProtocolError: if the frame + :raises ~websockets.exceptions.ProtocolError: if the frame contains incorrect values """ @@ -129,7 +129,7 @@ async def read( opcode = head1 & 0b00001111 if (True if head2 & 0b10000000 else False) != mask: - raise WebSocketProtocolError("incorrect masking") + raise ProtocolError("incorrect masking") length = head2 & 0b01111111 if length == 126: @@ -178,7 +178,7 @@ def write( :param extensions: list of classes with an ``encode()`` method that transform the frame and return a new frame; extensions are applied in order - :raises ~websockets.exceptions.WebSocketProtocolError: if the frame + :raises ~websockets.exceptions.ProtocolError: if the frame contains incorrect values """ @@ -235,7 +235,7 @@ def check(frame) -> None: """ Check that reserved bits and opcode have acceptable values. - :raises ~websockets.exceptions.WebSocketProtocolError: if a reserved + :raises ~websockets.exceptions.ProtocolError: if a reserved bit or the opcode is invalid """ @@ -243,17 +243,17 @@ def check(frame) -> None: # but it's the instance of class to which this method is bound. if frame.rsv1 or frame.rsv2 or frame.rsv3: - raise WebSocketProtocolError("reserved bits must be 0") + raise ProtocolError("reserved bits must be 0") if frame.opcode in DATA_OPCODES: return elif frame.opcode in CTRL_OPCODES: if len(frame.data) > 125: - raise WebSocketProtocolError("control frame too long") + raise ProtocolError("control frame too long") if not frame.fin: - raise WebSocketProtocolError("fragmented control frame") + raise ProtocolError("fragmented control frame") else: - raise WebSocketProtocolError(f"invalid opcode: {frame.opcode}") + raise ProtocolError(f"invalid opcode: {frame.opcode}") def prepare_data(data: Data) -> Tuple[int, bytes]: @@ -314,7 +314,7 @@ def parse_close(data: bytes) -> Tuple[int, str]: Return ``(code, reason)``. - :raises ~websockets.exceptions.WebSocketProtocolError: if data is ill-formed + :raises ~websockets.exceptions.ProtocolError: if data is ill-formed :raises UnicodeDecodeError: if the reason isn't valid UTF-8 """ @@ -328,7 +328,7 @@ def parse_close(data: bytes) -> Tuple[int, str]: return 1005, "" else: assert length == 1 - raise WebSocketProtocolError("close frame too short") + raise ProtocolError("close frame too short") def serialize_close(code: int, reason: str) -> bytes: @@ -346,12 +346,12 @@ def check_close(code: int) -> None: """ Check that the close code has an acceptable value for a close frame. - :raises ~websockets.exceptions.WebSocketProtocolError: if the close code + :raises ~websockets.exceptions.ProtocolError: if the close code is invalid """ if not (code in EXTERNAL_CLOSE_CODES or 3000 <= code < 5000): - raise WebSocketProtocolError("invalid status code") + raise ProtocolError("invalid status code") # at the bottom to allow circular import, because Extension depends on Frame diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 7d1560927..42ddf0763 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -35,7 +35,7 @@ ConnectionClosedOK, InvalidState, PayloadTooBig, - WebSocketProtocolError, + ProtocolError, ) from .extensions.base import Extension from .framing import * @@ -811,7 +811,7 @@ async def transfer_data(self) -> None: # twice and failing the connection again. raise - except WebSocketProtocolError as exc: + except ProtocolError as exc: self.transfer_data_exc = exc self.fail_connection(1002) @@ -861,7 +861,7 @@ async def read_message(self) -> Optional[Data]: elif frame.opcode == OP_BINARY: text = False else: # frame.opcode == OP_CONT - raise WebSocketProtocolError("unexpected opcode") + raise ProtocolError("unexpected opcode") # Shortcut for the common case - no fragmentation if frame.fin: @@ -906,9 +906,9 @@ def append(frame: Frame) -> None: while not frame.fin: frame = await self.read_data_frame(max_size=max_size) if frame is None: - raise WebSocketProtocolError("incomplete fragmented message") + raise ProtocolError("incomplete fragmented message") if frame.opcode != OP_CONT: - raise WebSocketProtocolError("unexpected opcode") + raise ProtocolError("unexpected opcode") append(frame) # mypy cannot figure out that chunks have the proper type. diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 2cbd78671..72b1076ab 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -9,21 +9,46 @@ def test_str(self): for exception, exception_str in [ # fmt: off ( - InvalidHandshake("invalid request"), - "invalid request", + WebSocketException("something went wrong"), + "something went wrong", ), ( - AbortHandshake(200, Headers(), b"OK\n"), - "HTTP 200, 0 headers, 3 bytes", + ConnectionClosed(1000, ""), + "WebSocket connection is closed: code = 1000 " + "(OK), no reason", ), ( - SecurityError("redirect from WSS to WS"), - "redirect from WSS to WS", - + ConnectionClosed(1006, None), + "WebSocket connection is closed: code = 1006 " + "(connection closed abnormally [internal]), no reason" ), ( - RedirectHandshake("wss://example.com"), - "redirect to wss://example.com", + ConnectionClosed(3000, None), + "WebSocket connection is closed: code = 3000 " + "(registered), no reason" + ), + ( + ConnectionClosed(4000, None), + "WebSocket connection is closed: code = 4000 " + "(private use), no reason" + ), + ( + ConnectionClosedError(1016, None), + "WebSocket connection is closed: code = 1016 " + "(unknown), no reason" + ), + ( + ConnectionClosedOK(1001, "bye"), + "WebSocket connection is closed: code = 1001 " + "(going away), reason = bye", + ), + ( + InvalidHandshake("invalid request"), + "invalid request", + ), + ( + SecurityError("redirect from WSS to WS"), + "redirect from WSS to WS", ), ( InvalidMessage("malformed HTTP message"), @@ -56,6 +81,10 @@ def test_str(self): InvalidHeaderValue("Sec-WebSocket-Version", "42"), "invalid Sec-WebSocket-Version header: 42", ), + ( + InvalidOrigin("http://bad.origin"), + "invalid Origin header: http://bad.origin", + ), ( InvalidUpgrade("Upgrade"), "missing Upgrade header", @@ -64,10 +93,6 @@ def test_str(self): InvalidUpgrade("Connection", "websocket"), "invalid Connection header: websocket", ), - ( - InvalidOrigin("http://bad.origin"), - "invalid Origin header: http://bad.origin", - ), ( InvalidStatusCode(403), "server rejected WebSocket connection: HTTP 403", @@ -76,51 +101,37 @@ def test_str(self): NegotiationError("unsupported subprotocol: spam"), "unsupported subprotocol: spam", ), - ( - InvalidParameterName("|"), - "invalid parameter name: |", - ), - ( - InvalidParameterValue("a", "|"), - "invalid value for parameter a: |", - ), ( DuplicateParameter("a"), "duplicate parameter: a", ), ( - InvalidState("WebSocket connection isn't established yet"), - "WebSocket connection isn't established yet", + InvalidParameterName("|"), + "invalid parameter name: |", ), ( - ConnectionClosed(1000, ""), - "WebSocket connection is closed: code = 1000 " - "(OK), no reason", + InvalidParameterValue("a", None), + "missing value for parameter a", ), ( - ConnectionClosedOK(1001, "bye"), - "WebSocket connection is closed: code = 1001 " - "(going away), reason = bye", + InvalidParameterValue("a", ""), + "empty value for parameter a", ), ( - ConnectionClosed(1006, None), - "WebSocket connection is closed: code = 1006 " - "(connection closed abnormally [internal]), no reason" + InvalidParameterValue("a", "|"), + "invalid value for parameter a: |", ), ( - ConnectionClosedError(1016, None), - "WebSocket connection is closed: code = 1016 " - "(unknown), no reason" + AbortHandshake(200, Headers(), b"OK\n"), + "HTTP 200, 0 headers, 3 bytes", ), ( - ConnectionClosed(3000, None), - "WebSocket connection is closed: code = 3000 " - "(registered), no reason" + RedirectHandshake("wss://example.com"), + "redirect to wss://example.com", ), ( - ConnectionClosed(4000, None), - "WebSocket connection is closed: code = 4000 " - "(private use), no reason" + InvalidState("WebSocket connection isn't established yet"), + "WebSocket connection isn't established yet", ), ( InvalidURI("|"), @@ -131,7 +142,7 @@ def test_str(self): "payload length exceeds limit: 2 > 1 bytes", ), ( - WebSocketProtocolError("invalid opcode: 7"), + ProtocolError("invalid opcode: 7"), "invalid opcode: 7", ), # fmt: on diff --git a/tests/test_framing.py b/tests/test_framing.py index 430faf6e1..9e6f1871d 100644 --- a/tests/test_framing.py +++ b/tests/test_framing.py @@ -3,7 +3,7 @@ import unittest import unittest.mock -from websockets.exceptions import PayloadTooBig, WebSocketProtocolError +from websockets.exceptions import PayloadTooBig, ProtocolError from websockets.framing import * from .utils import AsyncioTestCase @@ -112,7 +112,7 @@ def test_payload_too_big(self): def test_bad_reserved_bits(self): for encoded in [b"\xc0\x00", b"\xa0\x00", b"\x90\x00"]: with self.subTest(encoded=encoded): - with self.assertRaises(WebSocketProtocolError): + with self.assertRaises(ProtocolError): self.decode(encoded) def test_good_opcode(self): @@ -125,26 +125,26 @@ def test_bad_opcode(self): for opcode in list(range(0x03, 0x08)) + list(range(0x0B, 0x10)): encoded = bytes([0x80 | opcode, 0]) with self.subTest(encoded=encoded): - with self.assertRaises(WebSocketProtocolError): + with self.assertRaises(ProtocolError): self.decode(encoded) def test_mask_flag(self): # Mask flag correctly set. self.decode(b"\x80\x80\x00\x00\x00\x00", mask=True) # Mask flag incorrectly unset. - with self.assertRaises(WebSocketProtocolError): + with self.assertRaises(ProtocolError): self.decode(b"\x80\x80\x00\x00\x00\x00") # Mask flag correctly unset. self.decode(b"\x80\x00") # Mask flag incorrectly set. - with self.assertRaises(WebSocketProtocolError): + with self.assertRaises(ProtocolError): self.decode(b"\x80\x00", mask=True) def test_control_frame_max_length(self): # At maximum allowed length. self.decode(b"\x88\x7e\x00\x7d" + 125 * b"a") # Above maximum allowed length. - with self.assertRaises(WebSocketProtocolError): + with self.assertRaises(ProtocolError): self.decode(b"\x88\x7e\x00\x7e" + 126 * b"a") def test_prepare_data_str(self): @@ -201,7 +201,7 @@ def test_fragmented_control_frame(self): # Fin bit correctly set. self.decode(b"\x88\x00") # Fin bit incorrectly unset. - with self.assertRaises(WebSocketProtocolError): + with self.assertRaises(ProtocolError): self.decode(b"\x08\x00") def test_parse_close_and_serialize_close(self): @@ -212,15 +212,15 @@ def test_parse_close_empty(self): self.assertEqual(parse_close(b""), (1005, "")) def test_parse_close_errors(self): - with self.assertRaises(WebSocketProtocolError): + with self.assertRaises(ProtocolError): parse_close(b"\x03") - with self.assertRaises(WebSocketProtocolError): + with self.assertRaises(ProtocolError): parse_close(b"\x03\xe7") with self.assertRaises(UnicodeDecodeError): parse_close(b"\x03\xe8\xff\xff") def test_serialize_close_errors(self): - with self.assertRaises(WebSocketProtocolError): + with self.assertRaises(ProtocolError): serialize_close(999, "") def test_extensions(self): From c3681322989aab7c49b3bf94082690764f10c0a2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 7 Jul 2019 16:56:06 +0200 Subject: [PATCH 0623/1539] Use a plain dict to store pings. This is possible since Python 3.6 because dict preserves order. Also remove dependency on binascii for converting bytes to hex with bytes.hex() which is available since Python 3.5. Fix #645. --- src/websockets/protocol.py | 42 ++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 42ddf0763..ef935caf5 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -8,7 +8,6 @@ """ import asyncio -import binascii import codecs import collections import enum @@ -22,6 +21,7 @@ AsyncIterator, Awaitable, Deque, + Dict, Iterable, List, Optional, @@ -274,9 +274,7 @@ def __init__( self._fragmented_message_waiter: Optional[asyncio.Future[None]] = None # Mapping of ping IDs to waiters, in chronological order. - self.pings: collections.OrderedDict[ - bytes, asyncio.Future[None] - ] = collections.OrderedDict() + self.pings: Dict[bytes, asyncio.Future[None]] = {} # Task running the data transfer. self.transfer_data_task: asyncio.Task[None] @@ -954,23 +952,29 @@ async def read_data_frame(self, max_size: int) -> Optional[Frame]: elif frame.opcode == OP_PONG: # Acknowledge pings on solicited pongs. if frame.data in self.pings: + logger.debug( + "%s - received solicited pong: %s", + self.side, + frame.data.hex() or "[empty]", + ) # Acknowledge all pings up to the one matching this pong. ping_id = None ping_ids = [] - while ping_id != frame.data: - ping_id, pong_waiter = self.pings.popitem(last=False) + for ping_id, ping in self.pings.items(): ping_ids.append(ping_id) - if not pong_waiter.done(): - pong_waiter.set_result(None) - pong_hex = binascii.hexlify(frame.data).decode() or "[empty]" - logger.debug( - "%s - received solicited pong: %s", self.side, pong_hex - ) + if not ping.done(): + ping.set_result(None) + if ping_id == frame.data: + break + else: # pragma: no cover + assert False, "ping_id is in self.pings" + # Remove acknowledged pings from self.pings. + for ping_id in ping_ids: + del self.pings[ping_id] ping_ids = ping_ids[:-1] if ping_ids: pings_hex = ", ".join( - binascii.hexlify(ping_id).decode() or "[empty]" - for ping_id in ping_ids + ping_id.hex() or "[empty]" for ping_id in ping_ids ) plural = "s" if len(ping_ids) > 1 else "" logger.debug( @@ -980,9 +984,10 @@ async def read_data_frame(self, max_size: int) -> Optional[Frame]: pings_hex, ) else: - pong_hex = binascii.hexlify(frame.data).decode() or "[empty]" logger.debug( - "%s - received unsolicited pong: %s", self.side, pong_hex + "%s - received unsolicited pong: %s", + self.side, + frame.data.hex() or "[empty]", ) # 5.6. Data Frames @@ -1259,10 +1264,7 @@ def abort_pings(self) -> None: ping.cancel() if self.pings: - pings_hex = ", ".join( - binascii.hexlify(ping_id).decode() or "[empty]" - for ping_id in self.pings - ) + pings_hex = ", ".join(ping_id.hex() or "[empty]" for ping_id in self.pings) plural = "s" if len(self.pings) > 1 else "" logger.debug( "%s - aborted pending ping%s: %s", self.side, plural, pings_hex From 31ba3fad91a6add437a02a369d282683d5333840 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 7 Jul 2019 17:28:53 +0200 Subject: [PATCH 0624/1539] Add a task to build the C extension. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Declare all tasks as phony — this is only really necessary for build, but it can't hurt. --- Makefile | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/Makefile b/Makefile index d389623a7..c06de468e 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,5 @@ +.PHONY: default style test coverage build clean + export PYTHONASYNCIODEBUG=1 export PYTHONPATH=src @@ -16,6 +18,9 @@ coverage: python -m coverage html python -m coverage report --show-missing --fail-under=100 +build: + python setup.py build_ext --inplace + clean: find . -name '*.pyc' -o -name '*.so' -delete find . -name __pycache__ -delete From c01ae626a30891b2302b5b2df80296b5345f118a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 7 Jul 2019 17:30:47 +0200 Subject: [PATCH 0625/1539] Run all quality checks by default with make. --- Makefile | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Makefile b/Makefile index c06de468e..d9e16fefe 100644 --- a/Makefile +++ b/Makefile @@ -3,6 +3,8 @@ export PYTHONASYNCIODEBUG=1 export PYTHONPATH=src +default: coverage style + style: isort --recursive src tests black src tests From d8a3a98bddedb1949d8da3c902fecdf7ce020c50 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 7 Jul 2019 17:26:17 +0200 Subject: [PATCH 0626/1539] Deprecate host, port and secure attrs of protocols. Also factor out logic for testing deprecations. Fix #644. --- docs/api.rst | 6 +-- docs/changelog.rst | 9 ++++ src/websockets/client.py | 6 +-- src/websockets/protocol.py | 42 ++++++++++++----- tests/test_client_server.py | 93 +++++++++++++++++++++---------------- tests/utils.py | 11 +++++ 6 files changed, 108 insertions(+), 59 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index 28f41cc40..d265a91c2 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -44,7 +44,7 @@ Server .. automethod:: wait_closed .. autoattribute:: sockets - .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, host=None, port=None, secure=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None) + .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None) .. automethod:: handshake .. automethod:: process_request @@ -61,7 +61,7 @@ Client .. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, **kwds) :async: - .. autoclass:: WebSocketClientProtocol(*, host=None, port=None, secure=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, extensions=None, subprotocols=None, extra_headers=None) + .. autoclass:: WebSocketClientProtocol(*, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, extensions=None, subprotocols=None, extra_headers=None) .. automethod:: handshake @@ -70,7 +70,7 @@ Shared .. automodule:: websockets.protocol - .. autoclass:: WebSocketCommonProtocol(*, host=None, port=None, secure=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None) + .. autoclass:: WebSocketCommonProtocol(*, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None) .. automethod:: close .. automethod:: wait_closed diff --git a/docs/changelog.rst b/docs/changelog.rst index 12fc57749..cfad4a5b5 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -33,6 +33,15 @@ Changelog If you were setting ``max_queue=0`` to make the queue of incoming messages unbounded, change it to ``max_queue=None``. +.. note:: + + **Version 8.0 deprecates the** ``host`` **,** ``port`` **, and** ``secure`` + **attributes of** :class:`~protocol.WebSocketCommonProtocol`. + + Use :attr:`~protocol.WebSocketCommonProtocol.local_address` in servers and + :attr:`~protocol.WebSocketCommonProtocol.remote_address` in clients + instead of ``host`` and ``port``. + .. note:: **Version 8.0 renames the** ``WebSocketProtocolError`` **exception** diff --git a/src/websockets/client.py b/src/websockets/client.py index 89a624511..4d4a04cb8 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -437,9 +437,6 @@ def __init__( factory = functools.partial( create_protocol, - host=wsuri.host, - port=wsuri.port, - secure=wsuri.secure, ping_interval=ping_interval, ping_timeout=ping_timeout, close_timeout=close_timeout, @@ -448,6 +445,9 @@ def __init__( read_limit=read_limit, write_limit=write_limit, loop=loop, + host=wsuri.host, + port=wsuri.port, + secure=wsuri.secure, legacy_recv=legacy_recv, origin=origin, extensions=extensions, diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index ef935caf5..77dad5e1d 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -92,11 +92,6 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): It raises a :exc:`~websockets.exceptions.ConnectionClosedError` exception when the connection is closed with any other code. - When initializing a :class:`WebSocketCommonProtocol`, the ``host``, - ``port``, and ``secure`` parameters are stored as attributes for backwards - compatibility. Consider using :attr:`local_address` on the server side and - :attr:`remote_address` on the client side instead. - Once the connection is open, a `Ping frame`_ is sent every ``ping_interval`` seconds. This serves as a keepalive. It helps keeping the connection open, especially in the presence of proxies with short @@ -185,9 +180,6 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): def __init__( self, *, - host: Optional[str] = None, - port: Optional[int] = None, - secure: Optional[bool] = None, ping_interval: float = 20, ping_timeout: float = 20, close_timeout: Optional[float] = None, @@ -196,6 +188,10 @@ def __init__( read_limit: int = 2 ** 16, write_limit: int = 2 ** 16, loop: Optional[asyncio.AbstractEventLoop] = None, + # The following arguments are kept only for backwards compatibility. + host: Optional[str] = None, + port: Optional[int] = None, + secure: Optional[bool] = None, legacy_recv: bool = False, timeout: Optional[float] = None, ) -> None: @@ -208,9 +204,6 @@ def __init__( if close_timeout is None: close_timeout = timeout - self.host = host - self.port = port - self.secure = secure self.ping_interval = ping_interval self.ping_timeout = ping_timeout self.close_timeout = close_timeout @@ -225,6 +218,9 @@ def __init__( loop = asyncio.get_event_loop() self.loop = loop + self._host = host + self._port = port + self._secure = secure self.legacy_recv = legacy_recv # Configure read buffer limits. The high-water limit is defined by @@ -320,6 +316,23 @@ def connection_open(self) -> None: # Start the task that eventually closes the TCP connection. self.close_connection_task = self.loop.create_task(self.close_connection()) + @property + def host(self) -> Optional[str]: + alternative = "remote_address" if self.is_client else "local_address" + warnings.warn(f"use {alternative}[0] instead of host", DeprecationWarning) + return self._host + + @property + def port(self) -> Optional[int]: + alternative = "remote_address" if self.is_client else "local_address" + warnings.warn(f"use {alternative}[1] instead of port", DeprecationWarning) + return self._port + + @property + def secure(self) -> Optional[bool]: + warnings.warn(f"don't use secure", DeprecationWarning) + return self._secure + # Public API @property @@ -1144,7 +1157,12 @@ async def close_connection(self) -> None: # If connection_lost() was called, the TCP connection is closed. # However, if TLS is enabled, the transport still needs closing. # Else asyncio complains: ResourceWarning: unclosed transport. - if self.connection_lost_waiter.done() and not self.secure: + try: + writer_is_closing = self.writer.is_closing # type: ignore + except AttributeError: # pragma: no cover + # Python < 3.7 + writer_is_closing = self.writer.transport.is_closing + if self.connection_lost_waiter.done() and writer_is_closing(): return # Close the TCP connection. Buffers are flushed asynchronously. diff --git a/tests/test_client_server.py b/tests/test_client_server.py index aa4bebdc2..e74ec6bf6 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -51,7 +51,8 @@ async def handler(ws, path): - if path == "/attributes": + if path == "/deprecated_attributes": + await ws.recv() # delay that allows catching warnings await ws.send(repr((ws.host, ws.port, ws.secure))) elif path == "/close_timeout": await ws.send(repr(ws.close_timeout)) @@ -238,7 +239,7 @@ def setUp(self): def server_context(self): return None - def start_server(self, expected_warning=None, **kwargs): + def start_server(self, deprecation_warnings=None, **kwargs): # Disable compression by default in tests. kwargs.setdefault("compression", None) # Disable pings by default in tests. @@ -248,13 +249,8 @@ def start_server(self, expected_warning=None, **kwargs): start_server = serve(handler, "localhost", 0, **kwargs) self.server = self.loop.run_until_complete(start_server) - if expected_warning is None: - self.assertEqual(len(recorded_warnings), 0) - else: - self.assertEqual(len(recorded_warnings), 1) - actual_warning = recorded_warnings[0].message - self.assertEqual(str(actual_warning), expected_warning) - self.assertEqual(type(actual_warning), DeprecationWarning) + expected_warnings = [] if deprecation_warnings is None else deprecation_warnings + self.assertDeprecationWarnings(recorded_warnings, expected_warnings) def start_redirecting_server( self, status, include_location=True, force_insecure=False @@ -278,7 +274,7 @@ async def process_request(path, headers): self.redirecting_server = self.loop.run_until_complete(start_server) def start_client( - self, resource_name="/", user_info=None, expected_warning=None, **kwargs + self, resource_name="/", user_info=None, deprecation_warnings=None, **kwargs ): # Disable compression by default in tests. kwargs.setdefault("compression", None) @@ -295,13 +291,8 @@ def start_client( start_client = connect(server_uri, **kwargs) self.client = self.loop.run_until_complete(start_client) - if expected_warning is None: - self.assertEqual(len(recorded_warnings), 0) - else: - self.assertEqual(len(recorded_warnings), 1) - actual_warning = recorded_warnings[0].message - self.assertEqual(str(actual_warning), expected_warning) - self.assertEqual(type(actual_warning), DeprecationWarning) + expected_warnings = [] if deprecation_warnings is None else deprecation_warnings + self.assertDeprecationWarnings(recorded_warnings, expected_warnings) def stop_client(self): try: @@ -539,10 +530,9 @@ def test_process_request_argument_backwards_compatibility(self): with contextlib.closing(response): self.assertEqual(response.code, 200) - self.assertEqual(len(recorded_warnings), 1) - warning = recorded_warnings[0].message - self.assertEqual(str(warning), "declare process_request as a coroutine") - self.assertEqual(type(warning), DeprecationWarning) + self.assertDeprecationWarnings( + recorded_warnings, ["declare process_request as a coroutine"] + ) class ProcessRequestOKServerProtocol(WebSocketServerProtocol): async def process_request(self, path, request_headers): @@ -567,10 +557,9 @@ def test_process_request_override_backwards_compatibility(self): with contextlib.closing(response): self.assertEqual(response.code, 200) - self.assertEqual(len(recorded_warnings), 1) - warning = recorded_warnings[0].message - self.assertEqual(str(warning), "declare process_request as a coroutine") - self.assertEqual(type(warning), DeprecationWarning) + self.assertDeprecationWarnings( + recorded_warnings, ["declare process_request as a coroutine"] + ) def select_subprotocol_chat(client_subprotocols, server_subprotocols): return "chat" @@ -599,18 +588,37 @@ def test_select_subprotocol_override(self): self.assertEqual(self.client.subprotocol, "chat") @with_server() - @with_client("/attributes") - def test_protocol_attributes(self): + @with_client("/deprecated_attributes") + def test_protocol_deprecated_attributes(self): # The test could be connecting with IPv6 or IPv4. expected_client_attrs = [ server_socket.getsockname()[:2] + (self.secure,) for server_socket in self.server.sockets ] - client_attrs = (self.client.host, self.client.port, self.client.secure) + with warnings.catch_warnings(record=True) as recorded_warnings: + client_attrs = (self.client.host, self.client.port, self.client.secure) + self.assertDeprecationWarnings( + recorded_warnings, + [ + "use remote_address[0] instead of host", + "use remote_address[1] instead of port", + "don't use secure", + ], + ) self.assertIn(client_attrs, expected_client_attrs) expected_server_attrs = ("localhost", 0, self.secure) - server_attrs = self.loop.run_until_complete(self.client.recv()) + with warnings.catch_warnings(record=True) as recorded_warnings: + self.loop.run_until_complete(self.client.send("")) + server_attrs = self.loop.run_until_complete(self.client.recv()) + self.assertDeprecationWarnings( + recorded_warnings, + [ + "use local_address[0] instead of host", + "use local_address[1] instead of port", + "don't use secure", + ], + ) self.assertEqual(server_attrs, repr(expected_server_attrs)) @with_server() @@ -770,7 +778,7 @@ def test_server_create_protocol_function(self): @with_server( klass=UnauthorizedServerProtocol, - expected_warning="rename klass to create_protocol", + deprecation_warnings=["rename klass to create_protocol"], ) def test_server_klass_backwards_compatibility(self): self.assert_client_raises_code(401) @@ -778,7 +786,7 @@ def test_server_klass_backwards_compatibility(self): @with_server( create_protocol=ForbiddenServerProtocol, klass=UnauthorizedServerProtocol, - expected_warning="rename klass to create_protocol", + deprecation_warnings=["rename klass to create_protocol"], ) def test_server_create_protocol_over_klass(self): self.assert_client_raises_code(403) @@ -800,7 +808,7 @@ def test_client_create_protocol_function(self): @with_client( "/path", klass=FooClientProtocol, - expected_warning="rename klass to create_protocol", + deprecation_warnings=["rename klass to create_protocol"], ) def test_client_klass(self): self.assertIsInstance(self.client, FooClientProtocol) @@ -810,7 +818,7 @@ def test_client_klass(self): "/path", create_protocol=BarClientProtocol, klass=FooClientProtocol, - expected_warning="rename klass to create_protocol", + deprecation_warnings=["rename klass to create_protocol"], ) def test_client_create_protocol_over_klass(self): self.assertIsInstance(self.client, BarClientProtocol) @@ -821,14 +829,16 @@ def test_server_close_timeout(self): close_timeout = self.loop.run_until_complete(self.client.recv()) self.assertEqual(eval(close_timeout), 7) - @with_server(timeout=6, expected_warning="rename timeout to close_timeout") + @with_server(timeout=6, deprecation_warnings=["rename timeout to close_timeout"]) @with_client("/close_timeout") def test_server_timeout_backwards_compatibility(self): close_timeout = self.loop.run_until_complete(self.client.recv()) self.assertEqual(eval(close_timeout), 6) @with_server( - close_timeout=7, timeout=6, expected_warning="rename timeout to close_timeout" + close_timeout=7, + timeout=6, + deprecation_warnings=["rename timeout to close_timeout"], ) @with_client("/close_timeout") def test_server_close_timeout_over_timeout(self): @@ -842,7 +852,9 @@ def test_client_close_timeout(self): @with_server() @with_client( - "/close_timeout", timeout=6, expected_warning="rename timeout to close_timeout" + "/close_timeout", + timeout=6, + deprecation_warnings=["rename timeout to close_timeout"], ) def test_client_timeout_backwards_compatibility(self): self.assertEqual(self.client.close_timeout, 6) @@ -852,7 +864,7 @@ def test_client_timeout_backwards_compatibility(self): "/close_timeout", close_timeout=7, timeout=6, - expected_warning="rename timeout to close_timeout", + deprecation_warnings=["rename timeout to close_timeout"], ) def test_client_close_timeout_over_timeout(self): self.assertEqual(self.client.close_timeout, 7) @@ -1352,10 +1364,9 @@ def test_checking_lack_of_origin_succeeds_backwards_compatibility(self): ) client = self.loop.run_until_complete(connect(get_server_uri(server))) - self.assertEqual(len(recorded_warnings), 1) - warning = recorded_warnings[0].message - self.assertEqual(str(warning), "use None instead of '' in origins") - self.assertEqual(type(warning), DeprecationWarning) + self.assertDeprecationWarnings( + recorded_warnings, ["use None instead of '' in origins"] + ) self.loop.run_until_complete(client.send("Hello!")) self.assertEqual(self.loop.run_until_complete(client.recv()), "Hello!") diff --git a/tests/utils.py b/tests/utils.py index 2c067f8e6..983a91edf 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -69,6 +69,17 @@ def assertNoLogs(self, logger="websockets", level=logging.ERROR): level_name = logging.getLevelName(level) self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) + def assertDeprecationWarnings(self, recorded_warnings, expected_warnings): + """ + Check recorded deprecation warnings match a list of expected messages. + + """ + self.assertEqual(len(recorded_warnings), len(expected_warnings)) + for recorded, expected in zip(recorded_warnings, expected_warnings): + actual = recorded.message + self.assertEqual(str(actual), expected) + self.assertEqual(type(actual), DeprecationWarning) + # Unit for timeouts. May be increased on slow machines by setting the # WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. From 8d907e029996a5563ceb5b65e02406c442674733 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 7 Jul 2019 19:15:41 +0200 Subject: [PATCH 0627/1539] Explain close code 1006. Also remove redundant message from ConnectionClosed exception: it's pretty clear that websockets.exceptions.ConnectionClosed[Error|OK] means that a WebSocket connection is closed. This makes the close code more prominent and increases the chances that users will find the explanation in the FAQ. Fix #579. Fix #624. --- docs/cheatsheet.rst | 2 ++ docs/faq.rst | 54 ++++++++++++++++++++++++++++++++++-- src/websockets/exceptions.py | 4 +-- tests/test_exceptions.py | 18 ++++-------- 4 files changed, 61 insertions(+), 17 deletions(-) diff --git a/docs/cheatsheet.rst b/docs/cheatsheet.rst index 15a731084..f897326a6 100644 --- a/docs/cheatsheet.rst +++ b/docs/cheatsheet.rst @@ -60,6 +60,8 @@ Client * If you aren't using :func:`~client.connect` as a context manager, call :meth:`~protocol.WebSocketCommonProtocol.close` to terminate the connection. +.. _debugging: + Debugging --------- diff --git a/docs/faq.rst b/docs/faq.rst index 3dfdb5bcd..cea3f5358 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -132,8 +132,8 @@ Look at the ``ssl`` argument of :meth:`~asyncio.loop.create_connection`. :func:`connect` accepts the same arguments as :meth:`~asyncio.loop.create_connection`. -Architecture ------------- +Both sides +---------- How do I do two things in parallel? How do I integrate with another coroutine? .............................................................................. @@ -154,6 +154,56 @@ websockets doesn't have built-in publish / subscribe for these use cases. Depending on the scale of your service, a simple in-memory implementation may do the job or you may need an external publish / subscribe component. +What does ``ConnectionClosedError: code = 1006`` mean? +...................................................... + +If you're seeing this traceback in the logs of a server: + +.. code-block:: pytb + + Error in connection handler + Traceback (most recent call last): + ... + asyncio.streams.IncompleteReadError: 0 bytes read on a total of 2 expected bytes + + The above exception was the direct cause of the following exception: + + Traceback (most recent call last): + ... + websockets.exceptions.ConnectionClosedError: code = 1006 (connection closed abnormally [internal]), no reason + +or if a client crashes with this traceback: + +.. code-block:: pytb + + Traceback (most recent call last): + ... + ConnectionResetError: [Errno 54] Connection reset by peer + + The above exception was the direct cause of the following exception: + + Traceback (most recent call last): + ... + websockets.exceptions.ConnectionClosedError: code = 1006 (connection closed abnormally [internal]), no reason + +it means that the TCP connection was lost. As a consequence, the WebSocket +connection was closed without receiving a close frame, which is abnormal. + +You can catch and handle :exc:`~exceptions.ConnectionClosed` to prevent it +from being logged. + +There are several reasons why long-lived connections may be lost: + +* End-user devices tend to lose network connectivity often and unpredictably + because they can move out of wireless network coverage, get unplugged from + a wired network, enter airplane mode, be put to sleep, etc. +* HTTP load balancers or proxies that aren't configured for long-lived + connections may terminate connections after a short amount of time, usually + 30 seconds. + +If you're facing a reproducible issue, :ref:`enable debug logs ` to +see when and how connections are closed. + Are there ``onopen``, ``onmessage``, ``onerror``, and ``onclose`` callbacks? ............................................................................ diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index f03ab72f2..558bdec24 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -117,9 +117,7 @@ class ConnectionClosed(WebSocketException): def __init__(self, code: int, reason: str) -> None: self.code = code self.reason = reason - message = "WebSocket connection is closed: " - message += format_close(code, reason) - super().__init__(message) + super().__init__(format_close(code, reason)) class ConnectionClosedError(ConnectionClosed): diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 72b1076ab..7ad5ad833 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -14,33 +14,27 @@ def test_str(self): ), ( ConnectionClosed(1000, ""), - "WebSocket connection is closed: code = 1000 " - "(OK), no reason", + "code = 1000 (OK), no reason", ), ( ConnectionClosed(1006, None), - "WebSocket connection is closed: code = 1006 " - "(connection closed abnormally [internal]), no reason" + "code = 1006 (connection closed abnormally [internal]), no reason" ), ( ConnectionClosed(3000, None), - "WebSocket connection is closed: code = 3000 " - "(registered), no reason" + "code = 3000 (registered), no reason" ), ( ConnectionClosed(4000, None), - "WebSocket connection is closed: code = 4000 " - "(private use), no reason" + "code = 4000 (private use), no reason" ), ( ConnectionClosedError(1016, None), - "WebSocket connection is closed: code = 1016 " - "(unknown), no reason" + "code = 1016 (unknown), no reason" ), ( ConnectionClosedOK(1001, "bye"), - "WebSocket connection is closed: code = 1001 " - "(going away), reason = bye", + "code = 1001 (going away), reason = bye", ), ( InvalidHandshake("invalid request"), From a28fed3694f45fbfbc367afa9c51beb2f296a82d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 7 Jul 2019 19:29:56 +0200 Subject: [PATCH 0628/1539] Proof-read changelog. --- docs/changelog.rst | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index cfad4a5b5..59914b8ba 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -23,8 +23,8 @@ Changelog :meth:`~protocol.WebSocketServerProtocol.process_request` in a subclass, define it with ``async def`` instead of ``def``. - For backwards compatibility, functions are still supported. However, in - some inheritance scenarios, mixing functions and coroutines won't work. + For backwards compatibility, functions are still mostly supported, but + mixing functions and coroutines won't work in some inheritance scenarios. .. note:: @@ -45,9 +45,9 @@ Changelog .. note:: **Version 8.0 renames the** ``WebSocketProtocolError`` **exception** - :exc:`ProtocolError` **.** + to :exc:`ProtocolError` **.** - For backwards compatibility, a ``WebSocketProtocolError`` is provided. + A ``WebSocketProtocolError`` alias provides backwards compatibility. .. note:: @@ -66,7 +66,7 @@ Also: :exc:`~exceptions.ConnectionClosed` to tell apart normal connection termination from errors. -* Added :func:`~auth.basic_auth_protocol_factory` to provide HTTP Basic Auth +* Added :func:`~auth.basic_auth_protocol_factory` to enforce HTTP Basic Auth on the server side. * :func:`~client.connect` handles redirects from the server during the @@ -80,14 +80,13 @@ Also: iterators in :meth:`~protocol.WebSocketCommonProtocol.send`. * Prevented spurious log messages about :exc:`~exceptions.ConnectionClosed` - exceptions in keepalive ping task. + exceptions in keepalive ping task. If you were using ``ping_timeout=None`` + as a workaround, you can remove it. - If you were using ``ping_timeout=None`` as a workaround, you can remove it. +* Changed :meth:`WebSocketServer.close() ` to + perform a proper closing handshake instead of failing the connection. -* Changed :meth:`~server.WebSocketServer.close` to perform a proper closing - handshake instead of failing the connection. - -* Avoided a crash of a ``extra_headers`` callable returns ``None``. +* Avoided a crash when a ``extra_headers`` callable returns ``None``. * Improved error messages when HTTP parsing fails. From faa4c55c17c11fbb0441985613cf519eccb51c2c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 7 Jul 2019 19:31:04 +0200 Subject: [PATCH 0629/1539] Bump version number. --- docs/changelog.rst | 5 ++++- docs/conf.py | 4 ++-- src/websockets/version.py | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 59914b8ba..e81d80d85 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -3,11 +3,14 @@ Changelog .. currentmodule:: websockets -8.0 +8.1 ... *In development* +8.0 +... + .. warning:: **Version 8.0 drops compatibility with Python 3.4 and 3.5.** diff --git a/docs/conf.py b/docs/conf.py index e5e6ab15f..1241a49fb 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -59,9 +59,9 @@ # built documents. # # The short X.Y version. -version = '7.0' +version = '8.0' # The full version, including alpha/beta/rc tags. -release = '7.0' +release = '8.0' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/src/websockets/version.py b/src/websockets/version.py index 96b948d8a..1aa0a5ebc 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -1 +1 @@ -version = "7.0" +version = "8.0" From 02af45351df41603c2767b004b29ab158337a667 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 7 Jul 2019 20:06:27 +0200 Subject: [PATCH 0630/1539] PyPI disables the "raw" directive. --- README.rst | 4 +--- setup.py | 9 +++++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/README.rst b/README.rst index 7395d803a..e2ea6df69 100644 --- a/README.rst +++ b/README.rst @@ -87,9 +87,7 @@ Does that look good?

Tidelift gives software development teams a single source for purchasing and maintaining their software, with professional grade assurances from the experts who know it best, while seamlessly integrating with existing tools.

Get supported websockets with the Tidelift Subscription


- -(If you contribute to ``websockets`` and would like to become an official -support provider, `let me know `_.) +

(If you contribute to ``websockets`` and would like to become an official support provider, let me know.)

Why should I use ``websockets``? -------------------------------- diff --git a/setup.py b/setup.py index 3c87b2339..1ea735cb6 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,5 @@ import pathlib +import re import sys import setuptools @@ -10,6 +11,14 @@ long_description = (root_dir / 'README.rst').read_text(encoding='utf-8') +# PyPI disables the "raw" directive. +long_description = re.sub( + r"^\.\. raw:: html.*?^(?=\w)", + "", + long_description, + flags=re.DOTALL | re.MULTILINE, +) + exec((root_dir / 'src' / 'websockets' / 'version.py').read_text(encoding='utf-8')) py_version = sys.version_info[:2] From ec50f6b2b965f9ffa48a5760ed72376796728ede Mon Sep 17 00:00:00 2001 From: Manu NALEPA Date: Fri, 12 Jul 2019 17:29:50 +0200 Subject: [PATCH 0631/1539] __main.py__: Fix typo --- src/websockets/__main__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index 57d2a823b..bccb8aa52 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -48,7 +48,7 @@ def exit_from_event_loop_thread( loop.stop() if not stop.done(): # When exiting the thread that runs the event loop, raise - # KeyboardInterrupt in the main thead to exit the program. + # KeyboardInterrupt in the main thread to exit the program. try: ctrl_c = signal.CTRL_C_EVENT # Windows except AttributeError: From c1af276ab1e9fb1c323fe232e6ed768a912b61b8 Mon Sep 17 00:00:00 2001 From: Harmon Date: Mon, 15 Jul 2019 15:13:27 -0500 Subject: [PATCH 0632/1539] Re-expose WebSocketProtocolError --- src/websockets/exceptions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 558bdec24..9873a1717 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -57,6 +57,7 @@ "InvalidURI", "PayloadTooBig", "ProtocolError", + "WebSocketProtocolError", ] From f1f5d7d37927b020dd39c37bc75415c79b0d5b59 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jul 2019 07:37:12 +0200 Subject: [PATCH 0633/1539] Add changelog for #649.. --- docs/changelog.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index e81d80d85..8e862c5ec 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -8,6 +8,12 @@ Changelog *In development* +8.0.1 +..... + +* Restored the ability to import ``WebSocketProtocolError`` from + ``websockets``. + 8.0 ... From 5d059da31f0d967ddf300a15f03fd00a92c8712f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jul 2019 07:38:10 +0200 Subject: [PATCH 0634/1539] Bump version number. --- docs/conf.py | 2 +- src/websockets/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 1241a49fb..560140f9b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -61,7 +61,7 @@ # The short X.Y version. version = '8.0' # The full version, including alpha/beta/rc tags. -release = '8.0' +release = '8.0.1' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/src/websockets/version.py b/src/websockets/version.py index 1aa0a5ebc..add721549 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -1 +1 @@ -version = "8.0" +version = "8.0.1" From 7e0a651a06963c0a30f6c4888a30a9e7d3a7ad68 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 31 Jul 2019 20:40:51 +0200 Subject: [PATCH 0635/1539] Remove incorrect assertion. create_server must receive either host + port or sock. It does its own checks anyway; we don't need to replicate them. Fix #659. --- docs/changelog.rst | 6 ++++++ src/websockets/server.py | 3 --- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 8e862c5ec..6ed63b654 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -8,6 +8,12 @@ Changelog *In development* +8.0.2 +..... + +* Restored the ability to pass a socket with the ``sock`` parameter of + :func:`~server.serve`. + 8.0.1 ..... diff --git a/src/websockets/server.py b/src/websockets/server.py index 446f1db7f..b220a1b88 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -928,9 +928,6 @@ def __init__( ) if path is None: - # serve(..., host, port) must specify host and port parameters. - # host can be None to listen on all interfaces; port cannot be None. - assert port is not None create_server = functools.partial( loop.create_server, factory, host, port, **kwargs ) From fac562ddd5e6004949acd504c48fc91f2558593f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 31 Jul 2019 21:11:33 +0200 Subject: [PATCH 0636/1539] Remove incorrect assertion. Fix #646. See the ticket for details. --- docs/changelog.rst | 2 ++ src/websockets/protocol.py | 5 ++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 6ed63b654..87b2e4380 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -14,6 +14,8 @@ Changelog * Restored the ability to pass a socket with the ``sock`` parameter of :func:`~server.serve`. +* Removed an incorrect assertion when a connection drops. + 8.0.1 ..... diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 77dad5e1d..e25f4aaee 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -479,7 +479,6 @@ async def recv(self) -> Data: if self.legacy_recv: return None # type: ignore else: - assert self.state in [State.CLOSING, State.CLOSED] # Wait until the connection is closed to raise # ConnectionClosed with the correct code and reason. await self.ensure_open() @@ -760,8 +759,8 @@ async def ensure_open(self) -> None: # Handle cases from most common to least common for performance. if self.state is State.OPEN: # If self.transfer_data_task exited without a closing handshake, - # self.close_connection_task may be closing it, going straight - # from OPEN to CLOSED. + # self.close_connection_task may be closing the connection, going + # straight from OPEN to CLOSED. if self.transfer_data_task.done(): await asyncio.shield(self.close_connection_task) raise self.connection_closed_exc() From e8deaf9a93302c291eb8c05456a5bf90e94d7b63 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 31 Jul 2019 21:14:25 +0200 Subject: [PATCH 0637/1539] Bump version number. --- docs/conf.py | 2 +- src/websockets/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 560140f9b..617989cb1 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -61,7 +61,7 @@ # The short X.Y version. version = '8.0' # The full version, including alpha/beta/rc tags. -release = '8.0.1' +release = '8.0.2' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/src/websockets/version.py b/src/websockets/version.py index add721549..cd8898041 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -1 +1 @@ -version = "8.0.1" +version = "8.0.2" From 3f444b1629237a6795c30d55f0775f4e75728bf3 Mon Sep 17 00:00:00 2001 From: Gunnlaugur Thor Briem Date: Mon, 12 Aug 2019 21:58:08 +0000 Subject: [PATCH 0638/1539] fix: permit None in type annotations Fix type annotations for four parameters which are documented to accept `None`. --- src/websockets/protocol.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index e25f4aaee..0b48d0dca 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -180,11 +180,11 @@ class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): def __init__( self, *, - ping_interval: float = 20, - ping_timeout: float = 20, + ping_interval: Optional[float] = 20, + ping_timeout: Optional[float] = 20, close_timeout: Optional[float] = None, - max_size: int = 2 ** 20, - max_queue: int = 2 ** 5, + max_size: Optional[int] = 2 ** 20, + max_queue: Optional[int] = 2 ** 5, read_limit: int = 2 ** 16, write_limit: int = 2 ** 16, loop: Optional[asyncio.AbstractEventLoop] = None, From 05d256da094759200016f123d787d315d86fc5c2 Mon Sep 17 00:00:00 2001 From: Gunnlaugur Thor Briem Date: Tue, 13 Aug 2019 09:47:54 +0000 Subject: [PATCH 0639/1539] fix: downstream type annotations/assertions --- src/websockets/protocol.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 0b48d0dca..1f0edcce2 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -895,6 +895,7 @@ def append(frame: Frame) -> None: def append(frame: Frame) -> None: nonlocal chunks, max_size chunks.append(decoder.decode(frame.data, frame.fin)) + assert isinstance(max_size, int) max_size -= len(frame.data) else: @@ -909,6 +910,7 @@ def append(frame: Frame) -> None: def append(frame: Frame) -> None: nonlocal chunks, max_size chunks.append(frame.data) + assert isinstance(max_size, int) max_size -= len(frame.data) append(frame) @@ -924,7 +926,7 @@ def append(frame: Frame) -> None: # mypy cannot figure out that chunks have the proper type. return ("" if text else b"").join(chunks) # type: ignore - async def read_data_frame(self, max_size: int) -> Optional[Frame]: + async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: """ Read a single data frame from the connection. @@ -1006,7 +1008,7 @@ async def read_data_frame(self, max_size: int) -> Optional[Frame]: else: return frame - async def read_frame(self, max_size: int) -> Frame: + async def read_frame(self, max_size: Optional[int]) -> Frame: """ Read a single frame from the connection. From 4ccc512861e1b56d9152b93e133f2ec9c6118c21 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 21 Aug 2019 15:18:43 +0200 Subject: [PATCH 0640/1539] Fix typo in docstring. --- src/websockets/http.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/websockets/http.py b/src/websockets/http.py index e78a149ed..ba6d274bf 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -69,7 +69,7 @@ def d(value: bytes) -> str: async def read_request(stream: asyncio.StreamReader) -> Tuple[str, "Headers"]: """ - Read an HTTP/1.1 GET request and returns ``(path, headers)``. + Read an HTTP/1.1 GET request and return ``(path, headers)``. ``path`` isn't URL-decoded or validated in any way. @@ -115,7 +115,7 @@ async def read_request(stream: asyncio.StreamReader) -> Tuple[str, "Headers"]: async def read_response(stream: asyncio.StreamReader) -> Tuple[int, str, "Headers"]: """ - Read an HTTP/1.1 response and returns ``(status_code, reason, headers)``. + Read an HTTP/1.1 response and return ``(status_code, reason, headers)``. ``reason`` and ``headers`` are expected to contain only ASCII characters. Other characters are represented with surrogate escapes. From a693ec8cfcf206dfe7b917711e20600cdceb802e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 21 Aug 2019 15:20:28 +0200 Subject: [PATCH 0641/1539] Update description of default TLS contexts. --- src/websockets/client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 4d4a04cb8..c1fdf88a0 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -325,8 +325,8 @@ class Connect: For example, you can set the ``ssl`` keyword argument to a :class:`~ssl.SSLContext` to enforce some TLS settings. When connecting to - a ``wss://`` URI, if this argument isn't provided explicitly, it's set to - ``True``, which means Python's default :class:`~ssl.SSLContext` is used. + a ``wss://`` URI, if this argument isn't provided explicitly, + :func:`ssl.create_default_context` is called to create a context. You can connect to a different host and port from those found in ``uri`` by setting ``host`` and ``port`` keyword arguments. This only changes the From 46ddc64b3ab02f38579880a812b9c04da6d89ae1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 21 Aug 2019 15:32:14 +0200 Subject: [PATCH 0642/1539] Add a new type for extension names. --- src/websockets/extensions/base.py | 8 ++++---- src/websockets/extensions/permessage_deflate.py | 8 ++++---- src/websockets/headers.py | 10 ++++++---- src/websockets/typing.py | 6 +++++- 4 files changed, 19 insertions(+), 13 deletions(-) diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index 7d46687c6..aa52a7adb 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -11,7 +11,7 @@ from typing import List, Optional, Sequence, Tuple from ..framing import Frame -from ..typing import ExtensionParameter +from ..typing import ExtensionName, ExtensionParameter __all__ = ["Extension", "ClientExtensionFactory", "ServerExtensionFactory"] @@ -24,7 +24,7 @@ class Extension: """ @property - def name(self) -> str: + def name(self) -> ExtensionName: """ Extension identifier. @@ -55,7 +55,7 @@ class ClientExtensionFactory: """ @property - def name(self) -> str: + def name(self) -> ExtensionName: """ Extension identifier. @@ -92,7 +92,7 @@ class ServerExtensionFactory: """ @property - def name(self) -> str: + def name(self) -> ExtensionName: """ Extension identifier. diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index a41fd56ca..e38d9edab 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -15,7 +15,7 @@ PayloadTooBig, ) from ..framing import CTRL_OPCODES, OP_CONT, Frame -from ..typing import ExtensionParameter +from ..typing import ExtensionName, ExtensionParameter from .base import ClientExtensionFactory, Extension, ServerExtensionFactory @@ -36,7 +36,7 @@ class PerMessageDeflate(Extension): """ - name = "permessage-deflate" + name = ExtensionName("permessage-deflate") def __init__( self, @@ -274,7 +274,7 @@ class ClientPerMessageDeflateFactory(ClientExtensionFactory): """ - name = "permessage-deflate" + name = ExtensionName("permessage-deflate") def __init__( self, @@ -445,7 +445,7 @@ class ServerPerMessageDeflateFactory(ServerExtensionFactory): """ - name = "permessage-deflate" + name = ExtensionName("permessage-deflate") def __init__( self, diff --git a/src/websockets/headers.py b/src/websockets/headers.py index ac850654e..f33c94c04 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -13,7 +13,7 @@ from typing import Callable, List, NewType, Optional, Sequence, Tuple, TypeVar, cast from .exceptions import InvalidHeaderFormat, InvalidHeaderValue -from .typing import ExtensionHeader, ExtensionParameter, Subprotocol +from .typing import ExtensionHeader, ExtensionName, ExtensionParameter, Subprotocol __all__ = [ @@ -313,7 +313,7 @@ def parse_extension_item( pos = parse_OWS(header, pos + 1) parameter, pos = parse_extension_item_param(header, pos, header_name) parameters.append(parameter) - return (name, parameters), pos + return (cast(ExtensionName, name), parameters), pos def parse_extension(header: str) -> List[ExtensionHeader]: @@ -344,7 +344,9 @@ def parse_extension(header: str) -> List[ExtensionHeader]: parse_extension_list = parse_extension # alias for backwards compatibility -def build_extension_item(name: str, parameters: List[ExtensionParameter]) -> str: +def build_extension_item( + name: ExtensionName, parameters: List[ExtensionParameter] +) -> str: """ Build an extension definition. @@ -352,7 +354,7 @@ def build_extension_item(name: str, parameters: List[ExtensionParameter]) -> str """ return "; ".join( - [name] + [cast(str, name)] + [ # Quoted strings aren't necessary because values are always tokens. name if value is None else f"{name}={value}" diff --git a/src/websockets/typing.py b/src/websockets/typing.py index 3847701b2..4a60f93f6 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -23,6 +23,10 @@ Origin.__doc__ = """Value of a Origin header""" +ExtensionName = NewType("ExtensionName", str) +ExtensionName.__doc__ = """Name of a WebSocket extension""" + + ExtensionParameter = Tuple[str, Optional[str]] ExtensionParameter__doc__ = """Parameter of a WebSocket extension""" @@ -32,7 +36,7 @@ pass -ExtensionHeader = Tuple[str, List[ExtensionParameter]] +ExtensionHeader = Tuple[ExtensionName, List[ExtensionParameter]] ExtensionHeader__doc__ = """Item parsed in a Sec-WebSocket-Extensions header""" try: From a181964557eb94a17f7162a51810a8480bc1c896 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 22 Sep 2019 12:12:27 +0200 Subject: [PATCH 0643/1539] Build docs with Python 3.7. --- .readthedocs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.readthedocs.yml b/.readthedocs.yml index e5e224afd..109affab4 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -2,6 +2,6 @@ build: image: latest python: - version: 3.6 + version: 3.7 requirements_file: docs/requirements.txt From c6ee4a4111b5d17d5a63dd33e941f2b0d97837b4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 22 Sep 2019 12:10:44 +0200 Subject: [PATCH 0644/1539] =?UTF-8?q?Require=20Python=20=E2=89=A5=203.6.1.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit There've been multiple regressions where websockets stops working with Python 3.6 but works fine with Python 3.6.1 or higher. I don't have a good way to detect such regressions. I don't think it's a good practice to run anything but the latest 3.6.x (or 3.7.x, etc.) anyway. Instead of fighting a useless and losing battle, I'm moving the minimum requirement to 3.6.1. Strictly speaking, this isn't a backwards incompatible change. The incompatibility with Python 3.6 appeared in websockets 8.0, which was the first release that included 94945fec. Sure, I'm accepting the backwards incompatibility instead of fixing it... Judge me if you'd like, or support websockets on Tidelift — if it becomes profitable, I'll have an incentive to provide better support for older Python versions. Refs #655, #664, #667. --- README.rst | 2 +- docs/intro.rst | 2 +- setup.py | 8 +++----- src/websockets/framing.py | 30 +++++++----------------------- src/websockets/uri.py | 28 +++++++++++++--------------- 5 files changed, 25 insertions(+), 45 deletions(-) diff --git a/README.rst b/README.rst index e2ea6df69..5dc9a745d 100644 --- a/README.rst +++ b/README.rst @@ -128,7 +128,7 @@ Why shouldn't I use ``websockets``? and :rfc:`7692`: Compression Extensions for WebSocket. Its support for HTTP is minimal — just enough for a HTTP health check. * If you want to use Python 2: ``websockets`` builds upon ``asyncio`` which - only works on Python 3. ``websockets`` requires Python ≥ 3.6. + only works on Python 3. ``websockets`` requires Python ≥ 3.6.1. What else? ---------- diff --git a/docs/intro.rst b/docs/intro.rst index 14ba1b38a..8be700239 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -6,7 +6,7 @@ Getting started Requirements ------------ -``websockets`` requires Python ≥ 3.6. +``websockets`` requires Python ≥ 3.6.1. You should use the latest version of Python if possible. If you're using an older version, be aware that for each minor version (3.x), only the latest diff --git a/setup.py b/setup.py index 1ea735cb6..c76430104 100644 --- a/setup.py +++ b/setup.py @@ -21,10 +21,8 @@ exec((root_dir / 'src' / 'websockets' / 'version.py').read_text(encoding='utf-8')) -py_version = sys.version_info[:2] - -if py_version < (3, 6): - raise Exception("websockets requires Python >= 3.6.") +if sys.version_info[:3] < (3, 6, 1): + raise Exception("websockets requires Python >= 3.6.1.") packages = ['websockets', 'websockets/extensions'] @@ -62,6 +60,6 @@ ext_modules=ext_modules, include_package_data=True, zip_safe=False, - python_requires='>=3.6', + python_requires='>=3.6.1', test_loader='unittest:TestLoader', ) diff --git a/src/websockets/framing.py b/src/websockets/framing.py index 478a7b05a..81a3185b0 100644 --- a/src/websockets/framing.py +++ b/src/websockets/framing.py @@ -49,22 +49,10 @@ EXTERNAL_CLOSE_CODES = [1000, 1001, 1002, 1003, 1007, 1008, 1009, 1010, 1011] -# Remove FrameData when dropping support for Python < 3.6.1 — the first -# version where NamedTuple supports default values, methods, and docstrings. - # Consider converting to a dataclass when dropping support for Python < 3.7. -class FrameData(NamedTuple): - fin: bool - opcode: int - data: bytes - rsv1: bool - rsv2: bool - rsv3: bool - - -class Frame(FrameData): +class Frame(NamedTuple): """ WebSocket frame. @@ -80,16 +68,12 @@ class Frame(FrameData): """ - def __new__( - cls, - fin: bool, - opcode: int, - data: bytes, - rsv1: bool = False, - rsv2: bool = False, - rsv3: bool = False, - ) -> "Frame": - return FrameData.__new__(cls, fin, opcode, data, rsv1, rsv2, rsv3) + fin: bool + opcode: int + data: bytes + rsv1: bool = False + rsv2: bool = False + rsv3: bool = False @classmethod async def read( diff --git a/src/websockets/uri.py b/src/websockets/uri.py index cbb56524b..f5bbafa96 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -20,6 +20,19 @@ class WebSocketURI(NamedTuple): + """ + WebSocket URI. + + :param bool secure: secure flag + :param str host: lower-case host + :param int port: port, always set even if it's the default + :param str resource_name: path and optional query + :param str user_info: ``(username, password)`` tuple when the URI contains + `User Information`_, else ``None``. + + .. _User Information: https://tools.ietf.org/html/rfc3986#section-3.2.1 + """ + secure: bool host: str port: int @@ -27,21 +40,6 @@ class WebSocketURI(NamedTuple): user_info: Optional[Tuple[str, str]] -# Declare the docstring normally when dropping support for Python < 3.6.1. - -WebSocketURI.__doc__ = """ -WebSocket URI. - -:param bool secure: secure flag -:param str host: lower-case host -:param int port: port, always set even if it's the default -:param str resource_name: path and optional query -:param str user_info: ``(username, password)`` tuple when the URI contains - `User Information`_, else ``None``. - -.. _User Information: https://tools.ietf.org/html/rfc3986#section-3.2.1 -""" - # Work around https://bugs.python.org/issue19931 WebSocketURI.secure.__doc__ = "" From d72322764ee6a53fdd3a8a13a1a9bf324f7f844b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 22 Sep 2019 15:19:20 +0200 Subject: [PATCH 0645/1539] Clarify why we leave SIGINT alone. Ref #658. --- docs/deployment.rst | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/docs/deployment.rst b/docs/deployment.rst index 797284f3d..5b05afff1 100644 --- a/docs/deployment.rst +++ b/docs/deployment.rst @@ -32,11 +32,16 @@ with the object returned by :func:`~server.serve`: On Unix systems, shutdown is usually triggered by sending a signal. -Here's a full example (Unix-only): +Here's a full example for handling SIGTERM on Unix: .. literalinclude:: ../example/shutdown.py :emphasize-lines: 13,17-19 +This example is easily adapted to handle other signals. If you override the +default handler for SIGINT, which raises :exc:`KeyboardInterrupt`, be aware +that you won't be able to interrupt a program with Ctrl-C anymore when it's +stuck in a loop. + It's more difficult to achieve the same effect on Windows. Some third-party projects try to help with this problem. From 8800c0cb250897feda7c6e0db2767ff67bd480a2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 11:41:21 +0200 Subject: [PATCH 0646/1539] Copy FlowControlMixin and StreamReaderProtocol. This is the official recommendation of Python core devs. The code is taken from the current 3.7 branch. --- src/websockets/protocol.py | 127 ++++++++++++++++++++++++++++++++++++- 1 file changed, 126 insertions(+), 1 deletion(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 1f0edcce2..2f74cd23b 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -61,7 +61,132 @@ class State(enum.IntEnum): # between the check and the assignment. -class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): +class FlowControlMixin(asyncio.Protocol): + """Reusable flow control logic for StreamWriter.drain(). + This implements the protocol methods pause_writing(), + resume_writing() and connection_lost(). If the subclass overrides + these it must call the super methods. + StreamWriter.drain() must wait for _drain_helper() coroutine. + """ + + def __init__(self, loop=None): + if loop is None: + self._loop = asyncio.get_event_loop() + else: + self._loop = loop + self._paused = False + self._drain_waiter = None + self._connection_lost = False + + def pause_writing(self): + assert not self._paused + self._paused = True + if self._loop.get_debug(): + logger.debug("%r pauses writing", self) + + def resume_writing(self): + assert self._paused + self._paused = False + if self._loop.get_debug(): + logger.debug("%r resumes writing", self) + + waiter = self._drain_waiter + if waiter is not None: + self._drain_waiter = None + if not waiter.done(): + waiter.set_result(None) + + def connection_lost(self, exc): + self._connection_lost = True + # Wake up the writer if currently paused. + if not self._paused: + return + waiter = self._drain_waiter + if waiter is None: + return + self._drain_waiter = None + if waiter.done(): + return + if exc is None: + waiter.set_result(None) + else: + waiter.set_exception(exc) + + async def _drain_helper(self): + if self._connection_lost: + raise ConnectionResetError("Connection lost") + if not self._paused: + return + waiter = self._drain_waiter + assert waiter is None or waiter.cancelled() + waiter = self._loop.create_future() + self._drain_waiter = waiter + await waiter + + +class StreamReaderProtocol(FlowControlMixin, asyncio.Protocol): + """Helper class to adapt between Protocol and StreamReader. + (This is a helper class instead of making StreamReader itself a + Protocol subclass, because the StreamReader has other potential + uses, and to prevent the user of the StreamReader to accidentally + call inappropriate methods of the protocol.) + """ + + def __init__(self, stream_reader, client_connected_cb=None, loop=None): + super().__init__(loop=loop) + self._stream_reader = stream_reader + self._stream_writer = None + self._client_connected_cb = client_connected_cb + self._over_ssl = False + self._closed = self._loop.create_future() + + def connection_made(self, transport): + self._stream_reader.set_transport(transport) + self._over_ssl = transport.get_extra_info("sslcontext") is not None + if self._client_connected_cb is not None: + self._stream_writer = asyncio.StreamWriter( + transport, self, self._stream_reader, self._loop + ) + res = self._client_connected_cb(self._stream_reader, self._stream_writer) + if asyncio.iscoroutine(res): + self._loop.create_task(res) + + def connection_lost(self, exc): + if self._stream_reader is not None: + if exc is None: + self._stream_reader.feed_eof() + else: + self._stream_reader.set_exception(exc) + if not self._closed.done(): + if exc is None: + self._closed.set_result(None) + else: + self._closed.set_exception(exc) + super().connection_lost(exc) + self._stream_reader = None + self._stream_writer = None + + def data_received(self, data): + self._stream_reader.feed_data(data) + + def eof_received(self): + self._stream_reader.feed_eof() + if self._over_ssl: + # Prevent a warning in SSLProtocol.eof_received: + # "returning true from eof_received() + # has no effect when using ssl" + return False + return True + + def __del__(self): + # Prevent reports about unhandled exceptions. + # Better than self._closed._log_traceback = False hack + closed = self._closed + if closed.done() and not closed.cancelled(): + closed.exception() + + +class WebSocketCommonProtocol(StreamReaderProtocol): """ :class:`~asyncio.Protocol` subclass implementing the data transfer phase. From 7e7f747ca5267755ccb1bef397c3071baad9a2e1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 11:54:04 +0200 Subject: [PATCH 0647/1539] Remove docstrings and debug logs. --- src/websockets/protocol.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 2f74cd23b..d74c81576 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -62,12 +62,6 @@ class State(enum.IntEnum): class FlowControlMixin(asyncio.Protocol): - """Reusable flow control logic for StreamWriter.drain(). - This implements the protocol methods pause_writing(), - resume_writing() and connection_lost(). If the subclass overrides - these it must call the super methods. - StreamWriter.drain() must wait for _drain_helper() coroutine. - """ def __init__(self, loop=None): if loop is None: @@ -81,14 +75,10 @@ def __init__(self, loop=None): def pause_writing(self): assert not self._paused self._paused = True - if self._loop.get_debug(): - logger.debug("%r pauses writing", self) def resume_writing(self): assert self._paused self._paused = False - if self._loop.get_debug(): - logger.debug("%r resumes writing", self) waiter = self._drain_waiter if waiter is not None: @@ -125,12 +115,6 @@ async def _drain_helper(self): class StreamReaderProtocol(FlowControlMixin, asyncio.Protocol): - """Helper class to adapt between Protocol and StreamReader. - (This is a helper class instead of making StreamReader itself a - Protocol subclass, because the StreamReader has other potential - uses, and to prevent the user of the StreamReader to accidentally - call inappropriate methods of the protocol.) - """ def __init__(self, stream_reader, client_connected_cb=None, loop=None): super().__init__(loop=loop) From e7282008796ae30d3c3df5715b97f49e35309825 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 11:54:43 +0200 Subject: [PATCH 0648/1539] Merge FlowControlMixin in StreamReaderProtocol. --- src/websockets/protocol.py | 54 +++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 30 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index d74c81576..49d8b4f2d 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -61,9 +61,9 @@ class State(enum.IntEnum): # between the check and the assignment. -class FlowControlMixin(asyncio.Protocol): +class StreamReaderProtocol(asyncio.Protocol): - def __init__(self, loop=None): + def __init__(self, stream_reader, client_connected_cb=None, loop=None): if loop is None: self._loop = asyncio.get_event_loop() else: @@ -72,6 +72,12 @@ def __init__(self, loop=None): self._drain_waiter = None self._connection_lost = False + self._stream_reader = stream_reader + self._stream_writer = None + self._client_connected_cb = client_connected_cb + self._over_ssl = False + self._closed = self._loop.create_future() + def pause_writing(self): assert not self._paused self._paused = True @@ -86,22 +92,6 @@ def resume_writing(self): if not waiter.done(): waiter.set_result(None) - def connection_lost(self, exc): - self._connection_lost = True - # Wake up the writer if currently paused. - if not self._paused: - return - waiter = self._drain_waiter - if waiter is None: - return - self._drain_waiter = None - if waiter.done(): - return - if exc is None: - waiter.set_result(None) - else: - waiter.set_exception(exc) - async def _drain_helper(self): if self._connection_lost: raise ConnectionResetError("Connection lost") @@ -113,17 +103,6 @@ async def _drain_helper(self): self._drain_waiter = waiter await waiter - -class StreamReaderProtocol(FlowControlMixin, asyncio.Protocol): - - def __init__(self, stream_reader, client_connected_cb=None, loop=None): - super().__init__(loop=loop) - self._stream_reader = stream_reader - self._stream_writer = None - self._client_connected_cb = client_connected_cb - self._over_ssl = False - self._closed = self._loop.create_future() - def connection_made(self, transport): self._stream_reader.set_transport(transport) self._over_ssl = transport.get_extra_info("sslcontext") is not None @@ -146,7 +125,22 @@ def connection_lost(self, exc): self._closed.set_result(None) else: self._closed.set_exception(exc) - super().connection_lost(exc) + + self._connection_lost = True + # Wake up the writer if currently paused. + if not self._paused: + return + waiter = self._drain_waiter + if waiter is None: + return + self._drain_waiter = None + if waiter.done(): + return + if exc is None: + waiter.set_result(None) + else: + waiter.set_exception(exc) + self._stream_reader = None self._stream_writer = None From 42a436ce1dd37f4388a13d0c1591af7544c8bb1f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 11:55:53 +0200 Subject: [PATCH 0649/1539] Deduplicate loop and _loop attributes. --- src/websockets/protocol.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 49d8b4f2d..98c23ab1c 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -63,11 +63,7 @@ class State(enum.IntEnum): class StreamReaderProtocol(asyncio.Protocol): - def __init__(self, stream_reader, client_connected_cb=None, loop=None): - if loop is None: - self._loop = asyncio.get_event_loop() - else: - self._loop = loop + def __init__(self, stream_reader, client_connected_cb=None): self._paused = False self._drain_waiter = None self._connection_lost = False @@ -76,7 +72,7 @@ def __init__(self, stream_reader, client_connected_cb=None, loop=None): self._stream_writer = None self._client_connected_cb = client_connected_cb self._over_ssl = False - self._closed = self._loop.create_future() + self._closed = self.loop.create_future() def pause_writing(self): assert not self._paused @@ -99,7 +95,7 @@ async def _drain_helper(self): return waiter = self._drain_waiter assert waiter is None or waiter.cancelled() - waiter = self._loop.create_future() + waiter = self.loop.create_future() self._drain_waiter = waiter await waiter @@ -108,11 +104,11 @@ def connection_made(self, transport): self._over_ssl = transport.get_extra_info("sslcontext") is not None if self._client_connected_cb is not None: self._stream_writer = asyncio.StreamWriter( - transport, self, self._stream_reader, self._loop + transport, self, self._stream_reader, self.loop ) res = self._client_connected_cb(self._stream_reader, self._stream_writer) if asyncio.iscoroutine(res): - self._loop.create_task(res) + self.loop.create_task(res) def connection_lost(self, exc): if self._stream_reader is not None: @@ -315,8 +311,6 @@ def __init__( self.read_limit = read_limit self.write_limit = write_limit - # Store a reference to loop to avoid relying on self._loop, a private - # attribute of StreamReaderProtocol, inherited from FlowControlMixin. if loop is None: loop = asyncio.get_event_loop() self.loop = loop @@ -331,7 +325,7 @@ def __init__( # limit and half the buffer limit of :class:`~asyncio.StreamReader`. # That's why it must be set to half of ``self.read_limit``. stream_reader = asyncio.StreamReader(limit=read_limit // 2, loop=loop) - super().__init__(stream_reader, self.client_connected, loop) + super().__init__(stream_reader, self.client_connected) self.reader: asyncio.StreamReader self.writer: asyncio.StreamWriter From 5ed6a458b1992e4e00a2a25b0a2c378f22c3e2e7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 11:57:34 +0200 Subject: [PATCH 0650/1539] Remove client_connected callback. --- src/websockets/protocol.py | 31 +++++++------------------------ 1 file changed, 7 insertions(+), 24 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 98c23ab1c..bfc354a82 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -63,14 +63,13 @@ class State(enum.IntEnum): class StreamReaderProtocol(asyncio.Protocol): - def __init__(self, stream_reader, client_connected_cb=None): + def __init__(self, stream_reader): self._paused = False self._drain_waiter = None self._connection_lost = False self._stream_reader = stream_reader self._stream_writer = None - self._client_connected_cb = client_connected_cb self._over_ssl = False self._closed = self.loop.create_future() @@ -102,13 +101,11 @@ async def _drain_helper(self): def connection_made(self, transport): self._stream_reader.set_transport(transport) self._over_ssl = transport.get_extra_info("sslcontext") is not None - if self._client_connected_cb is not None: - self._stream_writer = asyncio.StreamWriter( - transport, self, self._stream_reader, self.loop - ) - res = self._client_connected_cb(self._stream_reader, self._stream_writer) - if asyncio.iscoroutine(res): - self.loop.create_task(res) + self._stream_writer = asyncio.StreamWriter( + transport, self, self._stream_reader, self.loop + ) + self.reader = self._stream_reader + self.writer = self._stream_writer def connection_lost(self, exc): if self._stream_reader is not None: @@ -325,7 +322,7 @@ def __init__( # limit and half the buffer limit of :class:`~asyncio.StreamReader`. # That's why it must be set to half of ``self.read_limit``. stream_reader = asyncio.StreamReader(limit=read_limit // 2, loop=loop) - super().__init__(stream_reader, self.client_connected) + super().__init__(stream_reader) self.reader: asyncio.StreamReader self.writer: asyncio.StreamWriter @@ -381,20 +378,6 @@ def __init__( # Task closing the TCP connection. self.close_connection_task: asyncio.Task[None] - def client_connected( - self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter - ) -> None: - """ - Callback when the TCP connection is established. - - Record references to the stream reader and the stream writer to avoid - using private attributes ``_stream_reader`` and ``_stream_writer`` of - :class:`~asyncio.StreamReaderProtocol`. - - """ - self.reader = reader - self.writer = writer - def connection_open(self) -> None: """ Callback when the WebSocket opening handshake completes. From 00ef5c3525442a943ec471f3f9f2edae8163cf7c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 12:00:31 +0200 Subject: [PATCH 0651/1539] Deduplicate reader/writer and _stream_reader/writer attributes. --- src/websockets/protocol.py | 36 +++++++++++++++--------------------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index bfc354a82..9c61a409b 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -62,14 +62,11 @@ class State(enum.IntEnum): class StreamReaderProtocol(asyncio.Protocol): - - def __init__(self, stream_reader): + def __init__(self): self._paused = False self._drain_waiter = None self._connection_lost = False - self._stream_reader = stream_reader - self._stream_writer = None self._over_ssl = False self._closed = self.loop.create_future() @@ -99,20 +96,16 @@ async def _drain_helper(self): await waiter def connection_made(self, transport): - self._stream_reader.set_transport(transport) + self.reader.set_transport(transport) self._over_ssl = transport.get_extra_info("sslcontext") is not None - self._stream_writer = asyncio.StreamWriter( - transport, self, self._stream_reader, self.loop - ) - self.reader = self._stream_reader - self.writer = self._stream_writer + self.writer = asyncio.StreamWriter(transport, self, self.reader, self.loop) def connection_lost(self, exc): - if self._stream_reader is not None: + if self.reader is not None: if exc is None: - self._stream_reader.feed_eof() + self.reader.feed_eof() else: - self._stream_reader.set_exception(exc) + self.reader.set_exception(exc) if not self._closed.done(): if exc is None: self._closed.set_result(None) @@ -134,14 +127,14 @@ def connection_lost(self, exc): else: waiter.set_exception(exc) - self._stream_reader = None - self._stream_writer = None + del self.reader + del self.writer def data_received(self, data): - self._stream_reader.feed_data(data) + self.reader.feed_data(data) def eof_received(self): - self._stream_reader.feed_eof() + self.reader.feed_eof() if self._over_ssl: # Prevent a warning in SSLProtocol.eof_received: # "returning true from eof_received() @@ -321,13 +314,14 @@ def __init__( # ``self.read_limit``. The ``limit`` argument controls the line length # limit and half the buffer limit of :class:`~asyncio.StreamReader`. # That's why it must be set to half of ``self.read_limit``. - stream_reader = asyncio.StreamReader(limit=read_limit // 2, loop=loop) - super().__init__(stream_reader) - - self.reader: asyncio.StreamReader + self.reader: asyncio.StreamReader = asyncio.StreamReader( + limit=read_limit // 2, loop=loop + ) self.writer: asyncio.StreamWriter self._drain_lock = asyncio.Lock(loop=loop) + super().__init__() + # This class implements the data transfer and closing handshake, which # are shared between the client-side and the server-side. # Subclasses implement the opening handshake and, on success, execute From 2707b51fec077b88060c9fca4dc2a5a50b55eda5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 12:25:40 +0200 Subject: [PATCH 0652/1539] Merge asyncio.Protocol methods. --- src/websockets/protocol.py | 160 ++++++++++++++++--------------------- 1 file changed, 70 insertions(+), 90 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 9c61a409b..89e3464a6 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -70,20 +70,6 @@ def __init__(self): self._over_ssl = False self._closed = self.loop.create_future() - def pause_writing(self): - assert not self._paused - self._paused = True - - def resume_writing(self): - assert self._paused - self._paused = False - - waiter = self._drain_waiter - if waiter is not None: - self._drain_waiter = None - if not waiter.done(): - waiter.set_result(None) - async def _drain_helper(self): if self._connection_lost: raise ConnectionResetError("Connection lost") @@ -95,53 +81,6 @@ async def _drain_helper(self): self._drain_waiter = waiter await waiter - def connection_made(self, transport): - self.reader.set_transport(transport) - self._over_ssl = transport.get_extra_info("sslcontext") is not None - self.writer = asyncio.StreamWriter(transport, self, self.reader, self.loop) - - def connection_lost(self, exc): - if self.reader is not None: - if exc is None: - self.reader.feed_eof() - else: - self.reader.set_exception(exc) - if not self._closed.done(): - if exc is None: - self._closed.set_result(None) - else: - self._closed.set_exception(exc) - - self._connection_lost = True - # Wake up the writer if currently paused. - if not self._paused: - return - waiter = self._drain_waiter - if waiter is None: - return - self._drain_waiter = None - if waiter.done(): - return - if exc is None: - waiter.set_result(None) - else: - waiter.set_exception(exc) - - del self.reader - del self.writer - - def data_received(self, data): - self.reader.feed_data(data) - - def eof_received(self): - self.reader.feed_eof() - if self._over_ssl: - # Prevent a warning in SSLProtocol.eof_received: - # "returning true from eof_received() - # has no effect when using ssl" - return False - return True - def __del__(self): # Prevent reports about unhandled exceptions. # Better than self._closed._log_traceback = False hack @@ -1363,7 +1302,7 @@ def abort_pings(self) -> None: "%s - aborted pending ping%s: %s", self.side, plural, pings_hex ) - # asyncio.StreamReaderProtocol methods + # asyncio.Protocol methods def connection_made(self, transport: asyncio.BaseTransport) -> None: """ @@ -1382,34 +1321,11 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: logger.debug("%s - event = connection_made(%s)", self.side, transport) # mypy thinks transport is a BaseTransport, not a Transport. transport.set_write_buffer_limits(self.write_limit) # type: ignore - super().connection_made(transport) - - def eof_received(self) -> bool: - """ - Close the transport after receiving EOF. - - Since Python 3.5, `:meth:~StreamReaderProtocol.eof_received` returns - ``True`` on non-TLS connections. - - See http://bugs.python.org/issue24539 for more information. - - This is inappropriate for ``websockets`` for at least three reasons: - - 1. The use case is to read data until EOF with self.reader.read(-1). - Since WebSocket is a TLV protocol, this never happens. - - 2. It doesn't work on TLS connections. A falsy value must be - returned to have the same behavior on TLS and plain connections. - - 3. The WebSocket protocol has its own closing handshake. Endpoints - close the TCP connection after sending a close frame. - - As a consequence we revert to the previous, more useful behavior. - """ - logger.debug("%s - event = eof_received()", self.side) - super().eof_received() - return False + # Copied from asyncio.StreamReaderProtocol + self.reader.set_transport(transport) + self._over_ssl = transport.get_extra_info("sslcontext") is not None + self.writer = asyncio.StreamWriter(transport, self, self.reader, self.loop) def connection_lost(self, exc: Optional[Exception]) -> None: """ @@ -1434,4 +1350,68 @@ def connection_lost(self, exc: Optional[Exception]) -> None: # - it's set only here in connection_lost() which is called only once; # - it must never be canceled. self.connection_lost_waiter.set_result(None) - super().connection_lost(exc) + + # Copied from asyncio.StreamReaderProtocol + if self.reader is not None: + if exc is None: + self.reader.feed_eof() + else: + self.reader.set_exception(exc) + if not self._closed.done(): + if exc is None: + self._closed.set_result(None) + else: + self._closed.set_exception(exc) + + # Copied from asyncio.FlowControlMixin + self._connection_lost = True + # Wake up the writer if currently paused. + if not self._paused: + return + waiter = self._drain_waiter + if waiter is None: + return + self._drain_waiter = None + if waiter.done(): + return + if exc is None: + waiter.set_result(None) + else: + waiter.set_exception(exc) + + del self.reader + del self.writer + + def pause_writing(self) -> None: + assert not self._paused + self._paused = True + + def resume_writing(self) -> None: + assert self._paused + self._paused = False + + waiter = self._drain_waiter + if waiter is not None: + self._drain_waiter = None + if not waiter.done(): + waiter.set_result(None) + + def data_received(self, data: bytes) -> None: + logger.debug("%s - event = data_received(<%d bytes>)", self.side, len(data)) + self.reader.feed_data(data) + + def eof_received(self) -> None: + """ + Close the transport after receiving EOF. + + The WebSocket protocol has its own closing handshake: endpoints close + the TCP or TLS connection after sending and receiving a close frame. + + As a consequence, they never need to write after receiving EOF, so + there's no reason to keep the transport open by returning ``True``. + + Besides, that doesn't work on TLS connections. + + """ + logger.debug("%s - event = eof_received()", self.side) + self.reader.feed_eof() From 5330199df186ab2516ea63a6588ceb50e8e4404e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 13:09:40 +0200 Subject: [PATCH 0653/1539] Finish merging StreamReaderProtocol and FlowControlMixin. --- src/websockets/protocol.py | 59 +++++++++++++++++++------------------- 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 89e3464a6..2db44e5d8 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -61,35 +61,7 @@ class State(enum.IntEnum): # between the check and the assignment. -class StreamReaderProtocol(asyncio.Protocol): - def __init__(self): - self._paused = False - self._drain_waiter = None - self._connection_lost = False - - self._over_ssl = False - self._closed = self.loop.create_future() - - async def _drain_helper(self): - if self._connection_lost: - raise ConnectionResetError("Connection lost") - if not self._paused: - return - waiter = self._drain_waiter - assert waiter is None or waiter.cancelled() - waiter = self.loop.create_future() - self._drain_waiter = waiter - await waiter - - def __del__(self): - # Prevent reports about unhandled exceptions. - # Better than self._closed._log_traceback = False hack - closed = self._closed - if closed.done() and not closed.cancelled(): - closed.exception() - - -class WebSocketCommonProtocol(StreamReaderProtocol): +class WebSocketCommonProtocol(asyncio.Protocol): """ :class:`~asyncio.Protocol` subclass implementing the data transfer phase. @@ -259,7 +231,14 @@ def __init__( self.writer: asyncio.StreamWriter self._drain_lock = asyncio.Lock(loop=loop) - super().__init__() + # Copied from asyncio.FlowControlMixin + self._paused = False + self._drain_waiter: Optional[asyncio.Future[None]] = None + self._connection_lost = False + + # Copied from asyncio.StreamReaderProtocol + self._over_ssl = False + self._closed = self.loop.create_future() # This class implements the data transfer and closing handshake, which # are shared between the client-side and the server-side. @@ -311,6 +290,26 @@ def __init__( # Task closing the TCP connection. self.close_connection_task: asyncio.Task[None] + # Copied from asyncio.StreamReaderProtocol + def __del__(self) -> None: + # Prevent reports about unhandled exceptions. + # Better than self._closed._log_traceback = False hack + closed = self._closed + if closed.done() and not closed.cancelled(): + closed.exception() + + # Copied from asyncio.FlowControlMixin + async def _drain_helper(self) -> None: + if self._connection_lost: + raise ConnectionResetError("Connection lost") + if not self._paused: + return + waiter = self._drain_waiter + assert waiter is None or waiter.cancelled() + waiter = self.loop.create_future() + self._drain_waiter = waiter + await waiter + def connection_open(self) -> None: """ Callback when the WebSocket opening handshake completes. From 25e0a5968529bdddaf240267da367a83cef5fc35 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 13:17:26 +0200 Subject: [PATCH 0654/1539] Deduplicate connection termination tracking. --- src/websockets/protocol.py | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 2db44e5d8..1cd5a91c2 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -234,11 +234,9 @@ def __init__( # Copied from asyncio.FlowControlMixin self._paused = False self._drain_waiter: Optional[asyncio.Future[None]] = None - self._connection_lost = False # Copied from asyncio.StreamReaderProtocol self._over_ssl = False - self._closed = self.loop.create_future() # This class implements the data transfer and closing handshake, which # are shared between the client-side and the server-side. @@ -290,17 +288,14 @@ def __init__( # Task closing the TCP connection. self.close_connection_task: asyncio.Task[None] - # Copied from asyncio.StreamReaderProtocol - def __del__(self) -> None: - # Prevent reports about unhandled exceptions. - # Better than self._closed._log_traceback = False hack - closed = self._closed - if closed.done() and not closed.cancelled(): - closed.exception() + # asyncio.StreamWriter expects this attribute on the Protocol + @property + def _closed(self) -> asyncio.Future: + return self.connection_lost_waiter # Copied from asyncio.FlowControlMixin async def _drain_helper(self) -> None: - if self._connection_lost: + if self.connection_lost_waiter.done(): raise ConnectionResetError("Connection lost") if not self._paused: return @@ -1356,14 +1351,8 @@ def connection_lost(self, exc: Optional[Exception]) -> None: self.reader.feed_eof() else: self.reader.set_exception(exc) - if not self._closed.done(): - if exc is None: - self._closed.set_result(None) - else: - self._closed.set_exception(exc) # Copied from asyncio.FlowControlMixin - self._connection_lost = True # Wake up the writer if currently paused. if not self._paused: return From d89721dd429e2fa64288ea638b0b79829f9cd222 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 13:21:25 +0200 Subject: [PATCH 0655/1539] Remove unused attribute. --- src/websockets/protocol.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 1cd5a91c2..a1c90916b 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -235,9 +235,6 @@ def __init__( self._paused = False self._drain_waiter: Optional[asyncio.Future[None]] = None - # Copied from asyncio.StreamReaderProtocol - self._over_ssl = False - # This class implements the data transfer and closing handshake, which # are shared between the client-side and the server-side. # Subclasses implement the opening handshake and, on success, execute @@ -1318,7 +1315,6 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: # Copied from asyncio.StreamReaderProtocol self.reader.set_transport(transport) - self._over_ssl = transport.get_extra_info("sslcontext") is not None self.writer = asyncio.StreamWriter(transport, self, self.reader, self.loop) def connection_lost(self, exc: Optional[Exception]) -> None: From e13387478c474396950259c4b2552a0e5469ae5e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 13:28:49 +0200 Subject: [PATCH 0656/1539] Ignore quality checks for code copied from asyncio. --- src/websockets/protocol.py | 51 +++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index a1c90916b..0bb12fd5a 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -287,11 +287,11 @@ def __init__( # asyncio.StreamWriter expects this attribute on the Protocol @property - def _closed(self) -> asyncio.Future: + def _closed(self) -> Any: # pragma: no cover return self.connection_lost_waiter # Copied from asyncio.FlowControlMixin - async def _drain_helper(self) -> None: + async def _drain_helper(self) -> None: # pragma: no cover if self.connection_lost_waiter.done(): raise ConnectionResetError("Connection lost") if not self._paused: @@ -1341,36 +1341,35 @@ def connection_lost(self, exc: Optional[Exception]) -> None: # - it must never be canceled. self.connection_lost_waiter.set_result(None) - # Copied from asyncio.StreamReaderProtocol - if self.reader is not None: - if exc is None: - self.reader.feed_eof() - else: - self.reader.set_exception(exc) + if True: # pragma: no cover - # Copied from asyncio.FlowControlMixin - # Wake up the writer if currently paused. - if not self._paused: - return - waiter = self._drain_waiter - if waiter is None: - return - self._drain_waiter = None - if waiter.done(): - return - if exc is None: - waiter.set_result(None) - else: - waiter.set_exception(exc) + # Copied from asyncio.StreamReaderProtocol + if self.reader is not None: + if exc is None: + self.reader.feed_eof() + else: + self.reader.set_exception(exc) - del self.reader - del self.writer + # Copied from asyncio.FlowControlMixin + # Wake up the writer if currently paused. + if not self._paused: + return + waiter = self._drain_waiter + if waiter is None: + return + self._drain_waiter = None + if waiter.done(): + return + if exc is None: + waiter.set_result(None) + else: + waiter.set_exception(exc) - def pause_writing(self) -> None: + def pause_writing(self) -> None: # pragma: no cover assert not self._paused self._paused = True - def resume_writing(self) -> None: + def resume_writing(self) -> None: # pragma: no cover assert self._paused self._paused = False From 94d43ebc3309176ed9b57dbfa8e9cd44fa1697ee Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 13:54:13 +0200 Subject: [PATCH 0657/1539] Remove asyncio.StreamWriter. It adds only one method for flow control. Copy it, as we've already copied the rest of the flow control implementation. --- src/websockets/client.py | 2 +- src/websockets/protocol.py | 64 +++++++++++++++++++++---------------- src/websockets/server.py | 6 ++-- tests/test_client_server.py | 2 +- tests/test_protocol.py | 18 +++++------ 5 files changed, 51 insertions(+), 41 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index c1fdf88a0..34cd86240 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -85,7 +85,7 @@ def write_http_request(self, path: str, headers: Headers) -> None: request = f"GET {path} HTTP/1.1\r\n" request += str(headers) - self.writer.write(request.encode()) + self.transport.write(request.encode()) async def read_http_response(self) -> Tuple[int, Headers]: """ diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 0bb12fd5a..eb3d6bcc7 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -228,7 +228,8 @@ def __init__( self.reader: asyncio.StreamReader = asyncio.StreamReader( limit=read_limit // 2, loop=loop ) - self.writer: asyncio.StreamWriter + + self.transport: asyncio.Transport self._drain_lock = asyncio.Lock(loop=loop) # Copied from asyncio.FlowControlMixin @@ -285,11 +286,6 @@ def __init__( # Task closing the TCP connection. self.close_connection_task: asyncio.Task[None] - # asyncio.StreamWriter expects this attribute on the Protocol - @property - def _closed(self) -> Any: # pragma: no cover - return self.connection_lost_waiter - # Copied from asyncio.FlowControlMixin async def _drain_helper(self) -> None: # pragma: no cover if self.connection_lost_waiter.done(): @@ -302,6 +298,23 @@ async def _drain_helper(self) -> None: # pragma: no cover self._drain_waiter = waiter await waiter + # Copied from asyncio.StreamWriter + async def _drain(self) -> None: # pragma: no cover + if self.reader is not None: + exc = self.reader.exception() + if exc is not None: + raise exc + if self.transport is not None: + if self.transport.is_closing(): + # Yield to the event loop so connection_lost() may be + # called. Without this, _drain_helper() would return + # immediately, and code that calls + # write(...); yield from drain() + # in a loop would never call connection_lost(), so it + # would not see an error when the socket is closed. + await asyncio.sleep(0) + await self._drain_helper() + def connection_open(self) -> None: """ Callback when the WebSocket opening handshake completes. @@ -348,9 +361,9 @@ def local_address(self) -> Any: been established yet. """ - if self.writer is None: + if self.transport is None: return None - return self.writer.get_extra_info("sockname") + return self.transport.get_extra_info("sockname") @property def remote_address(self) -> Any: @@ -361,9 +374,9 @@ def remote_address(self) -> Any: been established yet. """ - if self.writer is None: + if self.transport is None: return None - return self.writer.get_extra_info("peername") + return self.transport.get_extra_info("peername") @property def open(self) -> bool: @@ -1037,7 +1050,9 @@ async def write_frame( frame = Frame(fin, opcode, data) logger.debug("%s > %r", self.side, frame) - frame.write(self.writer.write, mask=self.is_client, extensions=self.extensions) + frame.write( + self.transport.write, mask=self.is_client, extensions=self.extensions + ) try: # drain() cannot be called concurrently by multiple coroutines: @@ -1045,7 +1060,7 @@ async def write_frame( # version of Python where this bugs exists is supported anymore. async with self._drain_lock: # Handle flow control automatically. - await self.writer.drain() + await self._drain() except ConnectionError: # Terminate the connection if the socket died. self.fail_connection() @@ -1147,9 +1162,9 @@ async def close_connection(self) -> None: logger.debug("%s ! timed out waiting for TCP close", self.side) # Half-close the TCP connection if possible (when there's no TLS). - if self.writer.can_write_eof(): + if self.transport.can_write_eof(): logger.debug("%s x half-closing TCP connection", self.side) - self.writer.write_eof() + self.transport.write_eof() if await self.wait_for_connection_lost(): return @@ -1162,17 +1177,12 @@ async def close_connection(self) -> None: # If connection_lost() was called, the TCP connection is closed. # However, if TLS is enabled, the transport still needs closing. # Else asyncio complains: ResourceWarning: unclosed transport. - try: - writer_is_closing = self.writer.is_closing # type: ignore - except AttributeError: # pragma: no cover - # Python < 3.7 - writer_is_closing = self.writer.transport.is_closing - if self.connection_lost_waiter.done() and writer_is_closing(): + if self.connection_lost_waiter.done() and self.transport.is_closing(): return # Close the TCP connection. Buffers are flushed asynchronously. logger.debug("%s x closing TCP connection", self.side) - self.writer.close() + self.transport.close() if await self.wait_for_connection_lost(): return @@ -1180,8 +1190,7 @@ async def close_connection(self) -> None: # Abort the TCP connection. Buffers are discarded. logger.debug("%s x aborting TCP connection", self.side) - # mypy thinks self.writer.transport is a BaseTransport, not a Transport. - self.writer.transport.abort() # type: ignore + self.transport.abort() # connection_lost() is called quickly after aborting. await self.wait_for_connection_lost() @@ -1261,7 +1270,7 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> None: frame = Frame(True, OP_CLOSE, frame_data) logger.debug("%s > %r", self.side, frame) frame.write( - self.writer.write, mask=self.is_client, extensions=self.extensions + self.transport.write, mask=self.is_client, extensions=self.extensions ) # Start close_connection_task if the opening handshake didn't succeed. @@ -1310,12 +1319,13 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: """ logger.debug("%s - event = connection_made(%s)", self.side, transport) - # mypy thinks transport is a BaseTransport, not a Transport. - transport.set_write_buffer_limits(self.write_limit) # type: ignore + + transport = cast(asyncio.Transport, transport) + transport.set_write_buffer_limits(self.write_limit) + self.transport = transport # Copied from asyncio.StreamReaderProtocol self.reader.set_transport(transport) - self.writer = asyncio.StreamWriter(transport, self, self.reader, self.loop) def connection_lost(self, exc: Optional[Exception]) -> None: """ diff --git a/src/websockets/server.py b/src/websockets/server.py index b220a1b88..1e8ae8617 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -211,7 +211,7 @@ async def handler(self) -> None: except Exception: # Last-ditch attempt to avoid leaking connections on errors. try: - self.writer.close() + self.transport.close() except Exception: # pragma: no cover pass @@ -265,11 +265,11 @@ def write_http_response( response = f"HTTP/1.1 {status.value} {status.phrase}\r\n" response += str(headers) - self.writer.write(response.encode()) + self.transport.write(response.encode()) if body is not None: logger.debug("%s > body (%d bytes)", self.side, len(body)) - self.writer.write(body) + self.transport.write(body) async def process_request( self, path: str, request_headers: Headers diff --git a/tests/test_client_server.py b/tests/test_client_server.py index e74ec6bf6..6171f21b0 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -1166,7 +1166,7 @@ def test_server_close_crashes(self, close): def test_client_closes_connection_before_handshake(self, handshake): # We have mocked the handshake() method to prevent the client from # performing the opening handshake. Force it to close the connection. - self.client.writer.close() + self.client.transport.close() # The server should stop properly anyway. It used to hang because the # task handling the connection was waiting for the opening handshake. diff --git a/tests/test_protocol.py b/tests/test_protocol.py index a6c420181..dfc2c6d45 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -94,16 +94,16 @@ def tearDown(self): # Utilities for writing tests. def make_drain_slow(self, delay=MS): - # Process connection_made in order to initialize self.protocol.writer. + # Process connection_made in order to initialize self.protocol.transport. self.run_loop_once() - original_drain = self.protocol.writer.drain + original_drain = self.protocol._drain async def delayed_drain(): await asyncio.sleep(delay, loop=self.loop) await original_drain() - self.protocol.writer.drain = delayed_drain + self.protocol._drain = delayed_drain close_frame = Frame(True, OP_CLOSE, serialize_close(1000, "close")) local_close = Frame(True, OP_CLOSE, serialize_close(1000, "local")) @@ -321,32 +321,32 @@ def test_local_address(self): self.transport.get_extra_info = get_extra_info self.assertEqual(self.protocol.local_address, ("host", 4312)) - get_extra_info.assert_called_with("sockname", None) + get_extra_info.assert_called_with("sockname") def test_local_address_before_connection(self): # Emulate the situation before connection_open() runs. - self.protocol.writer, _writer = None, self.protocol.writer + self.protocol.transport, _transport = None, self.protocol.transport try: self.assertEqual(self.protocol.local_address, None) finally: - self.protocol.writer = _writer + self.protocol.transport = _transport def test_remote_address(self): get_extra_info = unittest.mock.Mock(return_value=("host", 4312)) self.transport.get_extra_info = get_extra_info self.assertEqual(self.protocol.remote_address, ("host", 4312)) - get_extra_info.assert_called_with("peername", None) + get_extra_info.assert_called_with("peername") def test_remote_address_before_connection(self): # Emulate the situation before connection_open() runs. - self.protocol.writer, _writer = None, self.protocol.writer + self.protocol.transport, _transport = None, self.protocol.transport try: self.assertEqual(self.protocol.remote_address, None) finally: - self.protocol.writer = _writer + self.protocol.transport = _transport def test_open(self): self.assertTrue(self.protocol.open) From 8952c3a78a0cbf98501c94c30920a3eb4162c5d2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 13:55:16 +0200 Subject: [PATCH 0658/1539] Rename writer to write. It's a better name for a function that writes bytes. --- src/websockets/framing.py | 8 ++++---- tests/test_framing.py | 14 +++++++------- tests/test_protocol.py | 4 ++-- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/websockets/framing.py b/src/websockets/framing.py index 81a3185b0..c24b8a73d 100644 --- a/src/websockets/framing.py +++ b/src/websockets/framing.py @@ -147,7 +147,7 @@ async def read( def write( frame, - writer: Callable[[bytes], Any], + write: Callable[[bytes], Any], *, mask: bool, extensions: Optional[Sequence["websockets.extensions.base.Extension"]] = None, @@ -156,7 +156,7 @@ def write( Write a WebSocket frame. :param frame: frame to write - :param writer: function that writes bytes + :param write: function that writes bytes :param mask: whether the frame should be masked i.e. whether the write happens on the client side :param extensions: list of classes with an ``encode()`` method that @@ -210,10 +210,10 @@ def write( # Send the frame. - # The frame is written in a single call to writer in order to prevent + # The frame is written in a single call to write in order to prevent # TCP fragmentation. See #68 for details. This also makes it safe to # send frames concurrently from multiple coroutines. - writer(output.getvalue()) + write(output.getvalue()) def check(frame) -> None: """ diff --git a/tests/test_framing.py b/tests/test_framing.py index 9e6f1871d..5def415d2 100644 --- a/tests/test_framing.py +++ b/tests/test_framing.py @@ -27,15 +27,15 @@ def decode(self, message, mask=False, max_size=None, extensions=None): return frame def encode(self, frame, mask=False, extensions=None): - writer = unittest.mock.Mock() - frame.write(writer, mask=mask, extensions=extensions) - # Ensure the entire frame is sent with a single call to writer(). + write = unittest.mock.Mock() + frame.write(write, mask=mask, extensions=extensions) + # Ensure the entire frame is sent with a single call to write(). # Multiple calls cause TCP fragmentation and degrade performance. - self.assertEqual(writer.call_count, 1) + self.assertEqual(write.call_count, 1) # The frame data is the single positional argument of that call. - self.assertEqual(len(writer.call_args[0]), 1) - self.assertEqual(len(writer.call_args[1]), 0) - return writer.call_args[0][0] + self.assertEqual(len(write.call_args[0]), 1) + self.assertEqual(len(write.call_args[1]), 0) + return write.call_args[0][0] def round_trip(self, message, expected, mask=False, extensions=None): decoded = self.decode(message, mask, extensions=extensions) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index dfc2c6d45..d2793faf5 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -114,9 +114,9 @@ def receive_frame(self, frame): Make the protocol receive a frame. """ - writer = self.protocol.data_received + write = self.protocol.data_received mask = not self.protocol.is_client - frame.write(writer, mask=mask) + frame.write(write, mask=mask) def receive_eof(self): """ From e679490cf2af87bc060fc63a0f2898444f26d5c3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 15:49:08 +0200 Subject: [PATCH 0659/1539] Update to the latest version of mypy. The bugs that were locking us on an old version are fixed. --- src/websockets/__init__.py | 15 ++++++++------- src/websockets/__main__.py | 8 ++++---- src/websockets/client.py | 3 +-- src/websockets/handshake.py | 15 ++++++++++----- src/websockets/protocol.py | 14 +++++++++----- src/websockets/server.py | 11 +++-------- tox.ini | 2 +- 7 files changed, 36 insertions(+), 32 deletions(-) diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index e7ba31ce5..6bad0f7bc 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -1,12 +1,13 @@ # This relies on each of the submodules having an __all__ variable. -from .auth import * -from .client import * -from .exceptions import * -from .protocol import * -from .server import * -from .typing import * -from .uri import * +from . import auth, client, exceptions, protocol, server, typing, uri +from .auth import * # noqa +from .client import * # noqa +from .exceptions import * # noqa +from .protocol import * # noqa +from .server import * # noqa +from .typing import * # noqa +from .uri import * # noqa from .version import version as __version__ # noqa diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index bccb8aa52..394f7ac79 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -6,8 +6,8 @@ import threading from typing import Any, Set -import websockets -from websockets.exceptions import format_close +from .client import connect +from .exceptions import ConnectionClosed, format_close if sys.platform == "win32": @@ -95,7 +95,7 @@ async def run_client( stop: "asyncio.Future[None]", ) -> None: try: - websocket = await websockets.connect(uri) + websocket = await connect(uri) except Exception as exc: print_over_input(f"Failed to connect to {uri}: {exc}.") exit_from_event_loop_thread(loop, stop) @@ -122,7 +122,7 @@ async def run_client( if incoming in done: try: message = incoming.result() - except websockets.ConnectionClosed: + except ConnectionClosed: break else: if isinstance(message, str): diff --git a/src/websockets/client.py b/src/websockets/client.py index 34cd86240..725ec1e7a 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -24,7 +24,6 @@ from .extensions.permessage_deflate import ClientPerMessageDeflateFactory from .handshake import build_request, check_response from .headers import ( - ExtensionHeader, build_authorization_basic, build_extension, build_subprotocol, @@ -33,7 +32,7 @@ ) from .http import USER_AGENT, Headers, HeadersLike, read_response from .protocol import WebSocketCommonProtocol -from .typing import Origin, Subprotocol +from .typing import ExtensionHeader, Origin, Subprotocol from .uri import WebSocketURI, parse_uri diff --git a/src/websockets/handshake.py b/src/websockets/handshake.py index 17332d155..9bfe27754 100644 --- a/src/websockets/handshake.py +++ b/src/websockets/handshake.py @@ -29,9 +29,10 @@ import binascii import hashlib import random +from typing import List from .exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade -from .headers import parse_connection, parse_upgrade +from .headers import ConnectionOption, UpgradeProtocol, parse_connection, parse_upgrade from .http import Headers, MultipleValuesError @@ -74,14 +75,16 @@ def check_request(headers: Headers) -> str: is invalid; then the server must return 400 Bad Request error """ - connection = sum( + connection: List[ConnectionOption] = sum( [parse_connection(value) for value in headers.get_all("Connection")], [] ) if not any(value.lower() == "upgrade" for value in connection): raise InvalidUpgrade("Connection", ", ".join(connection)) - upgrade = sum([parse_upgrade(value) for value in headers.get_all("Upgrade")], []) + upgrade: List[UpgradeProtocol] = sum( + [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] + ) # For compatibility with non-strict implementations, ignore case when # checking the Upgrade header. It's supposed to be 'WebSocket'. @@ -148,14 +151,16 @@ def check_response(headers: Headers, key: str) -> None: is invalid """ - connection = sum( + connection: List[ConnectionOption] = sum( [parse_connection(value) for value in headers.get_all("Connection")], [] ) if not any(value.lower() == "upgrade" for value in connection): raise InvalidUpgrade("Connection", " ".join(connection)) - upgrade = sum([parse_upgrade(value) for value in headers.get_all("Upgrade")], []) + upgrade: List[UpgradeProtocol] = sum( + [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] + ) # For compatibility with non-strict implementations, ignore case when # checking the Upgrade header. It's supposed to be 'WebSocket'. diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index eb3d6bcc7..b7c1f19c9 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -601,10 +601,14 @@ async def send( elif isinstance(message, AsyncIterable): # aiter_message = aiter(message) without aiter - aiter_message = type(message).__aiter__(message) + # https://github.com/python/mypy/issues/5738 + aiter_message = type(message).__aiter__(message) # type: ignore try: # message_chunk = anext(aiter_message) without anext - message_chunk = await type(aiter_message).__anext__(aiter_message) + # https://github.com/python/mypy/issues/5738 + message_chunk = await type(aiter_message).__anext__( # type: ignore + aiter_message + ) except StopAsyncIteration: return opcode, data = prepare_data(message_chunk) @@ -615,7 +619,8 @@ async def send( await self.write_frame(False, opcode, data) # Other fragments. - async for message_chunk in aiter_message: + # https://github.com/python/mypy/issues/5738 + async for message_chunk in aiter_message: # type: ignore confirm_opcode, data = prepare_data(message_chunk) if confirm_opcode != opcode: raise TypeError("data contains inconsistent types") @@ -899,8 +904,7 @@ async def read_message(self) -> Optional[Data]: max_size = self.max_size if text: decoder_factory = codecs.getincrementaldecoder("utf-8") - # https://github.com/python/typeshed/pull/2752 - decoder = decoder_factory(errors="strict") # type: ignore + decoder = decoder_factory(errors="strict") if max_size is None: def append(frame: Frame) -> None: diff --git a/src/websockets/server.py b/src/websockets/server.py index 1e8ae8617..5114646dd 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -39,15 +39,10 @@ from .extensions.base import Extension, ServerExtensionFactory from .extensions.permessage_deflate import ServerPerMessageDeflateFactory from .handshake import build_response, check_request -from .headers import ( - ExtensionHeader, - build_extension, - parse_extension, - parse_subprotocol, -) +from .headers import build_extension, parse_extension, parse_subprotocol from .http import USER_AGENT, Headers, HeadersLike, MultipleValuesError, read_request from .protocol import WebSocketCommonProtocol -from .typing import Origin, Subprotocol +from .typing import ExtensionHeader, Origin, Subprotocol __all__ = ["serve", "unix_serve", "WebSocketServerProtocol", "WebSocketServer"] @@ -662,7 +657,7 @@ def is_serving(self) -> bool: """ try: # Python ≥ 3.7 - return self.server.is_serving() # type: ignore + return self.server.is_serving() except AttributeError: # pragma: no cover # Python < 3.7 return self.server.sockets is not None diff --git a/tox.ini b/tox.ini index 801d4d5d1..7397c90ae 100644 --- a/tox.ini +++ b/tox.ini @@ -25,4 +25,4 @@ deps = isort [testenv:mypy] commands = mypy --strict src -deps = mypy==0.670 +deps = mypy From 65ae7cd42ca5bcd1796e33c42909752b26b197f9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 15:49:51 +0200 Subject: [PATCH 0660/1539] Fix deprecation warnings on Python 3.8. * Don't pass the deprecated loop argument. * Ignore deprecation warnings for @asyncio.coroutine. --- src/websockets/protocol.py | 28 ++++++++++++++++++++-------- src/websockets/server.py | 10 +++++++--- tests/test_client_server.py | 36 +++++++++++++++++++++--------------- tests/test_protocol.py | 5 ++++- 4 files changed, 52 insertions(+), 27 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index b7c1f19c9..76d46ad9c 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -14,6 +14,7 @@ import logging import random import struct +import sys import warnings from typing import ( Any, @@ -230,7 +231,9 @@ def __init__( ) self.transport: asyncio.Transport - self._drain_lock = asyncio.Lock(loop=loop) + self._drain_lock = asyncio.Lock( + loop=loop if sys.version_info[:2] < (3, 8) else None + ) # Copied from asyncio.FlowControlMixin self._paused = False @@ -312,7 +315,9 @@ async def _drain(self) -> None: # pragma: no cover # write(...); yield from drain() # in a loop would never call connection_lost(), so it # would not see an error when the socket is closed. - await asyncio.sleep(0) + await asyncio.sleep( + 0, loop=self.loop if sys.version_info[:2] < (3, 8) else None + ) await self._drain_helper() def connection_open(self) -> None: @@ -483,7 +488,7 @@ async def recv(self) -> Data: # pop_message_waiter and self.transfer_data_task. await asyncio.wait( [pop_message_waiter, self.transfer_data_task], - loop=self.loop, + loop=self.loop if sys.version_info[:2] < (3, 8) else None, return_when=asyncio.FIRST_COMPLETED, ) finally: @@ -668,7 +673,7 @@ async def close(self, code: int = 1000, reason: str = "") -> None: await asyncio.wait_for( self.write_close_frame(serialize_close(code, reason)), self.close_timeout, - loop=self.loop, + loop=self.loop if sys.version_info[:2] < (3, 8) else None, ) except asyncio.TimeoutError: # If the close frame cannot be sent because the send buffers @@ -687,7 +692,9 @@ async def close(self, code: int = 1000, reason: str = "") -> None: # If close() is canceled during the wait, self.transfer_data_task # is canceled before the timeout elapses. await asyncio.wait_for( - self.transfer_data_task, self.close_timeout, loop=self.loop + self.transfer_data_task, + self.close_timeout, + loop=self.loop if sys.version_info[:2] < (3, 8) else None, ) except (asyncio.TimeoutError, asyncio.CancelledError): pass @@ -1106,7 +1113,10 @@ async def keepalive_ping(self) -> None: try: while True: - await asyncio.sleep(self.ping_interval, loop=self.loop) + await asyncio.sleep( + self.ping_interval, + loop=self.loop if sys.version_info[:2] < (3, 8) else None, + ) # ping() raises CancelledError if the connection is closed, # when close_connection() cancels self.keepalive_ping_task. @@ -1119,7 +1129,9 @@ async def keepalive_ping(self) -> None: if self.ping_timeout is not None: try: await asyncio.wait_for( - ping_waiter, self.ping_timeout, loop=self.loop + ping_waiter, + self.ping_timeout, + loop=self.loop if sys.version_info[:2] < (3, 8) else None, ) except asyncio.TimeoutError: logger.debug("%s ! timed out waiting for pong", self.side) @@ -1211,7 +1223,7 @@ async def wait_for_connection_lost(self) -> bool: await asyncio.wait_for( asyncio.shield(self.connection_lost_waiter), self.close_timeout, - loop=self.loop, + loop=self.loop if sys.version_info[:2] < (3, 8) else None, ) except asyncio.TimeoutError: pass diff --git a/src/websockets/server.py b/src/websockets/server.py index 5114646dd..4f5e9e0ef 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -10,6 +10,7 @@ import http import logging import socket +import sys import warnings from types import TracebackType from typing import ( @@ -698,7 +699,9 @@ async def _close(self) -> None: # Wait until all accepted connections reach connection_made() and call # register(). See https://bugs.python.org/issue34852 for details. - await asyncio.sleep(0) + await asyncio.sleep( + 0, loop=self.loop if sys.version_info[:2] < (3, 8) else None + ) # Close OPEN connections with status code 1001. Since the server was # closed, handshake() closes OPENING conections with a HTTP 503 error. @@ -707,7 +710,8 @@ async def _close(self) -> None: # asyncio.wait doesn't accept an empty first argument if self.websockets: await asyncio.wait( - [websocket.close(1001) for websocket in self.websockets], loop=self.loop + [websocket.close(1001) for websocket in self.websockets], + loop=self.loop if sys.version_info[:2] < (3, 8) else None, ) # Wait until all connection handlers are complete. @@ -716,7 +720,7 @@ async def _close(self) -> None: if self.websockets: await asyncio.wait( [websocket.handler_task for websocket in self.websockets], - loop=self.loop, + loop=self.loop if sys.version_info[:2] < (3, 8) else None, ) # Tell wait_closed() to return. diff --git a/tests/test_client_server.py b/tests/test_client_server.py index 6171f21b0..85828bdbc 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -1381,13 +1381,16 @@ def test_client(self): start_server = serve(handler, "localhost", 0) server = self.loop.run_until_complete(start_server) - @asyncio.coroutine - def run_client(): - # Yield from connect. - client = yield from connect(get_server_uri(server)) - self.assertEqual(client.state, State.OPEN) - yield from client.close() - self.assertEqual(client.state, State.CLOSED) + # @asyncio.coroutine is deprecated on Python ≥ 3.8 + with warnings.catch_warnings(record=True): + + @asyncio.coroutine + def run_client(): + # Yield from connect. + client = yield from connect(get_server_uri(server)) + self.assertEqual(client.state, State.OPEN) + yield from client.close() + self.assertEqual(client.state, State.CLOSED) self.loop.run_until_complete(run_client()) @@ -1395,14 +1398,17 @@ def run_client(): self.loop.run_until_complete(server.wait_closed()) def test_server(self): - @asyncio.coroutine - def run_server(): - # Yield from serve. - server = yield from serve(handler, "localhost", 0) - self.assertTrue(server.sockets) - server.close() - yield from server.wait_closed() - self.assertFalse(server.sockets) + # @asyncio.coroutine is deprecated on Python ≥ 3.8 + with warnings.catch_warnings(record=True): + + @asyncio.coroutine + def run_server(): + # Yield from serve. + server = yield from serve(handler, "localhost", 0) + self.assertTrue(server.sockets) + server.close() + yield from server.wait_closed() + self.assertFalse(server.sockets) self.loop.run_until_complete(run_server()) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index d2793faf5..04e2a38fa 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -1,6 +1,7 @@ import asyncio import contextlib import logging +import sys import unittest import unittest.mock import warnings @@ -100,7 +101,9 @@ def make_drain_slow(self, delay=MS): original_drain = self.protocol._drain async def delayed_drain(): - await asyncio.sleep(delay, loop=self.loop) + await asyncio.sleep( + delay, loop=self.loop if sys.version_info[:2] < (3, 8) else None + ) await original_drain() self.protocol._drain = delayed_drain From aa7c21497ce58c03c9d10eaeb70768c484d7d6ae Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 14:02:51 +0200 Subject: [PATCH 0661/1539] Document and test support for Python 3.8. --- .circleci/config.yml | 12 ++++++++++++ docs/changelog.rst | 2 ++ setup.py | 1 + tox.ini | 2 +- 4 files changed, 16 insertions(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index a6c85d237..0877c161a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -29,6 +29,15 @@ jobs: - checkout - run: sudo pip install tox - run: tox -e py37 + py38: + docker: + - image: circleci/python:3.8.0rc1 + steps: + # Remove IPv6 entry for localhost in Circle CI containers because it doesn't work anyway. + - run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc + - checkout + - run: sudo pip install tox + - run: tox -e py38 workflows: version: 2 @@ -41,3 +50,6 @@ workflows: - py37: requires: - main + - py38: + requires: + - main diff --git a/docs/changelog.rst b/docs/changelog.rst index 87b2e4380..2a106fbc0 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -8,6 +8,8 @@ Changelog *In development* +* Added compatibility with Python 3.8. + 8.0.2 ..... diff --git a/setup.py b/setup.py index c76430104..f35819247 100644 --- a/setup.py +++ b/setup.py @@ -53,6 +53,7 @@ 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', ], package_dir = {'': 'src'}, package_data = {'websockets': ['py.typed']}, diff --git a/tox.ini b/tox.ini index 7397c90ae..825e34061 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py36,py37,coverage,black,flake8,isort,mypy +envlist = py36,py37,py38,coverage,black,flake8,isort,mypy [testenv] commands = python -W default -m unittest {posargs} From a9ef745899b8346526eb3e29a95b5e0f7db9a1f2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 21:18:30 +0200 Subject: [PATCH 0662/1539] Move test logging configuration to a single place. --- tests/__init__.py | 5 +++++ tests/test_client_server.py | 5 ----- tests/test_protocol.py | 5 ----- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/tests/__init__.py b/tests/__init__.py index e69de29bb..dd78609f5 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,5 @@ +import logging + + +# Avoid displaying stack traces at the ERROR logging level. +logging.basicConfig(level=logging.CRITICAL) diff --git a/tests/test_client_server.py b/tests/test_client_server.py index 85828bdbc..ce0f66ce2 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -2,7 +2,6 @@ import contextlib import functools import http -import logging import pathlib import random import socket @@ -37,10 +36,6 @@ from .utils import AsyncioTestCase -# Avoid displaying stack traces at the ERROR logging level. -logging.basicConfig(level=logging.CRITICAL) - - # Generate TLS certificate with: # $ openssl req -x509 -config test_localhost.cnf -days 15340 -newkey rsa:2048 \ # -out test_localhost.crt -keyout test_localhost.key diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 04e2a38fa..d95260a84 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -1,6 +1,5 @@ import asyncio import contextlib -import logging import sys import unittest import unittest.mock @@ -13,10 +12,6 @@ from .utils import MS, AsyncioTestCase -# Avoid displaying stack traces at the ERROR logging level. -logging.basicConfig(level=logging.CRITICAL) - - async def async_iterable(iterable): for item in iterable: yield item From 1d673debfd306e3e1953f0312390fa5456e09b5a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 21:33:27 +0200 Subject: [PATCH 0663/1539] Remove test that no longer makes sense. Since version 7.0, when the server closes, it terminates connections with close code 1001 instead of canceling them. --- tests/test_client_server.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/tests/test_client_server.py b/tests/test_client_server.py index ce0f66ce2..35913666c 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -173,7 +173,7 @@ async def process_request(self, path, request_headers): return http.HTTPStatus.OK, [("X-Access", "OK")], b"status = green\n" -class SlowServerProtocol(WebSocketServerProtocol): +class SlowOpeningHandshakeProtocol(WebSocketServerProtocol): async def process_request(self, path, request_headers): await asyncio.sleep(10 * MS) @@ -1165,7 +1165,7 @@ def test_client_closes_connection_before_handshake(self, handshake): # The server should stop properly anyway. It used to hang because the # task handling the connection was waiting for the opening handshake. - @with_server(create_protocol=SlowServerProtocol) + @with_server(create_protocol=SlowOpeningHandshakeProtocol) def test_server_shuts_down_during_opening_handshake(self): self.loop.call_later(5 * MS, self.server.close) with self.assertRaises(InvalidStatusCode) as raised: @@ -1188,20 +1188,6 @@ def test_server_shuts_down_during_connection_handling(self): self.assertEqual(self.client.close_code, 1001) self.assertEqual(server_ws.close_code, 1001) - @with_server() - @unittest.mock.patch("websockets.server.WebSocketServerProtocol.close") - def test_server_shuts_down_during_connection_close(self, _close): - _close.side_effect = asyncio.CancelledError - - self.server.closing = True - with self.temp_client(): - self.loop.run_until_complete(self.client.send("Hello!")) - reply = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(reply, "Hello!") - - # Websocket connection terminates abnormally. - self.assertEqual(self.client.close_code, 1006) - @with_server() def test_server_shuts_down_waits_until_handlers_terminate(self): # This handler waits a bit after the connection is closed in order From d537c26ac380a1b74444f83f31cd744f7f24bf15 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 21:41:13 +0200 Subject: [PATCH 0664/1539] Fix refactoring error. WebSocketCommonProtocol.transport can be unset, but it cannot be None. --- src/websockets/protocol.py | 14 ++++++++++---- tests/test_protocol.py | 8 ++++---- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 76d46ad9c..0623e1364 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -366,9 +366,12 @@ def local_address(self) -> Any: been established yet. """ - if self.transport is None: + try: + transport = self.transport + except AttributeError: return None - return self.transport.get_extra_info("sockname") + else: + return transport.get_extra_info("sockname") @property def remote_address(self) -> Any: @@ -379,9 +382,12 @@ def remote_address(self) -> Any: been established yet. """ - if self.transport is None: + try: + transport = self.transport + except AttributeError: return None - return self.transport.get_extra_info("peername") + else: + return transport.get_extra_info("peername") @property def open(self) -> bool: diff --git a/tests/test_protocol.py b/tests/test_protocol.py index d95260a84..66a822e79 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -323,8 +323,8 @@ def test_local_address(self): def test_local_address_before_connection(self): # Emulate the situation before connection_open() runs. - self.protocol.transport, _transport = None, self.protocol.transport - + _transport = self.protocol.transport + del self.protocol.transport try: self.assertEqual(self.protocol.local_address, None) finally: @@ -339,8 +339,8 @@ def test_remote_address(self): def test_remote_address_before_connection(self): # Emulate the situation before connection_open() runs. - self.protocol.transport, _transport = None, self.protocol.transport - + _transport = self.protocol.transport + del self.protocol.transport try: self.assertEqual(self.protocol.remote_address, None) finally: From 154c5fa964fe407341edad5e70367e64913023bb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 21:39:14 +0200 Subject: [PATCH 0665/1539] Remove useless type declaration. --- src/websockets/protocol.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 0623e1364..6c29b2a52 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -226,19 +226,16 @@ def __init__( # ``self.read_limit``. The ``limit`` argument controls the line length # limit and half the buffer limit of :class:`~asyncio.StreamReader`. # That's why it must be set to half of ``self.read_limit``. - self.reader: asyncio.StreamReader = asyncio.StreamReader( - limit=read_limit // 2, loop=loop - ) - - self.transport: asyncio.Transport - self._drain_lock = asyncio.Lock( - loop=loop if sys.version_info[:2] < (3, 8) else None - ) + self.reader = asyncio.StreamReader(limit=read_limit // 2, loop=loop) # Copied from asyncio.FlowControlMixin self._paused = False self._drain_waiter: Optional[asyncio.Future[None]] = None + self._drain_lock = asyncio.Lock( + loop=loop if sys.version_info[:2] < (3, 8) else None + ) + # This class implements the data transfer and closing handshake, which # are shared between the client-side and the server-side. # Subclasses implement the opening handshake and, on success, execute From 3dab1fbe3705ba2c24cc7672d5ca3d7f02ea3535 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 6 Oct 2019 13:54:42 +0200 Subject: [PATCH 0666/1539] Small simplification. --- src/websockets/client.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 725ec1e7a..eb58f9f48 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -476,7 +476,6 @@ def __init__( # This is a coroutine function. self._create_connection = create_connection self._wsuri = wsuri - self._origin = origin def handle_redirect(self, uri: str) -> None: # Update the state of this instance to connect to a new URI. @@ -542,7 +541,7 @@ async def __await_impl__(self) -> WebSocketClientProtocol: try: await protocol.handshake( self._wsuri, - origin=self._origin, + origin=protocol.origin, available_extensions=protocol.available_extensions, available_subprotocols=protocol.available_subprotocols, extra_headers=protocol.extra_headers, From 2a87496cd80b273205bf5226ab0f9c12078b775d Mon Sep 17 00:00:00 2001 From: Anton Agestam Date: Tue, 8 Oct 2019 17:51:50 +0200 Subject: [PATCH 0667/1539] hardcoded top-level export --- src/websockets/__init__.py | 53 +++++++++++++++++++++++++++++++------- tests/test_exports.py | 22 ++++++++++++++++ 2 files changed, 65 insertions(+), 10 deletions(-) create mode 100644 tests/test_exports.py diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 6bad0f7bc..ea1d829a3 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -1,6 +1,5 @@ # This relies on each of the submodules having an __all__ variable. -from . import auth, client, exceptions, protocol, server, typing, uri from .auth import * # noqa from .client import * # noqa from .exceptions import * # noqa @@ -11,12 +10,46 @@ from .version import version as __version__ # noqa -__all__ = ( - auth.__all__ - + client.__all__ - + exceptions.__all__ - + protocol.__all__ - + server.__all__ - + typing.__all__ - + uri.__all__ -) +__all__ = [ + "AbortHandshake", + "basic_auth_protocol_factory", + "BasicAuthWebSocketServerProtocol", + "connect", + "ConnectionClosed", + "ConnectionClosedError", + "ConnectionClosedOK", + "Data", + "DuplicateParameter", + "ExtensionHeader", + "ExtensionParameter", + "InvalidHandshake", + "InvalidHeader", + "InvalidHeaderFormat", + "InvalidHeaderValue", + "InvalidMessage", + "InvalidOrigin", + "InvalidParameterName", + "InvalidParameterValue", + "InvalidState", + "InvalidStatusCode", + "InvalidUpgrade", + "InvalidURI", + "NegotiationError", + "Origin", + "parse_uri", + "PayloadTooBig", + "ProtocolError", + "RedirectHandshake", + "SecurityError", + "serve", + "Subprotocol", + "unix_connect", + "unix_serve", + "WebSocketClientProtocol", + "WebSocketCommonProtocol", + "WebSocketException", + "WebSocketProtocolError", + "WebSocketServer", + "WebSocketServerProtocol", + "WebSocketURI", +] diff --git a/tests/test_exports.py b/tests/test_exports.py new file mode 100644 index 000000000..7fcbc80e3 --- /dev/null +++ b/tests/test_exports.py @@ -0,0 +1,22 @@ +import unittest + +import websockets + + +combined_exports = ( + websockets.auth.__all__ + + websockets.client.__all__ + + websockets.exceptions.__all__ + + websockets.protocol.__all__ + + websockets.server.__all__ + + websockets.typing.__all__ + + websockets.uri.__all__ +) + + +class TestExportsAllSubmodules(unittest.TestCase): + def test_top_level_module_reexports_all_submodule_exports(self): + self.assertEqual(set(combined_exports), set(websockets.__all__)) + + def test_submodule_exports_are_globally_unique(self): + self.assertEqual(len(set(combined_exports)), len(combined_exports)) From d62ef45facfc07aedf1f630b891f8c06212c5c59 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 1 Nov 2019 09:11:06 +0100 Subject: [PATCH 0668/1539] Use the new Tidelift copy in README. --- README.rst | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.rst b/README.rst index 5dc9a745d..a9f54a35e 100644 --- a/README.rst +++ b/README.rst @@ -83,11 +83,11 @@ Does that look good?
-

Professionally supported websockets is now available

-

Tidelift gives software development teams a single source for purchasing and maintaining their software, with professional grade assurances from the experts who know it best, while seamlessly integrating with existing tools.

-

Get supported websockets with the Tidelift Subscription

+

websockets for enterprise

+

Available as part of the Tidelift Subscription

+

The maintainers of websockets and thousands of other packages are working with Tidelift to deliver commercial support and maintenance for the open source dependencies you use to build your applications. Save time, reduce risk, and improve code health, while paying the maintainers of the exact dependencies you use. Learn more.


-

(If you contribute to ``websockets`` and would like to become an official support provider, let me know.)

+

(If you contribute to `websockets` and would like to become an official support provider, let me know.)

Why should I use ``websockets``? -------------------------------- From 0b5de4e3d11928115c56d52a983d0fc356559925 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 1 Nov 2019 10:05:23 +0100 Subject: [PATCH 0669/1539] Add websockets for enterprise page to the docs. --- docs/_static/tidelift.png | 1 + docs/index.rst | 3 +- docs/tidelift.rst | 112 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 115 insertions(+), 1 deletion(-) create mode 120000 docs/_static/tidelift.png create mode 100644 docs/tidelift.rst diff --git a/docs/_static/tidelift.png b/docs/_static/tidelift.png new file mode 120000 index 000000000..2d1ed4a2c --- /dev/null +++ b/docs/_static/tidelift.png @@ -0,0 +1 @@ +../../logo/tidelift.png \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index c18af96e4..1b2f85f0a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -93,6 +93,7 @@ This is about websockets-the-project rather than websockets-the-software. .. toctree:: :maxdepth: 2 - contributing changelog + contributing license + For enterprise diff --git a/docs/tidelift.rst b/docs/tidelift.rst new file mode 100644 index 000000000..43b457aaf --- /dev/null +++ b/docs/tidelift.rst @@ -0,0 +1,112 @@ +websockets for enterprise +========================= + +Available as part of the Tidelift Subscription +---------------------------------------------- + +.. image:: _static/tidelift.png + :height: 150px + :width: 150px + :align: left + +Tidelift is working with the maintainers of websockets and thousands of other +open source projects to deliver commercial support and maintenance for the +open source dependencies you use to build your applications. Save time, reduce +risk, and improve code health, while paying the maintainers of the exact +dependencies you use. + +.. raw:: html + + + + + +Enterprise-ready open source software—managed for you +----------------------------------------------------- + +The Tidelift Subscription is a managed open source subscription for +application dependencies covering millions of open source projects across +JavaScript, Python, Java, PHP, Ruby, .NET, and more. + +Your subscription includes: + +* **Security updates** + + * Tidelift’s security response team coordinates patches for new breaking + security vulnerabilities and alerts immediately through a private channel, + so your software supply chain is always secure. + +* **Licensing verification and indemnification** + + * Tidelift verifies license information to enable easy policy enforcement + and adds intellectual property indemnification to cover creators and users + in case something goes wrong. You always have a 100% up-to-date bill of + materials for your dependencies to share with your legal team, customers, + or partners. + +* **Maintenance and code improvement** + + * Tidelift ensures the software you rely on keeps working as long as you + need it to work. Your managed dependencies are actively maintained and we + recruit additional maintainers where required. + +* **Package selection and version guidance** + + * We help you choose the best open source packages from the start—and then + guide you through updates to stay on the best releases as new issues + arise. + +* **Roadmap input** + + * Take a seat at the table with the creators behind the software you use. + Tidelift’s participating maintainers earn more income as their software is + used by more subscribers, so they’re interested in knowing what you need. + +* **Tooling and cloud integration** + + * Tidelift works with GitHub, GitLab, BitBucket, and more. We support every + cloud platform (and other deployment targets, too). + +The end result? All of the capabilities you expect from commercial-grade +software, for the full breadth of open source you use. That means less time +grappling with esoteric open source trivia, and more time building your own +applications—and your business. + +.. raw:: html + + From 2a3c8581a3689326d31386804b100710623526c8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 1 Nov 2019 10:52:21 +0100 Subject: [PATCH 0670/1539] Reject invalid Basic Auth credentials. Either both username and password are provided, or none of them. --- src/websockets/uri.py | 6 +++++- tests/test_uri.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/websockets/uri.py b/src/websockets/uri.py index f5bbafa96..6669e5668 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -72,6 +72,10 @@ def parse_uri(uri: str) -> WebSocketURI: if parsed.query: resource_name += "?" + parsed.query user_info = None - if parsed.username or parsed.password: + if parsed.username is not None: + # urllib.parse.urlparse accepts URLs with a username but without a + # password. This doesn't make sense for HTTP Basic Auth credentials. + if parsed.password is None: + raise InvalidURI(uri) user_info = (parsed.username, parsed.password) return WebSocketURI(secure, host, port, resource_name, user_info) diff --git a/tests/test_uri.py b/tests/test_uri.py index b7b69c3c1..e41860b8e 100644 --- a/tests/test_uri.py +++ b/tests/test_uri.py @@ -16,6 +16,7 @@ "http://localhost/", "https://localhost/", "ws://localhost/path#fragment", + "ws://user@localhost/", ] From b4f6efaf829c6b6acd33294fb6cab14bdc61584b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 1 Nov 2019 14:32:54 +0100 Subject: [PATCH 0671/1539] Make single-element tuple unpacking more explicit. The latest version of black does this. It's a good. --- src/websockets/framing.py | 6 +++--- tests/test_protocol.py | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/websockets/framing.py b/src/websockets/framing.py index c24b8a73d..26e58cdbf 100644 --- a/src/websockets/framing.py +++ b/src/websockets/framing.py @@ -118,10 +118,10 @@ async def read( length = head2 & 0b01111111 if length == 126: data = await reader(2) - length, = struct.unpack("!H", data) + (length,) = struct.unpack("!H", data) elif length == 127: data = await reader(8) - length, = struct.unpack("!Q", data) + (length,) = struct.unpack("!Q", data) if max_size is not None and length > max_size: raise PayloadTooBig( f"payload length exceeds size limit ({length} > {max_size} bytes)" @@ -304,7 +304,7 @@ def parse_close(data: bytes) -> Tuple[int, str]: """ length = len(data) if length >= 2: - code, = struct.unpack("!H", data[:2]) + (code,) = struct.unpack("!H", data[:2]) check_close(code) reason = data[2:].decode("utf-8") return code, reason diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 66a822e79..d32c1f72e 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -911,7 +911,7 @@ def test_abort_ping(self): def test_abort_ping_does_not_log_exception_if_not_retreived(self): self.loop.run_until_complete(self.protocol.ping()) # Get the internal Future, which isn't directly returned by ping(). - ping, = self.protocol.pings.values() + (ping,) = self.protocol.pings.values() # Remove the frame from the buffer, else close_connection() complains. self.last_sent_frame() self.close_connection() @@ -1126,13 +1126,13 @@ def test_keepalive_ping(self): # Ping is sent at 3ms and acknowledged at 4ms. self.loop.run_until_complete(asyncio.sleep(4 * MS)) - ping_1, = tuple(self.protocol.pings) + (ping_1,) = tuple(self.protocol.pings) self.assertOneFrameSent(True, OP_PING, ping_1) self.receive_frame(Frame(True, OP_PONG, ping_1)) # Next ping is sent at 7ms. self.loop.run_until_complete(asyncio.sleep(4 * MS)) - ping_2, = tuple(self.protocol.pings) + (ping_2,) = tuple(self.protocol.pings) self.assertOneFrameSent(True, OP_PING, ping_2) # The keepalive ping task goes on. @@ -1143,7 +1143,7 @@ def test_keepalive_ping_not_acknowledged_closes_connection(self): # Ping is sent at 3ms and not acknowleged. self.loop.run_until_complete(asyncio.sleep(4 * MS)) - ping_1, = tuple(self.protocol.pings) + (ping_1,) = tuple(self.protocol.pings) self.assertOneFrameSent(True, OP_PING, ping_1) # Connection is closed at 6ms. @@ -1183,7 +1183,7 @@ def test_keepalive_ping_does_not_crash_when_connection_lost(self): self.receive_frame(Frame(True, OP_TEXT, b"2")) # Ping is sent at 3ms. self.loop.run_until_complete(asyncio.sleep(4 * MS)) - ping_waiter, = tuple(self.protocol.pings.values()) + (ping_waiter,) = tuple(self.protocol.pings.values()) # Connection drops. self.receive_eof() self.loop.run_until_complete(self.protocol.wait_closed()) @@ -1210,7 +1210,7 @@ def test_keepalive_ping_with_no_ping_timeout(self): # Ping is sent at 3ms and not acknowleged. self.loop.run_until_complete(asyncio.sleep(4 * MS)) - ping_1, = tuple(self.protocol.pings) + (ping_1,) = tuple(self.protocol.pings) self.assertOneFrameSent(True, OP_PING, ping_1) # Next ping is sent at 7ms anyway. From 20d1eb2e5afcc03b49aafbf113250ffdc9f432e2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 1 Nov 2019 14:36:51 +0100 Subject: [PATCH 0672/1539] RST doesn't work inside raw HTML. --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index a9f54a35e..1e15ba198 100644 --- a/README.rst +++ b/README.rst @@ -87,7 +87,7 @@ Does that look good?

Available as part of the Tidelift Subscription

The maintainers of websockets and thousands of other packages are working with Tidelift to deliver commercial support and maintenance for the open source dependencies you use to build your applications. Save time, reduce risk, and improve code health, while paying the maintainers of the exact dependencies you use. Learn more.


-

(If you contribute to `websockets` and would like to become an official support provider, let me know.)

+

(If you contribute to websockets and would like to become an official support provider, let me know.)

Why should I use ``websockets``? -------------------------------- From 139085fe2624192a5a6c72b1e5db211dcec6ced1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 1 Nov 2019 14:39:33 +0100 Subject: [PATCH 0673/1539] Bump version number. --- docs/changelog.rst | 5 ++++- docs/conf.py | 4 ++-- src/websockets/version.py | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 2a106fbc0..04f18a765 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -3,11 +3,14 @@ Changelog .. currentmodule:: websockets -8.1 +8.2 ... *In development* +8.1 +... + * Added compatibility with Python 3.8. 8.0.2 diff --git a/docs/conf.py b/docs/conf.py index 617989cb1..064c657bf 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -59,9 +59,9 @@ # built documents. # # The short X.Y version. -version = '8.0' +version = '8.1' # The full version, including alpha/beta/rc tags. -release = '8.0.2' +release = '8.1' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/src/websockets/version.py b/src/websockets/version.py index cd8898041..7377332e1 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -1 +1 @@ -version = "8.0.2" +version = "8.1" From 93ad88a9a8fe2ea8d96fb1d2a0f1625a3c5fee7c Mon Sep 17 00:00:00 2001 From: Alex Coplan Date: Mon, 4 Nov 2019 11:54:49 +0000 Subject: [PATCH 0674/1539] fix type hints on client/server args * Make ping_interval et al. optional so that code that passes None here will type check. --- src/websockets/client.py | 8 ++++---- src/websockets/server.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index eb58f9f48..831b70805 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -373,11 +373,11 @@ def __init__( *, path: Optional[str] = None, create_protocol: Optional[Type[WebSocketClientProtocol]] = None, - ping_interval: float = 20, - ping_timeout: float = 20, + ping_interval: Optional[float] = 20, + ping_timeout: Optional[float] = 20, close_timeout: Optional[float] = None, - max_size: int = 2 ** 20, - max_queue: int = 2 ** 5, + max_size: Optional[int] = 2 ** 20, + max_queue: Optional[int] = 2 ** 5, read_limit: int = 2 ** 16, write_limit: int = 2 ** 16, loop: Optional[asyncio.AbstractEventLoop] = None, diff --git a/src/websockets/server.py b/src/websockets/server.py index 4f5e9e0ef..0313fa848 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -842,11 +842,11 @@ def __init__( *, path: Optional[str] = None, create_protocol: Optional[Type[WebSocketServerProtocol]] = None, - ping_interval: float = 20, - ping_timeout: float = 20, + ping_interval: Optional[float] = 20, + ping_timeout: Optional[float] = 20, close_timeout: Optional[float] = None, - max_size: int = 2 ** 20, - max_queue: int = 2 ** 5, + max_size: Optional[int] = 2 ** 20, + max_queue: Optional[int] = 2 ** 5, read_limit: int = 2 ** 16, write_limit: int = 2 ** 16, loop: Optional[asyncio.AbstractEventLoop] = None, From 3bab7fd155636c73b79b258de752b36687bba347 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 16 Nov 2019 20:37:14 +0100 Subject: [PATCH 0675/1539] Clarify local/remote_address after connection is closed. Fix #688. --- src/websockets/protocol.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 6c29b2a52..e065bef67 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -357,10 +357,9 @@ def secure(self) -> Optional[bool]: @property def local_address(self) -> Any: """ - Local address of the connection. + Local address of the connection as a ``(host, port)`` tuple. - This is a ``(host, port)`` tuple or ``None`` if the connection hasn't - been established yet. + When the connection isn't open, ``local_address`` is ``None``. """ try: @@ -373,10 +372,9 @@ def local_address(self) -> Any: @property def remote_address(self) -> Any: """ - Remote address of the connection. + Remote address of the connection as a ``(host, port)`` tuple. - This is a ``(host, port)`` tuple or ``None`` if the connection hasn't - been established yet. + When the connection isn't open, ``remote_address`` is ``None``. """ try: From 910f417c9179150c5ab4b44c7361dbf1e51ec322 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 16 Nov 2019 20:40:15 +0100 Subject: [PATCH 0676/1539] Always reraise CancelledError. It's really hard to write tests for this :-( Fix #672. --- src/websockets/client.py | 2 ++ src/websockets/server.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/src/websockets/client.py b/src/websockets/client.py index 831b70805..f92350249 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -99,6 +99,8 @@ async def read_http_response(self) -> Tuple[int, Headers]: """ try: status_code, reason, headers = await read_response(self.reader) + except asyncio.CancelledError: # pragma: no cover + raise except Exception as exc: raise InvalidMessage("did not receive a valid HTTP response") from exc diff --git a/src/websockets/server.py b/src/websockets/server.py index 0313fa848..f872262ef 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -133,6 +133,8 @@ async def handler(self) -> None: available_subprotocols=self.available_subprotocols, extra_headers=self.extra_headers, ) + except asyncio.CancelledError: # pragma: no cover + raise except ConnectionError: logger.debug("Connection error in opening handshake", exc_info=True) raise @@ -231,6 +233,8 @@ async def read_http_request(self) -> Tuple[str, Headers]: """ try: path, headers = await read_request(self.reader) + except asyncio.CancelledError: # pragma: no cover + raise except Exception as exc: raise InvalidMessage("did not receive a valid HTTP request") from exc From a1615b47fcd416e5016d7e471976314c267f4349 Mon Sep 17 00:00:00 2001 From: Hugo Date: Tue, 14 Jan 2020 20:53:41 +0200 Subject: [PATCH 0677/1539] Fix for Python 3.10: use sys.version_info instead of sys.version --- src/websockets/http.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/websockets/http.py b/src/websockets/http.py index ba6d274bf..f87bfb76a 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -36,7 +36,8 @@ MAX_HEADERS = 256 MAX_LINE = 4096 -USER_AGENT = f"Python/{sys.version[:3]} websockets/{websockets_version}" +PYTHON_VERSION = "{}.{}".format(*sys.version_info) +USER_AGENT = f"Python/{PYTHON_VERSION} websockets/{websockets_version}" def d(value: bytes) -> str: From 160dfbec7dd582c12817de5c85e6bf3fbbc34826 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 25 Jan 2020 21:10:09 +0100 Subject: [PATCH 0678/1539] Clarify comment about RFC inconsistency. --- src/websockets/handshake.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/websockets/handshake.py b/src/websockets/handshake.py index 9bfe27754..646b6dba4 100644 --- a/src/websockets/handshake.py +++ b/src/websockets/handshake.py @@ -87,7 +87,8 @@ def check_request(headers: Headers) -> str: ) # For compatibility with non-strict implementations, ignore case when - # checking the Upgrade header. It's supposed to be 'WebSocket'. + # checking the Upgrade header. The RFC always uses "websocket", except + # in section 11.2. (IANA registration) where it uses "WebSocket". if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"): raise InvalidUpgrade("Upgrade", ", ".join(upgrade)) @@ -163,7 +164,8 @@ def check_response(headers: Headers, key: str) -> None: ) # For compatibility with non-strict implementations, ignore case when - # checking the Upgrade header. It's supposed to be 'WebSocket'. + # checking the Upgrade header. The RFC always uses "websocket", except + # in section 11.2. (IANA registration) where it uses "WebSocket". if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"): raise InvalidUpgrade("Upgrade", ", ".join(upgrade)) From 4f1964295ad0e81c8c96b99c3fe9dafc96f11f28 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 18 Feb 2020 22:09:55 +0100 Subject: [PATCH 0679/1539] Speculation about proof-of-stake gets old. Meanwhile, bitcoin still heats the planet. Sorry crypto buffs. Refs #480 and several others. --- docs/contributing.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/contributing.rst b/docs/contributing.rst index 40f1dbb54..61c0b979c 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -53,6 +53,9 @@ Bitcoin users websockets appears to be quite popular for interfacing with Bitcoin or other cryptocurrency trackers. I'm strongly opposed to Bitcoin's carbon footprint. +I'm aware of efforts to build proof-of-stake models. I'll care once the total +carbon footprint of all cryptocurrencies drops to a non-bullshit level. + Please stop heating the planet where my children are supposed to live, thanks. Since ``websockets`` is released under an open-source license, you can use it From 6b5cbaf41cdbc9a2074e357ccc613ef25517dd32 Mon Sep 17 00:00:00 2001 From: Tim Gates Date: Sun, 1 Mar 2020 19:10:49 +1100 Subject: [PATCH 0680/1539] Fix simple typo: severel -> several There is a small typo in src/websockets/client.py, src/websockets/server.py. Should read `several` rather than `severel`. --- src/websockets/client.py | 2 +- src/websockets/server.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index f92350249..be055310d 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -134,7 +134,7 @@ def process_extensions( client configuration. If no match is found, an exception is raised. If several variants of the same extension are accepted by the server, - it may be configured severel times, which won't make sense in general. + it may be configured several times, which won't make sense in general. Extensions must implement their own requirements. For this purpose, the list of previously accepted extensions is provided. diff --git a/src/websockets/server.py b/src/websockets/server.py index f872262ef..1d8de8914 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -369,7 +369,7 @@ def process_extensions( server configuration. If no match is found, the extension is ignored. If several variants of the same extension are proposed by the client, - it may be accepted severel times, which won't make sense in general. + it may be accepted several times, which won't make sense in general. Extensions must implement their own requirements. For this purpose, the list of previously accepted extensions is provided. From 18dbc49c935285e35a54e46030d326f3a49ea7b7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 23 May 2020 07:46:06 +0200 Subject: [PATCH 0681/1539] Run tests against the latest Python 3.8. --- .circleci/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 0877c161a..68d02416d 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -31,7 +31,7 @@ jobs: - run: tox -e py37 py38: docker: - - image: circleci/python:3.8.0rc1 + - image: circleci/python:3.8 steps: # Remove IPv6 entry for localhost in Circle CI containers because it doesn't work anyway. - run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc From 6170e235723f27a5aaa42ea86828f0266cc004f9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 23 May 2020 09:30:16 +0200 Subject: [PATCH 0682/1539] Don't attempt to build wheels on PyPy 2.7. --- .appveyor.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.appveyor.yml b/.appveyor.yml index 7954ee4be..2db489a76 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -6,7 +6,7 @@ skip_branch_with_pr: true environment: # websockets only works on Python >= 3.6. - CIBW_SKIP: cp27-* cp33-* cp34-* cp35-* + CIBW_SKIP: cp27-* cp33-* cp34-* cp35-* pp27-* CIBW_TEST_COMMAND: python -W default -m unittest WEBSOCKETS_TESTS_TIMEOUT_FACTOR: 100 From 46e8fb5cecb474991e18f7b809378b7d76477df2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 23 May 2020 09:58:58 +0200 Subject: [PATCH 0683/1539] Fix flake8 violation. --- src/websockets/protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index e065bef67..2082c81fc 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -349,7 +349,7 @@ def port(self) -> Optional[int]: @property def secure(self) -> Optional[bool]: - warnings.warn(f"don't use secure", DeprecationWarning) + warnings.warn("don't use secure", DeprecationWarning) return self._secure # Public API From 68dfb14963ea12e0068aefbbb43f101113d0750d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 23 May 2020 10:15:51 +0200 Subject: [PATCH 0684/1539] Don't attempt to build wheels on PyPy 2.7 (bis). --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 030693759..6234bb649 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,7 @@ env: global: # websockets only works on Python >= 3.6. - - CIBW_SKIP="cp27-* cp33-* cp34-* cp35-*" + - CIBW_SKIP="cp27-* cp33-* cp34-* cp35-* pp27-*" - CIBW_TEST_COMMAND="python3 -W default -m unittest" - WEBSOCKETS_TESTS_TIMEOUT_FACTOR=100 From fafcf65d430149a8b94379f9557655828a0dcdab Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 23 May 2020 12:56:22 +0200 Subject: [PATCH 0685/1539] Only build wheels on supported CPython versions. PyPy 3 wheels were failing to build on macOS. --- .appveyor.yml | 2 +- .travis.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.appveyor.yml b/.appveyor.yml index 2db489a76..d34b15aed 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -6,7 +6,7 @@ skip_branch_with_pr: true environment: # websockets only works on Python >= 3.6. - CIBW_SKIP: cp27-* cp33-* cp34-* cp35-* pp27-* + CIBW_BUILD: cp36-* cp37-* cp38-* CIBW_TEST_COMMAND: python -W default -m unittest WEBSOCKETS_TESTS_TIMEOUT_FACTOR: 100 diff --git a/.travis.yml b/.travis.yml index 6234bb649..26e1de60e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,7 @@ env: global: # websockets only works on Python >= 3.6. - - CIBW_SKIP="cp27-* cp33-* cp34-* cp35-* pp27-*" + - CIBW_BUILD="cp36-* cp37-* cp38-*" - CIBW_TEST_COMMAND="python3 -W default -m unittest" - WEBSOCKETS_TESTS_TIMEOUT_FACTOR=100 From 69c94af5c0ad19402e0bedcc6b61a23fa070c946 Mon Sep 17 00:00:00 2001 From: David Bordeynik Date: Mon, 18 May 2020 10:38:08 +0300 Subject: [PATCH 0686/1539] Future-proof asyncio.wait usage. Fix #762. --- .circleci/config.yml | 12 ++++++++++++ .gitignore | 1 + src/websockets/server.py | 5 ++++- tox.ini | 2 +- 4 files changed, 18 insertions(+), 2 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 68d02416d..7be85d7f9 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -38,6 +38,15 @@ jobs: - checkout - run: sudo pip install tox - run: tox -e py38 + py39: + docker: + - image: circleci/python:3.9.0b1 + steps: + # Remove IPv6 entry for localhost in Circle CI containers because it doesn't work anyway. + - run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc + - checkout + - run: sudo pip install tox + - run: tox -e py39 workflows: version: 2 @@ -53,3 +62,6 @@ workflows: - py38: requires: - main + - py39: + requires: + - main diff --git a/.gitignore b/.gitignore index ef0d16520..c23cf5210 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ *.pyc *.so .coverage +.idea/ .mypy_cache .tox build/ diff --git a/src/websockets/server.py b/src/websockets/server.py index 1d8de8914..e9318a4df 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -714,7 +714,10 @@ async def _close(self) -> None: # asyncio.wait doesn't accept an empty first argument if self.websockets: await asyncio.wait( - [websocket.close(1001) for websocket in self.websockets], + [ + asyncio.ensure_future(websocket.close(1001)) + for websocket in self.websockets + ], loop=self.loop if sys.version_info[:2] < (3, 8) else None, ) diff --git a/tox.ini b/tox.ini index 825e34061..cc224f9c6 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py36,py37,py38,coverage,black,flake8,isort,mypy +envlist = py36,py37,py38,py39,coverage,black,flake8,isort,mypy [testenv] commands = python -W default -m unittest {posargs} From 24a77def7097cb7ae651edf35582c8def5a6ad3e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 13 Jun 2020 17:41:06 +0200 Subject: [PATCH 0687/1539] Update to mypy 0.780. --- src/websockets/__main__.py | 8 ++++---- src/websockets/typing.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index 394f7ac79..1a720498d 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -49,10 +49,10 @@ def exit_from_event_loop_thread( if not stop.done(): # When exiting the thread that runs the event loop, raise # KeyboardInterrupt in the main thread to exit the program. - try: - ctrl_c = signal.CTRL_C_EVENT # Windows - except AttributeError: - ctrl_c = signal.SIGINT # POSIX + if sys.platform == "win32": + ctrl_c = signal.CTRL_C_EVENT + else: + ctrl_c = signal.SIGINT os.kill(os.getpid(), ctrl_c) diff --git a/src/websockets/typing.py b/src/websockets/typing.py index 4a60f93f6..a5062bc4b 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -14,7 +14,7 @@ """ # Remove try / except when dropping support for Python < 3.7 try: - Data.__doc__ = Data__doc__ # type: ignore + Data.__doc__ = Data__doc__ except AttributeError: # pragma: no cover pass @@ -31,7 +31,7 @@ ExtensionParameter__doc__ = """Parameter of a WebSocket extension""" try: - ExtensionParameter.__doc__ = ExtensionParameter__doc__ # type: ignore + ExtensionParameter.__doc__ = ExtensionParameter__doc__ except AttributeError: # pragma: no cover pass @@ -40,7 +40,7 @@ ExtensionHeader__doc__ = """Item parsed in a Sec-WebSocket-Extensions header""" try: - ExtensionHeader.__doc__ = ExtensionHeader__doc__ # type: ignore + ExtensionHeader.__doc__ = ExtensionHeader__doc__ except AttributeError: # pragma: no cover pass From 017a072705408d3df945e333e5edd93e0aa8c706 Mon Sep 17 00:00:00 2001 From: Ram Rachum Date: Fri, 12 Jun 2020 23:16:57 +0300 Subject: [PATCH 0688/1539] Fix exception causes in server.py --- src/websockets/server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/websockets/server.py b/src/websockets/server.py index e9318a4df..0f0b51a7c 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -340,8 +340,8 @@ def process_origin( # per https://tools.ietf.org/html/rfc6454#section-7.3. try: origin = cast(Origin, headers.get("Origin")) - except MultipleValuesError: - raise InvalidHeader("Origin", "more than one Origin header found") + except MultipleValuesError as exc: + raise InvalidHeader("Origin", "more than one Origin header found") from exc if origins is not None: if origin not in origins: raise InvalidOrigin(origin) From 17499930cec591778d13e594b0cb978a9961e276 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 26 Jul 2020 13:39:43 +0200 Subject: [PATCH 0689/1539] Ignore coverage measurement issue. --- src/websockets/protocol.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 2082c81fc..803970205 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -1175,7 +1175,9 @@ async def close_connection(self) -> None: # A client should wait for a TCP close from the server. if self.is_client and hasattr(self, "transfer_data_task"): if await self.wait_for_connection_lost(): - return + # Coverage marks this line as a partially executed branch. + # I supect a bug in coverage. Ignore it for now. + return # pragma: no cover logger.debug("%s ! timed out waiting for TCP close", self.side) # Half-close the TCP connection if possible (when there's no TLS). @@ -1184,7 +1186,9 @@ async def close_connection(self) -> None: self.transport.write_eof() if await self.wait_for_connection_lost(): - return + # Coverage marks this line as a partially executed branch. + # I supect a bug in coverage. Ignore it for now. + return # pragma: no cover logger.debug("%s ! timed out waiting for TCP close", self.side) finally: @@ -1210,7 +1214,9 @@ async def close_connection(self) -> None: self.transport.abort() # connection_lost() is called quickly after aborting. - await self.wait_for_connection_lost() + # Coverage marks this line as a partially executed branch. + # I supect a bug in coverage. Ignore it for now. + await self.wait_for_connection_lost() # pragma: no cover async def wait_for_connection_lost(self) -> bool: """ From f0cfa6ba2abf6d4b032b30cfae9d321e583d546e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 26 Jul 2020 18:04:08 +0200 Subject: [PATCH 0690/1539] Realign docstring with Python version. --- src/websockets/speedups.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/websockets/speedups.c b/src/websockets/speedups.c index d1c2b37e6..ede181e5d 100644 --- a/src/websockets/speedups.c +++ b/src/websockets/speedups.c @@ -181,7 +181,7 @@ static PyMethodDef speedups_methods[] = { "apply_mask", (PyCFunction)apply_mask, METH_VARARGS | METH_KEYWORDS, - "Apply masking to websocket message.", + "Apply masking to the data of a WebSocket message.", }, {NULL, NULL, 0, NULL}, /* Sentinel */ }; From daad5180e09af5d860edf4191fb1791eb6b57cc8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 26 Jul 2020 20:21:51 +0200 Subject: [PATCH 0691/1539] =?UTF-8?q?Upgrade=20to=20isort=20=E2=89=A5=205.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Makefile | 2 +- setup.cfg | 6 +----- tox.ini | 2 +- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/Makefile b/Makefile index d9e16fefe..06832945c 100644 --- a/Makefile +++ b/Makefile @@ -6,7 +6,7 @@ export PYTHONPATH=src default: coverage style style: - isort --recursive src tests + isort src tests black src tests flake8 src tests mypy --strict src diff --git a/setup.cfg b/setup.cfg index c306b2d4f..02e70cdf5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,13 +9,9 @@ ignore = E731,F403,F405,W503 max-line-length = 88 [isort] +profile = black combine_as_imports = True -force_grid_wrap = 0 -include_trailing_comma = True -known_standard_library = asyncio -line_length = 88 lines_after_imports = 2 -multi_line_output = 3 [coverage:run] branch = True diff --git a/tox.ini b/tox.ini index cc224f9c6..b5488e5b0 100644 --- a/tox.ini +++ b/tox.ini @@ -20,7 +20,7 @@ commands = flake8 src tests deps = flake8 [testenv:isort] -commands = isort --check-only --recursive src tests +commands = isort --check-only src tests deps = isort [testenv:mypy] From 85b3fd67490bc1e5aa9e46c292c00aceeaa0d40b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 6 Oct 2019 10:28:03 +0200 Subject: [PATCH 0692/1539] Move Headers class to its own module. This allows breaking an import loop. --- docs/api.rst | 6 ++ docs/changelog.rst | 12 ++- src/websockets/__init__.py | 1 + src/websockets/auth.py | 2 +- src/websockets/client.py | 3 +- src/websockets/datastructures.py | 159 ++++++++++++++++++++++++++++ src/websockets/exceptions.py | 2 +- src/websockets/handshake.py | 2 +- src/websockets/http.py | 173 ++----------------------------- src/websockets/protocol.py | 2 +- src/websockets/server.py | 3 +- tests/test_client_server.py | 3 +- tests/test_datastructures.py | 131 +++++++++++++++++++++++ tests/test_exceptions.py | 2 +- tests/test_handshake.py | 2 +- tests/test_http.py | 114 -------------------- 16 files changed, 330 insertions(+), 287 deletions(-) create mode 100644 src/websockets/datastructures.py create mode 100644 tests/test_datastructures.py diff --git a/docs/api.rst b/docs/api.rst index d265a91c2..f7706ee2c 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -115,6 +115,12 @@ HTTP Basic Auth .. automethod:: process_request +Data structures +............... + +.. automodule:: websockets.datastructures + :members: + Exceptions .......... diff --git a/docs/changelog.rst b/docs/changelog.rst index 04f18a765..5de7357ca 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -3,11 +3,21 @@ Changelog .. currentmodule:: websockets -8.2 +9.0 ... *In development* +.. note:: + + **Version 9.0 moves or deprecates several low-level APIs.** + + * Import :class:`~datastructures.Headers` and + :exc:`~datastructures.MultipleValuesError` from + :mod:`websockets.datastructures` instead of :mod:`websockets.http`. + + Aliases provide backwards compatibility for all previously public APIs. + 8.1 ... diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index ea1d829a3..89829235c 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -2,6 +2,7 @@ from .auth import * # noqa from .client import * # noqa +from .datastructures import * # noqa from .exceptions import * # noqa from .protocol import * # noqa from .server import * # noqa diff --git a/src/websockets/auth.py b/src/websockets/auth.py index ae204b8d9..8198cd9d0 100644 --- a/src/websockets/auth.py +++ b/src/websockets/auth.py @@ -9,9 +9,9 @@ import http from typing import Any, Awaitable, Callable, Iterable, Optional, Tuple, Type, Union +from .datastructures import Headers from .exceptions import InvalidHeader from .headers import build_www_authenticate_basic, parse_authorization_basic -from .http import Headers from .server import HTTPResponse, WebSocketServerProtocol diff --git a/src/websockets/client.py b/src/websockets/client.py index be055310d..26a369c47 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -11,6 +11,7 @@ from types import TracebackType from typing import Any, Generator, List, Optional, Sequence, Tuple, Type, cast +from .datastructures import Headers, HeadersLike from .exceptions import ( InvalidHandshake, InvalidHeader, @@ -30,7 +31,7 @@ parse_extension, parse_subprotocol, ) -from .http import USER_AGENT, Headers, HeadersLike, read_response +from .http import USER_AGENT, read_response from .protocol import WebSocketCommonProtocol from .typing import ExtensionHeader, Origin, Subprotocol from .uri import WebSocketURI, parse_uri diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py new file mode 100644 index 000000000..f70d92ad7 --- /dev/null +++ b/src/websockets/datastructures.py @@ -0,0 +1,159 @@ +""" +This module defines a data structure for manipulating HTTP headers. + +""" + +from typing import ( + Any, + Dict, + Iterable, + Iterator, + List, + Mapping, + MutableMapping, + Tuple, + Union, +) + + +__all__ = ["Headers", "MultipleValuesError"] + + +class MultipleValuesError(LookupError): + """ + Exception raised when :class:`Headers` has more than one value for a key. + + """ + + def __str__(self) -> str: + # Implement the same logic as KeyError_str in Objects/exceptions.c. + if len(self.args) == 1: + return repr(self.args[0]) + return super().__str__() + + +class Headers(MutableMapping[str, str]): + """ + Efficient data structure for manipulating HTTP headers. + + A :class:`list` of ``(name, values)`` is inefficient for lookups. + + A :class:`dict` doesn't suffice because header names are case-insensitive + and multiple occurrences of headers with the same name are possible. + + :class:`Headers` stores HTTP headers in a hybrid data structure to provide + efficient insertions and lookups while preserving the original data. + + In order to account for multiple values with minimal hassle, + :class:`Headers` follows this logic: + + - When getting a header with ``headers[name]``: + - if there's no value, :exc:`KeyError` is raised; + - if there's exactly one value, it's returned; + - if there's more than one value, :exc:`MultipleValuesError` is raised. + + - When setting a header with ``headers[name] = value``, the value is + appended to the list of values for that header. + + - When deleting a header with ``del headers[name]``, all values for that + header are removed (this is slow). + + Other methods for manipulating headers are consistent with this logic. + + As long as no header occurs multiple times, :class:`Headers` behaves like + :class:`dict`, except keys are lower-cased to provide case-insensitivity. + + Two methods support support manipulating multiple values explicitly: + + - :meth:`get_all` returns a list of all values for a header; + - :meth:`raw_items` returns an iterator of ``(name, values)`` pairs. + + """ + + __slots__ = ["_dict", "_list"] + + def __init__(self, *args: Any, **kwargs: str) -> None: + self._dict: Dict[str, List[str]] = {} + self._list: List[Tuple[str, str]] = [] + # MutableMapping.update calls __setitem__ for each (name, value) pair. + self.update(*args, **kwargs) + + def __str__(self) -> str: + return "".join(f"{key}: {value}\r\n" for key, value in self._list) + "\r\n" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self._list!r})" + + def copy(self) -> "Headers": + copy = self.__class__() + copy._dict = self._dict.copy() + copy._list = self._list.copy() + return copy + + def serialize(self) -> bytes: + # Headers only contain ASCII characters. + return str(self).encode() + + # Collection methods + + def __contains__(self, key: object) -> bool: + return isinstance(key, str) and key.lower() in self._dict + + def __iter__(self) -> Iterator[str]: + return iter(self._dict) + + def __len__(self) -> int: + return len(self._dict) + + # MutableMapping methods + + def __getitem__(self, key: str) -> str: + value = self._dict[key.lower()] + if len(value) == 1: + return value[0] + else: + raise MultipleValuesError(key) + + def __setitem__(self, key: str, value: str) -> None: + self._dict.setdefault(key.lower(), []).append(value) + self._list.append((key, value)) + + def __delitem__(self, key: str) -> None: + key_lower = key.lower() + self._dict.__delitem__(key_lower) + # This is inefficent. Fortunately deleting HTTP headers is uncommon. + self._list = [(k, v) for k, v in self._list if k.lower() != key_lower] + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Headers): + return NotImplemented + return self._list == other._list + + def clear(self) -> None: + """ + Remove all headers. + + """ + self._dict = {} + self._list = [] + + # Methods for handling multiple values + + def get_all(self, key: str) -> List[str]: + """ + Return the (possibly empty) list of all values for a header. + + :param key: header name + + """ + return self._dict.get(key.lower(), []) + + def raw_items(self) -> Iterator[Tuple[str, str]]: + """ + Return an iterator of all values as ``(name, value)`` pairs. + + """ + return iter(self._list) + + +HeadersLike = Union[Headers, Mapping[str, str], Iterable[Tuple[str, str]]] diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 9873a1717..e593f1adc 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -30,7 +30,7 @@ import http from typing import Optional -from .http import Headers, HeadersLike +from .datastructures import Headers, HeadersLike __all__ = [ diff --git a/src/websockets/handshake.py b/src/websockets/handshake.py index 646b6dba4..e30a67125 100644 --- a/src/websockets/handshake.py +++ b/src/websockets/handshake.py @@ -31,9 +31,9 @@ import random from typing import List +from .datastructures import Headers, MultipleValuesError from .exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade from .headers import ConnectionOption, UpgradeProtocol, parse_connection, parse_upgrade -from .http import Headers, MultipleValuesError __all__ = ["build_request", "check_request", "build_response", "check_response"] diff --git a/src/websockets/http.py b/src/websockets/http.py index f87bfb76a..ddb2afcfa 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -10,28 +10,15 @@ import asyncio import re import sys -from typing import ( - Any, - Dict, - Iterable, - Iterator, - List, - Mapping, - MutableMapping, - Tuple, - Union, -) +from typing import Tuple +# For backwards compatibility - should be deprecated +from .datastructures import Headers, MultipleValuesError # noqa +from .exceptions import SecurityError from .version import version as websockets_version -__all__ = [ - "read_request", - "read_response", - "Headers", - "MultipleValuesError", - "USER_AGENT", -] +__all__ = ["read_request", "read_response", "USER_AGENT"] MAX_HEADERS = 256 MAX_LINE = 4096 @@ -68,7 +55,7 @@ def d(value: bytes) -> str: _value_re = re.compile(rb"[\x09\x20-\x7e\x80-\xff]*") -async def read_request(stream: asyncio.StreamReader) -> Tuple[str, "Headers"]: +async def read_request(stream: asyncio.StreamReader) -> Tuple[str, Headers]: """ Read an HTTP/1.1 GET request and return ``(path, headers)``. @@ -114,7 +101,7 @@ async def read_request(stream: asyncio.StreamReader) -> Tuple[str, "Headers"]: return path, headers -async def read_response(stream: asyncio.StreamReader) -> Tuple[int, str, "Headers"]: +async def read_response(stream: asyncio.StreamReader) -> Tuple[int, str, Headers]: """ Read an HTTP/1.1 response and return ``(status_code, reason, headers)``. @@ -163,7 +150,7 @@ async def read_response(stream: asyncio.StreamReader) -> Tuple[int, str, "Header return status_code, reason, headers -async def read_headers(stream: asyncio.StreamReader) -> "Headers": +async def read_headers(stream: asyncio.StreamReader) -> Headers: """ Read HTTP headers from ``stream``. @@ -198,7 +185,7 @@ async def read_headers(stream: asyncio.StreamReader) -> "Headers": headers[name] = value else: - raise websockets.exceptions.SecurityError("too many HTTP headers") + raise SecurityError("too many HTTP headers") return headers @@ -214,148 +201,8 @@ async def read_line(stream: asyncio.StreamReader) -> bytes: line = await stream.readline() # Security: this guarantees header values are small (hard-coded = 4 KiB) if len(line) > MAX_LINE: - raise websockets.exceptions.SecurityError("line too long") + raise SecurityError("line too long") # Not mandatory but safe - https://tools.ietf.org/html/rfc7230#section-3.5 if not line.endswith(b"\r\n"): raise EOFError("line without CRLF") return line[:-2] - - -class MultipleValuesError(LookupError): - """ - Exception raised when :class:`Headers` has more than one value for a key. - - """ - - def __str__(self) -> str: - # Implement the same logic as KeyError_str in Objects/exceptions.c. - if len(self.args) == 1: - return repr(self.args[0]) - return super().__str__() - - -class Headers(MutableMapping[str, str]): - """ - Efficient data structure for manipulating HTTP headers. - - A :class:`list` of ``(name, values)`` is inefficient for lookups. - - A :class:`dict` doesn't suffice because header names are case-insensitive - and multiple occurrences of headers with the same name are possible. - - :class:`Headers` stores HTTP headers in a hybrid data structure to provide - efficient insertions and lookups while preserving the original data. - - In order to account for multiple values with minimal hassle, - :class:`Headers` follows this logic: - - - When getting a header with ``headers[name]``: - - if there's no value, :exc:`KeyError` is raised; - - if there's exactly one value, it's returned; - - if there's more than one value, :exc:`MultipleValuesError` is raised. - - - When setting a header with ``headers[name] = value``, the value is - appended to the list of values for that header. - - - When deleting a header with ``del headers[name]``, all values for that - header are removed (this is slow). - - Other methods for manipulating headers are consistent with this logic. - - As long as no header occurs multiple times, :class:`Headers` behaves like - :class:`dict`, except keys are lower-cased to provide case-insensitivity. - - Two methods support support manipulating multiple values explicitly: - - - :meth:`get_all` returns a list of all values for a header; - - :meth:`raw_items` returns an iterator of ``(name, values)`` pairs. - - """ - - __slots__ = ["_dict", "_list"] - - def __init__(self, *args: Any, **kwargs: str) -> None: - self._dict: Dict[str, List[str]] = {} - self._list: List[Tuple[str, str]] = [] - # MutableMapping.update calls __setitem__ for each (name, value) pair. - self.update(*args, **kwargs) - - def __str__(self) -> str: - return "".join(f"{key}: {value}\r\n" for key, value in self._list) + "\r\n" - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self._list!r})" - - def copy(self) -> "Headers": - copy = self.__class__() - copy._dict = self._dict.copy() - copy._list = self._list.copy() - return copy - - # Collection methods - - def __contains__(self, key: object) -> bool: - return isinstance(key, str) and key.lower() in self._dict - - def __iter__(self) -> Iterator[str]: - return iter(self._dict) - - def __len__(self) -> int: - return len(self._dict) - - # MutableMapping methods - - def __getitem__(self, key: str) -> str: - value = self._dict[key.lower()] - if len(value) == 1: - return value[0] - else: - raise MultipleValuesError(key) - - def __setitem__(self, key: str, value: str) -> None: - self._dict.setdefault(key.lower(), []).append(value) - self._list.append((key, value)) - - def __delitem__(self, key: str) -> None: - key_lower = key.lower() - self._dict.__delitem__(key_lower) - # This is inefficent. Fortunately deleting HTTP headers is uncommon. - self._list = [(k, v) for k, v in self._list if k.lower() != key_lower] - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, Headers): - return NotImplemented - return self._list == other._list - - def clear(self) -> None: - """ - Remove all headers. - - """ - self._dict = {} - self._list = [] - - # Methods for handling multiple values - - def get_all(self, key: str) -> List[str]: - """ - Return the (possibly empty) list of all values for a header. - - :param key: header name - - """ - return self._dict.get(key.lower(), []) - - def raw_items(self) -> Iterator[Tuple[str, str]]: - """ - Return an iterator of all values as ``(name, value)`` pairs. - - """ - return iter(self._list) - - -HeadersLike = Union[Headers, Mapping[str, str], Iterable[Tuple[str, str]]] - - -# at the bottom to allow circular import, because AbortHandshake depends on HeadersLike -import websockets.exceptions # isort:skip # noqa diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 803970205..60235643e 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -30,6 +30,7 @@ cast, ) +from .datastructures import Headers from .exceptions import ( ConnectionClosed, ConnectionClosedError, @@ -41,7 +42,6 @@ from .extensions.base import Extension from .framing import * from .handshake import * -from .http import Headers from .typing import Data diff --git a/src/websockets/server.py b/src/websockets/server.py index 0f0b51a7c..da98cac05 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -28,6 +28,7 @@ cast, ) +from .datastructures import Headers, HeadersLike, MultipleValuesError from .exceptions import ( AbortHandshake, InvalidHandshake, @@ -41,7 +42,7 @@ from .extensions.permessage_deflate import ServerPerMessageDeflateFactory from .handshake import build_response, check_request from .headers import build_extension, parse_extension, parse_subprotocol -from .http import USER_AGENT, Headers, HeadersLike, MultipleValuesError, read_request +from .http import USER_AGENT, read_request from .protocol import WebSocketCommonProtocol from .typing import ExtensionHeader, Origin, Subprotocol diff --git a/tests/test_client_server.py b/tests/test_client_server.py index 35913666c..ba0984c80 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -14,6 +14,7 @@ import warnings from websockets.client import * +from websockets.datastructures import Headers from websockets.exceptions import ( ConnectionClosed, InvalidHandshake, @@ -27,7 +28,7 @@ ServerPerMessageDeflateFactory, ) from websockets.handshake import build_response -from websockets.http import USER_AGENT, Headers, read_response +from websockets.http import USER_AGENT, read_response from websockets.protocol import State from websockets.server import * from websockets.uri import parse_uri diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py new file mode 100644 index 000000000..628cbcb02 --- /dev/null +++ b/tests/test_datastructures.py @@ -0,0 +1,131 @@ +import unittest + +from websockets.datastructures import * + + +class HeadersTests(unittest.TestCase): + def setUp(self): + self.headers = Headers([("Connection", "Upgrade"), ("Server", "websockets")]) + + def test_str(self): + self.assertEqual( + str(self.headers), "Connection: Upgrade\r\nServer: websockets\r\n\r\n" + ) + + def test_repr(self): + self.assertEqual( + repr(self.headers), + "Headers([('Connection', 'Upgrade'), ('Server', 'websockets')])", + ) + + def test_copy(self): + self.assertEqual(repr(self.headers.copy()), repr(self.headers)) + + def test_serialize(self): + self.assertEqual( + self.headers.serialize(), + b"Connection: Upgrade\r\nServer: websockets\r\n\r\n", + ) + + def test_multiple_values_error_str(self): + self.assertEqual(str(MultipleValuesError("Connection")), "'Connection'") + self.assertEqual(str(MultipleValuesError()), "") + + def test_contains(self): + self.assertIn("Server", self.headers) + + def test_contains_case_insensitive(self): + self.assertIn("server", self.headers) + + def test_contains_not_found(self): + self.assertNotIn("Date", self.headers) + + def test_contains_non_string_key(self): + self.assertNotIn(42, self.headers) + + def test_iter(self): + self.assertEqual(set(iter(self.headers)), {"connection", "server"}) + + def test_len(self): + self.assertEqual(len(self.headers), 2) + + def test_getitem(self): + self.assertEqual(self.headers["Server"], "websockets") + + def test_getitem_case_insensitive(self): + self.assertEqual(self.headers["server"], "websockets") + + def test_getitem_key_error(self): + with self.assertRaises(KeyError): + self.headers["Upgrade"] + + def test_getitem_multiple_values_error(self): + self.headers["Server"] = "2" + with self.assertRaises(MultipleValuesError): + self.headers["Server"] + + def test_setitem(self): + self.headers["Upgrade"] = "websocket" + self.assertEqual(self.headers["Upgrade"], "websocket") + + def test_setitem_case_insensitive(self): + self.headers["upgrade"] = "websocket" + self.assertEqual(self.headers["Upgrade"], "websocket") + + def test_setitem_multiple_values(self): + self.headers["Connection"] = "close" + with self.assertRaises(MultipleValuesError): + self.headers["Connection"] + + def test_delitem(self): + del self.headers["Connection"] + with self.assertRaises(KeyError): + self.headers["Connection"] + + def test_delitem_case_insensitive(self): + del self.headers["connection"] + with self.assertRaises(KeyError): + self.headers["Connection"] + + def test_delitem_multiple_values(self): + self.headers["Connection"] = "close" + del self.headers["Connection"] + with self.assertRaises(KeyError): + self.headers["Connection"] + + def test_eq(self): + other_headers = Headers([("Connection", "Upgrade"), ("Server", "websockets")]) + self.assertEqual(self.headers, other_headers) + + def test_eq_not_equal(self): + other_headers = Headers([("Connection", "close"), ("Server", "websockets")]) + self.assertNotEqual(self.headers, other_headers) + + def test_eq_other_type(self): + self.assertNotEqual( + self.headers, "Connection: Upgrade\r\nServer: websockets\r\n\r\n" + ) + + def test_clear(self): + self.headers.clear() + self.assertFalse(self.headers) + self.assertEqual(self.headers, Headers()) + + def test_get_all(self): + self.assertEqual(self.headers.get_all("Connection"), ["Upgrade"]) + + def test_get_all_case_insensitive(self): + self.assertEqual(self.headers.get_all("connection"), ["Upgrade"]) + + def test_get_all_no_values(self): + self.assertEqual(self.headers.get_all("Upgrade"), []) + + def test_get_all_multiple_values(self): + self.headers["Connection"] = "close" + self.assertEqual(self.headers.get_all("Connection"), ["Upgrade", "close"]) + + def test_raw_items(self): + self.assertEqual( + list(self.headers.raw_items()), + [("Connection", "Upgrade"), ("Server", "websockets")], + ) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 7ad5ad833..b800d4f91 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -1,7 +1,7 @@ import unittest +from websockets.datastructures import Headers from websockets.exceptions import * -from websockets.http import Headers class ExceptionsTests(unittest.TestCase): diff --git a/tests/test_handshake.py b/tests/test_handshake.py index 7d0477715..6850fec9a 100644 --- a/tests/test_handshake.py +++ b/tests/test_handshake.py @@ -1,6 +1,7 @@ import contextlib import unittest +from websockets.datastructures import Headers from websockets.exceptions import ( InvalidHandshake, InvalidHeader, @@ -9,7 +10,6 @@ ) from websockets.handshake import * from websockets.handshake import accept # private API -from websockets.http import Headers class HandshakeTests(unittest.TestCase): diff --git a/tests/test_http.py b/tests/test_http.py index 41b522c3d..b09247c3e 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -1,5 +1,4 @@ import asyncio -import unittest from websockets.exceptions import SecurityError from websockets.http import * @@ -134,116 +133,3 @@ async def test_line_ending(self): self.stream.feed_data(b"foo: bar\n\n") with self.assertRaises(EOFError): await read_headers(self.stream) - - -class HeadersTests(unittest.TestCase): - def setUp(self): - self.headers = Headers([("Connection", "Upgrade"), ("Server", USER_AGENT)]) - - def test_str(self): - self.assertEqual( - str(self.headers), f"Connection: Upgrade\r\nServer: {USER_AGENT}\r\n\r\n" - ) - - def test_repr(self): - self.assertEqual( - repr(self.headers), - f"Headers([('Connection', 'Upgrade'), " f"('Server', '{USER_AGENT}')])", - ) - - def test_multiple_values_error_str(self): - self.assertEqual(str(MultipleValuesError("Connection")), "'Connection'") - self.assertEqual(str(MultipleValuesError()), "") - - def test_contains(self): - self.assertIn("Server", self.headers) - - def test_contains_case_insensitive(self): - self.assertIn("server", self.headers) - - def test_contains_not_found(self): - self.assertNotIn("Date", self.headers) - - def test_contains_non_string_key(self): - self.assertNotIn(42, self.headers) - - def test_iter(self): - self.assertEqual(set(iter(self.headers)), {"connection", "server"}) - - def test_len(self): - self.assertEqual(len(self.headers), 2) - - def test_getitem(self): - self.assertEqual(self.headers["Server"], USER_AGENT) - - def test_getitem_case_insensitive(self): - self.assertEqual(self.headers["server"], USER_AGENT) - - def test_getitem_key_error(self): - with self.assertRaises(KeyError): - self.headers["Upgrade"] - - def test_getitem_multiple_values_error(self): - self.headers["Server"] = "2" - with self.assertRaises(MultipleValuesError): - self.headers["Server"] - - def test_setitem(self): - self.headers["Upgrade"] = "websocket" - self.assertEqual(self.headers["Upgrade"], "websocket") - - def test_setitem_case_insensitive(self): - self.headers["upgrade"] = "websocket" - self.assertEqual(self.headers["Upgrade"], "websocket") - - def test_setitem_multiple_values(self): - self.headers["Connection"] = "close" - with self.assertRaises(MultipleValuesError): - self.headers["Connection"] - - def test_delitem(self): - del self.headers["Connection"] - with self.assertRaises(KeyError): - self.headers["Connection"] - - def test_delitem_case_insensitive(self): - del self.headers["connection"] - with self.assertRaises(KeyError): - self.headers["Connection"] - - def test_delitem_multiple_values(self): - self.headers["Connection"] = "close" - del self.headers["Connection"] - with self.assertRaises(KeyError): - self.headers["Connection"] - - def test_eq(self): - other_headers = self.headers.copy() - self.assertEqual(self.headers, other_headers) - - def test_eq_not_equal(self): - self.assertNotEqual(self.headers, []) - - def test_clear(self): - self.headers.clear() - self.assertFalse(self.headers) - self.assertEqual(self.headers, Headers()) - - def test_get_all(self): - self.assertEqual(self.headers.get_all("Connection"), ["Upgrade"]) - - def test_get_all_case_insensitive(self): - self.assertEqual(self.headers.get_all("connection"), ["Upgrade"]) - - def test_get_all_no_values(self): - self.assertEqual(self.headers.get_all("Upgrade"), []) - - def test_get_all_multiple_values(self): - self.headers["Connection"] = "close" - self.assertEqual(self.headers.get_all("Connection"), ["Upgrade", "close"]) - - def test_raw_items(self): - self.assertEqual( - list(self.headers.raw_items()), - [("Connection", "Upgrade"), ("Server", USER_AGENT)], - ) From 1f19838c81c3bb30f94881143c43842ac09162ec Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 6 Oct 2019 12:19:20 +0200 Subject: [PATCH 0693/1539] Move the handshake and http modules out of the way. --- docs/api.rst | 9 -- docs/changelog.rst | 4 + src/websockets/client.py | 5 +- src/websockets/handshake.py | 191 ++++----------------------- src/websockets/handshake_legacy.py | 186 ++++++++++++++++++++++++++ src/websockets/http.py | 205 +++-------------------------- src/websockets/http_legacy.py | 193 +++++++++++++++++++++++++++ src/websockets/protocol.py | 2 +- src/websockets/server.py | 5 +- tests/test_client_server.py | 5 +- tests/test_handshake.py | 192 +-------------------------- tests/test_handshake_legacy.py | 190 ++++++++++++++++++++++++++ tests/test_http.py | 137 +------------------ tests/test_http_legacy.py | 135 +++++++++++++++++++ 14 files changed, 765 insertions(+), 694 deletions(-) create mode 100644 src/websockets/handshake_legacy.py create mode 100644 src/websockets/http_legacy.py create mode 100644 tests/test_handshake_legacy.py create mode 100644 tests/test_http_legacy.py diff --git a/docs/api.rst b/docs/api.rst index f7706ee2c..b4bddaf38 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -130,12 +130,6 @@ Exceptions Low-level --------- -Opening handshake -................. - -.. automodule:: websockets.handshake - :members: - Data transfer ............. @@ -153,6 +147,3 @@ Utilities .. automodule:: websockets.headers :members: - -.. automodule:: websockets.http - :members: diff --git a/docs/changelog.rst b/docs/changelog.rst index 5de7357ca..3cda4919f 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -16,6 +16,10 @@ Changelog :exc:`~datastructures.MultipleValuesError` from :mod:`websockets.datastructures` instead of :mod:`websockets.http`. + * :mod:`websockets.handshake` is deprecated. + + * :mod:`websockets.http` is deprecated. + Aliases provide backwards compatibility for all previously public APIs. 8.1 diff --git a/src/websockets/client.py b/src/websockets/client.py index 26a369c47..f95dae060 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -23,7 +23,7 @@ ) from .extensions.base import ClientExtensionFactory, Extension from .extensions.permessage_deflate import ClientPerMessageDeflateFactory -from .handshake import build_request, check_response +from .handshake_legacy import build_request, check_response from .headers import ( build_authorization_basic, build_extension, @@ -31,7 +31,8 @@ parse_extension, parse_subprotocol, ) -from .http import USER_AGENT, read_response +from .http import USER_AGENT +from .http_legacy import read_response from .protocol import WebSocketCommonProtocol from .typing import ExtensionHeader, Origin, Subprotocol from .uri import WebSocketURI, parse_uri diff --git a/src/websockets/handshake.py b/src/websockets/handshake.py index e30a67125..f27bd1b84 100644 --- a/src/websockets/handshake.py +++ b/src/websockets/handshake.py @@ -1,187 +1,48 @@ -""" -:mod:`websockets.handshake` provides helpers for the WebSocket handshake. +import warnings -See `section 4 of RFC 6455`_. - -.. _section 4 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-4 - -Some checks cannot be performed because they depend too much on the -context; instead, they're documented below. - -To accept a connection, a server must: - -- Read the request, check that the method is GET, and check the headers with - :func:`check_request`, -- Send a 101 response to the client with the headers created by - :func:`build_response` if the request is valid; otherwise, send an - appropriate HTTP error code. - -To open a connection, a client must: - -- Send a GET request to the server with the headers created by - :func:`build_request`, -- Read the response, check that the status code is 101, and check the headers - with :func:`check_response`. - -""" - -import base64 -import binascii -import hashlib -import random -from typing import List - -from .datastructures import Headers, MultipleValuesError -from .exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade -from .headers import ConnectionOption, UpgradeProtocol, parse_connection, parse_upgrade +from .datastructures import Headers __all__ = ["build_request", "check_request", "build_response", "check_response"] -GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" - - -def build_request(headers: Headers) -> str: - """ - Build a handshake request to send to the server. - - Update request headers passed in argument. - - :param headers: request headers - :returns: ``key`` which must be passed to :func:`check_response` - - """ - raw_key = bytes(random.getrandbits(8) for _ in range(16)) - key = base64.b64encode(raw_key).decode() - headers["Upgrade"] = "websocket" - headers["Connection"] = "Upgrade" - headers["Sec-WebSocket-Key"] = key - headers["Sec-WebSocket-Version"] = "13" - return key +GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" -def check_request(headers: Headers) -> str: - """ - Check a handshake request received from the client. - This function doesn't verify that the request is an HTTP/1.1 or higher GET - request and doesn't perform ``Host`` and ``Origin`` checks. These controls - are usually performed earlier in the HTTP request handling code. They're - the responsibility of the caller. +# Backwards compatibility with previously documented public APIs - :param headers: request headers - :returns: ``key`` which must be passed to :func:`build_response` - :raises ~websockets.exceptions.InvalidHandshake: if the handshake request - is invalid; then the server must return 400 Bad Request error - """ - connection: List[ConnectionOption] = sum( - [parse_connection(value) for value in headers.get_all("Connection")], [] +def build_request(headers: Headers) -> str: # pragma: no cover + warnings.warn( + "websockets.handshake.build_request is deprecated", DeprecationWarning ) + from .handshake_legacy import build_request - if not any(value.lower() == "upgrade" for value in connection): - raise InvalidUpgrade("Connection", ", ".join(connection)) + return build_request(headers) - upgrade: List[UpgradeProtocol] = sum( - [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] - ) - # For compatibility with non-strict implementations, ignore case when - # checking the Upgrade header. The RFC always uses "websocket", except - # in section 11.2. (IANA registration) where it uses "WebSocket". - if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"): - raise InvalidUpgrade("Upgrade", ", ".join(upgrade)) - - try: - s_w_key = headers["Sec-WebSocket-Key"] - except KeyError: - raise InvalidHeader("Sec-WebSocket-Key") - except MultipleValuesError: - raise InvalidHeader( - "Sec-WebSocket-Key", "more than one Sec-WebSocket-Key header found" - ) - - try: - raw_key = base64.b64decode(s_w_key.encode(), validate=True) - except binascii.Error: - raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key) - if len(raw_key) != 16: - raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key) - - try: - s_w_version = headers["Sec-WebSocket-Version"] - except KeyError: - raise InvalidHeader("Sec-WebSocket-Version") - except MultipleValuesError: - raise InvalidHeader( - "Sec-WebSocket-Version", "more than one Sec-WebSocket-Version header found" - ) - - if s_w_version != "13": - raise InvalidHeaderValue("Sec-WebSocket-Version", s_w_version) - - return s_w_key - - -def build_response(headers: Headers, key: str) -> None: - """ - Build a handshake response to send to the client. - - Update response headers passed in argument. - - :param headers: response headers - :param key: comes from :func:`check_request` - - """ - headers["Upgrade"] = "websocket" - headers["Connection"] = "Upgrade" - headers["Sec-WebSocket-Accept"] = accept(key) - - -def check_response(headers: Headers, key: str) -> None: - """ - Check a handshake response received from the server. - - This function doesn't verify that the response is an HTTP/1.1 or higher - response with a 101 status code. These controls are the responsibility of - the caller. - - :param headers: response headers - :param key: comes from :func:`build_request` - :raises ~websockets.exceptions.InvalidHandshake: if the handshake response - is invalid - - """ - connection: List[ConnectionOption] = sum( - [parse_connection(value) for value in headers.get_all("Connection")], [] +def check_request(headers: Headers) -> str: # pragma: no cover + warnings.warn( + "websockets.handshake.check_request is deprecated", DeprecationWarning ) + from .handshake_legacy import check_request - if not any(value.lower() == "upgrade" for value in connection): - raise InvalidUpgrade("Connection", " ".join(connection)) + return check_request(headers) - upgrade: List[UpgradeProtocol] = sum( - [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] - ) - # For compatibility with non-strict implementations, ignore case when - # checking the Upgrade header. The RFC always uses "websocket", except - # in section 11.2. (IANA registration) where it uses "WebSocket". - if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"): - raise InvalidUpgrade("Upgrade", ", ".join(upgrade)) +def build_response(headers: Headers, key: str) -> None: # pragma: no cover + warnings.warn( + "websockets.handshake.build_response is deprecated", DeprecationWarning + ) + from .handshake_legacy import build_response - try: - s_w_accept = headers["Sec-WebSocket-Accept"] - except KeyError: - raise InvalidHeader("Sec-WebSocket-Accept") - except MultipleValuesError: - raise InvalidHeader( - "Sec-WebSocket-Accept", "more than one Sec-WebSocket-Accept header found" - ) + return build_response(headers, key) - if s_w_accept != accept(key): - raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept) +def check_response(headers: Headers, key: str) -> None: # pragma: no cover + warnings.warn( + "websockets.handshake.check_response is deprecated", DeprecationWarning + ) + from .handshake_legacy import check_response -def accept(key: str) -> str: - sha1 = hashlib.sha1((key + GUID).encode()).digest() - return base64.b64encode(sha1).decode() + return check_response(headers, key) diff --git a/src/websockets/handshake_legacy.py b/src/websockets/handshake_legacy.py new file mode 100644 index 000000000..3fca45545 --- /dev/null +++ b/src/websockets/handshake_legacy.py @@ -0,0 +1,186 @@ +""" +:mod:`websockets.handshake` provides helpers for the WebSocket handshake. + +See `section 4 of RFC 6455`_. + +.. _section 4 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-4 + +Some checks cannot be performed because they depend too much on the +context; instead, they're documented below. + +To accept a connection, a server must: + +- Read the request, check that the method is GET, and check the headers with + :func:`check_request`, +- Send a 101 response to the client with the headers created by + :func:`build_response` if the request is valid; otherwise, send an + appropriate HTTP error code. + +To open a connection, a client must: + +- Send a GET request to the server with the headers created by + :func:`build_request`, +- Read the response, check that the status code is 101, and check the headers + with :func:`check_response`. + +""" + +import base64 +import binascii +import hashlib +import random +from typing import List + +from .datastructures import Headers, MultipleValuesError +from .exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade +from .handshake import GUID +from .headers import ConnectionOption, UpgradeProtocol, parse_connection, parse_upgrade + + +__all__ = ["build_request", "check_request", "build_response", "check_response"] + + +def build_request(headers: Headers) -> str: + """ + Build a handshake request to send to the server. + + Update request headers passed in argument. + + :param headers: request headers + :returns: ``key`` which must be passed to :func:`check_response` + + """ + raw_key = bytes(random.getrandbits(8) for _ in range(16)) + key = base64.b64encode(raw_key).decode() + headers["Upgrade"] = "websocket" + headers["Connection"] = "Upgrade" + headers["Sec-WebSocket-Key"] = key + headers["Sec-WebSocket-Version"] = "13" + return key + + +def check_request(headers: Headers) -> str: + """ + Check a handshake request received from the client. + + This function doesn't verify that the request is an HTTP/1.1 or higher GET + request and doesn't perform ``Host`` and ``Origin`` checks. These controls + are usually performed earlier in the HTTP request handling code. They're + the responsibility of the caller. + + :param headers: request headers + :returns: ``key`` which must be passed to :func:`build_response` + :raises ~websockets.exceptions.InvalidHandshake: if the handshake request + is invalid; then the server must return 400 Bad Request error + + """ + connection: List[ConnectionOption] = sum( + [parse_connection(value) for value in headers.get_all("Connection")], [] + ) + + if not any(value.lower() == "upgrade" for value in connection): + raise InvalidUpgrade("Connection", ", ".join(connection)) + + upgrade: List[UpgradeProtocol] = sum( + [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] + ) + + # For compatibility with non-strict implementations, ignore case when + # checking the Upgrade header. The RFC always uses "websocket", except + # in section 11.2. (IANA registration) where it uses "WebSocket". + if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"): + raise InvalidUpgrade("Upgrade", ", ".join(upgrade)) + + try: + s_w_key = headers["Sec-WebSocket-Key"] + except KeyError: + raise InvalidHeader("Sec-WebSocket-Key") + except MultipleValuesError: + raise InvalidHeader( + "Sec-WebSocket-Key", "more than one Sec-WebSocket-Key header found" + ) + + try: + raw_key = base64.b64decode(s_w_key.encode(), validate=True) + except binascii.Error: + raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key) + if len(raw_key) != 16: + raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key) + + try: + s_w_version = headers["Sec-WebSocket-Version"] + except KeyError: + raise InvalidHeader("Sec-WebSocket-Version") + except MultipleValuesError: + raise InvalidHeader( + "Sec-WebSocket-Version", "more than one Sec-WebSocket-Version header found" + ) + + if s_w_version != "13": + raise InvalidHeaderValue("Sec-WebSocket-Version", s_w_version) + + return s_w_key + + +def build_response(headers: Headers, key: str) -> None: + """ + Build a handshake response to send to the client. + + Update response headers passed in argument. + + :param headers: response headers + :param key: comes from :func:`check_request` + + """ + headers["Upgrade"] = "websocket" + headers["Connection"] = "Upgrade" + headers["Sec-WebSocket-Accept"] = accept(key) + + +def check_response(headers: Headers, key: str) -> None: + """ + Check a handshake response received from the server. + + This function doesn't verify that the response is an HTTP/1.1 or higher + response with a 101 status code. These controls are the responsibility of + the caller. + + :param headers: response headers + :param key: comes from :func:`build_request` + :raises ~websockets.exceptions.InvalidHandshake: if the handshake response + is invalid + + """ + connection: List[ConnectionOption] = sum( + [parse_connection(value) for value in headers.get_all("Connection")], [] + ) + + if not any(value.lower() == "upgrade" for value in connection): + raise InvalidUpgrade("Connection", " ".join(connection)) + + upgrade: List[UpgradeProtocol] = sum( + [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] + ) + + # For compatibility with non-strict implementations, ignore case when + # checking the Upgrade header. The RFC always uses "websocket", except + # in section 11.2. (IANA registration) where it uses "WebSocket". + if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"): + raise InvalidUpgrade("Upgrade", ", ".join(upgrade)) + + try: + s_w_accept = headers["Sec-WebSocket-Accept"] + except KeyError: + raise InvalidHeader("Sec-WebSocket-Accept") + except MultipleValuesError: + raise InvalidHeader( + "Sec-WebSocket-Accept", "more than one Sec-WebSocket-Accept header found" + ) + + if s_w_accept != accept(key): + raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept) + + +def accept(key: str) -> str: + sha1 = hashlib.sha1((key + GUID).encode()).digest() + return base64.b64encode(sha1).decode() diff --git a/src/websockets/http.py b/src/websockets/http.py index ddb2afcfa..850b9beaa 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -1,208 +1,37 @@ -""" -:mod:`websockets.http` module provides basic HTTP/1.1 support. It is merely -:adequate for WebSocket handshake messages. - -These APIs cannot be imported from :mod:`websockets`. They must be imported -from :mod:`websockets.http`. - -""" - import asyncio -import re import sys +import warnings from typing import Tuple -# For backwards compatibility - should be deprecated +# For backwards compatibility: +# Headers and MultipleValuesError used to be defined in this module from .datastructures import Headers, MultipleValuesError # noqa -from .exceptions import SecurityError from .version import version as websockets_version -__all__ = ["read_request", "read_response", "USER_AGENT"] +__all__ = ["USER_AGENT"] -MAX_HEADERS = 256 -MAX_LINE = 4096 PYTHON_VERSION = "{}.{}".format(*sys.version_info) USER_AGENT = f"Python/{PYTHON_VERSION} websockets/{websockets_version}" -def d(value: bytes) -> str: - """ - Decode a bytestring for interpolating into an error message. - - """ - return value.decode(errors="backslashreplace") - - -# See https://tools.ietf.org/html/rfc7230#appendix-B. - -# Regex for validating header names. - -_token_re = re.compile(rb"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+") - -# Regex for validating header values. - -# We don't attempt to support obsolete line folding. - -# Include HTAB (\x09), SP (\x20), VCHAR (\x21-\x7e), obs-text (\x80-\xff). - -# The ABNF is complicated because it attempts to express that optional -# whitespace is ignored. We strip whitespace and don't revalidate that. - -# See also https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189 - -_value_re = re.compile(rb"[\x09\x20-\x7e\x80-\xff]*") - - -async def read_request(stream: asyncio.StreamReader) -> Tuple[str, Headers]: - """ - Read an HTTP/1.1 GET request and return ``(path, headers)``. - - ``path`` isn't URL-decoded or validated in any way. - - ``path`` and ``headers`` are expected to contain only ASCII characters. - Other characters are represented with surrogate escapes. - - :func:`read_request` doesn't attempt to read the request body because - WebSocket handshake requests don't have one. If the request contains a - body, it may be read from ``stream`` after this coroutine returns. - - :param stream: input to read the request from - :raises EOFError: if the connection is closed without a full HTTP request - :raises SecurityError: if the request exceeds a security limit - :raises ValueError: if the request isn't well formatted - - """ - # https://tools.ietf.org/html/rfc7230#section-3.1.1 - - # Parsing is simple because fixed values are expected for method and - # version and because path isn't checked. Since WebSocket software tends - # to implement HTTP/1.1 strictly, there's little need for lenient parsing. - - try: - request_line = await read_line(stream) - except EOFError as exc: - raise EOFError("connection closed while reading HTTP request line") from exc - - try: - method, raw_path, version = request_line.split(b" ", 2) - except ValueError: # not enough values to unpack (expected 3, got 1-2) - raise ValueError(f"invalid HTTP request line: {d(request_line)}") from None - - if method != b"GET": - raise ValueError(f"unsupported HTTP method: {d(method)}") - if version != b"HTTP/1.1": - raise ValueError(f"unsupported HTTP version: {d(version)}") - path = raw_path.decode("ascii", "surrogateescape") - - headers = await read_headers(stream) - - return path, headers - - -async def read_response(stream: asyncio.StreamReader) -> Tuple[int, str, Headers]: - """ - Read an HTTP/1.1 response and return ``(status_code, reason, headers)``. - - ``reason`` and ``headers`` are expected to contain only ASCII characters. - Other characters are represented with surrogate escapes. - - :func:`read_request` doesn't attempt to read the response body because - WebSocket handshake responses don't have one. If the response contains a - body, it may be read from ``stream`` after this coroutine returns. - - :param stream: input to read the response from - :raises EOFError: if the connection is closed without a full HTTP response - :raises SecurityError: if the response exceeds a security limit - :raises ValueError: if the response isn't well formatted - - """ - # https://tools.ietf.org/html/rfc7230#section-3.1.2 - - # As in read_request, parsing is simple because a fixed value is expected - # for version, status_code is a 3-digit number, and reason can be ignored. - - try: - status_line = await read_line(stream) - except EOFError as exc: - raise EOFError("connection closed while reading HTTP status line") from exc - - try: - version, raw_status_code, raw_reason = status_line.split(b" ", 2) - except ValueError: # not enough values to unpack (expected 3, got 1-2) - raise ValueError(f"invalid HTTP status line: {d(status_line)}") from None - - if version != b"HTTP/1.1": - raise ValueError(f"unsupported HTTP version: {d(version)}") - try: - status_code = int(raw_status_code) - except ValueError: # invalid literal for int() with base 10 - raise ValueError(f"invalid HTTP status code: {d(raw_status_code)}") from None - if not 100 <= status_code < 1000: - raise ValueError(f"unsupported HTTP status code: {d(raw_status_code)}") - if not _value_re.fullmatch(raw_reason): - raise ValueError(f"invalid HTTP reason phrase: {d(raw_reason)}") - reason = raw_reason.decode() - - headers = await read_headers(stream) - - return status_code, reason, headers - - -async def read_headers(stream: asyncio.StreamReader) -> Headers: - """ - Read HTTP headers from ``stream``. - - Non-ASCII characters are represented with surrogate escapes. - - """ - # https://tools.ietf.org/html/rfc7230#section-3.2 - - # We don't attempt to support obsolete line folding. - - headers = Headers() - for _ in range(MAX_HEADERS + 1): - try: - line = await read_line(stream) - except EOFError as exc: - raise EOFError("connection closed while reading HTTP headers") from exc - if line == b"": - break - - try: - raw_name, raw_value = line.split(b":", 1) - except ValueError: # not enough values to unpack (expected 2, got 1) - raise ValueError(f"invalid HTTP header line: {d(line)}") from None - if not _token_re.fullmatch(raw_name): - raise ValueError(f"invalid HTTP header name: {d(raw_name)}") - raw_value = raw_value.strip(b" \t") - if not _value_re.fullmatch(raw_value): - raise ValueError(f"invalid HTTP header value: {d(raw_value)}") - - name = raw_name.decode("ascii") # guaranteed to be ASCII at this point - value = raw_value.decode("ascii", "surrogateescape") - headers[name] = value +# Backwards compatibility with previously documented public APIs - else: - raise SecurityError("too many HTTP headers") - return headers +async def read_request( + stream: asyncio.StreamReader, +) -> Tuple[str, Headers]: # pragma: no cover + warnings.warn("websockets.http.read_request is deprecated", DeprecationWarning) + from .http_legacy import read_request + return await read_request(stream) -async def read_line(stream: asyncio.StreamReader) -> bytes: - """ - Read a single line from ``stream``. - CRLF is stripped from the return value. +async def read_response( + stream: asyncio.StreamReader, +) -> Tuple[int, str, Headers]: # pragma: no cover + warnings.warn("websockets.http.read_response is deprecated", DeprecationWarning) + from .http_legacy import read_response - """ - # Security: this is bounded by the StreamReader's limit (default = 32 KiB). - line = await stream.readline() - # Security: this guarantees header values are small (hard-coded = 4 KiB) - if len(line) > MAX_LINE: - raise SecurityError("line too long") - # Not mandatory but safe - https://tools.ietf.org/html/rfc7230#section-3.5 - if not line.endswith(b"\r\n"): - raise EOFError("line without CRLF") - return line[:-2] + return await read_response(stream) diff --git a/src/websockets/http_legacy.py b/src/websockets/http_legacy.py new file mode 100644 index 000000000..3630d3593 --- /dev/null +++ b/src/websockets/http_legacy.py @@ -0,0 +1,193 @@ +import asyncio +import re +from typing import Tuple + +from .datastructures import Headers +from .exceptions import SecurityError + + +__all__ = ["read_request", "read_response"] + +MAX_HEADERS = 256 +MAX_LINE = 4096 + + +def d(value: bytes) -> str: + """ + Decode a bytestring for interpolating into an error message. + + """ + return value.decode(errors="backslashreplace") + + +# See https://tools.ietf.org/html/rfc7230#appendix-B. + +# Regex for validating header names. + +_token_re = re.compile(rb"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+") + +# Regex for validating header values. + +# We don't attempt to support obsolete line folding. + +# Include HTAB (\x09), SP (\x20), VCHAR (\x21-\x7e), obs-text (\x80-\xff). + +# The ABNF is complicated because it attempts to express that optional +# whitespace is ignored. We strip whitespace and don't revalidate that. + +# See also https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189 + +_value_re = re.compile(rb"[\x09\x20-\x7e\x80-\xff]*") + + +async def read_request(stream: asyncio.StreamReader) -> Tuple[str, Headers]: + """ + Read an HTTP/1.1 GET request and return ``(path, headers)``. + + ``path`` isn't URL-decoded or validated in any way. + + ``path`` and ``headers`` are expected to contain only ASCII characters. + Other characters are represented with surrogate escapes. + + :func:`read_request` doesn't attempt to read the request body because + WebSocket handshake requests don't have one. If the request contains a + body, it may be read from ``stream`` after this coroutine returns. + + :param stream: input to read the request from + :raises EOFError: if the connection is closed without a full HTTP request + :raises SecurityError: if the request exceeds a security limit + :raises ValueError: if the request isn't well formatted + + """ + # https://tools.ietf.org/html/rfc7230#section-3.1.1 + + # Parsing is simple because fixed values are expected for method and + # version and because path isn't checked. Since WebSocket software tends + # to implement HTTP/1.1 strictly, there's little need for lenient parsing. + + try: + request_line = await read_line(stream) + except EOFError as exc: + raise EOFError("connection closed while reading HTTP request line") from exc + + try: + method, raw_path, version = request_line.split(b" ", 2) + except ValueError: # not enough values to unpack (expected 3, got 1-2) + raise ValueError(f"invalid HTTP request line: {d(request_line)}") from None + + if method != b"GET": + raise ValueError(f"unsupported HTTP method: {d(method)}") + if version != b"HTTP/1.1": + raise ValueError(f"unsupported HTTP version: {d(version)}") + path = raw_path.decode("ascii", "surrogateescape") + + headers = await read_headers(stream) + + return path, headers + + +async def read_response(stream: asyncio.StreamReader) -> Tuple[int, str, Headers]: + """ + Read an HTTP/1.1 response and return ``(status_code, reason, headers)``. + + ``reason`` and ``headers`` are expected to contain only ASCII characters. + Other characters are represented with surrogate escapes. + + :func:`read_request` doesn't attempt to read the response body because + WebSocket handshake responses don't have one. If the response contains a + body, it may be read from ``stream`` after this coroutine returns. + + :param stream: input to read the response from + :raises EOFError: if the connection is closed without a full HTTP response + :raises SecurityError: if the response exceeds a security limit + :raises ValueError: if the response isn't well formatted + + """ + # https://tools.ietf.org/html/rfc7230#section-3.1.2 + + # As in read_request, parsing is simple because a fixed value is expected + # for version, status_code is a 3-digit number, and reason can be ignored. + + try: + status_line = await read_line(stream) + except EOFError as exc: + raise EOFError("connection closed while reading HTTP status line") from exc + + try: + version, raw_status_code, raw_reason = status_line.split(b" ", 2) + except ValueError: # not enough values to unpack (expected 3, got 1-2) + raise ValueError(f"invalid HTTP status line: {d(status_line)}") from None + + if version != b"HTTP/1.1": + raise ValueError(f"unsupported HTTP version: {d(version)}") + try: + status_code = int(raw_status_code) + except ValueError: # invalid literal for int() with base 10 + raise ValueError(f"invalid HTTP status code: {d(raw_status_code)}") from None + if not 100 <= status_code < 1000: + raise ValueError(f"unsupported HTTP status code: {d(raw_status_code)}") + if not _value_re.fullmatch(raw_reason): + raise ValueError(f"invalid HTTP reason phrase: {d(raw_reason)}") + reason = raw_reason.decode() + + headers = await read_headers(stream) + + return status_code, reason, headers + + +async def read_headers(stream: asyncio.StreamReader) -> Headers: + """ + Read HTTP headers from ``stream``. + + Non-ASCII characters are represented with surrogate escapes. + + """ + # https://tools.ietf.org/html/rfc7230#section-3.2 + + # We don't attempt to support obsolete line folding. + + headers = Headers() + for _ in range(MAX_HEADERS + 1): + try: + line = await read_line(stream) + except EOFError as exc: + raise EOFError("connection closed while reading HTTP headers") from exc + if line == b"": + break + + try: + raw_name, raw_value = line.split(b":", 1) + except ValueError: # not enough values to unpack (expected 2, got 1) + raise ValueError(f"invalid HTTP header line: {d(line)}") from None + if not _token_re.fullmatch(raw_name): + raise ValueError(f"invalid HTTP header name: {d(raw_name)}") + raw_value = raw_value.strip(b" \t") + if not _value_re.fullmatch(raw_value): + raise ValueError(f"invalid HTTP header value: {d(raw_value)}") + + name = raw_name.decode("ascii") # guaranteed to be ASCII at this point + value = raw_value.decode("ascii", "surrogateescape") + headers[name] = value + + else: + raise SecurityError("too many HTTP headers") + + return headers + + +async def read_line(stream: asyncio.StreamReader) -> bytes: + """ + Read a single line from ``stream``. + + CRLF is stripped from the return value. + + """ + # Security: this is bounded by the StreamReader's limit (default = 32 KiB). + line = await stream.readline() + # Security: this guarantees header values are small (hard-coded = 4 KiB) + if len(line) > MAX_LINE: + raise SecurityError("line too long") + # Not mandatory but safe - https://tools.ietf.org/html/rfc7230#section-3.5 + if not line.endswith(b"\r\n"): + raise EOFError("line without CRLF") + return line[:-2] diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 60235643e..cc4416ba8 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -41,7 +41,7 @@ ) from .extensions.base import Extension from .framing import * -from .handshake import * +from .handshake_legacy import * from .typing import Data diff --git a/src/websockets/server.py b/src/websockets/server.py index da98cac05..522c76114 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -40,9 +40,10 @@ ) from .extensions.base import Extension, ServerExtensionFactory from .extensions.permessage_deflate import ServerPerMessageDeflateFactory -from .handshake import build_response, check_request +from .handshake_legacy import build_response, check_request from .headers import build_extension, parse_extension, parse_subprotocol -from .http import USER_AGENT, read_request +from .http import USER_AGENT +from .http_legacy import read_request from .protocol import WebSocketCommonProtocol from .typing import ExtensionHeader, Origin, Subprotocol diff --git a/tests/test_client_server.py b/tests/test_client_server.py index ba0984c80..db26d6583 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -27,8 +27,9 @@ PerMessageDeflate, ServerPerMessageDeflateFactory, ) -from websockets.handshake import build_response -from websockets.http import USER_AGENT, read_response +from websockets.handshake_legacy import build_response +from websockets.http import USER_AGENT +from websockets.http_legacy import read_response from websockets.protocol import State from websockets.server import * from websockets.uri import parse_uri diff --git a/tests/test_handshake.py b/tests/test_handshake.py index 6850fec9a..8c35c9714 100644 --- a/tests/test_handshake.py +++ b/tests/test_handshake.py @@ -1,190 +1,2 @@ -import contextlib -import unittest - -from websockets.datastructures import Headers -from websockets.exceptions import ( - InvalidHandshake, - InvalidHeader, - InvalidHeaderValue, - InvalidUpgrade, -) -from websockets.handshake import * -from websockets.handshake import accept # private API - - -class HandshakeTests(unittest.TestCase): - def test_accept(self): - # Test vector from RFC 6455 - key = "dGhlIHNhbXBsZSBub25jZQ==" - acc = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" - self.assertEqual(accept(key), acc) - - def test_round_trip(self): - request_headers = Headers() - request_key = build_request(request_headers) - response_key = check_request(request_headers) - self.assertEqual(request_key, response_key) - response_headers = Headers() - build_response(response_headers, response_key) - check_response(response_headers, request_key) - - @contextlib.contextmanager - def assertValidRequestHeaders(self): - """ - Provide request headers for modification. - - Assert that the transformation kept them valid. - - """ - headers = Headers() - build_request(headers) - yield headers - check_request(headers) - - @contextlib.contextmanager - def assertInvalidRequestHeaders(self, exc_type): - """ - Provide request headers for modification. - - Assert that the transformation made them invalid. - - """ - headers = Headers() - build_request(headers) - yield headers - assert issubclass(exc_type, InvalidHandshake) - with self.assertRaises(exc_type): - check_request(headers) - - def test_request_invalid_connection(self): - with self.assertInvalidRequestHeaders(InvalidUpgrade) as headers: - del headers["Connection"] - headers["Connection"] = "Downgrade" - - def test_request_missing_connection(self): - with self.assertInvalidRequestHeaders(InvalidUpgrade) as headers: - del headers["Connection"] - - def test_request_additional_connection(self): - with self.assertValidRequestHeaders() as headers: - headers["Connection"] = "close" - - def test_request_invalid_upgrade(self): - with self.assertInvalidRequestHeaders(InvalidUpgrade) as headers: - del headers["Upgrade"] - headers["Upgrade"] = "socketweb" - - def test_request_missing_upgrade(self): - with self.assertInvalidRequestHeaders(InvalidUpgrade) as headers: - del headers["Upgrade"] - - def test_request_additional_upgrade(self): - with self.assertInvalidRequestHeaders(InvalidUpgrade) as headers: - headers["Upgrade"] = "socketweb" - - def test_request_invalid_key_not_base64(self): - with self.assertInvalidRequestHeaders(InvalidHeaderValue) as headers: - del headers["Sec-WebSocket-Key"] - headers["Sec-WebSocket-Key"] = "!@#$%^&*()" - - def test_request_invalid_key_not_well_padded(self): - with self.assertInvalidRequestHeaders(InvalidHeaderValue) as headers: - del headers["Sec-WebSocket-Key"] - headers["Sec-WebSocket-Key"] = "CSIRmL8dWYxeAdr/XpEHRw" - - def test_request_invalid_key_not_16_bytes_long(self): - with self.assertInvalidRequestHeaders(InvalidHeaderValue) as headers: - del headers["Sec-WebSocket-Key"] - headers["Sec-WebSocket-Key"] = "ZLpprpvK4PE=" - - def test_request_missing_key(self): - with self.assertInvalidRequestHeaders(InvalidHeader) as headers: - del headers["Sec-WebSocket-Key"] - - def test_request_additional_key(self): - with self.assertInvalidRequestHeaders(InvalidHeader) as headers: - # This duplicates the Sec-WebSocket-Key header. - headers["Sec-WebSocket-Key"] = headers["Sec-WebSocket-Key"] - - def test_request_invalid_version(self): - with self.assertInvalidRequestHeaders(InvalidHeaderValue) as headers: - del headers["Sec-WebSocket-Version"] - headers["Sec-WebSocket-Version"] = "42" - - def test_request_missing_version(self): - with self.assertInvalidRequestHeaders(InvalidHeader) as headers: - del headers["Sec-WebSocket-Version"] - - def test_request_additional_version(self): - with self.assertInvalidRequestHeaders(InvalidHeader) as headers: - # This duplicates the Sec-WebSocket-Version header. - headers["Sec-WebSocket-Version"] = headers["Sec-WebSocket-Version"] - - @contextlib.contextmanager - def assertValidResponseHeaders(self, key="CSIRmL8dWYxeAdr/XpEHRw=="): - """ - Provide response headers for modification. - - Assert that the transformation kept them valid. - - """ - headers = Headers() - build_response(headers, key) - yield headers - check_response(headers, key) - - @contextlib.contextmanager - def assertInvalidResponseHeaders(self, exc_type, key="CSIRmL8dWYxeAdr/XpEHRw=="): - """ - Provide response headers for modification. - - Assert that the transformation made them invalid. - - """ - headers = Headers() - build_response(headers, key) - yield headers - assert issubclass(exc_type, InvalidHandshake) - with self.assertRaises(exc_type): - check_response(headers, key) - - def test_response_invalid_connection(self): - with self.assertInvalidResponseHeaders(InvalidUpgrade) as headers: - del headers["Connection"] - headers["Connection"] = "Downgrade" - - def test_response_missing_connection(self): - with self.assertInvalidResponseHeaders(InvalidUpgrade) as headers: - del headers["Connection"] - - def test_response_additional_connection(self): - with self.assertValidResponseHeaders() as headers: - headers["Connection"] = "close" - - def test_response_invalid_upgrade(self): - with self.assertInvalidResponseHeaders(InvalidUpgrade) as headers: - del headers["Upgrade"] - headers["Upgrade"] = "socketweb" - - def test_response_missing_upgrade(self): - with self.assertInvalidResponseHeaders(InvalidUpgrade) as headers: - del headers["Upgrade"] - - def test_response_additional_upgrade(self): - with self.assertInvalidResponseHeaders(InvalidUpgrade) as headers: - headers["Upgrade"] = "socketweb" - - def test_response_invalid_accept(self): - with self.assertInvalidResponseHeaders(InvalidHeaderValue) as headers: - del headers["Sec-WebSocket-Accept"] - other_key = "1Eq4UDEFQYg3YspNgqxv5g==" - headers["Sec-WebSocket-Accept"] = accept(other_key) - - def test_response_missing_accept(self): - with self.assertInvalidResponseHeaders(InvalidHeader) as headers: - del headers["Sec-WebSocket-Accept"] - - def test_response_additional_accept(self): - with self.assertInvalidResponseHeaders(InvalidHeader) as headers: - # This duplicates the Sec-WebSocket-Accept header. - headers["Sec-WebSocket-Accept"] = headers["Sec-WebSocket-Accept"] +# Check that the legacy handshake module imports without an exception. +from websockets.handshake import * # noqa diff --git a/tests/test_handshake_legacy.py b/tests/test_handshake_legacy.py new file mode 100644 index 000000000..361410d3f --- /dev/null +++ b/tests/test_handshake_legacy.py @@ -0,0 +1,190 @@ +import contextlib +import unittest + +from websockets.datastructures import Headers +from websockets.exceptions import ( + InvalidHandshake, + InvalidHeader, + InvalidHeaderValue, + InvalidUpgrade, +) +from websockets.handshake_legacy import * +from websockets.handshake_legacy import accept # private API + + +class HandshakeTests(unittest.TestCase): + def test_accept(self): + # Test vector from RFC 6455 + key = "dGhlIHNhbXBsZSBub25jZQ==" + acc = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" + self.assertEqual(accept(key), acc) + + def test_round_trip(self): + request_headers = Headers() + request_key = build_request(request_headers) + response_key = check_request(request_headers) + self.assertEqual(request_key, response_key) + response_headers = Headers() + build_response(response_headers, response_key) + check_response(response_headers, request_key) + + @contextlib.contextmanager + def assertValidRequestHeaders(self): + """ + Provide request headers for modification. + + Assert that the transformation kept them valid. + + """ + headers = Headers() + build_request(headers) + yield headers + check_request(headers) + + @contextlib.contextmanager + def assertInvalidRequestHeaders(self, exc_type): + """ + Provide request headers for modification. + + Assert that the transformation made them invalid. + + """ + headers = Headers() + build_request(headers) + yield headers + assert issubclass(exc_type, InvalidHandshake) + with self.assertRaises(exc_type): + check_request(headers) + + def test_request_invalid_connection(self): + with self.assertInvalidRequestHeaders(InvalidUpgrade) as headers: + del headers["Connection"] + headers["Connection"] = "Downgrade" + + def test_request_missing_connection(self): + with self.assertInvalidRequestHeaders(InvalidUpgrade) as headers: + del headers["Connection"] + + def test_request_additional_connection(self): + with self.assertValidRequestHeaders() as headers: + headers["Connection"] = "close" + + def test_request_invalid_upgrade(self): + with self.assertInvalidRequestHeaders(InvalidUpgrade) as headers: + del headers["Upgrade"] + headers["Upgrade"] = "socketweb" + + def test_request_missing_upgrade(self): + with self.assertInvalidRequestHeaders(InvalidUpgrade) as headers: + del headers["Upgrade"] + + def test_request_additional_upgrade(self): + with self.assertInvalidRequestHeaders(InvalidUpgrade) as headers: + headers["Upgrade"] = "socketweb" + + def test_request_invalid_key_not_base64(self): + with self.assertInvalidRequestHeaders(InvalidHeaderValue) as headers: + del headers["Sec-WebSocket-Key"] + headers["Sec-WebSocket-Key"] = "!@#$%^&*()" + + def test_request_invalid_key_not_well_padded(self): + with self.assertInvalidRequestHeaders(InvalidHeaderValue) as headers: + del headers["Sec-WebSocket-Key"] + headers["Sec-WebSocket-Key"] = "CSIRmL8dWYxeAdr/XpEHRw" + + def test_request_invalid_key_not_16_bytes_long(self): + with self.assertInvalidRequestHeaders(InvalidHeaderValue) as headers: + del headers["Sec-WebSocket-Key"] + headers["Sec-WebSocket-Key"] = "ZLpprpvK4PE=" + + def test_request_missing_key(self): + with self.assertInvalidRequestHeaders(InvalidHeader) as headers: + del headers["Sec-WebSocket-Key"] + + def test_request_additional_key(self): + with self.assertInvalidRequestHeaders(InvalidHeader) as headers: + # This duplicates the Sec-WebSocket-Key header. + headers["Sec-WebSocket-Key"] = headers["Sec-WebSocket-Key"] + + def test_request_invalid_version(self): + with self.assertInvalidRequestHeaders(InvalidHeaderValue) as headers: + del headers["Sec-WebSocket-Version"] + headers["Sec-WebSocket-Version"] = "42" + + def test_request_missing_version(self): + with self.assertInvalidRequestHeaders(InvalidHeader) as headers: + del headers["Sec-WebSocket-Version"] + + def test_request_additional_version(self): + with self.assertInvalidRequestHeaders(InvalidHeader) as headers: + # This duplicates the Sec-WebSocket-Version header. + headers["Sec-WebSocket-Version"] = headers["Sec-WebSocket-Version"] + + @contextlib.contextmanager + def assertValidResponseHeaders(self, key="CSIRmL8dWYxeAdr/XpEHRw=="): + """ + Provide response headers for modification. + + Assert that the transformation kept them valid. + + """ + headers = Headers() + build_response(headers, key) + yield headers + check_response(headers, key) + + @contextlib.contextmanager + def assertInvalidResponseHeaders(self, exc_type, key="CSIRmL8dWYxeAdr/XpEHRw=="): + """ + Provide response headers for modification. + + Assert that the transformation made them invalid. + + """ + headers = Headers() + build_response(headers, key) + yield headers + assert issubclass(exc_type, InvalidHandshake) + with self.assertRaises(exc_type): + check_response(headers, key) + + def test_response_invalid_connection(self): + with self.assertInvalidResponseHeaders(InvalidUpgrade) as headers: + del headers["Connection"] + headers["Connection"] = "Downgrade" + + def test_response_missing_connection(self): + with self.assertInvalidResponseHeaders(InvalidUpgrade) as headers: + del headers["Connection"] + + def test_response_additional_connection(self): + with self.assertValidResponseHeaders() as headers: + headers["Connection"] = "close" + + def test_response_invalid_upgrade(self): + with self.assertInvalidResponseHeaders(InvalidUpgrade) as headers: + del headers["Upgrade"] + headers["Upgrade"] = "socketweb" + + def test_response_missing_upgrade(self): + with self.assertInvalidResponseHeaders(InvalidUpgrade) as headers: + del headers["Upgrade"] + + def test_response_additional_upgrade(self): + with self.assertInvalidResponseHeaders(InvalidUpgrade) as headers: + headers["Upgrade"] = "socketweb" + + def test_response_invalid_accept(self): + with self.assertInvalidResponseHeaders(InvalidHeaderValue) as headers: + del headers["Sec-WebSocket-Accept"] + other_key = "1Eq4UDEFQYg3YspNgqxv5g==" + headers["Sec-WebSocket-Accept"] = accept(other_key) + + def test_response_missing_accept(self): + with self.assertInvalidResponseHeaders(InvalidHeader) as headers: + del headers["Sec-WebSocket-Accept"] + + def test_response_additional_accept(self): + with self.assertInvalidResponseHeaders(InvalidHeader) as headers: + # This duplicates the Sec-WebSocket-Accept header. + headers["Sec-WebSocket-Accept"] = headers["Sec-WebSocket-Accept"] diff --git a/tests/test_http.py b/tests/test_http.py index b09247c3e..322650354 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -1,135 +1,2 @@ -import asyncio - -from websockets.exceptions import SecurityError -from websockets.http import * -from websockets.http import read_headers - -from .utils import AsyncioTestCase - - -class HTTPAsyncTests(AsyncioTestCase): - def setUp(self): - super().setUp() - self.stream = asyncio.StreamReader(loop=self.loop) - - async def test_read_request(self): - # Example from the protocol overview in RFC 6455 - self.stream.feed_data( - b"GET /chat HTTP/1.1\r\n" - b"Host: server.example.com\r\n" - b"Upgrade: websocket\r\n" - b"Connection: Upgrade\r\n" - b"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" - b"Origin: http://example.com\r\n" - b"Sec-WebSocket-Protocol: chat, superchat\r\n" - b"Sec-WebSocket-Version: 13\r\n" - b"\r\n" - ) - path, headers = await read_request(self.stream) - self.assertEqual(path, "/chat") - self.assertEqual(headers["Upgrade"], "websocket") - - async def test_read_request_empty(self): - self.stream.feed_eof() - with self.assertRaisesRegex( - EOFError, "connection closed while reading HTTP request line" - ): - await read_request(self.stream) - - async def test_read_request_invalid_request_line(self): - self.stream.feed_data(b"GET /\r\n\r\n") - with self.assertRaisesRegex(ValueError, "invalid HTTP request line: GET /"): - await read_request(self.stream) - - async def test_read_request_unsupported_method(self): - self.stream.feed_data(b"OPTIONS * HTTP/1.1\r\n\r\n") - with self.assertRaisesRegex(ValueError, "unsupported HTTP method: OPTIONS"): - await read_request(self.stream) - - async def test_read_request_unsupported_version(self): - self.stream.feed_data(b"GET /chat HTTP/1.0\r\n\r\n") - with self.assertRaisesRegex(ValueError, "unsupported HTTP version: HTTP/1.0"): - await read_request(self.stream) - - async def test_read_request_invalid_header(self): - self.stream.feed_data(b"GET /chat HTTP/1.1\r\nOops\r\n") - with self.assertRaisesRegex(ValueError, "invalid HTTP header line: Oops"): - await read_request(self.stream) - - async def test_read_response(self): - # Example from the protocol overview in RFC 6455 - self.stream.feed_data( - b"HTTP/1.1 101 Switching Protocols\r\n" - b"Upgrade: websocket\r\n" - b"Connection: Upgrade\r\n" - b"Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n" - b"Sec-WebSocket-Protocol: chat\r\n" - b"\r\n" - ) - status_code, reason, headers = await read_response(self.stream) - self.assertEqual(status_code, 101) - self.assertEqual(reason, "Switching Protocols") - self.assertEqual(headers["Upgrade"], "websocket") - - async def test_read_response_empty(self): - self.stream.feed_eof() - with self.assertRaisesRegex( - EOFError, "connection closed while reading HTTP status line" - ): - await read_response(self.stream) - - async def test_read_request_invalid_status_line(self): - self.stream.feed_data(b"Hello!\r\n") - with self.assertRaisesRegex(ValueError, "invalid HTTP status line: Hello!"): - await read_response(self.stream) - - async def test_read_response_unsupported_version(self): - self.stream.feed_data(b"HTTP/1.0 400 Bad Request\r\n\r\n") - with self.assertRaisesRegex(ValueError, "unsupported HTTP version: HTTP/1.0"): - await read_response(self.stream) - - async def test_read_response_invalid_status(self): - self.stream.feed_data(b"HTTP/1.1 OMG WTF\r\n\r\n") - with self.assertRaisesRegex(ValueError, "invalid HTTP status code: OMG"): - await read_response(self.stream) - - async def test_read_response_unsupported_status(self): - self.stream.feed_data(b"HTTP/1.1 007 My name is Bond\r\n\r\n") - with self.assertRaisesRegex(ValueError, "unsupported HTTP status code: 007"): - await read_response(self.stream) - - async def test_read_response_invalid_reason(self): - self.stream.feed_data(b"HTTP/1.1 200 \x7f\r\n\r\n") - with self.assertRaisesRegex(ValueError, "invalid HTTP reason phrase: \\x7f"): - await read_response(self.stream) - - async def test_read_response_invalid_header(self): - self.stream.feed_data(b"HTTP/1.1 500 Internal Server Error\r\nOops\r\n") - with self.assertRaisesRegex(ValueError, "invalid HTTP header line: Oops"): - await read_response(self.stream) - - async def test_header_name(self): - self.stream.feed_data(b"foo bar: baz qux\r\n\r\n") - with self.assertRaises(ValueError): - await read_headers(self.stream) - - async def test_header_value(self): - self.stream.feed_data(b"foo: \x00\x00\x0f\r\n\r\n") - with self.assertRaises(ValueError): - await read_headers(self.stream) - - async def test_headers_limit(self): - self.stream.feed_data(b"foo: bar\r\n" * 257 + b"\r\n") - with self.assertRaises(SecurityError): - await read_headers(self.stream) - - async def test_line_limit(self): - # Header line contains 5 + 4090 + 2 = 4097 bytes. - self.stream.feed_data(b"foo: " + b"a" * 4090 + b"\r\n\r\n") - with self.assertRaises(SecurityError): - await read_headers(self.stream) - - async def test_line_ending(self): - self.stream.feed_data(b"foo: bar\n\n") - with self.assertRaises(EOFError): - await read_headers(self.stream) +# Check that the legacy http module imports without an exception. +from websockets.http import * # noqa diff --git a/tests/test_http_legacy.py b/tests/test_http_legacy.py new file mode 100644 index 000000000..3b43a6274 --- /dev/null +++ b/tests/test_http_legacy.py @@ -0,0 +1,135 @@ +import asyncio + +from websockets.exceptions import SecurityError +from websockets.http_legacy import * +from websockets.http_legacy import read_headers + +from .utils import AsyncioTestCase + + +class HTTPAsyncTests(AsyncioTestCase): + def setUp(self): + super().setUp() + self.stream = asyncio.StreamReader(loop=self.loop) + + async def test_read_request(self): + # Example from the protocol overview in RFC 6455 + self.stream.feed_data( + b"GET /chat HTTP/1.1\r\n" + b"Host: server.example.com\r\n" + b"Upgrade: websocket\r\n" + b"Connection: Upgrade\r\n" + b"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + b"Origin: http://example.com\r\n" + b"Sec-WebSocket-Protocol: chat, superchat\r\n" + b"Sec-WebSocket-Version: 13\r\n" + b"\r\n" + ) + path, headers = await read_request(self.stream) + self.assertEqual(path, "/chat") + self.assertEqual(headers["Upgrade"], "websocket") + + async def test_read_request_empty(self): + self.stream.feed_eof() + with self.assertRaisesRegex( + EOFError, "connection closed while reading HTTP request line" + ): + await read_request(self.stream) + + async def test_read_request_invalid_request_line(self): + self.stream.feed_data(b"GET /\r\n\r\n") + with self.assertRaisesRegex(ValueError, "invalid HTTP request line: GET /"): + await read_request(self.stream) + + async def test_read_request_unsupported_method(self): + self.stream.feed_data(b"OPTIONS * HTTP/1.1\r\n\r\n") + with self.assertRaisesRegex(ValueError, "unsupported HTTP method: OPTIONS"): + await read_request(self.stream) + + async def test_read_request_unsupported_version(self): + self.stream.feed_data(b"GET /chat HTTP/1.0\r\n\r\n") + with self.assertRaisesRegex(ValueError, "unsupported HTTP version: HTTP/1.0"): + await read_request(self.stream) + + async def test_read_request_invalid_header(self): + self.stream.feed_data(b"GET /chat HTTP/1.1\r\nOops\r\n") + with self.assertRaisesRegex(ValueError, "invalid HTTP header line: Oops"): + await read_request(self.stream) + + async def test_read_response(self): + # Example from the protocol overview in RFC 6455 + self.stream.feed_data( + b"HTTP/1.1 101 Switching Protocols\r\n" + b"Upgrade: websocket\r\n" + b"Connection: Upgrade\r\n" + b"Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n" + b"Sec-WebSocket-Protocol: chat\r\n" + b"\r\n" + ) + status_code, reason, headers = await read_response(self.stream) + self.assertEqual(status_code, 101) + self.assertEqual(reason, "Switching Protocols") + self.assertEqual(headers["Upgrade"], "websocket") + + async def test_read_response_empty(self): + self.stream.feed_eof() + with self.assertRaisesRegex( + EOFError, "connection closed while reading HTTP status line" + ): + await read_response(self.stream) + + async def test_read_request_invalid_status_line(self): + self.stream.feed_data(b"Hello!\r\n") + with self.assertRaisesRegex(ValueError, "invalid HTTP status line: Hello!"): + await read_response(self.stream) + + async def test_read_response_unsupported_version(self): + self.stream.feed_data(b"HTTP/1.0 400 Bad Request\r\n\r\n") + with self.assertRaisesRegex(ValueError, "unsupported HTTP version: HTTP/1.0"): + await read_response(self.stream) + + async def test_read_response_invalid_status(self): + self.stream.feed_data(b"HTTP/1.1 OMG WTF\r\n\r\n") + with self.assertRaisesRegex(ValueError, "invalid HTTP status code: OMG"): + await read_response(self.stream) + + async def test_read_response_unsupported_status(self): + self.stream.feed_data(b"HTTP/1.1 007 My name is Bond\r\n\r\n") + with self.assertRaisesRegex(ValueError, "unsupported HTTP status code: 007"): + await read_response(self.stream) + + async def test_read_response_invalid_reason(self): + self.stream.feed_data(b"HTTP/1.1 200 \x7f\r\n\r\n") + with self.assertRaisesRegex(ValueError, "invalid HTTP reason phrase: \\x7f"): + await read_response(self.stream) + + async def test_read_response_invalid_header(self): + self.stream.feed_data(b"HTTP/1.1 500 Internal Server Error\r\nOops\r\n") + with self.assertRaisesRegex(ValueError, "invalid HTTP header line: Oops"): + await read_response(self.stream) + + async def test_header_name(self): + self.stream.feed_data(b"foo bar: baz qux\r\n\r\n") + with self.assertRaises(ValueError): + await read_headers(self.stream) + + async def test_header_value(self): + self.stream.feed_data(b"foo: \x00\x00\x0f\r\n\r\n") + with self.assertRaises(ValueError): + await read_headers(self.stream) + + async def test_headers_limit(self): + self.stream.feed_data(b"foo: bar\r\n" * 257 + b"\r\n") + with self.assertRaises(SecurityError): + await read_headers(self.stream) + + async def test_line_limit(self): + # Header line contains 5 + 4090 + 2 = 4097 bytes. + self.stream.feed_data(b"foo: " + b"a" * 4090 + b"\r\n\r\n") + with self.assertRaises(SecurityError): + await read_headers(self.stream) + + async def test_line_ending(self): + self.stream.feed_data(b"foo: bar\n\n") + with self.assertRaises(EOFError): + await read_headers(self.stream) From 1c99e5b9fabd3b431c5697a90193ef8e1cd17d58 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 6 Oct 2019 12:34:07 +0200 Subject: [PATCH 0694/1539] Move all type definitions to the typing module. --- src/websockets/handshake_legacy.py | 3 ++- src/websockets/headers.py | 14 +++++++++----- src/websockets/typing.py | 14 ++++++++++---- tests/test_typing.py | 1 + 4 files changed, 22 insertions(+), 10 deletions(-) create mode 100644 tests/test_typing.py diff --git a/src/websockets/handshake_legacy.py b/src/websockets/handshake_legacy.py index 3fca45545..9683e8556 100644 --- a/src/websockets/handshake_legacy.py +++ b/src/websockets/handshake_legacy.py @@ -34,7 +34,8 @@ from .datastructures import Headers, MultipleValuesError from .exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade from .handshake import GUID -from .headers import ConnectionOption, UpgradeProtocol, parse_connection, parse_upgrade +from .headers import parse_connection, parse_upgrade +from .typing import ConnectionOption, UpgradeProtocol __all__ = ["build_request", "check_request", "build_response", "check_response"] diff --git a/src/websockets/headers.py b/src/websockets/headers.py index f33c94c04..256c66bb1 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -10,10 +10,17 @@ import base64 import binascii import re -from typing import Callable, List, NewType, Optional, Sequence, Tuple, TypeVar, cast +from typing import Callable, List, Optional, Sequence, Tuple, TypeVar, cast from .exceptions import InvalidHeaderFormat, InvalidHeaderValue -from .typing import ExtensionHeader, ExtensionName, ExtensionParameter, Subprotocol +from .typing import ( + ConnectionOption, + ExtensionHeader, + ExtensionName, + ExtensionParameter, + Subprotocol, + UpgradeProtocol, +) __all__ = [ @@ -31,9 +38,6 @@ T = TypeVar("T") -ConnectionOption = NewType("ConnectionOption", str) -UpgradeProtocol = NewType("UpgradeProtocol", str) - # To avoid a dependency on a parsing library, we implement manually the ABNF # described in https://tools.ietf.org/html/rfc6455#section-9.1 with the diff --git a/src/websockets/typing.py b/src/websockets/typing.py index a5062bc4b..ca66a8c54 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -28,7 +28,6 @@ ExtensionParameter = Tuple[str, Optional[str]] - ExtensionParameter__doc__ = """Parameter of a WebSocket extension""" try: ExtensionParameter.__doc__ = ExtensionParameter__doc__ @@ -37,8 +36,7 @@ ExtensionHeader = Tuple[ExtensionName, List[ExtensionParameter]] - -ExtensionHeader__doc__ = """Item parsed in a Sec-WebSocket-Extensions header""" +ExtensionHeader__doc__ = """Extension in a Sec-WebSocket-Extensions header""" try: ExtensionHeader.__doc__ = ExtensionHeader__doc__ except AttributeError: # pragma: no cover @@ -46,4 +44,12 @@ Subprotocol = NewType("Subprotocol", str) -Subprotocol.__doc__ = """Items parsed in a Sec-WebSocket-Protocol header""" +Subprotocol.__doc__ = """Subprotocol value in a Sec-WebSocket-Protocol header""" + + +ConnectionOption = NewType("ConnectionOption", str) +ConnectionOption.__doc__ = """Connection option in a Connection header""" + + +UpgradeProtocol = NewType("UpgradeProtocol", str) +UpgradeProtocol.__doc__ = """Upgrade protocol in an Upgrade header""" diff --git a/tests/test_typing.py b/tests/test_typing.py new file mode 100644 index 000000000..6eb1fe6c5 --- /dev/null +++ b/tests/test_typing.py @@ -0,0 +1 @@ +from websockets.typing import * # noqa From 80aea12a584b504f77e5a186c4c6b26444233297 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 25 Jan 2020 19:37:12 +0100 Subject: [PATCH 0695/1539] Add a StreamReader based on generator coroutines. --- src/websockets/streams.py | 115 ++++++++++++++++++++++++++++++ tests/test_streams.py | 146 ++++++++++++++++++++++++++++++++++++++ tests/utils.py | 18 +++++ 3 files changed, 279 insertions(+) create mode 100644 src/websockets/streams.py create mode 100644 tests/test_streams.py diff --git a/src/websockets/streams.py b/src/websockets/streams.py new file mode 100644 index 000000000..6f3163034 --- /dev/null +++ b/src/websockets/streams.py @@ -0,0 +1,115 @@ +from typing import Generator + + +class StreamReader: + """ + Generator-based stream reader. + + This class doesn't support concurrent calls to :meth:`read_line()`, + :meth:`read_exact()`, or :meth:`read_to_eof()`. Make sure calls are + serialized. + + """ + + def __init__(self) -> None: + self.buffer = bytearray() + self.eof = False + + def read_line(self) -> Generator[None, None, bytes]: + """ + Read a LF-terminated line from the stream. + + The return value includes the LF character. + + This is a generator-based coroutine. + + :raises EOFError: if the stream ends without a LF + + """ + n = 0 # number of bytes to read + p = 0 # number of bytes without a newline + while True: + n = self.buffer.find(b"\n", p) + 1 + if n > 0: + break + p = len(self.buffer) + if self.eof: + raise EOFError(f"stream ends after {p} bytes, before end of line") + yield + r = self.buffer[:n] + del self.buffer[:n] + return r + + def read_exact(self, n: int) -> Generator[None, None, bytes]: + """ + Read ``n`` bytes from the stream. + + This is a generator-based coroutine. + + :raises EOFError: if the stream ends in less than ``n`` bytes + + """ + assert n >= 0 + while len(self.buffer) < n: + if self.eof: + p = len(self.buffer) + raise EOFError(f"stream ends after {p} bytes, expected {n} bytes") + yield + r = self.buffer[:n] + del self.buffer[:n] + return r + + def read_to_eof(self) -> Generator[None, None, bytes]: + """ + Read all bytes from the stream. + + This is a generator-based coroutine. + + """ + while not self.eof: + yield + r = self.buffer[:] + del self.buffer[:] + return r + + def at_eof(self) -> Generator[None, None, bool]: + """ + Tell whether the stream has ended and all data was read. + + This is a generator-based coroutine. + + """ + while True: + if self.buffer: + return False + if self.eof: + return True + # When all data was read but the stream hasn't ended, we can't + # tell if until either feed_data() or feed_eof() is called. + yield + + def feed_data(self, data: bytes) -> None: + """ + Write ``data`` to the stream. + + :meth:`feed_data()` cannot be called after :meth:`feed_eof()`. + + :raises EOFError: if the stream has ended + + """ + if self.eof: + raise EOFError("stream ended") + self.buffer += data + + def feed_eof(self) -> None: + """ + End the stream. + + :meth:`feed_eof()` must be called at must once. + + :raises EOFError: if the stream has ended + + """ + if self.eof: + raise EOFError("stream ended") + self.eof = True diff --git a/tests/test_streams.py b/tests/test_streams.py new file mode 100644 index 000000000..566deb2db --- /dev/null +++ b/tests/test_streams.py @@ -0,0 +1,146 @@ +from websockets.streams import StreamReader + +from .utils import GeneratorTestCase + + +class StreamReaderTests(GeneratorTestCase): + def setUp(self): + self.reader = StreamReader() + + def test_read_line(self): + self.reader.feed_data(b"spam\neggs\n") + + gen = self.reader.read_line() + line = self.assertGeneratorReturns(gen) + self.assertEqual(line, b"spam\n") + + gen = self.reader.read_line() + line = self.assertGeneratorReturns(gen) + self.assertEqual(line, b"eggs\n") + + def test_read_line_need_more_data(self): + self.reader.feed_data(b"spa") + + gen = self.reader.read_line() + self.assertGeneratorRunning(gen) + self.reader.feed_data(b"m\neg") + line = self.assertGeneratorReturns(gen) + self.assertEqual(line, b"spam\n") + + gen = self.reader.read_line() + self.assertGeneratorRunning(gen) + self.reader.feed_data(b"gs\n") + line = self.assertGeneratorReturns(gen) + self.assertEqual(line, b"eggs\n") + + def test_read_line_not_enough_data(self): + self.reader.feed_data(b"spa") + self.reader.feed_eof() + + gen = self.reader.read_line() + with self.assertRaises(EOFError) as raised: + next(gen) + self.assertEqual( + str(raised.exception), "stream ends after 3 bytes, before end of line" + ) + + def test_read_exact(self): + self.reader.feed_data(b"spameggs") + + gen = self.reader.read_exact(4) + data = self.assertGeneratorReturns(gen) + self.assertEqual(data, b"spam") + + gen = self.reader.read_exact(4) + data = self.assertGeneratorReturns(gen) + self.assertEqual(data, b"eggs") + + def test_read_exact_need_more_data(self): + self.reader.feed_data(b"spa") + + gen = self.reader.read_exact(4) + self.assertGeneratorRunning(gen) + self.reader.feed_data(b"meg") + data = self.assertGeneratorReturns(gen) + self.assertEqual(data, b"spam") + + gen = self.reader.read_exact(4) + self.assertGeneratorRunning(gen) + self.reader.feed_data(b"gs") + data = self.assertGeneratorReturns(gen) + self.assertEqual(data, b"eggs") + + def test_read_exact_not_enough_data(self): + self.reader.feed_data(b"spa") + self.reader.feed_eof() + + gen = self.reader.read_exact(4) + with self.assertRaises(EOFError) as raised: + next(gen) + self.assertEqual( + str(raised.exception), "stream ends after 3 bytes, expected 4 bytes" + ) + + def test_read_to_eof(self): + gen = self.reader.read_to_eof() + + self.reader.feed_data(b"spam") + self.assertGeneratorRunning(gen) + + self.reader.feed_eof() + data = self.assertGeneratorReturns(gen) + self.assertEqual(data, b"spam") + + def test_read_to_eof_at_eof(self): + self.reader.feed_eof() + + gen = self.reader.read_to_eof() + data = self.assertGeneratorReturns(gen) + self.assertEqual(data, b"") + + def test_at_eof_after_feed_data(self): + gen = self.reader.at_eof() + self.assertGeneratorRunning(gen) + self.reader.feed_data(b"spam") + eof = self.assertGeneratorReturns(gen) + self.assertFalse(eof) + + def test_at_eof_after_feed_eof(self): + gen = self.reader.at_eof() + self.assertGeneratorRunning(gen) + self.reader.feed_eof() + eof = self.assertGeneratorReturns(gen) + self.assertTrue(eof) + + def test_feed_data_after_feed_data(self): + self.reader.feed_data(b"spam") + self.reader.feed_data(b"eggs") + + gen = self.reader.read_exact(8) + data = self.assertGeneratorReturns(gen) + self.assertEqual(data, b"spameggs") + gen = self.reader.at_eof() + self.assertGeneratorRunning(gen) + + def test_feed_eof_after_feed_data(self): + self.reader.feed_data(b"spam") + self.reader.feed_eof() + + gen = self.reader.read_exact(4) + data = self.assertGeneratorReturns(gen) + self.assertEqual(data, b"spam") + gen = self.reader.at_eof() + eof = self.assertGeneratorReturns(gen) + self.assertTrue(eof) + + def test_feed_data_after_feed_eof(self): + self.reader.feed_eof() + with self.assertRaises(EOFError) as raised: + self.reader.feed_data(b"spam") + self.assertEqual(str(raised.exception), "stream ended") + + def test_feed_eof_after_feed_eof(self): + self.reader.feed_eof() + with self.assertRaises(EOFError) as raised: + self.reader.feed_eof() + self.assertEqual(str(raised.exception), "stream ended") diff --git a/tests/utils.py b/tests/utils.py index 983a91edf..bbffa8649 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -7,6 +7,24 @@ import unittest +class GeneratorTestCase(unittest.TestCase): + def assertGeneratorRunning(self, gen): + """ + Check that a generator-based coroutine hasn't completed yet. + + """ + next(gen) + + def assertGeneratorReturns(self, gen): + """ + Check that a generator-based coroutine completes and return its value. + + """ + with self.assertRaises(StopIteration) as raised: + next(gen) + return raised.exception.value + + class AsyncioTestCase(unittest.TestCase): """ Base class for tests that sets up an isolated event loop for each test. From 624b9d20061c78df81f659af2c87557c764ebb19 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 6 Oct 2019 21:27:06 +0200 Subject: [PATCH 0696/1539] Add a sans-I/O compatible framing implementation. --- docs/changelog.rst | 2 + src/websockets/extensions/base.py | 2 +- .../extensions/permessage_deflate.py | 2 +- src/websockets/frames.py | 322 ++++++++++++++++++ src/websockets/framing.py | 233 +------------ src/websockets/protocol.py | 19 +- tests/__init__.py | 10 + tests/extensions/test_permessage_deflate.py | 2 +- tests/test_frames.py | 232 +++++++++++++ tests/test_framing.py | 103 +----- tests/test_protocol.py | 11 +- 11 files changed, 624 insertions(+), 314 deletions(-) create mode 100644 src/websockets/frames.py create mode 100644 tests/test_frames.py diff --git a/docs/changelog.rst b/docs/changelog.rst index 3cda4919f..68ec6f80c 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -20,6 +20,8 @@ Changelog * :mod:`websockets.http` is deprecated. + * :mod:`websocket.framing` is deprecated. + Aliases provide backwards compatibility for all previously public APIs. 8.1 diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index aa52a7adb..cfc090799 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -10,7 +10,7 @@ from typing import List, Optional, Sequence, Tuple -from ..framing import Frame +from ..frames import Frame from ..typing import ExtensionName, ExtensionParameter diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index e38d9edab..f1adf8bb6 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -14,7 +14,7 @@ NegotiationError, PayloadTooBig, ) -from ..framing import CTRL_OPCODES, OP_CONT, Frame +from ..frames import CTRL_OPCODES, OP_CONT, Frame from ..typing import ExtensionName, ExtensionParameter from .base import ClientExtensionFactory, Extension, ServerExtensionFactory diff --git a/src/websockets/frames.py b/src/websockets/frames.py new file mode 100644 index 000000000..5ed8e483f --- /dev/null +++ b/src/websockets/frames.py @@ -0,0 +1,322 @@ +""" +Parse and serialize WebSocket frames. + +""" + +import io +import random +import struct +from typing import Callable, Generator, NamedTuple, Optional, Sequence, Tuple + +from .exceptions import PayloadTooBig, ProtocolError +from .typing import Data + + +try: + from .speedups import apply_mask +except ImportError: # pragma: no cover + from .utils import apply_mask + + +__all__ = [ + "DATA_OPCODES", + "CTRL_OPCODES", + "OP_CONT", + "OP_TEXT", + "OP_BINARY", + "OP_CLOSE", + "OP_PING", + "OP_PONG", + "Frame", + "prepare_data", + "prepare_ctrl", + "parse_close", + "serialize_close", +] + +DATA_OPCODES = OP_CONT, OP_TEXT, OP_BINARY = 0x00, 0x01, 0x02 +CTRL_OPCODES = OP_CLOSE, OP_PING, OP_PONG = 0x08, 0x09, 0x0A + +# Close code that are allowed in a close frame. +# Using a list optimizes `code in EXTERNAL_CLOSE_CODES`. +EXTERNAL_CLOSE_CODES = [1000, 1001, 1002, 1003, 1007, 1008, 1009, 1010, 1011] + + +# Consider converting to a dataclass when dropping support for Python < 3.7. + + +class Frame(NamedTuple): + """ + WebSocket frame. + + :param bool fin: FIN bit + :param bool rsv1: RSV1 bit + :param bool rsv2: RSV2 bit + :param bool rsv3: RSV3 bit + :param int opcode: opcode + :param bytes data: payload data + + Only these fields are needed. The MASK bit, payload length and masking-key + are handled on the fly by :func:`parse_frame` and :meth:`serialize_frame`. + + """ + + fin: bool + opcode: int + data: bytes + rsv1: bool = False + rsv2: bool = False + rsv3: bool = False + + @classmethod + def parse( + cls, + read_exact: Callable[[int], Generator[None, None, bytes]], + *, + mask: bool, + max_size: Optional[int] = None, + extensions: Optional[Sequence["websockets.extensions.base.Extension"]] = None, + ) -> Generator[None, None, "Frame"]: + """ + Read a WebSocket frame. + + :param read_exact: generator-based coroutine that reads the requested + number of bytes or raises an exception if there isn't enough data + :param mask: whether the frame should be masked i.e. whether the read + happens on the server side + :param max_size: maximum payload size in bytes + :param extensions: list of classes with a ``decode()`` method that + transforms the frame and return a new frame; extensions are applied + in reverse order + :raises ~websockets.exceptions.PayloadTooBig: if the frame exceeds + ``max_size`` + :raises ~websockets.exceptions.ProtocolError: if the frame + contains incorrect values + + """ + # Read the header. + data = yield from read_exact(2) + head1, head2 = struct.unpack("!BB", data) + + # While not Pythonic, this is marginally faster than calling bool(). + fin = True if head1 & 0b10000000 else False + rsv1 = True if head1 & 0b01000000 else False + rsv2 = True if head1 & 0b00100000 else False + rsv3 = True if head1 & 0b00010000 else False + opcode = head1 & 0b00001111 + + if (True if head2 & 0b10000000 else False) != mask: + raise ProtocolError("incorrect masking") + + length = head2 & 0b01111111 + if length == 126: + data = yield from read_exact(2) + (length,) = struct.unpack("!H", data) + elif length == 127: + data = yield from read_exact(8) + (length,) = struct.unpack("!Q", data) + if max_size is not None and length > max_size: + raise PayloadTooBig( + f"payload length exceeds size limit ({length} > {max_size} bytes)" + ) + if mask: + mask_bits = yield from read_exact(4) + + # Read the data. + data = yield from read_exact(length) + if mask: + data = apply_mask(data, mask_bits) + + frame = cls(fin, opcode, data, rsv1, rsv2, rsv3) + + if extensions is None: + extensions = [] + for extension in reversed(extensions): + frame = extension.decode(frame, max_size=max_size) + + frame.check() + + return frame + + def serialize( + self, + *, + mask: bool, + extensions: Optional[Sequence["websockets.extensions.base.Extension"]] = None, + ) -> bytes: + """ + Write a WebSocket frame. + + :param frame: frame to write + :param mask: whether the frame should be masked i.e. whether the write + happens on the client side + :param extensions: list of classes with an ``encode()`` method that + transform the frame and return a new frame; extensions are applied + in order + :raises ~websockets.exceptions.ProtocolError: if the frame + contains incorrect values + + """ + self.check() + + if extensions is None: + extensions = [] + for extension in extensions: + self = extension.encode(self) + + output = io.BytesIO() + + # Prepare the header. + head1 = ( + (0b10000000 if self.fin else 0) + | (0b01000000 if self.rsv1 else 0) + | (0b00100000 if self.rsv2 else 0) + | (0b00010000 if self.rsv3 else 0) + | self.opcode + ) + + head2 = 0b10000000 if mask else 0 + + length = len(self.data) + if length < 126: + output.write(struct.pack("!BB", head1, head2 | length)) + elif length < 65536: + output.write(struct.pack("!BBH", head1, head2 | 126, length)) + else: + output.write(struct.pack("!BBQ", head1, head2 | 127, length)) + + if mask: + mask_bits = struct.pack("!I", random.getrandbits(32)) + output.write(mask_bits) + + # Prepare the data. + if mask: + data = apply_mask(self.data, mask_bits) + else: + data = self.data + output.write(data) + + return output.getvalue() + + def check(self) -> None: + """ + Check that reserved bits and opcode have acceptable values. + + :raises ~websockets.exceptions.ProtocolError: if a reserved + bit or the opcode is invalid + + """ + if self.rsv1 or self.rsv2 or self.rsv3: + raise ProtocolError("reserved bits must be 0") + + if self.opcode in DATA_OPCODES: + return + elif self.opcode in CTRL_OPCODES: + if len(self.data) > 125: + raise ProtocolError("control frame too long") + if not self.fin: + raise ProtocolError("fragmented control frame") + else: + raise ProtocolError(f"invalid opcode: {self.opcode}") + + +def prepare_data(data: Data) -> Tuple[int, bytes]: + """ + Convert a string or byte-like object to an opcode and a bytes-like object. + + This function is designed for data frames. + + If ``data`` is a :class:`str`, return ``OP_TEXT`` and a :class:`bytes` + object encoding ``data`` in UTF-8. + + If ``data`` is a bytes-like object, return ``OP_BINARY`` and a bytes-like + object. + + :raises TypeError: if ``data`` doesn't have a supported type + + """ + if isinstance(data, str): + return OP_TEXT, data.encode("utf-8") + elif isinstance(data, (bytes, bytearray)): + return OP_BINARY, data + elif isinstance(data, memoryview): + if data.c_contiguous: + return OP_BINARY, data + else: + return OP_BINARY, data.tobytes() + else: + raise TypeError("data must be bytes-like or str") + + +def prepare_ctrl(data: Data) -> bytes: + """ + Convert a string or byte-like object to bytes. + + This function is designed for ping and pong frames. + + If ``data`` is a :class:`str`, return a :class:`bytes` object encoding + ``data`` in UTF-8. + + If ``data`` is a bytes-like object, return a :class:`bytes` object. + + :raises TypeError: if ``data`` doesn't have a supported type + + """ + if isinstance(data, str): + return data.encode("utf-8") + elif isinstance(data, (bytes, bytearray)): + return bytes(data) + elif isinstance(data, memoryview): + return data.tobytes() + else: + raise TypeError("data must be bytes-like or str") + + +def parse_close(data: bytes) -> Tuple[int, str]: + """ + Parse the payload from a close frame. + + Return ``(code, reason)``. + + :raises ~websockets.exceptions.ProtocolError: if data is ill-formed + :raises UnicodeDecodeError: if the reason isn't valid UTF-8 + + """ + length = len(data) + if length >= 2: + (code,) = struct.unpack("!H", data[:2]) + check_close(code) + reason = data[2:].decode("utf-8") + return code, reason + elif length == 0: + return 1005, "" + else: + assert length == 1 + raise ProtocolError("close frame too short") + + +def serialize_close(code: int, reason: str) -> bytes: + """ + Serialize the payload for a close frame. + + This is the reverse of :func:`parse_close`. + + """ + check_close(code) + return struct.pack("!H", code) + reason.encode("utf-8") + + +def check_close(code: int) -> None: + """ + Check that the close code has an acceptable value for a close frame. + + :raises ~websockets.exceptions.ProtocolError: if the close code + is invalid + + """ + if not (code in EXTERNAL_CLOSE_CODES or 3000 <= code < 5000): + raise ProtocolError("invalid status code") + + +# at the bottom to allow circular import, because Extension depends on Frame +import websockets.extensions.base # isort:skip # noqa diff --git a/src/websockets/framing.py b/src/websockets/framing.py index 26e58cdbf..221afad6f 100644 --- a/src/websockets/framing.py +++ b/src/websockets/framing.py @@ -10,13 +10,12 @@ """ -import io -import random import struct -from typing import Any, Awaitable, Callable, NamedTuple, Optional, Sequence, Tuple +import warnings +from typing import Any, Awaitable, Callable, Optional, Sequence from .exceptions import PayloadTooBig, ProtocolError -from .typing import Data +from .frames import Frame as NewFrame try: @@ -25,56 +24,10 @@ from .utils import apply_mask -__all__ = [ - "DATA_OPCODES", - "CTRL_OPCODES", - "OP_CONT", - "OP_TEXT", - "OP_BINARY", - "OP_CLOSE", - "OP_PING", - "OP_PONG", - "Frame", - "prepare_data", - "encode_data", - "parse_close", - "serialize_close", -] +warnings.warn("websockets.framing is deprecated", DeprecationWarning) -DATA_OPCODES = OP_CONT, OP_TEXT, OP_BINARY = 0x00, 0x01, 0x02 -CTRL_OPCODES = OP_CLOSE, OP_PING, OP_PONG = 0x08, 0x09, 0x0A - -# Close code that are allowed in a close frame. -# Using a list optimizes `code in EXTERNAL_CLOSE_CODES`. -EXTERNAL_CLOSE_CODES = [1000, 1001, 1002, 1003, 1007, 1008, 1009, 1010, 1011] - - -# Consider converting to a dataclass when dropping support for Python < 3.7. - - -class Frame(NamedTuple): - """ - WebSocket frame. - - :param bool fin: FIN bit - :param bool rsv1: RSV1 bit - :param bool rsv2: RSV2 bit - :param bool rsv3: RSV3 bit - :param int opcode: opcode - :param bytes data: payload data - - Only these fields are needed. The MASK bit, payload length and masking-key - are handled on the fly by :meth:`read` and :meth:`write`. - - """ - - fin: bool - opcode: int - data: bytes - rsv1: bool = False - rsv2: bool = False - rsv3: bool = False +class Frame(NewFrame): @classmethod async def read( cls, @@ -101,6 +54,7 @@ async def read( contains incorrect values """ + # Read the header. data = await reader(2) head1, head2 = struct.unpack("!BB", data) @@ -139,14 +93,14 @@ async def read( if extensions is None: extensions = [] for extension in reversed(extensions): - frame = extension.decode(frame, max_size=max_size) + frame = cls(*extension.decode(frame, max_size=max_size)) frame.check() return frame def write( - frame, + self, write: Callable[[bytes], Any], *, mask: bool, @@ -166,176 +120,17 @@ def write( contains incorrect values """ - # The first parameter is called `frame` rather than `self`, - # but it's the instance of class to which this method is bound. - - frame.check() - - if extensions is None: - extensions = [] - for extension in extensions: - frame = extension.encode(frame) - - output = io.BytesIO() - - # Prepare the header. - head1 = ( - (0b10000000 if frame.fin else 0) - | (0b01000000 if frame.rsv1 else 0) - | (0b00100000 if frame.rsv2 else 0) - | (0b00010000 if frame.rsv3 else 0) - | frame.opcode - ) - - head2 = 0b10000000 if mask else 0 - - length = len(frame.data) - if length < 126: - output.write(struct.pack("!BB", head1, head2 | length)) - elif length < 65536: - output.write(struct.pack("!BBH", head1, head2 | 126, length)) - else: - output.write(struct.pack("!BBQ", head1, head2 | 127, length)) - - if mask: - mask_bits = struct.pack("!I", random.getrandbits(32)) - output.write(mask_bits) - - # Prepare the data. - if mask: - data = apply_mask(frame.data, mask_bits) - else: - data = frame.data - output.write(data) - - # Send the frame. - # The frame is written in a single call to write in order to prevent # TCP fragmentation. See #68 for details. This also makes it safe to # send frames concurrently from multiple coroutines. - write(output.getvalue()) - - def check(frame) -> None: - """ - Check that reserved bits and opcode have acceptable values. - - :raises ~websockets.exceptions.ProtocolError: if a reserved - bit or the opcode is invalid - - """ - # The first parameter is called `frame` rather than `self`, - # but it's the instance of class to which this method is bound. - - if frame.rsv1 or frame.rsv2 or frame.rsv3: - raise ProtocolError("reserved bits must be 0") - - if frame.opcode in DATA_OPCODES: - return - elif frame.opcode in CTRL_OPCODES: - if len(frame.data) > 125: - raise ProtocolError("control frame too long") - if not frame.fin: - raise ProtocolError("fragmented control frame") - else: - raise ProtocolError(f"invalid opcode: {frame.opcode}") - - -def prepare_data(data: Data) -> Tuple[int, bytes]: - """ - Convert a string or byte-like object to an opcode and a bytes-like object. - - This function is designed for data frames. - - If ``data`` is a :class:`str`, return ``OP_TEXT`` and a :class:`bytes` - object encoding ``data`` in UTF-8. - - If ``data`` is a bytes-like object, return ``OP_BINARY`` and a bytes-like - object. - - :raises TypeError: if ``data`` doesn't have a supported type - - """ - if isinstance(data, str): - return OP_TEXT, data.encode("utf-8") - elif isinstance(data, (bytes, bytearray)): - return OP_BINARY, data - elif isinstance(data, memoryview): - if data.c_contiguous: - return OP_BINARY, data - else: - return OP_BINARY, data.tobytes() - else: - raise TypeError("data must be bytes-like or str") - - -def encode_data(data: Data) -> bytes: - """ - Convert a string or byte-like object to bytes. - - This function is designed for ping and pong frames. - - If ``data`` is a :class:`str`, return a :class:`bytes` object encoding - ``data`` in UTF-8. - - If ``data`` is a bytes-like object, return a :class:`bytes` object. - - :raises TypeError: if ``data`` doesn't have a supported type - - """ - if isinstance(data, str): - return data.encode("utf-8") - elif isinstance(data, (bytes, bytearray)): - return bytes(data) - elif isinstance(data, memoryview): - return data.tobytes() - else: - raise TypeError("data must be bytes-like or str") - - -def parse_close(data: bytes) -> Tuple[int, str]: - """ - Parse the payload from a close frame. - - Return ``(code, reason)``. - - :raises ~websockets.exceptions.ProtocolError: if data is ill-formed - :raises UnicodeDecodeError: if the reason isn't valid UTF-8 - - """ - length = len(data) - if length >= 2: - (code,) = struct.unpack("!H", data[:2]) - check_close(code) - reason = data[2:].decode("utf-8") - return code, reason - elif length == 0: - return 1005, "" - else: - assert length == 1 - raise ProtocolError("close frame too short") - - -def serialize_close(code: int, reason: str) -> bytes: - """ - Serialize the payload for a close frame. - - This is the reverse of :func:`parse_close`. - - """ - check_close(code) - return struct.pack("!H", code) + reason.encode("utf-8") - - -def check_close(code: int) -> None: - """ - Check that the close code has an acceptable value for a close frame. + write(self.serialize(mask=mask, extensions=extensions)) - :raises ~websockets.exceptions.ProtocolError: if the close code - is invalid - """ - if not (code in EXTERNAL_CLOSE_CODES or 3000 <= code < 5000): - raise ProtocolError("invalid status code") +# Backwards compatibility with previously documented public APIs +from .frames import parse_close # isort:skip # noqa +from .frames import prepare_ctrl as encode_data # isort:skip # noqa +from .frames import prepare_data # isort:skip # noqa +from .frames import serialize_close # isort:skip # noqa # at the bottom to allow circular import, because Extension depends on Frame diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index cc4416ba8..748c1ae66 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -40,8 +40,19 @@ ProtocolError, ) from .extensions.base import Extension -from .framing import * -from .handshake_legacy import * +from .frames import ( + OP_BINARY, + OP_CLOSE, + OP_CONT, + OP_PING, + OP_PONG, + OP_TEXT, + parse_close, + prepare_ctrl, + prepare_data, + serialize_close, +) +from .framing import Frame from .typing import Data @@ -732,7 +743,7 @@ async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: await self.ensure_open() if data is not None: - data = encode_data(data) + data = prepare_ctrl(data) # Protect against duplicates if a payload is explicitly set. if data in self.pings: @@ -763,7 +774,7 @@ async def pong(self, data: Data = b"") -> None: """ await self.ensure_open() - data = encode_data(data) + data = prepare_ctrl(data) await self.write_frame(True, OP_PONG, data) diff --git a/tests/__init__.py b/tests/__init__.py index dd78609f5..76c869f50 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,5 +1,15 @@ import logging +import warnings # Avoid displaying stack traces at the ERROR logging level. logging.basicConfig(level=logging.CRITICAL) + + +# Ignore deprecation warnings while refactoring is in progress +warnings.filterwarnings( + action="ignore", + message=r"websockets\.framing is deprecated", + category=DeprecationWarning, + module="websockets.framing", +) diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index 0ec49c6c0..e1193e672 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -9,7 +9,7 @@ PayloadTooBig, ) from websockets.extensions.permessage_deflate import * -from websockets.framing import ( +from websockets.frames import ( OP_BINARY, OP_CLOSE, OP_CONT, diff --git a/tests/test_frames.py b/tests/test_frames.py new file mode 100644 index 000000000..39d4055a8 --- /dev/null +++ b/tests/test_frames.py @@ -0,0 +1,232 @@ +import codecs +import struct +import unittest +import unittest.mock + +from websockets.exceptions import PayloadTooBig, ProtocolError +from websockets.frames import * +from websockets.streams import StreamReader + +from .utils import GeneratorTestCase + + +class FrameTests(GeneratorTestCase): + def parse(self, data, mask=False, max_size=None, extensions=None): + reader = StreamReader() + reader.feed_data(data) + reader.feed_eof() + parser = Frame.parse( + reader.read_exact, mask=mask, max_size=max_size, extensions=extensions, + ) + return self.assertGeneratorReturns(parser) + + def round_trip(self, data, frame, mask=False, extensions=None): + parsed = self.parse(data, mask=mask, extensions=extensions) + self.assertEqual(parsed, frame) + + # Make masking deterministic by reusing the same "random" mask. + # This has an effect only when mask is True. + randbits = struct.unpack("!I", data[2:6])[0] if mask else 0 + with unittest.mock.patch("random.getrandbits", return_value=randbits): + serialized = parsed.serialize(mask=mask, extensions=extensions) + self.assertEqual(serialized, data) + + def test_text(self): + self.round_trip(b"\x81\x04Spam", Frame(True, OP_TEXT, b"Spam")) + + def test_text_masked(self): + self.round_trip( + b"\x81\x84\x5b\xfb\xe1\xa8\x08\x8b\x80\xc5", + Frame(True, OP_TEXT, b"Spam"), + mask=True, + ) + + def test_binary(self): + self.round_trip(b"\x82\x04Eggs", Frame(True, OP_BINARY, b"Eggs")) + + def test_binary_masked(self): + self.round_trip( + b"\x82\x84\x53\xcd\xe2\x89\x16\xaa\x85\xfa", + Frame(True, OP_BINARY, b"Eggs"), + mask=True, + ) + + def test_non_ascii_text(self): + self.round_trip( + b"\x81\x05caf\xc3\xa9", Frame(True, OP_TEXT, "café".encode("utf-8")) + ) + + def test_non_ascii_text_masked(self): + self.round_trip( + b"\x81\x85\x64\xbe\xee\x7e\x07\xdf\x88\xbd\xcd", + Frame(True, OP_TEXT, "café".encode("utf-8")), + mask=True, + ) + + def test_close(self): + self.round_trip(b"\x88\x00", Frame(True, OP_CLOSE, b"")) + + def test_ping(self): + self.round_trip(b"\x89\x04ping", Frame(True, OP_PING, b"ping")) + + def test_pong(self): + self.round_trip(b"\x8a\x04pong", Frame(True, OP_PONG, b"pong")) + + def test_long(self): + self.round_trip( + b"\x82\x7e\x00\x7e" + 126 * b"a", Frame(True, OP_BINARY, 126 * b"a") + ) + + def test_very_long(self): + self.round_trip( + b"\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x00" + 65536 * b"a", + Frame(True, OP_BINARY, 65536 * b"a"), + ) + + def test_payload_too_big(self): + with self.assertRaises(PayloadTooBig): + self.parse(b"\x82\x7e\x04\x01" + 1025 * b"a", max_size=1024) + + def test_bad_reserved_bits(self): + for data in [b"\xc0\x00", b"\xa0\x00", b"\x90\x00"]: + with self.subTest(data=data): + with self.assertRaises(ProtocolError): + self.parse(data) + + def test_good_opcode(self): + for opcode in list(range(0x00, 0x03)) + list(range(0x08, 0x0B)): + data = bytes([0x80 | opcode, 0]) + with self.subTest(data=data): + self.parse(data) # does not raise an exception + + def test_bad_opcode(self): + for opcode in list(range(0x03, 0x08)) + list(range(0x0B, 0x10)): + data = bytes([0x80 | opcode, 0]) + with self.subTest(data=data): + with self.assertRaises(ProtocolError): + self.parse(data) + + def test_mask_flag(self): + # Mask flag correctly set. + self.parse(b"\x80\x80\x00\x00\x00\x00", mask=True) + # Mask flag incorrectly unset. + with self.assertRaises(ProtocolError): + self.parse(b"\x80\x80\x00\x00\x00\x00") + # Mask flag correctly unset. + self.parse(b"\x80\x00") + # Mask flag incorrectly set. + with self.assertRaises(ProtocolError): + self.parse(b"\x80\x00", mask=True) + + def test_control_frame_max_length(self): + # At maximum allowed length. + self.parse(b"\x88\x7e\x00\x7d" + 125 * b"a") + # Above maximum allowed length. + with self.assertRaises(ProtocolError): + self.parse(b"\x88\x7e\x00\x7e" + 126 * b"a") + + def test_fragmented_control_frame(self): + # Fin bit correctly set. + self.parse(b"\x88\x00") + # Fin bit incorrectly unset. + with self.assertRaises(ProtocolError): + self.parse(b"\x08\x00") + + def test_extensions(self): + class Rot13: + @staticmethod + def encode(frame): + assert frame.opcode == OP_TEXT + text = frame.data.decode() + data = codecs.encode(text, "rot13").encode() + return frame._replace(data=data) + + # This extensions is symmetrical. + @staticmethod + def decode(frame, *, max_size=None): + return Rot13.encode(frame) + + self.round_trip( + b"\x81\x05uryyb", Frame(True, OP_TEXT, b"hello"), extensions=[Rot13()] + ) + + +class PrepareDataTests(unittest.TestCase): + def test_prepare_data_str(self): + self.assertEqual(prepare_data("café"), (OP_TEXT, b"caf\xc3\xa9")) + + def test_prepare_data_bytes(self): + self.assertEqual(prepare_data(b"tea"), (OP_BINARY, b"tea")) + + def test_prepare_data_bytearray(self): + self.assertEqual( + prepare_data(bytearray(b"tea")), (OP_BINARY, bytearray(b"tea")) + ) + + def test_prepare_data_memoryview(self): + self.assertEqual( + prepare_data(memoryview(b"tea")), (OP_BINARY, memoryview(b"tea")) + ) + + def test_prepare_data_non_contiguous_memoryview(self): + self.assertEqual(prepare_data(memoryview(b"tteeaa")[::2]), (OP_BINARY, b"tea")) + + def test_prepare_data_list(self): + with self.assertRaises(TypeError): + prepare_data([]) + + def test_prepare_data_none(self): + with self.assertRaises(TypeError): + prepare_data(None) + + +class PrepareCtrlTests(unittest.TestCase): + def test_prepare_ctrl_str(self): + self.assertEqual(prepare_ctrl("café"), b"caf\xc3\xa9") + + def test_prepare_ctrl_bytes(self): + self.assertEqual(prepare_ctrl(b"tea"), b"tea") + + def test_prepare_ctrl_bytearray(self): + self.assertEqual(prepare_ctrl(bytearray(b"tea")), b"tea") + + def test_prepare_ctrl_memoryview(self): + self.assertEqual(prepare_ctrl(memoryview(b"tea")), b"tea") + + def test_prepare_ctrl_non_contiguous_memoryview(self): + self.assertEqual(prepare_ctrl(memoryview(b"tteeaa")[::2]), b"tea") + + def test_prepare_ctrl_list(self): + with self.assertRaises(TypeError): + prepare_ctrl([]) + + def test_prepare_ctrl_none(self): + with self.assertRaises(TypeError): + prepare_ctrl(None) + + +class ParseAndSerializeCloseTests(unittest.TestCase): + def round_trip(self, data, code, reason): + parsed = parse_close(data) + self.assertEqual(parsed, (code, reason)) + serialized = serialize_close(code, reason) + self.assertEqual(serialized, data) + + def test_parse_close_and_serialize_close(self): + self.round_trip(b"\x03\xe8", 1000, "") + self.round_trip(b"\x03\xe8OK", 1000, "OK") + + def test_parse_close_empty(self): + self.assertEqual(parse_close(b""), (1005, "")) + + def test_parse_close_errors(self): + with self.assertRaises(ProtocolError): + parse_close(b"\x03") + with self.assertRaises(ProtocolError): + parse_close(b"\x03\xe7") + with self.assertRaises(UnicodeDecodeError): + parse_close(b"\x03\xe8\xff\xff") + + def test_serialize_close_errors(self): + with self.assertRaises(ProtocolError): + serialize_close(999, "") diff --git a/tests/test_framing.py b/tests/test_framing.py index 5def415d2..231cbf718 100644 --- a/tests/test_framing.py +++ b/tests/test_framing.py @@ -2,8 +2,10 @@ import codecs import unittest import unittest.mock +import warnings from websockets.exceptions import PayloadTooBig, ProtocolError +from websockets.frames import OP_BINARY, OP_CLOSE, OP_PING, OP_PONG, OP_TEXT from websockets.framing import * from .utils import AsyncioTestCase @@ -11,24 +13,26 @@ class FramingTests(AsyncioTestCase): def decode(self, message, mask=False, max_size=None, extensions=None): - self.stream = asyncio.StreamReader(loop=self.loop) - self.stream.feed_data(message) - self.stream.feed_eof() - frame = self.loop.run_until_complete( - Frame.read( - self.stream.readexactly, - mask=mask, - max_size=max_size, - extensions=extensions, + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(message) + stream.feed_eof() + with warnings.catch_warnings(record=True): + frame = self.loop.run_until_complete( + Frame.read( + stream.readexactly, + mask=mask, + max_size=max_size, + extensions=extensions, + ) ) - ) # Make sure all the data was consumed. - self.assertTrue(self.stream.at_eof()) + self.assertTrue(stream.at_eof()) return frame def encode(self, frame, mask=False, extensions=None): write = unittest.mock.Mock() - frame.write(write, mask=mask, extensions=extensions) + with warnings.catch_warnings(record=True): + frame.write(write, mask=mask, extensions=extensions) # Ensure the entire frame is sent with a single call to write(). # Multiple calls cause TCP fragmentation and degrade performance. self.assertEqual(write.call_count, 1) @@ -47,12 +51,6 @@ def round_trip(self, message, expected, mask=False, extensions=None): else: # deterministic encoding self.assertEqual(encoded, message) - def round_trip_close(self, data, code, reason): - parsed = parse_close(data) - self.assertEqual(parsed, (code, reason)) - serialized = serialize_close(code, reason) - self.assertEqual(serialized, data) - def test_text(self): self.round_trip(b"\x81\x04Spam", Frame(True, OP_TEXT, b"Spam")) @@ -147,56 +145,6 @@ def test_control_frame_max_length(self): with self.assertRaises(ProtocolError): self.decode(b"\x88\x7e\x00\x7e" + 126 * b"a") - def test_prepare_data_str(self): - self.assertEqual(prepare_data("café"), (OP_TEXT, b"caf\xc3\xa9")) - - def test_prepare_data_bytes(self): - self.assertEqual(prepare_data(b"tea"), (OP_BINARY, b"tea")) - - def test_prepare_data_bytearray(self): - self.assertEqual( - prepare_data(bytearray(b"tea")), (OP_BINARY, bytearray(b"tea")) - ) - - def test_prepare_data_memoryview(self): - self.assertEqual( - prepare_data(memoryview(b"tea")), (OP_BINARY, memoryview(b"tea")) - ) - - def test_prepare_data_non_contiguous_memoryview(self): - self.assertEqual(prepare_data(memoryview(b"tteeaa")[::2]), (OP_BINARY, b"tea")) - - def test_prepare_data_list(self): - with self.assertRaises(TypeError): - prepare_data([]) - - def test_prepare_data_none(self): - with self.assertRaises(TypeError): - prepare_data(None) - - def test_encode_data_str(self): - self.assertEqual(encode_data("café"), b"caf\xc3\xa9") - - def test_encode_data_bytes(self): - self.assertEqual(encode_data(b"tea"), b"tea") - - def test_encode_data_bytearray(self): - self.assertEqual(encode_data(bytearray(b"tea")), b"tea") - - def test_encode_data_memoryview(self): - self.assertEqual(encode_data(memoryview(b"tea")), b"tea") - - def test_encode_data_non_contiguous_memoryview(self): - self.assertEqual(encode_data(memoryview(b"tteeaa")[::2]), b"tea") - - def test_encode_data_list(self): - with self.assertRaises(TypeError): - encode_data([]) - - def test_encode_data_none(self): - with self.assertRaises(TypeError): - encode_data(None) - def test_fragmented_control_frame(self): # Fin bit correctly set. self.decode(b"\x88\x00") @@ -204,25 +152,6 @@ def test_fragmented_control_frame(self): with self.assertRaises(ProtocolError): self.decode(b"\x08\x00") - def test_parse_close_and_serialize_close(self): - self.round_trip_close(b"\x03\xe8", 1000, "") - self.round_trip_close(b"\x03\xe8OK", 1000, "OK") - - def test_parse_close_empty(self): - self.assertEqual(parse_close(b""), (1005, "")) - - def test_parse_close_errors(self): - with self.assertRaises(ProtocolError): - parse_close(b"\x03") - with self.assertRaises(ProtocolError): - parse_close(b"\x03\xe7") - with self.assertRaises(UnicodeDecodeError): - parse_close(b"\x03\xe8\xff\xff") - - def test_serialize_close_errors(self): - with self.assertRaises(ProtocolError): - serialize_close(999, "") - def test_extensions(self): class Rot13: @staticmethod diff --git a/tests/test_protocol.py b/tests/test_protocol.py index d32c1f72e..91fb02a50 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -6,7 +6,16 @@ import warnings from websockets.exceptions import ConnectionClosed, InvalidState -from websockets.framing import * +from websockets.frames import ( + OP_BINARY, + OP_CLOSE, + OP_CONT, + OP_PING, + OP_PONG, + OP_TEXT, + serialize_close, +) +from websockets.framing import Frame from websockets.protocol import State, WebSocketCommonProtocol from .utils import MS, AsyncioTestCase From 7b67307ec9f324535cea7e141c6d1a43cb47f4ff Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 14 Oct 2019 21:55:06 +0200 Subject: [PATCH 0697/1539] Add a sans-I/O compatible HTTP/1.1 implementation. --- src/websockets/http11.py | 295 +++++++++++++++++++++++++++++++++++++++ tests/test_http11.py | 271 +++++++++++++++++++++++++++++++++++ 2 files changed, 566 insertions(+) create mode 100644 src/websockets/http11.py create mode 100644 tests/test_http11.py diff --git a/src/websockets/http11.py b/src/websockets/http11.py new file mode 100644 index 000000000..e1d004881 --- /dev/null +++ b/src/websockets/http11.py @@ -0,0 +1,295 @@ +import re +from typing import Callable, Generator, NamedTuple, Optional + +from .datastructures import Headers +from .exceptions import SecurityError + + +MAX_HEADERS = 256 +MAX_LINE = 4096 + + +def d(value: bytes) -> str: + """ + Decode a bytestring for interpolating into an error message. + + """ + return value.decode(errors="backslashreplace") + + +# See https://tools.ietf.org/html/rfc7230#appendix-B. + +# Regex for validating header names. + +_token_re = re.compile(rb"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+") + +# Regex for validating header values. + +# We don't attempt to support obsolete line folding. + +# Include HTAB (\x09), SP (\x20), VCHAR (\x21-\x7e), obs-text (\x80-\xff). + +# The ABNF is complicated because it attempts to express that optional +# whitespace is ignored. We strip whitespace and don't revalidate that. + +# See also https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189 + +_value_re = re.compile(rb"[\x09\x20-\x7e\x80-\xff]*") + + +# Consider converting to dataclasses when dropping support for Python < 3.7. + + +class Request(NamedTuple): + """ + WebSocket handshake request. + + :param path: path and optional query + :param headers: + """ + + path: str + headers: Headers + # body isn't useful is the context of this library + + @classmethod + def parse( + cls, read_line: Callable[[], Generator[None, None, bytes]] + ) -> Generator[None, None, "Request"]: + """ + Parse an HTTP/1.1 GET request and return ``(path, headers)``. + + ``path`` isn't URL-decoded or validated in any way. + + ``path`` and ``headers`` are expected to contain only ASCII characters. + Other characters are represented with surrogate escapes. + + :func:`parse_request` doesn't attempt to read the request body because + WebSocket handshake requests don't have one. If the request contains a + body, it may be read from ``stream`` after this coroutine returns. + + :param read_line: generator-based coroutine that reads a LF-terminated + line or raises an exception if there isn't enough data + :raises EOFError: if the connection is closed without a full HTTP request + :raises SecurityError: if the request exceeds a security limit + :raises ValueError: if the request isn't well formatted + + """ + # https://tools.ietf.org/html/rfc7230#section-3.1.1 + + # Parsing is simple because fixed values are expected for method and + # version and because path isn't checked. Since WebSocket software tends + # to implement HTTP/1.1 strictly, there's little need for lenient parsing. + + try: + request_line = yield from parse_line(read_line) + except EOFError as exc: + raise EOFError("connection closed while reading HTTP request line") from exc + + try: + method, raw_path, version = request_line.split(b" ", 2) + except ValueError: # not enough values to unpack (expected 3, got 1-2) + raise ValueError(f"invalid HTTP request line: {d(request_line)}") from None + + if method != b"GET": + raise ValueError(f"unsupported HTTP method: {d(method)}") + if version != b"HTTP/1.1": + raise ValueError(f"unsupported HTTP version: {d(version)}") + path = raw_path.decode("ascii", "surrogateescape") + + headers = yield from parse_headers(read_line) + + return cls(path, headers) + + def serialize(self) -> bytes: + """ + Serialize an HTTP/1.1 GET request. + + """ + # Since the path and headers only contain ASCII characters, + # we can keep this simple. + request = f"GET {self.path} HTTP/1.1\r\n".encode() + request += self.headers.serialize() + return request + + +# Consider converting to dataclasses when dropping support for Python < 3.7. + + +class Response(NamedTuple): + """ + WebSocket handshake response. + + """ + + status_code: int + reason_phrase: str + headers: Headers + body: Optional[bytes] = None + + @classmethod + def parse( + cls, + read_line: Callable[[], Generator[None, None, bytes]], + read_exact: Callable[[int], Generator[None, None, bytes]], + read_to_eof: Callable[[], Generator[None, None, bytes]], + ) -> Generator[None, None, "Response"]: + """ + Parse an HTTP/1.1 response and return ``(status_code, reason, headers)``. + + ``reason`` and ``headers`` are expected to contain only ASCII characters. + Other characters are represented with surrogate escapes. + + :func:`parse_request` doesn't attempt to read the response body because + WebSocket handshake responses don't have one. If the response contains a + body, it may be read from ``stream`` after this coroutine returns. + + :param read_line: generator-based coroutine that reads a LF-terminated + line or raises an exception if there isn't enough data + :param read_exact: generator-based coroutine that reads the requested + number of bytes or raises an exception if there isn't enough data + :raises EOFError: if the connection is closed without a full HTTP response + :raises SecurityError: if the response exceeds a security limit + :raises LookupError: if the response isn't well formatted + :raises ValueError: if the response isn't well formatted + + """ + # https://tools.ietf.org/html/rfc7230#section-3.1.2 + + # As in parse_request, parsing is simple because a fixed value is expected + # for version, status_code is a 3-digit number, and reason can be ignored. + + try: + status_line = yield from parse_line(read_line) + except EOFError as exc: + raise EOFError("connection closed while reading HTTP status line") from exc + + try: + version, raw_status_code, raw_reason = status_line.split(b" ", 2) + except ValueError: # not enough values to unpack (expected 3, got 1-2) + raise ValueError(f"invalid HTTP status line: {d(status_line)}") from None + + if version != b"HTTP/1.1": + raise ValueError(f"unsupported HTTP version: {d(version)}") + try: + status_code = int(raw_status_code) + except ValueError: # invalid literal for int() with base 10 + raise ValueError( + f"invalid HTTP status code: {d(raw_status_code)}" + ) from None + if not 100 <= status_code < 1000: + raise ValueError(f"unsupported HTTP status code: {d(raw_status_code)}") + if not _value_re.fullmatch(raw_reason): + raise ValueError(f"invalid HTTP reason phrase: {d(raw_reason)}") + reason = raw_reason.decode() + + headers = yield from parse_headers(read_line) + + # https://tools.ietf.org/html/rfc7230#section-3.3.3 + + if "Transfer-Encoding" in headers: + raise NotImplementedError("transfer codings aren't supported") + + # Since websockets only does GET requests (no HEAD, no CONNECT), all + # responses except 1xx, 204, and 304 include a message body. + if 100 <= status_code < 200 or status_code == 204 or status_code == 304: + body = None + else: + content_length: Optional[int] + try: + # MultipleValuesError is sufficiently unlikely that we don't + # attempt to handle it. Instead we document that its parent + # class, LookupError, may be raised. + raw_content_length = headers["Content-Length"] + except KeyError: + content_length = None + else: + content_length = int(raw_content_length) + + if content_length is None: + body = yield from read_to_eof() + else: + body = yield from read_exact(content_length) + + return cls(status_code, reason, headers, body) + + def serialize(self) -> bytes: + """ + Serialize an HTTP/1.1 GET response. + + """ + # Since the status line and headers only contain ASCII characters, + # we can keep this simple. + response = f"HTTP/1.1 {self.status_code} {self.reason_phrase}\r\n".encode() + response += self.headers.serialize() + if self.body is not None: + response += self.body + return response + + +def parse_headers( + read_line: Callable[[], Generator[None, None, bytes]] +) -> Generator[None, None, Headers]: + """ + Parse HTTP headers. + + Non-ASCII characters are represented with surrogate escapes. + + :param read_line: generator-based coroutine that reads a LF-terminated + line or raises an exception if there isn't enough data + + """ + # https://tools.ietf.org/html/rfc7230#section-3.2 + + # We don't attempt to support obsolete line folding. + + headers = Headers() + for _ in range(MAX_HEADERS + 1): + try: + line = yield from parse_line(read_line) + except EOFError as exc: + raise EOFError("connection closed while reading HTTP headers") from exc + if line == b"": + break + + try: + raw_name, raw_value = line.split(b":", 1) + except ValueError: # not enough values to unpack (expected 2, got 1) + raise ValueError(f"invalid HTTP header line: {d(line)}") from None + if not _token_re.fullmatch(raw_name): + raise ValueError(f"invalid HTTP header name: {d(raw_name)}") + raw_value = raw_value.strip(b" \t") + if not _value_re.fullmatch(raw_value): + raise ValueError(f"invalid HTTP header value: {d(raw_value)}") + + name = raw_name.decode("ascii") # guaranteed to be ASCII at this point + value = raw_value.decode("ascii", "surrogateescape") + headers[name] = value + + else: + raise SecurityError("too many HTTP headers") + + return headers + + +def parse_line( + read_line: Callable[[], Generator[None, None, bytes]] +) -> Generator[None, None, bytes]: + """ + Parse a single line. + + CRLF is stripped from the return value. + + :param read_line: generator-based coroutine that reads a LF-terminated + line or raises an exception if there isn't enough data + + """ + # Security: TODO: add a limit here + line = yield from read_line() + # Security: this guarantees header values are small (hard-coded = 4 KiB) + if len(line) > MAX_LINE: + raise SecurityError("line too long") + # Not mandatory but safe - https://tools.ietf.org/html/rfc7230#section-3.5 + if not line.endswith(b"\r\n"): + raise EOFError("line without CRLF") + return line[:-2] diff --git a/tests/test_http11.py b/tests/test_http11.py new file mode 100644 index 000000000..bca874aee --- /dev/null +++ b/tests/test_http11.py @@ -0,0 +1,271 @@ +from websockets.datastructures import Headers +from websockets.exceptions import SecurityError +from websockets.http11 import * +from websockets.http11 import parse_headers +from websockets.streams import StreamReader + +from .utils import GeneratorTestCase + + +class RequestTests(GeneratorTestCase): + def setUp(self): + super().setUp() + self.reader = StreamReader() + + def parse(self): + return Request.parse(self.reader.read_line) + + def test_parse(self): + # Example from the protocol overview in RFC 6455 + self.reader.feed_data( + b"GET /chat HTTP/1.1\r\n" + b"Host: server.example.com\r\n" + b"Upgrade: websocket\r\n" + b"Connection: Upgrade\r\n" + b"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + b"Origin: http://example.com\r\n" + b"Sec-WebSocket-Protocol: chat, superchat\r\n" + b"Sec-WebSocket-Version: 13\r\n" + b"\r\n" + ) + request = self.assertGeneratorReturns(self.parse()) + self.assertEqual(request.path, "/chat") + self.assertEqual(request.headers["Upgrade"], "websocket") + + def test_parse_empty(self): + self.reader.feed_eof() + with self.assertRaises(EOFError) as raised: + next(self.parse()) + self.assertEqual( + str(raised.exception), "connection closed while reading HTTP request line" + ) + + def test_parse_invalid_request_line(self): + self.reader.feed_data(b"GET /\r\n\r\n") + with self.assertRaises(ValueError) as raised: + next(self.parse()) + self.assertEqual(str(raised.exception), "invalid HTTP request line: GET /") + + def test_parse_unsupported_method(self): + self.reader.feed_data(b"OPTIONS * HTTP/1.1\r\n\r\n") + with self.assertRaises(ValueError) as raised: + next(self.parse()) + self.assertEqual(str(raised.exception), "unsupported HTTP method: OPTIONS") + + def test_parse_unsupported_version(self): + self.reader.feed_data(b"GET /chat HTTP/1.0\r\n\r\n") + with self.assertRaises(ValueError) as raised: + next(self.parse()) + self.assertEqual(str(raised.exception), "unsupported HTTP version: HTTP/1.0") + + def test_parse_invalid_header(self): + self.reader.feed_data(b"GET /chat HTTP/1.1\r\nOops\r\n") + with self.assertRaises(ValueError) as raised: + next(self.parse()) + self.assertEqual(str(raised.exception), "invalid HTTP header line: Oops") + + def test_serialize(self): + # Example from the protocol overview in RFC 6455 + request = Request( + "/chat", + Headers( + [ + ("Host", "server.example.com"), + ("Upgrade", "websocket"), + ("Connection", "Upgrade"), + ("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ=="), + ("Origin", "http://example.com"), + ("Sec-WebSocket-Protocol", "chat, superchat"), + ("Sec-WebSocket-Version", "13"), + ] + ), + ) + self.assertEqual( + request.serialize(), + b"GET /chat HTTP/1.1\r\n" + b"Host: server.example.com\r\n" + b"Upgrade: websocket\r\n" + b"Connection: Upgrade\r\n" + b"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + b"Origin: http://example.com\r\n" + b"Sec-WebSocket-Protocol: chat, superchat\r\n" + b"Sec-WebSocket-Version: 13\r\n" + b"\r\n", + ) + + +class ResponseTests(GeneratorTestCase): + def setUp(self): + super().setUp() + self.reader = StreamReader() + + def parse(self): + return Response.parse( + self.reader.read_line, self.reader.read_exact, self.reader.read_to_eof + ) + + def test_parse(self): + # Example from the protocol overview in RFC 6455 + self.reader.feed_data( + b"HTTP/1.1 101 Switching Protocols\r\n" + b"Upgrade: websocket\r\n" + b"Connection: Upgrade\r\n" + b"Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n" + b"Sec-WebSocket-Protocol: chat\r\n" + b"\r\n" + ) + response = self.assertGeneratorReturns(self.parse()) + self.assertEqual(response.status_code, 101) + self.assertEqual(response.reason_phrase, "Switching Protocols") + self.assertEqual(response.headers["Upgrade"], "websocket") + self.assertIsNone(response.body) + + def test_parse_empty(self): + self.reader.feed_eof() + with self.assertRaises(EOFError) as raised: + next(self.parse()) + self.assertEqual( + str(raised.exception), "connection closed while reading HTTP status line" + ) + + def test_parse_invalid_status_line(self): + self.reader.feed_data(b"Hello!\r\n") + with self.assertRaises(ValueError) as raised: + next(self.parse()) + self.assertEqual(str(raised.exception), "invalid HTTP status line: Hello!") + + def test_parse_unsupported_version(self): + self.reader.feed_data(b"HTTP/1.0 400 Bad Request\r\n\r\n") + with self.assertRaises(ValueError) as raised: + next(self.parse()) + self.assertEqual(str(raised.exception), "unsupported HTTP version: HTTP/1.0") + + def test_parse_invalid_status(self): + self.reader.feed_data(b"HTTP/1.1 OMG WTF\r\n\r\n") + with self.assertRaises(ValueError) as raised: + next(self.parse()) + self.assertEqual(str(raised.exception), "invalid HTTP status code: OMG") + + def test_parse_unsupported_status(self): + self.reader.feed_data(b"HTTP/1.1 007 My name is Bond\r\n\r\n") + with self.assertRaises(ValueError) as raised: + next(self.parse()) + self.assertEqual(str(raised.exception), "unsupported HTTP status code: 007") + + def test_parse_invalid_reason(self): + self.reader.feed_data(b"HTTP/1.1 200 \x7f\r\n\r\n") + with self.assertRaises(ValueError) as raised: + next(self.parse()) + self.assertEqual(str(raised.exception), "invalid HTTP reason phrase: \x7f") + + def test_parse_invalid_header(self): + self.reader.feed_data(b"HTTP/1.1 500 Internal Server Error\r\nOops\r\n") + with self.assertRaises(ValueError) as raised: + next(self.parse()) + self.assertEqual(str(raised.exception), "invalid HTTP header line: Oops") + + def test_parse_body_with_content_length(self): + self.reader.feed_data( + b"HTTP/1.1 200 OK\r\nContent-Length: 13\r\n\r\nHello world!\n" + ) + response = self.assertGeneratorReturns(self.parse()) + self.assertEqual(response.body, b"Hello world!\n") + + def test_parse_body_without_content_length(self): + self.reader.feed_data(b"HTTP/1.1 200 OK\r\n\r\nHello world!\n") + gen = self.parse() + self.assertGeneratorRunning(gen) + self.reader.feed_eof() + response = self.assertGeneratorReturns(gen) + self.assertEqual(response.body, b"Hello world!\n") + + def test_parse_body_with_transfer_encoding(self): + self.reader.feed_data(b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n") + with self.assertRaises(NotImplementedError) as raised: + next(self.parse()) + self.assertEqual(str(raised.exception), "transfer codings aren't supported") + + def test_parse_body_no_content(self): + self.reader.feed_data(b"HTTP/1.1 204 No Content\r\n\r\n") + response = self.assertGeneratorReturns(self.parse()) + self.assertIsNone(response.body) + + def test_parse_body_not_modified(self): + self.reader.feed_data(b"HTTP/1.1 304 Not Modified\r\n\r\n") + response = self.assertGeneratorReturns(self.parse()) + self.assertIsNone(response.body) + + def test_serialize(self): + # Example from the protocol overview in RFC 6455 + response = Response( + 101, + "Switching Protocols", + Headers( + [ + ("Upgrade", "websocket"), + ("Connection", "Upgrade"), + ("Sec-WebSocket-Accept", "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="), + ("Sec-WebSocket-Protocol", "chat"), + ] + ), + ) + self.assertEqual( + response.serialize(), + b"HTTP/1.1 101 Switching Protocols\r\n" + b"Upgrade: websocket\r\n" + b"Connection: Upgrade\r\n" + b"Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n" + b"Sec-WebSocket-Protocol: chat\r\n" + b"\r\n", + ) + + def test_serialize_with_body(self): + response = Response( + 200, + "OK", + Headers([("Content-Length", "13"), ("Content-Type", "text/plain")]), + b"Hello world!\n", + ) + self.assertEqual( + response.serialize(), + b"HTTP/1.1 200 OK\r\n" + b"Content-Length: 13\r\n" + b"Content-Type: text/plain\r\n" + b"\r\n" + b"Hello world!\n", + ) + + +class HeadersTests(GeneratorTestCase): + def setUp(self): + super().setUp() + self.reader = StreamReader() + + def parse_headers(self): + return parse_headers(self.reader.read_line) + + def test_parse_invalid_name(self): + self.reader.feed_data(b"foo bar: baz qux\r\n\r\n") + with self.assertRaises(ValueError): + next(self.parse_headers()) + + def test_parse_invalid_value(self): + self.reader.feed_data(b"foo: \x00\x00\x0f\r\n\r\n") + with self.assertRaises(ValueError): + next(self.parse_headers()) + + def test_parse_too_long_value(self): + self.reader.feed_data(b"foo: bar\r\n" * 257 + b"\r\n") + with self.assertRaises(SecurityError): + next(self.parse_headers()) + + def test_parse_too_long_line(self): + # Header line contains 5 + 4090 + 2 = 4097 bytes. + self.reader.feed_data(b"foo: " + b"a" * 4090 + b"\r\n\r\n") + with self.assertRaises(SecurityError): + next(self.parse_headers()) + + def test_parse_invalid_line_ending(self): + self.reader.feed_data(b"foo: bar\n\n") + with self.assertRaises(EOFError): + next(self.parse_headers()) From e4bc504a880110b0d3cd1dbc8e55b69a5f44ee7c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 6 Oct 2019 20:00:10 +0200 Subject: [PATCH 0698/1539] Salvage accept() from the legacy handshake module. --- src/websockets/handshake.py | 3 --- src/websockets/handshake_legacy.py | 8 +------- src/websockets/utils.py | 18 +++++++++++++++++- tests/test_handshake_legacy.py | 10 ++-------- tests/test_utils.py | 16 +++++++++++++--- 5 files changed, 33 insertions(+), 22 deletions(-) diff --git a/src/websockets/handshake.py b/src/websockets/handshake.py index f27bd1b84..3ff6c005d 100644 --- a/src/websockets/handshake.py +++ b/src/websockets/handshake.py @@ -6,9 +6,6 @@ __all__ = ["build_request", "check_request", "build_response", "check_response"] -GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" - - # Backwards compatibility with previously documented public APIs diff --git a/src/websockets/handshake_legacy.py b/src/websockets/handshake_legacy.py index 9683e8556..1f6c58e1b 100644 --- a/src/websockets/handshake_legacy.py +++ b/src/websockets/handshake_legacy.py @@ -27,15 +27,14 @@ import base64 import binascii -import hashlib import random from typing import List from .datastructures import Headers, MultipleValuesError from .exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade -from .handshake import GUID from .headers import parse_connection, parse_upgrade from .typing import ConnectionOption, UpgradeProtocol +from .utils import accept_key as accept __all__ = ["build_request", "check_request", "build_response", "check_response"] @@ -180,8 +179,3 @@ def check_response(headers: Headers, key: str) -> None: if s_w_accept != accept(key): raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept) - - -def accept(key: str) -> str: - sha1 = hashlib.sha1((key + GUID).encode()).digest() - return base64.b64encode(sha1).decode() diff --git a/src/websockets/utils.py b/src/websockets/utils.py index 40ac8559f..f9d0ca763 100644 --- a/src/websockets/utils.py +++ b/src/websockets/utils.py @@ -1,7 +1,23 @@ +import base64 +import hashlib import itertools -__all__ = ["apply_mask"] +__all__ = ["accept_key", "apply_mask"] + + +GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + + +def accept_key(key: str) -> str: + """ + Compute the value of the Sec-WebSocket-Accept header. + + :param key: value of the Sec-WebSocket-Key header + + """ + sha1 = hashlib.sha1((key + GUID).encode()).digest() + return base64.b64encode(sha1).decode() def apply_mask(data: bytes, mask: bytes) -> bytes: diff --git a/tests/test_handshake_legacy.py b/tests/test_handshake_legacy.py index 361410d3f..c34b94e41 100644 --- a/tests/test_handshake_legacy.py +++ b/tests/test_handshake_legacy.py @@ -9,16 +9,10 @@ InvalidUpgrade, ) from websockets.handshake_legacy import * -from websockets.handshake_legacy import accept # private API +from websockets.utils import accept_key class HandshakeTests(unittest.TestCase): - def test_accept(self): - # Test vector from RFC 6455 - key = "dGhlIHNhbXBsZSBub25jZQ==" - acc = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" - self.assertEqual(accept(key), acc) - def test_round_trip(self): request_headers = Headers() request_key = build_request(request_headers) @@ -178,7 +172,7 @@ def test_response_invalid_accept(self): with self.assertInvalidResponseHeaders(InvalidHeaderValue) as headers: del headers["Sec-WebSocket-Accept"] other_key = "1Eq4UDEFQYg3YspNgqxv5g==" - headers["Sec-WebSocket-Accept"] = accept(other_key) + headers["Sec-WebSocket-Accept"] = accept_key(other_key) def test_response_missing_accept(self): with self.assertInvalidResponseHeaders(InvalidHeader) as headers: diff --git a/tests/test_utils.py b/tests/test_utils.py index e5570f098..7d5417d79 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,10 +1,20 @@ import itertools import unittest -from websockets.utils import apply_mask as py_apply_mask +from websockets.utils import accept_key, apply_mask as py_apply_mask -class UtilsTests(unittest.TestCase): +# Test vector from RFC 6455 +KEY = "dGhlIHNhbXBsZSBub25jZQ==" +ACCEPT = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" + + +class AcceptKeyTests(unittest.TestCase): + def test_accept_key(self): + self.assertEqual(accept_key(KEY), ACCEPT) + + +class ApplyMaskTests(unittest.TestCase): @staticmethod def apply_mask(*args, **kwargs): return py_apply_mask(*args, **kwargs) @@ -73,7 +83,7 @@ def test_apply_mask_check_mask_length(self): pass else: - class SpeedupsTests(UtilsTests): + class SpeedupsTests(ApplyMaskTests): @staticmethod def apply_mask(*args, **kwargs): return c_apply_mask(*args, **kwargs) From cf5af352200e6800f0152c8399af067a45053d76 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 25 Jan 2020 21:32:21 +0100 Subject: [PATCH 0699/1539] Extract generate_key() from the legacy handshake module. --- src/websockets/handshake_legacy.py | 6 ++---- src/websockets/utils.py | 10 ++++++++++ tests/test_utils.py | 9 +++++++-- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/src/websockets/handshake_legacy.py b/src/websockets/handshake_legacy.py index 1f6c58e1b..7e6acc77d 100644 --- a/src/websockets/handshake_legacy.py +++ b/src/websockets/handshake_legacy.py @@ -27,14 +27,13 @@ import base64 import binascii -import random from typing import List from .datastructures import Headers, MultipleValuesError from .exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade from .headers import parse_connection, parse_upgrade from .typing import ConnectionOption, UpgradeProtocol -from .utils import accept_key as accept +from .utils import accept_key as accept, generate_key __all__ = ["build_request", "check_request", "build_response", "check_response"] @@ -50,8 +49,7 @@ def build_request(headers: Headers) -> str: :returns: ``key`` which must be passed to :func:`check_response` """ - raw_key = bytes(random.getrandbits(8) for _ in range(16)) - key = base64.b64encode(raw_key).decode() + key = generate_key() headers["Upgrade"] = "websocket" headers["Connection"] = "Upgrade" headers["Sec-WebSocket-Key"] = key diff --git a/src/websockets/utils.py b/src/websockets/utils.py index f9d0ca763..a2fe8cc7f 100644 --- a/src/websockets/utils.py +++ b/src/websockets/utils.py @@ -1,6 +1,7 @@ import base64 import hashlib import itertools +import random __all__ = ["accept_key", "apply_mask"] @@ -9,6 +10,15 @@ GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" +def generate_key() -> str: + """ + Generate a random key for the Sec-WebSocket-Key header. + + """ + key = bytes(random.getrandbits(8) for _ in range(16)) + return base64.b64encode(key).decode() + + def accept_key(key: str) -> str: """ Compute the value of the Sec-WebSocket-Accept header. diff --git a/tests/test_utils.py b/tests/test_utils.py index 7d5417d79..b490c2409 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,8 @@ +import base64 import itertools import unittest -from websockets.utils import accept_key, apply_mask as py_apply_mask +from websockets.utils import accept_key, apply_mask as py_apply_mask, generate_key # Test vector from RFC 6455 @@ -9,7 +10,11 @@ ACCEPT = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" -class AcceptKeyTests(unittest.TestCase): +class UtilsTests(unittest.TestCase): + def test_generate_key(self): + key = generate_key() + self.assertEqual(len(base64.b64decode(key.encode())), 16) + def test_accept_key(self): self.assertEqual(accept_key(KEY), ACCEPT) From 1af2296159b0e5165bbcf4b636ed7a06520928ab Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 14 Jun 2020 10:27:30 +0200 Subject: [PATCH 0700/1539] Take advantage of the secrets module. Per RFC 6455, "the masking key MUST be derived from a strong source of entropy." There is no such requirement Sec-WebSocket-Key but it seems better anyway. --- src/websockets/frames.py | 12 ++++++------ src/websockets/utils.py | 4 ++-- tests/test_frames.py | 5 ++--- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 5ed8e483f..56dcf6171 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -4,7 +4,7 @@ """ import io -import random +import secrets import struct from typing import Callable, Generator, NamedTuple, Optional, Sequence, Tuple @@ -120,12 +120,12 @@ def parse( f"payload length exceeds size limit ({length} > {max_size} bytes)" ) if mask: - mask_bits = yield from read_exact(4) + mask_bytes = yield from read_exact(4) # Read the data. data = yield from read_exact(length) if mask: - data = apply_mask(data, mask_bits) + data = apply_mask(data, mask_bytes) frame = cls(fin, opcode, data, rsv1, rsv2, rsv3) @@ -186,12 +186,12 @@ def serialize( output.write(struct.pack("!BBQ", head1, head2 | 127, length)) if mask: - mask_bits = struct.pack("!I", random.getrandbits(32)) - output.write(mask_bits) + mask_bytes = secrets.token_bytes(4) + output.write(mask_bytes) # Prepare the data. if mask: - data = apply_mask(self.data, mask_bits) + data = apply_mask(self.data, mask_bytes) else: data = self.data output.write(data) diff --git a/src/websockets/utils.py b/src/websockets/utils.py index a2fe8cc7f..59210e438 100644 --- a/src/websockets/utils.py +++ b/src/websockets/utils.py @@ -1,7 +1,7 @@ import base64 import hashlib import itertools -import random +import secrets __all__ = ["accept_key", "apply_mask"] @@ -15,7 +15,7 @@ def generate_key() -> str: Generate a random key for the Sec-WebSocket-Key header. """ - key = bytes(random.getrandbits(8) for _ in range(16)) + key = secrets.token_bytes(16) return base64.b64encode(key).decode() diff --git a/tests/test_frames.py b/tests/test_frames.py index 39d4055a8..37a73b2df 100644 --- a/tests/test_frames.py +++ b/tests/test_frames.py @@ -1,5 +1,4 @@ import codecs -import struct import unittest import unittest.mock @@ -26,8 +25,8 @@ def round_trip(self, data, frame, mask=False, extensions=None): # Make masking deterministic by reusing the same "random" mask. # This has an effect only when mask is True. - randbits = struct.unpack("!I", data[2:6])[0] if mask else 0 - with unittest.mock.patch("random.getrandbits", return_value=randbits): + mask_bytes = data[2:6] if mask else b"" + with unittest.mock.patch("secrets.token_bytes", return_value=mask_bytes): serialized = parsed.serialize(mask=mask, extensions=extensions) self.assertEqual(serialized, data) From 6bce2489660daf09f1e6bdf121cabdea83128e4e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Feb 2020 09:40:33 +0100 Subject: [PATCH 0701/1539] Move asyncio client and server out of the way. --- src/websockets/asyncio_client.py | 588 ++++++++++ src/websockets/asyncio_server.py | 1004 ++++++++++++++++ src/websockets/auth.py | 2 +- src/websockets/client.py | 592 +--------- src/websockets/server.py | 1009 +---------------- ...erver.py => test_asyncio_client_server.py} | 28 +- tests/test_auth.py | 2 +- 7 files changed, 1623 insertions(+), 1602 deletions(-) create mode 100644 src/websockets/asyncio_client.py create mode 100644 src/websockets/asyncio_server.py rename tests/{test_client_server.py => test_asyncio_client_server.py} (98%) diff --git a/src/websockets/asyncio_client.py b/src/websockets/asyncio_client.py new file mode 100644 index 000000000..f95dae060 --- /dev/null +++ b/src/websockets/asyncio_client.py @@ -0,0 +1,588 @@ +""" +:mod:`websockets.client` defines the WebSocket client APIs. + +""" + +import asyncio +import collections.abc +import functools +import logging +import warnings +from types import TracebackType +from typing import Any, Generator, List, Optional, Sequence, Tuple, Type, cast + +from .datastructures import Headers, HeadersLike +from .exceptions import ( + InvalidHandshake, + InvalidHeader, + InvalidMessage, + InvalidStatusCode, + NegotiationError, + RedirectHandshake, + SecurityError, +) +from .extensions.base import ClientExtensionFactory, Extension +from .extensions.permessage_deflate import ClientPerMessageDeflateFactory +from .handshake_legacy import build_request, check_response +from .headers import ( + build_authorization_basic, + build_extension, + build_subprotocol, + parse_extension, + parse_subprotocol, +) +from .http import USER_AGENT +from .http_legacy import read_response +from .protocol import WebSocketCommonProtocol +from .typing import ExtensionHeader, Origin, Subprotocol +from .uri import WebSocketURI, parse_uri + + +__all__ = ["connect", "unix_connect", "WebSocketClientProtocol"] + +logger = logging.getLogger(__name__) + + +class WebSocketClientProtocol(WebSocketCommonProtocol): + """ + :class:`~asyncio.Protocol` subclass implementing a WebSocket client. + + This class inherits most of its methods from + :class:`~websockets.protocol.WebSocketCommonProtocol`. + + """ + + is_client = True + side = "client" + + def __init__( + self, + *, + origin: Optional[Origin] = None, + extensions: Optional[Sequence[ClientExtensionFactory]] = None, + subprotocols: Optional[Sequence[Subprotocol]] = None, + extra_headers: Optional[HeadersLike] = None, + **kwargs: Any, + ) -> None: + self.origin = origin + self.available_extensions = extensions + self.available_subprotocols = subprotocols + self.extra_headers = extra_headers + super().__init__(**kwargs) + + def write_http_request(self, path: str, headers: Headers) -> None: + """ + Write request line and headers to the HTTP request. + + """ + self.path = path + self.request_headers = headers + + logger.debug("%s > GET %s HTTP/1.1", self.side, path) + logger.debug("%s > %r", self.side, headers) + + # Since the path and headers only contain ASCII characters, + # we can keep this simple. + request = f"GET {path} HTTP/1.1\r\n" + request += str(headers) + + self.transport.write(request.encode()) + + async def read_http_response(self) -> Tuple[int, Headers]: + """ + Read status line and headers from the HTTP response. + + If the response contains a body, it may be read from ``self.reader`` + after this coroutine returns. + + :raises ~websockets.exceptions.InvalidMessage: if the HTTP message is + malformed or isn't an HTTP/1.1 GET response + + """ + try: + status_code, reason, headers = await read_response(self.reader) + except asyncio.CancelledError: # pragma: no cover + raise + except Exception as exc: + raise InvalidMessage("did not receive a valid HTTP response") from exc + + logger.debug("%s < HTTP/1.1 %d %s", self.side, status_code, reason) + logger.debug("%s < %r", self.side, headers) + + self.response_headers = headers + + return status_code, self.response_headers + + @staticmethod + def process_extensions( + headers: Headers, + available_extensions: Optional[Sequence[ClientExtensionFactory]], + ) -> List[Extension]: + """ + Handle the Sec-WebSocket-Extensions HTTP response header. + + Check that each extension is supported, as well as its parameters. + + Return the list of accepted extensions. + + Raise :exc:`~websockets.exceptions.InvalidHandshake` to abort the + connection. + + :rfc:`6455` leaves the rules up to the specification of each + :extension. + + To provide this level of flexibility, for each extension accepted by + the server, we check for a match with each extension available in the + client configuration. If no match is found, an exception is raised. + + If several variants of the same extension are accepted by the server, + it may be configured several times, which won't make sense in general. + Extensions must implement their own requirements. For this purpose, + the list of previously accepted extensions is provided. + + Other requirements, for example related to mandatory extensions or the + order of extensions, may be implemented by overriding this method. + + """ + accepted_extensions: List[Extension] = [] + + header_values = headers.get_all("Sec-WebSocket-Extensions") + + if header_values: + + if available_extensions is None: + raise InvalidHandshake("no extensions supported") + + parsed_header_values: List[ExtensionHeader] = sum( + [parse_extension(header_value) for header_value in header_values], [] + ) + + for name, response_params in parsed_header_values: + + for extension_factory in available_extensions: + + # Skip non-matching extensions based on their name. + if extension_factory.name != name: + continue + + # Skip non-matching extensions based on their params. + try: + extension = extension_factory.process_response_params( + response_params, accepted_extensions + ) + except NegotiationError: + continue + + # Add matching extension to the final list. + accepted_extensions.append(extension) + + # Break out of the loop once we have a match. + break + + # If we didn't break from the loop, no extension in our list + # matched what the server sent. Fail the connection. + else: + raise NegotiationError( + f"Unsupported extension: " + f"name = {name}, params = {response_params}" + ) + + return accepted_extensions + + @staticmethod + def process_subprotocol( + headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]] + ) -> Optional[Subprotocol]: + """ + Handle the Sec-WebSocket-Protocol HTTP response header. + + Check that it contains exactly one supported subprotocol. + + Return the selected subprotocol. + + """ + subprotocol: Optional[Subprotocol] = None + + header_values = headers.get_all("Sec-WebSocket-Protocol") + + if header_values: + + if available_subprotocols is None: + raise InvalidHandshake("no subprotocols supported") + + parsed_header_values: Sequence[Subprotocol] = sum( + [parse_subprotocol(header_value) for header_value in header_values], [] + ) + + if len(parsed_header_values) > 1: + subprotocols = ", ".join(parsed_header_values) + raise InvalidHandshake(f"multiple subprotocols: {subprotocols}") + + subprotocol = parsed_header_values[0] + + if subprotocol not in available_subprotocols: + raise NegotiationError(f"unsupported subprotocol: {subprotocol}") + + return subprotocol + + async def handshake( + self, + wsuri: WebSocketURI, + origin: Optional[Origin] = None, + available_extensions: Optional[Sequence[ClientExtensionFactory]] = None, + available_subprotocols: Optional[Sequence[Subprotocol]] = None, + extra_headers: Optional[HeadersLike] = None, + ) -> None: + """ + Perform the client side of the opening handshake. + + :param origin: sets the Origin HTTP header + :param available_extensions: list of supported extensions in the order + in which they should be used + :param available_subprotocols: list of supported subprotocols in order + of decreasing preference + :param extra_headers: sets additional HTTP request headers; it must be + a :class:`~websockets.http.Headers` instance, a + :class:`~collections.abc.Mapping`, or an iterable of ``(name, + value)`` pairs + :raises ~websockets.exceptions.InvalidHandshake: if the handshake + fails + + """ + request_headers = Headers() + + if wsuri.port == (443 if wsuri.secure else 80): # pragma: no cover + request_headers["Host"] = wsuri.host + else: + request_headers["Host"] = f"{wsuri.host}:{wsuri.port}" + + if wsuri.user_info: + request_headers["Authorization"] = build_authorization_basic( + *wsuri.user_info + ) + + if origin is not None: + request_headers["Origin"] = origin + + key = build_request(request_headers) + + if available_extensions is not None: + extensions_header = build_extension( + [ + (extension_factory.name, extension_factory.get_request_params()) + for extension_factory in available_extensions + ] + ) + request_headers["Sec-WebSocket-Extensions"] = extensions_header + + if available_subprotocols is not None: + protocol_header = build_subprotocol(available_subprotocols) + request_headers["Sec-WebSocket-Protocol"] = protocol_header + + if extra_headers is not None: + if isinstance(extra_headers, Headers): + extra_headers = extra_headers.raw_items() + elif isinstance(extra_headers, collections.abc.Mapping): + extra_headers = extra_headers.items() + for name, value in extra_headers: + request_headers[name] = value + + request_headers.setdefault("User-Agent", USER_AGENT) + + self.write_http_request(wsuri.resource_name, request_headers) + + status_code, response_headers = await self.read_http_response() + if status_code in (301, 302, 303, 307, 308): + if "Location" not in response_headers: + raise InvalidHeader("Location") + raise RedirectHandshake(response_headers["Location"]) + elif status_code != 101: + raise InvalidStatusCode(status_code) + + check_response(response_headers, key) + + self.extensions = self.process_extensions( + response_headers, available_extensions + ) + + self.subprotocol = self.process_subprotocol( + response_headers, available_subprotocols + ) + + self.connection_open() + + +class Connect: + """ + Connect to the WebSocket server at the given ``uri``. + + Awaiting :func:`connect` yields a :class:`WebSocketClientProtocol` which + can then be used to send and receive messages. + + :func:`connect` can also be used as a asynchronous context manager. In + that case, the connection is closed when exiting the context. + + :func:`connect` is a wrapper around the event loop's + :meth:`~asyncio.loop.create_connection` method. Unknown keyword arguments + are passed to :meth:`~asyncio.loop.create_connection`. + + For example, you can set the ``ssl`` keyword argument to a + :class:`~ssl.SSLContext` to enforce some TLS settings. When connecting to + a ``wss://`` URI, if this argument isn't provided explicitly, + :func:`ssl.create_default_context` is called to create a context. + + You can connect to a different host and port from those found in ``uri`` + by setting ``host`` and ``port`` keyword arguments. This only changes the + destination of the TCP connection. The host name from ``uri`` is still + used in the TLS handshake for secure connections and in the ``Host`` HTTP + header. + + The ``create_protocol`` parameter allows customizing the + :class:`~asyncio.Protocol` that manages the connection. It should be a + callable or class accepting the same arguments as + :class:`WebSocketClientProtocol` and returning an instance of + :class:`WebSocketClientProtocol` or a subclass. It defaults to + :class:`WebSocketClientProtocol`. + + The behavior of ``ping_interval``, ``ping_timeout``, ``close_timeout``, + ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit`` is + described in :class:`~websockets.protocol.WebSocketCommonProtocol`. + + :func:`connect` also accepts the following optional arguments: + + * ``compression`` is a shortcut to configure compression extensions; + by default it enables the "permessage-deflate" extension; set it to + ``None`` to disable compression + * ``origin`` sets the Origin HTTP header + * ``extensions`` is a list of supported extensions in order of + decreasing preference + * ``subprotocols`` is a list of supported subprotocols in order of + decreasing preference + * ``extra_headers`` sets additional HTTP request headers; it can be a + :class:`~websockets.http.Headers` instance, a + :class:`~collections.abc.Mapping`, or an iterable of ``(name, value)`` + pairs + + :raises ~websockets.uri.InvalidURI: if ``uri`` is invalid + :raises ~websockets.handshake.InvalidHandshake: if the opening handshake + fails + + """ + + MAX_REDIRECTS_ALLOWED = 10 + + def __init__( + self, + uri: str, + *, + path: Optional[str] = None, + create_protocol: Optional[Type[WebSocketClientProtocol]] = None, + ping_interval: Optional[float] = 20, + ping_timeout: Optional[float] = 20, + close_timeout: Optional[float] = None, + max_size: Optional[int] = 2 ** 20, + max_queue: Optional[int] = 2 ** 5, + read_limit: int = 2 ** 16, + write_limit: int = 2 ** 16, + loop: Optional[asyncio.AbstractEventLoop] = None, + legacy_recv: bool = False, + klass: Optional[Type[WebSocketClientProtocol]] = None, + timeout: Optional[float] = None, + compression: Optional[str] = "deflate", + origin: Optional[Origin] = None, + extensions: Optional[Sequence[ClientExtensionFactory]] = None, + subprotocols: Optional[Sequence[Subprotocol]] = None, + extra_headers: Optional[HeadersLike] = None, + **kwargs: Any, + ) -> None: + # Backwards compatibility: close_timeout used to be called timeout. + if timeout is None: + timeout = 10 + else: + warnings.warn("rename timeout to close_timeout", DeprecationWarning) + # If both are specified, timeout is ignored. + if close_timeout is None: + close_timeout = timeout + + # Backwards compatibility: create_protocol used to be called klass. + if klass is None: + klass = WebSocketClientProtocol + else: + warnings.warn("rename klass to create_protocol", DeprecationWarning) + # If both are specified, klass is ignored. + if create_protocol is None: + create_protocol = klass + + if loop is None: + loop = asyncio.get_event_loop() + + wsuri = parse_uri(uri) + if wsuri.secure: + kwargs.setdefault("ssl", True) + elif kwargs.get("ssl") is not None: + raise ValueError( + "connect() received a ssl argument for a ws:// URI, " + "use a wss:// URI to enable TLS" + ) + + if compression == "deflate": + if extensions is None: + extensions = [] + if not any( + extension_factory.name == ClientPerMessageDeflateFactory.name + for extension_factory in extensions + ): + extensions = list(extensions) + [ + ClientPerMessageDeflateFactory(client_max_window_bits=True) + ] + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + factory = functools.partial( + create_protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + close_timeout=close_timeout, + max_size=max_size, + max_queue=max_queue, + read_limit=read_limit, + write_limit=write_limit, + loop=loop, + host=wsuri.host, + port=wsuri.port, + secure=wsuri.secure, + legacy_recv=legacy_recv, + origin=origin, + extensions=extensions, + subprotocols=subprotocols, + extra_headers=extra_headers, + ) + + if path is None: + host: Optional[str] + port: Optional[int] + if kwargs.get("sock") is None: + host, port = wsuri.host, wsuri.port + else: + # If sock is given, host and port shouldn't be specified. + host, port = None, None + # If host and port are given, override values from the URI. + host = kwargs.pop("host", host) + port = kwargs.pop("port", port) + create_connection = functools.partial( + loop.create_connection, factory, host, port, **kwargs + ) + else: + create_connection = functools.partial( + loop.create_unix_connection, factory, path, **kwargs + ) + + # This is a coroutine function. + self._create_connection = create_connection + self._wsuri = wsuri + + def handle_redirect(self, uri: str) -> None: + # Update the state of this instance to connect to a new URI. + old_wsuri = self._wsuri + new_wsuri = parse_uri(uri) + + # Forbid TLS downgrade. + if old_wsuri.secure and not new_wsuri.secure: + raise SecurityError("redirect from WSS to WS") + + same_origin = ( + old_wsuri.host == new_wsuri.host and old_wsuri.port == new_wsuri.port + ) + + # Rewrite the host and port arguments for cross-origin redirects. + # This preserves connection overrides with the host and port + # arguments if the redirect points to the same host and port. + if not same_origin: + # Replace the host and port argument passed to the protocol factory. + factory = self._create_connection.args[0] + factory = functools.partial( + factory.func, + *factory.args, + **dict(factory.keywords, host=new_wsuri.host, port=new_wsuri.port), + ) + # Replace the host and port argument passed to create_connection. + self._create_connection = functools.partial( + self._create_connection.func, + *(factory, new_wsuri.host, new_wsuri.port), + **self._create_connection.keywords, + ) + + # Set the new WebSocket URI. This suffices for same-origin redirects. + self._wsuri = new_wsuri + + # async with connect(...) + + async def __aenter__(self) -> WebSocketClientProtocol: + return await self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + await self.ws_client.close() + + # await connect(...) + + def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]: + # Create a suitable iterator by calling __await__ on a coroutine. + return self.__await_impl__().__await__() + + async def __await_impl__(self) -> WebSocketClientProtocol: + for redirects in range(self.MAX_REDIRECTS_ALLOWED): + transport, protocol = await self._create_connection() + # https://github.com/python/typeshed/pull/2756 + transport = cast(asyncio.Transport, transport) + protocol = cast(WebSocketClientProtocol, protocol) + + try: + try: + await protocol.handshake( + self._wsuri, + origin=protocol.origin, + available_extensions=protocol.available_extensions, + available_subprotocols=protocol.available_subprotocols, + extra_headers=protocol.extra_headers, + ) + except Exception: + protocol.fail_connection() + await protocol.wait_closed() + raise + else: + self.ws_client = protocol + return protocol + except RedirectHandshake as exc: + self.handle_redirect(exc.uri) + else: + raise SecurityError("too many redirects") + + # yield from connect(...) + + __iter__ = __await__ + + +connect = Connect + + +def unix_connect(path: str, uri: str = "ws://localhost/", **kwargs: Any) -> Connect: + """ + Similar to :func:`connect`, but for connecting to a Unix socket. + + This function calls the event loop's + :meth:`~asyncio.loop.create_unix_connection` method. + + It is only available on Unix. + + It's mainly useful for debugging servers listening on Unix sockets. + + :param path: file system path to the Unix socket + :param uri: WebSocket URI + + """ + return connect(uri=uri, path=path, **kwargs) diff --git a/src/websockets/asyncio_server.py b/src/websockets/asyncio_server.py new file mode 100644 index 000000000..1eeddf0eb --- /dev/null +++ b/src/websockets/asyncio_server.py @@ -0,0 +1,1004 @@ +""" +:mod:`websockets.server` defines the WebSocket server APIs. + +""" + +import asyncio +import collections.abc +import email.utils +import functools +import http +import logging +import socket +import sys +import warnings +from types import TracebackType +from typing import ( + Any, + Awaitable, + Callable, + Generator, + List, + Optional, + Sequence, + Set, + Tuple, + Type, + Union, + cast, +) + +from .datastructures import Headers, HeadersLike, MultipleValuesError +from .exceptions import ( + AbortHandshake, + InvalidHandshake, + InvalidHeader, + InvalidMessage, + InvalidOrigin, + InvalidUpgrade, + NegotiationError, +) +from .extensions.base import Extension, ServerExtensionFactory +from .extensions.permessage_deflate import ServerPerMessageDeflateFactory +from .handshake_legacy import build_response, check_request +from .headers import build_extension, parse_extension, parse_subprotocol +from .http import USER_AGENT +from .http_legacy import read_request +from .protocol import WebSocketCommonProtocol +from .typing import ExtensionHeader, Origin, Subprotocol + + +__all__ = ["serve", "unix_serve", "WebSocketServerProtocol", "WebSocketServer"] + +logger = logging.getLogger(__name__) + + +HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]] + +HTTPResponse = Tuple[http.HTTPStatus, HeadersLike, bytes] + + +class WebSocketServerProtocol(WebSocketCommonProtocol): + """ + :class:`~asyncio.Protocol` subclass implementing a WebSocket server. + + This class inherits most of its methods from + :class:`~websockets.protocol.WebSocketCommonProtocol`. + + For the sake of simplicity, it doesn't rely on a full HTTP implementation. + Its support for HTTP responses is very limited. + + """ + + is_client = False + side = "server" + + def __init__( + self, + ws_handler: Callable[["WebSocketServerProtocol", str], Awaitable[Any]], + ws_server: "WebSocketServer", + *, + origins: Optional[Sequence[Optional[Origin]]] = None, + extensions: Optional[Sequence[ServerExtensionFactory]] = None, + subprotocols: Optional[Sequence[Subprotocol]] = None, + extra_headers: Optional[HeadersLikeOrCallable] = None, + process_request: Optional[ + Callable[[str, Headers], Awaitable[Optional[HTTPResponse]]] + ] = None, + select_subprotocol: Optional[ + Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] + ] = None, + **kwargs: Any, + ) -> None: + # For backwards compatibility with 6.0 or earlier. + if origins is not None and "" in origins: + warnings.warn("use None instead of '' in origins", DeprecationWarning) + origins = [None if origin == "" else origin for origin in origins] + self.ws_handler = ws_handler + self.ws_server = ws_server + self.origins = origins + self.available_extensions = extensions + self.available_subprotocols = subprotocols + self.extra_headers = extra_headers + self._process_request = process_request + self._select_subprotocol = select_subprotocol + super().__init__(**kwargs) + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + """ + Register connection and initialize a task to handle it. + + """ + super().connection_made(transport) + # Register the connection with the server before creating the handler + # task. Registering at the beginning of the handler coroutine would + # create a race condition between the creation of the task, which + # schedules its execution, and the moment the handler starts running. + self.ws_server.register(self) + self.handler_task = self.loop.create_task(self.handler()) + + async def handler(self) -> None: + """ + Handle the lifecycle of a WebSocket connection. + + Since this method doesn't have a caller able to handle exceptions, it + attemps to log relevant ones and guarantees that the TCP connection is + closed before exiting. + + """ + try: + + try: + path = await self.handshake( + origins=self.origins, + available_extensions=self.available_extensions, + available_subprotocols=self.available_subprotocols, + extra_headers=self.extra_headers, + ) + except asyncio.CancelledError: # pragma: no cover + raise + except ConnectionError: + logger.debug("Connection error in opening handshake", exc_info=True) + raise + except Exception as exc: + if isinstance(exc, AbortHandshake): + status, headers, body = exc.status, exc.headers, exc.body + elif isinstance(exc, InvalidOrigin): + logger.debug("Invalid origin", exc_info=True) + status, headers, body = ( + http.HTTPStatus.FORBIDDEN, + Headers(), + f"Failed to open a WebSocket connection: {exc}.\n".encode(), + ) + elif isinstance(exc, InvalidUpgrade): + logger.debug("Invalid upgrade", exc_info=True) + status, headers, body = ( + http.HTTPStatus.UPGRADE_REQUIRED, + Headers([("Upgrade", "websocket")]), + ( + f"Failed to open a WebSocket connection: {exc}.\n" + f"\n" + f"You cannot access a WebSocket server directly " + f"with a browser. You need a WebSocket client.\n" + ).encode(), + ) + elif isinstance(exc, InvalidHandshake): + logger.debug("Invalid handshake", exc_info=True) + status, headers, body = ( + http.HTTPStatus.BAD_REQUEST, + Headers(), + f"Failed to open a WebSocket connection: {exc}.\n".encode(), + ) + else: + logger.warning("Error in opening handshake", exc_info=True) + status, headers, body = ( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + Headers(), + ( + b"Failed to open a WebSocket connection.\n" + b"See server log for more information.\n" + ), + ) + + headers.setdefault("Date", email.utils.formatdate(usegmt=True)) + headers.setdefault("Server", USER_AGENT) + headers.setdefault("Content-Length", str(len(body))) + headers.setdefault("Content-Type", "text/plain") + headers.setdefault("Connection", "close") + + self.write_http_response(status, headers, body) + self.fail_connection() + await self.wait_closed() + return + + try: + await self.ws_handler(self, path) + except Exception: + logger.error("Error in connection handler", exc_info=True) + if not self.closed: + self.fail_connection(1011) + raise + + try: + await self.close() + except ConnectionError: + logger.debug("Connection error in closing handshake", exc_info=True) + raise + except Exception: + logger.warning("Error in closing handshake", exc_info=True) + raise + + except Exception: + # Last-ditch attempt to avoid leaking connections on errors. + try: + self.transport.close() + except Exception: # pragma: no cover + pass + + finally: + # Unregister the connection with the server when the handler task + # terminates. Registration is tied to the lifecycle of the handler + # task because the server waits for tasks attached to registered + # connections before terminating. + self.ws_server.unregister(self) + + async def read_http_request(self) -> Tuple[str, Headers]: + """ + Read request line and headers from the HTTP request. + + If the request contains a body, it may be read from ``self.reader`` + after this coroutine returns. + + :raises ~websockets.exceptions.InvalidMessage: if the HTTP message is + malformed or isn't an HTTP/1.1 GET request + + """ + try: + path, headers = await read_request(self.reader) + except asyncio.CancelledError: # pragma: no cover + raise + except Exception as exc: + raise InvalidMessage("did not receive a valid HTTP request") from exc + + logger.debug("%s < GET %s HTTP/1.1", self.side, path) + logger.debug("%s < %r", self.side, headers) + + self.path = path + self.request_headers = headers + + return path, headers + + def write_http_response( + self, status: http.HTTPStatus, headers: Headers, body: Optional[bytes] = None + ) -> None: + """ + Write status line and headers to the HTTP response. + + This coroutine is also able to write a response body. + + """ + self.response_headers = headers + + logger.debug("%s > HTTP/1.1 %d %s", self.side, status.value, status.phrase) + logger.debug("%s > %r", self.side, headers) + + # Since the status line and headers only contain ASCII characters, + # we can keep this simple. + response = f"HTTP/1.1 {status.value} {status.phrase}\r\n" + response += str(headers) + + self.transport.write(response.encode()) + + if body is not None: + logger.debug("%s > body (%d bytes)", self.side, len(body)) + self.transport.write(body) + + async def process_request( + self, path: str, request_headers: Headers + ) -> Optional[HTTPResponse]: + """ + Intercept the HTTP request and return an HTTP response if appropriate. + + If ``process_request`` returns ``None``, the WebSocket handshake + continues. If it returns 3-uple containing a status code, response + headers and a response body, that HTTP response is sent and the + connection is closed. In that case: + + * The HTTP status must be a :class:`~http.HTTPStatus`. + * HTTP headers must be a :class:`~websockets.http.Headers` instance, a + :class:`~collections.abc.Mapping`, or an iterable of ``(name, + value)`` pairs. + * The HTTP response body must be :class:`bytes`. It may be empty. + + This coroutine may be overridden in a :class:`WebSocketServerProtocol` + subclass, for example: + + * to return a HTTP 200 OK response on a given path; then a load + balancer can use this path for a health check; + * to authenticate the request and return a HTTP 401 Unauthorized or a + HTTP 403 Forbidden when authentication fails. + + Instead of subclassing, it is possible to override this method by + passing a ``process_request`` argument to the :func:`serve` function + or the :class:`WebSocketServerProtocol` constructor. This is + equivalent, except ``process_request`` won't have access to the + protocol instance, so it can't store information for later use. + + ``process_request`` is expected to complete quickly. If it may run for + a long time, then it should await :meth:`wait_closed` and exit if + :meth:`wait_closed` completes, or else it could prevent the server + from shutting down. + + :param path: request path, including optional query string + :param request_headers: request headers + + """ + if self._process_request is not None: + response = self._process_request(path, request_headers) + if isinstance(response, Awaitable): + return await response + else: + # For backwards compatibility with 7.0. + warnings.warn( + "declare process_request as a coroutine", DeprecationWarning + ) + return response # type: ignore + return None + + @staticmethod + def process_origin( + headers: Headers, origins: Optional[Sequence[Optional[Origin]]] = None + ) -> Optional[Origin]: + """ + Handle the Origin HTTP request header. + + :param headers: request headers + :param origins: optional list of acceptable origins + :raises ~websockets.exceptions.InvalidOrigin: if the origin isn't + acceptable + + """ + # "The user agent MUST NOT include more than one Origin header field" + # per https://tools.ietf.org/html/rfc6454#section-7.3. + try: + origin = cast(Origin, headers.get("Origin")) + except MultipleValuesError as exc: + raise InvalidHeader("Origin", "more than one Origin header found") from exc + if origins is not None: + if origin not in origins: + raise InvalidOrigin(origin) + return origin + + @staticmethod + def process_extensions( + headers: Headers, + available_extensions: Optional[Sequence[ServerExtensionFactory]], + ) -> Tuple[Optional[str], List[Extension]]: + """ + Handle the Sec-WebSocket-Extensions HTTP request header. + + Accept or reject each extension proposed in the client request. + Negotiate parameters for accepted extensions. + + Return the Sec-WebSocket-Extensions HTTP response header and the list + of accepted extensions. + + :rfc:`6455` leaves the rules up to the specification of each + :extension. + + To provide this level of flexibility, for each extension proposed by + the client, we check for a match with each extension available in the + server configuration. If no match is found, the extension is ignored. + + If several variants of the same extension are proposed by the client, + it may be accepted several times, which won't make sense in general. + Extensions must implement their own requirements. For this purpose, + the list of previously accepted extensions is provided. + + This process doesn't allow the server to reorder extensions. It can + only select a subset of the extensions proposed by the client. + + Other requirements, for example related to mandatory extensions or the + order of extensions, may be implemented by overriding this method. + + :param headers: request headers + :param extensions: optional list of supported extensions + :raises ~websockets.exceptions.InvalidHandshake: to abort the + handshake with an HTTP 400 error code + + """ + response_header_value: Optional[str] = None + + extension_headers: List[ExtensionHeader] = [] + accepted_extensions: List[Extension] = [] + + header_values = headers.get_all("Sec-WebSocket-Extensions") + + if header_values and available_extensions: + + parsed_header_values: List[ExtensionHeader] = sum( + [parse_extension(header_value) for header_value in header_values], [] + ) + + for name, request_params in parsed_header_values: + + for ext_factory in available_extensions: + + # Skip non-matching extensions based on their name. + if ext_factory.name != name: + continue + + # Skip non-matching extensions based on their params. + try: + response_params, extension = ext_factory.process_request_params( + request_params, accepted_extensions + ) + except NegotiationError: + continue + + # Add matching extension to the final list. + extension_headers.append((name, response_params)) + accepted_extensions.append(extension) + + # Break out of the loop once we have a match. + break + + # If we didn't break from the loop, no extension in our list + # matched what the client sent. The extension is declined. + + # Serialize extension header. + if extension_headers: + response_header_value = build_extension(extension_headers) + + return response_header_value, accepted_extensions + + # Not @staticmethod because it calls self.select_subprotocol() + def process_subprotocol( + self, headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]] + ) -> Optional[Subprotocol]: + """ + Handle the Sec-WebSocket-Protocol HTTP request header. + + Return Sec-WebSocket-Protocol HTTP response header, which is the same + as the selected subprotocol. + + :param headers: request headers + :param available_subprotocols: optional list of supported subprotocols + :raises ~websockets.exceptions.InvalidHandshake: to abort the + handshake with an HTTP 400 error code + + """ + subprotocol: Optional[Subprotocol] = None + + header_values = headers.get_all("Sec-WebSocket-Protocol") + + if header_values and available_subprotocols: + + parsed_header_values: List[Subprotocol] = sum( + [parse_subprotocol(header_value) for header_value in header_values], [] + ) + + subprotocol = self.select_subprotocol( + parsed_header_values, available_subprotocols + ) + + return subprotocol + + def select_subprotocol( + self, + client_subprotocols: Sequence[Subprotocol], + server_subprotocols: Sequence[Subprotocol], + ) -> Optional[Subprotocol]: + """ + Pick a subprotocol among those offered by the client. + + If several subprotocols are supported by the client and the server, + the default implementation selects the preferred subprotocols by + giving equal value to the priorities of the client and the server. + + If no subprotocol is supported by the client and the server, it + proceeds without a subprotocol. + + This is unlikely to be the most useful implementation in practice, as + many servers providing a subprotocol will require that the client uses + that subprotocol. Such rules can be implemented in a subclass. + + Instead of subclassing, it is possible to override this method by + passing a ``select_subprotocol`` argument to the :func:`serve` + function or the :class:`WebSocketServerProtocol` constructor + + :param client_subprotocols: list of subprotocols offered by the client + :param server_subprotocols: list of subprotocols available on the server + + """ + if self._select_subprotocol is not None: + return self._select_subprotocol(client_subprotocols, server_subprotocols) + + subprotocols = set(client_subprotocols) & set(server_subprotocols) + if not subprotocols: + return None + priority = lambda p: ( + client_subprotocols.index(p) + server_subprotocols.index(p) + ) + return sorted(subprotocols, key=priority)[0] + + async def handshake( + self, + origins: Optional[Sequence[Optional[Origin]]] = None, + available_extensions: Optional[Sequence[ServerExtensionFactory]] = None, + available_subprotocols: Optional[Sequence[Subprotocol]] = None, + extra_headers: Optional[HeadersLikeOrCallable] = None, + ) -> str: + """ + Perform the server side of the opening handshake. + + Return the path of the URI of the request. + + :param origins: list of acceptable values of the Origin HTTP header; + include ``None`` if the lack of an origin is acceptable + :param available_extensions: list of supported extensions in the order + in which they should be used + :param available_subprotocols: list of supported subprotocols in order + of decreasing preference + :param extra_headers: sets additional HTTP response headers when the + handshake succeeds; it can be a :class:`~websockets.http.Headers` + instance, a :class:`~collections.abc.Mapping`, an iterable of + ``(name, value)`` pairs, or a callable taking the request path and + headers in arguments and returning one of the above. + :raises ~websockets.exceptions.InvalidHandshake: if the handshake + fails + + """ + path, request_headers = await self.read_http_request() + + # Hook for customizing request handling, for example checking + # authentication or treating some paths as plain HTTP endpoints. + early_response_awaitable = self.process_request(path, request_headers) + if isinstance(early_response_awaitable, Awaitable): + early_response = await early_response_awaitable + else: + # For backwards compatibility with 7.0. + warnings.warn("declare process_request as a coroutine", DeprecationWarning) + early_response = early_response_awaitable # type: ignore + + # Change the response to a 503 error if the server is shutting down. + if not self.ws_server.is_serving(): + early_response = ( + http.HTTPStatus.SERVICE_UNAVAILABLE, + [], + b"Server is shutting down.\n", + ) + + if early_response is not None: + raise AbortHandshake(*early_response) + + key = check_request(request_headers) + + self.origin = self.process_origin(request_headers, origins) + + extensions_header, self.extensions = self.process_extensions( + request_headers, available_extensions + ) + + protocol_header = self.subprotocol = self.process_subprotocol( + request_headers, available_subprotocols + ) + + response_headers = Headers() + + build_response(response_headers, key) + + if extensions_header is not None: + response_headers["Sec-WebSocket-Extensions"] = extensions_header + + if protocol_header is not None: + response_headers["Sec-WebSocket-Protocol"] = protocol_header + + if callable(extra_headers): + extra_headers = extra_headers(path, self.request_headers) + if extra_headers is not None: + if isinstance(extra_headers, Headers): + extra_headers = extra_headers.raw_items() + elif isinstance(extra_headers, collections.abc.Mapping): + extra_headers = extra_headers.items() + for name, value in extra_headers: + response_headers[name] = value + + response_headers.setdefault("Date", email.utils.formatdate(usegmt=True)) + response_headers.setdefault("Server", USER_AGENT) + + self.write_http_response(http.HTTPStatus.SWITCHING_PROTOCOLS, response_headers) + + self.connection_open() + + return path + + +class WebSocketServer: + """ + WebSocket server returned by :func:`~websockets.server.serve`. + + This class provides the same interface as + :class:`~asyncio.AbstractServer`, namely the + :meth:`~asyncio.AbstractServer.close` and + :meth:`~asyncio.AbstractServer.wait_closed` methods. + + It keeps track of WebSocket connections in order to close them properly + when shutting down. + + Instances of this class store a reference to the :class:`~asyncio.Server` + object returned by :meth:`~asyncio.loop.create_server` rather than inherit + from :class:`~asyncio.Server` in part because + :meth:`~asyncio.loop.create_server` doesn't support passing a custom + :class:`~asyncio.Server` class. + + """ + + def __init__(self, loop: asyncio.AbstractEventLoop) -> None: + # Store a reference to loop to avoid relying on self.server._loop. + self.loop = loop + + # Keep track of active connections. + self.websockets: Set[WebSocketServerProtocol] = set() + + # Task responsible for closing the server and terminating connections. + self.close_task: Optional[asyncio.Task[None]] = None + + # Completed when the server is closed and connections are terminated. + self.closed_waiter: asyncio.Future[None] = loop.create_future() + + def wrap(self, server: asyncio.AbstractServer) -> None: + """ + Attach to a given :class:`~asyncio.Server`. + + Since :meth:`~asyncio.loop.create_server` doesn't support injecting a + custom ``Server`` class, the easiest solution that doesn't rely on + private :mod:`asyncio` APIs is to: + + - instantiate a :class:`WebSocketServer` + - give the protocol factory a reference to that instance + - call :meth:`~asyncio.loop.create_server` with the factory + - attach the resulting :class:`~asyncio.Server` with this method + + """ + self.server = server + + def register(self, protocol: WebSocketServerProtocol) -> None: + """ + Register a connection with this server. + + """ + self.websockets.add(protocol) + + def unregister(self, protocol: WebSocketServerProtocol) -> None: + """ + Unregister a connection with this server. + + """ + self.websockets.remove(protocol) + + def is_serving(self) -> bool: + """ + Tell whether the server is accepting new connections or shutting down. + + """ + try: + # Python ≥ 3.7 + return self.server.is_serving() + except AttributeError: # pragma: no cover + # Python < 3.7 + return self.server.sockets is not None + + def close(self) -> None: + """ + Close the server. + + This method: + + * closes the underlying :class:`~asyncio.Server`; + * rejects new WebSocket connections with an HTTP 503 (service + unavailable) error; this happens when the server accepted the TCP + connection but didn't complete the WebSocket opening handshake prior + to closing; + * closes open WebSocket connections with close code 1001 (going away). + + :meth:`close` is idempotent. + + """ + if self.close_task is None: + self.close_task = self.loop.create_task(self._close()) + + async def _close(self) -> None: + """ + Implementation of :meth:`close`. + + This calls :meth:`~asyncio.Server.close` on the underlying + :class:`~asyncio.Server` object to stop accepting new connections and + then closes open connections with close code 1001. + + """ + # Stop accepting new connections. + self.server.close() + + # Wait until self.server.close() completes. + await self.server.wait_closed() + + # Wait until all accepted connections reach connection_made() and call + # register(). See https://bugs.python.org/issue34852 for details. + await asyncio.sleep( + 0, loop=self.loop if sys.version_info[:2] < (3, 8) else None + ) + + # Close OPEN connections with status code 1001. Since the server was + # closed, handshake() closes OPENING conections with a HTTP 503 error. + # Wait until all connections are closed. + + # asyncio.wait doesn't accept an empty first argument + if self.websockets: + await asyncio.wait( + [ + asyncio.ensure_future(websocket.close(1001)) + for websocket in self.websockets + ], + loop=self.loop if sys.version_info[:2] < (3, 8) else None, + ) + + # Wait until all connection handlers are complete. + + # asyncio.wait doesn't accept an empty first argument. + if self.websockets: + await asyncio.wait( + [websocket.handler_task for websocket in self.websockets], + loop=self.loop if sys.version_info[:2] < (3, 8) else None, + ) + + # Tell wait_closed() to return. + self.closed_waiter.set_result(None) + + async def wait_closed(self) -> None: + """ + Wait until the server is closed. + + When :meth:`wait_closed` returns, all TCP connections are closed and + all connection handlers have returned. + + """ + await asyncio.shield(self.closed_waiter) + + @property + def sockets(self) -> Optional[List[socket.socket]]: + """ + List of :class:`~socket.socket` objects the server is listening to. + + ``None`` if the server is closed. + + """ + return self.server.sockets + + +class Serve: + """ + + Create, start, and return a WebSocket server on ``host`` and ``port``. + + Whenever a client connects, the server accepts the connection, creates a + :class:`WebSocketServerProtocol`, performs the opening handshake, and + delegates to the connection handler defined by ``ws_handler``. Once the + handler completes, either normally or with an exception, the server + performs the closing handshake and closes the connection. + + Awaiting :func:`serve` yields a :class:`WebSocketServer`. This instance + provides :meth:`~websockets.server.WebSocketServer.close` and + :meth:`~websockets.server.WebSocketServer.wait_closed` methods for + terminating the server and cleaning up its resources. + + When a server is closed with :meth:`~WebSocketServer.close`, it closes all + connections with close code 1001 (going away). Connections handlers, which + are running the ``ws_handler`` coroutine, will receive a + :exc:`~websockets.exceptions.ConnectionClosedOK` exception on their + current or next interaction with the WebSocket connection. + + :func:`serve` can also be used as an asynchronous context manager. In + this case, the server is shut down when exiting the context. + + :func:`serve` is a wrapper around the event loop's + :meth:`~asyncio.loop.create_server` method. It creates and starts a + :class:`~asyncio.Server` with :meth:`~asyncio.loop.create_server`. Then it + wraps the :class:`~asyncio.Server` in a :class:`WebSocketServer` and + returns the :class:`WebSocketServer`. + + The ``ws_handler`` argument is the WebSocket handler. It must be a + coroutine accepting two arguments: a :class:`WebSocketServerProtocol` and + the request URI. + + The ``host`` and ``port`` arguments, as well as unrecognized keyword + arguments, are passed along to :meth:`~asyncio.loop.create_server`. + + For example, you can set the ``ssl`` keyword argument to a + :class:`~ssl.SSLContext` to enable TLS. + + The ``create_protocol`` parameter allows customizing the + :class:`~asyncio.Protocol` that manages the connection. It should be a + callable or class accepting the same arguments as + :class:`WebSocketServerProtocol` and returning an instance of + :class:`WebSocketServerProtocol` or a subclass. It defaults to + :class:`WebSocketServerProtocol`. + + The behavior of ``ping_interval``, ``ping_timeout``, ``close_timeout``, + ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit`` is + described in :class:`~websockets.protocol.WebSocketCommonProtocol`. + + :func:`serve` also accepts the following optional arguments: + + * ``compression`` is a shortcut to configure compression extensions; + by default it enables the "permessage-deflate" extension; set it to + ``None`` to disable compression + * ``origins`` defines acceptable Origin HTTP headers; include ``None`` if + the lack of an origin is acceptable + * ``extensions`` is a list of supported extensions in order of + decreasing preference + * ``subprotocols`` is a list of supported subprotocols in order of + decreasing preference + * ``extra_headers`` sets additional HTTP response headers when the + handshake succeeds; it can be a :class:`~websockets.http.Headers` + instance, a :class:`~collections.abc.Mapping`, an iterable of ``(name, + value)`` pairs, or a callable taking the request path and headers in + arguments and returning one of the above + * ``process_request`` allows intercepting the HTTP request; it must be a + coroutine taking the request path and headers in argument; see + :meth:`~WebSocketServerProtocol.process_request` for details + * ``select_subprotocol`` allows customizing the logic for selecting a + subprotocol; it must be a callable taking the subprotocols offered by + the client and available on the server in argument; see + :meth:`~WebSocketServerProtocol.select_subprotocol` for details + + Since there's no useful way to propagate exceptions triggered in handlers, + they're sent to the ``'websockets.asyncio_server'`` logger instead. + Debugging is much easier if you configure logging to print them:: + + import logging + logger = logging.getLogger("websockets.asyncio_server") + logger.setLevel(logging.ERROR) + logger.addHandler(logging.StreamHandler()) + + """ + + def __init__( + self, + ws_handler: Callable[[WebSocketServerProtocol, str], Awaitable[Any]], + host: Optional[Union[str, Sequence[str]]] = None, + port: Optional[int] = None, + *, + path: Optional[str] = None, + create_protocol: Optional[Type[WebSocketServerProtocol]] = None, + ping_interval: Optional[float] = 20, + ping_timeout: Optional[float] = 20, + close_timeout: Optional[float] = None, + max_size: Optional[int] = 2 ** 20, + max_queue: Optional[int] = 2 ** 5, + read_limit: int = 2 ** 16, + write_limit: int = 2 ** 16, + loop: Optional[asyncio.AbstractEventLoop] = None, + legacy_recv: bool = False, + klass: Optional[Type[WebSocketServerProtocol]] = None, + timeout: Optional[float] = None, + compression: Optional[str] = "deflate", + origins: Optional[Sequence[Optional[Origin]]] = None, + extensions: Optional[Sequence[ServerExtensionFactory]] = None, + subprotocols: Optional[Sequence[Subprotocol]] = None, + extra_headers: Optional[HeadersLikeOrCallable] = None, + process_request: Optional[ + Callable[[str, Headers], Awaitable[Optional[HTTPResponse]]] + ] = None, + select_subprotocol: Optional[ + Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] + ] = None, + **kwargs: Any, + ) -> None: + # Backwards compatibility: close_timeout used to be called timeout. + if timeout is None: + timeout = 10 + else: + warnings.warn("rename timeout to close_timeout", DeprecationWarning) + # If both are specified, timeout is ignored. + if close_timeout is None: + close_timeout = timeout + + # Backwards compatibility: create_protocol used to be called klass. + if klass is None: + klass = WebSocketServerProtocol + else: + warnings.warn("rename klass to create_protocol", DeprecationWarning) + # If both are specified, klass is ignored. + if create_protocol is None: + create_protocol = klass + + if loop is None: + loop = asyncio.get_event_loop() + + ws_server = WebSocketServer(loop) + + secure = kwargs.get("ssl") is not None + + if compression == "deflate": + if extensions is None: + extensions = [] + if not any( + ext_factory.name == ServerPerMessageDeflateFactory.name + for ext_factory in extensions + ): + extensions = list(extensions) + [ServerPerMessageDeflateFactory()] + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + factory = functools.partial( + create_protocol, + ws_handler, + ws_server, + host=host, + port=port, + secure=secure, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + close_timeout=close_timeout, + max_size=max_size, + max_queue=max_queue, + read_limit=read_limit, + write_limit=write_limit, + loop=loop, + legacy_recv=legacy_recv, + origins=origins, + extensions=extensions, + subprotocols=subprotocols, + extra_headers=extra_headers, + process_request=process_request, + select_subprotocol=select_subprotocol, + ) + + if path is None: + create_server = functools.partial( + loop.create_server, factory, host, port, **kwargs + ) + else: + # unix_serve(path) must not specify host and port parameters. + assert host is None and port is None + create_server = functools.partial( + loop.create_unix_server, factory, path, **kwargs + ) + + # This is a coroutine function. + self._create_server = create_server + self.ws_server = ws_server + + # async with serve(...) + + async def __aenter__(self) -> WebSocketServer: + return await self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + self.ws_server.close() + await self.ws_server.wait_closed() + + # await serve(...) + + def __await__(self) -> Generator[Any, None, WebSocketServer]: + # Create a suitable iterator by calling __await__ on a coroutine. + return self.__await_impl__().__await__() + + async def __await_impl__(self) -> WebSocketServer: + server = await self._create_server() + self.ws_server.wrap(server) + return self.ws_server + + # yield from serve(...) + + __iter__ = __await__ + + +serve = Serve + + +def unix_serve( + ws_handler: Callable[[WebSocketServerProtocol, str], Awaitable[Any]], + path: str, + **kwargs: Any, +) -> Serve: + """ + Similar to :func:`serve`, but for listening on Unix sockets. + + This function calls the event loop's + :meth:`~asyncio.loop.create_unix_server` method. + + It is only available on Unix. + + It's useful for deploying a server behind a reverse proxy such as nginx. + + :param path: file system path to the Unix socket + + """ + return serve(ws_handler, path=path, **kwargs) diff --git a/src/websockets/auth.py b/src/websockets/auth.py index 8198cd9d0..03e8536c5 100644 --- a/src/websockets/auth.py +++ b/src/websockets/auth.py @@ -9,10 +9,10 @@ import http from typing import Any, Awaitable, Callable, Iterable, Optional, Tuple, Type, Union +from .asyncio_server import HTTPResponse, WebSocketServerProtocol from .datastructures import Headers from .exceptions import InvalidHeader from .headers import build_www_authenticate_basic, parse_authorization_basic -from .server import HTTPResponse, WebSocketServerProtocol __all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"] diff --git a/src/websockets/client.py b/src/websockets/client.py index f95dae060..c7d153f13 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -1,588 +1,8 @@ -""" -:mod:`websockets.client` defines the WebSocket client APIs. +from .asyncio_client import WebSocketClientProtocol, connect, unix_connect -""" -import asyncio -import collections.abc -import functools -import logging -import warnings -from types import TracebackType -from typing import Any, Generator, List, Optional, Sequence, Tuple, Type, cast - -from .datastructures import Headers, HeadersLike -from .exceptions import ( - InvalidHandshake, - InvalidHeader, - InvalidMessage, - InvalidStatusCode, - NegotiationError, - RedirectHandshake, - SecurityError, -) -from .extensions.base import ClientExtensionFactory, Extension -from .extensions.permessage_deflate import ClientPerMessageDeflateFactory -from .handshake_legacy import build_request, check_response -from .headers import ( - build_authorization_basic, - build_extension, - build_subprotocol, - parse_extension, - parse_subprotocol, -) -from .http import USER_AGENT -from .http_legacy import read_response -from .protocol import WebSocketCommonProtocol -from .typing import ExtensionHeader, Origin, Subprotocol -from .uri import WebSocketURI, parse_uri - - -__all__ = ["connect", "unix_connect", "WebSocketClientProtocol"] - -logger = logging.getLogger(__name__) - - -class WebSocketClientProtocol(WebSocketCommonProtocol): - """ - :class:`~asyncio.Protocol` subclass implementing a WebSocket client. - - This class inherits most of its methods from - :class:`~websockets.protocol.WebSocketCommonProtocol`. - - """ - - is_client = True - side = "client" - - def __init__( - self, - *, - origin: Optional[Origin] = None, - extensions: Optional[Sequence[ClientExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLike] = None, - **kwargs: Any, - ) -> None: - self.origin = origin - self.available_extensions = extensions - self.available_subprotocols = subprotocols - self.extra_headers = extra_headers - super().__init__(**kwargs) - - def write_http_request(self, path: str, headers: Headers) -> None: - """ - Write request line and headers to the HTTP request. - - """ - self.path = path - self.request_headers = headers - - logger.debug("%s > GET %s HTTP/1.1", self.side, path) - logger.debug("%s > %r", self.side, headers) - - # Since the path and headers only contain ASCII characters, - # we can keep this simple. - request = f"GET {path} HTTP/1.1\r\n" - request += str(headers) - - self.transport.write(request.encode()) - - async def read_http_response(self) -> Tuple[int, Headers]: - """ - Read status line and headers from the HTTP response. - - If the response contains a body, it may be read from ``self.reader`` - after this coroutine returns. - - :raises ~websockets.exceptions.InvalidMessage: if the HTTP message is - malformed or isn't an HTTP/1.1 GET response - - """ - try: - status_code, reason, headers = await read_response(self.reader) - except asyncio.CancelledError: # pragma: no cover - raise - except Exception as exc: - raise InvalidMessage("did not receive a valid HTTP response") from exc - - logger.debug("%s < HTTP/1.1 %d %s", self.side, status_code, reason) - logger.debug("%s < %r", self.side, headers) - - self.response_headers = headers - - return status_code, self.response_headers - - @staticmethod - def process_extensions( - headers: Headers, - available_extensions: Optional[Sequence[ClientExtensionFactory]], - ) -> List[Extension]: - """ - Handle the Sec-WebSocket-Extensions HTTP response header. - - Check that each extension is supported, as well as its parameters. - - Return the list of accepted extensions. - - Raise :exc:`~websockets.exceptions.InvalidHandshake` to abort the - connection. - - :rfc:`6455` leaves the rules up to the specification of each - :extension. - - To provide this level of flexibility, for each extension accepted by - the server, we check for a match with each extension available in the - client configuration. If no match is found, an exception is raised. - - If several variants of the same extension are accepted by the server, - it may be configured several times, which won't make sense in general. - Extensions must implement their own requirements. For this purpose, - the list of previously accepted extensions is provided. - - Other requirements, for example related to mandatory extensions or the - order of extensions, may be implemented by overriding this method. - - """ - accepted_extensions: List[Extension] = [] - - header_values = headers.get_all("Sec-WebSocket-Extensions") - - if header_values: - - if available_extensions is None: - raise InvalidHandshake("no extensions supported") - - parsed_header_values: List[ExtensionHeader] = sum( - [parse_extension(header_value) for header_value in header_values], [] - ) - - for name, response_params in parsed_header_values: - - for extension_factory in available_extensions: - - # Skip non-matching extensions based on their name. - if extension_factory.name != name: - continue - - # Skip non-matching extensions based on their params. - try: - extension = extension_factory.process_response_params( - response_params, accepted_extensions - ) - except NegotiationError: - continue - - # Add matching extension to the final list. - accepted_extensions.append(extension) - - # Break out of the loop once we have a match. - break - - # If we didn't break from the loop, no extension in our list - # matched what the server sent. Fail the connection. - else: - raise NegotiationError( - f"Unsupported extension: " - f"name = {name}, params = {response_params}" - ) - - return accepted_extensions - - @staticmethod - def process_subprotocol( - headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]] - ) -> Optional[Subprotocol]: - """ - Handle the Sec-WebSocket-Protocol HTTP response header. - - Check that it contains exactly one supported subprotocol. - - Return the selected subprotocol. - - """ - subprotocol: Optional[Subprotocol] = None - - header_values = headers.get_all("Sec-WebSocket-Protocol") - - if header_values: - - if available_subprotocols is None: - raise InvalidHandshake("no subprotocols supported") - - parsed_header_values: Sequence[Subprotocol] = sum( - [parse_subprotocol(header_value) for header_value in header_values], [] - ) - - if len(parsed_header_values) > 1: - subprotocols = ", ".join(parsed_header_values) - raise InvalidHandshake(f"multiple subprotocols: {subprotocols}") - - subprotocol = parsed_header_values[0] - - if subprotocol not in available_subprotocols: - raise NegotiationError(f"unsupported subprotocol: {subprotocol}") - - return subprotocol - - async def handshake( - self, - wsuri: WebSocketURI, - origin: Optional[Origin] = None, - available_extensions: Optional[Sequence[ClientExtensionFactory]] = None, - available_subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLike] = None, - ) -> None: - """ - Perform the client side of the opening handshake. - - :param origin: sets the Origin HTTP header - :param available_extensions: list of supported extensions in the order - in which they should be used - :param available_subprotocols: list of supported subprotocols in order - of decreasing preference - :param extra_headers: sets additional HTTP request headers; it must be - a :class:`~websockets.http.Headers` instance, a - :class:`~collections.abc.Mapping`, or an iterable of ``(name, - value)`` pairs - :raises ~websockets.exceptions.InvalidHandshake: if the handshake - fails - - """ - request_headers = Headers() - - if wsuri.port == (443 if wsuri.secure else 80): # pragma: no cover - request_headers["Host"] = wsuri.host - else: - request_headers["Host"] = f"{wsuri.host}:{wsuri.port}" - - if wsuri.user_info: - request_headers["Authorization"] = build_authorization_basic( - *wsuri.user_info - ) - - if origin is not None: - request_headers["Origin"] = origin - - key = build_request(request_headers) - - if available_extensions is not None: - extensions_header = build_extension( - [ - (extension_factory.name, extension_factory.get_request_params()) - for extension_factory in available_extensions - ] - ) - request_headers["Sec-WebSocket-Extensions"] = extensions_header - - if available_subprotocols is not None: - protocol_header = build_subprotocol(available_subprotocols) - request_headers["Sec-WebSocket-Protocol"] = protocol_header - - if extra_headers is not None: - if isinstance(extra_headers, Headers): - extra_headers = extra_headers.raw_items() - elif isinstance(extra_headers, collections.abc.Mapping): - extra_headers = extra_headers.items() - for name, value in extra_headers: - request_headers[name] = value - - request_headers.setdefault("User-Agent", USER_AGENT) - - self.write_http_request(wsuri.resource_name, request_headers) - - status_code, response_headers = await self.read_http_response() - if status_code in (301, 302, 303, 307, 308): - if "Location" not in response_headers: - raise InvalidHeader("Location") - raise RedirectHandshake(response_headers["Location"]) - elif status_code != 101: - raise InvalidStatusCode(status_code) - - check_response(response_headers, key) - - self.extensions = self.process_extensions( - response_headers, available_extensions - ) - - self.subprotocol = self.process_subprotocol( - response_headers, available_subprotocols - ) - - self.connection_open() - - -class Connect: - """ - Connect to the WebSocket server at the given ``uri``. - - Awaiting :func:`connect` yields a :class:`WebSocketClientProtocol` which - can then be used to send and receive messages. - - :func:`connect` can also be used as a asynchronous context manager. In - that case, the connection is closed when exiting the context. - - :func:`connect` is a wrapper around the event loop's - :meth:`~asyncio.loop.create_connection` method. Unknown keyword arguments - are passed to :meth:`~asyncio.loop.create_connection`. - - For example, you can set the ``ssl`` keyword argument to a - :class:`~ssl.SSLContext` to enforce some TLS settings. When connecting to - a ``wss://`` URI, if this argument isn't provided explicitly, - :func:`ssl.create_default_context` is called to create a context. - - You can connect to a different host and port from those found in ``uri`` - by setting ``host`` and ``port`` keyword arguments. This only changes the - destination of the TCP connection. The host name from ``uri`` is still - used in the TLS handshake for secure connections and in the ``Host`` HTTP - header. - - The ``create_protocol`` parameter allows customizing the - :class:`~asyncio.Protocol` that manages the connection. It should be a - callable or class accepting the same arguments as - :class:`WebSocketClientProtocol` and returning an instance of - :class:`WebSocketClientProtocol` or a subclass. It defaults to - :class:`WebSocketClientProtocol`. - - The behavior of ``ping_interval``, ``ping_timeout``, ``close_timeout``, - ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit`` is - described in :class:`~websockets.protocol.WebSocketCommonProtocol`. - - :func:`connect` also accepts the following optional arguments: - - * ``compression`` is a shortcut to configure compression extensions; - by default it enables the "permessage-deflate" extension; set it to - ``None`` to disable compression - * ``origin`` sets the Origin HTTP header - * ``extensions`` is a list of supported extensions in order of - decreasing preference - * ``subprotocols`` is a list of supported subprotocols in order of - decreasing preference - * ``extra_headers`` sets additional HTTP request headers; it can be a - :class:`~websockets.http.Headers` instance, a - :class:`~collections.abc.Mapping`, or an iterable of ``(name, value)`` - pairs - - :raises ~websockets.uri.InvalidURI: if ``uri`` is invalid - :raises ~websockets.handshake.InvalidHandshake: if the opening handshake - fails - - """ - - MAX_REDIRECTS_ALLOWED = 10 - - def __init__( - self, - uri: str, - *, - path: Optional[str] = None, - create_protocol: Optional[Type[WebSocketClientProtocol]] = None, - ping_interval: Optional[float] = 20, - ping_timeout: Optional[float] = 20, - close_timeout: Optional[float] = None, - max_size: Optional[int] = 2 ** 20, - max_queue: Optional[int] = 2 ** 5, - read_limit: int = 2 ** 16, - write_limit: int = 2 ** 16, - loop: Optional[asyncio.AbstractEventLoop] = None, - legacy_recv: bool = False, - klass: Optional[Type[WebSocketClientProtocol]] = None, - timeout: Optional[float] = None, - compression: Optional[str] = "deflate", - origin: Optional[Origin] = None, - extensions: Optional[Sequence[ClientExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLike] = None, - **kwargs: Any, - ) -> None: - # Backwards compatibility: close_timeout used to be called timeout. - if timeout is None: - timeout = 10 - else: - warnings.warn("rename timeout to close_timeout", DeprecationWarning) - # If both are specified, timeout is ignored. - if close_timeout is None: - close_timeout = timeout - - # Backwards compatibility: create_protocol used to be called klass. - if klass is None: - klass = WebSocketClientProtocol - else: - warnings.warn("rename klass to create_protocol", DeprecationWarning) - # If both are specified, klass is ignored. - if create_protocol is None: - create_protocol = klass - - if loop is None: - loop = asyncio.get_event_loop() - - wsuri = parse_uri(uri) - if wsuri.secure: - kwargs.setdefault("ssl", True) - elif kwargs.get("ssl") is not None: - raise ValueError( - "connect() received a ssl argument for a ws:// URI, " - "use a wss:// URI to enable TLS" - ) - - if compression == "deflate": - if extensions is None: - extensions = [] - if not any( - extension_factory.name == ClientPerMessageDeflateFactory.name - for extension_factory in extensions - ): - extensions = list(extensions) + [ - ClientPerMessageDeflateFactory(client_max_window_bits=True) - ] - elif compression is not None: - raise ValueError(f"unsupported compression: {compression}") - - factory = functools.partial( - create_protocol, - ping_interval=ping_interval, - ping_timeout=ping_timeout, - close_timeout=close_timeout, - max_size=max_size, - max_queue=max_queue, - read_limit=read_limit, - write_limit=write_limit, - loop=loop, - host=wsuri.host, - port=wsuri.port, - secure=wsuri.secure, - legacy_recv=legacy_recv, - origin=origin, - extensions=extensions, - subprotocols=subprotocols, - extra_headers=extra_headers, - ) - - if path is None: - host: Optional[str] - port: Optional[int] - if kwargs.get("sock") is None: - host, port = wsuri.host, wsuri.port - else: - # If sock is given, host and port shouldn't be specified. - host, port = None, None - # If host and port are given, override values from the URI. - host = kwargs.pop("host", host) - port = kwargs.pop("port", port) - create_connection = functools.partial( - loop.create_connection, factory, host, port, **kwargs - ) - else: - create_connection = functools.partial( - loop.create_unix_connection, factory, path, **kwargs - ) - - # This is a coroutine function. - self._create_connection = create_connection - self._wsuri = wsuri - - def handle_redirect(self, uri: str) -> None: - # Update the state of this instance to connect to a new URI. - old_wsuri = self._wsuri - new_wsuri = parse_uri(uri) - - # Forbid TLS downgrade. - if old_wsuri.secure and not new_wsuri.secure: - raise SecurityError("redirect from WSS to WS") - - same_origin = ( - old_wsuri.host == new_wsuri.host and old_wsuri.port == new_wsuri.port - ) - - # Rewrite the host and port arguments for cross-origin redirects. - # This preserves connection overrides with the host and port - # arguments if the redirect points to the same host and port. - if not same_origin: - # Replace the host and port argument passed to the protocol factory. - factory = self._create_connection.args[0] - factory = functools.partial( - factory.func, - *factory.args, - **dict(factory.keywords, host=new_wsuri.host, port=new_wsuri.port), - ) - # Replace the host and port argument passed to create_connection. - self._create_connection = functools.partial( - self._create_connection.func, - *(factory, new_wsuri.host, new_wsuri.port), - **self._create_connection.keywords, - ) - - # Set the new WebSocket URI. This suffices for same-origin redirects. - self._wsuri = new_wsuri - - # async with connect(...) - - async def __aenter__(self) -> WebSocketClientProtocol: - return await self - - async def __aexit__( - self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], - ) -> None: - await self.ws_client.close() - - # await connect(...) - - def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]: - # Create a suitable iterator by calling __await__ on a coroutine. - return self.__await_impl__().__await__() - - async def __await_impl__(self) -> WebSocketClientProtocol: - for redirects in range(self.MAX_REDIRECTS_ALLOWED): - transport, protocol = await self._create_connection() - # https://github.com/python/typeshed/pull/2756 - transport = cast(asyncio.Transport, transport) - protocol = cast(WebSocketClientProtocol, protocol) - - try: - try: - await protocol.handshake( - self._wsuri, - origin=protocol.origin, - available_extensions=protocol.available_extensions, - available_subprotocols=protocol.available_subprotocols, - extra_headers=protocol.extra_headers, - ) - except Exception: - protocol.fail_connection() - await protocol.wait_closed() - raise - else: - self.ws_client = protocol - return protocol - except RedirectHandshake as exc: - self.handle_redirect(exc.uri) - else: - raise SecurityError("too many redirects") - - # yield from connect(...) - - __iter__ = __await__ - - -connect = Connect - - -def unix_connect(path: str, uri: str = "ws://localhost/", **kwargs: Any) -> Connect: - """ - Similar to :func:`connect`, but for connecting to a Unix socket. - - This function calls the event loop's - :meth:`~asyncio.loop.create_unix_connection` method. - - It is only available on Unix. - - It's mainly useful for debugging servers listening on Unix sockets. - - :param path: file system path to the Unix socket - :param uri: WebSocket URI - - """ - return connect(uri=uri, path=path, **kwargs) +__all__ = [ + "connect", + "unix_connect", + "WebSocketClientProtocol", +] diff --git a/src/websockets/server.py b/src/websockets/server.py index 522c76114..ec94a2fbf 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -1,1004 +1,9 @@ -""" -:mod:`websockets.server` defines the WebSocket server APIs. +from .asyncio_server import WebSocketServer, WebSocketServerProtocol, serve, unix_serve -""" -import asyncio -import collections.abc -import email.utils -import functools -import http -import logging -import socket -import sys -import warnings -from types import TracebackType -from typing import ( - Any, - Awaitable, - Callable, - Generator, - List, - Optional, - Sequence, - Set, - Tuple, - Type, - Union, - cast, -) - -from .datastructures import Headers, HeadersLike, MultipleValuesError -from .exceptions import ( - AbortHandshake, - InvalidHandshake, - InvalidHeader, - InvalidMessage, - InvalidOrigin, - InvalidUpgrade, - NegotiationError, -) -from .extensions.base import Extension, ServerExtensionFactory -from .extensions.permessage_deflate import ServerPerMessageDeflateFactory -from .handshake_legacy import build_response, check_request -from .headers import build_extension, parse_extension, parse_subprotocol -from .http import USER_AGENT -from .http_legacy import read_request -from .protocol import WebSocketCommonProtocol -from .typing import ExtensionHeader, Origin, Subprotocol - - -__all__ = ["serve", "unix_serve", "WebSocketServerProtocol", "WebSocketServer"] - -logger = logging.getLogger(__name__) - - -HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]] - -HTTPResponse = Tuple[http.HTTPStatus, HeadersLike, bytes] - - -class WebSocketServerProtocol(WebSocketCommonProtocol): - """ - :class:`~asyncio.Protocol` subclass implementing a WebSocket server. - - This class inherits most of its methods from - :class:`~websockets.protocol.WebSocketCommonProtocol`. - - For the sake of simplicity, it doesn't rely on a full HTTP implementation. - Its support for HTTP responses is very limited. - - """ - - is_client = False - side = "server" - - def __init__( - self, - ws_handler: Callable[["WebSocketServerProtocol", str], Awaitable[Any]], - ws_server: "WebSocketServer", - *, - origins: Optional[Sequence[Optional[Origin]]] = None, - extensions: Optional[Sequence[ServerExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLikeOrCallable] = None, - process_request: Optional[ - Callable[[str, Headers], Awaitable[Optional[HTTPResponse]]] - ] = None, - select_subprotocol: Optional[ - Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] - ] = None, - **kwargs: Any, - ) -> None: - # For backwards compatibility with 6.0 or earlier. - if origins is not None and "" in origins: - warnings.warn("use None instead of '' in origins", DeprecationWarning) - origins = [None if origin == "" else origin for origin in origins] - self.ws_handler = ws_handler - self.ws_server = ws_server - self.origins = origins - self.available_extensions = extensions - self.available_subprotocols = subprotocols - self.extra_headers = extra_headers - self._process_request = process_request - self._select_subprotocol = select_subprotocol - super().__init__(**kwargs) - - def connection_made(self, transport: asyncio.BaseTransport) -> None: - """ - Register connection and initialize a task to handle it. - - """ - super().connection_made(transport) - # Register the connection with the server before creating the handler - # task. Registering at the beginning of the handler coroutine would - # create a race condition between the creation of the task, which - # schedules its execution, and the moment the handler starts running. - self.ws_server.register(self) - self.handler_task = self.loop.create_task(self.handler()) - - async def handler(self) -> None: - """ - Handle the lifecycle of a WebSocket connection. - - Since this method doesn't have a caller able to handle exceptions, it - attemps to log relevant ones and guarantees that the TCP connection is - closed before exiting. - - """ - try: - - try: - path = await self.handshake( - origins=self.origins, - available_extensions=self.available_extensions, - available_subprotocols=self.available_subprotocols, - extra_headers=self.extra_headers, - ) - except asyncio.CancelledError: # pragma: no cover - raise - except ConnectionError: - logger.debug("Connection error in opening handshake", exc_info=True) - raise - except Exception as exc: - if isinstance(exc, AbortHandshake): - status, headers, body = exc.status, exc.headers, exc.body - elif isinstance(exc, InvalidOrigin): - logger.debug("Invalid origin", exc_info=True) - status, headers, body = ( - http.HTTPStatus.FORBIDDEN, - Headers(), - f"Failed to open a WebSocket connection: {exc}.\n".encode(), - ) - elif isinstance(exc, InvalidUpgrade): - logger.debug("Invalid upgrade", exc_info=True) - status, headers, body = ( - http.HTTPStatus.UPGRADE_REQUIRED, - Headers([("Upgrade", "websocket")]), - ( - f"Failed to open a WebSocket connection: {exc}.\n" - f"\n" - f"You cannot access a WebSocket server directly " - f"with a browser. You need a WebSocket client.\n" - ).encode(), - ) - elif isinstance(exc, InvalidHandshake): - logger.debug("Invalid handshake", exc_info=True) - status, headers, body = ( - http.HTTPStatus.BAD_REQUEST, - Headers(), - f"Failed to open a WebSocket connection: {exc}.\n".encode(), - ) - else: - logger.warning("Error in opening handshake", exc_info=True) - status, headers, body = ( - http.HTTPStatus.INTERNAL_SERVER_ERROR, - Headers(), - ( - b"Failed to open a WebSocket connection.\n" - b"See server log for more information.\n" - ), - ) - - headers.setdefault("Date", email.utils.formatdate(usegmt=True)) - headers.setdefault("Server", USER_AGENT) - headers.setdefault("Content-Length", str(len(body))) - headers.setdefault("Content-Type", "text/plain") - headers.setdefault("Connection", "close") - - self.write_http_response(status, headers, body) - self.fail_connection() - await self.wait_closed() - return - - try: - await self.ws_handler(self, path) - except Exception: - logger.error("Error in connection handler", exc_info=True) - if not self.closed: - self.fail_connection(1011) - raise - - try: - await self.close() - except ConnectionError: - logger.debug("Connection error in closing handshake", exc_info=True) - raise - except Exception: - logger.warning("Error in closing handshake", exc_info=True) - raise - - except Exception: - # Last-ditch attempt to avoid leaking connections on errors. - try: - self.transport.close() - except Exception: # pragma: no cover - pass - - finally: - # Unregister the connection with the server when the handler task - # terminates. Registration is tied to the lifecycle of the handler - # task because the server waits for tasks attached to registered - # connections before terminating. - self.ws_server.unregister(self) - - async def read_http_request(self) -> Tuple[str, Headers]: - """ - Read request line and headers from the HTTP request. - - If the request contains a body, it may be read from ``self.reader`` - after this coroutine returns. - - :raises ~websockets.exceptions.InvalidMessage: if the HTTP message is - malformed or isn't an HTTP/1.1 GET request - - """ - try: - path, headers = await read_request(self.reader) - except asyncio.CancelledError: # pragma: no cover - raise - except Exception as exc: - raise InvalidMessage("did not receive a valid HTTP request") from exc - - logger.debug("%s < GET %s HTTP/1.1", self.side, path) - logger.debug("%s < %r", self.side, headers) - - self.path = path - self.request_headers = headers - - return path, headers - - def write_http_response( - self, status: http.HTTPStatus, headers: Headers, body: Optional[bytes] = None - ) -> None: - """ - Write status line and headers to the HTTP response. - - This coroutine is also able to write a response body. - - """ - self.response_headers = headers - - logger.debug("%s > HTTP/1.1 %d %s", self.side, status.value, status.phrase) - logger.debug("%s > %r", self.side, headers) - - # Since the status line and headers only contain ASCII characters, - # we can keep this simple. - response = f"HTTP/1.1 {status.value} {status.phrase}\r\n" - response += str(headers) - - self.transport.write(response.encode()) - - if body is not None: - logger.debug("%s > body (%d bytes)", self.side, len(body)) - self.transport.write(body) - - async def process_request( - self, path: str, request_headers: Headers - ) -> Optional[HTTPResponse]: - """ - Intercept the HTTP request and return an HTTP response if appropriate. - - If ``process_request`` returns ``None``, the WebSocket handshake - continues. If it returns 3-uple containing a status code, response - headers and a response body, that HTTP response is sent and the - connection is closed. In that case: - - * The HTTP status must be a :class:`~http.HTTPStatus`. - * HTTP headers must be a :class:`~websockets.http.Headers` instance, a - :class:`~collections.abc.Mapping`, or an iterable of ``(name, - value)`` pairs. - * The HTTP response body must be :class:`bytes`. It may be empty. - - This coroutine may be overridden in a :class:`WebSocketServerProtocol` - subclass, for example: - - * to return a HTTP 200 OK response on a given path; then a load - balancer can use this path for a health check; - * to authenticate the request and return a HTTP 401 Unauthorized or a - HTTP 403 Forbidden when authentication fails. - - Instead of subclassing, it is possible to override this method by - passing a ``process_request`` argument to the :func:`serve` function - or the :class:`WebSocketServerProtocol` constructor. This is - equivalent, except ``process_request`` won't have access to the - protocol instance, so it can't store information for later use. - - ``process_request`` is expected to complete quickly. If it may run for - a long time, then it should await :meth:`wait_closed` and exit if - :meth:`wait_closed` completes, or else it could prevent the server - from shutting down. - - :param path: request path, including optional query string - :param request_headers: request headers - - """ - if self._process_request is not None: - response = self._process_request(path, request_headers) - if isinstance(response, Awaitable): - return await response - else: - # For backwards compatibility with 7.0. - warnings.warn( - "declare process_request as a coroutine", DeprecationWarning - ) - return response # type: ignore - return None - - @staticmethod - def process_origin( - headers: Headers, origins: Optional[Sequence[Optional[Origin]]] = None - ) -> Optional[Origin]: - """ - Handle the Origin HTTP request header. - - :param headers: request headers - :param origins: optional list of acceptable origins - :raises ~websockets.exceptions.InvalidOrigin: if the origin isn't - acceptable - - """ - # "The user agent MUST NOT include more than one Origin header field" - # per https://tools.ietf.org/html/rfc6454#section-7.3. - try: - origin = cast(Origin, headers.get("Origin")) - except MultipleValuesError as exc: - raise InvalidHeader("Origin", "more than one Origin header found") from exc - if origins is not None: - if origin not in origins: - raise InvalidOrigin(origin) - return origin - - @staticmethod - def process_extensions( - headers: Headers, - available_extensions: Optional[Sequence[ServerExtensionFactory]], - ) -> Tuple[Optional[str], List[Extension]]: - """ - Handle the Sec-WebSocket-Extensions HTTP request header. - - Accept or reject each extension proposed in the client request. - Negotiate parameters for accepted extensions. - - Return the Sec-WebSocket-Extensions HTTP response header and the list - of accepted extensions. - - :rfc:`6455` leaves the rules up to the specification of each - :extension. - - To provide this level of flexibility, for each extension proposed by - the client, we check for a match with each extension available in the - server configuration. If no match is found, the extension is ignored. - - If several variants of the same extension are proposed by the client, - it may be accepted several times, which won't make sense in general. - Extensions must implement their own requirements. For this purpose, - the list of previously accepted extensions is provided. - - This process doesn't allow the server to reorder extensions. It can - only select a subset of the extensions proposed by the client. - - Other requirements, for example related to mandatory extensions or the - order of extensions, may be implemented by overriding this method. - - :param headers: request headers - :param extensions: optional list of supported extensions - :raises ~websockets.exceptions.InvalidHandshake: to abort the - handshake with an HTTP 400 error code - - """ - response_header_value: Optional[str] = None - - extension_headers: List[ExtensionHeader] = [] - accepted_extensions: List[Extension] = [] - - header_values = headers.get_all("Sec-WebSocket-Extensions") - - if header_values and available_extensions: - - parsed_header_values: List[ExtensionHeader] = sum( - [parse_extension(header_value) for header_value in header_values], [] - ) - - for name, request_params in parsed_header_values: - - for ext_factory in available_extensions: - - # Skip non-matching extensions based on their name. - if ext_factory.name != name: - continue - - # Skip non-matching extensions based on their params. - try: - response_params, extension = ext_factory.process_request_params( - request_params, accepted_extensions - ) - except NegotiationError: - continue - - # Add matching extension to the final list. - extension_headers.append((name, response_params)) - accepted_extensions.append(extension) - - # Break out of the loop once we have a match. - break - - # If we didn't break from the loop, no extension in our list - # matched what the client sent. The extension is declined. - - # Serialize extension header. - if extension_headers: - response_header_value = build_extension(extension_headers) - - return response_header_value, accepted_extensions - - # Not @staticmethod because it calls self.select_subprotocol() - def process_subprotocol( - self, headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]] - ) -> Optional[Subprotocol]: - """ - Handle the Sec-WebSocket-Protocol HTTP request header. - - Return Sec-WebSocket-Protocol HTTP response header, which is the same - as the selected subprotocol. - - :param headers: request headers - :param available_subprotocols: optional list of supported subprotocols - :raises ~websockets.exceptions.InvalidHandshake: to abort the - handshake with an HTTP 400 error code - - """ - subprotocol: Optional[Subprotocol] = None - - header_values = headers.get_all("Sec-WebSocket-Protocol") - - if header_values and available_subprotocols: - - parsed_header_values: List[Subprotocol] = sum( - [parse_subprotocol(header_value) for header_value in header_values], [] - ) - - subprotocol = self.select_subprotocol( - parsed_header_values, available_subprotocols - ) - - return subprotocol - - def select_subprotocol( - self, - client_subprotocols: Sequence[Subprotocol], - server_subprotocols: Sequence[Subprotocol], - ) -> Optional[Subprotocol]: - """ - Pick a subprotocol among those offered by the client. - - If several subprotocols are supported by the client and the server, - the default implementation selects the preferred subprotocols by - giving equal value to the priorities of the client and the server. - - If no subprotocol is supported by the client and the server, it - proceeds without a subprotocol. - - This is unlikely to be the most useful implementation in practice, as - many servers providing a subprotocol will require that the client uses - that subprotocol. Such rules can be implemented in a subclass. - - Instead of subclassing, it is possible to override this method by - passing a ``select_subprotocol`` argument to the :func:`serve` - function or the :class:`WebSocketServerProtocol` constructor - - :param client_subprotocols: list of subprotocols offered by the client - :param server_subprotocols: list of subprotocols available on the server - - """ - if self._select_subprotocol is not None: - return self._select_subprotocol(client_subprotocols, server_subprotocols) - - subprotocols = set(client_subprotocols) & set(server_subprotocols) - if not subprotocols: - return None - priority = lambda p: ( - client_subprotocols.index(p) + server_subprotocols.index(p) - ) - return sorted(subprotocols, key=priority)[0] - - async def handshake( - self, - origins: Optional[Sequence[Optional[Origin]]] = None, - available_extensions: Optional[Sequence[ServerExtensionFactory]] = None, - available_subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLikeOrCallable] = None, - ) -> str: - """ - Perform the server side of the opening handshake. - - Return the path of the URI of the request. - - :param origins: list of acceptable values of the Origin HTTP header; - include ``None`` if the lack of an origin is acceptable - :param available_extensions: list of supported extensions in the order - in which they should be used - :param available_subprotocols: list of supported subprotocols in order - of decreasing preference - :param extra_headers: sets additional HTTP response headers when the - handshake succeeds; it can be a :class:`~websockets.http.Headers` - instance, a :class:`~collections.abc.Mapping`, an iterable of - ``(name, value)`` pairs, or a callable taking the request path and - headers in arguments and returning one of the above. - :raises ~websockets.exceptions.InvalidHandshake: if the handshake - fails - - """ - path, request_headers = await self.read_http_request() - - # Hook for customizing request handling, for example checking - # authentication or treating some paths as plain HTTP endpoints. - early_response_awaitable = self.process_request(path, request_headers) - if isinstance(early_response_awaitable, Awaitable): - early_response = await early_response_awaitable - else: - # For backwards compatibility with 7.0. - warnings.warn("declare process_request as a coroutine", DeprecationWarning) - early_response = early_response_awaitable # type: ignore - - # Change the response to a 503 error if the server is shutting down. - if not self.ws_server.is_serving(): - early_response = ( - http.HTTPStatus.SERVICE_UNAVAILABLE, - [], - b"Server is shutting down.\n", - ) - - if early_response is not None: - raise AbortHandshake(*early_response) - - key = check_request(request_headers) - - self.origin = self.process_origin(request_headers, origins) - - extensions_header, self.extensions = self.process_extensions( - request_headers, available_extensions - ) - - protocol_header = self.subprotocol = self.process_subprotocol( - request_headers, available_subprotocols - ) - - response_headers = Headers() - - build_response(response_headers, key) - - if extensions_header is not None: - response_headers["Sec-WebSocket-Extensions"] = extensions_header - - if protocol_header is not None: - response_headers["Sec-WebSocket-Protocol"] = protocol_header - - if callable(extra_headers): - extra_headers = extra_headers(path, self.request_headers) - if extra_headers is not None: - if isinstance(extra_headers, Headers): - extra_headers = extra_headers.raw_items() - elif isinstance(extra_headers, collections.abc.Mapping): - extra_headers = extra_headers.items() - for name, value in extra_headers: - response_headers[name] = value - - response_headers.setdefault("Date", email.utils.formatdate(usegmt=True)) - response_headers.setdefault("Server", USER_AGENT) - - self.write_http_response(http.HTTPStatus.SWITCHING_PROTOCOLS, response_headers) - - self.connection_open() - - return path - - -class WebSocketServer: - """ - WebSocket server returned by :func:`~websockets.server.serve`. - - This class provides the same interface as - :class:`~asyncio.AbstractServer`, namely the - :meth:`~asyncio.AbstractServer.close` and - :meth:`~asyncio.AbstractServer.wait_closed` methods. - - It keeps track of WebSocket connections in order to close them properly - when shutting down. - - Instances of this class store a reference to the :class:`~asyncio.Server` - object returned by :meth:`~asyncio.loop.create_server` rather than inherit - from :class:`~asyncio.Server` in part because - :meth:`~asyncio.loop.create_server` doesn't support passing a custom - :class:`~asyncio.Server` class. - - """ - - def __init__(self, loop: asyncio.AbstractEventLoop) -> None: - # Store a reference to loop to avoid relying on self.server._loop. - self.loop = loop - - # Keep track of active connections. - self.websockets: Set[WebSocketServerProtocol] = set() - - # Task responsible for closing the server and terminating connections. - self.close_task: Optional[asyncio.Task[None]] = None - - # Completed when the server is closed and connections are terminated. - self.closed_waiter: asyncio.Future[None] = loop.create_future() - - def wrap(self, server: asyncio.AbstractServer) -> None: - """ - Attach to a given :class:`~asyncio.Server`. - - Since :meth:`~asyncio.loop.create_server` doesn't support injecting a - custom ``Server`` class, the easiest solution that doesn't rely on - private :mod:`asyncio` APIs is to: - - - instantiate a :class:`WebSocketServer` - - give the protocol factory a reference to that instance - - call :meth:`~asyncio.loop.create_server` with the factory - - attach the resulting :class:`~asyncio.Server` with this method - - """ - self.server = server - - def register(self, protocol: WebSocketServerProtocol) -> None: - """ - Register a connection with this server. - - """ - self.websockets.add(protocol) - - def unregister(self, protocol: WebSocketServerProtocol) -> None: - """ - Unregister a connection with this server. - - """ - self.websockets.remove(protocol) - - def is_serving(self) -> bool: - """ - Tell whether the server is accepting new connections or shutting down. - - """ - try: - # Python ≥ 3.7 - return self.server.is_serving() - except AttributeError: # pragma: no cover - # Python < 3.7 - return self.server.sockets is not None - - def close(self) -> None: - """ - Close the server. - - This method: - - * closes the underlying :class:`~asyncio.Server`; - * rejects new WebSocket connections with an HTTP 503 (service - unavailable) error; this happens when the server accepted the TCP - connection but didn't complete the WebSocket opening handshake prior - to closing; - * closes open WebSocket connections with close code 1001 (going away). - - :meth:`close` is idempotent. - - """ - if self.close_task is None: - self.close_task = self.loop.create_task(self._close()) - - async def _close(self) -> None: - """ - Implementation of :meth:`close`. - - This calls :meth:`~asyncio.Server.close` on the underlying - :class:`~asyncio.Server` object to stop accepting new connections and - then closes open connections with close code 1001. - - """ - # Stop accepting new connections. - self.server.close() - - # Wait until self.server.close() completes. - await self.server.wait_closed() - - # Wait until all accepted connections reach connection_made() and call - # register(). See https://bugs.python.org/issue34852 for details. - await asyncio.sleep( - 0, loop=self.loop if sys.version_info[:2] < (3, 8) else None - ) - - # Close OPEN connections with status code 1001. Since the server was - # closed, handshake() closes OPENING conections with a HTTP 503 error. - # Wait until all connections are closed. - - # asyncio.wait doesn't accept an empty first argument - if self.websockets: - await asyncio.wait( - [ - asyncio.ensure_future(websocket.close(1001)) - for websocket in self.websockets - ], - loop=self.loop if sys.version_info[:2] < (3, 8) else None, - ) - - # Wait until all connection handlers are complete. - - # asyncio.wait doesn't accept an empty first argument. - if self.websockets: - await asyncio.wait( - [websocket.handler_task for websocket in self.websockets], - loop=self.loop if sys.version_info[:2] < (3, 8) else None, - ) - - # Tell wait_closed() to return. - self.closed_waiter.set_result(None) - - async def wait_closed(self) -> None: - """ - Wait until the server is closed. - - When :meth:`wait_closed` returns, all TCP connections are closed and - all connection handlers have returned. - - """ - await asyncio.shield(self.closed_waiter) - - @property - def sockets(self) -> Optional[List[socket.socket]]: - """ - List of :class:`~socket.socket` objects the server is listening to. - - ``None`` if the server is closed. - - """ - return self.server.sockets - - -class Serve: - """ - - Create, start, and return a WebSocket server on ``host`` and ``port``. - - Whenever a client connects, the server accepts the connection, creates a - :class:`WebSocketServerProtocol`, performs the opening handshake, and - delegates to the connection handler defined by ``ws_handler``. Once the - handler completes, either normally or with an exception, the server - performs the closing handshake and closes the connection. - - Awaiting :func:`serve` yields a :class:`WebSocketServer`. This instance - provides :meth:`~websockets.server.WebSocketServer.close` and - :meth:`~websockets.server.WebSocketServer.wait_closed` methods for - terminating the server and cleaning up its resources. - - When a server is closed with :meth:`~WebSocketServer.close`, it closes all - connections with close code 1001 (going away). Connections handlers, which - are running the ``ws_handler`` coroutine, will receive a - :exc:`~websockets.exceptions.ConnectionClosedOK` exception on their - current or next interaction with the WebSocket connection. - - :func:`serve` can also be used as an asynchronous context manager. In - this case, the server is shut down when exiting the context. - - :func:`serve` is a wrapper around the event loop's - :meth:`~asyncio.loop.create_server` method. It creates and starts a - :class:`~asyncio.Server` with :meth:`~asyncio.loop.create_server`. Then it - wraps the :class:`~asyncio.Server` in a :class:`WebSocketServer` and - returns the :class:`WebSocketServer`. - - The ``ws_handler`` argument is the WebSocket handler. It must be a - coroutine accepting two arguments: a :class:`WebSocketServerProtocol` and - the request URI. - - The ``host`` and ``port`` arguments, as well as unrecognized keyword - arguments, are passed along to :meth:`~asyncio.loop.create_server`. - - For example, you can set the ``ssl`` keyword argument to a - :class:`~ssl.SSLContext` to enable TLS. - - The ``create_protocol`` parameter allows customizing the - :class:`~asyncio.Protocol` that manages the connection. It should be a - callable or class accepting the same arguments as - :class:`WebSocketServerProtocol` and returning an instance of - :class:`WebSocketServerProtocol` or a subclass. It defaults to - :class:`WebSocketServerProtocol`. - - The behavior of ``ping_interval``, ``ping_timeout``, ``close_timeout``, - ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit`` is - described in :class:`~websockets.protocol.WebSocketCommonProtocol`. - - :func:`serve` also accepts the following optional arguments: - - * ``compression`` is a shortcut to configure compression extensions; - by default it enables the "permessage-deflate" extension; set it to - ``None`` to disable compression - * ``origins`` defines acceptable Origin HTTP headers; include ``None`` if - the lack of an origin is acceptable - * ``extensions`` is a list of supported extensions in order of - decreasing preference - * ``subprotocols`` is a list of supported subprotocols in order of - decreasing preference - * ``extra_headers`` sets additional HTTP response headers when the - handshake succeeds; it can be a :class:`~websockets.http.Headers` - instance, a :class:`~collections.abc.Mapping`, an iterable of ``(name, - value)`` pairs, or a callable taking the request path and headers in - arguments and returning one of the above - * ``process_request`` allows intercepting the HTTP request; it must be a - coroutine taking the request path and headers in argument; see - :meth:`~WebSocketServerProtocol.process_request` for details - * ``select_subprotocol`` allows customizing the logic for selecting a - subprotocol; it must be a callable taking the subprotocols offered by - the client and available on the server in argument; see - :meth:`~WebSocketServerProtocol.select_subprotocol` for details - - Since there's no useful way to propagate exceptions triggered in handlers, - they're sent to the ``'websockets.server'`` logger instead. Debugging is - much easier if you configure logging to print them:: - - import logging - logger = logging.getLogger('websockets.server') - logger.setLevel(logging.ERROR) - logger.addHandler(logging.StreamHandler()) - - """ - - def __init__( - self, - ws_handler: Callable[[WebSocketServerProtocol, str], Awaitable[Any]], - host: Optional[Union[str, Sequence[str]]] = None, - port: Optional[int] = None, - *, - path: Optional[str] = None, - create_protocol: Optional[Type[WebSocketServerProtocol]] = None, - ping_interval: Optional[float] = 20, - ping_timeout: Optional[float] = 20, - close_timeout: Optional[float] = None, - max_size: Optional[int] = 2 ** 20, - max_queue: Optional[int] = 2 ** 5, - read_limit: int = 2 ** 16, - write_limit: int = 2 ** 16, - loop: Optional[asyncio.AbstractEventLoop] = None, - legacy_recv: bool = False, - klass: Optional[Type[WebSocketServerProtocol]] = None, - timeout: Optional[float] = None, - compression: Optional[str] = "deflate", - origins: Optional[Sequence[Optional[Origin]]] = None, - extensions: Optional[Sequence[ServerExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLikeOrCallable] = None, - process_request: Optional[ - Callable[[str, Headers], Awaitable[Optional[HTTPResponse]]] - ] = None, - select_subprotocol: Optional[ - Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] - ] = None, - **kwargs: Any, - ) -> None: - # Backwards compatibility: close_timeout used to be called timeout. - if timeout is None: - timeout = 10 - else: - warnings.warn("rename timeout to close_timeout", DeprecationWarning) - # If both are specified, timeout is ignored. - if close_timeout is None: - close_timeout = timeout - - # Backwards compatibility: create_protocol used to be called klass. - if klass is None: - klass = WebSocketServerProtocol - else: - warnings.warn("rename klass to create_protocol", DeprecationWarning) - # If both are specified, klass is ignored. - if create_protocol is None: - create_protocol = klass - - if loop is None: - loop = asyncio.get_event_loop() - - ws_server = WebSocketServer(loop) - - secure = kwargs.get("ssl") is not None - - if compression == "deflate": - if extensions is None: - extensions = [] - if not any( - ext_factory.name == ServerPerMessageDeflateFactory.name - for ext_factory in extensions - ): - extensions = list(extensions) + [ServerPerMessageDeflateFactory()] - elif compression is not None: - raise ValueError(f"unsupported compression: {compression}") - - factory = functools.partial( - create_protocol, - ws_handler, - ws_server, - host=host, - port=port, - secure=secure, - ping_interval=ping_interval, - ping_timeout=ping_timeout, - close_timeout=close_timeout, - max_size=max_size, - max_queue=max_queue, - read_limit=read_limit, - write_limit=write_limit, - loop=loop, - legacy_recv=legacy_recv, - origins=origins, - extensions=extensions, - subprotocols=subprotocols, - extra_headers=extra_headers, - process_request=process_request, - select_subprotocol=select_subprotocol, - ) - - if path is None: - create_server = functools.partial( - loop.create_server, factory, host, port, **kwargs - ) - else: - # unix_serve(path) must not specify host and port parameters. - assert host is None and port is None - create_server = functools.partial( - loop.create_unix_server, factory, path, **kwargs - ) - - # This is a coroutine function. - self._create_server = create_server - self.ws_server = ws_server - - # async with serve(...) - - async def __aenter__(self) -> WebSocketServer: - return await self - - async def __aexit__( - self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], - ) -> None: - self.ws_server.close() - await self.ws_server.wait_closed() - - # await serve(...) - - def __await__(self) -> Generator[Any, None, WebSocketServer]: - # Create a suitable iterator by calling __await__ on a coroutine. - return self.__await_impl__().__await__() - - async def __await_impl__(self) -> WebSocketServer: - server = await self._create_server() - self.ws_server.wrap(server) - return self.ws_server - - # yield from serve(...) - - __iter__ = __await__ - - -serve = Serve - - -def unix_serve( - ws_handler: Callable[[WebSocketServerProtocol, str], Awaitable[Any]], - path: str, - **kwargs: Any, -) -> Serve: - """ - Similar to :func:`serve`, but for listening on Unix sockets. - - This function calls the event loop's - :meth:`~asyncio.loop.create_unix_server` method. - - It is only available on Unix. - - It's useful for deploying a server behind a reverse proxy such as nginx. - - :param path: file system path to the Unix socket - - """ - return serve(ws_handler, path=path, **kwargs) +__all__ = [ + "serve", + "unix_serve", + "WebSocketServerProtocol", + "WebSocketServer", +] diff --git a/tests/test_client_server.py b/tests/test_asyncio_client_server.py similarity index 98% rename from tests/test_client_server.py rename to tests/test_asyncio_client_server.py index db26d6583..cff76d1f2 100644 --- a/tests/test_client_server.py +++ b/tests/test_asyncio_client_server.py @@ -13,7 +13,8 @@ import urllib.request import warnings -from websockets.client import * +from websockets.asyncio_client import * +from websockets.asyncio_server import * from websockets.datastructures import Headers from websockets.exceptions import ( ConnectionClosed, @@ -31,7 +32,6 @@ from websockets.http import USER_AGENT from websockets.http_legacy import read_response from websockets.protocol import State -from websockets.server import * from websockets.uri import parse_uri from .test_protocol import MS @@ -1072,7 +1072,7 @@ def test_subprotocol_error_two_subprotocols(self, _process_subprotocol): self.run_loop_once() @with_server() - @unittest.mock.patch("websockets.server.read_request") + @unittest.mock.patch("websockets.asyncio_server.read_request") def test_server_receives_malformed_request(self, _read_request): _read_request.side_effect = ValueError("read_request failed") @@ -1080,7 +1080,7 @@ def test_server_receives_malformed_request(self, _read_request): self.start_client() @with_server() - @unittest.mock.patch("websockets.client.read_response") + @unittest.mock.patch("websockets.asyncio_client.read_response") def test_client_receives_malformed_response(self, _read_response): _read_response.side_effect = ValueError("read_response failed") @@ -1089,7 +1089,7 @@ def test_client_receives_malformed_response(self, _read_response): self.run_loop_once() @with_server() - @unittest.mock.patch("websockets.client.build_request") + @unittest.mock.patch("websockets.asyncio_client.build_request") def test_client_sends_invalid_handshake_request(self, _build_request): def wrong_build_request(headers): return "42" @@ -1100,7 +1100,7 @@ def wrong_build_request(headers): self.start_client() @with_server() - @unittest.mock.patch("websockets.server.build_response") + @unittest.mock.patch("websockets.asyncio_server.build_response") def test_server_sends_invalid_handshake_response(self, _build_response): def wrong_build_response(headers, key): return build_response(headers, "42") @@ -1111,7 +1111,7 @@ def wrong_build_response(headers, key): self.start_client() @with_server() - @unittest.mock.patch("websockets.client.read_response") + @unittest.mock.patch("websockets.asyncio_client.read_response") def test_server_does_not_switch_protocols(self, _read_response): async def wrong_read_response(stream): status_code, reason, headers = await read_response(stream) @@ -1124,7 +1124,9 @@ async def wrong_read_response(stream): self.run_loop_once() @with_server() - @unittest.mock.patch("websockets.server.WebSocketServerProtocol.process_request") + @unittest.mock.patch( + "websockets.asyncio_server.WebSocketServerProtocol.process_request" + ) def test_server_error_in_handshake(self, _process_request): _process_request.side_effect = Exception("process_request crashed") @@ -1132,7 +1134,7 @@ def test_server_error_in_handshake(self, _process_request): self.start_client() @with_server() - @unittest.mock.patch("websockets.server.WebSocketServerProtocol.send") + @unittest.mock.patch("websockets.asyncio_server.WebSocketServerProtocol.send") def test_server_handler_crashes(self, send): send.side_effect = ValueError("send failed") @@ -1145,7 +1147,7 @@ def test_server_handler_crashes(self, send): self.assertEqual(self.client.close_code, 1011) @with_server() - @unittest.mock.patch("websockets.server.WebSocketServerProtocol.close") + @unittest.mock.patch("websockets.asyncio_server.WebSocketServerProtocol.close") def test_server_close_crashes(self, close): close.side_effect = ValueError("close failed") @@ -1220,7 +1222,9 @@ def test_invalid_status_error_during_client_connect(self): @unittest.mock.patch( "websockets.server.WebSocketServerProtocol.write_http_response" ) - @unittest.mock.patch("websockets.server.WebSocketServerProtocol.read_http_request") + @unittest.mock.patch( + "websockets.asyncio_server.WebSocketServerProtocol.read_http_request" + ) def test_connection_error_during_opening_handshake( self, _read_http_request, _write_http_response ): @@ -1238,7 +1242,7 @@ def test_connection_error_during_opening_handshake( _write_http_response.assert_not_called() @with_server() - @unittest.mock.patch("websockets.server.WebSocketServerProtocol.close") + @unittest.mock.patch("websockets.asyncio_server.WebSocketServerProtocol.close") def test_connection_error_during_closing_handshake(self, close): close.side_effect = ConnectionError diff --git a/tests/test_auth.py b/tests/test_auth.py index 97a4485a0..c693c9f45 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -6,7 +6,7 @@ from websockets.exceptions import InvalidStatusCode from websockets.headers import build_authorization_basic -from .test_client_server import ClientServerTestsMixin, with_client, with_server +from .test_asyncio_client_server import ClientServerTestsMixin, with_client, with_server from .utils import AsyncioTestCase From 80a8ac8194a9b3591549c6c5bc023f14f1f2c168 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 18 Feb 2020 22:04:00 +0100 Subject: [PATCH 0702/1539] Implement sans-I/O handshake. --- src/websockets/__init__.py | 2 + src/websockets/asyncio_server.py | 2 +- src/websockets/client.py | 291 ++++++++++++++ src/websockets/connection.py | 88 +++++ src/websockets/events.py | 27 ++ src/websockets/server.py | 426 ++++++++++++++++++++ tests/extensions/utils.py | 76 ++++ tests/test_client.py | 545 ++++++++++++++++++++++++++ tests/test_http11.py | 2 +- tests/test_protocol.py | 10 +- tests/test_server.py | 649 +++++++++++++++++++++++++++++++ tests/utils.py | 4 + 12 files changed, 2115 insertions(+), 7 deletions(-) create mode 100644 src/websockets/connection.py create mode 100644 src/websockets/events.py create mode 100644 tests/extensions/utils.py create mode 100644 tests/test_client.py create mode 100644 tests/test_server.py diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 89829235c..c4accaca1 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -15,6 +15,7 @@ "AbortHandshake", "basic_auth_protocol_factory", "BasicAuthWebSocketServerProtocol", + "ClientConnection", "connect", "ConnectionClosed", "ConnectionClosedError", @@ -43,6 +44,7 @@ "RedirectHandshake", "SecurityError", "serve", + "ServerConnection", "Subprotocol", "unix_connect", "unix_serve", diff --git a/src/websockets/asyncio_server.py b/src/websockets/asyncio_server.py index 1eeddf0eb..89ddf6c7d 100644 --- a/src/websockets/asyncio_server.py +++ b/src/websockets/asyncio_server.py @@ -341,7 +341,7 @@ def process_origin( # "The user agent MUST NOT include more than one Origin header field" # per https://tools.ietf.org/html/rfc6454#section-7.3. try: - origin = cast(Origin, headers.get("Origin")) + origin = cast(Optional[Origin], headers.get("Origin")) except MultipleValuesError as exc: raise InvalidHeader("Origin", "more than one Origin header found") from exc if origins is not None: diff --git a/src/websockets/client.py b/src/websockets/client.py index c7d153f13..ec4eb88f5 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -1,8 +1,299 @@ +import collections +import logging +from typing import Generator, List, Optional, Sequence + from .asyncio_client import WebSocketClientProtocol, connect, unix_connect +from .connection import CLIENT, CONNECTING, OPEN, Connection +from .datastructures import Headers, HeadersLike, MultipleValuesError +from .events import Accept, Connect, Event, Reject +from .exceptions import ( + InvalidHandshake, + InvalidHeader, + InvalidHeaderValue, + InvalidStatusCode, + InvalidUpgrade, + NegotiationError, +) +from .extensions.base import ClientExtensionFactory, Extension +from .headers import ( + build_authorization_basic, + build_extension, + build_subprotocol, + parse_connection, + parse_extension, + parse_subprotocol, + parse_upgrade, +) +from .http import USER_AGENT +from .http11 import Request, Response +from .typing import ( + ConnectionOption, + ExtensionHeader, + Origin, + Subprotocol, + UpgradeProtocol, +) +from .uri import parse_uri +from .utils import accept_key, generate_key __all__ = [ "connect", "unix_connect", + "ClientConnection", "WebSocketClientProtocol", ] + +logger = logging.getLogger(__name__) + + +class ClientConnection(Connection): + + side = CLIENT + + def __init__( + self, + uri: str, + origin: Optional[Origin] = None, + extensions: Optional[Sequence[ClientExtensionFactory]] = None, + subprotocols: Optional[Sequence[Subprotocol]] = None, + extra_headers: Optional[HeadersLike] = None, + ): + super().__init__(state=CONNECTING) + self.wsuri = parse_uri(uri) + self.origin = origin + self.available_extensions = extensions + self.available_subprotocols = subprotocols + self.extra_headers = extra_headers + self.key = generate_key() + + def connect(self) -> Connect: + """ + Create a Connect event to send to the server. + + """ + headers = Headers() + + if self.wsuri.port == (443 if self.wsuri.secure else 80): + headers["Host"] = self.wsuri.host + else: + headers["Host"] = f"{self.wsuri.host}:{self.wsuri.port}" + + if self.wsuri.user_info: + headers["Authorization"] = build_authorization_basic(*self.wsuri.user_info) + + if self.origin is not None: + headers["Origin"] = self.origin + + headers["Upgrade"] = "websocket" + headers["Connection"] = "Upgrade" + headers["Sec-WebSocket-Key"] = self.key + headers["Sec-WebSocket-Version"] = "13" + + if self.available_extensions is not None: + extensions_header = build_extension( + [ + (extension_factory.name, extension_factory.get_request_params()) + for extension_factory in self.available_extensions + ] + ) + headers["Sec-WebSocket-Extensions"] = extensions_header + + if self.available_subprotocols is not None: + protocol_header = build_subprotocol(self.available_subprotocols) + headers["Sec-WebSocket-Protocol"] = protocol_header + + if self.extra_headers is not None: + extra_headers = self.extra_headers + if isinstance(extra_headers, Headers): + extra_headers = extra_headers.raw_items() + elif isinstance(extra_headers, collections.abc.Mapping): + extra_headers = extra_headers.items() + for name, value in extra_headers: + headers[name] = value + + headers.setdefault("User-Agent", USER_AGENT) + + request = Request(self.wsuri.resource_name, headers) + return Connect(request) + + def process_response(self, response: Response) -> None: + """ + Check a handshake response received from the server. + + :param response: response + :param key: comes from :func:`build_request` + :raises ~websockets.exceptions.InvalidHandshake: if the handshake response + is invalid + + """ + + if response.status_code != 101: + raise InvalidStatusCode(response.status_code) + + headers = response.headers + + connection: List[ConnectionOption] = sum( + [parse_connection(value) for value in headers.get_all("Connection")], [] + ) + + if not any(value.lower() == "upgrade" for value in connection): + raise InvalidUpgrade( + "Connection", ", ".join(connection) if connection else None + ) + + upgrade: List[UpgradeProtocol] = sum( + [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] + ) + + # For compatibility with non-strict implementations, ignore case when + # checking the Upgrade header. It's supposed to be 'WebSocket'. + if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"): + raise InvalidUpgrade("Upgrade", ", ".join(upgrade) if upgrade else None) + + try: + s_w_accept = headers["Sec-WebSocket-Accept"] + except KeyError: + raise InvalidHeader("Sec-WebSocket-Accept") + except MultipleValuesError: + raise InvalidHeader( + "Sec-WebSocket-Accept", + "more than one Sec-WebSocket-Accept header found", + ) + + if s_w_accept != accept_key(self.key): + raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept) + + self.extensions = self.process_extensions(headers) + + self.subprotocol = self.process_subprotocol(headers) + + def process_extensions(self, headers: Headers) -> List[Extension]: + """ + Handle the Sec-WebSocket-Extensions HTTP response header. + + Check that each extension is supported, as well as its parameters. + + Return the list of accepted extensions. + + Raise :exc:`~websockets.exceptions.InvalidHandshake` to abort the + connection. + + :rfc:`6455` leaves the rules up to the specification of each + extension. + + To provide this level of flexibility, for each extension accepted by + the server, we check for a match with each extension available in the + client configuration. If no match is found, an exception is raised. + + If several variants of the same extension are accepted by the server, + it may be configured severel times, which won't make sense in general. + Extensions must implement their own requirements. For this purpose, + the list of previously accepted extensions is provided. + + Other requirements, for example related to mandatory extensions or the + order of extensions, may be implemented by overriding this method. + + """ + accepted_extensions: List[Extension] = [] + + extensions = headers.get_all("Sec-WebSocket-Extensions") + + if extensions: + + if self.available_extensions is None: + raise InvalidHandshake("no extensions supported") + + parsed_extensions: List[ExtensionHeader] = sum( + [parse_extension(header_value) for header_value in extensions], [] + ) + + for name, response_params in parsed_extensions: + + for extension_factory in self.available_extensions: + + # Skip non-matching extensions based on their name. + if extension_factory.name != name: + continue + + # Skip non-matching extensions based on their params. + try: + extension = extension_factory.process_response_params( + response_params, accepted_extensions + ) + except NegotiationError: + continue + + # Add matching extension to the final list. + accepted_extensions.append(extension) + + # Break out of the loop once we have a match. + break + + # If we didn't break from the loop, no extension in our list + # matched what the server sent. Fail the connection. + else: + raise NegotiationError( + f"Unsupported extension: " + f"name = {name}, params = {response_params}" + ) + + return accepted_extensions + + def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: + """ + Handle the Sec-WebSocket-Protocol HTTP response header. + + Check that it contains exactly one supported subprotocol. + + Return the selected subprotocol. + + """ + subprotocol: Optional[Subprotocol] = None + + subprotocols = headers.get_all("Sec-WebSocket-Protocol") + + if subprotocols: + + if self.available_subprotocols is None: + raise InvalidHandshake("no subprotocols supported") + + parsed_subprotocols: Sequence[Subprotocol] = sum( + [parse_subprotocol(header_value) for header_value in subprotocols], [] + ) + + if len(parsed_subprotocols) > 1: + subprotocols_display = ", ".join(parsed_subprotocols) + raise InvalidHandshake(f"multiple subprotocols: {subprotocols_display}") + + subprotocol = parsed_subprotocols[0] + + if subprotocol not in self.available_subprotocols: + raise NegotiationError(f"unsupported subprotocol: {subprotocol}") + + return subprotocol + + def send_in_connecting_state(self, event: Event) -> bytes: + assert isinstance(event, Connect) + + request = event.request + + logger.debug("%s > GET %s HTTP/1.1", self.side, request.path) + logger.debug("%s > %r", self.side, request.headers) + + return request.serialize() + + def parse(self) -> Generator[None, None, None]: + response = yield from Response.parse( + self.reader.read_line, self.reader.read_exact, self.reader.read_to_eof, + ) + assert self.state == CONNECTING + try: + self.process_response(response) + except InvalidHandshake as exc: + self.events.append(Reject(response, exc)) + return + else: + self.events.append(Accept(response)) + self.state = OPEN + yield from super().parse() diff --git a/src/websockets/connection.py b/src/websockets/connection.py new file mode 100644 index 000000000..5789b6ea1 --- /dev/null +++ b/src/websockets/connection.py @@ -0,0 +1,88 @@ +import enum +from typing import Generator, Iterable, List, Tuple + +from .events import Event +from .exceptions import InvalidState +from .streams import StreamReader + + +__all__ = ["Connection"] + + +# A WebSocket connection is either a server or a client. + + +class Side(enum.IntEnum): + SERVER, CLIENT = range(2) + + +SERVER = Side.SERVER +CLIENT = Side.CLIENT + + +# A WebSocket connection goes through the following four states, in order: + + +class State(enum.IntEnum): + CONNECTING, OPEN, CLOSING, CLOSED = range(4) + + +CONNECTING = State.CONNECTING +OPEN = State.OPEN +CLOSING = State.CLOSING +CLOSED = State.CLOSED + + +class Connection: + + side: Side + + def __init__(self, state: State = OPEN) -> None: + self.state = state + self.reader = StreamReader() + self.events: List[Event] = [] + self.parser = self.parse() + next(self.parser) # start coroutine + + # Public APIs for receiving data and producing events + + def receive_data(self, data: bytes) -> Tuple[Iterable[Event], bytes]: + self.reader.feed_data(data) + return self.receive() + + def receive_eof(self) -> Tuple[Iterable[Event], bytes]: + self.reader.feed_eof() + return self.receive() + + # Public APIs for receiving events and producing data + + def send(self, event: Event) -> bytes: + """ + Send an event to the remote endpoint. + + """ + if self.state == OPEN: + raise NotImplementedError # not implemented yet + elif self.state == CONNECTING: + return self.send_in_connecting_state(event) + else: + raise InvalidState( + f"Cannot write to a WebSocket in the {self.state.name} state" + ) + + # Private APIs + + def send_in_connecting_state(self, event: Event) -> bytes: + raise NotImplementedError + + def receive(self) -> Tuple[List[Event], bytes]: + # Run parser until more data is needed or EOF + try: + next(self.parser) + except StopIteration: + pass + events, self.events = self.events, [] + return events, b"" + + def parse(self) -> Generator[None, None, None]: + yield # not implemented yet diff --git a/src/websockets/events.py b/src/websockets/events.py new file mode 100644 index 000000000..196de9421 --- /dev/null +++ b/src/websockets/events.py @@ -0,0 +1,27 @@ +from typing import NamedTuple, Optional, Union + +from .http11 import Request, Response + + +__all__ = [ + "Accept", + "Connect", + "Event", + "Reject", +] + + +class Connect(NamedTuple): + request: Request + + +class Accept(NamedTuple): + response: Response + + +class Reject(NamedTuple): + response: Response + exception: Optional[Exception] + + +Event = Union[Connect, Accept, Reject] diff --git a/src/websockets/server.py b/src/websockets/server.py index ec94a2fbf..f668ff5e7 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -1,9 +1,435 @@ +import base64 +import binascii +import collections +import email.utils +import http +import logging +from typing import Callable, Generator, List, Optional, Sequence, Tuple, Union, cast + from .asyncio_server import WebSocketServer, WebSocketServerProtocol, serve, unix_serve +from .connection import CONNECTING, OPEN, SERVER, Connection +from .datastructures import Headers, HeadersLike, MultipleValuesError +from .events import Accept, Connect, Event, Reject +from .exceptions import ( + InvalidHandshake, + InvalidHeader, + InvalidHeaderValue, + InvalidOrigin, + InvalidUpgrade, + NegotiationError, +) +from .extensions.base import Extension, ServerExtensionFactory +from .headers import ( + build_extension, + parse_connection, + parse_extension, + parse_subprotocol, + parse_upgrade, +) +from .http import USER_AGENT +from .http11 import Request, Response +from .typing import ( + ConnectionOption, + ExtensionHeader, + Origin, + Subprotocol, + UpgradeProtocol, +) +from .utils import accept_key __all__ = [ "serve", "unix_serve", + "ServerConnection", "WebSocketServerProtocol", "WebSocketServer", ] + +logger = logging.getLogger(__name__) + + +HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]] + + +class ServerConnection(Connection): + + side = SERVER + + def __init__( + self, + origins: Optional[Sequence[Optional[Origin]]] = None, + extensions: Optional[Sequence[ServerExtensionFactory]] = None, + subprotocols: Optional[Sequence[Subprotocol]] = None, + extra_headers: Optional[HeadersLikeOrCallable] = None, + ): + super().__init__(state=CONNECTING) + self.origins = origins + self.available_extensions = extensions + self.available_subprotocols = subprotocols + self.extra_headers = extra_headers + + def accept(self, connect: Connect) -> Union[Accept, Reject]: + """ + Create an ``Accept`` or ``Reject`` event to send to the client. + + If the connection cannot be established, this method returns a + :class:`~websockets.events.Reject` event, which may be unexpected. + + """ + request = connect.request + try: + key, extensions_header, protocol_header = self.process_request(request) + except InvalidOrigin as exc: + logger.debug("Invalid origin", exc_info=True) + return self.reject( + http.HTTPStatus.FORBIDDEN, + f"Failed to open a WebSocket connection: {exc}.\n", + exception=exc, + ) + except InvalidUpgrade as exc: + logger.debug("Invalid upgrade", exc_info=True) + return self.reject( + http.HTTPStatus.UPGRADE_REQUIRED, + ( + f"Failed to open a WebSocket connection: {exc}.\n" + f"\n" + f"You cannot access a WebSocket server directly " + f"with a browser. You need a WebSocket client.\n" + ), + headers=Headers([("Upgrade", "websocket")]), + exception=exc, + ) + except InvalidHandshake as exc: + logger.debug("Invalid handshake", exc_info=True) + return self.reject( + http.HTTPStatus.BAD_REQUEST, + f"Failed to open a WebSocket connection: {exc}.\n", + exception=exc, + ) + except Exception as exc: + logger.warning("Error in opening handshake", exc_info=True) + return self.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + exception=exc, + ) + + headers = Headers() + + headers["Upgrade"] = "websocket" + headers["Connection"] = "Upgrade" + headers["Sec-WebSocket-Accept"] = accept_key(key) + + if extensions_header is not None: + headers["Sec-WebSocket-Extensions"] = extensions_header + + if protocol_header is not None: + headers["Sec-WebSocket-Protocol"] = protocol_header + + extra_headers: Optional[HeadersLike] + if callable(self.extra_headers): + extra_headers = self.extra_headers(request.path, request.headers) + else: + extra_headers = self.extra_headers + if extra_headers is not None: + if isinstance(extra_headers, Headers): + extra_headers = extra_headers.raw_items() + elif isinstance(extra_headers, collections.abc.Mapping): + extra_headers = extra_headers.items() + for name, value in extra_headers: + headers[name] = value + + headers.setdefault("Date", email.utils.formatdate(usegmt=True)) + headers.setdefault("Server", USER_AGENT) + + response = Response(101, "Switching Protocols", headers) + return Accept(response) + + def process_request( + self, request: Request + ) -> Tuple[str, Optional[str], Optional[str]]: + """ + Check a handshake request received from the client. + + This function doesn't verify that the request is an HTTP/1.1 or higher GET + request and doesn't perform ``Host`` and ``Origin`` checks. These controls + are usually performed earlier in the HTTP request handling code. They're + the responsibility of the caller. + + :param request: request + :returns: ``key`` which must be passed to :func:`build_response` + :raises ~websockets.exceptions.InvalidHandshake: if the handshake request + is invalid; then the server must return 400 Bad Request error + + """ + headers = request.headers + + connection: List[ConnectionOption] = sum( + [parse_connection(value) for value in headers.get_all("Connection")], [] + ) + + if not any(value.lower() == "upgrade" for value in connection): + raise InvalidUpgrade( + "Connection", ", ".join(connection) if connection else None + ) + + upgrade: List[UpgradeProtocol] = sum( + [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] + ) + + # For compatibility with non-strict implementations, ignore case when + # checking the Upgrade header. The RFC always uses "websocket", except + # in section 11.2. (IANA registration) where it uses "WebSocket". + if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"): + raise InvalidUpgrade("Upgrade", ", ".join(upgrade) if upgrade else None) + + try: + key = headers["Sec-WebSocket-Key"] + except KeyError: + raise InvalidHeader("Sec-WebSocket-Key") + except MultipleValuesError: + raise InvalidHeader( + "Sec-WebSocket-Key", "more than one Sec-WebSocket-Key header found" + ) + + try: + raw_key = base64.b64decode(key.encode(), validate=True) + except binascii.Error: + raise InvalidHeaderValue("Sec-WebSocket-Key", key) + if len(raw_key) != 16: + raise InvalidHeaderValue("Sec-WebSocket-Key", key) + + try: + version = headers["Sec-WebSocket-Version"] + except KeyError: + raise InvalidHeader("Sec-WebSocket-Version") + except MultipleValuesError: + raise InvalidHeader( + "Sec-WebSocket-Version", + "more than one Sec-WebSocket-Version header found", + ) + + if version != "13": + raise InvalidHeaderValue("Sec-WebSocket-Version", version) + + self.origin = self.process_origin(headers) + + extensions_header, self.extensions = self.process_extensions(headers) + + protocol_header = self.subprotocol = self.process_subprotocol(headers) + + return key, extensions_header, protocol_header + + def process_origin(self, headers: Headers) -> Optional[Origin]: + """ + Handle the Origin HTTP request header. + + :param headers: request headers + :raises ~websockets.exceptions.InvalidOrigin: if the origin isn't + acceptable + + """ + # "The user agent MUST NOT include more than one Origin header field" + # per https://tools.ietf.org/html/rfc6454#section-7.3. + try: + origin = cast(Optional[Origin], headers.get("Origin")) + except MultipleValuesError as exc: + raise InvalidHeader("Origin", "more than one Origin header found") from exc + if self.origins is not None: + if origin not in self.origins: + raise InvalidOrigin(origin) + return origin + + def process_extensions( + self, headers: Headers, + ) -> Tuple[Optional[str], List[Extension]]: + """ + Handle the Sec-WebSocket-Extensions HTTP request header. + + Accept or reject each extension proposed in the client request. + Negotiate parameters for accepted extensions. + + Return the Sec-WebSocket-Extensions HTTP response header and the list + of accepted extensions. + + :rfc:`6455` leaves the rules up to the specification of each + :extension. + + To provide this level of flexibility, for each extension proposed by + the client, we check for a match with each extension available in the + server configuration. If no match is found, the extension is ignored. + + If several variants of the same extension are proposed by the client, + it may be accepted several times, which won't make sense in general. + Extensions must implement their own requirements. For this purpose, + the list of previously accepted extensions is provided. + + This process doesn't allow the server to reorder extensions. It can + only select a subset of the extensions proposed by the client. + + Other requirements, for example related to mandatory extensions or the + order of extensions, may be implemented by overriding this method. + + :param headers: request headers + :raises ~websockets.exceptions.InvalidHandshake: to abort the + handshake with an HTTP 400 error code + + """ + response_header_value: Optional[str] = None + + extension_headers: List[ExtensionHeader] = [] + accepted_extensions: List[Extension] = [] + + header_values = headers.get_all("Sec-WebSocket-Extensions") + + if header_values and self.available_extensions: + + parsed_header_values: List[ExtensionHeader] = sum( + [parse_extension(header_value) for header_value in header_values], [] + ) + + for name, request_params in parsed_header_values: + + for ext_factory in self.available_extensions: + + # Skip non-matching extensions based on their name. + if ext_factory.name != name: + continue + + # Skip non-matching extensions based on their params. + try: + response_params, extension = ext_factory.process_request_params( + request_params, accepted_extensions + ) + except NegotiationError: + continue + + # Add matching extension to the final list. + extension_headers.append((name, response_params)) + accepted_extensions.append(extension) + + # Break out of the loop once we have a match. + break + + # If we didn't break from the loop, no extension in our list + # matched what the client sent. The extension is declined. + + # Serialize extension header. + if extension_headers: + response_header_value = build_extension(extension_headers) + + return response_header_value, accepted_extensions + + def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: + """ + Handle the Sec-WebSocket-Protocol HTTP request header. + + Return Sec-WebSocket-Protocol HTTP response header, which is the same + as the selected subprotocol. + + :param headers: request headers + :raises ~websockets.exceptions.InvalidHandshake: to abort the + handshake with an HTTP 400 error code + + """ + subprotocol: Optional[Subprotocol] = None + + header_values = headers.get_all("Sec-WebSocket-Protocol") + + if header_values and self.available_subprotocols: + + parsed_header_values: List[Subprotocol] = sum( + [parse_subprotocol(header_value) for header_value in header_values], [] + ) + + subprotocol = self.select_subprotocol( + parsed_header_values, self.available_subprotocols + ) + + return subprotocol + + def select_subprotocol( + self, + client_subprotocols: Sequence[Subprotocol], + server_subprotocols: Sequence[Subprotocol], + ) -> Optional[Subprotocol]: + """ + Pick a subprotocol among those offered by the client. + + If several subprotocols are supported by the client and the server, + the default implementation selects the preferred subprotocols by + giving equal value to the priorities of the client and the server. + + If no common subprotocol is supported by the client and the server, it + proceeds without a subprotocol. + + This is unlikely to be the most useful implementation in practice, as + many servers providing a subprotocol will require that the client uses + that subprotocol. + + :param client_subprotocols: list of subprotocols offered by the client + :param server_subprotocols: list of subprotocols available on the server + + """ + subprotocols = set(client_subprotocols) & set(server_subprotocols) + if not subprotocols: + return None + priority = lambda p: ( + client_subprotocols.index(p) + server_subprotocols.index(p) + ) + return sorted(subprotocols, key=priority)[0] + + def reject( + self, + status: http.HTTPStatus, + text: str, + headers: Optional[Headers] = None, + exception: Optional[Exception] = None, + ) -> Reject: + """ + Create a ``Reject`` event to send to the client. + + A short plain text response is the best fallback when failing to + establish a WebSocket connection. + + """ + body = text.encode() + if headers is None: + headers = Headers() + headers.setdefault("Date", email.utils.formatdate(usegmt=True)) + headers.setdefault("Server", USER_AGENT) + headers.setdefault("Content-Length", str(len(body))) + headers.setdefault("Content-Type", "text/plain; charset=utf-8") + headers.setdefault("Connection", "close") + response = Response(status.value, status.phrase, headers, body) + return Reject(response, exception) + + def send_in_connecting_state(self, event: Event) -> bytes: + assert isinstance(event, (Accept, Reject)) + + if isinstance(event, Accept): + self.state = OPEN + + response = event.response + + logger.debug( + "%s > HTTP/1.1 %d %s", + self.side, + response.status_code, + response.reason_phrase, + ) + logger.debug("%s > %r", self.side, response.headers) + if response.body is not None: + logger.debug("%s > body (%d bytes)", self.side, len(response.body)) + + return response.serialize() + + def parse(self) -> Generator[None, None, None]: + request = yield from Request.parse(self.reader.read_line) + assert self.state == CONNECTING + self.events.append(Connect(request)) + yield from super().parse() diff --git a/tests/extensions/utils.py b/tests/extensions/utils.py new file mode 100644 index 000000000..81990bb07 --- /dev/null +++ b/tests/extensions/utils.py @@ -0,0 +1,76 @@ +from websockets.exceptions import NegotiationError + + +class OpExtension: + name = "x-op" + + def __init__(self, op=None): + self.op = op + + def decode(self, frame, *, max_size=None): + return frame # pragma: no cover + + def encode(self, frame): + return frame # pragma: no cover + + def __eq__(self, other): + return isinstance(other, OpExtension) and self.op == other.op + + +class ClientOpExtensionFactory: + name = "x-op" + + def __init__(self, op=None): + self.op = op + + def get_request_params(self): + return [("op", self.op)] + + def process_response_params(self, params, accepted_extensions): + if params != [("op", self.op)]: + raise NegotiationError() + return OpExtension(self.op) + + +class ServerOpExtensionFactory: + name = "x-op" + + def __init__(self, op=None): + self.op = op + + def process_request_params(self, params, accepted_extensions): + if params != [("op", self.op)]: + raise NegotiationError() + return [("op", self.op)], OpExtension(self.op) + + +class Rsv2Extension: + name = "x-rsv2" + + def decode(self, frame, *, max_size=None): + assert frame.rsv2 + return frame._replace(rsv2=False) + + def encode(self, frame): + assert not frame.rsv2 + return frame._replace(rsv2=True) + + def __eq__(self, other): + return isinstance(other, Rsv2Extension) + + +class ClientRsv2ExtensionFactory: + name = "x-rsv2" + + def get_request_params(self): + return [] + + def process_response_params(self, params, accepted_extensions): + return Rsv2Extension() + + +class ServerRsv2ExtensionFactory: + name = "x-rsv2" + + def process_request_params(self, params, accepted_extensions): + return [], Rsv2Extension() diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 000000000..1cf27349d --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,545 @@ +import unittest +import unittest.mock + +from websockets.client import * +from websockets.connection import CONNECTING, OPEN +from websockets.datastructures import Headers +from websockets.events import Accept, Connect, Reject +from websockets.exceptions import InvalidHandshake, InvalidHeader +from websockets.http import USER_AGENT +from websockets.http11 import Request, Response +from websockets.utils import accept_key + +from .extensions.utils import ( + ClientOpExtensionFactory, + ClientRsv2ExtensionFactory, + OpExtension, + Rsv2Extension, +) +from .test_utils import ACCEPT, KEY +from .utils import DATE + + +class ConnectTests(unittest.TestCase): + def test_send_connect(self): + with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): + client = ClientConnection("wss://example.com/test") + connect = client.connect() + self.assertIsInstance(connect, Connect) + bytes_to_send = client.send(connect) + self.assertEqual( + bytes_to_send, + ( + f"GET /test HTTP/1.1\r\n" + f"Host: example.com\r\n" + f"Upgrade: websocket\r\n" + f"Connection: Upgrade\r\n" + f"Sec-WebSocket-Key: {KEY}\r\n" + f"Sec-WebSocket-Version: 13\r\n" + f"User-Agent: {USER_AGENT}\r\n" + f"\r\n" + ).encode(), + ) + + def test_connect_request(self): + with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): + client = ClientConnection("wss://example.com/test") + connect = client.connect() + self.assertIsInstance(connect.request, Request) + self.assertEqual(connect.request.path, "/test") + self.assertEqual( + connect.request.headers, + Headers( + { + "Host": "example.com", + "Upgrade": "websocket", + "Connection": "Upgrade", + "Sec-WebSocket-Key": KEY, + "Sec-WebSocket-Version": "13", + "User-Agent": USER_AGENT, + } + ), + ) + + def test_path(self): + client = ClientConnection("wss://example.com/endpoint?test=1") + request = client.connect().request + + self.assertEqual(request.path, "/endpoint?test=1") + + def test_port(self): + for uri, host in [ + ("ws://example.com/", "example.com"), + ("ws://example.com:80/", "example.com"), + ("ws://example.com:8080/", "example.com:8080"), + ("wss://example.com/", "example.com"), + ("wss://example.com:443/", "example.com"), + ("wss://example.com:8443/", "example.com:8443"), + ]: + with self.subTest(uri=uri): + client = ClientConnection(uri) + request = client.connect().request + + self.assertEqual(request.headers["Host"], host) + + def test_user_info(self): + client = ClientConnection("wss://hello:iloveyou@example.com/") + request = client.connect().request + + self.assertEqual(request.headers["Authorization"], "Basic aGVsbG86aWxvdmV5b3U=") + + def test_origin(self): + client = ClientConnection("wss://example.com/", origin="https://example.com") + request = client.connect().request + + self.assertEqual(request.headers["Origin"], "https://example.com") + + def test_extensions(self): + client = ClientConnection( + "wss://example.com/", extensions=[ClientOpExtensionFactory()] + ) + request = client.connect().request + + self.assertEqual(request.headers["Sec-WebSocket-Extensions"], "x-op; op") + + def test_subprotocols(self): + client = ClientConnection("wss://example.com/", subprotocols=["chat"]) + request = client.connect().request + + self.assertEqual(request.headers["Sec-WebSocket-Protocol"], "chat") + + def test_extra_headers(self): + for extra_headers in [ + Headers({"X-Spam": "Eggs"}), + {"X-Spam": "Eggs"}, + [("X-Spam", "Eggs")], + ]: + with self.subTest(extra_headers=extra_headers): + client = ClientConnection( + "wss://example.com/", extra_headers=extra_headers + ) + request = client.connect().request + + self.assertEqual(request.headers["X-Spam"], "Eggs") + + def test_extra_headers_overrides_user_agent(self): + client = ClientConnection( + "wss://example.com/", extra_headers={"User-Agent": "Other"} + ) + request = client.connect().request + + self.assertEqual(request.headers["User-Agent"], "Other") + + +class AcceptRejectTests(unittest.TestCase): + def test_receive_accept(self): + with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): + client = ClientConnection("ws://example.com/test") + client.connect() + [accept], bytes_to_send = client.receive_data( + ( + f"HTTP/1.1 101 Switching Protocols\r\n" + f"Upgrade: websocket\r\n" + f"Connection: Upgrade\r\n" + f"Sec-WebSocket-Accept: {ACCEPT}\r\n" + f"Date: {DATE}\r\n" + f"Server: {USER_AGENT}\r\n" + f"\r\n" + ).encode(), + ) + self.assertIsInstance(accept, Accept) + self.assertEqual(bytes_to_send, b"") + self.assertEqual(client.state, OPEN) + + def test_receive_reject(self): + with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): + client = ClientConnection("ws://example.com/test") + client.connect() + [reject], bytes_to_send = client.receive_data( + ( + f"HTTP/1.1 404 Not Found\r\n" + f"Date: {DATE}\r\n" + f"Server: {USER_AGENT}\r\n" + f"Content-Length: 13\r\n" + f"Content-Type: text/plain; charset=utf-8\r\n" + f"Connection: close\r\n" + f"\r\n" + f"Sorry folks.\n" + ).encode(), + ) + self.assertIsInstance(reject, Reject) + self.assertEqual(bytes_to_send, b"") + self.assertEqual(client.state, CONNECTING) + + def test_accept_response(self): + with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): + client = ClientConnection("ws://example.com/test") + client.connect() + [accept], _bytes_to_send = client.receive_data( + ( + f"HTTP/1.1 101 Switching Protocols\r\n" + f"Upgrade: websocket\r\n" + f"Connection: Upgrade\r\n" + f"Sec-WebSocket-Accept: {ACCEPT}\r\n" + f"Date: {DATE}\r\n" + f"Server: {USER_AGENT}\r\n" + f"\r\n" + ).encode(), + ) + self.assertEqual(accept.response.status_code, 101) + self.assertEqual(accept.response.reason_phrase, "Switching Protocols") + self.assertEqual( + accept.response.headers, + Headers( + { + "Upgrade": "websocket", + "Connection": "Upgrade", + "Sec-WebSocket-Accept": ACCEPT, + "Date": DATE, + "Server": USER_AGENT, + } + ), + ) + self.assertIsNone(accept.response.body) + + def test_reject_response(self): + with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): + client = ClientConnection("ws://example.com/test") + client.connect() + [reject], _bytes_to_send = client.receive_data( + ( + f"HTTP/1.1 404 Not Found\r\n" + f"Date: {DATE}\r\n" + f"Server: {USER_AGENT}\r\n" + f"Content-Length: 13\r\n" + f"Content-Type: text/plain; charset=utf-8\r\n" + f"Connection: close\r\n" + f"\r\n" + f"Sorry folks.\n" + ).encode(), + ) + self.assertEqual(reject.response.status_code, 404) + self.assertEqual(reject.response.reason_phrase, "Not Found") + self.assertEqual( + reject.response.headers, + Headers( + { + "Date": DATE, + "Server": USER_AGENT, + "Content-Length": "13", + "Content-Type": "text/plain; charset=utf-8", + "Connection": "close", + } + ), + ) + self.assertEqual(reject.response.body, b"Sorry folks.\n") + + def make_accept_response(self, client): + request = client.connect().request + return Response( + status_code=101, + reason_phrase="Switching Protocols", + headers=Headers( + { + "Upgrade": "websocket", + "Connection": "Upgrade", + "Sec-WebSocket-Accept": accept_key( + request.headers["Sec-WebSocket-Key"] + ), + } + ), + ) + + def test_basic(self): + client = ClientConnection("wss://example.com/") + response = self.make_accept_response(client) + [accept], _bytes_to_send = client.receive_data(response.serialize()) + + self.assertIsInstance(accept, Accept) + + def test_missing_connection(self): + client = ClientConnection("wss://example.com/") + response = self.make_accept_response(client) + del response.headers["Connection"] + [reject], _bytes_to_send = client.receive_data(response.serialize()) + + self.assertIsInstance(reject, Reject) + with self.assertRaises(InvalidHeader) as raised: + raise reject.exception + self.assertEqual(str(raised.exception), "missing Connection header") + + def test_invalid_connection(self): + client = ClientConnection("wss://example.com/") + response = self.make_accept_response(client) + del response.headers["Connection"] + response.headers["Connection"] = "close" + [reject], _bytes_to_send = client.receive_data(response.serialize()) + + self.assertIsInstance(reject, Reject) + with self.assertRaises(InvalidHeader) as raised: + raise reject.exception + self.assertEqual(str(raised.exception), "invalid Connection header: close") + + def test_missing_upgrade(self): + client = ClientConnection("wss://example.com/") + response = self.make_accept_response(client) + del response.headers["Upgrade"] + [reject], _bytes_to_send = client.receive_data(response.serialize()) + + self.assertIsInstance(reject, Reject) + with self.assertRaises(InvalidHeader) as raised: + raise reject.exception + self.assertEqual(str(raised.exception), "missing Upgrade header") + + def test_invalid_upgrade(self): + client = ClientConnection("wss://example.com/") + response = self.make_accept_response(client) + del response.headers["Upgrade"] + response.headers["Upgrade"] = "h2c" + [reject], _bytes_to_send = client.receive_data(response.serialize()) + + self.assertIsInstance(reject, Reject) + with self.assertRaises(InvalidHeader) as raised: + raise reject.exception + self.assertEqual(str(raised.exception), "invalid Upgrade header: h2c") + + def test_missing_accept(self): + client = ClientConnection("wss://example.com/") + response = self.make_accept_response(client) + del response.headers["Sec-WebSocket-Accept"] + [reject], _bytes_to_send = client.receive_data(response.serialize()) + + self.assertIsInstance(reject, Reject) + with self.assertRaises(InvalidHeader) as raised: + raise reject.exception + self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Accept header") + + def test_multiple_accept(self): + client = ClientConnection("wss://example.com/") + response = self.make_accept_response(client) + response.headers["Sec-WebSocket-Accept"] = ACCEPT + [reject], _bytes_to_send = client.receive_data(response.serialize()) + + self.assertIsInstance(reject, Reject) + with self.assertRaises(InvalidHeader) as raised: + raise reject.exception + self.assertEqual( + str(raised.exception), + "invalid Sec-WebSocket-Accept header: " + "more than one Sec-WebSocket-Accept header found", + ) + + def test_invalid_accept(self): + client = ClientConnection("wss://example.com/") + response = self.make_accept_response(client) + del response.headers["Sec-WebSocket-Accept"] + response.headers["Sec-WebSocket-Accept"] = ACCEPT + [reject], _bytes_to_send = client.receive_data(response.serialize()) + + self.assertIsInstance(reject, Reject) + with self.assertRaises(InvalidHeader) as raised: + raise reject.exception + self.assertEqual( + str(raised.exception), f"invalid Sec-WebSocket-Accept header: {ACCEPT}" + ) + + def test_no_extensions(self): + client = ClientConnection("wss://example.com/") + response = self.make_accept_response(client) + [accept], _bytes_to_send = client.receive_data(response.serialize()) + + self.assertIsInstance(accept, Accept) + self.assertEqual(client.extensions, []) + + def test_no_extension(self): + client = ClientConnection( + "wss://example.com/", extensions=[ClientOpExtensionFactory()] + ) + response = self.make_accept_response(client) + response.headers["Sec-WebSocket-Extensions"] = "x-op; op" + [accept], _bytes_to_send = client.receive_data(response.serialize()) + + self.assertIsInstance(accept, Accept) + self.assertEqual(client.extensions, [OpExtension()]) + + def test_extension(self): + client = ClientConnection( + "wss://example.com/", extensions=[ClientRsv2ExtensionFactory()] + ) + response = self.make_accept_response(client) + response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" + [accept], _bytes_to_send = client.receive_data(response.serialize()) + + self.assertIsInstance(accept, Accept) + self.assertEqual(client.extensions, [Rsv2Extension()]) + + def test_unexpected_extension(self): + client = ClientConnection("wss://example.com/") + response = self.make_accept_response(client) + response.headers["Sec-WebSocket-Extensions"] = "x-op; op" + [reject], _bytes_to_send = client.receive_data(response.serialize()) + + self.assertIsInstance(reject, Reject) + with self.assertRaises(InvalidHandshake) as raised: + raise reject.exception + self.assertEqual(str(raised.exception), "no extensions supported") + + def test_unsupported_extension(self): + client = ClientConnection( + "wss://example.com/", extensions=[ClientRsv2ExtensionFactory()] + ) + response = self.make_accept_response(client) + response.headers["Sec-WebSocket-Extensions"] = "x-op; op" + [reject], _bytes_to_send = client.receive_data(response.serialize()) + + self.assertIsInstance(reject, Reject) + with self.assertRaises(InvalidHandshake) as raised: + raise reject.exception + self.assertEqual( + str(raised.exception), + "Unsupported extension: name = x-op, params = [('op', None)]", + ) + + def test_supported_extension_parameters(self): + client = ClientConnection( + "wss://example.com/", extensions=[ClientOpExtensionFactory("this")] + ) + response = self.make_accept_response(client) + response.headers["Sec-WebSocket-Extensions"] = "x-op; op=this" + [accept], _bytes_to_send = client.receive_data(response.serialize()) + + self.assertIsInstance(accept, Accept) + self.assertEqual(client.extensions, [OpExtension("this")]) + + def test_unsupported_extension_parameters(self): + client = ClientConnection( + "wss://example.com/", extensions=[ClientOpExtensionFactory("this")] + ) + response = self.make_accept_response(client) + response.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" + [reject], _bytes_to_send = client.receive_data(response.serialize()) + + self.assertIsInstance(reject, Reject) + with self.assertRaises(InvalidHandshake) as raised: + raise reject.exception + self.assertEqual( + str(raised.exception), + "Unsupported extension: name = x-op, params = [('op', 'that')]", + ) + + def test_multiple_supported_extension_parameters(self): + client = ClientConnection( + "wss://example.com/", + extensions=[ + ClientOpExtensionFactory("this"), + ClientOpExtensionFactory("that"), + ], + ) + response = self.make_accept_response(client) + response.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" + [accept], _bytes_to_send = client.receive_data(response.serialize()) + + self.assertIsInstance(accept, Accept) + self.assertEqual(client.extensions, [OpExtension("that")]) + + def test_multiple_extensions(self): + client = ClientConnection( + "wss://example.com/", + extensions=[ClientOpExtensionFactory(), ClientRsv2ExtensionFactory()], + ) + response = self.make_accept_response(client) + response.headers["Sec-WebSocket-Extensions"] = "x-op; op" + response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" + [accept], _bytes_to_send = client.receive_data(response.serialize()) + + self.assertIsInstance(accept, Accept) + self.assertEqual(client.extensions, [OpExtension(), Rsv2Extension()]) + + def test_multiple_extensions_order(self): + client = ClientConnection( + "wss://example.com/", + extensions=[ClientOpExtensionFactory(), ClientRsv2ExtensionFactory()], + ) + response = self.make_accept_response(client) + response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" + response.headers["Sec-WebSocket-Extensions"] = "x-op; op" + [accept], _bytes_to_send = client.receive_data(response.serialize()) + + self.assertIsInstance(accept, Accept) + self.assertEqual(client.extensions, [Rsv2Extension(), OpExtension()]) + + def test_no_subprotocols(self): + client = ClientConnection("wss://example.com/") + response = self.make_accept_response(client) + [accept], _bytes_to_send = client.receive_data(response.serialize()) + + self.assertIsInstance(accept, Accept) + self.assertIsNone(client.subprotocol) + + def test_no_subprotocol(self): + client = ClientConnection("wss://example.com/", subprotocols=["chat"]) + response = self.make_accept_response(client) + [accept], _bytes_to_send = client.receive_data(response.serialize()) + + self.assertIsInstance(accept, Accept) + self.assertIsNone(client.subprotocol) + + def test_subprotocol(self): + client = ClientConnection("wss://example.com/", subprotocols=["chat"]) + response = self.make_accept_response(client) + response.headers["Sec-WebSocket-Protocol"] = "chat" + [accept], _bytes_to_send = client.receive_data(response.serialize()) + + self.assertIsInstance(accept, Accept) + self.assertEqual(client.subprotocol, "chat") + + def test_unexpected_subprotocol(self): + client = ClientConnection("wss://example.com/") + response = self.make_accept_response(client) + response.headers["Sec-WebSocket-Protocol"] = "chat" + [reject], _bytes_to_send = client.receive_data(response.serialize()) + + self.assertIsInstance(reject, Reject) + with self.assertRaises(InvalidHandshake) as raised: + raise reject.exception + self.assertEqual(str(raised.exception), "no subprotocols supported") + + def test_multiple_subprotocols(self): + client = ClientConnection( + "wss://example.com/", subprotocols=["superchat", "chat"] + ) + response = self.make_accept_response(client) + response.headers["Sec-WebSocket-Protocol"] = "superchat" + response.headers["Sec-WebSocket-Protocol"] = "chat" + [reject], _bytes_to_send = client.receive_data(response.serialize()) + + self.assertIsInstance(reject, Reject) + with self.assertRaises(InvalidHandshake) as raised: + raise reject.exception + self.assertEqual( + str(raised.exception), "multiple subprotocols: superchat, chat" + ) + + def test_supported_subprotocol(self): + client = ClientConnection( + "wss://example.com/", subprotocols=["superchat", "chat"] + ) + response = self.make_accept_response(client) + response.headers["Sec-WebSocket-Protocol"] = "chat" + [accept], _bytes_to_send = client.receive_data(response.serialize()) + + self.assertIsInstance(accept, Accept) + self.assertEqual(client.subprotocol, "chat") + + def test_unsupported_subprotocol(self): + client = ClientConnection( + "wss://example.com/", subprotocols=["superchat", "chat"] + ) + response = self.make_accept_response(client) + response.headers["Sec-WebSocket-Protocol"] = "otherchat" + [reject], _bytes_to_send = client.receive_data(response.serialize()) + + self.assertIsInstance(reject, Reject) + with self.assertRaises(InvalidHandshake) as raised: + raise reject.exception + self.assertEqual(str(raised.exception), "unsupported subprotocol: otherchat") diff --git a/tests/test_http11.py b/tests/test_http11.py index bca874aee..4574cf97e 100644 --- a/tests/test_http11.py +++ b/tests/test_http11.py @@ -101,7 +101,7 @@ def setUp(self): def parse(self): return Response.parse( - self.reader.read_line, self.reader.read_exact, self.reader.read_to_eof + self.reader.read_line, self.reader.read_exact, self.reader.read_to_eof, ) def test_parse(self): diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 91fb02a50..3054600e1 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -553,13 +553,13 @@ def test_recv_when_transfer_data_cancelled(self): def test_recv_prevents_concurrent_calls(self): recv = self.loop.create_task(self.protocol.recv()) - with self.assertRaisesRegex( - RuntimeError, + with self.assertRaises(RuntimeError) as raised: + self.loop.run_until_complete(self.protocol.recv()) + self.assertEqual( + str(raised.exception), "cannot call recv while another coroutine " "is already waiting for the next message", - ): - self.loop.run_until_complete(self.protocol.recv()) - + ) recv.cancel() # Test the send coroutine. diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 000000000..1d094a86d --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,649 @@ +import http +import unittest +import unittest.mock + +from websockets.connection import CONNECTING, OPEN +from websockets.datastructures import Headers +from websockets.events import Accept, Connect, Reject +from websockets.exceptions import InvalidHeader, InvalidOrigin, InvalidUpgrade +from websockets.http import USER_AGENT +from websockets.http11 import Request, Response +from websockets.server import * + +from .extensions.utils import ( + OpExtension, + Rsv2Extension, + ServerOpExtensionFactory, + ServerRsv2ExtensionFactory, +) +from .test_utils import ACCEPT, KEY +from .utils import DATE + + +class ConnectTests(unittest.TestCase): + def test_receive_connect(self): + server = ServerConnection() + [connect], bytes_to_send = server.receive_data( + ( + f"GET /test HTTP/1.1\r\n" + f"Host: example.com\r\n" + f"Upgrade: websocket\r\n" + f"Connection: Upgrade\r\n" + f"Sec-WebSocket-Key: {KEY}\r\n" + f"Sec-WebSocket-Version: 13\r\n" + f"User-Agent: {USER_AGENT}\r\n" + f"\r\n" + ).encode(), + ) + self.assertIsInstance(connect, Connect) + self.assertEqual(bytes_to_send, b"") + + def test_connect_request(self): + server = ServerConnection() + [connect], bytes_to_send = server.receive_data( + ( + f"GET /test HTTP/1.1\r\n" + f"Host: example.com\r\n" + f"Upgrade: websocket\r\n" + f"Connection: Upgrade\r\n" + f"Sec-WebSocket-Key: {KEY}\r\n" + f"Sec-WebSocket-Version: 13\r\n" + f"User-Agent: {USER_AGENT}\r\n" + f"\r\n" + ).encode(), + ) + self.assertEqual(connect.request.path, "/test") + self.assertEqual( + connect.request.headers, + Headers( + { + "Host": "example.com", + "Upgrade": "websocket", + "Connection": "Upgrade", + "Sec-WebSocket-Key": KEY, + "Sec-WebSocket-Version": "13", + "User-Agent": USER_AGENT, + } + ), + ) + + +class AcceptRejectTests(unittest.TestCase): + def make_connect_request(self): + return Request( + path="/test", + headers=Headers( + { + "Host": "example.com", + "Upgrade": "websocket", + "Connection": "Upgrade", + "Sec-WebSocket-Key": KEY, + "Sec-WebSocket-Version": "13", + "User-Agent": USER_AGENT, + } + ), + ) + + def test_send_accept(self): + server = ServerConnection() + with unittest.mock.patch("email.utils.formatdate", return_value=DATE): + accept = server.accept(Connect(self.make_connect_request())) + self.assertIsInstance(accept, Accept) + bytes_to_send = server.send(accept) + self.assertEqual( + bytes_to_send, + ( + f"HTTP/1.1 101 Switching Protocols\r\n" + f"Upgrade: websocket\r\n" + f"Connection: Upgrade\r\n" + f"Sec-WebSocket-Accept: {ACCEPT}\r\n" + f"Date: {DATE}\r\n" + f"Server: {USER_AGENT}\r\n" + f"\r\n" + ).encode(), + ) + self.assertEqual(server.state, OPEN) + + def test_send_reject(self): + server = ServerConnection() + with unittest.mock.patch("email.utils.formatdate", return_value=DATE): + reject = server.reject(http.HTTPStatus.NOT_FOUND, "Sorry folks.\n") + self.assertIsInstance(reject, Reject) + bytes_to_send = server.send(reject) + self.assertEqual( + bytes_to_send, + ( + f"HTTP/1.1 404 Not Found\r\n" + f"Date: {DATE}\r\n" + f"Server: {USER_AGENT}\r\n" + f"Content-Length: 13\r\n" + f"Content-Type: text/plain; charset=utf-8\r\n" + f"Connection: close\r\n" + f"\r\n" + f"Sorry folks.\n" + ).encode(), + ) + self.assertEqual(server.state, CONNECTING) + + def test_accept_response(self): + server = ServerConnection() + with unittest.mock.patch("email.utils.formatdate", return_value=DATE): + accept = server.accept(Connect(self.make_connect_request())) + self.assertIsInstance(accept.response, Response) + self.assertEqual(accept.response.status_code, 101) + self.assertEqual(accept.response.reason_phrase, "Switching Protocols") + self.assertEqual( + accept.response.headers, + Headers( + { + "Upgrade": "websocket", + "Connection": "Upgrade", + "Sec-WebSocket-Accept": ACCEPT, + "Date": DATE, + "Server": USER_AGENT, + } + ), + ) + self.assertIsNone(accept.response.body) + + def test_reject_response(self): + server = ServerConnection() + with unittest.mock.patch("email.utils.formatdate", return_value=DATE): + reject = server.reject(http.HTTPStatus.NOT_FOUND, "Sorry folks.\n") + self.assertIsInstance(reject.response, Response) + self.assertEqual(reject.response.status_code, 404) + self.assertEqual(reject.response.reason_phrase, "Not Found") + self.assertEqual( + reject.response.headers, + Headers( + { + "Date": DATE, + "Server": USER_AGENT, + "Content-Length": "13", + "Content-Type": "text/plain; charset=utf-8", + "Connection": "close", + } + ), + ) + self.assertEqual(reject.response.body, b"Sorry folks.\n") + + def test_basic(self): + server = ServerConnection() + request = self.make_connect_request() + accept = server.accept(Connect(request)) + + self.assertIsInstance(accept, Accept) + + def test_unexpected_exception(self): + server = ServerConnection() + request = self.make_connect_request() + with unittest.mock.patch( + "websockets.server.ServerConnection.process_request", + side_effect=Exception("BOOM"), + ): + reject = server.accept(Connect(request)) + + self.assertIsInstance(reject, Reject) + self.assertEqual(reject.response.status_code, 500) + with self.assertRaises(Exception) as raised: + raise reject.exception + self.assertEqual(str(raised.exception), "BOOM") + + def test_missing_connection(self): + server = ServerConnection() + request = self.make_connect_request() + del request.headers["Connection"] + reject = server.accept(Connect(request)) + + self.assertIsInstance(reject, Reject) + self.assertEqual(reject.response.status_code, 426) + self.assertEqual(reject.response.headers["Upgrade"], "websocket") + with self.assertRaises(InvalidUpgrade) as raised: + raise reject.exception + self.assertEqual(str(raised.exception), "missing Connection header") + + def test_invalid_connection(self): + server = ServerConnection() + request = self.make_connect_request() + del request.headers["Connection"] + request.headers["Connection"] = "close" + reject = server.accept(Connect(request)) + + self.assertIsInstance(reject, Reject) + self.assertEqual(reject.response.status_code, 426) + self.assertEqual(reject.response.headers["Upgrade"], "websocket") + with self.assertRaises(InvalidUpgrade) as raised: + raise reject.exception + self.assertEqual(str(raised.exception), "invalid Connection header: close") + + def test_missing_upgrade(self): + server = ServerConnection() + request = self.make_connect_request() + del request.headers["Upgrade"] + reject = server.accept(Connect(request)) + + self.assertIsInstance(reject, Reject) + self.assertEqual(reject.response.status_code, 426) + self.assertEqual(reject.response.headers["Upgrade"], "websocket") + with self.assertRaises(InvalidUpgrade) as raised: + raise reject.exception + self.assertEqual(str(raised.exception), "missing Upgrade header") + + def test_invalid_upgrade(self): + server = ServerConnection() + request = self.make_connect_request() + del request.headers["Upgrade"] + request.headers["Upgrade"] = "h2c" + reject = server.accept(Connect(request)) + + self.assertIsInstance(reject, Reject) + self.assertEqual(reject.response.status_code, 426) + self.assertEqual(reject.response.headers["Upgrade"], "websocket") + with self.assertRaises(InvalidUpgrade) as raised: + raise reject.exception + self.assertEqual(str(raised.exception), "invalid Upgrade header: h2c") + + def test_missing_key(self): + server = ServerConnection() + request = self.make_connect_request() + del request.headers["Sec-WebSocket-Key"] + reject = server.accept(Connect(request)) + + self.assertIsInstance(reject, Reject) + self.assertEqual(reject.response.status_code, 400) + with self.assertRaises(InvalidHeader) as raised: + raise reject.exception + self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Key header") + + def test_multiple_key(self): + server = ServerConnection() + request = self.make_connect_request() + request.headers["Sec-WebSocket-Key"] = KEY + reject = server.accept(Connect(request)) + + self.assertIsInstance(reject, Reject) + self.assertEqual(reject.response.status_code, 400) + with self.assertRaises(InvalidHeader) as raised: + raise reject.exception + self.assertEqual( + str(raised.exception), + "invalid Sec-WebSocket-Key header: " + "more than one Sec-WebSocket-Key header found", + ) + + def test_invalid_key(self): + server = ServerConnection() + request = self.make_connect_request() + del request.headers["Sec-WebSocket-Key"] + request.headers["Sec-WebSocket-Key"] = "not Base64 data!" + reject = server.accept(Connect(request)) + + self.assertIsInstance(reject, Reject) + self.assertEqual(reject.response.status_code, 400) + with self.assertRaises(InvalidHeader) as raised: + raise reject.exception + self.assertEqual( + str(raised.exception), "invalid Sec-WebSocket-Key header: not Base64 data!" + ) + + def test_truncated_key(self): + server = ServerConnection() + request = self.make_connect_request() + del request.headers["Sec-WebSocket-Key"] + request.headers["Sec-WebSocket-Key"] = KEY[ + :16 + ] # 12 bytes instead of 16, Base64-encoded + reject = server.accept(Connect(request)) + + self.assertIsInstance(reject, Reject) + self.assertEqual(reject.response.status_code, 400) + with self.assertRaises(InvalidHeader) as raised: + raise reject.exception + self.assertEqual( + str(raised.exception), f"invalid Sec-WebSocket-Key header: {KEY[:16]}" + ) + + def test_missing_version(self): + server = ServerConnection() + request = self.make_connect_request() + del request.headers["Sec-WebSocket-Version"] + reject = server.accept(Connect(request)) + + self.assertIsInstance(reject, Reject) + self.assertEqual(reject.response.status_code, 400) + with self.assertRaises(InvalidHeader) as raised: + raise reject.exception + self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Version header") + + def test_multiple_version(self): + server = ServerConnection() + request = self.make_connect_request() + request.headers["Sec-WebSocket-Version"] = "11" + reject = server.accept(Connect(request)) + + self.assertIsInstance(reject, Reject) + self.assertEqual(reject.response.status_code, 400) + with self.assertRaises(InvalidHeader) as raised: + raise reject.exception + self.assertEqual( + str(raised.exception), + "invalid Sec-WebSocket-Version header: " + "more than one Sec-WebSocket-Version header found", + ) + + def test_invalid_version(self): + server = ServerConnection() + request = self.make_connect_request() + del request.headers["Sec-WebSocket-Version"] + request.headers["Sec-WebSocket-Version"] = "11" + reject = server.accept(Connect(request)) + + self.assertIsInstance(reject, Reject) + self.assertEqual(reject.response.status_code, 400) + with self.assertRaises(InvalidHeader) as raised: + raise reject.exception + self.assertEqual( + str(raised.exception), "invalid Sec-WebSocket-Version header: 11" + ) + + def test_no_origin(self): + server = ServerConnection(origins=["https://example.com"]) + request = self.make_connect_request() + reject = server.accept(Connect(request)) + + self.assertIsInstance(reject, Reject) + self.assertEqual(reject.response.status_code, 403) + with self.assertRaises(InvalidOrigin) as raised: + raise reject.exception + self.assertEqual(str(raised.exception), "missing Origin header") + + def test_origin(self): + server = ServerConnection(origins=["https://example.com"]) + request = self.make_connect_request() + request.headers["Origin"] = "https://example.com" + accept = server.accept(Connect(request)) + + self.assertIsInstance(accept, Accept) + self.assertEqual(server.origin, "https://example.com") + + def test_unexpected_origin(self): + server = ServerConnection(origins=["https://example.com"]) + request = self.make_connect_request() + request.headers["Origin"] = "https://other.example.com" + reject = server.accept(Connect(request)) + + self.assertIsInstance(reject, Reject) + self.assertEqual(reject.response.status_code, 403) + with self.assertRaises(InvalidOrigin) as raised: + raise reject.exception + self.assertEqual( + str(raised.exception), "invalid Origin header: https://other.example.com" + ) + + def test_multiple_origin(self): + server = ServerConnection( + origins=["https://example.com", "https://other.example.com"] + ) + request = self.make_connect_request() + request.headers["Origin"] = "https://example.com" + request.headers["Origin"] = "https://other.example.com" + reject = server.accept(Connect(request)) + + self.assertIsInstance(reject, Reject) + # This is prohibited by the HTTP specification, so the return code is + # 400 Bad Request rather than 403 Forbidden. + self.assertEqual(reject.response.status_code, 400) + with self.assertRaises(InvalidHeader) as raised: + raise reject.exception + self.assertEqual( + str(raised.exception), + "invalid Origin header: more than one Origin header found", + ) + + def test_supported_origin(self): + server = ServerConnection( + origins=["https://example.com", "https://other.example.com"] + ) + request = self.make_connect_request() + request.headers["Origin"] = "https://other.example.com" + accept = server.accept(Connect(request)) + + self.assertIsInstance(accept, Accept) + self.assertEqual(server.origin, "https://other.example.com") + + def test_unsupported_origin(self): + server = ServerConnection( + origins=["https://example.com", "https://other.example.com"] + ) + request = self.make_connect_request() + request.headers["Origin"] = "https://original.example.com" + reject = server.accept(Connect(request)) + + self.assertIsInstance(reject, Reject) + self.assertEqual(reject.response.status_code, 403) + with self.assertRaises(InvalidOrigin) as raised: + raise reject.exception + self.assertEqual( + str(raised.exception), "invalid Origin header: https://original.example.com" + ) + + def test_no_origin_accepted(self): + server = ServerConnection(origins=[None]) + request = self.make_connect_request() + accept = server.accept(Connect(request)) + + self.assertIsInstance(accept, Accept) + self.assertIsNone(server.origin) + + def test_no_extensions(self): + server = ServerConnection() + request = self.make_connect_request() + accept = server.accept(Connect(request)) + + self.assertIsInstance(accept, Accept) + self.assertNotIn("Sec-WebSocket-Extensions", accept.response.headers) + self.assertEqual(server.extensions, []) + + def test_no_extension(self): + server = ServerConnection(extensions=[ServerOpExtensionFactory()]) + request = self.make_connect_request() + accept = server.accept(Connect(request)) + + self.assertIsInstance(accept, Accept) + self.assertNotIn("Sec-WebSocket-Extensions", accept.response.headers) + self.assertEqual(server.extensions, []) + + def test_extension(self): + server = ServerConnection(extensions=[ServerOpExtensionFactory()]) + request = self.make_connect_request() + request.headers["Sec-WebSocket-Extensions"] = "x-op; op" + accept = server.accept(Connect(request)) + + self.assertIsInstance(accept, Accept) + self.assertEqual( + accept.response.headers["Sec-WebSocket-Extensions"], "x-op; op" + ) + self.assertEqual(server.extensions, [OpExtension()]) + + def test_unexpected_extension(self): + server = ServerConnection() + request = self.make_connect_request() + request.headers["Sec-WebSocket-Extensions"] = "x-op; op" + accept = server.accept(Connect(request)) + + self.assertIsInstance(accept, Accept) + self.assertNotIn("Sec-WebSocket-Extensions", accept.response.headers) + self.assertEqual(server.extensions, []) + + def test_unsupported_extension(self): + server = ServerConnection(extensions=[ServerRsv2ExtensionFactory()]) + request = self.make_connect_request() + request.headers["Sec-WebSocket-Extensions"] = "x-op; op" + accept = server.accept(Connect(request)) + + self.assertIsInstance(accept, Accept) + self.assertNotIn("Sec-WebSocket-Extensions", accept.response.headers) + self.assertEqual(server.extensions, []) + + def test_supported_extension_parameters(self): + server = ServerConnection(extensions=[ServerOpExtensionFactory("this")]) + request = self.make_connect_request() + request.headers["Sec-WebSocket-Extensions"] = "x-op; op=this" + accept = server.accept(Connect(request)) + + self.assertIsInstance(accept, Accept) + self.assertEqual( + accept.response.headers["Sec-WebSocket-Extensions"], "x-op; op=this" + ) + self.assertEqual(server.extensions, [OpExtension("this")]) + + def test_unsupported_extension_parameters(self): + server = ServerConnection(extensions=[ServerOpExtensionFactory("this")]) + request = self.make_connect_request() + request.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" + accept = server.accept(Connect(request)) + + self.assertIsInstance(accept, Accept) + self.assertNotIn("Sec-WebSocket-Extensions", accept.response.headers) + self.assertEqual(server.extensions, []) + + def test_multiple_supported_extension_parameters(self): + server = ServerConnection( + extensions=[ + ServerOpExtensionFactory("this"), + ServerOpExtensionFactory("that"), + ] + ) + request = self.make_connect_request() + request.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" + accept = server.accept(Connect(request)) + + self.assertIsInstance(accept, Accept) + self.assertEqual( + accept.response.headers["Sec-WebSocket-Extensions"], "x-op; op=that" + ) + self.assertEqual(server.extensions, [OpExtension("that")]) + + def test_multiple_extensions(self): + server = ServerConnection( + extensions=[ServerOpExtensionFactory(), ServerRsv2ExtensionFactory()] + ) + request = self.make_connect_request() + request.headers["Sec-WebSocket-Extensions"] = "x-op; op" + request.headers["Sec-WebSocket-Extensions"] = "x-rsv2" + accept = server.accept(Connect(request)) + + self.assertIsInstance(accept, Accept) + self.assertEqual( + accept.response.headers["Sec-WebSocket-Extensions"], "x-op; op, x-rsv2" + ) + self.assertEqual(server.extensions, [OpExtension(), Rsv2Extension()]) + + def test_multiple_extensions_order(self): + server = ServerConnection( + extensions=[ServerOpExtensionFactory(), ServerRsv2ExtensionFactory()] + ) + request = self.make_connect_request() + request.headers["Sec-WebSocket-Extensions"] = "x-rsv2" + request.headers["Sec-WebSocket-Extensions"] = "x-op; op" + accept = server.accept(Connect(request)) + + self.assertIsInstance(accept, Accept) + self.assertEqual( + accept.response.headers["Sec-WebSocket-Extensions"], "x-rsv2, x-op; op" + ) + self.assertEqual(server.extensions, [Rsv2Extension(), OpExtension()]) + + def test_no_subprotocols(self): + server = ServerConnection() + request = self.make_connect_request() + accept = server.accept(Connect(request)) + + self.assertIsInstance(accept, Accept) + self.assertNotIn("Sec-WebSocket-Protocol", accept.response.headers) + self.assertIsNone(server.subprotocol) + + def test_no_subprotocol(self): + server = ServerConnection(subprotocols=["chat"]) + request = self.make_connect_request() + accept = server.accept(Connect(request)) + + self.assertIsInstance(accept, Accept) + self.assertNotIn("Sec-WebSocket-Protocol", accept.response.headers) + self.assertIsNone(server.subprotocol) + + def test_subprotocol(self): + server = ServerConnection(subprotocols=["chat"]) + request = self.make_connect_request() + request.headers["Sec-WebSocket-Protocol"] = "chat" + accept = server.accept(Connect(request)) + + self.assertIsInstance(accept, Accept) + self.assertEqual(accept.response.headers["Sec-WebSocket-Protocol"], "chat") + self.assertEqual(server.subprotocol, "chat") + + def test_unexpected_subprotocol(self): + server = ServerConnection() + request = self.make_connect_request() + request.headers["Sec-WebSocket-Protocol"] = "chat" + accept = server.accept(Connect(request)) + + self.assertIsInstance(accept, Accept) + self.assertNotIn("Sec-WebSocket-Protocol", accept.response.headers) + self.assertIsNone(server.subprotocol) + + def test_multiple_subprotocols(self): + server = ServerConnection(subprotocols=["superchat", "chat"]) + request = self.make_connect_request() + request.headers["Sec-WebSocket-Protocol"] = "superchat" + request.headers["Sec-WebSocket-Protocol"] = "chat" + accept = server.accept(Connect(request)) + + self.assertIsInstance(accept, Accept) + self.assertEqual(accept.response.headers["Sec-WebSocket-Protocol"], "superchat") + self.assertEqual(server.subprotocol, "superchat") + + def test_supported_subprotocol(self): + server = ServerConnection(subprotocols=["superchat", "chat"]) + request = self.make_connect_request() + request.headers["Sec-WebSocket-Protocol"] = "chat" + accept = server.accept(Connect(request)) + + self.assertIsInstance(accept, Accept) + self.assertEqual(accept.response.headers["Sec-WebSocket-Protocol"], "chat") + self.assertEqual(server.subprotocol, "chat") + + def test_unsupported_subprotocol(self): + server = ServerConnection(subprotocols=["superchat", "chat"]) + request = self.make_connect_request() + request.headers["Sec-WebSocket-Protocol"] = "otherchat" + accept = server.accept(Connect(request)) + + self.assertIsInstance(accept, Accept) + self.assertNotIn("Sec-WebSocket-Protocol", accept.response.headers) + self.assertIsNone(server.subprotocol) + + def test_extra_headers(self): + for extra_headers in [ + Headers({"X-Spam": "Eggs"}), + {"X-Spam": "Eggs"}, + [("X-Spam", "Eggs")], + lambda path, headers: Headers({"X-Spam": "Eggs"}), + lambda path, headers: {"X-Spam": "Eggs"}, + lambda path, headers: [("X-Spam", "Eggs")], + ]: + with self.subTest(extra_headers=extra_headers): + server = ServerConnection(extra_headers=extra_headers) + request = self.make_connect_request() + accept = server.accept(Connect(request)) + + self.assertIsInstance(accept, Accept) + self.assertEqual(accept.response.headers["X-Spam"], "Eggs") + + def test_extra_headers_overrides_server(self): + server = ServerConnection(extra_headers={"Server": "Other"}) + request = self.make_connect_request() + accept = server.accept(Connect(request)) + + self.assertIsInstance(accept, Accept) + self.assertEqual(accept.response.headers["Server"], "Other") diff --git a/tests/utils.py b/tests/utils.py index bbffa8649..790d25687 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,5 +1,6 @@ import asyncio import contextlib +import email.utils import functools import logging import os @@ -7,6 +8,9 @@ import unittest +DATE = email.utils.formatdate(usegmt=True) + + class GeneratorTestCase(unittest.TestCase): def assertGeneratorRunning(self, gen): """ From 1033db5d402ed3a241356f97d642cda0df82ce45 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jun 2020 10:48:43 +0200 Subject: [PATCH 0703/1539] Drop Event class. It was too thin. It didn't add any value. Using the same abstractions for connection events and wire messages is good enough for our purposes. --- src/websockets/client.py | 31 ++- src/websockets/connection.py | 31 ++- src/websockets/events.py | 27 --- src/websockets/http11.py | 3 + src/websockets/protocol.py | 4 +- src/websockets/server.py | 67 +++--- tests/test_client.py | 190 +++++++++-------- tests/test_server.py | 385 ++++++++++++++++------------------- 8 files changed, 341 insertions(+), 397 deletions(-) delete mode 100644 src/websockets/events.py diff --git a/src/websockets/client.py b/src/websockets/client.py index ec4eb88f5..50203f27c 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -5,7 +5,6 @@ from .asyncio_client import WebSocketClientProtocol, connect, unix_connect from .connection import CLIENT, CONNECTING, OPEN, Connection from .datastructures import Headers, HeadersLike, MultipleValuesError -from .events import Accept, Connect, Event, Reject from .exceptions import ( InvalidHandshake, InvalidHeader, @@ -67,9 +66,9 @@ def __init__( self.extra_headers = extra_headers self.key = generate_key() - def connect(self) -> Connect: + def connect(self) -> Request: """ - Create a Connect event to send to the server. + Create a WebSocket handshake request event to send to the server. """ headers = Headers() @@ -114,8 +113,7 @@ def connect(self) -> Connect: headers.setdefault("User-Agent", USER_AGENT) - request = Request(self.wsuri.resource_name, headers) - return Connect(request) + return Request(self.wsuri.resource_name, headers) def process_response(self, response: Response) -> None: """ @@ -153,13 +151,13 @@ def process_response(self, response: Response) -> None: try: s_w_accept = headers["Sec-WebSocket-Accept"] - except KeyError: - raise InvalidHeader("Sec-WebSocket-Accept") - except MultipleValuesError: + except KeyError as exc: + raise InvalidHeader("Sec-WebSocket-Accept") from exc + except MultipleValuesError as exc: raise InvalidHeader( "Sec-WebSocket-Accept", "more than one Sec-WebSocket-Accept header found", - ) + ) from exc if s_w_accept != accept_key(self.key): raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept) @@ -273,11 +271,11 @@ def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: return subprotocol - def send_in_connecting_state(self, event: Event) -> bytes: - assert isinstance(event, Connect) - - request = event.request + def send_request(self, request: Request) -> bytes: + """ + Convert a WebSocket handshake request to bytes to send to the server. + """ logger.debug("%s > GET %s HTTP/1.1", self.side, request.path) logger.debug("%s > %r", self.side, request.headers) @@ -291,9 +289,10 @@ def parse(self) -> Generator[None, None, None]: try: self.process_response(response) except InvalidHandshake as exc: - self.events.append(Reject(response, exc)) - return + response = response._replace(exception=exc) + logger.debug("Invalid handshake", exc_info=True) else: - self.events.append(Accept(response)) self.state = OPEN + finally: + self.events.append(response) yield from super().parse() diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 5789b6ea1..ac9aedd6b 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -1,14 +1,18 @@ import enum -from typing import Generator, Iterable, List, Tuple +from typing import Generator, List, Tuple, Union -from .events import Event from .exceptions import InvalidState +from .frames import Frame +from .http11 import Request, Response from .streams import StreamReader __all__ = ["Connection"] +Event = Union[Request, Response, Frame] + + # A WebSocket connection is either a server or a client. @@ -46,43 +50,38 @@ def __init__(self, state: State = OPEN) -> None: # Public APIs for receiving data and producing events - def receive_data(self, data: bytes) -> Tuple[Iterable[Event], bytes]: + def receive_data(self, data: bytes) -> Tuple[List[Event], List[bytes]]: self.reader.feed_data(data) return self.receive() - def receive_eof(self) -> Tuple[Iterable[Event], bytes]: + def receive_eof(self) -> Tuple[List[Event], List[bytes]]: self.reader.feed_eof() return self.receive() # Public APIs for receiving events and producing data - def send(self, event: Event) -> bytes: + def send_frame(self, frame: Frame) -> bytes: """ - Send an event to the remote endpoint. + Convert a WebSocket handshake response to bytes to send. """ - if self.state == OPEN: - raise NotImplementedError # not implemented yet - elif self.state == CONNECTING: - return self.send_in_connecting_state(event) - else: + # Defensive assertion for protocol compliance. + if self.state != OPEN: raise InvalidState( f"Cannot write to a WebSocket in the {self.state.name} state" ) + raise NotImplementedError # not implemented yet # Private APIs - def send_in_connecting_state(self, event: Event) -> bytes: - raise NotImplementedError - - def receive(self) -> Tuple[List[Event], bytes]: + def receive(self) -> Tuple[List[Event], List[bytes]]: # Run parser until more data is needed or EOF try: next(self.parser) except StopIteration: pass events, self.events = self.events, [] - return events, b"" + return events, [] def parse(self) -> Generator[None, None, None]: yield # not implemented yet diff --git a/src/websockets/events.py b/src/websockets/events.py deleted file mode 100644 index 196de9421..000000000 --- a/src/websockets/events.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import NamedTuple, Optional, Union - -from .http11 import Request, Response - - -__all__ = [ - "Accept", - "Connect", - "Event", - "Reject", -] - - -class Connect(NamedTuple): - request: Request - - -class Accept(NamedTuple): - response: Response - - -class Reject(NamedTuple): - response: Response - exception: Optional[Exception] - - -Event = Union[Connect, Accept, Reject] diff --git a/src/websockets/http11.py b/src/websockets/http11.py index e1d004881..58ee09253 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -127,6 +127,9 @@ class Response(NamedTuple): headers: Headers body: Optional[bytes] = None + # If processing the response triggers an exception, it's stored here. + exception: Optional[Exception] = None + @classmethod def parse( cls, diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 748c1ae66..58c4569d0 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -53,7 +53,7 @@ serialize_close, ) from .framing import Frame -from .typing import Data +from .typing import Data, Subprotocol __all__ = ["WebSocketCommonProtocol"] @@ -261,7 +261,7 @@ def __init__( # WebSocket protocol parameters. self.extensions: List[Extension] = [] - self.subprotocol: Optional[str] = None + self.subprotocol: Optional[Subprotocol] = None # The close code and reason are set when receiving a close frame or # losing the TCP connection. diff --git a/src/websockets/server.py b/src/websockets/server.py index f668ff5e7..095d9a17d 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -9,7 +9,6 @@ from .asyncio_server import WebSocketServer, WebSocketServerProtocol, serve, unix_serve from .connection import CONNECTING, OPEN, SERVER, Connection from .datastructures import Headers, HeadersLike, MultipleValuesError -from .events import Accept, Connect, Event, Reject from .exceptions import ( InvalidHandshake, InvalidHeader, @@ -69,15 +68,17 @@ def __init__( self.available_subprotocols = subprotocols self.extra_headers = extra_headers - def accept(self, connect: Connect) -> Union[Accept, Reject]: + def accept(self, request: Request) -> Response: """ - Create an ``Accept`` or ``Reject`` event to send to the client. + Create a WebSocket handshake response event to send to the client. - If the connection cannot be established, this method returns a - :class:`~websockets.events.Reject` event, which may be unexpected. + If the connection cannot be established, the response rejects the + connection, which may be unexpected. """ - request = connect.request + # TODO: when changing Request to a dataclass, set the exception + # attribute on the request rather than the Response, which will + # be semantically more correct. try: key, extensions_header, protocol_header = self.process_request(request) except InvalidOrigin as exc: @@ -85,8 +86,7 @@ def accept(self, connect: Connect) -> Union[Accept, Reject]: return self.reject( http.HTTPStatus.FORBIDDEN, f"Failed to open a WebSocket connection: {exc}.\n", - exception=exc, - ) + )._replace(exception=exc) except InvalidUpgrade as exc: logger.debug("Invalid upgrade", exc_info=True) return self.reject( @@ -98,15 +98,13 @@ def accept(self, connect: Connect) -> Union[Accept, Reject]: f"with a browser. You need a WebSocket client.\n" ), headers=Headers([("Upgrade", "websocket")]), - exception=exc, - ) + )._replace(exception=exc) except InvalidHandshake as exc: logger.debug("Invalid handshake", exc_info=True) return self.reject( http.HTTPStatus.BAD_REQUEST, f"Failed to open a WebSocket connection: {exc}.\n", - exception=exc, - ) + )._replace(exception=exc) except Exception as exc: logger.warning("Error in opening handshake", exc_info=True) return self.reject( @@ -115,8 +113,7 @@ def accept(self, connect: Connect) -> Union[Accept, Reject]: "Failed to open a WebSocket connection.\n" "See server log for more information.\n" ), - exception=exc, - ) + )._replace(exception=exc) headers = Headers() @@ -146,8 +143,7 @@ def accept(self, connect: Connect) -> Union[Accept, Reject]: headers.setdefault("Date", email.utils.formatdate(usegmt=True)) headers.setdefault("Server", USER_AGENT) - response = Response(101, "Switching Protocols", headers) - return Accept(response) + return Response(101, "Switching Protocols", headers) def process_request( self, request: Request @@ -189,29 +185,29 @@ def process_request( try: key = headers["Sec-WebSocket-Key"] - except KeyError: - raise InvalidHeader("Sec-WebSocket-Key") - except MultipleValuesError: + except KeyError as exc: + raise InvalidHeader("Sec-WebSocket-Key") from exc + except MultipleValuesError as exc: raise InvalidHeader( "Sec-WebSocket-Key", "more than one Sec-WebSocket-Key header found" - ) + ) from exc try: raw_key = base64.b64decode(key.encode(), validate=True) - except binascii.Error: - raise InvalidHeaderValue("Sec-WebSocket-Key", key) + except binascii.Error as exc: + raise InvalidHeaderValue("Sec-WebSocket-Key", key) from exc if len(raw_key) != 16: raise InvalidHeaderValue("Sec-WebSocket-Key", key) try: version = headers["Sec-WebSocket-Version"] - except KeyError: - raise InvalidHeader("Sec-WebSocket-Version") - except MultipleValuesError: + except KeyError as exc: + raise InvalidHeader("Sec-WebSocket-Version") from exc + except MultipleValuesError as exc: raise InvalidHeader( "Sec-WebSocket-Version", "more than one Sec-WebSocket-Version header found", - ) + ) from exc if version != "13": raise InvalidHeaderValue("Sec-WebSocket-Version", version) @@ -389,9 +385,9 @@ def reject( text: str, headers: Optional[Headers] = None, exception: Optional[Exception] = None, - ) -> Reject: + ) -> Response: """ - Create a ``Reject`` event to send to the client. + Create a HTTP response event to send to the client. A short plain text response is the best fallback when failing to establish a WebSocket connection. @@ -405,17 +401,16 @@ def reject( headers.setdefault("Content-Length", str(len(body))) headers.setdefault("Content-Type", "text/plain; charset=utf-8") headers.setdefault("Connection", "close") - response = Response(status.value, status.phrase, headers, body) - return Reject(response, exception) + return Response(status.value, status.phrase, headers, body) - def send_in_connecting_state(self, event: Event) -> bytes: - assert isinstance(event, (Accept, Reject)) + def send_response(self, response: Response) -> bytes: + """ + Convert a WebSocket handshake response to bytes to send to the client. - if isinstance(event, Accept): + """ + if response.status_code == 101: self.state = OPEN - response = event.response - logger.debug( "%s > HTTP/1.1 %d %s", self.side, @@ -431,5 +426,5 @@ def send_in_connecting_state(self, event: Event) -> bytes: def parse(self) -> Generator[None, None, None]: request = yield from Request.parse(self.reader.read_line) assert self.state == CONNECTING - self.events.append(Connect(request)) + self.events.append(request) yield from super().parse() diff --git a/tests/test_client.py b/tests/test_client.py index 1cf27349d..eef8eb13e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -4,7 +4,6 @@ from websockets.client import * from websockets.connection import CONNECTING, OPEN from websockets.datastructures import Headers -from websockets.events import Accept, Connect, Reject from websockets.exceptions import InvalidHandshake, InvalidHeader from websockets.http import USER_AGENT from websockets.http11 import Request, Response @@ -24,9 +23,9 @@ class ConnectTests(unittest.TestCase): def test_send_connect(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): client = ClientConnection("wss://example.com/test") - connect = client.connect() - self.assertIsInstance(connect, Connect) - bytes_to_send = client.send(connect) + request = client.connect() + self.assertIsInstance(request, Request) + bytes_to_send = client.send_request(request) self.assertEqual( bytes_to_send, ( @@ -44,11 +43,10 @@ def test_send_connect(self): def test_connect_request(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): client = ClientConnection("wss://example.com/test") - connect = client.connect() - self.assertIsInstance(connect.request, Request) - self.assertEqual(connect.request.path, "/test") + request = client.connect() + self.assertEqual(request.path, "/test") self.assertEqual( - connect.request.headers, + request.headers, Headers( { "Host": "example.com", @@ -63,7 +61,7 @@ def test_connect_request(self): def test_path(self): client = ClientConnection("wss://example.com/endpoint?test=1") - request = client.connect().request + request = client.connect() self.assertEqual(request.path, "/endpoint?test=1") @@ -78,19 +76,19 @@ def test_port(self): ]: with self.subTest(uri=uri): client = ClientConnection(uri) - request = client.connect().request + request = client.connect() self.assertEqual(request.headers["Host"], host) def test_user_info(self): client = ClientConnection("wss://hello:iloveyou@example.com/") - request = client.connect().request + request = client.connect() self.assertEqual(request.headers["Authorization"], "Basic aGVsbG86aWxvdmV5b3U=") def test_origin(self): client = ClientConnection("wss://example.com/", origin="https://example.com") - request = client.connect().request + request = client.connect() self.assertEqual(request.headers["Origin"], "https://example.com") @@ -98,13 +96,13 @@ def test_extensions(self): client = ClientConnection( "wss://example.com/", extensions=[ClientOpExtensionFactory()] ) - request = client.connect().request + request = client.connect() self.assertEqual(request.headers["Sec-WebSocket-Extensions"], "x-op; op") def test_subprotocols(self): client = ClientConnection("wss://example.com/", subprotocols=["chat"]) - request = client.connect().request + request = client.connect() self.assertEqual(request.headers["Sec-WebSocket-Protocol"], "chat") @@ -118,7 +116,7 @@ def test_extra_headers(self): client = ClientConnection( "wss://example.com/", extra_headers=extra_headers ) - request = client.connect().request + request = client.connect() self.assertEqual(request.headers["X-Spam"], "Eggs") @@ -126,7 +124,7 @@ def test_extra_headers_overrides_user_agent(self): client = ClientConnection( "wss://example.com/", extra_headers={"User-Agent": "Other"} ) - request = client.connect().request + request = client.connect() self.assertEqual(request.headers["User-Agent"], "Other") @@ -136,7 +134,7 @@ def test_receive_accept(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): client = ClientConnection("ws://example.com/test") client.connect() - [accept], bytes_to_send = client.receive_data( + [response], bytes_to_send = client.receive_data( ( f"HTTP/1.1 101 Switching Protocols\r\n" f"Upgrade: websocket\r\n" @@ -147,15 +145,15 @@ def test_receive_accept(self): f"\r\n" ).encode(), ) - self.assertIsInstance(accept, Accept) - self.assertEqual(bytes_to_send, b"") + self.assertIsInstance(response, Response) + self.assertEqual(bytes_to_send, []) self.assertEqual(client.state, OPEN) def test_receive_reject(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): client = ClientConnection("ws://example.com/test") client.connect() - [reject], bytes_to_send = client.receive_data( + [response], bytes_to_send = client.receive_data( ( f"HTTP/1.1 404 Not Found\r\n" f"Date: {DATE}\r\n" @@ -167,15 +165,15 @@ def test_receive_reject(self): f"Sorry folks.\n" ).encode(), ) - self.assertIsInstance(reject, Reject) - self.assertEqual(bytes_to_send, b"") + self.assertIsInstance(response, Response) + self.assertEqual(bytes_to_send, []) self.assertEqual(client.state, CONNECTING) def test_accept_response(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): client = ClientConnection("ws://example.com/test") client.connect() - [accept], _bytes_to_send = client.receive_data( + [response], _bytes_to_send = client.receive_data( ( f"HTTP/1.1 101 Switching Protocols\r\n" f"Upgrade: websocket\r\n" @@ -186,10 +184,10 @@ def test_accept_response(self): f"\r\n" ).encode(), ) - self.assertEqual(accept.response.status_code, 101) - self.assertEqual(accept.response.reason_phrase, "Switching Protocols") + self.assertEqual(response.status_code, 101) + self.assertEqual(response.reason_phrase, "Switching Protocols") self.assertEqual( - accept.response.headers, + response.headers, Headers( { "Upgrade": "websocket", @@ -200,13 +198,13 @@ def test_accept_response(self): } ), ) - self.assertIsNone(accept.response.body) + self.assertIsNone(response.body) def test_reject_response(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): client = ClientConnection("ws://example.com/test") client.connect() - [reject], _bytes_to_send = client.receive_data( + [response], _bytes_to_send = client.receive_data( ( f"HTTP/1.1 404 Not Found\r\n" f"Date: {DATE}\r\n" @@ -218,10 +216,10 @@ def test_reject_response(self): f"Sorry folks.\n" ).encode(), ) - self.assertEqual(reject.response.status_code, 404) - self.assertEqual(reject.response.reason_phrase, "Not Found") + self.assertEqual(response.status_code, 404) + self.assertEqual(response.reason_phrase, "Not Found") self.assertEqual( - reject.response.headers, + response.headers, Headers( { "Date": DATE, @@ -232,10 +230,10 @@ def test_reject_response(self): } ), ) - self.assertEqual(reject.response.body, b"Sorry folks.\n") + self.assertEqual(response.body, b"Sorry folks.\n") def make_accept_response(self, client): - request = client.connect().request + request = client.connect() return Response( status_code=101, reason_phrase="Switching Protocols", @@ -253,19 +251,19 @@ def make_accept_response(self, client): def test_basic(self): client = ClientConnection("wss://example.com/") response = self.make_accept_response(client) - [accept], _bytes_to_send = client.receive_data(response.serialize()) + [response], _bytes_to_send = client.receive_data(response.serialize()) - self.assertIsInstance(accept, Accept) + self.assertEqual(client.state, OPEN) def test_missing_connection(self): client = ClientConnection("wss://example.com/") response = self.make_accept_response(client) del response.headers["Connection"] - [reject], _bytes_to_send = client.receive_data(response.serialize()) + [response], _bytes_to_send = client.receive_data(response.serialize()) - self.assertIsInstance(reject, Reject) + self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHeader) as raised: - raise reject.exception + raise response.exception self.assertEqual(str(raised.exception), "missing Connection header") def test_invalid_connection(self): @@ -273,22 +271,22 @@ def test_invalid_connection(self): response = self.make_accept_response(client) del response.headers["Connection"] response.headers["Connection"] = "close" - [reject], _bytes_to_send = client.receive_data(response.serialize()) + [response], _bytes_to_send = client.receive_data(response.serialize()) - self.assertIsInstance(reject, Reject) + self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHeader) as raised: - raise reject.exception + raise response.exception self.assertEqual(str(raised.exception), "invalid Connection header: close") def test_missing_upgrade(self): client = ClientConnection("wss://example.com/") response = self.make_accept_response(client) del response.headers["Upgrade"] - [reject], _bytes_to_send = client.receive_data(response.serialize()) + [response], _bytes_to_send = client.receive_data(response.serialize()) - self.assertIsInstance(reject, Reject) + self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHeader) as raised: - raise reject.exception + raise response.exception self.assertEqual(str(raised.exception), "missing Upgrade header") def test_invalid_upgrade(self): @@ -296,33 +294,33 @@ def test_invalid_upgrade(self): response = self.make_accept_response(client) del response.headers["Upgrade"] response.headers["Upgrade"] = "h2c" - [reject], _bytes_to_send = client.receive_data(response.serialize()) + [response], _bytes_to_send = client.receive_data(response.serialize()) - self.assertIsInstance(reject, Reject) + self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHeader) as raised: - raise reject.exception + raise response.exception self.assertEqual(str(raised.exception), "invalid Upgrade header: h2c") def test_missing_accept(self): client = ClientConnection("wss://example.com/") response = self.make_accept_response(client) del response.headers["Sec-WebSocket-Accept"] - [reject], _bytes_to_send = client.receive_data(response.serialize()) + [response], _bytes_to_send = client.receive_data(response.serialize()) - self.assertIsInstance(reject, Reject) + self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHeader) as raised: - raise reject.exception + raise response.exception self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Accept header") def test_multiple_accept(self): client = ClientConnection("wss://example.com/") response = self.make_accept_response(client) response.headers["Sec-WebSocket-Accept"] = ACCEPT - [reject], _bytes_to_send = client.receive_data(response.serialize()) + [response], _bytes_to_send = client.receive_data(response.serialize()) - self.assertIsInstance(reject, Reject) + self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHeader) as raised: - raise reject.exception + raise response.exception self.assertEqual( str(raised.exception), "invalid Sec-WebSocket-Accept header: " @@ -334,11 +332,11 @@ def test_invalid_accept(self): response = self.make_accept_response(client) del response.headers["Sec-WebSocket-Accept"] response.headers["Sec-WebSocket-Accept"] = ACCEPT - [reject], _bytes_to_send = client.receive_data(response.serialize()) + [response], _bytes_to_send = client.receive_data(response.serialize()) - self.assertIsInstance(reject, Reject) + self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHeader) as raised: - raise reject.exception + raise response.exception self.assertEqual( str(raised.exception), f"invalid Sec-WebSocket-Accept header: {ACCEPT}" ) @@ -346,9 +344,9 @@ def test_invalid_accept(self): def test_no_extensions(self): client = ClientConnection("wss://example.com/") response = self.make_accept_response(client) - [accept], _bytes_to_send = client.receive_data(response.serialize()) + [response], _bytes_to_send = client.receive_data(response.serialize()) - self.assertIsInstance(accept, Accept) + self.assertEqual(client.state, OPEN) self.assertEqual(client.extensions, []) def test_no_extension(self): @@ -357,9 +355,9 @@ def test_no_extension(self): ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-op; op" - [accept], _bytes_to_send = client.receive_data(response.serialize()) + [response], _bytes_to_send = client.receive_data(response.serialize()) - self.assertIsInstance(accept, Accept) + self.assertEqual(client.state, OPEN) self.assertEqual(client.extensions, [OpExtension()]) def test_extension(self): @@ -368,20 +366,20 @@ def test_extension(self): ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" - [accept], _bytes_to_send = client.receive_data(response.serialize()) + [response], _bytes_to_send = client.receive_data(response.serialize()) - self.assertIsInstance(accept, Accept) + self.assertEqual(client.state, OPEN) self.assertEqual(client.extensions, [Rsv2Extension()]) def test_unexpected_extension(self): client = ClientConnection("wss://example.com/") response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-op; op" - [reject], _bytes_to_send = client.receive_data(response.serialize()) + [response], _bytes_to_send = client.receive_data(response.serialize()) - self.assertIsInstance(reject, Reject) + self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHandshake) as raised: - raise reject.exception + raise response.exception self.assertEqual(str(raised.exception), "no extensions supported") def test_unsupported_extension(self): @@ -390,11 +388,11 @@ def test_unsupported_extension(self): ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-op; op" - [reject], _bytes_to_send = client.receive_data(response.serialize()) + [response], _bytes_to_send = client.receive_data(response.serialize()) - self.assertIsInstance(reject, Reject) + self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHandshake) as raised: - raise reject.exception + raise response.exception self.assertEqual( str(raised.exception), "Unsupported extension: name = x-op, params = [('op', None)]", @@ -406,9 +404,9 @@ def test_supported_extension_parameters(self): ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-op; op=this" - [accept], _bytes_to_send = client.receive_data(response.serialize()) + [response], _bytes_to_send = client.receive_data(response.serialize()) - self.assertIsInstance(accept, Accept) + self.assertEqual(client.state, OPEN) self.assertEqual(client.extensions, [OpExtension("this")]) def test_unsupported_extension_parameters(self): @@ -417,11 +415,11 @@ def test_unsupported_extension_parameters(self): ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" - [reject], _bytes_to_send = client.receive_data(response.serialize()) + [response], _bytes_to_send = client.receive_data(response.serialize()) - self.assertIsInstance(reject, Reject) + self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHandshake) as raised: - raise reject.exception + raise response.exception self.assertEqual( str(raised.exception), "Unsupported extension: name = x-op, params = [('op', 'that')]", @@ -437,9 +435,9 @@ def test_multiple_supported_extension_parameters(self): ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" - [accept], _bytes_to_send = client.receive_data(response.serialize()) + [response], _bytes_to_send = client.receive_data(response.serialize()) - self.assertIsInstance(accept, Accept) + self.assertEqual(client.state, OPEN) self.assertEqual(client.extensions, [OpExtension("that")]) def test_multiple_extensions(self): @@ -450,9 +448,9 @@ def test_multiple_extensions(self): response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-op; op" response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" - [accept], _bytes_to_send = client.receive_data(response.serialize()) + [response], _bytes_to_send = client.receive_data(response.serialize()) - self.assertIsInstance(accept, Accept) + self.assertEqual(client.state, OPEN) self.assertEqual(client.extensions, [OpExtension(), Rsv2Extension()]) def test_multiple_extensions_order(self): @@ -463,45 +461,45 @@ def test_multiple_extensions_order(self): response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" response.headers["Sec-WebSocket-Extensions"] = "x-op; op" - [accept], _bytes_to_send = client.receive_data(response.serialize()) + [response], _bytes_to_send = client.receive_data(response.serialize()) - self.assertIsInstance(accept, Accept) + self.assertEqual(client.state, OPEN) self.assertEqual(client.extensions, [Rsv2Extension(), OpExtension()]) def test_no_subprotocols(self): client = ClientConnection("wss://example.com/") response = self.make_accept_response(client) - [accept], _bytes_to_send = client.receive_data(response.serialize()) + [response], _bytes_to_send = client.receive_data(response.serialize()) - self.assertIsInstance(accept, Accept) + self.assertEqual(client.state, OPEN) self.assertIsNone(client.subprotocol) def test_no_subprotocol(self): client = ClientConnection("wss://example.com/", subprotocols=["chat"]) response = self.make_accept_response(client) - [accept], _bytes_to_send = client.receive_data(response.serialize()) + [response], _bytes_to_send = client.receive_data(response.serialize()) - self.assertIsInstance(accept, Accept) + self.assertEqual(client.state, OPEN) self.assertIsNone(client.subprotocol) def test_subprotocol(self): client = ClientConnection("wss://example.com/", subprotocols=["chat"]) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Protocol"] = "chat" - [accept], _bytes_to_send = client.receive_data(response.serialize()) + [response], _bytes_to_send = client.receive_data(response.serialize()) - self.assertIsInstance(accept, Accept) + self.assertEqual(client.state, OPEN) self.assertEqual(client.subprotocol, "chat") def test_unexpected_subprotocol(self): client = ClientConnection("wss://example.com/") response = self.make_accept_response(client) response.headers["Sec-WebSocket-Protocol"] = "chat" - [reject], _bytes_to_send = client.receive_data(response.serialize()) + [response], _bytes_to_send = client.receive_data(response.serialize()) - self.assertIsInstance(reject, Reject) + self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHandshake) as raised: - raise reject.exception + raise response.exception self.assertEqual(str(raised.exception), "no subprotocols supported") def test_multiple_subprotocols(self): @@ -511,11 +509,11 @@ def test_multiple_subprotocols(self): response = self.make_accept_response(client) response.headers["Sec-WebSocket-Protocol"] = "superchat" response.headers["Sec-WebSocket-Protocol"] = "chat" - [reject], _bytes_to_send = client.receive_data(response.serialize()) + [response], _bytes_to_send = client.receive_data(response.serialize()) - self.assertIsInstance(reject, Reject) + self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHandshake) as raised: - raise reject.exception + raise response.exception self.assertEqual( str(raised.exception), "multiple subprotocols: superchat, chat" ) @@ -526,9 +524,9 @@ def test_supported_subprotocol(self): ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Protocol"] = "chat" - [accept], _bytes_to_send = client.receive_data(response.serialize()) + [response], _bytes_to_send = client.receive_data(response.serialize()) - self.assertIsInstance(accept, Accept) + self.assertEqual(client.state, OPEN) self.assertEqual(client.subprotocol, "chat") def test_unsupported_subprotocol(self): @@ -537,9 +535,9 @@ def test_unsupported_subprotocol(self): ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Protocol"] = "otherchat" - [reject], _bytes_to_send = client.receive_data(response.serialize()) + [response], _bytes_to_send = client.receive_data(response.serialize()) - self.assertIsInstance(reject, Reject) + self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHandshake) as raised: - raise reject.exception + raise response.exception self.assertEqual(str(raised.exception), "unsupported subprotocol: otherchat") diff --git a/tests/test_server.py b/tests/test_server.py index 1d094a86d..8b00cec11 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -4,7 +4,6 @@ from websockets.connection import CONNECTING, OPEN from websockets.datastructures import Headers -from websockets.events import Accept, Connect, Reject from websockets.exceptions import InvalidHeader, InvalidOrigin, InvalidUpgrade from websockets.http import USER_AGENT from websockets.http11 import Request, Response @@ -23,7 +22,7 @@ class ConnectTests(unittest.TestCase): def test_receive_connect(self): server = ServerConnection() - [connect], bytes_to_send = server.receive_data( + [request], bytes_to_send = server.receive_data( ( f"GET /test HTTP/1.1\r\n" f"Host: example.com\r\n" @@ -35,12 +34,12 @@ def test_receive_connect(self): f"\r\n" ).encode(), ) - self.assertIsInstance(connect, Connect) - self.assertEqual(bytes_to_send, b"") + self.assertIsInstance(request, Request) + self.assertEqual(bytes_to_send, []) def test_connect_request(self): server = ServerConnection() - [connect], bytes_to_send = server.receive_data( + [request], bytes_to_send = server.receive_data( ( f"GET /test HTTP/1.1\r\n" f"Host: example.com\r\n" @@ -52,9 +51,9 @@ def test_connect_request(self): f"\r\n" ).encode(), ) - self.assertEqual(connect.request.path, "/test") + self.assertEqual(request.path, "/test") self.assertEqual( - connect.request.headers, + request.headers, Headers( { "Host": "example.com", @@ -69,7 +68,7 @@ def test_connect_request(self): class AcceptRejectTests(unittest.TestCase): - def make_connect_request(self): + def make_request(self): return Request( path="/test", headers=Headers( @@ -87,9 +86,9 @@ def make_connect_request(self): def test_send_accept(self): server = ServerConnection() with unittest.mock.patch("email.utils.formatdate", return_value=DATE): - accept = server.accept(Connect(self.make_connect_request())) - self.assertIsInstance(accept, Accept) - bytes_to_send = server.send(accept) + response = server.accept(self.make_request()) + self.assertIsInstance(response, Response) + bytes_to_send = server.send_response(response) self.assertEqual( bytes_to_send, ( @@ -107,9 +106,9 @@ def test_send_accept(self): def test_send_reject(self): server = ServerConnection() with unittest.mock.patch("email.utils.formatdate", return_value=DATE): - reject = server.reject(http.HTTPStatus.NOT_FOUND, "Sorry folks.\n") - self.assertIsInstance(reject, Reject) - bytes_to_send = server.send(reject) + response = server.reject(http.HTTPStatus.NOT_FOUND, "Sorry folks.\n") + self.assertIsInstance(response, Response) + bytes_to_send = server.send_response(response) self.assertEqual( bytes_to_send, ( @@ -128,12 +127,12 @@ def test_send_reject(self): def test_accept_response(self): server = ServerConnection() with unittest.mock.patch("email.utils.formatdate", return_value=DATE): - accept = server.accept(Connect(self.make_connect_request())) - self.assertIsInstance(accept.response, Response) - self.assertEqual(accept.response.status_code, 101) - self.assertEqual(accept.response.reason_phrase, "Switching Protocols") + response = server.accept(self.make_request()) + self.assertIsInstance(response, Response) + self.assertEqual(response.status_code, 101) + self.assertEqual(response.reason_phrase, "Switching Protocols") self.assertEqual( - accept.response.headers, + response.headers, Headers( { "Upgrade": "websocket", @@ -144,17 +143,17 @@ def test_accept_response(self): } ), ) - self.assertIsNone(accept.response.body) + self.assertIsNone(response.body) def test_reject_response(self): server = ServerConnection() with unittest.mock.patch("email.utils.formatdate", return_value=DATE): - reject = server.reject(http.HTTPStatus.NOT_FOUND, "Sorry folks.\n") - self.assertIsInstance(reject.response, Response) - self.assertEqual(reject.response.status_code, 404) - self.assertEqual(reject.response.reason_phrase, "Not Found") + response = server.reject(http.HTTPStatus.NOT_FOUND, "Sorry folks.\n") + self.assertIsInstance(response, Response) + self.assertEqual(response.status_code, 404) + self.assertEqual(response.reason_phrase, "Not Found") self.assertEqual( - reject.response.headers, + response.headers, Headers( { "Date": DATE, @@ -165,106 +164,99 @@ def test_reject_response(self): } ), ) - self.assertEqual(reject.response.body, b"Sorry folks.\n") + self.assertEqual(response.body, b"Sorry folks.\n") def test_basic(self): server = ServerConnection() - request = self.make_connect_request() - accept = server.accept(Connect(request)) + request = self.make_request() + response = server.accept(request) - self.assertIsInstance(accept, Accept) + self.assertEqual(response.status_code, 101) def test_unexpected_exception(self): server = ServerConnection() - request = self.make_connect_request() + request = self.make_request() with unittest.mock.patch( "websockets.server.ServerConnection.process_request", side_effect=Exception("BOOM"), ): - reject = server.accept(Connect(request)) + response = server.accept(request) - self.assertIsInstance(reject, Reject) - self.assertEqual(reject.response.status_code, 500) + self.assertEqual(response.status_code, 500) with self.assertRaises(Exception) as raised: - raise reject.exception + raise response.exception self.assertEqual(str(raised.exception), "BOOM") def test_missing_connection(self): server = ServerConnection() - request = self.make_connect_request() + request = self.make_request() del request.headers["Connection"] - reject = server.accept(Connect(request)) + response = server.accept(request) - self.assertIsInstance(reject, Reject) - self.assertEqual(reject.response.status_code, 426) - self.assertEqual(reject.response.headers["Upgrade"], "websocket") + self.assertEqual(response.status_code, 426) + self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: - raise reject.exception + raise response.exception self.assertEqual(str(raised.exception), "missing Connection header") def test_invalid_connection(self): server = ServerConnection() - request = self.make_connect_request() + request = self.make_request() del request.headers["Connection"] request.headers["Connection"] = "close" - reject = server.accept(Connect(request)) + response = server.accept(request) - self.assertIsInstance(reject, Reject) - self.assertEqual(reject.response.status_code, 426) - self.assertEqual(reject.response.headers["Upgrade"], "websocket") + self.assertEqual(response.status_code, 426) + self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: - raise reject.exception + raise response.exception self.assertEqual(str(raised.exception), "invalid Connection header: close") def test_missing_upgrade(self): server = ServerConnection() - request = self.make_connect_request() + request = self.make_request() del request.headers["Upgrade"] - reject = server.accept(Connect(request)) + response = server.accept(request) - self.assertIsInstance(reject, Reject) - self.assertEqual(reject.response.status_code, 426) - self.assertEqual(reject.response.headers["Upgrade"], "websocket") + self.assertEqual(response.status_code, 426) + self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: - raise reject.exception + raise response.exception self.assertEqual(str(raised.exception), "missing Upgrade header") def test_invalid_upgrade(self): server = ServerConnection() - request = self.make_connect_request() + request = self.make_request() del request.headers["Upgrade"] request.headers["Upgrade"] = "h2c" - reject = server.accept(Connect(request)) + response = server.accept(request) - self.assertIsInstance(reject, Reject) - self.assertEqual(reject.response.status_code, 426) - self.assertEqual(reject.response.headers["Upgrade"], "websocket") + self.assertEqual(response.status_code, 426) + self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: - raise reject.exception + raise response.exception self.assertEqual(str(raised.exception), "invalid Upgrade header: h2c") def test_missing_key(self): server = ServerConnection() - request = self.make_connect_request() + request = self.make_request() del request.headers["Sec-WebSocket-Key"] - reject = server.accept(Connect(request)) + response = server.accept(request) - self.assertIsInstance(reject, Reject) - self.assertEqual(reject.response.status_code, 400) + self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise reject.exception + raise response.exception self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Key header") def test_multiple_key(self): server = ServerConnection() - request = self.make_connect_request() + request = self.make_request() request.headers["Sec-WebSocket-Key"] = KEY - reject = server.accept(Connect(request)) + response = server.accept(request) - self.assertIsInstance(reject, Reject) - self.assertEqual(reject.response.status_code, 400) + self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise reject.exception + raise response.exception self.assertEqual( str(raised.exception), "invalid Sec-WebSocket-Key header: " @@ -273,58 +265,54 @@ def test_multiple_key(self): def test_invalid_key(self): server = ServerConnection() - request = self.make_connect_request() + request = self.make_request() del request.headers["Sec-WebSocket-Key"] request.headers["Sec-WebSocket-Key"] = "not Base64 data!" - reject = server.accept(Connect(request)) + response = server.accept(request) - self.assertIsInstance(reject, Reject) - self.assertEqual(reject.response.status_code, 400) + self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise reject.exception + raise response.exception self.assertEqual( str(raised.exception), "invalid Sec-WebSocket-Key header: not Base64 data!" ) def test_truncated_key(self): server = ServerConnection() - request = self.make_connect_request() + request = self.make_request() del request.headers["Sec-WebSocket-Key"] request.headers["Sec-WebSocket-Key"] = KEY[ :16 ] # 12 bytes instead of 16, Base64-encoded - reject = server.accept(Connect(request)) + response = server.accept(request) - self.assertIsInstance(reject, Reject) - self.assertEqual(reject.response.status_code, 400) + self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise reject.exception + raise response.exception self.assertEqual( str(raised.exception), f"invalid Sec-WebSocket-Key header: {KEY[:16]}" ) def test_missing_version(self): server = ServerConnection() - request = self.make_connect_request() + request = self.make_request() del request.headers["Sec-WebSocket-Version"] - reject = server.accept(Connect(request)) + response = server.accept(request) - self.assertIsInstance(reject, Reject) - self.assertEqual(reject.response.status_code, 400) + self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise reject.exception + raise response.exception self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Version header") def test_multiple_version(self): server = ServerConnection() - request = self.make_connect_request() + request = self.make_request() request.headers["Sec-WebSocket-Version"] = "11" - reject = server.accept(Connect(request)) + response = server.accept(request) - self.assertIsInstance(reject, Reject) - self.assertEqual(reject.response.status_code, 400) + self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise reject.exception + raise response.exception self.assertEqual( str(raised.exception), "invalid Sec-WebSocket-Version header: " @@ -333,49 +321,46 @@ def test_multiple_version(self): def test_invalid_version(self): server = ServerConnection() - request = self.make_connect_request() + request = self.make_request() del request.headers["Sec-WebSocket-Version"] request.headers["Sec-WebSocket-Version"] = "11" - reject = server.accept(Connect(request)) + response = server.accept(request) - self.assertIsInstance(reject, Reject) - self.assertEqual(reject.response.status_code, 400) + self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise reject.exception + raise response.exception self.assertEqual( str(raised.exception), "invalid Sec-WebSocket-Version header: 11" ) def test_no_origin(self): server = ServerConnection(origins=["https://example.com"]) - request = self.make_connect_request() - reject = server.accept(Connect(request)) + request = self.make_request() + response = server.accept(request) - self.assertIsInstance(reject, Reject) - self.assertEqual(reject.response.status_code, 403) + self.assertEqual(response.status_code, 403) with self.assertRaises(InvalidOrigin) as raised: - raise reject.exception + raise response.exception self.assertEqual(str(raised.exception), "missing Origin header") def test_origin(self): server = ServerConnection(origins=["https://example.com"]) - request = self.make_connect_request() + request = self.make_request() request.headers["Origin"] = "https://example.com" - accept = server.accept(Connect(request)) + response = server.accept(request) - self.assertIsInstance(accept, Accept) + self.assertEqual(response.status_code, 101) self.assertEqual(server.origin, "https://example.com") def test_unexpected_origin(self): server = ServerConnection(origins=["https://example.com"]) - request = self.make_connect_request() + request = self.make_request() request.headers["Origin"] = "https://other.example.com" - reject = server.accept(Connect(request)) + response = server.accept(request) - self.assertIsInstance(reject, Reject) - self.assertEqual(reject.response.status_code, 403) + self.assertEqual(response.status_code, 403) with self.assertRaises(InvalidOrigin) as raised: - raise reject.exception + raise response.exception self.assertEqual( str(raised.exception), "invalid Origin header: https://other.example.com" ) @@ -384,17 +369,16 @@ def test_multiple_origin(self): server = ServerConnection( origins=["https://example.com", "https://other.example.com"] ) - request = self.make_connect_request() + request = self.make_request() request.headers["Origin"] = "https://example.com" request.headers["Origin"] = "https://other.example.com" - reject = server.accept(Connect(request)) + response = server.accept(request) - self.assertIsInstance(reject, Reject) # This is prohibited by the HTTP specification, so the return code is # 400 Bad Request rather than 403 Forbidden. - self.assertEqual(reject.response.status_code, 400) + self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise reject.exception + raise response.exception self.assertEqual( str(raised.exception), "invalid Origin header: more than one Origin header found", @@ -404,107 +388,102 @@ def test_supported_origin(self): server = ServerConnection( origins=["https://example.com", "https://other.example.com"] ) - request = self.make_connect_request() + request = self.make_request() request.headers["Origin"] = "https://other.example.com" - accept = server.accept(Connect(request)) + response = server.accept(request) - self.assertIsInstance(accept, Accept) + self.assertEqual(response.status_code, 101) self.assertEqual(server.origin, "https://other.example.com") def test_unsupported_origin(self): server = ServerConnection( origins=["https://example.com", "https://other.example.com"] ) - request = self.make_connect_request() + request = self.make_request() request.headers["Origin"] = "https://original.example.com" - reject = server.accept(Connect(request)) + response = server.accept(request) - self.assertIsInstance(reject, Reject) - self.assertEqual(reject.response.status_code, 403) + self.assertEqual(response.status_code, 403) with self.assertRaises(InvalidOrigin) as raised: - raise reject.exception + raise response.exception self.assertEqual( str(raised.exception), "invalid Origin header: https://original.example.com" ) def test_no_origin_accepted(self): server = ServerConnection(origins=[None]) - request = self.make_connect_request() - accept = server.accept(Connect(request)) + request = self.make_request() + response = server.accept(request) - self.assertIsInstance(accept, Accept) + self.assertEqual(response.status_code, 101) self.assertIsNone(server.origin) def test_no_extensions(self): server = ServerConnection() - request = self.make_connect_request() - accept = server.accept(Connect(request)) + request = self.make_request() + response = server.accept(request) - self.assertIsInstance(accept, Accept) - self.assertNotIn("Sec-WebSocket-Extensions", accept.response.headers) + self.assertEqual(response.status_code, 101) + self.assertNotIn("Sec-WebSocket-Extensions", response.headers) self.assertEqual(server.extensions, []) def test_no_extension(self): server = ServerConnection(extensions=[ServerOpExtensionFactory()]) - request = self.make_connect_request() - accept = server.accept(Connect(request)) + request = self.make_request() + response = server.accept(request) - self.assertIsInstance(accept, Accept) - self.assertNotIn("Sec-WebSocket-Extensions", accept.response.headers) + self.assertEqual(response.status_code, 101) + self.assertNotIn("Sec-WebSocket-Extensions", response.headers) self.assertEqual(server.extensions, []) def test_extension(self): server = ServerConnection(extensions=[ServerOpExtensionFactory()]) - request = self.make_connect_request() + request = self.make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op" - accept = server.accept(Connect(request)) + response = server.accept(request) - self.assertIsInstance(accept, Accept) - self.assertEqual( - accept.response.headers["Sec-WebSocket-Extensions"], "x-op; op" - ) + self.assertEqual(response.status_code, 101) + self.assertEqual(response.headers["Sec-WebSocket-Extensions"], "x-op; op") self.assertEqual(server.extensions, [OpExtension()]) def test_unexpected_extension(self): server = ServerConnection() - request = self.make_connect_request() + request = self.make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op" - accept = server.accept(Connect(request)) + response = server.accept(request) - self.assertIsInstance(accept, Accept) - self.assertNotIn("Sec-WebSocket-Extensions", accept.response.headers) + self.assertEqual(response.status_code, 101) + self.assertNotIn("Sec-WebSocket-Extensions", response.headers) self.assertEqual(server.extensions, []) def test_unsupported_extension(self): server = ServerConnection(extensions=[ServerRsv2ExtensionFactory()]) - request = self.make_connect_request() + request = self.make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op" - accept = server.accept(Connect(request)) + response = server.accept(request) - self.assertIsInstance(accept, Accept) - self.assertNotIn("Sec-WebSocket-Extensions", accept.response.headers) + self.assertEqual(response.status_code, 101) + self.assertNotIn("Sec-WebSocket-Extensions", response.headers) self.assertEqual(server.extensions, []) def test_supported_extension_parameters(self): server = ServerConnection(extensions=[ServerOpExtensionFactory("this")]) - request = self.make_connect_request() + request = self.make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op=this" - accept = server.accept(Connect(request)) + response = server.accept(request) - self.assertIsInstance(accept, Accept) - self.assertEqual( - accept.response.headers["Sec-WebSocket-Extensions"], "x-op; op=this" - ) + self.assertEqual(response.status_code, 101) + self.assertEqual(response.headers["Sec-WebSocket-Extensions"], "x-op; op=this") self.assertEqual(server.extensions, [OpExtension("this")]) def test_unsupported_extension_parameters(self): server = ServerConnection(extensions=[ServerOpExtensionFactory("this")]) - request = self.make_connect_request() + request = self.make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" - accept = server.accept(Connect(request)) + response = server.accept(request) - self.assertIsInstance(accept, Accept) - self.assertNotIn("Sec-WebSocket-Extensions", accept.response.headers) + self.assertEqual(response.status_code, 101) + self.assertNotIn("Sec-WebSocket-Extensions", response.headers) self.assertEqual(server.extensions, []) def test_multiple_supported_extension_parameters(self): @@ -514,28 +493,26 @@ def test_multiple_supported_extension_parameters(self): ServerOpExtensionFactory("that"), ] ) - request = self.make_connect_request() + request = self.make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" - accept = server.accept(Connect(request)) + response = server.accept(request) - self.assertIsInstance(accept, Accept) - self.assertEqual( - accept.response.headers["Sec-WebSocket-Extensions"], "x-op; op=that" - ) + self.assertEqual(response.status_code, 101) + self.assertEqual(response.headers["Sec-WebSocket-Extensions"], "x-op; op=that") self.assertEqual(server.extensions, [OpExtension("that")]) def test_multiple_extensions(self): server = ServerConnection( extensions=[ServerOpExtensionFactory(), ServerRsv2ExtensionFactory()] ) - request = self.make_connect_request() + request = self.make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op" request.headers["Sec-WebSocket-Extensions"] = "x-rsv2" - accept = server.accept(Connect(request)) + response = server.accept(request) - self.assertIsInstance(accept, Accept) + self.assertEqual(response.status_code, 101) self.assertEqual( - accept.response.headers["Sec-WebSocket-Extensions"], "x-op; op, x-rsv2" + response.headers["Sec-WebSocket-Extensions"], "x-op; op, x-rsv2" ) self.assertEqual(server.extensions, [OpExtension(), Rsv2Extension()]) @@ -543,84 +520,84 @@ def test_multiple_extensions_order(self): server = ServerConnection( extensions=[ServerOpExtensionFactory(), ServerRsv2ExtensionFactory()] ) - request = self.make_connect_request() + request = self.make_request() request.headers["Sec-WebSocket-Extensions"] = "x-rsv2" request.headers["Sec-WebSocket-Extensions"] = "x-op; op" - accept = server.accept(Connect(request)) + response = server.accept(request) - self.assertIsInstance(accept, Accept) + self.assertEqual(response.status_code, 101) self.assertEqual( - accept.response.headers["Sec-WebSocket-Extensions"], "x-rsv2, x-op; op" + response.headers["Sec-WebSocket-Extensions"], "x-rsv2, x-op; op" ) self.assertEqual(server.extensions, [Rsv2Extension(), OpExtension()]) def test_no_subprotocols(self): server = ServerConnection() - request = self.make_connect_request() - accept = server.accept(Connect(request)) + request = self.make_request() + response = server.accept(request) - self.assertIsInstance(accept, Accept) - self.assertNotIn("Sec-WebSocket-Protocol", accept.response.headers) + self.assertEqual(response.status_code, 101) + self.assertNotIn("Sec-WebSocket-Protocol", response.headers) self.assertIsNone(server.subprotocol) def test_no_subprotocol(self): server = ServerConnection(subprotocols=["chat"]) - request = self.make_connect_request() - accept = server.accept(Connect(request)) + request = self.make_request() + response = server.accept(request) - self.assertIsInstance(accept, Accept) - self.assertNotIn("Sec-WebSocket-Protocol", accept.response.headers) + self.assertEqual(response.status_code, 101) + self.assertNotIn("Sec-WebSocket-Protocol", response.headers) self.assertIsNone(server.subprotocol) def test_subprotocol(self): server = ServerConnection(subprotocols=["chat"]) - request = self.make_connect_request() + request = self.make_request() request.headers["Sec-WebSocket-Protocol"] = "chat" - accept = server.accept(Connect(request)) + response = server.accept(request) - self.assertIsInstance(accept, Accept) - self.assertEqual(accept.response.headers["Sec-WebSocket-Protocol"], "chat") + self.assertEqual(response.status_code, 101) + self.assertEqual(response.headers["Sec-WebSocket-Protocol"], "chat") self.assertEqual(server.subprotocol, "chat") def test_unexpected_subprotocol(self): server = ServerConnection() - request = self.make_connect_request() + request = self.make_request() request.headers["Sec-WebSocket-Protocol"] = "chat" - accept = server.accept(Connect(request)) + response = server.accept(request) - self.assertIsInstance(accept, Accept) - self.assertNotIn("Sec-WebSocket-Protocol", accept.response.headers) + self.assertEqual(response.status_code, 101) + self.assertNotIn("Sec-WebSocket-Protocol", response.headers) self.assertIsNone(server.subprotocol) def test_multiple_subprotocols(self): server = ServerConnection(subprotocols=["superchat", "chat"]) - request = self.make_connect_request() + request = self.make_request() request.headers["Sec-WebSocket-Protocol"] = "superchat" request.headers["Sec-WebSocket-Protocol"] = "chat" - accept = server.accept(Connect(request)) + response = server.accept(request) - self.assertIsInstance(accept, Accept) - self.assertEqual(accept.response.headers["Sec-WebSocket-Protocol"], "superchat") + self.assertEqual(response.status_code, 101) + self.assertEqual(response.headers["Sec-WebSocket-Protocol"], "superchat") self.assertEqual(server.subprotocol, "superchat") def test_supported_subprotocol(self): server = ServerConnection(subprotocols=["superchat", "chat"]) - request = self.make_connect_request() + request = self.make_request() request.headers["Sec-WebSocket-Protocol"] = "chat" - accept = server.accept(Connect(request)) + response = server.accept(request) - self.assertIsInstance(accept, Accept) - self.assertEqual(accept.response.headers["Sec-WebSocket-Protocol"], "chat") + self.assertEqual(response.status_code, 101) + self.assertEqual(response.headers["Sec-WebSocket-Protocol"], "chat") self.assertEqual(server.subprotocol, "chat") def test_unsupported_subprotocol(self): server = ServerConnection(subprotocols=["superchat", "chat"]) - request = self.make_connect_request() + request = self.make_request() request.headers["Sec-WebSocket-Protocol"] = "otherchat" - accept = server.accept(Connect(request)) + response = server.accept(request) - self.assertIsInstance(accept, Accept) - self.assertNotIn("Sec-WebSocket-Protocol", accept.response.headers) + self.assertEqual(response.status_code, 101) + self.assertNotIn("Sec-WebSocket-Protocol", response.headers) self.assertIsNone(server.subprotocol) def test_extra_headers(self): @@ -634,16 +611,16 @@ def test_extra_headers(self): ]: with self.subTest(extra_headers=extra_headers): server = ServerConnection(extra_headers=extra_headers) - request = self.make_connect_request() - accept = server.accept(Connect(request)) + request = self.make_request() + response = server.accept(request) - self.assertIsInstance(accept, Accept) - self.assertEqual(accept.response.headers["X-Spam"], "Eggs") + self.assertEqual(response.status_code, 101) + self.assertEqual(response.headers["X-Spam"], "Eggs") def test_extra_headers_overrides_server(self): server = ServerConnection(extra_headers={"Server": "Other"}) - request = self.make_connect_request() - accept = server.accept(Connect(request)) + request = self.make_request() + response = server.accept(request) - self.assertIsInstance(accept, Accept) - self.assertEqual(accept.response.headers["Server"], "Other") + self.assertEqual(response.status_code, 101) + self.assertEqual(response.headers["Server"], "Other") From f9177126eb6a6266c58345714ba75fdffd428802 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 26 Jul 2020 11:50:50 +0200 Subject: [PATCH 0704/1539] Change Sans I/O model to handle exceptions. In the new model, receive_data returns nothing and raises an exception on errors. Events received and bytes to send are obtained through other method calls. --- src/websockets/client.py | 16 +++--- src/websockets/connection.py | 48 ++++++++++++++--- src/websockets/server.py | 23 ++++++--- tests/test_client.py | 99 +++++++++++++++++++++++------------- tests/test_server.py | 27 +++++----- 5 files changed, 143 insertions(+), 70 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 50203f27c..d6250c7e9 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -1,6 +1,6 @@ import collections import logging -from typing import Generator, List, Optional, Sequence +from typing import Any, Generator, List, Optional, Sequence from .asyncio_client import WebSocketClientProtocol, connect, unix_connect from .connection import CLIENT, CONNECTING, OPEN, Connection @@ -47,9 +47,6 @@ class ClientConnection(Connection): - - side = CLIENT - def __init__( self, uri: str, @@ -57,8 +54,9 @@ def __init__( extensions: Optional[Sequence[ClientExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLike] = None, + **kwargs: Any, ): - super().__init__(state=CONNECTING) + super().__init__(side=CLIENT, state=CONNECTING, **kwargs) self.wsuri = parse_uri(uri) self.origin = origin self.available_extensions = extensions @@ -271,15 +269,15 @@ def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: return subprotocol - def send_request(self, request: Request) -> bytes: + def send_request(self, request: Request) -> None: """ - Convert a WebSocket handshake request to bytes to send to the server. + Send a WebSocket handshake request to the server. """ logger.debug("%s > GET %s HTTP/1.1", self.side, request.path) logger.debug("%s > %r", self.side, request.headers) - return request.serialize() + self.writes.append(request.serialize()) def parse(self) -> Generator[None, None, None]: response = yield from Response.parse( @@ -292,7 +290,7 @@ def parse(self) -> Generator[None, None, None]: response = response._replace(exception=exc) logger.debug("Invalid handshake", exc_info=True) else: - self.state = OPEN + self.set_state(OPEN) finally: self.events.append(response) yield from super().parse() diff --git a/src/websockets/connection.py b/src/websockets/connection.py index ac9aedd6b..616f2b3c2 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -1,5 +1,5 @@ import enum -from typing import Generator, List, Tuple, Union +from typing import Any, Generator, List, Tuple, Union from .exceptions import InvalidState from .frames import Frame @@ -41,22 +41,27 @@ class Connection: side: Side - def __init__(self, state: State = OPEN) -> None: + def __init__(self, side: Side, state: State = OPEN, **kwargs: Any) -> None: + self.side = side self.state = state self.reader = StreamReader() self.events: List[Event] = [] + self.writes: List[bytes] = [] self.parser = self.parse() next(self.parser) # start coroutine + def set_state(self, state: State) -> None: + self.state = state + # Public APIs for receiving data and producing events - def receive_data(self, data: bytes) -> Tuple[List[Event], List[bytes]]: + def receive_data(self, data: bytes) -> None: self.reader.feed_data(data) - return self.receive() + self.step_parser() - def receive_eof(self) -> Tuple[List[Event], List[bytes]]: + def receive_eof(self) -> None: self.reader.feed_eof() - return self.receive() + self.step_parser() # Public APIs for receiving events and producing data @@ -72,6 +77,34 @@ def send_frame(self, frame: Frame) -> bytes: ) raise NotImplementedError # not implemented yet + # Public API for getting incoming events after receiving data. + + def events_received(self) -> List[Event]: + """ + Return events read from the connection. + + Call this method immediately after calling any of the ``receive_*()`` + methods and process the events. + + """ + events, self.events = self.events, [] + return events + + # Public API for getting outgoing data after receiving data or sending events. + + def bytes_to_send(self) -> List[bytes]: + """ + Return data to write to the connection. + + Call this method immediately after calling any of the ``receive_*()`` + or ``send_*()`` methods and write the data to the connection. + + The empty bytestring signals the end of the data stream. + + """ + writes, self.writes = self.writes, [] + return writes + # Private APIs def receive(self) -> Tuple[List[Event], List[bytes]]: @@ -83,5 +116,8 @@ def receive(self) -> Tuple[List[Event], List[bytes]]: events, self.events = self.events, [] return events, [] + def step_parser(self) -> None: + next(self.parser) + def parse(self) -> Generator[None, None, None]: yield # not implemented yet diff --git a/src/websockets/server.py b/src/websockets/server.py index 095d9a17d..73156b33f 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -4,7 +4,17 @@ import email.utils import http import logging -from typing import Callable, Generator, List, Optional, Sequence, Tuple, Union, cast +from typing import ( + Any, + Callable, + Generator, + List, + Optional, + Sequence, + Tuple, + Union, + cast, +) from .asyncio_server import WebSocketServer, WebSocketServerProtocol, serve, unix_serve from .connection import CONNECTING, OPEN, SERVER, Connection @@ -61,8 +71,9 @@ def __init__( extensions: Optional[Sequence[ServerExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLikeOrCallable] = None, + **kwargs: Any, ): - super().__init__(state=CONNECTING) + super().__init__(SERVER, CONNECTING, **kwargs) self.origins = origins self.available_extensions = extensions self.available_subprotocols = subprotocols @@ -403,13 +414,13 @@ def reject( headers.setdefault("Connection", "close") return Response(status.value, status.phrase, headers, body) - def send_response(self, response: Response) -> bytes: + def send_response(self, response: Response) -> None: """ - Convert a WebSocket handshake response to bytes to send to the client. + Send a WebSocket handshake response to the client. """ if response.status_code == 101: - self.state = OPEN + self.set_state(OPEN) logger.debug( "%s > HTTP/1.1 %d %s", @@ -421,7 +432,7 @@ def send_response(self, response: Response) -> bytes: if response.body is not None: logger.debug("%s > body (%d bytes)", self.side, len(response.body)) - return response.serialize() + self.writes.append(response.serialize()) def parse(self) -> Generator[None, None, None]: request = yield from Request.parse(self.reader.read_line) diff --git a/tests/test_client.py b/tests/test_client.py index eef8eb13e..7a78ee09b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -25,10 +25,10 @@ def test_send_connect(self): client = ClientConnection("wss://example.com/test") request = client.connect() self.assertIsInstance(request, Request) - bytes_to_send = client.send_request(request) + client.send_request(request) self.assertEqual( - bytes_to_send, - ( + client.bytes_to_send(), + [ f"GET /test HTTP/1.1\r\n" f"Host: example.com\r\n" f"Upgrade: websocket\r\n" @@ -36,8 +36,8 @@ def test_send_connect(self): f"Sec-WebSocket-Key: {KEY}\r\n" f"Sec-WebSocket-Version: 13\r\n" f"User-Agent: {USER_AGENT}\r\n" - f"\r\n" - ).encode(), + f"\r\n".encode() + ], ) def test_connect_request(self): @@ -134,7 +134,7 @@ def test_receive_accept(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): client = ClientConnection("ws://example.com/test") client.connect() - [response], bytes_to_send = client.receive_data( + client.receive_data( ( f"HTTP/1.1 101 Switching Protocols\r\n" f"Upgrade: websocket\r\n" @@ -145,15 +145,15 @@ def test_receive_accept(self): f"\r\n" ).encode(), ) + [response] = client.events_received() self.assertIsInstance(response, Response) - self.assertEqual(bytes_to_send, []) self.assertEqual(client.state, OPEN) def test_receive_reject(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): client = ClientConnection("ws://example.com/test") client.connect() - [response], bytes_to_send = client.receive_data( + client.receive_data( ( f"HTTP/1.1 404 Not Found\r\n" f"Date: {DATE}\r\n" @@ -165,15 +165,15 @@ def test_receive_reject(self): f"Sorry folks.\n" ).encode(), ) + [response] = client.events_received() self.assertIsInstance(response, Response) - self.assertEqual(bytes_to_send, []) self.assertEqual(client.state, CONNECTING) def test_accept_response(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): client = ClientConnection("ws://example.com/test") client.connect() - [response], _bytes_to_send = client.receive_data( + client.receive_data( ( f"HTTP/1.1 101 Switching Protocols\r\n" f"Upgrade: websocket\r\n" @@ -184,6 +184,7 @@ def test_accept_response(self): f"\r\n" ).encode(), ) + [response] = client.events_received() self.assertEqual(response.status_code, 101) self.assertEqual(response.reason_phrase, "Switching Protocols") self.assertEqual( @@ -204,7 +205,7 @@ def test_reject_response(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): client = ClientConnection("ws://example.com/test") client.connect() - [response], _bytes_to_send = client.receive_data( + client.receive_data( ( f"HTTP/1.1 404 Not Found\r\n" f"Date: {DATE}\r\n" @@ -216,6 +217,7 @@ def test_reject_response(self): f"Sorry folks.\n" ).encode(), ) + [response] = client.events_received() self.assertEqual(response.status_code, 404) self.assertEqual(response.reason_phrase, "Not Found") self.assertEqual( @@ -251,7 +253,8 @@ def make_accept_response(self, client): def test_basic(self): client = ClientConnection("wss://example.com/") response = self.make_accept_response(client) - [response], _bytes_to_send = client.receive_data(response.serialize()) + client.receive_data(response.serialize()) + [response] = client.events_received() self.assertEqual(client.state, OPEN) @@ -259,7 +262,8 @@ def test_missing_connection(self): client = ClientConnection("wss://example.com/") response = self.make_accept_response(client) del response.headers["Connection"] - [response], _bytes_to_send = client.receive_data(response.serialize()) + client.receive_data(response.serialize()) + [response] = client.events_received() self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHeader) as raised: @@ -271,7 +275,8 @@ def test_invalid_connection(self): response = self.make_accept_response(client) del response.headers["Connection"] response.headers["Connection"] = "close" - [response], _bytes_to_send = client.receive_data(response.serialize()) + client.receive_data(response.serialize()) + [response] = client.events_received() self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHeader) as raised: @@ -282,7 +287,8 @@ def test_missing_upgrade(self): client = ClientConnection("wss://example.com/") response = self.make_accept_response(client) del response.headers["Upgrade"] - [response], _bytes_to_send = client.receive_data(response.serialize()) + client.receive_data(response.serialize()) + [response] = client.events_received() self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHeader) as raised: @@ -294,7 +300,8 @@ def test_invalid_upgrade(self): response = self.make_accept_response(client) del response.headers["Upgrade"] response.headers["Upgrade"] = "h2c" - [response], _bytes_to_send = client.receive_data(response.serialize()) + client.receive_data(response.serialize()) + [response] = client.events_received() self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHeader) as raised: @@ -305,7 +312,8 @@ def test_missing_accept(self): client = ClientConnection("wss://example.com/") response = self.make_accept_response(client) del response.headers["Sec-WebSocket-Accept"] - [response], _bytes_to_send = client.receive_data(response.serialize()) + client.receive_data(response.serialize()) + [response] = client.events_received() self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHeader) as raised: @@ -316,7 +324,8 @@ def test_multiple_accept(self): client = ClientConnection("wss://example.com/") response = self.make_accept_response(client) response.headers["Sec-WebSocket-Accept"] = ACCEPT - [response], _bytes_to_send = client.receive_data(response.serialize()) + client.receive_data(response.serialize()) + [response] = client.events_received() self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHeader) as raised: @@ -332,7 +341,8 @@ def test_invalid_accept(self): response = self.make_accept_response(client) del response.headers["Sec-WebSocket-Accept"] response.headers["Sec-WebSocket-Accept"] = ACCEPT - [response], _bytes_to_send = client.receive_data(response.serialize()) + client.receive_data(response.serialize()) + [response] = client.events_received() self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHeader) as raised: @@ -344,7 +354,8 @@ def test_invalid_accept(self): def test_no_extensions(self): client = ClientConnection("wss://example.com/") response = self.make_accept_response(client) - [response], _bytes_to_send = client.receive_data(response.serialize()) + client.receive_data(response.serialize()) + [response] = client.events_received() self.assertEqual(client.state, OPEN) self.assertEqual(client.extensions, []) @@ -355,7 +366,8 @@ def test_no_extension(self): ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-op; op" - [response], _bytes_to_send = client.receive_data(response.serialize()) + client.receive_data(response.serialize()) + [response] = client.events_received() self.assertEqual(client.state, OPEN) self.assertEqual(client.extensions, [OpExtension()]) @@ -366,7 +378,8 @@ def test_extension(self): ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" - [response], _bytes_to_send = client.receive_data(response.serialize()) + client.receive_data(response.serialize()) + [response] = client.events_received() self.assertEqual(client.state, OPEN) self.assertEqual(client.extensions, [Rsv2Extension()]) @@ -375,7 +388,8 @@ def test_unexpected_extension(self): client = ClientConnection("wss://example.com/") response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-op; op" - [response], _bytes_to_send = client.receive_data(response.serialize()) + client.receive_data(response.serialize()) + [response] = client.events_received() self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHandshake) as raised: @@ -388,7 +402,8 @@ def test_unsupported_extension(self): ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-op; op" - [response], _bytes_to_send = client.receive_data(response.serialize()) + client.receive_data(response.serialize()) + [response] = client.events_received() self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHandshake) as raised: @@ -404,7 +419,8 @@ def test_supported_extension_parameters(self): ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-op; op=this" - [response], _bytes_to_send = client.receive_data(response.serialize()) + client.receive_data(response.serialize()) + [response] = client.events_received() self.assertEqual(client.state, OPEN) self.assertEqual(client.extensions, [OpExtension("this")]) @@ -415,7 +431,8 @@ def test_unsupported_extension_parameters(self): ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" - [response], _bytes_to_send = client.receive_data(response.serialize()) + client.receive_data(response.serialize()) + [response] = client.events_received() self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHandshake) as raised: @@ -435,7 +452,8 @@ def test_multiple_supported_extension_parameters(self): ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" - [response], _bytes_to_send = client.receive_data(response.serialize()) + client.receive_data(response.serialize()) + [response] = client.events_received() self.assertEqual(client.state, OPEN) self.assertEqual(client.extensions, [OpExtension("that")]) @@ -448,7 +466,8 @@ def test_multiple_extensions(self): response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-op; op" response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" - [response], _bytes_to_send = client.receive_data(response.serialize()) + client.receive_data(response.serialize()) + [response] = client.events_received() self.assertEqual(client.state, OPEN) self.assertEqual(client.extensions, [OpExtension(), Rsv2Extension()]) @@ -461,7 +480,8 @@ def test_multiple_extensions_order(self): response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" response.headers["Sec-WebSocket-Extensions"] = "x-op; op" - [response], _bytes_to_send = client.receive_data(response.serialize()) + client.receive_data(response.serialize()) + [response] = client.events_received() self.assertEqual(client.state, OPEN) self.assertEqual(client.extensions, [Rsv2Extension(), OpExtension()]) @@ -469,7 +489,8 @@ def test_multiple_extensions_order(self): def test_no_subprotocols(self): client = ClientConnection("wss://example.com/") response = self.make_accept_response(client) - [response], _bytes_to_send = client.receive_data(response.serialize()) + client.receive_data(response.serialize()) + [response] = client.events_received() self.assertEqual(client.state, OPEN) self.assertIsNone(client.subprotocol) @@ -477,7 +498,8 @@ def test_no_subprotocols(self): def test_no_subprotocol(self): client = ClientConnection("wss://example.com/", subprotocols=["chat"]) response = self.make_accept_response(client) - [response], _bytes_to_send = client.receive_data(response.serialize()) + client.receive_data(response.serialize()) + [response] = client.events_received() self.assertEqual(client.state, OPEN) self.assertIsNone(client.subprotocol) @@ -486,7 +508,8 @@ def test_subprotocol(self): client = ClientConnection("wss://example.com/", subprotocols=["chat"]) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Protocol"] = "chat" - [response], _bytes_to_send = client.receive_data(response.serialize()) + client.receive_data(response.serialize()) + [response] = client.events_received() self.assertEqual(client.state, OPEN) self.assertEqual(client.subprotocol, "chat") @@ -495,7 +518,8 @@ def test_unexpected_subprotocol(self): client = ClientConnection("wss://example.com/") response = self.make_accept_response(client) response.headers["Sec-WebSocket-Protocol"] = "chat" - [response], _bytes_to_send = client.receive_data(response.serialize()) + client.receive_data(response.serialize()) + [response] = client.events_received() self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHandshake) as raised: @@ -509,7 +533,8 @@ def test_multiple_subprotocols(self): response = self.make_accept_response(client) response.headers["Sec-WebSocket-Protocol"] = "superchat" response.headers["Sec-WebSocket-Protocol"] = "chat" - [response], _bytes_to_send = client.receive_data(response.serialize()) + client.receive_data(response.serialize()) + [response] = client.events_received() self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHandshake) as raised: @@ -524,7 +549,8 @@ def test_supported_subprotocol(self): ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Protocol"] = "chat" - [response], _bytes_to_send = client.receive_data(response.serialize()) + client.receive_data(response.serialize()) + [response] = client.events_received() self.assertEqual(client.state, OPEN) self.assertEqual(client.subprotocol, "chat") @@ -535,7 +561,8 @@ def test_unsupported_subprotocol(self): ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Protocol"] = "otherchat" - [response], _bytes_to_send = client.receive_data(response.serialize()) + client.receive_data(response.serialize()) + [response] = client.events_received() self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHandshake) as raised: diff --git a/tests/test_server.py b/tests/test_server.py index 8b00cec11..a180b08e2 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -22,7 +22,7 @@ class ConnectTests(unittest.TestCase): def test_receive_connect(self): server = ServerConnection() - [request], bytes_to_send = server.receive_data( + server.receive_data( ( f"GET /test HTTP/1.1\r\n" f"Host: example.com\r\n" @@ -34,12 +34,12 @@ def test_receive_connect(self): f"\r\n" ).encode(), ) + [request] = server.events_received() self.assertIsInstance(request, Request) - self.assertEqual(bytes_to_send, []) def test_connect_request(self): server = ServerConnection() - [request], bytes_to_send = server.receive_data( + server.receive_data( ( f"GET /test HTTP/1.1\r\n" f"Host: example.com\r\n" @@ -51,6 +51,7 @@ def test_connect_request(self): f"\r\n" ).encode(), ) + [request] = server.events_received() self.assertEqual(request.path, "/test") self.assertEqual( request.headers, @@ -88,18 +89,18 @@ def test_send_accept(self): with unittest.mock.patch("email.utils.formatdate", return_value=DATE): response = server.accept(self.make_request()) self.assertIsInstance(response, Response) - bytes_to_send = server.send_response(response) + server.send_response(response) self.assertEqual( - bytes_to_send, - ( + server.bytes_to_send(), + [ f"HTTP/1.1 101 Switching Protocols\r\n" f"Upgrade: websocket\r\n" f"Connection: Upgrade\r\n" f"Sec-WebSocket-Accept: {ACCEPT}\r\n" f"Date: {DATE}\r\n" f"Server: {USER_AGENT}\r\n" - f"\r\n" - ).encode(), + f"\r\n".encode() + ], ) self.assertEqual(server.state, OPEN) @@ -108,10 +109,10 @@ def test_send_reject(self): with unittest.mock.patch("email.utils.formatdate", return_value=DATE): response = server.reject(http.HTTPStatus.NOT_FOUND, "Sorry folks.\n") self.assertIsInstance(response, Response) - bytes_to_send = server.send_response(response) + server.send_response(response) self.assertEqual( - bytes_to_send, - ( + server.bytes_to_send(), + [ f"HTTP/1.1 404 Not Found\r\n" f"Date: {DATE}\r\n" f"Server: {USER_AGENT}\r\n" @@ -119,8 +120,8 @@ def test_send_reject(self): f"Content-Type: text/plain; charset=utf-8\r\n" f"Connection: close\r\n" f"\r\n" - f"Sorry folks.\n" - ).encode(), + f"Sorry folks.\n".encode() + ], ) self.assertEqual(server.state, CONNECTING) From 207407404d2a1bfd95da040f3948892cbf17c950 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 26 Jul 2020 13:30:59 +0200 Subject: [PATCH 0705/1539] Implement Sans-I/O data transfer. --- setup.cfg | 2 +- src/websockets/client.py | 6 +- src/websockets/connection.py | 337 +++- src/websockets/exceptions.py | 2 +- .../extensions/permessage_deflate.py | 4 +- src/websockets/frames.py | 41 +- src/websockets/framing.py | 12 +- src/websockets/protocol.py | 3 +- src/websockets/server.py | 16 +- tests/extensions/test_permessage_deflate.py | 2 +- tests/test_connection.py | 1418 +++++++++++++++++ tests/test_frames.py | 115 +- 12 files changed, 1847 insertions(+), 111 deletions(-) create mode 100644 tests/test_connection.py diff --git a/setup.cfg b/setup.cfg index 02e70cdf5..5448b0f9b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,7 +5,7 @@ python-tag = py36.py37 license_file = LICENSE [flake8] -ignore = E731,F403,F405,W503 +ignore = E203,E731,F403,F405,W503 max-line-length = 88 [isort] diff --git a/src/websockets/client.py b/src/websockets/client.py index d6250c7e9..3f9777b94 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -1,6 +1,6 @@ import collections import logging -from typing import Any, Generator, List, Optional, Sequence +from typing import Generator, List, Optional, Sequence from .asyncio_client import WebSocketClientProtocol, connect, unix_connect from .connection import CLIENT, CONNECTING, OPEN, Connection @@ -54,9 +54,9 @@ def __init__( extensions: Optional[Sequence[ClientExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLike] = None, - **kwargs: Any, + max_size: Optional[int] = 2 ** 20, ): - super().__init__(side=CLIENT, state=CONNECTING, **kwargs) + super().__init__(side=CLIENT, state=CONNECTING, max_size=max_size) self.wsuri = parse_uri(uri) self.origin = origin self.available_extensions = extensions diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 616f2b3c2..ac30802db 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -1,14 +1,33 @@ import enum -from typing import Any, Generator, List, Tuple, Union - -from .exceptions import InvalidState -from .frames import Frame +import logging +from typing import Generator, List, Optional, Union + +from .exceptions import InvalidState, PayloadTooBig, ProtocolError +from .extensions.base import Extension +from .frames import ( + OP_BINARY, + OP_CLOSE, + OP_CONT, + OP_PING, + OP_PONG, + OP_TEXT, + Frame, + parse_close, + serialize_close, +) from .http11 import Request, Response from .streams import StreamReader +from .typing import Origin, Subprotocol -__all__ = ["Connection"] +__all__ = [ + "Connection", + "Side", + "State", + "SEND_EOF", +] +logger = logging.getLogger(__name__) Event = Union[Request, Response, Frame] @@ -37,45 +56,159 @@ class State(enum.IntEnum): CLOSED = State.CLOSED -class Connection: +# Sentinel to signal that the connection should be closed. - side: Side +SEND_EOF = b"" - def __init__(self, side: Side, state: State = OPEN, **kwargs: Any) -> None: + +class Connection: + def __init__( + self, side: Side, state: State = OPEN, max_size: Optional[int] = 2 ** 20, + ) -> None: + # Connection side. CLIENT or SERVER. self.side = side + + # Connnection state. CONNECTING and CLOSED states are handled in subclasses. + logger.debug("%s - initial state: %s", self.side, state.name) self.state = state + + # Maximum size of incoming messages in bytes. + self.max_size = max_size + + # Current size of incoming message in bytes. Only set while reading a + # fragmented message i.e. a data frames with the FIN bit not set. + self.cur_size: Optional[int] = None + + # True while sending a fragmented message i.e. a data frames with the + # FIN bit not set. + self.expect_continuation_frame = False + + # WebSocket protocol parameters. + self.origin: Optional[Origin] = None + self.extensions: List[Extension] = [] + self.subprotocol: Optional[Subprotocol] = None + + # Connection state isn't enough to tell if a close frame was received: + # when this side closes the connection, state is CLOSING as soon as a + # close frame is sent, before a close frame is received. + self.close_frame_received = False + + # Close code and reason. Set when receiving a close frame or when the + # TCP connection drops. + self.close_code: int + self.close_reason: str + + # Track if send_eof() was called. + self.eof_sent = False + + # Parser state. self.reader = StreamReader() self.events: List[Event] = [] self.writes: List[bytes] = [] self.parser = self.parse() next(self.parser) # start coroutine + self.parser_exc: Optional[Exception] = None def set_state(self, state: State) -> None: + logger.debug( + "%s - state change: %s > %s", self.side, self.state.name, state.name + ) self.state = state - # Public APIs for receiving data and producing events + # Public APIs for receiving data. def receive_data(self, data: bytes) -> None: + """ + Receive data from the connection. + + After calling this method: + + - You must call :meth:`bytes_to_send` and send this data. + - You should call :meth:`events_received` and process these events. + + """ self.reader.feed_data(data) self.step_parser() def receive_eof(self) -> None: + """ + Receive the end of the data stream from the connection. + + After calling this method: + + - You must call :meth:`bytes_to_send` and send this data. + - You shouldn't call :meth:`events_received` as it won't + return any new events. + + """ self.reader.feed_eof() self.step_parser() - # Public APIs for receiving events and producing data + # Public APIs for sending events. - def send_frame(self, frame: Frame) -> bytes: + def send_continuation(self, data: bytes, fin: bool) -> None: """ - Convert a WebSocket handshake response to bytes to send. + Send a continuation frame. """ - # Defensive assertion for protocol compliance. - if self.state != OPEN: - raise InvalidState( - f"Cannot write to a WebSocket in the {self.state.name} state" - ) - raise NotImplementedError # not implemented yet + if not self.expect_continuation_frame: + raise ProtocolError("unexpected continuation frame") + self.expect_continuation_frame = not fin + self.send_frame(Frame(fin, OP_CONT, data)) + + def send_text(self, data: bytes, fin: bool = True) -> None: + """ + Send a text frame. + + """ + if self.expect_continuation_frame: + raise ProtocolError("expected a continuation frame") + self.expect_continuation_frame = not fin + self.send_frame(Frame(fin, OP_TEXT, data)) + + def send_binary(self, data: bytes, fin: bool = True) -> None: + """ + Send a binary frame. + + """ + if self.expect_continuation_frame: + raise ProtocolError("expected a continuation frame") + self.expect_continuation_frame = not fin + self.send_frame(Frame(fin, OP_BINARY, data)) + + def send_close(self, code: Optional[int] = None, reason: str = "") -> None: + """ + Send a connection close frame. + + """ + if self.expect_continuation_frame: + raise ProtocolError("expected a continuation frame") + if code is None: + if reason != "": + raise ValueError("cannot send a reason without a code") + data = b"" + else: + data = serialize_close(code, reason) + self.send_frame(Frame(True, OP_CLOSE, data)) + # send_frame() guarantees that self.state is OPEN at this point. + # 7.1.3. The WebSocket Closing Handshake is Started + self.set_state(CLOSING) + if self.side is SERVER: + self.send_eof() + + def send_ping(self, data: bytes) -> None: + """ + Send a ping frame. + + """ + self.send_frame(Frame(True, OP_PING, data)) + + def send_pong(self, data: bytes) -> None: + """ + Send a pong frame. + + """ + self.send_frame(Frame(True, OP_PONG, data)) # Public API for getting incoming events after receiving data. @@ -105,19 +238,169 @@ def bytes_to_send(self) -> List[bytes]: writes, self.writes = self.writes, [] return writes - # Private APIs + # Private APIs for receiving data. - def receive(self) -> Tuple[List[Event], List[bytes]]: + def fail_connection(self, code: int = 1006, reason: str = "") -> None: + # Send a close frame when the state is OPEN (a close frame was already + # sent if it's CLOSING), except when failing the connection because of + # an error reading from or writing to the network. + if code != 1006 and self.state is OPEN: + self.send_frame(Frame(True, OP_CLOSE, serialize_close(code, reason))) + self.set_state(CLOSING) + if not self.eof_sent: + self.send_eof() + + def step_parser(self) -> None: # Run parser until more data is needed or EOF try: next(self.parser) except StopIteration: - pass - events, self.events = self.events, [] - return events, [] - - def step_parser(self) -> None: - next(self.parser) + # This happens if receive_data() or receive_eof() is called after + # the parser raised an exception. (It cannot happen after reaching + # EOF because receive_data() or receive_eof() would fail earlier.) + assert self.parser_exc is not None + raise RuntimeError( + "cannot receive data or EOF after an error" + ) from self.parser_exc + except ProtocolError as exc: + self.fail_connection(1002, str(exc)) + self.parser_exc = exc + raise + except EOFError as exc: + self.fail_connection(1006, str(exc)) + self.parser_exc = exc + raise + except UnicodeDecodeError as exc: + self.fail_connection(1007, f"{exc.reason} at position {exc.start}") + self.parser_exc = exc + raise + except PayloadTooBig as exc: + self.fail_connection(1009, str(exc)) + self.parser_exc = exc + raise + except Exception as exc: + logger.exception("unexpected exception in parser") + # Don't include exception details, which may be security-sensitive. + self.fail_connection(1011) + self.parser_exc = exc + raise def parse(self) -> Generator[None, None, None]: - yield # not implemented yet + while True: + eof = yield from self.reader.at_eof() + if eof: + if self.close_frame_received: + if not self.eof_sent: + self.send_eof() + yield + # Once the reader reaches EOF, its feed_data/eof() methods + # raise an error, so our receive_data/eof() methods never + # call step_parser(), so the generator shouldn't resume + # executing until it's garbage collected. + raise AssertionError( + "parser shouldn't step after EOF" + ) # pragma: no cover + else: + raise EOFError("unexpected end of stream") + + if self.max_size is None: + max_size = None + elif self.cur_size is None: + max_size = self.max_size + else: + max_size = self.max_size - self.cur_size + + frame = yield from Frame.parse( + self.reader.read_exact, + mask=self.side is SERVER, + max_size=max_size, + extensions=self.extensions, + ) + + if frame.opcode is OP_TEXT or frame.opcode is OP_BINARY: + # 5.5.1 Close: "The application MUST NOT send any more data + # frames after sending a Close frame." + if self.close_frame_received: + raise ProtocolError("data frame after close frame") + + if self.cur_size is not None: + raise ProtocolError("expected a continuation frame") + if frame.fin: + self.cur_size = None + else: + self.cur_size = len(frame.data) + + elif frame.opcode is OP_CONT: + # 5.5.1 Close: "The application MUST NOT send any more data + # frames after sending a Close frame." + if self.close_frame_received: + raise ProtocolError("data frame after close frame") + + if self.cur_size is None: + raise ProtocolError("unexpected continuation frame") + if frame.fin: + self.cur_size = None + else: + self.cur_size += len(frame.data) + + elif frame.opcode is OP_PING: + # 5.5.2. Ping: "Upon receipt of a Ping frame, an endpoint MUST + # send a Pong frame in response, unless it already received a + # Close frame." + if not self.close_frame_received: + pong_frame = Frame(True, OP_PONG, frame.data) + self.send_frame(pong_frame) + + elif frame.opcode is OP_PONG: + # 5.5.3 Pong: "A response to an unsolicited Pong frame is not + # expected." + pass + + elif frame.opcode is OP_CLOSE: + self.close_frame_received = True + # 7.1.5. The WebSocket Connection Close Code + # 7.1.6. The WebSocket Connection Close Reason + self.close_code, self.close_reason = parse_close(frame.data) + + if self.cur_size is not None: + raise ProtocolError("incomplete fragmented message") + # 5.5.1 Close: "If an endpoint receives a Close frame and did + # not previously send a Close frame, the endpoint MUST send a + # Close frame in response. (When sending a Close frame in + # response, the endpoint typically echos the status code it + # received.)" + if self.state is OPEN: + # Echo the original data instead of re-serializing it with + # serialize_close() because that fails when the close frame + # is empty and parse_close() synthetizes a 1005 close code. + # The rest is identical to send_close(). + self.send_frame(Frame(True, OP_CLOSE, frame.data)) + self.set_state(CLOSING) + if self.side is SERVER: + self.send_eof() + + else: # pragma: no cover + # This can't happen because Frame.parse() validates opcodes. + raise AssertionError(f"unexpected opcode: {frame.opcode:02x}") + + self.events.append(frame) + + # Private APIs for sending events. + + def send_frame(self, frame: Frame) -> None: + # Defensive assertion for protocol compliance. + if self.state is not OPEN: + raise InvalidState( + f"cannot write to a WebSocket in the {self.state.name} state" + ) + + logger.debug("%s > %r", self.side, frame) + self.writes.append( + frame.serialize(mask=self.side is CLIENT, extensions=self.extensions) + ) + + def send_eof(self) -> None: + assert not self.eof_sent + self.eof_sent = True + logger.debug("%s > EOF", self.side) + self.writes.append(SEND_EOF) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index e593f1adc..c60a3e10e 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -358,7 +358,7 @@ class PayloadTooBig(WebSocketException): class ProtocolError(WebSocketException): """ - Raised when the other side breaks the protocol. + Raised when a frame breaks the protocol. """ diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index f1adf8bb6..184183061 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -128,9 +128,7 @@ def decode(self, frame: Frame, *, max_size: Optional[int] = None) -> Frame: max_length = 0 if max_size is None else max_size data = self.decoder.decompress(data, max_length) if self.decoder.unconsumed_tail: - raise PayloadTooBig( - f"Uncompressed payload length exceeds size limit (? > {max_size} bytes)" - ) + raise PayloadTooBig(f"over size limit (? > {max_size} bytes)") # Allow garbage collection of the decoder if it won't be reused. if frame.fin and self.remote_no_context_takeover: diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 56dcf6171..2ff9dbd91 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -3,6 +3,7 @@ """ +import enum import io import secrets import struct @@ -19,14 +20,15 @@ __all__ = [ - "DATA_OPCODES", - "CTRL_OPCODES", + "Opcode", "OP_CONT", "OP_TEXT", "OP_BINARY", "OP_CLOSE", "OP_PING", "OP_PONG", + "DATA_OPCODES", + "CTRL_OPCODES", "Frame", "prepare_data", "prepare_ctrl", @@ -34,8 +36,21 @@ "serialize_close", ] -DATA_OPCODES = OP_CONT, OP_TEXT, OP_BINARY = 0x00, 0x01, 0x02 -CTRL_OPCODES = OP_CLOSE, OP_PING, OP_PONG = 0x08, 0x09, 0x0A + +class Opcode(enum.IntEnum): + CONT, TEXT, BINARY = 0x00, 0x01, 0x02 + CLOSE, PING, PONG = 0x08, 0x09, 0x0A + + +OP_CONT = Opcode.CONT +OP_TEXT = Opcode.TEXT +OP_BINARY = Opcode.BINARY +OP_CLOSE = Opcode.CLOSE +OP_PING = Opcode.PING +OP_PONG = Opcode.PONG + +DATA_OPCODES = OP_CONT, OP_TEXT, OP_BINARY +CTRL_OPCODES = OP_CLOSE, OP_PING, OP_PONG # Close code that are allowed in a close frame. # Using a list optimizes `code in EXTERNAL_CLOSE_CODES`. @@ -62,7 +77,7 @@ class Frame(NamedTuple): """ fin: bool - opcode: int + opcode: Opcode data: bytes rsv1: bool = False rsv2: bool = False @@ -103,7 +118,11 @@ def parse( rsv1 = True if head1 & 0b01000000 else False rsv2 = True if head1 & 0b00100000 else False rsv3 = True if head1 & 0b00010000 else False - opcode = head1 & 0b00001111 + + try: + opcode = Opcode(head1 & 0b00001111) + except ValueError as exc: + raise ProtocolError("invalid opcode") from exc if (True if head2 & 0b10000000 else False) != mask: raise ProtocolError("incorrect masking") @@ -116,9 +135,7 @@ def parse( data = yield from read_exact(8) (length,) = struct.unpack("!Q", data) if max_size is not None and length > max_size: - raise PayloadTooBig( - f"payload length exceeds size limit ({length} > {max_size} bytes)" - ) + raise PayloadTooBig(f"over size limit ({length} > {max_size} bytes)") if mask: mask_bytes = yield from read_exact(4) @@ -209,15 +226,11 @@ def check(self) -> None: if self.rsv1 or self.rsv2 or self.rsv3: raise ProtocolError("reserved bits must be 0") - if self.opcode in DATA_OPCODES: - return - elif self.opcode in CTRL_OPCODES: + if self.opcode in CTRL_OPCODES: if len(self.data) > 125: raise ProtocolError("control frame too long") if not self.fin: raise ProtocolError("fragmented control frame") - else: - raise ProtocolError(f"invalid opcode: {self.opcode}") def prepare_data(data: Data) -> Tuple[int, bytes]: diff --git a/src/websockets/framing.py b/src/websockets/framing.py index 221afad6f..b2996d788 100644 --- a/src/websockets/framing.py +++ b/src/websockets/framing.py @@ -15,7 +15,7 @@ from typing import Any, Awaitable, Callable, Optional, Sequence from .exceptions import PayloadTooBig, ProtocolError -from .frames import Frame as NewFrame +from .frames import Frame as NewFrame, Opcode try: @@ -64,7 +64,11 @@ async def read( rsv1 = True if head1 & 0b01000000 else False rsv2 = True if head1 & 0b00100000 else False rsv3 = True if head1 & 0b00010000 else False - opcode = head1 & 0b00001111 + + try: + opcode = Opcode(head1 & 0b00001111) + except ValueError as exc: + raise ProtocolError("invalid opcode") from exc if (True if head2 & 0b10000000 else False) != mask: raise ProtocolError("incorrect masking") @@ -77,9 +81,7 @@ async def read( data = await reader(8) (length,) = struct.unpack("!Q", data) if max_size is not None and length > max_size: - raise PayloadTooBig( - f"payload length exceeds size limit ({length} > {max_size} bytes)" - ) + raise PayloadTooBig(f"over size limit ({length} > {max_size} bytes)") if mask: mask_bits = await reader(4) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 58c4569d0..2e5d95e06 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -47,6 +47,7 @@ OP_PING, OP_PONG, OP_TEXT, + Opcode, parse_close, prepare_ctrl, prepare_data, @@ -1071,7 +1072,7 @@ async def write_frame( f"Cannot write to a WebSocket in the {self.state.name} state" ) - frame = Frame(fin, opcode, data) + frame = Frame(fin, Opcode(opcode), data) logger.debug("%s > %r", self.side, frame) frame.write( self.transport.write, mask=self.is_client, extensions=self.extensions diff --git a/src/websockets/server.py b/src/websockets/server.py index 73156b33f..1b03eabee 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -4,17 +4,7 @@ import email.utils import http import logging -from typing import ( - Any, - Callable, - Generator, - List, - Optional, - Sequence, - Tuple, - Union, - cast, -) +from typing import Callable, Generator, List, Optional, Sequence, Tuple, Union, cast from .asyncio_server import WebSocketServer, WebSocketServerProtocol, serve, unix_serve from .connection import CONNECTING, OPEN, SERVER, Connection @@ -71,9 +61,9 @@ def __init__( extensions: Optional[Sequence[ServerExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLikeOrCallable] = None, - **kwargs: Any, + max_size: Optional[int] = 2 ** 20, ): - super().__init__(SERVER, CONNECTING, **kwargs) + super().__init__(side=SERVER, state=CONNECTING, max_size=max_size) self.origins = origins self.available_extensions = extensions self.available_subprotocols = subprotocols diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index e1193e672..f9fca1999 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -243,7 +243,7 @@ def test_compress_settings(self): ), ) - # Frames aren't decoded beyond max_length. + # Frames aren't decoded beyond max_size. def test_decompress_max_size(self): frame = Frame(True, OP_TEXT, ("a" * 20).encode("utf-8")) diff --git a/tests/test_connection.py b/tests/test_connection.py new file mode 100644 index 000000000..5c0f7302f --- /dev/null +++ b/tests/test_connection.py @@ -0,0 +1,1418 @@ +import unittest.mock + +from websockets.connection import * +from websockets.exceptions import InvalidState, PayloadTooBig, ProtocolError +from websockets.frames import ( + OP_BINARY, + OP_CLOSE, + OP_CONT, + OP_PING, + OP_PONG, + OP_TEXT, + Frame, + serialize_close, +) + +from .extensions.utils import Rsv2Extension +from .test_frames import FramesTestCase + + +class ConnectionTestCase(FramesTestCase): + def assertFrameSent(self, connection, frame, eof=False): + """ + Outgoing data for ``connection`` contains the given frame. + + ``frame`` may be ``None`` if no frame is expected. + + When ``eof`` is ``True``, the end of the stream is also expected. + + """ + frames_sent = [ + None + if write is SEND_EOF + else self.parse( + write, + mask=connection.side is Side.CLIENT, + extensions=connection.extensions, + ) + for write in connection.bytes_to_send() + ] + frames_expected = [] if frame is None else [frame] + if eof: + frames_expected += [None] + self.assertEqual(frames_sent, frames_expected) + + def assertFrameReceived(self, connection, frame): + """ + Incoming data for ``connection`` contains the given frame. + + ``frame`` may be ``None`` if no frame is expected. + + """ + frames_received = connection.events_received() + frames_expected = [] if frame is None else [frame] + self.assertEqual(frames_received, frames_expected) + + def assertConnectionClosing(self, connection, code=None, reason=""): + """ + Incoming data caused the "Start the WebSocket Closing Handshake" process. + + """ + close_frame = Frame( + True, OP_CLOSE, b"" if code is None else serialize_close(code, reason), + ) + # A close frame was received. + self.assertFrameReceived(connection, close_frame) + # A close frame and possibly the end of stream were sent. + self.assertFrameSent( + connection, close_frame, eof=connection.side is Side.SERVER + ) + + def assertConnectionFailing(self, connection, code=None, reason=""): + """ + Incoming data caused the "Fail the WebSocket Connection" process. + + """ + close_frame = Frame( + True, OP_CLOSE, b"" if code is None else serialize_close(code, reason), + ) + # No frame was received. + self.assertFrameReceived(connection, None) + # A close frame and the end of stream were sent. + self.assertFrameSent(connection, close_frame, eof=True) + + +class MaskingTests(ConnectionTestCase): + """ + Test frame masking. + + 5.1. Overview + + """ + + unmasked_text_frame_date = b"\x81\x04Spam" + masked_text_frame_data = b"\x81\x84\x00\xff\x00\xff\x53\x8f\x61\x92" + + def test_client_sends_masked_frame(self): + client = Connection(Side.CLIENT) + with self.enforce_mask(b"\x00\xff\x00\xff"): + client.send_text(b"Spam", True) + self.assertEqual(client.bytes_to_send(), [self.masked_text_frame_data]) + + def test_server_sends_unmasked_frame(self): + server = Connection(Side.SERVER) + server.send_text(b"Spam", True) + self.assertEqual(server.bytes_to_send(), [self.unmasked_text_frame_date]) + + def test_client_receives_unmasked_frame(self): + client = Connection(Side.CLIENT) + client.receive_data(self.unmasked_text_frame_date) + self.assertFrameReceived( + client, Frame(True, OP_TEXT, b"Spam"), + ) + + def test_server_receives_masked_frame(self): + server = Connection(Side.SERVER) + server.receive_data(self.masked_text_frame_data) + self.assertFrameReceived( + server, Frame(True, OP_TEXT, b"Spam"), + ) + + def test_client_receives_masked_frame(self): + client = Connection(Side.CLIENT) + with self.assertRaises(ProtocolError) as raised: + client.receive_data(self.masked_text_frame_data) + self.assertEqual(str(raised.exception), "incorrect masking") + self.assertConnectionFailing(client, 1002, "incorrect masking") + + def test_server_receives_unmasked_frame(self): + server = Connection(Side.SERVER) + with self.assertRaises(ProtocolError) as raised: + server.receive_data(self.unmasked_text_frame_date) + self.assertEqual(str(raised.exception), "incorrect masking") + self.assertConnectionFailing(server, 1002, "incorrect masking") + + +class ContinuationTests(ConnectionTestCase): + """ + Test continuation frames without text or binary frames. + + """ + + def test_client_sends_unexpected_continuation(self): + client = Connection(Side.CLIENT) + with self.assertRaises(ProtocolError) as raised: + client.send_continuation(b"", fin=False) + self.assertEqual(str(raised.exception), "unexpected continuation frame") + + def test_server_sends_unexpected_continuation(self): + server = Connection(Side.SERVER) + with self.assertRaises(ProtocolError) as raised: + server.send_continuation(b"", fin=False) + self.assertEqual(str(raised.exception), "unexpected continuation frame") + + def test_client_receives_unexpected_continuation(self): + client = Connection(Side.CLIENT) + with self.assertRaises(ProtocolError) as raised: + client.receive_data(b"\x00\x00") + self.assertEqual(str(raised.exception), "unexpected continuation frame") + self.assertConnectionFailing(client, 1002, "unexpected continuation frame") + + def test_server_receives_unexpected_continuation(self): + server = Connection(Side.SERVER) + with self.assertRaises(ProtocolError) as raised: + server.receive_data(b"\x00\x80\x00\x00\x00\x00") + self.assertEqual(str(raised.exception), "unexpected continuation frame") + self.assertConnectionFailing(server, 1002, "unexpected continuation frame") + + def test_client_sends_continuation_after_sending_close(self): + client = Connection(Side.CLIENT) + # Since it isn't possible to send a close frame in a fragmented + # message (see test_client_send_close_in_fragmented_message), in fact, + # this is the same test as test_client_sends_unexpected_continuation. + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_close(1001) + self.assertEqual(client.bytes_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) + with self.assertRaises(ProtocolError) as raised: + client.send_continuation(b"", fin=False) + self.assertEqual(str(raised.exception), "unexpected continuation frame") + + def test_server_sends_continuation_after_sending_close(self): + # Since it isn't possible to send a close frame in a fragmented + # message (see test_server_send_close_in_fragmented_message), in fact, + # this is the same test as test_server_sends_unexpected_continuation. + server = Connection(Side.SERVER) + server.send_close(1000) + self.assertEqual(server.bytes_to_send(), [b"\x88\x02\x03\xe8", b""]) + with self.assertRaises(ProtocolError) as raised: + server.send_continuation(b"", fin=False) + self.assertEqual(str(raised.exception), "unexpected continuation frame") + + def test_client_receives_continuation_after_receiving_close(self): + client = Connection(Side.CLIENT) + client.receive_data(b"\x88\x02\x03\xe8") + self.assertConnectionClosing(client, 1000) + with self.assertRaises(ProtocolError) as raised: + client.receive_data(b"\x00\x00") + self.assertEqual(str(raised.exception), "data frame after close frame") + + def test_server_receives_continuation_after_receiving_close(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") + self.assertConnectionClosing(server, 1001) + with self.assertRaises(ProtocolError) as raised: + server.receive_data(b"\x00\x80\x00\xff\x00\xff") + self.assertEqual(str(raised.exception), "data frame after close frame") + + +class TextTests(ConnectionTestCase): + """ + Test text frames and continuation frames. + + """ + + def test_client_sends_text(self): + client = Connection(Side.CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_text("😀".encode()) + self.assertEqual( + client.bytes_to_send(), [b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80"] + ) + + def test_server_sends_text(self): + server = Connection(Side.SERVER) + server.send_text("😀".encode()) + self.assertEqual(server.bytes_to_send(), [b"\x81\x04\xf0\x9f\x98\x80"]) + + def test_client_receives_text(self): + client = Connection(Side.CLIENT) + client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") + self.assertFrameReceived( + client, Frame(True, OP_TEXT, "😀".encode()), + ) + + def test_server_receives_text(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") + self.assertFrameReceived( + server, Frame(True, OP_TEXT, "😀".encode()), + ) + + def test_client_receives_text_over_size_limit(self): + client = Connection(Side.CLIENT, max_size=3) + with self.assertRaises(PayloadTooBig) as raised: + client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") + self.assertEqual(str(raised.exception), "over size limit (4 > 3 bytes)") + self.assertConnectionFailing(client, 1009, "over size limit (4 > 3 bytes)") + + def test_server_receives_text_over_size_limit(self): + server = Connection(Side.SERVER, max_size=3) + with self.assertRaises(PayloadTooBig) as raised: + server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") + self.assertEqual(str(raised.exception), "over size limit (4 > 3 bytes)") + self.assertConnectionFailing(server, 1009, "over size limit (4 > 3 bytes)") + + def test_client_receives_text_without_size_limit(self): + client = Connection(Side.CLIENT, max_size=None) + client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") + self.assertFrameReceived( + client, Frame(True, OP_TEXT, "😀".encode()), + ) + + def test_server_receives_text_without_size_limit(self): + server = Connection(Side.SERVER, max_size=None) + server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") + self.assertFrameReceived( + server, Frame(True, OP_TEXT, "😀".encode()), + ) + + def test_client_sends_fragmented_text(self): + client = Connection(Side.CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_text("😀".encode()[:2], fin=False) + self.assertEqual(client.bytes_to_send(), [b"\x01\x82\x00\x00\x00\x00\xf0\x9f"]) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_continuation("😀😀".encode()[2:6], fin=False) + self.assertEqual( + client.bytes_to_send(), [b"\x00\x84\x00\x00\x00\x00\x98\x80\xf0\x9f"] + ) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_continuation("😀".encode()[2:], fin=True) + self.assertEqual(client.bytes_to_send(), [b"\x80\x82\x00\x00\x00\x00\x98\x80"]) + + def test_server_sends_fragmented_text(self): + server = Connection(Side.SERVER) + server.send_text("😀".encode()[:2], fin=False) + self.assertEqual(server.bytes_to_send(), [b"\x01\x02\xf0\x9f"]) + server.send_continuation("😀😀".encode()[2:6], fin=False) + self.assertEqual(server.bytes_to_send(), [b"\x00\x04\x98\x80\xf0\x9f"]) + server.send_continuation("😀".encode()[2:], fin=True) + self.assertEqual(server.bytes_to_send(), [b"\x80\x02\x98\x80"]) + + def test_client_receives_fragmented_text(self): + client = Connection(Side.CLIENT) + client.receive_data(b"\x01\x02\xf0\x9f") + self.assertFrameReceived( + client, Frame(False, OP_TEXT, "😀".encode()[:2]), + ) + client.receive_data(b"\x00\x04\x98\x80\xf0\x9f") + self.assertFrameReceived( + client, Frame(False, OP_CONT, "😀😀".encode()[2:6]), + ) + client.receive_data(b"\x80\x02\x98\x80") + self.assertFrameReceived( + client, Frame(True, OP_CONT, "😀".encode()[2:]), + ) + + def test_server_receives_fragmented_text(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") + self.assertFrameReceived( + server, Frame(False, OP_TEXT, "😀".encode()[:2]), + ) + server.receive_data(b"\x00\x84\x00\x00\x00\x00\x98\x80\xf0\x9f") + self.assertFrameReceived( + server, Frame(False, OP_CONT, "😀😀".encode()[2:6]), + ) + server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") + self.assertFrameReceived( + server, Frame(True, OP_CONT, "😀".encode()[2:]), + ) + + def test_client_receives_fragmented_text_over_size_limit(self): + client = Connection(Side.CLIENT, max_size=3) + client.receive_data(b"\x01\x02\xf0\x9f") + self.assertFrameReceived( + client, Frame(False, OP_TEXT, "😀".encode()[:2]), + ) + with self.assertRaises(PayloadTooBig) as raised: + client.receive_data(b"\x80\x02\x98\x80") + self.assertEqual(str(raised.exception), "over size limit (2 > 1 bytes)") + self.assertConnectionFailing(client, 1009, "over size limit (2 > 1 bytes)") + + def test_server_receives_fragmented_text_over_size_limit(self): + server = Connection(Side.SERVER, max_size=3) + server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") + self.assertFrameReceived( + server, Frame(False, OP_TEXT, "😀".encode()[:2]), + ) + with self.assertRaises(PayloadTooBig) as raised: + server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") + self.assertEqual(str(raised.exception), "over size limit (2 > 1 bytes)") + self.assertConnectionFailing(server, 1009, "over size limit (2 > 1 bytes)") + + def test_client_receives_fragmented_text_without_size_limit(self): + client = Connection(Side.CLIENT, max_size=None) + client.receive_data(b"\x01\x02\xf0\x9f") + self.assertFrameReceived( + client, Frame(False, OP_TEXT, "😀".encode()[:2]), + ) + client.receive_data(b"\x00\x04\x98\x80\xf0\x9f") + self.assertFrameReceived( + client, Frame(False, OP_CONT, "😀😀".encode()[2:6]), + ) + client.receive_data(b"\x80\x02\x98\x80") + self.assertFrameReceived( + client, Frame(True, OP_CONT, "😀".encode()[2:]), + ) + + def test_server_receives_fragmented_text_without_size_limit(self): + server = Connection(Side.SERVER, max_size=None) + server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") + self.assertFrameReceived( + server, Frame(False, OP_TEXT, "😀".encode()[:2]), + ) + server.receive_data(b"\x00\x84\x00\x00\x00\x00\x98\x80\xf0\x9f") + self.assertFrameReceived( + server, Frame(False, OP_CONT, "😀😀".encode()[2:6]), + ) + server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") + self.assertFrameReceived( + server, Frame(True, OP_CONT, "😀".encode()[2:]), + ) + + def test_client_sends_unexpected_text(self): + client = Connection(Side.CLIENT) + client.send_text(b"", fin=False) + with self.assertRaises(ProtocolError) as raised: + client.send_text(b"", fin=False) + self.assertEqual(str(raised.exception), "expected a continuation frame") + + def test_server_sends_unexpected_text(self): + server = Connection(Side.SERVER) + server.send_text(b"", fin=False) + with self.assertRaises(ProtocolError) as raised: + server.send_text(b"", fin=False) + self.assertEqual(str(raised.exception), "expected a continuation frame") + + def test_client_receives_unexpected_text(self): + client = Connection(Side.CLIENT) + client.receive_data(b"\x01\x00") + self.assertFrameReceived( + client, Frame(False, OP_TEXT, b""), + ) + with self.assertRaises(ProtocolError) as raised: + client.receive_data(b"\x01\x00") + self.assertEqual(str(raised.exception), "expected a continuation frame") + self.assertConnectionFailing(client, 1002, "expected a continuation frame") + + def test_server_receives_unexpected_text(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x01\x80\x00\x00\x00\x00") + self.assertFrameReceived( + server, Frame(False, OP_TEXT, b""), + ) + with self.assertRaises(ProtocolError) as raised: + server.receive_data(b"\x01\x80\x00\x00\x00\x00") + self.assertEqual(str(raised.exception), "expected a continuation frame") + self.assertConnectionFailing(server, 1002, "expected a continuation frame") + + def test_client_sends_text_after_sending_close(self): + client = Connection(Side.CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_close(1001) + self.assertEqual(client.bytes_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) + with self.assertRaises(InvalidState): + client.send_text(b"") + + def test_server_sends_text_after_sending_close(self): + server = Connection(Side.SERVER) + server.send_close(1000) + self.assertEqual(server.bytes_to_send(), [b"\x88\x02\x03\xe8", b""]) + with self.assertRaises(InvalidState): + server.send_text(b"") + + def test_client_receives_text_after_receiving_close(self): + client = Connection(Side.CLIENT) + client.receive_data(b"\x88\x02\x03\xe8") + self.assertConnectionClosing(client, 1000) + with self.assertRaises(ProtocolError) as raised: + client.receive_data(b"\x81\x00") + self.assertEqual(str(raised.exception), "data frame after close frame") + + def test_server_receives_text_after_receiving_close(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") + self.assertConnectionClosing(server, 1001) + with self.assertRaises(ProtocolError) as raised: + server.receive_data(b"\x81\x80\x00\xff\x00\xff") + self.assertEqual(str(raised.exception), "data frame after close frame") + + +class BinaryTests(ConnectionTestCase): + """ + Test binary frames and continuation frames. + + """ + + def test_client_sends_binary(self): + client = Connection(Side.CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_binary(b"\x01\x02\xfe\xff") + self.assertEqual( + client.bytes_to_send(), [b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff"] + ) + + def test_server_sends_binary(self): + server = Connection(Side.SERVER) + server.send_binary(b"\x01\x02\xfe\xff") + self.assertEqual(server.bytes_to_send(), [b"\x82\x04\x01\x02\xfe\xff"]) + + def test_client_receives_binary(self): + client = Connection(Side.CLIENT) + client.receive_data(b"\x82\x04\x01\x02\xfe\xff") + self.assertFrameReceived( + client, Frame(True, OP_BINARY, b"\x01\x02\xfe\xff"), + ) + + def test_server_receives_binary(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff") + self.assertFrameReceived( + server, Frame(True, OP_BINARY, b"\x01\x02\xfe\xff"), + ) + + def test_client_receives_binary_over_size_limit(self): + client = Connection(Side.CLIENT, max_size=3) + with self.assertRaises(PayloadTooBig) as raised: + client.receive_data(b"\x82\x04\x01\x02\xfe\xff") + self.assertEqual(str(raised.exception), "over size limit (4 > 3 bytes)") + self.assertConnectionFailing(client, 1009, "over size limit (4 > 3 bytes)") + + def test_server_receives_binary_over_size_limit(self): + server = Connection(Side.SERVER, max_size=3) + with self.assertRaises(PayloadTooBig) as raised: + server.receive_data(b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff") + self.assertEqual(str(raised.exception), "over size limit (4 > 3 bytes)") + self.assertConnectionFailing(server, 1009, "over size limit (4 > 3 bytes)") + + def test_client_sends_fragmented_binary(self): + client = Connection(Side.CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_binary(b"\x01\x02", fin=False) + self.assertEqual(client.bytes_to_send(), [b"\x02\x82\x00\x00\x00\x00\x01\x02"]) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_continuation(b"\xee\xff\x01\x02", fin=False) + self.assertEqual( + client.bytes_to_send(), [b"\x00\x84\x00\x00\x00\x00\xee\xff\x01\x02"] + ) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_continuation(b"\xee\xff", fin=True) + self.assertEqual(client.bytes_to_send(), [b"\x80\x82\x00\x00\x00\x00\xee\xff"]) + + def test_server_sends_fragmented_binary(self): + server = Connection(Side.SERVER) + server.send_binary(b"\x01\x02", fin=False) + self.assertEqual(server.bytes_to_send(), [b"\x02\x02\x01\x02"]) + server.send_continuation(b"\xee\xff\x01\x02", fin=False) + self.assertEqual(server.bytes_to_send(), [b"\x00\x04\xee\xff\x01\x02"]) + server.send_continuation(b"\xee\xff", fin=True) + self.assertEqual(server.bytes_to_send(), [b"\x80\x02\xee\xff"]) + + def test_client_receives_fragmented_binary(self): + client = Connection(Side.CLIENT) + client.receive_data(b"\x02\x02\x01\x02") + self.assertFrameReceived( + client, Frame(False, OP_BINARY, b"\x01\x02"), + ) + client.receive_data(b"\x00\x04\xfe\xff\x01\x02") + self.assertFrameReceived( + client, Frame(False, OP_CONT, b"\xfe\xff\x01\x02"), + ) + client.receive_data(b"\x80\x02\xfe\xff") + self.assertFrameReceived( + client, Frame(True, OP_CONT, b"\xfe\xff"), + ) + + def test_server_receives_fragmented_binary(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x02\x82\x00\x00\x00\x00\x01\x02") + self.assertFrameReceived( + server, Frame(False, OP_BINARY, b"\x01\x02"), + ) + server.receive_data(b"\x00\x84\x00\x00\x00\x00\xee\xff\x01\x02") + self.assertFrameReceived( + server, Frame(False, OP_CONT, b"\xee\xff\x01\x02"), + ) + server.receive_data(b"\x80\x82\x00\x00\x00\x00\xfe\xff") + self.assertFrameReceived( + server, Frame(True, OP_CONT, b"\xfe\xff"), + ) + + def test_client_receives_fragmented_binary_over_size_limit(self): + client = Connection(Side.CLIENT, max_size=3) + client.receive_data(b"\x02\x02\x01\x02") + self.assertFrameReceived( + client, Frame(False, OP_BINARY, b"\x01\x02"), + ) + with self.assertRaises(PayloadTooBig) as raised: + client.receive_data(b"\x80\x02\xfe\xff") + self.assertEqual(str(raised.exception), "over size limit (2 > 1 bytes)") + self.assertConnectionFailing(client, 1009, "over size limit (2 > 1 bytes)") + + def test_server_receives_fragmented_binary_over_size_limit(self): + server = Connection(Side.SERVER, max_size=3) + server.receive_data(b"\x02\x82\x00\x00\x00\x00\x01\x02") + self.assertFrameReceived( + server, Frame(False, OP_BINARY, b"\x01\x02"), + ) + with self.assertRaises(PayloadTooBig) as raised: + server.receive_data(b"\x80\x82\x00\x00\x00\x00\xfe\xff") + self.assertEqual(str(raised.exception), "over size limit (2 > 1 bytes)") + self.assertConnectionFailing(server, 1009, "over size limit (2 > 1 bytes)") + + def test_client_sends_unexpected_binary(self): + client = Connection(Side.CLIENT) + client.send_binary(b"", fin=False) + with self.assertRaises(ProtocolError) as raised: + client.send_binary(b"", fin=False) + self.assertEqual(str(raised.exception), "expected a continuation frame") + + def test_server_sends_unexpected_binary(self): + server = Connection(Side.SERVER) + server.send_binary(b"", fin=False) + with self.assertRaises(ProtocolError) as raised: + server.send_binary(b"", fin=False) + self.assertEqual(str(raised.exception), "expected a continuation frame") + + def test_client_receives_unexpected_binary(self): + client = Connection(Side.CLIENT) + client.receive_data(b"\x02\x00") + self.assertFrameReceived( + client, Frame(False, OP_BINARY, b""), + ) + with self.assertRaises(ProtocolError) as raised: + client.receive_data(b"\x02\x00") + self.assertEqual(str(raised.exception), "expected a continuation frame") + self.assertConnectionFailing(client, 1002, "expected a continuation frame") + + def test_server_receives_unexpected_binary(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x02\x80\x00\x00\x00\x00") + self.assertFrameReceived( + server, Frame(False, OP_BINARY, b""), + ) + with self.assertRaises(ProtocolError) as raised: + server.receive_data(b"\x02\x80\x00\x00\x00\x00") + self.assertEqual(str(raised.exception), "expected a continuation frame") + self.assertConnectionFailing(server, 1002, "expected a continuation frame") + + def test_client_sends_binary_after_sending_close(self): + client = Connection(Side.CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_close(1001) + self.assertEqual(client.bytes_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) + with self.assertRaises(InvalidState): + client.send_binary(b"") + + def test_server_sends_binary_after_sending_close(self): + server = Connection(Side.SERVER) + server.send_close(1000) + self.assertEqual(server.bytes_to_send(), [b"\x88\x02\x03\xe8", b""]) + with self.assertRaises(InvalidState): + server.send_binary(b"") + + def test_client_receives_binary_after_receiving_close(self): + client = Connection(Side.CLIENT) + client.receive_data(b"\x88\x02\x03\xe8") + self.assertConnectionClosing(client, 1000) + with self.assertRaises(ProtocolError) as raised: + client.receive_data(b"\x82\x00") + self.assertEqual(str(raised.exception), "data frame after close frame") + + def test_server_receives_binary_after_receiving_close(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") + self.assertConnectionClosing(server, 1001) + with self.assertRaises(ProtocolError) as raised: + server.receive_data(b"\x82\x80\x00\xff\x00\xff") + self.assertEqual(str(raised.exception), "data frame after close frame") + + +class CloseTests(ConnectionTestCase): + """ + Test close frames. See 5.5.1. Close in RFC 6544. + + """ + + def test_client_sends_close(self): + client = Connection(Side.CLIENT) + with self.enforce_mask(b"\x3c\x3c\x3c\x3c"): + client.send_close() + self.assertEqual(client.bytes_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) + self.assertIs(client.state, State.CLOSING) + + def test_server_sends_close(self): + server = Connection(Side.SERVER) + server.send_close() + self.assertEqual(server.bytes_to_send(), [b"\x88\x00", b""]) + self.assertIs(server.state, State.CLOSING) + + def test_client_receives_close(self): + client = Connection(Side.CLIENT) + with self.enforce_mask(b"\x3c\x3c\x3c\x3c"): + client.receive_data(b"\x88\x00") + self.assertEqual(client.events_received(), [Frame(True, OP_CLOSE, b"")]) + self.assertEqual(client.bytes_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) + self.assertIs(client.state, State.CLOSING) + + def test_server_receives_close(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") + self.assertEqual(server.events_received(), [Frame(True, OP_CLOSE, b"")]) + self.assertEqual(server.bytes_to_send(), [b"\x88\x00", b""]) + self.assertIs(server.state, State.CLOSING) + + def test_client_sends_close_then_receives_close(self): + # Client-initiated close handshake on the client side. + client = Connection(Side.CLIENT) + + client.send_close() + self.assertFrameReceived(client, None) + self.assertFrameSent(client, Frame(True, OP_CLOSE, b"")) + + client.receive_data(b"\x88\x00") + self.assertFrameReceived(client, Frame(True, OP_CLOSE, b"")) + self.assertFrameSent(client, None) + + client.receive_eof() + self.assertFrameReceived(client, None) + self.assertFrameSent(client, None, eof=True) + + def test_server_sends_close_then_receives_close(self): + # Server-initiated close handshake on the server side. + server = Connection(Side.SERVER) + + server.send_close() + self.assertFrameReceived(server, None) + self.assertFrameSent(server, Frame(True, OP_CLOSE, b""), eof=True) + + server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") + self.assertFrameReceived(server, Frame(True, OP_CLOSE, b"")) + self.assertFrameSent(server, None) + + server.receive_eof() + self.assertFrameReceived(server, None) + self.assertFrameSent(server, None) + + def test_client_receives_close_then_sends_close(self): + # Server-initiated close handshake on the client side. + client = Connection(Side.CLIENT) + + client.receive_data(b"\x88\x00") + self.assertFrameReceived(client, Frame(True, OP_CLOSE, b"")) + self.assertFrameSent(client, Frame(True, OP_CLOSE, b"")) + + client.receive_eof() + self.assertFrameReceived(client, None) + self.assertFrameSent(client, None, eof=True) + + def test_server_receives_close_then_sends_close(self): + # Client-initiated close handshake on the server side. + server = Connection(Side.SERVER) + + server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") + self.assertFrameReceived(server, Frame(True, OP_CLOSE, b"")) + self.assertFrameSent(server, Frame(True, OP_CLOSE, b""), eof=True) + + server.receive_eof() + self.assertFrameReceived(server, None) + self.assertFrameSent(server, None) + + def test_client_sends_close_with_code(self): + client = Connection(Side.CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_close(1001) + self.assertEqual(client.bytes_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) + self.assertIs(client.state, State.CLOSING) + + def test_server_sends_close_with_code(self): + server = Connection(Side.SERVER) + server.send_close(1000) + self.assertEqual(server.bytes_to_send(), [b"\x88\x02\x03\xe8", b""]) + self.assertIs(server.state, State.CLOSING) + + def test_client_receives_close_with_code(self): + client = Connection(Side.CLIENT) + client.receive_data(b"\x88\x02\x03\xe8") + self.assertConnectionClosing(client, 1000, "") + self.assertIs(client.state, State.CLOSING) + + def test_server_receives_close_with_code(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") + self.assertConnectionClosing(server, 1001, "") + self.assertIs(server.state, State.CLOSING) + + def test_client_sends_close_with_code_and_reason(self): + client = Connection(Side.CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_close(1001, "going away") + self.assertEqual( + client.bytes_to_send(), [b"\x88\x8c\x00\x00\x00\x00\x03\xe9going away"] + ) + self.assertIs(client.state, State.CLOSING) + + def test_server_sends_close_with_code_and_reason(self): + server = Connection(Side.SERVER) + server.send_close(1000, "OK") + self.assertEqual(server.bytes_to_send(), [b"\x88\x04\x03\xe8OK", b""]) + self.assertIs(server.state, State.CLOSING) + + def test_client_receives_close_with_code_and_reason(self): + client = Connection(Side.CLIENT) + client.receive_data(b"\x88\x04\x03\xe8OK") + self.assertConnectionClosing(client, 1000, "OK") + self.assertIs(client.state, State.CLOSING) + + def test_server_receives_close_with_code_and_reason(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x88\x8c\x00\x00\x00\x00\x03\xe9going away") + self.assertConnectionClosing(server, 1001, "going away") + self.assertIs(server.state, State.CLOSING) + + def test_client_sends_close_with_reason_only(self): + client = Connection(Side.CLIENT) + with self.assertRaises(ValueError) as raised: + client.send_close(reason="going away") + self.assertEqual(str(raised.exception), "cannot send a reason without a code") + + def test_server_sends_close_with_reason_only(self): + server = Connection(Side.SERVER) + with self.assertRaises(ValueError) as raised: + server.send_close(reason="OK") + self.assertEqual(str(raised.exception), "cannot send a reason without a code") + + def test_client_receives_close_with_truncated_code(self): + client = Connection(Side.CLIENT) + with self.assertRaises(ProtocolError) as raised: + client.receive_data(b"\x88\x01\x03") + self.assertEqual(str(raised.exception), "close frame too short") + self.assertConnectionFailing(client, 1002, "close frame too short") + self.assertIs(client.state, State.CLOSING) + + def test_server_receives_close_with_truncated_code(self): + server = Connection(Side.SERVER) + with self.assertRaises(ProtocolError) as raised: + server.receive_data(b"\x88\x81\x00\x00\x00\x00\x03") + self.assertEqual(str(raised.exception), "close frame too short") + self.assertConnectionFailing(server, 1002, "close frame too short") + self.assertIs(server.state, State.CLOSING) + + def test_client_receives_close_with_non_utf8_reason(self): + client = Connection(Side.CLIENT) + with self.assertRaises(UnicodeDecodeError) as raised: + client.receive_data(b"\x88\x04\x03\xe8\xff\xff") + self.assertEqual( + str(raised.exception), + "'utf-8' codec can't decode byte 0xff in position 0: invalid start byte", + ) + self.assertConnectionFailing(client, 1007, "invalid start byte at position 0") + self.assertIs(client.state, State.CLOSING) + + def test_server_receives_close_with_non_utf8_reason(self): + server = Connection(Side.SERVER) + with self.assertRaises(UnicodeDecodeError) as raised: + server.receive_data(b"\x88\x84\x00\x00\x00\x00\x03\xe9\xff\xff") + self.assertEqual( + str(raised.exception), + "'utf-8' codec can't decode byte 0xff in position 0: invalid start byte", + ) + self.assertConnectionFailing(server, 1007, "invalid start byte at position 0") + self.assertIs(server.state, State.CLOSING) + + +class PingTests(ConnectionTestCase): + """ + Test ping. See 5.5.2. Ping in RFC 6544. + + """ + + def test_client_sends_ping(self): + client = Connection(Side.CLIENT) + with self.enforce_mask(b"\x00\x44\x88\xcc"): + client.send_ping(b"") + self.assertEqual(client.bytes_to_send(), [b"\x89\x80\x00\x44\x88\xcc"]) + + def test_server_sends_ping(self): + server = Connection(Side.SERVER) + server.send_ping(b"") + self.assertEqual(server.bytes_to_send(), [b"\x89\x00"]) + + def test_client_receives_ping(self): + client = Connection(Side.CLIENT) + client.receive_data(b"\x89\x00") + self.assertFrameReceived( + client, Frame(True, OP_PING, b""), + ) + self.assertFrameSent( + client, Frame(True, OP_PONG, b""), + ) + + def test_server_receives_ping(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x89\x80\x00\x44\x88\xcc") + self.assertFrameReceived( + server, Frame(True, OP_PING, b""), + ) + self.assertFrameSent( + server, Frame(True, OP_PONG, b""), + ) + + def test_client_sends_ping_with_data(self): + client = Connection(Side.CLIENT) + with self.enforce_mask(b"\x00\x44\x88\xcc"): + client.send_ping(b"\x22\x66\xaa\xee") + self.assertEqual( + client.bytes_to_send(), [b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22"] + ) + + def test_server_sends_ping_with_data(self): + server = Connection(Side.SERVER) + server.send_ping(b"\x22\x66\xaa\xee") + self.assertEqual(server.bytes_to_send(), [b"\x89\x04\x22\x66\xaa\xee"]) + + def test_client_receives_ping_with_data(self): + client = Connection(Side.CLIENT) + client.receive_data(b"\x89\x04\x22\x66\xaa\xee") + self.assertFrameReceived( + client, Frame(True, OP_PING, b"\x22\x66\xaa\xee"), + ) + self.assertFrameSent( + client, Frame(True, OP_PONG, b"\x22\x66\xaa\xee"), + ) + + def test_server_receives_ping_with_data(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22") + self.assertFrameReceived( + server, Frame(True, OP_PING, b"\x22\x66\xaa\xee"), + ) + self.assertFrameSent( + server, Frame(True, OP_PONG, b"\x22\x66\xaa\xee"), + ) + + def test_client_sends_fragmented_ping_frame(self): + client = Connection(Side.CLIENT) + # This is only possible through a private API. + with self.assertRaises(ProtocolError) as raised: + client.send_frame(Frame(False, OP_PING, b"")) + self.assertEqual(str(raised.exception), "fragmented control frame") + + def test_server_sends_fragmented_ping_frame(self): + server = Connection(Side.SERVER) + # This is only possible through a private API. + with self.assertRaises(ProtocolError) as raised: + server.send_frame(Frame(False, OP_PING, b"")) + self.assertEqual(str(raised.exception), "fragmented control frame") + + def test_client_receives_fragmented_ping_frame(self): + client = Connection(Side.CLIENT) + with self.assertRaises(ProtocolError) as raised: + client.receive_data(b"\x09\x00") + self.assertEqual(str(raised.exception), "fragmented control frame") + self.assertConnectionFailing(client, 1002, "fragmented control frame") + + def test_server_receives_fragmented_ping_frame(self): + server = Connection(Side.SERVER) + with self.assertRaises(ProtocolError) as raised: + server.receive_data(b"\x09\x80\x3c\x3c\x3c\x3c") + self.assertEqual(str(raised.exception), "fragmented control frame") + self.assertConnectionFailing(server, 1002, "fragmented control frame") + + def test_client_sends_ping_after_sending_close(self): + client = Connection(Side.CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_close(1001) + self.assertEqual(client.bytes_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) + # The spec says: "An endpoint MAY send a Ping frame any time (...) + # before the connection is closed" but websockets doesn't support + # sending a Ping frame after a Close frame. + with self.assertRaises(InvalidState) as raised: + client.send_ping(b"") + self.assertEqual( + str(raised.exception), "cannot write to a WebSocket in the CLOSING state" + ) + + def test_server_sends_ping_after_sending_close(self): + server = Connection(Side.SERVER) + server.send_close(1000) + self.assertEqual(server.bytes_to_send(), [b"\x88\x02\x03\xe8", b""]) + # The spec says: "An endpoint MAY send a Ping frame any time (...) + # before the connection is closed" but websockets doesn't support + # sending a Ping frame after a Close frame. + with self.assertRaises(InvalidState) as raised: + server.send_ping(b"") + self.assertEqual( + str(raised.exception), "cannot write to a WebSocket in the CLOSING state" + ) + + def test_client_receives_ping_after_receiving_close(self): + client = Connection(Side.CLIENT) + client.receive_data(b"\x88\x02\x03\xe8") + self.assertConnectionClosing(client, 1000) + client.receive_data(b"\x89\x04\x22\x66\xaa\xee") + self.assertFrameReceived( + client, Frame(True, OP_PING, b"\x22\x66\xaa\xee"), + ) + self.assertFrameSent(client, None) + + def test_server_receives_ping_after_receiving_close(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") + self.assertConnectionClosing(server, 1001) + server.receive_data(b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22") + self.assertFrameReceived( + server, Frame(True, OP_PING, b"\x22\x66\xaa\xee"), + ) + self.assertFrameSent(server, None) + + +class PongTests(ConnectionTestCase): + """ + Test pong frames. See 5.5.3. Pong in RFC 6544. + + """ + + def test_client_sends_pong(self): + client = Connection(Side.CLIENT) + with self.enforce_mask(b"\x00\x44\x88\xcc"): + client.send_pong(b"") + self.assertEqual(client.bytes_to_send(), [b"\x8a\x80\x00\x44\x88\xcc"]) + + def test_server_sends_pong(self): + server = Connection(Side.SERVER) + server.send_pong(b"") + self.assertEqual(server.bytes_to_send(), [b"\x8a\x00"]) + + def test_client_receives_pong(self): + client = Connection(Side.CLIENT) + client.receive_data(b"\x8a\x00") + self.assertFrameReceived( + client, Frame(True, OP_PONG, b""), + ) + + def test_server_receives_pong(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x8a\x80\x00\x44\x88\xcc") + self.assertFrameReceived( + server, Frame(True, OP_PONG, b""), + ) + + def test_client_sends_pong_with_data(self): + client = Connection(Side.CLIENT) + with self.enforce_mask(b"\x00\x44\x88\xcc"): + client.send_pong(b"\x22\x66\xaa\xee") + self.assertEqual( + client.bytes_to_send(), [b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22"] + ) + + def test_server_sends_pong_with_data(self): + server = Connection(Side.SERVER) + server.send_pong(b"\x22\x66\xaa\xee") + self.assertEqual(server.bytes_to_send(), [b"\x8a\x04\x22\x66\xaa\xee"]) + + def test_client_receives_pong_with_data(self): + client = Connection(Side.CLIENT) + client.receive_data(b"\x8a\x04\x22\x66\xaa\xee") + self.assertFrameReceived( + client, Frame(True, OP_PONG, b"\x22\x66\xaa\xee"), + ) + + def test_server_receives_pong_with_data(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22") + self.assertFrameReceived( + server, Frame(True, OP_PONG, b"\x22\x66\xaa\xee"), + ) + + def test_client_sends_fragmented_pong_frame(self): + client = Connection(Side.CLIENT) + # This is only possible through a private API. + with self.assertRaises(ProtocolError) as raised: + client.send_frame(Frame(False, OP_PONG, b"")) + self.assertEqual(str(raised.exception), "fragmented control frame") + + def test_server_sends_fragmented_pong_frame(self): + server = Connection(Side.SERVER) + # This is only possible through a private API. + with self.assertRaises(ProtocolError) as raised: + server.send_frame(Frame(False, OP_PONG, b"")) + self.assertEqual(str(raised.exception), "fragmented control frame") + + def test_client_receives_fragmented_pong_frame(self): + client = Connection(Side.CLIENT) + with self.assertRaises(ProtocolError) as raised: + client.receive_data(b"\x0a\x00") + self.assertEqual(str(raised.exception), "fragmented control frame") + self.assertConnectionFailing(client, 1002, "fragmented control frame") + + def test_server_receives_fragmented_pong_frame(self): + server = Connection(Side.SERVER) + with self.assertRaises(ProtocolError) as raised: + server.receive_data(b"\x0a\x80\x3c\x3c\x3c\x3c") + self.assertEqual(str(raised.exception), "fragmented control frame") + self.assertConnectionFailing(server, 1002, "fragmented control frame") + + def test_client_sends_pong_after_sending_close(self): + client = Connection(Side.CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_close(1001) + self.assertEqual(client.bytes_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) + # websockets doesn't support sending a Pong frame after a Close frame. + with self.assertRaises(InvalidState): + client.send_pong(b"") + + def test_server_sends_pong_after_sending_close(self): + server = Connection(Side.SERVER) + server.send_close(1000) + self.assertEqual(server.bytes_to_send(), [b"\x88\x02\x03\xe8", b""]) + # websockets doesn't support sending a Pong frame after a Close frame. + with self.assertRaises(InvalidState): + server.send_pong(b"") + + def test_client_receives_pong_after_receiving_close(self): + client = Connection(Side.CLIENT) + client.receive_data(b"\x88\x02\x03\xe8") + self.assertConnectionClosing(client, 1000) + client.receive_data(b"\x8a\x04\x22\x66\xaa\xee") + self.assertFrameReceived( + client, Frame(True, OP_PONG, b"\x22\x66\xaa\xee"), + ) + + def test_server_receives_pong_after_receiving_close(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") + self.assertConnectionClosing(server, 1001) + server.receive_data(b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22") + self.assertFrameReceived( + server, Frame(True, OP_PONG, b"\x22\x66\xaa\xee"), + ) + + +class FragmentationTests(ConnectionTestCase): + """ + Test message fragmentation. + + See 5.4. Fragmentation in RFC 6544. + + """ + + def test_client_send_ping_pong_in_fragmented_message(self): + client = Connection(Side.CLIENT) + client.send_text(b"Spam", fin=False) + self.assertFrameSent(client, Frame(False, OP_TEXT, b"Spam")) + client.send_ping(b"Ping") + self.assertFrameSent(client, Frame(True, OP_PING, b"Ping")) + client.send_continuation(b"Ham", fin=False) + self.assertFrameSent(client, Frame(False, OP_CONT, b"Ham")) + client.send_pong(b"Pong") + self.assertFrameSent(client, Frame(True, OP_PONG, b"Pong")) + client.send_continuation(b"Eggs", fin=True) + self.assertFrameSent(client, Frame(True, OP_CONT, b"Eggs")) + + def test_server_send_ping_pong_in_fragmented_message(self): + server = Connection(Side.SERVER) + server.send_text(b"Spam", fin=False) + self.assertFrameSent(server, Frame(False, OP_TEXT, b"Spam")) + server.send_ping(b"Ping") + self.assertFrameSent(server, Frame(True, OP_PING, b"Ping")) + server.send_continuation(b"Ham", fin=False) + self.assertFrameSent(server, Frame(False, OP_CONT, b"Ham")) + server.send_pong(b"Pong") + self.assertFrameSent(server, Frame(True, OP_PONG, b"Pong")) + server.send_continuation(b"Eggs", fin=True) + self.assertFrameSent(server, Frame(True, OP_CONT, b"Eggs")) + + def test_client_receive_ping_pong_in_fragmented_message(self): + client = Connection(Side.CLIENT) + client.receive_data(b"\x01\x04Spam") + self.assertFrameReceived( + client, Frame(False, OP_TEXT, b"Spam"), + ) + client.receive_data(b"\x89\x04Ping") + self.assertFrameReceived( + client, Frame(True, OP_PING, b"Ping"), + ) + self.assertFrameSent( + client, Frame(True, OP_PONG, b"Ping"), + ) + client.receive_data(b"\x00\x03Ham") + self.assertFrameReceived( + client, Frame(False, OP_CONT, b"Ham"), + ) + client.receive_data(b"\x8a\x04Pong") + self.assertFrameReceived( + client, Frame(True, OP_PONG, b"Pong"), + ) + client.receive_data(b"\x80\x04Eggs") + self.assertFrameReceived( + client, Frame(True, OP_CONT, b"Eggs"), + ) + + def test_server_receive_ping_pong_in_fragmented_message(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x01\x84\x00\x00\x00\x00Spam") + self.assertFrameReceived( + server, Frame(False, OP_TEXT, b"Spam"), + ) + server.receive_data(b"\x89\x84\x00\x00\x00\x00Ping") + self.assertFrameReceived( + server, Frame(True, OP_PING, b"Ping"), + ) + self.assertFrameSent( + server, Frame(True, OP_PONG, b"Ping"), + ) + server.receive_data(b"\x00\x83\x00\x00\x00\x00Ham") + self.assertFrameReceived( + server, Frame(False, OP_CONT, b"Ham"), + ) + server.receive_data(b"\x8a\x84\x00\x00\x00\x00Pong") + self.assertFrameReceived( + server, Frame(True, OP_PONG, b"Pong"), + ) + server.receive_data(b"\x80\x84\x00\x00\x00\x00Eggs") + self.assertFrameReceived( + server, Frame(True, OP_CONT, b"Eggs"), + ) + + def test_client_send_close_in_fragmented_message(self): + client = Connection(Side.CLIENT) + client.send_text(b"Spam", fin=False) + self.assertFrameSent(client, Frame(False, OP_TEXT, b"Spam")) + # The spec says: "An endpoint MUST be capable of handling control + # frames in the middle of a fragmented message." However, since the + # endpoint must not send a data frame after a close frame, a close + # frame can't be "in the middle" of a fragmented message. + with self.assertRaises(ProtocolError) as raised: + client.send_close(1001) + self.assertEqual(str(raised.exception), "expected a continuation frame") + client.send_continuation(b"Eggs", fin=True) + + def test_server_send_close_in_fragmented_message(self): + server = Connection(Side.CLIENT) + server.send_text(b"Spam", fin=False) + self.assertFrameSent(server, Frame(False, OP_TEXT, b"Spam")) + # The spec says: "An endpoint MUST be capable of handling control + # frames in the middle of a fragmented message." However, since the + # endpoint must not send a data frame after a close frame, a close + # frame can't be "in the middle" of a fragmented message. + with self.assertRaises(ProtocolError) as raised: + server.send_close(1000) + self.assertEqual(str(raised.exception), "expected a continuation frame") + + def test_client_receive_close_in_fragmented_message(self): + client = Connection(Side.CLIENT) + client.receive_data(b"\x01\x04Spam") + self.assertFrameReceived( + client, Frame(False, OP_TEXT, b"Spam"), + ) + # The spec says: "An endpoint MUST be capable of handling control + # frames in the middle of a fragmented message." However, since the + # endpoint must not send a data frame after a close frame, a close + # frame can't be "in the middle" of a fragmented message. + with self.assertRaises(ProtocolError) as raised: + client.receive_data(b"\x88\x02\x03\xe8") + self.assertEqual(str(raised.exception), "incomplete fragmented message") + self.assertConnectionFailing(client, 1002, "incomplete fragmented message") + + def test_server_receive_close_in_fragmented_message(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x01\x84\x00\x00\x00\x00Spam") + self.assertFrameReceived( + server, Frame(False, OP_TEXT, b"Spam"), + ) + # The spec says: "An endpoint MUST be capable of handling control + # frames in the middle of a fragmented message." However, since the + # endpoint must not send a data frame after a close frame, a close + # frame can't be "in the middle" of a fragmented message. + with self.assertRaises(ProtocolError) as raised: + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") + self.assertEqual(str(raised.exception), "incomplete fragmented message") + self.assertConnectionFailing(server, 1002, "incomplete fragmented message") + + +class EOFTests(ConnectionTestCase): + """ + Test connection termination. + + """ + + def test_client_receives_eof(self): + client = Connection(Side.CLIENT) + client.receive_data(b"\x88\x00") + self.assertConnectionClosing(client) + client.receive_eof() # does not raise an exception + + def test_server_receives_eof(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") + self.assertConnectionClosing(server) + server.receive_eof() # does not raise an exception + + def test_client_receives_eof_between_frames(self): + client = Connection(Side.CLIENT) + with self.assertRaises(EOFError) as raised: + client.receive_eof() + self.assertEqual(str(raised.exception), "unexpected end of stream") + + def test_server_receives_eof_between_frames(self): + server = Connection(Side.SERVER) + with self.assertRaises(EOFError) as raised: + server.receive_eof() + self.assertEqual(str(raised.exception), "unexpected end of stream") + + def test_client_receives_eof_inside_frame(self): + client = Connection(Side.CLIENT) + client.receive_data(b"\x81") + with self.assertRaises(EOFError) as raised: + client.receive_eof() + self.assertEqual( + str(raised.exception), "stream ends after 1 bytes, expected 2 bytes" + ) + + def test_server_receives_eof_inside_frame(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x81") + with self.assertRaises(EOFError) as raised: + server.receive_eof() + self.assertEqual( + str(raised.exception), "stream ends after 1 bytes, expected 2 bytes" + ) + + def test_client_receives_data_after_exception(self): + client = Connection(Side.CLIENT) + with self.assertRaises(ProtocolError) as raised: + client.receive_data(b"\xff\xff") + self.assertEqual(str(raised.exception), "invalid opcode") + with self.assertRaises(RuntimeError) as raised: + client.receive_data(b"\x00\x00") + self.assertEqual( + str(raised.exception), "cannot receive data or EOF after an error" + ) + + def test_server_receives_data_after_exception(self): + server = Connection(Side.SERVER) + with self.assertRaises(ProtocolError) as raised: + server.receive_data(b"\xff\xff") + self.assertEqual(str(raised.exception), "invalid opcode") + with self.assertRaises(RuntimeError) as raised: + server.receive_data(b"\x00\x00") + self.assertEqual( + str(raised.exception), "cannot receive data or EOF after an error" + ) + + def test_client_receives_eof_after_exception(self): + client = Connection(Side.CLIENT) + with self.assertRaises(ProtocolError) as raised: + client.receive_data(b"\xff\xff") + self.assertEqual(str(raised.exception), "invalid opcode") + with self.assertRaises(RuntimeError) as raised: + client.receive_eof() + self.assertEqual( + str(raised.exception), "cannot receive data or EOF after an error" + ) + + def test_server_receives_eof_after_exception(self): + server = Connection(Side.SERVER) + with self.assertRaises(ProtocolError) as raised: + server.receive_data(b"\xff\xff") + self.assertEqual(str(raised.exception), "invalid opcode") + with self.assertRaises(RuntimeError) as raised: + server.receive_eof() + self.assertEqual( + str(raised.exception), "cannot receive data or EOF after an error" + ) + + def test_client_receives_data_after_eof(self): + client = Connection(Side.CLIENT) + client.receive_data(b"\x88\x00") + self.assertConnectionClosing(client) + client.receive_eof() + with self.assertRaises(EOFError) as raised: + client.receive_data(b"\x88\x00") + self.assertEqual(str(raised.exception), "stream ended") + + def test_server_receives_data_after_eof(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") + self.assertConnectionClosing(server) + server.receive_eof() + with self.assertRaises(EOFError) as raised: + server.receive_data(b"\x88\x80\x00\x00\x00\x00") + self.assertEqual(str(raised.exception), "stream ended") + + def test_client_receives_eof_after_eof(self): + client = Connection(Side.CLIENT) + client.receive_data(b"\x88\x00") + self.assertConnectionClosing(client) + client.receive_eof() + with self.assertRaises(EOFError) as raised: + client.receive_eof() + self.assertEqual(str(raised.exception), "stream ended") + + def test_server_receives_eof_after_eof(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") + self.assertConnectionClosing(server) + server.receive_eof() + with self.assertRaises(EOFError) as raised: + server.receive_eof() + self.assertEqual(str(raised.exception), "stream ended") + + +class ErrorTests(ConnectionTestCase): + """ + Test other error cases. + + """ + + def test_client_hits_internal_error_reading_frame(self): + client = Connection(Side.CLIENT) + # This isn't supposed to happen, so we're simulating it. + with unittest.mock.patch("struct.unpack", side_effect=RuntimeError("BOOM")): + with self.assertRaises(RuntimeError) as raised: + client.receive_data(b"\x81\x00") + self.assertEqual(str(raised.exception), "BOOM") + self.assertConnectionFailing(client, 1011, "") + + def test_server_hits_internal_error_reading_frame(self): + server = Connection(Side.SERVER) + # This isn't supposed to happen, so we're simulating it. + with unittest.mock.patch("struct.unpack", side_effect=RuntimeError("BOOM")): + with self.assertRaises(RuntimeError) as raised: + server.receive_data(b"\x81\x80\x00\x00\x00\x00") + self.assertEqual(str(raised.exception), "BOOM") + self.assertConnectionFailing(server, 1011, "") + + +class ExtensionsTests(ConnectionTestCase): + """ + Test how extensions affect frames. + + """ + + def test_client_extension_encodes_frame(self): + client = Connection(Side.CLIENT) + client.extensions = [Rsv2Extension()] + with self.enforce_mask(b"\x00\x44\x88\xcc"): + client.send_ping(b"") + self.assertEqual(client.bytes_to_send(), [b"\xa9\x80\x00\x44\x88\xcc"]) + + def test_server_extension_encodes_frame(self): + server = Connection(Side.SERVER) + server.extensions = [Rsv2Extension()] + server.send_ping(b"") + self.assertEqual(server.bytes_to_send(), [b"\xa9\x00"]) + + def test_client_extension_decodes_frame(self): + client = Connection(Side.CLIENT) + client.extensions = [Rsv2Extension()] + client.receive_data(b"\xaa\x00") + self.assertEqual(client.events_received(), [Frame(True, OP_PONG, b"")]) + + def test_server_extension_decodes_frame(self): + server = Connection(Side.SERVER) + server.extensions = [Rsv2Extension()] + server.receive_data(b"\xaa\x80\x00\x44\x88\xcc") + self.assertEqual(server.events_received(), [Frame(True, OP_PONG, b"")]) diff --git a/tests/test_frames.py b/tests/test_frames.py index 37a73b2df..514fe7c54 100644 --- a/tests/test_frames.py +++ b/tests/test_frames.py @@ -9,8 +9,15 @@ from .utils import GeneratorTestCase -class FrameTests(GeneratorTestCase): - def parse(self, data, mask=False, max_size=None, extensions=None): +class FramesTestCase(GeneratorTestCase): + def enforce_mask(self, mask): + return unittest.mock.patch("secrets.token_bytes", return_value=mask) + + def parse(self, data, mask, max_size=None, extensions=None): + """ + Parse a frame from a bytestring. + + """ reader = StreamReader() reader.feed_data(data) reader.feed_eof() @@ -19,117 +26,134 @@ def parse(self, data, mask=False, max_size=None, extensions=None): ) return self.assertGeneratorReturns(parser) - def round_trip(self, data, frame, mask=False, extensions=None): + def assertFrameData(self, frame, data, mask, extensions=None): + """ + Serializing frame yields data. Parsing data yields frame. + + """ + # Compare frames first, because test failures are easier to read, + # especially when mask = True. parsed = self.parse(data, mask=mask, extensions=extensions) self.assertEqual(parsed, frame) # Make masking deterministic by reusing the same "random" mask. # This has an effect only when mask is True. mask_bytes = data[2:6] if mask else b"" - with unittest.mock.patch("secrets.token_bytes", return_value=mask_bytes): - serialized = parsed.serialize(mask=mask, extensions=extensions) + with self.enforce_mask(mask_bytes): + serialized = frame.serialize(mask=mask, extensions=extensions) self.assertEqual(serialized, data) - def test_text(self): - self.round_trip(b"\x81\x04Spam", Frame(True, OP_TEXT, b"Spam")) + +class FrameTests(FramesTestCase): + def test_text_unmasked(self): + self.assertFrameData( + Frame(True, OP_TEXT, b"Spam"), b"\x81\x04Spam", mask=False, + ) def test_text_masked(self): - self.round_trip( - b"\x81\x84\x5b\xfb\xe1\xa8\x08\x8b\x80\xc5", + self.assertFrameData( Frame(True, OP_TEXT, b"Spam"), + b"\x81\x84\x5b\xfb\xe1\xa8\x08\x8b\x80\xc5", mask=True, ) - def test_binary(self): - self.round_trip(b"\x82\x04Eggs", Frame(True, OP_BINARY, b"Eggs")) + def test_binary_unmasked(self): + self.assertFrameData( + Frame(True, OP_BINARY, b"Eggs"), b"\x82\x04Eggs", mask=False, + ) def test_binary_masked(self): - self.round_trip( - b"\x82\x84\x53\xcd\xe2\x89\x16\xaa\x85\xfa", + self.assertFrameData( Frame(True, OP_BINARY, b"Eggs"), + b"\x82\x84\x53\xcd\xe2\x89\x16\xaa\x85\xfa", mask=True, ) - def test_non_ascii_text(self): - self.round_trip( - b"\x81\x05caf\xc3\xa9", Frame(True, OP_TEXT, "café".encode("utf-8")) + def test_non_ascii_text_unmasked(self): + self.assertFrameData( + Frame(True, OP_TEXT, "café".encode("utf-8")), + b"\x81\x05caf\xc3\xa9", + mask=False, ) def test_non_ascii_text_masked(self): - self.round_trip( - b"\x81\x85\x64\xbe\xee\x7e\x07\xdf\x88\xbd\xcd", + self.assertFrameData( Frame(True, OP_TEXT, "café".encode("utf-8")), + b"\x81\x85\x64\xbe\xee\x7e\x07\xdf\x88\xbd\xcd", mask=True, ) def test_close(self): - self.round_trip(b"\x88\x00", Frame(True, OP_CLOSE, b"")) + self.assertFrameData(Frame(True, OP_CLOSE, b""), b"\x88\x00", mask=False) def test_ping(self): - self.round_trip(b"\x89\x04ping", Frame(True, OP_PING, b"ping")) + self.assertFrameData(Frame(True, OP_PING, b"ping"), b"\x89\x04ping", mask=False) def test_pong(self): - self.round_trip(b"\x8a\x04pong", Frame(True, OP_PONG, b"pong")) + self.assertFrameData(Frame(True, OP_PONG, b"pong"), b"\x8a\x04pong", mask=False) def test_long(self): - self.round_trip( - b"\x82\x7e\x00\x7e" + 126 * b"a", Frame(True, OP_BINARY, 126 * b"a") + self.assertFrameData( + Frame(True, OP_BINARY, 126 * b"a"), + b"\x82\x7e\x00\x7e" + 126 * b"a", + mask=False, ) def test_very_long(self): - self.round_trip( - b"\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x00" + 65536 * b"a", + self.assertFrameData( Frame(True, OP_BINARY, 65536 * b"a"), + b"\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x00" + 65536 * b"a", + mask=False, ) def test_payload_too_big(self): with self.assertRaises(PayloadTooBig): - self.parse(b"\x82\x7e\x04\x01" + 1025 * b"a", max_size=1024) + self.parse(b"\x82\x7e\x04\x01" + 1025 * b"a", mask=False, max_size=1024) def test_bad_reserved_bits(self): for data in [b"\xc0\x00", b"\xa0\x00", b"\x90\x00"]: with self.subTest(data=data): with self.assertRaises(ProtocolError): - self.parse(data) + self.parse(data, mask=False) def test_good_opcode(self): for opcode in list(range(0x00, 0x03)) + list(range(0x08, 0x0B)): data = bytes([0x80 | opcode, 0]) with self.subTest(data=data): - self.parse(data) # does not raise an exception + self.parse(data, mask=False) # does not raise an exception def test_bad_opcode(self): for opcode in list(range(0x03, 0x08)) + list(range(0x0B, 0x10)): data = bytes([0x80 | opcode, 0]) with self.subTest(data=data): with self.assertRaises(ProtocolError): - self.parse(data) + self.parse(data, mask=False) def test_mask_flag(self): # Mask flag correctly set. self.parse(b"\x80\x80\x00\x00\x00\x00", mask=True) # Mask flag incorrectly unset. with self.assertRaises(ProtocolError): - self.parse(b"\x80\x80\x00\x00\x00\x00") + self.parse(b"\x80\x80\x00\x00\x00\x00", mask=False) # Mask flag correctly unset. - self.parse(b"\x80\x00") + self.parse(b"\x80\x00", mask=False) # Mask flag incorrectly set. with self.assertRaises(ProtocolError): self.parse(b"\x80\x00", mask=True) def test_control_frame_max_length(self): # At maximum allowed length. - self.parse(b"\x88\x7e\x00\x7d" + 125 * b"a") + self.parse(b"\x88\x7e\x00\x7d" + 125 * b"a", mask=False) # Above maximum allowed length. with self.assertRaises(ProtocolError): - self.parse(b"\x88\x7e\x00\x7e" + 126 * b"a") + self.parse(b"\x88\x7e\x00\x7e" + 126 * b"a", mask=False) def test_fragmented_control_frame(self): # Fin bit correctly set. - self.parse(b"\x88\x00") + self.parse(b"\x88\x00", mask=False) # Fin bit incorrectly unset. with self.assertRaises(ProtocolError): - self.parse(b"\x08\x00") + self.parse(b"\x08\x00", mask=False) def test_extensions(self): class Rot13: @@ -145,8 +169,11 @@ def encode(frame): def decode(frame, *, max_size=None): return Rot13.encode(frame) - self.round_trip( - b"\x81\x05uryyb", Frame(True, OP_TEXT, b"hello"), extensions=[Rot13()] + self.assertFrameData( + Frame(True, OP_TEXT, b"hello"), + b"\x81\x05uryyb", + mask=False, + extensions=[Rot13()], ) @@ -205,15 +232,19 @@ def test_prepare_ctrl_none(self): class ParseAndSerializeCloseTests(unittest.TestCase): - def round_trip(self, data, code, reason): - parsed = parse_close(data) - self.assertEqual(parsed, (code, reason)) + def assertCloseData(self, code, reason, data): + """ + Serializing code / reason yields data. Parsing data yields code / reason. + + """ serialized = serialize_close(code, reason) self.assertEqual(serialized, data) + parsed = parse_close(data) + self.assertEqual(parsed, (code, reason)) def test_parse_close_and_serialize_close(self): - self.round_trip(b"\x03\xe8", 1000, "") - self.round_trip(b"\x03\xe8OK", 1000, "OK") + self.assertCloseData(1000, "", b"\x03\xe8") + self.assertCloseData(1000, "OK", b"\x03\xe8OK") def test_parse_close_empty(self): self.assertEqual(parse_close(b""), (1005, "")) From fad4c57d4d84cb884bd30ebe44e07ace4d5f4cfb Mon Sep 17 00:00:00 2001 From: akgnah <1024@setq.me> Date: Mon, 2 Mar 2020 13:01:51 +0800 Subject: [PATCH 0706/1539] fix typo in example/counter.py --- example/counter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/counter.py b/example/counter.py index dbbbe5935..239ec203a 100755 --- a/example/counter.py +++ b/example/counter.py @@ -58,7 +58,7 @@ async def counter(websocket, path): STATE["value"] += 1 await notify_state() else: - logging.error("unsupported event: {}", data) + logging.error("unsupported event: %s", data) finally: await unregister(websocket) From 458c4d67faaaf52359f713aafc3eda26afb1de3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20L=C3=89VEIL?= Date: Thu, 9 Apr 2020 01:09:36 +0200 Subject: [PATCH 0707/1539] support request lines of 4107 bytes fix #743 avoid sending a `HTTP 400` response when popular browsers send a request with cookies maxing up the user-agent limit --- src/websockets/http11.py | 2 +- src/websockets/http_legacy.py | 2 +- tests/test_http11.py | 4 ++-- tests/test_http_legacy.py | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 58ee09253..693a20e54 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -6,7 +6,7 @@ MAX_HEADERS = 256 -MAX_LINE = 4096 +MAX_LINE = 4107 def d(value: bytes) -> str: diff --git a/src/websockets/http_legacy.py b/src/websockets/http_legacy.py index 3630d3593..0bc548b31 100644 --- a/src/websockets/http_legacy.py +++ b/src/websockets/http_legacy.py @@ -9,7 +9,7 @@ __all__ = ["read_request", "read_response"] MAX_HEADERS = 256 -MAX_LINE = 4096 +MAX_LINE = 4107 def d(value: bytes) -> str: diff --git a/tests/test_http11.py b/tests/test_http11.py index 4574cf97e..87be6e486 100644 --- a/tests/test_http11.py +++ b/tests/test_http11.py @@ -260,8 +260,8 @@ def test_parse_too_long_value(self): next(self.parse_headers()) def test_parse_too_long_line(self): - # Header line contains 5 + 4090 + 2 = 4097 bytes. - self.reader.feed_data(b"foo: " + b"a" * 4090 + b"\r\n\r\n") + # Header line contains 5 + 4101 + 2 = 4108 bytes. + self.reader.feed_data(b"foo: " + b"a" * 4101 + b"\r\n\r\n") with self.assertRaises(SecurityError): next(self.parse_headers()) diff --git a/tests/test_http_legacy.py b/tests/test_http_legacy.py index 3b43a6274..667aff52a 100644 --- a/tests/test_http_legacy.py +++ b/tests/test_http_legacy.py @@ -124,8 +124,8 @@ async def test_headers_limit(self): await read_headers(self.stream) async def test_line_limit(self): - # Header line contains 5 + 4090 + 2 = 4097 bytes. - self.stream.feed_data(b"foo: " + b"a" * 4090 + b"\r\n\r\n") + # Header line contains 5 + 4101 + 2 = 4108 bytes. + self.stream.feed_data(b"foo: " + b"a" * 4101 + b"\r\n\r\n") with self.assertRaises(SecurityError): await read_headers(self.stream) From f056c1cfb8ef417180bf337308aa73e49c9469b4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 26 Jul 2020 21:10:56 +0200 Subject: [PATCH 0708/1539] Adjust max header size (again). See #743 for the rationale. --- src/websockets/http11.py | 2 +- src/websockets/http_legacy.py | 2 +- tests/test_http11.py | 4 ++-- tests/test_http_legacy.py | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 693a20e54..0754ddabb 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -6,7 +6,7 @@ MAX_HEADERS = 256 -MAX_LINE = 4107 +MAX_LINE = 4110 def d(value: bytes) -> str: diff --git a/src/websockets/http_legacy.py b/src/websockets/http_legacy.py index 0bc548b31..5afe5f898 100644 --- a/src/websockets/http_legacy.py +++ b/src/websockets/http_legacy.py @@ -9,7 +9,7 @@ __all__ = ["read_request", "read_response"] MAX_HEADERS = 256 -MAX_LINE = 4107 +MAX_LINE = 4110 def d(value: bytes) -> str: diff --git a/tests/test_http11.py b/tests/test_http11.py index 87be6e486..9e4d70620 100644 --- a/tests/test_http11.py +++ b/tests/test_http11.py @@ -260,8 +260,8 @@ def test_parse_too_long_value(self): next(self.parse_headers()) def test_parse_too_long_line(self): - # Header line contains 5 + 4101 + 2 = 4108 bytes. - self.reader.feed_data(b"foo: " + b"a" * 4101 + b"\r\n\r\n") + # Header line contains 5 + 4104 + 2 = 4111 bytes. + self.reader.feed_data(b"foo: " + b"a" * 4104 + b"\r\n\r\n") with self.assertRaises(SecurityError): next(self.parse_headers()) diff --git a/tests/test_http_legacy.py b/tests/test_http_legacy.py index 667aff52a..e4c75315e 100644 --- a/tests/test_http_legacy.py +++ b/tests/test_http_legacy.py @@ -124,8 +124,8 @@ async def test_headers_limit(self): await read_headers(self.stream) async def test_line_limit(self): - # Header line contains 5 + 4101 + 2 = 4108 bytes. - self.stream.feed_data(b"foo: " + b"a" * 4101 + b"\r\n\r\n") + # Header line contains 5 + 4104 + 2 = 4111 bytes. + self.stream.feed_data(b"foo: " + b"a" * 4104 + b"\r\n\r\n") with self.assertRaises(SecurityError): await read_headers(self.stream) From 639b993a236107f22d529cde488d1e1eb6645228 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 26 Jul 2020 21:38:15 +0200 Subject: [PATCH 0709/1539] Create correct Host header for IPv6. Fix #802. --- src/websockets/asyncio_client.py | 7 ++----- src/websockets/client.py | 9 ++++----- src/websockets/http.py | 26 +++++++++++++++++++++++++- tests/test_http.py | 29 +++++++++++++++++++++++++++-- 4 files changed, 58 insertions(+), 13 deletions(-) diff --git a/src/websockets/asyncio_client.py b/src/websockets/asyncio_client.py index f95dae060..e01a641cb 100644 --- a/src/websockets/asyncio_client.py +++ b/src/websockets/asyncio_client.py @@ -31,7 +31,7 @@ parse_extension, parse_subprotocol, ) -from .http import USER_AGENT +from .http import USER_AGENT, build_host from .http_legacy import read_response from .protocol import WebSocketCommonProtocol from .typing import ExtensionHeader, Origin, Subprotocol @@ -251,10 +251,7 @@ async def handshake( """ request_headers = Headers() - if wsuri.port == (443 if wsuri.secure else 80): # pragma: no cover - request_headers["Host"] = wsuri.host - else: - request_headers["Host"] = f"{wsuri.host}:{wsuri.port}" + request_headers["Host"] = build_host(wsuri.host, wsuri.port, wsuri.secure) if wsuri.user_info: request_headers["Authorization"] = build_authorization_basic( diff --git a/src/websockets/client.py b/src/websockets/client.py index 3f9777b94..a7bfcc4ee 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -23,7 +23,7 @@ parse_subprotocol, parse_upgrade, ) -from .http import USER_AGENT +from .http import USER_AGENT, build_host from .http11 import Request, Response from .typing import ( ConnectionOption, @@ -71,10 +71,9 @@ def connect(self) -> Request: """ headers = Headers() - if self.wsuri.port == (443 if self.wsuri.secure else 80): - headers["Host"] = self.wsuri.host - else: - headers["Host"] = f"{self.wsuri.host}:{self.wsuri.port}" + headers["Host"] = build_host( + self.wsuri.host, self.wsuri.port, self.wsuri.secure + ) if self.wsuri.user_info: headers["Authorization"] = build_authorization_basic(*self.wsuri.user_info) diff --git a/src/websockets/http.py b/src/websockets/http.py index 850b9beaa..ed3fe48d0 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -1,4 +1,5 @@ import asyncio +import ipaddress import sys import warnings from typing import Tuple @@ -9,13 +10,36 @@ from .version import version as websockets_version -__all__ = ["USER_AGENT"] +__all__ = ["USER_AGENT", "build_host"] PYTHON_VERSION = "{}.{}".format(*sys.version_info) USER_AGENT = f"Python/{PYTHON_VERSION} websockets/{websockets_version}" +def build_host(host: str, port: int, secure: bool) -> str: + """ + Build a ``Host`` header. + + """ + # https://tools.ietf.org/html/rfc3986#section-3.2.2 + # IPv6 addresses must be enclosed in brackets. + try: + address = ipaddress.ip_address(host) + except ValueError: + # host is a hostname + pass + else: + # host is an IP address + if address.version == 6: + host = f"[{host}]" + + if port != (443 if secure else 80): + host = f"{host}:{port}" + + return host + + # Backwards compatibility with previously documented public APIs diff --git a/tests/test_http.py b/tests/test_http.py index 322650354..ca7c1c0a4 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -1,2 +1,27 @@ -# Check that the legacy http module imports without an exception. -from websockets.http import * # noqa +import unittest + +from websockets.http import * + + +class HTTPTests(unittest.TestCase): + def test_build_host(self): + for (host, port, secure), result in [ + (("localhost", 80, False), "localhost"), + (("localhost", 8000, False), "localhost:8000"), + (("localhost", 443, True), "localhost"), + (("localhost", 8443, True), "localhost:8443"), + (("example.com", 80, False), "example.com"), + (("example.com", 8000, False), "example.com:8000"), + (("example.com", 443, True), "example.com"), + (("example.com", 8443, True), "example.com:8443"), + (("127.0.0.1", 80, False), "127.0.0.1"), + (("127.0.0.1", 8000, False), "127.0.0.1:8000"), + (("127.0.0.1", 443, True), "127.0.0.1"), + (("127.0.0.1", 8443, True), "127.0.0.1:8443"), + (("::1", 80, False), "[::1]"), + (("::1", 8000, False), "[::1]:8000"), + (("::1", 443, True), "[::1]"), + (("::1", 8443, True), "[::1]:8443"), + ]: + with self.subTest(host=host, port=port, secure=secure): + self.assertEqual(build_host(host, port, secure), result) From 6466e238f4809e81579f70460563fa0d00b7905a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 26 Jul 2020 21:49:05 +0200 Subject: [PATCH 0710/1539] Raise a good error when sending a dict. This must be a common mistake. Fix #734. --- src/websockets/protocol.py | 10 ++++++++++ tests/test_protocol.py | 5 +++++ 2 files changed, 15 insertions(+) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 2e5d95e06..92ce8e305 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -25,6 +25,7 @@ Dict, Iterable, List, + Mapping, Optional, Union, cast, @@ -548,6 +549,10 @@ async def send( :meth:`send` will raise a :exc:`TypeError` and the connection will be closed. + :meth:`send` rejects dict-like objects because this is often an error. + If you wish to send the keys of a dict-like object as fragments, call + its :meth:`~dict.keys` method and pass the result to :meth:`send`. + Canceling :meth:`send` is discouraged. Instead, you should close the connection with :meth:`close`. Indeed, there only two situations where :meth:`send` yields control to the event loop: @@ -576,6 +581,11 @@ async def send( opcode, data = prepare_data(message) await self.write_frame(True, opcode, data) + # Catch a common mistake -- passing a dict to send(). + + elif isinstance(message, Mapping): + raise TypeError("data is a dict-like object") + # Fragmented message -- regular iterator. elif isinstance(message, Iterable): diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 3054600e1..432c31ef5 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -584,6 +584,11 @@ def test_send_binary_from_non_contiguous_memoryview(self): self.loop.run_until_complete(self.protocol.send(memoryview(b"tteeaa")[::2])) self.assertOneFrameSent(True, OP_BINARY, b"tea") + def test_send_dict(self): + with self.assertRaises(TypeError): + self.loop.run_until_complete(self.protocol.send({"not": "encoded"})) + self.assertNoFrameSent() + def test_send_type_error(self): with self.assertRaises(TypeError): self.loop.run_until_complete(self.protocol.send(42)) From 97ae02b4560516f577b265ef222fff5fb3e950b6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 26 Jul 2020 22:05:31 +0200 Subject: [PATCH 0711/1539] Document pitfall. Fix #335. --- docs/faq.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/faq.rst b/docs/faq.rst index cea3f5358..5e6439055 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -204,6 +204,13 @@ There are several reasons why long-lived connections may be lost: If you're facing a reproducible issue, :ref:`enable debug logs ` to see when and how connections are closed. +Why do I get the error: ``module 'websockets' has no attribute '...'``? +....................................................................... + +Often, this is because you created a script called ``websockets.py`` in your +current working directory. Then ``import websockets`` imports this module +instead of the websockets library. + Are there ``onopen``, ``onmessage``, ``onerror``, and ``onclose`` callbacks? ............................................................................ From 0a1195eed14eddb3f27929ef49af4024814c3f37 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 26 Jul 2020 22:48:28 +0200 Subject: [PATCH 0712/1539] Type create_protocol arguments as callables. Fix #764. --- src/websockets/asyncio_client.py | 4 ++-- src/websockets/asyncio_server.py | 2 +- src/websockets/auth.py | 13 +++++++++---- tests/test_auth.py | 22 +++++++++++++++++++++- 4 files changed, 33 insertions(+), 8 deletions(-) diff --git a/src/websockets/asyncio_client.py b/src/websockets/asyncio_client.py index e01a641cb..efa29b69a 100644 --- a/src/websockets/asyncio_client.py +++ b/src/websockets/asyncio_client.py @@ -9,7 +9,7 @@ import logging import warnings from types import TracebackType -from typing import Any, Generator, List, Optional, Sequence, Tuple, Type, cast +from typing import Any, Callable, Generator, List, Optional, Sequence, Tuple, Type, cast from .datastructures import Headers, HeadersLike from .exceptions import ( @@ -373,7 +373,7 @@ def __init__( uri: str, *, path: Optional[str] = None, - create_protocol: Optional[Type[WebSocketClientProtocol]] = None, + create_protocol: Optional[Callable[[Any], WebSocketClientProtocol]] = None, ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, close_timeout: Optional[float] = None, diff --git a/src/websockets/asyncio_server.py b/src/websockets/asyncio_server.py index 89ddf6c7d..fe61c7ddc 100644 --- a/src/websockets/asyncio_server.py +++ b/src/websockets/asyncio_server.py @@ -850,7 +850,7 @@ def __init__( port: Optional[int] = None, *, path: Optional[str] = None, - create_protocol: Optional[Type[WebSocketServerProtocol]] = None, + create_protocol: Optional[Callable[[Any], WebSocketServerProtocol]] = None, ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, close_timeout: Optional[float] = None, diff --git a/src/websockets/auth.py b/src/websockets/auth.py index 03e8536c5..c1b7a0b1a 100644 --- a/src/websockets/auth.py +++ b/src/websockets/auth.py @@ -7,7 +7,7 @@ import functools import http -from typing import Any, Awaitable, Callable, Iterable, Optional, Tuple, Type, Union +from typing import Any, Awaitable, Callable, Iterable, Optional, Tuple, Union, cast from .asyncio_server import HTTPResponse, WebSocketServerProtocol from .datastructures import Headers @@ -90,9 +90,7 @@ def basic_auth_protocol_factory( realm: str, credentials: Optional[Union[Credentials, Iterable[Credentials]]] = None, check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None, - create_protocol: Type[ - BasicAuthWebSocketServerProtocol - ] = BasicAuthWebSocketServerProtocol, + create_protocol: Optional[Callable[[Any], BasicAuthWebSocketServerProtocol]] = None, ) -> Callable[[Any], BasicAuthWebSocketServerProtocol]: """ Protocol factory that enforces HTTP Basic Auth. @@ -155,6 +153,13 @@ async def check_credentials(username: str, password: str) -> bool: else: raise TypeError(f"invalid credentials argument: {credentials}") + if create_protocol is None: + # Not sure why mypy cannot figure this out. + create_protocol = cast( + Callable[[Any], BasicAuthWebSocketServerProtocol], + BasicAuthWebSocketServerProtocol, + ) + return functools.partial( create_protocol, realm=realm, check_credentials=check_credentials ) diff --git a/tests/test_auth.py b/tests/test_auth.py index c693c9f45..68642389e 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -19,6 +19,12 @@ def test_is_not_credentials(self): self.assertFalse(is_credentials("username")) +class CustomWebSocketServerProtocol(BasicAuthWebSocketServerProtocol): + async def process_request(self, path, request_headers): + type(self).used = True + return await super().process_request(path, request_headers) + + class AuthClientServerTests(ClientServerTestsMixin, AsyncioTestCase): create_protocol = basic_auth_protocol_factory( @@ -73,7 +79,7 @@ async def check_credentials(username, password): return password == "iloveyou" create_protocol_check_credentials = basic_auth_protocol_factory( - realm="auth-tests", check_credentials=check_credentials + realm="auth-tests", check_credentials=check_credentials, ) @with_server(create_protocol=create_protocol_check_credentials) @@ -82,6 +88,20 @@ def test_basic_auth_check_credentials(self): self.loop.run_until_complete(self.client.send("Hello!")) self.loop.run_until_complete(self.client.recv()) + create_protocol_custom_protocol = basic_auth_protocol_factory( + realm="auth-tests", + credentials=[("hello", "iloveyou")], + create_protocol=CustomWebSocketServerProtocol, + ) + + @with_server(create_protocol=create_protocol_custom_protocol) + @with_client(user_info=("hello", "iloveyou")) + def test_basic_auth_custom_protocol(self): + self.assertTrue(CustomWebSocketServerProtocol.used) + del CustomWebSocketServerProtocol.used + self.loop.run_until_complete(self.client.send("Hello!")) + self.loop.run_until_complete(self.client.recv()) + @with_server(create_protocol=create_protocol) def test_basic_auth_missing_credentials(self): with self.assertRaises(InvalidStatusCode) as raised: From cb91aa1575066f6624944cb75bb41d68a45d1b45 Mon Sep 17 00:00:00 2001 From: Janakarajan Natarajan Date: Tue, 18 Aug 2020 22:52:03 +0000 Subject: [PATCH 0713/1539] Add aarch64 wheel build --- .travis.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.travis.yml b/.travis.yml index 26e1de60e..e31c9ea0b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -13,6 +13,13 @@ matrix: python: "3.7" services: - docker + - language: python + dist: xenial + sudo: required + python: "3.7" + arch: arm64 + services: + - docker - os: osx osx_image: xcode8.3 From c39268c4867e41d11c20f7859583761d52a04012 Mon Sep 17 00:00:00 2001 From: Ram Rachum Date: Mon, 27 Jul 2020 14:06:08 +0300 Subject: [PATCH 0714/1539] Fix exception causes in handshake_legacy.py --- src/websockets/handshake_legacy.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/websockets/handshake_legacy.py b/src/websockets/handshake_legacy.py index 7e6acc77d..d34ca5f7f 100644 --- a/src/websockets/handshake_legacy.py +++ b/src/websockets/handshake_legacy.py @@ -91,28 +91,28 @@ def check_request(headers: Headers) -> str: try: s_w_key = headers["Sec-WebSocket-Key"] - except KeyError: - raise InvalidHeader("Sec-WebSocket-Key") - except MultipleValuesError: + except KeyError as exc: + raise InvalidHeader("Sec-WebSocket-Key") from exc + except MultipleValuesError as exc: raise InvalidHeader( "Sec-WebSocket-Key", "more than one Sec-WebSocket-Key header found" - ) + ) from exc try: raw_key = base64.b64decode(s_w_key.encode(), validate=True) - except binascii.Error: - raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key) + except binascii.Error as exc: + raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key) from exc if len(raw_key) != 16: raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key) try: s_w_version = headers["Sec-WebSocket-Version"] - except KeyError: - raise InvalidHeader("Sec-WebSocket-Version") - except MultipleValuesError: + except KeyError as exc: + raise InvalidHeader("Sec-WebSocket-Version") from exc + except MultipleValuesError as exc: raise InvalidHeader( "Sec-WebSocket-Version", "more than one Sec-WebSocket-Version header found" - ) + ) from exc if s_w_version != "13": raise InvalidHeaderValue("Sec-WebSocket-Version", s_w_version) @@ -168,12 +168,12 @@ def check_response(headers: Headers, key: str) -> None: try: s_w_accept = headers["Sec-WebSocket-Accept"] - except KeyError: - raise InvalidHeader("Sec-WebSocket-Accept") - except MultipleValuesError: + except KeyError as exc: + raise InvalidHeader("Sec-WebSocket-Accept") from exc + except MultipleValuesError as exc: raise InvalidHeader( "Sec-WebSocket-Accept", "more than one Sec-WebSocket-Accept header found" - ) + ) from exc if s_w_accept != accept(key): raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept) From 69cf86724dc2a86f7e57f6393dd322a249dbee17 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Nov 2020 13:29:09 +0100 Subject: [PATCH 0715/1539] Move question to the FAQ. It was written in the cheatsheet before there was a FAQ. --- docs/cheatsheet.rst | 22 ---------------------- docs/faq.rst | 20 ++++++++++++++++++++ 2 files changed, 20 insertions(+), 22 deletions(-) diff --git a/docs/cheatsheet.rst b/docs/cheatsheet.rst index f897326a6..4b95c9eea 100644 --- a/docs/cheatsheet.rst +++ b/docs/cheatsheet.rst @@ -85,25 +85,3 @@ in particular. Fortunately Python's official documentation provides advice to .. _develop with asyncio: https://docs.python.org/3/library/asyncio-dev.html -Passing additional arguments to the connection handler ------------------------------------------------------- - -When writing a server, if you need to pass additional arguments to the -connection handler, you can bind them with :func:`functools.partial`:: - - import asyncio - import functools - import websockets - - async def handler(websocket, path, extra_argument): - ... - - bound_handler = functools.partial(handler, extra_argument='spam') - start_server = websockets.serve(bound_handler, '127.0.0.1', 8765) - - asyncio.get_event_loop().run_until_complete(start_server) - asyncio.get_event_loop().run_forever() - -Another way to achieve this result is to define the ``handler`` coroutine in -a scope where the ``extra_argument`` variable exists instead of injecting it -through an argument. diff --git a/docs/faq.rst b/docs/faq.rst index 5e6439055..5748521f0 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -56,6 +56,26 @@ See also Python's documentation about `running blocking code`_. .. _running blocking code: https://docs.python.org/3/library/asyncio-dev.html#running-blocking-code +How can I pass additional arguments to the connection handler? +.............................................................. + +You can bind additional arguments to the connection handler with +:func:`functools.partial`:: + + import asyncio + import functools + import websockets + + async def handler(websocket, path, extra_argument): + ... + + bound_handler = functools.partial(handler, extra_argument='spam') + start_server = websockets.serve(bound_handler, ...) + +Another way to achieve this result is to define the ``handler`` coroutine in +a scope where the ``extra_argument`` variable exists instead of injecting it +through an argument. + How do I get access HTTP headers, for example cookies? ...................................................... From a64136c869c527808c337b13e6dace43ad9d674e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Nov 2020 13:31:38 +0100 Subject: [PATCH 0716/1539] Remove unfinished sentence. --- docs/faq.rst | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/faq.rst b/docs/faq.rst index 5748521f0..cd0033734 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -85,8 +85,6 @@ To access HTTP headers during the WebSocket handshake, you can override async def process_request(self, path, request_headers): cookies = request_header["Cookie"] -See - Once the connection is established, they're available in :attr:`~protocol.WebSocketServerProtocol.request_headers`:: From b331e6c9c3d2cfd3d768aa81e396a9e2f977cf88 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Nov 2020 13:37:53 +0100 Subject: [PATCH 0717/1539] Document how to pass arguments to protocol factory. Fix #851. --- docs/faq.rst | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/docs/faq.rst b/docs/faq.rst index cd0033734..4a083e2d0 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -222,6 +222,26 @@ There are several reasons why long-lived connections may be lost: If you're facing a reproducible issue, :ref:`enable debug logs ` to see when and how connections are closed. +How can I pass additional arguments to a custom protocol subclass? +.................................................................. + +You can bind additional arguments to the protocol factory with +:func:`functools.partial`:: + + import asyncio + import functools + import websockets + + class MyServerProtocol(websockets.WebSocketServerProtocol): + def __init__(self, extra_argument, *args, **kwargs): + super().__init__(*args, **kwargs) + # do something with extra_argument + + create_protocol = functools.partial(MyServerProtocol, extra_argument='spam') + start_server = websockets.serve(..., create_protocol=create_protocol) + +This example was for a server. The same pattern applies on a client. + Why do I get the error: ``module 'websockets' has no attribute '...'``? ....................................................................... From 988572074edbde4dce1e49573e9dca05498bb159 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Nov 2020 13:50:45 +0100 Subject: [PATCH 0718/1539] Brag with # stargazers. Fix #844. --- docs/conf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/conf.py b/docs/conf.py index 064c657bf..0c00b96fb 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -111,6 +111,7 @@ 'logo': 'websockets.svg', 'description': 'A library for building WebSocket servers and clients in Python with a focus on correctness and simplicity.', 'github_button': True, + 'github_type': 'star', 'github_user': 'aaugustin', 'github_repo': 'websockets', 'tidelift_url': 'https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=docs', From e6d5da9b94167d875e2fb3936e44665fe0f562bc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Nov 2020 16:25:33 +0100 Subject: [PATCH 0719/1539] Include "broadcast" as a search term. Fix #841. --- docs/intro.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/intro.rst b/docs/intro.rst index 8be700239..8aaaeddca 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -180,7 +180,7 @@ unregister them when they disconnect. # Register. connected.add(websocket) try: - # Implement logic here. + # Broadcast a message to all connected clients. await asyncio.wait([ws.send("Hello!") for ws in connected]) await asyncio.sleep(10) finally: From f6e03bbd1f0e1affdda16488e46ae488ab0ccfcb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Nov 2020 17:38:30 +0100 Subject: [PATCH 0720/1539] Run new version of black. --- src/websockets/client.py | 2 +- src/websockets/connection.py | 5 +- src/websockets/exceptions.py | 5 +- src/websockets/server.py | 3 +- tests/test_auth.py | 3 +- tests/test_connection.py | 200 +++++++++++++++++++++++------------ tests/test_frames.py | 49 +++++++-- tests/test_http11.py | 60 ++++++++--- 8 files changed, 232 insertions(+), 95 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index a7bfcc4ee..b7e407a45 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -280,7 +280,7 @@ def send_request(self, request: Request) -> None: def parse(self) -> Generator[None, None, None]: response = yield from Response.parse( - self.reader.read_line, self.reader.read_exact, self.reader.read_to_eof, + self.reader.read_line, self.reader.read_exact, self.reader.read_to_eof ) assert self.state == CONNECTING try: diff --git a/src/websockets/connection.py b/src/websockets/connection.py index ac30802db..a98d0b1e7 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -63,7 +63,10 @@ class State(enum.IntEnum): class Connection: def __init__( - self, side: Side, state: State = OPEN, max_size: Optional[int] = 2 ** 20, + self, + side: Side, + state: State = OPEN, + max_size: Optional[int] = 2 ** 20, ) -> None: # Connection side. CLIENT or SERVER. self.side = side diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index c60a3e10e..84c27692c 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -302,7 +302,10 @@ class AbortHandshake(InvalidHandshake): """ def __init__( - self, status: http.HTTPStatus, headers: HeadersLike, body: bytes = b"" + self, + status: http.HTTPStatus, + headers: HeadersLike, + body: bytes = b"", ) -> None: self.status = status self.headers = Headers(headers) diff --git a/src/websockets/server.py b/src/websockets/server.py index 1b03eabee..c2c818ce9 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -242,7 +242,8 @@ def process_origin(self, headers: Headers) -> Optional[Origin]: return origin def process_extensions( - self, headers: Headers, + self, + headers: Headers, ) -> Tuple[Optional[str], List[Extension]]: """ Handle the Sec-WebSocket-Extensions HTTP request header. diff --git a/tests/test_auth.py b/tests/test_auth.py index 68642389e..ce23f913d 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -79,7 +79,8 @@ async def check_credentials(username, password): return password == "iloveyou" create_protocol_check_credentials = basic_auth_protocol_factory( - realm="auth-tests", check_credentials=check_credentials, + realm="auth-tests", + check_credentials=check_credentials, ) @with_server(create_protocol=create_protocol_check_credentials) diff --git a/tests/test_connection.py b/tests/test_connection.py index 5c0f7302f..d47147d64 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -59,7 +59,9 @@ def assertConnectionClosing(self, connection, code=None, reason=""): """ close_frame = Frame( - True, OP_CLOSE, b"" if code is None else serialize_close(code, reason), + True, + OP_CLOSE, + b"" if code is None else serialize_close(code, reason), ) # A close frame was received. self.assertFrameReceived(connection, close_frame) @@ -74,7 +76,9 @@ def assertConnectionFailing(self, connection, code=None, reason=""): """ close_frame = Frame( - True, OP_CLOSE, b"" if code is None else serialize_close(code, reason), + True, + OP_CLOSE, + b"" if code is None else serialize_close(code, reason), ) # No frame was received. self.assertFrameReceived(connection, None) @@ -108,14 +112,16 @@ def test_client_receives_unmasked_frame(self): client = Connection(Side.CLIENT) client.receive_data(self.unmasked_text_frame_date) self.assertFrameReceived( - client, Frame(True, OP_TEXT, b"Spam"), + client, + Frame(True, OP_TEXT, b"Spam"), ) def test_server_receives_masked_frame(self): server = Connection(Side.SERVER) server.receive_data(self.masked_text_frame_data) self.assertFrameReceived( - server, Frame(True, OP_TEXT, b"Spam"), + server, + Frame(True, OP_TEXT, b"Spam"), ) def test_client_receives_masked_frame(self): @@ -228,14 +234,16 @@ def test_client_receives_text(self): client = Connection(Side.CLIENT) client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") self.assertFrameReceived( - client, Frame(True, OP_TEXT, "😀".encode()), + client, + Frame(True, OP_TEXT, "😀".encode()), ) def test_server_receives_text(self): server = Connection(Side.SERVER) server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") self.assertFrameReceived( - server, Frame(True, OP_TEXT, "😀".encode()), + server, + Frame(True, OP_TEXT, "😀".encode()), ) def test_client_receives_text_over_size_limit(self): @@ -256,14 +264,16 @@ def test_client_receives_text_without_size_limit(self): client = Connection(Side.CLIENT, max_size=None) client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") self.assertFrameReceived( - client, Frame(True, OP_TEXT, "😀".encode()), + client, + Frame(True, OP_TEXT, "😀".encode()), ) def test_server_receives_text_without_size_limit(self): server = Connection(Side.SERVER, max_size=None) server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") self.assertFrameReceived( - server, Frame(True, OP_TEXT, "😀".encode()), + server, + Frame(True, OP_TEXT, "😀".encode()), ) def test_client_sends_fragmented_text(self): @@ -293,37 +303,44 @@ def test_client_receives_fragmented_text(self): client = Connection(Side.CLIENT) client.receive_data(b"\x01\x02\xf0\x9f") self.assertFrameReceived( - client, Frame(False, OP_TEXT, "😀".encode()[:2]), + client, + Frame(False, OP_TEXT, "😀".encode()[:2]), ) client.receive_data(b"\x00\x04\x98\x80\xf0\x9f") self.assertFrameReceived( - client, Frame(False, OP_CONT, "😀😀".encode()[2:6]), + client, + Frame(False, OP_CONT, "😀😀".encode()[2:6]), ) client.receive_data(b"\x80\x02\x98\x80") self.assertFrameReceived( - client, Frame(True, OP_CONT, "😀".encode()[2:]), + client, + Frame(True, OP_CONT, "😀".encode()[2:]), ) def test_server_receives_fragmented_text(self): server = Connection(Side.SERVER) server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") self.assertFrameReceived( - server, Frame(False, OP_TEXT, "😀".encode()[:2]), + server, + Frame(False, OP_TEXT, "😀".encode()[:2]), ) server.receive_data(b"\x00\x84\x00\x00\x00\x00\x98\x80\xf0\x9f") self.assertFrameReceived( - server, Frame(False, OP_CONT, "😀😀".encode()[2:6]), + server, + Frame(False, OP_CONT, "😀😀".encode()[2:6]), ) server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") self.assertFrameReceived( - server, Frame(True, OP_CONT, "😀".encode()[2:]), + server, + Frame(True, OP_CONT, "😀".encode()[2:]), ) def test_client_receives_fragmented_text_over_size_limit(self): client = Connection(Side.CLIENT, max_size=3) client.receive_data(b"\x01\x02\xf0\x9f") self.assertFrameReceived( - client, Frame(False, OP_TEXT, "😀".encode()[:2]), + client, + Frame(False, OP_TEXT, "😀".encode()[:2]), ) with self.assertRaises(PayloadTooBig) as raised: client.receive_data(b"\x80\x02\x98\x80") @@ -334,7 +351,8 @@ def test_server_receives_fragmented_text_over_size_limit(self): server = Connection(Side.SERVER, max_size=3) server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") self.assertFrameReceived( - server, Frame(False, OP_TEXT, "😀".encode()[:2]), + server, + Frame(False, OP_TEXT, "😀".encode()[:2]), ) with self.assertRaises(PayloadTooBig) as raised: server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") @@ -345,30 +363,36 @@ def test_client_receives_fragmented_text_without_size_limit(self): client = Connection(Side.CLIENT, max_size=None) client.receive_data(b"\x01\x02\xf0\x9f") self.assertFrameReceived( - client, Frame(False, OP_TEXT, "😀".encode()[:2]), + client, + Frame(False, OP_TEXT, "😀".encode()[:2]), ) client.receive_data(b"\x00\x04\x98\x80\xf0\x9f") self.assertFrameReceived( - client, Frame(False, OP_CONT, "😀😀".encode()[2:6]), + client, + Frame(False, OP_CONT, "😀😀".encode()[2:6]), ) client.receive_data(b"\x80\x02\x98\x80") self.assertFrameReceived( - client, Frame(True, OP_CONT, "😀".encode()[2:]), + client, + Frame(True, OP_CONT, "😀".encode()[2:]), ) def test_server_receives_fragmented_text_without_size_limit(self): server = Connection(Side.SERVER, max_size=None) server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") self.assertFrameReceived( - server, Frame(False, OP_TEXT, "😀".encode()[:2]), + server, + Frame(False, OP_TEXT, "😀".encode()[:2]), ) server.receive_data(b"\x00\x84\x00\x00\x00\x00\x98\x80\xf0\x9f") self.assertFrameReceived( - server, Frame(False, OP_CONT, "😀😀".encode()[2:6]), + server, + Frame(False, OP_CONT, "😀😀".encode()[2:6]), ) server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") self.assertFrameReceived( - server, Frame(True, OP_CONT, "😀".encode()[2:]), + server, + Frame(True, OP_CONT, "😀".encode()[2:]), ) def test_client_sends_unexpected_text(self): @@ -389,7 +413,8 @@ def test_client_receives_unexpected_text(self): client = Connection(Side.CLIENT) client.receive_data(b"\x01\x00") self.assertFrameReceived( - client, Frame(False, OP_TEXT, b""), + client, + Frame(False, OP_TEXT, b""), ) with self.assertRaises(ProtocolError) as raised: client.receive_data(b"\x01\x00") @@ -400,7 +425,8 @@ def test_server_receives_unexpected_text(self): server = Connection(Side.SERVER) server.receive_data(b"\x01\x80\x00\x00\x00\x00") self.assertFrameReceived( - server, Frame(False, OP_TEXT, b""), + server, + Frame(False, OP_TEXT, b""), ) with self.assertRaises(ProtocolError) as raised: server.receive_data(b"\x01\x80\x00\x00\x00\x00") @@ -462,14 +488,16 @@ def test_client_receives_binary(self): client = Connection(Side.CLIENT) client.receive_data(b"\x82\x04\x01\x02\xfe\xff") self.assertFrameReceived( - client, Frame(True, OP_BINARY, b"\x01\x02\xfe\xff"), + client, + Frame(True, OP_BINARY, b"\x01\x02\xfe\xff"), ) def test_server_receives_binary(self): server = Connection(Side.SERVER) server.receive_data(b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff") self.assertFrameReceived( - server, Frame(True, OP_BINARY, b"\x01\x02\xfe\xff"), + server, + Frame(True, OP_BINARY, b"\x01\x02\xfe\xff"), ) def test_client_receives_binary_over_size_limit(self): @@ -513,37 +541,44 @@ def test_client_receives_fragmented_binary(self): client = Connection(Side.CLIENT) client.receive_data(b"\x02\x02\x01\x02") self.assertFrameReceived( - client, Frame(False, OP_BINARY, b"\x01\x02"), + client, + Frame(False, OP_BINARY, b"\x01\x02"), ) client.receive_data(b"\x00\x04\xfe\xff\x01\x02") self.assertFrameReceived( - client, Frame(False, OP_CONT, b"\xfe\xff\x01\x02"), + client, + Frame(False, OP_CONT, b"\xfe\xff\x01\x02"), ) client.receive_data(b"\x80\x02\xfe\xff") self.assertFrameReceived( - client, Frame(True, OP_CONT, b"\xfe\xff"), + client, + Frame(True, OP_CONT, b"\xfe\xff"), ) def test_server_receives_fragmented_binary(self): server = Connection(Side.SERVER) server.receive_data(b"\x02\x82\x00\x00\x00\x00\x01\x02") self.assertFrameReceived( - server, Frame(False, OP_BINARY, b"\x01\x02"), + server, + Frame(False, OP_BINARY, b"\x01\x02"), ) server.receive_data(b"\x00\x84\x00\x00\x00\x00\xee\xff\x01\x02") self.assertFrameReceived( - server, Frame(False, OP_CONT, b"\xee\xff\x01\x02"), + server, + Frame(False, OP_CONT, b"\xee\xff\x01\x02"), ) server.receive_data(b"\x80\x82\x00\x00\x00\x00\xfe\xff") self.assertFrameReceived( - server, Frame(True, OP_CONT, b"\xfe\xff"), + server, + Frame(True, OP_CONT, b"\xfe\xff"), ) def test_client_receives_fragmented_binary_over_size_limit(self): client = Connection(Side.CLIENT, max_size=3) client.receive_data(b"\x02\x02\x01\x02") self.assertFrameReceived( - client, Frame(False, OP_BINARY, b"\x01\x02"), + client, + Frame(False, OP_BINARY, b"\x01\x02"), ) with self.assertRaises(PayloadTooBig) as raised: client.receive_data(b"\x80\x02\xfe\xff") @@ -554,7 +589,8 @@ def test_server_receives_fragmented_binary_over_size_limit(self): server = Connection(Side.SERVER, max_size=3) server.receive_data(b"\x02\x82\x00\x00\x00\x00\x01\x02") self.assertFrameReceived( - server, Frame(False, OP_BINARY, b"\x01\x02"), + server, + Frame(False, OP_BINARY, b"\x01\x02"), ) with self.assertRaises(PayloadTooBig) as raised: server.receive_data(b"\x80\x82\x00\x00\x00\x00\xfe\xff") @@ -579,7 +615,8 @@ def test_client_receives_unexpected_binary(self): client = Connection(Side.CLIENT) client.receive_data(b"\x02\x00") self.assertFrameReceived( - client, Frame(False, OP_BINARY, b""), + client, + Frame(False, OP_BINARY, b""), ) with self.assertRaises(ProtocolError) as raised: client.receive_data(b"\x02\x00") @@ -590,7 +627,8 @@ def test_server_receives_unexpected_binary(self): server = Connection(Side.SERVER) server.receive_data(b"\x02\x80\x00\x00\x00\x00") self.assertFrameReceived( - server, Frame(False, OP_BINARY, b""), + server, + Frame(False, OP_BINARY, b""), ) with self.assertRaises(ProtocolError) as raised: server.receive_data(b"\x02\x80\x00\x00\x00\x00") @@ -843,20 +881,24 @@ def test_client_receives_ping(self): client = Connection(Side.CLIENT) client.receive_data(b"\x89\x00") self.assertFrameReceived( - client, Frame(True, OP_PING, b""), + client, + Frame(True, OP_PING, b""), ) self.assertFrameSent( - client, Frame(True, OP_PONG, b""), + client, + Frame(True, OP_PONG, b""), ) def test_server_receives_ping(self): server = Connection(Side.SERVER) server.receive_data(b"\x89\x80\x00\x44\x88\xcc") self.assertFrameReceived( - server, Frame(True, OP_PING, b""), + server, + Frame(True, OP_PING, b""), ) self.assertFrameSent( - server, Frame(True, OP_PONG, b""), + server, + Frame(True, OP_PONG, b""), ) def test_client_sends_ping_with_data(self): @@ -876,20 +918,24 @@ def test_client_receives_ping_with_data(self): client = Connection(Side.CLIENT) client.receive_data(b"\x89\x04\x22\x66\xaa\xee") self.assertFrameReceived( - client, Frame(True, OP_PING, b"\x22\x66\xaa\xee"), + client, + Frame(True, OP_PING, b"\x22\x66\xaa\xee"), ) self.assertFrameSent( - client, Frame(True, OP_PONG, b"\x22\x66\xaa\xee"), + client, + Frame(True, OP_PONG, b"\x22\x66\xaa\xee"), ) def test_server_receives_ping_with_data(self): server = Connection(Side.SERVER) server.receive_data(b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22") self.assertFrameReceived( - server, Frame(True, OP_PING, b"\x22\x66\xaa\xee"), + server, + Frame(True, OP_PING, b"\x22\x66\xaa\xee"), ) self.assertFrameSent( - server, Frame(True, OP_PONG, b"\x22\x66\xaa\xee"), + server, + Frame(True, OP_PONG, b"\x22\x66\xaa\xee"), ) def test_client_sends_fragmented_ping_frame(self): @@ -953,7 +999,8 @@ def test_client_receives_ping_after_receiving_close(self): self.assertConnectionClosing(client, 1000) client.receive_data(b"\x89\x04\x22\x66\xaa\xee") self.assertFrameReceived( - client, Frame(True, OP_PING, b"\x22\x66\xaa\xee"), + client, + Frame(True, OP_PING, b"\x22\x66\xaa\xee"), ) self.assertFrameSent(client, None) @@ -963,7 +1010,8 @@ def test_server_receives_ping_after_receiving_close(self): self.assertConnectionClosing(server, 1001) server.receive_data(b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22") self.assertFrameReceived( - server, Frame(True, OP_PING, b"\x22\x66\xaa\xee"), + server, + Frame(True, OP_PING, b"\x22\x66\xaa\xee"), ) self.assertFrameSent(server, None) @@ -989,14 +1037,16 @@ def test_client_receives_pong(self): client = Connection(Side.CLIENT) client.receive_data(b"\x8a\x00") self.assertFrameReceived( - client, Frame(True, OP_PONG, b""), + client, + Frame(True, OP_PONG, b""), ) def test_server_receives_pong(self): server = Connection(Side.SERVER) server.receive_data(b"\x8a\x80\x00\x44\x88\xcc") self.assertFrameReceived( - server, Frame(True, OP_PONG, b""), + server, + Frame(True, OP_PONG, b""), ) def test_client_sends_pong_with_data(self): @@ -1016,14 +1066,16 @@ def test_client_receives_pong_with_data(self): client = Connection(Side.CLIENT) client.receive_data(b"\x8a\x04\x22\x66\xaa\xee") self.assertFrameReceived( - client, Frame(True, OP_PONG, b"\x22\x66\xaa\xee"), + client, + Frame(True, OP_PONG, b"\x22\x66\xaa\xee"), ) def test_server_receives_pong_with_data(self): server = Connection(Side.SERVER) server.receive_data(b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22") self.assertFrameReceived( - server, Frame(True, OP_PONG, b"\x22\x66\xaa\xee"), + server, + Frame(True, OP_PONG, b"\x22\x66\xaa\xee"), ) def test_client_sends_fragmented_pong_frame(self): @@ -1077,7 +1129,8 @@ def test_client_receives_pong_after_receiving_close(self): self.assertConnectionClosing(client, 1000) client.receive_data(b"\x8a\x04\x22\x66\xaa\xee") self.assertFrameReceived( - client, Frame(True, OP_PONG, b"\x22\x66\xaa\xee"), + client, + Frame(True, OP_PONG, b"\x22\x66\xaa\xee"), ) def test_server_receives_pong_after_receiving_close(self): @@ -1086,7 +1139,8 @@ def test_server_receives_pong_after_receiving_close(self): self.assertConnectionClosing(server, 1001) server.receive_data(b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22") self.assertFrameReceived( - server, Frame(True, OP_PONG, b"\x22\x66\xaa\xee"), + server, + Frame(True, OP_PONG, b"\x22\x66\xaa\xee"), ) @@ -1128,52 +1182,64 @@ def test_client_receive_ping_pong_in_fragmented_message(self): client = Connection(Side.CLIENT) client.receive_data(b"\x01\x04Spam") self.assertFrameReceived( - client, Frame(False, OP_TEXT, b"Spam"), + client, + Frame(False, OP_TEXT, b"Spam"), ) client.receive_data(b"\x89\x04Ping") self.assertFrameReceived( - client, Frame(True, OP_PING, b"Ping"), + client, + Frame(True, OP_PING, b"Ping"), ) self.assertFrameSent( - client, Frame(True, OP_PONG, b"Ping"), + client, + Frame(True, OP_PONG, b"Ping"), ) client.receive_data(b"\x00\x03Ham") self.assertFrameReceived( - client, Frame(False, OP_CONT, b"Ham"), + client, + Frame(False, OP_CONT, b"Ham"), ) client.receive_data(b"\x8a\x04Pong") self.assertFrameReceived( - client, Frame(True, OP_PONG, b"Pong"), + client, + Frame(True, OP_PONG, b"Pong"), ) client.receive_data(b"\x80\x04Eggs") self.assertFrameReceived( - client, Frame(True, OP_CONT, b"Eggs"), + client, + Frame(True, OP_CONT, b"Eggs"), ) def test_server_receive_ping_pong_in_fragmented_message(self): server = Connection(Side.SERVER) server.receive_data(b"\x01\x84\x00\x00\x00\x00Spam") self.assertFrameReceived( - server, Frame(False, OP_TEXT, b"Spam"), + server, + Frame(False, OP_TEXT, b"Spam"), ) server.receive_data(b"\x89\x84\x00\x00\x00\x00Ping") self.assertFrameReceived( - server, Frame(True, OP_PING, b"Ping"), + server, + Frame(True, OP_PING, b"Ping"), ) self.assertFrameSent( - server, Frame(True, OP_PONG, b"Ping"), + server, + Frame(True, OP_PONG, b"Ping"), ) server.receive_data(b"\x00\x83\x00\x00\x00\x00Ham") self.assertFrameReceived( - server, Frame(False, OP_CONT, b"Ham"), + server, + Frame(False, OP_CONT, b"Ham"), ) server.receive_data(b"\x8a\x84\x00\x00\x00\x00Pong") self.assertFrameReceived( - server, Frame(True, OP_PONG, b"Pong"), + server, + Frame(True, OP_PONG, b"Pong"), ) server.receive_data(b"\x80\x84\x00\x00\x00\x00Eggs") self.assertFrameReceived( - server, Frame(True, OP_CONT, b"Eggs"), + server, + Frame(True, OP_CONT, b"Eggs"), ) def test_client_send_close_in_fragmented_message(self): @@ -1205,7 +1271,8 @@ def test_client_receive_close_in_fragmented_message(self): client = Connection(Side.CLIENT) client.receive_data(b"\x01\x04Spam") self.assertFrameReceived( - client, Frame(False, OP_TEXT, b"Spam"), + client, + Frame(False, OP_TEXT, b"Spam"), ) # The spec says: "An endpoint MUST be capable of handling control # frames in the middle of a fragmented message." However, since the @@ -1220,7 +1287,8 @@ def test_server_receive_close_in_fragmented_message(self): server = Connection(Side.SERVER) server.receive_data(b"\x01\x84\x00\x00\x00\x00Spam") self.assertFrameReceived( - server, Frame(False, OP_TEXT, b"Spam"), + server, + Frame(False, OP_TEXT, b"Spam"), ) # The spec says: "An endpoint MUST be capable of handling control # frames in the middle of a fragmented message." However, since the diff --git a/tests/test_frames.py b/tests/test_frames.py index 514fe7c54..4d10c6ef2 100644 --- a/tests/test_frames.py +++ b/tests/test_frames.py @@ -22,7 +22,7 @@ def parse(self, data, mask, max_size=None, extensions=None): reader.feed_data(data) reader.feed_eof() parser = Frame.parse( - reader.read_exact, mask=mask, max_size=max_size, extensions=extensions, + reader.read_exact, mask=mask, max_size=max_size, extensions=extensions ) return self.assertGeneratorReturns(parser) @@ -47,7 +47,9 @@ def assertFrameData(self, frame, data, mask, extensions=None): class FrameTests(FramesTestCase): def test_text_unmasked(self): self.assertFrameData( - Frame(True, OP_TEXT, b"Spam"), b"\x81\x04Spam", mask=False, + Frame(True, OP_TEXT, b"Spam"), + b"\x81\x04Spam", + mask=False, ) def test_text_masked(self): @@ -59,7 +61,9 @@ def test_text_masked(self): def test_binary_unmasked(self): self.assertFrameData( - Frame(True, OP_BINARY, b"Eggs"), b"\x82\x04Eggs", mask=False, + Frame(True, OP_BINARY, b"Eggs"), + b"\x82\x04Eggs", + mask=False, ) def test_binary_masked(self): @@ -84,13 +88,25 @@ def test_non_ascii_text_masked(self): ) def test_close(self): - self.assertFrameData(Frame(True, OP_CLOSE, b""), b"\x88\x00", mask=False) + self.assertFrameData( + Frame(True, OP_CLOSE, b""), + b"\x88\x00", + mask=False, + ) def test_ping(self): - self.assertFrameData(Frame(True, OP_PING, b"ping"), b"\x89\x04ping", mask=False) + self.assertFrameData( + Frame(True, OP_PING, b"ping"), + b"\x89\x04ping", + mask=False, + ) def test_pong(self): - self.assertFrameData(Frame(True, OP_PONG, b"pong"), b"\x8a\x04pong", mask=False) + self.assertFrameData( + Frame(True, OP_PONG, b"pong"), + b"\x8a\x04pong", + mask=False, + ) def test_long(self): self.assertFrameData( @@ -179,23 +195,34 @@ def decode(frame, *, max_size=None): class PrepareDataTests(unittest.TestCase): def test_prepare_data_str(self): - self.assertEqual(prepare_data("café"), (OP_TEXT, b"caf\xc3\xa9")) + self.assertEqual( + prepare_data("café"), + (OP_TEXT, b"caf\xc3\xa9"), + ) def test_prepare_data_bytes(self): - self.assertEqual(prepare_data(b"tea"), (OP_BINARY, b"tea")) + self.assertEqual( + prepare_data(b"tea"), + (OP_BINARY, b"tea"), + ) def test_prepare_data_bytearray(self): self.assertEqual( - prepare_data(bytearray(b"tea")), (OP_BINARY, bytearray(b"tea")) + prepare_data(bytearray(b"tea")), + (OP_BINARY, bytearray(b"tea")), ) def test_prepare_data_memoryview(self): self.assertEqual( - prepare_data(memoryview(b"tea")), (OP_BINARY, memoryview(b"tea")) + prepare_data(memoryview(b"tea")), + (OP_BINARY, memoryview(b"tea")), ) def test_prepare_data_non_contiguous_memoryview(self): - self.assertEqual(prepare_data(memoryview(b"tteeaa")[::2]), (OP_BINARY, b"tea")) + self.assertEqual( + prepare_data(memoryview(b"tteeaa")[::2]), + (OP_BINARY, b"tea"), + ) def test_prepare_data_list(self): with self.assertRaises(TypeError): diff --git a/tests/test_http11.py b/tests/test_http11.py index 9e4d70620..1cca2053f 100644 --- a/tests/test_http11.py +++ b/tests/test_http11.py @@ -37,32 +37,45 @@ def test_parse_empty(self): with self.assertRaises(EOFError) as raised: next(self.parse()) self.assertEqual( - str(raised.exception), "connection closed while reading HTTP request line" + str(raised.exception), + "connection closed while reading HTTP request line", ) def test_parse_invalid_request_line(self): self.reader.feed_data(b"GET /\r\n\r\n") with self.assertRaises(ValueError) as raised: next(self.parse()) - self.assertEqual(str(raised.exception), "invalid HTTP request line: GET /") + self.assertEqual( + str(raised.exception), + "invalid HTTP request line: GET /", + ) def test_parse_unsupported_method(self): self.reader.feed_data(b"OPTIONS * HTTP/1.1\r\n\r\n") with self.assertRaises(ValueError) as raised: next(self.parse()) - self.assertEqual(str(raised.exception), "unsupported HTTP method: OPTIONS") + self.assertEqual( + str(raised.exception), + "unsupported HTTP method: OPTIONS", + ) def test_parse_unsupported_version(self): self.reader.feed_data(b"GET /chat HTTP/1.0\r\n\r\n") with self.assertRaises(ValueError) as raised: next(self.parse()) - self.assertEqual(str(raised.exception), "unsupported HTTP version: HTTP/1.0") + self.assertEqual( + str(raised.exception), + "unsupported HTTP version: HTTP/1.0", + ) def test_parse_invalid_header(self): self.reader.feed_data(b"GET /chat HTTP/1.1\r\nOops\r\n") with self.assertRaises(ValueError) as raised: next(self.parse()) - self.assertEqual(str(raised.exception), "invalid HTTP header line: Oops") + self.assertEqual( + str(raised.exception), + "invalid HTTP header line: Oops", + ) def test_serialize(self): # Example from the protocol overview in RFC 6455 @@ -101,7 +114,7 @@ def setUp(self): def parse(self): return Response.parse( - self.reader.read_line, self.reader.read_exact, self.reader.read_to_eof, + self.reader.read_line, self.reader.read_exact, self.reader.read_to_eof ) def test_parse(self): @@ -132,37 +145,55 @@ def test_parse_invalid_status_line(self): self.reader.feed_data(b"Hello!\r\n") with self.assertRaises(ValueError) as raised: next(self.parse()) - self.assertEqual(str(raised.exception), "invalid HTTP status line: Hello!") + self.assertEqual( + str(raised.exception), + "invalid HTTP status line: Hello!", + ) def test_parse_unsupported_version(self): self.reader.feed_data(b"HTTP/1.0 400 Bad Request\r\n\r\n") with self.assertRaises(ValueError) as raised: next(self.parse()) - self.assertEqual(str(raised.exception), "unsupported HTTP version: HTTP/1.0") + self.assertEqual( + str(raised.exception), + "unsupported HTTP version: HTTP/1.0", + ) def test_parse_invalid_status(self): self.reader.feed_data(b"HTTP/1.1 OMG WTF\r\n\r\n") with self.assertRaises(ValueError) as raised: next(self.parse()) - self.assertEqual(str(raised.exception), "invalid HTTP status code: OMG") + self.assertEqual( + str(raised.exception), + "invalid HTTP status code: OMG", + ) def test_parse_unsupported_status(self): self.reader.feed_data(b"HTTP/1.1 007 My name is Bond\r\n\r\n") with self.assertRaises(ValueError) as raised: next(self.parse()) - self.assertEqual(str(raised.exception), "unsupported HTTP status code: 007") + self.assertEqual( + str(raised.exception), + "unsupported HTTP status code: 007", + ) def test_parse_invalid_reason(self): self.reader.feed_data(b"HTTP/1.1 200 \x7f\r\n\r\n") with self.assertRaises(ValueError) as raised: next(self.parse()) - self.assertEqual(str(raised.exception), "invalid HTTP reason phrase: \x7f") + self.assertEqual( + str(raised.exception), + "invalid HTTP reason phrase: \x7f", + ) def test_parse_invalid_header(self): self.reader.feed_data(b"HTTP/1.1 500 Internal Server Error\r\nOops\r\n") with self.assertRaises(ValueError) as raised: next(self.parse()) - self.assertEqual(str(raised.exception), "invalid HTTP header line: Oops") + self.assertEqual( + str(raised.exception), + "invalid HTTP header line: Oops", + ) def test_parse_body_with_content_length(self): self.reader.feed_data( @@ -183,7 +214,10 @@ def test_parse_body_with_transfer_encoding(self): self.reader.feed_data(b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n") with self.assertRaises(NotImplementedError) as raised: next(self.parse()) - self.assertEqual(str(raised.exception), "transfer codings aren't supported") + self.assertEqual( + str(raised.exception), + "transfer codings aren't supported", + ) def test_parse_body_no_content(self): self.reader.feed_data(b"HTTP/1.1 204 No Content\r\n\r\n") From 5bce4c1c5e59c8c3f5ec45de1c94f9047126b885 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Nov 2020 17:44:11 +0100 Subject: [PATCH 0721/1539] Support IRIs in addition to URIs. Fix #832. --- docs/changelog.rst | 2 ++ src/websockets/uri.py | 18 ++++++++++++++++++ tests/test_uri.py | 5 +++++ 3 files changed, 25 insertions(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index 68ec6f80c..4c0eb7d2c 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -24,6 +24,8 @@ Changelog Aliases provide backwards compatibility for all previously public APIs. +* Added support for IRIs in addition to URIs. + 8.1 ... diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 6669e5668..ce21b445b 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -49,6 +49,10 @@ class WebSocketURI(NamedTuple): WebSocketURI.user_info.__doc__ = "" +# All characters from the gen-delims and sub-delims sets in RFC 3987. +DELIMS = ":/?#[]@!$&'()*+,;=" + + def parse_uri(uri: str) -> WebSocketURI: """ Parse and validate a WebSocket URI. @@ -78,4 +82,18 @@ def parse_uri(uri: str) -> WebSocketURI: if parsed.password is None: raise InvalidURI(uri) user_info = (parsed.username, parsed.password) + + try: + uri.encode("ascii") + except UnicodeEncodeError: + # Input contains non-ASCII characters. + # It must be an IRI. Convert it to a URI. + host = host.encode("idna").decode() + resource_name = urllib.parse.quote(resource_name, safe=DELIMS) + if user_info is not None: + user_info = ( + urllib.parse.quote(user_info[0], safe=DELIMS), + urllib.parse.quote(user_info[1], safe=DELIMS), + ) + return WebSocketURI(secure, host, port, resource_name, user_info) diff --git a/tests/test_uri.py b/tests/test_uri.py index e41860b8e..9eeb8431d 100644 --- a/tests/test_uri.py +++ b/tests/test_uri.py @@ -10,6 +10,11 @@ ("ws://localhost/path?query", (False, "localhost", 80, "/path?query", None)), ("WS://LOCALHOST/PATH?QUERY", (False, "localhost", 80, "/PATH?QUERY", None)), ("ws://user:pass@localhost/", (False, "localhost", 80, "/", ("user", "pass"))), + ("ws://høst/", (False, "xn--hst-0na", 80, "/", None)), + ( + "ws://üser:påss@høst/πass", + (False, "xn--hst-0na", 80, "/%CF%80ass", ("%C3%BCser", "p%C3%A5ss")), + ), ] INVALID_URIS = [ From 72d32619650eace78a4d7e797de9369fbee10ada Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Nov 2020 17:54:07 +0100 Subject: [PATCH 0722/1539] Improve detection of broken connections. Refs #810. --- src/websockets/protocol.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 92ce8e305..39b578aba 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -877,10 +877,11 @@ async def transfer_data(self) -> None: self.transfer_data_exc = exc self.fail_connection(1002) - except (ConnectionError, EOFError) as exc: + except (ConnectionError, TimeoutError, EOFError) as exc: # Reading data with self.reader.readexactly may raise: # - most subclasses of ConnectionError if the TCP connection # breaks, is reset, or is aborted; + # - TimeoutError if the TCP connection times out; # - IncompleteReadError, a subclass of EOFError, if fewer # bytes are available than requested. self.transfer_data_exc = exc From 8061b03b803fb1ce2c7dfcf7bf3cd48f41d34b83 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Nov 2020 18:18:54 +0100 Subject: [PATCH 0723/1539] Remove loop argument to asyncio.Queue. Prepare compatibility with Python 3.10. Fix #801. --- src/websockets/__main__.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index 1a720498d..5013ca04f 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -176,8 +176,13 @@ def main() -> None: # Create an event loop that will run in a background thread. loop = asyncio.new_event_loop() + # Due to zealous removal of the loop parameter in the Queue constructor, + # we need a factory coroutine to run in the freshly created event loop. + async def queue_factory() -> asyncio.Queue[str]: + return asyncio.Queue() + # Create a queue of user inputs. There's no need to limit its size. - inputs: asyncio.Queue[str] = asyncio.Queue(loop=loop) + inputs: asyncio.Queue[str] = loop.run_until_complete(queue_factory()) # Create a stop condition when receiving SIGINT or SIGTERM. stop: asyncio.Future[None] = loop.create_future() From 867a00e5bafa1c8ad412eef06a5b09bac40694dc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Nov 2020 18:32:38 +0100 Subject: [PATCH 0724/1539] Eliminate ResourceWarning. --- src/websockets/__main__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index 5013ca04f..bce3e4bbb 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -206,6 +206,10 @@ async def queue_factory() -> asyncio.Queue[str]: # Wait for the event loop to terminate. thread.join() + # For reasons unclear, even though the loop is closed in the thread, + # it still thinks it's running here. + loop.close() + if __name__ == "__main__": main() From 32c9036ac5eee02e5167f93474b22e9cddbc78bd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Nov 2020 18:33:03 +0100 Subject: [PATCH 0725/1539] Mask expected deprecation warning. --- src/websockets/protocol.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 39b578aba..677d50f2c 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -54,7 +54,14 @@ prepare_data, serialize_close, ) -from .framing import Frame + + +with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", "websockets.framing is deprecated", DeprecationWarning + ) + from .framing import Frame + from .typing import Data, Subprotocol From 07775cfaa07b2fb2e31622af03a4fa62820482fb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Nov 2020 18:56:03 +0100 Subject: [PATCH 0726/1539] Mark code for removal. Refs #803. --- src/websockets/asyncio_client.py | 2 ++ src/websockets/asyncio_server.py | 2 ++ src/websockets/protocol.py | 2 ++ 3 files changed, 6 insertions(+) diff --git a/src/websockets/asyncio_client.py b/src/websockets/asyncio_client.py index efa29b69a..43e3c1cd2 100644 --- a/src/websockets/asyncio_client.py +++ b/src/websockets/asyncio_client.py @@ -101,6 +101,8 @@ async def read_http_response(self) -> Tuple[int, Headers]: """ try: status_code, reason, headers = await read_response(self.reader) + # Remove this branch when dropping support for Python < 3.8 + # because CancelledError no longer inherits Exception. except asyncio.CancelledError: # pragma: no cover raise except Exception as exc: diff --git a/src/websockets/asyncio_server.py b/src/websockets/asyncio_server.py index fe61c7ddc..b4f7fbc92 100644 --- a/src/websockets/asyncio_server.py +++ b/src/websockets/asyncio_server.py @@ -135,6 +135,8 @@ async def handler(self) -> None: available_subprotocols=self.available_subprotocols, extra_headers=self.extra_headers, ) + # Remove this branch when dropping support for Python < 3.8 + # because CancelledError no longer inherits Exception. except asyncio.CancelledError: # pragma: no cover raise except ConnectionError: diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 677d50f2c..ba4fc1d3c 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -1169,6 +1169,8 @@ async def keepalive_ping(self) -> None: self.fail_connection(1011) break + # Remove this branch when dropping support for Python < 3.8 + # because CancelledError no longer inherits Exception. except asyncio.CancelledError: raise From a58540d681fc858fc43fcfaf7a6be33f177446a7 Mon Sep 17 00:00:00 2001 From: konichuvak Date: Thu, 27 Aug 2020 16:26:46 -0400 Subject: [PATCH 0727/1539] Adds 1012-1014 close codes. Also replac. `list` with a `set` for faster close code lookups. --- src/websockets/exceptions.py | 4 ++++ src/websockets/frames.py | 17 +++++++++++++++-- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 84c27692c..bdadae05e 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -68,6 +68,7 @@ class WebSocketException(Exception): """ +# See https://www.iana.org/assignments/websocket/websocket.xhtml CLOSE_CODES = { 1000: "OK", 1001: "going away", @@ -81,6 +82,9 @@ class WebSocketException(Exception): 1009: "message too big", 1010: "extension required", 1011: "unexpected error", + 1012: "service restart", + 1013: "try again later", + 1014: "bad gateway", 1015: "TLS failure [internal]", } diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 2ff9dbd91..74223c0e8 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -53,8 +53,21 @@ class Opcode(enum.IntEnum): CTRL_OPCODES = OP_CLOSE, OP_PING, OP_PONG # Close code that are allowed in a close frame. -# Using a list optimizes `code in EXTERNAL_CLOSE_CODES`. -EXTERNAL_CLOSE_CODES = [1000, 1001, 1002, 1003, 1007, 1008, 1009, 1010, 1011] +# Using a set optimizes `code in EXTERNAL_CLOSE_CODES`. +EXTERNAL_CLOSE_CODES = { + 1000, + 1001, + 1002, + 1003, + 1007, + 1008, + 1009, + 1010, + 1011, + 1012, + 1013, + 1014, +} # Consider converting to a dataclass when dropping support for Python < 3.7. From 189671d990a3ecf2d8bf5c7e0c4d97abc9167c20 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Nov 2020 19:06:18 +0100 Subject: [PATCH 0728/1539] Add changelog for previous commit. --- docs/changelog.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index 4c0eb7d2c..c131f0528 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -31,6 +31,8 @@ Changelog * Added compatibility with Python 3.8. +* Added close codes 1012, 1013, and 1014. + 8.0.2 ..... From b39f62a066bde151b7551a0d445705481e247e9b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Nov 2020 20:19:47 +0100 Subject: [PATCH 0729/1539] Log exceptions consistently. This was the only use of the exception method (vs. exc_info=True). --- src/websockets/connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index a98d0b1e7..4a75bede9 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -282,7 +282,7 @@ def step_parser(self) -> None: self.parser_exc = exc raise except Exception as exc: - logger.exception("unexpected exception in parser") + logger.error("unexpected exception in parser", exc_info=True) # Don't include exception details, which may be security-sensitive. self.fail_connection(1011) self.parser_exc = exc From 984da0efa69c0fe3518f1bb81d43775f5ef66902 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 28 Nov 2020 13:21:31 +0100 Subject: [PATCH 0730/1539] Rename bytes_to_send to data_to_send. Since this function doesn't return bytes, but an iterable of bytes, the name was confusing. --- src/websockets/connection.py | 6 +-- tests/test_client.py | 2 +- tests/test_connection.py | 94 ++++++++++++++++++------------------ tests/test_server.py | 4 +- 4 files changed, 53 insertions(+), 53 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 4a75bede9..aeb774f00 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -126,7 +126,7 @@ def receive_data(self, data: bytes) -> None: After calling this method: - - You must call :meth:`bytes_to_send` and send this data. + - You must call :meth:`data_to_send` and send this data. - You should call :meth:`events_received` and process these events. """ @@ -139,7 +139,7 @@ def receive_eof(self) -> None: After calling this method: - - You must call :meth:`bytes_to_send` and send this data. + - You must call :meth:`data_to_send` and send this data. - You shouldn't call :meth:`events_received` as it won't return any new events. @@ -228,7 +228,7 @@ def events_received(self) -> List[Event]: # Public API for getting outgoing data after receiving data or sending events. - def bytes_to_send(self) -> List[bytes]: + def data_to_send(self) -> List[bytes]: """ Return data to write to the connection. diff --git a/tests/test_client.py b/tests/test_client.py index 7a78ee09b..747594bf3 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -27,7 +27,7 @@ def test_send_connect(self): self.assertIsInstance(request, Request) client.send_request(request) self.assertEqual( - client.bytes_to_send(), + client.data_to_send(), [ f"GET /test HTTP/1.1\r\n" f"Host: example.com\r\n" diff --git a/tests/test_connection.py b/tests/test_connection.py index d47147d64..3e39a3f9e 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -35,7 +35,7 @@ def assertFrameSent(self, connection, frame, eof=False): mask=connection.side is Side.CLIENT, extensions=connection.extensions, ) - for write in connection.bytes_to_send() + for write in connection.data_to_send() ] frames_expected = [] if frame is None else [frame] if eof: @@ -101,12 +101,12 @@ def test_client_sends_masked_frame(self): client = Connection(Side.CLIENT) with self.enforce_mask(b"\x00\xff\x00\xff"): client.send_text(b"Spam", True) - self.assertEqual(client.bytes_to_send(), [self.masked_text_frame_data]) + self.assertEqual(client.data_to_send(), [self.masked_text_frame_data]) def test_server_sends_unmasked_frame(self): server = Connection(Side.SERVER) server.send_text(b"Spam", True) - self.assertEqual(server.bytes_to_send(), [self.unmasked_text_frame_date]) + self.assertEqual(server.data_to_send(), [self.unmasked_text_frame_date]) def test_client_receives_unmasked_frame(self): client = Connection(Side.CLIENT) @@ -178,7 +178,7 @@ def test_client_sends_continuation_after_sending_close(self): # this is the same test as test_client_sends_unexpected_continuation. with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(1001) - self.assertEqual(client.bytes_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) + self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) with self.assertRaises(ProtocolError) as raised: client.send_continuation(b"", fin=False) self.assertEqual(str(raised.exception), "unexpected continuation frame") @@ -189,7 +189,7 @@ def test_server_sends_continuation_after_sending_close(self): # this is the same test as test_server_sends_unexpected_continuation. server = Connection(Side.SERVER) server.send_close(1000) - self.assertEqual(server.bytes_to_send(), [b"\x88\x02\x03\xe8", b""]) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8", b""]) with self.assertRaises(ProtocolError) as raised: server.send_continuation(b"", fin=False) self.assertEqual(str(raised.exception), "unexpected continuation frame") @@ -222,13 +222,13 @@ def test_client_sends_text(self): with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_text("😀".encode()) self.assertEqual( - client.bytes_to_send(), [b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80"] + client.data_to_send(), [b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80"] ) def test_server_sends_text(self): server = Connection(Side.SERVER) server.send_text("😀".encode()) - self.assertEqual(server.bytes_to_send(), [b"\x81\x04\xf0\x9f\x98\x80"]) + self.assertEqual(server.data_to_send(), [b"\x81\x04\xf0\x9f\x98\x80"]) def test_client_receives_text(self): client = Connection(Side.CLIENT) @@ -280,24 +280,24 @@ def test_client_sends_fragmented_text(self): client = Connection(Side.CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_text("😀".encode()[:2], fin=False) - self.assertEqual(client.bytes_to_send(), [b"\x01\x82\x00\x00\x00\x00\xf0\x9f"]) + self.assertEqual(client.data_to_send(), [b"\x01\x82\x00\x00\x00\x00\xf0\x9f"]) with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_continuation("😀😀".encode()[2:6], fin=False) self.assertEqual( - client.bytes_to_send(), [b"\x00\x84\x00\x00\x00\x00\x98\x80\xf0\x9f"] + client.data_to_send(), [b"\x00\x84\x00\x00\x00\x00\x98\x80\xf0\x9f"] ) with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_continuation("😀".encode()[2:], fin=True) - self.assertEqual(client.bytes_to_send(), [b"\x80\x82\x00\x00\x00\x00\x98\x80"]) + self.assertEqual(client.data_to_send(), [b"\x80\x82\x00\x00\x00\x00\x98\x80"]) def test_server_sends_fragmented_text(self): server = Connection(Side.SERVER) server.send_text("😀".encode()[:2], fin=False) - self.assertEqual(server.bytes_to_send(), [b"\x01\x02\xf0\x9f"]) + self.assertEqual(server.data_to_send(), [b"\x01\x02\xf0\x9f"]) server.send_continuation("😀😀".encode()[2:6], fin=False) - self.assertEqual(server.bytes_to_send(), [b"\x00\x04\x98\x80\xf0\x9f"]) + self.assertEqual(server.data_to_send(), [b"\x00\x04\x98\x80\xf0\x9f"]) server.send_continuation("😀".encode()[2:], fin=True) - self.assertEqual(server.bytes_to_send(), [b"\x80\x02\x98\x80"]) + self.assertEqual(server.data_to_send(), [b"\x80\x02\x98\x80"]) def test_client_receives_fragmented_text(self): client = Connection(Side.CLIENT) @@ -437,14 +437,14 @@ def test_client_sends_text_after_sending_close(self): client = Connection(Side.CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(1001) - self.assertEqual(client.bytes_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) + self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) with self.assertRaises(InvalidState): client.send_text(b"") def test_server_sends_text_after_sending_close(self): server = Connection(Side.SERVER) server.send_close(1000) - self.assertEqual(server.bytes_to_send(), [b"\x88\x02\x03\xe8", b""]) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8", b""]) with self.assertRaises(InvalidState): server.send_text(b"") @@ -476,13 +476,13 @@ def test_client_sends_binary(self): with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_binary(b"\x01\x02\xfe\xff") self.assertEqual( - client.bytes_to_send(), [b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff"] + client.data_to_send(), [b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff"] ) def test_server_sends_binary(self): server = Connection(Side.SERVER) server.send_binary(b"\x01\x02\xfe\xff") - self.assertEqual(server.bytes_to_send(), [b"\x82\x04\x01\x02\xfe\xff"]) + self.assertEqual(server.data_to_send(), [b"\x82\x04\x01\x02\xfe\xff"]) def test_client_receives_binary(self): client = Connection(Side.CLIENT) @@ -518,24 +518,24 @@ def test_client_sends_fragmented_binary(self): client = Connection(Side.CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_binary(b"\x01\x02", fin=False) - self.assertEqual(client.bytes_to_send(), [b"\x02\x82\x00\x00\x00\x00\x01\x02"]) + self.assertEqual(client.data_to_send(), [b"\x02\x82\x00\x00\x00\x00\x01\x02"]) with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_continuation(b"\xee\xff\x01\x02", fin=False) self.assertEqual( - client.bytes_to_send(), [b"\x00\x84\x00\x00\x00\x00\xee\xff\x01\x02"] + client.data_to_send(), [b"\x00\x84\x00\x00\x00\x00\xee\xff\x01\x02"] ) with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_continuation(b"\xee\xff", fin=True) - self.assertEqual(client.bytes_to_send(), [b"\x80\x82\x00\x00\x00\x00\xee\xff"]) + self.assertEqual(client.data_to_send(), [b"\x80\x82\x00\x00\x00\x00\xee\xff"]) def test_server_sends_fragmented_binary(self): server = Connection(Side.SERVER) server.send_binary(b"\x01\x02", fin=False) - self.assertEqual(server.bytes_to_send(), [b"\x02\x02\x01\x02"]) + self.assertEqual(server.data_to_send(), [b"\x02\x02\x01\x02"]) server.send_continuation(b"\xee\xff\x01\x02", fin=False) - self.assertEqual(server.bytes_to_send(), [b"\x00\x04\xee\xff\x01\x02"]) + self.assertEqual(server.data_to_send(), [b"\x00\x04\xee\xff\x01\x02"]) server.send_continuation(b"\xee\xff", fin=True) - self.assertEqual(server.bytes_to_send(), [b"\x80\x02\xee\xff"]) + self.assertEqual(server.data_to_send(), [b"\x80\x02\xee\xff"]) def test_client_receives_fragmented_binary(self): client = Connection(Side.CLIENT) @@ -639,14 +639,14 @@ def test_client_sends_binary_after_sending_close(self): client = Connection(Side.CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(1001) - self.assertEqual(client.bytes_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) + self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) with self.assertRaises(InvalidState): client.send_binary(b"") def test_server_sends_binary_after_sending_close(self): server = Connection(Side.SERVER) server.send_close(1000) - self.assertEqual(server.bytes_to_send(), [b"\x88\x02\x03\xe8", b""]) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8", b""]) with self.assertRaises(InvalidState): server.send_binary(b"") @@ -677,13 +677,13 @@ def test_client_sends_close(self): client = Connection(Side.CLIENT) with self.enforce_mask(b"\x3c\x3c\x3c\x3c"): client.send_close() - self.assertEqual(client.bytes_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) + self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) self.assertIs(client.state, State.CLOSING) def test_server_sends_close(self): server = Connection(Side.SERVER) server.send_close() - self.assertEqual(server.bytes_to_send(), [b"\x88\x00", b""]) + self.assertEqual(server.data_to_send(), [b"\x88\x00", b""]) self.assertIs(server.state, State.CLOSING) def test_client_receives_close(self): @@ -691,14 +691,14 @@ def test_client_receives_close(self): with self.enforce_mask(b"\x3c\x3c\x3c\x3c"): client.receive_data(b"\x88\x00") self.assertEqual(client.events_received(), [Frame(True, OP_CLOSE, b"")]) - self.assertEqual(client.bytes_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) + self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) self.assertIs(client.state, State.CLOSING) def test_server_receives_close(self): server = Connection(Side.SERVER) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") self.assertEqual(server.events_received(), [Frame(True, OP_CLOSE, b"")]) - self.assertEqual(server.bytes_to_send(), [b"\x88\x00", b""]) + self.assertEqual(server.data_to_send(), [b"\x88\x00", b""]) self.assertIs(server.state, State.CLOSING) def test_client_sends_close_then_receives_close(self): @@ -761,13 +761,13 @@ def test_client_sends_close_with_code(self): client = Connection(Side.CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(1001) - self.assertEqual(client.bytes_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) + self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) self.assertIs(client.state, State.CLOSING) def test_server_sends_close_with_code(self): server = Connection(Side.SERVER) server.send_close(1000) - self.assertEqual(server.bytes_to_send(), [b"\x88\x02\x03\xe8", b""]) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8", b""]) self.assertIs(server.state, State.CLOSING) def test_client_receives_close_with_code(self): @@ -787,14 +787,14 @@ def test_client_sends_close_with_code_and_reason(self): with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(1001, "going away") self.assertEqual( - client.bytes_to_send(), [b"\x88\x8c\x00\x00\x00\x00\x03\xe9going away"] + client.data_to_send(), [b"\x88\x8c\x00\x00\x00\x00\x03\xe9going away"] ) self.assertIs(client.state, State.CLOSING) def test_server_sends_close_with_code_and_reason(self): server = Connection(Side.SERVER) server.send_close(1000, "OK") - self.assertEqual(server.bytes_to_send(), [b"\x88\x04\x03\xe8OK", b""]) + self.assertEqual(server.data_to_send(), [b"\x88\x04\x03\xe8OK", b""]) self.assertIs(server.state, State.CLOSING) def test_client_receives_close_with_code_and_reason(self): @@ -870,12 +870,12 @@ def test_client_sends_ping(self): client = Connection(Side.CLIENT) with self.enforce_mask(b"\x00\x44\x88\xcc"): client.send_ping(b"") - self.assertEqual(client.bytes_to_send(), [b"\x89\x80\x00\x44\x88\xcc"]) + self.assertEqual(client.data_to_send(), [b"\x89\x80\x00\x44\x88\xcc"]) def test_server_sends_ping(self): server = Connection(Side.SERVER) server.send_ping(b"") - self.assertEqual(server.bytes_to_send(), [b"\x89\x00"]) + self.assertEqual(server.data_to_send(), [b"\x89\x00"]) def test_client_receives_ping(self): client = Connection(Side.CLIENT) @@ -906,13 +906,13 @@ def test_client_sends_ping_with_data(self): with self.enforce_mask(b"\x00\x44\x88\xcc"): client.send_ping(b"\x22\x66\xaa\xee") self.assertEqual( - client.bytes_to_send(), [b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22"] + client.data_to_send(), [b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22"] ) def test_server_sends_ping_with_data(self): server = Connection(Side.SERVER) server.send_ping(b"\x22\x66\xaa\xee") - self.assertEqual(server.bytes_to_send(), [b"\x89\x04\x22\x66\xaa\xee"]) + self.assertEqual(server.data_to_send(), [b"\x89\x04\x22\x66\xaa\xee"]) def test_client_receives_ping_with_data(self): client = Connection(Side.CLIENT) @@ -970,7 +970,7 @@ def test_client_sends_ping_after_sending_close(self): client = Connection(Side.CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(1001) - self.assertEqual(client.bytes_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) + self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) # The spec says: "An endpoint MAY send a Ping frame any time (...) # before the connection is closed" but websockets doesn't support # sending a Ping frame after a Close frame. @@ -983,7 +983,7 @@ def test_client_sends_ping_after_sending_close(self): def test_server_sends_ping_after_sending_close(self): server = Connection(Side.SERVER) server.send_close(1000) - self.assertEqual(server.bytes_to_send(), [b"\x88\x02\x03\xe8", b""]) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8", b""]) # The spec says: "An endpoint MAY send a Ping frame any time (...) # before the connection is closed" but websockets doesn't support # sending a Ping frame after a Close frame. @@ -1026,12 +1026,12 @@ def test_client_sends_pong(self): client = Connection(Side.CLIENT) with self.enforce_mask(b"\x00\x44\x88\xcc"): client.send_pong(b"") - self.assertEqual(client.bytes_to_send(), [b"\x8a\x80\x00\x44\x88\xcc"]) + self.assertEqual(client.data_to_send(), [b"\x8a\x80\x00\x44\x88\xcc"]) def test_server_sends_pong(self): server = Connection(Side.SERVER) server.send_pong(b"") - self.assertEqual(server.bytes_to_send(), [b"\x8a\x00"]) + self.assertEqual(server.data_to_send(), [b"\x8a\x00"]) def test_client_receives_pong(self): client = Connection(Side.CLIENT) @@ -1054,13 +1054,13 @@ def test_client_sends_pong_with_data(self): with self.enforce_mask(b"\x00\x44\x88\xcc"): client.send_pong(b"\x22\x66\xaa\xee") self.assertEqual( - client.bytes_to_send(), [b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22"] + client.data_to_send(), [b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22"] ) def test_server_sends_pong_with_data(self): server = Connection(Side.SERVER) server.send_pong(b"\x22\x66\xaa\xee") - self.assertEqual(server.bytes_to_send(), [b"\x8a\x04\x22\x66\xaa\xee"]) + self.assertEqual(server.data_to_send(), [b"\x8a\x04\x22\x66\xaa\xee"]) def test_client_receives_pong_with_data(self): client = Connection(Side.CLIENT) @@ -1110,7 +1110,7 @@ def test_client_sends_pong_after_sending_close(self): client = Connection(Side.CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(1001) - self.assertEqual(client.bytes_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) + self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) # websockets doesn't support sending a Pong frame after a Close frame. with self.assertRaises(InvalidState): client.send_pong(b"") @@ -1118,7 +1118,7 @@ def test_client_sends_pong_after_sending_close(self): def test_server_sends_pong_after_sending_close(self): server = Connection(Side.SERVER) server.send_close(1000) - self.assertEqual(server.bytes_to_send(), [b"\x88\x02\x03\xe8", b""]) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8", b""]) # websockets doesn't support sending a Pong frame after a Close frame. with self.assertRaises(InvalidState): server.send_pong(b"") @@ -1465,13 +1465,13 @@ def test_client_extension_encodes_frame(self): client.extensions = [Rsv2Extension()] with self.enforce_mask(b"\x00\x44\x88\xcc"): client.send_ping(b"") - self.assertEqual(client.bytes_to_send(), [b"\xa9\x80\x00\x44\x88\xcc"]) + self.assertEqual(client.data_to_send(), [b"\xa9\x80\x00\x44\x88\xcc"]) def test_server_extension_encodes_frame(self): server = Connection(Side.SERVER) server.extensions = [Rsv2Extension()] server.send_ping(b"") - self.assertEqual(server.bytes_to_send(), [b"\xa9\x00"]) + self.assertEqual(server.data_to_send(), [b"\xa9\x00"]) def test_client_extension_decodes_frame(self): client = Connection(Side.CLIENT) diff --git a/tests/test_server.py b/tests/test_server.py index a180b08e2..ad56a37bc 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -91,7 +91,7 @@ def test_send_accept(self): self.assertIsInstance(response, Response) server.send_response(response) self.assertEqual( - server.bytes_to_send(), + server.data_to_send(), [ f"HTTP/1.1 101 Switching Protocols\r\n" f"Upgrade: websocket\r\n" @@ -111,7 +111,7 @@ def test_send_reject(self): self.assertIsInstance(response, Response) server.send_response(response) self.assertEqual( - server.bytes_to_send(), + server.data_to_send(), [ f"HTTP/1.1 404 Not Found\r\n" f"Date: {DATE}\r\n" From 44a5453612e7020d1305355c74c3d08ee4db4e91 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 28 Nov 2020 15:16:46 +0100 Subject: [PATCH 0731/1539] Extract logic for auto-configuring compression. --- src/websockets/asyncio_client.py | 12 +-- src/websockets/asyncio_server.py | 10 +- .../extensions/permessage_deflate.py | 45 +++++++++ tests/extensions/test_base.py | 36 +++++++ tests/extensions/test_permessage_deflate.py | 94 +++++++++++++++++++ tests/test_asyncio_client_server.py | 66 +------------ 6 files changed, 184 insertions(+), 79 deletions(-) diff --git a/src/websockets/asyncio_client.py b/src/websockets/asyncio_client.py index 43e3c1cd2..d22ba764a 100644 --- a/src/websockets/asyncio_client.py +++ b/src/websockets/asyncio_client.py @@ -22,7 +22,7 @@ SecurityError, ) from .extensions.base import ClientExtensionFactory, Extension -from .extensions.permessage_deflate import ClientPerMessageDeflateFactory +from .extensions.permessage_deflate import enable_client_permessage_deflate from .handshake_legacy import build_request, check_response from .headers import ( build_authorization_basic, @@ -425,15 +425,7 @@ def __init__( ) if compression == "deflate": - if extensions is None: - extensions = [] - if not any( - extension_factory.name == ClientPerMessageDeflateFactory.name - for extension_factory in extensions - ): - extensions = list(extensions) + [ - ClientPerMessageDeflateFactory(client_max_window_bits=True) - ] + extensions = enable_client_permessage_deflate(extensions) elif compression is not None: raise ValueError(f"unsupported compression: {compression}") diff --git a/src/websockets/asyncio_server.py b/src/websockets/asyncio_server.py index b4f7fbc92..79ceddf4b 100644 --- a/src/websockets/asyncio_server.py +++ b/src/websockets/asyncio_server.py @@ -39,7 +39,7 @@ NegotiationError, ) from .extensions.base import Extension, ServerExtensionFactory -from .extensions.permessage_deflate import ServerPerMessageDeflateFactory +from .extensions.permessage_deflate import enable_server_permessage_deflate from .handshake_legacy import build_response, check_request from .headers import build_extension, parse_extension, parse_subprotocol from .http import USER_AGENT @@ -903,13 +903,7 @@ def __init__( secure = kwargs.get("ssl") is not None if compression == "deflate": - if extensions is None: - extensions = [] - if not any( - ext_factory.name == ServerPerMessageDeflateFactory.name - for ext_factory in extensions - ): - extensions = list(extensions) + [ServerPerMessageDeflateFactory()] + extensions = enable_server_permessage_deflate(extensions) elif compression is not None: raise ValueError(f"unsupported compression: {compression}") diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 184183061..9a3fc4ba5 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -22,7 +22,9 @@ __all__ = [ "PerMessageDeflate", "ClientPerMessageDeflateFactory", + "enable_client_permessage_deflate", "ServerPerMessageDeflateFactory", + "enable_server_permessage_deflate", ] _EMPTY_UNCOMPRESSED_BLOCK = b"\x00\x00\xff\xff" @@ -424,6 +426,29 @@ def process_response_params( ) +def enable_client_permessage_deflate( + extensions: Optional[Sequence[ClientExtensionFactory]], +) -> Sequence[ClientExtensionFactory]: + """ + Enable Per-Message Deflate with default settings in client extensions. + + If the extension is already present, perhaps with non-default settings, + the configuration isn't changed. + + + """ + if extensions is None: + extensions = [] + if not any( + extension_factory.name == ClientPerMessageDeflateFactory.name + for extension_factory in extensions + ): + extensions = list(extensions) + [ + ClientPerMessageDeflateFactory(client_max_window_bits=True) + ] + return extensions + + class ServerPerMessageDeflateFactory(ServerExtensionFactory): """ Server-side extension factory for the Per-Message Deflate extension. @@ -584,3 +609,23 @@ def process_request_params( self.compress_settings, ), ) + + +def enable_server_permessage_deflate( + extensions: Optional[Sequence[ServerExtensionFactory]], +) -> Sequence[ServerExtensionFactory]: + """ + Enable Per-Message Deflate with default settings in server extensions. + + If the extension is already present, perhaps with non-default settings, + the configuration isn't changed. + + """ + if extensions is None: + extensions = [] + if not any( + ext_factory.name == ServerPerMessageDeflateFactory.name + for ext_factory in extensions + ): + extensions = list(extensions) + [ServerPerMessageDeflateFactory()] + return extensions diff --git a/tests/extensions/test_base.py b/tests/extensions/test_base.py index ba8657b65..0daa34211 100644 --- a/tests/extensions/test_base.py +++ b/tests/extensions/test_base.py @@ -1,4 +1,40 @@ +from websockets.exceptions import NegotiationError from websockets.extensions.base import * # noqa # Abstract classes don't provide any behavior to test. + + +class ClientNoOpExtensionFactory: + name = "x-no-op" + + def get_request_params(self): + return [] + + def process_response_params(self, params, accepted_extensions): + if params: + raise NegotiationError() + return NoOpExtension() + + +class ServerNoOpExtensionFactory: + name = "x-no-op" + + def __init__(self, params=None): + self.params = params or [] + + def process_request_params(self, params, accepted_extensions): + return self.params, NoOpExtension() + + +class NoOpExtension: + name = "x-no-op" + + def __repr__(self): + return "NoOpExtension()" + + def decode(self, frame, *, max_size=None): + return frame + + def encode(self, frame): + return frame diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index f9fca1999..328861e58 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -20,6 +20,8 @@ serialize_close, ) +from .test_base import ClientNoOpExtensionFactory, ServerNoOpExtensionFactory + class ExtensionTestsMixin: def assertExtensionEqual(self, extension1, extension2): @@ -500,6 +502,52 @@ def test_process_response_params_deduplication(self): [], [PerMessageDeflate(False, False, 15, 15)] ) + def test_enable_client_permessage_deflate(self): + for extensions, ( + expected_len, + expected_position, + expected_compress_settings, + ) in [ + ( + None, + (1, 0, None), + ), + ( + [], + (1, 0, None), + ), + ( + [ClientNoOpExtensionFactory()], + (2, 1, None), + ), + ( + [ClientPerMessageDeflateFactory(compress_settings={"level": 1})], + (1, 0, {"level": 1}), + ), + ( + [ + ClientPerMessageDeflateFactory(compress_settings={"level": 1}), + ClientNoOpExtensionFactory(), + ], + (2, 0, {"level": 1}), + ), + ( + [ + ClientNoOpExtensionFactory(), + ClientPerMessageDeflateFactory(compress_settings={"level": 1}), + ], + (2, 1, {"level": 1}), + ), + ]: + with self.subTest(extensions=extensions): + extensions = enable_client_permessage_deflate(extensions) + self.assertEqual(len(extensions), expected_len) + extension = extensions[expected_position] + self.assertIsInstance(extension, ClientPerMessageDeflateFactory) + self.assertEqual( + extension.compress_settings, expected_compress_settings + ) + class ServerPerMessageDeflateFactoryTests(unittest.TestCase, ExtensionTestsMixin): def test_name(self): @@ -790,3 +838,49 @@ def test_process_response_params_deduplication(self): factory.process_request_params( [], [PerMessageDeflate(False, False, 15, 15)] ) + + def test_enable_server_permessage_deflate(self): + for extensions, ( + expected_len, + expected_position, + expected_compress_settings, + ) in [ + ( + None, + (1, 0, None), + ), + ( + [], + (1, 0, None), + ), + ( + [ServerNoOpExtensionFactory()], + (2, 1, None), + ), + ( + [ServerPerMessageDeflateFactory(compress_settings={"level": 1})], + (1, 0, {"level": 1}), + ), + ( + [ + ServerPerMessageDeflateFactory(compress_settings={"level": 1}), + ServerNoOpExtensionFactory(), + ], + (2, 0, {"level": 1}), + ), + ( + [ + ServerNoOpExtensionFactory(), + ServerPerMessageDeflateFactory(compress_settings={"level": 1}), + ], + (2, 1, {"level": 1}), + ), + ]: + with self.subTest(extensions=extensions): + extensions = enable_server_permessage_deflate(extensions) + self.assertEqual(len(extensions), expected_len) + extension = extensions[expected_position] + self.assertIsInstance(extension, ServerPerMessageDeflateFactory) + self.assertEqual( + extension.compress_settings, expected_compress_settings + ) diff --git a/tests/test_asyncio_client_server.py b/tests/test_asyncio_client_server.py index cff76d1f2..76c29334e 100644 --- a/tests/test_asyncio_client_server.py +++ b/tests/test_asyncio_client_server.py @@ -34,6 +34,11 @@ from websockets.protocol import State from websockets.uri import parse_uri +from .extensions.test_base import ( + ClientNoOpExtensionFactory, + NoOpExtension, + ServerNoOpExtensionFactory, +) from .test_protocol import MS from .utils import AsyncioTestCase @@ -188,41 +193,6 @@ class BarClientProtocol(WebSocketClientProtocol): pass -class ClientNoOpExtensionFactory: - name = "x-no-op" - - def get_request_params(self): - return [] - - def process_response_params(self, params, accepted_extensions): - if params: - raise NegotiationError() - return NoOpExtension() - - -class ServerNoOpExtensionFactory: - name = "x-no-op" - - def __init__(self, params=None): - self.params = params or [] - - def process_request_params(self, params, accepted_extensions): - return self.params, NoOpExtension() - - -class NoOpExtension: - name = "x-no-op" - - def __repr__(self): - return "NoOpExtension()" - - def decode(self, frame, *, max_size=None): - return frame - - def encode(self, frame): - return frame - - class ClientServerTestsMixin: secure = False @@ -974,32 +944,6 @@ def test_compression_deflate(self): repr([PerMessageDeflate(False, False, 15, 15)]), ) - @with_server( - extensions=[ - ServerPerMessageDeflateFactory( - client_no_context_takeover=True, server_max_window_bits=10 - ) - ], - compression="deflate", # overridden by explicit config - ) - @with_client( - "/extensions", - extensions=[ - ClientPerMessageDeflateFactory( - server_no_context_takeover=True, client_max_window_bits=12 - ) - ], - compression="deflate", # overridden by explicit config - ) - def test_compression_deflate_and_explicit_config(self): - server_extensions = self.loop.run_until_complete(self.client.recv()) - self.assertEqual( - server_extensions, repr([PerMessageDeflate(True, True, 12, 10)]) - ) - self.assertEqual( - repr(self.client.extensions), repr([PerMessageDeflate(True, True, 10, 12)]) - ) - def test_compression_unsupported_server(self): with self.assertRaises(ValueError): self.start_server(compression="xz") From 3f36975b197f1250258055d403d2061f70013278 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 28 Nov 2020 15:18:44 +0100 Subject: [PATCH 0732/1539] Name asyncio protocol consistently. This isn't comparable to ws_server on the server side. --- src/websockets/asyncio_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/websockets/asyncio_client.py b/src/websockets/asyncio_client.py index d22ba764a..3f406170a 100644 --- a/src/websockets/asyncio_client.py +++ b/src/websockets/asyncio_client.py @@ -517,7 +517,7 @@ async def __aexit__( exc_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> None: - await self.ws_client.close() + await self.protocol.close() # await connect(...) @@ -546,7 +546,7 @@ async def __await_impl__(self) -> WebSocketClientProtocol: await protocol.wait_closed() raise else: - self.ws_client = protocol + self.protocol = protocol return protocol except RedirectHandshake as exc: self.handle_redirect(exc.uri) From 32b95fb0dd2cfc07d38df45dcf7f0ebf05008424 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 28 Nov 2020 15:19:11 +0100 Subject: [PATCH 0733/1539] Name pong waiter consistently. --- src/websockets/protocol.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index ba4fc1d3c..1552fb060 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -291,7 +291,7 @@ def __init__( # Protect sending fragmented messages. self._fragmented_message_waiter: Optional[asyncio.Future[None]] = None - # Mapping of ping IDs to waiters, in chronological order. + # Mapping of ping IDs to pong waiters, in chronological order. self.pings: Dict[bytes, asyncio.Future[None]] = {} # Task running the data transfer. @@ -736,15 +736,15 @@ async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: """ Send a ping. - Return a :class:`~asyncio.Future` which will be completed when the - corresponding pong is received and which you may ignore if you don't - want to wait. + Return a :class:`~asyncio.Future` that will be completed when the + corresponding pong is received. You can ignore it if you don't intend + to wait. A ping may serve as a keepalive or as a check that the remote endpoint received all messages up to this point:: pong_waiter = await ws.ping() - await pong_waiter # only if you want to wait for the pong + await pong_waiter # only if you want to wait for the pong By default, the ping contains four random bytes. This payload may be overridden with the optional ``data`` argument which must be a string @@ -1155,12 +1155,12 @@ async def keepalive_ping(self) -> None: # ping() raises ConnectionClosed if the connection is lost, # when connection_lost() calls abort_pings(). - ping_waiter = await self.ping() + pong_waiter = await self.ping() if self.ping_timeout is not None: try: await asyncio.wait_for( - ping_waiter, + pong_waiter, self.ping_timeout, loop=self.loop if sys.version_info[:2] < (3, 8) else None, ) From 165d0c69548e4c9d02624bcbb6eb565bb4c0c136 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 28 Nov 2020 19:15:43 +0100 Subject: [PATCH 0734/1539] Move asyncio-based APIs to a legacy subpackage. Clean up deprecations in the process. --- docs/api.rst | 8 +- docs/changelog.rst | 136 +- docs/cheatsheet.rst | 36 +- docs/deployment.rst | 6 +- docs/design.rst | 178 +- docs/extensions.rst | 5 +- docs/faq.rst | 8 +- docs/intro.rst | 2 +- setup.py | 2 +- src/websockets/__init__.py | 10 +- src/websockets/__main__.py | 2 +- src/websockets/auth.py | 163 +- src/websockets/client.py | 11 +- src/websockets/exceptions.py | 2 +- src/websockets/framing.py | 135 +- src/websockets/handshake.py | 8 +- src/websockets/http.py | 4 +- src/websockets/legacy/__init__.py | 0 src/websockets/legacy/auth.py | 165 ++ .../{asyncio_client.py => legacy/client.py} | 22 +- src/websockets/legacy/framing.py | 135 ++ .../handshake.py} | 12 +- .../{http_legacy.py => legacy/http.py} | 4 +- src/websockets/legacy/protocol.py | 1459 ++++++++++++++++ .../{asyncio_server.py => legacy/server.py} | 32 +- src/websockets/protocol.py | 1466 +--------------- src/websockets/server.py | 15 +- tests/__init__.py | 10 - tests/legacy/__init__.py | 0 tests/legacy/test_auth.py | 160 ++ .../test_client_server.py} | 39 +- tests/legacy/test_framing.py | 171 ++ .../test_handshake.py} | 2 +- .../test_http.py} | 4 +- tests/legacy/test_protocol.py | 1489 ++++++++++++++++ tests/legacy/utils.py | 93 + tests/test_auth.py | 162 +- tests/test_exports.py | 6 +- tests/test_framing.py | 174 +- tests/test_protocol.py | 1491 +---------------- tests/utils.py | 92 - 41 files changed, 3964 insertions(+), 3955 deletions(-) create mode 100644 src/websockets/legacy/__init__.py create mode 100644 src/websockets/legacy/auth.py rename src/websockets/{asyncio_client.py => legacy/client.py} (97%) create mode 100644 src/websockets/legacy/framing.py rename src/websockets/{handshake_legacy.py => legacy/handshake.py} (93%) rename src/websockets/{http_legacy.py => legacy/http.py} (98%) create mode 100644 src/websockets/legacy/protocol.py rename src/websockets/{asyncio_server.py => legacy/server.py} (97%) create mode 100644 tests/legacy/__init__.py create mode 100644 tests/legacy/test_auth.py rename tests/{test_asyncio_client_server.py => legacy/test_client_server.py} (97%) create mode 100644 tests/legacy/test_framing.py rename tests/{test_handshake_legacy.py => legacy/test_handshake.py} (99%) rename tests/{test_http_legacy.py => legacy/test_http.py} (98%) create mode 100644 tests/legacy/test_protocol.py create mode 100644 tests/legacy/utils.py diff --git a/docs/api.rst b/docs/api.rst index b4bddaf38..c73cf59d3 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -29,7 +29,7 @@ High-level Server ...... -.. automodule:: websockets.server +.. automodule:: websockets.legacy.server .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, **kwds) :async: @@ -53,7 +53,7 @@ Server Client ...... -.. automodule:: websockets.client +.. automodule:: websockets.legacy.client .. autofunction:: connect(uri, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, **kwds) :async: @@ -68,7 +68,7 @@ Client Shared ...... -.. automodule:: websockets.protocol +.. automodule:: websockets.legacy.protocol .. autoclass:: WebSocketCommonProtocol(*, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None) @@ -107,7 +107,7 @@ Per-Message Deflate Extension HTTP Basic Auth ............... -.. automodule:: websockets.auth +.. automodule:: websockets.legacy.auth .. autofunction:: basic_auth_protocol_factory diff --git a/docs/changelog.rst b/docs/changelog.rst index c131f0528..291ec6938 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -10,17 +10,23 @@ Changelog .. note:: - **Version 9.0 moves or deprecates several low-level APIs.** + **Version 9.0 moves or deprecates several APIs.** * Import :class:`~datastructures.Headers` and :exc:`~datastructures.MultipleValuesError` from :mod:`websockets.datastructures` instead of :mod:`websockets.http`. + * :mod:`websockets.client`, :mod:`websockets.server,` + :mod:`websockets.protocol`, and :mod:`websockets.auth` were moved to + :mod:`websockets.legacy.client`, :mod:`websockets.legacy.server`, + :mod:`websockets.legacy.protocol`, and :mod:`websockets.legacy.auth` + respectively. + * :mod:`websockets.handshake` is deprecated. * :mod:`websockets.http` is deprecated. - * :mod:`websocket.framing` is deprecated. + * :mod:`websockets.framing` is deprecated. Aliases provide backwards compatibility for all previously public APIs. @@ -37,7 +43,7 @@ Changelog ..... * Restored the ability to pass a socket with the ``sock`` parameter of - :func:`~server.serve`. + :func:`~legacy.server.serve`. * Removed an incorrect assertion when a connection drops. @@ -60,9 +66,9 @@ Changelog Previously, it could be a function or a coroutine. - If you're passing a ``process_request`` argument to :func:`~server.serve` - or :class:`~server.WebSocketServerProtocol`, or if you're overriding - :meth:`~protocol.WebSocketServerProtocol.process_request` in a subclass, + If you're passing a ``process_request`` argument to :func:`~legacy.server.serve` + or :class:`~legacy.server.WebSocketServerProtocol`, or if you're overriding + :meth:`~legacy.server.WebSocketServerProtocol.process_request` in a subclass, define it with ``async def`` instead of ``def``. For backwards compatibility, functions are still mostly supported, but @@ -78,10 +84,10 @@ Changelog .. note:: **Version 8.0 deprecates the** ``host`` **,** ``port`` **, and** ``secure`` - **attributes of** :class:`~protocol.WebSocketCommonProtocol`. + **attributes of** :class:`~legacy.protocol.WebSocketCommonProtocol`. - Use :attr:`~protocol.WebSocketCommonProtocol.local_address` in servers and - :attr:`~protocol.WebSocketCommonProtocol.remote_address` in clients + Use :attr:`~legacy.protocol.WebSocketCommonProtocol.local_address` in servers and + :attr:`~legacy.protocol.WebSocketCommonProtocol.remote_address` in clients instead of ``host`` and ``port``. .. note:: @@ -98,9 +104,9 @@ Changelog Also: -* :meth:`~protocol.WebSocketCommonProtocol.send`, - :meth:`~protocol.WebSocketCommonProtocol.ping`, and - :meth:`~protocol.WebSocketCommonProtocol.pong` support bytes-like types +* :meth:`~legacy.protocol.WebSocketCommonProtocol.send`, + :meth:`~legacy.protocol.WebSocketCommonProtocol.ping`, and + :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` support bytes-like types :class:`bytearray` and :class:`memoryview` in addition to :class:`bytes`. * Added :exc:`~exceptions.ConnectionClosedOK` and @@ -108,18 +114,18 @@ Also: :exc:`~exceptions.ConnectionClosed` to tell apart normal connection termination from errors. -* Added :func:`~auth.basic_auth_protocol_factory` to enforce HTTP Basic Auth +* Added :func:`~legacy.auth.basic_auth_protocol_factory` to enforce HTTP Basic Auth on the server side. -* :func:`~client.connect` handles redirects from the server during the +* :func:`~legacy.client.connect` handles redirects from the server during the handshake. -* :func:`~client.connect` supports overriding ``host`` and ``port``. +* :func:`~legacy.client.connect` supports overriding ``host`` and ``port``. -* Added :func:`~client.unix_connect` for connecting to Unix sockets. +* Added :func:`~legacy.client.unix_connect` for connecting to Unix sockets. * Improved support for sending fragmented messages by accepting asynchronous - iterators in :meth:`~protocol.WebSocketCommonProtocol.send`. + iterators in :meth:`~legacy.protocol.WebSocketCommonProtocol.send`. * Prevented spurious log messages about :exc:`~exceptions.ConnectionClosed` exceptions in keepalive ping task. If you were using ``ping_timeout=None`` @@ -150,7 +156,7 @@ Also: .. warning:: **Version 7.0 renames the** ``timeout`` **argument of** - :func:`~server.serve()` **and** :func:`~client.connect` **to** + :func:`~legacy.server.serve()` **and** :func:`~legacy.client.connect` **to** ``close_timeout`` **.** This prevents confusion with ``ping_timeout``. @@ -160,7 +166,7 @@ Also: .. warning:: **Version 7.0 changes how a server terminates connections when it's - closed with** :meth:`~server.WebSocketServer.close` **.** + closed with** :meth:`~legacy.server.WebSocketServer.close` **.** Previously, connections handlers were canceled. Now, connections are closed with close code 1001 (going away). From the perspective of the @@ -177,7 +183,7 @@ Also: .. note:: - **Version 7.0 changes how a** :meth:`~protocol.WebSocketCommonProtocol.ping` + **Version 7.0 changes how a** :meth:`~legacy.protocol.WebSocketCommonProtocol.ping` **that hasn't received a pong yet behaves when the connection is closed.** The ping — as in ``ping = await websocket.ping()`` — used to be canceled @@ -188,7 +194,7 @@ Also: .. note:: **Version 7.0 raises a** :exc:`RuntimeError` **exception if two coroutines - call** :meth:`~protocol.WebSocketCommonProtocol.recv` **concurrently.** + call** :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` **concurrently.** Concurrent calls lead to non-deterministic behavior because there are no guarantees about which coroutine will receive which message. @@ -197,17 +203,17 @@ Also: * ``websockets`` sends Ping frames at regular intervals and closes the connection if it doesn't receive a matching Pong frame. See - :class:`~protocol.WebSocketCommonProtocol` for details. + :class:`~legacy.protocol.WebSocketCommonProtocol` for details. * Added ``process_request`` and ``select_subprotocol`` arguments to - :func:`~server.serve` and :class:`~server.WebSocketServerProtocol` to - customize :meth:`~server.WebSocketServerProtocol.process_request` and - :meth:`~server.WebSocketServerProtocol.select_subprotocol` without - subclassing :class:`~server.WebSocketServerProtocol`. + :func:`~legacy.server.serve` and :class:`~legacy.server.WebSocketServerProtocol` to + customize :meth:`~legacy.server.WebSocketServerProtocol.process_request` and + :meth:`~legacy.server.WebSocketServerProtocol.select_subprotocol` without + subclassing :class:`~legacy.server.WebSocketServerProtocol`. * Added support for sending fragmented messages. -* Added the :meth:`~protocol.WebSocketCommonProtocol.wait_closed` method to +* Added the :meth:`~legacy.protocol.WebSocketCommonProtocol.wait_closed` method to protocols. * Added an interactive client: ``python -m websockets ``. @@ -215,7 +221,7 @@ Also: * Changed the ``origins`` argument to represent the lack of an origin with ``None`` rather than ``''``. -* Fixed a data loss bug in :meth:`~protocol.WebSocketCommonProtocol.recv`: +* Fixed a data loss bug in :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`: canceling it at the wrong time could result in messages being dropped. * Improved handling of multiple HTTP headers with the same name. @@ -230,18 +236,18 @@ Also: **Version 6.0 introduces the** :class:`~http.Headers` **class for managing HTTP headers and changes several public APIs:** - * :meth:`~server.WebSocketServerProtocol.process_request` now receives a + * :meth:`~legacy.server.WebSocketServerProtocol.process_request` now receives a :class:`~http.Headers` instead of a :class:`~http.client.HTTPMessage` in the ``request_headers`` argument. - * The :attr:`~protocol.WebSocketCommonProtocol.request_headers` and - :attr:`~protocol.WebSocketCommonProtocol.response_headers` attributes of - :class:`~protocol.WebSocketCommonProtocol` are :class:`~http.Headers` + * The :attr:`~legacy.protocol.WebSocketCommonProtocol.request_headers` and + :attr:`~legacy.protocol.WebSocketCommonProtocol.response_headers` attributes of + :class:`~legacy.protocol.WebSocketCommonProtocol` are :class:`~http.Headers` instead of :class:`~http.client.HTTPMessage`. - * The :attr:`~protocol.WebSocketCommonProtocol.raw_request_headers` and - :attr:`~protocol.WebSocketCommonProtocol.raw_response_headers` - attributes of :class:`~protocol.WebSocketCommonProtocol` are removed. + * The :attr:`~legacy.protocol.WebSocketCommonProtocol.raw_request_headers` and + :attr:`~legacy.protocol.WebSocketCommonProtocol.raw_response_headers` + attributes of :class:`~legacy.protocol.WebSocketCommonProtocol` are removed. Use :meth:`~http.Headers.raw_items` instead. * Functions defined in the :mod:`~handshake` module now receive @@ -265,7 +271,7 @@ Also: ..... * Fixed a regression in the 5.0 release that broke some invocations of - :func:`~server.serve()` and :func:`~client.connect`. + :func:`~legacy.server.serve()` and :func:`~legacy.client.connect`. 5.0 ... @@ -290,7 +296,7 @@ Also: Also: -* :func:`~client.connect` performs HTTP Basic Auth when the URI contains +* :func:`~legacy.client.connect` performs HTTP Basic Auth when the URI contains credentials. * Iterating on incoming messages no longer raises an exception when the @@ -299,13 +305,13 @@ Also: * A plain HTTP request now receives a 426 Upgrade Required response and doesn't log a stack trace. -* :func:`~server.unix_serve` can be used as an asynchronous context manager on +* :func:`~legacy.server.unix_serve` can be used as an asynchronous context manager on Python ≥ 3.5.1. -* Added the :attr:`~protocol.WebSocketCommonProtocol.closed` property to +* Added the :attr:`~legacy.protocol.WebSocketCommonProtocol.closed` property to protocols. -* If a :meth:`~protocol.WebSocketCommonProtocol.ping` doesn't receive a pong, +* If a :meth:`~legacy.protocol.WebSocketCommonProtocol.ping` doesn't receive a pong, it's canceled when the connection is closed. * Reported the cause of :exc:`~exceptions.ConnectionClosed` exceptions. @@ -346,7 +352,7 @@ Also: Compression should improve performance but it increases RAM and CPU use. If you want to disable compression, add ``compression=None`` when calling - :func:`~server.serve()` or :func:`~client.connect`. + :func:`~legacy.server.serve()` or :func:`~legacy.client.connect`. .. warning:: @@ -360,13 +366,13 @@ Also: Also: -* :class:`~protocol.WebSocketCommonProtocol` instances can be used as +* :class:`~legacy.protocol.WebSocketCommonProtocol` instances can be used as asynchronous iterators on Python ≥ 3.6. They yield incoming messages. -* Added :func:`~server.unix_serve` for listening on Unix sockets. +* Added :func:`~legacy.server.unix_serve` for listening on Unix sockets. -* Added the :attr:`~server.WebSocketServer.sockets` attribute to the return - value of :func:`~server.serve`. +* Added the :attr:`~legacy.server.WebSocketServer.sockets` attribute to the return + value of :func:`~legacy.server.serve`. * Reorganized and extended documentation. @@ -384,15 +390,15 @@ Also: 3.4 ... -* Renamed :func:`~server.serve()` and :func:`~client.connect`'s ``klass`` +* Renamed :func:`~legacy.server.serve()` and :func:`~legacy.client.connect`'s ``klass`` argument to ``create_protocol`` to reflect that it can also be a callable. For backwards compatibility, ``klass`` is still supported. -* :func:`~server.serve` can be used as an asynchronous context manager on +* :func:`~legacy.server.serve` can be used as an asynchronous context manager on Python ≥ 3.5.1. * Added support for customizing handling of incoming connections with - :meth:`~server.WebSocketServerProtocol.process_request`. + :meth:`~legacy.server.WebSocketServerProtocol.process_request`. * Made read and write buffer sizes configurable. @@ -400,10 +406,10 @@ Also: * Added an optional C extension to speed up low-level operations. -* An invalid response status code during :func:`~client.connect` now raises +* An invalid response status code during :func:`~legacy.client.connect` now raises :class:`~exceptions.InvalidStatusCode` with a ``code`` attribute. -* Providing a ``sock`` argument to :func:`~client.connect` no longer +* Providing a ``sock`` argument to :func:`~legacy.client.connect` no longer crashes. 3.3 @@ -419,7 +425,7 @@ Also: ... * Added ``timeout``, ``max_size``, and ``max_queue`` arguments to - :func:`~client.connect()` and :func:`~server.serve`. + :func:`~legacy.client.connect()` and :func:`~legacy.server.serve`. * Made server shutdown more robust. @@ -436,11 +442,11 @@ Also: .. warning:: **Version 3.0 introduces a backwards-incompatible change in the** - :meth:`~protocol.WebSocketCommonProtocol.recv` **API.** + :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` **API.** **If you're upgrading from 2.x or earlier, please read this carefully.** - :meth:`~protocol.WebSocketCommonProtocol.recv` used to return ``None`` + :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` used to return ``None`` when the connection was closed. This required checking the return value of every call:: @@ -459,20 +465,20 @@ Also: In order to avoid stranding projects built upon an earlier version, the previous behavior can be restored by passing ``legacy_recv=True`` to - :func:`~server.serve`, :func:`~client.connect`, - :class:`~server.WebSocketServerProtocol`, or - :class:`~client.WebSocketClientProtocol`. ``legacy_recv`` isn't documented + :func:`~legacy.server.serve`, :func:`~legacy.client.connect`, + :class:`~legacy.server.WebSocketServerProtocol`, or + :class:`~legacy.client.WebSocketClientProtocol`. ``legacy_recv`` isn't documented in their signatures but isn't scheduled for deprecation either. Also: -* :func:`~client.connect` can be used as an asynchronous context manager on +* :func:`~legacy.client.connect` can be used as an asynchronous context manager on Python ≥ 3.5.1. * Updated documentation with ``await`` and ``async`` syntax from Python 3.5. -* :meth:`~protocol.WebSocketCommonProtocol.ping` and - :meth:`~protocol.WebSocketCommonProtocol.pong` support data passed as +* :meth:`~legacy.protocol.WebSocketCommonProtocol.ping` and + :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` support data passed as :class:`str` in addition to :class:`bytes`. * Worked around an asyncio bug affecting connection termination under load. @@ -511,7 +517,7 @@ Also: * Returned a 403 status code instead of 400 when the request Origin isn't allowed. -* Canceling :meth:`~protocol.WebSocketCommonProtocol.recv` no longer drops +* Canceling :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` no longer drops the next message. * Clarified that the closing handshake can be initiated by the client. @@ -529,8 +535,8 @@ Also: * Supported non-default event loop. -* Added ``loop`` argument to :func:`~client.connect` and - :func:`~server.serve`. +* Added ``loop`` argument to :func:`~legacy.client.connect` and + :func:`~legacy.server.serve`. 2.3 ... @@ -557,9 +563,9 @@ Also: .. warning:: **Version 2.0 introduces a backwards-incompatible change in the** - :meth:`~protocol.WebSocketCommonProtocol.send`, - :meth:`~protocol.WebSocketCommonProtocol.ping`, and - :meth:`~protocol.WebSocketCommonProtocol.pong` **APIs.** + :meth:`~legacy.protocol.WebSocketCommonProtocol.send`, + :meth:`~legacy.protocol.WebSocketCommonProtocol.ping`, and + :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` **APIs.** **If you're upgrading from 1.x or earlier, please read this carefully.** diff --git a/docs/cheatsheet.rst b/docs/cheatsheet.rst index 4b95c9eea..a71f08d74 100644 --- a/docs/cheatsheet.rst +++ b/docs/cheatsheet.rst @@ -9,24 +9,24 @@ Server * Write a coroutine that handles a single connection. It receives a WebSocket protocol instance and the URI path in argument. - * Call :meth:`~protocol.WebSocketCommonProtocol.recv` and - :meth:`~protocol.WebSocketCommonProtocol.send` to receive and send + * Call :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` and + :meth:`~legacy.protocol.WebSocketCommonProtocol.send` to receive and send messages at any time. - * When :meth:`~protocol.WebSocketCommonProtocol.recv` or - :meth:`~protocol.WebSocketCommonProtocol.send` raises + * When :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` or + :meth:`~legacy.protocol.WebSocketCommonProtocol.send` raises :exc:`~exceptions.ConnectionClosed`, clean up and exit. If you started other :class:`asyncio.Task`, terminate them before exiting. - * If you aren't awaiting :meth:`~protocol.WebSocketCommonProtocol.recv`, - consider awaiting :meth:`~protocol.WebSocketCommonProtocol.wait_closed` + * If you aren't awaiting :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`, + consider awaiting :meth:`~legacy.protocol.WebSocketCommonProtocol.wait_closed` to detect quickly when the connection is closed. - * You may :meth:`~protocol.WebSocketCommonProtocol.ping` or - :meth:`~protocol.WebSocketCommonProtocol.pong` if you wish but it isn't + * You may :meth:`~legacy.protocol.WebSocketCommonProtocol.ping` or + :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` if you wish but it isn't needed in general. -* Create a server with :func:`~server.serve` which is similar to asyncio's +* Create a server with :func:`~legacy.server.serve` which is similar to asyncio's :meth:`~asyncio.AbstractEventLoop.create_server`. You can also use it as an asynchronous context manager. @@ -35,30 +35,30 @@ Server handler exits normally or with an exception. * For advanced customization, you may subclass - :class:`~server.WebSocketServerProtocol` and pass either this subclass or + :class:`~legacy.server.WebSocketServerProtocol` and pass either this subclass or a factory function as the ``create_protocol`` argument. Client ------ -* Create a client with :func:`~client.connect` which is similar to asyncio's +* Create a client with :func:`~legacy.client.connect` which is similar to asyncio's :meth:`~asyncio.BaseEventLoop.create_connection`. You can also use it as an asynchronous context manager. * For advanced customization, you may subclass - :class:`~server.WebSocketClientProtocol` and pass either this subclass or + :class:`~legacy.server.WebSocketClientProtocol` and pass either this subclass or a factory function as the ``create_protocol`` argument. -* Call :meth:`~protocol.WebSocketCommonProtocol.recv` and - :meth:`~protocol.WebSocketCommonProtocol.send` to receive and send messages +* Call :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` and + :meth:`~legacy.protocol.WebSocketCommonProtocol.send` to receive and send messages at any time. -* You may :meth:`~protocol.WebSocketCommonProtocol.ping` or - :meth:`~protocol.WebSocketCommonProtocol.pong` if you wish but it isn't +* You may :meth:`~legacy.protocol.WebSocketCommonProtocol.ping` or + :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` if you wish but it isn't needed in general. -* If you aren't using :func:`~client.connect` as a context manager, call - :meth:`~protocol.WebSocketCommonProtocol.close` to terminate the connection. +* If you aren't using :func:`~legacy.client.connect` as a context manager, call + :meth:`~legacy.protocol.WebSocketCommonProtocol.close` to terminate the connection. .. _debugging: diff --git a/docs/deployment.rst b/docs/deployment.rst index 5b05afff1..ed025094d 100644 --- a/docs/deployment.rst +++ b/docs/deployment.rst @@ -24,7 +24,7 @@ Graceful shutdown You may want to close connections gracefully when shutting down the server, perhaps after executing some cleanup logic. There are two ways to achieve this -with the object returned by :func:`~server.serve`: +with the object returned by :func:`~legacy.server.serve`: - using it as a asynchronous context manager, or - calling its ``close()`` method, then waiting for its ``wait_closed()`` @@ -132,7 +132,7 @@ Under high load, if a server receives more messages than it can process, bufferbloat can result in excessive memory use. By default ``websockets`` has generous limits. It is strongly recommended to -adapt them to your application. When you call :func:`~server.serve`: +adapt them to your application. When you call :func:`~legacy.server.serve`: - Set ``max_size`` (default: 1 MiB, UTF-8 encoded) to the maximum size of messages your application generates. @@ -155,7 +155,7 @@ The author of ``websockets`` doesn't think that's a good idea, due to the widely different operational characteristics of HTTP and WebSocket. ``websockets`` provide minimal support for responding to HTTP requests with -the :meth:`~server.WebSocketServerProtocol.process_request` hook. Typical +the :meth:`~legacy.server.WebSocketServerProtocol.process_request` hook. Typical use cases include health checks. Here's an example: .. literalinclude:: ../example/health_check_server.py diff --git a/docs/design.rst b/docs/design.rst index 74279b87f..f2718370d 100644 --- a/docs/design.rst +++ b/docs/design.rst @@ -32,20 +32,20 @@ WebSocket connections go through a trivial state machine: Transitions happen in the following places: - ``CONNECTING -> OPEN``: in - :meth:`~protocol.WebSocketCommonProtocol.connection_open` which runs when + :meth:`~legacy.protocol.WebSocketCommonProtocol.connection_open` which runs when the :ref:`opening handshake ` completes and the WebSocket connection is established — not to be confused with :meth:`~asyncio.Protocol.connection_made` which runs when the TCP connection is established; - ``OPEN -> CLOSING``: in - :meth:`~protocol.WebSocketCommonProtocol.write_frame` immediately before + :meth:`~legacy.protocol.WebSocketCommonProtocol.write_frame` immediately before sending a close frame; since receiving a close frame triggers sending a close frame, this does the right thing regardless of which side started the :ref:`closing handshake `; also in - :meth:`~protocol.WebSocketCommonProtocol.fail_connection` which duplicates + :meth:`~legacy.protocol.WebSocketCommonProtocol.fail_connection` which duplicates a few lines of code from ``write_close_frame()`` and ``write_frame()``; - ``* -> CLOSED``: in - :meth:`~protocol.WebSocketCommonProtocol.connection_lost` which is always + :meth:`~legacy.protocol.WebSocketCommonProtocol.connection_lost` which is always called exactly once when the TCP connection is closed. Coroutines @@ -58,36 +58,36 @@ connection lifecycle on the client side. :target: _images/lifecycle.svg The lifecycle is identical on the server side, except inversion of control -makes the equivalent of :meth:`~client.connect` implicit. +makes the equivalent of :meth:`~legacy.client.connect` implicit. Coroutines shown in green are called by the application. Multiple coroutines may interact with the WebSocket connection concurrently. Coroutines shown in gray manage the connection. When the opening handshake -succeeds, :meth:`~protocol.WebSocketCommonProtocol.connection_open` starts +succeeds, :meth:`~legacy.protocol.WebSocketCommonProtocol.connection_open` starts two tasks: -- :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` runs - :meth:`~protocol.WebSocketCommonProtocol.transfer_data` which handles - incoming data and lets :meth:`~protocol.WebSocketCommonProtocol.recv` +- :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` runs + :meth:`~legacy.protocol.WebSocketCommonProtocol.transfer_data` which handles + incoming data and lets :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` consume it. It may be canceled to terminate the connection. It never exits with an exception other than :exc:`~asyncio.CancelledError`. See :ref:`data transfer ` below. -- :attr:`~protocol.WebSocketCommonProtocol.keepalive_ping_task` runs - :meth:`~protocol.WebSocketCommonProtocol.keepalive_ping` which sends Ping +- :attr:`~legacy.protocol.WebSocketCommonProtocol.keepalive_ping_task` runs + :meth:`~legacy.protocol.WebSocketCommonProtocol.keepalive_ping` which sends Ping frames at regular intervals and ensures that corresponding Pong frames are received. It is canceled when the connection terminates. It never exits with an exception other than :exc:`~asyncio.CancelledError`. -- :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` runs - :meth:`~protocol.WebSocketCommonProtocol.close_connection` which waits for +- :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` runs + :meth:`~legacy.protocol.WebSocketCommonProtocol.close_connection` which waits for the data transfer to terminate, then takes care of closing the TCP connection. It must not be canceled. It never exits with an exception. See :ref:`connection termination ` below. -Besides, :meth:`~protocol.WebSocketCommonProtocol.fail_connection` starts -the same :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` when +Besides, :meth:`~legacy.protocol.WebSocketCommonProtocol.fail_connection` starts +the same :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` when the opening handshake fails, in order to close the TCP connection. Splitting the responsibilities between two tasks makes it easier to guarantee @@ -99,11 +99,11 @@ that ``websockets`` can terminate connections: regardless of whether the connection terminates normally or abnormally. -:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` completes when no +:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` completes when no more data will be received on the connection. Under normal circumstances, it exits after exchanging close frames. -:attr:`~protocol.WebSocketCommonProtocol.close_connection_task` completes when +:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` completes when the TCP connection is closed. @@ -113,7 +113,7 @@ Opening handshake ----------------- ``websockets`` performs the opening handshake when establishing a WebSocket -connection. On the client side, :meth:`~client.connect` executes it before +connection. On the client side, :meth:`~legacy.client.connect` executes it before returning the protocol to the caller. On the server side, it's executed before passing the protocol to the ``ws_handler`` coroutine handling the connection. @@ -122,26 +122,26 @@ request and the server replies with an HTTP Switching Protocols response — ``websockets`` aims at keeping the implementation of both sides consistent with one another. -On the client side, :meth:`~client.WebSocketClientProtocol.handshake`: +On the client side, :meth:`~legacy.client.WebSocketClientProtocol.handshake`: - builds a HTTP request based on the ``uri`` and parameters passed to - :meth:`~client.connect`; + :meth:`~legacy.client.connect`; - writes the HTTP request to the network; - reads a HTTP response from the network; - checks the HTTP response, validates ``extensions`` and ``subprotocol``, and configures the protocol accordingly; - moves to the ``OPEN`` state. -On the server side, :meth:`~server.WebSocketServerProtocol.handshake`: +On the server side, :meth:`~legacy.server.WebSocketServerProtocol.handshake`: - reads a HTTP request from the network; -- calls :meth:`~server.WebSocketServerProtocol.process_request` which may +- calls :meth:`~legacy.server.WebSocketServerProtocol.process_request` which may abort the WebSocket handshake and return a HTTP response instead; this hook only makes sense on the server side; - checks the HTTP request, negotiates ``extensions`` and ``subprotocol``, and configures the protocol accordingly; - builds a HTTP response based on the above and parameters passed to - :meth:`~server.serve`; + :meth:`~legacy.server.serve`; - writes the HTTP response to the network; - moves to the ``OPEN`` state; - returns the ``path`` part of the ``uri``. @@ -177,16 +177,16 @@ differences between a server and a client: These differences are so minor that all the logic for `data framing`_, for `sending and receiving data`_ and for `closing the connection`_ is implemented -in the same class, :class:`~protocol.WebSocketCommonProtocol`. +in the same class, :class:`~legacy.protocol.WebSocketCommonProtocol`. .. _data framing: https://tools.ietf.org/html/rfc6455#section-5 .. _sending and receiving data: https://tools.ietf.org/html/rfc6455#section-6 .. _closing the connection: https://tools.ietf.org/html/rfc6455#section-7 -The :attr:`~protocol.WebSocketCommonProtocol.is_client` attribute tells which +The :attr:`~legacy.protocol.WebSocketCommonProtocol.is_client` attribute tells which side a protocol instance is managing. This attribute is defined on the -:attr:`~server.WebSocketServerProtocol` and -:attr:`~client.WebSocketClientProtocol` classes. +:attr:`~legacy.server.WebSocketServerProtocol` and +:attr:`~legacy.client.WebSocketClientProtocol` classes. Data flow ......... @@ -210,11 +210,11 @@ The left side of the diagram shows how ``websockets`` receives data. Incoming data is written to a :class:`~asyncio.StreamReader` in order to implement flow control and provide backpressure on the TCP connection. -:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`, which is started +:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task`, which is started when the WebSocket connection is established, processes this data. When it receives data frames, it reassembles fragments and puts the resulting -messages in the :attr:`~protocol.WebSocketCommonProtocol.messages` queue. +messages in the :attr:`~legacy.protocol.WebSocketCommonProtocol.messages` queue. When it encounters a control frame: @@ -226,11 +226,11 @@ When it encounters a control frame: Running this process in a task guarantees that control frames are processed promptly. Without such a task, ``websockets`` would depend on the application to drive the connection by having exactly one coroutine awaiting -:meth:`~protocol.WebSocketCommonProtocol.recv` at any time. While this +:meth:`~legacy.protocol.WebSocketCommonProtocol.recv` at any time. While this happens naturally in many use cases, it cannot be relied upon. -Then :meth:`~protocol.WebSocketCommonProtocol.recv` fetches the next message -from the :attr:`~protocol.WebSocketCommonProtocol.messages` queue, with some +Then :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` fetches the next message +from the :attr:`~legacy.protocol.WebSocketCommonProtocol.messages` queue, with some complexity added for handling backpressure and termination correctly. Sending data @@ -238,18 +238,18 @@ Sending data The right side of the diagram shows how ``websockets`` sends data. -:meth:`~protocol.WebSocketCommonProtocol.send` writes one or several data +:meth:`~legacy.protocol.WebSocketCommonProtocol.send` writes one or several data frames containing the message. While sending a fragmented message, concurrent -calls to :meth:`~protocol.WebSocketCommonProtocol.send` are put on hold until +calls to :meth:`~legacy.protocol.WebSocketCommonProtocol.send` are put on hold until all fragments are sent. This makes concurrent calls safe. -:meth:`~protocol.WebSocketCommonProtocol.ping` writes a ping frame and +:meth:`~legacy.protocol.WebSocketCommonProtocol.ping` writes a ping frame and yields a :class:`~asyncio.Future` which will be completed when a matching pong frame is received. -:meth:`~protocol.WebSocketCommonProtocol.pong` writes a pong frame. +:meth:`~legacy.protocol.WebSocketCommonProtocol.pong` writes a pong frame. -:meth:`~protocol.WebSocketCommonProtocol.close` writes a close frame and +:meth:`~legacy.protocol.WebSocketCommonProtocol.close` writes a close frame and waits for the TCP connection to terminate. Outgoing data is written to a :class:`~asyncio.StreamWriter` in order to @@ -261,17 +261,17 @@ Closing handshake ................. When the other side of the connection initiates the closing handshake, -:meth:`~protocol.WebSocketCommonProtocol.read_message` receives a close +:meth:`~legacy.protocol.WebSocketCommonProtocol.read_message` receives a close frame while in the ``OPEN`` state. It moves to the ``CLOSING`` state, sends a close frame, and returns ``None``, causing -:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to terminate. +:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` to terminate. When this side of the connection initiates the closing handshake with -:meth:`~protocol.WebSocketCommonProtocol.close`, it moves to the ``CLOSING`` +:meth:`~legacy.protocol.WebSocketCommonProtocol.close`, it moves to the ``CLOSING`` state and sends a close frame. When the other side sends a close frame, -:meth:`~protocol.WebSocketCommonProtocol.read_message` receives it in the +:meth:`~legacy.protocol.WebSocketCommonProtocol.read_message` receives it in the ``CLOSING`` state and returns ``None``, also causing -:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to terminate. +:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` to terminate. If the other side doesn't send a close frame within the connection's close timeout, ``websockets`` :ref:`fails the connection `. @@ -288,31 +288,31 @@ Then ``websockets`` terminates the TCP connection. Connection termination ---------------------- -:attr:`~protocol.WebSocketCommonProtocol.close_connection_task`, which is +:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task`, which is started when the WebSocket connection is established, is responsible for eventually closing the TCP connection. -First :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` waits -for :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to terminate, +First :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` waits +for :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` to terminate, which may happen as a result of: - a successful closing handshake: as explained above, this exits the infinite - loop in :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`; + loop in :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task`; - a timeout while waiting for the closing handshake to complete: this cancels - :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`; + :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task`; - a protocol error, including connection errors: depending on the exception, - :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` :ref:`fails the + :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` :ref:`fails the connection ` with a suitable code and exits. -:attr:`~protocol.WebSocketCommonProtocol.close_connection_task` is separate -from :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to make it +:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` is separate +from :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` to make it easier to implement the timeout on the closing handshake. Canceling -:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` creates no risk -of canceling :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` +:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` creates no risk +of canceling :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` and failing to close the TCP connection, thus leaking resources. -Then :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` cancels -:attr:`~protocol.WebSocketCommonProtocol.keepalive_ping`. This task has no +Then :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` cancels +:attr:`~legacy.protocol.WebSocketCommonProtocol.keepalive_ping`. This task has no protocol compliance responsibilities. Terminating it to avoid leaking it is the only concern. @@ -334,11 +334,11 @@ If the opening handshake doesn't complete successfully, ``websockets`` fails the connection by closing the TCP connection. Once the opening handshake has completed, ``websockets`` fails the connection -by canceling :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` and +by canceling :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` and sending a close frame if appropriate. -:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` exits, unblocking -:attr:`~protocol.WebSocketCommonProtocol.close_connection_task`, which closes +:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` exits, unblocking +:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task`, which closes the TCP connection. @@ -414,45 +414,45 @@ happen on the client side. On the server side, the opening handshake is managed by ``websockets`` and nothing results in a cancellation. Once the WebSocket connection is established, internal tasks -:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` and -:attr:`~protocol.WebSocketCommonProtocol.close_connection_task` mustn't get +:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` and +:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` mustn't get accidentally canceled if a coroutine that awaits them is canceled. In other words, they must be shielded from cancellation. -:meth:`~protocol.WebSocketCommonProtocol.recv` waits for the next message in -the queue or for :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` +:meth:`~legacy.protocol.WebSocketCommonProtocol.recv` waits for the next message in +the queue or for :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` to terminate, whichever comes first. It relies on :func:`~asyncio.wait` for waiting on two futures in parallel. As a consequence, even though it's waiting on a :class:`~asyncio.Future` signaling the next message and on -:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`, it doesn't +:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task`, it doesn't propagate cancellation to them. -:meth:`~protocol.WebSocketCommonProtocol.ensure_open` is called by -:meth:`~protocol.WebSocketCommonProtocol.send`, -:meth:`~protocol.WebSocketCommonProtocol.ping`, and -:meth:`~protocol.WebSocketCommonProtocol.pong`. When the connection state is +:meth:`~legacy.protocol.WebSocketCommonProtocol.ensure_open` is called by +:meth:`~legacy.protocol.WebSocketCommonProtocol.send`, +:meth:`~legacy.protocol.WebSocketCommonProtocol.ping`, and +:meth:`~legacy.protocol.WebSocketCommonProtocol.pong`. When the connection state is ``CLOSING``, it waits for -:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` but shields it to +:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` but shields it to prevent cancellation. -:meth:`~protocol.WebSocketCommonProtocol.close` waits for the data transfer +:meth:`~legacy.protocol.WebSocketCommonProtocol.close` waits for the data transfer task to terminate with :func:`~asyncio.wait_for`. If it's canceled or if the -timeout elapses, :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` +timeout elapses, :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` is canceled, which is correct at this point. -:meth:`~protocol.WebSocketCommonProtocol.close` then waits for -:attr:`~protocol.WebSocketCommonProtocol.close_connection_task` but shields it +:meth:`~legacy.protocol.WebSocketCommonProtocol.close` then waits for +:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` but shields it to prevent cancellation. -:meth:`~protocol.WebSocketCommonProtocol.close` and -:func:`~protocol.WebSocketCommonProtocol.fail_connection` are the only -places where :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` may +:meth:`~legacy.protocol.WebSocketCommonProtocol.close` and +:func:`~legacy.protocol.WebSocketCommonProtocol.fail_connection` are the only +places where :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` may be canceled. -:attr:`~protocol.WebSocketCommonProtocol.close_connnection_task` starts by -waiting for :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`. It +:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connnection_task` starts by +waiting for :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task`. It catches :exc:`~asyncio.CancelledError` to prevent a cancellation of -:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` from propagating -to :attr:`~protocol.WebSocketCommonProtocol.close_connnection_task`. +:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` from propagating +to :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connnection_task`. .. _backpressure: @@ -519,47 +519,47 @@ For each connection, the receiving side contains these buffers: - OS buffers: tuning them is an advanced optimization. - :class:`~asyncio.StreamReader` bytes buffer: the default limit is 64 KiB. You can set another limit by passing a ``read_limit`` keyword argument to - :func:`~client.connect()` or :func:`~server.serve`. + :func:`~legacy.client.connect()` or :func:`~legacy.server.serve`. - Incoming messages :class:`~collections.deque`: its size depends both on the size and the number of messages it contains. By default the maximum UTF-8 encoded size is 1 MiB and the maximum number is 32. In the worst case, after UTF-8 decoding, a single message could take up to 4 MiB of memory and the overall memory consumption could reach 128 MiB. You should adjust these limits by setting the ``max_size`` and ``max_queue`` keyword arguments of - :func:`~client.connect()` or :func:`~server.serve` according to your + :func:`~legacy.client.connect()` or :func:`~legacy.server.serve` according to your application's requirements. For each connection, the sending side contains these buffers: - :class:`~asyncio.StreamWriter` bytes buffer: the default size is 64 KiB. You can set another limit by passing a ``write_limit`` keyword argument to - :func:`~client.connect()` or :func:`~server.serve`. + :func:`~legacy.client.connect()` or :func:`~legacy.server.serve`. - OS buffers: tuning them is an advanced optimization. Concurrency ----------- -Awaiting any combination of :meth:`~protocol.WebSocketCommonProtocol.recv`, -:meth:`~protocol.WebSocketCommonProtocol.send`, -:meth:`~protocol.WebSocketCommonProtocol.close` -:meth:`~protocol.WebSocketCommonProtocol.ping`, or -:meth:`~protocol.WebSocketCommonProtocol.pong` concurrently is safe, including +Awaiting any combination of :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`, +:meth:`~legacy.protocol.WebSocketCommonProtocol.send`, +:meth:`~legacy.protocol.WebSocketCommonProtocol.close` +:meth:`~legacy.protocol.WebSocketCommonProtocol.ping`, or +:meth:`~legacy.protocol.WebSocketCommonProtocol.pong` concurrently is safe, including multiple calls to the same method, with one exception and one limitation. * **Only one coroutine can receive messages at a time.** This constraint avoids non-deterministic behavior (and simplifies the implementation). If a - coroutine is awaiting :meth:`~protocol.WebSocketCommonProtocol.recv`, + coroutine is awaiting :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`, awaiting it again in another coroutine raises :exc:`RuntimeError`. * **Sending a fragmented message forces serialization.** Indeed, the WebSocket protocol doesn't support multiplexing messages. If a coroutine is awaiting - :meth:`~protocol.WebSocketCommonProtocol.send` to send a fragmented message, + :meth:`~legacy.protocol.WebSocketCommonProtocol.send` to send a fragmented message, awaiting it again in another coroutine waits until the first call completes. This will be transparent in many cases. It may be a concern if the fragmented message is generated slowly by an asynchronous iterator. Receiving frames is independent from sending frames. This isolates -:meth:`~protocol.WebSocketCommonProtocol.recv`, which receives frames, from +:meth:`~legacy.protocol.WebSocketCommonProtocol.recv`, which receives frames, from the other methods, which send frames. While the connection is open, each frame is sent with a single write. Combined diff --git a/docs/extensions.rst b/docs/extensions.rst index 400034090..dea91219e 100644 --- a/docs/extensions.rst +++ b/docs/extensions.rst @@ -14,8 +14,9 @@ Per-Message Deflate, specified in :rfc:`7692`. Per-Message Deflate ------------------- -:func:`~server.serve()` and :func:`~client.connect` enable the Per-Message -Deflate extension by default. You can disable this with ``compression=None``. +:func:`~legacy.server.serve()` and :func:`~legacy.client.connect` enable the +Per-Message Deflate extension by default. You can disable this with +``compression=None``. You can also configure the Per-Message Deflate extension explicitly if you want to customize its parameters. diff --git a/docs/faq.rst b/docs/faq.rst index 4a083e2d0..eee14dda8 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -80,13 +80,13 @@ How do I get access HTTP headers, for example cookies? ...................................................... To access HTTP headers during the WebSocket handshake, you can override -:attr:`~server.WebSocketServerProtocol.process_request`:: +:attr:`~legacy.server.WebSocketServerProtocol.process_request`:: async def process_request(self, path, request_headers): cookies = request_header["Cookie"] Once the connection is established, they're available in -:attr:`~protocol.WebSocketServerProtocol.request_headers`:: +:attr:`~legacy.protocol.WebSocketServerProtocol.request_headers`:: async def handler(websocket, path): cookies = websocket.request_headers["Cookie"] @@ -94,7 +94,7 @@ Once the connection is established, they're available in How do I get the IP address of the client connecting to my server? .................................................................. -It's available in :attr:`~protocol.WebSocketCommonProtocol.remote_address`:: +It's available in :attr:`~legacy.protocol.WebSocketCommonProtocol.remote_address`:: async def handler(websocket, path): remote_ip = websocket.remote_address[0] @@ -121,7 +121,7 @@ Providing a HTTP server is out of scope for websockets. It only aims at providing a WebSocket server. There's limited support for returning HTTP responses with the -:attr:`~server.WebSocketServerProtocol.process_request` hook. +:attr:`~legacy.server.WebSocketServerProtocol.process_request` hook. If you need more, pick a HTTP server and run it separately. Client side diff --git a/docs/intro.rst b/docs/intro.rst index 8aaaeddca..c77139cab 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -143,7 +143,7 @@ For getting messages from a ``producer`` coroutine and sending them:: In this example, ``producer`` represents your business logic for generating messages to send on the WebSocket connection. -:meth:`~protocol.WebSocketCommonProtocol.send` raises a +:meth:`~legacy.protocol.WebSocketCommonProtocol.send` raises a :exc:`~exceptions.ConnectionClosed` exception when the client disconnects, which breaks out of the ``while True`` loop. diff --git a/setup.py b/setup.py index f35819247..85d899cb4 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ if sys.version_info[:3] < (3, 6, 1): raise Exception("websockets requires Python >= 3.6.1.") -packages = ['websockets', 'websockets/extensions'] +packages = ['websockets', 'websockets/legacy', 'websockets/extensions'] ext_modules = [ setuptools.Extension( diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index c4accaca1..0242e7942 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -1,11 +1,13 @@ # This relies on each of the submodules having an __all__ variable. -from .auth import * # noqa -from .client import * # noqa +from .client import * from .datastructures import * # noqa from .exceptions import * # noqa -from .protocol import * # noqa -from .server import * # noqa +from .legacy.auth import * # noqa +from .legacy.client import * # noqa +from .legacy.protocol import * # noqa +from .legacy.server import * # noqa +from .server import * from .typing import * # noqa from .uri import * # noqa from .version import version as __version__ # noqa diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index bce3e4bbb..d44e34e74 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -6,8 +6,8 @@ import threading from typing import Any, Set -from .client import connect from .exceptions import ConnectionClosed, format_close +from .legacy.client import connect if sys.platform == "win32": diff --git a/src/websockets/auth.py b/src/websockets/auth.py index c1b7a0b1a..c8839c401 100644 --- a/src/websockets/auth.py +++ b/src/websockets/auth.py @@ -1,165 +1,4 @@ -""" -:mod:`websockets.auth` provides HTTP Basic Authentication according to -:rfc:`7235` and :rfc:`7617`. - -""" - - -import functools -import http -from typing import Any, Awaitable, Callable, Iterable, Optional, Tuple, Union, cast - -from .asyncio_server import HTTPResponse, WebSocketServerProtocol -from .datastructures import Headers -from .exceptions import InvalidHeader -from .headers import build_www_authenticate_basic, parse_authorization_basic +from .legacy.auth import BasicAuthWebSocketServerProtocol, basic_auth_protocol_factory __all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"] - -Credentials = Tuple[str, str] - - -def is_credentials(value: Any) -> bool: - try: - username, password = value - except (TypeError, ValueError): - return False - else: - return isinstance(username, str) and isinstance(password, str) - - -class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol): - """ - WebSocket server protocol that enforces HTTP Basic Auth. - - """ - - def __init__( - self, - *args: Any, - realm: str, - check_credentials: Callable[[str, str], Awaitable[bool]], - **kwargs: Any, - ) -> None: - self.realm = realm - self.check_credentials = check_credentials - super().__init__(*args, **kwargs) - - async def process_request( - self, path: str, request_headers: Headers - ) -> Optional[HTTPResponse]: - """ - Check HTTP Basic Auth and return a HTTP 401 or 403 response if needed. - - If authentication succeeds, the username of the authenticated user is - stored in the ``username`` attribute. - - """ - try: - authorization = request_headers["Authorization"] - except KeyError: - return ( - http.HTTPStatus.UNAUTHORIZED, - [("WWW-Authenticate", build_www_authenticate_basic(self.realm))], - b"Missing credentials\n", - ) - - try: - username, password = parse_authorization_basic(authorization) - except InvalidHeader: - return ( - http.HTTPStatus.UNAUTHORIZED, - [("WWW-Authenticate", build_www_authenticate_basic(self.realm))], - b"Unsupported credentials\n", - ) - - if not await self.check_credentials(username, password): - return ( - http.HTTPStatus.UNAUTHORIZED, - [("WWW-Authenticate", build_www_authenticate_basic(self.realm))], - b"Invalid credentials\n", - ) - - self.username = username - - return await super().process_request(path, request_headers) - - -def basic_auth_protocol_factory( - realm: str, - credentials: Optional[Union[Credentials, Iterable[Credentials]]] = None, - check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None, - create_protocol: Optional[Callable[[Any], BasicAuthWebSocketServerProtocol]] = None, -) -> Callable[[Any], BasicAuthWebSocketServerProtocol]: - """ - Protocol factory that enforces HTTP Basic Auth. - - ``basic_auth_protocol_factory`` is designed to integrate with - :func:`~websockets.server.serve` like this:: - - websockets.serve( - ..., - create_protocol=websockets.basic_auth_protocol_factory( - realm="my dev server", - credentials=("hello", "iloveyou"), - ) - ) - - ``realm`` indicates the scope of protection. It should contain only ASCII - characters because the encoding of non-ASCII characters is undefined. - Refer to section 2.2 of :rfc:`7235` for details. - - ``credentials`` defines hard coded authorized credentials. It can be a - ``(username, password)`` pair or a list of such pairs. - - ``check_credentials`` defines a coroutine that checks whether credentials - are authorized. This coroutine receives ``username`` and ``password`` - arguments and returns a :class:`bool`. - - One of ``credentials`` or ``check_credentials`` must be provided but not - both. - - By default, ``basic_auth_protocol_factory`` creates a factory for building - :class:`BasicAuthWebSocketServerProtocol` instances. You can override this - with the ``create_protocol`` parameter. - - :param realm: scope of protection - :param credentials: hard coded credentials - :param check_credentials: coroutine that verifies credentials - :raises TypeError: if the credentials argument has the wrong type - - """ - if (credentials is None) == (check_credentials is None): - raise TypeError("provide either credentials or check_credentials") - - if credentials is not None: - if is_credentials(credentials): - - async def check_credentials(username: str, password: str) -> bool: - return (username, password) == credentials - - elif isinstance(credentials, Iterable): - credentials_list = list(credentials) - if all(is_credentials(item) for item in credentials_list): - credentials_dict = dict(credentials_list) - - async def check_credentials(username: str, password: str) -> bool: - return credentials_dict.get(username) == password - - else: - raise TypeError(f"invalid credentials argument: {credentials}") - - else: - raise TypeError(f"invalid credentials argument: {credentials}") - - if create_protocol is None: - # Not sure why mypy cannot figure this out. - create_protocol = cast( - Callable[[Any], BasicAuthWebSocketServerProtocol], - BasicAuthWebSocketServerProtocol, - ) - - return functools.partial( - create_protocol, realm=realm, check_credentials=check_credentials - ) diff --git a/src/websockets/client.py b/src/websockets/client.py index b7e407a45..8cababed5 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -2,7 +2,6 @@ import logging from typing import Generator, List, Optional, Sequence -from .asyncio_client import WebSocketClientProtocol, connect, unix_connect from .connection import CLIENT, CONNECTING, OPEN, Connection from .datastructures import Headers, HeadersLike, MultipleValuesError from .exceptions import ( @@ -25,6 +24,7 @@ ) from .http import USER_AGENT, build_host from .http11 import Request, Response +from .legacy.client import WebSocketClientProtocol, connect, unix_connect # noqa from .typing import ( ConnectionOption, ExtensionHeader, @@ -36,12 +36,7 @@ from .utils import accept_key, generate_key -__all__ = [ - "connect", - "unix_connect", - "ClientConnection", - "WebSocketClientProtocol", -] +__all__ = ["ClientConnection"] logger = logging.getLogger(__name__) @@ -64,7 +59,7 @@ def __init__( self.extra_headers = extra_headers self.key = generate_key() - def connect(self) -> Request: + def connect(self) -> Request: # noqa: F811 """ Create a WebSocket handshake request event to send to the server. diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index bdadae05e..e0860c743 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -301,7 +301,7 @@ class AbortHandshake(InvalidHandshake): This exception is an implementation detail. - The public API is :meth:`~server.WebSocketServerProtocol.process_request`. + The public API is :meth:`~legacy.server.WebSocketServerProtocol.process_request`. """ diff --git a/src/websockets/framing.py b/src/websockets/framing.py index b2996d788..2dadb5610 100644 --- a/src/websockets/framing.py +++ b/src/websockets/framing.py @@ -1,139 +1,6 @@ -""" -:mod:`websockets.framing` reads and writes WebSocket frames. - -It deals with a single frame at a time. Anything that depends on the sequence -of frames is implemented in :mod:`websockets.protocol`. - -See `section 5 of RFC 6455`_. - -.. _section 5 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-5 - -""" - -import struct import warnings -from typing import Any, Awaitable, Callable, Optional, Sequence - -from .exceptions import PayloadTooBig, ProtocolError -from .frames import Frame as NewFrame, Opcode - -try: - from .speedups import apply_mask -except ImportError: # pragma: no cover - from .utils import apply_mask +from .legacy.framing import * # noqa warnings.warn("websockets.framing is deprecated", DeprecationWarning) - - -class Frame(NewFrame): - @classmethod - async def read( - cls, - reader: Callable[[int], Awaitable[bytes]], - *, - mask: bool, - max_size: Optional[int] = None, - extensions: Optional[Sequence["websockets.extensions.base.Extension"]] = None, - ) -> "Frame": - """ - Read a WebSocket frame. - - :param reader: coroutine that reads exactly the requested number of - bytes, unless the end of file is reached - :param mask: whether the frame should be masked i.e. whether the read - happens on the server side - :param max_size: maximum payload size in bytes - :param extensions: list of classes with a ``decode()`` method that - transforms the frame and return a new frame; extensions are applied - in reverse order - :raises ~websockets.exceptions.PayloadTooBig: if the frame exceeds - ``max_size`` - :raises ~websockets.exceptions.ProtocolError: if the frame - contains incorrect values - - """ - - # Read the header. - data = await reader(2) - head1, head2 = struct.unpack("!BB", data) - - # While not Pythonic, this is marginally faster than calling bool(). - fin = True if head1 & 0b10000000 else False - rsv1 = True if head1 & 0b01000000 else False - rsv2 = True if head1 & 0b00100000 else False - rsv3 = True if head1 & 0b00010000 else False - - try: - opcode = Opcode(head1 & 0b00001111) - except ValueError as exc: - raise ProtocolError("invalid opcode") from exc - - if (True if head2 & 0b10000000 else False) != mask: - raise ProtocolError("incorrect masking") - - length = head2 & 0b01111111 - if length == 126: - data = await reader(2) - (length,) = struct.unpack("!H", data) - elif length == 127: - data = await reader(8) - (length,) = struct.unpack("!Q", data) - if max_size is not None and length > max_size: - raise PayloadTooBig(f"over size limit ({length} > {max_size} bytes)") - if mask: - mask_bits = await reader(4) - - # Read the data. - data = await reader(length) - if mask: - data = apply_mask(data, mask_bits) - - frame = cls(fin, opcode, data, rsv1, rsv2, rsv3) - - if extensions is None: - extensions = [] - for extension in reversed(extensions): - frame = cls(*extension.decode(frame, max_size=max_size)) - - frame.check() - - return frame - - def write( - self, - write: Callable[[bytes], Any], - *, - mask: bool, - extensions: Optional[Sequence["websockets.extensions.base.Extension"]] = None, - ) -> None: - """ - Write a WebSocket frame. - - :param frame: frame to write - :param write: function that writes bytes - :param mask: whether the frame should be masked i.e. whether the write - happens on the client side - :param extensions: list of classes with an ``encode()`` method that - transform the frame and return a new frame; extensions are applied - in order - :raises ~websockets.exceptions.ProtocolError: if the frame - contains incorrect values - - """ - # The frame is written in a single call to write in order to prevent - # TCP fragmentation. See #68 for details. This also makes it safe to - # send frames concurrently from multiple coroutines. - write(self.serialize(mask=mask, extensions=extensions)) - - -# Backwards compatibility with previously documented public APIs -from .frames import parse_close # isort:skip # noqa -from .frames import prepare_ctrl as encode_data # isort:skip # noqa -from .frames import prepare_data # isort:skip # noqa -from .frames import serialize_close # isort:skip # noqa - - -# at the bottom to allow circular import, because Extension depends on Frame -import websockets.extensions.base # isort:skip # noqa diff --git a/src/websockets/handshake.py b/src/websockets/handshake.py index 3ff6c005d..cc4010d41 100644 --- a/src/websockets/handshake.py +++ b/src/websockets/handshake.py @@ -13,7 +13,7 @@ def build_request(headers: Headers) -> str: # pragma: no cover warnings.warn( "websockets.handshake.build_request is deprecated", DeprecationWarning ) - from .handshake_legacy import build_request + from .legacy.handshake import build_request return build_request(headers) @@ -22,7 +22,7 @@ def check_request(headers: Headers) -> str: # pragma: no cover warnings.warn( "websockets.handshake.check_request is deprecated", DeprecationWarning ) - from .handshake_legacy import check_request + from .legacy.handshake import check_request return check_request(headers) @@ -31,7 +31,7 @@ def build_response(headers: Headers, key: str) -> None: # pragma: no cover warnings.warn( "websockets.handshake.build_response is deprecated", DeprecationWarning ) - from .handshake_legacy import build_response + from .legacy.handshake import build_response return build_response(headers, key) @@ -40,6 +40,6 @@ def check_response(headers: Headers, key: str) -> None: # pragma: no cover warnings.warn( "websockets.handshake.check_response is deprecated", DeprecationWarning ) - from .handshake_legacy import check_response + from .legacy.handshake import check_response return check_response(headers, key) diff --git a/src/websockets/http.py b/src/websockets/http.py index ed3fe48d0..b05b78455 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -47,7 +47,7 @@ async def read_request( stream: asyncio.StreamReader, ) -> Tuple[str, Headers]: # pragma: no cover warnings.warn("websockets.http.read_request is deprecated", DeprecationWarning) - from .http_legacy import read_request + from .legacy.http import read_request return await read_request(stream) @@ -56,6 +56,6 @@ async def read_response( stream: asyncio.StreamReader, ) -> Tuple[int, str, Headers]: # pragma: no cover warnings.warn("websockets.http.read_response is deprecated", DeprecationWarning) - from .http_legacy import read_response + from .legacy.http import read_response return await read_response(stream) diff --git a/src/websockets/legacy/__init__.py b/src/websockets/legacy/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py new file mode 100644 index 000000000..8cb60429a --- /dev/null +++ b/src/websockets/legacy/auth.py @@ -0,0 +1,165 @@ +""" +:mod:`websockets.legacy.auth` provides HTTP Basic Authentication according to +:rfc:`7235` and :rfc:`7617`. + +""" + + +import functools +import http +from typing import Any, Awaitable, Callable, Iterable, Optional, Tuple, Union, cast + +from ..datastructures import Headers +from ..exceptions import InvalidHeader +from ..headers import build_www_authenticate_basic, parse_authorization_basic +from .server import HTTPResponse, WebSocketServerProtocol + + +__all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"] + +Credentials = Tuple[str, str] + + +def is_credentials(value: Any) -> bool: + try: + username, password = value + except (TypeError, ValueError): + return False + else: + return isinstance(username, str) and isinstance(password, str) + + +class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol): + """ + WebSocket server protocol that enforces HTTP Basic Auth. + + """ + + def __init__( + self, + *args: Any, + realm: str, + check_credentials: Callable[[str, str], Awaitable[bool]], + **kwargs: Any, + ) -> None: + self.realm = realm + self.check_credentials = check_credentials + super().__init__(*args, **kwargs) + + async def process_request( + self, path: str, request_headers: Headers + ) -> Optional[HTTPResponse]: + """ + Check HTTP Basic Auth and return a HTTP 401 or 403 response if needed. + + If authentication succeeds, the username of the authenticated user is + stored in the ``username`` attribute. + + """ + try: + authorization = request_headers["Authorization"] + except KeyError: + return ( + http.HTTPStatus.UNAUTHORIZED, + [("WWW-Authenticate", build_www_authenticate_basic(self.realm))], + b"Missing credentials\n", + ) + + try: + username, password = parse_authorization_basic(authorization) + except InvalidHeader: + return ( + http.HTTPStatus.UNAUTHORIZED, + [("WWW-Authenticate", build_www_authenticate_basic(self.realm))], + b"Unsupported credentials\n", + ) + + if not await self.check_credentials(username, password): + return ( + http.HTTPStatus.UNAUTHORIZED, + [("WWW-Authenticate", build_www_authenticate_basic(self.realm))], + b"Invalid credentials\n", + ) + + self.username = username + + return await super().process_request(path, request_headers) + + +def basic_auth_protocol_factory( + realm: str, + credentials: Optional[Union[Credentials, Iterable[Credentials]]] = None, + check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None, + create_protocol: Optional[Callable[[Any], BasicAuthWebSocketServerProtocol]] = None, +) -> Callable[[Any], BasicAuthWebSocketServerProtocol]: + """ + Protocol factory that enforces HTTP Basic Auth. + + ``basic_auth_protocol_factory`` is designed to integrate with + :func:`~websockets.legacy.server.serve` like this:: + + websockets.serve( + ..., + create_protocol=websockets.basic_auth_protocol_factory( + realm="my dev server", + credentials=("hello", "iloveyou"), + ) + ) + + ``realm`` indicates the scope of protection. It should contain only ASCII + characters because the encoding of non-ASCII characters is undefined. + Refer to section 2.2 of :rfc:`7235` for details. + + ``credentials`` defines hard coded authorized credentials. It can be a + ``(username, password)`` pair or a list of such pairs. + + ``check_credentials`` defines a coroutine that checks whether credentials + are authorized. This coroutine receives ``username`` and ``password`` + arguments and returns a :class:`bool`. + + One of ``credentials`` or ``check_credentials`` must be provided but not + both. + + By default, ``basic_auth_protocol_factory`` creates a factory for building + :class:`BasicAuthWebSocketServerProtocol` instances. You can override this + with the ``create_protocol`` parameter. + + :param realm: scope of protection + :param credentials: hard coded credentials + :param check_credentials: coroutine that verifies credentials + :raises TypeError: if the credentials argument has the wrong type + + """ + if (credentials is None) == (check_credentials is None): + raise TypeError("provide either credentials or check_credentials") + + if credentials is not None: + if is_credentials(credentials): + + async def check_credentials(username: str, password: str) -> bool: + return (username, password) == credentials + + elif isinstance(credentials, Iterable): + credentials_list = list(credentials) + if all(is_credentials(item) for item in credentials_list): + credentials_dict = dict(credentials_list) + + async def check_credentials(username: str, password: str) -> bool: + return credentials_dict.get(username) == password + + else: + raise TypeError(f"invalid credentials argument: {credentials}") + + else: + raise TypeError(f"invalid credentials argument: {credentials}") + + if create_protocol is None: + # Not sure why mypy cannot figure this out. + create_protocol = cast( + Callable[[Any], BasicAuthWebSocketServerProtocol], + BasicAuthWebSocketServerProtocol, + ) + + return functools.partial( + create_protocol, realm=realm, check_credentials=check_credentials + ) diff --git a/src/websockets/asyncio_client.py b/src/websockets/legacy/client.py similarity index 97% rename from src/websockets/asyncio_client.py rename to src/websockets/legacy/client.py index 3f406170a..27f6e8209 100644 --- a/src/websockets/asyncio_client.py +++ b/src/websockets/legacy/client.py @@ -1,5 +1,5 @@ """ -:mod:`websockets.client` defines the WebSocket client APIs. +:mod:`websockets.legacy.client` defines the WebSocket client APIs. """ @@ -11,8 +11,8 @@ from types import TracebackType from typing import Any, Callable, Generator, List, Optional, Sequence, Tuple, Type, cast -from .datastructures import Headers, HeadersLike -from .exceptions import ( +from ..datastructures import Headers, HeadersLike +from ..exceptions import ( InvalidHandshake, InvalidHeader, InvalidMessage, @@ -21,21 +21,21 @@ RedirectHandshake, SecurityError, ) -from .extensions.base import ClientExtensionFactory, Extension -from .extensions.permessage_deflate import enable_client_permessage_deflate -from .handshake_legacy import build_request, check_response -from .headers import ( +from ..extensions.base import ClientExtensionFactory, Extension +from ..extensions.permessage_deflate import enable_client_permessage_deflate +from ..headers import ( build_authorization_basic, build_extension, build_subprotocol, parse_extension, parse_subprotocol, ) -from .http import USER_AGENT, build_host -from .http_legacy import read_response +from ..http import USER_AGENT, build_host +from ..typing import ExtensionHeader, Origin, Subprotocol +from ..uri import WebSocketURI, parse_uri +from .handshake import build_request, check_response +from .http import read_response from .protocol import WebSocketCommonProtocol -from .typing import ExtensionHeader, Origin, Subprotocol -from .uri import WebSocketURI, parse_uri __all__ = ["connect", "unix_connect", "WebSocketClientProtocol"] diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py new file mode 100644 index 000000000..e41c295dd --- /dev/null +++ b/src/websockets/legacy/framing.py @@ -0,0 +1,135 @@ +""" +:mod:`websockets.legacy.framing` reads and writes WebSocket frames. + +It deals with a single frame at a time. Anything that depends on the sequence +of frames is implemented in :mod:`websockets.legacy.protocol`. + +See `section 5 of RFC 6455`_. + +.. _section 5 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-5 + +""" + +import struct +from typing import Any, Awaitable, Callable, Optional, Sequence + +from ..exceptions import PayloadTooBig, ProtocolError +from ..frames import Frame as NewFrame, Opcode + + +try: + from ..speedups import apply_mask +except ImportError: # pragma: no cover + from ..utils import apply_mask + + +class Frame(NewFrame): + @classmethod + async def read( + cls, + reader: Callable[[int], Awaitable[bytes]], + *, + mask: bool, + max_size: Optional[int] = None, + extensions: Optional[Sequence["websockets.extensions.base.Extension"]] = None, + ) -> "Frame": + """ + Read a WebSocket frame. + + :param reader: coroutine that reads exactly the requested number of + bytes, unless the end of file is reached + :param mask: whether the frame should be masked i.e. whether the read + happens on the server side + :param max_size: maximum payload size in bytes + :param extensions: list of classes with a ``decode()`` method that + transforms the frame and return a new frame; extensions are applied + in reverse order + :raises ~websockets.exceptions.PayloadTooBig: if the frame exceeds + ``max_size`` + :raises ~websockets.exceptions.ProtocolError: if the frame + contains incorrect values + + """ + + # Read the header. + data = await reader(2) + head1, head2 = struct.unpack("!BB", data) + + # While not Pythonic, this is marginally faster than calling bool(). + fin = True if head1 & 0b10000000 else False + rsv1 = True if head1 & 0b01000000 else False + rsv2 = True if head1 & 0b00100000 else False + rsv3 = True if head1 & 0b00010000 else False + + try: + opcode = Opcode(head1 & 0b00001111) + except ValueError as exc: + raise ProtocolError("invalid opcode") from exc + + if (True if head2 & 0b10000000 else False) != mask: + raise ProtocolError("incorrect masking") + + length = head2 & 0b01111111 + if length == 126: + data = await reader(2) + (length,) = struct.unpack("!H", data) + elif length == 127: + data = await reader(8) + (length,) = struct.unpack("!Q", data) + if max_size is not None and length > max_size: + raise PayloadTooBig(f"over size limit ({length} > {max_size} bytes)") + if mask: + mask_bits = await reader(4) + + # Read the data. + data = await reader(length) + if mask: + data = apply_mask(data, mask_bits) + + frame = cls(fin, opcode, data, rsv1, rsv2, rsv3) + + if extensions is None: + extensions = [] + for extension in reversed(extensions): + frame = cls(*extension.decode(frame, max_size=max_size)) + + frame.check() + + return frame + + def write( + self, + write: Callable[[bytes], Any], + *, + mask: bool, + extensions: Optional[Sequence["websockets.extensions.base.Extension"]] = None, + ) -> None: + """ + Write a WebSocket frame. + + :param frame: frame to write + :param write: function that writes bytes + :param mask: whether the frame should be masked i.e. whether the write + happens on the client side + :param extensions: list of classes with an ``encode()`` method that + transform the frame and return a new frame; extensions are applied + in order + :raises ~websockets.exceptions.ProtocolError: if the frame + contains incorrect values + + """ + # The frame is written in a single call to write in order to prevent + # TCP fragmentation. See #68 for details. This also makes it safe to + # send frames concurrently from multiple coroutines. + write(self.serialize(mask=mask, extensions=extensions)) + + +# Backwards compatibility with previously documented public APIs +from ..frames import parse_close # isort:skip # noqa +from ..frames import prepare_ctrl as encode_data # isort:skip # noqa +from ..frames import prepare_data # isort:skip # noqa +from ..frames import serialize_close # isort:skip # noqa + + +# at the bottom to allow circular import, because Extension depends on Frame +import websockets.extensions.base # isort:skip # noqa diff --git a/src/websockets/handshake_legacy.py b/src/websockets/legacy/handshake.py similarity index 93% rename from src/websockets/handshake_legacy.py rename to src/websockets/legacy/handshake.py index d34ca5f7f..44da72d21 100644 --- a/src/websockets/handshake_legacy.py +++ b/src/websockets/legacy/handshake.py @@ -1,5 +1,5 @@ """ -:mod:`websockets.handshake` provides helpers for the WebSocket handshake. +:mod:`websockets.legacy.handshake` provides helpers for the WebSocket handshake. See `section 4 of RFC 6455`_. @@ -29,11 +29,11 @@ import binascii from typing import List -from .datastructures import Headers, MultipleValuesError -from .exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade -from .headers import parse_connection, parse_upgrade -from .typing import ConnectionOption, UpgradeProtocol -from .utils import accept_key as accept, generate_key +from ..datastructures import Headers, MultipleValuesError +from ..exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade +from ..headers import parse_connection, parse_upgrade +from ..typing import ConnectionOption, UpgradeProtocol +from ..utils import accept_key as accept, generate_key __all__ = ["build_request", "check_request", "build_response", "check_response"] diff --git a/src/websockets/http_legacy.py b/src/websockets/legacy/http.py similarity index 98% rename from src/websockets/http_legacy.py rename to src/websockets/legacy/http.py index 5afe5f898..c18e08e8d 100644 --- a/src/websockets/http_legacy.py +++ b/src/websockets/legacy/http.py @@ -2,8 +2,8 @@ import re from typing import Tuple -from .datastructures import Headers -from .exceptions import SecurityError +from ..datastructures import Headers +from ..exceptions import SecurityError __all__ = ["read_request", "read_response"] diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py new file mode 100644 index 000000000..e4592b8a0 --- /dev/null +++ b/src/websockets/legacy/protocol.py @@ -0,0 +1,1459 @@ +""" +:mod:`websockets.legacy.protocol` handles WebSocket control and data frames. + +See `sections 4 to 8 of RFC 6455`_. + +.. _sections 4 to 8 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-4 + +""" + +import asyncio +import codecs +import collections +import enum +import logging +import random +import struct +import sys +import warnings +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Awaitable, + Deque, + Dict, + Iterable, + List, + Mapping, + Optional, + Union, + cast, +) + +from ..datastructures import Headers +from ..exceptions import ( + ConnectionClosed, + ConnectionClosedError, + ConnectionClosedOK, + InvalidState, + PayloadTooBig, + ProtocolError, +) +from ..extensions.base import Extension +from ..frames import ( + OP_BINARY, + OP_CLOSE, + OP_CONT, + OP_PING, + OP_PONG, + OP_TEXT, + Opcode, + parse_close, + prepare_ctrl, + prepare_data, + serialize_close, +) +from ..typing import Data, Subprotocol +from .framing import Frame + + +__all__ = ["WebSocketCommonProtocol"] + +logger = logging.getLogger(__name__) + + +# A WebSocket connection goes through the following four states, in order: + + +class State(enum.IntEnum): + CONNECTING, OPEN, CLOSING, CLOSED = range(4) + + +# In order to ensure consistency, the code always checks the current value of +# WebSocketCommonProtocol.state before assigning a new value and never yields +# between the check and the assignment. + + +class WebSocketCommonProtocol(asyncio.Protocol): + """ + :class:`~asyncio.Protocol` subclass implementing the data transfer phase. + + Once the WebSocket connection is established, during the data transfer + phase, the protocol is almost symmetrical between the server side and the + client side. :class:`WebSocketCommonProtocol` implements logic that's + shared between servers and clients.. + + Subclasses such as + :class:`~websockets.legacy.server.WebSocketServerProtocol` and + :class:`~websockets.legacy.client.WebSocketClientProtocol` implement the + opening handshake, which is different between servers and clients. + + :class:`WebSocketCommonProtocol` performs four functions: + + * It runs a task that stores incoming data frames in a queue and makes + them available with the :meth:`recv` coroutine. + * It sends outgoing data frames with the :meth:`send` coroutine. + * It deals with control frames automatically. + * It performs the closing handshake. + + :class:`WebSocketCommonProtocol` supports asynchronous iteration:: + + async for message in websocket: + await process(message) + + The iterator yields incoming messages. It exits normally when the + connection is closed with the close code 1000 (OK) or 1001 (going away). + It raises a :exc:`~websockets.exceptions.ConnectionClosedError` exception + when the connection is closed with any other code. + + Once the connection is open, a `Ping frame`_ is sent every + ``ping_interval`` seconds. This serves as a keepalive. It helps keeping + the connection open, especially in the presence of proxies with short + timeouts on inactive connections. Set ``ping_interval`` to ``None`` to + disable this behavior. + + .. _Ping frame: https://tools.ietf.org/html/rfc6455#section-5.5.2 + + If the corresponding `Pong frame`_ isn't received within ``ping_timeout`` + seconds, the connection is considered unusable and is closed with + code 1011. This ensures that the remote endpoint remains responsive. Set + ``ping_timeout`` to ``None`` to disable this behavior. + + .. _Pong frame: https://tools.ietf.org/html/rfc6455#section-5.5.3 + + The ``close_timeout`` parameter defines a maximum wait time in seconds for + completing the closing handshake and terminating the TCP connection. + :meth:`close` completes in at most ``4 * close_timeout`` on the server + side and ``5 * close_timeout`` on the client side. + + ``close_timeout`` needs to be a parameter of the protocol because + ``websockets`` usually calls :meth:`close` implicitly: + + - on the server side, when the connection handler terminates, + - on the client side, when exiting the context manager for the connection. + + To apply a timeout to any other API, wrap it in :func:`~asyncio.wait_for`. + + The ``max_size`` parameter enforces the maximum size for incoming messages + in bytes. The default value is 1 MiB. ``None`` disables the limit. If a + message larger than the maximum size is received, :meth:`recv` will + raise :exc:`~websockets.exceptions.ConnectionClosedError` and the + connection will be closed with code 1009. + + The ``max_queue`` parameter sets the maximum length of the queue that + holds incoming messages. The default value is ``32``. ``None`` disables + the limit. Messages are added to an in-memory queue when they're received; + then :meth:`recv` pops from that queue. In order to prevent excessive + memory consumption when messages are received faster than they can be + processed, the queue must be bounded. If the queue fills up, the protocol + stops processing incoming data until :meth:`recv` is called. In this + situation, various receive buffers (at least in ``asyncio`` and in the OS) + will fill up, then the TCP receive window will shrink, slowing down + transmission to avoid packet loss. + + Since Python can use up to 4 bytes of memory to represent a single + character, each connection may use up to ``4 * max_size * max_queue`` + bytes of memory to store incoming messages. By default, this is 128 MiB. + You may want to lower the limits, depending on your application's + requirements. + + The ``read_limit`` argument sets the high-water limit of the buffer for + incoming bytes. The low-water limit is half the high-water limit. The + default value is 64 KiB, half of asyncio's default (based on the current + implementation of :class:`~asyncio.StreamReader`). + + The ``write_limit`` argument sets the high-water limit of the buffer for + outgoing bytes. The low-water limit is a quarter of the high-water limit. + The default value is 64 KiB, equal to asyncio's default (based on the + current implementation of ``FlowControlMixin``). + + As soon as the HTTP request and response in the opening handshake are + processed: + + * the request path is available in the :attr:`path` attribute; + * the request and response HTTP headers are available in the + :attr:`request_headers` and :attr:`response_headers` attributes, + which are :class:`~websockets.http.Headers` instances. + + If a subprotocol was negotiated, it's available in the :attr:`subprotocol` + attribute. + + Once the connection is closed, the code is available in the + :attr:`close_code` attribute and the reason in :attr:`close_reason`. + + All these attributes must be treated as read-only. + + """ + + # There are only two differences between the client-side and server-side + # behavior: masking the payload and closing the underlying TCP connection. + # Set is_client = True/False and side = "client"/"server" to pick a side. + is_client: bool + side: str = "undefined" + + def __init__( + self, + *, + ping_interval: Optional[float] = 20, + ping_timeout: Optional[float] = 20, + close_timeout: Optional[float] = None, + max_size: Optional[int] = 2 ** 20, + max_queue: Optional[int] = 2 ** 5, + read_limit: int = 2 ** 16, + write_limit: int = 2 ** 16, + loop: Optional[asyncio.AbstractEventLoop] = None, + # The following arguments are kept only for backwards compatibility. + host: Optional[str] = None, + port: Optional[int] = None, + secure: Optional[bool] = None, + legacy_recv: bool = False, + timeout: Optional[float] = None, + ) -> None: + # Backwards compatibility: close_timeout used to be called timeout. + if timeout is None: + timeout = 10 + else: + warnings.warn("rename timeout to close_timeout", DeprecationWarning) + # If both are specified, timeout is ignored. + if close_timeout is None: + close_timeout = timeout + + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout + self.close_timeout = close_timeout + self.max_size = max_size + self.max_queue = max_queue + self.read_limit = read_limit + self.write_limit = write_limit + + if loop is None: + loop = asyncio.get_event_loop() + self.loop = loop + + self._host = host + self._port = port + self._secure = secure + self.legacy_recv = legacy_recv + + # Configure read buffer limits. The high-water limit is defined by + # ``self.read_limit``. The ``limit`` argument controls the line length + # limit and half the buffer limit of :class:`~asyncio.StreamReader`. + # That's why it must be set to half of ``self.read_limit``. + self.reader = asyncio.StreamReader(limit=read_limit // 2, loop=loop) + + # Copied from asyncio.FlowControlMixin + self._paused = False + self._drain_waiter: Optional[asyncio.Future[None]] = None + + self._drain_lock = asyncio.Lock( + loop=loop if sys.version_info[:2] < (3, 8) else None + ) + + # This class implements the data transfer and closing handshake, which + # are shared between the client-side and the server-side. + # Subclasses implement the opening handshake and, on success, execute + # :meth:`connection_open` to change the state to OPEN. + self.state = State.CONNECTING + logger.debug("%s - state = CONNECTING", self.side) + + # HTTP protocol parameters. + self.path: str + self.request_headers: Headers + self.response_headers: Headers + + # WebSocket protocol parameters. + self.extensions: List[Extension] = [] + self.subprotocol: Optional[Subprotocol] = None + + # The close code and reason are set when receiving a close frame or + # losing the TCP connection. + self.close_code: int + self.close_reason: str + + # Completed when the connection state becomes CLOSED. Translates the + # :meth:`connection_lost` callback to a :class:`~asyncio.Future` + # that can be awaited. (Other :class:`~asyncio.Protocol` callbacks are + # translated by ``self.stream_reader``). + self.connection_lost_waiter: asyncio.Future[None] = loop.create_future() + + # Queue of received messages. + self.messages: Deque[Data] = collections.deque() + self._pop_message_waiter: Optional[asyncio.Future[None]] = None + self._put_message_waiter: Optional[asyncio.Future[None]] = None + + # Protect sending fragmented messages. + self._fragmented_message_waiter: Optional[asyncio.Future[None]] = None + + # Mapping of ping IDs to pong waiters, in chronological order. + self.pings: Dict[bytes, asyncio.Future[None]] = {} + + # Task running the data transfer. + self.transfer_data_task: asyncio.Task[None] + + # Exception that occurred during data transfer, if any. + self.transfer_data_exc: Optional[BaseException] = None + + # Task sending keepalive pings. + self.keepalive_ping_task: asyncio.Task[None] + + # Task closing the TCP connection. + self.close_connection_task: asyncio.Task[None] + + # Copied from asyncio.FlowControlMixin + async def _drain_helper(self) -> None: # pragma: no cover + if self.connection_lost_waiter.done(): + raise ConnectionResetError("Connection lost") + if not self._paused: + return + waiter = self._drain_waiter + assert waiter is None or waiter.cancelled() + waiter = self.loop.create_future() + self._drain_waiter = waiter + await waiter + + # Copied from asyncio.StreamWriter + async def _drain(self) -> None: # pragma: no cover + if self.reader is not None: + exc = self.reader.exception() + if exc is not None: + raise exc + if self.transport is not None: + if self.transport.is_closing(): + # Yield to the event loop so connection_lost() may be + # called. Without this, _drain_helper() would return + # immediately, and code that calls + # write(...); yield from drain() + # in a loop would never call connection_lost(), so it + # would not see an error when the socket is closed. + await asyncio.sleep( + 0, loop=self.loop if sys.version_info[:2] < (3, 8) else None + ) + await self._drain_helper() + + def connection_open(self) -> None: + """ + Callback when the WebSocket opening handshake completes. + + Enter the OPEN state and start the data transfer phase. + + """ + # 4.1. The WebSocket Connection is Established. + assert self.state is State.CONNECTING + self.state = State.OPEN + logger.debug("%s - state = OPEN", self.side) + # Start the task that receives incoming WebSocket messages. + self.transfer_data_task = self.loop.create_task(self.transfer_data()) + # Start the task that sends pings at regular intervals. + self.keepalive_ping_task = self.loop.create_task(self.keepalive_ping()) + # Start the task that eventually closes the TCP connection. + self.close_connection_task = self.loop.create_task(self.close_connection()) + + @property + def host(self) -> Optional[str]: + alternative = "remote_address" if self.is_client else "local_address" + warnings.warn(f"use {alternative}[0] instead of host", DeprecationWarning) + return self._host + + @property + def port(self) -> Optional[int]: + alternative = "remote_address" if self.is_client else "local_address" + warnings.warn(f"use {alternative}[1] instead of port", DeprecationWarning) + return self._port + + @property + def secure(self) -> Optional[bool]: + warnings.warn("don't use secure", DeprecationWarning) + return self._secure + + # Public API + + @property + def local_address(self) -> Any: + """ + Local address of the connection as a ``(host, port)`` tuple. + + When the connection isn't open, ``local_address`` is ``None``. + + """ + try: + transport = self.transport + except AttributeError: + return None + else: + return transport.get_extra_info("sockname") + + @property + def remote_address(self) -> Any: + """ + Remote address of the connection as a ``(host, port)`` tuple. + + When the connection isn't open, ``remote_address`` is ``None``. + + """ + try: + transport = self.transport + except AttributeError: + return None + else: + return transport.get_extra_info("peername") + + @property + def open(self) -> bool: + """ + ``True`` when the connection is usable. + + It may be used to detect disconnections. However, this approach is + discouraged per the EAFP_ principle. + + When ``open`` is ``False``, using the connection raises a + :exc:`~websockets.exceptions.ConnectionClosed` exception. + + .. _EAFP: https://docs.python.org/3/glossary.html#term-eafp + + """ + return self.state is State.OPEN and not self.transfer_data_task.done() + + @property + def closed(self) -> bool: + """ + ``True`` once the connection is closed. + + Be aware that both :attr:`open` and :attr:`closed` are ``False`` during + the opening and closing sequences. + + """ + return self.state is State.CLOSED + + async def wait_closed(self) -> None: + """ + Wait until the connection is closed. + + This is identical to :attr:`closed`, except it can be awaited. + + This can make it easier to handle connection termination, regardless + of its cause, in tasks that interact with the WebSocket connection. + + """ + await asyncio.shield(self.connection_lost_waiter) + + async def __aiter__(self) -> AsyncIterator[Data]: + """ + Iterate on received messages. + + Exit normally when the connection is closed with code 1000 or 1001. + + Raise an exception in other cases. + + """ + try: + while True: + yield await self.recv() + except ConnectionClosedOK: + return + + async def recv(self) -> Data: + """ + Receive the next message. + + Return a :class:`str` for a text frame and :class:`bytes` for a binary + frame. + + When the end of the message stream is reached, :meth:`recv` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it + raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal + connection closure and + :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. + + .. versionchanged:: 3.0 + + :meth:`recv` used to return ``None`` instead. Refer to the + changelog for details. + + Canceling :meth:`recv` is safe. There's no risk of losing the next + message. The next invocation of :meth:`recv` will return it. This + makes it possible to enforce a timeout by wrapping :meth:`recv` in + :func:`~asyncio.wait_for`. + + :raises ~websockets.exceptions.ConnectionClosed: when the + connection is closed + :raises RuntimeError: if two coroutines call :meth:`recv` concurrently + + """ + if self._pop_message_waiter is not None: + raise RuntimeError( + "cannot call recv while another coroutine " + "is already waiting for the next message" + ) + + # Don't await self.ensure_open() here: + # - messages could be available in the queue even if the connection + # is closed; + # - messages could be received before the closing frame even if the + # connection is closing. + + # Wait until there's a message in the queue (if necessary) or the + # connection is closed. + while len(self.messages) <= 0: + pop_message_waiter: asyncio.Future[None] = self.loop.create_future() + self._pop_message_waiter = pop_message_waiter + try: + # If asyncio.wait() is canceled, it doesn't cancel + # pop_message_waiter and self.transfer_data_task. + await asyncio.wait( + [pop_message_waiter, self.transfer_data_task], + loop=self.loop if sys.version_info[:2] < (3, 8) else None, + return_when=asyncio.FIRST_COMPLETED, + ) + finally: + self._pop_message_waiter = None + + # If asyncio.wait(...) exited because self.transfer_data_task + # completed before receiving a new message, raise a suitable + # exception (or return None if legacy_recv is enabled). + if not pop_message_waiter.done(): + if self.legacy_recv: + return None # type: ignore + else: + # Wait until the connection is closed to raise + # ConnectionClosed with the correct code and reason. + await self.ensure_open() + + # Pop a message from the queue. + message = self.messages.popleft() + + # Notify transfer_data(). + if self._put_message_waiter is not None: + self._put_message_waiter.set_result(None) + self._put_message_waiter = None + + return message + + async def send( + self, message: Union[Data, Iterable[Data], AsyncIterable[Data]] + ) -> None: + """ + Send a message. + + A string (:class:`str`) is sent as a `Text frame`_. A bytestring or + bytes-like object (:class:`bytes`, :class:`bytearray`, or + :class:`memoryview`) is sent as a `Binary frame`_. + + .. _Text frame: https://tools.ietf.org/html/rfc6455#section-5.6 + .. _Binary frame: https://tools.ietf.org/html/rfc6455#section-5.6 + + :meth:`send` also accepts an iterable or an asynchronous iterable of + strings, bytestrings, or bytes-like objects. In that case the message + is fragmented. Each item is treated as a message fragment and sent in + its own frame. All items must be of the same type, or else + :meth:`send` will raise a :exc:`TypeError` and the connection will be + closed. + + :meth:`send` rejects dict-like objects because this is often an error. + If you wish to send the keys of a dict-like object as fragments, call + its :meth:`~dict.keys` method and pass the result to :meth:`send`. + + Canceling :meth:`send` is discouraged. Instead, you should close the + connection with :meth:`close`. Indeed, there only two situations where + :meth:`send` yields control to the event loop: + + 1. The write buffer is full. If you don't want to wait until enough + data is sent, your only alternative is to close the connection. + :meth:`close` will likely time out then abort the TCP connection. + 2. ``message`` is an asynchronous iterator. Stopping in the middle of + a fragmented message will cause a protocol error. Closing the + connection has the same effect. + + :raises TypeError: for unsupported inputs + + """ + await self.ensure_open() + + # While sending a fragmented message, prevent sending other messages + # until all fragments are sent. + while self._fragmented_message_waiter is not None: + await asyncio.shield(self._fragmented_message_waiter) + + # Unfragmented message -- this case must be handled first because + # strings and bytes-like objects are iterable. + + if isinstance(message, (str, bytes, bytearray, memoryview)): + opcode, data = prepare_data(message) + await self.write_frame(True, opcode, data) + + # Catch a common mistake -- passing a dict to send(). + + elif isinstance(message, Mapping): + raise TypeError("data is a dict-like object") + + # Fragmented message -- regular iterator. + + elif isinstance(message, Iterable): + + # Work around https://github.com/python/mypy/issues/6227 + message = cast(Iterable[Data], message) + + iter_message = iter(message) + try: + message_chunk = next(iter_message) + except StopIteration: + return + opcode, data = prepare_data(message_chunk) + + self._fragmented_message_waiter = asyncio.Future() + try: + # First fragment. + await self.write_frame(False, opcode, data) + + # Other fragments. + for message_chunk in iter_message: + confirm_opcode, data = prepare_data(message_chunk) + if confirm_opcode != opcode: + raise TypeError("data contains inconsistent types") + await self.write_frame(False, OP_CONT, data) + + # Final fragment. + await self.write_frame(True, OP_CONT, b"") + + except Exception: + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + self.fail_connection(1011) + raise + + finally: + self._fragmented_message_waiter.set_result(None) + self._fragmented_message_waiter = None + + # Fragmented message -- asynchronous iterator + + elif isinstance(message, AsyncIterable): + # aiter_message = aiter(message) without aiter + # https://github.com/python/mypy/issues/5738 + aiter_message = type(message).__aiter__(message) # type: ignore + try: + # message_chunk = anext(aiter_message) without anext + # https://github.com/python/mypy/issues/5738 + message_chunk = await type(aiter_message).__anext__( # type: ignore + aiter_message + ) + except StopAsyncIteration: + return + opcode, data = prepare_data(message_chunk) + + self._fragmented_message_waiter = asyncio.Future() + try: + # First fragment. + await self.write_frame(False, opcode, data) + + # Other fragments. + # https://github.com/python/mypy/issues/5738 + async for message_chunk in aiter_message: # type: ignore + confirm_opcode, data = prepare_data(message_chunk) + if confirm_opcode != opcode: + raise TypeError("data contains inconsistent types") + await self.write_frame(False, OP_CONT, data) + + # Final fragment. + await self.write_frame(True, OP_CONT, b"") + + except Exception: + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + self.fail_connection(1011) + raise + + finally: + self._fragmented_message_waiter.set_result(None) + self._fragmented_message_waiter = None + + else: + raise TypeError("data must be bytes, str, or iterable") + + async def close(self, code: int = 1000, reason: str = "") -> None: + """ + Perform the closing handshake. + + :meth:`close` waits for the other end to complete the handshake and + for the TCP connection to terminate. As a consequence, there's no need + to await :meth:`wait_closed`; :meth:`close` already does it. + + :meth:`close` is idempotent: it doesn't do anything once the + connection is closed. + + Wrapping :func:`close` in :func:`~asyncio.create_task` is safe, given + that errors during connection termination aren't particularly useful. + + Canceling :meth:`close` is discouraged. If it takes too long, you can + set a shorter ``close_timeout``. If you don't want to wait, let the + Python process exit, then the OS will close the TCP connection. + + :param code: WebSocket close code + :param reason: WebSocket close reason + + """ + try: + await asyncio.wait_for( + self.write_close_frame(serialize_close(code, reason)), + self.close_timeout, + loop=self.loop if sys.version_info[:2] < (3, 8) else None, + ) + except asyncio.TimeoutError: + # If the close frame cannot be sent because the send buffers + # are full, the closing handshake won't complete anyway. + # Fail the connection to shut down faster. + self.fail_connection() + + # If no close frame is received within the timeout, wait_for() cancels + # the data transfer task and raises TimeoutError. + + # If close() is called multiple times concurrently and one of these + # calls hits the timeout, the data transfer task will be cancelled. + # Other calls will receive a CancelledError here. + + try: + # If close() is canceled during the wait, self.transfer_data_task + # is canceled before the timeout elapses. + await asyncio.wait_for( + self.transfer_data_task, + self.close_timeout, + loop=self.loop if sys.version_info[:2] < (3, 8) else None, + ) + except (asyncio.TimeoutError, asyncio.CancelledError): + pass + + # Wait for the close connection task to close the TCP connection. + await asyncio.shield(self.close_connection_task) + + async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: + """ + Send a ping. + + Return a :class:`~asyncio.Future` that will be completed when the + corresponding pong is received. You can ignore it if you don't intend + to wait. + + A ping may serve as a keepalive or as a check that the remote endpoint + received all messages up to this point:: + + pong_waiter = await ws.ping() + await pong_waiter # only if you want to wait for the pong + + By default, the ping contains four random bytes. This payload may be + overridden with the optional ``data`` argument which must be a string + (which will be encoded to UTF-8) or a bytes-like object. + + Canceling :meth:`ping` is discouraged. If :meth:`ping` doesn't return + immediately, it means the write buffer is full. If you don't want to + wait, you should close the connection. + + Canceling the :class:`~asyncio.Future` returned by :meth:`ping` has no + effect. + + """ + await self.ensure_open() + + if data is not None: + data = prepare_ctrl(data) + + # Protect against duplicates if a payload is explicitly set. + if data in self.pings: + raise ValueError("already waiting for a pong with the same data") + + # Generate a unique random payload otherwise. + while data is None or data in self.pings: + data = struct.pack("!I", random.getrandbits(32)) + + self.pings[data] = self.loop.create_future() + + await self.write_frame(True, OP_PING, data) + + return asyncio.shield(self.pings[data]) + + async def pong(self, data: Data = b"") -> None: + """ + Send a pong. + + An unsolicited pong may serve as a unidirectional heartbeat. + + The payload may be set with the optional ``data`` argument which must + be a string (which will be encoded to UTF-8) or a bytes-like object. + + Canceling :meth:`pong` is discouraged for the same reason as + :meth:`ping`. + + """ + await self.ensure_open() + + data = prepare_ctrl(data) + + await self.write_frame(True, OP_PONG, data) + + # Private methods - no guarantees. + + def connection_closed_exc(self) -> ConnectionClosed: + exception: ConnectionClosed + if self.close_code == 1000 or self.close_code == 1001: + exception = ConnectionClosedOK(self.close_code, self.close_reason) + else: + exception = ConnectionClosedError(self.close_code, self.close_reason) + # Chain to the exception that terminated data transfer, if any. + exception.__cause__ = self.transfer_data_exc + return exception + + async def ensure_open(self) -> None: + """ + Check that the WebSocket connection is open. + + Raise :exc:`~websockets.exceptions.ConnectionClosed` if it isn't. + + """ + # Handle cases from most common to least common for performance. + if self.state is State.OPEN: + # If self.transfer_data_task exited without a closing handshake, + # self.close_connection_task may be closing the connection, going + # straight from OPEN to CLOSED. + if self.transfer_data_task.done(): + await asyncio.shield(self.close_connection_task) + raise self.connection_closed_exc() + else: + return + + if self.state is State.CLOSED: + raise self.connection_closed_exc() + + if self.state is State.CLOSING: + # If we started the closing handshake, wait for its completion to + # get the proper close code and reason. self.close_connection_task + # will complete within 4 or 5 * close_timeout after close(). The + # CLOSING state also occurs when failing the connection. In that + # case self.close_connection_task will complete even faster. + await asyncio.shield(self.close_connection_task) + raise self.connection_closed_exc() + + # Control may only reach this point in buggy third-party subclasses. + assert self.state is State.CONNECTING + raise InvalidState("WebSocket connection isn't established yet") + + async def transfer_data(self) -> None: + """ + Read incoming messages and put them in a queue. + + This coroutine runs in a task until the closing handshake is started. + + """ + try: + while True: + message = await self.read_message() + + # Exit the loop when receiving a close frame. + if message is None: + break + + # Wait until there's room in the queue (if necessary). + if self.max_queue is not None: + while len(self.messages) >= self.max_queue: + self._put_message_waiter = self.loop.create_future() + try: + await asyncio.shield(self._put_message_waiter) + finally: + self._put_message_waiter = None + + # Put the message in the queue. + self.messages.append(message) + + # Notify recv(). + if self._pop_message_waiter is not None: + self._pop_message_waiter.set_result(None) + self._pop_message_waiter = None + + except asyncio.CancelledError as exc: + self.transfer_data_exc = exc + # If fail_connection() cancels this task, avoid logging the error + # twice and failing the connection again. + raise + + except ProtocolError as exc: + self.transfer_data_exc = exc + self.fail_connection(1002) + + except (ConnectionError, TimeoutError, EOFError) as exc: + # Reading data with self.reader.readexactly may raise: + # - most subclasses of ConnectionError if the TCP connection + # breaks, is reset, or is aborted; + # - TimeoutError if the TCP connection times out; + # - IncompleteReadError, a subclass of EOFError, if fewer + # bytes are available than requested. + self.transfer_data_exc = exc + self.fail_connection(1006) + + except UnicodeDecodeError as exc: + self.transfer_data_exc = exc + self.fail_connection(1007) + + except PayloadTooBig as exc: + self.transfer_data_exc = exc + self.fail_connection(1009) + + except Exception as exc: + # This shouldn't happen often because exceptions expected under + # regular circumstances are handled above. If it does, consider + # catching and handling more exceptions. + logger.error("Error in data transfer", exc_info=True) + + self.transfer_data_exc = exc + self.fail_connection(1011) + + async def read_message(self) -> Optional[Data]: + """ + Read a single message from the connection. + + Re-assemble data frames if the message is fragmented. + + Return ``None`` when the closing handshake is started. + + """ + frame = await self.read_data_frame(max_size=self.max_size) + + # A close frame was received. + if frame is None: + return None + + if frame.opcode == OP_TEXT: + text = True + elif frame.opcode == OP_BINARY: + text = False + else: # frame.opcode == OP_CONT + raise ProtocolError("unexpected opcode") + + # Shortcut for the common case - no fragmentation + if frame.fin: + return frame.data.decode("utf-8") if text else frame.data + + # 5.4. Fragmentation + chunks: List[Data] = [] + max_size = self.max_size + if text: + decoder_factory = codecs.getincrementaldecoder("utf-8") + decoder = decoder_factory(errors="strict") + if max_size is None: + + def append(frame: Frame) -> None: + nonlocal chunks + chunks.append(decoder.decode(frame.data, frame.fin)) + + else: + + def append(frame: Frame) -> None: + nonlocal chunks, max_size + chunks.append(decoder.decode(frame.data, frame.fin)) + assert isinstance(max_size, int) + max_size -= len(frame.data) + + else: + if max_size is None: + + def append(frame: Frame) -> None: + nonlocal chunks + chunks.append(frame.data) + + else: + + def append(frame: Frame) -> None: + nonlocal chunks, max_size + chunks.append(frame.data) + assert isinstance(max_size, int) + max_size -= len(frame.data) + + append(frame) + + while not frame.fin: + frame = await self.read_data_frame(max_size=max_size) + if frame is None: + raise ProtocolError("incomplete fragmented message") + if frame.opcode != OP_CONT: + raise ProtocolError("unexpected opcode") + append(frame) + + # mypy cannot figure out that chunks have the proper type. + return ("" if text else b"").join(chunks) # type: ignore + + async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: + """ + Read a single data frame from the connection. + + Process control frames received before the next data frame. + + Return ``None`` if a close frame is encountered before any data frame. + + """ + # 6.2. Receiving Data + while True: + frame = await self.read_frame(max_size) + + # 5.5. Control Frames + if frame.opcode == OP_CLOSE: + # 7.1.5. The WebSocket Connection Close Code + # 7.1.6. The WebSocket Connection Close Reason + self.close_code, self.close_reason = parse_close(frame.data) + try: + # Echo the original data instead of re-serializing it with + # serialize_close() because that fails when the close frame + # is empty and parse_close() synthetizes a 1005 close code. + await self.write_close_frame(frame.data) + except ConnectionClosed: + # It doesn't really matter if the connection was closed + # before we could send back a close frame. + pass + return None + + elif frame.opcode == OP_PING: + # Answer pings. + ping_hex = frame.data.hex() or "[empty]" + logger.debug( + "%s - received ping, sending pong: %s", self.side, ping_hex + ) + await self.pong(frame.data) + + elif frame.opcode == OP_PONG: + # Acknowledge pings on solicited pongs. + if frame.data in self.pings: + logger.debug( + "%s - received solicited pong: %s", + self.side, + frame.data.hex() or "[empty]", + ) + # Acknowledge all pings up to the one matching this pong. + ping_id = None + ping_ids = [] + for ping_id, ping in self.pings.items(): + ping_ids.append(ping_id) + if not ping.done(): + ping.set_result(None) + if ping_id == frame.data: + break + else: # pragma: no cover + assert False, "ping_id is in self.pings" + # Remove acknowledged pings from self.pings. + for ping_id in ping_ids: + del self.pings[ping_id] + ping_ids = ping_ids[:-1] + if ping_ids: + pings_hex = ", ".join( + ping_id.hex() or "[empty]" for ping_id in ping_ids + ) + plural = "s" if len(ping_ids) > 1 else "" + logger.debug( + "%s - acknowledged previous ping%s: %s", + self.side, + plural, + pings_hex, + ) + else: + logger.debug( + "%s - received unsolicited pong: %s", + self.side, + frame.data.hex() or "[empty]", + ) + + # 5.6. Data Frames + else: + return frame + + async def read_frame(self, max_size: Optional[int]) -> Frame: + """ + Read a single frame from the connection. + + """ + frame = await Frame.read( + self.reader.readexactly, + mask=not self.is_client, + max_size=max_size, + extensions=self.extensions, + ) + logger.debug("%s < %r", self.side, frame) + return frame + + async def write_frame( + self, fin: bool, opcode: int, data: bytes, *, _expected_state: int = State.OPEN + ) -> None: + # Defensive assertion for protocol compliance. + if self.state is not _expected_state: # pragma: no cover + raise InvalidState( + f"Cannot write to a WebSocket in the {self.state.name} state" + ) + + frame = Frame(fin, Opcode(opcode), data) + logger.debug("%s > %r", self.side, frame) + frame.write( + self.transport.write, mask=self.is_client, extensions=self.extensions + ) + + try: + # drain() cannot be called concurrently by multiple coroutines: + # http://bugs.python.org/issue29930. Remove this lock when no + # version of Python where this bugs exists is supported anymore. + async with self._drain_lock: + # Handle flow control automatically. + await self._drain() + except ConnectionError: + # Terminate the connection if the socket died. + self.fail_connection() + # Wait until the connection is closed to raise ConnectionClosed + # with the correct code and reason. + await self.ensure_open() + + async def write_close_frame(self, data: bytes = b"") -> None: + """ + Write a close frame if and only if the connection state is OPEN. + + This dedicated coroutine must be used for writing close frames to + ensure that at most one close frame is sent on a given connection. + + """ + # Test and set the connection state before sending the close frame to + # avoid sending two frames in case of concurrent calls. + if self.state is State.OPEN: + # 7.1.3. The WebSocket Closing Handshake is Started + self.state = State.CLOSING + logger.debug("%s - state = CLOSING", self.side) + + # 7.1.2. Start the WebSocket Closing Handshake + await self.write_frame(True, OP_CLOSE, data, _expected_state=State.CLOSING) + + async def keepalive_ping(self) -> None: + """ + Send a Ping frame and wait for a Pong frame at regular intervals. + + This coroutine exits when the connection terminates and one of the + following happens: + + - :meth:`ping` raises :exc:`ConnectionClosed`, or + - :meth:`close_connection` cancels :attr:`keepalive_ping_task`. + + """ + if self.ping_interval is None: + return + + try: + while True: + await asyncio.sleep( + self.ping_interval, + loop=self.loop if sys.version_info[:2] < (3, 8) else None, + ) + + # ping() raises CancelledError if the connection is closed, + # when close_connection() cancels self.keepalive_ping_task. + + # ping() raises ConnectionClosed if the connection is lost, + # when connection_lost() calls abort_pings(). + + pong_waiter = await self.ping() + + if self.ping_timeout is not None: + try: + await asyncio.wait_for( + pong_waiter, + self.ping_timeout, + loop=self.loop if sys.version_info[:2] < (3, 8) else None, + ) + except asyncio.TimeoutError: + logger.debug("%s ! timed out waiting for pong", self.side) + self.fail_connection(1011) + break + + # Remove this branch when dropping support for Python < 3.8 + # because CancelledError no longer inherits Exception. + except asyncio.CancelledError: + raise + + except ConnectionClosed: + pass + + except Exception: + logger.warning("Unexpected exception in keepalive ping task", exc_info=True) + + async def close_connection(self) -> None: + """ + 7.1.1. Close the WebSocket Connection + + When the opening handshake succeeds, :meth:`connection_open` starts + this coroutine in a task. It waits for the data transfer phase to + complete then it closes the TCP connection cleanly. + + When the opening handshake fails, :meth:`fail_connection` does the + same. There's no data transfer phase in that case. + + """ + try: + # Wait for the data transfer phase to complete. + if hasattr(self, "transfer_data_task"): + try: + await self.transfer_data_task + except asyncio.CancelledError: + pass + + # Cancel the keepalive ping task. + if hasattr(self, "keepalive_ping_task"): + self.keepalive_ping_task.cancel() + + # A client should wait for a TCP close from the server. + if self.is_client and hasattr(self, "transfer_data_task"): + if await self.wait_for_connection_lost(): + # Coverage marks this line as a partially executed branch. + # I supect a bug in coverage. Ignore it for now. + return # pragma: no cover + logger.debug("%s ! timed out waiting for TCP close", self.side) + + # Half-close the TCP connection if possible (when there's no TLS). + if self.transport.can_write_eof(): + logger.debug("%s x half-closing TCP connection", self.side) + self.transport.write_eof() + + if await self.wait_for_connection_lost(): + # Coverage marks this line as a partially executed branch. + # I supect a bug in coverage. Ignore it for now. + return # pragma: no cover + logger.debug("%s ! timed out waiting for TCP close", self.side) + + finally: + # The try/finally ensures that the transport never remains open, + # even if this coroutine is canceled (for example). + + # If connection_lost() was called, the TCP connection is closed. + # However, if TLS is enabled, the transport still needs closing. + # Else asyncio complains: ResourceWarning: unclosed transport. + if self.connection_lost_waiter.done() and self.transport.is_closing(): + return + + # Close the TCP connection. Buffers are flushed asynchronously. + logger.debug("%s x closing TCP connection", self.side) + self.transport.close() + + if await self.wait_for_connection_lost(): + return + logger.debug("%s ! timed out waiting for TCP close", self.side) + + # Abort the TCP connection. Buffers are discarded. + logger.debug("%s x aborting TCP connection", self.side) + self.transport.abort() + + # connection_lost() is called quickly after aborting. + # Coverage marks this line as a partially executed branch. + # I supect a bug in coverage. Ignore it for now. + await self.wait_for_connection_lost() # pragma: no cover + + async def wait_for_connection_lost(self) -> bool: + """ + Wait until the TCP connection is closed or ``self.close_timeout`` elapses. + + Return ``True`` if the connection is closed and ``False`` otherwise. + + """ + if not self.connection_lost_waiter.done(): + try: + await asyncio.wait_for( + asyncio.shield(self.connection_lost_waiter), + self.close_timeout, + loop=self.loop if sys.version_info[:2] < (3, 8) else None, + ) + except asyncio.TimeoutError: + pass + # Re-check self.connection_lost_waiter.done() synchronously because + # connection_lost() could run between the moment the timeout occurs + # and the moment this coroutine resumes running. + return self.connection_lost_waiter.done() + + def fail_connection(self, code: int = 1006, reason: str = "") -> None: + """ + 7.1.7. Fail the WebSocket Connection + + This requires: + + 1. Stopping all processing of incoming data, which means cancelling + :attr:`transfer_data_task`. The close code will be 1006 unless a + close frame was received earlier. + + 2. Sending a close frame with an appropriate code if the opening + handshake succeeded and the other side is likely to process it. + + 3. Closing the connection. :meth:`close_connection` takes care of + this once :attr:`transfer_data_task` exits after being canceled. + + (The specification describes these steps in the opposite order.) + + """ + logger.debug( + "%s ! failing %s WebSocket connection with code %d", + self.side, + self.state.name, + code, + ) + + # Cancel transfer_data_task if the opening handshake succeeded. + # cancel() is idempotent and ignored if the task is done already. + if hasattr(self, "transfer_data_task"): + self.transfer_data_task.cancel() + + # Send a close frame when the state is OPEN (a close frame was already + # sent if it's CLOSING), except when failing the connection because of + # an error reading from or writing to the network. + # Don't send a close frame if the connection is broken. + if code != 1006 and self.state is State.OPEN: + + frame_data = serialize_close(code, reason) + + # Write the close frame without draining the write buffer. + + # Keeping fail_connection() synchronous guarantees it can't + # get stuck and simplifies the implementation of the callers. + # Not drainig the write buffer is acceptable in this context. + + # This duplicates a few lines of code from write_close_frame() + # and write_frame(). + + self.state = State.CLOSING + logger.debug("%s - state = CLOSING", self.side) + + frame = Frame(True, OP_CLOSE, frame_data) + logger.debug("%s > %r", self.side, frame) + frame.write( + self.transport.write, mask=self.is_client, extensions=self.extensions + ) + + # Start close_connection_task if the opening handshake didn't succeed. + if not hasattr(self, "close_connection_task"): + self.close_connection_task = self.loop.create_task(self.close_connection()) + + def abort_pings(self) -> None: + """ + Raise ConnectionClosed in pending keepalive pings. + + They'll never receive a pong once the connection is closed. + + """ + assert self.state is State.CLOSED + exc = self.connection_closed_exc() + + for ping in self.pings.values(): + ping.set_exception(exc) + # If the exception is never retrieved, it will be logged when ping + # is garbage-collected. This is confusing for users. + # Given that ping is done (with an exception), canceling it does + # nothing, but it prevents logging the exception. + ping.cancel() + + if self.pings: + pings_hex = ", ".join(ping_id.hex() or "[empty]" for ping_id in self.pings) + plural = "s" if len(self.pings) > 1 else "" + logger.debug( + "%s - aborted pending ping%s: %s", self.side, plural, pings_hex + ) + + # asyncio.Protocol methods + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + """ + Configure write buffer limits. + + The high-water limit is defined by ``self.write_limit``. + + The low-water limit currently defaults to ``self.write_limit // 4`` in + :meth:`~asyncio.WriteTransport.set_write_buffer_limits`, which should + be all right for reasonable use cases of this library. + + This is the earliest point where we can get hold of the transport, + which means it's the best point for configuring it. + + """ + logger.debug("%s - event = connection_made(%s)", self.side, transport) + + transport = cast(asyncio.Transport, transport) + transport.set_write_buffer_limits(self.write_limit) + self.transport = transport + + # Copied from asyncio.StreamReaderProtocol + self.reader.set_transport(transport) + + def connection_lost(self, exc: Optional[Exception]) -> None: + """ + 7.1.4. The WebSocket Connection is Closed. + + """ + logger.debug("%s - event = connection_lost(%s)", self.side, exc) + self.state = State.CLOSED + logger.debug("%s - state = CLOSED", self.side) + if not hasattr(self, "close_code"): + self.close_code = 1006 + if not hasattr(self, "close_reason"): + self.close_reason = "" + logger.debug( + "%s x code = %d, reason = %s", + self.side, + self.close_code, + self.close_reason or "[no reason]", + ) + self.abort_pings() + # If self.connection_lost_waiter isn't pending, that's a bug, because: + # - it's set only here in connection_lost() which is called only once; + # - it must never be canceled. + self.connection_lost_waiter.set_result(None) + + if True: # pragma: no cover + + # Copied from asyncio.StreamReaderProtocol + if self.reader is not None: + if exc is None: + self.reader.feed_eof() + else: + self.reader.set_exception(exc) + + # Copied from asyncio.FlowControlMixin + # Wake up the writer if currently paused. + if not self._paused: + return + waiter = self._drain_waiter + if waiter is None: + return + self._drain_waiter = None + if waiter.done(): + return + if exc is None: + waiter.set_result(None) + else: + waiter.set_exception(exc) + + def pause_writing(self) -> None: # pragma: no cover + assert not self._paused + self._paused = True + + def resume_writing(self) -> None: # pragma: no cover + assert self._paused + self._paused = False + + waiter = self._drain_waiter + if waiter is not None: + self._drain_waiter = None + if not waiter.done(): + waiter.set_result(None) + + def data_received(self, data: bytes) -> None: + logger.debug("%s - event = data_received(<%d bytes>)", self.side, len(data)) + self.reader.feed_data(data) + + def eof_received(self) -> None: + """ + Close the transport after receiving EOF. + + The WebSocket protocol has its own closing handshake: endpoints close + the TCP or TLS connection after sending and receiving a close frame. + + As a consequence, they never need to write after receiving EOF, so + there's no reason to keep the transport open by returning ``True``. + + Besides, that doesn't work on TLS connections. + + """ + logger.debug("%s - event = eof_received()", self.side) + self.reader.feed_eof() diff --git a/src/websockets/asyncio_server.py b/src/websockets/legacy/server.py similarity index 97% rename from src/websockets/asyncio_server.py rename to src/websockets/legacy/server.py index 79ceddf4b..4dea9459d 100644 --- a/src/websockets/asyncio_server.py +++ b/src/websockets/legacy/server.py @@ -1,5 +1,5 @@ """ -:mod:`websockets.server` defines the WebSocket server APIs. +:mod:`websockets.legacy.server` defines the WebSocket server APIs. """ @@ -28,8 +28,8 @@ cast, ) -from .datastructures import Headers, HeadersLike, MultipleValuesError -from .exceptions import ( +from ..datastructures import Headers, HeadersLike, MultipleValuesError +from ..exceptions import ( AbortHandshake, InvalidHandshake, InvalidHeader, @@ -38,14 +38,14 @@ InvalidUpgrade, NegotiationError, ) -from .extensions.base import Extension, ServerExtensionFactory -from .extensions.permessage_deflate import enable_server_permessage_deflate -from .handshake_legacy import build_response, check_request -from .headers import build_extension, parse_extension, parse_subprotocol -from .http import USER_AGENT -from .http_legacy import read_request +from ..extensions.base import Extension, ServerExtensionFactory +from ..extensions.permessage_deflate import enable_server_permessage_deflate +from ..headers import build_extension, parse_extension, parse_subprotocol +from ..http import USER_AGENT +from ..typing import ExtensionHeader, Origin, Subprotocol +from .handshake import build_response, check_request +from .http import read_request from .protocol import WebSocketCommonProtocol -from .typing import ExtensionHeader, Origin, Subprotocol __all__ = ["serve", "unix_serve", "WebSocketServerProtocol", "WebSocketServer"] @@ -598,7 +598,7 @@ async def handshake( class WebSocketServer: """ - WebSocket server returned by :func:`~websockets.server.serve`. + WebSocket server returned by :func:`serve`. This class provides the same interface as :class:`~asyncio.AbstractServer`, namely the @@ -770,9 +770,9 @@ class Serve: performs the closing handshake and closes the connection. Awaiting :func:`serve` yields a :class:`WebSocketServer`. This instance - provides :meth:`~websockets.server.WebSocketServer.close` and - :meth:`~websockets.server.WebSocketServer.wait_closed` methods for - terminating the server and cleaning up its resources. + provides :meth:`~WebSocketServer.close` and + :meth:`~WebSocketServer.wait_closed` methods for terminating the server + and cleaning up its resources. When a server is closed with :meth:`~WebSocketServer.close`, it closes all connections with close code 1001 (going away). Connections handlers, which @@ -835,11 +835,11 @@ class Serve: :meth:`~WebSocketServerProtocol.select_subprotocol` for details Since there's no useful way to propagate exceptions triggered in handlers, - they're sent to the ``'websockets.asyncio_server'`` logger instead. + they're sent to the ``'websockets.legacy.server'`` logger instead. Debugging is much easier if you configure logging to print them:: import logging - logger = logging.getLogger("websockets.asyncio_server") + logger = logging.getLogger("websockets.legacy.server") logger.setLevel(logging.ERROR) logger.addHandler(logging.StreamHandler()) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 1552fb060..287f92a57 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -1,1465 +1 @@ -""" -:mod:`websockets.protocol` handles WebSocket control and data frames. - -See `sections 4 to 8 of RFC 6455`_. - -.. _sections 4 to 8 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-4 - -""" - -import asyncio -import codecs -import collections -import enum -import logging -import random -import struct -import sys -import warnings -from typing import ( - Any, - AsyncIterable, - AsyncIterator, - Awaitable, - Deque, - Dict, - Iterable, - List, - Mapping, - Optional, - Union, - cast, -) - -from .datastructures import Headers -from .exceptions import ( - ConnectionClosed, - ConnectionClosedError, - ConnectionClosedOK, - InvalidState, - PayloadTooBig, - ProtocolError, -) -from .extensions.base import Extension -from .frames import ( - OP_BINARY, - OP_CLOSE, - OP_CONT, - OP_PING, - OP_PONG, - OP_TEXT, - Opcode, - parse_close, - prepare_ctrl, - prepare_data, - serialize_close, -) - - -with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", "websockets.framing is deprecated", DeprecationWarning - ) - from .framing import Frame - -from .typing import Data, Subprotocol - - -__all__ = ["WebSocketCommonProtocol"] - -logger = logging.getLogger(__name__) - - -# A WebSocket connection goes through the following four states, in order: - - -class State(enum.IntEnum): - CONNECTING, OPEN, CLOSING, CLOSED = range(4) - - -# In order to ensure consistency, the code always checks the current value of -# WebSocketCommonProtocol.state before assigning a new value and never yields -# between the check and the assignment. - - -class WebSocketCommonProtocol(asyncio.Protocol): - """ - :class:`~asyncio.Protocol` subclass implementing the data transfer phase. - - Once the WebSocket connection is established, during the data transfer - phase, the protocol is almost symmetrical between the server side and the - client side. :class:`WebSocketCommonProtocol` implements logic that's - shared between servers and clients.. - - Subclasses such as :class:`~websockets.server.WebSocketServerProtocol` and - :class:`~websockets.client.WebSocketClientProtocol` implement the opening - handshake, which is different between servers and clients. - - :class:`WebSocketCommonProtocol` performs four functions: - - * It runs a task that stores incoming data frames in a queue and makes - them available with the :meth:`recv` coroutine. - * It sends outgoing data frames with the :meth:`send` coroutine. - * It deals with control frames automatically. - * It performs the closing handshake. - - :class:`WebSocketCommonProtocol` supports asynchronous iteration:: - - async for message in websocket: - await process(message) - - The iterator yields incoming messages. It exits normally when the - connection is closed with the close code 1000 (OK) or 1001 (going away). - It raises a :exc:`~websockets.exceptions.ConnectionClosedError` exception - when the connection is closed with any other code. - - Once the connection is open, a `Ping frame`_ is sent every - ``ping_interval`` seconds. This serves as a keepalive. It helps keeping - the connection open, especially in the presence of proxies with short - timeouts on inactive connections. Set ``ping_interval`` to ``None`` to - disable this behavior. - - .. _Ping frame: https://tools.ietf.org/html/rfc6455#section-5.5.2 - - If the corresponding `Pong frame`_ isn't received within ``ping_timeout`` - seconds, the connection is considered unusable and is closed with - code 1011. This ensures that the remote endpoint remains responsive. Set - ``ping_timeout`` to ``None`` to disable this behavior. - - .. _Pong frame: https://tools.ietf.org/html/rfc6455#section-5.5.3 - - The ``close_timeout`` parameter defines a maximum wait time in seconds for - completing the closing handshake and terminating the TCP connection. - :meth:`close` completes in at most ``4 * close_timeout`` on the server - side and ``5 * close_timeout`` on the client side. - - ``close_timeout`` needs to be a parameter of the protocol because - ``websockets`` usually calls :meth:`close` implicitly: - - - on the server side, when the connection handler terminates, - - on the client side, when exiting the context manager for the connection. - - To apply a timeout to any other API, wrap it in :func:`~asyncio.wait_for`. - - The ``max_size`` parameter enforces the maximum size for incoming messages - in bytes. The default value is 1 MiB. ``None`` disables the limit. If a - message larger than the maximum size is received, :meth:`recv` will - raise :exc:`~websockets.exceptions.ConnectionClosedError` and the - connection will be closed with code 1009. - - The ``max_queue`` parameter sets the maximum length of the queue that - holds incoming messages. The default value is ``32``. ``None`` disables - the limit. Messages are added to an in-memory queue when they're received; - then :meth:`recv` pops from that queue. In order to prevent excessive - memory consumption when messages are received faster than they can be - processed, the queue must be bounded. If the queue fills up, the protocol - stops processing incoming data until :meth:`recv` is called. In this - situation, various receive buffers (at least in ``asyncio`` and in the OS) - will fill up, then the TCP receive window will shrink, slowing down - transmission to avoid packet loss. - - Since Python can use up to 4 bytes of memory to represent a single - character, each connection may use up to ``4 * max_size * max_queue`` - bytes of memory to store incoming messages. By default, this is 128 MiB. - You may want to lower the limits, depending on your application's - requirements. - - The ``read_limit`` argument sets the high-water limit of the buffer for - incoming bytes. The low-water limit is half the high-water limit. The - default value is 64 KiB, half of asyncio's default (based on the current - implementation of :class:`~asyncio.StreamReader`). - - The ``write_limit`` argument sets the high-water limit of the buffer for - outgoing bytes. The low-water limit is a quarter of the high-water limit. - The default value is 64 KiB, equal to asyncio's default (based on the - current implementation of ``FlowControlMixin``). - - As soon as the HTTP request and response in the opening handshake are - processed: - - * the request path is available in the :attr:`path` attribute; - * the request and response HTTP headers are available in the - :attr:`request_headers` and :attr:`response_headers` attributes, - which are :class:`~websockets.http.Headers` instances. - - If a subprotocol was negotiated, it's available in the :attr:`subprotocol` - attribute. - - Once the connection is closed, the code is available in the - :attr:`close_code` attribute and the reason in :attr:`close_reason`. - - All these attributes must be treated as read-only. - - """ - - # There are only two differences between the client-side and server-side - # behavior: masking the payload and closing the underlying TCP connection. - # Set is_client = True/False and side = "client"/"server" to pick a side. - is_client: bool - side: str = "undefined" - - def __init__( - self, - *, - ping_interval: Optional[float] = 20, - ping_timeout: Optional[float] = 20, - close_timeout: Optional[float] = None, - max_size: Optional[int] = 2 ** 20, - max_queue: Optional[int] = 2 ** 5, - read_limit: int = 2 ** 16, - write_limit: int = 2 ** 16, - loop: Optional[asyncio.AbstractEventLoop] = None, - # The following arguments are kept only for backwards compatibility. - host: Optional[str] = None, - port: Optional[int] = None, - secure: Optional[bool] = None, - legacy_recv: bool = False, - timeout: Optional[float] = None, - ) -> None: - # Backwards compatibility: close_timeout used to be called timeout. - if timeout is None: - timeout = 10 - else: - warnings.warn("rename timeout to close_timeout", DeprecationWarning) - # If both are specified, timeout is ignored. - if close_timeout is None: - close_timeout = timeout - - self.ping_interval = ping_interval - self.ping_timeout = ping_timeout - self.close_timeout = close_timeout - self.max_size = max_size - self.max_queue = max_queue - self.read_limit = read_limit - self.write_limit = write_limit - - if loop is None: - loop = asyncio.get_event_loop() - self.loop = loop - - self._host = host - self._port = port - self._secure = secure - self.legacy_recv = legacy_recv - - # Configure read buffer limits. The high-water limit is defined by - # ``self.read_limit``. The ``limit`` argument controls the line length - # limit and half the buffer limit of :class:`~asyncio.StreamReader`. - # That's why it must be set to half of ``self.read_limit``. - self.reader = asyncio.StreamReader(limit=read_limit // 2, loop=loop) - - # Copied from asyncio.FlowControlMixin - self._paused = False - self._drain_waiter: Optional[asyncio.Future[None]] = None - - self._drain_lock = asyncio.Lock( - loop=loop if sys.version_info[:2] < (3, 8) else None - ) - - # This class implements the data transfer and closing handshake, which - # are shared between the client-side and the server-side. - # Subclasses implement the opening handshake and, on success, execute - # :meth:`connection_open` to change the state to OPEN. - self.state = State.CONNECTING - logger.debug("%s - state = CONNECTING", self.side) - - # HTTP protocol parameters. - self.path: str - self.request_headers: Headers - self.response_headers: Headers - - # WebSocket protocol parameters. - self.extensions: List[Extension] = [] - self.subprotocol: Optional[Subprotocol] = None - - # The close code and reason are set when receiving a close frame or - # losing the TCP connection. - self.close_code: int - self.close_reason: str - - # Completed when the connection state becomes CLOSED. Translates the - # :meth:`connection_lost` callback to a :class:`~asyncio.Future` - # that can be awaited. (Other :class:`~asyncio.Protocol` callbacks are - # translated by ``self.stream_reader``). - self.connection_lost_waiter: asyncio.Future[None] = loop.create_future() - - # Queue of received messages. - self.messages: Deque[Data] = collections.deque() - self._pop_message_waiter: Optional[asyncio.Future[None]] = None - self._put_message_waiter: Optional[asyncio.Future[None]] = None - - # Protect sending fragmented messages. - self._fragmented_message_waiter: Optional[asyncio.Future[None]] = None - - # Mapping of ping IDs to pong waiters, in chronological order. - self.pings: Dict[bytes, asyncio.Future[None]] = {} - - # Task running the data transfer. - self.transfer_data_task: asyncio.Task[None] - - # Exception that occurred during data transfer, if any. - self.transfer_data_exc: Optional[BaseException] = None - - # Task sending keepalive pings. - self.keepalive_ping_task: asyncio.Task[None] - - # Task closing the TCP connection. - self.close_connection_task: asyncio.Task[None] - - # Copied from asyncio.FlowControlMixin - async def _drain_helper(self) -> None: # pragma: no cover - if self.connection_lost_waiter.done(): - raise ConnectionResetError("Connection lost") - if not self._paused: - return - waiter = self._drain_waiter - assert waiter is None or waiter.cancelled() - waiter = self.loop.create_future() - self._drain_waiter = waiter - await waiter - - # Copied from asyncio.StreamWriter - async def _drain(self) -> None: # pragma: no cover - if self.reader is not None: - exc = self.reader.exception() - if exc is not None: - raise exc - if self.transport is not None: - if self.transport.is_closing(): - # Yield to the event loop so connection_lost() may be - # called. Without this, _drain_helper() would return - # immediately, and code that calls - # write(...); yield from drain() - # in a loop would never call connection_lost(), so it - # would not see an error when the socket is closed. - await asyncio.sleep( - 0, loop=self.loop if sys.version_info[:2] < (3, 8) else None - ) - await self._drain_helper() - - def connection_open(self) -> None: - """ - Callback when the WebSocket opening handshake completes. - - Enter the OPEN state and start the data transfer phase. - - """ - # 4.1. The WebSocket Connection is Established. - assert self.state is State.CONNECTING - self.state = State.OPEN - logger.debug("%s - state = OPEN", self.side) - # Start the task that receives incoming WebSocket messages. - self.transfer_data_task = self.loop.create_task(self.transfer_data()) - # Start the task that sends pings at regular intervals. - self.keepalive_ping_task = self.loop.create_task(self.keepalive_ping()) - # Start the task that eventually closes the TCP connection. - self.close_connection_task = self.loop.create_task(self.close_connection()) - - @property - def host(self) -> Optional[str]: - alternative = "remote_address" if self.is_client else "local_address" - warnings.warn(f"use {alternative}[0] instead of host", DeprecationWarning) - return self._host - - @property - def port(self) -> Optional[int]: - alternative = "remote_address" if self.is_client else "local_address" - warnings.warn(f"use {alternative}[1] instead of port", DeprecationWarning) - return self._port - - @property - def secure(self) -> Optional[bool]: - warnings.warn("don't use secure", DeprecationWarning) - return self._secure - - # Public API - - @property - def local_address(self) -> Any: - """ - Local address of the connection as a ``(host, port)`` tuple. - - When the connection isn't open, ``local_address`` is ``None``. - - """ - try: - transport = self.transport - except AttributeError: - return None - else: - return transport.get_extra_info("sockname") - - @property - def remote_address(self) -> Any: - """ - Remote address of the connection as a ``(host, port)`` tuple. - - When the connection isn't open, ``remote_address`` is ``None``. - - """ - try: - transport = self.transport - except AttributeError: - return None - else: - return transport.get_extra_info("peername") - - @property - def open(self) -> bool: - """ - ``True`` when the connection is usable. - - It may be used to detect disconnections. However, this approach is - discouraged per the EAFP_ principle. - - When ``open`` is ``False``, using the connection raises a - :exc:`~websockets.exceptions.ConnectionClosed` exception. - - .. _EAFP: https://docs.python.org/3/glossary.html#term-eafp - - """ - return self.state is State.OPEN and not self.transfer_data_task.done() - - @property - def closed(self) -> bool: - """ - ``True`` once the connection is closed. - - Be aware that both :attr:`open` and :attr:`closed` are ``False`` during - the opening and closing sequences. - - """ - return self.state is State.CLOSED - - async def wait_closed(self) -> None: - """ - Wait until the connection is closed. - - This is identical to :attr:`closed`, except it can be awaited. - - This can make it easier to handle connection termination, regardless - of its cause, in tasks that interact with the WebSocket connection. - - """ - await asyncio.shield(self.connection_lost_waiter) - - async def __aiter__(self) -> AsyncIterator[Data]: - """ - Iterate on received messages. - - Exit normally when the connection is closed with code 1000 or 1001. - - Raise an exception in other cases. - - """ - try: - while True: - yield await self.recv() - except ConnectionClosedOK: - return - - async def recv(self) -> Data: - """ - Receive the next message. - - Return a :class:`str` for a text frame and :class:`bytes` for a binary - frame. - - When the end of the message stream is reached, :meth:`recv` raises - :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it - raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal - connection closure and - :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol - error or a network failure. - - .. versionchanged:: 3.0 - - :meth:`recv` used to return ``None`` instead. Refer to the - changelog for details. - - Canceling :meth:`recv` is safe. There's no risk of losing the next - message. The next invocation of :meth:`recv` will return it. This - makes it possible to enforce a timeout by wrapping :meth:`recv` in - :func:`~asyncio.wait_for`. - - :raises ~websockets.exceptions.ConnectionClosed: when the - connection is closed - :raises RuntimeError: if two coroutines call :meth:`recv` concurrently - - """ - if self._pop_message_waiter is not None: - raise RuntimeError( - "cannot call recv while another coroutine " - "is already waiting for the next message" - ) - - # Don't await self.ensure_open() here: - # - messages could be available in the queue even if the connection - # is closed; - # - messages could be received before the closing frame even if the - # connection is closing. - - # Wait until there's a message in the queue (if necessary) or the - # connection is closed. - while len(self.messages) <= 0: - pop_message_waiter: asyncio.Future[None] = self.loop.create_future() - self._pop_message_waiter = pop_message_waiter - try: - # If asyncio.wait() is canceled, it doesn't cancel - # pop_message_waiter and self.transfer_data_task. - await asyncio.wait( - [pop_message_waiter, self.transfer_data_task], - loop=self.loop if sys.version_info[:2] < (3, 8) else None, - return_when=asyncio.FIRST_COMPLETED, - ) - finally: - self._pop_message_waiter = None - - # If asyncio.wait(...) exited because self.transfer_data_task - # completed before receiving a new message, raise a suitable - # exception (or return None if legacy_recv is enabled). - if not pop_message_waiter.done(): - if self.legacy_recv: - return None # type: ignore - else: - # Wait until the connection is closed to raise - # ConnectionClosed with the correct code and reason. - await self.ensure_open() - - # Pop a message from the queue. - message = self.messages.popleft() - - # Notify transfer_data(). - if self._put_message_waiter is not None: - self._put_message_waiter.set_result(None) - self._put_message_waiter = None - - return message - - async def send( - self, message: Union[Data, Iterable[Data], AsyncIterable[Data]] - ) -> None: - """ - Send a message. - - A string (:class:`str`) is sent as a `Text frame`_. A bytestring or - bytes-like object (:class:`bytes`, :class:`bytearray`, or - :class:`memoryview`) is sent as a `Binary frame`_. - - .. _Text frame: https://tools.ietf.org/html/rfc6455#section-5.6 - .. _Binary frame: https://tools.ietf.org/html/rfc6455#section-5.6 - - :meth:`send` also accepts an iterable or an asynchronous iterable of - strings, bytestrings, or bytes-like objects. In that case the message - is fragmented. Each item is treated as a message fragment and sent in - its own frame. All items must be of the same type, or else - :meth:`send` will raise a :exc:`TypeError` and the connection will be - closed. - - :meth:`send` rejects dict-like objects because this is often an error. - If you wish to send the keys of a dict-like object as fragments, call - its :meth:`~dict.keys` method and pass the result to :meth:`send`. - - Canceling :meth:`send` is discouraged. Instead, you should close the - connection with :meth:`close`. Indeed, there only two situations where - :meth:`send` yields control to the event loop: - - 1. The write buffer is full. If you don't want to wait until enough - data is sent, your only alternative is to close the connection. - :meth:`close` will likely time out then abort the TCP connection. - 2. ``message`` is an asynchronous iterator. Stopping in the middle of - a fragmented message will cause a protocol error. Closing the - connection has the same effect. - - :raises TypeError: for unsupported inputs - - """ - await self.ensure_open() - - # While sending a fragmented message, prevent sending other messages - # until all fragments are sent. - while self._fragmented_message_waiter is not None: - await asyncio.shield(self._fragmented_message_waiter) - - # Unfragmented message -- this case must be handled first because - # strings and bytes-like objects are iterable. - - if isinstance(message, (str, bytes, bytearray, memoryview)): - opcode, data = prepare_data(message) - await self.write_frame(True, opcode, data) - - # Catch a common mistake -- passing a dict to send(). - - elif isinstance(message, Mapping): - raise TypeError("data is a dict-like object") - - # Fragmented message -- regular iterator. - - elif isinstance(message, Iterable): - - # Work around https://github.com/python/mypy/issues/6227 - message = cast(Iterable[Data], message) - - iter_message = iter(message) - try: - message_chunk = next(iter_message) - except StopIteration: - return - opcode, data = prepare_data(message_chunk) - - self._fragmented_message_waiter = asyncio.Future() - try: - # First fragment. - await self.write_frame(False, opcode, data) - - # Other fragments. - for message_chunk in iter_message: - confirm_opcode, data = prepare_data(message_chunk) - if confirm_opcode != opcode: - raise TypeError("data contains inconsistent types") - await self.write_frame(False, OP_CONT, data) - - # Final fragment. - await self.write_frame(True, OP_CONT, b"") - - except Exception: - # We're half-way through a fragmented message and we can't - # complete it. This makes the connection unusable. - self.fail_connection(1011) - raise - - finally: - self._fragmented_message_waiter.set_result(None) - self._fragmented_message_waiter = None - - # Fragmented message -- asynchronous iterator - - elif isinstance(message, AsyncIterable): - # aiter_message = aiter(message) without aiter - # https://github.com/python/mypy/issues/5738 - aiter_message = type(message).__aiter__(message) # type: ignore - try: - # message_chunk = anext(aiter_message) without anext - # https://github.com/python/mypy/issues/5738 - message_chunk = await type(aiter_message).__anext__( # type: ignore - aiter_message - ) - except StopAsyncIteration: - return - opcode, data = prepare_data(message_chunk) - - self._fragmented_message_waiter = asyncio.Future() - try: - # First fragment. - await self.write_frame(False, opcode, data) - - # Other fragments. - # https://github.com/python/mypy/issues/5738 - async for message_chunk in aiter_message: # type: ignore - confirm_opcode, data = prepare_data(message_chunk) - if confirm_opcode != opcode: - raise TypeError("data contains inconsistent types") - await self.write_frame(False, OP_CONT, data) - - # Final fragment. - await self.write_frame(True, OP_CONT, b"") - - except Exception: - # We're half-way through a fragmented message and we can't - # complete it. This makes the connection unusable. - self.fail_connection(1011) - raise - - finally: - self._fragmented_message_waiter.set_result(None) - self._fragmented_message_waiter = None - - else: - raise TypeError("data must be bytes, str, or iterable") - - async def close(self, code: int = 1000, reason: str = "") -> None: - """ - Perform the closing handshake. - - :meth:`close` waits for the other end to complete the handshake and - for the TCP connection to terminate. As a consequence, there's no need - to await :meth:`wait_closed`; :meth:`close` already does it. - - :meth:`close` is idempotent: it doesn't do anything once the - connection is closed. - - Wrapping :func:`close` in :func:`~asyncio.create_task` is safe, given - that errors during connection termination aren't particularly useful. - - Canceling :meth:`close` is discouraged. If it takes too long, you can - set a shorter ``close_timeout``. If you don't want to wait, let the - Python process exit, then the OS will close the TCP connection. - - :param code: WebSocket close code - :param reason: WebSocket close reason - - """ - try: - await asyncio.wait_for( - self.write_close_frame(serialize_close(code, reason)), - self.close_timeout, - loop=self.loop if sys.version_info[:2] < (3, 8) else None, - ) - except asyncio.TimeoutError: - # If the close frame cannot be sent because the send buffers - # are full, the closing handshake won't complete anyway. - # Fail the connection to shut down faster. - self.fail_connection() - - # If no close frame is received within the timeout, wait_for() cancels - # the data transfer task and raises TimeoutError. - - # If close() is called multiple times concurrently and one of these - # calls hits the timeout, the data transfer task will be cancelled. - # Other calls will receive a CancelledError here. - - try: - # If close() is canceled during the wait, self.transfer_data_task - # is canceled before the timeout elapses. - await asyncio.wait_for( - self.transfer_data_task, - self.close_timeout, - loop=self.loop if sys.version_info[:2] < (3, 8) else None, - ) - except (asyncio.TimeoutError, asyncio.CancelledError): - pass - - # Wait for the close connection task to close the TCP connection. - await asyncio.shield(self.close_connection_task) - - async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: - """ - Send a ping. - - Return a :class:`~asyncio.Future` that will be completed when the - corresponding pong is received. You can ignore it if you don't intend - to wait. - - A ping may serve as a keepalive or as a check that the remote endpoint - received all messages up to this point:: - - pong_waiter = await ws.ping() - await pong_waiter # only if you want to wait for the pong - - By default, the ping contains four random bytes. This payload may be - overridden with the optional ``data`` argument which must be a string - (which will be encoded to UTF-8) or a bytes-like object. - - Canceling :meth:`ping` is discouraged. If :meth:`ping` doesn't return - immediately, it means the write buffer is full. If you don't want to - wait, you should close the connection. - - Canceling the :class:`~asyncio.Future` returned by :meth:`ping` has no - effect. - - """ - await self.ensure_open() - - if data is not None: - data = prepare_ctrl(data) - - # Protect against duplicates if a payload is explicitly set. - if data in self.pings: - raise ValueError("already waiting for a pong with the same data") - - # Generate a unique random payload otherwise. - while data is None or data in self.pings: - data = struct.pack("!I", random.getrandbits(32)) - - self.pings[data] = self.loop.create_future() - - await self.write_frame(True, OP_PING, data) - - return asyncio.shield(self.pings[data]) - - async def pong(self, data: Data = b"") -> None: - """ - Send a pong. - - An unsolicited pong may serve as a unidirectional heartbeat. - - The payload may be set with the optional ``data`` argument which must - be a string (which will be encoded to UTF-8) or a bytes-like object. - - Canceling :meth:`pong` is discouraged for the same reason as - :meth:`ping`. - - """ - await self.ensure_open() - - data = prepare_ctrl(data) - - await self.write_frame(True, OP_PONG, data) - - # Private methods - no guarantees. - - def connection_closed_exc(self) -> ConnectionClosed: - exception: ConnectionClosed - if self.close_code == 1000 or self.close_code == 1001: - exception = ConnectionClosedOK(self.close_code, self.close_reason) - else: - exception = ConnectionClosedError(self.close_code, self.close_reason) - # Chain to the exception that terminated data transfer, if any. - exception.__cause__ = self.transfer_data_exc - return exception - - async def ensure_open(self) -> None: - """ - Check that the WebSocket connection is open. - - Raise :exc:`~websockets.exceptions.ConnectionClosed` if it isn't. - - """ - # Handle cases from most common to least common for performance. - if self.state is State.OPEN: - # If self.transfer_data_task exited without a closing handshake, - # self.close_connection_task may be closing the connection, going - # straight from OPEN to CLOSED. - if self.transfer_data_task.done(): - await asyncio.shield(self.close_connection_task) - raise self.connection_closed_exc() - else: - return - - if self.state is State.CLOSED: - raise self.connection_closed_exc() - - if self.state is State.CLOSING: - # If we started the closing handshake, wait for its completion to - # get the proper close code and reason. self.close_connection_task - # will complete within 4 or 5 * close_timeout after close(). The - # CLOSING state also occurs when failing the connection. In that - # case self.close_connection_task will complete even faster. - await asyncio.shield(self.close_connection_task) - raise self.connection_closed_exc() - - # Control may only reach this point in buggy third-party subclasses. - assert self.state is State.CONNECTING - raise InvalidState("WebSocket connection isn't established yet") - - async def transfer_data(self) -> None: - """ - Read incoming messages and put them in a queue. - - This coroutine runs in a task until the closing handshake is started. - - """ - try: - while True: - message = await self.read_message() - - # Exit the loop when receiving a close frame. - if message is None: - break - - # Wait until there's room in the queue (if necessary). - if self.max_queue is not None: - while len(self.messages) >= self.max_queue: - self._put_message_waiter = self.loop.create_future() - try: - await asyncio.shield(self._put_message_waiter) - finally: - self._put_message_waiter = None - - # Put the message in the queue. - self.messages.append(message) - - # Notify recv(). - if self._pop_message_waiter is not None: - self._pop_message_waiter.set_result(None) - self._pop_message_waiter = None - - except asyncio.CancelledError as exc: - self.transfer_data_exc = exc - # If fail_connection() cancels this task, avoid logging the error - # twice and failing the connection again. - raise - - except ProtocolError as exc: - self.transfer_data_exc = exc - self.fail_connection(1002) - - except (ConnectionError, TimeoutError, EOFError) as exc: - # Reading data with self.reader.readexactly may raise: - # - most subclasses of ConnectionError if the TCP connection - # breaks, is reset, or is aborted; - # - TimeoutError if the TCP connection times out; - # - IncompleteReadError, a subclass of EOFError, if fewer - # bytes are available than requested. - self.transfer_data_exc = exc - self.fail_connection(1006) - - except UnicodeDecodeError as exc: - self.transfer_data_exc = exc - self.fail_connection(1007) - - except PayloadTooBig as exc: - self.transfer_data_exc = exc - self.fail_connection(1009) - - except Exception as exc: - # This shouldn't happen often because exceptions expected under - # regular circumstances are handled above. If it does, consider - # catching and handling more exceptions. - logger.error("Error in data transfer", exc_info=True) - - self.transfer_data_exc = exc - self.fail_connection(1011) - - async def read_message(self) -> Optional[Data]: - """ - Read a single message from the connection. - - Re-assemble data frames if the message is fragmented. - - Return ``None`` when the closing handshake is started. - - """ - frame = await self.read_data_frame(max_size=self.max_size) - - # A close frame was received. - if frame is None: - return None - - if frame.opcode == OP_TEXT: - text = True - elif frame.opcode == OP_BINARY: - text = False - else: # frame.opcode == OP_CONT - raise ProtocolError("unexpected opcode") - - # Shortcut for the common case - no fragmentation - if frame.fin: - return frame.data.decode("utf-8") if text else frame.data - - # 5.4. Fragmentation - chunks: List[Data] = [] - max_size = self.max_size - if text: - decoder_factory = codecs.getincrementaldecoder("utf-8") - decoder = decoder_factory(errors="strict") - if max_size is None: - - def append(frame: Frame) -> None: - nonlocal chunks - chunks.append(decoder.decode(frame.data, frame.fin)) - - else: - - def append(frame: Frame) -> None: - nonlocal chunks, max_size - chunks.append(decoder.decode(frame.data, frame.fin)) - assert isinstance(max_size, int) - max_size -= len(frame.data) - - else: - if max_size is None: - - def append(frame: Frame) -> None: - nonlocal chunks - chunks.append(frame.data) - - else: - - def append(frame: Frame) -> None: - nonlocal chunks, max_size - chunks.append(frame.data) - assert isinstance(max_size, int) - max_size -= len(frame.data) - - append(frame) - - while not frame.fin: - frame = await self.read_data_frame(max_size=max_size) - if frame is None: - raise ProtocolError("incomplete fragmented message") - if frame.opcode != OP_CONT: - raise ProtocolError("unexpected opcode") - append(frame) - - # mypy cannot figure out that chunks have the proper type. - return ("" if text else b"").join(chunks) # type: ignore - - async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: - """ - Read a single data frame from the connection. - - Process control frames received before the next data frame. - - Return ``None`` if a close frame is encountered before any data frame. - - """ - # 6.2. Receiving Data - while True: - frame = await self.read_frame(max_size) - - # 5.5. Control Frames - if frame.opcode == OP_CLOSE: - # 7.1.5. The WebSocket Connection Close Code - # 7.1.6. The WebSocket Connection Close Reason - self.close_code, self.close_reason = parse_close(frame.data) - try: - # Echo the original data instead of re-serializing it with - # serialize_close() because that fails when the close frame - # is empty and parse_close() synthetizes a 1005 close code. - await self.write_close_frame(frame.data) - except ConnectionClosed: - # It doesn't really matter if the connection was closed - # before we could send back a close frame. - pass - return None - - elif frame.opcode == OP_PING: - # Answer pings. - ping_hex = frame.data.hex() or "[empty]" - logger.debug( - "%s - received ping, sending pong: %s", self.side, ping_hex - ) - await self.pong(frame.data) - - elif frame.opcode == OP_PONG: - # Acknowledge pings on solicited pongs. - if frame.data in self.pings: - logger.debug( - "%s - received solicited pong: %s", - self.side, - frame.data.hex() or "[empty]", - ) - # Acknowledge all pings up to the one matching this pong. - ping_id = None - ping_ids = [] - for ping_id, ping in self.pings.items(): - ping_ids.append(ping_id) - if not ping.done(): - ping.set_result(None) - if ping_id == frame.data: - break - else: # pragma: no cover - assert False, "ping_id is in self.pings" - # Remove acknowledged pings from self.pings. - for ping_id in ping_ids: - del self.pings[ping_id] - ping_ids = ping_ids[:-1] - if ping_ids: - pings_hex = ", ".join( - ping_id.hex() or "[empty]" for ping_id in ping_ids - ) - plural = "s" if len(ping_ids) > 1 else "" - logger.debug( - "%s - acknowledged previous ping%s: %s", - self.side, - plural, - pings_hex, - ) - else: - logger.debug( - "%s - received unsolicited pong: %s", - self.side, - frame.data.hex() or "[empty]", - ) - - # 5.6. Data Frames - else: - return frame - - async def read_frame(self, max_size: Optional[int]) -> Frame: - """ - Read a single frame from the connection. - - """ - frame = await Frame.read( - self.reader.readexactly, - mask=not self.is_client, - max_size=max_size, - extensions=self.extensions, - ) - logger.debug("%s < %r", self.side, frame) - return frame - - async def write_frame( - self, fin: bool, opcode: int, data: bytes, *, _expected_state: int = State.OPEN - ) -> None: - # Defensive assertion for protocol compliance. - if self.state is not _expected_state: # pragma: no cover - raise InvalidState( - f"Cannot write to a WebSocket in the {self.state.name} state" - ) - - frame = Frame(fin, Opcode(opcode), data) - logger.debug("%s > %r", self.side, frame) - frame.write( - self.transport.write, mask=self.is_client, extensions=self.extensions - ) - - try: - # drain() cannot be called concurrently by multiple coroutines: - # http://bugs.python.org/issue29930. Remove this lock when no - # version of Python where this bugs exists is supported anymore. - async with self._drain_lock: - # Handle flow control automatically. - await self._drain() - except ConnectionError: - # Terminate the connection if the socket died. - self.fail_connection() - # Wait until the connection is closed to raise ConnectionClosed - # with the correct code and reason. - await self.ensure_open() - - async def write_close_frame(self, data: bytes = b"") -> None: - """ - Write a close frame if and only if the connection state is OPEN. - - This dedicated coroutine must be used for writing close frames to - ensure that at most one close frame is sent on a given connection. - - """ - # Test and set the connection state before sending the close frame to - # avoid sending two frames in case of concurrent calls. - if self.state is State.OPEN: - # 7.1.3. The WebSocket Closing Handshake is Started - self.state = State.CLOSING - logger.debug("%s - state = CLOSING", self.side) - - # 7.1.2. Start the WebSocket Closing Handshake - await self.write_frame(True, OP_CLOSE, data, _expected_state=State.CLOSING) - - async def keepalive_ping(self) -> None: - """ - Send a Ping frame and wait for a Pong frame at regular intervals. - - This coroutine exits when the connection terminates and one of the - following happens: - - - :meth:`ping` raises :exc:`ConnectionClosed`, or - - :meth:`close_connection` cancels :attr:`keepalive_ping_task`. - - """ - if self.ping_interval is None: - return - - try: - while True: - await asyncio.sleep( - self.ping_interval, - loop=self.loop if sys.version_info[:2] < (3, 8) else None, - ) - - # ping() raises CancelledError if the connection is closed, - # when close_connection() cancels self.keepalive_ping_task. - - # ping() raises ConnectionClosed if the connection is lost, - # when connection_lost() calls abort_pings(). - - pong_waiter = await self.ping() - - if self.ping_timeout is not None: - try: - await asyncio.wait_for( - pong_waiter, - self.ping_timeout, - loop=self.loop if sys.version_info[:2] < (3, 8) else None, - ) - except asyncio.TimeoutError: - logger.debug("%s ! timed out waiting for pong", self.side) - self.fail_connection(1011) - break - - # Remove this branch when dropping support for Python < 3.8 - # because CancelledError no longer inherits Exception. - except asyncio.CancelledError: - raise - - except ConnectionClosed: - pass - - except Exception: - logger.warning("Unexpected exception in keepalive ping task", exc_info=True) - - async def close_connection(self) -> None: - """ - 7.1.1. Close the WebSocket Connection - - When the opening handshake succeeds, :meth:`connection_open` starts - this coroutine in a task. It waits for the data transfer phase to - complete then it closes the TCP connection cleanly. - - When the opening handshake fails, :meth:`fail_connection` does the - same. There's no data transfer phase in that case. - - """ - try: - # Wait for the data transfer phase to complete. - if hasattr(self, "transfer_data_task"): - try: - await self.transfer_data_task - except asyncio.CancelledError: - pass - - # Cancel the keepalive ping task. - if hasattr(self, "keepalive_ping_task"): - self.keepalive_ping_task.cancel() - - # A client should wait for a TCP close from the server. - if self.is_client and hasattr(self, "transfer_data_task"): - if await self.wait_for_connection_lost(): - # Coverage marks this line as a partially executed branch. - # I supect a bug in coverage. Ignore it for now. - return # pragma: no cover - logger.debug("%s ! timed out waiting for TCP close", self.side) - - # Half-close the TCP connection if possible (when there's no TLS). - if self.transport.can_write_eof(): - logger.debug("%s x half-closing TCP connection", self.side) - self.transport.write_eof() - - if await self.wait_for_connection_lost(): - # Coverage marks this line as a partially executed branch. - # I supect a bug in coverage. Ignore it for now. - return # pragma: no cover - logger.debug("%s ! timed out waiting for TCP close", self.side) - - finally: - # The try/finally ensures that the transport never remains open, - # even if this coroutine is canceled (for example). - - # If connection_lost() was called, the TCP connection is closed. - # However, if TLS is enabled, the transport still needs closing. - # Else asyncio complains: ResourceWarning: unclosed transport. - if self.connection_lost_waiter.done() and self.transport.is_closing(): - return - - # Close the TCP connection. Buffers are flushed asynchronously. - logger.debug("%s x closing TCP connection", self.side) - self.transport.close() - - if await self.wait_for_connection_lost(): - return - logger.debug("%s ! timed out waiting for TCP close", self.side) - - # Abort the TCP connection. Buffers are discarded. - logger.debug("%s x aborting TCP connection", self.side) - self.transport.abort() - - # connection_lost() is called quickly after aborting. - # Coverage marks this line as a partially executed branch. - # I supect a bug in coverage. Ignore it for now. - await self.wait_for_connection_lost() # pragma: no cover - - async def wait_for_connection_lost(self) -> bool: - """ - Wait until the TCP connection is closed or ``self.close_timeout`` elapses. - - Return ``True`` if the connection is closed and ``False`` otherwise. - - """ - if not self.connection_lost_waiter.done(): - try: - await asyncio.wait_for( - asyncio.shield(self.connection_lost_waiter), - self.close_timeout, - loop=self.loop if sys.version_info[:2] < (3, 8) else None, - ) - except asyncio.TimeoutError: - pass - # Re-check self.connection_lost_waiter.done() synchronously because - # connection_lost() could run between the moment the timeout occurs - # and the moment this coroutine resumes running. - return self.connection_lost_waiter.done() - - def fail_connection(self, code: int = 1006, reason: str = "") -> None: - """ - 7.1.7. Fail the WebSocket Connection - - This requires: - - 1. Stopping all processing of incoming data, which means cancelling - :attr:`transfer_data_task`. The close code will be 1006 unless a - close frame was received earlier. - - 2. Sending a close frame with an appropriate code if the opening - handshake succeeded and the other side is likely to process it. - - 3. Closing the connection. :meth:`close_connection` takes care of - this once :attr:`transfer_data_task` exits after being canceled. - - (The specification describes these steps in the opposite order.) - - """ - logger.debug( - "%s ! failing %s WebSocket connection with code %d", - self.side, - self.state.name, - code, - ) - - # Cancel transfer_data_task if the opening handshake succeeded. - # cancel() is idempotent and ignored if the task is done already. - if hasattr(self, "transfer_data_task"): - self.transfer_data_task.cancel() - - # Send a close frame when the state is OPEN (a close frame was already - # sent if it's CLOSING), except when failing the connection because of - # an error reading from or writing to the network. - # Don't send a close frame if the connection is broken. - if code != 1006 and self.state is State.OPEN: - - frame_data = serialize_close(code, reason) - - # Write the close frame without draining the write buffer. - - # Keeping fail_connection() synchronous guarantees it can't - # get stuck and simplifies the implementation of the callers. - # Not drainig the write buffer is acceptable in this context. - - # This duplicates a few lines of code from write_close_frame() - # and write_frame(). - - self.state = State.CLOSING - logger.debug("%s - state = CLOSING", self.side) - - frame = Frame(True, OP_CLOSE, frame_data) - logger.debug("%s > %r", self.side, frame) - frame.write( - self.transport.write, mask=self.is_client, extensions=self.extensions - ) - - # Start close_connection_task if the opening handshake didn't succeed. - if not hasattr(self, "close_connection_task"): - self.close_connection_task = self.loop.create_task(self.close_connection()) - - def abort_pings(self) -> None: - """ - Raise ConnectionClosed in pending keepalive pings. - - They'll never receive a pong once the connection is closed. - - """ - assert self.state is State.CLOSED - exc = self.connection_closed_exc() - - for ping in self.pings.values(): - ping.set_exception(exc) - # If the exception is never retrieved, it will be logged when ping - # is garbage-collected. This is confusing for users. - # Given that ping is done (with an exception), canceling it does - # nothing, but it prevents logging the exception. - ping.cancel() - - if self.pings: - pings_hex = ", ".join(ping_id.hex() or "[empty]" for ping_id in self.pings) - plural = "s" if len(self.pings) > 1 else "" - logger.debug( - "%s - aborted pending ping%s: %s", self.side, plural, pings_hex - ) - - # asyncio.Protocol methods - - def connection_made(self, transport: asyncio.BaseTransport) -> None: - """ - Configure write buffer limits. - - The high-water limit is defined by ``self.write_limit``. - - The low-water limit currently defaults to ``self.write_limit // 4`` in - :meth:`~asyncio.WriteTransport.set_write_buffer_limits`, which should - be all right for reasonable use cases of this library. - - This is the earliest point where we can get hold of the transport, - which means it's the best point for configuring it. - - """ - logger.debug("%s - event = connection_made(%s)", self.side, transport) - - transport = cast(asyncio.Transport, transport) - transport.set_write_buffer_limits(self.write_limit) - self.transport = transport - - # Copied from asyncio.StreamReaderProtocol - self.reader.set_transport(transport) - - def connection_lost(self, exc: Optional[Exception]) -> None: - """ - 7.1.4. The WebSocket Connection is Closed. - - """ - logger.debug("%s - event = connection_lost(%s)", self.side, exc) - self.state = State.CLOSED - logger.debug("%s - state = CLOSED", self.side) - if not hasattr(self, "close_code"): - self.close_code = 1006 - if not hasattr(self, "close_reason"): - self.close_reason = "" - logger.debug( - "%s x code = %d, reason = %s", - self.side, - self.close_code, - self.close_reason or "[no reason]", - ) - self.abort_pings() - # If self.connection_lost_waiter isn't pending, that's a bug, because: - # - it's set only here in connection_lost() which is called only once; - # - it must never be canceled. - self.connection_lost_waiter.set_result(None) - - if True: # pragma: no cover - - # Copied from asyncio.StreamReaderProtocol - if self.reader is not None: - if exc is None: - self.reader.feed_eof() - else: - self.reader.set_exception(exc) - - # Copied from asyncio.FlowControlMixin - # Wake up the writer if currently paused. - if not self._paused: - return - waiter = self._drain_waiter - if waiter is None: - return - self._drain_waiter = None - if waiter.done(): - return - if exc is None: - waiter.set_result(None) - else: - waiter.set_exception(exc) - - def pause_writing(self) -> None: # pragma: no cover - assert not self._paused - self._paused = True - - def resume_writing(self) -> None: # pragma: no cover - assert self._paused - self._paused = False - - waiter = self._drain_waiter - if waiter is not None: - self._drain_waiter = None - if not waiter.done(): - waiter.set_result(None) - - def data_received(self, data: bytes) -> None: - logger.debug("%s - event = data_received(<%d bytes>)", self.side, len(data)) - self.reader.feed_data(data) - - def eof_received(self) -> None: - """ - Close the transport after receiving EOF. - - The WebSocket protocol has its own closing handshake: endpoints close - the TCP or TLS connection after sending and receiving a close frame. - - As a consequence, they never need to write after receiving EOF, so - there's no reason to keep the transport open by returning ``True``. - - Besides, that doesn't work on TLS connections. - - """ - logger.debug("%s - event = eof_received()", self.side) - self.reader.feed_eof() +from .legacy.protocol import * # noqa diff --git a/src/websockets/server.py b/src/websockets/server.py index c2c818ce9..bd527be74 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -6,7 +6,6 @@ import logging from typing import Callable, Generator, List, Optional, Sequence, Tuple, Union, cast -from .asyncio_server import WebSocketServer, WebSocketServerProtocol, serve, unix_serve from .connection import CONNECTING, OPEN, SERVER, Connection from .datastructures import Headers, HeadersLike, MultipleValuesError from .exceptions import ( @@ -27,6 +26,12 @@ ) from .http import USER_AGENT from .http11 import Request, Response +from .legacy.server import ( # noqa + WebSocketServer, + WebSocketServerProtocol, + serve, + unix_serve, +) from .typing import ( ConnectionOption, ExtensionHeader, @@ -37,13 +42,7 @@ from .utils import accept_key -__all__ = [ - "serve", - "unix_serve", - "ServerConnection", - "WebSocketServerProtocol", - "WebSocketServer", -] +__all__ = ["ServerConnection"] logger = logging.getLogger(__name__) diff --git a/tests/__init__.py b/tests/__init__.py index 76c869f50..dd78609f5 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,15 +1,5 @@ import logging -import warnings # Avoid displaying stack traces at the ERROR logging level. logging.basicConfig(level=logging.CRITICAL) - - -# Ignore deprecation warnings while refactoring is in progress -warnings.filterwarnings( - action="ignore", - message=r"websockets\.framing is deprecated", - category=DeprecationWarning, - module="websockets.framing", -) diff --git a/tests/legacy/__init__.py b/tests/legacy/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/legacy/test_auth.py b/tests/legacy/test_auth.py new file mode 100644 index 000000000..bb8c6a6eb --- /dev/null +++ b/tests/legacy/test_auth.py @@ -0,0 +1,160 @@ +import unittest +import urllib.error + +from websockets.exceptions import InvalidStatusCode +from websockets.headers import build_authorization_basic +from websockets.legacy.auth import * +from websockets.legacy.auth import is_credentials + +from .test_client_server import ClientServerTestsMixin, with_client, with_server +from .utils import AsyncioTestCase + + +class AuthTests(unittest.TestCase): + def test_is_credentials(self): + self.assertTrue(is_credentials(("username", "password"))) + + def test_is_not_credentials(self): + self.assertFalse(is_credentials(None)) + self.assertFalse(is_credentials("username")) + + +class CustomWebSocketServerProtocol(BasicAuthWebSocketServerProtocol): + async def process_request(self, path, request_headers): + type(self).used = True + return await super().process_request(path, request_headers) + + +class AuthClientServerTests(ClientServerTestsMixin, AsyncioTestCase): + + create_protocol = basic_auth_protocol_factory( + realm="auth-tests", credentials=("hello", "iloveyou") + ) + + @with_server(create_protocol=create_protocol) + @with_client(user_info=("hello", "iloveyou")) + def test_basic_auth(self): + req_headers = self.client.request_headers + resp_headers = self.client.response_headers + self.assertEqual(req_headers["Authorization"], "Basic aGVsbG86aWxvdmV5b3U=") + self.assertNotIn("WWW-Authenticate", resp_headers) + + self.loop.run_until_complete(self.client.send("Hello!")) + self.loop.run_until_complete(self.client.recv()) + + def test_basic_auth_server_no_credentials(self): + with self.assertRaises(TypeError) as raised: + basic_auth_protocol_factory(realm="auth-tests", credentials=None) + self.assertEqual( + str(raised.exception), "provide either credentials or check_credentials" + ) + + def test_basic_auth_server_bad_credentials(self): + with self.assertRaises(TypeError) as raised: + basic_auth_protocol_factory(realm="auth-tests", credentials=42) + self.assertEqual(str(raised.exception), "invalid credentials argument: 42") + + create_protocol_multiple_credentials = basic_auth_protocol_factory( + realm="auth-tests", + credentials=[("hello", "iloveyou"), ("goodbye", "stillloveu")], + ) + + @with_server(create_protocol=create_protocol_multiple_credentials) + @with_client(user_info=("hello", "iloveyou")) + def test_basic_auth_server_multiple_credentials(self): + self.loop.run_until_complete(self.client.send("Hello!")) + self.loop.run_until_complete(self.client.recv()) + + def test_basic_auth_bad_multiple_credentials(self): + with self.assertRaises(TypeError) as raised: + basic_auth_protocol_factory( + realm="auth-tests", credentials=[("hello", "iloveyou"), 42] + ) + self.assertEqual( + str(raised.exception), + "invalid credentials argument: [('hello', 'iloveyou'), 42]", + ) + + async def check_credentials(username, password): + return password == "iloveyou" + + create_protocol_check_credentials = basic_auth_protocol_factory( + realm="auth-tests", + check_credentials=check_credentials, + ) + + @with_server(create_protocol=create_protocol_check_credentials) + @with_client(user_info=("hello", "iloveyou")) + def test_basic_auth_check_credentials(self): + self.loop.run_until_complete(self.client.send("Hello!")) + self.loop.run_until_complete(self.client.recv()) + + create_protocol_custom_protocol = basic_auth_protocol_factory( + realm="auth-tests", + credentials=[("hello", "iloveyou")], + create_protocol=CustomWebSocketServerProtocol, + ) + + @with_server(create_protocol=create_protocol_custom_protocol) + @with_client(user_info=("hello", "iloveyou")) + def test_basic_auth_custom_protocol(self): + self.assertTrue(CustomWebSocketServerProtocol.used) + del CustomWebSocketServerProtocol.used + self.loop.run_until_complete(self.client.send("Hello!")) + self.loop.run_until_complete(self.client.recv()) + + @with_server(create_protocol=create_protocol) + def test_basic_auth_missing_credentials(self): + with self.assertRaises(InvalidStatusCode) as raised: + self.start_client() + self.assertEqual(raised.exception.status_code, 401) + + @with_server(create_protocol=create_protocol) + def test_basic_auth_missing_credentials_details(self): + with self.assertRaises(urllib.error.HTTPError) as raised: + self.loop.run_until_complete(self.make_http_request()) + self.assertEqual(raised.exception.code, 401) + self.assertEqual( + raised.exception.headers["WWW-Authenticate"], + 'Basic realm="auth-tests", charset="UTF-8"', + ) + self.assertEqual(raised.exception.read().decode(), "Missing credentials\n") + + @with_server(create_protocol=create_protocol) + def test_basic_auth_unsupported_credentials(self): + with self.assertRaises(InvalidStatusCode) as raised: + self.start_client(extra_headers={"Authorization": "Digest ..."}) + self.assertEqual(raised.exception.status_code, 401) + + @with_server(create_protocol=create_protocol) + def test_basic_auth_unsupported_credentials_details(self): + with self.assertRaises(urllib.error.HTTPError) as raised: + self.loop.run_until_complete( + self.make_http_request(headers={"Authorization": "Digest ..."}) + ) + self.assertEqual(raised.exception.code, 401) + self.assertEqual( + raised.exception.headers["WWW-Authenticate"], + 'Basic realm="auth-tests", charset="UTF-8"', + ) + self.assertEqual(raised.exception.read().decode(), "Unsupported credentials\n") + + @with_server(create_protocol=create_protocol) + def test_basic_auth_invalid_credentials(self): + with self.assertRaises(InvalidStatusCode) as raised: + self.start_client(user_info=("hello", "ihateyou")) + self.assertEqual(raised.exception.status_code, 401) + + @with_server(create_protocol=create_protocol) + def test_basic_auth_invalid_credentials_details(self): + with self.assertRaises(urllib.error.HTTPError) as raised: + authorization = build_authorization_basic("hello", "ihateyou") + self.loop.run_until_complete( + self.make_http_request(headers={"Authorization": authorization}) + ) + self.assertEqual(raised.exception.code, 401) + self.assertEqual( + raised.exception.headers["WWW-Authenticate"], + 'Basic realm="auth-tests", charset="UTF-8"', + ) + self.assertEqual(raised.exception.read().decode(), "Invalid credentials\n") diff --git a/tests/test_asyncio_client_server.py b/tests/legacy/test_client_server.py similarity index 97% rename from tests/test_asyncio_client_server.py rename to tests/legacy/test_client_server.py index 76c29334e..499ea1d59 100644 --- a/tests/test_asyncio_client_server.py +++ b/tests/legacy/test_client_server.py @@ -13,8 +13,6 @@ import urllib.request import warnings -from websockets.asyncio_client import * -from websockets.asyncio_server import * from websockets.datastructures import Headers from websockets.exceptions import ( ConnectionClosed, @@ -28,19 +26,20 @@ PerMessageDeflate, ServerPerMessageDeflateFactory, ) -from websockets.handshake_legacy import build_response from websockets.http import USER_AGENT -from websockets.http_legacy import read_response -from websockets.protocol import State +from websockets.legacy.client import * +from websockets.legacy.handshake import build_response +from websockets.legacy.http import read_response +from websockets.legacy.protocol import State +from websockets.legacy.server import * from websockets.uri import parse_uri -from .extensions.test_base import ( +from ..extensions.test_base import ( ClientNoOpExtensionFactory, NoOpExtension, ServerNoOpExtensionFactory, ) -from .test_protocol import MS -from .utils import AsyncioTestCase +from .utils import MS, AsyncioTestCase # Generate TLS certificate with: @@ -49,7 +48,7 @@ # $ cat test_localhost.key test_localhost.crt > test_localhost.pem # $ rm test_localhost.key test_localhost.crt -testcert = bytes(pathlib.Path(__file__).with_name("test_localhost.pem")) +testcert = bytes(pathlib.Path(__file__).parent.with_name("test_localhost.pem")) async def handler(ws, path): @@ -1016,7 +1015,7 @@ def test_subprotocol_error_two_subprotocols(self, _process_subprotocol): self.run_loop_once() @with_server() - @unittest.mock.patch("websockets.asyncio_server.read_request") + @unittest.mock.patch("websockets.legacy.server.read_request") def test_server_receives_malformed_request(self, _read_request): _read_request.side_effect = ValueError("read_request failed") @@ -1024,7 +1023,7 @@ def test_server_receives_malformed_request(self, _read_request): self.start_client() @with_server() - @unittest.mock.patch("websockets.asyncio_client.read_response") + @unittest.mock.patch("websockets.legacy.client.read_response") def test_client_receives_malformed_response(self, _read_response): _read_response.side_effect = ValueError("read_response failed") @@ -1033,7 +1032,7 @@ def test_client_receives_malformed_response(self, _read_response): self.run_loop_once() @with_server() - @unittest.mock.patch("websockets.asyncio_client.build_request") + @unittest.mock.patch("websockets.legacy.client.build_request") def test_client_sends_invalid_handshake_request(self, _build_request): def wrong_build_request(headers): return "42" @@ -1044,7 +1043,7 @@ def wrong_build_request(headers): self.start_client() @with_server() - @unittest.mock.patch("websockets.asyncio_server.build_response") + @unittest.mock.patch("websockets.legacy.server.build_response") def test_server_sends_invalid_handshake_response(self, _build_response): def wrong_build_response(headers, key): return build_response(headers, "42") @@ -1055,7 +1054,7 @@ def wrong_build_response(headers, key): self.start_client() @with_server() - @unittest.mock.patch("websockets.asyncio_client.read_response") + @unittest.mock.patch("websockets.legacy.client.read_response") def test_server_does_not_switch_protocols(self, _read_response): async def wrong_read_response(stream): status_code, reason, headers = await read_response(stream) @@ -1069,7 +1068,7 @@ async def wrong_read_response(stream): @with_server() @unittest.mock.patch( - "websockets.asyncio_server.WebSocketServerProtocol.process_request" + "websockets.legacy.server.WebSocketServerProtocol.process_request" ) def test_server_error_in_handshake(self, _process_request): _process_request.side_effect = Exception("process_request crashed") @@ -1078,7 +1077,7 @@ def test_server_error_in_handshake(self, _process_request): self.start_client() @with_server() - @unittest.mock.patch("websockets.asyncio_server.WebSocketServerProtocol.send") + @unittest.mock.patch("websockets.legacy.server.WebSocketServerProtocol.send") def test_server_handler_crashes(self, send): send.side_effect = ValueError("send failed") @@ -1091,7 +1090,7 @@ def test_server_handler_crashes(self, send): self.assertEqual(self.client.close_code, 1011) @with_server() - @unittest.mock.patch("websockets.asyncio_server.WebSocketServerProtocol.close") + @unittest.mock.patch("websockets.legacy.server.WebSocketServerProtocol.close") def test_server_close_crashes(self, close): close.side_effect = ValueError("close failed") @@ -1164,10 +1163,10 @@ def test_invalid_status_error_during_client_connect(self): @with_server() @unittest.mock.patch( - "websockets.server.WebSocketServerProtocol.write_http_response" + "websockets.legacy.server.WebSocketServerProtocol.write_http_response" ) @unittest.mock.patch( - "websockets.asyncio_server.WebSocketServerProtocol.read_http_request" + "websockets.legacy.server.WebSocketServerProtocol.read_http_request" ) def test_connection_error_during_opening_handshake( self, _read_http_request, _write_http_response @@ -1186,7 +1185,7 @@ def test_connection_error_during_opening_handshake( _write_http_response.assert_not_called() @with_server() - @unittest.mock.patch("websockets.asyncio_server.WebSocketServerProtocol.close") + @unittest.mock.patch("websockets.legacy.server.WebSocketServerProtocol.close") def test_connection_error_during_closing_handshake(self, close): close.side_effect = ConnectionError diff --git a/tests/legacy/test_framing.py b/tests/legacy/test_framing.py new file mode 100644 index 000000000..ac870c79e --- /dev/null +++ b/tests/legacy/test_framing.py @@ -0,0 +1,171 @@ +import asyncio +import codecs +import unittest +import unittest.mock +import warnings + +from websockets.exceptions import PayloadTooBig, ProtocolError +from websockets.frames import OP_BINARY, OP_CLOSE, OP_PING, OP_PONG, OP_TEXT +from websockets.legacy.framing import * + +from .utils import AsyncioTestCase + + +class FramingTests(AsyncioTestCase): + def decode(self, message, mask=False, max_size=None, extensions=None): + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(message) + stream.feed_eof() + with warnings.catch_warnings(record=True): + frame = self.loop.run_until_complete( + Frame.read( + stream.readexactly, + mask=mask, + max_size=max_size, + extensions=extensions, + ) + ) + # Make sure all the data was consumed. + self.assertTrue(stream.at_eof()) + return frame + + def encode(self, frame, mask=False, extensions=None): + write = unittest.mock.Mock() + with warnings.catch_warnings(record=True): + frame.write(write, mask=mask, extensions=extensions) + # Ensure the entire frame is sent with a single call to write(). + # Multiple calls cause TCP fragmentation and degrade performance. + self.assertEqual(write.call_count, 1) + # The frame data is the single positional argument of that call. + self.assertEqual(len(write.call_args[0]), 1) + self.assertEqual(len(write.call_args[1]), 0) + return write.call_args[0][0] + + def round_trip(self, message, expected, mask=False, extensions=None): + decoded = self.decode(message, mask, extensions=extensions) + self.assertEqual(decoded, expected) + encoded = self.encode(decoded, mask, extensions=extensions) + if mask: # non-deterministic encoding + decoded = self.decode(encoded, mask, extensions=extensions) + self.assertEqual(decoded, expected) + else: # deterministic encoding + self.assertEqual(encoded, message) + + def test_text(self): + self.round_trip(b"\x81\x04Spam", Frame(True, OP_TEXT, b"Spam")) + + def test_text_masked(self): + self.round_trip( + b"\x81\x84\x5b\xfb\xe1\xa8\x08\x8b\x80\xc5", + Frame(True, OP_TEXT, b"Spam"), + mask=True, + ) + + def test_binary(self): + self.round_trip(b"\x82\x04Eggs", Frame(True, OP_BINARY, b"Eggs")) + + def test_binary_masked(self): + self.round_trip( + b"\x82\x84\x53\xcd\xe2\x89\x16\xaa\x85\xfa", + Frame(True, OP_BINARY, b"Eggs"), + mask=True, + ) + + def test_non_ascii_text(self): + self.round_trip( + b"\x81\x05caf\xc3\xa9", Frame(True, OP_TEXT, "café".encode("utf-8")) + ) + + def test_non_ascii_text_masked(self): + self.round_trip( + b"\x81\x85\x64\xbe\xee\x7e\x07\xdf\x88\xbd\xcd", + Frame(True, OP_TEXT, "café".encode("utf-8")), + mask=True, + ) + + def test_close(self): + self.round_trip(b"\x88\x00", Frame(True, OP_CLOSE, b"")) + + def test_ping(self): + self.round_trip(b"\x89\x04ping", Frame(True, OP_PING, b"ping")) + + def test_pong(self): + self.round_trip(b"\x8a\x04pong", Frame(True, OP_PONG, b"pong")) + + def test_long(self): + self.round_trip( + b"\x82\x7e\x00\x7e" + 126 * b"a", Frame(True, OP_BINARY, 126 * b"a") + ) + + def test_very_long(self): + self.round_trip( + b"\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x00" + 65536 * b"a", + Frame(True, OP_BINARY, 65536 * b"a"), + ) + + def test_payload_too_big(self): + with self.assertRaises(PayloadTooBig): + self.decode(b"\x82\x7e\x04\x01" + 1025 * b"a", max_size=1024) + + def test_bad_reserved_bits(self): + for encoded in [b"\xc0\x00", b"\xa0\x00", b"\x90\x00"]: + with self.subTest(encoded=encoded): + with self.assertRaises(ProtocolError): + self.decode(encoded) + + def test_good_opcode(self): + for opcode in list(range(0x00, 0x03)) + list(range(0x08, 0x0B)): + encoded = bytes([0x80 | opcode, 0]) + with self.subTest(encoded=encoded): + self.decode(encoded) # does not raise an exception + + def test_bad_opcode(self): + for opcode in list(range(0x03, 0x08)) + list(range(0x0B, 0x10)): + encoded = bytes([0x80 | opcode, 0]) + with self.subTest(encoded=encoded): + with self.assertRaises(ProtocolError): + self.decode(encoded) + + def test_mask_flag(self): + # Mask flag correctly set. + self.decode(b"\x80\x80\x00\x00\x00\x00", mask=True) + # Mask flag incorrectly unset. + with self.assertRaises(ProtocolError): + self.decode(b"\x80\x80\x00\x00\x00\x00") + # Mask flag correctly unset. + self.decode(b"\x80\x00") + # Mask flag incorrectly set. + with self.assertRaises(ProtocolError): + self.decode(b"\x80\x00", mask=True) + + def test_control_frame_max_length(self): + # At maximum allowed length. + self.decode(b"\x88\x7e\x00\x7d" + 125 * b"a") + # Above maximum allowed length. + with self.assertRaises(ProtocolError): + self.decode(b"\x88\x7e\x00\x7e" + 126 * b"a") + + def test_fragmented_control_frame(self): + # Fin bit correctly set. + self.decode(b"\x88\x00") + # Fin bit incorrectly unset. + with self.assertRaises(ProtocolError): + self.decode(b"\x08\x00") + + def test_extensions(self): + class Rot13: + @staticmethod + def encode(frame): + assert frame.opcode == OP_TEXT + text = frame.data.decode() + data = codecs.encode(text, "rot13").encode() + return frame._replace(data=data) + + # This extensions is symmetrical. + @staticmethod + def decode(frame, *, max_size=None): + return Rot13.encode(frame) + + self.round_trip( + b"\x81\x05uryyb", Frame(True, OP_TEXT, b"hello"), extensions=[Rot13()] + ) diff --git a/tests/test_handshake_legacy.py b/tests/legacy/test_handshake.py similarity index 99% rename from tests/test_handshake_legacy.py rename to tests/legacy/test_handshake.py index c34b94e41..661ae64fc 100644 --- a/tests/test_handshake_legacy.py +++ b/tests/legacy/test_handshake.py @@ -8,7 +8,7 @@ InvalidHeaderValue, InvalidUpgrade, ) -from websockets.handshake_legacy import * +from websockets.legacy.handshake import * from websockets.utils import accept_key diff --git a/tests/test_http_legacy.py b/tests/legacy/test_http.py similarity index 98% rename from tests/test_http_legacy.py rename to tests/legacy/test_http.py index e4c75315e..5c9adc97f 100644 --- a/tests/test_http_legacy.py +++ b/tests/legacy/test_http.py @@ -1,8 +1,8 @@ import asyncio from websockets.exceptions import SecurityError -from websockets.http_legacy import * -from websockets.http_legacy import read_headers +from websockets.legacy.http import * +from websockets.legacy.http import read_headers from .utils import AsyncioTestCase diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py new file mode 100644 index 000000000..218d05376 --- /dev/null +++ b/tests/legacy/test_protocol.py @@ -0,0 +1,1489 @@ +import asyncio +import contextlib +import sys +import unittest +import unittest.mock +import warnings + +from websockets.exceptions import ConnectionClosed, InvalidState +from websockets.frames import ( + OP_BINARY, + OP_CLOSE, + OP_CONT, + OP_PING, + OP_PONG, + OP_TEXT, + serialize_close, +) +from websockets.legacy.framing import Frame +from websockets.legacy.protocol import State, WebSocketCommonProtocol + +from .utils import MS, AsyncioTestCase + + +async def async_iterable(iterable): + for item in iterable: + yield item + + +class TransportMock(unittest.mock.Mock): + """ + Transport mock to control the protocol's inputs and outputs in tests. + + It calls the protocol's connection_made and connection_lost methods like + actual transports. + + It also calls the protocol's connection_open method to bypass the + WebSocket handshake. + + To simulate incoming data, tests call the protocol's data_received and + eof_received methods directly. + + They could also pause_writing and resume_writing to test flow control. + + """ + + # This should happen in __init__ but overriding Mock.__init__ is hard. + def setup_mock(self, loop, protocol): + self.loop = loop + self.protocol = protocol + self._eof = False + self._closing = False + # Simulate a successful TCP handshake. + self.protocol.connection_made(self) + # Simulate a successful WebSocket handshake. + self.protocol.connection_open() + + def can_write_eof(self): + return True + + def write_eof(self): + # When the protocol half-closes the TCP connection, it expects the + # other end to close it. Simulate that. + if not self._eof: + self.loop.call_soon(self.close) + self._eof = True + + def close(self): + # Simulate how actual transports drop the connection. + if not self._closing: + self.loop.call_soon(self.protocol.connection_lost, None) + self._closing = True + + def abort(self): + # Change this to an `if` if tests call abort() multiple times. + assert self.protocol.state is not State.CLOSED + self.loop.call_soon(self.protocol.connection_lost, None) + + +class CommonTests: + """ + Mixin that defines most tests but doesn't inherit unittest.TestCase. + + Tests are run by the ServerTests and ClientTests subclasses. + + """ + + def setUp(self): + super().setUp() + # Disable pings to make it easier to test what frames are sent exactly. + self.protocol = WebSocketCommonProtocol(ping_interval=None) + self.transport = TransportMock() + self.transport.setup_mock(self.loop, self.protocol) + + def tearDown(self): + self.transport.close() + self.loop.run_until_complete(self.protocol.close()) + super().tearDown() + + # Utilities for writing tests. + + def make_drain_slow(self, delay=MS): + # Process connection_made in order to initialize self.protocol.transport. + self.run_loop_once() + + original_drain = self.protocol._drain + + async def delayed_drain(): + await asyncio.sleep( + delay, loop=self.loop if sys.version_info[:2] < (3, 8) else None + ) + await original_drain() + + self.protocol._drain = delayed_drain + + close_frame = Frame(True, OP_CLOSE, serialize_close(1000, "close")) + local_close = Frame(True, OP_CLOSE, serialize_close(1000, "local")) + remote_close = Frame(True, OP_CLOSE, serialize_close(1000, "remote")) + + def receive_frame(self, frame): + """ + Make the protocol receive a frame. + + """ + write = self.protocol.data_received + mask = not self.protocol.is_client + frame.write(write, mask=mask) + + def receive_eof(self): + """ + Make the protocol receive the end of the data stream. + + Since ``WebSocketCommonProtocol.eof_received`` returns ``None``, an + actual transport would close itself after calling it. This function + emulates that behavior. + + """ + self.protocol.eof_received() + self.loop.call_soon(self.transport.close) + + def receive_eof_if_client(self): + """ + Like receive_eof, but only if this is the client side. + + Since the server is supposed to initiate the termination of the TCP + connection, this method helps making tests work for both sides. + + """ + if self.protocol.is_client: + self.receive_eof() + + def close_connection(self, code=1000, reason="close"): + """ + Execute a closing handshake. + + This puts the connection in the CLOSED state. + + """ + close_frame_data = serialize_close(code, reason) + # Prepare the response to the closing handshake from the remote side. + self.receive_frame(Frame(True, OP_CLOSE, close_frame_data)) + self.receive_eof_if_client() + # Trigger the closing handshake from the local side and complete it. + self.loop.run_until_complete(self.protocol.close(code, reason)) + # Empty the outgoing data stream so we can make assertions later on. + self.assertOneFrameSent(True, OP_CLOSE, close_frame_data) + + assert self.protocol.state is State.CLOSED + + def half_close_connection_local(self, code=1000, reason="close"): + """ + Start a closing handshake but do not complete it. + + The main difference with `close_connection` is that the connection is + left in the CLOSING state until the event loop runs again. + + The current implementation returns a task that must be awaited or + canceled, else asyncio complains about destroying a pending task. + + """ + close_frame_data = serialize_close(code, reason) + # Trigger the closing handshake from the local endpoint. + close_task = self.loop.create_task(self.protocol.close(code, reason)) + self.run_loop_once() # wait_for executes + self.run_loop_once() # write_frame executes + # Empty the outgoing data stream so we can make assertions later on. + self.assertOneFrameSent(True, OP_CLOSE, close_frame_data) + + assert self.protocol.state is State.CLOSING + + # Complete the closing sequence at 1ms intervals so the test can run + # at each point even it goes back to the event loop several times. + self.loop.call_later( + MS, self.receive_frame, Frame(True, OP_CLOSE, close_frame_data) + ) + self.loop.call_later(2 * MS, self.receive_eof_if_client) + + # This task must be awaited or canceled by the caller. + return close_task + + def half_close_connection_remote(self, code=1000, reason="close"): + """ + Receive a closing handshake but do not complete it. + + The main difference with `close_connection` is that the connection is + left in the CLOSING state until the event loop runs again. + + """ + # On the server side, websockets completes the closing handshake and + # closes the TCP connection immediately. Yield to the event loop after + # sending the close frame to run the test while the connection is in + # the CLOSING state. + if not self.protocol.is_client: + self.make_drain_slow() + + close_frame_data = serialize_close(code, reason) + # Trigger the closing handshake from the remote endpoint. + self.receive_frame(Frame(True, OP_CLOSE, close_frame_data)) + self.run_loop_once() # read_frame executes + # Empty the outgoing data stream so we can make assertions later on. + self.assertOneFrameSent(True, OP_CLOSE, close_frame_data) + + assert self.protocol.state is State.CLOSING + + # Complete the closing sequence at 1ms intervals so the test can run + # at each point even it goes back to the event loop several times. + self.loop.call_later(2 * MS, self.receive_eof_if_client) + + def process_invalid_frames(self): + """ + Make the protocol fail quickly after simulating invalid data. + + To achieve this, this function triggers the protocol's eof_received, + which interrupts pending reads waiting for more data. + + """ + self.run_loop_once() + self.receive_eof() + self.loop.run_until_complete(self.protocol.close_connection_task) + + def sent_frames(self): + """ + Read all frames sent to the transport. + + """ + stream = asyncio.StreamReader(loop=self.loop) + + for (data,), kw in self.transport.write.call_args_list: + stream.feed_data(data) + self.transport.write.call_args_list = [] + stream.feed_eof() + + frames = [] + while not stream.at_eof(): + frames.append( + self.loop.run_until_complete( + Frame.read(stream.readexactly, mask=self.protocol.is_client) + ) + ) + return frames + + def last_sent_frame(self): + """ + Read the last frame sent to the transport. + + This method assumes that at most one frame was sent. It raises an + AssertionError otherwise. + + """ + frames = self.sent_frames() + if frames: + assert len(frames) == 1 + return frames[0] + + def assertFramesSent(self, *frames): + self.assertEqual(self.sent_frames(), [Frame(*args) for args in frames]) + + def assertOneFrameSent(self, *args): + self.assertEqual(self.last_sent_frame(), Frame(*args)) + + def assertNoFrameSent(self): + self.assertIsNone(self.last_sent_frame()) + + def assertConnectionClosed(self, code, message): + # The following line guarantees that connection_lost was called. + self.assertEqual(self.protocol.state, State.CLOSED) + # A close frame was received. + self.assertEqual(self.protocol.close_code, code) + self.assertEqual(self.protocol.close_reason, message) + + def assertConnectionFailed(self, code, message): + # The following line guarantees that connection_lost was called. + self.assertEqual(self.protocol.state, State.CLOSED) + # No close frame was received. + self.assertEqual(self.protocol.close_code, 1006) + self.assertEqual(self.protocol.close_reason, "") + # A close frame was sent -- unless the connection was already lost. + if code == 1006: + self.assertNoFrameSent() + else: + self.assertOneFrameSent(True, OP_CLOSE, serialize_close(code, message)) + + @contextlib.contextmanager + def assertCompletesWithin(self, min_time, max_time): + t0 = self.loop.time() + yield + t1 = self.loop.time() + dt = t1 - t0 + self.assertGreaterEqual(dt, min_time, f"Too fast: {dt} < {min_time}") + self.assertLess(dt, max_time, f"Too slow: {dt} >= {max_time}") + + # Test constructor. + + def test_timeout_backwards_compatibility(self): + with warnings.catch_warnings(record=True) as recorded_warnings: + protocol = WebSocketCommonProtocol(timeout=5) + + self.assertEqual(protocol.close_timeout, 5) + + self.assertEqual(len(recorded_warnings), 1) + warning = recorded_warnings[0].message + self.assertEqual(str(warning), "rename timeout to close_timeout") + self.assertEqual(type(warning), DeprecationWarning) + + # Test public attributes. + + def test_local_address(self): + get_extra_info = unittest.mock.Mock(return_value=("host", 4312)) + self.transport.get_extra_info = get_extra_info + + self.assertEqual(self.protocol.local_address, ("host", 4312)) + get_extra_info.assert_called_with("sockname") + + def test_local_address_before_connection(self): + # Emulate the situation before connection_open() runs. + _transport = self.protocol.transport + del self.protocol.transport + try: + self.assertEqual(self.protocol.local_address, None) + finally: + self.protocol.transport = _transport + + def test_remote_address(self): + get_extra_info = unittest.mock.Mock(return_value=("host", 4312)) + self.transport.get_extra_info = get_extra_info + + self.assertEqual(self.protocol.remote_address, ("host", 4312)) + get_extra_info.assert_called_with("peername") + + def test_remote_address_before_connection(self): + # Emulate the situation before connection_open() runs. + _transport = self.protocol.transport + del self.protocol.transport + try: + self.assertEqual(self.protocol.remote_address, None) + finally: + self.protocol.transport = _transport + + def test_open(self): + self.assertTrue(self.protocol.open) + self.close_connection() + self.assertFalse(self.protocol.open) + + def test_closed(self): + self.assertFalse(self.protocol.closed) + self.close_connection() + self.assertTrue(self.protocol.closed) + + def test_wait_closed(self): + wait_closed = self.loop.create_task(self.protocol.wait_closed()) + self.assertFalse(wait_closed.done()) + self.close_connection() + self.assertTrue(wait_closed.done()) + + # Test the recv coroutine. + + def test_recv_text(self): + self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) + data = self.loop.run_until_complete(self.protocol.recv()) + self.assertEqual(data, "café") + + def test_recv_binary(self): + self.receive_frame(Frame(True, OP_BINARY, b"tea")) + data = self.loop.run_until_complete(self.protocol.recv()) + self.assertEqual(data, b"tea") + + def test_recv_on_closing_connection_local(self): + close_task = self.half_close_connection_local() + + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.protocol.recv()) + + self.loop.run_until_complete(close_task) # cleanup + + def test_recv_on_closing_connection_remote(self): + self.half_close_connection_remote() + + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.protocol.recv()) + + def test_recv_on_closed_connection(self): + self.close_connection() + + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.protocol.recv()) + + def test_recv_protocol_error(self): + self.receive_frame(Frame(True, OP_CONT, "café".encode("utf-8"))) + self.process_invalid_frames() + self.assertConnectionFailed(1002, "") + + def test_recv_unicode_error(self): + self.receive_frame(Frame(True, OP_TEXT, "café".encode("latin-1"))) + self.process_invalid_frames() + self.assertConnectionFailed(1007, "") + + def test_recv_text_payload_too_big(self): + self.protocol.max_size = 1024 + self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8") * 205)) + self.process_invalid_frames() + self.assertConnectionFailed(1009, "") + + def test_recv_binary_payload_too_big(self): + self.protocol.max_size = 1024 + self.receive_frame(Frame(True, OP_BINARY, b"tea" * 342)) + self.process_invalid_frames() + self.assertConnectionFailed(1009, "") + + def test_recv_text_no_max_size(self): + self.protocol.max_size = None # for test coverage + self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8") * 205)) + data = self.loop.run_until_complete(self.protocol.recv()) + self.assertEqual(data, "café" * 205) + + def test_recv_binary_no_max_size(self): + self.protocol.max_size = None # for test coverage + self.receive_frame(Frame(True, OP_BINARY, b"tea" * 342)) + data = self.loop.run_until_complete(self.protocol.recv()) + self.assertEqual(data, b"tea" * 342) + + def test_recv_queue_empty(self): + recv = self.loop.create_task(self.protocol.recv()) + with self.assertRaises(asyncio.TimeoutError): + self.loop.run_until_complete( + asyncio.wait_for(asyncio.shield(recv), timeout=MS) + ) + + self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) + data = self.loop.run_until_complete(recv) + self.assertEqual(data, "café") + + def test_recv_queue_full(self): + self.protocol.max_queue = 2 + # Test internals because it's hard to verify buffers from the outside. + self.assertEqual(list(self.protocol.messages), []) + + self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) + self.run_loop_once() + self.assertEqual(list(self.protocol.messages), ["café"]) + + self.receive_frame(Frame(True, OP_BINARY, b"tea")) + self.run_loop_once() + self.assertEqual(list(self.protocol.messages), ["café", b"tea"]) + + self.receive_frame(Frame(True, OP_BINARY, b"milk")) + self.run_loop_once() + self.assertEqual(list(self.protocol.messages), ["café", b"tea"]) + + self.loop.run_until_complete(self.protocol.recv()) + self.run_loop_once() + self.assertEqual(list(self.protocol.messages), [b"tea", b"milk"]) + + self.loop.run_until_complete(self.protocol.recv()) + self.run_loop_once() + self.assertEqual(list(self.protocol.messages), [b"milk"]) + + self.loop.run_until_complete(self.protocol.recv()) + self.run_loop_once() + self.assertEqual(list(self.protocol.messages), []) + + def test_recv_queue_no_limit(self): + self.protocol.max_queue = None + + for _ in range(100): + self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) + self.run_loop_once() + + # Incoming message queue can contain at least 100 messages. + self.assertEqual(list(self.protocol.messages), ["café"] * 100) + + for _ in range(100): + self.loop.run_until_complete(self.protocol.recv()) + + self.assertEqual(list(self.protocol.messages), []) + + def test_recv_other_error(self): + async def read_message(): + raise Exception("BOOM") + + self.protocol.read_message = read_message + self.process_invalid_frames() + self.assertConnectionFailed(1011, "") + + def test_recv_canceled(self): + recv = self.loop.create_task(self.protocol.recv()) + self.loop.call_soon(recv.cancel) + + with self.assertRaises(asyncio.CancelledError): + self.loop.run_until_complete(recv) + + # The next frame doesn't disappear in a vacuum (it used to). + self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) + data = self.loop.run_until_complete(self.protocol.recv()) + self.assertEqual(data, "café") + + def test_recv_canceled_race_condition(self): + recv = self.loop.create_task( + asyncio.wait_for(self.protocol.recv(), timeout=0.000_001) + ) + self.loop.call_soon( + self.receive_frame, Frame(True, OP_TEXT, "café".encode("utf-8")) + ) + + with self.assertRaises(asyncio.TimeoutError): + self.loop.run_until_complete(recv) + + # The previous frame doesn't disappear in a vacuum (it used to). + self.receive_frame(Frame(True, OP_TEXT, "tea".encode("utf-8"))) + data = self.loop.run_until_complete(self.protocol.recv()) + # If we're getting "tea" there, it means "café" was swallowed (ha, ha). + self.assertEqual(data, "café") + + def test_recv_when_transfer_data_cancelled(self): + # Clog incoming queue. + self.protocol.max_queue = 1 + self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) + self.receive_frame(Frame(True, OP_BINARY, b"tea")) + self.run_loop_once() + + # Flow control kicks in (check with an implementation detail). + self.assertFalse(self.protocol._put_message_waiter.done()) + + # Schedule recv(). + recv = self.loop.create_task(self.protocol.recv()) + + # Cancel transfer_data_task (again, implementation detail). + self.protocol.fail_connection() + self.run_loop_once() + self.assertTrue(self.protocol.transfer_data_task.cancelled()) + + # recv() completes properly. + self.assertEqual(self.loop.run_until_complete(recv), "café") + + def test_recv_prevents_concurrent_calls(self): + recv = self.loop.create_task(self.protocol.recv()) + + with self.assertRaises(RuntimeError) as raised: + self.loop.run_until_complete(self.protocol.recv()) + self.assertEqual( + str(raised.exception), + "cannot call recv while another coroutine " + "is already waiting for the next message", + ) + recv.cancel() + + # Test the send coroutine. + + def test_send_text(self): + self.loop.run_until_complete(self.protocol.send("café")) + self.assertOneFrameSent(True, OP_TEXT, "café".encode("utf-8")) + + def test_send_binary(self): + self.loop.run_until_complete(self.protocol.send(b"tea")) + self.assertOneFrameSent(True, OP_BINARY, b"tea") + + def test_send_binary_from_bytearray(self): + self.loop.run_until_complete(self.protocol.send(bytearray(b"tea"))) + self.assertOneFrameSent(True, OP_BINARY, b"tea") + + def test_send_binary_from_memoryview(self): + self.loop.run_until_complete(self.protocol.send(memoryview(b"tea"))) + self.assertOneFrameSent(True, OP_BINARY, b"tea") + + def test_send_binary_from_non_contiguous_memoryview(self): + self.loop.run_until_complete(self.protocol.send(memoryview(b"tteeaa")[::2])) + self.assertOneFrameSent(True, OP_BINARY, b"tea") + + def test_send_dict(self): + with self.assertRaises(TypeError): + self.loop.run_until_complete(self.protocol.send({"not": "encoded"})) + self.assertNoFrameSent() + + def test_send_type_error(self): + with self.assertRaises(TypeError): + self.loop.run_until_complete(self.protocol.send(42)) + self.assertNoFrameSent() + + def test_send_iterable_text(self): + self.loop.run_until_complete(self.protocol.send(["ca", "fé"])) + self.assertFramesSent( + (False, OP_TEXT, "ca".encode("utf-8")), + (False, OP_CONT, "fé".encode("utf-8")), + (True, OP_CONT, "".encode("utf-8")), + ) + + def test_send_iterable_binary(self): + self.loop.run_until_complete(self.protocol.send([b"te", b"a"])) + self.assertFramesSent( + (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") + ) + + def test_send_iterable_binary_from_bytearray(self): + self.loop.run_until_complete( + self.protocol.send([bytearray(b"te"), bytearray(b"a")]) + ) + self.assertFramesSent( + (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") + ) + + def test_send_iterable_binary_from_memoryview(self): + self.loop.run_until_complete( + self.protocol.send([memoryview(b"te"), memoryview(b"a")]) + ) + self.assertFramesSent( + (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") + ) + + def test_send_iterable_binary_from_non_contiguous_memoryview(self): + self.loop.run_until_complete( + self.protocol.send([memoryview(b"ttee")[::2], memoryview(b"aa")[::2]]) + ) + self.assertFramesSent( + (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") + ) + + def test_send_empty_iterable(self): + self.loop.run_until_complete(self.protocol.send([])) + self.assertNoFrameSent() + + def test_send_iterable_type_error(self): + with self.assertRaises(TypeError): + self.loop.run_until_complete(self.protocol.send([42])) + self.assertNoFrameSent() + + def test_send_iterable_mixed_type_error(self): + with self.assertRaises(TypeError): + self.loop.run_until_complete(self.protocol.send(["café", b"tea"])) + self.assertFramesSent( + (False, OP_TEXT, "café".encode("utf-8")), + (True, OP_CLOSE, serialize_close(1011, "")), + ) + + def test_send_iterable_prevents_concurrent_send(self): + self.make_drain_slow(2 * MS) + + async def send_iterable(): + await self.protocol.send(["ca", "fé"]) + + async def send_concurrent(): + await asyncio.sleep(MS) + await self.protocol.send(b"tea") + + self.loop.run_until_complete(asyncio.gather(send_iterable(), send_concurrent())) + self.assertFramesSent( + (False, OP_TEXT, "ca".encode("utf-8")), + (False, OP_CONT, "fé".encode("utf-8")), + (True, OP_CONT, "".encode("utf-8")), + (True, OP_BINARY, b"tea"), + ) + + def test_send_async_iterable_text(self): + self.loop.run_until_complete(self.protocol.send(async_iterable(["ca", "fé"]))) + self.assertFramesSent( + (False, OP_TEXT, "ca".encode("utf-8")), + (False, OP_CONT, "fé".encode("utf-8")), + (True, OP_CONT, "".encode("utf-8")), + ) + + def test_send_async_iterable_binary(self): + self.loop.run_until_complete(self.protocol.send(async_iterable([b"te", b"a"]))) + self.assertFramesSent( + (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") + ) + + def test_send_async_iterable_binary_from_bytearray(self): + self.loop.run_until_complete( + self.protocol.send(async_iterable([bytearray(b"te"), bytearray(b"a")])) + ) + self.assertFramesSent( + (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") + ) + + def test_send_async_iterable_binary_from_memoryview(self): + self.loop.run_until_complete( + self.protocol.send(async_iterable([memoryview(b"te"), memoryview(b"a")])) + ) + self.assertFramesSent( + (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") + ) + + def test_send_async_iterable_binary_from_non_contiguous_memoryview(self): + self.loop.run_until_complete( + self.protocol.send( + async_iterable([memoryview(b"ttee")[::2], memoryview(b"aa")[::2]]) + ) + ) + self.assertFramesSent( + (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") + ) + + def test_send_empty_async_iterable(self): + self.loop.run_until_complete(self.protocol.send(async_iterable([]))) + self.assertNoFrameSent() + + def test_send_async_iterable_type_error(self): + with self.assertRaises(TypeError): + self.loop.run_until_complete(self.protocol.send(async_iterable([42]))) + self.assertNoFrameSent() + + def test_send_async_iterable_mixed_type_error(self): + with self.assertRaises(TypeError): + self.loop.run_until_complete( + self.protocol.send(async_iterable(["café", b"tea"])) + ) + self.assertFramesSent( + (False, OP_TEXT, "café".encode("utf-8")), + (True, OP_CLOSE, serialize_close(1011, "")), + ) + + def test_send_async_iterable_prevents_concurrent_send(self): + self.make_drain_slow(2 * MS) + + async def send_async_iterable(): + await self.protocol.send(async_iterable(["ca", "fé"])) + + async def send_concurrent(): + await asyncio.sleep(MS) + await self.protocol.send(b"tea") + + self.loop.run_until_complete( + asyncio.gather(send_async_iterable(), send_concurrent()) + ) + self.assertFramesSent( + (False, OP_TEXT, "ca".encode("utf-8")), + (False, OP_CONT, "fé".encode("utf-8")), + (True, OP_CONT, "".encode("utf-8")), + (True, OP_BINARY, b"tea"), + ) + + def test_send_on_closing_connection_local(self): + close_task = self.half_close_connection_local() + + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.protocol.send("foobar")) + + self.assertNoFrameSent() + + self.loop.run_until_complete(close_task) # cleanup + + def test_send_on_closing_connection_remote(self): + self.half_close_connection_remote() + + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.protocol.send("foobar")) + + self.assertNoFrameSent() + + def test_send_on_closed_connection(self): + self.close_connection() + + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.protocol.send("foobar")) + + self.assertNoFrameSent() + + # Test the ping coroutine. + + def test_ping_default(self): + self.loop.run_until_complete(self.protocol.ping()) + # With our testing tools, it's more convenient to extract the expected + # ping data from the library's internals than from the frame sent. + ping_data = next(iter(self.protocol.pings)) + self.assertIsInstance(ping_data, bytes) + self.assertEqual(len(ping_data), 4) + self.assertOneFrameSent(True, OP_PING, ping_data) + + def test_ping_text(self): + self.loop.run_until_complete(self.protocol.ping("café")) + self.assertOneFrameSent(True, OP_PING, "café".encode("utf-8")) + + def test_ping_binary(self): + self.loop.run_until_complete(self.protocol.ping(b"tea")) + self.assertOneFrameSent(True, OP_PING, b"tea") + + def test_ping_binary_from_bytearray(self): + self.loop.run_until_complete(self.protocol.ping(bytearray(b"tea"))) + self.assertOneFrameSent(True, OP_PING, b"tea") + + def test_ping_binary_from_memoryview(self): + self.loop.run_until_complete(self.protocol.ping(memoryview(b"tea"))) + self.assertOneFrameSent(True, OP_PING, b"tea") + + def test_ping_binary_from_non_contiguous_memoryview(self): + self.loop.run_until_complete(self.protocol.ping(memoryview(b"tteeaa")[::2])) + self.assertOneFrameSent(True, OP_PING, b"tea") + + def test_ping_type_error(self): + with self.assertRaises(TypeError): + self.loop.run_until_complete(self.protocol.ping(42)) + self.assertNoFrameSent() + + def test_ping_on_closing_connection_local(self): + close_task = self.half_close_connection_local() + + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.protocol.ping()) + + self.assertNoFrameSent() + + self.loop.run_until_complete(close_task) # cleanup + + def test_ping_on_closing_connection_remote(self): + self.half_close_connection_remote() + + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.protocol.ping()) + + self.assertNoFrameSent() + + def test_ping_on_closed_connection(self): + self.close_connection() + + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.protocol.ping()) + + self.assertNoFrameSent() + + # Test the pong coroutine. + + def test_pong_default(self): + self.loop.run_until_complete(self.protocol.pong()) + self.assertOneFrameSent(True, OP_PONG, b"") + + def test_pong_text(self): + self.loop.run_until_complete(self.protocol.pong("café")) + self.assertOneFrameSent(True, OP_PONG, "café".encode("utf-8")) + + def test_pong_binary(self): + self.loop.run_until_complete(self.protocol.pong(b"tea")) + self.assertOneFrameSent(True, OP_PONG, b"tea") + + def test_pong_binary_from_bytearray(self): + self.loop.run_until_complete(self.protocol.pong(bytearray(b"tea"))) + self.assertOneFrameSent(True, OP_PONG, b"tea") + + def test_pong_binary_from_memoryview(self): + self.loop.run_until_complete(self.protocol.pong(memoryview(b"tea"))) + self.assertOneFrameSent(True, OP_PONG, b"tea") + + def test_pong_binary_from_non_contiguous_memoryview(self): + self.loop.run_until_complete(self.protocol.pong(memoryview(b"tteeaa")[::2])) + self.assertOneFrameSent(True, OP_PONG, b"tea") + + def test_pong_type_error(self): + with self.assertRaises(TypeError): + self.loop.run_until_complete(self.protocol.pong(42)) + self.assertNoFrameSent() + + def test_pong_on_closing_connection_local(self): + close_task = self.half_close_connection_local() + + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.protocol.pong()) + + self.assertNoFrameSent() + + self.loop.run_until_complete(close_task) # cleanup + + def test_pong_on_closing_connection_remote(self): + self.half_close_connection_remote() + + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.protocol.pong()) + + self.assertNoFrameSent() + + def test_pong_on_closed_connection(self): + self.close_connection() + + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.protocol.pong()) + + self.assertNoFrameSent() + + # Test the protocol's logic for acknowledging pings with pongs. + + def test_answer_ping(self): + self.receive_frame(Frame(True, OP_PING, b"test")) + self.run_loop_once() + self.assertOneFrameSent(True, OP_PONG, b"test") + + def test_ignore_pong(self): + self.receive_frame(Frame(True, OP_PONG, b"test")) + self.run_loop_once() + self.assertNoFrameSent() + + def test_acknowledge_ping(self): + ping = self.loop.run_until_complete(self.protocol.ping()) + self.assertFalse(ping.done()) + ping_frame = self.last_sent_frame() + pong_frame = Frame(True, OP_PONG, ping_frame.data) + self.receive_frame(pong_frame) + self.run_loop_once() + self.run_loop_once() + self.assertTrue(ping.done()) + + def test_abort_ping(self): + ping = self.loop.run_until_complete(self.protocol.ping()) + # Remove the frame from the buffer, else close_connection() complains. + self.last_sent_frame() + self.assertFalse(ping.done()) + self.close_connection() + self.assertTrue(ping.done()) + self.assertIsInstance(ping.exception(), ConnectionClosed) + + def test_abort_ping_does_not_log_exception_if_not_retreived(self): + self.loop.run_until_complete(self.protocol.ping()) + # Get the internal Future, which isn't directly returned by ping(). + (ping,) = self.protocol.pings.values() + # Remove the frame from the buffer, else close_connection() complains. + self.last_sent_frame() + self.close_connection() + # Check a private attribute, for lack of a better solution. + self.assertFalse(ping._log_traceback) + + def test_acknowledge_previous_pings(self): + pings = [ + (self.loop.run_until_complete(self.protocol.ping()), self.last_sent_frame()) + for i in range(3) + ] + # Unsolicited pong doesn't acknowledge pings + self.receive_frame(Frame(True, OP_PONG, b"")) + self.run_loop_once() + self.run_loop_once() + self.assertFalse(pings[0][0].done()) + self.assertFalse(pings[1][0].done()) + self.assertFalse(pings[2][0].done()) + # Pong acknowledges all previous pings + self.receive_frame(Frame(True, OP_PONG, pings[1][1].data)) + self.run_loop_once() + self.run_loop_once() + self.assertTrue(pings[0][0].done()) + self.assertTrue(pings[1][0].done()) + self.assertFalse(pings[2][0].done()) + + def test_acknowledge_aborted_ping(self): + ping = self.loop.run_until_complete(self.protocol.ping()) + ping_frame = self.last_sent_frame() + # Clog incoming queue. This lets connection_lost() abort pending pings + # with a ConnectionClosed exception before transfer_data_task + # terminates and close_connection cancels keepalive_ping_task. + self.protocol.max_queue = 1 + self.receive_frame(Frame(True, OP_TEXT, b"1")) + self.receive_frame(Frame(True, OP_TEXT, b"2")) + # Add pong frame to the queue. + pong_frame = Frame(True, OP_PONG, ping_frame.data) + self.receive_frame(pong_frame) + # Connection drops. + self.receive_eof() + self.loop.run_until_complete(self.protocol.wait_closed()) + # Ping receives a ConnectionClosed exception. + with self.assertRaises(ConnectionClosed): + ping.result() + + # transfer_data doesn't crash, which would be logged. + with self.assertNoLogs(): + # Unclog incoming queue. + self.loop.run_until_complete(self.protocol.recv()) + self.loop.run_until_complete(self.protocol.recv()) + + def test_canceled_ping(self): + ping = self.loop.run_until_complete(self.protocol.ping()) + ping_frame = self.last_sent_frame() + ping.cancel() + pong_frame = Frame(True, OP_PONG, ping_frame.data) + self.receive_frame(pong_frame) + self.run_loop_once() + self.run_loop_once() + self.assertTrue(ping.cancelled()) + + def test_duplicate_ping(self): + self.loop.run_until_complete(self.protocol.ping(b"foobar")) + self.assertOneFrameSent(True, OP_PING, b"foobar") + with self.assertRaises(ValueError): + self.loop.run_until_complete(self.protocol.ping(b"foobar")) + self.assertNoFrameSent() + + # Test the protocol's logic for rebuilding fragmented messages. + + def test_fragmented_text(self): + self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) + self.receive_frame(Frame(True, OP_CONT, "fé".encode("utf-8"))) + data = self.loop.run_until_complete(self.protocol.recv()) + self.assertEqual(data, "café") + + def test_fragmented_binary(self): + self.receive_frame(Frame(False, OP_BINARY, b"t")) + self.receive_frame(Frame(False, OP_CONT, b"e")) + self.receive_frame(Frame(True, OP_CONT, b"a")) + data = self.loop.run_until_complete(self.protocol.recv()) + self.assertEqual(data, b"tea") + + def test_fragmented_text_payload_too_big(self): + self.protocol.max_size = 1024 + self.receive_frame(Frame(False, OP_TEXT, "café".encode("utf-8") * 100)) + self.receive_frame(Frame(True, OP_CONT, "café".encode("utf-8") * 105)) + self.process_invalid_frames() + self.assertConnectionFailed(1009, "") + + def test_fragmented_binary_payload_too_big(self): + self.protocol.max_size = 1024 + self.receive_frame(Frame(False, OP_BINARY, b"tea" * 171)) + self.receive_frame(Frame(True, OP_CONT, b"tea" * 171)) + self.process_invalid_frames() + self.assertConnectionFailed(1009, "") + + def test_fragmented_text_no_max_size(self): + self.protocol.max_size = None # for test coverage + self.receive_frame(Frame(False, OP_TEXT, "café".encode("utf-8") * 100)) + self.receive_frame(Frame(True, OP_CONT, "café".encode("utf-8") * 105)) + data = self.loop.run_until_complete(self.protocol.recv()) + self.assertEqual(data, "café" * 205) + + def test_fragmented_binary_no_max_size(self): + self.protocol.max_size = None # for test coverage + self.receive_frame(Frame(False, OP_BINARY, b"tea" * 171)) + self.receive_frame(Frame(True, OP_CONT, b"tea" * 171)) + data = self.loop.run_until_complete(self.protocol.recv()) + self.assertEqual(data, b"tea" * 342) + + def test_control_frame_within_fragmented_text(self): + self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) + self.receive_frame(Frame(True, OP_PING, b"")) + self.receive_frame(Frame(True, OP_CONT, "fé".encode("utf-8"))) + data = self.loop.run_until_complete(self.protocol.recv()) + self.assertEqual(data, "café") + self.assertOneFrameSent(True, OP_PONG, b"") + + def test_unterminated_fragmented_text(self): + self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) + # Missing the second part of the fragmented frame. + self.receive_frame(Frame(True, OP_BINARY, b"tea")) + self.process_invalid_frames() + self.assertConnectionFailed(1002, "") + + def test_close_handshake_in_fragmented_text(self): + self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) + self.receive_frame(Frame(True, OP_CLOSE, b"")) + self.process_invalid_frames() + # The RFC may have overlooked this case: it says that control frames + # can be interjected in the middle of a fragmented message and that a + # close frame must be echoed. Even though there's an unterminated + # message, technically, the closing handshake was successful. + self.assertConnectionClosed(1005, "") + + def test_connection_close_in_fragmented_text(self): + self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) + self.process_invalid_frames() + self.assertConnectionFailed(1006, "") + + # Test miscellaneous code paths to ensure full coverage. + + def test_connection_lost(self): + # Test calling connection_lost without going through close_connection. + self.protocol.connection_lost(None) + + self.assertConnectionFailed(1006, "") + + def test_ensure_open_before_opening_handshake(self): + # Simulate a bug by forcibly reverting the protocol state. + self.protocol.state = State.CONNECTING + + with self.assertRaises(InvalidState): + self.loop.run_until_complete(self.protocol.ensure_open()) + + def test_ensure_open_during_unclean_close(self): + # Process connection_made in order to start transfer_data_task. + self.run_loop_once() + + # Ensure the test terminates quickly. + self.loop.call_later(MS, self.receive_eof_if_client) + + # Simulate the case when close() times out sending a close frame. + self.protocol.fail_connection() + + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.protocol.ensure_open()) + + def test_legacy_recv(self): + # By default legacy_recv in disabled. + self.assertEqual(self.protocol.legacy_recv, False) + + self.close_connection() + + # Enable legacy_recv. + self.protocol.legacy_recv = True + + # Now recv() returns None instead of raising ConnectionClosed. + self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) + + def test_connection_closed_attributes(self): + self.close_connection() + + with self.assertRaises(ConnectionClosed) as context: + self.loop.run_until_complete(self.protocol.recv()) + + connection_closed_exc = context.exception + self.assertEqual(connection_closed_exc.code, 1000) + self.assertEqual(connection_closed_exc.reason, "close") + + # Test the protocol logic for sending keepalive pings. + + def restart_protocol_with_keepalive_ping( + self, ping_interval=3 * MS, ping_timeout=3 * MS + ): + initial_protocol = self.protocol + # copied from tearDown + self.transport.close() + self.loop.run_until_complete(self.protocol.close()) + # copied from setUp, but enables keepalive pings + self.protocol = WebSocketCommonProtocol( + ping_interval=ping_interval, ping_timeout=ping_timeout + ) + self.transport = TransportMock() + self.transport.setup_mock(self.loop, self.protocol) + self.protocol.is_client = initial_protocol.is_client + self.protocol.side = initial_protocol.side + + def test_keepalive_ping(self): + self.restart_protocol_with_keepalive_ping() + + # Ping is sent at 3ms and acknowledged at 4ms. + self.loop.run_until_complete(asyncio.sleep(4 * MS)) + (ping_1,) = tuple(self.protocol.pings) + self.assertOneFrameSent(True, OP_PING, ping_1) + self.receive_frame(Frame(True, OP_PONG, ping_1)) + + # Next ping is sent at 7ms. + self.loop.run_until_complete(asyncio.sleep(4 * MS)) + (ping_2,) = tuple(self.protocol.pings) + self.assertOneFrameSent(True, OP_PING, ping_2) + + # The keepalive ping task goes on. + self.assertFalse(self.protocol.keepalive_ping_task.done()) + + def test_keepalive_ping_not_acknowledged_closes_connection(self): + self.restart_protocol_with_keepalive_ping() + + # Ping is sent at 3ms and not acknowleged. + self.loop.run_until_complete(asyncio.sleep(4 * MS)) + (ping_1,) = tuple(self.protocol.pings) + self.assertOneFrameSent(True, OP_PING, ping_1) + + # Connection is closed at 6ms. + self.loop.run_until_complete(asyncio.sleep(4 * MS)) + self.assertOneFrameSent(True, OP_CLOSE, serialize_close(1011, "")) + + # The keepalive ping task is complete. + self.assertEqual(self.protocol.keepalive_ping_task.result(), None) + + def test_keepalive_ping_stops_when_connection_closing(self): + self.restart_protocol_with_keepalive_ping() + close_task = self.half_close_connection_local() + + # No ping sent at 3ms because the closing handshake is in progress. + self.loop.run_until_complete(asyncio.sleep(4 * MS)) + self.assertNoFrameSent() + + # The keepalive ping task terminated. + self.assertTrue(self.protocol.keepalive_ping_task.cancelled()) + + self.loop.run_until_complete(close_task) # cleanup + + def test_keepalive_ping_stops_when_connection_closed(self): + self.restart_protocol_with_keepalive_ping() + self.close_connection() + + # The keepalive ping task terminated. + self.assertTrue(self.protocol.keepalive_ping_task.cancelled()) + + def test_keepalive_ping_does_not_crash_when_connection_lost(self): + self.restart_protocol_with_keepalive_ping() + # Clog incoming queue. This lets connection_lost() abort pending pings + # with a ConnectionClosed exception before transfer_data_task + # terminates and close_connection cancels keepalive_ping_task. + self.protocol.max_queue = 1 + self.receive_frame(Frame(True, OP_TEXT, b"1")) + self.receive_frame(Frame(True, OP_TEXT, b"2")) + # Ping is sent at 3ms. + self.loop.run_until_complete(asyncio.sleep(4 * MS)) + (ping_waiter,) = tuple(self.protocol.pings.values()) + # Connection drops. + self.receive_eof() + self.loop.run_until_complete(self.protocol.wait_closed()) + + # The ping waiter receives a ConnectionClosed exception. + with self.assertRaises(ConnectionClosed): + ping_waiter.result() + # The keepalive ping task terminated properly. + self.assertIsNone(self.protocol.keepalive_ping_task.result()) + + # Unclog incoming queue to terminate the test quickly. + self.loop.run_until_complete(self.protocol.recv()) + self.loop.run_until_complete(self.protocol.recv()) + + def test_keepalive_ping_with_no_ping_interval(self): + self.restart_protocol_with_keepalive_ping(ping_interval=None) + + # No ping is sent at 3ms. + self.loop.run_until_complete(asyncio.sleep(4 * MS)) + self.assertNoFrameSent() + + def test_keepalive_ping_with_no_ping_timeout(self): + self.restart_protocol_with_keepalive_ping(ping_timeout=None) + + # Ping is sent at 3ms and not acknowleged. + self.loop.run_until_complete(asyncio.sleep(4 * MS)) + (ping_1,) = tuple(self.protocol.pings) + self.assertOneFrameSent(True, OP_PING, ping_1) + + # Next ping is sent at 7ms anyway. + self.loop.run_until_complete(asyncio.sleep(4 * MS)) + ping_1_again, ping_2 = tuple(self.protocol.pings) + self.assertEqual(ping_1, ping_1_again) + self.assertOneFrameSent(True, OP_PING, ping_2) + + # The keepalive ping task goes on. + self.assertFalse(self.protocol.keepalive_ping_task.done()) + + def test_keepalive_ping_unexpected_error(self): + self.restart_protocol_with_keepalive_ping() + + async def ping(): + raise Exception("BOOM") + + self.protocol.ping = ping + + # The keepalive ping task fails when sending a ping at 3ms. + self.loop.run_until_complete(asyncio.sleep(4 * MS)) + + # The keepalive ping task is complete. + # It logs and swallows the exception. + self.assertEqual(self.protocol.keepalive_ping_task.result(), None) + + # Test the protocol logic for closing the connection. + + def test_local_close(self): + # Emulate how the remote endpoint answers the closing handshake. + self.loop.call_later(MS, self.receive_frame, self.close_frame) + self.loop.call_later(MS, self.receive_eof_if_client) + + # Run the closing handshake. + self.loop.run_until_complete(self.protocol.close(reason="close")) + + self.assertConnectionClosed(1000, "close") + self.assertOneFrameSent(*self.close_frame) + + # Closing the connection again is a no-op. + self.loop.run_until_complete(self.protocol.close(reason="oh noes!")) + + self.assertConnectionClosed(1000, "close") + self.assertNoFrameSent() + + def test_remote_close(self): + # Emulate how the remote endpoint initiates the closing handshake. + self.loop.call_later(MS, self.receive_frame, self.close_frame) + self.loop.call_later(MS, self.receive_eof_if_client) + + # Wait for some data in order to process the handshake. + # After recv() raises ConnectionClosed, the connection is closed. + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.protocol.recv()) + + self.assertConnectionClosed(1000, "close") + self.assertOneFrameSent(*self.close_frame) + + # Closing the connection again is a no-op. + self.loop.run_until_complete(self.protocol.close(reason="oh noes!")) + + self.assertConnectionClosed(1000, "close") + self.assertNoFrameSent() + + def test_remote_close_and_connection_lost(self): + self.make_drain_slow() + # Drop the connection right after receiving a close frame, + # which prevents echoing the close frame properly. + self.receive_frame(self.close_frame) + self.receive_eof() + + with self.assertNoLogs(): + self.loop.run_until_complete(self.protocol.close(reason="oh noes!")) + + self.assertConnectionClosed(1000, "close") + self.assertOneFrameSent(*self.close_frame) + + def test_simultaneous_close(self): + # Receive the incoming close frame right after self.protocol.close() + # starts executing. This reproduces the error described in: + # https://github.com/aaugustin/websockets/issues/339 + self.loop.call_soon(self.receive_frame, self.remote_close) + self.loop.call_soon(self.receive_eof_if_client) + + self.loop.run_until_complete(self.protocol.close(reason="local")) + + self.assertConnectionClosed(1000, "remote") + # The current implementation sends a close frame in response to the + # close frame received from the remote end. It skips the close frame + # that should be sent as a result of calling close(). + self.assertOneFrameSent(*self.remote_close) + + def test_close_preserves_incoming_frames(self): + self.receive_frame(Frame(True, OP_TEXT, b"hello")) + + self.loop.call_later(MS, self.receive_frame, self.close_frame) + self.loop.call_later(MS, self.receive_eof_if_client) + self.loop.run_until_complete(self.protocol.close(reason="close")) + + self.assertConnectionClosed(1000, "close") + self.assertOneFrameSent(*self.close_frame) + + next_message = self.loop.run_until_complete(self.protocol.recv()) + self.assertEqual(next_message, "hello") + + def test_close_protocol_error(self): + invalid_close_frame = Frame(True, OP_CLOSE, b"\x00") + self.receive_frame(invalid_close_frame) + self.receive_eof_if_client() + self.run_loop_once() + self.loop.run_until_complete(self.protocol.close(reason="close")) + + self.assertConnectionFailed(1002, "") + + def test_close_connection_lost(self): + self.receive_eof() + self.run_loop_once() + self.loop.run_until_complete(self.protocol.close(reason="close")) + + self.assertConnectionFailed(1006, "") + + def test_local_close_during_recv(self): + recv = self.loop.create_task(self.protocol.recv()) + + self.loop.call_later(MS, self.receive_frame, self.close_frame) + self.loop.call_later(MS, self.receive_eof_if_client) + + self.loop.run_until_complete(self.protocol.close(reason="close")) + + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(recv) + + self.assertConnectionClosed(1000, "close") + + # There is no test_remote_close_during_recv because it would be identical + # to test_remote_close. + + def test_remote_close_during_send(self): + self.make_drain_slow() + send = self.loop.create_task(self.protocol.send("hello")) + + self.receive_frame(self.close_frame) + self.receive_eof() + + with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(send) + + self.assertConnectionClosed(1000, "close") + + # There is no test_local_close_during_send because this cannot really + # happen, considering that writes are serialized. + + +class ServerTests(CommonTests, AsyncioTestCase): + def setUp(self): + super().setUp() + self.protocol.is_client = False + self.protocol.side = "server" + + def test_local_close_send_close_frame_timeout(self): + self.protocol.close_timeout = 10 * MS + self.make_drain_slow(50 * MS) + # If we can't send a close frame, time out in 10ms. + # Check the timing within -1/+9ms for robustness. + with self.assertCompletesWithin(9 * MS, 19 * MS): + self.loop.run_until_complete(self.protocol.close(reason="close")) + self.assertConnectionClosed(1006, "") + + def test_local_close_receive_close_frame_timeout(self): + self.protocol.close_timeout = 10 * MS + # If the client doesn't send a close frame, time out in 10ms. + # Check the timing within -1/+9ms for robustness. + with self.assertCompletesWithin(9 * MS, 19 * MS): + self.loop.run_until_complete(self.protocol.close(reason="close")) + self.assertConnectionClosed(1006, "") + + def test_local_close_connection_lost_timeout_after_write_eof(self): + self.protocol.close_timeout = 10 * MS + # If the client doesn't close its side of the TCP connection after we + # half-close our side with write_eof(), time out in 10ms. + # Check the timing within -1/+9ms for robustness. + with self.assertCompletesWithin(9 * MS, 19 * MS): + # HACK: disable write_eof => other end drops connection emulation. + self.transport._eof = True + self.receive_frame(self.close_frame) + self.loop.run_until_complete(self.protocol.close(reason="close")) + self.assertConnectionClosed(1000, "close") + + def test_local_close_connection_lost_timeout_after_close(self): + self.protocol.close_timeout = 10 * MS + # If the client doesn't close its side of the TCP connection after we + # half-close our side with write_eof() and close it with close(), time + # out in 20ms. + # Check the timing within -1/+9ms for robustness. + with self.assertCompletesWithin(19 * MS, 29 * MS): + # HACK: disable write_eof => other end drops connection emulation. + self.transport._eof = True + # HACK: disable close => other end drops connection emulation. + self.transport._closing = True + self.receive_frame(self.close_frame) + self.loop.run_until_complete(self.protocol.close(reason="close")) + self.assertConnectionClosed(1000, "close") + + +class ClientTests(CommonTests, AsyncioTestCase): + def setUp(self): + super().setUp() + self.protocol.is_client = True + self.protocol.side = "client" + + def test_local_close_send_close_frame_timeout(self): + self.protocol.close_timeout = 10 * MS + self.make_drain_slow(50 * MS) + # If we can't send a close frame, time out in 20ms. + # - 10ms waiting for sending a close frame + # - 10ms waiting for receiving a half-close + # Check the timing within -1/+9ms for robustness. + with self.assertCompletesWithin(19 * MS, 29 * MS): + self.loop.run_until_complete(self.protocol.close(reason="close")) + self.assertConnectionClosed(1006, "") + + def test_local_close_receive_close_frame_timeout(self): + self.protocol.close_timeout = 10 * MS + # If the server doesn't send a close frame, time out in 20ms: + # - 10ms waiting for receiving a close frame + # - 10ms waiting for receiving a half-close + # Check the timing within -1/+9ms for robustness. + with self.assertCompletesWithin(19 * MS, 29 * MS): + self.loop.run_until_complete(self.protocol.close(reason="close")) + self.assertConnectionClosed(1006, "") + + def test_local_close_connection_lost_timeout_after_write_eof(self): + self.protocol.close_timeout = 10 * MS + # If the server doesn't half-close its side of the TCP connection + # after we send a close frame, time out in 20ms: + # - 10ms waiting for receiving a half-close + # - 10ms waiting for receiving a close after write_eof + # Check the timing within -1/+9ms for robustness. + with self.assertCompletesWithin(19 * MS, 29 * MS): + # HACK: disable write_eof => other end drops connection emulation. + self.transport._eof = True + self.receive_frame(self.close_frame) + self.loop.run_until_complete(self.protocol.close(reason="close")) + self.assertConnectionClosed(1000, "close") + + def test_local_close_connection_lost_timeout_after_close(self): + self.protocol.close_timeout = 10 * MS + # If the client doesn't close its side of the TCP connection after we + # half-close our side with write_eof() and close it with close(), time + # out in 20ms. + # - 10ms waiting for receiving a half-close + # - 10ms waiting for receiving a close after write_eof + # - 10ms waiting for receiving a close after close + # Check the timing within -1/+9ms for robustness. + with self.assertCompletesWithin(29 * MS, 39 * MS): + # HACK: disable write_eof => other end drops connection emulation. + self.transport._eof = True + # HACK: disable close => other end drops connection emulation. + self.transport._closing = True + self.receive_frame(self.close_frame) + self.loop.run_until_complete(self.protocol.close(reason="close")) + self.assertConnectionClosed(1000, "close") diff --git a/tests/legacy/utils.py b/tests/legacy/utils.py new file mode 100644 index 000000000..983a91edf --- /dev/null +++ b/tests/legacy/utils.py @@ -0,0 +1,93 @@ +import asyncio +import contextlib +import functools +import logging +import os +import time +import unittest + + +class AsyncioTestCase(unittest.TestCase): + """ + Base class for tests that sets up an isolated event loop for each test. + + """ + + def __init_subclass__(cls, **kwargs): + """ + Convert test coroutines to test functions. + + This supports asychronous tests transparently. + + """ + super().__init_subclass__(**kwargs) + for name in unittest.defaultTestLoader.getTestCaseNames(cls): + test = getattr(cls, name) + if asyncio.iscoroutinefunction(test): + setattr(cls, name, cls.convert_async_to_sync(test)) + + @staticmethod + def convert_async_to_sync(test): + """ + Convert a test coroutine to a test function. + + """ + + @functools.wraps(test) + def test_func(self, *args, **kwargs): + return self.loop.run_until_complete(test(self, *args, **kwargs)) + + return test_func + + def setUp(self): + super().setUp() + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + super().tearDown() + + def run_loop_once(self): + # Process callbacks scheduled with call_soon by appending a callback + # to stop the event loop then running it until it hits that callback. + self.loop.call_soon(self.loop.stop) + self.loop.run_forever() + + @contextlib.contextmanager + def assertNoLogs(self, logger="websockets", level=logging.ERROR): + """ + No message is logged on the given logger with at least the given level. + + """ + with self.assertLogs(logger, level) as logs: + # We want to test that no log message is emitted + # but assertLogs expects at least one log message. + logging.getLogger(logger).log(level, "dummy") + yield + + level_name = logging.getLevelName(level) + self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) + + def assertDeprecationWarnings(self, recorded_warnings, expected_warnings): + """ + Check recorded deprecation warnings match a list of expected messages. + + """ + self.assertEqual(len(recorded_warnings), len(expected_warnings)) + for recorded, expected in zip(recorded_warnings, expected_warnings): + actual = recorded.message + self.assertEqual(str(actual), expected) + self.assertEqual(type(actual), DeprecationWarning) + + +# Unit for timeouts. May be increased on slow machines by setting the +# WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. +MS = 0.001 * int(os.environ.get("WEBSOCKETS_TESTS_TIMEOUT_FACTOR", 1)) + +# asyncio's debug mode has a 10x performance penalty for this test suite. +if os.environ.get("PYTHONASYNCIODEBUG"): # pragma: no cover + MS *= 10 + +# Ensure that timeouts are larger than the clock's resolution (for Windows). +MS = max(MS, 2.5 * time.get_clock_info("monotonic").resolution) diff --git a/tests/test_auth.py b/tests/test_auth.py index ce23f913d..01ca207c7 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,160 +1,2 @@ -import unittest -import urllib.error - -from websockets.auth import * -from websockets.auth import is_credentials -from websockets.exceptions import InvalidStatusCode -from websockets.headers import build_authorization_basic - -from .test_asyncio_client_server import ClientServerTestsMixin, with_client, with_server -from .utils import AsyncioTestCase - - -class AuthTests(unittest.TestCase): - def test_is_credentials(self): - self.assertTrue(is_credentials(("username", "password"))) - - def test_is_not_credentials(self): - self.assertFalse(is_credentials(None)) - self.assertFalse(is_credentials("username")) - - -class CustomWebSocketServerProtocol(BasicAuthWebSocketServerProtocol): - async def process_request(self, path, request_headers): - type(self).used = True - return await super().process_request(path, request_headers) - - -class AuthClientServerTests(ClientServerTestsMixin, AsyncioTestCase): - - create_protocol = basic_auth_protocol_factory( - realm="auth-tests", credentials=("hello", "iloveyou") - ) - - @with_server(create_protocol=create_protocol) - @with_client(user_info=("hello", "iloveyou")) - def test_basic_auth(self): - req_headers = self.client.request_headers - resp_headers = self.client.response_headers - self.assertEqual(req_headers["Authorization"], "Basic aGVsbG86aWxvdmV5b3U=") - self.assertNotIn("WWW-Authenticate", resp_headers) - - self.loop.run_until_complete(self.client.send("Hello!")) - self.loop.run_until_complete(self.client.recv()) - - def test_basic_auth_server_no_credentials(self): - with self.assertRaises(TypeError) as raised: - basic_auth_protocol_factory(realm="auth-tests", credentials=None) - self.assertEqual( - str(raised.exception), "provide either credentials or check_credentials" - ) - - def test_basic_auth_server_bad_credentials(self): - with self.assertRaises(TypeError) as raised: - basic_auth_protocol_factory(realm="auth-tests", credentials=42) - self.assertEqual(str(raised.exception), "invalid credentials argument: 42") - - create_protocol_multiple_credentials = basic_auth_protocol_factory( - realm="auth-tests", - credentials=[("hello", "iloveyou"), ("goodbye", "stillloveu")], - ) - - @with_server(create_protocol=create_protocol_multiple_credentials) - @with_client(user_info=("hello", "iloveyou")) - def test_basic_auth_server_multiple_credentials(self): - self.loop.run_until_complete(self.client.send("Hello!")) - self.loop.run_until_complete(self.client.recv()) - - def test_basic_auth_bad_multiple_credentials(self): - with self.assertRaises(TypeError) as raised: - basic_auth_protocol_factory( - realm="auth-tests", credentials=[("hello", "iloveyou"), 42] - ) - self.assertEqual( - str(raised.exception), - "invalid credentials argument: [('hello', 'iloveyou'), 42]", - ) - - async def check_credentials(username, password): - return password == "iloveyou" - - create_protocol_check_credentials = basic_auth_protocol_factory( - realm="auth-tests", - check_credentials=check_credentials, - ) - - @with_server(create_protocol=create_protocol_check_credentials) - @with_client(user_info=("hello", "iloveyou")) - def test_basic_auth_check_credentials(self): - self.loop.run_until_complete(self.client.send("Hello!")) - self.loop.run_until_complete(self.client.recv()) - - create_protocol_custom_protocol = basic_auth_protocol_factory( - realm="auth-tests", - credentials=[("hello", "iloveyou")], - create_protocol=CustomWebSocketServerProtocol, - ) - - @with_server(create_protocol=create_protocol_custom_protocol) - @with_client(user_info=("hello", "iloveyou")) - def test_basic_auth_custom_protocol(self): - self.assertTrue(CustomWebSocketServerProtocol.used) - del CustomWebSocketServerProtocol.used - self.loop.run_until_complete(self.client.send("Hello!")) - self.loop.run_until_complete(self.client.recv()) - - @with_server(create_protocol=create_protocol) - def test_basic_auth_missing_credentials(self): - with self.assertRaises(InvalidStatusCode) as raised: - self.start_client() - self.assertEqual(raised.exception.status_code, 401) - - @with_server(create_protocol=create_protocol) - def test_basic_auth_missing_credentials_details(self): - with self.assertRaises(urllib.error.HTTPError) as raised: - self.loop.run_until_complete(self.make_http_request()) - self.assertEqual(raised.exception.code, 401) - self.assertEqual( - raised.exception.headers["WWW-Authenticate"], - 'Basic realm="auth-tests", charset="UTF-8"', - ) - self.assertEqual(raised.exception.read().decode(), "Missing credentials\n") - - @with_server(create_protocol=create_protocol) - def test_basic_auth_unsupported_credentials(self): - with self.assertRaises(InvalidStatusCode) as raised: - self.start_client(extra_headers={"Authorization": "Digest ..."}) - self.assertEqual(raised.exception.status_code, 401) - - @with_server(create_protocol=create_protocol) - def test_basic_auth_unsupported_credentials_details(self): - with self.assertRaises(urllib.error.HTTPError) as raised: - self.loop.run_until_complete( - self.make_http_request(headers={"Authorization": "Digest ..."}) - ) - self.assertEqual(raised.exception.code, 401) - self.assertEqual( - raised.exception.headers["WWW-Authenticate"], - 'Basic realm="auth-tests", charset="UTF-8"', - ) - self.assertEqual(raised.exception.read().decode(), "Unsupported credentials\n") - - @with_server(create_protocol=create_protocol) - def test_basic_auth_invalid_credentials(self): - with self.assertRaises(InvalidStatusCode) as raised: - self.start_client(user_info=("hello", "ihateyou")) - self.assertEqual(raised.exception.status_code, 401) - - @with_server(create_protocol=create_protocol) - def test_basic_auth_invalid_credentials_details(self): - with self.assertRaises(urllib.error.HTTPError) as raised: - authorization = build_authorization_basic("hello", "ihateyou") - self.loop.run_until_complete( - self.make_http_request(headers={"Authorization": authorization}) - ) - self.assertEqual(raised.exception.code, 401) - self.assertEqual( - raised.exception.headers["WWW-Authenticate"], - 'Basic realm="auth-tests", charset="UTF-8"', - ) - self.assertEqual(raised.exception.read().decode(), "Invalid credentials\n") +# Check that the legacy auth module imports without an exception. +from websockets.auth import * # noqa diff --git a/tests/test_exports.py b/tests/test_exports.py index 7fcbc80e3..8e4330304 100644 --- a/tests/test_exports.py +++ b/tests/test_exports.py @@ -4,10 +4,12 @@ combined_exports = ( - websockets.auth.__all__ + websockets.legacy.auth.__all__ + + websockets.legacy.client.__all__ + + websockets.legacy.protocol.__all__ + + websockets.legacy.server.__all__ + websockets.client.__all__ + websockets.exceptions.__all__ - + websockets.protocol.__all__ + websockets.server.__all__ + websockets.typing.__all__ + websockets.uri.__all__ diff --git a/tests/test_framing.py b/tests/test_framing.py index 231cbf718..d6fa6352a 100644 --- a/tests/test_framing.py +++ b/tests/test_framing.py @@ -1,171 +1,9 @@ -import asyncio -import codecs -import unittest -import unittest.mock import warnings -from websockets.exceptions import PayloadTooBig, ProtocolError -from websockets.frames import OP_BINARY, OP_CLOSE, OP_PING, OP_PONG, OP_TEXT -from websockets.framing import * -from .utils import AsyncioTestCase - - -class FramingTests(AsyncioTestCase): - def decode(self, message, mask=False, max_size=None, extensions=None): - stream = asyncio.StreamReader(loop=self.loop) - stream.feed_data(message) - stream.feed_eof() - with warnings.catch_warnings(record=True): - frame = self.loop.run_until_complete( - Frame.read( - stream.readexactly, - mask=mask, - max_size=max_size, - extensions=extensions, - ) - ) - # Make sure all the data was consumed. - self.assertTrue(stream.at_eof()) - return frame - - def encode(self, frame, mask=False, extensions=None): - write = unittest.mock.Mock() - with warnings.catch_warnings(record=True): - frame.write(write, mask=mask, extensions=extensions) - # Ensure the entire frame is sent with a single call to write(). - # Multiple calls cause TCP fragmentation and degrade performance. - self.assertEqual(write.call_count, 1) - # The frame data is the single positional argument of that call. - self.assertEqual(len(write.call_args[0]), 1) - self.assertEqual(len(write.call_args[1]), 0) - return write.call_args[0][0] - - def round_trip(self, message, expected, mask=False, extensions=None): - decoded = self.decode(message, mask, extensions=extensions) - self.assertEqual(decoded, expected) - encoded = self.encode(decoded, mask, extensions=extensions) - if mask: # non-deterministic encoding - decoded = self.decode(encoded, mask, extensions=extensions) - self.assertEqual(decoded, expected) - else: # deterministic encoding - self.assertEqual(encoded, message) - - def test_text(self): - self.round_trip(b"\x81\x04Spam", Frame(True, OP_TEXT, b"Spam")) - - def test_text_masked(self): - self.round_trip( - b"\x81\x84\x5b\xfb\xe1\xa8\x08\x8b\x80\xc5", - Frame(True, OP_TEXT, b"Spam"), - mask=True, - ) - - def test_binary(self): - self.round_trip(b"\x82\x04Eggs", Frame(True, OP_BINARY, b"Eggs")) - - def test_binary_masked(self): - self.round_trip( - b"\x82\x84\x53\xcd\xe2\x89\x16\xaa\x85\xfa", - Frame(True, OP_BINARY, b"Eggs"), - mask=True, - ) - - def test_non_ascii_text(self): - self.round_trip( - b"\x81\x05caf\xc3\xa9", Frame(True, OP_TEXT, "café".encode("utf-8")) - ) - - def test_non_ascii_text_masked(self): - self.round_trip( - b"\x81\x85\x64\xbe\xee\x7e\x07\xdf\x88\xbd\xcd", - Frame(True, OP_TEXT, "café".encode("utf-8")), - mask=True, - ) - - def test_close(self): - self.round_trip(b"\x88\x00", Frame(True, OP_CLOSE, b"")) - - def test_ping(self): - self.round_trip(b"\x89\x04ping", Frame(True, OP_PING, b"ping")) - - def test_pong(self): - self.round_trip(b"\x8a\x04pong", Frame(True, OP_PONG, b"pong")) - - def test_long(self): - self.round_trip( - b"\x82\x7e\x00\x7e" + 126 * b"a", Frame(True, OP_BINARY, 126 * b"a") - ) - - def test_very_long(self): - self.round_trip( - b"\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x00" + 65536 * b"a", - Frame(True, OP_BINARY, 65536 * b"a"), - ) - - def test_payload_too_big(self): - with self.assertRaises(PayloadTooBig): - self.decode(b"\x82\x7e\x04\x01" + 1025 * b"a", max_size=1024) - - def test_bad_reserved_bits(self): - for encoded in [b"\xc0\x00", b"\xa0\x00", b"\x90\x00"]: - with self.subTest(encoded=encoded): - with self.assertRaises(ProtocolError): - self.decode(encoded) - - def test_good_opcode(self): - for opcode in list(range(0x00, 0x03)) + list(range(0x08, 0x0B)): - encoded = bytes([0x80 | opcode, 0]) - with self.subTest(encoded=encoded): - self.decode(encoded) # does not raise an exception - - def test_bad_opcode(self): - for opcode in list(range(0x03, 0x08)) + list(range(0x0B, 0x10)): - encoded = bytes([0x80 | opcode, 0]) - with self.subTest(encoded=encoded): - with self.assertRaises(ProtocolError): - self.decode(encoded) - - def test_mask_flag(self): - # Mask flag correctly set. - self.decode(b"\x80\x80\x00\x00\x00\x00", mask=True) - # Mask flag incorrectly unset. - with self.assertRaises(ProtocolError): - self.decode(b"\x80\x80\x00\x00\x00\x00") - # Mask flag correctly unset. - self.decode(b"\x80\x00") - # Mask flag incorrectly set. - with self.assertRaises(ProtocolError): - self.decode(b"\x80\x00", mask=True) - - def test_control_frame_max_length(self): - # At maximum allowed length. - self.decode(b"\x88\x7e\x00\x7d" + 125 * b"a") - # Above maximum allowed length. - with self.assertRaises(ProtocolError): - self.decode(b"\x88\x7e\x00\x7e" + 126 * b"a") - - def test_fragmented_control_frame(self): - # Fin bit correctly set. - self.decode(b"\x88\x00") - # Fin bit incorrectly unset. - with self.assertRaises(ProtocolError): - self.decode(b"\x08\x00") - - def test_extensions(self): - class Rot13: - @staticmethod - def encode(frame): - assert frame.opcode == OP_TEXT - text = frame.data.decode() - data = codecs.encode(text, "rot13").encode() - return frame._replace(data=data) - - # This extensions is symmetrical. - @staticmethod - def decode(frame, *, max_size=None): - return Rot13.encode(frame) - - self.round_trip( - b"\x81\x05uryyb", Frame(True, OP_TEXT, b"hello"), extensions=[Rot13()] - ) +with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", "websockets.framing is deprecated", DeprecationWarning + ) + # Check that the legacy framing module imports without an exception. + from websockets.framing import * # noqa diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 432c31ef5..f896fcae4 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -1,1489 +1,2 @@ -import asyncio -import contextlib -import sys -import unittest -import unittest.mock -import warnings - -from websockets.exceptions import ConnectionClosed, InvalidState -from websockets.frames import ( - OP_BINARY, - OP_CLOSE, - OP_CONT, - OP_PING, - OP_PONG, - OP_TEXT, - serialize_close, -) -from websockets.framing import Frame -from websockets.protocol import State, WebSocketCommonProtocol - -from .utils import MS, AsyncioTestCase - - -async def async_iterable(iterable): - for item in iterable: - yield item - - -class TransportMock(unittest.mock.Mock): - """ - Transport mock to control the protocol's inputs and outputs in tests. - - It calls the protocol's connection_made and connection_lost methods like - actual transports. - - It also calls the protocol's connection_open method to bypass the - WebSocket handshake. - - To simulate incoming data, tests call the protocol's data_received and - eof_received methods directly. - - They could also pause_writing and resume_writing to test flow control. - - """ - - # This should happen in __init__ but overriding Mock.__init__ is hard. - def setup_mock(self, loop, protocol): - self.loop = loop - self.protocol = protocol - self._eof = False - self._closing = False - # Simulate a successful TCP handshake. - self.protocol.connection_made(self) - # Simulate a successful WebSocket handshake. - self.protocol.connection_open() - - def can_write_eof(self): - return True - - def write_eof(self): - # When the protocol half-closes the TCP connection, it expects the - # other end to close it. Simulate that. - if not self._eof: - self.loop.call_soon(self.close) - self._eof = True - - def close(self): - # Simulate how actual transports drop the connection. - if not self._closing: - self.loop.call_soon(self.protocol.connection_lost, None) - self._closing = True - - def abort(self): - # Change this to an `if` if tests call abort() multiple times. - assert self.protocol.state is not State.CLOSED - self.loop.call_soon(self.protocol.connection_lost, None) - - -class CommonTests: - """ - Mixin that defines most tests but doesn't inherit unittest.TestCase. - - Tests are run by the ServerTests and ClientTests subclasses. - - """ - - def setUp(self): - super().setUp() - # Disable pings to make it easier to test what frames are sent exactly. - self.protocol = WebSocketCommonProtocol(ping_interval=None) - self.transport = TransportMock() - self.transport.setup_mock(self.loop, self.protocol) - - def tearDown(self): - self.transport.close() - self.loop.run_until_complete(self.protocol.close()) - super().tearDown() - - # Utilities for writing tests. - - def make_drain_slow(self, delay=MS): - # Process connection_made in order to initialize self.protocol.transport. - self.run_loop_once() - - original_drain = self.protocol._drain - - async def delayed_drain(): - await asyncio.sleep( - delay, loop=self.loop if sys.version_info[:2] < (3, 8) else None - ) - await original_drain() - - self.protocol._drain = delayed_drain - - close_frame = Frame(True, OP_CLOSE, serialize_close(1000, "close")) - local_close = Frame(True, OP_CLOSE, serialize_close(1000, "local")) - remote_close = Frame(True, OP_CLOSE, serialize_close(1000, "remote")) - - def receive_frame(self, frame): - """ - Make the protocol receive a frame. - - """ - write = self.protocol.data_received - mask = not self.protocol.is_client - frame.write(write, mask=mask) - - def receive_eof(self): - """ - Make the protocol receive the end of the data stream. - - Since ``WebSocketCommonProtocol.eof_received`` returns ``None``, an - actual transport would close itself after calling it. This function - emulates that behavior. - - """ - self.protocol.eof_received() - self.loop.call_soon(self.transport.close) - - def receive_eof_if_client(self): - """ - Like receive_eof, but only if this is the client side. - - Since the server is supposed to initiate the termination of the TCP - connection, this method helps making tests work for both sides. - - """ - if self.protocol.is_client: - self.receive_eof() - - def close_connection(self, code=1000, reason="close"): - """ - Execute a closing handshake. - - This puts the connection in the CLOSED state. - - """ - close_frame_data = serialize_close(code, reason) - # Prepare the response to the closing handshake from the remote side. - self.receive_frame(Frame(True, OP_CLOSE, close_frame_data)) - self.receive_eof_if_client() - # Trigger the closing handshake from the local side and complete it. - self.loop.run_until_complete(self.protocol.close(code, reason)) - # Empty the outgoing data stream so we can make assertions later on. - self.assertOneFrameSent(True, OP_CLOSE, close_frame_data) - - assert self.protocol.state is State.CLOSED - - def half_close_connection_local(self, code=1000, reason="close"): - """ - Start a closing handshake but do not complete it. - - The main difference with `close_connection` is that the connection is - left in the CLOSING state until the event loop runs again. - - The current implementation returns a task that must be awaited or - canceled, else asyncio complains about destroying a pending task. - - """ - close_frame_data = serialize_close(code, reason) - # Trigger the closing handshake from the local endpoint. - close_task = self.loop.create_task(self.protocol.close(code, reason)) - self.run_loop_once() # wait_for executes - self.run_loop_once() # write_frame executes - # Empty the outgoing data stream so we can make assertions later on. - self.assertOneFrameSent(True, OP_CLOSE, close_frame_data) - - assert self.protocol.state is State.CLOSING - - # Complete the closing sequence at 1ms intervals so the test can run - # at each point even it goes back to the event loop several times. - self.loop.call_later( - MS, self.receive_frame, Frame(True, OP_CLOSE, close_frame_data) - ) - self.loop.call_later(2 * MS, self.receive_eof_if_client) - - # This task must be awaited or canceled by the caller. - return close_task - - def half_close_connection_remote(self, code=1000, reason="close"): - """ - Receive a closing handshake but do not complete it. - - The main difference with `close_connection` is that the connection is - left in the CLOSING state until the event loop runs again. - - """ - # On the server side, websockets completes the closing handshake and - # closes the TCP connection immediately. Yield to the event loop after - # sending the close frame to run the test while the connection is in - # the CLOSING state. - if not self.protocol.is_client: - self.make_drain_slow() - - close_frame_data = serialize_close(code, reason) - # Trigger the closing handshake from the remote endpoint. - self.receive_frame(Frame(True, OP_CLOSE, close_frame_data)) - self.run_loop_once() # read_frame executes - # Empty the outgoing data stream so we can make assertions later on. - self.assertOneFrameSent(True, OP_CLOSE, close_frame_data) - - assert self.protocol.state is State.CLOSING - - # Complete the closing sequence at 1ms intervals so the test can run - # at each point even it goes back to the event loop several times. - self.loop.call_later(2 * MS, self.receive_eof_if_client) - - def process_invalid_frames(self): - """ - Make the protocol fail quickly after simulating invalid data. - - To achieve this, this function triggers the protocol's eof_received, - which interrupts pending reads waiting for more data. - - """ - self.run_loop_once() - self.receive_eof() - self.loop.run_until_complete(self.protocol.close_connection_task) - - def sent_frames(self): - """ - Read all frames sent to the transport. - - """ - stream = asyncio.StreamReader(loop=self.loop) - - for (data,), kw in self.transport.write.call_args_list: - stream.feed_data(data) - self.transport.write.call_args_list = [] - stream.feed_eof() - - frames = [] - while not stream.at_eof(): - frames.append( - self.loop.run_until_complete( - Frame.read(stream.readexactly, mask=self.protocol.is_client) - ) - ) - return frames - - def last_sent_frame(self): - """ - Read the last frame sent to the transport. - - This method assumes that at most one frame was sent. It raises an - AssertionError otherwise. - - """ - frames = self.sent_frames() - if frames: - assert len(frames) == 1 - return frames[0] - - def assertFramesSent(self, *frames): - self.assertEqual(self.sent_frames(), [Frame(*args) for args in frames]) - - def assertOneFrameSent(self, *args): - self.assertEqual(self.last_sent_frame(), Frame(*args)) - - def assertNoFrameSent(self): - self.assertIsNone(self.last_sent_frame()) - - def assertConnectionClosed(self, code, message): - # The following line guarantees that connection_lost was called. - self.assertEqual(self.protocol.state, State.CLOSED) - # A close frame was received. - self.assertEqual(self.protocol.close_code, code) - self.assertEqual(self.protocol.close_reason, message) - - def assertConnectionFailed(self, code, message): - # The following line guarantees that connection_lost was called. - self.assertEqual(self.protocol.state, State.CLOSED) - # No close frame was received. - self.assertEqual(self.protocol.close_code, 1006) - self.assertEqual(self.protocol.close_reason, "") - # A close frame was sent -- unless the connection was already lost. - if code == 1006: - self.assertNoFrameSent() - else: - self.assertOneFrameSent(True, OP_CLOSE, serialize_close(code, message)) - - @contextlib.contextmanager - def assertCompletesWithin(self, min_time, max_time): - t0 = self.loop.time() - yield - t1 = self.loop.time() - dt = t1 - t0 - self.assertGreaterEqual(dt, min_time, f"Too fast: {dt} < {min_time}") - self.assertLess(dt, max_time, f"Too slow: {dt} >= {max_time}") - - # Test constructor. - - def test_timeout_backwards_compatibility(self): - with warnings.catch_warnings(record=True) as recorded_warnings: - protocol = WebSocketCommonProtocol(timeout=5) - - self.assertEqual(protocol.close_timeout, 5) - - self.assertEqual(len(recorded_warnings), 1) - warning = recorded_warnings[0].message - self.assertEqual(str(warning), "rename timeout to close_timeout") - self.assertEqual(type(warning), DeprecationWarning) - - # Test public attributes. - - def test_local_address(self): - get_extra_info = unittest.mock.Mock(return_value=("host", 4312)) - self.transport.get_extra_info = get_extra_info - - self.assertEqual(self.protocol.local_address, ("host", 4312)) - get_extra_info.assert_called_with("sockname") - - def test_local_address_before_connection(self): - # Emulate the situation before connection_open() runs. - _transport = self.protocol.transport - del self.protocol.transport - try: - self.assertEqual(self.protocol.local_address, None) - finally: - self.protocol.transport = _transport - - def test_remote_address(self): - get_extra_info = unittest.mock.Mock(return_value=("host", 4312)) - self.transport.get_extra_info = get_extra_info - - self.assertEqual(self.protocol.remote_address, ("host", 4312)) - get_extra_info.assert_called_with("peername") - - def test_remote_address_before_connection(self): - # Emulate the situation before connection_open() runs. - _transport = self.protocol.transport - del self.protocol.transport - try: - self.assertEqual(self.protocol.remote_address, None) - finally: - self.protocol.transport = _transport - - def test_open(self): - self.assertTrue(self.protocol.open) - self.close_connection() - self.assertFalse(self.protocol.open) - - def test_closed(self): - self.assertFalse(self.protocol.closed) - self.close_connection() - self.assertTrue(self.protocol.closed) - - def test_wait_closed(self): - wait_closed = self.loop.create_task(self.protocol.wait_closed()) - self.assertFalse(wait_closed.done()) - self.close_connection() - self.assertTrue(wait_closed.done()) - - # Test the recv coroutine. - - def test_recv_text(self): - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) - data = self.loop.run_until_complete(self.protocol.recv()) - self.assertEqual(data, "café") - - def test_recv_binary(self): - self.receive_frame(Frame(True, OP_BINARY, b"tea")) - data = self.loop.run_until_complete(self.protocol.recv()) - self.assertEqual(data, b"tea") - - def test_recv_on_closing_connection_local(self): - close_task = self.half_close_connection_local() - - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(self.protocol.recv()) - - self.loop.run_until_complete(close_task) # cleanup - - def test_recv_on_closing_connection_remote(self): - self.half_close_connection_remote() - - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(self.protocol.recv()) - - def test_recv_on_closed_connection(self): - self.close_connection() - - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(self.protocol.recv()) - - def test_recv_protocol_error(self): - self.receive_frame(Frame(True, OP_CONT, "café".encode("utf-8"))) - self.process_invalid_frames() - self.assertConnectionFailed(1002, "") - - def test_recv_unicode_error(self): - self.receive_frame(Frame(True, OP_TEXT, "café".encode("latin-1"))) - self.process_invalid_frames() - self.assertConnectionFailed(1007, "") - - def test_recv_text_payload_too_big(self): - self.protocol.max_size = 1024 - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8") * 205)) - self.process_invalid_frames() - self.assertConnectionFailed(1009, "") - - def test_recv_binary_payload_too_big(self): - self.protocol.max_size = 1024 - self.receive_frame(Frame(True, OP_BINARY, b"tea" * 342)) - self.process_invalid_frames() - self.assertConnectionFailed(1009, "") - - def test_recv_text_no_max_size(self): - self.protocol.max_size = None # for test coverage - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8") * 205)) - data = self.loop.run_until_complete(self.protocol.recv()) - self.assertEqual(data, "café" * 205) - - def test_recv_binary_no_max_size(self): - self.protocol.max_size = None # for test coverage - self.receive_frame(Frame(True, OP_BINARY, b"tea" * 342)) - data = self.loop.run_until_complete(self.protocol.recv()) - self.assertEqual(data, b"tea" * 342) - - def test_recv_queue_empty(self): - recv = self.loop.create_task(self.protocol.recv()) - with self.assertRaises(asyncio.TimeoutError): - self.loop.run_until_complete( - asyncio.wait_for(asyncio.shield(recv), timeout=MS) - ) - - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) - data = self.loop.run_until_complete(recv) - self.assertEqual(data, "café") - - def test_recv_queue_full(self): - self.protocol.max_queue = 2 - # Test internals because it's hard to verify buffers from the outside. - self.assertEqual(list(self.protocol.messages), []) - - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) - self.run_loop_once() - self.assertEqual(list(self.protocol.messages), ["café"]) - - self.receive_frame(Frame(True, OP_BINARY, b"tea")) - self.run_loop_once() - self.assertEqual(list(self.protocol.messages), ["café", b"tea"]) - - self.receive_frame(Frame(True, OP_BINARY, b"milk")) - self.run_loop_once() - self.assertEqual(list(self.protocol.messages), ["café", b"tea"]) - - self.loop.run_until_complete(self.protocol.recv()) - self.run_loop_once() - self.assertEqual(list(self.protocol.messages), [b"tea", b"milk"]) - - self.loop.run_until_complete(self.protocol.recv()) - self.run_loop_once() - self.assertEqual(list(self.protocol.messages), [b"milk"]) - - self.loop.run_until_complete(self.protocol.recv()) - self.run_loop_once() - self.assertEqual(list(self.protocol.messages), []) - - def test_recv_queue_no_limit(self): - self.protocol.max_queue = None - - for _ in range(100): - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) - self.run_loop_once() - - # Incoming message queue can contain at least 100 messages. - self.assertEqual(list(self.protocol.messages), ["café"] * 100) - - for _ in range(100): - self.loop.run_until_complete(self.protocol.recv()) - - self.assertEqual(list(self.protocol.messages), []) - - def test_recv_other_error(self): - async def read_message(): - raise Exception("BOOM") - - self.protocol.read_message = read_message - self.process_invalid_frames() - self.assertConnectionFailed(1011, "") - - def test_recv_canceled(self): - recv = self.loop.create_task(self.protocol.recv()) - self.loop.call_soon(recv.cancel) - - with self.assertRaises(asyncio.CancelledError): - self.loop.run_until_complete(recv) - - # The next frame doesn't disappear in a vacuum (it used to). - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) - data = self.loop.run_until_complete(self.protocol.recv()) - self.assertEqual(data, "café") - - def test_recv_canceled_race_condition(self): - recv = self.loop.create_task( - asyncio.wait_for(self.protocol.recv(), timeout=0.000_001) - ) - self.loop.call_soon( - self.receive_frame, Frame(True, OP_TEXT, "café".encode("utf-8")) - ) - - with self.assertRaises(asyncio.TimeoutError): - self.loop.run_until_complete(recv) - - # The previous frame doesn't disappear in a vacuum (it used to). - self.receive_frame(Frame(True, OP_TEXT, "tea".encode("utf-8"))) - data = self.loop.run_until_complete(self.protocol.recv()) - # If we're getting "tea" there, it means "café" was swallowed (ha, ha). - self.assertEqual(data, "café") - - def test_recv_when_transfer_data_cancelled(self): - # Clog incoming queue. - self.protocol.max_queue = 1 - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) - self.receive_frame(Frame(True, OP_BINARY, b"tea")) - self.run_loop_once() - - # Flow control kicks in (check with an implementation detail). - self.assertFalse(self.protocol._put_message_waiter.done()) - - # Schedule recv(). - recv = self.loop.create_task(self.protocol.recv()) - - # Cancel transfer_data_task (again, implementation detail). - self.protocol.fail_connection() - self.run_loop_once() - self.assertTrue(self.protocol.transfer_data_task.cancelled()) - - # recv() completes properly. - self.assertEqual(self.loop.run_until_complete(recv), "café") - - def test_recv_prevents_concurrent_calls(self): - recv = self.loop.create_task(self.protocol.recv()) - - with self.assertRaises(RuntimeError) as raised: - self.loop.run_until_complete(self.protocol.recv()) - self.assertEqual( - str(raised.exception), - "cannot call recv while another coroutine " - "is already waiting for the next message", - ) - recv.cancel() - - # Test the send coroutine. - - def test_send_text(self): - self.loop.run_until_complete(self.protocol.send("café")) - self.assertOneFrameSent(True, OP_TEXT, "café".encode("utf-8")) - - def test_send_binary(self): - self.loop.run_until_complete(self.protocol.send(b"tea")) - self.assertOneFrameSent(True, OP_BINARY, b"tea") - - def test_send_binary_from_bytearray(self): - self.loop.run_until_complete(self.protocol.send(bytearray(b"tea"))) - self.assertOneFrameSent(True, OP_BINARY, b"tea") - - def test_send_binary_from_memoryview(self): - self.loop.run_until_complete(self.protocol.send(memoryview(b"tea"))) - self.assertOneFrameSent(True, OP_BINARY, b"tea") - - def test_send_binary_from_non_contiguous_memoryview(self): - self.loop.run_until_complete(self.protocol.send(memoryview(b"tteeaa")[::2])) - self.assertOneFrameSent(True, OP_BINARY, b"tea") - - def test_send_dict(self): - with self.assertRaises(TypeError): - self.loop.run_until_complete(self.protocol.send({"not": "encoded"})) - self.assertNoFrameSent() - - def test_send_type_error(self): - with self.assertRaises(TypeError): - self.loop.run_until_complete(self.protocol.send(42)) - self.assertNoFrameSent() - - def test_send_iterable_text(self): - self.loop.run_until_complete(self.protocol.send(["ca", "fé"])) - self.assertFramesSent( - (False, OP_TEXT, "ca".encode("utf-8")), - (False, OP_CONT, "fé".encode("utf-8")), - (True, OP_CONT, "".encode("utf-8")), - ) - - def test_send_iterable_binary(self): - self.loop.run_until_complete(self.protocol.send([b"te", b"a"])) - self.assertFramesSent( - (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") - ) - - def test_send_iterable_binary_from_bytearray(self): - self.loop.run_until_complete( - self.protocol.send([bytearray(b"te"), bytearray(b"a")]) - ) - self.assertFramesSent( - (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") - ) - - def test_send_iterable_binary_from_memoryview(self): - self.loop.run_until_complete( - self.protocol.send([memoryview(b"te"), memoryview(b"a")]) - ) - self.assertFramesSent( - (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") - ) - - def test_send_iterable_binary_from_non_contiguous_memoryview(self): - self.loop.run_until_complete( - self.protocol.send([memoryview(b"ttee")[::2], memoryview(b"aa")[::2]]) - ) - self.assertFramesSent( - (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") - ) - - def test_send_empty_iterable(self): - self.loop.run_until_complete(self.protocol.send([])) - self.assertNoFrameSent() - - def test_send_iterable_type_error(self): - with self.assertRaises(TypeError): - self.loop.run_until_complete(self.protocol.send([42])) - self.assertNoFrameSent() - - def test_send_iterable_mixed_type_error(self): - with self.assertRaises(TypeError): - self.loop.run_until_complete(self.protocol.send(["café", b"tea"])) - self.assertFramesSent( - (False, OP_TEXT, "café".encode("utf-8")), - (True, OP_CLOSE, serialize_close(1011, "")), - ) - - def test_send_iterable_prevents_concurrent_send(self): - self.make_drain_slow(2 * MS) - - async def send_iterable(): - await self.protocol.send(["ca", "fé"]) - - async def send_concurrent(): - await asyncio.sleep(MS) - await self.protocol.send(b"tea") - - self.loop.run_until_complete(asyncio.gather(send_iterable(), send_concurrent())) - self.assertFramesSent( - (False, OP_TEXT, "ca".encode("utf-8")), - (False, OP_CONT, "fé".encode("utf-8")), - (True, OP_CONT, "".encode("utf-8")), - (True, OP_BINARY, b"tea"), - ) - - def test_send_async_iterable_text(self): - self.loop.run_until_complete(self.protocol.send(async_iterable(["ca", "fé"]))) - self.assertFramesSent( - (False, OP_TEXT, "ca".encode("utf-8")), - (False, OP_CONT, "fé".encode("utf-8")), - (True, OP_CONT, "".encode("utf-8")), - ) - - def test_send_async_iterable_binary(self): - self.loop.run_until_complete(self.protocol.send(async_iterable([b"te", b"a"]))) - self.assertFramesSent( - (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") - ) - - def test_send_async_iterable_binary_from_bytearray(self): - self.loop.run_until_complete( - self.protocol.send(async_iterable([bytearray(b"te"), bytearray(b"a")])) - ) - self.assertFramesSent( - (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") - ) - - def test_send_async_iterable_binary_from_memoryview(self): - self.loop.run_until_complete( - self.protocol.send(async_iterable([memoryview(b"te"), memoryview(b"a")])) - ) - self.assertFramesSent( - (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") - ) - - def test_send_async_iterable_binary_from_non_contiguous_memoryview(self): - self.loop.run_until_complete( - self.protocol.send( - async_iterable([memoryview(b"ttee")[::2], memoryview(b"aa")[::2]]) - ) - ) - self.assertFramesSent( - (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") - ) - - def test_send_empty_async_iterable(self): - self.loop.run_until_complete(self.protocol.send(async_iterable([]))) - self.assertNoFrameSent() - - def test_send_async_iterable_type_error(self): - with self.assertRaises(TypeError): - self.loop.run_until_complete(self.protocol.send(async_iterable([42]))) - self.assertNoFrameSent() - - def test_send_async_iterable_mixed_type_error(self): - with self.assertRaises(TypeError): - self.loop.run_until_complete( - self.protocol.send(async_iterable(["café", b"tea"])) - ) - self.assertFramesSent( - (False, OP_TEXT, "café".encode("utf-8")), - (True, OP_CLOSE, serialize_close(1011, "")), - ) - - def test_send_async_iterable_prevents_concurrent_send(self): - self.make_drain_slow(2 * MS) - - async def send_async_iterable(): - await self.protocol.send(async_iterable(["ca", "fé"])) - - async def send_concurrent(): - await asyncio.sleep(MS) - await self.protocol.send(b"tea") - - self.loop.run_until_complete( - asyncio.gather(send_async_iterable(), send_concurrent()) - ) - self.assertFramesSent( - (False, OP_TEXT, "ca".encode("utf-8")), - (False, OP_CONT, "fé".encode("utf-8")), - (True, OP_CONT, "".encode("utf-8")), - (True, OP_BINARY, b"tea"), - ) - - def test_send_on_closing_connection_local(self): - close_task = self.half_close_connection_local() - - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(self.protocol.send("foobar")) - - self.assertNoFrameSent() - - self.loop.run_until_complete(close_task) # cleanup - - def test_send_on_closing_connection_remote(self): - self.half_close_connection_remote() - - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(self.protocol.send("foobar")) - - self.assertNoFrameSent() - - def test_send_on_closed_connection(self): - self.close_connection() - - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(self.protocol.send("foobar")) - - self.assertNoFrameSent() - - # Test the ping coroutine. - - def test_ping_default(self): - self.loop.run_until_complete(self.protocol.ping()) - # With our testing tools, it's more convenient to extract the expected - # ping data from the library's internals than from the frame sent. - ping_data = next(iter(self.protocol.pings)) - self.assertIsInstance(ping_data, bytes) - self.assertEqual(len(ping_data), 4) - self.assertOneFrameSent(True, OP_PING, ping_data) - - def test_ping_text(self): - self.loop.run_until_complete(self.protocol.ping("café")) - self.assertOneFrameSent(True, OP_PING, "café".encode("utf-8")) - - def test_ping_binary(self): - self.loop.run_until_complete(self.protocol.ping(b"tea")) - self.assertOneFrameSent(True, OP_PING, b"tea") - - def test_ping_binary_from_bytearray(self): - self.loop.run_until_complete(self.protocol.ping(bytearray(b"tea"))) - self.assertOneFrameSent(True, OP_PING, b"tea") - - def test_ping_binary_from_memoryview(self): - self.loop.run_until_complete(self.protocol.ping(memoryview(b"tea"))) - self.assertOneFrameSent(True, OP_PING, b"tea") - - def test_ping_binary_from_non_contiguous_memoryview(self): - self.loop.run_until_complete(self.protocol.ping(memoryview(b"tteeaa")[::2])) - self.assertOneFrameSent(True, OP_PING, b"tea") - - def test_ping_type_error(self): - with self.assertRaises(TypeError): - self.loop.run_until_complete(self.protocol.ping(42)) - self.assertNoFrameSent() - - def test_ping_on_closing_connection_local(self): - close_task = self.half_close_connection_local() - - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(self.protocol.ping()) - - self.assertNoFrameSent() - - self.loop.run_until_complete(close_task) # cleanup - - def test_ping_on_closing_connection_remote(self): - self.half_close_connection_remote() - - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(self.protocol.ping()) - - self.assertNoFrameSent() - - def test_ping_on_closed_connection(self): - self.close_connection() - - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(self.protocol.ping()) - - self.assertNoFrameSent() - - # Test the pong coroutine. - - def test_pong_default(self): - self.loop.run_until_complete(self.protocol.pong()) - self.assertOneFrameSent(True, OP_PONG, b"") - - def test_pong_text(self): - self.loop.run_until_complete(self.protocol.pong("café")) - self.assertOneFrameSent(True, OP_PONG, "café".encode("utf-8")) - - def test_pong_binary(self): - self.loop.run_until_complete(self.protocol.pong(b"tea")) - self.assertOneFrameSent(True, OP_PONG, b"tea") - - def test_pong_binary_from_bytearray(self): - self.loop.run_until_complete(self.protocol.pong(bytearray(b"tea"))) - self.assertOneFrameSent(True, OP_PONG, b"tea") - - def test_pong_binary_from_memoryview(self): - self.loop.run_until_complete(self.protocol.pong(memoryview(b"tea"))) - self.assertOneFrameSent(True, OP_PONG, b"tea") - - def test_pong_binary_from_non_contiguous_memoryview(self): - self.loop.run_until_complete(self.protocol.pong(memoryview(b"tteeaa")[::2])) - self.assertOneFrameSent(True, OP_PONG, b"tea") - - def test_pong_type_error(self): - with self.assertRaises(TypeError): - self.loop.run_until_complete(self.protocol.pong(42)) - self.assertNoFrameSent() - - def test_pong_on_closing_connection_local(self): - close_task = self.half_close_connection_local() - - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(self.protocol.pong()) - - self.assertNoFrameSent() - - self.loop.run_until_complete(close_task) # cleanup - - def test_pong_on_closing_connection_remote(self): - self.half_close_connection_remote() - - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(self.protocol.pong()) - - self.assertNoFrameSent() - - def test_pong_on_closed_connection(self): - self.close_connection() - - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(self.protocol.pong()) - - self.assertNoFrameSent() - - # Test the protocol's logic for acknowledging pings with pongs. - - def test_answer_ping(self): - self.receive_frame(Frame(True, OP_PING, b"test")) - self.run_loop_once() - self.assertOneFrameSent(True, OP_PONG, b"test") - - def test_ignore_pong(self): - self.receive_frame(Frame(True, OP_PONG, b"test")) - self.run_loop_once() - self.assertNoFrameSent() - - def test_acknowledge_ping(self): - ping = self.loop.run_until_complete(self.protocol.ping()) - self.assertFalse(ping.done()) - ping_frame = self.last_sent_frame() - pong_frame = Frame(True, OP_PONG, ping_frame.data) - self.receive_frame(pong_frame) - self.run_loop_once() - self.run_loop_once() - self.assertTrue(ping.done()) - - def test_abort_ping(self): - ping = self.loop.run_until_complete(self.protocol.ping()) - # Remove the frame from the buffer, else close_connection() complains. - self.last_sent_frame() - self.assertFalse(ping.done()) - self.close_connection() - self.assertTrue(ping.done()) - self.assertIsInstance(ping.exception(), ConnectionClosed) - - def test_abort_ping_does_not_log_exception_if_not_retreived(self): - self.loop.run_until_complete(self.protocol.ping()) - # Get the internal Future, which isn't directly returned by ping(). - (ping,) = self.protocol.pings.values() - # Remove the frame from the buffer, else close_connection() complains. - self.last_sent_frame() - self.close_connection() - # Check a private attribute, for lack of a better solution. - self.assertFalse(ping._log_traceback) - - def test_acknowledge_previous_pings(self): - pings = [ - (self.loop.run_until_complete(self.protocol.ping()), self.last_sent_frame()) - for i in range(3) - ] - # Unsolicited pong doesn't acknowledge pings - self.receive_frame(Frame(True, OP_PONG, b"")) - self.run_loop_once() - self.run_loop_once() - self.assertFalse(pings[0][0].done()) - self.assertFalse(pings[1][0].done()) - self.assertFalse(pings[2][0].done()) - # Pong acknowledges all previous pings - self.receive_frame(Frame(True, OP_PONG, pings[1][1].data)) - self.run_loop_once() - self.run_loop_once() - self.assertTrue(pings[0][0].done()) - self.assertTrue(pings[1][0].done()) - self.assertFalse(pings[2][0].done()) - - def test_acknowledge_aborted_ping(self): - ping = self.loop.run_until_complete(self.protocol.ping()) - ping_frame = self.last_sent_frame() - # Clog incoming queue. This lets connection_lost() abort pending pings - # with a ConnectionClosed exception before transfer_data_task - # terminates and close_connection cancels keepalive_ping_task. - self.protocol.max_queue = 1 - self.receive_frame(Frame(True, OP_TEXT, b"1")) - self.receive_frame(Frame(True, OP_TEXT, b"2")) - # Add pong frame to the queue. - pong_frame = Frame(True, OP_PONG, ping_frame.data) - self.receive_frame(pong_frame) - # Connection drops. - self.receive_eof() - self.loop.run_until_complete(self.protocol.wait_closed()) - # Ping receives a ConnectionClosed exception. - with self.assertRaises(ConnectionClosed): - ping.result() - - # transfer_data doesn't crash, which would be logged. - with self.assertNoLogs(): - # Unclog incoming queue. - self.loop.run_until_complete(self.protocol.recv()) - self.loop.run_until_complete(self.protocol.recv()) - - def test_canceled_ping(self): - ping = self.loop.run_until_complete(self.protocol.ping()) - ping_frame = self.last_sent_frame() - ping.cancel() - pong_frame = Frame(True, OP_PONG, ping_frame.data) - self.receive_frame(pong_frame) - self.run_loop_once() - self.run_loop_once() - self.assertTrue(ping.cancelled()) - - def test_duplicate_ping(self): - self.loop.run_until_complete(self.protocol.ping(b"foobar")) - self.assertOneFrameSent(True, OP_PING, b"foobar") - with self.assertRaises(ValueError): - self.loop.run_until_complete(self.protocol.ping(b"foobar")) - self.assertNoFrameSent() - - # Test the protocol's logic for rebuilding fragmented messages. - - def test_fragmented_text(self): - self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) - self.receive_frame(Frame(True, OP_CONT, "fé".encode("utf-8"))) - data = self.loop.run_until_complete(self.protocol.recv()) - self.assertEqual(data, "café") - - def test_fragmented_binary(self): - self.receive_frame(Frame(False, OP_BINARY, b"t")) - self.receive_frame(Frame(False, OP_CONT, b"e")) - self.receive_frame(Frame(True, OP_CONT, b"a")) - data = self.loop.run_until_complete(self.protocol.recv()) - self.assertEqual(data, b"tea") - - def test_fragmented_text_payload_too_big(self): - self.protocol.max_size = 1024 - self.receive_frame(Frame(False, OP_TEXT, "café".encode("utf-8") * 100)) - self.receive_frame(Frame(True, OP_CONT, "café".encode("utf-8") * 105)) - self.process_invalid_frames() - self.assertConnectionFailed(1009, "") - - def test_fragmented_binary_payload_too_big(self): - self.protocol.max_size = 1024 - self.receive_frame(Frame(False, OP_BINARY, b"tea" * 171)) - self.receive_frame(Frame(True, OP_CONT, b"tea" * 171)) - self.process_invalid_frames() - self.assertConnectionFailed(1009, "") - - def test_fragmented_text_no_max_size(self): - self.protocol.max_size = None # for test coverage - self.receive_frame(Frame(False, OP_TEXT, "café".encode("utf-8") * 100)) - self.receive_frame(Frame(True, OP_CONT, "café".encode("utf-8") * 105)) - data = self.loop.run_until_complete(self.protocol.recv()) - self.assertEqual(data, "café" * 205) - - def test_fragmented_binary_no_max_size(self): - self.protocol.max_size = None # for test coverage - self.receive_frame(Frame(False, OP_BINARY, b"tea" * 171)) - self.receive_frame(Frame(True, OP_CONT, b"tea" * 171)) - data = self.loop.run_until_complete(self.protocol.recv()) - self.assertEqual(data, b"tea" * 342) - - def test_control_frame_within_fragmented_text(self): - self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) - self.receive_frame(Frame(True, OP_PING, b"")) - self.receive_frame(Frame(True, OP_CONT, "fé".encode("utf-8"))) - data = self.loop.run_until_complete(self.protocol.recv()) - self.assertEqual(data, "café") - self.assertOneFrameSent(True, OP_PONG, b"") - - def test_unterminated_fragmented_text(self): - self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) - # Missing the second part of the fragmented frame. - self.receive_frame(Frame(True, OP_BINARY, b"tea")) - self.process_invalid_frames() - self.assertConnectionFailed(1002, "") - - def test_close_handshake_in_fragmented_text(self): - self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) - self.receive_frame(Frame(True, OP_CLOSE, b"")) - self.process_invalid_frames() - # The RFC may have overlooked this case: it says that control frames - # can be interjected in the middle of a fragmented message and that a - # close frame must be echoed. Even though there's an unterminated - # message, technically, the closing handshake was successful. - self.assertConnectionClosed(1005, "") - - def test_connection_close_in_fragmented_text(self): - self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) - self.process_invalid_frames() - self.assertConnectionFailed(1006, "") - - # Test miscellaneous code paths to ensure full coverage. - - def test_connection_lost(self): - # Test calling connection_lost without going through close_connection. - self.protocol.connection_lost(None) - - self.assertConnectionFailed(1006, "") - - def test_ensure_open_before_opening_handshake(self): - # Simulate a bug by forcibly reverting the protocol state. - self.protocol.state = State.CONNECTING - - with self.assertRaises(InvalidState): - self.loop.run_until_complete(self.protocol.ensure_open()) - - def test_ensure_open_during_unclean_close(self): - # Process connection_made in order to start transfer_data_task. - self.run_loop_once() - - # Ensure the test terminates quickly. - self.loop.call_later(MS, self.receive_eof_if_client) - - # Simulate the case when close() times out sending a close frame. - self.protocol.fail_connection() - - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(self.protocol.ensure_open()) - - def test_legacy_recv(self): - # By default legacy_recv in disabled. - self.assertEqual(self.protocol.legacy_recv, False) - - self.close_connection() - - # Enable legacy_recv. - self.protocol.legacy_recv = True - - # Now recv() returns None instead of raising ConnectionClosed. - self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) - - def test_connection_closed_attributes(self): - self.close_connection() - - with self.assertRaises(ConnectionClosed) as context: - self.loop.run_until_complete(self.protocol.recv()) - - connection_closed_exc = context.exception - self.assertEqual(connection_closed_exc.code, 1000) - self.assertEqual(connection_closed_exc.reason, "close") - - # Test the protocol logic for sending keepalive pings. - - def restart_protocol_with_keepalive_ping( - self, ping_interval=3 * MS, ping_timeout=3 * MS - ): - initial_protocol = self.protocol - # copied from tearDown - self.transport.close() - self.loop.run_until_complete(self.protocol.close()) - # copied from setUp, but enables keepalive pings - self.protocol = WebSocketCommonProtocol( - ping_interval=ping_interval, ping_timeout=ping_timeout - ) - self.transport = TransportMock() - self.transport.setup_mock(self.loop, self.protocol) - self.protocol.is_client = initial_protocol.is_client - self.protocol.side = initial_protocol.side - - def test_keepalive_ping(self): - self.restart_protocol_with_keepalive_ping() - - # Ping is sent at 3ms and acknowledged at 4ms. - self.loop.run_until_complete(asyncio.sleep(4 * MS)) - (ping_1,) = tuple(self.protocol.pings) - self.assertOneFrameSent(True, OP_PING, ping_1) - self.receive_frame(Frame(True, OP_PONG, ping_1)) - - # Next ping is sent at 7ms. - self.loop.run_until_complete(asyncio.sleep(4 * MS)) - (ping_2,) = tuple(self.protocol.pings) - self.assertOneFrameSent(True, OP_PING, ping_2) - - # The keepalive ping task goes on. - self.assertFalse(self.protocol.keepalive_ping_task.done()) - - def test_keepalive_ping_not_acknowledged_closes_connection(self): - self.restart_protocol_with_keepalive_ping() - - # Ping is sent at 3ms and not acknowleged. - self.loop.run_until_complete(asyncio.sleep(4 * MS)) - (ping_1,) = tuple(self.protocol.pings) - self.assertOneFrameSent(True, OP_PING, ping_1) - - # Connection is closed at 6ms. - self.loop.run_until_complete(asyncio.sleep(4 * MS)) - self.assertOneFrameSent(True, OP_CLOSE, serialize_close(1011, "")) - - # The keepalive ping task is complete. - self.assertEqual(self.protocol.keepalive_ping_task.result(), None) - - def test_keepalive_ping_stops_when_connection_closing(self): - self.restart_protocol_with_keepalive_ping() - close_task = self.half_close_connection_local() - - # No ping sent at 3ms because the closing handshake is in progress. - self.loop.run_until_complete(asyncio.sleep(4 * MS)) - self.assertNoFrameSent() - - # The keepalive ping task terminated. - self.assertTrue(self.protocol.keepalive_ping_task.cancelled()) - - self.loop.run_until_complete(close_task) # cleanup - - def test_keepalive_ping_stops_when_connection_closed(self): - self.restart_protocol_with_keepalive_ping() - self.close_connection() - - # The keepalive ping task terminated. - self.assertTrue(self.protocol.keepalive_ping_task.cancelled()) - - def test_keepalive_ping_does_not_crash_when_connection_lost(self): - self.restart_protocol_with_keepalive_ping() - # Clog incoming queue. This lets connection_lost() abort pending pings - # with a ConnectionClosed exception before transfer_data_task - # terminates and close_connection cancels keepalive_ping_task. - self.protocol.max_queue = 1 - self.receive_frame(Frame(True, OP_TEXT, b"1")) - self.receive_frame(Frame(True, OP_TEXT, b"2")) - # Ping is sent at 3ms. - self.loop.run_until_complete(asyncio.sleep(4 * MS)) - (ping_waiter,) = tuple(self.protocol.pings.values()) - # Connection drops. - self.receive_eof() - self.loop.run_until_complete(self.protocol.wait_closed()) - - # The ping waiter receives a ConnectionClosed exception. - with self.assertRaises(ConnectionClosed): - ping_waiter.result() - # The keepalive ping task terminated properly. - self.assertIsNone(self.protocol.keepalive_ping_task.result()) - - # Unclog incoming queue to terminate the test quickly. - self.loop.run_until_complete(self.protocol.recv()) - self.loop.run_until_complete(self.protocol.recv()) - - def test_keepalive_ping_with_no_ping_interval(self): - self.restart_protocol_with_keepalive_ping(ping_interval=None) - - # No ping is sent at 3ms. - self.loop.run_until_complete(asyncio.sleep(4 * MS)) - self.assertNoFrameSent() - - def test_keepalive_ping_with_no_ping_timeout(self): - self.restart_protocol_with_keepalive_ping(ping_timeout=None) - - # Ping is sent at 3ms and not acknowleged. - self.loop.run_until_complete(asyncio.sleep(4 * MS)) - (ping_1,) = tuple(self.protocol.pings) - self.assertOneFrameSent(True, OP_PING, ping_1) - - # Next ping is sent at 7ms anyway. - self.loop.run_until_complete(asyncio.sleep(4 * MS)) - ping_1_again, ping_2 = tuple(self.protocol.pings) - self.assertEqual(ping_1, ping_1_again) - self.assertOneFrameSent(True, OP_PING, ping_2) - - # The keepalive ping task goes on. - self.assertFalse(self.protocol.keepalive_ping_task.done()) - - def test_keepalive_ping_unexpected_error(self): - self.restart_protocol_with_keepalive_ping() - - async def ping(): - raise Exception("BOOM") - - self.protocol.ping = ping - - # The keepalive ping task fails when sending a ping at 3ms. - self.loop.run_until_complete(asyncio.sleep(4 * MS)) - - # The keepalive ping task is complete. - # It logs and swallows the exception. - self.assertEqual(self.protocol.keepalive_ping_task.result(), None) - - # Test the protocol logic for closing the connection. - - def test_local_close(self): - # Emulate how the remote endpoint answers the closing handshake. - self.loop.call_later(MS, self.receive_frame, self.close_frame) - self.loop.call_later(MS, self.receive_eof_if_client) - - # Run the closing handshake. - self.loop.run_until_complete(self.protocol.close(reason="close")) - - self.assertConnectionClosed(1000, "close") - self.assertOneFrameSent(*self.close_frame) - - # Closing the connection again is a no-op. - self.loop.run_until_complete(self.protocol.close(reason="oh noes!")) - - self.assertConnectionClosed(1000, "close") - self.assertNoFrameSent() - - def test_remote_close(self): - # Emulate how the remote endpoint initiates the closing handshake. - self.loop.call_later(MS, self.receive_frame, self.close_frame) - self.loop.call_later(MS, self.receive_eof_if_client) - - # Wait for some data in order to process the handshake. - # After recv() raises ConnectionClosed, the connection is closed. - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(self.protocol.recv()) - - self.assertConnectionClosed(1000, "close") - self.assertOneFrameSent(*self.close_frame) - - # Closing the connection again is a no-op. - self.loop.run_until_complete(self.protocol.close(reason="oh noes!")) - - self.assertConnectionClosed(1000, "close") - self.assertNoFrameSent() - - def test_remote_close_and_connection_lost(self): - self.make_drain_slow() - # Drop the connection right after receiving a close frame, - # which prevents echoing the close frame properly. - self.receive_frame(self.close_frame) - self.receive_eof() - - with self.assertNoLogs(): - self.loop.run_until_complete(self.protocol.close(reason="oh noes!")) - - self.assertConnectionClosed(1000, "close") - self.assertOneFrameSent(*self.close_frame) - - def test_simultaneous_close(self): - # Receive the incoming close frame right after self.protocol.close() - # starts executing. This reproduces the error described in: - # https://github.com/aaugustin/websockets/issues/339 - self.loop.call_soon(self.receive_frame, self.remote_close) - self.loop.call_soon(self.receive_eof_if_client) - - self.loop.run_until_complete(self.protocol.close(reason="local")) - - self.assertConnectionClosed(1000, "remote") - # The current implementation sends a close frame in response to the - # close frame received from the remote end. It skips the close frame - # that should be sent as a result of calling close(). - self.assertOneFrameSent(*self.remote_close) - - def test_close_preserves_incoming_frames(self): - self.receive_frame(Frame(True, OP_TEXT, b"hello")) - - self.loop.call_later(MS, self.receive_frame, self.close_frame) - self.loop.call_later(MS, self.receive_eof_if_client) - self.loop.run_until_complete(self.protocol.close(reason="close")) - - self.assertConnectionClosed(1000, "close") - self.assertOneFrameSent(*self.close_frame) - - next_message = self.loop.run_until_complete(self.protocol.recv()) - self.assertEqual(next_message, "hello") - - def test_close_protocol_error(self): - invalid_close_frame = Frame(True, OP_CLOSE, b"\x00") - self.receive_frame(invalid_close_frame) - self.receive_eof_if_client() - self.run_loop_once() - self.loop.run_until_complete(self.protocol.close(reason="close")) - - self.assertConnectionFailed(1002, "") - - def test_close_connection_lost(self): - self.receive_eof() - self.run_loop_once() - self.loop.run_until_complete(self.protocol.close(reason="close")) - - self.assertConnectionFailed(1006, "") - - def test_local_close_during_recv(self): - recv = self.loop.create_task(self.protocol.recv()) - - self.loop.call_later(MS, self.receive_frame, self.close_frame) - self.loop.call_later(MS, self.receive_eof_if_client) - - self.loop.run_until_complete(self.protocol.close(reason="close")) - - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(recv) - - self.assertConnectionClosed(1000, "close") - - # There is no test_remote_close_during_recv because it would be identical - # to test_remote_close. - - def test_remote_close_during_send(self): - self.make_drain_slow() - send = self.loop.create_task(self.protocol.send("hello")) - - self.receive_frame(self.close_frame) - self.receive_eof() - - with self.assertRaises(ConnectionClosed): - self.loop.run_until_complete(send) - - self.assertConnectionClosed(1000, "close") - - # There is no test_local_close_during_send because this cannot really - # happen, considering that writes are serialized. - - -class ServerTests(CommonTests, AsyncioTestCase): - def setUp(self): - super().setUp() - self.protocol.is_client = False - self.protocol.side = "server" - - def test_local_close_send_close_frame_timeout(self): - self.protocol.close_timeout = 10 * MS - self.make_drain_slow(50 * MS) - # If we can't send a close frame, time out in 10ms. - # Check the timing within -1/+9ms for robustness. - with self.assertCompletesWithin(9 * MS, 19 * MS): - self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionClosed(1006, "") - - def test_local_close_receive_close_frame_timeout(self): - self.protocol.close_timeout = 10 * MS - # If the client doesn't send a close frame, time out in 10ms. - # Check the timing within -1/+9ms for robustness. - with self.assertCompletesWithin(9 * MS, 19 * MS): - self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionClosed(1006, "") - - def test_local_close_connection_lost_timeout_after_write_eof(self): - self.protocol.close_timeout = 10 * MS - # If the client doesn't close its side of the TCP connection after we - # half-close our side with write_eof(), time out in 10ms. - # Check the timing within -1/+9ms for robustness. - with self.assertCompletesWithin(9 * MS, 19 * MS): - # HACK: disable write_eof => other end drops connection emulation. - self.transport._eof = True - self.receive_frame(self.close_frame) - self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionClosed(1000, "close") - - def test_local_close_connection_lost_timeout_after_close(self): - self.protocol.close_timeout = 10 * MS - # If the client doesn't close its side of the TCP connection after we - # half-close our side with write_eof() and close it with close(), time - # out in 20ms. - # Check the timing within -1/+9ms for robustness. - with self.assertCompletesWithin(19 * MS, 29 * MS): - # HACK: disable write_eof => other end drops connection emulation. - self.transport._eof = True - # HACK: disable close => other end drops connection emulation. - self.transport._closing = True - self.receive_frame(self.close_frame) - self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionClosed(1000, "close") - - -class ClientTests(CommonTests, AsyncioTestCase): - def setUp(self): - super().setUp() - self.protocol.is_client = True - self.protocol.side = "client" - - def test_local_close_send_close_frame_timeout(self): - self.protocol.close_timeout = 10 * MS - self.make_drain_slow(50 * MS) - # If we can't send a close frame, time out in 20ms. - # - 10ms waiting for sending a close frame - # - 10ms waiting for receiving a half-close - # Check the timing within -1/+9ms for robustness. - with self.assertCompletesWithin(19 * MS, 29 * MS): - self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionClosed(1006, "") - - def test_local_close_receive_close_frame_timeout(self): - self.protocol.close_timeout = 10 * MS - # If the server doesn't send a close frame, time out in 20ms: - # - 10ms waiting for receiving a close frame - # - 10ms waiting for receiving a half-close - # Check the timing within -1/+9ms for robustness. - with self.assertCompletesWithin(19 * MS, 29 * MS): - self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionClosed(1006, "") - - def test_local_close_connection_lost_timeout_after_write_eof(self): - self.protocol.close_timeout = 10 * MS - # If the server doesn't half-close its side of the TCP connection - # after we send a close frame, time out in 20ms: - # - 10ms waiting for receiving a half-close - # - 10ms waiting for receiving a close after write_eof - # Check the timing within -1/+9ms for robustness. - with self.assertCompletesWithin(19 * MS, 29 * MS): - # HACK: disable write_eof => other end drops connection emulation. - self.transport._eof = True - self.receive_frame(self.close_frame) - self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionClosed(1000, "close") - - def test_local_close_connection_lost_timeout_after_close(self): - self.protocol.close_timeout = 10 * MS - # If the client doesn't close its side of the TCP connection after we - # half-close our side with write_eof() and close it with close(), time - # out in 20ms. - # - 10ms waiting for receiving a half-close - # - 10ms waiting for receiving a close after write_eof - # - 10ms waiting for receiving a close after close - # Check the timing within -1/+9ms for robustness. - with self.assertCompletesWithin(29 * MS, 39 * MS): - # HACK: disable write_eof => other end drops connection emulation. - self.transport._eof = True - # HACK: disable close => other end drops connection emulation. - self.transport._closing = True - self.receive_frame(self.close_frame) - self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionClosed(1000, "close") +# Check that the legacy protocol module imports without an exception. +from websockets.protocol import * # noqa diff --git a/tests/utils.py b/tests/utils.py index 790d25687..ac891a0fd 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,10 +1,4 @@ -import asyncio -import contextlib import email.utils -import functools -import logging -import os -import time import unittest @@ -27,89 +21,3 @@ def assertGeneratorReturns(self, gen): with self.assertRaises(StopIteration) as raised: next(gen) return raised.exception.value - - -class AsyncioTestCase(unittest.TestCase): - """ - Base class for tests that sets up an isolated event loop for each test. - - """ - - def __init_subclass__(cls, **kwargs): - """ - Convert test coroutines to test functions. - - This supports asychronous tests transparently. - - """ - super().__init_subclass__(**kwargs) - for name in unittest.defaultTestLoader.getTestCaseNames(cls): - test = getattr(cls, name) - if asyncio.iscoroutinefunction(test): - setattr(cls, name, cls.convert_async_to_sync(test)) - - @staticmethod - def convert_async_to_sync(test): - """ - Convert a test coroutine to a test function. - - """ - - @functools.wraps(test) - def test_func(self, *args, **kwargs): - return self.loop.run_until_complete(test(self, *args, **kwargs)) - - return test_func - - def setUp(self): - super().setUp() - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) - - def tearDown(self): - self.loop.close() - super().tearDown() - - def run_loop_once(self): - # Process callbacks scheduled with call_soon by appending a callback - # to stop the event loop then running it until it hits that callback. - self.loop.call_soon(self.loop.stop) - self.loop.run_forever() - - @contextlib.contextmanager - def assertNoLogs(self, logger="websockets", level=logging.ERROR): - """ - No message is logged on the given logger with at least the given level. - - """ - with self.assertLogs(logger, level) as logs: - # We want to test that no log message is emitted - # but assertLogs expects at least one log message. - logging.getLogger(logger).log(level, "dummy") - yield - - level_name = logging.getLevelName(level) - self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) - - def assertDeprecationWarnings(self, recorded_warnings, expected_warnings): - """ - Check recorded deprecation warnings match a list of expected messages. - - """ - self.assertEqual(len(recorded_warnings), len(expected_warnings)) - for recorded, expected in zip(recorded_warnings, expected_warnings): - actual = recorded.message - self.assertEqual(str(actual), expected) - self.assertEqual(type(actual), DeprecationWarning) - - -# Unit for timeouts. May be increased on slow machines by setting the -# WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. -MS = 0.001 * int(os.environ.get("WEBSOCKETS_TESTS_TIMEOUT_FACTOR", 1)) - -# asyncio's debug mode has a 10x performance penalty for this test suite. -if os.environ.get("PYTHONASYNCIODEBUG"): # pragma: no cover - MS *= 10 - -# Ensure that timeouts are larger than the clock's resolution (for Windows). -MS = max(MS, 2.5 * time.get_clock_info("monotonic").resolution) From 9a99229c671711d6274d3914244694e106966268 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 29 Nov 2020 15:45:41 +0100 Subject: [PATCH 0735/1539] Explain backwards-compatibility & versioning policies. --- docs/changelog.rst | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index 291ec6938..2d2e7ca08 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -3,6 +3,23 @@ Changelog .. currentmodule:: websockets +Backwards-compatibility policy +.............................. + +``websockets`` is intended for production use. Therefore, stability is a goal. + +``websockets`` also aims at providing the best API for WebSocket in Python. + +While we value stability, we value progress more. When an improvement requires +changing the API, we make the change and document it below. + +When possible with reasonable effort, we preserve backwards-compatibility for +five years after the release that introduced the change. + +When a release contains backwards-incompatible API changes, the major version +is increased, else the minor version is increased. Patch versions are only for +fixing regressions shortly after a release. + 9.0 ... From 9c14a2f981af2da3517564ea7396ea06e19114d3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 29 Nov 2020 18:02:12 +0100 Subject: [PATCH 0736/1539] Review and update changelog. * Add missing items for 9.0 release. * Re-assess infos / warnings. * Add release dates. --- docs/changelog.rst | 276 ++++++++++++++++++++++++++++----------------- 1 file changed, 174 insertions(+), 102 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 2d2e7ca08..8d255fdfd 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -29,36 +29,54 @@ fixing regressions shortly after a release. **Version 9.0 moves or deprecates several APIs.** - * Import :class:`~datastructures.Headers` and - :exc:`~datastructures.MultipleValuesError` from - :mod:`websockets.datastructures` instead of :mod:`websockets.http`. + * :class:`~datastructures.Headers` and + :exc:`~datastructures.MultipleValuesError` were moved from + ``websockets.http`` to :mod:`websockets.datastructures`. - * :mod:`websockets.client`, :mod:`websockets.server,` - :mod:`websockets.protocol`, and :mod:`websockets.auth` were moved to - :mod:`websockets.legacy.client`, :mod:`websockets.legacy.server`, - :mod:`websockets.legacy.protocol`, and :mod:`websockets.legacy.auth` - respectively. + * ``websockets.client``, ``websockets.server``, ``websockets.protocol``, + and ``websockets.auth`` were moved to :mod:`websockets.legacy.client`, + :mod:`websockets.legacy.server`, :mod:`websockets.legacy.protocol`, and + :mod:`websockets.legacy.auth` respectively. - * :mod:`websockets.handshake` is deprecated. + * ``websockets.handshake`` is deprecated. - * :mod:`websockets.http` is deprecated. + * ``websockets.http`` is deprecated. - * :mod:`websockets.framing` is deprecated. + * ``websockets.framing`` is deprecated. Aliases provide backwards compatibility for all previously public APIs. +* Added compatibility with Python 3.9. + * Added support for IRIs in addition to URIs. +* Added close codes 1012, 1013, and 1014. + +* Raised an error when passing a :class:`dict` to + :meth:`~legacy.protocol.WebSocketCommonProtocol.send`. + +* Fixed ``Host`` header sent when connecting to an IPv6 address. + +* Aligned maximum cookie size with popular web browsers. + +* Ensured cancellation always propagates, even on Python versions where + :exc:`~asyncio.CancelledError` inherits :exc:`Exception`. + +* Improved error reporting. + + 8.1 ... -* Added compatibility with Python 3.8. +*November 1, 2019* -* Added close codes 1012, 1013, and 1014. +* Added compatibility with Python 3.8. 8.0.2 ..... +*July 31, 2019* + * Restored the ability to pass a socket with the ``sock`` parameter of :func:`~legacy.server.serve`. @@ -67,12 +85,16 @@ fixing regressions shortly after a release. 8.0.1 ..... +*July 21, 2019* + * Restored the ability to import ``WebSocketProtocolError`` from ``websockets``. 8.0 ... +*July 7, 2019* + .. warning:: **Version 8.0 drops compatibility with Python 3.4 and 3.5.** @@ -83,7 +105,8 @@ fixing regressions shortly after a release. Previously, it could be a function or a coroutine. - If you're passing a ``process_request`` argument to :func:`~legacy.server.serve` + If you're passing a ``process_request`` argument to + :func:`~legacy.server.serve` or :class:`~legacy.server.WebSocketServerProtocol`, or if you're overriding :meth:`~legacy.server.WebSocketServerProtocol.process_request` in a subclass, define it with ``async def`` instead of ``def``. @@ -103,36 +126,38 @@ fixing regressions shortly after a release. **Version 8.0 deprecates the** ``host`` **,** ``port`` **, and** ``secure`` **attributes of** :class:`~legacy.protocol.WebSocketCommonProtocol`. - Use :attr:`~legacy.protocol.WebSocketCommonProtocol.local_address` in servers and + Use :attr:`~legacy.protocol.WebSocketCommonProtocol.local_address` in + servers and :attr:`~legacy.protocol.WebSocketCommonProtocol.remote_address` in clients instead of ``host`` and ``port``. .. note:: **Version 8.0 renames the** ``WebSocketProtocolError`` **exception** - to :exc:`ProtocolError` **.** + to :exc:`~exceptions.ProtocolError` **.** A ``WebSocketProtocolError`` alias provides backwards compatibility. .. note:: **Version 8.0 adds the reason phrase to the return type of the low-level - API** :func:`~http.read_response` **.** + API** ``read_response()`` **.** Also: * :meth:`~legacy.protocol.WebSocketCommonProtocol.send`, :meth:`~legacy.protocol.WebSocketCommonProtocol.ping`, and - :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` support bytes-like types - :class:`bytearray` and :class:`memoryview` in addition to :class:`bytes`. + :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` support bytes-like + types :class:`bytearray` and :class:`memoryview` in addition to + :class:`bytes`. * Added :exc:`~exceptions.ConnectionClosedOK` and :exc:`~exceptions.ConnectionClosedError` subclasses of :exc:`~exceptions.ConnectionClosed` to tell apart normal connection termination from errors. -* Added :func:`~legacy.auth.basic_auth_protocol_factory` to enforce HTTP Basic Auth - on the server side. +* Added :func:`~legacy.auth.basic_auth_protocol_factory` to enforce HTTP + Basic Auth on the server side. * :func:`~legacy.client.connect` handles redirects from the server during the handshake. @@ -148,8 +173,9 @@ Also: exceptions in keepalive ping task. If you were using ``ping_timeout=None`` as a workaround, you can remove it. -* Changed :meth:`WebSocketServer.close() ` to - perform a proper closing handshake instead of failing the connection. +* Changed :meth:`WebSocketServer.close() + ` to perform a proper closing handshake + instead of failing the connection. * Avoided a crash when a ``extra_headers`` callable returns ``None``. @@ -170,20 +196,20 @@ Also: 7.0 ... -.. warning:: +*November 1, 2018* - **Version 7.0 renames the** ``timeout`` **argument of** - :func:`~legacy.server.serve()` **and** :func:`~legacy.client.connect` **to** - ``close_timeout`` **.** +.. warning:: - This prevents confusion with ``ping_timeout``. + ``websockets`` **now sends Ping frames at regular intervals and closes the + connection if it doesn't receive a matching Pong frame.** - For backwards compatibility, ``timeout`` is still supported. + See :class:`~legacy.protocol.WebSocketCommonProtocol` for details. .. warning:: - **Version 7.0 changes how a server terminates connections when it's - closed with** :meth:`~legacy.server.WebSocketServer.close` **.** + **Version 7.0 changes how a server terminates connections when it's closed + with** :meth:`WebSocketServer.close() + ` **.** Previously, connections handlers were canceled. Now, connections are closed with close code 1001 (going away). From the perspective of the @@ -200,8 +226,19 @@ Also: .. note:: - **Version 7.0 changes how a** :meth:`~legacy.protocol.WebSocketCommonProtocol.ping` - **that hasn't received a pong yet behaves when the connection is closed.** + **Version 7.0 renames the** ``timeout`` **argument of** + :func:`~legacy.server.serve` **and** :func:`~legacy.client.connect` **to** + ``close_timeout`` **.** + + This prevents confusion with ``ping_timeout``. + + For backwards compatibility, ``timeout`` is still supported. + +.. note:: + + **Version 7.0 changes how a** + :meth:`~legacy.protocol.WebSocketCommonProtocol.ping` **that hasn't + received a pong yet behaves when the connection is closed.** The ping — as in ``ping = await websocket.ping()`` — used to be canceled when the connection is closed, so that ``await ping`` raised @@ -211,34 +248,33 @@ Also: .. note:: **Version 7.0 raises a** :exc:`RuntimeError` **exception if two coroutines - call** :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` **concurrently.** + call** :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` + **concurrently.** Concurrent calls lead to non-deterministic behavior because there are no guarantees about which coroutine will receive which message. Also: -* ``websockets`` sends Ping frames at regular intervals and closes the - connection if it doesn't receive a matching Pong frame. See - :class:`~legacy.protocol.WebSocketCommonProtocol` for details. - * Added ``process_request`` and ``select_subprotocol`` arguments to - :func:`~legacy.server.serve` and :class:`~legacy.server.WebSocketServerProtocol` to - customize :meth:`~legacy.server.WebSocketServerProtocol.process_request` and + :func:`~legacy.server.serve` and + :class:`~legacy.server.WebSocketServerProtocol` to customize + :meth:`~legacy.server.WebSocketServerProtocol.process_request` and :meth:`~legacy.server.WebSocketServerProtocol.select_subprotocol` without subclassing :class:`~legacy.server.WebSocketServerProtocol`. * Added support for sending fragmented messages. -* Added the :meth:`~legacy.protocol.WebSocketCommonProtocol.wait_closed` method to - protocols. +* Added the :meth:`~legacy.protocol.WebSocketCommonProtocol.wait_closed` + method to protocols. * Added an interactive client: ``python -m websockets ``. * Changed the ``origins`` argument to represent the lack of an origin with ``None`` rather than ``''``. -* Fixed a data loss bug in :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`: +* Fixed a data loss bug in + :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`: canceling it at the wrong time could result in messages being dropped. * Improved handling of multiple HTTP headers with the same name. @@ -248,36 +284,37 @@ Also: 6.0 ... +*July 16, 2018* + .. warning:: - **Version 6.0 introduces the** :class:`~http.Headers` **class for managing - HTTP headers and changes several public APIs:** + **Version 6.0 introduces the** :class:`~datastructures.Headers` **class + for managing HTTP headers and changes several public APIs:** - * :meth:`~legacy.server.WebSocketServerProtocol.process_request` now receives a - :class:`~http.Headers` instead of a :class:`~http.client.HTTPMessage` in - the ``request_headers`` argument. + * :meth:`~legacy.server.WebSocketServerProtocol.process_request` now + receives a :class:`~datastructures.Headers` instead of a + ``http.client.HTTPMessage`` in the ``request_headers`` argument. - * The :attr:`~legacy.protocol.WebSocketCommonProtocol.request_headers` and - :attr:`~legacy.protocol.WebSocketCommonProtocol.response_headers` attributes of - :class:`~legacy.protocol.WebSocketCommonProtocol` are :class:`~http.Headers` - instead of :class:`~http.client.HTTPMessage`. + * The ``request_headers`` and ``response_headers`` attributes of + :class:`~legacy.protocol.WebSocketCommonProtocol` are + :class:`~datastructures.Headers` instead of ``http.client.HTTPMessage``. - * The :attr:`~legacy.protocol.WebSocketCommonProtocol.raw_request_headers` and - :attr:`~legacy.protocol.WebSocketCommonProtocol.raw_response_headers` - attributes of :class:`~legacy.protocol.WebSocketCommonProtocol` are removed. - Use :meth:`~http.Headers.raw_items` instead. + * The ``raw_request_headers`` and ``raw_response_headers`` attributes of + :class:`~legacy.protocol.WebSocketCommonProtocol` are removed. Use + :meth:`~datastructures.Headers.raw_items` instead. - * Functions defined in the :mod:`~handshake` module now receive - :class:`~http.Headers` in argument instead of ``get_header`` or - ``set_header`` functions. This affects libraries that rely on + * Functions defined in the ``handshake`` module now receive + :class:`~datastructures.Headers` in argument instead of ``get_header`` + or ``set_header`` functions. This affects libraries that rely on low-level APIs. - * Functions defined in the :mod:`~http` module now return HTTP headers as - :class:`~http.Headers` instead of lists of ``(name, value)`` pairs. + * Functions defined in the ``http`` module now return HTTP headers as + :class:`~datastructures.Headers` instead of lists of ``(name, value)`` + pairs. - Since :class:`~http.Headers` and :class:`~http.client.HTTPMessage` provide - similar APIs, this change won't affect most of the code dealing with HTTP - headers. + Since :class:`~datastructures.Headers` and ``http.client.HTTPMessage`` + provide similar APIs, this change won't affect most of the code dealing + with HTTP headers. Also: @@ -287,12 +324,16 @@ Also: 5.0.1 ..... -* Fixed a regression in the 5.0 release that broke some invocations of - :func:`~legacy.server.serve()` and :func:`~legacy.client.connect`. +*May 24, 2018* + +* Fixed a regression in 5.0 that broke some invocations of + :func:`~legacy.server.serve` and :func:`~legacy.client.connect`. 5.0 ... +*May 22, 2018* + .. note:: **Version 5.0 fixes a security issue introduced in version 4.0.** @@ -308,8 +349,8 @@ Also: **Version 5.0 adds a** ``user_info`` **field to the return value of** :func:`~uri.parse_uri` **and** :class:`~uri.WebSocketURI` **.** - If you're unpacking :class:`~exceptions.WebSocketURI` into four variables, - adjust your code to account for that fifth field. + If you're unpacking :class:`~uri.WebSocketURI` into four variables, adjust + your code to account for that fifth field. Also: @@ -322,14 +363,14 @@ Also: * A plain HTTP request now receives a 426 Upgrade Required response and doesn't log a stack trace. -* :func:`~legacy.server.unix_serve` can be used as an asynchronous context manager on - Python ≥ 3.5.1. +* :func:`~legacy.server.unix_serve` can be used as an asynchronous context + manager on Python ≥ 3.5.1. -* Added the :attr:`~legacy.protocol.WebSocketCommonProtocol.closed` property to - protocols. +* Added the :attr:`~legacy.protocol.WebSocketCommonProtocol.closed` property + to protocols. -* If a :meth:`~legacy.protocol.WebSocketCommonProtocol.ping` doesn't receive a pong, - it's canceled when the connection is closed. +* If a :meth:`~legacy.protocol.WebSocketCommonProtocol.ping` doesn't receive a + pong, it's canceled when the connection is closed. * Reported the cause of :exc:`~exceptions.ConnectionClosed` exceptions. @@ -355,13 +396,21 @@ Also: 4.0.1 ..... +*November 2, 2017* + * Fixed issues with the packaging of the 4.0 release. 4.0 ... +*November 2, 2017* + .. warning:: + **Version 4.0 drops compatibility with Python 3.3.** + +.. note:: + **Version 4.0 enables compression with the permessage-deflate extension.** In August 2017, Firefox and Chrome support it, but not Safari and IE. @@ -369,11 +418,7 @@ Also: Compression should improve performance but it increases RAM and CPU use. If you want to disable compression, add ``compression=None`` when calling - :func:`~legacy.server.serve()` or :func:`~legacy.client.connect`. - -.. warning:: - - **Version 4.0 drops compatibility with Python 3.3.** + :func:`~legacy.server.serve` or :func:`~legacy.client.connect`. .. note:: @@ -388,8 +433,8 @@ Also: * Added :func:`~legacy.server.unix_serve` for listening on Unix sockets. -* Added the :attr:`~legacy.server.WebSocketServer.sockets` attribute to the return - value of :func:`~legacy.server.serve`. +* Added the :attr:`~legacy.server.WebSocketServer.sockets` attribute to the + return value of :func:`~legacy.server.serve`. * Reorganized and extended documentation. @@ -407,12 +452,14 @@ Also: 3.4 ... -* Renamed :func:`~legacy.server.serve()` and :func:`~legacy.client.connect`'s ``klass`` - argument to ``create_protocol`` to reflect that it can also be a callable. - For backwards compatibility, ``klass`` is still supported. +*August 20, 2017* + +* Renamed :func:`~legacy.server.serve` and :func:`~legacy.client.connect`'s + ``klass`` argument to ``create_protocol`` to reflect that it can also be a + callable. For backwards compatibility, ``klass`` is still supported. -* :func:`~legacy.server.serve` can be used as an asynchronous context manager on - Python ≥ 3.5.1. +* :func:`~legacy.server.serve` can be used as an asynchronous context manager + on Python ≥ 3.5.1. * Added support for customizing handling of incoming connections with :meth:`~legacy.server.WebSocketServerProtocol.process_request`. @@ -423,8 +470,8 @@ Also: * Added an optional C extension to speed up low-level operations. -* An invalid response status code during :func:`~legacy.client.connect` now raises - :class:`~exceptions.InvalidStatusCode` with a ``code`` attribute. +* An invalid response status code during :func:`~legacy.client.connect` now + raises :class:`~exceptions.InvalidStatusCode` with a ``code`` attribute. * Providing a ``sock`` argument to :func:`~legacy.client.connect` no longer crashes. @@ -432,6 +479,8 @@ Also: 3.3 ... +*March 29, 2017* + * Ensured compatibility with Python 3.6. * Reduced noise in logs caused by connection resets. @@ -441,14 +490,18 @@ Also: 3.2 ... +*August 17, 2016* + * Added ``timeout``, ``max_size``, and ``max_queue`` arguments to - :func:`~legacy.client.connect()` and :func:`~legacy.server.serve`. + :func:`~legacy.client.connect` and :func:`~legacy.server.serve`. * Made server shutdown more robust. 3.1 ... +*April 21, 2016* + * Avoided a warning when closing a connection before the opening handshake. * Added flow control for incoming data. @@ -456,6 +509,8 @@ Also: 3.0 ... +*December 25, 2015* + .. warning:: **Version 3.0 introduces a backwards-incompatible change in the** @@ -463,9 +518,9 @@ Also: **If you're upgrading from 2.x or earlier, please read this carefully.** - :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` used to return ``None`` - when the connection was closed. This required checking the return value of - every call:: + :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` used to return + ``None`` when the connection was closed. This required checking the return + value of every call:: message = await websocket.recv() if message is None: @@ -484,13 +539,13 @@ Also: previous behavior can be restored by passing ``legacy_recv=True`` to :func:`~legacy.server.serve`, :func:`~legacy.client.connect`, :class:`~legacy.server.WebSocketServerProtocol`, or - :class:`~legacy.client.WebSocketClientProtocol`. ``legacy_recv`` isn't documented - in their signatures but isn't scheduled for deprecation either. + :class:`~legacy.client.WebSocketClientProtocol`. ``legacy_recv`` isn't + documented in their signatures but isn't scheduled for deprecation either. Also: -* :func:`~legacy.client.connect` can be used as an asynchronous context manager on - Python ≥ 3.5.1. +* :func:`~legacy.client.connect` can be used as an asynchronous context + manager on Python ≥ 3.5.1. * Updated documentation with ``await`` and ``async`` syntax from Python 3.5. @@ -498,7 +553,8 @@ Also: :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` support data passed as :class:`str` in addition to :class:`bytes`. -* Worked around an asyncio bug affecting connection termination under load. +* Worked around an :mod:`asyncio` bug affecting connection termination under + load. * Made ``state_name`` attribute on protocols a public API. @@ -507,6 +563,8 @@ Also: 2.7 ... +*November 18, 2015* + * Added compatibility with Python 3.5. * Refreshed documentation. @@ -514,6 +572,8 @@ Also: 2.6 ... +*August 18, 2015* + * Added ``local_address`` and ``remote_address`` attributes on protocols. * Closed open connections with code 1001 when a server shuts down. @@ -523,19 +583,21 @@ Also: 2.5 ... +*July 28, 2015* + * Improved documentation. * Provided access to handshake request and response HTTP headers. * Allowed customizing handshake request and response HTTP headers. -* Supported running on a non-default event loop. +* Added support for running on a non-default event loop. * Returned a 403 status code instead of 400 when the request Origin isn't allowed. -* Canceling :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` no longer drops - the next message. +* Canceling :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` no longer + drops the next message. * Clarified that the closing handshake can be initiated by the client. @@ -548,9 +610,9 @@ Also: 2.4 ... -* Added support for subprotocols. +*January 31, 2015* -* Supported non-default event loop. +* Added support for subprotocols. * Added ``loop`` argument to :func:`~legacy.client.connect` and :func:`~legacy.server.serve`. @@ -558,16 +620,22 @@ Also: 2.3 ... +*November 3, 2014* + * Improved compliance of close codes. 2.2 ... +*July 28, 2014* + * Added support for limiting message size. 2.1 ... +*April 26, 2014* + * Added ``host``, ``port`` and ``secure`` attributes on protocols. * Added support for providing and checking Origin_. @@ -577,6 +645,8 @@ Also: 2.0 ... +*February 16, 2014* + .. warning:: **Version 2.0 introduces a backwards-incompatible change in the** @@ -603,4 +673,6 @@ Also: 1.0 ... +*November 14, 2013* + * Initial public release. From 94256f4f41ef024f7f511a573763bd755f5f1b46 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 29 Nov 2020 18:02:27 +0100 Subject: [PATCH 0737/1539] Update word list for spell check. --- docs/spelling_wordlist.txt | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 1eacc491d..dd3500b73 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -5,14 +5,21 @@ awaitable aymeric backpressure Backpressure +balancer +balancers Bitcoin +bottlenecked bufferbloat Bufferbloat bugfix bytestring bytestrings changelog +coroutine +coroutines +cryptocurrencies cryptocurrency +Ctrl daemonize fractalideas iterable @@ -20,18 +27,25 @@ keepalive KiB lifecycle Lifecycle +lookups MiB nginx +parsers permessage pong pongs Pythonic serializers +Subclasses +subclasses subclassing subprotocol subprotocols +Tidelift TLS +tox Unparse +unregister uple username websocket From 42f0e2c0b8e994c33b792208adff32bea1cdff4f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 29 Nov 2020 22:01:42 +0100 Subject: [PATCH 0738/1539] Add helper to manage aliases and deprecations. This may save a little bit of CPU and memory by avoiding unnecessary imports too, especially as the library grows. --- src/websockets/__init__.py | 72 +++++++++++++++++++++++----- src/websockets/auth.py | 4 -- src/websockets/client.py | 11 ++++- src/websockets/framing.py | 6 --- src/websockets/handshake.py | 45 ------------------ src/websockets/http.py | 45 +++++++----------- src/websockets/imports.py | 95 +++++++++++++++++++++++++++++++++++++ src/websockets/protocol.py | 1 - src/websockets/server.py | 18 ++++--- tests/test_auth.py | 2 - tests/test_exports.py | 9 ++++ tests/test_framing.py | 9 ---- tests/test_handshake.py | 2 - tests/test_imports.py | 53 +++++++++++++++++++++ tests/test_protocol.py | 2 - 15 files changed, 256 insertions(+), 118 deletions(-) delete mode 100644 src/websockets/auth.py delete mode 100644 src/websockets/framing.py delete mode 100644 src/websockets/handshake.py create mode 100644 src/websockets/imports.py delete mode 100644 src/websockets/protocol.py delete mode 100644 tests/test_auth.py delete mode 100644 tests/test_framing.py delete mode 100644 tests/test_handshake.py create mode 100644 tests/test_imports.py delete mode 100644 tests/test_protocol.py diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 0242e7942..580a3960f 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -1,19 +1,8 @@ -# This relies on each of the submodules having an __all__ variable. - -from .client import * -from .datastructures import * # noqa -from .exceptions import * # noqa -from .legacy.auth import * # noqa -from .legacy.client import * # noqa -from .legacy.protocol import * # noqa -from .legacy.server import * # noqa -from .server import * -from .typing import * # noqa -from .uri import * # noqa +from .imports import lazy_import from .version import version as __version__ # noqa -__all__ = [ +__all__ = [ # noqa "AbortHandshake", "basic_auth_protocol_factory", "BasicAuthWebSocketServerProtocol", @@ -58,3 +47,60 @@ "WebSocketServerProtocol", "WebSocketURI", ] + +lazy_import( + globals(), + aliases={ + "auth": ".legacy", + "basic_auth_protocol_factory": ".legacy.auth", + "BasicAuthWebSocketServerProtocol": ".legacy.auth", + "ClientConnection": ".client", + "connect": ".legacy.client", + "unix_connect": ".legacy.client", + "WebSocketClientProtocol": ".legacy.client", + "Headers": ".datastructures", + "MultipleValuesError": ".datastructures", + "WebSocketException": ".exceptions", + "ConnectionClosed": ".exceptions", + "ConnectionClosedError": ".exceptions", + "ConnectionClosedOK": ".exceptions", + "InvalidHandshake": ".exceptions", + "SecurityError": ".exceptions", + "InvalidMessage": ".exceptions", + "InvalidHeader": ".exceptions", + "InvalidHeaderFormat": ".exceptions", + "InvalidHeaderValue": ".exceptions", + "InvalidOrigin": ".exceptions", + "InvalidUpgrade": ".exceptions", + "InvalidStatusCode": ".exceptions", + "NegotiationError": ".exceptions", + "DuplicateParameter": ".exceptions", + "InvalidParameterName": ".exceptions", + "InvalidParameterValue": ".exceptions", + "AbortHandshake": ".exceptions", + "RedirectHandshake": ".exceptions", + "InvalidState": ".exceptions", + "InvalidURI": ".exceptions", + "PayloadTooBig": ".exceptions", + "ProtocolError": ".exceptions", + "WebSocketProtocolError": ".exceptions", + "protocol": ".legacy", + "WebSocketCommonProtocol": ".legacy.protocol", + "ServerConnection": ".server", + "serve": ".legacy.server", + "unix_serve": ".legacy.server", + "WebSocketServerProtocol": ".legacy.server", + "WebSocketServer": ".legacy.server", + "Data": ".typing", + "Origin": ".typing", + "ExtensionHeader": ".typing", + "ExtensionParameter": ".typing", + "Subprotocol": ".typing", + "parse_uri": ".uri", + "WebSocketURI": ".uri", + }, + deprecated_aliases={ + "framing": ".legacy", + "handshake": ".legacy", + }, +) diff --git a/src/websockets/auth.py b/src/websockets/auth.py deleted file mode 100644 index c8839c401..000000000 --- a/src/websockets/auth.py +++ /dev/null @@ -1,4 +0,0 @@ -from .legacy.auth import BasicAuthWebSocketServerProtocol, basic_auth_protocol_factory - - -__all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"] diff --git a/src/websockets/client.py b/src/websockets/client.py index 8cababed5..91dd1662e 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -24,7 +24,7 @@ ) from .http import USER_AGENT, build_host from .http11 import Request, Response -from .legacy.client import WebSocketClientProtocol, connect, unix_connect # noqa +from .imports import lazy_import from .typing import ( ConnectionOption, ExtensionHeader, @@ -36,6 +36,15 @@ from .utils import accept_key, generate_key +lazy_import( + globals(), + aliases={ + "connect": ".legacy.client", + "unix_connect": ".legacy.client", + "WebSocketClientProtocol": ".legacy.client", + }, +) + __all__ = ["ClientConnection"] logger = logging.getLogger(__name__) diff --git a/src/websockets/framing.py b/src/websockets/framing.py deleted file mode 100644 index 2dadb5610..000000000 --- a/src/websockets/framing.py +++ /dev/null @@ -1,6 +0,0 @@ -import warnings - -from .legacy.framing import * # noqa - - -warnings.warn("websockets.framing is deprecated", DeprecationWarning) diff --git a/src/websockets/handshake.py b/src/websockets/handshake.py deleted file mode 100644 index cc4010d41..000000000 --- a/src/websockets/handshake.py +++ /dev/null @@ -1,45 +0,0 @@ -import warnings - -from .datastructures import Headers - - -__all__ = ["build_request", "check_request", "build_response", "check_response"] - - -# Backwards compatibility with previously documented public APIs - - -def build_request(headers: Headers) -> str: # pragma: no cover - warnings.warn( - "websockets.handshake.build_request is deprecated", DeprecationWarning - ) - from .legacy.handshake import build_request - - return build_request(headers) - - -def check_request(headers: Headers) -> str: # pragma: no cover - warnings.warn( - "websockets.handshake.check_request is deprecated", DeprecationWarning - ) - from .legacy.handshake import check_request - - return check_request(headers) - - -def build_response(headers: Headers, key: str) -> None: # pragma: no cover - warnings.warn( - "websockets.handshake.build_response is deprecated", DeprecationWarning - ) - from .legacy.handshake import build_response - - return build_response(headers, key) - - -def check_response(headers: Headers, key: str) -> None: # pragma: no cover - warnings.warn( - "websockets.handshake.check_response is deprecated", DeprecationWarning - ) - from .legacy.handshake import check_response - - return check_response(headers, key) diff --git a/src/websockets/http.py b/src/websockets/http.py index b05b78455..9092836c2 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -1,15 +1,27 @@ -import asyncio import ipaddress import sys -import warnings -from typing import Tuple -# For backwards compatibility: -# Headers and MultipleValuesError used to be defined in this module -from .datastructures import Headers, MultipleValuesError # noqa +from .imports import lazy_import from .version import version as websockets_version +# For backwards compatibility: + + +lazy_import( + globals(), + # Headers and MultipleValuesError used to be defined in this module. + aliases={ + "Headers": ".datastructures", + "MultipleValuesError": ".datastructures", + }, + deprecated_aliases={ + "read_request": ".legacy.http", + "read_response": ".legacy.http", + }, +) + + __all__ = ["USER_AGENT", "build_host"] @@ -38,24 +50,3 @@ def build_host(host: str, port: int, secure: bool) -> str: host = f"{host}:{port}" return host - - -# Backwards compatibility with previously documented public APIs - - -async def read_request( - stream: asyncio.StreamReader, -) -> Tuple[str, Headers]: # pragma: no cover - warnings.warn("websockets.http.read_request is deprecated", DeprecationWarning) - from .legacy.http import read_request - - return await read_request(stream) - - -async def read_response( - stream: asyncio.StreamReader, -) -> Tuple[int, str, Headers]: # pragma: no cover - warnings.warn("websockets.http.read_response is deprecated", DeprecationWarning) - from .legacy.http import read_response - - return await read_response(stream) diff --git a/src/websockets/imports.py b/src/websockets/imports.py new file mode 100644 index 000000000..9a4cfd98a --- /dev/null +++ b/src/websockets/imports.py @@ -0,0 +1,95 @@ +import importlib +import sys +import warnings +from typing import Any, Dict, Iterable, Optional + + +__all__ = ["lazy_import"] + + +def lazy_import( + namespace: Dict[str, Any], + aliases: Optional[Dict[str, str]] = None, + deprecated_aliases: Optional[Dict[str, str]] = None, +) -> None: + """ + Provide lazy, module-level imports. + + Typical use:: + + __getattr__, __dir__ = lazy_import( + globals(), + aliases={ + "": "", + ... + }, + deprecated_aliases={ + ..., + } + ) + + This function defines __getattr__ and __dir__ per PEP 562. + + On Python 3.6 and earlier, it falls back to non-lazy imports and doesn't + raise deprecation warnings. + + """ + if aliases is None: + aliases = {} + if deprecated_aliases is None: + deprecated_aliases = {} + + namespace_set = set(namespace) + aliases_set = set(aliases) + deprecated_aliases_set = set(deprecated_aliases) + + assert not namespace_set & aliases_set, "namespace conflict" + assert not namespace_set & deprecated_aliases_set, "namespace conflict" + assert not aliases_set & deprecated_aliases_set, "namespace conflict" + + package = namespace["__name__"] + + if sys.version_info[:2] >= (3, 7): + + def __getattr__(name: str) -> Any: + assert aliases is not None # mypy cannot figure this out + try: + source = aliases[name] + except KeyError: + pass + else: + module = importlib.import_module(source, package) + return getattr(module, name) + + assert deprecated_aliases is not None # mypy cannot figure this out + try: + source = deprecated_aliases[name] + except KeyError: + pass + else: + warnings.warn( + f"{package}.{name} is deprecated", + DeprecationWarning, + stacklevel=2, + ) + module = importlib.import_module(source, package) + return getattr(module, name) + + raise AttributeError(f"module {package!r} has no attribute {name!r}") + + namespace["__getattr__"] = __getattr__ + + def __dir__() -> Iterable[str]: + return sorted(namespace_set | aliases_set | deprecated_aliases_set) + + namespace["__dir__"] = __dir__ + + else: # pragma: no cover + + for name, source in aliases.items(): + module = importlib.import_module(source, package) + namespace[name] = getattr(module, name) + + for name, source in deprecated_aliases.items(): + module = importlib.import_module(source, package) + namespace[name] = getattr(module, name) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py deleted file mode 100644 index 287f92a57..000000000 --- a/src/websockets/protocol.py +++ /dev/null @@ -1 +0,0 @@ -from .legacy.protocol import * # noqa diff --git a/src/websockets/server.py b/src/websockets/server.py index bd527be74..67ab83031 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -26,12 +26,7 @@ ) from .http import USER_AGENT from .http11 import Request, Response -from .legacy.server import ( # noqa - WebSocketServer, - WebSocketServerProtocol, - serve, - unix_serve, -) +from .imports import lazy_import from .typing import ( ConnectionOption, ExtensionHeader, @@ -42,6 +37,17 @@ from .utils import accept_key +lazy_import( + globals(), + aliases={ + "serve": ".legacy.server", + "unix_serve": ".legacy.server", + "WebSocketServerProtocol": ".legacy.server", + "WebSocketServer": ".legacy.server", + }, +) + + __all__ = ["ServerConnection"] logger = logging.getLogger(__name__) diff --git a/tests/test_auth.py b/tests/test_auth.py deleted file mode 100644 index 01ca207c7..000000000 --- a/tests/test_auth.py +++ /dev/null @@ -1,2 +0,0 @@ -# Check that the legacy auth module imports without an exception. -from websockets.auth import * # noqa diff --git a/tests/test_exports.py b/tests/test_exports.py index 8e4330304..568c50c54 100644 --- a/tests/test_exports.py +++ b/tests/test_exports.py @@ -1,6 +1,15 @@ import unittest import websockets +import websockets.client +import websockets.exceptions +import websockets.legacy.auth +import websockets.legacy.client +import websockets.legacy.protocol +import websockets.legacy.server +import websockets.server +import websockets.typing +import websockets.uri combined_exports = ( diff --git a/tests/test_framing.py b/tests/test_framing.py deleted file mode 100644 index d6fa6352a..000000000 --- a/tests/test_framing.py +++ /dev/null @@ -1,9 +0,0 @@ -import warnings - - -with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", "websockets.framing is deprecated", DeprecationWarning - ) - # Check that the legacy framing module imports without an exception. - from websockets.framing import * # noqa diff --git a/tests/test_handshake.py b/tests/test_handshake.py deleted file mode 100644 index 8c35c9714..000000000 --- a/tests/test_handshake.py +++ /dev/null @@ -1,2 +0,0 @@ -# Check that the legacy handshake module imports without an exception. -from websockets.handshake import * # noqa diff --git a/tests/test_imports.py b/tests/test_imports.py new file mode 100644 index 000000000..113564e9f --- /dev/null +++ b/tests/test_imports.py @@ -0,0 +1,53 @@ +import types +import unittest +import warnings + +from websockets.imports import * + + +foo = object() + +bar = object() + + +class ImportsTests(unittest.TestCase): + def test_get_alias(self): + mod = types.ModuleType("tests.test_imports.test_alias") + lazy_import(vars(mod), aliases={"foo": ".."}) + + self.assertEqual(mod.foo, foo) + + def test_get_deprecated_alias(self): + mod = types.ModuleType("tests.test_imports.test_alias") + lazy_import(vars(mod), deprecated_aliases={"bar": ".."}) + + with warnings.catch_warnings(record=True) as recorded_warnings: + self.assertEqual(mod.bar, bar) + + self.assertEqual(len(recorded_warnings), 1) + warning = recorded_warnings[0].message + self.assertEqual( + str(warning), "tests.test_imports.test_alias.bar is deprecated" + ) + self.assertEqual(type(warning), DeprecationWarning) + + def test_dir(self): + mod = types.ModuleType("tests.test_imports.test_alias") + lazy_import(vars(mod), aliases={"foo": ".."}, deprecated_aliases={"bar": ".."}) + + self.assertEqual( + [item for item in dir(mod) if not item[:2] == item[-2:] == "__"], + ["bar", "foo"], + ) + + def test_attribute_error(self): + mod = types.ModuleType("tests.test_imports.test_alias") + lazy_import(vars(mod)) + + with self.assertRaises(AttributeError) as raised: + mod.foo + + self.assertEqual( + str(raised.exception), + "module 'tests.test_imports.test_alias' has no attribute 'foo'", + ) diff --git a/tests/test_protocol.py b/tests/test_protocol.py deleted file mode 100644 index f896fcae4..000000000 --- a/tests/test_protocol.py +++ /dev/null @@ -1,2 +0,0 @@ -# Check that the legacy protocol module imports without an exception. -from websockets.protocol import * # noqa From 965f8ec77347adaaf23c82eef693c9882269b46c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 30 Nov 2020 21:46:25 +0100 Subject: [PATCH 0739/1539] Fix lazy imports of objects on Python 3.6. --- src/websockets/imports.py | 34 +++++++++++++++++++++++++--------- tests/test_imports.py | 39 +++++++++++++++++++++++++++------------ 2 files changed, 52 insertions(+), 21 deletions(-) diff --git a/src/websockets/imports.py b/src/websockets/imports.py index 9a4cfd98a..efd3eabf3 100644 --- a/src/websockets/imports.py +++ b/src/websockets/imports.py @@ -1,4 +1,3 @@ -import importlib import sys import warnings from typing import Any, Dict, Iterable, Optional @@ -7,6 +6,27 @@ __all__ = ["lazy_import"] +def import_name(name: str, source: str, namespace: Dict[str, Any]) -> Any: + """ + Import from in . + + There are two cases: + + - is an object defined in + - is a submodule of source + + Neither __import__ nor importlib.import_module does exactly this. + __import__ is closer to the intended behavior. + + """ + level = 0 + while source[level] == ".": + level += 1 + assert level < len(source), "importing from parent isn't supported" + module = __import__(source[level:], namespace, None, [name], level) + return getattr(module, name) + + def lazy_import( namespace: Dict[str, Any], aliases: Optional[Dict[str, str]] = None, @@ -58,8 +78,7 @@ def __getattr__(name: str) -> Any: except KeyError: pass else: - module = importlib.import_module(source, package) - return getattr(module, name) + return import_name(name, source, namespace) assert deprecated_aliases is not None # mypy cannot figure this out try: @@ -72,8 +91,7 @@ def __getattr__(name: str) -> Any: DeprecationWarning, stacklevel=2, ) - module = importlib.import_module(source, package) - return getattr(module, name) + return import_name(name, source, namespace) raise AttributeError(f"module {package!r} has no attribute {name!r}") @@ -87,9 +105,7 @@ def __dir__() -> Iterable[str]: else: # pragma: no cover for name, source in aliases.items(): - module = importlib.import_module(source, package) - namespace[name] = getattr(module, name) + namespace[name] = import_name(name, source, namespace) for name, source in deprecated_aliases.items(): - module = importlib.import_module(source, package) - namespace[name] = getattr(module, name) + namespace[name] = import_name(name, source, namespace) diff --git a/tests/test_imports.py b/tests/test_imports.py index 113564e9f..d84808902 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -1,3 +1,4 @@ +import sys import types import unittest import warnings @@ -11,18 +12,30 @@ class ImportsTests(unittest.TestCase): + def setUp(self): + self.mod = types.ModuleType("tests.test_imports.test_alias") + self.mod.__package__ = self.mod.__name__ + def test_get_alias(self): - mod = types.ModuleType("tests.test_imports.test_alias") - lazy_import(vars(mod), aliases={"foo": ".."}) + lazy_import( + vars(self.mod), + aliases={"foo": "...test_imports"}, + ) - self.assertEqual(mod.foo, foo) + self.assertEqual(self.mod.foo, foo) def test_get_deprecated_alias(self): - mod = types.ModuleType("tests.test_imports.test_alias") - lazy_import(vars(mod), deprecated_aliases={"bar": ".."}) + lazy_import( + vars(self.mod), + deprecated_aliases={"bar": "...test_imports"}, + ) with warnings.catch_warnings(record=True) as recorded_warnings: - self.assertEqual(mod.bar, bar) + self.assertEqual(self.mod.bar, bar) + + # No warnings raised on pre-PEP 526 Python. + if sys.version_info[:2] < (3, 7): # pragma: no cover + return self.assertEqual(len(recorded_warnings), 1) warning = recorded_warnings[0].message @@ -32,20 +45,22 @@ def test_get_deprecated_alias(self): self.assertEqual(type(warning), DeprecationWarning) def test_dir(self): - mod = types.ModuleType("tests.test_imports.test_alias") - lazy_import(vars(mod), aliases={"foo": ".."}, deprecated_aliases={"bar": ".."}) + lazy_import( + vars(self.mod), + aliases={"foo": "...test_imports"}, + deprecated_aliases={"bar": "...test_imports"}, + ) self.assertEqual( - [item for item in dir(mod) if not item[:2] == item[-2:] == "__"], + [item for item in dir(self.mod) if not item[:2] == item[-2:] == "__"], ["bar", "foo"], ) def test_attribute_error(self): - mod = types.ModuleType("tests.test_imports.test_alias") - lazy_import(vars(mod)) + lazy_import(vars(self.mod)) with self.assertRaises(AttributeError) as raised: - mod.foo + self.mod.foo self.assertEqual( str(raised.exception), From ecf64e7a56ee85e10a812139a4aee09e736aa241 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 30 Nov 2020 22:36:21 +0100 Subject: [PATCH 0740/1539] Handle non-contiguous memoryviews in C extension. This avoids the special-case in Python code. --- src/websockets/frames.py | 11 ++------ src/websockets/speedups.c | 51 ++++++++++++++++++----------------- tests/legacy/test_protocol.py | 30 --------------------- tests/test_frames.py | 9 ------- tests/test_utils.py | 24 +++-------------- 5 files changed, 32 insertions(+), 93 deletions(-) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 74223c0e8..71783e176 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -263,13 +263,8 @@ def prepare_data(data: Data) -> Tuple[int, bytes]: """ if isinstance(data, str): return OP_TEXT, data.encode("utf-8") - elif isinstance(data, (bytes, bytearray)): + elif isinstance(data, (bytes, bytearray, memoryview)): return OP_BINARY, data - elif isinstance(data, memoryview): - if data.c_contiguous: - return OP_BINARY, data - else: - return OP_BINARY, data.tobytes() else: raise TypeError("data must be bytes-like or str") @@ -290,10 +285,8 @@ def prepare_ctrl(data: Data) -> bytes: """ if isinstance(data, str): return data.encode("utf-8") - elif isinstance(data, (bytes, bytearray)): + elif isinstance(data, (bytes, bytearray, memoryview)): return bytes(data) - elif isinstance(data, memoryview): - return data.tobytes() else: raise TypeError("data must be bytes-like or str") diff --git a/src/websockets/speedups.c b/src/websockets/speedups.c index ede181e5d..fc328e528 100644 --- a/src/websockets/speedups.c +++ b/src/websockets/speedups.c @@ -13,39 +13,35 @@ static const Py_ssize_t MASK_LEN = 4; /* Similar to PyBytes_AsStringAndSize, but accepts more types */ static int -_PyBytesLike_AsStringAndSize(PyObject *obj, char **buffer, Py_ssize_t *length) +_PyBytesLike_AsStringAndSize(PyObject *obj, PyObject **tmp, char **buffer, Py_ssize_t *length) { - // This supports bytes, bytearrays, and C-contiguous memoryview objects, - // which are the most useful data structures for handling byte streams. - // websockets.framing.prepare_data() returns only values of these types. - // Any object implementing the buffer protocol could be supported, however - // that would require allocation or copying memory, which is expensive. + // This supports bytes, bytearrays, and memoryview objects, + // which are common data structures for handling byte streams. + // websockets.framing.prepare_data() returns only these types. + // If *tmp isn't NULL, the caller gets a new reference. if (PyBytes_Check(obj)) { + *tmp = NULL; *buffer = PyBytes_AS_STRING(obj); *length = PyBytes_GET_SIZE(obj); } else if (PyByteArray_Check(obj)) { + *tmp = NULL; *buffer = PyByteArray_AS_STRING(obj); *length = PyByteArray_GET_SIZE(obj); } else if (PyMemoryView_Check(obj)) { - Py_buffer *mv_buf; - mv_buf = PyMemoryView_GET_BUFFER(obj); - if (PyBuffer_IsContiguous(mv_buf, 'C')) - { - *buffer = mv_buf->buf; - *length = mv_buf->len; - } - else + *tmp = PyMemoryView_GetContiguous(obj, PyBUF_READ, 'C'); + if (*tmp == NULL) { - PyErr_Format( - PyExc_TypeError, - "expected a contiguous memoryview"); return -1; } + Py_buffer *mv_buf; + mv_buf = PyMemoryView_GET_BUFFER(*tmp); + *buffer = mv_buf->buf; + *length = mv_buf->len; } else { @@ -74,15 +70,17 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds) // A pointer to a char * + length will be extracted from the data and mask // arguments, possibly via a Py_buffer. + PyObject *input_tmp = NULL; char *input; Py_ssize_t input_len; + PyObject *mask_tmp = NULL; char *mask; Py_ssize_t mask_len; // Initialize a PyBytesObject then get a pointer to the underlying char * // in order to avoid an extra memory copy in PyBytes_FromStringAndSize. - PyObject *result; + PyObject *result = NULL; char *output; // Other variables. @@ -94,23 +92,23 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds) if (!PyArg_ParseTupleAndKeywords( args, kwds, "OO", kwlist, &input_obj, &mask_obj)) { - return NULL; + goto exit; } - if (_PyBytesLike_AsStringAndSize(input_obj, &input, &input_len) == -1) + if (_PyBytesLike_AsStringAndSize(input_obj, &input_tmp, &input, &input_len) == -1) { - return NULL; + goto exit; } - if (_PyBytesLike_AsStringAndSize(mask_obj, &mask, &mask_len) == -1) + if (_PyBytesLike_AsStringAndSize(mask_obj, &mask_tmp, &mask, &mask_len) == -1) { - return NULL; + goto exit; } if (mask_len != MASK_LEN) { PyErr_SetString(PyExc_ValueError, "mask must contain 4 bytes"); - return NULL; + goto exit; } // Create output. @@ -118,7 +116,7 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds) result = PyBytes_FromStringAndSize(NULL, input_len); if (result == NULL) { - return NULL; + goto exit; } // Since we juste created result, we don't need error checks. @@ -172,6 +170,9 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds) output[i] = input[i] ^ mask[i & (MASK_LEN - 1)]; } +exit: + Py_XDECREF(input_tmp); + Py_XDECREF(mask_tmp); return result; } diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 218d05376..a89bcc88b 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -580,10 +580,6 @@ def test_send_binary_from_memoryview(self): self.loop.run_until_complete(self.protocol.send(memoryview(b"tea"))) self.assertOneFrameSent(True, OP_BINARY, b"tea") - def test_send_binary_from_non_contiguous_memoryview(self): - self.loop.run_until_complete(self.protocol.send(memoryview(b"tteeaa")[::2])) - self.assertOneFrameSent(True, OP_BINARY, b"tea") - def test_send_dict(self): with self.assertRaises(TypeError): self.loop.run_until_complete(self.protocol.send({"not": "encoded"})) @@ -624,14 +620,6 @@ def test_send_iterable_binary_from_memoryview(self): (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") ) - def test_send_iterable_binary_from_non_contiguous_memoryview(self): - self.loop.run_until_complete( - self.protocol.send([memoryview(b"ttee")[::2], memoryview(b"aa")[::2]]) - ) - self.assertFramesSent( - (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") - ) - def test_send_empty_iterable(self): self.loop.run_until_complete(self.protocol.send([])) self.assertNoFrameSent() @@ -697,16 +685,6 @@ def test_send_async_iterable_binary_from_memoryview(self): (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") ) - def test_send_async_iterable_binary_from_non_contiguous_memoryview(self): - self.loop.run_until_complete( - self.protocol.send( - async_iterable([memoryview(b"ttee")[::2], memoryview(b"aa")[::2]]) - ) - ) - self.assertFramesSent( - (False, OP_BINARY, b"te"), (False, OP_CONT, b"a"), (True, OP_CONT, b"") - ) - def test_send_empty_async_iterable(self): self.loop.run_until_complete(self.protocol.send(async_iterable([]))) self.assertNoFrameSent() @@ -799,10 +777,6 @@ def test_ping_binary_from_memoryview(self): self.loop.run_until_complete(self.protocol.ping(memoryview(b"tea"))) self.assertOneFrameSent(True, OP_PING, b"tea") - def test_ping_binary_from_non_contiguous_memoryview(self): - self.loop.run_until_complete(self.protocol.ping(memoryview(b"tteeaa")[::2])) - self.assertOneFrameSent(True, OP_PING, b"tea") - def test_ping_type_error(self): with self.assertRaises(TypeError): self.loop.run_until_complete(self.protocol.ping(42)) @@ -856,10 +830,6 @@ def test_pong_binary_from_memoryview(self): self.loop.run_until_complete(self.protocol.pong(memoryview(b"tea"))) self.assertOneFrameSent(True, OP_PONG, b"tea") - def test_pong_binary_from_non_contiguous_memoryview(self): - self.loop.run_until_complete(self.protocol.pong(memoryview(b"tteeaa")[::2])) - self.assertOneFrameSent(True, OP_PONG, b"tea") - def test_pong_type_error(self): with self.assertRaises(TypeError): self.loop.run_until_complete(self.protocol.pong(42)) diff --git a/tests/test_frames.py b/tests/test_frames.py index 4d10c6ef2..13a712322 100644 --- a/tests/test_frames.py +++ b/tests/test_frames.py @@ -218,12 +218,6 @@ def test_prepare_data_memoryview(self): (OP_BINARY, memoryview(b"tea")), ) - def test_prepare_data_non_contiguous_memoryview(self): - self.assertEqual( - prepare_data(memoryview(b"tteeaa")[::2]), - (OP_BINARY, b"tea"), - ) - def test_prepare_data_list(self): with self.assertRaises(TypeError): prepare_data([]) @@ -246,9 +240,6 @@ def test_prepare_ctrl_bytearray(self): def test_prepare_ctrl_memoryview(self): self.assertEqual(prepare_ctrl(memoryview(b"tea")), b"tea") - def test_prepare_ctrl_non_contiguous_memoryview(self): - self.assertEqual(prepare_ctrl(memoryview(b"tteeaa")[::2]), b"tea") - def test_prepare_ctrl_list(self): with self.assertRaises(TypeError): prepare_ctrl([]) diff --git a/tests/test_utils.py b/tests/test_utils.py index b490c2409..a9ea8dcbd 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -43,21 +43,18 @@ def test_apply_mask(self): self.assertEqual(result, data_out) def test_apply_mask_memoryview(self): - for data_type, mask_type in self.apply_mask_type_combos: + for mask_type in [bytes, bytearray]: for data_in, mask, data_out in self.apply_mask_test_values: - data_in, mask = data_type(data_in), mask_type(mask) - data_in, mask = memoryview(data_in), memoryview(mask) + data_in, mask = memoryview(data_in), mask_type(mask) with self.subTest(data_in=data_in, mask=mask): result = self.apply_mask(data_in, mask) self.assertEqual(result, data_out) def test_apply_mask_non_contiguous_memoryview(self): - for data_type, mask_type in self.apply_mask_type_combos: + for mask_type in [bytes, bytearray]: for data_in, mask, data_out in self.apply_mask_test_values: - data_in, mask = data_type(data_in), mask_type(mask) - data_in, mask = memoryview(data_in), memoryview(mask) - data_in, mask = data_in[::-1], mask[::-1] + data_in, mask = memoryview(data_in)[::-1], mask_type(mask)[::-1] data_out = data_out[::-1] with self.subTest(data_in=data_in, mask=mask): @@ -92,16 +89,3 @@ class SpeedupsTests(ApplyMaskTests): @staticmethod def apply_mask(*args, **kwargs): return c_apply_mask(*args, **kwargs) - - def test_apply_mask_non_contiguous_memoryview(self): - for data_type, mask_type in self.apply_mask_type_combos: - for data_in, mask, data_out in self.apply_mask_test_values: - data_in, mask = data_type(data_in), mask_type(mask) - data_in, mask = memoryview(data_in), memoryview(mask) - data_in, mask = data_in[::-1], mask[::-1] - data_out = data_out[::-1] - - with self.subTest(data_in=data_in, mask=mask): - # The C extension only supports contiguous memoryviews. - with self.assertRaises(TypeError): - self.apply_mask(data_in, mask) From 6167b5d8d8f7ec7d96f925089813503ee53b2983 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 11 Dec 2020 22:02:12 +0100 Subject: [PATCH 0741/1539] Clarify there's no guarantee to yield control. Fix #865. --- src/websockets/legacy/protocol.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index e4592b8a0..aa1b156c6 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -555,15 +555,15 @@ async def send( its :meth:`~dict.keys` method and pass the result to :meth:`send`. Canceling :meth:`send` is discouraged. Instead, you should close the - connection with :meth:`close`. Indeed, there only two situations where - :meth:`send` yields control to the event loop: + connection with :meth:`close`. Indeed, there are only two situations + where :meth:`send` may yield control to the event loop: 1. The write buffer is full. If you don't want to wait until enough data is sent, your only alternative is to close the connection. :meth:`close` will likely time out then abort the TCP connection. - 2. ``message`` is an asynchronous iterator. Stopping in the middle of - a fragmented message will cause a protocol error. Closing the - connection has the same effect. + 2. ``message`` is an asynchronous iterator that yields control. + Stopping in the middle of a fragmented message will cause a + protocol error. Closing the connection has the same effect. :raises TypeError: for unsupported inputs From dccba0efb3bcb554fad85d72b4f6aa392626caac Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Dec 2020 10:57:15 +0100 Subject: [PATCH 0742/1539] Fix sending fragmented, compressed messages. Fix #866. --- docs/changelog.rst | 2 ++ .../extensions/permessage_deflate.py | 26 +++++++++++-------- tests/extensions/test_permessage_deflate.py | 6 ++--- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 8d255fdfd..e8a41b53c 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -55,6 +55,8 @@ fixing regressions shortly after a release. * Raised an error when passing a :class:`dict` to :meth:`~legacy.protocol.WebSocketCommonProtocol.send`. +* Fixed sending fragmented, compressed messages. + * Fixed ``Host`` header sent when connecting to an IPv6 address. * Aligned maximum cookie size with popular web browsers. diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 9a3fc4ba5..4f520af38 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -100,7 +100,7 @@ def decode(self, frame: Frame, *, max_size: Optional[int] = None) -> Frame: return frame # Handle continuation data frames: - # - skip if the initial data frame wasn't encoded + # - skip if the message isn't encoded # - reset "decode continuation data" flag if it's a final frame if frame.opcode == OP_CONT: if not self.decode_cont_data: @@ -109,21 +109,23 @@ def decode(self, frame: Frame, *, max_size: Optional[int] = None) -> Frame: self.decode_cont_data = False # Handle text and binary data frames: - # - skip if the frame isn't encoded + # - skip if the message isn't encoded + # - unset the rsv1 flag on the first frame of a compressed message # - set "decode continuation data" flag if it's a non-final frame else: if not frame.rsv1: return frame - if not frame.fin: # frame.rsv1 is True at this point + frame = frame._replace(rsv1=False) + if not frame.fin: self.decode_cont_data = True # Re-initialize per-message decoder. if self.remote_no_context_takeover: self.decoder = zlib.decompressobj(wbits=-self.remote_max_window_bits) - # Uncompress compressed frames. Protect against zip bombs by - # preventing zlib from decompressing more than max_length bytes - # (except when the limit is disabled with max_size = None). + # Uncompress data. Protect against zip bombs by preventing zlib from + # decompressing more than max_length bytes (except when the limit is + # disabled with max_size = None). data = frame.data if frame.fin: data += _EMPTY_UNCOMPRESSED_BLOCK @@ -136,7 +138,7 @@ def decode(self, frame: Frame, *, max_size: Optional[int] = None) -> Frame: if frame.fin and self.remote_no_context_takeover: del self.decoder - return frame._replace(data=data, rsv1=False) + return frame._replace(data=data) def encode(self, frame: Frame) -> Frame: """ @@ -147,17 +149,19 @@ def encode(self, frame: Frame) -> Frame: if frame.opcode in CTRL_OPCODES: return frame - # Since we always encode and never fragment messages, there's no logic - # similar to decode() here at this time. + # Since we always encode messages, there's no "encode continuation + # data" flag similar to "decode continuation data" at this time. if frame.opcode != OP_CONT: + # Set the rsv1 flag on the first frame of a compressed message. + frame = frame._replace(rsv1=True) # Re-initialize per-message decoder. if self.local_no_context_takeover: self.encoder = zlib.compressobj( wbits=-self.local_max_window_bits, **self.compress_settings ) - # Compress data frames. + # Compress data. data = self.encoder.compress(frame.data) + self.encoder.flush(zlib.Z_SYNC_FLUSH) if frame.fin and data.endswith(_EMPTY_UNCOMPRESSED_BLOCK): data = data[:-4] @@ -166,7 +170,7 @@ def encode(self, frame: Frame) -> Frame: if frame.fin and self.local_no_context_takeover: del self.encoder - return frame._replace(data=data, rsv1=True) + return frame._replace(data=data) def _build_parameters( diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index 328861e58..7fc4c1c3a 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -113,10 +113,10 @@ def test_encode_decode_fragmented_text_frame(self): frame1._replace(rsv1=True, data=b"JNL;\xbc\x12\x00\x00\x00\xff\xff"), ) self.assertEqual( - enc_frame2, frame2._replace(rsv1=True, data=b"RPS\x00\x00\x00\x00\xff\xff") + enc_frame2, frame2._replace(data=b"RPS\x00\x00\x00\x00\xff\xff") ) self.assertEqual( - enc_frame3, frame3._replace(rsv1=True, data=b"J.\xca\xcf,.N\xcc+)\x06\x00") + enc_frame3, frame3._replace(data=b"J.\xca\xcf,.N\xcc+)\x06\x00") ) dec_frame1 = self.extension.decode(enc_frame1) @@ -138,7 +138,7 @@ def test_encode_decode_fragmented_binary_frame(self): enc_frame1, frame1._replace(rsv1=True, data=b"*IMT\x00\x00\x00\x00\xff\xff") ) self.assertEqual( - enc_frame2, frame2._replace(rsv1=True, data=b"*\xc9\xccM\x05\x00") + enc_frame2, frame2._replace(data=b"*\xc9\xccM\x05\x00") ) dec_frame1 = self.extension.decode(enc_frame1) From 97a601454e193d1f30d3069d8015d086a5b83aa2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 1 Jan 2021 18:21:24 +0100 Subject: [PATCH 0743/1539] Support serve() with existing Unix socket. Fix #878. --- docs/changelog.rst | 2 ++ src/websockets/legacy/server.py | 15 ++++++++------- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index e8a41b53c..de4483b17 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -59,6 +59,8 @@ fixing regressions shortly after a release. * Fixed ``Host`` header sent when connecting to an IPv6 address. +* Fixed starting a Unix server listening on an existing socket. + * Aligned maximum cookie size with popular web browsers. * Ensured cancellation always propagates, even on Python versions where diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 4dea9459d..42e0d6cf0 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -875,6 +875,7 @@ def __init__( select_subprotocol: Optional[ Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] ] = None, + unix: bool = False, **kwargs: Any, ) -> None: # Backwards compatibility: close_timeout used to be called timeout. @@ -931,16 +932,16 @@ def __init__( select_subprotocol=select_subprotocol, ) - if path is None: - create_server = functools.partial( - loop.create_server, factory, host, port, **kwargs - ) - else: + if unix: # unix_serve(path) must not specify host and port parameters. assert host is None and port is None create_server = functools.partial( loop.create_unix_server, factory, path, **kwargs ) + else: + create_server = functools.partial( + loop.create_server, factory, host, port, **kwargs + ) # This is a coroutine function. self._create_server = create_server @@ -981,7 +982,7 @@ async def __await_impl__(self) -> WebSocketServer: def unix_serve( ws_handler: Callable[[WebSocketServerProtocol, str], Awaitable[Any]], - path: str, + path: Optional[str] = None, **kwargs: Any, ) -> Serve: """ @@ -997,4 +998,4 @@ def unix_serve( :param path: file system path to the Unix socket """ - return serve(ws_handler, path=path, **kwargs) + return serve(ws_handler, path=path, unix=True, **kwargs) From aa93c4ceca90a1798f86b2fc2b110a42f308d721 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 1 Jan 2021 18:34:06 +0100 Subject: [PATCH 0744/1539] Make black happy. --- tests/extensions/test_permessage_deflate.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index 7fc4c1c3a..908cd91a4 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -135,10 +135,12 @@ def test_encode_decode_fragmented_binary_frame(self): enc_frame2 = self.extension.encode(frame2) self.assertEqual( - enc_frame1, frame1._replace(rsv1=True, data=b"*IMT\x00\x00\x00\x00\xff\xff") + enc_frame1, + frame1._replace(rsv1=True, data=b"*IMT\x00\x00\x00\x00\xff\xff"), ) self.assertEqual( - enc_frame2, frame2._replace(data=b"*\xc9\xccM\x05\x00") + enc_frame2, + frame2._replace(data=b"*\xc9\xccM\x05\x00"), ) dec_frame1 = self.extension.decode(enc_frame1) From dda3dfa992ddf6045be48c34143e4c1656dff9d4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Apr 2021 19:31:39 +0200 Subject: [PATCH 0745/1539] Document how to run on Heroku. Fix #929. --- docs/heroku.rst | 153 +++++++++++++++++++++++++++++++++++++ docs/index.rst | 1 + docs/spelling_wordlist.txt | 2 + 3 files changed, 156 insertions(+) create mode 100644 docs/heroku.rst diff --git a/docs/heroku.rst b/docs/heroku.rst new file mode 100644 index 000000000..31c4b3f19 --- /dev/null +++ b/docs/heroku.rst @@ -0,0 +1,153 @@ +Deploying to Heroku +=================== + +This guide describes how to deploy a websockets server to Heroku_. We're going +to deploy a very simple app. The process would be identical for a more +realistic app. + +.. _Heroku: https://www.heroku.com/ + +Create application +------------------ + +Deploying to Heroku requires a git repository. Let's initialize one: + +.. code:: console + + $ mkdir websockets-echo + $ cd websockets-echo + $ git init . + Initialized empty Git repository in websockets-echo/.git/ + $ git commit --allow-empty -m "Initial commit." + [master (root-commit) 1e7947d] Initial commit. + +Follow the `set-up instructions`_ to install the Heroku CLI and to log in, if +you haven't done that yet. + +.. _set-up instructions: https://devcenter.heroku.com/articles/getting-started-with-python#set-up + +Then, create a Heroku app — if you follow these instructions step-by-step, +you'll have to pick a different name because I'm already using +``websockets-echo`` on Heroku: + +.. code:: console + + $ $ heroku create websockets-echo + Creating ⬢ websockets-echo... done + https://websockets-echo.herokuapp.com/ | https://git.heroku.com/websockets-echo.git + +Here's the implementation of the app, an echo server. Save it in a file called +``app.py``: + +.. code:: python + + #!/usr/bin/env python + + import asyncio + import os + + import websockets + + async def echo(websocket, path): + async for message in websocket: + await websocket.send(message) + + start_server = websockets.serve(echo, "", int(os.environ["PORT"])) + + asyncio.get_event_loop().run_until_complete(start_server) + asyncio.get_event_loop().run_forever() + +The server relies on the ``$PORT`` environment variable to tell on which port +it will listen, according to Heroku's conventions. + +Configure deployment +-------------------- + +In order to build the app, Heroku needs to know that it depends on websockets. +Create a ``requirements.txt`` file containing this line: + +.. code:: + + websockets + +Heroku also needs to know how to run the app. Create a ``Procfile`` with this +content: + +.. code:: + + web: python app.py + +Confirm that you created the correct files and commit them to git: + +.. code:: console + + $ ls + Procfile app.py requirements.txt + $ git add . + $ git commit -m "Deploy echo server to Heroku." + [master 8418c62] Deploy echo server to Heroku. +  3 files changed, 19 insertions(+) +  create mode 100644 Procfile +  create mode 100644 app.py +  create mode 100644 requirements.txt + +Deploy +------ + +Our app is ready. Let's deploy it! + +.. code:: console + + $ git push heroku master + + ... lots of output... + + remote: -----> Launching... + remote: Released v3 + remote: https://websockets-echo.herokuapp.com/ deployed to Heroku + remote: + remote: Verifying deploy... done. + To https://git.heroku.com/websockets-echo.git +  * [new branch] master -> master + +Validate deployment +------------------- + +Of course we'd like to confirm that our application is running as expected! + +Since it's a WebSocket server, we need a WebSocket client, such as the +interactive client that comes with websockets. + +If you're currently building a websockets server, perhaps you're already in a +virtualenv where websockets is installed. If not, you can install it in a new +virtualenv as follows: + +.. code:: console + + $ python -m venv websockets-client + $ . websockets-client/bin/activate + $ pip install websockets + +Connect the interactive client — using the name of your Heroku app instead of +``websockets-echo``: + +.. code:: console + + $ python -m websockets wss://websockets-echo.herokuapp.com/ + Connected to wss://websockets-echo.herokuapp.com/. + > + +Great! Our app is running! + +In this example, I used a secure connection (``wss://``). It worked because +Heroku served a valid TLS certificate for ``websockets-echo.herokuapp.com``. +An insecure connection (``ws://``) would also work. + +Once you're connected, you can send any message and the server will echo it, +then press Ctrl-D to terminate the connection: + +.. code:: console + + > Hello! + < Hello! + Connection closed: code = 1000 (OK), no reason. diff --git a/docs/index.rst b/docs/index.rst index 1b2f85f0a..90262ba9a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -62,6 +62,7 @@ These guides will help you build and deploy a ``websockets`` application. cheatsheet deployment extensions + heroku Reference --------- diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index dd3500b73..5e0a254c7 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -22,6 +22,7 @@ cryptocurrency Ctrl daemonize fractalideas +IPv iterable keepalive KiB @@ -48,6 +49,7 @@ Unparse unregister uple username +virtualenv websocket WebSocket websockets From 93f78884ffcaf71a60d4ad20eabb603224453fa2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Apr 2021 19:37:01 +0200 Subject: [PATCH 0746/1539] Bump year. --- LICENSE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LICENSE b/LICENSE index b2962adba..119b29ef3 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright (c) 2013-2019 Aymeric Augustin and contributors. +Copyright (c) 2013-2021 Aymeric Augustin and contributors. All rights reserved. Redistribution and use in source and binary forms, with or without From 6b9e821183f8b42984e49313a7a3f5ccdd6fa8fc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 22 Apr 2021 09:02:14 +0200 Subject: [PATCH 0747/1539] Clarify backwards-compatibility policy. --- docs/changelog.rst | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index de4483b17..f3bc3a297 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -11,7 +11,7 @@ Backwards-compatibility policy ``websockets`` also aims at providing the best API for WebSocket in Python. While we value stability, we value progress more. When an improvement requires -changing the API, we make the change and document it below. +changing a public API, we make the change and document it in this changelog. When possible with reasonable effort, we preserve backwards-compatibility for five years after the release that introduced the change. @@ -20,6 +20,9 @@ When a release contains backwards-incompatible API changes, the major version is increased, else the minor version is increased. Patch versions are only for fixing regressions shortly after a release. +Only documented APIs are public. Undocumented APIs are considered private. +They may change at any time. + 9.0 ... From c2c8bffcf5e8cae8a648c06e4cf64943550be216 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 22 Apr 2021 09:10:26 +0200 Subject: [PATCH 0748/1539] Improve explanation of ongoing refactoring. --- docs/changelog.rst | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index f3bc3a297..4b1843713 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -3,6 +3,8 @@ Changelog .. currentmodule:: websockets +.. _backwards-compatibility policy: + Backwards-compatibility policy .............................. @@ -32,22 +34,24 @@ They may change at any time. **Version 9.0 moves or deprecates several APIs.** + Aliases provide backwards compatibility for all previously public APIs. + * :class:`~datastructures.Headers` and :exc:`~datastructures.MultipleValuesError` were moved from - ``websockets.http`` to :mod:`websockets.datastructures`. - - * ``websockets.client``, ``websockets.server``, ``websockets.protocol``, - and ``websockets.auth`` were moved to :mod:`websockets.legacy.client`, - :mod:`websockets.legacy.server`, :mod:`websockets.legacy.protocol`, and - :mod:`websockets.legacy.auth` respectively. - - * ``websockets.handshake`` is deprecated. - - * ``websockets.http`` is deprecated. - - * ``websockets.framing`` is deprecated. - - Aliases provide backwards compatibility for all previously public APIs. + ``websockets.http`` to :mod:`websockets.datastructures`. If you're using + them, you should adjust the import path. + + * The ``client``, ``server``, ``protocol``, and ``auth`` modules were + moved from the ``websockets`` package to ``websockets.legacy`` + sub-package, as part of an upcoming refactoring. Despite the name, + they're still fully supported. The refactoring should be a transparent + upgrade for most uses when it's available. The legacy implementation + will be preserved according to the `backwards-compatibility policy`_. + + * The ``handshake``, ``http``, and ``framing`` modules in the + ``websockets`` package are deprecated. These modules provided low-level + APIs for reuse by other WebSocket implementations, but that never + happened and keeping these APIs public prevents improvements. * Added compatibility with Python 3.9. From ce1f4a071cc6651ff8bcf89f0919721aa9ca4574 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 22 Apr 2021 09:37:17 +0200 Subject: [PATCH 0749/1539] Deprecate headers and uri as well. They aren't involved in any public API any more. --- docs/changelog.rst | 9 +++++---- src/websockets/__init__.py | 4 ++-- src/websockets/headers.py | 3 --- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 4b1843713..9b2fa4441 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -48,10 +48,11 @@ They may change at any time. upgrade for most uses when it's available. The legacy implementation will be preserved according to the `backwards-compatibility policy`_. - * The ``handshake``, ``http``, and ``framing`` modules in the - ``websockets`` package are deprecated. These modules provided low-level - APIs for reuse by other WebSocket implementations, but that never - happened and keeping these APIs public prevents improvements. + * The ``framing``, ``handshake``, ``headers``, ``http``, and ``uri`` + modules in the ``websockets`` package are deprecated. These modules + provided low-level APIs for reuse by other WebSocket implementations, + but that never happened. Keeping these APIs public makes it more + difficult to improve websockets for no actual benefit. * Added compatibility with Python 3.9. diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 580a3960f..65d9fb913 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -96,11 +96,11 @@ "ExtensionHeader": ".typing", "ExtensionParameter": ".typing", "Subprotocol": ".typing", - "parse_uri": ".uri", - "WebSocketURI": ".uri", }, deprecated_aliases={ "framing": ".legacy", "handshake": ".legacy", + "parse_uri": ".uri", + "WebSocketURI": ".uri", }, ) diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 256c66bb1..6779c9c04 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -2,9 +2,6 @@ :mod:`websockets.headers` provides parsers and serializers for HTTP headers used in WebSocket handshake messages. -These APIs cannot be imported from :mod:`websockets`. They must be imported -from :mod:`websockets.headers`. - """ import base64 From bb40530d4051dd1dbc0522e4d9e3e72cc7e25436 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 22 Apr 2021 09:40:27 +0200 Subject: [PATCH 0750/1539] Remove deprecated modules from API documentation. --- docs/api.rst | 40 ++++++++-------------------------------- 1 file changed, 8 insertions(+), 32 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index c73cf59d3..2adc0dde4 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -23,11 +23,8 @@ For convenience, public APIs can be imported directly from the :mod:`websockets` package, unless noted otherwise. Anything that isn't listed in this document is a private API. -High-level ----------- - Server -...... +------ .. automodule:: websockets.legacy.server @@ -51,7 +48,7 @@ Server .. automethod:: select_subprotocol Client -...... +------ .. automodule:: websockets.legacy.client @@ -66,7 +63,7 @@ Client .. automethod:: handshake Shared -...... +------ .. automodule:: websockets.legacy.protocol @@ -88,7 +85,7 @@ Shared .. autoattribute:: closed Types -..... +----- .. automodule:: websockets.typing @@ -96,7 +93,7 @@ Types Per-Message Deflate Extension -............................. +----------------------------- .. automodule:: websockets.extensions.permessage_deflate @@ -105,7 +102,7 @@ Per-Message Deflate Extension .. autoclass:: ClientPerMessageDeflateFactory HTTP Basic Auth -............... +--------------- .. automodule:: websockets.legacy.auth @@ -116,34 +113,13 @@ HTTP Basic Auth .. automethod:: process_request Data structures -............... +--------------- .. automodule:: websockets.datastructures :members: Exceptions -.......... +---------- .. automodule:: websockets.exceptions :members: - -Low-level ---------- - -Data transfer -............. - -.. automodule:: websockets.framing - :members: - -URI parser -.......... - -.. automodule:: websockets.uri - :members: - -Utilities -......... - -.. automodule:: websockets.headers - :members: From c0002603eb39a9a85f89a0c83337ce398aeea7de Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 25 Apr 2021 21:39:57 +0200 Subject: [PATCH 0751/1539] Make HeadersLike a public API. Refs #845, #854. --- src/websockets/datastructures.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index f70d92ad7..c8e17fa98 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -1,5 +1,5 @@ """ -This module defines a data structure for manipulating HTTP headers. +:mod:`websockets.datastructures` defines a class for manipulating HTTP headers. """ @@ -16,7 +16,7 @@ ) -__all__ = ["Headers", "MultipleValuesError"] +__all__ = ["Headers", "HeadersLike", "MultipleValuesError"] class MultipleValuesError(LookupError): @@ -63,7 +63,7 @@ class Headers(MutableMapping[str, str]): As long as no header occurs multiple times, :class:`Headers` behaves like :class:`dict`, except keys are lower-cased to provide case-insensitivity. - Two methods support support manipulating multiple values explicitly: + Two methods support manipulating multiple values explicitly: - :meth:`get_all` returns a list of all values for a header; - :meth:`raw_items` returns an iterator of ``(name, values)`` pairs. @@ -157,3 +157,9 @@ def raw_items(self) -> Iterator[Tuple[str, str]]: HeadersLike = Union[Headers, Mapping[str, str], Iterable[Tuple[str, str]]] +HeadersLike__doc__ = """Types accepted wherever :class:`Headers` is expected""" +# Remove try / except when dropping support for Python < 3.7 +try: + HeadersLike.__doc__ = HeadersLike__doc__ +except AttributeError: # pragma: no cover + pass From fa295a75fd0fcf53906d7aa0fe4fdcc8c7d81cd2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 25 Apr 2021 21:41:39 +0200 Subject: [PATCH 0752/1539] Rewrite extensions guide. --- docs/deployment.rst | 2 ++ docs/extensions.rst | 79 ++++++++++++++++++++++++++------------------- 2 files changed, 47 insertions(+), 34 deletions(-) diff --git a/docs/deployment.rst b/docs/deployment.rst index ed025094d..2331af936 100644 --- a/docs/deployment.rst +++ b/docs/deployment.rst @@ -66,6 +66,8 @@ Memory usage of a single connection is the sum of: Baseline ........ +.. _compression-settings: + Compression settings are the main factor affecting the baseline amount of memory used by each connection. diff --git a/docs/extensions.rst b/docs/extensions.rst index dea91219e..151a7e297 100644 --- a/docs/extensions.rst +++ b/docs/extensions.rst @@ -1,12 +1,12 @@ Extensions ========== -.. currentmodule:: websockets +.. currentmodule:: websockets.extensions The WebSocket protocol supports extensions_. -At the time of writing, there's only one `registered extension`_, WebSocket -Per-Message Deflate, specified in :rfc:`7692`. +At the time of writing, there's only one `registered extension`_ with a public +specification, WebSocket Per-Message Deflate, specified in :rfc:`7692`. .. _extensions: https://tools.ietf.org/html/rfc6455#section-9 .. _registered extension: https://www.iana.org/assignments/websocket/websocket.xhtml#extension-name @@ -14,24 +14,31 @@ Per-Message Deflate, specified in :rfc:`7692`. Per-Message Deflate ------------------- -:func:`~legacy.server.serve()` and :func:`~legacy.client.connect` enable the -Per-Message Deflate extension by default. You can disable this with -``compression=None``. +:func:`~websockets.legacy.client.connect` and +:func:`~websockets.legacy.server.serve` enable the Per-Message Deflate +extension by default. + +If you want to disable it, set ``compression=None``:: + + import websockets + + websockets.connect(..., compression=None) + + websockets.serve(..., compression=None) -You can also configure the Per-Message Deflate extension explicitly if you -want to customize its parameters. .. _per-message-deflate-configuration-example: -Here's an example on the server side:: +You can also configure the Per-Message Deflate extension explicitly if you +want to customize compression settings:: import websockets from websockets.extensions import permessage_deflate - websockets.serve( + websockets.connect( ..., extensions=[ - permessage_deflate.ServerPerMessageDeflateFactory( + permessage_deflate.ClientPerMessageDeflateFactory( server_max_window_bits=11, client_max_window_bits=11, compress_settings={'memLevel': 4}, @@ -39,15 +46,10 @@ Here's an example on the server side:: ], ) -Here's an example on the client side:: - - import websockets - from websockets.extensions import permessage_deflate - - websockets.connect( + websockets.serve( ..., extensions=[ - permessage_deflate.ClientPerMessageDeflateFactory( + permessage_deflate.ServerPerMessageDeflateFactory( server_max_window_bits=11, client_max_window_bits=11, compress_settings={'memLevel': 4}, @@ -55,34 +57,43 @@ Here's an example on the client side:: ], ) +The window bits and memory level values chosen in these examples reduce memory +usage. You can read more about :ref:`optimizing compression settings +`. + Refer to the API documentation of -:class:`~extensions.permessage_deflate.ServerPerMessageDeflateFactory` and -:class:`~extensions.permessage_deflate.ClientPerMessageDeflateFactory` for -details. +:class:`~permessage_deflate.ClientPerMessageDeflateFactory` and +:class:`~permessage_deflate.ServerPerMessageDeflateFactory` for details. Writing an extension -------------------- During the opening handshake, WebSocket clients and servers negotiate which extensions will be used with which parameters. Then each frame is processed by -extensions before it's sent and after it's received. +extensions before being sent or after being received. + +As a consequence, writing an extension requires implementing several classes: + +* Extension Factory: it negotiates parameters and instantiates the extension. -As a consequence writing an extension requires implementing several classes: + Clients and servers require separate extension factories with distinct APIs. -1. Extension Factory: it negotiates parameters and instantiates the extension. - Clients and servers require separate extension factories with distinct APIs. + Extension factories are the public API of an extension. -2. Extension: it decodes incoming frames and encodes outgoing frames. If the - extension is symmetrical, clients and servers can use the same class. +* Extension: it decodes incoming frames and encodes outgoing frames. + + If the extension is symmetrical, clients and servers can use the same + class. + + Extensions are initialized by extension factories, so they don't need to be + part of the public API of an extension. ``websockets`` provides abstract base classes for extension factories and -extensions. +extensions. See the API documentation for details on their methods: + +* :class:`~base.ClientExtensionFactory` and + :class:`~base.ServerExtensionFactory` for extension factories, -.. autoclass:: websockets.extensions.base.ServerExtensionFactory - :members: +* :class:`~base.Extension` for extensions. -.. autoclass:: websockets.extensions.base.ClientExtensionFactory - :members: -.. autoclass:: websockets.extensions.base.Extension - :members: From 835d16dfadd912766df99dac21e82c151eb1bda7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 May 2021 10:47:38 +0200 Subject: [PATCH 0753/1539] Add example of client shutdown. Fix #933. --- docs/deployment.rst | 2 +- docs/faq.rst | 11 +++++++++++ example/shutdown_client.py | 19 +++++++++++++++++++ example/{shutdown.py => shutdown_server.py} | 0 4 files changed, 31 insertions(+), 1 deletion(-) create mode 100755 example/shutdown_client.py rename example/{shutdown.py => shutdown_server.py} (100%) diff --git a/docs/deployment.rst b/docs/deployment.rst index 2331af936..8baa8836c 100644 --- a/docs/deployment.rst +++ b/docs/deployment.rst @@ -34,7 +34,7 @@ On Unix systems, shutdown is usually triggered by sending a signal. Here's a full example for handling SIGTERM on Unix: -.. literalinclude:: ../example/shutdown.py +.. literalinclude:: ../example/shutdown_server.py :emphasize-lines: 13,17-19 This example is easily adapted to handle other signals. If you override the diff --git a/docs/faq.rst b/docs/faq.rst index eee14dda8..ff91105b4 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -142,6 +142,17 @@ See `issue 414`_. .. _issue 414: https://github.com/aaugustin/websockets/issues/414 +How do I stop a client that is continuously processing messages? +................................................................ + +You can close the connection. + +Here's an example that terminates cleanly when it receives SIGTERM on Unix: + +.. literalinclude:: ../example/shutdown_client.py + :emphasize-lines: 10-13 + + How do I disable TLS/SSL certificate verification? .................................................. diff --git a/example/shutdown_client.py b/example/shutdown_client.py new file mode 100755 index 000000000..f21c0f6fa --- /dev/null +++ b/example/shutdown_client.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python + +import asyncio +import signal +import websockets + +async def client(): + uri = "ws://localhost:8765" + async with websockets.connect(uri) as websocket: + # Close the connection when receiving SIGTERM. + loop = asyncio.get_event_loop() + loop.add_signal_handler( + signal.SIGTERM, loop.create_task, websocket.close()) + + # Process messages received on the connection. + async for message in websocket: + ... + +asyncio.get_event_loop().run_until_complete(client()) diff --git a/example/shutdown.py b/example/shutdown_server.py similarity index 100% rename from example/shutdown.py rename to example/shutdown_server.py From cf2453625a023868bfe760dc438a500e3ebcb931 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 May 2021 21:45:32 +0200 Subject: [PATCH 0754/1539] Clean up signature of Protocol classes. --- src/websockets/legacy/client.py | 10 ++++++---- src/websockets/legacy/server.py | 13 +++++++------ 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 27f6e8209..1c0ecf62f 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -374,7 +374,6 @@ def __init__( self, uri: str, *, - path: Optional[str] = None, create_protocol: Optional[Callable[[Any], WebSocketClientProtocol]] = None, ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, @@ -384,9 +383,6 @@ def __init__( read_limit: int = 2 ** 16, write_limit: int = 2 ** 16, loop: Optional[asyncio.AbstractEventLoop] = None, - legacy_recv: bool = False, - klass: Optional[Type[WebSocketClientProtocol]] = None, - timeout: Optional[float] = None, compression: Optional[str] = "deflate", origin: Optional[Origin] = None, extensions: Optional[Sequence[ClientExtensionFactory]] = None, @@ -395,6 +391,7 @@ def __init__( **kwargs: Any, ) -> None: # Backwards compatibility: close_timeout used to be called timeout. + timeout: Optional[float] = kwargs.pop("timeout", None) if timeout is None: timeout = 10 else: @@ -404,6 +401,7 @@ def __init__( close_timeout = timeout # Backwards compatibility: create_protocol used to be called klass. + klass: Optional[Type[WebSocketClientProtocol]] = kwargs.pop("klass", None) if klass is None: klass = WebSocketClientProtocol else: @@ -412,6 +410,9 @@ def __init__( if create_protocol is None: create_protocol = klass + # Backwards compatibility: recv() used to return None on closed connections + legacy_recv: bool = kwargs.pop("legacy_recv", False) + if loop is None: loop = asyncio.get_event_loop() @@ -449,6 +450,7 @@ def __init__( extra_headers=extra_headers, ) + path: Optional[str] = kwargs.pop("path", None) if path is None: host: Optional[str] port: Optional[int] diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 42e0d6cf0..b7eed52b0 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -851,7 +851,6 @@ def __init__( host: Optional[Union[str, Sequence[str]]] = None, port: Optional[int] = None, *, - path: Optional[str] = None, create_protocol: Optional[Callable[[Any], WebSocketServerProtocol]] = None, ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, @@ -861,9 +860,6 @@ def __init__( read_limit: int = 2 ** 16, write_limit: int = 2 ** 16, loop: Optional[asyncio.AbstractEventLoop] = None, - legacy_recv: bool = False, - klass: Optional[Type[WebSocketServerProtocol]] = None, - timeout: Optional[float] = None, compression: Optional[str] = "deflate", origins: Optional[Sequence[Optional[Origin]]] = None, extensions: Optional[Sequence[ServerExtensionFactory]] = None, @@ -875,10 +871,10 @@ def __init__( select_subprotocol: Optional[ Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] ] = None, - unix: bool = False, **kwargs: Any, ) -> None: # Backwards compatibility: close_timeout used to be called timeout. + timeout: Optional[float] = kwargs.pop("timeout", None) if timeout is None: timeout = 10 else: @@ -888,6 +884,7 @@ def __init__( close_timeout = timeout # Backwards compatibility: create_protocol used to be called klass. + klass: Optional[Type[WebSocketServerProtocol]] = kwargs.pop("klass", None) if klass is None: klass = WebSocketServerProtocol else: @@ -896,6 +893,9 @@ def __init__( if create_protocol is None: create_protocol = klass + # Backwards compatibility: recv() used to return None on closed connections + legacy_recv: bool = kwargs.pop("legacy_recv", False) + if loop is None: loop = asyncio.get_event_loop() @@ -932,7 +932,8 @@ def __init__( select_subprotocol=select_subprotocol, ) - if unix: + if kwargs.pop("unix", False): + path: Optional[str] = kwargs.pop("path", None) # unix_serve(path) must not specify host and port parameters. assert host is None and port is None create_server = functools.partial( From 9c818367b2177aae6c90c3a5c4fad26e540c81bc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 May 2021 21:51:07 +0200 Subject: [PATCH 0755/1539] Support existing Unix sockets in unix_connect. The same fix was made for the server side, but not the client side. --- docs/changelog.rst | 2 +- src/websockets/legacy/client.py | 18 ++++++++++-------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 9b2fa4441..91ea23dc9 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -67,7 +67,7 @@ They may change at any time. * Fixed ``Host`` header sent when connecting to an IPv6 address. -* Fixed starting a Unix server listening on an existing socket. +* Fixed creating a client or a server with an existing Unix socket. * Aligned maximum cookie size with popular web browsers. diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 1c0ecf62f..219c3c9bc 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -450,8 +450,12 @@ def __init__( extra_headers=extra_headers, ) - path: Optional[str] = kwargs.pop("path", None) - if path is None: + if kwargs.pop("unix", False): + path: Optional[str] = kwargs.pop("path", None) + create_connection = functools.partial( + loop.create_unix_connection, factory, path, **kwargs + ) + else: host: Optional[str] port: Optional[int] if kwargs.get("sock") is None: @@ -465,10 +469,6 @@ def __init__( create_connection = functools.partial( loop.create_connection, factory, host, port, **kwargs ) - else: - create_connection = functools.partial( - loop.create_unix_connection, factory, path, **kwargs - ) # This is a coroutine function. self._create_connection = create_connection @@ -563,7 +563,9 @@ async def __await_impl__(self) -> WebSocketClientProtocol: connect = Connect -def unix_connect(path: str, uri: str = "ws://localhost/", **kwargs: Any) -> Connect: +def unix_connect( + path: Optional[str], uri: str = "ws://localhost/", **kwargs: Any +) -> Connect: """ Similar to :func:`connect`, but for connecting to a Unix socket. @@ -578,4 +580,4 @@ def unix_connect(path: str, uri: str = "ws://localhost/", **kwargs: Any) -> Conn :param uri: WebSocket URI """ - return connect(uri=uri, path=path, **kwargs) + return connect(uri=uri, path=path, unix=True, **kwargs) From 9223d7d72ab11824988442847d4c02d7524a61c1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 May 2021 22:15:58 +0200 Subject: [PATCH 0756/1539] Restore backwards-compatibility for logger names. --- src/websockets/legacy/client.py | 2 +- src/websockets/legacy/protocol.py | 2 +- src/websockets/legacy/server.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 219c3c9bc..4000375fb 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -40,7 +40,7 @@ __all__ = ["connect", "unix_connect", "WebSocketClientProtocol"] -logger = logging.getLogger(__name__) +logger = logging.getLogger("websockets.server") class WebSocketClientProtocol(WebSocketCommonProtocol): diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index aa1b156c6..84af7b626 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -60,7 +60,7 @@ __all__ = ["WebSocketCommonProtocol"] -logger = logging.getLogger(__name__) +logger = logging.getLogger("websockets.protocol") # A WebSocket connection goes through the following four states, in order: diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index b7eed52b0..8e5f97a66 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -50,7 +50,7 @@ __all__ = ["serve", "unix_serve", "WebSocketServerProtocol", "WebSocketServer"] -logger = logging.getLogger(__name__) +logger = logging.getLogger("websockets.server") HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]] From fcb3a4c31838b797ff609d2fdb89db7f37c527ff Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 May 2021 22:32:20 +0200 Subject: [PATCH 0757/1539] Remove backwards-compatibility from docs after 5 years. --- src/websockets/legacy/protocol.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 84af7b626..56c4d5f6a 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -466,11 +466,6 @@ async def recv(self) -> Data: :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol error or a network failure. - .. versionchanged:: 3.0 - - :meth:`recv` used to return ``None`` instead. Refer to the - changelog for details. - Canceling :meth:`recv` is safe. There's no risk of losing the next message. The next invocation of :meth:`recv` will return it. This makes it possible to enforce a timeout by wrapping :meth:`recv` in From d82a7a9de7cebec22bdcdf763ba4cb7ea75bdb76 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 May 2021 21:45:51 +0200 Subject: [PATCH 0758/1539] Revamp API documentation. --- docs/api.rst | 125 ----------------------- docs/api/client.rst | 74 ++++++++++++++ docs/api/extensions.rst | 26 +++++ docs/api/index.rst | 50 ++++++++++ docs/api/server.rst | 105 +++++++++++++++++++ docs/api/utilities.rst | 20 ++++ docs/design.rst | 6 +- docs/index.rst | 2 +- docs/spelling_wordlist.txt | 1 + src/websockets/legacy/auth.py | 3 - src/websockets/legacy/client.py | 122 +++++++++++++++++++--- src/websockets/legacy/protocol.py | 97 +----------------- src/websockets/legacy/server.py | 161 ++++++++++++++++++++++++------ 13 files changed, 517 insertions(+), 275 deletions(-) delete mode 100644 docs/api.rst create mode 100644 docs/api/client.rst create mode 100644 docs/api/extensions.rst create mode 100644 docs/api/index.rst create mode 100644 docs/api/server.rst create mode 100644 docs/api/utilities.rst diff --git a/docs/api.rst b/docs/api.rst deleted file mode 100644 index 2adc0dde4..000000000 --- a/docs/api.rst +++ /dev/null @@ -1,125 +0,0 @@ -API -=== - -Design ------- - -``websockets`` provides complete client and server implementations, as shown -in the :doc:`getting started guide `. These functions are built on top -of low-level APIs reflecting the two phases of the WebSocket protocol: - -1. An opening handshake, in the form of an HTTP Upgrade request; - -2. Data transfer, as framed messages, ending with a closing handshake. - -The first phase is designed to integrate with existing HTTP software. -``websockets`` provides a minimal implementation to build, parse and validate -HTTP requests and responses. - -The second phase is the core of the WebSocket protocol. ``websockets`` -provides a complete implementation on top of ``asyncio`` with a simple API. - -For convenience, public APIs can be imported directly from the -:mod:`websockets` package, unless noted otherwise. Anything that isn't listed -in this document is a private API. - -Server ------- - -.. automodule:: websockets.legacy.server - - .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, **kwds) - :async: - - .. autofunction:: unix_serve(ws_handler, path, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, **kwds) - :async: - - - .. autoclass:: WebSocketServer - - .. automethod:: close - .. automethod:: wait_closed - .. autoattribute:: sockets - - .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None) - - .. automethod:: handshake - .. automethod:: process_request - .. automethod:: select_subprotocol - -Client ------- - -.. automodule:: websockets.legacy.client - - .. autofunction:: connect(uri, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, **kwds) - :async: - - .. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, **kwds) - :async: - - .. autoclass:: WebSocketClientProtocol(*, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, extensions=None, subprotocols=None, extra_headers=None) - - .. automethod:: handshake - -Shared ------- - -.. automodule:: websockets.legacy.protocol - - .. autoclass:: WebSocketCommonProtocol(*, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None) - - .. automethod:: close - .. automethod:: wait_closed - - .. automethod:: recv - .. automethod:: send - - .. automethod:: ping - .. automethod:: pong - - .. autoattribute:: local_address - .. autoattribute:: remote_address - - .. autoattribute:: open - .. autoattribute:: closed - -Types ------ - -.. automodule:: websockets.typing - - .. autodata:: Data - - -Per-Message Deflate Extension ------------------------------ - -.. automodule:: websockets.extensions.permessage_deflate - - .. autoclass:: ServerPerMessageDeflateFactory - - .. autoclass:: ClientPerMessageDeflateFactory - -HTTP Basic Auth ---------------- - -.. automodule:: websockets.legacy.auth - - .. autofunction:: basic_auth_protocol_factory - - .. autoclass:: BasicAuthWebSocketServerProtocol - - .. automethod:: process_request - -Data structures ---------------- - -.. automodule:: websockets.datastructures - :members: - -Exceptions ----------- - -.. automodule:: websockets.exceptions - :members: diff --git a/docs/api/client.rst b/docs/api/client.rst new file mode 100644 index 000000000..f969227a9 --- /dev/null +++ b/docs/api/client.rst @@ -0,0 +1,74 @@ +Client +====== + +.. automodule:: websockets.legacy.client + + Opening a connection + -------------------- + + .. autofunction:: connect(uri, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, **kwds) + :async: + + .. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, **kwds) + :async: + + Using a connection + ------------------ + + .. autoclass:: WebSocketClientProtocol(*, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, extensions=None, subprotocols=None, extra_headers=None) + + .. autoattribute:: local_address + + .. autoattribute:: remote_address + + .. autoattribute:: open + + .. autoattribute:: closed + + .. attribute:: path + + Path of the HTTP request. + + Available once the connection is open. + + .. attribute:: request_headers + + HTTP request headers as a :class:`~websockets.http.Headers` instance. + + Available once the connection is open. + + .. attribute:: response_headers + + HTTP response headers as a :class:`~websockets.http.Headers` instance. + + Available once the connection is open. + + .. attribute:: subprotocol + + Subprotocol, if one was negotiated. + + Available once the connection is open. + + .. attribute:: close_code + + WebSocket close code. + + Available once the connection is closed. + + .. attribute:: close_reason + + WebSocket close reason. + + Available once the connection is closed. + + .. automethod:: recv + + .. automethod:: send + + .. automethod:: ping + + .. automethod:: pong + + .. automethod:: close + + .. automethod:: wait_closed diff --git a/docs/api/extensions.rst b/docs/api/extensions.rst new file mode 100644 index 000000000..635c5c426 --- /dev/null +++ b/docs/api/extensions.rst @@ -0,0 +1,26 @@ +Extensions +========== + +Per-Message Deflate +------------------- + +.. automodule:: websockets.extensions.permessage_deflate + + .. autoclass:: ClientPerMessageDeflateFactory + + .. autoclass:: ServerPerMessageDeflateFactory + +Abstract classes +---------------- + +.. automodule:: websockets.extensions.base + + .. autoclass:: Extension + :members: + + .. autoclass:: ClientExtensionFactory + :members: + + .. autoclass:: ServerExtensionFactory + :members: + diff --git a/docs/api/index.rst b/docs/api/index.rst new file mode 100644 index 000000000..20bb740b3 --- /dev/null +++ b/docs/api/index.rst @@ -0,0 +1,50 @@ +API +=== + +``websockets`` provides complete client and server implementations, as shown +in the :doc:`getting started guide <../intro>`. + +The process for opening and closing a WebSocket connection depends on which +side you're implementing. + +* On the client side, connecting to a server with :class:`~websockets.connect` + yields a connection object that provides methods for interacting with the + connection. Your code can open a connection, then send or receive messages. + + If you use :class:`~websockets.connect` as an asynchronous context manager, + then websockets closes the connection on exit. If not, then your code is + responsible for closing the connection. + +* On the server side, :class:`~websockets.serve` starts listening for client + connections and yields an server object that supports closing the server. + + Then, when clients connects, the server initializes a connection object and + passes it to a handler coroutine, which is where your code can send or + receive messages. This pattern is called `inversion of control`_. It's + common in frameworks implementing servers. + + When the handler coroutine terminates, websockets closes the connection. You + may also close it in the handler coroutine if you'd like. + +.. _inversion of control: https://en.wikipedia.org/wiki/Inversion_of_control + +Once the connection is open, the WebSocket protocol is symmetrical, except for +low-level details that websockets manages under the hood. The same methods are +available on client connections created with :class:`~websockets.connect` and +on server connections passed to the connection handler in the arguments. + +At this point, websockets provides the same API — and uses the same code — for +client and server connections. For convenience, common methods are documented +both in the client API and server API. + +.. toctree:: + :maxdepth: 2 + + client + server + extensions + utilities + +All public APIs can be imported from the :mod:`websockets` package, unless +noted otherwise. Anything that isn't listed in this API documentation is a +private API, with no guarantees of behavior or backwards-compatibility. diff --git a/docs/api/server.rst b/docs/api/server.rst new file mode 100644 index 000000000..16c8f6359 --- /dev/null +++ b/docs/api/server.rst @@ -0,0 +1,105 @@ +Server +====== + +.. automodule:: websockets.legacy.server + + Starting a server + ----------------- + + .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, **kwds) + :async: + + .. autofunction:: unix_serve(ws_handler, path, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, **kwds) + :async: + + Stopping a server + ----------------- + + .. autoclass:: WebSocketServer + + .. autoattribute:: sockets + + .. automethod:: close + .. automethod:: wait_closed + + Using a connection + ------------------ + + .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None) + + .. autoattribute:: local_address + + .. autoattribute:: remote_address + + .. autoattribute:: open + + .. autoattribute:: closed + + .. attribute:: path + + Path of the HTTP request. + + Available once the connection is open. + + .. attribute:: request_headers + + HTTP request headers as a :class:`~websockets.http.Headers` instance. + + Available once the connection is open. + + .. attribute:: response_headers + + HTTP response headers as a :class:`~websockets.http.Headers` instance. + + Available once the connection is open. + + .. attribute:: subprotocol + + Subprotocol, if one was negotiated. + + Available once the connection is open. + + .. attribute:: close_code + + WebSocket close code. + + Available once the connection is closed. + + .. attribute:: close_reason + + WebSocket close reason. + + Available once the connection is closed. + + .. automethod:: process_request + + .. automethod:: select_subprotocol + + .. automethod:: recv + + .. automethod:: send + + .. automethod:: ping + + .. automethod:: pong + + .. automethod:: close + + .. automethod:: wait_closed + +Basic authentication +-------------------- + +.. automodule:: websockets.legacy.auth + + .. autofunction:: basic_auth_protocol_factory + + .. autoclass:: BasicAuthWebSocketServerProtocol + + .. automethod:: process_request + + .. attribute:: username + + Username of the authenticated user. + + diff --git a/docs/api/utilities.rst b/docs/api/utilities.rst new file mode 100644 index 000000000..198e928b0 --- /dev/null +++ b/docs/api/utilities.rst @@ -0,0 +1,20 @@ +Utilities +========= + +Data structures +--------------- + +.. automodule:: websockets.datastructures + :members: + +Exceptions +---------- + +.. automodule:: websockets.exceptions + :members: + +Types +----- + +.. automodule:: websockets.typing + :members: diff --git a/docs/design.rst b/docs/design.rst index f2718370d..0cabc2e5d 100644 --- a/docs/design.rst +++ b/docs/design.rst @@ -13,7 +13,7 @@ wish to understand what happens under the hood. Internals described in this document may change at any time. - Backwards compatibility is only guaranteed for `public APIs `_. + Backwards compatibility is only guaranteed for `public APIs `_. Lifecycle @@ -404,8 +404,8 @@ don't involve inversion of control. Library ....... -Most :doc:`public APIs ` of ``websockets`` are coroutines. They may be -canceled, for example if the user starts a task that calls these coroutines +Most :doc:`public APIs ` of ``websockets`` are coroutines. They may +be canceled, for example if the user starts a task that calls these coroutines and cancels the task later. ``websockets`` must handle this situation. Cancellation during the opening handshake is handled like any other exception: diff --git a/docs/index.rst b/docs/index.rst index 90262ba9a..e121fd930 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -72,7 +72,7 @@ Find all the details you could ask for, and then some. .. toctree:: :maxdepth: 2 - api + api/index Discussions ----------- diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 5e0a254c7..d7c744147 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -21,6 +21,7 @@ cryptocurrencies cryptocurrency Ctrl daemonize +datastructures fractalideas IPv iterable diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index 8cb60429a..e0beede57 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -52,9 +52,6 @@ async def process_request( """ Check HTTP Basic Auth and return a HTTP 401 or 403 response if needed. - If authentication succeeds, the username of the authenticated user is - stored in the ``username`` attribute. - """ try: authorization = request_headers["Authorization"] diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 4000375fb..1b5bd303f 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -47,8 +47,97 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): """ :class:`~asyncio.Protocol` subclass implementing a WebSocket client. - This class inherits most of its methods from - :class:`~websockets.protocol.WebSocketCommonProtocol`. + :class:`WebSocketClientProtocol`: + + * performs the opening handshake to establish the connection; + * provides :meth:`recv` and :meth:`send` coroutines for receiving and + sending messages; + * deals with control frames automatically; + * performs the closing handshake to terminate the connection. + + :class:`WebSocketClientProtocol` supports asynchronous iteration:: + + async for message in websocket: + await process(message) + + The iterator yields incoming messages. It exits normally when the + connection is closed with the close code 1000 (OK) or 1001 (going away). + It raises a :exc:`~websockets.exceptions.ConnectionClosedError` exception + when the connection is closed with any other code. + + Once the connection is open, a `Ping frame`_ is sent every + ``ping_interval`` seconds. This serves as a keepalive. It helps keeping + the connection open, especially in the presence of proxies with short + timeouts on inactive connections. Set ``ping_interval`` to ``None`` to + disable this behavior. + + .. _Ping frame: https://tools.ietf.org/html/rfc6455#section-5.5.2 + + If the corresponding `Pong frame`_ isn't received within ``ping_timeout`` + seconds, the connection is considered unusable and is closed with + code 1011. This ensures that the remote endpoint remains responsive. Set + ``ping_timeout`` to ``None`` to disable this behavior. + + .. _Pong frame: https://tools.ietf.org/html/rfc6455#section-5.5.3 + + The ``close_timeout`` parameter defines a maximum wait time for completing + the closing handshake and terminating the TCP connection. For legacy + reasons, :meth:`close` completes in at most ``5 * close_timeout`` seconds. + + ``close_timeout`` needs to be a parameter of the protocol because + websockets usually calls :meth:`close` implicitly upon exit when + :func:`connect` is used as a context manager. + + To apply a timeout to any other API, wrap it in :func:`~asyncio.wait_for`. + + The ``max_size`` parameter enforces the maximum size for incoming messages + in bytes. The default value is 1 MiB. ``None`` disables the limit. If a + message larger than the maximum size is received, :meth:`recv` will + raise :exc:`~websockets.exceptions.ConnectionClosedError` and the + connection will be closed with code 1009. + + The ``max_queue`` parameter sets the maximum length of the queue that + holds incoming messages. The default value is ``32``. ``None`` disables + the limit. Messages are added to an in-memory queue when they're received; + then :meth:`recv` pops from that queue. In order to prevent excessive + memory consumption when messages are received faster than they can be + processed, the queue must be bounded. If the queue fills up, the protocol + stops processing incoming data until :meth:`recv` is called. In this + situation, various receive buffers (at least in :mod:`asyncio` and in the + OS) will fill up, then the TCP receive window will shrink, slowing down + transmission to avoid packet loss. + + Since Python can use up to 4 bytes of memory to represent a single + character, each connection may use up to ``4 * max_size * max_queue`` + bytes of memory to store incoming messages. By default, this is 128 MiB. + You may want to lower the limits, depending on your application's + requirements. + + The ``read_limit`` argument sets the high-water limit of the buffer for + incoming bytes. The low-water limit is half the high-water limit. The + default value is 64 KiB, half of asyncio's default (based on the current + implementation of :class:`~asyncio.StreamReader`). + + The ``write_limit`` argument sets the high-water limit of the buffer for + outgoing bytes. The low-water limit is a quarter of the high-water limit. + The default value is 64 KiB, equal to asyncio's default (based on the + current implementation of ``FlowControlMixin``). + + As soon as the HTTP request and response in the opening handshake are + processed: + + * the request path is available in the :attr:`path` attribute; + * the request and response HTTP headers are available in the + :attr:`request_headers` and :attr:`response_headers` attributes, + which are :class:`~websockets.http.Headers` instances. + + If a subprotocol was negotiated, it's available in the :attr:`subprotocol` + attribute. + + Once the connection is closed, the code is available in the + :attr:`close_code` attribute and the reason in :attr:`close_reason`. + + All attributes must be treated as read-only. """ @@ -318,8 +407,12 @@ class Connect: Awaiting :func:`connect` yields a :class:`WebSocketClientProtocol` which can then be used to send and receive messages. - :func:`connect` can also be used as a asynchronous context manager. In - that case, the connection is closed when exiting the context. + :func:`connect` can also be used as a asynchronous context manager:: + + async with connect(...) as websocket: + ... + + In that case, the connection is closed when exiting the context. :func:`connect` is a wrapper around the event loop's :meth:`~asyncio.loop.create_connection` method. Unknown keyword arguments @@ -336,31 +429,28 @@ class Connect: used in the TLS handshake for secure connections and in the ``Host`` HTTP header. - The ``create_protocol`` parameter allows customizing the - :class:`~asyncio.Protocol` that manages the connection. It should be a - callable or class accepting the same arguments as - :class:`WebSocketClientProtocol` and returning an instance of - :class:`WebSocketClientProtocol` or a subclass. It defaults to - :class:`WebSocketClientProtocol`. + ``create_protocol`` defaults to :class:`WebSocketClientProtocol`. It may + be replaced by a wrapper or a subclass to customize the protocol that + manages the connection. The behavior of ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit`` is - described in :class:`~websockets.protocol.WebSocketCommonProtocol`. + described in :class:`WebSocketClientProtocol`. :func:`connect` also accepts the following optional arguments: * ``compression`` is a shortcut to configure compression extensions; by default it enables the "permessage-deflate" extension; set it to - ``None`` to disable compression - * ``origin`` sets the Origin HTTP header + ``None`` to disable compression. + * ``origin`` sets the Origin HTTP header. * ``extensions`` is a list of supported extensions in order of - decreasing preference + decreasing preference. * ``subprotocols`` is a list of supported subprotocols in order of - decreasing preference + decreasing preference. * ``extra_headers`` sets additional HTTP request headers; it can be a :class:`~websockets.http.Headers` instance, a :class:`~collections.abc.Mapping`, or an iterable of ``(name, value)`` - pairs + pairs. :raises ~websockets.uri.InvalidURI: if ``uri`` is invalid :raises ~websockets.handshake.InvalidHandshake: if the opening handshake diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 56c4d5f6a..a46e3dc4e 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -82,108 +82,13 @@ class WebSocketCommonProtocol(asyncio.Protocol): Once the WebSocket connection is established, during the data transfer phase, the protocol is almost symmetrical between the server side and the client side. :class:`WebSocketCommonProtocol` implements logic that's - shared between servers and clients.. + shared between servers and clients. Subclasses such as :class:`~websockets.legacy.server.WebSocketServerProtocol` and :class:`~websockets.legacy.client.WebSocketClientProtocol` implement the opening handshake, which is different between servers and clients. - :class:`WebSocketCommonProtocol` performs four functions: - - * It runs a task that stores incoming data frames in a queue and makes - them available with the :meth:`recv` coroutine. - * It sends outgoing data frames with the :meth:`send` coroutine. - * It deals with control frames automatically. - * It performs the closing handshake. - - :class:`WebSocketCommonProtocol` supports asynchronous iteration:: - - async for message in websocket: - await process(message) - - The iterator yields incoming messages. It exits normally when the - connection is closed with the close code 1000 (OK) or 1001 (going away). - It raises a :exc:`~websockets.exceptions.ConnectionClosedError` exception - when the connection is closed with any other code. - - Once the connection is open, a `Ping frame`_ is sent every - ``ping_interval`` seconds. This serves as a keepalive. It helps keeping - the connection open, especially in the presence of proxies with short - timeouts on inactive connections. Set ``ping_interval`` to ``None`` to - disable this behavior. - - .. _Ping frame: https://tools.ietf.org/html/rfc6455#section-5.5.2 - - If the corresponding `Pong frame`_ isn't received within ``ping_timeout`` - seconds, the connection is considered unusable and is closed with - code 1011. This ensures that the remote endpoint remains responsive. Set - ``ping_timeout`` to ``None`` to disable this behavior. - - .. _Pong frame: https://tools.ietf.org/html/rfc6455#section-5.5.3 - - The ``close_timeout`` parameter defines a maximum wait time in seconds for - completing the closing handshake and terminating the TCP connection. - :meth:`close` completes in at most ``4 * close_timeout`` on the server - side and ``5 * close_timeout`` on the client side. - - ``close_timeout`` needs to be a parameter of the protocol because - ``websockets`` usually calls :meth:`close` implicitly: - - - on the server side, when the connection handler terminates, - - on the client side, when exiting the context manager for the connection. - - To apply a timeout to any other API, wrap it in :func:`~asyncio.wait_for`. - - The ``max_size`` parameter enforces the maximum size for incoming messages - in bytes. The default value is 1 MiB. ``None`` disables the limit. If a - message larger than the maximum size is received, :meth:`recv` will - raise :exc:`~websockets.exceptions.ConnectionClosedError` and the - connection will be closed with code 1009. - - The ``max_queue`` parameter sets the maximum length of the queue that - holds incoming messages. The default value is ``32``. ``None`` disables - the limit. Messages are added to an in-memory queue when they're received; - then :meth:`recv` pops from that queue. In order to prevent excessive - memory consumption when messages are received faster than they can be - processed, the queue must be bounded. If the queue fills up, the protocol - stops processing incoming data until :meth:`recv` is called. In this - situation, various receive buffers (at least in ``asyncio`` and in the OS) - will fill up, then the TCP receive window will shrink, slowing down - transmission to avoid packet loss. - - Since Python can use up to 4 bytes of memory to represent a single - character, each connection may use up to ``4 * max_size * max_queue`` - bytes of memory to store incoming messages. By default, this is 128 MiB. - You may want to lower the limits, depending on your application's - requirements. - - The ``read_limit`` argument sets the high-water limit of the buffer for - incoming bytes. The low-water limit is half the high-water limit. The - default value is 64 KiB, half of asyncio's default (based on the current - implementation of :class:`~asyncio.StreamReader`). - - The ``write_limit`` argument sets the high-water limit of the buffer for - outgoing bytes. The low-water limit is a quarter of the high-water limit. - The default value is 64 KiB, equal to asyncio's default (based on the - current implementation of ``FlowControlMixin``). - - As soon as the HTTP request and response in the opening handshake are - processed: - - * the request path is available in the :attr:`path` attribute; - * the request and response HTTP headers are available in the - :attr:`request_headers` and :attr:`response_headers` attributes, - which are :class:`~websockets.http.Headers` instances. - - If a subprotocol was negotiated, it's available in the :attr:`subprotocol` - attribute. - - Once the connection is closed, the code is available in the - :attr:`close_code` attribute and the reason in :attr:`close_reason`. - - All these attributes must be treated as read-only. - """ # There are only two differences between the client-side and server-side diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 8e5f97a66..e693bbd2f 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -62,11 +62,107 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): """ :class:`~asyncio.Protocol` subclass implementing a WebSocket server. - This class inherits most of its methods from - :class:`~websockets.protocol.WebSocketCommonProtocol`. - - For the sake of simplicity, it doesn't rely on a full HTTP implementation. - Its support for HTTP responses is very limited. + :class:`WebSocketServerProtocol`: + + * performs the opening handshake to establish the connection; + * provides :meth:`recv` and :meth:`send` coroutines for receiving and + sending messages; + * deals with control frames automatically; + * performs the closing handshake to terminate the connection. + + You may customize the opening handshake by subclassing + :class:`WebSocketServer` and overriding: + + * :meth:`process_request` to intercept the client request before any + processing and, if appropriate, to abort the WebSocket request and + return a HTTP response instead; + * :meth:`select_subprotocol` to select a subprotocol, if the client and + the server have multiple subprotocols in common and the default logic + for choosing one isn't suitable (this is rarely needed). + + :class:`WebSocketServerProtocol` supports asynchronous iteration:: + + async for message in websocket: + await process(message) + + The iterator yields incoming messages. It exits normally when the + connection is closed with the close code 1000 (OK) or 1001 (going away). + It raises a :exc:`~websockets.exceptions.ConnectionClosedError` exception + when the connection is closed with any other code. + + Once the connection is open, a `Ping frame`_ is sent every + ``ping_interval`` seconds. This serves as a keepalive. It helps keeping + the connection open, especially in the presence of proxies with short + timeouts on inactive connections. Set ``ping_interval`` to ``None`` to + disable this behavior. + + .. _Ping frame: https://tools.ietf.org/html/rfc6455#section-5.5.2 + + If the corresponding `Pong frame`_ isn't received within ``ping_timeout`` + seconds, the connection is considered unusable and is closed with + code 1011. This ensures that the remote endpoint remains responsive. Set + ``ping_timeout`` to ``None`` to disable this behavior. + + .. _Pong frame: https://tools.ietf.org/html/rfc6455#section-5.5.3 + + The ``close_timeout`` parameter defines a maximum wait time for completing + the closing handshake and terminating the TCP connection. For legacy + reasons, :meth:`close` completes in at most ``4 * close_timeout`` seconds. + + ``close_timeout`` needs to be a parameter of the protocol because + websockets usually calls :meth:`close` implicitly when the connection + handler terminates. + + To apply a timeout to any other API, wrap it in :func:`~asyncio.wait_for`. + + The ``max_size`` parameter enforces the maximum size for incoming messages + in bytes. The default value is 1 MiB. ``None`` disables the limit. If a + message larger than the maximum size is received, :meth:`recv` will + raise :exc:`~websockets.exceptions.ConnectionClosedError` and the + connection will be closed with code 1009. + + The ``max_queue`` parameter sets the maximum length of the queue that + holds incoming messages. The default value is ``32``. ``None`` disables + the limit. Messages are added to an in-memory queue when they're received; + then :meth:`recv` pops from that queue. In order to prevent excessive + memory consumption when messages are received faster than they can be + processed, the queue must be bounded. If the queue fills up, the protocol + stops processing incoming data until :meth:`recv` is called. In this + situation, various receive buffers (at least in :mod:`asyncio` and in the + OS) will fill up, then the TCP receive window will shrink, slowing down + transmission to avoid packet loss. + + Since Python can use up to 4 bytes of memory to represent a single + character, each connection may use up to ``4 * max_size * max_queue`` + bytes of memory to store incoming messages. By default, this is 128 MiB. + You may want to lower the limits, depending on your application's + requirements. + + The ``read_limit`` argument sets the high-water limit of the buffer for + incoming bytes. The low-water limit is half the high-water limit. The + default value is 64 KiB, half of asyncio's default (based on the current + implementation of :class:`~asyncio.StreamReader`). + + The ``write_limit`` argument sets the high-water limit of the buffer for + outgoing bytes. The low-water limit is a quarter of the high-water limit. + The default value is 64 KiB, equal to asyncio's default (based on the + current implementation of ``FlowControlMixin``). + + As soon as the HTTP request and response in the opening handshake are + processed: + + * the request path is available in the :attr:`path` attribute; + * the request and response HTTP headers are available in the + :attr:`request_headers` and :attr:`response_headers` attributes, + which are :class:`~websockets.http.Headers` instances. + + If a subprotocol was negotiated, it's available in the :attr:`subprotocol` + attribute. + + Once the connection is closed, the code is available in the + :attr:`close_code` attribute and the reason in :attr:`close_reason`. + + All attributes must be treated as read-only. """ @@ -487,7 +583,7 @@ def select_subprotocol( Instead of subclassing, it is possible to override this method by passing a ``select_subprotocol`` argument to the :func:`serve` - function or the :class:`WebSocketServerProtocol` constructor + function or the :class:`WebSocketServerProtocol` constructor. :param client_subprotocols: list of subprotocols offered by the client :param server_subprotocols: list of subprotocols available on the server @@ -780,66 +876,69 @@ class Serve: :exc:`~websockets.exceptions.ConnectionClosedOK` exception on their current or next interaction with the WebSocket connection. - :func:`serve` can also be used as an asynchronous context manager. In - this case, the server is shut down when exiting the context. + :func:`serve` can also be used as an asynchronous context manager:: + + stop = asyncio.Future() # set this future to exit the server + + async with serve(...): + await stop + + In this case, the server is shut down when exiting the context. :func:`serve` is a wrapper around the event loop's :meth:`~asyncio.loop.create_server` method. It creates and starts a - :class:`~asyncio.Server` with :meth:`~asyncio.loop.create_server`. Then it - wraps the :class:`~asyncio.Server` in a :class:`WebSocketServer` and + :class:`asyncio.Server` with :meth:`~asyncio.loop.create_server`. Then it + wraps the :class:`asyncio.Server` in a :class:`WebSocketServer` and returns the :class:`WebSocketServer`. - The ``ws_handler`` argument is the WebSocket handler. It must be a - coroutine accepting two arguments: a :class:`WebSocketServerProtocol` and - the request URI. + ``ws_handler`` is the WebSocket handler. It must be a coroutine accepting + two arguments: the WebSocket connection, which is an instance of + :class:`WebSocketServerProtocol`, and the path of the request. The ``host`` and ``port`` arguments, as well as unrecognized keyword - arguments, are passed along to :meth:`~asyncio.loop.create_server`. + arguments, are passed to :meth:`~asyncio.loop.create_server`. For example, you can set the ``ssl`` keyword argument to a :class:`~ssl.SSLContext` to enable TLS. - The ``create_protocol`` parameter allows customizing the - :class:`~asyncio.Protocol` that manages the connection. It should be a - callable or class accepting the same arguments as - :class:`WebSocketServerProtocol` and returning an instance of - :class:`WebSocketServerProtocol` or a subclass. It defaults to - :class:`WebSocketServerProtocol`. + ``create_protocol`` defaults to :class:`WebSocketServerProtocol`. It may + be replaced by a wrapper or a subclass to customize the protocol that + manages the connection. The behavior of ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit`` is - described in :class:`~websockets.protocol.WebSocketCommonProtocol`. + described in :class:`WebSocketServerProtocol`. :func:`serve` also accepts the following optional arguments: * ``compression`` is a shortcut to configure compression extensions; by default it enables the "permessage-deflate" extension; set it to - ``None`` to disable compression - * ``origins`` defines acceptable Origin HTTP headers; include ``None`` if - the lack of an origin is acceptable + ``None`` to disable compression. + * ``origins`` defines acceptable Origin HTTP headers; include ``None`` in + the list if the lack of an origin is acceptable. * ``extensions`` is a list of supported extensions in order of - decreasing preference + decreasing preference. * ``subprotocols`` is a list of supported subprotocols in order of - decreasing preference + decreasing preference. * ``extra_headers`` sets additional HTTP response headers when the handshake succeeds; it can be a :class:`~websockets.http.Headers` instance, a :class:`~collections.abc.Mapping`, an iterable of ``(name, value)`` pairs, or a callable taking the request path and headers in - arguments and returning one of the above + arguments and returning one of the above. * ``process_request`` allows intercepting the HTTP request; it must be a coroutine taking the request path and headers in argument; see - :meth:`~WebSocketServerProtocol.process_request` for details + :meth:`~WebSocketServerProtocol.process_request` for details. * ``select_subprotocol`` allows customizing the logic for selecting a subprotocol; it must be a callable taking the subprotocols offered by the client and available on the server in argument; see - :meth:`~WebSocketServerProtocol.select_subprotocol` for details + :meth:`~WebSocketServerProtocol.select_subprotocol` for details. Since there's no useful way to propagate exceptions triggered in handlers, - they're sent to the ``'websockets.legacy.server'`` logger instead. + they're sent to the ``"websockets.server"`` logger instead. Debugging is much easier if you configure logging to print them:: import logging - logger = logging.getLogger("websockets.legacy.server") + logger = logging.getLogger("websockets.server") logger.setLevel(logging.ERROR) logger.addHandler(logging.StreamHandler()) From 927287380011e4388c11c24d286beef2b877284d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 May 2021 22:49:58 +0200 Subject: [PATCH 0759/1539] Work around coverage bug. --- src/websockets/legacy/protocol.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index a46e3dc4e..e4c6d63c5 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -549,7 +549,9 @@ async def send( # Other fragments. # https://github.com/python/mypy/issues/5738 - async for message_chunk in aiter_message: # type: ignore + # coverage reports this code as not covered, but it is + # exercised by tests - changing it breaks the tests! + async for message_chunk in aiter_message: # type: ignore # pragma: no cover # noqa confirm_opcode, data = prepare_data(message_chunk) if confirm_opcode != opcode: raise TypeError("data contains inconsistent types") From 5ab214b00f38cae3976fce5a315fbfa30762b60d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 May 2021 22:39:12 +0200 Subject: [PATCH 0760/1539] Bump version number --- docs/changelog.rst | 7 ++++++- docs/conf.py | 4 ++-- src/websockets/version.py | 2 +- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 91ea23dc9..2644d3735 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -25,11 +25,16 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented APIs are considered private. They may change at any time. -9.0 +9.1 ... *In development* +9.0 +... + +*May 1, 2021* + .. note:: **Version 9.0 moves or deprecates several APIs.** diff --git a/docs/conf.py b/docs/conf.py index 0c00b96fb..dad7475f7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -59,9 +59,9 @@ # built documents. # # The short X.Y version. -version = '8.1' +version = '9.0' # The full version, including alpha/beta/rc tags. -release = '8.1' +release = '9.0' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/src/websockets/version.py b/src/websockets/version.py index 7377332e1..94d9f2ead 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -1 +1 @@ -version = "8.1" +version = "9.0" From 5d6fcf96cd81680e35cba00ed52cb12bf2c8f544 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 May 2021 22:59:03 +0200 Subject: [PATCH 0761/1539] Python 3.9 is now released. --- .circleci/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 7be85d7f9..8a1441209 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -40,7 +40,7 @@ jobs: - run: tox -e py38 py39: docker: - - image: circleci/python:3.9.0b1 + - image: circleci/python:3.9 steps: # Remove IPv6 entry for localhost in Circle CI containers because it doesn't work anyway. - run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc From 56be5f71e273fee7a2ef86166838f574b58e3c59 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 May 2021 15:36:56 +0200 Subject: [PATCH 0762/1539] Build wheels on Python 3.9. --- .appveyor.yml | 2 +- .travis.yml | 2 +- setup.cfg | 2 +- setup.py | 1 + 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.appveyor.yml b/.appveyor.yml index d34b15aed..ef17ebba5 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -6,7 +6,7 @@ skip_branch_with_pr: true environment: # websockets only works on Python >= 3.6. - CIBW_BUILD: cp36-* cp37-* cp38-* + CIBW_BUILD: cp36-* cp37-* cp38-* cp39-* CIBW_TEST_COMMAND: python -W default -m unittest WEBSOCKETS_TESTS_TIMEOUT_FACTOR: 100 diff --git a/.travis.yml b/.travis.yml index e31c9ea0b..f2bfc724e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,7 @@ env: global: # websockets only works on Python >= 3.6. - - CIBW_BUILD="cp36-* cp37-* cp38-*" + - CIBW_BUILD="cp36-* cp37-* cp38-* cp39-*"" - CIBW_TEST_COMMAND="python3 -W default -m unittest" - WEBSOCKETS_TESTS_TIMEOUT_FACTOR=100 diff --git a/setup.cfg b/setup.cfg index 5448b0f9b..04b792989 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bdist_wheel] -python-tag = py36.py37 +python-tag = py36.py37.py38.py39 [metadata] license_file = LICENSE diff --git a/setup.py b/setup.py index 85d899cb4..5adb8e835 100644 --- a/setup.py +++ b/setup.py @@ -54,6 +54,7 @@ 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', ], package_dir = {'': 'src'}, package_data = {'websockets': ['py.typed']}, From cbae1fb00e07a880bc7e9b566249afa474469c0d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 May 2021 15:36:56 +0200 Subject: [PATCH 0763/1539] Setup GitHub actions. --- .github/workflows/tests.yml | 52 +++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 .github/workflows/tests.yml diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 000000000..eb06ebfea --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,52 @@ +name: Run tests + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + main: + name: Run code quality checks + runs-on: ubuntu-latest + steps: + - name: Check out repository + uses: actions/checkout@v2 + - name: Install Python 3.x + uses: actions/setup-python@v2 + with: + python-version: 3.x + - name: Install tox + run: pip install tox + - name: Run tests with coverage + run: tox -e coverage + - name: Check code formatting + run: tox -e black + - name: Check code style + run: tox -e flake8 + - name: Check imports ordering + run: tox -e isort + - name: Check types statically + run: tox -e mypy + + matrix: + name: Run tests on Python ${{ matrix.python }} + needs: main + runs-on: ubuntu-latest + strategy: + matrix: + python: [3.6, 3.7, 3.8, 3.9] + steps: + - name: Check out repository + uses: actions/checkout@v2 + - name: Install Python ${{ matrix.python }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python }} + - name: Install tox + run: pip install tox + - name: Run tests + run: tox -e py From 3d55449d5df642d6be401c21afee450edb8c4422 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 May 2021 17:57:53 +0200 Subject: [PATCH 0764/1539] Drop CircleCI setup. --- .circleci/config.yml | 67 -------------------------------------------- 1 file changed, 67 deletions(-) delete mode 100644 .circleci/config.yml diff --git a/.circleci/config.yml b/.circleci/config.yml deleted file mode 100644 index 8a1441209..000000000 --- a/.circleci/config.yml +++ /dev/null @@ -1,67 +0,0 @@ -version: 2 - -jobs: - main: - docker: - - image: circleci/python:3.7 - steps: - # Remove IPv6 entry for localhost in Circle CI containers because it doesn't work anyway. - - run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc - - checkout - - run: sudo pip install tox codecov - - run: tox -e coverage,black,flake8,isort,mypy - - run: codecov - py36: - docker: - - image: circleci/python:3.6 - steps: - # Remove IPv6 entry for localhost in Circle CI containers because it doesn't work anyway. - - run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc - - checkout - - run: sudo pip install tox - - run: tox -e py36 - py37: - docker: - - image: circleci/python:3.7 - steps: - # Remove IPv6 entry for localhost in Circle CI containers because it doesn't work anyway. - - run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc - - checkout - - run: sudo pip install tox - - run: tox -e py37 - py38: - docker: - - image: circleci/python:3.8 - steps: - # Remove IPv6 entry for localhost in Circle CI containers because it doesn't work anyway. - - run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc - - checkout - - run: sudo pip install tox - - run: tox -e py38 - py39: - docker: - - image: circleci/python:3.9 - steps: - # Remove IPv6 entry for localhost in Circle CI containers because it doesn't work anyway. - - run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc - - checkout - - run: sudo pip install tox - - run: tox -e py39 - -workflows: - version: 2 - build: - jobs: - - main - - py36: - requires: - - main - - py37: - requires: - - main - - py38: - requires: - - main - - py39: - requires: - - main From c3d7b7f6565bd2a40aa5cdd5d0e44642148d41e2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 May 2021 17:57:59 +0200 Subject: [PATCH 0765/1539] Change badge in README. --- README.rst | 9 +++------ docs/index.rst | 9 +++------ 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/README.rst b/README.rst index 1e15ba198..bda73c640 100644 --- a/README.rst +++ b/README.rst @@ -2,7 +2,7 @@ :width: 480px :alt: websockets -|rtd| |pypi-v| |pypi-pyversions| |pypi-l| |pypi-wheel| |circleci| |codecov| +|rtd| |pypi-v| |pypi-pyversions| |pypi-l| |pypi-wheel| |tests| .. |rtd| image:: https://readthedocs.org/projects/websockets/badge/?version=latest :target: https://websockets.readthedocs.io/ @@ -19,11 +19,8 @@ .. |pypi-wheel| image:: https://img.shields.io/pypi/wheel/websockets.svg :target: https://pypi.python.org/pypi/websockets -.. |circleci| image:: https://img.shields.io/circleci/project/github/aaugustin/websockets.svg - :target: https://circleci.com/gh/aaugustin/websockets - -.. |codecov| image:: https://codecov.io/gh/aaugustin/websockets/branch/master/graph/badge.svg - :target: https://codecov.io/gh/aaugustin/websockets +.. |tests| image:: https://github.com/aaugustin/websockets/workflows/tests/badge.svg?branch=master + :target: https://github.com/aaugustin/websockets/actions?workflow=tests What is ``websockets``? ----------------------- diff --git a/docs/index.rst b/docs/index.rst index e121fd930..5914d7289 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,7 +1,7 @@ websockets ========== -|pypi-v| |pypi-pyversions| |pypi-l| |pypi-wheel| |circleci| |codecov| +|pypi-v| |pypi-pyversions| |pypi-l| |pypi-wheel| |tests| .. |pypi-v| image:: https://img.shields.io/pypi/v/websockets.svg :target: https://pypi.python.org/pypi/websockets @@ -15,11 +15,8 @@ websockets .. |pypi-wheel| image:: https://img.shields.io/pypi/wheel/websockets.svg :target: https://pypi.python.org/pypi/websockets -.. |circleci| image:: https://img.shields.io/circleci/project/github/aaugustin/websockets.svg - :target: https://circleci.com/gh/aaugustin/websockets - -.. |codecov| image:: https://codecov.io/gh/aaugustin/websockets/branch/master/graph/badge.svg - :target: https://codecov.io/gh/aaugustin/websockets +.. |tests| image:: https://github.com/aaugustin/websockets/workflows/tests/badge.svg?branch=master + :target: https://github.com/aaugustin/websockets/actions?workflow=tests ``websockets`` is a library for building WebSocket servers_ and clients_ in Python with a focus on correctness and simplicity. From a45cc5afe067925759a2644bd9ef9b5346adefa1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 May 2021 20:16:16 +0200 Subject: [PATCH 0766/1539] Build distributions on GitHub actions. --- .github/workflows/wheels.yml | 71 ++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 .github/workflows/wheels.yml diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml new file mode 100644 index 000000000..7ea97c61f --- /dev/null +++ b/.github/workflows/wheels.yml @@ -0,0 +1,71 @@ +name: Build wheels + +on: + push: + branches: + - main + tags: + - '*' + +jobs: + sdist: + name: Build source distribution + runs-on: ubuntu-latest + steps: + - name: Check out repository + uses: actions/checkout@v2 + - name: Install Python 3.x + uses: actions/setup-python@v2 + with: + python-version: 3.x + - name: Build sdist + run: python setup.py sdist + - name: Save sdist + uses: actions/upload-artifact@v2 + with: + path: dist/*.tar.gz + + wheels: + name: Build wheels on ${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-20.04, windows-2019, macOS-10.15] + + steps: + - name: Check out repository + uses: actions/checkout@v2 + - name: Make extension build mandatory + run: touch .cibuildwheel + - name: Install Python 3.x + uses: actions/setup-python@v2 + with: + python-version: 3.x + - name: Set up QEMU + if: runner.os == 'Linux' + uses: docker/setup-qemu-action@v1 + with: + platforms: all + - name: Build wheels + uses: joerick/cibuildwheel@v1.11.0 + env: + CIBW_ARCHS_LINUX: auto aarch64 + CIBW_BUILD: cp36-* cp37-* cp38-* cp39-* + - name: Save wheels + uses: actions/upload-artifact@v2 + with: + path: wheelhouse/*.whl + + upload_pypi: + name: Upload to PyPI + needs: [sdist, wheels] + runs-on: ubuntu-latest + if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') + steps: + - uses: actions/download-artifact@v2 + with: + name: artifact + path: dist + - uses: pypa/gh-action-pypi-publish@master + with: + password: ${{ secrets.PYPI_API_TOKEN }} From b0d211d0f32633977e73f51a1573e6a07319a0b0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 May 2021 20:23:38 +0200 Subject: [PATCH 0767/1539] Drop Travis CI and Appveyor setup. --- .appveyor.yml | 27 --------------------------- .travis.yml | 43 ------------------------------------------- 2 files changed, 70 deletions(-) delete mode 100644 .appveyor.yml delete mode 100644 .travis.yml diff --git a/.appveyor.yml b/.appveyor.yml deleted file mode 100644 index ef17ebba5..000000000 --- a/.appveyor.yml +++ /dev/null @@ -1,27 +0,0 @@ -branches: - only: - - master - -skip_branch_with_pr: true - -environment: -# websockets only works on Python >= 3.6. - CIBW_BUILD: cp36-* cp37-* cp38-* cp39-* - CIBW_TEST_COMMAND: python -W default -m unittest - WEBSOCKETS_TESTS_TIMEOUT_FACTOR: 100 - -install: -# Ensure python is Python 3. - - set PATH=C:\Python37;%PATH% - - cmd: python -m pip install --upgrade cibuildwheel -# Create file '.cibuildwheel' so that extension build is not optional (c.f. setup.py). - - cmd: touch .cibuildwheel - -build_script: - - cmd: python -m cibuildwheel --output-dir wheelhouse -# Upload to PyPI on tags - - ps: >- - if ($env:APPVEYOR_REPO_TAG -eq "true") { - Invoke-Expression "python -m pip install twine" - Invoke-Expression "python -m twine upload --skip-existing wheelhouse/*.whl" - } diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index f2bfc724e..000000000 --- a/.travis.yml +++ /dev/null @@ -1,43 +0,0 @@ -env: - global: - # websockets only works on Python >= 3.6. - - CIBW_BUILD="cp36-* cp37-* cp38-* cp39-*"" - - CIBW_TEST_COMMAND="python3 -W default -m unittest" - - WEBSOCKETS_TESTS_TIMEOUT_FACTOR=100 - -matrix: - include: - - language: python - dist: xenial # required for Python 3.7 (travis-ci/travis-ci#9069) - sudo: required - python: "3.7" - services: - - docker - - language: python - dist: xenial - sudo: required - python: "3.7" - arch: arm64 - services: - - docker - - os: osx - osx_image: xcode8.3 - -install: -# Python 3 is needed to run cibuildwheel for websockets. - - if [ "${TRAVIS_OS_NAME:-}" == "osx" ]; then - brew update; - brew upgrade python; - fi -# Install cibuildwheel using pip3 to make sure Python 3 is used. - - pip3 install --upgrade cibuildwheel -# Create file '.cibuildwheel' so that extension build is not optional (c.f. setup.py). - - touch .cibuildwheel - -script: - - cibuildwheel --output-dir wheelhouse -# Upload to PyPI on tags - - if [ "${TRAVIS_TAG:-}" != "" ]; then - pip3 install twine; - python3 -m twine upload --skip-existing wheelhouse/*; - fi From fc176f462b6a5ef4f470df415780b09fed5da7c1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 May 2021 20:49:05 +0200 Subject: [PATCH 0768/1539] Bump version number. --- docs/changelog.rst | 7 +++++++ src/websockets/version.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 2644d3735..1e5f92211 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -30,6 +30,13 @@ They may change at any time. *In development* +9.0.1 +..... + +*May 2, 2021* + +* Fixed issues with the packaging of the 9.0 release. + 9.0 ... diff --git a/src/websockets/version.py b/src/websockets/version.py index 94d9f2ead..23b7f329b 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -1 +1 @@ -version = "9.0" +version = "9.0.1" From b6f085e86f0a62e1fe3a38b33256479cbf84a0dd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 May 2021 20:55:50 +0200 Subject: [PATCH 0769/1539] Fix badge. --- README.rst | 4 ++-- docs/index.rst | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.rst b/README.rst index bda73c640..6b83ffed6 100644 --- a/README.rst +++ b/README.rst @@ -19,8 +19,8 @@ .. |pypi-wheel| image:: https://img.shields.io/pypi/wheel/websockets.svg :target: https://pypi.python.org/pypi/websockets -.. |tests| image:: https://github.com/aaugustin/websockets/workflows/tests/badge.svg?branch=master - :target: https://github.com/aaugustin/websockets/actions?workflow=tests +.. |tests| image:: https://github.com/aaugustin/websockets/actions/workflows/tests.yml/badge.svg + :target: https://github.com/aaugustin/websockets/actions/workflows/tests.yml What is ``websockets``? ----------------------- diff --git a/docs/index.rst b/docs/index.rst index 5914d7289..f0c5f8d00 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -15,8 +15,8 @@ websockets .. |pypi-wheel| image:: https://img.shields.io/pypi/wheel/websockets.svg :target: https://pypi.python.org/pypi/websockets -.. |tests| image:: https://github.com/aaugustin/websockets/workflows/tests/badge.svg?branch=master - :target: https://github.com/aaugustin/websockets/actions?workflow=tests +.. |tests| image:: https://github.com/aaugustin/websockets/actions/workflows/tests.yml/badge.svg + :target: https://github.com/aaugustin/websockets/actions/workflows/tests.yml ``websockets`` is a library for building WebSocket servers_ and clients_ in Python with a focus on correctness and simplicity. From db6f5b50d09ae86606c4e0d9b92079eb02dbff5e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 May 2021 21:24:13 +0200 Subject: [PATCH 0770/1539] Document public methods that can raise ConnectionClosed. Fix #768. --- src/websockets/legacy/protocol.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index e4c6d63c5..c155c3bd7 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -465,6 +465,8 @@ async def send( Stopping in the middle of a fragmented message will cause a protocol error. Closing the connection has the same effect. + :raises ~websockets.exceptions.ConnectionClosed: when the + connection is closed :raises TypeError: for unsupported inputs """ @@ -653,6 +655,11 @@ async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: Canceling the :class:`~asyncio.Future` returned by :meth:`ping` has no effect. + :raises ~websockets.exceptions.ConnectionClosed: when the + connection is closed + :raises ValueError: if another ping was sent with the same data and + the corresponding pong wasn't received yet + """ await self.ensure_open() @@ -685,6 +692,9 @@ async def pong(self, data: Data = b"") -> None: Canceling :meth:`pong` is discouraged for the same reason as :meth:`ping`. + :raises ~websockets.exceptions.ConnectionClosed: when the + connection is closed + """ await self.ensure_open() From e4fe1e0999237c3def055dddfee39c3a09e743a4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 3 May 2021 07:38:21 +0200 Subject: [PATCH 0771/1539] Give up on a flaky test. Ref #390. --- tests/legacy/test_protocol.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index a89bcc88b..f928322ca 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1389,7 +1389,8 @@ def test_local_close_connection_lost_timeout_after_close(self): # half-close our side with write_eof() and close it with close(), time # out in 20ms. # Check the timing within -1/+9ms for robustness. - with self.assertCompletesWithin(19 * MS, 29 * MS): + # Add another 10ms because this test is flaky and I don't understand. + with self.assertCompletesWithin(19 * MS, 39 * MS): # HACK: disable write_eof => other end drops connection emulation. self.transport._eof = True # HACK: disable close => other end drops connection emulation. @@ -1444,12 +1445,13 @@ def test_local_close_connection_lost_timeout_after_close(self): self.protocol.close_timeout = 10 * MS # If the client doesn't close its side of the TCP connection after we # half-close our side with write_eof() and close it with close(), time - # out in 20ms. + # out in 30ms. # - 10ms waiting for receiving a half-close # - 10ms waiting for receiving a close after write_eof # - 10ms waiting for receiving a close after close # Check the timing within -1/+9ms for robustness. - with self.assertCompletesWithin(29 * MS, 39 * MS): + # Add another 10ms because this test is flaky and I don't understand. + with self.assertCompletesWithin(29 * MS, 49 * MS): # HACK: disable write_eof => other end drops connection emulation. self.transport._eof = True # HACK: disable close => other end drops connection emulation. From 4ca3721bd29e1e0efddb89dff17809382af339dd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 May 2021 09:01:01 +0200 Subject: [PATCH 0772/1539] Drop support for Python 3.6. Also fix a few places that were missed when adding 3.9. --- .github/workflows/tests.yml | 2 +- .github/workflows/wheels.yml | 2 +- README.rst | 2 +- docs/changelog.rst | 8 +++- docs/intro.rst | 2 +- setup.cfg | 2 +- setup.py | 7 +-- src/websockets/__main__.py | 6 +-- src/websockets/datastructures.py | 6 +-- src/websockets/imports.py | 74 +++++++++++++------------------- src/websockets/legacy/server.py | 9 +--- src/websockets/typing.py | 20 ++------- tox.ini | 2 +- 13 files changed, 52 insertions(+), 90 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index eb06ebfea..0042c302c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -38,7 +38,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python: [3.6, 3.7, 3.8, 3.9] + python: [3.7, 3.8, 3.9] steps: - name: Check out repository uses: actions/checkout@v2 diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 7ea97c61f..249cd36ce 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -50,7 +50,7 @@ jobs: uses: joerick/cibuildwheel@v1.11.0 env: CIBW_ARCHS_LINUX: auto aarch64 - CIBW_BUILD: cp36-* cp37-* cp38-* cp39-* + CIBW_BUILD: cp37-* cp38-* cp39-* - name: Save wheels uses: actions/upload-artifact@v2 with: diff --git a/README.rst b/README.rst index 6b83ffed6..7a163f5b7 100644 --- a/README.rst +++ b/README.rst @@ -125,7 +125,7 @@ Why shouldn't I use ``websockets``? and :rfc:`7692`: Compression Extensions for WebSocket. Its support for HTTP is minimal — just enough for a HTTP health check. * If you want to use Python 2: ``websockets`` builds upon ``asyncio`` which - only works on Python 3. ``websockets`` requires Python ≥ 3.6.1. + only works on Python 3. ``websockets`` requires Python ≥ 3.7. What else? ---------- diff --git a/docs/changelog.rst b/docs/changelog.rst index 1e5f92211..35f292e49 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -25,11 +25,15 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented APIs are considered private. They may change at any time. -9.1 -... +10.0 +.... *In development* +.. warning:: + + **Version 10.0 drops compatibility with Python 3.6.** + 9.0.1 ..... diff --git a/docs/intro.rst b/docs/intro.rst index c77139cab..58d482a09 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -6,7 +6,7 @@ Getting started Requirements ------------ -``websockets`` requires Python ≥ 3.6.1. +``websockets`` requires Python ≥ 3.7. You should use the latest version of Python if possible. If you're using an older version, be aware that for each minor version (3.x), only the latest diff --git a/setup.cfg b/setup.cfg index 04b792989..b15a7515c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bdist_wheel] -python-tag = py36.py37.py38.py39 +python-tag = py37.py38.py39 [metadata] license_file = LICENSE diff --git a/setup.py b/setup.py index 5adb8e835..d493a12f1 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,5 @@ import pathlib import re -import sys import setuptools @@ -21,9 +20,6 @@ exec((root_dir / 'src' / 'websockets' / 'version.py').read_text(encoding='utf-8')) -if sys.version_info[:3] < (3, 6, 1): - raise Exception("websockets requires Python >= 3.6.1.") - packages = ['websockets', 'websockets/legacy', 'websockets/extensions'] ext_modules = [ @@ -51,7 +47,6 @@ 'Operating System :: OS Independent', 'Programming Language :: Python', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', @@ -62,6 +57,6 @@ ext_modules=ext_modules, include_package_data=True, zip_safe=False, - python_requires='>=3.6.1', + python_requires='>=3.7', test_loader='unittest:TestLoader', ) diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index d44e34e74..530daef85 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -105,8 +105,8 @@ async def run_client( try: while True: - incoming: asyncio.Future[Any] = asyncio.ensure_future(websocket.recv()) - outgoing: asyncio.Future[Any] = asyncio.ensure_future(inputs.get()) + incoming: asyncio.Future[Any] = asyncio.create_task(websocket.recv()) + outgoing: asyncio.Future[Any] = asyncio.create_task(inputs.get()) done: Set[asyncio.Future[Any]] pending: Set[asyncio.Future[Any]] done, pending = await asyncio.wait( @@ -188,7 +188,7 @@ async def queue_factory() -> asyncio.Queue[str]: stop: asyncio.Future[None] = loop.create_future() # Schedule the task that will manage the connection. - asyncio.ensure_future(run_client(args.uri, loop, inputs, stop), loop=loop) + loop.create_task(run_client(args.uri, loop, inputs, stop)) # Start the event loop in a background thread. thread = threading.Thread(target=loop.run_forever) diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index c8e17fa98..66f91e9bb 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -158,8 +158,4 @@ def raw_items(self) -> Iterator[Tuple[str, str]]: HeadersLike = Union[Headers, Mapping[str, str], Iterable[Tuple[str, str]]] HeadersLike__doc__ = """Types accepted wherever :class:`Headers` is expected""" -# Remove try / except when dropping support for Python < 3.7 -try: - HeadersLike.__doc__ = HeadersLike__doc__ -except AttributeError: # pragma: no cover - pass +HeadersLike.__doc__ = HeadersLike__doc__ diff --git a/src/websockets/imports.py b/src/websockets/imports.py index efd3eabf3..06917ce1d 100644 --- a/src/websockets/imports.py +++ b/src/websockets/imports.py @@ -1,4 +1,3 @@ -import sys import warnings from typing import Any, Dict, Iterable, Optional @@ -50,9 +49,6 @@ def lazy_import( This function defines __getattr__ and __dir__ per PEP 562. - On Python 3.6 and earlier, it falls back to non-lazy imports and doesn't - raise deprecation warnings. - """ if aliases is None: aliases = {} @@ -69,43 +65,33 @@ def lazy_import( package = namespace["__name__"] - if sys.version_info[:2] >= (3, 7): - - def __getattr__(name: str) -> Any: - assert aliases is not None # mypy cannot figure this out - try: - source = aliases[name] - except KeyError: - pass - else: - return import_name(name, source, namespace) - - assert deprecated_aliases is not None # mypy cannot figure this out - try: - source = deprecated_aliases[name] - except KeyError: - pass - else: - warnings.warn( - f"{package}.{name} is deprecated", - DeprecationWarning, - stacklevel=2, - ) - return import_name(name, source, namespace) - - raise AttributeError(f"module {package!r} has no attribute {name!r}") - - namespace["__getattr__"] = __getattr__ - - def __dir__() -> Iterable[str]: - return sorted(namespace_set | aliases_set | deprecated_aliases_set) - - namespace["__dir__"] = __dir__ - - else: # pragma: no cover - - for name, source in aliases.items(): - namespace[name] = import_name(name, source, namespace) - - for name, source in deprecated_aliases.items(): - namespace[name] = import_name(name, source, namespace) + def __getattr__(name: str) -> Any: + assert aliases is not None # mypy cannot figure this out + try: + source = aliases[name] + except KeyError: + pass + else: + return import_name(name, source, namespace) + + assert deprecated_aliases is not None # mypy cannot figure this out + try: + source = deprecated_aliases[name] + except KeyError: + pass + else: + warnings.warn( + f"{package}.{name} is deprecated", + DeprecationWarning, + stacklevel=2, + ) + return import_name(name, source, namespace) + + raise AttributeError(f"module {package!r} has no attribute {name!r}") + + namespace["__getattr__"] = __getattr__ + + def __dir__() -> Iterable[str]: + return sorted(namespace_set | aliases_set | deprecated_aliases_set) + + namespace["__dir__"] = __dir__ diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index e693bbd2f..1daf3a9ad 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -760,12 +760,7 @@ def is_serving(self) -> bool: Tell whether the server is accepting new connections or shutting down. """ - try: - # Python ≥ 3.7 - return self.server.is_serving() - except AttributeError: # pragma: no cover - # Python < 3.7 - return self.server.sockets is not None + return self.server.is_serving() def close(self) -> None: """ @@ -815,7 +810,7 @@ async def _close(self) -> None: if self.websockets: await asyncio.wait( [ - asyncio.ensure_future(websocket.close(1001)) + asyncio.create_task(websocket.close(1001)) for websocket in self.websockets ], loop=self.loop if sys.version_info[:2] < (3, 8) else None, diff --git a/src/websockets/typing.py b/src/websockets/typing.py index ca66a8c54..630a9fbe3 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -4,19 +4,13 @@ __all__ = ["Data", "Origin", "ExtensionHeader", "ExtensionParameter", "Subprotocol"] Data = Union[str, bytes] - -Data__doc__ = """ +Data.__doc__ = """ Types supported in a WebSocket message: - :class:`str` for text messages - :class:`bytes` for binary messages """ -# Remove try / except when dropping support for Python < 3.7 -try: - Data.__doc__ = Data__doc__ -except AttributeError: # pragma: no cover - pass Origin = NewType("Origin", str) @@ -28,19 +22,11 @@ ExtensionParameter = Tuple[str, Optional[str]] -ExtensionParameter__doc__ = """Parameter of a WebSocket extension""" -try: - ExtensionParameter.__doc__ = ExtensionParameter__doc__ -except AttributeError: # pragma: no cover - pass +ExtensionParameter.__doc__ = """Parameter of a WebSocket extension""" ExtensionHeader = Tuple[ExtensionName, List[ExtensionParameter]] -ExtensionHeader__doc__ = """Extension in a Sec-WebSocket-Extensions header""" -try: - ExtensionHeader.__doc__ = ExtensionHeader__doc__ -except AttributeError: # pragma: no cover - pass +ExtensionHeader.__doc__ = """Extension in a Sec-WebSocket-Extensions header""" Subprotocol = NewType("Subprotocol", str) diff --git a/tox.ini b/tox.ini index b5488e5b0..e74c979ba 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py36,py37,py38,py39,coverage,black,flake8,isort,mypy +envlist = py37,py38,py39,coverage,black,flake8,isort,mypy [testenv] commands = python -W default -m unittest {posargs} From ab6f4382ec1cd122b3a515601d413a8f247ea79e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 May 2021 09:42:50 +0200 Subject: [PATCH 0773/1539] Remove use of get_event_loop. Refs #916 - get_event_loop is deprecated in Python 3.10. Fix #534. --- README.rst | 15 +++++++-------- compliance/test_client.py | 3 +-- compliance/test_server.py | 14 ++++++++------ docs/faq.rst | 3 ++- docs/heroku.rst | 7 ++++--- example/basic_auth_client.py | 2 +- example/basic_auth_server.py | 17 +++++++++-------- example/client.py | 2 +- example/counter.py | 7 +++---- example/echo.py | 7 ++++--- example/health_check_server.py | 11 ++++++----- example/hello.py | 2 +- example/secure_client.py | 2 +- example/secure_server.py | 11 ++++++----- example/server.py | 7 ++++--- example/show_time.py | 7 ++++--- example/shutdown_client.py | 4 ++-- example/shutdown_server.py | 15 ++++++--------- example/unix_client.py | 2 +- example/unix_server.py | 9 +++++---- performance/mem_client.py | 2 +- performance/mem_server.py | 17 ++++++++--------- 22 files changed, 85 insertions(+), 81 deletions(-) diff --git a/README.rst b/README.rst index 7a163f5b7..5db7d9d0b 100644 --- a/README.rst +++ b/README.rst @@ -45,15 +45,14 @@ Here's how a client sends and receives messages: #!/usr/bin/env python import asyncio - import websockets + from websockets import connect async def hello(uri): - async with websockets.connect(uri) as websocket: + async with connect(uri) as websocket: await websocket.send("Hello world!") await websocket.recv() - asyncio.get_event_loop().run_until_complete( - hello('ws://localhost:8765')) + asyncio.run(hello('ws://localhost:8765')) And here's an echo server: @@ -62,15 +61,15 @@ And here's an echo server: #!/usr/bin/env python import asyncio - import websockets + from websockets import serve async def echo(websocket, path): async for message in websocket: await websocket.send(message) - asyncio.get_event_loop().run_until_complete( - websockets.serve(echo, 'localhost', 8765)) - asyncio.get_event_loop().run_forever() + async def main(): + async with serve(echo, 'localhost', 8765): + await asyncio.Future() # run forever Does that look good? diff --git a/compliance/test_client.py b/compliance/test_client.py index 5fd0f4b4f..1ed4d711e 100644 --- a/compliance/test_client.py +++ b/compliance/test_client.py @@ -45,5 +45,4 @@ async def run_tests(server, agent): await update_reports(server, agent) -main = run_tests(SERVER, urllib.parse.quote(AGENT)) -asyncio.get_event_loop().run_until_complete(main) +asyncio.run(run_tests(SERVER, urllib.parse.quote(AGENT))) diff --git a/compliance/test_server.py b/compliance/test_server.py index 8020f68d3..58e357bc7 100644 --- a/compliance/test_server.py +++ b/compliance/test_server.py @@ -18,10 +18,12 @@ async def echo(ws, path): await ws.send(msg) -start_server = websockets.serve(echo, HOST, PORT, max_size=2 ** 25, max_queue=1) +async def main(): + with websockets.serve(echo, HOST, PORT, max_size=2 ** 25, max_queue=1): + try: + await asyncio.Future() + except KeyboardInterrupt: + pass -try: - asyncio.get_event_loop().run_until_complete(start_server) - asyncio.get_event_loop().run_forever() -except KeyboardInterrupt: - pass + +asyncio.run(main) diff --git a/docs/faq.rst b/docs/faq.rst index ff91105b4..20b74ba98 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -274,7 +274,8 @@ Can I use ``websockets`` synchronously, without ``async`` / ``await``? ...................................................................... You can convert every asynchronous call to a synchronous call by wrapping it -in ``asyncio.get_event_loop().run_until_complete(...)``. +in ``asyncio.get_event_loop().run_until_complete(...)``. Unfortunately, this +is deprecated as of Python 3.10. If this turns out to be impractical, you should use another library. diff --git a/docs/heroku.rst b/docs/heroku.rst index 31c4b3f19..5ef4b120e 100644 --- a/docs/heroku.rst +++ b/docs/heroku.rst @@ -52,10 +52,11 @@ Here's the implementation of the app, an echo server. Save it in a file called async for message in websocket: await websocket.send(message) - start_server = websockets.serve(echo, "", int(os.environ["PORT"])) + async def main(): + async with websockets.serve(echo, "", int(os.environ["PORT"])): + await asyncio.Future() # run forever - asyncio.get_event_loop().run_until_complete(start_server) - asyncio.get_event_loop().run_forever() + asyncio.run(main) The server relies on the ``$PORT`` environment variable to tell on which port it will listen, according to Heroku's conventions. diff --git a/example/basic_auth_client.py b/example/basic_auth_client.py index cc94dbe4b..164732152 100755 --- a/example/basic_auth_client.py +++ b/example/basic_auth_client.py @@ -11,4 +11,4 @@ async def hello(): greeting = await websocket.recv() print(greeting) -asyncio.get_event_loop().run_until_complete(hello()) +asyncio.run(hello()) diff --git a/example/basic_auth_server.py b/example/basic_auth_server.py index 6740d5798..a11910445 100755 --- a/example/basic_auth_server.py +++ b/example/basic_auth_server.py @@ -9,12 +9,13 @@ async def hello(websocket, path): greeting = f"Hello {websocket.username}!" await websocket.send(greeting) -start_server = websockets.serve( - hello, "localhost", 8765, - create_protocol=websockets.basic_auth_protocol_factory( - realm="example", credentials=("mary", "p@ssw0rd") - ), -) +async def main(): + async with websockets.serve( + hello, "localhost", 8765, + create_protocol=websockets.basic_auth_protocol_factory( + realm="example", credentials=("mary", "p@ssw0rd") + ), + ): + await asyncio.Future() # run forever -asyncio.get_event_loop().run_until_complete(start_server) -asyncio.get_event_loop().run_forever() +asyncio.run(main) diff --git a/example/client.py b/example/client.py index 4f969c478..e39df81f7 100755 --- a/example/client.py +++ b/example/client.py @@ -16,4 +16,4 @@ async def hello(): greeting = await websocket.recv() print(f"< {greeting}") -asyncio.get_event_loop().run_until_complete(hello()) +asyncio.run(hello()) diff --git a/example/counter.py b/example/counter.py index 239ec203a..81cbdb55c 100755 --- a/example/counter.py +++ b/example/counter.py @@ -63,7 +63,6 @@ async def counter(websocket, path): await unregister(websocket) -start_server = websockets.serve(counter, "localhost", 6789) - -asyncio.get_event_loop().run_until_complete(start_server) -asyncio.get_event_loop().run_forever() +async def main(): + async with websockets.serve(counter, "localhost", 6789): + await asyncio.Future() # run forever diff --git a/example/echo.py b/example/echo.py index b7ca38d32..b285f1664 100755 --- a/example/echo.py +++ b/example/echo.py @@ -7,7 +7,8 @@ async def echo(websocket, path): async for message in websocket: await websocket.send(message) -start_server = websockets.serve(echo, "localhost", 8765) +async def main(): + async with websockets.serve(echo, "localhost", 8765): + await asyncio.Future() # run forever -asyncio.get_event_loop().run_until_complete(start_server) -asyncio.get_event_loop().run_forever() +asyncio.run(main) diff --git a/example/health_check_server.py b/example/health_check_server.py index 417063fce..bb861fad3 100755 --- a/example/health_check_server.py +++ b/example/health_check_server.py @@ -14,9 +14,10 @@ async def echo(websocket, path): async for message in websocket: await websocket.send(message) -start_server = websockets.serve( - echo, "localhost", 8765, process_request=health_check -) +async def main(): + async with websockets.serve( + echo, "localhost", 8765, process_request=health_check + ): + await asyncio.Future() # run forever -asyncio.get_event_loop().run_until_complete(start_server) -asyncio.get_event_loop().run_forever() +asyncio.run(main) diff --git a/example/hello.py b/example/hello.py index 6c9c839d8..96095dd02 100755 --- a/example/hello.py +++ b/example/hello.py @@ -9,4 +9,4 @@ async def hello(): await websocket.send("Hello world!") await websocket.recv() -asyncio.get_event_loop().run_until_complete(hello()) +asyncio.run(hello()) diff --git a/example/secure_client.py b/example/secure_client.py index 54971b984..455b6492a 100755 --- a/example/secure_client.py +++ b/example/secure_client.py @@ -24,4 +24,4 @@ async def hello(): greeting = await websocket.recv() print(f"< {greeting}") -asyncio.get_event_loop().run_until_complete(hello()) +asyncio.run(hello()) diff --git a/example/secure_server.py b/example/secure_server.py index 2a00bdb50..cc1cea876 100755 --- a/example/secure_server.py +++ b/example/secure_server.py @@ -20,9 +20,10 @@ async def hello(websocket, path): localhost_pem = pathlib.Path(__file__).with_name("localhost.pem") ssl_context.load_cert_chain(localhost_pem) -start_server = websockets.serve( - hello, "localhost", 8765, ssl=ssl_context -) +async def main(): + async with websockets.serve( + hello, "localhost", 8765, ssl=ssl_context + ): + await asyncio.Future() # run forever -asyncio.get_event_loop().run_until_complete(start_server) -asyncio.get_event_loop().run_forever() +asyncio.run(main) diff --git a/example/server.py b/example/server.py index c8ab69971..3ad6c8226 100755 --- a/example/server.py +++ b/example/server.py @@ -14,7 +14,8 @@ async def hello(websocket, path): await websocket.send(greeting) print(f"> {greeting}") -start_server = websockets.serve(hello, "localhost", 8765) +async def main(): + async with websockets.serve(hello, "localhost", 8765): + await asyncio.Future() # run forever -asyncio.get_event_loop().run_until_complete(start_server) -asyncio.get_event_loop().run_forever() +asyncio.run(main) diff --git a/example/show_time.py b/example/show_time.py index e5d6ac9aa..2f2ad897a 100755 --- a/example/show_time.py +++ b/example/show_time.py @@ -13,7 +13,8 @@ async def time(websocket, path): await websocket.send(now) await asyncio.sleep(random.random() * 3) -start_server = websockets.serve(time, "127.0.0.1", 5678) +async def main(): + async with websockets.serve(time, "localhost", 5678): + await asyncio.Future() # run forever -asyncio.get_event_loop().run_until_complete(start_server) -asyncio.get_event_loop().run_forever() +asyncio.run(main) diff --git a/example/shutdown_client.py b/example/shutdown_client.py index f21c0f6fa..ba1287801 100755 --- a/example/shutdown_client.py +++ b/example/shutdown_client.py @@ -7,8 +7,8 @@ async def client(): uri = "ws://localhost:8765" async with websockets.connect(uri) as websocket: + loop = asyncio.get_running_loop() # Close the connection when receiving SIGTERM. - loop = asyncio.get_event_loop() loop.add_signal_handler( signal.SIGTERM, loop.create_task, websocket.close()) @@ -16,4 +16,4 @@ async def client(): async for message in websocket: ... -asyncio.get_event_loop().run_until_complete(client()) +asyncio.run(client()) diff --git a/example/shutdown_server.py b/example/shutdown_server.py index 86846abe7..5732313cb 100755 --- a/example/shutdown_server.py +++ b/example/shutdown_server.py @@ -8,15 +8,12 @@ async def echo(websocket, path): async for message in websocket: await websocket.send(message) -async def echo_server(stop): +async def server(): + loop = asyncio.get_running_loop() + stop = loop.create_future() + # Set the stop condition when receiving SIGTERM. + loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) async with websockets.serve(echo, "localhost", 8765): await stop -loop = asyncio.get_event_loop() - -# The stop condition is set when receiving SIGTERM. -stop = loop.create_future() -loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - -# Run the server until the stop condition is met. -loop.run_until_complete(echo_server(stop)) +asyncio.run(server()) diff --git a/example/unix_client.py b/example/unix_client.py index 577135b3d..434638c80 100755 --- a/example/unix_client.py +++ b/example/unix_client.py @@ -16,4 +16,4 @@ async def hello(): greeting = await websocket.recv() print(f"< {greeting}") -asyncio.get_event_loop().run_until_complete(hello()) +asyncio.run(hello()) diff --git a/example/unix_server.py b/example/unix_server.py index a6ec0168a..80bc824cf 100755 --- a/example/unix_server.py +++ b/example/unix_server.py @@ -15,8 +15,9 @@ async def hello(websocket, path): await websocket.send(greeting) print(f"> {greeting}") -socket_path = os.path.join(os.path.dirname(__file__), "socket") -start_server = websockets.unix_serve(hello, socket_path) +async def main(): + socket_path = os.path.join(os.path.dirname(__file__), "socket") + async with websockets.unix_serve(hello, socket_path): + await asyncio.Future() # run forever -asyncio.get_event_loop().run_until_complete(start_server) -asyncio.get_event_loop().run_forever() +asyncio.run(main) diff --git a/performance/mem_client.py b/performance/mem_client.py index 890216edf..6eab690d8 100644 --- a/performance/mem_client.py +++ b/performance/mem_client.py @@ -43,7 +43,7 @@ async def mem_client(client): await asyncio.sleep(CLIENTS * INTERVAL) -asyncio.get_event_loop().run_until_complete( +asyncio.run( asyncio.gather(*[mem_client(client) for client in range(CLIENTS + 1)]) ) diff --git a/performance/mem_server.py b/performance/mem_server.py index 0a4a29f76..81490a0e7 100644 --- a/performance/mem_server.py +++ b/performance/mem_server.py @@ -31,7 +31,11 @@ async def handler(ws, path): await asyncio.sleep(CLIENTS * INTERVAL) -async def mem_server(stop): +async def mem_server(): + loop = asyncio.get_running_loop() + stop = loop.create_future() + # Set the stop condition when receiving SIGTERM. + loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) async with websockets.serve( handler, "localhost", @@ -44,19 +48,14 @@ async def mem_server(stop): ) ], ): + tracemalloc.start() await stop -loop = asyncio.get_event_loop() +asyncio.run(mem_server()) -stop = loop.create_future() -loop.add_signal_handler(signal.SIGINT, stop.set_result, None) -tracemalloc.start() - -loop.run_until_complete(mem_server(stop)) - -# First connection incurs non-representative setup costs. +# First connection may incur non-representative setup costs. del MEM_SIZE[0] print(f"µ = {statistics.mean(MEM_SIZE) / 1024:.1f} KiB") From 08d8011132ba038b3f6c4d591189b57af4c9f147 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 May 2021 09:19:16 +0200 Subject: [PATCH 0774/1539] Add support for Python 3.10. --- compliance/test_server.py | 2 +- docs/api/client.rst | 6 +- docs/api/server.rst | 6 +- docs/changelog.rst | 8 + docs/heroku.rst | 2 +- example/basic_auth_server.py | 2 +- example/echo.py | 2 +- example/health_check_server.py | 2 +- example/secure_server.py | 2 +- example/server.py | 2 +- example/show_time.py | 2 +- example/unix_server.py | 2 +- setup.cfg | 2 +- setup.py | 1 + src/websockets/exceptions.py | 2 +- src/websockets/legacy/client.py | 8 +- src/websockets/legacy/compatibility.py | 22 ++ src/websockets/legacy/protocol.py | 30 ++- src/websockets/legacy/server.py | 19 +- tests/legacy/test_client_server.py | 272 ++++++++++++------------- tests/legacy/test_protocol.py | 41 ++-- tests/legacy/utils.py | 11 +- tox.ini | 2 +- 23 files changed, 235 insertions(+), 213 deletions(-) create mode 100644 src/websockets/legacy/compatibility.py diff --git a/compliance/test_server.py b/compliance/test_server.py index 58e357bc7..14ac90fe6 100644 --- a/compliance/test_server.py +++ b/compliance/test_server.py @@ -26,4 +26,4 @@ async def main(): pass -asyncio.run(main) +asyncio.run(main()) diff --git a/docs/api/client.rst b/docs/api/client.rst index f969227a9..af341b2ba 100644 --- a/docs/api/client.rst +++ b/docs/api/client.rst @@ -6,16 +6,16 @@ Client Opening a connection -------------------- - .. autofunction:: connect(uri, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, **kwds) + .. autofunction:: connect(uri, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, **kwds) :async: - .. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, **kwds) + .. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, **kwds) :async: Using a connection ------------------ - .. autoclass:: WebSocketClientProtocol(*, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, extensions=None, subprotocols=None, extra_headers=None) + .. autoclass:: WebSocketClientProtocol(*, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, origin=None, extensions=None, subprotocols=None, extra_headers=None) .. autoattribute:: local_address diff --git a/docs/api/server.rst b/docs/api/server.rst index 16c8f6359..849f03fab 100644 --- a/docs/api/server.rst +++ b/docs/api/server.rst @@ -6,10 +6,10 @@ Server Starting a server ----------------- - .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, **kwds) + .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, **kwds) :async: - .. autofunction:: unix_serve(ws_handler, path, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, **kwds) + .. autofunction:: unix_serve(ws_handler, path, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, **kwds) :async: Stopping a server @@ -25,7 +25,7 @@ Server Using a connection ------------------ - .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None) + .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None) .. autoattribute:: local_address diff --git a/docs/changelog.rst b/docs/changelog.rst index 35f292e49..70102cade 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -34,6 +34,14 @@ They may change at any time. **Version 10.0 drops compatibility with Python 3.6.** +.. note:: + + **Version 10.0 deprecates the** ``loop`` **parameter from all APIs for the + same reasons the same change was made in Python 3.8. See the release notes + of Python 3.10 for details.** + +* Added compatibility with Python 3.10. + 9.0.1 ..... diff --git a/docs/heroku.rst b/docs/heroku.rst index 5ef4b120e..8af2ebd3d 100644 --- a/docs/heroku.rst +++ b/docs/heroku.rst @@ -56,7 +56,7 @@ Here's the implementation of the app, an echo server. Save it in a file called async with websockets.serve(echo, "", int(os.environ["PORT"])): await asyncio.Future() # run forever - asyncio.run(main) + asyncio.run(main()) The server relies on the ``$PORT`` environment variable to tell on which port it will listen, according to Heroku's conventions. diff --git a/example/basic_auth_server.py b/example/basic_auth_server.py index a11910445..532c5bc51 100755 --- a/example/basic_auth_server.py +++ b/example/basic_auth_server.py @@ -18,4 +18,4 @@ async def main(): ): await asyncio.Future() # run forever -asyncio.run(main) +asyncio.run(main()) diff --git a/example/echo.py b/example/echo.py index b285f1664..024f8d8ac 100755 --- a/example/echo.py +++ b/example/echo.py @@ -11,4 +11,4 @@ async def main(): async with websockets.serve(echo, "localhost", 8765): await asyncio.Future() # run forever -asyncio.run(main) +asyncio.run(main()) diff --git a/example/health_check_server.py b/example/health_check_server.py index bb861fad3..2565f9c48 100755 --- a/example/health_check_server.py +++ b/example/health_check_server.py @@ -20,4 +20,4 @@ async def main(): ): await asyncio.Future() # run forever -asyncio.run(main) +asyncio.run(main()) diff --git a/example/secure_server.py b/example/secure_server.py index cc1cea876..55b5a4231 100755 --- a/example/secure_server.py +++ b/example/secure_server.py @@ -26,4 +26,4 @@ async def main(): ): await asyncio.Future() # run forever -asyncio.run(main) +asyncio.run(main()) diff --git a/example/server.py b/example/server.py index 3ad6c8226..98dbb5acd 100755 --- a/example/server.py +++ b/example/server.py @@ -18,4 +18,4 @@ async def main(): async with websockets.serve(hello, "localhost", 8765): await asyncio.Future() # run forever -asyncio.run(main) +asyncio.run(main()) diff --git a/example/show_time.py b/example/show_time.py index 2f2ad897a..8e39f1776 100755 --- a/example/show_time.py +++ b/example/show_time.py @@ -17,4 +17,4 @@ async def main(): async with websockets.serve(time, "localhost", 5678): await asyncio.Future() # run forever -asyncio.run(main) +asyncio.run(main()) diff --git a/example/unix_server.py b/example/unix_server.py index 80bc824cf..223d97301 100755 --- a/example/unix_server.py +++ b/example/unix_server.py @@ -20,4 +20,4 @@ async def main(): async with websockets.unix_serve(hello, socket_path): await asyncio.Future() # run forever -asyncio.run(main) +asyncio.run(main()) diff --git a/setup.cfg b/setup.cfg index b15a7515c..d8877aa2e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bdist_wheel] -python-tag = py37.py38.py39 +python-tag = py37.py38.py39.py310 [metadata] license_file = LICENSE diff --git a/setup.py b/setup.py index d493a12f1..b2d07737d 100644 --- a/setup.py +++ b/setup.py @@ -50,6 +50,7 @@ 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', ], package_dir = {'': 'src'}, package_data = {'websockets': ['py.typed']}, diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index e0860c743..9ab9d3ebe 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -314,7 +314,7 @@ def __init__( self.status = status self.headers = Headers(headers) self.body = body - message = f"HTTP {status}, {len(self.headers)} headers, {len(body)} bytes" + message = f"HTTP {status:d}, {len(self.headers)} headers, {len(body)} bytes" super().__init__(message) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 1b5bd303f..b77b4e86d 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -33,6 +33,7 @@ from ..http import USER_AGENT, build_host from ..typing import ExtensionHeader, Origin, Subprotocol from ..uri import WebSocketURI, parse_uri +from .compatibility import asyncio_get_running_loop from .handshake import build_request, check_response from .http import read_response from .protocol import WebSocketCommonProtocol @@ -472,7 +473,6 @@ def __init__( max_queue: Optional[int] = 2 ** 5, read_limit: int = 2 ** 16, write_limit: int = 2 ** 16, - loop: Optional[asyncio.AbstractEventLoop] = None, compression: Optional[str] = "deflate", origin: Optional[Origin] = None, extensions: Optional[Sequence[ClientExtensionFactory]] = None, @@ -503,8 +503,12 @@ def __init__( # Backwards compatibility: recv() used to return None on closed connections legacy_recv: bool = kwargs.pop("legacy_recv", False) + # Backwards compatibility: the loop parameter used to be supported. + loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None) if loop is None: - loop = asyncio.get_event_loop() + loop = asyncio_get_running_loop() + else: + warnings.warn("remove loop argument", DeprecationWarning) wsuri = parse_uri(uri) if wsuri.secure: diff --git a/src/websockets/legacy/compatibility.py b/src/websockets/legacy/compatibility.py new file mode 100644 index 000000000..86f6715fd --- /dev/null +++ b/src/websockets/legacy/compatibility.py @@ -0,0 +1,22 @@ +import asyncio +import sys +from typing import Any, Dict + + +def loop_if_py_lt_38(loop: asyncio.AbstractEventLoop) -> Dict[str, Any]: + """ + Helper for the removal of the loop argument in Python 3.10. + + """ + return {"loop": loop} if sys.version_info[:2] < (3, 8) else {} + + +def asyncio_get_running_loop() -> asyncio.AbstractEventLoop: + """ + Helper for the deprecation of get_event_loop in Python 3.10. + + """ + if sys.version_info[:2] < (3, 10): # pragma: no cover + return asyncio.get_event_loop() + else: # pragma: no cover + return asyncio.get_running_loop() diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index c155c3bd7..4e8958b60 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -14,7 +14,6 @@ import logging import random import struct -import sys import warnings from typing import ( Any, @@ -55,6 +54,7 @@ serialize_close, ) from ..typing import Data, Subprotocol +from .compatibility import loop_if_py_lt_38 from .framing import Frame @@ -107,13 +107,13 @@ def __init__( max_queue: Optional[int] = 2 ** 5, read_limit: int = 2 ** 16, write_limit: int = 2 ** 16, - loop: Optional[asyncio.AbstractEventLoop] = None, # The following arguments are kept only for backwards compatibility. host: Optional[str] = None, port: Optional[int] = None, secure: Optional[bool] = None, - legacy_recv: bool = False, timeout: Optional[float] = None, + legacy_recv: bool = False, + loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: # Backwards compatibility: close_timeout used to be called timeout. if timeout is None: @@ -132,8 +132,8 @@ def __init__( self.read_limit = read_limit self.write_limit = write_limit - if loop is None: - loop = asyncio.get_event_loop() + assert loop is not None + # Remove when dropping Python < 3.10 - use get_running_loop instead. self.loop = loop self._host = host @@ -151,9 +151,7 @@ def __init__( self._paused = False self._drain_waiter: Optional[asyncio.Future[None]] = None - self._drain_lock = asyncio.Lock( - loop=loop if sys.version_info[:2] < (3, 8) else None - ) + self._drain_lock = asyncio.Lock(**loop_if_py_lt_38(loop)) # This class implements the data transfer and closing handshake, which # are shared between the client-side and the server-side. @@ -231,9 +229,7 @@ async def _drain(self) -> None: # pragma: no cover # write(...); yield from drain() # in a loop would never call connection_lost(), so it # would not see an error when the socket is closed. - await asyncio.sleep( - 0, loop=self.loop if sys.version_info[:2] < (3, 8) else None - ) + await asyncio.sleep(0, **loop_if_py_lt_38(self.loop)) await self._drain_helper() def connection_open(self) -> None: @@ -403,8 +399,8 @@ async def recv(self) -> Data: # pop_message_waiter and self.transfer_data_task. await asyncio.wait( [pop_message_waiter, self.transfer_data_task], - loop=self.loop if sys.version_info[:2] < (3, 8) else None, return_when=asyncio.FIRST_COMPLETED, + **loop_if_py_lt_38(self.loop), ) finally: self._pop_message_waiter = None @@ -601,7 +597,7 @@ async def close(self, code: int = 1000, reason: str = "") -> None: await asyncio.wait_for( self.write_close_frame(serialize_close(code, reason)), self.close_timeout, - loop=self.loop if sys.version_info[:2] < (3, 8) else None, + **loop_if_py_lt_38(self.loop), ) except asyncio.TimeoutError: # If the close frame cannot be sent because the send buffers @@ -622,7 +618,7 @@ async def close(self, code: int = 1000, reason: str = "") -> None: await asyncio.wait_for( self.transfer_data_task, self.close_timeout, - loop=self.loop if sys.version_info[:2] < (3, 8) else None, + **loop_if_py_lt_38(self.loop), ) except (asyncio.TimeoutError, asyncio.CancelledError): pass @@ -1052,7 +1048,7 @@ async def keepalive_ping(self) -> None: while True: await asyncio.sleep( self.ping_interval, - loop=self.loop if sys.version_info[:2] < (3, 8) else None, + **loop_if_py_lt_38(self.loop), ) # ping() raises CancelledError if the connection is closed, @@ -1068,7 +1064,7 @@ async def keepalive_ping(self) -> None: await asyncio.wait_for( pong_waiter, self.ping_timeout, - loop=self.loop if sys.version_info[:2] < (3, 8) else None, + **loop_if_py_lt_38(self.loop), ) except asyncio.TimeoutError: logger.debug("%s ! timed out waiting for pong", self.side) @@ -1168,7 +1164,7 @@ async def wait_for_connection_lost(self) -> bool: await asyncio.wait_for( asyncio.shield(self.connection_lost_waiter), self.close_timeout, - loop=self.loop if sys.version_info[:2] < (3, 8) else None, + **loop_if_py_lt_38(self.loop), ) except asyncio.TimeoutError: pass diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 1daf3a9ad..5bd7d0f56 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -10,7 +10,6 @@ import http import logging import socket -import sys import warnings from types import TracebackType from typing import ( @@ -43,6 +42,7 @@ from ..headers import build_extension, parse_extension, parse_subprotocol from ..http import USER_AGENT from ..typing import ExtensionHeader, Origin, Subprotocol +from .compatibility import asyncio_get_running_loop, loop_if_py_lt_38 from .handshake import build_response, check_request from .http import read_request from .protocol import WebSocketCommonProtocol @@ -798,9 +798,7 @@ async def _close(self) -> None: # Wait until all accepted connections reach connection_made() and call # register(). See https://bugs.python.org/issue34852 for details. - await asyncio.sleep( - 0, loop=self.loop if sys.version_info[:2] < (3, 8) else None - ) + await asyncio.sleep(0, **loop_if_py_lt_38(self.loop)) # Close OPEN connections with status code 1001. Since the server was # closed, handshake() closes OPENING conections with a HTTP 503 error. @@ -813,7 +811,7 @@ async def _close(self) -> None: asyncio.create_task(websocket.close(1001)) for websocket in self.websockets ], - loop=self.loop if sys.version_info[:2] < (3, 8) else None, + **loop_if_py_lt_38(self.loop), ) # Wait until all connection handlers are complete. @@ -822,7 +820,7 @@ async def _close(self) -> None: if self.websockets: await asyncio.wait( [websocket.handler_task for websocket in self.websockets], - loop=self.loop if sys.version_info[:2] < (3, 8) else None, + **loop_if_py_lt_38(self.loop), ) # Tell wait_closed() to return. @@ -953,7 +951,6 @@ def __init__( max_queue: Optional[int] = 2 ** 5, read_limit: int = 2 ** 16, write_limit: int = 2 ** 16, - loop: Optional[asyncio.AbstractEventLoop] = None, compression: Optional[str] = "deflate", origins: Optional[Sequence[Optional[Origin]]] = None, extensions: Optional[Sequence[ServerExtensionFactory]] = None, @@ -990,10 +987,14 @@ def __init__( # Backwards compatibility: recv() used to return None on closed connections legacy_recv: bool = kwargs.pop("legacy_recv", False) + # Backwards compatibility: the loop parameter used to be supported. + loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None) if loop is None: - loop = asyncio.get_event_loop() + loop = asyncio_get_running_loop() + else: + warnings.warn("remove loop argument", DeprecationWarning) - ws_server = WebSocketServer(loop) + ws_server = WebSocketServer(loop=loop) secure = kwargs.get("ssl") is not None diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 499ea1d59..80c9d56bb 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -6,6 +6,7 @@ import random import socket import ssl +import sys import tempfile import unittest import unittest.mock @@ -51,7 +52,7 @@ testcert = bytes(pathlib.Path(__file__).parent.with_name("test_localhost.pem")) -async def handler(ws, path): +async def default_handler(ws, path): if path == "/deprecated_attributes": await ws.recv() # delay that allows catching warnings await ws.send(repr((ws.host, ws.port, ws.secure))) @@ -84,9 +85,9 @@ def temp_test_server(test, **kwargs): @contextlib.contextmanager def temp_test_redirecting_server( - test, status, include_location=True, force_insecure=False + test, status, include_location=True, force_insecure=False, **kwargs ): - test.start_redirecting_server(status, include_location, force_insecure) + test.start_redirecting_server(status, include_location, force_insecure, **kwargs) try: yield finally: @@ -206,21 +207,36 @@ def server_context(self): return None def start_server(self, deprecation_warnings=None, **kwargs): + handler = kwargs.pop("handler", default_handler) # Disable compression by default in tests. kwargs.setdefault("compression", None) # Disable pings by default in tests. kwargs.setdefault("ping_interval", None) + # Python 3.10 dislikes not having a running event loop + if sys.version_info[:2] >= (3, 10): # pragma: no cover + kwargs.setdefault("loop", self.loop) with warnings.catch_warnings(record=True) as recorded_warnings: start_server = serve(handler, "localhost", 0, **kwargs) self.server = self.loop.run_until_complete(start_server) expected_warnings = [] if deprecation_warnings is None else deprecation_warnings + if sys.version_info[:2] >= (3, 10): # pragma: no cover + expected_warnings += ["remove loop argument"] self.assertDeprecationWarnings(recorded_warnings, expected_warnings) def start_redirecting_server( - self, status, include_location=True, force_insecure=False + self, + status, + include_location=True, + force_insecure=False, + deprecation_warnings=None, + **kwargs, ): + # Python 3.10 dislikes not having a running event loop + if sys.version_info[:2] >= (3, 10): # pragma: no cover + kwargs.setdefault("loop", self.loop) + async def process_request(path, headers): server_uri = get_server_uri(self.server, self.secure, path) if force_insecure: @@ -228,16 +244,23 @@ async def process_request(path, headers): headers = {"Location": server_uri} if include_location else [] return status, headers, b"" - start_server = serve( - handler, - "localhost", - 0, - compression=None, - ping_interval=None, - process_request=process_request, - ssl=self.server_context, - ) - self.redirecting_server = self.loop.run_until_complete(start_server) + with warnings.catch_warnings(record=True) as recorded_warnings: + start_server = serve( + default_handler, + "localhost", + 0, + compression=None, + ping_interval=None, + process_request=process_request, + ssl=self.server_context, + **kwargs, + ) + self.redirecting_server = self.loop.run_until_complete(start_server) + + expected_warnings = [] if deprecation_warnings is None else deprecation_warnings + if sys.version_info[:2] >= (3, 10): # pragma: no cover + expected_warnings += ["remove loop argument"] + self.assertDeprecationWarnings(recorded_warnings, expected_warnings) def start_client( self, resource_name="/", user_info=None, deprecation_warnings=None, **kwargs @@ -246,6 +269,10 @@ def start_client( kwargs.setdefault("compression", None) # Disable pings by default in tests. kwargs.setdefault("ping_interval", None) + # Python 3.10 dislikes not having a running event loop + if sys.version_info[:2] >= (3, 10): # pragma: no cover + kwargs.setdefault("loop", self.loop) + secure = kwargs.get("ssl") is not None try: server_uri = kwargs.pop("uri") @@ -258,6 +285,8 @@ def start_client( self.client = self.loop.run_until_complete(start_client) expected_warnings = [] if deprecation_warnings is None else deprecation_warnings + if sys.version_info[:2] >= (3, 10): # pragma: no cover + expected_warnings += ["remove loop argument"] self.assertDeprecationWarnings(recorded_warnings, expected_warnings) def stop_client(self): @@ -376,7 +405,12 @@ def test_redirect(self): self.assertEqual(reply, "Hello!") def test_infinite_redirect(self): - with temp_test_redirecting_server(self, http.HTTPStatus.FOUND): + with temp_test_redirecting_server( + self, + http.HTTPStatus.FOUND, + loop=self.loop, + deprecation_warnings=["remove loop argument"], + ): self.server = self.redirecting_server with self.assertRaises(InvalidHandshake): with temp_test_client(self): @@ -385,15 +419,23 @@ def test_infinite_redirect(self): @with_server() def test_redirect_missing_location(self): with temp_test_redirecting_server( - self, http.HTTPStatus.FOUND, include_location=False + self, + http.HTTPStatus.FOUND, + include_location=False, + loop=self.loop, + deprecation_warnings=["remove loop argument"], ): with self.assertRaises(InvalidHeader): with temp_test_client(self): self.fail("Did not raise") # pragma: no cover def test_explicit_event_loop(self): - with self.temp_server(loop=self.loop): - with self.temp_client(loop=self.loop): + with self.temp_server( + loop=self.loop, deprecation_warnings=["remove loop argument"] + ): + with self.temp_client( + loop=self.loop, deprecation_warnings=["remove loop argument"] + ): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") @@ -460,12 +502,19 @@ def test_unix_socket(self): path = bytes(pathlib.Path(temp_dir) / "websockets") # Like self.start_server() but with unix_serve(). - unix_server = unix_serve(handler, path) - self.server = self.loop.run_until_complete(unix_server) + with warnings.catch_warnings(record=True) as recorded_warnings: + unix_server = unix_serve(default_handler, path, loop=self.loop) + self.server = self.loop.run_until_complete(unix_server) + self.assertDeprecationWarnings(recorded_warnings, ["remove loop argument"]) + try: # Like self.start_client() but with unix_connect() - unix_client = unix_connect(path) - self.client = self.loop.run_until_complete(unix_client) + with warnings.catch_warnings(record=True) as recorded_warnings: + unix_client = unix_connect(path, loop=self.loop) + self.client = self.loop.run_until_complete(unix_client) + self.assertDeprecationWarnings( + recorded_warnings, ["remove loop argument"] + ) try: self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) @@ -1214,7 +1263,9 @@ class SecureClientServerTests( @with_server() def test_ws_uri_is_rejected(self): with self.assertRaises(ValueError): - connect(get_server_uri(self.server, secure=False), ssl=self.client_context) + self.start_client( + uri=get_server_uri(self.server, secure=False), ssl=self.client_context + ) @with_server() def test_redirect_insecure(self): @@ -1226,107 +1277,59 @@ def test_redirect_insecure(self): self.fail("Did not raise") # pragma: no cover -class ClientServerOriginTests(AsyncioTestCase): +class ClientServerOriginTests(ClientServerTestsMixin, AsyncioTestCase): + @with_server(origins=["http://localhost"]) + @with_client(origin="http://localhost") def test_checking_origin_succeeds(self): - server = self.loop.run_until_complete( - serve(handler, "localhost", 0, origins=["http://localhost"]) - ) - client = self.loop.run_until_complete( - connect(get_server_uri(server), origin="http://localhost") - ) - - self.loop.run_until_complete(client.send("Hello!")) - self.assertEqual(self.loop.run_until_complete(client.recv()), "Hello!") - - self.loop.run_until_complete(client.close()) - server.close() - self.loop.run_until_complete(server.wait_closed()) + self.loop.run_until_complete(self.client.send("Hello!")) + self.assertEqual(self.loop.run_until_complete(self.client.recv()), "Hello!") + @with_server(origins=["http://localhost"]) def test_checking_origin_fails(self): - server = self.loop.run_until_complete( - serve(handler, "localhost", 0, origins=["http://localhost"]) - ) with self.assertRaisesRegex( InvalidHandshake, "server rejected WebSocket connection: HTTP 403" ): - self.loop.run_until_complete( - connect(get_server_uri(server), origin="http://otherhost") - ) - - server.close() - self.loop.run_until_complete(server.wait_closed()) + self.start_client(origin="http://otherhost") + @with_server(origins=["http://localhost"]) def test_checking_origins_fails_with_multiple_headers(self): - server = self.loop.run_until_complete( - serve(handler, "localhost", 0, origins=["http://localhost"]) - ) with self.assertRaisesRegex( InvalidHandshake, "server rejected WebSocket connection: HTTP 400" ): - self.loop.run_until_complete( - connect( - get_server_uri(server), - origin="http://localhost", - extra_headers=[("Origin", "http://otherhost")], - ) + self.start_client( + origin="http://localhost", + extra_headers=[("Origin", "http://otherhost")], ) - server.close() - self.loop.run_until_complete(server.wait_closed()) - + @with_server(origins=[None]) + @with_client() def test_checking_lack_of_origin_succeeds(self): - server = self.loop.run_until_complete( - serve(handler, "localhost", 0, origins=[None]) - ) - client = self.loop.run_until_complete(connect(get_server_uri(server))) - - self.loop.run_until_complete(client.send("Hello!")) - self.assertEqual(self.loop.run_until_complete(client.recv()), "Hello!") - - self.loop.run_until_complete(client.close()) - server.close() - self.loop.run_until_complete(server.wait_closed()) + self.loop.run_until_complete(self.client.send("Hello!")) + self.assertEqual(self.loop.run_until_complete(self.client.recv()), "Hello!") + @with_server(origins=[""]) + @with_client(deprecation_warnings=["use None instead of '' in origins"]) def test_checking_lack_of_origin_succeeds_backwards_compatibility(self): - with warnings.catch_warnings(record=True) as recorded_warnings: - server = self.loop.run_until_complete( - serve(handler, "localhost", 0, origins=[""]) - ) - client = self.loop.run_until_complete(connect(get_server_uri(server))) - - self.assertDeprecationWarnings( - recorded_warnings, ["use None instead of '' in origins"] - ) - - self.loop.run_until_complete(client.send("Hello!")) - self.assertEqual(self.loop.run_until_complete(client.recv()), "Hello!") - - self.loop.run_until_complete(client.close()) - server.close() - self.loop.run_until_complete(server.wait_closed()) + self.loop.run_until_complete(self.client.send("Hello!")) + self.assertEqual(self.loop.run_until_complete(self.client.recv()), "Hello!") -class YieldFromTests(AsyncioTestCase): +class YieldFromTests(ClientServerTestsMixin, AsyncioTestCase): + @with_server() def test_client(self): - start_server = serve(handler, "localhost", 0) - server = self.loop.run_until_complete(start_server) - # @asyncio.coroutine is deprecated on Python ≥ 3.8 with warnings.catch_warnings(record=True): @asyncio.coroutine def run_client(): # Yield from connect. - client = yield from connect(get_server_uri(server)) + client = yield from connect(get_server_uri(self.server)) self.assertEqual(client.state, State.OPEN) yield from client.close() self.assertEqual(client.state, State.CLOSED) self.loop.run_until_complete(run_client()) - server.close() - self.loop.run_until_complete(server.wait_closed()) - def test_server(self): # @asyncio.coroutine is deprecated on Python ≥ 3.8 with warnings.catch_warnings(record=True): @@ -1334,7 +1337,7 @@ def test_server(self): @asyncio.coroutine def run_server(): # Yield from serve. - server = yield from serve(handler, "localhost", 0) + server = yield from serve(default_handler, "localhost", 0) self.assertTrue(server.sockets) server.close() yield from server.wait_closed() @@ -1343,27 +1346,22 @@ def run_server(): self.loop.run_until_complete(run_server()) -class AsyncAwaitTests(AsyncioTestCase): +class AsyncAwaitTests(ClientServerTestsMixin, AsyncioTestCase): + @with_server() def test_client(self): - start_server = serve(handler, "localhost", 0) - server = self.loop.run_until_complete(start_server) - async def run_client(): # Await connect. - client = await connect(get_server_uri(server)) + client = await connect(get_server_uri(self.server)) self.assertEqual(client.state, State.OPEN) await client.close() self.assertEqual(client.state, State.CLOSED) self.loop.run_until_complete(run_client()) - server.close() - self.loop.run_until_complete(server.wait_closed()) - def test_server(self): async def run_server(): # Await serve. - server = await serve(handler, "localhost", 0) + server = await serve(default_handler, "localhost", 0) self.assertTrue(server.sockets) server.close() await server.wait_closed() @@ -1372,14 +1370,12 @@ async def run_server(): self.loop.run_until_complete(run_server()) -class ContextManagerTests(AsyncioTestCase): +class ContextManagerTests(ClientServerTestsMixin, AsyncioTestCase): + @with_server() def test_client(self): - start_server = serve(handler, "localhost", 0) - server = self.loop.run_until_complete(start_server) - async def run_client(): # Use connect as an asynchronous context manager. - async with connect(get_server_uri(server)) as client: + async with connect(get_server_uri(self.server)) as client: self.assertEqual(client.state, State.OPEN) # Check that exiting the context manager closed the connection. @@ -1387,13 +1383,10 @@ async def run_client(): self.loop.run_until_complete(run_client()) - server.close() - self.loop.run_until_complete(server.wait_closed()) - def test_server(self): async def run_server(): # Use serve as an asynchronous context manager. - async with serve(handler, "localhost", 0) as server: + async with serve(default_handler, "localhost", 0) as server: self.assertTrue(server.sockets) # Check that exiting the context manager closed the server. @@ -1404,7 +1397,7 @@ async def run_server(): @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") def test_unix_server(self): async def run_server(path): - async with unix_serve(handler, path) as server: + async with unix_serve(default_handler, path) as server: self.assertTrue(server.sockets) # Check that exiting the context manager closed the server. @@ -1415,26 +1408,24 @@ async def run_server(path): self.loop.run_until_complete(run_server(path)) -class AsyncIteratorTests(AsyncioTestCase): +class AsyncIteratorTests(ClientServerTestsMixin, AsyncioTestCase): # This is a protocol-level feature, but since it's a high-level API, it is # much easier to exercise at the client or server level. MESSAGES = ["3", "2", "1", "Fire!"] - def test_iterate_on_messages(self): - async def handler(ws, path): - for message in self.MESSAGES: - await ws.send(message) - - start_server = serve(handler, "localhost", 0) - server = self.loop.run_until_complete(start_server) + async def echo_handler(ws, path): + for message in AsyncIteratorTests.MESSAGES: + await ws.send(message) + @with_server(handler=echo_handler) + def test_iterate_on_messages(self): messages = [] async def run_client(): nonlocal messages - async with connect(get_server_uri(server)) as ws: + async with connect(get_server_uri(self.server)) as ws: async for message in ws: messages.append(message) @@ -1442,23 +1433,18 @@ async def run_client(): self.assertEqual(messages, self.MESSAGES) - server.close() - self.loop.run_until_complete(server.wait_closed()) + async def echo_handler_1001(ws, path): + for message in AsyncIteratorTests.MESSAGES: + await ws.send(message) + await ws.close(1001) + @with_server(handler=echo_handler_1001) def test_iterate_on_messages_going_away_exit_ok(self): - async def handler(ws, path): - for message in self.MESSAGES: - await ws.send(message) - await ws.close(1001) - - start_server = serve(handler, "localhost", 0) - server = self.loop.run_until_complete(start_server) - messages = [] async def run_client(): nonlocal messages - async with connect(get_server_uri(server)) as ws: + async with connect(get_server_uri(self.server)) as ws: async for message in ws: messages.append(message) @@ -1466,23 +1452,18 @@ async def run_client(): self.assertEqual(messages, self.MESSAGES) - server.close() - self.loop.run_until_complete(server.wait_closed()) + async def echo_handler_1011(ws, path): + for message in AsyncIteratorTests.MESSAGES: + await ws.send(message) + await ws.close(1011) + @with_server(handler=echo_handler_1011) def test_iterate_on_messages_internal_error_exit_not_ok(self): - async def handler(ws, path): - for message in self.MESSAGES: - await ws.send(message) - await ws.close(1011) - - start_server = serve(handler, "localhost", 0) - server = self.loop.run_until_complete(start_server) - messages = [] async def run_client(): nonlocal messages - async with connect(get_server_uri(server)) as ws: + async with connect(get_server_uri(self.server)) as ws: async for message in ws: messages.append(message) @@ -1490,6 +1471,3 @@ async def run_client(): self.loop.run_until_complete(run_client()) self.assertEqual(messages, self.MESSAGES) - - server.close() - self.loop.run_until_complete(server.wait_closed()) diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index f928322ca..58444ce5a 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1,6 +1,5 @@ import asyncio import contextlib -import sys import unittest import unittest.mock import warnings @@ -15,6 +14,7 @@ OP_TEXT, serialize_close, ) +from websockets.legacy.compatibility import loop_if_py_lt_38 from websockets.legacy.framing import Frame from websockets.legacy.protocol import State, WebSocketCommonProtocol @@ -87,7 +87,7 @@ class CommonTests: def setUp(self): super().setUp() # Disable pings to make it easier to test what frames are sent exactly. - self.protocol = WebSocketCommonProtocol(ping_interval=None) + self.protocol = WebSocketCommonProtocol(ping_interval=None, loop=self.loop) self.transport = TransportMock() self.transport.setup_mock(self.loop, self.protocol) @@ -105,9 +105,7 @@ def make_drain_slow(self, delay=MS): original_drain = self.protocol._drain async def delayed_drain(): - await asyncio.sleep( - delay, loop=self.loop if sys.version_info[:2] < (3, 8) else None - ) + await asyncio.sleep(delay, **loop_if_py_lt_38(self.loop)) await original_drain() self.protocol._drain = delayed_drain @@ -312,14 +310,13 @@ def assertCompletesWithin(self, min_time, max_time): def test_timeout_backwards_compatibility(self): with warnings.catch_warnings(record=True) as recorded_warnings: - protocol = WebSocketCommonProtocol(timeout=5) + protocol = WebSocketCommonProtocol(timeout=5, loop=self.loop) self.assertEqual(protocol.close_timeout, 5) - self.assertEqual(len(recorded_warnings), 1) - warning = recorded_warnings[0].message - self.assertEqual(str(warning), "rename timeout to close_timeout") - self.assertEqual(type(warning), DeprecationWarning) + self.assertDeprecationWarnings( + recorded_warnings, ["rename timeout to close_timeout"] + ) # Test public attributes. @@ -647,7 +644,14 @@ async def send_concurrent(): await asyncio.sleep(MS) await self.protocol.send(b"tea") - self.loop.run_until_complete(asyncio.gather(send_iterable(), send_concurrent())) + async def run_concurrently(): + await asyncio.gather( + send_iterable(), + send_concurrent(), + ) + + self.loop.run_until_complete(run_concurrently()) + self.assertFramesSent( (False, OP_TEXT, "ca".encode("utf-8")), (False, OP_CONT, "fé".encode("utf-8")), @@ -714,9 +718,14 @@ async def send_concurrent(): await asyncio.sleep(MS) await self.protocol.send(b"tea") - self.loop.run_until_complete( - asyncio.gather(send_async_iterable(), send_concurrent()) - ) + async def run_concurrently(): + await asyncio.gather( + send_async_iterable(), + send_concurrent(), + ) + + self.loop.run_until_complete(run_concurrently()) + self.assertFramesSent( (False, OP_TEXT, "ca".encode("utf-8")), (False, OP_CONT, "fé".encode("utf-8")), @@ -1098,7 +1107,9 @@ def restart_protocol_with_keepalive_ping( self.loop.run_until_complete(self.protocol.close()) # copied from setUp, but enables keepalive pings self.protocol = WebSocketCommonProtocol( - ping_interval=ping_interval, ping_timeout=ping_timeout + ping_interval=ping_interval, + ping_timeout=ping_timeout, + loop=self.loop, ) self.transport = TransportMock() self.transport.setup_mock(self.loop, self.protocol) diff --git a/tests/legacy/utils.py b/tests/legacy/utils.py index 983a91edf..4d4306232 100644 --- a/tests/legacy/utils.py +++ b/tests/legacy/utils.py @@ -74,11 +74,12 @@ def assertDeprecationWarnings(self, recorded_warnings, expected_warnings): Check recorded deprecation warnings match a list of expected messages. """ - self.assertEqual(len(recorded_warnings), len(expected_warnings)) - for recorded, expected in zip(recorded_warnings, expected_warnings): - actual = recorded.message - self.assertEqual(str(actual), expected) - self.assertEqual(type(actual), DeprecationWarning) + for recorded in recorded_warnings: + self.assertEqual(type(recorded.message), DeprecationWarning) + self.assertEqual( + set(str(recorded.message) for recorded in recorded_warnings), + set(expected_warnings), + ) # Unit for timeouts. May be increased on slow machines by setting the diff --git a/tox.ini b/tox.ini index e74c979ba..c243e9880 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py37,py38,py39,coverage,black,flake8,isort,mypy +envlist = py37,py38,py39,py310,coverage,black,flake8,isort,mypy [testenv] commands = python -W default -m unittest {posargs} From eb856f2ba31d1154e18db4ca33b0ab9586ae129c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 3 May 2021 09:21:18 +0200 Subject: [PATCH 0775/1539] Add missing asyncio.run in example. --- README.rst | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index 5db7d9d0b..fa42a9be3 100644 --- a/README.rst +++ b/README.rst @@ -52,7 +52,7 @@ Here's how a client sends and receives messages: await websocket.send("Hello world!") await websocket.recv() - asyncio.run(hello('ws://localhost:8765')) + asyncio.run(hello("ws://localhost:8765")) And here's an echo server: @@ -68,9 +68,11 @@ And here's an echo server: await websocket.send(message) async def main(): - async with serve(echo, 'localhost', 8765): + async with serve(echo, "localhost", 8765): await asyncio.Future() # run forever + asyncio.run(main()) + Does that look good? `Get started with the tutorial! `_ From 9834fca95204c517adedca2478a53058ecf72ae3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 7 May 2021 21:21:37 +0200 Subject: [PATCH 0776/1539] Use relative imports everywhere, for consistency. Fix #946. --- docs/api/extensions.rst | 2 +- docs/extensions.rst | 7 +++---- src/websockets/extensions/__init__.py | 4 ++++ src/websockets/frames.py | 6 +++--- src/websockets/legacy/framing.py | 6 +++--- 5 files changed, 14 insertions(+), 11 deletions(-) diff --git a/docs/api/extensions.rst b/docs/api/extensions.rst index 635c5c426..71f015bb2 100644 --- a/docs/api/extensions.rst +++ b/docs/api/extensions.rst @@ -13,7 +13,7 @@ Per-Message Deflate Abstract classes ---------------- -.. automodule:: websockets.extensions.base +.. automodule:: websockets.extensions .. autoclass:: Extension :members: diff --git a/docs/extensions.rst b/docs/extensions.rst index 151a7e297..042ed3d9a 100644 --- a/docs/extensions.rst +++ b/docs/extensions.rst @@ -91,9 +91,8 @@ As a consequence, writing an extension requires implementing several classes: ``websockets`` provides abstract base classes for extension factories and extensions. See the API documentation for details on their methods: -* :class:`~base.ClientExtensionFactory` and - :class:`~base.ServerExtensionFactory` for extension factories, - -* :class:`~base.Extension` for extensions. +* :class:`ClientExtensionFactory` and class:`ServerExtensionFactory` for + :extension factories, +* :class:`Extension` for extensions. diff --git a/src/websockets/extensions/__init__.py b/src/websockets/extensions/__init__.py index e69de29bb..02838b98a 100644 --- a/src/websockets/extensions/__init__.py +++ b/src/websockets/extensions/__init__.py @@ -0,0 +1,4 @@ +from .base import * + + +__all__ = ["Extension", "ClientExtensionFactory", "ServerExtensionFactory"] diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 71783e176..6e5ef1b73 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -103,7 +103,7 @@ def parse( *, mask: bool, max_size: Optional[int] = None, - extensions: Optional[Sequence["websockets.extensions.base.Extension"]] = None, + extensions: Optional[Sequence["extensions.Extension"]] = None, ) -> Generator[None, None, "Frame"]: """ Read a WebSocket frame. @@ -172,7 +172,7 @@ def serialize( self, *, mask: bool, - extensions: Optional[Sequence["websockets.extensions.base.Extension"]] = None, + extensions: Optional[Sequence["extensions.Extension"]] = None, ) -> bytes: """ Write a WebSocket frame. @@ -338,4 +338,4 @@ def check_close(code: int) -> None: # at the bottom to allow circular import, because Extension depends on Frame -import websockets.extensions.base # isort:skip # noqa +from . import extensions # isort:skip # noqa diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index e41c295dd..627e6922c 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -31,7 +31,7 @@ async def read( *, mask: bool, max_size: Optional[int] = None, - extensions: Optional[Sequence["websockets.extensions.base.Extension"]] = None, + extensions: Optional[Sequence["extensions.Extension"]] = None, ) -> "Frame": """ Read a WebSocket frame. @@ -102,7 +102,7 @@ def write( write: Callable[[bytes], Any], *, mask: bool, - extensions: Optional[Sequence["websockets.extensions.base.Extension"]] = None, + extensions: Optional[Sequence["extensions.Extension"]] = None, ) -> None: """ Write a WebSocket frame. @@ -132,4 +132,4 @@ def write( # at the bottom to allow circular import, because Extension depends on Frame -import websockets.extensions.base # isort:skip # noqa +from .. import extensions # isort:skip # noqa From 60e9531eb978b41deee1db8552b27faa749aa515 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 3 May 2021 23:49:45 +0200 Subject: [PATCH 0777/1539] Document how to use with Django. Fix #890. --- docs/howto/django.rst | 276 +++++++++++++++++++++++++++++++ docs/index.rst | 1 + docs/spelling_wordlist.txt | 28 ++-- example/django/authentication.py | 53 ++++++ example/django/notifications.py | 70 ++++++++ example/django/signals.py | 23 +++ 6 files changed, 439 insertions(+), 12 deletions(-) create mode 100644 docs/howto/django.rst create mode 100644 example/django/authentication.py create mode 100644 example/django/notifications.py create mode 100644 example/django/signals.py diff --git a/docs/howto/django.rst b/docs/howto/django.rst new file mode 100644 index 000000000..fd170c387 --- /dev/null +++ b/docs/howto/django.rst @@ -0,0 +1,276 @@ +Using websockets with Django +============================ + +If you're looking at adding real-time capabilities to a Django project with +WebSocket, you have two main options. + +1. Using Django Channels_, a project adding WebSocket to Django, among other + features. This approach is fully supported by Django. However, it requires + switching to a new deployment architecture. + +2. Deploying a separate WebSocket server next to your Django project. This + technique is well suited when you need to add a small set of real-time + features — maybe a notification service — to a HTTP application. + +.. _Channels: https://channels.readthedocs.io/en/latest/ + +This guide shows how to implement the second technique with websockets. It +assumes familiarity with Django. + +Authenticating connections +-------------------------- + +Since the websockets server will run outside of Django, we need to connect it +to ``django.contrib.auth``. + +Our clients are running in browser. The `WebSocket API`_ doesn't support +setting `custom headers`_ so our options boil down to: + +* HTTP Basic Auth: this seems technically possible but isn't supported by + Firefox (`bug 1229443`_) so browser support is clearly insufficient. +* Sharing cookies: this is technically possible if there's a common parent + domain between the Django server (e.g. api.example.com) and the websockets + server (e.g. ws.example.com). However, there's a risk to share cookies too + widely (e.g. with anything under .example.com here). For authentication + cookies, this risk seems unacceptable. +* Sending an authentication ticket: Django generates a secure single-use token + with the user ID. The browser includes this token in the WebSocket URI when + it connects to the server in order to authenticate. It could also send the + ticket over the WebSocket connection in the first message, however this is a + bit more difficult to monitor, as you can't detect authentication failures + simply by looking at HTTP response codes. + +.. _custom headers: https://github.com/whatwg/html/issues/3062 +.. _WebSocket API: https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API +.. _bug 1229443: https://bugzilla.mozilla.org/show_bug.cgi?id=1229443 + +To generate our authentication tokens, we'll use `django-sesame`_, a small +library designed exactly for this purpose. + + .. _django-sesame: https://github.com/aaugustin/django-sesame + +Add django-sesame to the dependencies of your Django project, install it and configure it in the settings of the project: + +.. code:: python + + AUTHENTICATION_BACKENDS = [ + "django.contrib.auth.backends.ModelBackend", + "sesame.backends.ModelBackend", + ] + +(If your project already uses another authentication backend than the default +``"django.contrib.auth.backends.ModelBackend"``, adjust accordingly.) + +You don't need ``"sesame.middleware.AuthenticationMiddleware"``. It is for +authenticating users in the Django server, while we're authenticating them in +the websockets server. + +We'd like our tokens to be valid for 30 seconds and usable only once. A +shorter lifespan is possible but it would make manual testing difficult. + +Configure django-sesame accordingly in the settings of your Django project: + +.. code:: python + + SESAME_MAX_AGE = 30 + SESAME_ONE_TIME = True + +Now you can generate tokens in a ``django-admin shell`` as follows: + +.. code:: pycon + + >>> from django.contrib.auth import get_user_model + >>> User = get_user_model() + >>> user = User.objects.get(username="") + >>> from sesame.utils import get_token + >>> get_token(user) + '' + +Keep this console open: since tokens are single use, we'll have to generate a +new token every time we want to test our server. + +Let's move on to the websockets server. Add websockets to the dependencies of +your Django project and install it. + +Now here's a way to implement authentication. We're taking advantage of the +:meth:`~websockets.server.WebSocketServerProtocol.process_request` hook to +authenticate requests. If authentication succeeds, we store the user as an +attribute of the connection in order to make it available to the connection +handler. If authentication fails, we return a HTTP 401 Unauthorized error. + +.. literalinclude:: ../../example/django/authentication.py + +Let's unpack this code. + +We're using Django in a `standalone script`_. This requires setting the +``DJANGO_SETTINGS_MODULE`` environment variable and calling ``django.setup()`` +before doing anything with Django. + +.. _standalone script: https://docs.djangoproject.com/en/stable/topics/settings/#calling-django-setup-is-required-for-standalone-django-usage + +We subclass :class:`~websockets.server.WebSocketServerProtocol` and override +:meth:`~websockets.server.WebSocketServerProtocol.process_request`, where: + +* We extract the token from the URL with the ``get_sesame()`` utility function + defined just above. If the token is missing, we return a HTTP 401 error. +* We authenticate the user with ``get_user()``, the API for `authentication + outside views`_. If authentication fails, we return a HTTP 401 error. + +.. _authentication outside views: https://github.com/aaugustin/django-sesame#authentication-outside-views + +When we call an API that makes a database query, we wrap the call in +:func:`~asyncio.to_thread`. Indeed, the Django ORM doesn't support +asynchronous I/O. We would block the event loop if we didn't run these calls +in a separate thread. :func:`~asyncio.to_thread` is available since Python +3.9; in earlier versions, use :meth:`~asyncio.loop.run_in_executor` instead. + +The connection handler accesses the logged-in user that we stored as an +attribute of the connection object so we can test that authentication works. + +Finally, we start a server with :func:`~websockets.serve`, with the +``create_protocol`` pointing to our subclass of +:class:`~websockets.server.WebSocketServerProtocol`. + +We're ready to test! + +Make sure the ``DJANGO_SETTINGS_MODULE`` environment variable is set to the +Python path to your settings module and start the websockets server. If you +saved the server implementation to a file called ``authentication.py``: + +.. code:: console + + $ python authentication.py + +Open a new shell, generate a new token — remember, they're only valid for +30 seconds — and use it to connect to your server: + +.. code:: console + + $ python -m websockets "ws://localhost:8888/?sesame=" + Connected to ws://localhost:8888/?sesame= + < Hello ! + Connection closed: code = 1000 (OK), no reason. + +It works! + +If we try to reuse the same token, the connection is now rejected: + +.. code:: console + + $ python -m websockets "ws://localhost:8888/?sesame=" + Failed to connect to ws://localhost:8888/?sesame=: + server rejected WebSocket connection: HTTP 401. + +You can also test from a browser by generating a new token and running the +following code in the JavaScript console of the browser: + +.. code:: javascript + + webSocket = new WebSocket("ws://localhost:8888/?sesame="); + webSocket.onmessage = (event) => console.log(event.data); + +Streaming events +---------------- + +We can connect and authenticate but our server doesn't do anything useful yet! + +Let's send a message every time any user makes any action in the admin. This +message will be broadcast to all users who can access the model on which the +action was made. This may be used for showing notifications to other users. + +Many use cases for WebSocket with Django follow a similar pattern. + +We need a event bus to enable communications between Django and websockets. +Both sides connect permanently to the bus. Then Django writes events and +websockets reads them. For the sake of simplicity, we'll rely on `Redis +Pub/Sub`_. + +.. _Redis Pub/Sub: https://redis.io/topics/pubsub + +Let's start by writing events. The easiest way to add Redis to a Django +project is by configuring a cache backend with `django-redis`_. This library +manages connections to Redis efficiently, persisting them between requests, +and provides an API to access the Redis connection directly. + +.. _django-redis: https://github.com/jazzband/django-redis + +Install Redis, add django-redis to the dependencies of your Django project, +install it and configure it in the settings of the project: + +.. code:: python + + CACHES = { + "default": { + "BACKEND": "django_redis.cache.RedisCache", + "LOCATION": "redis://127.0.0.1:6379/1", + }, + } + +If you already have a default cache, add a new one with a different name and +change ``get_redis_connection("default")`` in the code below to the same name. + +Add the following code to a module that is imported when your Django project +starts. Typically, you would put it in a ``signals.py`` module, which you +would import in the ``AppConfig.ready()`` method of one of your apps: + +.. literalinclude:: ../../example/django/signals.py + +This code runs every time the admin saves a ``LogEntry`` object to keep track +of a change. It extracts interesting data, serializes it to JSON, and writes +an event to Redis. + +Let's check that it works: + +.. code:: console + + $ redis-cli + 127.0.0.1:6379> SELECT 1 + OK + 127.0.0.1:6379[1]> SUBSCRIBE events + Reading messages... (press Ctrl-C to quit) + 1) "subscribe" + 2) "events" + 3) (integer) 1 + +Leave this command running, start the Django development server and make +changes in the admin: add, modify, or delete objects. You should see +corresponding events published to the ``"events"`` stream. + +Now let's turn to reading events and broadcasting them to connected clients. +We'll reuse our custom ``ServerProtocol`` class for authentication. Then we +need to add several features: + +* Keep track of connected clients so we can broadcast messages. +* Tell which content types the user has permission to view or to change. +* Connect to the message bus and read events. +* Broadcast these events to users who have corresponding permissions. + +Here's a complete implementation. + +.. literalinclude:: ../../example/django/notifications.py + +Since the ``get_content_types()`` function makes a database query, it is +wrapped inside ``asyncio.to_thread()``. It runs once when each WebSocket +connection is open; then its result is cached for the lifetime of the +connection. Indeed, running it for each message would trigger database queries +for all connected users at the same time, which could hurt the database. + +The connection handler merely registers the connection in a global variable, +associated to the list of content types for which events should be sent to +that connection, and waits until the client disconnects. + +The ``process_events()`` function reads events from Redis and broadcasts them +to all connections that should receive them. We don't care much if a sending a +notification fails — this happens when a connection drops between the moment +we iterate on connections and the moment the corresponding message is sent — +so we start a task with for each message and forget about it. Also, this means +we're immediately ready to process the next event, even if it takes time to +send a message to a slow client. + +Since Redis can publish a message to multiple subscribers, multiple instances +of this server can safely run in parallel. + +In theory, given enough servers, this design can scale to a hundred million +clients, since Redis can handle ten thousand servers and each server can +handle ten thousand clients. In practice, you would need a more scalable +message bus before reaching that scale, due to the volume of messages. diff --git a/docs/index.rst b/docs/index.rst index f0c5f8d00..257a806ea 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -59,6 +59,7 @@ These guides will help you build and deploy a ``websockets`` application. cheatsheet deployment extensions + howto/django heroku Reference diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index d7c744147..e168035c1 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1,16 +1,16 @@ +api attr augustin -Auth +auth awaitable aymeric +backend backpressure -Backpressure balancer balancers -Bitcoin +bitcoin bottlenecked bufferbloat -Bufferbloat bugfix bytestring bytestrings @@ -19,38 +19,42 @@ coroutine coroutines cryptocurrencies cryptocurrency -Ctrl +ctrl daemonize datastructures +django fractalideas IPv iterable keepalive KiB lifecycle -Lifecycle lookups MiB nginx +onmessage parsers permessage pong pongs -Pythonic +pythonic +redis +scalable serializers -Subclasses subclasses subclassing subprotocol subprotocols -Tidelift -TLS +tidelift +tls tox -Unparse +unparse unregister uple username virtualenv -websocket WebSocket +websocket websockets +ws +wss diff --git a/example/django/authentication.py b/example/django/authentication.py new file mode 100644 index 000000000..c0a061109 --- /dev/null +++ b/example/django/authentication.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python + +import asyncio +import http +import urllib.parse + +import django +import websockets + +django.setup() + +from sesame.utils import get_user + + +def get_sesame(path): + """Utility function to extract sesame token from request path.""" + query = urllib.parse.urlparse(path).query + params = urllib.parse.parse_qs(query) + sesame = params.get("sesame", []) + if len(sesame) == 1: + return sesame[0] + + +class ServerProtocol(websockets.WebSocketServerProtocol): + async def process_request(self, path, headers): + """Authenticate users with a django-sesame token.""" + sesame = get_sesame(path) + if sesame is None: + return http.HTTPStatus.UNAUTHORIZED, [], b"Missing token\n" + + user = await asyncio.to_thread(get_user, sesame) + if user is None: + return http.HTTPStatus.UNAUTHORIZED, [], b"Invalid token\n" + + self.user = user + + +async def handler(websocket, path): + await websocket.send(f"Hello {websocket.user}!") + + +async def main(): + async with websockets.serve( + handler, + "localhost", + 8888, + create_protocol=ServerProtocol, + ): + await asyncio.Future() # run forever + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/django/notifications.py b/example/django/notifications.py new file mode 100644 index 000000000..ad2751d98 --- /dev/null +++ b/example/django/notifications.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python + +import asyncio +import json + +import aioredis +import django +import websockets + +django.setup() + +from django.contrib.contenttypes.models import ContentType + +# Reuse our custom protocol to authenticate connections +from authentication import ServerProtocol + + +CONNECTIONS = {} + + +def get_content_types(user): + """Return the set of IDs of content types visible by user.""" + # This does only three database queries because Django caches + # all permissions on the first call to user.has_perm(...). + return { + ct.id + for ct in ContentType.objects.all() + if user.has_perm(f"{ct.app_label}.view_{ct.model}") + or user.has_perm(f"{ct.app_label}.change_{ct.model}") + } + + +async def handler(websocket, path): + """Register connection in CONNECTIONS dict, until it's closed.""" + ct_ids = await asyncio.to_thread(get_content_types, websocket.user) + CONNECTIONS[websocket] = {"content_type_ids": ct_ids} + try: + await websocket.wait_closed() + finally: + del CONNECTIONS[websocket] + + +async def process_events(): + """Listen to events in Redis and process them.""" + redis = aioredis.from_url("redis://127.0.0.1:6379/1") + pubsub = redis.pubsub() + await pubsub.subscribe("events") + async for message in pubsub.listen(): + if message["type"] != "message": + continue + payload = message["data"].decode() + # Broadcast event to all users who have permissions to see it. + event = json.loads(payload) + for websocket, connection in CONNECTIONS.items(): + if event["content_type_id"] in connection["content_type_ids"]: + asyncio.create_task(websocket.send(payload)) + + +async def main(): + async with websockets.serve( + handler, + "localhost", + 8888, + create_protocol=ServerProtocol, + ): + await process_events() # runs forever + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/django/signals.py b/example/django/signals.py new file mode 100644 index 000000000..6dc827f72 --- /dev/null +++ b/example/django/signals.py @@ -0,0 +1,23 @@ +import json + +from django.contrib.admin.models import LogEntry +from django.db.models.signals import post_save +from django.dispatch import receiver + +from django_redis import get_redis_connection + + +@receiver(post_save, sender=LogEntry) +def publish_event(instance, **kwargs): + event = { + "model": instance.content_type.name, + "object": instance.object_repr, + "message": instance.get_change_message(), + "timestamp": instance.action_time.isoformat(), + "user": str(instance.user), + "content_type_id": instance.content_type_id, + "object_id": instance.object_id, + } + connection = get_redis_connection("default") + payload = json.dumps(event) + connection.publish("events", payload) From cceff406270c47037357a5adab9665bc05b3ab15 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 8 May 2021 08:19:38 +0200 Subject: [PATCH 0778/1539] Add FAQ about connection handler terminating early. Ref #948. --- docs/faq.rst | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/docs/faq.rst b/docs/faq.rst index 20b74ba98..8173f7b61 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -14,6 +14,23 @@ FAQ Server side ----------- +Why does my server close the connection prematurely? +.................................................... + +Your connection handler exits prematurely. Wait for the work to be finished +before returning. + +For example, if your handler has a structure similar to:: + + async def handler(websocket, path): + ... + asyncio.create_task(do_some_work()) + +change it to:: + + async def handler(websocket, path): + await do_some_work() + Why does the server close the connection after processing one message? ...................................................................... From aa79b6c057c5f11c85b0ac0f932e0abf59e8033c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 8 May 2021 08:22:03 +0200 Subject: [PATCH 0779/1539] Add FAQ about context manager terminating early. Ref #947. --- docs/faq.rst | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/docs/faq.rst b/docs/faq.rst index 8173f7b61..5f70de43a 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -144,6 +144,22 @@ If you need more, pick a HTTP server and run it separately. Client side ----------- +Why does my client close the connection prematurely? +.................................................... + +You're exiting the context manager prematurely. Wait for the work to be +finished before exiting. + +For example, if your code has a structure similar to:: + + async with connect(...) as websocket: + asyncio.create_task(do_some_work()) + +change it to:: + + async with connect(...) as websocket: + await do_some_work() + How do I close a connection properly? ..................................... From 088c59bb895f24e46e6606d702376c9e7a229d29 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 8 May 2021 08:26:15 +0200 Subject: [PATCH 0780/1539] Update FAQ answer on Python 2. --- docs/faq.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/faq.rst b/docs/faq.rst index 5f70de43a..53da0f004 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -346,6 +346,8 @@ Is there a Python 2 version? No, there isn't. -websockets builds upon asyncio which requires Python 3. +Python 2 reached end of life on January 1st, 2020. + +Before that date, websockets required asyncio and therefore Python 3. From 217ac2d19174c6f01d9524648eb4058985f72754 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 13 May 2021 22:37:31 +0200 Subject: [PATCH 0781/1539] Fix broken link. Fix #953. --- docs/design.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/design.rst b/docs/design.rst index 0cabc2e5d..61b42b528 100644 --- a/docs/design.rst +++ b/docs/design.rst @@ -13,7 +13,7 @@ wish to understand what happens under the hood. Internals described in this document may change at any time. - Backwards compatibility is only guaranteed for `public APIs `_. + Backwards compatibility is only guaranteed for :doc:`public APIs `. Lifecycle From 70fadbf97c5a117ca13f6c8f4f111ba5025f3c94 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 13 May 2021 22:41:46 +0200 Subject: [PATCH 0782/1539] Restore compatibility with Python < 3.9. Fix #951. --- docs/changelog.rst | 7 +++++++ src/websockets/__main__.py | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 1e5f92211..fb40aee2a 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -30,6 +30,13 @@ They may change at any time. *In development* +9.0.2 +..... + +*In development* + +* Restored compatibility of ``python -m websockets`` with Python < 3.9. + 9.0.1 ..... diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index d44e34e74..fb126997a 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -178,11 +178,11 @@ def main() -> None: # Due to zealous removal of the loop parameter in the Queue constructor, # we need a factory coroutine to run in the freshly created event loop. - async def queue_factory() -> asyncio.Queue[str]: + async def queue_factory() -> "asyncio.Queue[str]": return asyncio.Queue() # Create a queue of user inputs. There's no need to limit its size. - inputs: asyncio.Queue[str] = loop.run_until_complete(queue_factory()) + inputs: "asyncio.Queue[str]" = loop.run_until_complete(queue_factory()) # Create a stop condition when receiving SIGINT or SIGTERM. stop: asyncio.Future[None] = loop.create_future() From e44e085e030d186c7bb9822becfbb5423aefe971 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 7 May 2021 21:21:37 +0200 Subject: [PATCH 0783/1539] Use relative imports everywhere, for consistency. Fix #946. --- docs/api/extensions.rst | 2 +- docs/extensions.rst | 7 +++---- src/websockets/extensions/__init__.py | 4 ++++ src/websockets/frames.py | 6 +++--- src/websockets/legacy/framing.py | 6 +++--- 5 files changed, 14 insertions(+), 11 deletions(-) diff --git a/docs/api/extensions.rst b/docs/api/extensions.rst index 635c5c426..71f015bb2 100644 --- a/docs/api/extensions.rst +++ b/docs/api/extensions.rst @@ -13,7 +13,7 @@ Per-Message Deflate Abstract classes ---------------- -.. automodule:: websockets.extensions.base +.. automodule:: websockets.extensions .. autoclass:: Extension :members: diff --git a/docs/extensions.rst b/docs/extensions.rst index 151a7e297..042ed3d9a 100644 --- a/docs/extensions.rst +++ b/docs/extensions.rst @@ -91,9 +91,8 @@ As a consequence, writing an extension requires implementing several classes: ``websockets`` provides abstract base classes for extension factories and extensions. See the API documentation for details on their methods: -* :class:`~base.ClientExtensionFactory` and - :class:`~base.ServerExtensionFactory` for extension factories, - -* :class:`~base.Extension` for extensions. +* :class:`ClientExtensionFactory` and class:`ServerExtensionFactory` for + :extension factories, +* :class:`Extension` for extensions. diff --git a/src/websockets/extensions/__init__.py b/src/websockets/extensions/__init__.py index e69de29bb..02838b98a 100644 --- a/src/websockets/extensions/__init__.py +++ b/src/websockets/extensions/__init__.py @@ -0,0 +1,4 @@ +from .base import * + + +__all__ = ["Extension", "ClientExtensionFactory", "ServerExtensionFactory"] diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 71783e176..6e5ef1b73 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -103,7 +103,7 @@ def parse( *, mask: bool, max_size: Optional[int] = None, - extensions: Optional[Sequence["websockets.extensions.base.Extension"]] = None, + extensions: Optional[Sequence["extensions.Extension"]] = None, ) -> Generator[None, None, "Frame"]: """ Read a WebSocket frame. @@ -172,7 +172,7 @@ def serialize( self, *, mask: bool, - extensions: Optional[Sequence["websockets.extensions.base.Extension"]] = None, + extensions: Optional[Sequence["extensions.Extension"]] = None, ) -> bytes: """ Write a WebSocket frame. @@ -338,4 +338,4 @@ def check_close(code: int) -> None: # at the bottom to allow circular import, because Extension depends on Frame -import websockets.extensions.base # isort:skip # noqa +from . import extensions # isort:skip # noqa diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index e41c295dd..627e6922c 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -31,7 +31,7 @@ async def read( *, mask: bool, max_size: Optional[int] = None, - extensions: Optional[Sequence["websockets.extensions.base.Extension"]] = None, + extensions: Optional[Sequence["extensions.Extension"]] = None, ) -> "Frame": """ Read a WebSocket frame. @@ -102,7 +102,7 @@ def write( write: Callable[[bytes], Any], *, mask: bool, - extensions: Optional[Sequence["websockets.extensions.base.Extension"]] = None, + extensions: Optional[Sequence["extensions.Extension"]] = None, ) -> None: """ Write a WebSocket frame. @@ -132,4 +132,4 @@ def write( # at the bottom to allow circular import, because Extension depends on Frame -import websockets.extensions.base # isort:skip # noqa +from .. import extensions # isort:skip # noqa From b99c4fe390a22cc846ce550a29f2c9841e99660d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 7 May 2021 21:44:51 +0200 Subject: [PATCH 0784/1539] Restore real imports for compatibility with mypy. Fix #940. --- docs/api/client.rst | 2 +- docs/api/index.rst | 9 +++++++-- docs/api/server.rst | 4 ++-- docs/changelog.rst | 16 ++++++++++++++++ docs/extensions.rst | 5 ++--- src/websockets/auth.py | 2 ++ src/websockets/client.py | 12 +++--------- src/websockets/server.py | 12 ++---------- 8 files changed, 35 insertions(+), 27 deletions(-) create mode 100644 src/websockets/auth.py diff --git a/docs/api/client.rst b/docs/api/client.rst index f969227a9..db8cbc914 100644 --- a/docs/api/client.rst +++ b/docs/api/client.rst @@ -1,7 +1,7 @@ Client ====== -.. automodule:: websockets.legacy.client +.. automodule:: websockets.client Opening a connection -------------------- diff --git a/docs/api/index.rst b/docs/api/index.rst index 20bb740b3..0a616cbce 100644 --- a/docs/api/index.rst +++ b/docs/api/index.rst @@ -46,5 +46,10 @@ both in the client API and server API. utilities All public APIs can be imported from the :mod:`websockets` package, unless -noted otherwise. Anything that isn't listed in this API documentation is a -private API, with no guarantees of behavior or backwards-compatibility. +noted otherwise. This convenience feature is incompatible with static code +analysis tools such as mypy_, though. + +.. _mypy: https://github.com/python/mypy + +Anything that isn't listed in this API documentation is a private API. There's +no guarantees of behavior or backwards-compatibility for private APIs. diff --git a/docs/api/server.rst b/docs/api/server.rst index 16c8f6359..9e7b801a9 100644 --- a/docs/api/server.rst +++ b/docs/api/server.rst @@ -1,7 +1,7 @@ Server ====== -.. automodule:: websockets.legacy.server +.. automodule:: websockets.server Starting a server ----------------- @@ -90,7 +90,7 @@ Server Basic authentication -------------------- -.. automodule:: websockets.legacy.auth +.. automodule:: websockets.auth .. autofunction:: basic_auth_protocol_factory diff --git a/docs/changelog.rst b/docs/changelog.rst index fb40aee2a..218bbec3d 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -37,6 +37,8 @@ They may change at any time. * Restored compatibility of ``python -m websockets`` with Python < 3.9. +* Restored compatibility with mypy. + 9.0.1 ..... @@ -73,6 +75,20 @@ They may change at any time. but that never happened. Keeping these APIs public makes it more difficult to improve websockets for no actual benefit. +.. note:: + + **Version 9.0 may require changes if you use static code analysis tools.** + + Convenience imports from the ``websockets`` module are performed lazily. + While this is supported by Python, static code analysis tools such as mypy + are unable to understand the behavior. + + If you depend on such tools, use the real import path, which can be found + in the API documentation:: + + from websockets.client import connect + from websockets.server import serve + * Added compatibility with Python 3.9. * Added support for IRIs in addition to URIs. diff --git a/docs/extensions.rst b/docs/extensions.rst index 042ed3d9a..f5e2f497f 100644 --- a/docs/extensions.rst +++ b/docs/extensions.rst @@ -14,9 +14,8 @@ specification, WebSocket Per-Message Deflate, specified in :rfc:`7692`. Per-Message Deflate ------------------- -:func:`~websockets.legacy.client.connect` and -:func:`~websockets.legacy.server.serve` enable the Per-Message Deflate -extension by default. +:func:`~websockets.client.connect` and :func:`~websockets.server.serve` enable +the Per-Message Deflate extension by default. If you want to disable it, set ``compression=None``:: diff --git a/src/websockets/auth.py b/src/websockets/auth.py new file mode 100644 index 000000000..f97c1feb0 --- /dev/null +++ b/src/websockets/auth.py @@ -0,0 +1,2 @@ +# See #940 for why lazy_import isn't used here for backwards compatibility. +from .legacy.auth import * # noqa diff --git a/src/websockets/client.py b/src/websockets/client.py index 91dd1662e..0ddf19f00 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -24,7 +24,6 @@ ) from .http import USER_AGENT, build_host from .http11 import Request, Response -from .imports import lazy_import from .typing import ( ConnectionOption, ExtensionHeader, @@ -36,14 +35,9 @@ from .utils import accept_key, generate_key -lazy_import( - globals(), - aliases={ - "connect": ".legacy.client", - "unix_connect": ".legacy.client", - "WebSocketClientProtocol": ".legacy.client", - }, -) +# See #940 for why lazy_import isn't used here for backwards compatibility. +from .legacy.client import * # isort:skip # noqa + __all__ = ["ClientConnection"] diff --git a/src/websockets/server.py b/src/websockets/server.py index 67ab83031..f57d36b70 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -26,7 +26,6 @@ ) from .http import USER_AGENT from .http11 import Request, Response -from .imports import lazy_import from .typing import ( ConnectionOption, ExtensionHeader, @@ -37,15 +36,8 @@ from .utils import accept_key -lazy_import( - globals(), - aliases={ - "serve": ".legacy.server", - "unix_serve": ".legacy.server", - "WebSocketServerProtocol": ".legacy.server", - "WebSocketServer": ".legacy.server", - }, -) +# See #940 for why lazy_import isn't used here for backwards compatibility. +from .legacy.server import * # isort:skip # noqa __all__ = ["ServerConnection"] From 0713dbf2d37a8c2c071d8479a6768dd3d3c7dacf Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 14 May 2021 07:54:58 +0200 Subject: [PATCH 0785/1539] Add test coverage. --- tests/test_auth.py | 1 + 1 file changed, 1 insertion(+) create mode 100644 tests/test_auth.py diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 000000000..d5a8bd9ad --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1 @@ +from websockets.auth import * # noqa From bc19676a06d410ba35363d4630fb417c24d90b66 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Fri, 14 May 2021 07:07:41 +0100 Subject: [PATCH 0786/1539] Add project_urls metadata (#943) --- setup.cfg | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/setup.cfg b/setup.cfg index d8877aa2e..8c7a4984a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,6 +3,11 @@ python-tag = py37.py38.py39.py310 [metadata] license_file = LICENSE +project_urls = + Changelog = https://websockets.readthedocs.io/en/stable/changelog.html + Documentation = https://websockets.readthedocs.io/ + Funding = https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=readme + Tracker = https://github.com/aaugustin/websockets/issues [flake8] ignore = E203,E731,F403,F405,W503 From ccfe98e5111a74d602df7a2a195951acf69057f0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 May 2021 10:42:29 +0200 Subject: [PATCH 0787/1539] Add FAQ on blocking send loops. Refs #867 (and others). --- docs/faq.rst | 53 ++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 41 insertions(+), 12 deletions(-) diff --git a/docs/faq.rst b/docs/faq.rst index 53da0f004..abd396cde 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -23,7 +23,6 @@ before returning. For example, if your handler has a structure similar to:: async def handler(websocket, path): - ... asyncio.create_task(do_some_work()) change it to:: @@ -185,7 +184,6 @@ Here's an example that terminates cleanly when it receives SIGTERM on Unix: .. literalinclude:: ../example/shutdown_client.py :emphasize-lines: 10-13 - How do I disable TLS/SSL certificate verification? .................................................. @@ -194,27 +192,50 @@ Look at the ``ssl`` argument of :meth:`~asyncio.loop.create_connection`. :func:`connect` accepts the same arguments as :meth:`~asyncio.loop.create_connection`. -Both sides ----------- +asyncio usage +------------- How do I do two things in parallel? How do I integrate with another coroutine? .............................................................................. You must start two tasks, which the event loop will run concurrently. You can -achieve this with :func:`asyncio.gather` or :func:`asyncio.wait`. - -This is also part of learning asyncio and not specific to websockets. +achieve this with :func:`asyncio.gather` or :func:`asyncio.create_task`. Keep track of the tasks and make sure they terminate or you cancel them when the connection terminates. -How do I create channels or topics? -................................... +Why does my program never receives any messages? +................................................ -websockets doesn't have built-in publish / subscribe for these use cases. +Your program runs a coroutine that never yield control to the event loop. The +coroutine that receives messages never gets a chance to run. -Depending on the scale of your service, a simple in-memory implementation may -do the job or you may need an external publish / subscribe component. +Putting an ``await`` statement in a ``for`` or a ``while`` loop isn't enough +to yield control. Awaiting a coroutine may yield control, but there's no +guarantee that it will. + +For example, ``send()`` only yields control when send buffers are full, which +never happens in most practical cases. + +If you run a loop that contains only synchronous operations and a ``send()`` +call, you must yield control explicitly with :func:`asyncio.sleep`:: + + async def producer(websocket): + message = generate_next_message() + await websocket.send(message) + await asyncio.sleep(0) # yield control to the event loop + +:func:`asyncio.sleep` always suspends the current task, allowing other tasks +to run. This behavior is documented precisely because it isn't expected from +every coroutine. + +See `issue 867`_. + +.. _issue 867: https://github.com/aaugustin/websockets/issues/867 + + +Both sides +---------- What does ``ConnectionClosedError: code = 1006`` mean? ...................................................... @@ -315,6 +336,14 @@ If this turns out to be impractical, you should use another library. Miscellaneous ------------- +How do I create channels or topics? +................................... + +websockets doesn't have built-in publish / subscribe for these use cases. + +Depending on the scale of your service, a simple in-memory implementation may +do the job or you may need an external publish / subscribe component. + How do I set a timeout on ``recv()``? ..................................... From 8900c13d3234c8ae87b0d852e849eaf6bf7cf8b7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 May 2021 14:19:47 +0200 Subject: [PATCH 0788/1539] Add mypy to dictionary. --- docs/spelling_wordlist.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index d7c744147..4d8fc1e2d 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -31,6 +31,7 @@ lifecycle Lifecycle lookups MiB +mypy nginx parsers permessage From b8517b11f98582d4ed3c0bb0c20c5ecf1c31df47 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 May 2021 14:18:14 +0200 Subject: [PATCH 0789/1539] Optimize default compression settings. --- .gitignore | 1 + benchmark/compression.py | 163 ++++++++++++++++++ {performance => benchmark}/mem_client.py | 19 +- {performance => benchmark}/mem_server.py | 16 +- docs/changelog.rst | 2 + docs/deployment.rst | 64 ++++--- .../extensions/permessage_deflate.py | 14 +- tests/extensions/test_permessage_deflate.py | 36 ++-- tests/legacy/test_client_server.py | 4 +- 9 files changed, 263 insertions(+), 56 deletions(-) create mode 100644 benchmark/compression.py rename {performance => benchmark}/mem_client.py (77%) rename {performance => benchmark}/mem_server.py (82%) diff --git a/.gitignore b/.gitignore index c23cf5210..ac68ff739 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ .mypy_cache .tox build/ +benchmark/corpus.pkl compliance/reports/ dist/ docs/_build/ diff --git a/benchmark/compression.py b/benchmark/compression.py new file mode 100644 index 000000000..15fb8653e --- /dev/null +++ b/benchmark/compression.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python + +import getpass +import json +import pickle +import subprocess +import sys +import time +import zlib + + +CORPUS_FILE = "corpus.pkl" + +REPEAT = 10 + +WB, ML = 12, 5 # defaults used as a reference + + +def _corpus(): + OAUTH_TOKEN = getpass.getpass("OAuth Token? ") + COMMIT_API = ( + f'curl -H "Authorization: token {OAUTH_TOKEN}" ' + f"https://api.github.com/repos/aaugustin/websockets/git/commits/:sha" + ) + + commits = [] + + head = subprocess.check_output("git rev-parse HEAD", shell=True).decode().strip() + todo = [head] + seen = set() + + while todo: + sha = todo.pop(0) + commit = subprocess.check_output(COMMIT_API.replace(":sha", sha), shell=True) + commits.append(commit) + seen.add(sha) + for parent in json.loads(commit)["parents"]: + sha = parent["sha"] + if sha not in seen and sha not in todo: + todo.append(sha) + time.sleep(1) # rate throttling + + return commits + + +def corpus(): + data = _corpus() + with open(CORPUS_FILE, "wb") as handle: + pickle.dump(data, handle) + + +def _benchmark(data): + size = {} + duration = {} + + for wbits in range(9, 16): + size[wbits] = {} + duration[wbits] = {} + + for memLevel in range(1, 10): + encoder = zlib.compressobj(wbits=-wbits, memLevel=memLevel) + encoded = [] + + t0 = time.perf_counter() + + for _ in range(REPEAT): + for item in data: + if isinstance(item, str): + item = item.encode("utf-8") + # Taken from PerMessageDeflate.encode + item = encoder.compress(item) + encoder.flush(zlib.Z_SYNC_FLUSH) + if item.endswith(b"\x00\x00\xff\xff"): + item = item[:-4] + encoded.append(item) + + t1 = time.perf_counter() + + size[wbits][memLevel] = sum(len(item) for item in encoded) + duration[wbits][memLevel] = (t1 - t0) / REPEAT + + raw_size = sum(len(item) for item in data) + + print("=" * 79) + print("Compression ratio") + print("=" * 79) + print("\t".join(["wb \\ ml"] + [str(memLevel) for memLevel in range(1, 10)])) + for wbits in range(9, 16): + print( + "\t".join( + [str(wbits)] + + [ + f"{100 * (1 - size[wbits][memLevel] / raw_size):.1f}%" + for memLevel in range(1, 10) + ] + ) + ) + print("=" * 79) + print() + + print("=" * 79) + print("CPU time") + print("=" * 79) + print("\t".join(["wb \\ ml"] + [str(memLevel) for memLevel in range(1, 10)])) + for wbits in range(9, 16): + print( + "\t".join( + [str(wbits)] + + [ + f"{1000 * duration[wbits][memLevel]:.1f}ms" + for memLevel in range(1, 10) + ] + ) + ) + print("=" * 79) + print() + + print("=" * 79) + print(f"Size vs. {WB} \\ {ML}") + print("=" * 79) + print("\t".join(["wb \\ ml"] + [str(memLevel) for memLevel in range(1, 10)])) + for wbits in range(9, 16): + print( + "\t".join( + [str(wbits)] + + [ + f"{100 * (size[wbits][memLevel] / size[WB][ML] - 1):.1f}%" + for memLevel in range(1, 10) + ] + ) + ) + print("=" * 79) + print() + + print("=" * 79) + print(f"Time vs. {WB} \\ {ML}") + print("=" * 79) + print("\t".join(["wb \\ ml"] + [str(memLevel) for memLevel in range(1, 10)])) + for wbits in range(9, 16): + print( + "\t".join( + [str(wbits)] + + [ + f"{100 * (duration[wbits][memLevel] / duration[WB][ML] - 1):.1f}%" + for memLevel in range(1, 10) + ] + ) + ) + print("=" * 79) + print() + + +def benchmark(): + with open(CORPUS_FILE, "rb") as handle: + data = pickle.load(handle) + _benchmark(data) + + +try: + run = globals()[sys.argv[1]] +except (KeyError, IndexError): + print(f"Usage: {sys.argv[0]} [corpus|benchmark]") +else: + run() diff --git a/performance/mem_client.py b/benchmark/mem_client.py similarity index 77% rename from performance/mem_client.py rename to benchmark/mem_client.py index 6eab690d8..db68eb995 100644 --- a/performance/mem_client.py +++ b/benchmark/mem_client.py @@ -8,9 +8,11 @@ from websockets.extensions import permessage_deflate -CLIENTS = 10 +CLIENTS = 20 INTERVAL = 1 / 10 # seconds +WB, ML = 12, 5 + MEM_SIZE = [] @@ -24,9 +26,9 @@ async def mem_client(client): "ws://localhost:8765", extensions=[ permessage_deflate.ClientPerMessageDeflateFactory( - server_max_window_bits=10, - client_max_window_bits=10, - compress_settings={"memLevel": 3}, + server_max_window_bits=WB, + client_max_window_bits=WB, + compress_settings={"memLevel": ML}, ) ], ) as ws: @@ -43,9 +45,12 @@ async def mem_client(client): await asyncio.sleep(CLIENTS * INTERVAL) -asyncio.run( - asyncio.gather(*[mem_client(client) for client in range(CLIENTS + 1)]) -) +async def mem_clients(): + await asyncio.gather(*[mem_client(client) for client in range(CLIENTS + 1)]) + + +asyncio.run(mem_clients()) + # First connection incurs non-representative setup costs. del MEM_SIZE[0] diff --git a/performance/mem_server.py b/benchmark/mem_server.py similarity index 82% rename from performance/mem_server.py rename to benchmark/mem_server.py index 81490a0e7..852796249 100644 --- a/performance/mem_server.py +++ b/benchmark/mem_server.py @@ -1,6 +1,7 @@ #!/usr/bin/env python import asyncio +import os import signal import statistics import tracemalloc @@ -9,9 +10,11 @@ from websockets.extensions import permessage_deflate -CLIENTS = 10 +CLIENTS = 20 INTERVAL = 1 / 10 # seconds +WB, ML = 12, 5 + MEM_SIZE = [] @@ -34,17 +37,22 @@ async def handler(ws, path): async def mem_server(): loop = asyncio.get_running_loop() stop = loop.create_future() + # Set the stop condition when receiving SIGTERM. + print("Stop the server with:") + print(f"kill -TERM {os.getpid()}") + print() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) + async with websockets.serve( handler, "localhost", 8765, extensions=[ permessage_deflate.ServerPerMessageDeflateFactory( - server_max_window_bits=10, - client_max_window_bits=10, - compress_settings={"memLevel": 3}, + server_max_window_bits=WB, + client_max_window_bits=WB, + compress_settings={"memLevel": ML}, ) ], ): diff --git a/docs/changelog.rst b/docs/changelog.rst index d2d58c57b..434cdea61 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -42,6 +42,8 @@ They may change at any time. * Added compatibility with Python 3.10. +* Optimized default compression settings to reduce memory usage. + 9.0.2 ..... diff --git a/docs/deployment.rst b/docs/deployment.rst index 8baa8836c..aa1af211c 100644 --- a/docs/deployment.rst +++ b/docs/deployment.rst @@ -6,8 +6,8 @@ Deployment Application server ------------------ -The author of ``websockets`` isn't aware of best practices for deploying -network services based on :mod:`asyncio`, let alone application servers. +The author of websockets isn't aware of best practices for deploying network +services based on :mod:`asyncio`, let alone application servers. You can run a script similar to the :ref:`server example `, inside a supervisor if you deem that useful. @@ -59,7 +59,7 @@ memory usage can become a bottleneck. Memory usage of a single connection is the sum of: -1. the baseline amount of memory ``websockets`` requires for each connection, +1. the baseline amount of memory websockets requires for each connection, 2. the amount of data held in buffers before the application processes it, 3. any additional memory allocated by the application itself. @@ -71,60 +71,78 @@ Baseline Compression settings are the main factor affecting the baseline amount of memory used by each connection. -By default ``websockets`` maximizes compression rate at the expense of memory -usage. If memory usage is an issue, lowering compression settings can help: +If you'd like to customize compression settings, here are the main knobs. - Context Takeover is necessary to get good performance for almost all applications. It should remain enabled. - Window Bits is a trade-off between memory usage and compression rate. - It defaults to 15 and can be lowered. The default value isn't optimal - for small, repetitive messages which are typical of WebSocket servers. + It should be an integer between 9 (lowest memory usage) and 15 (highest + compression rate). Setting it to 8 is possible but triggers a bug in some + versions of zlib. - Memory Level is a trade-off between memory usage and compression speed. - It defaults to 8 and can be lowered. A lower memory level can actually - increase speed thanks to memory locality, even if the CPU does more work! + However, a lower memory level can increase speed thanks to memory locality, + even if the CPU does more work! It should be an integer between 1 (lowest + memory usage) and 9 (highest compression speed in theory, not in practice). -See this :ref:`example ` for how to -configure compression settings. +By default, websockets enables compression with conservative settings that +optimize memory usage at the cost of a slightly worse compression rate: Window +Bits = 12 and Memory Level = 5. This strikes a good balance for small messages +that are typical of WebSocket servers. + +If you'd like to configure different compression settings, see this +:ref:`example `. If you don't set +limits on Window Bits and neither does the remote endpoint, it defaults to the +maximum value of 15. If you don't set Memory Level, it defaults to 8 — more +accurately, to ``zlib.DEF_MEM_LEVEL`` which is 8. Here's how various compression settings affect memory usage of a single -connection on a 64-bit system, as well a benchmark_ of compressed size and +connection on a 64-bit system, as well a benchmark of compressed size and compression time for a corpus of small JSON documents. +-------------+-------------+--------------+--------------+------------------+------------------+ | Compression | Window Bits | Memory Level | Memory usage | Size vs. default | Time vs. default | +=============+=============+==============+==============+==================+==================+ -| *default* | 15 | 8 | 325 KiB | +0% | +0% + +| | 15 | 8 | 322 KiB | -4.0% | +15% + ++-------------+-------------+--------------+--------------+------------------+------------------+ +| | 14 | 7 | 178 KiB | -2.6% | +10% | +-------------+-------------+--------------+--------------+------------------+------------------+ -| | 14 | 7 | 181 KiB | +1.5% | -5.3% | +| | 13 | 6 | 106 KiB | -1.4% | +5% | +-------------+-------------+--------------+--------------+------------------+------------------+ -| | 13 | 6 | 110 KiB | +2.8% | -7.5% | +| *default* | 12 | 5 | 70 KiB | = | = | +-------------+-------------+--------------+--------------+------------------+------------------+ -| | 12 | 5 | 73 KiB | +4.4% | -18.9% | +| | 11 | 4 | 52 KiB | +3.7% | -5% | +-------------+-------------+--------------+--------------+------------------+------------------+ -| | 11 | 4 | 55 KiB | +8.5% | -18.8% | +| | 10 | 3 | 43 KiB | +90% | +50% | +-------------+-------------+--------------+--------------+------------------+------------------+ -| *disabled* | N/A | N/A | 22 KiB | N/A | N/A | +| | 9 | 2 | 39 KiB | +160% | +100% | ++-------------+-------------+--------------+--------------+------------------+------------------+ +| *disabled* | N/A | N/A | 19 KiB | N/A | N/A | +-------------+-------------+--------------+--------------+------------------+------------------+ *Don't assume this example is representative! Compressed size and compression time depend heavily on the kind of messages exchanged by the application!* -You can run the same benchmark for your application by creating a list of -typical messages and passing it to the ``_benchmark`` function_. +You can adapt the `compression.py`_ benchmark for your application by creating +a list of typical messages and passing it to the ``_benchmark`` function. -.. _benchmark: https://gist.github.com/aaugustin/fbea09ce8b5b30c4e56458eb081fe599 -.. _function: https://gist.github.com/aaugustin/fbea09ce8b5b30c4e56458eb081fe599#file-compression-py-L48-L144 +.. _compression.py: https://github.com/aaugustin/websockets/blob/main/performance/compression.py This `blog post by Ilya Grigorik`_ provides more details about how compression settings affect memory usage and how to optimize them. .. _blog post by Ilya Grigorik: https://www.igvita.com/2013/11/27/configuring-and-optimizing-websocket-compression/ -This `experiment by Peter Thorson`_ suggests Window Bits = 11, Memory Level = +This `experiment by Peter Thorson`_ suggests Window Bits = 11 and Memory Level = 4 as a sweet spot for optimizing memory usage. .. _experiment by Peter Thorson: https://www.ietf.org/mail-archive/web/hybi/current/msg10222.html +websockets defaults to Window Bits = 12 and Memory Level = 5 in order to stay +away from Window Bits = 10 or Memory Level = 3, where performance craters in +the benchmark. This raises doubts on what could happen at Window Bits = 11 and +Memory Level = 4 on a different set of messages. The defaults needs to be safe +for all applications, hence a more conservative choice. + Buffers ....... diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 4f520af38..34cc1f950 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -448,7 +448,11 @@ def enable_client_permessage_deflate( for extension_factory in extensions ): extensions = list(extensions) + [ - ClientPerMessageDeflateFactory(client_max_window_bits=True) + ClientPerMessageDeflateFactory( + server_max_window_bits=12, + client_max_window_bits=12, + compress_settings={"memLevel": 5}, + ) ] return extensions @@ -631,5 +635,11 @@ def enable_server_permessage_deflate( ext_factory.name == ServerPerMessageDeflateFactory.name for ext_factory in extensions ): - extensions = list(extensions) + [ServerPerMessageDeflateFactory()] + extensions = list(extensions) + [ + ServerPerMessageDeflateFactory( + server_max_window_bits=12, + client_max_window_bits=12, + compress_settings={"memLevel": 5}, + ) + ] return extensions diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index 908cd91a4..d3c6f0ac6 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -512,33 +512,33 @@ def test_enable_client_permessage_deflate(self): ) in [ ( None, - (1, 0, None), + (1, 0, {"memLevel": 5}), ), ( [], - (1, 0, None), + (1, 0, {"memLevel": 5}), ), ( [ClientNoOpExtensionFactory()], - (2, 1, None), + (2, 1, {"memLevel": 5}), ), ( - [ClientPerMessageDeflateFactory(compress_settings={"level": 1})], - (1, 0, {"level": 1}), + [ClientPerMessageDeflateFactory(compress_settings={"memLevel": 7})], + (1, 0, {"memLevel": 7}), ), ( [ - ClientPerMessageDeflateFactory(compress_settings={"level": 1}), + ClientPerMessageDeflateFactory(compress_settings={"memLevel": 7}), ClientNoOpExtensionFactory(), ], - (2, 0, {"level": 1}), + (2, 0, {"memLevel": 7}), ), ( [ ClientNoOpExtensionFactory(), - ClientPerMessageDeflateFactory(compress_settings={"level": 1}), + ClientPerMessageDeflateFactory(compress_settings={"memLevel": 7}), ], - (2, 1, {"level": 1}), + (2, 1, {"memLevel": 7}), ), ]: with self.subTest(extensions=extensions): @@ -849,33 +849,33 @@ def test_enable_server_permessage_deflate(self): ) in [ ( None, - (1, 0, None), + (1, 0, {"memLevel": 5}), ), ( [], - (1, 0, None), + (1, 0, {"memLevel": 5}), ), ( [ServerNoOpExtensionFactory()], - (2, 1, None), + (2, 1, {"memLevel": 5}), ), ( - [ServerPerMessageDeflateFactory(compress_settings={"level": 1})], - (1, 0, {"level": 1}), + [ServerPerMessageDeflateFactory(compress_settings={"memLevel": 7})], + (1, 0, {"memLevel": 7}), ), ( [ - ServerPerMessageDeflateFactory(compress_settings={"level": 1}), + ServerPerMessageDeflateFactory(compress_settings={"memLevel": 7}), ServerNoOpExtensionFactory(), ], - (2, 0, {"level": 1}), + (2, 0, {"memLevel": 7}), ), ( [ ServerNoOpExtensionFactory(), - ServerPerMessageDeflateFactory(compress_settings={"level": 1}), + ServerPerMessageDeflateFactory(compress_settings={"memLevel": 7}), ], - (2, 1, {"level": 1}), + (2, 1, {"memLevel": 7}), ), ]: with self.subTest(extensions=extensions): diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 80c9d56bb..353e5b370 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -985,11 +985,11 @@ def test_extensions_error_no_extensions(self, _process_extensions): def test_compression_deflate(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual( - server_extensions, repr([PerMessageDeflate(False, False, 15, 15)]) + server_extensions, repr([PerMessageDeflate(False, False, 12, 12)]) ) self.assertEqual( repr(self.client.extensions), - repr([PerMessageDeflate(False, False, 15, 15)]), + repr([PerMessageDeflate(False, False, 12, 12)]), ) def test_compression_unsupported_server(self): From 66ded17c39cd6e5bb408edd9b53b9afad7ab19fc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 May 2021 17:54:40 +0200 Subject: [PATCH 0790/1539] Main branch is now called main. --- README.rst | 12 ++++++------ docs/contributing.rst | 2 +- docs/heroku.rst | 12 ++++++------ docs/index.rst | 4 ++-- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/README.rst b/README.rst index fa42a9be3..cd635835c 100644 --- a/README.rst +++ b/README.rst @@ -28,8 +28,8 @@ What is ``websockets``? ``websockets`` is a library for building WebSocket servers_ and clients_ in Python with a focus on correctness and simplicity. -.. _servers: https://github.com/aaugustin/websockets/blob/master/example/server.py -.. _clients: https://github.com/aaugustin/websockets/blob/master/example/client.py +.. _servers: https://github.com/aaugustin/websockets/blob/main/example/server.py +.. _clients: https://github.com/aaugustin/websockets/blob/main/example/client.py Built on top of ``asyncio``, Python's standard asynchronous I/O framework, it provides an elegant coroutine-based API. @@ -80,7 +80,7 @@ Does that look good? .. raw:: html
- +

websockets for enterprise

Available as part of the Tidelift Subscription

The maintainers of websockets and thousands of other packages are working with Tidelift to deliver commercial support and maintenance for the open source dependencies you use to build your applications. Save time, reduce risk, and improve code health, while paying the maintainers of the exact dependencies you use. Learn more.

@@ -113,7 +113,7 @@ Docs`_ and see for yourself. .. _Read the Docs: https://websockets.readthedocs.io/ .. _handle backpressure correctly: https://vorpus.org/blog/some-thoughts-on-asynchronous-api-design-in-a-post-asyncawait-world/#websocket-servers -.. _Autobahn Testsuite: https://github.com/aaugustin/websockets/blob/master/compliance/README.rst +.. _Autobahn Testsuite: https://github.com/aaugustin/websockets/blob/main/compliance/README.rst Why shouldn't I use ``websockets``? ----------------------------------- @@ -145,8 +145,8 @@ For anything else, please open an issue_ or send a `pull request`_. Participants must uphold the `Contributor Covenant code of conduct`_. -.. _Contributor Covenant code of conduct: https://github.com/aaugustin/websockets/blob/master/CODE_OF_CONDUCT.md +.. _Contributor Covenant code of conduct: https://github.com/aaugustin/websockets/blob/main/CODE_OF_CONDUCT.md ``websockets`` is released under the `BSD license`_. -.. _BSD license: https://github.com/aaugustin/websockets/blob/master/LICENSE +.. _BSD license: https://github.com/aaugustin/websockets/blob/main/LICENSE diff --git a/docs/contributing.rst b/docs/contributing.rst index 61c0b979c..59d7451e0 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -10,7 +10,7 @@ This project and everyone participating in it is governed by the `Code of Conduct`_. By participating, you are expected to uphold this code. Please report inappropriate behavior to aymeric DOT augustin AT fractalideas DOT com. -.. _Code of Conduct: https://github.com/aaugustin/websockets/blob/master/CODE_OF_CONDUCT.md +.. _Code of Conduct: https://github.com/aaugustin/websockets/blob/main/CODE_OF_CONDUCT.md *(If I'm the person with the inappropriate behavior, please accept my apologies. I know I can mess up. I can't expect you to tell me, but if you diff --git a/docs/heroku.rst b/docs/heroku.rst index 8af2ebd3d..d23dc64c0 100644 --- a/docs/heroku.rst +++ b/docs/heroku.rst @@ -16,10 +16,10 @@ Deploying to Heroku requires a git repository. Let's initialize one: $ mkdir websockets-echo $ cd websockets-echo - $ git init . + $ git init -b main Initialized empty Git repository in websockets-echo/.git/ $ git commit --allow-empty -m "Initial commit." - [master (root-commit) 1e7947d] Initial commit. + [main (root-commit) 1e7947d] Initial commit. Follow the `set-up instructions`_ to install the Heroku CLI and to log in, if you haven't done that yet. @@ -32,7 +32,7 @@ you'll have to pick a different name because I'm already using .. code:: console - $ $ heroku create websockets-echo + $ heroku create websockets-echo Creating ⬢ websockets-echo... done https://websockets-echo.herokuapp.com/ | https://git.heroku.com/websockets-echo.git @@ -86,7 +86,7 @@ Confirm that you created the correct files and commit them to git: Procfile app.py requirements.txt $ git add . $ git commit -m "Deploy echo server to Heroku." - [master 8418c62] Deploy echo server to Heroku. + [main 8418c62] Deploy echo server to Heroku.  3 files changed, 19 insertions(+)  create mode 100644 Procfile  create mode 100644 app.py @@ -99,7 +99,7 @@ Our app is ready. Let's deploy it! .. code:: console - $ git push heroku master + $ git push heroku main ... lots of output... @@ -109,7 +109,7 @@ Our app is ready. Let's deploy it! remote: remote: Verifying deploy... done. To https://git.heroku.com/websockets-echo.git -  * [new branch] master -> master +  * [new branch] main -> main Validate deployment ------------------- diff --git a/docs/index.rst b/docs/index.rst index 257a806ea..4ec682f69 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -21,8 +21,8 @@ websockets ``websockets`` is a library for building WebSocket servers_ and clients_ in Python with a focus on correctness and simplicity. -.. _servers: https://github.com/aaugustin/websockets/blob/master/example/server.py -.. _clients: https://github.com/aaugustin/websockets/blob/master/example/client.py +.. _servers: https://github.com/aaugustin/websockets/blob/main/example/server.py +.. _clients: https://github.com/aaugustin/websockets/blob/main/example/client.py Built on top of :mod:`asyncio`, Python's standard asynchronous I/O framework, it provides an elegant coroutine-based API. From a14226afb77b524c2ced7d649ac7420a14992716 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 May 2021 18:01:23 +0200 Subject: [PATCH 0791/1539] Bump version number. --- docs/changelog.rst | 2 +- src/websockets/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 218bbec3d..1064af736 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -33,7 +33,7 @@ They may change at any time. 9.0.2 ..... -*In development* +*May 15, 2021* * Restored compatibility of ``python -m websockets`` with Python < 3.9. diff --git a/src/websockets/version.py b/src/websockets/version.py index 23b7f329b..02dbe9d3c 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -1 +1 @@ -version = "9.0.1" +version = "9.0.2" From f9371fca175f799ae3f1cc1cb0d5122cfd25d8de Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 May 2021 19:03:11 +0200 Subject: [PATCH 0792/1539] Improve example in intro. * Refactor to increase clarity. * Avoid deprecated usage of asyncio.wait. * Clarify what happens when clients disconnect. * Restore call to main() -- it disappeared. Fix #757. --- example/counter.py | 49 +++++++++++++++++++++++----------------------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/example/counter.py b/example/counter.py index 81cbdb55c..a9ed61893 100755 --- a/example/counter.py +++ b/example/counter.py @@ -22,47 +22,46 @@ def users_event(): return json.dumps({"type": "users", "count": len(USERS)}) -async def notify_state(): - if USERS: # asyncio.wait doesn't accept an empty list - message = state_event() - await asyncio.wait([user.send(message) for user in USERS]) - - -async def notify_users(): - if USERS: # asyncio.wait doesn't accept an empty list - message = users_event() - await asyncio.wait([user.send(message) for user in USERS]) - - -async def register(websocket): - USERS.add(websocket) - await notify_users() - - -async def unregister(websocket): - USERS.remove(websocket) - await notify_users() +async def broadcast(message): + # asyncio.wait doesn't accept an empty list + if not USERS: + return + # Ignore return value. If a user disconnects before we send + # the message to them, there's nothing we can do anyway. + await asyncio.wait([ + asyncio.create_task(user.send(message)) + for user in USERS + ]) async def counter(websocket, path): - # register(websocket) sends user_event() to websocket - await register(websocket) try: + # Register user + USERS.add(websocket) + await broadcast(users_event()) + # Send current state to user await websocket.send(state_event()) + # Manage state changes async for message in websocket: data = json.loads(message) if data["action"] == "minus": STATE["value"] -= 1 - await notify_state() + await broadcast(state_event()) elif data["action"] == "plus": STATE["value"] += 1 - await notify_state() + await broadcast(state_event()) else: logging.error("unsupported event: %s", data) finally: - await unregister(websocket) + # Unregister user + USERS.remove(websocket) + await broadcast(users_event()) async def main(): async with websockets.serve(counter, "localhost", 6789): await asyncio.Future() # run forever + + +if __name__ == "__main__": + asyncio.run(main()) From 32a135e4b33020eb1a2a45cfa5d90dc2e22cbeb1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 May 2021 08:26:01 +0200 Subject: [PATCH 0793/1539] Generate Github social preview. --- logo/github-social-preview.html | 38 ++++++++++++++++++++++++++++++++ logo/github-social-preview.png | Bin 0 -> 28346 bytes 2 files changed, 38 insertions(+) create mode 100644 logo/github-social-preview.html create mode 100644 logo/github-social-preview.png diff --git a/logo/github-social-preview.html b/logo/github-social-preview.html new file mode 100644 index 000000000..187a183b0 --- /dev/null +++ b/logo/github-social-preview.html @@ -0,0 +1,38 @@ + + + + GitHub social preview + + + +

Take a screenshot of this DOM node to make a PNG.

+

For 2x DPI screens.

+

+

For regular screens.

+

+ diff --git a/logo/github-social-preview.png b/logo/github-social-preview.png new file mode 100644 index 0000000000000000000000000000000000000000..59a51b6e33810bc0bc34b7d2ea6e47b0fd8be78b GIT binary patch literal 28346 zcmb5WcT|&2^e!4e!3Nk-ilEp5QITG4fFPh!A~jf$CPi9k31CA-Km|mlsWc%#=q->$ z5tJqfhENlVv_J?offTrt;CJs?_uRAA@A^mNGc&XI?0NRHpFQu(_xzfvf$$dbEf5Gq z_=@4B8xY7w@FRctCVueO)z=s35Qt~Ql}qO>0{Nzie8LdOHvW$cMMUe+$Bj^&2pF?r zANb%C{2zVjg>U@-HJT3s(F^~deQb9G9{f)~cs~D6KmLb%Md{F|G`_C>4DKO z3u14>obIH$l~*jLtJbhqSY+vj!IlU?1<#1Vy{bEtAM1JZ^X)vW*FVsuBo|lG+n%EI zJ};$*(Za^2O|{C7pY}qx)%q$_SgHirRd$xH`LxD-oA)(B-G_9M}@u5U9PscqM{k(IC_c{;5O8m@*&R_!w19OYA1zO>>_@ew3RIAPh)EHwzbyQwmz;2( z`BM>aa^0fL;L`<4dWAv0?%KBxzKVG(4_Ml}CQ6)>>bxQ@E-C*eKz~@kc9|nsKHn49 zmJoA# zrE16ObWnFB`FtMEJ8(vis1Dx&3ngt(8%pxeyh<6R})2 zfH}G7O>0`6)s!i!T6@--QaX3Z@(zX{63)L3qPIcKylat}6?K+gFn_TKo_Qx1eTPLL z5dT)y3$>B*EFq9<6@`=4WauyM|_`qH?f9CEW1F)DIC4k4yWqN>*H4&Aiau zH6>#pkZF5F<@lBm=p!kl35$dIr32F;*>wNu%2`jZ?fSn+p5hW==Z8rN1z3$N=F{7> zCP}0Tn?Nijt8@03OdGm7#r<>e=JCqkvrWAqb}2C*YUSdRzxL9|V|Q8ef7$&?s&0GM z&*!nGU)fA%JJo)VAsOQX4)PsE^CJh@mlpP`Rp**T!Mb4?tC>;ph!AK%BrIiz7MpUH zQ3i{XhplR8x*M$egW|xY@Q532(91m}Lc}ea{lG#`Ne(uPLqFnsoq0Ly6CV!@cYAz> z3xqnfTo#k=Y&9mPY3lFiFSia|!@}A^t>rq!5NJdn4d!>2)aREhvW*9g z!P=J3ZVY;XY107L)n1zYdy4xo8mob&^rgMuPxHtS5a1{J=+j9B-h=0SDdTz%Q=Qn>%mx?0rgGKz+$vz){3Wtc@}3q(rtjt-wWKX!ep zt&eJNEr}eZ+V@B!*C^i0%2ZtoLkJ{oX&q zL4JeR{t>|asbIOG;dbEnlU5P|euG{hh@@*3a*%NpQ7N_EJUBtHrd4}$Yf&TDtKg&U z$hEt@G*PMLeNyWQ=eU8YZN6V81Zr>?_#ufDDkr1OSGyBW_<;-qp%35TlUY2Psmw<& zkQc#q&TmThl84!>Hmg75|zu~xRRC6>^hp^v_0pJS_stIDg-J*B`!LIy;)mM z5cy{i@1)de?6>59hemN_FbXcrw!e)Op~-BuhC*Cqr|Tf(Blf|2pL)u+ABT))aJ}rshXx%wOlW0$snmw?%NCAg?W8 ziHw4)f(MZPH02de;`0$Jsm}toe7wYd_d2+Bc-9>QUwX_TCH#quk=F@1a1G&)oqZy5 zpO1IcgFgux^n#WlB9caF0}ifILcCas83P*uxXT3&Df_9T%we>SJuFp?s z_D)+tf@%OTW117mSIJAKqykU>qhs2+sv>wj$%2Qm zr)Tb0v1Dh&o^`B_-L>1=R*MoSrWf!_A_sW9!TgLud*U>*{w?z{t4bD6N~j-#;z^dK zmuU9FF1Rs3q|#ntA=akGA=&(pOF`TUkJ-}PgW2M}oPdFkhAt};7)5Ji(MNg!VxQ3g z%|dRgq`}qLMtv9`9d;8LV^ltXDOO|a&@Yy2(|?z%?sfhh!6y^E6~wgCre~T$vP?yV zK&S1g2bd&6?CoH*D^4R8SxQ~1f6raAgS%vxWR>!T5a`}WSWXC3DT8J|sRo6GmFqBv zx`<-j<(S+fnwXrFDxtUFDRN;nx3%y>_@1k*rEz7A(~wKHT*r?U5DmyA%eS+&a;!RZ z43CaKXw473HPGq4AS4(*aS|vG#|NG{&LPlWk+78zXwE#N0F)>VTQ-%LR`IqSSaRc9 z3PjYs*GtpJge*JJ+`OKFBP^%9tD}X}F+FKev3utLYsn0AclgESz9^-{IjRm_S}QkR zhi3CkyW&J~G_!B0t(}v2$o3&F6aRI{E$_hue+!_U+z<()bFF4^m4R~@l zVC-u=5H{iRonCNZvHkZ?9?g0D8;i|eXb(whIAxeFAfWesFJPJNao$~?v?h>^weTf4 zs}|*=&Zek=tX(4X{B7%2<1TxfyKGr&%4=5pYgPz!WCPW?C4Mwyc7V-v z@KWHWoyP-ik!2=sS1IX1N*l`XP20q*CIN;{t8Vu!&6wlIVgcofDC zk%;%jU&XjlG2}*NM-^N-=pxd@9J8f^^l%Dj#!n$oIv|%H6jUE=_W;sFzs*>L1_~W_ zd2dv? zO@O6d4w4c~$1;Rhr0lhuY)s5q%~3vnON287)tB^wt4o9;P)74GVQ+p#`OmkCi|ziPQpCU2yxQ5AXi6 zBZG!d-%DNaqq#*X-0w0 z<{!MUtzaKKG^!CW-8DYW`J)7+2~KuVO7-)m^s)}Gw$Cxmpm6%bFg-{^wV z7_rjUbwB1M%NIVur97~{nE$X&i1^7O{=LS)!S90TQ9yn@7y67o45K zZd>W-`U>$vzmpHrP8b-WkO-^Wr|sX@fy#gg(m7@DBzbkO#;ImZEz0Bo$gWayg6`pl z#67yP)^mS|eG_i{V&uOu&c}$ zu6wFo$(z_bhU~3Mc6=~C)o!IzZMO3PX#^qSq68N4%&Ihf!Q|BOo4qVK?{2p~MbmC`O zTdGMYfq0kl^bS+|Fj{j~?!hF`l^5pPe+Qc**6Fch2EVR6P@;AcJ+&w=dlE)GaG7%* zYeaYAtTs2fFDX#dhdVNL{QjQERcF*5U|Q1S%{DsF`4;ZuVy=dLF4+%CFR&t*NaOZUWzw3lJD0 z27d9HoCVlu&99X+M@aRe71O9IRAs;yL!j;(rtP28{v8xLzp0Y|CGm6| zkVruPpkc>SYp*M02pIQZd>FHGxPK9BF~M#JMH2{}PEh$~nhSKpstgPPBO#5#+ft8h z?Z<3c`{(d0i<^5X05knSH-c7@f-0F(MXYp+h5T|z1U?@z_)1DB{BLt*HNSZ||MdWp zLxql;!va<(qXv953gc_#e3vuVt%v{Y>g;26yr~Naf$qNmII*0}2O9C%G?s>U!ebG( z&;2(*h=Ogf|1Dc6&nnIifi{+dCKE^vc5k9t!1DZ$Hu#>m7Df~J^zmb{coM;A2hb%! z!_5k1CizYN|D<*kr7no(#u%i+$Ibd^VVbk12jPQ>V@F@~VRpW9#`Sdl>HTCU#H-?` zDXhBlK)p*Nk4)POv!{D3U~#b(y7W1^ZfNsVG=9W-+P)D~!MkH)7n|(Y9lBgGB_)Cbmk^-w)0E0Z zHyZSXu5hvsX?XkEj(-WdGE6Vn@OQ9k>c8=K-bw|y(oS0bA`xV$L@2`l05ysxV&|P{ z7nGlqf2XD9@Q*t5yIMKX`R;gS%O!=T%jLNo94zpA~c!s~-Va%mNlwi_!x(2jIF50hpzkrZaH#uSlL| z;g?!DK6SsrnLdRxCXIUeXY9aPj zo7Pp*h{CaY9lyIc_^98qa?f;g9c@Yy<7az8=zx;q9CtmtRZ;?Kzc##Lejm@ZiJHt0 z^qi*yxu?z(3a%53EtmK$Yi}w~-!M_nEz>K5T^|}*Cgcv4ks*;VAY{pvrSQ>M^Cjgg zfzx%ZXF=zPZt0N)(iIn0>^ei2wG~Zn)3y(uU`N{?0FSszOAU|HM zNt~S!N?N$j|FOxF@O1q=?`Px5p*=zoPkk5OCEm{up&4B~mFC?u0=U%Nli zqO&{q>Bsz%sien`3fSqmlyXzv@rPea4oxbytoQSWfBq$v@)1Q7eSUp3EhkI<2qAts zGz!$*^%jN#sQFi^Nh}t3g4Q9mUn(#*xh+-KCT?J$6ZNp@)IzD;+Tx@;oM5=#bl0Zj zszWV9&%UHg{XyOl%wYQxkXIVbZF+Y1aH)sP2SisiR%?DW>!O(F33OS2Uv$iX>@Awx zFvGQJQ;?8D!Alc;8`CkgiG$1A5s?SP)msFkxv_=Afh zikdW%)3f|h*IZR+o}_<5Dsbk27o#AR?c;9ioG4zvfI3;Ph8Gou#5b>B3YuYZ&GyYv z+7a95Hu7SgIgAl0ZMI78|y}L0v@J@Vfwlr|R zRHg5q`||&_^_PnsrYTgCW%!QS!E;M6-*U}k{9Med3td?eNKe^LsaE>;!E2%ZVSeANrfBb zya)}(B2i+20BSQ=%~LAkw`ndIX^}8YHNOz3Krbymf(oB!viv`n3WSx-h5KB&Fr%IX zD?YoFu?|w;TJ5n8w!i7@bfgYN^vrJDylMext^4;Fa-D^&R^8TI?Z=1*usC% z()CHxk*aKC59F2pQ%O~7yn|n!uL^!%--(^Ef=Ri8vR=i>>Anj1Y2X{9fL|JEax{No zJfiV8Ap$Oa;>Ia;M1KERVoToxi*;b%$|Y;u=3cfs0;bJZwd^^3P-?k*)Q1N67j*Uo zl1L7LfnCk3^&0NYGtVT!2kE~+x_r)t{P5KM%)4>LdzbFDiP6Hq0PZX*nCY6{KZK9q zzy0wYs)Cof1_GTR+KVJou1=fXyLIK9pCftUa?_EXY@_?gMn>btlP)|u{)oCUkTPby zM0iku6&(2Ls8sss5pEBQX}7YdN}!bH)cye5JZl+0b4uu;PIAOvhF~J-8}u_-6YtX;F_Bog?J!2C;f*Cy|Lu)i96~v0=VJ@ zvnzWntC@@`)?tH7Xh0?9h$^pd;+whADR_+$5KAN+cQ^5n!o9TiTu}NIACHn* zzt)@)^oP40PoB>-E9SRSn+B#PmFcD_Wi_N;bV3mfP1;vu-BbjMbEb6_6bkNru|ZfbQjHdCCiON9N|3sDc`WGe-5YVOVk&Y14=(ZC%%}Rd?^9;3J*U&U|n~ZhzTw9`l znKzpzsrulUA2t9Dil$76gqj2?1b7XBr#Ba>;7h5Xj|R=lHZ3;&SVP*L#+7=baMh81 zoIqIVQ49)Wuu+yYGbtPQqWa!?v3Z(XYl!c}st}+C7D6$`8dBIu7^^Kc?12}{eyF*P z^~TEk$PF4E-DZ~OdS9wnbRxBC<(1sp#6%Sa1eaTe^#+3Z%vgf#v6V7B73!QrkP82h-IxQ5G(ikqt zzRcdr!S+q?(Pz9abCg6taAhU9ajSYx!FwqfZt3K=vvJ6Nxc(%?jrby@mxyJ4=nu93 zCKVXeP&FGe@srw?DkaxnPP(jj+44|+`x*SEMzlR{@03VDznxC8JW%ZSh$NUw?HX(S%CcHjH#S^XWCQHna zEH5n2*b8U=qH6hF-@LsEaZCR%Ij7wj)8JpNZJ@i1F^b~wJ%VcI^Q z5kPBttXU|xU)i~0i#X4>3D=#grD((wp^B6!aY$p9GQ9wvNi~uPCjq0%6pyrT)bPe@ zsXZW2Df)ZPtq0FuZp{@IH)I8igEBP=84B(ff7NqIW<)XMhvJnzw&9Vi$7k19_uQ18 zjmXcXiS7rD8@jD(dq*wm${~8e1pe*&+NHwvno#$r^XQZ-j78x-&NFeaBK|O(z$ekg zE)S6QkT2N&nzu2Z9R)KS#JykaaM$!ilw! z7dLp;3uG^mLvcTA1G+TWrWF$519(JcefS_Fd8gyI8#LS#heA!WDhO{?iS&)Sp)lOsla6(0k2HaevG?C{^@OaW)Hx5weUCtTql2)CBG;D;DUL9X1GmhiS9s6QgAwk&2f= ze~Q@LB6o5DD*&GoksPCn@y7bacZ2#S3LZ3iEGY+F#!KCVJ3i-KIAXZRQ7@4&tXasW zZ{?N5;E+Gm&HSnKC2((2)1H`k*>CV-;zKgV<`|IEXL0Y#=>h)fY>}%VIN%ahuja$( zZ4|ox=fX7()aEcz(0Js5#-l2X-6U7x{`dbGNjQ4b0`|a+WdjF|gO2 z5yupbv)A{z-6}Z5Vr3iP)-5?$tX5~@_y#{m4g)Vjh2k$7IE+SYmh-YyK7J;v;O8a^ zqZ3&x$a^#m&2*^SX$fvN{Xvb`W`!_Gjhj{Q^I6?EZc=tCcH1i4#Ukc*^F@r+2vfNsWApiW#q*U4izptBsp$aiAOMZ(Se!(S_x{Z+bh*2Ht6p+k87xaZr?FDC}lRh?m<50?R zY(6M=gR}e)z!+)tJs=gIS|Ch-%qe|{s^)A_V+SWaY_>1~MAt0^(GyA)`@o2n&qbTM z$)J#1#oPz<(|aKw1-1tD5+6}OT%;L**SVI|#*O86b8Bh^1R|w249Xu0 z3tdfW5ZI|Pt9Zxik~@p=zFO-d_vMX%9l%rB8CMvFf?AMHd5UOq?|E_r_l(4!XAbHO zx7=WaAQSka9rRMKYe2y1EfD~^5V13teECv~RzC9FI0f7=24wdLEaD|KgHIhRI^kYj z9l~=0$_J4O^xOcV@#+)c6OF=fNu4lro^dH|m~^QU-#`K?5?+}I(NA^~$B*+EPUgB< ziJNx1_#4ZCy_O|bmWq}M3znA2xz_oV@|-M0ZoZ*EQLgCZ$ObKLmPxpV&BegJ*m`MLz)}uclKD;~eGxCQ@XC(I37tyY z6=BBe1th^dx0!swz>*K?_^qit=F~>siz6=Kfe$9_Q0<`tiYZ zUQ%XJ`R;-Ax$eFn&OvPpov;;=*nrTkw6J9p(N1Rc5rjR*0cJ-zHM!DgC=lt}Td`Q@ zHMvNjs+k)I5WOq3XJ335f#`u-X+h@QwWfESa2HOsROk))?pigV7&ALni;>C*|F`1H zYcn>h(B=K!*lR!$-rfcz;V0XGwwyuPHitk=qQimpD;G_|3psbpn^JdLtXGMtX`$_keQgj%zI^|Rk(RHv=HF%}(1_HddxDV2f zgPl!}{Nr_D66OB;%Dsmwh0o!G#*wEY776>Wz*Gfj_%wE}xHlwC<<=wc;?5}7h-M}t zw10BSN(13foH#q;ac=hWD^r$!Uf!aAAF|pH`&H9BZ)&o6GkDwaG{jXOr4@m%18vxj>Q3QQh+ypVPc)G*{TmsgUDH3`UZ1_q_rk5Xw znT_ttsdplOju`14V}$L6K-dENARpUuCqp6AD&NUP*Etgmgg1!PM*tn1C?*PbBW?ZhYMV z#KlnXMiV$&+S^{R0W(;|qM8IFLQKfFm#B+~lxM{p_L$j~wqX}utS#7PeJnmT$f^Sh zJ&Ev#?oB*!XGALSeC1?F@R3POzS>q{t-$lDFCnYMSHFUV=Dxb03ahn z^aLPrW>(uD%^6Z5ftzncFJoIQOaDv;t&0pVzDtffEU6l70^wud8Sw=2;SIonBvSJ; zvNhA|E=NBDos=NxnKu3&OE+a%C1X90O%&6NT2pj6jS&9jW{(s9rc5K+yIf&wi~h#Z zr{{6)PiG`nve5-E7YV{9u!>%?es#dm0^f#NAzfwQ7#GTilp4MvWMdbj^T@_g#=<3Y zB3i1IomeEuS_fz!d~hL4dxstbw{|BkUU;1xe6q|~%nY<`U+q2d?p0S<+9Dz0HMZJF z{2dt*R-S3nBnXiKkC+e&UEFbABZH@YmkGzMnqIN{!4JjhcS4~h;Ji$Wy{@o>nLf8d zH!}NDs5ex=E5zF%69aDQ2>;Z{sjlF}Md{L6kz~9pZ0jT{$`CfN*It+up}U-;4lq;} z6#LJeYFq`79I6O^u_+tcx!?Hr>O(2?B|})nosrd}T`kelPj=vi-wXV=KJ#qXBBA~b zw)!mKou!%Y5;w~fx6PEI&6sYSy~LFd-_tRwaP~ymtqrpgb=v>UTJ7p~CBN$+QV+S- z)-`C(PD9p1FfOt4krNsrafw6O@`Ue%1dp&5*>-I_kL@j()nzwb2M0}A!5Qe8cf+{@ z_!QjI>YeVU(-ioulr4lCsE<2q`#e(V(Q*h0B>72)4<=eLbq_GNbIJB3V#H<8PA81@ z2ran44*B}>7$`b`;7-1F7u80{QlQQ`+*F?e(TBx$`D0I!=AV4~;^g(@2ClmAK;Y$n z>kDuWOeO*O9J{F8jps|HE2LmwFQF!<#>n_xP@LE`4&|N~({O$8g~`{yvk?NOuz`*B z`5vL^IN9NQ7(T^8KeohGD|>Ok#SI{oYoBV3wL4UQcVc(Q$sz{i z`-QBi3j!Kkw+fI08VBB-MaR4*YrU zVUwRxtM%yx38^Xb(t;)4vmLaJN~wR0wgvy=nHzQ6lIhX12>(n#da1vR1lkr%VJVFO zw^zv>eLfGTEA_;^60$IUeZSblkDe11x_5_i9Msp;=lFS07E0@2nrI=k*AIQ6wNbSY zDHi+`-07UjM6Ve+olXlEPAryJp{h%w+^fHWthnE^mjotREsuUNBPE6Y z=wPyfyzi~u`ZuRh`fXRzBOWfs#MQJl!#lA7U31hu?r9-^U#Kt5Z5?VJ*8%Lxg_P3( z`7kIAdVsIdGxGMGOKHF*JwChl`} z1V{(7XKirUTRol4r`}l@DnPt9MY?f8>Q#CcaT1aLwtDh=_sdZWY| z%XHUL$Xs6bdaB)KpYe9P>ynMo?jW|dXwJ%aR0vM7SJ`3?_FC*u(QB3P%MbZ#CTr9F z4Ucbg6NvQhC4xlw^{N5oQ6tb+6xOSnfmojTbE27;rtWG_6;SIVtnYI>diBv8@g2Rc zFlfyYis~W6KxqW4_sE0EsthkGd=Rss8|d6CkYq)w2^`Ul%zueN8YqMqy8%tkt8Kk_ z00aR4gTl#Eh3)vpRJv22E`~FiNK2{-j1<{-6kA(lCw|TLaYfrRcYyQq?VBir=>ykP5;XAXNSSNrC{5_{Q>A72B{vI=wC{R0r7%(yU2*OC#mO|f@Lw0 zz$AY+1ai(>FB^gAfWmVN4z9o*JqG`5Gi3=DD;&ELW(C~hmZz0G74Mtb2wD1gC((nL zX(j${QP?rWiYxF6&p=!T^?0P+ccOS|r-ZOpEmn8Fr99=-3MkFt5EbA>s2LD(>1*RF zrOz&Bq5DkbgzbJzHq$+JaslZJX^56JQ%2`BDc}b@CB_7k_->&AB zk|eshsZL-jc#dawB)<8c=}aUWq*{+l@PpXh-wh!}*wRF(Tmr79*8Kdg>D^ZvYy|5% z3w}vQCcITZPax8di#7jadz^D4sM(5r!7mo2fhgbPxBy&$Kk@5Ty*-+Nt`$tR5n5q* z(P!rKi%{f|S)88(abRGe?@aFx@E7+_{tJ=qc3bY~y^45Dz9l{M`rf@wn-7OjSy-%p zN+YVRjQ*=wmK@9G9)ZiC`_t%aVDo?<<%((@-qvu>Uu{%y|J9%jQnL_CEG?Oz%=( z?qX~C)Fj3lmo$a35g@)Z!D@2N9{!Dw&i1mkqxR`BJ!g_9Fh&A>IFCWnO;}X20ME8S zG}-IV>LjMhJcBlNZgwnh{LP{RXCv>shK>s10(yQ?a0e7DnIJ99<(cPSTAYOYmkC(c z46pK%68{vbTIi1^@6!I|DNC1Tl?Cu7K3}ppUVV0UHNc2`U-JXD$Mg+0X{z?DyW=>% zzJEDM0Ej5PT;cwj#n!hYU{_shCIVu<>FA%xUhbAM`tjLc$SQ`sxk*c?h^XqJ@jb)t zX?(!CcqQ=2BMHp9S5A^U=Y8F(SDG-w2P*u>eZn>9CkzSYSJ<8pY1p1~-nUei9W4_1 zxz=8bSIXw25E6#uR{zu;+KKf(VWs^OsCY_D-k%)_J^-JB$L$0}sUFR&=wI$^WI7c@ zqgwI#=vS9~irY4wkIOHV(v^P#rh78p0~yjqHzaLj&kSgX$#k= z3h|~_-5)GVZ=(O1sy(`Q^#C_C-K3Is0Or0tK?h1SXd2A0&nxtu&Tu_$2arhp16(#0 zfK{Ss0%)n%mB;#j8sDx0FQ10zJ7O#9=w9pHm^+KNXV7#ly6n0PPCZjVzfWV203JTqT37dZH`Dof%LaF=av?!1b!O z3hD&f4s|kK?hBgqL@@SozdgD2dT-J<0=Cs_FzR;Bv4Ed9%szEiu-?22h338FSq`X; zkOpt}U^gcVh-zb=4fOe&zQoFpAtk2VxrZFSSH)YWfNnAZtg$!vHI`4yBI@ z;(@bYCGDt54BtV*`-$4}?Y~22yWWL5w{T9KjN(Crqb5X(>m^`KFcLYM;?Va|t9_k_ zZ!7)#Dis^Z`wXyg{yMoYSsfk?+;5} zW@w4(Dsl(C>L_Vf?o@rkivF-H^&6q8@9K4zVQ6phu@8~0vwf0>7^gE+)y`e$Mc@Gs zV29x%HBjd2`2F>xH=LKOjA2E;Nn5hxf~ZsQV?-v}Ng~8DV;jibaKQtCdQ6*gS#s1O z9{@@_zJnNJcA1Y0nsz)x4u1hnYdFpw2_Xp~!6hMlmzOAj@uUD=y)a>I2X{rGaKQk+ zj$q}TLflqB_$w2PfpEd{BBDN%yZ?_79!`hmK5jrDm!}i@yqS zc?gmbCUQT@?DFuxnSYPG*5GEGi=(DejG6v_|2#~6t+xKpQyPk0W4xJvva5Gz3IsLj zYAz;lSB5+X4&gaP4$<-dEE21V&hGusy&lEOzFbz*6G+tX%Zg0172=u~?^e~7mKKG$ z90dyy63X2jlUV)`$a5j?S6L+RXk?hTxen4z^ajW)AMW<@8ND|(_d@iVc;6KG-Ub8| zo&j!9gzq~yuxSDiLoMz>1T1FrK@K0_?S6|dS-L-)(UR!-wKgY1>2b-ybF!|nC}%wdrkq1uWWUMyyFbN{ni$IDvHYpuSS3f zzoQ%dtHiVmX+?QP0C&WHb;9`^hiJcjt)pj43KoXjp?rS7LpSQhJM1G- zeJ%!G?7>3{yY~w)_Kf+H4KfrVyVbz@wc-Y0lC{udDcH8>nKL}8dCC7o>gpvwSh5W> z27C5x@CcdLx8^oKn3(zvzc_3P3CskY)CT&xXwuxag?JwKNzZ)%GB;-YgzFXhMvqkCTqB?(ElquaNE*cuJCTpuTa zQ^Eym-T2-;-A~gr(B*Xg&|N6`vDU7vp`$4c-1yj-yY-hEbOKUw?(p9P2cc)tMc@-AkK%&RBtI?cmzGPRQI>c?{JbhaN9h)7gXB z)U7tsIy`RC93oIMhhMPLnLw{*DRB$>#Rgc~64k0J3_88nQP$l(@>;&{ARD{0C0)SW zmi4j3Kej~>xX_)I@G0+p1-1&-Pyjg<(j!(*x{O|KQG66yJx(Q)m_h?{5z-ARe4sQ= ziW0@}vWiBn(v%wc4e+ z<8P}fl4ws8;-EaR^TGgm06af25Tfni zLmwi*zYh%L-^aHNSmfJQdTV_G7|8b$A_(da_k({QaGLwE$r1dp$x#nNG~-SJ6aLNI zaQeUH3n7C4-!p+3?o6IPGT{+)k^hrFxeE^_@PJN}06XI}2VeNTKb znSr`N*eV~G;aO%=GL9+GlA28V&v6`yy+i0D2@(Wzk%1LdXGR|ZGH5>NE>M`Qx z&S>^k1nZ0Yi87?AiPyC{B?|_m>#DHXjZWJ?<>A$DN9c zy-j1OY<{yV&Zj7=EOv3EtNr^5r5w>SY8>|~tnkRzxo)!)@Pf5b!m;SJu#v`{?3-2D z&bNEbu5ZDn=N0QGkG;w8mG*26b|o|#glVu7FVDdpUj0bSXz5GA$S!O_&fohHuDJb$ zIA23+xbF|RhvX*sx6gWEH-_8HUR%yKg)DJmQ4cqVXvfqX4L`MNu>aA{ z5uxOb2Q*(n2I_wW^-nxa81HIr^gEq1J=nh)i zx>bwM@`S9qpB(K-@}X-i;=S#+^2t3Fy%_yj?;`ECs)w^*uqC z(Yl7F5bGzp`&P8A`J5hF$V6WCyyJM^W=2-WNT5qnjw12iCmk|koV-w}o2c??*VSLk zN=B5qw?oi0YT~l{J4C&}ktlIK{m^nHgC<)2(_Kdb0Ty5VyOV7wd-LXp^MPKb~cSEw8 zj?XpFu|>zXJ6?L=aqPYG-P?kk>Md^*Z{9Tdy_t8U?GJp-*vmm1{X(7HEZyGV_2){S z7cMz{y{oP)PWa^}aHL7r`qIH#qmRy>j#*->7^O4Y&D>7pQn$Atk8j@`QDEvHasTvT zlZp>?hKh4;&8B@)f4A^8R4SSq=t2kX!P9B6`izKPJ%lpTBC^4`{}le6)#anE)8fIi zth>t1b%Uo~81eGuYOK+>HaVI3jM5#oeHPo&Ri-Yhn202$WE40phaRh2IBNo(JOtkB zn4|rIZMSUg2+V`B@zKBZi-j|mPG@g1cyRjKy}ZU}dHIu06tgu#Sx&)r#}c&RHAR#1 z2khJVgT}|_^k=lQ3#x3wcPL?HE?>3(@o|`Vm!N>y{e%aXn}J=#Ys!0ERxF-8kdH)t zM0osJme;X+v@9olzzWKrAJ)&-jK2G(H5uxOdp4C1et92BNQ+7BeDvD%+*73b&Q@*N zEgNoHb4qsHJlzwnyPc^s^YSP|;YGytozo6{4KsGH`xI5=ZiEJP8dEm&EEYmEvp&e< z?~G53=`Iyx>53=LHS&jQR^N6}aU>faHaJjxYaqanEk2S`x@D8P`gz3oAcuTJJ#g4L z*G5^)aktKd44#v2u9&CwiC_|?nO&Vp%;bNGh&{nS`)~{mQ*S@CBN29GTOEJP>6owl zE&N*B(K+)6WvL$b_^?t6($`%&V&;B^y_25$xhudLk?6-#Zc0_&wr%)OnPW@0*SM57 zJ?{+7z?Y8{LisSlKRtQi0_@9ERfXz@ySC)^AII-8(LN_;?yz=ISJLP2w*$Cq#ynFZ zcwkCW#JufZEh>zbo+G|@ux}{n2|wp>0<GA5^(JS!;e-3i>JcrgtFgR6Z7@lU0d^z^H@N<1m)N?Pb?b}|v@i56XL+C^Q)c@L9 zmJDLn{key1@NF^Xs>aPjNq*J6A}lANU-BO9<9E-DUv=hd7~dy7@l(ZcqOks_Q$>FT zY7l>Fhhq)o^V$odv|phrk~4SvhA=nl06lH>UOS(7LRvB~D`MkvfFpWraNqNx!Jj?j zw_{lR2dxn}@PG3tMmjEXK6{xbZtjjU-+lj4hL4b=Y{ttaPpontcHX*H043&FX zarG^Kt}{C+#wgPZUfV%w(<@rloehj$+r|0(&Q5{xth?fhk+PZ(Ea2s&L!qiD41OPZ zM8)!3R;x40rt`&=xBTB?2iqii+VpWM{jc*z{=Dh{h$_x8u6)spf`-+HVdajE&0a1o z&r1B0*yX>U;qf~<2FR)f#@|~f=cUD(`llpd?}04Z|Jvf|v3v4q%1@8ph0X0#%-{F) z!J~C(&fffVZt`Khf9O`U((iapWK1vaO~9gFn#2CVr^j+al%vzKd#ZNI-qJOBd!y~B zTc~O8Z(E(CaWb;EJgTo6w{4n?_^uEpHq#n8cj(BzPk|SoerhU?!46&8{QP%^=Aqvn zuL>Uy2Uka);6M2xcs9Us<#8P&l|be!bdmYm zt|b>Z^?KtTnBaYLF+O_D)jJz6mE4VZZu~Q?*}CE9iL-Ub`bt`L_;t#rkJCjCYM-uF ztrxgrABY~6^4qZU+gbmaeBsGq_vO}cq1!i3(PK?tpZv46a5>Am@|GdQ?$!y_Z4Rn~ z!cUwwlap;uVg+9o9!R5CnT8V&$Q93a31~L_UyXfbP+URNE-daAT!K5n-GWPSm*DR1 z8eA6F0D<7{?(QxD7J?-NhsA?)H_3bNulv>a{+X(+bE-~vPftHjchAgX4YIw%`L|Sh znhp8DzM}b6_>?T@L&{y`pp_qWP`o|w3qg_;*pv!pR%2Hc=dsa|41DrfpY?awkc0>dWTZ5=$hxm2@xdZMVNe~?fVm-0VfXFj@JQS0K(KsbM(!DlubqP zPB|ygXCDG~8=P>1E*jJJS zUQP(DnS^E9Ja=foHYrr3hp*VqIR~}CAUNq9zuu_-em29=9kHjh%<+(hn9peAcHIYo znbqFq2fr5*8)NG<3wdV_#IryD^X$4_b^G;vG;;A)2e+0s(;xR|{i|nt;yujaLI~d7d$N6%`i4w7970O7|qFR|vX!SnQ(?mf`>N zzzB5RE$^0JRLBT&?|Ljxl46A_nrwG3P>wkI#%e1}s3(_Q<_voe$+kl-Ef=7!&noho zsyJjE7{$<{n@==C2b79?p(FxSdl_@@xHog7>{p)526XL4rCeOieq+j6xQ@IvQn3K` zr(hG5{Gx;oJV0<0?t`FllnDVZ5*s9pH$o-s>5#fS35AcyQKndw>7@6f%+Dc70lvSW zg%1`1bE~x`2YSCE-falE*(5?pT+QYqeza~iP|06;TdUIVl@f;bK;r?0ZIcSItakTo zi_oa=&Z7~tlu2Kfj6f%x_#^?v>|z`n#>N~UsSYG^4(E3T`8fW6J}wbi<;bxl-|${( z;R;z~F`Yo5LOfWQRee<^#^!nHhb>INLaW%rRRLcRxnw+`vi>%>$Jt4w91bLYO}wrL z7pINATqAvpi}C3I+xR>2uUxVwyMi^{+~1pyq#63SwJ+7`5OU^KUwZM`bI6T-)PTAV zlACfg2O9;pB=AC868uK|XM{-UD#x= zj0>N&nnM`8qq8H`9eET~EL3NPbB*|=7MX1c)QU?=z`c;?v3)!+=LlitYn5 zjVai|5@x4##NLfDGJUIai~@7g;)IY#Hu`-HtJ4`4G?uaSzSTRgqcwu9$LrNo6UPL{ zt^z&3yDi|oPFFaqxO(>+?ClUn-cE0k$CK(!=diyTUrWL2z_uN^&1i1l^c?X^)Hv}Q zjT&0is!UaW46f{@wX(^aV%+hZr|Q>HO~L39t>JU!s%0dI7hp{}$yGZtDatZar7W>@ z-Fw*8+%l0wg%BZ1|l8G`;k~`$(2qR6Kq&9Se-d^ zEMqKVrP`l6=20)09%m&nt@mV;N`dH`v|sgt*Mlgm4n?Lxo-6F0SPWFG0>{^iOLDe4 zP*YoJh3gTgjRGKAYqRA%cVagP18rx^6YQ(uW)rx>ujQGmT77Bc0gXzu(zgVPI^1dA+*YwQ1 zGog|n=wj|mYo$9us9Qqr)2c(#T`!orXr7Z!o0VSswh>TfLRIs#IZZdHG?9pGroCDg zGlfCegWU`^iQ)a;#HQm}d2XNIBO{;J?t^s#dFDP6fuY&0V_?Oc6@uX)7Jyy#3g;FZ zRYdxP!!ymzQ2*3q<(?jofCv%cj!8g;bL@PE*h+;`uM-eNl9PY1a~CW1$l0C{iYoI1 z-?;+8V{G}h|L6dTqh>7n$rj7_adGeM+vaZv@LHs8V$ttu7?UdzB0ht9N1Q|qKudDu z+;9X^lgja`npDqcRR)=22B){lK_1{vk? zIl?1>AAN;me&o*!c)Zh|5POtIW%LJwY9VJB)|seSS^#QGF1=bn$6N?8y9sg7*&m0% zlnj;%7Cw$km?ivxQFu&H0?l986Sv5qBfn=N87$~3i{Ous2JZSnDjy=C+$x+fGaElC z**7uuQT-}p;5vRsY{P!eVfYQb(!8jLJVX+*|lNi>HrxK+sK zkA7(e9R{zb4HEG=F>kfW+O?KtQLpigop+_f9ZD~pM4bleiM?UQ#2||QbBFG+^HEA` z{|s4_X68{zU4$k4o{blj!nqY+X!P9_bV>cVypbT)g{Qg!4P^~!Cq`-L{16lI2PK#Y zOg+ItLLw@Fy*p`_p3+{M=Hf&ql)@{9A?|6X+U+vX=v36x9U-&r1{kY4mL&i{-8<_t z-`5Q=yuKGboCx^GtP;FyK}83g zkzP!|62)?+KB$V<+bNQ4Wsh5sei77vG?e>f7$I8z7_!36TF@ii@M$6e)rr+R(w%&B zlE3xYg(GD2ILK^GNiYxaD&ntx$)66DA^iCvr#`fYR z&@}M}02PqmQ=75si|`US_rOGhM*YW~jM@LFR!yi^=>&=mi|(l>^X2N9)DQ$RZqZ#2 zG5TaBQpK@(lqC~r1etlK^u!jy1fb>cqKsVJ+)mH7+wX^N$KnKu&8|Znrf6PZ6Yilg z{82J^`Z|0h0iLilzgouaM#|COCO}Xa+np`R_m(80re- z+yqs-d*c6Wbx~#ff9Kfvpt$5n|ANQktT=3_qs;S8LbS+)Qc7?PvsnVw>04)>``6ah zyT@?ts1NF$PE~P>S`k(Jwz9z%sMw}5z!NPjV@|uFho4>7dJYNJGx4!!N(a(|q=U;fY1?v89y-#TI~dwipsgN=P!FoA)?9UJ(i zz)(InVmyNPd&_6{ulv9YPt&{<{&VLp^8E&3iX=e~hg8Jt_?J_}^8o|xD#fbj<5^;< z;8Uk4+_+)xsRKigZE0N)A@lXL5l2Ul&(-TwpGD^acI z3R2su(X`{v-I|$%RH0n%kyc@#dtdj^`OGn0BJN#ifOUJ3uhK#p5QmWjX-|$Q7*@~D zi|i=&a*LxV7Ru}^m!b3b#<@D=QgpkC+eU9Q({=$4?0lo#oa*R@f<-~!k2`z$yQ_W- zARB$zf@Py%Lh@(4oEyXeCHr^AnPp@8}p1Cy2O&n)bw24s%LIB~`UUPHFM`j(SqpG@>#^ zS=s_%3K3lBgDXTB8cw1R_7Zq+GWzZJBHZJ)E2510f@DFx)u;vSR=Zvhv?WgbkP>c3 ziAAoT{SqFo_CCoP@HqS*EIJPh77?sgVMD{CS2*2gzaF}LlR~ZN)lxf|K#6{Qm)IC) zkmKNok;qY#!A&t{>7e61{BoR5hi%p`%6%3*+$zY`%<=H~?d*ECC&2zI7Vl4{sUX(O zw19rGn=f_0xL7ZHC%4O3xGX;~K<~{quelajuXdQ-re&0h{^LOss;LV}gR5z|fT>fz zsz8%9l@-+jf{Ozdj-?&ZNlz7k2 z&$S9OmO39}klAWYL@vka9E7hr4(DB7N$Ni@ApDuGu$Wd49Aodoc)&+dr8KD4ss_z~ zN)ehSI}W))>cC1y+``}wd8L_oIR&E9xL~5Kt48<;D1Cf9!=BfU&ok^u#8+%0SL?-* zYVThE(ls#22;CsffvmZ^^cGyi{2tEmwyNW@{OuFVSpEYyyHgFlrq^-CQL%QHBdiORC`ZRB7)WdQ?nne z5I*IRILL^aXp}h$7)I#@a7-&RC9pYGyJH)}&wC49T~7zatq%BerqTXgQe4w}z;RWi z#WTJ16ah-d@b(||WWSE|d2;~w=D3QDe16vJN^%PUF`x;iWvSc6Ety}nEra#BomwuZ}_Z6r6Sa9iD9L0xy>r;IONgG zn15cqa=4pih*ATDoCO&W0i0Lg2#*s&IaPdOGZZ0FzX8!bCNNZUNxF zV&U!IwcX<^kO^lX2%9&n%^Yx5jf5p4td6Z?yB zAC*S<7^hVww`dalbp+)rQTy;E8)9uV7H!*r=e5%dFr1#3CUSC``JGgVozNyLV5(CA zh-Mr=BO2kbz%$J}vpXQz%I=O`_)GA^+7fD=zAI7= zV~#Y77*d}=jb_Ioh|?d(`zmc;&rW@0{bhQ^9?KZV=SGMAyriiAjqZmoeA3W+;Tmk0 z>Z+txrY>VjO04b9D(<`4CAXiPe?VaItr(Hm6B%`rLMO5c62z=7t$sTbA% z_(AUKCOPzkDcq;AD7hixc2vf}!l|~PTmts47vxnuHOF4ygjbC5mjUI=zZCkHsRd-Q zQ+4gR&5hs3H4xKgB~CfflTqRh4KA^PmT(XGl~GN*GgIwMzhBTT?Ni`NcDl6O6H81x zH^nxYK!RvQF2Dp4xYGgTV{^@Kj(O%7=-rFusSgyj4uG&9%D$6U#ykB#&%@hI`3%(t z+kdR}K6gcTw^G=}V`D<|T%2L-1pGBhBD3`*WsN~QuvTe1535z3u=h#fM5VpF&@;d; z!0XIfu>G!qm%Lym>`Q&3Bq=kE9ny-6KtXxl%5b3|rnF&Re(8IK*!*3Bp0hSW1nES) zh6=eArj+W6B%Ig|70#{Ly>-gIu2#mjq%p(;kD(rgG36OEl7{{Q-w|<7UK&M6Zfqr> zS8;K;X|%G5Q(SNxBMb>w#W!SK%hFgj4KVjpKph24i`nox)w0J8@QR*o!z#Z z(`$CQ9`w-zb#&sxM!)5F6(!zfhKH7r| z8>=Q9_@==*7IKyF8_{EO6A-=}vEH1ME*e>@?M$Wp zWK&})gnZQbA}b<#3Znc1LB@prQ+81ePv2!9uhLguratg+BxL@FZ|2WlL0 zJL76xQ4!gkTW=WiFt-uq%e+WONIFI@)o0HLyyWo(j-i&iOZ;JF`4;W>Uzx%<8PMX} z2tOIU)R2q17^xlEMJLL9*P{;m84|%bhUp^eakjHBwACn|TM`vHr7pPC02=$!$*?9p zj&n>27PDKzM?ok+`-OI@+#tY7h%7EH0?tF4szDxc;}@Az8L_|ptbcCxfhTyi@cz=Q z5`JoCqQ|+gD;gL~CVT##*&SQS0D1(ikI$etr`_UOKbNF?`pM$Kbp8c`RpyDu7K#U@ z+$!0|fexaY_Oq4svv_BrI?1=*T`#;Pr{*NvzXZ`z9si<{uAt#14$E>-mnFKt_CVAV{I=CA7kx0mzNq1g~$qG2au|EpyV`4`7t8g2ePk)7v$j+ z<;V2x`0QQ0{UHK+;zgh3bL-nAzc;1>?w5X(hZD3h>G8TM$%f2>$QGbrOK9!+r~NVm z1I5!lrv2_|QLCqG-9=R*D7x=h%muHye^YeVGTTgq>okRTK8&7E7jh_((_%FD1RYgP zo(daiGj0hB)yQ$Am|dvFk*VagrWQCehuqm6e%)JP#NKCwvcR^fzW)tN)0~n9mM(% zAa&}qPcjBKn9i17pE} z&rq|5W06ljr}ZnP9H#i`EUT##cq0BZJY36B_00C8H{XKeD|8QzukHt)Srw+_U2AOi z4ztUt`kvO21nGXhPk&=TTSU{#K|}2MaFj|a57sZ^uJ^XGjBgg~r>+BMyp^W!VZx*g zwHMR{KCHL0tYP=Yd9G3ZNt^l@g3^kO*<8QUhtEpfXD=&1*kq**&<%BGDg;MO#NrS} z4=uBr&t^rh9d_IyHKD2`jY)uo?xr|SWLvrRI=gGuy*FgLoT$Oc1j-Dhu3kuRKK z2O0RY)^w>lajU3!84nHrRvXM|>#DN)%~?)C<%IA-@dM4~{1@aunizbZJrC>8f4P)CS0tbJKrWx< zdHH<@<$U?RIo6Vv8a#;pDQc0QTU||v@8<04J;nNG$T8HvXXJOjKYn9FhlBdY?Wj(r zO14Cd#sQh1O?)}gvzj)kLjPbZn-PGT37B%qT^az7+dmR@>{&CD(Jh{I=nQN{9QI7W zd|7wbIX_o2O8jfJJ6ZO+)E5rxtiqd}b)j9JGrAQF24|i*HKvc*JRxE8%&^R^#`Rm! zuJ69j?|LWfX+}ed3fKwLka&+B6`GSCEET)AidOL{U;w(|ulb7wXE#c=kpJSKOP(h> z@cBKr`c@#n_p%L(AHn_cEY*5cWVvVCCc_0-wJ2@sQ|@e*5MXAi=RK<0@+T8tp+^hn zlPyh1d2Hv7qHw+j(Rtvo?UktfCP0woio~>cc{8WtxBsgOece$XnxfL7*Ccw<-&BXM zE(Iro((zP2wVrVOV71(-z^N?6zC=_WO`Thq;;uaUHqmA?=cN)oAvllqJAU+V6h<$d z7Ncjksk3J2fvL1(52CX;o8J9rm-U-2Ctgw}YPhGS%cup1jHTB3fUiP2k?JvrTf*It zQYU2kF%dkc0wvPc(uM~D9Y(w$fn{FdwkN-AWqkY_jq;I)I1#O?kOFEGGt>xK2aJgt z6p79tNH&xGgHmuAUv3MjkTVxCKD1Iq%UPd{+iz=+MbnD&_8)p`FqBQ;=}ika!t(^C zc)lR5NC($u@4p$hn6IggNw3%RD2V0$T_RPOU6NsC>A%dv|%8J*-| z_231;Ap{}lrTnF72-swcdZc%*sOeu=@aiQYt^9)dNBOz z0xzX;h(pxT!$vWgB*z#()g5M(51saff7YK>Z8rMpsMYexHbGu>fq-#+goITMX5gmH zf85Brc;|WGDBa0n${@(fas~nst|89U0+1#I(<0DmWl`eJzF*|=^>8p_xz%$K=AS(% z9r;Qb(IOr^e{n!kMGZ@PxeunW$8G_&Eqm=2+*0}s+)+5R*xEg{t|MTkrH)Spq(ubO z6zlg-zfal}MjD%u(IT4Y;~pr(#xf|0$1ANP++G3`E7aM&H64wpEf0>UeQfF+)R0c} z;v%nzC6#9dYMZiuf7NU!kWo--e3d7B$!81VqrWbW487yY@&nx}zij|{88PP%{Y7lq z(1T^gM2XnZr;z}n&pyw5;^YFFh%3of-NP&mn2M+0B`(N|$y|`1va|M84`Y`t1Vts{ z7RS7zAE?an+j$vJ$9*>4yV#cSX@rBn4Cpi0qDAQCX)d#B+uZ(vBS;y-&Vd|@{}z(3 z;hN>R?A(Fz@6b>D^__z1FaqxeV+Ml=QcX+*fA9;wv&H?4Dr1`2VO|1Dw3tSZ#XFJx zBI2KdCftM%ZjZ6|#D?upg!B!T;3Wj}YoC^W#FNCX%iq7#siP_Lj=^RwZmHI|!r8k1 zxRqo@3U|pMd~Q^jj;nFQ@SCOYao13xt^8qd7Z%c*G0Z8mm7YmpN_bB#o;4U=!*tFVXflyB#j4rLp6JZKB` z4y;$)MwN{pkcP?eCCP&f$8lc%=+`(10_48RRF0XuAvu%$``}&e4)w{2uLo`Y&Nnia zId}FuJqR-j;y5-RGju`tj^CP+#ag~Flhwc6t#i6fd9?&t$lv~+#2EBT(xgbR4>=Dy zU8#rcu7~O)4?5O8VT3fxAwT|H1pdDbfzXhx{(X@5{>K3k@<#?x|LySSK53}`atQu+ wxZv*zga38-AI1CsJ@{{>{rld3DJ>K Date: Sun, 16 May 2021 13:55:52 +0200 Subject: [PATCH 0794/1539] Standardize badges --- README.rst | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/README.rst b/README.rst index cd635835c..70dec88db 100644 --- a/README.rst +++ b/README.rst @@ -2,26 +2,26 @@ :width: 480px :alt: websockets -|rtd| |pypi-v| |pypi-pyversions| |pypi-l| |pypi-wheel| |tests| +|licence| |version| |pyversions| |wheel| |tests| |docs| -.. |rtd| image:: https://readthedocs.org/projects/websockets/badge/?version=latest - :target: https://websockets.readthedocs.io/ - -.. |pypi-v| image:: https://img.shields.io/pypi/v/websockets.svg +.. |licence| image:: https://img.shields.io/pypi/l/websockets.svg :target: https://pypi.python.org/pypi/websockets -.. |pypi-pyversions| image:: https://img.shields.io/pypi/pyversions/websockets.svg +.. |version| image:: https://img.shields.io/pypi/v/websockets.svg :target: https://pypi.python.org/pypi/websockets -.. |pypi-l| image:: https://img.shields.io/pypi/l/websockets.svg +.. |pyversions| image:: https://img.shields.io/pypi/pyversions/websockets.svg :target: https://pypi.python.org/pypi/websockets -.. |pypi-wheel| image:: https://img.shields.io/pypi/wheel/websockets.svg +.. |wheel| image:: https://img.shields.io/pypi/wheel/websockets.svg :target: https://pypi.python.org/pypi/websockets -.. |tests| image:: https://github.com/aaugustin/websockets/actions/workflows/tests.yml/badge.svg +.. |tests| image:: https://img.shields.io/github/checks-status/aaugustin/websockets/main :target: https://github.com/aaugustin/websockets/actions/workflows/tests.yml +.. |docs| image:: https://img.shields.io/readthedocs/websockets.svg + :target: https://websockets.readthedocs.io/ + What is ``websockets``? ----------------------- From f029e81605a80f37da1aea3c73644b7edca7f7a9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 May 2021 09:16:54 +0200 Subject: [PATCH 0795/1539] Move Heroku docs inside how-to directory. --- docs/howto/django.rst | 4 ++-- docs/{ => howto}/heroku.rst | 25 ++++--------------------- docs/index.rst | 2 +- example/deployment/heroku/app.py | 16 ++++++++++++++++ 4 files changed, 23 insertions(+), 24 deletions(-) rename docs/{ => howto}/heroku.rst (89%) create mode 100644 example/deployment/heroku/app.py diff --git a/docs/howto/django.rst b/docs/howto/django.rst index fd170c387..1b4b4d3b9 100644 --- a/docs/howto/django.rst +++ b/docs/howto/django.rst @@ -1,5 +1,5 @@ -Using websockets with Django -============================ +Integrate with Django +===================== If you're looking at adding real-time capabilities to a Django project with WebSocket, you have two main options. diff --git a/docs/heroku.rst b/docs/howto/heroku.rst similarity index 89% rename from docs/heroku.rst rename to docs/howto/heroku.rst index d23dc64c0..876a9160a 100644 --- a/docs/heroku.rst +++ b/docs/howto/heroku.rst @@ -1,9 +1,9 @@ -Deploying to Heroku -=================== +Deploy to Heroku +================ This guide describes how to deploy a websockets server to Heroku_. We're going to deploy a very simple app. The process would be identical for a more -realistic app. +realistic app. It would be similiar on other Platorm as a Service providers. .. _Heroku: https://www.heroku.com/ @@ -39,24 +39,7 @@ you'll have to pick a different name because I'm already using Here's the implementation of the app, an echo server. Save it in a file called ``app.py``: -.. code:: python - - #!/usr/bin/env python - - import asyncio - import os - - import websockets - - async def echo(websocket, path): - async for message in websocket: - await websocket.send(message) - - async def main(): - async with websockets.serve(echo, "", int(os.environ["PORT"])): - await asyncio.Future() # run forever - - asyncio.run(main()) +.. literalinclude:: ../../example/deployment/heroku/app.py The server relies on the ``$PORT`` environment variable to tell on which port it will listen, according to Heroku's conventions. diff --git a/docs/index.rst b/docs/index.rst index 4ec682f69..57d158e8e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -60,7 +60,7 @@ These guides will help you build and deploy a ``websockets`` application. deployment extensions howto/django - heroku + howto/heroku Reference --------- diff --git a/example/deployment/heroku/app.py b/example/deployment/heroku/app.py new file mode 100644 index 000000000..aceb754a3 --- /dev/null +++ b/example/deployment/heroku/app.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python + +import asyncio +import os + +import websockets + +async def echo(websocket, path): + async for message in websocket: + await websocket.send(message) + +async def main(): + async with websockets.serve(echo, "", int(os.environ["PORT"])): + await asyncio.Future() # run forever + +asyncio.run(main()) From ad8ea999391ccd3a7d97edd7a36bd228fdc6c09e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 May 2021 14:31:48 +0200 Subject: [PATCH 0796/1539] Restructure documentation. Lots of small improvements while proof-reading. --- docs/extensions.rst | 97 ------------------------ docs/{ => howto}/cheatsheet.rst | 6 +- docs/howto/compression.rst | 51 +++++++++++++ docs/{ => howto}/deployment.rst | 20 ++--- docs/howto/extensions.rst | 31 ++++++++ docs/{ => howto}/faq.rst | 6 +- docs/howto/heroku.rst | 2 +- docs/howto/index.rst | 35 +++++++++ docs/index.rst | 49 ++++++------ docs/{intro.rst => intro/index.rst} | 42 +++++------ docs/license.rst | 4 - docs/limitations.rst | 10 --- docs/{ => project}/changelog.rst | 35 +++++---- docs/{ => project}/contributing.rst | 20 ++--- docs/project/index.rst | 10 +++ docs/project/license.rst | 4 + docs/{ => project}/tidelift.rst | 2 +- docs/{api => reference}/client.rst | 0 docs/{api => reference}/extensions.rst | 10 +++ docs/{api => reference}/index.rst | 9 ++- docs/reference/limitations.rst | 35 +++++++++ docs/{api => reference}/server.rst | 0 docs/{api => reference}/utilities.rst | 0 docs/{ => topics}/design.rst | 99 +++++++++++++------------ docs/topics/index.rst | 8 ++ docs/{ => topics}/lifecycle.graffle | Bin docs/{ => topics}/lifecycle.svg | 0 docs/{ => topics}/protocol.graffle | Bin docs/{ => topics}/protocol.svg | 0 docs/{ => topics}/security.rst | 18 +++-- example/health_check_server.py | 3 +- example/secure_client.py | 3 +- example/secure_server.py | 3 +- example/shutdown_client.py | 2 +- example/shutdown_server.py | 2 +- 35 files changed, 345 insertions(+), 271 deletions(-) delete mode 100644 docs/extensions.rst rename docs/{ => howto}/cheatsheet.rst (93%) create mode 100644 docs/howto/compression.rst rename docs/{ => howto}/deployment.rst (92%) create mode 100644 docs/howto/extensions.rst rename docs/{ => howto}/faq.rst (98%) create mode 100644 docs/howto/index.rst rename docs/{intro.rst => intro/index.rst} (85%) delete mode 100644 docs/license.rst delete mode 100644 docs/limitations.rst rename docs/{ => project}/changelog.rst (94%) rename docs/{ => project}/contributing.rst (76%) create mode 100644 docs/project/index.rst create mode 100644 docs/project/license.rst rename docs/{ => project}/tidelift.rst (99%) rename docs/{api => reference}/client.rst (100%) rename docs/{api => reference}/extensions.rst (53%) rename docs/{api => reference}/index.rst (92%) create mode 100644 docs/reference/limitations.rst rename docs/{api => reference}/server.rst (100%) rename docs/{api => reference}/utilities.rst (100%) rename docs/{ => topics}/design.rst (87%) create mode 100644 docs/topics/index.rst rename docs/{ => topics}/lifecycle.graffle (100%) rename docs/{ => topics}/lifecycle.svg (100%) rename docs/{ => topics}/protocol.graffle (100%) rename docs/{ => topics}/protocol.svg (100%) rename docs/{ => topics}/security.rst (56%) diff --git a/docs/extensions.rst b/docs/extensions.rst deleted file mode 100644 index f5e2f497f..000000000 --- a/docs/extensions.rst +++ /dev/null @@ -1,97 +0,0 @@ -Extensions -========== - -.. currentmodule:: websockets.extensions - -The WebSocket protocol supports extensions_. - -At the time of writing, there's only one `registered extension`_ with a public -specification, WebSocket Per-Message Deflate, specified in :rfc:`7692`. - -.. _extensions: https://tools.ietf.org/html/rfc6455#section-9 -.. _registered extension: https://www.iana.org/assignments/websocket/websocket.xhtml#extension-name - -Per-Message Deflate -------------------- - -:func:`~websockets.client.connect` and :func:`~websockets.server.serve` enable -the Per-Message Deflate extension by default. - -If you want to disable it, set ``compression=None``:: - - import websockets - - websockets.connect(..., compression=None) - - websockets.serve(..., compression=None) - - -.. _per-message-deflate-configuration-example: - -You can also configure the Per-Message Deflate extension explicitly if you -want to customize compression settings:: - - import websockets - from websockets.extensions import permessage_deflate - - websockets.connect( - ..., - extensions=[ - permessage_deflate.ClientPerMessageDeflateFactory( - server_max_window_bits=11, - client_max_window_bits=11, - compress_settings={'memLevel': 4}, - ), - ], - ) - - websockets.serve( - ..., - extensions=[ - permessage_deflate.ServerPerMessageDeflateFactory( - server_max_window_bits=11, - client_max_window_bits=11, - compress_settings={'memLevel': 4}, - ), - ], - ) - -The window bits and memory level values chosen in these examples reduce memory -usage. You can read more about :ref:`optimizing compression settings -`. - -Refer to the API documentation of -:class:`~permessage_deflate.ClientPerMessageDeflateFactory` and -:class:`~permessage_deflate.ServerPerMessageDeflateFactory` for details. - -Writing an extension --------------------- - -During the opening handshake, WebSocket clients and servers negotiate which -extensions will be used with which parameters. Then each frame is processed by -extensions before being sent or after being received. - -As a consequence, writing an extension requires implementing several classes: - -* Extension Factory: it negotiates parameters and instantiates the extension. - - Clients and servers require separate extension factories with distinct APIs. - - Extension factories are the public API of an extension. - -* Extension: it decodes incoming frames and encodes outgoing frames. - - If the extension is symmetrical, clients and servers can use the same - class. - - Extensions are initialized by extension factories, so they don't need to be - part of the public API of an extension. - -``websockets`` provides abstract base classes for extension factories and -extensions. See the API documentation for details on their methods: - -* :class:`ClientExtensionFactory` and class:`ServerExtensionFactory` for - :extension factories, -* :class:`Extension` for extensions. - - diff --git a/docs/cheatsheet.rst b/docs/howto/cheatsheet.rst similarity index 93% rename from docs/cheatsheet.rst rename to docs/howto/cheatsheet.rst index a71f08d74..86684c44c 100644 --- a/docs/cheatsheet.rst +++ b/docs/howto/cheatsheet.rst @@ -65,7 +65,7 @@ Client Debugging --------- -If you don't understand what ``websockets`` is doing, enable logging:: +If you don't understand what websockets is doing, enable logging:: import logging logger = logging.getLogger('websockets') @@ -79,8 +79,8 @@ The logs contain: * All frames at the ``DEBUG`` level — this can be very verbose If you're new to ``asyncio``, you will certainly encounter issues that are -related to asynchronous programming in general rather than to ``websockets`` -in particular. Fortunately Python's official documentation provides advice to +related to asynchronous programming in general rather than to websockets in +particular. Fortunately Python's official documentation provides advice to `develop with asyncio`_. Check it out: it's invaluable! .. _develop with asyncio: https://docs.python.org/3/library/asyncio-dev.html diff --git a/docs/howto/compression.rst b/docs/howto/compression.rst new file mode 100644 index 000000000..9023cec56 --- /dev/null +++ b/docs/howto/compression.rst @@ -0,0 +1,51 @@ +Compression +=========== + +:func:`~websockets.client.connect` and :func:`~websockets.server.serve` enable +compression by default. + +If you want to disable it, set ``compression=None``:: + + import websockets + + websockets.connect(..., compression=None) + + websockets.serve(..., compression=None) + +.. _per-message-deflate-configuration-example: + +You can also configure the Per-Message Deflate extension explicitly if you +want to customize compression settings:: + + import websockets + from websockets.extensions import permessage_deflate + + websockets.connect( + ..., + extensions=[ + permessage_deflate.ClientPerMessageDeflateFactory( + server_max_window_bits=11, + client_max_window_bits=11, + compress_settings={'memLevel': 4}, + ), + ], + ) + + websockets.serve( + ..., + extensions=[ + permessage_deflate.ServerPerMessageDeflateFactory( + server_max_window_bits=11, + client_max_window_bits=11, + compress_settings={'memLevel': 4}, + ), + ], + ) + +The window bits and memory level values chosen in these examples reduce memory +usage. You can read more about :ref:`optimizing compression settings +`. + +Refer to the API documentation of +:class:`~permessage_deflate.ClientPerMessageDeflateFactory` and +:class:`~permessage_deflate.ServerPerMessageDeflateFactory` for details. diff --git a/docs/deployment.rst b/docs/howto/deployment.rst similarity index 92% rename from docs/deployment.rst rename to docs/howto/deployment.rst index aa1af211c..35720e8f1 100644 --- a/docs/deployment.rst +++ b/docs/howto/deployment.rst @@ -34,8 +34,8 @@ On Unix systems, shutdown is usually triggered by sending a signal. Here's a full example for handling SIGTERM on Unix: -.. literalinclude:: ../example/shutdown_server.py - :emphasize-lines: 13,17-19 +.. literalinclude:: ../../example/shutdown_server.py + :emphasize-lines: 12-15,17 This example is easily adapted to handle other signals. If you override the default handler for SIGINT, which raises :exc:`KeyboardInterrupt`, be aware @@ -151,8 +151,8 @@ Under normal circumstances, buffers are almost always empty. Under high load, if a server receives more messages than it can process, bufferbloat can result in excessive memory use. -By default ``websockets`` has generous limits. It is strongly recommended to -adapt them to your application. When you call :func:`~legacy.server.serve`: +By default websockets has generous limits. It is strongly recommended to adapt +them to your application. When you call :func:`~legacy.server.serve`: - Set ``max_size`` (default: 1 MiB, UTF-8 encoded) to the maximum size of messages your application generates. @@ -171,12 +171,12 @@ Port sharing The WebSocket protocol is an extension of HTTP/1.1. It can be tempting to serve both HTTP and WebSocket on the same port. -The author of ``websockets`` doesn't think that's a good idea, due to the -widely different operational characteristics of HTTP and WebSocket. +The author of websockets doesn't think that's a good idea, due to the widely +different operational characteristics of HTTP and WebSocket. -``websockets`` provide minimal support for responding to HTTP requests with -the :meth:`~legacy.server.WebSocketServerProtocol.process_request` hook. Typical +websockets provide minimal support for responding to HTTP requests with the +:meth:`~legacy.server.WebSocketServerProtocol.process_request` hook. Typical use cases include health checks. Here's an example: -.. literalinclude:: ../example/health_check_server.py - :emphasize-lines: 9-11,17-19 +.. literalinclude:: ../../example/health_check_server.py + :emphasize-lines: 9-11,20 diff --git a/docs/howto/extensions.rst b/docs/howto/extensions.rst new file mode 100644 index 000000000..fdaf09f63 --- /dev/null +++ b/docs/howto/extensions.rst @@ -0,0 +1,31 @@ +Writing an extension +==================== + +During the opening handshake, WebSocket clients and servers negotiate which +extensions will be used with which parameters. Then each frame is processed by +extensions before being sent or after being received. + +As a consequence, writing an extension requires implementing several classes: + +* Extension Factory: it negotiates parameters and instantiates the extension. + + Clients and servers require separate extension factories with distinct APIs. + + Extension factories are the public API of an extension. + +* Extension: it decodes incoming frames and encodes outgoing frames. + + If the extension is symmetrical, clients and servers can use the same + class. + + Extensions are initialized by extension factories, so they don't need to be + part of the public API of an extension. + +websockets provides abstract base classes for extension factories and +extensions. See the API documentation for details on their methods: + +* :class:`ClientExtensionFactory` and class:`ServerExtensionFactory` for + :extension factories, +* :class:`Extension` for extensions. + + diff --git a/docs/faq.rst b/docs/howto/faq.rst similarity index 98% rename from docs/faq.rst rename to docs/howto/faq.rst index abd396cde..0196fd2c0 100644 --- a/docs/faq.rst +++ b/docs/howto/faq.rst @@ -181,7 +181,7 @@ You can close the connection. Here's an example that terminates cleanly when it receives SIGTERM on Unix: -.. literalinclude:: ../example/shutdown_client.py +.. literalinclude:: ../../example/shutdown_client.py :emphasize-lines: 10-13 How do I disable TLS/SSL certificate verification? @@ -324,8 +324,8 @@ coroutines make it easier to manage control flow in concurrent code. If you prefer callback-based APIs, you should use another library. -Can I use ``websockets`` synchronously, without ``async`` / ``await``? -...................................................................... +Can I use websockets synchronously, without ``async`` / ``await``? +.................................................................. You can convert every asynchronous call to a synchronous call by wrapping it in ``asyncio.get_event_loop().run_until_complete(...)``. Unfortunately, this diff --git a/docs/howto/heroku.rst b/docs/howto/heroku.rst index 876a9160a..c6d81ab1d 100644 --- a/docs/howto/heroku.rst +++ b/docs/howto/heroku.rst @@ -3,7 +3,7 @@ Deploy to Heroku This guide describes how to deploy a websockets server to Heroku_. We're going to deploy a very simple app. The process would be identical for a more -realistic app. It would be similiar on other Platorm as a Service providers. +realistic app. It would be similar on other Platform as a Service providers. .. _Heroku: https://www.heroku.com/ diff --git a/docs/howto/index.rst b/docs/howto/index.rst new file mode 100644 index 000000000..bcacbc174 --- /dev/null +++ b/docs/howto/index.rst @@ -0,0 +1,35 @@ +How-to guides +============= + +If you're stuck, perhaps you'll find the answer in the FAQ or the cheat sheet. + +.. toctree:: + :maxdepth: 2 + + faq + cheatsheet + +The following guides will help you integrate websockets into a broader system. + +.. toctree:: + :maxdepth: 2 + + django + +The WebSocket protocol makes provisions for extending or specializing its +features, which websockets supports fully. + +.. toctree:: + :maxdepth: 2 + + extensions + +Once your application is ready, learn how to deploy it with a convenient, +optimized, and secure setup. + +.. toctree:: + :maxdepth: 2 + + deployment + compression + heroku diff --git a/docs/index.rst b/docs/index.rst index 57d158e8e..000c19ca7 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,25 +1,28 @@ websockets ========== -|pypi-v| |pypi-pyversions| |pypi-l| |pypi-wheel| |tests| +|licence| |version| |pyversions| |wheel| |tests| |docs| -.. |pypi-v| image:: https://img.shields.io/pypi/v/websockets.svg +.. |licence| image:: https://img.shields.io/pypi/l/websockets.svg :target: https://pypi.python.org/pypi/websockets -.. |pypi-pyversions| image:: https://img.shields.io/pypi/pyversions/websockets.svg +.. |version| image:: https://img.shields.io/pypi/v/websockets.svg :target: https://pypi.python.org/pypi/websockets -.. |pypi-l| image:: https://img.shields.io/pypi/l/websockets.svg +.. |pyversions| image:: https://img.shields.io/pypi/pyversions/websockets.svg :target: https://pypi.python.org/pypi/websockets -.. |pypi-wheel| image:: https://img.shields.io/pypi/wheel/websockets.svg +.. |wheel| image:: https://img.shields.io/pypi/wheel/websockets.svg :target: https://pypi.python.org/pypi/websockets -.. |tests| image:: https://github.com/aaugustin/websockets/actions/workflows/tests.yml/badge.svg +.. |tests| image:: https://img.shields.io/github/checks-status/aaugustin/websockets/main :target: https://github.com/aaugustin/websockets/actions/workflows/tests.yml -``websockets`` is a library for building WebSocket servers_ and clients_ in -Python with a focus on correctness and simplicity. +.. |docs| image:: https://img.shields.io/readthedocs/websockets.svg + :target: https://websockets.readthedocs.io/ + +websockets is a library for building WebSocket servers_ and clients_ in Python +with a focus on correctness and simplicity. .. _servers: https://github.com/aaugustin/websockets/blob/main/example/server.py .. _clients: https://github.com/aaugustin/websockets/blob/main/example/client.py @@ -40,27 +43,22 @@ Do you like it? Let's dive in! Tutorials --------- -If you're new to ``websockets``, this is the place to start. +If you're new to websockets, this is the place to start. .. toctree:: :maxdepth: 2 - intro - faq + intro/index How-to guides ------------- -These guides will help you build and deploy a ``websockets`` application. +These guides will help you build and deploy a websockets application. .. toctree:: :maxdepth: 2 - cheatsheet - deployment - extensions - howto/django - howto/heroku + howto/index Reference --------- @@ -70,19 +68,17 @@ Find all the details you could ask for, and then some. .. toctree:: :maxdepth: 2 - api/index + reference/index -Discussions ------------ +Topics +------ -Get a deeper understanding of how ``websockets`` is built and why. +Get a deeper understanding of how websockets is built and why. .. toctree:: :maxdepth: 2 - design - limitations - security + topics/index Project ------- @@ -92,7 +88,4 @@ This is about websockets-the-project rather than websockets-the-software. .. toctree:: :maxdepth: 2 - changelog - contributing - license - For enterprise + project/index diff --git a/docs/intro.rst b/docs/intro/index.rst similarity index 85% rename from docs/intro.rst rename to docs/intro/index.rst index 58d482a09..49e17b668 100644 --- a/docs/intro.rst +++ b/docs/intro/index.rst @@ -6,7 +6,7 @@ Getting started Requirements ------------ -``websockets`` requires Python ≥ 3.7. +websockets requires Python ≥ 3.7. You should use the latest version of Python if possible. If you're using an older version, be aware that for each minor version (3.x), only the latest @@ -15,7 +15,7 @@ bugfix release (3.x.y) is officially supported. Installation ------------ -Install ``websockets`` with:: +Install websockets with:: pip install websockets @@ -28,19 +28,19 @@ Here's a WebSocket server example. It reads a name from the client, sends a greeting, and closes the connection. -.. literalinclude:: ../example/server.py - :emphasize-lines: 8,17 +.. literalinclude:: ../../example/server.py + :emphasize-lines: 8,18 .. _client-example: -On the server side, ``websockets`` executes the handler coroutine ``hello`` -once for each WebSocket connection. It closes the connection when the handler +On the server side, websockets executes the handler coroutine ``hello`` once +for each WebSocket connection. It closes the connection when the handler coroutine returns. Here's a corresponding WebSocket client example. -.. literalinclude:: ../example/client.py - :emphasize-lines: 8,10 +.. literalinclude:: ../../example/client.py + :emphasize-lines: 10 Using :func:`connect` as an asynchronous context manager ensures the connection is closed before exiting the ``hello`` coroutine. @@ -60,13 +60,13 @@ Sockets Layer (SSL). WSS requires TLS certificates like HTTPS. Here's how to adapt the server example to provide secure connections. See the documentation of the :mod:`ssl` module for configuring the context securely. -.. literalinclude:: ../example/secure_server.py - :emphasize-lines: 19,23-25 +.. literalinclude:: ../../example/secure_server.py + :emphasize-lines: 19-21,26 Here's how to adapt the client. -.. literalinclude:: ../example/secure_client.py - :emphasize-lines: 10,15-18 +.. literalinclude:: ../../example/secure_client.py + :emphasize-lines: 10-12,18 This client needs a context because the server uses a self-signed certificate. @@ -81,11 +81,11 @@ Here's an example of how to run a WebSocket server and connect from a browser. Run this script in a console: -.. literalinclude:: ../example/show_time.py +.. literalinclude:: ../../example/show_time.py Then open this HTML file in a browser. -.. literalinclude:: ../example/show_time.html +.. literalinclude:: ../../example/show_time.html :language: html Synchronization example @@ -102,11 +102,11 @@ serialized. Run this script in a console: -.. literalinclude:: ../example/counter.py +.. literalinclude:: ../../example/counter.py Then open this HTML file in several browsers. -.. literalinclude:: ../example/counter.html +.. literalinclude:: ../../example/counter.html :language: html Common patterns @@ -147,8 +147,8 @@ messages to send on the WebSocket connection. :exc:`~exceptions.ConnectionClosed` exception when the client disconnects, which breaks out of the ``while True`` loop. -Both -.... +Both sides +.......... You can read and write messages on the same connection by combining the two patterns shown above and running the two tasks in parallel:: @@ -194,16 +194,16 @@ handler may subscribe to some channels on a message broker, for example. That's all! ----------- -The design of the ``websockets`` API was driven by simplicity. +The design of the websockets API was driven by simplicity. You don't have to worry about performing the opening or the closing handshake, answering pings, or any other behavior required by the specification. -``websockets`` handles all this under the hood so you don't have to. +websockets handles all this under the hood so you don't have to. One more thing... ----------------- -``websockets`` provides an interactive client:: +websockets provides an interactive client:: $ python -m websockets wss://echo.websocket.org/ diff --git a/docs/license.rst b/docs/license.rst deleted file mode 100644 index 842d3b07f..000000000 --- a/docs/license.rst +++ /dev/null @@ -1,4 +0,0 @@ -License -------- - -.. literalinclude:: ../LICENSE diff --git a/docs/limitations.rst b/docs/limitations.rst deleted file mode 100644 index bd6d32b2f..000000000 --- a/docs/limitations.rst +++ /dev/null @@ -1,10 +0,0 @@ -Limitations ------------ - -The client doesn't attempt to guarantee that there is no more than one -connection to a given IP address in a CONNECTING state. - -The client doesn't support connecting through a proxy. - -There is no way to fragment outgoing messages. A message is always sent in a -single frame. diff --git a/docs/changelog.rst b/docs/project/changelog.rst similarity index 94% rename from docs/changelog.rst rename to docs/project/changelog.rst index 17f640816..db29a7290 100644 --- a/docs/changelog.rst +++ b/docs/project/changelog.rst @@ -1,5 +1,5 @@ Changelog ---------- +========= .. currentmodule:: websockets @@ -8,9 +8,9 @@ Changelog Backwards-compatibility policy .............................. -``websockets`` is intended for production use. Therefore, stability is a goal. +websockets is intended for production use. Therefore, stability is a goal. -``websockets`` also aims at providing the best API for WebSocket in Python. +websockets also aims at providing the best API for WebSocket in Python. While we value stability, we value progress more. When an improvement requires changing a public API, we make the change and document it in this changelog. @@ -77,25 +77,25 @@ They may change at any time. them, you should adjust the import path. * The ``client``, ``server``, ``protocol``, and ``auth`` modules were - moved from the ``websockets`` package to ``websockets.legacy`` - sub-package, as part of an upcoming refactoring. Despite the name, - they're still fully supported. The refactoring should be a transparent - upgrade for most uses when it's available. The legacy implementation - will be preserved according to the `backwards-compatibility policy`_. + moved from the websockets package to ``websockets.legacy`` sub-package, + as part of an upcoming refactoring. Despite the name, they're still + fully supported. The refactoring should be a transparent upgrade for + most uses when it's available. The legacy implementation will be + preserved according to the `backwards-compatibility policy`_. * The ``framing``, ``handshake``, ``headers``, ``http``, and ``uri`` - modules in the ``websockets`` package are deprecated. These modules - provided low-level APIs for reuse by other WebSocket implementations, - but that never happened. Keeping these APIs public makes it more - difficult to improve websockets for no actual benefit. + modules in the websockets package are deprecated. These modules provided + low-level APIs for reuse by other WebSocket implementations, but that + never happened. Keeping these APIs public makes it more difficult to + improve websockets for no actual benefit. .. note:: **Version 9.0 may require changes if you use static code analysis tools.** - Convenience imports from the ``websockets`` module are performed lazily. - While this is supported by Python, static code analysis tools such as mypy - are unable to understand the behavior. + Convenience imports from the websockets module are performed lazily. While + this is supported by Python, static code analysis tools such as mypy are + unable to understand the behavior. If you depend on such tools, use the real import path, which can be found in the API documentation:: @@ -148,8 +148,7 @@ They may change at any time. *July 21, 2019* -* Restored the ability to import ``WebSocketProtocolError`` from - ``websockets``. +* Restored the ability to import ``WebSocketProtocolError`` from websockets. 8.0 ... @@ -261,7 +260,7 @@ Also: .. warning:: - ``websockets`` **now sends Ping frames at regular intervals and closes the + websockets **now sends Ping frames at regular intervals and closes the connection if it doesn't receive a matching Pong frame.** See :class:`~legacy.protocol.WebSocketCommonProtocol` for details. diff --git a/docs/contributing.rst b/docs/project/contributing.rst similarity index 76% rename from docs/contributing.rst rename to docs/project/contributing.rst index 59d7451e0..c3e8dfc4c 100644 --- a/docs/contributing.rst +++ b/docs/project/contributing.rst @@ -24,11 +24,12 @@ Bug reports, patches and suggestions are welcome! Please open an issue_ or send a `pull request`_. -Feedback about the documentation is especially valuable — the authors of -``websockets`` feel more confident about writing code than writing docs :-) +Feedback about the documentation is especially valuable, as the primary author +feels more confident about writing code than writing docs :-) If you're wondering why things are done in a certain way, the :doc:`design -document ` provides lots of details about the internals of websockets. +document <../topics/design>` provides lots of details about the internals of +websockets. .. _issue: https://github.com/aaugustin/websockets/issues/new .. _pull request: https://github.com/aaugustin/websockets/compare/ @@ -41,14 +42,14 @@ places to ask questions, for example Stack Overflow. If you want to ask a question anyway, please make sure that: -- it's a question about ``websockets`` and not about :mod:`asyncio`; -- it isn't answered by the documentation; +- it's a question about websockets and not about :mod:`asyncio`; +- it isn't answered in the documentation; - it wasn't asked already. A good question can be written as a suggestion to improve the documentation. -Bitcoin users -------------- +Cryptocurrency users +-------------------- websockets appears to be quite popular for interfacing with Bitcoin or other cryptocurrency trackers. I'm strongly opposed to Bitcoin's carbon footprint. @@ -58,7 +59,8 @@ carbon footprint of all cryptocurrencies drops to a non-bullshit level. Please stop heating the planet where my children are supposed to live, thanks. -Since ``websockets`` is released under an open-source license, you can use it -for any purpose you like. However, I won't spend any of my time to help. +Since websockets is released under an open-source license, you can use it for +any purpose you like. However, I won't spend any of my time to help you. I will summarily close issues related to Bitcoin or cryptocurrency in any way. +Thanks for your understanding. diff --git a/docs/project/index.rst b/docs/project/index.rst new file mode 100644 index 000000000..931fbe2d4 --- /dev/null +++ b/docs/project/index.rst @@ -0,0 +1,10 @@ +Project +======= + +.. toctree:: + :maxdepth: 2 + + changelog + contributing + license + For enterprise diff --git a/docs/project/license.rst b/docs/project/license.rst new file mode 100644 index 000000000..426110020 --- /dev/null +++ b/docs/project/license.rst @@ -0,0 +1,4 @@ +License +======= + +.. literalinclude:: ../../LICENSE diff --git a/docs/tidelift.rst b/docs/project/tidelift.rst similarity index 99% rename from docs/tidelift.rst rename to docs/project/tidelift.rst index 43b457aaf..42100fade 100644 --- a/docs/tidelift.rst +++ b/docs/project/tidelift.rst @@ -4,7 +4,7 @@ websockets for enterprise Available as part of the Tidelift Subscription ---------------------------------------------- -.. image:: _static/tidelift.png +.. image:: ../_static/tidelift.png :height: 150px :width: 150px :align: left diff --git a/docs/api/client.rst b/docs/reference/client.rst similarity index 100% rename from docs/api/client.rst rename to docs/reference/client.rst diff --git a/docs/api/extensions.rst b/docs/reference/extensions.rst similarity index 53% rename from docs/api/extensions.rst rename to docs/reference/extensions.rst index 71f015bb2..6ca20dbc8 100644 --- a/docs/api/extensions.rst +++ b/docs/reference/extensions.rst @@ -1,6 +1,16 @@ Extensions ========== +.. currentmodule:: websockets.extensions + +The WebSocket protocol supports extensions_. + +At the time of writing, there's only one `registered extension`_ with a public +specification, WebSocket Per-Message Deflate, specified in :rfc:`7692`. + +.. _extensions: https://tools.ietf.org/html/rfc6455#section-9 +.. _registered extension: https://www.iana.org/assignments/websocket/websocket.xhtml#extension-name + Per-Message Deflate ------------------- diff --git a/docs/api/index.rst b/docs/reference/index.rst similarity index 92% rename from docs/api/index.rst rename to docs/reference/index.rst index 0a616cbce..9fc4a0092 100644 --- a/docs/api/index.rst +++ b/docs/reference/index.rst @@ -1,8 +1,8 @@ -API -=== +Reference +========= -``websockets`` provides complete client and server implementations, as shown -in the :doc:`getting started guide <../intro>`. +websockets provides complete client and server implementations, as shown in +the :doc:`getting started guide <../intro/index>`. The process for opening and closing a WebSocket connection depends on which side you're implementing. @@ -44,6 +44,7 @@ both in the client API and server API. server extensions utilities + limitations All public APIs can be imported from the :mod:`websockets` package, unless noted otherwise. This convenience feature is incompatible with static code diff --git a/docs/reference/limitations.rst b/docs/reference/limitations.rst new file mode 100644 index 000000000..505186770 --- /dev/null +++ b/docs/reference/limitations.rst @@ -0,0 +1,35 @@ +Limitations +=========== + +Client +------ + +The client doesn't attempt to guarantee that there is no more than one +connection to a given IP address in a CONNECTING state. This behavior is +mandated by :rfc:`6455`. However, :func:`~websockets.connect()` isn't the +right layer for enforcing this constraint. It's the caller's responsibility. + +The client doesn't support connecting through a HTTP proxy (`issue 364`_) or a +SOCKS proxy (`issue 475`_). + +.. _issue 364: https://github.com/aaugustin/websockets/issues/364 +.. _issue 475: https://github.com/aaugustin/websockets/issues/475 + +Server +------ + +At this time, there are no known limitations affecting only the server. + +Both sides +---------- + +There is no way to control compression of outgoing frames on a per-frame basis +(`issue 538`_). If compression is enabled, all frames are compressed. + +.. _issue 538: https://github.com/aaugustin/websockets/issues/538 + +There is no way to receive each fragment of a fragmented messages as it +arrives (`issue 479`_). websockets always reassembles framented messages +before returning them. + +.. _issue 479: https://github.com/aaugustin/websockets/issues/479 diff --git a/docs/api/server.rst b/docs/reference/server.rst similarity index 100% rename from docs/api/server.rst rename to docs/reference/server.rst diff --git a/docs/api/utilities.rst b/docs/reference/utilities.rst similarity index 100% rename from docs/api/utilities.rst rename to docs/reference/utilities.rst diff --git a/docs/design.rst b/docs/topics/design.rst similarity index 87% rename from docs/design.rst rename to docs/topics/design.rst index 61b42b528..fa2093433 100644 --- a/docs/design.rst +++ b/docs/topics/design.rst @@ -3,8 +3,8 @@ Design .. currentmodule:: websockets -This document describes the design of ``websockets``. It assumes familiarity -with the specification of the WebSocket protocol in :rfc:`6455`. +This document describes the design of websockets. It assumes familiarity with +the specification of the WebSocket protocol in :rfc:`6455`. It's primarily intended at maintainers. It may also be useful for users who wish to understand what happens under the hood. @@ -13,8 +13,8 @@ wish to understand what happens under the hood. Internals described in this document may change at any time. - Backwards compatibility is only guaranteed for :doc:`public APIs `. - + Backwards compatibility is only guaranteed for :doc:`public APIs + <../reference/index>`. Lifecycle --------- @@ -91,7 +91,7 @@ the same :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` the opening handshake fails, in order to close the TCP connection. Splitting the responsibilities between two tasks makes it easier to guarantee -that ``websockets`` can terminate connections: +that websockets can terminate connections: - within a fixed timeout, - without leaking pending tasks, @@ -112,15 +112,16 @@ the TCP connection is closed. Opening handshake ----------------- -``websockets`` performs the opening handshake when establishing a WebSocket -connection. On the client side, :meth:`~legacy.client.connect` executes it before -returning the protocol to the caller. On the server side, it's executed before -passing the protocol to the ``ws_handler`` coroutine handling the connection. +websockets performs the opening handshake when establishing a WebSocket +connection. On the client side, :meth:`~legacy.client.connect` executes it +before returning the protocol to the caller. On the server side, it's executed +before passing the protocol to the ``ws_handler`` coroutine handling the +connection. While the opening handshake is asymmetrical — the client sends an HTTP Upgrade request and the server replies with an HTTP Switching Protocols response — -``websockets`` aims at keeping the implementation of both sides consistent -with one another. +websockets aims at keeping the implementation of both sides consistent with +one another. On the client side, :meth:`~legacy.client.WebSocketClientProtocol.handshake`: @@ -151,8 +152,8 @@ lies in the negotiation of extensions and, to a lesser extent, of the subprotocol. The server knows everything about both sides and decides what the parameters should be for the connection. The client merely applies them. -If anything goes wrong during the opening handshake, ``websockets`` -:ref:`fails the connection `. +If anything goes wrong during the opening handshake, websockets :ref:`fails +the connection `. .. _data-transfer: @@ -192,8 +193,8 @@ Data flow ......... The following diagram shows how data flows between an application built on top -of ``websockets`` and a remote endpoint. It applies regardless of which side -is the server or the client. +of websockets and a remote endpoint. It applies regardless of which side is +the server or the client. .. image:: protocol.svg :target: _images/protocol.svg @@ -205,7 +206,7 @@ termination is discussed in another section below. Receiving data .............. -The left side of the diagram shows how ``websockets`` receives data. +The left side of the diagram shows how websockets receives data. Incoming data is written to a :class:`~asyncio.StreamReader` in order to implement flow control and provide backpressure on the TCP connection. @@ -224,8 +225,8 @@ When it encounters a control frame: unsolicited pong). Running this process in a task guarantees that control frames are processed -promptly. Without such a task, ``websockets`` would depend on the application -to drive the connection by having exactly one coroutine awaiting +promptly. Without such a task, websockets would depend on the application to +drive the connection by having exactly one coroutine awaiting :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` at any time. While this happens naturally in many use cases, it cannot be relied upon. @@ -236,7 +237,7 @@ complexity added for handling backpressure and termination correctly. Sending data ............ -The right side of the diagram shows how ``websockets`` sends data. +The right side of the diagram shows how websockets sends data. :meth:`~legacy.protocol.WebSocketCommonProtocol.send` writes one or several data frames containing the message. While sending a fragmented message, concurrent @@ -274,13 +275,13 @@ state and sends a close frame. When the other side sends a close frame, :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` to terminate. If the other side doesn't send a close frame within the connection's close -timeout, ``websockets`` :ref:`fails the connection `. +timeout, websockets :ref:`fails the connection `. The closing handshake can take up to ``2 * close_timeout``: one ``close_timeout`` to write a close frame and one ``close_timeout`` to receive a close frame. -Then ``websockets`` terminates the TCP connection. +Then websockets terminates the TCP connection. .. _connection-termination: @@ -330,12 +331,12 @@ the connection drops regardless of what happens on the network. Connection failure ------------------ -If the opening handshake doesn't complete successfully, ``websockets`` fails -the connection by closing the TCP connection. +If the opening handshake doesn't complete successfully, websockets fails the +connection by closing the TCP connection. -Once the opening handshake has completed, ``websockets`` fails the connection -by canceling :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` and -sending a close frame if appropriate. +Once the opening handshake has completed, websockets fails the connection by +canceling :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` +and sending a close frame if appropriate. :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` exits, unblocking :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task`, which closes @@ -369,8 +370,8 @@ Cancellation User code ......... -``websockets`` provides a WebSocket application server. It manages connections -and passes them to user-provided connection handlers. This is an *inversion of +websockets provides a WebSocket application server. It manages connections and +passes them to user-provided connection handlers. This is an *inversion of control* scenario: library code calls user code. If a connection drops, the corresponding handler should terminate. If the @@ -387,15 +388,15 @@ interacts with finalization logic. In the example above, what if a handler gets interrupted with :exc:`~asyncio.CancelledError` while it's finalizing the tasks it started, after detecting that the connection dropped? -``websockets`` considers that cancellation may only be triggered by the caller -of a coroutine when it doesn't care about the results of that coroutine -anymore. (Source: `Guido van Rossum `_). Since connection handlers run -arbitrary user code, ``websockets`` has no way of deciding whether that code -is still doing something worth caring about. +arbitrary user code, websockets has no way of deciding whether that code is +still doing something worth caring about. -For these reasons, ``websockets`` never cancels connection handlers. Instead -it expects them to detect when the connection is closed, execute finalization +For these reasons, websockets never cancels connection handlers. Instead it +expects them to detect when the connection is closed, execute finalization logic if needed, and exit. Conversely, cancellation isn't a concern for WebSocket clients because they @@ -404,14 +405,14 @@ don't involve inversion of control. Library ....... -Most :doc:`public APIs ` of ``websockets`` are coroutines. They may -be canceled, for example if the user starts a task that calls these coroutines -and cancels the task later. ``websockets`` must handle this situation. +Most :doc:`public APIs <../reference/index>` of websockets are coroutines. +They may be canceled, for example if the user starts a task that calls these +coroutines and cancels the task later. websockets must handle this situation. Cancellation during the opening handshake is handled like any other exception: the TCP connection is closed and the exception is re-raised. This can only happen on the client side. On the server side, the opening handshake is -managed by ``websockets`` and nothing results in a cancellation. +managed by websockets and nothing results in a cancellation. Once the WebSocket connection is established, internal tasks :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` and @@ -473,16 +474,16 @@ The solution to this problem is backpressure. Any part of the server that receives inputs faster than it can process them and send the outputs must propagate that information back to the previous part in the chain. -``websockets`` is designed to make it easy to get backpressure right. +websockets is designed to make it easy to get backpressure right. -For incoming data, ``websockets`` builds upon :class:`~asyncio.StreamReader` -which propagates backpressure to its own buffer and to the TCP stream. Frames -are parsed from the input stream and added to a bounded queue. If the queue -fills up, parsing halts until the application reads a frame. +For incoming data, websockets builds upon :class:`~asyncio.StreamReader` which +propagates backpressure to its own buffer and to the TCP stream. Frames are +parsed from the input stream and added to a bounded queue. If the queue fills +up, parsing halts until the application reads a frame. -For outgoing data, ``websockets`` builds upon :class:`~asyncio.StreamWriter` -which implements flow control. If the output buffers grow too large, it waits -until they're drained. That's why all APIs that write frames are asynchronous. +For outgoing data, websockets builds upon :class:`~asyncio.StreamWriter` which +implements flow control. If the output buffers grow too large, it waits until +they're drained. That's why all APIs that write frames are asynchronous. Of course, it's still possible for an application to create its own unbounded buffers and break the backpressure. Be careful with queues. @@ -510,8 +511,8 @@ capacity — typically because the system is bottlenecked by the output and constantly regulated by backpressure — reducing the size of buffers minimizes negative consequences. -By default ``websockets`` has rather high limits. You can decrease them -according to your application's characteristics. +By default websockets has rather high limits. You can decrease them according +to your application's characteristics. Bufferbloat can happen at every level in the stack where there is a buffer. For each connection, the receiving side contains these buffers: diff --git a/docs/topics/index.rst b/docs/topics/index.rst new file mode 100644 index 000000000..157278f76 --- /dev/null +++ b/docs/topics/index.rst @@ -0,0 +1,8 @@ +Topics +====== + +.. toctree:: + :maxdepth: 2 + + design + security diff --git a/docs/lifecycle.graffle b/docs/topics/lifecycle.graffle similarity index 100% rename from docs/lifecycle.graffle rename to docs/topics/lifecycle.graffle diff --git a/docs/lifecycle.svg b/docs/topics/lifecycle.svg similarity index 100% rename from docs/lifecycle.svg rename to docs/topics/lifecycle.svg diff --git a/docs/protocol.graffle b/docs/topics/protocol.graffle similarity index 100% rename from docs/protocol.graffle rename to docs/topics/protocol.graffle diff --git a/docs/protocol.svg b/docs/topics/protocol.svg similarity index 100% rename from docs/protocol.svg rename to docs/topics/protocol.svg diff --git a/docs/security.rst b/docs/topics/security.rst similarity index 56% rename from docs/security.rst rename to docs/topics/security.rst index e9acf0629..39de08120 100644 --- a/docs/security.rst +++ b/docs/topics/security.rst @@ -17,9 +17,9 @@ Memory use An attacker who can open an arbitrary number of connections will be able to perform a denial of service by memory exhaustion. If you're concerned by denial of service attacks, you must reject suspicious connections - before they reach ``websockets``, typically in a reverse proxy. + before they reach websockets, typically in a reverse proxy. -With the default settings, opening a connection uses 325 KiB of memory. +With the default settings, opening a connection uses 70 KiB of memory. Sending some highly compressed messages could use up to 128 MiB of memory with an amplification factor of 1000 between network traffic and memory use. @@ -30,10 +30,12 @@ improve security in addition to improving performance. Other limits ------------ -``websockets`` implements additional limits on the amount of data it accepts -in order to minimize exposure to security vulnerabilities. +websockets implements additional limits on the amount of data it accepts in +order to minimize exposure to security vulnerabilities. -In the opening handshake, ``websockets`` limits the number of HTTP headers to -256 and the size of an individual header to 4096 bytes. These limits are 10 to -20 times larger than what's expected in standard use cases. They're hard-coded. -If you need to change them, monkey-patch the constants in ``websockets.http``. +In the opening handshake, websockets limits the number of HTTP headers to 256 +and the size of an individual header to 4096 bytes. These limits are 10 to 20 +times larger than what's expected in standard use cases. They're hard-coded. + +If you need to change these limits, you can monkey-patch the constants in +``websockets.http11``. diff --git a/example/health_check_server.py b/example/health_check_server.py index 2565f9c48..c5bb6d5ab 100755 --- a/example/health_check_server.py +++ b/example/health_check_server.py @@ -16,7 +16,8 @@ async def echo(websocket, path): async def main(): async with websockets.serve( - echo, "localhost", 8765, process_request=health_check + echo, "localhost", 8765, + process_request=health_check, ): await asyncio.Future() # run forever diff --git a/example/secure_client.py b/example/secure_client.py index 455b6492a..2657ba68b 100755 --- a/example/secure_client.py +++ b/example/secure_client.py @@ -14,7 +14,8 @@ async def hello(): uri = "wss://localhost:8765" async with websockets.connect( - uri, ssl=ssl_context + uri, + ssl=ssl_context, ) as websocket: name = input("What's your name? ") diff --git a/example/secure_server.py b/example/secure_server.py index 55b5a4231..e0ef6e53b 100755 --- a/example/secure_server.py +++ b/example/secure_server.py @@ -22,7 +22,8 @@ async def hello(websocket, path): async def main(): async with websockets.serve( - hello, "localhost", 8765, ssl=ssl_context + hello, "localhost", 8765, + ssl=ssl_context, ): await asyncio.Future() # run forever diff --git a/example/shutdown_client.py b/example/shutdown_client.py index ba1287801..539dd0304 100755 --- a/example/shutdown_client.py +++ b/example/shutdown_client.py @@ -7,8 +7,8 @@ async def client(): uri = "ws://localhost:8765" async with websockets.connect(uri) as websocket: - loop = asyncio.get_running_loop() # Close the connection when receiving SIGTERM. + loop = asyncio.get_running_loop() loop.add_signal_handler( signal.SIGTERM, loop.create_task, websocket.close()) diff --git a/example/shutdown_server.py b/example/shutdown_server.py index 5732313cb..1ae44af1e 100755 --- a/example/shutdown_server.py +++ b/example/shutdown_server.py @@ -9,9 +9,9 @@ async def echo(websocket, path): await websocket.send(message) async def server(): + # Set the stop condition when receiving SIGTERM. loop = asyncio.get_running_loop() stop = loop.create_future() - # Set the stop condition when receiving SIGTERM. loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) async with websockets.serve(echo, "localhost", 8765): await stop From ad2e643b4a7a98afeb0bfe6c0d7b7105f2061779 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 May 2021 18:10:31 +0200 Subject: [PATCH 0797/1539] Add graceful shutdown to Heroku guide. --- docs/howto/heroku.rst | 62 ++++++++++++++++++++++++++------ example/deployment/heroku/app.py | 20 +++++++++-- 2 files changed, 68 insertions(+), 14 deletions(-) diff --git a/docs/howto/heroku.rst b/docs/howto/heroku.rst index c6d81ab1d..6d4346dcd 100644 --- a/docs/howto/heroku.rst +++ b/docs/howto/heroku.rst @@ -1,9 +1,11 @@ Deploy to Heroku ================ -This guide describes how to deploy a websockets server to Heroku_. We're going -to deploy a very simple app. The process would be identical for a more -realistic app. It would be similar on other Platform as a Service providers. +This guide describes how to deploy a websockets server to Heroku_. The same +principles should apply to other Platform as a Service providers. + +We're going to deploy a very simple app. The process would be identical for a +more realistic app. .. _Heroku: https://www.heroku.com/ @@ -41,8 +43,17 @@ Here's the implementation of the app, an echo server. Save it in a file called .. literalinclude:: ../../example/deployment/heroku/app.py -The server relies on the ``$PORT`` environment variable to tell on which port -it will listen, according to Heroku's conventions. +Heroku expects the server to `listen on a specific port`_, which is provided +in the ``$PORT`` environment variable. The app reads it and passes it to +:func:`~websockets.server.serve`. + +.. _listen on a specific port: https://devcenter.heroku.com/articles/preparing-a-codebase-for-heroku-deployment#4-listen-on-the-correct-port + +Heroku sends a ``SIGTERM`` signal to all processes when `shutting down a +dyno`_. When the app receives this signal, it closes connections and exits +cleanly. + +.. _shutting down a dyno: https://devcenter.heroku.com/articles/dynos#shutdown Configure deployment -------------------- @@ -70,7 +81,7 @@ Confirm that you created the correct files and commit them to git: $ git add . $ git commit -m "Deploy echo server to Heroku." [main 8418c62] Deploy echo server to Heroku. -  3 files changed, 19 insertions(+) +  3 files changed, 32 insertions(+)  create mode 100644 Procfile  create mode 100644 app.py  create mode 100644 requirements.txt @@ -78,7 +89,7 @@ Confirm that you created the correct files and commit them to git: Deploy ------ -Our app is ready. Let's deploy it! +The app is ready. Let's deploy it! .. code:: console @@ -87,7 +98,7 @@ Our app is ready. Let's deploy it! ... lots of output... remote: -----> Launching... - remote: Released v3 + remote: Released v1 remote: https://websockets-echo.herokuapp.com/ deployed to Heroku remote: remote: Verifying deploy... done. @@ -97,9 +108,9 @@ Our app is ready. Let's deploy it! Validate deployment ------------------- -Of course we'd like to confirm that our application is running as expected! +Of course you'd like to confirm that your application is running as expected! -Since it's a WebSocket server, we need a WebSocket client, such as the +Since it's a WebSocket server, you need a WebSocket client, such as the interactive client that comes with websockets. If you're currently building a websockets server, perhaps you're already in a @@ -121,7 +132,7 @@ Connect the interactive client — using the name of your Heroku app instead of Connected to wss://websockets-echo.herokuapp.com/. > -Great! Our app is running! +Great! Your app is running! In this example, I used a secure connection (``wss://``). It worked because Heroku served a valid TLS certificate for ``websockets-echo.herokuapp.com``. @@ -135,3 +146,32 @@ then press Ctrl-D to terminate the connection: > Hello! < Hello! Connection closed: code = 1000 (OK), no reason. + +You can also confirm that your application shuts down gracefully. Connect an +interactive client again — remember to replace ``websockets-echo`` with your app: + +.. code:: console + + $ python -m websockets wss://websockets-echo.herokuapp.com/ + Connected to wss://websockets-echo.herokuapp.com/. + > + +In another shell, restart the dyno — again, replace ``websockets-echo`` with your app: + +.. code:: console + + $ heroku dyno:restart -a websockets-echo + Restarting dynos on ⬢ websockets-echo... done + +Go back to the first shell. The connection is closed with code 1001 (going +away). + +.. code:: console + + $ python -m websockets wss://websockets-echo.herokuapp.com/ + Connected to wss://websockets-echo.herokuapp.com/. + Connection closed: code = 1001 (going away), no reason. + +If graceful shutdown wasn't working, the server wouldn't perform a closing +handshake and the connection would be closed with code 1006 (connection closed +abnormally). diff --git a/example/deployment/heroku/app.py b/example/deployment/heroku/app.py index aceb754a3..ff9ba2775 100644 --- a/example/deployment/heroku/app.py +++ b/example/deployment/heroku/app.py @@ -1,16 +1,30 @@ #!/usr/bin/env python import asyncio +import signal import os import websockets + async def echo(websocket, path): async for message in websocket: await websocket.send(message) + async def main(): - async with websockets.serve(echo, "", int(os.environ["PORT"])): - await asyncio.Future() # run forever + # Set the stop condition when receiving SIGTERM. + loop = asyncio.get_running_loop() + stop = loop.create_future() + loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) + + async with websockets.serve( + echo, + host="", + port=int(os.environ["PORT"]), + ): + await stop + -asyncio.run(main()) +if __name__ == "__main__": + asyncio.run(main()) From 83d2f1dedcd104ad7172d471a8ebc3d4971257b3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 17 May 2021 22:10:37 +0200 Subject: [PATCH 0798/1539] Document deployment to Kubernetes. Refs #445. --- docs/howto/heroku.rst | 7 +- docs/howto/index.rst | 1 + docs/howto/kubernetes.rst | 213 ++++++++++++++++++ example/deployment/kubernetes/Dockerfile | 7 + example/deployment/kubernetes/app.py | 49 ++++ example/deployment/kubernetes/benchmark.py | 27 +++ example/deployment/kubernetes/deployment.yaml | 35 +++ 7 files changed, 334 insertions(+), 5 deletions(-) create mode 100644 docs/howto/kubernetes.rst create mode 100644 example/deployment/kubernetes/Dockerfile create mode 100755 example/deployment/kubernetes/app.py create mode 100755 example/deployment/kubernetes/benchmark.py create mode 100644 example/deployment/kubernetes/deployment.yaml diff --git a/docs/howto/heroku.rst b/docs/howto/heroku.rst index 6d4346dcd..92dee5f27 100644 --- a/docs/howto/heroku.rst +++ b/docs/howto/heroku.rst @@ -55,8 +55,8 @@ cleanly. .. _shutting down a dyno: https://devcenter.heroku.com/articles/dynos#shutdown -Configure deployment --------------------- +Deploy application +------------------ In order to build the app, Heroku needs to know that it depends on websockets. Create a ``requirements.txt`` file containing this line: @@ -86,9 +86,6 @@ Confirm that you created the correct files and commit them to git:  create mode 100644 app.py  create mode 100644 requirements.txt -Deploy ------- - The app is ready. Let's deploy it! .. code:: console diff --git a/docs/howto/index.rst b/docs/howto/index.rst index bcacbc174..c5cb4b8f6 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -33,3 +33,4 @@ optimized, and secure setup. deployment compression heroku + kubernetes diff --git a/docs/howto/kubernetes.rst b/docs/howto/kubernetes.rst new file mode 100644 index 000000000..ef5c963d7 --- /dev/null +++ b/docs/howto/kubernetes.rst @@ -0,0 +1,213 @@ +Deploy to Kubernetes +==================== + +This guide describes how to deploy a websockets server to Kubernetes_. It +assumes familiarity with Docker and Kubernetes. + +We're going to deploy a simple app to a local Kubernetes cluster and to ensure +that it scales as expected. + +In a more realistic context, you would follow your organization's practices +for deploying to Kubernetes, but you would apply the same principles as far as +websockets is concerned. + +.. _Kubernetes: https://kubernetes.io/ + +Containerize application +------------------------ + +Here's the app we're going to deploy. Save it in a file called +``app.py``: + +.. literalinclude:: ../../example/deployment/kubernetes/app.py + +This is an echo server with one twist: every message blocks the server for +100ms, which creates artificial starvation of CPU time. This makes it easier +to saturate the server for load testing. + +The app exposes a health check on ``/healthz``. It also provides two other +endpoints for testing purposes: ``/inemuri`` will make the app unresponsive +for 10 seconds and ``/seppuku`` will terminate it. + +The quest for the perfect Python container image is out of scope of this +guide, so we'll go for the simplest possible configuration instead: + +.. literalinclude:: ../../example/deployment/kubernetes/Dockerfile + +After saving this ``Dockerfile``, build the image: + +.. code:: console + + $ docker build -t websockets-test:1.0 . + +Test your image by running: + +.. code:: console + + $ docker run --name run-websockets-test --publish 32080:80 --rm \ + websockets-test:1.0 + +Then, in another shell, in a virtualenv where websockets is installed, connect +to the app and check that it echoes anything you send: + +.. code:: console + + $ python -m websockets ws://localhost:32080/ + Connected to ws://localhost:32080/. + > Hey there! + < Hey there! + > + +Now, in yet another shell, stop the app with: + +.. code:: console + + $ docker kill -s TERM run-websockets-test + +Going to the shell where you connected to the app, you can confirm that it +shut down gracefully: + +.. code:: console + + $ python -m websockets ws://localhost:32080/ + Connected to ws://localhost:32080/. + > Hey there! + < Hey there! + Connection closed: code = 1001 (going away), no reason. + +If it didn't, you'd get code 1006 (connection closed abnormally). + +Deploy application +------------------ + +Configuring Kubernetes is even further beyond the scope of this guide, so +we'll use a basic configuration for testing, with just one Service_ and one +Deployment_: + +.. literalinclude:: ../../example/deployment/kubernetes/deployment.yaml + +For local testing, a service of type NodePort_ is good enough. For deploying +to production, you would configure an Ingress_. + +.. _Service: https://kubernetes.io/docs/concepts/services-networking/service/ +.. _Deployment: https://kubernetes.io/docs/concepts/workloads/controllers/deployment/ +.. _NodePort: https://kubernetes.io/docs/concepts/services-networking/service/#nodeport +.. _Ingress: https://kubernetes.io/docs/concepts/services-networking/ingress/ + +After saving this to a file called ``deployment.yaml``, you can deploy: + +.. code:: console + + $ kubectl apply -f deployment.yaml + service/websockets-test created + deployment.apps/websockets-test created + +Now you have a deployment with one pod running: + +.. code:: console + + $ kubectl get deployment websockets-test + NAME READY UP-TO-DATE AVAILABLE AGE + websockets-test 1/1 1 1 10s + $ kubectl get pods -l app=websockets-test + NAME READY STATUS RESTARTS AGE + websockets-test-86b48f4bb7-nltfh 1/1 Running 0 10s + +You can connect to the service — press Ctrl-D to exit: + +.. code:: console + + $ python -m websockets ws://localhost:32080/ + Connected to ws://localhost:32080/. + Connection closed: code = 1000 (OK), no reason. + +Validate deployment +------------------- + +First, let's ensure the liveness probe works by making the app unresponsive: + +.. code:: console + + $ curl http://localhost:32080/inemuri + Sleeping for 10s + +Since we have only one pod, we know that this pod will go to sleep. + +The liveness probe is configured to run every second. By default, liveness +probes time out after one second and have a threshold of three failures. +Therefore Kubernetes should restart the pod after at most 5 seconds. + +Indeed, after a few seconds, the pod reports a restart: + +.. code:: console + + $ kubectl get pods -l app=websockets-test + NAME READY STATUS RESTARTS AGE + websockets-test-86b48f4bb7-nltfh 1/1 Running 1 42s + +Next, let's take it one step further and crash the app: + +.. code:: console + + $ curl http://localhost:32080/seppuku + Terminating + +The pod reports a second restart: + +.. code:: console + + $ kubectl get pods -l app=websockets-test + NAME READY STATUS RESTARTS AGE + websockets-test-86b48f4bb7-nltfh 1/1 Running 2 72s + +All good — Kubernetes delivers on its promise to keep our app alive! + +Scale deployment +---------------- + +Of course, Kubernetes is for scaling. Let's scale — modestly — to 10 pods: + +.. code:: console + + $ kubectl scale deployment.apps/websockets-test --replicas=10 + deployment.apps/websockets-test scaled + +After a few seconds, we have 10 pods running: + +.. code:: console + + $ kubectl get deployment websockets-test + NAME READY UP-TO-DATE AVAILABLE AGE + websockets-test 10/10 10 10 10m + +Now let's generate load. We'll use this script: + +.. literalinclude:: ../../example/deployment/kubernetes/benchmark.py + +We'll connect 500 clients in parallel, meaning 50 clients per pod, and have +each client send 6 messages. Since the app blocks for 100ms before responding, +if connections are perfectly distributed, we expect a total run time slightly +over 50 * 6 * 0.1 = 30 seconds. + +Let's try it: + +.. code:: console + + $ ulimit -n 512 + $ time python benchmark.py 500 6 + python benchmark.py 500 6 2.40s user 0.51s system 7% cpu 36.471 total + +A total runtime of 36 seconds is in the right ballpark. Repeating this +experiment with other parameters shows roughly consistent results, with the +high variability you'd expect from a quick benchmark without any effort to +stabilize the test setup. + +Finally, we can scale back to one pod. + +.. code:: console + + $ kubectl scale deployment.apps/websockets-test --replicas=1 + deployment.apps/websockets-test scaled + $ kubectl get deployment websockets-test + NAME READY UP-TO-DATE AVAILABLE AGE + websockets-test 1/1 1 1 15m diff --git a/example/deployment/kubernetes/Dockerfile b/example/deployment/kubernetes/Dockerfile new file mode 100644 index 000000000..83ed8722c --- /dev/null +++ b/example/deployment/kubernetes/Dockerfile @@ -0,0 +1,7 @@ +FROM python:3.9-alpine + +RUN pip install websockets + +COPY app.py . + +CMD ["python", "app.py"] diff --git a/example/deployment/kubernetes/app.py b/example/deployment/kubernetes/app.py new file mode 100755 index 000000000..dcc29bd1c --- /dev/null +++ b/example/deployment/kubernetes/app.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python + +import asyncio +import http +import signal +import sys +import time + +import websockets + + +async def slow_echo(websocket, path): + async for message in websocket: + # Block the event loop! This allows saturating a single asyncio + # process without opening an impractical number of connections. + time.sleep(0.1) # 100ms + await websocket.send(message) + + +async def health_check(path, request_headers): + if path == "/healthz": + return http.HTTPStatus.OK, [], b"OK\n" + if path == "/inemuri": + loop = asyncio.get_running_loop() + loop.call_later(1, time.sleep, 10) + return http.HTTPStatus.OK, [], b"Sleeping for 10s\n" + if path == "/seppuku": + loop = asyncio.get_running_loop() + loop.call_later(1, sys.exit, 69) + return http.HTTPStatus.OK, [], b"Terminating\n" + + +async def main(): + # Set the stop condition when receiving SIGTERM. + loop = asyncio.get_running_loop() + stop = loop.create_future() + loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) + + async with websockets.serve( + slow_echo, + host="", + port=80, + process_request=health_check, + ): + await stop + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/deployment/kubernetes/benchmark.py b/example/deployment/kubernetes/benchmark.py new file mode 100755 index 000000000..600c47316 --- /dev/null +++ b/example/deployment/kubernetes/benchmark.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python + +import asyncio +import sys +import websockets + + +URI = "ws://localhost:32080" + + +async def run(client_id, messages): + async with websockets.connect(URI) as websocket: + for message_id in range(messages): + await websocket.send("{client_id}:{message_id}") + await websocket.recv() + + +async def benchmark(clients, messages): + await asyncio.wait([ + asyncio.create_task(run(client_id, messages)) + for client_id in range(clients) + ]) + + +if __name__ == "__main__": + clients, messages = int(sys.argv[1]), int(sys.argv[2]) + asyncio.run(benchmark(clients, messages)) diff --git a/example/deployment/kubernetes/deployment.yaml b/example/deployment/kubernetes/deployment.yaml new file mode 100644 index 000000000..ba58dd62b --- /dev/null +++ b/example/deployment/kubernetes/deployment.yaml @@ -0,0 +1,35 @@ +apiVersion: v1 +kind: Service +metadata: + name: websockets-test +spec: + type: NodePort + ports: + - port: 80 + nodePort: 32080 + selector: + app: websockets-test +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: websockets-test +spec: + selector: + matchLabels: + app: websockets-test + template: + metadata: + labels: + app: websockets-test + spec: + containers: + - name: websockets-test + image: websockets-test:1.0 + livenessProbe: + httpGet: + path: /healthz + port: 80 + periodSeconds: 1 + ports: + - containerPort: 80 From af5bfd1b52b70ace251abe2e0ca7541b316ea2c3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 19 May 2021 09:09:58 +0200 Subject: [PATCH 0799/1539] Deployment deployment under Supervisor. Refs #445. --- docs/howto/index.rst | 1 + docs/howto/supervisor.rst | 131 ++++++++++++++++++ example/deployment/supervisor/app.py | 28 ++++ .../deployment/supervisor/supervisord.conf | 7 + 4 files changed, 167 insertions(+) create mode 100644 docs/howto/supervisor.rst create mode 100644 example/deployment/supervisor/app.py create mode 100644 example/deployment/supervisor/supervisord.conf diff --git a/docs/howto/index.rst b/docs/howto/index.rst index c5cb4b8f6..033dfe15d 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -34,3 +34,4 @@ optimized, and secure setup. compression heroku kubernetes + supervisor diff --git a/docs/howto/supervisor.rst b/docs/howto/supervisor.rst new file mode 100644 index 000000000..6d679ca2b --- /dev/null +++ b/docs/howto/supervisor.rst @@ -0,0 +1,131 @@ +Deploy with Supervisor +====================== + +This guide proposes a simple way to deploy a websockets server directly on a +Linux or BSD operating system. + +We'll configure Supervisor_ to run several server processes and to restart +them if needed. + +.. _Supervisor: http://supervisord.org/ + +We'll bind all servers to the same port. The OS will take care of balancing +connections. + +Create and activate a virtualenv: + +.. code:: console + + $ python -m venv supervisor-websockets + $ . supervisor-websockets/bin/activate + +Install websockets and Supervisor: + +.. code:: console + + $ pip install websockets + $ pip install supervisor + +Save this app to a file called ``app.py``: + +.. literalinclude:: ../../example/deployment/supervisor/app.py + +This is an echo server with two features added for the purpose of this guide: + +* It shuts down gracefully when receiving a ``SIGTERM`` signal; +* It enables the ``reuse_port`` option of :meth:`~asyncio.loop.create_server`, + which in turns sets ``SO_REUSEPORT`` on the accept socket. + +Save this Supervisor configuration to ``supervisord.conf``: + +.. literalinclude:: ../../example/deployment/supervisor/supervisord.conf + +This is the minimal configuration required to keep four instances of the app +running, restarting them if they exit. + +Now start Supervisor in the foreground: + +.. code:: console + + $ supervisord -c supervisord.conf -n + INFO Increased RLIMIT_NOFILE limit to 1024 + INFO supervisord started with pid 43596 + INFO spawned: 'websockets-test_00' with pid 43597 + INFO spawned: 'websockets-test_01' with pid 43598 + INFO spawned: 'websockets-test_02' with pid 43599 + INFO spawned: 'websockets-test_03' with pid 43600 + INFO success: websockets-test_00 entered RUNNING state, process has stayed up for > than 1 seconds (startsecs) + INFO success: websockets-test_01 entered RUNNING state, process has stayed up for > than 1 seconds (startsecs) + INFO success: websockets-test_02 entered RUNNING state, process has stayed up for > than 1 seconds (startsecs) + INFO success: websockets-test_03 entered RUNNING state, process has stayed up for > than 1 seconds (startsecs) + +In another shell, after activating the virtualenv, we can connect to the app — +press Ctrl-D to exit: + +.. code:: console + + $ python -m websockets ws://localhost:8080/ + Connected to ws://localhost:8080/. + > Hello! + < Hello! + Connection closed: code = 1000 (OK), no reason. + +Look at the pid of an instance of the app in the logs and terminate it: + +.. code:: console + + $ kill -TERM 43597 + +The logs show that Supervisor restarted this instance: + +.. code:: console + + INFO exited: websockets-test_00 (exit status 0; expected) + INFO spawned: 'websockets-test_00' with pid 43629 + INFO success: websockets-test_00 entered RUNNING state, process has stayed up for > than 1 seconds (startsecs) + +Now let's check what happens when we shut down Supervisor, but first let's +establish a connection and leave it open: + +.. code:: console + + $ python -m websockets ws://localhost:8080/ + Connected to ws://localhost:8080/. + > + +Look at the pid of supervisord itself in the logs and terminate it: + +.. code:: console + + $ kill -TERM 43596 + +The logs show that Supervisor terminated all instances of the app before +exiting: + +.. code:: console + + WARN received SIGTERM indicating exit request + INFO waiting for websockets-test_00, websockets-test_01, websockets-test_02, websockets-test_03 to die + INFO stopped: websockets-test_02 (exit status 0) + INFO stopped: websockets-test_03 (exit status 0) + INFO stopped: websockets-test_01 (exit status 0) + INFO stopped: websockets-test_00 (exit status 0) + +And you can see that the connection to the app was closed gracefully: + +.. code:: console + + $ python -m websockets ws://localhost:8080/ + Connected to ws://localhost:8080/. + Connection closed: code = 1001 (going away), no reason. + +In this example, we've been sharing the same virtualenv for supervisor and +websockets. + +In a real deployment, you would likely: + +* Install Supervisor with the package manager of the OS. +* Create a virtualenv dedicated to your application. +* Add ``environment=PATH="path/to/your/virtualenv/bin"`` in the Supervisor + configuration. Then ``python app.py`` runs in that virtualenv. + diff --git a/example/deployment/supervisor/app.py b/example/deployment/supervisor/app.py new file mode 100644 index 000000000..5fa596d0a --- /dev/null +++ b/example/deployment/supervisor/app.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python + +import asyncio +import signal + +import websockets + + +async def echo(websocket, path): + async for message in websocket: + await websocket.send(message) + + +async def main(): + # Set the stop condition when receiving SIGTERM. + loop = asyncio.get_running_loop() + stop = loop.create_future() + loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) + + async with websockets.serve( + echo, "", 8080, + reuse_port=True, + ): + await stop + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/deployment/supervisor/supervisord.conf b/example/deployment/supervisor/supervisord.conf new file mode 100644 index 000000000..76a664d91 --- /dev/null +++ b/example/deployment/supervisor/supervisord.conf @@ -0,0 +1,7 @@ +[supervisord] + +[program:websockets-test] +command = python app.py +process_name = %(program_name)s_%(process_num)02d +numprocs = 4 +autorestart = true From b3e50f341de7df9b0b82787111017787f9e82ce4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 21 May 2021 22:39:41 +0200 Subject: [PATCH 0800/1539] Include all files in Heroku example. --- example/deployment/heroku/Procfile | 1 + example/deployment/heroku/requirements.txt | 1 + 2 files changed, 2 insertions(+) create mode 100644 example/deployment/heroku/Procfile create mode 100644 example/deployment/heroku/requirements.txt diff --git a/example/deployment/heroku/Procfile b/example/deployment/heroku/Procfile new file mode 100644 index 000000000..2e35818f6 --- /dev/null +++ b/example/deployment/heroku/Procfile @@ -0,0 +1 @@ +web: python app.py diff --git a/example/deployment/heroku/requirements.txt b/example/deployment/heroku/requirements.txt new file mode 100644 index 000000000..14774b465 --- /dev/null +++ b/example/deployment/heroku/requirements.txt @@ -0,0 +1 @@ +websockets From 559a5724a96e1e4d921ed48d37f49876069e65da Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 21 May 2021 22:40:57 +0200 Subject: [PATCH 0801/1539] Spell check docs. --- docs/reference/limitations.rst | 2 +- docs/spelling_wordlist.txt | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/reference/limitations.rst b/docs/reference/limitations.rst index 505186770..ecdde23bf 100644 --- a/docs/reference/limitations.rst +++ b/docs/reference/limitations.rst @@ -29,7 +29,7 @@ There is no way to control compression of outgoing frames on a per-frame basis .. _issue 538: https://github.com/aaugustin/websockets/issues/538 There is no way to receive each fragment of a fragmented messages as it -arrives (`issue 479`_). websockets always reassembles framented messages +arrives (`issue 479`_). websockets always reassembles fragmented messages before returning them. .. _issue 479: https://github.com/aaugustin/websockets/issues/479 diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 363413b4b..19c4b7c44 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -23,12 +23,15 @@ ctrl daemonize datastructures django +dyno fractalideas IPv iterable keepalive KiB +Kubernetes lifecycle +liveness lookups MiB mypy @@ -36,16 +39,19 @@ nginx onmessage parsers permessage +pid pong pongs pythonic redis +runtime scalable serializers subclasses subclassing subprotocol subprotocols +supervisord tidelift tls tox From 2ced669fa9ad1976d4bf197f5d966eaae7fafb21 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 May 2021 20:35:46 +0200 Subject: [PATCH 0802/1539] Make changelog more readable. --- docs/project/changelog.rst | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index db29a7290..f28ca7f0b 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -36,9 +36,10 @@ They may change at any time. .. note:: - **Version 10.0 deprecates the** ``loop`` **parameter from all APIs for the - same reasons the same change was made in Python 3.8. See the release notes - of Python 3.10 for details.** + **Version 10.0 deprecates the** ``loop`` **parameter from all APIs.** + + This reflects a decision made in Python 3.8. See the release notes of + Python 3.10 for details. * Added compatibility with Python 3.10. From 4d2cd03bae4558253d51997e6cee9bd87103b236 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 May 2021 20:36:58 +0200 Subject: [PATCH 0803/1539] Standardize style in Makefile. --- Makefile | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/Makefile b/Makefile index 06832945c..b15cd13c9 100644 --- a/Makefile +++ b/Makefile @@ -2,6 +2,7 @@ export PYTHONASYNCIODEBUG=1 export PYTHONPATH=src +export PYTHONWARNINGS=default default: coverage style @@ -12,13 +13,13 @@ style: mypy --strict src test: - python -W default -m unittest + python -m unittest coverage: - python -m coverage erase - python -W default -m coverage run -m unittest - python -m coverage html - python -m coverage report --show-missing --fail-under=100 + coverage erase + coverage run -m unittest + coverage html + coverage report --show-missing --fail-under=100 build: python setup.py build_ext --inplace From c0750dac4bbf30138fa825054fb065f0355e4150 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 May 2021 07:41:41 +0200 Subject: [PATCH 0804/1539] Add low tech icon generator. --- logo/github-social-preview.html | 7 ++++--- logo/icon.html | 25 +++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 3 deletions(-) create mode 100644 logo/icon.html diff --git a/logo/github-social-preview.html b/logo/github-social-preview.html index 187a183b0..7f2b45bad 100644 --- a/logo/github-social-preview.html +++ b/logo/github-social-preview.html @@ -1,5 +1,5 @@ - + GitHub social preview + + +

Take a screenshot of these DOM nodes to2x make a PNG.

+

8x8 / 16x16 @ 2x

+

16x16 / 32x32 @ 2x

+

32x32 / 32x32 @ 2x

+

32x32 / 64x64 @ 2x

+

64x64 / 128x128 @ 2x

+

128x128 / 256x256 @ 2x

+

256x256 / 512x512 @ 2x

+

512x512 / 1024x1024 @ 2x

+ + From dd6d6bce2d1ac7e50a0a90e44c8073a8a17a2c05 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 20 May 2021 22:10:15 +0200 Subject: [PATCH 0805/1539] Make it easier to customize authentication. --- docs/project/changelog.rst | 3 +++ docs/reference/server.rst | 9 +++++++-- src/websockets/legacy/auth.py | 36 ++++++++++++++++++++++++++++------- tests/legacy/test_auth.py | 18 ++++++++++++++++++ 4 files changed, 57 insertions(+), 9 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index f28ca7f0b..c2c869c13 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -45,6 +45,9 @@ They may change at any time. * Optimized default compression settings to reduce memory usage. +* Made it easier to customize authentication with + :meth:`~auth.BasicAuthWebSocketServerProtocol.check_credentials`. + 9.0.2 ..... diff --git a/docs/reference/server.rst b/docs/reference/server.rst index c1822dda6..1a2dd1c88 100644 --- a/docs/reference/server.rst +++ b/docs/reference/server.rst @@ -96,10 +96,15 @@ Basic authentication .. autoclass:: BasicAuthWebSocketServerProtocol - .. automethod:: process_request + .. attribute:: realm + + Scope of protection. + + If provided, it should contain only ASCII characters because the + encoding of non-ASCII characters is undefined. .. attribute:: username Username of the authenticated user. - + .. automethod:: check_credentials diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index e0beede57..e7e69cac1 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -35,22 +35,44 @@ class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol): """ + realm = "" + def __init__( self, *args: Any, - realm: str, - check_credentials: Callable[[str, str], Awaitable[bool]], + realm: Optional[str] = None, + check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None, **kwargs: Any, ) -> None: - self.realm = realm - self.check_credentials = check_credentials + if realm is not None: + self.realm = realm # shadow class attribute + self._check_credentials = check_credentials super().__init__(*args, **kwargs) + async def check_credentials(self, username: str, password: str) -> bool: + """ + Check whether credentials are authorized. + + If ``check_credentials`` returns ``True``, the WebSocket handshake + continues. If it returns ``False``, the handshake fails with a HTTP + 401 error. + + This coroutine may be overridden in a subclass, for example to + authenticate against a database or an external service. + + """ + if self._check_credentials is not None: + return await self._check_credentials(username, password) + + return False + async def process_request( - self, path: str, request_headers: Headers + self, + path: str, + request_headers: Headers, ) -> Optional[HTTPResponse]: """ - Check HTTP Basic Auth and return a HTTP 401 or 403 response if needed. + Check HTTP Basic Auth and return a HTTP 401 response if needed. """ try: @@ -84,7 +106,7 @@ async def process_request( def basic_auth_protocol_factory( - realm: str, + realm: Optional[str] = None, credentials: Optional[Union[Credentials, Iterable[Credentials]]] = None, check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None, create_protocol: Optional[Callable[[Any], BasicAuthWebSocketServerProtocol]] = None, diff --git a/tests/legacy/test_auth.py b/tests/legacy/test_auth.py index bb8c6a6eb..c4dbd88ad 100644 --- a/tests/legacy/test_auth.py +++ b/tests/legacy/test_auth.py @@ -25,6 +25,11 @@ async def process_request(self, path, request_headers): return await super().process_request(path, request_headers) +class CheckWebSocketServerProtocol(BasicAuthWebSocketServerProtocol): + async def check_credentials(self, username, password): + return password == "letmein" + + class AuthClientServerTests(ClientServerTestsMixin, AsyncioTestCase): create_protocol = basic_auth_protocol_factory( @@ -103,6 +108,19 @@ def test_basic_auth_custom_protocol(self): self.loop.run_until_complete(self.client.send("Hello!")) self.loop.run_until_complete(self.client.recv()) + @with_server(create_protocol=CheckWebSocketServerProtocol) + @with_client(user_info=("hello", "letmein")) + def test_basic_auth_custom_protocol_subclass(self): + self.loop.run_until_complete(self.client.send("Hello!")) + self.loop.run_until_complete(self.client.recv()) + + # CustomWebSocketServerProtocol doesn't override check_credentials + @with_server(create_protocol=CustomWebSocketServerProtocol) + def test_basic_auth_defaults_to_deny_all(self): + with self.assertRaises(InvalidStatusCode) as raised: + self.start_client(user_info=("hello", "iloveyou")) + self.assertEqual(raised.exception.status_code, 401) + @with_server(create_protocol=create_protocol) def test_basic_auth_missing_credentials(self): with self.assertRaises(InvalidStatusCode) as raised: From 8ab85e54a764a26529b801e263e5223b1b2921c9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 May 2021 08:11:46 +0200 Subject: [PATCH 0806/1539] Experiment authentication techniques. --- experiments/authentication/app.py | 225 ++++++++++++++++++ experiments/authentication/cookie.html | 15 ++ experiments/authentication/cookie.js | 36 +++ experiments/authentication/cookie_iframe.html | 9 + experiments/authentication/cookie_iframe.js | 36 +++ experiments/authentication/favicon.ico | Bin 0 -> 5430 bytes experiments/authentication/first_message.html | 14 ++ experiments/authentication/first_message.js | 11 + experiments/authentication/index.html | 12 + experiments/authentication/query_param.html | 14 ++ experiments/authentication/query_param.js | 11 + experiments/authentication/script.js | 51 ++++ experiments/authentication/style.css | 69 ++++++ experiments/authentication/test.html | 15 ++ experiments/authentication/test.js | 6 + experiments/authentication/user_info.html | 14 ++ experiments/authentication/user_info.js | 11 + 17 files changed, 549 insertions(+) create mode 100644 experiments/authentication/app.py create mode 100644 experiments/authentication/cookie.html create mode 100644 experiments/authentication/cookie.js create mode 100644 experiments/authentication/cookie_iframe.html create mode 100644 experiments/authentication/cookie_iframe.js create mode 100644 experiments/authentication/favicon.ico create mode 100644 experiments/authentication/first_message.html create mode 100644 experiments/authentication/first_message.js create mode 100644 experiments/authentication/index.html create mode 100644 experiments/authentication/query_param.html create mode 100644 experiments/authentication/query_param.js create mode 100644 experiments/authentication/script.js create mode 100644 experiments/authentication/style.css create mode 100644 experiments/authentication/test.html create mode 100644 experiments/authentication/test.js create mode 100644 experiments/authentication/user_info.html create mode 100644 experiments/authentication/user_info.js diff --git a/experiments/authentication/app.py b/experiments/authentication/app.py new file mode 100644 index 000000000..c3c9a0557 --- /dev/null +++ b/experiments/authentication/app.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python + +import asyncio +import http +import http.cookies +import pathlib +import signal +import urllib.parse +import uuid + +import websockets + + +# User accounts database + +USERS = {} + + +def create_token(user, lifetime=1): + """Create token for user and delete it once its lifetime is over.""" + token = uuid.uuid4().hex + USERS[token] = user + asyncio.get_running_loop().call_later(lifetime, USERS.pop, token) + return token + + +def get_user(token): + """Find user authenticated by token or return None.""" + return USERS.get(token) + + +# Utilities + + +def get_cookie(raw, key): + cookie = http.cookies.SimpleCookie(raw) + morsel = cookie.get(key) + if morsel is not None: + return morsel.value + + +def get_query_param(path, key): + query = urllib.parse.urlparse(path).query + params = urllib.parse.parse_qs(query) + values = params.get(key, []) + if len(values) == 1: + return values[0] + + +# Main HTTP server + +CONTENT_TYPES = { + ".css": "text/css", + ".html": "text/html; charset=utf-8", + ".ico": "image/x-icon", + ".js": "text/javascript", +} + + +async def serve_html(path, request_headers): + user = get_query_param(path, "user") + path = urllib.parse.urlparse(path).path + if path == "/": + if user is None: + page = "index.html" + else: + page = "test.html" + else: + page = path[1:] + + try: + template = pathlib.Path(__file__).with_name(page) + except ValueError: + pass + else: + if template.is_file(): + headers = {"Content-Type": CONTENT_TYPES[template.suffix]} + body = template.read_bytes() + if user is not None: + token = create_token(user) + body = body.replace(b"TOKEN", token.encode()) + return http.HTTPStatus.OK, headers, body + + return http.HTTPStatus.NOT_FOUND, {}, b"Not found\n" + + +async def noop_handler(websocket, path): + pass + + +# Send credentials as the first message in the WebSocket connection + + +async def first_message_handler(websocket, path): + token = await websocket.recv() + user = get_user(token) + if user is None: + await websocket.close(1011, "authentication failed") + return + + await websocket.send(f"Hello {user}!") + message = await websocket.recv() + assert message == f"Goodbye {user}." + + +# Add credentials to the WebSocket URI in a query parameter + + +class QueryParamProtocol(websockets.WebSocketServerProtocol): + async def process_request(self, path, headers): + token = get_query_param(path, "token") + if token is None: + return http.HTTPStatus.UNAUTHORIZED, [], b"Missing token\n" + + user = get_user(token) + if user is None: + return http.HTTPStatus.UNAUTHORIZED, [], b"Invalid token\n" + + self.user = user + + +async def query_param_handler(websocket, path): + user = websocket.user + + await websocket.send(f"Hello {user}!") + message = await websocket.recv() + assert message == f"Goodbye {user}." + + +# Set a cookie on the domain of the WebSocket URI + + +class CookieProtocol(websockets.WebSocketServerProtocol): + async def process_request(self, path, headers): + if "Upgrade" not in headers: + template = pathlib.Path(__file__).with_name(path[1:]) + headers = {"Content-Type": CONTENT_TYPES[template.suffix]} + body = template.read_bytes() + return http.HTTPStatus.OK, headers, body + + token = get_cookie(headers.get("Cookie", ""), "token") + if token is None: + return http.HTTPStatus.UNAUTHORIZED, [], b"Missing token\n" + + user = get_user(token) + if user is None: + return http.HTTPStatus.UNAUTHORIZED, [], b"Invalid token\n" + + self.user = user + + +async def cookie_handler(websocket, path): + user = websocket.user + + await websocket.send(f"Hello {user}!") + message = await websocket.recv() + assert message == f"Goodbye {user}." + + +# Adding credentials to the WebSocket URI in user information + + +class UserInfoProtocol(websockets.BasicAuthWebSocketServerProtocol): + async def check_credentials(self, username, password): + if username != "token": + return False + + user = get_user(password) + if user is None: + return False + + self.user = user + return True + + +async def user_info_handler(websocket, path): + user = websocket.user + + await websocket.send(f"Hello {user}!") + message = await websocket.recv() + assert message == f"Goodbye {user}." + + +# Start all five servers + + +async def main(): + # Set the stop condition when receiving SIGINT or SIGTERM. + loop = asyncio.get_running_loop() + stop = loop.create_future() + loop.add_signal_handler(signal.SIGINT, stop.set_result, None) + loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) + + async with websockets.serve( + noop_handler, + host="", + port=8000, + process_request=serve_html, + ), websockets.serve( + first_message_handler, + host="", + port=8001, + ), websockets.serve( + query_param_handler, + host="", + port=8002, + create_protocol=QueryParamProtocol, + ), websockets.serve( + cookie_handler, + host="", + port=8003, + create_protocol=CookieProtocol, + ), websockets.serve( + user_info_handler, + host="", + port=8004, + create_protocol=UserInfoProtocol, + ): + print("Running on http://localhost:8000/") + await stop + print("\rExiting") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/experiments/authentication/cookie.html b/experiments/authentication/cookie.html new file mode 100644 index 000000000..ca17358fd --- /dev/null +++ b/experiments/authentication/cookie.html @@ -0,0 +1,15 @@ + + + + Cookie | WebSocket Authentication + + + +

[??] Cookie

+

[OK] Cookie

+

[KO] Cookie

+ + + + + diff --git a/experiments/authentication/cookie.js b/experiments/authentication/cookie.js new file mode 100644 index 000000000..9c5d5f078 --- /dev/null +++ b/experiments/authentication/cookie.js @@ -0,0 +1,36 @@ +// wait for "load" rather than "DOMContentLoaded" +// to ensure that the iframe has finished loading +window.addEventListener("load", () => { + // create message channel to communicate + // with the iframe + const channel = new MessageChannel(); + const port1 = channel.port1; + + // receive WebSocket events from the iframe + const expected = getExpectedEvents(); + var actual = []; + port1.onmessage = ({ data }) => { + // respond to messages + if (data.type == "message") { + port1.postMessage({ + type: "message", + message: `Goodbye ${data.data.slice(6, -1)}.`, + }); + } + // run tests + actual.push(data); + testStep(expected, actual); + }; + + // send message channel to the iframe + const iframe = document.querySelector("iframe"); + const origin = "http://localhost:8003"; + const ports = [channel.port2]; + iframe.contentWindow.postMessage("", origin, ports); + + // send token to the iframe + port1.postMessage({ + type: "open", + token: token, + }); +}); diff --git a/experiments/authentication/cookie_iframe.html b/experiments/authentication/cookie_iframe.html new file mode 100644 index 000000000..9f49ebb9a --- /dev/null +++ b/experiments/authentication/cookie_iframe.html @@ -0,0 +1,9 @@ + + + + Cookie iframe | WebSocket Authentication + + + + + diff --git a/experiments/authentication/cookie_iframe.js b/experiments/authentication/cookie_iframe.js new file mode 100644 index 000000000..88b5d1f2b --- /dev/null +++ b/experiments/authentication/cookie_iframe.js @@ -0,0 +1,36 @@ +// receive message channel from the parent window +window.addEventListener("message", ({ ports }) => { + const port2 = ports[0]; + var websocket; + port2.onmessage = ({ data }) => { + switch (data.type) { + case "open": + websocket = initWebSocket(data.token, port2); + break; + case "message": + websocket.send(data.message); + break; + case "close": + websocket.close(data.code, data.reason); + break; + } + }; +}); + +// open WebSocket connection and relay events to the parent window +function initWebSocket(token, port2) { + document.cookie = `token=${token}; SameSite=Strict`; + const websocket = new WebSocket("ws://localhost:8003/"); + + websocket.addEventListener("open", ({ type }) => { + port2.postMessage({ type }); + }); + websocket.addEventListener("message", ({ type, data }) => { + port2.postMessage({ type, data }); + }); + websocket.addEventListener("close", ({ type, code, reason, wasClean }) => { + port2.postMessage({ type, code, reason, wasClean }); + }); + + return websocket; +} diff --git a/experiments/authentication/favicon.ico b/experiments/authentication/favicon.ico new file mode 100644 index 0000000000000000000000000000000000000000..36e855029d705e72d44428bda6e8cb6d3dd317ed GIT binary patch literal 5430 zcmeH~eNdFw6~-?im3Ahb_(wY9G?@uAnl{s!P5}k6EMLo)Am6YM4N*)C21rB%KSG3E zv;hMxiW&u@(MF93wOCQ3rD-sRXbhdyP7o3K#!S*q0u${7l(0J*XQZVahM#RhT&Mw%^BqMk{U@J1~1H2BJEU64r)0!B+s<_@|*2 zqhiBw1#eC^Kd~FBvDYyg*$#wV#Xwjm9t&v2p)n2r@@vI8+WgsgcG9|#V(P|dLI)7j zfdSf?v#0~lgtfyurUjuNe*Ta2+T*8ca~J%;(ZdETkY^gva%XNcqcj2kdN-G|0yeVF}ZFP=4Z!J2R#i0#02+Ir+*zVG@y z*5YS?yJuBTDvF9 z+e4qHUo$?zisgL(ZFDT{$HRZz5|LeB6ok5h0Mx8E{O!%Wpod<2-phQZDi;R}!Uptt z_Sffbz;E>+X1=#0_Q!9(68Y81O(FP$*?76Z67toHt(cKzH39p~78|EMbIvF7dZI9CxZHU=B0 z<9Jap>Rt-q+)$@;bvu`KC6(K3*mRuQ9MaslWomPE8-LktGQPhv{>AfKV-~kmXM9f| z|2D6HQbVN2tW*=RJWA^l;TTP>949?wSa4Hnl)qW|Kbb4f8Fvyy2h!?Rc2a%__sk=;L*m&%aOk`2^Q3=D%J$`lrna32wt&`ua?J=IYOX`)ey| zsDH*}F--J3_a!wE;q=XC^%sGy0H6D|y~p17*k>&_PGWEtANNEx9#Nf`$LIWDXd7VU z-lc0`F21g#6;|0_CyK+_IPED%Vmjd+bq#|NSGn)ee~4%xL9&NlK@0!?spEQZeA9Sw zOg84YE{XKby>e*LH9+>d2-$me4ei*@T2#ST@l)pX-aNEoMD}B9TkIH*)9-fc#{JDF zsS`tE`y^thLD#X6zKIf~ACX*0P0;qeI{2O1>Ze+hd$YWhu%sV8DP0&!?gSFKPsfuu z=^dyc`Wh;U66imQy~wci5ZZc7tYMe4y3>E-zOC46crNM2&=cI_Q|MnbFb{~Qf9lvq zl)zjpqW|G=)`eCR4jO{~dh5&OHAgQLZpnV9ydbWbam~@=o9J4d>8Y0WqGf3>kS6_3 zH=ygFmaRky%tJKeAJJ-p{_kqA)#Yh(*$efw^_AAoNKZthCz1G^u_xP8J>82TYJl`x zPtw0=VBb|il)!vMb4^CrJ8A0?#hMfy%!Slv9#Q+&`23lKphoC2`RmZF?C`@YT|BPQis$fd)qZc3H#T~U(G%W)mfglo{sU;_HV?#Giog+ zWdZKb7(gC1)GY7CGNOdEKE$SWVT5(5)r3};^sUup2XfhuD&bJR@3U{@`dIhp%!9oK z`hCKgZ~0C9j|Y3#aMr|#eN9k{rX#rbA2Ga;A`w_nTG-C ztBDe-vb<91XNIG%>P!!=&(~otdZi|$F7=v_x=f{SqJ-l`K`Vb(5MVvJ!Jze!ng;g7 z>?QBK{=AowD1llEwR+6*IO(WiJm0l|D{Ep{uL)g4S}$^l5>9OnTX~|`Xtnjy{y6%g zO~ax51$Wq&ClMvEKQ9vBhc`yq?ukr~>U(@1IJr6Gc0;i-fhgfV}k+ptmI=o*Qx0@RDGY}eIiR4)*!i%Vr#UY#Q>p{*{nuh%O zw-{rDA=?=E^vn-xf;^a)yc&4S#>DB;t~+h17%G6M9Y7Y%ttpeEt)~ z*kB1&)1jq0b@m6ll1AkWCmP=Q^&;&&DfRHy%i&r*$hok-P}6XevH3fT@AR)C)pTj8 zN-n>!BLP|-wu%PrJgTk5T@A${HyM{c{o!N z{uyKPol0D`gW)%O|F9}G?&6Mwv&}o=zc^nRJKR(eeeKWX(PfP1M^laIn|c0j*0+@b zZDY{i*oZ&{!b(QfqZlSb-*-dLEEhC+xWJOGU};h)wgh3bYC?&N1*5L&qtO{IrNa-n mx{c17;Wp~=fE$Kp5swEk + + + First message | WebSocket Authentication + + + +

[??] First message

+

[OK] First message

+

[KO] First message

+ + + + diff --git a/experiments/authentication/first_message.js b/experiments/authentication/first_message.js new file mode 100644 index 000000000..1acf048ba --- /dev/null +++ b/experiments/authentication/first_message.js @@ -0,0 +1,11 @@ +window.addEventListener("DOMContentLoaded", () => { + const websocket = new WebSocket("ws://localhost:8001/"); + websocket.onopen = () => websocket.send(token); + + websocket.onmessage = ({ data }) => { + // event.data is expected to be "Hello !" + websocket.send(`Goodbye ${data.slice(6, -1)}.`); + }; + + runTest(websocket); +}); diff --git a/experiments/authentication/index.html b/experiments/authentication/index.html new file mode 100644 index 000000000..c37deef27 --- /dev/null +++ b/experiments/authentication/index.html @@ -0,0 +1,12 @@ + + + + WebSocket Authentication + + + +
+ +
+ + diff --git a/experiments/authentication/query_param.html b/experiments/authentication/query_param.html new file mode 100644 index 000000000..27aa454a4 --- /dev/null +++ b/experiments/authentication/query_param.html @@ -0,0 +1,14 @@ + + + + Query parameter | WebSocket Authentication + + + +

[??] Query parameter

+

[OK] Query parameter

+

[KO] Query parameter

+ + + + diff --git a/experiments/authentication/query_param.js b/experiments/authentication/query_param.js new file mode 100644 index 000000000..6a54d0b6c --- /dev/null +++ b/experiments/authentication/query_param.js @@ -0,0 +1,11 @@ +window.addEventListener("DOMContentLoaded", () => { + const uri = `ws://localhost:8002/?token=${token}`; + const websocket = new WebSocket(uri); + + websocket.onmessage = ({ data }) => { + // event.data is expected to be "Hello !" + websocket.send(`Goodbye ${data.slice(6, -1)}.`); + }; + + runTest(websocket); +}); diff --git a/experiments/authentication/script.js b/experiments/authentication/script.js new file mode 100644 index 000000000..ec4e5e670 --- /dev/null +++ b/experiments/authentication/script.js @@ -0,0 +1,51 @@ +var token = window.parent.token; + +function getExpectedEvents() { + return [ + { + type: "open", + }, + { + type: "message", + data: `Hello ${window.parent.user}!`, + }, + { + type: "close", + code: 1000, + reason: "", + wasClean: true, + }, + ]; +} + +function isEqual(expected, actual) { + // good enough for our purposes here! + return JSON.stringify(expected) === JSON.stringify(actual); +} + +function testStep(expected, actual) { + if (isEqual(expected, actual)) { + document.body.className = "ok"; + } else if (isEqual(expected.slice(0, actual.length), actual)) { + document.body.className = "test"; + } else { + document.body.className = "ko"; + } +} + +function runTest(websocket) { + const expected = getExpectedEvents(); + var actual = []; + websocket.addEventListener("open", ({ type }) => { + actual.push({ type }); + testStep(expected, actual); + }); + websocket.addEventListener("message", ({ type, data }) => { + actual.push({ type, data }); + testStep(expected, actual); + }); + websocket.addEventListener("close", ({ type, code, reason, wasClean }) => { + actual.push({ type, code, reason, wasClean }); + testStep(expected, actual); + }); +} diff --git a/experiments/authentication/style.css b/experiments/authentication/style.css new file mode 100644 index 000000000..6e3918cca --- /dev/null +++ b/experiments/authentication/style.css @@ -0,0 +1,69 @@ +/* page layout */ + +body { + display: flex; + flex-direction: column; + justify-content: center; + align-items: center; + margin: 0; + height: 100vh; +} +div.title, iframe { + width: 100vw; + height: 20vh; + border: none; +} +div.title { + display: flex; + flex-direction: column; + justify-content: center; + align-items: center; +} +h1, p { + margin: 0; + width: 24em; +} + +/* text style */ + +h1, input, p { + font-family: monospace; + font-size: 3em; +} +input { + color: #333; + border: 3px solid #999; + padding: 1em; +} +input:focus { + border-color: #333; + outline: none; +} +input::placeholder { + color: #999; + opacity: 1; +} + +/* test results */ + +body.test { + background-color: #666; + color: #fff; +} +body.ok { + background-color: #090; + color: #fff; +} +body.ko { + background-color: #900; + color: #fff; +} +body > p { + display: none; +} +body > p.title, +body.test > p.test, +body.ok > p.ok, +body.ko > p.ko { + display: block; +} diff --git a/experiments/authentication/test.html b/experiments/authentication/test.html new file mode 100644 index 000000000..3883d6a39 --- /dev/null +++ b/experiments/authentication/test.html @@ -0,0 +1,15 @@ + + + + WebSocket Authentication + + + +

WebSocket Authentication

+ + + + + + + diff --git a/experiments/authentication/test.js b/experiments/authentication/test.js new file mode 100644 index 000000000..428830ff3 --- /dev/null +++ b/experiments/authentication/test.js @@ -0,0 +1,6 @@ +// for connecting to WebSocket servers +var token = document.body.dataset.token; + +// for test assertions only +const params = new URLSearchParams(window.location.search); +var user = params.get("user"); diff --git a/experiments/authentication/user_info.html b/experiments/authentication/user_info.html new file mode 100644 index 000000000..7b9c99c73 --- /dev/null +++ b/experiments/authentication/user_info.html @@ -0,0 +1,14 @@ + + + + User information | WebSocket Authentication + + + +

[??] User information

+

[OK] User information

+

[KO] User information

+ + + + diff --git a/experiments/authentication/user_info.js b/experiments/authentication/user_info.js new file mode 100644 index 000000000..1dab2ce4c --- /dev/null +++ b/experiments/authentication/user_info.js @@ -0,0 +1,11 @@ +window.addEventListener("DOMContentLoaded", () => { + const uri = `ws://token:${token}@localhost:8004/`; + const websocket = new WebSocket(uri); + + websocket.onmessage = ({ data }) => { + // event.data is expected to be "Hello !" + websocket.send(`Goodbye ${data.slice(6, -1)}.`); + }; + + runTest(websocket); +}); From f7a62680bc7df0f17efac83109daf3bdc14bc0f5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 May 2021 15:14:52 +0200 Subject: [PATCH 0807/1539] Simplify cookie-based authentication. --- experiments/authentication/cookie.js | 51 ++++++++------------- experiments/authentication/cookie_iframe.js | 43 ++++------------- 2 files changed, 27 insertions(+), 67 deletions(-) diff --git a/experiments/authentication/cookie.js b/experiments/authentication/cookie.js index 9c5d5f078..2cca34fcb 100644 --- a/experiments/authentication/cookie.js +++ b/experiments/authentication/cookie.js @@ -1,36 +1,23 @@ -// wait for "load" rather than "DOMContentLoaded" -// to ensure that the iframe has finished loading -window.addEventListener("load", () => { - // create message channel to communicate - // with the iframe - const channel = new MessageChannel(); - const port1 = channel.port1; +// send token to iframe +window.addEventListener("DOMContentLoaded", () => { + const iframe = document.querySelector("iframe"); + iframe.addEventListener("load", () => { + iframe.contentWindow.postMessage(token, "http://localhost:8003"); + }); +}); - // receive WebSocket events from the iframe - const expected = getExpectedEvents(); - var actual = []; - port1.onmessage = ({ data }) => { - // respond to messages - if (data.type == "message") { - port1.postMessage({ - type: "message", - message: `Goodbye ${data.data.slice(6, -1)}.`, - }); - } - // run tests - actual.push(data); - testStep(expected, actual); - }; +// once iframe has set cookie, open WebSocket connection +window.addEventListener("message", ({ origin }) => { + if (origin !== "http://localhost:8003") { + return; + } - // send message channel to the iframe - const iframe = document.querySelector("iframe"); - const origin = "http://localhost:8003"; - const ports = [channel.port2]; - iframe.contentWindow.postMessage("", origin, ports); + const websocket = new WebSocket("ws://localhost:8003/"); - // send token to the iframe - port1.postMessage({ - type: "open", - token: token, - }); + websocket.onmessage = ({ data }) => { + // event.data is expected to be "Hello !" + websocket.send(`Goodbye ${data.slice(6, -1)}.`); + }; + + runTest(websocket); }); diff --git a/experiments/authentication/cookie_iframe.js b/experiments/authentication/cookie_iframe.js index 88b5d1f2b..2d2e692e8 100644 --- a/experiments/authentication/cookie_iframe.js +++ b/experiments/authentication/cookie_iframe.js @@ -1,36 +1,9 @@ -// receive message channel from the parent window -window.addEventListener("message", ({ ports }) => { - const port2 = ports[0]; - var websocket; - port2.onmessage = ({ data }) => { - switch (data.type) { - case "open": - websocket = initWebSocket(data.token, port2); - break; - case "message": - websocket.send(data.message); - break; - case "close": - websocket.close(data.code, data.reason); - break; - } - }; -}); - -// open WebSocket connection and relay events to the parent window -function initWebSocket(token, port2) { - document.cookie = `token=${token}; SameSite=Strict`; - const websocket = new WebSocket("ws://localhost:8003/"); +// receive token from the parent window, set cookie and notify parent +window.addEventListener("message", ({ origin, data }) => { + if (origin !== "http://localhost:8000") { + return; + } - websocket.addEventListener("open", ({ type }) => { - port2.postMessage({ type }); - }); - websocket.addEventListener("message", ({ type, data }) => { - port2.postMessage({ type, data }); - }); - websocket.addEventListener("close", ({ type, code, reason, wasClean }) => { - port2.postMessage({ type, code, reason, wasClean }); - }); - - return websocket; -} + document.cookie = `token=${data}; SameSite=Strict`; + window.parent.postMessage("", "http://localhost:8000"); +}); From 920207c6d4d2b31def30970fc47c968806c2b64e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 May 2021 14:58:00 +0200 Subject: [PATCH 0808/1539] Discuss authentication in docs. --- docs/spelling_wordlist.txt | 2 + docs/topics/authentication.rst | 348 +++++++++++++++++++++++++++++++++ docs/topics/authentication.svg | 63 ++++++ docs/topics/index.rst | 1 + 4 files changed, 414 insertions(+) create mode 100644 docs/topics/authentication.rst create mode 100644 docs/topics/authentication.svg diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 19c4b7c44..030917491 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -25,6 +25,7 @@ datastructures django dyno fractalideas +iframe IPv iterable keepalive @@ -65,3 +66,4 @@ websocket websockets ws wss +www diff --git a/docs/topics/authentication.rst b/docs/topics/authentication.rst new file mode 100644 index 000000000..44c9c6151 --- /dev/null +++ b/docs/topics/authentication.rst @@ -0,0 +1,348 @@ +Authentication +============== + +The WebSocket protocol was designed for creating web applications that need +bidirectional communication between clients running in browsers and servers. + +In most practical use cases, WebSocket servers need to authenticate clients in +order to route communications appropriately and securely. + +:rfc:`6455` stays elusive when it comes to authentication: + + This protocol doesn't prescribe any particular way that servers can + authenticate clients during the WebSocket handshake. The WebSocket + server can use any client authentication mechanism available to a + generic HTTP server, such as cookies, HTTP authentication, or TLS + authentication. + +None of these three mechanisms works well in practice. Using cookies is +cumbersome, HTTP authentication isn't supported by all mainstream browsers, +and TLS authentication in a browser is an esoteric user experience. + +Fortunately, there are better alternatives! Let's discuss them. + +System design +------------- + +Consider a setup where the WebSocket server is separate from the HTTP server. + +Most servers built with websockets to complement a web application adopt this +design because websockets doesn't aim at supporting HTTP. + +The following diagram illustrates the authentication flow. + +.. image:: authentication.svg + +Assuming the current user is authenticated with the HTTP server (1), the +application needs to obtain credentials from the HTTP server (2) in order to +send them to the WebSocket server (3), who can check them against the database +of user accounts (4). + +Usernames and passwords aren't a good choice of credentials here, if only +because passwords aren't available in clear text in the database. + +Tokens linked to user accounts are a better choice. These tokens must be +impossible to forge by an attacker. For additional security, they can be +short-lived or even single-use. + +Sending credentials +------------------- + +Assume the web application obtained authentication credentials, likely a +token, from the HTTP server. There's four options for passing them to the +WebSocket server. + +1. **Sending credentials as the first message in the WebSocket connection.** + + This is fully reliable and the most secure mechanism in this discussion. It + has two minor downsides: + + * Authentication is performed at the application layer. Ideally, it would + be managed at the protocol layer. + + * Authentication is performed after the WebSocket handshake, making it + impossible to monitor authentication failures with HTTP response codes. + +2. **Adding credentials to the WebSocket URI in a query parameter.** + + This is also fully reliable but less secure. Indeed, it has a major + downside: + + * URIs end up in logs, which leaks credentials. Even if that risk could be + lowered with single-use tokens, it is usually considered unacceptable. + + Authentication is still performed at the application layer but it can + happen before the WebSocket handshake, which improves separation of + concerns and enables responding to authentication failures with HTTP 401. + +3. **Setting a cookie on the domain of the WebSocket URI.** + + Cookies are undoubtedly the most common and hardened mechanism for sending + credentials from a web application to a server. In a HTTP application, + credentials would be a session identifier or a serialized, signed session. + + Unfortunately, when the WebSocket server runs on a different domain from + the web application, this idea bumps into the `Same-Origin Policy`_. For + security reasons, setting a cookie on a different origin is impossible. + + The proper workaround consists in: + + * creating a hidden iframe_ served from the domain of the WebSocket server + * sending the token to the iframe with postMessage_ + * setting the cookie in the iframe + + before opening the WebSocket connection. + + Sharing a parent domain (e.g. example.com) between the HTTP server (e.g. + www.example.com) and the WebSocket server (e.g. ws.example.com) and setting + the cookie on that parent domain would work too. + + However, the cookie would be shared with all subdomains of the parent + domain. For a cookie containing credentials, this is unacceptable. + +.. _Same-Origin Policy: https://developer.mozilla.org/en-US/docs/Web/Security/Same-origin_policy +.. _iframe: https://developer.mozilla.org/en-US/docs/Web/HTML/Element/iframe +.. _postMessage: https://developer.mozilla.org/en-US/docs/Web/API/MessagePort/postMessage + +4. **Adding credentials to the WebSocket URI in user information.** + + Letting the browser perform HTTP Basic Auth is a nice idea in theory. + + In practice it doesn't work due to poor support in browsers. + + As of May 2021: + + * Chrome 90 behaves as expected. + + * Firefox 88 caches credentials too aggressively. + + When connecting again to the same server with new credentials, it reuses + the old credentials, which may be expired, resulting in an HTTP 401. Then + the next connection succeeds. Perhaps errors clear the cache. + + When tokens are short-lived on single-use, this bug produces an + interesting effect: every other WebSocket connection fails. + + * Safari 14 ignores credentials entirely. + +Two other options are off the table: + +1. **Setting a custom HTTP header** + + This would be the most elegant mechanism, solving all issues with the options + discussed above. + + Unfortunately, it doesn't work because the `WebSocket API`_ doesn't support + `setting custom headers`_. + +.. _WebSocket API: https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API +.. _setting custom headers: https://github.com/whatwg/html/issues/3062 + +2. **Authenticating with a TLS certificate** + + While this is suggested by the RFC, installing a TLS certificate is too far + from the mainstream experience of browser users. This could make sense in + high security contexts. I hope developers working on such projects don't + take security advice from the documentation of random open source projects. + +Let's experiment! +----------------- + +The `experiments/authentication`_ directory demonstrates these techniques. + +Run the experiment in an environment where websockets is installed: + +.. _experiments/authentication: https://github.com/aaugustin/websockets/tree/main/experiments/authentication + +.. code:: console + + $ python experiments/authentication/app.py + Running on http://localhost:8000/ + +When you browse to the HTTP server at http://localhost:8000/ and you submit a +username, the server creates a token and returns a testing web page. + +This page opens WebSocket connections to four WebSocket servers running on +four different origins. It attempts to authenticate with the token in four +different ways. + +First message +............. + +As soon as the connection is open, the client sends a message containing the +token: + +.. code:: javascript + + const websocket = new WebSocket("ws://.../"); + websocket.onopen = () => websocket.send(token); + + // ... + +At the beginning of the connection handler, the server receives this message +and authenticates the user. If authentication fails, the server closes the +connection: + +.. code:: python + + async def first_message_handler(websocket, path): + token = await websocket.recv() + user = get_user(token) + if user is None: + await websocket.close(1011, "authentication failed") + return + + ... + +Query parameter +............... + +The client adds the token to the WebSocket URI in a query parameter before +opening the connection: + +.. code:: javascript + + const uri = `ws://.../?token=${token}`; + const websocket = new WebSocket(uri); + + // ... + +The server intercepts the HTTP request, extracts the token and authenticates +the user. If authentication fails, it returns a HTTP 401: + +.. code:: python + + class QueryParamProtocol(websockets.WebSocketServerProtocol): + async def process_request(self, path, headers): + token = get_query_parameter(path, "token") + if token is None: + return http.HTTPStatus.UNAUTHORIZED, [], b"Missing token\n" + + user = get_user(token) + if user is None: + return http.HTTPStatus.UNAUTHORIZED, [], b"Invalid token\n" + + self.user = user + + async def query_param_handler(websocket, path): + user = websocket.user + + ... + +Cookie +...... + +The client sets a cookie containing the token before opening the connection. + +The cookie must be set by an iframe loaded from the same origin as the +WebSocket server. This requires passing the token to this iframe. + +.. code:: javascript + + // in main window + iframe.contentWindow.postMessage(token, "http://..."); + + // in iframe + document.cookie = `token=${data}; SameSite=Strict`; + + // in main window + const websocket = new WebSocket("ws://.../"); + + // ... + +This sequence must be synchronized between the main window and the iframe. +This involves several events. Look at the full implementation for details. + +The server intercepts the HTTP request, extracts the token and authenticates +the user. If authentication fails, it returns a HTTP 401: + +.. code:: python + + class CookieProtocol(websockets.WebSocketServerProtocol): + async def process_request(self, path, headers): + # Serve iframe on non-WebSocket requests + ... + + token = get_cookie(headers.get("Cookie", ""), "token") + if token is None: + return http.HTTPStatus.UNAUTHORIZED, [], b"Missing token\n" + + user = get_user(token) + if user is None: + return http.HTTPStatus.UNAUTHORIZED, [], b"Invalid token\n" + + self.user = user + + async def cookie_handler(websocket, path): + user = websocket.user + + ... + +User information +................ + +The client adds the token to the WebSocket URI in user information before +opening the connection: + +.. code:: javascript + + const uri = `ws://token:${token}@.../`; + const websocket = new WebSocket(uri); + + // ... + +Since HTTP Basic Auth is designed to accept a username and a password rather +than a token, we send ``token`` as username and the token as password. + +The server intercepts the HTTP request, extracts the token and authenticates +the user. If authentication fails, it returns a HTTP 401: + +.. code:: python + + class UserInfoProtocol(websockets.BasicAuthWebSocketServerProtocol): + async def check_credentials(self, username, password): + if username != "token": + return False + + user = get_user(password) + if user is None: + return False + + self.user = user + return True + + async def user_info_handler(websocket, path): + user = websocket.user + + ... + +Machine-to-machine authentication +--------------------------------- + +When the WebSocket client is a standalone program rather than a script running +in a browser, there are far fewer constraints. HTTP Authentication is the best +solution in this scenario. + +To authenticate a websockets client with HTTP Basic Authentication +(:rfc:`7617`), include the credentials in the URI: + +.. code:: python + + async with websockets.connect( + f"wss://{username}:{password}@example.com", + ) as websocket: + ... + +(You must :func:`~urllib.parse.quote` ``username`` and ``password`` if they +contain unsafe characters.) + +To authenticate a websockets client with HTTP Bearer Authentication +(:rfc:`6750`), add a suitable ``Authorization`` header: + +.. code:: python + + async with websockets.connect( + "wss://example.com", + extra_headers={"Authorization": f"Bearer {token}"} + ) as websocket: + ... diff --git a/docs/topics/authentication.svg b/docs/topics/authentication.svg new file mode 100644 index 000000000..ad2ad0e44 --- /dev/null +++ b/docs/topics/authentication.svg @@ -0,0 +1,63 @@ +HTTPserverWebSocketserverweb appin browseruser accounts(1) authenticate user(2) obtain credentials(3) send credentials(4) authenticate user \ No newline at end of file diff --git a/docs/topics/index.rst b/docs/topics/index.rst index 157278f76..5363de0ce 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -4,5 +4,6 @@ Topics .. toctree:: :maxdepth: 2 + authentication design security From c91b4c2a01bbf8dd41d521a29710470c8b73599b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 May 2021 18:51:27 +0200 Subject: [PATCH 0809/1539] Use constant-time comparison for passwords. --- docs/project/changelog.rst | 2 ++ src/websockets/legacy/auth.py | 28 +++++++++++++++------------- tests/legacy/test_auth.py | 13 ++++++++++--- 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index c2c869c13..a44dd0418 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -45,6 +45,8 @@ They may change at any time. * Optimized default compression settings to reduce memory usage. +* Protected against timing attacks on HTTP Basic Auth. + * Made it easier to customize authentication with :meth:`~auth.BasicAuthWebSocketServerProtocol.check_credentials`. diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index e7e69cac1..16016e6fd 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -6,6 +6,7 @@ import functools +import hmac import http from typing import Any, Awaitable, Callable, Iterable, Optional, Tuple, Union, cast @@ -154,24 +155,23 @@ def basic_auth_protocol_factory( if credentials is not None: if is_credentials(credentials): - - async def check_credentials(username: str, password: str) -> bool: - return (username, password) == credentials - + credentials_list = [cast(Credentials, credentials)] elif isinstance(credentials, Iterable): credentials_list = list(credentials) - if all(is_credentials(item) for item in credentials_list): - credentials_dict = dict(credentials_list) - - async def check_credentials(username: str, password: str) -> bool: - return credentials_dict.get(username) == password - - else: + if not all(is_credentials(item) for item in credentials_list): raise TypeError(f"invalid credentials argument: {credentials}") - else: raise TypeError(f"invalid credentials argument: {credentials}") + credentials_dict = dict(credentials_list) + + async def check_credentials(username: str, password: str) -> bool: + try: + expected_password = credentials_dict[username] + except KeyError: + return False + return hmac.compare_digest(expected_password, password) + if create_protocol is None: # Not sure why mypy cannot figure this out. create_protocol = cast( @@ -180,5 +180,7 @@ async def check_credentials(username: str, password: str) -> bool: ) return functools.partial( - create_protocol, realm=realm, check_credentials=check_credentials + create_protocol, + realm=realm, + check_credentials=check_credentials, ) diff --git a/tests/legacy/test_auth.py b/tests/legacy/test_auth.py index c4dbd88ad..2b670c31f 100644 --- a/tests/legacy/test_auth.py +++ b/tests/legacy/test_auth.py @@ -1,3 +1,4 @@ +import hmac import unittest import urllib.error @@ -27,7 +28,7 @@ async def process_request(self, path, request_headers): class CheckWebSocketServerProtocol(BasicAuthWebSocketServerProtocol): async def check_credentials(self, username, password): - return password == "letmein" + return hmac.compare_digest(password, "letmein") class AuthClientServerTests(ClientServerTestsMixin, AsyncioTestCase): @@ -81,7 +82,7 @@ def test_basic_auth_bad_multiple_credentials(self): ) async def check_credentials(username, password): - return password == "iloveyou" + return hmac.compare_digest(password, "iloveyou") create_protocol_check_credentials = basic_auth_protocol_factory( realm="auth-tests", @@ -158,7 +159,13 @@ def test_basic_auth_unsupported_credentials_details(self): self.assertEqual(raised.exception.read().decode(), "Unsupported credentials\n") @with_server(create_protocol=create_protocol) - def test_basic_auth_invalid_credentials(self): + def test_basic_auth_invalid_username(self): + with self.assertRaises(InvalidStatusCode) as raised: + self.start_client(user_info=("goodbye", "iloveyou")) + self.assertEqual(raised.exception.status_code, 401) + + @with_server(create_protocol=create_protocol) + def test_basic_auth_invalid_password(self): with self.assertRaises(InvalidStatusCode) as raised: self.start_client(user_info=("hello", "ihateyou")) self.assertEqual(raised.exception.status_code, 401) From 4454e7df32ab82b439dec1b2ab64071bc22b0b38 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 May 2021 21:30:37 +0200 Subject: [PATCH 0810/1539] Normalize code style. --- example/deployment/supervisor/app.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/example/deployment/supervisor/app.py b/example/deployment/supervisor/app.py index 5fa596d0a..484566bc8 100644 --- a/example/deployment/supervisor/app.py +++ b/example/deployment/supervisor/app.py @@ -18,7 +18,9 @@ async def main(): loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) async with websockets.serve( - echo, "", 8080, + echo, + host="", + port=8080, reuse_port=True, ): await stop From 2222beaaa2900847ecf7b3b487ac6106254504d1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 21 May 2021 22:38:10 +0200 Subject: [PATCH 0811/1539] Improve Django integration guide. Switch to a simpler and more secure authentication design. --- docs/howto/django.rst | 193 +++++++++++++++++-------------- example/django/authentication.py | 40 ++----- example/django/notifications.py | 21 ++-- 3 files changed, 121 insertions(+), 133 deletions(-) diff --git a/docs/howto/django.rst b/docs/howto/django.rst index 1b4b4d3b9..e776f5f8c 100644 --- a/docs/howto/django.rst +++ b/docs/howto/django.rst @@ -17,39 +17,32 @@ WebSocket, you have two main options. This guide shows how to implement the second technique with websockets. It assumes familiarity with Django. -Authenticating connections --------------------------- - -Since the websockets server will run outside of Django, we need to connect it -to ``django.contrib.auth``. - -Our clients are running in browser. The `WebSocket API`_ doesn't support -setting `custom headers`_ so our options boil down to: - -* HTTP Basic Auth: this seems technically possible but isn't supported by - Firefox (`bug 1229443`_) so browser support is clearly insufficient. -* Sharing cookies: this is technically possible if there's a common parent - domain between the Django server (e.g. api.example.com) and the websockets - server (e.g. ws.example.com). However, there's a risk to share cookies too - widely (e.g. with anything under .example.com here). For authentication - cookies, this risk seems unacceptable. -* Sending an authentication ticket: Django generates a secure single-use token - with the user ID. The browser includes this token in the WebSocket URI when - it connects to the server in order to authenticate. It could also send the - ticket over the WebSocket connection in the first message, however this is a - bit more difficult to monitor, as you can't detect authentication failures - simply by looking at HTTP response codes. - -.. _custom headers: https://github.com/whatwg/html/issues/3062 -.. _WebSocket API: https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API -.. _bug 1229443: https://bugzilla.mozilla.org/show_bug.cgi?id=1229443 - -To generate our authentication tokens, we'll use `django-sesame`_, a small -library designed exactly for this purpose. +Authenticate connections +------------------------ + +Since the websockets server runs outside of Django, we need to integrate it +with ``django.contrib.auth``. + +We will generate authentication tokens in the Django project. Then we will +send them to the websockets server, where they will authenticate the user. + +Generating a token for the current user and making it available in the browser +is up to you. You could render the token in a template or fetch it with an API +call. + +Refer to the topic guide on :doc:`authentication <../topics/authentication>` +for details on this design. + +Generate tokens +............... + +We want secure, short-lived tokens containing the user ID. We'll rely on +`django-sesame`_, a small library designed exactly for this purpose. .. _django-sesame: https://github.com/aaugustin/django-sesame -Add django-sesame to the dependencies of your Django project, install it and configure it in the settings of the project: +Add django-sesame to the dependencies of your Django project, install it, and +configure it in the settings of the project: .. code:: python @@ -65,15 +58,22 @@ You don't need ``"sesame.middleware.AuthenticationMiddleware"``. It is for authenticating users in the Django server, while we're authenticating them in the websockets server. -We'd like our tokens to be valid for 30 seconds and usable only once. A -shorter lifespan is possible but it would make manual testing difficult. - -Configure django-sesame accordingly in the settings of your Django project: +We'd like our tokens to be valid for 30 seconds. We expect web pages to load +and to establish the WebSocket connection within this delay. Configure +django-sesame accordingly in the settings of your Django project: .. code:: python SESAME_MAX_AGE = 30 - SESAME_ONE_TIME = True + +If you expect your web site to load faster for all clients, a shorter lifespan +is possible. However, in the context of this document, it would make manual +testing more difficult. + +You could also enable single-use tokens. However, this would update the last +login date of the user every time a WebSocket connection is established. This +doesn't seem like a good idea, both in terms of behavior and in terms of +performance. Now you can generate tokens in a ``django-admin shell`` as follows: @@ -86,100 +86,103 @@ Now you can generate tokens in a ``django-admin shell`` as follows: >>> get_token(user) '' -Keep this console open: since tokens are single use, we'll have to generate a -new token every time we want to test our server. +Keep this console open: since tokens expire after 30 seconds, you'll have to +generate a new token every time you want to test connecting to the server. -Let's move on to the websockets server. Add websockets to the dependencies of -your Django project and install it. +Validate tokens +............... -Now here's a way to implement authentication. We're taking advantage of the -:meth:`~websockets.server.WebSocketServerProtocol.process_request` hook to -authenticate requests. If authentication succeeds, we store the user as an -attribute of the connection in order to make it available to the connection -handler. If authentication fails, we return a HTTP 401 Unauthorized error. +Let's move on to the websockets server. + +Add websockets to the dependencies of your Django project and install it. +Indeed, we're going to reuse the environment of the Django project, so we can +call its APIs in the websockets server. + +Now here's how to implement authentication. .. literalinclude:: ../../example/django/authentication.py Let's unpack this code. -We're using Django in a `standalone script`_. This requires setting the -``DJANGO_SETTINGS_MODULE`` environment variable and calling ``django.setup()`` -before doing anything with Django. +We're calling ``django.setup()`` before doing anything with Django because +we're using Django in a `standalone script`_. This assumes that the +``DJANGO_SETTINGS_MODULE`` environment variable is set to the Python path to +your settings module. .. _standalone script: https://docs.djangoproject.com/en/stable/topics/settings/#calling-django-setup-is-required-for-standalone-django-usage -We subclass :class:`~websockets.server.WebSocketServerProtocol` and override -:meth:`~websockets.server.WebSocketServerProtocol.process_request`, where: - -* We extract the token from the URL with the ``get_sesame()`` utility function - defined just above. If the token is missing, we return a HTTP 401 error. -* We authenticate the user with ``get_user()``, the API for `authentication - outside views`_. If authentication fails, we return a HTTP 401 error. +The connection handler reads the first message received from the client, which +is expected to contain a django-sesame token. Then it authenticates the user +with ``get_user()``, the API for `authentication outside views`_. If +authentication fails, it closes the connection and exits. .. _authentication outside views: https://github.com/aaugustin/django-sesame#authentication-outside-views -When we call an API that makes a database query, we wrap the call in -:func:`~asyncio.to_thread`. Indeed, the Django ORM doesn't support -asynchronous I/O. We would block the event loop if we didn't run these calls -in a separate thread. :func:`~asyncio.to_thread` is available since Python -3.9; in earlier versions, use :meth:`~asyncio.loop.run_in_executor` instead. +When we call an API that makes a database query such as ``get_user()``, we +wrap the call in :func:`~asyncio.to_thread`. Indeed, the Django ORM doesn't +support asynchronous I/O. It would block the event loop if it didn't run in a +separate thread. :func:`~asyncio.to_thread` is available since Python 3.9. In +earlier versions, use :meth:`~asyncio.loop.run_in_executor` instead. -The connection handler accesses the logged-in user that we stored as an -attribute of the connection object so we can test that authentication works. - -Finally, we start a server with :func:`~websockets.serve`, with the -``create_protocol`` pointing to our subclass of -:class:`~websockets.server.WebSocketServerProtocol`. +Finally, we start a server with :func:`~websockets.serve`. We're ready to test! -Make sure the ``DJANGO_SETTINGS_MODULE`` environment variable is set to the -Python path to your settings module and start the websockets server. If you -saved the server implementation to a file called ``authentication.py``: +Save this code to a file called ``authentication.py``, make sure the +``DJANGO_SETTINGS_MODULE`` environment variable is set properly, and start the +websockets server: .. code:: console $ python authentication.py -Open a new shell, generate a new token — remember, they're only valid for -30 seconds — and use it to connect to your server: +Generate a new token — remember, they're only valid for 30 seconds — and use +it to connect to your server. Paste your token and press Enter when you get a +prompt: .. code:: console - $ python -m websockets "ws://localhost:8888/?sesame=" - Connected to ws://localhost:8888/?sesame= + $ python -m websockets ws://localhost:8888/ + Connected to ws://localhost:8888/ + > < Hello ! Connection closed: code = 1000 (OK), no reason. It works! -If we try to reuse the same token, the connection is now rejected: +If you enter an expired or invalid token, authentication fails and the server +closes the connection: .. code:: console - $ python -m websockets "ws://localhost:8888/?sesame=" - Failed to connect to ws://localhost:8888/?sesame=: - server rejected WebSocket connection: HTTP 401. + $ python -m websockets ws://localhost:8888/ + Connected to ws://localhost:8888. + > not a token + Connection closed: code = 1011 (unexpected error), reason = authentication failed. You can also test from a browser by generating a new token and running the following code in the JavaScript console of the browser: .. code:: javascript - webSocket = new WebSocket("ws://localhost:8888/?sesame="); - webSocket.onmessage = (event) => console.log(event.data); + websocket = new WebSocket("ws://localhost:8888/"); + websocket.onopen = (event) => websocket.send(""); + websocket.onmessage = (event) => console.log(event.data); -Streaming events ----------------- +Stream events +------------- We can connect and authenticate but our server doesn't do anything useful yet! -Let's send a message every time any user makes any action in the admin. This +Let's send a message every time a user makes an action in the admin. This message will be broadcast to all users who can access the model on which the action was made. This may be used for showing notifications to other users. Many use cases for WebSocket with Django follow a similar pattern. +Set up event bus +................ + We need a event bus to enable communications between Django and websockets. Both sides connect permanently to the bus. Then Django writes events and websockets reads them. For the sake of simplicity, we'll rely on `Redis @@ -187,15 +190,15 @@ Pub/Sub`_. .. _Redis Pub/Sub: https://redis.io/topics/pubsub -Let's start by writing events. The easiest way to add Redis to a Django -project is by configuring a cache backend with `django-redis`_. This library -manages connections to Redis efficiently, persisting them between requests, -and provides an API to access the Redis connection directly. +The easiest way to add Redis to a Django project is by configuring a cache +backend with `django-redis`_. This library manages connections to Redis +efficiently, persisting them between requests, and provides an API to access +the Redis connection directly. .. _django-redis: https://github.com/jazzband/django-redis Install Redis, add django-redis to the dependencies of your Django project, -install it and configure it in the settings of the project: +install it, and configure it in the settings of the project: .. code:: python @@ -209,6 +212,11 @@ install it and configure it in the settings of the project: If you already have a default cache, add a new one with a different name and change ``get_redis_connection("default")`` in the code below to the same name. +Publish events +.............. + +Now let's write events to the bus. + Add the following code to a module that is imported when your Django project starts. Typically, you would put it in a ``signals.py`` module, which you would import in the ``AppConfig.ready()`` method of one of your apps: @@ -236,9 +244,11 @@ Leave this command running, start the Django development server and make changes in the admin: add, modify, or delete objects. You should see corresponding events published to the ``"events"`` stream. +Broadcast events +................ + Now let's turn to reading events and broadcasting them to connected clients. -We'll reuse our custom ``ServerProtocol`` class for authentication. Then we -need to add several features: +We need to add several features: * Keep track of connected clients so we can broadcast messages. * Tell which content types the user has permission to view or to change. @@ -250,10 +260,10 @@ Here's a complete implementation. .. literalinclude:: ../../example/django/notifications.py Since the ``get_content_types()`` function makes a database query, it is -wrapped inside ``asyncio.to_thread()``. It runs once when each WebSocket +wrapped inside :func:`asyncio.to_thread()`. It runs once when each WebSocket connection is open; then its result is cached for the lifetime of the connection. Indeed, running it for each message would trigger database queries -for all connected users at the same time, which could hurt the database. +for all connected users at the same time, which would hurt the database. The connection handler merely registers the connection in a global variable, associated to the list of content types for which events should be sent to @@ -270,6 +280,9 @@ send a message to a slow client. Since Redis can publish a message to multiple subscribers, multiple instances of this server can safely run in parallel. +Does it scale? +-------------- + In theory, given enough servers, this design can scale to a hundred million clients, since Redis can handle ten thousand servers and each server can handle ten thousand clients. In practice, you would need a more scalable diff --git a/example/django/authentication.py b/example/django/authentication.py index c0a061109..bbb3db02a 100644 --- a/example/django/authentication.py +++ b/example/django/authentication.py @@ -1,8 +1,6 @@ #!/usr/bin/env python import asyncio -import http -import urllib.parse import django import websockets @@ -12,40 +10,18 @@ from sesame.utils import get_user -def get_sesame(path): - """Utility function to extract sesame token from request path.""" - query = urllib.parse.urlparse(path).query - params = urllib.parse.parse_qs(query) - sesame = params.get("sesame", []) - if len(sesame) == 1: - return sesame[0] - - -class ServerProtocol(websockets.WebSocketServerProtocol): - async def process_request(self, path, headers): - """Authenticate users with a django-sesame token.""" - sesame = get_sesame(path) - if sesame is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Missing token\n" - - user = await asyncio.to_thread(get_user, sesame) - if user is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Invalid token\n" - - self.user = user - - async def handler(websocket, path): - await websocket.send(f"Hello {websocket.user}!") + sesame = await websocket.recv() + user = await asyncio.to_thread(get_user, sesame) + if user is None: + await websocket.close(1011, "authentication failed") + return + + await websocket.send(f"Hello {user}!") async def main(): - async with websockets.serve( - handler, - "localhost", - 8888, - create_protocol=ServerProtocol, - ): + async with websockets.serve(handler, "localhost", 8888): await asyncio.Future() # run forever diff --git a/example/django/notifications.py b/example/django/notifications.py index ad2751d98..641643f92 100644 --- a/example/django/notifications.py +++ b/example/django/notifications.py @@ -10,9 +10,7 @@ django.setup() from django.contrib.contenttypes.models import ContentType - -# Reuse our custom protocol to authenticate connections -from authentication import ServerProtocol +from sesame.utils import get_user CONNECTIONS = {} @@ -31,8 +29,14 @@ def get_content_types(user): async def handler(websocket, path): - """Register connection in CONNECTIONS dict, until it's closed.""" - ct_ids = await asyncio.to_thread(get_content_types, websocket.user) + """Authenticate user and register connection in CONNECTIONS.""" + sesame = await websocket.recv() + user = await asyncio.to_thread(get_user, sesame) + if user is None: + await websocket.close(1011, "authentication failed") + return + + ct_ids = await asyncio.to_thread(get_content_types, user) CONNECTIONS[websocket] = {"content_type_ids": ct_ids} try: await websocket.wait_closed() @@ -57,12 +61,7 @@ async def process_events(): async def main(): - async with websockets.serve( - handler, - "localhost", - 8888, - create_protocol=ServerProtocol, - ): + async with websockets.serve(handler, "localhost", 8888): await process_events() # runs forever From 8acf7ccf39c96acec293a2ec668e299f6f29e46b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 May 2021 18:33:41 +0200 Subject: [PATCH 0812/1539] Setup issue template Fix #835. --- .github/ISSUE_TEMPLATE/config.yml | 1 + .github/ISSUE_TEMPLATE/issue.md | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/config.yml create mode 100644 .github/ISSUE_TEMPLATE/issue.md diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 000000000..3ba13e0ce --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1 @@ +blank_issues_enabled: false diff --git a/.github/ISSUE_TEMPLATE/issue.md b/.github/ISSUE_TEMPLATE/issue.md new file mode 100644 index 000000000..8efabc617 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/issue.md @@ -0,0 +1,26 @@ +--- +name: Report an issue +about: Let us know about a problem with websockets +title: '' +labels: '' +assignees: '' + +--- + + From 81edc5f4adf6054852d041a222863d350a3d92d7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 24 May 2021 19:20:18 +0200 Subject: [PATCH 0813/1539] Move compression docs to a topic guide. --- .gitignore | 2 +- docs/howto/compression.rst | 51 ------ docs/howto/deployment.rst | 75 +------- docs/howto/index.rst | 1 - docs/topics/compression.rst | 172 ++++++++++++++++++ docs/topics/index.rst | 1 + .../compression/benchmark.py | 8 +- .../compression/client.py | 8 +- .../compression/server.py | 4 +- 9 files changed, 186 insertions(+), 136 deletions(-) delete mode 100644 docs/howto/compression.rst create mode 100644 docs/topics/compression.rst rename benchmark/compression.py => experiments/compression/benchmark.py (97%) rename benchmark/mem_client.py => experiments/compression/client.py (87%) rename benchmark/mem_server.py => experiments/compression/server.py (96%) diff --git a/.gitignore b/.gitignore index ac68ff739..8f9e7dc51 100644 --- a/.gitignore +++ b/.gitignore @@ -5,8 +5,8 @@ .mypy_cache .tox build/ -benchmark/corpus.pkl compliance/reports/ +experiments/compression/corpus.pkl dist/ docs/_build/ htmlcov/ diff --git a/docs/howto/compression.rst b/docs/howto/compression.rst deleted file mode 100644 index 9023cec56..000000000 --- a/docs/howto/compression.rst +++ /dev/null @@ -1,51 +0,0 @@ -Compression -=========== - -:func:`~websockets.client.connect` and :func:`~websockets.server.serve` enable -compression by default. - -If you want to disable it, set ``compression=None``:: - - import websockets - - websockets.connect(..., compression=None) - - websockets.serve(..., compression=None) - -.. _per-message-deflate-configuration-example: - -You can also configure the Per-Message Deflate extension explicitly if you -want to customize compression settings:: - - import websockets - from websockets.extensions import permessage_deflate - - websockets.connect( - ..., - extensions=[ - permessage_deflate.ClientPerMessageDeflateFactory( - server_max_window_bits=11, - client_max_window_bits=11, - compress_settings={'memLevel': 4}, - ), - ], - ) - - websockets.serve( - ..., - extensions=[ - permessage_deflate.ServerPerMessageDeflateFactory( - server_max_window_bits=11, - client_max_window_bits=11, - compress_settings={'memLevel': 4}, - ), - ], - ) - -The window bits and memory level values chosen in these examples reduce memory -usage. You can read more about :ref:`optimizing compression settings -`. - -Refer to the API documentation of -:class:`~permessage_deflate.ClientPerMessageDeflateFactory` and -:class:`~permessage_deflate.ServerPerMessageDeflateFactory` for details. diff --git a/docs/howto/deployment.rst b/docs/howto/deployment.rst index 35720e8f1..d8264b2cc 100644 --- a/docs/howto/deployment.rst +++ b/docs/howto/deployment.rst @@ -66,82 +66,11 @@ Memory usage of a single connection is the sum of: Baseline ........ -.. _compression-settings: - Compression settings are the main factor affecting the baseline amount of memory used by each connection. -If you'd like to customize compression settings, here are the main knobs. - -- Context Takeover is necessary to get good performance for almost all - applications. It should remain enabled. -- Window Bits is a trade-off between memory usage and compression rate. - It should be an integer between 9 (lowest memory usage) and 15 (highest - compression rate). Setting it to 8 is possible but triggers a bug in some - versions of zlib. -- Memory Level is a trade-off between memory usage and compression speed. - However, a lower memory level can increase speed thanks to memory locality, - even if the CPU does more work! It should be an integer between 1 (lowest - memory usage) and 9 (highest compression speed in theory, not in practice). - -By default, websockets enables compression with conservative settings that -optimize memory usage at the cost of a slightly worse compression rate: Window -Bits = 12 and Memory Level = 5. This strikes a good balance for small messages -that are typical of WebSocket servers. - -If you'd like to configure different compression settings, see this -:ref:`example `. If you don't set -limits on Window Bits and neither does the remote endpoint, it defaults to the -maximum value of 15. If you don't set Memory Level, it defaults to 8 — more -accurately, to ``zlib.DEF_MEM_LEVEL`` which is 8. - -Here's how various compression settings affect memory usage of a single -connection on a 64-bit system, as well a benchmark of compressed size and -compression time for a corpus of small JSON documents. - -+-------------+-------------+--------------+--------------+------------------+------------------+ -| Compression | Window Bits | Memory Level | Memory usage | Size vs. default | Time vs. default | -+=============+=============+==============+==============+==================+==================+ -| | 15 | 8 | 322 KiB | -4.0% | +15% + -+-------------+-------------+--------------+--------------+------------------+------------------+ -| | 14 | 7 | 178 KiB | -2.6% | +10% | -+-------------+-------------+--------------+--------------+------------------+------------------+ -| | 13 | 6 | 106 KiB | -1.4% | +5% | -+-------------+-------------+--------------+--------------+------------------+------------------+ -| *default* | 12 | 5 | 70 KiB | = | = | -+-------------+-------------+--------------+--------------+------------------+------------------+ -| | 11 | 4 | 52 KiB | +3.7% | -5% | -+-------------+-------------+--------------+--------------+------------------+------------------+ -| | 10 | 3 | 43 KiB | +90% | +50% | -+-------------+-------------+--------------+--------------+------------------+------------------+ -| | 9 | 2 | 39 KiB | +160% | +100% | -+-------------+-------------+--------------+--------------+------------------+------------------+ -| *disabled* | N/A | N/A | 19 KiB | N/A | N/A | -+-------------+-------------+--------------+--------------+------------------+------------------+ - -*Don't assume this example is representative! Compressed size and compression -time depend heavily on the kind of messages exchanged by the application!* - -You can adapt the `compression.py`_ benchmark for your application by creating -a list of typical messages and passing it to the ``_benchmark`` function. - -.. _compression.py: https://github.com/aaugustin/websockets/blob/main/performance/compression.py - -This `blog post by Ilya Grigorik`_ provides more details about how compression -settings affect memory usage and how to optimize them. - -.. _blog post by Ilya Grigorik: https://www.igvita.com/2013/11/27/configuring-and-optimizing-websocket-compression/ - -This `experiment by Peter Thorson`_ suggests Window Bits = 11 and Memory Level = -4 as a sweet spot for optimizing memory usage. - -.. _experiment by Peter Thorson: https://www.ietf.org/mail-archive/web/hybi/current/msg10222.html - -websockets defaults to Window Bits = 12 and Memory Level = 5 in order to stay -away from Window Bits = 10 or Memory Level = 3, where performance craters in -the benchmark. This raises doubts on what could happen at Window Bits = 11 and -Memory Level = 4 on a different set of messages. The defaults needs to be safe -for all applications, hence a more conservative choice. +Read to the topic guide on :doc:`../topics/compression` to learn more about +tuning compression settings. Buffers ....... diff --git a/docs/howto/index.rst b/docs/howto/index.rst index 033dfe15d..b6cff7237 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -31,7 +31,6 @@ optimized, and secure setup. :maxdepth: 2 deployment - compression heroku kubernetes supervisor diff --git a/docs/topics/compression.rst b/docs/topics/compression.rst new file mode 100644 index 000000000..206bfa7b7 --- /dev/null +++ b/docs/topics/compression.rst @@ -0,0 +1,172 @@ +Compression +=========== + +Most WebSocket servers exchange JSON messages because they're convenient to +parse and serialize in a browser. These messages contain text data and tend to +be repetitive. + +This makes the stream of messages highly compressible. Enabling compression +can reduce network traffic by more than 80%. + +There's a standard for compressing messages. :rfc:`7692` defines WebSocket +Per-Message Deflate, a compression extension based on the Deflate_ algorithm. + +.. _Deflate: https://en.wikipedia.org/wiki/Deflate + +Configuring compression +----------------------- + +:func:`~websockets.client.connect` and :func:`~websockets.server.serve` enable +compression by default because the reduction in network bandwidth is usually +worth the additional memory and CPU cost. + +If you want to disable compression, set ``compression=None``:: + + import websockets + + websockets.connect(..., compression=None) + + websockets.serve(..., compression=None) + +If you want to customize compression settings, you can enable the Per-Message +Deflate extension explicitly with +:class:`~permessage_deflate.ClientPerMessageDeflateFactory` or +:class:`~permessage_deflate.ServerPerMessageDeflateFactory`:: + + import websockets + from websockets.extensions import permessage_deflate + + websockets.connect( + ..., + extensions=[ + permessage_deflate.ClientPerMessageDeflateFactory( + server_max_window_bits=11, + client_max_window_bits=11, + compress_settings={"memLevel": 4}, + ), + ], + ) + + websockets.serve( + ..., + extensions=[ + permessage_deflate.ServerPerMessageDeflateFactory( + server_max_window_bits=11, + client_max_window_bits=11, + compress_settings={"memLevel": 4}, + ), + ], + ) + +The Window Bits and Memory Level values in these examples reduce memory usage at the expense of compression rate. + +Compression settings +-------------------- + +When a client and a server enable the Per-Message Deflate extension, they +negotiate two parameters to guarantee compatibility between compression and +decompression. This affects the trade-off between compression rate and memory +usage for both sides. + +* **Context Takeover** means that the compression context is retained between + messages. In other words, compression is applied to the stream of messages + rather than to each message individually. Context takeover should remain + enabled to get good performance on applications that send a stream of + messages with the same structure, that is, most applications. + +* **Window Bits** controls the size of the compression context. It must be + an integer between 9 (lowest memory usage) and 15 (best compression). + websockets defaults to 12. Setting it to 8 is possible but rejected by some + versions of zlib. + +:mod:`zlib` offers additional parameters for tuning compression. They control +the trade-off between compression rate and CPU and memory usage for the +compression side, transparently for the decompression side. + +* **Memory Level** controls the size of the compression state. It must be an + integer between 1 (lowest memory usage) and 9 (best compression). websockets + defaults to 5. A lower memory level can increase speed thanks to memory + locality. + +* **Compression Level** controls the effort to optimize compression. It must + be an integer between 1 (lowest CPU usage) and 9 (best compression). + +* **Strategy** selects the compression strategy. The best choice depends on + the type of data being compressed. + +Unless mentioned otherwise, websockets uses the defaults of +:func:`zlib.compressobj` for all these settings. + +Tuning compression +------------------ + +By default, websockets enables compression with conservative settings that +optimize memory usage at the cost of a slightly worse compression rate: Window +Bits = 12 and Memory Level = 5. This strikes a good balance for small messages +that are typical of WebSocket servers. + +Here's how various compression settings affect memory usage of a single +connection on a 64-bit system, as well a benchmark of compressed size and +compression time for a corpus of small JSON documents. + ++-------------+-------------+--------------+--------------+------------------+------------------+ +| Compression | Window Bits | Memory Level | Memory usage | Size vs. default | Time vs. default | ++=============+=============+==============+==============+==================+==================+ +| | 15 | 8 | 322 KiB | -4.0% | +15% + ++-------------+-------------+--------------+--------------+------------------+------------------+ +| | 14 | 7 | 178 KiB | -2.6% | +10% | ++-------------+-------------+--------------+--------------+------------------+------------------+ +| | 13 | 6 | 106 KiB | -1.4% | +5% | ++-------------+-------------+--------------+--------------+------------------+------------------+ +| *default* | 12 | 5 | 70 KiB | = | = | ++-------------+-------------+--------------+--------------+------------------+------------------+ +| | 11 | 4 | 52 KiB | +3.7% | -5% | ++-------------+-------------+--------------+--------------+------------------+------------------+ +| | 10 | 3 | 43 KiB | +90% | +50% | ++-------------+-------------+--------------+--------------+------------------+------------------+ +| | 9 | 2 | 39 KiB | +160% | +100% | ++-------------+-------------+--------------+--------------+------------------+------------------+ +| *disabled* | — | — | 19 KiB | +452% | — | ++-------------+-------------+--------------+--------------+------------------+------------------+ + +Window Bits and Memory Level don't have to move in lockstep. However, other +combinations don't yield significantly better results than those shown above. + +Compressed size and compression time depend heavily on the kind of messages +exchanged by the application so this example may not apply to your use case. + +You can adapt `compression/benchmark.py`_ by creating a list of typical +messages and passing it to the ``_run`` function. + +Window Bits = 11 and Memory Level = 4 looks like the sweet spot in this table. + +websockets defaults to Window Bits = 12 and Memory Level = 5 to stay away from +Window Bits = 10 or Memory Level = 3 where performance craters, raising doubts +on what could happen at Window Bits = 11 and Memory Level = 4 on a different +corpus. + +Defaults must be safe for all applications, hence a more conservative choice. + +.. _compression/benchmark.py: https://github.com/aaugustin/websockets/blob/main/experiments/compression/benchmark.py + +The benchmark focuses on compression because it's more expensive than +decompression. Indeed, leaving aside small allocations, theoretical memory +usage is: + +* ``(1 << (windowBits + 2)) + (1 << (memLevel + 9))`` for compression; +* ``1 << windowBits`` for decompression. + +CPU usage is also higher for compression than decompression. + +Further reading +--------------- + +This `blog post by Ilya Grigorik`_ provides more details about how compression +settings affect memory usage and how to optimize them. + +.. _blog post by Ilya Grigorik: https://www.igvita.com/2013/11/27/configuring-and-optimizing-websocket-compression/ + +This `experiment by Peter Thorson`_ recommends Window Bits = 11 and Memory +Level = 4 for optimizing memory usage. + +.. _experiment by Peter Thorson: https://www.ietf.org/mail-archive/web/hybi/current/msg10222.html diff --git a/docs/topics/index.rst b/docs/topics/index.rst index 5363de0ce..b18434a39 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -5,5 +5,6 @@ Topics :maxdepth: 2 authentication + compression design security diff --git a/benchmark/compression.py b/experiments/compression/benchmark.py similarity index 97% rename from benchmark/compression.py rename to experiments/compression/benchmark.py index 15fb8653e..bdcdd8e95 100644 --- a/benchmark/compression.py +++ b/experiments/compression/benchmark.py @@ -49,7 +49,7 @@ def corpus(): pickle.dump(data, handle) -def _benchmark(data): +def _run(data): size = {} duration = {} @@ -149,15 +149,15 @@ def _benchmark(data): print() -def benchmark(): +def run(): with open(CORPUS_FILE, "rb") as handle: data = pickle.load(handle) - _benchmark(data) + _run(data) try: run = globals()[sys.argv[1]] except (KeyError, IndexError): - print(f"Usage: {sys.argv[0]} [corpus|benchmark]") + print(f"Usage: {sys.argv[0]} [corpus|run]") else: run() diff --git a/benchmark/mem_client.py b/experiments/compression/client.py similarity index 87% rename from benchmark/mem_client.py rename to experiments/compression/client.py index db68eb995..3ee19ddc5 100644 --- a/benchmark/mem_client.py +++ b/experiments/compression/client.py @@ -16,7 +16,7 @@ MEM_SIZE = [] -async def mem_client(client): +async def client(client): # Space out connections to make them sequential. await asyncio.sleep(client * INTERVAL) @@ -45,11 +45,11 @@ async def mem_client(client): await asyncio.sleep(CLIENTS * INTERVAL) -async def mem_clients(): - await asyncio.gather(*[mem_client(client) for client in range(CLIENTS + 1)]) +async def clients(): + await asyncio.gather(*[client(client) for client in range(CLIENTS + 1)]) -asyncio.run(mem_clients()) +asyncio.run(clients()) # First connection incurs non-representative setup costs. diff --git a/benchmark/mem_server.py b/experiments/compression/server.py similarity index 96% rename from benchmark/mem_server.py rename to experiments/compression/server.py index 852796249..f7b147006 100644 --- a/benchmark/mem_server.py +++ b/experiments/compression/server.py @@ -34,7 +34,7 @@ async def handler(ws, path): await asyncio.sleep(CLIENTS * INTERVAL) -async def mem_server(): +async def server(): loop = asyncio.get_running_loop() stop = loop.create_future() @@ -60,7 +60,7 @@ async def mem_server(): await stop -asyncio.run(mem_server()) +asyncio.run(server()) # First connection may incur non-representative setup costs. From 773f0b6d542307ff94c83d2f6fb9e28786e63dd8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 24 May 2021 21:46:13 +0200 Subject: [PATCH 0814/1539] Address objection to Django integration. --- docs/howto/django.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/howto/django.rst b/docs/howto/django.rst index e776f5f8c..001d1cb73 100644 --- a/docs/howto/django.rst +++ b/docs/howto/django.rst @@ -169,6 +169,11 @@ following code in the JavaScript console of the browser: websocket.onopen = (event) => websocket.send(""); websocket.onmessage = (event) => console.log(event.data); +If you don't want to import your entire Django project into the websockets +server, you can build a separate Django project with ``django.contrib.auth``, +``django-sesame``, a suitable ``User`` model, and a subset of the settings of +the main project. + Stream events ------------- From b455d30252d8eb66771b63c7b9ca26d37f76e4d9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 25 May 2021 22:31:32 +0200 Subject: [PATCH 0815/1539] Reorder FAQ. Add question on threads. --- docs/howto/faq.rst | 89 +++++++++++++++++++++++++--------------------- 1 file changed, 49 insertions(+), 40 deletions(-) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index 0196fd2c0..23964dcfb 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -131,13 +131,17 @@ websockets takes care of closing the connection when the handler exits. How do I run a HTTP server and WebSocket server on the same port? ................................................................. -This isn't supported. +You don't. + +HTTP and WebSockets have widely different operational characteristics. +Running them on the same server is a bad idea. Providing a HTTP server is out of scope for websockets. It only aims at providing a WebSocket server. There's limited support for returning HTTP responses with the :attr:`~legacy.server.WebSocketServerProtocol.process_request` hook. + If you need more, pick a HTTP server and run it separately. Client side @@ -287,6 +291,18 @@ There are several reasons why long-lived connections may be lost: If you're facing a reproducible issue, :ref:`enable debug logs ` to see when and how connections are closed. +How do I set a timeout on ``recv()``? +..................................... + +Use :func:`~asyncio.wait_for`:: + + await asyncio.wait_for(websocket.recv(), timeout=10) + +This technique works for most APIs, except for asynchronous context managers. +See `issue 574`_. + +.. _issue 574: https://github.com/aaugustin/websockets/issues/574 + How can I pass additional arguments to a custom protocol subclass? .................................................................. @@ -307,31 +323,19 @@ You can bind additional arguments to the protocol factory with This example was for a server. The same pattern applies on a client. -Why do I get the error: ``module 'websockets' has no attribute '...'``? -....................................................................... - -Often, this is because you created a script called ``websockets.py`` in your -current working directory. Then ``import websockets`` imports this module -instead of the websockets library. - -Are there ``onopen``, ``onmessage``, ``onerror``, and ``onclose`` callbacks? -............................................................................ - -No, there aren't. +How do I keep idle connections open? +.................................... -websockets provides high-level, coroutine-based APIs. Compared to callbacks, -coroutines make it easier to manage control flow in concurrent code. +websockets sends pings at 20 seconds intervals to keep the connection open. -If you prefer callback-based APIs, you should use another library. +In closes the connection if it doesn't get a pong within 20 seconds. -Can I use websockets synchronously, without ``async`` / ``await``? -.................................................................. +You can adjust this behavior with ``ping_interval`` and ``ping_timeout``. -You can convert every asynchronous call to a synchronous call by wrapping it -in ``asyncio.get_event_loop().run_until_complete(...)``. Unfortunately, this -is deprecated as of Python 3.10. +How do I respond to pings? +.......................... -If this turns out to be impractical, you should use another library. +websockets takes care of responding to pings with pongs. Miscellaneous ------------- @@ -344,31 +348,24 @@ websockets doesn't have built-in publish / subscribe for these use cases. Depending on the scale of your service, a simple in-memory implementation may do the job or you may need an external publish / subscribe component. -How do I set a timeout on ``recv()``? -..................................... - -Use :func:`~asyncio.wait_for`:: - - await asyncio.wait_for(websocket.recv(), timeout=10) - -This technique works for most APIs, except for asynchronous context managers. -See `issue 574`_. - -.. _issue 574: https://github.com/aaugustin/websockets/issues/574 +Can I use websockets synchronously, without ``async`` / ``await``? +.................................................................. -How do I keep idle connections open? -.................................... +You can convert every asynchronous call to a synchronous call by wrapping it +in ``asyncio.get_event_loop().run_until_complete(...)``. Unfortunately, this +is deprecated as of Python 3.10. -websockets sends pings at 20 seconds intervals to keep the connection open. +If this turns out to be impractical, you should use another library. -In closes the connection if it doesn't get a pong within 20 seconds. +Are there ``onopen``, ``onmessage``, ``onerror``, and ``onclose`` callbacks? +............................................................................ -You can adjust this behavior with ``ping_interval`` and ``ping_timeout``. +No, there aren't. -How do I respond to pings? -.......................... +websockets provides high-level, coroutine-based APIs. Compared to callbacks, +coroutines make it easier to manage control flow in concurrent code. -websockets takes care of responding to pings with pongs. +If you prefer callback-based APIs, you should use another library. Is there a Python 2 version? ............................ @@ -379,4 +376,16 @@ Python 2 reached end of life on January 1st, 2020. Before that date, websockets required asyncio and therefore Python 3. +Why do I get the error: ``module 'websockets' has no attribute '...'``? +....................................................................... + +Often, this is because you created a script called ``websockets.py`` in your +current working directory. Then ``import websockets`` imports this module +instead of the websockets library. + +I'm having problems with threads +................................ + +You shouldn't use threads. Use tasks instead. +:func:`~asyncio.AbstractEventLoop.call_soon_threadsafe` may help. From 5e69983096359fdf87d26afa7b5143badfe2140e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 25 May 2021 23:00:21 +0200 Subject: [PATCH 0816/1539] Add deployment doc. Fix #445. --- docs/howto/deployment.rst | 111 -------------------- docs/howto/index.rst | 6 +- docs/spelling_wordlist.txt | 7 ++ docs/topics/deployment.rst | 181 +++++++++++++++++++++++++++++++++ docs/topics/deployment.svg | 63 ++++++++++++ docs/topics/index.rst | 2 + docs/topics/memory.rst | 43 ++++++++ docs/topics/security.rst | 4 +- example/health_check_server.py | 4 +- example/shutdown_server.py | 1 + 10 files changed, 303 insertions(+), 119 deletions(-) delete mode 100644 docs/howto/deployment.rst create mode 100644 docs/topics/deployment.rst create mode 100644 docs/topics/deployment.svg create mode 100644 docs/topics/memory.rst diff --git a/docs/howto/deployment.rst b/docs/howto/deployment.rst deleted file mode 100644 index d8264b2cc..000000000 --- a/docs/howto/deployment.rst +++ /dev/null @@ -1,111 +0,0 @@ -Deployment -========== - -.. currentmodule:: websockets - -Application server ------------------- - -The author of websockets isn't aware of best practices for deploying network -services based on :mod:`asyncio`, let alone application servers. - -You can run a script similar to the :ref:`server example `, -inside a supervisor if you deem that useful. - -You can also add a wrapper to daemonize the process. Third-party libraries -provide solutions for that. - -If you can share knowledge on this topic, please file an issue_. Thanks! - -.. _issue: https://github.com/aaugustin/websockets/issues/new - -Graceful shutdown ------------------ - -You may want to close connections gracefully when shutting down the server, -perhaps after executing some cleanup logic. There are two ways to achieve this -with the object returned by :func:`~legacy.server.serve`: - -- using it as a asynchronous context manager, or -- calling its ``close()`` method, then waiting for its ``wait_closed()`` - method to complete. - -On Unix systems, shutdown is usually triggered by sending a signal. - -Here's a full example for handling SIGTERM on Unix: - -.. literalinclude:: ../../example/shutdown_server.py - :emphasize-lines: 12-15,17 - -This example is easily adapted to handle other signals. If you override the -default handler for SIGINT, which raises :exc:`KeyboardInterrupt`, be aware -that you won't be able to interrupt a program with Ctrl-C anymore when it's -stuck in a loop. - -It's more difficult to achieve the same effect on Windows. Some third-party -projects try to help with this problem. - -If your server doesn't run in the main thread, look at -:func:`~asyncio.AbstractEventLoop.call_soon_threadsafe`. - -Memory usage ------------- - -.. _memory-usage: - -In most cases, memory usage of a WebSocket server is proportional to the -number of open connections. When a server handles thousands of connections, -memory usage can become a bottleneck. - -Memory usage of a single connection is the sum of: - -1. the baseline amount of memory websockets requires for each connection, -2. the amount of data held in buffers before the application processes it, -3. any additional memory allocated by the application itself. - -Baseline -........ - -Compression settings are the main factor affecting the baseline amount of -memory used by each connection. - -Read to the topic guide on :doc:`../topics/compression` to learn more about -tuning compression settings. - -Buffers -....... - -Under normal circumstances, buffers are almost always empty. - -Under high load, if a server receives more messages than it can process, -bufferbloat can result in excessive memory use. - -By default websockets has generous limits. It is strongly recommended to adapt -them to your application. When you call :func:`~legacy.server.serve`: - -- Set ``max_size`` (default: 1 MiB, UTF-8 encoded) to the maximum size of - messages your application generates. -- Set ``max_queue`` (default: 32) to the maximum number of messages your - application expects to receive faster than it can process them. The queue - provides burst tolerance without slowing down the TCP connection. - -Furthermore, you can lower ``read_limit`` and ``write_limit`` (default: -64 KiB) to reduce the size of buffers for incoming and outgoing data. - -The design document provides :ref:`more details about buffers`. - -Port sharing ------------- - -The WebSocket protocol is an extension of HTTP/1.1. It can be tempting to -serve both HTTP and WebSocket on the same port. - -The author of websockets doesn't think that's a good idea, due to the widely -different operational characteristics of HTTP and WebSocket. - -websockets provide minimal support for responding to HTTP requests with the -:meth:`~legacy.server.WebSocketServerProtocol.process_request` hook. Typical -use cases include health checks. Here's an example: - -.. literalinclude:: ../../example/health_check_server.py - :emphasize-lines: 9-11,20 diff --git a/docs/howto/index.rst b/docs/howto/index.rst index b6cff7237..38073a41e 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -24,13 +24,13 @@ features, which websockets supports fully. extensions -Once your application is ready, learn how to deploy it with a convenient, -optimized, and secure setup. +.. _deployment-howto: + +Once your application is ready, learn how to deploy it on various platforms. .. toctree:: :maxdepth: 2 - deployment heroku kubernetes supervisor diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 030917491..b7493bba6 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -2,6 +2,7 @@ api attr augustin auth +autoscaler awaitable aymeric backend @@ -25,13 +26,17 @@ datastructures django dyno fractalideas +gunicorn +hypercorn iframe IPv +istio iterable keepalive KiB Kubernetes lifecycle +linkerd liveness lookups MiB @@ -60,10 +65,12 @@ unparse unregister uple username +uvicorn virtualenv WebSocket websocket websockets ws +wsgi wss www diff --git a/docs/topics/deployment.rst b/docs/topics/deployment.rst new file mode 100644 index 000000000..a7b4bd3ad --- /dev/null +++ b/docs/topics/deployment.rst @@ -0,0 +1,181 @@ +Deployment +========== + +.. currentmodule:: websockets + +When you deploy your websockets server to production, at a high level, your +architecture will almost certainly look like the following diagram: + +.. image:: deployment.svg + +The basic unit for scaling a websockets server is "one server process". Each +blue box in the diagram represents one server process. + +There's more variation in routing. While the routing layer is shown as one big +box, it is likely to involve several subsystems. + +When you design a deployment, your should consider two questions: + +1. How will I run the appropriate number of server processes? +2. How will I route incoming connections to these processes? + +These questions are strongly related. There's a wide range of acceptable +answers, depending on your goals and your constraints. + +You can find a few concrete examples in the :ref:`deployment how-to guides +`. + +Running server processes +------------------------ + +How many processes do I need? +............................. + +Typically, one server process will manage a few hundreds or thousands +connections, depending on the frequency of messages and the amount of work +they require. + +CPU and memory usage increase with the number of connections to the server. + +Often CPU is the limiting factor. If a server process goes to 100% CPU, then +you reached the limit. How much headroom you want to keep is up to you. + +Once you know how many connections a server process can manage and how many +connections you need to handle, you can calculate how many processes to run. + +You can also automate this calculation by configuring an autoscaler to keep +CPU usage or connection count within acceptable limits. + +Don't scale with threads. Threads doesn't make sense for a server built with +:mod:`asyncio`. + +How do I run processes? +....................... + +Most solutions for running multiple instances of a server process fall into +one of these three buckets: + +1. Running N processes on a platform: + + * a Kubernetes Deployment + + * its equivalent on a Platform as a Service provider + +2. Running N servers: + + * an AWS Auto Scaling group, a GCP Managed instance group, etc. + + * a fixed set of long-lived servers + +3. Running N processes on a server: + + * preferrably via a process manager or supervisor + +Option 1 is easiest of you have access to such a platform. + +Option 2 almost always combines with option 3. + +How do I start a process? +......................... + +Run a Python program that invokes :func:`~serve`. That's it. + +Don't run an ASGI server such as Uvicorn, Hypercorn, or Daphne. They're +alternatives to websockets, not complements. + +Don't run a WSGI server such as Gunicorn, Waitress, or mod_wsgi. They aren't +designed to run WebSocket applications. + +Applications servers handle network connections and expose a Python API. You +don't need one because websockets handles network connections directly. + +How do I stop a process? +........................ + +Process managers send the SIGTERM signal to terminate processes. Catch this +signal and exit the server to ensure a graceful shutdown. + +Here's an example: + +.. literalinclude:: ../../example/shutdown_server.py + :emphasize-lines: 12-15,18 + +When exiting the context manager, :func:`~server.serve` closes all connections +with code 1001 (going away). As a consequence: + +* If the connection handler is awaiting + :meth:`~server.WebSocketServerProtocol.recv`, it receives a + :exc:`~exceptions.ConnectionClosedOK` exception. It can catch the exception + and clean up before exiting. + +* Otherwise, it should be waiting on + :meth:`~server.WebSocketServerProtocol.wait_closed`, so it can receive the + :exc:`~exceptions.ConnectionClosedOK` exception and exit. + +This example is easily adapted to handle other signals. + +If you override the default signal handler for SIGINT, which raises +:exc:`KeyboardInterrupt`, be aware that you won't be able to interrupt a +program with Ctrl-C anymore when it's stuck in a loop. + +Routing connections +------------------- + +What does routing involve? +.......................... + +Since the routing layer is directly exposed to the Internet, it should provide +appropriate protection against threats ranging from Internet background noise +to targeted attacks. + +You should always secure WebSocket connections with TLS. Since the routing +layer carries the public domain name, it should terminate TLS connections. + +Finally, it must route connections to the server processes, balancing new +connections across them. + +How do I route connections? +........................... + +Here are typical solutions for load balancing, matched to ways of running +processes: + +1. If you're running on a platform, it comes with a routing layer: + + * a Kubernetes Ingress and Service + + * a service mesh: Istio, Consul, Linkerd, etc. + + * the routing mesh of a Platform as a Service + +2. If you're running N servers, you may load balance with: + + * a cloud load balancer: AWS Elastic Load Balancing, GCP Cloud Load + Balancing, etc. + + * A software load balancer: HAProxy, NGINX, etc. + +3. If you're running N processes on a server, you may load balance with: + + * A software load balancer: HAProxy, NGINX, etc. + + * The operating system — all processes listen on the same port + +You may trust the load balancer to handle encryption and to provide security. +You may add another layer in front of the load balancer for these purposes. + +There are many possibilities. Don't add layers that you don't need, though. + +How do I implement a health check? +.................................. + +Load balancers need a way to check whether server processes are up and running +to avoid routing connections to a non-functional backend. + +websockets provide minimal support for responding to HTTP requests with the +:meth:`~server.WebSocketServerProtocol.process_request` hook. + +Here's an example: + +.. literalinclude:: ../../example/health_check_server.py + :emphasize-lines: 7-9,18 diff --git a/docs/topics/deployment.svg b/docs/topics/deployment.svg new file mode 100644 index 000000000..fbacb18c4 --- /dev/null +++ b/docs/topics/deployment.svg @@ -0,0 +1,63 @@ +Internetwebsocketswebsocketswebsocketsrouting \ No newline at end of file diff --git a/docs/topics/index.rst b/docs/topics/index.rst index b18434a39..41bdc3051 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -4,7 +4,9 @@ Topics .. toctree:: :maxdepth: 2 + deployment authentication compression design + memory security diff --git a/docs/topics/memory.rst b/docs/topics/memory.rst new file mode 100644 index 000000000..e5b9ed9a2 --- /dev/null +++ b/docs/topics/memory.rst @@ -0,0 +1,43 @@ +Memory usage +============ + +In most cases, memory usage of a WebSocket server is proportional to the +number of open connections. When a server handles thousands of connections, +memory usage can become a bottleneck. + +Memory usage of a single connection is the sum of: + +1. the baseline amount of memory websockets requires for each connection, +2. the amount of data held in buffers before the application processes it, +3. any additional memory allocated by the application itself. + +Baseline +-------- + +Compression settings are the main factor affecting the baseline amount of +memory used by each connection. + +Read to the topic guide on :doc:`../topics/compression` to learn more about +tuning compression settings. + +Buffers +------- + +Under normal circumstances, buffers are almost always empty. + +Under high load, if a server receives more messages than it can process, +bufferbloat can result in excessive memory use. + +By default websockets has generous limits. It is strongly recommended to adapt +them to your application. When you call :func:`~legacy.server.serve`: + +- Set ``max_size`` (default: 1 MiB, UTF-8 encoded) to the maximum size of + messages your application generates. +- Set ``max_queue`` (default: 32) to the maximum number of messages your + application expects to receive faster than it can process them. The queue + provides burst tolerance without slowing down the TCP connection. + +Furthermore, you can lower ``read_limit`` and ``write_limit`` (default: +64 KiB) to reduce the size of buffers for incoming and outgoing data. + +The design document provides :ref:`more details about buffers`. diff --git a/docs/topics/security.rst b/docs/topics/security.rst index 39de08120..6c541db06 100644 --- a/docs/topics/security.rst +++ b/docs/topics/security.rst @@ -24,8 +24,8 @@ With the default settings, opening a connection uses 70 KiB of memory. Sending some highly compressed messages could use up to 128 MiB of memory with an amplification factor of 1000 between network traffic and memory use. -Configuring a server to :ref:`optimize memory usage ` will -improve security in addition to improving performance. +Configuring a server to :doc:`memory` will improve security in addition to +improving performance. Other limits ------------ diff --git a/example/health_check_server.py b/example/health_check_server.py index c5bb6d5ab..2ca185cde 100755 --- a/example/health_check_server.py +++ b/example/health_check_server.py @@ -1,13 +1,11 @@ #!/usr/bin/env python -# WS echo server with HTTP endpoint at /health/ - import asyncio import http import websockets async def health_check(path, request_headers): - if path == "/health/": + if path == "/healthz": return http.HTTPStatus.OK, [], b"OK\n" async def echo(websocket, path): diff --git a/example/shutdown_server.py b/example/shutdown_server.py index 1ae44af1e..cabba4014 100755 --- a/example/shutdown_server.py +++ b/example/shutdown_server.py @@ -13,6 +13,7 @@ async def server(): loop = asyncio.get_running_loop() stop = loop.create_future() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) + async with websockets.serve(echo, "localhost", 8765): await stop From dfecbd0307a94bc8704f895b66167e8838c222e6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 27 May 2021 13:23:30 +0200 Subject: [PATCH 0817/1539] Prep for releasing 9.1 with security fix. --- docs/project/changelog.rst | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index a44dd0418..dcaa06e9d 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -45,11 +45,22 @@ They may change at any time. * Optimized default compression settings to reduce memory usage. -* Protected against timing attacks on HTTP Basic Auth. - * Made it easier to customize authentication with :meth:`~auth.BasicAuthWebSocketServerProtocol.check_credentials`. +9.1 +... + +*May 27, 2021* + +.. note:: + + **Version 9.1 fixes a security issue introduced in version 8.0.** + + Version 8.0 was vulnerable to timing attacks on HTTP Basic Auth passwords. + + .. _CVE-2018-1000518: https://nvd.nist.gov/vuln/detail/CVE-2018-1000518 + 9.0.2 ..... From 547a26b685d08cac0aa64e5e65f7867ac0ea9bc0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 May 2021 18:51:27 +0200 Subject: [PATCH 0818/1539] Use constant-time comparison for passwords. Backport of c91b4c2a and dfecbd03. --- docs/changelog.rst | 6 ++++++ src/websockets/legacy/auth.py | 28 +++++++++++++++------------- tests/legacy/test_auth.py | 11 +++++++++-- 3 files changed, 30 insertions(+), 15 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 1064af736..f3e1acf08 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -30,6 +30,12 @@ They may change at any time. *In development* +.. note:: + + **Version 9.1 fixes a security issue introduced in version 8.0.** + + Version 8.0 was vulnerable to timing attacks on HTTP Basic Auth passwords. + 9.0.2 ..... diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index e0beede57..80ceff28d 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -6,6 +6,7 @@ import functools +import hmac import http from typing import Any, Awaitable, Callable, Iterable, Optional, Tuple, Union, cast @@ -132,24 +133,23 @@ def basic_auth_protocol_factory( if credentials is not None: if is_credentials(credentials): - - async def check_credentials(username: str, password: str) -> bool: - return (username, password) == credentials - + credentials_list = [cast(Credentials, credentials)] elif isinstance(credentials, Iterable): credentials_list = list(credentials) - if all(is_credentials(item) for item in credentials_list): - credentials_dict = dict(credentials_list) - - async def check_credentials(username: str, password: str) -> bool: - return credentials_dict.get(username) == password - - else: + if not all(is_credentials(item) for item in credentials_list): raise TypeError(f"invalid credentials argument: {credentials}") - else: raise TypeError(f"invalid credentials argument: {credentials}") + credentials_dict = dict(credentials_list) + + async def check_credentials(username: str, password: str) -> bool: + try: + expected_password = credentials_dict[username] + except KeyError: + return False + return hmac.compare_digest(expected_password, password) + if create_protocol is None: # Not sure why mypy cannot figure this out. create_protocol = cast( @@ -158,5 +158,7 @@ async def check_credentials(username: str, password: str) -> bool: ) return functools.partial( - create_protocol, realm=realm, check_credentials=check_credentials + create_protocol, + realm=realm, + check_credentials=check_credentials, ) diff --git a/tests/legacy/test_auth.py b/tests/legacy/test_auth.py index bb8c6a6eb..3d8eb90d7 100644 --- a/tests/legacy/test_auth.py +++ b/tests/legacy/test_auth.py @@ -1,3 +1,4 @@ +import hmac import unittest import urllib.error @@ -76,7 +77,7 @@ def test_basic_auth_bad_multiple_credentials(self): ) async def check_credentials(username, password): - return password == "iloveyou" + return hmac.compare_digest(password, "iloveyou") create_protocol_check_credentials = basic_auth_protocol_factory( realm="auth-tests", @@ -140,7 +141,13 @@ def test_basic_auth_unsupported_credentials_details(self): self.assertEqual(raised.exception.read().decode(), "Unsupported credentials\n") @with_server(create_protocol=create_protocol) - def test_basic_auth_invalid_credentials(self): + def test_basic_auth_invalid_username(self): + with self.assertRaises(InvalidStatusCode) as raised: + self.start_client(user_info=("goodbye", "iloveyou")) + self.assertEqual(raised.exception.status_code, 401) + + @with_server(create_protocol=create_protocol) + def test_basic_auth_invalid_password(self): with self.assertRaises(InvalidStatusCode) as raised: self.start_client(user_info=("hello", "ihateyou")) self.assertEqual(raised.exception.status_code, 401) From d0f328888f3e695aa64d78dcf48af4ece219221b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 27 May 2021 13:32:46 +0200 Subject: [PATCH 0819/1539] Bump version number. --- docs/changelog.rst | 2 +- docs/conf.py | 4 ++-- src/websockets/version.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index f3e1acf08..a82008a49 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -28,7 +28,7 @@ They may change at any time. 9.1 ... -*In development* +*May 27, 2021* .. note:: diff --git a/docs/conf.py b/docs/conf.py index dad7475f7..2246c0287 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -59,9 +59,9 @@ # built documents. # # The short X.Y version. -version = '9.0' +version = '9.1' # The full version, including alpha/beta/rc tags. -release = '9.0' +release = '9.1' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/src/websockets/version.py b/src/websockets/version.py index 02dbe9d3c..a7901ef92 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -1 +1 @@ -version = "9.0.2" +version = "9.1" From ac85304980285079ff871cde728964d0acfde569 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 27 May 2021 21:27:27 +0200 Subject: [PATCH 0820/1539] Fix link to changelog. --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 8c7a4984a..0625798a2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -4,7 +4,7 @@ python-tag = py37.py38.py39.py310 [metadata] license_file = LICENSE project_urls = - Changelog = https://websockets.readthedocs.io/en/stable/changelog.html + Changelog = https://websockets.readthedocs.io/en/stable/project/changelog.html Documentation = https://websockets.readthedocs.io/ Funding = https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=readme Tracker = https://github.com/aaugustin/websockets/issues From 17210674d6d2f0987a0cd74d0d6ac37d88d28977 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 27 May 2021 22:52:49 +0200 Subject: [PATCH 0821/1539] Add deployment guide for nginx. --- docs/howto/index.rst | 1 + docs/howto/nginx.rst | 84 +++++++++++++++++++++++ example/deployment/nginx/app.py | 29 ++++++++ example/deployment/nginx/nginx.conf | 25 +++++++ example/deployment/nginx/supervisord.conf | 7 ++ 5 files changed, 146 insertions(+) create mode 100644 docs/howto/nginx.rst create mode 100644 example/deployment/nginx/app.py create mode 100644 example/deployment/nginx/nginx.conf create mode 100644 example/deployment/nginx/supervisord.conf diff --git a/docs/howto/index.rst b/docs/howto/index.rst index 38073a41e..eab6cab6e 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -34,3 +34,4 @@ Once your application is ready, learn how to deploy it on various platforms. heroku kubernetes supervisor + nginx diff --git a/docs/howto/nginx.rst b/docs/howto/nginx.rst new file mode 100644 index 000000000..cb4dc83f2 --- /dev/null +++ b/docs/howto/nginx.rst @@ -0,0 +1,84 @@ +Deploy behind nginx +=================== + +This guide demonstrates a way to load balance connections across multiple +websockets server processes running on the same machine with nginx_. + +We'll run server processes with Supervisor as described in :doc:`this guide +`. + +.. _nginx: https://nginx.org/ + +Run server processes +-------------------- + +Save this app to ``app.py``: + +.. literalinclude:: ../../example/deployment/nginx/app.py + :emphasize-lines: 21,23 + +We'd like to nginx to connect to websockets servers via Unix sockets in order +to avoid the overhead of TCP for communicating between processes running in +the same OS. + +We start the app with :func:`~websockets.server.unix_serve`. Each server +process listens on a different socket thanks to an environment variable set +by Supervisor to a different value. + +Save this configuration to ``supervisord.conf``: + +.. literalinclude:: ../../example/deployment/nginx/supervisord.conf + +This configuration runs four instances of the app. + +Install Supervisor and run it: + +.. code:: console + + $ supervisord -c supervisord.conf -n + +Configure and run nginx +----------------------- + +Here's a simple nginx configuration to load balance connections across four +processes: + +.. literalinclude:: ../../example/deployment/nginx/nginx.conf + +We set ``daemon off`` so we can run nginx in the foreground for testing. + +Then we combine the `WebSocket proxying`_ and `load balancing`_ guides: + +* The WebSocket protocol requires HTTP/1.1. We must set the HTTP protocol + version to 1.1, else nginx defaults to HTTP/1.0 for proxying. + +* The WebSocket handshake involves the ``Connection`` and ``Upgrade`` HTTP + headers. We must pass them to the upstream explicitly, else nginx drops + them because they're hop-by-hop headers. + + We deviate from the `WebSocket proxying`_ guide because its example adds a + ``Connection: Upgrade`` header to every upstream request, even if the + original request didn't contain that header. + +* In the upstream configuration, we set the load balancing method to + ``least_conn`` in order to balance the number of active connections across + servers. This is best for long running connections. + +.. _WebSocket proxying: http://nginx.org/en/docs/http/websocket.html +.. _load balancing: http://nginx.org/en/docs/http/load_balancing.html + +Save the configuration to ``nginx.conf``, install nginx, and run it: + +.. code:: console + + $ nginx -c nginx.conf -p . + +You can confirm that nginx proxies connections properly: + +.. code:: console + + $ PYTHONPATH=src python -m websockets ws://localhost:8080/ + Connected to ws://localhost:8080/. + > Hello! + < Hello! + Connection closed: code = 1000 (OK), no reason. diff --git a/example/deployment/nginx/app.py b/example/deployment/nginx/app.py new file mode 100644 index 000000000..ad42a8b3e --- /dev/null +++ b/example/deployment/nginx/app.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python + +import asyncio +import os +import signal + +import websockets + + +async def echo(websocket, path): + async for message in websocket: + await websocket.send(message) + + +async def main(): + # Set the stop condition when receiving SIGTERM. + loop = asyncio.get_running_loop() + stop = loop.create_future() + loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) + + async with websockets.unix_serve( + echo, + path=f"{os.environ['SUPERVISOR_PROCESS_NAME']}.sock", + ): + await stop + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/deployment/nginx/nginx.conf b/example/deployment/nginx/nginx.conf new file mode 100644 index 000000000..67aa0086d --- /dev/null +++ b/example/deployment/nginx/nginx.conf @@ -0,0 +1,25 @@ +daemon off; + +events { +} + +http { + server { + listen localhost:8080; + + location / { + proxy_http_version 1.1; + proxy_pass http://websocket; + proxy_set_header Connection $http_connection; + proxy_set_header Upgrade $http_upgrade; + } + } + + upstream websocket { + least_conn; + server unix:websockets-test_00.sock; + server unix:websockets-test_01.sock; + server unix:websockets-test_02.sock; + server unix:websockets-test_03.sock; + } +} diff --git a/example/deployment/nginx/supervisord.conf b/example/deployment/nginx/supervisord.conf new file mode 100644 index 000000000..76a664d91 --- /dev/null +++ b/example/deployment/nginx/supervisord.conf @@ -0,0 +1,7 @@ +[supervisord] + +[program:websockets-test] +command = python app.py +process_name = %(program_name)s_%(process_num)02d +numprocs = 4 +autorestart = true From 2990cf761177073e6d741e1ef81dc1d6d3c5dba8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 27 May 2021 23:05:18 +0200 Subject: [PATCH 0822/1539] Add HAProxy deployment guide. Ref #445. --- docs/howto/haproxy.rst | 61 +++++++++++++++++++++ docs/howto/index.rst | 1 + docs/spelling_wordlist.txt | 2 + docs/topics/deployment.rst | 2 +- example/deployment/haproxy/app.py | 30 ++++++++++ example/deployment/haproxy/haproxy.cfg | 17 ++++++ example/deployment/haproxy/supervisord.conf | 7 +++ 7 files changed, 119 insertions(+), 1 deletion(-) create mode 100644 docs/howto/haproxy.rst create mode 100644 example/deployment/haproxy/app.py create mode 100644 example/deployment/haproxy/haproxy.cfg create mode 100644 example/deployment/haproxy/supervisord.conf diff --git a/docs/howto/haproxy.rst b/docs/howto/haproxy.rst new file mode 100644 index 000000000..d520d278a --- /dev/null +++ b/docs/howto/haproxy.rst @@ -0,0 +1,61 @@ +Deploy behind HAProxy +===================== + +This guide demonstrates a way to load balance connections across multiple +websockets server processes running on the same machine with HAProxy_. + +We'll run server processes with Supervisor as described in :doc:`this guide +`. + +.. _HAProxy: https://www.haproxy.org/ + +Run server processes +-------------------- + +Save this app to ``app.py``: + +.. literalinclude:: ../../example/deployment/haproxy/app.py + :emphasize-lines: 24 + +Each server process listens on a different port by extracting an incremental +index from an environment variable set by Supervisor. + +Save this configuration to ``supervisord.conf``: + +.. literalinclude:: ../../example/deployment/haproxy/supervisord.conf + +This configuration runs four instances of the app. + +Install Supervisor and run it: + +.. code:: console + + $ supervisord -c supervisord.conf -n + +Configure and run HAProxy +------------------------- + +Here's a simple HAProxy configuration to load balance connections across four +processes: + +.. literalinclude:: ../../example/deployment/haproxy/haproxy.cfg + +In the backend configuration, we set the load balancing method to +``leastconn`` in order to balance the number of active connections across +servers. This is best for long running connections. + +Save the configuration to ``haproxy.cfg``, install HAProxy, and run it: + +.. code:: console + + $ haproxy -f haproxy.cfg + +You can confirm that HAProxy proxies connections properly: + +.. code:: console + + $ PYTHONPATH=src python -m websockets ws://localhost:8080/ + Connected to ws://localhost:8080/. + > Hello! + < Hello! + Connection closed: code = 1000 (OK), no reason. diff --git a/docs/howto/index.rst b/docs/howto/index.rst index eab6cab6e..ac1182705 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -35,3 +35,4 @@ Once your application is ready, learn how to deploy it on various platforms. kubernetes supervisor nginx + haproxy diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index b7493bba6..676e13afc 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -27,6 +27,7 @@ django dyno fractalideas gunicorn +haproxy hypercorn iframe IPv @@ -48,6 +49,7 @@ permessage pid pong pongs +proxying pythonic redis runtime diff --git a/docs/topics/deployment.rst b/docs/topics/deployment.rst index a7b4bd3ad..d30c5568e 100644 --- a/docs/topics/deployment.rst +++ b/docs/topics/deployment.rst @@ -69,7 +69,7 @@ one of these three buckets: 3. Running N processes on a server: - * preferrably via a process manager or supervisor + * preferably via a process manager or supervisor Option 1 is easiest of you have access to such a platform. diff --git a/example/deployment/haproxy/app.py b/example/deployment/haproxy/app.py new file mode 100644 index 000000000..2b24790dd --- /dev/null +++ b/example/deployment/haproxy/app.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python + +import asyncio +import os +import signal + +import websockets + + +async def echo(websocket, path): + async for message in websocket: + await websocket.send(message) + + +async def main(): + # Set the stop condition when receiving SIGTERM. + loop = asyncio.get_running_loop() + stop = loop.create_future() + loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) + + async with websockets.serve( + echo, + host="localhost", + port=8000 + int(os.environ["SUPERVISOR_PROCESS_NAME"][-2:]), + ): + await stop + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/deployment/haproxy/haproxy.cfg b/example/deployment/haproxy/haproxy.cfg new file mode 100644 index 000000000..e63727d1c --- /dev/null +++ b/example/deployment/haproxy/haproxy.cfg @@ -0,0 +1,17 @@ +defaults + mode http + timeout connect 10s + timeout client 30s + timeout server 30s + +frontend websocket + bind localhost:8080 + default_backend websocket + +backend websocket + balance leastconn + server websockets-test_00 localhost:8000 + server websockets-test_01 localhost:8001 + server websockets-test_02 localhost:8002 + server websockets-test_03 localhost:8003 + diff --git a/example/deployment/haproxy/supervisord.conf b/example/deployment/haproxy/supervisord.conf new file mode 100644 index 000000000..76a664d91 --- /dev/null +++ b/example/deployment/haproxy/supervisord.conf @@ -0,0 +1,7 @@ +[supervisord] + +[program:websockets-test] +command = python app.py +process_name = %(program_name)s_%(process_num)02d +numprocs = 4 +autorestart = true From 58ac5228dabae9a734e2ee360be36b8db5f2e8c2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 28 May 2021 21:56:17 +0200 Subject: [PATCH 0823/1539] Explain keepalives. Ref #919. --- docs/spelling_wordlist.txt | 1 + docs/topics/index.rst | 1 + docs/topics/timeouts.rst | 66 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 68 insertions(+) create mode 100644 docs/topics/timeouts.rst diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 676e13afc..4fa44fbaa 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -52,6 +52,7 @@ pongs proxying pythonic redis +retransmit runtime scalable serializers diff --git a/docs/topics/index.rst b/docs/topics/index.rst index 41bdc3051..9269c585a 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -7,6 +7,7 @@ Topics deployment authentication compression + timeouts design memory security diff --git a/docs/topics/timeouts.rst b/docs/topics/timeouts.rst new file mode 100644 index 000000000..828d22a0b --- /dev/null +++ b/docs/topics/timeouts.rst @@ -0,0 +1,66 @@ +Timeouts +======== + +Since the WebSocket protocol is intended for real-time communications over +long-lived connections, it is desirable to ensure that connections don't +break, and if they do, to report the problem quickly. + +WebSocket is built on top of HTTP/1.1 where connections are short-lived, even +with ``Connection: keep-alive``. Typically, HTTP/1.1 infrastructure closes +idle connections after 30 to 120 seconds. + +As a consequence, proxies may terminate WebSocket connections prematurely, +when no message was exchanged in 30 seconds. + +In order to avoid this problem, websockets implements a keepalive mechanism +based on WebSocket Ping and Pong frames. Ping and Pong are designed for this +purpose. + +By default, websockets waits 20 seconds, then sends a Ping frame, and expects +to receive the corresponding Pong frame within 20 seconds. Else, it considers +the connection broken and closes it. + +Timings are configurable with ``ping_interval`` and ``ping_timeout``. + +While WebSocket runs on top of TCP, websockets doesn't rely on TCP keepalive +because it's disabled by default and, if enabled, the default interval is no +less than two hours, which doesn't meet requirements. + +Latency between a client and a server may increase for two reasons: + +* Network connectivity is poor. When network packets are lost, TCP attempts to + retransmit them, which manifests as latency. Excessive packet loss makes + the connection unusable in practice. At some point, timing out is a + reasonable choice. + +* Traffic is high. For example, if a client sends messages on the connection + faster than a server can process them, this manifests as latency as well, + because data is waiting in flight, mostly in OS buffers. + + If the server is more than 20 seconds behind, it doesn't see the Pong before + the default timeout elapses. As a consequence, it closes the connection. + This is a reasonable choice to prevent overload. + + If traffic spikes cause unwanted timeouts and you're confident that the + server will catch up eventually, you can increase ``ping_timeout`` or you + can disable keepalive entirely with ``ping_interval=None``. + + The same reasoning applies to situations where the server sends more traffic + than the client can accept. + +You can monitor latency as follows: + +.. code:: python + + import asyncio + import logging + import time + + async def log_latency(websocket, logger): + t0 = time.perf_counter() + pong_waiter = await websocket.ping() + await pong_waiter + t1 = time.perf_counter() + logger.info("Connection latency: %.3f seconds", t1 - t0) + + asyncio.create_task(log_latency(websocket, logging.getLogger())) From 820dad7b339a6dd653938d64b22cf732a9415133 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 29 May 2021 07:02:33 +0200 Subject: [PATCH 0824/1539] Add ToC to FAQ. --- docs/howto/faq.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index 23964dcfb..ce3704a2e 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -11,6 +11,9 @@ FAQ .. _developing with asyncio: https://docs.python.org/3/library/asyncio-dev.html +.. contents:: + :local: + Server side ----------- From f8e081da8d1ee76da19b64cadcb2adf1eaa140d4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 29 May 2021 21:04:26 +0200 Subject: [PATCH 0825/1539] isort now supports import at the bottom of modules. --- src/websockets/frames.py | 2 +- src/websockets/legacy/framing.py | 15 +++++++-------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 6e5ef1b73..62bea6814 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -338,4 +338,4 @@ def check_close(code: int) -> None: # at the bottom to allow circular import, because Extension depends on Frame -from . import extensions # isort:skip # noqa +from . import extensions # noqa diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index 627e6922c..e947c9383 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -124,12 +124,11 @@ def write( write(self.serialize(mask=mask, extensions=extensions)) -# Backwards compatibility with previously documented public APIs -from ..frames import parse_close # isort:skip # noqa -from ..frames import prepare_ctrl as encode_data # isort:skip # noqa -from ..frames import prepare_data # isort:skip # noqa -from ..frames import serialize_close # isort:skip # noqa - - # at the bottom to allow circular import, because Extension depends on Frame -from .. import extensions # isort:skip # noqa +from .. import extensions # noqa + +# Backwards compatibility with previously documented public APIs +from ..frames import parse_close # noqa +from ..frames import prepare_data # noqa +from ..frames import serialize_close # noqa +from ..frames import prepare_ctrl as encode_data # noqa From 07e8a636eeb7187532175b27e92a8c67564f631a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 30 May 2021 15:17:04 +0200 Subject: [PATCH 0826/1539] Check that GET requests don't have a body. It is technically possible but doesn't have a meaning in general. --- src/websockets/http11.py | 8 ++++++++ tests/test_http11.py | 18 ++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 0754ddabb..6f3cbccc4 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -99,6 +99,14 @@ def parse( headers = yield from parse_headers(read_line) + # https://tools.ietf.org/html/rfc7230#section-3.3.3 + + if "Transfer-Encoding" in headers: + raise NotImplementedError("transfer codings aren't supported") + + if "Content-Length" in headers: + raise ValueError("unsupported request body") + return cls(path, headers) def serialize(self) -> bytes: diff --git a/tests/test_http11.py b/tests/test_http11.py index 1cca2053f..e73365cf0 100644 --- a/tests/test_http11.py +++ b/tests/test_http11.py @@ -77,6 +77,24 @@ def test_parse_invalid_header(self): "invalid HTTP header line: Oops", ) + def test_parse_body(self): + self.reader.feed_data(b"GET / HTTP/1.1\r\nContent-Length: 3\r\n\r\nYo\n") + with self.assertRaises(ValueError) as raised: + next(self.parse()) + self.assertEqual( + str(raised.exception), + "unsupported request body", + ) + + def test_parse_body_with_transfer_encoding(self): + self.reader.feed_data(b"GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n") + with self.assertRaises(NotImplementedError) as raised: + next(self.parse()) + self.assertEqual( + str(raised.exception), + "transfer codings aren't supported", + ) + def test_serialize(self): # Example from the protocol overview in RFC 6455 request = Request( From 58193e088ace6ce7c4c2dfcb6deea2af7cbe8601 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 30 May 2021 17:20:56 +0200 Subject: [PATCH 0827/1539] Add UUID to connections. --- docs/reference/client.rst | 6 ++++++ docs/reference/server.rst | 6 ++++++ src/websockets/connection.py | 4 ++++ src/websockets/legacy/protocol.py | 4 ++++ 4 files changed, 20 insertions(+) diff --git a/docs/reference/client.rst b/docs/reference/client.rst index 374a8197e..c2bb6266b 100644 --- a/docs/reference/client.rst +++ b/docs/reference/client.rst @@ -17,6 +17,12 @@ Client .. autoclass:: WebSocketClientProtocol(*, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, origin=None, extensions=None, subprotocols=None, extra_headers=None) + .. attribute:: id + + UUID for the connection. + + Useful for identifying connections in logs. + .. autoattribute:: local_address .. autoattribute:: remote_address diff --git a/docs/reference/server.rst b/docs/reference/server.rst index 1a2dd1c88..898403070 100644 --- a/docs/reference/server.rst +++ b/docs/reference/server.rst @@ -27,6 +27,12 @@ Server .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None) + .. attribute:: id + + UUID for the connection. + + Useful for identifying connections in logs. + .. autoattribute:: local_address .. autoattribute:: remote_address diff --git a/src/websockets/connection.py b/src/websockets/connection.py index aeb774f00..4bd09e282 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -1,5 +1,6 @@ import enum import logging +import uuid from typing import Generator, List, Optional, Union from .exceptions import InvalidState, PayloadTooBig, ProtocolError @@ -68,6 +69,9 @@ def __init__( state: State = OPEN, max_size: Optional[int] = 2 ** 20, ) -> None: + # Unique identifier. For logs. + self.id = uuid.uuid4() + # Connection side. CLIENT or SERVER. self.side = side diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 4e8958b60..b544b93d0 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -14,6 +14,7 @@ import logging import random import struct +import uuid import warnings from typing import ( Any, @@ -132,6 +133,9 @@ def __init__( self.read_limit = read_limit self.write_limit = write_limit + # Unique identifier. For logs. + self.id = uuid.uuid4() + assert loop is not None # Remove when dropping Python < 3.10 - use get_running_loop instead. self.loop = loop From 2617b2c8f80c19e759a4e1364bc7007ff8d99acd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 29 May 2021 20:59:26 +0200 Subject: [PATCH 0828/1539] Add human-friendly representation of frames. Fix #765. --- src/websockets/__main__.py | 3 +- src/websockets/exceptions.py | 46 ++----------- src/websockets/frames.py | 89 +++++++++++++++++++++++++ tests/test_frames.py | 126 +++++++++++++++++++++++++++++++++++ 4 files changed, 221 insertions(+), 43 deletions(-) diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index 46e746da5..4358c323c 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -6,7 +6,8 @@ import threading from typing import Any, Set -from .exceptions import ConnectionClosed, format_close +from .exceptions import ConnectionClosed +from .frames import format_close from .legacy.client import connect diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 9ab9d3ebe..4bd2a41a6 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -68,48 +68,6 @@ class WebSocketException(Exception): """ -# See https://www.iana.org/assignments/websocket/websocket.xhtml -CLOSE_CODES = { - 1000: "OK", - 1001: "going away", - 1002: "protocol error", - 1003: "unsupported type", - # 1004 is reserved - 1005: "no status code [internal]", - 1006: "connection closed abnormally [internal]", - 1007: "invalid data", - 1008: "policy violation", - 1009: "message too big", - 1010: "extension required", - 1011: "unexpected error", - 1012: "service restart", - 1013: "try again later", - 1014: "bad gateway", - 1015: "TLS failure [internal]", -} - - -def format_close(code: int, reason: str) -> str: - """ - Display a human-readable version of the close code and reason. - - """ - if 3000 <= code < 4000: - explanation = "registered" - elif 4000 <= code < 5000: - explanation = "private use" - else: - explanation = CLOSE_CODES.get(code, "unknown") - result = f"code = {code} ({explanation}), " - - if reason: - result += f"reason = {reason}" - else: - result += "no reason" - - return result - - class ConnectionClosed(WebSocketException): """ Raised when trying to interact with a closed connection. @@ -371,3 +329,7 @@ class ProtocolError(WebSocketException): WebSocketProtocolError = ProtocolError # for backwards compatibility + + +# at the bottom to allow circular import, because the frames module imports exceptions +from .frames import format_close # noqa diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 62bea6814..de7fdb941 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -52,6 +52,28 @@ class Opcode(enum.IntEnum): DATA_OPCODES = OP_CONT, OP_TEXT, OP_BINARY CTRL_OPCODES = OP_CLOSE, OP_PING, OP_PONG + +# See https://www.iana.org/assignments/websocket/websocket.xhtml +CLOSE_CODES = { + 1000: "OK", + 1001: "going away", + 1002: "protocol error", + 1003: "unsupported type", + # 1004 is reserved + 1005: "no status code [internal]", + 1006: "connection closed abnormally [internal]", + 1007: "invalid data", + 1008: "policy violation", + 1009: "message too big", + 1010: "extension required", + 1011: "unexpected error", + 1012: "service restart", + 1013: "try again later", + 1014: "bad gateway", + 1015: "TLS failure [internal]", +} + + # Close code that are allowed in a close frame. # Using a set optimizes `code in EXTERNAL_CLOSE_CODES`. EXTERNAL_CLOSE_CODES = { @@ -96,6 +118,52 @@ class Frame(NamedTuple): rsv2: bool = False rsv3: bool = False + def __str__(self) -> str: + """ + Return a human-readable represention of a frame. + + """ + coding = None + length = f"{len(self.data)} byte{'' if len(self.data) == 1 else 's'}" + non_final = "" if self.fin else "continued" + + if self.opcode is OP_TEXT: + # Decoding only the beginning and the end is needlessly hard. + # Decode the entire payload then elide later if necessary. + data = self.data.decode() + elif self.opcode is OP_BINARY: + # We'll show at most the first 16 bytes and the last 8 bytes. + # Encode just what we need, plus two dummy bytes to elide later. + binary = self.data + if len(binary) > 25: + binary = binary[:16] + b"\x00\x00" + binary[-8:] + data = " ".join(f"{byte:02x}" for byte in binary) + elif self.opcode is OP_CLOSE: + code, reason = parse_close(self.data) + data = format_close(code, reason) + elif self.data: + # We don't know if a Continuation frame contains text or binary. + # Ping and Pong frames could contain UTF-8. Attempt to decode as + # UTF-8 and display it as text; fallback to binary. + try: + data = self.data.decode() + coding = "text" + except UnicodeDecodeError: + binary = self.data + if len(binary) > 25: + binary = binary[:16] + b"\x00\x00" + binary[-8:] + data = " ".join(f"{byte:02x}" for byte in binary) + coding = "binary" + else: + data = "" + + if len(data) > 75: + data = data[:48] + "..." + data[-24:] + + metadata = ", ".join(filter(None, [coding, length, non_final])) + + return f"{self.opcode.name} {data} [{metadata}]" + @classmethod def parse( cls, @@ -337,5 +405,26 @@ def check_close(code: int) -> None: raise ProtocolError("invalid status code") +def format_close(code: int, reason: str) -> str: + """ + Display a human-readable version of the close code and reason. + + """ + if 3000 <= code < 4000: + explanation = "registered" + elif 4000 <= code < 5000: + explanation = "private use" + else: + explanation = CLOSE_CODES.get(code, "unknown") + result = f"code = {code} ({explanation}), " + + if reason: + result += f"reason = {reason}" + else: + result += "no reason" + + return result + + # at the bottom to allow circular import, because Extension depends on Frame from . import extensions # noqa diff --git a/tests/test_frames.py b/tests/test_frames.py index 13a712322..491386566 100644 --- a/tests/test_frames.py +++ b/tests/test_frames.py @@ -193,6 +193,132 @@ def decode(frame, *, max_size=None): ) +class StrTests(unittest.TestCase): + def test_cont_text(self): + self.assertEqual( + str(Frame(False, OP_CONT, b" cr\xc3\xa8me")), + "CONT crème [text, 7 bytes, continued]", + ) + + def test_cont_binary(self): + self.assertEqual( + str(Frame(False, OP_CONT, b"\xfc\xfd\xfe\xff")), + "CONT fc fd fe ff [binary, 4 bytes, continued]", + ) + + def test_cont_final_text(self): + self.assertEqual( + str(Frame(True, OP_CONT, b" cr\xc3\xa8me")), + "CONT crème [text, 7 bytes]", + ) + + def test_cont_final_binary(self): + self.assertEqual( + str(Frame(True, OP_CONT, b"\xfc\xfd\xfe\xff")), + "CONT fc fd fe ff [binary, 4 bytes]", + ) + + def test_cont_text_truncated(self): + self.assertEqual( + str(Frame(False, OP_CONT, b"caf\xc3\xa9 " * 16)), + "CONT café café café café café café café café café caf..." + "afé café café café café [text, 96 bytes, continued]", + ) + + def test_cont_binary_truncated(self): + self.assertEqual( + str(Frame(False, OP_CONT, bytes(range(256)))), + "CONT 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f ..." + " f8 f9 fa fb fc fd fe ff [binary, 256 bytes, continued]", + ) + + def test_text(self): + self.assertEqual( + str(Frame(True, OP_TEXT, b"caf\xc3\xa9")), + "TEXT café [5 bytes]", + ) + + def test_text_non_final(self): + self.assertEqual( + str(Frame(False, OP_TEXT, b"caf\xc3\xa9")), + "TEXT café [5 bytes, continued]", + ) + + def test_text_truncated(self): + self.assertEqual( + str(Frame(True, OP_TEXT, b"caf\xc3\xa9 " * 16)), + "TEXT café café café café café café café café café caf..." + "afé café café café café [96 bytes]", + ) + + def test_binary(self): + self.assertEqual( + str(Frame(True, OP_BINARY, b"\x00\x01\x02\x03")), + "BINARY 00 01 02 03 [4 bytes]", + ) + + def test_binary_non_final(self): + self.assertEqual( + str(Frame(False, OP_BINARY, b"\x00\x01\x02\x03")), + "BINARY 00 01 02 03 [4 bytes, continued]", + ) + + def test_binary_truncated(self): + self.assertEqual( + str(Frame(True, OP_BINARY, bytes(range(256)))), + "BINARY 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f ..." + " f8 f9 fa fb fc fd fe ff [256 bytes]", + ) + + def test_close(self): + self.assertEqual( + str(Frame(True, OP_CLOSE, b"\x03\xe8")), + "CLOSE code = 1000 (OK), no reason [2 bytes]", + ) + + def test_close_reason(self): + self.assertEqual( + str(Frame(True, OP_CLOSE, b"\x03\xe9Bye!")), + "CLOSE code = 1001 (going away), reason = Bye! [6 bytes]", + ) + + def test_ping(self): + self.assertEqual( + str(Frame(True, OP_PING, b"")), + "PING [0 bytes]", + ) + + def test_ping_text(self): + self.assertEqual( + str(Frame(True, OP_PING, b"ping")), + "PING ping [text, 4 bytes]", + ) + + def test_ping_binary(self): + self.assertEqual( + str(Frame(True, OP_PING, b"\xff\x00\xff\x00")), + "PING ff 00 ff 00 [binary, 4 bytes]", + ) + + def test_pong(self): + self.assertEqual( + str(Frame(True, OP_PONG, b"")), + "PONG [0 bytes]", + ) + + def test_pong_text(self): + self.assertEqual( + str(Frame(True, OP_PONG, b"pong")), + "PONG pong [text, 4 bytes]", + ) + + def test_pong_binary(self): + self.assertEqual( + str(Frame(True, OP_PONG, b"\xff\x00\xff\x00")), + "PONG ff 00 ff 00 [binary, 4 bytes]", + ) + + class PrepareDataTests(unittest.TestCase): def test_prepare_data_str(self): self.assertEqual( From 09f829f66a57ded2832279a18a64257b5fd3f875 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 May 2021 16:13:43 +0200 Subject: [PATCH 0829/1539] Inject logger in Sans-I/O layer. --- src/websockets/__init__.py | 2 ++ src/websockets/client.py | 18 +++++++++++------- src/websockets/connection.py | 22 ++++++++++++++-------- src/websockets/server.py | 26 +++++++++++++++----------- src/websockets/typing.py | 12 +++++++++++- tests/test_client.py | 9 +++++++++ tests/test_server.py | 9 +++++++++ 7 files changed, 71 insertions(+), 27 deletions(-) diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 65d9fb913..f136a4e45 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -27,6 +27,7 @@ "InvalidStatusCode", "InvalidUpgrade", "InvalidURI", + "LoggerLike", "NegotiationError", "Origin", "parse_uri", @@ -92,6 +93,7 @@ "WebSocketServerProtocol": ".legacy.server", "WebSocketServer": ".legacy.server", "Data": ".typing", + "LoggerLike": ".typing", "Origin": ".typing", "ExtensionHeader": ".typing", "ExtensionParameter": ".typing", diff --git a/src/websockets/client.py b/src/websockets/client.py index 0ddf19f00..9d3f0e7f7 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -1,5 +1,4 @@ import collections -import logging from typing import Generator, List, Optional, Sequence from .connection import CLIENT, CONNECTING, OPEN, Connection @@ -27,6 +26,7 @@ from .typing import ( ConnectionOption, ExtensionHeader, + LoggerLike, Origin, Subprotocol, UpgradeProtocol, @@ -41,8 +41,6 @@ __all__ = ["ClientConnection"] -logger = logging.getLogger(__name__) - class ClientConnection(Connection): def __init__( @@ -53,8 +51,14 @@ def __init__( subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLike] = None, max_size: Optional[int] = 2 ** 20, + logger: Optional[LoggerLike] = None, ): - super().__init__(side=CLIENT, state=CONNECTING, max_size=max_size) + super().__init__( + side=CLIENT, + state=CONNECTING, + max_size=max_size, + logger=logger, + ) self.wsuri = parse_uri(uri) self.origin = origin self.available_extensions = extensions @@ -271,8 +275,8 @@ def send_request(self, request: Request) -> None: Send a WebSocket handshake request to the server. """ - logger.debug("%s > GET %s HTTP/1.1", self.side, request.path) - logger.debug("%s > %r", self.side, request.headers) + self.logger.debug("%s > GET %s HTTP/1.1", self.side, request.path) + self.logger.debug("%s > %r", self.side, request.headers) self.writes.append(request.serialize()) @@ -285,7 +289,7 @@ def parse(self) -> Generator[None, None, None]: self.process_response(response) except InvalidHandshake as exc: response = response._replace(exception=exc) - logger.debug("Invalid handshake", exc_info=True) + self.logger.debug("Invalid handshake", exc_info=True) else: self.set_state(OPEN) finally: diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 4bd09e282..4852560e3 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -18,7 +18,7 @@ ) from .http11 import Request, Response from .streams import StreamReader -from .typing import Origin, Subprotocol +from .typing import LoggerLike, Origin, Subprotocol __all__ = [ @@ -28,8 +28,6 @@ "SEND_EOF", ] -logger = logging.getLogger(__name__) - Event = Union[Request, Response, Frame] @@ -68,6 +66,7 @@ def __init__( side: Side, state: State = OPEN, max_size: Optional[int] = 2 ** 20, + logger: Optional[LoggerLike] = None, ) -> None: # Unique identifier. For logs. self.id = uuid.uuid4() @@ -76,12 +75,19 @@ def __init__( self.side = side # Connnection state. CONNECTING and CLOSED states are handled in subclasses. - logger.debug("%s - initial state: %s", self.side, state.name) self.state = state # Maximum size of incoming messages in bytes. self.max_size = max_size + # Logger or LoggerAdapter for this connection. + if logger is None: + logger = logging.getLogger(f"websockets.{side.name.lower()}") + self.logger = logger + + # Must wait until we have the logger to log the initial state! + self.logger.debug("%s - initial state: %s", self.side, state.name) + # Current size of incoming message in bytes. Only set while reading a # fragmented message i.e. a data frames with the FIN bit not set. self.cur_size: Optional[int] = None @@ -117,7 +123,7 @@ def __init__( self.parser_exc: Optional[Exception] = None def set_state(self, state: State) -> None: - logger.debug( + self.logger.debug( "%s - state change: %s > %s", self.side, self.state.name, state.name ) self.state = state @@ -286,7 +292,7 @@ def step_parser(self) -> None: self.parser_exc = exc raise except Exception as exc: - logger.error("unexpected exception in parser", exc_info=True) + self.logger.error("unexpected exception in parser", exc_info=True) # Don't include exception details, which may be security-sensitive. self.fail_connection(1011) self.parser_exc = exc @@ -401,7 +407,7 @@ def send_frame(self, frame: Frame) -> None: f"cannot write to a WebSocket in the {self.state.name} state" ) - logger.debug("%s > %r", self.side, frame) + self.logger.debug("%s > %r", self.side, frame) self.writes.append( frame.serialize(mask=self.side is CLIENT, extensions=self.extensions) ) @@ -409,5 +415,5 @@ def send_frame(self, frame: Frame) -> None: def send_eof(self) -> None: assert not self.eof_sent self.eof_sent = True - logger.debug("%s > EOF", self.side) + self.logger.debug("%s > EOF", self.side) self.writes.append(SEND_EOF) diff --git a/src/websockets/server.py b/src/websockets/server.py index f57d36b70..4a76ac886 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -3,7 +3,6 @@ import collections import email.utils import http -import logging from typing import Callable, Generator, List, Optional, Sequence, Tuple, Union, cast from .connection import CONNECTING, OPEN, SERVER, Connection @@ -29,6 +28,7 @@ from .typing import ( ConnectionOption, ExtensionHeader, + LoggerLike, Origin, Subprotocol, UpgradeProtocol, @@ -42,8 +42,6 @@ __all__ = ["ServerConnection"] -logger = logging.getLogger(__name__) - HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]] @@ -59,8 +57,14 @@ def __init__( subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLikeOrCallable] = None, max_size: Optional[int] = 2 ** 20, + logger: Optional[LoggerLike] = None, ): - super().__init__(side=SERVER, state=CONNECTING, max_size=max_size) + super().__init__( + side=SERVER, + state=CONNECTING, + max_size=max_size, + logger=logger, + ) self.origins = origins self.available_extensions = extensions self.available_subprotocols = subprotocols @@ -80,13 +84,13 @@ def accept(self, request: Request) -> Response: try: key, extensions_header, protocol_header = self.process_request(request) except InvalidOrigin as exc: - logger.debug("Invalid origin", exc_info=True) + self.logger.debug("Invalid origin", exc_info=True) return self.reject( http.HTTPStatus.FORBIDDEN, f"Failed to open a WebSocket connection: {exc}.\n", )._replace(exception=exc) except InvalidUpgrade as exc: - logger.debug("Invalid upgrade", exc_info=True) + self.logger.debug("Invalid upgrade", exc_info=True) return self.reject( http.HTTPStatus.UPGRADE_REQUIRED, ( @@ -98,13 +102,13 @@ def accept(self, request: Request) -> Response: headers=Headers([("Upgrade", "websocket")]), )._replace(exception=exc) except InvalidHandshake as exc: - logger.debug("Invalid handshake", exc_info=True) + self.logger.debug("Invalid handshake", exc_info=True) return self.reject( http.HTTPStatus.BAD_REQUEST, f"Failed to open a WebSocket connection: {exc}.\n", )._replace(exception=exc) except Exception as exc: - logger.warning("Error in opening handshake", exc_info=True) + self.logger.warning("Error in opening handshake", exc_info=True) return self.reject( http.HTTPStatus.INTERNAL_SERVER_ERROR, ( @@ -410,15 +414,15 @@ def send_response(self, response: Response) -> None: if response.status_code == 101: self.set_state(OPEN) - logger.debug( + self.logger.debug( "%s > HTTP/1.1 %d %s", self.side, response.status_code, response.reason_phrase, ) - logger.debug("%s > %r", self.side, response.headers) + self.logger.debug("%s > %r", self.side, response.headers) if response.body is not None: - logger.debug("%s > body (%d bytes)", self.side, len(response.body)) + self.logger.debug("%s > body (%d bytes)", self.side, len(response.body)) self.writes.append(response.serialize()) diff --git a/src/websockets/typing.py b/src/websockets/typing.py index 630a9fbe3..b1858d73e 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -1,7 +1,15 @@ +import logging from typing import List, NewType, Optional, Tuple, Union -__all__ = ["Data", "Origin", "ExtensionHeader", "ExtensionParameter", "Subprotocol"] +__all__ = [ + "Data", + "LoggerLike", + "Origin", + "ExtensionHeader", + "ExtensionParameter", + "Subprotocol", +] Data = Union[str, bytes] Data.__doc__ = """ @@ -12,6 +20,8 @@ """ +LoggerLike = Union[logging.Logger, logging.LoggerAdapter] +LoggerLike.__doc__ = """"Types accepted where :class:`~logging.Logger` is expected""" Origin = NewType("Origin", str) Origin.__doc__ = """Value of a Origin header""" diff --git a/tests/test_client.py b/tests/test_client.py index 747594bf3..b96ebd272 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,3 +1,4 @@ +import logging import unittest import unittest.mock @@ -568,3 +569,11 @@ def test_unsupported_subprotocol(self): with self.assertRaises(InvalidHandshake) as raised: raise response.exception self.assertEqual(str(raised.exception), "unsupported subprotocol: otherchat") + + +class MiscTests(unittest.TestCase): + def test_custom_logger(self): + logger = logging.getLogger("test") + with self.assertLogs("test", logging.DEBUG) as logs: + ClientConnection("wss://example.com/test", logger=logger) + self.assertEqual(len(logs.records), 1) diff --git a/tests/test_server.py b/tests/test_server.py index ad56a37bc..6a25f0d25 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,4 +1,5 @@ import http +import logging import unittest import unittest.mock @@ -625,3 +626,11 @@ def test_extra_headers_overrides_server(self): self.assertEqual(response.status_code, 101) self.assertEqual(response.headers["Server"], "Other") + + +class MiscTests(unittest.TestCase): + def test_custom_logger(self): + logger = logging.getLogger("test") + with self.assertLogs("test", logging.DEBUG) as logs: + ServerConnection(logger=logger) + self.assertEqual(len(logs.records), 1) From f5eae5657c2d948cdf20174a03190c6a4206341c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 May 2021 16:23:30 +0200 Subject: [PATCH 0830/1539] Inject logger in legacy asyncio layer. --- docs/reference/client.rst | 6 +-- docs/reference/server.rst | 6 +-- src/websockets/legacy/client.py | 19 +++++---- src/websockets/legacy/protocol.py | 68 ++++++++++++++++-------------- src/websockets/legacy/server.py | 41 ++++++++++-------- tests/legacy/test_client_server.py | 16 +++++++ 6 files changed, 94 insertions(+), 62 deletions(-) diff --git a/docs/reference/client.rst b/docs/reference/client.rst index c2bb6266b..c7c738c81 100644 --- a/docs/reference/client.rst +++ b/docs/reference/client.rst @@ -6,16 +6,16 @@ Client Opening a connection -------------------- - .. autofunction:: connect(uri, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, **kwds) + .. autofunction:: connect(uri, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, logger=None, **kwds) :async: - .. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, **kwds) + .. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, logger=None, **kwds) :async: Using a connection ------------------ - .. autoclass:: WebSocketClientProtocol(*, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, origin=None, extensions=None, subprotocols=None, extra_headers=None) + .. autoclass:: WebSocketClientProtocol(*, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, origin=None, extensions=None, subprotocols=None, extra_headers=None, logger=None) .. attribute:: id diff --git a/docs/reference/server.rst b/docs/reference/server.rst index 898403070..2d54eca7a 100644 --- a/docs/reference/server.rst +++ b/docs/reference/server.rst @@ -6,10 +6,10 @@ Server Starting a server ----------------- - .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, **kwds) + .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, logger=None, **kwds) :async: - .. autofunction:: unix_serve(ws_handler, path, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, **kwds) + .. autofunction:: unix_serve(ws_handler, path, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, logger=None, **kwds) :async: Stopping a server @@ -25,7 +25,7 @@ Server Using a connection ------------------ - .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None) + .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, logger=None) .. attribute:: id diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index b77b4e86d..760309c3f 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -31,7 +31,7 @@ parse_subprotocol, ) from ..http import USER_AGENT, build_host -from ..typing import ExtensionHeader, Origin, Subprotocol +from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol from ..uri import WebSocketURI, parse_uri from .compatibility import asyncio_get_running_loop from .handshake import build_request, check_response @@ -41,8 +41,6 @@ __all__ = ["connect", "unix_connect", "WebSocketClientProtocol"] -logger = logging.getLogger("websockets.server") - class WebSocketClientProtocol(WebSocketCommonProtocol): """ @@ -152,13 +150,16 @@ def __init__( extensions: Optional[Sequence[ClientExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLike] = None, + logger: Optional[LoggerLike] = None, **kwargs: Any, ) -> None: + if logger is None: + logger = logging.getLogger("websockets.client") + super().__init__(logger=logger, **kwargs) self.origin = origin self.available_extensions = extensions self.available_subprotocols = subprotocols self.extra_headers = extra_headers - super().__init__(**kwargs) def write_http_request(self, path: str, headers: Headers) -> None: """ @@ -168,8 +169,8 @@ def write_http_request(self, path: str, headers: Headers) -> None: self.path = path self.request_headers = headers - logger.debug("%s > GET %s HTTP/1.1", self.side, path) - logger.debug("%s > %r", self.side, headers) + self.logger.debug("%s > GET %s HTTP/1.1", self.side, path) + self.logger.debug("%s > %r", self.side, headers) # Since the path and headers only contain ASCII characters, # we can keep this simple. @@ -198,8 +199,8 @@ async def read_http_response(self) -> Tuple[int, Headers]: except Exception as exc: raise InvalidMessage("did not receive a valid HTTP response") from exc - logger.debug("%s < HTTP/1.1 %d %s", self.side, status_code, reason) - logger.debug("%s < %r", self.side, headers) + self.logger.debug("%s < HTTP/1.1 %d %s", self.side, status_code, reason) + self.logger.debug("%s < %r", self.side, headers) self.response_headers = headers @@ -478,6 +479,7 @@ def __init__( extensions: Optional[Sequence[ClientExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLike] = None, + logger: Optional[LoggerLike] = None, **kwargs: Any, ) -> None: # Backwards compatibility: close_timeout used to be called timeout. @@ -542,6 +544,7 @@ def __init__( extensions=extensions, subprotocols=subprotocols, extra_headers=extra_headers, + logger=logger, ) if kwargs.pop("unix", False): diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index b544b93d0..940e330a9 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -54,15 +54,13 @@ prepare_data, serialize_close, ) -from ..typing import Data, Subprotocol +from ..typing import Data, LoggerLike, Subprotocol from .compatibility import loop_if_py_lt_38 from .framing import Frame __all__ = ["WebSocketCommonProtocol"] -logger = logging.getLogger("websockets.protocol") - # A WebSocket connection goes through the following four states, in order: @@ -108,6 +106,7 @@ def __init__( max_queue: Optional[int] = 2 ** 5, read_limit: int = 2 ** 16, write_limit: int = 2 ** 16, + logger: Optional[LoggerLike] = None, # The following arguments are kept only for backwards compatibility. host: Optional[str] = None, port: Optional[int] = None, @@ -132,6 +131,9 @@ def __init__( self.max_queue = max_queue self.read_limit = read_limit self.write_limit = write_limit + if logger is None: + logger = logging.getLogger("websockets.protocol") + self.logger = logger # Unique identifier. For logs. self.id = uuid.uuid4() @@ -162,7 +164,7 @@ def __init__( # Subclasses implement the opening handshake and, on success, execute # :meth:`connection_open` to change the state to OPEN. self.state = State.CONNECTING - logger.debug("%s - state = CONNECTING", self.side) + self.logger.debug("%s - state = CONNECTING", self.side) # HTTP protocol parameters. self.path: str @@ -246,7 +248,7 @@ def connection_open(self) -> None: # 4.1. The WebSocket Connection is Established. assert self.state is State.CONNECTING self.state = State.OPEN - logger.debug("%s - state = OPEN", self.side) + self.logger.debug("%s - state = OPEN", self.side) # Start the task that receives incoming WebSocket messages. self.transfer_data_task = self.loop.create_task(self.transfer_data()) # Start the task that sends pings at regular intervals. @@ -812,7 +814,7 @@ async def transfer_data(self) -> None: # This shouldn't happen often because exceptions expected under # regular circumstances are handled above. If it does, consider # catching and handling more exceptions. - logger.error("Error in data transfer", exc_info=True) + self.logger.error("Error in data transfer", exc_info=True) self.transfer_data_exc = exc self.fail_connection(1011) @@ -923,7 +925,7 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: elif frame.opcode == OP_PING: # Answer pings. ping_hex = frame.data.hex() or "[empty]" - logger.debug( + self.logger.debug( "%s - received ping, sending pong: %s", self.side, ping_hex ) await self.pong(frame.data) @@ -931,7 +933,7 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: elif frame.opcode == OP_PONG: # Acknowledge pings on solicited pongs. if frame.data in self.pings: - logger.debug( + self.logger.debug( "%s - received solicited pong: %s", self.side, frame.data.hex() or "[empty]", @@ -956,14 +958,14 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: ping_id.hex() or "[empty]" for ping_id in ping_ids ) plural = "s" if len(ping_ids) > 1 else "" - logger.debug( + self.logger.debug( "%s - acknowledged previous ping%s: %s", self.side, plural, pings_hex, ) else: - logger.debug( + self.logger.debug( "%s - received unsolicited pong: %s", self.side, frame.data.hex() or "[empty]", @@ -984,7 +986,7 @@ async def read_frame(self, max_size: Optional[int]) -> Frame: max_size=max_size, extensions=self.extensions, ) - logger.debug("%s < %r", self.side, frame) + self.logger.debug("%s < %r", self.side, frame) return frame async def write_frame( @@ -997,7 +999,7 @@ async def write_frame( ) frame = Frame(fin, Opcode(opcode), data) - logger.debug("%s > %r", self.side, frame) + self.logger.debug("%s > %r", self.side, frame) frame.write( self.transport.write, mask=self.is_client, extensions=self.extensions ) @@ -1029,7 +1031,7 @@ async def write_close_frame(self, data: bytes = b"") -> None: if self.state is State.OPEN: # 7.1.3. The WebSocket Closing Handshake is Started self.state = State.CLOSING - logger.debug("%s - state = CLOSING", self.side) + self.logger.debug("%s - state = CLOSING", self.side) # 7.1.2. Start the WebSocket Closing Handshake await self.write_frame(True, OP_CLOSE, data, _expected_state=State.CLOSING) @@ -1071,7 +1073,7 @@ async def keepalive_ping(self) -> None: **loop_if_py_lt_38(self.loop), ) except asyncio.TimeoutError: - logger.debug("%s ! timed out waiting for pong", self.side) + self.logger.debug("%s ! timed out waiting for pong", self.side) self.fail_connection(1011) break @@ -1084,7 +1086,9 @@ async def keepalive_ping(self) -> None: pass except Exception: - logger.warning("Unexpected exception in keepalive ping task", exc_info=True) + self.logger.warning( + "Unexpected exception in keepalive ping task", exc_info=True + ) async def close_connection(self) -> None: """ @@ -1116,18 +1120,18 @@ async def close_connection(self) -> None: # Coverage marks this line as a partially executed branch. # I supect a bug in coverage. Ignore it for now. return # pragma: no cover - logger.debug("%s ! timed out waiting for TCP close", self.side) + self.logger.debug("%s ! timed out waiting for TCP close", self.side) # Half-close the TCP connection if possible (when there's no TLS). if self.transport.can_write_eof(): - logger.debug("%s x half-closing TCP connection", self.side) + self.logger.debug("%s x half-closing TCP connection", self.side) self.transport.write_eof() if await self.wait_for_connection_lost(): # Coverage marks this line as a partially executed branch. # I supect a bug in coverage. Ignore it for now. return # pragma: no cover - logger.debug("%s ! timed out waiting for TCP close", self.side) + self.logger.debug("%s ! timed out waiting for TCP close", self.side) finally: # The try/finally ensures that the transport never remains open, @@ -1140,15 +1144,15 @@ async def close_connection(self) -> None: return # Close the TCP connection. Buffers are flushed asynchronously. - logger.debug("%s x closing TCP connection", self.side) + self.logger.debug("%s x closing TCP connection", self.side) self.transport.close() if await self.wait_for_connection_lost(): return - logger.debug("%s ! timed out waiting for TCP close", self.side) + self.logger.debug("%s ! timed out waiting for TCP close", self.side) # Abort the TCP connection. Buffers are discarded. - logger.debug("%s x aborting TCP connection", self.side) + self.logger.debug("%s x aborting TCP connection", self.side) self.transport.abort() # connection_lost() is called quickly after aborting. @@ -1196,7 +1200,7 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> None: (The specification describes these steps in the opposite order.) """ - logger.debug( + self.logger.debug( "%s ! failing %s WebSocket connection with code %d", self.side, self.state.name, @@ -1226,10 +1230,10 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> None: # and write_frame(). self.state = State.CLOSING - logger.debug("%s - state = CLOSING", self.side) + self.logger.debug("%s - state = CLOSING", self.side) frame = Frame(True, OP_CLOSE, frame_data) - logger.debug("%s > %r", self.side, frame) + self.logger.debug("%s > %r", self.side, frame) frame.write( self.transport.write, mask=self.is_client, extensions=self.extensions ) @@ -1259,7 +1263,7 @@ def abort_pings(self) -> None: if self.pings: pings_hex = ", ".join(ping_id.hex() or "[empty]" for ping_id in self.pings) plural = "s" if len(self.pings) > 1 else "" - logger.debug( + self.logger.debug( "%s - aborted pending ping%s: %s", self.side, plural, pings_hex ) @@ -1279,7 +1283,7 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: which means it's the best point for configuring it. """ - logger.debug("%s - event = connection_made(%s)", self.side, transport) + self.logger.debug("%s - event = connection_made(%s)", self.side, transport) transport = cast(asyncio.Transport, transport) transport.set_write_buffer_limits(self.write_limit) @@ -1293,14 +1297,14 @@ def connection_lost(self, exc: Optional[Exception]) -> None: 7.1.4. The WebSocket Connection is Closed. """ - logger.debug("%s - event = connection_lost(%s)", self.side, exc) + self.logger.debug("%s - event = connection_lost(%s)", self.side, exc) self.state = State.CLOSED - logger.debug("%s - state = CLOSED", self.side) + self.logger.debug("%s - state = CLOSED", self.side) if not hasattr(self, "close_code"): self.close_code = 1006 if not hasattr(self, "close_reason"): self.close_reason = "" - logger.debug( + self.logger.debug( "%s x code = %d, reason = %s", self.side, self.close_code, @@ -1351,7 +1355,9 @@ def resume_writing(self) -> None: # pragma: no cover waiter.set_result(None) def data_received(self, data: bytes) -> None: - logger.debug("%s - event = data_received(<%d bytes>)", self.side, len(data)) + self.logger.debug( + "%s - event = data_received(<%d bytes>)", self.side, len(data) + ) self.reader.feed_data(data) def eof_received(self) -> None: @@ -1367,5 +1373,5 @@ def eof_received(self) -> None: Besides, that doesn't work on TLS connections. """ - logger.debug("%s - event = eof_received()", self.side) + self.logger.debug("%s - event = eof_received()", self.side) self.reader.feed_eof() diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 5bd7d0f56..5f76500df 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -41,7 +41,7 @@ from ..extensions.permessage_deflate import enable_server_permessage_deflate from ..headers import build_extension, parse_extension, parse_subprotocol from ..http import USER_AGENT -from ..typing import ExtensionHeader, Origin, Subprotocol +from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol from .compatibility import asyncio_get_running_loop, loop_if_py_lt_38 from .handshake import build_response, check_request from .http import read_request @@ -50,8 +50,6 @@ __all__ = ["serve", "unix_serve", "WebSocketServerProtocol", "WebSocketServer"] -logger = logging.getLogger("websockets.server") - HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]] @@ -184,8 +182,12 @@ def __init__( select_subprotocol: Optional[ Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] ] = None, + logger: Optional[LoggerLike] = None, **kwargs: Any, ) -> None: + if logger is None: + logger = logging.getLogger("websockets.server") + super().__init__(logger=logger, **kwargs) # For backwards compatibility with 6.0 or earlier. if origins is not None and "" in origins: warnings.warn("use None instead of '' in origins", DeprecationWarning) @@ -198,7 +200,6 @@ def __init__( self.extra_headers = extra_headers self._process_request = process_request self._select_subprotocol = select_subprotocol - super().__init__(**kwargs) def connection_made(self, transport: asyncio.BaseTransport) -> None: """ @@ -236,20 +237,22 @@ async def handler(self) -> None: except asyncio.CancelledError: # pragma: no cover raise except ConnectionError: - logger.debug("Connection error in opening handshake", exc_info=True) + self.logger.debug( + "Connection error in opening handshake", exc_info=True + ) raise except Exception as exc: if isinstance(exc, AbortHandshake): status, headers, body = exc.status, exc.headers, exc.body elif isinstance(exc, InvalidOrigin): - logger.debug("Invalid origin", exc_info=True) + self.logger.debug("Invalid origin", exc_info=True) status, headers, body = ( http.HTTPStatus.FORBIDDEN, Headers(), f"Failed to open a WebSocket connection: {exc}.\n".encode(), ) elif isinstance(exc, InvalidUpgrade): - logger.debug("Invalid upgrade", exc_info=True) + self.logger.debug("Invalid upgrade", exc_info=True) status, headers, body = ( http.HTTPStatus.UPGRADE_REQUIRED, Headers([("Upgrade", "websocket")]), @@ -261,14 +264,14 @@ async def handler(self) -> None: ).encode(), ) elif isinstance(exc, InvalidHandshake): - logger.debug("Invalid handshake", exc_info=True) + self.logger.debug("Invalid handshake", exc_info=True) status, headers, body = ( http.HTTPStatus.BAD_REQUEST, Headers(), f"Failed to open a WebSocket connection: {exc}.\n".encode(), ) else: - logger.warning("Error in opening handshake", exc_info=True) + self.logger.warning("Error in opening handshake", exc_info=True) status, headers, body = ( http.HTTPStatus.INTERNAL_SERVER_ERROR, Headers(), @@ -292,7 +295,7 @@ async def handler(self) -> None: try: await self.ws_handler(self, path) except Exception: - logger.error("Error in connection handler", exc_info=True) + self.logger.error("Error in connection handler", exc_info=True) if not self.closed: self.fail_connection(1011) raise @@ -300,10 +303,12 @@ async def handler(self) -> None: try: await self.close() except ConnectionError: - logger.debug("Connection error in closing handshake", exc_info=True) + self.logger.debug( + "Connection error in closing handshake", exc_info=True + ) raise except Exception: - logger.warning("Error in closing handshake", exc_info=True) + self.logger.warning("Error in closing handshake", exc_info=True) raise except Exception: @@ -338,8 +343,8 @@ async def read_http_request(self) -> Tuple[str, Headers]: except Exception as exc: raise InvalidMessage("did not receive a valid HTTP request") from exc - logger.debug("%s < GET %s HTTP/1.1", self.side, path) - logger.debug("%s < %r", self.side, headers) + self.logger.debug("%s < GET %s HTTP/1.1", self.side, path) + self.logger.debug("%s < %r", self.side, headers) self.path = path self.request_headers = headers @@ -357,8 +362,8 @@ def write_http_response( """ self.response_headers = headers - logger.debug("%s > HTTP/1.1 %d %s", self.side, status.value, status.phrase) - logger.debug("%s > %r", self.side, headers) + self.logger.debug("%s > HTTP/1.1 %d %s", self.side, status.value, status.phrase) + self.logger.debug("%s > %r", self.side, headers) # Since the status line and headers only contain ASCII characters, # we can keep this simple. @@ -368,7 +373,7 @@ def write_http_response( self.transport.write(response.encode()) if body is not None: - logger.debug("%s > body (%d bytes)", self.side, len(body)) + self.logger.debug("%s > body (%d bytes)", self.side, len(body)) self.transport.write(body) async def process_request( @@ -962,6 +967,7 @@ def __init__( select_subprotocol: Optional[ Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] ] = None, + logger: Optional[LoggerLike] = None, **kwargs: Any, ) -> None: # Backwards compatibility: close_timeout used to be called timeout. @@ -1025,6 +1031,7 @@ def __init__( extra_headers=extra_headers, process_request=process_request, select_subprotocol=select_subprotocol, + logger=logger, ) if kwargs.pop("unix", False): diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 353e5b370..60c0a14ae 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -2,6 +2,7 @@ import contextlib import functools import http +import logging import pathlib import random import socket @@ -1471,3 +1472,18 @@ async def run_client(): self.loop.run_until_complete(run_client()) self.assertEqual(messages, self.MESSAGES) + + +class LoggerTests(ClientServerTestsMixin, AsyncioTestCase): + def test_logger_client(self): + with self.assertLogs("test.server", logging.DEBUG) as server_logs: + self.start_server(logger=logging.getLogger("test.server")) + with self.assertLogs("test.client", logging.DEBUG) as client_logs: + self.start_client(logger=logging.getLogger("test.client")) + self.loop.run_until_complete(self.client.send("Hello!")) + self.loop.run_until_complete(self.client.recv()) + self.stop_client() + self.stop_server() + + self.assertGreater(len(server_logs.records), 0) + self.assertGreater(len(client_logs.records), 0) From 57ff93f051d0731d28c8be5182740a0167a85d23 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 29 May 2021 07:25:52 +0200 Subject: [PATCH 0831/1539] Document how to disable logging. Fix #759. --- docs/howto/index.rst | 1 + docs/howto/logging.rst | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+) create mode 100644 docs/howto/logging.rst diff --git a/docs/howto/index.rst b/docs/howto/index.rst index ac1182705..e5af8488e 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -15,6 +15,7 @@ The following guides will help you integrate websockets into a broader system. :maxdepth: 2 django + logging The WebSocket protocol makes provisions for extending or specializing its features, which websockets supports fully. diff --git a/docs/howto/logging.rst b/docs/howto/logging.rst new file mode 100644 index 000000000..214e05105 --- /dev/null +++ b/docs/howto/logging.rst @@ -0,0 +1,26 @@ +Configure logging +================= + +Disable logging +--------------- + +If your application doesn't configure :mod:`logging`, Python outputs messages +of severity :data:`~logging.WARNING` and higher to :data:`~sys.stderr`. If +you want to disable this behavior for websockets, you can add +a :class:`~logging.NullHandler`:: + + logging.getLogger("websockets").addHandler(logging.NullHandler()) + +Additionally, if your application configures :mod:`logging`, you must disable +propagation to the root logger, or else its handlers could output logs:: + + logging.getLogger("websockets").propagate = False + +Alternatively, you could set the log level to :data:`~logging.CRITICAL` for +websockets, as the highest level currently used is :data:`~logging.ERROR`:: + + logging.getLogger("websockets").setLevel(logging.CRITICAL) + +Or you could configure a filter to drop all messages:: + + logging.getLogger("websockets").addFilter(lambda record: None) From 3388225957c6bba1cf46748f6f83c6bb8f7ff7f6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 30 May 2021 14:13:07 +0200 Subject: [PATCH 0832/1539] Separate prints in examples from logs. They used the same < and > symbols which was confusing. --- example/client.py | 4 ++-- example/secure_client.py | 4 ++-- example/secure_server.py | 4 ++-- example/server.py | 4 ++-- example/unix_client.py | 4 ++-- example/unix_server.py | 4 ++-- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/example/client.py b/example/client.py index e39df81f7..062540202 100755 --- a/example/client.py +++ b/example/client.py @@ -11,9 +11,9 @@ async def hello(): name = input("What's your name? ") await websocket.send(name) - print(f"> {name}") + print(f">>> {name}") greeting = await websocket.recv() - print(f"< {greeting}") + print(f"<<< {greeting}") asyncio.run(hello()) diff --git a/example/secure_client.py b/example/secure_client.py index 2657ba68b..518819dd1 100755 --- a/example/secure_client.py +++ b/example/secure_client.py @@ -20,9 +20,9 @@ async def hello(): name = input("What's your name? ") await websocket.send(name) - print(f"> {name}") + print(f">>> {name}") greeting = await websocket.recv() - print(f"< {greeting}") + print(f"<<< {greeting}") asyncio.run(hello()) diff --git a/example/secure_server.py b/example/secure_server.py index e0ef6e53b..96c300390 100755 --- a/example/secure_server.py +++ b/example/secure_server.py @@ -9,12 +9,12 @@ async def hello(websocket, path): name = await websocket.recv() - print(f"< {name}") + print(f"<<< {name}") greeting = f"Hello {name}!" await websocket.send(greeting) - print(f"> {greeting}") + print(f">>> {greeting}") ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) localhost_pem = pathlib.Path(__file__).with_name("localhost.pem") diff --git a/example/server.py b/example/server.py index 98dbb5acd..4dcf317f5 100755 --- a/example/server.py +++ b/example/server.py @@ -7,12 +7,12 @@ async def hello(websocket, path): name = await websocket.recv() - print(f"< {name}") + print(f"<<< {name}") greeting = f"Hello {name}!" await websocket.send(greeting) - print(f"> {greeting}") + print(f">>> {greeting}") async def main(): async with websockets.serve(hello, "localhost", 8765): diff --git a/example/unix_client.py b/example/unix_client.py index 434638c80..926156730 100755 --- a/example/unix_client.py +++ b/example/unix_client.py @@ -11,9 +11,9 @@ async def hello(): async with websockets.unix_connect(socket_path) as websocket: name = input("What's your name? ") await websocket.send(name) - print(f"> {name}") + print(f">>> {name}") greeting = await websocket.recv() - print(f"< {greeting}") + print(f"<<< {greeting}") asyncio.run(hello()) diff --git a/example/unix_server.py b/example/unix_server.py index 223d97301..192f31bb0 100755 --- a/example/unix_server.py +++ b/example/unix_server.py @@ -8,12 +8,12 @@ async def hello(websocket, path): name = await websocket.recv() - print(f"< {name}") + print(f"<<< {name}") greeting = f"Hello {name}!" await websocket.send(greeting) - print(f"> {greeting}") + print(f">>> {greeting}") async def main(): socket_path = os.path.join(os.path.dirname(__file__), "socket") From acb6166aa3ab20dbd4ca8e80abf1a614b76f92c1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 30 May 2021 15:06:06 +0200 Subject: [PATCH 0833/1539] Overhaul logging. * Log connections open and closed at the info level. * Put debug logs behind `if self.debug:` to minimize overhead. * Standardize logging style. * Add documentation. --- docs/howto/logging.rst | 129 +++++++++++++++++++++++++- docs/project/changelog.rst | 2 + setup.cfg | 8 +- src/websockets/client.py | 20 +++- src/websockets/connection.py | 38 ++++---- src/websockets/legacy/client.py | 12 ++- src/websockets/legacy/protocol.py | 148 ++++++++++++++++-------------- src/websockets/legacy/server.py | 42 +++++---- src/websockets/server.py | 35 ++++--- tests/test_connection.py | 8 +- tests/test_http11.py | 4 +- 11 files changed, 315 insertions(+), 131 deletions(-) diff --git a/docs/howto/logging.rst b/docs/howto/logging.rst index 214e05105..c83ec6906 100644 --- a/docs/howto/logging.rst +++ b/docs/howto/logging.rst @@ -1,12 +1,98 @@ +Logging +======= + +.. currentmodule:: websockets + +Logs contents +------------- + +When you run a WebSocket client, your code calls coroutines provided by +websockets. + +If an error occurs, websockets tells you by raising an exception. For example, +it raises a :exc:`~exception.ConnectionClosed` exception if the other side +closes the connection. + +When you run a WebSocket server, websockets accepts connections, performs the +opening handshake, runs the connection handler coroutine that you provided, +and performs the closing handshake. + +Given this `inversion of control`_, if an error happens in the opening +handshake or if the connection handler crashes, there is no way to raise an +exception that you can handle. + +.. _inversion of control: https://en.wikipedia.org/wiki/Inversion_of_control + +Logs tell you about these errors. + +Besides errors, you may want to record the activity of the server. + +In a request/response protocol such as HTTP, there's an obvious way to record +activity: log one event per request/response. Unfortunately, this solution +doesn't work well for a bidirectional protocol such as WebSocket. + +Instead, when running as a server, websockets logs one event when a +`connection is established`_ and another event when a `connection is +closed`_. + +.. _connection is established: https://datatracker.ietf.org/doc/html/rfc6455#section-4 +.. _connection is closed: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.4 + +websockets doesn't log an event for every message because that would be +excessive for many applications exchanging small messages at a fast rate. +However, you could add this level of logging in your own code if necessary. + +See :ref:`log levels ` below for details of events logged by +websockets at each level. + Configure logging -================= +----------------- + +websockets relies on the :mod:`logging` module from the standard library in +order to maximize compatibility and integrate nicely with other libraries:: + + import logging + +websockets logs to the ``"websockets.client"`` and ``"websockets.server"`` +loggers. + +websockets doesn't provide a default logging configuration because +requirements vary a lot depending on the environment. + +Here's a basic configuration for a server in production:: + + logging.basicConfig( + format="%(asctime)s %(message)s", + level=logging.INFO, + ) + +Here's how to enable debug logs for development:: + + logging.basicConfig( + format="%(message)s", + level=logging.DEBUG, + ) + +You can select a different :class:`~logging.Logger` with the ``logger`` +argument:: + + import websockets + + async with websockets.serve( + ..., + logger=logging.getLogger("interface.websocket"), + ): + ... Disable logging --------------- If your application doesn't configure :mod:`logging`, Python outputs messages -of severity :data:`~logging.WARNING` and higher to :data:`~sys.stderr`. If -you want to disable this behavior for websockets, you can add +of severity :data:`~logging.WARNING` and higher to :data:`~sys.stderr`. As a +consequence, you will see a message and a stack trace if a connection handler +coroutine crashes or if you hit a bug in websockets. + +If you want to disable this behavior for websockets, you can add a :class:`~logging.NullHandler`:: logging.getLogger("websockets").addHandler(logging.NullHandler()) @@ -24,3 +110,40 @@ websockets, as the highest level currently used is :data:`~logging.ERROR`:: Or you could configure a filter to drop all messages:: logging.getLogger("websockets").addFilter(lambda record: None) + +.. _log-levels: + +Log levels +---------- + +Here's what websockets logs at each level. + +:attr:`~logging.ERROR` +...................... + +* Exceptions raised by connection handler coroutines in servers +* Exceptions resulting from bugs in websockets + +:attr:`~logging.INFO` +..................... + +* Connections opened and closed in servers + +:attr:`~logging.DEBUG` +...................... + +* Changes to the state of connections +* Handshake requests and responses +* All frames sent and received +* Steps to close a connection +* Keepalive pings and pongs +* Errors handled transparently + +Debug messages have cute prefixes that make logs easier to scan: + +* ``>`` - send something +* ``<`` - receive something +* ``=`` - set connection state +* ``x`` - shut down connection +* ``%`` - manage pings and pongs +* ``!`` - handle errors and timeouts diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 919081a3a..caf637614 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -43,6 +43,8 @@ They may change at any time. * Added compatibility with Python 3.10. +* Improved logging. + * Optimized default compression settings to reduce memory usage. * Made it easier to customize authentication with diff --git a/setup.cfg b/setup.cfg index 0625798a2..9ff1939c7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,7 +20,8 @@ lines_after_imports = 2 [coverage:run] branch = True -omit = */__main__.py +omit = + */__main__.py source = websockets tests @@ -29,3 +30,8 @@ source = source = src/websockets .tox/*/lib/python*/site-packages/websockets + +[coverage:report] +exclude_lines = + if self.debug: + pragma: no cover diff --git a/src/websockets/client.py b/src/websockets/client.py index 9d3f0e7f7..738b6c998 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -275,21 +275,33 @@ def send_request(self, request: Request) -> None: Send a WebSocket handshake request to the server. """ - self.logger.debug("%s > GET %s HTTP/1.1", self.side, request.path) - self.logger.debug("%s > %r", self.side, request.headers) + if self.debug: + self.logger.debug("> GET %s HTTP/1.1", request.path) + for key, value in request.headers.raw_items(): + self.logger.debug("> %s: %s", key, value) self.writes.append(request.serialize()) def parse(self) -> Generator[None, None, None]: response = yield from Response.parse( - self.reader.read_line, self.reader.read_exact, self.reader.read_to_eof + self.reader.read_line, + self.reader.read_exact, + self.reader.read_to_eof, ) + + if self.debug: + code, phrase = response.status_code, response.reason_phrase + self.logger.debug("< HTTP/1.1 %d %s", code, phrase) + for key, value in response.headers.raw_items(): + self.logger.debug("< %s: %s", key, value) + if response.body is not None: + self.logger.debug("< [body] (%d bytes)", len(response.body)) + assert self.state == CONNECTING try: self.process_response(response) except InvalidHandshake as exc: response = response._replace(exception=exc) - self.logger.debug("Invalid handshake", exc_info=True) else: self.set_state(OPEN) finally: diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 4852560e3..5052df3f7 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -71,23 +71,23 @@ def __init__( # Unique identifier. For logs. self.id = uuid.uuid4() + # Logger or LoggerAdapter for this connection. + if logger is None: + logger = logging.getLogger(f"websockets.{side.name.lower()}") + self.logger = logger + + # Track if DEBUG is enabled. Shortcut logging calls if it isn't. + self.debug = logger.isEnabledFor(logging.DEBUG) + # Connection side. CLIENT or SERVER. self.side = side # Connnection state. CONNECTING and CLOSED states are handled in subclasses. - self.state = state + self.set_state(state) # Maximum size of incoming messages in bytes. self.max_size = max_size - # Logger or LoggerAdapter for this connection. - if logger is None: - logger = logging.getLogger(f"websockets.{side.name.lower()}") - self.logger = logger - - # Must wait until we have the logger to log the initial state! - self.logger.debug("%s - initial state: %s", self.side, state.name) - # Current size of incoming message in bytes. Only set while reading a # fragmented message i.e. a data frames with the FIN bit not set. self.cur_size: Optional[int] = None @@ -123,9 +123,8 @@ def __init__( self.parser_exc: Optional[Exception] = None def set_state(self, state: State) -> None: - self.logger.debug( - "%s - state change: %s > %s", self.side, self.state.name, state.name - ) + if self.debug: + self.logger.debug("= connection is %s", state.name) self.state = state # Public APIs for receiving data. @@ -273,7 +272,7 @@ def step_parser(self) -> None: # EOF because receive_data() or receive_eof() would fail earlier.) assert self.parser_exc is not None raise RuntimeError( - "cannot receive data or EOF after an error" + "parser cannot receive data or EOF after an error" ) from self.parser_exc except ProtocolError as exc: self.fail_connection(1002, str(exc)) @@ -292,7 +291,7 @@ def step_parser(self) -> None: self.parser_exc = exc raise except Exception as exc: - self.logger.error("unexpected exception in parser", exc_info=True) + self.logger.error("parser failed", exc_info=True) # Don't include exception details, which may be security-sensitive. self.fail_connection(1011) self.parser_exc = exc @@ -302,6 +301,8 @@ def parse(self) -> Generator[None, None, None]: while True: eof = yield from self.reader.at_eof() if eof: + if self.debug: + self.logger.debug("< EOF") if self.close_frame_received: if not self.eof_sent: self.send_eof() @@ -330,6 +331,9 @@ def parse(self) -> Generator[None, None, None]: extensions=self.extensions, ) + if self.debug: + self.logger.debug("< %s", frame) + if frame.opcode is OP_TEXT or frame.opcode is OP_BINARY: # 5.5.1 Close: "The application MUST NOT send any more data # frames after sending a Close frame." @@ -407,7 +411,8 @@ def send_frame(self, frame: Frame) -> None: f"cannot write to a WebSocket in the {self.state.name} state" ) - self.logger.debug("%s > %r", self.side, frame) + if self.debug: + self.logger.debug("> %s", frame) self.writes.append( frame.serialize(mask=self.side is CLIENT, extensions=self.extensions) ) @@ -415,5 +420,6 @@ def send_frame(self, frame: Frame) -> None: def send_eof(self) -> None: assert not self.eof_sent self.eof_sent = True - self.logger.debug("%s > EOF", self.side) + if self.debug: + self.logger.debug("> EOF") self.writes.append(SEND_EOF) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 760309c3f..5c281d3b8 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -169,8 +169,10 @@ def write_http_request(self, path: str, headers: Headers) -> None: self.path = path self.request_headers = headers - self.logger.debug("%s > GET %s HTTP/1.1", self.side, path) - self.logger.debug("%s > %r", self.side, headers) + if self.debug: + self.logger.debug("> GET %s HTTP/1.1", path) + for key, value in headers.raw_items(): + self.logger.debug("> %s: %s", key, value) # Since the path and headers only contain ASCII characters, # we can keep this simple. @@ -199,8 +201,10 @@ async def read_http_response(self) -> Tuple[int, Headers]: except Exception as exc: raise InvalidMessage("did not receive a valid HTTP response") from exc - self.logger.debug("%s < HTTP/1.1 %d %s", self.side, status_code, reason) - self.logger.debug("%s < %r", self.side, headers) + if self.debug: + self.logger.debug("< HTTP/1.1 %d %s", status_code, reason) + for key, value in headers.raw_items(): + self.logger.debug("< %s: %s", key, value) self.response_headers = headers diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 940e330a9..ec8abaed9 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -131,12 +131,17 @@ def __init__( self.max_queue = max_queue self.read_limit = read_limit self.write_limit = write_limit + + # Unique identifier. For logs. + self.id = uuid.uuid4() + + # Logger or LoggerAdapter for this connection. if logger is None: logger = logging.getLogger("websockets.protocol") self.logger = logger - # Unique identifier. For logs. - self.id = uuid.uuid4() + # Track if DEBUG is enabled. Shortcut logging calls if it isn't. + self.debug = logger.isEnabledFor(logging.DEBUG) assert loop is not None # Remove when dropping Python < 3.10 - use get_running_loop instead. @@ -164,7 +169,8 @@ def __init__( # Subclasses implement the opening handshake and, on success, execute # :meth:`connection_open` to change the state to OPEN. self.state = State.CONNECTING - self.logger.debug("%s - state = CONNECTING", self.side) + if self.debug: + self.logger.debug("= connection is CONNECTING") # HTTP protocol parameters. self.path: str @@ -248,7 +254,8 @@ def connection_open(self) -> None: # 4.1. The WebSocket Connection is Established. assert self.state is State.CONNECTING self.state = State.OPEN - self.logger.debug("%s - state = OPEN", self.side) + if self.debug: + self.logger.debug("= connection is OPEN") # Start the task that receives incoming WebSocket messages. self.transfer_data_task = self.loop.create_task(self.transfer_data()) # Start the task that sends pings at regular intervals. @@ -814,7 +821,7 @@ async def transfer_data(self) -> None: # This shouldn't happen often because exceptions expected under # regular circumstances are handled above. If it does, consider # catching and handling more exceptions. - self.logger.error("Error in data transfer", exc_info=True) + self.logger.error("data transfer failed", exc_info=True) self.transfer_data_exc = exc self.fail_connection(1011) @@ -925,20 +932,20 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: elif frame.opcode == OP_PING: # Answer pings. ping_hex = frame.data.hex() or "[empty]" - self.logger.debug( - "%s - received ping, sending pong: %s", self.side, ping_hex - ) + if self.debug: + self.logger.debug("%% received ping, sending pong: %s", ping_hex) await self.pong(frame.data) elif frame.opcode == OP_PONG: # Acknowledge pings on solicited pongs. if frame.data in self.pings: - self.logger.debug( - "%s - received solicited pong: %s", - self.side, - frame.data.hex() or "[empty]", - ) + if self.debug: + self.logger.debug( + "%% received solicited pong: %s", + frame.data.hex() or "[empty]", + ) # Acknowledge all pings up to the one matching this pong. + # Sending a pong for only the most recent ping is legal. ping_id = None ping_ids = [] for ping_id, ping in self.pings.items(): @@ -952,24 +959,25 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: # Remove acknowledged pings from self.pings. for ping_id in ping_ids: del self.pings[ping_id] - ping_ids = ping_ids[:-1] - if ping_ids: - pings_hex = ", ".join( - ping_id.hex() or "[empty]" for ping_id in ping_ids - ) - plural = "s" if len(ping_ids) > 1 else "" + # Log previous pings acknowledged. + if self.debug: + ping_ids = ping_ids[:-1] + if ping_ids: + pings_hex = ", ".join( + ping_id.hex() or "[empty]" for ping_id in ping_ids + ) + plural = "s" if len(ping_ids) > 1 else "" + self.logger.debug( + "%% acknowledged previous ping%s: %s", + plural, + pings_hex, + ) + else: + if self.debug: self.logger.debug( - "%s - acknowledged previous ping%s: %s", - self.side, - plural, - pings_hex, + "%% received unsolicited pong: %s", + frame.data.hex() or "[empty]", ) - else: - self.logger.debug( - "%s - received unsolicited pong: %s", - self.side, - frame.data.hex() or "[empty]", - ) # 5.6. Data Frames else: @@ -986,7 +994,8 @@ async def read_frame(self, max_size: Optional[int]) -> Frame: max_size=max_size, extensions=self.extensions, ) - self.logger.debug("%s < %r", self.side, frame) + if self.debug: + self.logger.debug("< %s", frame) return frame async def write_frame( @@ -999,9 +1008,12 @@ async def write_frame( ) frame = Frame(fin, Opcode(opcode), data) - self.logger.debug("%s > %r", self.side, frame) + if self.debug: + self.logger.debug("> %s", frame) frame.write( - self.transport.write, mask=self.is_client, extensions=self.extensions + self.transport.write, + mask=self.is_client, + extensions=self.extensions, ) try: @@ -1031,7 +1043,8 @@ async def write_close_frame(self, data: bytes = b"") -> None: if self.state is State.OPEN: # 7.1.3. The WebSocket Closing Handshake is Started self.state = State.CLOSING - self.logger.debug("%s - state = CLOSING", self.side) + if self.debug: + self.logger.debug("= connection is CLOSING") # 7.1.2. Start the WebSocket Closing Handshake await self.write_frame(True, OP_CLOSE, data, _expected_state=State.CLOSING) @@ -1073,7 +1086,8 @@ async def keepalive_ping(self) -> None: **loop_if_py_lt_38(self.loop), ) except asyncio.TimeoutError: - self.logger.debug("%s ! timed out waiting for pong", self.side) + if self.debug: + self.logger.debug("! timed out waiting for pong") self.fail_connection(1011) break @@ -1086,9 +1100,7 @@ async def keepalive_ping(self) -> None: pass except Exception: - self.logger.warning( - "Unexpected exception in keepalive ping task", exc_info=True - ) + self.logger.error("keepalive ping failed", exc_info=True) async def close_connection(self) -> None: """ @@ -1120,18 +1132,21 @@ async def close_connection(self) -> None: # Coverage marks this line as a partially executed branch. # I supect a bug in coverage. Ignore it for now. return # pragma: no cover - self.logger.debug("%s ! timed out waiting for TCP close", self.side) + if self.debug: + self.logger.debug("! timed out waiting for TCP close") # Half-close the TCP connection if possible (when there's no TLS). if self.transport.can_write_eof(): - self.logger.debug("%s x half-closing TCP connection", self.side) + if self.debug: + self.logger.debug("x half-closing TCP connection") self.transport.write_eof() if await self.wait_for_connection_lost(): # Coverage marks this line as a partially executed branch. # I supect a bug in coverage. Ignore it for now. return # pragma: no cover - self.logger.debug("%s ! timed out waiting for TCP close", self.side) + if self.debug: + self.logger.debug("! timed out waiting for TCP close") finally: # The try/finally ensures that the transport never remains open, @@ -1144,15 +1159,18 @@ async def close_connection(self) -> None: return # Close the TCP connection. Buffers are flushed asynchronously. - self.logger.debug("%s x closing TCP connection", self.side) + if self.debug: + self.logger.debug("x closing TCP connection") self.transport.close() if await self.wait_for_connection_lost(): return - self.logger.debug("%s ! timed out waiting for TCP close", self.side) + if self.debug: + self.logger.debug("! timed out waiting for TCP close") # Abort the TCP connection. Buffers are discarded. - self.logger.debug("%s x aborting TCP connection", self.side) + if self.debug: + self.logger.debug("x aborting TCP connection") self.transport.abort() # connection_lost() is called quickly after aborting. @@ -1200,12 +1218,8 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> None: (The specification describes these steps in the opposite order.) """ - self.logger.debug( - "%s ! failing %s WebSocket connection with code %d", - self.side, - self.state.name, - code, - ) + if self.debug: + self.logger.debug("! failing connection with code %d", code) # Cancel transfer_data_task if the opening handshake succeeded. # cancel() is idempotent and ignored if the task is done already. @@ -1230,12 +1244,16 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> None: # and write_frame(). self.state = State.CLOSING - self.logger.debug("%s - state = CLOSING", self.side) + if self.debug: + self.logger.debug("= connection is CLOSING") frame = Frame(True, OP_CLOSE, frame_data) - self.logger.debug("%s > %r", self.side, frame) + if self.debug: + self.logger.debug("> %s", frame) frame.write( - self.transport.write, mask=self.is_client, extensions=self.extensions + self.transport.write, + mask=self.is_client, + extensions=self.extensions, ) # Start close_connection_task if the opening handshake didn't succeed. @@ -1260,12 +1278,13 @@ def abort_pings(self) -> None: # nothing, but it prevents logging the exception. ping.cancel() - if self.pings: - pings_hex = ", ".join(ping_id.hex() or "[empty]" for ping_id in self.pings) - plural = "s" if len(self.pings) > 1 else "" - self.logger.debug( - "%s - aborted pending ping%s: %s", self.side, plural, pings_hex - ) + if self.debug: + if self.pings: + pings_hex = ", ".join( + ping_id.hex() or "[empty]" for ping_id in self.pings + ) + plural = "s" if len(self.pings) > 1 else "" + self.logger.debug("% aborted pending ping%s: %s", plural, pings_hex) # asyncio.Protocol methods @@ -1283,8 +1302,6 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: which means it's the best point for configuring it. """ - self.logger.debug("%s - event = connection_made(%s)", self.side, transport) - transport = cast(asyncio.Transport, transport) transport.set_write_buffer_limits(self.write_limit) self.transport = transport @@ -1297,20 +1314,19 @@ def connection_lost(self, exc: Optional[Exception]) -> None: 7.1.4. The WebSocket Connection is Closed. """ - self.logger.debug("%s - event = connection_lost(%s)", self.side, exc) self.state = State.CLOSED - self.logger.debug("%s - state = CLOSED", self.side) if not hasattr(self, "close_code"): self.close_code = 1006 if not hasattr(self, "close_reason"): self.close_reason = "" self.logger.debug( - "%s x code = %d, reason = %s", - self.side, + "= connection is CLOSED - code = %d, reason = %s", self.close_code, self.close_reason or "[no reason]", ) + self.abort_pings() + # If self.connection_lost_waiter isn't pending, that's a bug, because: # - it's set only here in connection_lost() which is called only once; # - it must never be canceled. @@ -1355,9 +1371,6 @@ def resume_writing(self) -> None: # pragma: no cover waiter.set_result(None) def data_received(self, data: bytes) -> None: - self.logger.debug( - "%s - event = data_received(<%d bytes>)", self.side, len(data) - ) self.reader.feed_data(data) def eof_received(self) -> None: @@ -1373,5 +1386,4 @@ def eof_received(self) -> None: Besides, that doesn't work on TLS connections. """ - self.logger.debug("%s - event = eof_received()", self.side) self.reader.feed_eof() diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 5f76500df..31ad0b18e 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -237,22 +237,21 @@ async def handler(self) -> None: except asyncio.CancelledError: # pragma: no cover raise except ConnectionError: - self.logger.debug( - "Connection error in opening handshake", exc_info=True - ) raise except Exception as exc: if isinstance(exc, AbortHandshake): status, headers, body = exc.status, exc.headers, exc.body elif isinstance(exc, InvalidOrigin): - self.logger.debug("Invalid origin", exc_info=True) + if self.debug: + self.logger.debug("! invalid origin", exc_info=True) status, headers, body = ( http.HTTPStatus.FORBIDDEN, Headers(), f"Failed to open a WebSocket connection: {exc}.\n".encode(), ) elif isinstance(exc, InvalidUpgrade): - self.logger.debug("Invalid upgrade", exc_info=True) + if self.debug: + self.logger.debug("! invalid upgrade", exc_info=True) status, headers, body = ( http.HTTPStatus.UPGRADE_REQUIRED, Headers([("Upgrade", "websocket")]), @@ -264,14 +263,15 @@ async def handler(self) -> None: ).encode(), ) elif isinstance(exc, InvalidHandshake): - self.logger.debug("Invalid handshake", exc_info=True) + if self.debug: + self.logger.debug("! invalid handshake", exc_info=True) status, headers, body = ( http.HTTPStatus.BAD_REQUEST, Headers(), f"Failed to open a WebSocket connection: {exc}.\n".encode(), ) else: - self.logger.warning("Error in opening handshake", exc_info=True) + self.logger.error("opening handshake failed", exc_info=True) status, headers, body = ( http.HTTPStatus.INTERNAL_SERVER_ERROR, Headers(), @@ -288,6 +288,9 @@ async def handler(self) -> None: headers.setdefault("Connection", "close") self.write_http_response(status, headers, body) + self.logger.info( + "connection failed (%d %s)", status.value, status.phrase + ) self.fail_connection() await self.wait_closed() return @@ -295,7 +298,7 @@ async def handler(self) -> None: try: await self.ws_handler(self, path) except Exception: - self.logger.error("Error in connection handler", exc_info=True) + self.logger.error("connection handler failed", exc_info=True) if not self.closed: self.fail_connection(1011) raise @@ -303,12 +306,9 @@ async def handler(self) -> None: try: await self.close() except ConnectionError: - self.logger.debug( - "Connection error in closing handshake", exc_info=True - ) raise except Exception: - self.logger.warning("Error in closing handshake", exc_info=True) + self.logger.error("closing handshake failed", exc_info=True) raise except Exception: @@ -324,6 +324,7 @@ async def handler(self) -> None: # task because the server waits for tasks attached to registered # connections before terminating. self.ws_server.unregister(self) + self.logger.info("connection closed") async def read_http_request(self) -> Tuple[str, Headers]: """ @@ -343,8 +344,10 @@ async def read_http_request(self) -> Tuple[str, Headers]: except Exception as exc: raise InvalidMessage("did not receive a valid HTTP request") from exc - self.logger.debug("%s < GET %s HTTP/1.1", self.side, path) - self.logger.debug("%s < %r", self.side, headers) + if self.debug: + self.logger.debug("< GET %s HTTP/1.1", path) + for key, value in headers.raw_items(): + self.logger.debug("< %s: %s", key, value) self.path = path self.request_headers = headers @@ -362,8 +365,12 @@ def write_http_response( """ self.response_headers = headers - self.logger.debug("%s > HTTP/1.1 %d %s", self.side, status.value, status.phrase) - self.logger.debug("%s > %r", self.side, headers) + if self.debug: + self.logger.debug("> HTTP/1.1 %d %s", status.value, status.phrase) + for key, value in headers.raw_items(): + self.logger.debug("> %s: %s", key, value) + if body is not None: + self.logger.debug("> [body] (%d bytes)", len(body)) # Since the status line and headers only contain ASCII characters, # we can keep this simple. @@ -373,7 +380,6 @@ def write_http_response( self.transport.write(response.encode()) if body is not None: - self.logger.debug("%s > body (%d bytes)", self.side, len(body)) self.transport.write(body) async def process_request( @@ -692,6 +698,8 @@ async def handshake( self.write_http_response(http.HTTPStatus.SWITCHING_PROTOCOLS, response_headers) + self.logger.info("connection open") + self.connection_open() return path diff --git a/src/websockets/server.py b/src/websockets/server.py index 4a76ac886..87aa07b51 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -84,13 +84,15 @@ def accept(self, request: Request) -> Response: try: key, extensions_header, protocol_header = self.process_request(request) except InvalidOrigin as exc: - self.logger.debug("Invalid origin", exc_info=True) + if self.debug: + self.logger.debug("! invalid origin", exc_info=True) return self.reject( http.HTTPStatus.FORBIDDEN, f"Failed to open a WebSocket connection: {exc}.\n", )._replace(exception=exc) except InvalidUpgrade as exc: - self.logger.debug("Invalid upgrade", exc_info=True) + if self.debug: + self.logger.debug("! invalid upgrade", exc_info=True) return self.reject( http.HTTPStatus.UPGRADE_REQUIRED, ( @@ -102,13 +104,14 @@ def accept(self, request: Request) -> Response: headers=Headers([("Upgrade", "websocket")]), )._replace(exception=exc) except InvalidHandshake as exc: - self.logger.debug("Invalid handshake", exc_info=True) + if self.debug: + self.logger.debug("! invalid handshake", exc_info=True) return self.reject( http.HTTPStatus.BAD_REQUEST, f"Failed to open a WebSocket connection: {exc}.\n", )._replace(exception=exc) except Exception as exc: - self.logger.warning("Error in opening handshake", exc_info=True) + self.logger.error("opening handshake failed", exc_info=True) return self.reject( http.HTTPStatus.INTERNAL_SERVER_ERROR, ( @@ -145,6 +148,7 @@ def accept(self, request: Request) -> Response: headers.setdefault("Date", email.utils.formatdate(usegmt=True)) headers.setdefault("Server", USER_AGENT) + self.logger.info("connection open") return Response(101, "Switching Protocols", headers) def process_request( @@ -404,6 +408,7 @@ def reject( headers.setdefault("Content-Length", str(len(body))) headers.setdefault("Content-Type", "text/plain; charset=utf-8") headers.setdefault("Connection", "close") + self.logger.info("connection failed (%d %s)", status.value, status.phrase) return Response(status.value, status.phrase, headers, body) def send_response(self, response: Response) -> None: @@ -414,20 +419,24 @@ def send_response(self, response: Response) -> None: if response.status_code == 101: self.set_state(OPEN) - self.logger.debug( - "%s > HTTP/1.1 %d %s", - self.side, - response.status_code, - response.reason_phrase, - ) - self.logger.debug("%s > %r", self.side, response.headers) - if response.body is not None: - self.logger.debug("%s > body (%d bytes)", self.side, len(response.body)) + if self.debug: + code, phrase = response.status_code, response.reason_phrase + self.logger.debug("> HTTP/1.1 %d %s", code, phrase) + for key, value in response.headers.raw_items(): + self.logger.debug("> %s: %s", key, value) + if response.body is not None: + self.logger.debug("> [body] (%d bytes)", len(response.body)) self.writes.append(response.serialize()) def parse(self) -> Generator[None, None, None]: request = yield from Request.parse(self.reader.read_line) + + if self.debug: + self.logger.debug("< GET %s HTTP/1.1", request.path) + for key, value in request.headers.raw_items(): + self.logger.debug("< %s: %s", key, value) + assert self.state == CONNECTING self.events.append(request) yield from super().parse() diff --git a/tests/test_connection.py b/tests/test_connection.py index 3e39a3f9e..6203a1469 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1356,7 +1356,7 @@ def test_client_receives_data_after_exception(self): with self.assertRaises(RuntimeError) as raised: client.receive_data(b"\x00\x00") self.assertEqual( - str(raised.exception), "cannot receive data or EOF after an error" + str(raised.exception), "parser cannot receive data or EOF after an error" ) def test_server_receives_data_after_exception(self): @@ -1367,7 +1367,7 @@ def test_server_receives_data_after_exception(self): with self.assertRaises(RuntimeError) as raised: server.receive_data(b"\x00\x00") self.assertEqual( - str(raised.exception), "cannot receive data or EOF after an error" + str(raised.exception), "parser cannot receive data or EOF after an error" ) def test_client_receives_eof_after_exception(self): @@ -1378,7 +1378,7 @@ def test_client_receives_eof_after_exception(self): with self.assertRaises(RuntimeError) as raised: client.receive_eof() self.assertEqual( - str(raised.exception), "cannot receive data or EOF after an error" + str(raised.exception), "parser cannot receive data or EOF after an error" ) def test_server_receives_eof_after_exception(self): @@ -1389,7 +1389,7 @@ def test_server_receives_eof_after_exception(self): with self.assertRaises(RuntimeError) as raised: server.receive_eof() self.assertEqual( - str(raised.exception), "cannot receive data or EOF after an error" + str(raised.exception), "parser cannot receive data or EOF after an error" ) def test_client_receives_data_after_eof(self): diff --git a/tests/test_http11.py b/tests/test_http11.py index e73365cf0..afd85f64a 100644 --- a/tests/test_http11.py +++ b/tests/test_http11.py @@ -132,7 +132,9 @@ def setUp(self): def parse(self): return Response.parse( - self.reader.read_line, self.reader.read_exact, self.reader.read_to_eof + self.reader.read_line, + self.reader.read_exact, + self.reader.read_to_eof, ) def test_parse(self): From 6c138c7f48cfb41151bad829f1bc5d41a7a95e90 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 30 May 2021 20:30:07 +0200 Subject: [PATCH 0834/1539] Add protocol instance to logging context. --- docs/howto/logging.rst | 49 ++++++++++++++++++++++++++++--- docs/spelling_wordlist.txt | 1 + src/websockets/legacy/protocol.py | 4 ++- 3 files changed, 49 insertions(+), 5 deletions(-) diff --git a/docs/howto/logging.rst b/docs/howto/logging.rst index c83ec6906..5aaf31ee3 100644 --- a/docs/howto/logging.rst +++ b/docs/howto/logging.rst @@ -73,14 +73,55 @@ Here's how to enable debug logs for development:: level=logging.DEBUG, ) -You can select a different :class:`~logging.Logger` with the ``logger`` -argument:: +Furthermore, websockets adds a ``websocket`` attribute to every log record, so +you can include additional information about connections in logs. - import websockets +You could attempt to add information with a formatter:: + + # this doesn't work! + logging.basicConfig( + format="{asctime} {websocket.id} {websocket.remote_address[0]} {message}", + level=logging.INFO, + style="{", + ) + +However, this technique has two downsides: + +* The formatter applies to all records. It will crash if it receives a record + that doesn't have a ``websocket`` attribute. You could configure logging to + work around this problem but that could get complicated quickly. + +* Even with :meth:`str.format` style, you're restricted to attribute and index + lookups, which isn't enough to implement some fairly simple requirements. + +There's a better way. :func:`~server.serve` accepts a ``logger`` argument to +override the default :class:`~logging.Logger`. You can set ``logger`` to +a :class:`~logging.LoggerAdapter` that enriches logs. + +For example, if the server is behind a reverse proxy, ``remote_address`` gives +the IP address of the proxy, which isn't useful. IP addresses of clients are +generally available in a HTTP header set by the proxy. + +Here's how to include them in logs, assuming they're in the +``X-Forwarded-For`` header:: + + logging.basicConfig( + format="%(asctime)s %(message)s", + level=logging.INFO, + ) + + class LoggerAdapter(logging.LoggerAdapter): + def process(self, msg, kwargs): + try: + websocket = kwargs["extra"]["websocket"] + except KeyError: + return msg, kwargs + xff = websocket.request_headers.get("X-Forwarded-For") + return f"{websocket.id} {xff} {msg}", kwargs async with websockets.serve( ..., - logger=logging.getLogger("interface.websocket"), + logger=LoggerAdapter(logging.getLogger("websockets.server")), ): ... diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 4fa44fbaa..b460ef033 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -25,6 +25,7 @@ daemonize datastructures django dyno +formatter fractalideas gunicorn haproxy diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index ec8abaed9..90683aaf9 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -138,7 +138,9 @@ def __init__( # Logger or LoggerAdapter for this connection. if logger is None: logger = logging.getLogger("websockets.protocol") - self.logger = logger + # https://github.com/python/typeshed/issues/5561 + logger = cast(logging.Logger, logger) + self.logger = logging.LoggerAdapter(logger, {"websocket": self}) # Track if DEBUG is enabled. Shortcut logging calls if it isn't. self.debug = logger.isEnabledFor(logging.DEBUG) From e93f5e916291479c9b6be85609744e805681eb41 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 31 May 2021 07:57:04 +0200 Subject: [PATCH 0835/1539] Document structured logging. --- docs/howto/logging.rst | 40 +++++++++++++++++++++++++++++++++++ example/json_log_formatter.py | 33 +++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 example/json_log_formatter.py diff --git a/docs/howto/logging.rst b/docs/howto/logging.rst index 5aaf31ee3..b780c0162 100644 --- a/docs/howto/logging.rst +++ b/docs/howto/logging.rst @@ -125,6 +125,46 @@ Here's how to include them in logs, assuming they're in the ): ... +Logging to JSON +--------------- + +Even though :mod:`logging` predates structured logging, it's still possible to +output logs as JSON with a bit of effort. + +First, we need a :class:`~logging.Formatter` that renders JSON: + +.. literalinclude:: ../../example/json_log_formatter.py + +Then, we configure logging to apply this formatter:: + + handler = logging.StreamHandler() + handler.setFormatter(formatter) + + logger = logging.getLogger() + logger.addHandler(handler) + logger.setLevel(logging.INFO) + +Finally, we populate the ``event_data`` custom attribute in log records with +a :class:`~logging.LoggerAdapter`:: + + class LoggerAdapter(logging.LoggerAdapter): + def process(self, msg, kwargs): + try: + websocket = kwargs["extra"]["websocket"] + except KeyError: + return msg, kwargs + kwargs["extra"]["event_data"] = { + "connection_id": str(websocket.id), + "remote_addr": websocket.request_headers.get("X-Forwarded-For"), + } + return msg, kwargs + + async with websockets.serve( + ..., + logger=LoggerAdapter(logging.getLogger("websockets.server")), + ): + ... + Disable logging --------------- diff --git a/example/json_log_formatter.py b/example/json_log_formatter.py new file mode 100644 index 000000000..b8fc8d6dc --- /dev/null +++ b/example/json_log_formatter.py @@ -0,0 +1,33 @@ +import json +import logging +import datetime + +class JSONFormatter(logging.Formatter): + """ + Render logs as JSON. + + To add details to a log record, store them in a ``event_data`` + custom attribute. This dict is merged into the event. + + """ + def __init__(self): + pass # override logging.Formatter constructor + + def format(self, record): + event = { + "timestamp": self.getTimestamp(record.created), + "message": record.getMessage(), + "level": record.levelname, + "logger": record.name, + } + event_data = getattr(record, "event_data", None) + if event_data: + event.update(event_data) + if record.exc_info: + event["exc_info"] = self.formatException(record.exc_info) + if record.stack_info: + event["stack_info"] = self.formatStack(record.stack_info) + return json.dumps(event) + + def getTimestamp(self, created): + return datetime.datetime.utcfromtimestamp(created).isoformat() From a61ab26038e97e48b2f894c3e66e28bdfa3d8f3f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 30 May 2021 19:02:58 +0200 Subject: [PATCH 0836/1539] Log when server starts and stops. --- docs/howto/logging.rst | 13 +++++++------ src/websockets/legacy/server.py | 27 +++++++++++++++++++++++++-- 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/docs/howto/logging.rst b/docs/howto/logging.rst index b780c0162..f69ee47b9 100644 --- a/docs/howto/logging.rst +++ b/docs/howto/logging.rst @@ -73,8 +73,8 @@ Here's how to enable debug logs for development:: level=logging.DEBUG, ) -Furthermore, websockets adds a ``websocket`` attribute to every log record, so -you can include additional information about connections in logs. +Furthermore, websockets adds a ``websocket`` attribute to log records, so you +can include additional information about the current connection in logs. You could attempt to add information with a formatter:: @@ -85,11 +85,11 @@ You could attempt to add information with a formatter:: style="{", ) -However, this technique has two downsides: +However, this technique runs into two problems: * The formatter applies to all records. It will crash if it receives a record - that doesn't have a ``websocket`` attribute. You could configure logging to - work around this problem but that could get complicated quickly. + without a ``websocket`` attribute. For example, this happens when logging + that the server starts because there is no current connection. * Even with :meth:`str.format` style, you're restricted to attribute and index lookups, which isn't enough to implement some fairly simple requirements. @@ -208,7 +208,8 @@ Here's what websockets logs at each level. :attr:`~logging.INFO` ..................... -* Connections opened and closed in servers +* Server starting and stopping +* Server establishing and closing connections :attr:`~logging.DEBUG` ...................... diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 31ad0b18e..67fffff2b 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -725,10 +725,16 @@ class WebSocketServer: """ - def __init__(self, loop: asyncio.AbstractEventLoop) -> None: + def __init__( + self, loop: asyncio.AbstractEventLoop, logger: Optional[LoggerLike] = None + ) -> None: # Store a reference to loop to avoid relying on self.server._loop. self.loop = loop + if logger is None: + logger = logging.getLogger("websockets.server") + self.logger = logger + # Keep track of active connections. self.websockets: Set[WebSocketServerProtocol] = set() @@ -753,6 +759,19 @@ def wrap(self, server: asyncio.AbstractServer) -> None: """ self.server = server + assert server.sockets is not None + for sock in server.sockets: + if sock.family == socket.AF_INET: + name = "%s:%d" % sock.getsockname() + elif sock.family == socket.AF_INET6: + name = "[%s]:%d" % sock.getsockname()[:2] + elif sock.family == socket.AF_UNIX: + name = sock.getsockname() + # In the unlikely event that someone runs websockets over a + # protocol other than IP or Unix sockets, avoid crashing. + else: # pragma: no cover + name = str(sock.getsockname()) + self.logger.info("server listening on %s", name) def register(self, protocol: WebSocketServerProtocol) -> None: """ @@ -803,6 +822,8 @@ async def _close(self) -> None: then closes open connections with close code 1001. """ + self.logger.info("server closing") + # Stop accepting new connections. self.server.close() @@ -839,6 +860,8 @@ async def _close(self) -> None: # Tell wait_closed() to return. self.closed_waiter.set_result(None) + self.logger.info("server closed") + async def wait_closed(self) -> None: """ Wait until the server is closed. @@ -1008,7 +1031,7 @@ def __init__( else: warnings.warn("remove loop argument", DeprecationWarning) - ws_server = WebSocketServer(loop=loop) + ws_server = WebSocketServer(logger=logger, loop=loop) secure = kwargs.get("ssl") is not None From 848e0500a61e14180cebd8ad2eadaf9c9d3e4176 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 1 Jun 2021 07:45:10 +0200 Subject: [PATCH 0837/1539] Simplify logging of ping/pong. Focus on keepalive ping/pongs, as they're the only ones that users run in trouble with. The readable logs should do the rest. Ref #765. --- src/websockets/legacy/protocol.py | 34 ++++--------------------------- 1 file changed, 4 insertions(+), 30 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 90683aaf9..6a6a00012 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -933,21 +933,12 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: elif frame.opcode == OP_PING: # Answer pings. - ping_hex = frame.data.hex() or "[empty]" - if self.debug: - self.logger.debug("%% received ping, sending pong: %s", ping_hex) await self.pong(frame.data) elif frame.opcode == OP_PONG: - # Acknowledge pings on solicited pongs. if frame.data in self.pings: - if self.debug: - self.logger.debug( - "%% received solicited pong: %s", - frame.data.hex() or "[empty]", - ) - # Acknowledge all pings up to the one matching this pong. # Sending a pong for only the most recent ping is legal. + # Acknowledge all previous pings too in that case. ping_id = None ping_ids = [] for ping_id, ping in self.pings.items(): @@ -961,25 +952,6 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: # Remove acknowledged pings from self.pings. for ping_id in ping_ids: del self.pings[ping_id] - # Log previous pings acknowledged. - if self.debug: - ping_ids = ping_ids[:-1] - if ping_ids: - pings_hex = ", ".join( - ping_id.hex() or "[empty]" for ping_id in ping_ids - ) - plural = "s" if len(ping_ids) > 1 else "" - self.logger.debug( - "%% acknowledged previous ping%s: %s", - plural, - pings_hex, - ) - else: - if self.debug: - self.logger.debug( - "%% received unsolicited pong: %s", - frame.data.hex() or "[empty]", - ) # 5.6. Data Frames else: @@ -1078,6 +1050,7 @@ async def keepalive_ping(self) -> None: # ping() raises ConnectionClosed if the connection is lost, # when connection_lost() calls abort_pings(). + self.logger.debug("%% sending keepalive ping") pong_waiter = await self.ping() if self.ping_timeout is not None: @@ -1087,9 +1060,10 @@ async def keepalive_ping(self) -> None: self.ping_timeout, **loop_if_py_lt_38(self.loop), ) + self.logger.debug("%% received keepalive pong") except asyncio.TimeoutError: if self.debug: - self.logger.debug("! timed out waiting for pong") + self.logger.debug("! timed out waiting for keepalive pong") self.fail_connection(1011) break From 2883ca8a41bb4ce2a933e58a1fdbf958a97ccb84 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 1 Jun 2021 08:57:50 +0200 Subject: [PATCH 0838/1539] Simplify connection shutdown on handshake failure. --- src/websockets/legacy/protocol.py | 48 +++++++++++++++++-------------- src/websockets/legacy/server.py | 3 +- 2 files changed, 28 insertions(+), 23 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 6a6a00012..7a3570139 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1127,32 +1127,38 @@ async def close_connection(self) -> None: finally: # The try/finally ensures that the transport never remains open, # even if this coroutine is canceled (for example). + await self.close_transport() - # If connection_lost() was called, the TCP connection is closed. - # However, if TLS is enabled, the transport still needs closing. - # Else asyncio complains: ResourceWarning: unclosed transport. - if self.connection_lost_waiter.done() and self.transport.is_closing(): - return + async def close_transport(self) -> None: + """ + Close the TCP connection. - # Close the TCP connection. Buffers are flushed asynchronously. - if self.debug: - self.logger.debug("x closing TCP connection") - self.transport.close() + """ + # If connection_lost() was called, the TCP connection is closed. + # However, if TLS is enabled, the transport still needs closing. + # Else asyncio complains: ResourceWarning: unclosed transport. + if self.connection_lost_waiter.done() and self.transport.is_closing(): + return - if await self.wait_for_connection_lost(): - return - if self.debug: - self.logger.debug("! timed out waiting for TCP close") + # Close the TCP connection. Buffers are flushed asynchronously. + if self.debug: + self.logger.debug("x closing TCP connection") + self.transport.close() - # Abort the TCP connection. Buffers are discarded. - if self.debug: - self.logger.debug("x aborting TCP connection") - self.transport.abort() + if await self.wait_for_connection_lost(): + return + if self.debug: + self.logger.debug("! timed out waiting for TCP close") + + # Abort the TCP connection. Buffers are discarded. + if self.debug: + self.logger.debug("x aborting TCP connection") + self.transport.abort() - # connection_lost() is called quickly after aborting. - # Coverage marks this line as a partially executed branch. - # I supect a bug in coverage. Ignore it for now. - await self.wait_for_connection_lost() # pragma: no cover + # connection_lost() is called quickly after aborting. + # Coverage marks this line as a partially executed branch. + # I supect a bug in coverage. Ignore it for now. + await self.wait_for_connection_lost() # pragma: no cover async def wait_for_connection_lost(self) -> bool: """ diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 67fffff2b..1704ae083 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -291,8 +291,7 @@ async def handler(self) -> None: self.logger.info( "connection failed (%d %s)", status.value, status.phrase ) - self.fail_connection() - await self.wait_closed() + await self.close_transport() return try: From 6722857ebefd365c9299a79e1526644e25fe1015 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 3 Jun 2021 06:37:52 +0200 Subject: [PATCH 0839/1539] Make it possibly to bypass the handshake. This can facilitate integration in other projects that bring their own HTTP stack. --- src/websockets/client.py | 51 +++++++++++++++++++++------------------- src/websockets/server.py | 21 ++++++++++------- tests/test_client.py | 7 ++++++ tests/test_server.py | 7 ++++++ 4 files changed, 53 insertions(+), 33 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 738b6c998..698228d3a 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -1,7 +1,7 @@ import collections from typing import Generator, List, Optional, Sequence -from .connection import CLIENT, CONNECTING, OPEN, Connection +from .connection import CLIENT, CONNECTING, OPEN, Connection, State from .datastructures import Headers, HeadersLike, MultipleValuesError from .exceptions import ( InvalidHandshake, @@ -50,12 +50,13 @@ def __init__( extensions: Optional[Sequence[ClientExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLike] = None, + state: State = CONNECTING, max_size: Optional[int] = 2 ** 20, logger: Optional[LoggerLike] = None, ): super().__init__( side=CLIENT, - state=CONNECTING, + state=state, max_size=max_size, logger=logger, ) @@ -283,27 +284,29 @@ def send_request(self, request: Request) -> None: self.writes.append(request.serialize()) def parse(self) -> Generator[None, None, None]: - response = yield from Response.parse( - self.reader.read_line, - self.reader.read_exact, - self.reader.read_to_eof, - ) + if self.state is CONNECTING: + response = yield from Response.parse( + self.reader.read_line, + self.reader.read_exact, + self.reader.read_to_eof, + ) + + if self.debug: + code, phrase = response.status_code, response.reason_phrase + self.logger.debug("< HTTP/1.1 %d %s", code, phrase) + for key, value in response.headers.raw_items(): + self.logger.debug("< %s: %s", key, value) + if response.body is not None: + self.logger.debug("< [body] (%d bytes)", len(response.body)) + + try: + self.process_response(response) + except InvalidHandshake as exc: + response = response._replace(exception=exc) + else: + assert self.state == CONNECTING + self.set_state(OPEN) + finally: + self.events.append(response) - if self.debug: - code, phrase = response.status_code, response.reason_phrase - self.logger.debug("< HTTP/1.1 %d %s", code, phrase) - for key, value in response.headers.raw_items(): - self.logger.debug("< %s: %s", key, value) - if response.body is not None: - self.logger.debug("< [body] (%d bytes)", len(response.body)) - - assert self.state == CONNECTING - try: - self.process_response(response) - except InvalidHandshake as exc: - response = response._replace(exception=exc) - else: - self.set_state(OPEN) - finally: - self.events.append(response) yield from super().parse() diff --git a/src/websockets/server.py b/src/websockets/server.py index 87aa07b51..1ce943891 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -5,7 +5,7 @@ import http from typing import Callable, Generator, List, Optional, Sequence, Tuple, Union, cast -from .connection import CONNECTING, OPEN, SERVER, Connection +from .connection import CONNECTING, OPEN, SERVER, Connection, State from .datastructures import Headers, HeadersLike, MultipleValuesError from .exceptions import ( InvalidHandshake, @@ -56,12 +56,13 @@ def __init__( extensions: Optional[Sequence[ServerExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLikeOrCallable] = None, + state: State = CONNECTING, max_size: Optional[int] = 2 ** 20, logger: Optional[LoggerLike] = None, ): super().__init__( side=SERVER, - state=CONNECTING, + state=state, max_size=max_size, logger=logger, ) @@ -417,6 +418,7 @@ def send_response(self, response: Response) -> None: """ if response.status_code == 101: + assert self.state is CONNECTING self.set_state(OPEN) if self.debug: @@ -430,13 +432,14 @@ def send_response(self, response: Response) -> None: self.writes.append(response.serialize()) def parse(self) -> Generator[None, None, None]: - request = yield from Request.parse(self.reader.read_line) + if self.state == CONNECTING: + request = yield from Request.parse(self.reader.read_line) - if self.debug: - self.logger.debug("< GET %s HTTP/1.1", request.path) - for key, value in request.headers.raw_items(): - self.logger.debug("< %s: %s", key, value) + if self.debug: + self.logger.debug("< GET %s HTTP/1.1", request.path) + for key, value in request.headers.raw_items(): + self.logger.debug("< %s: %s", key, value) + + self.events.append(request) - assert self.state == CONNECTING - self.events.append(request) yield from super().parse() diff --git a/tests/test_client.py b/tests/test_client.py index b96ebd272..840b7148f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -6,6 +6,7 @@ from websockets.connection import CONNECTING, OPEN from websockets.datastructures import Headers from websockets.exceptions import InvalidHandshake, InvalidHeader +from websockets.frames import OP_TEXT, Frame from websockets.http import USER_AGENT from websockets.http11 import Request, Response from websockets.utils import accept_key @@ -572,6 +573,12 @@ def test_unsupported_subprotocol(self): class MiscTests(unittest.TestCase): + def test_bypass_handshake(self): + client = ClientConnection("ws://example.com/test", state=OPEN) + client.receive_data(b"\x81\x06Hello!") + [frame] = client.events_received() + self.assertEqual(frame, Frame(True, OP_TEXT, b"Hello!")) + def test_custom_logger(self): logger = logging.getLogger("test") with self.assertLogs("test", logging.DEBUG) as logs: diff --git a/tests/test_server.py b/tests/test_server.py index 6a25f0d25..db34e6ba3 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -6,6 +6,7 @@ from websockets.connection import CONNECTING, OPEN from websockets.datastructures import Headers from websockets.exceptions import InvalidHeader, InvalidOrigin, InvalidUpgrade +from websockets.frames import OP_TEXT, Frame from websockets.http import USER_AGENT from websockets.http11 import Request, Response from websockets.server import * @@ -629,6 +630,12 @@ def test_extra_headers_overrides_server(self): class MiscTests(unittest.TestCase): + def test_bypass_handshake(self): + server = ServerConnection(state=OPEN) + server.receive_data(b"\x81\x86\x00\x00\x00\x00Hello!") + [frame] = server.events_received() + self.assertEqual(frame, Frame(True, OP_TEXT, b"Hello!")) + def test_custom_logger(self): logger = logging.getLogger("test") with self.assertLogs("test", logging.DEBUG) as logs: From 0cce70a8645abeac25cae59eac6bd0fb008afa38 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 3 Jun 2021 09:26:10 +0200 Subject: [PATCH 0840/1539] Add error reason in close on ping timeout. Fix #636. --- src/websockets/legacy/protocol.py | 2 +- tests/legacy/test_protocol.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 7a3570139..992d678e2 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1064,7 +1064,7 @@ async def keepalive_ping(self) -> None: except asyncio.TimeoutError: if self.debug: self.logger.debug("! timed out waiting for keepalive pong") - self.fail_connection(1011) + self.fail_connection(1011, "keepalive ping timeout") break # Remove this branch when dropping support for Python < 3.8 diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 58444ce5a..6f6e6f686 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1143,7 +1143,9 @@ def test_keepalive_ping_not_acknowledged_closes_connection(self): # Connection is closed at 6ms. self.loop.run_until_complete(asyncio.sleep(4 * MS)) - self.assertOneFrameSent(True, OP_CLOSE, serialize_close(1011, "")) + self.assertOneFrameSent( + True, OP_CLOSE, serialize_close(1011, "keepalive ping timeout") + ) # The keepalive ping task is complete. self.assertEqual(self.protocol.keepalive_ping_task.result(), None) From f447a5699a2c6b85f6d6fbcba10982aa3e8d5afd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 3 Jun 2021 07:12:51 +0200 Subject: [PATCH 0841/1539] Convert WebSocketURI to a dataclass. --- src/websockets/uri.py | 20 +++++--------------- tests/test_uri.py | 31 ++++++++++++++++++++++++------- 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/src/websockets/uri.py b/src/websockets/uri.py index ce21b445b..958975b22 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -7,8 +7,9 @@ """ +import dataclasses import urllib.parse -from typing import NamedTuple, Optional, Tuple +from typing import Optional, Tuple from .exceptions import InvalidURI @@ -16,10 +17,8 @@ __all__ = ["parse_uri", "WebSocketURI"] -# Consider converting to a dataclass when dropping support for Python < 3.7. - - -class WebSocketURI(NamedTuple): +@dataclasses.dataclass +class WebSocketURI: """ WebSocket URI. @@ -37,16 +36,7 @@ class WebSocketURI(NamedTuple): host: str port: int resource_name: str - user_info: Optional[Tuple[str, str]] - - -# Work around https://bugs.python.org/issue19931 - -WebSocketURI.secure.__doc__ = "" -WebSocketURI.host.__doc__ = "" -WebSocketURI.port.__doc__ = "" -WebSocketURI.resource_name.__doc__ = "" -WebSocketURI.user_info.__doc__ = "" + user_info: Optional[Tuple[str, str]] = None # All characters from the gen-delims and sub-delims sets in RFC 3987. diff --git a/tests/test_uri.py b/tests/test_uri.py index 9eeb8431d..a91bcb083 100644 --- a/tests/test_uri.py +++ b/tests/test_uri.py @@ -5,15 +5,32 @@ VALID_URIS = [ - ("ws://localhost/", (False, "localhost", 80, "/", None)), - ("wss://localhost/", (True, "localhost", 443, "/", None)), - ("ws://localhost/path?query", (False, "localhost", 80, "/path?query", None)), - ("WS://LOCALHOST/PATH?QUERY", (False, "localhost", 80, "/PATH?QUERY", None)), - ("ws://user:pass@localhost/", (False, "localhost", 80, "/", ("user", "pass"))), - ("ws://høst/", (False, "xn--hst-0na", 80, "/", None)), + ( + "ws://localhost/", + WebSocketURI(False, "localhost", 80, "/", None), + ), + ( + "wss://localhost/", + WebSocketURI(True, "localhost", 443, "/", None), + ), + ( + "ws://localhost/path?query", + WebSocketURI(False, "localhost", 80, "/path?query", None), + ), + ( + "WS://LOCALHOST/PATH?QUERY", + WebSocketURI(False, "localhost", 80, "/PATH?QUERY", None), + ), + ( + "ws://user:pass@localhost/", + WebSocketURI(False, "localhost", 80, "/", ("user", "pass")), + ), + ("ws://høst/", WebSocketURI(False, "xn--hst-0na", 80, "/", None)), ( "ws://üser:påss@høst/πass", - (False, "xn--hst-0na", 80, "/%CF%80ass", ("%C3%BCser", "p%C3%A5ss")), + WebSocketURI( + False, "xn--hst-0na", 80, "/%CF%80ass", ("%C3%BCser", "p%C3%A5ss") + ), ), ] From 5dd1290f202a8274e68d48582b400013fb953010 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 3 Jun 2021 07:33:21 +0200 Subject: [PATCH 0842/1539] Convert Request/Response to dataclasses. --- src/websockets/client.py | 2 +- src/websockets/http11.py | 18 +++++++++--------- src/websockets/server.py | 15 ++++++++------- tests/test_server.py | 32 ++++++++++++++++---------------- 4 files changed, 34 insertions(+), 33 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 698228d3a..e9bc12cbe 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -302,7 +302,7 @@ def parse(self) -> Generator[None, None, None]: try: self.process_response(response) except InvalidHandshake as exc: - response = response._replace(exception=exc) + response.exception = exc else: assert self.state == CONNECTING self.set_state(OPEN) diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 6f3cbccc4..22488dc89 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -1,5 +1,6 @@ +import dataclasses import re -from typing import Callable, Generator, NamedTuple, Optional +from typing import Callable, Generator, Optional from .datastructures import Headers from .exceptions import SecurityError @@ -37,10 +38,8 @@ def d(value: bytes) -> str: _value_re = re.compile(rb"[\x09\x20-\x7e\x80-\xff]*") -# Consider converting to dataclasses when dropping support for Python < 3.7. - - -class Request(NamedTuple): +@dataclasses.dataclass +class Request: """ WebSocket handshake request. @@ -52,6 +51,9 @@ class Request(NamedTuple): headers: Headers # body isn't useful is the context of this library + # If processing the request triggers an exception, it's stored here. + exception: Optional[Exception] = None + @classmethod def parse( cls, read_line: Callable[[], Generator[None, None, bytes]] @@ -121,10 +123,8 @@ def serialize(self) -> bytes: return request -# Consider converting to dataclasses when dropping support for Python < 3.7. - - -class Response(NamedTuple): +@dataclasses.dataclass +class Response: """ WebSocket handshake response. diff --git a/src/websockets/server.py b/src/websockets/server.py index 1ce943891..483ddbe1e 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -79,19 +79,18 @@ def accept(self, request: Request) -> Response: connection, which may be unexpected. """ - # TODO: when changing Request to a dataclass, set the exception - # attribute on the request rather than the Response, which will - # be semantically more correct. try: key, extensions_header, protocol_header = self.process_request(request) except InvalidOrigin as exc: + request.exception = exc if self.debug: self.logger.debug("! invalid origin", exc_info=True) return self.reject( http.HTTPStatus.FORBIDDEN, f"Failed to open a WebSocket connection: {exc}.\n", - )._replace(exception=exc) + ) except InvalidUpgrade as exc: + request.exception = exc if self.debug: self.logger.debug("! invalid upgrade", exc_info=True) return self.reject( @@ -103,15 +102,17 @@ def accept(self, request: Request) -> Response: f"with a browser. You need a WebSocket client.\n" ), headers=Headers([("Upgrade", "websocket")]), - )._replace(exception=exc) + ) except InvalidHandshake as exc: + request.exception = exc if self.debug: self.logger.debug("! invalid handshake", exc_info=True) return self.reject( http.HTTPStatus.BAD_REQUEST, f"Failed to open a WebSocket connection: {exc}.\n", - )._replace(exception=exc) + ) except Exception as exc: + request.exception = exc self.logger.error("opening handshake failed", exc_info=True) return self.reject( http.HTTPStatus.INTERNAL_SERVER_ERROR, @@ -119,7 +120,7 @@ def accept(self, request: Request) -> Response: "Failed to open a WebSocket connection.\n" "See server log for more information.\n" ), - )._replace(exception=exc) + ) headers = Headers() diff --git a/tests/test_server.py b/tests/test_server.py index db34e6ba3..86fa3f34d 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -187,7 +187,7 @@ def test_unexpected_exception(self): self.assertEqual(response.status_code, 500) with self.assertRaises(Exception) as raised: - raise response.exception + raise request.exception self.assertEqual(str(raised.exception), "BOOM") def test_missing_connection(self): @@ -199,7 +199,7 @@ def test_missing_connection(self): self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: - raise response.exception + raise request.exception self.assertEqual(str(raised.exception), "missing Connection header") def test_invalid_connection(self): @@ -212,7 +212,7 @@ def test_invalid_connection(self): self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: - raise response.exception + raise request.exception self.assertEqual(str(raised.exception), "invalid Connection header: close") def test_missing_upgrade(self): @@ -224,7 +224,7 @@ def test_missing_upgrade(self): self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: - raise response.exception + raise request.exception self.assertEqual(str(raised.exception), "missing Upgrade header") def test_invalid_upgrade(self): @@ -237,7 +237,7 @@ def test_invalid_upgrade(self): self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: - raise response.exception + raise request.exception self.assertEqual(str(raised.exception), "invalid Upgrade header: h2c") def test_missing_key(self): @@ -248,7 +248,7 @@ def test_missing_key(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise response.exception + raise request.exception self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Key header") def test_multiple_key(self): @@ -259,7 +259,7 @@ def test_multiple_key(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise response.exception + raise request.exception self.assertEqual( str(raised.exception), "invalid Sec-WebSocket-Key header: " @@ -275,7 +275,7 @@ def test_invalid_key(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise response.exception + raise request.exception self.assertEqual( str(raised.exception), "invalid Sec-WebSocket-Key header: not Base64 data!" ) @@ -291,7 +291,7 @@ def test_truncated_key(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise response.exception + raise request.exception self.assertEqual( str(raised.exception), f"invalid Sec-WebSocket-Key header: {KEY[:16]}" ) @@ -304,7 +304,7 @@ def test_missing_version(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise response.exception + raise request.exception self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Version header") def test_multiple_version(self): @@ -315,7 +315,7 @@ def test_multiple_version(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise response.exception + raise request.exception self.assertEqual( str(raised.exception), "invalid Sec-WebSocket-Version header: " @@ -331,7 +331,7 @@ def test_invalid_version(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise response.exception + raise request.exception self.assertEqual( str(raised.exception), "invalid Sec-WebSocket-Version header: 11" ) @@ -343,7 +343,7 @@ def test_no_origin(self): self.assertEqual(response.status_code, 403) with self.assertRaises(InvalidOrigin) as raised: - raise response.exception + raise request.exception self.assertEqual(str(raised.exception), "missing Origin header") def test_origin(self): @@ -363,7 +363,7 @@ def test_unexpected_origin(self): self.assertEqual(response.status_code, 403) with self.assertRaises(InvalidOrigin) as raised: - raise response.exception + raise request.exception self.assertEqual( str(raised.exception), "invalid Origin header: https://other.example.com" ) @@ -381,7 +381,7 @@ def test_multiple_origin(self): # 400 Bad Request rather than 403 Forbidden. self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise response.exception + raise request.exception self.assertEqual( str(raised.exception), "invalid Origin header: more than one Origin header found", @@ -408,7 +408,7 @@ def test_unsupported_origin(self): self.assertEqual(response.status_code, 403) with self.assertRaises(InvalidOrigin) as raised: - raise response.exception + raise request.exception self.assertEqual( str(raised.exception), "invalid Origin header: https://original.example.com" ) From 634ee6a29327146d317564c88010310556c4c384 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 3 Jun 2021 22:30:49 +0200 Subject: [PATCH 0843/1539] Wrap NewFrame instead of inheriting it. This ensures Frame remains a NamedTuple for backwards compatibility. --- src/websockets/legacy/framing.py | 46 +++++++++++++++++++++++++++----- tests/legacy/test_framing.py | 1 + 2 files changed, 40 insertions(+), 7 deletions(-) diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index e947c9383..016677615 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -11,7 +11,7 @@ """ import struct -from typing import Any, Awaitable, Callable, Optional, Sequence +from typing import Any, Awaitable, Callable, NamedTuple, Optional, Sequence from ..exceptions import PayloadTooBig, ProtocolError from ..frames import Frame as NewFrame, Opcode @@ -23,7 +23,32 @@ from ..utils import apply_mask -class Frame(NewFrame): +class Frame(NamedTuple): + + fin: bool + opcode: Opcode + data: bytes + rsv1: bool = False + rsv2: bool = False + rsv3: bool = False + + @property + def new_frame(self) -> NewFrame: + return NewFrame( + self.fin, + self.opcode, + self.data, + self.rsv1, + self.rsv2, + self.rsv3, + ) + + def __str__(self) -> str: + return str(self.new_frame) + + def check(self) -> None: + return self.new_frame.check() + @classmethod async def read( cls, @@ -86,16 +111,23 @@ async def read( if mask: data = apply_mask(data, mask_bits) - frame = cls(fin, opcode, data, rsv1, rsv2, rsv3) + new_frame = NewFrame(fin, opcode, data, rsv1, rsv2, rsv3) if extensions is None: extensions = [] for extension in reversed(extensions): - frame = cls(*extension.decode(frame, max_size=max_size)) + new_frame = extension.decode(new_frame, max_size=max_size) - frame.check() + new_frame.check() - return frame + return cls( + new_frame.fin, + new_frame.opcode, + new_frame.data, + new_frame.rsv1, + new_frame.rsv2, + new_frame.rsv3, + ) def write( self, @@ -121,7 +153,7 @@ def write( # The frame is written in a single call to write in order to prevent # TCP fragmentation. See #68 for details. This also makes it safe to # send frames concurrently from multiple coroutines. - write(self.serialize(mask=mask, extensions=extensions)) + write(self.new_frame.serialize(mask=mask, extensions=extensions)) # at the bottom to allow circular import, because Extension depends on Frame diff --git a/tests/legacy/test_framing.py b/tests/legacy/test_framing.py index ac870c79e..0e2670faa 100644 --- a/tests/legacy/test_framing.py +++ b/tests/legacy/test_framing.py @@ -43,6 +43,7 @@ def encode(self, frame, mask=False, extensions=None): def round_trip(self, message, expected, mask=False, extensions=None): decoded = self.decode(message, mask, extensions=extensions) + decoded.check() self.assertEqual(decoded, expected) encoded = self.encode(decoded, mask, extensions=extensions) if mask: # non-deterministic encoding From d500274f9340b4336d27665a48e9ce34f23c9910 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 3 Jun 2021 09:16:21 +0200 Subject: [PATCH 0844/1539] Convert Frame to a dataclass. --- .../extensions/permessage_deflate.py | 9 +++--- src/websockets/frames.py | 9 +++--- tests/extensions/test_permessage_deflate.py | 29 +++++++++++++------ tests/extensions/utils.py | 6 ++-- tests/legacy/test_framing.py | 3 +- tests/test_frames.py | 3 +- 6 files changed, 37 insertions(+), 22 deletions(-) diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 34cc1f950..56ef03e0c 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -4,6 +4,7 @@ """ +import dataclasses import zlib from typing import Any, Dict, List, Optional, Sequence, Tuple, Union @@ -115,7 +116,7 @@ def decode(self, frame: Frame, *, max_size: Optional[int] = None) -> Frame: else: if not frame.rsv1: return frame - frame = frame._replace(rsv1=False) + frame = dataclasses.replace(frame, rsv1=False) if not frame.fin: self.decode_cont_data = True @@ -138,7 +139,7 @@ def decode(self, frame: Frame, *, max_size: Optional[int] = None) -> Frame: if frame.fin and self.remote_no_context_takeover: del self.decoder - return frame._replace(data=data) + return dataclasses.replace(frame, data=data) def encode(self, frame: Frame) -> Frame: """ @@ -154,7 +155,7 @@ def encode(self, frame: Frame) -> Frame: if frame.opcode != OP_CONT: # Set the rsv1 flag on the first frame of a compressed message. - frame = frame._replace(rsv1=True) + frame = dataclasses.replace(frame, rsv1=True) # Re-initialize per-message decoder. if self.local_no_context_takeover: self.encoder = zlib.compressobj( @@ -170,7 +171,7 @@ def encode(self, frame: Frame) -> Frame: if frame.fin and self.local_no_context_takeover: del self.encoder - return frame._replace(data=data) + return dataclasses.replace(frame, data=data) def _build_parameters( diff --git a/src/websockets/frames.py b/src/websockets/frames.py index de7fdb941..79de63e47 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -3,11 +3,12 @@ """ +import dataclasses import enum import io import secrets import struct -from typing import Callable, Generator, NamedTuple, Optional, Sequence, Tuple +from typing import Callable, Generator, Optional, Sequence, Tuple from .exceptions import PayloadTooBig, ProtocolError from .typing import Data @@ -92,10 +93,8 @@ class Opcode(enum.IntEnum): } -# Consider converting to a dataclass when dropping support for Python < 3.7. - - -class Frame(NamedTuple): +@dataclasses.dataclass +class Frame: """ WebSocket frame. diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index d3c6f0ac6..2af04a806 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -1,3 +1,4 @@ +import dataclasses import unittest import zlib @@ -82,7 +83,9 @@ def test_encode_decode_text_frame(self): enc_frame = self.extension.encode(frame) - self.assertEqual(enc_frame, frame._replace(rsv1=True, data=b"JNL;\xbc\x12\x00")) + self.assertEqual( + enc_frame, dataclasses.replace(frame, rsv1=True, data=b"JNL;\xbc\x12\x00") + ) dec_frame = self.extension.decode(enc_frame) @@ -93,7 +96,9 @@ def test_encode_decode_binary_frame(self): enc_frame = self.extension.encode(frame) - self.assertEqual(enc_frame, frame._replace(rsv1=True, data=b"*IM\x04\x00")) + self.assertEqual( + enc_frame, dataclasses.replace(frame, rsv1=True, data=b"*IM\x04\x00") + ) dec_frame = self.extension.decode(enc_frame) @@ -110,13 +115,15 @@ def test_encode_decode_fragmented_text_frame(self): self.assertEqual( enc_frame1, - frame1._replace(rsv1=True, data=b"JNL;\xbc\x12\x00\x00\x00\xff\xff"), + dataclasses.replace( + frame1, rsv1=True, data=b"JNL;\xbc\x12\x00\x00\x00\xff\xff" + ), ) self.assertEqual( - enc_frame2, frame2._replace(data=b"RPS\x00\x00\x00\x00\xff\xff") + enc_frame2, dataclasses.replace(frame2, data=b"RPS\x00\x00\x00\x00\xff\xff") ) self.assertEqual( - enc_frame3, frame3._replace(data=b"J.\xca\xcf,.N\xcc+)\x06\x00") + enc_frame3, dataclasses.replace(frame3, data=b"J.\xca\xcf,.N\xcc+)\x06\x00") ) dec_frame1 = self.extension.decode(enc_frame1) @@ -136,11 +143,13 @@ def test_encode_decode_fragmented_binary_frame(self): self.assertEqual( enc_frame1, - frame1._replace(rsv1=True, data=b"*IMT\x00\x00\x00\x00\xff\xff"), + dataclasses.replace( + frame1, rsv1=True, data=b"*IMT\x00\x00\x00\x00\xff\xff" + ), ) self.assertEqual( enc_frame2, - frame2._replace(data=b"*\xc9\xccM\x05\x00"), + dataclasses.replace(frame2, data=b"*\xc9\xccM\x05\x00"), ) dec_frame1 = self.extension.decode(enc_frame1) @@ -242,8 +251,10 @@ def test_compress_settings(self): self.assertEqual( enc_frame, - frame._replace( - rsv1=True, data=b"\x00\x05\x00\xfa\xffcaf\xc3\xa9\x00" # not compressed + dataclasses.replace( + frame, + rsv1=True, + data=b"\x00\x05\x00\xfa\xffcaf\xc3\xa9\x00", # not compressed ), ) diff --git a/tests/extensions/utils.py b/tests/extensions/utils.py index 81990bb07..1eabc163f 100644 --- a/tests/extensions/utils.py +++ b/tests/extensions/utils.py @@ -1,3 +1,5 @@ +import dataclasses + from websockets.exceptions import NegotiationError @@ -49,11 +51,11 @@ class Rsv2Extension: def decode(self, frame, *, max_size=None): assert frame.rsv2 - return frame._replace(rsv2=False) + return dataclasses.replace(frame, rsv2=False) def encode(self, frame): assert not frame.rsv2 - return frame._replace(rsv2=True) + return dataclasses.replace(frame, rsv2=True) def __eq__(self, other): return isinstance(other, Rsv2Extension) diff --git a/tests/legacy/test_framing.py b/tests/legacy/test_framing.py index 0e2670faa..2baa827a9 100644 --- a/tests/legacy/test_framing.py +++ b/tests/legacy/test_framing.py @@ -1,5 +1,6 @@ import asyncio import codecs +import dataclasses import unittest import unittest.mock import warnings @@ -160,7 +161,7 @@ def encode(frame): assert frame.opcode == OP_TEXT text = frame.data.decode() data = codecs.encode(text, "rot13").encode() - return frame._replace(data=data) + return dataclasses.replace(frame, data=data) # This extensions is symmetrical. @staticmethod diff --git a/tests/test_frames.py b/tests/test_frames.py index 491386566..c05fa43a5 100644 --- a/tests/test_frames.py +++ b/tests/test_frames.py @@ -1,4 +1,5 @@ import codecs +import dataclasses import unittest import unittest.mock @@ -178,7 +179,7 @@ def encode(frame): assert frame.opcode == OP_TEXT text = frame.data.decode() data = codecs.encode(text, "rot13").encode() - return frame._replace(data=data) + return dataclasses.replace(frame, data=data) # This extensions is symmetrical. @staticmethod From ba4be454b6fd7f283d2ea8ec4354e2e6c9627267 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 3 Jun 2021 22:20:48 +0200 Subject: [PATCH 0845/1539] Give FIN a default value of False. This makes the default case, non fragmented messages, convenient. Also it uniformises with rsv1/2/3. --- src/websockets/connection.py | 18 +- src/websockets/frames.py | 4 +- src/websockets/legacy/framing.py | 4 +- tests/extensions/test_permessage_deflate.py | 44 ++--- tests/test_client.py | 2 +- tests/test_connection.py | 186 ++++++++++---------- tests/test_frames.py | 64 +++---- tests/test_server.py | 2 +- 8 files changed, 161 insertions(+), 163 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 5052df3f7..dbdc15bb5 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -166,7 +166,7 @@ def send_continuation(self, data: bytes, fin: bool) -> None: if not self.expect_continuation_frame: raise ProtocolError("unexpected continuation frame") self.expect_continuation_frame = not fin - self.send_frame(Frame(fin, OP_CONT, data)) + self.send_frame(Frame(OP_CONT, data, fin)) def send_text(self, data: bytes, fin: bool = True) -> None: """ @@ -176,7 +176,7 @@ def send_text(self, data: bytes, fin: bool = True) -> None: if self.expect_continuation_frame: raise ProtocolError("expected a continuation frame") self.expect_continuation_frame = not fin - self.send_frame(Frame(fin, OP_TEXT, data)) + self.send_frame(Frame(OP_TEXT, data, fin)) def send_binary(self, data: bytes, fin: bool = True) -> None: """ @@ -186,7 +186,7 @@ def send_binary(self, data: bytes, fin: bool = True) -> None: if self.expect_continuation_frame: raise ProtocolError("expected a continuation frame") self.expect_continuation_frame = not fin - self.send_frame(Frame(fin, OP_BINARY, data)) + self.send_frame(Frame(OP_BINARY, data, fin)) def send_close(self, code: Optional[int] = None, reason: str = "") -> None: """ @@ -201,7 +201,7 @@ def send_close(self, code: Optional[int] = None, reason: str = "") -> None: data = b"" else: data = serialize_close(code, reason) - self.send_frame(Frame(True, OP_CLOSE, data)) + self.send_frame(Frame(OP_CLOSE, data)) # send_frame() guarantees that self.state is OPEN at this point. # 7.1.3. The WebSocket Closing Handshake is Started self.set_state(CLOSING) @@ -213,14 +213,14 @@ def send_ping(self, data: bytes) -> None: Send a ping frame. """ - self.send_frame(Frame(True, OP_PING, data)) + self.send_frame(Frame(OP_PING, data)) def send_pong(self, data: bytes) -> None: """ Send a pong frame. """ - self.send_frame(Frame(True, OP_PONG, data)) + self.send_frame(Frame(OP_PONG, data)) # Public API for getting incoming events after receiving data. @@ -257,7 +257,7 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> None: # sent if it's CLOSING), except when failing the connection because of # an error reading from or writing to the network. if code != 1006 and self.state is OPEN: - self.send_frame(Frame(True, OP_CLOSE, serialize_close(code, reason))) + self.send_frame(Frame(OP_CLOSE, serialize_close(code, reason))) self.set_state(CLOSING) if not self.eof_sent: self.send_eof() @@ -365,7 +365,7 @@ def parse(self) -> Generator[None, None, None]: # send a Pong frame in response, unless it already received a # Close frame." if not self.close_frame_received: - pong_frame = Frame(True, OP_PONG, frame.data) + pong_frame = Frame(OP_PONG, frame.data) self.send_frame(pong_frame) elif frame.opcode is OP_PONG: @@ -391,7 +391,7 @@ def parse(self) -> Generator[None, None, None]: # serialize_close() because that fails when the close frame # is empty and parse_close() synthetizes a 1005 close code. # The rest is identical to send_close(). - self.send_frame(Frame(True, OP_CLOSE, frame.data)) + self.send_frame(Frame(OP_CLOSE, frame.data)) self.set_state(CLOSING) if self.side is SERVER: self.send_eof() diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 79de63e47..ec6ff2258 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -110,9 +110,9 @@ class Frame: """ - fin: bool opcode: Opcode data: bytes + fin: bool = True rsv1: bool = False rsv2: bool = False rsv3: bool = False @@ -224,7 +224,7 @@ def parse( if mask: data = apply_mask(data, mask_bytes) - frame = cls(fin, opcode, data, rsv1, rsv2, rsv3) + frame = cls(opcode, data, fin, rsv1, rsv2, rsv3) if extensions is None: extensions = [] diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index 016677615..901b2e2e3 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -35,9 +35,9 @@ class Frame(NamedTuple): @property def new_frame(self) -> NewFrame: return NewFrame( - self.fin, self.opcode, self.data, + self.fin, self.rsv1, self.rsv2, self.rsv3, @@ -111,7 +111,7 @@ async def read( if mask: data = apply_mask(data, mask_bits) - new_frame = NewFrame(fin, opcode, data, rsv1, rsv2, rsv3) + new_frame = NewFrame(opcode, data, fin, rsv1, rsv2, rsv3) if extensions is None: extensions = [] diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index 2af04a806..5ba4d8ddf 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -56,21 +56,21 @@ def test_repr(self): # Control frames aren't encoded or decoded. def test_no_encode_decode_ping_frame(self): - frame = Frame(True, OP_PING, b"") + frame = Frame(OP_PING, b"") self.assertEqual(self.extension.encode(frame), frame) self.assertEqual(self.extension.decode(frame), frame) def test_no_encode_decode_pong_frame(self): - frame = Frame(True, OP_PONG, b"") + frame = Frame(OP_PONG, b"") self.assertEqual(self.extension.encode(frame), frame) self.assertEqual(self.extension.decode(frame), frame) def test_no_encode_decode_close_frame(self): - frame = Frame(True, OP_CLOSE, serialize_close(1000, "")) + frame = Frame(OP_CLOSE, serialize_close(1000, "")) self.assertEqual(self.extension.encode(frame), frame) @@ -79,7 +79,7 @@ def test_no_encode_decode_close_frame(self): # Data frames are encoded and decoded. def test_encode_decode_text_frame(self): - frame = Frame(True, OP_TEXT, "café".encode("utf-8")) + frame = Frame(OP_TEXT, "café".encode("utf-8")) enc_frame = self.extension.encode(frame) @@ -92,7 +92,7 @@ def test_encode_decode_text_frame(self): self.assertEqual(dec_frame, frame) def test_encode_decode_binary_frame(self): - frame = Frame(True, OP_BINARY, b"tea") + frame = Frame(OP_BINARY, b"tea") enc_frame = self.extension.encode(frame) @@ -105,9 +105,9 @@ def test_encode_decode_binary_frame(self): self.assertEqual(dec_frame, frame) def test_encode_decode_fragmented_text_frame(self): - frame1 = Frame(False, OP_TEXT, "café".encode("utf-8")) - frame2 = Frame(False, OP_CONT, " & ".encode("utf-8")) - frame3 = Frame(True, OP_CONT, "croissants".encode("utf-8")) + frame1 = Frame(OP_TEXT, "café".encode("utf-8"), fin=False) + frame2 = Frame(OP_CONT, " & ".encode("utf-8"), fin=False) + frame3 = Frame(OP_CONT, "croissants".encode("utf-8")) enc_frame1 = self.extension.encode(frame1) enc_frame2 = self.extension.encode(frame2) @@ -135,8 +135,8 @@ def test_encode_decode_fragmented_text_frame(self): self.assertEqual(dec_frame3, frame3) def test_encode_decode_fragmented_binary_frame(self): - frame1 = Frame(False, OP_TEXT, b"tea ") - frame2 = Frame(True, OP_CONT, b"time") + frame1 = Frame(OP_TEXT, b"tea ", fin=False) + frame2 = Frame(OP_CONT, b"time") enc_frame1 = self.extension.encode(frame1) enc_frame2 = self.extension.encode(frame2) @@ -159,21 +159,21 @@ def test_encode_decode_fragmented_binary_frame(self): self.assertEqual(dec_frame2, frame2) def test_no_decode_text_frame(self): - frame = Frame(True, OP_TEXT, "café".encode("utf-8")) + frame = Frame(OP_TEXT, "café".encode("utf-8")) # Try decoding a frame that wasn't encoded. self.assertEqual(self.extension.decode(frame), frame) def test_no_decode_binary_frame(self): - frame = Frame(True, OP_TEXT, b"tea") + frame = Frame(OP_TEXT, b"tea") # Try decoding a frame that wasn't encoded. self.assertEqual(self.extension.decode(frame), frame) def test_no_decode_fragmented_text_frame(self): - frame1 = Frame(False, OP_TEXT, "café".encode("utf-8")) - frame2 = Frame(False, OP_CONT, " & ".encode("utf-8")) - frame3 = Frame(True, OP_CONT, "croissants".encode("utf-8")) + frame1 = Frame(OP_TEXT, "café".encode("utf-8"), fin=False) + frame2 = Frame(OP_CONT, " & ".encode("utf-8"), fin=False) + frame3 = Frame(OP_CONT, "croissants".encode("utf-8")) dec_frame1 = self.extension.decode(frame1) dec_frame2 = self.extension.decode(frame2) @@ -184,8 +184,8 @@ def test_no_decode_fragmented_text_frame(self): self.assertEqual(dec_frame3, frame3) def test_no_decode_fragmented_binary_frame(self): - frame1 = Frame(False, OP_TEXT, b"tea ") - frame2 = Frame(True, OP_CONT, b"time") + frame1 = Frame(OP_TEXT, b"tea ", fin=False) + frame2 = Frame(OP_CONT, b"time") dec_frame1 = self.extension.decode(frame1) dec_frame2 = self.extension.decode(frame2) @@ -194,7 +194,7 @@ def test_no_decode_fragmented_binary_frame(self): self.assertEqual(dec_frame2, frame2) def test_context_takeover(self): - frame = Frame(True, OP_TEXT, "café".encode("utf-8")) + frame = Frame(OP_TEXT, "café".encode("utf-8")) enc_frame1 = self.extension.encode(frame) enc_frame2 = self.extension.encode(frame) @@ -206,7 +206,7 @@ def test_remote_no_context_takeover(self): # No context takeover when decoding messages. self.extension = PerMessageDeflate(True, False, 15, 15) - frame = Frame(True, OP_TEXT, "café".encode("utf-8")) + frame = Frame(OP_TEXT, "café".encode("utf-8")) enc_frame1 = self.extension.encode(frame) enc_frame2 = self.extension.encode(frame) @@ -225,7 +225,7 @@ def test_local_no_context_takeover(self): # No context takeover when encoding and decoding messages. self.extension = PerMessageDeflate(True, True, 15, 15) - frame = Frame(True, OP_TEXT, "café".encode("utf-8")) + frame = Frame(OP_TEXT, "café".encode("utf-8")) enc_frame1 = self.extension.encode(frame) enc_frame2 = self.extension.encode(frame) @@ -245,7 +245,7 @@ def test_compress_settings(self): # Configure an extension so that no compression actually occurs. extension = PerMessageDeflate(False, False, 15, 15, {"level": 0}) - frame = Frame(True, OP_TEXT, "café".encode("utf-8")) + frame = Frame(OP_TEXT, "café".encode("utf-8")) enc_frame = extension.encode(frame) @@ -261,7 +261,7 @@ def test_compress_settings(self): # Frames aren't decoded beyond max_size. def test_decompress_max_size(self): - frame = Frame(True, OP_TEXT, ("a" * 20).encode("utf-8")) + frame = Frame(OP_TEXT, ("a" * 20).encode("utf-8")) enc_frame = self.extension.encode(frame) diff --git a/tests/test_client.py b/tests/test_client.py index 840b7148f..2ef1f6a95 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -577,7 +577,7 @@ def test_bypass_handshake(self): client = ClientConnection("ws://example.com/test", state=OPEN) client.receive_data(b"\x81\x06Hello!") [frame] = client.events_received() - self.assertEqual(frame, Frame(True, OP_TEXT, b"Hello!")) + self.assertEqual(frame, Frame(OP_TEXT, b"Hello!")) def test_custom_logger(self): logger = logging.getLogger("test") diff --git a/tests/test_connection.py b/tests/test_connection.py index 6203a1469..677881238 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -59,7 +59,6 @@ def assertConnectionClosing(self, connection, code=None, reason=""): """ close_frame = Frame( - True, OP_CLOSE, b"" if code is None else serialize_close(code, reason), ) @@ -76,7 +75,6 @@ def assertConnectionFailing(self, connection, code=None, reason=""): """ close_frame = Frame( - True, OP_CLOSE, b"" if code is None else serialize_close(code, reason), ) @@ -113,7 +111,7 @@ def test_client_receives_unmasked_frame(self): client.receive_data(self.unmasked_text_frame_date) self.assertFrameReceived( client, - Frame(True, OP_TEXT, b"Spam"), + Frame(OP_TEXT, b"Spam"), ) def test_server_receives_masked_frame(self): @@ -121,7 +119,7 @@ def test_server_receives_masked_frame(self): server.receive_data(self.masked_text_frame_data) self.assertFrameReceived( server, - Frame(True, OP_TEXT, b"Spam"), + Frame(OP_TEXT, b"Spam"), ) def test_client_receives_masked_frame(self): @@ -235,7 +233,7 @@ def test_client_receives_text(self): client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") self.assertFrameReceived( client, - Frame(True, OP_TEXT, "😀".encode()), + Frame(OP_TEXT, "😀".encode()), ) def test_server_receives_text(self): @@ -243,7 +241,7 @@ def test_server_receives_text(self): server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") self.assertFrameReceived( server, - Frame(True, OP_TEXT, "😀".encode()), + Frame(OP_TEXT, "😀".encode()), ) def test_client_receives_text_over_size_limit(self): @@ -265,7 +263,7 @@ def test_client_receives_text_without_size_limit(self): client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") self.assertFrameReceived( client, - Frame(True, OP_TEXT, "😀".encode()), + Frame(OP_TEXT, "😀".encode()), ) def test_server_receives_text_without_size_limit(self): @@ -273,7 +271,7 @@ def test_server_receives_text_without_size_limit(self): server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") self.assertFrameReceived( server, - Frame(True, OP_TEXT, "😀".encode()), + Frame(OP_TEXT, "😀".encode()), ) def test_client_sends_fragmented_text(self): @@ -304,17 +302,17 @@ def test_client_receives_fragmented_text(self): client.receive_data(b"\x01\x02\xf0\x9f") self.assertFrameReceived( client, - Frame(False, OP_TEXT, "😀".encode()[:2]), + Frame(OP_TEXT, "😀".encode()[:2], fin=False), ) client.receive_data(b"\x00\x04\x98\x80\xf0\x9f") self.assertFrameReceived( client, - Frame(False, OP_CONT, "😀😀".encode()[2:6]), + Frame(OP_CONT, "😀😀".encode()[2:6], fin=False), ) client.receive_data(b"\x80\x02\x98\x80") self.assertFrameReceived( client, - Frame(True, OP_CONT, "😀".encode()[2:]), + Frame(OP_CONT, "😀".encode()[2:]), ) def test_server_receives_fragmented_text(self): @@ -322,17 +320,17 @@ def test_server_receives_fragmented_text(self): server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") self.assertFrameReceived( server, - Frame(False, OP_TEXT, "😀".encode()[:2]), + Frame(OP_TEXT, "😀".encode()[:2], fin=False), ) server.receive_data(b"\x00\x84\x00\x00\x00\x00\x98\x80\xf0\x9f") self.assertFrameReceived( server, - Frame(False, OP_CONT, "😀😀".encode()[2:6]), + Frame(OP_CONT, "😀😀".encode()[2:6], fin=False), ) server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") self.assertFrameReceived( server, - Frame(True, OP_CONT, "😀".encode()[2:]), + Frame(OP_CONT, "😀".encode()[2:]), ) def test_client_receives_fragmented_text_over_size_limit(self): @@ -340,7 +338,7 @@ def test_client_receives_fragmented_text_over_size_limit(self): client.receive_data(b"\x01\x02\xf0\x9f") self.assertFrameReceived( client, - Frame(False, OP_TEXT, "😀".encode()[:2]), + Frame(OP_TEXT, "😀".encode()[:2], fin=False), ) with self.assertRaises(PayloadTooBig) as raised: client.receive_data(b"\x80\x02\x98\x80") @@ -352,7 +350,7 @@ def test_server_receives_fragmented_text_over_size_limit(self): server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") self.assertFrameReceived( server, - Frame(False, OP_TEXT, "😀".encode()[:2]), + Frame(OP_TEXT, "😀".encode()[:2], fin=False), ) with self.assertRaises(PayloadTooBig) as raised: server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") @@ -364,17 +362,17 @@ def test_client_receives_fragmented_text_without_size_limit(self): client.receive_data(b"\x01\x02\xf0\x9f") self.assertFrameReceived( client, - Frame(False, OP_TEXT, "😀".encode()[:2]), + Frame(OP_TEXT, "😀".encode()[:2], fin=False), ) client.receive_data(b"\x00\x04\x98\x80\xf0\x9f") self.assertFrameReceived( client, - Frame(False, OP_CONT, "😀😀".encode()[2:6]), + Frame(OP_CONT, "😀😀".encode()[2:6], fin=False), ) client.receive_data(b"\x80\x02\x98\x80") self.assertFrameReceived( client, - Frame(True, OP_CONT, "😀".encode()[2:]), + Frame(OP_CONT, "😀".encode()[2:]), ) def test_server_receives_fragmented_text_without_size_limit(self): @@ -382,17 +380,17 @@ def test_server_receives_fragmented_text_without_size_limit(self): server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") self.assertFrameReceived( server, - Frame(False, OP_TEXT, "😀".encode()[:2]), + Frame(OP_TEXT, "😀".encode()[:2], fin=False), ) server.receive_data(b"\x00\x84\x00\x00\x00\x00\x98\x80\xf0\x9f") self.assertFrameReceived( server, - Frame(False, OP_CONT, "😀😀".encode()[2:6]), + Frame(OP_CONT, "😀😀".encode()[2:6], fin=False), ) server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") self.assertFrameReceived( server, - Frame(True, OP_CONT, "😀".encode()[2:]), + Frame(OP_CONT, "😀".encode()[2:]), ) def test_client_sends_unexpected_text(self): @@ -414,7 +412,7 @@ def test_client_receives_unexpected_text(self): client.receive_data(b"\x01\x00") self.assertFrameReceived( client, - Frame(False, OP_TEXT, b""), + Frame(OP_TEXT, b"", fin=False), ) with self.assertRaises(ProtocolError) as raised: client.receive_data(b"\x01\x00") @@ -426,7 +424,7 @@ def test_server_receives_unexpected_text(self): server.receive_data(b"\x01\x80\x00\x00\x00\x00") self.assertFrameReceived( server, - Frame(False, OP_TEXT, b""), + Frame(OP_TEXT, b"", fin=False), ) with self.assertRaises(ProtocolError) as raised: server.receive_data(b"\x01\x80\x00\x00\x00\x00") @@ -489,7 +487,7 @@ def test_client_receives_binary(self): client.receive_data(b"\x82\x04\x01\x02\xfe\xff") self.assertFrameReceived( client, - Frame(True, OP_BINARY, b"\x01\x02\xfe\xff"), + Frame(OP_BINARY, b"\x01\x02\xfe\xff"), ) def test_server_receives_binary(self): @@ -497,7 +495,7 @@ def test_server_receives_binary(self): server.receive_data(b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff") self.assertFrameReceived( server, - Frame(True, OP_BINARY, b"\x01\x02\xfe\xff"), + Frame(OP_BINARY, b"\x01\x02\xfe\xff"), ) def test_client_receives_binary_over_size_limit(self): @@ -542,17 +540,17 @@ def test_client_receives_fragmented_binary(self): client.receive_data(b"\x02\x02\x01\x02") self.assertFrameReceived( client, - Frame(False, OP_BINARY, b"\x01\x02"), + Frame(OP_BINARY, b"\x01\x02", fin=False), ) client.receive_data(b"\x00\x04\xfe\xff\x01\x02") self.assertFrameReceived( client, - Frame(False, OP_CONT, b"\xfe\xff\x01\x02"), + Frame(OP_CONT, b"\xfe\xff\x01\x02", fin=False), ) client.receive_data(b"\x80\x02\xfe\xff") self.assertFrameReceived( client, - Frame(True, OP_CONT, b"\xfe\xff"), + Frame(OP_CONT, b"\xfe\xff"), ) def test_server_receives_fragmented_binary(self): @@ -560,17 +558,17 @@ def test_server_receives_fragmented_binary(self): server.receive_data(b"\x02\x82\x00\x00\x00\x00\x01\x02") self.assertFrameReceived( server, - Frame(False, OP_BINARY, b"\x01\x02"), + Frame(OP_BINARY, b"\x01\x02", fin=False), ) server.receive_data(b"\x00\x84\x00\x00\x00\x00\xee\xff\x01\x02") self.assertFrameReceived( server, - Frame(False, OP_CONT, b"\xee\xff\x01\x02"), + Frame(OP_CONT, b"\xee\xff\x01\x02", fin=False), ) server.receive_data(b"\x80\x82\x00\x00\x00\x00\xfe\xff") self.assertFrameReceived( server, - Frame(True, OP_CONT, b"\xfe\xff"), + Frame(OP_CONT, b"\xfe\xff"), ) def test_client_receives_fragmented_binary_over_size_limit(self): @@ -578,7 +576,7 @@ def test_client_receives_fragmented_binary_over_size_limit(self): client.receive_data(b"\x02\x02\x01\x02") self.assertFrameReceived( client, - Frame(False, OP_BINARY, b"\x01\x02"), + Frame(OP_BINARY, b"\x01\x02", fin=False), ) with self.assertRaises(PayloadTooBig) as raised: client.receive_data(b"\x80\x02\xfe\xff") @@ -590,7 +588,7 @@ def test_server_receives_fragmented_binary_over_size_limit(self): server.receive_data(b"\x02\x82\x00\x00\x00\x00\x01\x02") self.assertFrameReceived( server, - Frame(False, OP_BINARY, b"\x01\x02"), + Frame(OP_BINARY, b"\x01\x02", fin=False), ) with self.assertRaises(PayloadTooBig) as raised: server.receive_data(b"\x80\x82\x00\x00\x00\x00\xfe\xff") @@ -616,7 +614,7 @@ def test_client_receives_unexpected_binary(self): client.receive_data(b"\x02\x00") self.assertFrameReceived( client, - Frame(False, OP_BINARY, b""), + Frame(OP_BINARY, b"", fin=False), ) with self.assertRaises(ProtocolError) as raised: client.receive_data(b"\x02\x00") @@ -628,7 +626,7 @@ def test_server_receives_unexpected_binary(self): server.receive_data(b"\x02\x80\x00\x00\x00\x00") self.assertFrameReceived( server, - Frame(False, OP_BINARY, b""), + Frame(OP_BINARY, b"", fin=False), ) with self.assertRaises(ProtocolError) as raised: server.receive_data(b"\x02\x80\x00\x00\x00\x00") @@ -690,14 +688,14 @@ def test_client_receives_close(self): client = Connection(Side.CLIENT) with self.enforce_mask(b"\x3c\x3c\x3c\x3c"): client.receive_data(b"\x88\x00") - self.assertEqual(client.events_received(), [Frame(True, OP_CLOSE, b"")]) + self.assertEqual(client.events_received(), [Frame(OP_CLOSE, b"")]) self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) self.assertIs(client.state, State.CLOSING) def test_server_receives_close(self): server = Connection(Side.SERVER) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") - self.assertEqual(server.events_received(), [Frame(True, OP_CLOSE, b"")]) + self.assertEqual(server.events_received(), [Frame(OP_CLOSE, b"")]) self.assertEqual(server.data_to_send(), [b"\x88\x00", b""]) self.assertIs(server.state, State.CLOSING) @@ -707,10 +705,10 @@ def test_client_sends_close_then_receives_close(self): client.send_close() self.assertFrameReceived(client, None) - self.assertFrameSent(client, Frame(True, OP_CLOSE, b"")) + self.assertFrameSent(client, Frame(OP_CLOSE, b"")) client.receive_data(b"\x88\x00") - self.assertFrameReceived(client, Frame(True, OP_CLOSE, b"")) + self.assertFrameReceived(client, Frame(OP_CLOSE, b"")) self.assertFrameSent(client, None) client.receive_eof() @@ -723,10 +721,10 @@ def test_server_sends_close_then_receives_close(self): server.send_close() self.assertFrameReceived(server, None) - self.assertFrameSent(server, Frame(True, OP_CLOSE, b""), eof=True) + self.assertFrameSent(server, Frame(OP_CLOSE, b""), eof=True) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") - self.assertFrameReceived(server, Frame(True, OP_CLOSE, b"")) + self.assertFrameReceived(server, Frame(OP_CLOSE, b"")) self.assertFrameSent(server, None) server.receive_eof() @@ -738,8 +736,8 @@ def test_client_receives_close_then_sends_close(self): client = Connection(Side.CLIENT) client.receive_data(b"\x88\x00") - self.assertFrameReceived(client, Frame(True, OP_CLOSE, b"")) - self.assertFrameSent(client, Frame(True, OP_CLOSE, b"")) + self.assertFrameReceived(client, Frame(OP_CLOSE, b"")) + self.assertFrameSent(client, Frame(OP_CLOSE, b"")) client.receive_eof() self.assertFrameReceived(client, None) @@ -750,8 +748,8 @@ def test_server_receives_close_then_sends_close(self): server = Connection(Side.SERVER) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") - self.assertFrameReceived(server, Frame(True, OP_CLOSE, b"")) - self.assertFrameSent(server, Frame(True, OP_CLOSE, b""), eof=True) + self.assertFrameReceived(server, Frame(OP_CLOSE, b"")) + self.assertFrameSent(server, Frame(OP_CLOSE, b""), eof=True) server.receive_eof() self.assertFrameReceived(server, None) @@ -882,11 +880,11 @@ def test_client_receives_ping(self): client.receive_data(b"\x89\x00") self.assertFrameReceived( client, - Frame(True, OP_PING, b""), + Frame(OP_PING, b""), ) self.assertFrameSent( client, - Frame(True, OP_PONG, b""), + Frame(OP_PONG, b""), ) def test_server_receives_ping(self): @@ -894,11 +892,11 @@ def test_server_receives_ping(self): server.receive_data(b"\x89\x80\x00\x44\x88\xcc") self.assertFrameReceived( server, - Frame(True, OP_PING, b""), + Frame(OP_PING, b""), ) self.assertFrameSent( server, - Frame(True, OP_PONG, b""), + Frame(OP_PONG, b""), ) def test_client_sends_ping_with_data(self): @@ -919,11 +917,11 @@ def test_client_receives_ping_with_data(self): client.receive_data(b"\x89\x04\x22\x66\xaa\xee") self.assertFrameReceived( client, - Frame(True, OP_PING, b"\x22\x66\xaa\xee"), + Frame(OP_PING, b"\x22\x66\xaa\xee"), ) self.assertFrameSent( client, - Frame(True, OP_PONG, b"\x22\x66\xaa\xee"), + Frame(OP_PONG, b"\x22\x66\xaa\xee"), ) def test_server_receives_ping_with_data(self): @@ -931,25 +929,25 @@ def test_server_receives_ping_with_data(self): server.receive_data(b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22") self.assertFrameReceived( server, - Frame(True, OP_PING, b"\x22\x66\xaa\xee"), + Frame(OP_PING, b"\x22\x66\xaa\xee"), ) self.assertFrameSent( server, - Frame(True, OP_PONG, b"\x22\x66\xaa\xee"), + Frame(OP_PONG, b"\x22\x66\xaa\xee"), ) def test_client_sends_fragmented_ping_frame(self): client = Connection(Side.CLIENT) # This is only possible through a private API. with self.assertRaises(ProtocolError) as raised: - client.send_frame(Frame(False, OP_PING, b"")) + client.send_frame(Frame(OP_PING, b"", fin=False)) self.assertEqual(str(raised.exception), "fragmented control frame") def test_server_sends_fragmented_ping_frame(self): server = Connection(Side.SERVER) # This is only possible through a private API. with self.assertRaises(ProtocolError) as raised: - server.send_frame(Frame(False, OP_PING, b"")) + server.send_frame(Frame(OP_PING, b"", fin=False)) self.assertEqual(str(raised.exception), "fragmented control frame") def test_client_receives_fragmented_ping_frame(self): @@ -1000,7 +998,7 @@ def test_client_receives_ping_after_receiving_close(self): client.receive_data(b"\x89\x04\x22\x66\xaa\xee") self.assertFrameReceived( client, - Frame(True, OP_PING, b"\x22\x66\xaa\xee"), + Frame(OP_PING, b"\x22\x66\xaa\xee"), ) self.assertFrameSent(client, None) @@ -1011,7 +1009,7 @@ def test_server_receives_ping_after_receiving_close(self): server.receive_data(b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22") self.assertFrameReceived( server, - Frame(True, OP_PING, b"\x22\x66\xaa\xee"), + Frame(OP_PING, b"\x22\x66\xaa\xee"), ) self.assertFrameSent(server, None) @@ -1038,7 +1036,7 @@ def test_client_receives_pong(self): client.receive_data(b"\x8a\x00") self.assertFrameReceived( client, - Frame(True, OP_PONG, b""), + Frame(OP_PONG, b""), ) def test_server_receives_pong(self): @@ -1046,7 +1044,7 @@ def test_server_receives_pong(self): server.receive_data(b"\x8a\x80\x00\x44\x88\xcc") self.assertFrameReceived( server, - Frame(True, OP_PONG, b""), + Frame(OP_PONG, b""), ) def test_client_sends_pong_with_data(self): @@ -1067,7 +1065,7 @@ def test_client_receives_pong_with_data(self): client.receive_data(b"\x8a\x04\x22\x66\xaa\xee") self.assertFrameReceived( client, - Frame(True, OP_PONG, b"\x22\x66\xaa\xee"), + Frame(OP_PONG, b"\x22\x66\xaa\xee"), ) def test_server_receives_pong_with_data(self): @@ -1075,21 +1073,21 @@ def test_server_receives_pong_with_data(self): server.receive_data(b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22") self.assertFrameReceived( server, - Frame(True, OP_PONG, b"\x22\x66\xaa\xee"), + Frame(OP_PONG, b"\x22\x66\xaa\xee"), ) def test_client_sends_fragmented_pong_frame(self): client = Connection(Side.CLIENT) # This is only possible through a private API. with self.assertRaises(ProtocolError) as raised: - client.send_frame(Frame(False, OP_PONG, b"")) + client.send_frame(Frame(OP_PONG, b"", fin=False)) self.assertEqual(str(raised.exception), "fragmented control frame") def test_server_sends_fragmented_pong_frame(self): server = Connection(Side.SERVER) # This is only possible through a private API. with self.assertRaises(ProtocolError) as raised: - server.send_frame(Frame(False, OP_PONG, b"")) + server.send_frame(Frame(OP_PONG, b"", fin=False)) self.assertEqual(str(raised.exception), "fragmented control frame") def test_client_receives_fragmented_pong_frame(self): @@ -1130,7 +1128,7 @@ def test_client_receives_pong_after_receiving_close(self): client.receive_data(b"\x8a\x04\x22\x66\xaa\xee") self.assertFrameReceived( client, - Frame(True, OP_PONG, b"\x22\x66\xaa\xee"), + Frame(OP_PONG, b"\x22\x66\xaa\xee"), ) def test_server_receives_pong_after_receiving_close(self): @@ -1140,7 +1138,7 @@ def test_server_receives_pong_after_receiving_close(self): server.receive_data(b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22") self.assertFrameReceived( server, - Frame(True, OP_PONG, b"\x22\x66\xaa\xee"), + Frame(OP_PONG, b"\x22\x66\xaa\xee"), ) @@ -1155,59 +1153,59 @@ class FragmentationTests(ConnectionTestCase): def test_client_send_ping_pong_in_fragmented_message(self): client = Connection(Side.CLIENT) client.send_text(b"Spam", fin=False) - self.assertFrameSent(client, Frame(False, OP_TEXT, b"Spam")) + self.assertFrameSent(client, Frame(OP_TEXT, b"Spam", fin=False)) client.send_ping(b"Ping") - self.assertFrameSent(client, Frame(True, OP_PING, b"Ping")) + self.assertFrameSent(client, Frame(OP_PING, b"Ping")) client.send_continuation(b"Ham", fin=False) - self.assertFrameSent(client, Frame(False, OP_CONT, b"Ham")) + self.assertFrameSent(client, Frame(OP_CONT, b"Ham", fin=False)) client.send_pong(b"Pong") - self.assertFrameSent(client, Frame(True, OP_PONG, b"Pong")) + self.assertFrameSent(client, Frame(OP_PONG, b"Pong")) client.send_continuation(b"Eggs", fin=True) - self.assertFrameSent(client, Frame(True, OP_CONT, b"Eggs")) + self.assertFrameSent(client, Frame(OP_CONT, b"Eggs")) def test_server_send_ping_pong_in_fragmented_message(self): server = Connection(Side.SERVER) server.send_text(b"Spam", fin=False) - self.assertFrameSent(server, Frame(False, OP_TEXT, b"Spam")) + self.assertFrameSent(server, Frame(OP_TEXT, b"Spam", fin=False)) server.send_ping(b"Ping") - self.assertFrameSent(server, Frame(True, OP_PING, b"Ping")) + self.assertFrameSent(server, Frame(OP_PING, b"Ping")) server.send_continuation(b"Ham", fin=False) - self.assertFrameSent(server, Frame(False, OP_CONT, b"Ham")) + self.assertFrameSent(server, Frame(OP_CONT, b"Ham", fin=False)) server.send_pong(b"Pong") - self.assertFrameSent(server, Frame(True, OP_PONG, b"Pong")) + self.assertFrameSent(server, Frame(OP_PONG, b"Pong")) server.send_continuation(b"Eggs", fin=True) - self.assertFrameSent(server, Frame(True, OP_CONT, b"Eggs")) + self.assertFrameSent(server, Frame(OP_CONT, b"Eggs")) def test_client_receive_ping_pong_in_fragmented_message(self): client = Connection(Side.CLIENT) client.receive_data(b"\x01\x04Spam") self.assertFrameReceived( client, - Frame(False, OP_TEXT, b"Spam"), + Frame(OP_TEXT, b"Spam", fin=False), ) client.receive_data(b"\x89\x04Ping") self.assertFrameReceived( client, - Frame(True, OP_PING, b"Ping"), + Frame(OP_PING, b"Ping"), ) self.assertFrameSent( client, - Frame(True, OP_PONG, b"Ping"), + Frame(OP_PONG, b"Ping"), ) client.receive_data(b"\x00\x03Ham") self.assertFrameReceived( client, - Frame(False, OP_CONT, b"Ham"), + Frame(OP_CONT, b"Ham", fin=False), ) client.receive_data(b"\x8a\x04Pong") self.assertFrameReceived( client, - Frame(True, OP_PONG, b"Pong"), + Frame(OP_PONG, b"Pong"), ) client.receive_data(b"\x80\x04Eggs") self.assertFrameReceived( client, - Frame(True, OP_CONT, b"Eggs"), + Frame(OP_CONT, b"Eggs"), ) def test_server_receive_ping_pong_in_fragmented_message(self): @@ -1215,37 +1213,37 @@ def test_server_receive_ping_pong_in_fragmented_message(self): server.receive_data(b"\x01\x84\x00\x00\x00\x00Spam") self.assertFrameReceived( server, - Frame(False, OP_TEXT, b"Spam"), + Frame(OP_TEXT, b"Spam", fin=False), ) server.receive_data(b"\x89\x84\x00\x00\x00\x00Ping") self.assertFrameReceived( server, - Frame(True, OP_PING, b"Ping"), + Frame(OP_PING, b"Ping"), ) self.assertFrameSent( server, - Frame(True, OP_PONG, b"Ping"), + Frame(OP_PONG, b"Ping"), ) server.receive_data(b"\x00\x83\x00\x00\x00\x00Ham") self.assertFrameReceived( server, - Frame(False, OP_CONT, b"Ham"), + Frame(OP_CONT, b"Ham", fin=False), ) server.receive_data(b"\x8a\x84\x00\x00\x00\x00Pong") self.assertFrameReceived( server, - Frame(True, OP_PONG, b"Pong"), + Frame(OP_PONG, b"Pong"), ) server.receive_data(b"\x80\x84\x00\x00\x00\x00Eggs") self.assertFrameReceived( server, - Frame(True, OP_CONT, b"Eggs"), + Frame(OP_CONT, b"Eggs"), ) def test_client_send_close_in_fragmented_message(self): client = Connection(Side.CLIENT) client.send_text(b"Spam", fin=False) - self.assertFrameSent(client, Frame(False, OP_TEXT, b"Spam")) + self.assertFrameSent(client, Frame(OP_TEXT, b"Spam", fin=False)) # The spec says: "An endpoint MUST be capable of handling control # frames in the middle of a fragmented message." However, since the # endpoint must not send a data frame after a close frame, a close @@ -1258,7 +1256,7 @@ def test_client_send_close_in_fragmented_message(self): def test_server_send_close_in_fragmented_message(self): server = Connection(Side.CLIENT) server.send_text(b"Spam", fin=False) - self.assertFrameSent(server, Frame(False, OP_TEXT, b"Spam")) + self.assertFrameSent(server, Frame(OP_TEXT, b"Spam", fin=False)) # The spec says: "An endpoint MUST be capable of handling control # frames in the middle of a fragmented message." However, since the # endpoint must not send a data frame after a close frame, a close @@ -1272,7 +1270,7 @@ def test_client_receive_close_in_fragmented_message(self): client.receive_data(b"\x01\x04Spam") self.assertFrameReceived( client, - Frame(False, OP_TEXT, b"Spam"), + Frame(OP_TEXT, b"Spam", fin=False), ) # The spec says: "An endpoint MUST be capable of handling control # frames in the middle of a fragmented message." However, since the @@ -1288,7 +1286,7 @@ def test_server_receive_close_in_fragmented_message(self): server.receive_data(b"\x01\x84\x00\x00\x00\x00Spam") self.assertFrameReceived( server, - Frame(False, OP_TEXT, b"Spam"), + Frame(OP_TEXT, b"Spam", fin=False), ) # The spec says: "An endpoint MUST be capable of handling control # frames in the middle of a fragmented message." However, since the @@ -1477,10 +1475,10 @@ def test_client_extension_decodes_frame(self): client = Connection(Side.CLIENT) client.extensions = [Rsv2Extension()] client.receive_data(b"\xaa\x00") - self.assertEqual(client.events_received(), [Frame(True, OP_PONG, b"")]) + self.assertEqual(client.events_received(), [Frame(OP_PONG, b"")]) def test_server_extension_decodes_frame(self): server = Connection(Side.SERVER) server.extensions = [Rsv2Extension()] server.receive_data(b"\xaa\x80\x00\x44\x88\xcc") - self.assertEqual(server.events_received(), [Frame(True, OP_PONG, b"")]) + self.assertEqual(server.events_received(), [Frame(OP_PONG, b"")]) diff --git a/tests/test_frames.py b/tests/test_frames.py index c05fa43a5..d36344639 100644 --- a/tests/test_frames.py +++ b/tests/test_frames.py @@ -48,77 +48,77 @@ def assertFrameData(self, frame, data, mask, extensions=None): class FrameTests(FramesTestCase): def test_text_unmasked(self): self.assertFrameData( - Frame(True, OP_TEXT, b"Spam"), + Frame(OP_TEXT, b"Spam"), b"\x81\x04Spam", mask=False, ) def test_text_masked(self): self.assertFrameData( - Frame(True, OP_TEXT, b"Spam"), + Frame(OP_TEXT, b"Spam"), b"\x81\x84\x5b\xfb\xe1\xa8\x08\x8b\x80\xc5", mask=True, ) def test_binary_unmasked(self): self.assertFrameData( - Frame(True, OP_BINARY, b"Eggs"), + Frame(OP_BINARY, b"Eggs"), b"\x82\x04Eggs", mask=False, ) def test_binary_masked(self): self.assertFrameData( - Frame(True, OP_BINARY, b"Eggs"), + Frame(OP_BINARY, b"Eggs"), b"\x82\x84\x53\xcd\xe2\x89\x16\xaa\x85\xfa", mask=True, ) def test_non_ascii_text_unmasked(self): self.assertFrameData( - Frame(True, OP_TEXT, "café".encode("utf-8")), + Frame(OP_TEXT, "café".encode("utf-8")), b"\x81\x05caf\xc3\xa9", mask=False, ) def test_non_ascii_text_masked(self): self.assertFrameData( - Frame(True, OP_TEXT, "café".encode("utf-8")), + Frame(OP_TEXT, "café".encode("utf-8")), b"\x81\x85\x64\xbe\xee\x7e\x07\xdf\x88\xbd\xcd", mask=True, ) def test_close(self): self.assertFrameData( - Frame(True, OP_CLOSE, b""), + Frame(OP_CLOSE, b""), b"\x88\x00", mask=False, ) def test_ping(self): self.assertFrameData( - Frame(True, OP_PING, b"ping"), + Frame(OP_PING, b"ping"), b"\x89\x04ping", mask=False, ) def test_pong(self): self.assertFrameData( - Frame(True, OP_PONG, b"pong"), + Frame(OP_PONG, b"pong"), b"\x8a\x04pong", mask=False, ) def test_long(self): self.assertFrameData( - Frame(True, OP_BINARY, 126 * b"a"), + Frame(OP_BINARY, 126 * b"a"), b"\x82\x7e\x00\x7e" + 126 * b"a", mask=False, ) def test_very_long(self): self.assertFrameData( - Frame(True, OP_BINARY, 65536 * b"a"), + Frame(OP_BINARY, 65536 * b"a"), b"\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x00" + 65536 * b"a", mask=False, ) @@ -187,7 +187,7 @@ def decode(frame, *, max_size=None): return Rot13.encode(frame) self.assertFrameData( - Frame(True, OP_TEXT, b"hello"), + Frame(OP_TEXT, b"hello"), b"\x81\x05uryyb", mask=False, extensions=[Rot13()], @@ -197,125 +197,125 @@ def decode(frame, *, max_size=None): class StrTests(unittest.TestCase): def test_cont_text(self): self.assertEqual( - str(Frame(False, OP_CONT, b" cr\xc3\xa8me")), + str(Frame(OP_CONT, b" cr\xc3\xa8me", fin=False)), "CONT crème [text, 7 bytes, continued]", ) def test_cont_binary(self): self.assertEqual( - str(Frame(False, OP_CONT, b"\xfc\xfd\xfe\xff")), + str(Frame(OP_CONT, b"\xfc\xfd\xfe\xff", fin=False)), "CONT fc fd fe ff [binary, 4 bytes, continued]", ) def test_cont_final_text(self): self.assertEqual( - str(Frame(True, OP_CONT, b" cr\xc3\xa8me")), + str(Frame(OP_CONT, b" cr\xc3\xa8me")), "CONT crème [text, 7 bytes]", ) def test_cont_final_binary(self): self.assertEqual( - str(Frame(True, OP_CONT, b"\xfc\xfd\xfe\xff")), + str(Frame(OP_CONT, b"\xfc\xfd\xfe\xff")), "CONT fc fd fe ff [binary, 4 bytes]", ) def test_cont_text_truncated(self): self.assertEqual( - str(Frame(False, OP_CONT, b"caf\xc3\xa9 " * 16)), + str(Frame(OP_CONT, b"caf\xc3\xa9 " * 16, fin=False)), "CONT café café café café café café café café café caf..." "afé café café café café [text, 96 bytes, continued]", ) def test_cont_binary_truncated(self): self.assertEqual( - str(Frame(False, OP_CONT, bytes(range(256)))), + str(Frame(OP_CONT, bytes(range(256)), fin=False)), "CONT 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f ..." " f8 f9 fa fb fc fd fe ff [binary, 256 bytes, continued]", ) def test_text(self): self.assertEqual( - str(Frame(True, OP_TEXT, b"caf\xc3\xa9")), + str(Frame(OP_TEXT, b"caf\xc3\xa9")), "TEXT café [5 bytes]", ) def test_text_non_final(self): self.assertEqual( - str(Frame(False, OP_TEXT, b"caf\xc3\xa9")), + str(Frame(OP_TEXT, b"caf\xc3\xa9", fin=False)), "TEXT café [5 bytes, continued]", ) def test_text_truncated(self): self.assertEqual( - str(Frame(True, OP_TEXT, b"caf\xc3\xa9 " * 16)), + str(Frame(OP_TEXT, b"caf\xc3\xa9 " * 16)), "TEXT café café café café café café café café café caf..." "afé café café café café [96 bytes]", ) def test_binary(self): self.assertEqual( - str(Frame(True, OP_BINARY, b"\x00\x01\x02\x03")), + str(Frame(OP_BINARY, b"\x00\x01\x02\x03")), "BINARY 00 01 02 03 [4 bytes]", ) def test_binary_non_final(self): self.assertEqual( - str(Frame(False, OP_BINARY, b"\x00\x01\x02\x03")), + str(Frame(OP_BINARY, b"\x00\x01\x02\x03", fin=False)), "BINARY 00 01 02 03 [4 bytes, continued]", ) def test_binary_truncated(self): self.assertEqual( - str(Frame(True, OP_BINARY, bytes(range(256)))), + str(Frame(OP_BINARY, bytes(range(256)))), "BINARY 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f ..." " f8 f9 fa fb fc fd fe ff [256 bytes]", ) def test_close(self): self.assertEqual( - str(Frame(True, OP_CLOSE, b"\x03\xe8")), + str(Frame(OP_CLOSE, b"\x03\xe8")), "CLOSE code = 1000 (OK), no reason [2 bytes]", ) def test_close_reason(self): self.assertEqual( - str(Frame(True, OP_CLOSE, b"\x03\xe9Bye!")), + str(Frame(OP_CLOSE, b"\x03\xe9Bye!")), "CLOSE code = 1001 (going away), reason = Bye! [6 bytes]", ) def test_ping(self): self.assertEqual( - str(Frame(True, OP_PING, b"")), + str(Frame(OP_PING, b"")), "PING [0 bytes]", ) def test_ping_text(self): self.assertEqual( - str(Frame(True, OP_PING, b"ping")), + str(Frame(OP_PING, b"ping")), "PING ping [text, 4 bytes]", ) def test_ping_binary(self): self.assertEqual( - str(Frame(True, OP_PING, b"\xff\x00\xff\x00")), + str(Frame(OP_PING, b"\xff\x00\xff\x00")), "PING ff 00 ff 00 [binary, 4 bytes]", ) def test_pong(self): self.assertEqual( - str(Frame(True, OP_PONG, b"")), + str(Frame(OP_PONG, b"")), "PONG [0 bytes]", ) def test_pong_text(self): self.assertEqual( - str(Frame(True, OP_PONG, b"pong")), + str(Frame(OP_PONG, b"pong")), "PONG pong [text, 4 bytes]", ) def test_pong_binary(self): self.assertEqual( - str(Frame(True, OP_PONG, b"\xff\x00\xff\x00")), + str(Frame(OP_PONG, b"\xff\x00\xff\x00")), "PONG ff 00 ff 00 [binary, 4 bytes]", ) diff --git a/tests/test_server.py b/tests/test_server.py index 86fa3f34d..d2c41598e 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -634,7 +634,7 @@ def test_bypass_handshake(self): server = ServerConnection(state=OPEN) server.receive_data(b"\x81\x86\x00\x00\x00\x00Hello!") [frame] = server.events_received() - self.assertEqual(frame, Frame(True, OP_TEXT, b"Hello!")) + self.assertEqual(frame, Frame(OP_TEXT, b"Hello!")) def test_custom_logger(self): logger = logging.getLogger("test") From d5ccd3658b29fea1b57b869d7b9911efa333ff8d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 3 Jun 2021 21:04:19 +0200 Subject: [PATCH 0846/1539] Fix docstring after refactoring. --- src/websockets/frames.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index ec6ff2258..440f5a93b 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -106,7 +106,7 @@ class Frame: :param bytes data: payload data Only these fields are needed. The MASK bit, payload length and masking-key - are handled on the fly by :func:`parse_frame` and :meth:`serialize_frame`. + are handled on the fly by :meth:`parse` and :meth:`serialize`. """ From d3d1d0550c62ac68edb09dc6f6a11b4715148251 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Jun 2021 07:39:01 +0200 Subject: [PATCH 0847/1539] Store response headers in InvalidStatusCode. Also introduce InvalidStatus for the new implementation, storing the entire response (but not for the legacy implementationn). Fix #712. --- docs/project/changelog.rst | 2 +- src/websockets/__init__.py | 2 ++ src/websockets/client.py | 4 ++-- src/websockets/exceptions.py | 25 ++++++++++++++++++++++--- src/websockets/legacy/client.py | 2 +- tests/test_exceptions.py | 2 +- 6 files changed, 29 insertions(+), 8 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index caf637614..2a0199147 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -548,7 +548,7 @@ Also: * Added an optional C extension to speed up low-level operations. * An invalid response status code during :func:`~legacy.client.connect` now - raises :class:`~exceptions.InvalidStatusCode` with a ``code`` attribute. + raises :class:`~exceptions.InvalidStatusCode`. * Providing a ``sock`` argument to :func:`~legacy.client.connect` no longer crashes. diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index f136a4e45..8c69a6d63 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -24,6 +24,7 @@ "InvalidParameterName", "InvalidParameterValue", "InvalidState", + "InvalidStatus", "InvalidStatusCode", "InvalidUpgrade", "InvalidURI", @@ -73,6 +74,7 @@ "InvalidHeaderValue": ".exceptions", "InvalidOrigin": ".exceptions", "InvalidUpgrade": ".exceptions", + "InvalidStatus": ".exceptions", "InvalidStatusCode": ".exceptions", "NegotiationError": ".exceptions", "DuplicateParameter": ".exceptions", diff --git a/src/websockets/client.py b/src/websockets/client.py index e9bc12cbe..a517140cd 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -7,7 +7,7 @@ InvalidHandshake, InvalidHeader, InvalidHeaderValue, - InvalidStatusCode, + InvalidStatus, InvalidUpgrade, NegotiationError, ) @@ -127,7 +127,7 @@ def process_response(self, response: Response) -> None: """ if response.status_code != 101: - raise InvalidStatusCode(response.status_code) + raise InvalidStatus(response) headers = response.headers diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 4bd2a41a6..644735258 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -13,7 +13,8 @@ * :exc:`InvalidHeaderValue` * :exc:`InvalidOrigin` * :exc:`InvalidUpgrade` - * :exc:`InvalidStatusCode` + * :exc:`InvalidStatus` + * :exc:`InvalidStatusCode` (legacy) * :exc:`NegotiationError` * :exc:`DuplicateParameter` * :exc:`InvalidParameterName` @@ -46,6 +47,7 @@ "InvalidHeaderValue", "InvalidOrigin", "InvalidUpgrade", + "InvalidStatus", "InvalidStatusCode", "NegotiationError", "DuplicateParameter", @@ -190,16 +192,30 @@ class InvalidUpgrade(InvalidHeader): """ +class InvalidStatus(InvalidHandshake): + """ + Raised when a handshake response rejects the WebSocket upgrade. + + """ + + def __init__(self, response: "Response") -> None: + self.response = response + message = f"server rejected WebSocket connection: HTTP {response.status_code:d}" + super().__init__(message) + + class InvalidStatusCode(InvalidHandshake): """ Raised when a handshake response status code is invalid. - The integer status code is available in the ``status_code`` attribute. + The integer status code is available in the ``status_code`` attribute and + HTTP headers in the ``headers`` attribute. """ - def __init__(self, status_code: int) -> None: + def __init__(self, status_code: int, headers: Headers) -> None: self.status_code = status_code + self.headers = headers message = f"server rejected WebSocket connection: HTTP {status_code}" super().__init__(message) @@ -333,3 +349,6 @@ class ProtocolError(WebSocketException): # at the bottom to allow circular import, because the frames module imports exceptions from .frames import format_close # noqa + +# at the bottom to allow circular import, because the http11 module imports exceptions +from .http11 import Response # noqa diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 5c281d3b8..68ab9acf1 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -391,7 +391,7 @@ async def handshake( raise InvalidHeader("Location") raise RedirectHandshake(response_headers["Location"]) elif status_code != 101: - raise InvalidStatusCode(status_code) + raise InvalidStatusCode(status_code, response_headers) check_response(response_headers, key) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index b800d4f91..094fb6d33 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -88,7 +88,7 @@ def test_str(self): "invalid Connection header: websocket", ), ( - InvalidStatusCode(403), + InvalidStatusCode(403, Headers()), "server rejected WebSocket connection: HTTP 403", ), ( From 0c157d457b871a3055329a50ac30cef640787b67 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Jun 2021 08:21:40 +0200 Subject: [PATCH 0848/1539] Add timeout to connect. Fix #574. --- docs/project/changelog.rst | 10 ++++++++++ src/websockets/legacy/client.py | 11 ++++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 2a0199147..603466593 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -34,6 +34,14 @@ They may change at any time. **Version 10.0 drops compatibility with Python 3.6.** +.. note:: + + **Version 10.0 enables a timeout of 10 seconds on** + :func:`~legacy.client.connect` **by default.** + + You can adjust the timeout with the ``open_timeout`` parameter. Set it to + ``None`` to disable the timeout entirely. + .. note:: **Version 10.0 deprecates the** ``loop`` **parameter from all APIs.** @@ -43,6 +51,8 @@ They may change at any time. * Added compatibility with Python 3.10. +* Added ``open_timeout`` to :func:`~legacy.client.connect`. + * Improved logging. * Optimized default compression settings to reduce memory usage. diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 68ab9acf1..df1a2f57c 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -439,6 +439,10 @@ class Connect: be replaced by a wrapper or a subclass to customize the protocol that manages the connection. + If the WebSocket connection isn't established within ``open_timeout`` + seconds, :func:`connect` raises :exc:`~asyncio.TimeoutError`. The default + is 10 seconds. Set ``open_timeout`` to ``None`` to disable the timeout. + The behavior of ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit`` is described in :class:`WebSocketClientProtocol`. @@ -471,6 +475,7 @@ def __init__( uri: str, *, create_protocol: Optional[Callable[[Any], WebSocketClientProtocol]] = None, + open_timeout: Optional[float] = 10, ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, close_timeout: Optional[float] = None, @@ -571,6 +576,7 @@ def __init__( loop.create_connection, factory, host, port, **kwargs ) + self.open_timeout = open_timeout # This is a coroutine function. self._create_connection = create_connection self._wsuri = wsuri @@ -626,7 +632,10 @@ async def __aexit__( def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]: # Create a suitable iterator by calling __await__ on a coroutine. - return self.__await_impl__().__await__() + return self.__await_impl_timeout__().__await__() + + async def __await_impl_timeout__(self) -> WebSocketClientProtocol: + return await asyncio.wait_for(self.__await_impl__(), self.open_timeout) async def __await_impl__(self) -> WebSocketClientProtocol: for redirects in range(self.MAX_REDIRECTS_ALLOWED): From 0779eb973d2779c33b560488212ea8b652601596 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 4 Jun 2021 23:08:11 +0200 Subject: [PATCH 0849/1539] Prevent crash of data transfer task When the read buffer isn't empty, the following scenario is possible: * TCP connection is closed (one way or another) * transfer_data is still processing buffered data * it encounters a ping and attempts to send a pong transfer_data must never crash, so ignore the error in this case. Fix #977 --- src/websockets/legacy/protocol.py | 9 ++++++--- tests/legacy/test_protocol.py | 10 ++++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 992d678e2..353317ab5 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -926,14 +926,17 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: # is empty and parse_close() synthetizes a 1005 close code. await self.write_close_frame(frame.data) except ConnectionClosed: - # It doesn't really matter if the connection was closed - # before we could send back a close frame. + # Connection closed before we could echo the close frame. pass return None elif frame.opcode == OP_PING: # Answer pings. - await self.pong(frame.data) + try: + await self.pong(frame.data) + except ConnectionClosed: + # Connection closed before we could respond to the ping. + pass elif frame.opcode == OP_PONG: if frame.data in self.pings: diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 6f6e6f686..c1948c838 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -877,6 +877,16 @@ def test_answer_ping(self): self.run_loop_once() self.assertOneFrameSent(True, OP_PONG, b"test") + def test_answer_ping_does_not_crash_if_connection_closed(self): + self.make_drain_slow() + # Drop the connection right after receiving a ping frame, + # which prevents responding wwith a pong frame properly. + self.receive_frame(Frame(True, OP_PING, b"test")) + self.receive_eof() + + with self.assertNoLogs(): + self.loop.run_until_complete(self.protocol.close()) + def test_ignore_pong(self): self.receive_frame(Frame(True, OP_PONG, b"test")) self.run_loop_once() From ebc8448b6603977c4e81e182e9827705c9a5889c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Jun 2021 13:14:04 +0200 Subject: [PATCH 0850/1539] Improve representation of frames. Specifically this avoids cutting log messages in two lines. Ref #765. --- src/websockets/frames.py | 6 +++--- tests/test_frames.py | 42 ++++++++++++++++++++++++++++------------ 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 440f5a93b..9011ac867 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -129,7 +129,7 @@ def __str__(self) -> str: if self.opcode is OP_TEXT: # Decoding only the beginning and the end is needlessly hard. # Decode the entire payload then elide later if necessary. - data = self.data.decode() + data = repr(self.data.decode()) elif self.opcode is OP_BINARY: # We'll show at most the first 16 bytes and the last 8 bytes. # Encode just what we need, plus two dummy bytes to elide later. @@ -145,7 +145,7 @@ def __str__(self) -> str: # Ping and Pong frames could contain UTF-8. Attempt to decode as # UTF-8 and display it as text; fallback to binary. try: - data = self.data.decode() + data = repr(self.data.decode()) coding = "text" except UnicodeDecodeError: binary = self.data @@ -154,7 +154,7 @@ def __str__(self) -> str: data = " ".join(f"{byte:02x}" for byte in binary) coding = "binary" else: - data = "" + data = "''" if len(data) > 75: data = data[:48] + "..." + data[-24:] diff --git a/tests/test_frames.py b/tests/test_frames.py index d36344639..85414e9d3 100644 --- a/tests/test_frames.py +++ b/tests/test_frames.py @@ -198,7 +198,7 @@ class StrTests(unittest.TestCase): def test_cont_text(self): self.assertEqual( str(Frame(OP_CONT, b" cr\xc3\xa8me", fin=False)), - "CONT crème [text, 7 bytes, continued]", + "CONT ' crème' [text, 7 bytes, continued]", ) def test_cont_binary(self): @@ -210,7 +210,7 @@ def test_cont_binary(self): def test_cont_final_text(self): self.assertEqual( str(Frame(OP_CONT, b" cr\xc3\xa8me")), - "CONT crème [text, 7 bytes]", + "CONT ' crème' [text, 7 bytes]", ) def test_cont_final_binary(self): @@ -222,8 +222,8 @@ def test_cont_final_binary(self): def test_cont_text_truncated(self): self.assertEqual( str(Frame(OP_CONT, b"caf\xc3\xa9 " * 16, fin=False)), - "CONT café café café café café café café café café caf..." - "afé café café café café [text, 96 bytes, continued]", + "CONT 'café café café café café café café café café ca..." + "fé café café café café ' [text, 96 bytes, continued]", ) def test_cont_binary_truncated(self): @@ -236,20 +236,26 @@ def test_cont_binary_truncated(self): def test_text(self): self.assertEqual( str(Frame(OP_TEXT, b"caf\xc3\xa9")), - "TEXT café [5 bytes]", + "TEXT 'café' [5 bytes]", ) def test_text_non_final(self): self.assertEqual( str(Frame(OP_TEXT, b"caf\xc3\xa9", fin=False)), - "TEXT café [5 bytes, continued]", + "TEXT 'café' [5 bytes, continued]", ) def test_text_truncated(self): self.assertEqual( str(Frame(OP_TEXT, b"caf\xc3\xa9 " * 16)), - "TEXT café café café café café café café café café caf..." - "afé café café café café [96 bytes]", + "TEXT 'café café café café café café café café café ca..." + "fé café café café café ' [96 bytes]", + ) + + def test_text_with_newline(self): + self.assertEqual( + str(Frame(OP_TEXT, b"Hello\nworld!")), + "TEXT 'Hello\\nworld!' [12 bytes]", ) def test_binary(self): @@ -286,13 +292,19 @@ def test_close_reason(self): def test_ping(self): self.assertEqual( str(Frame(OP_PING, b"")), - "PING [0 bytes]", + "PING '' [0 bytes]", ) def test_ping_text(self): self.assertEqual( str(Frame(OP_PING, b"ping")), - "PING ping [text, 4 bytes]", + "PING 'ping' [text, 4 bytes]", + ) + + def test_ping_text_with_newline(self): + self.assertEqual( + str(Frame(OP_PING, b"ping\n")), + "PING 'ping\\n' [text, 5 bytes]", ) def test_ping_binary(self): @@ -304,13 +316,19 @@ def test_ping_binary(self): def test_pong(self): self.assertEqual( str(Frame(OP_PONG, b"")), - "PONG [0 bytes]", + "PONG '' [0 bytes]", ) def test_pong_text(self): self.assertEqual( str(Frame(OP_PONG, b"pong")), - "PONG pong [text, 4 bytes]", + "PONG 'pong' [text, 4 bytes]", + ) + + def test_pong_text_with_newline(self): + self.assertEqual( + str(Frame(OP_PONG, b"pong\n")), + "PONG 'pong\\n' [text, 5 bytes]", ) def test_pong_binary(self): From e5458a16f7c4162289c248a386e21fce24fa621f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Jun 2021 22:29:22 +0200 Subject: [PATCH 0851/1539] Simplify table. Fix #985. --- docs/topics/compression.rst | 31 ++++++++++++------------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/docs/topics/compression.rst b/docs/topics/compression.rst index 206bfa7b7..e23319636 100644 --- a/docs/topics/compression.rst +++ b/docs/topics/compression.rst @@ -109,25 +109,18 @@ Here's how various compression settings affect memory usage of a single connection on a 64-bit system, as well a benchmark of compressed size and compression time for a corpus of small JSON documents. -+-------------+-------------+--------------+--------------+------------------+------------------+ -| Compression | Window Bits | Memory Level | Memory usage | Size vs. default | Time vs. default | -+=============+=============+==============+==============+==================+==================+ -| | 15 | 8 | 322 KiB | -4.0% | +15% + -+-------------+-------------+--------------+--------------+------------------+------------------+ -| | 14 | 7 | 178 KiB | -2.6% | +10% | -+-------------+-------------+--------------+--------------+------------------+------------------+ -| | 13 | 6 | 106 KiB | -1.4% | +5% | -+-------------+-------------+--------------+--------------+------------------+------------------+ -| *default* | 12 | 5 | 70 KiB | = | = | -+-------------+-------------+--------------+--------------+------------------+------------------+ -| | 11 | 4 | 52 KiB | +3.7% | -5% | -+-------------+-------------+--------------+--------------+------------------+------------------+ -| | 10 | 3 | 43 KiB | +90% | +50% | -+-------------+-------------+--------------+--------------+------------------+------------------+ -| | 9 | 2 | 39 KiB | +160% | +100% | -+-------------+-------------+--------------+--------------+------------------+------------------+ -| *disabled* | — | — | 19 KiB | +452% | — | -+-------------+-------------+--------------+--------------+------------------+------------------+ +=========== ============ ============ ================ ================ +Window Bits Memory Level Memory usage Size vs. default Time vs. default +=========== ============ ============ ================ ================ +15 8 322 KiB -4.0% +15% +14 7 178 KiB -2.6% +10% +13 6 106 KiB -1.4% +5% +**12** **5** **70 KiB** **=** **=** +11 4 52 KiB +3.7% -5% +10 3 43 KiB +90% +50% +9 2 39 KiB +160% +100% +— — 19 KiB +452% — +=========== ============ ============ ================ ================ Window Bits and Memory Level don't have to move in lockstep. However, other combinations don't yield significantly better results than those shown above. From e444fb57b88c5c446fbe406c66d230e9ce15a8d1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 7 Jun 2021 20:20:51 +0200 Subject: [PATCH 0852/1539] Add CVE reference. --- docs/project/changelog.rst | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 603466593..fc751bbd0 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -69,7 +69,11 @@ They may change at any time. **Version 9.1 fixes a security issue introduced in version 8.0.** - Version 8.0 was vulnerable to timing attacks on HTTP Basic Auth passwords. + Version 8.0 was vulnerable to timing attacks on HTTP Basic Auth passwords + (`CVE-2021-33880`_). + + .. _CVE-2021-33880: https://nvd.nist.gov/vuln/detail/CVE-2021-33880 + 9.0.2 ..... From cc1254b28867fcd2391e436fe5f10f7f40c77729 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 7 Jun 2021 21:01:23 +0200 Subject: [PATCH 0853/1539] Don't use get_running_loop outside of coroutines. get_event_loop() returns the running loop if there's one anyway. --- src/websockets/legacy/client.py | 3 +-- src/websockets/legacy/compatibility.py | 11 ----------- src/websockets/legacy/server.py | 4 ++-- 3 files changed, 3 insertions(+), 15 deletions(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index df1a2f57c..55bc77422 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -33,7 +33,6 @@ from ..http import USER_AGENT, build_host from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol from ..uri import WebSocketURI, parse_uri -from .compatibility import asyncio_get_running_loop from .handshake import build_request, check_response from .http import read_response from .protocol import WebSocketCommonProtocol @@ -517,7 +516,7 @@ def __init__( # Backwards compatibility: the loop parameter used to be supported. loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None) if loop is None: - loop = asyncio_get_running_loop() + loop = asyncio.get_event_loop() else: warnings.warn("remove loop argument", DeprecationWarning) diff --git a/src/websockets/legacy/compatibility.py b/src/websockets/legacy/compatibility.py index 86f6715fd..96df028e4 100644 --- a/src/websockets/legacy/compatibility.py +++ b/src/websockets/legacy/compatibility.py @@ -9,14 +9,3 @@ def loop_if_py_lt_38(loop: asyncio.AbstractEventLoop) -> Dict[str, Any]: """ return {"loop": loop} if sys.version_info[:2] < (3, 8) else {} - - -def asyncio_get_running_loop() -> asyncio.AbstractEventLoop: - """ - Helper for the deprecation of get_event_loop in Python 3.10. - - """ - if sys.version_info[:2] < (3, 10): # pragma: no cover - return asyncio.get_event_loop() - else: # pragma: no cover - return asyncio.get_running_loop() diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 1704ae083..ab03112b2 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -42,7 +42,7 @@ from ..headers import build_extension, parse_extension, parse_subprotocol from ..http import USER_AGENT from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol -from .compatibility import asyncio_get_running_loop, loop_if_py_lt_38 +from .compatibility import loop_if_py_lt_38 from .handshake import build_response, check_request from .http import read_request from .protocol import WebSocketCommonProtocol @@ -1026,7 +1026,7 @@ def __init__( # Backwards compatibility: the loop parameter used to be supported. loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None) if loop is None: - loop = asyncio_get_running_loop() + loop = asyncio.get_event_loop() else: warnings.warn("remove loop argument", DeprecationWarning) From cb11516e0ed4fe2b67ab6c1511650bd42115d0b6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 7 Jun 2021 22:11:26 +0200 Subject: [PATCH 0854/1539] Stop requiring loop in WebSocketCommonProtocol. Change tests to avoid passing a loop argument. Fix #988. --- src/websockets/legacy/client.py | 7 ++--- src/websockets/legacy/protocol.py | 8 ++++-- src/websockets/legacy/server.py | 7 ++--- tests/legacy/test_client_server.py | 35 ++++++++++++------------- tests/legacy/test_protocol.py | 41 +++++++++++++++++++++--------- 5 files changed, 59 insertions(+), 39 deletions(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 55bc77422..695da3cdb 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -514,10 +514,11 @@ def __init__( legacy_recv: bool = kwargs.pop("legacy_recv", False) # Backwards compatibility: the loop parameter used to be supported. - loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None) - if loop is None: + _loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None) + if _loop is None: loop = asyncio.get_event_loop() else: + loop = _loop warnings.warn("remove loop argument", DeprecationWarning) wsuri = parse_uri(uri) @@ -543,7 +544,7 @@ def __init__( max_queue=max_queue, read_limit=read_limit, write_limit=write_limit, - loop=loop, + loop=_loop, host=wsuri.host, port=wsuri.port, secure=wsuri.secure, diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 353317ab5..ca6076142 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -124,6 +124,12 @@ def __init__( if close_timeout is None: close_timeout = timeout + # Backwards compatibility: the loop parameter used to be supported. + if loop is None: + loop = asyncio.get_event_loop() + else: + warnings.warn("remove loop argument", DeprecationWarning) + self.ping_interval = ping_interval self.ping_timeout = ping_timeout self.close_timeout = close_timeout @@ -145,8 +151,6 @@ def __init__( # Track if DEBUG is enabled. Shortcut logging calls if it isn't. self.debug = logger.isEnabledFor(logging.DEBUG) - assert loop is not None - # Remove when dropping Python < 3.10 - use get_running_loop instead. self.loop = loop self._host = host diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index ab03112b2..c8425993d 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -1024,10 +1024,11 @@ def __init__( legacy_recv: bool = kwargs.pop("legacy_recv", False) # Backwards compatibility: the loop parameter used to be supported. - loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None) - if loop is None: + _loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None) + if _loop is None: loop = asyncio.get_event_loop() else: + loop = _loop warnings.warn("remove loop argument", DeprecationWarning) ws_server = WebSocketServer(logger=logger, loop=loop) @@ -1053,7 +1054,7 @@ def __init__( max_queue=max_queue, read_limit=read_limit, write_limit=write_limit, - loop=loop, + loop=_loop, legacy_recv=legacy_recv, origins=origins, extensions=extensions, diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 60c0a14ae..041eca28e 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -213,17 +213,17 @@ def start_server(self, deprecation_warnings=None, **kwargs): kwargs.setdefault("compression", None) # Disable pings by default in tests. kwargs.setdefault("ping_interval", None) - # Python 3.10 dislikes not having a running event loop - if sys.version_info[:2] >= (3, 10): # pragma: no cover - kwargs.setdefault("loop", self.loop) with warnings.catch_warnings(record=True) as recorded_warnings: start_server = serve(handler, "localhost", 0, **kwargs) self.server = self.loop.run_until_complete(start_server) expected_warnings = [] if deprecation_warnings is None else deprecation_warnings - if sys.version_info[:2] >= (3, 10): # pragma: no cover - expected_warnings += ["remove loop argument"] + if ( + sys.version_info[:2] >= (3, 10) + and "remove loop argument" not in expected_warnings + ): # pragma: no cover + expected_warnings += ["There is no current event loop"] self.assertDeprecationWarnings(recorded_warnings, expected_warnings) def start_redirecting_server( @@ -234,10 +234,6 @@ def start_redirecting_server( deprecation_warnings=None, **kwargs, ): - # Python 3.10 dislikes not having a running event loop - if sys.version_info[:2] >= (3, 10): # pragma: no cover - kwargs.setdefault("loop", self.loop) - async def process_request(path, headers): server_uri = get_server_uri(self.server, self.secure, path) if force_insecure: @@ -259,8 +255,11 @@ async def process_request(path, headers): self.redirecting_server = self.loop.run_until_complete(start_server) expected_warnings = [] if deprecation_warnings is None else deprecation_warnings - if sys.version_info[:2] >= (3, 10): # pragma: no cover - expected_warnings += ["remove loop argument"] + if ( + sys.version_info[:2] >= (3, 10) + and "remove loop argument" not in expected_warnings + ): # pragma: no cover + expected_warnings += ["There is no current event loop"] self.assertDeprecationWarnings(recorded_warnings, expected_warnings) def start_client( @@ -270,9 +269,6 @@ def start_client( kwargs.setdefault("compression", None) # Disable pings by default in tests. kwargs.setdefault("ping_interval", None) - # Python 3.10 dislikes not having a running event loop - if sys.version_info[:2] >= (3, 10): # pragma: no cover - kwargs.setdefault("loop", self.loop) secure = kwargs.get("ssl") is not None try: @@ -286,8 +282,11 @@ def start_client( self.client = self.loop.run_until_complete(start_client) expected_warnings = [] if deprecation_warnings is None else deprecation_warnings - if sys.version_info[:2] >= (3, 10): # pragma: no cover - expected_warnings += ["remove loop argument"] + if ( + sys.version_info[:2] >= (3, 10) + and "remove loop argument" not in expected_warnings + ): # pragma: no cover + expected_warnings += ["There is no current event loop"] self.assertDeprecationWarnings(recorded_warnings, expected_warnings) def stop_client(self): @@ -409,8 +408,6 @@ def test_infinite_redirect(self): with temp_test_redirecting_server( self, http.HTTPStatus.FOUND, - loop=self.loop, - deprecation_warnings=["remove loop argument"], ): self.server = self.redirecting_server with self.assertRaises(InvalidHandshake): @@ -430,7 +427,7 @@ def test_redirect_missing_location(self): with temp_test_client(self): self.fail("Did not raise") # pragma: no cover - def test_explicit_event_loop(self): + def test_loop_backwards_compatibility(self): with self.temp_server( loop=self.loop, deprecation_warnings=["remove loop argument"] ): diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index c1948c838..ccbbffe7c 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1,5 +1,6 @@ import asyncio import contextlib +import sys import unittest import unittest.mock import warnings @@ -86,8 +87,9 @@ class CommonTests: def setUp(self): super().setUp() - # Disable pings to make it easier to test what frames are sent exactly. - self.protocol = WebSocketCommonProtocol(ping_interval=None, loop=self.loop) + with warnings.catch_warnings(record=True): + # Disable pings to make it easier to test what frames are sent exactly. + self.protocol = WebSocketCommonProtocol(ping_interval=None) self.transport = TransportMock() self.transport.setup_mock(self.loop, self.protocol) @@ -309,14 +311,29 @@ def assertCompletesWithin(self, min_time, max_time): # Test constructor. def test_timeout_backwards_compatibility(self): - with warnings.catch_warnings(record=True) as recorded_warnings: - protocol = WebSocketCommonProtocol(timeout=5, loop=self.loop) + with warnings.catch_warnings(record=True) as recorded: + protocol = WebSocketCommonProtocol(timeout=5) self.assertEqual(protocol.close_timeout, 5) - self.assertDeprecationWarnings( - recorded_warnings, ["rename timeout to close_timeout"] - ) + expected = ["rename timeout to close_timeout"] + if sys.version_info[:2] >= (3, 10): # pragma: no cover + expected += ["There is no current event loop"] + + self.assertDeprecationWarnings(recorded, expected) + + def test_loop_backwards_compatibility(self): + loop = asyncio.new_event_loop() + self.addCleanup(loop.close) + + with warnings.catch_warnings(record=True) as recorded: + protocol = WebSocketCommonProtocol(loop=loop) + + self.assertEqual(protocol.loop, loop) + + expected = ["remove loop argument"] + + self.assertDeprecationWarnings(recorded, expected) # Test public attributes. @@ -1116,11 +1133,11 @@ def restart_protocol_with_keepalive_ping( self.transport.close() self.loop.run_until_complete(self.protocol.close()) # copied from setUp, but enables keepalive pings - self.protocol = WebSocketCommonProtocol( - ping_interval=ping_interval, - ping_timeout=ping_timeout, - loop=self.loop, - ) + with warnings.catch_warnings(record=True): + self.protocol = WebSocketCommonProtocol( + ping_interval=ping_interval, + ping_timeout=ping_timeout, + ) self.transport = TransportMock() self.transport.setup_mock(self.loop, self.protocol) self.protocol.is_client = initial_protocol.is_client From 29094a27e65ae34b9219effd2f4bd755f0a63ff5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Jun 2021 11:49:15 +0200 Subject: [PATCH 0855/1539] Add support for reconnecting automatically. Fix #414. --- docs/howto/logging.rst | 1 + docs/project/changelog.rst | 3 ++ docs/spelling_wordlist.txt | 2 + src/websockets/legacy/client.py | 69 +++++++++++++++++++++++--- src/websockets/legacy/server.py | 2 +- tests/legacy/test_client_server.py | 79 ++++++++++++++++++++++++++++++ 6 files changed, 149 insertions(+), 7 deletions(-) diff --git a/docs/howto/logging.rst b/docs/howto/logging.rst index f69ee47b9..824812959 100644 --- a/docs/howto/logging.rst +++ b/docs/howto/logging.rst @@ -210,6 +210,7 @@ Here's what websockets logs at each level. * Server starting and stopping * Server establishing and closing connections +* Client reconnecting automatically :attr:`~logging.DEBUG` ...................... diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index fc751bbd0..a0cd8f07c 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -51,6 +51,9 @@ They may change at any time. * Added compatibility with Python 3.10. +* Added support for reconnecting automatically by using + :func:`~legacy.client.connect` as an asynchronous iterator. + * Added ``open_timeout`` to :func:`~legacy.client.connect`. * Improved logging. diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index b460ef033..8346acefa 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -6,6 +6,7 @@ autoscaler awaitable aymeric backend +backoff backpressure balancer balancers @@ -52,6 +53,7 @@ pong pongs proxying pythonic +reconnection redis retransmit runtime diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 695da3cdb..63fdbf9b2 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -9,7 +9,18 @@ import logging import warnings from types import TracebackType -from typing import Any, Callable, Generator, List, Optional, Sequence, Tuple, Type, cast +from typing import ( + Any, + AsyncIterator, + Callable, + Generator, + List, + Optional, + Sequence, + Tuple, + Type, + cast, +) from ..datastructures import Headers, HeadersLike from ..exceptions import ( @@ -412,12 +423,23 @@ class Connect: Awaiting :func:`connect` yields a :class:`WebSocketClientProtocol` which can then be used to send and receive messages. - :func:`connect` can also be used as a asynchronous context manager:: + :func:`connect` can be used as a asynchronous context manager:: async with connect(...) as websocket: ... - In that case, the connection is closed when exiting the context. + The connection is closed automatically when exiting the context. + + :func:`connect` can be used as an infinite asynchronous iterator to + reconnect automatically on errors:: + + async for websocket in connect(...): + ... + + You must catch all exceptions, or else you will exit the loop prematurely. + As above, connections are closed automatically. Connection attempts are + delayed with exponential backoff, starting at three seconds and + increasing up to one minute. :func:`connect` is a wrapper around the event loop's :meth:`~asyncio.loop.create_connection` method. Unknown keyword arguments @@ -577,6 +599,10 @@ def __init__( ) self.open_timeout = open_timeout + if logger is None: + logger = logging.getLogger("websockets.client") + self.logger = logger + # This is a coroutine function. self._create_connection = create_connection self._wsuri = wsuri @@ -615,7 +641,38 @@ def handle_redirect(self, uri: str) -> None: # Set the new WebSocket URI. This suffices for same-origin redirects. self._wsuri = new_wsuri - # async with connect(...) + # async for ... in connect(...): + + BACKOFF_MIN = 2.0 + BACKOFF_MAX = 60.0 + BACKOFF_FACTOR = 1.5 + + async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]: + backoff_delay = self.BACKOFF_MIN + while True: + try: + async with self as protocol: + yield protocol + # Remove this branch when dropping support for Python < 3.8 + # because CancelledError no longer inherits Exception. + except asyncio.CancelledError: # pragma: no cover + raise + except Exception: + # Connection timed out - increase backoff delay + backoff_delay = backoff_delay * self.BACKOFF_FACTOR + backoff_delay = min(backoff_delay, self.BACKOFF_MAX) + self.logger.info( + "! connect failed; retrying in %d seconds", + int(backoff_delay), + exc_info=True, + ) + await asyncio.sleep(backoff_delay) + continue + else: + # Connection succeeded - reset backoff delay + backoff_delay = self.BACKOFF_MIN + + # async with connect(...) as ...: async def __aenter__(self) -> WebSocketClientProtocol: return await self @@ -628,7 +685,7 @@ async def __aexit__( ) -> None: await self.protocol.close() - # await connect(...) + # ... = await connect(...) def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]: # Create a suitable iterator by calling __await__ on a coroutine. @@ -665,7 +722,7 @@ async def __await_impl__(self) -> WebSocketClientProtocol: else: raise SecurityError("too many redirects") - # yield from connect(...) + # ... = yield from connect(...) __iter__ = __await__ diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index c8425993d..9ace9a2c6 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -904,7 +904,7 @@ class Serve: :exc:`~websockets.exceptions.ConnectionClosedOK` exception on their current or next interaction with the WebSocket connection. - :func:`serve` can also be used as an asynchronous context manager:: + :func:`serve` can be used as an asynchronous context manager:: stop = asyncio.Future() # set this future to exit the server diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 041eca28e..2f754aa89 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1471,6 +1471,85 @@ async def run_client(): self.assertEqual(messages, self.MESSAGES) +class ReconnectionTests(ClientServerTestsMixin, AsyncioTestCase): + async def echo_handler(ws, path): + async for msg in ws: + await ws.send(msg) + + service_available = True + + async def maybe_service_unavailable(path, headers): + if not ReconnectionTests.service_available: + return http.HTTPStatus.SERVICE_UNAVAILABLE, [], b"" + + async def disable_server(self, duration): + ReconnectionTests.service_available = False + await asyncio.sleep(duration) + ReconnectionTests.service_available = True + + @with_server(handler=echo_handler, process_request=maybe_service_unavailable) + def test_reconnect(self): + # Big, ugly integration test :-( + + async def run_client(): + iteration = 0 + connect_inst = connect(get_server_uri(self.server)) + connect_inst.BACKOFF_MIN = 10 * MS + connect_inst.BACKOFF_MAX = 200 * MS + async for ws in connect_inst: + await ws.send("spam") + msg = await ws.recv() + self.assertEqual(msg, "spam") + + iteration += 1 + if iteration == 1: + # Exit block normally. + pass + elif iteration == 2: + # Disable server for a little bit + asyncio.create_task(self.disable_server(70 * MS)) + await asyncio.sleep(0) + elif iteration == 3: + # Exit block after catching connection error. + server_ws = next(iter(self.server.websockets)) + await server_ws.close() + with self.assertRaises(ConnectionClosed): + await ws.recv() + else: + # Exit block with an exception. + raise Exception("BOOM!") + + with self.assertLogs("websockets", logging.INFO) as logs: + with self.assertRaisesRegex(Exception, "BOOM!"): + self.loop.run_until_complete(run_client()) + + self.assertEqual( + [record.getMessage() for record in logs.records], + [ + # Iteration 1 + "connection open", + "connection closed", + # Iteration 2 + "connection open", + "connection closed", + # Iteration 3 + "connection failed (503 Service Unavailable)", + "connection closed", + "! connect failed; retrying in 0 seconds", + "connection failed (503 Service Unavailable)", + "connection closed", + "! connect failed; retrying in 0 seconds", + "connection failed (503 Service Unavailable)", + "connection closed", + "! connect failed; retrying in 0 seconds", + "connection open", + "connection closed", + # Iteration 4 + "connection open", + ], + ) + + class LoggerTests(ClientServerTestsMixin, AsyncioTestCase): def test_logger_client(self): with self.assertLogs("test.server", logging.DEBUG) as server_logs: From 0a60fac6c2cd922198bd17e034c9a4406e273db7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 9 Jun 2021 08:15:00 +0200 Subject: [PATCH 0856/1539] Make test less timing-sensitive. Fix #991. --- tests/legacy/test_client_server.py | 36 ++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 2f754aa89..d3e3f1e9f 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1507,7 +1507,7 @@ async def run_client(): pass elif iteration == 2: # Disable server for a little bit - asyncio.create_task(self.disable_server(70 * MS)) + asyncio.create_task(self.disable_server(50 * MS)) await asyncio.sleep(0) elif iteration == 3: # Exit block after catching connection error. @@ -1523,28 +1523,40 @@ async def run_client(): with self.assertRaisesRegex(Exception, "BOOM!"): self.loop.run_until_complete(run_client()) + # Iteration 1 self.assertEqual( - [record.getMessage() for record in logs.records], + [record.getMessage() for record in logs.records][:2], [ - # Iteration 1 "connection open", "connection closed", - # Iteration 2 + ], + ) + # Iteration 2 + self.assertEqual( + [record.getMessage() for record in logs.records][2:4], + [ "connection open", "connection closed", - # Iteration 3 - "connection failed (503 Service Unavailable)", - "connection closed", - "! connect failed; retrying in 0 seconds", - "connection failed (503 Service Unavailable)", - "connection closed", - "! connect failed; retrying in 0 seconds", + ], + ) + # Iteration 3 + self.assertEqual( + [record.getMessage() for record in logs.records][4:-1], + [ "connection failed (503 Service Unavailable)", "connection closed", "! connect failed; retrying in 0 seconds", + ] + * ((len(logs.records) - 5) // 3) + + [ "connection open", "connection closed", - # Iteration 4 + ], + ) + # Iteration 4 + self.assertEqual( + [record.getMessage() for record in logs.records][-1:], + [ "connection open", ], ) From 298cdabef68e9619d89100684eb27f719fc361a6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 10 Jun 2021 22:58:57 +0200 Subject: [PATCH 0857/1539] Postpone evaluation of annotations. See PEP 563. --- src/websockets/__main__.py | 11 +++++++---- src/websockets/client.py | 2 ++ src/websockets/connection.py | 2 ++ src/websockets/datastructures.py | 2 ++ src/websockets/exceptions.py | 4 +++- src/websockets/extensions/base.py | 2 ++ src/websockets/extensions/permessage_deflate.py | 6 ++++-- src/websockets/frames.py | 6 ++++-- src/websockets/headers.py | 2 ++ src/websockets/http.py | 2 ++ src/websockets/http11.py | 2 ++ src/websockets/imports.py | 2 ++ src/websockets/legacy/auth.py | 1 + src/websockets/legacy/client.py | 2 ++ src/websockets/legacy/compatibility.py | 2 ++ src/websockets/legacy/framing.py | 6 ++++-- src/websockets/legacy/handshake.py | 2 ++ src/websockets/legacy/http.py | 2 ++ src/websockets/legacy/protocol.py | 2 ++ src/websockets/legacy/server.py | 6 ++++-- src/websockets/server.py | 2 ++ src/websockets/streams.py | 2 ++ src/websockets/typing.py | 2 ++ src/websockets/uri.py | 2 ++ src/websockets/utils.py | 2 ++ 25 files changed, 63 insertions(+), 13 deletions(-) diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index 4358c323c..dae165cd2 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse import asyncio import os @@ -44,7 +46,8 @@ def win_enable_vt100() -> None: def exit_from_event_loop_thread( - loop: asyncio.AbstractEventLoop, stop: "asyncio.Future[None]" + loop: asyncio.AbstractEventLoop, + stop: asyncio.Future[None], ) -> None: loop.stop() if not stop.done(): @@ -92,8 +95,8 @@ def print_over_input(string: str) -> None: async def run_client( uri: str, loop: asyncio.AbstractEventLoop, - inputs: "asyncio.Queue[str]", - stop: "asyncio.Future[None]", + inputs: asyncio.Queue[str], + stop: asyncio.Future[None], ) -> None: try: websocket = await connect(uri) @@ -183,7 +186,7 @@ async def queue_factory() -> "asyncio.Queue[str]": return asyncio.Queue() # Create a queue of user inputs. There's no need to limit its size. - inputs: "asyncio.Queue[str]" = loop.run_until_complete(queue_factory()) + inputs: asyncio.Queue[str] = loop.run_until_complete(queue_factory()) # Create a stop condition when receiving SIGINT or SIGTERM. stop: asyncio.Future[None] = loop.create_future() diff --git a/src/websockets/client.py b/src/websockets/client.py index a517140cd..cd83a59bc 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import collections from typing import Generator, List, Optional, Sequence diff --git a/src/websockets/connection.py b/src/websockets/connection.py index dbdc15bb5..c98f5c39b 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import enum import logging import uuid diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index 66f91e9bb..117ffd4f2 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -3,6 +3,8 @@ """ +from __future__ import annotations + from typing import ( Any, Dict, diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 644735258..061b902fb 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -28,6 +28,8 @@ """ +from __future__ import annotations + import http from typing import Optional @@ -198,7 +200,7 @@ class InvalidStatus(InvalidHandshake): """ - def __init__(self, response: "Response") -> None: + def __init__(self, response: Response) -> None: self.response = response message = f"server rejected WebSocket connection: HTTP {response.status_code:d}" super().__init__(message) diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index cfc090799..82ae5e27f 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -8,6 +8,8 @@ """ +from __future__ import annotations + from typing import List, Optional, Sequence, Tuple from ..frames import Frame diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 56ef03e0c..d578a7a1d 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -4,6 +4,8 @@ """ +from __future__ import annotations + import dataclasses import zlib from typing import Any, Dict, List, Optional, Sequence, Tuple, Union @@ -328,7 +330,7 @@ def get_request_params(self) -> List[ExtensionParameter]: def process_response_params( self, params: Sequence[ExtensionParameter], - accepted_extensions: Sequence["Extension"], + accepted_extensions: Sequence[Extension], ) -> PerMessageDeflate: """ Process response parameters. @@ -510,7 +512,7 @@ def __init__( def process_request_params( self, params: Sequence[ExtensionParameter], - accepted_extensions: Sequence["Extension"], + accepted_extensions: Sequence[Extension], ) -> Tuple[List[ExtensionParameter], PerMessageDeflate]: """ Process request parameters. diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 9011ac867..99e43388b 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -3,6 +3,8 @@ """ +from __future__ import annotations + import dataclasses import enum import io @@ -170,7 +172,7 @@ def parse( *, mask: bool, max_size: Optional[int] = None, - extensions: Optional[Sequence["extensions.Extension"]] = None, + extensions: Optional[Sequence[extensions.Extension]] = None, ) -> Generator[None, None, "Frame"]: """ Read a WebSocket frame. @@ -239,7 +241,7 @@ def serialize( self, *, mask: bool, - extensions: Optional[Sequence["extensions.Extension"]] = None, + extensions: Optional[Sequence[extensions.Extension]] = None, ) -> bytes: """ Write a WebSocket frame. diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 6779c9c04..12d2a4e94 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -4,6 +4,8 @@ """ +from __future__ import annotations + import base64 import binascii import re diff --git a/src/websockets/http.py b/src/websockets/http.py index 9092836c2..6168c5144 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import ipaddress import sys diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 22488dc89..aaa61f8c7 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dataclasses import re from typing import Callable, Generator, Optional diff --git a/src/websockets/imports.py b/src/websockets/imports.py index 06917ce1d..c9508d188 100644 --- a/src/websockets/imports.py +++ b/src/websockets/imports.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import warnings from typing import Any, Dict, Iterable, Optional diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index 16016e6fd..5f2b1311a 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -4,6 +4,7 @@ """ +from __future__ import annotations import functools import hmac diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 63fdbf9b2..468b5c15c 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -3,6 +3,8 @@ """ +from __future__ import annotations + import asyncio import collections.abc import functools diff --git a/src/websockets/legacy/compatibility.py b/src/websockets/legacy/compatibility.py index 96df028e4..df81de9db 100644 --- a/src/websockets/legacy/compatibility.py +++ b/src/websockets/legacy/compatibility.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import sys from typing import Any, Dict diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index 901b2e2e3..14667925f 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -10,6 +10,8 @@ """ +from __future__ import annotations + import struct from typing import Any, Awaitable, Callable, NamedTuple, Optional, Sequence @@ -56,7 +58,7 @@ async def read( *, mask: bool, max_size: Optional[int] = None, - extensions: Optional[Sequence["extensions.Extension"]] = None, + extensions: Optional[Sequence[extensions.Extension]] = None, ) -> "Frame": """ Read a WebSocket frame. @@ -134,7 +136,7 @@ def write( write: Callable[[bytes], Any], *, mask: bool, - extensions: Optional[Sequence["extensions.Extension"]] = None, + extensions: Optional[Sequence[extensions.Extension]] = None, ) -> None: """ Write a WebSocket frame. diff --git a/src/websockets/legacy/handshake.py b/src/websockets/legacy/handshake.py index 44da72d21..49d08cfe8 100644 --- a/src/websockets/legacy/handshake.py +++ b/src/websockets/legacy/handshake.py @@ -25,6 +25,8 @@ """ +from __future__ import annotations + import base64 import binascii from typing import List diff --git a/src/websockets/legacy/http.py b/src/websockets/legacy/http.py index c18e08e8d..0b9a92267 100644 --- a/src/websockets/legacy/http.py +++ b/src/websockets/legacy/http.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import re from typing import Tuple diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index ca6076142..f8daf544b 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -7,6 +7,8 @@ """ +from __future__ import annotations + import asyncio import codecs import collections diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 9ace9a2c6..a7a98e006 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -3,6 +3,8 @@ """ +from __future__ import annotations + import asyncio import collections.abc import email.utils @@ -169,8 +171,8 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): def __init__( self, - ws_handler: Callable[["WebSocketServerProtocol", str], Awaitable[Any]], - ws_server: "WebSocketServer", + ws_handler: Callable[[WebSocketServerProtocol, str], Awaitable[Any]], + ws_server: WebSocketServer, *, origins: Optional[Sequence[Optional[Origin]]] = None, extensions: Optional[Sequence[ServerExtensionFactory]] = None, diff --git a/src/websockets/server.py b/src/websockets/server.py index 483ddbe1e..09ed63150 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import binascii import collections diff --git a/src/websockets/streams.py b/src/websockets/streams.py index 6f3163034..e02a6ab39 100644 --- a/src/websockets/streams.py +++ b/src/websockets/streams.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Generator diff --git a/src/websockets/typing.py b/src/websockets/typing.py index b1858d73e..13b172f15 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging from typing import List, NewType, Optional, Tuple, Union diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 958975b22..7406a60a8 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -7,6 +7,8 @@ """ +from __future__ import annotations + import dataclasses import urllib.parse from typing import Optional, Tuple diff --git a/src/websockets/utils.py b/src/websockets/utils.py index 59210e438..ffb706963 100644 --- a/src/websockets/utils.py +++ b/src/websockets/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import hashlib import itertools From 71dbbffabcaaaba5dbcfb21df20f98e2bdbe01f6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 10 Jun 2021 23:16:30 +0200 Subject: [PATCH 0858/1539] Break import cycles. Import modules rather than their contents. Fix #989. --- src/websockets/client.py | 2 +- src/websockets/connection.py | 2 +- src/websockets/exceptions.py | 19 +++++---------- src/websockets/extensions/base.py | 11 ++++++--- .../extensions/permessage_deflate.py | 19 +++++++++------ src/websockets/frames.py | 24 +++++++++---------- src/websockets/http11.py | 19 +++++++-------- src/websockets/legacy/client.py | 2 +- src/websockets/legacy/framing.py | 15 +++++------- src/websockets/legacy/protocol.py | 2 +- src/websockets/legacy/server.py | 2 +- src/websockets/server.py | 2 +- src/websockets/uri.py | 6 ++--- 13 files changed, 61 insertions(+), 64 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index cd83a59bc..3217090d0 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -13,7 +13,7 @@ InvalidUpgrade, NegotiationError, ) -from .extensions.base import ClientExtensionFactory, Extension +from .extensions import ClientExtensionFactory, Extension from .headers import ( build_authorization_basic, build_extension, diff --git a/src/websockets/connection.py b/src/websockets/connection.py index c98f5c39b..57d3a2227 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -6,7 +6,7 @@ from typing import Generator, List, Optional, Union from .exceptions import InvalidState, PayloadTooBig, ProtocolError -from .extensions.base import Extension +from .extensions import Extension from .frames import ( OP_BINARY, OP_CLOSE, diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 061b902fb..f110322aa 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -33,7 +33,7 @@ import http from typing import Optional -from .datastructures import Headers, HeadersLike +from . import datastructures, frames, http11 __all__ = [ @@ -84,7 +84,7 @@ class ConnectionClosed(WebSocketException): def __init__(self, code: int, reason: str) -> None: self.code = code self.reason = reason - super().__init__(format_close(code, reason)) + super().__init__(frames.format_close(code, reason)) class ConnectionClosedError(ConnectionClosed): @@ -200,7 +200,7 @@ class InvalidStatus(InvalidHandshake): """ - def __init__(self, response: Response) -> None: + def __init__(self, response: http11.Response) -> None: self.response = response message = f"server rejected WebSocket connection: HTTP {response.status_code:d}" super().__init__(message) @@ -215,7 +215,7 @@ class InvalidStatusCode(InvalidHandshake): """ - def __init__(self, status_code: int, headers: Headers) -> None: + def __init__(self, status_code: int, headers: datastructures.Headers) -> None: self.status_code = status_code self.headers = headers message = f"server rejected WebSocket connection: HTTP {status_code}" @@ -284,11 +284,11 @@ class AbortHandshake(InvalidHandshake): def __init__( self, status: http.HTTPStatus, - headers: HeadersLike, + headers: datastructures.HeadersLike, body: bytes = b"", ) -> None: self.status = status - self.headers = Headers(headers) + self.headers = datastructures.Headers(headers) self.body = body message = f"HTTP {status:d}, {len(self.headers)} headers, {len(body)} bytes" super().__init__(message) @@ -347,10 +347,3 @@ class ProtocolError(WebSocketException): WebSocketProtocolError = ProtocolError # for backwards compatibility - - -# at the bottom to allow circular import, because the frames module imports exceptions -from .frames import format_close # noqa - -# at the bottom to allow circular import, because the http11 module imports exceptions -from .http11 import Response # noqa diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index 82ae5e27f..2de9176bd 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -12,7 +12,7 @@ from typing import List, Optional, Sequence, Tuple -from ..frames import Frame +from .. import frames from ..typing import ExtensionName, ExtensionParameter @@ -32,7 +32,12 @@ def name(self) -> ExtensionName: """ - def decode(self, frame: Frame, *, max_size: Optional[int] = None) -> Frame: + def decode( + self, + frame: frames.Frame, + *, + max_size: Optional[int] = None, + ) -> frames.Frame: """ Decode an incoming frame. @@ -41,7 +46,7 @@ def decode(self, frame: Frame, *, max_size: Optional[int] = None) -> Frame: """ - def encode(self, frame: Frame) -> Frame: + def encode(self, frame: frames.Frame) -> frames.Frame: """ Encode an outgoing frame. diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index d578a7a1d..5604fb8f9 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -10,6 +10,7 @@ import zlib from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from .. import frames from ..exceptions import ( DuplicateParameter, InvalidParameterName, @@ -17,7 +18,6 @@ NegotiationError, PayloadTooBig, ) -from ..frames import CTRL_OPCODES, OP_CONT, Frame from ..typing import ExtensionName, ExtensionParameter from .base import ClientExtensionFactory, Extension, ServerExtensionFactory @@ -93,19 +93,24 @@ def __repr__(self) -> str: f"local_max_window_bits={self.local_max_window_bits})" ) - def decode(self, frame: Frame, *, max_size: Optional[int] = None) -> Frame: + def decode( + self, + frame: frames.Frame, + *, + max_size: Optional[int] = None, + ) -> frames.Frame: """ Decode an incoming frame. """ # Skip control frames. - if frame.opcode in CTRL_OPCODES: + if frame.opcode in frames.CTRL_OPCODES: return frame # Handle continuation data frames: # - skip if the message isn't encoded # - reset "decode continuation data" flag if it's a final frame - if frame.opcode == OP_CONT: + if frame.opcode is frames.OP_CONT: if not self.decode_cont_data: return frame if frame.fin: @@ -143,19 +148,19 @@ def decode(self, frame: Frame, *, max_size: Optional[int] = None) -> Frame: return dataclasses.replace(frame, data=data) - def encode(self, frame: Frame) -> Frame: + def encode(self, frame: frames.Frame) -> frames.Frame: """ Encode an outgoing frame. """ # Skip control frames. - if frame.opcode in CTRL_OPCODES: + if frame.opcode in frames.CTRL_OPCODES: return frame # Since we always encode messages, there's no "encode continuation # data" flag similar to "decode continuation data" at this time. - if frame.opcode != OP_CONT: + if frame.opcode is not frames.OP_CONT: # Set the rsv1 flag on the first frame of a compressed message. frame = dataclasses.replace(frame, rsv1=True) # Re-initialize per-message decoder. diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 99e43388b..510fc6198 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -12,7 +12,7 @@ import struct from typing import Callable, Generator, Optional, Sequence, Tuple -from .exceptions import PayloadTooBig, ProtocolError +from . import exceptions, extensions from .typing import Data @@ -204,10 +204,10 @@ def parse( try: opcode = Opcode(head1 & 0b00001111) except ValueError as exc: - raise ProtocolError("invalid opcode") from exc + raise exceptions.ProtocolError("invalid opcode") from exc if (True if head2 & 0b10000000 else False) != mask: - raise ProtocolError("incorrect masking") + raise exceptions.ProtocolError("incorrect masking") length = head2 & 0b01111111 if length == 126: @@ -217,7 +217,9 @@ def parse( data = yield from read_exact(8) (length,) = struct.unpack("!Q", data) if max_size is not None and length > max_size: - raise PayloadTooBig(f"over size limit ({length} > {max_size} bytes)") + raise exceptions.PayloadTooBig( + f"over size limit ({length} > {max_size} bytes)" + ) if mask: mask_bytes = yield from read_exact(4) @@ -306,13 +308,13 @@ def check(self) -> None: """ if self.rsv1 or self.rsv2 or self.rsv3: - raise ProtocolError("reserved bits must be 0") + raise exceptions.ProtocolError("reserved bits must be 0") if self.opcode in CTRL_OPCODES: if len(self.data) > 125: - raise ProtocolError("control frame too long") + raise exceptions.ProtocolError("control frame too long") if not self.fin: - raise ProtocolError("fragmented control frame") + raise exceptions.ProtocolError("fragmented control frame") def prepare_data(data: Data) -> Tuple[int, bytes]: @@ -380,7 +382,7 @@ def parse_close(data: bytes) -> Tuple[int, str]: return 1005, "" else: assert length == 1 - raise ProtocolError("close frame too short") + raise exceptions.ProtocolError("close frame too short") def serialize_close(code: int, reason: str) -> bytes: @@ -403,7 +405,7 @@ def check_close(code: int) -> None: """ if not (code in EXTERNAL_CLOSE_CODES or 3000 <= code < 5000): - raise ProtocolError("invalid status code") + raise exceptions.ProtocolError("invalid status code") def format_close(code: int, reason: str) -> str: @@ -425,7 +427,3 @@ def format_close(code: int, reason: str) -> str: result += "no reason" return result - - -# at the bottom to allow circular import, because Extension depends on Frame -from . import extensions # noqa diff --git a/src/websockets/http11.py b/src/websockets/http11.py index aaa61f8c7..11b9d7f39 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -4,8 +4,7 @@ import re from typing import Callable, Generator, Optional -from .datastructures import Headers -from .exceptions import SecurityError +from . import datastructures, exceptions MAX_HEADERS = 256 @@ -50,7 +49,7 @@ class Request: """ path: str - headers: Headers + headers: datastructures.Headers # body isn't useful is the context of this library # If processing the request triggers an exception, it's stored here. @@ -75,7 +74,7 @@ def parse( :param read_line: generator-based coroutine that reads a LF-terminated line or raises an exception if there isn't enough data :raises EOFError: if the connection is closed without a full HTTP request - :raises SecurityError: if the request exceeds a security limit + :raises exceptions.SecurityError: if the request exceeds a security limit :raises ValueError: if the request isn't well formatted """ @@ -134,7 +133,7 @@ class Response: status_code: int reason_phrase: str - headers: Headers + headers: datastructures.Headers body: Optional[bytes] = None # If processing the response triggers an exception, it's stored here. @@ -162,7 +161,7 @@ def parse( :param read_exact: generator-based coroutine that reads the requested number of bytes or raises an exception if there isn't enough data :raises EOFError: if the connection is closed without a full HTTP response - :raises SecurityError: if the response exceeds a security limit + :raises exceptions.SecurityError: if the response exceeds a security limit :raises LookupError: if the response isn't well formatted :raises ValueError: if the response isn't well formatted @@ -242,7 +241,7 @@ def serialize(self) -> bytes: def parse_headers( read_line: Callable[[], Generator[None, None, bytes]] -) -> Generator[None, None, Headers]: +) -> Generator[None, None, datastructures.Headers]: """ Parse HTTP headers. @@ -256,7 +255,7 @@ def parse_headers( # We don't attempt to support obsolete line folding. - headers = Headers() + headers = datastructures.Headers() for _ in range(MAX_HEADERS + 1): try: line = yield from parse_line(read_line) @@ -280,7 +279,7 @@ def parse_headers( headers[name] = value else: - raise SecurityError("too many HTTP headers") + raise exceptions.SecurityError("too many HTTP headers") return headers @@ -301,7 +300,7 @@ def parse_line( line = yield from read_line() # Security: this guarantees header values are small (hard-coded = 4 KiB) if len(line) > MAX_LINE: - raise SecurityError("line too long") + raise exceptions.SecurityError("line too long") # Not mandatory but safe - https://tools.ietf.org/html/rfc7230#section-3.5 if not line.endswith(b"\r\n"): raise EOFError("line without CRLF") diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 468b5c15c..9197bd504 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -34,7 +34,7 @@ RedirectHandshake, SecurityError, ) -from ..extensions.base import ClientExtensionFactory, Extension +from ..extensions import ClientExtensionFactory, Extension from ..extensions.permessage_deflate import enable_client_permessage_deflate from ..headers import ( build_authorization_basic, diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index 14667925f..12e006911 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -15,8 +15,8 @@ import struct from typing import Any, Awaitable, Callable, NamedTuple, Optional, Sequence +from .. import extensions, frames from ..exceptions import PayloadTooBig, ProtocolError -from ..frames import Frame as NewFrame, Opcode try: @@ -28,15 +28,15 @@ class Frame(NamedTuple): fin: bool - opcode: Opcode + opcode: frames.Opcode data: bytes rsv1: bool = False rsv2: bool = False rsv3: bool = False @property - def new_frame(self) -> NewFrame: - return NewFrame( + def new_frame(self) -> frames.Frame: + return frames.Frame( self.opcode, self.data, self.fin, @@ -89,7 +89,7 @@ async def read( rsv3 = True if head1 & 0b00010000 else False try: - opcode = Opcode(head1 & 0b00001111) + opcode = frames.Opcode(head1 & 0b00001111) except ValueError as exc: raise ProtocolError("invalid opcode") from exc @@ -113,7 +113,7 @@ async def read( if mask: data = apply_mask(data, mask_bits) - new_frame = NewFrame(opcode, data, fin, rsv1, rsv2, rsv3) + new_frame = frames.Frame(opcode, data, fin, rsv1, rsv2, rsv3) if extensions is None: extensions = [] @@ -158,9 +158,6 @@ def write( write(self.new_frame.serialize(mask=mask, extensions=extensions)) -# at the bottom to allow circular import, because Extension depends on Frame -from .. import extensions # noqa - # Backwards compatibility with previously documented public APIs from ..frames import parse_close # noqa from ..frames import prepare_data # noqa diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index f8daf544b..45204ea3f 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -42,7 +42,7 @@ PayloadTooBig, ProtocolError, ) -from ..extensions.base import Extension +from ..extensions import Extension from ..frames import ( OP_BINARY, OP_CLOSE, diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index a7a98e006..10416abc5 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -39,7 +39,7 @@ InvalidUpgrade, NegotiationError, ) -from ..extensions.base import Extension, ServerExtensionFactory +from ..extensions import Extension, ServerExtensionFactory from ..extensions.permessage_deflate import enable_server_permessage_deflate from ..headers import build_extension, parse_extension, parse_subprotocol from ..http import USER_AGENT diff --git a/src/websockets/server.py b/src/websockets/server.py index 09ed63150..a53798f65 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -17,7 +17,7 @@ InvalidUpgrade, NegotiationError, ) -from .extensions.base import Extension, ServerExtensionFactory +from .extensions import Extension, ServerExtensionFactory from .headers import ( build_extension, parse_connection, diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 7406a60a8..ed4521d53 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -13,7 +13,7 @@ import urllib.parse from typing import Optional, Tuple -from .exceptions import InvalidURI +from . import exceptions __all__ = ["parse_uri", "WebSocketURI"] @@ -59,7 +59,7 @@ def parse_uri(uri: str) -> WebSocketURI: assert parsed.fragment == "" assert parsed.hostname is not None except AssertionError as exc: - raise InvalidURI(uri) from exc + raise exceptions.InvalidURI(uri) from exc secure = parsed.scheme == "wss" host = parsed.hostname @@ -72,7 +72,7 @@ def parse_uri(uri: str) -> WebSocketURI: # urllib.parse.urlparse accepts URLs with a username but without a # password. This doesn't make sense for HTTP Basic Auth credentials. if parsed.password is None: - raise InvalidURI(uri) + raise exceptions.InvalidURI(uri) user_info = (parsed.username, parsed.password) try: From d6188e71df9b5cf6dabadb352ac0ce489e4408cf Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 11 Jun 2021 09:57:29 +0200 Subject: [PATCH 0859/1539] Standardize import style. --- .../extensions/permessage_deflate.py | 45 ++++++++----------- src/websockets/headers.py | 41 +++++++++++------ 2 files changed, 46 insertions(+), 40 deletions(-) diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 5604fb8f9..0c9088a9e 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -10,14 +10,7 @@ import zlib from typing import Any, Dict, List, Optional, Sequence, Tuple, Union -from .. import frames -from ..exceptions import ( - DuplicateParameter, - InvalidParameterName, - InvalidParameterValue, - NegotiationError, - PayloadTooBig, -) +from .. import exceptions, frames from ..typing import ExtensionName, ExtensionParameter from .base import ClientExtensionFactory, Extension, ServerExtensionFactory @@ -140,7 +133,7 @@ def decode( max_length = 0 if max_size is None else max_size data = self.decoder.decompress(data, max_length) if self.decoder.unconsumed_tail: - raise PayloadTooBig(f"over size limit (? > {max_size} bytes)") + raise exceptions.PayloadTooBig(f"over size limit (? > {max_size} bytes)") # Allow garbage collection of the decoder if it won't be reused. if frame.fin and self.remote_no_context_takeover: @@ -224,40 +217,40 @@ def _extract_parameters( if name == "server_no_context_takeover": if server_no_context_takeover: - raise DuplicateParameter(name) + raise exceptions.DuplicateParameter(name) if value is None: server_no_context_takeover = True else: - raise InvalidParameterValue(name, value) + raise exceptions.InvalidParameterValue(name, value) elif name == "client_no_context_takeover": if client_no_context_takeover: - raise DuplicateParameter(name) + raise exceptions.DuplicateParameter(name) if value is None: client_no_context_takeover = True else: - raise InvalidParameterValue(name, value) + raise exceptions.InvalidParameterValue(name, value) elif name == "server_max_window_bits": if server_max_window_bits is not None: - raise DuplicateParameter(name) + raise exceptions.DuplicateParameter(name) if value in _MAX_WINDOW_BITS_VALUES: server_max_window_bits = int(value) else: - raise InvalidParameterValue(name, value) + raise exceptions.InvalidParameterValue(name, value) elif name == "client_max_window_bits": if client_max_window_bits is not None: - raise DuplicateParameter(name) + raise exceptions.DuplicateParameter(name) if is_server and value is None: # only in handshake requests client_max_window_bits = True elif value in _MAX_WINDOW_BITS_VALUES: client_max_window_bits = int(value) else: - raise InvalidParameterValue(name, value) + raise exceptions.InvalidParameterValue(name, value) else: - raise InvalidParameterName(name) + raise exceptions.InvalidParameterName(name) return ( server_no_context_takeover, @@ -344,7 +337,7 @@ def process_response_params( """ if any(other.name == self.name for other in accepted_extensions): - raise NegotiationError(f"received duplicate {self.name}") + raise exceptions.NegotiationError(f"received duplicate {self.name}") # Request parameters are available in instance variables. @@ -370,7 +363,7 @@ def process_response_params( if self.server_no_context_takeover: if not server_no_context_takeover: - raise NegotiationError("expected server_no_context_takeover") + raise exceptions.NegotiationError("expected server_no_context_takeover") # client_no_context_takeover # @@ -400,9 +393,9 @@ def process_response_params( else: if server_max_window_bits is None: - raise NegotiationError("expected server_max_window_bits") + raise exceptions.NegotiationError("expected server_max_window_bits") elif server_max_window_bits > self.server_max_window_bits: - raise NegotiationError("unsupported server_max_window_bits") + raise exceptions.NegotiationError("unsupported server_max_window_bits") # client_max_window_bits @@ -418,7 +411,7 @@ def process_response_params( if self.client_max_window_bits is None: if client_max_window_bits is not None: - raise NegotiationError("unexpected client_max_window_bits") + raise exceptions.NegotiationError("unexpected client_max_window_bits") elif self.client_max_window_bits is True: pass @@ -427,7 +420,7 @@ def process_response_params( if client_max_window_bits is None: client_max_window_bits = self.client_max_window_bits elif client_max_window_bits > self.client_max_window_bits: - raise NegotiationError("unsupported client_max_window_bits") + raise exceptions.NegotiationError("unsupported client_max_window_bits") return PerMessageDeflate( server_no_context_takeover, # remote_no_context_takeover @@ -526,7 +519,7 @@ def process_request_params( """ if any(other.name == self.name for other in accepted_extensions): - raise NegotiationError(f"skipped duplicate {self.name}") + raise exceptions.NegotiationError(f"skipped duplicate {self.name}") # Load request parameters in local variables. ( @@ -604,7 +597,7 @@ def process_request_params( else: if client_max_window_bits is None: - raise NegotiationError("required client_max_window_bits") + raise exceptions.NegotiationError("required client_max_window_bits") elif client_max_window_bits is True: client_max_window_bits = self.client_max_window_bits elif self.client_max_window_bits < client_max_window_bits: diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 12d2a4e94..82aaa8848 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -11,7 +11,7 @@ import re from typing import Callable, List, Optional, Sequence, Tuple, TypeVar, cast -from .exceptions import InvalidHeaderFormat, InvalidHeaderValue +from . import exceptions from .typing import ( ConnectionOption, ExtensionHeader, @@ -87,7 +87,7 @@ def parse_token(header: str, pos: int, header_name: str) -> Tuple[str, int]: """ match = _token_re.match(header, pos) if match is None: - raise InvalidHeaderFormat(header_name, "expected token", header, pos) + raise exceptions.InvalidHeaderFormat(header_name, "expected token", header, pos) return match.group(), match.end() @@ -110,7 +110,9 @@ def parse_quoted_string(header: str, pos: int, header_name: str) -> Tuple[str, i """ match = _quoted_string_re.match(header, pos) if match is None: - raise InvalidHeaderFormat(header_name, "expected quoted string", header, pos) + raise exceptions.InvalidHeaderFormat( + header_name, "expected quoted string", header, pos + ) return _unquote_re.sub(r"\1", match.group()[1:-1]), match.end() @@ -181,7 +183,9 @@ def parse_list( if peek_ahead(header, pos) == ",": pos = parse_OWS(header, pos + 1) else: - raise InvalidHeaderFormat(header_name, "expected comma", header, pos) + raise exceptions.InvalidHeaderFormat( + header_name, "expected comma", header, pos + ) # Remove extra delimiters before the next item. while peek_ahead(header, pos) == ",": @@ -244,7 +248,9 @@ def parse_upgrade_protocol( """ match = _protocol_re.match(header, pos) if match is None: - raise InvalidHeaderFormat(header_name, "expected protocol", header, pos) + raise exceptions.InvalidHeaderFormat( + header_name, "expected protocol", header, pos + ) return cast(UpgradeProtocol, match.group()), match.end() @@ -285,7 +291,7 @@ def parse_extension_item_param( # https://tools.ietf.org/html/rfc6455#section-9.1 says: the value # after quoted-string unescaping MUST conform to the 'token' ABNF. if _token_re.fullmatch(value) is None: - raise InvalidHeaderFormat( + raise exceptions.InvalidHeaderFormat( header_name, "invalid quoted header content", header, pos_before ) else: @@ -451,7 +457,9 @@ def parse_token68(header: str, pos: int, header_name: str) -> Tuple[str, int]: """ match = _token68_re.match(header, pos) if match is None: - raise InvalidHeaderFormat(header_name, "expected token68", header, pos) + raise exceptions.InvalidHeaderFormat( + header_name, "expected token68", header, pos + ) return match.group(), match.end() @@ -461,7 +469,7 @@ def parse_end(header: str, pos: int, header_name: str) -> None: """ if pos < len(header): - raise InvalidHeaderFormat(header_name, "trailing data", header, pos) + raise exceptions.InvalidHeaderFormat(header_name, "trailing data", header, pos) def parse_authorization_basic(header: str) -> Tuple[str, str]: @@ -479,9 +487,12 @@ def parse_authorization_basic(header: str) -> Tuple[str, str]: # https://tools.ietf.org/html/rfc7617#section-2 scheme, pos = parse_token(header, 0, "Authorization") if scheme.lower() != "basic": - raise InvalidHeaderValue("Authorization", f"unsupported scheme: {scheme}") + raise exceptions.InvalidHeaderValue( + "Authorization", + f"unsupported scheme: {scheme}", + ) if peek_ahead(header, pos) != " ": - raise InvalidHeaderFormat( + raise exceptions.InvalidHeaderFormat( "Authorization", "expected space after scheme", header, pos ) pos += 1 @@ -491,14 +502,16 @@ def parse_authorization_basic(header: str) -> Tuple[str, str]: try: user_pass = base64.b64decode(basic_credentials.encode()).decode() except binascii.Error: - raise InvalidHeaderValue( - "Authorization", "expected base64-encoded credentials" + raise exceptions.InvalidHeaderValue( + "Authorization", + "expected base64-encoded credentials", ) from None try: username, password = user_pass.split(":", 1) except ValueError: - raise InvalidHeaderValue( - "Authorization", "expected username:password credentials" + raise exceptions.InvalidHeaderValue( + "Authorization", + "expected username:password credentials", ) from None return username, password From 85e8799de819da0d6e370c61b10f72a5368fa40b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Jun 2021 14:26:46 +0200 Subject: [PATCH 0860/1539] Deduplicate State class. --- src/websockets/legacy/protocol.py | 9 +-------- tests/legacy/test_client_server.py | 2 +- tests/legacy/test_protocol.py | 3 ++- 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 45204ea3f..c7dd4d22c 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -12,7 +12,6 @@ import asyncio import codecs import collections -import enum import logging import random import struct @@ -33,6 +32,7 @@ cast, ) +from ..connection import State from ..datastructures import Headers from ..exceptions import ( ConnectionClosed, @@ -64,13 +64,6 @@ __all__ = ["WebSocketCommonProtocol"] -# A WebSocket connection goes through the following four states, in order: - - -class State(enum.IntEnum): - CONNECTING, OPEN, CLOSING, CLOSED = range(4) - - # In order to ensure consistency, the code always checks the current value of # WebSocketCommonProtocol.state before assigning a new value and never yields # between the check and the assignment. diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index d3e3f1e9f..f3e96ac4e 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -15,6 +15,7 @@ import urllib.request import warnings +from websockets.connection import State from websockets.datastructures import Headers from websockets.exceptions import ( ConnectionClosed, @@ -32,7 +33,6 @@ from websockets.legacy.client import * from websockets.legacy.handshake import build_response from websockets.legacy.http import read_response -from websockets.legacy.protocol import State from websockets.legacy.server import * from websockets.uri import parse_uri diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index ccbbffe7c..11589baf1 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -5,6 +5,7 @@ import unittest.mock import warnings +from websockets.connection import State from websockets.exceptions import ConnectionClosed, InvalidState from websockets.frames import ( OP_BINARY, @@ -17,7 +18,7 @@ ) from websockets.legacy.compatibility import loop_if_py_lt_38 from websockets.legacy.framing import Frame -from websockets.legacy.protocol import State, WebSocketCommonProtocol +from websockets.legacy.protocol import WebSocketCommonProtocol from .utils import MS, AsyncioTestCase From f96c0d0e71e781e8603561c1d0e3c2322ffd76ed Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Jun 2021 14:20:31 +0200 Subject: [PATCH 0861/1539] Add a utility function for broadcasting messages. Fix #870. --- docs/project/changelog.rst | 2 + docs/reference/utilities.rst | 5 +++ src/websockets/__init__.py | 2 + src/websockets/legacy/protocol.py | 73 ++++++++++++++++++++++++++----- src/websockets/legacy/server.py | 4 +- tests/legacy/test_protocol.py | 48 +++++++++++++++++++- 6 files changed, 120 insertions(+), 14 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index a0cd8f07c..c8553d17d 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -51,6 +51,8 @@ They may change at any time. * Added compatibility with Python 3.10. +* Added :func:`~websockets.broadcast` to send a message to many clients. + * Added support for reconnecting automatically by using :func:`~legacy.client.connect` as an asynchronous iterator. diff --git a/docs/reference/utilities.rst b/docs/reference/utilities.rst index 198e928b0..f1d89eddc 100644 --- a/docs/reference/utilities.rst +++ b/docs/reference/utilities.rst @@ -1,6 +1,11 @@ Utilities ========= +Broadcast +--------- + +.. autofunction:: websockets.broadcast + Data structures --------------- diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 8c69a6d63..6378b82cf 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -6,6 +6,7 @@ "AbortHandshake", "basic_auth_protocol_factory", "BasicAuthWebSocketServerProtocol", + "broadcast", "ClientConnection", "connect", "ConnectionClosed", @@ -56,6 +57,7 @@ "auth": ".legacy", "basic_auth_protocol_factory": ".legacy.auth", "BasicAuthWebSocketServerProtocol": ".legacy.auth", + "broadcast": ".legacy.protocol", "ClientConnection": ".client", "connect": ".legacy.client", "unix_connect": ".legacy.client", diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index c7dd4d22c..8a9557792 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -61,7 +61,7 @@ from .framing import Frame -__all__ = ["WebSocketCommonProtocol"] +__all__ = ["WebSocketCommonProtocol", "broadcast"] # In order to ensure consistency, the code always checks the current value of @@ -974,15 +974,7 @@ async def read_frame(self, max_size: Optional[int]) -> Frame: self.logger.debug("< %s", frame) return frame - async def write_frame( - self, fin: bool, opcode: int, data: bytes, *, _expected_state: int = State.OPEN - ) -> None: - # Defensive assertion for protocol compliance. - if self.state is not _expected_state: # pragma: no cover - raise InvalidState( - f"Cannot write to a WebSocket in the {self.state.name} state" - ) - + def write_frame_sync(self, fin: bool, opcode: int, data: bytes) -> None: frame = Frame(fin, Opcode(opcode), data) if self.debug: self.logger.debug("> %s", frame) @@ -992,6 +984,7 @@ async def write_frame( extensions=self.extensions, ) + async def drain(self) -> None: try: # drain() cannot be called concurrently by multiple coroutines: # http://bugs.python.org/issue29930. Remove this lock when no @@ -1006,6 +999,17 @@ async def write_frame( # with the correct code and reason. await self.ensure_open() + async def write_frame( + self, fin: bool, opcode: int, data: bytes, *, _state: int = State.OPEN + ) -> None: + # Defensive assertion for protocol compliance. + if self.state is not _state: # pragma: no cover + raise InvalidState( + f"Cannot write to a WebSocket in the {self.state.name} state" + ) + self.write_frame_sync(fin, opcode, data) + await self.drain() + async def write_close_frame(self, data: bytes = b"") -> None: """ Write a close frame if and only if the connection state is OPEN. @@ -1023,7 +1027,7 @@ async def write_close_frame(self, data: bytes = b"") -> None: self.logger.debug("= connection is CLOSING") # 7.1.2. Start the WebSocket Closing Handshake - await self.write_frame(True, OP_CLOSE, data, _expected_state=State.CLOSING) + await self.write_frame(True, OP_CLOSE, data, _state=State.CLOSING) async def keepalive_ping(self) -> None: """ @@ -1371,3 +1375,50 @@ def eof_received(self) -> None: """ self.reader.feed_eof() + + +def broadcast(websockets: Iterable[WebSocketCommonProtocol], message: Data) -> None: + """ + Broadcast a message to several WebSocket connections. + + A string (:class:`str`) is sent as a `Text frame`_. A bytestring or + bytes-like object (:class:`bytes`, :class:`bytearray`, or + :class:`memoryview`) is sent as a `Binary frame`_. + + .. _Text frame: https://tools.ietf.org/html/rfc6455#section-5.6 + .. _Binary frame: https://tools.ietf.org/html/rfc6455#section-5.6 + + :func:`broadcast` pushes the message synchronously to all connections even + if their write buffers overflow ``write_limit``. There's no backpressure. + + :func:`broadcast` skips silently connections that aren't open in order to + avoid errors on connections where the closing handshake is in progress. + + If you broadcast messages faster than a connection can handle them, + messages will pile up in its write buffer until the connection times out. + Keep low values for ``ping_interval`` and ``ping_timeout`` to prevent + excessive memory use by slow connections when you use :func:`broadcast`. + + Unlike :meth:`~websockets.server.WebSocketServerProtocol.send`, + :func:`broadcast` doesn't support sending fragmented messages. Indeed, + fragmentation is useful for sending large messages without buffering + them in memory, while :func:`broadcast` buffers one copy per connection + as fast as possible. + + :raises RuntimeError: if a connection is busy sending a fragmented message + :raises TypeError: for unsupported inputs + + """ + if not isinstance(message, (str, bytes, bytearray, memoryview)): + raise TypeError("data must be bytes, str, or iterable") + + opcode, data = prepare_data(message) + + for websocket in websockets: + if websocket.state is not State.OPEN: + continue + + if websocket._fragmented_message_waiter is not None: + raise RuntimeError("busy sending a fragmented message") + + websocket.write_frame_sync(True, opcode, data) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 10416abc5..ae749aa5f 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -836,8 +836,8 @@ async def _close(self) -> None: await asyncio.sleep(0, **loop_if_py_lt_38(self.loop)) # Close OPEN connections with status code 1001. Since the server was - # closed, handshake() closes OPENING conections with a HTTP 503 error. - # Wait until all connections are closed. + # closed, handshake() closes OPENING connections with a HTTP 503 + # error. Wait until all connections are closed. # asyncio.wait doesn't accept an empty first argument if self.websockets: diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 11589baf1..97aa01596 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -18,7 +18,7 @@ ) from websockets.legacy.compatibility import loop_if_py_lt_38 from websockets.legacy.framing import Frame -from websockets.legacy.protocol import WebSocketCommonProtocol +from websockets.legacy.protocol import WebSocketCommonProtocol, broadcast from .utils import MS, AsyncioTestCase @@ -1388,6 +1388,52 @@ def test_remote_close_during_send(self): # There is no test_local_close_during_send because this cannot really # happen, considering that writes are serialized. + def test_broadcast_text(self): + broadcast([self.protocol], "café") + self.assertOneFrameSent(True, OP_TEXT, "café".encode("utf-8")) + + def test_broadcast_binary(self): + broadcast([self.protocol], b"tea") + self.assertOneFrameSent(True, OP_BINARY, b"tea") + + def test_broadcast_type_error(self): + with self.assertRaises(TypeError): + broadcast([self.protocol], ["ca", "fé"]) + + def test_broadcast_no_clients(self): + broadcast([], "café") + self.assertNoFrameSent() + + def test_broadcast_two_clients(self): + broadcast([self.protocol, self.protocol], "café") + self.assertFramesSent( + (True, OP_TEXT, "café".encode("utf-8")), + (True, OP_TEXT, "café".encode("utf-8")), + ) + + def test_broadcast_skips_closed_connection(self): + self.close_connection() + + broadcast([self.protocol], "café") + self.assertNoFrameSent() + + def test_broadcast_skips_closing_connection(self): + close_task = self.half_close_connection_local() + + broadcast([self.protocol], "café") + self.assertNoFrameSent() + + self.loop.run_until_complete(close_task) # cleanup + + def test_broadcast_within_fragmented_text(self): + self.make_drain_slow() + self.loop.create_task(self.protocol.send(["ca", "fé"])) + self.run_loop_once() + self.assertOneFrameSent(False, OP_TEXT, "ca".encode("utf-8")) + + with self.assertRaises(RuntimeError): + broadcast([self.protocol], "café") + class ServerTests(CommonTests, AsyncioTestCase): def setUp(self): From a3958847c033c5d532eaef0b335c103171f9f7e7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 6 Jun 2021 20:02:02 +0200 Subject: [PATCH 0862/1539] Add broadcast benchmarking scripts. --- experiments/broadcast/clients.py | 61 ++++++++++++ experiments/broadcast/server.py | 153 +++++++++++++++++++++++++++++++ 2 files changed, 214 insertions(+) create mode 100644 experiments/broadcast/clients.py create mode 100644 experiments/broadcast/server.py diff --git a/experiments/broadcast/clients.py b/experiments/broadcast/clients.py new file mode 100644 index 000000000..fe39dfe05 --- /dev/null +++ b/experiments/broadcast/clients.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python + +import asyncio +import statistics +import sys +import time + +import websockets + + +LATENCIES = {} + + +async def log_latency(interval): + while True: + await asyncio.sleep(interval) + p = statistics.quantiles(LATENCIES.values(), n=100) + print(f"clients = {len(LATENCIES)}") + print( + f"p50 = {p[49] / 1e6:.1f}ms, " + f"p95 = {p[94] / 1e6:.1f}ms, " + f"p99 = {p[98] / 1e6:.1f}ms" + ) + print() + + +async def client(): + try: + async with websockets.connect( + "ws://localhost:8765", + ping_timeout=None, + ) as websocket: + async for msg in websocket: + client_time = time.time_ns() + server_time = int(msg[:19].decode()) + LATENCIES[websocket] = client_time - server_time + except Exception as exc: + print(exc) + + +async def main(count, interval): + asyncio.create_task(log_latency(interval)) + clients = [] + for _ in range(count): + clients.append(asyncio.create_task(client())) + await asyncio.sleep(0.001) # 1ms between each connection + await asyncio.wait(clients) + + +if __name__ == "__main__": + try: + count = int(sys.argv[1]) + interval = float(sys.argv[2]) + except Exception as exc: + print(f"Usage: {sys.argv[0]} count interval") + print(" Connect clients e.g. 1000") + print(" Report latency every seconds e.g. 1") + print() + print(exc) + else: + asyncio.run(main(count, interval)) diff --git a/experiments/broadcast/server.py b/experiments/broadcast/server.py new file mode 100644 index 000000000..355d6b5e9 --- /dev/null +++ b/experiments/broadcast/server.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python + +import asyncio +import functools +import os +import sys +import time + +import websockets + + +CLIENTS = set() + + +async def send(websocket, message): + try: + await websocket.send(message) + except websockets.ConnectionClosed: + pass + + +async def relay(queue, websocket): + while True: + message = await queue.get() + await websocket.send(message) + + +class PubSub: + def __init__(self): + self.waiter = asyncio.Future() + + def publish(self, value): + waiter, self.waiter = self.waiter, asyncio.Future() + waiter.set_result((value, self.waiter)) + + async def subscribe(self): + waiter = self.waiter + while True: + value, waiter = await waiter + yield value + + __aiter__ = subscribe + + +PUBSUB = PubSub() + + +async def handler(websocket, path, method=None): + if method in ["default", "naive", "task", "wait"]: + CLIENTS.add(websocket) + try: + await websocket.wait_closed() + finally: + CLIENTS.remove(websocket) + elif method == "queue": + queue = asyncio.Queue() + relay_task = asyncio.create_task(relay(queue, websocket)) + CLIENTS.add(queue) + try: + await websocket.wait_closed() + finally: + CLIENTS.remove(queue) + relay_task.cancel() + elif method == "pubsub": + async for message in PUBSUB: + await websocket.send(message) + else: + raise NotImplementedError(f"unsupported method: {method}") + + +async def broadcast(method, size, delay): + """Broadcast messages at regular intervals.""" + load_average = 0 + time_average = 0 + pc1, pt1 = time.perf_counter_ns(), time.process_time_ns() + await asyncio.sleep(delay) + while True: + print(f"clients = {len(CLIENTS)}") + pc0, pt0 = time.perf_counter_ns(), time.process_time_ns() + load_average = 0.9 * load_average + 0.1 * (pt0 - pt1) / (pc0 - pc1) + print( + f"load = {(pt0 - pt1) / (pc0 - pc1) * 100:.1f}% / " + f"average = {load_average * 100:.1f}%, " + f"late = {(pc0 - pc1 - delay * 1e9) / 1e6:.1f} ms" + ) + pc1, pt1 = pc0, pt0 + + assert size > 20 + message = str(time.time_ns()).encode() + b" " + os.urandom(size - 20) + + if method == "default": + websockets.broadcast(CLIENTS, message) + elif method == "naive": + # Since the loop can yield control, make a copy of CLIENTS + # to avoid: RuntimeError: Set changed size during iteration + for websocket in CLIENTS.copy(): + await send(websocket, message) + elif method == "task": + for websocket in CLIENTS: + asyncio.create_task(send(websocket, message)) + elif method == "wait": + if CLIENTS: # asyncio.wait doesn't accept an empty list + await asyncio.wait( + [ + asyncio.create_task(send(websocket, message)) + for websocket in CLIENTS + ] + ) + elif method == "queue": + for queue in CLIENTS: + queue.put_nowait(message) + elif method == "pubsub": + PUBSUB.publish(message) + else: + raise NotImplementedError(f"unsupported method: {method}") + + pc2 = time.perf_counter_ns() + wait = delay + (pc1 - pc2) / 1e9 + time_average = 0.9 * time_average + 0.1 * (pc2 - pc1) + print( + f"broadcast = {(pc2 - pc1) / 1e6:.1f}ms / " + f"average = {time_average / 1e6:.1f}ms, " + f"wait = {wait * 1e3:.1f}ms" + ) + await asyncio.sleep(wait) + print() + + +async def main(method, size, delay): + async with websockets.serve( + functools.partial(handler, method=method), + "localhost", + 8765, + compression=None, + ping_timeout=None, + ): + await broadcast(method, size, delay) + + +if __name__ == "__main__": + try: + method = sys.argv[1] + assert method in ["default", "naive", "task", "wait", "queue", "pubsub"] + size = int(sys.argv[2]) + delay = float(sys.argv[3]) + except Exception as exc: + print(f"Usage: {sys.argv[0]} method size delay") + print(" Start a server broadcasting messages with e.g. naive") + print(" Send a payload of bytes every seconds") + print() + print(exc) + else: + asyncio.run(main(method, size, delay)) From 185e9c6e076aecdff0aee3e858049f569cc0ed8e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Jun 2021 17:55:29 +0200 Subject: [PATCH 0863/1539] Discuss broadcasting messages. Fix #653. --- docs/spelling_wordlist.txt | 1 + docs/topics/broadcast.rst | 349 +++++++++++++++++++++++++++++++++++++ docs/topics/index.rst | 1 + 3 files changed, 351 insertions(+) create mode 100644 docs/topics/broadcast.rst diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 8346acefa..86a41ff7a 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -59,6 +59,7 @@ retransmit runtime scalable serializers +stateful subclasses subclassing subprotocol diff --git a/docs/topics/broadcast.rst b/docs/topics/broadcast.rst new file mode 100644 index 000000000..c43a8fa36 --- /dev/null +++ b/docs/topics/broadcast.rst @@ -0,0 +1,349 @@ +Broadcasting messages +===================== + +.. currentmodule: websockets + + +.. note:: + + If you just want to send a message to all connected clients, use + :func:`~websockets.broadcast`. + + If you want to learn about its design in depth, continue reading this + document. + +WebSocket servers often send the same message to all connected clients to a +subset of clients for which the message is relevant. + +Let's explore options for broadcasting a message, explain the design +of :func:`~websockets.broadcast`, and discuss alternatives. + +For each option, we'll provide a connection handler called ``handler()`` and a +function or coroutine called ``broadcast()`` that sends a message to all +connected clients. + +Integrating them is left as an exercise for the reader. You could start with:: + + import asyncio + import websockets + + async def handler(websocket, path): + ... + + async def broadcast(message): + ... + + async def broadcast_messages(): + while True: + await asyncio.sleep(1) + message = ... # your application logic goes here + await broadcast(message) + + async def main(): + async with websockets.serve(handler, "localhost", 8765): + await broadcast_messages() # runs forever + + if __name__ == "__main__": + asyncio.run(main()) + +``broadcast_messages()`` must yield control to the event loop between each +message, or else it will never let the server run. That's why it includes +``await asyncio.sleep(1)``. + +A complete example is available in the `experiments/broadcast`_ directory. + +.. _experiments/broadcast: https://github.com/aaugustin/websockets/tree/main/experiments/broadcast + +The naive way +------------- + +The most obvious way to send a message to all connected clients consists in +keeping track of them and sending the message to each of them. + +Here's a connection handler that registers clients in a global variable:: + + CLIENTS = set() + + async def handler(websocket, path): + CLIENTS.add(websocket) + try: + await websocket.wait_closed() + finally: + CLIENTS.remove(websocket) + +This implementation assumes that the client will never send any messages. If +you'd rather not make this assumption, you can change:: + + await websocket.wait_closed() + +to:: + + async for _ in websocket: + pass + +Here's a coroutine that broadcasts a message to all clients:: + + async def broadcast(message): + for websocket in CLIENTS.copy(): + try: + await websocket.send(message) + except websockets.ConnectionClosed: + pass + +There are two tricks in this version of ``broadcast()``. + +First, it makes a copy of ``CLIENTS`` before iterating it. Else, if a client +connects or disconnects while ``broadcast()`` is running, the loop would fail +with:: + + RuntimeError: Set changed size during iteration + +Second, it ignores :exc:`~exceptions.ConnectionClosed` exceptions because a +client could disconnect between the moment ``broadcast()`` makes a copy of +``CLIENTS`` and the moment it sends a message to this client. This is fine: a +client that disconnected doesn't belongs to "all connected clients" anymore. + +The naive way can be very fast. Indeed, if all connections have enough free +space in their write buffers, ``await websocket.send(message)`` writes the +message and returns immediately, as it doesn't need to wait for the buffer to +drain. In this case, ``broadcast()`` doesn't yield control to the event loop, +which minimizes overhead. + +The naive way can also fail badly. If the write buffer of a connection reaches +``write_limit``, ``broadcast()`` waits for the buffer to drain before sending +the message to other clients. This can cause a massive drop in performance. + +As a consequence, this pattern works only when write buffers never fill up, +which is usually outside of the control of the server. + +If you know for sure that you will never write more than ``write_limit`` bytes +within ``ping_interval + ping_timeout``, then websockets will terminate slow +connections before the write buffer has time to fill up. + +Don't set extreme ``write_limit``, ``ping_interval``, and ``ping_timeout`` +values to ensure that this condition holds. Set reasonable values and use the +built-in :func:`~websockets.broadcast` function instead. + +The concurrent way +------------------ + +The naive way didn't work well because it serialized writes, while the whole +point of asynchronous I/O is to perform I/O concurrently. + +Let's modify ``broadcast()`` to send messages concurrently:: + + async def send(websocket, message): + try: + await websocket.send(message) + except websockets.ConnectionClosed: + pass + + def broadcast(message): + for websocket in CLIENTS: + asyncio.create_task(send(websocket, message)) + +We move the error handling logic in a new coroutine and we schedule +a :class:`~asyncio.Task` to run it instead of executing it immediately. + +Since ``broadcast()`` no longer awaits coroutines, we can make it a function +rather than a coroutine and do away with the copy of ``CLIENTS``. + +This version of ``broadcast()`` makes clients independent from one another: a +slow client won't block others. As a side effect, it makes messages +independent from one another. + +If you broadcast several messages, there is no strong guarantee that they will +be sent in the expected order. Fortunately, the event loop runs tasks in the +order in which they are created, so the order is correct in practice. + +Technically, this is an implementation detail of the event loop. However, it +seems unlikely for an event loop to run tasks in an order other than FIFO. + +If you wanted to enforce the order without relying this implementation detail, +you could be tempted to wait until all clients have received the message:: + + async def broadcast(message): + if CLIENTS: # asyncio.wait doesn't accept an empty list + await asyncio.wait([ + asyncio.create_task(send(websocket, message)) + for websocket in CLIENTS + ]) + +However, this doesn't really work in practice. Quite often, it will block +until the slowest client times out. + +Backpressure meets broadcast +---------------------------- + +At this point, it becomes apparent that backpressure, usually a good practice, +doesn't work well when broadcasting a message to thousands of clients. + +When you're sending messages to a single client, you don't want to send them +faster than the network can transfer them and the client accept them. This is +why :meth:`~server.WebSocketServerProtocol.send` checks if the write buffer +is full and, if it is, waits until it drain, giving the network and the +client time to catch up. This provides backpressure. + +Without backpressure, you could pile up data in the write buffer until the +server process runs out of memory and the operating system kills it. + +The :meth:`~server.WebSocketServerProtocol.send` API is designed to enforce +backpressure by default. This helps users of websockets write robust programs +even if they never heard about backpressure. + +For comparison, :class:`asyncio.StreamWriter` requires users to understand +backpressure and to await :meth:`~asyncio.StreamWriter.drain` explicitly +after each :meth:`~asyncio.StreamWriter.write`. + +When broadcasting messages, backpressure consists in slowing down all clients +in an attempt to let the slowest client catch up. With thousands of clients, +the slowest one is probably timing out and isn't going to receive the message +anyway. So it doesn't make sense to synchronize with the slowest client. + +How do we avoid running out of memory when slow clients can't keep up with the +broadcast rate, then? The most straightforward option is to disconnect them. + +If a client gets too far behind, eventually it reaches the limit defined by +``ping_timeout`` and websockets terminates the connection. You can read the +discussion of :doc:`keepalive and timeouts <./timeouts>` for details. + +How :func:`~websockets.broadcast` works +--------------------------------------- + +The built-in :func:`~websockets.broadcast` function is similar to the naive +way. The main difference is that it doesn't apply backpressure. + +This provides the best performance by avoiding the overhead of scheduling and +running one task per client. + +Also, when sending text messages, encoding to UTF-8 happens only once rather +than once per client, providing a small performance gain. + +Per-client queues +----------------- + +At this point, we deal with slow clients rather brutally: we disconnect then. + +Can we do better? For example, we could decide to skip or to batch messages, +depending on how far behind a client is. + +To implement this logic, we can create a queue of messages for each client and +run a task that gets messages from the queue and sends them to the client:: + + import asyncio + + CLIENTS = set() + + async def relay(queue, websocket): + while True: + # Implement custom logic based on queue.qsize() and + # websocket.transport.get_write_buffer_size() here. + message = await queue.get() + await websocket.send(message) + + async def handler(websocket, path): + queue = asyncio.Queue() + relay_task = asyncio.create_task(relay(queue, websocket)) + CLIENTS.add(queue) + try: + await websocket.wait_closed() + finally: + CLIENTS.remove(queue) + relay_task.cancel() + +Then we can broadcast a message by pushing it to all queues:: + + def broadcast(message): + for queue in CLIENTS: + queue.put_nowait(message) + +The queues provide an additional buffer between the ``broadcast()`` function +and clients. This makes it easier to support slow clients without excessive +memory usage because queued messages aren't duplicated to write buffers +until ``relay()`` processes them. + +Publish–subscribe +----------------- + +Can we avoid centralizing the list of connected clients in a global variable? + +If each client subscribes to a stream a messages, then broadcasting becomes as +simple as publishing a message to the stream. + +Here's a message stream that supports multiple consumers:: + + class PubSub: + def __init__(self): + self.waiter = asyncio.Future() + + def publish(self, value): + waiter, self.waiter = self.waiter, asyncio.Future() + waiter.set_result((value, self.waiter)) + + async def subscribe(self): + waiter = self.waiter + while True: + value, waiter = await waiter + yield value + + __aiter__ = subscribe + + PUBSUB = PubSub() + +The stream is implemented as a linked list of futures. It isn't necessary to +synchronize consumers. They can read the stream at their own pace, +independently from one another. Once all consumers read a message, there are +no references left, therefore the garbage collector deletes it. + +The connection handler subscribes to the stream and sends messages:: + + async def handler(websocket, path): + async for message in PUBSUB: + await websocket.send(message) + +The broadcast function publishes to the stream:: + + def broadcast(message): + PUBSUB.publish(message) + +Like per-client queues, this version supports slow clients with limited memory +usage. Unlike per-client queues, it makes it difficult to tell how far behind +a client is. The ``PubSub`` class could be extended or refactored to provide +this information. + +The ``for`` loop is gone from this version of the ``broadcast()`` function. +However, there's still a ``for`` loop iterating on all clients hidden deep +inside :mod:`asyncio`. When ``publish()`` sets the result of the ``waiter`` +future, :mod:`asyncio` loops on callbacks registered with this future and +schedules them. This is how connection handlers receive the next value from +the asynchronous iterator returned by ``subscribe()``. + +Performance considerations +-------------------------- + +The built-in :func:`~websockets.broadcast` function sends all messages without +yielding control to the event loop. So does the naive way when the network +and clients are fast and reliable. + +For each client, a WebSocket frame is prepared and sent to the network. This +is the minimum amount of work required to broadcast a message. + +It would be tempting to prepare a frame and reuse it for all connections. +However, this isn't possible in general for two reasons: + +* Clients can negotiate different extensions. You would have to enforce the + same extensions with the same parameters. For example, you would have to + select some compression settings and reject clients that cannot support + these settings. + +* Extensions can be stateful, producing different encodings of the same + message depending on previous messages. For example, you would have to + disable context takeover to make compression stateless, resulting in poor + compression rates. + +All other patterns discussed above yield control to the event loop once per +client because messages are sent by different tasks. This makes them slower +than the built-in :func:`~websockets.broadcast` function. + +There is no major difference between the performance of per-message queues and +publish–subscribe. diff --git a/docs/topics/index.rst b/docs/topics/index.rst index 9269c585a..dc192a290 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -6,6 +6,7 @@ Topics deployment authentication + broadcast compression timeouts design From fa497e501b505ba7255315cb660128d01cfbef44 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Jun 2021 22:13:15 +0200 Subject: [PATCH 0864/1539] Take advantage of broadcast() in examples. Fix #995. --- docs/intro/index.rst | 2 +- example/counter.py | 20 ++++---------------- example/django/notifications.py | 9 ++++++--- 3 files changed, 11 insertions(+), 20 deletions(-) diff --git a/docs/intro/index.rst b/docs/intro/index.rst index 49e17b668..1df92ed59 100644 --- a/docs/intro/index.rst +++ b/docs/intro/index.rst @@ -181,7 +181,7 @@ unregister them when they disconnect. connected.add(websocket) try: # Broadcast a message to all connected clients. - await asyncio.wait([ws.send("Hello!") for ws in connected]) + websockets.broadcast(connected, "Hello!") await asyncio.sleep(10) finally: # Unregister. diff --git a/example/counter.py b/example/counter.py index a9ed61893..e41f6fabb 100755 --- a/example/counter.py +++ b/example/counter.py @@ -22,23 +22,11 @@ def users_event(): return json.dumps({"type": "users", "count": len(USERS)}) -async def broadcast(message): - # asyncio.wait doesn't accept an empty list - if not USERS: - return - # Ignore return value. If a user disconnects before we send - # the message to them, there's nothing we can do anyway. - await asyncio.wait([ - asyncio.create_task(user.send(message)) - for user in USERS - ]) - - async def counter(websocket, path): try: # Register user USERS.add(websocket) - await broadcast(users_event()) + websockets.broadcast(USERS, users_event()) # Send current state to user await websocket.send(state_event()) # Manage state changes @@ -46,16 +34,16 @@ async def counter(websocket, path): data = json.loads(message) if data["action"] == "minus": STATE["value"] -= 1 - await broadcast(state_event()) + websockets.broadcast(USERS, state_event()) elif data["action"] == "plus": STATE["value"] += 1 - await broadcast(state_event()) + websockets.broadcast(USERS, state_event()) else: logging.error("unsupported event: %s", data) finally: # Unregister user USERS.remove(websocket) - await broadcast(users_event()) + websockets.broadcast(USERS, users_event()) async def main(): diff --git a/example/django/notifications.py b/example/django/notifications.py index 641643f92..41fb719dc 100644 --- a/example/django/notifications.py +++ b/example/django/notifications.py @@ -55,9 +55,12 @@ async def process_events(): payload = message["data"].decode() # Broadcast event to all users who have permissions to see it. event = json.loads(payload) - for websocket, connection in CONNECTIONS.items(): - if event["content_type_id"] in connection["content_type_ids"]: - asyncio.create_task(websocket.send(payload)) + recipients = ( + websocket + for websocket, connection in CONNECTIONS.items() + if event["content_type_id"] in connection["content_type_ids"] + ) + websockets.broadcast(recipients, payload) async def main(): From 4c8d987e02895e0c372ddfca63b73d86aace35e9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Jun 2021 22:54:15 +0200 Subject: [PATCH 0865/1539] Add performance tips. Fix #968. --- docs/spelling_wordlist.txt | 1 + docs/topics/index.rst | 1 + docs/topics/performance.rst | 20 ++++++++++++++++++++ 3 files changed, 22 insertions(+) create mode 100644 docs/topics/performance.rst diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 86a41ff7a..b3389a920 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -73,6 +73,7 @@ unregister uple username uvicorn +uvloop virtualenv WebSocket websocket diff --git a/docs/topics/index.rst b/docs/topics/index.rst index dc192a290..993303106 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -12,3 +12,4 @@ Topics design memory security + performance diff --git a/docs/topics/performance.rst b/docs/topics/performance.rst new file mode 100644 index 000000000..45e23b239 --- /dev/null +++ b/docs/topics/performance.rst @@ -0,0 +1,20 @@ +Performance +=========== + +Here are tips to optimize performance. + +uvloop +------ + +You can make a websockets application faster by running it with uvloop_. + +(This advice isn't specific to websockets. It applies to any :mod:`asyncio` +application.) + +.. _uvloop: https://github.com/MagicStack/uvloop + +broadcast +--------- + +:func:`~websockets.broadcast` is the most efficient way to send a message to +many clients. From 897d1b27f362c4113308e76b989321ece3bfd4af Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 14 Jun 2021 20:53:10 +0200 Subject: [PATCH 0866/1539] Validate subprotocols argument. Fix #263. --- src/websockets/headers.py | 19 +++++++++++++++++-- src/websockets/legacy/client.py | 4 ++++ src/websockets/legacy/server.py | 10 +++++++++- tests/legacy/test_client_server.py | 9 +++++++++ tests/test_headers.py | 15 +++++++++++++++ 5 files changed, 54 insertions(+), 3 deletions(-) diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 82aaa8848..181e976e3 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -29,6 +29,7 @@ "build_extension", "parse_subprotocol", "build_subprotocol", + "validate_subprotocols", "build_www_authenticate_basic", "parse_authorization_basic", "build_authorization_basic", @@ -417,19 +418,33 @@ def parse_subprotocol(header: str) -> List[Subprotocol]: parse_subprotocol_list = parse_subprotocol # alias for backwards compatibility -def build_subprotocol(protocols: Sequence[Subprotocol]) -> str: +def build_subprotocol(subprotocols: Sequence[Subprotocol]) -> str: """ Build a ``Sec-WebSocket-Protocol`` header. This is the reverse of :func:`parse_subprotocol`. """ - return ", ".join(protocols) + return ", ".join(subprotocols) build_subprotocol_list = build_subprotocol # alias for backwards compatibility +def validate_subprotocols(subprotocols: Sequence[Subprotocol]) -> None: + """ + Validate that ``subprotocols`` is suitable for :func:`build_subprotocol`. + + """ + if not isinstance(subprotocols, Sequence): + raise TypeError("subprotocols must be a list") + if isinstance(subprotocols, str): + raise TypeError("subprotocols must be a list, not a str") + for subprotocol in subprotocols: + if not _token_re.fullmatch(subprotocol): + raise ValueError(f"invalid subprotocol: {subprotocol}") + + def build_www_authenticate_basic(realm: str) -> str: """ Build a ``WWW-Authenticate`` header for HTTP Basic Auth. diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 9197bd504..d94e9d7b9 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -42,6 +42,7 @@ build_subprotocol, parse_extension, parse_subprotocol, + validate_subprotocols, ) from ..http import USER_AGENT, build_host from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol @@ -559,6 +560,9 @@ def __init__( elif compression is not None: raise ValueError(f"unsupported compression: {compression}") + if subprotocols is not None: + validate_subprotocols(subprotocols) + factory = functools.partial( create_protocol, ping_interval=ping_interval, diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index ae749aa5f..80d72f93e 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -41,7 +41,12 @@ ) from ..extensions import Extension, ServerExtensionFactory from ..extensions.permessage_deflate import enable_server_permessage_deflate -from ..headers import build_extension, parse_extension, parse_subprotocol +from ..headers import ( + build_extension, + parse_extension, + parse_subprotocol, + validate_subprotocols, +) from ..http import USER_AGENT from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol from .compatibility import loop_if_py_lt_38 @@ -1042,6 +1047,9 @@ def __init__( elif compression is not None: raise ValueError(f"unsupported compression: {compression}") + if subprotocols is not None: + validate_subprotocols(subprotocols) + factory = functools.partial( create_protocol, ws_handler, diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index f3e96ac4e..755fcefdd 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1013,6 +1013,15 @@ def test_subprotocol(self): self.assertEqual(server_subprotocol, repr("chat")) self.assertEqual(self.client.subprotocol, "chat") + def test_invalid_subprotocol_server(self): + with self.assertRaises(TypeError): + self.start_server(subprotocols="sip") + + @with_server() + def test_invalid_subprotocol_client(self): + with self.assertRaises(TypeError): + self.start_client(subprotocols="sip") + @with_server(subprotocols=["superchat"]) @with_client("/subprotocol", subprotocols=["otherchat"]) def test_subprotocol_not_accepted(self): diff --git a/tests/test_headers.py b/tests/test_headers.py index 26d85fa5e..badec5a86 100644 --- a/tests/test_headers.py +++ b/tests/test_headers.py @@ -126,6 +126,21 @@ def test_parse_subprotocol_invalid_header(self): with self.assertRaises(InvalidHeaderFormat): parse_subprotocol(header) + def test_validate_subprotocols(self): + for subprotocols in [[], ["sip"], ["v1.usp"], ["sip", "v1.usp"]]: + with self.subTest(subprotocols=subprotocols): + validate_subprotocols(subprotocols) + + def test_validate_subprotocols_invalid(self): + for subprotocols, exception in [ + ({"sip": None}, TypeError), + ("sip", TypeError), + ([""], ValueError), + ]: + with self.subTest(subprotocols=subprotocols): + with self.assertRaises(exception): + validate_subprotocols(subprotocols) + def test_build_www_authenticate_basic(self): # Test vector from RFC 7617 self.assertEqual( From 047222c434fa8f54d92bb904438267113a09f5a1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 16 Jun 2021 07:42:46 +0200 Subject: [PATCH 0867/1539] Add an abstraction for close codes and reasons. --- src/websockets/frames.py | 104 ++++++++++++++++++++++++++------------- 1 file changed, 70 insertions(+), 34 deletions(-) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 510fc6198..175478794 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -372,17 +372,7 @@ def parse_close(data: bytes) -> Tuple[int, str]: :raises UnicodeDecodeError: if the reason isn't valid UTF-8 """ - length = len(data) - if length >= 2: - (code,) = struct.unpack("!H", data[:2]) - check_close(code) - reason = data[2:].decode("utf-8") - return code, reason - elif length == 0: - return 1005, "" - else: - assert length == 1 - raise exceptions.ProtocolError("close frame too short") + return dataclasses.astuple(Close.parse(data)) # type: ignore def serialize_close(code: int, reason: str) -> bytes: @@ -392,38 +382,84 @@ def serialize_close(code: int, reason: str) -> bytes: This is the reverse of :func:`parse_close`. """ - check_close(code) - return struct.pack("!H", code) + reason.encode("utf-8") + return Close(code, reason).serialize() -def check_close(code: int) -> None: +def format_close(code: int, reason: str) -> str: """ - Check that the close code has an acceptable value for a close frame. - - :raises ~websockets.exceptions.ProtocolError: if the close code - is invalid + Display a human-readable version of the close code and reason. """ - if not (code in EXTERNAL_CLOSE_CODES or 3000 <= code < 5000): - raise exceptions.ProtocolError("invalid status code") + return str(Close(code, reason)) -def format_close(code: int, reason: str) -> str: +@dataclasses.dataclass +class Close: """ - Display a human-readable version of the close code and reason. + WebSocket close code and reason. """ - if 3000 <= code < 4000: - explanation = "registered" - elif 4000 <= code < 5000: - explanation = "private use" - else: - explanation = CLOSE_CODES.get(code, "unknown") - result = f"code = {code} ({explanation}), " - if reason: - result += f"reason = {reason}" - else: - result += "no reason" + code: int + reason: str + + def __str__(self) -> str: + """ + Return a human-readable represention of a close code and reason. + + """ + if 3000 <= self.code < 4000: + explanation = "registered" + elif 4000 <= self.code < 5000: + explanation = "private use" + else: + explanation = CLOSE_CODES.get(self.code, "unknown") + result = f"code = {self.code} ({explanation}), " + + if self.reason: + result += f"reason = {self.reason}" + else: + result += "no reason" + + return result - return result + @classmethod + def parse(cls, data: bytes) -> Close: + """ + Parse the payload of a close frame. + + :raises ~websockets.exceptions.ProtocolError: if data is ill-formed + :raises UnicodeDecodeError: if the reason isn't valid UTF-8 + + """ + if len(data) >= 2: + (code,) = struct.unpack("!H", data[:2]) + reason = data[2:].decode("utf-8") + close = cls(code, reason) + close.check() + return close + elif len(data) == 0: + return cls(1005, "") + else: + raise exceptions.ProtocolError("close frame too short") + + def serialize(self) -> bytes: + """ + Serialize the payload for a close frame. + + This is the reverse of :meth:`parse`. + + """ + self.check() + return struct.pack("!H", self.code) + self.reason.encode("utf-8") + + def check(self) -> None: + """ + Check that the close code has a valid value for a close frame. + + :raises ~websockets.exceptions.ProtocolError: if the close code + is invalid + + """ + if not (self.code in EXTERNAL_CLOSE_CODES or 3000 <= self.code < 5000): + raise exceptions.ProtocolError("invalid status code") From 5bb653f2da9351ee00275e04c63bef19da565528 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 16 Jun 2021 07:59:27 +0200 Subject: [PATCH 0868/1539] Adopt the abstraction for close code and reasons. Preserve backwards compatibility in the framing module. --- src/websockets/__main__.py | 4 +-- src/websockets/connection.py | 14 ++++---- src/websockets/exceptions.py | 2 +- src/websockets/frames.py | 37 ++------------------- src/websockets/legacy/framing.py | 32 +++++++++++++++--- src/websockets/legacy/protocol.py | 14 ++++---- tests/extensions/test_permessage_deflate.py | 4 +-- tests/legacy/test_framing.py | 31 +++++++++++++++++ tests/legacy/test_protocol.py | 22 ++++++------ tests/test_connection.py | 6 ++-- tests/test_frames.py | 34 +++++++++---------- 11 files changed, 110 insertions(+), 90 deletions(-) diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index dae165cd2..6347ee278 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -9,7 +9,7 @@ from typing import Any, Set from .exceptions import ConnectionClosed -from .frames import format_close +from .frames import Close from .legacy.client import connect @@ -143,7 +143,7 @@ async def run_client( finally: await websocket.close() - close_status = format_close(websocket.close_code, websocket.close_reason) + close_status = Close(websocket.close_code, websocket.close_reason) print_over_input(f"Connection closed: {close_status}.") diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 57d3a2227..067ad54ce 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -14,9 +14,8 @@ OP_PING, OP_PONG, OP_TEXT, + Close, Frame, - parse_close, - serialize_close, ) from .http11 import Request, Response from .streams import StreamReader @@ -202,7 +201,7 @@ def send_close(self, code: Optional[int] = None, reason: str = "") -> None: raise ValueError("cannot send a reason without a code") data = b"" else: - data = serialize_close(code, reason) + data = Close(code, reason).serialize() self.send_frame(Frame(OP_CLOSE, data)) # send_frame() guarantees that self.state is OPEN at this point. # 7.1.3. The WebSocket Closing Handshake is Started @@ -259,7 +258,7 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> None: # sent if it's CLOSING), except when failing the connection because of # an error reading from or writing to the network. if code != 1006 and self.state is OPEN: - self.send_frame(Frame(OP_CLOSE, serialize_close(code, reason))) + self.send_frame(Frame(OP_CLOSE, Close(code, reason).serialize())) self.set_state(CLOSING) if not self.eof_sent: self.send_eof() @@ -379,7 +378,8 @@ def parse(self) -> Generator[None, None, None]: self.close_frame_received = True # 7.1.5. The WebSocket Connection Close Code # 7.1.6. The WebSocket Connection Close Reason - self.close_code, self.close_reason = parse_close(frame.data) + close = Close.parse(frame.data) + self.close_code, self.close_reason = close.code, close.reason if self.cur_size is not None: raise ProtocolError("incomplete fragmented message") @@ -390,8 +390,8 @@ def parse(self) -> Generator[None, None, None]: # received.)" if self.state is OPEN: # Echo the original data instead of re-serializing it with - # serialize_close() because that fails when the close frame - # is empty and parse_close() synthetizes a 1005 close code. + # Close.serialize() because that fails when the close frame + # is empty and Close.parse() synthetizes a 1005 close code. # The rest is identical to send_close(). self.send_frame(Frame(OP_CLOSE, frame.data)) self.set_state(CLOSING) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index f110322aa..59922be9c 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -84,7 +84,7 @@ class ConnectionClosed(WebSocketException): def __init__(self, code: int, reason: str) -> None: self.code = code self.reason = reason - super().__init__(frames.format_close(code, reason)) + super().__init__(str(frames.Close(code, reason))) class ConnectionClosedError(ConnectionClosed): diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 175478794..b0de645ec 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -35,8 +35,7 @@ "Frame", "prepare_data", "prepare_ctrl", - "parse_close", - "serialize_close", + "Close", ] @@ -140,8 +139,7 @@ def __str__(self) -> str: binary = binary[:16] + b"\x00\x00" + binary[-8:] data = " ".join(f"{byte:02x}" for byte in binary) elif self.opcode is OP_CLOSE: - code, reason = parse_close(self.data) - data = format_close(code, reason) + data = str(Close.parse(self.data)) elif self.data: # We don't know if a Continuation frame contains text or binary. # Ping and Pong frames could contain UTF-8. Attempt to decode as @@ -362,37 +360,6 @@ def prepare_ctrl(data: Data) -> bytes: raise TypeError("data must be bytes-like or str") -def parse_close(data: bytes) -> Tuple[int, str]: - """ - Parse the payload from a close frame. - - Return ``(code, reason)``. - - :raises ~websockets.exceptions.ProtocolError: if data is ill-formed - :raises UnicodeDecodeError: if the reason isn't valid UTF-8 - - """ - return dataclasses.astuple(Close.parse(data)) # type: ignore - - -def serialize_close(code: int, reason: str) -> bytes: - """ - Serialize the payload for a close frame. - - This is the reverse of :func:`parse_close`. - - """ - return Close(code, reason).serialize() - - -def format_close(code: int, reason: str) -> str: - """ - Display a human-readable version of the close code and reason. - - """ - return str(Close(code, reason)) - - @dataclasses.dataclass class Close: """ diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index 12e006911..c8ae48690 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -12,8 +12,9 @@ from __future__ import annotations +import dataclasses import struct -from typing import Any, Awaitable, Callable, NamedTuple, Optional, Sequence +from typing import Any, Awaitable, Callable, NamedTuple, Optional, Sequence, Tuple from .. import extensions, frames from ..exceptions import PayloadTooBig, ProtocolError @@ -159,7 +160,28 @@ def write( # Backwards compatibility with previously documented public APIs -from ..frames import parse_close # noqa -from ..frames import prepare_data # noqa -from ..frames import serialize_close # noqa -from ..frames import prepare_ctrl as encode_data # noqa + +from ..frames import Close, prepare_ctrl as encode_data, prepare_data # noqa + + +def parse_close(data: bytes) -> Tuple[int, str]: + """ + Parse the payload from a close frame. + + Return ``(code, reason)``. + + :raises ~websockets.exceptions.ProtocolError: if data is ill-formed + :raises UnicodeDecodeError: if the reason isn't valid UTF-8 + + """ + return dataclasses.astuple(Close.parse(data)) # type: ignore + + +def serialize_close(code: int, reason: str) -> bytes: + """ + Serialize the payload for a close frame. + + This is the reverse of :func:`parse_close`. + + """ + return Close(code, reason).serialize() diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 8a9557792..c3221bb45 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -50,11 +50,10 @@ OP_PING, OP_PONG, OP_TEXT, + Close, Opcode, - parse_close, prepare_ctrl, prepare_data, - serialize_close, ) from ..typing import Data, LoggerLike, Subprotocol from .compatibility import loop_if_py_lt_38 @@ -609,7 +608,7 @@ async def close(self, code: int = 1000, reason: str = "") -> None: """ try: await asyncio.wait_for( - self.write_close_frame(serialize_close(code, reason)), + self.write_close_frame(Close(code, reason).serialize()), self.close_timeout, **loop_if_py_lt_38(self.loop), ) @@ -918,11 +917,12 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: if frame.opcode == OP_CLOSE: # 7.1.5. The WebSocket Connection Close Code # 7.1.6. The WebSocket Connection Close Reason - self.close_code, self.close_reason = parse_close(frame.data) + close = Close.parse(frame.data) + self.close_code, self.close_reason = close.code, close.reason try: # Echo the original data instead of re-serializing it with - # serialize_close() because that fails when the close frame - # is empty and parse_close() synthetizes a 1005 close code. + # Close.serialize() because that fails when the close frame + # is empty and Close.parse() synthetizes a 1005 close code. await self.write_close_frame(frame.data) except ConnectionClosed: # Connection closed before we could echo the close frame. @@ -1220,7 +1220,7 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> None: # Don't send a close frame if the connection is broken. if code != 1006 and self.state is State.OPEN: - frame_data = serialize_close(code, reason) + frame_data = Close(code, reason).serialize() # Write the close frame without draining the write buffer. diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index 5ba4d8ddf..bcd08a7ef 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -17,8 +17,8 @@ OP_PING, OP_PONG, OP_TEXT, + Close, Frame, - serialize_close, ) from .test_base import ClientNoOpExtensionFactory, ServerNoOpExtensionFactory @@ -70,7 +70,7 @@ def test_no_encode_decode_pong_frame(self): self.assertEqual(self.extension.decode(frame), frame) def test_no_encode_decode_close_frame(self): - frame = Frame(OP_CLOSE, serialize_close(1000, "")) + frame = Frame(OP_CLOSE, Close(1000, "").serialize()) self.assertEqual(self.extension.encode(frame), frame) diff --git a/tests/legacy/test_framing.py b/tests/legacy/test_framing.py index 2baa827a9..4646817a8 100644 --- a/tests/legacy/test_framing.py +++ b/tests/legacy/test_framing.py @@ -171,3 +171,34 @@ def decode(frame, *, max_size=None): self.round_trip( b"\x81\x05uryyb", Frame(True, OP_TEXT, b"hello"), extensions=[Rot13()] ) + + +class ParseAndSerializeCloseTests(unittest.TestCase): + def assertCloseData(self, code, reason, data): + """ + Serializing code / reason yields data. Parsing data yields code / reason. + + """ + serialized = serialize_close(code, reason) + self.assertEqual(serialized, data) + parsed = parse_close(data) + self.assertEqual(parsed, (code, reason)) + + def test_parse_close_and_serialize_close(self): + self.assertCloseData(1000, "", b"\x03\xe8") + self.assertCloseData(1000, "OK", b"\x03\xe8OK") + + def test_parse_close_empty(self): + self.assertEqual(parse_close(b""), (1005, "")) + + def test_parse_close_errors(self): + with self.assertRaises(ProtocolError): + parse_close(b"\x03") + with self.assertRaises(ProtocolError): + parse_close(b"\x03\xe7") + with self.assertRaises(UnicodeDecodeError): + parse_close(b"\x03\xe8\xff\xff") + + def test_serialize_close_errors(self): + with self.assertRaises(ProtocolError): + serialize_close(999, "") diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 97aa01596..1e3f1b77e 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -14,7 +14,7 @@ OP_PING, OP_PONG, OP_TEXT, - serialize_close, + Close, ) from websockets.legacy.compatibility import loop_if_py_lt_38 from websockets.legacy.framing import Frame @@ -113,9 +113,9 @@ async def delayed_drain(): self.protocol._drain = delayed_drain - close_frame = Frame(True, OP_CLOSE, serialize_close(1000, "close")) - local_close = Frame(True, OP_CLOSE, serialize_close(1000, "local")) - remote_close = Frame(True, OP_CLOSE, serialize_close(1000, "remote")) + close_frame = Frame(True, OP_CLOSE, Close(1000, "close").serialize()) + local_close = Frame(True, OP_CLOSE, Close(1000, "local").serialize()) + remote_close = Frame(True, OP_CLOSE, Close(1000, "remote").serialize()) def receive_frame(self, frame): """ @@ -156,7 +156,7 @@ def close_connection(self, code=1000, reason="close"): This puts the connection in the CLOSED state. """ - close_frame_data = serialize_close(code, reason) + close_frame_data = Close(code, reason).serialize() # Prepare the response to the closing handshake from the remote side. self.receive_frame(Frame(True, OP_CLOSE, close_frame_data)) self.receive_eof_if_client() @@ -178,7 +178,7 @@ def half_close_connection_local(self, code=1000, reason="close"): canceled, else asyncio complains about destroying a pending task. """ - close_frame_data = serialize_close(code, reason) + close_frame_data = Close(code, reason).serialize() # Trigger the closing handshake from the local endpoint. close_task = self.loop.create_task(self.protocol.close(code, reason)) self.run_loop_once() # wait_for executes @@ -213,7 +213,7 @@ def half_close_connection_remote(self, code=1000, reason="close"): if not self.protocol.is_client: self.make_drain_slow() - close_frame_data = serialize_close(code, reason) + close_frame_data = Close(code, reason).serialize() # Trigger the closing handshake from the remote endpoint. self.receive_frame(Frame(True, OP_CLOSE, close_frame_data)) self.run_loop_once() # read_frame executes @@ -298,7 +298,7 @@ def assertConnectionFailed(self, code, message): if code == 1006: self.assertNoFrameSent() else: - self.assertOneFrameSent(True, OP_CLOSE, serialize_close(code, message)) + self.assertOneFrameSent(True, OP_CLOSE, Close(code, message).serialize()) @contextlib.contextmanager def assertCompletesWithin(self, min_time, max_time): @@ -649,7 +649,7 @@ def test_send_iterable_mixed_type_error(self): self.loop.run_until_complete(self.protocol.send(["café", b"tea"])) self.assertFramesSent( (False, OP_TEXT, "café".encode("utf-8")), - (True, OP_CLOSE, serialize_close(1011, "")), + (True, OP_CLOSE, Close(1011, "").serialize()), ) def test_send_iterable_prevents_concurrent_send(self): @@ -723,7 +723,7 @@ def test_send_async_iterable_mixed_type_error(self): ) self.assertFramesSent( (False, OP_TEXT, "café".encode("utf-8")), - (True, OP_CLOSE, serialize_close(1011, "")), + (True, OP_CLOSE, Close(1011, "").serialize()), ) def test_send_async_iterable_prevents_concurrent_send(self): @@ -1172,7 +1172,7 @@ def test_keepalive_ping_not_acknowledged_closes_connection(self): # Connection is closed at 6ms. self.loop.run_until_complete(asyncio.sleep(4 * MS)) self.assertOneFrameSent( - True, OP_CLOSE, serialize_close(1011, "keepalive ping timeout") + True, OP_CLOSE, Close(1011, "keepalive ping timeout").serialize() ) # The keepalive ping task is complete. diff --git a/tests/test_connection.py b/tests/test_connection.py index 677881238..f2ce8de46 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -9,8 +9,8 @@ OP_PING, OP_PONG, OP_TEXT, + Close, Frame, - serialize_close, ) from .extensions.utils import Rsv2Extension @@ -60,7 +60,7 @@ def assertConnectionClosing(self, connection, code=None, reason=""): """ close_frame = Frame( OP_CLOSE, - b"" if code is None else serialize_close(code, reason), + b"" if code is None else Close(code, reason).serialize(), ) # A close frame was received. self.assertFrameReceived(connection, close_frame) @@ -76,7 +76,7 @@ def assertConnectionFailing(self, connection, code=None, reason=""): """ close_frame = Frame( OP_CLOSE, - b"" if code is None else serialize_close(code, reason), + b"" if code is None else Close(code, reason).serialize(), ) # No frame was received. self.assertFrameReceived(connection, None) diff --git a/tests/test_frames.py b/tests/test_frames.py index 85414e9d3..f5ed7ebc1 100644 --- a/tests/test_frames.py +++ b/tests/test_frames.py @@ -394,32 +394,32 @@ def test_prepare_ctrl_none(self): prepare_ctrl(None) -class ParseAndSerializeCloseTests(unittest.TestCase): - def assertCloseData(self, code, reason, data): +class CloseTests(unittest.TestCase): + def assertCloseData(self, close, data): """ - Serializing code / reason yields data. Parsing data yields code / reason. + Serializing close yields data. Parsing data yields close. """ - serialized = serialize_close(code, reason) + serialized = close.serialize() self.assertEqual(serialized, data) - parsed = parse_close(data) - self.assertEqual(parsed, (code, reason)) + parsed = Close.parse(data) + self.assertEqual(parsed, close) - def test_parse_close_and_serialize_close(self): - self.assertCloseData(1000, "", b"\x03\xe8") - self.assertCloseData(1000, "OK", b"\x03\xe8OK") + def test_parse_and_serialize(self): + self.assertCloseData(Close(1000, ""), b"\x03\xe8") + self.assertCloseData(Close(1000, "OK"), b"\x03\xe8OK") - def test_parse_close_empty(self): - self.assertEqual(parse_close(b""), (1005, "")) + def test_parse_empty(self): + self.assertEqual(Close.parse(b""), Close(1005, "")) - def test_parse_close_errors(self): + def test_parse_errors(self): with self.assertRaises(ProtocolError): - parse_close(b"\x03") + Close.parse(b"\x03") with self.assertRaises(ProtocolError): - parse_close(b"\x03\xe7") + Close.parse(b"\x03\xe7") with self.assertRaises(UnicodeDecodeError): - parse_close(b"\x03\xe8\xff\xff") + Close.parse(b"\x03\xe8\xff\xff") - def test_serialize_close_errors(self): + def test_serialize_errors(self): with self.assertRaises(ProtocolError): - serialize_close(999, "") + Close(999, "").serialize() From 2dd793ffc0e0faab2ed34e4e06fc0b9702a84999 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 16 Jun 2021 08:42:03 +0200 Subject: [PATCH 0869/1539] Simplify display of close codes and reasons. --- src/websockets/frames.py | 6 ++---- src/websockets/legacy/protocol.py | 2 +- tests/test_exceptions.py | 12 ++++++------ tests/test_frames.py | 4 ++-- 4 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index b0de645ec..15b1a9de2 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -381,12 +381,10 @@ def __str__(self) -> str: explanation = "private use" else: explanation = CLOSE_CODES.get(self.code, "unknown") - result = f"code = {self.code} ({explanation}), " + result = f"{self.code} ({explanation})" if self.reason: - result += f"reason = {self.reason}" - else: - result += "no reason" + result = f"{result} {self.reason}" return result diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index c3221bb45..7443a565e 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1308,7 +1308,7 @@ def connection_lost(self, exc: Optional[Exception]) -> None: if not hasattr(self, "close_reason"): self.close_reason = "" self.logger.debug( - "= connection is CLOSED - code = %d, reason = %s", + "= connection is CLOSED - %d %s", self.close_code, self.close_reason or "[no reason]", ) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 094fb6d33..9c8eef4fc 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -14,27 +14,27 @@ def test_str(self): ), ( ConnectionClosed(1000, ""), - "code = 1000 (OK), no reason", + "1000 (OK)", ), ( ConnectionClosed(1006, None), - "code = 1006 (connection closed abnormally [internal]), no reason" + "1006 (connection closed abnormally [internal])" ), ( ConnectionClosed(3000, None), - "code = 3000 (registered), no reason" + "3000 (registered)" ), ( ConnectionClosed(4000, None), - "code = 4000 (private use), no reason" + "4000 (private use)" ), ( ConnectionClosedError(1016, None), - "code = 1016 (unknown), no reason" + "1016 (unknown)" ), ( ConnectionClosedOK(1001, "bye"), - "code = 1001 (going away), reason = bye", + "1001 (going away) bye", ), ( InvalidHandshake("invalid request"), diff --git a/tests/test_frames.py b/tests/test_frames.py index f5ed7ebc1..7620fe415 100644 --- a/tests/test_frames.py +++ b/tests/test_frames.py @@ -280,13 +280,13 @@ def test_binary_truncated(self): def test_close(self): self.assertEqual( str(Frame(OP_CLOSE, b"\x03\xe8")), - "CLOSE code = 1000 (OK), no reason [2 bytes]", + "CLOSE 1000 (OK) [2 bytes]", ) def test_close_reason(self): self.assertEqual( str(Frame(OP_CLOSE, b"\x03\xe9Bye!")), - "CLOSE code = 1001 (going away), reason = Bye! [6 bytes]", + "CLOSE 1001 (going away) Bye! [6 bytes]", ) def test_ping(self): From 62eb267c34034e97813bc210f0b04a6acbedd7a5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 16 Jun 2021 22:21:01 +0200 Subject: [PATCH 0870/1539] Add details to ConnectionClosed. Fix #587. Ref #767. --- docs/howto/django.rst | 4 +- docs/howto/faq.rst | 11 +-- docs/howto/haproxy.rst | 2 +- docs/howto/heroku.rst | 4 +- docs/howto/kubernetes.rst | 4 +- docs/howto/nginx.rst | 2 +- docs/howto/supervisor.rst | 4 +- docs/project/changelog.rst | 9 +++ docs/reference/client.rst | 12 +--- docs/reference/server.rst | 12 +--- src/websockets/__main__.py | 1 + src/websockets/exceptions.py | 54 ++++++++++---- src/websockets/frames.py | 6 +- src/websockets/legacy/protocol.py | 114 ++++++++++++++++++++---------- tests/legacy/test_protocol.py | 14 ++++ tests/test_exceptions.py | 33 +++++---- tests/test_frames.py | 9 ++- 17 files changed, 195 insertions(+), 100 deletions(-) diff --git a/docs/howto/django.rst b/docs/howto/django.rst index 001d1cb73..c87c3821c 100644 --- a/docs/howto/django.rst +++ b/docs/howto/django.rst @@ -146,7 +146,7 @@ prompt: Connected to ws://localhost:8888/ > < Hello ! - Connection closed: code = 1000 (OK), no reason. + Connection closed: 1000 (OK). It works! @@ -158,7 +158,7 @@ closes the connection: $ python -m websockets ws://localhost:8888/ Connected to ws://localhost:8888. > not a token - Connection closed: code = 1011 (unexpected error), reason = authentication failed. + Connection closed: 1011 (unexpected error) authentication failed. You can also test from a browser by generating a new token and running the following code in the JavaScript console of the browser: diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index ce3704a2e..a4ad84680 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -244,8 +244,8 @@ See `issue 867`_. Both sides ---------- -What does ``ConnectionClosedError: code = 1006`` mean? -...................................................... +What does ``ConnectionClosedError: no close frame received or sent`` mean? +.......................................................................... If you're seeing this traceback in the logs of a server: @@ -260,7 +260,7 @@ If you're seeing this traceback in the logs of a server: Traceback (most recent call last): ... - websockets.exceptions.ConnectionClosedError: code = 1006 (connection closed abnormally [internal]), no reason + websockets.exceptions.ConnectionClosedError: no close frame received or sent or if a client crashes with this traceback: @@ -274,10 +274,11 @@ or if a client crashes with this traceback: Traceback (most recent call last): ... - websockets.exceptions.ConnectionClosedError: code = 1006 (connection closed abnormally [internal]), no reason + websockets.exceptions.ConnectionClosedError: no close frame received or sent it means that the TCP connection was lost. As a consequence, the WebSocket -connection was closed without receiving a close frame, which is abnormal. +connection was closed without receiving and sending a close frame, which is +abnormal. You can catch and handle :exc:`~exceptions.ConnectionClosed` to prevent it from being logged. diff --git a/docs/howto/haproxy.rst b/docs/howto/haproxy.rst index d520d278a..0ecb46a04 100644 --- a/docs/howto/haproxy.rst +++ b/docs/howto/haproxy.rst @@ -58,4 +58,4 @@ You can confirm that HAProxy proxies connections properly: Connected to ws://localhost:8080/. > Hello! < Hello! - Connection closed: code = 1000 (OK), no reason. + Connection closed: 1000 (OK). diff --git a/docs/howto/heroku.rst b/docs/howto/heroku.rst index 92dee5f27..b728106e9 100644 --- a/docs/howto/heroku.rst +++ b/docs/howto/heroku.rst @@ -142,7 +142,7 @@ then press Ctrl-D to terminate the connection: > Hello! < Hello! - Connection closed: code = 1000 (OK), no reason. + Connection closed: 1000 (OK). You can also confirm that your application shuts down gracefully. Connect an interactive client again — remember to replace ``websockets-echo`` with your app: @@ -167,7 +167,7 @@ away). $ python -m websockets wss://websockets-echo.herokuapp.com/ Connected to wss://websockets-echo.herokuapp.com/. - Connection closed: code = 1001 (going away), no reason. + Connection closed: 1001 (going away). If graceful shutdown wasn't working, the server wouldn't perform a closing handshake and the connection would be closed with code 1006 (connection closed diff --git a/docs/howto/kubernetes.rst b/docs/howto/kubernetes.rst index ef5c963d7..0e77aeac1 100644 --- a/docs/howto/kubernetes.rst +++ b/docs/howto/kubernetes.rst @@ -73,7 +73,7 @@ shut down gracefully: Connected to ws://localhost:32080/. > Hey there! < Hey there! - Connection closed: code = 1001 (going away), no reason. + Connection closed: 1001 (going away). If it didn't, you'd get code 1006 (connection closed abnormally). @@ -119,7 +119,7 @@ You can connect to the service — press Ctrl-D to exit: $ python -m websockets ws://localhost:32080/ Connected to ws://localhost:32080/. - Connection closed: code = 1000 (OK), no reason. + Connection closed: 1000 (OK). Validate deployment ------------------- diff --git a/docs/howto/nginx.rst b/docs/howto/nginx.rst index cb4dc83f2..e20f82098 100644 --- a/docs/howto/nginx.rst +++ b/docs/howto/nginx.rst @@ -81,4 +81,4 @@ You can confirm that nginx proxies connections properly: Connected to ws://localhost:8080/. > Hello! < Hello! - Connection closed: code = 1000 (OK), no reason. + Connection closed: 1000 (OK). diff --git a/docs/howto/supervisor.rst b/docs/howto/supervisor.rst index 6d679ca2b..0c07aebae 100644 --- a/docs/howto/supervisor.rst +++ b/docs/howto/supervisor.rst @@ -68,7 +68,7 @@ press Ctrl-D to exit: Connected to ws://localhost:8080/. > Hello! < Hello! - Connection closed: code = 1000 (OK), no reason. + Connection closed: 1000 (OK). Look at the pid of an instance of the app in the logs and terminate it: @@ -117,7 +117,7 @@ And you can see that the connection to the app was closed gracefully: $ python -m websockets ws://localhost:8080/ Connected to ws://localhost:8080/. - Connection closed: code = 1001 (going away), no reason. + Connection closed: 1001 (going away). In this example, we've been sharing the same virtualenv for supervisor and websockets. diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index c8553d17d..cf38c6159 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -49,6 +49,13 @@ They may change at any time. This reflects a decision made in Python 3.8. See the release notes of Python 3.10 for details. +.. note:: + + **Version 10.0 changes parameters of** ``ConnectionClosed.__init__`` **.** + + If you raise :exc:`~exceptions.ConnectionClosed` or a subclass — rather + than catch them when websockets raises them — you must change your code. + * Added compatibility with Python 3.10. * Added :func:`~websockets.broadcast` to send a message to many clients. @@ -60,6 +67,8 @@ They may change at any time. * Improved logging. +* Provided additional information in :exc:`ConnectionClosed` exceptions. + * Optimized default compression settings to reduce memory usage. * Made it easier to customize authentication with diff --git a/docs/reference/client.rst b/docs/reference/client.rst index c7c738c81..84f66a19a 100644 --- a/docs/reference/client.rst +++ b/docs/reference/client.rst @@ -55,17 +55,9 @@ Client Available once the connection is open. - .. attribute:: close_code + .. autoattribute:: close_code - WebSocket close code. - - Available once the connection is closed. - - .. attribute:: close_reason - - WebSocket close reason. - - Available once the connection is closed. + .. autoattribute:: close_reason .. automethod:: recv diff --git a/docs/reference/server.rst b/docs/reference/server.rst index 2d54eca7a..667c0b9d0 100644 --- a/docs/reference/server.rst +++ b/docs/reference/server.rst @@ -65,17 +65,9 @@ Server Available once the connection is open. - .. attribute:: close_code + .. autoattribute:: close_code - WebSocket close code. - - Available once the connection is closed. - - .. attribute:: close_reason - - WebSocket close reason. - - Available once the connection is closed. + .. autoattribute:: close_reason .. automethod:: process_request diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index 6347ee278..785d2c3c9 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -143,6 +143,7 @@ async def run_client( finally: await websocket.close() + assert websocket.close_code is not None and websocket.close_reason is not None close_status = Close(websocket.close_code, websocket.close_reason) print_over_input(f"Connection closed: {close_status}.") diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 59922be9c..67745da61 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -81,37 +81,63 @@ class ConnectionClosed(WebSocketException): """ - def __init__(self, code: int, reason: str) -> None: - self.code = code - self.reason = reason - super().__init__(str(frames.Close(code, reason))) + def __init__( + self, + rcvd: Optional[frames.Close], + sent: Optional[frames.Close], + rcvd_then_sent: Optional[bool] = None, + ) -> None: + self.rcvd = rcvd + self.sent = sent + self.rcvd_then_sent = rcvd_then_sent + if rcvd is None: + if sent is None: + assert rcvd_then_sent is None + msg = "no close frame received or sent" + else: + assert rcvd_then_sent is None + msg = f"sent {sent}; no close frame received" + else: + if sent is None: + assert rcvd_then_sent is None + msg = f"received {rcvd}; no close frame sent" + else: + assert rcvd_then_sent is not None + if rcvd_then_sent: + msg = f"received {rcvd}; then sent {sent}" + else: + msg = f"sent {sent}; then received {rcvd}" + super().__init__(msg) + + # code and reason attributes are provided for backwards-compatibility + + @property + def code(self) -> int: + return 1006 if self.rcvd is None else self.rcvd.code + + @property + def reason(self) -> str: + return "" if self.rcvd is None else self.rcvd.reason class ConnectionClosedError(ConnectionClosed): """ Like :exc:`ConnectionClosed`, when the connection terminated with an error. - This means the close code is different from 1000 (OK) and 1001 (going away). + A close code other than 1000 (OK) or 1001 (going away) was received or + sent, or the closing handshake didn't complete properly. """ - def __init__(self, code: int, reason: str) -> None: - assert code != 1000 and code != 1001 - super().__init__(code, reason) - class ConnectionClosedOK(ConnectionClosed): """ Like :exc:`ConnectionClosed`, when the connection terminated properly. - This means the close code is 1000 (OK) or 1001 (going away). + A close code 1000 (OK) or 1001 (going away) was received and sent. """ - def __init__(self, code: int, reason: str) -> None: - assert code == 1000 or code == 1001 - super().__init__(code, reason) - class InvalidHandshake(WebSocketException): """ diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 15b1a9de2..4c57386b7 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -93,18 +93,20 @@ class Opcode(enum.IntEnum): 1014, } +OK_CLOSE_CODES = {1000, 1001} + @dataclasses.dataclass class Frame: """ WebSocket frame. + :param int opcode: opcode + :param bytes data: payload data :param bool fin: FIN bit :param bool rsv1: RSV1 bit :param bool rsv2: RSV2 bit :param bool rsv3: RSV3 bit - :param int opcode: opcode - :param bytes data: payload data Only these fields are needed. The MASK bit, payload length and masking-key are handled on the fly by :meth:`parse` and :meth:`serialize`. diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 7443a565e..badd1e0d8 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -44,6 +44,7 @@ ) from ..extensions import Extension from ..frames import ( + OK_CLOSE_CODES, OP_BINARY, OP_CLOSE, OP_CONT, @@ -181,10 +182,10 @@ def __init__( self.extensions: List[Extension] = [] self.subprotocol: Optional[Subprotocol] = None - # The close code and reason are set when receiving a close frame or - # losing the TCP connection. - self.close_code: int - self.close_reason: str + # Close code and reason, set when a close frame is sent or received. + self.close_rcvd: Optional[Close] = None + self.close_sent: Optional[Close] = None + self.close_rcvd_then_sent: Optional[bool] = None # Completed when the connection state becomes CLOSED. Translates the # :meth:`connection_lost` callback to a :class:`~asyncio.Future` @@ -339,6 +340,36 @@ def closed(self) -> bool: """ return self.state is State.CLOSED + @property + def close_code(self) -> Optional[int]: + """ + WebSocket close code received in a close frame. + + Available once the connection is closed. + + """ + if self.state is not State.CLOSED: + return None + elif self.close_rcvd is None: + return 1006 + else: + return self.close_rcvd.code + + @property + def close_reason(self) -> Optional[str]: + """ + WebSocket close reason received in a close frame. + + Available once the connection is closed. + + """ + if self.state is not State.CLOSED: + return None + elif self.close_rcvd is None: + return "" + else: + return self.close_rcvd.reason + async def wait_closed(self) -> None: """ Wait until the connection is closed. @@ -608,7 +639,7 @@ async def close(self, code: int = 1000, reason: str = "") -> None: """ try: await asyncio.wait_for( - self.write_close_frame(Close(code, reason).serialize()), + self.write_close_frame(Close(code, reason)), self.close_timeout, **loop_if_py_lt_38(self.loop), ) @@ -714,14 +745,27 @@ async def pong(self, data: Data = b"") -> None: # Private methods - no guarantees. def connection_closed_exc(self) -> ConnectionClosed: - exception: ConnectionClosed - if self.close_code == 1000 or self.close_code == 1001: - exception = ConnectionClosedOK(self.close_code, self.close_reason) + exc: ConnectionClosed + if ( + self.close_rcvd is not None + and self.close_rcvd.code in OK_CLOSE_CODES + and self.close_sent is not None + and self.close_sent.code in OK_CLOSE_CODES + ): + exc = ConnectionClosedOK( + self.close_rcvd, + self.close_sent, + self.close_rcvd_then_sent, + ) else: - exception = ConnectionClosedError(self.close_code, self.close_reason) + exc = ConnectionClosedError( + self.close_rcvd, + self.close_sent, + self.close_rcvd_then_sent, + ) # Chain to the exception that terminated data transfer, if any. - exception.__cause__ = self.transfer_data_exc - return exception + exc.__cause__ = self.transfer_data_exc + return exc async def ensure_open(self) -> None: """ @@ -917,13 +961,14 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: if frame.opcode == OP_CLOSE: # 7.1.5. The WebSocket Connection Close Code # 7.1.6. The WebSocket Connection Close Reason - close = Close.parse(frame.data) - self.close_code, self.close_reason = close.code, close.reason + self.close_rcvd = Close.parse(frame.data) + if self.close_sent is not None: + self.close_rcvd_then_sent = False try: # Echo the original data instead of re-serializing it with # Close.serialize() because that fails when the close frame # is empty and Close.parse() synthetizes a 1005 close code. - await self.write_close_frame(frame.data) + await self.write_close_frame(self.close_rcvd, frame.data) except ConnectionClosed: # Connection closed before we could echo the close frame. pass @@ -1010,7 +1055,9 @@ async def write_frame( self.write_frame_sync(fin, opcode, data) await self.drain() - async def write_close_frame(self, data: bytes = b"") -> None: + async def write_close_frame( + self, close: Close, data: Optional[bytes] = None + ) -> None: """ Write a close frame if and only if the connection state is OPEN. @@ -1026,6 +1073,12 @@ async def write_close_frame(self, data: bytes = b"") -> None: if self.debug: self.logger.debug("= connection is CLOSING") + self.close_sent = close + if self.close_rcvd is not None: + self.close_rcvd_then_sent = True + if data is None: + data = close.serialize() + # 7.1.2. Start the WebSocket Closing Handshake await self.write_frame(True, OP_CLOSE, data, _state=State.CLOSING) @@ -1219,8 +1272,7 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> None: # an error reading from or writing to the network. # Don't send a close frame if the connection is broken. if code != 1006 and self.state is State.OPEN: - - frame_data = Close(code, reason).serialize() + close = Close(code, reason) # Write the close frame without draining the write buffer. @@ -1228,21 +1280,19 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> None: # get stuck and simplifies the implementation of the callers. # Not drainig the write buffer is acceptable in this context. - # This duplicates a few lines of code from write_close_frame() - # and write_frame(). + # This duplicates a few lines of code from write_close_frame(). self.state = State.CLOSING if self.debug: self.logger.debug("= connection is CLOSING") - frame = Frame(True, OP_CLOSE, frame_data) - if self.debug: - self.logger.debug("> %s", frame) - frame.write( - self.transport.write, - mask=self.is_client, - extensions=self.extensions, - ) + # If self.close_rcvd was set, the connection state would be + # CLOSING. Therefore self.close_rcvd isn't set and we don't + # have to set self.close_rcvd_then_sent. + assert self.close_rcvd is None + self.close_sent = close + + self.write_frame_sync(True, OP_CLOSE, close.serialize()) # Start close_connection_task if the opening handshake didn't succeed. if not hasattr(self, "close_connection_task"): @@ -1303,15 +1353,7 @@ def connection_lost(self, exc: Optional[Exception]) -> None: """ self.state = State.CLOSED - if not hasattr(self, "close_code"): - self.close_code = 1006 - if not hasattr(self, "close_reason"): - self.close_reason = "" - self.logger.debug( - "= connection is CLOSED - %d %s", - self.close_code, - self.close_reason or "[no reason]", - ) + self.logger.debug("= connection is CLOSED") self.abort_pings() diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 1e3f1b77e..033cebe17 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -386,6 +386,20 @@ def test_wait_closed(self): self.close_connection() self.assertTrue(wait_closed.done()) + def test_close_code(self): + self.close_connection(1001, "Bye!") + self.assertEqual(self.protocol.close_code, 1001) + + def test_close_reason(self): + self.close_connection(1001, "Bye!") + self.assertEqual(self.protocol.close_reason, "Bye!") + + def test_close_code_not_set(self): + self.assertIsNone(self.protocol.close_code) + + def test_close_reason_not_set(self): + self.assertIsNone(self.protocol.close_reason) + # Test the recv coroutine. def test_recv_text(self): diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 9c8eef4fc..85ebc24da 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -2,6 +2,7 @@ from websockets.datastructures import Headers from websockets.exceptions import * +from websockets.frames import Close class ExceptionsTests(unittest.TestCase): @@ -13,28 +14,36 @@ def test_str(self): "something went wrong", ), ( - ConnectionClosed(1000, ""), - "1000 (OK)", + ConnectionClosed(Close(1000, ""), Close(1000, ""), True), + "received 1000 (OK); then sent 1000 (OK)", ), ( - ConnectionClosed(1006, None), - "1006 (connection closed abnormally [internal])" + ConnectionClosed(Close(1001, "Bye!"), Close(1001, "Bye!"), False), + "sent 1001 (going away) Bye!; then received 1001 (going away) Bye!", ), ( - ConnectionClosed(3000, None), - "3000 (registered)" + ConnectionClosed(Close(1000, "race"), Close(1000, "cond"), True), + "received 1000 (OK) race; then sent 1000 (OK) cond", ), ( - ConnectionClosed(4000, None), - "4000 (private use)" + ConnectionClosed(Close(1000, "cond"), Close(1000, "race"), False), + "sent 1000 (OK) race; then received 1000 (OK) cond", ), ( - ConnectionClosedError(1016, None), - "1016 (unknown)" + ConnectionClosed(None, Close(1009, ""), None), + "sent 1009 (message too big); no close frame received", ), ( - ConnectionClosedOK(1001, "bye"), - "1001 (going away) bye", + ConnectionClosed(Close(1002, ""), None, None), + "received 1002 (protocol error); no close frame sent", + ), + ( + ConnectionClosedOK(Close(1000, ""), Close(1000, ""), True), + "received 1000 (OK); then sent 1000 (OK)", + ), + ( + ConnectionClosedError(None, None, None), + "no close frame received or sent" ), ( InvalidHandshake("invalid request"), diff --git a/tests/test_frames.py b/tests/test_frames.py index 7620fe415..c8f9867d4 100644 --- a/tests/test_frames.py +++ b/tests/test_frames.py @@ -405,8 +405,15 @@ def assertCloseData(self, close, data): parsed = Close.parse(data) self.assertEqual(parsed, close) + def test_str(self): + self.assertEqual(str(Close(1000, "")), "1000 (OK)") + self.assertEqual(str(Close(1001, "Bye!")), "1001 (going away) Bye!") + self.assertEqual(str(Close(3000, "")), "3000 (registered)") + self.assertEqual(str(Close(4000, "")), "4000 (private use)") + self.assertEqual(str(Close(5000, "")), "5000 (unknown)") + def test_parse_and_serialize(self): - self.assertCloseData(Close(1000, ""), b"\x03\xe8") + self.assertCloseData(Close(1001, ""), b"\x03\xe9") self.assertCloseData(Close(1000, "OK"), b"\x03\xe8OK") def test_parse_empty(self): From 1186dd6cfc072a7038d5478f6f5722038efba51c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 16 Jun 2021 22:51:50 +0200 Subject: [PATCH 0871/1539] Store details of close frames. Ref #587. --- src/websockets/connection.py | 70 ++++++++++++++++++++++++++---------- tests/test_connection.py | 50 +++++++++++++++++++++++++- 2 files changed, 101 insertions(+), 19 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 067ad54ce..d682226ab 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -102,15 +102,10 @@ def __init__( self.extensions: List[Extension] = [] self.subprotocol: Optional[Subprotocol] = None - # Connection state isn't enough to tell if a close frame was received: - # when this side closes the connection, state is CLOSING as soon as a - # close frame is sent, before a close frame is received. - self.close_frame_received = False - - # Close code and reason. Set when receiving a close frame or when the - # TCP connection drops. - self.close_code: int - self.close_reason: str + # Close code and reason, set when a close frame is sent or received. + self.close_rcvd: Optional[Close] = None + self.close_sent: Optional[Close] = None + self.close_rcvd_then_sent: Optional[bool] = None # Track if send_eof() was called. self.eof_sent = False @@ -128,6 +123,36 @@ def set_state(self, state: State) -> None: self.logger.debug("= connection is %s", state.name) self.state = state + @property + def close_code(self) -> Optional[int]: + """ + WebSocket close code received in a close frame. + + Available once the connection is closed. + + """ + if self.state is not State.CLOSED: + return None + elif self.close_rcvd is None: + return 1006 + else: + return self.close_rcvd.code + + @property + def close_reason(self) -> Optional[str]: + """ + WebSocket close reason received in a close frame. + + Available once the connection is closed. + + """ + if self.state is not State.CLOSED: + return None + elif self.close_rcvd is None: + return "" + else: + return self.close_rcvd.reason + # Public APIs for receiving data. def receive_data(self, data: bytes) -> None: @@ -199,10 +224,13 @@ def send_close(self, code: Optional[int] = None, reason: str = "") -> None: if code is None: if reason != "": raise ValueError("cannot send a reason without a code") + close = Close(1005, "") data = b"" else: - data = Close(code, reason).serialize() + close = Close(code, reason) + data = close.serialize() self.send_frame(Frame(OP_CLOSE, data)) + self.close_sent = close # send_frame() guarantees that self.state is OPEN at this point. # 7.1.3. The WebSocket Closing Handshake is Started self.set_state(CLOSING) @@ -258,7 +286,9 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> None: # sent if it's CLOSING), except when failing the connection because of # an error reading from or writing to the network. if code != 1006 and self.state is OPEN: - self.send_frame(Frame(OP_CLOSE, Close(code, reason).serialize())) + close = Close(code, reason) + self.send_frame(Frame(OP_CLOSE, close.serialize())) + self.close_sent = close self.set_state(CLOSING) if not self.eof_sent: self.send_eof() @@ -304,7 +334,7 @@ def parse(self) -> Generator[None, None, None]: if eof: if self.debug: self.logger.debug("< EOF") - if self.close_frame_received: + if self.close_rcvd is not None: if not self.eof_sent: self.send_eof() yield @@ -338,7 +368,7 @@ def parse(self) -> Generator[None, None, None]: if frame.opcode is OP_TEXT or frame.opcode is OP_BINARY: # 5.5.1 Close: "The application MUST NOT send any more data # frames after sending a Close frame." - if self.close_frame_received: + if self.close_rcvd is not None: raise ProtocolError("data frame after close frame") if self.cur_size is not None: @@ -351,7 +381,7 @@ def parse(self) -> Generator[None, None, None]: elif frame.opcode is OP_CONT: # 5.5.1 Close: "The application MUST NOT send any more data # frames after sending a Close frame." - if self.close_frame_received: + if self.close_rcvd is not None: raise ProtocolError("data frame after close frame") if self.cur_size is None: @@ -365,7 +395,7 @@ def parse(self) -> Generator[None, None, None]: # 5.5.2. Ping: "Upon receipt of a Ping frame, an endpoint MUST # send a Pong frame in response, unless it already received a # Close frame." - if not self.close_frame_received: + if self.close_rcvd is None: pong_frame = Frame(OP_PONG, frame.data) self.send_frame(pong_frame) @@ -375,11 +405,12 @@ def parse(self) -> Generator[None, None, None]: pass elif frame.opcode is OP_CLOSE: - self.close_frame_received = True # 7.1.5. The WebSocket Connection Close Code # 7.1.6. The WebSocket Connection Close Reason - close = Close.parse(frame.data) - self.close_code, self.close_reason = close.code, close.reason + self.close_rcvd = Close.parse(frame.data) + if self.state is CLOSING: + assert self.close_sent is not None + self.close_rcvd_then_sent = False if self.cur_size is not None: raise ProtocolError("incomplete fragmented message") @@ -388,12 +419,15 @@ def parse(self) -> Generator[None, None, None]: # Close frame in response. (When sending a Close frame in # response, the endpoint typically echos the status code it # received.)" + if self.state is OPEN: # Echo the original data instead of re-serializing it with # Close.serialize() because that fails when the close frame # is empty and Close.parse() synthetizes a 1005 close code. # The rest is identical to send_close(). self.send_frame(Frame(OP_CLOSE, frame.data)) + self.close_sent = self.close_rcvd + self.close_rcvd_then_sent = True self.set_state(CLOSING) if self.side is SERVER: self.send_eof() diff --git a/tests/test_connection.py b/tests/test_connection.py index f2ce8de46..2a887564b 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -667,10 +667,58 @@ def test_server_receives_binary_after_receiving_close(self): class CloseTests(ConnectionTestCase): """ - Test close frames. See 5.5.1. Close in RFC 6544. + Test close frames. + + See RFC 6544: + + 5.5.1. Close + 7.1.6. The WebSocket Connection Close Reason + 7.1.7. Fail the WebSocket Connection """ + def test_close_code(self): + client = Connection(Side.CLIENT) + client.receive_data(b"\x88\x04\x03\xe8OK") + client.set_state(State.CLOSED) + self.assertEqual(client.close_code, 1000) + + def test_close_reason(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x88\x84\x00\x00\x00\x00\x03\xe8OK") + server.set_state(State.CLOSED) + self.assertEqual(server.close_reason, "OK") + + def test_close_code_not_provided(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x88\x80\x00\x00\x00\x00") + server.set_state(State.CLOSED) + self.assertEqual(server.close_code, 1005) + + def test_close_reason_not_provided(self): + client = Connection(Side.CLIENT) + client.receive_data(b"\x88\x00") + client.set_state(State.CLOSED) + self.assertEqual(client.close_reason, "") + + def test_close_code_not_available(self): + client = Connection(Side.CLIENT) + client.set_state(State.CLOSED) + self.assertEqual(client.close_code, 1006) + + def test_close_reason_not_available(self): + server = Connection(Side.SERVER) + server.set_state(State.CLOSED) + self.assertEqual(server.close_reason, "") + + def test_close_code_not_available_yet(self): + server = Connection(Side.SERVER) + self.assertIsNone(server.close_code) + + def test_close_reason_not_available_yet(self): + client = Connection(Side.CLIENT) + self.assertIsNone(client.close_reason) + def test_client_sends_close(self): client = Connection(Side.CLIENT) with self.enforce_mask(b"\x3c\x3c\x3c\x3c"): From 8152df6f0b02c1e8c655a3845b7d39e9904c5a9b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 17 Jun 2021 06:57:47 +0200 Subject: [PATCH 0872/1539] Refactor exceptions to create messages on demand. --- src/websockets/exceptions.py | 96 +++++++++++++++++++++--------------- tests/test_exceptions.py | 5 ++ 2 files changed, 60 insertions(+), 41 deletions(-) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 67745da61..c8ae1d6b5 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -90,24 +90,25 @@ def __init__( self.rcvd = rcvd self.sent = sent self.rcvd_then_sent = rcvd_then_sent - if rcvd is None: - if sent is None: - assert rcvd_then_sent is None - msg = "no close frame received or sent" + + def __str__(self) -> str: + if self.rcvd is None: + if self.sent is None: + assert self.rcvd_then_sent is None + return "no close frame received or sent" else: - assert rcvd_then_sent is None - msg = f"sent {sent}; no close frame received" + assert self.rcvd_then_sent is None + return f"sent {self.sent}; no close frame received" else: - if sent is None: - assert rcvd_then_sent is None - msg = f"received {rcvd}; no close frame sent" + if self.sent is None: + assert self.rcvd_then_sent is None + return f"received {self.rcvd}; no close frame sent" else: - assert rcvd_then_sent is not None - if rcvd_then_sent: - msg = f"received {rcvd}; then sent {sent}" + assert self.rcvd_then_sent is not None + if self.rcvd_then_sent: + return f"received {self.rcvd}; then sent {self.sent}" else: - msg = f"sent {sent}; then received {rcvd}" - super().__init__(msg) + return f"sent {self.sent}; then received {self.rcvd}" # code and reason attributes are provided for backwards-compatibility @@ -171,13 +172,14 @@ class InvalidHeader(InvalidHandshake): def __init__(self, name: str, value: Optional[str] = None) -> None: self.name = name self.value = value - if value is None: - message = f"missing {name} header" - elif value == "": - message = f"empty {name} header" + + def __str__(self) -> str: + if self.value is None: + return f"missing {self.name} header" + elif self.value == "": + return f"empty {self.name} header" else: - message = f"invalid {name} header: {value}" - super().__init__(message) + return f"invalid {self.name} header: {self.value}" class InvalidHeaderFormat(InvalidHeader): @@ -189,9 +191,7 @@ class InvalidHeaderFormat(InvalidHeader): """ def __init__(self, name: str, error: str, header: str, pos: int) -> None: - self.name = name - error = f"{error} at {pos} in {header}" - super().__init__(name, error) + super().__init__(name, f"{error} at {pos} in {header}") class InvalidHeaderValue(InvalidHeader): @@ -228,8 +228,12 @@ class InvalidStatus(InvalidHandshake): def __init__(self, response: http11.Response) -> None: self.response = response - message = f"server rejected WebSocket connection: HTTP {response.status_code:d}" - super().__init__(message) + + def __str__(self) -> str: + return ( + "server rejected WebSocket connection: " + f"HTTP {self.response.status_code:d}" + ) class InvalidStatusCode(InvalidHandshake): @@ -244,8 +248,9 @@ class InvalidStatusCode(InvalidHandshake): def __init__(self, status_code: int, headers: datastructures.Headers) -> None: self.status_code = status_code self.headers = headers - message = f"server rejected WebSocket connection: HTTP {status_code}" - super().__init__(message) + + def __str__(self) -> str: + return f"server rejected WebSocket connection: HTTP {self.status_code}" class NegotiationError(InvalidHandshake): @@ -263,8 +268,9 @@ class DuplicateParameter(NegotiationError): def __init__(self, name: str) -> None: self.name = name - message = f"duplicate parameter: {name}" - super().__init__(message) + + def __str__(self) -> str: + return f"duplicate parameter: {self.name}" class InvalidParameterName(NegotiationError): @@ -275,8 +281,9 @@ class InvalidParameterName(NegotiationError): def __init__(self, name: str) -> None: self.name = name - message = f"invalid parameter name: {name}" - super().__init__(message) + + def __str__(self) -> str: + return f"invalid parameter name: {self.name}" class InvalidParameterValue(NegotiationError): @@ -288,13 +295,14 @@ class InvalidParameterValue(NegotiationError): def __init__(self, name: str, value: Optional[str]) -> None: self.name = name self.value = value - if value is None: - message = f"missing value for parameter {name}" - elif value == "": - message = f"empty value for parameter {name}" + + def __str__(self) -> str: + if self.value is None: + return f"missing value for parameter {self.name}" + elif self.value == "": + return f"empty value for parameter {self.name}" else: - message = f"invalid value for parameter {name}: {value}" - super().__init__(message) + return f"invalid value for parameter {self.name}: {self.value}" class AbortHandshake(InvalidHandshake): @@ -316,8 +324,13 @@ def __init__( self.status = status self.headers = datastructures.Headers(headers) self.body = body - message = f"HTTP {status:d}, {len(self.headers)} headers, {len(body)} bytes" - super().__init__(message) + + def __str__(self) -> str: + return ( + f"HTTP {self.status:d}, " + f"{len(self.headers)} headers, " + f"{len(self.body)} bytes" + ) class RedirectHandshake(InvalidHandshake): @@ -354,8 +367,9 @@ class InvalidURI(WebSocketException): def __init__(self, uri: str) -> None: self.uri = uri - message = "{} isn't a valid URI".format(uri) - super().__init__(message) + + def __str__(self) -> str: + return f"{self.uri} isn't a valid URI" class PayloadTooBig(WebSocketException): diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 85ebc24da..e172cdd02 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -3,6 +3,7 @@ from websockets.datastructures import Headers from websockets.exceptions import * from websockets.frames import Close +from websockets.http11 import Response class ExceptionsTests(unittest.TestCase): @@ -96,6 +97,10 @@ def test_str(self): InvalidUpgrade("Connection", "websocket"), "invalid Connection header: websocket", ), + ( + InvalidStatus(Response(401, "Unauthorized", Headers())), + "server rejected WebSocket connection: HTTP 401", + ), ( InvalidStatusCode(403, Headers()), "server rejected WebSocket connection: HTTP 403", From 0bfa9f2ea86dbc12f24e4a2756cb339efafe76d8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 17 Jun 2021 07:17:58 +0200 Subject: [PATCH 0873/1539] Document ConnectionClosed attributes. --- src/websockets/exceptions.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index c8ae1d6b5..b3462484f 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -76,8 +76,17 @@ class ConnectionClosed(WebSocketException): """ Raised when trying to interact with a closed connection. - Provides the connection close code and reason in its ``code`` and - ``reason`` attributes respectively. + If a close frame was received, its code and reason are available in the + ``rcvd.code`` and ``rcvd.reason`` attributes. Else, the ``rcvd`` + attribute is ``None``. + + Likewise, if a close frame was sent, its code and reason are available in + the ``sent.code`` and ``sent.reason`` attributes. Else, the ``sent`` + attribute is ``None``. + + If close frames were received and sent, the ``rcvd_then_sent`` attribute + tells in which order this happened, from the perspective of this side of + the connection. """ From f148d821f3b9ce620c65011281056f7655ea6fa6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 17 Jun 2021 22:06:16 +0200 Subject: [PATCH 0874/1539] Remove superfluous logging. --- src/websockets/legacy/protocol.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index badd1e0d8..cab72caef 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1316,14 +1316,6 @@ def abort_pings(self) -> None: # nothing, but it prevents logging the exception. ping.cancel() - if self.debug: - if self.pings: - pings_hex = ", ".join( - ping_id.hex() or "[empty]" for ping_id in self.pings - ) - plural = "s" if len(self.pings) > 1 else "" - self.logger.debug("% aborted pending ping%s: %s", plural, pings_hex) - # asyncio.Protocol methods def connection_made(self, transport: asyncio.BaseTransport) -> None: From b2a95c45fae19fff0a3473158dd02afe2ca42604 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 17 Jun 2021 22:30:00 +0200 Subject: [PATCH 0875/1539] Don't answer pings on closing connection. Technically, this is the wrong behavior, but I'll live with this in the legacy layer. The new Sans I/O layer has the right behavior. Fix #669. --- src/websockets/legacy/protocol.py | 13 +++++++------ tests/legacy/test_protocol.py | 12 +++++++++++- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index cab72caef..df74bdb38 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -975,12 +975,13 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: return None elif frame.opcode == OP_PING: - # Answer pings. - try: - await self.pong(frame.data) - except ConnectionClosed: - # Connection closed before we could respond to the ping. - pass + # Answer pings, unless connection is CLOSING. + if self.state is State.OPEN: + try: + await self.pong(frame.data) + except ConnectionClosed: + # Connection closed while draining write buffer. + pass elif frame.opcode == OP_PONG: if frame.data in self.pings: diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 033cebe17..61a5fe7cf 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -909,10 +909,20 @@ def test_answer_ping(self): self.run_loop_once() self.assertOneFrameSent(True, OP_PONG, b"test") + def test_answer_ping_does_not_crash_if_connection_closing(self): + close_task = self.half_close_connection_local() + + self.receive_frame(Frame(True, OP_PING, b"test")) + + with self.assertNoLogs(): + self.loop.run_until_complete(self.protocol.close()) + + self.loop.run_until_complete(close_task) # cleanup + def test_answer_ping_does_not_crash_if_connection_closed(self): self.make_drain_slow() # Drop the connection right after receiving a ping frame, - # which prevents responding wwith a pong frame properly. + # which prevents responding with a pong frame properly. self.receive_frame(Frame(True, OP_PING, b"test")) self.receive_eof() From 361f38066622636eff8d9e0d3d071e93e8e02b05 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 17 Jun 2021 23:31:31 +0200 Subject: [PATCH 0876/1539] Handle connection drop during handshake. Fix #984. --- src/websockets/legacy/server.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 80d72f93e..3fde99568 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -29,6 +29,7 @@ cast, ) +from ..connection import State from ..datastructures import Headers, HeadersLike, MultipleValuesError from ..exceptions import ( AbortHandshake, @@ -656,6 +657,10 @@ async def handshake( warnings.warn("declare process_request as a coroutine", DeprecationWarning) early_response = early_response_awaitable # type: ignore + # The connection may drop while process_request is running. + if self.state is State.CLOSED: + raise self.connection_closed_exc() # pragma: no cover + # Change the response to a 503 error if the server is shutting down. if not self.ws_server.is_serving(): early_response = ( From f9c28e8b7810c5a2fcadc88825046618f1fdc012 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 18 Jun 2021 18:00:30 +0200 Subject: [PATCH 0877/1539] Polish class structure. --- src/websockets/connection.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index d682226ab..9fd7db490 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -118,10 +118,7 @@ def __init__( next(self.parser) # start coroutine self.parser_exc: Optional[Exception] = None - def set_state(self, state: State) -> None: - if self.debug: - self.logger.debug("= connection is %s", state.name) - self.state = state + # Public attributes @property def close_code(self) -> Optional[int]: @@ -153,6 +150,13 @@ def close_reason(self) -> Optional[str]: else: return self.close_rcvd.reason + # Private attributes + + def set_state(self, state: State) -> None: + if self.debug: + self.logger.debug("= connection is %s", state.name) + self.state = state + # Public APIs for receiving data. def receive_data(self, data: bytes) -> None: From ebda766d856b59faa86e14f9cf53a5e0ede0a355 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 18 Jun 2021 18:03:01 +0200 Subject: [PATCH 0878/1539] Fix connection closure. * In a regular closing handshake, server closes the connection after receiving and sending a close frame. * When failing the connection, server closes the connection after sending a close frame. --- src/websockets/connection.py | 55 +++++++++++++++++++++--------------- tests/test_connection.py | 26 +++++++++-------- 2 files changed, 46 insertions(+), 35 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 9fd7db490..e5c9b866c 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -238,8 +238,6 @@ def send_close(self, code: Optional[int] = None, reason: str = "") -> None: # send_frame() guarantees that self.state is OPEN at this point. # 7.1.3. The WebSocket Closing Handshake is Started self.set_state(CLOSING) - if self.side is SERVER: - self.send_eof() def send_ping(self, data: bytes) -> None: """ @@ -255,6 +253,22 @@ def send_pong(self, data: bytes) -> None: """ self.send_frame(Frame(OP_PONG, data)) + def fail(self, code: Optional[int] = None, reason: str = "") -> None: + """ + Fail the WebSocket connection. + + """ + # 7.1.7. Fail the WebSocket Connection + + # Send a close frame when the state is OPEN (a close frame was already + # sent if it's CLOSING), except when failing the connection because + # of an error reading from or writing to the network. + if self.state is OPEN and code != 1006: + self.send_close(code, reason) + + if self.side is SERVER and not self.eof_sent: + self.send_eof() + # Public API for getting incoming events after receiving data. def events_received(self) -> List[Event]: @@ -274,8 +288,9 @@ def data_to_send(self) -> List[bytes]: """ Return data to write to the connection. - Call this method immediately after calling any of the ``receive_*()`` - or ``send_*()`` methods and write the data to the connection. + Call this method immediately after calling any of the ``receive_*()``, + ``send_*()``, or ``fail()`` methods and write the data to the + connection. The empty bytestring signals the end of the data stream. @@ -285,18 +300,6 @@ def data_to_send(self) -> List[bytes]: # Private APIs for receiving data. - def fail_connection(self, code: int = 1006, reason: str = "") -> None: - # Send a close frame when the state is OPEN (a close frame was already - # sent if it's CLOSING), except when failing the connection because of - # an error reading from or writing to the network. - if code != 1006 and self.state is OPEN: - close = Close(code, reason) - self.send_frame(Frame(OP_CLOSE, close.serialize())) - self.close_sent = close - self.set_state(CLOSING) - if not self.eof_sent: - self.send_eof() - def step_parser(self) -> None: # Run parser until more data is needed or EOF try: @@ -310,25 +313,25 @@ def step_parser(self) -> None: "parser cannot receive data or EOF after an error" ) from self.parser_exc except ProtocolError as exc: - self.fail_connection(1002, str(exc)) + self.fail(1002, str(exc)) self.parser_exc = exc raise except EOFError as exc: - self.fail_connection(1006, str(exc)) + self.fail(1006, str(exc)) self.parser_exc = exc raise except UnicodeDecodeError as exc: - self.fail_connection(1007, f"{exc.reason} at position {exc.start}") + self.fail(1007, f"{exc.reason} at position {exc.start}") self.parser_exc = exc raise except PayloadTooBig as exc: - self.fail_connection(1009, str(exc)) + self.fail(1009, str(exc)) self.parser_exc = exc raise except Exception as exc: self.logger.error("parser failed", exc_info=True) # Don't include exception details, which may be security-sensitive. - self.fail_connection(1011) + self.fail(1011) self.parser_exc = exc raise @@ -418,6 +421,7 @@ def parse(self) -> Generator[None, None, None]: if self.cur_size is not None: raise ProtocolError("incomplete fragmented message") + # 5.5.1 Close: "If an endpoint receives a Close frame and did # not previously send a Close frame, the endpoint MUST send a # Close frame in response. (When sending a Close frame in @@ -433,8 +437,13 @@ def parse(self) -> Generator[None, None, None]: self.close_sent = self.close_rcvd self.close_rcvd_then_sent = True self.set_state(CLOSING) - if self.side is SERVER: - self.send_eof() + + # 7.1.2. Start the WebSocket Closing Handshake: "Once an + # endpoint has both sent and received a Close control frame, + # that endpoint SHOULD _Close the WebSocket Connection_" + + if self.side is SERVER: + self.send_eof() else: # pragma: no cover # This can't happen because Frame.parse() validates opcodes. diff --git a/tests/test_connection.py b/tests/test_connection.py index 2a887564b..2d6ccb3f9 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -80,8 +80,10 @@ def assertConnectionFailing(self, connection, code=None, reason=""): ) # No frame was received. self.assertFrameReceived(connection, None) - # A close frame and the end of stream were sent. - self.assertFrameSent(connection, close_frame, eof=True) + # A close frame and possibly the end of stream were sent. + self.assertFrameSent( + connection, close_frame, eof=connection.side is Side.SERVER + ) class MaskingTests(ConnectionTestCase): @@ -187,7 +189,7 @@ def test_server_sends_continuation_after_sending_close(self): # this is the same test as test_server_sends_unexpected_continuation. server = Connection(Side.SERVER) server.send_close(1000) - self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8", b""]) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) with self.assertRaises(ProtocolError) as raised: server.send_continuation(b"", fin=False) self.assertEqual(str(raised.exception), "unexpected continuation frame") @@ -442,7 +444,7 @@ def test_client_sends_text_after_sending_close(self): def test_server_sends_text_after_sending_close(self): server = Connection(Side.SERVER) server.send_close(1000) - self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8", b""]) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) with self.assertRaises(InvalidState): server.send_text(b"") @@ -644,7 +646,7 @@ def test_client_sends_binary_after_sending_close(self): def test_server_sends_binary_after_sending_close(self): server = Connection(Side.SERVER) server.send_close(1000) - self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8", b""]) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) with self.assertRaises(InvalidState): server.send_binary(b"") @@ -729,7 +731,7 @@ def test_client_sends_close(self): def test_server_sends_close(self): server = Connection(Side.SERVER) server.send_close() - self.assertEqual(server.data_to_send(), [b"\x88\x00", b""]) + self.assertEqual(server.data_to_send(), [b"\x88\x00"]) self.assertIs(server.state, State.CLOSING) def test_client_receives_close(self): @@ -769,11 +771,11 @@ def test_server_sends_close_then_receives_close(self): server.send_close() self.assertFrameReceived(server, None) - self.assertFrameSent(server, Frame(OP_CLOSE, b""), eof=True) + self.assertFrameSent(server, Frame(OP_CLOSE, b"")) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") self.assertFrameReceived(server, Frame(OP_CLOSE, b"")) - self.assertFrameSent(server, None) + self.assertFrameSent(server, None, eof=True) server.receive_eof() self.assertFrameReceived(server, None) @@ -813,7 +815,7 @@ def test_client_sends_close_with_code(self): def test_server_sends_close_with_code(self): server = Connection(Side.SERVER) server.send_close(1000) - self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8", b""]) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) self.assertIs(server.state, State.CLOSING) def test_client_receives_close_with_code(self): @@ -840,7 +842,7 @@ def test_client_sends_close_with_code_and_reason(self): def test_server_sends_close_with_code_and_reason(self): server = Connection(Side.SERVER) server.send_close(1000, "OK") - self.assertEqual(server.data_to_send(), [b"\x88\x04\x03\xe8OK", b""]) + self.assertEqual(server.data_to_send(), [b"\x88\x04\x03\xe8OK"]) self.assertIs(server.state, State.CLOSING) def test_client_receives_close_with_code_and_reason(self): @@ -1029,7 +1031,7 @@ def test_client_sends_ping_after_sending_close(self): def test_server_sends_ping_after_sending_close(self): server = Connection(Side.SERVER) server.send_close(1000) - self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8", b""]) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) # The spec says: "An endpoint MAY send a Ping frame any time (...) # before the connection is closed" but websockets doesn't support # sending a Ping frame after a Close frame. @@ -1164,7 +1166,7 @@ def test_client_sends_pong_after_sending_close(self): def test_server_sends_pong_after_sending_close(self): server = Connection(Side.SERVER) server.send_close(1000) - self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8", b""]) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) # websockets doesn't support sending a Pong frame after a Close frame. with self.assertRaises(InvalidState): server.send_pong(b"") From 8f6c4d9e22c7105248cebbe0cb50e47008e12257 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 18 Jun 2021 18:35:10 +0200 Subject: [PATCH 0879/1539] Stop processing after failing connection. --- src/websockets/connection.py | 15 ++++++----- src/websockets/streams.py | 8 ++++++ tests/test_connection.py | 49 ++++++++++++++++++++++++------------ tests/test_streams.py | 9 +++++++ 4 files changed, 57 insertions(+), 24 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index e5c9b866c..966547f59 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -269,6 +269,11 @@ def fail(self, code: Optional[int] = None, reason: str = "") -> None: if self.side is SERVER and not self.eof_sent: self.send_eof() + # "An endpoint MUST NOT continue to attempt to process data + # (including a responding Close frame) from the remote endpoint + # after being instructed to _Fail the WebSocket Connection_." + self.reader.abort() + # Public API for getting incoming events after receiving data. def events_received(self) -> List[Event]: @@ -304,14 +309,8 @@ def step_parser(self) -> None: # Run parser until more data is needed or EOF try: next(self.parser) - except StopIteration: - # This happens if receive_data() or receive_eof() is called after - # the parser raised an exception. (It cannot happen after reaching - # EOF because receive_data() or receive_eof() would fail earlier.) - assert self.parser_exc is not None - raise RuntimeError( - "parser cannot receive data or EOF after an error" - ) from self.parser_exc + except StopIteration: # pragma: no cover + raise AssertionError("parser shouldn't exit") except ProtocolError as exc: self.fail(1002, str(exc)) self.parser_exc = exc diff --git a/src/websockets/streams.py b/src/websockets/streams.py index e02a6ab39..92a356265 100644 --- a/src/websockets/streams.py +++ b/src/websockets/streams.py @@ -115,3 +115,11 @@ def feed_eof(self) -> None: if self.eof: raise EOFError("stream ended") self.eof = True + + def abort(self) -> None: + """ + End the stream, discarding all buffered data. + + """ + self.eof = True + del self.buffer[:] diff --git a/tests/test_connection.py b/tests/test_connection.py index 2d6ccb3f9..45a0594d2 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1192,6 +1192,31 @@ def test_server_receives_pong_after_receiving_close(self): ) +class FailTests(ConnectionTestCase): + """ + Test failing the connection. + + See 7.1.7. Fail the WebSocket Connection in RFC 6544. + + """ + + def test_client_stops_processing_frames_after_fail(self): + client = Connection(Side.CLIENT) + client.fail(1002) + self.assertConnectionFailing(client, 1002) + with self.assertRaises(EOFError) as raised: + client.receive_data(b"\x88\x02\x03\xea") + self.assertEqual(str(raised.exception), "stream ended") + + def test_server_stops_processing_frames_after_fail(self): + server = Connection(Side.SERVER) + server.fail(1002) + self.assertConnectionFailing(server, 1002) + with self.assertRaises(EOFError) as raised: + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xea") + self.assertEqual(str(raised.exception), "stream ended") + + class FragmentationTests(ConnectionTestCase): """ Test message fragmentation. @@ -1401,44 +1426,36 @@ def test_client_receives_data_after_exception(self): with self.assertRaises(ProtocolError) as raised: client.receive_data(b"\xff\xff") self.assertEqual(str(raised.exception), "invalid opcode") - with self.assertRaises(RuntimeError) as raised: + with self.assertRaises(EOFError) as raised: client.receive_data(b"\x00\x00") - self.assertEqual( - str(raised.exception), "parser cannot receive data or EOF after an error" - ) + self.assertEqual(str(raised.exception), "stream ended") def test_server_receives_data_after_exception(self): server = Connection(Side.SERVER) with self.assertRaises(ProtocolError) as raised: server.receive_data(b"\xff\xff") self.assertEqual(str(raised.exception), "invalid opcode") - with self.assertRaises(RuntimeError) as raised: + with self.assertRaises(EOFError) as raised: server.receive_data(b"\x00\x00") - self.assertEqual( - str(raised.exception), "parser cannot receive data or EOF after an error" - ) + self.assertEqual(str(raised.exception), "stream ended") def test_client_receives_eof_after_exception(self): client = Connection(Side.CLIENT) with self.assertRaises(ProtocolError) as raised: client.receive_data(b"\xff\xff") self.assertEqual(str(raised.exception), "invalid opcode") - with self.assertRaises(RuntimeError) as raised: + with self.assertRaises(EOFError) as raised: client.receive_eof() - self.assertEqual( - str(raised.exception), "parser cannot receive data or EOF after an error" - ) + self.assertEqual(str(raised.exception), "stream ended") def test_server_receives_eof_after_exception(self): server = Connection(Side.SERVER) with self.assertRaises(ProtocolError) as raised: server.receive_data(b"\xff\xff") self.assertEqual(str(raised.exception), "invalid opcode") - with self.assertRaises(RuntimeError) as raised: + with self.assertRaises(EOFError) as raised: server.receive_eof() - self.assertEqual( - str(raised.exception), "parser cannot receive data or EOF after an error" - ) + self.assertEqual(str(raised.exception), "stream ended") def test_client_receives_data_after_eof(self): client = Connection(Side.CLIENT) diff --git a/tests/test_streams.py b/tests/test_streams.py index 566deb2db..e942c8d12 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -144,3 +144,12 @@ def test_feed_eof_after_feed_eof(self): with self.assertRaises(EOFError) as raised: self.reader.feed_eof() self.assertEqual(str(raised.exception), "stream ended") + + def test_abort(self): + gen = self.reader.read_to_eof() + + self.reader.feed_data(b"spam") + self.reader.abort() + + data = self.assertGeneratorReturns(gen) + self.assertEqual(data, b"") From bff6397ffb69dd52c3e91a23268915582ce0b5e2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 18 Jun 2021 21:52:03 +0200 Subject: [PATCH 0880/1539] Factor out bytes-like types. --- src/websockets/frames.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 4c57386b7..839e2f7b7 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -96,6 +96,9 @@ class Opcode(enum.IntEnum): OK_CLOSE_CODES = {1000, 1001} +BytesLike = bytes, bytearray, memoryview + + @dataclasses.dataclass class Frame: """ @@ -334,7 +337,7 @@ def prepare_data(data: Data) -> Tuple[int, bytes]: """ if isinstance(data, str): return OP_TEXT, data.encode("utf-8") - elif isinstance(data, (bytes, bytearray, memoryview)): + elif isinstance(data, BytesLike): return OP_BINARY, data else: raise TypeError("data must be bytes-like or str") @@ -356,7 +359,7 @@ def prepare_ctrl(data: Data) -> bytes: """ if isinstance(data, str): return data.encode("utf-8") - elif isinstance(data, (bytes, bytearray, memoryview)): + elif isinstance(data, BytesLike): return bytes(data) else: raise TypeError("data must be bytes-like or str") From 82ba1ff963f623b80f2d6f70c684f545390708a2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 19 Jun 2021 11:46:22 +0200 Subject: [PATCH 0881/1539] Clarify comments. Attributes are also APIs. --- src/websockets/connection.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 966547f59..620fc6f78 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -157,7 +157,7 @@ def set_state(self, state: State) -> None: self.logger.debug("= connection is %s", state.name) self.state = state - # Public APIs for receiving data. + # Public methods for receiving data. def receive_data(self, data: bytes) -> None: """ @@ -186,7 +186,7 @@ def receive_eof(self) -> None: self.reader.feed_eof() self.step_parser() - # Public APIs for sending events. + # Public methods for sending events. def send_continuation(self, data: bytes, fin: bool) -> None: """ @@ -274,7 +274,7 @@ def fail(self, code: Optional[int] = None, reason: str = "") -> None: # after being instructed to _Fail the WebSocket Connection_." self.reader.abort() - # Public API for getting incoming events after receiving data. + # Public method for getting incoming events after receiving data. def events_received(self) -> List[Event]: """ @@ -287,7 +287,7 @@ def events_received(self) -> List[Event]: events, self.events = self.events, [] return events - # Public API for getting outgoing data after receiving data or sending events. + # Public method for getting outgoing data after receiving data or sending events. def data_to_send(self) -> List[bytes]: """ @@ -303,7 +303,7 @@ def data_to_send(self) -> List[bytes]: writes, self.writes = self.writes, [] return writes - # Private APIs for receiving data. + # Private methods for receiving data. def step_parser(self) -> None: # Run parser until more data is needed or EOF @@ -450,7 +450,7 @@ def parse(self) -> Generator[None, None, None]: self.events.append(frame) - # Private APIs for sending events. + # Private methods for sending events. def send_frame(self, frame: Frame) -> None: # Defensive assertion for protocol compliance. From be42b07606a5cd3975dc92ea0e8a8f2821b55af0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 19 Jun 2021 11:46:58 +0200 Subject: [PATCH 0882/1539] Uniformize signature. --- src/websockets/legacy/client.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index d94e9d7b9..c87b5f8d5 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -737,7 +737,9 @@ async def __await_impl__(self) -> WebSocketClientProtocol: def unix_connect( - path: Optional[str], uri: str = "ws://localhost/", **kwargs: Any + path: Optional[str] = None, + uri: str = "ws://localhost/", + **kwargs: Any, ) -> Connect: """ Similar to :func:`connect`, but for connecting to a Unix socket. From 98e8b8d31326557c5c1a85fa0bd75d390b34b3f3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Jun 2021 07:39:17 +0200 Subject: [PATCH 0883/1539] Compare enum with identity. --- src/websockets/client.py | 2 +- src/websockets/server.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 3217090d0..cc807e97c 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -306,7 +306,7 @@ def parse(self) -> Generator[None, None, None]: except InvalidHandshake as exc: response.exception = exc else: - assert self.state == CONNECTING + assert self.state is CONNECTING self.set_state(OPEN) finally: self.events.append(response) diff --git a/src/websockets/server.py b/src/websockets/server.py index a53798f65..0710dc0be 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -435,7 +435,7 @@ def send_response(self, response: Response) -> None: self.writes.append(response.serialize()) def parse(self) -> Generator[None, None, None]: - if self.state == CONNECTING: + if self.state is CONNECTING: request = yield from Request.parse(self.reader.read_line) if self.debug: From 81a4a2a369190ae4a30e7914ead889c5931172d1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Jun 2021 07:41:45 +0200 Subject: [PATCH 0884/1539] Close connecting when opening handshake fails. --- src/websockets/server.py | 10 ++++++---- tests/test_server.py | 3 ++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/websockets/server.py b/src/websockets/server.py index 0710dc0be..545b38c98 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -420,10 +420,6 @@ def send_response(self, response: Response) -> None: Send a WebSocket handshake response to the client. """ - if response.status_code == 101: - assert self.state is CONNECTING - self.set_state(OPEN) - if self.debug: code, phrase = response.status_code, response.reason_phrase self.logger.debug("> HTTP/1.1 %d %s", code, phrase) @@ -434,6 +430,12 @@ def send_response(self, response: Response) -> None: self.writes.append(response.serialize()) + if response.status_code == 101: + assert self.state is CONNECTING + self.set_state(OPEN) + else: + self.send_eof() + def parse(self) -> Generator[None, None, None]: if self.state is CONNECTING: request = yield from Request.parse(self.reader.read_line) diff --git a/tests/test_server.py b/tests/test_server.py index d2c41598e..042d64a31 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -122,7 +122,8 @@ def test_send_reject(self): f"Content-Type: text/plain; charset=utf-8\r\n" f"Connection: close\r\n" f"\r\n" - f"Sorry folks.\n".encode() + f"Sorry folks.\n".encode(), + b"", ], ) self.assertEqual(server.state, CONNECTING) From 5495af64f016c332f4a6f91811e9896f07e0ea59 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Jun 2021 09:22:16 +0200 Subject: [PATCH 0885/1539] Prevent client from closing TCP connection Read until EOF and wait for the server to close the connection first. --- src/websockets/connection.py | 273 +++++++++++++++++++---------------- src/websockets/streams.py | 5 +- tests/test_connection.py | 64 ++++---- tests/test_streams.py | 6 +- 4 files changed, 190 insertions(+), 158 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 620fc6f78..0a3a87bd5 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -170,7 +170,7 @@ def receive_data(self, data: bytes) -> None: """ self.reader.feed_data(data) - self.step_parser() + next(self.parser) def receive_eof(self) -> None: """ @@ -184,7 +184,7 @@ def receive_eof(self) -> None: """ self.reader.feed_eof() - self.step_parser() + next(self.parser) # Public methods for sending events. @@ -233,10 +233,10 @@ def send_close(self, code: Optional[int] = None, reason: str = "") -> None: else: close = Close(code, reason) data = close.serialize() - self.send_frame(Frame(OP_CLOSE, data)) - self.close_sent = close # send_frame() guarantees that self.state is OPEN at this point. # 7.1.3. The WebSocket Closing Handshake is Started + self.send_frame(Frame(OP_CLOSE, data)) + self.close_sent = close self.set_state(CLOSING) def send_ping(self, data: bytes) -> None: @@ -253,7 +253,7 @@ def send_pong(self, data: bytes) -> None: """ self.send_frame(Frame(OP_PONG, data)) - def fail(self, code: Optional[int] = None, reason: str = "") -> None: + def fail(self, code: int, reason: str = "") -> None: """ Fail the WebSocket connection. @@ -263,8 +263,13 @@ def fail(self, code: Optional[int] = None, reason: str = "") -> None: # Send a close frame when the state is OPEN (a close frame was already # sent if it's CLOSING), except when failing the connection because # of an error reading from or writing to the network. - if self.state is OPEN and code != 1006: - self.send_close(code, reason) + if self.state is OPEN: + if code != 1006: + close = Close(code, reason) + data = close.serialize() + self.send_frame(Frame(OP_CLOSE, data)) + self.close_sent = close + self.set_state(CLOSING) if self.side is SERVER and not self.eof_sent: self.send_eof() @@ -272,7 +277,8 @@ def fail(self, code: Optional[int] = None, reason: str = "") -> None: # "An endpoint MUST NOT continue to attempt to process data # (including a responding Close frame) from the remote endpoint # after being instructed to _Fail the WebSocket Connection_." - self.reader.abort() + self.parser = self.discard() + next(self.parser) # start coroutine # Public method for getting incoming events after receiving data. @@ -305,28 +311,65 @@ def data_to_send(self) -> List[bytes]: # Private methods for receiving data. - def step_parser(self) -> None: - # Run parser until more data is needed or EOF + def parse(self) -> Generator[None, None, None]: try: - next(self.parser) - except StopIteration: # pragma: no cover - raise AssertionError("parser shouldn't exit") + while True: + if (yield from self.reader.at_eof()): + if self.debug: + self.logger.debug("< EOF") + if self.close_rcvd_then_sent is not None: + if self.side is CLIENT: + self.send_eof() + # If parse() completes normally, execution ends here. + yield + # Once the reader reaches EOF, its feed_data/eof() + # methods raise an error, so our receive_data/eof() + # methods don't step the generator. + raise AssertionError( + "parse() shouldn't step after EOF" + ) # pragma: no cover + else: + raise EOFError("unexpected end of stream") + + if self.max_size is None: + max_size = None + elif self.cur_size is None: + max_size = self.max_size + else: + max_size = self.max_size - self.cur_size + + frame = yield from Frame.parse( + self.reader.read_exact, + mask=self.side is SERVER, + max_size=max_size, + extensions=self.extensions, + ) + + if self.debug: + self.logger.debug("< %s", frame) + + self.recv_frame(frame) + except ProtocolError as exc: self.fail(1002, str(exc)) self.parser_exc = exc raise + except EOFError as exc: self.fail(1006, str(exc)) self.parser_exc = exc raise + except UnicodeDecodeError as exc: self.fail(1007, f"{exc.reason} at position {exc.start}") self.parser_exc = exc raise + except PayloadTooBig as exc: self.fail(1009, str(exc)) self.parser_exc = exc raise + except Exception as exc: self.logger.error("parser failed", exc_info=True) # Don't include exception details, which may be security-sensitive. @@ -334,121 +377,97 @@ def step_parser(self) -> None: self.parser_exc = exc raise - def parse(self) -> Generator[None, None, None]: - while True: - eof = yield from self.reader.at_eof() - if eof: - if self.debug: - self.logger.debug("< EOF") - if self.close_rcvd is not None: - if not self.eof_sent: - self.send_eof() - yield - # Once the reader reaches EOF, its feed_data/eof() methods - # raise an error, so our receive_data/eof() methods never - # call step_parser(), so the generator shouldn't resume - # executing until it's garbage collected. - raise AssertionError( - "parser shouldn't step after EOF" - ) # pragma: no cover - else: - raise EOFError("unexpected end of stream") - - if self.max_size is None: - max_size = None - elif self.cur_size is None: - max_size = self.max_size + def discard(self) -> Generator[None, None, None]: + while not (yield from self.reader.at_eof()): + self.reader.discard() + if self.side is CLIENT: + self.send_eof() + # If discard() completes normally, execution ends here. + yield + # Once the reader reaches EOF, its feed_data/eof() + # methods raise an error, so our receive_data/eof() + # methods don't step the generator. + raise AssertionError("discard() shouldn't step after EOF") # pragma: no cover + + def recv_frame(self, frame: Frame) -> None: + if frame.opcode is OP_TEXT or frame.opcode is OP_BINARY: + # 5.5.1 Close: "The application MUST NOT send any more data + # frames after sending a Close frame." + if self.close_rcvd is not None: + raise ProtocolError("data frame after close frame") + + if self.cur_size is not None: + raise ProtocolError("expected a continuation frame") + if frame.fin: + self.cur_size = None else: - max_size = self.max_size - self.cur_size - - frame = yield from Frame.parse( - self.reader.read_exact, - mask=self.side is SERVER, - max_size=max_size, - extensions=self.extensions, - ) - - if self.debug: - self.logger.debug("< %s", frame) - - if frame.opcode is OP_TEXT or frame.opcode is OP_BINARY: - # 5.5.1 Close: "The application MUST NOT send any more data - # frames after sending a Close frame." - if self.close_rcvd is not None: - raise ProtocolError("data frame after close frame") - - if self.cur_size is not None: - raise ProtocolError("expected a continuation frame") - if frame.fin: - self.cur_size = None - else: - self.cur_size = len(frame.data) - - elif frame.opcode is OP_CONT: - # 5.5.1 Close: "The application MUST NOT send any more data - # frames after sending a Close frame." - if self.close_rcvd is not None: - raise ProtocolError("data frame after close frame") - - if self.cur_size is None: - raise ProtocolError("unexpected continuation frame") - if frame.fin: - self.cur_size = None - else: - self.cur_size += len(frame.data) - - elif frame.opcode is OP_PING: - # 5.5.2. Ping: "Upon receipt of a Ping frame, an endpoint MUST - # send a Pong frame in response, unless it already received a - # Close frame." - if self.close_rcvd is None: - pong_frame = Frame(OP_PONG, frame.data) - self.send_frame(pong_frame) - - elif frame.opcode is OP_PONG: - # 5.5.3 Pong: "A response to an unsolicited Pong frame is not - # expected." - pass - - elif frame.opcode is OP_CLOSE: - # 7.1.5. The WebSocket Connection Close Code - # 7.1.6. The WebSocket Connection Close Reason - self.close_rcvd = Close.parse(frame.data) - if self.state is CLOSING: - assert self.close_sent is not None - self.close_rcvd_then_sent = False - - if self.cur_size is not None: - raise ProtocolError("incomplete fragmented message") - - # 5.5.1 Close: "If an endpoint receives a Close frame and did - # not previously send a Close frame, the endpoint MUST send a - # Close frame in response. (When sending a Close frame in - # response, the endpoint typically echos the status code it - # received.)" - - if self.state is OPEN: - # Echo the original data instead of re-serializing it with - # Close.serialize() because that fails when the close frame - # is empty and Close.parse() synthetizes a 1005 close code. - # The rest is identical to send_close(). - self.send_frame(Frame(OP_CLOSE, frame.data)) - self.close_sent = self.close_rcvd - self.close_rcvd_then_sent = True - self.set_state(CLOSING) - - # 7.1.2. Start the WebSocket Closing Handshake: "Once an - # endpoint has both sent and received a Close control frame, - # that endpoint SHOULD _Close the WebSocket Connection_" - - if self.side is SERVER: - self.send_eof() - - else: # pragma: no cover - # This can't happen because Frame.parse() validates opcodes. - raise AssertionError(f"unexpected opcode: {frame.opcode:02x}") - - self.events.append(frame) + self.cur_size = len(frame.data) + + elif frame.opcode is OP_CONT: + # 5.5.1 Close: "The application MUST NOT send any more data + # frames after sending a Close frame." + if self.close_rcvd is not None: + raise ProtocolError("data frame after close frame") + + if self.cur_size is None: + raise ProtocolError("unexpected continuation frame") + if frame.fin: + self.cur_size = None + else: + self.cur_size += len(frame.data) + + elif frame.opcode is OP_PING: + # 5.5.2. Ping: "Upon receipt of a Ping frame, an endpoint MUST + # send a Pong frame in response, unless it already received a + # Close frame." + if self.close_rcvd is None: + pong_frame = Frame(OP_PONG, frame.data) + self.send_frame(pong_frame) + + elif frame.opcode is OP_PONG: + # 5.5.3 Pong: "A response to an unsolicited Pong frame is not + # expected." + pass + + elif frame.opcode is OP_CLOSE: + # 7.1.5. The WebSocket Connection Close Code + # 7.1.6. The WebSocket Connection Close Reason + self.close_rcvd = Close.parse(frame.data) + if self.state is CLOSING: + assert self.close_sent is not None + self.close_rcvd_then_sent = False + + if self.cur_size is not None: + raise ProtocolError("incomplete fragmented message") + + # 5.5.1 Close: "If an endpoint receives a Close frame and did + # not previously send a Close frame, the endpoint MUST send a + # Close frame in response. (When sending a Close frame in + # response, the endpoint typically echos the status code it + # received.)" + + if self.state is OPEN: + # Echo the original data instead of re-serializing it with + # Close.serialize() because that fails when the close frame + # is empty and Close.parse() synthetizes a 1005 close code. + # The rest is identical to send_close(). + self.send_frame(Frame(OP_CLOSE, frame.data)) + self.close_sent = self.close_rcvd + self.close_rcvd_then_sent = True + self.set_state(CLOSING) + + # 7.1.2. Start the WebSocket Closing Handshake: "Once an + # endpoint has both sent and received a Close control frame, + # that endpoint SHOULD _Close the WebSocket Connection_" + + if self.side is SERVER: + self.send_eof() + + else: # pragma: no cover + # This can't happen because Frame.parse() validates opcodes. + raise AssertionError(f"unexpected opcode: {frame.opcode:02x}") + + self.events.append(frame) # Private methods for sending events. diff --git a/src/websockets/streams.py b/src/websockets/streams.py index 92a356265..d1ce377e7 100644 --- a/src/websockets/streams.py +++ b/src/websockets/streams.py @@ -116,10 +116,9 @@ def feed_eof(self) -> None: raise EOFError("stream ended") self.eof = True - def abort(self) -> None: + def discard(self) -> None: """ - End the stream, discarding all buffered data. + Discarding all buffered data, but don't end the stream. """ - self.eof = True del self.buffer[:] diff --git a/tests/test_connection.py b/tests/test_connection.py index 45a0594d2..b8843028d 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1204,17 +1204,15 @@ def test_client_stops_processing_frames_after_fail(self): client = Connection(Side.CLIENT) client.fail(1002) self.assertConnectionFailing(client, 1002) - with self.assertRaises(EOFError) as raised: - client.receive_data(b"\x88\x02\x03\xea") - self.assertEqual(str(raised.exception), "stream ended") + client.receive_data(b"\x88\x02\x03\xea") + self.assertFrameReceived(client, None) def test_server_stops_processing_frames_after_fail(self): server = Connection(Side.SERVER) server.fail(1002) self.assertConnectionFailing(server, 1002) - with self.assertRaises(EOFError) as raised: - server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xea") - self.assertEqual(str(raised.exception), "stream ended") + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xea") + self.assertFrameReceived(server, None) class FragmentationTests(ConnectionTestCase): @@ -1423,39 +1421,53 @@ def test_server_receives_eof_inside_frame(self): def test_client_receives_data_after_exception(self): client = Connection(Side.CLIENT) - with self.assertRaises(ProtocolError) as raised: + with self.assertRaises(ProtocolError): client.receive_data(b"\xff\xff") - self.assertEqual(str(raised.exception), "invalid opcode") - with self.assertRaises(EOFError) as raised: - client.receive_data(b"\x00\x00") - self.assertEqual(str(raised.exception), "stream ended") + self.assertConnectionFailing(client, 1002, "invalid opcode") + client.receive_data(b"\x00\x00") + self.assertFrameSent(client, None) def test_server_receives_data_after_exception(self): server = Connection(Side.SERVER) - with self.assertRaises(ProtocolError) as raised: + with self.assertRaises(ProtocolError): server.receive_data(b"\xff\xff") - self.assertEqual(str(raised.exception), "invalid opcode") - with self.assertRaises(EOFError) as raised: - server.receive_data(b"\x00\x00") - self.assertEqual(str(raised.exception), "stream ended") + self.assertConnectionFailing(server, 1002, "invalid opcode") + server.receive_data(b"\x00\x00") + self.assertFrameSent(server, None) def test_client_receives_eof_after_exception(self): client = Connection(Side.CLIENT) - with self.assertRaises(ProtocolError) as raised: + with self.assertRaises(ProtocolError): client.receive_data(b"\xff\xff") - self.assertEqual(str(raised.exception), "invalid opcode") - with self.assertRaises(EOFError) as raised: - client.receive_eof() - self.assertEqual(str(raised.exception), "stream ended") + self.assertConnectionFailing(client, 1002, "invalid opcode") + client.receive_eof() + self.assertFrameSent(client, None, eof=True) def test_server_receives_eof_after_exception(self): server = Connection(Side.SERVER) - with self.assertRaises(ProtocolError) as raised: + with self.assertRaises(ProtocolError): server.receive_data(b"\xff\xff") - self.assertEqual(str(raised.exception), "invalid opcode") - with self.assertRaises(EOFError) as raised: - server.receive_eof() - self.assertEqual(str(raised.exception), "stream ended") + self.assertConnectionFailing(server, 1002, "invalid opcode") + server.receive_eof() + self.assertFrameSent(server, None) + + def test_client_receives_data_and_eof_after_exception(self): + client = Connection(Side.CLIENT) + with self.assertRaises(ProtocolError): + client.receive_data(b"\xff\xff") + self.assertConnectionFailing(client, 1002, "invalid opcode") + client.receive_data(b"\x00\x00") + client.receive_eof() + self.assertFrameSent(client, None, eof=True) + + def test_server_receives_data_and_eof_after_exception(self): + server = Connection(Side.SERVER) + with self.assertRaises(ProtocolError): + server.receive_data(b"\xff\xff") + self.assertConnectionFailing(server, 1002, "invalid opcode") + server.receive_data(b"\x00\x00") + server.receive_eof() + self.assertFrameSent(server, None) def test_client_receives_data_after_eof(self): client = Connection(Side.CLIENT) diff --git a/tests/test_streams.py b/tests/test_streams.py index e942c8d12..8abefbcc9 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -145,11 +145,13 @@ def test_feed_eof_after_feed_eof(self): self.reader.feed_eof() self.assertEqual(str(raised.exception), "stream ended") - def test_abort(self): + def test_discard(self): gen = self.reader.read_to_eof() self.reader.feed_data(b"spam") - self.reader.abort() + self.reader.discard() + self.assertGeneratorRunning(gen) + self.reader.feed_eof() data = self.assertGeneratorReturns(gen) self.assertEqual(data, b"") From 008c160365bcda5b38ddeccf5613b1b2a8395599 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Jun 2021 10:36:12 +0200 Subject: [PATCH 0886/1539] Prevent receive_data/eof from raising exceptions. --- src/websockets/connection.py | 17 ++- tests/test_connection.py | 248 +++++++++++++++++------------------ 2 files changed, 132 insertions(+), 133 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 0a3a87bd5..f8637ab20 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -168,6 +168,8 @@ def receive_data(self, data: bytes) -> None: - You must call :meth:`data_to_send` and send this data. - You should call :meth:`events_received` and process these events. + :raises EOFError: if :meth:`receive_eof` was called before + """ self.reader.feed_data(data) next(self.parser) @@ -179,9 +181,11 @@ def receive_eof(self) -> None: After calling this method: - You must call :meth:`data_to_send` and send this data. - - You shouldn't call :meth:`events_received` as it won't + - You aren't exepcted to call :meth:`events_received` as it won't return any new events. + :raises EOFError: if :meth:`receive_eof` was called before + """ self.reader.feed_eof() next(self.parser) @@ -324,7 +328,7 @@ def parse(self) -> Generator[None, None, None]: yield # Once the reader reaches EOF, its feed_data/eof() # methods raise an error, so our receive_data/eof() - # methods don't step the generator. + # methods don't step parse(). raise AssertionError( "parse() shouldn't step after EOF" ) # pragma: no cover @@ -353,29 +357,28 @@ def parse(self) -> Generator[None, None, None]: except ProtocolError as exc: self.fail(1002, str(exc)) self.parser_exc = exc - raise except EOFError as exc: self.fail(1006, str(exc)) self.parser_exc = exc - raise except UnicodeDecodeError as exc: self.fail(1007, f"{exc.reason} at position {exc.start}") self.parser_exc = exc - raise except PayloadTooBig as exc: self.fail(1009, str(exc)) self.parser_exc = exc - raise except Exception as exc: self.logger.error("parser failed", exc_info=True) # Don't include exception details, which may be security-sensitive. self.fail(1011) self.parser_exc = exc - raise + + yield + # If an error occurs, parse() is replaced by discard(). + raise AssertionError("parse() shouldn't step after EOF") # pragma: no cover def discard(self) -> Generator[None, None, None]: while not (yield from self.reader.at_eof()): diff --git a/tests/test_connection.py b/tests/test_connection.py index b8843028d..ec62fc7fc 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -126,16 +126,16 @@ def test_server_receives_masked_frame(self): def test_client_receives_masked_frame(self): client = Connection(Side.CLIENT) - with self.assertRaises(ProtocolError) as raised: - client.receive_data(self.masked_text_frame_data) - self.assertEqual(str(raised.exception), "incorrect masking") + client.receive_data(self.masked_text_frame_data) + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "incorrect masking") self.assertConnectionFailing(client, 1002, "incorrect masking") def test_server_receives_unmasked_frame(self): server = Connection(Side.SERVER) - with self.assertRaises(ProtocolError) as raised: - server.receive_data(self.unmasked_text_frame_date) - self.assertEqual(str(raised.exception), "incorrect masking") + server.receive_data(self.unmasked_text_frame_date) + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "incorrect masking") self.assertConnectionFailing(server, 1002, "incorrect masking") @@ -159,16 +159,16 @@ def test_server_sends_unexpected_continuation(self): def test_client_receives_unexpected_continuation(self): client = Connection(Side.CLIENT) - with self.assertRaises(ProtocolError) as raised: - client.receive_data(b"\x00\x00") - self.assertEqual(str(raised.exception), "unexpected continuation frame") + client.receive_data(b"\x00\x00") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "unexpected continuation frame") self.assertConnectionFailing(client, 1002, "unexpected continuation frame") def test_server_receives_unexpected_continuation(self): server = Connection(Side.SERVER) - with self.assertRaises(ProtocolError) as raised: - server.receive_data(b"\x00\x80\x00\x00\x00\x00") - self.assertEqual(str(raised.exception), "unexpected continuation frame") + server.receive_data(b"\x00\x80\x00\x00\x00\x00") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "unexpected continuation frame") self.assertConnectionFailing(server, 1002, "unexpected continuation frame") def test_client_sends_continuation_after_sending_close(self): @@ -198,17 +198,17 @@ def test_client_receives_continuation_after_receiving_close(self): client = Connection(Side.CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, 1000) - with self.assertRaises(ProtocolError) as raised: - client.receive_data(b"\x00\x00") - self.assertEqual(str(raised.exception), "data frame after close frame") + client.receive_data(b"\x00\x00") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "data frame after close frame") def test_server_receives_continuation_after_receiving_close(self): server = Connection(Side.SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, 1001) - with self.assertRaises(ProtocolError) as raised: - server.receive_data(b"\x00\x80\x00\xff\x00\xff") - self.assertEqual(str(raised.exception), "data frame after close frame") + server.receive_data(b"\x00\x80\x00\xff\x00\xff") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "data frame after close frame") class TextTests(ConnectionTestCase): @@ -248,16 +248,16 @@ def test_server_receives_text(self): def test_client_receives_text_over_size_limit(self): client = Connection(Side.CLIENT, max_size=3) - with self.assertRaises(PayloadTooBig) as raised: - client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") - self.assertEqual(str(raised.exception), "over size limit (4 > 3 bytes)") + client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") + self.assertIsInstance(client.parser_exc, PayloadTooBig) + self.assertEqual(str(client.parser_exc), "over size limit (4 > 3 bytes)") self.assertConnectionFailing(client, 1009, "over size limit (4 > 3 bytes)") def test_server_receives_text_over_size_limit(self): server = Connection(Side.SERVER, max_size=3) - with self.assertRaises(PayloadTooBig) as raised: - server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") - self.assertEqual(str(raised.exception), "over size limit (4 > 3 bytes)") + server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") + self.assertIsInstance(server.parser_exc, PayloadTooBig) + self.assertEqual(str(server.parser_exc), "over size limit (4 > 3 bytes)") self.assertConnectionFailing(server, 1009, "over size limit (4 > 3 bytes)") def test_client_receives_text_without_size_limit(self): @@ -342,9 +342,9 @@ def test_client_receives_fragmented_text_over_size_limit(self): client, Frame(OP_TEXT, "😀".encode()[:2], fin=False), ) - with self.assertRaises(PayloadTooBig) as raised: - client.receive_data(b"\x80\x02\x98\x80") - self.assertEqual(str(raised.exception), "over size limit (2 > 1 bytes)") + client.receive_data(b"\x80\x02\x98\x80") + self.assertIsInstance(client.parser_exc, PayloadTooBig) + self.assertEqual(str(client.parser_exc), "over size limit (2 > 1 bytes)") self.assertConnectionFailing(client, 1009, "over size limit (2 > 1 bytes)") def test_server_receives_fragmented_text_over_size_limit(self): @@ -354,9 +354,9 @@ def test_server_receives_fragmented_text_over_size_limit(self): server, Frame(OP_TEXT, "😀".encode()[:2], fin=False), ) - with self.assertRaises(PayloadTooBig) as raised: - server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") - self.assertEqual(str(raised.exception), "over size limit (2 > 1 bytes)") + server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") + self.assertIsInstance(server.parser_exc, PayloadTooBig) + self.assertEqual(str(server.parser_exc), "over size limit (2 > 1 bytes)") self.assertConnectionFailing(server, 1009, "over size limit (2 > 1 bytes)") def test_client_receives_fragmented_text_without_size_limit(self): @@ -416,9 +416,9 @@ def test_client_receives_unexpected_text(self): client, Frame(OP_TEXT, b"", fin=False), ) - with self.assertRaises(ProtocolError) as raised: - client.receive_data(b"\x01\x00") - self.assertEqual(str(raised.exception), "expected a continuation frame") + client.receive_data(b"\x01\x00") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "expected a continuation frame") self.assertConnectionFailing(client, 1002, "expected a continuation frame") def test_server_receives_unexpected_text(self): @@ -428,9 +428,9 @@ def test_server_receives_unexpected_text(self): server, Frame(OP_TEXT, b"", fin=False), ) - with self.assertRaises(ProtocolError) as raised: - server.receive_data(b"\x01\x80\x00\x00\x00\x00") - self.assertEqual(str(raised.exception), "expected a continuation frame") + server.receive_data(b"\x01\x80\x00\x00\x00\x00") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "expected a continuation frame") self.assertConnectionFailing(server, 1002, "expected a continuation frame") def test_client_sends_text_after_sending_close(self): @@ -452,17 +452,17 @@ def test_client_receives_text_after_receiving_close(self): client = Connection(Side.CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, 1000) - with self.assertRaises(ProtocolError) as raised: - client.receive_data(b"\x81\x00") - self.assertEqual(str(raised.exception), "data frame after close frame") + client.receive_data(b"\x81\x00") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "data frame after close frame") def test_server_receives_text_after_receiving_close(self): server = Connection(Side.SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, 1001) - with self.assertRaises(ProtocolError) as raised: - server.receive_data(b"\x81\x80\x00\xff\x00\xff") - self.assertEqual(str(raised.exception), "data frame after close frame") + server.receive_data(b"\x81\x80\x00\xff\x00\xff") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "data frame after close frame") class BinaryTests(ConnectionTestCase): @@ -502,16 +502,16 @@ def test_server_receives_binary(self): def test_client_receives_binary_over_size_limit(self): client = Connection(Side.CLIENT, max_size=3) - with self.assertRaises(PayloadTooBig) as raised: - client.receive_data(b"\x82\x04\x01\x02\xfe\xff") - self.assertEqual(str(raised.exception), "over size limit (4 > 3 bytes)") + client.receive_data(b"\x82\x04\x01\x02\xfe\xff") + self.assertIsInstance(client.parser_exc, PayloadTooBig) + self.assertEqual(str(client.parser_exc), "over size limit (4 > 3 bytes)") self.assertConnectionFailing(client, 1009, "over size limit (4 > 3 bytes)") def test_server_receives_binary_over_size_limit(self): server = Connection(Side.SERVER, max_size=3) - with self.assertRaises(PayloadTooBig) as raised: - server.receive_data(b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff") - self.assertEqual(str(raised.exception), "over size limit (4 > 3 bytes)") + server.receive_data(b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff") + self.assertIsInstance(server.parser_exc, PayloadTooBig) + self.assertEqual(str(server.parser_exc), "over size limit (4 > 3 bytes)") self.assertConnectionFailing(server, 1009, "over size limit (4 > 3 bytes)") def test_client_sends_fragmented_binary(self): @@ -580,9 +580,9 @@ def test_client_receives_fragmented_binary_over_size_limit(self): client, Frame(OP_BINARY, b"\x01\x02", fin=False), ) - with self.assertRaises(PayloadTooBig) as raised: - client.receive_data(b"\x80\x02\xfe\xff") - self.assertEqual(str(raised.exception), "over size limit (2 > 1 bytes)") + client.receive_data(b"\x80\x02\xfe\xff") + self.assertIsInstance(client.parser_exc, PayloadTooBig) + self.assertEqual(str(client.parser_exc), "over size limit (2 > 1 bytes)") self.assertConnectionFailing(client, 1009, "over size limit (2 > 1 bytes)") def test_server_receives_fragmented_binary_over_size_limit(self): @@ -592,9 +592,9 @@ def test_server_receives_fragmented_binary_over_size_limit(self): server, Frame(OP_BINARY, b"\x01\x02", fin=False), ) - with self.assertRaises(PayloadTooBig) as raised: - server.receive_data(b"\x80\x82\x00\x00\x00\x00\xfe\xff") - self.assertEqual(str(raised.exception), "over size limit (2 > 1 bytes)") + server.receive_data(b"\x80\x82\x00\x00\x00\x00\xfe\xff") + self.assertIsInstance(server.parser_exc, PayloadTooBig) + self.assertEqual(str(server.parser_exc), "over size limit (2 > 1 bytes)") self.assertConnectionFailing(server, 1009, "over size limit (2 > 1 bytes)") def test_client_sends_unexpected_binary(self): @@ -618,9 +618,9 @@ def test_client_receives_unexpected_binary(self): client, Frame(OP_BINARY, b"", fin=False), ) - with self.assertRaises(ProtocolError) as raised: - client.receive_data(b"\x02\x00") - self.assertEqual(str(raised.exception), "expected a continuation frame") + client.receive_data(b"\x02\x00") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "expected a continuation frame") self.assertConnectionFailing(client, 1002, "expected a continuation frame") def test_server_receives_unexpected_binary(self): @@ -630,9 +630,9 @@ def test_server_receives_unexpected_binary(self): server, Frame(OP_BINARY, b"", fin=False), ) - with self.assertRaises(ProtocolError) as raised: - server.receive_data(b"\x02\x80\x00\x00\x00\x00") - self.assertEqual(str(raised.exception), "expected a continuation frame") + server.receive_data(b"\x02\x80\x00\x00\x00\x00") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "expected a continuation frame") self.assertConnectionFailing(server, 1002, "expected a continuation frame") def test_client_sends_binary_after_sending_close(self): @@ -654,17 +654,17 @@ def test_client_receives_binary_after_receiving_close(self): client = Connection(Side.CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, 1000) - with self.assertRaises(ProtocolError) as raised: - client.receive_data(b"\x82\x00") - self.assertEqual(str(raised.exception), "data frame after close frame") + client.receive_data(b"\x82\x00") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "data frame after close frame") def test_server_receives_binary_after_receiving_close(self): server = Connection(Side.SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, 1001) - with self.assertRaises(ProtocolError) as raised: - server.receive_data(b"\x82\x80\x00\xff\x00\xff") - self.assertEqual(str(raised.exception), "data frame after close frame") + server.receive_data(b"\x82\x80\x00\xff\x00\xff") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "data frame after close frame") class CloseTests(ConnectionTestCase): @@ -871,26 +871,27 @@ def test_server_sends_close_with_reason_only(self): def test_client_receives_close_with_truncated_code(self): client = Connection(Side.CLIENT) - with self.assertRaises(ProtocolError) as raised: - client.receive_data(b"\x88\x01\x03") - self.assertEqual(str(raised.exception), "close frame too short") + client.receive_data(b"\x88\x01\x03") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "close frame too short") self.assertConnectionFailing(client, 1002, "close frame too short") self.assertIs(client.state, State.CLOSING) def test_server_receives_close_with_truncated_code(self): server = Connection(Side.SERVER) - with self.assertRaises(ProtocolError) as raised: - server.receive_data(b"\x88\x81\x00\x00\x00\x00\x03") - self.assertEqual(str(raised.exception), "close frame too short") + server.receive_data(b"\x88\x81\x00\x00\x00\x00\x03") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "close frame too short") self.assertConnectionFailing(server, 1002, "close frame too short") self.assertIs(server.state, State.CLOSING) def test_client_receives_close_with_non_utf8_reason(self): client = Connection(Side.CLIENT) - with self.assertRaises(UnicodeDecodeError) as raised: - client.receive_data(b"\x88\x04\x03\xe8\xff\xff") + + client.receive_data(b"\x88\x04\x03\xe8\xff\xff") + self.assertIsInstance(client.parser_exc, UnicodeDecodeError) self.assertEqual( - str(raised.exception), + str(client.parser_exc), "'utf-8' codec can't decode byte 0xff in position 0: invalid start byte", ) self.assertConnectionFailing(client, 1007, "invalid start byte at position 0") @@ -898,10 +899,11 @@ def test_client_receives_close_with_non_utf8_reason(self): def test_server_receives_close_with_non_utf8_reason(self): server = Connection(Side.SERVER) - with self.assertRaises(UnicodeDecodeError) as raised: - server.receive_data(b"\x88\x84\x00\x00\x00\x00\x03\xe9\xff\xff") + + server.receive_data(b"\x88\x84\x00\x00\x00\x00\x03\xe9\xff\xff") + self.assertIsInstance(server.parser_exc, UnicodeDecodeError) self.assertEqual( - str(raised.exception), + str(server.parser_exc), "'utf-8' codec can't decode byte 0xff in position 0: invalid start byte", ) self.assertConnectionFailing(server, 1007, "invalid start byte at position 0") @@ -1002,16 +1004,16 @@ def test_server_sends_fragmented_ping_frame(self): def test_client_receives_fragmented_ping_frame(self): client = Connection(Side.CLIENT) - with self.assertRaises(ProtocolError) as raised: - client.receive_data(b"\x09\x00") - self.assertEqual(str(raised.exception), "fragmented control frame") + client.receive_data(b"\x09\x00") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "fragmented control frame") self.assertConnectionFailing(client, 1002, "fragmented control frame") def test_server_receives_fragmented_ping_frame(self): server = Connection(Side.SERVER) - with self.assertRaises(ProtocolError) as raised: - server.receive_data(b"\x09\x80\x3c\x3c\x3c\x3c") - self.assertEqual(str(raised.exception), "fragmented control frame") + server.receive_data(b"\x09\x80\x3c\x3c\x3c\x3c") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "fragmented control frame") self.assertConnectionFailing(server, 1002, "fragmented control frame") def test_client_sends_ping_after_sending_close(self): @@ -1142,16 +1144,16 @@ def test_server_sends_fragmented_pong_frame(self): def test_client_receives_fragmented_pong_frame(self): client = Connection(Side.CLIENT) - with self.assertRaises(ProtocolError) as raised: - client.receive_data(b"\x0a\x00") - self.assertEqual(str(raised.exception), "fragmented control frame") + client.receive_data(b"\x0a\x00") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "fragmented control frame") self.assertConnectionFailing(client, 1002, "fragmented control frame") def test_server_receives_fragmented_pong_frame(self): server = Connection(Side.SERVER) - with self.assertRaises(ProtocolError) as raised: - server.receive_data(b"\x0a\x80\x3c\x3c\x3c\x3c") - self.assertEqual(str(raised.exception), "fragmented control frame") + server.receive_data(b"\x0a\x80\x3c\x3c\x3c\x3c") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "fragmented control frame") self.assertConnectionFailing(server, 1002, "fragmented control frame") def test_client_sends_pong_after_sending_close(self): @@ -1349,9 +1351,9 @@ def test_client_receive_close_in_fragmented_message(self): # frames in the middle of a fragmented message." However, since the # endpoint must not send a data frame after a close frame, a close # frame can't be "in the middle" of a fragmented message. - with self.assertRaises(ProtocolError) as raised: - client.receive_data(b"\x88\x02\x03\xe8") - self.assertEqual(str(raised.exception), "incomplete fragmented message") + client.receive_data(b"\x88\x02\x03\xe8") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "incomplete fragmented message") self.assertConnectionFailing(client, 1002, "incomplete fragmented message") def test_server_receive_close_in_fragmented_message(self): @@ -1365,9 +1367,9 @@ def test_server_receive_close_in_fragmented_message(self): # frames in the middle of a fragmented message." However, since the # endpoint must not send a data frame after a close frame, a close # frame can't be "in the middle" of a fragmented message. - with self.assertRaises(ProtocolError) as raised: - server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") - self.assertEqual(str(raised.exception), "incomplete fragmented message") + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "incomplete fragmented message") self.assertConnectionFailing(server, 1002, "incomplete fragmented message") @@ -1391,70 +1393,65 @@ def test_server_receives_eof(self): def test_client_receives_eof_between_frames(self): client = Connection(Side.CLIENT) - with self.assertRaises(EOFError) as raised: - client.receive_eof() - self.assertEqual(str(raised.exception), "unexpected end of stream") + client.receive_eof() + self.assertIsInstance(client.parser_exc, EOFError) + self.assertEqual(str(client.parser_exc), "unexpected end of stream") def test_server_receives_eof_between_frames(self): server = Connection(Side.SERVER) - with self.assertRaises(EOFError) as raised: - server.receive_eof() - self.assertEqual(str(raised.exception), "unexpected end of stream") + server.receive_eof() + self.assertIsInstance(server.parser_exc, EOFError) + self.assertEqual(str(server.parser_exc), "unexpected end of stream") def test_client_receives_eof_inside_frame(self): client = Connection(Side.CLIENT) client.receive_data(b"\x81") - with self.assertRaises(EOFError) as raised: - client.receive_eof() + client.receive_eof() + self.assertIsInstance(client.parser_exc, EOFError) self.assertEqual( - str(raised.exception), "stream ends after 1 bytes, expected 2 bytes" + str(client.parser_exc), "stream ends after 1 bytes, expected 2 bytes" ) def test_server_receives_eof_inside_frame(self): server = Connection(Side.SERVER) server.receive_data(b"\x81") - with self.assertRaises(EOFError) as raised: - server.receive_eof() + server.receive_eof() + self.assertIsInstance(server.parser_exc, EOFError) self.assertEqual( - str(raised.exception), "stream ends after 1 bytes, expected 2 bytes" + str(server.parser_exc), "stream ends after 1 bytes, expected 2 bytes" ) def test_client_receives_data_after_exception(self): client = Connection(Side.CLIENT) - with self.assertRaises(ProtocolError): - client.receive_data(b"\xff\xff") + client.receive_data(b"\xff\xff") self.assertConnectionFailing(client, 1002, "invalid opcode") client.receive_data(b"\x00\x00") self.assertFrameSent(client, None) def test_server_receives_data_after_exception(self): server = Connection(Side.SERVER) - with self.assertRaises(ProtocolError): - server.receive_data(b"\xff\xff") + server.receive_data(b"\xff\xff") self.assertConnectionFailing(server, 1002, "invalid opcode") server.receive_data(b"\x00\x00") self.assertFrameSent(server, None) def test_client_receives_eof_after_exception(self): client = Connection(Side.CLIENT) - with self.assertRaises(ProtocolError): - client.receive_data(b"\xff\xff") + client.receive_data(b"\xff\xff") self.assertConnectionFailing(client, 1002, "invalid opcode") client.receive_eof() self.assertFrameSent(client, None, eof=True) def test_server_receives_eof_after_exception(self): server = Connection(Side.SERVER) - with self.assertRaises(ProtocolError): - server.receive_data(b"\xff\xff") + server.receive_data(b"\xff\xff") self.assertConnectionFailing(server, 1002, "invalid opcode") server.receive_eof() self.assertFrameSent(server, None) def test_client_receives_data_and_eof_after_exception(self): client = Connection(Side.CLIENT) - with self.assertRaises(ProtocolError): - client.receive_data(b"\xff\xff") + client.receive_data(b"\xff\xff") self.assertConnectionFailing(client, 1002, "invalid opcode") client.receive_data(b"\x00\x00") client.receive_eof() @@ -1462,8 +1459,7 @@ def test_client_receives_data_and_eof_after_exception(self): def test_server_receives_data_and_eof_after_exception(self): server = Connection(Side.SERVER) - with self.assertRaises(ProtocolError): - server.receive_data(b"\xff\xff") + server.receive_data(b"\xff\xff") self.assertConnectionFailing(server, 1002, "invalid opcode") server.receive_data(b"\x00\x00") server.receive_eof() @@ -1516,18 +1512,18 @@ def test_client_hits_internal_error_reading_frame(self): client = Connection(Side.CLIENT) # This isn't supposed to happen, so we're simulating it. with unittest.mock.patch("struct.unpack", side_effect=RuntimeError("BOOM")): - with self.assertRaises(RuntimeError) as raised: - client.receive_data(b"\x81\x00") - self.assertEqual(str(raised.exception), "BOOM") + client.receive_data(b"\x81\x00") + self.assertIsInstance(client.parser_exc, RuntimeError) + self.assertEqual(str(client.parser_exc), "BOOM") self.assertConnectionFailing(client, 1011, "") def test_server_hits_internal_error_reading_frame(self): server = Connection(Side.SERVER) # This isn't supposed to happen, so we're simulating it. with unittest.mock.patch("struct.unpack", side_effect=RuntimeError("BOOM")): - with self.assertRaises(RuntimeError) as raised: - server.receive_data(b"\x81\x80\x00\x00\x00\x00") - self.assertEqual(str(raised.exception), "BOOM") + server.receive_data(b"\x81\x80\x00\x00\x00\x00") + self.assertIsInstance(server.parser_exc, RuntimeError) + self.assertEqual(str(server.parser_exc), "BOOM") self.assertConnectionFailing(server, 1011, "") From 8add3a648ae4886b78db018f5b672abc0f6e5a8c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Jun 2021 14:59:43 +0200 Subject: [PATCH 0887/1539] Indicate when TCP connection should close --- src/websockets/connection.py | 16 ++++++++ tests/test_connection.py | 71 +++++++++++++++++++++++++++++++++--- 2 files changed, 82 insertions(+), 5 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index f8637ab20..581614e49 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -313,6 +313,22 @@ def data_to_send(self) -> List[bytes]: writes, self.writes = self.writes, [] return writes + def close_expected(self) -> bool: + """ + Tell whether the TCP connection is expected to close soon. + + Call this method immediately after calling any of the ``receive_*()`` + methods and, if it returns ``True``, schedule closing the TCP + connection after a short timeout. + + """ + # We already got a TCP Close if and only if the state is CLOSED. + # We expect a TCP close if and only if we sent a close frame: + # * Normal closure: once we send a close frame, we expect a TCP close. + # * Abnormal closure: we always send a close frame except on EOFError, + # but that's fine because we already got the TCP close. + return self.state is not CLOSED and self.close_sent is not None + # Private methods for receiving data. def parse(self) -> Generator[None, None, None]: diff --git a/tests/test_connection.py b/tests/test_connection.py index ec62fc7fc..e420c7853 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1027,7 +1027,8 @@ def test_client_sends_ping_after_sending_close(self): with self.assertRaises(InvalidState) as raised: client.send_ping(b"") self.assertEqual( - str(raised.exception), "cannot write to a WebSocket in the CLOSING state" + str(raised.exception), + "cannot write to a WebSocket in the CLOSING state", ) def test_server_sends_ping_after_sending_close(self): @@ -1040,7 +1041,8 @@ def test_server_sends_ping_after_sending_close(self): with self.assertRaises(InvalidState) as raised: server.send_ping(b"") self.assertEqual( - str(raised.exception), "cannot write to a WebSocket in the CLOSING state" + str(raised.exception), + "cannot write to a WebSocket in the CLOSING state", ) def test_client_receives_ping_after_receiving_close(self): @@ -1375,7 +1377,7 @@ def test_server_receive_close_in_fragmented_message(self): class EOFTests(ConnectionTestCase): """ - Test connection termination. + Test half-closes on connection termination. """ @@ -1409,7 +1411,8 @@ def test_client_receives_eof_inside_frame(self): client.receive_eof() self.assertIsInstance(client.parser_exc, EOFError) self.assertEqual( - str(client.parser_exc), "stream ends after 1 bytes, expected 2 bytes" + str(client.parser_exc), + "stream ends after 1 bytes, expected 2 bytes", ) def test_server_receives_eof_inside_frame(self): @@ -1418,7 +1421,8 @@ def test_server_receives_eof_inside_frame(self): server.receive_eof() self.assertIsInstance(server.parser_exc, EOFError) self.assertEqual( - str(server.parser_exc), "stream ends after 1 bytes, expected 2 bytes" + str(server.parser_exc), + "stream ends after 1 bytes, expected 2 bytes", ) def test_client_receives_data_after_exception(self): @@ -1502,6 +1506,63 @@ def test_server_receives_eof_after_eof(self): self.assertEqual(str(raised.exception), "stream ended") +class TCPCloseTests(ConnectionTestCase): + """ + Test expectation of TCP close on connection termination. + + """ + + def test_client_default(self): + client = Connection(Side.CLIENT) + self.assertFalse(client.close_expected()) + + def test_server_default(self): + server = Connection(Side.SERVER) + self.assertFalse(server.close_expected()) + + def test_client_sends_close(self): + client = Connection(Side.CLIENT) + client.send_close() + self.assertTrue(client.close_expected()) + + def test_server_sends_close(self): + server = Connection(Side.SERVER) + server.send_close() + self.assertTrue(server.close_expected()) + + def test_client_receives_close(self): + client = Connection(Side.CLIENT) + client.receive_data(b"\x88\x00") + self.assertTrue(client.close_expected()) + + def test_client_receives_close_then_eof(self): + client = Connection(Side.CLIENT) + client.receive_data(b"\x88\x00") + client.receive_eof() + self.assertFalse(client.close_expected()) + + def test_server_receives_close_then_eof(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") + server.receive_eof() + self.assertFalse(server.close_expected()) + + def test_server_receives_close(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") + self.assertTrue(server.close_expected()) + + def test_client_fails_connection(self): + client = Connection(Side.CLIENT) + client.fail(1002) + self.assertTrue(client.close_expected()) + + def test_server_fails_connection(self): + server = Connection(Side.SERVER) + server.fail(1002) + self.assertTrue(server.close_expected()) + + class ErrorTests(ConnectionTestCase): """ Test other error cases. From d5ab29687ae19170cbf6ee80ecd26d7de0862f5a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Jun 2021 15:15:41 +0200 Subject: [PATCH 0888/1539] Set connection state to closed after EOF --- src/websockets/connection.py | 6 ++++-- tests/test_connection.py | 22 ++++++++++++++-------- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 581614e49..2e6871c00 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -128,7 +128,7 @@ def close_code(self) -> Optional[int]: Available once the connection is closed. """ - if self.state is not State.CLOSED: + if self.state is not CLOSED: return None elif self.close_rcvd is None: return 1006 @@ -143,7 +143,7 @@ def close_reason(self) -> Optional[str]: Available once the connection is closed. """ - if self.state is not State.CLOSED: + if self.state is not CLOSED: return None elif self.close_rcvd is None: return "" @@ -340,6 +340,7 @@ def parse(self) -> Generator[None, None, None]: if self.close_rcvd_then_sent is not None: if self.side is CLIENT: self.send_eof() + self.set_state(CLOSED) # If parse() completes normally, execution ends here. yield # Once the reader reaches EOF, its feed_data/eof() @@ -402,6 +403,7 @@ def discard(self) -> Generator[None, None, None]: if self.side is CLIENT: self.send_eof() # If discard() completes normally, execution ends here. + self.set_state(CLOSED) yield # Once the reader reaches EOF, its feed_data/eof() # methods raise an error, so our receive_data/eof() diff --git a/tests/test_connection.py b/tests/test_connection.py index e420c7853..fa559bc43 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -682,35 +682,35 @@ class CloseTests(ConnectionTestCase): def test_close_code(self): client = Connection(Side.CLIENT) client.receive_data(b"\x88\x04\x03\xe8OK") - client.set_state(State.CLOSED) + client.receive_eof() self.assertEqual(client.close_code, 1000) def test_close_reason(self): server = Connection(Side.SERVER) server.receive_data(b"\x88\x84\x00\x00\x00\x00\x03\xe8OK") - server.set_state(State.CLOSED) + server.receive_eof() self.assertEqual(server.close_reason, "OK") def test_close_code_not_provided(self): server = Connection(Side.SERVER) server.receive_data(b"\x88\x80\x00\x00\x00\x00") - server.set_state(State.CLOSED) + server.receive_eof() self.assertEqual(server.close_code, 1005) def test_close_reason_not_provided(self): client = Connection(Side.CLIENT) client.receive_data(b"\x88\x00") - client.set_state(State.CLOSED) + client.receive_eof() self.assertEqual(client.close_reason, "") def test_close_code_not_available(self): client = Connection(Side.CLIENT) - client.set_state(State.CLOSED) + client.receive_eof() self.assertEqual(client.close_code, 1006) def test_close_reason_not_available(self): server = Connection(Side.SERVER) - server.set_state(State.CLOSED) + server.receive_eof() self.assertEqual(server.close_reason, "") def test_close_code_not_available_yet(self): @@ -1385,25 +1385,29 @@ def test_client_receives_eof(self): client = Connection(Side.CLIENT) client.receive_data(b"\x88\x00") self.assertConnectionClosing(client) - client.receive_eof() # does not raise an exception + client.receive_eof() + self.assertIs(client.state, State.CLOSED) def test_server_receives_eof(self): server = Connection(Side.SERVER) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") self.assertConnectionClosing(server) - server.receive_eof() # does not raise an exception + server.receive_eof() + self.assertIs(server.state, State.CLOSED) def test_client_receives_eof_between_frames(self): client = Connection(Side.CLIENT) client.receive_eof() self.assertIsInstance(client.parser_exc, EOFError) self.assertEqual(str(client.parser_exc), "unexpected end of stream") + self.assertIs(client.state, State.CLOSED) def test_server_receives_eof_between_frames(self): server = Connection(Side.SERVER) server.receive_eof() self.assertIsInstance(server.parser_exc, EOFError) self.assertEqual(str(server.parser_exc), "unexpected end of stream") + self.assertIs(server.state, State.CLOSED) def test_client_receives_eof_inside_frame(self): client = Connection(Side.CLIENT) @@ -1414,6 +1418,7 @@ def test_client_receives_eof_inside_frame(self): str(client.parser_exc), "stream ends after 1 bytes, expected 2 bytes", ) + self.assertIs(client.state, State.CLOSED) def test_server_receives_eof_inside_frame(self): server = Connection(Side.SERVER) @@ -1424,6 +1429,7 @@ def test_server_receives_eof_inside_frame(self): str(server.parser_exc), "stream ends after 1 bytes, expected 2 bytes", ) + self.assertIs(server.state, State.CLOSED) def test_client_receives_data_after_exception(self): client = Connection(Side.CLIENT) From 48527e4c1c07a9d2c68ac16f91dd2f0c267c18b5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Jun 2021 18:43:56 +0200 Subject: [PATCH 0889/1539] Replace set_state by a property setter. --- src/websockets/client.py | 2 +- src/websockets/connection.py | 35 +++++++++++++++++++++-------------- src/websockets/server.py | 2 +- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index cc807e97c..6a0bb580e 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -307,7 +307,7 @@ def parse(self) -> Generator[None, None, None]: response.exception = exc else: assert self.state is CONNECTING - self.set_state(OPEN) + self.state = OPEN finally: self.events.append(response) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 2e6871c00..8f407f5c7 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -83,8 +83,8 @@ def __init__( # Connection side. CLIENT or SERVER. self.side = side - # Connnection state. CONNECTING and CLOSED states are handled in subclasses. - self.set_state(state) + # Connnection state. Initially OPEN because subclasses handle CONNECTING. + self.state = state # Maximum size of incoming messages in bytes. self.max_size = max_size @@ -120,6 +120,20 @@ def __init__( # Public attributes + @property + def state(self) -> State: + """ + WebSocket connection state. + + """ + return self._state + + @state.setter + def state(self, state: State) -> None: + if self.debug: + self.logger.debug("= connection is %s", state.name) + self._state = state + @property def close_code(self) -> Optional[int]: """ @@ -150,13 +164,6 @@ def close_reason(self) -> Optional[str]: else: return self.close_rcvd.reason - # Private attributes - - def set_state(self, state: State) -> None: - if self.debug: - self.logger.debug("= connection is %s", state.name) - self.state = state - # Public methods for receiving data. def receive_data(self, data: bytes) -> None: @@ -241,7 +248,7 @@ def send_close(self, code: Optional[int] = None, reason: str = "") -> None: # 7.1.3. The WebSocket Closing Handshake is Started self.send_frame(Frame(OP_CLOSE, data)) self.close_sent = close - self.set_state(CLOSING) + self.state = CLOSING def send_ping(self, data: bytes) -> None: """ @@ -273,7 +280,7 @@ def fail(self, code: int, reason: str = "") -> None: data = close.serialize() self.send_frame(Frame(OP_CLOSE, data)) self.close_sent = close - self.set_state(CLOSING) + self.state = CLOSING if self.side is SERVER and not self.eof_sent: self.send_eof() @@ -340,7 +347,7 @@ def parse(self) -> Generator[None, None, None]: if self.close_rcvd_then_sent is not None: if self.side is CLIENT: self.send_eof() - self.set_state(CLOSED) + self.state = CLOSED # If parse() completes normally, execution ends here. yield # Once the reader reaches EOF, its feed_data/eof() @@ -403,7 +410,7 @@ def discard(self) -> Generator[None, None, None]: if self.side is CLIENT: self.send_eof() # If discard() completes normally, execution ends here. - self.set_state(CLOSED) + self.state = CLOSED yield # Once the reader reaches EOF, its feed_data/eof() # methods raise an error, so our receive_data/eof() @@ -475,7 +482,7 @@ def recv_frame(self, frame: Frame) -> None: self.send_frame(Frame(OP_CLOSE, frame.data)) self.close_sent = self.close_rcvd self.close_rcvd_then_sent = True - self.set_state(CLOSING) + self.state = CLOSING # 7.1.2. Start the WebSocket Closing Handshake: "Once an # endpoint has both sent and received a Close control frame, diff --git a/src/websockets/server.py b/src/websockets/server.py index 545b38c98..d6d2143bb 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -432,7 +432,7 @@ def send_response(self, response: Response) -> None: if response.status_code == 101: assert self.state is CONNECTING - self.set_state(OPEN) + self.state = OPEN else: self.send_eof() From b96d82202cf6d7fbe2196c07c23679cf2318055b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Jun 2021 21:26:47 +0200 Subject: [PATCH 0890/1539] Add API for getting connection closed exception. --- src/websockets/connection.py | 49 +++++++++++++++-- tests/test_connection.py | 103 ++++++++++++++++++++++++++++++++++- 2 files changed, 145 insertions(+), 7 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 8f407f5c7..4d8851787 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -3,11 +3,19 @@ import enum import logging import uuid -from typing import Generator, List, Optional, Union - -from .exceptions import InvalidState, PayloadTooBig, ProtocolError +from typing import Generator, List, Optional, Type, Union + +from .exceptions import ( + ConnectionClosed, + ConnectionClosedError, + ConnectionClosedOK, + InvalidState, + PayloadTooBig, + ProtocolError, +) from .extensions import Extension from .frames import ( + OK_CLOSE_CODES, OP_BINARY, OP_CLOSE, OP_CONT, @@ -123,7 +131,7 @@ def __init__( @property def state(self) -> State: """ - WebSocket connection state. + Connection State defined in 4.1, 4.2, 7.1.3, and 7.1.4 of :rfc:`6455`. """ return self._state @@ -137,7 +145,7 @@ def state(self, state: State) -> None: @property def close_code(self) -> Optional[int]: """ - WebSocket close code received in a close frame. + Connection Close Code defined in 7.1.5 of :rfc:`6455`. Available once the connection is closed. @@ -152,7 +160,7 @@ def close_code(self) -> Optional[int]: @property def close_reason(self) -> Optional[str]: """ - WebSocket close reason received in a close frame. + Connection Close Reason defined in 7.1.6 of :rfc:`6455`. Available once the connection is closed. @@ -164,6 +172,35 @@ def close_reason(self) -> Optional[str]: else: return self.close_rcvd.reason + @property + def connection_closed_exc(self) -> ConnectionClosed: + """ + Exception raised when trying to interact with a closed connection. + + Available once the connection is closed. If you need to raise this + exception while the connection is closing, wait until it's closed. + + """ + assert self.state is CLOSED + exc_type: Type[ConnectionClosed] + if ( + self.close_rcvd is not None + and self.close_sent is not None + and self.close_rcvd.code in OK_CLOSE_CODES + and self.close_sent.code in OK_CLOSE_CODES + ): + exc_type = ConnectionClosedOK + else: + exc_type = ConnectionClosedError + exc: ConnectionClosed = exc_type( + self.close_rcvd, + self.close_sent, + self.close_rcvd_then_sent, + ) + # Chain to the exception raised in the parser, if any. + exc.__cause__ = self.parser_exc + return exc + # Public methods for receiving data. def receive_data(self, data: bytes) -> None: diff --git a/tests/test_connection.py b/tests/test_connection.py index fa559bc43..10eb58eeb 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,7 +1,13 @@ import unittest.mock from websockets.connection import * -from websockets.exceptions import InvalidState, PayloadTooBig, ProtocolError +from websockets.exceptions import ( + ConnectionClosedError, + ConnectionClosedOK, + InvalidState, + PayloadTooBig, + ProtocolError, +) from websockets.frames import ( OP_BINARY, OP_CLOSE, @@ -1569,6 +1575,101 @@ def test_server_fails_connection(self): self.assertTrue(server.close_expected()) +class ConnectionClosedTests(ConnectionTestCase): + """ + Test connection closed exception. + + """ + + def test_client_sends_close_then_receives_close(self): + # Client-initiated close handshake on the client side complete. + client = Connection(Side.CLIENT) + client.send_close(1000, "") + client.receive_data(b"\x88\x02\x03\xe8") + client.receive_eof() + exc = client.connection_closed_exc + self.assertIsInstance(exc, ConnectionClosedOK) + self.assertEqual(exc.rcvd, Close(1000, "")) + self.assertEqual(exc.sent, Close(1000, "")) + self.assertFalse(exc.rcvd_then_sent) + + def test_server_sends_close_then_receives_close(self): + # Server-initiated close handshake on the server side complete. + server = Connection(Side.SERVER) + server.send_close(1000, "") + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe8") + server.receive_eof() + exc = server.connection_closed_exc + self.assertIsInstance(exc, ConnectionClosedOK) + self.assertEqual(exc.rcvd, Close(1000, "")) + self.assertEqual(exc.sent, Close(1000, "")) + self.assertFalse(exc.rcvd_then_sent) + + def test_client_receives_close_then_sends_close(self): + # Server-initiated close handshake on the client side complete. + client = Connection(Side.CLIENT) + client.receive_data(b"\x88\x02\x03\xe8") + client.receive_eof() + exc = client.connection_closed_exc + self.assertIsInstance(exc, ConnectionClosedOK) + self.assertEqual(exc.rcvd, Close(1000, "")) + self.assertEqual(exc.sent, Close(1000, "")) + self.assertTrue(exc.rcvd_then_sent) + + def test_server_receives_close_then_sends_close(self): + # Client-initiated close handshake on the server side complete. + server = Connection(Side.SERVER) + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe8") + server.receive_eof() + exc = server.connection_closed_exc + self.assertIsInstance(exc, ConnectionClosedOK) + self.assertEqual(exc.rcvd, Close(1000, "")) + self.assertEqual(exc.sent, Close(1000, "")) + self.assertTrue(exc.rcvd_then_sent) + + def test_client_sends_close_then_receives_eof(self): + # Client-initiated close handshake on the client side times out. + client = Connection(Side.CLIENT) + client.send_close(1000, "") + client.receive_eof() + exc = client.connection_closed_exc + self.assertIsInstance(exc, ConnectionClosedError) + self.assertIsNone(exc.rcvd) + self.assertEqual(exc.sent, Close(1000, "")) + self.assertIsNone(exc.rcvd_then_sent) + + def test_server_sends_close_then_receives_eof(self): + # Server-initiated close handshake on the server side times out. + server = Connection(Side.SERVER) + server.send_close(1000, "") + server.receive_eof() + exc = server.connection_closed_exc + self.assertIsInstance(exc, ConnectionClosedError) + self.assertIsNone(exc.rcvd) + self.assertEqual(exc.sent, Close(1000, "")) + self.assertIsNone(exc.rcvd_then_sent) + + def test_client_receives_eof(self): + # Server-initiated close handshake on the client side times out. + client = Connection(Side.CLIENT) + client.receive_eof() + exc = client.connection_closed_exc + self.assertIsInstance(exc, ConnectionClosedError) + self.assertIsNone(exc.rcvd) + self.assertIsNone(exc.sent) + self.assertIsNone(exc.rcvd_then_sent) + + def test_server_receives_eof(self): + # Client-initiated close handshake on the server side times out. + server = Connection(Side.SERVER) + server.receive_eof() + exc = server.connection_closed_exc + self.assertIsInstance(exc, ConnectionClosedError) + self.assertIsNone(exc.rcvd) + self.assertIsNone(exc.sent) + self.assertIsNone(exc.rcvd_then_sent) + + class ErrorTests(ConnectionTestCase): """ Test other error cases. From a9de47d9247335d4a4b2df551b1732049597ec58 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 23 Jun 2021 08:29:04 +0200 Subject: [PATCH 0891/1539] Rename API with a shorter name. --- src/websockets/connection.py | 2 +- tests/test_connection.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 4d8851787..4a4b5ddd0 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -173,7 +173,7 @@ def close_reason(self) -> Optional[str]: return self.close_rcvd.reason @property - def connection_closed_exc(self) -> ConnectionClosed: + def close_exc(self) -> ConnectionClosed: """ Exception raised when trying to interact with a closed connection. diff --git a/tests/test_connection.py b/tests/test_connection.py index 10eb58eeb..b82428eab 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1587,7 +1587,7 @@ def test_client_sends_close_then_receives_close(self): client.send_close(1000, "") client.receive_data(b"\x88\x02\x03\xe8") client.receive_eof() - exc = client.connection_closed_exc + exc = client.close_exc self.assertIsInstance(exc, ConnectionClosedOK) self.assertEqual(exc.rcvd, Close(1000, "")) self.assertEqual(exc.sent, Close(1000, "")) @@ -1599,7 +1599,7 @@ def test_server_sends_close_then_receives_close(self): server.send_close(1000, "") server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe8") server.receive_eof() - exc = server.connection_closed_exc + exc = server.close_exc self.assertIsInstance(exc, ConnectionClosedOK) self.assertEqual(exc.rcvd, Close(1000, "")) self.assertEqual(exc.sent, Close(1000, "")) @@ -1610,7 +1610,7 @@ def test_client_receives_close_then_sends_close(self): client = Connection(Side.CLIENT) client.receive_data(b"\x88\x02\x03\xe8") client.receive_eof() - exc = client.connection_closed_exc + exc = client.close_exc self.assertIsInstance(exc, ConnectionClosedOK) self.assertEqual(exc.rcvd, Close(1000, "")) self.assertEqual(exc.sent, Close(1000, "")) @@ -1621,7 +1621,7 @@ def test_server_receives_close_then_sends_close(self): server = Connection(Side.SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe8") server.receive_eof() - exc = server.connection_closed_exc + exc = server.close_exc self.assertIsInstance(exc, ConnectionClosedOK) self.assertEqual(exc.rcvd, Close(1000, "")) self.assertEqual(exc.sent, Close(1000, "")) @@ -1632,7 +1632,7 @@ def test_client_sends_close_then_receives_eof(self): client = Connection(Side.CLIENT) client.send_close(1000, "") client.receive_eof() - exc = client.connection_closed_exc + exc = client.close_exc self.assertIsInstance(exc, ConnectionClosedError) self.assertIsNone(exc.rcvd) self.assertEqual(exc.sent, Close(1000, "")) @@ -1643,7 +1643,7 @@ def test_server_sends_close_then_receives_eof(self): server = Connection(Side.SERVER) server.send_close(1000, "") server.receive_eof() - exc = server.connection_closed_exc + exc = server.close_exc self.assertIsInstance(exc, ConnectionClosedError) self.assertIsNone(exc.rcvd) self.assertEqual(exc.sent, Close(1000, "")) @@ -1653,7 +1653,7 @@ def test_client_receives_eof(self): # Server-initiated close handshake on the client side times out. client = Connection(Side.CLIENT) client.receive_eof() - exc = client.connection_closed_exc + exc = client.close_exc self.assertIsInstance(exc, ConnectionClosedError) self.assertIsNone(exc.rcvd) self.assertIsNone(exc.sent) @@ -1663,7 +1663,7 @@ def test_server_receives_eof(self): # Client-initiated close handshake on the server side times out. server = Connection(Side.SERVER) server.receive_eof() - exc = server.connection_closed_exc + exc = server.close_exc self.assertIsInstance(exc, ConnectionClosedError) self.assertIsNone(exc.rcvd) self.assertIsNone(exc.sent) From 5d8e0c5dbefa1c487c14736943f858e3b0dc9921 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 23 Jun 2021 08:44:11 +0200 Subject: [PATCH 0892/1539] Stop processing data after receiving close frame. --- src/websockets/connection.py | 46 +++++++++++++----------------------- tests/test_connection.py | 46 ++++++++++++++---------------------- 2 files changed, 34 insertions(+), 58 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 4a4b5ddd0..87a4e01b1 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -381,20 +381,11 @@ def parse(self) -> Generator[None, None, None]: if (yield from self.reader.at_eof()): if self.debug: self.logger.debug("< EOF") - if self.close_rcvd_then_sent is not None: - if self.side is CLIENT: - self.send_eof() - self.state = CLOSED - # If parse() completes normally, execution ends here. - yield - # Once the reader reaches EOF, its feed_data/eof() - # methods raise an error, so our receive_data/eof() - # methods don't step parse(). - raise AssertionError( - "parse() shouldn't step after EOF" - ) # pragma: no cover - else: - raise EOFError("unexpected end of stream") + # If the WebSocket connection is closed cleanly, with a + # closing handhshake, recv_frame() substitutes parse() + # with discard(). This branch is reached only when the + # connection isn't closed cleanly. + raise EOFError("unexpected end of stream") if self.max_size is None: max_size = None @@ -444,6 +435,8 @@ def parse(self) -> Generator[None, None, None]: def discard(self) -> Generator[None, None, None]: while not (yield from self.reader.at_eof()): self.reader.discard() + if self.debug: + self.logger.debug("< EOF") if self.side is CLIENT: self.send_eof() # If discard() completes normally, execution ends here. @@ -456,11 +449,6 @@ def discard(self) -> Generator[None, None, None]: def recv_frame(self, frame: Frame) -> None: if frame.opcode is OP_TEXT or frame.opcode is OP_BINARY: - # 5.5.1 Close: "The application MUST NOT send any more data - # frames after sending a Close frame." - if self.close_rcvd is not None: - raise ProtocolError("data frame after close frame") - if self.cur_size is not None: raise ProtocolError("expected a continuation frame") if frame.fin: @@ -469,11 +457,6 @@ def recv_frame(self, frame: Frame) -> None: self.cur_size = len(frame.data) elif frame.opcode is OP_CONT: - # 5.5.1 Close: "The application MUST NOT send any more data - # frames after sending a Close frame." - if self.close_rcvd is not None: - raise ProtocolError("data frame after close frame") - if self.cur_size is None: raise ProtocolError("unexpected continuation frame") if frame.fin: @@ -483,11 +466,9 @@ def recv_frame(self, frame: Frame) -> None: elif frame.opcode is OP_PING: # 5.5.2. Ping: "Upon receipt of a Ping frame, an endpoint MUST - # send a Pong frame in response, unless it already received a - # Close frame." - if self.close_rcvd is None: - pong_frame = Frame(OP_PONG, frame.data) - self.send_frame(pong_frame) + # send a Pong frame in response" + pong_frame = Frame(OP_PONG, frame.data) + self.send_frame(pong_frame) elif frame.opcode is OP_PONG: # 5.5.3 Pong: "A response to an unsolicited Pong frame is not @@ -528,6 +509,12 @@ def recv_frame(self, frame: Frame) -> None: if self.side is SERVER: self.send_eof() + # 1.4. Closing Handshake: "after receiving a control frame + # indicating the connection should be closed, a peer discards + # any further data received." + self.parser = self.discard() + next(self.parser) # start coroutine + else: # pragma: no cover # This can't happen because Frame.parse() validates opcodes. raise AssertionError(f"unexpected opcode: {frame.opcode:02x}") @@ -537,7 +524,6 @@ def recv_frame(self, frame: Frame) -> None: # Private methods for sending events. def send_frame(self, frame: Frame) -> None: - # Defensive assertion for protocol compliance. if self.state is not OPEN: raise InvalidState( f"cannot write to a WebSocket in the {self.state.name} state" diff --git a/tests/test_connection.py b/tests/test_connection.py index b82428eab..ac0df46ce 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -205,16 +205,16 @@ def test_client_receives_continuation_after_receiving_close(self): client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, 1000) client.receive_data(b"\x00\x00") - self.assertIsInstance(client.parser_exc, ProtocolError) - self.assertEqual(str(client.parser_exc), "data frame after close frame") + self.assertFrameReceived(client, None) + self.assertFrameSent(client, None) def test_server_receives_continuation_after_receiving_close(self): server = Connection(Side.SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, 1001) server.receive_data(b"\x00\x80\x00\xff\x00\xff") - self.assertIsInstance(server.parser_exc, ProtocolError) - self.assertEqual(str(server.parser_exc), "data frame after close frame") + self.assertFrameReceived(server, None) + self.assertFrameSent(server, None) class TextTests(ConnectionTestCase): @@ -459,16 +459,16 @@ def test_client_receives_text_after_receiving_close(self): client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, 1000) client.receive_data(b"\x81\x00") - self.assertIsInstance(client.parser_exc, ProtocolError) - self.assertEqual(str(client.parser_exc), "data frame after close frame") + self.assertFrameReceived(client, None) + self.assertFrameSent(client, None) def test_server_receives_text_after_receiving_close(self): server = Connection(Side.SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, 1001) server.receive_data(b"\x81\x80\x00\xff\x00\xff") - self.assertIsInstance(server.parser_exc, ProtocolError) - self.assertEqual(str(server.parser_exc), "data frame after close frame") + self.assertFrameReceived(server, None) + self.assertFrameSent(server, None) class BinaryTests(ConnectionTestCase): @@ -661,16 +661,16 @@ def test_client_receives_binary_after_receiving_close(self): client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, 1000) client.receive_data(b"\x82\x00") - self.assertIsInstance(client.parser_exc, ProtocolError) - self.assertEqual(str(client.parser_exc), "data frame after close frame") + self.assertFrameReceived(client, None) + self.assertFrameSent(client, None) def test_server_receives_binary_after_receiving_close(self): server = Connection(Side.SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, 1001) server.receive_data(b"\x82\x80\x00\xff\x00\xff") - self.assertIsInstance(server.parser_exc, ProtocolError) - self.assertEqual(str(server.parser_exc), "data frame after close frame") + self.assertFrameReceived(server, None) + self.assertFrameSent(server, None) class CloseTests(ConnectionTestCase): @@ -1056,10 +1056,7 @@ def test_client_receives_ping_after_receiving_close(self): client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, 1000) client.receive_data(b"\x89\x04\x22\x66\xaa\xee") - self.assertFrameReceived( - client, - Frame(OP_PING, b"\x22\x66\xaa\xee"), - ) + self.assertFrameReceived(client, None) self.assertFrameSent(client, None) def test_server_receives_ping_after_receiving_close(self): @@ -1067,10 +1064,7 @@ def test_server_receives_ping_after_receiving_close(self): server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, 1001) server.receive_data(b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22") - self.assertFrameReceived( - server, - Frame(OP_PING, b"\x22\x66\xaa\xee"), - ) + self.assertFrameReceived(server, None) self.assertFrameSent(server, None) @@ -1186,20 +1180,16 @@ def test_client_receives_pong_after_receiving_close(self): client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, 1000) client.receive_data(b"\x8a\x04\x22\x66\xaa\xee") - self.assertFrameReceived( - client, - Frame(OP_PONG, b"\x22\x66\xaa\xee"), - ) + self.assertFrameReceived(client, None) + self.assertFrameSent(client, None) def test_server_receives_pong_after_receiving_close(self): server = Connection(Side.SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, 1001) server.receive_data(b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22") - self.assertFrameReceived( - server, - Frame(OP_PONG, b"\x22\x66\xaa\xee"), - ) + self.assertFrameReceived(server, None) + self.assertFrameSent(server, None) class FailTests(ConnectionTestCase): From d5cf1efb737a943583b1e5b4ceca5376bdc3995f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 26 Jun 2021 08:22:32 +0200 Subject: [PATCH 0893/1539] Improve comments on EOF handling. --- src/websockets/connection.py | 68 ++++++++++++++++++++++++++++-------- 1 file changed, 54 insertions(+), 14 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 87a4e01b1..857698f4f 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -319,12 +319,17 @@ def fail(self, code: int, reason: str = "") -> None: self.close_sent = close self.state = CLOSING + # When failing the connection, a server closes the TCP connection + # without waiting for the client to complete the handshake, while a + # client waits for the server to close the TCP connection, possibly + # after sending a close frame that the client will ignore. if self.side is SERVER and not self.eof_sent: self.send_eof() - # "An endpoint MUST NOT continue to attempt to process data - # (including a responding Close frame) from the remote endpoint - # after being instructed to _Fail the WebSocket Connection_." + # 7.1.7. Fail the WebSocket Connection "An endpoint MUST NOT continue + # to attempt to process data(including a responding Close frame) from + # the remote endpoint after being instructed to _Fail the WebSocket + # Connection_." self.parser = self.discard() next(self.parser) # start coroutine @@ -362,20 +367,30 @@ def close_expected(self) -> bool: Tell whether the TCP connection is expected to close soon. Call this method immediately after calling any of the ``receive_*()`` - methods and, if it returns ``True``, schedule closing the TCP - connection after a short timeout. + or ``fail_*()`` methods and, if it returns ``True``, schedule closing + the TCP connection after a short timeout. """ # We already got a TCP Close if and only if the state is CLOSED. # We expect a TCP close if and only if we sent a close frame: - # * Normal closure: once we send a close frame, we expect a TCP close. - # * Abnormal closure: we always send a close frame except on EOFError, - # but that's fine because we already got the TCP close. + # * Normal closure: once we send a close frame, we expect a TCP close: + # server waits for client to complete the TCP closing handshake; + # client waits for server to initiate the TCP closing handshake. + # * Abnormal closure: we always send a close frame and the same logic + # applies, except on EOFError where we don't send a close frame + # because we already received the TCP close, so we don't expect it. return self.state is not CLOSED and self.close_sent is not None # Private methods for receiving data. def parse(self) -> Generator[None, None, None]: + """ + Parse incoming data into frames. + + :meth:`receive_data` and :meth:`receive_eof` run this generator + coroutine until it needs more data or reaches EOF. + + """ try: while True: if (yield from self.reader.at_eof()): @@ -394,6 +409,9 @@ def parse(self) -> Generator[None, None, None]: else: max_size = self.max_size - self.cur_size + # During a normal closure, execution ends here on the next + # iteration of the loop after receiving a close frame. At + # this point, recv_frame() replaced parse() by discard(). frame = yield from Frame.parse( self.reader.read_exact, mask=self.side is SERVER, @@ -428,26 +446,46 @@ def parse(self) -> Generator[None, None, None]: self.fail(1011) self.parser_exc = exc + # During an abnormal closure, execution ends here after catching an + # exception. At this point, fail() replaced parse() by discard(). yield - # If an error occurs, parse() is replaced by discard(). - raise AssertionError("parse() shouldn't step after EOF") # pragma: no cover + raise AssertionError("parse() shouldn't step after error") # pragma: no cover def discard(self) -> Generator[None, None, None]: + """ + Discard incoming data. + + This coroutine replaces :meth:`parse`: + + - after receiving a close frame, during a normal closure (1.4); + - after sending a close frame, during an abnormal closure (7.1.7). + + """ + # The server close the TCP connection in the same circumstances where + # discard() replaces parse(). The client closes the connection later, + # after the server closes the connection or a timeout elapses. + # (The latter case cannot be handled in this Sans-I/O layer.) + assert (self.side is SERVER) == (self.eof_sent) while not (yield from self.reader.at_eof()): self.reader.discard() if self.debug: self.logger.debug("< EOF") + # A server closes the TCP connection immediately, while a client + # waits for the server to close the TCP connection. if self.side is CLIENT: self.send_eof() - # If discard() completes normally, execution ends here. self.state = CLOSED + # If discard() completes normally, execution ends here. yield - # Once the reader reaches EOF, its feed_data/eof() - # methods raise an error, so our receive_data/eof() - # methods don't step the generator. + # Once the reader reaches EOF, its feed_data/eof() methods raise an + # error, so our receive_data/eof() methods don't step the generator. raise AssertionError("discard() shouldn't step after EOF") # pragma: no cover def recv_frame(self, frame: Frame) -> None: + """ + Process an incoming frame. + + """ if frame.opcode is OP_TEXT or frame.opcode is OP_BINARY: if self.cur_size is not None: raise ProtocolError("expected a continuation frame") @@ -506,6 +544,8 @@ def recv_frame(self, frame: Frame) -> None: # endpoint has both sent and received a Close control frame, # that endpoint SHOULD _Close the WebSocket Connection_" + # A server closes the TCP connection immediately, while a client + # waits for the server to close the TCP connection. if self.side is SERVER: self.send_eof() From 4ed76b2cc43b9480428121cc9f94ee36af4ea27f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 25 Jul 2021 09:43:37 +0200 Subject: [PATCH 0894/1539] Document all types visible in public APIs. --- docs/reference/utilities.rst | 20 +++++++++++++++-- src/websockets/__init__.py | 2 +- src/websockets/datastructures.py | 3 +-- src/websockets/typing.py | 38 ++++++++++++++++++-------------- 4 files changed, 42 insertions(+), 21 deletions(-) diff --git a/docs/reference/utilities.rst b/docs/reference/utilities.rst index f1d89eddc..e7f489fbd 100644 --- a/docs/reference/utilities.rst +++ b/docs/reference/utilities.rst @@ -10,7 +10,12 @@ Data structures --------------- .. automodule:: websockets.datastructures - :members: + + .. autoclass:: Headers + + .. autodata:: HeadersLike + + .. autoexception:: MultipleValuesError Exceptions ---------- @@ -22,4 +27,15 @@ Types ----- .. automodule:: websockets.typing - :members: + + .. autodata:: Data + + .. autodata:: LoggerLike + + .. autodata:: Origin + + .. autodata:: Subprotocol + + .. autodata:: ExtensionName + + .. autodata:: ExtensionParameter diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 6378b82cf..5883c3d65 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -14,7 +14,7 @@ "ConnectionClosedOK", "Data", "DuplicateParameter", - "ExtensionHeader", + "ExtensionName", "ExtensionParameter", "InvalidHandshake", "InvalidHeader", diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index 117ffd4f2..3e3f9705a 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -159,5 +159,4 @@ def raw_items(self) -> Iterator[Tuple[str, str]]: HeadersLike = Union[Headers, Mapping[str, str], Iterable[Tuple[str, str]]] -HeadersLike__doc__ = """Types accepted wherever :class:`Headers` is expected""" -HeadersLike.__doc__ = HeadersLike__doc__ +HeadersLike.__doc__ = """Types accepted where :class:`Headers` is expected""" diff --git a/src/websockets/typing.py b/src/websockets/typing.py index 13b172f15..1bd118071 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -8,46 +8,52 @@ "Data", "LoggerLike", "Origin", - "ExtensionHeader", - "ExtensionParameter", "Subprotocol", + "ExtensionName", + "ExtensionParameter", ] + +# Public types used in the signature of public APIs + Data = Union[str, bytes] Data.__doc__ = """ Types supported in a WebSocket message: -- :class:`str` for text messages -- :class:`bytes` for binary messages - +- :class:`str` for text messages; +- :class:`bytes` for binary messages. """ + LoggerLike = Union[logging.Logger, logging.LoggerAdapter] -LoggerLike.__doc__ = """"Types accepted where :class:`~logging.Logger` is expected""" +LoggerLike.__doc__ = """Types accepted where :class:`~logging.Logger` is expected.""" + Origin = NewType("Origin", str) -Origin.__doc__ = """Value of a Origin header""" +Origin.__doc__ = """Value of a Origin header.""" + + +Subprotocol = NewType("Subprotocol", str) +Subprotocol.__doc__ = """Subprotocol in a Sec-WebSocket-Protocol header.""" ExtensionName = NewType("ExtensionName", str) -ExtensionName.__doc__ = """Name of a WebSocket extension""" +ExtensionName.__doc__ = """Name of a WebSocket extension.""" ExtensionParameter = Tuple[str, Optional[str]] -ExtensionParameter.__doc__ = """Parameter of a WebSocket extension""" +ExtensionParameter.__doc__ = """Parameter of a WebSocket extension.""" -ExtensionHeader = Tuple[ExtensionName, List[ExtensionParameter]] -ExtensionHeader.__doc__ = """Extension in a Sec-WebSocket-Extensions header""" +# Private types - -Subprotocol = NewType("Subprotocol", str) -Subprotocol.__doc__ = """Subprotocol value in a Sec-WebSocket-Protocol header""" +ExtensionHeader = Tuple[ExtensionName, List[ExtensionParameter]] +ExtensionHeader.__doc__ = """Extension in a Sec-WebSocket-Extensions header.""" ConnectionOption = NewType("ConnectionOption", str) -ConnectionOption.__doc__ = """Connection option in a Connection header""" +ConnectionOption.__doc__ = """Connection option in a Connection header.""" UpgradeProtocol = NewType("UpgradeProtocol", str) -UpgradeProtocol.__doc__ = """Upgrade protocol in an Upgrade header""" +UpgradeProtocol.__doc__ = """Upgrade protocol in an Upgrade header.""" From 6aae7c69e30c74bbdc9bd81c52f7a5f90b4c8037 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 25 Jul 2021 18:21:36 +0200 Subject: [PATCH 0895/1539] Explain ping timeout errors in the FAQ. Fix #1012. --- docs/howto/faq.rst | 54 +++++++++++++++++++++++++++++++++++++--- docs/topics/timeouts.rst | 3 ++- 2 files changed, 53 insertions(+), 4 deletions(-) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index a4ad84680..7cec6d43e 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -251,10 +251,10 @@ If you're seeing this traceback in the logs of a server: .. code-block:: pytb - Error in connection handler + connection handler failed Traceback (most recent call last): ... - asyncio.streams.IncompleteReadError: 0 bytes read on a total of 2 expected bytes + asyncio.exceptions.IncompleteReadError: 0 bytes read on a total of 2 expected bytes The above exception was the direct cause of the following exception: @@ -290,11 +290,59 @@ There are several reasons why long-lived connections may be lost: a wired network, enter airplane mode, be put to sleep, etc. * HTTP load balancers or proxies that aren't configured for long-lived connections may terminate connections after a short amount of time, usually - 30 seconds. + 30 seconds, despite websockets' keepalive mechanism. If you're facing a reproducible issue, :ref:`enable debug logs ` to see when and how connections are closed. +What does ``ConnectionClosedError: sent 1011 (unexpected error) keepalive ping timeout; no close frame received`` mean? +....................................................................................................................... + +If you're seeing this traceback in the logs of a server: + +.. code-block:: pytb + + connection handler failed + Traceback (most recent call last): + ... + asyncio.exceptions.CancelledError + + The above exception was the direct cause of the following exception: + + Traceback (most recent call last): + ... + websockets.exceptions.ConnectionClosedError: sent 1011 (unexpected error) keepalive ping timeout; no close frame received + +or if a client crashes with this traceback: + +.. code-block:: pytb + + Traceback (most recent call last): + ... + asyncio.exceptions.CancelledError + + The above exception was the direct cause of the following exception: + + Traceback (most recent call last): + ... + websockets.exceptions.ConnectionClosedError: sent 1011 (unexpected error) keepalive ping timeout; no close frame received + +it means that the WebSocket connection suffered from excessive latency and was +closed after reaching the timeout of websockets' keepalive mechanism. + +You can catch and handle :exc:`~exceptions.ConnectionClosed` to prevent it +from being logged. + +There are two main reasons why latency may increase: + +* Poor network connectivity. +* More traffic than the recipient can handle. + +See the discussion of :doc:`timeouts <../topics/timeouts>` for details. + +If websockets' default timeout of 20 seconds is too short for your use case, +you can adjust it with the ``ping_timeout`` argument. + How do I set a timeout on ``recv()``? ..................................... diff --git a/docs/topics/timeouts.rst b/docs/topics/timeouts.rst index 828d22a0b..8febfce9f 100644 --- a/docs/topics/timeouts.rst +++ b/docs/topics/timeouts.rst @@ -20,7 +20,8 @@ By default, websockets waits 20 seconds, then sends a Ping frame, and expects to receive the corresponding Pong frame within 20 seconds. Else, it considers the connection broken and closes it. -Timings are configurable with ``ping_interval`` and ``ping_timeout``. +Timings are configurable with the ``ping_interval`` and ``ping_timeout`` +arguments of :func:`~websockets.connect` and :func:`~websockets.serve`. While WebSocket runs on top of TCP, websockets doesn't rely on TCP keepalive because it's disabled by default and, if enabled, the default interval is no From 4e1dac362a3639b9cf0e5bcf382601e6e32cfede Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 25 Jul 2021 18:22:26 +0200 Subject: [PATCH 0896/1539] Show DEBUG log level in example. --- docs/howto/cheatsheet.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/howto/cheatsheet.rst b/docs/howto/cheatsheet.rst index 86684c44c..edfb00baa 100644 --- a/docs/howto/cheatsheet.rst +++ b/docs/howto/cheatsheet.rst @@ -69,7 +69,7 @@ If you don't understand what websockets is doing, enable logging:: import logging logger = logging.getLogger('websockets') - logger.setLevel(logging.INFO) + logger.setLevel(logging.DEBUG) logger.addHandler(logging.StreamHandler()) The logs contain: From 21af14cdd61cb31a4a4c003d73f9eb489a7a22f3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 25 Jul 2021 21:37:01 +0200 Subject: [PATCH 0897/1539] Fix display of license in docs. --- docs/project/license.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/project/license.rst b/docs/project/license.rst index 426110020..0a3b8703d 100644 --- a/docs/project/license.rst +++ b/docs/project/license.rst @@ -1,4 +1,4 @@ License ======= -.. literalinclude:: ../../LICENSE +.. include:: ../../LICENSE From bf02ad01e91a05777d607f926673f8d8b0efc2e7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 25 Jul 2021 21:42:10 +0200 Subject: [PATCH 0898/1539] Memory use => usage. --- README.rst | 2 +- docs/topics/memory.rst | 2 +- docs/topics/security.rst | 12 ++++++------ src/websockets/legacy/protocol.py | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/README.rst b/README.rst index 70dec88db..6b9b67672 100644 --- a/README.rst +++ b/README.rst @@ -104,7 +104,7 @@ The development of ``websockets`` is shaped by four principles: under 100% branch coverage. Also it passes the industry-standard `Autobahn Testsuite`_. -4. **Performance**: memory use is configurable. An extension written in C +4. **Performance**: memory usage is configurable. An extension written in C accelerates expensive operations. It's pre-compiled for Linux, macOS and Windows and packaged in the wheel format for each system and Python version. diff --git a/docs/topics/memory.rst b/docs/topics/memory.rst index e5b9ed9a2..c880d5579 100644 --- a/docs/topics/memory.rst +++ b/docs/topics/memory.rst @@ -26,7 +26,7 @@ Buffers Under normal circumstances, buffers are almost always empty. Under high load, if a server receives more messages than it can process, -bufferbloat can result in excessive memory use. +bufferbloat can result in excessive memory usage. By default websockets has generous limits. It is strongly recommended to adapt them to your application. When you call :func:`~legacy.server.serve`: diff --git a/docs/topics/security.rst b/docs/topics/security.rst index 6c541db06..d3dec21bd 100644 --- a/docs/topics/security.rst +++ b/docs/topics/security.rst @@ -9,8 +9,8 @@ For production use, a server should require encrypted connections. See this example of :ref:`encrypting connections with TLS `. -Memory use ----------- +Memory usage +------------ .. warning:: @@ -21,11 +21,11 @@ Memory use With the default settings, opening a connection uses 70 KiB of memory. -Sending some highly compressed messages could use up to 128 MiB of memory -with an amplification factor of 1000 between network traffic and memory use. +Sending some highly compressed messages could use up to 128 MiB of memory with +an amplification factor of 1000 between network traffic and memory usage. -Configuring a server to :doc:`memory` will improve security in addition to -improving performance. +Configuring a server to :doc:`optimize memory usage ` will improve +security in addition to improving performance. Other limits ------------ diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index df74bdb38..b0d2ed832 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1432,7 +1432,7 @@ def broadcast(websockets: Iterable[WebSocketCommonProtocol], message: Data) -> N If you broadcast messages faster than a connection can handle them, messages will pile up in its write buffer until the connection times out. Keep low values for ``ping_interval`` and ``ping_timeout`` to prevent - excessive memory use by slow connections when you use :func:`broadcast`. + excessive memory usage by slow connections when you use :func:`broadcast`. Unlike :meth:`~websockets.server.WebSocketServerProtocol.send`, :func:`broadcast` doesn't support sending fragmented messages. Indeed, From 59eb1b53a9b3b868a05e0eecc8255ee3b11b8bcc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 25 Jul 2021 21:45:52 +0200 Subject: [PATCH 0899/1539] Fix typo. --- docs/topics/broadcast.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/topics/broadcast.rst b/docs/topics/broadcast.rst index c43a8fa36..f9cd9e281 100644 --- a/docs/topics/broadcast.rst +++ b/docs/topics/broadcast.rst @@ -12,7 +12,7 @@ Broadcasting messages If you want to learn about its design in depth, continue reading this document. -WebSocket servers often send the same message to all connected clients to a +WebSocket servers often send the same message to all connected clients or to a subset of clients for which the message is relevant. Let's explore options for broadcasting a message, explain the design From 0867d9f0f9fa0dfc267e6e0b91cb1aaa5858ccaf Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 25 Jul 2021 21:51:41 +0200 Subject: [PATCH 0900/1539] Clarify error message. --- src/websockets/frames.py | 4 ++-- src/websockets/legacy/protocol.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 839e2f7b7..a0ce1d350 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -340,7 +340,7 @@ def prepare_data(data: Data) -> Tuple[int, bytes]: elif isinstance(data, BytesLike): return OP_BINARY, data else: - raise TypeError("data must be bytes-like or str") + raise TypeError("data must be str or bytes-like") def prepare_ctrl(data: Data) -> bytes: @@ -362,7 +362,7 @@ def prepare_ctrl(data: Data) -> bytes: elif isinstance(data, BytesLike): return bytes(data) else: - raise TypeError("data must be bytes-like or str") + raise TypeError("data must be str or bytes-like") @dataclasses.dataclass diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index b0d2ed832..339733db2 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -507,7 +507,7 @@ async def send( :raises ~websockets.exceptions.ConnectionClosed: when the connection is closed - :raises TypeError: for unsupported inputs + :raises TypeError: if ``message`` doesn't have a supported type """ await self.ensure_open() @@ -613,7 +613,7 @@ async def send( self._fragmented_message_waiter = None else: - raise TypeError("data must be bytes, str, or iterable") + raise TypeError("data must be str, bytes-like, or iterable") async def close(self, code: int = 1000, reason: str = "") -> None: """ @@ -1441,11 +1441,11 @@ def broadcast(websockets: Iterable[WebSocketCommonProtocol], message: Data) -> N as fast as possible. :raises RuntimeError: if a connection is busy sending a fragmented message - :raises TypeError: for unsupported inputs + :raises TypeError: if ``message`` doesn't have a supported type """ if not isinstance(message, (str, bytes, bytearray, memoryview)): - raise TypeError("data must be bytes, str, or iterable") + raise TypeError("data must be str or bytes-like") opcode, data = prepare_data(message) From 786134542d9f0d6ad3959198c5e543ce55fff4aa Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 31 Jul 2021 09:35:19 +0200 Subject: [PATCH 0901/1539] Document that time.sleep doesn't work in asyncio. This belongs in asyncio's docs, not in websockets', but I can't find it (easily) in Python docs and it keeps tripping people up. Fix #1025, #1024, #923, #901, #703, etc. --- docs/howto/faq.rst | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index 7cec6d43e..394920b58 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -214,7 +214,7 @@ the connection terminates. Why does my program never receives any messages? ................................................ -Your program runs a coroutine that never yield control to the event loop. The +Your program runs a coroutine that never yields control to the event loop. The coroutine that receives messages never gets a chance to run. Putting an ``await`` statement in a ``for`` or a ``while`` loop isn't enough @@ -240,6 +240,15 @@ See `issue 867`_. .. _issue 867: https://github.com/aaugustin/websockets/issues/867 +Why does my very simple program misbehave mysteriously? +....................................................... + +You are using :func:`time.sleep` instead of :func:`asyncio.sleep`, which +blocks the event loop and prevents asyncio from operating normally. + +This may lead to messages getting send but not received, to connection +timeouts, and to unexpected results of shotgun debugging e.g. adding an +unnecessary call to ``send()`` makes the program functional. Both sides ---------- From b343fc60824737d4faa421aafe7983a8d3d0c9df Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 6 Aug 2021 22:42:41 +0200 Subject: [PATCH 0902/1539] Refactor tests for redirects. --- tests/legacy/test_client_server.py | 116 ++++++++--------------------- 1 file changed, 29 insertions(+), 87 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 755fcefdd..6ce100fa2 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -75,6 +75,20 @@ async def default_handler(ws, path): await ws.send((await ws.recv())) +async def redirect_request(path, headers, test, status): + if path == "/redirect": + location = get_server_uri(test.server, test.secure, "/") + elif path == "/infinite": + location = get_server_uri(test.server, test.secure, "/infinite") + elif path == "/force_insecure": + location = get_server_uri(test.server, False, "/") + elif path == "/missing_location": + return status, {}, b"" + else: + return None + return status, {"Location": location}, b"" + + @contextlib.contextmanager def temp_test_server(test, **kwargs): test.start_server(**kwargs) @@ -84,15 +98,9 @@ def temp_test_server(test, **kwargs): test.stop_server() -@contextlib.contextmanager -def temp_test_redirecting_server( - test, status, include_location=True, force_insecure=False, **kwargs -): - test.start_redirecting_server(status, include_location, force_insecure, **kwargs) - try: - yield - finally: - test.stop_redirecting_server() +def temp_test_redirecting_server(test, status=http.HTTPStatus.FOUND, **kwargs): + process_request = functools.partial(redirect_request, test=test, status=status) + return temp_test_server(test, process_request=process_request, **kwargs) @contextlib.contextmanager @@ -201,11 +209,6 @@ class ClientServerTestsMixin: def setUp(self): super().setUp() self.server = None - self.redirecting_server = None - - @property - def server_context(self): - return None def start_server(self, deprecation_warnings=None, **kwargs): handler = kwargs.pop("handler", default_handler) @@ -226,42 +229,6 @@ def start_server(self, deprecation_warnings=None, **kwargs): expected_warnings += ["There is no current event loop"] self.assertDeprecationWarnings(recorded_warnings, expected_warnings) - def start_redirecting_server( - self, - status, - include_location=True, - force_insecure=False, - deprecation_warnings=None, - **kwargs, - ): - async def process_request(path, headers): - server_uri = get_server_uri(self.server, self.secure, path) - if force_insecure: - server_uri = server_uri.replace("wss:", "ws:") - headers = {"Location": server_uri} if include_location else [] - return status, headers, b"" - - with warnings.catch_warnings(record=True) as recorded_warnings: - start_server = serve( - default_handler, - "localhost", - 0, - compression=None, - ping_interval=None, - process_request=process_request, - ssl=self.server_context, - **kwargs, - ) - self.redirecting_server = self.loop.run_until_complete(start_server) - - expected_warnings = [] if deprecation_warnings is None else deprecation_warnings - if ( - sys.version_info[:2] >= (3, 10) - and "remove loop argument" not in expected_warnings - ): # pragma: no cover - expected_warnings += ["There is no current event loop"] - self.assertDeprecationWarnings(recorded_warnings, expected_warnings) - def start_client( self, resource_name="/", user_info=None, deprecation_warnings=None, **kwargs ): @@ -274,8 +241,7 @@ def start_client( try: server_uri = kwargs.pop("uri") except KeyError: - server = self.redirecting_server if self.redirecting_server else self.server - server_uri = get_server_uri(server, secure, resource_name, user_info) + server_uri = get_server_uri(self.server, secure, resource_name, user_info) with warnings.catch_warnings(record=True) as recorded_warnings: start_client = connect(server_uri, **kwargs) @@ -306,17 +272,6 @@ def stop_server(self): except asyncio.TimeoutError: # pragma: no cover self.fail("Server failed to stop") - def stop_redirecting_server(self): - self.redirecting_server.close() - try: - self.loop.run_until_complete( - asyncio.wait_for(self.redirecting_server.wait_closed(), timeout=1) - ) - except asyncio.TimeoutError: # pragma: no cover - self.fail("Redirecting server failed to stop") - finally: - self.redirecting_server = None - @contextlib.contextmanager def temp_server(self, **kwargs): with temp_test_server(self, **kwargs): @@ -388,7 +343,6 @@ def test_basic(self): reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") - @with_server() def test_redirect(self): redirect_statuses = [ http.HTTPStatus.MOVED_PERMANENTLY, @@ -399,40 +353,31 @@ def test_redirect(self): ] for status in redirect_statuses: with temp_test_redirecting_server(self, status): - with temp_test_client(self): + with self.temp_client("/redirect"): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") def test_infinite_redirect(self): - with temp_test_redirecting_server( - self, - http.HTTPStatus.FOUND, - ): - self.server = self.redirecting_server + with temp_test_redirecting_server(self): with self.assertRaises(InvalidHandshake): - with temp_test_client(self): + with self.temp_client("/infinite"): self.fail("Did not raise") # pragma: no cover - @with_server() def test_redirect_missing_location(self): - with temp_test_redirecting_server( - self, - http.HTTPStatus.FOUND, - include_location=False, - loop=self.loop, - deprecation_warnings=["remove loop argument"], - ): + with temp_test_redirecting_server(self): with self.assertRaises(InvalidHeader): - with temp_test_client(self): + with self.temp_client("/missing_location"): self.fail("Did not raise") # pragma: no cover def test_loop_backwards_compatibility(self): with self.temp_server( - loop=self.loop, deprecation_warnings=["remove loop argument"] + loop=self.loop, + deprecation_warnings=["remove loop argument"], ): with self.temp_client( - loop=self.loop, deprecation_warnings=["remove loop argument"] + loop=self.loop, + deprecation_warnings=["remove loop argument"], ): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) @@ -1274,13 +1219,10 @@ def test_ws_uri_is_rejected(self): uri=get_server_uri(self.server, secure=False), ssl=self.client_context ) - @with_server() def test_redirect_insecure(self): - with temp_test_redirecting_server( - self, http.HTTPStatus.FOUND, force_insecure=True - ): + with temp_test_redirecting_server(self): with self.assertRaises(InvalidHandshake): - with temp_test_client(self): + with self.temp_client("/force_insecure"): self.fail("Did not raise") # pragma: no cover From 1f462000ac9b4f8ad80deb565083ffe430e09acf Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 11 Aug 2021 09:53:50 +0200 Subject: [PATCH 0903/1539] Fix handling of relative redirects. Fix #1023. --- docs/project/changelog.rst | 2 ++ src/websockets/legacy/client.py | 7 ++++++- tests/legacy/test_client_server.py | 13 +++++++++++-- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index cf38c6159..277ee5022 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -74,6 +74,8 @@ They may change at any time. * Made it easier to customize authentication with :meth:`~auth.BasicAuthWebSocketServerProtocol.check_credentials`. +* Fixed handling of relative redirects in :func:`~legacy.client.connect`. + 9.1 ... diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index c87b5f8d5..5d6d8130a 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -9,6 +9,7 @@ import collections.abc import functools import logging +import urllib.parse import warnings from types import TracebackType from typing import ( @@ -611,12 +612,15 @@ def __init__( # This is a coroutine function. self._create_connection = create_connection + self._uri = uri self._wsuri = wsuri def handle_redirect(self, uri: str) -> None: # Update the state of this instance to connect to a new URI. + old_uri = self._uri old_wsuri = self._wsuri - new_wsuri = parse_uri(uri) + new_uri = urllib.parse.urljoin(old_uri, uri) + new_wsuri = parse_uri(new_uri) # Forbid TLS downgrade. if old_wsuri.secure and not new_wsuri.secure: @@ -645,6 +649,7 @@ def handle_redirect(self, uri: str) -> None: ) # Set the new WebSocket URI. This suffices for same-origin redirects. + self._uri = new_uri self._wsuri = new_wsuri # async for ... in connect(...): diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 6ce100fa2..02d7a7aa3 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -76,8 +76,10 @@ async def default_handler(ws, path): async def redirect_request(path, headers, test, status): - if path == "/redirect": + if path == "/absolute_redirect": location = get_server_uri(test.server, test.secure, "/") + elif path == "/relative_redirect": + location = "/" elif path == "/infinite": location = get_server_uri(test.server, test.secure, "/infinite") elif path == "/force_insecure": @@ -353,11 +355,18 @@ def test_redirect(self): ] for status in redirect_statuses: with temp_test_redirecting_server(self, status): - with self.temp_client("/redirect"): + with self.temp_client("/absolute_redirect"): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") + def test_redirect_relative_location(self): + with temp_test_redirecting_server(self): + with self.temp_client("/relative_redirect"): + self.loop.run_until_complete(self.client.send("Hello!")) + reply = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(reply, "Hello!") + def test_infinite_redirect(self): with temp_test_redirecting_server(self): with self.assertRaises(InvalidHandshake): From 90a20f550ea57f3ae6474d560d623158de97e490 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 12 Aug 2021 07:09:04 +0200 Subject: [PATCH 0904/1539] Simplify exceptions raised by send_close. --- src/websockets/connection.py | 2 +- tests/test_connection.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 857698f4f..52fd9bb81 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -275,7 +275,7 @@ def send_close(self, code: Optional[int] = None, reason: str = "") -> None: raise ProtocolError("expected a continuation frame") if code is None: if reason != "": - raise ValueError("cannot send a reason without a code") + raise ProtocolError("cannot send a reason without a code") close = Close(1005, "") data = b"" else: diff --git a/tests/test_connection.py b/tests/test_connection.py index ac0df46ce..3d4d98436 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -865,13 +865,13 @@ def test_server_receives_close_with_code_and_reason(self): def test_client_sends_close_with_reason_only(self): client = Connection(Side.CLIENT) - with self.assertRaises(ValueError) as raised: + with self.assertRaises(ProtocolError) as raised: client.send_close(reason="going away") self.assertEqual(str(raised.exception), "cannot send a reason without a code") def test_server_sends_close_with_reason_only(self): server = Connection(Side.SERVER) - with self.assertRaises(ValueError) as raised: + with self.assertRaises(ProtocolError) as raised: server.send_close(reason="OK") self.assertEqual(str(raised.exception), "cannot send a reason without a code") From dc42ecba8d809e14f1856e2e65816c27eaeb8fb6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 12 Aug 2021 08:43:04 +0200 Subject: [PATCH 0905/1539] Make Headers comparison case-insensitive. --- src/websockets/datastructures.py | 2 +- tests/test_datastructures.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index 3e3f9705a..42282e209 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -129,7 +129,7 @@ def __delitem__(self, key: str) -> None: def __eq__(self, other: Any) -> bool: if not isinstance(other, Headers): return NotImplemented - return self._list == other._list + return self._dict == other._dict def clear(self) -> None: """ diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index 628cbcb02..a361a1b4d 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -97,6 +97,10 @@ def test_eq(self): other_headers = Headers([("Connection", "Upgrade"), ("Server", "websockets")]) self.assertEqual(self.headers, other_headers) + def test_eq_case_insensitive(self): + other_headers = Headers(connection="Upgrade", server="websockets") + self.assertEqual(self.headers, other_headers) + def test_eq_not_equal(self): other_headers = Headers([("Connection", "close"), ("Server", "websockets")]) self.assertNotEqual(self.headers, other_headers) From 5de7b41a2b2003aaf1db4042ac3d5e2ab4c24cdb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 12 Aug 2021 08:43:46 +0200 Subject: [PATCH 0906/1539] Fix initializing Headers from Headers. It didn't work when a header had multiple values. --- src/websockets/datastructures.py | 16 +++- tests/test_datastructures.py | 149 ++++++++++++++++++++++++++----- 2 files changed, 138 insertions(+), 27 deletions(-) diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index 42282e209..65c5d4115 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -74,10 +74,10 @@ class Headers(MutableMapping[str, str]): __slots__ = ["_dict", "_list"] - def __init__(self, *args: Any, **kwargs: str) -> None: + # Like dict, Headers accepts an optional "mapping or iterable" argument. + def __init__(self, *args: HeadersLike, **kwargs: str) -> None: self._dict: Dict[str, List[str]] = {} self._list: List[Tuple[str, str]] = [] - # MutableMapping.update calls __setitem__ for each (name, value) pair. self.update(*args, **kwargs) def __str__(self) -> str: @@ -86,7 +86,7 @@ def __str__(self) -> str: def __repr__(self) -> str: return f"{self.__class__.__name__}({self._list!r})" - def copy(self) -> "Headers": + def copy(self) -> Headers: copy = self.__class__() copy._dict = self._dict.copy() copy._list = self._list.copy() @@ -139,6 +139,16 @@ def clear(self) -> None: self._dict = {} self._list = [] + def update(self, *args: HeadersLike, **kwargs: str) -> None: + """ + Update from a Headers instance and/or keyword arguments. + + """ + args = tuple( + arg.raw_items() if isinstance(arg, Headers) else arg for arg in args + ) + super().update(*args, **kwargs) + # Methods for handling multiple values def get_all(self, key: str) -> List[str]: diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index a361a1b4d..32b79817a 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -3,10 +3,68 @@ from websockets.datastructures import * +class MultipleValuesErrorTests(unittest.TestCase): + def test_multiple_values_error_str(self): + self.assertEqual(str(MultipleValuesError("Connection")), "'Connection'") + self.assertEqual(str(MultipleValuesError()), "") + + class HeadersTests(unittest.TestCase): def setUp(self): self.headers = Headers([("Connection", "Upgrade"), ("Server", "websockets")]) + def test_init(self): + self.assertEqual( + Headers(), + Headers(), + ) + + def test_init_from_kwargs(self): + self.assertEqual( + Headers(connection="Upgrade", server="websockets"), + self.headers, + ) + + def test_init_from_headers(self): + self.assertEqual( + Headers(self.headers), + self.headers, + ) + + def test_init_from_headers_and_kwargs(self): + self.assertEqual( + Headers(Headers(connection="Upgrade"), server="websockets"), + self.headers, + ) + + def test_init_from_mapping(self): + self.assertEqual( + Headers({"Connection": "Upgrade", "Server": "websockets"}), + self.headers, + ) + + def test_init_from_mapping_and_kwargs(self): + self.assertEqual( + Headers({"Connection": "Upgrade"}, server="websockets"), + self.headers, + ) + + def test_init_from_iterable(self): + self.assertEqual( + Headers([("Connection", "Upgrade"), ("Server", "websockets")]), + self.headers, + ) + + def test_init_from_iterable_and_kwargs(self): + self.assertEqual( + Headers([("Connection", "Upgrade")], server="websockets"), + self.headers, + ) + + def test_init_multiple_positional_arguments(self): + with self.assertRaises(TypeError): + Headers(Headers(connection="Upgrade"), Headers(server="websockets")) + def test_str(self): self.assertEqual( str(self.headers), "Connection: Upgrade\r\nServer: websockets\r\n\r\n" @@ -27,10 +85,6 @@ def test_serialize(self): b"Connection: Upgrade\r\nServer: websockets\r\n\r\n", ) - def test_multiple_values_error_str(self): - self.assertEqual(str(MultipleValuesError("Connection")), "'Connection'") - self.assertEqual(str(MultipleValuesError()), "") - def test_contains(self): self.assertIn("Server", self.headers) @@ -59,11 +113,6 @@ def test_getitem_key_error(self): with self.assertRaises(KeyError): self.headers["Upgrade"] - def test_getitem_multiple_values_error(self): - self.headers["Server"] = "2" - with self.assertRaises(MultipleValuesError): - self.headers["Server"] - def test_setitem(self): self.headers["Upgrade"] = "websocket" self.assertEqual(self.headers["Upgrade"], "websocket") @@ -72,11 +121,6 @@ def test_setitem_case_insensitive(self): self.headers["upgrade"] = "websocket" self.assertEqual(self.headers["Upgrade"], "websocket") - def test_setitem_multiple_values(self): - self.headers["Connection"] = "close" - with self.assertRaises(MultipleValuesError): - self.headers["Connection"] - def test_delitem(self): del self.headers["Connection"] with self.assertRaises(KeyError): @@ -87,12 +131,6 @@ def test_delitem_case_insensitive(self): with self.assertRaises(KeyError): self.headers["Connection"] - def test_delitem_multiple_values(self): - self.headers["Connection"] = "close" - del self.headers["Connection"] - with self.assertRaises(KeyError): - self.headers["Connection"] - def test_eq(self): other_headers = Headers([("Connection", "Upgrade"), ("Server", "websockets")]) self.assertEqual(self.headers, other_headers) @@ -124,12 +162,75 @@ def test_get_all_case_insensitive(self): def test_get_all_no_values(self): self.assertEqual(self.headers.get_all("Upgrade"), []) - def test_get_all_multiple_values(self): - self.headers["Connection"] = "close" - self.assertEqual(self.headers.get_all("Connection"), ["Upgrade", "close"]) - def test_raw_items(self): self.assertEqual( list(self.headers.raw_items()), [("Connection", "Upgrade"), ("Server", "websockets")], ) + + +class MultiValueHeadersTests(unittest.TestCase): + def setUp(self): + self.headers = Headers([("Server", "Python"), ("Server", "websockets")]) + + def test_init_from_headers(self): + self.assertEqual( + Headers(self.headers), + self.headers, + ) + + def test_init_from_headers_and_kwargs(self): + self.assertEqual( + Headers(Headers(server="Python"), server="websockets"), + self.headers, + ) + + def test_str(self): + self.assertEqual( + str(self.headers), "Server: Python\r\nServer: websockets\r\n\r\n" + ) + + def test_repr(self): + self.assertEqual( + repr(self.headers), + "Headers([('Server', 'Python'), ('Server', 'websockets')])", + ) + + def test_copy(self): + self.assertEqual(repr(self.headers.copy()), repr(self.headers)) + + def test_serialize(self): + self.assertEqual( + self.headers.serialize(), + b"Server: Python\r\nServer: websockets\r\n\r\n", + ) + + def test_iter(self): + self.assertEqual(set(iter(self.headers)), {"server"}) + + def test_len(self): + self.assertEqual(len(self.headers), 1) + + def test_getitem_multiple_values_error(self): + with self.assertRaises(MultipleValuesError): + self.headers["Server"] + + def test_setitem(self): + self.headers["Server"] = "redux" + self.assertEqual( + self.headers.get_all("Server"), ["Python", "websockets", "redux"] + ) + + def test_delitem(self): + del self.headers["Server"] + with self.assertRaises(KeyError): + self.headers["Server"] + + def test_get_all(self): + self.assertEqual(self.headers.get_all("Server"), ["Python", "websockets"]) + + def test_raw_items(self): + self.assertEqual( + list(self.headers.raw_items()), + [("Server", "Python"), ("Server", "websockets")], + ) From 26c17794dffef336ec5f43405d09608a42a78bca Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 12 Aug 2021 09:12:23 +0200 Subject: [PATCH 0907/1539] Simplify extra_headers implementation. --- src/websockets/client.py | 9 +----- src/websockets/legacy/client.py | 10 ++---- src/websockets/legacy/server.py | 8 +---- src/websockets/server.py | 8 +---- tests/legacy/test_client_server.py | 52 +++--------------------------- 5 files changed, 10 insertions(+), 77 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 6a0bb580e..105f42775 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -1,6 +1,5 @@ from __future__ import annotations -import collections from typing import Generator, List, Optional, Sequence from .connection import CLIENT, CONNECTING, OPEN, Connection, State @@ -105,13 +104,7 @@ def connect(self) -> Request: # noqa: F811 headers["Sec-WebSocket-Protocol"] = protocol_header if self.extra_headers is not None: - extra_headers = self.extra_headers - if isinstance(extra_headers, Headers): - extra_headers = extra_headers.raw_items() - elif isinstance(extra_headers, collections.abc.Mapping): - extra_headers = extra_headers.items() - for name, value in extra_headers: - headers[name] = value + headers.update(self.extra_headers) headers.setdefault("User-Agent", USER_AGENT) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 5d6d8130a..bad2cbeea 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -6,7 +6,6 @@ from __future__ import annotations import asyncio -import collections.abc import functools import logging import urllib.parse @@ -387,13 +386,8 @@ async def handshake( protocol_header = build_subprotocol(available_subprotocols) request_headers["Sec-WebSocket-Protocol"] = protocol_header - if extra_headers is not None: - if isinstance(extra_headers, Headers): - extra_headers = extra_headers.raw_items() - elif isinstance(extra_headers, collections.abc.Mapping): - extra_headers = extra_headers.items() - for name, value in extra_headers: - request_headers[name] = value + if self.extra_headers is not None: + request_headers.update(self.extra_headers) request_headers.setdefault("User-Agent", USER_AGENT) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 3fde99568..297fc5664 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -6,7 +6,6 @@ from __future__ import annotations import asyncio -import collections.abc import email.utils import functools import http @@ -697,12 +696,7 @@ async def handshake( if callable(extra_headers): extra_headers = extra_headers(path, self.request_headers) if extra_headers is not None: - if isinstance(extra_headers, Headers): - extra_headers = extra_headers.raw_items() - elif isinstance(extra_headers, collections.abc.Mapping): - extra_headers = extra_headers.items() - for name, value in extra_headers: - response_headers[name] = value + response_headers.update(extra_headers) response_headers.setdefault("Date", email.utils.formatdate(usegmt=True)) response_headers.setdefault("Server", USER_AGENT) diff --git a/src/websockets/server.py b/src/websockets/server.py index d6d2143bb..ae0cfdd58 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -2,7 +2,6 @@ import base64 import binascii -import collections import email.utils import http from typing import Callable, Generator, List, Optional, Sequence, Tuple, Union, cast @@ -142,12 +141,7 @@ def accept(self, request: Request) -> Response: else: extra_headers = self.extra_headers if extra_headers is not None: - if isinstance(extra_headers, Headers): - extra_headers = extra_headers.raw_items() - elif isinstance(extra_headers, collections.abc.Mapping): - extra_headers = extra_headers.items() - for name, value in extra_headers: - headers[name] = value + headers.update(extra_headers) headers.setdefault("Date", email.utils.formatdate(usegmt=True)) headers.setdefault("Server", USER_AGENT) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 02d7a7aa3..3fcd0b044 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -608,52 +608,24 @@ def test_protocol_headers(self): self.assertEqual(server_req, repr(client_req)) self.assertEqual(server_resp, repr(client_resp)) - @with_server() - @with_client("/headers", extra_headers=Headers({"X-Spam": "Eggs"})) - def test_protocol_custom_request_headers(self): - req_headers = self.loop.run_until_complete(self.client.recv()) - self.loop.run_until_complete(self.client.recv()) - self.assertIn("('X-Spam', 'Eggs')", req_headers) - @with_server() @with_client("/headers", extra_headers={"X-Spam": "Eggs"}) - def test_protocol_custom_request_headers_dict(self): - req_headers = self.loop.run_until_complete(self.client.recv()) - self.loop.run_until_complete(self.client.recv()) - self.assertIn("('X-Spam', 'Eggs')", req_headers) - - @with_server() - @with_client("/headers", extra_headers=[("X-Spam", "Eggs")]) - def test_protocol_custom_request_headers_list(self): + def test_protocol_custom_request_headers(self): req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", req_headers) @with_server() - @with_client("/headers", extra_headers=[("User-Agent", "Eggs")]) + @with_client("/headers", extra_headers={"User-Agent": "Eggs"}) def test_protocol_custom_request_user_agent(self): req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertEqual(req_headers.count("User-Agent"), 1) self.assertIn("('User-Agent', 'Eggs')", req_headers) - @with_server(extra_headers=lambda p, r: Headers({"X-Spam": "Eggs"})) - @with_client("/headers") - def test_protocol_custom_response_headers_callable(self): - self.loop.run_until_complete(self.client.recv()) - resp_headers = self.loop.run_until_complete(self.client.recv()) - self.assertIn("('X-Spam', 'Eggs')", resp_headers) - @with_server(extra_headers=lambda p, r: {"X-Spam": "Eggs"}) @with_client("/headers") - def test_protocol_custom_response_headers_callable_dict(self): - self.loop.run_until_complete(self.client.recv()) - resp_headers = self.loop.run_until_complete(self.client.recv()) - self.assertIn("('X-Spam', 'Eggs')", resp_headers) - - @with_server(extra_headers=lambda p, r: [("X-Spam", "Eggs")]) - @with_client("/headers") - def test_protocol_custom_response_headers_callable_list(self): + def test_protocol_custom_response_headers_callable(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) @@ -664,28 +636,14 @@ def test_protocol_custom_response_headers_callable_none(self): self.loop.run_until_complete(self.client.recv()) # doesn't crash self.loop.run_until_complete(self.client.recv()) # nothing to check - @with_server(extra_headers=Headers({"X-Spam": "Eggs"})) - @with_client("/headers") - def test_protocol_custom_response_headers(self): - self.loop.run_until_complete(self.client.recv()) - resp_headers = self.loop.run_until_complete(self.client.recv()) - self.assertIn("('X-Spam', 'Eggs')", resp_headers) - @with_server(extra_headers={"X-Spam": "Eggs"}) @with_client("/headers") - def test_protocol_custom_response_headers_dict(self): - self.loop.run_until_complete(self.client.recv()) - resp_headers = self.loop.run_until_complete(self.client.recv()) - self.assertIn("('X-Spam', 'Eggs')", resp_headers) - - @with_server(extra_headers=[("X-Spam", "Eggs")]) - @with_client("/headers") - def test_protocol_custom_response_headers_list(self): + def test_protocol_custom_response_headers(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) - @with_server(extra_headers=[("Server", "Eggs")]) + @with_server(extra_headers={"Server": "Eggs"}) @with_client("/headers") def test_protocol_custom_response_user_agent(self): self.loop.run_until_complete(self.client.recv()) From 6edb363af07365867b4c372405b6ab3177a57830 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 12 Aug 2021 09:22:38 +0200 Subject: [PATCH 0908/1539] Remove extra_headers from Sans-I/O layer. Request and response objects are explicitly passed through the integration layer, making this API unnecessary. --- src/websockets/client.py | 9 ++------- src/websockets/server.py | 26 +++++++------------------- tests/test_client.py | 22 ---------------------- tests/test_server.py | 37 ++++++------------------------------- 4 files changed, 15 insertions(+), 79 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 105f42775..e21fd36c0 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -3,7 +3,7 @@ from typing import Generator, List, Optional, Sequence from .connection import CLIENT, CONNECTING, OPEN, Connection, State -from .datastructures import Headers, HeadersLike, MultipleValuesError +from .datastructures import Headers, MultipleValuesError from .exceptions import ( InvalidHandshake, InvalidHeader, @@ -50,7 +50,6 @@ def __init__( origin: Optional[Origin] = None, extensions: Optional[Sequence[ClientExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLike] = None, state: State = CONNECTING, max_size: Optional[int] = 2 ** 20, logger: Optional[LoggerLike] = None, @@ -65,7 +64,6 @@ def __init__( self.origin = origin self.available_extensions = extensions self.available_subprotocols = subprotocols - self.extra_headers = extra_headers self.key = generate_key() def connect(self) -> Request: # noqa: F811 @@ -103,10 +101,7 @@ def connect(self) -> Request: # noqa: F811 protocol_header = build_subprotocol(self.available_subprotocols) headers["Sec-WebSocket-Protocol"] = protocol_header - if self.extra_headers is not None: - headers.update(self.extra_headers) - - headers.setdefault("User-Agent", USER_AGENT) + headers["User-Agent"] = USER_AGENT return Request(self.wsuri.resource_name, headers) diff --git a/src/websockets/server.py b/src/websockets/server.py index ae0cfdd58..67183c685 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -4,10 +4,10 @@ import binascii import email.utils import http -from typing import Callable, Generator, List, Optional, Sequence, Tuple, Union, cast +from typing import Generator, List, Optional, Sequence, Tuple, cast from .connection import CONNECTING, OPEN, SERVER, Connection, State -from .datastructures import Headers, HeadersLike, MultipleValuesError +from .datastructures import Headers, MultipleValuesError from .exceptions import ( InvalidHandshake, InvalidHeader, @@ -44,9 +44,6 @@ __all__ = ["ServerConnection"] -HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]] - - class ServerConnection(Connection): side = SERVER @@ -56,7 +53,6 @@ def __init__( origins: Optional[Sequence[Optional[Origin]]] = None, extensions: Optional[Sequence[ServerExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLikeOrCallable] = None, state: State = CONNECTING, max_size: Optional[int] = 2 ** 20, logger: Optional[LoggerLike] = None, @@ -70,7 +66,6 @@ def __init__( self.origins = origins self.available_extensions = extensions self.available_subprotocols = subprotocols - self.extra_headers = extra_headers def accept(self, request: Request) -> Response: """ @@ -125,6 +120,8 @@ def accept(self, request: Request) -> Response: headers = Headers() + headers["Date"] = email.utils.formatdate(usegmt=True) + headers["Upgrade"] = "websocket" headers["Connection"] = "Upgrade" headers["Sec-WebSocket-Accept"] = accept_key(key) @@ -135,16 +132,7 @@ def accept(self, request: Request) -> Response: if protocol_header is not None: headers["Sec-WebSocket-Protocol"] = protocol_header - extra_headers: Optional[HeadersLike] - if callable(self.extra_headers): - extra_headers = self.extra_headers(request.path, request.headers) - else: - extra_headers = self.extra_headers - if extra_headers is not None: - headers.update(extra_headers) - - headers.setdefault("Date", email.utils.formatdate(usegmt=True)) - headers.setdefault("Server", USER_AGENT) + headers["Server"] = USER_AGENT self.logger.info("connection open") return Response(101, "Switching Protocols", headers) @@ -402,10 +390,10 @@ def reject( if headers is None: headers = Headers() headers.setdefault("Date", email.utils.formatdate(usegmt=True)) - headers.setdefault("Server", USER_AGENT) + headers.setdefault("Connection", "close") headers.setdefault("Content-Length", str(len(body))) headers.setdefault("Content-Type", "text/plain; charset=utf-8") - headers.setdefault("Connection", "close") + headers.setdefault("Server", USER_AGENT) self.logger.info("connection failed (%d %s)", status.value, status.phrase) return Response(status.value, status.phrase, headers, body) diff --git a/tests/test_client.py b/tests/test_client.py index 2ef1f6a95..015b93b3f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -108,28 +108,6 @@ def test_subprotocols(self): self.assertEqual(request.headers["Sec-WebSocket-Protocol"], "chat") - def test_extra_headers(self): - for extra_headers in [ - Headers({"X-Spam": "Eggs"}), - {"X-Spam": "Eggs"}, - [("X-Spam", "Eggs")], - ]: - with self.subTest(extra_headers=extra_headers): - client = ClientConnection( - "wss://example.com/", extra_headers=extra_headers - ) - request = client.connect() - - self.assertEqual(request.headers["X-Spam"], "Eggs") - - def test_extra_headers_overrides_user_agent(self): - client = ClientConnection( - "wss://example.com/", extra_headers={"User-Agent": "Other"} - ) - request = client.connect() - - self.assertEqual(request.headers["User-Agent"], "Other") - class AcceptRejectTests(unittest.TestCase): def test_receive_accept(self): diff --git a/tests/test_server.py b/tests/test_server.py index 042d64a31..54699c3ef 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -96,10 +96,10 @@ def test_send_accept(self): server.data_to_send(), [ f"HTTP/1.1 101 Switching Protocols\r\n" + f"Date: {DATE}\r\n" f"Upgrade: websocket\r\n" f"Connection: Upgrade\r\n" f"Sec-WebSocket-Accept: {ACCEPT}\r\n" - f"Date: {DATE}\r\n" f"Server: {USER_AGENT}\r\n" f"\r\n".encode() ], @@ -117,10 +117,10 @@ def test_send_reject(self): [ f"HTTP/1.1 404 Not Found\r\n" f"Date: {DATE}\r\n" - f"Server: {USER_AGENT}\r\n" + f"Connection: close\r\n" f"Content-Length: 13\r\n" f"Content-Type: text/plain; charset=utf-8\r\n" - f"Connection: close\r\n" + f"Server: {USER_AGENT}\r\n" f"\r\n" f"Sorry folks.\n".encode(), b"", @@ -139,10 +139,10 @@ def test_accept_response(self): response.headers, Headers( { + "Date": DATE, "Upgrade": "websocket", "Connection": "Upgrade", "Sec-WebSocket-Accept": ACCEPT, - "Date": DATE, "Server": USER_AGENT, } ), @@ -161,10 +161,10 @@ def test_reject_response(self): Headers( { "Date": DATE, - "Server": USER_AGENT, + "Connection": "close", "Content-Length": "13", "Content-Type": "text/plain; charset=utf-8", - "Connection": "close", + "Server": USER_AGENT, } ), ) @@ -604,31 +604,6 @@ def test_unsupported_subprotocol(self): self.assertNotIn("Sec-WebSocket-Protocol", response.headers) self.assertIsNone(server.subprotocol) - def test_extra_headers(self): - for extra_headers in [ - Headers({"X-Spam": "Eggs"}), - {"X-Spam": "Eggs"}, - [("X-Spam", "Eggs")], - lambda path, headers: Headers({"X-Spam": "Eggs"}), - lambda path, headers: {"X-Spam": "Eggs"}, - lambda path, headers: [("X-Spam", "Eggs")], - ]: - with self.subTest(extra_headers=extra_headers): - server = ServerConnection(extra_headers=extra_headers) - request = self.make_request() - response = server.accept(request) - - self.assertEqual(response.status_code, 101) - self.assertEqual(response.headers["X-Spam"], "Eggs") - - def test_extra_headers_overrides_server(self): - server = ServerConnection(extra_headers={"Server": "Other"}) - request = self.make_request() - response = server.accept(request) - - self.assertEqual(response.status_code, 101) - self.assertEqual(response.headers["Server"], "Other") - class MiscTests(unittest.TestCase): def test_bypass_handshake(self): From af0f5646ad14e62ee614fa898cc4651ae7d02b93 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 14 Aug 2021 08:36:27 +0200 Subject: [PATCH 0909/1539] Fix rendering of extensions how-to. --- docs/howto/extensions.rst | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/howto/extensions.rst b/docs/howto/extensions.rst index fdaf09f63..9c49de172 100644 --- a/docs/howto/extensions.rst +++ b/docs/howto/extensions.rst @@ -1,6 +1,8 @@ Writing an extension ==================== +.. currentmodule:: websockets.extensions + During the opening handshake, WebSocket clients and servers negotiate which extensions will be used with which parameters. Then each frame is processed by extensions before being sent or after being received. @@ -24,8 +26,8 @@ As a consequence, writing an extension requires implementing several classes: websockets provides abstract base classes for extension factories and extensions. See the API documentation for details on their methods: -* :class:`ClientExtensionFactory` and class:`ServerExtensionFactory` for - :extension factories, +* :class:`ClientExtensionFactory` and :class:`ServerExtensionFactory` for + extension factories, * :class:`Extension` for extensions. From 017a832fb0070318a874626b3c69e36d2af669b3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 14 Aug 2021 08:45:55 +0200 Subject: [PATCH 0910/1539] Regenerate Sphinx configuration. Recent Sphinx versions have a simpler template. --- docs/Makefile | 166 +++----------------------------- docs/conf.py | 258 +++++++------------------------------------------- docs/make.bat | 35 +++++++ 3 files changed, 83 insertions(+), 376 deletions(-) create mode 100644 docs/make.bat diff --git a/docs/Makefile b/docs/Makefile index bb25aa49d..d4bb2cbb9 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -1,160 +1,20 @@ -# Makefile for Sphinx documentation +# Minimal makefile for Sphinx documentation # -# You can set these variables from the command line. -SPHINXOPTS = -SPHINXBUILD = sphinx-build -PAPER = +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . BUILDDIR = _build -# Internal variables. -PAPEROPT_a4 = -D latex_paper_size=a4 -PAPEROPT_letter = -D latex_paper_size=letter -ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . -# the i18n builder cannot share the environment and doctrees with the others -I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . - -.PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext - +# Put it first so that "make" without argument is like "make help". help: - @echo "Please use \`make ' where is one of" - @echo " html to make standalone HTML files" - @echo " dirhtml to make HTML files named index.html in directories" - @echo " singlehtml to make a single large HTML file" - @echo " pickle to make pickle files" - @echo " json to make JSON files" - @echo " htmlhelp to make HTML files and a HTML help project" - @echo " qthelp to make HTML files and a qthelp project" - @echo " devhelp to make HTML files and a Devhelp project" - @echo " epub to make an epub" - @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" - @echo " latexpdf to make LaTeX files and run them through pdflatex" - @echo " text to make text files" - @echo " man to make manual pages" - @echo " texinfo to make Texinfo files" - @echo " info to make Texinfo files and run them through makeinfo" - @echo " gettext to make PO message catalogs" - @echo " changes to make an overview of all changed/added/deprecated items" - @echo " linkcheck to check all external links for integrity" - @echo " doctest to run all doctests embedded in the documentation (if enabled)" - @echo " spelling to check for typos in documentation" - -clean: - -rm -rf $(BUILDDIR)/* - -html: - $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html - @echo - @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." - -dirhtml: - $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml - @echo - @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." - -singlehtml: - $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml - @echo - @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." - -pickle: - $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle - @echo - @echo "Build finished; now you can process the pickle files." - -json: - $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json - @echo - @echo "Build finished; now you can process the JSON files." - -htmlhelp: - $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp - @echo - @echo "Build finished; now you can run HTML Help Workshop with the" \ - ".hhp project file in $(BUILDDIR)/htmlhelp." - -qthelp: - $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp - @echo - @echo "Build finished; now you can run "qcollectiongenerator" with the" \ - ".qhcp project file in $(BUILDDIR)/qthelp, like this:" - @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/websockets.qhcp" - @echo "To view the help file:" - @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/websockets.qhc" - -devhelp: - $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp - @echo - @echo "Build finished." - @echo "To view the help file:" - @echo "# mkdir -p $$HOME/.local/share/devhelp/websockets" - @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/websockets" - @echo "# devhelp" - -epub: - $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub - @echo - @echo "Build finished. The epub file is in $(BUILDDIR)/epub." - -latex: - $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex - @echo - @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." - @echo "Run \`make' in that directory to run these through (pdf)latex" \ - "(use \`make latexpdf' here to do that automatically)." - -latexpdf: - $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex - @echo "Running LaTeX files through pdflatex..." - $(MAKE) -C $(BUILDDIR)/latex all-pdf - @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." - -text: - $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text - @echo - @echo "Build finished. The text files are in $(BUILDDIR)/text." - -man: - $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man - @echo - @echo "Build finished. The manual pages are in $(BUILDDIR)/man." - -texinfo: - $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo - @echo - @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." - @echo "Run \`make' in that directory to run these through makeinfo" \ - "(use \`make info' here to do that automatically)." - -info: - $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo - @echo "Running Texinfo files through makeinfo..." - make -C $(BUILDDIR)/texinfo info - @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." - -gettext: - $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale - @echo - @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." - -changes: - $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes - @echo - @echo "The overview file is in $(BUILDDIR)/changes." - -linkcheck: - $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck - @echo - @echo "Link check complete; look for any errors in the above output " \ - "or in $(BUILDDIR)/linkcheck/output.txt." + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -doctest: - $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest - @echo "Testing of doctests in the sources finished, look at the " \ - "results in $(BUILDDIR)/doctest/output.txt." +.PHONY: help Makefile -spelling: - $(SPHINXBUILD) -b spelling $(ALLSPHINXOPTS) $(BUILDDIR)/spelling - @echo - @echo "Check finished. Wrong words can be found in " \ - "$(BUILDDIR)/spelling/output.txt." +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/conf.py b/docs/conf.py index 2246c0287..145abafa4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,112 +1,64 @@ -# -*- coding: utf-8 -*- +# Configuration file for the Sphinx documentation builder. # -# websockets documentation build configuration file, created by -# sphinx-quickstart on Sun Mar 31 20:48:44 2013. -# -# This file is execfile()d with the current directory set to its containing dir. -# -# Note that not all possible configuration values are present in this -# autogenerated file. -# -# All configuration values have a default; values that are commented out -# serve to show the default. +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html -import sys, os, datetime +import datetime +import os +import sys + +# -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. + sys.path.insert(0, os.path.join(os.path.abspath('..'), 'src')) -# -- General configuration ----------------------------------------------------- -# If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' +# -- Project information ----------------------------------------------------- + +project = 'websockets' +copyright = f'2013-{datetime.date.today().year}, Aymeric Augustin and contributors' +author = 'Aymeric Augustin' + +# The full version, including alpha/beta/rc tags +release = '9.1' + -# Add any Sphinx extension module names here, as strings. They can be extensions -# coming with Sphinx (named 'sphinx.ext.*') or your custom ones. +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. extensions = [ 'sphinx.ext.autodoc', 'sphinx.ext.intersphinx', 'sphinx.ext.viewcode', 'sphinx_autodoc_typehints', + 'sphinxcontrib.spelling', 'sphinxcontrib_trio', - ] +] -# Spelling check needs an additional module that is not installed by default. -# Add it only if spelling check is requested so docs can be generated without it. -if 'spelling' in sys.argv: - extensions.append('sphinxcontrib.spelling') +intersphinx_mapping = {'python': ('https://docs.python.org/3', None)} # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] -# The suffix of source filenames. -source_suffix = '.rst' - -# The encoding of source files. -#source_encoding = 'utf-8-sig' - -# The master toctree document. -master_doc = 'index' - -# General information about the project. -project = 'websockets' -copyright = f'2013-{datetime.date.today().year}, Aymeric Augustin and contributors' - -# The version info for the project you're documenting, acts as replacement for -# |version| and |release|, also used in various other places throughout the -# built documents. -# -# The short X.Y version. -version = '9.1' -# The full version, including alpha/beta/rc tags. -release = '9.1' - -# The language for content autogenerated by Sphinx. Refer to documentation -# for a list of supported languages. -#language = None - -# There are two options for replacing |today|: either, you set today to some -# non-false value, then it is used: -#today = '' -# Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' - # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ['_build'] - -# The reST default role (used for this markup: `text`) to use for all documents. -#default_role = None - -# If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True - -# If true, the current module name will be prepended to all description -# unit titles (such as .. function::). -#add_module_names = True +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] -# If true, sectionauthor and moduleauthor directives will be shown in the -# output. They are ignored by default. -#show_authors = False -# The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' - -# A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] - - -# -- Options for HTML output --------------------------------------------------- +# -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. +# html_theme = 'alabaster' -# Theme options are theme-specific and customize the look and feel of a theme -# further. For a list of options available for each theme, see the -# documentation. html_theme_options = { 'logo': 'websockets.svg', 'description': 'A library for building WebSocket servers and clients in Python with a focus on correctness and simplicity.', @@ -117,39 +69,6 @@ 'tidelift_url': 'https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=docs', } -# Add any paths that contain custom themes here, relative to this directory. -#html_theme_path = [] - -# The name for this set of Sphinx documents. If None, it defaults to -# " v documentation". -#html_title = None - -# A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None - -# The name of an image file (relative to this directory) to place at the top -# of the sidebar. -#html_logo = None - -# The name of an image file (within the static path) to use as favicon of the -# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 -# pixels large. -#html_favicon = None - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] - -# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, -# using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' - -# If true, SmartyPants will be used to convert quotes and dashes to -# typographically correct entities. -#html_use_smartypants = True - -# Custom sidebar templates, maps document names to template names. html_sidebars = { '**': [ 'about.html', @@ -160,114 +79,7 @@ ] } -# Additional templates that should be rendered to pages, maps page names to -# template names. -#html_additional_pages = {} - -# If false, no module index is generated. -#html_domain_indices = True - -# If false, no index is generated. -#html_use_index = True - -# If true, the index is split into individual pages for each letter. -#html_split_index = False - -# If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True - -# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = True - -# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True - -# If true, an OpenSearch description file will be output, and all pages will -# contain a tag referring to it. The value of this option must be the -# base URL from which the finished HTML is served. -#html_use_opensearch = '' - -# This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None - -# Output file base name for HTML help builder. -htmlhelp_basename = 'websocketsdoc' - - -# -- Options for LaTeX output -------------------------------------------------- - -latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -#'papersize': 'letterpaper', - -# The font size ('10pt', '11pt' or '12pt'). -#'pointsize': '10pt', - -# Additional stuff for the LaTeX preamble. -#'preamble': '', -} - -# Grouping the document tree into LaTeX files. List of tuples -# (source start file, target name, title, author, documentclass [howto/manual]). -latex_documents = [ - ('index', 'websockets.tex', 'websockets Documentation', - 'Aymeric Augustin', 'manual'), -] - -# The name of an image file (relative to this directory) to place at the top of -# the title page. -#latex_logo = None - -# For "manual" documents, if this is true, then toplevel headings are parts, -# not chapters. -#latex_use_parts = False - -# If true, show page references after internal links. -#latex_show_pagerefs = False - -# If true, show URL addresses after external links. -#latex_show_urls = False - -# Documents to append as an appendix to all manuals. -#latex_appendices = [] - -# If false, no module index is generated. -#latex_domain_indices = True - - -# -- Options for manual page output -------------------------------------------- - -# One entry per manual page. List of tuples -# (source start file, name, description, authors, manual section). -man_pages = [ - ('index', 'websockets', 'websockets Documentation', - ['Aymeric Augustin'], 1) -] - -# If true, show URL addresses after external links. -#man_show_urls = False - - -# -- Options for Texinfo output ------------------------------------------------ - -# Grouping the document tree into Texinfo files. List of tuples -# (source start file, target name, title, author, -# dir menu entry, description, category) -texinfo_documents = [ - ('index', 'websockets', 'websockets Documentation', - 'Aymeric Augustin', 'websockets', 'One line description of project.', - 'Miscellaneous'), -] - -# Documents to append as an appendix to all manuals. -#texinfo_appendices = [] - -# If false, no module index is generated. -#texinfo_domain_indices = True - -# How to display URL addresses: 'footnote', 'no', or 'inline'. -#texinfo_show_urls = 'footnote' - - -# Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = {'https://docs.python.org/3/': None} +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 000000000..2119f5109 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=_build + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd From cc35695df01241752d01266ad3d86510505af070 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 14 Aug 2021 09:10:48 +0200 Subject: [PATCH 0911/1539] Reformat conf.py with black. --- docs/conf.py | 60 +++++++++++++++++++++++++--------------------------- 1 file changed, 29 insertions(+), 31 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 145abafa4..3523477c2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -13,18 +13,17 @@ # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. - -sys.path.insert(0, os.path.join(os.path.abspath('..'), 'src')) +sys.path.insert(0, os.path.join(os.path.abspath(".."), "src")) # -- Project information ----------------------------------------------------- -project = 'websockets' -copyright = f'2013-{datetime.date.today().year}, Aymeric Augustin and contributors' -author = 'Aymeric Augustin' +project = "websockets" +copyright = f"2013-{datetime.date.today().year}, Aymeric Augustin and contributors" +author = "Aymeric Augustin" # The full version, including alpha/beta/rc tags -release = '9.1' +release = "9.1" # -- General configuration --------------------------------------------------- @@ -33,53 +32,52 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.intersphinx', - 'sphinx.ext.viewcode', - 'sphinx_autodoc_typehints', - 'sphinxcontrib.spelling', - 'sphinxcontrib_trio', + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinx.ext.viewcode", + "sphinx_autodoc_typehints", + "sphinxcontrib.spelling", + "sphinxcontrib_trio", ] -intersphinx_mapping = {'python': ('https://docs.python.org/3', None)} +intersphinx_mapping = {"python": ("https://docs.python.org/3", None)} # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -# -html_theme = 'alabaster' +html_theme = "alabaster" html_theme_options = { - 'logo': 'websockets.svg', - 'description': 'A library for building WebSocket servers and clients in Python with a focus on correctness and simplicity.', - 'github_button': True, - 'github_type': 'star', - 'github_user': 'aaugustin', - 'github_repo': 'websockets', - 'tidelift_url': 'https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=docs', + "logo": "websockets.svg", + "description": "A library for building WebSocket servers and clients in Python with a focus on correctness and simplicity.", + "github_button": True, + "github_type": "star", + "github_user": "aaugustin", + "github_repo": "websockets", + "tidelift_url": "https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=docs", } html_sidebars = { - '**': [ - 'about.html', - 'searchbox.html', - 'navigation.html', - 'relations.html', - 'donate.html', + "**": [ + "about.html", + "searchbox.html", + "navigation.html", + "relations.html", + "donate.html", ] } # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] From 323c93297a0876f7fc47993b908200888a50afbb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 14 Aug 2021 09:07:58 +0200 Subject: [PATCH 0912/1539] Update spelling wordlist. List generated with this command, then checked and adjusted manually: $ make -C docs spelling | grep 'Spell check' | awk '{ print $4 }' \ sed 's/:$//' | tr '[:upper:]' '[:lower:]' | sort | uniq --- docs/conf.py | 2 ++ docs/spelling_wordlist.txt | 22 +++------------------- 2 files changed, 5 insertions(+), 19 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 3523477c2..e4273570d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -42,6 +42,8 @@ intersphinx_mapping = {"python": ("https://docs.python.org/3", None)} +spelling_show_suggestions = True + # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index b3389a920..3d05752d5 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1,16 +1,12 @@ -api -attr augustin auth autoscaler -awaitable aymeric backend backoff backpressure balancer balancers -bitcoin bottlenecked bufferbloat bugfix @@ -22,14 +18,10 @@ coroutines cryptocurrencies cryptocurrency ctrl -daemonize -datastructures django dyno -formatter fractalideas gunicorn -haproxy hypercorn iframe IPv @@ -37,48 +29,40 @@ istio iterable keepalive KiB -Kubernetes +kubernetes lifecycle linkerd liveness lookups MiB -mypy nginx -onmessage -parsers permessage pid -pong -pongs proxying pythonic reconnection redis +redistributions retransmit runtime scalable -serializers stateful subclasses subclassing +subpackages subprotocol subprotocols supervisord tidelift tls tox -unparse unregister uple -username uvicorn -uvloop virtualenv WebSocket websocket websockets ws wsgi -wss www From 5a6bd258ca29665ad8acd7709d989c57c2a8b2bc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 14 Aug 2021 11:31:15 +0200 Subject: [PATCH 0913/1539] Enable and configure Furo theme. This removes two features from the sidebar: the GitHub "stars" count and the Tidelift banner. --- docs/Makefile | 3 ++ docs/_static/favicon.ico | 1 + docs/conf.py | 37 +++++++++++++------------ docs/requirements.txt | 5 ++++ experiments/authentication/favicon.ico | Bin 5430 -> 22 bytes logo/favicon.ico | Bin 0 -> 5430 bytes 6 files changed, 29 insertions(+), 17 deletions(-) create mode 120000 docs/_static/favicon.ico mode change 100644 => 120000 experiments/authentication/favicon.ico create mode 100644 logo/favicon.ico diff --git a/docs/Makefile b/docs/Makefile index d4bb2cbb9..7a04f7827 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -18,3 +18,6 @@ help: # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +livehtml: + sphinx-autobuild "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/_static/favicon.ico b/docs/_static/favicon.ico new file mode 120000 index 000000000..dd7df921e --- /dev/null +++ b/docs/_static/favicon.ico @@ -0,0 +1 @@ +../../logo/favicon.ico \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index e4273570d..04d4c0557 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -36,8 +36,11 @@ "sphinx.ext.intersphinx", "sphinx.ext.viewcode", "sphinx_autodoc_typehints", + "sphinx_copybutton", + "sphinx_inline_tabs", "sphinxcontrib.spelling", "sphinxcontrib_trio", + "sphinxext.opengraph", ] intersphinx_mapping = {"python": ("https://docs.python.org/3", None)} @@ -57,29 +60,29 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = "alabaster" +html_theme = "furo" html_theme_options = { - "logo": "websockets.svg", - "description": "A library for building WebSocket servers and clients in Python with a focus on correctness and simplicity.", - "github_button": True, - "github_type": "star", - "github_user": "aaugustin", - "github_repo": "websockets", - "tidelift_url": "https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=docs", + "light_css_variables": { + "color-brand-primary": "#306998", # blue from logo + "color-brand-content": "#0b487a", # blue more saturated and less dark + }, + "dark_css_variables": { + "color-brand-primary": "#ffd43bcc", # yellow from logo, more muted than content + "color-brand-content": "#ffd43bd9", # yellow from logo, transparent like text + }, + "sidebar_hide_name": True, } -html_sidebars = { - "**": [ - "about.html", - "searchbox.html", - "navigation.html", - "relations.html", - "donate.html", - ] -} +html_logo = "_static/websockets.svg" + +html_favicon = "_static/favicon.ico" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ["_static"] + +html_copy_source = False + +html_show_sphinx = False diff --git a/docs/requirements.txt b/docs/requirements.txt index 0eaf94fbe..b9c371228 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,4 +1,9 @@ +furo sphinx +sphinx-autobuild sphinx-autodoc-typehints +sphinx-copybutton +sphinx-inline-tabs sphinxcontrib-spelling sphinxcontrib-trio +sphinxext-opengraph diff --git a/experiments/authentication/favicon.ico b/experiments/authentication/favicon.ico deleted file mode 100644 index 36e855029d705e72d44428bda6e8cb6d3dd317ed..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 5430 zcmeH~eNdFw6~-?im3Ahb_(wY9G?@uAnl{s!P5}k6EMLo)Am6YM4N*)C21rB%KSG3E zv;hMxiW&u@(MF93wOCQ3rD-sRXbhdyP7o3K#!S*q0u${7l(0J*XQZVahM#RhT&Mw%^BqMk{U@J1~1H2BJEU64r)0!B+s<_@|*2 zqhiBw1#eC^Kd~FBvDYyg*$#wV#Xwjm9t&v2p)n2r@@vI8+WgsgcG9|#V(P|dLI)7j zfdSf?v#0~lgtfyurUjuNe*Ta2+T*8ca~J%;(ZdETkY^gva%XNcqcj2kdN-G|0yeVF}ZFP=4Z!J2R#i0#02+Ir+*zVG@y z*5YS?yJuBTDvF9 z+e4qHUo$?zisgL(ZFDT{$HRZz5|LeB6ok5h0Mx8E{O!%Wpod<2-phQZDi;R}!Uptt z_Sffbz;E>+X1=#0_Q!9(68Y81O(FP$*?76Z67toHt(cKzH39p~78|EMbIvF7dZI9CxZHU=B0 z<9Jap>Rt-q+)$@;bvu`KC6(K3*mRuQ9MaslWomPE8-LktGQPhv{>AfKV-~kmXM9f| z|2D6HQbVN2tW*=RJWA^l;TTP>949?wSa4Hnl)qW|Kbb4f8Fvyy2h!?Rc2a%__sk=;L*m&%aOk`2^Q3=D%J$`lrna32wt&`ua?J=IYOX`)ey| zsDH*}F--J3_a!wE;q=XC^%sGy0H6D|y~p17*k>&_PGWEtANNEx9#Nf`$LIWDXd7VU z-lc0`F21g#6;|0_CyK+_IPED%Vmjd+bq#|NSGn)ee~4%xL9&NlK@0!?spEQZeA9Sw zOg84YE{XKby>e*LH9+>d2-$me4ei*@T2#ST@l)pX-aNEoMD}B9TkIH*)9-fc#{JDF zsS`tE`y^thLD#X6zKIf~ACX*0P0;qeI{2O1>Ze+hd$YWhu%sV8DP0&!?gSFKPsfuu z=^dyc`Wh;U66imQy~wci5ZZc7tYMe4y3>E-zOC46crNM2&=cI_Q|MnbFb{~Qf9lvq zl)zjpqW|G=)`eCR4jO{~dh5&OHAgQLZpnV9ydbWbam~@=o9J4d>8Y0WqGf3>kS6_3 zH=ygFmaRky%tJKeAJJ-p{_kqA)#Yh(*$efw^_AAoNKZthCz1G^u_xP8J>82TYJl`x zPtw0=VBb|il)!vMb4^CrJ8A0?#hMfy%!Slv9#Q+&`23lKphoC2`RmZF?C`@YT|BPQis$fd)qZc3H#T~U(G%W)mfglo{sU;_HV?#Giog+ zWdZKb7(gC1)GY7CGNOdEKE$SWVT5(5)r3};^sUup2XfhuD&bJR@3U{@`dIhp%!9oK z`hCKgZ~0C9j|Y3#aMr|#eN9k{rX#rbA2Ga;A`w_nTG-C ztBDe-vb<91XNIG%>P!!=&(~otdZi|$F7=v_x=f{SqJ-l`K`Vb(5MVvJ!Jze!ng;g7 z>?QBK{=AowD1llEwR+6*IO(WiJm0l|D{Ep{uL)g4S}$^l5>9OnTX~|`Xtnjy{y6%g zO~ax51$Wq&ClMvEKQ9vBhc`yq?ukr~>U(@1IJr6Gc0;i-fhgfV}k+ptmI=o*Qx0@RDGY}eIiR4)*!i%Vr#UY#Q>p{*{nuh%O zw-{rDA=?=E^vn-xf;^a)yc&4S#>DB;t~+h17%G6M9Y7Y%ttpeEt)~ z*kB1&)1jq0b@m6ll1AkWCmP=Q^&;&&DfRHy%i&r*$hok-P}6XevH3fT@AR)C)pTj8 zN-n>!BLP|-wu%PrJgTk5T@A${HyM{c{o!N z{uyKPol0D`gW)%O|F9}G?&6Mwv&}o=zc^nRJKR(eeeKWX(PfP1M^laIn|c0j*0+@b zZDY{i*oZ&{!b(QfqZlSb-*-dLEEhC+xWJOGU};h)wgh3bYC?&N1*5L&qtO{IrNa-n mx{c17;Wp~=fE$Kp5swEk(0J*XQZVahM#RhT&Mw%^BqMk{U@J1~1H2BJEU64r)0!B+s<_@|*2 zqhiBw1#eC^Kd~FBvDYyg*$#wV#Xwjm9t&v2p)n2r@@vI8+WgsgcG9|#V(P|dLI)7j zfdSf?v#0~lgtfyurUjuNe*Ta2+T*8ca~J%;(ZdETkY^gva%XNcqcj2kdN-G|0yeVF}ZFP=4Z!J2R#i0#02+Ir+*zVG@y z*5YS?yJuBTDvF9 z+e4qHUo$?zisgL(ZFDT{$HRZz5|LeB6ok5h0Mx8E{O!%Wpod<2-phQZDi;R}!Uptt z_Sffbz;E>+X1=#0_Q!9(68Y81O(FP$*?76Z67toHt(cKzH39p~78|EMbIvF7dZI9CxZHU=B0 z<9Jap>Rt-q+)$@;bvu`KC6(K3*mRuQ9MaslWomPE8-LktGQPhv{>AfKV-~kmXM9f| z|2D6HQbVN2tW*=RJWA^l;TTP>949?wSa4Hnl)qW|Kbb4f8Fvyy2h!?Rc2a%__sk=;L*m&%aOk`2^Q3=D%J$`lrna32wt&`ua?J=IYOX`)ey| zsDH*}F--J3_a!wE;q=XC^%sGy0H6D|y~p17*k>&_PGWEtANNEx9#Nf`$LIWDXd7VU z-lc0`F21g#6;|0_CyK+_IPED%Vmjd+bq#|NSGn)ee~4%xL9&NlK@0!?spEQZeA9Sw zOg84YE{XKby>e*LH9+>d2-$me4ei*@T2#ST@l)pX-aNEoMD}B9TkIH*)9-fc#{JDF zsS`tE`y^thLD#X6zKIf~ACX*0P0;qeI{2O1>Ze+hd$YWhu%sV8DP0&!?gSFKPsfuu z=^dyc`Wh;U66imQy~wci5ZZc7tYMe4y3>E-zOC46crNM2&=cI_Q|MnbFb{~Qf9lvq zl)zjpqW|G=)`eCR4jO{~dh5&OHAgQLZpnV9ydbWbam~@=o9J4d>8Y0WqGf3>kS6_3 zH=ygFmaRky%tJKeAJJ-p{_kqA)#Yh(*$efw^_AAoNKZthCz1G^u_xP8J>82TYJl`x zPtw0=VBb|il)!vMb4^CrJ8A0?#hMfy%!Slv9#Q+&`23lKphoC2`RmZF?C`@YT|BPQis$fd)qZc3H#T~U(G%W)mfglo{sU;_HV?#Giog+ zWdZKb7(gC1)GY7CGNOdEKE$SWVT5(5)r3};^sUup2XfhuD&bJR@3U{@`dIhp%!9oK z`hCKgZ~0C9j|Y3#aMr|#eN9k{rX#rbA2Ga;A`w_nTG-C ztBDe-vb<91XNIG%>P!!=&(~otdZi|$F7=v_x=f{SqJ-l`K`Vb(5MVvJ!Jze!ng;g7 z>?QBK{=AowD1llEwR+6*IO(WiJm0l|D{Ep{uL)g4S}$^l5>9OnTX~|`Xtnjy{y6%g zO~ax51$Wq&ClMvEKQ9vBhc`yq?ukr~>U(@1IJr6Gc0;i-fhgfV}k+ptmI=o*Qx0@RDGY}eIiR4)*!i%Vr#UY#Q>p{*{nuh%O zw-{rDA=?=E^vn-xf;^a)yc&4S#>DB;t~+h17%G6M9Y7Y%ttpeEt)~ z*kB1&)1jq0b@m6ll1AkWCmP=Q^&;&&DfRHy%i&r*$hok-P}6XevH3fT@AR)C)pTj8 zN-n>!BLP|-wu%PrJgTk5T@A${HyM{c{o!N z{uyKPol0D`gW)%O|F9}G?&6Mwv&}o=zc^nRJKR(eeeKWX(PfP1M^laIn|c0j*0+@b zZDY{i*oZ&{!b(QfqZlSb-*-dLEEhC+xWJOGU};h)wgh3bYC?&N1*5L&qtO{IrNa-n mx{c17;Wp~=fE$Kp5swEk Date: Sun, 15 Aug 2021 07:34:33 +0200 Subject: [PATCH 0914/1539] Update links to RFC after IETF website changes. --- docs/howto/logging.rst | 4 ++-- docs/project/changelog.rst | 2 +- docs/reference/extensions.rst | 2 +- docs/reference/limitations.rst | 4 +++- docs/topics/compression.rst | 2 +- docs/topics/design.rst | 10 ++++---- src/websockets/extensions/base.py | 2 +- .../extensions/permessage_deflate.py | 4 ++-- src/websockets/headers.py | 23 ++++++++++--------- src/websockets/http.py | 2 +- src/websockets/http11.py | 14 +++++------ src/websockets/legacy/client.py | 4 ++-- src/websockets/legacy/framing.py | 2 +- src/websockets/legacy/handshake.py | 2 +- src/websockets/legacy/http.py | 10 ++++---- src/websockets/legacy/protocol.py | 10 ++++---- src/websockets/legacy/server.py | 6 ++--- src/websockets/server.py | 2 +- src/websockets/uri.py | 4 ++-- 19 files changed, 56 insertions(+), 53 deletions(-) diff --git a/docs/howto/logging.rst b/docs/howto/logging.rst index 824812959..4e91557fa 100644 --- a/docs/howto/logging.rst +++ b/docs/howto/logging.rst @@ -35,8 +35,8 @@ Instead, when running as a server, websockets logs one event when a `connection is established`_ and another event when a `connection is closed`_. -.. _connection is established: https://datatracker.ietf.org/doc/html/rfc6455#section-4 -.. _connection is closed: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.4 +.. _connection is established: https://www.rfc-editor.org/rfc/rfc6455.html#section-4 +.. _connection is closed: https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.4 websockets doesn't log an event for every message because that would be excessive for many applications exchanging small messages at a fast rate. diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 277ee5022..07942783f 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -747,7 +747,7 @@ Also: * Added support for providing and checking Origin_. -.. _Origin: https://tools.ietf.org/html/rfc6455#section-10.2 +.. _Origin: https://www.rfc-editor.org/rfc/rfc6455.html#section-10.2 2.0 ... diff --git a/docs/reference/extensions.rst b/docs/reference/extensions.rst index 6ca20dbc8..bae583a21 100644 --- a/docs/reference/extensions.rst +++ b/docs/reference/extensions.rst @@ -8,7 +8,7 @@ The WebSocket protocol supports extensions_. At the time of writing, there's only one `registered extension`_ with a public specification, WebSocket Per-Message Deflate, specified in :rfc:`7692`. -.. _extensions: https://tools.ietf.org/html/rfc6455#section-9 +.. _extensions: https://www.rfc-editor.org/rfc/rfc6455.html#section-9 .. _registered extension: https://www.iana.org/assignments/websocket/websocket.xhtml#extension-name Per-Message Deflate diff --git a/docs/reference/limitations.rst b/docs/reference/limitations.rst index ecdde23bf..81f1445b5 100644 --- a/docs/reference/limitations.rst +++ b/docs/reference/limitations.rst @@ -6,9 +6,11 @@ Client The client doesn't attempt to guarantee that there is no more than one connection to a given IP address in a CONNECTING state. This behavior is -mandated by :rfc:`6455`. However, :func:`~websockets.connect()` isn't the +`mandated by RFC 6455`_. However, :func:`~websockets.connect()` isn't the right layer for enforcing this constraint. It's the caller's responsibility. +.. _mandated by RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-4.1 + The client doesn't support connecting through a HTTP proxy (`issue 364`_) or a SOCKS proxy (`issue 475`_). diff --git a/docs/topics/compression.rst b/docs/topics/compression.rst index e23319636..f78e32748 100644 --- a/docs/topics/compression.rst +++ b/docs/topics/compression.rst @@ -162,4 +162,4 @@ settings affect memory usage and how to optimize them. This `experiment by Peter Thorson`_ recommends Window Bits = 11 and Memory Level = 4 for optimizing memory usage. -.. _experiment by Peter Thorson: https://www.ietf.org/mail-archive/web/hybi/current/msg10222.html +.. _experiment by Peter Thorson: https://mailarchive.ietf.org/arch/msg/hybi/F9t4uPufVEy8KBLuL36cZjCmM_Y/ diff --git a/docs/topics/design.rst b/docs/topics/design.rst index fa2093433..2c9d505aa 100644 --- a/docs/topics/design.rst +++ b/docs/topics/design.rst @@ -173,16 +173,16 @@ differences between a server and a client: - `closing the TCP connection`_: the server closes the connection immediately; the client waits for the server to do it. -.. _client-to-server masking: https://tools.ietf.org/html/rfc6455#section-5.3 -.. _closing the TCP connection: https://tools.ietf.org/html/rfc6455#section-5.5.1 +.. _client-to-server masking: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.3 +.. _closing the TCP connection: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.1 These differences are so minor that all the logic for `data framing`_, for `sending and receiving data`_ and for `closing the connection`_ is implemented in the same class, :class:`~legacy.protocol.WebSocketCommonProtocol`. -.. _data framing: https://tools.ietf.org/html/rfc6455#section-5 -.. _sending and receiving data: https://tools.ietf.org/html/rfc6455#section-6 -.. _closing the connection: https://tools.ietf.org/html/rfc6455#section-7 +.. _data framing: https://www.rfc-editor.org/rfc/rfc6455.html#section-5 +.. _sending and receiving data: https://www.rfc-editor.org/rfc/rfc6455.html#section-6 +.. _closing the connection: https://www.rfc-editor.org/rfc/rfc6455.html#section-7 The :attr:`~legacy.protocol.WebSocketCommonProtocol.is_client` attribute tells which side a protocol instance is managing. This attribute is defined on the diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index 2de9176bd..7217aa513 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -4,7 +4,7 @@ See `section 9 of RFC 6455`_. -.. _section 9 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-9 +.. _section 9 of RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-9 """ diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 0c9088a9e..59f019d53 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -268,7 +268,7 @@ class ClientPerMessageDeflateFactory(ClientExtensionFactory): ``True`` to include them in the negotiation offer without a value or to an integer value to include them with this value. - .. _section 7.1 of RFC 7692: https://tools.ietf.org/html/rfc7692#section-7.1 + .. _section 7.1 of RFC 7692: https://www.rfc-editor.org/rfc/rfc7692.html#section-7.1 :param server_no_context_takeover: defaults to ``False`` :param client_no_context_takeover: defaults to ``False`` @@ -466,7 +466,7 @@ class ServerPerMessageDeflateFactory(ServerExtensionFactory): ``True`` to include them in the negotiation offer without a value or to an integer value to include them with this value. - .. _section 7.1 of RFC 7692: https://tools.ietf.org/html/rfc7692#section-7.1 + .. _section 7.1 of RFC 7692: https://www.rfc-editor.org/rfc/rfc7692.html#section-7.1 :param server_no_context_takeover: defaults to ``False`` :param client_no_context_takeover: defaults to ``False`` diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 181e976e3..ee6dd1672 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -40,8 +40,8 @@ # To avoid a dependency on a parsing library, we implement manually the ABNF -# described in https://tools.ietf.org/html/rfc6455#section-9.1 with the -# definitions from https://tools.ietf.org/html/rfc7230#appendix-B. +# described in https://www.rfc-editor.org/rfc/rfc6455.html#section-9.1 and +# https://www.rfc-editor.org/rfc/rfc7230.html#appendix-B. def peek_ahead(header: str, pos: int) -> Optional[str]: @@ -161,9 +161,9 @@ def parse_list( :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. """ - # Per https://tools.ietf.org/html/rfc7230#section-7, "a recipient MUST - # parse and ignore a reasonable number of empty list elements"; hence - # while loops that remove extra delimiters. + # Per https://www.rfc-editor.org/rfc/rfc7230.html#section-7, "a recipient + # MUST parse and ignore a reasonable number of empty list elements"; + # hence while loops that remove extra delimiters. # Remove extra delimiters before the first item. while peek_ahead(header, pos) == ",": @@ -289,8 +289,9 @@ def parse_extension_item_param( if peek_ahead(header, pos) == '"': pos_before = pos # for proper error reporting below value, pos = parse_quoted_string(header, pos, header_name) - # https://tools.ietf.org/html/rfc6455#section-9.1 says: the value - # after quoted-string unescaping MUST conform to the 'token' ABNF. + # https://www.rfc-editor.org/rfc/rfc6455.html#section-9.1 says: + # the value after quoted-string unescaping MUST conform to + # the 'token' ABNF. if _token_re.fullmatch(value) is None: raise exceptions.InvalidHeaderFormat( header_name, "invalid quoted header content", header, pos_before @@ -452,7 +453,7 @@ def build_www_authenticate_basic(realm: str) -> str: :param realm: authentication realm """ - # https://tools.ietf.org/html/rfc7617#section-2 + # https://www.rfc-editor.org/rfc/rfc7617.html#section-2 realm = build_quoted_string(realm) charset = build_quoted_string("UTF-8") return f"Basic realm={realm}, charset={charset}" @@ -498,8 +499,8 @@ def parse_authorization_basic(header: str) -> Tuple[str, str]: :raises InvalidHeaderValue: on unsupported inputs """ - # https://tools.ietf.org/html/rfc7235#section-2.1 - # https://tools.ietf.org/html/rfc7617#section-2 + # https://www.rfc-editor.org/rfc/rfc7235.html#section-2.1 + # https://www.rfc-editor.org/rfc/rfc7617.html#section-2 scheme, pos = parse_token(header, 0, "Authorization") if scheme.lower() != "basic": raise exceptions.InvalidHeaderValue( @@ -539,7 +540,7 @@ def build_authorization_basic(username: str, password: str) -> str: This is the reverse of :func:`parse_authorization_basic`. """ - # https://tools.ietf.org/html/rfc7617#section-2 + # https://www.rfc-editor.org/rfc/rfc7617.html#section-2 assert ":" not in username user_pass = f"{username}:{password}" basic_credentials = base64.b64encode(user_pass.encode()).decode() diff --git a/src/websockets/http.py b/src/websockets/http.py index 6168c5144..38848b56d 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -36,7 +36,7 @@ def build_host(host: str, port: int, secure: bool) -> str: Build a ``Host`` header. """ - # https://tools.ietf.org/html/rfc3986#section-3.2.2 + # https://www.rfc-editor.org/rfc/rfc3986.html#section-3.2.2 # IPv6 addresses must be enclosed in brackets. try: address = ipaddress.ip_address(host) diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 11b9d7f39..daa0efffb 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -19,7 +19,7 @@ def d(value: bytes) -> str: return value.decode(errors="backslashreplace") -# See https://tools.ietf.org/html/rfc7230#appendix-B. +# See https://www.rfc-editor.org/rfc/rfc7230.html#appendix-B. # Regex for validating header names. @@ -78,7 +78,7 @@ def parse( :raises ValueError: if the request isn't well formatted """ - # https://tools.ietf.org/html/rfc7230#section-3.1.1 + # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.1 # Parsing is simple because fixed values are expected for method and # version and because path isn't checked. Since WebSocket software tends @@ -102,7 +102,7 @@ def parse( headers = yield from parse_headers(read_line) - # https://tools.ietf.org/html/rfc7230#section-3.3.3 + # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.3.3 if "Transfer-Encoding" in headers: raise NotImplementedError("transfer codings aren't supported") @@ -166,7 +166,7 @@ def parse( :raises ValueError: if the response isn't well formatted """ - # https://tools.ietf.org/html/rfc7230#section-3.1.2 + # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.2 # As in parse_request, parsing is simple because a fixed value is expected # for version, status_code is a 3-digit number, and reason can be ignored. @@ -197,7 +197,7 @@ def parse( headers = yield from parse_headers(read_line) - # https://tools.ietf.org/html/rfc7230#section-3.3.3 + # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.3.3 if "Transfer-Encoding" in headers: raise NotImplementedError("transfer codings aren't supported") @@ -251,7 +251,7 @@ def parse_headers( line or raises an exception if there isn't enough data """ - # https://tools.ietf.org/html/rfc7230#section-3.2 + # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.2 # We don't attempt to support obsolete line folding. @@ -301,7 +301,7 @@ def parse_line( # Security: this guarantees header values are small (hard-coded = 4 KiB) if len(line) > MAX_LINE: raise exceptions.SecurityError("line too long") - # Not mandatory but safe - https://tools.ietf.org/html/rfc7230#section-3.5 + # Not mandatory but safe - https://www.rfc-editor.org/rfc/rfc7230.html#section-3.5 if not line.endswith(b"\r\n"): raise EOFError("line without CRLF") return line[:-2] diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index bad2cbeea..6d976e0df 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -83,14 +83,14 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): timeouts on inactive connections. Set ``ping_interval`` to ``None`` to disable this behavior. - .. _Ping frame: https://tools.ietf.org/html/rfc6455#section-5.5.2 + .. _Ping frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 If the corresponding `Pong frame`_ isn't received within ``ping_timeout`` seconds, the connection is considered unusable and is closed with code 1011. This ensures that the remote endpoint remains responsive. Set ``ping_timeout`` to ``None`` to disable this behavior. - .. _Pong frame: https://tools.ietf.org/html/rfc6455#section-5.5.3 + .. _Pong frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 The ``close_timeout`` parameter defines a maximum wait time for completing the closing handshake and terminating the TCP connection. For legacy diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index c8ae48690..40cbd41bf 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -6,7 +6,7 @@ See `section 5 of RFC 6455`_. -.. _section 5 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-5 +.. _section 5 of RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-5 """ diff --git a/src/websockets/legacy/handshake.py b/src/websockets/legacy/handshake.py index 49d08cfe8..7cde58ac1 100644 --- a/src/websockets/legacy/handshake.py +++ b/src/websockets/legacy/handshake.py @@ -3,7 +3,7 @@ See `section 4 of RFC 6455`_. -.. _section 4 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-4 +.. _section 4 of RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-4 Some checks cannot be performed because they depend too much on the context; instead, they're documented below. diff --git a/src/websockets/legacy/http.py b/src/websockets/legacy/http.py index 0b9a92267..3725fa9c3 100644 --- a/src/websockets/legacy/http.py +++ b/src/websockets/legacy/http.py @@ -22,7 +22,7 @@ def d(value: bytes) -> str: return value.decode(errors="backslashreplace") -# See https://tools.ietf.org/html/rfc7230#appendix-B. +# See https://www.rfc-editor.org/rfc/rfc7230.html#appendix-B. # Regex for validating header names. @@ -61,7 +61,7 @@ async def read_request(stream: asyncio.StreamReader) -> Tuple[str, Headers]: :raises ValueError: if the request isn't well formatted """ - # https://tools.ietf.org/html/rfc7230#section-3.1.1 + # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.1 # Parsing is simple because fixed values are expected for method and # version and because path isn't checked. Since WebSocket software tends @@ -105,7 +105,7 @@ async def read_response(stream: asyncio.StreamReader) -> Tuple[int, str, Headers :raises ValueError: if the response isn't well formatted """ - # https://tools.ietf.org/html/rfc7230#section-3.1.2 + # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.2 # As in read_request, parsing is simple because a fixed value is expected # for version, status_code is a 3-digit number, and reason can be ignored. @@ -144,7 +144,7 @@ async def read_headers(stream: asyncio.StreamReader) -> Headers: Non-ASCII characters are represented with surrogate escapes. """ - # https://tools.ietf.org/html/rfc7230#section-3.2 + # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.2 # We don't attempt to support obsolete line folding. @@ -189,7 +189,7 @@ async def read_line(stream: asyncio.StreamReader) -> bytes: # Security: this guarantees header values are small (hard-coded = 4 KiB) if len(line) > MAX_LINE: raise SecurityError("line too long") - # Not mandatory but safe - https://tools.ietf.org/html/rfc7230#section-3.5 + # Not mandatory but safe - https://www.rfc-editor.org/rfc/rfc7230.html#section-3.5 if not line.endswith(b"\r\n"): raise EOFError("line without CRLF") return line[:-2] diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 339733db2..7b24ccf98 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -3,7 +3,7 @@ See `sections 4 to 8 of RFC 6455`_. -.. _sections 4 to 8 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-4 +.. _sections 4 to 8 of RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-4 """ @@ -480,8 +480,8 @@ async def send( bytes-like object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent as a `Binary frame`_. - .. _Text frame: https://tools.ietf.org/html/rfc6455#section-5.6 - .. _Binary frame: https://tools.ietf.org/html/rfc6455#section-5.6 + .. _Text frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 :meth:`send` also accepts an iterable or an asynchronous iterable of strings, bytestrings, or bytes-like objects. In that case the message @@ -1420,8 +1420,8 @@ def broadcast(websockets: Iterable[WebSocketCommonProtocol], message: Data) -> N bytes-like object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent as a `Binary frame`_. - .. _Text frame: https://tools.ietf.org/html/rfc6455#section-5.6 - .. _Binary frame: https://tools.ietf.org/html/rfc6455#section-5.6 + .. _Text frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 :func:`broadcast` pushes the message synchronously to all connections even if their write buffers overflow ``write_limit``. There's no backpressure. diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 297fc5664..ddf6d9f87 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -101,14 +101,14 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): timeouts on inactive connections. Set ``ping_interval`` to ``None`` to disable this behavior. - .. _Ping frame: https://tools.ietf.org/html/rfc6455#section-5.5.2 + .. _Ping frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 If the corresponding `Pong frame`_ isn't received within ``ping_timeout`` seconds, the connection is considered unusable and is closed with code 1011. This ensures that the remote endpoint remains responsive. Set ``ping_timeout`` to ``None`` to disable this behavior. - .. _Pong frame: https://tools.ietf.org/html/rfc6455#section-5.5.3 + .. _Pong frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 The ``close_timeout`` parameter defines a maximum wait time for completing the closing handshake and terminating the TCP connection. For legacy @@ -454,7 +454,7 @@ def process_origin( """ # "The user agent MUST NOT include more than one Origin header field" - # per https://tools.ietf.org/html/rfc6454#section-7.3. + # per https://www.rfc-editor.org/rfc/rfc6454.html#section-7.3. try: origin = cast(Optional[Origin], headers.get("Origin")) except MultipleValuesError as exc: diff --git a/src/websockets/server.py b/src/websockets/server.py index 67183c685..0ae0ae940 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -222,7 +222,7 @@ def process_origin(self, headers: Headers) -> Optional[Origin]: """ # "The user agent MUST NOT include more than one Origin header field" - # per https://tools.ietf.org/html/rfc6454#section-7.3. + # per https://www.rfc-editor.org/rfc/rfc6454.html#section-7.3. try: origin = cast(Optional[Origin], headers.get("Origin")) except MultipleValuesError as exc: diff --git a/src/websockets/uri.py b/src/websockets/uri.py index ed4521d53..1ab895a21 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -3,7 +3,7 @@ See `section 3 of RFC 6455`_. -.. _section 3 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-3 +.. _section 3 of RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-3 """ @@ -31,7 +31,7 @@ class WebSocketURI: :param str user_info: ``(username, password)`` tuple when the URI contains `User Information`_, else ``None``. - .. _User Information: https://tools.ietf.org/html/rfc3986#section-3.2.1 + .. _User Information: https://www.rfc-editor.org/rfc/rfc3986.html#section-3.2.1 """ secure: bool From 4407d02b8f69f654c90cfdc8ce261377e008ce16 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 15 Aug 2021 08:01:56 +0200 Subject: [PATCH 0915/1539] Auto-rebuild docs on docstring changes. --- docs/Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/Makefile b/docs/Makefile index 7a04f7827..045870645 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -20,4 +20,4 @@ help: @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) livehtml: - sphinx-autobuild "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + sphinx-autobuild --watch "$(SOURCEDIR)/../src" "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) From 232351294683b422c98451536824d5d1cdca7f58 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 15 Aug 2021 09:19:35 +0200 Subject: [PATCH 0916/1539] Use a better exception type. --- src/websockets/legacy/protocol.py | 4 ++-- tests/legacy/test_protocol.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 7b24ccf98..99a821be6 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -697,7 +697,7 @@ async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: :raises ~websockets.exceptions.ConnectionClosed: when the connection is closed - :raises ValueError: if another ping was sent with the same data and + :raises RuntimeError: if another ping was sent with the same data and the corresponding pong wasn't received yet """ @@ -708,7 +708,7 @@ async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: # Protect against duplicates if a payload is explicitly set. if data in self.pings: - raise ValueError("already waiting for a pong with the same data") + raise RuntimeError("already waiting for a pong with the same data") # Generate a unique random payload otherwise. while data is None or data in self.pings: diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 61a5fe7cf..1672ab1ed 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1021,7 +1021,7 @@ def test_canceled_ping(self): def test_duplicate_ping(self): self.loop.run_until_complete(self.protocol.ping(b"foobar")) self.assertOneFrameSent(True, OP_PING, b"foobar") - with self.assertRaises(ValueError): + with self.assertRaises(RuntimeError): self.loop.run_until_complete(self.protocol.ping(b"foobar")) self.assertNoFrameSent() From 6a0cb60d069c9c4db6be1da3e539b672b54a8c93 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 17 Aug 2021 08:08:57 +0200 Subject: [PATCH 0917/1539] Move logging document to topics. It's more a discussion than a how-to. --- docs/howto/index.rst | 3 +-- docs/topics/index.rst | 1 + docs/{howto => topics}/logging.rst | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename docs/{howto => topics}/logging.rst (100%) diff --git a/docs/howto/index.rst b/docs/howto/index.rst index e5af8488e..a5779f6c9 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -9,13 +9,12 @@ If you're stuck, perhaps you'll find the answer in the FAQ or the cheat sheet. faq cheatsheet -The following guides will help you integrate websockets into a broader system. +This guide will help you integrate websockets into a broader system. .. toctree:: :maxdepth: 2 django - logging The WebSocket protocol makes provisions for extending or specializing its features, which websockets supports fully. diff --git a/docs/topics/index.rst b/docs/topics/index.rst index 993303106..e3b20d73f 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -5,6 +5,7 @@ Topics :maxdepth: 2 deployment + logging authentication broadcast compression diff --git a/docs/howto/logging.rst b/docs/topics/logging.rst similarity index 100% rename from docs/howto/logging.rst rename to docs/topics/logging.rst From 6881cea2cea2ecfc9430673ac042e7188b8a4125 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 17 Aug 2021 08:47:36 +0200 Subject: [PATCH 0918/1539] Improve intro slightly. --- docs/intro/index.rst | 32 +++++++++++++++++++++----------- example/hello.py | 3 +-- example/secure_client.py | 5 +---- example/secure_server.py | 5 +---- 4 files changed, 24 insertions(+), 21 deletions(-) diff --git a/docs/intro/index.rst b/docs/intro/index.rst index 1df92ed59..ff3ae0ffd 100644 --- a/docs/intro/index.rst +++ b/docs/intro/index.rst @@ -8,9 +8,11 @@ Requirements websockets requires Python ≥ 3.7. -You should use the latest version of Python if possible. If you're using an -older version, be aware that for each minor version (3.x), only the latest -bugfix release (3.x.y) is officially supported. +.. admonition:: Use the most recent Python release + :class: tip + + For each minor version (3.x), only the latest bugfix or security release + (3.x.y) is officially supported. Installation ------------ @@ -33,7 +35,7 @@ It reads a name from the client, sends a greeting, and closes the connection. .. _client-example: -On the server side, websockets executes the handler coroutine ``hello`` once +On the server side, websockets executes the handler coroutine ``hello()`` once for each WebSocket connection. It closes the connection when the handler coroutine returns. @@ -42,8 +44,8 @@ Here's a corresponding WebSocket client example. .. literalinclude:: ../../example/client.py :emphasize-lines: 10 -Using :func:`connect` as an asynchronous context manager ensures the -connection is closed before exiting the ``hello`` coroutine. +Using :func:`~client.connect` as an asynchronous context manager ensures the +connection is closed before exiting the ``hello()`` coroutine. .. _secure-server-example: @@ -53,20 +55,28 @@ Secure example Secure WebSocket connections improve confidentiality and also reliability because they reduce the risk of interference by bad proxies. -The WSS protocol is to WS what HTTPS is to HTTP: the connection is encrypted -with Transport Layer Security (TLS) — which is often referred to as Secure -Sockets Layer (SSL). WSS requires TLS certificates like HTTPS. +The ``wss`` protocol is to ``ws`` what ``https`` is to ``http``. The +connection is encrypted with TLS_ (Transport Layer Security). ``wss`` +requires certificates like ``https``. + +.. _TLS: https://developer.mozilla.org/en-US/docs/Web/Security/Transport_Layer_Security + +.. admonition:: TLS vs. SSL + :class: tip + + TLS is sometimes referred to as SSL (Secure Sockets Layer). SSL was an + earlier encryption protocol; the name stuck. Here's how to adapt the server example to provide secure connections. See the documentation of the :mod:`ssl` module for configuring the context securely. .. literalinclude:: ../../example/secure_server.py - :emphasize-lines: 19-21,26 + :emphasize-lines: 19-21,24 Here's how to adapt the client. .. literalinclude:: ../../example/secure_client.py - :emphasize-lines: 10-12,18 + :emphasize-lines: 10-12,16 This client needs a context because the server uses a self-signed certificate. diff --git a/example/hello.py b/example/hello.py index 96095dd02..84f55dc52 100755 --- a/example/hello.py +++ b/example/hello.py @@ -4,8 +4,7 @@ import websockets async def hello(): - uri = "ws://localhost:8765" - async with websockets.connect(uri) as websocket: + async with websockets.connect("ws://localhost:8765") as websocket: await websocket.send("Hello world!") await websocket.recv() diff --git a/example/secure_client.py b/example/secure_client.py index 518819dd1..8a1551e29 100755 --- a/example/secure_client.py +++ b/example/secure_client.py @@ -13,10 +13,7 @@ async def hello(): uri = "wss://localhost:8765" - async with websockets.connect( - uri, - ssl=ssl_context, - ) as websocket: + async with websockets.connect(uri, ssl=ssl_context) as websocket: name = input("What's your name? ") await websocket.send(name) diff --git a/example/secure_server.py b/example/secure_server.py index 96c300390..cd8ee0cc1 100755 --- a/example/secure_server.py +++ b/example/secure_server.py @@ -21,10 +21,7 @@ async def hello(websocket, path): ssl_context.load_cert_chain(localhost_pem) async def main(): - async with websockets.serve( - hello, "localhost", 8765, - ssl=ssl_context, - ): + async with websockets.serve(hello, "localhost", 8765, ssl=ssl_context): await asyncio.Future() # run forever asyncio.run(main()) From 73ae74cdbcb5b43efeaa1615da88ca90bb62ecc4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 17 Aug 2021 08:51:31 +0200 Subject: [PATCH 0919/1539] Clean up tables of contents. --- docs/howto/index.rst | 8 ++++---- docs/index.rst | 43 +--------------------------------------- docs/project/index.rst | 8 +++++--- docs/reference/index.rst | 6 +++--- docs/topics/index.rst | 8 +++++--- 5 files changed, 18 insertions(+), 55 deletions(-) diff --git a/docs/howto/index.rst b/docs/howto/index.rst index a5779f6c9..d7c83dd7a 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -4,7 +4,7 @@ How-to guides If you're stuck, perhaps you'll find the answer in the FAQ or the cheat sheet. .. toctree:: - :maxdepth: 2 + :titlesonly: faq cheatsheet @@ -12,7 +12,7 @@ If you're stuck, perhaps you'll find the answer in the FAQ or the cheat sheet. This guide will help you integrate websockets into a broader system. .. toctree:: - :maxdepth: 2 + :titlesonly: django @@ -20,7 +20,7 @@ The WebSocket protocol makes provisions for extending or specializing its features, which websockets supports fully. .. toctree:: - :maxdepth: 2 + :titlesonly: extensions @@ -29,7 +29,7 @@ features, which websockets supports fully. Once your application is ready, learn how to deploy it on various platforms. .. toctree:: - :maxdepth: 2 + :titlesonly: heroku kubernetes diff --git a/docs/index.rst b/docs/index.rst index 000c19ca7..30d01c2f8 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -40,52 +40,11 @@ And here's an echo server: Do you like it? Let's dive in! -Tutorials ---------- - -If you're new to websockets, this is the place to start. - .. toctree:: - :maxdepth: 2 + :hidden: intro/index - -How-to guides -------------- - -These guides will help you build and deploy a websockets application. - -.. toctree:: - :maxdepth: 2 - howto/index - -Reference ---------- - -Find all the details you could ask for, and then some. - -.. toctree:: - :maxdepth: 2 - reference/index - -Topics ------- - -Get a deeper understanding of how websockets is built and why. - -.. toctree:: - :maxdepth: 2 - topics/index - -Project -------- - -This is about websockets-the-project rather than websockets-the-software. - -.. toctree:: - :maxdepth: 2 - project/index diff --git a/docs/project/index.rst b/docs/project/index.rst index 931fbe2d4..459146345 100644 --- a/docs/project/index.rst +++ b/docs/project/index.rst @@ -1,8 +1,10 @@ -Project -======= +About websockets +================ + +This is about websockets-the-project rather than websockets-the-software. .. toctree:: - :maxdepth: 2 + :titlesonly: changelog contributing diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 9fc4a0092..8d01c5b40 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -1,5 +1,5 @@ -Reference -========= +API reference +============= websockets provides complete client and server implementations, as shown in the :doc:`getting started guide <../intro/index>`. @@ -38,7 +38,7 @@ client and server connections. For convenience, common methods are documented both in the client API and server API. .. toctree:: - :maxdepth: 2 + :titlesonly: client server diff --git a/docs/topics/index.rst b/docs/topics/index.rst index e3b20d73f..120a3dd32 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -1,8 +1,10 @@ -Topics -====== +Topic guides +============ + +Get a deeper understanding of how websockets is built and why. .. toctree:: - :maxdepth: 2 + :titlesonly: deployment logging From f1d6345d2d88109f907e7f4e9d71f6204508e746 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 18 Aug 2021 08:23:33 +0200 Subject: [PATCH 0920/1539] Review logging doc. --- docs/topics/logging.rst | 50 +++++++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index 4e91557fa..393efeb71 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -10,7 +10,7 @@ When you run a WebSocket client, your code calls coroutines provided by websockets. If an error occurs, websockets tells you by raising an exception. For example, -it raises a :exc:`~exception.ConnectionClosed` exception if the other side +it raises a :exc:`~exceptions.ConnectionClosed` exception if the other side closes the connection. When you run a WebSocket server, websockets accepts connections, performs the @@ -38,12 +38,15 @@ closed`_. .. _connection is established: https://www.rfc-editor.org/rfc/rfc6455.html#section-4 .. _connection is closed: https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.4 -websockets doesn't log an event for every message because that would be -excessive for many applications exchanging small messages at a fast rate. -However, you could add this level of logging in your own code if necessary. +By default, websockets doesn't log an event for every message. That would be +excessive for many applications exchanging small messages at a fast rate. If +you need this level of detail, you could add logging in your own code. -See :ref:`log levels ` below for details of events logged by -websockets at each level. +Finally, you can enable debug logs to get details about everything websockets +is doing. This can be useful when developing clients as well as servers. + +See :ref:`log levels ` below for a list of events logged by +websockets logs at each log level. Configure logging ----------------- @@ -94,13 +97,14 @@ However, this technique runs into two problems: * Even with :meth:`str.format` style, you're restricted to attribute and index lookups, which isn't enough to implement some fairly simple requirements. -There's a better way. :func:`~server.serve` accepts a ``logger`` argument to -override the default :class:`~logging.Logger`. You can set ``logger`` to -a :class:`~logging.LoggerAdapter` that enriches logs. +There's a better way. :func:`~client.connect` and :func:`~server.serve` accept +a ``logger`` argument to override the default :class:`~logging.Logger`. You +can set ``logger`` to a :class:`~logging.LoggerAdapter` that enriches logs. -For example, if the server is behind a reverse proxy, ``remote_address`` gives +For example, if the server is behind a reverse +proxy, :attr:`~legacy.protocol.WebSocketCommonProtocol.remote_address` gives the IP address of the proxy, which isn't useful. IP addresses of clients are -generally available in a HTTP header set by the proxy. +provided in a HTTP header set by the proxy. Here's how to include them in logs, assuming they're in the ``X-Forwarded-For`` header:: @@ -111,6 +115,7 @@ Here's how to include them in logs, assuming they're in the ) class LoggerAdapter(logging.LoggerAdapter): + """Add connection ID and client IP address to websockets logs.""" def process(self, msg, kwargs): try: websocket = kwargs["extra"]["websocket"] @@ -148,6 +153,7 @@ Finally, we populate the ``event_data`` custom attribute in log records with a :class:`~logging.LoggerAdapter`:: class LoggerAdapter(logging.LoggerAdapter): + """Add connection ID and client IP address to websockets logs.""" def process(self, msg, kwargs): try: websocket = kwargs["extra"]["websocket"] @@ -169,9 +175,9 @@ Disable logging --------------- If your application doesn't configure :mod:`logging`, Python outputs messages -of severity :data:`~logging.WARNING` and higher to :data:`~sys.stderr`. As a -consequence, you will see a message and a stack trace if a connection handler -coroutine crashes or if you hit a bug in websockets. +of severity ``WARNING`` and higher to :data:`~sys.stderr`. As a consequence, +you will see a message and a stack trace if a connection handler coroutine +crashes or if you hit a bug in websockets. If you want to disable this behavior for websockets, you can add a :class:`~logging.NullHandler`:: @@ -183,8 +189,8 @@ propagation to the root logger, or else its handlers could output logs:: logging.getLogger("websockets").propagate = False -Alternatively, you could set the log level to :data:`~logging.CRITICAL` for -websockets, as the highest level currently used is :data:`~logging.ERROR`:: +Alternatively, you could set the log level to ``CRITICAL`` for the +``"websockets"`` logger, as the highest level currently used is ``ERROR``:: logging.getLogger("websockets").setLevel(logging.CRITICAL) @@ -199,21 +205,21 @@ Log levels Here's what websockets logs at each level. -:attr:`~logging.ERROR` -...................... +``ERROR`` +......... * Exceptions raised by connection handler coroutines in servers * Exceptions resulting from bugs in websockets -:attr:`~logging.INFO` -..................... +``INFO`` +........ * Server starting and stopping * Server establishing and closing connections * Client reconnecting automatically -:attr:`~logging.DEBUG` -...................... +``DEBUG`` +......... * Changes to the state of connections * Handshake requests and responses From c7fc0d36bd8ea2aeb7c4321f53d208fb1297db85 Mon Sep 17 00:00:00 2001 From: Ben Hoyt Date: Fri, 20 Aug 2021 01:53:46 +1200 Subject: [PATCH 0921/1539] Speed up Python apply_mask 20x by using int.from_bytes/to_bytes (#1034) Speed up Python apply_mask 20x by using int.from_bytes/to_bytes This speeds up the Python version of utils.apply_mask about 20 times, using int.from_bytes so that the XOR is done in a single Python operation -- in other words, the loop over the bytes is in C rather than in Python. Note that it is a trade-off as it uses more memory: this version allocates roughly len(data) bytes for each of the intermediate values (e.g., data_int, mask_repeated, mask_int, the XOR result); whereas I believe the original version only allocates for the return value. Still, most websocket packets aren't huge, and I believe the massive speed gain here makes it worth it. (And people that use the speedups.c version won't be affected.) Obviously the speedups.c version is still significantly faster again, but this change makes the library more usable in environments when it's not feasible to use the C extension. Data Size ForLoop IntXor Speedups ------------------------------------ 1KB 78.6us 3.79us 151ns 1MB 79.7ms 4.38ms 55.4us I got these timings by using commands like the following (with the function call adjusted, and 1024 replaced with 1024*1024 as needed). python3 -m timeit \ -s 'from websockets.utils import apply_mask' \ -s 'data=b"x"*1024; mask=b"abcd"' \ 'apply_mask(data, mask)' This idea came from Will McGugan's blog post "Speeding up Websockets 60X": https://www.willmcgugan.com/blog/tech/post/speeding-up-websockets-60x/ That post contains an ever faster (about 50% faster) way to solve it using a pre-calculated XOR lookup table, but that pre-allocates a 64K-entry table at import time, which didn't seem ideal. Still, that is how aiohttp does it, so maybe it's worth considering: https://github.com/aio-libs/aiohttp/blob/6ec33c5d841c8e845c27ebdd9384bbf72651cbb8/aiohttp/http_websocket.py#L115-L140 The int.from_bytes approach is also the approach used by the websocket-client library: https://github.com/websocket-client/websocket-client/blob/5f32b3c0cfb836c016ad2a5f6caeff2978a6a16f/websocket/_abnf.py#L46-L50 Co-authored-by: Aymeric Augustin --- src/websockets/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/websockets/utils.py b/src/websockets/utils.py index ffb706963..c6e4b788c 100644 --- a/src/websockets/utils.py +++ b/src/websockets/utils.py @@ -2,8 +2,8 @@ import base64 import hashlib -import itertools import secrets +import sys __all__ = ["accept_key", "apply_mask"] @@ -43,4 +43,7 @@ def apply_mask(data: bytes, mask: bytes) -> bytes: if len(mask) != 4: raise ValueError("mask must contain 4 bytes") - return bytes(b ^ m for b, m in zip(data, itertools.cycle(mask))) + data_int = int.from_bytes(data, sys.byteorder) + mask_repeated = mask * (len(data) // 4) + mask[: len(data) % 4] + mask_int = int.from_bytes(mask_repeated, sys.byteorder) + return (data_int ^ mask_int).to_bytes(len(data), sys.byteorder) From 4b10c2c8e0334535ae4874d95dc7fdd6392bf68a Mon Sep 17 00:00:00 2001 From: Stefan Wojcik Date: Wed, 25 Aug 2021 11:02:42 +0200 Subject: [PATCH 0922/1539] Fix typo in the FAQ. --- docs/howto/faq.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index 394920b58..d43d5fd67 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -389,7 +389,7 @@ How do I keep idle connections open? websockets sends pings at 20 seconds intervals to keep the connection open. -In closes the connection if it doesn't get a pong within 20 seconds. +It closes the connection if it doesn't get a pong within 20 seconds. You can adjust this behavior with ``ping_interval`` and ``ping_timeout``. From 9e75d1e6fc4edc53f27d546362711d3bd22291b3 Mon Sep 17 00:00:00 2001 From: Ben Hoyt Date: Fri, 27 Aug 2021 14:14:48 +1200 Subject: [PATCH 0923/1539] Fix docstring of parse_uri to reflect the exception type raised It doesn't raise a ValueError, but an InvalidURI error (see below). If we want the exception type to be a ValueError as well, you could make InvalidURI inherit from ValueError as well as WebSocketException. But just changing the doc seems like the right fix here. >>> from websockets.uri import parse_uri >>> parse_uri('foo://example.com') Traceback (most recent call last): File "/home/ben/w/websockets/src/websockets/uri.py", line 58, in parse_uri assert parsed.scheme in ["ws", "wss"] AssertionError The above exception was the direct cause of the following exception: Traceback (most recent call last): File "", line 1, in File "/home/ben/w/websockets/src/websockets/uri.py", line 63, in parse_uri raise exceptions.InvalidURI(uri) from exc websockets.exceptions.InvalidURI: foo://example.com isn't a valid URI >>> import sys >>> exc = sys.last_value >>> isinstance(exc, ValueError) False --- src/websockets/uri.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 1ab895a21..c99c3f16e 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -49,7 +49,8 @@ def parse_uri(uri: str) -> WebSocketURI: """ Parse and validate a WebSocket URI. - :raises ValueError: if ``uri`` isn't a valid WebSocket URI. + :raises ~websockets.exceptions.InvalidURI: if ``uri`` isn't a valid + WebSocket URI. """ parsed = urllib.parse.urlparse(uri) From 5e0f002152fd42f08eb8b9f050fff6211384e30b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 28 Aug 2021 08:36:36 +0200 Subject: [PATCH 0924/1539] Switch from viewcode to linkcode. This makes docs builds noticeably faster as they no longer need to syntax-highlight and render all source code. --- docs/conf.py | 51 ++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index 04d4c0557..d5f527bf6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -5,7 +5,10 @@ # https://www.sphinx-doc.org/en/master/usage/configuration.html import datetime +import importlib +import inspect import os +import subprocess import sys # -- Path setup -------------------------------------------------------------- @@ -34,7 +37,7 @@ extensions = [ "sphinx.ext.autodoc", "sphinx.ext.intersphinx", - "sphinx.ext.viewcode", + "sphinx.ext.linkcode", "sphinx_autodoc_typehints", "sphinx_copybutton", "sphinx_inline_tabs", @@ -55,6 +58,52 @@ # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] +# Configure viewcode extension. +try: + git_sha1 = subprocess.run( + "git rev-parse --short HEAD", + capture_output=True, + shell=True, + check=True, + text=True, + ).stdout.strip() +except subprocess.SubprocessError as exc: + print("Cannot get git commit, disabling linkcode:", exc) + extensions.remove("sphinx.ext.linkcode") +else: + code_url = f"https://github.com/aaugustin/websockets/blob/{git_sha1}" + + +def linkcode_resolve(domain, info): + assert domain == "py" + + mod = importlib.import_module(info["module"]) + if "." in info["fullname"]: + objname, attrname = info["fullname"].split(".") + obj = getattr(mod, objname) + try: + # object is a method of a class + obj = getattr(obj, attrname) + except AttributeError: + # object is an attribute of a class + return None + else: + obj = getattr(mod, info["fullname"]) + + try: + file = inspect.getsourcefile(obj) + lines = inspect.getsourcelines(obj) + except TypeError: + # e.g. object is a typing.Union + return None + file = os.path.relpath(file, os.path.abspath("..")) + if not file.startswith("src/websockets"): + # e.g. object is a typing.NewType + return None + start, end = lines[1], lines[1] + len(lines[0]) - 1 + + return f"{code_url}/{file}#L{start}-L{end}" + # -- Options for HTML output ------------------------------------------------- From 1b0b2de59f13b79062a888a8d17ae5e9e5a60ef9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 30 Aug 2021 22:44:09 +0200 Subject: [PATCH 0925/1539] Advertise support for client_max_window_bits by default. This doesn't change anything in the default setup (compression="deflate") because it sets client_max_window_bits to 12 explicitly. If a client is configured with extensions=[ClientPerMessageDeflateFactory()] this change gives the server the option to limit client_max_window_bits. Specifically, this makes such a client compatible with a websockets server using the default setup (compression="deflate"). --- src/websockets/extensions/permessage_deflate.py | 2 +- tests/legacy/test_client_server.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 59f019d53..a377abb55 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -286,7 +286,7 @@ def __init__( server_no_context_takeover: bool = False, client_no_context_takeover: bool = False, server_max_window_bits: Optional[int] = None, - client_max_window_bits: Optional[Union[int, bool]] = None, + client_max_window_bits: Optional[Union[int, bool]] = True, compress_settings: Optional[Dict[str, Any]] = None, ) -> None: """ diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 3fcd0b044..016f08e73 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -834,7 +834,12 @@ def test_extension_client_rejection(self): ServerPerMessageDeflateFactory(), ] ) - @with_client("/extensions", extensions=[ClientPerMessageDeflateFactory()]) + @with_client( + "/extensions", + extensions=[ + ClientPerMessageDeflateFactory(client_max_window_bits=None), + ], + ) def test_extension_no_match_then_match(self): # The order requested by the client has priority. server_extensions = self.loop.run_until_complete(self.client.recv()) From 88ae5eb25fa3e8caa3f5d7be7517ec8d5c68847d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 3 Sep 2021 20:50:03 +0200 Subject: [PATCH 0926/1539] Make error messages more error friendly. Avoid weird exception chaining to assertion errors. --- docs/project/changelog.rst | 11 +++++++++++ src/websockets/exceptions.py | 5 +++-- src/websockets/uri.py | 15 +++++++-------- tests/test_exceptions.py | 4 ++-- tests/test_uri.py | 5 +++++ 5 files changed, 28 insertions(+), 12 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 07942783f..b9329d7b5 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -56,6 +56,15 @@ They may change at any time. If you raise :exc:`~exceptions.ConnectionClosed` or a subclass — rather than catch them when websockets raises them — you must change your code. +.. note:: + + **Version 10.0 adds a ``msg`` parameter to** ``InvalidURI.__init__`` **.** + + If you raise :exc:`~exceptions.InvalidURI` — rather than catch them when + websockets raises them — you must change your code. + +Also: + * Added compatibility with Python 3.10. * Added :func:`~websockets.broadcast` to send a message to many clients. @@ -150,6 +159,8 @@ They may change at any time. from websockets.client import connect from websockets.server import serve +Also: + * Added compatibility with Python 3.9. * Added support for IRIs in addition to URIs. diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index b3462484f..6bbea324c 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -374,11 +374,12 @@ class InvalidURI(WebSocketException): """ - def __init__(self, uri: str) -> None: + def __init__(self, uri: str, msg: str) -> None: self.uri = uri + self.msg = msg def __str__(self) -> str: - return f"{self.uri} isn't a valid URI" + return f"{self.uri} isn't a valid URI: {self.msg}" class PayloadTooBig(WebSocketException): diff --git a/src/websockets/uri.py b/src/websockets/uri.py index c99c3f16e..397c23116 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -54,13 +54,12 @@ def parse_uri(uri: str) -> WebSocketURI: """ parsed = urllib.parse.urlparse(uri) - try: - assert parsed.scheme in ["ws", "wss"] - assert parsed.params == "" - assert parsed.fragment == "" - assert parsed.hostname is not None - except AssertionError as exc: - raise exceptions.InvalidURI(uri) from exc + if parsed.scheme not in ["ws", "wss"]: + raise exceptions.InvalidURI(uri, "scheme isn't ws or wss") + if parsed.hostname is None: + raise exceptions.InvalidURI(uri, "hostname isn't provided") + if parsed.fragment != "": + raise exceptions.InvalidURI(uri, "fragment identifier is meaningless") secure = parsed.scheme == "wss" host = parsed.hostname @@ -73,7 +72,7 @@ def parse_uri(uri: str) -> WebSocketURI: # urllib.parse.urlparse accepts URLs with a username but without a # password. This doesn't make sense for HTTP Basic Auth credentials. if parsed.password is None: - raise exceptions.InvalidURI(uri) + raise exceptions.InvalidURI(uri, "username provided without password") user_info = (parsed.username, parsed.password) try: diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index e172cdd02..3ede25fdb 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -142,8 +142,8 @@ def test_str(self): "WebSocket connection isn't established yet", ), ( - InvalidURI("|"), - "| isn't a valid URI", + InvalidURI("|", "not at all!"), + "| isn't a valid URI: not at all!", ), ( PayloadTooBig("payload length exceeds limit: 2 > 1 bytes"), diff --git a/tests/test_uri.py b/tests/test_uri.py index a91bcb083..f937d2949 100644 --- a/tests/test_uri.py +++ b/tests/test_uri.py @@ -17,6 +17,10 @@ "ws://localhost/path?query", WebSocketURI(False, "localhost", 80, "/path?query", None), ), + ( + "ws://localhost/path;params", + WebSocketURI(False, "localhost", 80, "/path;params", None), + ), ( "WS://LOCALHOST/PATH?QUERY", WebSocketURI(False, "localhost", 80, "/PATH?QUERY", None), @@ -39,6 +43,7 @@ "https://localhost/", "ws://localhost/path#fragment", "ws://user@localhost/", + "ws:///path", ] From 27f6539abc678182970af621c007646403968f82 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 14 Aug 2021 15:01:47 +0200 Subject: [PATCH 0927/1539] Review reference docs. Remove sphinx-autodoc-typehints and use native autodoc features instead, mainly because I can't find a workaround for: https://github.com/agronholm/sphinx-autodoc-typehints/issues/132 Re-introduce the "common" docs, else linking to send() and recv() is a mess. --- docs/conf.py | 33 +- docs/howto/cheatsheet.rst | 16 +- docs/howto/django.rst | 2 +- docs/howto/extensions.rst | 15 +- docs/howto/faq.rst | 16 +- docs/intro/index.rst | 2 +- docs/project/changelog.rst | 108 ++--- docs/reference/client.rst | 62 +-- docs/reference/common.rst | 51 ++ docs/reference/exceptions.rst | 6 + docs/reference/extensions.rst | 36 +- docs/reference/index.rst | 46 +- docs/reference/limitations.rst | 4 +- docs/reference/server.rst | 84 ++-- docs/reference/types.rst | 20 + docs/reference/utilities.rst | 36 +- docs/requirements.txt | 1 - docs/spelling_wordlist.txt | 2 + docs/topics/broadcast.rst | 24 +- docs/topics/compression.rst | 7 +- docs/topics/deployment.rst | 2 +- docs/topics/design.rst | 38 +- docs/topics/memory.rst | 10 +- docs/topics/timeouts.rst | 9 +- src/websockets/__init__.py | 2 + src/websockets/__main__.py | 2 +- src/websockets/auth.py | 2 + src/websockets/client.py | 40 +- src/websockets/connection.py | 10 +- src/websockets/datastructures.py | 10 +- src/websockets/exceptions.py | 31 +- src/websockets/extensions/base.py | 100 ++-- .../extensions/permessage_deflate.py | 57 +-- src/websockets/frames.py | 94 ++-- src/websockets/headers.py | 69 +-- src/websockets/http11.py | 116 +++-- src/websockets/imports.py | 14 +- src/websockets/legacy/auth.py | 71 +-- src/websockets/legacy/client.py | 288 +++++------- src/websockets/legacy/framing.py | 65 ++- src/websockets/legacy/handshake.py | 62 +-- src/websockets/legacy/http.py | 22 +- src/websockets/legacy/protocol.py | 380 ++++++++++----- src/websockets/legacy/server.py | 442 ++++++++---------- src/websockets/server.py | 109 +++-- src/websockets/streams.py | 36 +- src/websockets/typing.py | 21 +- src/websockets/uri.py | 30 +- src/websockets/utils.py | 8 +- 49 files changed, 1467 insertions(+), 1244 deletions(-) create mode 100644 docs/reference/common.rst create mode 100644 docs/reference/exceptions.rst create mode 100644 docs/reference/types.rst diff --git a/docs/conf.py b/docs/conf.py index d5f527bf6..d9e3cd598 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -31,6 +31,28 @@ # -- General configuration --------------------------------------------------- +nitpicky = True + +nitpick_ignore = [ + # topics/design.rst discusses undocumented APIs + ("py:meth", "client.WebSocketClientProtocol.handshake"), + ("py:meth", "server.WebSocketServerProtocol.handshake"), + ("py:attr", "legacy.protocol.WebSocketCommonProtocol.is_client"), + ("py:attr", "legacy.protocol.WebSocketCommonProtocol.messages"), + ("py:meth", "legacy.protocol.WebSocketCommonProtocol.close_connection"), + ("py:attr", "legacy.protocol.WebSocketCommonProtocol.close_connection_task"), + ("py:meth", "legacy.protocol.WebSocketCommonProtocol.keepalive_ping"), + ("py:attr", "legacy.protocol.WebSocketCommonProtocol.keepalive_ping_task"), + ("py:meth", "legacy.protocol.WebSocketCommonProtocol.transfer_data"), + ("py:attr", "legacy.protocol.WebSocketCommonProtocol.transfer_data_task"), + ("py:meth", "legacy.protocol.WebSocketCommonProtocol.connection_open"), + ("py:meth", "legacy.protocol.WebSocketCommonProtocol.ensure_open"), + ("py:meth", "legacy.protocol.WebSocketCommonProtocol.fail_connection"), + ("py:meth", "legacy.protocol.WebSocketCommonProtocol.connection_lost"), + ("py:meth", "legacy.protocol.WebSocketCommonProtocol.read_message"), + ("py:meth", "legacy.protocol.WebSocketCommonProtocol.write_frame"), +] + # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. @@ -38,7 +60,7 @@ "sphinx.ext.autodoc", "sphinx.ext.intersphinx", "sphinx.ext.linkcode", - "sphinx_autodoc_typehints", + "sphinx.ext.napoleon", "sphinx_copybutton", "sphinx_inline_tabs", "sphinxcontrib.spelling", @@ -46,6 +68,15 @@ "sphinxext.opengraph", ] +autodoc_typehints = "description" + +autodoc_typehints_description_target = "documented" + +# Workaround for https://github.com/sphinx-doc/sphinx/issues/9560 +from sphinx.domains.python import PythonDomain +assert PythonDomain.object_types['data'].roles == ('data', 'obj') +PythonDomain.object_types['data'].roles = ('data', 'class', 'obj') + intersphinx_mapping = {"python": ("https://docs.python.org/3", None)} spelling_show_suggestions = True diff --git a/docs/howto/cheatsheet.rst b/docs/howto/cheatsheet.rst index edfb00baa..95b551f67 100644 --- a/docs/howto/cheatsheet.rst +++ b/docs/howto/cheatsheet.rst @@ -26,27 +26,27 @@ Server :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` if you wish but it isn't needed in general. -* Create a server with :func:`~legacy.server.serve` which is similar to asyncio's - :meth:`~asyncio.AbstractEventLoop.create_server`. You can also use it as an - asynchronous context manager. +* Create a server with :func:`~server.serve` which is similar to asyncio's + :meth:`~asyncio.loop.create_server`. You can also use it as an asynchronous + context manager. * The server takes care of establishing connections, then lets the handler execute the application logic, and finally closes the connection after the handler exits normally or with an exception. * For advanced customization, you may subclass - :class:`~legacy.server.WebSocketServerProtocol` and pass either this subclass or + :class:`~server.WebSocketServerProtocol` and pass either this subclass or a factory function as the ``create_protocol`` argument. Client ------ -* Create a client with :func:`~legacy.client.connect` which is similar to asyncio's - :meth:`~asyncio.BaseEventLoop.create_connection`. You can also use it as an +* Create a client with :func:`~client.connect` which is similar to asyncio's + :meth:`~asyncio.loop.create_connection`. You can also use it as an asynchronous context manager. * For advanced customization, you may subclass - :class:`~legacy.server.WebSocketClientProtocol` and pass either this subclass or + :class:`~client.WebSocketClientProtocol` and pass either this subclass or a factory function as the ``create_protocol`` argument. * Call :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` and @@ -57,7 +57,7 @@ Client :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` if you wish but it isn't needed in general. -* If you aren't using :func:`~legacy.client.connect` as a context manager, call +* If you aren't using :func:`~client.connect` as a context manager, call :meth:`~legacy.protocol.WebSocketCommonProtocol.close` to terminate the connection. .. _debugging: diff --git a/docs/howto/django.rst b/docs/howto/django.rst index c87c3821c..67bf582f1 100644 --- a/docs/howto/django.rst +++ b/docs/howto/django.rst @@ -124,7 +124,7 @@ support asynchronous I/O. It would block the event loop if it didn't run in a separate thread. :func:`~asyncio.to_thread` is available since Python 3.9. In earlier versions, use :meth:`~asyncio.loop.run_in_executor` instead. -Finally, we start a server with :func:`~websockets.serve`. +Finally, we start a server with :func:`~websockets.server.serve`. We're ready to test! diff --git a/docs/howto/extensions.rst b/docs/howto/extensions.rst index 9c49de172..2baead3f0 100644 --- a/docs/howto/extensions.rst +++ b/docs/howto/extensions.rst @@ -4,8 +4,10 @@ Writing an extension .. currentmodule:: websockets.extensions During the opening handshake, WebSocket clients and servers negotiate which -extensions will be used with which parameters. Then each frame is processed by -extensions before being sent or after being received. +extensions_ will be used with which parameters. Then each frame is processed +by extensions before being sent or after being received. + +.. _extensions: https://www.rfc-editor.org/rfc/rfc6455.html#section-9 As a consequence, writing an extension requires implementing several classes: @@ -23,11 +25,8 @@ As a consequence, writing an extension requires implementing several classes: Extensions are initialized by extension factories, so they don't need to be part of the public API of an extension. -websockets provides abstract base classes for extension factories and -extensions. See the API documentation for details on their methods: - -* :class:`ClientExtensionFactory` and :class:`ServerExtensionFactory` for - extension factories, -* :class:`Extension` for extensions. +websockets provides base classes for extension factories and extensions. +See :class:`ClientExtensionFactory`, :class:`ServerExtensionFactory`, +and :class:`Extension` for details. diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index d43d5fd67..ec657bb06 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -5,7 +5,7 @@ FAQ .. note:: - Many questions asked in :mod:`websockets`' issue tracker are actually + Many questions asked in websockets' issue tracker are actually about :mod:`asyncio`. Python's documentation about `developing with asyncio`_ is a good complement. @@ -99,13 +99,13 @@ How do I get access HTTP headers, for example cookies? ...................................................... To access HTTP headers during the WebSocket handshake, you can override -:attr:`~legacy.server.WebSocketServerProtocol.process_request`:: +:attr:`~server.WebSocketServerProtocol.process_request`:: async def process_request(self, path, request_headers): cookies = request_header["Cookie"] Once the connection is established, they're available in -:attr:`~legacy.protocol.WebSocketServerProtocol.request_headers`:: +:attr:`~server.WebSocketServerProtocol.request_headers`:: async def handler(websocket, path): cookies = websocket.request_headers["Cookie"] @@ -123,7 +123,7 @@ How do I set which IP addresses my server listens to? Look at the ``host`` argument of :meth:`~asyncio.loop.create_server`. -:func:`serve` accepts the same arguments as +:func:`~server.serve` accepts the same arguments as :meth:`~asyncio.loop.create_server`. How do I close a connection properly? @@ -143,7 +143,7 @@ Providing a HTTP server is out of scope for websockets. It only aims at providing a WebSocket server. There's limited support for returning HTTP responses with the -:attr:`~legacy.server.WebSocketServerProtocol.process_request` hook. +:attr:`~server.WebSocketServerProtocol.process_request` hook. If you need more, pick a HTTP server and run it separately. @@ -169,7 +169,7 @@ change it to:: How do I close a connection properly? ..................................... -The easiest is to use :func:`connect` as a context manager:: +The easiest is to use :func:`~client.connect` as a context manager:: async with connect(...) as websocket: ... @@ -196,7 +196,7 @@ How do I disable TLS/SSL certificate verification? Look at the ``ssl`` argument of :meth:`~asyncio.loop.create_connection`. -:func:`connect` accepts the same arguments as +:func:`~client.connect` accepts the same arguments as :meth:`~asyncio.loop.create_connection`. asyncio usage @@ -449,4 +449,4 @@ I'm having problems with threads You shouldn't use threads. Use tasks instead. -:func:`~asyncio.AbstractEventLoop.call_soon_threadsafe` may help. +:meth:`~asyncio.loop.call_soon_threadsafe` may help. diff --git a/docs/intro/index.rst b/docs/intro/index.rst index ff3ae0ffd..c8426719c 100644 --- a/docs/intro/index.rst +++ b/docs/intro/index.rst @@ -82,7 +82,7 @@ This client needs a context because the server uses a self-signed certificate. A client connecting to a secure WebSocket server with a valid certificate (i.e. signed by a CA that your Python installation trusts) can simply pass -``ssl=True`` to :func:`connect` instead of building a context. +``ssl=True`` to :func:`~client.connect` instead of building a context. Browser-based example --------------------- diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index b9329d7b5..073514ecd 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -37,10 +37,10 @@ They may change at any time. .. note:: **Version 10.0 enables a timeout of 10 seconds on** - :func:`~legacy.client.connect` **by default.** + :func:`~client.connect` **by default.** You can adjust the timeout with the ``open_timeout`` parameter. Set it to - ``None`` to disable the timeout entirely. + :obj:`None` to disable the timeout entirely. .. note:: @@ -51,7 +51,8 @@ They may change at any time. .. note:: - **Version 10.0 changes parameters of** ``ConnectionClosed.__init__`` **.** + **Version 10.0 changes arguments of** + :exc:`~exceptions.ConnectionClosed` **.** If you raise :exc:`~exceptions.ConnectionClosed` or a subclass — rather than catch them when websockets raises them — you must change your code. @@ -70,27 +71,30 @@ Also: * Added :func:`~websockets.broadcast` to send a message to many clients. * Added support for reconnecting automatically by using - :func:`~legacy.client.connect` as an asynchronous iterator. + :func:`~client.connect` as an asynchronous iterator. -* Added ``open_timeout`` to :func:`~legacy.client.connect`. +* Added ``open_timeout`` to :func:`~client.connect`. * Improved logging. -* Provided additional information in :exc:`ConnectionClosed` exceptions. +* Provided additional information in :exc:`~exceptions.ConnectionClosed` + exceptions. * Optimized default compression settings to reduce memory usage. * Made it easier to customize authentication with :meth:`~auth.BasicAuthWebSocketServerProtocol.check_credentials`. -* Fixed handling of relative redirects in :func:`~legacy.client.connect`. +* Fixed handling of relative redirects in :func:`~client.connect`. + +* Improved API documentation. 9.1 ... *May 27, 2021* -.. note:: +.. caution:: **Version 9.1 fixes a security issue introduced in version 8.0.** @@ -197,7 +201,7 @@ Also: *July 31, 2019* * Restored the ability to pass a socket with the ``sock`` parameter of - :func:`~legacy.server.serve`. + :func:`~server.serve`. * Removed an incorrect assertion when a connection drops. @@ -224,9 +228,9 @@ Also: Previously, it could be a function or a coroutine. If you're passing a ``process_request`` argument to - :func:`~legacy.server.serve` - or :class:`~legacy.server.WebSocketServerProtocol`, or if you're overriding - :meth:`~legacy.server.WebSocketServerProtocol.process_request` in a subclass, + :func:`~server.serve` or :class:`~server.WebSocketServerProtocol`, or if + you're overriding + :meth:`~server.WebSocketServerProtocol.process_request` in a subclass, define it with ``async def`` instead of ``def``. For backwards compatibility, functions are still mostly supported, but @@ -274,15 +278,15 @@ Also: :exc:`~exceptions.ConnectionClosed` to tell apart normal connection termination from errors. -* Added :func:`~legacy.auth.basic_auth_protocol_factory` to enforce HTTP +* Added :func:`~auth.basic_auth_protocol_factory` to enforce HTTP Basic Auth on the server side. -* :func:`~legacy.client.connect` handles redirects from the server during the +* :func:`~client.connect` handles redirects from the server during the handshake. -* :func:`~legacy.client.connect` supports overriding ``host`` and ``port``. +* :func:`~client.connect` supports overriding ``host`` and ``port``. -* Added :func:`~legacy.client.unix_connect` for connecting to Unix sockets. +* Added :func:`~client.unix_connect` for connecting to Unix sockets. * Improved support for sending fragmented messages by accepting asynchronous iterators in :meth:`~legacy.protocol.WebSocketCommonProtocol.send`. @@ -292,10 +296,10 @@ Also: as a workaround, you can remove it. * Changed :meth:`WebSocketServer.close() - ` to perform a proper closing handshake + ` to perform a proper closing handshake instead of failing the connection. -* Avoided a crash when a ``extra_headers`` callable returns ``None``. +* Avoided a crash when a ``extra_headers`` callable returns :obj:`None`. * Improved error messages when HTTP parsing fails. @@ -327,7 +331,7 @@ Also: **Version 7.0 changes how a server terminates connections when it's closed with** :meth:`WebSocketServer.close() - ` **.** + ` **.** Previously, connections handlers were canceled. Now, connections are closed with close code 1001 (going away). From the perspective of the @@ -345,7 +349,7 @@ Also: .. note:: **Version 7.0 renames the** ``timeout`` **argument of** - :func:`~legacy.server.serve` **and** :func:`~legacy.client.connect` **to** + :func:`~server.serve` **and** :func:`~client.connect` **to** ``close_timeout`` **.** This prevents confusion with ``ping_timeout``. @@ -375,11 +379,11 @@ Also: Also: * Added ``process_request`` and ``select_subprotocol`` arguments to - :func:`~legacy.server.serve` and - :class:`~legacy.server.WebSocketServerProtocol` to customize - :meth:`~legacy.server.WebSocketServerProtocol.process_request` and - :meth:`~legacy.server.WebSocketServerProtocol.select_subprotocol` without - subclassing :class:`~legacy.server.WebSocketServerProtocol`. + :func:`~server.serve` and + :class:`~server.WebSocketServerProtocol` to customize + :meth:`~server.WebSocketServerProtocol.process_request` and + :meth:`~server.WebSocketServerProtocol.select_subprotocol` without + subclassing :class:`~server.WebSocketServerProtocol`. * Added support for sending fragmented messages. @@ -389,7 +393,7 @@ Also: * Added an interactive client: ``python -m websockets ``. * Changed the ``origins`` argument to represent the lack of an origin with - ``None`` rather than ``''``. + :obj:`None` rather than ``''``. * Fixed a data loss bug in :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`: @@ -409,7 +413,7 @@ Also: **Version 6.0 introduces the** :class:`~datastructures.Headers` **class for managing HTTP headers and changes several public APIs:** - * :meth:`~legacy.server.WebSocketServerProtocol.process_request` now + * :meth:`~server.WebSocketServerProtocol.process_request` now receives a :class:`~datastructures.Headers` instead of a ``http.client.HTTPMessage`` in the ``request_headers`` argument. @@ -445,14 +449,14 @@ Also: *May 24, 2018* * Fixed a regression in 5.0 that broke some invocations of - :func:`~legacy.server.serve` and :func:`~legacy.client.connect`. + :func:`~server.serve` and :func:`~client.connect`. 5.0 ... *May 22, 2018* -.. note:: +.. caution:: **Version 5.0 fixes a security issue introduced in version 4.0.** @@ -465,14 +469,14 @@ Also: .. note:: **Version 5.0 adds a** ``user_info`` **field to the return value of** - :func:`~uri.parse_uri` **and** :class:`~uri.WebSocketURI` **.** + ``parse_uri`` **and** ``WebSocketURI`` **.** - If you're unpacking :class:`~uri.WebSocketURI` into four variables, adjust - your code to account for that fifth field. + If you're unpacking ``WebSocketURI`` into four variables, adjust your code + to account for that fifth field. Also: -* :func:`~legacy.client.connect` performs HTTP Basic Auth when the URI contains +* :func:`~client.connect` performs HTTP Basic Auth when the URI contains credentials. * Iterating on incoming messages no longer raises an exception when the @@ -481,7 +485,7 @@ Also: * A plain HTTP request now receives a 426 Upgrade Required response and doesn't log a stack trace. -* :func:`~legacy.server.unix_serve` can be used as an asynchronous context +* :func:`~server.unix_serve` can be used as an asynchronous context manager on Python ≥ 3.5.1. * Added the :attr:`~legacy.protocol.WebSocketCommonProtocol.closed` property @@ -536,7 +540,7 @@ Also: Compression should improve performance but it increases RAM and CPU use. If you want to disable compression, add ``compression=None`` when calling - :func:`~legacy.server.serve` or :func:`~legacy.client.connect`. + :func:`~server.serve` or :func:`~client.connect`. .. note:: @@ -549,10 +553,10 @@ Also: * :class:`~legacy.protocol.WebSocketCommonProtocol` instances can be used as asynchronous iterators on Python ≥ 3.6. They yield incoming messages. -* Added :func:`~legacy.server.unix_serve` for listening on Unix sockets. +* Added :func:`~server.unix_serve` for listening on Unix sockets. -* Added the :attr:`~legacy.server.WebSocketServer.sockets` attribute to the - return value of :func:`~legacy.server.serve`. +* Added the :attr:`~server.WebSocketServer.sockets` attribute to the + return value of :func:`~server.serve`. * Reorganized and extended documentation. @@ -572,15 +576,15 @@ Also: *August 20, 2017* -* Renamed :func:`~legacy.server.serve` and :func:`~legacy.client.connect`'s +* Renamed :func:`~server.serve` and :func:`~client.connect`'s ``klass`` argument to ``create_protocol`` to reflect that it can also be a callable. For backwards compatibility, ``klass`` is still supported. -* :func:`~legacy.server.serve` can be used as an asynchronous context manager +* :func:`~server.serve` can be used as an asynchronous context manager on Python ≥ 3.5.1. * Added support for customizing handling of incoming connections with - :meth:`~legacy.server.WebSocketServerProtocol.process_request`. + :meth:`~server.WebSocketServerProtocol.process_request`. * Made read and write buffer sizes configurable. @@ -588,10 +592,10 @@ Also: * Added an optional C extension to speed up low-level operations. -* An invalid response status code during :func:`~legacy.client.connect` now +* An invalid response status code during :func:`~client.connect` now raises :class:`~exceptions.InvalidStatusCode`. -* Providing a ``sock`` argument to :func:`~legacy.client.connect` no longer +* Providing a ``sock`` argument to :func:`~client.connect` no longer crashes. 3.3 @@ -611,7 +615,7 @@ Also: *August 17, 2016* * Added ``timeout``, ``max_size``, and ``max_queue`` arguments to - :func:`~legacy.client.connect` and :func:`~legacy.server.serve`. + :func:`~client.connect` and :func:`~server.serve`. * Made server shutdown more robust. @@ -637,8 +641,8 @@ Also: **If you're upgrading from 2.x or earlier, please read this carefully.** :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` used to return - ``None`` when the connection was closed. This required checking the return - value of every call:: + :obj:`None` when the connection was closed. This required checking the + return value of every call:: message = await websocket.recv() if message is None: @@ -655,14 +659,14 @@ Also: In order to avoid stranding projects built upon an earlier version, the previous behavior can be restored by passing ``legacy_recv=True`` to - :func:`~legacy.server.serve`, :func:`~legacy.client.connect`, - :class:`~legacy.server.WebSocketServerProtocol`, or - :class:`~legacy.client.WebSocketClientProtocol`. ``legacy_recv`` isn't + :func:`~server.serve`, :func:`~client.connect`, + :class:`~server.WebSocketServerProtocol`, or + :class:`~client.WebSocketClientProtocol`. ``legacy_recv`` isn't documented in their signatures but isn't scheduled for deprecation either. Also: -* :func:`~legacy.client.connect` can be used as an asynchronous context +* :func:`~client.connect` can be used as an asynchronous context manager on Python ≥ 3.5.1. * Updated documentation with ``await`` and ``async`` syntax from Python 3.5. @@ -732,8 +736,8 @@ Also: * Added support for subprotocols. -* Added ``loop`` argument to :func:`~legacy.client.connect` and - :func:`~legacy.server.serve`. +* Added ``loop`` argument to :func:`~client.connect` and + :func:`~server.serve`. 2.3 ... diff --git a/docs/reference/client.rst b/docs/reference/client.rst index 84f66a19a..eaa6cdd76 100644 --- a/docs/reference/client.rst +++ b/docs/reference/client.rst @@ -6,67 +6,55 @@ Client Opening a connection -------------------- - .. autofunction:: connect(uri, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, logger=None, **kwds) + .. autofunction:: connect(uri, *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) :async: - .. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, logger=None, **kwds) + .. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) :async: Using a connection ------------------ - .. autoclass:: WebSocketClientProtocol(*, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, origin=None, extensions=None, subprotocols=None, extra_headers=None, logger=None) + .. autoclass:: WebSocketClientProtocol(*, logger=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) - .. attribute:: id - - UUID for the connection. - - Useful for identifying connections in logs. - - .. autoattribute:: local_address - - .. autoattribute:: remote_address - - .. autoattribute:: open - - .. autoattribute:: closed - - .. attribute:: path + .. automethod:: recv - Path of the HTTP request. + .. automethod:: send - Available once the connection is open. + .. automethod:: close - .. attribute:: request_headers + .. automethod:: wait_closed - HTTP request headers as a :class:`~websockets.http.Headers` instance. + .. automethod:: ping - Available once the connection is open. + .. automethod:: pong - .. attribute:: response_headers + WebSocket connection objects also provide these attributes: - HTTP response headers as a :class:`~websockets.http.Headers` instance. + .. autoattribute:: id - Available once the connection is open. + .. autoproperty:: local_address - .. attribute:: subprotocol + .. autoproperty:: remote_address - Subprotocol, if one was negotiated. + .. autoproperty:: open - Available once the connection is open. + .. autoproperty:: closed - .. autoattribute:: close_code + The following attributes are available after the opening handshake, + once the WebSocket connection is open: - .. autoattribute:: close_reason + .. autoattribute:: path - .. automethod:: recv + .. autoattribute:: request_headers - .. automethod:: send + .. autoattribute:: response_headers - .. automethod:: ping + .. autoattribute:: subprotocol - .. automethod:: pong + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: - .. automethod:: close + .. autoproperty:: close_code - .. automethod:: wait_closed + .. autoproperty:: close_reason diff --git a/docs/reference/common.rst b/docs/reference/common.rst new file mode 100644 index 000000000..3b9f34a57 --- /dev/null +++ b/docs/reference/common.rst @@ -0,0 +1,51 @@ +Both sides +========== + +.. automodule:: websockets.legacy.protocol + + Using a connection + ------------------ + + .. autoclass:: WebSocketCommonProtocol(*, logger=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) + + .. automethod:: recv + + .. automethod:: send + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + .. autoproperty:: open + + .. autoproperty:: closed + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: path + + .. autoattribute:: request_headers + + .. autoattribute:: response_headers + + .. autoattribute:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason diff --git a/docs/reference/exceptions.rst b/docs/reference/exceptions.rst new file mode 100644 index 000000000..907a650d2 --- /dev/null +++ b/docs/reference/exceptions.rst @@ -0,0 +1,6 @@ +Exceptions +========== + +.. automodule:: websockets.exceptions + :members: + diff --git a/docs/reference/extensions.rst b/docs/reference/extensions.rst index bae583a21..a70f1b1e5 100644 --- a/docs/reference/extensions.rst +++ b/docs/reference/extensions.rst @@ -6,7 +6,7 @@ Extensions The WebSocket protocol supports extensions_. At the time of writing, there's only one `registered extension`_ with a public -specification, WebSocket Per-Message Deflate, specified in :rfc:`7692`. +specification, WebSocket Per-Message Deflate. .. _extensions: https://www.rfc-editor.org/rfc/rfc6455.html#section-9 .. _registered extension: https://www.iana.org/assignments/websocket/websocket.xhtml#extension-name @@ -16,21 +16,45 @@ Per-Message Deflate .. automodule:: websockets.extensions.permessage_deflate + :mod:`websockets.extensions.permessage_deflate` implements WebSocket + Per-Message Deflate. + + This extension is specified in :rfc:`7692`. + + Refer to the :doc:`topic guide on compression <../topics/compression>` to + learn more about tuning compression settings. + .. autoclass:: ClientPerMessageDeflateFactory .. autoclass:: ServerPerMessageDeflateFactory -Abstract classes ----------------- +Base classes +------------ .. automodule:: websockets.extensions + :mod:`websockets.extensions` defines base classes for implementing + extensions. + + Refer to the :doc:`how-to guide on extensions <../howto/extensions>` to + learn more about writing an extension. + .. autoclass:: Extension - :members: + + .. autoattribute:: name + + .. automethod:: decode + + .. automethod:: encode .. autoclass:: ClientExtensionFactory - :members: + + .. autoattribute:: name + + .. automethod:: get_request_params + + .. automethod:: process_response_params .. autoclass:: ServerExtensionFactory - :members: + .. automethod:: process_request_params diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 8d01c5b40..385beab29 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -1,24 +1,27 @@ API reference ============= -websockets provides complete client and server implementations, as shown in +.. currentmodule:: websockets + +websockets provides client and server implementations, as shown in the :doc:`getting started guide <../intro/index>`. The process for opening and closing a WebSocket connection depends on which side you're implementing. -* On the client side, connecting to a server with :class:`~websockets.connect` +* On the client side, connecting to a server with :func:`~client.connect` yields a connection object that provides methods for interacting with the connection. Your code can open a connection, then send or receive messages. - If you use :class:`~websockets.connect` as an asynchronous context manager, + If you use :func:`~client.connect` as an asynchronous context manager, then websockets closes the connection on exit. If not, then your code is responsible for closing the connection. -* On the server side, :class:`~websockets.serve` starts listening for client - connections and yields an server object that supports closing the server. +* On the server side, :func:`~server.serve` starts listening for client + connections and yields an server object that you can use to shut down + the server. - Then, when clients connects, the server initializes a connection object and + Then, when a client connects, the server initializes a connection object and passes it to a handler coroutine, which is where your code can send or receive messages. This pattern is called `inversion of control`_. It's common in frameworks implementing servers. @@ -29,28 +32,35 @@ side you're implementing. .. _inversion of control: https://en.wikipedia.org/wiki/Inversion_of_control Once the connection is open, the WebSocket protocol is symmetrical, except for -low-level details that websockets manages under the hood. The same methods are -available on client connections created with :class:`~websockets.connect` and -on server connections passed to the connection handler in the arguments. +low-level details that websockets manages under the hood. The same methods +are available on client connections created with :func:`~client.connect` and +on server connections received in argument by the connection handler +of :func:`~server.serve`. -At this point, websockets provides the same API — and uses the same code — for -client and server connections. For convenience, common methods are documented -both in the client API and server API. +Since websockets provides the same API — and uses the same code — for client +and server connections, common methods are documented in a "Both sides" page. .. toctree:: :titlesonly: client server - extensions + common utilities + exceptions + types + extensions limitations -All public APIs can be imported from the :mod:`websockets` package, unless -noted otherwise. This convenience feature is incompatible with static code -analysis tools such as mypy_, though. +Public API documented in the API reference are subject to the +:ref:`backwards-compatibility policy `. + +Anything that isn't listed in the API reference is a private API. There's no +guarantees of behavior or backwards-compatibility for private APIs. + +For convenience, many public APIs can be imported from the ``websockets`` +package. This feature is incompatible with static code analysis tools such as +mypy_, though. If you're using such tools, use the full import path. .. _mypy: https://github.com/python/mypy -Anything that isn't listed in this API documentation is a private API. There's -no guarantees of behavior or backwards-compatibility for private APIs. diff --git a/docs/reference/limitations.rst b/docs/reference/limitations.rst index 81f1445b5..3304bdb8c 100644 --- a/docs/reference/limitations.rst +++ b/docs/reference/limitations.rst @@ -1,12 +1,14 @@ Limitations =========== +.. currentmodule:: websockets + Client ------ The client doesn't attempt to guarantee that there is no more than one connection to a given IP address in a CONNECTING state. This behavior is -`mandated by RFC 6455`_. However, :func:`~websockets.connect()` isn't the +`mandated by RFC 6455`_. However, :func:`~client.connect()` isn't the right layer for enforcing this constraint. It's the caller's responsibility. .. _mandated by RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-4.1 diff --git a/docs/reference/server.rst b/docs/reference/server.rst index 667c0b9d0..0a5a060f3 100644 --- a/docs/reference/server.rst +++ b/docs/reference/server.rst @@ -6,10 +6,10 @@ Server Starting a server ----------------- - .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, logger=None, **kwds) + .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) :async: - .. autofunction:: unix_serve(ws_handler, path, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, logger=None, **kwds) + .. autofunction:: unix_serve(ws_handler, path=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) :async: Stopping a server @@ -17,92 +17,80 @@ Server .. autoclass:: WebSocketServer - .. autoattribute:: sockets - .. automethod:: close + .. automethod:: wait_closed + .. autoattribute:: sockets + Using a connection ------------------ - .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, logger=None) - - .. attribute:: id - - UUID for the connection. - - Useful for identifying connections in logs. + .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, logger=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) - .. autoattribute:: local_address - - .. autoattribute:: remote_address - - .. autoattribute:: open + .. automethod:: recv - .. autoattribute:: closed + .. automethod:: send - .. attribute:: path + .. automethod:: close - Path of the HTTP request. + .. automethod:: wait_closed - Available once the connection is open. + .. automethod:: ping - .. attribute:: request_headers + .. automethod:: pong - HTTP request headers as a :class:`~websockets.http.Headers` instance. + You can customize the opening handshake in a subclass by overriding these methods: - Available once the connection is open. + .. automethod:: process_request - .. attribute:: response_headers + .. automethod:: select_subprotocol - HTTP response headers as a :class:`~websockets.http.Headers` instance. + WebSocket connection objects also provide these attributes: - Available once the connection is open. + .. autoattribute:: id - .. attribute:: subprotocol + .. autoproperty:: local_address - Subprotocol, if one was negotiated. + .. autoproperty:: remote_address - Available once the connection is open. + .. autoproperty:: open - .. autoattribute:: close_code + .. autoproperty:: closed - .. autoattribute:: close_reason + The following attributes are available after the opening handshake, + once the WebSocket connection is open: - .. automethod:: process_request + .. autoattribute:: path - .. automethod:: select_subprotocol + .. autoattribute:: request_headers - .. automethod:: recv + .. autoattribute:: response_headers - .. automethod:: send + .. autoattribute:: subprotocol - .. automethod:: ping + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: - .. automethod:: pong + .. autoproperty:: close_code - .. automethod:: close + .. autoproperty:: close_reason - .. automethod:: wait_closed Basic authentication -------------------- .. automodule:: websockets.auth + websockets supports HTTP Basic Authentication according to + :rfc:`7235` and :rfc:`7617`. + .. autofunction:: basic_auth_protocol_factory .. autoclass:: BasicAuthWebSocketServerProtocol - .. attribute:: realm - - Scope of protection. - - If provided, it should contain only ASCII characters because the - encoding of non-ASCII characters is undefined. - - .. attribute:: username + .. autoattribute:: realm - Username of the authenticated user. + .. autoattribute:: username .. automethod:: check_credentials diff --git a/docs/reference/types.rst b/docs/reference/types.rst new file mode 100644 index 000000000..3dab553af --- /dev/null +++ b/docs/reference/types.rst @@ -0,0 +1,20 @@ +Types +===== + +.. autodata:: websockets.datastructures.HeadersLike + +.. automodule:: websockets.typing + + .. autodata:: Data + + .. autodata:: LoggerLike + + .. autodata:: Origin + + .. autodata:: Subprotocol + + .. autodata:: ExtensionName + + .. autodata:: ExtensionParameter + + diff --git a/docs/reference/utilities.rst b/docs/reference/utilities.rst index e7f489fbd..dc6333847 100644 --- a/docs/reference/utilities.rst +++ b/docs/reference/utilities.rst @@ -6,36 +6,32 @@ Broadcast .. autofunction:: websockets.broadcast -Data structures ---------------- +WebSocket events +---------------- -.. automodule:: websockets.datastructures - - .. autoclass:: Headers +.. automodule:: websockets.frames - .. autodata:: HeadersLike + .. autoclass:: Frame - .. autoexception:: MultipleValuesError + .. autoclass:: Opcode -Exceptions ----------- + .. autoclass:: Close -.. automodule:: websockets.exceptions - :members: +HTTP events +----------- -Types ------ +.. automodule:: websockets.http11 -.. automodule:: websockets.typing + .. autoclass:: Request - .. autodata:: Data + .. autoclass:: Response - .. autodata:: LoggerLike +.. automodule:: websockets.datastructures - .. autodata:: Origin + .. autoclass:: Headers - .. autodata:: Subprotocol + .. automethod:: get_all - .. autodata:: ExtensionName + .. automethod:: raw_items - .. autodata:: ExtensionParameter + .. autoexception:: MultipleValuesError diff --git a/docs/requirements.txt b/docs/requirements.txt index b9c371228..bcd1d7114 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,7 +1,6 @@ furo sphinx sphinx-autobuild -sphinx-autodoc-typehints sphinx-copybutton sphinx-inline-tabs sphinxcontrib-spelling diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 3d05752d5..b57d3c77f 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -35,6 +35,7 @@ linkerd liveness lookups MiB +mypy nginx permessage pid @@ -59,6 +60,7 @@ tox unregister uple uvicorn +uvloop virtualenv WebSocket websocket diff --git a/docs/topics/broadcast.rst b/docs/topics/broadcast.rst index f9cd9e281..a90cc2d70 100644 --- a/docs/topics/broadcast.rst +++ b/docs/topics/broadcast.rst @@ -1,13 +1,13 @@ Broadcasting messages ===================== -.. currentmodule: websockets +.. currentmodule:: websockets .. note:: If you just want to send a message to all connected clients, use - :func:`~websockets.broadcast`. + :func:`broadcast`. If you want to learn about its design in depth, continue reading this document. @@ -16,7 +16,7 @@ WebSocket servers often send the same message to all connected clients or to a subset of clients for which the message is relevant. Let's explore options for broadcasting a message, explain the design -of :func:`~websockets.broadcast`, and discuss alternatives. +of :func:`broadcast`, and discuss alternatives. For each option, we'll provide a connection handler called ``handler()`` and a function or coroutine called ``broadcast()`` that sends a message to all @@ -122,7 +122,7 @@ connections before the write buffer has time to fill up. Don't set extreme ``write_limit``, ``ping_interval``, and ``ping_timeout`` values to ensure that this condition holds. Set reasonable values and use the -built-in :func:`~websockets.broadcast` function instead. +built-in :func:`broadcast` function instead. The concurrent way ------------------ @@ -207,11 +207,11 @@ If a client gets too far behind, eventually it reaches the limit defined by ``ping_timeout`` and websockets terminates the connection. You can read the discussion of :doc:`keepalive and timeouts <./timeouts>` for details. -How :func:`~websockets.broadcast` works ---------------------------------------- +How :func:`broadcast` works +--------------------------- -The built-in :func:`~websockets.broadcast` function is similar to the naive -way. The main difference is that it doesn't apply backpressure. +The built-in :func:`broadcast` function is similar to the naive way. The main +difference is that it doesn't apply backpressure. This provides the best performance by avoiding the overhead of scheduling and running one task per client. @@ -321,9 +321,9 @@ the asynchronous iterator returned by ``subscribe()``. Performance considerations -------------------------- -The built-in :func:`~websockets.broadcast` function sends all messages without -yielding control to the event loop. So does the naive way when the network -and clients are fast and reliable. +The built-in :func:`broadcast` function sends all messages without yielding +control to the event loop. So does the naive way when the network and clients +are fast and reliable. For each client, a WebSocket frame is prepared and sent to the network. This is the minimum amount of work required to broadcast a message. @@ -343,7 +343,7 @@ However, this isn't possible in general for two reasons: All other patterns discussed above yield control to the event loop once per client because messages are sent by different tasks. This makes them slower -than the built-in :func:`~websockets.broadcast` function. +than the built-in :func:`broadcast` function. There is no major difference between the performance of per-message queues and publish–subscribe. diff --git a/docs/topics/compression.rst b/docs/topics/compression.rst index f78e32748..d40c4257d 100644 --- a/docs/topics/compression.rst +++ b/docs/topics/compression.rst @@ -1,6 +1,8 @@ Compression =========== +.. currentmodule:: websockets.extensions.permessage_deflate + Most WebSocket servers exchange JSON messages because they're convenient to parse and serialize in a browser. These messages contain text data and tend to be repetitive. @@ -29,9 +31,8 @@ If you want to disable compression, set ``compression=None``:: websockets.serve(..., compression=None) If you want to customize compression settings, you can enable the Per-Message -Deflate extension explicitly with -:class:`~permessage_deflate.ClientPerMessageDeflateFactory` or -:class:`~permessage_deflate.ServerPerMessageDeflateFactory`:: +Deflate extension explicitly with :class:`ClientPerMessageDeflateFactory` or +:class:`ServerPerMessageDeflateFactory`:: import websockets from websockets.extensions import permessage_deflate diff --git a/docs/topics/deployment.rst b/docs/topics/deployment.rst index d30c5568e..ac0a8ed4c 100644 --- a/docs/topics/deployment.rst +++ b/docs/topics/deployment.rst @@ -78,7 +78,7 @@ Option 2 almost always combines with option 3. How do I start a process? ......................... -Run a Python program that invokes :func:`~serve`. That's it. +Run a Python program that invokes :func:`~server.serve`. That's it. Don't run an ASGI server such as Uvicorn, Hypercorn, or Daphne. They're alternatives to websockets, not complements. diff --git a/docs/topics/design.rst b/docs/topics/design.rst index 2c9d505aa..b5c55afc9 100644 --- a/docs/topics/design.rst +++ b/docs/topics/design.rst @@ -35,7 +35,7 @@ Transitions happen in the following places: :meth:`~legacy.protocol.WebSocketCommonProtocol.connection_open` which runs when the :ref:`opening handshake ` completes and the WebSocket connection is established — not to be confused with - :meth:`~asyncio.Protocol.connection_made` which runs when the TCP connection + :meth:`~asyncio.BaseProtocol.connection_made` which runs when the TCP connection is established; - ``OPEN -> CLOSING``: in :meth:`~legacy.protocol.WebSocketCommonProtocol.write_frame` immediately before @@ -58,7 +58,7 @@ connection lifecycle on the client side. :target: _images/lifecycle.svg The lifecycle is identical on the server side, except inversion of control -makes the equivalent of :meth:`~legacy.client.connect` implicit. +makes the equivalent of :meth:`~client.connect` implicit. Coroutines shown in green are called by the application. Multiple coroutines may interact with the WebSocket connection concurrently. @@ -113,7 +113,7 @@ Opening handshake ----------------- websockets performs the opening handshake when establishing a WebSocket -connection. On the client side, :meth:`~legacy.client.connect` executes it +connection. On the client side, :meth:`~client.connect` executes it before returning the protocol to the caller. On the server side, it's executed before passing the protocol to the ``ws_handler`` coroutine handling the connection. @@ -123,26 +123,26 @@ request and the server replies with an HTTP Switching Protocols response — websockets aims at keeping the implementation of both sides consistent with one another. -On the client side, :meth:`~legacy.client.WebSocketClientProtocol.handshake`: +On the client side, :meth:`~client.WebSocketClientProtocol.handshake`: - builds a HTTP request based on the ``uri`` and parameters passed to - :meth:`~legacy.client.connect`; + :meth:`~client.connect`; - writes the HTTP request to the network; - reads a HTTP response from the network; - checks the HTTP response, validates ``extensions`` and ``subprotocol``, and configures the protocol accordingly; - moves to the ``OPEN`` state. -On the server side, :meth:`~legacy.server.WebSocketServerProtocol.handshake`: +On the server side, :meth:`~server.WebSocketServerProtocol.handshake`: - reads a HTTP request from the network; -- calls :meth:`~legacy.server.WebSocketServerProtocol.process_request` which may +- calls :meth:`~server.WebSocketServerProtocol.process_request` which may abort the WebSocket handshake and return a HTTP response instead; this hook only makes sense on the server side; - checks the HTTP request, negotiates ``extensions`` and ``subprotocol``, and configures the protocol accordingly; - builds a HTTP response based on the above and parameters passed to - :meth:`~legacy.server.serve`; + :meth:`~server.serve`; - writes the HTTP response to the network; - moves to the ``OPEN`` state; - returns the ``path`` part of the ``uri``. @@ -186,8 +186,8 @@ in the same class, :class:`~legacy.protocol.WebSocketCommonProtocol`. The :attr:`~legacy.protocol.WebSocketCommonProtocol.is_client` attribute tells which side a protocol instance is managing. This attribute is defined on the -:attr:`~legacy.server.WebSocketServerProtocol` and -:attr:`~legacy.client.WebSocketClientProtocol` classes. +:attr:`~server.WebSocketServerProtocol` and +:attr:`~client.WebSocketClientProtocol` classes. Data flow ......... @@ -264,14 +264,14 @@ Closing handshake When the other side of the connection initiates the closing handshake, :meth:`~legacy.protocol.WebSocketCommonProtocol.read_message` receives a close frame while in the ``OPEN`` state. It moves to the ``CLOSING`` state, sends a -close frame, and returns ``None``, causing +close frame, and returns :obj:`None`, causing :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` to terminate. When this side of the connection initiates the closing handshake with :meth:`~legacy.protocol.WebSocketCommonProtocol.close`, it moves to the ``CLOSING`` state and sends a close frame. When the other side sends a close frame, :meth:`~legacy.protocol.WebSocketCommonProtocol.read_message` receives it in the -``CLOSING`` state and returns ``None``, also causing +``CLOSING`` state and returns :obj:`None`, also causing :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` to terminate. If the other side doesn't send a close frame within the connection's close @@ -313,7 +313,7 @@ of canceling :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_ta and failing to close the TCP connection, thus leaking resources. Then :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` cancels -:attr:`~legacy.protocol.WebSocketCommonProtocol.keepalive_ping`. This task has no +:meth:`~legacy.protocol.WebSocketCommonProtocol.keepalive_ping`. This task has no protocol compliance responsibilities. Terminating it to avoid leaking it is the only concern. @@ -445,15 +445,15 @@ is canceled, which is correct at this point. to prevent cancellation. :meth:`~legacy.protocol.WebSocketCommonProtocol.close` and -:func:`~legacy.protocol.WebSocketCommonProtocol.fail_connection` are the only +:meth:`~legacy.protocol.WebSocketCommonProtocol.fail_connection` are the only places where :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` may be canceled. -:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connnection_task` starts by +:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` starts by waiting for :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task`. It catches :exc:`~asyncio.CancelledError` to prevent a cancellation of :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` from propagating -to :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connnection_task`. +to :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task`. .. _backpressure: @@ -520,21 +520,21 @@ For each connection, the receiving side contains these buffers: - OS buffers: tuning them is an advanced optimization. - :class:`~asyncio.StreamReader` bytes buffer: the default limit is 64 KiB. You can set another limit by passing a ``read_limit`` keyword argument to - :func:`~legacy.client.connect()` or :func:`~legacy.server.serve`. + :func:`~client.connect()` or :func:`~server.serve`. - Incoming messages :class:`~collections.deque`: its size depends both on the size and the number of messages it contains. By default the maximum UTF-8 encoded size is 1 MiB and the maximum number is 32. In the worst case, after UTF-8 decoding, a single message could take up to 4 MiB of memory and the overall memory consumption could reach 128 MiB. You should adjust these limits by setting the ``max_size`` and ``max_queue`` keyword arguments of - :func:`~legacy.client.connect()` or :func:`~legacy.server.serve` according to your + :func:`~client.connect()` or :func:`~server.serve` according to your application's requirements. For each connection, the sending side contains these buffers: - :class:`~asyncio.StreamWriter` bytes buffer: the default size is 64 KiB. You can set another limit by passing a ``write_limit`` keyword argument to - :func:`~legacy.client.connect()` or :func:`~legacy.server.serve`. + :func:`~client.connect()` or :func:`~server.serve`. - OS buffers: tuning them is an advanced optimization. Concurrency diff --git a/docs/topics/memory.rst b/docs/topics/memory.rst index c880d5579..ee0109c35 100644 --- a/docs/topics/memory.rst +++ b/docs/topics/memory.rst @@ -1,6 +1,8 @@ Memory usage ============ +.. currentmodule:: websockets + In most cases, memory usage of a WebSocket server is proportional to the number of open connections. When a server handles thousands of connections, memory usage can become a bottleneck. @@ -17,8 +19,8 @@ Baseline Compression settings are the main factor affecting the baseline amount of memory used by each connection. -Read to the topic guide on :doc:`../topics/compression` to learn more about -tuning compression settings. +Refer to the :doc:`topic guide on compression <../topics/compression>` to +learn more about tuning compression settings. Buffers ------- @@ -29,7 +31,7 @@ Under high load, if a server receives more messages than it can process, bufferbloat can result in excessive memory usage. By default websockets has generous limits. It is strongly recommended to adapt -them to your application. When you call :func:`~legacy.server.serve`: +them to your application. When you call :func:`~server.serve`: - Set ``max_size`` (default: 1 MiB, UTF-8 encoded) to the maximum size of messages your application generates. @@ -40,4 +42,4 @@ them to your application. When you call :func:`~legacy.server.serve`: Furthermore, you can lower ``read_limit`` and ``write_limit`` (default: 64 KiB) to reduce the size of buffers for incoming and outgoing data. -The design document provides :ref:`more details about buffers`. +The design document provides :ref:`more details about buffers `. diff --git a/docs/topics/timeouts.rst b/docs/topics/timeouts.rst index 8febfce9f..51666ceea 100644 --- a/docs/topics/timeouts.rst +++ b/docs/topics/timeouts.rst @@ -1,6 +1,8 @@ Timeouts ======== +.. currentmodule:: websockets + Since the WebSocket protocol is intended for real-time communications over long-lived connections, it is desirable to ensure that connections don't break, and if they do, to report the problem quickly. @@ -13,15 +15,18 @@ As a consequence, proxies may terminate WebSocket connections prematurely, when no message was exchanged in 30 seconds. In order to avoid this problem, websockets implements a keepalive mechanism -based on WebSocket Ping and Pong frames. Ping and Pong are designed for this +based on WebSocket Ping_ and Pong_ frames. Ping and Pong are designed for this purpose. +.. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 +.. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 + By default, websockets waits 20 seconds, then sends a Ping frame, and expects to receive the corresponding Pong frame within 20 seconds. Else, it considers the connection broken and closes it. Timings are configurable with the ``ping_interval`` and ``ping_timeout`` -arguments of :func:`~websockets.connect` and :func:`~websockets.serve`. +arguments of :func:`~client.connect` and :func:`~server.serve`. While WebSocket runs on top of TCP, websockets doesn't rely on TCP keepalive because it's disabled by default and, if enabled, the default interval is no diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 5883c3d65..ec3484124 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from .imports import lazy_import from .version import version as __version__ # noqa diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index 785d2c3c9..860e4b1fa 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -183,7 +183,7 @@ def main() -> None: # Due to zealous removal of the loop parameter in the Queue constructor, # we need a factory coroutine to run in the freshly created event loop. - async def queue_factory() -> "asyncio.Queue[str]": + async def queue_factory() -> asyncio.Queue[str]: return asyncio.Queue() # Create a queue of user inputs. There's no need to limit its size. diff --git a/src/websockets/auth.py b/src/websockets/auth.py index f97c1feb0..afcb38cff 100644 --- a/src/websockets/auth.py +++ b/src/websockets/auth.py @@ -1,2 +1,4 @@ +from __future__ import annotations + # See #940 for why lazy_import isn't used here for backwards compatibility. from .legacy.auth import * # noqa diff --git a/src/websockets/client.py b/src/websockets/client.py index e21fd36c0..13c8a8ad0 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -68,7 +68,7 @@ def __init__( def connect(self) -> Request: # noqa: F811 """ - Create a WebSocket handshake request event to send to the server. + Create a WebSocket handshake request event to open a connection. """ headers = Headers() @@ -107,12 +107,13 @@ def connect(self) -> Request: # noqa: F811 def process_response(self, response: Response) -> None: """ - Check a handshake response received from the server. + Check a handshake response. - :param response: response - :param key: comes from :func:`build_request` - :raises ~websockets.exceptions.InvalidHandshake: if the handshake response - is invalid + Args: + request: WebSocket handshake response received from the server. + + Raises: + InvalidHandshake: if the handshake response is invalid. """ @@ -162,11 +163,6 @@ def process_extensions(self, headers: Headers) -> List[Extension]: Check that each extension is supported, as well as its parameters. - Return the list of accepted extensions. - - Raise :exc:`~websockets.exceptions.InvalidHandshake` to abort the - connection. - :rfc:`6455` leaves the rules up to the specification of each extension. @@ -182,6 +178,15 @@ def process_extensions(self, headers: Headers) -> List[Extension]: Other requirements, for example related to mandatory extensions or the order of extensions, may be implemented by overriding this method. + Args: + headers: WebSocket handshake response headers. + + Returns: + List[Extension]: List of accepted extensions. + + Raises: + InvalidHandshake: to abort the handshake. + """ accepted_extensions: List[Extension] = [] @@ -232,9 +237,13 @@ def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: """ Handle the Sec-WebSocket-Protocol HTTP response header. - Check that it contains exactly one supported subprotocol. + If provided, check that it contains exactly one supported subprotocol. - Return the selected subprotocol. + Args: + headers: WebSocket handshake response headers. + + Returns: + Optional[Subprotocol]: Subprotocol, if one was selected. """ subprotocol: Optional[Subprotocol] = None @@ -263,7 +272,10 @@ def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: def send_request(self, request: Request) -> None: """ - Send a WebSocket handshake request to the server. + Send a handshake request to the server. + + Args: + request: WebSocket handshake request event to send. """ if self.debug: diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 52fd9bb81..684664860 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -212,7 +212,8 @@ def receive_data(self, data: bytes) -> None: - You must call :meth:`data_to_send` and send this data. - You should call :meth:`events_received` and process these events. - :raises EOFError: if :meth:`receive_eof` was called before + Raises: + EOFError: if :meth:`receive_eof` was called before. """ self.reader.feed_data(data) @@ -228,7 +229,8 @@ def receive_eof(self) -> None: - You aren't exepcted to call :meth:`events_received` as it won't return any new events. - :raises EOFError: if :meth:`receive_eof` was called before + Raises: + EOFError: if :meth:`receive_eof` was called before. """ self.reader.feed_eof() @@ -367,8 +369,8 @@ def close_expected(self) -> bool: Tell whether the TCP connection is expected to close soon. Call this method immediately after calling any of the ``receive_*()`` - or ``fail_*()`` methods and, if it returns ``True``, schedule closing - the TCP connection after a short timeout. + or ``fail_*()`` methods and, if it returns :obj:`True`, schedule + closing the TCP connection after a short timeout. """ # We already got a TCP Close if and only if the state is CLOSED. diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index 65c5d4115..1ff586abd 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -1,8 +1,3 @@ -""" -:mod:`websockets.datastructures` defines a class for manipulating HTTP headers. - -""" - from __future__ import annotations from typing import ( @@ -141,7 +136,7 @@ def clear(self) -> None: def update(self, *args: HeadersLike, **kwargs: str) -> None: """ - Update from a Headers instance and/or keyword arguments. + Update from a :class:`Headers` instance and/or keyword arguments. """ args = tuple( @@ -155,7 +150,8 @@ def get_all(self, key: str) -> List[str]: """ Return the (possibly empty) list of all values for a header. - :param key: header name + Args: + key: header name. """ return self._dict.get(key.lower(), []) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 6bbea324c..0c4fc5185 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -67,7 +67,7 @@ class WebSocketException(Exception): """ - Base class for all exceptions defined by :mod:`websockets`. + Base class for all exceptions defined by websockets. """ @@ -76,17 +76,14 @@ class ConnectionClosed(WebSocketException): """ Raised when trying to interact with a closed connection. - If a close frame was received, its code and reason are available in the - ``rcvd.code`` and ``rcvd.reason`` attributes. Else, the ``rcvd`` - attribute is ``None``. - - Likewise, if a close frame was sent, its code and reason are available in - the ``sent.code`` and ``sent.reason`` attributes. Else, the ``sent`` - attribute is ``None``. - - If close frames were received and sent, the ``rcvd_then_sent`` attribute - tells in which order this happened, from the perspective of this side of - the connection. + Attributes: + rcvd (Optional[Close]): if a close frame was received, its code and + reason are available in ``rcvd.code`` and ``rcvd.reason``. + sent (Optional[Close]): if a close frame was sent, its code and reason + are available in ``sent.code`` and ``sent.reason``. + rcvd_then_sent (Optional[bool]): if close frames were received and + sent, this attribute tells in which order this happened, from the + perspective of this side of the connection. """ @@ -249,9 +246,6 @@ class InvalidStatusCode(InvalidHandshake): """ Raised when a handshake response status code is invalid. - The integer status code is available in the ``status_code`` attribute and - HTTP headers in the ``headers`` attribute. - """ def __init__(self, status_code: int, headers: datastructures.Headers) -> None: @@ -320,8 +314,13 @@ class AbortHandshake(InvalidHandshake): This exception is an implementation detail. - The public API is :meth:`~legacy.server.WebSocketServerProtocol.process_request`. + The public API + is :meth:`~websockets.server.WebSocketServerProtocol.process_request`. + Attributes: + status (~http.HTTPStatus): HTTP status code. + headers (Headers): HTTP response headers. + body (bytes): HTTP response body. """ def __init__( diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index 7217aa513..060967618 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -1,13 +1,3 @@ -""" -:mod:`websockets.extensions.base` defines abstract classes for implementing -extensions. - -See `section 9 of RFC 6455`_. - -.. _section 9 of RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-9 - -""" - from __future__ import annotations from typing import List, Optional, Sequence, Tuple @@ -21,16 +11,12 @@ class Extension: """ - Abstract class for extensions. + Base class for extensions. """ - @property - def name(self) -> ExtensionName: - """ - Extension identifier. - - """ + name: ExtensionName + """Extension identifier.""" def decode( self, @@ -41,8 +27,15 @@ def decode( """ Decode an incoming frame. - :param frame: incoming frame - :param max_size: maximum payload size in bytes + Args: + frame (Frame): incoming frame. + max_size: maximum payload size in bytes. + + Returns: + Frame: Decoded frame. + + Raises: + PayloadTooBig: if decoding the payload exceeds ``max_size``. """ @@ -50,29 +43,30 @@ def encode(self, frame: frames.Frame) -> frames.Frame: """ Encode an outgoing frame. - :param frame: outgoing frame + Args: + frame (Frame): outgoing frame. + + Returns: + Frame: Encoded frame. """ class ClientExtensionFactory: """ - Abstract class for client-side extension factories. + Base class for client-side extension factories. """ - @property - def name(self) -> ExtensionName: - """ - Extension identifier. - - """ + name: ExtensionName + """Extension identifier.""" def get_request_params(self) -> List[ExtensionParameter]: """ - Build request parameters. + Build parameters to send to the server for this extension. - Return a list of ``(name, value)`` pairs. + Returns: + List[ExtensionParameter]: Parameters to send to the server. """ @@ -82,28 +76,31 @@ def process_response_params( accepted_extensions: Sequence[Extension], ) -> Extension: """ - Process response parameters received from the server. + Process parameters received from the server. + + Args: + params (Sequence[ExtensionParameter]): parameters received from + the server for this extension. + accepted_extensions (Sequence[Extension]): list of previously + accepted extensions. - :param params: list of ``(name, value)`` pairs. - :param accepted_extensions: list of previously accepted extensions. - :raises ~websockets.exceptions.NegotiationError: if parameters aren't - acceptable + Returns: + Extension: An extension instance. + + Raises: + NegotiationError: if parameters aren't acceptable. """ class ServerExtensionFactory: """ - Abstract class for server-side extension factories. + Base class for server-side extension factories. """ - @property - def name(self) -> ExtensionName: - """ - Extension identifier. - - """ + name: ExtensionName + """Extension identifier.""" def process_request_params( self, @@ -111,16 +108,21 @@ def process_request_params( accepted_extensions: Sequence[Extension], ) -> Tuple[List[ExtensionParameter], Extension]: """ - Process request parameters received from the client. + Process parameters received from the client. - To accept the offer, return a 2-uple containing: + Args: + params (Sequence[ExtensionParameter]): parameters received from + the client for this extension. + accepted_extensions (Sequence[Extension]): list of previously + accepted extensions. - - response parameters: a list of ``(name, value)`` pairs - - an extension: an instance of a subclass of :class:`Extension` + Returns: + Tuple[List[ExtensionParameter], Extension]: To accept the offer, + parameters to send to the client for this extension and an + extension instance. - :param params: list of ``(name, value)`` pairs. - :param accepted_extensions: list of previously accepted extensions. - :raises ~websockets.exceptions.NegotiationError: to reject the offer, - if parameters aren't acceptable + Raises: + NegotiationError: to reject the offer, if parameters received from + the client aren't acceptable. """ diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index a377abb55..da2bc153e 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -1,9 +1,3 @@ -""" -:mod:`websockets.extensions.permessage_deflate` implements the Compression -Extensions for WebSocket as specified in :rfc:`7692`. - -""" - from __future__ import annotations import dataclasses @@ -204,8 +198,8 @@ def _extract_parameters( """ Extract compression parameters from a list of ``(name, value)`` pairs. - If ``is_server`` is ``True``, ``client_max_window_bits`` may be provided - without a value. This is only allow in handshake requests. + If ``is_server`` is :obj:`True`, ``client_max_window_bits`` may be + provided without a value. This is only allowed in handshake requests. """ server_no_context_takeover: bool = False @@ -264,18 +258,23 @@ class ClientPerMessageDeflateFactory(ClientExtensionFactory): """ Client-side extension factory for the Per-Message Deflate extension. - Parameters behave as described in `section 7.1 of RFC 7692`_. Set them to - ``True`` to include them in the negotiation offer without a value or to an - integer value to include them with this value. + Parameters behave as described in `section 7.1 of RFC 7692`_. .. _section 7.1 of RFC 7692: https://www.rfc-editor.org/rfc/rfc7692.html#section-7.1 - :param server_no_context_takeover: defaults to ``False`` - :param client_no_context_takeover: defaults to ``False`` - :param server_max_window_bits: optional, defaults to ``None`` - :param client_max_window_bits: optional, defaults to ``None`` - :param compress_settings: optional, keyword arguments for - :func:`zlib.compressobj`, excluding ``wbits`` + Set them to :obj:`True` to include them in the negotiation offer without a + value or to an integer value to include them with this value. + + Args: + server_no_context_takeover: prevent server from using context takeover. + client_no_context_takeover: prevent client from using context takeover. + server_max_window_bits: maximum size of the server's LZ77 sliding window + in bits, between 8 and 15. + client_max_window_bits: maximum size of the client's LZ77 sliding window + in bits, between 8 and 15, or :obj:`True` to indicate support without + setting a limit. + compress_settings: additional keyword arguments for :func:`zlib.compressobj`, + excluding ``wbits``. """ @@ -440,7 +439,6 @@ def enable_client_permessage_deflate( If the extension is already present, perhaps with non-default settings, the configuration isn't changed. - """ if extensions is None: extensions = [] @@ -462,18 +460,23 @@ class ServerPerMessageDeflateFactory(ServerExtensionFactory): """ Server-side extension factory for the Per-Message Deflate extension. - Parameters behave as described in `section 7.1 of RFC 7692`_. Set them to - ``True`` to include them in the negotiation offer without a value or to an - integer value to include them with this value. + Parameters behave as described in `section 7.1 of RFC 7692`_. .. _section 7.1 of RFC 7692: https://www.rfc-editor.org/rfc/rfc7692.html#section-7.1 - :param server_no_context_takeover: defaults to ``False`` - :param client_no_context_takeover: defaults to ``False`` - :param server_max_window_bits: optional, defaults to ``None`` - :param client_max_window_bits: optional, defaults to ``None`` - :param compress_settings: optional, keyword arguments for - :func:`zlib.compressobj`, excluding ``wbits`` + Set them to :obj:`True` to include them in the negotiation offer without a + value or to an integer value to include them with this value. + + Args: + server_no_context_takeover: prevent server from using context takeover. + client_no_context_takeover: prevent client from using context takeover. + server_max_window_bits: maximum size of the server's LZ77 sliding window + in bits, between 8 and 15. + client_max_window_bits: maximum size of the client's LZ77 sliding window + in bits, between 8 and 15, or :obj:`True` to indicate support without + setting a limit. + compress_settings: additional keyword arguments for :func:`zlib.compressobj`, + excluding ``wbits``. """ diff --git a/src/websockets/frames.py b/src/websockets/frames.py index a0ce1d350..9a97f2530 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -1,8 +1,3 @@ -""" -Parse and serialize WebSocket frames. - -""" - from __future__ import annotations import dataclasses @@ -104,15 +99,16 @@ class Frame: """ WebSocket frame. - :param int opcode: opcode - :param bytes data: payload data - :param bool fin: FIN bit - :param bool rsv1: RSV1 bit - :param bool rsv2: RSV2 bit - :param bool rsv3: RSV3 bit + Args: + opcode: opcode. + data: payload data. + fin: FIN bit. + rsv1: RSV1 bit. + rsv2: RSV2 bit. + rsv3: RSV3 bit. Only these fields are needed. The MASK bit, payload length and masking-key - are handled on the fly by :meth:`parse` and :meth:`serialize`. + are handled on the fly when parsing and serializing frames. """ @@ -176,22 +172,23 @@ def parse( mask: bool, max_size: Optional[int] = None, extensions: Optional[Sequence[extensions.Extension]] = None, - ) -> Generator[None, None, "Frame"]: + ) -> Generator[None, None, Frame]: """ - Read a WebSocket frame. - - :param read_exact: generator-based coroutine that reads the requested - number of bytes or raises an exception if there isn't enough data - :param mask: whether the frame should be masked i.e. whether the read - happens on the server side - :param max_size: maximum payload size in bytes - :param extensions: list of classes with a ``decode()`` method that - transforms the frame and return a new frame; extensions are applied - in reverse order - :raises ~websockets.exceptions.PayloadTooBig: if the frame exceeds - ``max_size`` - :raises ~websockets.exceptions.ProtocolError: if the frame - contains incorrect values + Parse a WebSocket frame. + + This is a generator-based coroutine. + + Args: + read_exact: generator-based coroutine that reads the requested + bytes or raises an exception if there isn't enough data. + mask: whether the frame should be masked i.e. whether the read + happens on the server side. + max_size: maximum payload size in bytes. + extensions: list of extensions, applied in reverse order. + + Raises: + PayloadTooBig: if the frame's payload size exceeds ``max_size``. + ProtocolError: if the frame contains incorrect values. """ # Read the header. @@ -249,16 +246,15 @@ def serialize( extensions: Optional[Sequence[extensions.Extension]] = None, ) -> bytes: """ - Write a WebSocket frame. + Serialize a WebSocket frame. + + Args: + mask: whether the frame should be masked i.e. whether the write + happens on the client side. + extensions: list of extensions, applied in order. - :param frame: frame to write - :param mask: whether the frame should be masked i.e. whether the write - happens on the client side - :param extensions: list of classes with an ``encode()`` method that - transform the frame and return a new frame; extensions are applied - in order - :raises ~websockets.exceptions.ProtocolError: if the frame - contains incorrect values + Raises: + ProtocolError: if the frame contains incorrect values. """ self.check() @@ -306,8 +302,8 @@ def check(self) -> None: """ Check that reserved bits and opcode have acceptable values. - :raises ~websockets.exceptions.ProtocolError: if a reserved - bit or the opcode is invalid + Raises: + ProtocolError: if a reserved bit or the opcode is invalid. """ if self.rsv1 or self.rsv2 or self.rsv3: @@ -332,7 +328,8 @@ def prepare_data(data: Data) -> Tuple[int, bytes]: If ``data`` is a bytes-like object, return ``OP_BINARY`` and a bytes-like object. - :raises TypeError: if ``data`` doesn't have a supported type + Raises: + TypeError: if ``data`` doesn't have a supported type. """ if isinstance(data, str): @@ -354,7 +351,8 @@ def prepare_ctrl(data: Data) -> bytes: If ``data`` is a bytes-like object, return a :class:`bytes` object. - :raises TypeError: if ``data`` doesn't have a supported type + Raises: + TypeError: if ``data`` doesn't have a supported type. """ if isinstance(data, str): @@ -398,8 +396,12 @@ def parse(cls, data: bytes) -> Close: """ Parse the payload of a close frame. - :raises ~websockets.exceptions.ProtocolError: if data is ill-formed - :raises UnicodeDecodeError: if the reason isn't valid UTF-8 + Args: + data: payload of the close frame. + + Raises: + ProtocolError: if data is ill-formed. + UnicodeDecodeError: if the reason isn't valid UTF-8. """ if len(data) >= 2: @@ -415,9 +417,7 @@ def parse(cls, data: bytes) -> Close: def serialize(self) -> bytes: """ - Serialize the payload for a close frame. - - This is the reverse of :meth:`parse`. + Serialize the payload of a close frame. """ self.check() @@ -427,8 +427,8 @@ def check(self) -> None: """ Check that the close code has a valid value for a close frame. - :raises ~websockets.exceptions.ProtocolError: if the close code - is invalid + Raises: + ProtocolError: if the close code is invalid. """ if not (self.code in EXTERNAL_CLOSE_CODES or 3000 <= self.code < 5000): diff --git a/src/websockets/headers.py b/src/websockets/headers.py index ee6dd1672..a2fdfdd30 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -1,9 +1,3 @@ -""" -:mod:`websockets.headers` provides parsers and serializers for HTTP headers -used in WebSocket handshake messages. - -""" - from __future__ import annotations import base64 @@ -48,7 +42,7 @@ def peek_ahead(header: str, pos: int) -> Optional[str]: """ Return the next character from ``header`` at the given position. - Return ``None`` at the end of ``header``. + Return :obj:`None` at the end of ``header``. We never need to peek more than one character ahead. @@ -83,7 +77,8 @@ def parse_token(header: str, pos: int, header_name: str) -> Tuple[str, int]: Return the token value and the new position. - :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. + Raises: + InvalidHeaderFormat: on invalid inputs. """ match = _token_re.match(header, pos) @@ -106,7 +101,8 @@ def parse_quoted_string(header: str, pos: int, header_name: str) -> Tuple[str, i Return the unquoted value and the new position. - :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. + Raises: + InvalidHeaderFormat: on invalid inputs. """ match = _quoted_string_re.match(header, pos) @@ -158,7 +154,8 @@ def parse_list( Return a list of items. - :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. + Raises: + InvalidHeaderFormat: on invalid inputs. """ # Per https://www.rfc-editor.org/rfc/rfc7230.html#section-7, "a recipient @@ -211,7 +208,8 @@ def parse_connection_option( Return the protocol value and the new position. - :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. + Raises: + InvalidHeaderFormat: on invalid inputs. """ item, pos = parse_token(header, pos, header_name) @@ -224,8 +222,11 @@ def parse_connection(header: str) -> List[ConnectionOption]: Return a list of HTTP connection options. - :param header: value of the ``Connection`` header - :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. + Args + header: value of the ``Connection`` header. + + Raises: + InvalidHeaderFormat: on invalid inputs. """ return parse_list(parse_connection_option, header, 0, "Connection") @@ -244,7 +245,8 @@ def parse_upgrade_protocol( Return the protocol value and the new position. - :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. + Raises: + InvalidHeaderFormat: on invalid inputs. """ match = _protocol_re.match(header, pos) @@ -261,8 +263,11 @@ def parse_upgrade(header: str) -> List[UpgradeProtocol]: Return a list of HTTP protocols. - :param header: value of the ``Upgrade`` header - :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. + Args: + header: value of the ``Upgrade`` header. + + Raises: + InvalidHeaderFormat: on invalid inputs. """ return parse_list(parse_upgrade_protocol, header, 0, "Upgrade") @@ -276,7 +281,8 @@ def parse_extension_item_param( Return a ``(name, value)`` pair and the new position. - :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. + Raises: + InvalidHeaderFormat: on invalid inputs. """ # Extract parameter name. @@ -312,7 +318,8 @@ def parse_extension_item( Return an ``(extension name, parameters)`` pair, where ``parameters`` is a list of ``(name, value)`` pairs, and the new position. - :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. + Raises: + InvalidHeaderFormat: on invalid inputs. """ # Extract extension name. @@ -344,9 +351,10 @@ def parse_extension(header: str) -> List[ExtensionHeader]: ... ] - Parameter values are ``None`` when no value is provided. + Parameter values are :obj:`None` when no value is provided. - :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. + Raises: + InvalidHeaderFormat: on invalid inputs. """ return parse_list(parse_extension_item, header, 0, "Sec-WebSocket-Extensions") @@ -397,7 +405,8 @@ def parse_subprotocol_item( Return the subprotocol value and the new position. - :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. + Raises: + InvalidHeaderFormat: on invalid inputs. """ item, pos = parse_token(header, pos, header_name) @@ -410,7 +419,8 @@ def parse_subprotocol(header: str) -> List[Subprotocol]: Return a list of WebSocket subprotocols. - :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. + Raises: + InvalidHeaderFormat: on invalid inputs. """ return parse_list(parse_subprotocol_item, header, 0, "Sec-WebSocket-Protocol") @@ -450,7 +460,8 @@ def build_www_authenticate_basic(realm: str) -> str: """ Build a ``WWW-Authenticate`` header for HTTP Basic Auth. - :param realm: authentication realm + Args: + realm: identifier of the protection space. """ # https://www.rfc-editor.org/rfc/rfc7617.html#section-2 @@ -468,7 +479,8 @@ def parse_token68(header: str, pos: int, header_name: str) -> Tuple[str, int]: Return the token value and the new position. - :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. + Raises: + InvalidHeaderFormat: on invalid inputs. """ match = _token68_re.match(header, pos) @@ -494,9 +506,12 @@ def parse_authorization_basic(header: str) -> Tuple[str, str]: Return a ``(username, password)`` tuple. - :param header: value of the ``Authorization`` header - :raises InvalidHeaderFormat: on invalid inputs - :raises InvalidHeaderValue: on unsupported inputs + Args: + header: value of the ``Authorization`` header. + + Raises: + InvalidHeaderFormat: on invalid inputs. + InvalidHeaderValue: on unsupported inputs. """ # https://www.rfc-editor.org/rfc/rfc7235.html#section-2.1 diff --git a/src/websockets/http11.py b/src/websockets/http11.py index daa0efffb..b82a0bfdc 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -44,38 +44,46 @@ class Request: """ WebSocket handshake request. - :param path: path and optional query - :param headers: + Attributes: + path: Request path, including optional query. + headers: Request headers. + exception: If processing the response triggers an exception, + the exception is stored in this attribute. """ path: str headers: datastructures.Headers - # body isn't useful is the context of this library + # body isn't useful is the context of this library. - # If processing the request triggers an exception, it's stored here. exception: Optional[Exception] = None @classmethod def parse( - cls, read_line: Callable[[], Generator[None, None, bytes]] - ) -> Generator[None, None, "Request"]: + cls, + read_line: Callable[[], Generator[None, None, bytes]], + ) -> Generator[None, None, Request]: """ - Parse an HTTP/1.1 GET request and return ``(path, headers)``. + Parse a WebSocket handshake request. + + This is a generator-based coroutine. - ``path`` isn't URL-decoded or validated in any way. + The request path isn't URL-decoded or validated in any way. - ``path`` and ``headers`` are expected to contain only ASCII characters. - Other characters are represented with surrogate escapes. + The request path and headers are expected to contain only ASCII + characters. Other characters are represented with surrogate escapes. - :func:`parse_request` doesn't attempt to read the request body because + :meth:`parse` doesn't attempt to read the request body because WebSocket handshake requests don't have one. If the request contains a - body, it may be read from ``stream`` after this coroutine returns. + body, it may be read from the data stream after :meth:`parse` returns. - :param read_line: generator-based coroutine that reads a LF-terminated - line or raises an exception if there isn't enough data - :raises EOFError: if the connection is closed without a full HTTP request - :raises exceptions.SecurityError: if the request exceeds a security limit - :raises ValueError: if the request isn't well formatted + Args: + read_line: generator-based coroutine that reads a LF-terminated + line or raises an exception if there isn't enough data + + Raises: + EOFError: if the connection is closed without a full HTTP request. + SecurityError: if the request exceeds a security limit. + ValueError: if the request isn't well formatted. """ # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.1 @@ -114,10 +122,10 @@ def parse( def serialize(self) -> bytes: """ - Serialize an HTTP/1.1 GET request. + Serialize a WebSocket handshake request. """ - # Since the path and headers only contain ASCII characters, + # Since the request line and headers only contain ASCII characters, # we can keep this simple. request = f"GET {self.path} HTTP/1.1\r\n".encode() request += self.headers.serialize() @@ -129,6 +137,14 @@ class Response: """ WebSocket handshake response. + Attributes: + status_code: Response code. + reason_phrase: Response reason. + headers: Response headers. + body: Response body, if any. + exception: if processing the response triggers an exception, + the exception is stored in this attribute. + """ status_code: int @@ -136,7 +152,6 @@ class Response: headers: datastructures.Headers body: Optional[bytes] = None - # If processing the response triggers an exception, it's stored here. exception: Optional[Exception] = None @classmethod @@ -145,32 +160,32 @@ def parse( read_line: Callable[[], Generator[None, None, bytes]], read_exact: Callable[[int], Generator[None, None, bytes]], read_to_eof: Callable[[], Generator[None, None, bytes]], - ) -> Generator[None, None, "Response"]: + ) -> Generator[None, None, Response]: """ - Parse an HTTP/1.1 response and return ``(status_code, reason, headers)``. + Parse a WebSocket handshake response. - ``reason`` and ``headers`` are expected to contain only ASCII characters. - Other characters are represented with surrogate escapes. + This is a generator-based coroutine. - :func:`parse_request` doesn't attempt to read the response body because - WebSocket handshake responses don't have one. If the response contains a - body, it may be read from ``stream`` after this coroutine returns. + The reason phrase and headers are expected to contain only ASCII + characters. Other characters are represented with surrogate escapes. - :param read_line: generator-based coroutine that reads a LF-terminated - line or raises an exception if there isn't enough data - :param read_exact: generator-based coroutine that reads the requested - number of bytes or raises an exception if there isn't enough data - :raises EOFError: if the connection is closed without a full HTTP response - :raises exceptions.SecurityError: if the response exceeds a security limit - :raises LookupError: if the response isn't well formatted - :raises ValueError: if the response isn't well formatted + Args: + read_line: generator-based coroutine that reads a LF-terminated + line or raises an exception if there isn't enough data. + read_exact: generator-based coroutine that reads the requested + bytes or raises an exception if there isn't enough data. + read_to_eof: generator-based coroutine that reads until the end + of the strem. + + Raises: + EOFError: if the connection is closed without a full HTTP response. + SecurityError: if the response exceeds a security limit. + LookupError: if the response isn't well formatted. + ValueError: if the response isn't well formatted. """ # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.2 - # As in parse_request, parsing is simple because a fixed value is expected - # for version, status_code is a 3-digit number, and reason can be ignored. - try: status_line = yield from parse_line(read_line) except EOFError as exc: @@ -227,7 +242,7 @@ def parse( def serialize(self) -> bytes: """ - Serialize an HTTP/1.1 GET response. + Serialize a WebSocket handshake response. """ # Since the status line and headers only contain ASCII characters, @@ -240,15 +255,21 @@ def serialize(self) -> bytes: def parse_headers( - read_line: Callable[[], Generator[None, None, bytes]] + read_line: Callable[[], Generator[None, None, bytes]], ) -> Generator[None, None, datastructures.Headers]: """ Parse HTTP headers. Non-ASCII characters are represented with surrogate escapes. - :param read_line: generator-based coroutine that reads a LF-terminated - line or raises an exception if there isn't enough data + Args: + read_line: generator-based coroutine that reads a LF-terminated line + or raises an exception if there isn't enough data. + + Raises: + EOFError: if the connection is closed without complete headers. + SecurityError: if the request exceeds a security limit. + ValueError: if the request isn't well formatted. """ # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.2 @@ -285,15 +306,20 @@ def parse_headers( def parse_line( - read_line: Callable[[], Generator[None, None, bytes]] + read_line: Callable[[], Generator[None, None, bytes]], ) -> Generator[None, None, bytes]: """ Parse a single line. CRLF is stripped from the return value. - :param read_line: generator-based coroutine that reads a LF-terminated - line or raises an exception if there isn't enough data + Args: + read_line: generator-based coroutine that reads a LF-terminated line + or raises an exception if there isn't enough data. + + Raises: + EOFError: if the connection is closed without a CRLF. + SecurityError: if the response exceeds a security limit. """ # Security: TODO: add a limit here diff --git a/src/websockets/imports.py b/src/websockets/imports.py index c9508d188..a6a59d4c2 100644 --- a/src/websockets/imports.py +++ b/src/websockets/imports.py @@ -9,15 +9,15 @@ def import_name(name: str, source: str, namespace: Dict[str, Any]) -> Any: """ - Import from in . + Import ``name`` from ``source`` in ``namespace``. - There are two cases: + There are two use cases: - - is an object defined in - - is a submodule of source + - ``name`` is an object defined in ``source``; + - ``name`` is a submodule of ``source``. - Neither __import__ nor importlib.import_module does exactly this. - __import__ is closer to the intended behavior. + Neither :func:`__import__` nor :func:`~importlib.import_module` does + exactly this. :func:`__import__` is closer to the intended behavior. """ level = 0 @@ -49,7 +49,7 @@ def lazy_import( } ) - This function defines __getattr__ and __dir__ per PEP 562. + This function defines ``__getattr__`` and ``__dir__`` per :pep:`562`. """ if aliases is None: diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index 5f2b1311a..8825c14ec 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -1,9 +1,3 @@ -""" -:mod:`websockets.legacy.auth` provides HTTP Basic Authentication according to -:rfc:`7235` and :rfc:`7617`. - -""" - from __future__ import annotations import functools @@ -37,7 +31,16 @@ class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol): """ - realm = "" + realm: str = "" + """ + Scope of protection. + + If provided, it should contain only ASCII characters because the + encoding of non-ASCII characters is undefined. + """ + + username: Optional[str] = None + """Username of the authenticated user.""" def __init__( self, @@ -55,13 +58,17 @@ async def check_credentials(self, username: str, password: str) -> bool: """ Check whether credentials are authorized. - If ``check_credentials`` returns ``True``, the WebSocket handshake - continues. If it returns ``False``, the handshake fails with a HTTP - 401 error. - This coroutine may be overridden in a subclass, for example to authenticate against a database or an external service. + Args: + username: HTTP Basic Auth username. + password: HTTP Basic Auth password. + + Returns: + bool: :obj:`True` if the handshake should continue; + :obj:`False` if it should fail with a HTTP 401 error. + """ if self._check_credentials is not None: return await self._check_credentials(username, password) @@ -116,8 +123,8 @@ def basic_auth_protocol_factory( """ Protocol factory that enforces HTTP Basic Auth. - ``basic_auth_protocol_factory`` is designed to integrate with - :func:`~websockets.legacy.server.serve` like this:: + :func:`basic_auth_protocol_factory` is designed to integrate with + :func:`~websockets.server.serve` like this:: websockets.serve( ..., @@ -127,28 +134,22 @@ def basic_auth_protocol_factory( ) ) - ``realm`` indicates the scope of protection. It should contain only ASCII - characters because the encoding of non-ASCII characters is undefined. - Refer to section 2.2 of :rfc:`7235` for details. - - ``credentials`` defines hard coded authorized credentials. It can be a - ``(username, password)`` pair or a list of such pairs. - - ``check_credentials`` defines a coroutine that checks whether credentials - are authorized. This coroutine receives ``username`` and ``password`` - arguments and returns a :class:`bool`. - - One of ``credentials`` or ``check_credentials`` must be provided but not - both. - - By default, ``basic_auth_protocol_factory`` creates a factory for building - :class:`BasicAuthWebSocketServerProtocol` instances. You can override this - with the ``create_protocol`` parameter. - - :param realm: scope of protection - :param credentials: hard coded credentials - :param check_credentials: coroutine that verifies credentials - :raises TypeError: if the credentials argument has the wrong type + Args: + realm: indicates the scope of protection. It should contain only ASCII + characters because the encoding of non-ASCII characters is + undefined. Refer to section 2.2 of :rfc:`7235` for details. + credentials: defines hard coded authorized credentials. It can be a + ``(username, password)`` pair or a list of such pairs. + check_credentials: defines a coroutine that verifies credentials. + This coroutine receives ``username`` and ``password`` arguments + and returns a :class:`bool`. One of ``credentials`` or + ``check_credentials`` must be provided but not both. + create_protocol: factory that creates the protocol. By default, this + is :class:`BasicAuthWebSocketServerProtocol`. It can be replaced + by a subclass. + Raises: + TypeError: if the ``credentials`` or ``check_credentials`` argument is + wrong. """ if (credentials is None) == (check_credentials is None): diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 6d976e0df..e5743cc0e 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -1,8 +1,3 @@ -""" -:mod:`websockets.legacy.client` defines the WebSocket client APIs. - -""" - from __future__ import annotations import asyncio @@ -57,99 +52,27 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): """ - :class:`~asyncio.Protocol` subclass implementing a WebSocket client. - - :class:`WebSocketClientProtocol`: + WebSocket client connection. - * performs the opening handshake to establish the connection; - * provides :meth:`recv` and :meth:`send` coroutines for receiving and - sending messages; - * deals with control frames automatically; - * performs the closing handshake to terminate the connection. + :class:`WebSocketClientProtocol` provides :meth:`recv` and :meth:`send` + coroutines for receiving and sending messages. - :class:`WebSocketClientProtocol` supports asynchronous iteration:: + It supports asynchronous iteration to receive incoming messages:: async for message in websocket: await process(message) - The iterator yields incoming messages. It exits normally when the - connection is closed with the close code 1000 (OK) or 1001 (going away). - It raises a :exc:`~websockets.exceptions.ConnectionClosedError` exception - when the connection is closed with any other code. - - Once the connection is open, a `Ping frame`_ is sent every - ``ping_interval`` seconds. This serves as a keepalive. It helps keeping - the connection open, especially in the presence of proxies with short - timeouts on inactive connections. Set ``ping_interval`` to ``None`` to - disable this behavior. - - .. _Ping frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 - - If the corresponding `Pong frame`_ isn't received within ``ping_timeout`` - seconds, the connection is considered unusable and is closed with - code 1011. This ensures that the remote endpoint remains responsive. Set - ``ping_timeout`` to ``None`` to disable this behavior. - - .. _Pong frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 - - The ``close_timeout`` parameter defines a maximum wait time for completing - the closing handshake and terminating the TCP connection. For legacy - reasons, :meth:`close` completes in at most ``5 * close_timeout`` seconds. - - ``close_timeout`` needs to be a parameter of the protocol because - websockets usually calls :meth:`close` implicitly upon exit when - :func:`connect` is used as a context manager. - - To apply a timeout to any other API, wrap it in :func:`~asyncio.wait_for`. - - The ``max_size`` parameter enforces the maximum size for incoming messages - in bytes. The default value is 1 MiB. ``None`` disables the limit. If a - message larger than the maximum size is received, :meth:`recv` will - raise :exc:`~websockets.exceptions.ConnectionClosedError` and the - connection will be closed with code 1009. - - The ``max_queue`` parameter sets the maximum length of the queue that - holds incoming messages. The default value is ``32``. ``None`` disables - the limit. Messages are added to an in-memory queue when they're received; - then :meth:`recv` pops from that queue. In order to prevent excessive - memory consumption when messages are received faster than they can be - processed, the queue must be bounded. If the queue fills up, the protocol - stops processing incoming data until :meth:`recv` is called. In this - situation, various receive buffers (at least in :mod:`asyncio` and in the - OS) will fill up, then the TCP receive window will shrink, slowing down - transmission to avoid packet loss. - - Since Python can use up to 4 bytes of memory to represent a single - character, each connection may use up to ``4 * max_size * max_queue`` - bytes of memory to store incoming messages. By default, this is 128 MiB. - You may want to lower the limits, depending on your application's - requirements. - - The ``read_limit`` argument sets the high-water limit of the buffer for - incoming bytes. The low-water limit is half the high-water limit. The - default value is 64 KiB, half of asyncio's default (based on the current - implementation of :class:`~asyncio.StreamReader`). - - The ``write_limit`` argument sets the high-water limit of the buffer for - outgoing bytes. The low-water limit is a quarter of the high-water limit. - The default value is 64 KiB, equal to asyncio's default (based on the - current implementation of ``FlowControlMixin``). - - As soon as the HTTP request and response in the opening handshake are - processed: - - * the request path is available in the :attr:`path` attribute; - * the request and response HTTP headers are available in the - :attr:`request_headers` and :attr:`response_headers` attributes, - which are :class:`~websockets.http.Headers` instances. - - If a subprotocol was negotiated, it's available in the :attr:`subprotocol` - attribute. - - Once the connection is closed, the code is available in the - :attr:`close_code` attribute and the reason in :attr:`close_reason`. - - All attributes must be treated as read-only. + The iterator exits normally when the connection is closed with close code + 1000 (OK) or 1001 (going away). It raises + a :exc:`~websockets.exceptions.ConnectionClosedError` when the connection + is closed with any other code. + + See :func:`connect` for the documentation of ``logger``, ``origin``, + ``extensions``, ``subprotocols``, and ``extra_headers``. + + See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the + documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, + ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``. """ @@ -159,11 +82,11 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): def __init__( self, *, + logger: Optional[LoggerLike] = None, origin: Optional[Origin] = None, extensions: Optional[Sequence[ClientExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLike] = None, - logger: Optional[LoggerLike] = None, **kwargs: Any, ) -> None: if logger is None: @@ -201,8 +124,9 @@ async def read_http_response(self) -> Tuple[int, Headers]: If the response contains a body, it may be read from ``self.reader`` after this coroutine returns. - :raises ~websockets.exceptions.InvalidMessage: if the HTTP message is - malformed or isn't an HTTP/1.1 GET response + Raises: + InvalidMessage: if the HTTP message is malformed or isn't an + HTTP/1.1 GET response. """ try: @@ -346,17 +270,17 @@ async def handshake( """ Perform the client side of the opening handshake. - :param origin: sets the Origin HTTP header - :param available_extensions: list of supported extensions in the order - in which they should be used - :param available_subprotocols: list of supported subprotocols in order - of decreasing preference - :param extra_headers: sets additional HTTP request headers; it must be - a :class:`~websockets.http.Headers` instance, a - :class:`~collections.abc.Mapping`, or an iterable of ``(name, - value)`` pairs - :raises ~websockets.exceptions.InvalidHandshake: if the handshake - fails + Args: + wsuri: URI of the WebSocket server. + origin: value of the ``Origin`` header. + available_extensions: list of supported extensions, in order in + which they should be tried. + available_subprotocols: list of supported subprotocols, in order + of decreasing preference. + extra_headers: arbitrary HTTP headers to add to the request. + + Raises: + InvalidHandshake: if the handshake fails. """ request_headers = Headers() @@ -416,14 +340,14 @@ async def handshake( class Connect: """ - Connect to the WebSocket server at the given ``uri``. + Connect to the WebSocket server at ``uri``. Awaiting :func:`connect` yields a :class:`WebSocketClientProtocol` which can then be used to send and receive messages. :func:`connect` can be used as a asynchronous context manager:: - async with connect(...) as websocket: + async with websockets.connect(...) as websocket: ... The connection is closed automatically when exiting the context. @@ -431,59 +355,71 @@ class Connect: :func:`connect` can be used as an infinite asynchronous iterator to reconnect automatically on errors:: - async for websocket in connect(...): - ... + async for websocket in websockets.connect(...): + try: + ... + except websockets.ConnectionClosed: + continue - You must catch all exceptions, or else you will exit the loop prematurely. - As above, connections are closed automatically. Connection attempts are - delayed with exponential backoff, starting at three seconds and - increasing up to one minute. - - :func:`connect` is a wrapper around the event loop's - :meth:`~asyncio.loop.create_connection` method. Unknown keyword arguments - are passed to :meth:`~asyncio.loop.create_connection`. - - For example, you can set the ``ssl`` keyword argument to a - :class:`~ssl.SSLContext` to enforce some TLS settings. When connecting to - a ``wss://`` URI, if this argument isn't provided explicitly, - :func:`ssl.create_default_context` is called to create a context. - - You can connect to a different host and port from those found in ``uri`` - by setting ``host`` and ``port`` keyword arguments. This only changes the - destination of the TCP connection. The host name from ``uri`` is still - used in the TLS handshake for secure connections and in the ``Host`` HTTP - header. - - ``create_protocol`` defaults to :class:`WebSocketClientProtocol`. It may - be replaced by a wrapper or a subclass to customize the protocol that - manages the connection. - - If the WebSocket connection isn't established within ``open_timeout`` - seconds, :func:`connect` raises :exc:`~asyncio.TimeoutError`. The default - is 10 seconds. Set ``open_timeout`` to ``None`` to disable the timeout. - - The behavior of ``ping_interval``, ``ping_timeout``, ``close_timeout``, - ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit`` is - described in :class:`WebSocketClientProtocol`. - - :func:`connect` also accepts the following optional arguments: - - * ``compression`` is a shortcut to configure compression extensions; - by default it enables the "permessage-deflate" extension; set it to - ``None`` to disable compression. - * ``origin`` sets the Origin HTTP header. - * ``extensions`` is a list of supported extensions in order of - decreasing preference. - * ``subprotocols`` is a list of supported subprotocols in order of - decreasing preference. - * ``extra_headers`` sets additional HTTP request headers; it can be a - :class:`~websockets.http.Headers` instance, a - :class:`~collections.abc.Mapping`, or an iterable of ``(name, value)`` - pairs. - - :raises ~websockets.uri.InvalidURI: if ``uri`` is invalid - :raises ~websockets.handshake.InvalidHandshake: if the opening handshake - fails + The connection is closed automatically after each iteration of the loop. + + If an error occurs while establishing the connection, :func:`connect` + retries with exponential backoff. The backoff delay starts at three + seconds and increases up to one minute. + + If an error occurs in the body of the loop, you can handle the exception + and :func:`connect` will reconnect with the next iteration; or you can + let the exception bubble up and break out of the loop. This lets you + decide which errors trigger a reconnection and which errors are fatal. + + Args: + uri: URI of the WebSocket server. + create_protocol: factory for the :class:`asyncio.Protocol` managing + the connection; defaults to :class:`WebSocketClientProtocol`; may + be set to a wrapper or a subclass to customize connection handling. + logger: logger for this connection; + defaults to ``logging.getLogger("websockets.client")``; + see the :doc:`logging guide <../topics/logging>` for details. + compression: shortcut that enables the "permessage-deflate" extension + by default; may be set to :obj:`None` to disable compression; + see the :doc:`compression guide <../topics/compression>` for details. + origin: value of the ``Origin`` header. This is useful when connecting + to a server that validates the ``Origin`` header to defend against + Cross-Site WebSocket Hijacking attacks. + extensions: list of supported extensions, in order in which they + should be tried. + subprotocols: list of supported subprotocols, in order of decreasing + preference. + extra_headers: arbitrary HTTP headers to add to the request. + open_timeout: timeout for opening the connection in seconds; + :obj:`None` to disable the timeout + + See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the + documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, + ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``. + + Any other keyword arguments are passed the event loop's + :meth:`~asyncio.loop.create_connection` method. + + For example: + + * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enforce TLS + settings. When connecting to a ``wss://`` URI, if ``ssl`` isn't + provided, a TLS context is created + with :func:`~ssl.create_default_context`. + + * You can set ``host`` and ``port`` to connect to a different host and + port from those found in ``uri``. This only changes the destination of + the TCP connection. The host name from ``uri`` is still used in the TLS + handshake for secure connections and in the ``Host`` header. + + Returns: + WebSocketClientProtocol: WebSocket connection. + + Raises: + InvalidURI: if ``uri`` isn't a valid WebSocket URI. + InvalidHandshake: if the opening handshake fails. + ~asyncio.TimeoutError: if the opening handshake times out. """ @@ -494,6 +430,12 @@ def __init__( uri: str, *, create_protocol: Optional[Callable[[Any], WebSocketClientProtocol]] = None, + logger: Optional[LoggerLike] = None, + compression: Optional[str] = "deflate", + origin: Optional[Origin] = None, + extensions: Optional[Sequence[ClientExtensionFactory]] = None, + subprotocols: Optional[Sequence[Subprotocol]] = None, + extra_headers: Optional[HeadersLike] = None, open_timeout: Optional[float] = 10, ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, @@ -502,12 +444,6 @@ def __init__( max_queue: Optional[int] = 2 ** 5, read_limit: int = 2 ** 16, write_limit: int = 2 ** 16, - compression: Optional[str] = "deflate", - origin: Optional[Origin] = None, - extensions: Optional[Sequence[ClientExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLike] = None, - logger: Optional[LoggerLike] = None, **kwargs: Any, ) -> None: # Backwards compatibility: close_timeout used to be called timeout. @@ -560,6 +496,11 @@ def __init__( factory = functools.partial( create_protocol, + logger=logger, + origin=origin, + extensions=extensions, + subprotocols=subprotocols, + extra_headers=extra_headers, ping_interval=ping_interval, ping_timeout=ping_timeout, close_timeout=close_timeout, @@ -567,16 +508,11 @@ def __init__( max_queue=max_queue, read_limit=read_limit, write_limit=write_limit, - loop=_loop, host=wsuri.host, port=wsuri.port, secure=wsuri.secure, legacy_recv=legacy_recv, - origin=origin, - extensions=extensions, - subprotocols=subprotocols, - extra_headers=extra_headers, - logger=logger, + loop=_loop, ) if kwargs.pop("unix", False): @@ -743,15 +679,17 @@ def unix_connect( """ Similar to :func:`connect`, but for connecting to a Unix socket. - This function calls the event loop's + This function builds upon the event loop's :meth:`~asyncio.loop.create_unix_connection` method. It is only available on Unix. It's mainly useful for debugging servers listening on Unix sockets. - :param path: file system path to the Unix socket - :param uri: WebSocket URI + Args: + path: file system path to the Unix socket. + uri: URI of the WebSocket server; the host is used in the TLS + handshake for secure connections and in the ``Host`` header. """ return connect(uri=uri, path=path, unix=True, **kwargs) diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index 40cbd41bf..c4de7eb28 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -1,15 +1,3 @@ -""" -:mod:`websockets.legacy.framing` reads and writes WebSocket frames. - -It deals with a single frame at a time. Anything that depends on the sequence -of frames is implemented in :mod:`websockets.legacy.protocol`. - -See `section 5 of RFC 6455`_. - -.. _section 5 of RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-5 - -""" - from __future__ import annotations import dataclasses @@ -60,22 +48,21 @@ async def read( mask: bool, max_size: Optional[int] = None, extensions: Optional[Sequence[extensions.Extension]] = None, - ) -> "Frame": + ) -> Frame: """ Read a WebSocket frame. - :param reader: coroutine that reads exactly the requested number of - bytes, unless the end of file is reached - :param mask: whether the frame should be masked i.e. whether the read - happens on the server side - :param max_size: maximum payload size in bytes - :param extensions: list of classes with a ``decode()`` method that - transforms the frame and return a new frame; extensions are applied - in reverse order - :raises ~websockets.exceptions.PayloadTooBig: if the frame exceeds - ``max_size`` - :raises ~websockets.exceptions.ProtocolError: if the frame - contains incorrect values + Args: + reader: coroutine that reads exactly the requested number of + bytes, unless the end of file is reached. + mask: whether the frame should be masked i.e. whether the read + happens on the server side. + max_size: maximum payload size in bytes. + extensions: list of extensions, applied in reverse order. + + Raises: + PayloadTooBig: if the frame exceeds ``max_size``. + ProtocolError: if the frame contains incorrect values. """ @@ -142,15 +129,15 @@ def write( """ Write a WebSocket frame. - :param frame: frame to write - :param write: function that writes bytes - :param mask: whether the frame should be masked i.e. whether the write - happens on the client side - :param extensions: list of classes with an ``encode()`` method that - transform the frame and return a new frame; extensions are applied - in order - :raises ~websockets.exceptions.ProtocolError: if the frame - contains incorrect values + Args: + frame: frame to write. + write: function that writes bytes. + mask: whether the frame should be masked i.e. whether the write + happens on the client side. + extensions: list of extensions, applied in order. + + Raises: + ProtocolError: if the frame contains incorrect values. """ # The frame is written in a single call to write in order to prevent @@ -168,10 +155,12 @@ def parse_close(data: bytes) -> Tuple[int, str]: """ Parse the payload from a close frame. - Return ``(code, reason)``. + Returns: + Tuple[int, str]: close code and reason. - :raises ~websockets.exceptions.ProtocolError: if data is ill-formed - :raises UnicodeDecodeError: if the reason isn't valid UTF-8 + Raises: + ProtocolError: if data is ill-formed. + UnicodeDecodeError: if the reason isn't valid UTF-8. """ return dataclasses.astuple(Close.parse(data)) # type: ignore @@ -181,7 +170,5 @@ def serialize_close(code: int, reason: str) -> bytes: """ Serialize the payload for a close frame. - This is the reverse of :func:`parse_close`. - """ return Close(code, reason).serialize() diff --git a/src/websockets/legacy/handshake.py b/src/websockets/legacy/handshake.py index 7cde58ac1..569937bb9 100644 --- a/src/websockets/legacy/handshake.py +++ b/src/websockets/legacy/handshake.py @@ -1,30 +1,3 @@ -""" -:mod:`websockets.legacy.handshake` provides helpers for the WebSocket handshake. - -See `section 4 of RFC 6455`_. - -.. _section 4 of RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-4 - -Some checks cannot be performed because they depend too much on the -context; instead, they're documented below. - -To accept a connection, a server must: - -- Read the request, check that the method is GET, and check the headers with - :func:`check_request`, -- Send a 101 response to the client with the headers created by - :func:`build_response` if the request is valid; otherwise, send an - appropriate HTTP error code. - -To open a connection, a client must: - -- Send a GET request to the server with the headers created by - :func:`build_request`, -- Read the response, check that the status code is 101, and check the headers - with :func:`check_response`. - -""" - from __future__ import annotations import base64 @@ -47,8 +20,11 @@ def build_request(headers: Headers) -> str: Update request headers passed in argument. - :param headers: request headers - :returns: ``key`` which must be passed to :func:`check_response` + Args: + headers: handshake request headers. + + Returns: + str: ``key`` that must be passed to :func:`check_response`. """ key = generate_key() @@ -68,10 +44,15 @@ def check_request(headers: Headers) -> str: are usually performed earlier in the HTTP request handling code. They're the responsibility of the caller. - :param headers: request headers - :returns: ``key`` which must be passed to :func:`build_response` - :raises ~websockets.exceptions.InvalidHandshake: if the handshake request - is invalid; then the server must return 400 Bad Request error + Args: + headers: handshake request headers. + + Returns: + str: ``key`` that must be passed to :func:`build_response`. + + Raises: + InvalidHandshake: if the handshake request is invalid; + then the server must return 400 Bad Request error. """ connection: List[ConnectionOption] = sum( @@ -128,8 +109,9 @@ def build_response(headers: Headers, key: str) -> None: Update response headers passed in argument. - :param headers: response headers - :param key: comes from :func:`check_request` + Args: + headers: handshake response headers. + key: returned by :func:`check_request`. """ headers["Upgrade"] = "websocket" @@ -145,10 +127,12 @@ def check_response(headers: Headers, key: str) -> None: response with a 101 status code. These controls are the responsibility of the caller. - :param headers: response headers - :param key: comes from :func:`build_request` - :raises ~websockets.exceptions.InvalidHandshake: if the handshake response - is invalid + Args: + headers: handshake response headers. + key: returned by :func:`build_request`. + + Raises: + InvalidHandshake: if the handshake response is invalid. """ connection: List[ConnectionOption] = sum( diff --git a/src/websockets/legacy/http.py b/src/websockets/legacy/http.py index 3725fa9c3..cc2ef1f06 100644 --- a/src/websockets/legacy/http.py +++ b/src/websockets/legacy/http.py @@ -55,10 +55,13 @@ async def read_request(stream: asyncio.StreamReader) -> Tuple[str, Headers]: WebSocket handshake requests don't have one. If the request contains a body, it may be read from ``stream`` after this coroutine returns. - :param stream: input to read the request from - :raises EOFError: if the connection is closed without a full HTTP request - :raises SecurityError: if the request exceeds a security limit - :raises ValueError: if the request isn't well formatted + Args: + stream: input to read the request from + + Raises: + EOFError: if the connection is closed without a full HTTP request + SecurityError: if the request exceeds a security limit + ValueError: if the request isn't well formatted """ # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.1 @@ -99,10 +102,13 @@ async def read_response(stream: asyncio.StreamReader) -> Tuple[int, str, Headers WebSocket handshake responses don't have one. If the response contains a body, it may be read from ``stream`` after this coroutine returns. - :param stream: input to read the response from - :raises EOFError: if the connection is closed without a full HTTP response - :raises SecurityError: if the response exceeds a security limit - :raises ValueError: if the response isn't well formatted + Args: + stream: input to read the response from + + Raises: + EOFError: if the connection is closed without a full HTTP response + SecurityError: if the response exceeds a security limit + ValueError: if the response isn't well formatted """ # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.2 diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 99a821be6..4631151e6 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1,12 +1,3 @@ -""" -:mod:`websockets.legacy.protocol` handles WebSocket control and data frames. - -See `sections 4 to 8 of RFC 6455`_. - -.. _sections 4 to 8 of RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-4 - -""" - from __future__ import annotations import asyncio @@ -71,17 +62,99 @@ class WebSocketCommonProtocol(asyncio.Protocol): """ - :class:`~asyncio.Protocol` subclass implementing the data transfer phase. - - Once the WebSocket connection is established, during the data transfer - phase, the protocol is almost symmetrical between the server side and the - client side. :class:`WebSocketCommonProtocol` implements logic that's - shared between servers and clients. - - Subclasses such as - :class:`~websockets.legacy.server.WebSocketServerProtocol` and - :class:`~websockets.legacy.client.WebSocketClientProtocol` implement the - opening handshake, which is different between servers and clients. + WebSocket connection. + + :class:`WebSocketCommonProtocol` provides APIs shared between WebSocket + servers and clients. You shouldn't use it directly. Instead, use + :class:`~websockets.client.WebSocketClientProtocol` or + :class:`~websockets.server.WebSocketServerProtocol`. + + This documentation focuses on low-level details that aren't covered in the + documentation of :class:`~websockets.client.WebSocketClientProtocol` and + :class:`~websockets.server.WebSocketServerProtocol` for the sake of + simplicity. + + Once the connection is open, a Ping_ frame is sent every ``ping_interval`` + seconds. This serves as a keepalive. It helps keeping the connection + open, especially in the presence of proxies with short timeouts on + inactive connections. Set ``ping_interval`` to :obj:`None` to disable + this behavior. + + .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 + + If the corresponding Pong_ frame isn't received within ``ping_timeout`` + seconds, the connection is considered unusable and is closed with code + 1011. This ensures that the remote endpoint remains responsive. Set + ``ping_timeout`` to :obj:`None` to disable this behavior. + + .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 + + The ``close_timeout`` parameter defines a maximum wait time for completing + the closing handshake and terminating the TCP connection. For legacy + reasons, :meth:`close` completes in at most ``5 * close_timeout`` seconds + for clients and ``4 * close_timeout`` for servers. + + See the discussion of :doc:`timeouts <../topics/timeouts>` for details. + + ``close_timeout`` needs to be a parameter of the protocol because + websockets usually calls :meth:`close` implicitly upon exit: + + * on the client side, when :func:`~websockets.client.connect` is used as a + context manager; + * on the server side, when the connection handler terminates; + + To apply a timeout to any other API, wrap it in :func:`~asyncio.wait_for`. + + The ``max_size`` parameter enforces the maximum size for incoming messages + in bytes. The default value is 1 MiB. If a larger message is received, + :meth:`recv` will raise :exc:`~websockets.exceptions.ConnectionClosedError` + and the connection will be closed with code 1009. + + The ``max_queue`` parameter sets the maximum length of the queue that + holds incoming messages. The default value is ``32``. Messages are added + to an in-memory queue when they're received; then :meth:`recv` pops from + that queue. In order to prevent excessive memory consumption when + messages are received faster than they can be processed, the queue must + be bounded. If the queue fills up, the protocol stops processing incoming + data until :meth:`recv` is called. In this situation, various receive + buffers (at least in :mod:`asyncio` and in the OS) will fill up, then the + TCP receive window will shrink, slowing down transmission to avoid packet + loss. + + Since Python can use up to 4 bytes of memory to represent a single + character, each connection may use up to ``4 * max_size * max_queue`` + bytes of memory to store incoming messages. By default, this is 128 MiB. + You may want to lower the limits, depending on your application's + requirements. + + The ``read_limit`` argument sets the high-water limit of the buffer for + incoming bytes. The low-water limit is half the high-water limit. The + default value is 64 KiB, half of asyncio's default (based on the current + implementation of :class:`~asyncio.StreamReader`). + + The ``write_limit`` argument sets the high-water limit of the buffer for + outgoing bytes. The low-water limit is a quarter of the high-water limit. + The default value is 64 KiB, equal to asyncio's default (based on the + current implementation of ``FlowControlMixin``). + + See the discussion of :doc:`memory usage <../topics/memory>` for details. + + Args: + logger: logger for this connection; + defaults to ``logging.getLogger("websockets.protocol")``; + see the :doc:`logging guide <../topics/logging>` for details. + ping_interval: delay between keepalive pings in seconds; + :obj:`None` to disable keepalive pings. + ping_timeout: timeout for keepalive pings in seconds; + :obj:`None` to disable timeouts. + close_timeout: timeout for closing the connection in seconds; + for legacy reasons, the actual timeout is 4 or 5 times larger. + max_size: maximum size of incoming messages in bytes; + :obj:`None` to disable the limit. + max_queue: maximum number of incoming messages in receive buffer; + :obj:`None` to disable the limit. + read_limit: high-water mark of read buffer in bytes. + write_limit: high-water mark of write buffer in bytes. """ @@ -94,6 +167,7 @@ class WebSocketCommonProtocol(asyncio.Protocol): def __init__( self, *, + logger: Optional[LoggerLike] = None, ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, close_timeout: Optional[float] = None, @@ -101,14 +175,13 @@ def __init__( max_queue: Optional[int] = 2 ** 5, read_limit: int = 2 ** 16, write_limit: int = 2 ** 16, - logger: Optional[LoggerLike] = None, # The following arguments are kept only for backwards compatibility. host: Optional[str] = None, port: Optional[int] = None, secure: Optional[bool] = None, - timeout: Optional[float] = None, legacy_recv: bool = False, loop: Optional[asyncio.AbstractEventLoop] = None, + timeout: Optional[float] = None, ) -> None: # Backwards compatibility: close_timeout used to be called timeout. if timeout is None: @@ -134,7 +207,8 @@ def __init__( self.write_limit = write_limit # Unique identifier. For logs. - self.id = uuid.uuid4() + self.id: uuid.UUID = uuid.uuid4() + """Unique identifier of the connection. Useful in logs.""" # Logger or LoggerAdapter for this connection. if logger is None: @@ -175,12 +249,16 @@ def __init__( # HTTP protocol parameters. self.path: str + """Path of the opening handshake request.""" self.request_headers: Headers + """Opening handshake request headers.""" self.response_headers: Headers + """Opening handshake response headers.""" # WebSocket protocol parameters. self.extensions: List[Extension] = [] self.subprotocol: Optional[Subprotocol] = None + """Subprotocol, if one was negotiated.""" # Close code and reason, set when a close frame is sent or received. self.close_rcvd: Optional[Close] = None @@ -286,9 +364,14 @@ def secure(self) -> Optional[bool]: @property def local_address(self) -> Any: """ - Local address of the connection as a ``(host, port)`` tuple. + Local address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family; + see :meth:`~socket.socket.getsockname`. - When the connection isn't open, ``local_address`` is ``None``. + :obj:`None` if the TCP connection isn't established yet. """ try: @@ -301,9 +384,14 @@ def local_address(self) -> Any: @property def remote_address(self) -> Any: """ - Remote address of the connection as a ``(host, port)`` tuple. + Remote address of the connection. - When the connection isn't open, ``remote_address`` is ``None``. + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family; + see :meth:`~socket.socket.getpeername`. + + :obj:`None` if the TCP connection isn't established yet. """ try: @@ -316,13 +404,11 @@ def remote_address(self) -> Any: @property def open(self) -> bool: """ - ``True`` when the connection is usable. + :obj:`True` when the connection is open; :obj:`False` otherwise. - It may be used to detect disconnections. However, this approach is - discouraged per the EAFP_ principle. - - When ``open`` is ``False``, using the connection raises a - :exc:`~websockets.exceptions.ConnectionClosed` exception. + This attribute may be used to detect disconnections. However, this + approach is discouraged per the EAFP_ principle. Instead, you should + handle :exc:`~websockets.exceptions.ConnectionClosed` exceptions. .. _EAFP: https://docs.python.org/3/glossary.html#term-eafp @@ -332,10 +418,10 @@ def open(self) -> bool: @property def closed(self) -> bool: """ - ``True`` once the connection is closed. + :obj:`True` when the connection is closed; :obj:`False` otherwise. - Be aware that both :attr:`open` and :attr:`closed` are ``False`` during - the opening and closing sequences. + Be aware that both :attr:`open` and :attr:`closed` are :obj:`False` + during the opening and closing sequences. """ return self.state is State.CLOSED @@ -343,9 +429,12 @@ def closed(self) -> bool: @property def close_code(self) -> Optional[int]: """ - WebSocket close code received in a close frame. + WebSocket close code, defined in `section 7.1.5 of RFC 6455`_. + + .. _section 7.1.5 of RFC 6455: + https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.5 - Available once the connection is closed. + :obj:`None` if the connection isn't closed yet. """ if self.state is not State.CLOSED: @@ -358,9 +447,12 @@ def close_code(self) -> Optional[int]: @property def close_reason(self) -> Optional[str]: """ - WebSocket close reason received in a close frame. + WebSocket close reason, defined in `section 7.1.6 of RFC 6455`_. - Available once the connection is closed. + .. _section 7.1.6 of RFC 6455: + https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.6 + + :obj:`None` if the connection isn't closed yet. """ if self.state is not State.CLOSED: @@ -370,25 +462,14 @@ def close_reason(self) -> Optional[str]: else: return self.close_rcvd.reason - async def wait_closed(self) -> None: - """ - Wait until the connection is closed. - - This is identical to :attr:`closed`, except it can be awaited. - - This can make it easier to handle connection termination, regardless - of its cause, in tasks that interact with the WebSocket connection. - - """ - await asyncio.shield(self.connection_lost_waiter) - async def __aiter__(self) -> AsyncIterator[Data]: """ - Iterate on received messages. - - Exit normally when the connection is closed with code 1000 or 1001. + Iterate on incoming messages. - Raise an exception in other cases. + The iterator exits normally when the connection is closed with the + close code 1000 (OK) or 1001(going away). It raises + a :exc:`~websockets.exceptions.ConnectionClosedError` exception when + the connection is closed with any other code. """ try: @@ -401,24 +482,30 @@ async def recv(self) -> Data: """ Receive the next message. - Return a :class:`str` for a text frame and :class:`bytes` for a binary - frame. - - When the end of the message stream is reached, :meth:`recv` raises + When the connection is closed, :meth:`recv` raises :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal connection closure and :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol - error or a network failure. + error or a network failure. This is how you detect the end of the + message stream. Canceling :meth:`recv` is safe. There's no risk of losing the next - message. The next invocation of :meth:`recv` will return it. This - makes it possible to enforce a timeout by wrapping :meth:`recv` in - :func:`~asyncio.wait_for`. + message. The next invocation of :meth:`recv` will return it. + + This makes it possible to enforce a timeout by wrapping :meth:`recv` + in :func:`~asyncio.wait_for`. - :raises ~websockets.exceptions.ConnectionClosed: when the - connection is closed - :raises RuntimeError: if two coroutines call :meth:`recv` concurrently + Returns: + Data: A string (:class:`str`) for a Text_ frame. A bytestring + (:class:`bytes`) for a Binary_ frame. + + .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + + Raises: + ConnectionClosed: when the connection is closed. + RuntimeError: if two coroutines call :meth:`recv` concurrently. """ if self._pop_message_waiter is not None: @@ -471,43 +558,58 @@ async def recv(self) -> Data: return message async def send( - self, message: Union[Data, Iterable[Data], AsyncIterable[Data]] + self, + message: Union[Data, Iterable[Data], AsyncIterable[Data]], ) -> None: """ Send a message. - A string (:class:`str`) is sent as a `Text frame`_. A bytestring or + A string (:class:`str`) is sent as a Text_ frame. A bytestring or bytes-like object (:class:`bytes`, :class:`bytearray`, or - :class:`memoryview`) is sent as a `Binary frame`_. + :class:`memoryview`) is sent as a Binary_ frame. - .. _Text frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 - .. _Binary frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 :meth:`send` also accepts an iterable or an asynchronous iterable of - strings, bytestrings, or bytes-like objects. In that case the message - is fragmented. Each item is treated as a message fragment and sent in - its own frame. All items must be of the same type, or else - :meth:`send` will raise a :exc:`TypeError` and the connection will be - closed. + strings, bytestrings, or bytes-like objects to enable fragmentation_. + Each item is treated as a message fragment and sent in its own frame. + All items must be of the same type, or else :meth:`send` will raise a + :exc:`TypeError` and the connection will be closed. + + .. _fragmentation: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.4 :meth:`send` rejects dict-like objects because this is often an error. - If you wish to send the keys of a dict-like object as fragments, call - its :meth:`~dict.keys` method and pass the result to :meth:`send`. + (If you want to send the keys of a dict-like object as fragments, call + its :meth:`~dict.keys` method and pass the result to :meth:`send`.) Canceling :meth:`send` is discouraged. Instead, you should close the connection with :meth:`close`. Indeed, there are only two situations - where :meth:`send` may yield control to the event loop: + where :meth:`send` may yield control to the event loop and then get + canceled; in both cases, :meth:`close` has the same effect and is + more clear: 1. The write buffer is full. If you don't want to wait until enough data is sent, your only alternative is to close the connection. :meth:`close` will likely time out then abort the TCP connection. 2. ``message`` is an asynchronous iterator that yields control. Stopping in the middle of a fragmented message will cause a - protocol error. Closing the connection has the same effect. + protocol error and the connection will be closed. + + When the connection is closed, :meth:`send` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it + raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal + connection closure and + :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. - :raises ~websockets.exceptions.ConnectionClosed: when the - connection is closed - :raises TypeError: if ``message`` doesn't have a supported type + Args: + message (Union[Data, Iterable[Data], AsyncIterable[Data]): message + to send. + + Raises: + ConnectionClosed: when the connection is closed. + TypeError: if ``message`` doesn't have a supported type. """ await self.ensure_open() @@ -621,7 +723,7 @@ async def close(self, code: int = 1000, reason: str = "") -> None: :meth:`close` waits for the other end to complete the handshake and for the TCP connection to terminate. As a consequence, there's no need - to await :meth:`wait_closed`; :meth:`close` already does it. + to await :meth:`wait_closed` after :meth:`close`. :meth:`close` is idempotent: it doesn't do anything once the connection is closed. @@ -631,10 +733,12 @@ async def close(self, code: int = 1000, reason: str = "") -> None: Canceling :meth:`close` is discouraged. If it takes too long, you can set a shorter ``close_timeout``. If you don't want to wait, let the - Python process exit, then the OS will close the TCP connection. + Python process exit, then the OS will take care of closing the TCP + connection. - :param code: WebSocket close code - :param reason: WebSocket close reason + Args: + code: WebSocket close code. + reason: WebSocket close reason. """ try: @@ -653,7 +757,7 @@ async def close(self, code: int = 1000, reason: str = "") -> None: # the data transfer task and raises TimeoutError. # If close() is called multiple times concurrently and one of these - # calls hits the timeout, the data transfer task will be cancelled. + # calls hits the timeout, the data transfer task will be canceled. # Other calls will receive a CancelledError here. try: @@ -670,23 +774,27 @@ async def close(self, code: int = 1000, reason: str = "") -> None: # Wait for the close connection task to close the TCP connection. await asyncio.shield(self.close_connection_task) - async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: + async def wait_closed(self) -> None: """ - Send a ping. + Wait until the connection is closed. - Return a :class:`~asyncio.Future` that will be completed when the - corresponding pong is received. You can ignore it if you don't intend - to wait. + This coroutine is identical to the :attr:`closed` attribute, except it + can be awaited. - A ping may serve as a keepalive or as a check that the remote endpoint - received all messages up to this point:: + This can make it easier to detect connection termination, regardless + of its cause, in tasks that interact with the WebSocket connection. - pong_waiter = await ws.ping() - await pong_waiter # only if you want to wait for the pong + """ + await asyncio.shield(self.connection_lost_waiter) + + async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: + """ + Send a Ping_. + + .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 - By default, the ping contains four random bytes. This payload may be - overridden with the optional ``data`` argument which must be a string - (which will be encoded to UTF-8) or a bytes-like object. + A ping may serve as a keepalive or as a check that the remote endpoint + received all messages up to this point Canceling :meth:`ping` is discouraged. If :meth:`ping` doesn't return immediately, it means the write buffer is full. If you don't want to @@ -695,10 +803,25 @@ async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: Canceling the :class:`~asyncio.Future` returned by :meth:`ping` has no effect. - :raises ~websockets.exceptions.ConnectionClosed: when the - connection is closed - :raises RuntimeError: if another ping was sent with the same data and - the corresponding pong wasn't received yet + Args: + data (Optional[Data]): payload of the ping; a string will be + encoded to UTF-8; or :obj:`None` to generate a payload + containing four random bytes. + + Returns: + ~asyncio.Future: A future that will be completed when the + corresponding pong is received. You can ignore it if you + don't intend to wait. + + :: + + pong_waiter = await ws.ping() + await pong_waiter # only if you want to wait for the pong + + Raises: + ConnectionClosed: when the connection is closed. + RuntimeError: if another ping was sent with the same data and + the corresponding pong wasn't received yet. """ await self.ensure_open() @@ -722,18 +845,22 @@ async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: async def pong(self, data: Data = b"") -> None: """ - Send a pong. + Send a Pong_. + + .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 An unsolicited pong may serve as a unidirectional heartbeat. - The payload may be set with the optional ``data`` argument which must - be a string (which will be encoded to UTF-8) or a bytes-like object. + Canceling :meth:`pong` is discouraged. If :meth:`pong` doesn't return + immediately, it means the write buffer is full. If you don't want to + wait, you should close the connection. - Canceling :meth:`pong` is discouraged for the same reason as - :meth:`ping`. + Args: + data (Data): payload of the pong; a string will be encoded to + UTF-8. - :raises ~websockets.exceptions.ConnectionClosed: when the - connection is closed + Raises: + ConnectionClosed: when the connection is closed. """ await self.ensure_open() @@ -876,7 +1003,7 @@ async def read_message(self) -> Optional[Data]: Re-assemble data frames if the message is fragmented. - Return ``None`` when the closing handshake is started. + Return :obj:`None` when the closing handshake is started. """ frame = await self.read_data_frame(max_size=self.max_size) @@ -950,7 +1077,7 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: Process control frames received before the next data frame. - Return ``None`` if a close frame is encountered before any data frame. + Return :obj:`None` if a close frame is encountered before any data frame. """ # 6.2. Receiving Data @@ -1224,7 +1351,8 @@ async def wait_for_connection_lost(self) -> bool: """ Wait until the TCP connection is closed or ``self.close_timeout`` elapses. - Return ``True`` if the connection is closed and ``False`` otherwise. + Return :obj:`True` if the connection is closed and :obj:`False` + otherwise. """ if not self.connection_lost_waiter.done(): @@ -1404,7 +1532,7 @@ def eof_received(self) -> None: the TCP or TLS connection after sending and receiving a close frame. As a consequence, they never need to write after receiving EOF, so - there's no reason to keep the transport open by returning ``True``. + there's no reason to keep the transport open by returning :obj:`True`. Besides, that doesn't work on TLS connections. @@ -1416,15 +1544,15 @@ def broadcast(websockets: Iterable[WebSocketCommonProtocol], message: Data) -> N """ Broadcast a message to several WebSocket connections. - A string (:class:`str`) is sent as a `Text frame`_. A bytestring or + A string (:class:`str`) is sent as a Text_ frame. A bytestring or bytes-like object (:class:`bytes`, :class:`bytearray`, or - :class:`memoryview`) is sent as a `Binary frame`_. + :class:`memoryview`) is sent as a Binary_ frame. - .. _Text frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 - .. _Binary frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 :func:`broadcast` pushes the message synchronously to all connections even - if their write buffers overflow ``write_limit``. There's no backpressure. + if their write buffers are overflowing. There's no backpressure. :func:`broadcast` skips silently connections that aren't open in order to avoid errors on connections where the closing handshake is in progress. @@ -1440,8 +1568,14 @@ def broadcast(websockets: Iterable[WebSocketCommonProtocol], message: Data) -> N them in memory, while :func:`broadcast` buffers one copy per connection as fast as possible. - :raises RuntimeError: if a connection is busy sending a fragmented message - :raises TypeError: if ``message`` doesn't have a supported type + Args: + websockets (Iterable[WebSocketCommonProtocol]): WebSocket connections + to which the message will be sent. + message (Data): message to send. + + Raises: + RuntimeError: if a connection is busy sending a fragmented message. + TypeError: if ``message`` doesn't have a supported type. """ if not isinstance(message, (str, bytes, bytearray, memoryview)): diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index ddf6d9f87..673888c3d 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -1,8 +1,3 @@ -""" -:mod:`websockets.legacy.server` defines the WebSocket server APIs. - -""" - from __future__ import annotations import asyncio @@ -65,109 +60,33 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): """ - :class:`~asyncio.Protocol` subclass implementing a WebSocket server. + WebSocket server connection. - :class:`WebSocketServerProtocol`: + :class:`WebSocketServerProtocol` provides :meth:`recv` and :meth:`send` + coroutines for receiving and sending messages. - * performs the opening handshake to establish the connection; - * provides :meth:`recv` and :meth:`send` coroutines for receiving and - sending messages; - * deals with control frames automatically; - * performs the closing handshake to terminate the connection. + It supports asynchronous iteration to receive messages:: - You may customize the opening handshake by subclassing - :class:`WebSocketServer` and overriding: + async for message in websocket: + await process(message) - * :meth:`process_request` to intercept the client request before any - processing and, if appropriate, to abort the WebSocket request and - return a HTTP response instead; - * :meth:`select_subprotocol` to select a subprotocol, if the client and - the server have multiple subprotocols in common and the default logic - for choosing one isn't suitable (this is rarely needed). + The iterator exits normally when the connection is closed with close code + 1000 (OK) or 1001 (going away). It raises + a :exc:`~websockets.exceptions.ConnectionClosedError` when the connection + is closed with any other code. - :class:`WebSocketServerProtocol` supports asynchronous iteration:: + You may customize the opening handshake in a subclass by + overriding :meth:`process_request` or :meth:`select_subprotocol`. - async for message in websocket: - await process(message) + Args: + ws_server: WebSocket server that created this connection. - The iterator yields incoming messages. It exits normally when the - connection is closed with the close code 1000 (OK) or 1001 (going away). - It raises a :exc:`~websockets.exceptions.ConnectionClosedError` exception - when the connection is closed with any other code. - - Once the connection is open, a `Ping frame`_ is sent every - ``ping_interval`` seconds. This serves as a keepalive. It helps keeping - the connection open, especially in the presence of proxies with short - timeouts on inactive connections. Set ``ping_interval`` to ``None`` to - disable this behavior. - - .. _Ping frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 - - If the corresponding `Pong frame`_ isn't received within ``ping_timeout`` - seconds, the connection is considered unusable and is closed with - code 1011. This ensures that the remote endpoint remains responsive. Set - ``ping_timeout`` to ``None`` to disable this behavior. - - .. _Pong frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 - - The ``close_timeout`` parameter defines a maximum wait time for completing - the closing handshake and terminating the TCP connection. For legacy - reasons, :meth:`close` completes in at most ``4 * close_timeout`` seconds. - - ``close_timeout`` needs to be a parameter of the protocol because - websockets usually calls :meth:`close` implicitly when the connection - handler terminates. - - To apply a timeout to any other API, wrap it in :func:`~asyncio.wait_for`. - - The ``max_size`` parameter enforces the maximum size for incoming messages - in bytes. The default value is 1 MiB. ``None`` disables the limit. If a - message larger than the maximum size is received, :meth:`recv` will - raise :exc:`~websockets.exceptions.ConnectionClosedError` and the - connection will be closed with code 1009. - - The ``max_queue`` parameter sets the maximum length of the queue that - holds incoming messages. The default value is ``32``. ``None`` disables - the limit. Messages are added to an in-memory queue when they're received; - then :meth:`recv` pops from that queue. In order to prevent excessive - memory consumption when messages are received faster than they can be - processed, the queue must be bounded. If the queue fills up, the protocol - stops processing incoming data until :meth:`recv` is called. In this - situation, various receive buffers (at least in :mod:`asyncio` and in the - OS) will fill up, then the TCP receive window will shrink, slowing down - transmission to avoid packet loss. - - Since Python can use up to 4 bytes of memory to represent a single - character, each connection may use up to ``4 * max_size * max_queue`` - bytes of memory to store incoming messages. By default, this is 128 MiB. - You may want to lower the limits, depending on your application's - requirements. - - The ``read_limit`` argument sets the high-water limit of the buffer for - incoming bytes. The low-water limit is half the high-water limit. The - default value is 64 KiB, half of asyncio's default (based on the current - implementation of :class:`~asyncio.StreamReader`). - - The ``write_limit`` argument sets the high-water limit of the buffer for - outgoing bytes. The low-water limit is a quarter of the high-water limit. - The default value is 64 KiB, equal to asyncio's default (based on the - current implementation of ``FlowControlMixin``). - - As soon as the HTTP request and response in the opening handshake are - processed: - - * the request path is available in the :attr:`path` attribute; - * the request and response HTTP headers are available in the - :attr:`request_headers` and :attr:`response_headers` attributes, - which are :class:`~websockets.http.Headers` instances. - - If a subprotocol was negotiated, it's available in the :attr:`subprotocol` - attribute. - - Once the connection is closed, the code is available in the - :attr:`close_code` attribute and the reason in :attr:`close_reason`. - - All attributes must be treated as read-only. + See :func:`serve` for the documentation of ``ws_handler``, ``logger``, ``origins``, + ``extensions``, ``subprotocols``, and ``extra_headers``. + + See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the + documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, + ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``. """ @@ -179,6 +98,7 @@ def __init__( ws_handler: Callable[[WebSocketServerProtocol, str], Awaitable[Any]], ws_server: WebSocketServer, *, + logger: Optional[LoggerLike] = None, origins: Optional[Sequence[Optional[Origin]]] = None, extensions: Optional[Sequence[ServerExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, @@ -189,7 +109,6 @@ def __init__( select_subprotocol: Optional[ Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] ] = None, - logger: Optional[LoggerLike] = None, **kwargs: Any, ) -> None: if logger is None: @@ -339,8 +258,9 @@ async def read_http_request(self) -> Tuple[str, Headers]: If the request contains a body, it may be read from ``self.reader`` after this coroutine returns. - :raises ~websockets.exceptions.InvalidMessage: if the HTTP message is - malformed or isn't an HTTP/1.1 GET request + Raises: + InvalidMessage: if the HTTP message is malformed or isn't an + HTTP/1.1 GET request. """ try: @@ -394,18 +314,7 @@ async def process_request( """ Intercept the HTTP request and return an HTTP response if appropriate. - If ``process_request`` returns ``None``, the WebSocket handshake - continues. If it returns 3-uple containing a status code, response - headers and a response body, that HTTP response is sent and the - connection is closed. In that case: - - * The HTTP status must be a :class:`~http.HTTPStatus`. - * HTTP headers must be a :class:`~websockets.http.Headers` instance, a - :class:`~collections.abc.Mapping`, or an iterable of ``(name, - value)`` pairs. - * The HTTP response body must be :class:`bytes`. It may be empty. - - This coroutine may be overridden in a :class:`WebSocketServerProtocol` + You may override this method in a :class:`WebSocketServerProtocol` subclass, for example: * to return a HTTP 200 OK response on a given path; then a load @@ -413,19 +322,27 @@ async def process_request( * to authenticate the request and return a HTTP 401 Unauthorized or a HTTP 403 Forbidden when authentication fails. - Instead of subclassing, it is possible to override this method by - passing a ``process_request`` argument to the :func:`serve` function - or the :class:`WebSocketServerProtocol` constructor. This is - equivalent, except ``process_request`` won't have access to the + You may also override this method with the ``process_request`` + argument of :func:`serve` and :class:`WebSocketServerProtocol`. This + is equivalent, except ``process_request`` won't have access to the protocol instance, so it can't store information for later use. - ``process_request`` is expected to complete quickly. If it may run for - a long time, then it should await :meth:`wait_closed` and exit if + :meth:`process_request` is expected to complete quickly. If it may run + for a long time, then it should await :meth:`wait_closed` and exit if :meth:`wait_closed` completes, or else it could prevent the server from shutting down. - :param path: request path, including optional query string - :param request_headers: request headers + Args: + path: request path, including optional query string. + request_headers: request headers. + + Returns: + Optional[Tuple[http.HTTPStatus, HeadersLike, bytes]]: :obj:`None` + to continue the WebSocket handshake normally. + + An HTTP response, represented by a 3-uple of the response status, + headers, and body, to abort the WebSocket handshake and return + that HTTP response instead. """ if self._process_request is not None: @@ -447,10 +364,12 @@ def process_origin( """ Handle the Origin HTTP request header. - :param headers: request headers - :param origins: optional list of acceptable origins - :raises ~websockets.exceptions.InvalidOrigin: if the origin isn't - acceptable + Args: + headers: request headers. + origins: optional list of acceptable origins. + + Raises: + InvalidOrigin: if the origin isn't acceptable. """ # "The user agent MUST NOT include more than one Origin header field" @@ -496,10 +415,12 @@ def process_extensions( Other requirements, for example related to mandatory extensions or the order of extensions, may be implemented by overriding this method. - :param headers: request headers - :param extensions: optional list of supported extensions - :raises ~websockets.exceptions.InvalidHandshake: to abort the - handshake with an HTTP 400 error code + Args: + headers: request headers. + extensions: optional list of supported extensions. + + Raises: + InvalidHandshake: to abort the handshake with an HTTP 400 error. """ response_header_value: Optional[str] = None @@ -557,10 +478,12 @@ def process_subprotocol( Return Sec-WebSocket-Protocol HTTP response header, which is the same as the selected subprotocol. - :param headers: request headers - :param available_subprotocols: optional list of supported subprotocols - :raises ~websockets.exceptions.InvalidHandshake: to abort the - handshake with an HTTP 400 error code + Args: + headers: request headers. + available_subprotocols: optional list of supported subprotocols. + + Raises: + InvalidHandshake: to abort the handshake with an HTTP 400 error. """ subprotocol: Optional[Subprotocol] = None @@ -588,22 +511,27 @@ def select_subprotocol( Pick a subprotocol among those offered by the client. If several subprotocols are supported by the client and the server, - the default implementation selects the preferred subprotocols by + the default implementation selects the preferred subprotocol by giving equal value to the priorities of the client and the server. - If no subprotocol is supported by the client and the server, it proceeds without a subprotocol. - This is unlikely to be the most useful implementation in practice, as - many servers providing a subprotocol will require that the client uses - that subprotocol. Such rules can be implemented in a subclass. + This is unlikely to be the most useful implementation in practice. + Many servers providing a subprotocol will require that the client + uses that subprotocol. Such rules can be implemented in a subclass. + + You may also override this method with the ``select_subprotocol`` + argument of :func:`serve` and :class:`WebSocketServerProtocol`. - Instead of subclassing, it is possible to override this method by - passing a ``select_subprotocol`` argument to the :func:`serve` - function or the :class:`WebSocketServerProtocol` constructor. + Args: + client_subprotocols: list of subprotocols offered by the client. + server_subprotocols: list of subprotocols available on the server. + + Returns: + Optional[Subprotocol]: Selected subprotocol. + + :obj:`None` to continue without a subprotocol. - :param client_subprotocols: list of subprotocols offered by the client - :param server_subprotocols: list of subprotocols available on the server """ if self._select_subprotocol is not None: @@ -629,19 +557,18 @@ async def handshake( Return the path of the URI of the request. - :param origins: list of acceptable values of the Origin HTTP header; - include ``None`` if the lack of an origin is acceptable - :param available_extensions: list of supported extensions in the order - in which they should be used - :param available_subprotocols: list of supported subprotocols in order - of decreasing preference - :param extra_headers: sets additional HTTP response headers when the - handshake succeeds; it can be a :class:`~websockets.http.Headers` - instance, a :class:`~collections.abc.Mapping`, an iterable of - ``(name, value)`` pairs, or a callable taking the request path and - headers in arguments and returning one of the above. - :raises ~websockets.exceptions.InvalidHandshake: if the handshake - fails + Args: + origins: list of acceptable values of the Origin HTTP header; + include :obj:`None` if the lack of an origin is acceptable. + extensions: list of supported extensions, in order in which they + should be tried. + subprotocols: list of supported subprotocols, in order of + decreasing preference. + extra_headers: arbitrary HTTP headers to add to the response when + the handshake succeeds. + + Raises: + InvalidHandshake: if the handshake fails. """ path, request_headers = await self.read_http_request() @@ -714,24 +641,24 @@ class WebSocketServer: """ WebSocket server returned by :func:`serve`. - This class provides the same interface as - :class:`~asyncio.AbstractServer`, namely the - :meth:`~asyncio.AbstractServer.close` and - :meth:`~asyncio.AbstractServer.wait_closed` methods. + This class provides the same interface as :class:`~asyncio.Server`, + notably the :meth:`~asyncio.Server.close` + and :meth:`~asyncio.Server.wait_closed` methods. It keeps track of WebSocket connections in order to close them properly when shutting down. - Instances of this class store a reference to the :class:`~asyncio.Server` - object returned by :meth:`~asyncio.loop.create_server` rather than inherit - from :class:`~asyncio.Server` in part because - :meth:`~asyncio.loop.create_server` doesn't support passing a custom - :class:`~asyncio.Server` class. + Args: + logger: logger for this server; + defaults to ``logging.getLogger("websockets.server")``; + see the :doc:`logging guide <../topics/logging>` for details. """ def __init__( - self, loop: asyncio.AbstractEventLoop, logger: Optional[LoggerLike] = None + self, + loop: asyncio.AbstractEventLoop, + logger: Optional[LoggerLike] = None, ) -> None: # Store a reference to loop to avoid relying on self.server._loop. self.loop = loop @@ -874,15 +801,26 @@ async def wait_closed(self) -> None: When :meth:`wait_closed` returns, all TCP connections are closed and all connection handlers have returned. + To ensure a fast shutdown, a connection handler should always be + awaiting at least one of: + + * :meth:`~WebSocketServerProtocol.recv`: when the connection is closed, + it raises :exc:`~websockets.exceptions.ConnectionClosedOK`; + * :meth:`~WebSocketServerProtocol.wait_closed`: when the connection is + closed, it returns. + + Then the connection handler is immediately notified of the shutdown; + it can clean up and exit. + """ await asyncio.shield(self.closed_waiter) @property def sockets(self) -> Optional[List[socket.socket]]: """ - List of :class:`~socket.socket` objects the server is listening to. + List of :obj:`~socket.socket` objects the server is listening on. - ``None`` if the server is closed. + :obj:`None` if the server is closed. """ return self.server.sockets @@ -890,25 +828,21 @@ def sockets(self) -> Optional[List[socket.socket]]: class Serve: """ + Start a WebSocket server listening on ``host`` and ``port``. - Create, start, and return a WebSocket server on ``host`` and ``port``. - - Whenever a client connects, the server accepts the connection, creates a + Whenever a client connects, the server creates a :class:`WebSocketServerProtocol`, performs the opening handshake, and - delegates to the connection handler defined by ``ws_handler``. Once the - handler completes, either normally or with an exception, the server - performs the closing handshake and closes the connection. + delegates to the connection handler, ``ws_handler``. - Awaiting :func:`serve` yields a :class:`WebSocketServer`. This instance - provides :meth:`~WebSocketServer.close` and - :meth:`~WebSocketServer.wait_closed` methods for terminating the server - and cleaning up its resources. + The handler receives the :class:`WebSocketServerProtocol` and uses it to + send and receive messages. + + Once the handler completes, either normally or with an exception, the + server performs the closing handshake and closes the connection. - When a server is closed with :meth:`~WebSocketServer.close`, it closes all - connections with close code 1001 (going away). Connections handlers, which - are running the ``ws_handler`` coroutine, will receive a - :exc:`~websockets.exceptions.ConnectionClosedOK` exception on their - current or next interaction with the WebSocket connection. + Awaiting :func:`serve` yields a :class:`WebSocketServer`. This object + provides :meth:`~WebSocketServer.close` and + :meth:`~WebSocketServer.wait_closed` methods for shutting down the server. :func:`serve` can be used as an asynchronous context manager:: @@ -917,64 +851,61 @@ class Serve: async with serve(...): await stop - In this case, the server is shut down when exiting the context. - - :func:`serve` is a wrapper around the event loop's - :meth:`~asyncio.loop.create_server` method. It creates and starts a - :class:`asyncio.Server` with :meth:`~asyncio.loop.create_server`. Then it - wraps the :class:`asyncio.Server` in a :class:`WebSocketServer` and - returns the :class:`WebSocketServer`. - - ``ws_handler`` is the WebSocket handler. It must be a coroutine accepting - two arguments: the WebSocket connection, which is an instance of - :class:`WebSocketServerProtocol`, and the path of the request. - - The ``host`` and ``port`` arguments, as well as unrecognized keyword - arguments, are passed to :meth:`~asyncio.loop.create_server`. - - For example, you can set the ``ssl`` keyword argument to a - :class:`~ssl.SSLContext` to enable TLS. - - ``create_protocol`` defaults to :class:`WebSocketServerProtocol`. It may - be replaced by a wrapper or a subclass to customize the protocol that - manages the connection. - - The behavior of ``ping_interval``, ``ping_timeout``, ``close_timeout``, - ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit`` is - described in :class:`WebSocketServerProtocol`. - - :func:`serve` also accepts the following optional arguments: - - * ``compression`` is a shortcut to configure compression extensions; - by default it enables the "permessage-deflate" extension; set it to - ``None`` to disable compression. - * ``origins`` defines acceptable Origin HTTP headers; include ``None`` in - the list if the lack of an origin is acceptable. - * ``extensions`` is a list of supported extensions in order of - decreasing preference. - * ``subprotocols`` is a list of supported subprotocols in order of - decreasing preference. - * ``extra_headers`` sets additional HTTP response headers when the - handshake succeeds; it can be a :class:`~websockets.http.Headers` - instance, a :class:`~collections.abc.Mapping`, an iterable of ``(name, - value)`` pairs, or a callable taking the request path and headers in - arguments and returning one of the above. - * ``process_request`` allows intercepting the HTTP request; it must be a - coroutine taking the request path and headers in argument; see - :meth:`~WebSocketServerProtocol.process_request` for details. - * ``select_subprotocol`` allows customizing the logic for selecting a - subprotocol; it must be a callable taking the subprotocols offered by - the client and available on the server in argument; see - :meth:`~WebSocketServerProtocol.select_subprotocol` for details. - - Since there's no useful way to propagate exceptions triggered in handlers, - they're sent to the ``"websockets.server"`` logger instead. - Debugging is much easier if you configure logging to print them:: - - import logging - logger = logging.getLogger("websockets.server") - logger.setLevel(logging.ERROR) - logger.addHandler(logging.StreamHandler()) + The server is shut down automatically when exiting the context. + + Args: + ws_handler: connection handler. It must be a coroutine accepting + two arguments: the WebSocket connection, which is a + :class:`WebSocketServerProtocol`, and the path of the request. + host: network interfaces the server is bound to; + see :meth:`~asyncio.loop.create_server` for details. + port: TCP port the server listens on; + see :meth:`~asyncio.loop.create_server` for details. + create_protocol: factory for the :class:`asyncio.Protocol` managing + the connection; defaults to :class:`WebSocketServerProtocol`; may + be set to a wrapper or a subclass to customize connection handling. + logger: logger for this server; + defaults to ``logging.getLogger("websockets.server")``; + see the :doc:`logging guide <../topics/logging>` for details. + compression: shortcut that enables the "permessage-deflate" extension + by default; may be set to :obj:`None` to disable compression; + see the :doc:`compression guide <../topics/compression>` for details. + origins: acceptable values of the ``Origin`` header; include + :obj:`None` in the list if the lack of an origin is acceptable. + This is useful for defending against Cross-Site WebSocket + Hijacking attacks. + extensions: list of supported extensions, in order in which they + should be tried. + subprotocols: list of supported subprotocols, in order of decreasing + preference. + extra_headers (Union[HeadersLike, Callable[[str, Headers], HeadersLike]]): + arbitrary HTTP headers to add to the request; this can be + a :data:`~websockets.datastructures.HeadersLike` or a callable + taking the request path and headers in arguments and returning + a :data:`~websockets.datastructures.HeadersLike`. + process_request (Optional[Callable[[str, Headers], \ + Awaitable[Optional[Tuple[http.HTTPStatus, HeadersLike, bytes]]]]]): + intercept HTTP request before the opening handshake; + see :meth:`~WebSocketServerProtocol.process_request` for details. + select_subprotocol: select a subprotocol supported by the client; + see :meth:`~WebSocketServerProtocol.select_subprotocol` for details. + + See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the + documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, + ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``. + + Any other keyword arguments are passed the event loop's + :meth:`~asyncio.loop.create_server` method. + + For example: + + * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enable TLS. + + * You can set ``sock`` to a :obj:`~socket.socket` that you created + outside of websockets. + + Returns: + WebSocketServer: WebSocket server. """ @@ -985,13 +916,7 @@ def __init__( port: Optional[int] = None, *, create_protocol: Optional[Callable[[Any], WebSocketServerProtocol]] = None, - ping_interval: Optional[float] = 20, - ping_timeout: Optional[float] = 20, - close_timeout: Optional[float] = None, - max_size: Optional[int] = 2 ** 20, - max_queue: Optional[int] = 2 ** 5, - read_limit: int = 2 ** 16, - write_limit: int = 2 ** 16, + logger: Optional[LoggerLike] = None, compression: Optional[str] = "deflate", origins: Optional[Sequence[Optional[Origin]]] = None, extensions: Optional[Sequence[ServerExtensionFactory]] = None, @@ -1003,7 +928,13 @@ def __init__( select_subprotocol: Optional[ Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] ] = None, - logger: Optional[LoggerLike] = None, + ping_interval: Optional[float] = 20, + ping_timeout: Optional[float] = 20, + close_timeout: Optional[float] = None, + max_size: Optional[int] = 2 ** 20, + max_queue: Optional[int] = 2 ** 5, + read_limit: int = 2 ** 16, + write_limit: int = 2 ** 16, **kwargs: Any, ) -> None: # Backwards compatibility: close_timeout used to be called timeout. @@ -1131,14 +1062,15 @@ def unix_serve( """ Similar to :func:`serve`, but for listening on Unix sockets. - This function calls the event loop's - :meth:`~asyncio.loop.create_unix_server` method. + This function builds upon the event + loop's :meth:`~asyncio.loop.create_unix_server` method. It is only available on Unix. It's useful for deploying a server behind a reverse proxy such as nginx. - :param path: file system path to the Unix socket + Args: + path: file system path to the Unix socket. """ return serve(ws_handler, path=path, unix=True, **kwargs) diff --git a/src/websockets/server.py b/src/websockets/server.py index 0ae0ae940..5f7bec30d 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -69,14 +69,24 @@ def __init__( def accept(self, request: Request) -> Response: """ - Create a WebSocket handshake response event to send to the client. + Create a WebSocket handshake response event to accept the connection. - If the connection cannot be established, the response rejects the - connection, which may be unexpected. + If the connection cannot be established, create a HTTP response event + to reject the handshake. + + Args: + request: handshake request event received from the client. + + Returns: + Response: handshake response event to send to the client. """ try: - key, extensions_header, protocol_header = self.process_request(request) + ( + accept_header, + extensions_header, + protocol_header, + ) = self.process_request(request) except InvalidOrigin as exc: request.exception = exc if self.debug: @@ -124,7 +134,7 @@ def accept(self, request: Request) -> Response: headers["Upgrade"] = "websocket" headers["Connection"] = "Upgrade" - headers["Sec-WebSocket-Accept"] = accept_key(key) + headers["Sec-WebSocket-Accept"] = accept_header if extensions_header is not None: headers["Sec-WebSocket-Extensions"] = extensions_header @@ -141,17 +151,24 @@ def process_request( self, request: Request ) -> Tuple[str, Optional[str], Optional[str]]: """ - Check a handshake request received from the client. + Check a handshake request and negociate extensions and subprotocol. - This function doesn't verify that the request is an HTTP/1.1 or higher GET - request and doesn't perform ``Host`` and ``Origin`` checks. These controls - are usually performed earlier in the HTTP request handling code. They're + This function doesn't verify that the request is an HTTP/1.1 or higher + GET request and doesn't check the ``Host`` header. These controls are + usually performed earlier in the HTTP request handling code. They're the responsibility of the caller. - :param request: request - :returns: ``key`` which must be passed to :func:`build_response` - :raises ~websockets.exceptions.InvalidHandshake: if the handshake request - is invalid; then the server must return 400 Bad Request error + Args: + request: WebSocket handshake request received from the client. + + Returns: + Tuple[str, Optional[str], Optional[str]]: + ``Sec-WebSocket-Accept``, ``Sec-WebSocket-Extensions``, and + ``Sec-WebSocket-Protocol`` headers for the handshake response. + + Raises: + InvalidHandshake: if the handshake request is invalid; + then the server must return 400 Bad Request error. """ headers = request.headers @@ -204,21 +221,32 @@ def process_request( if version != "13": raise InvalidHeaderValue("Sec-WebSocket-Version", version) + accept_header = accept_key(key) + self.origin = self.process_origin(headers) extensions_header, self.extensions = self.process_extensions(headers) protocol_header = self.subprotocol = self.process_subprotocol(headers) - return key, extensions_header, protocol_header + return ( + accept_header, + extensions_header, + protocol_header, + ) def process_origin(self, headers: Headers) -> Optional[Origin]: """ Handle the Origin HTTP request header. - :param headers: request headers - :raises ~websockets.exceptions.InvalidOrigin: if the origin isn't - acceptable + Args: + headers: WebSocket handshake request headers. + + Returns: + Optional[Origin]: origin, if it is acceptable. + + Raises: + InvalidOrigin: if the origin isn't acceptable. """ # "The user agent MUST NOT include more than one Origin header field" @@ -242,9 +270,6 @@ def process_extensions( Accept or reject each extension proposed in the client request. Negotiate parameters for accepted extensions. - Return the Sec-WebSocket-Extensions HTTP response header and the list - of accepted extensions. - :rfc:`6455` leaves the rules up to the specification of each :extension. @@ -263,9 +288,15 @@ def process_extensions( Other requirements, for example related to mandatory extensions or the order of extensions, may be implemented by overriding this method. - :param headers: request headers - :raises ~websockets.exceptions.InvalidHandshake: to abort the - handshake with an HTTP 400 error code + Args: + headers: WebSocket handshake request headers. + + Returns: + Tuple[Optional[str], List[Extension]]: ``Sec-WebSocket-Extensions`` + HTTP response header and list of accepted extensions. + + Raises: + InvalidHandshake: to abort the handshake with an HTTP 400 error. """ response_header_value: Optional[str] = None @@ -317,12 +348,15 @@ def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: """ Handle the Sec-WebSocket-Protocol HTTP request header. - Return Sec-WebSocket-Protocol HTTP response header, which is the same - as the selected subprotocol. + Args: + headers: WebSocket handshake request headers. + + Returns: + Optional[Subprotocol]: Subprotocol, if one was selected; this is + also the value of the ``Sec-WebSocket-Protocol`` response header. - :param headers: request headers - :raises ~websockets.exceptions.InvalidHandshake: to abort the - handshake with an HTTP 400 error code + Raises: + InvalidHandshake: to abort the handshake with an HTTP 400 error. """ subprotocol: Optional[Subprotocol] = None @@ -360,8 +394,13 @@ def select_subprotocol( many servers providing a subprotocol will require that the client uses that subprotocol. - :param client_subprotocols: list of subprotocols offered by the client - :param server_subprotocols: list of subprotocols available on the server + Args: + client_subprotocols: list of subprotocols offered by the client. + server_subprotocols: list of subprotocols available on the server. + + Returns: + Optional[Subprotocol]: Subprotocol, if a common subprotocol was + found. """ subprotocols = set(client_subprotocols) & set(server_subprotocols) @@ -380,11 +419,14 @@ def reject( exception: Optional[Exception] = None, ) -> Response: """ - Create a HTTP response event to send to the client. + Create a HTTP response event to reject the connection. A short plain text response is the best fallback when failing to establish a WebSocket connection. + Returns: + Response: HTTP handshake response to send to the client. + """ body = text.encode() if headers is None: @@ -399,7 +441,10 @@ def reject( def send_response(self, response: Response) -> None: """ - Send a WebSocket handshake response to the client. + Send a handshake response to the client. + + Args: + response: WebSocket handshake response event to send. """ if self.debug: diff --git a/src/websockets/streams.py b/src/websockets/streams.py index d1ce377e7..094cbb53a 100644 --- a/src/websockets/streams.py +++ b/src/websockets/streams.py @@ -7,8 +7,8 @@ class StreamReader: """ Generator-based stream reader. - This class doesn't support concurrent calls to :meth:`read_line()`, - :meth:`read_exact()`, or :meth:`read_to_eof()`. Make sure calls are + This class doesn't support concurrent calls to :meth:`read_line`, + :meth:`read_exact`, or :meth:`read_to_eof`. Make sure calls are serialized. """ @@ -21,11 +21,12 @@ def read_line(self) -> Generator[None, None, bytes]: """ Read a LF-terminated line from the stream. - The return value includes the LF character. - This is a generator-based coroutine. - :raises EOFError: if the stream ends without a LF + The return value includes the LF character. + + Raises: + EOFError: if the stream ends without a LF. """ n = 0 # number of bytes to read @@ -44,11 +45,15 @@ def read_line(self) -> Generator[None, None, bytes]: def read_exact(self, n: int) -> Generator[None, None, bytes]: """ - Read ``n`` bytes from the stream. + Read a given number of bytes from the stream. This is a generator-based coroutine. - :raises EOFError: if the stream ends in less than ``n`` bytes + Args: + n: how many bytes to read. + + Raises: + EOFError: if the stream ends in less than ``n`` bytes. """ assert n >= 0 @@ -92,11 +97,15 @@ def at_eof(self) -> Generator[None, None, bool]: def feed_data(self, data: bytes) -> None: """ - Write ``data`` to the stream. + Write data to the stream. + + :meth:`feed_data` cannot be called after :meth:`feed_eof`. - :meth:`feed_data()` cannot be called after :meth:`feed_eof()`. + Args: + data: data to write. - :raises EOFError: if the stream has ended + Raises: + EOFError: if the stream has ended. """ if self.eof: @@ -107,9 +116,10 @@ def feed_eof(self) -> None: """ End the stream. - :meth:`feed_eof()` must be called at must once. + :meth:`feed_eof` cannot be called more than once. - :raises EOFError: if the stream has ended + Raises: + EOFError: if the stream has ended. """ if self.eof: @@ -118,7 +128,7 @@ def feed_eof(self) -> None: def discard(self) -> None: """ - Discarding all buffered data, but don't end the stream. + Discard all buffered data, but don't end the stream. """ del self.buffer[:] diff --git a/src/websockets/typing.py b/src/websockets/typing.py index 1bd118071..dadee7aba 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -17,24 +17,25 @@ # Public types used in the signature of public APIs Data = Union[str, bytes] -Data.__doc__ = """ -Types supported in a WebSocket message: +Data.__doc__ = """Types supported in a WebSocket message: +:class:`str` for a Text_ frame, :class:`bytes` for a Binary_. + +.. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 +.. _Binary : https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 -- :class:`str` for text messages; -- :class:`bytes` for binary messages. """ LoggerLike = Union[logging.Logger, logging.LoggerAdapter] -LoggerLike.__doc__ = """Types accepted where :class:`~logging.Logger` is expected.""" +LoggerLike.__doc__ = """Types accepted where a :class:`~logging.Logger` is expected.""" Origin = NewType("Origin", str) -Origin.__doc__ = """Value of a Origin header.""" +Origin.__doc__ = """Value of a ``Origin`` header.""" Subprotocol = NewType("Subprotocol", str) -Subprotocol.__doc__ = """Subprotocol in a Sec-WebSocket-Protocol header.""" +Subprotocol.__doc__ = """Subprotocol in a ``Sec-WebSocket-Protocol`` header.""" ExtensionName = NewType("ExtensionName", str) @@ -48,12 +49,12 @@ # Private types ExtensionHeader = Tuple[ExtensionName, List[ExtensionParameter]] -ExtensionHeader.__doc__ = """Extension in a Sec-WebSocket-Extensions header.""" +ExtensionHeader.__doc__ = """Extension in a ``Sec-WebSocket-Extensions`` header.""" ConnectionOption = NewType("ConnectionOption", str) -ConnectionOption.__doc__ = """Connection option in a Connection header.""" +ConnectionOption.__doc__ = """Connection option in a ``Connection`` header.""" UpgradeProtocol = NewType("UpgradeProtocol", str) -UpgradeProtocol.__doc__ = """Upgrade protocol in an Upgrade header.""" +UpgradeProtocol.__doc__ = """Upgrade protocol in an ``Upgrade`` header.""" diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 397c23116..3d8f7cd95 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -1,12 +1,3 @@ -""" -:mod:`websockets.uri` parses WebSocket URIs. - -See `section 3 of RFC 6455`_. - -.. _section 3 of RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-3 - -""" - from __future__ import annotations import dataclasses @@ -24,14 +15,16 @@ class WebSocketURI: """ WebSocket URI. - :param bool secure: secure flag - :param str host: lower-case host - :param int port: port, always set even if it's the default - :param str resource_name: path and optional query - :param str user_info: ``(username, password)`` tuple when the URI contains - `User Information`_, else ``None``. + Attributes: + secure: :obj:`True` for a ``wss`` URI, :obj:`False` for a ``ws`` URI. + host: Host, normalized to lower case. + port: Port, always set even if it's the default. + resource_name: Path and optional query. + user_info: ``(username, password)`` when the URI contains + `User Information`_, else :obj:`None`. .. _User Information: https://www.rfc-editor.org/rfc/rfc3986.html#section-3.2.1 + """ secure: bool @@ -49,8 +42,11 @@ def parse_uri(uri: str) -> WebSocketURI: """ Parse and validate a WebSocket URI. - :raises ~websockets.exceptions.InvalidURI: if ``uri`` isn't a valid - WebSocket URI. + Args: + uri: WebSocket URI. + + Raises: + InvalidURI: if ``uri`` isn't a valid WebSocket URI. """ parsed = urllib.parse.urlparse(uri) diff --git a/src/websockets/utils.py b/src/websockets/utils.py index c6e4b788c..c40404906 100644 --- a/src/websockets/utils.py +++ b/src/websockets/utils.py @@ -25,7 +25,8 @@ def accept_key(key: str) -> str: """ Compute the value of the Sec-WebSocket-Accept header. - :param key: value of the Sec-WebSocket-Key header + Args: + key: value of the Sec-WebSocket-Key header. """ sha1 = hashlib.sha1((key + GUID).encode()).digest() @@ -36,8 +37,9 @@ def apply_mask(data: bytes, mask: bytes) -> bytes: """ Apply masking to the data of a WebSocket message. - :param data: Data to mask - :param mask: 4-bytes mask + Args: + data: data to mask. + mask: 4-bytes mask. """ if len(mask) != 4: From f635e5e6afc01e0138cd54cacade764143f5b050 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 5 Sep 2021 08:53:30 +0200 Subject: [PATCH 0928/1539] Improve changlog. * Use admonitions with custom titles and classes. * Separate changes by category. --- docs/project/changelog.rst | 649 ++++++++++++++++++++++--------------- 1 file changed, 395 insertions(+), 254 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 073514ecd..9161176fc 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -6,7 +6,7 @@ Changelog .. _backwards-compatibility policy: Backwards-compatibility policy -.............................. +------------------------------ websockets is intended for production use. Therefore, stability is a goal. @@ -26,45 +26,44 @@ Only documented APIs are public. Undocumented APIs are considered private. They may change at any time. 10.0 -.... +---- *In development* -.. warning:: - - **Version 10.0 drops compatibility with Python 3.6.** - -.. note:: - - **Version 10.0 enables a timeout of 10 seconds on** - :func:`~client.connect` **by default.** +Backwards-incompatible changes +.............................. - You can adjust the timeout with the ``open_timeout`` parameter. Set it to - :obj:`None` to disable the timeout entirely. +.. admonition:: websockets 10.0 requires Python ≥ 3.7. + :class: tip -.. note:: + websockets 9.1 is the last version supporting Python 3.6. - **Version 10.0 deprecates the** ``loop`` **parameter from all APIs.** +.. admonition:: The ``loop`` parameter is deprecated from all APIs. + :class: caution This reflects a decision made in Python 3.8. See the release notes of Python 3.10 for details. -.. note:: +.. admonition:: :func:`~client.connect` times out after 10 seconds by default. + :class: note - **Version 10.0 changes arguments of** - :exc:`~exceptions.ConnectionClosed` **.** + You can adjust the timeout with the ``open_timeout`` parameter. Set it to + :obj:`None` to disable the timeout entirely. - If you raise :exc:`~exceptions.ConnectionClosed` or a subclass — rather - than catch them when websockets raises them — you must change your code. +.. admonition:: The signature of :exc:`~exceptions.ConnectionClosed` changed. + :class: note -.. note:: + If you raise :exc:`~exceptions.ConnectionClosed` or a subclass, rather + than catch them when websockets raises them, you must change your code. - **Version 10.0 adds a ``msg`` parameter to** ``InvalidURI.__init__`` **.** +.. admonition:: A ``msg`` parameter was added to :exc:`~exceptions.InvalidURI`. + :class: note - If you raise :exc:`~exceptions.InvalidURI` — rather than catch them when - websockets raises them — you must change your code. + If you raise :exc:`~exceptions.InvalidURI`, rather than catch it when + websockets raises it, you must change your code. -Also: +New features +............ * Added compatibility with Python 3.10. @@ -75,6 +74,9 @@ Also: * Added ``open_timeout`` to :func:`~client.connect`. +Improvements +............ + * Improved logging. * Provided additional information in :exc:`~exceptions.ConnectionClosed` @@ -85,88 +87,100 @@ Also: * Made it easier to customize authentication with :meth:`~auth.BasicAuthWebSocketServerProtocol.check_credentials`. -* Fixed handling of relative redirects in :func:`~client.connect`. +* Supported relative redirects in :func:`~client.connect`. * Improved API documentation. 9.1 -... +--- *May 27, 2021* -.. caution:: +Security fix +............ - **Version 9.1 fixes a security issue introduced in version 8.0.** +.. admonition:: websockets 9.1 fixes a security issue introduced in 8.0. + :class: important Version 8.0 was vulnerable to timing attacks on HTTP Basic Auth passwords (`CVE-2021-33880`_). .. _CVE-2021-33880: https://nvd.nist.gov/vuln/detail/CVE-2021-33880 - 9.0.2 -..... +----- *May 15, 2021* +Bug fixes +......... + * Restored compatibility of ``python -m websockets`` with Python < 3.9. * Restored compatibility with mypy. 9.0.1 -..... +----- *May 2, 2021* +Bug fixes +......... + * Fixed issues with the packaging of the 9.0 release. 9.0 -... +--- *May 1, 2021* -.. note:: +Backwards-incompatible changes +.............................. - **Version 9.0 moves or deprecates several APIs.** +.. admonition:: Several modules are moved or deprecated. + :class: caution - Aliases provide backwards compatibility for all previously public APIs. + Aliases provide compatibility for all previously public APIs according to + the `backwards-compatibility policy`_ * :class:`~datastructures.Headers` and - :exc:`~datastructures.MultipleValuesError` were moved from + :exc:`~datastructures.MultipleValuesError` are moved from ``websockets.http`` to :mod:`websockets.datastructures`. If you're using them, you should adjust the import path. * The ``client``, ``server``, ``protocol``, and ``auth`` modules were - moved from the websockets package to ``websockets.legacy`` sub-package, - as part of an upcoming refactoring. Despite the name, they're still - fully supported. The refactoring should be a transparent upgrade for - most uses when it's available. The legacy implementation will be - preserved according to the `backwards-compatibility policy`_. + moved from the ``websockets`` package to a ``websockets.legacy`` + sub-package. Despite the name, they're still fully supported. * The ``framing``, ``handshake``, ``headers``, ``http``, and ``uri`` - modules in the websockets package are deprecated. These modules provided - low-level APIs for reuse by other WebSocket implementations, but that - never happened. Keeping these APIs public makes it more difficult to - improve websockets for no actual benefit. + modules in the ``websockets`` package are deprecated. These modules + provided low-level APIs for reuse by other projects, but they didn't + reach that goal. Keeping these APIs public makes it more difficult to + improve websockets. -.. note:: + These changes pave the path for a refactoring that should be a transparent + upgrade for most uses and facilitate integration by other projects. - **Version 9.0 may require changes if you use static code analysis tools.** +.. admonition:: Convenience imports from ``websockets`` are performed lazily. + :class: note - Convenience imports from the websockets module are performed lazily. While - this is supported by Python, static code analysis tools such as mypy are + While Python supports this, static code analysis tools such as mypy are unable to understand the behavior. If you depend on such tools, use the real import path, which can be found - in the API documentation:: + in the API documentation, for example:: from websockets.client import connect from websockets.server import serve -Also: +New features +............ * Added compatibility with Python 3.9. +Improvements +............ + * Added support for IRIs in addition to URIs. * Added close codes 1012, 1013, and 1014. @@ -174,6 +188,11 @@ Also: * Raised an error when passing a :class:`dict` to :meth:`~legacy.protocol.WebSocketCommonProtocol.send`. +* Improved error reporting. + +Bug fixes +......... + * Fixed sending fragmented, compressed messages. * Fixed ``Host`` header sent when connecting to an IPv6 address. @@ -185,98 +204,93 @@ Also: * Ensured cancellation always propagates, even on Python versions where :exc:`~asyncio.CancelledError` inherits :exc:`Exception`. -* Improved error reporting. - - 8.1 -... +--- *November 1, 2019* +New features +............ + * Added compatibility with Python 3.8. 8.0.2 -..... +----- *July 31, 2019* +Bug fixes +......... + * Restored the ability to pass a socket with the ``sock`` parameter of :func:`~server.serve`. * Removed an incorrect assertion when a connection drops. 8.0.1 -..... +----- *July 21, 2019* -* Restored the ability to import ``WebSocketProtocolError`` from websockets. +Bug fixes +......... + +* Restored the ability to import ``WebSocketProtocolError`` from + ``websockets``. 8.0 -... +--- *July 7, 2019* -.. warning:: - - **Version 8.0 drops compatibility with Python 3.4 and 3.5.** +Backwards-incompatible changes +.............................. -.. note:: +.. admonition:: websockets 8.0 requires Python ≥ 3.6. + :class: tip - **Version 8.0 expects** ``process_request`` **to be a coroutine.** + websockets 7.0 is the last version supporting Python 3.4 and 3.5. - Previously, it could be a function or a coroutine. +.. admonition:: ``process_request`` is now expected to be a coroutine. + :class: note If you're passing a ``process_request`` argument to :func:`~server.serve` or :class:`~server.WebSocketServerProtocol`, or if you're overriding :meth:`~server.WebSocketServerProtocol.process_request` in a subclass, - define it with ``async def`` instead of ``def``. - - For backwards compatibility, functions are still mostly supported, but - mixing functions and coroutines won't work in some inheritance scenarios. + define it with ``async def`` instead of ``def``. Previously, both were supported. -.. note:: + For backwards compatibility, functions are still accepted, but mixing + functions and coroutines won't work in some inheritance scenarios. - **Version 8.0 changes the behavior of the** ``max_queue`` **parameter.** +.. admonition:: ``max_queue`` must be :obj:`None` to disable the limit. + :class: note If you were setting ``max_queue=0`` to make the queue of incoming messages unbounded, change it to ``max_queue=None``. -.. note:: - - **Version 8.0 deprecates the** ``host`` **,** ``port`` **, and** ``secure`` - **attributes of** :class:`~legacy.protocol.WebSocketCommonProtocol`. +.. admonition:: The ``host``, ``port``, and ``secure`` attributes + of :class:`~legacy.protocol.WebSocketCommonProtocol` are deprecated. + :class: note Use :attr:`~legacy.protocol.WebSocketCommonProtocol.local_address` in servers and :attr:`~legacy.protocol.WebSocketCommonProtocol.remote_address` in clients instead of ``host`` and ``port``. -.. note:: - - **Version 8.0 renames the** ``WebSocketProtocolError`` **exception** - to :exc:`~exceptions.ProtocolError` **.** +.. admonition:: ``WebSocketProtocolError`` is renamed + to :exc:`~exceptions.ProtocolError`. + :class: note - A ``WebSocketProtocolError`` alias provides backwards compatibility. + An alias provides backwards compatibility. -.. note:: +.. admonition:: ``read_response()`` now returns the reason phrase. + :class: note - **Version 8.0 adds the reason phrase to the return type of the low-level - API** ``read_response()`` **.** + If you're using this low-level API, you must change your code. -Also: - -* :meth:`~legacy.protocol.WebSocketCommonProtocol.send`, - :meth:`~legacy.protocol.WebSocketCommonProtocol.ping`, and - :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` support bytes-like - types :class:`bytearray` and :class:`memoryview` in addition to - :class:`bytes`. - -* Added :exc:`~exceptions.ConnectionClosedOK` and - :exc:`~exceptions.ConnectionClosedError` subclasses of - :exc:`~exceptions.ConnectionClosed` to tell apart normal connection - termination from errors. +New features +............ * Added :func:`~auth.basic_auth_protocol_factory` to enforce HTTP Basic Auth on the server side. @@ -288,20 +302,9 @@ Also: * Added :func:`~client.unix_connect` for connecting to Unix sockets. -* Improved support for sending fragmented messages by accepting asynchronous - iterators in :meth:`~legacy.protocol.WebSocketCommonProtocol.send`. - -* Prevented spurious log messages about :exc:`~exceptions.ConnectionClosed` - exceptions in keepalive ping task. If you were using ``ping_timeout=None`` - as a workaround, you can remove it. - -* Changed :meth:`WebSocketServer.close() - ` to perform a proper closing handshake - instead of failing the connection. - -* Avoided a crash when a ``extra_headers`` callable returns :obj:`None`. - -* Improved error messages when HTTP parsing fails. +* Added support for asynchronous generators + in :meth:`~legacy.protocol.WebSocketCommonProtocol.send` + to generate fragmented messages incrementally. * Enabled readline in the interactive client. @@ -313,30 +316,61 @@ Also: * Documented how to optimize memory usage. +Improvements +............ + +* :meth:`~legacy.protocol.WebSocketCommonProtocol.send`, + :meth:`~legacy.protocol.WebSocketCommonProtocol.ping`, and + :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` support bytes-like + types :class:`bytearray` and :class:`memoryview` in addition to + :class:`bytes`. + +* Added :exc:`~exceptions.ConnectionClosedOK` and + :exc:`~exceptions.ConnectionClosedError` subclasses of + :exc:`~exceptions.ConnectionClosed` to tell apart normal connection + termination from errors. + +* Changed :meth:`WebSocketServer.close() + ` to perform a proper closing handshake + instead of failing the connection. + +* Improved error messages when HTTP parsing fails. + * Improved API documentation. +Bug fixes +......... + +* Prevented spurious log messages about :exc:`~exceptions.ConnectionClosed` + exceptions in keepalive ping task. If you were using ``ping_timeout=None`` + as a workaround, you can remove it. + +* Avoided a crash when a ``extra_headers`` callable returns :obj:`None`. + 7.0 -... +--- *November 1, 2018* -.. warning:: +Backwards-incompatible changes +.............................. - websockets **now sends Ping frames at regular intervals and closes the - connection if it doesn't receive a matching Pong frame.** +.. admonition:: Keepalive is enabled by default. + :class: important + websockets now sends Ping frames at regular intervals and closes the + connection if it doesn't receive a matching Pong frame. See :class:`~legacy.protocol.WebSocketCommonProtocol` for details. -.. warning:: - - **Version 7.0 changes how a server terminates connections when it's closed - with** :meth:`WebSocketServer.close() - ` **.** +.. admonition:: Termination of connections by :meth:`WebSocketServer.close() + ` changes. + :class: caution Previously, connections handlers were canceled. Now, connections are - closed with close code 1001 (going away). From the perspective of the - connection handler, this is the same as if the remote endpoint was - disconnecting. This removes the need to prepare for + closed with close code 1001 (going away). + + From the perspective of the connection handler, this is the same as if the + remote endpoint was disconnecting. This removes the need to prepare for :exc:`~asyncio.CancelledError` in connection handlers. You can restore the previous behavior by adding the following line at the @@ -346,44 +380,45 @@ Also: closed = asyncio.ensure_future(websocket.wait_closed()) closed.add_done_callback(lambda task: task.cancel()) -.. note:: +.. admonition:: Calling :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` + concurrently raises a :exc:`RuntimeError`. + :class: note + + Concurrent calls lead to non-deterministic behavior because there are no + guarantees about which coroutine will receive which message. - **Version 7.0 renames the** ``timeout`` **argument of** - :func:`~server.serve` **and** :func:`~client.connect` **to** - ``close_timeout`` **.** +.. admonition:: The ``timeout`` argument of :func:`~server.serve` + and :func:`~client.connect` is renamed to ``close_timeout`` . + :class: note This prevents confusion with ``ping_timeout``. For backwards compatibility, ``timeout`` is still supported. -.. note:: - - **Version 7.0 changes how a** - :meth:`~legacy.protocol.WebSocketCommonProtocol.ping` **that hasn't - received a pong yet behaves when the connection is closed.** +.. admonition:: The ``origins`` argument of :func:`~server.serve` changes. + :class: note - The ping — as in ``ping = await websocket.ping()`` — used to be canceled - when the connection is closed, so that ``await ping`` raised - :exc:`~asyncio.CancelledError`. Now ``await ping`` raises - :exc:`~exceptions.ConnectionClosed` like other public APIs. + Include :obj:`None` in the list rather than ``''`` to allow requests that + don't contain an Origin header. -.. note:: +.. admonition:: Pending pings aren't canceled when the connection is closed. + :class: note - **Version 7.0 raises a** :exc:`RuntimeError` **exception if two coroutines - call** :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` - **concurrently.** + A ping — as in ``ping = await websocket.ping()`` — for which no pong was + received yet used to be canceled when the connection is closed, so that + ``await ping`` raised :exc:`~asyncio.CancelledError`. - Concurrent calls lead to non-deterministic behavior because there are no - guarantees about which coroutine will receive which message. + Now ``await ping`` raises :exc:`~exceptions.ConnectionClosed` like other + public APIs. -Also: +New features +............ * Added ``process_request`` and ``select_subprotocol`` arguments to :func:`~server.serve` and - :class:`~server.WebSocketServerProtocol` to customize + :class:`~server.WebSocketServerProtocol` to facilitate customization of :meth:`~server.WebSocketServerProtocol.process_request` and - :meth:`~server.WebSocketServerProtocol.select_subprotocol` without - subclassing :class:`~server.WebSocketServerProtocol`. + :meth:`~server.WebSocketServerProtocol.select_subprotocol`. * Added support for sending fragmented messages. @@ -392,33 +427,39 @@ Also: * Added an interactive client: ``python -m websockets ``. -* Changed the ``origins`` argument to represent the lack of an origin with - :obj:`None` rather than ``''``. - -* Fixed a data loss bug in - :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`: - canceling it at the wrong time could result in messages being dropped. +Improvements +............ * Improved handling of multiple HTTP headers with the same name. * Improved error messages when a required HTTP header is missing. +Bug fixes +......... + +* Fixed a data loss bug in + :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`: + canceling it at the wrong time could result in messages being dropped. + 6.0 -... +--- *July 16, 2018* -.. warning:: +Backwards-incompatible changes +.............................. - **Version 6.0 introduces the** :class:`~datastructures.Headers` **class - for managing HTTP headers and changes several public APIs:** +.. admonition:: The :class:`~datastructures.Headers` class is introduced and + several APIs are updated to use it. + :class: caution - * :meth:`~server.WebSocketServerProtocol.process_request` now - receives a :class:`~datastructures.Headers` instead of a - ``http.client.HTTPMessage`` in the ``request_headers`` argument. + * The ``request_headers`` argument + of :meth:`~server.WebSocketServerProtocol.process_request` is now + a :class:`~datastructures.Headers` instead of + an ``http.client.HTTPMessage``. * The ``request_headers`` and ``response_headers`` attributes of - :class:`~legacy.protocol.WebSocketCommonProtocol` are + :class:`~legacy.protocol.WebSocketCommonProtocol` are now :class:`~datastructures.Headers` instead of ``http.client.HTTPMessage``. * The ``raw_request_headers`` and ``raw_response_headers`` attributes of @@ -435,30 +476,35 @@ Also: pairs. Since :class:`~datastructures.Headers` and ``http.client.HTTPMessage`` - provide similar APIs, this change won't affect most of the code dealing - with HTTP headers. + provide similar APIs, much of the code dealing with HTTP headers won't + require changes. - -Also: +New features +............ * Added compatibility with Python 3.7. 5.0.1 -..... +----- *May 24, 2018* +Bug fixes +......... + * Fixed a regression in 5.0 that broke some invocations of :func:`~server.serve` and :func:`~client.connect`. 5.0 -... +--- *May 22, 2018* -.. caution:: +Security fix +............ - **Version 5.0 fixes a security issue introduced in version 4.0.** +.. admonition:: websockets 5.0 fixes a security issue introduced in 4.0. + :class: important Version 4.0 was vulnerable to denial of service by memory exhaustion because it didn't enforce ``max_size`` when decompressing compressed @@ -466,46 +512,43 @@ Also: .. _CVE-2018-1000518: https://nvd.nist.gov/vuln/detail/CVE-2018-1000518 -.. note:: +Backwards-incompatible changes +.............................. - **Version 5.0 adds a** ``user_info`` **field to the return value of** - ``parse_uri`` **and** ``WebSocketURI`` **.** +.. admonition:: A ``user_info`` field is added to the return value of + ``parse_uri`` and ``WebSocketURI``. + :class: note If you're unpacking ``WebSocketURI`` into four variables, adjust your code to account for that fifth field. -Also: +New features +............ * :func:`~client.connect` performs HTTP Basic Auth when the URI contains credentials. -* Iterating on incoming messages no longer raises an exception when the - connection terminates with close code 1001 (going away). - -* A plain HTTP request now receives a 426 Upgrade Required response and - doesn't log a stack trace. - * :func:`~server.unix_serve` can be used as an asynchronous context manager on Python ≥ 3.5.1. * Added the :attr:`~legacy.protocol.WebSocketCommonProtocol.closed` property to protocols. -* If a :meth:`~legacy.protocol.WebSocketCommonProtocol.ping` doesn't receive a - pong, it's canceled when the connection is closed. - -* Reported the cause of :exc:`~exceptions.ConnectionClosed` exceptions. - * Added new examples in the documentation. -* Updated documentation with new features from Python 3.6. +Improvements +............ -* Improved several other sections of the documentation. +* Iterating on incoming messages no longer raises an exception when the + connection terminates with close code 1001 (going away). -* Fixed missing close code, which caused :exc:`TypeError` on connection close. +* A plain HTTP request now receives a 426 Upgrade Required response and + doesn't log a stack trace. -* Fixed a race condition in the closing handshake that raised - :exc:`~exceptions.InvalidState`. +* If a :meth:`~legacy.protocol.WebSocketCommonProtocol.ping` doesn't receive a + pong, it's canceled when the connection is closed. + +* Reported the cause of :exc:`~exceptions.ConnectionClosed` exceptions. * Stopped logging stack traces when the TCP connection dies prematurely. @@ -515,40 +558,59 @@ Also: * Prevented processing of incoming frames after failing the connection. +* Updated documentation with new features from Python 3.6. + +* Improved several sections of the documentation. + +Bug fixes +......... + +* Prevented :exc:`TypeError` due to missing close code on connection close. + +* Fixed a race condition in the closing handshake that raised + :exc:`~exceptions.InvalidState`. + 4.0.1 -..... +----- *November 2, 2017* +Bug fixes +......... + * Fixed issues with the packaging of the 4.0 release. 4.0 -... +--- *November 2, 2017* -.. warning:: +Backwards-incompatible changes +.............................. - **Version 4.0 drops compatibility with Python 3.3.** +.. admonition:: websockets 4.0 requires Python ≥ 3.4. + :class: tip -.. note:: + websockets 3.4 is the last version supporting Python 3.3. - **Version 4.0 enables compression with the permessage-deflate extension.** +.. admonition:: Compression is enabled by default. + :class: important - In August 2017, Firefox and Chrome support it, but not Safari and IE. + In August 2017, Firefox and Chrome support the permessage-deflate + extension, but not Safari and IE. Compression should improve performance but it increases RAM and CPU use. If you want to disable compression, add ``compression=None`` when calling :func:`~server.serve` or :func:`~client.connect`. -.. note:: - - **Version 4.0 removes the** ``state_name`` **attribute of protocols.** +.. admonition:: The ``state_name`` attribute of protocols is deprecated. + :class: note Use ``protocol.state.name`` instead of ``protocol.state_name``. -Also: +New features +............ * :class:`~legacy.protocol.WebSocketCommonProtocol` instances can be used as asynchronous iterators on Python ≥ 3.6. They yield incoming messages. @@ -558,27 +620,42 @@ Also: * Added the :attr:`~server.WebSocketServer.sockets` attribute to the return value of :func:`~server.serve`. -* Reorganized and extended documentation. +* Allowed ``extra_headers`` to override ``Server`` and ``User-Agent`` headers. -* Aborted connections if they don't close within the configured ``timeout``. +Improvements +............ -* Rewrote connection termination to increase robustness in edge cases. +* Reorganized and extended documentation. -* Stopped leaking pending tasks when :meth:`~asyncio.Task.cancel` is called on - a connection while it's being closed. +* Rewrote connection termination to increase robustness in edge cases. * Reduced verbosity of "Failing the WebSocket connection" logs. -* Allowed ``extra_headers`` to override ``Server`` and ``User-Agent`` headers. +Bug fixes +......... + +* Aborted connections if they don't close within the configured ``timeout``. + +* Stopped leaking pending tasks when :meth:`~asyncio.Task.cancel` is called on + a connection while it's being closed. 3.4 -... +--- *August 20, 2017* -* Renamed :func:`~server.serve` and :func:`~client.connect`'s - ``klass`` argument to ``create_protocol`` to reflect that it can also be a - callable. For backwards compatibility, ``klass`` is still supported. +Backwards-incompatible changes +.............................. + +.. admonition:: ``InvalidStatus`` is replaced + by :class:`~exceptions.InvalidStatusCode`. + :class: note + + This exception is raised when :func:`~client.connect` receives an invalid + response status code from the server. + +New features +............ * :func:`~server.serve` can be used as an asynchronous context manager on Python ≥ 3.5.1. @@ -588,57 +665,85 @@ Also: * Made read and write buffer sizes configurable. +Improvements +............ + +* Renamed :func:`~server.serve` and :func:`~client.connect`'s + ``klass`` argument to ``create_protocol`` to reflect that it can also be a + callable. For backwards compatibility, ``klass`` is still supported. + * Rewrote HTTP handling for simplicity and performance. * Added an optional C extension to speed up low-level operations. -* An invalid response status code during :func:`~client.connect` now - raises :class:`~exceptions.InvalidStatusCode`. +Bug fixes +......... * Providing a ``sock`` argument to :func:`~client.connect` no longer crashes. 3.3 -... +--- *March 29, 2017* +New features +............ + * Ensured compatibility with Python 3.6. +Improvements +............ + * Reduced noise in logs caused by connection resets. +Bug fixes +......... + * Avoided crashing on concurrent writes on slow connections. 3.2 -... +--- *August 17, 2016* +New features +............ + * Added ``timeout``, ``max_size``, and ``max_queue`` arguments to :func:`~client.connect` and :func:`~server.serve`. +Improvements +............ + * Made server shutdown more robust. 3.1 -... +--- *April 21, 2016* -* Avoided a warning when closing a connection before the opening handshake. +New features +............ * Added flow control for incoming data. +Bug fixes +......... + +* Avoided a warning when closing a connection before the opening handshake. + 3.0 -... +--- *December 25, 2015* -.. warning:: - - **Version 3.0 introduces a backwards-incompatible change in the** - :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` **API.** +Backwards-incompatible changes +.............................. - **If you're upgrading from 2.x or earlier, please read this carefully.** +.. admonition:: :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` now + raises an exception when the connection is closed. + :class: caution :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` used to return :obj:`None` when the connection was closed. This required checking the @@ -653,61 +758,77 @@ Also: message = await websocket.recv() - When implementing a server, which is the more popular use case, there's no - strong reason to handle such exceptions. Let them bubble up, terminate the - handler coroutine, and the server will simply ignore them. + When implementing a server, there's no strong reason to handle such + exceptions. Let them bubble up, terminate the handler coroutine, and the + server will simply ignore them. In order to avoid stranding projects built upon an earlier version, the previous behavior can be restored by passing ``legacy_recv=True`` to :func:`~server.serve`, :func:`~client.connect`, :class:`~server.WebSocketServerProtocol`, or - :class:`~client.WebSocketClientProtocol`. ``legacy_recv`` isn't - documented in their signatures but isn't scheduled for deprecation either. + :class:`~client.WebSocketClientProtocol`. -Also: +New features +............ * :func:`~client.connect` can be used as an asynchronous context manager on Python ≥ 3.5.1. -* Updated documentation with ``await`` and ``async`` syntax from Python 3.5. - * :meth:`~legacy.protocol.WebSocketCommonProtocol.ping` and :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` support data passed as :class:`str` in addition to :class:`bytes`. +* Made ``state_name`` attribute on protocols a public API. + +Improvements +............ + +* Updated documentation with ``await`` and ``async`` syntax from Python 3.5. + * Worked around an :mod:`asyncio` bug affecting connection termination under load. -* Made ``state_name`` attribute on protocols a public API. - * Improved documentation. 2.7 -... +--- *November 18, 2015* +New features +............ + * Added compatibility with Python 3.5. +Improvements +............ + * Refreshed documentation. 2.6 -... +--- *August 18, 2015* +New features +............ + * Added ``local_address`` and ``remote_address`` attributes on protocols. * Closed open connections with code 1001 when a server shuts down. +Bug fixes +......... + * Avoided TCP fragmentation of small frames. 2.5 -... +--- *July 28, 2015* -* Improved documentation. +New features +............ * Provided access to handshake request and response HTTP headers. @@ -715,49 +836,66 @@ Also: * Added support for running on a non-default event loop. -* Returned a 403 status code instead of 400 when the request Origin isn't - allowed. +Improvements +............ -* Canceling :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` no longer - drops the next message. +* Improved documentation. + +* Sent a 403 status code instead of 400 when request Origin isn't allowed. * Clarified that the closing handshake can be initiated by the client. * Set the close code and reason more consistently. -* Strengthened connection termination by simplifying the implementation. +* Strengthened connection termination. + +Bug fixes +......... -* Improved tests, added tox configuration, and enforced 100% branch coverage. +* Canceling :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` no longer + drops the next message. 2.4 -... +--- *January 31, 2015* +New features +............ + * Added support for subprotocols. * Added ``loop`` argument to :func:`~client.connect` and :func:`~server.serve`. 2.3 -... +--- *November 3, 2014* +Improvements +............ + * Improved compliance of close codes. 2.2 -... +--- *July 28, 2014* +New features +............ + * Added support for limiting message size. 2.1 -... +--- *April 26, 2014* +New features +............ + * Added ``host``, ``port`` and ``secure`` attributes on protocols. * Added support for providing and checking Origin_. @@ -765,36 +903,39 @@ Also: .. _Origin: https://www.rfc-editor.org/rfc/rfc6455.html#section-10.2 2.0 -... +--- *February 16, 2014* -.. warning:: +Backwards-incompatible changes +.............................. - **Version 2.0 introduces a backwards-incompatible change in the** - :meth:`~legacy.protocol.WebSocketCommonProtocol.send`, +.. admonition:: :meth:`~legacy.protocol.WebSocketCommonProtocol.send`, :meth:`~legacy.protocol.WebSocketCommonProtocol.ping`, and - :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` **APIs.** - - **If you're upgrading from 1.x or earlier, please read this carefully.** + :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` are now coroutines. + :class: caution - These APIs used to be functions. Now they're coroutines. + They used to be functions. Instead of:: websocket.send(message) - you must now write:: + you must write:: await websocket.send(message) -Also: +New features +............ * Added flow control for outgoing data. 1.0 -... +--- *November 14, 2013* +New features +............ + * Initial public release. From ab0e3b9114f65158104a9cdc1b83ee3357438390 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 5 Sep 2021 09:03:39 +0200 Subject: [PATCH 0929/1539] Remove loop parameter from WebSocketServer. --- docs/project/changelog.rst | 3 +++ src/websockets/legacy/server.py | 25 +++++++++++-------------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 9161176fc..66d8d541d 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -44,6 +44,9 @@ Backwards-incompatible changes This reflects a decision made in Python 3.8. See the release notes of Python 3.10 for details. + The ``loop`` parameter is also removed + from :class:`~server.WebSocketServer`. This should be transparent. + .. admonition:: :func:`~client.connect` times out after 10 seconds by default. :class: note diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 673888c3d..4399f0782 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -655,14 +655,7 @@ class WebSocketServer: """ - def __init__( - self, - loop: asyncio.AbstractEventLoop, - logger: Optional[LoggerLike] = None, - ) -> None: - # Store a reference to loop to avoid relying on self.server._loop. - self.loop = loop - + def __init__(self, logger: Optional[LoggerLike] = None): if logger is None: logger = logging.getLogger("websockets.server") self.logger = logger @@ -674,7 +667,7 @@ def __init__( self.close_task: Optional[asyncio.Task[None]] = None # Completed when the server is closed and connections are terminated. - self.closed_waiter: asyncio.Future[None] = loop.create_future() + self.closed_waiter: asyncio.Future[None] def wrap(self, server: asyncio.AbstractServer) -> None: """ @@ -705,6 +698,10 @@ def wrap(self, server: asyncio.AbstractServer) -> None: name = str(sock.getsockname()) self.logger.info("server listening on %s", name) + # Initialized here because we need a reference to the event loop. + # This should be moved back to __init__ in Python 3.10. + self.closed_waiter = server.get_loop().create_future() + def register(self, protocol: WebSocketServerProtocol) -> None: """ Register a connection with this server. @@ -743,7 +740,7 @@ def close(self) -> None: """ if self.close_task is None: - self.close_task = self.loop.create_task(self._close()) + self.close_task = self.server.get_loop().create_task(self._close()) async def _close(self) -> None: """ @@ -764,7 +761,7 @@ async def _close(self) -> None: # Wait until all accepted connections reach connection_made() and call # register(). See https://bugs.python.org/issue34852 for details. - await asyncio.sleep(0, **loop_if_py_lt_38(self.loop)) + await asyncio.sleep(0, **loop_if_py_lt_38(self.server.get_loop())) # Close OPEN connections with status code 1001. Since the server was # closed, handshake() closes OPENING connections with a HTTP 503 @@ -777,7 +774,7 @@ async def _close(self) -> None: asyncio.create_task(websocket.close(1001)) for websocket in self.websockets ], - **loop_if_py_lt_38(self.loop), + **loop_if_py_lt_38(self.server.get_loop()), ) # Wait until all connection handlers are complete. @@ -786,7 +783,7 @@ async def _close(self) -> None: if self.websockets: await asyncio.wait( [websocket.handler_task for websocket in self.websockets], - **loop_if_py_lt_38(self.loop), + **loop_if_py_lt_38(self.server.get_loop()), ) # Tell wait_closed() to return. @@ -968,7 +965,7 @@ def __init__( loop = _loop warnings.warn("remove loop argument", DeprecationWarning) - ws_server = WebSocketServer(logger=logger, loop=loop) + ws_server = WebSocketServer(logger=logger) secure = kwargs.get("ssl") is not None From 6f5e609b3d7644a42367946f0a8d33afcda854f4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 5 Sep 2021 09:07:53 +0200 Subject: [PATCH 0930/1539] Deprecated legacy_recv. It's been more than 5 years. I would like to get rid of it at some point in the future -- 5 years from now if I apply the backwards-compatibility policy strictly :-/ --- docs/project/changelog.rst | 4 ++++ src/websockets/legacy/protocol.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 66d8d541d..8b6325a1a 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -53,6 +53,10 @@ Backwards-incompatible changes You can adjust the timeout with the ``open_timeout`` parameter. Set it to :obj:`None` to disable the timeout entirely. +.. admonition:: The ``legacy_recv`` option is deprecated. + + See the release notes of websockets 3.0 for details. + .. admonition:: The signature of :exc:`~exceptions.ConnectionClosed` changed. :class: note diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 4631151e6..618e451c2 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -183,6 +183,9 @@ def __init__( loop: Optional[asyncio.AbstractEventLoop] = None, timeout: Optional[float] = None, ) -> None: + if legacy_recv: # pragma: no cover + warnings.warn("legacy_recv is deprecated", DeprecationWarning) + # Backwards compatibility: close_timeout used to be called timeout. if timeout is None: timeout = 10 From 6680923c0f253ed85f885f983bb9f29efcd86b30 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 5 Sep 2021 16:44:19 +0200 Subject: [PATCH 0931/1539] Fix typo. --- docs/topics/authentication.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/topics/authentication.rst b/docs/topics/authentication.rst index 44c9c6151..8beeea9aa 100644 --- a/docs/topics/authentication.rst +++ b/docs/topics/authentication.rst @@ -120,7 +120,7 @@ WebSocket server. the old credentials, which may be expired, resulting in an HTTP 401. Then the next connection succeeds. Perhaps errors clear the cache. - When tokens are short-lived on single-use, this bug produces an + When tokens are short-lived or single-use, this bug produces an interesting effect: every other WebSocket connection fails. * Safari 14 ignores credentials entirely. From 903fd24d8cb163e7a8836f3a1faf03fa8869d969 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 5 Sep 2021 17:01:04 +0200 Subject: [PATCH 0932/1539] Minor formatting fixes in FAQ. --- docs/howto/faq.rst | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index ec657bb06..a32e5ec1b 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -11,9 +11,6 @@ FAQ .. _developing with asyncio: https://docs.python.org/3/library/asyncio-dev.html -.. contents:: - :local: - Server side ----------- @@ -221,11 +218,13 @@ Putting an ``await`` statement in a ``for`` or a ``while`` loop isn't enough to yield control. Awaiting a coroutine may yield control, but there's no guarantee that it will. -For example, ``send()`` only yields control when send buffers are full, which -never happens in most practical cases. +For example, :meth:`~legacy.protocol.WebSocketCommonProtocol.send` only yields +control when send buffers are full, which never happens in most practical +cases. -If you run a loop that contains only synchronous operations and a ``send()`` -call, you must yield control explicitly with :func:`asyncio.sleep`:: +If you run a loop that contains only synchronous operations and +a :meth:`~legacy.protocol.WebSocketCommonProtocol.send` call, you must yield +control explicitly with :func:`asyncio.sleep`:: async def producer(websocket): message = generate_next_message() @@ -248,7 +247,8 @@ blocks the event loop and prevents asyncio from operating normally. This may lead to messages getting send but not received, to connection timeouts, and to unexpected results of shotgun debugging e.g. adding an -unnecessary call to ``send()`` makes the program functional. +unnecessary call to :meth:`~legacy.protocol.WebSocketCommonProtocol.send` +makes the program functional. Both sides ---------- @@ -352,8 +352,8 @@ See the discussion of :doc:`timeouts <../topics/timeouts>` for details. If websockets' default timeout of 20 seconds is too short for your use case, you can adjust it with the ``ping_timeout`` argument. -How do I set a timeout on ``recv()``? -..................................... +How do I set a timeout on :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`? +................................................................................ Use :func:`~asyncio.wait_for`:: From b12adc59e74dc521710973894576d13d03dff869 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 5 Sep 2021 17:02:20 +0200 Subject: [PATCH 0933/1539] Reformat conf.py with black. --- docs/conf.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index d9e3cd598..ffe61f7ba 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -74,8 +74,9 @@ # Workaround for https://github.com/sphinx-doc/sphinx/issues/9560 from sphinx.domains.python import PythonDomain -assert PythonDomain.object_types['data'].roles == ('data', 'obj') -PythonDomain.object_types['data'].roles = ('data', 'class', 'obj') + +assert PythonDomain.object_types["data"].roles == ("data", "obj") +PythonDomain.object_types["data"].roles = ("data", "class", "obj") intersphinx_mapping = {"python": ("https://docs.python.org/3", None)} From 788f8e149c8c76570fa485a193fdb4191beff69b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 5 Sep 2021 17:05:37 +0200 Subject: [PATCH 0934/1539] Add missing items to changelog. --- docs/project/changelog.rst | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 8b6325a1a..84ab6a314 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -81,23 +81,47 @@ New features * Added ``open_timeout`` to :func:`~client.connect`. +* Documented how to integrate with `Django `_. + +* Documented how to deploy websockets in production, with several options. + +* Documented how to authenticate connections. + +* Documented how to broadcast messages to many connections. + Improvements ............ -* Improved logging. - -* Provided additional information in :exc:`~exceptions.ConnectionClosed` - exceptions. +* Improved logging. See the :doc:`logging guide <../topics/logging>`. * Optimized default compression settings to reduce memory usage. +* Optimized processing of client-to-server messages when the C extension isn't + available. + +* Supported relative redirects in :func:`~client.connect`. + +* Handled TCP connection drops during the opening handshake. + * Made it easier to customize authentication with :meth:`~auth.BasicAuthWebSocketServerProtocol.check_credentials`. -* Supported relative redirects in :func:`~client.connect`. +* Provided additional information in :exc:`~exceptions.ConnectionClosed` + exceptions. + +* Clarified several exceptions or log messages. + +* Restructured documentation. * Improved API documentation. +* Extended FAQ. + +Bug fixes +......... + +* Avoided a crash when receiving a ping while the connection is closing. + 9.1 --- From add0d464b9721a94195f7481aa1a8bbffeed3f98 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 5 Sep 2021 19:05:31 +0200 Subject: [PATCH 0935/1539] Make the logger attribute a public API. --- docs/reference/client.rst | 2 ++ docs/reference/common.rst | 2 ++ docs/reference/server.rst | 2 ++ src/websockets/legacy/protocol.py | 3 ++- 4 files changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/reference/client.rst b/docs/reference/client.rst index eaa6cdd76..dc31dc032 100644 --- a/docs/reference/client.rst +++ b/docs/reference/client.rst @@ -33,6 +33,8 @@ Client .. autoattribute:: id + .. autoattribute:: logger + .. autoproperty:: local_address .. autoproperty:: remote_address diff --git a/docs/reference/common.rst b/docs/reference/common.rst index 3b9f34a57..f5422bc35 100644 --- a/docs/reference/common.rst +++ b/docs/reference/common.rst @@ -24,6 +24,8 @@ Both sides .. autoattribute:: id + .. autoattribute:: logger + .. autoproperty:: local_address .. autoproperty:: remote_address diff --git a/docs/reference/server.rst b/docs/reference/server.rst index 0a5a060f3..1864594f5 100644 --- a/docs/reference/server.rst +++ b/docs/reference/server.rst @@ -50,6 +50,8 @@ Server .. autoattribute:: id + .. autoattribute:: logger + .. autoproperty:: local_address .. autoproperty:: remote_address diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 618e451c2..a31a5c7c8 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -218,7 +218,8 @@ def __init__( logger = logging.getLogger("websockets.protocol") # https://github.com/python/typeshed/issues/5561 logger = cast(logging.Logger, logger) - self.logger = logging.LoggerAdapter(logger, {"websocket": self}) + self.logger: LoggerLike = logging.LoggerAdapter(logger, {"websocket": self}) + """Logger for this connection.""" # Track if DEBUG is enabled. Shortcut logging calls if it isn't. self.debug = logger.isEnabledFor(logging.DEBUG) From 0a935b8ec16f4430ffe638cdbfbe45f6f9d7f684 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 5 Sep 2021 20:53:39 +0200 Subject: [PATCH 0936/1539] Prevent unlimited reads. This can mitigate some denial of service scenarios. --- src/websockets/http11.py | 40 +++++++++++++++++------- src/websockets/streams.py | 21 +++++++++++-- tests/test_http11.py | 25 +++++++++++++-- tests/test_streams.py | 65 +++++++++++++++++++++++++++++++-------- 4 files changed, 123 insertions(+), 28 deletions(-) diff --git a/src/websockets/http11.py b/src/websockets/http11.py index b82a0bfdc..052719c67 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -7,8 +7,18 @@ from . import datastructures, exceptions +# Maximum total size of headers is around 256 * 4 KiB = 1 MiB MAX_HEADERS = 256 -MAX_LINE = 4110 + +# We can use the same limit for the request line and header lines: +# "GET <4096 bytes> HTTP/1.1\r\n" = 4111 bytes +# "Set-Cookie: <4097 bytes>\r\n" = 4111 bytes +# (RFC requires 4096 bytes; for some reason Firefox supports 4097 bytes.) +MAX_LINE = 4111 + +# Support for HTTP response bodies is intended to read an error message +# returned by a server. It isn't designed to perform large file transfers. +MAX_BODY = 2 ** 20 # 1 MiB def d(value: bytes) -> str: @@ -60,7 +70,7 @@ class Request: @classmethod def parse( cls, - read_line: Callable[[], Generator[None, None, bytes]], + read_line: Callable[[int], Generator[None, None, bytes]], ) -> Generator[None, None, Request]: """ Parse a WebSocket handshake request. @@ -157,9 +167,9 @@ class Response: @classmethod def parse( cls, - read_line: Callable[[], Generator[None, None, bytes]], + read_line: Callable[[int], Generator[None, None, bytes]], read_exact: Callable[[int], Generator[None, None, bytes]], - read_to_eof: Callable[[], Generator[None, None, bytes]], + read_to_eof: Callable[[int], Generator[None, None, bytes]], ) -> Generator[None, None, Response]: """ Parse a WebSocket handshake response. @@ -234,7 +244,16 @@ def parse( content_length = int(raw_content_length) if content_length is None: - body = yield from read_to_eof() + try: + body = yield from read_to_eof(MAX_BODY) + except RuntimeError: + raise exceptions.SecurityError( + f"body too large: over {MAX_BODY} bytes" + ) + elif content_length > MAX_BODY: + raise exceptions.SecurityError( + f"body too large: {content_length} bytes" + ) else: body = yield from read_exact(content_length) @@ -255,7 +274,7 @@ def serialize(self) -> bytes: def parse_headers( - read_line: Callable[[], Generator[None, None, bytes]], + read_line: Callable[[int], Generator[None, None, bytes]], ) -> Generator[None, None, datastructures.Headers]: """ Parse HTTP headers. @@ -306,7 +325,7 @@ def parse_headers( def parse_line( - read_line: Callable[[], Generator[None, None, bytes]], + read_line: Callable[[int], Generator[None, None, bytes]], ) -> Generator[None, None, bytes]: """ Parse a single line. @@ -322,10 +341,9 @@ def parse_line( SecurityError: if the response exceeds a security limit. """ - # Security: TODO: add a limit here - line = yield from read_line() - # Security: this guarantees header values are small (hard-coded = 4 KiB) - if len(line) > MAX_LINE: + try: + line = yield from read_line(MAX_LINE) + except RuntimeError: raise exceptions.SecurityError("line too long") # Not mandatory but safe - https://www.rfc-editor.org/rfc/rfc7230.html#section-3.5 if not line.endswith(b"\r\n"): diff --git a/src/websockets/streams.py b/src/websockets/streams.py index 094cbb53a..f861d4bd2 100644 --- a/src/websockets/streams.py +++ b/src/websockets/streams.py @@ -17,7 +17,7 @@ def __init__(self) -> None: self.buffer = bytearray() self.eof = False - def read_line(self) -> Generator[None, None, bytes]: + def read_line(self, m: int) -> Generator[None, None, bytes]: """ Read a LF-terminated line from the stream. @@ -25,8 +25,12 @@ def read_line(self) -> Generator[None, None, bytes]: The return value includes the LF character. + Args: + m: maximum number bytes to read; this is a security limit. + Raises: EOFError: if the stream ends without a LF. + RuntimeError: if the stream ends in more than ``m`` bytes. """ n = 0 # number of bytes to read @@ -36,9 +40,13 @@ def read_line(self) -> Generator[None, None, bytes]: if n > 0: break p = len(self.buffer) + if p > m: + raise RuntimeError(f"read {p} bytes, expected no more than {m} bytes") if self.eof: raise EOFError(f"stream ends after {p} bytes, before end of line") yield + if n > m: + raise RuntimeError(f"read {n} bytes, expected no more than {m} bytes") r = self.buffer[:n] del self.buffer[:n] return r @@ -66,14 +74,23 @@ def read_exact(self, n: int) -> Generator[None, None, bytes]: del self.buffer[:n] return r - def read_to_eof(self) -> Generator[None, None, bytes]: + def read_to_eof(self, m: int) -> Generator[None, None, bytes]: """ Read all bytes from the stream. This is a generator-based coroutine. + Args: + m: maximum number bytes to read; this is a security limit. + + Raises: + RuntimeError: if the stream ends in more than ``m`` bytes. + """ while not self.eof: + p = len(self.buffer) + if p > m: + raise RuntimeError(f"read {p} bytes, expected no more than {m} bytes") yield r = self.buffer[:] del self.buffer[:] diff --git a/tests/test_http11.py b/tests/test_http11.py index afd85f64a..61d377925 100644 --- a/tests/test_http11.py +++ b/tests/test_http11.py @@ -158,7 +158,8 @@ def test_parse_empty(self): with self.assertRaises(EOFError) as raised: next(self.parse()) self.assertEqual( - str(raised.exception), "connection closed while reading HTTP status line" + str(raised.exception), + "connection closed while reading HTTP status line", ) def test_parse_invalid_status_line(self): @@ -230,6 +231,24 @@ def test_parse_body_without_content_length(self): response = self.assertGeneratorReturns(gen) self.assertEqual(response.body, b"Hello world!\n") + def test_parse_body_with_content_length_too_long(self): + self.reader.feed_data(b"HTTP/1.1 200 OK\r\nContent-Length: 1048577\r\n\r\n") + with self.assertRaises(SecurityError) as raised: + next(self.parse()) + self.assertEqual( + str(raised.exception), + "body too large: 1048577 bytes", + ) + + def test_parse_body_without_content_length_too_long(self): + self.reader.feed_data(b"HTTP/1.1 200 OK\r\n\r\n" + b"a" * 1048577) + with self.assertRaises(SecurityError) as raised: + next(self.parse()) + self.assertEqual( + str(raised.exception), + "body too large: over 1048576 bytes", + ) + def test_parse_body_with_transfer_encoding(self): self.reader.feed_data(b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n") with self.assertRaises(NotImplementedError) as raised: @@ -314,8 +333,8 @@ def test_parse_too_long_value(self): next(self.parse_headers()) def test_parse_too_long_line(self): - # Header line contains 5 + 4104 + 2 = 4111 bytes. - self.reader.feed_data(b"foo: " + b"a" * 4104 + b"\r\n\r\n") + # Header line contains 5 + 4105 + 2 = 4112 bytes. + self.reader.feed_data(b"foo: " + b"a" * 4105 + b"\r\n\r\n") with self.assertRaises(SecurityError): next(self.parse_headers()) diff --git a/tests/test_streams.py b/tests/test_streams.py index 8abefbcc9..fd7c66a0b 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -10,24 +10,24 @@ def setUp(self): def test_read_line(self): self.reader.feed_data(b"spam\neggs\n") - gen = self.reader.read_line() + gen = self.reader.read_line(32) line = self.assertGeneratorReturns(gen) self.assertEqual(line, b"spam\n") - gen = self.reader.read_line() + gen = self.reader.read_line(32) line = self.assertGeneratorReturns(gen) self.assertEqual(line, b"eggs\n") def test_read_line_need_more_data(self): self.reader.feed_data(b"spa") - gen = self.reader.read_line() + gen = self.reader.read_line(32) self.assertGeneratorRunning(gen) self.reader.feed_data(b"m\neg") line = self.assertGeneratorReturns(gen) self.assertEqual(line, b"spam\n") - gen = self.reader.read_line() + gen = self.reader.read_line(32) self.assertGeneratorRunning(gen) self.reader.feed_data(b"gs\n") line = self.assertGeneratorReturns(gen) @@ -37,11 +37,34 @@ def test_read_line_not_enough_data(self): self.reader.feed_data(b"spa") self.reader.feed_eof() - gen = self.reader.read_line() + gen = self.reader.read_line(32) with self.assertRaises(EOFError) as raised: next(gen) self.assertEqual( - str(raised.exception), "stream ends after 3 bytes, before end of line" + str(raised.exception), + "stream ends after 3 bytes, before end of line", + ) + + def test_read_line_too_long(self): + self.reader.feed_data(b"spam\neggs\n") + + gen = self.reader.read_line(2) + with self.assertRaises(RuntimeError) as raised: + next(gen) + self.assertEqual( + str(raised.exception), + "read 5 bytes, expected no more than 2 bytes", + ) + + def test_read_line_too_long_need_more_data(self): + self.reader.feed_data(b"spa") + + gen = self.reader.read_line(2) + with self.assertRaises(RuntimeError) as raised: + next(gen) + self.assertEqual( + str(raised.exception), + "read 3 bytes, expected no more than 2 bytes", ) def test_read_exact(self): @@ -78,11 +101,12 @@ def test_read_exact_not_enough_data(self): with self.assertRaises(EOFError) as raised: next(gen) self.assertEqual( - str(raised.exception), "stream ends after 3 bytes, expected 4 bytes" + str(raised.exception), + "stream ends after 3 bytes, expected 4 bytes", ) def test_read_to_eof(self): - gen = self.reader.read_to_eof() + gen = self.reader.read_to_eof(32) self.reader.feed_data(b"spam") self.assertGeneratorRunning(gen) @@ -94,10 +118,21 @@ def test_read_to_eof(self): def test_read_to_eof_at_eof(self): self.reader.feed_eof() - gen = self.reader.read_to_eof() + gen = self.reader.read_to_eof(32) data = self.assertGeneratorReturns(gen) self.assertEqual(data, b"") + def test_read_to_eof_too_long(self): + gen = self.reader.read_to_eof(2) + + self.reader.feed_data(b"spam") + with self.assertRaises(RuntimeError) as raised: + next(gen) + self.assertEqual( + str(raised.exception), + "read 4 bytes, expected no more than 2 bytes", + ) + def test_at_eof_after_feed_data(self): gen = self.reader.at_eof() self.assertGeneratorRunning(gen) @@ -137,16 +172,22 @@ def test_feed_data_after_feed_eof(self): self.reader.feed_eof() with self.assertRaises(EOFError) as raised: self.reader.feed_data(b"spam") - self.assertEqual(str(raised.exception), "stream ended") + self.assertEqual( + str(raised.exception), + "stream ended", + ) def test_feed_eof_after_feed_eof(self): self.reader.feed_eof() with self.assertRaises(EOFError) as raised: self.reader.feed_eof() - self.assertEqual(str(raised.exception), "stream ended") + self.assertEqual( + str(raised.exception), + "stream ended", + ) def test_discard(self): - gen = self.reader.read_to_eof() + gen = self.reader.read_to_eof(32) self.reader.feed_data(b"spam") self.reader.discard() From cf070908da6f0fdb29035458a041cea7a64e9e3c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 5 Sep 2021 21:48:28 +0200 Subject: [PATCH 0937/1539] Prepare removal of generator-based coroutines. --- src/websockets/legacy/client.py | 2 +- src/websockets/legacy/server.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index e5743cc0e..57fe7e2c4 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -663,7 +663,7 @@ async def __await_impl__(self) -> WebSocketClientProtocol: else: raise SecurityError("too many redirects") - # ... = yield from connect(...) + # ... = yield from connect(...) - remove when dropping Python < 3.10 __iter__ = __await__ diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 4399f0782..9bda5cdd8 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -1043,7 +1043,7 @@ async def __await_impl__(self) -> WebSocketServer: self.ws_server.wrap(server) return self.ws_server - # yield from serve(...) + # yield from serve(...) - remove when dropping Python < 3.10 __iter__ = __await__ From d084f4b457eeb9cf79102b8bd73e8018b769e03f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 5 Sep 2021 22:15:11 +0200 Subject: [PATCH 0938/1539] Tweak constants. For aesthetic reasons. --- src/websockets/legacy/client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 57fe7e2c4..ffc54453e 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -584,9 +584,9 @@ def handle_redirect(self, uri: str) -> None: # async for ... in connect(...): - BACKOFF_MIN = 2.0 + BACKOFF_MIN = 1.92 BACKOFF_MAX = 60.0 - BACKOFF_FACTOR = 1.5 + BACKOFF_FACTOR = 1.618 async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]: backoff_delay = self.BACKOFF_MIN From e7b0c0ff8caecee6f1f2b818940db3d8a6b87027 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 5 Sep 2021 22:23:23 +0200 Subject: [PATCH 0939/1539] Add random delay before first reconnection attempt. --- src/websockets/legacy/client.py | 27 ++++++++++++++++++++------- tests/legacy/test_client_server.py | 12 +++++++++--- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index ffc54453e..63b973ecb 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -3,6 +3,7 @@ import asyncio import functools import logging +import random import urllib.parse import warnings from types import TracebackType @@ -587,6 +588,7 @@ def handle_redirect(self, uri: str) -> None: BACKOFF_MIN = 1.92 BACKOFF_MAX = 60.0 BACKOFF_FACTOR = 1.618 + BACKOFF_INITIAL = 5 async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]: backoff_delay = self.BACKOFF_MIN @@ -599,15 +601,26 @@ async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]: except asyncio.CancelledError: # pragma: no cover raise except Exception: - # Connection timed out - increase backoff delay + # Add a random initial delay between 0 and 5 seconds. + # See 7.2.3. Recovering from Abnormal Closure in RFC 6544. + if backoff_delay == self.BACKOFF_MIN: + initial_delay = random.random() * self.BACKOFF_INITIAL + self.logger.info( + "! connect failed; reconnecting in %.1f seconds", + initial_delay, + exc_info=True, + ) + await asyncio.sleep(initial_delay) + else: + self.logger.info( + "! connect failed again; retrying in %d seconds", + int(backoff_delay), + exc_info=True, + ) + await asyncio.sleep(int(backoff_delay)) + # Increase delay with truncated exponential backoff. backoff_delay = backoff_delay * self.BACKOFF_FACTOR backoff_delay = min(backoff_delay, self.BACKOFF_MAX) - self.logger.info( - "! connect failed; retrying in %d seconds", - int(backoff_delay), - exc_info=True, - ) - await asyncio.sleep(backoff_delay) continue else: # Connection succeeded - reset backoff delay diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 016f08e73..482b2cd0c 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1418,7 +1418,8 @@ async def run_client(): iteration = 0 connect_inst = connect(get_server_uri(self.server)) connect_inst.BACKOFF_MIN = 10 * MS - connect_inst.BACKOFF_MAX = 200 * MS + connect_inst.BACKOFF_MAX = 99 * MS + connect_inst.BACKOFF_INITIAL = 0 async for ws in connect_inst: await ws.send("spam") msg = await ws.recv() @@ -1468,9 +1469,14 @@ async def run_client(): [ "connection failed (503 Service Unavailable)", "connection closed", - "! connect failed; retrying in 0 seconds", + "! connect failed; reconnecting in 0.0 seconds", ] - * ((len(logs.records) - 5) // 3) + + [ + "connection failed (503 Service Unavailable)", + "connection closed", + "! connect failed again; retrying in 0 seconds", + ] + * ((len(logs.records) - 8) // 3) + [ "connection open", "connection closed", From 35a7ddcdbd29e1c48216a53c7f1d0ff395595ffc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 6 Sep 2021 21:45:28 +0200 Subject: [PATCH 0940/1539] Standardize documentation of dataclasses. --- src/websockets/frames.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 9a97f2530..c31796cbd 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -99,9 +99,9 @@ class Frame: """ WebSocket frame. - Args: - opcode: opcode. - data: payload data. + Attributes: + opcode: Opcode. + data: Payload data. fin: FIN bit. rsv1: RSV1 bit. rsv2: RSV2 bit. @@ -368,6 +368,10 @@ class Close: """ WebSocket close code and reason. + Attributes: + code: Close code. + reason: Close reason. + """ code: int From dcd64046b39e874cb5c2e8f9629bdacf14cfe2af Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 6 Sep 2021 21:45:48 +0200 Subject: [PATCH 0941/1539] Refactor WebSocketURI. --- src/websockets/uri.py | 71 ++++++++++++++++++++++++++++--------------- tests/test_uri.py | 54 ++++++++++++++++++++++++++------ 2 files changed, 91 insertions(+), 34 deletions(-) diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 3d8f7cd95..fff0c3806 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -17,11 +17,12 @@ class WebSocketURI: Attributes: secure: :obj:`True` for a ``wss`` URI, :obj:`False` for a ``ws`` URI. - host: Host, normalized to lower case. - port: Port, always set even if it's the default. - resource_name: Path and optional query. - user_info: ``(username, password)`` when the URI contains - `User Information`_, else :obj:`None`. + host: Normalized to lower case. + port: Always set even if it's the default. + path: May be empty. + query: May be empty if the URI doesn't include a query component. + username: Available when the URI contains `User Information`_. + password: Available when the URI contains `User Information`_. .. _User Information: https://www.rfc-editor.org/rfc/rfc3986.html#section-3.2.1 @@ -30,8 +31,27 @@ class WebSocketURI: secure: bool host: str port: int - resource_name: str - user_info: Optional[Tuple[str, str]] = None + path: str + query: str + username: Optional[str] + password: Optional[str] + + @property + def resource_name(self) -> str: + if self.path: + resource_name = self.path + else: + resource_name = "/" + if self.query: + resource_name += "?" + self.query + return resource_name + + @property + def user_info(self) -> Optional[Tuple[str, str]]: + if self.username is None: + return None + assert self.password is not None + return (self.username, self.password) # All characters from the gen-delims and sub-delims sets in RFC 3987. @@ -45,6 +65,9 @@ def parse_uri(uri: str) -> WebSocketURI: Args: uri: WebSocket URI. + Returns: + WebSocketURI: Parsed WebSocket URI. + Raises: InvalidURI: if ``uri`` isn't a valid WebSocket URI. @@ -60,16 +83,14 @@ def parse_uri(uri: str) -> WebSocketURI: secure = parsed.scheme == "wss" host = parsed.hostname port = parsed.port or (443 if secure else 80) - resource_name = parsed.path or "/" - if parsed.query: - resource_name += "?" + parsed.query - user_info = None - if parsed.username is not None: - # urllib.parse.urlparse accepts URLs with a username but without a - # password. This doesn't make sense for HTTP Basic Auth credentials. - if parsed.password is None: - raise exceptions.InvalidURI(uri, "username provided without password") - user_info = (parsed.username, parsed.password) + path = parsed.path + query = parsed.query + username = parsed.username + password = parsed.password + # urllib.parse.urlparse accepts URLs with a username but without a + # password. This doesn't make sense for HTTP Basic Auth credentials. + if username is not None and password is None: + raise exceptions.InvalidURI(uri, "username provided without password") try: uri.encode("ascii") @@ -77,11 +98,11 @@ def parse_uri(uri: str) -> WebSocketURI: # Input contains non-ASCII characters. # It must be an IRI. Convert it to a URI. host = host.encode("idna").decode() - resource_name = urllib.parse.quote(resource_name, safe=DELIMS) - if user_info is not None: - user_info = ( - urllib.parse.quote(user_info[0], safe=DELIMS), - urllib.parse.quote(user_info[1], safe=DELIMS), - ) - - return WebSocketURI(secure, host, port, resource_name, user_info) + path = urllib.parse.quote(path, safe=DELIMS) + query = urllib.parse.quote(query, safe=DELIMS) + if username is not None: + assert password is not None + username = urllib.parse.quote(username, safe=DELIMS) + password = urllib.parse.quote(password, safe=DELIMS) + + return WebSocketURI(secure, host, port, path, query, username, password) diff --git a/tests/test_uri.py b/tests/test_uri.py index f937d2949..8acc01c18 100644 --- a/tests/test_uri.py +++ b/tests/test_uri.py @@ -7,33 +7,46 @@ VALID_URIS = [ ( "ws://localhost/", - WebSocketURI(False, "localhost", 80, "/", None), + WebSocketURI(False, "localhost", 80, "/", "", None, None), ), ( "wss://localhost/", - WebSocketURI(True, "localhost", 443, "/", None), + WebSocketURI(True, "localhost", 443, "/", "", None, None), + ), + ( + "ws://localhost", + WebSocketURI(False, "localhost", 80, "", "", None, None), ), ( "ws://localhost/path?query", - WebSocketURI(False, "localhost", 80, "/path?query", None), + WebSocketURI(False, "localhost", 80, "/path", "query", None, None), ), ( "ws://localhost/path;params", - WebSocketURI(False, "localhost", 80, "/path;params", None), + WebSocketURI(False, "localhost", 80, "/path;params", "", None, None), ), ( "WS://LOCALHOST/PATH?QUERY", - WebSocketURI(False, "localhost", 80, "/PATH?QUERY", None), + WebSocketURI(False, "localhost", 80, "/PATH", "QUERY", None, None), ), ( "ws://user:pass@localhost/", - WebSocketURI(False, "localhost", 80, "/", ("user", "pass")), + WebSocketURI(False, "localhost", 80, "/", "", "user", "pass"), + ), + ( + "ws://høst/", + WebSocketURI(False, "xn--hst-0na", 80, "/", "", None, None), ), - ("ws://høst/", WebSocketURI(False, "xn--hst-0na", 80, "/", None)), ( - "ws://üser:påss@høst/πass", + "ws://üser:påss@høst/πass?qùéry", WebSocketURI( - False, "xn--hst-0na", 80, "/%CF%80ass", ("%C3%BCser", "p%C3%A5ss") + False, + "xn--hst-0na", + 80, + "/%CF%80ass", + "q%C3%B9%C3%A9ry", + "%C3%BCser", + "p%C3%A5ss", ), ), ] @@ -46,6 +59,19 @@ "ws:///path", ] +RESOURCE_NAMES = [ + ("ws://localhost/", "/"), + ("ws://localhost", "/"), + ("ws://localhost/path?query", "/path?query"), + ("ws://høst/πass?qùéry", "/%CF%80ass?q%C3%B9%C3%A9ry"), +] + +USER_INFOS = [ + ("ws://localhost/", None), + ("ws://user:pass@localhost/", ("user", "pass")), + ("ws://üser:påss@høst/", ("%C3%BCser", "p%C3%A5ss")), +] + class URITests(unittest.TestCase): def test_success(self): @@ -58,3 +84,13 @@ def test_error(self): with self.subTest(uri=uri): with self.assertRaises(InvalidURI): parse_uri(uri) + + def test_resource_name(self): + for uri, resource_name in RESOURCE_NAMES: + with self.subTest(uri=uri): + self.assertEqual(parse_uri(uri).resource_name, resource_name) + + def test_user_info(self): + for uri, user_info in USER_INFOS: + with self.subTest(uri=uri): + self.assertEqual(parse_uri(uri).user_info, user_info) From 4a22bdf8d570ff3eba0b797a80435f75f4505e66 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 6 Sep 2021 21:46:37 +0200 Subject: [PATCH 0942/1539] Make websockets.uri a public API (again!) --- docs/reference/utilities.rst | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/reference/utilities.rst b/docs/reference/utilities.rst index dc6333847..e3e53f2d3 100644 --- a/docs/reference/utilities.rst +++ b/docs/reference/utilities.rst @@ -35,3 +35,12 @@ HTTP events .. automethod:: raw_items .. autoexception:: MultipleValuesError + +URIs +---- + +.. automodule:: websockets.uri + + .. autofunction:: parse_uri + + .. autoclass:: WebSocketURI From abd297b832a74fe23825aacc28bf30e8b00f122d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 6 Sep 2021 21:47:41 +0200 Subject: [PATCH 0943/1539] Expect a WebSocketURI in ClientConnection. Parsing the URI to get the host and port and opening the connection before initializing a ClientConnection feels more natural. --- src/websockets/client.py | 6 +-- tests/test_client.py | 98 ++++++++++++++++++++++++---------------- 2 files changed, 62 insertions(+), 42 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 13c8a8ad0..42c7aeee9 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -32,7 +32,7 @@ Subprotocol, UpgradeProtocol, ) -from .uri import parse_uri +from .uri import WebSocketURI from .utils import accept_key, generate_key @@ -46,7 +46,7 @@ class ClientConnection(Connection): def __init__( self, - uri: str, + wsuri: WebSocketURI, origin: Optional[Origin] = None, extensions: Optional[Sequence[ClientExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, @@ -60,7 +60,7 @@ def __init__( max_size=max_size, logger=logger, ) - self.wsuri = parse_uri(uri) + self.wsuri = wsuri self.origin = origin self.available_extensions = extensions self.available_subprotocols = subprotocols diff --git a/tests/test_client.py b/tests/test_client.py index 015b93b3f..1c1452d41 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -9,6 +9,7 @@ from websockets.frames import OP_TEXT, Frame from websockets.http import USER_AGENT from websockets.http11 import Request, Response +from websockets.uri import parse_uri from websockets.utils import accept_key from .extensions.utils import ( @@ -24,7 +25,7 @@ class ConnectTests(unittest.TestCase): def test_send_connect(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientConnection("wss://example.com/test") + client = ClientConnection(parse_uri("wss://example.com/test")) request = client.connect() self.assertIsInstance(request, Request) client.send_request(request) @@ -44,7 +45,7 @@ def test_send_connect(self): def test_connect_request(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientConnection("wss://example.com/test") + client = ClientConnection(parse_uri("wss://example.com/test")) request = client.connect() self.assertEqual(request.path, "/test") self.assertEqual( @@ -62,7 +63,7 @@ def test_connect_request(self): ) def test_path(self): - client = ClientConnection("wss://example.com/endpoint?test=1") + client = ClientConnection(parse_uri("wss://example.com/endpoint?test=1")) request = client.connect() self.assertEqual(request.path, "/endpoint?test=1") @@ -77,33 +78,40 @@ def test_port(self): ("wss://example.com:8443/", "example.com:8443"), ]: with self.subTest(uri=uri): - client = ClientConnection(uri) + client = ClientConnection(parse_uri(uri)) request = client.connect() self.assertEqual(request.headers["Host"], host) def test_user_info(self): - client = ClientConnection("wss://hello:iloveyou@example.com/") + client = ClientConnection(parse_uri("wss://hello:iloveyou@example.com/")) request = client.connect() self.assertEqual(request.headers["Authorization"], "Basic aGVsbG86aWxvdmV5b3U=") def test_origin(self): - client = ClientConnection("wss://example.com/", origin="https://example.com") + client = ClientConnection( + parse_uri("wss://example.com/"), + origin="https://example.com", + ) request = client.connect() self.assertEqual(request.headers["Origin"], "https://example.com") def test_extensions(self): client = ClientConnection( - "wss://example.com/", extensions=[ClientOpExtensionFactory()] + parse_uri("wss://example.com/"), + extensions=[ClientOpExtensionFactory()], ) request = client.connect() self.assertEqual(request.headers["Sec-WebSocket-Extensions"], "x-op; op") def test_subprotocols(self): - client = ClientConnection("wss://example.com/", subprotocols=["chat"]) + client = ClientConnection( + parse_uri("wss://example.com/"), + subprotocols=["chat"], + ) request = client.connect() self.assertEqual(request.headers["Sec-WebSocket-Protocol"], "chat") @@ -112,7 +120,7 @@ def test_subprotocols(self): class AcceptRejectTests(unittest.TestCase): def test_receive_accept(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientConnection("ws://example.com/test") + client = ClientConnection(parse_uri("ws://example.com/test")) client.connect() client.receive_data( ( @@ -131,7 +139,7 @@ def test_receive_accept(self): def test_receive_reject(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientConnection("ws://example.com/test") + client = ClientConnection(parse_uri("ws://example.com/test")) client.connect() client.receive_data( ( @@ -151,7 +159,7 @@ def test_receive_reject(self): def test_accept_response(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientConnection("ws://example.com/test") + client = ClientConnection(parse_uri("ws://example.com/test")) client.connect() client.receive_data( ( @@ -183,7 +191,7 @@ def test_accept_response(self): def test_reject_response(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientConnection("ws://example.com/test") + client = ClientConnection(parse_uri("ws://example.com/test")) client.connect() client.receive_data( ( @@ -231,7 +239,7 @@ def make_accept_response(self, client): ) def test_basic(self): - client = ClientConnection("wss://example.com/") + client = ClientConnection(parse_uri("wss://example.com/")) response = self.make_accept_response(client) client.receive_data(response.serialize()) [response] = client.events_received() @@ -239,7 +247,7 @@ def test_basic(self): self.assertEqual(client.state, OPEN) def test_missing_connection(self): - client = ClientConnection("wss://example.com/") + client = ClientConnection(parse_uri("wss://example.com/")) response = self.make_accept_response(client) del response.headers["Connection"] client.receive_data(response.serialize()) @@ -251,7 +259,7 @@ def test_missing_connection(self): self.assertEqual(str(raised.exception), "missing Connection header") def test_invalid_connection(self): - client = ClientConnection("wss://example.com/") + client = ClientConnection(parse_uri("wss://example.com/")) response = self.make_accept_response(client) del response.headers["Connection"] response.headers["Connection"] = "close" @@ -264,7 +272,7 @@ def test_invalid_connection(self): self.assertEqual(str(raised.exception), "invalid Connection header: close") def test_missing_upgrade(self): - client = ClientConnection("wss://example.com/") + client = ClientConnection(parse_uri("wss://example.com/")) response = self.make_accept_response(client) del response.headers["Upgrade"] client.receive_data(response.serialize()) @@ -276,7 +284,7 @@ def test_missing_upgrade(self): self.assertEqual(str(raised.exception), "missing Upgrade header") def test_invalid_upgrade(self): - client = ClientConnection("wss://example.com/") + client = ClientConnection(parse_uri("wss://example.com/")) response = self.make_accept_response(client) del response.headers["Upgrade"] response.headers["Upgrade"] = "h2c" @@ -289,7 +297,7 @@ def test_invalid_upgrade(self): self.assertEqual(str(raised.exception), "invalid Upgrade header: h2c") def test_missing_accept(self): - client = ClientConnection("wss://example.com/") + client = ClientConnection(parse_uri("wss://example.com/")) response = self.make_accept_response(client) del response.headers["Sec-WebSocket-Accept"] client.receive_data(response.serialize()) @@ -301,7 +309,7 @@ def test_missing_accept(self): self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Accept header") def test_multiple_accept(self): - client = ClientConnection("wss://example.com/") + client = ClientConnection(parse_uri("wss://example.com/")) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Accept"] = ACCEPT client.receive_data(response.serialize()) @@ -317,7 +325,7 @@ def test_multiple_accept(self): ) def test_invalid_accept(self): - client = ClientConnection("wss://example.com/") + client = ClientConnection(parse_uri("wss://example.com/")) response = self.make_accept_response(client) del response.headers["Sec-WebSocket-Accept"] response.headers["Sec-WebSocket-Accept"] = ACCEPT @@ -332,7 +340,7 @@ def test_invalid_accept(self): ) def test_no_extensions(self): - client = ClientConnection("wss://example.com/") + client = ClientConnection(parse_uri("wss://example.com/")) response = self.make_accept_response(client) client.receive_data(response.serialize()) [response] = client.events_received() @@ -342,7 +350,8 @@ def test_no_extensions(self): def test_no_extension(self): client = ClientConnection( - "wss://example.com/", extensions=[ClientOpExtensionFactory()] + parse_uri("wss://example.com/"), + extensions=[ClientOpExtensionFactory()], ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-op; op" @@ -354,7 +363,8 @@ def test_no_extension(self): def test_extension(self): client = ClientConnection( - "wss://example.com/", extensions=[ClientRsv2ExtensionFactory()] + parse_uri("wss://example.com/"), + extensions=[ClientRsv2ExtensionFactory()], ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" @@ -365,7 +375,7 @@ def test_extension(self): self.assertEqual(client.extensions, [Rsv2Extension()]) def test_unexpected_extension(self): - client = ClientConnection("wss://example.com/") + client = ClientConnection(parse_uri("wss://example.com/")) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-op; op" client.receive_data(response.serialize()) @@ -378,7 +388,8 @@ def test_unexpected_extension(self): def test_unsupported_extension(self): client = ClientConnection( - "wss://example.com/", extensions=[ClientRsv2ExtensionFactory()] + parse_uri("wss://example.com/"), + extensions=[ClientRsv2ExtensionFactory()], ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-op; op" @@ -395,7 +406,8 @@ def test_unsupported_extension(self): def test_supported_extension_parameters(self): client = ClientConnection( - "wss://example.com/", extensions=[ClientOpExtensionFactory("this")] + parse_uri("wss://example.com/"), + extensions=[ClientOpExtensionFactory("this")], ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-op; op=this" @@ -407,7 +419,8 @@ def test_supported_extension_parameters(self): def test_unsupported_extension_parameters(self): client = ClientConnection( - "wss://example.com/", extensions=[ClientOpExtensionFactory("this")] + parse_uri("wss://example.com/"), + extensions=[ClientOpExtensionFactory("this")], ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" @@ -424,7 +437,7 @@ def test_unsupported_extension_parameters(self): def test_multiple_supported_extension_parameters(self): client = ClientConnection( - "wss://example.com/", + parse_uri("wss://example.com/"), extensions=[ ClientOpExtensionFactory("this"), ClientOpExtensionFactory("that"), @@ -440,7 +453,7 @@ def test_multiple_supported_extension_parameters(self): def test_multiple_extensions(self): client = ClientConnection( - "wss://example.com/", + parse_uri("wss://example.com/"), extensions=[ClientOpExtensionFactory(), ClientRsv2ExtensionFactory()], ) response = self.make_accept_response(client) @@ -454,7 +467,7 @@ def test_multiple_extensions(self): def test_multiple_extensions_order(self): client = ClientConnection( - "wss://example.com/", + parse_uri("wss://example.com/"), extensions=[ClientOpExtensionFactory(), ClientRsv2ExtensionFactory()], ) response = self.make_accept_response(client) @@ -467,7 +480,7 @@ def test_multiple_extensions_order(self): self.assertEqual(client.extensions, [Rsv2Extension(), OpExtension()]) def test_no_subprotocols(self): - client = ClientConnection("wss://example.com/") + client = ClientConnection(parse_uri("wss://example.com/")) response = self.make_accept_response(client) client.receive_data(response.serialize()) [response] = client.events_received() @@ -476,7 +489,9 @@ def test_no_subprotocols(self): self.assertIsNone(client.subprotocol) def test_no_subprotocol(self): - client = ClientConnection("wss://example.com/", subprotocols=["chat"]) + client = ClientConnection( + parse_uri("wss://example.com/"), subprotocols=["chat"] + ) response = self.make_accept_response(client) client.receive_data(response.serialize()) [response] = client.events_received() @@ -485,7 +500,9 @@ def test_no_subprotocol(self): self.assertIsNone(client.subprotocol) def test_subprotocol(self): - client = ClientConnection("wss://example.com/", subprotocols=["chat"]) + client = ClientConnection( + parse_uri("wss://example.com/"), subprotocols=["chat"] + ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Protocol"] = "chat" client.receive_data(response.serialize()) @@ -495,7 +512,7 @@ def test_subprotocol(self): self.assertEqual(client.subprotocol, "chat") def test_unexpected_subprotocol(self): - client = ClientConnection("wss://example.com/") + client = ClientConnection(parse_uri("wss://example.com/")) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Protocol"] = "chat" client.receive_data(response.serialize()) @@ -508,7 +525,8 @@ def test_unexpected_subprotocol(self): def test_multiple_subprotocols(self): client = ClientConnection( - "wss://example.com/", subprotocols=["superchat", "chat"] + parse_uri("wss://example.com/"), + subprotocols=["superchat", "chat"], ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Protocol"] = "superchat" @@ -525,7 +543,8 @@ def test_multiple_subprotocols(self): def test_supported_subprotocol(self): client = ClientConnection( - "wss://example.com/", subprotocols=["superchat", "chat"] + parse_uri("wss://example.com/"), + subprotocols=["superchat", "chat"], ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Protocol"] = "chat" @@ -537,7 +556,8 @@ def test_supported_subprotocol(self): def test_unsupported_subprotocol(self): client = ClientConnection( - "wss://example.com/", subprotocols=["superchat", "chat"] + parse_uri("wss://example.com/"), + subprotocols=["superchat", "chat"], ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Protocol"] = "otherchat" @@ -552,7 +572,7 @@ def test_unsupported_subprotocol(self): class MiscTests(unittest.TestCase): def test_bypass_handshake(self): - client = ClientConnection("ws://example.com/test", state=OPEN) + client = ClientConnection(parse_uri("ws://example.com/test"), state=OPEN) client.receive_data(b"\x81\x06Hello!") [frame] = client.events_received() self.assertEqual(frame, Frame(OP_TEXT, b"Hello!")) @@ -560,5 +580,5 @@ def test_bypass_handshake(self): def test_custom_logger(self): logger = logging.getLogger("test") with self.assertLogs("test", logging.DEBUG) as logs: - ClientConnection("wss://example.com/test", logger=logger) + ClientConnection(parse_uri("wss://example.com/test"), logger=logger) self.assertEqual(len(logs.records), 1) From eba7b56eb652a1bfff2d536b5d348499a1e434d5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 6 Sep 2021 22:33:10 +0200 Subject: [PATCH 0944/1539] Improve docs of Frame and Close. --- docs/reference/utilities.rst | 12 ++++++++++++ src/websockets/frames.py | 4 +++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/docs/reference/utilities.rst b/docs/reference/utilities.rst index e3e53f2d3..6b5d402fc 100644 --- a/docs/reference/utilities.rst +++ b/docs/reference/utilities.rst @@ -15,6 +15,18 @@ WebSocket events .. autoclass:: Opcode + .. autoattribute:: CONT + + .. autoattribute:: TEXT + + .. autoattribute:: BINARY + + .. autoattribute:: CLOSE + + .. autoattribute:: PING + + .. autoattribute:: PONG + .. autoclass:: Close HTTP events diff --git a/src/websockets/frames.py b/src/websockets/frames.py index c31796cbd..82b4a1403 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -35,6 +35,8 @@ class Opcode(enum.IntEnum): + """Opcode values for WebSocket frames.""" + CONT, TEXT, BINARY = 0x00, 0x01, 0x02 CLOSE, PING, PONG = 0x08, 0x09, 0x0A @@ -366,7 +368,7 @@ def prepare_ctrl(data: Data) -> bytes: @dataclasses.dataclass class Close: """ - WebSocket close code and reason. + Code and reason for WebSocket close frames. Attributes: code: Close code. From 5fc6fa832c11be7c42739f901b3a893285bddec0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 6 Sep 2021 22:40:53 +0200 Subject: [PATCH 0945/1539] Clarify comment. Fix #1032. --- src/websockets/datastructures.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index 1ff586abd..5afe86931 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -88,7 +88,7 @@ def copy(self) -> Headers: return copy def serialize(self) -> bytes: - # Headers only contain ASCII characters. + # Since headers only contain ASCII characters, we can keep this simple. return str(self).encode() # Collection methods From a8eb9738244f9f3ee3d2471baa09c3010cf45afc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 8 Sep 2021 21:27:43 +0200 Subject: [PATCH 0946/1539] Avoid creating __doc__ attributes. Sphinx still finds the docstring-that-isn't-really-a-docstring. --- src/websockets/datastructures.py | 2 +- src/websockets/typing.py | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index 5afe86931..d5c061cf8 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -165,4 +165,4 @@ def raw_items(self) -> Iterator[Tuple[str, str]]: HeadersLike = Union[Headers, Mapping[str, str], Iterable[Tuple[str, str]]] -HeadersLike.__doc__ = """Types accepted where :class:`Headers` is expected""" +"""Types accepted where :class:`Headers` is expected.""" diff --git a/src/websockets/typing.py b/src/websockets/typing.py index dadee7aba..e672ba006 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -17,7 +17,7 @@ # Public types used in the signature of public APIs Data = Union[str, bytes] -Data.__doc__ = """Types supported in a WebSocket message: +"""Types supported in a WebSocket message: :class:`str` for a Text_ frame, :class:`bytes` for a Binary_. .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 @@ -27,34 +27,34 @@ LoggerLike = Union[logging.Logger, logging.LoggerAdapter] -LoggerLike.__doc__ = """Types accepted where a :class:`~logging.Logger` is expected.""" +"""Types accepted where a :class:`~logging.Logger` is expected.""" Origin = NewType("Origin", str) -Origin.__doc__ = """Value of a ``Origin`` header.""" +"""Value of a ``Origin`` header.""" Subprotocol = NewType("Subprotocol", str) -Subprotocol.__doc__ = """Subprotocol in a ``Sec-WebSocket-Protocol`` header.""" +"""Subprotocol in a ``Sec-WebSocket-Protocol`` header.""" ExtensionName = NewType("ExtensionName", str) -ExtensionName.__doc__ = """Name of a WebSocket extension.""" +"""Name of a WebSocket extension.""" ExtensionParameter = Tuple[str, Optional[str]] -ExtensionParameter.__doc__ = """Parameter of a WebSocket extension.""" +"""Parameter of a WebSocket extension.""" # Private types ExtensionHeader = Tuple[ExtensionName, List[ExtensionParameter]] -ExtensionHeader.__doc__ = """Extension in a ``Sec-WebSocket-Extensions`` header.""" +"""Extension in a ``Sec-WebSocket-Extensions`` header.""" ConnectionOption = NewType("ConnectionOption", str) -ConnectionOption.__doc__ = """Connection option in a ``Connection`` header.""" +"""Connection option in a ``Connection`` header.""" UpgradeProtocol = NewType("UpgradeProtocol", str) -UpgradeProtocol.__doc__ = """Upgrade protocol in an ``Upgrade`` header.""" +"""Upgrade protocol in an ``Upgrade`` header.""" From 0cf844146b4764c54943691ce5983596cd9e3272 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 8 Sep 2021 22:59:09 +0200 Subject: [PATCH 0947/1539] Remove unnecessary parameters from reject(). --- src/websockets/server.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/websockets/server.py b/src/websockets/server.py index 5f7bec30d..bb91bdc77 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -99,7 +99,7 @@ def accept(self, request: Request) -> Response: request.exception = exc if self.debug: self.logger.debug("! invalid upgrade", exc_info=True) - return self.reject( + response = self.reject( http.HTTPStatus.UPGRADE_REQUIRED, ( f"Failed to open a WebSocket connection: {exc}.\n" @@ -107,8 +107,9 @@ def accept(self, request: Request) -> Response: f"You cannot access a WebSocket server directly " f"with a browser. You need a WebSocket client.\n" ), - headers=Headers([("Upgrade", "websocket")]), ) + response.headers["Upgrade"] = "websocket" + return response except InvalidHandshake as exc: request.exception = exc if self.debug: @@ -415,8 +416,6 @@ def reject( self, status: http.HTTPStatus, text: str, - headers: Optional[Headers] = None, - exception: Optional[Exception] = None, ) -> Response: """ Create a HTTP response event to reject the connection. @@ -429,13 +428,15 @@ def reject( """ body = text.encode() - if headers is None: - headers = Headers() - headers.setdefault("Date", email.utils.formatdate(usegmt=True)) - headers.setdefault("Connection", "close") - headers.setdefault("Content-Length", str(len(body))) - headers.setdefault("Content-Type", "text/plain; charset=utf-8") - headers.setdefault("Server", USER_AGENT) + headers = Headers( + [ + ("Date", email.utils.formatdate(usegmt=True)), + ("Connection", "close"), + ("Content-Length", str(len(body))), + ("Content-Type", "text/plain; charset=utf-8"), + ("Server", USER_AGENT), + ] + ) self.logger.info("connection failed (%d %s)", status.value, status.phrase) return Response(status.value, status.phrase, headers, body) From 724408ea72c2aaf7598174a8e276c95738c3d996 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Jun 2021 15:43:52 +0200 Subject: [PATCH 0948/1539] Add Sans-I/O howto guide. --- docs/howto/index.rst | 8 + docs/howto/sansio.rst | 318 ++++++++++++++++++++++++++++++++++++++ docs/topics/data-flow.svg | 63 ++++++++ 3 files changed, 389 insertions(+) create mode 100644 docs/howto/sansio.rst create mode 100644 docs/topics/data-flow.svg diff --git a/docs/howto/index.rst b/docs/howto/index.rst index d7c83dd7a..5deb7c767 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -36,3 +36,11 @@ Once your application is ready, learn how to deploy it on various platforms. supervisor nginx haproxy + +If you're integrating the Sans-I/O layer of websockets into a library, rather +than building an application with websockets, follow this guide. + +.. toctree:: + :maxdepth: 2 + + sansio diff --git a/docs/howto/sansio.rst b/docs/howto/sansio.rst new file mode 100644 index 000000000..83496bff2 --- /dev/null +++ b/docs/howto/sansio.rst @@ -0,0 +1,318 @@ +Integrate the Sans-I/O layer +============================ + +.. currentmodule:: websockets + +This guide explains how to integrate the `Sans-I/O`_ layer of websockets to +add support for WebSocket in another library. + +.. _Sans-I/O: https://sans-io.readthedocs.io/ + +As a prerequisite, you should decide how you will handle network I/O and +asynchronous control flow. + +Your integration layer will provide an API for the application on one side, +will talk to the network on the other side, and will rely on websockets to +implement the protocol in the middle. + +.. image:: ../topics/data-flow.svg + :align: center + +Opening a connection +-------------------- + +Client-side +........... + +If you're building a client, parse the URI you'd like to connect to:: + + from websockets.uri import parse_uri + + wsuri = parse_uri("ws://example.com/") + +Open a TCP connection to ``(wsuri.host, wsuri.port)`` and perform a TLS +handshake if ``wsuri.secure`` is :obj:`True`. + +Initialize a :class:`~client.ClientConnection`:: + + from websockets.client import ClientConnection + + connection = ClientConnection(wsuri) + +Create a WebSocket handshake request +with :meth:`~client.ClientConnection.connect` and send it +with :meth:`~client.ClientConnection.send_request`:: + + request = connection.connect() + connection.send_request(request) + +Then, call :meth:`~connection.Connection.data_to_send` and send its output to +the network, as described in `Send data`_ below. + +The first event returned by :meth:`~connection.Connection.events_received` is +the WebSocket handshake response. + +When the handshake fails, the reason is available in ``response.exception``:: + + if response.exception is not None: + raise response.exception + +Else, the WebSocket connection is open. + +A WebSocket client API usually performs the handshake then returns a wrapper +around the network connection and the :class:`~client.ClientConnection`. + +Server-side +........... + +If you're building a server, accept network connections from clients and +perform a TLS handshake if desired. + +For each connection, initialize a :class:`~server.ServerConnection`:: + + from websockets.server import ServerConnection + + connection = ServerConnection() + +The first event returned by :meth:`~connection.Connection.events_received` is +the WebSocket handshake request. + +Create a WebSocket handshake response +with :meth:`~server.ServerConnection.accept` and send it +with :meth:`~server.ServerConnection.send_response`:: + + response = connection.accept(request) + connection.send_response(response) + +Alternatively, you may reject the WebSocket handshake and return a HTTP +response with :meth:`~server.ServerConnection.reject`:: + + response = connection.reject(status, explanation) + connection.send_response(response) + +Then, call :meth:`~connection.Connection.data_to_send` and send its output to +the network, as described in `Send data`_ below. + +Even when you call :meth:`~server.ServerConnection.accept`, the WebSocket +handshake may fail if the request is incorrect or unsupported. + +When the handshake fails, the reason is available in ``request.exception``:: + + if request.exception is not None: + raise request.exception + +Else, the WebSocket connection is open. + +A WebSocket server API usually builds a wrapper around the network connection +and the :class:`~server.ServerConnection`. Then it invokes a connection +handler that accepts the wrapper in argument. + +It may also provide a way to close all connections and to shut down the server +gracefully. + +Going forwards, this guide focuses on handling an individual connection. + +From the network to the application +----------------------------------- + +Go through the five steps below until you reach the end of the data stream. + +Receive data +............ + +When receiving data from the network, feed it to the connection's +:meth:`~connection.Connection.receive_data` method. + +When reaching the end of the data stream, call the connection's +:meth:`~connection.Connection.receive_eof` method. + +For example, if ``sock`` is a :obj:`~socket.socket`:: + + try: + data = sock.recv(4096) + except OSError: # socket closed + data = b"" + if data: + connection.receive_data(data) + else: + connection.receive_eof() + +These methods aren't expected to raise exceptions — unless you call them again +after calling :meth:`~connection.Connection.receive_eof`, which is an error. +(If you get an exception, please file a bug!) + +Send data +......... + +Then, call :meth:`~connection.Connection.data_to_send` and send its output to +the network:: + + for data in connection.data_to_send(): + if data: + sock.sendall(data) + else: + sock.shutdown(socket.SHUT_WR) + +The empty bytestring signals the end of the data stream. When you see it, you +must half-close the TCP connection. + +Sending data right after receiving data is necessary because websockets +responds to ping frames, close frames, and incorrect inputs automatically. + +Expect TCP connection to close +.............................. + +Closing a WebSocket connection normally involves a two-way WebSocket closing +handshake. Then, regardless of whether the closure is normal or abnormal, the +server starts the four-way TCP closing handshake. If the network fails at the +wrong point, you can end up waiting until the TCP timeout, which is very long. + +To prevent dangling TCP connections when you expect the end of the data stream +but you never reach it, call :meth:`~connection.Connection.close_expected` +and, if it returns :obj:`True`, schedule closing the TCP connection after a +short timeout:: + + # start a new execution thread to run this code + sleep(10) + sock.close() # does nothing if the socket is already closed + +If the connection is still open when the timeout elapses, closing the socket +makes the execution thread that reads from the socket reach the end of the +data stream, possibly with an exception. + +Close TCP connection +.................... + +If you called :meth:`~connection.Connection.receive_eof`, close the TCP +connection now. This is a clean closure because the receive buffer is empty. + +After :meth:`~connection.Connection.receive_eof` signals the end of the read +stream, :meth:`~connection.Connection.data_to_send` always signals the end of +the write stream, unless it already ended. So, at this point, the TCP +connection is already half-closed. The only reason for closing it now is to +release resources related to the socket. + +Now you can exit the loop relaying data from the network to the application. + +Receive events +.............. + +Finally, call :meth:`~connection.Connection.events_received` to obtain events +parsed from the data provided to :meth:`~connection.Connection.receive_data`:: + + events = connection.events_received() + +The first event will be the WebSocket opening handshake request or response. +See `Opening a connection`_ above for details. + +All later events are WebSocket frames. There are two types of frames: + +* Data frames contain messages transferred over the WebSocket connections. You + should provide them to the application. See `Fragmentation`_ below for + how to reassemble messages from frames. +* Control frames provide information about the connection's state. The main + use case is to expose an abstraction over ping and pong to the application. + Keep in mind that websockets responds to ping frames and close frames + automatically. Don't duplicate this functionality! + +From the application to the network +----------------------------------- + +The connection object provides one method for each type of WebSocket frame. + +For sending a data frame: + +* :meth:`~connection.Connection.send_continuation` +* :meth:`~connection.Connection.send_text` +* :meth:`~connection.Connection.send_binary` + +These methods raise :exc:`~exceptions.ProtocolError` if you don't set +the :attr:`FIN ` bit correctly in fragmented +messages. + +For sending a control frame: + +* :meth:`~connection.Connection.send_close` +* :meth:`~connection.Connection.send_ping` +* :meth:`~connection.Connection.send_pong` + +:meth:`~connection.Connection.send_close` initiates the closing handshake. +See `Closing a connection`_ below for details. + +If you encounter an unrecoverable error and you must fail the WebSocket +connection, call :meth:`~connection.Connection.fail`. + +After any of the above, call :meth:`~connection.Connection.data_to_send` and +send its output to the network, as shown in `Send data`_ above. + +If you called :meth:`~connection.Connection.send_close` +or :meth:`~connection.Connection.fail`, you expect the end of the data +stream. You should follow the process described in `Close TCP connection`_ +above in order to prevent dangling TCP connections. + +Closing a connection +-------------------- + +Under normal circumstances, when a server wants to close the TCP connection: + +* it closes the write side; +* it reads until the end of the stream, because it expects the client to close + the read side; +* it closes the socket. + +When a client wants to close the TCP connection: + +* it reads until the end of the stream, because it expects the server to close + the read side; +* it closes the write side; +* it closes the socket. + +Applying the rules described earlier in this document gives the intended +result. As a reminder, the rules are: + +* When :meth:`~connection.Connection.data_to_send` returns the empty + bytestring, close the write side of the TCP connection. +* When you reach the end of the read stream, close the TCP connection. +* When :meth:`~connection.Connection.close_expected` returns :obj:`True`, if + you don't reach the end of the read stream quickly, close the TCP connection. + +Fragmentation +------------- + +WebSocket messages may be fragmented. Since this is a protocol-level concern, +you may choose to reassemble fragmented messages before handing them over to +the application. + +To reassemble a message, read data frames until you get a frame where +the :attr:`FIN ` bit is set, then concatenate +the payloads of all frames. + +You will never receive an inconsistent sequence of frames because websockets +raises a :exc:`~exceptions.ProtocolError` and fails the connection when this +happens. However, you may receive an incomplete sequence if the connection +drops in the middle of a fragmented message. + +Tips +---- + +Serialize operations +.................... + +The Sans-I/O layer expects to run sequentially. If your interact with it from +multiple threads or coroutines, you must ensure correct serialization. This +should happen automatically in a cooperative multitasking environment. + +However, you still have to make sure you don't break this property by +accident. For example, serialize writes to the network +when :meth:`~connection.Connection.data_to_send` returns multiple values to +prevent concurrent writes from interleaving incorrectly. + +Avoid buffers +............. + +The Sans-I/O layer doesn't do any buffering. It makes events available in +:meth:`~connection.Connection.events_received` as soon as they're received. + +You should make incoming messages available to the application immediately and +stop further processing until the application fetches them. This will usually +result in the best performance. diff --git a/docs/topics/data-flow.svg b/docs/topics/data-flow.svg new file mode 100644 index 000000000..749d9d482 --- /dev/null +++ b/docs/topics/data-flow.svg @@ -0,0 +1,63 @@ +Integration layerSans-I/O layerApplicationreceivemessagessendmessagesNetworksenddatareceivedatareceivebytessendbytessendeventsreceiveevents \ No newline at end of file From be1203b8f6903905024c8e1e3b0b6a5c4e290c99 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 5 Sep 2021 17:31:54 +0200 Subject: [PATCH 0949/1539] Add API documentation for the Sans-I/O layer. --- docs/reference/client.rst | 112 +++++++++++++------ docs/reference/common.rst | 113 ++++++++++++++----- docs/reference/server.rst | 150 +++++++++++++++++--------- docs/reference/types.rst | 4 +- src/websockets/client.py | 33 +++++- src/websockets/connection.py | 204 ++++++++++++++++++++++++++--------- src/websockets/server.py | 48 +++++++-- 7 files changed, 496 insertions(+), 168 deletions(-) diff --git a/docs/reference/client.rst b/docs/reference/client.rst index dc31dc032..daf01ef58 100644 --- a/docs/reference/client.rst +++ b/docs/reference/client.rst @@ -3,60 +3,108 @@ Client .. automodule:: websockets.client - Opening a connection - -------------------- +asyncio +------- - .. autofunction:: connect(uri, *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) - :async: +Opening a connection +.................... - .. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) - :async: +.. autofunction:: connect(uri, *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) + :async: - Using a connection - ------------------ +.. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) + :async: - .. autoclass:: WebSocketClientProtocol(*, logger=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) +Using a connection +.................. - .. automethod:: recv +.. autoclass:: WebSocketClientProtocol(*, logger=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) - .. automethod:: send + .. automethod:: recv - .. automethod:: close + .. automethod:: send - .. automethod:: wait_closed + .. automethod:: close - .. automethod:: ping + .. automethod:: wait_closed - .. automethod:: pong + .. automethod:: ping - WebSocket connection objects also provide these attributes: + .. automethod:: pong - .. autoattribute:: id + WebSocket connection objects also provide these attributes: - .. autoattribute:: logger + .. autoattribute:: id - .. autoproperty:: local_address + .. autoattribute:: logger - .. autoproperty:: remote_address + .. autoproperty:: local_address - .. autoproperty:: open + .. autoproperty:: remote_address - .. autoproperty:: closed + .. autoproperty:: open - The following attributes are available after the opening handshake, - once the WebSocket connection is open: + .. autoproperty:: closed - .. autoattribute:: path + The following attributes are available after the opening handshake, + once the WebSocket connection is open: - .. autoattribute:: request_headers + .. autoattribute:: path - .. autoattribute:: response_headers + .. autoattribute:: request_headers - .. autoattribute:: subprotocol + .. autoattribute:: response_headers - The following attributes are available after the closing handshake, - once the WebSocket connection is closed: + .. autoattribute:: subprotocol - .. autoproperty:: close_code + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: - .. autoproperty:: close_reason + .. autoproperty:: close_code + + .. autoproperty:: close_reason + +Sans-I/O +-------- + +.. autoclass:: ClientConnection(wsuri, origin=None, extensions=None, subprotocols=None, state=State.CONNECTING, max_size=2 ** 20, logger=None) + + .. automethod:: receive_data + + .. automethod:: receive_eof + + .. automethod:: connect + + .. automethod:: send_request + + .. automethod:: send_continuation + + .. automethod:: send_text + + .. automethod:: send_binary + + .. automethod:: send_close + + .. automethod:: send_ping + + .. automethod:: send_pong + + .. automethod:: fail + + .. automethod:: events_received + + .. automethod:: data_to_send + + .. automethod:: close_expected + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: state + + .. autoproperty:: close_code + + .. autoproperty:: close_reason + + .. autoproperty:: close_exc diff --git a/docs/reference/common.rst b/docs/reference/common.rst index f5422bc35..f2683bc77 100644 --- a/docs/reference/common.rst +++ b/docs/reference/common.rst @@ -1,53 +1,114 @@ Both sides ========== +asyncio +------- + .. automodule:: websockets.legacy.protocol - Using a connection - ------------------ +.. autoclass:: WebSocketCommonProtocol(*, logger=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) + + .. automethod:: recv + + .. automethod:: send + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + .. autoproperty:: open + + .. autoproperty:: closed + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: path + + .. autoattribute:: request_headers + + .. autoattribute:: response_headers + + .. autoattribute:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason + +Sans-I/O +-------- + +.. automodule:: websockets.connection + +.. autoclass:: Connection(side, state=State.OPEN, max_size=2 ** 20, logger=None) + + .. automethod:: receive_data + + .. automethod:: receive_eof + + .. automethod:: send_continuation + + .. automethod:: send_text + + .. automethod:: send_binary - .. autoclass:: WebSocketCommonProtocol(*, logger=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) + .. automethod:: send_close - .. automethod:: recv + .. automethod:: send_ping - .. automethod:: send + .. automethod:: send_pong - .. automethod:: close + .. automethod:: fail - .. automethod:: wait_closed + .. automethod:: events_received - .. automethod:: ping + .. automethod:: data_to_send - .. automethod:: pong + .. automethod:: close_expected - WebSocket connection objects also provide these attributes: + .. autoattribute:: id - .. autoattribute:: id + .. autoattribute:: logger - .. autoattribute:: logger + .. autoproperty:: state - .. autoproperty:: local_address + .. autoproperty:: close_code - .. autoproperty:: remote_address + .. autoproperty:: close_reason - .. autoproperty:: open + .. autoproperty:: close_exc - .. autoproperty:: closed +.. autoclass:: Side - The following attributes are available after the opening handshake, - once the WebSocket connection is open: + .. autoattribute:: SERVER - .. autoattribute:: path + .. autoattribute:: CLIENT - .. autoattribute:: request_headers +.. autoclass:: State - .. autoattribute:: response_headers + .. autoattribute:: CONNECTING - .. autoattribute:: subprotocol + .. autoattribute:: OPEN - The following attributes are available after the closing handshake, - once the WebSocket connection is closed: + .. autoattribute:: CLOSING - .. autoproperty:: close_code + .. autoattribute:: CLOSED - .. autoproperty:: close_reason +.. autodata:: SEND_EOF diff --git a/docs/reference/server.rst b/docs/reference/server.rst index 1864594f5..a8eb9c5b4 100644 --- a/docs/reference/server.rst +++ b/docs/reference/server.rst @@ -3,96 +3,148 @@ Server .. automodule:: websockets.server - Starting a server - ----------------- +asyncio +------- - .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) - :async: +Starting a server +................. - .. autofunction:: unix_serve(ws_handler, path=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) - :async: +.. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) + :async: - Stopping a server - ----------------- +.. autofunction:: unix_serve(ws_handler, path=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) + :async: - .. autoclass:: WebSocketServer +Stopping a server +................. - .. automethod:: close +.. autoclass:: WebSocketServer - .. automethod:: wait_closed + .. automethod:: close - .. autoattribute:: sockets + .. automethod:: wait_closed - Using a connection - ------------------ + .. autoattribute:: sockets - .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, logger=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) +Using a connection +.................. - .. automethod:: recv +.. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, logger=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) - .. automethod:: send + .. automethod:: recv - .. automethod:: close + .. automethod:: send - .. automethod:: wait_closed + .. automethod:: close - .. automethod:: ping + .. automethod:: wait_closed - .. automethod:: pong + .. automethod:: ping - You can customize the opening handshake in a subclass by overriding these methods: + .. automethod:: pong - .. automethod:: process_request + You can customize the opening handshake in a subclass by overriding these methods: - .. automethod:: select_subprotocol + .. automethod:: process_request - WebSocket connection objects also provide these attributes: + .. automethod:: select_subprotocol - .. autoattribute:: id + WebSocket connection objects also provide these attributes: - .. autoattribute:: logger + .. autoattribute:: id - .. autoproperty:: local_address + .. autoattribute:: logger - .. autoproperty:: remote_address + .. autoproperty:: local_address - .. autoproperty:: open + .. autoproperty:: remote_address - .. autoproperty:: closed + .. autoproperty:: open - The following attributes are available after the opening handshake, - once the WebSocket connection is open: + .. autoproperty:: closed - .. autoattribute:: path + The following attributes are available after the opening handshake, + once the WebSocket connection is open: - .. autoattribute:: request_headers + .. autoattribute:: path - .. autoattribute:: response_headers + .. autoattribute:: request_headers - .. autoattribute:: subprotocol + .. autoattribute:: response_headers - The following attributes are available after the closing handshake, - once the WebSocket connection is closed: + .. autoattribute:: subprotocol - .. autoproperty:: close_code + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: - .. autoproperty:: close_reason + .. autoproperty:: close_code + + .. autoproperty:: close_reason Basic authentication --------------------- +.................... .. automodule:: websockets.auth - websockets supports HTTP Basic Authentication according to - :rfc:`7235` and :rfc:`7617`. +websockets supports HTTP Basic Authentication according to +:rfc:`7235` and :rfc:`7617`. + +.. autofunction:: basic_auth_protocol_factory + +.. autoclass:: BasicAuthWebSocketServerProtocol + + .. autoattribute:: realm + + .. autoattribute:: username + + .. automethod:: check_credentials + +.. currentmodule:: websockets.server + +Sans-I/O +-------- + +.. autoclass:: ServerConnection(origins=None, extensions=None, subprotocols=None, state=State.CONNECTING, max_size=2 ** 20, logger=None) + + .. automethod:: receive_data + + .. automethod:: receive_eof + + .. automethod:: accept + + .. automethod:: reject + + .. automethod:: send_response + + .. automethod:: send_continuation + + .. automethod:: send_text + + .. automethod:: send_binary + + .. automethod:: send_close + + .. automethod:: send_ping + + .. automethod:: send_pong + + .. automethod:: fail + + .. automethod:: events_received + + .. automethod:: data_to_send + + .. automethod:: close_expected + + .. autoattribute:: id - .. autofunction:: basic_auth_protocol_factory + .. autoattribute:: logger - .. autoclass:: BasicAuthWebSocketServerProtocol + .. autoproperty:: state - .. autoattribute:: realm + .. autoproperty:: close_code - .. autoattribute:: username + .. autoproperty:: close_reason - .. automethod:: check_credentials + .. autoproperty:: close_exc diff --git a/docs/reference/types.rst b/docs/reference/types.rst index 3dab553af..4b7952553 100644 --- a/docs/reference/types.rst +++ b/docs/reference/types.rst @@ -1,8 +1,6 @@ Types ===== -.. autodata:: websockets.datastructures.HeadersLike - .. automodule:: websockets.typing .. autodata:: Data @@ -17,4 +15,6 @@ Types .. autodata:: ExtensionParameter +.. autodata:: websockets.connection.Event +.. autodata:: websockets.datastructures.HeadersLike diff --git a/src/websockets/client.py b/src/websockets/client.py index 42c7aeee9..6f2da5a6e 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -44,6 +44,28 @@ class ClientConnection(Connection): + """ + Sans-I/O implementation of a WebSocket client connection. + + Args: + wsuri: URI of the WebSocket server, parsed + with :func:`~websockets.uri.parse_uri`. + origin: value of the ``Origin`` header. This is useful when connecting + to a server that validates the ``Origin`` header to defend against + Cross-Site WebSocket Hijacking attacks. + extensions: list of supported extensions, in order in which they + should be tried. + subprotocols: list of supported subprotocols, in order of decreasing + preference. + state: initial state of the WebSocket connection. + max_size: maximum size of incoming messages in bytes; + :obj:`None` to disable the limit. + logger: logger for this connection; + defaults to ``logging.getLogger("websockets.client")``; + see the :doc:`logging guide <../topics/logging>` for details. + + """ + def __init__( self, wsuri: WebSocketURI, @@ -68,7 +90,14 @@ def __init__( def connect(self) -> Request: # noqa: F811 """ - Create a WebSocket handshake request event to open a connection. + Create a handshake request to open a connection. + + You must send the handshake request with :meth:`send_request`. + + You can modify it before sending it, for example to add HTTP headers. + + Returns: + Request: WebSocket handshake request event to send to the server. """ headers = Headers() @@ -275,7 +304,7 @@ def send_request(self, request: Request) -> None: Send a handshake request to the server. Args: - request: WebSocket handshake request event to send. + request: WebSocket handshake request event. """ if self.debug: diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 684664860..8661a148b 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -38,12 +38,12 @@ ] Event = Union[Request, Response, Frame] - - -# A WebSocket connection is either a server or a client. +"""Events that :meth:`~Connection.events_received` may return.""" class Side(enum.IntEnum): + """A WebSocket connection is either a server or a client.""" + SERVER, CLIENT = range(2) @@ -51,10 +51,9 @@ class Side(enum.IntEnum): CLIENT = Side.CLIENT -# A WebSocket connection goes through the following four states, in order: - - class State(enum.IntEnum): + """A WebSocket connection is in one of these four states.""" + CONNECTING, OPEN, CLOSING, CLOSED = range(4) @@ -64,12 +63,26 @@ class State(enum.IntEnum): CLOSED = State.CLOSED -# Sentinel to signal that the connection should be closed. - SEND_EOF = b"" +"""Sentinel signaling that the TCP connection must be half-closed.""" class Connection: + """ + Sans-I/O implementation of a WebSocket connection. + + Args: + side: :attr:`~Side.CLIENT` or :attr:`~Side.SERVER`. + state: initial state of the WebSocket connection. + max_size: maximum size of incoming messages in bytes; + :obj:`None` to disable the limit. + logger: logger for this connection; depending on ``side``, + defaults to ``logging.getLogger("websockets.client")`` + or ``logging.getLogger("websockets.server")``; + see the :doc:`logging guide <../topics/logging>` for details. + + """ + def __init__( self, side: Side, @@ -78,12 +91,14 @@ def __init__( logger: Optional[LoggerLike] = None, ) -> None: # Unique identifier. For logs. - self.id = uuid.uuid4() + self.id: uuid.UUID = uuid.uuid4() + """Unique identifier of the connection. Useful in logs.""" # Logger or LoggerAdapter for this connection. if logger is None: logger = logging.getLogger(f"websockets.{side.name.lower()}") - self.logger = logger + self.logger: LoggerLike = logger + """Logger for this connection.""" # Track if DEBUG is enabled. Shortcut logging calls if it isn't. self.debug = logger.isEnabledFor(logging.DEBUG) @@ -126,12 +141,12 @@ def __init__( next(self.parser) # start coroutine self.parser_exc: Optional[Exception] = None - # Public attributes - @property def state(self) -> State: """ - Connection State defined in 4.1, 4.2, 7.1.3, and 7.1.4 of :rfc:`6455`. + WebSocket connection state. + + Defined in 4.1, 4.2, 7.1.3, and 7.1.4 of :rfc:`6455`. """ return self._state @@ -145,9 +160,12 @@ def state(self, state: State) -> None: @property def close_code(self) -> Optional[int]: """ - Connection Close Code defined in 7.1.5 of :rfc:`6455`. + `WebSocket close code`_. - Available once the connection is closed. + .. _WebSocket close code: + https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.5 + + :obj:`None` if the connection isn't closed yet. """ if self.state is not CLOSED: @@ -160,9 +178,12 @@ def close_code(self) -> Optional[int]: @property def close_reason(self) -> Optional[str]: """ - Connection Close Reason defined in 7.1.6 of :rfc:`6455`. + `WebSocket close reason`_. + + .. _WebSocket close reason: + https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.6 - Available once the connection is closed. + :obj:`None` if the connection isn't closed yet. """ if self.state is not CLOSED: @@ -175,13 +196,20 @@ def close_reason(self) -> Optional[str]: @property def close_exc(self) -> ConnectionClosed: """ - Exception raised when trying to interact with a closed connection. + Exception to raise when trying to interact with a closed connection. + + Don't raise this exception while the connection :attr:`state` + is :attr:`~websockets.connection.State.CLOSING`; wait until + it's :attr:`~websockets.connection.State.CLOSED`. - Available once the connection is closed. If you need to raise this - exception while the connection is closing, wait until it's closed. + Indeed, the exception includes the close code and reason, which are + known only once the connection is closed. + + Raises: + AssertionError: if the connection isn't closed yet. """ - assert self.state is CLOSED + assert self.state is CLOSED, "connection isn't closed yet" exc_type: Type[ConnectionClosed] if ( self.close_rcvd is not None @@ -205,15 +233,15 @@ def close_exc(self) -> ConnectionClosed: def receive_data(self, data: bytes) -> None: """ - Receive data from the connection. + Receive data from the network. After calling this method: - - You must call :meth:`data_to_send` and send this data. - - You should call :meth:`events_received` and process these events. + - You must call :meth:`data_to_send` and send this data to the network. + - You should call :meth:`events_received` and process resulting events. Raises: - EOFError: if :meth:`receive_eof` was called before. + EOFError: if :meth:`receive_eof` was called earlier. """ self.reader.feed_data(data) @@ -221,16 +249,16 @@ def receive_data(self, data: bytes) -> None: def receive_eof(self) -> None: """ - Receive the end of the data stream from the connection. + Receive the end of the data stream from the network. After calling this method: - - You must call :meth:`data_to_send` and send this data. - - You aren't exepcted to call :meth:`events_received` as it won't - return any new events. + - You must call :meth:`data_to_send` and send this data to the network. + - You aren't expected to call :meth:`events_received`; it won't return + any new events. Raises: - EOFError: if :meth:`receive_eof` was called before. + EOFError: if :meth:`receive_eof` was called earlier. """ self.reader.feed_eof() @@ -240,7 +268,19 @@ def receive_eof(self) -> None: def send_continuation(self, data: bytes, fin: bool) -> None: """ - Send a continuation frame. + Send a `Continuation frame`_. + + .. _Continuation frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + Parameters: + data: payload containing the same kind of data + as the initial frame. + fin: FIN bit; set it to :obj:`True` if this is the last frame + of a fragmented message and to :obj:`False` otherwise. + + Raises: + ProtocolError: if a fragmented message isn't in progress. """ if not self.expect_continuation_frame: @@ -250,7 +290,18 @@ def send_continuation(self, data: bytes, fin: bool) -> None: def send_text(self, data: bytes, fin: bool = True) -> None: """ - Send a text frame. + Send a `Text frame`_. + + .. _Text frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + Parameters: + data: payload containing text encoded with UTF-8. + fin: FIN bit; set it to :obj:`False` if this is the first frame of + a fragmented message. + + Raises: + ProtocolError: if a fragmented message is in progress. """ if self.expect_continuation_frame: @@ -260,7 +311,18 @@ def send_text(self, data: bytes, fin: bool = True) -> None: def send_binary(self, data: bytes, fin: bool = True) -> None: """ - Send a binary frame. + Send a `Binary frame`_. + + .. _Binary frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + Parameters: + data: payload containing arbitrary binary data. + fin: FIN bit; set it to :obj:`False` if this is the first frame of + a fragmented message. + + Raises: + ProtocolError: if a fragmented message is in progress. """ if self.expect_continuation_frame: @@ -270,7 +332,18 @@ def send_binary(self, data: bytes, fin: bool = True) -> None: def send_close(self, code: Optional[int] = None, reason: str = "") -> None: """ - Send a connection close frame. + Send a `Close frame`_. + + .. _Close frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.1 + + Parameters: + code: close code. + reason: close reason. + + Raises: + ProtocolError: if a fragmented message is being sent, if the code + isn't valid, or if a reason is provided without a code """ if self.expect_continuation_frame: @@ -291,22 +364,43 @@ def send_close(self, code: Optional[int] = None, reason: str = "") -> None: def send_ping(self, data: bytes) -> None: """ - Send a ping frame. + Send a `Ping frame`_. + + .. _Ping frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 + + Parameters: + data: payload containing arbitrary binary data. """ self.send_frame(Frame(OP_PING, data)) def send_pong(self, data: bytes) -> None: """ - Send a pong frame. + Send a `Pong frame`_. + + .. _Pong frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 + + Parameters: + data: payload containing arbitrary binary data. """ self.send_frame(Frame(OP_PONG, data)) def fail(self, code: int, reason: str = "") -> None: """ - Fail the WebSocket connection. + `Fail the WebSocket connection`_. + .. _Fail the WebSocket connection: + https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.7 + + Parameters: + code: close code + reason: close reason + + Raises: + ProtocolError: if the code isn't valid. """ # 7.1.7. Fail the WebSocket Connection @@ -339,11 +433,14 @@ def fail(self, code: int, reason: str = "") -> None: def events_received(self) -> List[Event]: """ - Return events read from the connection. + Fetch events generated from data received from the network. - Call this method immediately after calling any of the ``receive_*()`` - methods and process the events. + Call this method immediately after any of the ``receive_*()`` methods. + Process resulting events, likely by passing them to the application. + + Returns: + List[Event]: Events read from the connection. """ events, self.events = self.events, [] return events @@ -352,13 +449,19 @@ def events_received(self) -> List[Event]: def data_to_send(self) -> List[bytes]: """ - Return data to write to the connection. + Obtain data to send to the network. + + Call this method immediately after any of the ``receive_*()``, + ``send_*()``, or :meth:`fail` methods. + + Write resulting data to the connection. - Call this method immediately after calling any of the ``receive_*()``, - ``send_*()``, or ``fail()`` methods and write the data to the + The empty bytestring :data:`~websockets.connection.SEND_EOF` signals + the end of the data stream. When you receive it, half-close the TCP connection. - The empty bytestring signals the end of the data stream. + Returns: + List[bytes]: Data to write to the connection. """ writes, self.writes = self.writes, [] @@ -366,11 +469,16 @@ def data_to_send(self) -> List[bytes]: def close_expected(self) -> bool: """ - Tell whether the TCP connection is expected to close soon. + Tell if the TCP connection is expected to close soon. + + Call this method immediately after any of the ``receive_*()`` or + :meth:`fail` methods. + + If it returns :obj:`True`, schedule closing the TCP connection after a + short timeout if the other side hasn't already closed it. - Call this method immediately after calling any of the ``receive_*()`` - or ``fail_*()`` methods and, if it returns :obj:`True`, schedule - closing the TCP connection after a short timeout. + Returns: + bool: Whether the TCP connection is expected to close soon. """ # We already got a TCP Close if and only if the state is CLOSED. diff --git a/src/websockets/server.py b/src/websockets/server.py index bb91bdc77..2d9b4f9a8 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -45,8 +45,26 @@ class ServerConnection(Connection): - - side = SERVER + """ + Sans-I/O implementation of a WebSocket server connection. + + Args: + origins: acceptable values of the ``Origin`` header; include + :obj:`None` in the list if the lack of an origin is acceptable. + This is useful for defending against Cross-Site WebSocket + Hijacking attacks. + extensions: list of supported extensions, in order in which they + should be tried. + subprotocols: list of supported subprotocols, in order of decreasing + preference. + state: initial state of the WebSocket connection. + max_size: maximum size of incoming messages in bytes; + :obj:`None` to disable the limit. + logger: logger for this connection; + defaults to ``logging.getLogger("websockets.client")``; + see the :doc:`logging guide <../topics/logging>` for details. + + """ def __init__( self, @@ -69,16 +87,20 @@ def __init__( def accept(self, request: Request) -> Response: """ - Create a WebSocket handshake response event to accept the connection. + Create a handshake response to accept the connection. + + If the connection cannot be established, the handshake response + actually rejects the handshake. - If the connection cannot be established, create a HTTP response event - to reject the handshake. + You must send the handshake response with :meth:`send_response`. + + You can modify it before sending it, for example to add HTTP headers. Args: - request: handshake request event received from the client. + request: WebSocket handshake request event received from the client. Returns: - Response: handshake response event to send to the client. + Response: WebSocket handshake response event to send to the client. """ try: @@ -418,13 +440,21 @@ def reject( text: str, ) -> Response: """ - Create a HTTP response event to reject the connection. + Create a handshake response to reject the connection. A short plain text response is the best fallback when failing to establish a WebSocket connection. + You must send the handshake response with :meth:`send_response`. + + You can modify it before sending it, for example to alter HTTP headers. + + Args: + status: HTTP status code. + text: HTTP response body; will be encoded to UTF-8. + Returns: - Response: HTTP handshake response to send to the client. + Response: WebSocket handshake response event to send to the client. """ body = text.encode() From a04bfdb8f7eaa0071f3b37efe83960763311fa6f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 9 Sep 2021 07:45:58 +0200 Subject: [PATCH 0950/1539] Add changelog. --- docs/project/changelog.rst | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 84ab6a314..736dfde41 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -54,6 +54,7 @@ Backwards-incompatible changes :obj:`None` to disable the timeout entirely. .. admonition:: The ``legacy_recv`` option is deprecated. + :class: note See the release notes of websockets 3.0 for details. @@ -72,6 +73,14 @@ Backwards-incompatible changes New features ............ +.. admonition:: websockets 10.0 introduces a `Sans-I/O API + `_ for easier integration + in third-party libraries. + :class: important + + If you're integrating websockets in a library, rather than just using it, + look at the :doc:`Sans-I/O integration guide <../howto/sansio>`. + * Added compatibility with Python 3.10. * Added :func:`~websockets.broadcast` to send a message to many clients. From 13eff12bb4c995b50154fdc250281c92ddccaca0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 9 Sep 2021 07:55:31 +0200 Subject: [PATCH 0951/1539] Bump version number. --- docs/conf.py | 2 +- docs/project/changelog.rst | 2 +- src/websockets/version.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index ffe61f7ba..d22b85a82 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -26,7 +26,7 @@ author = "Aymeric Augustin" # The full version, including alpha/beta/rc tags -release = "9.1" +release = "10.0" # -- General configuration --------------------------------------------------- diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 736dfde41..90e08728b 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -28,7 +28,7 @@ They may change at any time. 10.0 ---- -*In development* +*September 9, 2021* Backwards-incompatible changes .............................. diff --git a/src/websockets/version.py b/src/websockets/version.py index a7901ef92..168f8b054 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -1 +1 @@ -version = "9.1" +version = "10.0" From 32d9a52a18960004780a735cabb2d881969cb6ee Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 9 Sep 2021 20:44:18 +0200 Subject: [PATCH 0952/1539] Remove references to Python 2. RIP --- README.rst | 2 -- docs/howto/faq.rst | 9 --------- 2 files changed, 11 deletions(-) diff --git a/README.rst b/README.rst index 6b9b67672..48d2637c1 100644 --- a/README.rst +++ b/README.rst @@ -125,8 +125,6 @@ Why shouldn't I use ``websockets``? at being an excellent implementation of :rfc:`6455`: The WebSocket Protocol and :rfc:`7692`: Compression Extensions for WebSocket. Its support for HTTP is minimal — just enough for a HTTP health check. -* If you want to use Python 2: ``websockets`` builds upon ``asyncio`` which - only works on Python 3. ``websockets`` requires Python ≥ 3.7. What else? ---------- diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index a32e5ec1b..62c017a88 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -428,15 +428,6 @@ coroutines make it easier to manage control flow in concurrent code. If you prefer callback-based APIs, you should use another library. -Is there a Python 2 version? -............................ - -No, there isn't. - -Python 2 reached end of life on January 1st, 2020. - -Before that date, websockets required asyncio and therefore Python 3. - Why do I get the error: ``module 'websockets' has no attribute '...'``? ....................................................................... From c8428ced9850b0838edd185605b076b4b28ad406 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 9 Sep 2021 22:01:09 +0200 Subject: [PATCH 0953/1539] Add security policy --- SECURITY.md | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 SECURITY.md diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 000000000..556217a4d --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,5 @@ +# Security policy + +Only the latest version receives security updates. + +Please report vulnerabilities [via Tidelift](https://tidelift.com/docs/security). From 9b8a8d1cb560d292aecde52252289e3560760167 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 9 Sep 2021 21:40:10 +0200 Subject: [PATCH 0954/1539] Remove path argument from connection handlers. Ensure backwards-compatibility. Ref #1038. --- README.rst | 2 +- compliance/test_server.py | 2 +- docs/howto/faq.rst | 18 ++++---- docs/intro/index.rst | 12 +++--- docs/project/changelog.rst | 13 ++++++ docs/topics/authentication.rst | 8 ++-- docs/topics/broadcast.rst | 8 ++-- example/basic_auth_server.py | 2 +- example/counter.py | 2 +- example/deployment/haproxy/app.py | 2 +- example/deployment/heroku/app.py | 2 +- example/deployment/kubernetes/app.py | 2 +- example/deployment/nginx/app.py | 2 +- example/deployment/supervisor/app.py | 2 +- example/django/authentication.py | 2 +- example/django/notifications.py | 2 +- example/echo.py | 2 +- example/health_check_server.py | 2 +- example/secure_server.py | 2 +- example/server.py | 2 +- example/show_time.py | 2 +- example/shutdown_server.py | 2 +- example/unix_server.py | 2 +- experiments/authentication/app.py | 10 ++--- experiments/broadcast/server.py | 2 +- experiments/compression/server.py | 2 +- src/websockets/legacy/server.py | 64 ++++++++++++++++++++++------ tests/legacy/test_client_server.py | 39 +++++++++++------ 28 files changed, 140 insertions(+), 72 deletions(-) diff --git a/README.rst b/README.rst index 48d2637c1..99e477867 100644 --- a/README.rst +++ b/README.rst @@ -63,7 +63,7 @@ And here's an echo server: import asyncio from websockets import serve - async def echo(websocket, path): + async def echo(websocket): async for message in websocket: await websocket.send(message) diff --git a/compliance/test_server.py b/compliance/test_server.py index 14ac90fe6..92f895d92 100644 --- a/compliance/test_server.py +++ b/compliance/test_server.py @@ -13,7 +13,7 @@ HOST, PORT = "127.0.0.1", 8642 -async def echo(ws, path): +async def echo(ws): async for msg in ws: await ws.send(msg) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index 62c017a88..bef61bdab 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -22,12 +22,12 @@ before returning. For example, if your handler has a structure similar to:: - async def handler(websocket, path): + async def handler(websocket): asyncio.create_task(do_some_work()) change it to:: - async def handler(websocket, path): + async def handler(websocket): await do_some_work() Why does the server close the connection after processing one message? @@ -38,12 +38,12 @@ process multiple messages. For example, if your handler looks like this:: - async def handler(websocket, path): + async def handler(websocket): print(websocket.recv()) change it like this:: - async def handler(websocket, path): + async def handler(websocket): async for message in websocket: print(message) @@ -58,12 +58,12 @@ Any call that may take some time must be asynchronous. For example, if you have:: - async def handler(websocket, path): + async def handler(websocket): time.sleep(1) change it to:: - async def handler(websocket, path): + async def handler(websocket): await asyncio.sleep(1) This is part of learning asyncio. It isn't specific to websockets. @@ -82,7 +82,7 @@ You can bind additional arguments to the connection handler with import functools import websockets - async def handler(websocket, path, extra_argument): + async def handler(websocket, extra_argument): ... bound_handler = functools.partial(handler, extra_argument='spam') @@ -104,7 +104,7 @@ To access HTTP headers during the WebSocket handshake, you can override Once the connection is established, they're available in :attr:`~server.WebSocketServerProtocol.request_headers`:: - async def handler(websocket, path): + async def handler(websocket): cookies = websocket.request_headers["Cookie"] How do I get the IP address of the client connecting to my server? @@ -112,7 +112,7 @@ How do I get the IP address of the client connecting to my server? It's available in :attr:`~legacy.protocol.WebSocketCommonProtocol.remote_address`:: - async def handler(websocket, path): + async def handler(websocket): remote_ip = websocket.remote_address[0] How do I set which IP addresses my server listens to? diff --git a/docs/intro/index.rst b/docs/intro/index.rst index c8426719c..bd7c48f81 100644 --- a/docs/intro/index.rst +++ b/docs/intro/index.rst @@ -131,7 +131,7 @@ Consumer For receiving messages and passing them to a ``consumer`` coroutine:: - async def consumer_handler(websocket, path): + async def consumer_handler(websocket): async for message in websocket: await consumer(message) @@ -145,7 +145,7 @@ Producer For getting messages from a ``producer`` coroutine and sending them:: - async def producer_handler(websocket, path): + async def producer_handler(websocket): while True: message = await producer() await websocket.send(message) @@ -163,11 +163,11 @@ Both sides You can read and write messages on the same connection by combining the two patterns shown above and running the two tasks in parallel:: - async def handler(websocket, path): + async def handler(websocket): consumer_task = asyncio.ensure_future( - consumer_handler(websocket, path)) + consumer_handler(websocket)) producer_task = asyncio.ensure_future( - producer_handler(websocket, path)) + producer_handler(websocket)) done, pending = await asyncio.wait( [consumer_task, producer_task], return_when=asyncio.FIRST_COMPLETED, @@ -186,7 +186,7 @@ unregister them when they disconnect. connected = set() - async def handler(websocket, path): + async def handler(websocket): # Register. connected.add(websocket) try: diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 90e08728b..2a714f945 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,19 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented APIs are considered private. They may change at any time. +10.1 +---- + +*In development* + +Improvements +............ + +* Made the second parameter of connection handlers optional. It will be + deprecated in the next major release. The request path is available in + the :attr:`~legacy.protocol.WebSocketCommonProtocol.path` attribute of + the first argument. + 10.0 ---- diff --git a/docs/topics/authentication.rst b/docs/topics/authentication.rst index 8beeea9aa..4d702b2f6 100644 --- a/docs/topics/authentication.rst +++ b/docs/topics/authentication.rst @@ -185,7 +185,7 @@ connection: .. code:: python - async def first_message_handler(websocket, path): + async def first_message_handler(websocket): token = await websocket.recv() user = get_user(token) if user is None: @@ -224,7 +224,7 @@ the user. If authentication fails, it returns a HTTP 401: self.user = user - async def query_param_handler(websocket, path): + async def query_param_handler(websocket): user = websocket.user ... @@ -273,7 +273,7 @@ the user. If authentication fails, it returns a HTTP 401: self.user = user - async def cookie_handler(websocket, path): + async def cookie_handler(websocket): user = websocket.user ... @@ -311,7 +311,7 @@ the user. If authentication fails, it returns a HTTP 401: self.user = user return True - async def user_info_handler(websocket, path): + async def user_info_handler(websocket): user = websocket.user ... diff --git a/docs/topics/broadcast.rst b/docs/topics/broadcast.rst index a90cc2d70..531d8ca12 100644 --- a/docs/topics/broadcast.rst +++ b/docs/topics/broadcast.rst @@ -27,7 +27,7 @@ Integrating them is left as an exercise for the reader. You could start with:: import asyncio import websockets - async def handler(websocket, path): + async def handler(websocket): ... async def broadcast(message): @@ -64,7 +64,7 @@ Here's a connection handler that registers clients in a global variable:: CLIENTS = set() - async def handler(websocket, path): + async def handler(websocket): CLIENTS.add(websocket) try: await websocket.wait_closed() @@ -241,7 +241,7 @@ run a task that gets messages from the queue and sends them to the client:: message = await queue.get() await websocket.send(message) - async def handler(websocket, path): + async def handler(websocket): queue = asyncio.Queue() relay_task = asyncio.create_task(relay(queue, websocket)) CLIENTS.add(queue) @@ -297,7 +297,7 @@ no references left, therefore the garbage collector deletes it. The connection handler subscribes to the stream and sends messages:: - async def handler(websocket, path): + async def handler(websocket): async for message in PUBSUB: await websocket.send(message) diff --git a/example/basic_auth_server.py b/example/basic_auth_server.py index 532c5bc51..d2efeb7e5 100755 --- a/example/basic_auth_server.py +++ b/example/basic_auth_server.py @@ -5,7 +5,7 @@ import asyncio import websockets -async def hello(websocket, path): +async def hello(websocket): greeting = f"Hello {websocket.username}!" await websocket.send(greeting) diff --git a/example/counter.py b/example/counter.py index e41f6fabb..6e33b3afc 100755 --- a/example/counter.py +++ b/example/counter.py @@ -22,7 +22,7 @@ def users_event(): return json.dumps({"type": "users", "count": len(USERS)}) -async def counter(websocket, path): +async def counter(websocket): try: # Register user USERS.add(websocket) diff --git a/example/deployment/haproxy/app.py b/example/deployment/haproxy/app.py index 2b24790dd..360479b8e 100644 --- a/example/deployment/haproxy/app.py +++ b/example/deployment/haproxy/app.py @@ -7,7 +7,7 @@ import websockets -async def echo(websocket, path): +async def echo(websocket): async for message in websocket: await websocket.send(message) diff --git a/example/deployment/heroku/app.py b/example/deployment/heroku/app.py index ff9ba2775..d4ba3edb5 100644 --- a/example/deployment/heroku/app.py +++ b/example/deployment/heroku/app.py @@ -7,7 +7,7 @@ import websockets -async def echo(websocket, path): +async def echo(websocket): async for message in websocket: await websocket.send(message) diff --git a/example/deployment/kubernetes/app.py b/example/deployment/kubernetes/app.py index dcc29bd1c..a8bcef688 100755 --- a/example/deployment/kubernetes/app.py +++ b/example/deployment/kubernetes/app.py @@ -9,7 +9,7 @@ import websockets -async def slow_echo(websocket, path): +async def slow_echo(websocket): async for message in websocket: # Block the event loop! This allows saturating a single asyncio # process without opening an impractical number of connections. diff --git a/example/deployment/nginx/app.py b/example/deployment/nginx/app.py index ad42a8b3e..24e608975 100644 --- a/example/deployment/nginx/app.py +++ b/example/deployment/nginx/app.py @@ -7,7 +7,7 @@ import websockets -async def echo(websocket, path): +async def echo(websocket): async for message in websocket: await websocket.send(message) diff --git a/example/deployment/supervisor/app.py b/example/deployment/supervisor/app.py index 484566bc8..bf61983ef 100644 --- a/example/deployment/supervisor/app.py +++ b/example/deployment/supervisor/app.py @@ -6,7 +6,7 @@ import websockets -async def echo(websocket, path): +async def echo(websocket): async for message in websocket: await websocket.send(message) diff --git a/example/django/authentication.py b/example/django/authentication.py index bbb3db02a..7f60f8275 100644 --- a/example/django/authentication.py +++ b/example/django/authentication.py @@ -10,7 +10,7 @@ from sesame.utils import get_user -async def handler(websocket, path): +async def handler(websocket): sesame = await websocket.recv() user = await asyncio.to_thread(get_user, sesame) if user is None: diff --git a/example/django/notifications.py b/example/django/notifications.py index 41fb719dc..7275a1ef7 100644 --- a/example/django/notifications.py +++ b/example/django/notifications.py @@ -28,7 +28,7 @@ def get_content_types(user): } -async def handler(websocket, path): +async def handler(websocket): """Authenticate user and register connection in CONNECTIONS.""" sesame = await websocket.recv() user = await asyncio.to_thread(get_user, sesame) diff --git a/example/echo.py b/example/echo.py index 024f8d8ac..4b673cb17 100755 --- a/example/echo.py +++ b/example/echo.py @@ -3,7 +3,7 @@ import asyncio import websockets -async def echo(websocket, path): +async def echo(websocket): async for message in websocket: await websocket.send(message) diff --git a/example/health_check_server.py b/example/health_check_server.py index 2ca185cde..7b8bded77 100755 --- a/example/health_check_server.py +++ b/example/health_check_server.py @@ -8,7 +8,7 @@ async def health_check(path, request_headers): if path == "/healthz": return http.HTTPStatus.OK, [], b"OK\n" -async def echo(websocket, path): +async def echo(websocket): async for message in websocket: await websocket.send(message) diff --git a/example/secure_server.py b/example/secure_server.py index cd8ee0cc1..f0231bc16 100755 --- a/example/secure_server.py +++ b/example/secure_server.py @@ -7,7 +7,7 @@ import ssl import websockets -async def hello(websocket, path): +async def hello(websocket): name = await websocket.recv() print(f"<<< {name}") diff --git a/example/server.py b/example/server.py index 4dcf317f5..7fd7bdf4c 100755 --- a/example/server.py +++ b/example/server.py @@ -5,7 +5,7 @@ import asyncio import websockets -async def hello(websocket, path): +async def hello(websocket): name = await websocket.recv() print(f"<<< {name}") diff --git a/example/show_time.py b/example/show_time.py index 8e39f1776..b5a153b71 100755 --- a/example/show_time.py +++ b/example/show_time.py @@ -7,7 +7,7 @@ import random import websockets -async def time(websocket, path): +async def time(websocket): while True: now = datetime.datetime.utcnow().isoformat() + "Z" await websocket.send(now) diff --git a/example/shutdown_server.py b/example/shutdown_server.py index cabba4014..1bcc9c90b 100755 --- a/example/shutdown_server.py +++ b/example/shutdown_server.py @@ -4,7 +4,7 @@ import signal import websockets -async def echo(websocket, path): +async def echo(websocket): async for message in websocket: await websocket.send(message) diff --git a/example/unix_server.py b/example/unix_server.py index 192f31bb0..335039c35 100755 --- a/example/unix_server.py +++ b/example/unix_server.py @@ -6,7 +6,7 @@ import os.path import websockets -async def hello(websocket, path): +async def hello(websocket): name = await websocket.recv() print(f"<<< {name}") diff --git a/experiments/authentication/app.py b/experiments/authentication/app.py index c3c9a0557..6b3b2ae3f 100644 --- a/experiments/authentication/app.py +++ b/experiments/authentication/app.py @@ -84,14 +84,14 @@ async def serve_html(path, request_headers): return http.HTTPStatus.NOT_FOUND, {}, b"Not found\n" -async def noop_handler(websocket, path): +async def noop_handler(websocket): pass # Send credentials as the first message in the WebSocket connection -async def first_message_handler(websocket, path): +async def first_message_handler(websocket): token = await websocket.recv() user = get_user(token) if user is None: @@ -119,7 +119,7 @@ async def process_request(self, path, headers): self.user = user -async def query_param_handler(websocket, path): +async def query_param_handler(websocket): user = websocket.user await websocket.send(f"Hello {user}!") @@ -149,7 +149,7 @@ async def process_request(self, path, headers): self.user = user -async def cookie_handler(websocket, path): +async def cookie_handler(websocket): user = websocket.user await websocket.send(f"Hello {user}!") @@ -173,7 +173,7 @@ async def check_credentials(self, username, password): return True -async def user_info_handler(websocket, path): +async def user_info_handler(websocket): user = websocket.user await websocket.send(f"Hello {user}!") diff --git a/experiments/broadcast/server.py b/experiments/broadcast/server.py index 355d6b5e9..9c9907b7f 100644 --- a/experiments/broadcast/server.py +++ b/experiments/broadcast/server.py @@ -45,7 +45,7 @@ async def subscribe(self): PUBSUB = PubSub() -async def handler(websocket, path, method=None): +async def handler(websocket, method=None): if method in ["default", "naive", "task", "wait"]: CLIENTS.add(websocket) try: diff --git a/experiments/compression/server.py b/experiments/compression/server.py index f7b147006..8d1ee3cd7 100644 --- a/experiments/compression/server.py +++ b/experiments/compression/server.py @@ -18,7 +18,7 @@ MEM_SIZE = [] -async def handler(ws, path): +async def handler(ws): msg = await ws.recv() await ws.send(msg) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 9bda5cdd8..4de4959b9 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -4,6 +4,7 @@ import email.utils import functools import http +import inspect import logging import socket import warnings @@ -95,7 +96,10 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): def __init__( self, - ws_handler: Callable[[WebSocketServerProtocol, str], Awaitable[Any]], + ws_handler: Union[ + Callable[[WebSocketServerProtocol], Awaitable[Any]], + Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated + ], ws_server: WebSocketServer, *, logger: Optional[LoggerLike] = None, @@ -118,7 +122,10 @@ def __init__( if origins is not None and "" in origins: warnings.warn("use None instead of '' in origins", DeprecationWarning) origins = [None if origin == "" else origin for origin in origins] - self.ws_handler = ws_handler + # For backwards compatibility with 10.0 or earlier. Done here in + # addition to serve to trigger the deprecation warning on direct + # use of WebSocketServerProtocol. + self.ws_handler = remove_path_argument(ws_handler) self.ws_server = ws_server self.origins = origins self.available_extensions = extensions @@ -152,7 +159,7 @@ async def handler(self) -> None: try: try: - path = await self.handshake( + await self.handshake( origins=self.origins, available_extensions=self.available_extensions, available_subprotocols=self.available_subprotocols, @@ -221,7 +228,7 @@ async def handler(self) -> None: return try: - await self.ws_handler(self, path) + await self.ws_handler(self) except Exception: self.logger.error("connection handler failed", exc_info=True) if not self.closed: @@ -555,8 +562,6 @@ async def handshake( """ Perform the server side of the opening handshake. - Return the path of the URI of the request. - Args: origins: list of acceptable values of the Origin HTTP header; include :obj:`None` if the lack of an origin is acceptable. @@ -567,6 +572,9 @@ async def handshake( extra_headers: arbitrary HTTP headers to add to the response when the handshake succeeds. + Returns: + str: path of the URI of the request. + Raises: InvalidHandshake: if the handshake fails. @@ -851,9 +859,8 @@ class Serve: The server is shut down automatically when exiting the context. Args: - ws_handler: connection handler. It must be a coroutine accepting - two arguments: the WebSocket connection, which is a - :class:`WebSocketServerProtocol`, and the path of the request. + ws_handler: connection handler. It receives the WebSocket connection, + which is a :class:`WebSocketServerProtocol`, in argument. host: network interfaces the server is bound to; see :meth:`~asyncio.loop.create_server` for details. port: TCP port the server listens on; @@ -908,7 +915,10 @@ class Serve: def __init__( self, - ws_handler: Callable[[WebSocketServerProtocol, str], Awaitable[Any]], + ws_handler: Union[ + Callable[[WebSocketServerProtocol], Awaitable[Any]], + Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated + ], host: Optional[Union[str, Sequence[str]]] = None, port: Optional[int] = None, *, @@ -979,7 +989,10 @@ def __init__( factory = functools.partial( create_protocol, - ws_handler, + # For backwards compatibility with 10.0 or earlier. Done here in + # addition to WebSocketServerProtocol to trigger the deprecation + # warning once per serve() call rather than once per connection. + remove_path_argument(ws_handler), ws_server, host=host, port=port, @@ -1052,7 +1065,10 @@ async def __await_impl__(self) -> WebSocketServer: def unix_serve( - ws_handler: Callable[[WebSocketServerProtocol, str], Awaitable[Any]], + ws_handler: Union[ + Callable[[WebSocketServerProtocol], Awaitable[Any]], + Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated + ], path: Optional[str] = None, **kwargs: Any, ) -> Serve: @@ -1071,3 +1087,27 @@ def unix_serve( """ return serve(ws_handler, path=path, unix=True, **kwargs) + + +def remove_path_argument( + ws_handler: Union[ + Callable[[WebSocketServerProtocol], Awaitable[Any]], + Callable[[WebSocketServerProtocol, str], Awaitable[Any]], + ] +) -> Callable[[WebSocketServerProtocol], Awaitable[Any]]: + if len(inspect.signature(ws_handler).parameters) == 2: + # Enable deprecation warning and announce deprecation in 11.0. + # warnings.warn("remove second argument of ws_handler", DeprecationWarning) + + async def _ws_handler(websocket: WebSocketServerProtocol) -> Any: + return await cast( + Callable[[WebSocketServerProtocol, str], Awaitable[Any]], + ws_handler, + )(websocket, websocket.path) + + return _ws_handler + + return cast( + Callable[[WebSocketServerProtocol], Awaitable[Any]], + ws_handler, + ) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 482b2cd0c..fea2b4178 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -53,22 +53,22 @@ testcert = bytes(pathlib.Path(__file__).parent.with_name("test_localhost.pem")) -async def default_handler(ws, path): - if path == "/deprecated_attributes": +async def default_handler(ws): + if ws.path == "/deprecated_attributes": await ws.recv() # delay that allows catching warnings await ws.send(repr((ws.host, ws.port, ws.secure))) - elif path == "/close_timeout": + elif ws.path == "/close_timeout": await ws.send(repr(ws.close_timeout)) - elif path == "/path": + elif ws.path == "/path": await ws.send(str(ws.path)) - elif path == "/headers": + elif ws.path == "/headers": await ws.send(repr(ws.request_headers)) await ws.send(repr(ws.response_headers)) - elif path == "/extensions": + elif ws.path == "/extensions": await ws.send(repr(ws.extensions)) - elif path == "/subprotocol": + elif ws.path == "/subprotocol": await ws.send(repr(ws.subprotocol)) - elif path == "/slow_stop": + elif ws.path == "/slow_stop": await ws.wait_closed() await asyncio.sleep(2 * MS) else: @@ -476,6 +476,21 @@ def test_unix_socket(self): finally: self.stop_server() + def test_ws_handler_argument_backwards_compatibility(self): + async def handler_with_path(ws, path): + await ws.send(path) + + with self.temp_server( + handler=handler_with_path, + # Enable deprecation warning and announce deprecation in 11.0. + # deprecation_warnings=["remove second argument of ws_handler"], + ): + with self.temp_client("/path"): + self.assertEqual( + self.loop.run_until_complete(self.client.recv()), + "/path", + ) + async def process_request_OK(path, request_headers): return http.HTTPStatus.OK, [], b"OK\n" @@ -1336,7 +1351,7 @@ class AsyncIteratorTests(ClientServerTestsMixin, AsyncioTestCase): MESSAGES = ["3", "2", "1", "Fire!"] - async def echo_handler(ws, path): + async def echo_handler(ws): for message in AsyncIteratorTests.MESSAGES: await ws.send(message) @@ -1354,7 +1369,7 @@ async def run_client(): self.assertEqual(messages, self.MESSAGES) - async def echo_handler_1001(ws, path): + async def echo_handler_1001(ws): for message in AsyncIteratorTests.MESSAGES: await ws.send(message) await ws.close(1001) @@ -1373,7 +1388,7 @@ async def run_client(): self.assertEqual(messages, self.MESSAGES) - async def echo_handler_1011(ws, path): + async def echo_handler_1011(ws): for message in AsyncIteratorTests.MESSAGES: await ws.send(message) await ws.close(1011) @@ -1395,7 +1410,7 @@ async def run_client(): class ReconnectionTests(ClientServerTestsMixin, AsyncioTestCase): - async def echo_handler(ws, path): + async def echo_handler(ws): async for msg in ws: await ws.send(msg) From c439f1d52aafc05064cc11702d1c3014046799b0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 9 Sep 2021 22:30:05 +0200 Subject: [PATCH 0955/1539] Mirror full asyncio.Server API in WebSocketServer. --- docs/project/changelog.rst | 3 ++ docs/reference/server.rst | 8 +++++ src/websockets/legacy/server.py | 59 +++++++++++++++++++++++++-------- 3 files changed, 56 insertions(+), 14 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 2a714f945..a2c9991d5 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -38,6 +38,9 @@ Improvements the :attr:`~legacy.protocol.WebSocketCommonProtocol.path` attribute of the first argument. +* Mirrored the entire :class:`~asyncio.Server` API + in :class:`~server.WebSocketServer`. + 10.0 ---- diff --git a/docs/reference/server.rst b/docs/reference/server.rst index a8eb9c5b4..97bf320b6 100644 --- a/docs/reference/server.rst +++ b/docs/reference/server.rst @@ -24,6 +24,14 @@ Stopping a server .. automethod:: wait_closed + .. automethod:: get_loop + + .. automethod:: is_serving + + .. automethod:: start_serving + + .. automethod:: serve_forever + .. autoattribute:: sockets Using a connection diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 4de4959b9..98712ff86 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -724,13 +724,6 @@ def unregister(self, protocol: WebSocketServerProtocol) -> None: """ self.websockets.remove(protocol) - def is_serving(self) -> bool: - """ - Tell whether the server is accepting new connections or shutting down. - - """ - return self.server.is_serving() - def close(self) -> None: """ Close the server. @@ -748,7 +741,7 @@ def close(self) -> None: """ if self.close_task is None: - self.close_task = self.server.get_loop().create_task(self._close()) + self.close_task = self.get_loop().create_task(self._close()) async def _close(self) -> None: """ @@ -769,7 +762,7 @@ async def _close(self) -> None: # Wait until all accepted connections reach connection_made() and call # register(). See https://bugs.python.org/issue34852 for details. - await asyncio.sleep(0, **loop_if_py_lt_38(self.server.get_loop())) + await asyncio.sleep(0, **loop_if_py_lt_38(self.get_loop())) # Close OPEN connections with status code 1001. Since the server was # closed, handshake() closes OPENING connections with a HTTP 503 @@ -782,7 +775,7 @@ async def _close(self) -> None: asyncio.create_task(websocket.close(1001)) for websocket in self.websockets ], - **loop_if_py_lt_38(self.server.get_loop()), + **loop_if_py_lt_38(self.get_loop()), ) # Wait until all connection handlers are complete. @@ -791,7 +784,7 @@ async def _close(self) -> None: if self.websockets: await asyncio.wait( [websocket.handler_task for websocket in self.websockets], - **loop_if_py_lt_38(self.server.get_loop()), + **loop_if_py_lt_38(self.get_loop()), ) # Tell wait_closed() to return. @@ -820,16 +813,54 @@ async def wait_closed(self) -> None: """ await asyncio.shield(self.closed_waiter) + def get_loop(self) -> asyncio.AbstractEventLoop: + """ + See :meth:`asyncio.Server.get_loop`. + + """ + return self.server.get_loop() + + def is_serving(self) -> bool: + """ + See :meth:`asyncio.Server.is_serving`. + + """ + return self.server.is_serving() + + async def start_serving(self) -> None: + """ + See :meth:`asyncio.Server.start_serving`. + + """ + await self.server.start_serving() # pragma: no cover + + async def serve_forever(self) -> None: + """ + See :meth:`asyncio.Server.serve_forever`. + + """ + await self.server.serve_forever() # pragma: no cover + @property def sockets(self) -> Optional[List[socket.socket]]: """ - List of :obj:`~socket.socket` objects the server is listening on. - - :obj:`None` if the server is closed. + See :attr:`asyncio.Server.sockets`. """ return self.server.sockets + async def __aenter__(self) -> WebSocketServer: + return self # pragma: no cover + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + self.close() # pragma: no cover + await self.wait_closed() # pragma: no cover + class Serve: """ From 1b16b57ba94c00e20ab08e5596a98511bca95070 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 9 Sep 2021 22:34:19 +0200 Subject: [PATCH 0956/1539] Remove obsolete workaround. --- src/websockets/legacy/client.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 63b973ecb..595b13a73 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -651,8 +651,6 @@ async def __await_impl_timeout__(self) -> WebSocketClientProtocol: async def __await_impl__(self) -> WebSocketClientProtocol: for redirects in range(self.MAX_REDIRECTS_ALLOWED): transport, protocol = await self._create_connection() - # https://github.com/python/typeshed/pull/2756 - transport = cast(asyncio.Transport, transport) protocol = cast(WebSocketClientProtocol, protocol) try: From 541f95cdab5d4dd953fc9428c47421b7168ae0b1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 9 Sep 2021 22:48:52 +0200 Subject: [PATCH 0957/1539] Move build_host to headers module. This is where similar functions are defined. --- src/websockets/client.py | 3 ++- src/websockets/headers.py | 25 +++++++++++++++++++++++++ src/websockets/http.py | 26 +------------------------- src/websockets/legacy/client.py | 3 ++- tests/test_headers.py | 22 ++++++++++++++++++++++ tests/test_http.py | 28 +--------------------------- 6 files changed, 53 insertions(+), 54 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 6f2da5a6e..34732b3a6 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -16,13 +16,14 @@ from .headers import ( build_authorization_basic, build_extension, + build_host, build_subprotocol, parse_connection, parse_extension, parse_subprotocol, parse_upgrade, ) -from .http import USER_AGENT, build_host +from .http import USER_AGENT from .http11 import Request, Response from .typing import ( ConnectionOption, diff --git a/src/websockets/headers.py b/src/websockets/headers.py index a2fdfdd30..9ae3035a5 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -2,6 +2,7 @@ import base64 import binascii +import ipaddress import re from typing import Callable, List, Optional, Sequence, Tuple, TypeVar, cast @@ -17,6 +18,7 @@ __all__ = [ + "build_host", "parse_connection", "parse_upgrade", "parse_extension", @@ -33,6 +35,29 @@ T = TypeVar("T") +def build_host(host: str, port: int, secure: bool) -> str: + """ + Build a ``Host`` header. + + """ + # https://www.rfc-editor.org/rfc/rfc3986.html#section-3.2.2 + # IPv6 addresses must be enclosed in brackets. + try: + address = ipaddress.ip_address(host) + except ValueError: + # host is a hostname + pass + else: + # host is an IP address + if address.version == 6: + host = f"[{host}]" + + if port != (443 if secure else 80): + host = f"{host}:{port}" + + return host + + # To avoid a dependency on a parsing library, we implement manually the ABNF # described in https://www.rfc-editor.org/rfc/rfc6455.html#section-9.1 and # https://www.rfc-editor.org/rfc/rfc7230.html#appendix-B. diff --git a/src/websockets/http.py b/src/websockets/http.py index 38848b56d..b14fa94bd 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -1,6 +1,5 @@ from __future__ import annotations -import ipaddress import sys from .imports import lazy_import @@ -24,31 +23,8 @@ ) -__all__ = ["USER_AGENT", "build_host"] +__all__ = ["USER_AGENT"] PYTHON_VERSION = "{}.{}".format(*sys.version_info) USER_AGENT = f"Python/{PYTHON_VERSION} websockets/{websockets_version}" - - -def build_host(host: str, port: int, secure: bool) -> str: - """ - Build a ``Host`` header. - - """ - # https://www.rfc-editor.org/rfc/rfc3986.html#section-3.2.2 - # IPv6 addresses must be enclosed in brackets. - try: - address = ipaddress.ip_address(host) - except ValueError: - # host is a hostname - pass - else: - # host is an IP address - if address.version == 6: - host = f"[{host}]" - - if port != (443 if secure else 80): - host = f"{host}:{port}" - - return host diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 595b13a73..6704d16ce 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -35,12 +35,13 @@ from ..headers import ( build_authorization_basic, build_extension, + build_host, build_subprotocol, parse_extension, parse_subprotocol, validate_subprotocols, ) -from ..http import USER_AGENT, build_host +from ..http import USER_AGENT from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol from ..uri import WebSocketURI, parse_uri from .handshake import build_request, check_response diff --git a/tests/test_headers.py b/tests/test_headers.py index badec5a86..a2d51fc6a 100644 --- a/tests/test_headers.py +++ b/tests/test_headers.py @@ -5,6 +5,28 @@ class HeadersTests(unittest.TestCase): + def test_build_host(self): + for (host, port, secure), result in [ + (("localhost", 80, False), "localhost"), + (("localhost", 8000, False), "localhost:8000"), + (("localhost", 443, True), "localhost"), + (("localhost", 8443, True), "localhost:8443"), + (("example.com", 80, False), "example.com"), + (("example.com", 8000, False), "example.com:8000"), + (("example.com", 443, True), "example.com"), + (("example.com", 8443, True), "example.com:8443"), + (("127.0.0.1", 80, False), "127.0.0.1"), + (("127.0.0.1", 8000, False), "127.0.0.1:8000"), + (("127.0.0.1", 443, True), "127.0.0.1"), + (("127.0.0.1", 8443, True), "127.0.0.1:8443"), + (("::1", 80, False), "[::1]"), + (("::1", 8000, False), "[::1]:8000"), + (("::1", 443, True), "[::1]"), + (("::1", 8443, True), "[::1]:8443"), + ]: + with self.subTest(host=host, port=port, secure=secure): + self.assertEqual(build_host(host, port, secure), result) + def test_parse_connection(self): for header, parsed in [ # Realistic use cases diff --git a/tests/test_http.py b/tests/test_http.py index ca7c1c0a4..16bec9468 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -1,27 +1 @@ -import unittest - -from websockets.http import * - - -class HTTPTests(unittest.TestCase): - def test_build_host(self): - for (host, port, secure), result in [ - (("localhost", 80, False), "localhost"), - (("localhost", 8000, False), "localhost:8000"), - (("localhost", 443, True), "localhost"), - (("localhost", 8443, True), "localhost:8443"), - (("example.com", 80, False), "example.com"), - (("example.com", 8000, False), "example.com:8000"), - (("example.com", 443, True), "example.com"), - (("example.com", 8443, True), "example.com:8443"), - (("127.0.0.1", 80, False), "127.0.0.1"), - (("127.0.0.1", 8000, False), "127.0.0.1:8000"), - (("127.0.0.1", 443, True), "127.0.0.1"), - (("127.0.0.1", 8443, True), "127.0.0.1:8443"), - (("::1", 80, False), "[::1]"), - (("::1", 8000, False), "[::1]:8000"), - (("::1", 443, True), "[::1]"), - (("::1", 8443, True), "[::1]:8443"), - ]: - with self.subTest(host=host, port=port, secure=secure): - self.assertEqual(build_host(host, port, secure), result) +from websockets.http import * # noqa From dc0e79d6ca2ce6da4906c861443b8c696bf93dff Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 10 Sep 2021 22:54:12 +0200 Subject: [PATCH 0958/1539] Fix excessive escaping in debug logs. --- src/websockets/legacy/protocol.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index a31a5c7c8..a52f9f71d 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1241,7 +1241,7 @@ async def keepalive_ping(self) -> None: # ping() raises ConnectionClosed if the connection is lost, # when connection_lost() calls abort_pings(). - self.logger.debug("%% sending keepalive ping") + self.logger.debug("% sending keepalive ping") pong_waiter = await self.ping() if self.ping_timeout is not None: @@ -1251,7 +1251,7 @@ async def keepalive_ping(self) -> None: self.ping_timeout, **loop_if_py_lt_38(self.loop), ) - self.logger.debug("%% received keepalive pong") + self.logger.debug("% received keepalive pong") except asyncio.TimeoutError: if self.debug: self.logger.debug("! timed out waiting for keepalive pong") From 5d67724fa80765ef95b21b0b04a3d22b1014649a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 18 Sep 2021 15:55:04 +0200 Subject: [PATCH 0959/1539] Fix indentation issue. --- docs/howto/django.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/howto/django.rst b/docs/howto/django.rst index 67bf582f1..eef1b8f47 100644 --- a/docs/howto/django.rst +++ b/docs/howto/django.rst @@ -39,7 +39,7 @@ Generate tokens We want secure, short-lived tokens containing the user ID. We'll rely on `django-sesame`_, a small library designed exactly for this purpose. - .. _django-sesame: https://github.com/aaugustin/django-sesame +.. _django-sesame: https://github.com/aaugustin/django-sesame Add django-sesame to the dependencies of your Django project, install it, and configure it in the settings of the project: From 397eda4952e1f2b044e352ddbf2b45b336881583 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 18 Sep 2021 17:11:20 +0200 Subject: [PATCH 0960/1539] Mention crypto policy in issue template. --- .github/ISSUE_TEMPLATE/issue.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/ISSUE_TEMPLATE/issue.md b/.github/ISSUE_TEMPLATE/issue.md index 8efabc617..ee32cf0ee 100644 --- a/.github/ISSUE_TEMPLATE/issue.md +++ b/.github/ISSUE_TEMPLATE/issue.md @@ -20,6 +20,9 @@ https://docs.python.org/3/library/asyncio-dev.html Did you look for similar issues? Please keep the discussion in one place :-) https://github.com/aaugustin/websockets/issues?q=is%3Aissue +Is your issue related to cryptocurrency in any way? Please don't file it. +https://websockets.readthedocs.io/en/stable/project/contributing.html#cryptocurrency-users + For bugs, providing a reproduction helps a lot. Take an existing example and tweak it! https://github.com/aaugustin/websockets/tree/main/example From d8a436fc0eec66e6e9c4b4631124c0039ae63a25 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 18 Sep 2021 17:16:13 +0200 Subject: [PATCH 0961/1539] Crypto keeps getting worse. --- docs/project/contributing.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/project/contributing.rst b/docs/project/contributing.rst index c3e8dfc4c..43fd58dc8 100644 --- a/docs/project/contributing.rst +++ b/docs/project/contributing.rst @@ -55,12 +55,12 @@ websockets appears to be quite popular for interfacing with Bitcoin or other cryptocurrency trackers. I'm strongly opposed to Bitcoin's carbon footprint. I'm aware of efforts to build proof-of-stake models. I'll care once the total -carbon footprint of all cryptocurrencies drops to a non-bullshit level. +energy consumption of all cryptocurrencies drops to a non-bullshit level. -Please stop heating the planet where my children are supposed to live, thanks. +You already negated all of humanity's efforts to develop renewable energy. +Please stop heating the planet where my children will have to live. Since websockets is released under an open-source license, you can use it for any purpose you like. However, I won't spend any of my time to help you. I will summarily close issues related to Bitcoin or cryptocurrency in any way. -Thanks for your understanding. From 744482a2c2ab33e86ec877441f6f8d44ce03b282 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Sep 2021 21:20:41 +0200 Subject: [PATCH 0962/1539] Pin Python 3.9.6 in CI. Ref #1051. --- .github/workflows/tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 0042c302c..9d0792c7b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -18,7 +18,7 @@ jobs: - name: Install Python 3.x uses: actions/setup-python@v2 with: - python-version: 3.x + python-version: 3.9.6 - name: Install tox run: pip install tox - name: Run tests with coverage @@ -38,7 +38,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python: [3.7, 3.8, 3.9] + python: [3.7, 3.8, 3.9.6] steps: - name: Check out repository uses: actions/checkout@v2 From a2a61ead84cfaf60752452f245e111eecf4a6e53 Mon Sep 17 00:00:00 2001 From: Oliver Zehentleitner <47597331+oliver-zehentleitner@users.noreply.github.com> Date: Mon, 20 Sep 2021 21:43:31 +0200 Subject: [PATCH 0963/1539] Fix link to tutorial --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 99e477867..8603d91ee 100644 --- a/README.rst +++ b/README.rst @@ -75,7 +75,7 @@ And here's an echo server: Does that look good? -`Get started with the tutorial! `_ +`Get started with the tutorial! `_ .. raw:: html From 17af113f028b8a04e1ff8ba00381e9b57b386cfc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 20 Sep 2021 22:20:54 +0200 Subject: [PATCH 0964/1539] Fix link to FAQ --- .github/ISSUE_TEMPLATE/issue.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ISSUE_TEMPLATE/issue.md b/.github/ISSUE_TEMPLATE/issue.md index ee32cf0ee..f2704152c 100644 --- a/.github/ISSUE_TEMPLATE/issue.md +++ b/.github/ISSUE_TEMPLATE/issue.md @@ -12,7 +12,7 @@ assignees: '' Thanks for taking the time to report an issue! Did you check the FAQ? Perhaps you'll find the answer you need: -https://websockets.readthedocs.io/en/stable/faq.html +https://websockets.readthedocs.io/en/stable/howto/faq.html Is your question really about asyncio? Perhaps the dev guide will help: https://docs.python.org/3/library/asyncio-dev.html From 54d59f2f0e2af947714a86419b78a462063e75af Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 20 Sep 2021 22:22:02 +0200 Subject: [PATCH 0965/1539] Use the default documentation version. --- docs/howto/django.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/howto/django.rst b/docs/howto/django.rst index eef1b8f47..a3d4a2ceb 100644 --- a/docs/howto/django.rst +++ b/docs/howto/django.rst @@ -12,7 +12,7 @@ WebSocket, you have two main options. technique is well suited when you need to add a small set of real-time features — maybe a notification service — to a HTTP application. -.. _Channels: https://channels.readthedocs.io/en/latest/ +.. _Channels: https://channels.readthedocs.io/ This guide shows how to implement the second technique with websockets. It assumes familiarity with Django. From 737ea76b3f2697da9e69f2736b0c868430c18219 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 20 Sep 2021 22:24:35 +0200 Subject: [PATCH 0966/1539] Remove stale reference. --- README.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/README.rst b/README.rst index 8603d91ee..089e1bfa9 100644 --- a/README.rst +++ b/README.rst @@ -113,7 +113,6 @@ Docs`_ and see for yourself. .. _Read the Docs: https://websockets.readthedocs.io/ .. _handle backpressure correctly: https://vorpus.org/blog/some-thoughts-on-asynchronous-api-design-in-a-post-asyncawait-world/#websocket-servers -.. _Autobahn Testsuite: https://github.com/aaugustin/websockets/blob/main/compliance/README.rst Why shouldn't I use ``websockets``? ----------------------------------- From 20960c2792353dff7569554fcbe956111d772ba0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 20 Sep 2021 22:39:11 +0200 Subject: [PATCH 0967/1539] Add python -m websockets --version. --- src/websockets/__main__.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index 860e4b1fa..c562d21b5 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -11,6 +11,7 @@ from .exceptions import ConnectionClosed from .frames import Close from .legacy.client import connect +from .version import version as websockets_version if sys.platform == "win32": @@ -152,6 +153,24 @@ async def run_client( def main() -> None: + # Parse command line arguments. + parser = argparse.ArgumentParser( + prog="python -m websockets", + description="Interactive WebSocket client.", + add_help=False, + ) + group = parser.add_mutually_exclusive_group() + group.add_argument("--version", action="store_true") + group.add_argument("uri", metavar="", nargs="?") + args = parser.parse_args() + + if args.version: + print(f"websockets {websockets_version}") + return + + if args.uri is None: + parser.error("the following arguments are required: ") + # If we're on Windows, enable VT100 terminal support. if sys.platform == "win32": try: @@ -169,15 +188,6 @@ def main() -> None: except ImportError: # Windows has no `readline` normally pass - # Parse command line arguments. - parser = argparse.ArgumentParser( - prog="python -m websockets", - description="Interactive WebSocket client.", - add_help=False, - ) - parser.add_argument("uri", metavar="") - args = parser.parse_args() - # Create an event loop that will run in a background thread. loop = asyncio.new_event_loop() From 9f77f4e492fa0030ae3e5c388c606931c130065e Mon Sep 17 00:00:00 2001 From: Vladislav Smirnov Date: Tue, 21 Sep 2021 22:35:04 +0300 Subject: [PATCH 0968/1539] Fix typo in FAQ --- docs/howto/faq.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index bef61bdab..60d80677d 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -208,8 +208,8 @@ achieve this with :func:`asyncio.gather` or :func:`asyncio.create_task`. Keep track of the tasks and make sure they terminate or you cancel them when the connection terminates. -Why does my program never receives any messages? -................................................ +Why does my program never receive any messages? +............................................... Your program runs a coroutine that never yields control to the event loop. The coroutine that receives messages never gets a chance to run. From 37ef247bc4bb8692d2de187b9e153ba2dc64a9e9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 26 Sep 2021 11:18:51 +0200 Subject: [PATCH 0969/1539] Reformat notes. --- docs/howto/faq.rst | 9 +++++---- docs/topics/broadcast.rst | 7 +++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index 60d80677d..f97ea73a4 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -3,11 +3,12 @@ FAQ .. currentmodule:: websockets -.. note:: +.. admonition:: Many questions asked in websockets' issue tracker are really + about :mod:`asyncio`. + :class: seealso - Many questions asked in websockets' issue tracker are actually - about :mod:`asyncio`. Python's documentation about `developing with - asyncio`_ is a good complement. + Python's documentation about `developing with asyncio`_ is a good + complement. .. _developing with asyncio: https://docs.python.org/3/library/asyncio-dev.html diff --git a/docs/topics/broadcast.rst b/docs/topics/broadcast.rst index 531d8ca12..6c7ced8b0 100644 --- a/docs/topics/broadcast.rst +++ b/docs/topics/broadcast.rst @@ -4,10 +4,9 @@ Broadcasting messages .. currentmodule:: websockets -.. note:: - - If you just want to send a message to all connected clients, use - :func:`broadcast`. +.. admonition:: If you just want to send a message to all connected clients, + use :func:`broadcast`. + :class: tip If you want to learn about its design in depth, continue reading this document. From c70c7d6703ab8bef35cc89965becc67e8a1afe7c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 26 Sep 2021 11:22:06 +0200 Subject: [PATCH 0970/1539] Document how to autoreload in dev. --- docs/howto/autoreload.rst | 31 +++++++++++++++++++++++++++++++ docs/howto/index.rst | 3 ++- docs/project/changelog.rst | 2 ++ 3 files changed, 35 insertions(+), 1 deletion(-) create mode 100644 docs/howto/autoreload.rst diff --git a/docs/howto/autoreload.rst b/docs/howto/autoreload.rst new file mode 100644 index 000000000..4f650acb3 --- /dev/null +++ b/docs/howto/autoreload.rst @@ -0,0 +1,31 @@ +Reload on code changes +====================== + +When developing a websockets server, you may run it locally to test changes. +Unfortunately, whenever you want to try a new version of the code, you must +stop the server and restart it, which slows down your development process. + +Web frameworks such as Django or Flask provide a development server that +reloads the application automatically when you make code changes. There is no +such functionality in websockets because it's designed for production rather +than development. + +However, you can achieve the same result easily. + +Install watchdog_ with the ``watchmedo`` shell utility: + +.. code:: console + + $ pip install watchdog[watchmedo] + +.. _watchdog: https://pypi.org/project/watchdog/ + +Run your server with ``watchmedo auto-restart``: + +.. code:: console + + $ watchmedo auto-restart --pattern "*.py" --recursive --signal SIGTERM \ + python app.py + +This example assumes that the server is defined in a script called ``app.py``. +Adapt it as necessary. diff --git a/docs/howto/index.rst b/docs/howto/index.rst index 5deb7c767..d11789d48 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -1,13 +1,14 @@ How-to guides ============= -If you're stuck, perhaps you'll find the answer in the FAQ or the cheat sheet. +If you're stuck, perhaps you'll find the answer here. .. toctree:: :titlesonly: faq cheatsheet + autoreload This guide will help you integrate websockets into a broader system. diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index a2c9991d5..3209b39db 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -41,6 +41,8 @@ Improvements * Mirrored the entire :class:`~asyncio.Server` API in :class:`~server.WebSocketServer`. +* Documented how to auto-reload on code changes in development. + 10.0 ---- From 27e861fd8566d056d4382992c659868bc57ef01d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 14 Aug 2021 11:34:06 +0200 Subject: [PATCH 0971/1539] Enable Python 3.10 in CI. Fix #935. --- .github/workflows/tests.yml | 4 ++-- .github/workflows/wheels.yml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9d0792c7b..d8a95a70d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -18,7 +18,7 @@ jobs: - name: Install Python 3.x uses: actions/setup-python@v2 with: - python-version: 3.9.6 + python-version: "3.9.6" - name: Install tox run: pip install tox - name: Run tests with coverage @@ -38,7 +38,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python: [3.7, 3.8, 3.9.6] + python: ["3.7", "3.8", "3.9.6", "3.10"] steps: - name: Check out repository uses: actions/checkout@v2 diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 249cd36ce..7e62237b3 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -50,7 +50,7 @@ jobs: uses: joerick/cibuildwheel@v1.11.0 env: CIBW_ARCHS_LINUX: auto aarch64 - CIBW_BUILD: cp37-* cp38-* cp39-* + CIBW_SKIP: cp36-* - name: Save wheels uses: actions/upload-artifact@v2 with: From 5ba529bf55e271040a122b999f756f5c1919cd11 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 24 Oct 2021 21:20:46 +0200 Subject: [PATCH 0972/1539] Make apply_mask twice as fast on ARM. --- docs/project/changelog.rst | 2 ++ src/websockets/speedups.c | 22 +++++++++++++++++++--- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 3209b39db..69955ee1d 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -41,6 +41,8 @@ Improvements * Mirrored the entire :class:`~asyncio.Server` API in :class:`~server.WebSocketServer`. +* Improved performance for large messages on ARM processors. + * Documented how to auto-reload on code changes in development. 10.0 diff --git a/src/websockets/speedups.c b/src/websockets/speedups.c index fc328e528..f8d24ec7a 100644 --- a/src/websockets/speedups.c +++ b/src/websockets/speedups.c @@ -2,9 +2,11 @@ #define PY_SSIZE_T_CLEAN #include -#include /* uint32_t, uint64_t */ +#include /* uint8_t, uint32_t, uint64_t */ -#if __SSE2__ +#if __ARM_NEON +#include +#elif __SSE2__ #include #endif @@ -128,7 +130,21 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds) // We need a new scope for MSVC 2010 (non C99 friendly) { -#if __SSE2__ +#if __ARM_NEON + + // With NEON support, XOR by blocks of 16 bytes = 128 bits. + + Py_ssize_t input_len_128 = input_len & ~15; + uint8x16_t mask_128 = vreinterpretq_u8_u32(vdupq_n_u32(*(uint32_t *)mask)); + + for (; i < input_len_128; i += 16) + { + uint8x16_t in_128 = vld1q_u8((uint8_t *)(input + i)); + uint8x16_t out_128 = veorq_u8(in_128, mask_128); + vst1q_u8((uint8_t *)(output + i), out_128); + } + +#elif __SSE2__ // With SSE2 support, XOR by blocks of 16 bytes = 128 bits. From fc3ade70aeafa5535e433b2c6d05f6fb7c70e9e3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 28 Oct 2021 21:17:23 +0200 Subject: [PATCH 0973/1539] Recommend Sanic for mixing HTTP and WebSocket. Fix #1073. --- README.rst | 7 +++++++ docs/howto/faq.rst | 7 ++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 089e1bfa9..82f89f952 100644 --- a/README.rst +++ b/README.rst @@ -120,11 +120,18 @@ Why shouldn't I use ``websockets``? * If you prefer callbacks over coroutines: ``websockets`` was created to provide the best coroutine-based API to manage WebSocket connections in Python. Pick another library for a callback-based API. + * If you're looking for a mixed HTTP / WebSocket library: ``websockets`` aims at being an excellent implementation of :rfc:`6455`: The WebSocket Protocol and :rfc:`7692`: Compression Extensions for WebSocket. Its support for HTTP is minimal — just enough for a HTTP health check. + If you want do to both in the same server, look at HTTP frameworks that + build on top of ``websockets`` to support WebSocket connections, like + Sanic_. + +.. _Sanic: https://sanicframework.org/en/ + What else? ---------- diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index f97ea73a4..20b957745 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -135,7 +135,7 @@ How do I run a HTTP server and WebSocket server on the same port? You don't. HTTP and WebSockets have widely different operational characteristics. -Running them on the same server is a bad idea. +Running them with the same server becomes inconvenient when you scale. Providing a HTTP server is out of scope for websockets. It only aims at providing a WebSocket server. @@ -145,6 +145,11 @@ There's limited support for returning HTTP responses with the If you need more, pick a HTTP server and run it separately. +Alternatively, pick a HTTP framework that builds on top of ``websockets`` to +support WebSocket connections, like Sanic_. + +.. _Sanic: https://sanicframework.org/en/ + Client side ----------- From ed9a7b446c7147f6f88dbeb1d86546ad754e435e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 8 Oct 2021 22:18:24 +0200 Subject: [PATCH 0974/1539] Disable memory optimization on the client side. Also clarify compression docs. Fix #1065. --- docs/project/changelog.rst | 5 ++ docs/topics/compression.rst | 75 ++++++++++++++----- docs/topics/memory.rst | 3 + .../extensions/permessage_deflate.py | 2 - 4 files changed, 65 insertions(+), 20 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 69955ee1d..a4176d572 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -38,6 +38,11 @@ Improvements the :attr:`~legacy.protocol.WebSocketCommonProtocol.path` attribute of the first argument. +* Reverted optimization of default compression settings for clients, mainly to + avoid triggering bugs in poorly implemented servers like `AWS API Gateway`_. + + .. _AWS API Gateway: https://github.com/aaugustin/websockets/issues/1065 + * Mirrored the entire :class:`~asyncio.Server` API in :class:`~server.WebSocketServer`. diff --git a/docs/topics/compression.rst b/docs/topics/compression.rst index d40c4257d..0f264dc66 100644 --- a/docs/topics/compression.rst +++ b/docs/topics/compression.rst @@ -59,35 +59,46 @@ Deflate extension explicitly with :class:`ClientPerMessageDeflateFactory` or ], ) -The Window Bits and Memory Level values in these examples reduce memory usage at the expense of compression rate. +The Window Bits and Memory Level values in these examples reduce memory usage +at the expense of compression rate. Compression settings -------------------- When a client and a server enable the Per-Message Deflate extension, they negotiate two parameters to guarantee compatibility between compression and -decompression. This affects the trade-off between compression rate and memory -usage for both sides. +decompression. These parameters affect the trade-off between compression rate +and memory usage for both sides. * **Context Takeover** means that the compression context is retained between messages. In other words, compression is applied to the stream of messages - rather than to each message individually. Context takeover should remain - enabled to get good performance on applications that send a stream of - messages with the same structure, that is, most applications. + rather than to each message individually. + + Context takeover should remain enabled to get good performance on + applications that send a stream of messages with similar structure, + that is, most applications. + + This requires retaining the compression context and state between messages, + which increases the memory footprint of a connection. * **Window Bits** controls the size of the compression context. It must be an integer between 9 (lowest memory usage) and 15 (best compression). - websockets defaults to 12. Setting it to 8 is possible but rejected by some - versions of zlib. + Setting it to 8 is possible but rejected by some versions of zlib. + + On the server side, websockets defaults to 12. On the client side, it lets + the server pick a suitable value, which is the same as defaulting to 15. :mod:`zlib` offers additional parameters for tuning compression. They control -the trade-off between compression rate and CPU and memory usage for the -compression side, transparently for the decompression side. +the trade-off between compression rate, memory usage, and CPU usage only for +compressing. They're transparent for decompressing. Unless mentioned +otherwise, websockets inherits defaults of :func:`~zlib.compressobj`. * **Memory Level** controls the size of the compression state. It must be an - integer between 1 (lowest memory usage) and 9 (best compression). websockets - defaults to 5. A lower memory level can increase speed thanks to memory - locality. + integer between 1 (lowest memory usage) and 9 (best compression). + + websockets defaults to 5. This is lower than zlib's default of 8. Not only + does a lower memory level reduce memory usage, but it can also increase + speed thanks to memory locality. * **Compression Level** controls the effort to optimize compression. It must be an integer between 1 (lowest CPU usage) and 9 (best compression). @@ -95,16 +106,17 @@ compression side, transparently for the decompression side. * **Strategy** selects the compression strategy. The best choice depends on the type of data being compressed. -Unless mentioned otherwise, websockets uses the defaults of -:func:`zlib.compressobj` for all these settings. Tuning compression ------------------ +For servers +........... + By default, websockets enables compression with conservative settings that -optimize memory usage at the cost of a slightly worse compression rate: Window -Bits = 12 and Memory Level = 5. This strikes a good balance for small messages -that are typical of WebSocket servers. +optimize memory usage at the cost of a slightly worse compression rate: +Window Bits = 12 and Memory Level = 5. This strikes a good balance for small +messages that are typical of WebSocket servers. Here's how various compression settings affect memory usage of a single connection on a 64-bit system, as well a benchmark of compressed size and @@ -152,6 +164,33 @@ usage is: CPU usage is also higher for compression than decompression. +For clients +........... + +By default, websockets enables compression with Memory Level = 5 but leaves +the Window Bits setting up to the server. + +There's two good reasons and one bad reason for not optimizing the client side +like the server side: + +1. If the maintainers of a server configured some optimized settings, we don't + want to override them with more restrictive settings. + +2. Optimizing memory usage doesn't matter very much for clients because it's + uncommon to open thousands of client connections in a program. + +3. On a more pragmatic note, some servers misbehave badly when a client + configures compression settings. `AWS API Gateway`_ is the worst offender. + + .. _AWS API Gateway: https://github.com/aaugustin/websockets/issues/1065 + + Unfortunately, even though websockets is right and AWS is wrong, many users + jump to the conclusion that websockets doesn't work. + + Until the ecosystem levels up, interoperability with buggy servers seems + more valuable than optimizing memory usage. + + Further reading --------------- diff --git a/docs/topics/memory.rst b/docs/topics/memory.rst index ee0109c35..e44247a77 100644 --- a/docs/topics/memory.rst +++ b/docs/topics/memory.rst @@ -19,6 +19,9 @@ Baseline Compression settings are the main factor affecting the baseline amount of memory used by each connection. +With websockets' defaults, on the server side, a single connections uses +70 KiB of memory. + Refer to the :doc:`topic guide on compression <../topics/compression>` to learn more about tuning compression settings. diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index da2bc153e..fefa55643 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -448,8 +448,6 @@ def enable_client_permessage_deflate( ): extensions = list(extensions) + [ ClientPerMessageDeflateFactory( - server_max_window_bits=12, - client_max_window_bits=12, compress_settings={"memLevel": 5}, ) ] From b240c047f342411140dee8e774beae7abe55381a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 6 Nov 2021 09:14:00 +0100 Subject: [PATCH 0975/1539] Add wheels for more architectures. --- .github/workflows/wheels.yml | 5 +++-- docs/project/changelog.rst | 2 ++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 7e62237b3..a6df67743 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -47,9 +47,10 @@ jobs: with: platforms: all - name: Build wheels - uses: joerick/cibuildwheel@v1.11.0 + uses: pypa/cibuildwheel@v2.2.2 env: - CIBW_ARCHS_LINUX: auto aarch64 + CIBW_ARCHS_MACOS: "x86_64 universal2 arm64" + CIBW_ARCHS_LINUX: "auto aarch64" CIBW_SKIP: cp36-* - name: Save wheels uses: actions/upload-artifact@v2 diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index a4176d572..50b8ae7a5 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -33,6 +33,8 @@ They may change at any time. Improvements ............ +* Added wheels for Python 3.10, PyPy 3.7, and for more platforms. + * Made the second parameter of connection handlers optional. It will be deprecated in the next major release. The request path is available in the :attr:`~legacy.protocol.WebSocketCommonProtocol.path` attribute of From b4d9a2add6a9e716c9405dc74aae8bd521dec3ae Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 6 Nov 2021 09:31:38 +0100 Subject: [PATCH 0976/1539] Clarified API change on connection handlers. --- docs/project/changelog.rst | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 50b8ae7a5..5b190fb94 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -30,16 +30,30 @@ They may change at any time. *In development* -Improvements +New features ............ -* Added wheels for Python 3.10, PyPy 3.7, and for more platforms. - * Made the second parameter of connection handlers optional. It will be deprecated in the next major release. The request path is available in the :attr:`~legacy.protocol.WebSocketCommonProtocol.path` attribute of the first argument. + If you implemented the connection handler of a server as:: + + async def handler(request, path): + ... + + You should replace it by:: + + async def handler(request): + path = request.path # if handler() uses the path argument + ... + +Improvements +............ + +* Added wheels for Python 3.10, PyPy 3.7, and for more platforms. + * Reverted optimization of default compression settings for clients, mainly to avoid triggering bugs in poorly implemented servers like `AWS API Gateway`_. From 56b589368175bf18fa1c3f0036460c0c2f3e4c4b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 11 Nov 2021 19:14:09 +0100 Subject: [PATCH 0977/1539] Unpin Python 3.9.6. Reverts 744482a2. Fix #1051. --- .github/workflows/tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d8a95a70d..3aa579aa4 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -18,7 +18,7 @@ jobs: - name: Install Python 3.x uses: actions/setup-python@v2 with: - python-version: "3.9.6" + python-version: "3.x" - name: Install tox run: pip install tox - name: Run tests with coverage @@ -38,7 +38,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python: ["3.7", "3.8", "3.9.6", "3.10"] + python: ["3.7", "3.8", "3.9", "3.10"] steps: - name: Check out repository uses: actions/checkout@v2 From 27a439b6d20089dd24ab5a08d07d0585cc31ee34 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 11 Nov 2021 19:31:33 +0100 Subject: [PATCH 0978/1539] Work around bug in coverage. --- tests/legacy/test_client_server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index fea2b4178..4eb8229b2 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1454,6 +1454,7 @@ async def run_client(): await server_ws.close() with self.assertRaises(ConnectionClosed): await ws.recv() + pass # work around bug in coverage else: # Exit block with an exception. raise Exception("BOOM!") From b6e25e290143c4dbb5b7bd0acd0c44efd086167e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 11 Nov 2021 20:16:58 +0100 Subject: [PATCH 0979/1539] Work around bug in mypy. --- src/websockets/legacy/protocol.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index a52f9f71d..3340c33be 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -750,7 +750,7 @@ async def close(self, code: int = 1000, reason: str = "") -> None: self.write_close_frame(Close(code, reason)), self.close_timeout, **loop_if_py_lt_38(self.loop), - ) + ) # type: ignore # remove when removing loop_if_py_lt_38 except asyncio.TimeoutError: # If the close frame cannot be sent because the send buffers # are full, the closing handshake won't complete anyway. @@ -771,7 +771,7 @@ async def close(self, code: int = 1000, reason: str = "") -> None: self.transfer_data_task, self.close_timeout, **loop_if_py_lt_38(self.loop), - ) + ) # type: ignore # remove when removing loop_if_py_lt_38 except (asyncio.TimeoutError, asyncio.CancelledError): pass @@ -1250,7 +1250,7 @@ async def keepalive_ping(self) -> None: pong_waiter, self.ping_timeout, **loop_if_py_lt_38(self.loop), - ) + ) # type: ignore # remove when removing loop_if_py_lt_38 self.logger.debug("% received keepalive pong") except asyncio.TimeoutError: if self.debug: @@ -1365,7 +1365,7 @@ async def wait_for_connection_lost(self) -> bool: asyncio.shield(self.connection_lost_waiter), self.close_timeout, **loop_if_py_lt_38(self.loop), - ) + ) # type: ignore # remove when removing loop_if_py_lt_38 except asyncio.TimeoutError: pass # Re-check self.connection_lost_waiter.done() synchronously because From cbbaa26e4ab303ef3c5aa2212cf45db9c15d14d2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 11 Nov 2021 21:12:41 +0100 Subject: [PATCH 0980/1539] Standardize on the code-block directive. --- docs/howto/autoreload.rst | 4 ++-- docs/howto/django.rst | 18 +++++++++--------- docs/howto/haproxy.rst | 6 +++--- docs/howto/heroku.rst | 28 ++++++++++++---------------- docs/howto/kubernetes.rst | 32 ++++++++++++++++---------------- docs/howto/nginx.rst | 6 +++--- docs/howto/supervisor.rst | 20 ++++++++++---------- docs/topics/authentication.rst | 22 +++++++++++----------- docs/topics/timeouts.rst | 2 +- 9 files changed, 67 insertions(+), 71 deletions(-) diff --git a/docs/howto/autoreload.rst b/docs/howto/autoreload.rst index 4f650acb3..edd87d0fd 100644 --- a/docs/howto/autoreload.rst +++ b/docs/howto/autoreload.rst @@ -14,7 +14,7 @@ However, you can achieve the same result easily. Install watchdog_ with the ``watchmedo`` shell utility: -.. code:: console +.. code-block:: console $ pip install watchdog[watchmedo] @@ -22,7 +22,7 @@ Install watchdog_ with the ``watchmedo`` shell utility: Run your server with ``watchmedo auto-restart``: -.. code:: console +.. code-block:: console $ watchmedo auto-restart --pattern "*.py" --recursive --signal SIGTERM \ python app.py diff --git a/docs/howto/django.rst b/docs/howto/django.rst index a3d4a2ceb..34cce58e9 100644 --- a/docs/howto/django.rst +++ b/docs/howto/django.rst @@ -44,7 +44,7 @@ We want secure, short-lived tokens containing the user ID. We'll rely on Add django-sesame to the dependencies of your Django project, install it, and configure it in the settings of the project: -.. code:: python +.. code-block:: python AUTHENTICATION_BACKENDS = [ "django.contrib.auth.backends.ModelBackend", @@ -62,7 +62,7 @@ We'd like our tokens to be valid for 30 seconds. We expect web pages to load and to establish the WebSocket connection within this delay. Configure django-sesame accordingly in the settings of your Django project: -.. code:: python +.. code-block:: python SESAME_MAX_AGE = 30 @@ -77,7 +77,7 @@ performance. Now you can generate tokens in a ``django-admin shell`` as follows: -.. code:: pycon +.. code-block:: pycon >>> from django.contrib.auth import get_user_model >>> User = get_user_model() @@ -132,7 +132,7 @@ Save this code to a file called ``authentication.py``, make sure the ``DJANGO_SETTINGS_MODULE`` environment variable is set properly, and start the websockets server: -.. code:: console +.. code-block:: console $ python authentication.py @@ -140,7 +140,7 @@ Generate a new token — remember, they're only valid for 30 seconds — and use it to connect to your server. Paste your token and press Enter when you get a prompt: -.. code:: console +.. code-block:: console $ python -m websockets ws://localhost:8888/ Connected to ws://localhost:8888/ @@ -153,7 +153,7 @@ It works! If you enter an expired or invalid token, authentication fails and the server closes the connection: -.. code:: console +.. code-block:: console $ python -m websockets ws://localhost:8888/ Connected to ws://localhost:8888. @@ -163,7 +163,7 @@ closes the connection: You can also test from a browser by generating a new token and running the following code in the JavaScript console of the browser: -.. code:: javascript +.. code-block:: javascript websocket = new WebSocket("ws://localhost:8888/"); websocket.onopen = (event) => websocket.send(""); @@ -205,7 +205,7 @@ the Redis connection directly. Install Redis, add django-redis to the dependencies of your Django project, install it, and configure it in the settings of the project: -.. code:: python +.. code-block:: python CACHES = { "default": { @@ -234,7 +234,7 @@ an event to Redis. Let's check that it works: -.. code:: console +.. code-block:: console $ redis-cli 127.0.0.1:6379> SELECT 1 diff --git a/docs/howto/haproxy.rst b/docs/howto/haproxy.rst index 0ecb46a04..fdaab0401 100644 --- a/docs/howto/haproxy.rst +++ b/docs/howto/haproxy.rst @@ -28,7 +28,7 @@ This configuration runs four instances of the app. Install Supervisor and run it: -.. code:: console +.. code-block:: console $ supervisord -c supervisord.conf -n @@ -46,13 +46,13 @@ servers. This is best for long running connections. Save the configuration to ``haproxy.cfg``, install HAProxy, and run it: -.. code:: console +.. code-block:: console $ haproxy -f haproxy.cfg You can confirm that HAProxy proxies connections properly: -.. code:: console +.. code-block:: console $ PYTHONPATH=src python -m websockets ws://localhost:8080/ Connected to ws://localhost:8080/. diff --git a/docs/howto/heroku.rst b/docs/howto/heroku.rst index b728106e9..6a7c4d00b 100644 --- a/docs/howto/heroku.rst +++ b/docs/howto/heroku.rst @@ -14,7 +14,7 @@ Create application Deploying to Heroku requires a git repository. Let's initialize one: -.. code:: console +.. code-block:: console $ mkdir websockets-echo $ cd websockets-echo @@ -32,7 +32,7 @@ Then, create a Heroku app — if you follow these instructions step-by-step, you'll have to pick a different name because I'm already using ``websockets-echo`` on Heroku: -.. code:: console +.. code-block:: console $ heroku create websockets-echo Creating ⬢ websockets-echo... done @@ -61,20 +61,16 @@ Deploy application In order to build the app, Heroku needs to know that it depends on websockets. Create a ``requirements.txt`` file containing this line: -.. code:: - - websockets +.. literalinclude:: ../../example/deployment/heroku/requirements.txt Heroku also needs to know how to run the app. Create a ``Procfile`` with this content: -.. code:: - - web: python app.py +.. literalinclude:: ../../example/deployment/heroku/Procfile Confirm that you created the correct files and commit them to git: -.. code:: console +.. code-block:: console $ ls Procfile app.py requirements.txt @@ -88,7 +84,7 @@ Confirm that you created the correct files and commit them to git: The app is ready. Let's deploy it! -.. code:: console +.. code-block:: console $ git push heroku main @@ -114,7 +110,7 @@ If you're currently building a websockets server, perhaps you're already in a virtualenv where websockets is installed. If not, you can install it in a new virtualenv as follows: -.. code:: console +.. code-block:: console $ python -m venv websockets-client $ . websockets-client/bin/activate @@ -123,7 +119,7 @@ virtualenv as follows: Connect the interactive client — using the name of your Heroku app instead of ``websockets-echo``: -.. code:: console +.. code-block:: console $ python -m websockets wss://websockets-echo.herokuapp.com/ Connected to wss://websockets-echo.herokuapp.com/. @@ -138,7 +134,7 @@ An insecure connection (``ws://``) would also work. Once you're connected, you can send any message and the server will echo it, then press Ctrl-D to terminate the connection: -.. code:: console +.. code-block:: console > Hello! < Hello! @@ -147,7 +143,7 @@ then press Ctrl-D to terminate the connection: You can also confirm that your application shuts down gracefully. Connect an interactive client again — remember to replace ``websockets-echo`` with your app: -.. code:: console +.. code-block:: console $ python -m websockets wss://websockets-echo.herokuapp.com/ Connected to wss://websockets-echo.herokuapp.com/. @@ -155,7 +151,7 @@ interactive client again — remember to replace ``websockets-echo`` with your a In another shell, restart the dyno — again, replace ``websockets-echo`` with your app: -.. code:: console +.. code-block:: console $ heroku dyno:restart -a websockets-echo Restarting dynos on ⬢ websockets-echo... done @@ -163,7 +159,7 @@ In another shell, restart the dyno — again, replace ``websockets-echo`` with y Go back to the first shell. The connection is closed with code 1001 (going away). -.. code:: console +.. code-block:: console $ python -m websockets wss://websockets-echo.herokuapp.com/ Connected to wss://websockets-echo.herokuapp.com/. diff --git a/docs/howto/kubernetes.rst b/docs/howto/kubernetes.rst index 0e77aeac1..26dbf8a94 100644 --- a/docs/howto/kubernetes.rst +++ b/docs/howto/kubernetes.rst @@ -36,13 +36,13 @@ guide, so we'll go for the simplest possible configuration instead: After saving this ``Dockerfile``, build the image: -.. code:: console +.. code-block:: console $ docker build -t websockets-test:1.0 . Test your image by running: -.. code:: console +.. code-block:: console $ docker run --name run-websockets-test --publish 32080:80 --rm \ websockets-test:1.0 @@ -50,7 +50,7 @@ Test your image by running: Then, in another shell, in a virtualenv where websockets is installed, connect to the app and check that it echoes anything you send: -.. code:: console +.. code-block:: console $ python -m websockets ws://localhost:32080/ Connected to ws://localhost:32080/. @@ -60,14 +60,14 @@ to the app and check that it echoes anything you send: Now, in yet another shell, stop the app with: -.. code:: console +.. code-block:: console $ docker kill -s TERM run-websockets-test Going to the shell where you connected to the app, you can confirm that it shut down gracefully: -.. code:: console +.. code-block:: console $ python -m websockets ws://localhost:32080/ Connected to ws://localhost:32080/. @@ -96,7 +96,7 @@ to production, you would configure an Ingress_. After saving this to a file called ``deployment.yaml``, you can deploy: -.. code:: console +.. code-block:: console $ kubectl apply -f deployment.yaml service/websockets-test created @@ -104,7 +104,7 @@ After saving this to a file called ``deployment.yaml``, you can deploy: Now you have a deployment with one pod running: -.. code:: console +.. code-block:: console $ kubectl get deployment websockets-test NAME READY UP-TO-DATE AVAILABLE AGE @@ -115,7 +115,7 @@ Now you have a deployment with one pod running: You can connect to the service — press Ctrl-D to exit: -.. code:: console +.. code-block:: console $ python -m websockets ws://localhost:32080/ Connected to ws://localhost:32080/. @@ -126,7 +126,7 @@ Validate deployment First, let's ensure the liveness probe works by making the app unresponsive: -.. code:: console +.. code-block:: console $ curl http://localhost:32080/inemuri Sleeping for 10s @@ -139,7 +139,7 @@ Therefore Kubernetes should restart the pod after at most 5 seconds. Indeed, after a few seconds, the pod reports a restart: -.. code:: console +.. code-block:: console $ kubectl get pods -l app=websockets-test NAME READY STATUS RESTARTS AGE @@ -147,14 +147,14 @@ Indeed, after a few seconds, the pod reports a restart: Next, let's take it one step further and crash the app: -.. code:: console +.. code-block:: console $ curl http://localhost:32080/seppuku Terminating The pod reports a second restart: -.. code:: console +.. code-block:: console $ kubectl get pods -l app=websockets-test NAME READY STATUS RESTARTS AGE @@ -167,14 +167,14 @@ Scale deployment Of course, Kubernetes is for scaling. Let's scale — modestly — to 10 pods: -.. code:: console +.. code-block:: console $ kubectl scale deployment.apps/websockets-test --replicas=10 deployment.apps/websockets-test scaled After a few seconds, we have 10 pods running: -.. code:: console +.. code-block:: console $ kubectl get deployment websockets-test NAME READY UP-TO-DATE AVAILABLE AGE @@ -191,7 +191,7 @@ over 50 * 6 * 0.1 = 30 seconds. Let's try it: -.. code:: console +.. code-block:: console $ ulimit -n 512 $ time python benchmark.py 500 6 @@ -204,7 +204,7 @@ stabilize the test setup. Finally, we can scale back to one pod. -.. code:: console +.. code-block:: console $ kubectl scale deployment.apps/websockets-test --replicas=1 deployment.apps/websockets-test scaled diff --git a/docs/howto/nginx.rst b/docs/howto/nginx.rst index e20f82098..30545fbc7 100644 --- a/docs/howto/nginx.rst +++ b/docs/howto/nginx.rst @@ -33,7 +33,7 @@ This configuration runs four instances of the app. Install Supervisor and run it: -.. code:: console +.. code-block:: console $ supervisord -c supervisord.conf -n @@ -69,13 +69,13 @@ Then we combine the `WebSocket proxying`_ and `load balancing`_ guides: Save the configuration to ``nginx.conf``, install nginx, and run it: -.. code:: console +.. code-block:: console $ nginx -c nginx.conf -p . You can confirm that nginx proxies connections properly: -.. code:: console +.. code-block:: console $ PYTHONPATH=src python -m websockets ws://localhost:8080/ Connected to ws://localhost:8080/. diff --git a/docs/howto/supervisor.rst b/docs/howto/supervisor.rst index 0c07aebae..5eefc7711 100644 --- a/docs/howto/supervisor.rst +++ b/docs/howto/supervisor.rst @@ -14,14 +14,14 @@ connections. Create and activate a virtualenv: -.. code:: console +.. code-block:: console $ python -m venv supervisor-websockets $ . supervisor-websockets/bin/activate Install websockets and Supervisor: -.. code:: console +.. code-block:: console $ pip install websockets $ pip install supervisor @@ -45,7 +45,7 @@ running, restarting them if they exit. Now start Supervisor in the foreground: -.. code:: console +.. code-block:: console $ supervisord -c supervisord.conf -n INFO Increased RLIMIT_NOFILE limit to 1024 @@ -62,7 +62,7 @@ Now start Supervisor in the foreground: In another shell, after activating the virtualenv, we can connect to the app — press Ctrl-D to exit: -.. code:: console +.. code-block:: console $ python -m websockets ws://localhost:8080/ Connected to ws://localhost:8080/. @@ -72,13 +72,13 @@ press Ctrl-D to exit: Look at the pid of an instance of the app in the logs and terminate it: -.. code:: console +.. code-block:: console $ kill -TERM 43597 The logs show that Supervisor restarted this instance: -.. code:: console +.. code-block:: console INFO exited: websockets-test_00 (exit status 0; expected) INFO spawned: 'websockets-test_00' with pid 43629 @@ -87,7 +87,7 @@ The logs show that Supervisor restarted this instance: Now let's check what happens when we shut down Supervisor, but first let's establish a connection and leave it open: -.. code:: console +.. code-block:: console $ python -m websockets ws://localhost:8080/ Connected to ws://localhost:8080/. @@ -95,14 +95,14 @@ establish a connection and leave it open: Look at the pid of supervisord itself in the logs and terminate it: -.. code:: console +.. code-block:: console $ kill -TERM 43596 The logs show that Supervisor terminated all instances of the app before exiting: -.. code:: console +.. code-block:: console WARN received SIGTERM indicating exit request INFO waiting for websockets-test_00, websockets-test_01, websockets-test_02, websockets-test_03 to die @@ -113,7 +113,7 @@ exiting: And you can see that the connection to the app was closed gracefully: -.. code:: console +.. code-block:: console $ python -m websockets ws://localhost:8080/ Connected to ws://localhost:8080/. diff --git a/docs/topics/authentication.rst b/docs/topics/authentication.rst index 4d702b2f6..31bfd6465 100644 --- a/docs/topics/authentication.rst +++ b/docs/topics/authentication.rst @@ -154,7 +154,7 @@ Run the experiment in an environment where websockets is installed: .. _experiments/authentication: https://github.com/aaugustin/websockets/tree/main/experiments/authentication -.. code:: console +.. code-block:: console $ python experiments/authentication/app.py Running on http://localhost:8000/ @@ -172,7 +172,7 @@ First message As soon as the connection is open, the client sends a message containing the token: -.. code:: javascript +.. code-block:: javascript const websocket = new WebSocket("ws://.../"); websocket.onopen = () => websocket.send(token); @@ -183,7 +183,7 @@ At the beginning of the connection handler, the server receives this message and authenticates the user. If authentication fails, the server closes the connection: -.. code:: python +.. code-block:: python async def first_message_handler(websocket): token = await websocket.recv() @@ -200,7 +200,7 @@ Query parameter The client adds the token to the WebSocket URI in a query parameter before opening the connection: -.. code:: javascript +.. code-block:: javascript const uri = `ws://.../?token=${token}`; const websocket = new WebSocket(uri); @@ -210,7 +210,7 @@ opening the connection: The server intercepts the HTTP request, extracts the token and authenticates the user. If authentication fails, it returns a HTTP 401: -.. code:: python +.. code-block:: python class QueryParamProtocol(websockets.WebSocketServerProtocol): async def process_request(self, path, headers): @@ -237,7 +237,7 @@ The client sets a cookie containing the token before opening the connection. The cookie must be set by an iframe loaded from the same origin as the WebSocket server. This requires passing the token to this iframe. -.. code:: javascript +.. code-block:: javascript // in main window iframe.contentWindow.postMessage(token, "http://..."); @@ -256,7 +256,7 @@ This involves several events. Look at the full implementation for details. The server intercepts the HTTP request, extracts the token and authenticates the user. If authentication fails, it returns a HTTP 401: -.. code:: python +.. code-block:: python class CookieProtocol(websockets.WebSocketServerProtocol): async def process_request(self, path, headers): @@ -284,7 +284,7 @@ User information The client adds the token to the WebSocket URI in user information before opening the connection: -.. code:: javascript +.. code-block:: javascript const uri = `ws://token:${token}@.../`; const websocket = new WebSocket(uri); @@ -297,7 +297,7 @@ than a token, we send ``token`` as username and the token as password. The server intercepts the HTTP request, extracts the token and authenticates the user. If authentication fails, it returns a HTTP 401: -.. code:: python +.. code-block:: python class UserInfoProtocol(websockets.BasicAuthWebSocketServerProtocol): async def check_credentials(self, username, password): @@ -326,7 +326,7 @@ solution in this scenario. To authenticate a websockets client with HTTP Basic Authentication (:rfc:`7617`), include the credentials in the URI: -.. code:: python +.. code-block:: python async with websockets.connect( f"wss://{username}:{password}@example.com", @@ -339,7 +339,7 @@ contain unsafe characters.) To authenticate a websockets client with HTTP Bearer Authentication (:rfc:`6750`), add a suitable ``Authorization`` header: -.. code:: python +.. code-block:: python async with websockets.connect( "wss://example.com", diff --git a/docs/topics/timeouts.rst b/docs/topics/timeouts.rst index 51666ceea..815a29b3f 100644 --- a/docs/topics/timeouts.rst +++ b/docs/topics/timeouts.rst @@ -56,7 +56,7 @@ Latency between a client and a server may increase for two reasons: You can monitor latency as follows: -.. code:: python +.. code-block:: python import asyncio import logging From 46d973a3163a92117259122b7b636cff1b142a23 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 14 Nov 2021 14:54:06 +0100 Subject: [PATCH 0981/1539] Upgrade RTD config to v2. --- .readthedocs.yml | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/.readthedocs.yml b/.readthedocs.yml index 109affab4..0369e0656 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -1,7 +1,13 @@ +version: 2 + build: - image: latest + os: ubuntu-20.04 + tools: + python: "3.10" -python: - version: 3.7 +sphinx: + configuration: docs/conf.py -requirements_file: docs/requirements.txt +python: + install: + - requirements: docs/requirements.txt From e6b8522b25d66b4847a29fa1d0fa27f135909279 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 14 Nov 2021 16:42:32 +0100 Subject: [PATCH 0982/1539] Remove dead code. --- tests/test_imports.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/test_imports.py b/tests/test_imports.py index d84808902..8f1625a9b 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -1,4 +1,3 @@ -import sys import types import unittest import warnings @@ -33,10 +32,6 @@ def test_get_deprecated_alias(self): with warnings.catch_warnings(record=True) as recorded_warnings: self.assertEqual(self.mod.bar, bar) - # No warnings raised on pre-PEP 526 Python. - if sys.version_info[:2] < (3, 7): # pragma: no cover - return - self.assertEqual(len(recorded_warnings), 1) warning = recorded_warnings[0].message self.assertEqual( From 11cbe72428bf78ac2e89b1ec2844969bc3f7ab8a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 14 Nov 2021 15:19:46 +0100 Subject: [PATCH 0983/1539] Introduce development version numbers. This prevents developement versions of websocket from advertising themselves as the previous release while they may contain backwards-incompatible changes since that release. --- docs/conf.py | 18 ++------- src/websockets/version.py | 77 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 79 insertions(+), 16 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index d22b85a82..750682573 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -25,8 +25,7 @@ copyright = f"2013-{datetime.date.today().year}, Aymeric Augustin and contributors" author = "Aymeric Augustin" -# The full version, including alpha/beta/rc tags -release = "10.0" +from websockets.version import tag as version, version as release # -- General configuration --------------------------------------------------- @@ -91,20 +90,9 @@ exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # Configure viewcode extension. -try: - git_sha1 = subprocess.run( - "git rev-parse --short HEAD", - capture_output=True, - shell=True, - check=True, - text=True, - ).stdout.strip() -except subprocess.SubprocessError as exc: - print("Cannot get git commit, disabling linkcode:", exc) - extensions.remove("sphinx.ext.linkcode") -else: - code_url = f"https://github.com/aaugustin/websockets/blob/{git_sha1}" +from websockets.version import commit +code_url = f"https://github.com/aaugustin/websockets/blob/{commit}" def linkcode_resolve(domain, info): assert domain == "py" diff --git a/src/websockets/version.py b/src/websockets/version.py index 168f8b054..3a6e6aa08 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -1 +1,76 @@ -version = "10.0" +from __future__ import annotations + + +__all__ = ["tag", "version", "commit"] + + +# ========= =========== =================== +# release development +# ========= =========== =================== +# tag X.Y X.Y (upcoming) +# version X.Y X.Y.dev1+g5678cde +# commit X.Y 5678cde +# ========= =========== =================== + + +# When tagging a release, set `released = True`. +# After tagging a release, set `released = False` and increment `tag`. + +released = False + +tag = version = commit = "10.1" + + +if not released: # pragma: no cover + import pathlib + import re + import subprocess + + def get_version(tag: str) -> str: + # Since setup.py executes the contents of src/websockets/version.py, + # __file__ can point to either of these two files. + file_path = pathlib.Path(__file__) + root_dir = file_path.parents[0 if file_path.name == "setup.py" else 2] + + # Read version from git if available. This prevents reading stale + # information from src/websockets.egg-info after building a sdist. + try: + description = subprocess.run( + ["git", "describe", "--dirty", "--tags", "--long"], + capture_output=True, + cwd=root_dir, + check=True, + text=True, + ).stdout.strip() + except subprocess.CalledProcessError: + pass + else: + description_re = r"[0-9.]+-([0-9]+)-(g[0-9a-f]{7}(?:-dirty)?)" + match = re.fullmatch(description_re, description) + assert match is not None + distance, remainder = match.groups() + remainder = remainder.replace("-", ".") # required by PEP 440 + return f"{tag}.dev{distance}+{remainder}" + + # Read version from package metadata if it is installed. + try: + import importlib.metadata # move up when dropping Python 3.7 + + return importlib.metadata.version("websockets") + except ImportError: + pass + + # Avoid crashing if the development version cannot be determined. + return f"{tag}.dev0+gunknown" + + version = get_version(tag) + + def get_commit(tag: str, version: str) -> str: + # Extract commit from version, falling back to tag if not available. + version_re = r"[0-9.]+\.dev[0-9]+\+g([0-9a-f]{7}|unknown)(?:\.dirty)?" + match = re.fullmatch(version_re, version) + assert match is not None + (commit,) = match.groups() + return tag if commit == "unknown" else commit + + commit = get_commit(tag, version) From 150f2a04551a3af669dccfdc5efe59d1e2cae40a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 18 Sep 2021 08:09:38 +0200 Subject: [PATCH 0984/1539] Update marketing pitch. --- README.rst | 23 +++++++++++------------ docs/index.rst | 8 ++++---- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/README.rst b/README.rst index 82f89f952..f8df94ba4 100644 --- a/README.rst +++ b/README.rst @@ -25,11 +25,10 @@ What is ``websockets``? ----------------------- -``websockets`` is a library for building WebSocket servers_ and clients_ in -Python with a focus on correctness and simplicity. +websockets is a library for building WebSocket_ servers and clients in Python +with a focus on correctness, simplicity, robustness, and performance. -.. _servers: https://github.com/aaugustin/websockets/blob/main/example/server.py -.. _clients: https://github.com/aaugustin/websockets/blob/main/example/client.py +.. _WebSocket: https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API Built on top of ``asyncio``, Python's standard asynchronous I/O framework, it provides an elegant coroutine-based API. @@ -92,19 +91,19 @@ Why should I use ``websockets``? The development of ``websockets`` is shaped by four principles: -1. **Simplicity**: all you need to understand is ``msg = await ws.recv()`` and - ``await ws.send(msg)``; ``websockets`` takes care of managing connections +1. **Correctness**: ``websockets`` is heavily tested for compliance + with :rfc:`6455`. Continuous integration fails under 100% branch + coverage. + +2. **Simplicity**: all you need to understand is ``msg = await ws.recv()`` and + ``await ws.send(msg)``. ``websockets`` takes care of managing connections so you can focus on your application. -2. **Robustness**: ``websockets`` is built for production; for example it was +3. **Robustness**: ``websockets`` is built for production. For example, it was the only library to `handle backpressure correctly`_ before the issue became widely known in the Python community. -3. **Quality**: ``websockets`` is heavily tested. Continuous integration fails - under 100% branch coverage. Also it passes the industry-standard `Autobahn - Testsuite`_. - -4. **Performance**: memory usage is configurable. An extension written in C +4. **Performance**: memory usage is optimized and configurable. A C extension accelerates expensive operations. It's pre-compiled for Linux, macOS and Windows and packaged in the wheel format for each system and Python version. diff --git a/docs/index.rst b/docs/index.rst index 30d01c2f8..def9076af 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -21,11 +21,11 @@ websockets .. |docs| image:: https://img.shields.io/readthedocs/websockets.svg :target: https://websockets.readthedocs.io/ -websockets is a library for building WebSocket servers_ and clients_ in Python -with a focus on correctness and simplicity. +websockets is a library for building WebSocket_ servers and +clients in Python with a focus on correctness, simplicity, robustness, and +performance. -.. _servers: https://github.com/aaugustin/websockets/blob/main/example/server.py -.. _clients: https://github.com/aaugustin/websockets/blob/main/example/client.py +.. _WebSocket: https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API Built on top of :mod:`asyncio`, Python's standard asynchronous I/O framework, it provides an elegant coroutine-based API. From 22d881b00e823cce7237ac7c1aaad2c5d09cb95f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 18 Sep 2021 08:10:14 +0200 Subject: [PATCH 0985/1539] Move interactive client to front page. Remove dependency on echo.websocket.org which no longer exists. --- docs/index.rst | 10 ++++++++++ docs/intro/index.rst | 7 ------- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index def9076af..bdde6dc74 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -38,6 +38,16 @@ And here's an echo server: .. literalinclude:: ../example/echo.py +Also, websockets provides an interactive client: + +.. code-block:: console + + $ python -m websockets ws://localhost:8765/ + Connected to ws://localhost:8765/. + > Hello world! + < Hello world! + Connection closed: 1000 (OK). + Do you like it? Let's dive in! .. toctree:: diff --git a/docs/intro/index.rst b/docs/intro/index.rst index bd7c48f81..f1ea2a2a4 100644 --- a/docs/intro/index.rst +++ b/docs/intro/index.rst @@ -210,10 +210,3 @@ You don't have to worry about performing the opening or the closing handshake, answering pings, or any other behavior required by the specification. websockets handles all this under the hood so you don't have to. - -One more thing... ------------------ - -websockets provides an interactive client:: - - $ python -m websockets wss://echo.websocket.org/ From b9e11125b781da3cf567cd2cf3036478c6ee1670 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 18 Sep 2021 08:21:54 +0200 Subject: [PATCH 0986/1539] Move API design principle to front page. --- docs/index.rst | 4 ++++ docs/intro/index.rst | 10 ---------- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index bdde6dc74..07835a81c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -38,6 +38,10 @@ And here's an echo server: .. literalinclude:: ../example/echo.py +Don't worry about the opening and closing handshakes, pings and pongs, or any +other behavior described in the specification. websockets takes care of this +under the hood so you can focus on your application! + Also, websockets provides an interactive client: .. code-block:: console diff --git a/docs/intro/index.rst b/docs/intro/index.rst index f1ea2a2a4..bf3c96b33 100644 --- a/docs/intro/index.rst +++ b/docs/intro/index.rst @@ -200,13 +200,3 @@ unregister them when they disconnect. This simplistic example keeps track of connected clients in memory. This only works as long as you run a single process. In a practical application, the handler may subscribe to some channels on a message broker, for example. - -That's all! ------------ - -The design of the websockets API was driven by simplicity. - -You don't have to worry about performing the opening or the closing handshake, -answering pings, or any other behavior required by the specification. - -websockets handles all this under the hood so you don't have to. From 0fafda84ae722bf319a46d7cad4c67bcfd47c16d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 18 Sep 2021 16:57:48 +0200 Subject: [PATCH 0987/1539] Extract patterns to a howto guide. Improve this guide. Fix #1022. --- docs/howto/index.rst | 1 + docs/howto/patterns.rst | 110 ++++++++++++++++++++++++++++++++++++++++ docs/intro/index.rst | 82 ------------------------------ 3 files changed, 111 insertions(+), 82 deletions(-) create mode 100644 docs/howto/patterns.rst diff --git a/docs/howto/index.rst b/docs/howto/index.rst index d11789d48..d399cebc2 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -8,6 +8,7 @@ If you're stuck, perhaps you'll find the answer here. faq cheatsheet + patterns autoreload This guide will help you integrate websockets into a broader system. diff --git a/docs/howto/patterns.rst b/docs/howto/patterns.rst new file mode 100644 index 000000000..c6f325d21 --- /dev/null +++ b/docs/howto/patterns.rst @@ -0,0 +1,110 @@ +Patterns +======== + +.. currentmodule:: websockets + +Here are typical patterns for processing messages in a WebSocket server or +client. You will certainly implement some of them in your application. + +This page gives examples of connection handlers for a server. However, they're +also applicable to a client, simply by assuming that ``websocket`` is a +connection created with :func:`~client.connect`. + +WebSocket connections are long-lived. You will usually write a loop to process +several messages during the lifetime of a connection. + +Consumer +-------- + +To receive messages from the WebSocket connection:: + + async def consumer_handler(websocket): + async for message in websocket: + await consumer(message) + +In this example, ``consumer()`` is a coroutine implementing your business +logic for processing a message received on the WebSocket connection. Each +message may be :class:`str` or :class:`bytes`. + +Iteration terminates when the client disconnects. + +Producer +-------- + +To send messages to the WebSocket connection:: + + async def producer_handler(websocket): + while True: + message = await producer() + await websocket.send(message) + +In this example, ``producer()`` is a coroutine implementing your business +logic for generating the next message to send on the WebSocket connection. +Each message must be :class:`str` or :class:`bytes`. + +Iteration terminates when the client disconnects +because :meth:`~server.WebSocketServerProtocol.send` raises a +:exc:`~exceptions.ConnectionClosed` exception, +which breaks out of the ``while True`` loop. + +Consumer and producer +--------------------- + +You can receive and send messages on the same WebSocket connection by +combining the consumer and producer patterns. This requires running two tasks +in parallel:: + + async def handler(websocket): + await asyncio.gather( + consumer_handler(websocket), + producer_handler(websocket), + ) + +If a task terminates, :func:`~asyncio.gather` doesn't cancel the other task. +This can result in a situation where the producer keeps running after the +consumer finished, which may leak resources. + +Here's a way to exit and close the WebSocket connection as soon as a task +terminates, after canceling the other task:: + + async def handler(websocket): + consumer_task = asyncio.create_task(consumer_handler(websocket)) + producer_task = asyncio.create_task(producer_handler(websocket)) + done, pending = await asyncio.wait( + [consumer_task, producer_task], + return_when=asyncio.FIRST_COMPLETED, + ) + for task in pending: + task.cancel() + +Registration +------------ + +To keep track of currently connected clients, you can register them when they +connect and unregister them when they disconnect:: + + connected = set() + + async def handler(websocket): + # Register. + connected.add(websocket) + try: + # Broadcast a message to all connected clients. + websockets.broadcast(connected, "Hello!") + await asyncio.sleep(10) + finally: + # Unregister. + connected.remove(websocket) + +This example maintains the set of connected clients in memory. This works as +long as you run a single process. It doesn't scale to multiple processes. + +Publish–subscribe +----------------- + +If you plan to run multiple processes and you want to communicate updates +between processes, then you must deploy a messaging system. You may find +publish-subscribe functionality useful. + +A complete implementation of this idea with Redis is described in +the :doc:`Django integration guide <../howto/django>`. diff --git a/docs/intro/index.rst b/docs/intro/index.rst index bf3c96b33..e95f0889c 100644 --- a/docs/intro/index.rst +++ b/docs/intro/index.rst @@ -118,85 +118,3 @@ Then open this HTML file in several browsers. .. literalinclude:: ../../example/counter.html :language: html - -Common patterns ---------------- - -You will usually want to process several messages during the lifetime of a -connection. Therefore you must write a loop. Here are the basic patterns for -building a WebSocket server. - -Consumer -........ - -For receiving messages and passing them to a ``consumer`` coroutine:: - - async def consumer_handler(websocket): - async for message in websocket: - await consumer(message) - -In this example, ``consumer`` represents your business logic for processing -messages received on the WebSocket connection. - -Iteration terminates when the client disconnects. - -Producer -........ - -For getting messages from a ``producer`` coroutine and sending them:: - - async def producer_handler(websocket): - while True: - message = await producer() - await websocket.send(message) - -In this example, ``producer`` represents your business logic for generating -messages to send on the WebSocket connection. - -:meth:`~legacy.protocol.WebSocketCommonProtocol.send` raises a -:exc:`~exceptions.ConnectionClosed` exception when the client disconnects, -which breaks out of the ``while True`` loop. - -Both sides -.......... - -You can read and write messages on the same connection by combining the two -patterns shown above and running the two tasks in parallel:: - - async def handler(websocket): - consumer_task = asyncio.ensure_future( - consumer_handler(websocket)) - producer_task = asyncio.ensure_future( - producer_handler(websocket)) - done, pending = await asyncio.wait( - [consumer_task, producer_task], - return_when=asyncio.FIRST_COMPLETED, - ) - for task in pending: - task.cancel() - -Registration -............ - -As shown in the synchronization example above, if you need to maintain a list -of currently connected clients, you must register them when they connect and -unregister them when they disconnect. - -:: - - connected = set() - - async def handler(websocket): - # Register. - connected.add(websocket) - try: - # Broadcast a message to all connected clients. - websockets.broadcast(connected, "Hello!") - await asyncio.sleep(10) - finally: - # Unregister. - connected.remove(websocket) - -This simplistic example keeps track of connected clients in memory. This only -works as long as you run a single process. In a practical application, the -handler may subscribe to some channels on a message broker, for example. From 33b38eea7c627fb314aaaa6aa29041917fd7718e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 18 Sep 2021 17:04:15 +0200 Subject: [PATCH 0988/1539] Extract examples to a quick start guide. --- docs/intro/index.rst | 101 ++------------- docs/intro/quickstart.rst | 116 ++++++++++++++++++ example/counter.html | 80 ------------ example/{ => quickstart}/client.py | 5 +- .../client_secure.py} | 5 +- example/quickstart/counter.css | 33 +++++ example/quickstart/counter.html | 18 +++ example/quickstart/counter.js | 26 ++++ example/{ => quickstart}/counter.py | 32 ++--- example/{ => quickstart}/localhost.pem | 0 example/{ => quickstart}/server.py | 5 +- .../server_secure.py} | 5 +- example/quickstart/show_time.html | 9 ++ example/quickstart/show_time.js | 12 ++ example/quickstart/show_time.py | 18 +++ example/show_time.html | 20 --- example/show_time.py | 20 --- 17 files changed, 261 insertions(+), 244 deletions(-) create mode 100644 docs/intro/quickstart.rst delete mode 100644 example/counter.html rename example/{ => quickstart}/client.py (86%) rename example/{secure_client.py => quickstart/client_secure.py} (86%) create mode 100644 example/quickstart/counter.css create mode 100644 example/quickstart/counter.html create mode 100644 example/quickstart/counter.js rename example/{ => quickstart}/counter.py (57%) rename example/{ => quickstart}/localhost.pem (100%) rename example/{ => quickstart}/server.py (87%) rename example/{secure_server.py => quickstart/server_secure.py} (86%) create mode 100644 example/quickstart/show_time.html create mode 100644 example/quickstart/show_time.js create mode 100755 example/quickstart/show_time.py delete mode 100644 example/show_time.html delete mode 100755 example/show_time.py diff --git a/docs/intro/index.rst b/docs/intro/index.rst index e95f0889c..a5b68bfbf 100644 --- a/docs/intro/index.rst +++ b/docs/intro/index.rst @@ -14,6 +14,8 @@ websockets requires Python ≥ 3.7. For each minor version (3.x), only the latest bugfix or security release (3.x.y) is officially supported. +It doesn't have any dependencies. + Installation ------------ @@ -21,100 +23,13 @@ Install websockets with:: pip install websockets -Basic example -------------- - -.. _server-example: - -Here's a WebSocket server example. - -It reads a name from the client, sends a greeting, and closes the connection. - -.. literalinclude:: ../../example/server.py - :emphasize-lines: 8,18 - -.. _client-example: - -On the server side, websockets executes the handler coroutine ``hello()`` once -for each WebSocket connection. It closes the connection when the handler -coroutine returns. - -Here's a corresponding WebSocket client example. - -.. literalinclude:: ../../example/client.py - :emphasize-lines: 10 - -Using :func:`~client.connect` as an asynchronous context manager ensures the -connection is closed before exiting the ``hello()`` coroutine. - -.. _secure-server-example: - -Secure example --------------- - -Secure WebSocket connections improve confidentiality and also reliability -because they reduce the risk of interference by bad proxies. - -The ``wss`` protocol is to ``ws`` what ``https`` is to ``http``. The -connection is encrypted with TLS_ (Transport Layer Security). ``wss`` -requires certificates like ``https``. - -.. _TLS: https://developer.mozilla.org/en-US/docs/Web/Security/Transport_Layer_Security - -.. admonition:: TLS vs. SSL - :class: tip - - TLS is sometimes referred to as SSL (Secure Sockets Layer). SSL was an - earlier encryption protocol; the name stuck. - -Here's how to adapt the server example to provide secure connections. See the -documentation of the :mod:`ssl` module for configuring the context securely. - -.. literalinclude:: ../../example/secure_server.py - :emphasize-lines: 19-21,24 - -Here's how to adapt the client. - -.. literalinclude:: ../../example/secure_client.py - :emphasize-lines: 10-12,16 - -This client needs a context because the server uses a self-signed certificate. - -A client connecting to a secure WebSocket server with a valid certificate -(i.e. signed by a CA that your Python installation trusts) can simply pass -``ssl=True`` to :func:`~client.connect` instead of building a context. - -Browser-based example ---------------------- - -Here's an example of how to run a WebSocket server and connect from a browser. - -Run this script in a console: - -.. literalinclude:: ../../example/show_time.py - -Then open this HTML file in a browser. - -.. literalinclude:: ../../example/show_time.html - :language: html - -Synchronization example ------------------------ - -A WebSocket server can receive events from clients, process them to update the -application state, and synchronize the resulting state across clients. - -Here's an example where any client can increment or decrement a counter. -Updates are propagated to all connected clients. - -The concurrency model of :mod:`asyncio` guarantees that updates are -serialized. +Wheels are available for all platforms. -Run this script in a console: +First steps +----------- -.. literalinclude:: ../../example/counter.py +If you're in a hurry, check out these examples. -Then open this HTML file in several browsers. +.. toctree:: -.. literalinclude:: ../../example/counter.html - :language: html + quickstart diff --git a/docs/intro/quickstart.rst b/docs/intro/quickstart.rst new file mode 100644 index 000000000..8c1221126 --- /dev/null +++ b/docs/intro/quickstart.rst @@ -0,0 +1,116 @@ +Quick start +=========== + +.. currentmodule:: websockets + +Here are a few examples to get you started quickly with websockets. + +Hello world! +------------ + +Here's a WebSocket server. + +It receives a name from the client, sends a greeting, and closes the connection. + +.. literalinclude:: ../../example/quickstart/server.py + +:func:`~server.serve` executes the connection handler coroutine ``hello()`` +once for each WebSocket connection. It closes the WebSocket connection when +the handler returns. + +Here's a corresponding WebSocket client. + +It sends a name to the server, receives a greeting, and closes the connection. + +.. literalinclude:: ../../example/quickstart/client.py + +Using :func:`~client.connect` as an asynchronous context manager ensures the +WebSocket connection is closed. + +.. _secure-server-example: + +Encryption +---------- + +Secure WebSocket connections improve confidentiality and also reliability +because they reduce the risk of interference by bad proxies. + +The ``wss`` protocol is to ``ws`` what ``https`` is to ``http``. The +connection is encrypted with TLS_ (Transport Layer Security). ``wss`` +requires certificates like ``https``. + +.. _TLS: https://developer.mozilla.org/en-US/docs/Web/Security/Transport_Layer_Security + +.. admonition:: TLS vs. SSL + :class: tip + + TLS is sometimes referred to as SSL (Secure Sockets Layer). SSL was an + earlier encryption protocol; the name stuck. + +Here's how to adapt the server to encrypt connections. See the documentation +of the :mod:`ssl` module for configuring the context securely. + +.. literalinclude:: ../../example/quickstart/server_secure.py + +Here's how to adapt the client similarly. + +.. literalinclude:: ../../example/quickstart/client_secure.py + +This client needs a context because the server uses a self-signed certificate. + +When connecting to a secure WebSocket server with a valid certificate — any +certificate signed by a CA that your Python installation trusts — you can +simply pass ``ssl=True`` to :func:`~client.connect`. + +In a browser +------------ + +The WebSocket protocol was invented for the web — as the name says! + +Here's how to connect to a WebSocket server in a browser. + +Run this script in a console: + +.. literalinclude:: ../../example/quickstart/show_time.py + +Save this file as ``show_time.html``: + +.. literalinclude:: ../../example/quickstart/show_time.html + :language: html + +Save this file as ``show_time.js``: + +.. literalinclude:: ../../example/quickstart/show_time.js + :language: js + +Then open ``show_time.html`` in a browser and see the clock tick irregularly. + +Broadcast +--------- + +A WebSocket server can receive events from clients, process them to update the +application state, and broadcast the updated state to all connected clients. + +Here's an example where any client can increment or decrement a counter. The +concurrency model of :mod:`asyncio` guarantees that updates are serialized. + +Run this script in a console: + +.. literalinclude:: ../../example/quickstart/counter.py + +Save this file as ``counter.html``: + +.. literalinclude:: ../../example/quickstart/counter.html + :language: html + +Save this file as ``counter.css``: + +.. literalinclude:: ../../example/quickstart/counter.css + :language: css + +Save this file as ``counter.js``: + +.. literalinclude:: ../../example/quickstart/counter.js + :language: js + +Then open ``counter.html`` file in several browsers and play with [+] and [-]. diff --git a/example/counter.html b/example/counter.html deleted file mode 100644 index 6310c4a16..000000000 --- a/example/counter.html +++ /dev/null @@ -1,80 +0,0 @@ - - - - WebSocket demo - - - -
-
-
-
?
-
+
-
-
- ? online -
- - - diff --git a/example/client.py b/example/quickstart/client.py similarity index 86% rename from example/client.py rename to example/quickstart/client.py index 062540202..8d588c2b0 100755 --- a/example/client.py +++ b/example/quickstart/client.py @@ -1,7 +1,5 @@ #!/usr/bin/env python -# WS client example - import asyncio import websockets @@ -16,4 +14,5 @@ async def hello(): greeting = await websocket.recv() print(f"<<< {greeting}") -asyncio.run(hello()) +if __name__ == "__main__": + asyncio.run(hello()) diff --git a/example/secure_client.py b/example/quickstart/client_secure.py similarity index 86% rename from example/secure_client.py rename to example/quickstart/client_secure.py index 8a1551e29..f4b39f2b8 100755 --- a/example/secure_client.py +++ b/example/quickstart/client_secure.py @@ -1,7 +1,5 @@ #!/usr/bin/env python -# WSS (WS over TLS) client example, with a self-signed certificate - import asyncio import pathlib import ssl @@ -22,4 +20,5 @@ async def hello(): greeting = await websocket.recv() print(f"<<< {greeting}") -asyncio.run(hello()) +if __name__ == "__main__": + asyncio.run(hello()) diff --git a/example/quickstart/counter.css b/example/quickstart/counter.css new file mode 100644 index 000000000..e1f4b7714 --- /dev/null +++ b/example/quickstart/counter.css @@ -0,0 +1,33 @@ +body { + font-family: "Courier New", sans-serif; + text-align: center; +} +.buttons { + font-size: 4em; + display: flex; + justify-content: center; +} +.button, .value { + line-height: 1; + padding: 2rem; + margin: 2rem; + border: medium solid; + min-height: 1em; + min-width: 1em; +} +.button { + cursor: pointer; + user-select: none; +} +.minus { + color: red; +} +.plus { + color: green; +} +.value { + min-width: 2em; +} +.state { + font-size: 2em; +} diff --git a/example/quickstart/counter.html b/example/quickstart/counter.html new file mode 100644 index 000000000..2e3433bd2 --- /dev/null +++ b/example/quickstart/counter.html @@ -0,0 +1,18 @@ + + + + WebSocket demo + + + +
+
-
+
?
+
+
+
+
+ ? online +
+ + + diff --git a/example/quickstart/counter.js b/example/quickstart/counter.js new file mode 100644 index 000000000..37d892a28 --- /dev/null +++ b/example/quickstart/counter.js @@ -0,0 +1,26 @@ +window.addEventListener("DOMContentLoaded", () => { + const websocket = new WebSocket("ws://localhost:6789/"); + + document.querySelector(".minus").addEventListener("click", () => { + websocket.send(JSON.stringify({ action: "minus" })); + }); + + document.querySelector(".plus").addEventListener("click", () => { + websocket.send(JSON.stringify({ action: "plus" })); + }); + + websocket.onmessage = ({ data }) => { + const event = JSON.parse(data); + switch (event.type) { + case "value": + document.querySelector(".value").textContent = event.value; + break; + case "users": + const users = `${event.count} user${event.count == 1 ? "" : "s"}`; + document.querySelector(".users").textContent = users; + break; + default: + console.error("unsupported event", event); + } + }; +}); diff --git a/example/counter.py b/example/quickstart/counter.py similarity index 57% rename from example/counter.py rename to example/quickstart/counter.py index 6e33b3afc..566e12965 100755 --- a/example/counter.py +++ b/example/quickstart/counter.py @@ -1,7 +1,5 @@ #!/usr/bin/env python -# WS server example that synchronizes state across clients - import asyncio import json import logging @@ -9,47 +7,43 @@ logging.basicConfig() -STATE = {"value": 0} - USERS = set() - -def state_event(): - return json.dumps({"type": "state", **STATE}) - +VALUE = 0 def users_event(): return json.dumps({"type": "users", "count": len(USERS)}) +def value_event(): + return json.dumps({"type": "value", "value": VALUE}) async def counter(websocket): + global USERS, VALUE try: # Register user USERS.add(websocket) websockets.broadcast(USERS, users_event()) # Send current state to user - await websocket.send(state_event()) + await websocket.send(value_event()) # Manage state changes async for message in websocket: - data = json.loads(message) - if data["action"] == "minus": - STATE["value"] -= 1 - websockets.broadcast(USERS, state_event()) - elif data["action"] == "plus": - STATE["value"] += 1 - websockets.broadcast(USERS, state_event()) + event = json.loads(message) + if event["action"] == "minus": + VALUE -= 1 + websockets.broadcast(USERS, value_event()) + elif event["action"] == "plus": + VALUE += 1 + websockets.broadcast(USERS, value_event()) else: - logging.error("unsupported event: %s", data) + logging.error("unsupported event: %s", event) finally: # Unregister user USERS.remove(websocket) websockets.broadcast(USERS, users_event()) - async def main(): async with websockets.serve(counter, "localhost", 6789): await asyncio.Future() # run forever - if __name__ == "__main__": asyncio.run(main()) diff --git a/example/localhost.pem b/example/quickstart/localhost.pem similarity index 100% rename from example/localhost.pem rename to example/quickstart/localhost.pem diff --git a/example/server.py b/example/quickstart/server.py similarity index 87% rename from example/server.py rename to example/quickstart/server.py index 7fd7bdf4c..31b182972 100755 --- a/example/server.py +++ b/example/quickstart/server.py @@ -1,7 +1,5 @@ #!/usr/bin/env python -# WS server example - import asyncio import websockets @@ -18,4 +16,5 @@ async def main(): async with websockets.serve(hello, "localhost", 8765): await asyncio.Future() # run forever -asyncio.run(main()) +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/secure_server.py b/example/quickstart/server_secure.py similarity index 86% rename from example/secure_server.py rename to example/quickstart/server_secure.py index f0231bc16..de41d30dc 100755 --- a/example/secure_server.py +++ b/example/quickstart/server_secure.py @@ -1,7 +1,5 @@ #!/usr/bin/env python -# WSS (WS over TLS) server example, with a self-signed certificate - import asyncio import pathlib import ssl @@ -24,4 +22,5 @@ async def main(): async with websockets.serve(hello, "localhost", 8765, ssl=ssl_context): await asyncio.Future() # run forever -asyncio.run(main()) +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/quickstart/show_time.html b/example/quickstart/show_time.html new file mode 100644 index 000000000..b1c93b141 --- /dev/null +++ b/example/quickstart/show_time.html @@ -0,0 +1,9 @@ + + + + WebSocket demo + + + + + diff --git a/example/quickstart/show_time.js b/example/quickstart/show_time.js new file mode 100644 index 000000000..26bed7ec9 --- /dev/null +++ b/example/quickstart/show_time.js @@ -0,0 +1,12 @@ +window.addEventListener("DOMContentLoaded", () => { + const messages = document.createElement("ul"); + document.body.appendChild(messages); + + const websocket = new WebSocket("ws://localhost:5678/"); + websocket.onmessage = ({ data }) => { + const message = document.createElement("li"); + const content = document.createTextNode(data); + message.appendChild(content); + messages.appendChild(message); + }; +}); diff --git a/example/quickstart/show_time.py b/example/quickstart/show_time.py new file mode 100755 index 000000000..facd56b00 --- /dev/null +++ b/example/quickstart/show_time.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python + +import asyncio +import datetime +import random +import websockets + +async def show_time(websocket): + while websocket.open: + await websocket.send(datetime.datetime.utcnow().isoformat() + "Z") + await asyncio.sleep(random.random() * 2 + 1) + +async def main(): + async with websockets.serve(show_time, "localhost", 5678): + await asyncio.Future() # run forever + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/show_time.html b/example/show_time.html deleted file mode 100644 index 721f44264..000000000 --- a/example/show_time.html +++ /dev/null @@ -1,20 +0,0 @@ - - - - WebSocket demo - - - - - diff --git a/example/show_time.py b/example/show_time.py deleted file mode 100755 index b5a153b71..000000000 --- a/example/show_time.py +++ /dev/null @@ -1,20 +0,0 @@ -#!/usr/bin/env python - -# WS server that sends messages at random intervals - -import asyncio -import datetime -import random -import websockets - -async def time(websocket): - while True: - now = datetime.datetime.utcnow().isoformat() + "Z" - await websocket.send(now) - await asyncio.sleep(random.random() * 3) - -async def main(): - async with websockets.serve(time, "localhost", 5678): - await asyncio.Future() # run forever - -asyncio.run(main()) From ce611d79b3c849f66da6645513f6a28195e298e1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Sep 2021 21:20:13 +0200 Subject: [PATCH 0989/1539] Add tutorial. --- docs/conf.py | 6 +- docs/howto/heroku.rst | 4 +- docs/intro/index.rst | 23 +- docs/intro/tutorial1.rst | 591 ++++++++++++++++++++++++ docs/intro/tutorial2.rst | 565 ++++++++++++++++++++++ docs/intro/tutorial3.rst | 290 ++++++++++++ docs/spelling_wordlist.txt | 1 + example/tutorial/start/connect4.css | 105 +++++ example/tutorial/start/connect4.js | 45 ++ example/tutorial/start/connect4.py | 62 +++ example/tutorial/start/favicon.ico | Bin 0 -> 5430 bytes example/tutorial/step1/app.py | 65 +++ example/tutorial/step1/connect4.css | 1 + example/tutorial/step1/connect4.js | 1 + example/tutorial/step1/connect4.py | 1 + example/tutorial/step1/favicon.ico | 1 + example/tutorial/step1/index.html | 10 + example/tutorial/step1/main.js | 53 +++ example/tutorial/step2/app.py | 190 ++++++++ example/tutorial/step2/connect4.css | 1 + example/tutorial/step2/connect4.js | 1 + example/tutorial/step2/connect4.py | 1 + example/tutorial/step2/favicon.ico | 1 + example/tutorial/step2/index.html | 15 + example/tutorial/step2/main.js | 83 ++++ example/tutorial/step3/Procfile | 1 + example/tutorial/step3/app.py | 198 ++++++++ example/tutorial/step3/connect4.css | 1 + example/tutorial/step3/connect4.js | 1 + example/tutorial/step3/connect4.py | 1 + example/tutorial/step3/favicon.ico | 1 + example/tutorial/step3/index.html | 15 + example/tutorial/step3/main.js | 93 ++++ example/tutorial/step3/requirements.txt | 1 + 34 files changed, 2422 insertions(+), 6 deletions(-) create mode 100644 docs/intro/tutorial1.rst create mode 100644 docs/intro/tutorial2.rst create mode 100644 docs/intro/tutorial3.rst create mode 100644 example/tutorial/start/connect4.css create mode 100644 example/tutorial/start/connect4.js create mode 100644 example/tutorial/start/connect4.py create mode 100644 example/tutorial/start/favicon.ico create mode 100644 example/tutorial/step1/app.py create mode 120000 example/tutorial/step1/connect4.css create mode 120000 example/tutorial/step1/connect4.js create mode 120000 example/tutorial/step1/connect4.py create mode 120000 example/tutorial/step1/favicon.ico create mode 100644 example/tutorial/step1/index.html create mode 100644 example/tutorial/step1/main.js create mode 100644 example/tutorial/step2/app.py create mode 120000 example/tutorial/step2/connect4.css create mode 120000 example/tutorial/step2/connect4.js create mode 120000 example/tutorial/step2/connect4.py create mode 120000 example/tutorial/step2/favicon.ico create mode 100644 example/tutorial/step2/index.html create mode 100644 example/tutorial/step2/main.js create mode 100644 example/tutorial/step3/Procfile create mode 100644 example/tutorial/step3/app.py create mode 120000 example/tutorial/step3/connect4.css create mode 120000 example/tutorial/step3/connect4.js create mode 120000 example/tutorial/step3/connect4.py create mode 120000 example/tutorial/step3/favicon.ico create mode 100644 example/tutorial/step3/index.html create mode 100644 example/tutorial/step3/main.js create mode 100644 example/tutorial/step3/requirements.txt diff --git a/docs/conf.py b/docs/conf.py index 750682573..8d9fcdc8f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -95,7 +95,11 @@ code_url = f"https://github.com/aaugustin/websockets/blob/{commit}" def linkcode_resolve(domain, info): - assert domain == "py" + # Non-linkable objects from the starter kit in the tutorial. + if domain == "js" or info["module"] == "connect4": + return + + assert domain == "py", "expected only Python objects" mod = importlib.import_module(info["module"]) if "." in info["fullname"]: diff --git a/docs/howto/heroku.rst b/docs/howto/heroku.rst index 6a7c4d00b..464420e05 100644 --- a/docs/howto/heroku.rst +++ b/docs/howto/heroku.rst @@ -42,6 +42,7 @@ Here's the implementation of the app, an echo server. Save it in a file called ``app.py``: .. literalinclude:: ../../example/deployment/heroku/app.py + :language: text Heroku expects the server to `listen on a specific port`_, which is provided in the ``$PORT`` environment variable. The app reads it and passes it to @@ -62,6 +63,7 @@ In order to build the app, Heroku needs to know that it depends on websockets. Create a ``requirements.txt`` file containing this line: .. literalinclude:: ../../example/deployment/heroku/requirements.txt + :language: text Heroku also needs to know how to run the app. Create a ``Procfile`` with this content: @@ -86,7 +88,7 @@ The app is ready. Let's deploy it! .. code-block:: console - $ git push heroku main + $ git push heroku ... lots of output... diff --git a/docs/intro/index.rst b/docs/intro/index.rst index a5b68bfbf..2c66dea9a 100644 --- a/docs/intro/index.rst +++ b/docs/intro/index.rst @@ -16,19 +16,34 @@ websockets requires Python ≥ 3.7. It doesn't have any dependencies. +.. _install: + Installation ------------ -Install websockets with:: +Install websockets with: + +.. code-block:: console - pip install websockets + $ pip install websockets Wheels are available for all platforms. -First steps +Tutorial +-------- + +Learn how to build an real-time web application with websockets. + +.. toctree:: + + tutorial1 + tutorial2 + tutorial3 + +In a hurry? ----------- -If you're in a hurry, check out these examples. +Check out these examples. .. toctree:: diff --git a/docs/intro/tutorial1.rst b/docs/intro/tutorial1.rst new file mode 100644 index 000000000..ab4f39d79 --- /dev/null +++ b/docs/intro/tutorial1.rst @@ -0,0 +1,591 @@ +Part 1 - Send & receive +======================= + +.. currentmodule:: websockets + +In this tutorial, you're going to build a web-based `Connect Four`_ game. + +.. _Connect Four: https://en.wikipedia.org/wiki/Connect_Four + +The web removes the constraint of being in the same room for playing a game. +Two players can connect over of the Internet, regardless of where they are, +and play in their browsers. + +When a player makes a move, it should be reflected immediately on both sides. +This is difficult to implement over HTTP due to the request-response style of +the protocol. + +Indeed, there is no good way to be notified when the other player makes a +move. Workarounds such as polling or long-polling introduce significant +overhead. + +Enter `WebSocket `_. + +The WebSocket protocol provides two-way communication between a browser and a +server over a persistent connection. That's exactly what you need to exchange +moves between players, via a server. + +.. admonition:: This is the first part of the tutorial. + + * In this :doc:`first part `, you will create a server and + connect one browser; you can play if you share the same browser. + * In the :doc:`second part `, you will connect a second + browser; you can play from different browsers on a local network. + * In the :doc:`third part `, you will deploy the game to the + web; you can play from any browser connected to the Internet. + +Prerequisites +------------- + +This tutorial assumes basic knowledge of Python and JavaScript. + +If you're comfortable with :doc:`virtual environments `, +you can use one for this tutorial. Else, don't worry: websockets doesn't have +any dependencies; it shouldn't create trouble in the default environment. + +If you haven't installed websockets yet, do it now: + +.. code-block:: console + + $ pip install websockets + +Confirm that websockets is installed: + +.. code-block:: console + + $ python -m websockets --version + +.. admonition:: This tutorial is written for websockets |release|. + :class: tip + + If you installed another version, you should switch to the corresponding + version of the documentation. + +Download the starter kit +------------------------ + +Create a directory and download these three files: +:download:`connect4.js <../../example/tutorial/start/connect4.js>`, +:download:`connect4.css <../../example/tutorial/start/connect4.css>`, +and :download:`connect4.py <../../example/tutorial/start/connect4.py>`. + +The JavaScript module, along with the CSS file, provides a web-based user +interface. Here's its API. + +.. js:module:: connect4 + +.. js:data:: PLAYER1 + + Color of the first player. + +.. js:data:: PLAYER2 + + Color of the second player. + +.. js:function:: createBoard(board) + + Draw a board. + + :param board: DOM element containing the board; must be initially empty. + +.. js:function:: playMove(board, player, column, row) + + Play a move. + + :param board: DOM element containing the board. + :param player: :js:data:`PLAYER1` or :js:data:`PLAYER2`. + :param column: between ``0`` and ``6``. + :param row: between ``0`` and ``5``. + +The Python module provides a class to record moves and tell when a player +wins. Here's its API. + +.. module:: connect4 + +.. data:: PLAYER1 + :value: "red" + + Color of the first player. + +.. data:: PLAYER2 + :value: "yellow" + + Color of the second player. + +.. class:: Connect4 + + A Connect Four game. + + .. method:: play(player, column) + + Play a move. + + :param player: :data:`~connect4.PLAYER1` or :data:`~connect4.PLAYER2`. + :param column: between ``0`` and ``6``. + :returns: Row where the checker lands, between ``0`` and ``5``. + :raises RuntimeError: if the move is illegal. + + .. attribute:: moves + + List of moves played during this game, as ``(player, column, row)`` + tuples. + + .. attribute:: winner + + :data:`~connect4.PLAYER1` or :data:`~connect4.PLAYER2` if they + won; :obj:`None` if the game is still ongoing. + +.. currentmodule:: websockets + +Bootstrap the web UI +-------------------- + +Create an ``index.html`` file next to ``connect4.js`` and ``connect4.css`` +with this content: + +.. literalinclude:: ../../example/tutorial/step1/index.html + :language: html + +This HTML page contains an empty ``
`` element where you will draw the +Connect Four board. It loads a ``main.js`` script where you will write all +your JavaScript code. + +Create a ``main.js`` file next to ``index.html``. In this script, when the +page loads, draw the board: + +.. code-block:: javascript + + import { createBoard, playMove } from "./connect4.js"; + + window.addEventListener("DOMContentLoaded", () => { + // Initialize the UI. + const board = document.querySelector(".board"); + createBoard(board); + }); + +Open a shell, navigate to the directory containing these files, and start a +HTTP server: + +.. code-block:: console + + $ python -m http.server + +Open http://localhost:8000/ in a web browser. The page displays an empty board +with seven columns and six rows. You will play moves in this board later. + +Bootstrap the server +-------------------- + +Create an ``app.py`` file next to ``connect4.py`` with this content: + +.. code-block:: python + + #!/usr/bin/env python + + import asyncio + + import websockets + + + async def handler(websocket): + while True: + message = await websocket.recv() + print(message) + + + async def main(): + async with websockets.serve(handler, "", 8001): + await asyncio.Future() # run forever + + + if __name__ == "__main__": + asyncio.run(main()) + +The entry point of this program is ``asyncio.run(main())``. It creates an +asyncio event loop, runs the ``main()`` coroutine, and shuts down the loop. + +The ``main()`` coroutine calls :func:`~server.serve` to start a websockets +server. :func:`~server.serve` takes three positional arguments: + +* ``handler`` is a coroutine that manages a connection. When a client + connects, websockets calls ``handler`` with the connection in argument. + When ``handler`` terminates, websockets closes the connection. +* The second argument defines the network interfaces where the server can be + reached. Here, the server listens on all interfaces, so that other devices + on the same local network can connect. +* The third argument is the port on which the server listens. + +Invoking :func:`~server.serve` as an asynchronous context manager, in an +``async with`` block, ensures that the server shuts down properly when +terminating the program. + +For each connection, the ``handler()`` coroutine runs an infinite loop that +receives messages from the browser and prints them. + +Open a shell, navigate to the directory containing ``app.py``, and start the +server: + +.. code-block:: console + + $ python app.py + +This doesn't display anything. Hopefully the WebSocket server is running. +Let's make sure that it works. You cannot test the WebSocket server with a +web browser like you tested the HTTP server. However, you can test it with +websockets' interactive client. + +Open another shell and run this command: + +.. code-block:: console + + $ python -m websockets ws://localhost:8001/ + +You get a prompt. Type a message and press "Enter". Switch to the shell where +the server is running and check that the server received the message. Good! + +Exit the interactive client with Ctrl-C or Ctrl-D. + +Now, if you look at the console where you started the server, you can see the +stack trace of an exception: + +.. code-block:: pytb + + connection handler failed + Traceback (most recent call last): + ... + File "app.py", line 22, in handler + message = await websocket.recv() + ... + websockets.exceptions.ConnectionClosedOK: received 1000 (OK); then sent 1000 (OK) + +Indeed, the server was waiting for the next message +with :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` when the client +disconnected. When this happens, websockets raises +a :exc:`~exceptions.ConnectionClosedOK` exception to let you know that you +won't receive another message on this connection. + +This exception creates noise in the server logs, making it more difficult to +spot real errors when you add functionality to the server. Catch it in the +``handler()`` coroutine: + +.. code-block:: python + + async def handler(websocket): + while True: + try: + message = await websocket.recv() + except websockets.ConnectionClosedOK: + break + print(message) + +Stop the server with Ctrl-C and start it again: + +.. code-block:: console + + $ python app.py + +.. admonition:: You must restart the WebSocket server when you make changes. + :class: tip + + The WebSocket server loads the Python code in ``app.py`` then serves every + WebSocket request with this version of the code. As a consequence, + changes to ``app.py`` aren't visible until you restart the server. + + This is unlike the HTTP server that you started earlier with ``python -m + http.server``. For every request, this HTTP server reads the target file + and sends it. That's why changes are immediately visible. + + It is possible to :doc:`restart the WebSocket server automatically + <../howto/autoreload>` but this isn't necessary for this tutorial. + +Try connecting and disconnecting the interactive client again. +The :exc:`~exceptions.ConnectionClosedOK` exception doesn't appear anymore. + +This pattern is so common that websockets provides a shortcut for iterating +over messages received on the connection until the client disconnects: + +.. code-block:: python + + async def handler(websocket): + async for message in websocket: + print(message) + +Restart the server and check with the interactive client that its behavior +didn't change. + +At this point, you bootstrapped a web application and a WebSocket server. +Let's connect them. + +Transmit from browser to server +------------------------------- + +In JavaScript, you open a WebSocket connection as follows: + +.. code-block:: javascript + + const websocket = new WebSocket("ws://localhost:8001/"); + +Before you exchange messages with the server, you need to decide their format. +There is no universal convention for this. + +Let's use JSON objects with a ``type`` key identifying the type of the event +and the rest of the object containing properties of the event. + +Here's an event describing a move in the middle slot of the board: + +.. code-block:: javascript + + const event = {type: "play", column: 3}; + +Here's how to serialize this event to JSON and send it to the server: + +.. code-block:: javascript + + websocket.send(JSON.stringify(event)); + +Now you have all the building blocks to send moves to the server. + +Add this function to ``main.js``: + +.. literalinclude:: ../../example/tutorial/step1/main.js + :language: js + :start-at: function sendMoves + :end-before: window.addEventListener + +``sendMoves()`` registers a listener for ``click`` events on the board. The +listener figures out which column was clicked, builds a event of type +``"play"``, serializes it, and sends it to the server. + +Modify the initialization to open the WebSocket connection and call the +``sendMoves()`` function: + +.. code-block:: javascript + + window.addEventListener("DOMContentLoaded", () => { + // Initialize the UI. + const board = document.querySelector(".board"); + createBoard(board); + // Open the WebSocket connection and register event handlers. + const websocket = new WebSocket("ws://localhost:8001/"); + sendMoves(board, websocket); + }); + +Check that the HTTP server and the WebSocket server are still running. If you +stopped them, here are the commands to start them again: + +.. code-block:: console + + $ python -m http.server + +.. code-block:: console + + $ python app.py + +Refresh http://localhost:8000/ in your web browser. Click various columns in +the board. The server receives messages with the expected column number. + +There isn't any feedback in the board because you haven't implemented that +yet. Let's do it. + +Transmit from server to browser +------------------------------- + +In JavaScript, you receive WebSocket messages by listening to ``message`` +events. Here's how to receive a message from the server and deserialize it +from JSON: + +.. code-block:: javascript + + websocket.addEventListener("message", ({ data }) => { + const event = JSON.parse(data); + // do something with event + }); + +You're going to need three types of messages from the server to the browser: + +.. code-block:: javascript + + {type: "play", player: "red", column: 3, row: 0} + {type: "win", player: "red"} + {type: "error", message: "This slot is full."} + +The JavaScript code receiving these messages will dispatch events depending on +their type and take appropriate action. For example, it will react to an +event of type ``"play"`` by displaying the move on the board with +the :js:func:`~connect4.playMove` function. + +Add this function to ``main.js``: + +.. literalinclude:: ../../example/tutorial/step1/main.js + :language: js + :start-at: function showMessage + :end-before: function sendMoves + +.. admonition:: Why does ``showMessage`` use ``window.setTimeout``? + :class: hint + + When :js:func:`playMove` modifies the state of the board, the browser + renders changes asynchronously. Conversely, ``window.alert()`` runs + synchronously and blocks rendering while the alert is visible. + + If you called ``window.alert()`` immediately after :js:func:`playMove`, + the browser could display the alert before rendering the move. You could + get a "Player red wins!" alert without seeing red's last move. + + We're using ``window.alert()`` for simplicity in this tutorial. A real + application would display these messages in the user interface instead. + It wouldn't be vulnerable to this problem. + +Modify the initialization to call the ``receiveMoves()`` function: + +.. literalinclude:: ../../example/tutorial/step1/main.js + :language: js + :start-at: window.addEventListener + +At this point, the user interface should receive events properly. Let's test +it by modifying the server to send some events. + +Sending an event from Python is quite similar to JavaScript: + +.. code-block:: python + + event = {"type": "play", "player": "red", "column": 3, "row": 0} + await websocket.send(json.dumps(event)) + +.. admonition:: Don't forget to serialize the event with :func:`json.dumps`. + :class: tip + + Else, websockets raises ``TypeError: data is a dict-like object``. + +Modify the ``handler()`` coroutine in ``app.py`` as follows: + +.. code-block:: python + + import json + + from connect4 import PLAYER1, PLAYER2 + + async def handler(websocket): + for player, column, row in [ + (PLAYER1, 3, 0), + (PLAYER2, 3, 1), + (PLAYER1, 4, 0), + (PLAYER2, 4, 1), + (PLAYER1, 2, 0), + (PLAYER2, 1, 0), + (PLAYER1, 5, 0), + ]: + event = { + "type": "play", + "player": player, + "column": column, + "row": row, + } + await websocket.send(json.dumps(event)) + await asyncio.sleep(0.5) + event = { + "type": "win", + "player": PLAYER1, + } + await websocket.send(json.dumps(event)) + +Restart the WebSocket server and refresh http://localhost:8000/ in your web +browser. Seven moves appear at 0.5 second intervals. Then an alert announces +the winner. + +Good! Now you know how to communicate both ways. + +Once you plug the game engine to process moves, you will have a fully +functional game. + +Add the game logic +------------------ + +In the ``handler()`` coroutine, you're going to initialize a game: + +.. code-block:: python + + from connect4 import Connect4 + + async def handler(websocket): + # Initialize a Connect Four game. + game = Connect4() + + ... + +Then, you're going to iterate over incoming messages and take these steps: + +* parse an event of type ``"play"``, the only type of event that the user + interface sends; +* play the move in the board with the :meth:`~connect4.Connect4.play` method, + alternating between the two players; +* if :meth:`~connect4.Connect4.play` raises :exc:`RuntimeError` because the + move is illegal, send an event of type ``"error"``; +* else, send an event of type ``"play"`` to tell the user interface where the + checker lands; +* if the move won the game, send an event of type ``"win"``. + +Try to implement this by yourself! + +Keep in mind that you must restart the WebSocket server and reload the page in +the browser when you make changes. + +When it works, you can play the game from a single browser, with players +taking alternate turns. + +.. admonition:: Enable debug logs to see all messages sent and received. + :class: tip + + Here's how to enable debug logs: + + .. code-block:: python + + import logging + + logging.basicConfig(format="%(message)s", level=logging.DEBUG) + +If you're stuck, a solution is available at the bottom of this document. + +Summary +------- + +In this first part of the tutorial, you learned how to: + +* build and run a WebSocket server in Python with :func:`~server.serve`; +* receive a message in a connection handler + with :meth:`~server.WebSocketServerProtocol.recv`; +* send a message in a connection handler + with :meth:`~server.WebSocketServerProtocol.send`; +* iterate over incoming messages with ``async for + message in websocket: ...``; +* open a WebSocket connection in JavaScript with the ``WebSocket`` API; +* send messages in a browser with ``WebSocket.send()``; +* receive messages in a browser by listening to ``message`` events; +* design a set of events to be exchanged between the browser and the server. + +You can now play a Connect Four game in a browser, communicating over a +WebSocket connection with a server where the game logic resides! + +However, the two players share a browser, so the constraint of being in the +same room still applies. + +Move on to the :doc:`second part ` of the tutorial to break this +constraint and play from separate browsers. + +Solution +-------- + +.. literalinclude:: ../../example/tutorial/step1/app.py + :caption: app.py + :language: python + :linenos: + +.. literalinclude:: ../../example/tutorial/step1/index.html + :caption: index.html + :language: html + :linenos: + +.. literalinclude:: ../../example/tutorial/step1/main.js + :caption: main.js + :language: js + :linenos: diff --git a/docs/intro/tutorial2.rst b/docs/intro/tutorial2.rst new file mode 100644 index 000000000..669d46cde --- /dev/null +++ b/docs/intro/tutorial2.rst @@ -0,0 +1,565 @@ +Part 2 - Route & broadcast +========================== + +.. currentmodule:: websockets + +.. admonition:: This is the second part of the tutorial. + + * In the :doc:`first part `, you created a server and + connected one browser; you could play if you shared the same browser. + * In this :doc:`second part `, you will connect a second + browser; you can play from different browsers on a local network. + * In the :doc:`third part `, you will deploy the game to the + web; you can play from any browser connected to the Internet. + +In the first part of the tutorial, you opened a WebSocket connection from a +browser to a server and exchanged events to play moves. The state of the game +was stored in an instance of the :class:`~connect4.Connect4` class, +referenced as a local variable in the connection handler coroutine. + +Now you want to open two WebSocket connections from two separate browsers, one +for each player, to the same server in order to play the same game. This +requires moving the state of the game to a place where both connections can +access it. + +Share game state +---------------- + +As long as you're running a single server process, you can share state by +storing it in a global variable. + +.. admonition:: What if you need to scale to multiple server processes? + :class: hint + + In that case, you must design a way for the process that handles a given + connection to be aware of relevant events for that client. This is often + achieved with a publish / subscribe mechanism. + +How can you make two connection handlers agree on which game they're playing? +When the first player starts a game, you give it an identifier. Then, you +communicate the identifier to the second player. When the second player joins +the game, you look it up with the identifier. + +In addition to the game itself, you need to keep track of the WebSocket +connections of the two players. Since both players receive the same events, +you don't need to treat the two connections differently; you can store both +in the same set. + +Let's sketch this in code. + +A module-level :class:`dict` enables lookups by identifier: + +.. code-block:: python + + JOIN = {} + +When the first player starts the game, initialize and store it: + +.. code-block:: python + + import secrets + + async def handler(websocket): + ... + + # Initialize a Connect Four game, the set of WebSocket connections + # receiving moves from this game, and secret access token. + game = Connect4() + connected = {websocket} + + join_key = secrets.token_urlsafe(12) + JOIN[join_key] = game, connected + + try: + + ... + + finally: + del JOIN[join_key] + +When the second player joins the game, look it up: + +.. code-block:: python + + async def handler(websocket): + ... + + join_key = ... # TODO + + # Find the Connect Four game. + game, connected = JOIN[join_key] + + # Register to receive moves from this game. + connected.add(websocket) + try: + + ... + + finally: + connected.remove(websocket) + +Notice how we're carefully cleaning up global state with ``try: ... +finally: ...`` blocks. Else, we could leave references to games or +connections in global state, which would cause a memory leak. + +In both connection handlers, you have a ``game`` pointing to the same +:class:`~connect4.Connect4` instance, so you can interact with the game, +and a ``connected`` set of connections, so you can send game events to +both players as follows: + +.. code-block:: python + + async def handler(websocket): + + ... + + for connection in connected: + await connection.send(json.dumps(event)) + + ... + +Perhaps you spotted a major piece missing from the puzzle. How does the second +player obtain ``join_key``? Let's design new events to carry this information. + +To start a game, the first player sends an ``"init"`` event: + +.. code-block:: javascript + + {type: "init"} + +The connection handler for the first player creates a game as shown above and +responds with: + +.. code-block:: javascript + + {type: "init", join: ""} + +With this information, the user interface of the first player can create a +link to ``http://localhost:8000/?join=``. For the sake of simplicity, +we will assume that the first player shares this link with the second player +outside of the application, for example via an instant messaging service. + +To join the game, the second player sends a different ``"init"`` event: + +.. code-block:: javascript + + {type: "init", join: ""} + +The connection handler for the second player can look up the game with the +join key as shown above. There is no need to respond. + +Let's dive into the details of implementing this design. + +Start a game +------------ + +We'll start with the initialization sequence for the first player. + +In ``main.js``, define a function to send an initialization event when the +WebSocket connection is established, which triggers an ``open`` event: + +.. code-block:: javascript + + function initGame(websocket) { + websocket.addEventListener("open", () => { + // Send an "init" event for the first player. + const event = { type: "init" }; + websocket.send(JSON.stringify(event)); + }); + } + +Update the initialization sequence to call ``initGame()``: + +.. literalinclude:: ../../example/tutorial/step2/main.js + :language: js + :start-at: window.addEventListener + +In ``app.py``, define a new ``handler`` coroutine — keep a copy of the +previous one to reuse it later: + +.. code-block:: python + + import secrets + + + JOIN = {} + + + async def start(websocket): + # Initialize a Connect Four game, the set of WebSocket connections + # receiving moves from this game, and secret access token. + game = Connect4() + connected = {websocket} + + join_key = secrets.token_urlsafe(12) + JOIN[join_key] = game, connected + + try: + # Send the secret access token to the browser of the first player, + # where it'll be used for building a "join" link. + event = { + "type": "init", + "join": join_key, + } + await websocket.send(json.dumps(event)) + + # Temporary - for testing. + print("first player started game", id(game)) + async for message in websocket: + print("first player sent", message) + + finally: + del JOIN[join_key] + + + async def handler(websocket, path): + # Receive and parse the "init" event from the UI. + message = await websocket.recv() + event = json.loads(message) + assert event["type"] == "init" + + # First player starts a new game. + await start(websocket) + +In ``index.html``, add an ```` element to display the link to share with +the other player. + +.. code-block:: html + + + + + + +In ``main.js``, modify ``receiveMoves()`` to handle the ``"init"`` message and +set the target of that link: + +.. code-block:: javascript + + switch (event.type) { + case "init": + // Create link for inviting the second player. + document.querySelector(".join").href = "?join=" + event.join; + break; + // ... + } + +Restart the WebSocket server and reload http://localhost:8000/ in the browser. +There's a link labeled JOIN below the board with a target that looks like +http://localhost:8000/?join=95ftAaU5DJVP1zvb. + +The server logs say ``first player started game ...``. If you click the board, +you see ``"play"`` events. There is no feedback in the UI, though, because +you haven't restored the game logic yet. + +Before we get there, let's handle links with a ``join`` query parameter. + +Join a game +----------- + +We'll now update the initialization sequence to account for the second +player. + +In ``main.js``, update ``initGame()`` to send the join key in the ``"init"`` +message when it's in the URL: + +.. code-block:: javascript + + function initGame(websocket) { + websocket.addEventListener("open", () => { + // Send an "init" event according to who is connecting. + const params = new URLSearchParams(window.location.search); + let event = { type: "init" }; + if (params.has("join")) { + // Second player joins an existing game. + event.join = params.get("join"); + } else { + // First player starts a new game. + } + websocket.send(JSON.stringify(event)); + }); + } + +In ``app.py``, update the ``handler`` coroutine to look for the join key in +the ``"init"`` message, then load that game: + +.. code-block:: python + + async def error(websocket, message): + event = { + "type": "error", + "message": message, + } + await websocket.send(json.dumps(event)) + + + async def join(websocket, join_key): + # Find the Connect Four game. + try: + game, connected = JOIN[join_key] + except KeyError: + await error(websocket, "Game not found.") + return + + # Register to receive moves from this game. + connected.add(websocket) + try: + + # Temporary - for testing. + print("second player joined game", id(game)) + async for message in websocket: + print("second player sent", message) + + finally: + connected.remove(websocket) + + + async def handler(websocket): + # Receive and parse the "init" event from the UI. + message = await websocket.recv() + event = json.loads(message) + assert event["type"] == "init" + + if "join" in event: + # Second player joins an existing game. + await join(websocket, event["join"]) + else: + # First player starts a new game. + await start(websocket) + +Restart the WebSocket server and reload http://localhost:8000/ in the browser. + +Copy the link labeled JOIN and open it in another browser. You may also open +it in another tab or another window of the same browser; however, that makes +it a bit tricky to remember which one is the first or second player. + +.. admonition:: You must start a new game when you restart the server. + :class: tip + + Since games are stored in the memory of the Python process, they're lost + when you stop the server. + + Whenever you make changes to ``app.py``, you must restart the server, + create a new game in a browser, and join it in another browser. + +The server logs say ``first player started game ...`` and ``second player +joined game ...``. The numbers match, proving that the ``game`` local +variable in both connection handlers points to same object in the memory of +the Python process. + +Click the board in either browser. The server receives ``"play"`` events from +the corresponding player. + +In the initialization sequence, you're routing connections to ``start()`` or +``join()`` depending on the first message received by the server. This is a +common pattern in servers that handle different clients. + +.. admonition:: Why not use different URIs for ``start()`` and ``join()``? + :class: hint + + Instead of sending an initialization event, you could encode the join key + in the WebSocket URI e.g. ``ws://localhost:8001/join/``. The + WebSocket server would parse ``websocket.path`` and route the connection, + similar to how HTTP servers route requests. + + When you need to send sensitive data like authentication credentials to + the server, sending it an event is considered more secure than encoding + it in the URI because URIs end up in logs. + + For the purposes of this tutorial, both approaches are equivalent because + the join key comes from a HTTP URL. There isn't much at risk anyway! + +Now you can restore the logic for playing moves and you'll have a fully +functional two-player game. + +Add the game logic +------------------ + +Once the initialization is done, the game is symmetrical, so you can write a +single coroutine to process the moves of both players: + +.. code-block:: python + + async def play(game, player, connected): + ... + +With such a coroutine, you can replace the temporary code for testing in +``start()`` by: + +.. code-block:: python + + await play(game, PLAYER1, connected) + +and in ``join()`` by: + +.. code-block:: python + + await play(game, PLAYER2, connected) + +The ``play()`` coroutine will reuse much of the code you wrote in the first +part of the tutorial. + +Try to implement this by yourself! + +Keep in mind that you must restart the WebSocket server, reload the page to +start a new game with the first player, copy the JOIN link, and join the game +with the second player when you make changes. + +When ``play()`` works, you can play the game from two separate browsers, +possibly running on separate computers on the same local network. + +A complete solution is available at the bottom of this document. + +Watch a game +------------ + +Let's add one more feature: allow spectators to watch the game. + +The process for inviting a spectator can be the same as for inviting the +second player. You will have to duplicate all the initialization logic: + +- declare a ``WATCH`` global variable similar to ``JOIN``; +- generate a watch key when creating a game; it must be different from the + join key, or else a spectator could hijack a game by tweaking the URL; +- include the watch key in the ``"init"`` event sent to the first player; +- generate a WATCH link in the UI with a ``watch`` query parameter; +- update the ``initGame()`` function to handle such links; +- update the ``handler()`` coroutine to invoke a ``watch()`` coroutine for + spectators; +- prevent ``sendMoves()`` from sending ``"play"`` events for spectators. + +Once the initialization sequence is done, watching a game is as simple as +registering the WebSocket connection in the ``connected`` set in order to +receive game events and doing nothing until the spectator disconnects. You +can wait for a connection to terminate with +:meth:`~legacy.protocol.WebSocketCommonProtocol.wait_closed`: + +.. code-block:: python + + async def watch(websocket, watch_key): + + ... + + connected.add(websocket) + try: + await websocket.wait_closed() + finally: + connected.remove(websocket) + +The connection can terminate because the ``receiveMoves()`` function closed it +explicitly after receiving a ``"win"`` event, because the spectator closed +their browser, or because the network failed. + +Again, try to implement this by yourself. + +When ``watch()`` works, you can invite spectators to watch the game from other +browsers, as long as they're on the same local network. + +As a further improvement, you may support adding spectators while a game is +already in progress. This requires replaying moves that were played before +the spectator was added to the ``connected`` set. Past moves are available in +the :attr:`~connect4.Connect4.moves` attribute of the game. + +This feature is included in the solution proposed below. + +Broadcast +--------- + +When you need to send a message to the two players and to all spectators, +you're using this pattern: + +.. code-block:: python + + async def handler(websocket): + + ... + + for connection in connected: + await connection.send(json.dumps(event)) + + ... + +Since this is a very common pattern in WebSocket servers, websockets provides +the :func:`broadcast` helper for this purpose: + +.. code-block:: python + + async def handler(websocket): + + ... + + websockets.broadcast(connected, json.dumps(event)) + + ... + +Calling :func:`broadcast` once is more efficient than +calling :meth:`~legacy.protocol.WebSocketCommonProtocol.send` in a loop. + +However, there's a subtle difference in behavior. Did you notice that there's +no ``await`` in the second version? Indeed, :func:`broadcast` is a function, +not a coroutine like :meth:`~legacy.protocol.WebSocketCommonProtocol.send` +or :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`. + +It's quite obvious why :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` +is a coroutine. When you want to receive the next message, you have to wait +until the client sends it and the network transmits it. + +It's less obvious why :meth:`~legacy.protocol.WebSocketCommonProtocol.send` is +a coroutine. If you send many messages or large messages, you could write +data faster than the network can transmit it or the client can read it. Then, +outgoing data will pile up in buffers, which will consume memory and may +crash your application. + +To avoid this problem, :meth:`~legacy.protocol.WebSocketCommonProtocol.send` +waits until the write buffer drains. By slowing down the application as +necessary, this ensures that the server doesn't send data too quickly. This +is called backpressure and it's useful for building robust systems. + +That said, when you're sending the same messages to many clients in a loop, +applying backpressure in this way can become counterproductive. When you're +broadcasting, you don't want to slow down everyone to the pace of the slowest +clients; you want to drop clients that cannot keep up with the data stream. +That's why :func:`broadcast` doesn't wait until write buffers drain. + +For our Connect Four game, there's no difference in practice: the total amount +of data sent on a connection for a game of Connect Four is less than 64 KB, +so the write buffer never fills up and backpressure never kicks in anyway. + +Summary +------- + +In this second part of the tutorial, you learned how to: + +* configure a connection by exchanging initialization messages; +* keep track of connections within a single server process; +* wait until a client disconnects in a connection handler; +* broadcast a message to many connections efficiently. + +You can now play a Connect Four game from separate browser, communicating over +WebSocket connections with a server that synchronizes the game logic! + +However, the two players have to be on the same local network as the server, +so the constraint of being in the same place still mostly applies. + +Head over to the :doc:`third part ` of the tutorial to deploy the +game to the web and remove this constraint. + +Solution +-------- + +.. literalinclude:: ../../example/tutorial/step2/app.py + :caption: app.py + :language: python + :linenos: + +.. literalinclude:: ../../example/tutorial/step2/index.html + :caption: index.html + :language: html + :linenos: + +.. literalinclude:: ../../example/tutorial/step2/main.js + :caption: main.js + :language: js + :linenos: diff --git a/docs/intro/tutorial3.rst b/docs/intro/tutorial3.rst new file mode 100644 index 000000000..9a447f39b --- /dev/null +++ b/docs/intro/tutorial3.rst @@ -0,0 +1,290 @@ +Part 3 - Deploy to the web +========================== + +.. currentmodule:: websockets + +.. admonition:: This is the third part of the tutorial. + + * In the :doc:`first part `, you created a server and + connected one browser; you could play if you shared the same browser. + * In this :doc:`second part `, you connected a second browser; + you could play from different browsers on a local network. + * In this :doc:`third part `, you will deploy the game to the + web; you can play from any browser connected to the Internet. + +In the first and second parts of the tutorial, for local development, you ran +a HTTP server on ``http://localhost:8000/`` with: + +.. code-block:: console + + $ python -m http.server + +and a WebSocket server on ``ws://localhost:8001/`` with: + +.. code-block:: console + + $ python app.py + +Now you want to deploy these servers on the Internet. There's a vast range of +hosting providers to choose from. For the sake of simplicity, we'll rely on: + +* GitHub Pages for the HTTP server; +* Heroku for the WebSocket server. + +Commit project to git +--------------------- + +Perhaps you committed your work to git while you were progressing through the +tutorial. If you didn't, now is a good time, because GitHub and Heroku offer +git-based deployment workflows. + +Initialize a git repository: + +.. code-block:: console + + $ git init -b main + Initialized empty Git repository in websockets-tutorial/.git/ + $ git commit --allow-empty -m "Initial commit." + [main (root-commit) ...] Initial commit. + +Add all files and commit: + +.. code-block:: console + + $ git add . + $ git commit -m "Initial implementation of Connect Four game." + [main ...] Initial implementation of Connect Four game. + 6 files changed, 500 insertions(+) + create mode 100644 app.py + create mode 100644 connect4.css + create mode 100644 connect4.js + create mode 100644 connect4.py + create mode 100644 index.html + create mode 100644 main.js + +Prepare the WebSocket server +---------------------------- + +Before you deploy the server, you must adapt it to meet requirements of +Heroku's runtime. This involves two small changes: + +1. Heroku expects the server to `listen on a specific port`_, provided in the + ``$PORT`` environment variable. + +2. Heroku sends a ``SIGTERM`` signal when `shutting down a dyno`_, which + should trigger a clean exit. + +.. _listen on a specific port: https://devcenter.heroku.com/articles/preparing-a-codebase-for-heroku-deployment#4-listen-on-the-correct-port + +.. _shutting down a dyno: https://devcenter.heroku.com/articles/dynos#shutdown + +Adapt the ``main()`` coroutine accordingly: + +.. code-block:: python + + import os + import signal + +.. literalinclude:: ../../example/tutorial/step3/app.py + :pyobject: main + +To catch the ``SIGTERM`` signal, ``main()`` creates a :class:`~asyncio.Future` +called ``stop`` and registers a signal handler that sets the result of this +future. The value of the future doesn't matter; it's only for waiting for +``SIGTERM``. + +Then, by using :func:`~server.serve` as a context manager and exiting the +context when ``stop`` has a result, ``main()`` ensures that the server closes +connections cleanly and exits on ``SIGTERM``. + +The app is now fully compatible with Heroku. + +Deploy the WebSocket server +--------------------------- + +Create a ``requirements.txt`` file with this content to install ``websockets`` +when building the image: + +.. literalinclude:: ../../example/tutorial/step3/requirements.txt + :language: text + +.. admonition:: Heroku treats ``requirements.txt`` as a signal to `detect a Python app`_. + :class: tip + + That's why you don't need to declare that you need a Python runtime. + +.. _detect a Python app: https://devcenter.heroku.com/articles/python-support#recognizing-a-python-app + +Create a ``Procfile`` file with this content to configure the command for +running the server: + +.. literalinclude:: ../../example/tutorial/step3/Procfile + :language: text + +Commit your changes: + +.. code-block:: console + + $ git add . + $ git commit -m "Deploy to Heroku." + [main ...] Deploy to Heroku. + 3 files changed, 12 insertions(+), 2 deletions(-) + create mode 100644 Procfile + create mode 100644 requirements.txt + +Follow the `set-up instructions`_ to install the Heroku CLI and to log in, if +you haven't done that yet. + +.. _set-up instructions: https://devcenter.heroku.com/articles/getting-started-with-python#set-up + +Create a Heroku app. You must choose a unique name and replace +``websockets-tutorial`` by this name in the following command: + +.. code-block:: console + + $ heroku create websockets-tutorial + Creating ⬢ websockets-tutorial... done + https://websockets-tutorial.herokuapp.com/ | https://git.heroku.com/websockets-tutorial.git + +If you reuse a name that someone else already uses, you will receive this +error; if this happens, try another name: + +.. code-block:: console + + $ heroku create websockets-tutorial + Creating ⬢ websockets-tutorial... ! + ▸ Name websockets-tutorial is already taken + +Deploy by pushing the code to Heroku: + +.. code-block:: console + + $ git push heroku + + ... lots of output... + + remote: Released v1 + remote: https://websockets-tutorial.herokuapp.com/ deployed to Heroku + remote: + remote: Verifying deploy... done. + To https://git.heroku.com/websockets-tutorial.git + * [new branch] main -> main + +You can test the WebSocket server with the interactive client exactly like you +did in the first part of the tutorial. Replace ``websockets-tutorial`` by the +name of your app in the following command: + +.. code-block:: console + + $ python -m websockets wss://websockets-tutorial.herokuapp.com/ + Connected to wss://websockets-tutorial.herokuapp.com/. + > {"type": "init"} + < {"type": "init", "join": "54ICxFae_Ip7TJE2", "watch": "634w44TblL5Dbd9a"} + Connection closed: 1000 (OK). + +It works! + +Prepare the web application +--------------------------- + +Before you deploy the web application, perhaps you're wondering how it will +locate the WebSocket server? Indeed, at this point, its address is hard-coded +in ``main.js``: + +.. code-block:: javascript + + const websocket = new WebSocket("ws://localhost:8001/"); + +You can take this strategy one step further by checking the address of the +HTTP server and determining the address of the WebSocket server accordingly. + +Add this function to ``main.js``; replace ``aaugustin`` by your GitHub +username and ``websockets-tutorial`` by the name of your app on Heroku: + +.. literalinclude:: ../../example/tutorial/step3/main.js + :language: js + :start-at: function getWebSocketServer + :end-before: function initGame + +Then, update the initialization to connect to this address instead: + +.. code-block:: javascript + + const websocket = new WebSocket(getWebSocketServer()); + +Commit your changes: + +.. code-block:: console + + $ git add . + $ git commit -m "Configure WebSocket server address." + [main ...] Configure WebSocket server address. + 1 file changed, 11 insertions(+), 1 deletion(-) + +Deploy the web application +-------------------------- + +Go to GitHub and create a new repository called ``websockets-tutorial``. + +Push your code to this repository. You must replace ``aaugustin`` by your +GitHub username in the following command: + +.. code-block:: console + + $ git remote add origin git@github.com:aaugustin/websockets-tutorial.git + $ git push -u origin main + Enumerating objects: 11, done. + Counting objects: 100% (11/11), done. + Delta compression using up to 8 threads + Compressing objects: 100% (10/10), done. + Writing objects: 100% (11/11), 5.90 KiB | 2.95 MiB/s, done. + Total 11 (delta 0), reused 0 (delta 0), pack-reused 0 + To github.com:/websockets-tutorial.git + * [new branch] main -> main + Branch 'main' set up to track remote branch 'main' from 'origin'. + +Go back to GitHub, open the Settings tab of the repository and select Pages in +the menu. Select the main branch as source and click Save. GitHub tells you +that your site is published. + +Follow the link and start a game! + +Summary +------- + +In this third part of the tutorial, you learned how to deploy a WebSocket +application with Heroku. + +You can start a Connect Four game, send the JOIN link to a friend, and play +over the Internet! + +Congratulations for completing the tutorial. Enjoy building real-time web +applications with websockets! + +Solution +-------- + +.. literalinclude:: ../../example/tutorial/step3/app.py + :caption: app.py + :language: python + :linenos: + +.. literalinclude:: ../../example/tutorial/step3/index.html + :caption: index.html + :language: html + :linenos: + +.. literalinclude:: ../../example/tutorial/step3/main.js + :caption: main.js + :language: js + :linenos: + +.. literalinclude:: ../../example/tutorial/step3/Procfile + :caption: Procfile + :language: text + :linenos: + +.. literalinclude:: ../../example/tutorial/step3/requirements.txt + :caption: requirements.txt + :language: text + :linenos: diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index b57d3c77f..1d5ae527d 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -18,6 +18,7 @@ coroutines cryptocurrencies cryptocurrency ctrl +deserialize django dyno fractalideas diff --git a/example/tutorial/start/connect4.css b/example/tutorial/start/connect4.css new file mode 100644 index 000000000..27f0baf6e --- /dev/null +++ b/example/tutorial/start/connect4.css @@ -0,0 +1,105 @@ +/* General layout */ + +body { + background-color: white; + display: flex; + flex-direction: column-reverse; + justify-content: center; + align-items: center; + margin: 0; + min-height: 100vh; +} + +/* Action buttons */ + +.actions { + display: flex; + flex-direction: row; + justify-content: space-evenly; + align-items: flex-end; + width: 720px; + height: 100px; +} + +.action { + color: darkgray; + font-family: "Helvetica Neue", sans-serif; + font-size: 20px; + line-height: 20px; + font-weight: 300; + text-align: center; + text-decoration: none; + text-transform: uppercase; + padding: 20px; + width: 120px; +} + +.action:hover { + background-color: darkgray; + color: white; + font-weight: 700; +} + +.action[href=""] { + display: none; +} + +/* Connect Four board */ + +.board { + background-color: blue; + display: flex; + flex-direction: row; + padding: 0 10px; + position: relative; +} + +.board::before, +.board::after { + background-color: blue; + content: ""; + height: 720px; + width: 20px; + position: absolute; +} + +.board::before { + left: -20px; +} + +.board::after { + right: -20px; +} + +.column { + display: flex; + flex-direction: column-reverse; + padding: 10px; +} + +.cell { + border-radius: 50%; + width: 80px; + height: 80px; + margin: 10px 0; +} + +.empty { + background-color: white; +} + +.column:hover .empty { + background-color: lightgray; +} + +.column:hover .empty ~ .empty { + background-color: white; +} + +.red { + background-color: red; +} + +.yellow { + background-color: yellow; +} diff --git a/example/tutorial/start/connect4.js b/example/tutorial/start/connect4.js new file mode 100644 index 000000000..cb5eb9fa2 --- /dev/null +++ b/example/tutorial/start/connect4.js @@ -0,0 +1,45 @@ +const PLAYER1 = "red"; + +const PLAYER2 = "yellow"; + +function createBoard(board) { + // Inject stylesheet. + const linkElement = document.createElement("link"); + linkElement.href = import.meta.url.replace(".js", ".css"); + linkElement.rel = "stylesheet"; + document.head.append(linkElement); + // Generate board. + for (let column = 0; column < 7; column++) { + const columnElement = document.createElement("div"); + columnElement.className = "column"; + columnElement.dataset.column = column; + for (let row = 0; row < 6; row++) { + const cellElement = document.createElement("div"); + cellElement.className = "cell empty"; + cellElement.dataset.column = column; + columnElement.append(cellElement); + } + board.append(columnElement); + } +} + +function playMove(board, player, column, row) { + // Check values of arguments. + if (player !== PLAYER1 && player !== PLAYER2) { + throw new Error(`player must be ${PLAYER1} or ${PLAYER2}.`); + } + const columnElement = board.querySelectorAll(".column")[column]; + if (columnElement === undefined) { + throw new RangeError("column must be between 0 and 6."); + } + const cellElement = columnElement.querySelectorAll(".cell")[row]; + if (cellElement === undefined) { + throw new RangeError("row must be between 0 and 5."); + } + // Place checker in cell. + if (!cellElement.classList.replace("empty", player)) { + throw new Error("cell must be empty."); + } +} + +export { PLAYER1, PLAYER2, createBoard, playMove }; diff --git a/example/tutorial/start/connect4.py b/example/tutorial/start/connect4.py new file mode 100644 index 000000000..0a61e7c7e --- /dev/null +++ b/example/tutorial/start/connect4.py @@ -0,0 +1,62 @@ +__all__ = ["PLAYER1", "PLAYER2", "Connect4"] + +PLAYER1, PLAYER2 = "red", "yellow" + + +class Connect4: + """ + A Connect Four game. + + Play moves with :meth:`play`. + + Get past moves with :attr:`moves`. + + Check for a victory with :attr:`winner`. + + """ + + def __init__(self): + self.moves = [] + self.top = [0 for _ in range(7)] + self.winner = None + + @property + def last_player(self): + """ + Player who played the last move. + + """ + return PLAYER1 if len(self.moves) % 2 else PLAYER2 + + @property + def last_player_won(self): + """ + Whether the last move is winning. + + """ + b = sum(1 << (8 * column + row) for _, column, row in self.moves[::-2]) + return any(b & b >> v & b >> 2 * v & b >> 3 * v for v in [1, 7, 8, 9]) + + def play(self, player, column): + """ + Play a move in a column. + + Returns the row where the checker lands. + + Raises :exc:`RuntimeError` if the move is illegal. + + """ + if player == self.last_player: + raise RuntimeError("It isn't your turn.") + + row = self.top[column] + if row == 6: + raise RuntimeError("This slot is full.") + + self.moves.append((player, column, row)) + self.top[column] += 1 + + if self.winner is None and self.last_player_won: + self.winner = self.last_player + + return row diff --git a/example/tutorial/start/favicon.ico b/example/tutorial/start/favicon.ico new file mode 100644 index 0000000000000000000000000000000000000000..36e855029d705e72d44428bda6e8cb6d3dd317ed GIT binary patch literal 5430 zcmeH~eNdFw6~-?im3Ahb_(wY9G?@uAnl{s!P5}k6EMLo)Am6YM4N*)C21rB%KSG3E zv;hMxiW&u@(MF93wOCQ3rD-sRXbhdyP7o3K#!S*q0u${7l(0J*XQZVahM#RhT&Mw%^BqMk{U@J1~1H2BJEU64r)0!B+s<_@|*2 zqhiBw1#eC^Kd~FBvDYyg*$#wV#Xwjm9t&v2p)n2r@@vI8+WgsgcG9|#V(P|dLI)7j zfdSf?v#0~lgtfyurUjuNe*Ta2+T*8ca~J%;(ZdETkY^gva%XNcqcj2kdN-G|0yeVF}ZFP=4Z!J2R#i0#02+Ir+*zVG@y z*5YS?yJuBTDvF9 z+e4qHUo$?zisgL(ZFDT{$HRZz5|LeB6ok5h0Mx8E{O!%Wpod<2-phQZDi;R}!Uptt z_Sffbz;E>+X1=#0_Q!9(68Y81O(FP$*?76Z67toHt(cKzH39p~78|EMbIvF7dZI9CxZHU=B0 z<9Jap>Rt-q+)$@;bvu`KC6(K3*mRuQ9MaslWomPE8-LktGQPhv{>AfKV-~kmXM9f| z|2D6HQbVN2tW*=RJWA^l;TTP>949?wSa4Hnl)qW|Kbb4f8Fvyy2h!?Rc2a%__sk=;L*m&%aOk`2^Q3=D%J$`lrna32wt&`ua?J=IYOX`)ey| zsDH*}F--J3_a!wE;q=XC^%sGy0H6D|y~p17*k>&_PGWEtANNEx9#Nf`$LIWDXd7VU z-lc0`F21g#6;|0_CyK+_IPED%Vmjd+bq#|NSGn)ee~4%xL9&NlK@0!?spEQZeA9Sw zOg84YE{XKby>e*LH9+>d2-$me4ei*@T2#ST@l)pX-aNEoMD}B9TkIH*)9-fc#{JDF zsS`tE`y^thLD#X6zKIf~ACX*0P0;qeI{2O1>Ze+hd$YWhu%sV8DP0&!?gSFKPsfuu z=^dyc`Wh;U66imQy~wci5ZZc7tYMe4y3>E-zOC46crNM2&=cI_Q|MnbFb{~Qf9lvq zl)zjpqW|G=)`eCR4jO{~dh5&OHAgQLZpnV9ydbWbam~@=o9J4d>8Y0WqGf3>kS6_3 zH=ygFmaRky%tJKeAJJ-p{_kqA)#Yh(*$efw^_AAoNKZthCz1G^u_xP8J>82TYJl`x zPtw0=VBb|il)!vMb4^CrJ8A0?#hMfy%!Slv9#Q+&`23lKphoC2`RmZF?C`@YT|BPQis$fd)qZc3H#T~U(G%W)mfglo{sU;_HV?#Giog+ zWdZKb7(gC1)GY7CGNOdEKE$SWVT5(5)r3};^sUup2XfhuD&bJR@3U{@`dIhp%!9oK z`hCKgZ~0C9j|Y3#aMr|#eN9k{rX#rbA2Ga;A`w_nTG-C ztBDe-vb<91XNIG%>P!!=&(~otdZi|$F7=v_x=f{SqJ-l`K`Vb(5MVvJ!Jze!ng;g7 z>?QBK{=AowD1llEwR+6*IO(WiJm0l|D{Ep{uL)g4S}$^l5>9OnTX~|`Xtnjy{y6%g zO~ax51$Wq&ClMvEKQ9vBhc`yq?ukr~>U(@1IJr6Gc0;i-fhgfV}k+ptmI=o*Qx0@RDGY}eIiR4)*!i%Vr#UY#Q>p{*{nuh%O zw-{rDA=?=E^vn-xf;^a)yc&4S#>DB;t~+h17%G6M9Y7Y%ttpeEt)~ z*kB1&)1jq0b@m6ll1AkWCmP=Q^&;&&DfRHy%i&r*$hok-P}6XevH3fT@AR)C)pTj8 zN-n>!BLP|-wu%PrJgTk5T@A${HyM{c{o!N z{uyKPol0D`gW)%O|F9}G?&6Mwv&}o=zc^nRJKR(eeeKWX(PfP1M^laIn|c0j*0+@b zZDY{i*oZ&{!b(QfqZlSb-*-dLEEhC+xWJOGU};h)wgh3bYC?&N1*5L&qtO{IrNa-n mx{c17;Wp~=fE$Kp5swEk + + + Connect Four + + +
+ + + diff --git a/example/tutorial/step1/main.js b/example/tutorial/step1/main.js new file mode 100644 index 000000000..dd28f9a6a --- /dev/null +++ b/example/tutorial/step1/main.js @@ -0,0 +1,53 @@ +import { createBoard, playMove } from "./connect4.js"; + +function showMessage(message) { + window.setTimeout(() => window.alert(message), 50); +} + +function receiveMoves(board, websocket) { + websocket.addEventListener("message", ({ data }) => { + const event = JSON.parse(data); + switch (event.type) { + case "play": + // Update the UI with the move. + playMove(board, event.player, event.column, event.row); + break; + case "win": + showMessage(`Player ${event.player} wins!`); + // No further messages are expected; close the WebSocket connection. + websocket.close(1000); + break; + case "error": + showMessage(event.message); + break; + default: + throw new Error(`Unsupported event type: ${event.type}.`); + } + }); +} + +function sendMoves(board, websocket) { + // When clicking a column, send a "play" event for a move in that column. + board.addEventListener("click", ({ target }) => { + const column = target.dataset.column; + // Ignore clicks outside a column. + if (column === undefined) { + return; + } + const event = { + type: "play", + column: parseInt(column, 10), + }; + websocket.send(JSON.stringify(event)); + }); +} + +window.addEventListener("DOMContentLoaded", () => { + // Initialize the UI. + const board = document.querySelector(".board"); + createBoard(board); + // Open the WebSocket connection and register event handlers. + const websocket = new WebSocket("ws://localhost:8001/"); + receiveMoves(board, websocket); + sendMoves(board, websocket); +}); diff --git a/example/tutorial/step2/app.py b/example/tutorial/step2/app.py new file mode 100644 index 000000000..bac2b6f27 --- /dev/null +++ b/example/tutorial/step2/app.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python + +import asyncio +import json +import secrets + +import websockets + +from connect4 import PLAYER1, PLAYER2, Connect4 + + +JOIN = {} + +WATCH = {} + + +async def error(websocket, message): + """ + Send an error message. + + """ + event = { + "type": "error", + "message": message, + } + await websocket.send(json.dumps(event)) + + +async def replay(websocket, game): + """ + Send previous moves. + + """ + # Make a copy to avoid an exception if game.moves changes while iteration + # is in progress. If a move is played while replay is running, moves will + # be sent out of order but each move will be sent once and eventually the + # UI will be consistent. + for player, column, row in game.moves.copy(): + event = { + "type": "play", + "player": player, + "column": column, + "row": row, + } + await websocket.send(json.dumps(event)) + + +async def play(websocket, game, player, connected): + """ + Receive and process moves from a player. + + """ + async for message in websocket: + # Parse a "play" event from the UI. + event = json.loads(message) + assert event["type"] == "play" + column = event["column"] + + try: + # Play the move. + row = game.play(player, column) + except RuntimeError as exc: + # Send an "error" event if the move was illegal. + await error(websocket, str(exc)) + continue + + # Send a "play" event to update the UI. + event = { + "type": "play", + "player": player, + "column": column, + "row": row, + } + websockets.broadcast(connected, json.dumps(event)) + + # If move is winning, send a "win" event. + if game.winner is not None: + event = { + "type": "win", + "player": game.winner, + } + websockets.broadcast(connected, json.dumps(event)) + + +async def start(websocket): + """ + Handle a connection from the first player: start a new game. + + """ + # Initialize a Connect Four game, the set of WebSocket connections + # receiving moves from this game, and secret access tokens. + game = Connect4() + connected = {websocket} + + join_key = secrets.token_urlsafe(12) + JOIN[join_key] = game, connected + + watch_key = secrets.token_urlsafe(12) + WATCH[watch_key] = game, connected + + try: + # Send the secret access tokens to the browser of the first player, + # where they'll be used for building "join" and "watch" links. + event = { + "type": "init", + "join": join_key, + "watch": watch_key, + } + await websocket.send(json.dumps(event)) + # Receive and process moves from the first player. + await play(websocket, game, PLAYER1, connected) + finally: + del JOIN[join_key] + del WATCH[watch_key] + + +async def join(websocket, join_key): + """ + Handle a connection from the second player: join an existing game. + + """ + # Find the Connect Four game. + try: + game, connected = JOIN[join_key] + except KeyError: + await error(websocket, "Game not found.") + return + + # Register to receive moves from this game. + connected.add(websocket) + try: + # Send the first move, in case the first player already played it. + await replay(websocket, game) + # Receive and process moves from the second player. + await play(websocket, game, PLAYER2, connected) + finally: + connected.remove(websocket) + + +async def watch(websocket, watch_key): + """ + Handle a connection from a spectator: watch an existing game. + + """ + # Find the Connect Four game. + try: + game, connected = WATCH[watch_key] + except KeyError: + await error(websocket, "Game not found.") + return + + # Register to receive moves from this game. + connected.add(websocket) + try: + # Send previous moves, in case the game already started. + await replay(websocket, game) + # Keep the connection open, but don't receive any messages. + await websocket.wait_closed() + finally: + connected.remove(websocket) + + +async def handler(websocket, path): + """ + Handle a connection and dispatch it according to who is connecting. + + """ + # Receive and parse the "init" event from the UI. + message = await websocket.recv() + event = json.loads(message) + assert event["type"] == "init" + + if "join" in event: + # Second player joins an existing game. + await join(websocket, event["join"]) + elif "watch" in event: + # Spectator watches an existing game. + await watch(websocket, event["watch"]) + else: + # First player starts a new game. + await start(websocket) + + +async def main(): + async with websockets.serve(handler, "", 8001): + await asyncio.Future() # run forever + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/tutorial/step2/connect4.css b/example/tutorial/step2/connect4.css new file mode 120000 index 000000000..55a9977ca --- /dev/null +++ b/example/tutorial/step2/connect4.css @@ -0,0 +1 @@ +../start/connect4.css \ No newline at end of file diff --git a/example/tutorial/step2/connect4.js b/example/tutorial/step2/connect4.js new file mode 120000 index 000000000..7c4ed2f3e --- /dev/null +++ b/example/tutorial/step2/connect4.js @@ -0,0 +1 @@ +../start/connect4.js \ No newline at end of file diff --git a/example/tutorial/step2/connect4.py b/example/tutorial/step2/connect4.py new file mode 120000 index 000000000..eab6b7dc0 --- /dev/null +++ b/example/tutorial/step2/connect4.py @@ -0,0 +1 @@ +../start/connect4.py \ No newline at end of file diff --git a/example/tutorial/step2/favicon.ico b/example/tutorial/step2/favicon.ico new file mode 120000 index 000000000..76da1c2fb --- /dev/null +++ b/example/tutorial/step2/favicon.ico @@ -0,0 +1 @@ +../../../logo/favicon.ico \ No newline at end of file diff --git a/example/tutorial/step2/index.html b/example/tutorial/step2/index.html new file mode 100644 index 000000000..1a16f72a2 --- /dev/null +++ b/example/tutorial/step2/index.html @@ -0,0 +1,15 @@ + + + + Connect Four + + +
+ New + Join + Watch +
+
+ + + diff --git a/example/tutorial/step2/main.js b/example/tutorial/step2/main.js new file mode 100644 index 000000000..d38a0140a --- /dev/null +++ b/example/tutorial/step2/main.js @@ -0,0 +1,83 @@ +import { createBoard, playMove } from "./connect4.js"; + +function initGame(websocket) { + websocket.addEventListener("open", () => { + // Send an "init" event according to who is connecting. + const params = new URLSearchParams(window.location.search); + let event = { type: "init" }; + if (params.has("join")) { + // Second player joins an existing game. + event.join = params.get("join"); + } else if (params.has("watch")) { + // Spectator watches an existing game. + event.watch = params.get("watch"); + } else { + // First player starts a new game. + } + websocket.send(JSON.stringify(event)); + }); +} + +function showMessage(message) { + window.setTimeout(() => window.alert(message), 50); +} + +function receiveMoves(board, websocket) { + websocket.addEventListener("message", ({ data }) => { + const event = JSON.parse(data); + switch (event.type) { + case "init": + // Create links for inviting the second player and spectators. + document.querySelector(".join").href = "?join=" + event.join; + document.querySelector(".watch").href = "?watch=" + event.watch; + break; + case "play": + // Update the UI with the move. + playMove(board, event.player, event.column, event.row); + break; + case "win": + showMessage(`Player ${event.player} wins!`); + // No further messages are expected; close the WebSocket connection. + websocket.close(1000); + break; + case "error": + showMessage(event.message); + break; + default: + throw new Error(`Unsupported event type: ${event.type}.`); + } + }); +} + +function sendMoves(board, websocket) { + // Don't send moves for a spectator watching a game. + const params = new URLSearchParams(window.location.search); + if (params.has("watch")) { + return; + } + + // When clicking a column, send a "play" event for a move in that column. + board.addEventListener("click", ({ target }) => { + const column = target.dataset.column; + // Ignore clicks outside a column. + if (column === undefined) { + return; + } + const event = { + type: "play", + column: parseInt(column, 10), + }; + websocket.send(JSON.stringify(event)); + }); +} + +window.addEventListener("DOMContentLoaded", () => { + // Initialize the UI. + const board = document.querySelector(".board"); + createBoard(board); + // Open the WebSocket connection and register event handlers. + const websocket = new WebSocket("ws://localhost:8001/"); + initGame(websocket); + receiveMoves(board, websocket); + sendMoves(board, websocket); +}); diff --git a/example/tutorial/step3/Procfile b/example/tutorial/step3/Procfile new file mode 100644 index 000000000..2e35818f6 --- /dev/null +++ b/example/tutorial/step3/Procfile @@ -0,0 +1 @@ +web: python app.py diff --git a/example/tutorial/step3/app.py b/example/tutorial/step3/app.py new file mode 100644 index 000000000..6fff79c95 --- /dev/null +++ b/example/tutorial/step3/app.py @@ -0,0 +1,198 @@ +#!/usr/bin/env python + +import asyncio +import json +import os +import secrets +import signal + +import websockets + +from connect4 import PLAYER1, PLAYER2, Connect4 + + +JOIN = {} + +WATCH = {} + + +async def error(websocket, message): + """ + Send an error message. + + """ + event = { + "type": "error", + "message": message, + } + await websocket.send(json.dumps(event)) + + +async def replay(websocket, game): + """ + Send previous moves. + + """ + # Make a copy to avoid an exception if game.moves changes while iteration + # is in progress. If a move is played while replay is running, moves will + # be sent out of order but each move will be sent once and eventually the + # UI will be consistent. + for player, column, row in game.moves.copy(): + event = { + "type": "play", + "player": player, + "column": column, + "row": row, + } + await websocket.send(json.dumps(event)) + + +async def play(websocket, game, player, connected): + """ + Receive and process moves from a player. + + """ + async for message in websocket: + # Parse a "play" event from the UI. + event = json.loads(message) + assert event["type"] == "play" + column = event["column"] + + try: + # Play the move. + row = game.play(player, column) + except RuntimeError as exc: + # Send an "error" event if the move was illegal. + await error(websocket, str(exc)) + continue + + # Send a "play" event to update the UI. + event = { + "type": "play", + "player": player, + "column": column, + "row": row, + } + websockets.broadcast(connected, json.dumps(event)) + + # If move is winning, send a "win" event. + if game.winner is not None: + event = { + "type": "win", + "player": game.winner, + } + websockets.broadcast(connected, json.dumps(event)) + + +async def start(websocket): + """ + Handle a connection from the first player: start a new game. + + """ + # Initialize a Connect Four game, the set of WebSocket connections + # receiving moves from this game, and secret access tokens. + game = Connect4() + connected = {websocket} + + join_key = secrets.token_urlsafe(12) + JOIN[join_key] = game, connected + + watch_key = secrets.token_urlsafe(12) + WATCH[watch_key] = game, connected + + try: + # Send the secret access tokens to the browser of the first player, + # where they'll be used for building "join" and "watch" links. + event = { + "type": "init", + "join": join_key, + "watch": watch_key, + } + await websocket.send(json.dumps(event)) + # Receive and process moves from the first player. + await play(websocket, game, PLAYER1, connected) + finally: + del JOIN[join_key] + del WATCH[watch_key] + + +async def join(websocket, join_key): + """ + Handle a connection from the second player: join an existing game. + + """ + # Find the Connect Four game. + try: + game, connected = JOIN[join_key] + except KeyError: + await error(websocket, "Game not found.") + return + + # Register to receive moves from this game. + connected.add(websocket) + try: + # Send the first move, in case the first player already played it. + await replay(websocket, game) + # Receive and process moves from the second player. + await play(websocket, game, PLAYER2, connected) + finally: + connected.remove(websocket) + + +async def watch(websocket, watch_key): + """ + Handle a connection from a spectator: watch an existing game. + + """ + # Find the Connect Four game. + try: + game, connected = WATCH[watch_key] + except KeyError: + await error(websocket, "Game not found.") + return + + # Register to receive moves from this game. + connected.add(websocket) + try: + # Send previous moves, in case the game already started. + await replay(websocket, game) + # Keep the connection open, but don't receive any messages. + await websocket.wait_closed() + finally: + connected.remove(websocket) + + +async def handler(websocket, path): + """ + Handle a connection and dispatch it according to who is connecting. + + """ + # Receive and parse the "init" event from the UI. + message = await websocket.recv() + event = json.loads(message) + assert event["type"] == "init" + + if "join" in event: + # Second player joins an existing game. + await join(websocket, event["join"]) + elif "watch" in event: + # Spectator watches an existing game. + await watch(websocket, event["watch"]) + else: + # First player starts a new game. + await start(websocket) + + +async def main(): + # Set the stop condition when receiving SIGTERM. + loop = asyncio.get_running_loop() + stop = loop.create_future() + loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) + + port = int(os.environ.get("PORT", "8001")) + async with websockets.serve(handler, "", port): + await stop + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/tutorial/step3/connect4.css b/example/tutorial/step3/connect4.css new file mode 120000 index 000000000..55a9977ca --- /dev/null +++ b/example/tutorial/step3/connect4.css @@ -0,0 +1 @@ +../start/connect4.css \ No newline at end of file diff --git a/example/tutorial/step3/connect4.js b/example/tutorial/step3/connect4.js new file mode 120000 index 000000000..7c4ed2f3e --- /dev/null +++ b/example/tutorial/step3/connect4.js @@ -0,0 +1 @@ +../start/connect4.js \ No newline at end of file diff --git a/example/tutorial/step3/connect4.py b/example/tutorial/step3/connect4.py new file mode 120000 index 000000000..eab6b7dc0 --- /dev/null +++ b/example/tutorial/step3/connect4.py @@ -0,0 +1 @@ +../start/connect4.py \ No newline at end of file diff --git a/example/tutorial/step3/favicon.ico b/example/tutorial/step3/favicon.ico new file mode 120000 index 000000000..76da1c2fb --- /dev/null +++ b/example/tutorial/step3/favicon.ico @@ -0,0 +1 @@ +../../../logo/favicon.ico \ No newline at end of file diff --git a/example/tutorial/step3/index.html b/example/tutorial/step3/index.html new file mode 100644 index 000000000..1a16f72a2 --- /dev/null +++ b/example/tutorial/step3/index.html @@ -0,0 +1,15 @@ + + + + Connect Four + + +
+ New + Join + Watch +
+
+ + + diff --git a/example/tutorial/step3/main.js b/example/tutorial/step3/main.js new file mode 100644 index 000000000..15afd4163 --- /dev/null +++ b/example/tutorial/step3/main.js @@ -0,0 +1,93 @@ +import { createBoard, playMove } from "./connect4.js"; + +function getWebSocketServer() { + if (window.location.host === "aaugustin.github.io") { + return "wss://websockets-tutorial.herokuapp.com/"; + } else if (window.location.host === "localhost:8000") { + return "ws://localhost:8001/"; + } else { + throw new Error(`Unsupported host: ${window.location.host}`); + } +} + +function initGame(websocket) { + websocket.addEventListener("open", () => { + // Send an "init" event according to who is connecting. + const params = new URLSearchParams(window.location.search); + let event = { type: "init" }; + if (params.has("join")) { + // Second player joins an existing game. + event.join = params.get("join"); + } else if (params.has("watch")) { + // Spectator watches an existing game. + event.watch = params.get("watch"); + } else { + // First player starts a new game. + } + websocket.send(JSON.stringify(event)); + }); +} + +function showMessage(message) { + window.setTimeout(() => window.alert(message), 50); +} + +function receiveMoves(board, websocket) { + websocket.addEventListener("message", ({ data }) => { + const event = JSON.parse(data); + switch (event.type) { + case "init": + // Create links for inviting the second player and spectators. + document.querySelector(".join").href = "?join=" + event.join; + document.querySelector(".watch").href = "?watch=" + event.watch; + break; + case "play": + // Update the UI with the move. + playMove(board, event.player, event.column, event.row); + break; + case "win": + showMessage(`Player ${event.player} wins!`); + // No further messages are expected; close the WebSocket connection. + websocket.close(1000); + break; + case "error": + showMessage(event.message); + break; + default: + throw new Error(`Unsupported event type: ${event.type}.`); + } + }); +} + +function sendMoves(board, websocket) { + // Don't send moves for a spectator watching a game. + const params = new URLSearchParams(window.location.search); + if (params.has("watch")) { + return; + } + + // When clicking a column, send a "play" event for a move in that column. + board.addEventListener("click", ({ target }) => { + const column = target.dataset.column; + // Ignore clicks outside a column. + if (column === undefined) { + return; + } + const event = { + type: "play", + column: parseInt(column, 10), + }; + websocket.send(JSON.stringify(event)); + }); +} + +window.addEventListener("DOMContentLoaded", () => { + // Initialize the UI. + const board = document.querySelector(".board"); + createBoard(board); + // Open the WebSocket connection and register event handlers. + const websocket = new WebSocket(getWebSocketServer()); + initGame(websocket); + receiveMoves(board, websocket); + sendMoves(board, websocket); +}); diff --git a/example/tutorial/step3/requirements.txt b/example/tutorial/step3/requirements.txt new file mode 100644 index 000000000..14774b465 --- /dev/null +++ b/example/tutorial/step3/requirements.txt @@ -0,0 +1 @@ +websockets From 448e777adc7d8cf0cbc0bd585959bbdfad682926 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 6 Nov 2021 09:01:11 +0100 Subject: [PATCH 0990/1539] Avoid shutdown on closed socket. Fix #1072. --- src/websockets/legacy/protocol.py | 2 +- tests/legacy/test_protocol.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 3340c33be..23d46d018 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1303,7 +1303,7 @@ async def close_connection(self) -> None: self.logger.debug("! timed out waiting for TCP close") # Half-close the TCP connection if possible (when there's no TLS). - if self.transport.can_write_eof(): + if self.transport.can_write_eof() and not self.transport.is_closing(): if self.debug: self.logger.debug("x half-closing TCP connection") self.transport.write_eof() diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 1672ab1ed..22e72a696 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -59,6 +59,9 @@ def setup_mock(self, loop, protocol): def can_write_eof(self): return True + def is_closing(self): + return False + def write_eof(self): # When the protocol half-closes the TCP connection, it expects the # other end to close it. Simulate that. From d7320ca11e47499ae8cbbaddfe8f8bd69802bc48 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 14 Nov 2021 20:52:21 +0100 Subject: [PATCH 0991/1539] Add missing items to changelog. --- docs/project/changelog.rst | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 5b190fb94..a8b77c034 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -33,6 +33,8 @@ They may change at any time. New features ............ +* Added a tutorial. + * Made the second parameter of connection handlers optional. It will be deprecated in the next major release. The request path is available in the :attr:`~legacy.protocol.WebSocketCommonProtocol.path` attribute of @@ -49,6 +51,8 @@ New features path = request.path # if handler() uses the path argument ... +* Added ``python -m websockets --version``. + Improvements ............ @@ -66,6 +70,11 @@ Improvements * Documented how to auto-reload on code changes in development. +Bug fixes +......... + +* Avoided half-closing TCP connections that are already closed. + 10.0 ---- From 5f63545517270cd3218c2ed19df190f5ce88f795 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 14 Nov 2021 20:45:50 +0100 Subject: [PATCH 0992/1539] Release 10.1. --- docs/project/changelog.rst | 2 +- src/websockets/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index a8b77c034..ea9079166 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -28,7 +28,7 @@ They may change at any time. 10.1 ---- -*In development* +*November 14, 2021* New features ............ diff --git a/src/websockets/version.py b/src/websockets/version.py index 3a6e6aa08..cb76be5d2 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -16,7 +16,7 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = False +released = True tag = version = commit = "10.1" From b835315f465124e7fedf1a753d07cfbce95327af Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 14 Nov 2021 21:29:44 +0100 Subject: [PATCH 0993/1539] Start 10.2. --- docs/project/changelog.rst | 5 +++++ src/websockets/version.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index ea9079166..b58ff15b1 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,11 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented APIs are considered private. They may change at any time. +10.2 +---- + +*In development* + 10.1 ---- diff --git a/src/websockets/version.py b/src/websockets/version.py index cb76be5d2..605c8264a 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -16,9 +16,9 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = True +released = False -tag = version = commit = "10.1" +tag = version = commit = "10.2" if not released: # pragma: no cover From 67b19e6e9ac96722a584e35d7e9dc3a2173af6b8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 20 Nov 2021 20:33:47 +0100 Subject: [PATCH 0994/1539] Clean up mypy cache. --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index b15cd13c9..6f8130840 100644 --- a/Makefile +++ b/Makefile @@ -27,4 +27,4 @@ build: clean: find . -name '*.pyc' -o -name '*.so' -delete find . -name __pycache__ -delete - rm -rf .coverage build compliance/reports dist docs/_build htmlcov MANIFEST src/websockets.egg-info + rm -rf .coverage .mypy_cache build compliance/reports dist docs/_build htmlcov MANIFEST src/websockets.egg-info From 14891ffa20f6acc0f6ec24a96b6bd4d3504ea3eb Mon Sep 17 00:00:00 2001 From: David Sanders Date: Tue, 21 Dec 2021 15:54:27 -0800 Subject: [PATCH 0995/1539] Fix some typos in comments --- src/websockets/client.py | 2 +- src/websockets/connection.py | 2 +- src/websockets/datastructures.py | 2 +- src/websockets/http11.py | 2 +- src/websockets/server.py | 2 +- src/websockets/speedups.c | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 34732b3a6..9b86b4d0a 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -201,7 +201,7 @@ def process_extensions(self, headers: Headers) -> List[Extension]: client configuration. If no match is found, an exception is raised. If several variants of the same extension are accepted by the server, - it may be configured severel times, which won't make sense in general. + it may be configured several times, which won't make sense in general. Extensions must implement their own requirements. For this purpose, the list of previously accepted extensions is provided. diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 8661a148b..0a4d3c7bc 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -106,7 +106,7 @@ def __init__( # Connection side. CLIENT or SERVER. self.side = side - # Connnection state. Initially OPEN because subclasses handle CONNECTING. + # Connection state. Initially OPEN because subclasses handle CONNECTING. self.state = state # Maximum size of incoming messages in bytes. diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index d5c061cf8..37d3b5f86 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -118,7 +118,7 @@ def __setitem__(self, key: str, value: str) -> None: def __delitem__(self, key: str) -> None: key_lower = key.lower() self._dict.__delitem__(key_lower) - # This is inefficent. Fortunately deleting HTTP headers is uncommon. + # This is inefficient. Fortunately deleting HTTP headers is uncommon. self._list = [(k, v) for k, v in self._list if k.lower() != key_lower] def __eq__(self, other: Any) -> bool: diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 052719c67..a2fd22dd2 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -185,7 +185,7 @@ def parse( read_exact: generator-based coroutine that reads the requested bytes or raises an exception if there isn't enough data. read_to_eof: generator-based coroutine that reads until the end - of the strem. + of the stream. Raises: EOFError: if the connection is closed without a full HTTP response. diff --git a/src/websockets/server.py b/src/websockets/server.py index 2d9b4f9a8..a94c0b629 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -174,7 +174,7 @@ def process_request( self, request: Request ) -> Tuple[str, Optional[str], Optional[str]]: """ - Check a handshake request and negociate extensions and subprotocol. + Check a handshake request and negotiate extensions and subprotocol. This function doesn't verify that the request is an HTTP/1.1 or higher GET request and doesn't check the ``Host`` header. These controls are diff --git a/src/websockets/speedups.c b/src/websockets/speedups.c index f8d24ec7a..a19590419 100644 --- a/src/websockets/speedups.c +++ b/src/websockets/speedups.c @@ -121,7 +121,7 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds) goto exit; } - // Since we juste created result, we don't need error checks. + // Since we just created result, we don't need error checks. output = PyBytes_AS_STRING(result); // Perform the masking operation. From 668f320e0547d80afe6529528e1ecc6088955cdc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 24 Dec 2021 18:51:53 +0100 Subject: [PATCH 0996/1539] Update for the latest version fo mypy. --- src/websockets/legacy/protocol.py | 11 +++++------ src/websockets/legacy/server.py | 10 +++++----- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 23d46d018..c1809e20d 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -750,7 +750,7 @@ async def close(self, code: int = 1000, reason: str = "") -> None: self.write_close_frame(Close(code, reason)), self.close_timeout, **loop_if_py_lt_38(self.loop), - ) # type: ignore # remove when removing loop_if_py_lt_38 + ) except asyncio.TimeoutError: # If the close frame cannot be sent because the send buffers # are full, the closing handshake won't complete anyway. @@ -771,7 +771,7 @@ async def close(self, code: int = 1000, reason: str = "") -> None: self.transfer_data_task, self.close_timeout, **loop_if_py_lt_38(self.loop), - ) # type: ignore # remove when removing loop_if_py_lt_38 + ) except (asyncio.TimeoutError, asyncio.CancelledError): pass @@ -1072,8 +1072,7 @@ def append(frame: Frame) -> None: raise ProtocolError("unexpected opcode") append(frame) - # mypy cannot figure out that chunks have the proper type. - return ("" if text else b"").join(chunks) # type: ignore + return ("" if text else b"").join(chunks) async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: """ @@ -1250,7 +1249,7 @@ async def keepalive_ping(self) -> None: pong_waiter, self.ping_timeout, **loop_if_py_lt_38(self.loop), - ) # type: ignore # remove when removing loop_if_py_lt_38 + ) self.logger.debug("% received keepalive pong") except asyncio.TimeoutError: if self.debug: @@ -1365,7 +1364,7 @@ async def wait_for_connection_lost(self) -> bool: asyncio.shield(self.connection_lost_waiter), self.close_timeout, **loop_if_py_lt_38(self.loop), - ) # type: ignore # remove when removing loop_if_py_lt_38 + ) except asyncio.TimeoutError: pass # Re-check self.connection_lost_waiter.done() synchronously because diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 98712ff86..3172059d2 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -14,6 +14,7 @@ Awaitable, Callable, Generator, + Iterable, List, Optional, Sequence, @@ -361,7 +362,7 @@ async def process_request( warnings.warn( "declare process_request as a coroutine", DeprecationWarning ) - return response # type: ignore + return response return None @staticmethod @@ -589,7 +590,7 @@ async def handshake( else: # For backwards compatibility with 7.0. warnings.warn("declare process_request as a coroutine", DeprecationWarning) - early_response = early_response_awaitable # type: ignore + early_response = early_response_awaitable # The connection may drop while process_request is running. if self.state is State.CLOSED: @@ -677,7 +678,7 @@ def __init__(self, logger: Optional[LoggerLike] = None): # Completed when the server is closed and connections are terminated. self.closed_waiter: asyncio.Future[None] - def wrap(self, server: asyncio.AbstractServer) -> None: + def wrap(self, server: asyncio.base_events.Server) -> None: """ Attach to a given :class:`~asyncio.Server`. @@ -692,7 +693,6 @@ def wrap(self, server: asyncio.AbstractServer) -> None: """ self.server = server - assert server.sockets is not None for sock in server.sockets: if sock.family == socket.AF_INET: name = "%s:%d" % sock.getsockname() @@ -842,7 +842,7 @@ async def serve_forever(self) -> None: await self.server.serve_forever() # pragma: no cover @property - def sockets(self) -> Optional[List[socket.socket]]: + def sockets(self) -> Iterable[socket.socket]: """ See :attr:`asyncio.Server.sockets`. From 2b965a146eb41d05073264ddd296542992fb327f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 19 Feb 2022 09:50:26 +0100 Subject: [PATCH 0997/1539] Document that convenience imports break IDEs. Fix #1124. --- docs/project/changelog.rst | 8 +++++--- docs/reference/index.rst | 6 +++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index b58ff15b1..14945b7bf 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -259,10 +259,12 @@ Backwards-incompatible changes .. admonition:: Convenience imports from ``websockets`` are performed lazily. :class: note - While Python supports this, static code analysis tools such as mypy are - unable to understand the behavior. + While Python supports this, tools relying on static code analysis don't. + This breaks autocompletion in an IDE or type checking with mypy_. - If you depend on such tools, use the real import path, which can be found + .. _mypy: https://github.com/python/mypy + + If you depend on such tools, use the real import paths, which can be found in the API documentation, for example:: from websockets.client import connect diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 385beab29..a5ee57843 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -59,8 +59,8 @@ Anything that isn't listed in the API reference is a private API. There's no guarantees of behavior or backwards-compatibility for private APIs. For convenience, many public APIs can be imported from the ``websockets`` -package. This feature is incompatible with static code analysis tools such as -mypy_, though. If you're using such tools, use the full import path. +package. However, this feature is incompatible with static code analysis. It +breaks autocompletion in an IDE or type checking with mypy_. If you're using +such tools, use the real import paths. .. _mypy: https://github.com/python/mypy - From 1e44065f49647abb6bd58bdb9a435dd064e40c74 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 19 Feb 2022 13:41:05 +0100 Subject: [PATCH 0998/1539] Document OSError: [Errno 99]. Fix #1102. --- docs/howto/faq.rst | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index 20b957745..128d1dbd0 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -124,6 +124,16 @@ Look at the ``host`` argument of :meth:`~asyncio.loop.create_server`. :func:`~server.serve` accepts the same arguments as :meth:`~asyncio.loop.create_server`. +What does ``OSError: [Errno 99] error while attempting to bind on address ('::1', 80, 0, 0): address not available`` mean? +.......................................................................................................................... + +You are calling :func:`~server.serve` without a ``host`` argument in a context +where IPv6 isn't available. + +To listen only on IPv4, specify ``host="0.0.0.0"`` or ``family=socket.AF_INET``. + +Refer to the documentation of :meth:`~asyncio.loop.create_server` for details. + How do I close a connection properly? ..................................... From cd8659e25444c14f4b1e67b5a612a297d0a00cf4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 19 Feb 2022 13:49:48 +0100 Subject: [PATCH 0999/1539] Standardize style. (black has gotten better at preserving multi-line style.) --- tests/extensions/test_permessage_deflate.py | 116 +++++++++++++++----- 1 file changed, 91 insertions(+), 25 deletions(-) diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index bcd08a7ef..0d56917b4 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -27,16 +27,20 @@ class ExtensionTestsMixin: def assertExtensionEqual(self, extension1, extension2): self.assertEqual( - extension1.remote_no_context_takeover, extension2.remote_no_context_takeover + extension1.remote_no_context_takeover, + extension2.remote_no_context_takeover, ) self.assertEqual( - extension1.local_no_context_takeover, extension2.local_no_context_takeover + extension1.local_no_context_takeover, + extension2.local_no_context_takeover, ) self.assertEqual( - extension1.remote_max_window_bits, extension2.remote_max_window_bits + extension1.remote_max_window_bits, + extension2.remote_max_window_bits, ) self.assertEqual( - extension1.local_max_window_bits, extension2.local_max_window_bits + extension1.local_max_window_bits, + extension2.local_max_window_bits, ) @@ -84,7 +88,8 @@ def test_encode_decode_text_frame(self): enc_frame = self.extension.encode(frame) self.assertEqual( - enc_frame, dataclasses.replace(frame, rsv1=True, data=b"JNL;\xbc\x12\x00") + enc_frame, + dataclasses.replace(frame, rsv1=True, data=b"JNL;\xbc\x12\x00"), ) dec_frame = self.extension.decode(enc_frame) @@ -97,7 +102,8 @@ def test_encode_decode_binary_frame(self): enc_frame = self.extension.encode(frame) self.assertEqual( - enc_frame, dataclasses.replace(frame, rsv1=True, data=b"*IM\x04\x00") + enc_frame, + dataclasses.replace(frame, rsv1=True, data=b"*IM\x04\x00"), ) dec_frame = self.extension.decode(enc_frame) @@ -120,10 +126,12 @@ def test_encode_decode_fragmented_text_frame(self): ), ) self.assertEqual( - enc_frame2, dataclasses.replace(frame2, data=b"RPS\x00\x00\x00\x00\xff\xff") + enc_frame2, + dataclasses.replace(frame2, data=b"RPS\x00\x00\x00\x00\xff\xff"), ) self.assertEqual( - enc_frame3, dataclasses.replace(frame3, data=b"J.\xca\xcf,.N\xcc+)\x06\x00") + enc_frame3, + dataclasses.replace(frame3, data=b"J.\xca\xcf,.N\xcc+)\x06\x00"), ) dec_frame1 = self.extension.decode(enc_frame1) @@ -304,16 +312,34 @@ def test_init_error(self): def test_get_request_params(self): for config, result in [ # Test without any parameter - ((False, False, None, None), []), + ( + (False, False, None, None), + [], + ), # Test server_no_context_takeover - ((True, False, None, None), [("server_no_context_takeover", None)]), + ( + (True, False, None, None), + [("server_no_context_takeover", None)], + ), # Test client_no_context_takeover - ((False, True, None, None), [("client_no_context_takeover", None)]), + ( + (False, True, None, None), + [("client_no_context_takeover", None)], + ), # Test server_max_window_bits - ((False, False, 10, None), [("server_max_window_bits", "10")]), + ( + (False, False, 10, None), + [("server_max_window_bits", "10")], + ), # Test client_max_window_bits - ((False, False, None, 10), [("client_max_window_bits", "10")]), - ((False, False, None, True), [("client_max_window_bits", None)]), + ( + (False, False, None, 10), + [("client_max_window_bits", "10")], + ), + ( + (False, False, None, True), + [("client_max_window_bits", None)], + ), # Test all parameters together ( (True, True, 12, 12), @@ -332,15 +358,27 @@ def test_get_request_params(self): def test_process_response_params(self): for config, response_params, result in [ # Test without any parameter - ((False, False, None, None), [], (False, False, 15, 15)), - ((False, False, None, None), [("unknown", None)], InvalidParameterName), + ( + (False, False, None, None), + [], + (False, False, 15, 15), + ), + ( + (False, False, None, None), + [("unknown", None)], + InvalidParameterName, + ), # Test server_no_context_takeover ( (False, False, None, None), [("server_no_context_takeover", None)], (True, False, 15, 15), ), - ((True, False, None, None), [], NegotiationError), + ( + (True, False, None, None), + [], + NegotiationError, + ), ( (True, False, None, None), [("server_no_context_takeover", None)], @@ -362,7 +400,11 @@ def test_process_response_params(self): [("client_no_context_takeover", None)], (False, True, 15, 15), ), - ((False, True, None, None), [], (False, True, 15, 15)), + ( + (False, True, None, None), + [], + (False, True, 15, 15), + ), ( (False, True, None, None), [("client_no_context_takeover", None)], @@ -394,7 +436,11 @@ def test_process_response_params(self): [("server_max_window_bits", "16")], NegotiationError, ), - ((False, False, 12, None), [], NegotiationError), + ( + (False, False, 12, None), + [], + NegotiationError, + ), ( (False, False, 12, None), [("server_max_window_bits", "10")], @@ -426,7 +472,11 @@ def test_process_response_params(self): [("client_max_window_bits", "10")], NegotiationError, ), - ((False, False, None, True), [], (False, False, 15, 15)), + ( + (False, False, None, True), + [], + (False, False, 15, 15), + ), ( (False, False, None, True), [("client_max_window_bits", "7")], @@ -442,7 +492,11 @@ def test_process_response_params(self): [("client_max_window_bits", "16")], NegotiationError, ), - ((False, False, None, 12), [], (False, False, 15, 12)), + ( + (False, False, None, 12), + [], + (False, False, 15, 12), + ), ( (False, False, None, 12), [("client_max_window_bits", "10")], @@ -558,7 +612,8 @@ def test_enable_client_permessage_deflate(self): extension = extensions[expected_position] self.assertIsInstance(extension, ClientPerMessageDeflateFactory) self.assertEqual( - extension.compress_settings, expected_compress_settings + extension.compress_settings, + expected_compress_settings, ) @@ -597,7 +652,12 @@ def test_process_request_params(self): # (remote, local) vs. (server, client). for config, request_params, response_params, result in [ # Test without any parameter - ((False, False, None, None), [], [], (False, False, 15, 15)), + ( + (False, False, None, None), + [], + [], + (False, False, 15, 15), + ), ( (False, False, None, None), [("unknown", None)], @@ -746,7 +806,12 @@ def test_process_request_params(self): None, InvalidParameterValue, ), - ((False, False, None, 12), [], None, NegotiationError), + ( + (False, False, None, 12), + [], + None, + NegotiationError, + ), ( (False, False, None, 12), [("client_max_window_bits", None)], @@ -895,5 +960,6 @@ def test_enable_server_permessage_deflate(self): extension = extensions[expected_position] self.assertIsInstance(extension, ServerPerMessageDeflateFactory) self.assertEqual( - extension.compress_settings, expected_compress_settings + extension.compress_settings, + expected_compress_settings, ) From 9535c2137bdcdc0d34cf8367d2bb16c91a6fc083 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 19 Feb 2022 15:24:14 +0100 Subject: [PATCH 1000/1539] Support building docs without pyenchant. --- docs/conf.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/conf.py b/docs/conf.py index 8d9fcdc8f..fe6282b5a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -66,6 +66,11 @@ "sphinxcontrib_trio", "sphinxext.opengraph", ] +# It is currently inconvenient to install PyEnchant on Apple Silicon. +try: + import sphinxcontrib.spelling +except ImportError: + extensions.remove("sphinxcontrib.spelling") autodoc_typehints = "description" From c3560c9f6ff33d058459c2e626c0bd0b8c2c8e1a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 19 Feb 2022 15:24:31 +0100 Subject: [PATCH 1001/1539] Made compression negotation more lax. This change means connections from Firefox get compression by default, while they didn't use to since the compression "optimizations" in 10.0. Fix #1109. --- docs/project/changelog.rst | 5 +++++ docs/topics/compression.rst | 21 +++++++++++++++++-- .../extensions/permessage_deflate.py | 19 +++++++++++++---- tests/extensions/test_permessage_deflate.py | 10 ++++++++- tests/legacy/test_client_server.py | 5 ++++- 5 files changed, 52 insertions(+), 8 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 14945b7bf..f888aeb2c 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -30,6 +30,11 @@ They may change at any time. *In development* +Improvements +............ + +* Made compression negotiation more lax for compatibility with Firefox. + 10.1 ---- diff --git a/docs/topics/compression.rst b/docs/topics/compression.rst index 0f264dc66..f0b7ce898 100644 --- a/docs/topics/compression.rst +++ b/docs/topics/compression.rst @@ -85,8 +85,13 @@ and memory usage for both sides. an integer between 9 (lowest memory usage) and 15 (best compression). Setting it to 8 is possible but rejected by some versions of zlib. - On the server side, websockets defaults to 12. On the client side, it lets - the server pick a suitable value, which is the same as defaulting to 15. + On the server side, websockets defaults to 12. Specifically, the compression + window size (server to client) is always 12 while the decompression window + (client to server) size may be 12 or 15 depending on whether the client + supports configuring it. + + On the client side, websockets lets the server pick a suitable value, which + has the same effect as defaulting to 15. :mod:`zlib` offers additional parameters for tuning compression. They control the trade-off between compression rate, memory usage, and CPU usage only for @@ -164,6 +169,18 @@ usage is: CPU usage is also higher for compression than decompression. +While it's always possible for a server to use a smaller window size for +compressing outgoing messages, using a smaller window size for decompressing +incoming messages requires collaboration from clients. + +When a client doesn't support configuring the size of its compression window, +websockets enables compression with the largest possible decompression window. +In most use cases, this is more efficient than disabling compression both ways. + +If you are very sensitive to memory usage, you can reverse this behavior by +setting the ``require_client_max_window_bits`` parameter of +:class:`ServerPerMessageDeflateFactory` to ``True``. + For clients ........... diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index fefa55643..017e3d843 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -471,10 +471,13 @@ class ServerPerMessageDeflateFactory(ServerExtensionFactory): server_max_window_bits: maximum size of the server's LZ77 sliding window in bits, between 8 and 15. client_max_window_bits: maximum size of the client's LZ77 sliding window - in bits, between 8 and 15, or :obj:`True` to indicate support without - setting a limit. + in bits, between 8 and 15. compress_settings: additional keyword arguments for :func:`zlib.compressobj`, excluding ``wbits``. + require_client_max_window_bits: do not enable compression at all if + client doesn't advertise support for ``client_max_window_bits``; + the default behavior is to enable compression without enforcing + ``client_max_window_bits``. """ @@ -487,6 +490,7 @@ def __init__( server_max_window_bits: Optional[int] = None, client_max_window_bits: Optional[int] = None, compress_settings: Optional[Dict[str, Any]] = None, + require_client_max_window_bits: bool = False, ) -> None: """ Configure the Per-Message Deflate extension factory. @@ -501,12 +505,18 @@ def __init__( "compress_settings must not include wbits, " "set server_max_window_bits instead" ) + if client_max_window_bits is None and require_client_max_window_bits: + raise ValueError( + "require_client_max_window_bits is enabled, " + "but client_max_window_bits isn't configured" + ) self.server_no_context_takeover = server_no_context_takeover self.client_no_context_takeover = client_no_context_takeover self.server_max_window_bits = server_max_window_bits self.client_max_window_bits = client_max_window_bits self.compress_settings = compress_settings + self.require_client_max_window_bits = require_client_max_window_bits def process_request_params( self, @@ -587,7 +597,7 @@ def process_request_params( # None None None # None True None - must change value # None 8≤M≤15 M (or None) - # 8≤N≤15 None Error! + # 8≤N≤15 None None or Error! # 8≤N≤15 True N - must change value # 8≤N≤15 8≤M≤N M (or None) # 8≤N≤15 N Date: Sat, 19 Feb 2022 17:35:01 +0100 Subject: [PATCH 1002/1539] Work around coverage issue in tests. --- tests/legacy/test_client_server.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index eee50be1d..142e4099e 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1438,7 +1438,8 @@ async def run_client(): connect_inst.BACKOFF_MIN = 10 * MS connect_inst.BACKOFF_MAX = 99 * MS connect_inst.BACKOFF_INITIAL = 0 - async for ws in connect_inst: + # coverage has a hard time dealing with this code - I give up. + async for ws in connect_inst: # pragma: no cover await ws.send("spam") msg = await ws.recv() self.assertEqual(msg, "spam") @@ -1457,10 +1458,10 @@ async def run_client(): await server_ws.close() with self.assertRaises(ConnectionClosed): await ws.recv() - pass # work around bug in coverage else: # Exit block with an exception. raise Exception("BOOM!") + pass # work around bug in coverage with self.assertLogs("websockets", logging.INFO) as logs: with self.assertRaisesRegex(Exception, "BOOM!"): From 88d2e2fb51c9154d603861249d26efb0b5e55d80 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 08:03:41 +0100 Subject: [PATCH 1003/1539] Refactor nested try/except for clarity. --- src/websockets/legacy/client.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 6704d16ce..7253f4c59 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -655,23 +655,24 @@ async def __await_impl__(self) -> WebSocketClientProtocol: protocol = cast(WebSocketClientProtocol, protocol) try: - try: - await protocol.handshake( - self._wsuri, - origin=protocol.origin, - available_extensions=protocol.available_extensions, - available_subprotocols=protocol.available_subprotocols, - extra_headers=protocol.extra_headers, - ) - except Exception: - protocol.fail_connection() - await protocol.wait_closed() - raise - else: - self.protocol = protocol - return protocol + await protocol.handshake( + self._wsuri, + origin=protocol.origin, + available_extensions=protocol.available_extensions, + available_subprotocols=protocol.available_subprotocols, + extra_headers=protocol.extra_headers, + ) except RedirectHandshake as exc: + protocol.fail_connection() + await protocol.wait_closed() self.handle_redirect(exc.uri) + except Exception: + protocol.fail_connection() + await protocol.wait_closed() + raise + else: + self.protocol = protocol + return protocol else: raise SecurityError("too many redirects") From 8516801f7f41b91cc491b3466f57a7bd4c8ee544 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 10:52:48 +0100 Subject: [PATCH 1004/1539] Avoid leaking sockets when connect() is canceled. Fix #1113. --- docs/project/changelog.rst | 5 ++++ src/websockets/legacy/client.py | 3 ++- tests/legacy/test_client_server.py | 37 +++++++++++++++++++++++++----- 3 files changed, 38 insertions(+), 7 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index f888aeb2c..3c264a31b 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -35,6 +35,11 @@ Improvements * Made compression negotiation more lax for compatibility with Firefox. +Bug fixes +......... + +* Avoided leaking open sockets when :func:`~client.connect` is canceled. + 10.1 ---- diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 7253f4c59..0bb2cb690 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -666,7 +666,8 @@ async def __await_impl__(self) -> WebSocketClientProtocol: protocol.fail_connection() await protocol.wait_closed() self.handle_redirect(exc.uri) - except Exception: + # Avoid leaking a connected socket when the handshake fails. + except (Exception, asyncio.CancelledError): protocol.fail_connection() await protocol.wait_closed() raise diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 142e4099e..2275ecdf3 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -147,15 +147,11 @@ def with_client(*args, **kwargs): return with_manager(temp_test_client, *args, **kwargs) -def get_server_uri(server, secure=False, resource_name="/", user_info=None): +def get_server_address(server): """ - Return a WebSocket URI for connecting to the given server. + Return an address on which the given server listens. """ - proto = "wss" if secure else "ws" - - user_info = ":".join(user_info) + "@" if user_info else "" - # Pick a random socket in order to test both IPv4 and IPv6 on systems # where both are available. Randomizing tests is usually a bad idea. If # needed, either use the first socket, or test separately IPv4 and IPv6. @@ -169,6 +165,17 @@ def get_server_uri(server, secure=False, resource_name="/", user_info=None): else: # pragma: no cover raise ValueError("expected an IPv6, IPv4, or Unix socket") + return host, port + + +def get_server_uri(server, secure=False, resource_name="/", user_info=None): + """ + Return a WebSocket URI for connecting to the given server. + + """ + proto = "wss" if secure else "ws" + user_info = ":".join(user_info) + "@" if user_info else "" + host, port = get_server_address(server) return f"{proto}://{user_info}{host}:{port}{resource_name}" @@ -1067,6 +1074,21 @@ def test_server_error_in_handshake(self, _process_request): with self.assertRaises(InvalidHandshake): self.start_client() + @with_server(create_protocol=SlowOpeningHandshakeProtocol) + def test_client_connect_canceled_during_handshake(self): + sock = socket.create_connection(get_server_address(self.server)) + sock.send(b"") # socket is connected + + async def cancelled_client(): + start_client = connect(get_server_uri(self.server), sock=sock) + await asyncio.wait_for(start_client, 5 * MS) + + with self.assertRaises(asyncio.TimeoutError): + self.loop.run_until_complete(cancelled_client()) + + with self.assertRaises(OSError): + sock.send(b"") # socket is closed + @with_server() @unittest.mock.patch("websockets.legacy.server.WebSocketServerProtocol.send") def test_server_handler_crashes(self, send): @@ -1199,6 +1221,9 @@ class SecureClientServerTests( CommonClientServerTests, SecureClientServerTestsMixin, AsyncioTestCase ): + # The implementation of this test makes it hard to run it over TLS. + test_client_connect_canceled_during_handshake = None + # TLS over Unix sockets doesn't make sense. test_unix_socket = None From b3794a9fe1e042be38515c9ae4d22922d2d35382 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 10:54:46 +0100 Subject: [PATCH 1005/1539] Fail connection when send() is cancelled in a fragmented message. Ref #1129. --- src/websockets/legacy/protocol.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index c1809e20d..aa737ca26 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -664,7 +664,7 @@ async def send( # Final fragment. await self.write_frame(True, OP_CONT, b"") - except Exception: + except (Exception, asyncio.CancelledError): # We're half-way through a fragmented message and we can't # complete it. This makes the connection unusable. self.fail_connection(1011) @@ -708,7 +708,7 @@ async def send( # Final fragment. await self.write_frame(True, OP_CONT, b"") - except Exception: + except (Exception, asyncio.CancelledError): # We're half-way through a fragmented message and we can't # complete it. This makes the connection unusable. self.fail_connection(1011) From c6f05b6a594bf5ef5857b2ef23402105d98f3a7a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 10:58:26 +0100 Subject: [PATCH 1006/1539] Update to latest black version. --- src/websockets/client.py | 2 +- src/websockets/connection.py | 2 +- src/websockets/http11.py | 2 +- src/websockets/legacy/client.py | 8 ++++---- src/websockets/legacy/protocol.py | 8 ++++---- src/websockets/legacy/server.py | 8 ++++---- src/websockets/server.py | 2 +- 7 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 9b86b4d0a..7a904b151 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -74,7 +74,7 @@ def __init__( extensions: Optional[Sequence[ClientExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, state: State = CONNECTING, - max_size: Optional[int] = 2 ** 20, + max_size: Optional[int] = 2**20, logger: Optional[LoggerLike] = None, ): super().__init__( diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 0a4d3c7bc..15a40f80c 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -87,7 +87,7 @@ def __init__( self, side: Side, state: State = OPEN, - max_size: Optional[int] = 2 ** 20, + max_size: Optional[int] = 2**20, logger: Optional[LoggerLike] = None, ) -> None: # Unique identifier. For logs. diff --git a/src/websockets/http11.py b/src/websockets/http11.py index a2fd22dd2..502ca64f8 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -18,7 +18,7 @@ # Support for HTTP response bodies is intended to read an error message # returned by a server. It isn't designed to perform large file transfers. -MAX_BODY = 2 ** 20 # 1 MiB +MAX_BODY = 2**20 # 1 MiB def d(value: bytes) -> str: diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 0bb2cb690..fadc3efe8 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -442,10 +442,10 @@ def __init__( ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, close_timeout: Optional[float] = None, - max_size: Optional[int] = 2 ** 20, - max_queue: Optional[int] = 2 ** 5, - read_limit: int = 2 ** 16, - write_limit: int = 2 ** 16, + max_size: Optional[int] = 2**20, + max_queue: Optional[int] = 2**5, + read_limit: int = 2**16, + write_limit: int = 2**16, **kwargs: Any, ) -> None: # Backwards compatibility: close_timeout used to be called timeout. diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index aa737ca26..bbcc19664 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -171,10 +171,10 @@ def __init__( ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, close_timeout: Optional[float] = None, - max_size: Optional[int] = 2 ** 20, - max_queue: Optional[int] = 2 ** 5, - read_limit: int = 2 ** 16, - write_limit: int = 2 ** 16, + max_size: Optional[int] = 2**20, + max_queue: Optional[int] = 2**5, + read_limit: int = 2**16, + write_limit: int = 2**16, # The following arguments are kept only for backwards compatibility. host: Optional[str] = None, port: Optional[int] = None, diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 3172059d2..8bc466a38 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -969,10 +969,10 @@ def __init__( ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, close_timeout: Optional[float] = None, - max_size: Optional[int] = 2 ** 20, - max_queue: Optional[int] = 2 ** 5, - read_limit: int = 2 ** 16, - write_limit: int = 2 ** 16, + max_size: Optional[int] = 2**20, + max_queue: Optional[int] = 2**5, + read_limit: int = 2**16, + write_limit: int = 2**16, **kwargs: Any, ) -> None: # Backwards compatibility: close_timeout used to be called timeout. diff --git a/src/websockets/server.py b/src/websockets/server.py index a94c0b629..bb0e0d7e2 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -72,7 +72,7 @@ def __init__( extensions: Optional[Sequence[ServerExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, state: State = CONNECTING, - max_size: Optional[int] = 2 ** 20, + max_size: Optional[int] = 2**20, logger: Optional[LoggerLike] = None, ): super().__init__( From fe946ef0d1fb6dac982879a20f584ef66bc0a879 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 11:11:30 +0100 Subject: [PATCH 1007/1539] Fix flaky test. Ref #1113. --- tests/legacy/test_client_server.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 2275ecdf3..e6fb05d57 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -158,15 +158,12 @@ def get_server_address(server): server_socket = random.choice(server.sockets) if server_socket.family == socket.AF_INET6: # pragma: no cover - host, port = server_socket.getsockname()[:2] # (no IPv6 on CI) - host = f"[{host}]" + return server_socket.getsockname()[:2] # (no IPv6 on CI) elif server_socket.family == socket.AF_INET: - host, port = server_socket.getsockname() + return server_socket.getsockname() else: # pragma: no cover raise ValueError("expected an IPv6, IPv4, or Unix socket") - return host, port - def get_server_uri(server, secure=False, resource_name="/", user_info=None): """ @@ -176,6 +173,8 @@ def get_server_uri(server, secure=False, resource_name="/", user_info=None): proto = "wss" if secure else "ws" user_info = ":".join(user_info) + "@" if user_info else "" host, port = get_server_address(server) + if ":" in host: # IPv6 address + host = f"[{host}]" return f"{proto}://{user_info}{host}:{port}{resource_name}" From 5b3a6d26c493f1aa54d267223ba3a908f0793fc8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 11:12:31 +0100 Subject: [PATCH 1008/1539] Avoid OSError: [Errno 107] noise in logs. Fix #1117. Ref #1072. --- src/websockets/legacy/protocol.py | 10 ++++++++-- tests/legacy/test_protocol.py | 3 --- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index bbcc19664..49c81da4f 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1302,10 +1302,16 @@ async def close_connection(self) -> None: self.logger.debug("! timed out waiting for TCP close") # Half-close the TCP connection if possible (when there's no TLS). - if self.transport.can_write_eof() and not self.transport.is_closing(): + if self.transport.can_write_eof(): if self.debug: self.logger.debug("x half-closing TCP connection") - self.transport.write_eof() + # write_eof() doesn't document which exceptions it raises. + # "[Errno 107] Transport endpoint is not connected" happens + # but it isn't completely clear under which circumstances. + try: + self.transport.write_eof() + except OSError: # pragma: no cover + pass if await self.wait_for_connection_lost(): # Coverage marks this line as a partially executed branch. diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 22e72a696..1672ab1ed 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -59,9 +59,6 @@ def setup_mock(self, loop, protocol): def can_write_eof(self): return True - def is_closing(self): - return False - def write_eof(self): # When the protocol half-closes the TCP connection, it expects the # other end to close it. Simulate that. From 72e9b37d422191faef3e939905a4edaee1145c24 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 14:24:28 +0100 Subject: [PATCH 1009/1539] Expand discussion of keepalives. Also mention pitfall in browsers. Fix #1084. --- docs/topics/timeouts.rst | 84 +++++++++++++++++++++++++++++++++------- 1 file changed, 69 insertions(+), 15 deletions(-) diff --git a/docs/topics/timeouts.rst b/docs/topics/timeouts.rst index 815a29b3f..dcf0322a4 100644 --- a/docs/topics/timeouts.rst +++ b/docs/topics/timeouts.rst @@ -3,34 +3,88 @@ Timeouts .. currentmodule:: websockets +Long-lived connections +---------------------- + Since the WebSocket protocol is intended for real-time communications over long-lived connections, it is desirable to ensure that connections don't break, and if they do, to report the problem quickly. -WebSocket is built on top of HTTP/1.1 where connections are short-lived, even -with ``Connection: keep-alive``. Typically, HTTP/1.1 infrastructure closes -idle connections after 30 to 120 seconds. +Connections can drop as a consequence of temporary network connectivity issues, +which are very common, even within datacenters. + +Furthermore, WebSocket builds on top of HTTP/1.1 where connections are +short-lived, even with ``Connection: keep-alive``. Typically, HTTP/1.1 +infrastructure closes idle connections after 30 to 120 seconds. -As a consequence, proxies may terminate WebSocket connections prematurely, -when no message was exchanged in 30 seconds. +As a consequence, proxies may terminate WebSocket connections prematurely when +no message was exchanged in 30 seconds. -In order to avoid this problem, websockets implements a keepalive mechanism -based on WebSocket Ping_ and Pong_ frames. Ping and Pong are designed for this -purpose. +Keepalive in websockets +----------------------- + +To avoid these problems, websockets runs a keepalive and heartbeat mechanism +based on WebSocket Ping_ and Pong_ frames, which are designed for this purpose. .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 -By default, websockets waits 20 seconds, then sends a Ping frame, and expects -to receive the corresponding Pong frame within 20 seconds. Else, it considers -the connection broken and closes it. +It loops through these steps: + +1. Wait 20 seconds. +2. Send a Ping frame. +3. Receive a corresponding Pong frame within 20 seconds. + +If the Pong frame isn't received, websockets considers the connection broken and +closes it. + +This mechanism serves two purposes: + +1. It creates a trickle of traffic so that the TCP connection isn't idle and + network infrastructure along the path keeps it open ("keepalive"). +2. It detects if the connection drops or becomes so slow that it's unusable in + practice ("heartbeat"). In that case, it terminates the connection and your + application gets a :exc:`~exceptions.ConnectionClosed` exception. Timings are configurable with the ``ping_interval`` and ``ping_timeout`` -arguments of :func:`~client.connect` and :func:`~server.serve`. +arguments of :func:`~client.connect` and :func:`~server.serve`. Shorter values +will detect connection drops faster but they will increase network traffic and +they will be more sensitive to latency. + +Setting ``ping_interval`` to :obj:`None` disables the whole keepalive and +heartbeat mechanism. + +Setting ``ping_timeout`` to :obj:`None` disables only timeouts. This enables +keepalive, to keep idle connections open, and disables heartbeat, to support large +latency spikes. + +.. admonition:: Why doesn't websockets rely on TCP keepalive? + :class: hint + + TCP keepalive is disabled by default on most operating systems. When + enabled, the default interval is two hours or more, which is far too much. + +Keepalive in browsers +--------------------- + +Browsers don't enable a keepalive mechanism like websockets by default. As a +consequence, they can fail to notice that a WebSocket connection is broken for +an extended period of time, until the TCP connection times out. + +In this scenario, the ``WebSocket`` object in the browser doesn't fire a +``close`` event. If you have a reconnection mechanism, it doesn't kick in +because it believes that the connection is still working. + +If your browser-based app mysteriously and randomly fails to receive events, +this is a likely cause. You need a keepalive mechanism in the browser to avoid +this scenario. + +Unfortunately, the WebSocket API in browsers doesn't expose the native Ping and +Pong functionality in the WebSocket protocol. You have to roll your own in the +application layer. -While WebSocket runs on top of TCP, websockets doesn't rely on TCP keepalive -because it's disabled by default and, if enabled, the default interval is no -less than two hours, which doesn't meet requirements. +Latency issues +-------------- Latency between a client and a server may increase for two reasons: From 1167e1d61d9fa4916a70a212441dbcc3e6624268 Mon Sep 17 00:00:00 2001 From: Carlos Sobrinho Date: Tue, 1 Feb 2022 11:56:34 -0800 Subject: [PATCH 1010/1539] Update version.py Allow version check to ignore `git` if the command is not available on the path. --- src/websockets/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/websockets/version.py b/src/websockets/version.py index 605c8264a..324f43168 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -42,7 +42,7 @@ def get_version(tag: str) -> str: check=True, text=True, ).stdout.strip() - except subprocess.CalledProcessError: + except (subprocess.CalledProcessError, FileNotFoundError): pass else: description_re = r"[0-9.]+-([0-9]+)-(g[0-9a-f]{7}(?:-dirty)?)" From 46bc5e4dba19aa7052c5777657f9608fd6245c43 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 14:46:27 +0100 Subject: [PATCH 1011/1539] Add timeout to get_version(). --- src/websockets/version.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/websockets/version.py b/src/websockets/version.py index 324f43168..a7a12a61e 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -39,10 +39,12 @@ def get_version(tag: str) -> str: ["git", "describe", "--dirty", "--tags", "--long"], capture_output=True, cwd=root_dir, + timeout=1, check=True, text=True, ).stdout.strip() - except (subprocess.CalledProcessError, FileNotFoundError): + # subprocess.run raises FileNotFoundError if git isn't on $PATH. + except (FileNotFoundError, subprocess.CalledProcessError): pass else: description_re = r"[0-9.]+-([0-9]+)-(g[0-9a-f]{7}(?:-dirty)?)" From a2ca001a06ebc0d4af7ae66776cb6212e62f8e24 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 16:22:00 +0100 Subject: [PATCH 1012/1539] Add to FAQ how to message one, several, or all users. Fix #1083. --- docs/howto/faq.rst | 94 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 86 insertions(+), 8 deletions(-) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index 128d1dbd0..03f85802e 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -73,6 +73,92 @@ See also Python's documentation about `running blocking code`_. .. _running blocking code: https://docs.python.org/3/library/asyncio-dev.html#running-blocking-code +.. _send-message-to-all-users: + +How do I send a message to all users? +..................................... + +Record all connections in a global variable:: + + CONNECTIONS = set() + + async def handler(websocket): + CONNECTIONS.add(websocket) + try: + await websocket.wait_closed() + finally: + CONNECTIONS.remove(websocket) + +Then, call :func:`~websockets.broadcast`:: + + import websockets + + def message_all(message): + websockets.broadcast(CONNECTIONS, message) + +If you're running multiple server processes, make sure you call ``message_all`` +in each process. + +.. _send-message-to-single-user: + +How do I send a message to a single user? +......................................... + +Record connections in a global variable, keyed by user identifier:: + + CONNECTIONS = {} + + async def handler(websocket): + user_id = ... # identify user in your app's context + CONNECTIONS[user_id] = websocket + try: + await websocket.wait_closed() + finally: + del CONNECTIONS[user_id] + +Then, call :meth:`~legacy.protocol.WebSocketCommonProtocol.send`:: + + async def message_user(user_id, message): + websocket = CONNECTIONS[user_id] # raises KeyError if user disconnected + await websocket.send(message) # may raise websockets.ConnectionClosed + +Add error handling according to the behavior you want if the user disconnected +before the message could be sent. + +This example supports only one connection per user. To support concurrent +connects by the same user, you can change ``CONNECTIONS`` to store a set of +connections for each user. + +If you're running multiple server processes, call ``message_user`` in each +process. The process managing the user's connection sends the message; other +processes do nothing. + +When you reach a scale where server processes cannot keep up with the stream of +all messages, you need a better architecture. For example, you could deploy an +external publish / subscribe system such as Redis_. Server processes would +subscribe their clients. Then, they would receive messages only for the +connections that they're managing. + +.. _Redis: https://redis.io/ + +How do I send a message to a channel, a topic, or a subset of users? +.................................................................... + +websockets doesn't provide built-in publish / subscribe functionality. + +Record connections in a global variable, keyed by user identifier, as shown in +:ref:`How do I send a message to a single user?` + +Then, build the set of recipients and broadcast the message to them, as shown in +:ref:`How do I send a message to all users?` + +:doc:`django` contains a complete implementation of this pattern. + +Again, as you scale, you may reach the performance limits of a basic in-process +implementation. You may need an external publish / subscribe system like Redis_. + +.. _Redis: https://redis.io/ + How can I pass additional arguments to the connection handler? .............................................................. @@ -417,14 +503,6 @@ websockets takes care of responding to pings with pongs. Miscellaneous ------------- -How do I create channels or topics? -................................... - -websockets doesn't have built-in publish / subscribe for these use cases. - -Depending on the scale of your service, a simple in-memory implementation may -do the job or you may need an external publish / subscribe component. - Can I use websockets synchronously, without ``async`` / ``await``? .................................................................. From 731ad8c127bae60b59d622b87d9406333432cfbd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 19:26:27 +0100 Subject: [PATCH 1013/1539] Add broadcast example to quick start guide. Also review and improve that guide. Fix #1070. --- docs/intro/quickstart.rst | 84 ++++++++++++++++++++++++++------- example/quickstart/show_time.py | 5 +- 2 files changed, 69 insertions(+), 20 deletions(-) diff --git a/docs/intro/quickstart.rst b/docs/intro/quickstart.rst index 8c1221126..da3c3999e 100644 --- a/docs/intro/quickstart.rst +++ b/docs/intro/quickstart.rst @@ -5,14 +5,17 @@ Quick start Here are a few examples to get you started quickly with websockets. -Hello world! ------------- +Say "Hello world!" +------------------ Here's a WebSocket server. It receives a name from the client, sends a greeting, and closes the connection. .. literalinclude:: ../../example/quickstart/server.py + :caption: server.py + :language: python + :linenos: :func:`~server.serve` executes the connection handler coroutine ``hello()`` once for each WebSocket connection. It closes the WebSocket connection when @@ -23,14 +26,17 @@ Here's a corresponding WebSocket client. It sends a name to the server, receives a greeting, and closes the connection. .. literalinclude:: ../../example/quickstart/client.py + :caption: client.py + :language: python + :linenos: Using :func:`~client.connect` as an asynchronous context manager ensures the WebSocket connection is closed. .. _secure-server-example: -Encryption ----------- +Encrypt connections +------------------- Secure WebSocket connections improve confidentiality and also reliability because they reduce the risk of interference by bad proxies. @@ -47,46 +53,79 @@ requires certificates like ``https``. TLS is sometimes referred to as SSL (Secure Sockets Layer). SSL was an earlier encryption protocol; the name stuck. -Here's how to adapt the server to encrypt connections. See the documentation -of the :mod:`ssl` module for configuring the context securely. +Here's how to adapt the server to encrypt connections. You must download +:download:`localhost.pem <../../example/quickstart/localhost.pem>` and save it +in the same directory as ``server_secure.py``. + +See the documentation of the :mod:`ssl` module for details on configuring the +TLS context securely. .. literalinclude:: ../../example/quickstart/server_secure.py + :caption: server_secure.py + :language: python + :linenos: Here's how to adapt the client similarly. .. literalinclude:: ../../example/quickstart/client_secure.py + :caption: client_secure.py + :language: python + :linenos: -This client needs a context because the server uses a self-signed certificate. +In this example, the client needs a TLS context because the server uses a +self-signed certificate. When connecting to a secure WebSocket server with a valid certificate — any certificate signed by a CA that your Python installation trusts — you can simply pass ``ssl=True`` to :func:`~client.connect`. -In a browser ------------- +Connect from a browser +---------------------- The WebSocket protocol was invented for the web — as the name says! -Here's how to connect to a WebSocket server in a browser. +Here's how to connect to a WebSocket server from a browser. Run this script in a console: .. literalinclude:: ../../example/quickstart/show_time.py + :caption: show_time.py + :language: python + :linenos: Save this file as ``show_time.html``: .. literalinclude:: ../../example/quickstart/show_time.html - :language: html + :caption: show_time.html + :language: html + :linenos: Save this file as ``show_time.js``: .. literalinclude:: ../../example/quickstart/show_time.js - :language: js + :caption: show_time.js + :language: js + :linenos: + +Then, open ``show_time.html`` in several browsers. Clocks tick irregularly. + +Broadcast messages +------------------ + +Let's change the previous example to send the same timestamps to all browsers, +instead of generating independent sequences for each client. + +Stop the previous script if it's still running and run this script in a console: + +.. literalinclude:: ../../example/quickstart/show_time_2.py + :caption: show_time_2.py + :language: python + :linenos: -Then open ``show_time.html`` in a browser and see the clock tick irregularly. +Refresh ``show_time.html`` in all browsers. Clocks tick in sync. -Broadcast ---------- +Manage application state +------------------------ A WebSocket server can receive events from clients, process them to update the application state, and broadcast the updated state to all connected clients. @@ -97,20 +136,29 @@ concurrency model of :mod:`asyncio` guarantees that updates are serialized. Run this script in a console: .. literalinclude:: ../../example/quickstart/counter.py + :caption: counter.py + :language: python + :linenos: Save this file as ``counter.html``: .. literalinclude:: ../../example/quickstart/counter.html - :language: html + :caption: counter.html + :language: html + :linenos: Save this file as ``counter.css``: .. literalinclude:: ../../example/quickstart/counter.css - :language: css + :caption: counter.css + :language: css + :linenos: Save this file as ``counter.js``: .. literalinclude:: ../../example/quickstart/counter.js - :language: js + :caption: counter.js + :language: js + :linenos: Then open ``counter.html`` file in several browsers and play with [+] and [-]. diff --git a/example/quickstart/show_time.py b/example/quickstart/show_time.py index facd56b00..a83078e8a 100755 --- a/example/quickstart/show_time.py +++ b/example/quickstart/show_time.py @@ -6,8 +6,9 @@ import websockets async def show_time(websocket): - while websocket.open: - await websocket.send(datetime.datetime.utcnow().isoformat() + "Z") + while True: + message = datetime.datetime.utcnow().isoformat() + "Z" + await websocket.send(message) await asyncio.sleep(random.random() * 2 + 1) async def main(): From 330dd3dab7af30fbda404e0b0c03deb7d3cd9dd9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 19:36:41 +0100 Subject: [PATCH 1014/1539] Update FAQ on reconnection for clients. --- docs/howto/faq.rst | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index 03f85802e..262551664 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -276,9 +276,16 @@ The easiest is to use :func:`~client.connect` as a context manager:: How do I reconnect automatically when the connection drops? ........................................................... -See `issue 414`_. +Use :func:`connect` as an asynchronous iterator:: -.. _issue 414: https://github.com/aaugustin/websockets/issues/414 + async for websocket in websockets.connect(...): + try: + ... + except websockets.ConnectionClosed: + continue + +Make sure you handle exceptions in the ``async for`` loop. Uncaught exceptions +will break out of the loop. How do I stop a client that is continuously processing messages? ................................................................ From a113adc08a97b1689acd5d79fbf9ef7f00805c0a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 19:39:05 +0100 Subject: [PATCH 1015/1539] Add FAQ on stopping a server. --- docs/howto/faq.rst | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index 262551664..44ae889da 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -220,11 +220,22 @@ To listen only on IPv4, specify ``host="0.0.0.0"`` or ``family=socket.AF_INET``. Refer to the documentation of :meth:`~asyncio.loop.create_server` for details. -How do I close a connection properly? -..................................... +How do I close a connection? +............................ websockets takes care of closing the connection when the handler exits. +How do I stop a server? +....................... + +Exit the :func:`~server.serve` context manager. + +Here's an example that terminates cleanly when it receives SIGTERM on Unix: + +.. literalinclude:: ../../example/shutdown_server.py + :emphasize-lines: 12-15,18 + + How do I run a HTTP server and WebSocket server on the same port? ................................................................. @@ -265,8 +276,8 @@ change it to:: async with connect(...) as websocket: await do_some_work() -How do I close a connection properly? -..................................... +How do I close a connection? +............................ The easiest is to use :func:`~client.connect` as a context manager:: @@ -287,8 +298,8 @@ Use :func:`connect` as an asynchronous iterator:: Make sure you handle exceptions in the ``async for`` loop. Uncaught exceptions will break out of the loop. -How do I stop a client that is continuously processing messages? -................................................................ +How do I stop a client that is processing messages in a loop? +............................................................. You can close the connection. From c7d8b587c1d2124ca04ffb618f15566a64b9eaa6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 20:09:43 +0100 Subject: [PATCH 1016/1539] Shorten questions in FAQ. --- docs/howto/faq.rst | 70 ++++++++++++++++++++++++---------------------- 1 file changed, 36 insertions(+), 34 deletions(-) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index 44ae889da..30259e37e 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -15,8 +15,8 @@ FAQ Server side ----------- -Why does my server close the connection prematurely? -.................................................... +Why does the server close the connection prematurely? +..................................................... Your connection handler exits prematurely. Wait for the work to be finished before returning. @@ -31,8 +31,8 @@ change it to:: async def handler(websocket): await do_some_work() -Why does the server close the connection after processing one message? -...................................................................... +Why does the server close the connection after one message? +........................................................... Your connection handler exits after processing one message. Write a loop to process multiple messages. @@ -141,8 +141,8 @@ connections that they're managing. .. _Redis: https://redis.io/ -How do I send a message to a channel, a topic, or a subset of users? -.................................................................... +How do I send a message to a channel, a topic, or some users? +............................................................. websockets doesn't provide built-in publish / subscribe functionality. @@ -159,8 +159,8 @@ implementation. You may need an external publish / subscribe system like Redis_. .. _Redis: https://redis.io/ -How can I pass additional arguments to the connection handler? -.............................................................. +How do I pass arguments to the connection handler? +.................................................. You can bind additional arguments to the connection handler with :func:`functools.partial`:: @@ -179,8 +179,8 @@ Another way to achieve this result is to define the ``handler`` coroutine in a scope where the ``extra_argument`` variable exists instead of injecting it through an argument. -How do I get access HTTP headers, for example cookies? -...................................................... +How do I access HTTP headers, like cookies? +........................................... To access HTTP headers during the WebSocket handshake, you can override :attr:`~server.WebSocketServerProtocol.process_request`:: @@ -194,16 +194,16 @@ Once the connection is established, they're available in async def handler(websocket): cookies = websocket.request_headers["Cookie"] -How do I get the IP address of the client connecting to my server? -.................................................................. +How do I get the IP address of the client? +.......................................... It's available in :attr:`~legacy.protocol.WebSocketCommonProtocol.remote_address`:: async def handler(websocket): remote_ip = websocket.remote_address[0] -How do I set which IP addresses my server listens to? -..................................................... +How do I set the IP addresses my server listens on? +................................................... Look at the ``host`` argument of :meth:`~asyncio.loop.create_server`. @@ -236,13 +236,13 @@ Here's an example that terminates cleanly when it receives SIGTERM on Unix: :emphasize-lines: 12-15,18 -How do I run a HTTP server and WebSocket server on the same port? -................................................................. +How do I run HTTP and WebSocket servers on the same port? +......................................................... You don't. -HTTP and WebSockets have widely different operational characteristics. -Running them with the same server becomes inconvenient when you scale. +HTTP and WebSocket have widely different operational characteristics. Running +them with the same server becomes inconvenient when you scale. Providing a HTTP server is out of scope for websockets. It only aims at providing a WebSocket server. @@ -260,8 +260,8 @@ support WebSocket connections, like Sanic_. Client side ----------- -Why does my client close the connection prematurely? -.................................................... +Why does the client close the connection prematurely? +..................................................... You're exiting the context manager prematurely. Wait for the work to be finished before exiting. @@ -284,8 +284,8 @@ The easiest is to use :func:`~client.connect` as a context manager:: async with connect(...) as websocket: ... -How do I reconnect automatically when the connection drops? -........................................................... +How do I reconnect when the connection drops? +............................................. Use :func:`connect` as an asynchronous iterator:: @@ -319,8 +319,8 @@ Look at the ``ssl`` argument of :meth:`~asyncio.loop.create_connection`. asyncio usage ------------- -How do I do two things in parallel? How do I integrate with another coroutine? -.............................................................................. +How do I run two coroutines in parallel? +........................................ You must start two tasks, which the event loop will run concurrently. You can achieve this with :func:`asyncio.gather` or :func:`asyncio.create_task`. @@ -359,8 +359,8 @@ See `issue 867`_. .. _issue 867: https://github.com/aaugustin/websockets/issues/867 -Why does my very simple program misbehave mysteriously? -....................................................... +Why does my simple program misbehave mysteriously? +.................................................. You are using :func:`time.sleep` instead of :func:`asyncio.sleep`, which blocks the event loop and prevents asyncio from operating normally. @@ -484,8 +484,8 @@ See `issue 574`_. .. _issue 574: https://github.com/aaugustin/websockets/issues/574 -How can I pass additional arguments to a custom protocol subclass? -.................................................................. +How can I pass arguments to a custom protocol subclass? +....................................................... You can bind additional arguments to the protocol factory with :func:`functools.partial`:: @@ -513,16 +513,18 @@ It closes the connection if it doesn't get a pong within 20 seconds. You can adjust this behavior with ``ping_interval`` and ``ping_timeout``. +See :doc:`../topics/timeouts` for details. + How do I respond to pings? .......................... -websockets takes care of responding to pings with pongs. +Don't bother; websockets takes care of responding to pings with pongs. Miscellaneous ------------- -Can I use websockets synchronously, without ``async`` / ``await``? -.................................................................. +Can I use websockets without ``async`` and ``await``? +..................................................... You can convert every asynchronous call to a synchronous call by wrapping it in ``asyncio.get_event_loop().run_until_complete(...)``. Unfortunately, this @@ -547,9 +549,9 @@ Often, this is because you created a script called ``websockets.py`` in your current working directory. Then ``import websockets`` imports this module instead of the websockets library. -I'm having problems with threads -................................ +Why am I having problems with threads? +...................................... You shouldn't use threads. Use tasks instead. -:meth:`~asyncio.loop.call_soon_threadsafe` may help. +If you have to, :meth:`~asyncio.loop.call_soon_threadsafe` may help. From 5407db58d5a840a78dbc3572545030dc2d5036fe Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 20:10:20 +0100 Subject: [PATCH 1017/1539] Clarify answer on async/await. --- docs/howto/faq.rst | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index 30259e37e..1eda6adad 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -526,11 +526,7 @@ Miscellaneous Can I use websockets without ``async`` and ``await``? ..................................................... -You can convert every asynchronous call to a synchronous call by wrapping it -in ``asyncio.get_event_loop().run_until_complete(...)``. Unfortunately, this -is deprecated as of Python 3.10. - -If this turns out to be impractical, you should use another library. +No, there is no convenient way to do this. You should use another library. Are there ``onopen``, ``onmessage``, ``onerror``, and ``onclose`` callbacks? ............................................................................ From 695d43a90819d817562b5d51bb2eb8c1b641a324 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 20:31:23 +0100 Subject: [PATCH 1018/1539] Improve FAQ about HTTP headers. Fix #798. --- docs/howto/faq.rst | 47 +++++++++++++++++++++++++++++++++++++++------- 1 file changed, 40 insertions(+), 7 deletions(-) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index 1eda6adad..bc742f458 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -179,20 +179,30 @@ Another way to achieve this result is to define the ``handler`` coroutine in a scope where the ``extra_argument`` variable exists instead of injecting it through an argument. -How do I access HTTP headers, like cookies? -........................................... +How do I access HTTP headers? +............................. To access HTTP headers during the WebSocket handshake, you can override :attr:`~server.WebSocketServerProtocol.process_request`:: async def process_request(self, path, request_headers): - cookies = request_header["Cookie"] + authorization = request_headers["Authorization"] -Once the connection is established, they're available in -:attr:`~server.WebSocketServerProtocol.request_headers`:: +Once the connection is established, HTTP headers are available in +:attr:`~server.WebSocketServerProtocol.request_headers` and +:attr:`~server.WebSocketServerProtocol.response_headers`:: async def handler(websocket): - cookies = websocket.request_headers["Cookie"] + authorization = websocket.request_headers["Authorization"] + +How do I set HTTP headers? +.......................... + +To set the ``Sec-WebSocket-Extensions`` or ``Sec-WebSocket-Protocol`` headers in +the WebSocket handshake response, use the ``extensions`` or ``subprotocols`` +arguments of :func:`~server.serve`. + +To set other HTTP headers, use the ``extra_headers`` argument. How do I get the IP address of the client? .......................................... @@ -276,6 +286,27 @@ change it to:: async with connect(...) as websocket: await do_some_work() +How do I access HTTP headers? +............................. + +Once the connection is established, HTTP headers are available in +:attr:`~client.WebSocketClientProtocol.request_headers` and +:attr:`~client.WebSocketClientProtocol.response_headers`. + +How do I set HTTP headers? +.......................... + +To set the ``Origin``, ``Sec-WebSocket-Extensions``, or +``Sec-WebSocket-Protocol`` headers in the WebSocket handshake request, use the +``origin``, ``extensions``, or ``subprotocols`` arguments of +:func:`~client.connect`. + +To set other HTTP headers, for example the ``Authorization`` header, use the +``extra_headers`` argument:: + + async with connect(..., extra_headers={"Authorization": ...}) as websocket: + ... + How do I close a connection? ............................ @@ -284,10 +315,12 @@ The easiest is to use :func:`~client.connect` as a context manager:: async with connect(...) as websocket: ... +The connection is closed when exiting the context manager. + How do I reconnect when the connection drops? ............................................. -Use :func:`connect` as an asynchronous iterator:: +Use :func:`~client.connect` as an asynchronous iterator:: async for websocket in websockets.connect(...): try: From f605e86aeab7ac045c21e2eac0d677916031022e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 21:28:28 +0100 Subject: [PATCH 1019/1539] Restore backwards-compatibility for partial handlers. Fix #1095. --- docs/project/changelog.rst | 3 +++ src/websockets/legacy/server.py | 28 +++++++++++++++++++--------- tests/legacy/test_client_server.py | 17 +++++++++++++++++ 3 files changed, 39 insertions(+), 9 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 3c264a31b..b47ae2341 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -38,6 +38,9 @@ Improvements Bug fixes ......... +* Fixed backwards-incompatibility in 10.1 for connection handlers created with + :func:`functools.partial`. + * Avoided leaking open sockets when :func:`~client.connect` is canceled. 10.1 diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 8bc466a38..ea5b0d1fa 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -1126,17 +1126,27 @@ def remove_path_argument( Callable[[WebSocketServerProtocol, str], Awaitable[Any]], ] ) -> Callable[[WebSocketServerProtocol], Awaitable[Any]]: - if len(inspect.signature(ws_handler).parameters) == 2: - # Enable deprecation warning and announce deprecation in 11.0. - # warnings.warn("remove second argument of ws_handler", DeprecationWarning) + try: + inspect.signature(ws_handler).bind(None) + except TypeError: + try: + inspect.signature(ws_handler).bind(None, "") + except TypeError: # pragma: no cover + # ws_handler accepts neither one nor two arguments; leave it alone. + pass + else: + # ws_handler accepts two arguments; activate backwards compatibility. + + # Enable deprecation warning and announce deprecation in 11.0. + # warnings.warn("remove second argument of ws_handler", DeprecationWarning) - async def _ws_handler(websocket: WebSocketServerProtocol) -> Any: - return await cast( - Callable[[WebSocketServerProtocol, str], Awaitable[Any]], - ws_handler, - )(websocket, websocket.path) + async def _ws_handler(websocket: WebSocketServerProtocol) -> Any: + return await cast( + Callable[[WebSocketServerProtocol, str], Awaitable[Any]], + ws_handler, + )(websocket, websocket.path) - return _ws_handler + return _ws_handler return cast( Callable[[WebSocketServerProtocol], Awaitable[Any]], diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index e6fb05d57..acc231d76 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -497,6 +497,23 @@ async def handler_with_path(ws, path): "/path", ) + def test_ws_handler_argument_backwards_compatibility_partial(self): + async def handler_with_path(ws, path, extra): + await ws.send(path) + + bound_handler_with_path = functools.partial(handler_with_path, extra=None) + + with self.temp_server( + handler=bound_handler_with_path, + # Enable deprecation warning and announce deprecation in 11.0. + # deprecation_warnings=["remove second argument of ws_handler"], + ): + with self.temp_client("/path"): + self.assertEqual( + self.loop.run_until_complete(self.client.recv()), + "/path", + ) + async def process_request_OK(path, request_headers): return http.HTTPStatus.OK, [], b"OK\n" From a1d16ba97344e7f98c3dfce9295231c3862fde10 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 21 Feb 2022 07:40:14 +0100 Subject: [PATCH 1020/1539] Update changelog for 10.2. --- docs/project/changelog.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index b47ae2341..027b2d8bf 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -35,6 +35,8 @@ Improvements * Made compression negotiation more lax for compatibility with Firefox. +* Improved FAQ and quick start guide. + Bug fixes ......... From 498cc8c061e53f0001cb2e3ade22ee8ce5ff11a1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 21 Feb 2022 07:41:08 +0100 Subject: [PATCH 1021/1539] Release 10.2. --- docs/project/changelog.rst | 2 +- src/websockets/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 027b2d8bf..019aac324 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -28,7 +28,7 @@ They may change at any time. 10.2 ---- -*In development* +*February 21, 2022* Improvements ............ diff --git a/src/websockets/version.py b/src/websockets/version.py index a7a12a61e..ca379afe2 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -16,7 +16,7 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = False +released = True tag = version = commit = "10.2" From 778a1ca6936ac67e7a3fe1bbe585db2eafeaa515 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 21 Feb 2022 07:42:14 +0100 Subject: [PATCH 1022/1539] Start 10.3. --- docs/project/changelog.rst | 5 +++++ src/websockets/version.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 019aac324..81dee9e10 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,11 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented APIs are considered private. They may change at any time. +10.3 +---- + +*In development* + 10.2 ---- diff --git a/src/websockets/version.py b/src/websockets/version.py index ca379afe2..605a0eeef 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -16,9 +16,9 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = True +released = False -tag = version = commit = "10.2" +tag = version = commit = "10.3" if not released: # pragma: no cover From c60df611023ac47345d9201b0a4785c4d8dbdbd9 Mon Sep 17 00:00:00 2001 From: Erik van Raalte Date: Sat, 26 Feb 2022 21:33:26 +0100 Subject: [PATCH 1023/1539] Update documentation Fix misleading function definition for `play` function in intro/tutorial2 --- docs/intro/tutorial2.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/intro/tutorial2.rst b/docs/intro/tutorial2.rst index 669d46cde..5f4be2047 100644 --- a/docs/intro/tutorial2.rst +++ b/docs/intro/tutorial2.rst @@ -382,7 +382,7 @@ single coroutine to process the moves of both players: .. code-block:: python - async def play(game, player, connected): + async def play(websocket, game, player, connected): ... With such a coroutine, you can replace the temporary code for testing in @@ -390,13 +390,13 @@ With such a coroutine, you can replace the temporary code for testing in .. code-block:: python - await play(game, PLAYER1, connected) + await play(websocket, game, PLAYER1, connected) and in ``join()`` by: .. code-block:: python - await play(game, PLAYER2, connected) + await play(websocket, game, PLAYER2, connected) The ``play()`` coroutine will reuse much of the code you wrote in the first part of the tutorial. From a7b9860538c50e58a06f751b5f9eecde575fae2a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Mar 2022 08:19:26 +0100 Subject: [PATCH 1024/1539] Add file forgotten in 731ad8c. --- example/quickstart/show_time_2.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100755 example/quickstart/show_time_2.py diff --git a/example/quickstart/show_time_2.py b/example/quickstart/show_time_2.py new file mode 100755 index 000000000..08e87f593 --- /dev/null +++ b/example/quickstart/show_time_2.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python + +import asyncio +import datetime +import random +import websockets + +CONNECTIONS = set() + +async def register(websocket): + CONNECTIONS.add(websocket) + try: + await websocket.wait_closed() + finally: + CONNECTIONS.remove(websocket) + +async def show_time(): + while True: + message = datetime.datetime.utcnow().isoformat() + "Z" + websockets.broadcast(CONNECTIONS, message) + await asyncio.sleep(random.random() * 2 + 1) + +async def main(): + async with websockets.serve(register, "localhost", 5678): + await show_time() + +if __name__ == "__main__": + asyncio.run(main()) From db600e51a0edd6d1fbc18891775182cbc00f5f62 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Mar 2022 22:37:05 +0100 Subject: [PATCH 1025/1539] Work around problematic change in typeshed. Refs https://github.com/python/typeshed/pull/6653. --- src/websockets/datastructures.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index 37d3b5f86..77b68ba9a 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -134,7 +134,7 @@ def clear(self) -> None: self._dict = {} self._list = [] - def update(self, *args: HeadersLike, **kwargs: str) -> None: + def update(self, *args: HeadersLike, **kwargs: str) -> None: # type: ignore """ Update from a :class:`Headers` instance and/or keyword arguments. From 747df3d174acdbcda5506ed1588bd635fedd7078 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 21 Mar 2022 22:35:24 +0100 Subject: [PATCH 1026/1539] Widen HeadersLike type. Restore compatibility with the latest mypy. Refs https://github.com/python/typeshed/pull/6653. --- docs/reference/types.rst | 2 ++ src/websockets/datastructures.py | 32 +++++++++++++++++++++++++++++--- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/docs/reference/types.rst b/docs/reference/types.rst index 4b7952553..d86429be4 100644 --- a/docs/reference/types.rst +++ b/docs/reference/types.rst @@ -18,3 +18,5 @@ Types .. autodata:: websockets.connection.Event .. autodata:: websockets.datastructures.HeadersLike + +.. autodata:: websockets.datastructures.SupportsKeysAndGetItem diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index 77b68ba9a..a0a648463 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -8,6 +8,7 @@ List, Mapping, MutableMapping, + Protocol, Tuple, Union, ) @@ -134,7 +135,7 @@ def clear(self) -> None: self._dict = {} self._list = [] - def update(self, *args: HeadersLike, **kwargs: str) -> None: # type: ignore + def update(self, *args: HeadersLike, **kwargs: str) -> None: """ Update from a :class:`Headers` instance and/or keyword arguments. @@ -164,5 +165,30 @@ def raw_items(self) -> Iterator[Tuple[str, str]]: return iter(self._list) -HeadersLike = Union[Headers, Mapping[str, str], Iterable[Tuple[str, str]]] -"""Types accepted where :class:`Headers` is expected.""" +# copy of _typeshed.SupportsKeysAndGetItem. +class SupportsKeysAndGetItem(Protocol): # pragma: no cover + """ + Dict-like types with ``keys() -> str`` and ``__getitem__(key: str) -> str`` methods. + + """ + + def keys(self) -> Iterable[str]: + ... + + def __getitem__(self, key: str) -> str: + ... + + +HeadersLike = Union[ + Headers, + Mapping[str, str], + Iterable[Tuple[str, str]], + SupportsKeysAndGetItem, +] +""" +Types accepted where :class:`Headers` is expected. + +In addition to :class:`Headers` itself, this includes dict-like types where both +keys and values are :class:`str`. + +""" From f7478be1321e5ccc9cd8995e2e95676faff57944 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 23 Mar 2022 07:42:36 +0100 Subject: [PATCH 1027/1539] Restore compatibility with Python 3.7. --- src/websockets/datastructures.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index a0a648463..36a2cbaf9 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys from typing import ( Any, Dict, @@ -8,12 +9,17 @@ List, Mapping, MutableMapping, - Protocol, Tuple, Union, ) +if sys.version_info[:2] >= (3, 8): + from typing import Protocol +else: # pragma: no cover + Protocol = object # mypy will report errors on Python 3.7. + + __all__ = ["Headers", "HeadersLike", "MultipleValuesError"] From 286768512b0c2bd671cae0ae3e64c1545632b6d4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 26 Mar 2022 08:51:27 +0100 Subject: [PATCH 1028/1539] Remove path parameters in connection handlers. 9b8a8d1c wasn't reflected properly before the tutorial was merged. Fix #1154. --- docs/intro/tutorial2.rst | 2 +- example/tutorial/step2/app.py | 2 +- example/tutorial/step3/app.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/intro/tutorial2.rst b/docs/intro/tutorial2.rst index 5f4be2047..53f295d3f 100644 --- a/docs/intro/tutorial2.rst +++ b/docs/intro/tutorial2.rst @@ -212,7 +212,7 @@ previous one to reuse it later: del JOIN[join_key] - async def handler(websocket, path): + async def handler(websocket): # Receive and parse the "init" event from the UI. message = await websocket.recv() event = json.loads(message) diff --git a/example/tutorial/step2/app.py b/example/tutorial/step2/app.py index bac2b6f27..2693d4304 100644 --- a/example/tutorial/step2/app.py +++ b/example/tutorial/step2/app.py @@ -160,7 +160,7 @@ async def watch(websocket, watch_key): connected.remove(websocket) -async def handler(websocket, path): +async def handler(websocket): """ Handle a connection and dispatch it according to who is connecting. diff --git a/example/tutorial/step3/app.py b/example/tutorial/step3/app.py index 6fff79c95..c2ee020d2 100644 --- a/example/tutorial/step3/app.py +++ b/example/tutorial/step3/app.py @@ -162,7 +162,7 @@ async def watch(websocket, watch_key): connected.remove(websocket) -async def handler(websocket, path): +async def handler(websocket): """ Handle a connection and dispatch it according to who is connecting. From 0796c43c5c88e385d11471264eb519d421c7232d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 3 Apr 2022 09:25:55 +0200 Subject: [PATCH 1029/1539] Simplify implementation of close_expected(). --- src/websockets/connection.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 15a40f80c..ca996ba62 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -481,7 +481,6 @@ def close_expected(self) -> bool: bool: Whether the TCP connection is expected to close soon. """ - # We already got a TCP Close if and only if the state is CLOSED. # We expect a TCP close if and only if we sent a close frame: # * Normal closure: once we send a close frame, we expect a TCP close: # server waits for client to complete the TCP closing handshake; @@ -489,7 +488,8 @@ def close_expected(self) -> bool: # * Abnormal closure: we always send a close frame and the same logic # applies, except on EOFError where we don't send a close frame # because we already received the TCP close, so we don't expect it. - return self.state is not CLOSED and self.close_sent is not None + # We already got a TCP Close if and only if the state is CLOSED. + return self.state is CLOSING # Private methods for receiving data. From 5dc16c27001a9465b22832e5b02036d2f10862bb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 3 Apr 2022 10:28:12 +0200 Subject: [PATCH 1030/1539] Store handshake exceptions in connection objects. Storing them in request/response objects was probably legacy. --- docs/howto/sansio.rst | 14 ++++++++------ docs/project/changelog.rst | 11 +++++++++++ docs/reference/client.rst | 2 ++ docs/reference/server.rst | 2 ++ src/websockets/client.py | 3 ++- src/websockets/connection.py | 9 +++++++++ src/websockets/http11.py | 27 +++++++++++++++++++++------ src/websockets/server.py | 12 ++++++++---- tests/test_client.py | 26 +++++++++++++------------- tests/test_server.py | 32 ++++++++++++++++---------------- 10 files changed, 92 insertions(+), 46 deletions(-) diff --git a/docs/howto/sansio.rst b/docs/howto/sansio.rst index 83496bff2..1373c81a5 100644 --- a/docs/howto/sansio.rst +++ b/docs/howto/sansio.rst @@ -52,10 +52,11 @@ the network, as described in `Send data`_ below. The first event returned by :meth:`~connection.Connection.events_received` is the WebSocket handshake response. -When the handshake fails, the reason is available in ``response.exception``:: +When the handshake fails, the reason is available in +:attr:`~client.ClientConnection.handshake_exc`:: - if response.exception is not None: - raise response.exception + if connection.handshake_exc is not None: + raise connection.handshake_exc Else, the WebSocket connection is open. @@ -96,10 +97,11 @@ the network, as described in `Send data`_ below. Even when you call :meth:`~server.ServerConnection.accept`, the WebSocket handshake may fail if the request is incorrect or unsupported. -When the handshake fails, the reason is available in ``request.exception``:: +When the handshake fails, the reason is available in +:attr:`~server.ServerConnection.handshake_exc`:: - if request.exception is not None: - raise request.exception + if connection.handshake_exc is not None: + raise connection.handshake_exc Else, the WebSocket connection is open. diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 81dee9e10..26e9a5cdc 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -30,6 +30,17 @@ They may change at any time. *In development* +Backwards-incompatible changes +.............................. + +.. admonition:: The ``exception`` attribute of :class:`~http11.Request` and :class:`~http11.Response` is deprecated. + :class: note + + Use the ``handshake_exc`` attribute of :class:`~server.ServerConnection` and + :class:`~client.ClientConnection` instead. + + See :doc:`../howto/sansio` for details. + 10.2 ---- diff --git a/docs/reference/client.rst b/docs/reference/client.rst index daf01ef58..379765397 100644 --- a/docs/reference/client.rst +++ b/docs/reference/client.rst @@ -103,6 +103,8 @@ Sans-I/O .. autoproperty:: state + .. autoattribute:: handshake_exc + .. autoproperty:: close_code .. autoproperty:: close_reason diff --git a/docs/reference/server.rst b/docs/reference/server.rst index 97bf320b6..0e446f382 100644 --- a/docs/reference/server.rst +++ b/docs/reference/server.rst @@ -151,6 +151,8 @@ Sans-I/O .. autoproperty:: state + .. autoattribute:: handshake_exc + .. autoproperty:: close_code .. autoproperty:: close_reason diff --git a/src/websockets/client.py b/src/websockets/client.py index 7a904b151..8d826feaa 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -334,7 +334,8 @@ def parse(self) -> Generator[None, None, None]: try: self.process_response(response) except InvalidHandshake as exc: - response.exception = exc + response._exception = exc + self.handshake_exc = exc else: assert self.state is CONNECTING self.state = OPEN diff --git a/src/websockets/connection.py b/src/websockets/connection.py index ca996ba62..967bd8fa5 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -130,6 +130,15 @@ def __init__( self.close_sent: Optional[Close] = None self.close_rcvd_then_sent: Optional[bool] = None + # Track if an exception happened during the handshake. + self.handshake_exc: Optional[Exception] = None + """ + Exception to raise if the opening handshake failed. + + :obj:`None` if the opening handshake succeeded. + + """ + # Track if send_eof() was called. self.eof_sent = False diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 502ca64f8..84048fa47 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -2,6 +2,7 @@ import dataclasses import re +import warnings from typing import Callable, Generator, Optional from . import datastructures, exceptions @@ -57,15 +58,22 @@ class Request: Attributes: path: Request path, including optional query. headers: Request headers. - exception: If processing the response triggers an exception, - the exception is stored in this attribute. """ path: str headers: datastructures.Headers # body isn't useful is the context of this library. - exception: Optional[Exception] = None + _exception: Optional[Exception] = None + + @property + def exception(self) -> Optional[Exception]: # pragma: no cover + warnings.warn( + "Request.exception is deprecated; " + "use ServerConnection.handshake_exc instead", + DeprecationWarning, + ) + return self._exception @classmethod def parse( @@ -152,8 +160,6 @@ class Response: reason_phrase: Response reason. headers: Response headers. body: Response body, if any. - exception: if processing the response triggers an exception, - the exception is stored in this attribute. """ @@ -162,7 +168,16 @@ class Response: headers: datastructures.Headers body: Optional[bytes] = None - exception: Optional[Exception] = None + _exception: Optional[Exception] = None + + @property + def exception(self) -> Optional[Exception]: # pragma: no cover + warnings.warn( + "Response.exception is deprecated; " + "use ClientConnection.handshake_exc instead", + DeprecationWarning, + ) + return self._exception @classmethod def parse( diff --git a/src/websockets/server.py b/src/websockets/server.py index bb0e0d7e2..214417ad0 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -110,7 +110,8 @@ def accept(self, request: Request) -> Response: protocol_header, ) = self.process_request(request) except InvalidOrigin as exc: - request.exception = exc + request._exception = exc + self.handshake_exc = exc if self.debug: self.logger.debug("! invalid origin", exc_info=True) return self.reject( @@ -118,7 +119,8 @@ def accept(self, request: Request) -> Response: f"Failed to open a WebSocket connection: {exc}.\n", ) except InvalidUpgrade as exc: - request.exception = exc + request._exception = exc + self.handshake_exc = exc if self.debug: self.logger.debug("! invalid upgrade", exc_info=True) response = self.reject( @@ -133,7 +135,8 @@ def accept(self, request: Request) -> Response: response.headers["Upgrade"] = "websocket" return response except InvalidHandshake as exc: - request.exception = exc + request._exception = exc + self.handshake_exc = exc if self.debug: self.logger.debug("! invalid handshake", exc_info=True) return self.reject( @@ -141,7 +144,8 @@ def accept(self, request: Request) -> Response: f"Failed to open a WebSocket connection: {exc}.\n", ) except Exception as exc: - request.exception = exc + request._exception = exc + self.handshake_exc = exc self.logger.error("opening handshake failed", exc_info=True) return self.reject( http.HTTPStatus.INTERNAL_SERVER_ERROR, diff --git a/tests/test_client.py b/tests/test_client.py index 1c1452d41..12fd8726f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -255,7 +255,7 @@ def test_missing_connection(self): self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHeader) as raised: - raise response.exception + raise client.handshake_exc self.assertEqual(str(raised.exception), "missing Connection header") def test_invalid_connection(self): @@ -268,7 +268,7 @@ def test_invalid_connection(self): self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHeader) as raised: - raise response.exception + raise client.handshake_exc self.assertEqual(str(raised.exception), "invalid Connection header: close") def test_missing_upgrade(self): @@ -280,7 +280,7 @@ def test_missing_upgrade(self): self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHeader) as raised: - raise response.exception + raise client.handshake_exc self.assertEqual(str(raised.exception), "missing Upgrade header") def test_invalid_upgrade(self): @@ -293,7 +293,7 @@ def test_invalid_upgrade(self): self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHeader) as raised: - raise response.exception + raise client.handshake_exc self.assertEqual(str(raised.exception), "invalid Upgrade header: h2c") def test_missing_accept(self): @@ -305,7 +305,7 @@ def test_missing_accept(self): self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHeader) as raised: - raise response.exception + raise client.handshake_exc self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Accept header") def test_multiple_accept(self): @@ -317,7 +317,7 @@ def test_multiple_accept(self): self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHeader) as raised: - raise response.exception + raise client.handshake_exc self.assertEqual( str(raised.exception), "invalid Sec-WebSocket-Accept header: " @@ -334,7 +334,7 @@ def test_invalid_accept(self): self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHeader) as raised: - raise response.exception + raise client.handshake_exc self.assertEqual( str(raised.exception), f"invalid Sec-WebSocket-Accept header: {ACCEPT}" ) @@ -383,7 +383,7 @@ def test_unexpected_extension(self): self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHandshake) as raised: - raise response.exception + raise client.handshake_exc self.assertEqual(str(raised.exception), "no extensions supported") def test_unsupported_extension(self): @@ -398,7 +398,7 @@ def test_unsupported_extension(self): self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHandshake) as raised: - raise response.exception + raise client.handshake_exc self.assertEqual( str(raised.exception), "Unsupported extension: name = x-op, params = [('op', None)]", @@ -429,7 +429,7 @@ def test_unsupported_extension_parameters(self): self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHandshake) as raised: - raise response.exception + raise client.handshake_exc self.assertEqual( str(raised.exception), "Unsupported extension: name = x-op, params = [('op', 'that')]", @@ -520,7 +520,7 @@ def test_unexpected_subprotocol(self): self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHandshake) as raised: - raise response.exception + raise client.handshake_exc self.assertEqual(str(raised.exception), "no subprotocols supported") def test_multiple_subprotocols(self): @@ -536,7 +536,7 @@ def test_multiple_subprotocols(self): self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHandshake) as raised: - raise response.exception + raise client.handshake_exc self.assertEqual( str(raised.exception), "multiple subprotocols: superchat, chat" ) @@ -566,7 +566,7 @@ def test_unsupported_subprotocol(self): self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHandshake) as raised: - raise response.exception + raise client.handshake_exc self.assertEqual(str(raised.exception), "unsupported subprotocol: otherchat") diff --git a/tests/test_server.py b/tests/test_server.py index 54699c3ef..43bc03e14 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -188,7 +188,7 @@ def test_unexpected_exception(self): self.assertEqual(response.status_code, 500) with self.assertRaises(Exception) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual(str(raised.exception), "BOOM") def test_missing_connection(self): @@ -200,7 +200,7 @@ def test_missing_connection(self): self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual(str(raised.exception), "missing Connection header") def test_invalid_connection(self): @@ -213,7 +213,7 @@ def test_invalid_connection(self): self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual(str(raised.exception), "invalid Connection header: close") def test_missing_upgrade(self): @@ -225,7 +225,7 @@ def test_missing_upgrade(self): self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual(str(raised.exception), "missing Upgrade header") def test_invalid_upgrade(self): @@ -238,7 +238,7 @@ def test_invalid_upgrade(self): self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual(str(raised.exception), "invalid Upgrade header: h2c") def test_missing_key(self): @@ -249,7 +249,7 @@ def test_missing_key(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Key header") def test_multiple_key(self): @@ -260,7 +260,7 @@ def test_multiple_key(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual( str(raised.exception), "invalid Sec-WebSocket-Key header: " @@ -276,7 +276,7 @@ def test_invalid_key(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual( str(raised.exception), "invalid Sec-WebSocket-Key header: not Base64 data!" ) @@ -292,7 +292,7 @@ def test_truncated_key(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual( str(raised.exception), f"invalid Sec-WebSocket-Key header: {KEY[:16]}" ) @@ -305,7 +305,7 @@ def test_missing_version(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Version header") def test_multiple_version(self): @@ -316,7 +316,7 @@ def test_multiple_version(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual( str(raised.exception), "invalid Sec-WebSocket-Version header: " @@ -332,7 +332,7 @@ def test_invalid_version(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual( str(raised.exception), "invalid Sec-WebSocket-Version header: 11" ) @@ -344,7 +344,7 @@ def test_no_origin(self): self.assertEqual(response.status_code, 403) with self.assertRaises(InvalidOrigin) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual(str(raised.exception), "missing Origin header") def test_origin(self): @@ -364,7 +364,7 @@ def test_unexpected_origin(self): self.assertEqual(response.status_code, 403) with self.assertRaises(InvalidOrigin) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual( str(raised.exception), "invalid Origin header: https://other.example.com" ) @@ -382,7 +382,7 @@ def test_multiple_origin(self): # 400 Bad Request rather than 403 Forbidden. self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual( str(raised.exception), "invalid Origin header: more than one Origin header found", @@ -409,7 +409,7 @@ def test_unsupported_origin(self): self.assertEqual(response.status_code, 403) with self.assertRaises(InvalidOrigin) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual( str(raised.exception), "invalid Origin header: https://original.example.com" ) From 4034d8dbc81742dc2c0688dc9f29c5161bec66e9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 3 Apr 2022 10:43:21 +0200 Subject: [PATCH 1031/1539] Expect TCP close after a failed opening handshake. --- src/websockets/client.py | 2 ++ src/websockets/connection.py | 2 +- src/websockets/server.py | 11 ++++++++++- tests/test_client.py | 5 +++++ tests/test_server.py | 4 ++++ 5 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 8d826feaa..df8e53429 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -336,6 +336,8 @@ def parse(self) -> Generator[None, None, None]: except InvalidHandshake as exc: response._exception = exc self.handshake_exc = exc + self.parser = self.discard() + next(self.parser) # start coroutine else: assert self.state is CONNECTING self.state = OPEN diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 967bd8fa5..db8b53699 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -498,7 +498,7 @@ def close_expected(self) -> bool: # applies, except on EOFError where we don't send a close frame # because we already received the TCP close, so we don't expect it. # We already got a TCP Close if and only if the state is CLOSED. - return self.state is CLOSING + return self.state is CLOSING or self.handshake_exc is not None # Private methods for receiving data. diff --git a/src/websockets/server.py b/src/websockets/server.py index 214417ad0..5dad50b6a 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -13,6 +13,7 @@ InvalidHeader, InvalidHeaderValue, InvalidOrigin, + InvalidStatus, InvalidUpgrade, NegotiationError, ) @@ -471,8 +472,14 @@ def reject( ("Server", USER_AGENT), ] ) + response = Response(status.value, status.phrase, headers, body) + # When reject() is called from accept(), handshake_exc is already set. + # If a user calls reject(), set handshake_exc to guarantee invariant: + # "handshake_exc is None if and only if opening handshake succeded." + if self.handshake_exc is None: + self.handshake_exc = InvalidStatus(response) self.logger.info("connection failed (%d %s)", status.value, status.phrase) - return Response(status.value, status.phrase, headers, body) + return response def send_response(self, response: Response) -> None: """ @@ -497,6 +504,8 @@ def send_response(self, response: Response) -> None: self.state = OPEN else: self.send_eof() + self.parser = self.discard() + next(self.parser) # start coroutine def parse(self) -> Generator[None, None, None]: if self.state is CONNECTING: diff --git a/tests/test_client.py b/tests/test_client.py index 12fd8726f..a843d3272 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -42,6 +42,7 @@ def test_send_connect(self): f"\r\n".encode() ], ) + self.assertFalse(client.close_expected()) def test_connect_request(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): @@ -135,6 +136,8 @@ def test_receive_accept(self): ) [response] = client.events_received() self.assertIsInstance(response, Response) + self.assertEqual(client.data_to_send(), []) + self.assertFalse(client.close_expected()) self.assertEqual(client.state, OPEN) def test_receive_reject(self): @@ -155,6 +158,8 @@ def test_receive_reject(self): ) [response] = client.events_received() self.assertIsInstance(response, Response) + self.assertEqual(client.data_to_send(), []) + self.assertTrue(client.close_expected()) self.assertEqual(client.state, CONNECTING) def test_accept_response(self): diff --git a/tests/test_server.py b/tests/test_server.py index 43bc03e14..e3e802239 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -38,6 +38,8 @@ def test_receive_connect(self): ) [request] = server.events_received() self.assertIsInstance(request, Request) + self.assertEqual(server.data_to_send(), []) + self.assertFalse(server.close_expected()) def test_connect_request(self): server = ServerConnection() @@ -104,6 +106,7 @@ def test_send_accept(self): f"\r\n".encode() ], ) + self.assertFalse(server.close_expected()) self.assertEqual(server.state, OPEN) def test_send_reject(self): @@ -126,6 +129,7 @@ def test_send_reject(self): b"", ], ) + self.assertTrue(server.close_expected()) self.assertEqual(server.state, CONNECTING) def test_accept_response(self): From 5e78f15929569e5557930089613a5760557a3a59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E0=AE=AE=E0=AE=A9=E0=AF=8B=E0=AE=9C=E0=AF=8D=E0=AE=95?= =?UTF-8?q?=E0=AF=81=E0=AE=AE=E0=AE=BE=E0=AE=B0=E0=AF=8D=20=E0=AE=AA?= =?UTF-8?q?=E0=AE=B4=E0=AE=A9=E0=AE=BF=E0=AE=9A=E0=AF=8D=E0=AE=9A=E0=AE=BE?= =?UTF-8?q?=E0=AE=AE=E0=AE=BF?= Date: Fri, 15 Apr 2022 09:51:55 +0530 Subject: [PATCH 1032/1539] Update README.rst --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index f8df94ba4..2b9a445ea 100644 --- a/README.rst +++ b/README.rst @@ -125,7 +125,7 @@ Why shouldn't I use ``websockets``? and :rfc:`7692`: Compression Extensions for WebSocket. Its support for HTTP is minimal — just enough for a HTTP health check. - If you want do to both in the same server, look at HTTP frameworks that + If you want to do both in the same server, look at HTTP frameworks that build on top of ``websockets`` to support WebSocket connections, like Sanic_. From 8a58de259381f82036ab624255d01a8fbc6234de Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 6 Apr 2022 08:10:53 +0200 Subject: [PATCH 1033/1539] Update Tidelift security link. --- SECURITY.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/SECURITY.md b/SECURITY.md index 556217a4d..82024b485 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -2,4 +2,4 @@ Only the latest version receives security updates. -Please report vulnerabilities [via Tidelift](https://tidelift.com/docs/security). +Please report vulnerabilities [via Tidelift](https://tidelift.com/security). From 81cb6f55e50d54640f9a0cdff31e3eb0d079434d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20G=C3=B3rny?= Date: Fri, 8 Apr 2022 15:57:51 +0200 Subject: [PATCH 1034/1539] Skip non-contiguous buffer tests on NotImplementedError (pypy) PyPy3.9 7.3.9 does not implement creating contiguous buffers from non-contiguous, causing the respective tests to fail due to NotImplementedError. Catch it and skip the tests appropriately as discussed in #1157. Fixes #1158 --- tests/test_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index a9ea8dcbd..528eeaf24 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -88,4 +88,10 @@ def test_apply_mask_check_mask_length(self): class SpeedupsTests(ApplyMaskTests): @staticmethod def apply_mask(*args, **kwargs): - return c_apply_mask(*args, **kwargs) + try: + return c_apply_mask(*args, **kwargs) + except NotImplementedError as e: + # PyPy3.9 as of 7.3.9 does not implement creating + # contiguous buffers from non-contiguous and raises + # NotImplementedError. Catch it and skip the test. + raise unittest.SkipTest(str(e)) From 72abbd3d61801f4f45ed7babdfe14ebbd7c2ccf1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Apr 2022 09:07:57 +0200 Subject: [PATCH 1035/1539] Narrow down exception for PyPy To avoid accidentally masking real test failures on other platforms. --- tests/test_utils.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 528eeaf24..acd60edfc 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,6 @@ import base64 import itertools +import platform import unittest from websockets.utils import accept_key, apply_mask as py_apply_mask, generate_key @@ -90,8 +91,13 @@ class SpeedupsTests(ApplyMaskTests): def apply_mask(*args, **kwargs): try: return c_apply_mask(*args, **kwargs) - except NotImplementedError as e: - # PyPy3.9 as of 7.3.9 does not implement creating - # contiguous buffers from non-contiguous and raises - # NotImplementedError. Catch it and skip the test. - raise unittest.SkipTest(str(e)) + except NotImplementedError as exc: # pragma: no cover + # PyPy doesn't implement creating contiguous readonly buffer + # from non-contiguous. We don't care about this edge case. + if ( + platform.python_implementation() == "PyPy" + and "not implemented yet" in str(exc) + ): + raise unittest.SkipTest(str(exc)) + else: + raise From 79182b66798ed96c622462dbe9e28457655f887b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Apr 2022 09:43:42 +0200 Subject: [PATCH 1036/1539] Update regex matching git hashes. A branch (test-on-pypy) now uses 8 characters for git hashes. --- src/websockets/version.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/websockets/version.py b/src/websockets/version.py index 605a0eeef..65123b3fa 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -47,7 +47,7 @@ def get_version(tag: str) -> str: except (FileNotFoundError, subprocess.CalledProcessError): pass else: - description_re = r"[0-9.]+-([0-9]+)-(g[0-9a-f]{7}(?:-dirty)?)" + description_re = r"[0-9.]+-([0-9]+)-(g[0-9a-f]{7,}(?:-dirty)?)" match = re.fullmatch(description_re, description) assert match is not None distance, remainder = match.groups() @@ -69,7 +69,7 @@ def get_version(tag: str) -> str: def get_commit(tag: str, version: str) -> str: # Extract commit from version, falling back to tag if not available. - version_re = r"[0-9.]+\.dev[0-9]+\+g([0-9a-f]{7}|unknown)(?:\.dirty)?" + version_re = r"[0-9.]+\.dev[0-9]+\+g([0-9a-f]{7,}|unknown)(?:\.dirty)?" match = re.fullmatch(version_re, version) assert match is not None (commit,) = match.groups() From f761e23614c4d8b35ce461e70896e7c6c5bc1cfa Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Apr 2022 10:42:33 +0200 Subject: [PATCH 1037/1539] Prevent AttributeError on server shutdown. close() shouldn't be called on a CONNECTING connection because the transfer_data_task attribute isn't initialized. --- src/websockets/legacy/server.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index ea5b0d1fa..3e51db1b7 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -768,13 +768,15 @@ async def _close(self) -> None: # closed, handshake() closes OPENING connections with a HTTP 503 # error. Wait until all connections are closed. - # asyncio.wait doesn't accept an empty first argument - if self.websockets: + close_tasks = [ + asyncio.create_task(websocket.close(1001)) + for websocket in self.websockets + if websocket.state is not State.CONNECTING + ] + # asyncio.wait doesn't accept an empty first argument. + if close_tasks: await asyncio.wait( - [ - asyncio.create_task(websocket.close(1001)) - for websocket in self.websockets - ], + close_tasks, **loop_if_py_lt_38(self.get_loop()), ) From 9e2503a5c5cc4c48f223ec8aa780dfa7268622a1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Apr 2022 09:14:58 +0200 Subject: [PATCH 1038/1539] Add PyPy to testing matrix. --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3aa579aa4..9412c0ea5 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -38,7 +38,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python: ["3.7", "3.8", "3.9", "3.10"] + python: ["3.7", "3.8", "3.9", "3.10", "pypy-3.9"] steps: - name: Check out repository uses: actions/checkout@v2 From 8dd8e410431408cf33fe761d8b273a95cf45da5f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Apr 2022 10:36:32 +0200 Subject: [PATCH 1039/1539] Ignore failing test on PyPy. This has to do with closing TLS connections: transport.close() is enough on CPython while PyPy requires transport.abort() (which websockets will call eventually, but not before that test fails.) --- tests/legacy/test_client_server.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index acc231d76..f9de70c9c 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -4,6 +4,7 @@ import http import logging import pathlib +import platform import random import socket import ssl @@ -1243,6 +1244,10 @@ class SecureClientServerTests( # TLS over Unix sockets doesn't make sense. test_unix_socket = None + # This test fails under PyPy due to a difference with CPython. + if platform.python_implementation() == "PyPy": # pragma: no cover + test_http_request_ws_endpoint = None + @with_server() def test_ws_uri_is_rejected(self): with self.assertRaises(ValueError): From 1e8e8ce4bcccbbf922d941e2fe61a28a4e370a46 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Apr 2022 11:01:10 +0200 Subject: [PATCH 1040/1539] Bump timings when testing on PyPy. --- tests/legacy/utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/legacy/utils.py b/tests/legacy/utils.py index 4d4306232..1fa2b53c8 100644 --- a/tests/legacy/utils.py +++ b/tests/legacy/utils.py @@ -3,6 +3,7 @@ import functools import logging import os +import platform import time import unittest @@ -90,5 +91,9 @@ def assertDeprecationWarnings(self, recorded_warnings, expected_warnings): if os.environ.get("PYTHONASYNCIODEBUG"): # pragma: no cover MS *= 10 +# PyPy has a performance penalty for this test suite. +if platform.python_implementation() == "PyPy": # pragma: no cover + MS *= 5 + # Ensure that timeouts are larger than the clock's resolution (for Windows). MS = max(MS, 2.5 * time.get_clock_info("monotonic").resolution) From 19367485bbdda2eb7c55ad3a43b5b1296743b1f3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Apr 2022 11:57:29 +0200 Subject: [PATCH 1041/1539] Handle SSLError when receiving messages. Refs #1160. --- src/websockets/legacy/protocol.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 49c81da4f..66751a477 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -5,6 +5,7 @@ import collections import logging import random +import ssl import struct import uuid import warnings @@ -974,13 +975,14 @@ async def transfer_data(self) -> None: self.transfer_data_exc = exc self.fail_connection(1002) - except (ConnectionError, TimeoutError, EOFError) as exc: + except (ConnectionError, TimeoutError, EOFError, ssl.SSLError) as exc: # Reading data with self.reader.readexactly may raise: # - most subclasses of ConnectionError if the TCP connection # breaks, is reset, or is aborted; # - TimeoutError if the TCP connection times out; # - IncompleteReadError, a subclass of EOFError, if fewer - # bytes are available than requested. + # bytes are available than requested; + # - ssl.SSLError if the other side infringes the TLS protocol. self.transfer_data_exc = exc self.fail_connection(1006) From a446fb78d9994c30a479719ea579c349de4f372f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Apr 2022 12:03:30 +0200 Subject: [PATCH 1042/1539] Handle zlib.error when receiving compressed messages. Refs #1160, #665. --- src/websockets/extensions/permessage_deflate.py | 5 ++++- tests/extensions/test_permessage_deflate.py | 5 ++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 017e3d843..e0de5e8f8 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -125,7 +125,10 @@ def decode( if frame.fin: data += _EMPTY_UNCOMPRESSED_BLOCK max_length = 0 if max_size is None else max_size - data = self.decoder.decompress(data, max_length) + try: + data = self.decoder.decompress(data, max_length) + except zlib.error as exc: + raise exceptions.ProtocolError("decompression failed") from exc if self.decoder.unconsumed_tail: raise exceptions.PayloadTooBig(f"over size limit (? > {max_size} bytes)") diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index a50685587..3cc9172df 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -1,6 +1,5 @@ import dataclasses import unittest -import zlib from websockets.exceptions import ( DuplicateParameter, @@ -8,6 +7,7 @@ InvalidParameterValue, NegotiationError, PayloadTooBig, + ProtocolError, ) from websockets.extensions.permessage_deflate import * from websockets.frames import ( @@ -225,9 +225,8 @@ def test_remote_no_context_takeover(self): dec_frame1 = self.extension.decode(enc_frame1) self.assertEqual(dec_frame1, frame) - with self.assertRaises(zlib.error) as exc: + with self.assertRaises(ProtocolError): self.extension.decode(enc_frame2) - self.assertIn("invalid distance too far back", str(exc.exception)) def test_local_no_context_takeover(self): # No context takeover when encoding and decoding messages. From 8850eb06c7955462dd641acf50d1dcfcfc95b0ee Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Mon, 28 Feb 2022 14:10:16 +0800 Subject: [PATCH 1043/1539] Fix logging error when sending a memoryview. See also: https://bugs.python.org/issue15945 --- src/websockets/frames.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 82b4a1403..07aabac8d 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -139,7 +139,7 @@ def __str__(self) -> str: # Encode just what we need, plus two dummy bytes to elide later. binary = self.data if len(binary) > 25: - binary = binary[:16] + b"\x00\x00" + binary[-8:] + binary = b"".join([binary[:16], b"\x00\x00", binary[-8:]]) data = " ".join(f"{byte:02x}" for byte in binary) elif self.opcode is OP_CLOSE: data = str(Close.parse(self.data)) @@ -153,7 +153,7 @@ def __str__(self) -> str: except UnicodeDecodeError: binary = self.data if len(binary) > 25: - binary = binary[:16] + b"\x00\x00" + binary[-8:] + binary = b"".join([binary[:16], b"\x00\x00", binary[-8:]]) data = " ".join(f"{byte:02x}" for byte in binary) coding = "binary" else: From 778879eb0144a14a2a406fb2c3fa45f80afcd421 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Apr 2022 12:29:21 +0200 Subject: [PATCH 1044/1539] Add tests for previous commit. --- src/websockets/frames.py | 8 +++++--- tests/test_frames.py | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 07aabac8d..043b688b5 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -145,12 +145,14 @@ def __str__(self) -> str: data = str(Close.parse(self.data)) elif self.data: # We don't know if a Continuation frame contains text or binary. - # Ping and Pong frames could contain UTF-8. Attempt to decode as - # UTF-8 and display it as text; fallback to binary. + # Ping and Pong frames could contain UTF-8. + # Attempt to decode as UTF-8 and display it as text; fallback to + # binary. If self.data is a memoryview, it has no decode() method, + # which raises AttributeError. try: data = repr(self.data.decode()) coding = "text" - except UnicodeDecodeError: + except (UnicodeDecodeError, AttributeError): binary = self.data if len(binary) > 25: binary = b"".join([binary[:16], b"\x00\x00", binary[-8:]]) diff --git a/tests/test_frames.py b/tests/test_frames.py index c8f9867d4..e7c48b930 100644 --- a/tests/test_frames.py +++ b/tests/test_frames.py @@ -207,6 +207,12 @@ def test_cont_binary(self): "CONT fc fd fe ff [binary, 4 bytes, continued]", ) + def test_cont_binary_from_memoryview(self): + self.assertEqual( + str(Frame(OP_CONT, memoryview(b"\xfc\xfd\xfe\xff"), fin=False)), + "CONT fc fd fe ff [binary, 4 bytes, continued]", + ) + def test_cont_final_text(self): self.assertEqual( str(Frame(OP_CONT, b" cr\xc3\xa8me")), @@ -219,6 +225,12 @@ def test_cont_final_binary(self): "CONT fc fd fe ff [binary, 4 bytes]", ) + def test_cont_final_binary_from_memoryview(self): + self.assertEqual( + str(Frame(OP_CONT, memoryview(b"\xfc\xfd\xfe\xff"))), + "CONT fc fd fe ff [binary, 4 bytes]", + ) + def test_cont_text_truncated(self): self.assertEqual( str(Frame(OP_CONT, b"caf\xc3\xa9 " * 16, fin=False)), @@ -233,6 +245,13 @@ def test_cont_binary_truncated(self): " f8 f9 fa fb fc fd fe ff [binary, 256 bytes, continued]", ) + def test_cont_binary_truncated_from_memoryview(self): + self.assertEqual( + str(Frame(OP_CONT, memoryview(bytes(range(256))), fin=False)), + "CONT 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f ..." + " f8 f9 fa fb fc fd fe ff [binary, 256 bytes, continued]", + ) + def test_text(self): self.assertEqual( str(Frame(OP_TEXT, b"caf\xc3\xa9")), @@ -264,12 +283,24 @@ def test_binary(self): "BINARY 00 01 02 03 [4 bytes]", ) + def test_binary_from_memoryview(self): + self.assertEqual( + str(Frame(OP_BINARY, memoryview(b"\x00\x01\x02\x03"))), + "BINARY 00 01 02 03 [4 bytes]", + ) + def test_binary_non_final(self): self.assertEqual( str(Frame(OP_BINARY, b"\x00\x01\x02\x03", fin=False)), "BINARY 00 01 02 03 [4 bytes, continued]", ) + def test_binary_non_final_from_memoryview(self): + self.assertEqual( + str(Frame(OP_BINARY, memoryview(b"\x00\x01\x02\x03"), fin=False)), + "BINARY 00 01 02 03 [4 bytes, continued]", + ) + def test_binary_truncated(self): self.assertEqual( str(Frame(OP_BINARY, bytes(range(256)))), @@ -277,6 +308,13 @@ def test_binary_truncated(self): " f8 f9 fa fb fc fd fe ff [256 bytes]", ) + def test_binary_truncated_from_memoryview(self): + self.assertEqual( + str(Frame(OP_BINARY, memoryview(bytes(range(256))))), + "BINARY 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f ..." + " f8 f9 fa fb fc fd fe ff [256 bytes]", + ) + def test_close(self): self.assertEqual( str(Frame(OP_CLOSE, b"\x03\xe8")), From a96fffa85907a392711c75ad3039c2949531a268 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Apr 2022 12:57:01 +0200 Subject: [PATCH 1045/1539] Add FAQ on request path. Fix #1154. --- docs/howto/faq.rst | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index bc742f458..e9aeae881 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -179,6 +179,30 @@ Another way to achieve this result is to define the ``handler`` coroutine in a scope where the ``extra_argument`` variable exists instead of injecting it through an argument. +How do I access the request path? +................................. + +It is available in the :attr:`~server.WebSocketServerProtocol.path` attribute. + +You may route a connection to different handlers depending on the request path:: + + async def handler(websocket): + if websocket.path == "/blue": + await blue_handler(websocket) + elif websocket.path == "/green": + await green_handler(websocket) + else: + # No handler for this path; close the connection. + return + +You may also route the connection based on the first message received from the +client, as shown in the :doc:`tutorial <../intro/tutorial2>`. When you want to +authenticate the connection before routing it, this is usually more convenient. + +Generally speaking, there is far less emphasis on the request path in WebSocket +servers than in HTTP servers. When a WebSockt server provides a single endpoint, +it may ignore the request path entirely. + How do I access HTTP headers? ............................. From f9fd2cebcd42633ed917cd64e805bea17879c2d7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Apr 2022 13:27:32 +0200 Subject: [PATCH 1046/1539] Catch RuntimeError to cater to uvloop. Fix #1138, #1072 (again). --- src/websockets/legacy/protocol.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 66751a477..d1d52bfac 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1310,9 +1310,10 @@ async def close_connection(self) -> None: # write_eof() doesn't document which exceptions it raises. # "[Errno 107] Transport endpoint is not connected" happens # but it isn't completely clear under which circumstances. + # uvloop can raise RuntimeError here. try: self.transport.write_eof() - except OSError: # pragma: no cover + except (OSError, RuntimeError): # pragma: no cover pass if await self.wait_for_connection_lost(): From 42e33f17cdaaa258f3035618fb806c74744406aa Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Apr 2022 15:39:18 +0200 Subject: [PATCH 1047/1539] Clarify answer about threads. Fix #1156. Ref #1162. --- docs/howto/faq.rst | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index e9aeae881..a103f677c 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -607,4 +607,15 @@ Why am I having problems with threads? You shouldn't use threads. Use tasks instead. -If you have to, :meth:`~asyncio.loop.call_soon_threadsafe` may help. +Indeed, when you chose websockets, you chose :mod:`asyncio` as the primary +framework to handle concurrency. This choice is mutually exclusive with +:mod:`threading`. + +If you believe that you need to run websockets in a thread and some logic in +another thread, you should run that logic in a :class:`~asyncio.Task` instead. + +If you believe that you cannot run that logic in the same event loop because it +will block websockets, :meth:`~asyncio.loop.run_in_executor` may help. + +This question is really about :mod:`asyncio`. Please review the advice about +:ref:`asyncio-multithreading` in the Python documentation. From 3fb06c347a6039959c38ff41a7d1057971ca17c5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Apr 2022 16:12:30 +0200 Subject: [PATCH 1048/1539] Updated changelog for 10.3. --- docs/project/changelog.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 26e9a5cdc..9bdc99db5 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -41,6 +41,11 @@ Backwards-incompatible changes See :doc:`../howto/sansio` for details. +Improvements +............ + +* Reduced noise in logs when :mod:`ssl` or :mod:`zlib` raise exceptions. + 10.2 ---- From 0d9a251a338f549b91e13cc346b0f73fd5964493 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Apr 2022 16:12:42 +0200 Subject: [PATCH 1049/1539] Release version 10.3 --- docs/project/changelog.rst | 2 +- src/websockets/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 9bdc99db5..398fa65a9 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -28,7 +28,7 @@ They may change at any time. 10.3 ---- -*In development* +*April 17, 2022* Backwards-incompatible changes .............................. diff --git a/src/websockets/version.py b/src/websockets/version.py index 65123b3fa..c30bfd68f 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -16,7 +16,7 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = False +released = True tag = version = commit = "10.3" From 3742b429a25d5f51511b626435c6a1acdd9027a3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Apr 2022 16:14:11 +0200 Subject: [PATCH 1050/1539] Start version 10.4. --- docs/project/changelog.rst | 5 +++++ src/websockets/version.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 398fa65a9..105065c8c 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,11 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented APIs are considered private. They may change at any time. +10.4 +---- + +*In development* + 10.3 ---- diff --git a/src/websockets/version.py b/src/websockets/version.py index c30bfd68f..29d658ce2 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -16,9 +16,9 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = True +released = False -tag = version = commit = "10.3" +tag = version = commit = "10.4" if not released: # pragma: no cover From c590dc8a69ca9f3052aa34f7db017b9a253a48a1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 23 Apr 2022 19:28:43 +0200 Subject: [PATCH 1051/1539] Tweak coverage setup. This makes it possible to run coverage with the --include option without creating a conflict with the --source option. --- Makefile | 3 +-- setup.cfg | 3 --- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/Makefile b/Makefile index 6f8130840..a38634aee 100644 --- a/Makefile +++ b/Makefile @@ -16,8 +16,7 @@ test: python -m unittest coverage: - coverage erase - coverage run -m unittest + coverage run --source websockets,tests -m unittest coverage html coverage report --show-missing --fail-under=100 diff --git a/setup.cfg b/setup.cfg index 9ff1939c7..91f769620 100644 --- a/setup.cfg +++ b/setup.cfg @@ -22,9 +22,6 @@ lines_after_imports = 2 branch = True omit = */__main__.py -source = - websockets - tests [coverage:paths] source = From 9c87d43f1d7bbf6847350087aae74fd35f73a642 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 24 Apr 2022 10:24:30 +0200 Subject: [PATCH 1052/1539] Fix two typos. Found by codereview.doctor (among many false positives). --- example/deployment/kubernetes/benchmark.py | 2 +- tests/test_headers.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/example/deployment/kubernetes/benchmark.py b/example/deployment/kubernetes/benchmark.py index 600c47316..22ee4c5bd 100755 --- a/example/deployment/kubernetes/benchmark.py +++ b/example/deployment/kubernetes/benchmark.py @@ -11,7 +11,7 @@ async def run(client_id, messages): async with websockets.connect(URI) as websocket: for message_id in range(messages): - await websocket.send("{client_id}:{message_id}") + await websocket.send(f"{client_id}:{message_id}") await websocket.recv() diff --git a/tests/test_headers.py b/tests/test_headers.py index a2d51fc6a..4ebd8b90c 100644 --- a/tests/test_headers.py +++ b/tests/test_headers.py @@ -140,7 +140,7 @@ def test_parse_subprotocol_invalid_header(self): for header in [ # Truncated examples "", - ",\t," + ",\t,", # Wrong delimiter "foo; bar", ]: From ac45051ecb34c07407c4315a08729e9967be8e51 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 3 May 2022 13:53:03 +0200 Subject: [PATCH 1053/1539] Fix typing in client connection initialization. --- src/websockets/legacy/client.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index fadc3efe8..29550b694 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -651,9 +651,8 @@ async def __await_impl_timeout__(self) -> WebSocketClientProtocol: async def __await_impl__(self) -> WebSocketClientProtocol: for redirects in range(self.MAX_REDIRECTS_ALLOWED): - transport, protocol = await self._create_connection() - protocol = cast(WebSocketClientProtocol, protocol) - + _transport, _protocol = await self._create_connection() + protocol = cast(WebSocketClientProtocol, _protocol) try: await protocol.handshake( self._wsuri, From ebb591811e9626272168ed1ee1cf653d0760d92a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 3 May 2022 14:30:18 +0200 Subject: [PATCH 1054/1539] Move type: ignore comment. It needs to be in a new location to be recognized. --- src/websockets/legacy/protocol.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index d1d52bfac..9a7ad1d35 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -684,9 +684,7 @@ async def send( try: # message_chunk = anext(aiter_message) without anext # https://github.com/python/mypy/issues/5738 - message_chunk = await type(aiter_message).__anext__( # type: ignore - aiter_message - ) + message_chunk = await type(aiter_message).__anext__(aiter_message) # type: ignore # noqa except StopAsyncIteration: return opcode, data = prepare_data(message_chunk) From 670f7c5e8c934c29f46f24089a84dd65709c1d13 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 3 May 2022 14:33:56 +0200 Subject: [PATCH 1055/1539] Rename chunks to fragments. --- src/websockets/legacy/protocol.py | 38 +++++++++++++++---------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 9a7ad1d35..b5cd64618 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -645,10 +645,10 @@ async def send( iter_message = iter(message) try: - message_chunk = next(iter_message) + fragment = next(iter_message) except StopIteration: return - opcode, data = prepare_data(message_chunk) + opcode, data = prepare_data(fragment) self._fragmented_message_waiter = asyncio.Future() try: @@ -656,8 +656,8 @@ async def send( await self.write_frame(False, opcode, data) # Other fragments. - for message_chunk in iter_message: - confirm_opcode, data = prepare_data(message_chunk) + for fragment in iter_message: + confirm_opcode, data = prepare_data(fragment) if confirm_opcode != opcode: raise TypeError("data contains inconsistent types") await self.write_frame(False, OP_CONT, data) @@ -682,12 +682,12 @@ async def send( # https://github.com/python/mypy/issues/5738 aiter_message = type(message).__aiter__(message) # type: ignore try: - # message_chunk = anext(aiter_message) without anext + # fragment = anext(aiter_message) without anext # https://github.com/python/mypy/issues/5738 - message_chunk = await type(aiter_message).__anext__(aiter_message) # type: ignore # noqa + fragment = await type(aiter_message).__anext__(aiter_message) # type: ignore # noqa except StopAsyncIteration: return - opcode, data = prepare_data(message_chunk) + opcode, data = prepare_data(fragment) self._fragmented_message_waiter = asyncio.Future() try: @@ -698,8 +698,8 @@ async def send( # https://github.com/python/mypy/issues/5738 # coverage reports this code as not covered, but it is # exercised by tests - changing it breaks the tests! - async for message_chunk in aiter_message: # type: ignore # pragma: no cover # noqa - confirm_opcode, data = prepare_data(message_chunk) + async for fragment in aiter_message: # type: ignore # pragma: no cover # noqa + confirm_opcode, data = prepare_data(fragment) if confirm_opcode != opcode: raise TypeError("data contains inconsistent types") await self.write_frame(False, OP_CONT, data) @@ -1028,7 +1028,7 @@ async def read_message(self) -> Optional[Data]: return frame.data.decode("utf-8") if text else frame.data # 5.4. Fragmentation - chunks: List[Data] = [] + fragments: List[Data] = [] max_size = self.max_size if text: decoder_factory = codecs.getincrementaldecoder("utf-8") @@ -1036,14 +1036,14 @@ async def read_message(self) -> Optional[Data]: if max_size is None: def append(frame: Frame) -> None: - nonlocal chunks - chunks.append(decoder.decode(frame.data, frame.fin)) + nonlocal fragments + fragments.append(decoder.decode(frame.data, frame.fin)) else: def append(frame: Frame) -> None: - nonlocal chunks, max_size - chunks.append(decoder.decode(frame.data, frame.fin)) + nonlocal fragments, max_size + fragments.append(decoder.decode(frame.data, frame.fin)) assert isinstance(max_size, int) max_size -= len(frame.data) @@ -1051,14 +1051,14 @@ def append(frame: Frame) -> None: if max_size is None: def append(frame: Frame) -> None: - nonlocal chunks - chunks.append(frame.data) + nonlocal fragments + fragments.append(frame.data) else: def append(frame: Frame) -> None: - nonlocal chunks, max_size - chunks.append(frame.data) + nonlocal fragments, max_size + fragments.append(frame.data) assert isinstance(max_size, int) max_size -= len(frame.data) @@ -1072,7 +1072,7 @@ def append(frame: Frame) -> None: raise ProtocolError("unexpected opcode") append(frame) - return ("" if text else b"").join(chunks) + return ("" if text else b"").join(fragments) async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: """ From 50483798cdbdf0ab3edee5cd8d4a9b4069d7dfc4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 3 May 2022 07:29:59 +0200 Subject: [PATCH 1056/1539] Bump cibuildwheel version. Fix #1169. --- .github/workflows/wheels.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index a6df67743..6f529cfdb 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -47,7 +47,7 @@ jobs: with: platforms: all - name: Build wheels - uses: pypa/cibuildwheel@v2.2.2 + uses: pypa/cibuildwheel@v2.5.0 env: CIBW_ARCHS_MACOS: "x86_64 universal2 arm64" CIBW_ARCHS_LINUX: "auto aarch64" From 3e6f0c474fbb89909988cdfdfa8dbee7ac9cb84d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 3 May 2022 07:32:21 +0200 Subject: [PATCH 1057/1539] Use the latest version of each OS. --- .github/workflows/wheels.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 6f529cfdb..6933c941d 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -30,7 +30,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-20.04, windows-2019, macOS-10.15] + os: [ubuntu-latest, windows-latest, macOS-latest] steps: - name: Check out repository From ef8a4de1e1a97ad2ed15637b4abaab89c819a2d6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 3 May 2022 08:06:39 +0200 Subject: [PATCH 1058/1539] Run a subset of tests on branches. --- .github/workflows/tests.yml | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9412c0ea5..09ec750df 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -38,7 +38,17 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python: ["3.7", "3.8", "3.9", "3.10", "pypy-3.9"] + python: + - "3.7" + - "3.8" + - "3.9" + - "3.10" + - "pypy-3.9" + is_main: + - ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} + exclude: + - python: "pypy-3.9" + is_main: false steps: - name: Check out repository uses: actions/checkout@v2 From 8b603d824786de82330d21b89a109e720082440d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 3 May 2022 08:26:51 +0200 Subject: [PATCH 1059/1539] Test on all PyPy versions. Refs #1169. --- .github/workflows/tests.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 09ec750df..03ce3aff9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -43,10 +43,16 @@ jobs: - "3.8" - "3.9" - "3.10" + - "pypy-3.7" + - "pypy-3.8" - "pypy-3.9" is_main: - ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} exclude: + - python: "pypy-3.7" + is_main: false + - python: "pypy-3.8" + is_main: false - python: "pypy-3.9" is_main: false steps: From 71dd4f881365a910972c6dc1326638e90eda5cf5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 3 May 2022 18:51:04 +0200 Subject: [PATCH 1060/1539] Stop building wheels with every push to main. No need to heat the planet. It's working. Keep the possibility to do it manually. --- .github/workflows/wheels.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 6933c941d..6966b647f 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -2,10 +2,9 @@ name: Build wheels on: push: - branches: - - main tags: - '*' + workflow_dispatch: jobs: sdist: From 6de9572076500959d7c8fa1d5889acc1fa06edac Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 3 May 2022 08:27:28 +0200 Subject: [PATCH 1061/1539] Standardize on multiline list syntax. --- .github/workflows/wheels.yml | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 6966b647f..0322055b1 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -29,8 +29,10 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, windows-latest, macOS-latest] - + os: + - ubuntu-latest + - windows-latest + - macOS-latest steps: - name: Check out repository uses: actions/checkout@v2 @@ -58,7 +60,9 @@ jobs: upload_pypi: name: Upload to PyPI - needs: [sdist, wheels] + needs: + - sdist + - wheels runs-on: ubuntu-latest if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') steps: From c4a4b6f45af607431c5707d56420bbaf471bbb6e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 5 May 2022 08:32:07 +0200 Subject: [PATCH 1062/1539] Remove all # type: ignore but one. Prefer cast() instead. The only remaining # type: ignore is for the legacy_recv behavior. --- src/websockets/legacy/framing.py | 4 ++-- src/websockets/legacy/protocol.py | 22 ++++++++++++++-------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index c4de7eb28..04cddc0e0 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -1,6 +1,5 @@ from __future__ import annotations -import dataclasses import struct from typing import Any, Awaitable, Callable, NamedTuple, Optional, Sequence, Tuple @@ -163,7 +162,8 @@ def parse_close(data: bytes) -> Tuple[int, str]: UnicodeDecodeError: if the reason isn't valid UTF-8. """ - return dataclasses.astuple(Close.parse(data)) # type: ignore + close = Close.parse(data) + return close.code, close.reason def serialize_close(code: int, reason: str) -> bytes: diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index b5cd64618..651ff824d 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -14,6 +14,7 @@ AsyncIterable, AsyncIterator, Awaitable, + Callable, Deque, Dict, Iterable, @@ -678,13 +679,19 @@ async def send( # Fragmented message -- asynchronous iterator elif isinstance(message, AsyncIterable): - # aiter_message = aiter(message) without aiter - # https://github.com/python/mypy/issues/5738 - aiter_message = type(message).__aiter__(message) # type: ignore + # Implement aiter_message = aiter(message) without aiter + # Work around https://github.com/python/mypy/issues/5738 + aiter_message = cast( + Callable[[AsyncIterable[Data]], AsyncIterator[Data]], + type(message).__aiter__, + )(message) try: - # fragment = anext(aiter_message) without anext - # https://github.com/python/mypy/issues/5738 - fragment = await type(aiter_message).__anext__(aiter_message) # type: ignore # noqa + # Implement fragment = anext(aiter_message) without anext + # Work around https://github.com/python/mypy/issues/5738 + fragment = await cast( + Callable[[AsyncIterator[Data]], Awaitable[Data]], + type(aiter_message).__anext__, + )(aiter_message) except StopAsyncIteration: return opcode, data = prepare_data(fragment) @@ -695,10 +702,9 @@ async def send( await self.write_frame(False, opcode, data) # Other fragments. - # https://github.com/python/mypy/issues/5738 # coverage reports this code as not covered, but it is # exercised by tests - changing it breaks the tests! - async for fragment in aiter_message: # type: ignore # pragma: no cover # noqa + async for fragment in aiter_message: # pragma: no cover confirm_opcode, data = prepare_data(fragment) if confirm_opcode != opcode: raise TypeError("data contains inconsistent types") From 225a57c450a4064e04dcbcc50dc85dc7e9452628 Mon Sep 17 00:00:00 2001 From: Shoaib Date: Sat, 25 Jun 2022 03:33:42 +0500 Subject: [PATCH 1063/1539] Replaced unsafe reference to self.pings with a variable --- src/websockets/legacy/protocol.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 651ff824d..2d1391a01 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -846,11 +846,12 @@ async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: while data is None or data in self.pings: data = struct.pack("!I", random.getrandbits(32)) - self.pings[data] = self.loop.create_future() + ping_future = self.loop.create_future() + self.pings[data] = ping_future await self.write_frame(True, OP_PING, data) - return asyncio.shield(self.pings[data]) + return asyncio.shield(ping_future) async def pong(self, data: Data = b"") -> None: """ From 57a1325795cc7b9166a5b4599c34c4869d8b2ab6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Jul 2022 09:47:16 +0200 Subject: [PATCH 1064/1539] Update link to django-sesame documentation. --- docs/howto/django.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/howto/django.rst b/docs/howto/django.rst index 34cce58e9..5bb2e296b 100644 --- a/docs/howto/django.rst +++ b/docs/howto/django.rst @@ -113,10 +113,10 @@ your settings module. The connection handler reads the first message received from the client, which is expected to contain a django-sesame token. Then it authenticates the user -with ``get_user()``, the API for `authentication outside views`_. If +with ``get_user()``, the API for `authentication outside a view`_. If authentication fails, it closes the connection and exits. -.. _authentication outside views: https://github.com/aaugustin/django-sesame#authentication-outside-views +.. _authentication outside a view: https://django-sesame.readthedocs.io/en/stable/howto.html#outside-a-view When we call an API that makes a database query such as ``get_user()``, we wrap the call in :func:`~asyncio.to_thread`. Indeed, the Django ORM doesn't From fad12d475e4c65aff9040d9e5f34a4924c470774 Mon Sep 17 00:00:00 2001 From: Tim Gates Date: Sun, 31 Jul 2022 13:17:18 +1000 Subject: [PATCH 1065/1539] docs: Fix a few typos There are small typos in: - src/websockets/legacy/protocol.py - src/websockets/legacy/server.py - tests/legacy/test_protocol.py - tests/legacy/utils.py Fixes: - Should read `suspect` rather than `supect`. - Should read `acknowledged` rather than `acknowleged`. - Should read `attempts` rather than `attemps`. - Should read `asynchronous` rather than `asychronous`. Signed-off-by: Tim Gates --- src/websockets/legacy/protocol.py | 6 +++--- src/websockets/legacy/server.py | 2 +- tests/legacy/test_protocol.py | 4 ++-- tests/legacy/utils.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 2d1391a01..04726033e 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1303,7 +1303,7 @@ async def close_connection(self) -> None: if self.is_client and hasattr(self, "transfer_data_task"): if await self.wait_for_connection_lost(): # Coverage marks this line as a partially executed branch. - # I supect a bug in coverage. Ignore it for now. + # I suspect a bug in coverage. Ignore it for now. return # pragma: no cover if self.debug: self.logger.debug("! timed out waiting for TCP close") @@ -1323,7 +1323,7 @@ async def close_connection(self) -> None: if await self.wait_for_connection_lost(): # Coverage marks this line as a partially executed branch. - # I supect a bug in coverage. Ignore it for now. + # I suspect a bug in coverage. Ignore it for now. return # pragma: no cover if self.debug: self.logger.debug("! timed out waiting for TCP close") @@ -1361,7 +1361,7 @@ async def close_transport(self) -> None: # connection_lost() is called quickly after aborting. # Coverage marks this line as a partially executed branch. - # I supect a bug in coverage. Ignore it for now. + # I suspect a bug in coverage. Ignore it for now. await self.wait_for_connection_lost() # pragma: no cover async def wait_for_connection_lost(self) -> bool: diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 3e51db1b7..6b4eccd02 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -153,7 +153,7 @@ async def handler(self) -> None: Handle the lifecycle of a WebSocket connection. Since this method doesn't have a caller able to handle exceptions, it - attemps to log relevant ones and guarantees that the TCP connection is + attempts to log relevant ones and guarantees that the TCP connection is closed before exiting. """ diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 1672ab1ed..ed6424694 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1188,7 +1188,7 @@ def test_keepalive_ping(self): def test_keepalive_ping_not_acknowledged_closes_connection(self): self.restart_protocol_with_keepalive_ping() - # Ping is sent at 3ms and not acknowleged. + # Ping is sent at 3ms and not acknowledged. self.loop.run_until_complete(asyncio.sleep(4 * MS)) (ping_1,) = tuple(self.protocol.pings) self.assertOneFrameSent(True, OP_PING, ping_1) @@ -1257,7 +1257,7 @@ def test_keepalive_ping_with_no_ping_interval(self): def test_keepalive_ping_with_no_ping_timeout(self): self.restart_protocol_with_keepalive_ping(ping_timeout=None) - # Ping is sent at 3ms and not acknowleged. + # Ping is sent at 3ms and not acknowledged. self.loop.run_until_complete(asyncio.sleep(4 * MS)) (ping_1,) = tuple(self.protocol.pings) self.assertOneFrameSent(True, OP_PING, ping_1) diff --git a/tests/legacy/utils.py b/tests/legacy/utils.py index 1fa2b53c8..fd5dfc294 100644 --- a/tests/legacy/utils.py +++ b/tests/legacy/utils.py @@ -18,7 +18,7 @@ def __init_subclass__(cls, **kwargs): """ Convert test coroutines to test functions. - This supports asychronous tests transparently. + This supports asynchronous tests transparently. """ super().__init_subclass__(**kwargs) From d84108b2bf2256187509276a7cb2a905f39af392 Mon Sep 17 00:00:00 2001 From: Dmitry Lavrentev Date: Wed, 10 Aug 2022 19:31:29 +0300 Subject: [PATCH 1066/1539] Added default None values for Optional fields in WebSocketURI for compatibility with Dataclass syntax. --- src/websockets/uri.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/websockets/uri.py b/src/websockets/uri.py index fff0c3806..385090f66 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -33,8 +33,8 @@ class WebSocketURI: port: int path: str query: str - username: Optional[str] - password: Optional[str] + username: Optional[str] = None + password: Optional[str] = None @property def resource_name(self) -> str: From 5d2a04ae65b48c8d02e37318bc6e8d43be9ae18e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 16 Aug 2022 09:05:34 +0200 Subject: [PATCH 1067/1539] Add usage docs for start_serving and serve_forever. The only example for Python I found was https://bugs.python.org/issue32662: async def main(): srv = await asyncio.start_server(...) async with srv: await srv.serve_forever() asyncio.run(main()) It looks pretty bad to me: it starts the server thrice and stops it twice, relying on idempotency to avoid issues. I dislike this style of "pile it up to be sure that it works" so I'm showing a subset of possibilities only. Ref #1197. --- src/websockets/legacy/server.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 6b4eccd02..2ee469825 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -833,6 +833,13 @@ async def start_serving(self) -> None: """ See :meth:`asyncio.Server.start_serving`. + Typical use:: + + server = await serve(..., start_serving=False) + # perform additional setup here... + # ... then start the server + await server.start_serving() + """ await self.server.start_serving() # pragma: no cover @@ -840,6 +847,17 @@ async def serve_forever(self) -> None: """ See :meth:`asyncio.Server.serve_forever`. + Typical use:: + + server = await serve(...) + # this coroutine doesn't return + # canceling it stops the server + await server.serve_forever() + + This is an alternative to using :func:`serve` as an asynchronous context + manager. Shutdown is triggered by canceling :meth:`serve_forever` + instead of exiting a :func:`serve` context. + """ await self.server.serve_forever() # pragma: no cover From c7c8ecc12323855c37f3c77874ade5fdd082890f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 16 Aug 2022 09:22:26 +0200 Subject: [PATCH 1068/1539] Escape square brackets. They're special characters in some shells. --- docs/howto/autoreload.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/howto/autoreload.rst b/docs/howto/autoreload.rst index edd87d0fd..fc736a591 100644 --- a/docs/howto/autoreload.rst +++ b/docs/howto/autoreload.rst @@ -16,7 +16,7 @@ Install watchdog_ with the ``watchmedo`` shell utility: .. code-block:: console - $ pip install watchdog[watchmedo] + $ pip install 'watchdog[watchmedo]' .. _watchdog: https://pypi.org/project/watchdog/ From 7980c00812c683ddbfbf244f6ec0ca6faca8c023 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 16 Aug 2022 09:27:58 +0200 Subject: [PATCH 1069/1539] Standardize title style. --- docs/howto/extensions.rst | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/docs/howto/extensions.rst b/docs/howto/extensions.rst index 2baead3f0..3c8a7d72a 100644 --- a/docs/howto/extensions.rst +++ b/docs/howto/extensions.rst @@ -1,5 +1,5 @@ -Writing an extension -==================== +Write an extension +================== .. currentmodule:: websockets.extensions @@ -28,5 +28,3 @@ As a consequence, writing an extension requires implementing several classes: websockets provides base classes for extension factories and extensions. See :class:`ClientExtensionFactory`, :class:`ServerExtensionFactory`, and :class:`Extension` for details. - - From 28e1c7d60214f85d058d4f63629dc94a7e686dc4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 16 Aug 2022 11:58:57 +0200 Subject: [PATCH 1070/1539] Move FAQ to top-level. This should make it more discoverable. Also: * split FAQ in multiple pages * move quick start to the howto section --- docs/faq/asyncio.rst | 73 ++++ docs/faq/client.rst | 83 ++++ docs/faq/common.rst | 151 +++++++ docs/faq/index.rst | 21 + docs/faq/misc.rst | 26 ++ docs/faq/server.rst | 279 ++++++++++++ docs/howto/faq.rst | 621 --------------------------- docs/howto/index.rst | 8 +- docs/{intro => howto}/quickstart.rst | 0 docs/index.rst | 3 +- docs/intro/index.rst | 6 +- docs/reference/index.rst | 2 +- 12 files changed, 644 insertions(+), 629 deletions(-) create mode 100644 docs/faq/asyncio.rst create mode 100644 docs/faq/client.rst create mode 100644 docs/faq/common.rst create mode 100644 docs/faq/index.rst create mode 100644 docs/faq/misc.rst create mode 100644 docs/faq/server.rst delete mode 100644 docs/howto/faq.rst rename docs/{intro => howto}/quickstart.rst (100%) diff --git a/docs/faq/asyncio.rst b/docs/faq/asyncio.rst new file mode 100644 index 000000000..7db43e1be --- /dev/null +++ b/docs/faq/asyncio.rst @@ -0,0 +1,73 @@ +asyncio usage +============= + +.. currentmodule:: websockets + +How do I run two coroutines in parallel? +---------------------------------------- + +You must start two tasks, which the event loop will run concurrently. You can +achieve this with :func:`asyncio.gather` or :func:`asyncio.create_task`. + +Keep track of the tasks and make sure they terminate or you cancel them when +the connection terminates. + +Why does my program never receive any messages? +----------------------------------------------- + +Your program runs a coroutine that never yields control to the event loop. The +coroutine that receives messages never gets a chance to run. + +Putting an ``await`` statement in a ``for`` or a ``while`` loop isn't enough +to yield control. Awaiting a coroutine may yield control, but there's no +guarantee that it will. + +For example, :meth:`~legacy.protocol.WebSocketCommonProtocol.send` only yields +control when send buffers are full, which never happens in most practical +cases. + +If you run a loop that contains only synchronous operations and +a :meth:`~legacy.protocol.WebSocketCommonProtocol.send` call, you must yield +control explicitly with :func:`asyncio.sleep`:: + + async def producer(websocket): + message = generate_next_message() + await websocket.send(message) + await asyncio.sleep(0) # yield control to the event loop + +:func:`asyncio.sleep` always suspends the current task, allowing other tasks +to run. This behavior is documented precisely because it isn't expected from +every coroutine. + +See `issue 867`_. + +.. _issue 867: https://github.com/aaugustin/websockets/issues/867 + +Why am I having problems with threads? +-------------------------------------- + +You shouldn't use threads. Use tasks instead. + +Indeed, when you chose websockets, you chose :mod:`asyncio` as the primary +framework to handle concurrency. This choice is mutually exclusive with +:mod:`threading`. + +If you believe that you need to run websockets in a thread and some logic in +another thread, you should run that logic in a :class:`~asyncio.Task` instead. + +If you believe that you cannot run that logic in the same event loop because it +will block websockets, :meth:`~asyncio.loop.run_in_executor` may help. + +This question is really about :mod:`asyncio`. Please review the advice about +:ref:`asyncio-multithreading` in the Python documentation. + +Why does my simple program misbehave mysteriously? +-------------------------------------------------- + +You are using :func:`time.sleep` instead of :func:`asyncio.sleep`, which +blocks the event loop and prevents asyncio from operating normally. + +This may lead to messages getting send but not received, to connection +timeouts, and to unexpected results of shotgun debugging e.g. adding an +unnecessary call to :meth:`~legacy.protocol.WebSocketCommonProtocol.send` +makes the program functional. diff --git a/docs/faq/client.rst b/docs/faq/client.rst new file mode 100644 index 000000000..5b39cf1ec --- /dev/null +++ b/docs/faq/client.rst @@ -0,0 +1,83 @@ +Client +====== + +.. currentmodule:: websockets + +Why does the client close the connection prematurely? +----------------------------------------------------- + +You're exiting the context manager prematurely. Wait for the work to be +finished before exiting. + +For example, if your code has a structure similar to:: + + async with connect(...) as websocket: + asyncio.create_task(do_some_work()) + +change it to:: + + async with connect(...) as websocket: + await do_some_work() + +How do I access HTTP headers? +----------------------------- + +Once the connection is established, HTTP headers are available in +:attr:`~client.WebSocketClientProtocol.request_headers` and +:attr:`~client.WebSocketClientProtocol.response_headers`. + +How do I set HTTP headers? +-------------------------- + +To set the ``Origin``, ``Sec-WebSocket-Extensions``, or +``Sec-WebSocket-Protocol`` headers in the WebSocket handshake request, use the +``origin``, ``extensions``, or ``subprotocols`` arguments of +:func:`~client.connect`. + +To set other HTTP headers, for example the ``Authorization`` header, use the +``extra_headers`` argument:: + + async with connect(..., extra_headers={"Authorization": ...}) as websocket: + ... + +How do I close a connection? +---------------------------- + +The easiest is to use :func:`~client.connect` as a context manager:: + + async with connect(...) as websocket: + ... + +The connection is closed when exiting the context manager. + +How do I reconnect when the connection drops? +--------------------------------------------- + +Use :func:`~client.connect` as an asynchronous iterator:: + + async for websocket in websockets.connect(...): + try: + ... + except websockets.ConnectionClosed: + continue + +Make sure you handle exceptions in the ``async for`` loop. Uncaught exceptions +will break out of the loop. + +How do I stop a client that is processing messages in a loop? +------------------------------------------------------------- + +You can close the connection. + +Here's an example that terminates cleanly when it receives SIGTERM on Unix: + +.. literalinclude:: ../../example/shutdown_client.py + :emphasize-lines: 10-13 + +How do I disable TLS/SSL certificate verification? +-------------------------------------------------- + +Look at the ``ssl`` argument of :meth:`~asyncio.loop.create_connection`. + +:func:`~client.connect` accepts the same arguments as +:meth:`~asyncio.loop.create_connection`. diff --git a/docs/faq/common.rst b/docs/faq/common.rst new file mode 100644 index 000000000..dff64f67c --- /dev/null +++ b/docs/faq/common.rst @@ -0,0 +1,151 @@ +Both sides +========== + +.. currentmodule:: websockets + +What does ``ConnectionClosedError: no close frame received or sent`` mean? +-------------------------------------------------------------------------- + +If you're seeing this traceback in the logs of a server: + +.. code-block:: pytb + + connection handler failed + Traceback (most recent call last): + ... + asyncio.exceptions.IncompleteReadError: 0 bytes read on a total of 2 expected bytes + + The above exception was the direct cause of the following exception: + + Traceback (most recent call last): + ... + websockets.exceptions.ConnectionClosedError: no close frame received or sent + +or if a client crashes with this traceback: + +.. code-block:: pytb + + Traceback (most recent call last): + ... + ConnectionResetError: [Errno 54] Connection reset by peer + + The above exception was the direct cause of the following exception: + + Traceback (most recent call last): + ... + websockets.exceptions.ConnectionClosedError: no close frame received or sent + +it means that the TCP connection was lost. As a consequence, the WebSocket +connection was closed without receiving and sending a close frame, which is +abnormal. + +You can catch and handle :exc:`~exceptions.ConnectionClosed` to prevent it +from being logged. + +There are several reasons why long-lived connections may be lost: + +* End-user devices tend to lose network connectivity often and unpredictably + because they can move out of wireless network coverage, get unplugged from + a wired network, enter airplane mode, be put to sleep, etc. +* HTTP load balancers or proxies that aren't configured for long-lived + connections may terminate connections after a short amount of time, usually + 30 seconds, despite websockets' keepalive mechanism. + +If you're facing a reproducible issue, :ref:`enable debug logs ` to +see when and how connections are closed. + +What does ``ConnectionClosedError: sent 1011 (unexpected error) keepalive ping timeout; no close frame received`` mean? +----------------------------------------------------------------------------------------------------------------------- + +If you're seeing this traceback in the logs of a server: + +.. code-block:: pytb + + connection handler failed + Traceback (most recent call last): + ... + asyncio.exceptions.CancelledError + + The above exception was the direct cause of the following exception: + + Traceback (most recent call last): + ... + websockets.exceptions.ConnectionClosedError: sent 1011 (unexpected error) keepalive ping timeout; no close frame received + +or if a client crashes with this traceback: + +.. code-block:: pytb + + Traceback (most recent call last): + ... + asyncio.exceptions.CancelledError + + The above exception was the direct cause of the following exception: + + Traceback (most recent call last): + ... + websockets.exceptions.ConnectionClosedError: sent 1011 (unexpected error) keepalive ping timeout; no close frame received + +it means that the WebSocket connection suffered from excessive latency and was +closed after reaching the timeout of websockets' keepalive mechanism. + +You can catch and handle :exc:`~exceptions.ConnectionClosed` to prevent it +from being logged. + +There are two main reasons why latency may increase: + +* Poor network connectivity. +* More traffic than the recipient can handle. + +See the discussion of :doc:`timeouts <../topics/timeouts>` for details. + +If websockets' default timeout of 20 seconds is too short for your use case, +you can adjust it with the ``ping_timeout`` argument. + +How do I set a timeout on :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`? +-------------------------------------------------------------------------------- + +Use :func:`~asyncio.wait_for`:: + + await asyncio.wait_for(websocket.recv(), timeout=10) + +This technique works for most APIs, except for asynchronous context managers. +See `issue 574`_. + +.. _issue 574: https://github.com/aaugustin/websockets/issues/574 + +How can I pass arguments to a custom protocol subclass? +------------------------------------------------------- + +You can bind additional arguments to the protocol factory with +:func:`functools.partial`:: + + import asyncio + import functools + import websockets + + class MyServerProtocol(websockets.WebSocketServerProtocol): + def __init__(self, extra_argument, *args, **kwargs): + super().__init__(*args, **kwargs) + # do something with extra_argument + + create_protocol = functools.partial(MyServerProtocol, extra_argument='spam') + start_server = websockets.serve(..., create_protocol=create_protocol) + +This example was for a server. The same pattern applies on a client. + +How do I keep idle connections open? +------------------------------------ + +websockets sends pings at 20 seconds intervals to keep the connection open. + +It closes the connection if it doesn't get a pong within 20 seconds. + +You can adjust this behavior with ``ping_interval`` and ``ping_timeout``. + +See :doc:`../topics/timeouts` for details. + +How do I respond to pings? +-------------------------- + +Don't bother; websockets takes care of responding to pings with pongs. diff --git a/docs/faq/index.rst b/docs/faq/index.rst new file mode 100644 index 000000000..9d5b0d538 --- /dev/null +++ b/docs/faq/index.rst @@ -0,0 +1,21 @@ +Frequently asked questions +========================== + +.. currentmodule:: websockets + +.. admonition:: Many questions asked in websockets' issue tracker are really + about :mod:`asyncio`. + :class: seealso + + Python's documentation about `developing with asyncio`_ is a good + complement. + + .. _developing with asyncio: https://docs.python.org/3/library/asyncio-dev.html + +.. toctree:: + + server + client + common + asyncio + misc diff --git a/docs/faq/misc.rst b/docs/faq/misc.rst new file mode 100644 index 000000000..3606937b8 --- /dev/null +++ b/docs/faq/misc.rst @@ -0,0 +1,26 @@ +Miscellaneous +============= + +.. currentmodule:: websockets + +Can I use websockets without ``async`` and ``await``? +..................................................... + +No, there is no convenient way to do this. You should use another library. + +Are there ``onopen``, ``onmessage``, ``onerror``, and ``onclose`` callbacks? +............................................................................ + +No, there aren't. + +websockets provides high-level, coroutine-based APIs. Compared to callbacks, +coroutines make it easier to manage control flow in concurrent code. + +If you prefer callback-based APIs, you should use another library. + +Why do I get the error: ``module 'websockets' has no attribute '...'``? +....................................................................... + +Often, this is because you created a script called ``websockets.py`` in your +current working directory. Then ``import websockets`` imports this module +instead of the websockets library. diff --git a/docs/faq/server.rst b/docs/faq/server.rst new file mode 100644 index 000000000..a83f879b9 --- /dev/null +++ b/docs/faq/server.rst @@ -0,0 +1,279 @@ +Server +====== + +.. currentmodule:: websockets + +Why does the server close the connection prematurely? +----------------------------------------------------- + +Your connection handler exits prematurely. Wait for the work to be finished +before returning. + +For example, if your handler has a structure similar to:: + + async def handler(websocket): + asyncio.create_task(do_some_work()) + +change it to:: + + async def handler(websocket): + await do_some_work() + +Why does the server close the connection after one message? +----------------------------------------------------------- + +Your connection handler exits after processing one message. Write a loop to +process multiple messages. + +For example, if your handler looks like this:: + + async def handler(websocket): + print(websocket.recv()) + +change it like this:: + + async def handler(websocket): + async for message in websocket: + print(message) + +*Don't feel bad if this happens to you — it's the most common question in +websockets' issue tracker :-)* + +Why can only one client connect at a time? +------------------------------------------ + +Your connection handler blocks the event loop. Look for blocking calls. +Any call that may take some time must be asynchronous. + +For example, if you have:: + + async def handler(websocket): + time.sleep(1) + +change it to:: + + async def handler(websocket): + await asyncio.sleep(1) + +This is part of learning asyncio. It isn't specific to websockets. + +See also Python's documentation about `running blocking code`_. + +.. _running blocking code: https://docs.python.org/3/library/asyncio-dev.html#running-blocking-code + +.. _send-message-to-all-users: + +How do I send a message to all users? +------------------------------------- + +Record all connections in a global variable:: + + CONNECTIONS = set() + + async def handler(websocket): + CONNECTIONS.add(websocket) + try: + await websocket.wait_closed() + finally: + CONNECTIONS.remove(websocket) + +Then, call :func:`~websockets.broadcast`:: + + import websockets + + def message_all(message): + websockets.broadcast(CONNECTIONS, message) + +If you're running multiple server processes, make sure you call ``message_all`` +in each process. + +.. _send-message-to-single-user: + +How do I send a message to a single user? +----------------------------------------- + +Record connections in a global variable, keyed by user identifier:: + + CONNECTIONS = {} + + async def handler(websocket): + user_id = ... # identify user in your app's context + CONNECTIONS[user_id] = websocket + try: + await websocket.wait_closed() + finally: + del CONNECTIONS[user_id] + +Then, call :meth:`~legacy.protocol.WebSocketCommonProtocol.send`:: + + async def message_user(user_id, message): + websocket = CONNECTIONS[user_id] # raises KeyError if user disconnected + await websocket.send(message) # may raise websockets.ConnectionClosed + +Add error handling according to the behavior you want if the user disconnected +before the message could be sent. + +This example supports only one connection per user. To support concurrent +connects by the same user, you can change ``CONNECTIONS`` to store a set of +connections for each user. + +If you're running multiple server processes, call ``message_user`` in each +process. The process managing the user's connection sends the message; other +processes do nothing. + +When you reach a scale where server processes cannot keep up with the stream of +all messages, you need a better architecture. For example, you could deploy an +external publish / subscribe system such as Redis_. Server processes would +subscribe their clients. Then, they would receive messages only for the +connections that they're managing. + +.. _Redis: https://redis.io/ + +How do I send a message to a channel, a topic, or some users? +------------------------------------------------------------- + +websockets doesn't provide built-in publish / subscribe functionality. + +Record connections in a global variable, keyed by user identifier, as shown in +:ref:`How do I send a message to a single user?` + +Then, build the set of recipients and broadcast the message to them, as shown in +:ref:`How do I send a message to all users?` + +:doc:`../howto/django` contains a complete implementation of this pattern. + +Again, as you scale, you may reach the performance limits of a basic in-process +implementation. You may need an external publish / subscribe system like Redis_. + +.. _Redis: https://redis.io/ + +How do I pass arguments to the connection handler? +-------------------------------------------------- + +You can bind additional arguments to the connection handler with +:func:`functools.partial`:: + + import asyncio + import functools + import websockets + + async def handler(websocket, extra_argument): + ... + + bound_handler = functools.partial(handler, extra_argument='spam') + start_server = websockets.serve(bound_handler, ...) + +Another way to achieve this result is to define the ``handler`` coroutine in +a scope where the ``extra_argument`` variable exists instead of injecting it +through an argument. + +How do I access the request path? +--------------------------------- + +It is available in the :attr:`~server.WebSocketServerProtocol.path` attribute. + +You may route a connection to different handlers depending on the request path:: + + async def handler(websocket): + if websocket.path == "/blue": + await blue_handler(websocket) + elif websocket.path == "/green": + await green_handler(websocket) + else: + # No handler for this path; close the connection. + return + +You may also route the connection based on the first message received from the +client, as shown in the :doc:`tutorial <../intro/tutorial2>`. When you want to +authenticate the connection before routing it, this is usually more convenient. + +Generally speaking, there is far less emphasis on the request path in WebSocket +servers than in HTTP servers. When a WebSockt server provides a single endpoint, +it may ignore the request path entirely. + +How do I access HTTP headers? +----------------------------- + +To access HTTP headers during the WebSocket handshake, you can override +:attr:`~server.WebSocketServerProtocol.process_request`:: + + async def process_request(self, path, request_headers): + authorization = request_headers["Authorization"] + +Once the connection is established, HTTP headers are available in +:attr:`~server.WebSocketServerProtocol.request_headers` and +:attr:`~server.WebSocketServerProtocol.response_headers`:: + + async def handler(websocket): + authorization = websocket.request_headers["Authorization"] + +How do I set HTTP headers? +-------------------------- + +To set the ``Sec-WebSocket-Extensions`` or ``Sec-WebSocket-Protocol`` headers in +the WebSocket handshake response, use the ``extensions`` or ``subprotocols`` +arguments of :func:`~server.serve`. + +To set other HTTP headers, use the ``extra_headers`` argument. + +How do I get the IP address of the client? +------------------------------------------ + +It's available in :attr:`~legacy.protocol.WebSocketCommonProtocol.remote_address`:: + + async def handler(websocket): + remote_ip = websocket.remote_address[0] + +How do I set the IP addresses my server listens on? +--------------------------------------------------- + +Look at the ``host`` argument of :meth:`~asyncio.loop.create_server`. + +:func:`~server.serve` accepts the same arguments as +:meth:`~asyncio.loop.create_server`. + +What does ``OSError: [Errno 99] error while attempting to bind on address ('::1', 80, 0, 0): address not available`` mean? +-------------------------------------------------------------------------------------------------------------------------- + +You are calling :func:`~server.serve` without a ``host`` argument in a context +where IPv6 isn't available. + +To listen only on IPv4, specify ``host="0.0.0.0"`` or ``family=socket.AF_INET``. + +Refer to the documentation of :meth:`~asyncio.loop.create_server` for details. + +How do I close a connection? +---------------------------- + +websockets takes care of closing the connection when the handler exits. + +How do I stop a server? +----------------------- + +Exit the :func:`~server.serve` context manager. + +Here's an example that terminates cleanly when it receives SIGTERM on Unix: + +.. literalinclude:: ../../example/shutdown_server.py + :emphasize-lines: 12-15,18 + +How do I run HTTP and WebSocket servers on the same port? +--------------------------------------------------------- + +You don't. + +HTTP and WebSocket have widely different operational characteristics. Running +them with the same server becomes inconvenient when you scale. + +Providing a HTTP server is out of scope for websockets. It only aims at +providing a WebSocket server. + +There's limited support for returning HTTP responses with the +:attr:`~server.WebSocketServerProtocol.process_request` hook. + +If you need more, pick a HTTP server and run it separately. + +Alternatively, pick a HTTP framework that builds on top of ``websockets`` to +support WebSocket connections, like Sanic_. + +.. _Sanic: https://sanicframework.org/en/ diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst deleted file mode 100644 index a103f677c..000000000 --- a/docs/howto/faq.rst +++ /dev/null @@ -1,621 +0,0 @@ -FAQ -=== - -.. currentmodule:: websockets - -.. admonition:: Many questions asked in websockets' issue tracker are really - about :mod:`asyncio`. - :class: seealso - - Python's documentation about `developing with asyncio`_ is a good - complement. - - .. _developing with asyncio: https://docs.python.org/3/library/asyncio-dev.html - -Server side ------------ - -Why does the server close the connection prematurely? -..................................................... - -Your connection handler exits prematurely. Wait for the work to be finished -before returning. - -For example, if your handler has a structure similar to:: - - async def handler(websocket): - asyncio.create_task(do_some_work()) - -change it to:: - - async def handler(websocket): - await do_some_work() - -Why does the server close the connection after one message? -........................................................... - -Your connection handler exits after processing one message. Write a loop to -process multiple messages. - -For example, if your handler looks like this:: - - async def handler(websocket): - print(websocket.recv()) - -change it like this:: - - async def handler(websocket): - async for message in websocket: - print(message) - -*Don't feel bad if this happens to you — it's the most common question in -websockets' issue tracker :-)* - -Why can only one client connect at a time? -.......................................... - -Your connection handler blocks the event loop. Look for blocking calls. -Any call that may take some time must be asynchronous. - -For example, if you have:: - - async def handler(websocket): - time.sleep(1) - -change it to:: - - async def handler(websocket): - await asyncio.sleep(1) - -This is part of learning asyncio. It isn't specific to websockets. - -See also Python's documentation about `running blocking code`_. - -.. _running blocking code: https://docs.python.org/3/library/asyncio-dev.html#running-blocking-code - -.. _send-message-to-all-users: - -How do I send a message to all users? -..................................... - -Record all connections in a global variable:: - - CONNECTIONS = set() - - async def handler(websocket): - CONNECTIONS.add(websocket) - try: - await websocket.wait_closed() - finally: - CONNECTIONS.remove(websocket) - -Then, call :func:`~websockets.broadcast`:: - - import websockets - - def message_all(message): - websockets.broadcast(CONNECTIONS, message) - -If you're running multiple server processes, make sure you call ``message_all`` -in each process. - -.. _send-message-to-single-user: - -How do I send a message to a single user? -......................................... - -Record connections in a global variable, keyed by user identifier:: - - CONNECTIONS = {} - - async def handler(websocket): - user_id = ... # identify user in your app's context - CONNECTIONS[user_id] = websocket - try: - await websocket.wait_closed() - finally: - del CONNECTIONS[user_id] - -Then, call :meth:`~legacy.protocol.WebSocketCommonProtocol.send`:: - - async def message_user(user_id, message): - websocket = CONNECTIONS[user_id] # raises KeyError if user disconnected - await websocket.send(message) # may raise websockets.ConnectionClosed - -Add error handling according to the behavior you want if the user disconnected -before the message could be sent. - -This example supports only one connection per user. To support concurrent -connects by the same user, you can change ``CONNECTIONS`` to store a set of -connections for each user. - -If you're running multiple server processes, call ``message_user`` in each -process. The process managing the user's connection sends the message; other -processes do nothing. - -When you reach a scale where server processes cannot keep up with the stream of -all messages, you need a better architecture. For example, you could deploy an -external publish / subscribe system such as Redis_. Server processes would -subscribe their clients. Then, they would receive messages only for the -connections that they're managing. - -.. _Redis: https://redis.io/ - -How do I send a message to a channel, a topic, or some users? -............................................................. - -websockets doesn't provide built-in publish / subscribe functionality. - -Record connections in a global variable, keyed by user identifier, as shown in -:ref:`How do I send a message to a single user?` - -Then, build the set of recipients and broadcast the message to them, as shown in -:ref:`How do I send a message to all users?` - -:doc:`django` contains a complete implementation of this pattern. - -Again, as you scale, you may reach the performance limits of a basic in-process -implementation. You may need an external publish / subscribe system like Redis_. - -.. _Redis: https://redis.io/ - -How do I pass arguments to the connection handler? -.................................................. - -You can bind additional arguments to the connection handler with -:func:`functools.partial`:: - - import asyncio - import functools - import websockets - - async def handler(websocket, extra_argument): - ... - - bound_handler = functools.partial(handler, extra_argument='spam') - start_server = websockets.serve(bound_handler, ...) - -Another way to achieve this result is to define the ``handler`` coroutine in -a scope where the ``extra_argument`` variable exists instead of injecting it -through an argument. - -How do I access the request path? -................................. - -It is available in the :attr:`~server.WebSocketServerProtocol.path` attribute. - -You may route a connection to different handlers depending on the request path:: - - async def handler(websocket): - if websocket.path == "/blue": - await blue_handler(websocket) - elif websocket.path == "/green": - await green_handler(websocket) - else: - # No handler for this path; close the connection. - return - -You may also route the connection based on the first message received from the -client, as shown in the :doc:`tutorial <../intro/tutorial2>`. When you want to -authenticate the connection before routing it, this is usually more convenient. - -Generally speaking, there is far less emphasis on the request path in WebSocket -servers than in HTTP servers. When a WebSockt server provides a single endpoint, -it may ignore the request path entirely. - -How do I access HTTP headers? -............................. - -To access HTTP headers during the WebSocket handshake, you can override -:attr:`~server.WebSocketServerProtocol.process_request`:: - - async def process_request(self, path, request_headers): - authorization = request_headers["Authorization"] - -Once the connection is established, HTTP headers are available in -:attr:`~server.WebSocketServerProtocol.request_headers` and -:attr:`~server.WebSocketServerProtocol.response_headers`:: - - async def handler(websocket): - authorization = websocket.request_headers["Authorization"] - -How do I set HTTP headers? -.......................... - -To set the ``Sec-WebSocket-Extensions`` or ``Sec-WebSocket-Protocol`` headers in -the WebSocket handshake response, use the ``extensions`` or ``subprotocols`` -arguments of :func:`~server.serve`. - -To set other HTTP headers, use the ``extra_headers`` argument. - -How do I get the IP address of the client? -.......................................... - -It's available in :attr:`~legacy.protocol.WebSocketCommonProtocol.remote_address`:: - - async def handler(websocket): - remote_ip = websocket.remote_address[0] - -How do I set the IP addresses my server listens on? -................................................... - -Look at the ``host`` argument of :meth:`~asyncio.loop.create_server`. - -:func:`~server.serve` accepts the same arguments as -:meth:`~asyncio.loop.create_server`. - -What does ``OSError: [Errno 99] error while attempting to bind on address ('::1', 80, 0, 0): address not available`` mean? -.......................................................................................................................... - -You are calling :func:`~server.serve` without a ``host`` argument in a context -where IPv6 isn't available. - -To listen only on IPv4, specify ``host="0.0.0.0"`` or ``family=socket.AF_INET``. - -Refer to the documentation of :meth:`~asyncio.loop.create_server` for details. - -How do I close a connection? -............................ - -websockets takes care of closing the connection when the handler exits. - -How do I stop a server? -....................... - -Exit the :func:`~server.serve` context manager. - -Here's an example that terminates cleanly when it receives SIGTERM on Unix: - -.. literalinclude:: ../../example/shutdown_server.py - :emphasize-lines: 12-15,18 - - -How do I run HTTP and WebSocket servers on the same port? -......................................................... - -You don't. - -HTTP and WebSocket have widely different operational characteristics. Running -them with the same server becomes inconvenient when you scale. - -Providing a HTTP server is out of scope for websockets. It only aims at -providing a WebSocket server. - -There's limited support for returning HTTP responses with the -:attr:`~server.WebSocketServerProtocol.process_request` hook. - -If you need more, pick a HTTP server and run it separately. - -Alternatively, pick a HTTP framework that builds on top of ``websockets`` to -support WebSocket connections, like Sanic_. - -.. _Sanic: https://sanicframework.org/en/ - -Client side ------------ - -Why does the client close the connection prematurely? -..................................................... - -You're exiting the context manager prematurely. Wait for the work to be -finished before exiting. - -For example, if your code has a structure similar to:: - - async with connect(...) as websocket: - asyncio.create_task(do_some_work()) - -change it to:: - - async with connect(...) as websocket: - await do_some_work() - -How do I access HTTP headers? -............................. - -Once the connection is established, HTTP headers are available in -:attr:`~client.WebSocketClientProtocol.request_headers` and -:attr:`~client.WebSocketClientProtocol.response_headers`. - -How do I set HTTP headers? -.......................... - -To set the ``Origin``, ``Sec-WebSocket-Extensions``, or -``Sec-WebSocket-Protocol`` headers in the WebSocket handshake request, use the -``origin``, ``extensions``, or ``subprotocols`` arguments of -:func:`~client.connect`. - -To set other HTTP headers, for example the ``Authorization`` header, use the -``extra_headers`` argument:: - - async with connect(..., extra_headers={"Authorization": ...}) as websocket: - ... - -How do I close a connection? -............................ - -The easiest is to use :func:`~client.connect` as a context manager:: - - async with connect(...) as websocket: - ... - -The connection is closed when exiting the context manager. - -How do I reconnect when the connection drops? -............................................. - -Use :func:`~client.connect` as an asynchronous iterator:: - - async for websocket in websockets.connect(...): - try: - ... - except websockets.ConnectionClosed: - continue - -Make sure you handle exceptions in the ``async for`` loop. Uncaught exceptions -will break out of the loop. - -How do I stop a client that is processing messages in a loop? -............................................................. - -You can close the connection. - -Here's an example that terminates cleanly when it receives SIGTERM on Unix: - -.. literalinclude:: ../../example/shutdown_client.py - :emphasize-lines: 10-13 - -How do I disable TLS/SSL certificate verification? -.................................................. - -Look at the ``ssl`` argument of :meth:`~asyncio.loop.create_connection`. - -:func:`~client.connect` accepts the same arguments as -:meth:`~asyncio.loop.create_connection`. - -asyncio usage -------------- - -How do I run two coroutines in parallel? -........................................ - -You must start two tasks, which the event loop will run concurrently. You can -achieve this with :func:`asyncio.gather` or :func:`asyncio.create_task`. - -Keep track of the tasks and make sure they terminate or you cancel them when -the connection terminates. - -Why does my program never receive any messages? -............................................... - -Your program runs a coroutine that never yields control to the event loop. The -coroutine that receives messages never gets a chance to run. - -Putting an ``await`` statement in a ``for`` or a ``while`` loop isn't enough -to yield control. Awaiting a coroutine may yield control, but there's no -guarantee that it will. - -For example, :meth:`~legacy.protocol.WebSocketCommonProtocol.send` only yields -control when send buffers are full, which never happens in most practical -cases. - -If you run a loop that contains only synchronous operations and -a :meth:`~legacy.protocol.WebSocketCommonProtocol.send` call, you must yield -control explicitly with :func:`asyncio.sleep`:: - - async def producer(websocket): - message = generate_next_message() - await websocket.send(message) - await asyncio.sleep(0) # yield control to the event loop - -:func:`asyncio.sleep` always suspends the current task, allowing other tasks -to run. This behavior is documented precisely because it isn't expected from -every coroutine. - -See `issue 867`_. - -.. _issue 867: https://github.com/aaugustin/websockets/issues/867 - -Why does my simple program misbehave mysteriously? -.................................................. - -You are using :func:`time.sleep` instead of :func:`asyncio.sleep`, which -blocks the event loop and prevents asyncio from operating normally. - -This may lead to messages getting send but not received, to connection -timeouts, and to unexpected results of shotgun debugging e.g. adding an -unnecessary call to :meth:`~legacy.protocol.WebSocketCommonProtocol.send` -makes the program functional. - -Both sides ----------- - -What does ``ConnectionClosedError: no close frame received or sent`` mean? -.......................................................................... - -If you're seeing this traceback in the logs of a server: - -.. code-block:: pytb - - connection handler failed - Traceback (most recent call last): - ... - asyncio.exceptions.IncompleteReadError: 0 bytes read on a total of 2 expected bytes - - The above exception was the direct cause of the following exception: - - Traceback (most recent call last): - ... - websockets.exceptions.ConnectionClosedError: no close frame received or sent - -or if a client crashes with this traceback: - -.. code-block:: pytb - - Traceback (most recent call last): - ... - ConnectionResetError: [Errno 54] Connection reset by peer - - The above exception was the direct cause of the following exception: - - Traceback (most recent call last): - ... - websockets.exceptions.ConnectionClosedError: no close frame received or sent - -it means that the TCP connection was lost. As a consequence, the WebSocket -connection was closed without receiving and sending a close frame, which is -abnormal. - -You can catch and handle :exc:`~exceptions.ConnectionClosed` to prevent it -from being logged. - -There are several reasons why long-lived connections may be lost: - -* End-user devices tend to lose network connectivity often and unpredictably - because they can move out of wireless network coverage, get unplugged from - a wired network, enter airplane mode, be put to sleep, etc. -* HTTP load balancers or proxies that aren't configured for long-lived - connections may terminate connections after a short amount of time, usually - 30 seconds, despite websockets' keepalive mechanism. - -If you're facing a reproducible issue, :ref:`enable debug logs ` to -see when and how connections are closed. - -What does ``ConnectionClosedError: sent 1011 (unexpected error) keepalive ping timeout; no close frame received`` mean? -....................................................................................................................... - -If you're seeing this traceback in the logs of a server: - -.. code-block:: pytb - - connection handler failed - Traceback (most recent call last): - ... - asyncio.exceptions.CancelledError - - The above exception was the direct cause of the following exception: - - Traceback (most recent call last): - ... - websockets.exceptions.ConnectionClosedError: sent 1011 (unexpected error) keepalive ping timeout; no close frame received - -or if a client crashes with this traceback: - -.. code-block:: pytb - - Traceback (most recent call last): - ... - asyncio.exceptions.CancelledError - - The above exception was the direct cause of the following exception: - - Traceback (most recent call last): - ... - websockets.exceptions.ConnectionClosedError: sent 1011 (unexpected error) keepalive ping timeout; no close frame received - -it means that the WebSocket connection suffered from excessive latency and was -closed after reaching the timeout of websockets' keepalive mechanism. - -You can catch and handle :exc:`~exceptions.ConnectionClosed` to prevent it -from being logged. - -There are two main reasons why latency may increase: - -* Poor network connectivity. -* More traffic than the recipient can handle. - -See the discussion of :doc:`timeouts <../topics/timeouts>` for details. - -If websockets' default timeout of 20 seconds is too short for your use case, -you can adjust it with the ``ping_timeout`` argument. - -How do I set a timeout on :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`? -................................................................................ - -Use :func:`~asyncio.wait_for`:: - - await asyncio.wait_for(websocket.recv(), timeout=10) - -This technique works for most APIs, except for asynchronous context managers. -See `issue 574`_. - -.. _issue 574: https://github.com/aaugustin/websockets/issues/574 - -How can I pass arguments to a custom protocol subclass? -....................................................... - -You can bind additional arguments to the protocol factory with -:func:`functools.partial`:: - - import asyncio - import functools - import websockets - - class MyServerProtocol(websockets.WebSocketServerProtocol): - def __init__(self, extra_argument, *args, **kwargs): - super().__init__(*args, **kwargs) - # do something with extra_argument - - create_protocol = functools.partial(MyServerProtocol, extra_argument='spam') - start_server = websockets.serve(..., create_protocol=create_protocol) - -This example was for a server. The same pattern applies on a client. - -How do I keep idle connections open? -.................................... - -websockets sends pings at 20 seconds intervals to keep the connection open. - -It closes the connection if it doesn't get a pong within 20 seconds. - -You can adjust this behavior with ``ping_interval`` and ``ping_timeout``. - -See :doc:`../topics/timeouts` for details. - -How do I respond to pings? -.......................... - -Don't bother; websockets takes care of responding to pings with pongs. - -Miscellaneous -------------- - -Can I use websockets without ``async`` and ``await``? -..................................................... - -No, there is no convenient way to do this. You should use another library. - -Are there ``onopen``, ``onmessage``, ``onerror``, and ``onclose`` callbacks? -............................................................................ - -No, there aren't. - -websockets provides high-level, coroutine-based APIs. Compared to callbacks, -coroutines make it easier to manage control flow in concurrent code. - -If you prefer callback-based APIs, you should use another library. - -Why do I get the error: ``module 'websockets' has no attribute '...'``? -....................................................................... - -Often, this is because you created a script called ``websockets.py`` in your -current working directory. Then ``import websockets`` imports this module -instead of the websockets library. - -Why am I having problems with threads? -...................................... - -You shouldn't use threads. Use tasks instead. - -Indeed, when you chose websockets, you chose :mod:`asyncio` as the primary -framework to handle concurrency. This choice is mutually exclusive with -:mod:`threading`. - -If you believe that you need to run websockets in a thread and some logic in -another thread, you should run that logic in a :class:`~asyncio.Task` instead. - -If you believe that you cannot run that logic in the same event loop because it -will block websockets, :meth:`~asyncio.loop.run_in_executor` may help. - -This question is really about :mod:`asyncio`. Please review the advice about -:ref:`asyncio-multithreading` in the Python documentation. diff --git a/docs/howto/index.rst b/docs/howto/index.rst index d399cebc2..dafb72391 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -1,12 +1,18 @@ How-to guides ============= +In a hurry? Check out these examples. + +.. toctree:: + :titlesonly: + + quickstart + If you're stuck, perhaps you'll find the answer here. .. toctree:: :titlesonly: - faq cheatsheet patterns autoreload diff --git a/docs/intro/quickstart.rst b/docs/howto/quickstart.rst similarity index 100% rename from docs/intro/quickstart.rst rename to docs/howto/quickstart.rst diff --git a/docs/index.rst b/docs/index.rst index 07835a81c..00a9b5999 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -52,13 +52,14 @@ Also, websockets provides an interactive client: < Hello world! Connection closed: 1000 (OK). -Do you like it? Let's dive in! +Do you like it? :doc:`Let's dive in! ` .. toctree:: :hidden: intro/index howto/index + faq/index reference/index topics/index project/index diff --git a/docs/intro/index.rst b/docs/intro/index.rst index 2c66dea9a..fe4e704d6 100644 --- a/docs/intro/index.rst +++ b/docs/intro/index.rst @@ -43,8 +43,4 @@ Learn how to build an real-time web application with websockets. In a hurry? ----------- -Check out these examples. - -.. toctree:: - - quickstart +Look at the :doc:`quick start guide <../howto/quickstart>`. diff --git a/docs/reference/index.rst b/docs/reference/index.rst index a5ee57843..5f51a1c1c 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -43,8 +43,8 @@ and server connections, common methods are documented in a "Both sides" page. .. toctree:: :titlesonly: - client server + client common utilities exceptions From 8e096fc5a2a49b3c42fb2543547867cab11297fb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 16 Aug 2022 12:39:33 +0200 Subject: [PATCH 1071/1539] Add FAQ on benchmarking. Fix #1189. --- docs/faq/misc.rst | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docs/faq/misc.rst b/docs/faq/misc.rst index 3606937b8..9ef07ef2d 100644 --- a/docs/faq/misc.rst +++ b/docs/faq/misc.rst @@ -24,3 +24,16 @@ Why do I get the error: ``module 'websockets' has no attribute '...'``? Often, this is because you created a script called ``websockets.py`` in your current working directory. Then ``import websockets`` imports this module instead of the websockets library. + +Why is websockets slower than another Python library in my benchmark? +..................................................................... + +Not all libraries are as feature-complete as websockets. For a fair benchmark, +you should disable features that the other library doesn't provide. Typically, +you may need to disable: + +* Compression: set ``compression=None`` +* Keepalive: set ``ping_interval=None`` +* UTF-8 decoding: send ``bytes`` rather than ``str`` + +If websockets is still slower than another Python library, please file a bug. From 2a07325cecf8be4732cd991ead0314c530dd7cdd Mon Sep 17 00:00:00 2001 From: Irfanuddin Date: Wed, 17 Aug 2022 10:09:05 +0200 Subject: [PATCH 1072/1539] Support removing Server and User-Agent headers. Fix #1193. --- docs/faq/client.rst | 3 ++ docs/faq/server.rst | 3 ++ docs/reference/client.rst | 6 ++-- docs/reference/server.rst | 6 ++-- src/websockets/client.py | 8 ++++- src/websockets/legacy/client.py | 13 ++++++-- src/websockets/legacy/server.py | 16 +++++++-- src/websockets/server.py | 12 +++++-- tests/legacy/test_client_server.py | 52 ++++++++++++++++++++++++++---- tests/test_client.py | 16 +++++++++ tests/test_server.py | 22 +++++++++++++ 11 files changed, 137 insertions(+), 20 deletions(-) diff --git a/docs/faq/client.rst b/docs/faq/client.rst index 5b39cf1ec..5bbbd6ded 100644 --- a/docs/faq/client.rst +++ b/docs/faq/client.rst @@ -34,6 +34,9 @@ To set the ``Origin``, ``Sec-WebSocket-Extensions``, or ``origin``, ``extensions``, or ``subprotocols`` arguments of :func:`~client.connect`. +To override the ``User-Agent`` header, use the ``user_agent_header`` argument. +Set it to :obj:`None` to remove the header. + To set other HTTP headers, for example the ``Authorization`` header, use the ``extra_headers`` argument:: diff --git a/docs/faq/server.rst b/docs/faq/server.rst index a83f879b9..68490d755 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -214,6 +214,9 @@ To set the ``Sec-WebSocket-Extensions`` or ``Sec-WebSocket-Protocol`` headers in the WebSocket handshake response, use the ``extensions`` or ``subprotocols`` arguments of :func:`~server.serve`. +To override the ``Server`` header, use the ``server_header`` argument. Set it to +:obj:`None` to remove the header. + To set other HTTP headers, use the ``extra_headers`` argument. How do I get the IP address of the client? diff --git a/docs/reference/client.rst b/docs/reference/client.rst index 379765397..3016e85d7 100644 --- a/docs/reference/client.rst +++ b/docs/reference/client.rst @@ -9,16 +9,16 @@ asyncio Opening a connection .................... -.. autofunction:: connect(uri, *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) +.. autofunction:: connect(uri, *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) :async: -.. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) +.. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) :async: Using a connection .................. -.. autoclass:: WebSocketClientProtocol(*, logger=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) +.. autoclass:: WebSocketClientProtocol(*, logger=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) .. automethod:: recv diff --git a/docs/reference/server.rst b/docs/reference/server.rst index 0e446f382..65f98842a 100644 --- a/docs/reference/server.rst +++ b/docs/reference/server.rst @@ -9,10 +9,10 @@ asyncio Starting a server ................. -.. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) +.. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) :async: -.. autofunction:: unix_serve(ws_handler, path=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) +.. autofunction:: unix_serve(ws_handler, path=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) :async: Stopping a server @@ -37,7 +37,7 @@ Stopping a server Using a connection .................. -.. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, logger=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) +.. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, logger=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) .. automethod:: recv diff --git a/src/websockets/client.py b/src/websockets/client.py index df8e53429..1c33fae0b 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -64,6 +64,9 @@ class ClientConnection(Connection): logger: logger for this connection; defaults to ``logging.getLogger("websockets.client")``; see the :doc:`logging guide <../topics/logging>` for details. + user_agent_header: value of the ``User-Agent`` request header; + defauts to ``"Python/x.y.z websockets/X.Y"``; + :obj:`None` removes the header. """ @@ -73,6 +76,7 @@ def __init__( origin: Optional[Origin] = None, extensions: Optional[Sequence[ClientExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, + user_agent_header: Optional[str] = USER_AGENT, state: State = CONNECTING, max_size: Optional[int] = 2**20, logger: Optional[LoggerLike] = None, @@ -87,6 +91,7 @@ def __init__( self.origin = origin self.available_extensions = extensions self.available_subprotocols = subprotocols + self.user_agent_header = user_agent_header self.key = generate_key() def connect(self) -> Request: # noqa: F811 @@ -131,7 +136,8 @@ def connect(self) -> Request: # noqa: F811 protocol_header = build_subprotocol(self.available_subprotocols) headers["Sec-WebSocket-Protocol"] = protocol_header - headers["User-Agent"] = USER_AGENT + if self.user_agent_header is not None: + headers["User-Agent"] = self.user_agent_header return Request(self.wsuri.resource_name, headers) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 29550b694..93566b87e 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -70,7 +70,8 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): is closed with any other code. See :func:`connect` for the documentation of ``logger``, ``origin``, - ``extensions``, ``subprotocols``, and ``extra_headers``. + ``extensions``, ``subprotocols``, ``extra_headers``, and + ``user_agent_header``. See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, @@ -89,6 +90,7 @@ def __init__( extensions: Optional[Sequence[ClientExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLike] = None, + user_agent_header: Optional[str] = USER_AGENT, **kwargs: Any, ) -> None: if logger is None: @@ -98,6 +100,7 @@ def __init__( self.available_extensions = extensions self.available_subprotocols = subprotocols self.extra_headers = extra_headers + self.user_agent_header = user_agent_header def write_http_request(self, path: str, headers: Headers) -> None: """ @@ -315,7 +318,8 @@ async def handshake( if self.extra_headers is not None: request_headers.update(self.extra_headers) - request_headers.setdefault("User-Agent", USER_AGENT) + if self.user_agent_header is not None: + request_headers.setdefault("User-Agent", self.user_agent_header) self.write_http_request(wsuri.resource_name, request_headers) @@ -393,6 +397,9 @@ class Connect: subprotocols: list of supported subprotocols, in order of decreasing preference. extra_headers: arbitrary HTTP headers to add to the request. + user_agent_header: value of the ``User-Agent`` request header; + defauts to ``"Python/x.y.z websockets/X.Y"``; + :obj:`None` removes the header. open_timeout: timeout for opening the connection in seconds; :obj:`None` to disable the timeout @@ -438,6 +445,7 @@ def __init__( extensions: Optional[Sequence[ClientExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLike] = None, + user_agent_header: Optional[str] = USER_AGENT, open_timeout: Optional[float] = 10, ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, @@ -503,6 +511,7 @@ def __init__( extensions=extensions, subprotocols=subprotocols, extra_headers=extra_headers, + user_agent_header=user_agent_header, ping_interval=ping_interval, ping_timeout=ping_timeout, close_timeout=close_timeout, diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 2ee469825..836496b6e 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -84,7 +84,7 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): ws_server: WebSocket server that created this connection. See :func:`serve` for the documentation of ``ws_handler``, ``logger``, ``origins``, - ``extensions``, ``subprotocols``, and ``extra_headers``. + ``extensions``, ``subprotocols``, ``extra_headers``, and ``server_header``. See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, @@ -108,6 +108,7 @@ def __init__( extensions: Optional[Sequence[ServerExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLikeOrCallable] = None, + server_header: Optional[str] = USER_AGENT, process_request: Optional[ Callable[[str, Headers], Awaitable[Optional[HTTPResponse]]] ] = None, @@ -132,6 +133,7 @@ def __init__( self.available_extensions = extensions self.available_subprotocols = subprotocols self.extra_headers = extra_headers + self.server_header = server_header self._process_request = process_request self._select_subprotocol = select_subprotocol @@ -216,7 +218,9 @@ async def handler(self) -> None: ) headers.setdefault("Date", email.utils.formatdate(usegmt=True)) - headers.setdefault("Server", USER_AGENT) + if self.server_header is not None: + headers.setdefault("Server", self.server_header) + headers.setdefault("Content-Length", str(len(body))) headers.setdefault("Content-Type", "text/plain") headers.setdefault("Connection", "close") @@ -635,7 +639,8 @@ async def handshake( response_headers.update(extra_headers) response_headers.setdefault("Date", email.utils.formatdate(usegmt=True)) - response_headers.setdefault("Server", USER_AGENT) + if self.server_header is not None: + response_headers.setdefault("Server", self.server_header) self.write_http_response(http.HTTPStatus.SWITCHING_PROTOCOLS, response_headers) @@ -938,6 +943,9 @@ class Serve: a :data:`~websockets.datastructures.HeadersLike` or a callable taking the request path and headers in arguments and returning a :data:`~websockets.datastructures.HeadersLike`. + server_header: value of the ``Server`` response header; + defauts to ``"Python/x.y.z websockets/X.Y"``; + :obj:`None` removes the header. process_request (Optional[Callable[[str, Headers], \ Awaitable[Optional[Tuple[http.HTTPStatus, HeadersLike, bytes]]]]]): intercept HTTP request before the opening handshake; @@ -980,6 +988,7 @@ def __init__( extensions: Optional[Sequence[ServerExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLikeOrCallable] = None, + server_header: Optional[str] = USER_AGENT, process_request: Optional[ Callable[[str, Headers], Awaitable[Optional[HTTPResponse]]] ] = None, @@ -1061,6 +1070,7 @@ def __init__( extensions=extensions, subprotocols=subprotocols, extra_headers=extra_headers, + server_header=server_header, process_request=process_request, select_subprotocol=select_subprotocol, logger=logger, diff --git a/src/websockets/server.py b/src/websockets/server.py index 5dad50b6a..2057d9fb9 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -64,6 +64,9 @@ class ServerConnection(Connection): logger: logger for this connection; defaults to ``logging.getLogger("websockets.client")``; see the :doc:`logging guide <../topics/logging>` for details. + server_header: value of the ``Server`` response header; + defauts to ``"Python/x.y.z websockets/X.Y"``; + :obj:`None` removes the header. """ @@ -72,6 +75,7 @@ def __init__( origins: Optional[Sequence[Optional[Origin]]] = None, extensions: Optional[Sequence[ServerExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, + server_header: Optional[str] = USER_AGENT, state: State = CONNECTING, max_size: Optional[int] = 2**20, logger: Optional[LoggerLike] = None, @@ -85,6 +89,7 @@ def __init__( self.origins = origins self.available_extensions = extensions self.available_subprotocols = subprotocols + self.server_header = server_header def accept(self, request: Request) -> Response: """ @@ -170,7 +175,8 @@ def accept(self, request: Request) -> Response: if protocol_header is not None: headers["Sec-WebSocket-Protocol"] = protocol_header - headers["Server"] = USER_AGENT + if self.server_header is not None: + headers["Server"] = self.server_header self.logger.info("connection open") return Response(101, "Switching Protocols", headers) @@ -469,9 +475,11 @@ def reject( ("Connection", "close"), ("Content-Length", str(len(body))), ("Content-Type", "text/plain; charset=utf-8"), - ("Server", USER_AGENT), ] ) + if self.server_header is not None: + headers["Server"] = self.server_header + response = Response(status.value, status.phrase, headers, body) # When reject() is called from accept(), handshake_exc is already set. # If a user calls reject(), set handshake_exc to guarantee invariant: diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index f9de70c9c..22d72d1f7 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -655,12 +655,27 @@ def test_protocol_custom_request_headers(self): self.assertIn("('X-Spam', 'Eggs')", req_headers) @with_server() - @with_client("/headers", extra_headers={"User-Agent": "Eggs"}) - def test_protocol_custom_request_user_agent(self): + @with_client("/headers", extra_headers={"User-Agent": "websockets"}) + def test_protocol_custom_user_agent_header_legacy(self): req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertEqual(req_headers.count("User-Agent"), 1) - self.assertIn("('User-Agent', 'Eggs')", req_headers) + self.assertIn("('User-Agent', 'websockets')", req_headers) + + @with_server() + @with_client("/headers", user_agent_header=None) + def test_protocol_no_user_agent_header(self): + req_headers = self.loop.run_until_complete(self.client.recv()) + self.loop.run_until_complete(self.client.recv()) + self.assertNotIn("User-Agent", req_headers) + + @with_server() + @with_client("/headers", user_agent_header="websockets") + def test_protocol_custom_user_agent_header(self): + req_headers = self.loop.run_until_complete(self.client.recv()) + self.loop.run_until_complete(self.client.recv()) + self.assertEqual(req_headers.count("User-Agent"), 1) + self.assertIn("('User-Agent', 'websockets')", req_headers) @with_server(extra_headers=lambda p, r: {"X-Spam": "Eggs"}) @with_client("/headers") @@ -682,13 +697,28 @@ def test_protocol_custom_response_headers(self): resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) - @with_server(extra_headers={"Server": "Eggs"}) + @with_server(extra_headers={"Server": "websockets"}) + @with_client("/headers") + def test_protocol_custom_server_header_legacy(self): + self.loop.run_until_complete(self.client.recv()) + resp_headers = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(resp_headers.count("Server"), 1) + self.assertIn("('Server', 'websockets')", resp_headers) + + @with_server(server_header=None) @with_client("/headers") - def test_protocol_custom_response_user_agent(self): + def test_protocol_no_server_header(self): + self.loop.run_until_complete(self.client.recv()) + resp_headers = self.loop.run_until_complete(self.client.recv()) + self.assertNotIn("Server", resp_headers) + + @with_server(server_header="websockets") + @with_client("/headers") + def test_protocol_custom_server_header(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertEqual(resp_headers.count("Server"), 1) - self.assertIn("('Server', 'Eggs')", resp_headers) + self.assertIn("('Server', 'websockets')", resp_headers) @with_server(create_protocol=HealthCheckServerProtocol) def test_http_request_http_endpoint(self): @@ -724,6 +754,16 @@ def test_ws_connection_ws_endpoint(self): self.loop.run_until_complete(self.client.recv()) self.stop_client() + @with_server(create_protocol=HealthCheckServerProtocol, server_header=None) + def test_http_request_no_server_header(self): + response = self.loop.run_until_complete(self.make_http_request("/__health__/")) + self.assertNotIn("Server", response.headers) + + @with_server(create_protocol=HealthCheckServerProtocol, server_header="websockets") + def test_http_request_custom_server_header(self): + response = self.loop.run_until_complete(self.make_http_request("/__health__/")) + self.assertEqual(response.headers["Server"], "websockets") + def assert_client_raises_code(self, status_code): with self.assertRaises(InvalidStatusCode) as raised: self.start_client() diff --git a/tests/test_client.py b/tests/test_client.py index a843d3272..0504f79e6 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -574,6 +574,22 @@ def test_unsupported_subprotocol(self): raise client.handshake_exc self.assertEqual(str(raised.exception), "unsupported subprotocol: otherchat") + def test_no_user_agent_header(self): + client = ClientConnection( + parse_uri("wss://example.com/"), + user_agent_header=None, + ) + request = client.connect() + self.assertNotIn("User-Agent", request.headers) + + def test_custom_user_agent_header(self): + client = ClientConnection( + parse_uri("wss://example.com/"), + user_agent_header="websockets", + ) + request = client.connect() + self.assertEqual(request.headers["User-Agent"], "websockets") + class MiscTests(unittest.TestCase): def test_bypass_handshake(self): diff --git a/tests/test_server.py b/tests/test_server.py index e3e802239..c7f398cdf 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -608,6 +608,28 @@ def test_unsupported_subprotocol(self): self.assertNotIn("Sec-WebSocket-Protocol", response.headers) self.assertIsNone(server.subprotocol) + def test_no_server_header(self): + server = ServerConnection(server_header=None) + request = self.make_request() + response = server.accept(request) + self.assertNotIn("Server", response.headers) + + def test_custom_server_header(self): + server = ServerConnection(server_header="websockets") + request = self.make_request() + response = server.accept(request) + self.assertEqual(response.headers["Server"], "websockets") + + def test_reject_response_no_server_header(self): + server = ServerConnection(server_header=None) + response = server.reject(http.HTTPStatus.OK, "Hello world!\n") + self.assertNotIn("Server", response.headers) + + def test_reject_response_custom_server_header(self): + server = ServerConnection(server_header="websockets") + response = server.reject(http.HTTPStatus.OK, "Hello world!\n") + self.assertEqual(response.headers["Server"], "websockets") + class MiscTests(unittest.TestCase): def test_bypass_handshake(self): From b5a2f37f8b2d0149e81ab380b84441a37a2db83e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20G=C3=B3rny?= Date: Sun, 15 May 2022 10:27:09 +0200 Subject: [PATCH 1073/1539] Wrap recv_into() in test_explicit_socket to fix py3.11 Extend TrackedSocket class in test_explicit_socket to wrap the recv_into() method. The Python 3.11 implementation of asyncio is calling it rather than recv(), therefore causing used_for_read not to be set if it's not wrapped. --- tests/legacy/test_client_server.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 22d72d1f7..c7e2e1cae 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -422,10 +422,14 @@ def __init__(self, *args, **kwargs): self.used_for_write = False super().__init__(*args, **kwargs) - def recv(self, *args, **kwargs): + def recv(self, *args, **kwargs): # pragma: no cover self.used_for_read = True return super().recv(*args, **kwargs) + def recv_into(self, *args, **kwargs): # pragma: no cover + self.used_for_read = True + return super().recv_into(*args, **kwargs) + def send(self, *args, **kwargs): self.used_for_write = True return super().send(*args, **kwargs) From bebbf40b5a14a75cd06e58be79dc1eed4255d8ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20G=C3=B3rny?= Date: Sun, 15 May 2022 10:32:03 +0200 Subject: [PATCH 1074/1539] Skip YieldFromTests in Python 3.11+ asyncio.coroutine has been removed in Python 3.11, so skip it if Python is newer than that. Fixes #1175 --- tests/legacy/test_client_server.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index c7e2e1cae..8d0f8725f 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1343,6 +1343,9 @@ def test_checking_lack_of_origin_succeeds_backwards_compatibility(self): self.assertEqual(self.loop.run_until_complete(self.client.recv()), "Hello!") +@unittest.skipIf( + sys.version_info[:2] >= (3, 11), "asyncio.coroutine has been removed in Python 3.11" +) class YieldFromTests(ClientServerTestsMixin, AsyncioTestCase): @with_server() def test_client(self): From 0ea2818b370ae572d52ec5d7063de2f8b64faa38 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 17 Aug 2022 10:19:59 +0200 Subject: [PATCH 1075/1539] Add changelog for 2a07325c. --- docs/project/changelog.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 105065c8c..b52943b8d 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -30,6 +30,12 @@ They may change at any time. *In development* +New features +............ + +* Supported overriding or removing the ``User-Agent`` header in clients and the + ``Server`` header in servers. + 10.3 ---- From 653466389996b466e7b0a6a83ec9e573ca5b09c5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 18 Aug 2022 12:45:04 +0200 Subject: [PATCH 1076/1539] Fix tests with PyPy. --- tests/legacy/test_client_server.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 8d0f8725f..f13ef6882 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -761,12 +761,16 @@ def test_ws_connection_ws_endpoint(self): @with_server(create_protocol=HealthCheckServerProtocol, server_header=None) def test_http_request_no_server_header(self): response = self.loop.run_until_complete(self.make_http_request("/__health__/")) - self.assertNotIn("Server", response.headers) + + with contextlib.closing(response): + self.assertNotIn("Server", response.headers) @with_server(create_protocol=HealthCheckServerProtocol, server_header="websockets") def test_http_request_custom_server_header(self): response = self.loop.run_until_complete(self.make_http_request("/__health__/")) - self.assertEqual(response.headers["Server"], "websockets") + + with contextlib.closing(response): + self.assertEqual(response.headers["Server"], "websockets") def assert_client_raises_code(self, status_code): with self.assertRaises(InvalidStatusCode) as raised: From f7f8313f4078a75db4213d4340234b0715567349 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 22 Aug 2022 22:19:12 +0200 Subject: [PATCH 1077/1539] Enable dependabot --- .github/dependabot.yml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .github/dependabot.yml diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 000000000..123014908 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,6 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "daily" From e36f5b673bc32885d3c08102a19f5194951e7940 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 22 Aug 2022 20:19:42 +0000 Subject: [PATCH 1078/1539] Bump actions/setup-python from 2 to 4 Bumps [actions/setup-python](https://github.com/actions/setup-python) from 2 to 4. - [Release notes](https://github.com/actions/setup-python/releases) - [Commits](https://github.com/actions/setup-python/compare/v2...v4) --- updated-dependencies: - dependency-name: actions/setup-python dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/tests.yml | 4 ++-- .github/workflows/wheels.yml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 03ce3aff9..2e6e0f10e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,7 +16,7 @@ jobs: - name: Check out repository uses: actions/checkout@v2 - name: Install Python 3.x - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: "3.x" - name: Install tox @@ -59,7 +59,7 @@ jobs: - name: Check out repository uses: actions/checkout@v2 - name: Install Python ${{ matrix.python }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python }} - name: Install tox diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 0322055b1..aff7dcec9 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -14,7 +14,7 @@ jobs: - name: Check out repository uses: actions/checkout@v2 - name: Install Python 3.x - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: 3.x - name: Build sdist @@ -39,7 +39,7 @@ jobs: - name: Make extension build mandatory run: touch .cibuildwheel - name: Install Python 3.x - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: 3.x - name: Set up QEMU From e22a19a0666dfe6e0deeb0fd72adbdbbabfe5fea Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 22 Aug 2022 20:19:46 +0000 Subject: [PATCH 1079/1539] Bump docker/setup-qemu-action from 1 to 2 Bumps [docker/setup-qemu-action](https://github.com/docker/setup-qemu-action) from 1 to 2. - [Release notes](https://github.com/docker/setup-qemu-action/releases) - [Commits](https://github.com/docker/setup-qemu-action/compare/v1...v2) --- updated-dependencies: - dependency-name: docker/setup-qemu-action dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/wheels.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index aff7dcec9..110e6add8 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -44,7 +44,7 @@ jobs: python-version: 3.x - name: Set up QEMU if: runner.os == 'Linux' - uses: docker/setup-qemu-action@v1 + uses: docker/setup-qemu-action@v2 with: platforms: all - name: Build wheels From 8836c4e7e1b70ca41752e0dfb65f777062264f54 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 22 Aug 2022 20:19:50 +0000 Subject: [PATCH 1080/1539] Bump pypa/cibuildwheel from 2.5.0 to 2.9.0 Bumps [pypa/cibuildwheel](https://github.com/pypa/cibuildwheel) from 2.5.0 to 2.9.0. - [Release notes](https://github.com/pypa/cibuildwheel/releases) - [Changelog](https://github.com/pypa/cibuildwheel/blob/main/docs/changelog.md) - [Commits](https://github.com/pypa/cibuildwheel/compare/v2.5.0...v2.9.0) --- updated-dependencies: - dependency-name: pypa/cibuildwheel dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/wheels.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 110e6add8..fa1c660b5 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -48,7 +48,7 @@ jobs: with: platforms: all - name: Build wheels - uses: pypa/cibuildwheel@v2.5.0 + uses: pypa/cibuildwheel@v2.9.0 env: CIBW_ARCHS_MACOS: "x86_64 universal2 arm64" CIBW_ARCHS_LINUX: "auto aarch64" From fdb4b68d0dc5c4b439f1e2d81bf91fe70b8c8c95 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 22 Aug 2022 20:19:53 +0000 Subject: [PATCH 1081/1539] Bump actions/download-artifact from 2 to 3 Bumps [actions/download-artifact](https://github.com/actions/download-artifact) from 2 to 3. - [Release notes](https://github.com/actions/download-artifact/releases) - [Commits](https://github.com/actions/download-artifact/compare/v2...v3) --- updated-dependencies: - dependency-name: actions/download-artifact dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/wheels.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index fa1c660b5..7b07d9b31 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -66,7 +66,7 @@ jobs: runs-on: ubuntu-latest if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') steps: - - uses: actions/download-artifact@v2 + - uses: actions/download-artifact@v3 with: name: artifact path: dist From 83204e26354be8f574e4c339f6d013899cc773ac Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 22 Aug 2022 20:20:00 +0000 Subject: [PATCH 1082/1539] Bump actions/upload-artifact from 2 to 3 Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 2 to 3. - [Release notes](https://github.com/actions/upload-artifact/releases) - [Commits](https://github.com/actions/upload-artifact/compare/v2...v3) --- updated-dependencies: - dependency-name: actions/upload-artifact dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/wheels.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 7b07d9b31..2ce8a36fe 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -20,7 +20,7 @@ jobs: - name: Build sdist run: python setup.py sdist - name: Save sdist - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: path: dist/*.tar.gz @@ -54,7 +54,7 @@ jobs: CIBW_ARCHS_LINUX: "auto aarch64" CIBW_SKIP: cp36-* - name: Save wheels - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: path: wheelhouse/*.whl From b839d36b493a49e6da2a72a8883e5b4d48c3aff6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 22 Aug 2022 22:37:57 +0200 Subject: [PATCH 1083/1539] Configure dependabot schedule. --- .github/dependabot.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 123014908..ad1e824b4 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -3,4 +3,7 @@ updates: - package-ecosystem: "github-actions" directory: "/" schedule: - interval: "daily" + interval: "weekly" + day: "saturday" + time: "07:00" + timezone: "Europe/Paris" From ff0e13010684e08f81c2534b315d228aac0acfae Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 22 Aug 2022 20:38:59 +0000 Subject: [PATCH 1084/1539] Bump actions/checkout from 2 to 3 Bumps [actions/checkout](https://github.com/actions/checkout) from 2 to 3. - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v2...v3) --- updated-dependencies: - dependency-name: actions/checkout dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/tests.yml | 4 ++-- .github/workflows/wheels.yml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2e6e0f10e..4785e8d20 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out repository - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Install Python 3.x uses: actions/setup-python@v4 with: @@ -57,7 +57,7 @@ jobs: is_main: false steps: - name: Check out repository - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Install Python ${{ matrix.python }} uses: actions/setup-python@v4 with: diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 2ce8a36fe..fe80b3431 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out repository - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Install Python 3.x uses: actions/setup-python@v4 with: @@ -35,7 +35,7 @@ jobs: - macOS-latest steps: - name: Check out repository - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Make extension build mandatory run: touch .cibuildwheel - name: Install Python 3.x From cfa70cac5214c5eb2ba28077b53da037ee51d9a0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 23 Aug 2022 00:07:34 +0200 Subject: [PATCH 1085/1539] Pin action to a released version --- .github/workflows/wheels.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index fe80b3431..e7b9dc431 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -70,6 +70,6 @@ jobs: with: name: artifact path: dist - - uses: pypa/gh-action-pypi-publish@master + - uses: pypa/gh-action-pypi-publish@release/v1 with: password: ${{ secrets.PYPI_API_TOKEN }} From fd17f4684ef41750bdad09e09e22ee9a653f1238 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 23 Aug 2022 22:39:03 +0200 Subject: [PATCH 1086/1539] Add OSS-Fuzz fuzz targets (experimental). --- fuzzing/fuzz_http11_request_parser.py | 35 ++++++++++++++++++++++++ fuzzing/fuzz_http11_response_parser.py | 38 ++++++++++++++++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 fuzzing/fuzz_http11_request_parser.py create mode 100644 fuzzing/fuzz_http11_response_parser.py diff --git a/fuzzing/fuzz_http11_request_parser.py b/fuzzing/fuzz_http11_request_parser.py new file mode 100644 index 000000000..4879899e1 --- /dev/null +++ b/fuzzing/fuzz_http11_request_parser.py @@ -0,0 +1,35 @@ +import sys + +import atheris + + +with atheris.instrument_imports(): + from websockets.exceptions import SecurityError + from websockets.http11 import Request + from websockets.streams import StreamReader + + +def test_one_input(data): + reader = StreamReader() + reader.feed_data(data) + reader.feed_eof() + + try: + Request.parse( + reader.read_line, + ) + except ( + EOFError, # connection is closed without a full HTTP request + SecurityError, # request exceeds a security limit + ValueError, # request isn't well formatted + ): + pass + + +def main(): + atheris.Setup(sys.argv, test_one_input) + atheris.Fuzz() + + +if __name__ == "__main__": + main() diff --git a/fuzzing/fuzz_http11_response_parser.py b/fuzzing/fuzz_http11_response_parser.py new file mode 100644 index 000000000..2f0bcd3ed --- /dev/null +++ b/fuzzing/fuzz_http11_response_parser.py @@ -0,0 +1,38 @@ +import sys + +import atheris + + +with atheris.instrument_imports(): + from websockets.exceptions import SecurityError + from websockets.http11 import Response + from websockets.streams import StreamReader + + +def test_one_input(data): + reader = StreamReader() + reader.feed_data(data) + reader.feed_eof() + + try: + Response.parse( + reader.read_line, + reader.read_exact, + reader.read_to_eof, + ) + except ( + EOFError, # connection is closed without a full HTTP response. + SecurityError, # response exceeds a security limit. + LookupError, # response isn't well formatted. + ValueError, # response isn't well formatted. + ): + pass + + +def main(): + atheris.Setup(sys.argv, test_one_input) + atheris.Fuzz() + + +if __name__ == "__main__": + main() From a26cec226e63894b2cff3c5ea1ad3a643fb1d889 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 24 Aug 2022 07:50:33 +0200 Subject: [PATCH 1087/1539] Make fuzz targets actually run. --- fuzzing/fuzz_http11_request_parser.py | 10 +++++++--- fuzzing/fuzz_http11_response_parser.py | 21 ++++++++++++--------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/fuzzing/fuzz_http11_request_parser.py b/fuzzing/fuzz_http11_request_parser.py index 4879899e1..148785385 100644 --- a/fuzzing/fuzz_http11_request_parser.py +++ b/fuzzing/fuzz_http11_request_parser.py @@ -14,10 +14,14 @@ def test_one_input(data): reader.feed_data(data) reader.feed_eof() + parser = Request.parse( + reader.read_line, + ) + try: - Request.parse( - reader.read_line, - ) + next(parser) + except StopIteration: + pass # request is available in exc.value except ( EOFError, # connection is closed without a full HTTP request SecurityError, # request exceeds a security limit diff --git a/fuzzing/fuzz_http11_response_parser.py b/fuzzing/fuzz_http11_response_parser.py index 2f0bcd3ed..0f783f6fd 100644 --- a/fuzzing/fuzz_http11_response_parser.py +++ b/fuzzing/fuzz_http11_response_parser.py @@ -14,17 +14,20 @@ def test_one_input(data): reader.feed_data(data) reader.feed_eof() + parser = Response.parse( + reader.read_line, + reader.read_exact, + reader.read_to_eof, + ) try: - Response.parse( - reader.read_line, - reader.read_exact, - reader.read_to_eof, - ) + next(parser) + except StopIteration: + pass # response is available in exc.value except ( - EOFError, # connection is closed without a full HTTP response. - SecurityError, # response exceeds a security limit. - LookupError, # response isn't well formatted. - ValueError, # response isn't well formatted. + EOFError, # connection is closed without a full HTTP response + SecurityError, # response exceeds a security limit + LookupError, # response isn't well formatted + ValueError, # response isn't well formatted ): pass From 5d1bad7ebb1121349d08260554368553c02d1a37 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 24 Aug 2022 07:50:48 +0200 Subject: [PATCH 1088/1539] Add fuzz target for WebSocket parser. --- fuzzing/fuzz_websocket_parser.py | 46 ++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 fuzzing/fuzz_websocket_parser.py diff --git a/fuzzing/fuzz_websocket_parser.py b/fuzzing/fuzz_websocket_parser.py new file mode 100644 index 000000000..7569d0b61 --- /dev/null +++ b/fuzzing/fuzz_websocket_parser.py @@ -0,0 +1,46 @@ +import sys + +import atheris + + +with atheris.instrument_imports(): + from websockets.exceptions import PayloadTooBig, ProtocolError + from websockets.frames import Frame + from websockets.streams import StreamReader + + +def test_one_input(data): + fdp = atheris.FuzzedDataProvider(data) + mask = fdp.ConsumeBool() + max_size_enabled = fdp.ConsumeBool() + max_size = fdp.ConsumeInt(4) + payload = fdp.ConsumeBytes(atheris.ALL_REMAINING) + + reader = StreamReader() + reader.feed_data(payload) + reader.feed_eof() + + parser = Frame.parse( + reader.read_exact, + mask=mask, + max_size=max_size if max_size_enabled else None, + ) + + try: + next(parser) + except StopIteration: + pass # response is available in exc.value + except ( + PayloadTooBig, # frame's payload size exceeds ``max_size`` + ProtocolError, # frame contains incorrect values + ): + pass + + +def main(): + atheris.Setup(sys.argv, test_one_input) + atheris.Fuzz() + + +if __name__ == "__main__": + main() From 61e0e1c10f9895a15b34f0d43f160c6b4861e18b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 24 Aug 2022 08:04:54 +0200 Subject: [PATCH 1089/1539] Ensure fuzz targets work as expected. --- fuzzing/fuzz_http11_request_parser.py | 9 ++++++--- fuzzing/fuzz_http11_response_parser.py | 9 ++++++--- fuzzing/fuzz_websocket_parser.py | 9 ++++++--- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/fuzzing/fuzz_http11_request_parser.py b/fuzzing/fuzz_http11_request_parser.py index 148785385..59e0cea0f 100644 --- a/fuzzing/fuzz_http11_request_parser.py +++ b/fuzzing/fuzz_http11_request_parser.py @@ -20,14 +20,17 @@ def test_one_input(data): try: next(parser) - except StopIteration: - pass # request is available in exc.value + except StopIteration as exc: + assert isinstance(exc.value, Request) + return # input accepted except ( EOFError, # connection is closed without a full HTTP request SecurityError, # request exceeds a security limit ValueError, # request isn't well formatted ): - pass + return # input rejected with a documented exception + + raise RuntimeError("parsing didn't complete") def main(): diff --git a/fuzzing/fuzz_http11_response_parser.py b/fuzzing/fuzz_http11_response_parser.py index 0f783f6fd..6906720a4 100644 --- a/fuzzing/fuzz_http11_response_parser.py +++ b/fuzzing/fuzz_http11_response_parser.py @@ -21,15 +21,18 @@ def test_one_input(data): ) try: next(parser) - except StopIteration: - pass # response is available in exc.value + except StopIteration as exc: + assert isinstance(exc.value, Response) + return # input accepted except ( EOFError, # connection is closed without a full HTTP response SecurityError, # response exceeds a security limit LookupError, # response isn't well formatted ValueError, # response isn't well formatted ): - pass + return # input rejected with a documented exception + + raise RuntimeError("parsing didn't complete") def main(): diff --git a/fuzzing/fuzz_websocket_parser.py b/fuzzing/fuzz_websocket_parser.py index 7569d0b61..ab9c1dd2e 100644 --- a/fuzzing/fuzz_websocket_parser.py +++ b/fuzzing/fuzz_websocket_parser.py @@ -28,13 +28,16 @@ def test_one_input(data): try: next(parser) - except StopIteration: - pass # response is available in exc.value + except StopIteration as exc: + assert isinstance(exc.value, Frame) + return # input accepted except ( PayloadTooBig, # frame's payload size exceeds ``max_size`` ProtocolError, # frame contains incorrect values ): - pass + return # input rejected with a documented exception + + raise RuntimeError("parsing didn't complete") def main(): From 973edf67f2956c5fc6e5bd11c779f267224d08e6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 24 Aug 2022 08:28:45 +0200 Subject: [PATCH 1090/1539] Add expected exceptions in Frame.parse. --- fuzzing/fuzz_websocket_parser.py | 2 ++ src/websockets/frames.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/fuzzing/fuzz_websocket_parser.py b/fuzzing/fuzz_websocket_parser.py index ab9c1dd2e..1509a3549 100644 --- a/fuzzing/fuzz_websocket_parser.py +++ b/fuzzing/fuzz_websocket_parser.py @@ -32,6 +32,8 @@ def test_one_input(data): assert isinstance(exc.value, Frame) return # input accepted except ( + EOFError, # connection is closed without a full WebSocket frame + UnicodeDecodeError, # frame contains invalid UTF-8 PayloadTooBig, # frame's payload size exceeds ``max_size`` ProtocolError, # frame contains incorrect values ): diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 043b688b5..ec6b8547d 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -191,6 +191,8 @@ def parse( extensions: list of extensions, applied in reverse order. Raises: + EOFError: if the connection is closed without a full WebSocket frame. + UnicodeDecodeError: if the frame contains invalid UTF-8. PayloadTooBig: if the frame's payload size exceeds ``max_size``. ProtocolError: if the frame contains incorrect values. From eeedb71b1b2f88633d3b0cc8715f2ce8e62f77de Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 10 Sep 2022 08:43:32 +0200 Subject: [PATCH 1091/1539] Add CII Best Practices badge. Remove "wheel" badge which has become standard practice and isn't particularly interesting. --- README.rst | 10 +++++----- docs/index.rst | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/README.rst b/README.rst index 2b9a445ea..7cbccfe13 100644 --- a/README.rst +++ b/README.rst @@ -2,7 +2,7 @@ :width: 480px :alt: websockets -|licence| |version| |pyversions| |wheel| |tests| |docs| +|licence| |version| |pyversions| |tests| |docs| |openssf| .. |licence| image:: https://img.shields.io/pypi/l/websockets.svg :target: https://pypi.python.org/pypi/websockets @@ -13,15 +13,15 @@ .. |pyversions| image:: https://img.shields.io/pypi/pyversions/websockets.svg :target: https://pypi.python.org/pypi/websockets -.. |wheel| image:: https://img.shields.io/pypi/wheel/websockets.svg - :target: https://pypi.python.org/pypi/websockets - -.. |tests| image:: https://img.shields.io/github/checks-status/aaugustin/websockets/main +.. |tests| image:: https://img.shields.io/github/checks-status/aaugustin/websockets/main?label=tests :target: https://github.com/aaugustin/websockets/actions/workflows/tests.yml .. |docs| image:: https://img.shields.io/readthedocs/websockets.svg :target: https://websockets.readthedocs.io/ +.. |openssf| image:: https://bestpractices.coreinfrastructure.org/projects/6475/badge + :target: https://bestpractices.coreinfrastructure.org/projects/6475 + What is ``websockets``? ----------------------- diff --git a/docs/index.rst b/docs/index.rst index 00a9b5999..d64d80ac3 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,7 +1,7 @@ websockets ========== -|licence| |version| |pyversions| |wheel| |tests| |docs| +|licence| |version| |pyversions| |tests| |docs| |openssf| .. |licence| image:: https://img.shields.io/pypi/l/websockets.svg :target: https://pypi.python.org/pypi/websockets @@ -12,15 +12,15 @@ websockets .. |pyversions| image:: https://img.shields.io/pypi/pyversions/websockets.svg :target: https://pypi.python.org/pypi/websockets -.. |wheel| image:: https://img.shields.io/pypi/wheel/websockets.svg - :target: https://pypi.python.org/pypi/websockets - -.. |tests| image:: https://img.shields.io/github/checks-status/aaugustin/websockets/main +.. |tests| image:: https://img.shields.io/github/checks-status/aaugustin/websockets/main?label=tests :target: https://github.com/aaugustin/websockets/actions/workflows/tests.yml .. |docs| image:: https://img.shields.io/readthedocs/websockets.svg :target: https://websockets.readthedocs.io/ +.. |openssf| image:: https://bestpractices.coreinfrastructure.org/projects/6475/badge + :target: https://bestpractices.coreinfrastructure.org/projects/6475 + websockets is a library for building WebSocket_ servers and clients in Python with a focus on correctness, simplicity, robustness, and performance. From eabb4b6a218eaa4be908e1038ebe7eb0724031de Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 17 Sep 2022 05:29:33 +0000 Subject: [PATCH 1092/1539] Bump pypa/cibuildwheel from 2.9.0 to 2.10.0 Bumps [pypa/cibuildwheel](https://github.com/pypa/cibuildwheel) from 2.9.0 to 2.10.0. - [Release notes](https://github.com/pypa/cibuildwheel/releases) - [Changelog](https://github.com/pypa/cibuildwheel/blob/main/docs/changelog.md) - [Commits](https://github.com/pypa/cibuildwheel/compare/v2.9.0...v2.10.0) --- updated-dependencies: - dependency-name: pypa/cibuildwheel dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/wheels.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index e7b9dc431..cc22cd4c6 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -48,7 +48,7 @@ jobs: with: platforms: all - name: Build wheels - uses: pypa/cibuildwheel@v2.9.0 + uses: pypa/cibuildwheel@v2.10.0 env: CIBW_ARCHS_MACOS: "x86_64 universal2 arm64" CIBW_ARCHS_LINUX: "auto aarch64" From 569ba1e04aa78313fd4bceb1cdae59c822682add Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 26 Sep 2022 21:59:13 +0200 Subject: [PATCH 1093/1539] Add disclaimer to Heroku tutorial. I wouldn't write it today :-( --- docs/howto/heroku.rst | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/docs/howto/heroku.rst b/docs/howto/heroku.rst index 464420e05..2ac5a8c1e 100644 --- a/docs/howto/heroku.rst +++ b/docs/howto/heroku.rst @@ -4,6 +4,14 @@ Deploy to Heroku This guide describes how to deploy a websockets server to Heroku_. The same principles should apply to other Platform as a Service providers. +.. admonition:: Heroku no longer offers a free tier. + :class: attention + + When this tutorial was written, in September 2021, Heroku offered a free + tier where a websockets app could run at no cost. In November 2022, Heroku + removed the free tier, making it impossible to maintain this document. As a + consequence, it isn't updated anymore and may be removed in the future. + We're going to deploy a very simple app. The process would be identical for a more realistic app. @@ -42,7 +50,7 @@ Here's the implementation of the app, an echo server. Save it in a file called ``app.py``: .. literalinclude:: ../../example/deployment/heroku/app.py - :language: text + :language: python Heroku expects the server to `listen on a specific port`_, which is provided in the ``$PORT`` environment variable. The app reads it and passes it to From 7eedf7aab316cd26b8db2db2d7113d479d2f4d0b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Oct 2022 17:59:29 +0200 Subject: [PATCH 1094/1539] Add guide for deploying on Render. --- docs/howto/index.rst | 1 + docs/howto/render.rst | 173 +++++++++++++++++++++ example/deployment/render/app.py | 36 +++++ example/deployment/render/requirements.txt | 1 + 4 files changed, 211 insertions(+) create mode 100644 docs/howto/render.rst create mode 100644 example/deployment/render/app.py create mode 100644 example/deployment/render/requirements.txt diff --git a/docs/howto/index.rst b/docs/howto/index.rst index dafb72391..23fe823e1 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -39,6 +39,7 @@ Once your application is ready, learn how to deploy it on various platforms. .. toctree:: :titlesonly: + render heroku kubernetes supervisor diff --git a/docs/howto/render.rst b/docs/howto/render.rst new file mode 100644 index 000000000..b8c417b66 --- /dev/null +++ b/docs/howto/render.rst @@ -0,0 +1,173 @@ +Deploy to Render +================ + +This guide describes how to deploy a websockets server to Render_. + +.. _Render: https://render.com/ + +.. admonition:: The free plan of Render is sufficient for trying this guide. + :class: tip + + However, on a `free plan`__, connections are dropped after five minutes, + which is quite short for WebSocket application. + + __ https://render.com/docs/free + +We're going to deploy a very simple app. The process would be identical for a +more realistic app. + +Create repository +----------------- + +Deploying to Render requires a git repository. Let's initialize one: + +.. code-block:: console + + $ mkdir websockets-echo + $ cd websockets-echo + $ git init -b main + Initialized empty Git repository in websockets-echo/.git/ + $ git commit --allow-empty -m "Initial commit." + [main (root-commit) 816c3b1] Initial commit. + +Render requires the git repository to be hosted at GitHub or GitLab. + +Sign up or log in to GitHub. Create a new repository named ``websockets-echo``. +Don't enable any of the initialization options offered by GitHub. Then, follow +instructions for pushing an existing repository from the command line. + +After pushing, refresh your repository's homepage on GitHub. You should see an +empty repository with an empty initial commit. + +Create application +------------------ + +Here's the implementation of the app, an echo server. Save it in a file called +``app.py``: + +.. literalinclude:: ../../example/deployment/render/app.py + :language: python + +This app implements requirements for `zero downtime deploys`_: + +* it provides a health check at ``/healthz``; +* it closes connections and exits cleanly when it receives a ``SIGTERM`` signal. + +.. _zero downtime deploys: https://render.com/docs/deploys#zero-downtime-deploys + +Create a ``requirements.txt`` file containing this line to declare a dependency +on websockets: + +.. literalinclude:: ../../example/deployment/render/requirements.txt + :language: text + +Confirm that you created the correct files and commit them to git: + +.. code-block:: console + + $ ls + app.py requirements.txt + $ git add . + $ git commit -m "Initial implementation." + [main f26bf7f] Initial implementation. + 2 files changed, 37 insertions(+) + create mode 100644 app.py + create mode 100644 requirements.txt + +Push the changes to GitHub: + +.. code-block:: console + + $ git push + ... + To github.com:/websockets-echo.git + 816c3b1..f26bf7f main -> main + +The app is ready. Let's deploy it! + +Deploy application +------------------ + +Sign up or log in to Render. + +Create a new web service. Connect the git repository that you just created. + +Then, finalize the configuration of your app as follows: + +* **Name**: websockets-echo +* **Start Command**: ``python app.py`` + +If you're just experimenting, select the free plan. Create the web service. + +To configure the health check, go to Settings, scroll down to Health & Alerts, +and set: + +* **Health Check Path**: /healthz + +This triggers a new deployment. + +Validate deployment +------------------- + +Let's confirm that your application is running as expected. + +Since it's a WebSocket server, you need a WebSocket client, such as the +interactive client that comes with websockets. + +If you're currently building a websockets server, perhaps you're already in a +virtualenv where websockets is installed. If not, you can install it in a new +virtualenv as follows: + +.. code-block:: console + + $ python -m venv websockets-client + $ . websockets-client/bin/activate + $ pip install websockets + +Connect the interactive client — you must replace ``websockets-echo`` with the +name of your Render app in this command: + +.. code-block:: console + + $ python -m websockets wss://websockets-echo.onrender.com/ + Connected to wss://websockets-echo.onrender.com/. + > + +Great! Your app is running! + +Once you're connected, you can send any message and the server will echo it, +or press Ctrl-D to terminate the connection: + +.. code-block:: console + + > Hello! + < Hello! + Connection closed: 1000 (OK). + +You can also confirm that your application shuts down gracefully when you deploy +a new version. Due to limitations of Render's free plan, you must upgrade to a +paid plan before you perform this test. + +Connect an interactive client again — remember to replace ``websockets-echo`` +with your app: + +.. code-block:: console + + $ python -m websockets wss://websockets-echo.onrender.com/ + Connected to wss://websockets-echo.onrender.com/. + > + +Trigger a new deployment with Manual Deploy > Deploy latest commit. When the +deployment completes, the connection is closed with code 1001 (going away). + +.. code-block:: console + + $ python -m websockets wss://websockets-echo.onrender.com/ + Connected to wss://websockets-echo.onrender.com/. + Connection closed: 1001 (going away). + +If graceful shutdown wasn't working, the server wouldn't perform a closing +handshake and the connection would be closed with code 1006 (connection closed +abnormally). + +Remember to downgrade to a free plan if you upgraded just for testing this feature. diff --git a/example/deployment/render/app.py b/example/deployment/render/app.py new file mode 100644 index 000000000..4ca34d23b --- /dev/null +++ b/example/deployment/render/app.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python + +import asyncio +import http +import signal + +import websockets + + +async def echo(websocket): + async for message in websocket: + await websocket.send(message) + + +async def health_check(path, request_headers): + if path == "/healthz": + return http.HTTPStatus.OK, [], b"OK\n" + + +async def main(): + # Set the stop condition when receiving SIGTERM. + loop = asyncio.get_running_loop() + stop = loop.create_future() + loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) + + async with websockets.serve( + echo, + host="", + port=8080, + process_request=health_check, + ): + await stop + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/deployment/render/requirements.txt b/example/deployment/render/requirements.txt new file mode 100644 index 000000000..14774b465 --- /dev/null +++ b/example/deployment/render/requirements.txt @@ -0,0 +1 @@ +websockets From 98999b66d502a067cdb0d241dcd7003f6f185da4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Oct 2022 20:39:35 +0200 Subject: [PATCH 1095/1539] Add guide for deploying on Fly. --- docs/howto/fly.rst | 178 ++++++++++++++++++++++++ docs/howto/index.rst | 1 + docs/howto/kubernetes.rst | 2 + docs/spelling_wordlist.txt | 4 + example/deployment/fly/Procfile | 1 + example/deployment/fly/app.py | 36 +++++ example/deployment/fly/fly.toml | 16 +++ example/deployment/fly/requirements.txt | 1 + 8 files changed, 239 insertions(+) create mode 100644 docs/howto/fly.rst create mode 100644 example/deployment/fly/Procfile create mode 100644 example/deployment/fly/app.py create mode 100644 example/deployment/fly/fly.toml create mode 100644 example/deployment/fly/requirements.txt diff --git a/docs/howto/fly.rst b/docs/howto/fly.rst new file mode 100644 index 000000000..7e404de61 --- /dev/null +++ b/docs/howto/fly.rst @@ -0,0 +1,178 @@ +Deploy to Fly +================ + +This guide describes how to deploy a websockets server to Fly_. + +.. _Fly: https://fly.io/ + +.. admonition:: The free tier of Fly is sufficient for trying this guide. + :class: tip + + The `free tier`__ include up to three small VMs. This guide uses only one. + + __ https://fly.io/docs/about/pricing/ + +We're going to deploy a very simple app. The process would be identical for a +more realistic app. + +Create application +------------------ + +Here's the implementation of the app, an echo server. Save it in a file called +``app.py``: + +.. literalinclude:: ../../example/deployment/fly/app.py + :language: python + +This app implements typical requirements for running on a Platform as a Service: + +* it provides a health check at ``/healthz``; +* it closes connections and exits cleanly when it receives a ``SIGTERM`` signal. + +Create a ``requirements.txt`` file containing this line to declare a dependency +on websockets: + +.. literalinclude:: ../../example/deployment/fly/requirements.txt + :language: text + +The app is ready. Let's deploy it! + +Deploy application +------------------ + +Follow the instructions__ to install the Fly CLI, if you haven't done that yet. + +__ https://fly.io/docs/hands-on/install-flyctl/ + +Sign up or log in to Fly. + +Launch the app — you'll have to pick a different name because I'm already using +``websockets-echo``: + +.. code-block:: console + + $ fly launch + Creating app in ... + Scanning source code + Detected a Python app + Using the following build configuration: + Builder: paketobuildpacks/builder:base + ? App Name (leave blank to use an auto-generated name): websockets-echo + ? Select organization: ... + ? Select region: ... + Created app websockets-echo in organization ... + Wrote config file fly.toml + ? Would you like to set up a Postgresql database now? No + We have generated a simple Procfile for you. Modify it to fit your needs and run "fly deploy" to deploy your application. + +.. admonition:: This will build the image with a generic buildpack. + :class: tip + + Fly can `build images`__ with a Dockerfile or a buildpack. Here, ``fly + launch`` configures a generic Paketo buildpack. + + If you'd rather package the app with a Dockerfile, check out the guide to + :ref:`containerize an application `. + + __ https://fly.io/docs/reference/builders/ + +Replace the auto-generated ``fly.toml`` with: + +.. literalinclude:: ../../example/deployment/fly/fly.toml + :language: toml + +This configuration: + +* listens on port 443, terminates TLS, and forwards to the app on port 8080; +* declares a health check at ``/healthz``; +* requests a ``SIGTERM`` for terminating the app. + +Replace the auto-generated ``Procfile`` with: + +.. literalinclude:: ../../example/deployment/fly/Procfile + :language: text + +This tells Fly how to run the app. + +Now you can deploy it: + +.. code-block:: console + + $ fly deploy + + ... lots of output... + + ==> Monitoring deployment + + 1 desired, 1 placed, 1 healthy, 0 unhealthy [health checks: 1 total, 1 passing] + --> v0 deployed successfully + +Validate deployment +------------------- + +Let's confirm that your application is running as expected. + +Since it's a WebSocket server, you need a WebSocket client, such as the +interactive client that comes with websockets. + +If you're currently building a websockets server, perhaps you're already in a +virtualenv where websockets is installed. If not, you can install it in a new +virtualenv as follows: + +.. code-block:: console + + $ python -m venv websockets-client + $ . websockets-client/bin/activate + $ pip install websockets + +Connect the interactive client — you must replace ``websockets-echo`` with the +name of your Fly app in this command: + +.. code-block:: console + + $ python -m websockets wss://websockets-echo.fly.dev/ + Connected to wss://websockets-echo.fly.dev/. + > + +Great! Your app is running! + +Once you're connected, you can send any message and the server will echo it, +or press Ctrl-D to terminate the connection: + +.. code-block:: console + + > Hello! + < Hello! + Connection closed: 1000 (OK). + +You can also confirm that your application shuts down gracefully. + +Connect an interactive client again — remember to replace ``websockets-echo`` +with your app: + +.. code-block:: console + + $ python -m websockets wss://websockets-echo.fly.dev/ + Connected to wss://websockets-echo.fly.dev/. + > + +In another shell, restart the app — again, replace ``websockets-echo`` with your +app: + +.. code-block:: console + + $ fly restart websockets-echo + websockets-echo is being restarted + +Go back to the first shell. The connection is closed with code 1001 (going +away). + +.. code-block:: console + + $ python -m websockets wss://websockets-echo.fly.dev/ + Connected to wss://websockets-echo.fly.dev/. + Connection closed: 1001 (going away). + +If graceful shutdown wasn't working, the server wouldn't perform a closing +handshake and the connection would be closed with code 1006 (connection closed +abnormally). diff --git a/docs/howto/index.rst b/docs/howto/index.rst index 23fe823e1..ddbe67d3a 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -40,6 +40,7 @@ Once your application is ready, learn how to deploy it on various platforms. :titlesonly: render + fly heroku kubernetes supervisor diff --git a/docs/howto/kubernetes.rst b/docs/howto/kubernetes.rst index 26dbf8a94..c217e5946 100644 --- a/docs/howto/kubernetes.rst +++ b/docs/howto/kubernetes.rst @@ -13,6 +13,8 @@ websockets is concerned. .. _Kubernetes: https://kubernetes.io/ +.. _containerize-application: + Containerize application ------------------------ diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 1d5ae527d..d5b093e2d 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -10,6 +10,7 @@ balancers bottlenecked bufferbloat bugfix +buildpack bytestring bytestrings changelog @@ -20,9 +21,11 @@ cryptocurrency ctrl deserialize django +Dockerfile dyno fractalideas gunicorn +healthz hypercorn iframe IPv @@ -38,6 +41,7 @@ lookups MiB mypy nginx +Paketo permessage pid proxying diff --git a/example/deployment/fly/Procfile b/example/deployment/fly/Procfile new file mode 100644 index 000000000..2e35818f6 --- /dev/null +++ b/example/deployment/fly/Procfile @@ -0,0 +1 @@ +web: python app.py diff --git a/example/deployment/fly/app.py b/example/deployment/fly/app.py new file mode 100644 index 000000000..4ca34d23b --- /dev/null +++ b/example/deployment/fly/app.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python + +import asyncio +import http +import signal + +import websockets + + +async def echo(websocket): + async for message in websocket: + await websocket.send(message) + + +async def health_check(path, request_headers): + if path == "/healthz": + return http.HTTPStatus.OK, [], b"OK\n" + + +async def main(): + # Set the stop condition when receiving SIGTERM. + loop = asyncio.get_running_loop() + stop = loop.create_future() + loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) + + async with websockets.serve( + echo, + host="", + port=8080, + process_request=health_check, + ): + await stop + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/deployment/fly/fly.toml b/example/deployment/fly/fly.toml new file mode 100644 index 000000000..5290072ed --- /dev/null +++ b/example/deployment/fly/fly.toml @@ -0,0 +1,16 @@ +app = "websockets-echo" +kill_signal = "SIGTERM" + +[build] + builder = "paketobuildpacks/builder:base" + +[[services]] + internal_port = 8080 + protocol = "tcp" + + [[services.http_checks]] + path = "/healthz" + + [[services.ports]] + handlers = ["tls", "http"] + port = 443 diff --git a/example/deployment/fly/requirements.txt b/example/deployment/fly/requirements.txt new file mode 100644 index 000000000..14774b465 --- /dev/null +++ b/example/deployment/fly/requirements.txt @@ -0,0 +1 @@ +websockets From c77b3087a25b2eb459b1049b5352b6a4e912970e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Oct 2022 20:49:35 +0200 Subject: [PATCH 1096/1539] Uniformize deployment guides. --- docs/howto/heroku.rst | 80 +++++++++++++++++++++++-------------------- 1 file changed, 42 insertions(+), 38 deletions(-) diff --git a/docs/howto/heroku.rst b/docs/howto/heroku.rst index 2ac5a8c1e..2b3a44819 100644 --- a/docs/howto/heroku.rst +++ b/docs/howto/heroku.rst @@ -4,6 +4,8 @@ Deploy to Heroku This guide describes how to deploy a websockets server to Heroku_. The same principles should apply to other Platform as a Service providers. +.. _Heroku: https://www.heroku.com/ + .. admonition:: Heroku no longer offers a free tier. :class: attention @@ -15,10 +17,8 @@ principles should apply to other Platform as a Service providers. We're going to deploy a very simple app. The process would be identical for a more realistic app. -.. _Heroku: https://www.heroku.com/ - -Create application ------------------- +Create repository +----------------- Deploying to Heroku requires a git repository. Let's initialize one: @@ -31,20 +31,8 @@ Deploying to Heroku requires a git repository. Let's initialize one: $ git commit --allow-empty -m "Initial commit." [main (root-commit) 1e7947d] Initial commit. -Follow the `set-up instructions`_ to install the Heroku CLI and to log in, if -you haven't done that yet. - -.. _set-up instructions: https://devcenter.heroku.com/articles/getting-started-with-python#set-up - -Then, create a Heroku app — if you follow these instructions step-by-step, -you'll have to pick a different name because I'm already using -``websockets-echo`` on Heroku: - -.. code-block:: console - - $ heroku create websockets-echo - Creating ⬢ websockets-echo... done - https://websockets-echo.herokuapp.com/ | https://git.heroku.com/websockets-echo.git +Create application +------------------ Here's the implementation of the app, an echo server. Save it in a file called ``app.py``: @@ -64,20 +52,18 @@ cleanly. .. _shutting down a dyno: https://devcenter.heroku.com/articles/dynos#shutdown -Deploy application ------------------- - -In order to build the app, Heroku needs to know that it depends on websockets. -Create a ``requirements.txt`` file containing this line: +Create a ``requirements.txt`` file containing this line to declare a dependency +on websockets: .. literalinclude:: ../../example/deployment/heroku/requirements.txt :language: text -Heroku also needs to know how to run the app. Create a ``Procfile`` with this -content: +Create a ``Procfile``. .. literalinclude:: ../../example/deployment/heroku/Procfile +This tells Heroku how to run the app. + Confirm that you created the correct files and commit them to git: .. code-block:: console @@ -85,8 +71,8 @@ Confirm that you created the correct files and commit them to git: $ ls Procfile app.py requirements.txt $ git add . - $ git commit -m "Deploy echo server to Heroku." - [main 8418c62] Deploy echo server to Heroku. + $ git commit -m "Initial implementation." + [main 8418c62] Initial implementation.  3 files changed, 32 insertions(+)  create mode 100644 Procfile  create mode 100644 app.py @@ -94,6 +80,25 @@ Confirm that you created the correct files and commit them to git: The app is ready. Let's deploy it! +Deploy application +------------------ + +Follow the instructions_ to install the Heroku CLI, if you haven't done that +yet. + +.. _instructions: https://devcenter.heroku.com/articles/getting-started-with-python#set-up + +Sign up or log in to Heroku. + +Create a Heroku app — you'll have to pick a different name because I'm already +using ``websockets-echo``: + +.. code-block:: console + + $ heroku create websockets-echo + Creating ⬢ websockets-echo... done + https://websockets-echo.herokuapp.com/ | https://git.heroku.com/websockets-echo.git + .. code-block:: console $ git push heroku @@ -111,7 +116,7 @@ The app is ready. Let's deploy it! Validate deployment ------------------- -Of course you'd like to confirm that your application is running as expected! +Let's confirm that your application is running as expected. Since it's a WebSocket server, you need a WebSocket client, such as the interactive client that comes with websockets. @@ -126,8 +131,8 @@ virtualenv as follows: $ . websockets-client/bin/activate $ pip install websockets -Connect the interactive client — using the name of your Heroku app instead of -``websockets-echo``: +Connect the interactive client — you must replace ``websockets-echo`` with the +name of your Heroku app in this command: .. code-block:: console @@ -137,12 +142,8 @@ Connect the interactive client — using the name of your Heroku app instead of Great! Your app is running! -In this example, I used a secure connection (``wss://``). It worked because -Heroku served a valid TLS certificate for ``websockets-echo.herokuapp.com``. -An insecure connection (``ws://``) would also work. - Once you're connected, you can send any message and the server will echo it, -then press Ctrl-D to terminate the connection: +or press Ctrl-D to terminate the connection: .. code-block:: console @@ -150,8 +151,10 @@ then press Ctrl-D to terminate the connection: < Hello! Connection closed: 1000 (OK). -You can also confirm that your application shuts down gracefully. Connect an -interactive client again — remember to replace ``websockets-echo`` with your app: +You can also confirm that your application shuts down gracefully. + +Connect an interactive client again — remember to replace ``websockets-echo`` +with your app: .. code-block:: console @@ -159,7 +162,8 @@ interactive client again — remember to replace ``websockets-echo`` with your a Connected to wss://websockets-echo.herokuapp.com/. > -In another shell, restart the dyno — again, replace ``websockets-echo`` with your app: +In another shell, restart the app — again, replace ``websockets-echo`` with your +app: .. code-block:: console From 16a85b8362738851fc622b8ffd8092f7f1bb3e87 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Oct 2022 20:52:52 +0200 Subject: [PATCH 1097/1539] Run spell check. --- docs/project/changelog.rst | 2 +- docs/reference/index.rst | 2 +- docs/spelling_wordlist.txt | 1 + docs/topics/timeouts.rst | 2 +- src/websockets/client.py | 2 +- src/websockets/legacy/client.py | 2 +- src/websockets/legacy/server.py | 2 +- src/websockets/server.py | 2 +- 8 files changed, 8 insertions(+), 7 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index b52943b8d..dd19228be 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -307,7 +307,7 @@ Backwards-incompatible changes :class: note While Python supports this, tools relying on static code analysis don't. - This breaks autocompletion in an IDE or type checking with mypy_. + This breaks auto-completion in an IDE or type checking with mypy_. .. _mypy: https://github.com/python/mypy diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 5f51a1c1c..8147c4cf3 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -60,7 +60,7 @@ guarantees of behavior or backwards-compatibility for private APIs. For convenience, many public APIs can be imported from the ``websockets`` package. However, this feature is incompatible with static code analysis. It -breaks autocompletion in an IDE or type checking with mypy_. If you're using +breaks auto-completion in an IDE or type checking with mypy_. If you're using such tools, use the real import paths. .. _mypy: https://github.com/python/mypy diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index d5b093e2d..1d342515b 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -23,6 +23,7 @@ deserialize django Dockerfile dyno +formatter fractalideas gunicorn healthz diff --git a/docs/topics/timeouts.rst b/docs/topics/timeouts.rst index dcf0322a4..23e8020c4 100644 --- a/docs/topics/timeouts.rst +++ b/docs/topics/timeouts.rst @@ -11,7 +11,7 @@ long-lived connections, it is desirable to ensure that connections don't break, and if they do, to report the problem quickly. Connections can drop as a consequence of temporary network connectivity issues, -which are very common, even within datacenters. +which are very common, even within data centers. Furthermore, WebSocket builds on top of HTTP/1.1 where connections are short-lived, even with ``Connection: keep-alive``. Typically, HTTP/1.1 diff --git a/src/websockets/client.py b/src/websockets/client.py index 1c33fae0b..c18c76f22 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -65,7 +65,7 @@ class ClientConnection(Connection): defaults to ``logging.getLogger("websockets.client")``; see the :doc:`logging guide <../topics/logging>` for details. user_agent_header: value of the ``User-Agent`` request header; - defauts to ``"Python/x.y.z websockets/X.Y"``; + defaults to ``"Python/x.y.z websockets/X.Y"``; :obj:`None` removes the header. """ diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 93566b87e..1e3c6e741 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -398,7 +398,7 @@ class Connect: preference. extra_headers: arbitrary HTTP headers to add to the request. user_agent_header: value of the ``User-Agent`` request header; - defauts to ``"Python/x.y.z websockets/X.Y"``; + defaults to ``"Python/x.y.z websockets/X.Y"``; :obj:`None` removes the header. open_timeout: timeout for opening the connection in seconds; :obj:`None` to disable the timeout diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 836496b6e..e01548896 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -944,7 +944,7 @@ class Serve: taking the request path and headers in arguments and returning a :data:`~websockets.datastructures.HeadersLike`. server_header: value of the ``Server`` response header; - defauts to ``"Python/x.y.z websockets/X.Y"``; + defaults to ``"Python/x.y.z websockets/X.Y"``; :obj:`None` removes the header. process_request (Optional[Callable[[str, Headers], \ Awaitable[Optional[Tuple[http.HTTPStatus, HeadersLike, bytes]]]]]): diff --git a/src/websockets/server.py b/src/websockets/server.py index 2057d9fb9..ca48449d6 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -65,7 +65,7 @@ class ServerConnection(Connection): defaults to ``logging.getLogger("websockets.client")``; see the :doc:`logging guide <../topics/logging>` for details. server_header: value of the ``Server`` response header; - defauts to ``"Python/x.y.z websockets/X.Y"``; + defaults to ``"Python/x.y.z websockets/X.Y"``; :obj:`None` removes the header. """ From 270d5dae8f87afa4f0a340fc0045e4cc980e5d44 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 7 Oct 2022 13:06:01 +0200 Subject: [PATCH 1098/1539] Don't require src to be on PYTHONPATH. --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index a38634aee..578f6b1ae 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ export PYTHONWARNINGS=default default: coverage style style: - isort src tests + isort --project websockets src tests black src tests flake8 src tests mypy --strict src From a899d5a1e5453767cda984324c61caa83ef0654c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 8 Oct 2022 14:03:48 +0200 Subject: [PATCH 1099/1539] Avoid triggerring deprecation warnings in tests. This also avoids having to account for them. --- tests/legacy/test_client_server.py | 50 ++++++++++++++++-------------- tests/legacy/test_protocol.py | 46 ++++++++++++++++----------- 2 files changed, 54 insertions(+), 42 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index f13ef6882..62cddd2b4 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -226,16 +226,16 @@ def start_server(self, deprecation_warnings=None, **kwargs): # Disable pings by default in tests. kwargs.setdefault("ping_interval", None) + # This logic is encapsulated in a coroutine to prevent it from executing + # before the event loop is running which causes asyncio.get_event_loop() + # to raise a DeprecationWarning on Python ≥ 3.10. + async def start_server(): + return await serve(handler, "localhost", 0, **kwargs) + with warnings.catch_warnings(record=True) as recorded_warnings: - start_server = serve(handler, "localhost", 0, **kwargs) - self.server = self.loop.run_until_complete(start_server) + self.server = self.loop.run_until_complete(start_server()) expected_warnings = [] if deprecation_warnings is None else deprecation_warnings - if ( - sys.version_info[:2] >= (3, 10) - and "remove loop argument" not in expected_warnings - ): # pragma: no cover - expected_warnings += ["There is no current event loop"] self.assertDeprecationWarnings(recorded_warnings, expected_warnings) def start_client( @@ -252,16 +252,16 @@ def start_client( except KeyError: server_uri = get_server_uri(self.server, secure, resource_name, user_info) + # This logic is encapsulated in a coroutine to prevent it from executing + # before the event loop is running which causes asyncio.get_event_loop() + # to raise a DeprecationWarning on Python ≥ 3.10. + async def start_client(): + return await connect(server_uri, **kwargs) + with warnings.catch_warnings(record=True) as recorded_warnings: - start_client = connect(server_uri, **kwargs) - self.client = self.loop.run_until_complete(start_client) + self.client = self.loop.run_until_complete(start_client()) expected_warnings = [] if deprecation_warnings is None else deprecation_warnings - if ( - sys.version_info[:2] >= (3, 10) - and "remove loop argument" not in expected_warnings - ): # pragma: no cover - expected_warnings += ["There is no current event loop"] self.assertDeprecationWarnings(recorded_warnings, expected_warnings) def stop_client(self): @@ -465,25 +465,26 @@ def test_unix_socket(self): path = bytes(pathlib.Path(temp_dir) / "websockets") # Like self.start_server() but with unix_serve(). - with warnings.catch_warnings(record=True) as recorded_warnings: - unix_server = unix_serve(default_handler, path, loop=self.loop) - self.server = self.loop.run_until_complete(unix_server) - self.assertDeprecationWarnings(recorded_warnings, ["remove loop argument"]) + async def start_server(): + return await unix_serve(default_handler, path) + + self.server = self.loop.run_until_complete(start_server()) try: # Like self.start_client() but with unix_connect() - with warnings.catch_warnings(record=True) as recorded_warnings: - unix_client = unix_connect(path, loop=self.loop) - self.client = self.loop.run_until_complete(unix_client) - self.assertDeprecationWarnings( - recorded_warnings, ["remove loop argument"] - ) + async def start_client(): + return await unix_connect(path) + + self.client = self.loop.run_until_complete(start_client()) + try: self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") + finally: self.stop_client() + finally: self.stop_server() @@ -1341,6 +1342,7 @@ def test_checking_lack_of_origin_succeeds(self): self.assertEqual(self.loop.run_until_complete(self.client.recv()), "Hello!") @with_server(origins=[""]) + # The deprecation warning is raised when a client connects to the server. @with_client(deprecation_warnings=["use None instead of '' in origins"]) def test_checking_lack_of_origin_succeeds_backwards_compatibility(self): self.loop.run_until_complete(self.client.send("Hello!")) diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index ed6424694..6e7f3727b 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1,6 +1,5 @@ import asyncio import contextlib -import sys import unittest import unittest.mock import warnings @@ -88,9 +87,16 @@ class CommonTests: def setUp(self): super().setUp() - with warnings.catch_warnings(record=True): + + # This logic is encapsulated in a coroutine to prevent it from executing + # before the event loop is running which causes asyncio.get_event_loop() + # to raise a DeprecationWarning on Python ≥ 3.10. + + async def create_protocol(): # Disable pings to make it easier to test what frames are sent exactly. - self.protocol = WebSocketCommonProtocol(ping_interval=None) + return WebSocketCommonProtocol(ping_interval=None) + + self.protocol = self.loop.run_until_complete(create_protocol()) self.transport = TransportMock() self.transport.setup_mock(self.loop, self.protocol) @@ -312,29 +318,24 @@ def assertCompletesWithin(self, min_time, max_time): # Test constructor. def test_timeout_backwards_compatibility(self): + async def create_protocol(): + return WebSocketCommonProtocol(ping_interval=None, timeout=5) + with warnings.catch_warnings(record=True) as recorded: - protocol = WebSocketCommonProtocol(timeout=5) + protocol = self.loop.run_until_complete(create_protocol()) self.assertEqual(protocol.close_timeout, 5) - - expected = ["rename timeout to close_timeout"] - if sys.version_info[:2] >= (3, 10): # pragma: no cover - expected += ["There is no current event loop"] - - self.assertDeprecationWarnings(recorded, expected) + self.assertDeprecationWarnings(recorded, ["rename timeout to close_timeout"]) def test_loop_backwards_compatibility(self): loop = asyncio.new_event_loop() self.addCleanup(loop.close) with warnings.catch_warnings(record=True) as recorded: - protocol = WebSocketCommonProtocol(loop=loop) + protocol = WebSocketCommonProtocol(ping_interval=None, loop=loop) self.assertEqual(protocol.loop, loop) - - expected = ["remove loop argument"] - - self.assertDeprecationWarnings(recorded, expected) + self.assertDeprecationWarnings(recorded, ["remove loop argument"]) # Test public attributes. @@ -1151,18 +1152,27 @@ def test_connection_closed_attributes(self): # Test the protocol logic for sending keepalive pings. def restart_protocol_with_keepalive_ping( - self, ping_interval=3 * MS, ping_timeout=3 * MS + self, + ping_interval=3 * MS, + ping_timeout=3 * MS, ): initial_protocol = self.protocol + # copied from tearDown + self.transport.close() self.loop.run_until_complete(self.protocol.close()) + # copied from setUp, but enables keepalive pings - with warnings.catch_warnings(record=True): - self.protocol = WebSocketCommonProtocol( + + async def create_protocol(): + return WebSocketCommonProtocol( ping_interval=ping_interval, ping_timeout=ping_timeout, ) + + self.protocol = self.loop.run_until_complete(create_protocol()) + self.transport = TransportMock() self.transport.setup_mock(self.loop, self.protocol) self.protocol.is_client = initial_protocol.is_client From 15791c57c28b334bf548e401d037b51f16620604 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 8 Oct 2022 14:20:17 +0200 Subject: [PATCH 1100/1539] Clean deprecation warnings in test suite. --- tests/legacy/test_client_server.py | 12 ++++++++++-- tests/legacy/test_framing.py | 6 ++++-- tests/legacy/test_protocol.py | 2 ++ tests/test_imports.py | 1 + tox.ini | 2 +- 5 files changed, 18 insertions(+), 5 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 62cddd2b4..d02daede1 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -233,6 +233,7 @@ async def start_server(): return await serve(handler, "localhost", 0, **kwargs) with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter("always") self.server = self.loop.run_until_complete(start_server()) expected_warnings = [] if deprecation_warnings is None else deprecation_warnings @@ -259,6 +260,7 @@ async def start_client(): return await connect(server_uri, **kwargs) with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter("always") self.client = self.loop.run_until_complete(start_client()) expected_warnings = [] if deprecation_warnings is None else deprecation_warnings @@ -536,6 +538,7 @@ def legacy_process_request_OK(path, request_headers): @with_server(process_request=legacy_process_request_OK) def test_process_request_argument_backwards_compatibility(self): with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter("always") response = self.loop.run_until_complete(self.make_http_request("/")) with contextlib.closing(response): @@ -563,6 +566,7 @@ def process_request(self, path, request_headers): @with_server(create_protocol=LegacyProcessRequestOKServerProtocol) def test_process_request_override_backwards_compatibility(self): with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter("always") response = self.loop.run_until_complete(self.make_http_request("/")) with contextlib.closing(response): @@ -607,6 +611,7 @@ def test_protocol_deprecated_attributes(self): for server_socket in self.server.sockets ] with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter("always") client_attrs = (self.client.host, self.client.port, self.client.secure) self.assertDeprecationWarnings( recorded_warnings, @@ -620,6 +625,7 @@ def test_protocol_deprecated_attributes(self): expected_server_attrs = ("localhost", 0, self.secure) with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter("always") self.loop.run_until_complete(self.client.send("")) server_attrs = self.loop.run_until_complete(self.client.recv()) self.assertDeprecationWarnings( @@ -1356,7 +1362,8 @@ class YieldFromTests(ClientServerTestsMixin, AsyncioTestCase): @with_server() def test_client(self): # @asyncio.coroutine is deprecated on Python ≥ 3.8 - with warnings.catch_warnings(record=True): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") @asyncio.coroutine def run_client(): @@ -1370,7 +1377,8 @@ def run_client(): def test_server(self): # @asyncio.coroutine is deprecated on Python ≥ 3.8 - with warnings.catch_warnings(record=True): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") @asyncio.coroutine def run_server(): diff --git a/tests/legacy/test_framing.py b/tests/legacy/test_framing.py index 4646817a8..035f0f03c 100644 --- a/tests/legacy/test_framing.py +++ b/tests/legacy/test_framing.py @@ -17,7 +17,8 @@ def decode(self, message, mask=False, max_size=None, extensions=None): stream = asyncio.StreamReader(loop=self.loop) stream.feed_data(message) stream.feed_eof() - with warnings.catch_warnings(record=True): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") frame = self.loop.run_until_complete( Frame.read( stream.readexactly, @@ -32,7 +33,8 @@ def decode(self, message, mask=False, max_size=None, extensions=None): def encode(self, frame, mask=False, extensions=None): write = unittest.mock.Mock() - with warnings.catch_warnings(record=True): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") frame.write(write, mask=mask, extensions=extensions) # Ensure the entire frame is sent with a single call to write(). # Multiple calls cause TCP fragmentation and degrade performance. diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 6e7f3727b..4fe0092a5 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -322,6 +322,7 @@ async def create_protocol(): return WebSocketCommonProtocol(ping_interval=None, timeout=5) with warnings.catch_warnings(record=True) as recorded: + warnings.simplefilter("always") protocol = self.loop.run_until_complete(create_protocol()) self.assertEqual(protocol.close_timeout, 5) @@ -332,6 +333,7 @@ def test_loop_backwards_compatibility(self): self.addCleanup(loop.close) with warnings.catch_warnings(record=True) as recorded: + warnings.simplefilter("always") protocol = WebSocketCommonProtocol(ping_interval=None, loop=loop) self.assertEqual(protocol.loop, loop) diff --git a/tests/test_imports.py b/tests/test_imports.py index 8f1625a9b..b69ed9316 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -30,6 +30,7 @@ def test_get_deprecated_alias(self): ) with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter("always") self.assertEqual(self.mod.bar, bar) self.assertEqual(len(recorded_warnings), 1) diff --git a/tox.ini b/tox.ini index c243e9880..20c5320c9 100644 --- a/tox.ini +++ b/tox.ini @@ -2,7 +2,7 @@ envlist = py37,py38,py39,py310,coverage,black,flake8,isort,mypy [testenv] -commands = python -W default -m unittest {posargs} +commands = python -W error::DeprecationWarning -W error::PendingDeprecationWarning -m unittest {posargs} [testenv:coverage] commands = From 3a4ea9d270b5af431179adf193e1357a46e5fc95 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 8 Oct 2022 14:22:49 +0200 Subject: [PATCH 1101/1539] Confirm support for Python 3.11. --- docs/project/changelog.rst | 2 ++ tox.ini | 12 +++++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index dd19228be..6f44bb36b 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -33,6 +33,8 @@ They may change at any time. New features ............ +* Validated compatibility with Python 3.11. + * Supported overriding or removing the ``User-Agent`` header in clients and the ``Server`` header in servers. diff --git a/tox.ini b/tox.ini index 20c5320c9..3a284ed31 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,15 @@ [tox] -envlist = py37,py38,py39,py310,coverage,black,flake8,isort,mypy +envlist = + py37 + py38 + py39 + py310 + py311 + coverage + black + flake8 + isort + mypy [testenv] commands = python -W error::DeprecationWarning -W error::PendingDeprecationWarning -m unittest {posargs} From f461295aefaa517054ff374bdff6362ad1dfac7b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 8 Oct 2022 14:43:14 +0200 Subject: [PATCH 1102/1539] Add Python 3.11 for the next release. --- setup.cfg | 2 +- setup.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 91f769620..a300ce628 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bdist_wheel] -python-tag = py37.py38.py39.py310 +python-tag = py37.py38.py39.py310.py311 [metadata] license_file = LICENSE diff --git a/setup.py b/setup.py index b2d07737d..492b1597c 100644 --- a/setup.py +++ b/setup.py @@ -51,6 +51,7 @@ 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', ], package_dir = {'': 'src'}, package_data = {'websockets': ['py.typed']}, From ee54c4db1ad0d7a0701bad90e44950cc51c73ce9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 8 Oct 2022 14:45:45 +0200 Subject: [PATCH 1103/1539] Clarify comment. --- src/websockets/legacy/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index e01548896..0c69d1317 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -712,7 +712,7 @@ def wrap(self, server: asyncio.base_events.Server) -> None: self.logger.info("server listening on %s", name) # Initialized here because we need a reference to the event loop. - # This should be moved back to __init__ in Python 3.10. + # This should be moved back to __init__ when dropping Python < 3.10. self.closed_waiter = server.get_loop().create_future() def register(self, protocol: WebSocketServerProtocol) -> None: From 8ce4739b7efed3ac78b287da7fb5e537f78e72aa Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 8 Oct 2022 19:31:23 +0200 Subject: [PATCH 1104/1539] Increase maximum header length (again). Fix #1239. --- src/websockets/http11.py | 14 ++++++-------- src/websockets/legacy/http.py | 2 +- tests/test_http11.py | 6 +++--- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 84048fa47..68249192c 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -8,14 +8,12 @@ from . import datastructures, exceptions -# Maximum total size of headers is around 256 * 4 KiB = 1 MiB -MAX_HEADERS = 256 - -# We can use the same limit for the request line and header lines: -# "GET <4096 bytes> HTTP/1.1\r\n" = 4111 bytes -# "Set-Cookie: <4097 bytes>\r\n" = 4111 bytes -# (RFC requires 4096 bytes; for some reason Firefox supports 4097 bytes.) -MAX_LINE = 4111 +# Maximum total size of headers is around 128 * 8 KiB = 1 MiB. +MAX_HEADERS = 128 + +# Limit request line and header lines. 8KiB is the most common default +# configuration of popular HTTP servers. +MAX_LINE = 8192 # Support for HTTP response bodies is intended to read an error message # returned by a server. It isn't designed to perform large file transfers. diff --git a/src/websockets/legacy/http.py b/src/websockets/legacy/http.py index cc2ef1f06..d9e44cc28 100644 --- a/src/websockets/legacy/http.py +++ b/src/websockets/legacy/http.py @@ -192,7 +192,7 @@ async def read_line(stream: asyncio.StreamReader) -> bytes: """ # Security: this is bounded by the StreamReader's limit (default = 32 KiB). line = await stream.readline() - # Security: this guarantees header values are small (hard-coded = 4 KiB) + # Security: this guarantees header values are small (hard-coded = 8 KiB) if len(line) > MAX_LINE: raise SecurityError("line too long") # Not mandatory but safe - https://www.rfc-editor.org/rfc/rfc7230.html#section-3.5 diff --git a/tests/test_http11.py b/tests/test_http11.py index 61d377925..d2e5e0462 100644 --- a/tests/test_http11.py +++ b/tests/test_http11.py @@ -328,13 +328,13 @@ def test_parse_invalid_value(self): next(self.parse_headers()) def test_parse_too_long_value(self): - self.reader.feed_data(b"foo: bar\r\n" * 257 + b"\r\n") + self.reader.feed_data(b"foo: bar\r\n" * 129 + b"\r\n") with self.assertRaises(SecurityError): next(self.parse_headers()) def test_parse_too_long_line(self): - # Header line contains 5 + 4105 + 2 = 4112 bytes. - self.reader.feed_data(b"foo: " + b"a" * 4105 + b"\r\n\r\n") + # Header line contains 5 + 8186 + 2 = 8193 bytes. + self.reader.feed_data(b"foo: " + b"a" * 8186 + b"\r\n\r\n") with self.assertRaises(SecurityError): next(self.parse_headers()) From 86961582f40596efe81d616ac9daa5369e0e17e9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Oct 2022 14:53:10 +0200 Subject: [PATCH 1105/1539] Add API for connection latency. Fix #1195. --- docs/project/changelog.rst | 6 ++++ docs/reference/client.rst | 2 ++ docs/reference/common.rst | 2 ++ docs/reference/server.rst | 2 ++ docs/topics/timeouts.rst | 28 ++++++----------- src/websockets/legacy/protocol.py | 52 +++++++++++++++++++++---------- tests/legacy/test_protocol.py | 49 +++++++++++++++++++---------- 7 files changed, 90 insertions(+), 51 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 6f44bb36b..8ad7bc5ab 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -35,6 +35,12 @@ New features * Validated compatibility with Python 3.11. +* Added the :attr:`~legacy.protocol.WebSocketCommonProtocol.latency` property to + protocols. + +* Changed :attr:`~legacy.protocol.WebSocketCommonProtocol.ping` to return the + latency of the connection. + * Supported overriding or removing the ``User-Agent`` header in clients and the ``Server`` header in servers. diff --git a/docs/reference/client.rst b/docs/reference/client.rst index 3016e85d7..b72f49f5d 100644 --- a/docs/reference/client.rst +++ b/docs/reference/client.rst @@ -46,6 +46,8 @@ Using a connection .. autoproperty:: closed + .. autoattribute:: latency + The following attributes are available after the opening handshake, once the WebSocket connection is open: diff --git a/docs/reference/common.rst b/docs/reference/common.rst index f2683bc77..6ba11bff5 100644 --- a/docs/reference/common.rst +++ b/docs/reference/common.rst @@ -34,6 +34,8 @@ asyncio .. autoproperty:: closed + .. autoattribute:: latency + The following attributes are available after the opening handshake, once the WebSocket connection is open: diff --git a/docs/reference/server.rst b/docs/reference/server.rst index 65f98842a..12fe1f806 100644 --- a/docs/reference/server.rst +++ b/docs/reference/server.rst @@ -71,6 +71,8 @@ Using a connection .. autoproperty:: closed + .. autoattribute:: latency + The following attributes are available after the opening handshake, once the WebSocket connection is open: diff --git a/docs/topics/timeouts.rst b/docs/topics/timeouts.rst index 23e8020c4..633fc1ab4 100644 --- a/docs/topics/timeouts.rst +++ b/docs/topics/timeouts.rst @@ -20,6 +20,8 @@ infrastructure closes idle connections after 30 to 120 seconds. As a consequence, proxies may terminate WebSocket connections prematurely when no message was exchanged in 30 seconds. +.. _keepalive: + Keepalive in websockets ----------------------- @@ -101,26 +103,14 @@ Latency between a client and a server may increase for two reasons: the default timeout elapses. As a consequence, it closes the connection. This is a reasonable choice to prevent overload. - If traffic spikes cause unwanted timeouts and you're confident that the - server will catch up eventually, you can increase ``ping_timeout`` or you - can disable keepalive entirely with ``ping_interval=None``. + If traffic spikes cause unwanted timeouts and you're confident that the server + will catch up eventually, you can increase ``ping_timeout`` or you can set it + to :obj:`None` to disable heartbeat entirely. The same reasoning applies to situations where the server sends more traffic than the client can accept. -You can monitor latency as follows: - -.. code-block:: python - - import asyncio - import logging - import time - - async def log_latency(websocket, logger): - t0 = time.perf_counter() - pong_waiter = await websocket.ping() - await pong_waiter - t1 = time.perf_counter() - logger.info("Connection latency: %.3f seconds", t1 - t0) - - asyncio.create_task(log_latency(websocket, logging.getLogger())) +The latency measured during the last exchange of Ping and Pong frames is +available in the :attr:`~legacy.protocol.WebSocketCommonProtocol.latency` +attribute. Alternatively, you can measure the latency at any time with the +:attr:`~legacy.protocol.WebSocketCommonProtocol.ping` method. diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 04726033e..1b6e58efa 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -7,6 +7,7 @@ import random import ssl import struct +import time import uuid import warnings from typing import ( @@ -21,6 +22,7 @@ List, Mapping, Optional, + Tuple, Union, cast, ) @@ -286,7 +288,19 @@ def __init__( self._fragmented_message_waiter: Optional[asyncio.Future[None]] = None # Mapping of ping IDs to pong waiters, in chronological order. - self.pings: Dict[bytes, asyncio.Future[None]] = {} + self.pings: Dict[bytes, Tuple[asyncio.Future[float], float]] = {} + + self.latency: float = 0 + """ + Latency of the connection, in seconds. + + This value is updated after sending a ping frame and receiving a + matching pong frame. Before the first ping, :attr:`latency` is ``0``. + + By default, websockets enables a :ref:`keepalive ` mechanism + that sends ping frames automatically at regular intervals. You can also + send ping frames and measure latency with :meth:`ping`. + """ # Task running the data transfer. self.transfer_data_task: asyncio.Task[None] @@ -802,8 +816,8 @@ async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 - A ping may serve as a keepalive or as a check that the remote endpoint - received all messages up to this point + A ping may serve as a keepalive, as a check that the remote endpoint + received all messages up to this point, or to measure :attr:`latency`. Canceling :meth:`ping` is discouraged. If :meth:`ping` doesn't return immediately, it means the write buffer is full. If you don't want to @@ -818,14 +832,16 @@ async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: containing four random bytes. Returns: - ~asyncio.Future: A future that will be completed when the - corresponding pong is received. You can ignore it if you - don't intend to wait. + ~asyncio.Future[float]: A future that will be completed when the + corresponding pong is received. You can ignore it if you don't + intend to wait. The result of the future is the latency of the + connection in seconds. :: pong_waiter = await ws.ping() - await pong_waiter # only if you want to wait for the pong + # only if you want to wait for the corresponding pong + latency = await pong_waiter Raises: ConnectionClosed: when the connection is closed. @@ -846,12 +862,14 @@ async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: while data is None or data in self.pings: data = struct.pack("!I", random.getrandbits(32)) - ping_future = self.loop.create_future() - self.pings[data] = ping_future + pong_waiter = self.loop.create_future() + # Resolution of time.monotonic() may be too low on Windows. + ping_timestamp = time.perf_counter() + self.pings[data] = (pong_waiter, ping_timestamp) await self.write_frame(True, OP_PING, data) - return asyncio.shield(ping_future) + return asyncio.shield(pong_waiter) async def pong(self, data: Data = b"") -> None: """ @@ -1122,15 +1140,17 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: elif frame.opcode == OP_PONG: if frame.data in self.pings: + pong_timestamp = time.perf_counter() # Sending a pong for only the most recent ping is legal. # Acknowledge all previous pings too in that case. ping_id = None ping_ids = [] - for ping_id, ping in self.pings.items(): + for ping_id, (pong_waiter, ping_timestamp) in self.pings.items(): ping_ids.append(ping_id) - if not ping.done(): - ping.set_result(None) + if not pong_waiter.done(): + pong_waiter.set_result(pong_timestamp - ping_timestamp) if ping_id == frame.data: + self.latency = pong_timestamp - ping_timestamp break else: # pragma: no cover assert False, "ping_id is in self.pings" @@ -1454,13 +1474,13 @@ def abort_pings(self) -> None: assert self.state is State.CLOSED exc = self.connection_closed_exc() - for ping in self.pings.values(): - ping.set_exception(exc) + for pong_waiter, _ping_timestamp in self.pings.values(): + pong_waiter.set_exception(exc) # If the exception is never retrieved, it will be logged when ping # is garbage-collected. This is confusing for users. # Given that ping is done (with an exception), canceling it does # nothing, but it prevents logging the exception. - ping.cancel() + pong_waiter.cancel() # asyncio.Protocol methods diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 4fe0092a5..1f830ebee 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -938,33 +938,33 @@ def test_ignore_pong(self): self.assertNoFrameSent() def test_acknowledge_ping(self): - ping = self.loop.run_until_complete(self.protocol.ping()) - self.assertFalse(ping.done()) + pong_waiter = self.loop.run_until_complete(self.protocol.ping()) + self.assertFalse(pong_waiter.done()) ping_frame = self.last_sent_frame() pong_frame = Frame(True, OP_PONG, ping_frame.data) self.receive_frame(pong_frame) self.run_loop_once() self.run_loop_once() - self.assertTrue(ping.done()) + self.assertTrue(pong_waiter.done()) def test_abort_ping(self): - ping = self.loop.run_until_complete(self.protocol.ping()) + pong_waiter = self.loop.run_until_complete(self.protocol.ping()) # Remove the frame from the buffer, else close_connection() complains. self.last_sent_frame() - self.assertFalse(ping.done()) + self.assertFalse(pong_waiter.done()) self.close_connection() - self.assertTrue(ping.done()) - self.assertIsInstance(ping.exception(), ConnectionClosed) + self.assertTrue(pong_waiter.done()) + self.assertIsInstance(pong_waiter.exception(), ConnectionClosed) def test_abort_ping_does_not_log_exception_if_not_retreived(self): self.loop.run_until_complete(self.protocol.ping()) # Get the internal Future, which isn't directly returned by ping(). - (ping,) = self.protocol.pings.values() + ((pong_waiter, _timestamp),) = self.protocol.pings.values() # Remove the frame from the buffer, else close_connection() complains. self.last_sent_frame() self.close_connection() # Check a private attribute, for lack of a better solution. - self.assertFalse(ping._log_traceback) + self.assertFalse(pong_waiter._log_traceback) def test_acknowledge_previous_pings(self): pings = [ @@ -987,7 +987,7 @@ def test_acknowledge_previous_pings(self): self.assertFalse(pings[2][0].done()) def test_acknowledge_aborted_ping(self): - ping = self.loop.run_until_complete(self.protocol.ping()) + pong_waiter = self.loop.run_until_complete(self.protocol.ping()) ping_frame = self.last_sent_frame() # Clog incoming queue. This lets connection_lost() abort pending pings # with a ConnectionClosed exception before transfer_data_task @@ -1003,7 +1003,7 @@ def test_acknowledge_aborted_ping(self): self.loop.run_until_complete(self.protocol.wait_closed()) # Ping receives a ConnectionClosed exception. with self.assertRaises(ConnectionClosed): - ping.result() + pong_waiter.result() # transfer_data doesn't crash, which would be logged. with self.assertNoLogs(): @@ -1012,14 +1012,14 @@ def test_acknowledge_aborted_ping(self): self.loop.run_until_complete(self.protocol.recv()) def test_canceled_ping(self): - ping = self.loop.run_until_complete(self.protocol.ping()) + pong_waiter = self.loop.run_until_complete(self.protocol.ping()) ping_frame = self.last_sent_frame() - ping.cancel() + pong_waiter.cancel() pong_frame = Frame(True, OP_PONG, ping_frame.data) self.receive_frame(pong_frame) self.run_loop_once() self.run_loop_once() - self.assertTrue(ping.cancelled()) + self.assertTrue(pong_waiter.cancelled()) def test_duplicate_ping(self): self.loop.run_until_complete(self.protocol.ping(b"foobar")) @@ -1028,6 +1028,23 @@ def test_duplicate_ping(self): self.loop.run_until_complete(self.protocol.ping(b"foobar")) self.assertNoFrameSent() + # Test the protocol's logic for measuring latency + + def test_record_latency_on_pong(self): + self.assertEqual(self.protocol.latency, 0) + self.loop.run_until_complete(self.protocol.ping(b"test")) + self.receive_frame(Frame(True, OP_PONG, b"test")) + self.run_loop_once() + self.assertGreater(self.protocol.latency, 0) + + def test_return_latency_on_pong(self): + pong_waiter = self.loop.run_until_complete(self.protocol.ping()) + ping_frame = self.last_sent_frame() + pong_frame = Frame(True, OP_PONG, ping_frame.data) + self.receive_frame(pong_frame) + latency = self.loop.run_until_complete(pong_waiter) + self.assertGreater(latency, 0) + # Test the protocol's logic for rebuilding fragmented messages. def test_fragmented_text(self): @@ -1244,14 +1261,14 @@ def test_keepalive_ping_does_not_crash_when_connection_lost(self): self.receive_frame(Frame(True, OP_TEXT, b"2")) # Ping is sent at 3ms. self.loop.run_until_complete(asyncio.sleep(4 * MS)) - (ping_waiter,) = tuple(self.protocol.pings.values()) + ((pong_waiter, _timestamp),) = self.protocol.pings.values() # Connection drops. self.receive_eof() self.loop.run_until_complete(self.protocol.wait_closed()) # The ping waiter receives a ConnectionClosed exception. with self.assertRaises(ConnectionClosed): - ping_waiter.result() + pong_waiter.result() # The keepalive ping task terminated properly. self.assertIsNone(self.protocol.keepalive_ping_task.result()) From d5cf4a94e599dd678255a2d23712cb82a43ec41a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Oct 2022 15:32:09 +0200 Subject: [PATCH 1106/1539] Document workaround for bug in Python < 3.10. Fix #1182. --- docs/topics/logging.rst | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index 393efeb71..95acf57ff 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -126,7 +126,8 @@ Here's how to include them in logs, assuming they're in the async with websockets.serve( ..., - logger=LoggerAdapter(logging.getLogger("websockets.server")), + # Python < 3.10 requires passing None as the second argument. + logger=LoggerAdapter(logging.getLogger("websockets.server"), None), ): ... @@ -167,7 +168,8 @@ a :class:`~logging.LoggerAdapter`:: async with websockets.serve( ..., - logger=LoggerAdapter(logging.getLogger("websockets.server")), + # Python < 3.10 requires passing None as the second argument. + logger=LoggerAdapter(logging.getLogger("websockets.server"), None), ): ... From a07f4e8190755de47f9915ef89535de18c02f88e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Oct 2022 15:45:48 +0200 Subject: [PATCH 1107/1539] Highlight issue with convenience imports in FAQ. Fix #1183. --- docs/faq/misc.rst | 51 +++++++++++++++++++++++++++++----------- docs/reference/index.rst | 15 ++++++++---- 2 files changed, 47 insertions(+), 19 deletions(-) diff --git a/docs/faq/misc.rst b/docs/faq/misc.rst index 9ef07ef2d..681c5e45a 100644 --- a/docs/faq/misc.rst +++ b/docs/faq/misc.rst @@ -3,27 +3,35 @@ Miscellaneous .. currentmodule:: websockets -Can I use websockets without ``async`` and ``await``? -..................................................... +Why do I get the error: ``module 'websockets' has no attribute '...'``? +....................................................................... -No, there is no convenient way to do this. You should use another library. +Often, this is because you created a script called ``websockets.py`` in your +current working directory. Then ``import websockets`` imports this module +instead of the websockets library. -Are there ``onopen``, ``onmessage``, ``onerror``, and ``onclose`` callbacks? -............................................................................ +Why does my IDE fail to show documentation for websockets APIs? +............................................................... -No, there aren't. +You are probably using the convenience imports e.g.:: -websockets provides high-level, coroutine-based APIs. Compared to callbacks, -coroutines make it easier to manage control flow in concurrent code. + import websockets -If you prefer callback-based APIs, you should use another library. + websockets.connect(...) + websockets.serve(...) -Why do I get the error: ``module 'websockets' has no attribute '...'``? -....................................................................... +This is incompatible with static code analysis. It may break auto-completion and +contextual documentation in IDEs, type checking with mypy_, etc. -Often, this is because you created a script called ``websockets.py`` in your -current working directory. Then ``import websockets`` imports this module -instead of the websockets library. +.. _mypy: https://github.com/python/mypy + +Instead, use the real import paths e.g.:: + + import websockets.client + import websockets.server + + websockets.client.connect(...) + websockets.server.serve(...) Why is websockets slower than another Python library in my benchmark? ..................................................................... @@ -37,3 +45,18 @@ you may need to disable: * UTF-8 decoding: send ``bytes`` rather than ``str`` If websockets is still slower than another Python library, please file a bug. + +Can I use websockets without ``async`` and ``await``? +..................................................... + +No, there is no convenient way to do this. You should use another library. + +Are there ``onopen``, ``onmessage``, ``onerror``, and ``onclose`` callbacks? +............................................................................ + +No, there aren't. + +websockets provides high-level, coroutine-based APIs. Compared to callbacks, +coroutines make it easier to manage control flow in concurrent code. + +If you prefer callback-based APIs, you should use another library. diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 8147c4cf3..51af8e622 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -58,9 +58,14 @@ Public API documented in the API reference are subject to the Anything that isn't listed in the API reference is a private API. There's no guarantees of behavior or backwards-compatibility for private APIs. -For convenience, many public APIs can be imported from the ``websockets`` -package. However, this feature is incompatible with static code analysis. It -breaks auto-completion in an IDE or type checking with mypy_. If you're using -such tools, use the real import paths. +.. admonition:: Convenience imports are incompatible with some development tools. + :class: caution -.. _mypy: https://github.com/python/mypy + For convenience, most public APIs can be imported from the ``websockets`` + package. However, this is incompatible with static code analysis. + + It may break auto-completion and contextual documentation in IDEs, type + checking with mypy_, etc. If you're using such tools, stick to the full + import paths. + + .. _mypy: https://github.com/python/mypy From 580c41744330a83ff0e4eda728e2aab1a10a4e79 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Oct 2022 15:46:31 +0200 Subject: [PATCH 1108/1539] Remove obsolete justification. --- docs/reference/index.rst | 3 --- 1 file changed, 3 deletions(-) diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 51af8e622..f164dde91 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -37,9 +37,6 @@ are available on client connections created with :func:`~client.connect` and on server connections received in argument by the connection handler of :func:`~server.serve`. -Since websockets provides the same API — and uses the same code — for client -and server connections, common methods are documented in a "Both sides" page. - .. toctree:: :titlesonly: From ad797212ce45bcab7c4cf57d21095a12e8f284ba Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 10 Oct 2022 08:52:29 +0200 Subject: [PATCH 1109/1539] Don't log when connection drops during handshake. Fix #1237. Refs #984. --- src/websockets/legacy/server.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 0c69d1317..c32c85612 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -598,7 +598,8 @@ async def handshake( # The connection may drop while process_request is running. if self.state is State.CLOSED: - raise self.connection_closed_exc() # pragma: no cover + # This subclass of ConnectionError is silently ignored in handler(). + raise BrokenPipeError("connection closed during opening handshake") # Change the response to a 503 error if the server is shutting down. if not self.ws_server.is_serving(): From cdd30e453f421d65d9650ee10425996ed12912b6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 14 Oct 2022 22:27:50 +0200 Subject: [PATCH 1110/1539] Make a note to use a new API. --- tests/legacy/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/legacy/utils.py b/tests/legacy/utils.py index fd5dfc294..bc37ee7bf 100644 --- a/tests/legacy/utils.py +++ b/tests/legacy/utils.py @@ -12,6 +12,8 @@ class AsyncioTestCase(unittest.TestCase): """ Base class for tests that sets up an isolated event loop for each test. + Replace with IsolatedAsyncioTestCase when dropping Python < 3.8. + """ def __init_subclass__(cls, **kwargs): From eb86f67bce6942b8ec5e3fba68a5f468d01b6eaf Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 Oct 2022 08:48:20 +0200 Subject: [PATCH 1111/1539] Remove User-Agent/Server from the Sans-I/O layer. It doesn't make sense to set it at the library level. It should be set by the embedding program. Partially reverts 2a07325c. --- src/websockets/client.py | 9 --------- src/websockets/server.py | 12 ------------ tests/test_client.py | 25 ------------------------- tests/test_server.py | 31 ------------------------------- 4 files changed, 77 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index c18c76f22..373e6b751 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -23,7 +23,6 @@ parse_subprotocol, parse_upgrade, ) -from .http import USER_AGENT from .http11 import Request, Response from .typing import ( ConnectionOption, @@ -64,9 +63,6 @@ class ClientConnection(Connection): logger: logger for this connection; defaults to ``logging.getLogger("websockets.client")``; see the :doc:`logging guide <../topics/logging>` for details. - user_agent_header: value of the ``User-Agent`` request header; - defaults to ``"Python/x.y.z websockets/X.Y"``; - :obj:`None` removes the header. """ @@ -76,7 +72,6 @@ def __init__( origin: Optional[Origin] = None, extensions: Optional[Sequence[ClientExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, - user_agent_header: Optional[str] = USER_AGENT, state: State = CONNECTING, max_size: Optional[int] = 2**20, logger: Optional[LoggerLike] = None, @@ -91,7 +86,6 @@ def __init__( self.origin = origin self.available_extensions = extensions self.available_subprotocols = subprotocols - self.user_agent_header = user_agent_header self.key = generate_key() def connect(self) -> Request: # noqa: F811 @@ -136,9 +130,6 @@ def connect(self) -> Request: # noqa: F811 protocol_header = build_subprotocol(self.available_subprotocols) headers["Sec-WebSocket-Protocol"] = protocol_header - if self.user_agent_header is not None: - headers["User-Agent"] = self.user_agent_header - return Request(self.wsuri.resource_name, headers) def process_response(self, response: Response) -> None: diff --git a/src/websockets/server.py b/src/websockets/server.py index ca48449d6..edd1764c3 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -25,7 +25,6 @@ parse_subprotocol, parse_upgrade, ) -from .http import USER_AGENT from .http11 import Request, Response from .typing import ( ConnectionOption, @@ -64,9 +63,6 @@ class ServerConnection(Connection): logger: logger for this connection; defaults to ``logging.getLogger("websockets.client")``; see the :doc:`logging guide <../topics/logging>` for details. - server_header: value of the ``Server`` response header; - defaults to ``"Python/x.y.z websockets/X.Y"``; - :obj:`None` removes the header. """ @@ -75,7 +71,6 @@ def __init__( origins: Optional[Sequence[Optional[Origin]]] = None, extensions: Optional[Sequence[ServerExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, - server_header: Optional[str] = USER_AGENT, state: State = CONNECTING, max_size: Optional[int] = 2**20, logger: Optional[LoggerLike] = None, @@ -89,7 +84,6 @@ def __init__( self.origins = origins self.available_extensions = extensions self.available_subprotocols = subprotocols - self.server_header = server_header def accept(self, request: Request) -> Response: """ @@ -175,9 +169,6 @@ def accept(self, request: Request) -> Response: if protocol_header is not None: headers["Sec-WebSocket-Protocol"] = protocol_header - if self.server_header is not None: - headers["Server"] = self.server_header - self.logger.info("connection open") return Response(101, "Switching Protocols", headers) @@ -477,9 +468,6 @@ def reject( ("Content-Type", "text/plain; charset=utf-8"), ] ) - if self.server_header is not None: - headers["Server"] = self.server_header - response = Response(status.value, status.phrase, headers, body) # When reject() is called from accept(), handshake_exc is already set. # If a user calls reject(), set handshake_exc to guarantee invariant: diff --git a/tests/test_client.py b/tests/test_client.py index 0504f79e6..9ed36c1d4 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -7,7 +7,6 @@ from websockets.datastructures import Headers from websockets.exceptions import InvalidHandshake, InvalidHeader from websockets.frames import OP_TEXT, Frame -from websockets.http import USER_AGENT from websockets.http11 import Request, Response from websockets.uri import parse_uri from websockets.utils import accept_key @@ -38,7 +37,6 @@ def test_send_connect(self): f"Connection: Upgrade\r\n" f"Sec-WebSocket-Key: {KEY}\r\n" f"Sec-WebSocket-Version: 13\r\n" - f"User-Agent: {USER_AGENT}\r\n" f"\r\n".encode() ], ) @@ -58,7 +56,6 @@ def test_connect_request(self): "Connection": "Upgrade", "Sec-WebSocket-Key": KEY, "Sec-WebSocket-Version": "13", - "User-Agent": USER_AGENT, } ), ) @@ -130,7 +127,6 @@ def test_receive_accept(self): f"Connection: Upgrade\r\n" f"Sec-WebSocket-Accept: {ACCEPT}\r\n" f"Date: {DATE}\r\n" - f"Server: {USER_AGENT}\r\n" f"\r\n" ).encode(), ) @@ -148,7 +144,6 @@ def test_receive_reject(self): ( f"HTTP/1.1 404 Not Found\r\n" f"Date: {DATE}\r\n" - f"Server: {USER_AGENT}\r\n" f"Content-Length: 13\r\n" f"Content-Type: text/plain; charset=utf-8\r\n" f"Connection: close\r\n" @@ -173,7 +168,6 @@ def test_accept_response(self): f"Connection: Upgrade\r\n" f"Sec-WebSocket-Accept: {ACCEPT}\r\n" f"Date: {DATE}\r\n" - f"Server: {USER_AGENT}\r\n" f"\r\n" ).encode(), ) @@ -188,7 +182,6 @@ def test_accept_response(self): "Connection": "Upgrade", "Sec-WebSocket-Accept": ACCEPT, "Date": DATE, - "Server": USER_AGENT, } ), ) @@ -202,7 +195,6 @@ def test_reject_response(self): ( f"HTTP/1.1 404 Not Found\r\n" f"Date: {DATE}\r\n" - f"Server: {USER_AGENT}\r\n" f"Content-Length: 13\r\n" f"Content-Type: text/plain; charset=utf-8\r\n" f"Connection: close\r\n" @@ -218,7 +210,6 @@ def test_reject_response(self): Headers( { "Date": DATE, - "Server": USER_AGENT, "Content-Length": "13", "Content-Type": "text/plain; charset=utf-8", "Connection": "close", @@ -574,22 +565,6 @@ def test_unsupported_subprotocol(self): raise client.handshake_exc self.assertEqual(str(raised.exception), "unsupported subprotocol: otherchat") - def test_no_user_agent_header(self): - client = ClientConnection( - parse_uri("wss://example.com/"), - user_agent_header=None, - ) - request = client.connect() - self.assertNotIn("User-Agent", request.headers) - - def test_custom_user_agent_header(self): - client = ClientConnection( - parse_uri("wss://example.com/"), - user_agent_header="websockets", - ) - request = client.connect() - self.assertEqual(request.headers["User-Agent"], "websockets") - class MiscTests(unittest.TestCase): def test_bypass_handshake(self): diff --git a/tests/test_server.py b/tests/test_server.py index c7f398cdf..f1404499b 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -7,7 +7,6 @@ from websockets.datastructures import Headers from websockets.exceptions import InvalidHeader, InvalidOrigin, InvalidUpgrade from websockets.frames import OP_TEXT, Frame -from websockets.http import USER_AGENT from websockets.http11 import Request, Response from websockets.server import * @@ -32,7 +31,6 @@ def test_receive_connect(self): f"Connection: Upgrade\r\n" f"Sec-WebSocket-Key: {KEY}\r\n" f"Sec-WebSocket-Version: 13\r\n" - f"User-Agent: {USER_AGENT}\r\n" f"\r\n" ).encode(), ) @@ -51,7 +49,6 @@ def test_connect_request(self): f"Connection: Upgrade\r\n" f"Sec-WebSocket-Key: {KEY}\r\n" f"Sec-WebSocket-Version: 13\r\n" - f"User-Agent: {USER_AGENT}\r\n" f"\r\n" ).encode(), ) @@ -66,7 +63,6 @@ def test_connect_request(self): "Connection": "Upgrade", "Sec-WebSocket-Key": KEY, "Sec-WebSocket-Version": "13", - "User-Agent": USER_AGENT, } ), ) @@ -83,7 +79,6 @@ def make_request(self): "Connection": "Upgrade", "Sec-WebSocket-Key": KEY, "Sec-WebSocket-Version": "13", - "User-Agent": USER_AGENT, } ), ) @@ -102,7 +97,6 @@ def test_send_accept(self): f"Upgrade: websocket\r\n" f"Connection: Upgrade\r\n" f"Sec-WebSocket-Accept: {ACCEPT}\r\n" - f"Server: {USER_AGENT}\r\n" f"\r\n".encode() ], ) @@ -123,7 +117,6 @@ def test_send_reject(self): f"Connection: close\r\n" f"Content-Length: 13\r\n" f"Content-Type: text/plain; charset=utf-8\r\n" - f"Server: {USER_AGENT}\r\n" f"\r\n" f"Sorry folks.\n".encode(), b"", @@ -147,7 +140,6 @@ def test_accept_response(self): "Upgrade": "websocket", "Connection": "Upgrade", "Sec-WebSocket-Accept": ACCEPT, - "Server": USER_AGENT, } ), ) @@ -168,7 +160,6 @@ def test_reject_response(self): "Connection": "close", "Content-Length": "13", "Content-Type": "text/plain; charset=utf-8", - "Server": USER_AGENT, } ), ) @@ -608,28 +599,6 @@ def test_unsupported_subprotocol(self): self.assertNotIn("Sec-WebSocket-Protocol", response.headers) self.assertIsNone(server.subprotocol) - def test_no_server_header(self): - server = ServerConnection(server_header=None) - request = self.make_request() - response = server.accept(request) - self.assertNotIn("Server", response.headers) - - def test_custom_server_header(self): - server = ServerConnection(server_header="websockets") - request = self.make_request() - response = server.accept(request) - self.assertEqual(response.headers["Server"], "websockets") - - def test_reject_response_no_server_header(self): - server = ServerConnection(server_header=None) - response = server.reject(http.HTTPStatus.OK, "Hello world!\n") - self.assertNotIn("Server", response.headers) - - def test_reject_response_custom_server_header(self): - server = ServerConnection(server_header="websockets") - response = server.reject(http.HTTPStatus.OK, "Hello world!\n") - self.assertEqual(response.headers["Server"], "websockets") - class MiscTests(unittest.TestCase): def test_bypass_handshake(self): From 4ae6d7b2a21b02bffa52f2653d344b6da07413cb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 Oct 2022 17:58:49 +0200 Subject: [PATCH 1112/1539] Include correct files in make coverage & tox -e coverage. --- Makefile | 2 +- tox.ini | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 578f6b1ae..cd8095ae1 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ test: python -m unittest coverage: - coverage run --source websockets,tests -m unittest + coverage run --source src/websockets,tests -m unittest coverage html coverage report --show-missing --fail-under=100 diff --git a/tox.ini b/tox.ini index 3a284ed31..2f37cdcbf 100644 --- a/tox.ini +++ b/tox.ini @@ -17,7 +17,7 @@ commands = python -W error::DeprecationWarning -W error::PendingDeprecationWarni [testenv:coverage] commands = python -m coverage erase - python -W default -m coverage run -m unittest {posargs} + python -m coverage run --source {envsitepackagesdir}/websockets,tests -m unittest {posargs} python -m coverage report --show-missing --fail-under=100 deps = coverage From 781f5b8c46070b74d1761c709b67820348eb978b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 Oct 2022 09:09:33 +0200 Subject: [PATCH 1113/1539] Standardize access to State and Side enum values. Access them directly rather than as attributes of the State class. --- tests/test_connection.py | 395 +++++++++++++++++++-------------------- 1 file changed, 196 insertions(+), 199 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index 3d4d98436..44f5bc35e 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,6 +1,7 @@ import unittest.mock from websockets.connection import * +from websockets.connection import CLIENT, CLOSED, CLOSING, SERVER from websockets.exceptions import ( ConnectionClosedError, ConnectionClosedOK, @@ -38,7 +39,7 @@ def assertFrameSent(self, connection, frame, eof=False): if write is SEND_EOF else self.parse( write, - mask=connection.side is Side.CLIENT, + mask=connection.side is CLIENT, extensions=connection.extensions, ) for write in connection.data_to_send() @@ -71,9 +72,7 @@ def assertConnectionClosing(self, connection, code=None, reason=""): # A close frame was received. self.assertFrameReceived(connection, close_frame) # A close frame and possibly the end of stream were sent. - self.assertFrameSent( - connection, close_frame, eof=connection.side is Side.SERVER - ) + self.assertFrameSent(connection, close_frame, eof=connection.side is SERVER) def assertConnectionFailing(self, connection, code=None, reason=""): """ @@ -87,9 +86,7 @@ def assertConnectionFailing(self, connection, code=None, reason=""): # No frame was received. self.assertFrameReceived(connection, None) # A close frame and possibly the end of stream were sent. - self.assertFrameSent( - connection, close_frame, eof=connection.side is Side.SERVER - ) + self.assertFrameSent(connection, close_frame, eof=connection.side is SERVER) class MaskingTests(ConnectionTestCase): @@ -104,18 +101,18 @@ class MaskingTests(ConnectionTestCase): masked_text_frame_data = b"\x81\x84\x00\xff\x00\xff\x53\x8f\x61\x92" def test_client_sends_masked_frame(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x00\xff\x00\xff"): client.send_text(b"Spam", True) self.assertEqual(client.data_to_send(), [self.masked_text_frame_data]) def test_server_sends_unmasked_frame(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_text(b"Spam", True) self.assertEqual(server.data_to_send(), [self.unmasked_text_frame_date]) def test_client_receives_unmasked_frame(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(self.unmasked_text_frame_date) self.assertFrameReceived( client, @@ -123,7 +120,7 @@ def test_client_receives_unmasked_frame(self): ) def test_server_receives_masked_frame(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(self.masked_text_frame_data) self.assertFrameReceived( server, @@ -131,14 +128,14 @@ def test_server_receives_masked_frame(self): ) def test_client_receives_masked_frame(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(self.masked_text_frame_data) self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "incorrect masking") self.assertConnectionFailing(client, 1002, "incorrect masking") def test_server_receives_unmasked_frame(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(self.unmasked_text_frame_date) self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "incorrect masking") @@ -152,33 +149,33 @@ class ContinuationTests(ConnectionTestCase): """ def test_client_sends_unexpected_continuation(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.assertRaises(ProtocolError) as raised: client.send_continuation(b"", fin=False) self.assertEqual(str(raised.exception), "unexpected continuation frame") def test_server_sends_unexpected_continuation(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) with self.assertRaises(ProtocolError) as raised: server.send_continuation(b"", fin=False) self.assertEqual(str(raised.exception), "unexpected continuation frame") def test_client_receives_unexpected_continuation(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x00\x00") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "unexpected continuation frame") self.assertConnectionFailing(client, 1002, "unexpected continuation frame") def test_server_receives_unexpected_continuation(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x00\x80\x00\x00\x00\x00") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "unexpected continuation frame") self.assertConnectionFailing(server, 1002, "unexpected continuation frame") def test_client_sends_continuation_after_sending_close(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) # Since it isn't possible to send a close frame in a fragmented # message (see test_client_send_close_in_fragmented_message), in fact, # this is the same test as test_client_sends_unexpected_continuation. @@ -193,7 +190,7 @@ def test_server_sends_continuation_after_sending_close(self): # Since it isn't possible to send a close frame in a fragmented # message (see test_server_send_close_in_fragmented_message), in fact, # this is the same test as test_server_sends_unexpected_continuation. - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_close(1000) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) with self.assertRaises(ProtocolError) as raised: @@ -201,7 +198,7 @@ def test_server_sends_continuation_after_sending_close(self): self.assertEqual(str(raised.exception), "unexpected continuation frame") def test_client_receives_continuation_after_receiving_close(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, 1000) client.receive_data(b"\x00\x00") @@ -209,7 +206,7 @@ def test_client_receives_continuation_after_receiving_close(self): self.assertFrameSent(client, None) def test_server_receives_continuation_after_receiving_close(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, 1001) server.receive_data(b"\x00\x80\x00\xff\x00\xff") @@ -224,7 +221,7 @@ class TextTests(ConnectionTestCase): """ def test_client_sends_text(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_text("😀".encode()) self.assertEqual( @@ -232,12 +229,12 @@ def test_client_sends_text(self): ) def test_server_sends_text(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_text("😀".encode()) self.assertEqual(server.data_to_send(), [b"\x81\x04\xf0\x9f\x98\x80"]) def test_client_receives_text(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") self.assertFrameReceived( client, @@ -245,7 +242,7 @@ def test_client_receives_text(self): ) def test_server_receives_text(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") self.assertFrameReceived( server, @@ -253,21 +250,21 @@ def test_server_receives_text(self): ) def test_client_receives_text_over_size_limit(self): - client = Connection(Side.CLIENT, max_size=3) + client = Connection(CLIENT, max_size=3) client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") self.assertIsInstance(client.parser_exc, PayloadTooBig) self.assertEqual(str(client.parser_exc), "over size limit (4 > 3 bytes)") self.assertConnectionFailing(client, 1009, "over size limit (4 > 3 bytes)") def test_server_receives_text_over_size_limit(self): - server = Connection(Side.SERVER, max_size=3) + server = Connection(SERVER, max_size=3) server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") self.assertIsInstance(server.parser_exc, PayloadTooBig) self.assertEqual(str(server.parser_exc), "over size limit (4 > 3 bytes)") self.assertConnectionFailing(server, 1009, "over size limit (4 > 3 bytes)") def test_client_receives_text_without_size_limit(self): - client = Connection(Side.CLIENT, max_size=None) + client = Connection(CLIENT, max_size=None) client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") self.assertFrameReceived( client, @@ -275,7 +272,7 @@ def test_client_receives_text_without_size_limit(self): ) def test_server_receives_text_without_size_limit(self): - server = Connection(Side.SERVER, max_size=None) + server = Connection(SERVER, max_size=None) server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") self.assertFrameReceived( server, @@ -283,7 +280,7 @@ def test_server_receives_text_without_size_limit(self): ) def test_client_sends_fragmented_text(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_text("😀".encode()[:2], fin=False) self.assertEqual(client.data_to_send(), [b"\x01\x82\x00\x00\x00\x00\xf0\x9f"]) @@ -297,7 +294,7 @@ def test_client_sends_fragmented_text(self): self.assertEqual(client.data_to_send(), [b"\x80\x82\x00\x00\x00\x00\x98\x80"]) def test_server_sends_fragmented_text(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_text("😀".encode()[:2], fin=False) self.assertEqual(server.data_to_send(), [b"\x01\x02\xf0\x9f"]) server.send_continuation("😀😀".encode()[2:6], fin=False) @@ -306,7 +303,7 @@ def test_server_sends_fragmented_text(self): self.assertEqual(server.data_to_send(), [b"\x80\x02\x98\x80"]) def test_client_receives_fragmented_text(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x01\x02\xf0\x9f") self.assertFrameReceived( client, @@ -324,7 +321,7 @@ def test_client_receives_fragmented_text(self): ) def test_server_receives_fragmented_text(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") self.assertFrameReceived( server, @@ -342,7 +339,7 @@ def test_server_receives_fragmented_text(self): ) def test_client_receives_fragmented_text_over_size_limit(self): - client = Connection(Side.CLIENT, max_size=3) + client = Connection(CLIENT, max_size=3) client.receive_data(b"\x01\x02\xf0\x9f") self.assertFrameReceived( client, @@ -354,7 +351,7 @@ def test_client_receives_fragmented_text_over_size_limit(self): self.assertConnectionFailing(client, 1009, "over size limit (2 > 1 bytes)") def test_server_receives_fragmented_text_over_size_limit(self): - server = Connection(Side.SERVER, max_size=3) + server = Connection(SERVER, max_size=3) server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") self.assertFrameReceived( server, @@ -366,7 +363,7 @@ def test_server_receives_fragmented_text_over_size_limit(self): self.assertConnectionFailing(server, 1009, "over size limit (2 > 1 bytes)") def test_client_receives_fragmented_text_without_size_limit(self): - client = Connection(Side.CLIENT, max_size=None) + client = Connection(CLIENT, max_size=None) client.receive_data(b"\x01\x02\xf0\x9f") self.assertFrameReceived( client, @@ -384,7 +381,7 @@ def test_client_receives_fragmented_text_without_size_limit(self): ) def test_server_receives_fragmented_text_without_size_limit(self): - server = Connection(Side.SERVER, max_size=None) + server = Connection(SERVER, max_size=None) server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") self.assertFrameReceived( server, @@ -402,21 +399,21 @@ def test_server_receives_fragmented_text_without_size_limit(self): ) def test_client_sends_unexpected_text(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.send_text(b"", fin=False) with self.assertRaises(ProtocolError) as raised: client.send_text(b"", fin=False) self.assertEqual(str(raised.exception), "expected a continuation frame") def test_server_sends_unexpected_text(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_text(b"", fin=False) with self.assertRaises(ProtocolError) as raised: server.send_text(b"", fin=False) self.assertEqual(str(raised.exception), "expected a continuation frame") def test_client_receives_unexpected_text(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x01\x00") self.assertFrameReceived( client, @@ -428,7 +425,7 @@ def test_client_receives_unexpected_text(self): self.assertConnectionFailing(client, 1002, "expected a continuation frame") def test_server_receives_unexpected_text(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x01\x80\x00\x00\x00\x00") self.assertFrameReceived( server, @@ -440,7 +437,7 @@ def test_server_receives_unexpected_text(self): self.assertConnectionFailing(server, 1002, "expected a continuation frame") def test_client_sends_text_after_sending_close(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(1001) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) @@ -448,14 +445,14 @@ def test_client_sends_text_after_sending_close(self): client.send_text(b"") def test_server_sends_text_after_sending_close(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_close(1000) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) with self.assertRaises(InvalidState): server.send_text(b"") def test_client_receives_text_after_receiving_close(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, 1000) client.receive_data(b"\x81\x00") @@ -463,7 +460,7 @@ def test_client_receives_text_after_receiving_close(self): self.assertFrameSent(client, None) def test_server_receives_text_after_receiving_close(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, 1001) server.receive_data(b"\x81\x80\x00\xff\x00\xff") @@ -478,7 +475,7 @@ class BinaryTests(ConnectionTestCase): """ def test_client_sends_binary(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_binary(b"\x01\x02\xfe\xff") self.assertEqual( @@ -486,12 +483,12 @@ def test_client_sends_binary(self): ) def test_server_sends_binary(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_binary(b"\x01\x02\xfe\xff") self.assertEqual(server.data_to_send(), [b"\x82\x04\x01\x02\xfe\xff"]) def test_client_receives_binary(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x82\x04\x01\x02\xfe\xff") self.assertFrameReceived( client, @@ -499,7 +496,7 @@ def test_client_receives_binary(self): ) def test_server_receives_binary(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff") self.assertFrameReceived( server, @@ -507,21 +504,21 @@ def test_server_receives_binary(self): ) def test_client_receives_binary_over_size_limit(self): - client = Connection(Side.CLIENT, max_size=3) + client = Connection(CLIENT, max_size=3) client.receive_data(b"\x82\x04\x01\x02\xfe\xff") self.assertIsInstance(client.parser_exc, PayloadTooBig) self.assertEqual(str(client.parser_exc), "over size limit (4 > 3 bytes)") self.assertConnectionFailing(client, 1009, "over size limit (4 > 3 bytes)") def test_server_receives_binary_over_size_limit(self): - server = Connection(Side.SERVER, max_size=3) + server = Connection(SERVER, max_size=3) server.receive_data(b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff") self.assertIsInstance(server.parser_exc, PayloadTooBig) self.assertEqual(str(server.parser_exc), "over size limit (4 > 3 bytes)") self.assertConnectionFailing(server, 1009, "over size limit (4 > 3 bytes)") def test_client_sends_fragmented_binary(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_binary(b"\x01\x02", fin=False) self.assertEqual(client.data_to_send(), [b"\x02\x82\x00\x00\x00\x00\x01\x02"]) @@ -535,7 +532,7 @@ def test_client_sends_fragmented_binary(self): self.assertEqual(client.data_to_send(), [b"\x80\x82\x00\x00\x00\x00\xee\xff"]) def test_server_sends_fragmented_binary(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_binary(b"\x01\x02", fin=False) self.assertEqual(server.data_to_send(), [b"\x02\x02\x01\x02"]) server.send_continuation(b"\xee\xff\x01\x02", fin=False) @@ -544,7 +541,7 @@ def test_server_sends_fragmented_binary(self): self.assertEqual(server.data_to_send(), [b"\x80\x02\xee\xff"]) def test_client_receives_fragmented_binary(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x02\x02\x01\x02") self.assertFrameReceived( client, @@ -562,7 +559,7 @@ def test_client_receives_fragmented_binary(self): ) def test_server_receives_fragmented_binary(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x02\x82\x00\x00\x00\x00\x01\x02") self.assertFrameReceived( server, @@ -580,7 +577,7 @@ def test_server_receives_fragmented_binary(self): ) def test_client_receives_fragmented_binary_over_size_limit(self): - client = Connection(Side.CLIENT, max_size=3) + client = Connection(CLIENT, max_size=3) client.receive_data(b"\x02\x02\x01\x02") self.assertFrameReceived( client, @@ -592,7 +589,7 @@ def test_client_receives_fragmented_binary_over_size_limit(self): self.assertConnectionFailing(client, 1009, "over size limit (2 > 1 bytes)") def test_server_receives_fragmented_binary_over_size_limit(self): - server = Connection(Side.SERVER, max_size=3) + server = Connection(SERVER, max_size=3) server.receive_data(b"\x02\x82\x00\x00\x00\x00\x01\x02") self.assertFrameReceived( server, @@ -604,21 +601,21 @@ def test_server_receives_fragmented_binary_over_size_limit(self): self.assertConnectionFailing(server, 1009, "over size limit (2 > 1 bytes)") def test_client_sends_unexpected_binary(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.send_binary(b"", fin=False) with self.assertRaises(ProtocolError) as raised: client.send_binary(b"", fin=False) self.assertEqual(str(raised.exception), "expected a continuation frame") def test_server_sends_unexpected_binary(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_binary(b"", fin=False) with self.assertRaises(ProtocolError) as raised: server.send_binary(b"", fin=False) self.assertEqual(str(raised.exception), "expected a continuation frame") def test_client_receives_unexpected_binary(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x02\x00") self.assertFrameReceived( client, @@ -630,7 +627,7 @@ def test_client_receives_unexpected_binary(self): self.assertConnectionFailing(client, 1002, "expected a continuation frame") def test_server_receives_unexpected_binary(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x02\x80\x00\x00\x00\x00") self.assertFrameReceived( server, @@ -642,7 +639,7 @@ def test_server_receives_unexpected_binary(self): self.assertConnectionFailing(server, 1002, "expected a continuation frame") def test_client_sends_binary_after_sending_close(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(1001) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) @@ -650,14 +647,14 @@ def test_client_sends_binary_after_sending_close(self): client.send_binary(b"") def test_server_sends_binary_after_sending_close(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_close(1000) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) with self.assertRaises(InvalidState): server.send_binary(b"") def test_client_receives_binary_after_receiving_close(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, 1000) client.receive_data(b"\x82\x00") @@ -665,7 +662,7 @@ def test_client_receives_binary_after_receiving_close(self): self.assertFrameSent(client, None) def test_server_receives_binary_after_receiving_close(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, 1001) server.receive_data(b"\x82\x80\x00\xff\x00\xff") @@ -686,78 +683,78 @@ class CloseTests(ConnectionTestCase): """ def test_close_code(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x04\x03\xe8OK") client.receive_eof() self.assertEqual(client.close_code, 1000) def test_close_reason(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x84\x00\x00\x00\x00\x03\xe8OK") server.receive_eof() self.assertEqual(server.close_reason, "OK") def test_close_code_not_provided(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x80\x00\x00\x00\x00") server.receive_eof() self.assertEqual(server.close_code, 1005) def test_close_reason_not_provided(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x00") client.receive_eof() self.assertEqual(client.close_reason, "") def test_close_code_not_available(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_eof() self.assertEqual(client.close_code, 1006) def test_close_reason_not_available(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_eof() self.assertEqual(server.close_reason, "") def test_close_code_not_available_yet(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) self.assertIsNone(server.close_code) def test_close_reason_not_available_yet(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) self.assertIsNone(client.close_reason) def test_client_sends_close(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x3c\x3c\x3c\x3c"): client.send_close() self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) - self.assertIs(client.state, State.CLOSING) + self.assertIs(client.state, CLOSING) def test_server_sends_close(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_close() self.assertEqual(server.data_to_send(), [b"\x88\x00"]) - self.assertIs(server.state, State.CLOSING) + self.assertIs(server.state, CLOSING) def test_client_receives_close(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x3c\x3c\x3c\x3c"): client.receive_data(b"\x88\x00") self.assertEqual(client.events_received(), [Frame(OP_CLOSE, b"")]) self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) - self.assertIs(client.state, State.CLOSING) + self.assertIs(client.state, CLOSING) def test_server_receives_close(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") self.assertEqual(server.events_received(), [Frame(OP_CLOSE, b"")]) self.assertEqual(server.data_to_send(), [b"\x88\x00", b""]) - self.assertIs(server.state, State.CLOSING) + self.assertIs(server.state, CLOSING) def test_client_sends_close_then_receives_close(self): # Client-initiated close handshake on the client side. - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.send_close() self.assertFrameReceived(client, None) @@ -773,7 +770,7 @@ def test_client_sends_close_then_receives_close(self): def test_server_sends_close_then_receives_close(self): # Server-initiated close handshake on the server side. - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_close() self.assertFrameReceived(server, None) @@ -789,7 +786,7 @@ def test_server_sends_close_then_receives_close(self): def test_client_receives_close_then_sends_close(self): # Server-initiated close handshake on the client side. - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x00") self.assertFrameReceived(client, Frame(OP_CLOSE, b"")) @@ -801,7 +798,7 @@ def test_client_receives_close_then_sends_close(self): def test_server_receives_close_then_sends_close(self): # Client-initiated close handshake on the server side. - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") self.assertFrameReceived(server, Frame(OP_CLOSE, b"")) @@ -812,87 +809,87 @@ def test_server_receives_close_then_sends_close(self): self.assertFrameSent(server, None) def test_client_sends_close_with_code(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(1001) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - self.assertIs(client.state, State.CLOSING) + self.assertIs(client.state, CLOSING) def test_server_sends_close_with_code(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_close(1000) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) - self.assertIs(server.state, State.CLOSING) + self.assertIs(server.state, CLOSING) def test_client_receives_close_with_code(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, 1000, "") - self.assertIs(client.state, State.CLOSING) + self.assertIs(client.state, CLOSING) def test_server_receives_close_with_code(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, 1001, "") - self.assertIs(server.state, State.CLOSING) + self.assertIs(server.state, CLOSING) def test_client_sends_close_with_code_and_reason(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(1001, "going away") self.assertEqual( client.data_to_send(), [b"\x88\x8c\x00\x00\x00\x00\x03\xe9going away"] ) - self.assertIs(client.state, State.CLOSING) + self.assertIs(client.state, CLOSING) def test_server_sends_close_with_code_and_reason(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_close(1000, "OK") self.assertEqual(server.data_to_send(), [b"\x88\x04\x03\xe8OK"]) - self.assertIs(server.state, State.CLOSING) + self.assertIs(server.state, CLOSING) def test_client_receives_close_with_code_and_reason(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x04\x03\xe8OK") self.assertConnectionClosing(client, 1000, "OK") - self.assertIs(client.state, State.CLOSING) + self.assertIs(client.state, CLOSING) def test_server_receives_close_with_code_and_reason(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x8c\x00\x00\x00\x00\x03\xe9going away") self.assertConnectionClosing(server, 1001, "going away") - self.assertIs(server.state, State.CLOSING) + self.assertIs(server.state, CLOSING) def test_client_sends_close_with_reason_only(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.assertRaises(ProtocolError) as raised: client.send_close(reason="going away") self.assertEqual(str(raised.exception), "cannot send a reason without a code") def test_server_sends_close_with_reason_only(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) with self.assertRaises(ProtocolError) as raised: server.send_close(reason="OK") self.assertEqual(str(raised.exception), "cannot send a reason without a code") def test_client_receives_close_with_truncated_code(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x01\x03") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "close frame too short") self.assertConnectionFailing(client, 1002, "close frame too short") - self.assertIs(client.state, State.CLOSING) + self.assertIs(client.state, CLOSING) def test_server_receives_close_with_truncated_code(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x81\x00\x00\x00\x00\x03") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "close frame too short") self.assertConnectionFailing(server, 1002, "close frame too short") - self.assertIs(server.state, State.CLOSING) + self.assertIs(server.state, CLOSING) def test_client_receives_close_with_non_utf8_reason(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x04\x03\xe8\xff\xff") self.assertIsInstance(client.parser_exc, UnicodeDecodeError) @@ -901,10 +898,10 @@ def test_client_receives_close_with_non_utf8_reason(self): "'utf-8' codec can't decode byte 0xff in position 0: invalid start byte", ) self.assertConnectionFailing(client, 1007, "invalid start byte at position 0") - self.assertIs(client.state, State.CLOSING) + self.assertIs(client.state, CLOSING) def test_server_receives_close_with_non_utf8_reason(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x84\x00\x00\x00\x00\x03\xe9\xff\xff") self.assertIsInstance(server.parser_exc, UnicodeDecodeError) @@ -913,7 +910,7 @@ def test_server_receives_close_with_non_utf8_reason(self): "'utf-8' codec can't decode byte 0xff in position 0: invalid start byte", ) self.assertConnectionFailing(server, 1007, "invalid start byte at position 0") - self.assertIs(server.state, State.CLOSING) + self.assertIs(server.state, CLOSING) class PingTests(ConnectionTestCase): @@ -923,18 +920,18 @@ class PingTests(ConnectionTestCase): """ def test_client_sends_ping(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x00\x44\x88\xcc"): client.send_ping(b"") self.assertEqual(client.data_to_send(), [b"\x89\x80\x00\x44\x88\xcc"]) def test_server_sends_ping(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_ping(b"") self.assertEqual(server.data_to_send(), [b"\x89\x00"]) def test_client_receives_ping(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x89\x00") self.assertFrameReceived( client, @@ -946,7 +943,7 @@ def test_client_receives_ping(self): ) def test_server_receives_ping(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x89\x80\x00\x44\x88\xcc") self.assertFrameReceived( server, @@ -958,7 +955,7 @@ def test_server_receives_ping(self): ) def test_client_sends_ping_with_data(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x00\x44\x88\xcc"): client.send_ping(b"\x22\x66\xaa\xee") self.assertEqual( @@ -966,12 +963,12 @@ def test_client_sends_ping_with_data(self): ) def test_server_sends_ping_with_data(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_ping(b"\x22\x66\xaa\xee") self.assertEqual(server.data_to_send(), [b"\x89\x04\x22\x66\xaa\xee"]) def test_client_receives_ping_with_data(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x89\x04\x22\x66\xaa\xee") self.assertFrameReceived( client, @@ -983,7 +980,7 @@ def test_client_receives_ping_with_data(self): ) def test_server_receives_ping_with_data(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22") self.assertFrameReceived( server, @@ -995,35 +992,35 @@ def test_server_receives_ping_with_data(self): ) def test_client_sends_fragmented_ping_frame(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) # This is only possible through a private API. with self.assertRaises(ProtocolError) as raised: client.send_frame(Frame(OP_PING, b"", fin=False)) self.assertEqual(str(raised.exception), "fragmented control frame") def test_server_sends_fragmented_ping_frame(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) # This is only possible through a private API. with self.assertRaises(ProtocolError) as raised: server.send_frame(Frame(OP_PING, b"", fin=False)) self.assertEqual(str(raised.exception), "fragmented control frame") def test_client_receives_fragmented_ping_frame(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x09\x00") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "fragmented control frame") self.assertConnectionFailing(client, 1002, "fragmented control frame") def test_server_receives_fragmented_ping_frame(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x09\x80\x3c\x3c\x3c\x3c") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "fragmented control frame") self.assertConnectionFailing(server, 1002, "fragmented control frame") def test_client_sends_ping_after_sending_close(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(1001) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) @@ -1038,7 +1035,7 @@ def test_client_sends_ping_after_sending_close(self): ) def test_server_sends_ping_after_sending_close(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_close(1000) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) # The spec says: "An endpoint MAY send a Ping frame any time (...) @@ -1052,7 +1049,7 @@ def test_server_sends_ping_after_sending_close(self): ) def test_client_receives_ping_after_receiving_close(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, 1000) client.receive_data(b"\x89\x04\x22\x66\xaa\xee") @@ -1060,7 +1057,7 @@ def test_client_receives_ping_after_receiving_close(self): self.assertFrameSent(client, None) def test_server_receives_ping_after_receiving_close(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, 1001) server.receive_data(b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22") @@ -1075,18 +1072,18 @@ class PongTests(ConnectionTestCase): """ def test_client_sends_pong(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x00\x44\x88\xcc"): client.send_pong(b"") self.assertEqual(client.data_to_send(), [b"\x8a\x80\x00\x44\x88\xcc"]) def test_server_sends_pong(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_pong(b"") self.assertEqual(server.data_to_send(), [b"\x8a\x00"]) def test_client_receives_pong(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x8a\x00") self.assertFrameReceived( client, @@ -1094,7 +1091,7 @@ def test_client_receives_pong(self): ) def test_server_receives_pong(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x8a\x80\x00\x44\x88\xcc") self.assertFrameReceived( server, @@ -1102,7 +1099,7 @@ def test_server_receives_pong(self): ) def test_client_sends_pong_with_data(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x00\x44\x88\xcc"): client.send_pong(b"\x22\x66\xaa\xee") self.assertEqual( @@ -1110,12 +1107,12 @@ def test_client_sends_pong_with_data(self): ) def test_server_sends_pong_with_data(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_pong(b"\x22\x66\xaa\xee") self.assertEqual(server.data_to_send(), [b"\x8a\x04\x22\x66\xaa\xee"]) def test_client_receives_pong_with_data(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x8a\x04\x22\x66\xaa\xee") self.assertFrameReceived( client, @@ -1123,7 +1120,7 @@ def test_client_receives_pong_with_data(self): ) def test_server_receives_pong_with_data(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22") self.assertFrameReceived( server, @@ -1131,35 +1128,35 @@ def test_server_receives_pong_with_data(self): ) def test_client_sends_fragmented_pong_frame(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) # This is only possible through a private API. with self.assertRaises(ProtocolError) as raised: client.send_frame(Frame(OP_PONG, b"", fin=False)) self.assertEqual(str(raised.exception), "fragmented control frame") def test_server_sends_fragmented_pong_frame(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) # This is only possible through a private API. with self.assertRaises(ProtocolError) as raised: server.send_frame(Frame(OP_PONG, b"", fin=False)) self.assertEqual(str(raised.exception), "fragmented control frame") def test_client_receives_fragmented_pong_frame(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x0a\x00") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "fragmented control frame") self.assertConnectionFailing(client, 1002, "fragmented control frame") def test_server_receives_fragmented_pong_frame(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x0a\x80\x3c\x3c\x3c\x3c") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "fragmented control frame") self.assertConnectionFailing(server, 1002, "fragmented control frame") def test_client_sends_pong_after_sending_close(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(1001) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) @@ -1168,7 +1165,7 @@ def test_client_sends_pong_after_sending_close(self): client.send_pong(b"") def test_server_sends_pong_after_sending_close(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_close(1000) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) # websockets doesn't support sending a Pong frame after a Close frame. @@ -1176,7 +1173,7 @@ def test_server_sends_pong_after_sending_close(self): server.send_pong(b"") def test_client_receives_pong_after_receiving_close(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, 1000) client.receive_data(b"\x8a\x04\x22\x66\xaa\xee") @@ -1184,7 +1181,7 @@ def test_client_receives_pong_after_receiving_close(self): self.assertFrameSent(client, None) def test_server_receives_pong_after_receiving_close(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, 1001) server.receive_data(b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22") @@ -1201,14 +1198,14 @@ class FailTests(ConnectionTestCase): """ def test_client_stops_processing_frames_after_fail(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.fail(1002) self.assertConnectionFailing(client, 1002) client.receive_data(b"\x88\x02\x03\xea") self.assertFrameReceived(client, None) def test_server_stops_processing_frames_after_fail(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.fail(1002) self.assertConnectionFailing(server, 1002) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xea") @@ -1224,7 +1221,7 @@ class FragmentationTests(ConnectionTestCase): """ def test_client_send_ping_pong_in_fragmented_message(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.send_text(b"Spam", fin=False) self.assertFrameSent(client, Frame(OP_TEXT, b"Spam", fin=False)) client.send_ping(b"Ping") @@ -1237,7 +1234,7 @@ def test_client_send_ping_pong_in_fragmented_message(self): self.assertFrameSent(client, Frame(OP_CONT, b"Eggs")) def test_server_send_ping_pong_in_fragmented_message(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_text(b"Spam", fin=False) self.assertFrameSent(server, Frame(OP_TEXT, b"Spam", fin=False)) server.send_ping(b"Ping") @@ -1250,7 +1247,7 @@ def test_server_send_ping_pong_in_fragmented_message(self): self.assertFrameSent(server, Frame(OP_CONT, b"Eggs")) def test_client_receive_ping_pong_in_fragmented_message(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x01\x04Spam") self.assertFrameReceived( client, @@ -1282,7 +1279,7 @@ def test_client_receive_ping_pong_in_fragmented_message(self): ) def test_server_receive_ping_pong_in_fragmented_message(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x01\x84\x00\x00\x00\x00Spam") self.assertFrameReceived( server, @@ -1314,7 +1311,7 @@ def test_server_receive_ping_pong_in_fragmented_message(self): ) def test_client_send_close_in_fragmented_message(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.send_text(b"Spam", fin=False) self.assertFrameSent(client, Frame(OP_TEXT, b"Spam", fin=False)) # The spec says: "An endpoint MUST be capable of handling control @@ -1327,7 +1324,7 @@ def test_client_send_close_in_fragmented_message(self): client.send_continuation(b"Eggs", fin=True) def test_server_send_close_in_fragmented_message(self): - server = Connection(Side.CLIENT) + server = Connection(CLIENT) server.send_text(b"Spam", fin=False) self.assertFrameSent(server, Frame(OP_TEXT, b"Spam", fin=False)) # The spec says: "An endpoint MUST be capable of handling control @@ -1339,7 +1336,7 @@ def test_server_send_close_in_fragmented_message(self): self.assertEqual(str(raised.exception), "expected a continuation frame") def test_client_receive_close_in_fragmented_message(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x01\x04Spam") self.assertFrameReceived( client, @@ -1355,7 +1352,7 @@ def test_client_receive_close_in_fragmented_message(self): self.assertConnectionFailing(client, 1002, "incomplete fragmented message") def test_server_receive_close_in_fragmented_message(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x01\x84\x00\x00\x00\x00Spam") self.assertFrameReceived( server, @@ -1378,35 +1375,35 @@ class EOFTests(ConnectionTestCase): """ def test_client_receives_eof(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x00") self.assertConnectionClosing(client) client.receive_eof() - self.assertIs(client.state, State.CLOSED) + self.assertIs(client.state, CLOSED) def test_server_receives_eof(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") self.assertConnectionClosing(server) server.receive_eof() - self.assertIs(server.state, State.CLOSED) + self.assertIs(server.state, CLOSED) def test_client_receives_eof_between_frames(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_eof() self.assertIsInstance(client.parser_exc, EOFError) self.assertEqual(str(client.parser_exc), "unexpected end of stream") - self.assertIs(client.state, State.CLOSED) + self.assertIs(client.state, CLOSED) def test_server_receives_eof_between_frames(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_eof() self.assertIsInstance(server.parser_exc, EOFError) self.assertEqual(str(server.parser_exc), "unexpected end of stream") - self.assertIs(server.state, State.CLOSED) + self.assertIs(server.state, CLOSED) def test_client_receives_eof_inside_frame(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x81") client.receive_eof() self.assertIsInstance(client.parser_exc, EOFError) @@ -1414,10 +1411,10 @@ def test_client_receives_eof_inside_frame(self): str(client.parser_exc), "stream ends after 1 bytes, expected 2 bytes", ) - self.assertIs(client.state, State.CLOSED) + self.assertIs(client.state, CLOSED) def test_server_receives_eof_inside_frame(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x81") server.receive_eof() self.assertIsInstance(server.parser_exc, EOFError) @@ -1425,38 +1422,38 @@ def test_server_receives_eof_inside_frame(self): str(server.parser_exc), "stream ends after 1 bytes, expected 2 bytes", ) - self.assertIs(server.state, State.CLOSED) + self.assertIs(server.state, CLOSED) def test_client_receives_data_after_exception(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\xff\xff") self.assertConnectionFailing(client, 1002, "invalid opcode") client.receive_data(b"\x00\x00") self.assertFrameSent(client, None) def test_server_receives_data_after_exception(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\xff\xff") self.assertConnectionFailing(server, 1002, "invalid opcode") server.receive_data(b"\x00\x00") self.assertFrameSent(server, None) def test_client_receives_eof_after_exception(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\xff\xff") self.assertConnectionFailing(client, 1002, "invalid opcode") client.receive_eof() self.assertFrameSent(client, None, eof=True) def test_server_receives_eof_after_exception(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\xff\xff") self.assertConnectionFailing(server, 1002, "invalid opcode") server.receive_eof() self.assertFrameSent(server, None) def test_client_receives_data_and_eof_after_exception(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\xff\xff") self.assertConnectionFailing(client, 1002, "invalid opcode") client.receive_data(b"\x00\x00") @@ -1464,7 +1461,7 @@ def test_client_receives_data_and_eof_after_exception(self): self.assertFrameSent(client, None, eof=True) def test_server_receives_data_and_eof_after_exception(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\xff\xff") self.assertConnectionFailing(server, 1002, "invalid opcode") server.receive_data(b"\x00\x00") @@ -1472,7 +1469,7 @@ def test_server_receives_data_and_eof_after_exception(self): self.assertFrameSent(server, None) def test_client_receives_data_after_eof(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x00") self.assertConnectionClosing(client) client.receive_eof() @@ -1481,7 +1478,7 @@ def test_client_receives_data_after_eof(self): self.assertEqual(str(raised.exception), "stream ended") def test_server_receives_data_after_eof(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") self.assertConnectionClosing(server) server.receive_eof() @@ -1490,7 +1487,7 @@ def test_server_receives_data_after_eof(self): self.assertEqual(str(raised.exception), "stream ended") def test_client_receives_eof_after_eof(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x00") self.assertConnectionClosing(client) client.receive_eof() @@ -1499,7 +1496,7 @@ def test_client_receives_eof_after_eof(self): self.assertEqual(str(raised.exception), "stream ended") def test_server_receives_eof_after_eof(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") self.assertConnectionClosing(server) server.receive_eof() @@ -1515,52 +1512,52 @@ class TCPCloseTests(ConnectionTestCase): """ def test_client_default(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) self.assertFalse(client.close_expected()) def test_server_default(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) self.assertFalse(server.close_expected()) def test_client_sends_close(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.send_close() self.assertTrue(client.close_expected()) def test_server_sends_close(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_close() self.assertTrue(server.close_expected()) def test_client_receives_close(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x00") self.assertTrue(client.close_expected()) def test_client_receives_close_then_eof(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x00") client.receive_eof() self.assertFalse(client.close_expected()) def test_server_receives_close_then_eof(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") server.receive_eof() self.assertFalse(server.close_expected()) def test_server_receives_close(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") self.assertTrue(server.close_expected()) def test_client_fails_connection(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.fail(1002) self.assertTrue(client.close_expected()) def test_server_fails_connection(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.fail(1002) self.assertTrue(server.close_expected()) @@ -1573,7 +1570,7 @@ class ConnectionClosedTests(ConnectionTestCase): def test_client_sends_close_then_receives_close(self): # Client-initiated close handshake on the client side complete. - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.send_close(1000, "") client.receive_data(b"\x88\x02\x03\xe8") client.receive_eof() @@ -1585,7 +1582,7 @@ def test_client_sends_close_then_receives_close(self): def test_server_sends_close_then_receives_close(self): # Server-initiated close handshake on the server side complete. - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_close(1000, "") server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe8") server.receive_eof() @@ -1597,7 +1594,7 @@ def test_server_sends_close_then_receives_close(self): def test_client_receives_close_then_sends_close(self): # Server-initiated close handshake on the client side complete. - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") client.receive_eof() exc = client.close_exc @@ -1608,7 +1605,7 @@ def test_client_receives_close_then_sends_close(self): def test_server_receives_close_then_sends_close(self): # Client-initiated close handshake on the server side complete. - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe8") server.receive_eof() exc = server.close_exc @@ -1619,7 +1616,7 @@ def test_server_receives_close_then_sends_close(self): def test_client_sends_close_then_receives_eof(self): # Client-initiated close handshake on the client side times out. - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.send_close(1000, "") client.receive_eof() exc = client.close_exc @@ -1630,7 +1627,7 @@ def test_client_sends_close_then_receives_eof(self): def test_server_sends_close_then_receives_eof(self): # Server-initiated close handshake on the server side times out. - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_close(1000, "") server.receive_eof() exc = server.close_exc @@ -1641,7 +1638,7 @@ def test_server_sends_close_then_receives_eof(self): def test_client_receives_eof(self): # Server-initiated close handshake on the client side times out. - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_eof() exc = client.close_exc self.assertIsInstance(exc, ConnectionClosedError) @@ -1651,7 +1648,7 @@ def test_client_receives_eof(self): def test_server_receives_eof(self): # Client-initiated close handshake on the server side times out. - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_eof() exc = server.close_exc self.assertIsInstance(exc, ConnectionClosedError) @@ -1667,7 +1664,7 @@ class ErrorTests(ConnectionTestCase): """ def test_client_hits_internal_error_reading_frame(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) # This isn't supposed to happen, so we're simulating it. with unittest.mock.patch("struct.unpack", side_effect=RuntimeError("BOOM")): client.receive_data(b"\x81\x00") @@ -1676,7 +1673,7 @@ def test_client_hits_internal_error_reading_frame(self): self.assertConnectionFailing(client, 1011, "") def test_server_hits_internal_error_reading_frame(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) # This isn't supposed to happen, so we're simulating it. with unittest.mock.patch("struct.unpack", side_effect=RuntimeError("BOOM")): server.receive_data(b"\x81\x80\x00\x00\x00\x00") @@ -1692,26 +1689,26 @@ class ExtensionsTests(ConnectionTestCase): """ def test_client_extension_encodes_frame(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.extensions = [Rsv2Extension()] with self.enforce_mask(b"\x00\x44\x88\xcc"): client.send_ping(b"") self.assertEqual(client.data_to_send(), [b"\xa9\x80\x00\x44\x88\xcc"]) def test_server_extension_encodes_frame(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.extensions = [Rsv2Extension()] server.send_ping(b"") self.assertEqual(server.data_to_send(), [b"\xa9\x00"]) def test_client_extension_decodes_frame(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.extensions = [Rsv2Extension()] client.receive_data(b"\xaa\x00") self.assertEqual(client.events_received(), [Frame(OP_PONG, b"")]) def test_server_extension_decodes_frame(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.extensions = [Rsv2Extension()] server.receive_data(b"\xaa\x80\x00\x44\x88\xcc") self.assertEqual(server.events_received(), [Frame(OP_PONG, b"")]) From 413ae6ee4e75c03a790400f24d691e3e6badeb45 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 Oct 2022 11:24:32 +0200 Subject: [PATCH 1114/1539] Mark function for future removal. --- tests/legacy/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/legacy/utils.py b/tests/legacy/utils.py index bc37ee7bf..302fd70be 100644 --- a/tests/legacy/utils.py +++ b/tests/legacy/utils.py @@ -57,6 +57,7 @@ def run_loop_once(self): self.loop.call_soon(self.loop.stop) self.loop.run_forever() + # Remove when dropping Python < 3.10 @contextlib.contextmanager def assertNoLogs(self, logger="websockets", level=logging.ERROR): """ From 42ce29ab423233c936e59c0502c44549ec2bdf87 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 17 Aug 2022 20:17:34 +0200 Subject: [PATCH 1115/1539] Add Python 3.11. --- .github/workflows/tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4785e8d20..4c1b1e899 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -43,6 +43,7 @@ jobs: - "3.8" - "3.9" - "3.10" + - "3.11" - "pypy-3.7" - "pypy-3.8" - "pypy-3.9" From d160c1bc51ea83700d6de2d7da928aba3a034b52 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 15 Oct 2022 05:18:59 +0000 Subject: [PATCH 1116/1539] Bump pypa/cibuildwheel from 2.10.0 to 2.11.1 Bumps [pypa/cibuildwheel](https://github.com/pypa/cibuildwheel) from 2.10.0 to 2.11.1. - [Release notes](https://github.com/pypa/cibuildwheel/releases) - [Changelog](https://github.com/pypa/cibuildwheel/blob/main/docs/changelog.md) - [Commits](https://github.com/pypa/cibuildwheel/compare/v2.10.0...v2.11.1) --- updated-dependencies: - dependency-name: pypa/cibuildwheel dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/wheels.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index cc22cd4c6..0013ad103 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -48,7 +48,7 @@ jobs: with: platforms: all - name: Build wheels - uses: pypa/cibuildwheel@v2.10.0 + uses: pypa/cibuildwheel@v2.11.1 env: CIBW_ARCHS_MACOS: "x86_64 universal2 arm64" CIBW_ARCHS_LINUX: "auto aarch64" From 9230cca4b5e6e100a57407ac61dbf20008b7225a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 25 Oct 2022 21:48:26 +0200 Subject: [PATCH 1117/1539] Complete changelog for 10.4. --- docs/project/changelog.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 8ad7bc5ab..c641a050e 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -44,6 +44,13 @@ New features * Supported overriding or removing the ``User-Agent`` header in clients and the ``Server`` header in servers. +* Added deployment guides for more Platform as a Service providers. + +Improvements +............ + +* Improved FAQ. + 10.3 ---- From d8c17625857797db029110d74417ae5f840aeb75 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 25 Oct 2022 21:39:56 +0200 Subject: [PATCH 1118/1539] Release version 10.4 --- docs/project/changelog.rst | 2 +- src/websockets/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index c641a050e..8ad509a08 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -28,7 +28,7 @@ They may change at any time. 10.4 ---- -*In development* +*October 25, 2022* New features ............ diff --git a/src/websockets/version.py b/src/websockets/version.py index 29d658ce2..3fdc4fad9 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -16,7 +16,7 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = False +released = True tag = version = commit = "10.4" From b2847c2786a88c8c6f6017081f828032fb42ded3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 25 Oct 2022 22:04:09 +0200 Subject: [PATCH 1119/1539] Start version 11.0 --- docs/project/changelog.rst | 5 +++++ src/websockets/version.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 8ad509a08..61136a36c 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,11 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented APIs are considered private. They may change at any time. +11.0 +---- + +*In development* + 10.4 ---- diff --git a/src/websockets/version.py b/src/websockets/version.py index 3fdc4fad9..1a3d884d8 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -16,9 +16,9 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = True +released = False -tag = version = commit = "10.4" +tag = version = commit = "11.0" if not released: # pragma: no cover From 06ffba57f9460ff577eada5ad3fe2558593be4b9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 25 Oct 2022 22:42:28 +0200 Subject: [PATCH 1120/1539] Fix deprecation warning. --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index a300ce628..9126f55d0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,7 +2,7 @@ python-tag = py37.py38.py39.py310.py311 [metadata] -license_file = LICENSE +license_files = LICENSE project_urls = Changelog = https://websockets.readthedocs.io/en/stable/project/changelog.html Documentation = https://websockets.readthedocs.io/ From 2cbb8134acfa8b56aba5c40f0479f12434104a62 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 22 Aug 2022 21:52:48 +0200 Subject: [PATCH 1121/1539] Add option keep connections open when closing server. Fix #1174. --- docs/faq/server.rst | 15 +++++++ docs/project/changelog.rst | 5 +++ src/websockets/legacy/server.py | 69 +++++++++++++++++------------- tests/legacy/test_client_server.py | 17 +++++++- 4 files changed, 74 insertions(+), 32 deletions(-) diff --git a/docs/faq/server.rst b/docs/faq/server.rst index 68490d755..22a9d3a4c 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -260,6 +260,21 @@ Here's an example that terminates cleanly when it receives SIGTERM on Unix: .. literalinclude:: ../../example/shutdown_server.py :emphasize-lines: 12-15,18 +How do I stop a server while keeping existing connections open? +--------------------------------------------------------------- + +Call the server's :meth:`~server.WebSocketServer.close` method with +``close_connections=False``. + +Here's how to adapt the example just above:: + + async def server(): + ... + + server = await websockets.serve(echo, "localhost", 8765) + await stop + await server.close(close_connections=False) + How do I run HTTP and WebSocket servers on the same port? --------------------------------------------------------- diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 61136a36c..7851db80d 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -30,6 +30,11 @@ They may change at any time. *In development* +New features +............ + +* Made it possible to close a server without closing existing connections. + 10.4 ---- diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index c32c85612..89e5322bc 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -730,26 +730,30 @@ def unregister(self, protocol: WebSocketServerProtocol) -> None: """ self.websockets.remove(protocol) - def close(self) -> None: + def close(self, close_connections: bool = True) -> None: """ Close the server. - This method: + * Close the underlying :class:`~asyncio.Server`. + * When ``close_connections`` is :obj:`True`, which is the default, + close existing connections. Specifically: - * closes the underlying :class:`~asyncio.Server`; - * rejects new WebSocket connections with an HTTP 503 (service - unavailable) error; this happens when the server accepted the TCP - connection but didn't complete the WebSocket opening handshake prior - to closing; - * closes open WebSocket connections with close code 1001 (going away). + * Reject opening WebSocket connections with an HTTP 503 (service + unavailable) error. This happens when the server accepted the TCP + connection but didn't complete the opening handshake before closing. + * Close open WebSocket connections with close code 1001 (going away). + + * Wait until all connection handlers terminate. :meth:`close` is idempotent. """ if self.close_task is None: - self.close_task = self.get_loop().create_task(self._close()) + self.close_task = self.get_loop().create_task( + self._close(close_connections) + ) - async def _close(self) -> None: + async def _close(self, close_connections: bool) -> None: """ Implementation of :meth:`close`. @@ -770,21 +774,22 @@ async def _close(self) -> None: # register(). See https://bugs.python.org/issue34852 for details. await asyncio.sleep(0, **loop_if_py_lt_38(self.get_loop())) - # Close OPEN connections with status code 1001. Since the server was - # closed, handshake() closes OPENING connections with a HTTP 503 - # error. Wait until all connections are closed. - - close_tasks = [ - asyncio.create_task(websocket.close(1001)) - for websocket in self.websockets - if websocket.state is not State.CONNECTING - ] - # asyncio.wait doesn't accept an empty first argument. - if close_tasks: - await asyncio.wait( - close_tasks, - **loop_if_py_lt_38(self.get_loop()), - ) + if close_connections: + # Close OPEN connections with status code 1001. Since the server was + # closed, handshake() closes OPENING connections with a HTTP 503 + # error. Wait until all connections are closed. + + close_tasks = [ + asyncio.create_task(websocket.close(1001)) + for websocket in self.websockets + if websocket.state is not State.CONNECTING + ] + # asyncio.wait doesn't accept an empty first argument. + if close_tasks: + await asyncio.wait( + close_tasks, + **loop_if_py_lt_38(self.get_loop()), + ) # Wait until all connection handlers are complete. @@ -903,18 +908,22 @@ class Serve: server performs the closing handshake and closes the connection. Awaiting :func:`serve` yields a :class:`WebSocketServer`. This object - provides :meth:`~WebSocketServer.close` and - :meth:`~WebSocketServer.wait_closed` methods for shutting down the server. + provides a :meth:`~WebSocketServer.close` method to shut down the server:: + + stop = asyncio.Future() # set this future to exit the server + + server = await serve(...) + await stop + await server.close() - :func:`serve` can be used as an asynchronous context manager:: + :func:`serve` can be used as an asynchronous context manager. Then, the + server is shut down automatically when exiting the context:: stop = asyncio.Future() # set this future to exit the server async with serve(...): await stop - The server is shut down automatically when exiting the context. - Args: ws_handler: connection handler. It receives the WebSocket connection, which is a :class:`WebSocketServerProtocol`, in argument. diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index d02daede1..9db15c0f1 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1214,14 +1214,27 @@ def test_server_shuts_down_during_connection_handling(self): server_ws = next(iter(self.server.websockets)) self.server.close() with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.client.send("Hello!")) self.loop.run_until_complete(self.client.recv()) - # Websocket connection closes properly with 1001 Going Away. + # Server closed the connection with 1001 Going Away. self.assertEqual(self.client.close_code, 1001) self.assertEqual(server_ws.close_code, 1001) @with_server() - def test_server_shuts_down_waits_until_handlers_terminate(self): + def test_server_shuts_down_gracefully_during_connection_handling(self): + with self.temp_client(): + server_ws = next(iter(self.server.websockets)) + self.server.close(close_connections=False) + self.loop.run_until_complete(self.client.send("Hello!")) + self.loop.run_until_complete(self.client.recv()) + + # Client closed the connection with 1000 OK. + self.assertEqual(self.client.close_code, 1000) + self.assertEqual(server_ws.close_code, 1000) + + @with_server() + def test_server_shuts_down_and_waits_until_handlers_terminate(self): # This handler waits a bit after the connection is closed in order # to test that wait_closed() really waits for handlers to complete. self.start_client("/slow_stop") From 7a398921d8b41094ad4e13940c91a69db2bcfb1f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 26 Oct 2022 16:39:42 +0200 Subject: [PATCH 1122/1539] Reduce usage of # pragma: no cover. --- setup.cfg | 1 + src/websockets/connection.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/setup.cfg b/setup.cfg index 9126f55d0..74efc8042 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,3 +32,4 @@ source = exclude_lines = if self.debug: pragma: no cover + raise AssertionError diff --git a/src/websockets/connection.py b/src/websockets/connection.py index db8b53699..7a9be9f2e 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -568,7 +568,7 @@ def parse(self) -> Generator[None, None, None]: # During an abnormal closure, execution ends here after catching an # exception. At this point, fail() replaced parse() by discard(). yield - raise AssertionError("parse() shouldn't step after error") # pragma: no cover + raise AssertionError("parse() shouldn't step after error") def discard(self) -> Generator[None, None, None]: """ @@ -598,7 +598,7 @@ def discard(self) -> Generator[None, None, None]: yield # Once the reader reaches EOF, its feed_data/eof() methods raise an # error, so our receive_data/eof() methods don't step the generator. - raise AssertionError("discard() shouldn't step after EOF") # pragma: no cover + raise AssertionError("discard() shouldn't step after EOF") def recv_frame(self, frame: Frame) -> None: """ @@ -674,7 +674,7 @@ def recv_frame(self, frame: Frame) -> None: self.parser = self.discard() next(self.parser) # start coroutine - else: # pragma: no cover + else: # This can't happen because Frame.parse() validates opcodes. raise AssertionError(f"unexpected opcode: {frame.opcode:02x}") From c19506b34dbf2bcb796dbe370ed41c269436ce92 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 26 Oct 2022 16:40:23 +0200 Subject: [PATCH 1123/1539] Ignore tests that no longer run on Python 3.11. --- tests/legacy/test_client_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 9db15c0f1..3004c685d 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1371,7 +1371,7 @@ def test_checking_lack_of_origin_succeeds_backwards_compatibility(self): @unittest.skipIf( sys.version_info[:2] >= (3, 11), "asyncio.coroutine has been removed in Python 3.11" ) -class YieldFromTests(ClientServerTestsMixin, AsyncioTestCase): +class YieldFromTests(ClientServerTestsMixin, AsyncioTestCase): # pragma: no cover @with_server() def test_client(self): # @asyncio.coroutine is deprecated on Python ≥ 3.8 From d8d2ad5a824f155b1b44c9750ebf0ece792df1cc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 31 Oct 2022 22:42:44 +0100 Subject: [PATCH 1124/1539] Raise coverage of websockets.exceptions to 100%. Add tests missing from 62eb267c. --- src/websockets/exceptions.py | 8 ++++++-- tests/test_exceptions.py | 10 ++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 0c4fc5185..291b6d2cd 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -120,11 +120,15 @@ def __str__(self) -> str: @property def code(self) -> int: - return 1006 if self.rcvd is None else self.rcvd.code + if self.rcvd is None: + return 1006 + return self.rcvd.code @property def reason(self) -> str: - return "" if self.rcvd is None else self.rcvd.reason + if self.rcvd is None: + return "" + return self.rcvd.reason class ConnectionClosedError(ConnectionClosed): diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 3ede25fdb..a0f9dfcd2 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -157,3 +157,13 @@ def test_str(self): ]: with self.subTest(exception=exception): self.assertEqual(str(exception), exception_str) + + def test_connection_closed_attributes_backwards_compatibility(self): + exception = ConnectionClosed(Close(1000, "OK"), None, None) + self.assertEqual(exception.code, 1000) + self.assertEqual(exception.reason, "OK") + + def test_connection_closed_attributes_backwards_compatibility_defaults(self): + exception = ConnectionClosed(None, None, None) + self.assertEqual(exception.code, 1006) + self.assertEqual(exception.reason, "") From ca4968eb607a3a1b763e82d3b683a12b0eba67d5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 31 Oct 2022 22:43:28 +0100 Subject: [PATCH 1125/1539] Raise coverage of websockets.connection to 100%. Add tests for the logger argument. --- tests/test_connection.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/test_connection.py b/tests/test_connection.py index 44f5bc35e..3858d2521 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,3 +1,4 @@ +import logging import unittest.mock from websockets.connection import * @@ -1712,3 +1713,25 @@ def test_server_extension_decodes_frame(self): server.extensions = [Rsv2Extension()] server.receive_data(b"\xaa\x80\x00\x44\x88\xcc") self.assertEqual(server.events_received(), [Frame(OP_PONG, b"")]) + + +class MiscTests(unittest.TestCase): + def test_client_default_logger(self): + client = Connection(CLIENT) + logger = logging.getLogger("websockets.client") + self.assertIs(client.logger, logger) + + def test_server_default_logger(self): + server = Connection(SERVER) + logger = logging.getLogger("websockets.server") + self.assertIs(server.logger, logger) + + def test_client_custom_logger(self): + logger = logging.getLogger("test") + client = Connection(CLIENT, logger=logger) + self.assertIs(client.logger, logger) + + def test_server_custom_logger(self): + logger = logging.getLogger("test") + server = Connection(SERVER, logger=logger) + self.assertIs(server.logger, logger) From 892c86a017379cb57dd49ea03c142561917d353b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 31 Oct 2022 22:44:28 +0100 Subject: [PATCH 1126/1539] Move all test extensions to the same module. Some were in test_base.py, others in utils.py. --- tests/extensions/test_base.py | 36 --------------------- tests/extensions/test_permessage_deflate.py | 2 +- tests/extensions/utils.py | 35 ++++++++++++++++++++ tests/legacy/test_client_server.py | 2 +- 4 files changed, 37 insertions(+), 38 deletions(-) diff --git a/tests/extensions/test_base.py b/tests/extensions/test_base.py index 0daa34211..ba8657b65 100644 --- a/tests/extensions/test_base.py +++ b/tests/extensions/test_base.py @@ -1,40 +1,4 @@ -from websockets.exceptions import NegotiationError from websockets.extensions.base import * # noqa # Abstract classes don't provide any behavior to test. - - -class ClientNoOpExtensionFactory: - name = "x-no-op" - - def get_request_params(self): - return [] - - def process_response_params(self, params, accepted_extensions): - if params: - raise NegotiationError() - return NoOpExtension() - - -class ServerNoOpExtensionFactory: - name = "x-no-op" - - def __init__(self, params=None): - self.params = params or [] - - def process_request_params(self, params, accepted_extensions): - return self.params, NoOpExtension() - - -class NoOpExtension: - name = "x-no-op" - - def __repr__(self): - return "NoOpExtension()" - - def decode(self, frame, *, max_size=None): - return frame - - def encode(self, frame): - return frame diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index 3cc9172df..d433762c5 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -21,7 +21,7 @@ Frame, ) -from .test_base import ClientNoOpExtensionFactory, ServerNoOpExtensionFactory +from .utils import ClientNoOpExtensionFactory, ServerNoOpExtensionFactory class ExtensionTestsMixin: diff --git a/tests/extensions/utils.py b/tests/extensions/utils.py index 1eabc163f..24fb74b4e 100644 --- a/tests/extensions/utils.py +++ b/tests/extensions/utils.py @@ -46,6 +46,41 @@ def process_request_params(self, params, accepted_extensions): return [("op", self.op)], OpExtension(self.op) +class NoOpExtension: + name = "x-no-op" + + def __repr__(self): + return "NoOpExtension()" + + def decode(self, frame, *, max_size=None): + return frame + + def encode(self, frame): + return frame + + +class ClientNoOpExtensionFactory: + name = "x-no-op" + + def get_request_params(self): + return [] + + def process_response_params(self, params, accepted_extensions): + if params: + raise NegotiationError() + return NoOpExtension() + + +class ServerNoOpExtensionFactory: + name = "x-no-op" + + def __init__(self, params=None): + self.params = params or [] + + def process_request_params(self, params, accepted_extensions): + return self.params, NoOpExtension() + + class Rsv2Extension: name = "x-rsv2" diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 3004c685d..8ee7187eb 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -37,7 +37,7 @@ from websockets.legacy.server import * from websockets.uri import parse_uri -from ..extensions.test_base import ( +from ..extensions.utils import ( ClientNoOpExtensionFactory, NoOpExtension, ServerNoOpExtensionFactory, From 1948936b1af6aba31fa551dd93248f6b6f4db60e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 31 Oct 2022 22:12:21 +0100 Subject: [PATCH 1127/1539] Add a script to measure coverage per module. This makes it possible to increase coverage threshold to "each module has 100% branch coverage from its own tests". This implementation supports `make maxi_cov` and `tox -e maxi_cov`. It is also enabled in CI. --- .github/workflows/tests.yml | 23 +++++- Makefile | 7 +- setup.cfg | 4 +- tests/maxi_cov.py | 148 ++++++++++++++++++++++++++++++++++++ tox.ini | 6 ++ 5 files changed, 183 insertions(+), 5 deletions(-) create mode 100755 tests/maxi_cov.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4c1b1e899..4d5cc3cd0 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -9,8 +9,8 @@ on: - main jobs: - main: - name: Run code quality checks + coverage: + name: Run test coverage checks runs-on: ubuntu-latest steps: - name: Check out repository @@ -23,6 +23,21 @@ jobs: run: pip install tox - name: Run tests with coverage run: tox -e coverage + - name: Run tests with per-module coverage + run: tox -e maxi_cov + + quality: + name: Run code quality checks + runs-on: ubuntu-latest + steps: + - name: Check out repository + uses: actions/checkout@v3 + - name: Install Python 3.x + uses: actions/setup-python@v4 + with: + python-version: "3.x" + - name: Install tox + run: pip install tox - name: Check code formatting run: tox -e black - name: Check code style @@ -34,7 +49,9 @@ jobs: matrix: name: Run tests on Python ${{ matrix.python }} - needs: main + needs: + - coverage + - quality runs-on: ubuntu-latest strategy: matrix: diff --git a/Makefile b/Makefile index cd8095ae1..ac5d6a4aa 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: default style test coverage build clean +.PHONY: default style test coverage maxi_cov build clean export PYTHONASYNCIODEBUG=1 export PYTHONPATH=src @@ -20,6 +20,11 @@ coverage: coverage html coverage report --show-missing --fail-under=100 +maxi_cov: + python tests/maxi_cov.py + coverage html + coverage report --show-missing --fail-under=100 + build: python setup.py build_ext --inplace diff --git a/setup.cfg b/setup.cfg index 74efc8042..4c1f6091f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -21,7 +21,9 @@ lines_after_imports = 2 [coverage:run] branch = True omit = - */__main__.py + # */websockets matches src/websockets and .tox/**/site-packages/websockets + */websockets/__main__.py + tests/maxi_cov.py [coverage:paths] source = diff --git a/tests/maxi_cov.py b/tests/maxi_cov.py new file mode 100755 index 000000000..5acb979de --- /dev/null +++ b/tests/maxi_cov.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python + +"""Measure coverage of each module by its test module.""" + +import glob +import os.path +import subprocess +import sys + + +UNMAPPED_SRC_FILES = ["websockets/version.py"] +UNMAPPED_TEST_FILES = ["tests/test_exports.py"] + +IGNORED_FILES = [ + # */websockets matches src/websockets and .tox/**/site-packages/websockets. + # There are no tests for the __main__ module. + "*/websockets/__main__.py", + # This approach isn't applicable to the test suite of the legacy + # implementation, due to the huge test_client_server test module. + "*/websockets/legacy/*", + "tests/legacy/*", + # Test utilities don't fit anywhere because they are shared. + "tests/extensions/utils.py", + "tests/utils.py", + # There is no point measure the coverage of this script. + "tests/maxi_cov.py", +] + + +def check_environment(): + """Check that prerequisites for running this script are met.""" + try: + import websockets # noqa: F401 + except ImportError: + print("failed to import websockets; is src on PYTHONPATH?") + return False + try: + import coverage # noqa: F401 + except ImportError: + print("failed to locate Coverage.py; is it installed?") + return False + return True + + +def get_mapping(src_dir="src"): + """Return a dict mapping each source file to its test file.""" + + # List source and test files. + + src_files = glob.glob( + os.path.join(src_dir, "websockets/**/*.py"), + recursive=True, + ) + + test_files = glob.glob( + "tests/**/*.py", + recursive=True, + ) + + src_files = [ + os.path.relpath(src_file, src_dir) + for src_file in sorted(src_files) + if os.path.basename(src_file) != "__init__.py" + and os.path.basename(src_file) != "__main__.py" + and "legacy" not in os.path.dirname(src_file) + ] + test_files = [ + test_file + for test_file in sorted(test_files) + if os.path.basename(test_file) != "__init__.py" + and os.path.basename(test_file).startswith("test_") + and "legacy" not in os.path.dirname(test_file) + ] + + # Map source files to test files. + + mapping = {} + unmapped_test_files = [] + + for test_file in test_files: + dir_name, file_name = os.path.split(test_file) + assert dir_name.startswith("tests") + assert file_name.startswith("test_") + src_file = os.path.join( + "websockets" + dir_name[len("tests") :], + file_name[len("test_") :], + ) + if src_file in src_files: + mapping[src_file] = test_file + else: + unmapped_test_files.append(test_file) + + unmapped_src_files = list(set(src_files) - set(mapping)) + + # Ensure that all files are mapped. + + assert unmapped_src_files == UNMAPPED_SRC_FILES + assert unmapped_test_files == UNMAPPED_TEST_FILES + + return mapping + + +def run_coverage(mapping, src_dir="src"): + # Initialize a new coverage measurement session. The --source option + # includes all files in the report, even if they're never imported. + print("\nInitializing session\n", flush=True) + subprocess.run( + [ + sys.executable, + "-m", + "coverage", + "run", + "--source", + ",".join([os.path.join(src_dir, "websockets"), "tests"]), + "--omit", + ",".join(IGNORED_FILES), + "-m", + "unittest", + ] + + UNMAPPED_TEST_FILES, + check=True, + ) + # Append coverage of each source module by the corresponding test module. + for src_file, test_file in mapping.items(): + print(f"\nTesting {src_file} with {test_file}\n", flush=True) + subprocess.run( + [ + sys.executable, + "-m", + "coverage", + "run", + "--append", + "--include", + ",".join([os.path.join(src_dir, src_file), test_file]), + "-m", + "unittest", + test_file, + ], + check=True, + ) + + +if __name__ == "__main__": + if not check_environment(): + sys.exit(1) + src_dir = sys.argv[1] if len(sys.argv) == 2 else "src" + mapping = get_mapping(src_dir) + run_coverage(mapping, src_dir) diff --git a/tox.ini b/tox.ini index 2f37cdcbf..0fcab4d79 100644 --- a/tox.ini +++ b/tox.ini @@ -21,6 +21,12 @@ commands = python -m coverage report --show-missing --fail-under=100 deps = coverage +[testenv:maxi_cov] +commands = + python tests/maxi_cov.py {envsitepackagesdir} + python -m coverage report --show-missing --fail-under=100 +deps = coverage + [testenv:black] commands = black --check src tests deps = black From 17d5d41140aad185e543092f98d55ab87e7c2f29 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 1 Nov 2022 07:50:40 +0100 Subject: [PATCH 1128/1539] Determine excluded test files automatically. --- tests/maxi_cov.py | 47 +++++++++++++++++++++++++++-------------------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/tests/maxi_cov.py b/tests/maxi_cov.py index 5acb979de..77d374f96 100755 --- a/tests/maxi_cov.py +++ b/tests/maxi_cov.py @@ -11,21 +11,6 @@ UNMAPPED_SRC_FILES = ["websockets/version.py"] UNMAPPED_TEST_FILES = ["tests/test_exports.py"] -IGNORED_FILES = [ - # */websockets matches src/websockets and .tox/**/site-packages/websockets. - # There are no tests for the __main__ module. - "*/websockets/__main__.py", - # This approach isn't applicable to the test suite of the legacy - # implementation, due to the huge test_client_server test module. - "*/websockets/legacy/*", - "tests/legacy/*", - # Test utilities don't fit anywhere because they are shared. - "tests/extensions/utils.py", - "tests/utils.py", - # There is no point measure the coverage of this script. - "tests/maxi_cov.py", -] - def check_environment(): """Check that prerequisites for running this script are met.""" @@ -51,7 +36,6 @@ def get_mapping(src_dir="src"): os.path.join(src_dir, "websockets/**/*.py"), recursive=True, ) - test_files = glob.glob( "tests/**/*.py", recursive=True, @@ -60,16 +44,16 @@ def get_mapping(src_dir="src"): src_files = [ os.path.relpath(src_file, src_dir) for src_file in sorted(src_files) + if "legacy" not in os.path.dirname(src_file) if os.path.basename(src_file) != "__init__.py" and os.path.basename(src_file) != "__main__.py" - and "legacy" not in os.path.dirname(src_file) ] test_files = [ test_file for test_file in sorted(test_files) - if os.path.basename(test_file) != "__init__.py" + if "legacy" not in os.path.dirname(test_file) + and os.path.basename(test_file) != "__init__.py" and os.path.basename(test_file).startswith("test_") - and "legacy" not in os.path.dirname(test_file) ] # Map source files to test files. @@ -100,7 +84,30 @@ def get_mapping(src_dir="src"): return mapping +def get_ignored_files(src_dir="src"): + """Return the list of files to exclude from coverage measurement.""" + + return [ + # */websockets matches src/websockets and .tox/**/site-packages/websockets. + # There are no tests for the __main__ module. + "*/websockets/__main__.py", + # This approach isn't applicable to the test suite of the legacy + # implementation, due to the huge test_client_server test module. + "*/websockets/legacy/*", + "tests/legacy/*", + ] + [ + # Exclude test utilities that are shared between several test modules. + # Also excludes this script. + test_file + for test_file in sorted(glob.glob("tests/**/*.py", recursive=True)) + if "legacy" not in os.path.dirname(test_file) + and os.path.basename(test_file) != "__init__.py" + and not os.path.basename(test_file).startswith("test_") + ] + + def run_coverage(mapping, src_dir="src"): + print(get_ignored_files(src_dir)) # Initialize a new coverage measurement session. The --source option # includes all files in the report, even if they're never imported. print("\nInitializing session\n", flush=True) @@ -113,7 +120,7 @@ def run_coverage(mapping, src_dir="src"): "--source", ",".join([os.path.join(src_dir, "websockets"), "tests"]), "--omit", - ",".join(IGNORED_FILES), + ",".join(get_ignored_files(src_dir)), "-m", "unittest", ] From 077e6429df312f549c9ecf348ec63abfd249da28 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 1 Nov 2022 08:29:02 +0100 Subject: [PATCH 1129/1539] Perform version check at compile time. This is a small performance optimization. --- src/websockets/legacy/compatibility.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/src/websockets/legacy/compatibility.py b/src/websockets/legacy/compatibility.py index df81de9db..296e9c584 100644 --- a/src/websockets/legacy/compatibility.py +++ b/src/websockets/legacy/compatibility.py @@ -5,9 +5,20 @@ from typing import Any, Dict -def loop_if_py_lt_38(loop: asyncio.AbstractEventLoop) -> Dict[str, Any]: - """ - Helper for the removal of the loop argument in Python 3.10. +if sys.version_info[:2] >= (3, 8): - """ - return {"loop": loop} if sys.version_info[:2] < (3, 8) else {} + def loop_if_py_lt_38(loop: asyncio.AbstractEventLoop) -> Dict[str, Any]: + """ + Helper for the removal of the loop argument in Python 3.10. + + """ + return {} + +else: # pragma: no cover + + def loop_if_py_lt_38(loop: asyncio.AbstractEventLoop) -> Dict[str, Any]: + """ + Helper for the removal of the loop argument in Python 3.10. + + """ + return {"loop": loop} From 68c041ec43a6a4d093d2316bf75c2ca0c51f4893 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 1 Nov 2022 13:17:26 +0100 Subject: [PATCH 1130/1539] Reduce usage of pragma: no cover. --- setup.cfg | 1 + tests/legacy/test_client_server.py | 24 +++++++++--------------- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/setup.cfg b/setup.cfg index 4c1f6091f..47acf312c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,3 +35,4 @@ exclude_lines = if self.debug: pragma: no cover raise AssertionError + self.fail\(".*"\) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 8ee7187eb..1a668a431 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -267,21 +267,15 @@ async def start_client(): self.assertDeprecationWarnings(recorded_warnings, expected_warnings) def stop_client(self): - try: - self.loop.run_until_complete( - asyncio.wait_for(self.client.close_connection_task, timeout=1) - ) - except asyncio.TimeoutError: # pragma: no cover - self.fail("Client failed to stop") + self.loop.run_until_complete( + asyncio.wait_for(self.client.close_connection_task, timeout=1) + ) def stop_server(self): self.server.close() - try: - self.loop.run_until_complete( - asyncio.wait_for(self.server.wait_closed(), timeout=1) - ) - except asyncio.TimeoutError: # pragma: no cover - self.fail("Server failed to stop") + self.loop.run_until_complete( + asyncio.wait_for(self.server.wait_closed(), timeout=1) + ) @contextlib.contextmanager def temp_server(self, **kwargs): @@ -380,13 +374,13 @@ def test_infinite_redirect(self): with temp_test_redirecting_server(self): with self.assertRaises(InvalidHandshake): with self.temp_client("/infinite"): - self.fail("Did not raise") # pragma: no cover + self.fail("did not raise") def test_redirect_missing_location(self): with temp_test_redirecting_server(self): with self.assertRaises(InvalidHeader): with self.temp_client("/missing_location"): - self.fail("Did not raise") # pragma: no cover + self.fail("did not raise") def test_loop_backwards_compatibility(self): with self.temp_server( @@ -1327,7 +1321,7 @@ def test_redirect_insecure(self): with temp_test_redirecting_server(self): with self.assertRaises(InvalidHandshake): with self.temp_client("/force_insecure"): - self.fail("Did not raise") # pragma: no cover + self.fail("did not raise") class ClientServerOriginTests(ClientServerTestsMixin, AsyncioTestCase): From 591d047884d08d938324c4a484e8138a88a7a4a9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 2 Nov 2022 07:52:29 +0100 Subject: [PATCH 1131/1539] Remove debug statement. --- tests/maxi_cov.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/maxi_cov.py b/tests/maxi_cov.py index 77d374f96..b7c07b698 100755 --- a/tests/maxi_cov.py +++ b/tests/maxi_cov.py @@ -107,7 +107,6 @@ def get_ignored_files(src_dir="src"): def run_coverage(mapping, src_dir="src"): - print(get_ignored_files(src_dir)) # Initialize a new coverage measurement session. The --source option # includes all files in the report, even if they're never imported. print("\nInitializing session\n", flush=True) From 9e960b508988c4049eb9f3377c505f506a3af060 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 11 Nov 2022 14:26:45 +0100 Subject: [PATCH 1132/1539] Fix example of shutting down a client. Fix #1261. --- example/shutdown_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/shutdown_client.py b/example/shutdown_client.py index 539dd0304..65cfca41b 100755 --- a/example/shutdown_client.py +++ b/example/shutdown_client.py @@ -10,7 +10,7 @@ async def client(): # Close the connection when receiving SIGTERM. loop = asyncio.get_running_loop() loop.add_signal_handler( - signal.SIGTERM, loop.create_task, websocket.close()) + signal.SIGTERM, loop.create_task, websocket.close) # Process messages received on the connection. async for message in websocket: From f199a31361118f58bc7cf5f928df8fc64575a824 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Nov 2022 15:08:48 +0100 Subject: [PATCH 1133/1539] Format setup.py with black. --- setup.py | 54 +++++++++++++++++++++++++++--------------------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/setup.py b/setup.py index 492b1597c..86b8bf9b8 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" -long_description = (root_dir / 'README.rst').read_text(encoding='utf-8') +long_description = (root_dir / "README.rst").read_text(encoding="utf-8") # PyPI disables the "raw" directive. long_description = re.sub( @@ -18,47 +18,47 @@ flags=re.DOTALL | re.MULTILINE, ) -exec((root_dir / 'src' / 'websockets' / 'version.py').read_text(encoding='utf-8')) +exec((root_dir / "src" / "websockets" / "version.py").read_text(encoding="utf-8")) -packages = ['websockets', 'websockets/legacy', 'websockets/extensions'] +packages = ["websockets", "websockets/legacy", "websockets/extensions"] ext_modules = [ setuptools.Extension( - 'websockets.speedups', - sources=['src/websockets/speedups.c'], - optional=not (root_dir / '.cibuildwheel').exists(), + "websockets.speedups", + sources=["src/websockets/speedups.c"], + optional=not (root_dir / ".cibuildwheel").exists(), ) ] setuptools.setup( - name='websockets', + name="websockets", version=version, description=description, long_description=long_description, - url='https://github.com/aaugustin/websockets', - author='Aymeric Augustin', - author_email='aymeric.augustin@m4x.org', - license='BSD', + url="https://github.com/aaugustin/websockets", + author="Aymeric Augustin", + author_email="aymeric.augustin@m4x.org", + license="BSD", classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Environment :: Web Environment', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: BSD License', - 'Operating System :: OS Independent', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', + "Development Status :: 5 - Production/Stable", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", ], - package_dir = {'': 'src'}, - package_data = {'websockets': ['py.typed']}, + package_dir={"": "src"}, + package_data={"websockets": ["py.typed"]}, packages=packages, ext_modules=ext_modules, include_package_data=True, zip_safe=False, - python_requires='>=3.7', - test_loader='unittest:TestLoader', + python_requires=">=3.7", + test_loader="unittest:TestLoader", ) From 26e1946f4d02357c71ca909a8e347f38892900d1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Nov 2022 15:43:05 +0100 Subject: [PATCH 1134/1539] Update for the latest version of mypy. https://github.com/python/mypy/issues/2350 --- setup.cfg | 1 + src/websockets/extensions/base.py | 5 +++++ src/websockets/legacy/protocol.py | 2 -- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 47acf312c..48703df87 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,4 +35,5 @@ exclude_lines = if self.debug: pragma: no cover raise AssertionError + raise NotImplementedError self.fail\(".*"\) diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index 060967618..6c481a46c 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -38,6 +38,7 @@ def decode( PayloadTooBig: if decoding the payload exceeds ``max_size``. """ + raise NotImplementedError def encode(self, frame: frames.Frame) -> frames.Frame: """ @@ -50,6 +51,7 @@ def encode(self, frame: frames.Frame) -> frames.Frame: Frame: Encoded frame. """ + raise NotImplementedError class ClientExtensionFactory: @@ -69,6 +71,7 @@ def get_request_params(self) -> List[ExtensionParameter]: List[ExtensionParameter]: Parameters to send to the server. """ + raise NotImplementedError def process_response_params( self, @@ -91,6 +94,7 @@ def process_response_params( NegotiationError: if parameters aren't acceptable. """ + raise NotImplementedError class ServerExtensionFactory: @@ -126,3 +130,4 @@ def process_request_params( the client aren't acceptable. """ + raise NotImplementedError diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 1b6e58efa..21784eb55 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -220,8 +220,6 @@ def __init__( # Logger or LoggerAdapter for this connection. if logger is None: logger = logging.getLogger("websockets.protocol") - # https://github.com/python/typeshed/issues/5561 - logger = cast(logging.Logger, logger) self.logger: LoggerLike = logging.LoggerAdapter(logger, {"websocket": self}) """Logger for this connection.""" From 0f4ecfc3d0abe4fdc2bc5104ea98e134be399c9a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Nov 2022 22:39:49 +0100 Subject: [PATCH 1135/1539] Don't treat close code 1005 as an error. Fix #1260. --- docs/project/changelog.rst | 14 ++++++++++++++ src/websockets/exceptions.py | 7 ++++--- src/websockets/frames.py | 6 +++++- src/websockets/legacy/client.py | 2 +- src/websockets/legacy/protocol.py | 6 +++--- src/websockets/legacy/server.py | 2 +- 6 files changed, 28 insertions(+), 9 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 7851db80d..40e090c31 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -30,6 +30,20 @@ They may change at any time. *In development* +Backwards-incompatible changes +.............................. + +.. admonition:: Closing a connection without an empty close frame is OK. + :class: note + + Receiving an empty close frame now results in + :exc:`~exceptions.ConnectionClosedOK` instead of + :exc:`~exceptions.ConnectionClosedError`. + + As a consequence, calling ``WebSocket.close()`` without arguments in a + browser isn't reported as an error anymore. + + New features ............ diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 291b6d2cd..46f314d9e 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -135,8 +135,8 @@ class ConnectionClosedError(ConnectionClosed): """ Like :exc:`ConnectionClosed`, when the connection terminated with an error. - A close code other than 1000 (OK) or 1001 (going away) was received or - sent, or the closing handshake didn't complete properly. + A close frame with a code other than 1000 (OK) or 1001 (going away) was + received or sent, or the closing handshake didn't complete properly. """ @@ -145,7 +145,8 @@ class ConnectionClosedOK(ConnectionClosed): """ Like :exc:`ConnectionClosed`, when the connection terminated properly. - A close code 1000 (OK) or 1001 (going away) was received and sent. + A close code with code 1000 (OK) or 1001 (going away) or without a code was + received and sent. """ diff --git a/src/websockets/frames.py b/src/websockets/frames.py index ec6b8547d..45d006e3f 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -90,7 +90,11 @@ class Opcode(enum.IntEnum): 1014, } -OK_CLOSE_CODES = {1000, 1001} +OK_CLOSE_CODES = { + 1000, + 1001, + 1005, +} BytesLike = bytes, bytearray, memoryview diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 1e3c6e741..9b953df8f 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -65,7 +65,7 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): await process(message) The iterator exits normally when the connection is closed with close code - 1000 (OK) or 1001 (going away). It raises + 1000 (OK) or 1001 (going away) or without a close code. It raises a :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 21784eb55..9f2bda1ab 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -485,9 +485,9 @@ async def __aiter__(self) -> AsyncIterator[Data]: Iterate on incoming messages. The iterator exits normally when the connection is closed with the - close code 1000 (OK) or 1001(going away). It raises - a :exc:`~websockets.exceptions.ConnectionClosedError` exception when - the connection is closed with any other code. + close code 1000 (OK) or 1001(going away) or without a close code. It + raises a :exc:`~websockets.exceptions.ConnectionClosedError` exception + when the connection is closed with any other code. """ try: diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 89e5322bc..9359472fe 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -73,7 +73,7 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): await process(message) The iterator exits normally when the connection is closed with close code - 1000 (OK) or 1001 (going away). It raises + 1000 (OK) or 1001 (going away) or without a close code. It raises a :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. From 9c1430379bcb42120e45ee0e9f9dca142b1d7560 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 22 Nov 2022 07:26:59 +0100 Subject: [PATCH 1136/1539] Use standard licence text and SPDX identifier. --- LICENSE | 9 ++++----- setup.py | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/LICENSE b/LICENSE index 119b29ef3..5d61ece22 100644 --- a/LICENSE +++ b/LICENSE @@ -1,5 +1,4 @@ -Copyright (c) 2013-2021 Aymeric Augustin and contributors. -All rights reserved. +Copyright (c) Aymeric Augustin and contributors Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -9,9 +8,9 @@ modification, are permitted provided that the following conditions are met: * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - * Neither the name of websockets nor the names of its contributors may - be used to endorse or promote products derived from this software without - specific prior written permission. + * Neither the name of the copyright holder nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED diff --git a/setup.py b/setup.py index 86b8bf9b8..564ada85c 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,7 @@ url="https://github.com/aaugustin/websockets", author="Aymeric Augustin", author_email="aymeric.augustin@m4x.org", - license="BSD", + license="BSD-3-Clause", classifiers=[ "Development Status :: 5 - Production/Stable", "Environment :: Web Environment", From f1a18247e78f9efcdca34a4b5a616d9733e92ff6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 22 Nov 2022 08:05:49 +0100 Subject: [PATCH 1137/1539] Rename test class with a more specific name. --- tests/extensions/test_permessage_deflate.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index d433762c5..c341fdb32 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -24,7 +24,7 @@ from .utils import ClientNoOpExtensionFactory, ServerNoOpExtensionFactory -class ExtensionTestsMixin: +class PerMessageDeflateTestsMixin: def assertExtensionEqual(self, extension1, extension2): self.assertEqual( extension1.remote_no_context_takeover, @@ -44,7 +44,7 @@ def assertExtensionEqual(self, extension1, extension2): ) -class PerMessageDeflateTests(unittest.TestCase, ExtensionTestsMixin): +class PerMessageDeflateTests(unittest.TestCase, PerMessageDeflateTestsMixin): def setUp(self): # Set up an instance of the permessage-deflate extension with the most # common settings. Since the extension is symmetrical, this instance @@ -278,7 +278,9 @@ def test_decompress_max_size(self): self.extension.decode(enc_frame, max_size=10) -class ClientPerMessageDeflateFactoryTests(unittest.TestCase, ExtensionTestsMixin): +class ClientPerMessageDeflateFactoryTests( + unittest.TestCase, PerMessageDeflateTestsMixin +): def test_name(self): assert ClientPerMessageDeflateFactory.name == "permessage-deflate" @@ -616,7 +618,9 @@ def test_enable_client_permessage_deflate(self): ) -class ServerPerMessageDeflateFactoryTests(unittest.TestCase, ExtensionTestsMixin): +class ServerPerMessageDeflateFactoryTests( + unittest.TestCase, PerMessageDeflateTestsMixin +): def test_name(self): assert ServerPerMessageDeflateFactory.name == "permessage-deflate" From 0b884ed68f2c4b482f9eadbf38adc01f7d869f1a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 22 Nov 2022 08:07:22 +0100 Subject: [PATCH 1138/1539] Rename test class consistently with others. --- tests/test_exports.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_exports.py b/tests/test_exports.py index 568c50c54..978b1d0e7 100644 --- a/tests/test_exports.py +++ b/tests/test_exports.py @@ -25,7 +25,7 @@ ) -class TestExportsAllSubmodules(unittest.TestCase): +class ExportsTests(unittest.TestCase): def test_top_level_module_reexports_all_submodule_exports(self): self.assertEqual(set(combined_exports), set(websockets.__all__)) From 173aac8e536346155b6bb9edac6ac5b5a351b7e5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 23 Nov 2022 07:52:56 +0100 Subject: [PATCH 1139/1539] Organize the examples directory. --- docs/faq/client.rst | 2 +- docs/faq/server.rst | 14 +++++++++++++- docs/topics/deployment.rst | 4 ++-- docs/topics/logging.rst | 2 +- example/{ => faq}/health_check_server.py | 0 example/{ => faq}/shutdown_client.py | 0 example/{ => faq}/shutdown_server.py | 0 example/{ => legacy}/basic_auth_client.py | 0 example/{ => legacy}/basic_auth_server.py | 0 example/{ => legacy}/unix_client.py | 0 example/{ => legacy}/unix_server.py | 0 example/{ => logging}/json_log_formatter.py | 0 12 files changed, 17 insertions(+), 5 deletions(-) rename example/{ => faq}/health_check_server.py (100%) rename example/{ => faq}/shutdown_client.py (100%) rename example/{ => faq}/shutdown_server.py (100%) rename example/{ => legacy}/basic_auth_client.py (100%) rename example/{ => legacy}/basic_auth_server.py (100%) rename example/{ => legacy}/unix_client.py (100%) rename example/{ => legacy}/unix_server.py (100%) rename example/{ => logging}/json_log_formatter.py (100%) diff --git a/docs/faq/client.rst b/docs/faq/client.rst index 5bbbd6ded..73825e480 100644 --- a/docs/faq/client.rst +++ b/docs/faq/client.rst @@ -74,7 +74,7 @@ You can close the connection. Here's an example that terminates cleanly when it receives SIGTERM on Unix: -.. literalinclude:: ../../example/shutdown_client.py +.. literalinclude:: ../../example/faq/shutdown_client.py :emphasize-lines: 10-13 How do I disable TLS/SSL certificate verification? diff --git a/docs/faq/server.rst b/docs/faq/server.rst index 22a9d3a4c..feec65a58 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -257,7 +257,7 @@ Exit the :func:`~server.serve` context manager. Here's an example that terminates cleanly when it receives SIGTERM on Unix: -.. literalinclude:: ../../example/shutdown_server.py +.. literalinclude:: ../../example/faq/shutdown_server.py :emphasize-lines: 12-15,18 How do I stop a server while keeping existing connections open? @@ -275,6 +275,18 @@ Here's how to adapt the example just above:: await stop await server.close(close_connections=False) +How do I implement a health check? +---------------------------------- + +Intercept WebSocket handshake requests with the +:meth:`~server.WebSocketServerProtocol.process_request` hook. + +When a request is sent to the health check endpoint, treat is as an HTTP request +and return a ``(status, headers, body)`` tuple, as in this example: + +.. literalinclude:: ../../example/faq/health_check_server.py + :emphasize-lines: 7-9,18 + How do I run HTTP and WebSocket servers on the same port? --------------------------------------------------------- diff --git a/docs/topics/deployment.rst b/docs/topics/deployment.rst index ac0a8ed4c..2a1fe9a78 100644 --- a/docs/topics/deployment.rst +++ b/docs/topics/deployment.rst @@ -97,7 +97,7 @@ signal and exit the server to ensure a graceful shutdown. Here's an example: -.. literalinclude:: ../../example/shutdown_server.py +.. literalinclude:: ../../example/faq/shutdown_server.py :emphasize-lines: 12-15,18 When exiting the context manager, :func:`~server.serve` closes all connections @@ -177,5 +177,5 @@ websockets provide minimal support for responding to HTTP requests with the Here's an example: -.. literalinclude:: ../../example/health_check_server.py +.. literalinclude:: ../../example/faq/health_check_server.py :emphasize-lines: 7-9,18 diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index 95acf57ff..e2b4a7be1 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -139,7 +139,7 @@ output logs as JSON with a bit of effort. First, we need a :class:`~logging.Formatter` that renders JSON: -.. literalinclude:: ../../example/json_log_formatter.py +.. literalinclude:: ../../example/logging/json_log_formatter.py Then, we configure logging to apply this formatter:: diff --git a/example/health_check_server.py b/example/faq/health_check_server.py similarity index 100% rename from example/health_check_server.py rename to example/faq/health_check_server.py diff --git a/example/shutdown_client.py b/example/faq/shutdown_client.py similarity index 100% rename from example/shutdown_client.py rename to example/faq/shutdown_client.py diff --git a/example/shutdown_server.py b/example/faq/shutdown_server.py similarity index 100% rename from example/shutdown_server.py rename to example/faq/shutdown_server.py diff --git a/example/basic_auth_client.py b/example/legacy/basic_auth_client.py similarity index 100% rename from example/basic_auth_client.py rename to example/legacy/basic_auth_client.py diff --git a/example/basic_auth_server.py b/example/legacy/basic_auth_server.py similarity index 100% rename from example/basic_auth_server.py rename to example/legacy/basic_auth_server.py diff --git a/example/unix_client.py b/example/legacy/unix_client.py similarity index 100% rename from example/unix_client.py rename to example/legacy/unix_client.py diff --git a/example/unix_server.py b/example/legacy/unix_server.py similarity index 100% rename from example/unix_server.py rename to example/legacy/unix_server.py diff --git a/example/json_log_formatter.py b/example/logging/json_log_formatter.py similarity index 100% rename from example/json_log_formatter.py rename to example/logging/json_log_formatter.py From cac6e8575e7b3c6339817c74fcbba846a4f935dc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 23 Nov 2022 08:04:15 +0100 Subject: [PATCH 1140/1539] Standardize to "an HTTP". --- README.rst | 2 +- docs/faq/server.rst | 6 +++--- docs/howto/django.rst | 2 +- docs/howto/sansio.rst | 2 +- docs/intro/tutorial1.rst | 2 +- docs/intro/tutorial2.rst | 2 +- docs/intro/tutorial3.rst | 2 +- docs/reference/limitations.rst | 2 +- docs/topics/authentication.rst | 8 ++++---- docs/topics/design.rst | 10 +++++----- docs/topics/logging.rst | 2 +- src/websockets/exceptions.py | 8 ++++---- src/websockets/legacy/auth.py | 4 ++-- src/websockets/legacy/server.py | 6 +++--- tests/legacy/test_client_server.py | 6 +++--- 15 files changed, 32 insertions(+), 32 deletions(-) diff --git a/README.rst b/README.rst index 7cbccfe13..fa1d91061 100644 --- a/README.rst +++ b/README.rst @@ -123,7 +123,7 @@ Why shouldn't I use ``websockets``? * If you're looking for a mixed HTTP / WebSocket library: ``websockets`` aims at being an excellent implementation of :rfc:`6455`: The WebSocket Protocol and :rfc:`7692`: Compression Extensions for WebSocket. Its support for HTTP - is minimal — just enough for a HTTP health check. + is minimal — just enough for an HTTP health check. If you want to do both in the same server, look at HTTP frameworks that build on top of ``websockets`` to support WebSocket connections, like diff --git a/docs/faq/server.rst b/docs/faq/server.rst index feec65a58..ab665f2b8 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -295,15 +295,15 @@ You don't. HTTP and WebSocket have widely different operational characteristics. Running them with the same server becomes inconvenient when you scale. -Providing a HTTP server is out of scope for websockets. It only aims at +Providing an HTTP server is out of scope for websockets. It only aims at providing a WebSocket server. There's limited support for returning HTTP responses with the :attr:`~server.WebSocketServerProtocol.process_request` hook. -If you need more, pick a HTTP server and run it separately. +If you need more, pick an HTTP server and run it separately. -Alternatively, pick a HTTP framework that builds on top of ``websockets`` to +Alternatively, pick an HTTP framework that builds on top of ``websockets`` to support WebSocket connections, like Sanic_. .. _Sanic: https://sanicframework.org/en/ diff --git a/docs/howto/django.rst b/docs/howto/django.rst index 5bb2e296b..c955a5ec1 100644 --- a/docs/howto/django.rst +++ b/docs/howto/django.rst @@ -10,7 +10,7 @@ WebSocket, you have two main options. 2. Deploying a separate WebSocket server next to your Django project. This technique is well suited when you need to add a small set of real-time - features — maybe a notification service — to a HTTP application. + features — maybe a notification service — to an HTTP application. .. _Channels: https://channels.readthedocs.io/ diff --git a/docs/howto/sansio.rst b/docs/howto/sansio.rst index 1373c81a5..197e29dc9 100644 --- a/docs/howto/sansio.rst +++ b/docs/howto/sansio.rst @@ -85,7 +85,7 @@ with :meth:`~server.ServerConnection.send_response`:: response = connection.accept(request) connection.send_response(response) -Alternatively, you may reject the WebSocket handshake and return a HTTP +Alternatively, you may reject the WebSocket handshake and return an HTTP response with :meth:`~server.ServerConnection.reject`:: response = connection.reject(status, explanation) diff --git a/docs/intro/tutorial1.rst b/docs/intro/tutorial1.rst index ab4f39d79..ff85003b5 100644 --- a/docs/intro/tutorial1.rst +++ b/docs/intro/tutorial1.rst @@ -163,7 +163,7 @@ page loads, draw the board: createBoard(board); }); -Open a shell, navigate to the directory containing these files, and start a +Open a shell, navigate to the directory containing these files, and start an HTTP server: .. code-block:: console diff --git a/docs/intro/tutorial2.rst b/docs/intro/tutorial2.rst index 53f295d3f..5ac4ae9dd 100644 --- a/docs/intro/tutorial2.rst +++ b/docs/intro/tutorial2.rst @@ -369,7 +369,7 @@ common pattern in servers that handle different clients. it in the URI because URIs end up in logs. For the purposes of this tutorial, both approaches are equivalent because - the join key comes from a HTTP URL. There isn't much at risk anyway! + the join key comes from an HTTP URL. There isn't much at risk anyway! Now you can restore the logic for playing moves and you'll have a fully functional two-player game. diff --git a/docs/intro/tutorial3.rst b/docs/intro/tutorial3.rst index 9a447f39b..4d42447b7 100644 --- a/docs/intro/tutorial3.rst +++ b/docs/intro/tutorial3.rst @@ -13,7 +13,7 @@ Part 3 - Deploy to the web web; you can play from any browser connected to the Internet. In the first and second parts of the tutorial, for local development, you ran -a HTTP server on ``http://localhost:8000/`` with: +an HTTP server on ``http://localhost:8000/`` with: .. code-block:: console diff --git a/docs/reference/limitations.rst b/docs/reference/limitations.rst index 3304bdb8c..696aa38fd 100644 --- a/docs/reference/limitations.rst +++ b/docs/reference/limitations.rst @@ -13,7 +13,7 @@ right layer for enforcing this constraint. It's the caller's responsibility. .. _mandated by RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-4.1 -The client doesn't support connecting through a HTTP proxy (`issue 364`_) or a +The client doesn't support connecting through an HTTP proxy (`issue 364`_) or a SOCKS proxy (`issue 475`_). .. _issue 364: https://github.com/aaugustin/websockets/issues/364 diff --git a/docs/topics/authentication.rst b/docs/topics/authentication.rst index 31bfd6465..1849d635a 100644 --- a/docs/topics/authentication.rst +++ b/docs/topics/authentication.rst @@ -78,7 +78,7 @@ WebSocket server. 3. **Setting a cookie on the domain of the WebSocket URI.** Cookies are undoubtedly the most common and hardened mechanism for sending - credentials from a web application to a server. In a HTTP application, + credentials from a web application to a server. In an HTTP application, credentials would be a session identifier or a serialized, signed session. Unfortunately, when the WebSocket server runs on a different domain from @@ -208,7 +208,7 @@ opening the connection: // ... The server intercepts the HTTP request, extracts the token and authenticates -the user. If authentication fails, it returns a HTTP 401: +the user. If authentication fails, it returns an HTTP 401: .. code-block:: python @@ -254,7 +254,7 @@ This sequence must be synchronized between the main window and the iframe. This involves several events. Look at the full implementation for details. The server intercepts the HTTP request, extracts the token and authenticates -the user. If authentication fails, it returns a HTTP 401: +the user. If authentication fails, it returns an HTTP 401: .. code-block:: python @@ -295,7 +295,7 @@ Since HTTP Basic Auth is designed to accept a username and a password rather than a token, we send ``token`` as username and the token as password. The server intercepts the HTTP request, extracts the token and authenticates -the user. If authentication fails, it returns a HTTP 401: +the user. If authentication fails, it returns an HTTP 401: .. code-block:: python diff --git a/docs/topics/design.rst b/docs/topics/design.rst index b5c55afc9..33dd187b9 100644 --- a/docs/topics/design.rst +++ b/docs/topics/design.rst @@ -125,23 +125,23 @@ one another. On the client side, :meth:`~client.WebSocketClientProtocol.handshake`: -- builds a HTTP request based on the ``uri`` and parameters passed to +- builds an HTTP request based on the ``uri`` and parameters passed to :meth:`~client.connect`; - writes the HTTP request to the network; -- reads a HTTP response from the network; +- reads an HTTP response from the network; - checks the HTTP response, validates ``extensions`` and ``subprotocol``, and configures the protocol accordingly; - moves to the ``OPEN`` state. On the server side, :meth:`~server.WebSocketServerProtocol.handshake`: -- reads a HTTP request from the network; +- reads an HTTP request from the network; - calls :meth:`~server.WebSocketServerProtocol.process_request` which may - abort the WebSocket handshake and return a HTTP response instead; this + abort the WebSocket handshake and return an HTTP response instead; this hook only makes sense on the server side; - checks the HTTP request, negotiates ``extensions`` and ``subprotocol``, and configures the protocol accordingly; -- builds a HTTP response based on the above and parameters passed to +- builds an HTTP response based on the above and parameters passed to :meth:`~server.serve`; - writes the HTTP response to the network; - moves to the ``OPEN`` state; diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index e2b4a7be1..294a6cda8 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -104,7 +104,7 @@ can set ``logger`` to a :class:`~logging.LoggerAdapter` that enriches logs. For example, if the server is behind a reverse proxy, :attr:`~legacy.protocol.WebSocketCommonProtocol.remote_address` gives the IP address of the proxy, which isn't useful. IP addresses of clients are -provided in a HTTP header set by the proxy. +provided in an HTTP header set by the proxy. Here's how to include them in logs, assuming they're in the ``X-Forwarded-For`` header:: diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 46f314d9e..1f4b9265c 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -176,7 +176,7 @@ class InvalidMessage(InvalidHandshake): class InvalidHeader(InvalidHandshake): """ - Raised when a HTTP header doesn't have a valid format or value. + Raised when an HTTP header doesn't have a valid format or value. """ @@ -195,7 +195,7 @@ def __str__(self) -> str: class InvalidHeaderFormat(InvalidHeader): """ - Raised when a HTTP header cannot be parsed. + Raised when an HTTP header cannot be parsed. The format of the header doesn't match the grammar for that header. @@ -207,7 +207,7 @@ def __init__(self, name: str, error: str, header: str, pos: int) -> None: class InvalidHeaderValue(InvalidHeader): """ - Raised when a HTTP header has a wrong value. + Raised when an HTTP header has a wrong value. The format of the header is correct but a value isn't acceptable. @@ -315,7 +315,7 @@ def __str__(self) -> str: class AbortHandshake(InvalidHandshake): """ - Raised to abort the handshake on purpose and return a HTTP response. + Raised to abort the handshake on purpose and return an HTTP response. This exception is an implementation detail. diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index 8825c14ec..ac24c179e 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -67,7 +67,7 @@ async def check_credentials(self, username: str, password: str) -> bool: Returns: bool: :obj:`True` if the handshake should continue; - :obj:`False` if it should fail with a HTTP 401 error. + :obj:`False` if it should fail with an HTTP 401 error. """ if self._check_credentials is not None: @@ -81,7 +81,7 @@ async def process_request( request_headers: Headers, ) -> Optional[HTTPResponse]: """ - Check HTTP Basic Auth and return a HTTP 401 response if needed. + Check HTTP Basic Auth and return an HTTP 401 response if needed. """ try: diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 9359472fe..f37a16c20 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -329,9 +329,9 @@ async def process_request( You may override this method in a :class:`WebSocketServerProtocol` subclass, for example: - * to return a HTTP 200 OK response on a given path; then a load + * to return an HTTP 200 OK response on a given path; then a load balancer can use this path for a health check; - * to authenticate the request and return a HTTP 401 Unauthorized or a + * to authenticate the request and return an HTTP 401 Unauthorized or an HTTP 403 Forbidden when authentication fails. You may also override this method with the ``process_request`` @@ -776,7 +776,7 @@ async def _close(self, close_connections: bool) -> None: if close_connections: # Close OPEN connections with status code 1001. Since the server was - # closed, handshake() closes OPENING connections with a HTTP 503 + # closed, handshake() closes OPENING connections with an HTTP 503 # error. Wait until all connections are closed. close_tasks = [ diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 1a668a431..f8a79027e 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -727,7 +727,7 @@ def test_protocol_custom_server_header(self): @with_server(create_protocol=HealthCheckServerProtocol) def test_http_request_http_endpoint(self): - # Making a HTTP request to a HTTP endpoint succeeds. + # Making an HTTP request to an HTTP endpoint succeeds. response = self.loop.run_until_complete(self.make_http_request("/__health__/")) with contextlib.closing(response): @@ -736,7 +736,7 @@ def test_http_request_http_endpoint(self): @with_server(create_protocol=HealthCheckServerProtocol) def test_http_request_ws_endpoint(self): - # Making a HTTP request to a WS endpoint fails. + # Making an HTTP request to a WS endpoint fails. with self.assertRaises(urllib.error.HTTPError) as raised: self.loop.run_until_complete(self.make_http_request()) @@ -745,7 +745,7 @@ def test_http_request_ws_endpoint(self): @with_server(create_protocol=HealthCheckServerProtocol) def test_ws_connection_http_endpoint(self): - # Making a WS connection to a HTTP endpoint fails. + # Making a WS connection to an HTTP endpoint fails. with self.assertRaises(InvalidStatusCode) as raised: self.start_client("/__health__/") From 39c53fb67d6815cab0a72be03b0bba60df504831 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 27 Nov 2022 19:02:44 +0100 Subject: [PATCH 1141/1539] websockets CAN handle multiple clients. Ref #1268. --- docs/faq/server.rst | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/docs/faq/server.rst b/docs/faq/server.rst index ab665f2b8..4e4622dce 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -43,19 +43,44 @@ Why can only one client connect at a time? ------------------------------------------ Your connection handler blocks the event loop. Look for blocking calls. + Any call that may take some time must be asynchronous. -For example, if you have:: +For example, this connection handler prevents the event loop from running during +one second:: async def handler(websocket): time.sleep(1) + ... -change it to:: +Change it to:: async def handler(websocket): await asyncio.sleep(1) + ... + +In addition, calling a coroutine doesn't guarantee that it will yield control to +the event loop. + +For example, this connection handler blocks the event loop by sending messages +continuously:: + + async def handler(websocket): + while True: + await websocket.send("firehose!") + +:meth:`~legacy.protocol.WebSocketCommonProtocol.send` completes synchronously as +long as there's space in send buffers. The event loop never runs. (This pattern +is uncommon in real-world applications. It occurs mostly in toy programs.) + +You can avoid the issue by yielding control to the event loop explicitly:: + + async def handler(websocket): + while True: + await websocket.send("firehose!") + await asyncio.sleep(0) -This is part of learning asyncio. It isn't specific to websockets. +All this is part of learning asyncio. It isn't specific to websockets. See also Python's documentation about `running blocking code`_. From f5ea94ab818d873664ffb76274b675fe307e289b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 Oct 2022 09:25:50 +0200 Subject: [PATCH 1142/1539] Rename Connection to Protocol. This makes the following naming possible: * Connection = TCP/TLS connection + pointer to protocol; concerned with opening the network connection, moving bytes, and closing it. * Protocol = Sans-I/O protocol; concerned with parsing and serializing messages and with the state of the connection. Previously, these names were reversed in the sockets & threads branch. The previous choice was influenced by the legacy implementation using "protocol" to describe the two layers together, itself influenced by asyncio using the same word. --- docs/howto/sansio.rst | 130 ++- docs/project/changelog.rst | 17 +- docs/reference/client.rst | 2 +- docs/reference/common.rst | 4 +- docs/reference/server.rst | 2 +- docs/reference/types.rst | 2 +- src/websockets/__init__.py | 8 +- src/websockets/client.py | 18 +- src/websockets/connection.py | 705 +---------- src/websockets/http11.py | 4 +- src/websockets/legacy/protocol.py | 2 +- src/websockets/legacy/server.py | 2 +- src/websockets/protocol.py | 702 +++++++++++ src/websockets/server.py | 18 +- tests/legacy/test_client_server.py | 2 +- tests/legacy/test_protocol.py | 2 +- tests/test_client.py | 98 +- tests/test_connection.py | 1743 +--------------------------- tests/test_protocol.py | 1737 +++++++++++++++++++++++++++ tests/test_server.py | 108 +- tests/utils.py | 29 + 21 files changed, 2722 insertions(+), 2613 deletions(-) create mode 100644 src/websockets/protocol.py create mode 100644 tests/test_protocol.py diff --git a/docs/howto/sansio.rst b/docs/howto/sansio.rst index 197e29dc9..08b09f7ce 100644 --- a/docs/howto/sansio.rst +++ b/docs/howto/sansio.rst @@ -33,35 +33,36 @@ If you're building a client, parse the URI you'd like to connect to:: Open a TCP connection to ``(wsuri.host, wsuri.port)`` and perform a TLS handshake if ``wsuri.secure`` is :obj:`True`. -Initialize a :class:`~client.ClientConnection`:: +Initialize a :class:`~client.ClientProtocol`:: - from websockets.client import ClientConnection + from websockets.client import ClientProtocol - connection = ClientConnection(wsuri) + protocol = ClientProtocol(wsuri) Create a WebSocket handshake request -with :meth:`~client.ClientConnection.connect` and send it -with :meth:`~client.ClientConnection.send_request`:: +with :meth:`~client.ClientProtocol.connect` and send it +with :meth:`~client.ClientProtocol.send_request`:: - request = connection.connect() - connection.send_request(request) + request = protocol.connect() + protocol.send_request(request) -Then, call :meth:`~connection.Connection.data_to_send` and send its output to +Then, call :meth:`~protocol.Protocol.data_to_send` and send its output to the network, as described in `Send data`_ below. -The first event returned by :meth:`~connection.Connection.events_received` is -the WebSocket handshake response. +Once you receive enough data, as explained in `Receive data`_ below, the first +event returned by :meth:`~protocol.Protocol.events_received` is the WebSocket +handshake response. When the handshake fails, the reason is available in -:attr:`~client.ClientConnection.handshake_exc`:: +:attr:`~client.ClientProtocol.handshake_exc`:: - if connection.handshake_exc is not None: - raise connection.handshake_exc + if protocol.handshake_exc is not None: + raise protocol.handshake_exc Else, the WebSocket connection is open. A WebSocket client API usually performs the handshake then returns a wrapper -around the network connection and the :class:`~client.ClientConnection`. +around the network socket and the :class:`~client.ClientProtocol`. Server-side ........... @@ -69,45 +70,46 @@ Server-side If you're building a server, accept network connections from clients and perform a TLS handshake if desired. -For each connection, initialize a :class:`~server.ServerConnection`:: +For each connection, initialize a :class:`~server.ServerProtocol`:: - from websockets.server import ServerConnection + from websockets.server import ServerProtocol - connection = ServerConnection() + protocol = ServerProtocol() -The first event returned by :meth:`~connection.Connection.events_received` is -the WebSocket handshake request. +Once you receive enough data, as explained in `Receive data`_ below, the first +event returned by :meth:`~protocol.Protocol.events_received` is the WebSocket +handshake request. Create a WebSocket handshake response -with :meth:`~server.ServerConnection.accept` and send it -with :meth:`~server.ServerConnection.send_response`:: +with :meth:`~server.ServerProtocol.accept` and send it +with :meth:`~server.ServerProtocol.send_response`:: - response = connection.accept(request) - connection.send_response(response) + response = protocol.accept(request) + protocol.send_response(response) Alternatively, you may reject the WebSocket handshake and return an HTTP -response with :meth:`~server.ServerConnection.reject`:: +response with :meth:`~server.ServerProtocol.reject`:: - response = connection.reject(status, explanation) - connection.send_response(response) + response = protocol.reject(status, explanation) + protocol.send_response(response) -Then, call :meth:`~connection.Connection.data_to_send` and send its output to +Then, call :meth:`~protocol.Protocol.data_to_send` and send its output to the network, as described in `Send data`_ below. -Even when you call :meth:`~server.ServerConnection.accept`, the WebSocket +Even when you call :meth:`~server.ServerProtocol.accept`, the WebSocket handshake may fail if the request is incorrect or unsupported. When the handshake fails, the reason is available in -:attr:`~server.ServerConnection.handshake_exc`:: +:attr:`~server.ServerProtocol.handshake_exc`:: - if connection.handshake_exc is not None: - raise connection.handshake_exc + if protocol.handshake_exc is not None: + raise protocol.handshake_exc Else, the WebSocket connection is open. -A WebSocket server API usually builds a wrapper around the network connection -and the :class:`~server.ServerConnection`. Then it invokes a connection -handler that accepts the wrapper in argument. +A WebSocket server API usually builds a wrapper around the network socket and +the :class:`~server.ServerProtocol`. Then it invokes a connection handler that +accepts the wrapper in argument. It may also provide a way to close all connections and to shut down the server gracefully. @@ -122,11 +124,11 @@ Go through the five steps below until you reach the end of the data stream. Receive data ............ -When receiving data from the network, feed it to the connection's -:meth:`~connection.Connection.receive_data` method. +When receiving data from the network, feed it to the protocol's +:meth:`~protocol.Protocol.receive_data` method. -When reaching the end of the data stream, call the connection's -:meth:`~connection.Connection.receive_eof` method. +When reaching the end of the data stream, call the protocol's +:meth:`~protocol.Protocol.receive_eof` method. For example, if ``sock`` is a :obj:`~socket.socket`:: @@ -135,21 +137,21 @@ For example, if ``sock`` is a :obj:`~socket.socket`:: except OSError: # socket closed data = b"" if data: - connection.receive_data(data) + protocol.receive_data(data) else: - connection.receive_eof() + protocol.receive_eof() These methods aren't expected to raise exceptions — unless you call them again -after calling :meth:`~connection.Connection.receive_eof`, which is an error. +after calling :meth:`~protocol.Protocol.receive_eof`, which is an error. (If you get an exception, please file a bug!) Send data ......... -Then, call :meth:`~connection.Connection.data_to_send` and send its output to +Then, call :meth:`~protocol.Protocol.data_to_send` and send its output to the network:: - for data in connection.data_to_send(): + for data in protocol.data_to_send(): if data: sock.sendall(data) else: @@ -170,7 +172,7 @@ server starts the four-way TCP closing handshake. If the network fails at the wrong point, you can end up waiting until the TCP timeout, which is very long. To prevent dangling TCP connections when you expect the end of the data stream -but you never reach it, call :meth:`~connection.Connection.close_expected` +but you never reach it, call :meth:`~protocol.Protocol.close_expected` and, if it returns :obj:`True`, schedule closing the TCP connection after a short timeout:: @@ -185,11 +187,11 @@ data stream, possibly with an exception. Close TCP connection .................... -If you called :meth:`~connection.Connection.receive_eof`, close the TCP +If you called :meth:`~protocol.Protocol.receive_eof`, close the TCP connection now. This is a clean closure because the receive buffer is empty. -After :meth:`~connection.Connection.receive_eof` signals the end of the read -stream, :meth:`~connection.Connection.data_to_send` always signals the end of +After :meth:`~protocol.Protocol.receive_eof` signals the end of the read +stream, :meth:`~protocol.Protocol.data_to_send` always signals the end of the write stream, unless it already ended. So, at this point, the TCP connection is already half-closed. The only reason for closing it now is to release resources related to the socket. @@ -199,8 +201,8 @@ Now you can exit the loop relaying data from the network to the application. Receive events .............. -Finally, call :meth:`~connection.Connection.events_received` to obtain events -parsed from the data provided to :meth:`~connection.Connection.receive_data`:: +Finally, call :meth:`~protocol.Protocol.events_received` to obtain events +parsed from the data provided to :meth:`~protocol.Protocol.receive_data`:: events = connection.events_received() @@ -224,9 +226,9 @@ The connection object provides one method for each type of WebSocket frame. For sending a data frame: -* :meth:`~connection.Connection.send_continuation` -* :meth:`~connection.Connection.send_text` -* :meth:`~connection.Connection.send_binary` +* :meth:`~protocol.Protocol.send_continuation` +* :meth:`~protocol.Protocol.send_text` +* :meth:`~protocol.Protocol.send_binary` These methods raise :exc:`~exceptions.ProtocolError` if you don't set the :attr:`FIN ` bit correctly in fragmented @@ -234,21 +236,21 @@ messages. For sending a control frame: -* :meth:`~connection.Connection.send_close` -* :meth:`~connection.Connection.send_ping` -* :meth:`~connection.Connection.send_pong` +* :meth:`~protocol.Protocol.send_close` +* :meth:`~protocol.Protocol.send_ping` +* :meth:`~protocol.Protocol.send_pong` -:meth:`~connection.Connection.send_close` initiates the closing handshake. +:meth:`~protocol.Protocol.send_close` initiates the closing handshake. See `Closing a connection`_ below for details. If you encounter an unrecoverable error and you must fail the WebSocket -connection, call :meth:`~connection.Connection.fail`. +connection, call :meth:`~protocol.Protocol.fail`. -After any of the above, call :meth:`~connection.Connection.data_to_send` and +After any of the above, call :meth:`~protocol.Protocol.data_to_send` and send its output to the network, as shown in `Send data`_ above. -If you called :meth:`~connection.Connection.send_close` -or :meth:`~connection.Connection.fail`, you expect the end of the data +If you called :meth:`~protocol.Protocol.send_close` +or :meth:`~protocol.Protocol.fail`, you expect the end of the data stream. You should follow the process described in `Close TCP connection`_ above in order to prevent dangling TCP connections. @@ -272,10 +274,10 @@ When a client wants to close the TCP connection: Applying the rules described earlier in this document gives the intended result. As a reminder, the rules are: -* When :meth:`~connection.Connection.data_to_send` returns the empty +* When :meth:`~protocol.Protocol.data_to_send` returns the empty bytestring, close the write side of the TCP connection. * When you reach the end of the read stream, close the TCP connection. -* When :meth:`~connection.Connection.close_expected` returns :obj:`True`, if +* When :meth:`~protocol.Protocol.close_expected` returns :obj:`True`, if you don't reach the end of the read stream quickly, close the TCP connection. Fragmentation @@ -306,14 +308,14 @@ should happen automatically in a cooperative multitasking environment. However, you still have to make sure you don't break this property by accident. For example, serialize writes to the network -when :meth:`~connection.Connection.data_to_send` returns multiple values to +when :meth:`~protocol.Protocol.data_to_send` returns multiple values to prevent concurrent writes from interleaving incorrectly. Avoid buffers ............. The Sans-I/O layer doesn't do any buffering. It makes events available in -:meth:`~connection.Connection.events_received` as soon as they're received. +:meth:`~protocol.Protocol.events_received` as soon as they're received. You should make incoming messages available to the application immediately and stop further processing until the application fetches them. This will usually diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 40e090c31..532eeffb2 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -33,6 +33,18 @@ They may change at any time. Backwards-incompatible changes .............................. +.. admonition:: The Sans-I/O implementation was moved. + :class: caution + + Aliases provide compatibility for all previously public APIs according to + the `backwards-compatibility policy`_ + + * The ``connection`` module was renamed to ``protocol``. + + * The ``connection.Connection``, ``server.ServerConnection``, and + ``client.ClientConnection`` classes were renamed to ``protocol.Protocol``, + ``server.ServerProtocol``, and ``client.ClientProtocol``. + .. admonition:: Closing a connection without an empty close frame is OK. :class: note @@ -43,7 +55,6 @@ Backwards-incompatible changes As a consequence, calling ``WebSocket.close()`` without arguments in a browser isn't reported as an error anymore. - New features ............ @@ -86,8 +97,8 @@ Backwards-incompatible changes .. admonition:: The ``exception`` attribute of :class:`~http11.Request` and :class:`~http11.Response` is deprecated. :class: note - Use the ``handshake_exc`` attribute of :class:`~server.ServerConnection` and - :class:`~client.ClientConnection` instead. + Use the ``handshake_exc`` attribute of :class:`~server.ServerProtocol` and + :class:`~client.ClientProtocol` instead. See :doc:`../howto/sansio` for details. diff --git a/docs/reference/client.rst b/docs/reference/client.rst index b72f49f5d..44f053b1e 100644 --- a/docs/reference/client.rst +++ b/docs/reference/client.rst @@ -69,7 +69,7 @@ Using a connection Sans-I/O -------- -.. autoclass:: ClientConnection(wsuri, origin=None, extensions=None, subprotocols=None, state=State.CONNECTING, max_size=2 ** 20, logger=None) +.. autoclass:: ClientProtocol(wsuri, origin=None, extensions=None, subprotocols=None, state=State.CONNECTING, max_size=2 ** 20, logger=None) .. automethod:: receive_data diff --git a/docs/reference/common.rst b/docs/reference/common.rst index 6ba11bff5..b42f5ea3e 100644 --- a/docs/reference/common.rst +++ b/docs/reference/common.rst @@ -57,9 +57,9 @@ asyncio Sans-I/O -------- -.. automodule:: websockets.connection +.. automodule:: websockets.protocol -.. autoclass:: Connection(side, state=State.OPEN, max_size=2 ** 20, logger=None) +.. autoclass:: Protocol(side, state=State.OPEN, max_size=2 ** 20, logger=None) .. automethod:: receive_data diff --git a/docs/reference/server.rst b/docs/reference/server.rst index 12fe1f806..50ef4ee3c 100644 --- a/docs/reference/server.rst +++ b/docs/reference/server.rst @@ -115,7 +115,7 @@ websockets supports HTTP Basic Authentication according to Sans-I/O -------- -.. autoclass:: ServerConnection(origins=None, extensions=None, subprotocols=None, state=State.CONNECTING, max_size=2 ** 20, logger=None) +.. autoclass:: ServerProtocol(origins=None, extensions=None, subprotocols=None, state=State.CONNECTING, max_size=2 ** 20, logger=None) .. automethod:: receive_data diff --git a/docs/reference/types.rst b/docs/reference/types.rst index d86429be4..88550d08d 100644 --- a/docs/reference/types.rst +++ b/docs/reference/types.rst @@ -15,7 +15,7 @@ Types .. autodata:: ExtensionParameter -.. autodata:: websockets.connection.Event +.. autodata:: websockets.protocol.Event .. autodata:: websockets.datastructures.HeadersLike diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index ec3484124..826decc48 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -9,7 +9,7 @@ "basic_auth_protocol_factory", "BasicAuthWebSocketServerProtocol", "broadcast", - "ClientConnection", + "ClientProtocol", "connect", "ConnectionClosed", "ConnectionClosedError", @@ -40,7 +40,7 @@ "RedirectHandshake", "SecurityError", "serve", - "ServerConnection", + "ServerProtocol", "Subprotocol", "unix_connect", "unix_serve", @@ -60,7 +60,7 @@ "basic_auth_protocol_factory": ".legacy.auth", "BasicAuthWebSocketServerProtocol": ".legacy.auth", "broadcast": ".legacy.protocol", - "ClientConnection": ".client", + "ClientProtocol": ".client", "connect": ".legacy.client", "unix_connect": ".legacy.client", "WebSocketClientProtocol": ".legacy.client", @@ -93,7 +93,7 @@ "WebSocketProtocolError": ".exceptions", "protocol": ".legacy", "WebSocketCommonProtocol": ".legacy.protocol", - "ServerConnection": ".server", + "ServerProtocol": ".server", "serve": ".legacy.server", "unix_serve": ".legacy.server", "WebSocketServerProtocol": ".legacy.server", diff --git a/src/websockets/client.py b/src/websockets/client.py index 373e6b751..a439ab846 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import Generator, List, Optional, Sequence +import warnings +from typing import Any, Generator, List, Optional, Sequence -from .connection import CLIENT, CONNECTING, OPEN, Connection, State from .datastructures import Headers, MultipleValuesError from .exceptions import ( InvalidHandshake, @@ -24,6 +24,7 @@ parse_upgrade, ) from .http11 import Request, Response +from .protocol import CLIENT, CONNECTING, OPEN, Protocol, State from .typing import ( ConnectionOption, ExtensionHeader, @@ -40,10 +41,10 @@ from .legacy.client import * # isort:skip # noqa -__all__ = ["ClientConnection"] +__all__ = ["ClientProtocol"] -class ClientConnection(Connection): +class ClientProtocol(Protocol): """ Sans-I/O implementation of a WebSocket client connection. @@ -342,3 +343,12 @@ def parse(self) -> Generator[None, None, None]: self.events.append(response) yield from super().parse() + + +class ClientConnection(ClientProtocol): + def __init__(self, *args: Any, **kwargs: Any) -> None: + warnings.warn( + "ClientConnection was renamed to ClientProtocol", + DeprecationWarning, + ) + super().__init__(*args, **kwargs) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 7a9be9f2e..5ce4d6a3b 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -1,702 +1,13 @@ from __future__ import annotations -import enum -import logging -import uuid -from typing import Generator, List, Optional, Type, Union +import warnings -from .exceptions import ( - ConnectionClosed, - ConnectionClosedError, - ConnectionClosedOK, - InvalidState, - PayloadTooBig, - ProtocolError, -) -from .extensions import Extension -from .frames import ( - OK_CLOSE_CODES, - OP_BINARY, - OP_CLOSE, - OP_CONT, - OP_PING, - OP_PONG, - OP_TEXT, - Close, - Frame, -) -from .http11 import Request, Response -from .streams import StreamReader -from .typing import LoggerLike, Origin, Subprotocol - - -__all__ = [ - "Connection", - "Side", - "State", - "SEND_EOF", -] - -Event = Union[Request, Response, Frame] -"""Events that :meth:`~Connection.events_received` may return.""" - - -class Side(enum.IntEnum): - """A WebSocket connection is either a server or a client.""" - - SERVER, CLIENT = range(2) - - -SERVER = Side.SERVER -CLIENT = Side.CLIENT - - -class State(enum.IntEnum): - """A WebSocket connection is in one of these four states.""" - - CONNECTING, OPEN, CLOSING, CLOSED = range(4) - - -CONNECTING = State.CONNECTING -OPEN = State.OPEN -CLOSING = State.CLOSING -CLOSED = State.CLOSED - - -SEND_EOF = b"" -"""Sentinel signaling that the TCP connection must be half-closed.""" - - -class Connection: - """ - Sans-I/O implementation of a WebSocket connection. - - Args: - side: :attr:`~Side.CLIENT` or :attr:`~Side.SERVER`. - state: initial state of the WebSocket connection. - max_size: maximum size of incoming messages in bytes; - :obj:`None` to disable the limit. - logger: logger for this connection; depending on ``side``, - defaults to ``logging.getLogger("websockets.client")`` - or ``logging.getLogger("websockets.server")``; - see the :doc:`logging guide <../topics/logging>` for details. - - """ - - def __init__( - self, - side: Side, - state: State = OPEN, - max_size: Optional[int] = 2**20, - logger: Optional[LoggerLike] = None, - ) -> None: - # Unique identifier. For logs. - self.id: uuid.UUID = uuid.uuid4() - """Unique identifier of the connection. Useful in logs.""" - - # Logger or LoggerAdapter for this connection. - if logger is None: - logger = logging.getLogger(f"websockets.{side.name.lower()}") - self.logger: LoggerLike = logger - """Logger for this connection.""" - - # Track if DEBUG is enabled. Shortcut logging calls if it isn't. - self.debug = logger.isEnabledFor(logging.DEBUG) - - # Connection side. CLIENT or SERVER. - self.side = side - - # Connection state. Initially OPEN because subclasses handle CONNECTING. - self.state = state - - # Maximum size of incoming messages in bytes. - self.max_size = max_size - - # Current size of incoming message in bytes. Only set while reading a - # fragmented message i.e. a data frames with the FIN bit not set. - self.cur_size: Optional[int] = None - - # True while sending a fragmented message i.e. a data frames with the - # FIN bit not set. - self.expect_continuation_frame = False - - # WebSocket protocol parameters. - self.origin: Optional[Origin] = None - self.extensions: List[Extension] = [] - self.subprotocol: Optional[Subprotocol] = None - - # Close code and reason, set when a close frame is sent or received. - self.close_rcvd: Optional[Close] = None - self.close_sent: Optional[Close] = None - self.close_rcvd_then_sent: Optional[bool] = None - - # Track if an exception happened during the handshake. - self.handshake_exc: Optional[Exception] = None - """ - Exception to raise if the opening handshake failed. - - :obj:`None` if the opening handshake succeeded. - - """ - - # Track if send_eof() was called. - self.eof_sent = False - - # Parser state. - self.reader = StreamReader() - self.events: List[Event] = [] - self.writes: List[bytes] = [] - self.parser = self.parse() - next(self.parser) # start coroutine - self.parser_exc: Optional[Exception] = None - - @property - def state(self) -> State: - """ - WebSocket connection state. - - Defined in 4.1, 4.2, 7.1.3, and 7.1.4 of :rfc:`6455`. - - """ - return self._state - - @state.setter - def state(self, state: State) -> None: - if self.debug: - self.logger.debug("= connection is %s", state.name) - self._state = state - - @property - def close_code(self) -> Optional[int]: - """ - `WebSocket close code`_. - - .. _WebSocket close code: - https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.5 - - :obj:`None` if the connection isn't closed yet. - - """ - if self.state is not CLOSED: - return None - elif self.close_rcvd is None: - return 1006 - else: - return self.close_rcvd.code - - @property - def close_reason(self) -> Optional[str]: - """ - `WebSocket close reason`_. - - .. _WebSocket close reason: - https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.6 - - :obj:`None` if the connection isn't closed yet. - - """ - if self.state is not CLOSED: - return None - elif self.close_rcvd is None: - return "" - else: - return self.close_rcvd.reason - - @property - def close_exc(self) -> ConnectionClosed: - """ - Exception to raise when trying to interact with a closed connection. - - Don't raise this exception while the connection :attr:`state` - is :attr:`~websockets.connection.State.CLOSING`; wait until - it's :attr:`~websockets.connection.State.CLOSED`. - - Indeed, the exception includes the close code and reason, which are - known only once the connection is closed. - - Raises: - AssertionError: if the connection isn't closed yet. - - """ - assert self.state is CLOSED, "connection isn't closed yet" - exc_type: Type[ConnectionClosed] - if ( - self.close_rcvd is not None - and self.close_sent is not None - and self.close_rcvd.code in OK_CLOSE_CODES - and self.close_sent.code in OK_CLOSE_CODES - ): - exc_type = ConnectionClosedOK - else: - exc_type = ConnectionClosedError - exc: ConnectionClosed = exc_type( - self.close_rcvd, - self.close_sent, - self.close_rcvd_then_sent, - ) - # Chain to the exception raised in the parser, if any. - exc.__cause__ = self.parser_exc - return exc - - # Public methods for receiving data. - - def receive_data(self, data: bytes) -> None: - """ - Receive data from the network. - - After calling this method: - - - You must call :meth:`data_to_send` and send this data to the network. - - You should call :meth:`events_received` and process resulting events. - - Raises: - EOFError: if :meth:`receive_eof` was called earlier. - - """ - self.reader.feed_data(data) - next(self.parser) - - def receive_eof(self) -> None: - """ - Receive the end of the data stream from the network. - - After calling this method: - - - You must call :meth:`data_to_send` and send this data to the network. - - You aren't expected to call :meth:`events_received`; it won't return - any new events. - - Raises: - EOFError: if :meth:`receive_eof` was called earlier. - - """ - self.reader.feed_eof() - next(self.parser) - - # Public methods for sending events. - - def send_continuation(self, data: bytes, fin: bool) -> None: - """ - Send a `Continuation frame`_. - - .. _Continuation frame: - https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 - - Parameters: - data: payload containing the same kind of data - as the initial frame. - fin: FIN bit; set it to :obj:`True` if this is the last frame - of a fragmented message and to :obj:`False` otherwise. - - Raises: - ProtocolError: if a fragmented message isn't in progress. - - """ - if not self.expect_continuation_frame: - raise ProtocolError("unexpected continuation frame") - self.expect_continuation_frame = not fin - self.send_frame(Frame(OP_CONT, data, fin)) - - def send_text(self, data: bytes, fin: bool = True) -> None: - """ - Send a `Text frame`_. - - .. _Text frame: - https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 - - Parameters: - data: payload containing text encoded with UTF-8. - fin: FIN bit; set it to :obj:`False` if this is the first frame of - a fragmented message. - - Raises: - ProtocolError: if a fragmented message is in progress. - - """ - if self.expect_continuation_frame: - raise ProtocolError("expected a continuation frame") - self.expect_continuation_frame = not fin - self.send_frame(Frame(OP_TEXT, data, fin)) - - def send_binary(self, data: bytes, fin: bool = True) -> None: - """ - Send a `Binary frame`_. +# lazy_import doesn't support this use case. +from .protocol import SEND_EOF, Protocol as Connection, Side, State # noqa - .. _Binary frame: - https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 - Parameters: - data: payload containing arbitrary binary data. - fin: FIN bit; set it to :obj:`False` if this is the first frame of - a fragmented message. - - Raises: - ProtocolError: if a fragmented message is in progress. - - """ - if self.expect_continuation_frame: - raise ProtocolError("expected a continuation frame") - self.expect_continuation_frame = not fin - self.send_frame(Frame(OP_BINARY, data, fin)) - - def send_close(self, code: Optional[int] = None, reason: str = "") -> None: - """ - Send a `Close frame`_. - - .. _Close frame: - https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.1 - - Parameters: - code: close code. - reason: close reason. - - Raises: - ProtocolError: if a fragmented message is being sent, if the code - isn't valid, or if a reason is provided without a code - - """ - if self.expect_continuation_frame: - raise ProtocolError("expected a continuation frame") - if code is None: - if reason != "": - raise ProtocolError("cannot send a reason without a code") - close = Close(1005, "") - data = b"" - else: - close = Close(code, reason) - data = close.serialize() - # send_frame() guarantees that self.state is OPEN at this point. - # 7.1.3. The WebSocket Closing Handshake is Started - self.send_frame(Frame(OP_CLOSE, data)) - self.close_sent = close - self.state = CLOSING - - def send_ping(self, data: bytes) -> None: - """ - Send a `Ping frame`_. - - .. _Ping frame: - https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 - - Parameters: - data: payload containing arbitrary binary data. - - """ - self.send_frame(Frame(OP_PING, data)) - - def send_pong(self, data: bytes) -> None: - """ - Send a `Pong frame`_. - - .. _Pong frame: - https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 - - Parameters: - data: payload containing arbitrary binary data. - - """ - self.send_frame(Frame(OP_PONG, data)) - - def fail(self, code: int, reason: str = "") -> None: - """ - `Fail the WebSocket connection`_. - - .. _Fail the WebSocket connection: - https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.7 - - Parameters: - code: close code - reason: close reason - - Raises: - ProtocolError: if the code isn't valid. - """ - # 7.1.7. Fail the WebSocket Connection - - # Send a close frame when the state is OPEN (a close frame was already - # sent if it's CLOSING), except when failing the connection because - # of an error reading from or writing to the network. - if self.state is OPEN: - if code != 1006: - close = Close(code, reason) - data = close.serialize() - self.send_frame(Frame(OP_CLOSE, data)) - self.close_sent = close - self.state = CLOSING - - # When failing the connection, a server closes the TCP connection - # without waiting for the client to complete the handshake, while a - # client waits for the server to close the TCP connection, possibly - # after sending a close frame that the client will ignore. - if self.side is SERVER and not self.eof_sent: - self.send_eof() - - # 7.1.7. Fail the WebSocket Connection "An endpoint MUST NOT continue - # to attempt to process data(including a responding Close frame) from - # the remote endpoint after being instructed to _Fail the WebSocket - # Connection_." - self.parser = self.discard() - next(self.parser) # start coroutine - - # Public method for getting incoming events after receiving data. - - def events_received(self) -> List[Event]: - """ - Fetch events generated from data received from the network. - - Call this method immediately after any of the ``receive_*()`` methods. - - Process resulting events, likely by passing them to the application. - - Returns: - List[Event]: Events read from the connection. - """ - events, self.events = self.events, [] - return events - - # Public method for getting outgoing data after receiving data or sending events. - - def data_to_send(self) -> List[bytes]: - """ - Obtain data to send to the network. - - Call this method immediately after any of the ``receive_*()``, - ``send_*()``, or :meth:`fail` methods. - - Write resulting data to the connection. - - The empty bytestring :data:`~websockets.connection.SEND_EOF` signals - the end of the data stream. When you receive it, half-close the TCP - connection. - - Returns: - List[bytes]: Data to write to the connection. - - """ - writes, self.writes = self.writes, [] - return writes - - def close_expected(self) -> bool: - """ - Tell if the TCP connection is expected to close soon. - - Call this method immediately after any of the ``receive_*()`` or - :meth:`fail` methods. - - If it returns :obj:`True`, schedule closing the TCP connection after a - short timeout if the other side hasn't already closed it. - - Returns: - bool: Whether the TCP connection is expected to close soon. - - """ - # We expect a TCP close if and only if we sent a close frame: - # * Normal closure: once we send a close frame, we expect a TCP close: - # server waits for client to complete the TCP closing handshake; - # client waits for server to initiate the TCP closing handshake. - # * Abnormal closure: we always send a close frame and the same logic - # applies, except on EOFError where we don't send a close frame - # because we already received the TCP close, so we don't expect it. - # We already got a TCP Close if and only if the state is CLOSED. - return self.state is CLOSING or self.handshake_exc is not None - - # Private methods for receiving data. - - def parse(self) -> Generator[None, None, None]: - """ - Parse incoming data into frames. - - :meth:`receive_data` and :meth:`receive_eof` run this generator - coroutine until it needs more data or reaches EOF. - - """ - try: - while True: - if (yield from self.reader.at_eof()): - if self.debug: - self.logger.debug("< EOF") - # If the WebSocket connection is closed cleanly, with a - # closing handhshake, recv_frame() substitutes parse() - # with discard(). This branch is reached only when the - # connection isn't closed cleanly. - raise EOFError("unexpected end of stream") - - if self.max_size is None: - max_size = None - elif self.cur_size is None: - max_size = self.max_size - else: - max_size = self.max_size - self.cur_size - - # During a normal closure, execution ends here on the next - # iteration of the loop after receiving a close frame. At - # this point, recv_frame() replaced parse() by discard(). - frame = yield from Frame.parse( - self.reader.read_exact, - mask=self.side is SERVER, - max_size=max_size, - extensions=self.extensions, - ) - - if self.debug: - self.logger.debug("< %s", frame) - - self.recv_frame(frame) - - except ProtocolError as exc: - self.fail(1002, str(exc)) - self.parser_exc = exc - - except EOFError as exc: - self.fail(1006, str(exc)) - self.parser_exc = exc - - except UnicodeDecodeError as exc: - self.fail(1007, f"{exc.reason} at position {exc.start}") - self.parser_exc = exc - - except PayloadTooBig as exc: - self.fail(1009, str(exc)) - self.parser_exc = exc - - except Exception as exc: - self.logger.error("parser failed", exc_info=True) - # Don't include exception details, which may be security-sensitive. - self.fail(1011) - self.parser_exc = exc - - # During an abnormal closure, execution ends here after catching an - # exception. At this point, fail() replaced parse() by discard(). - yield - raise AssertionError("parse() shouldn't step after error") - - def discard(self) -> Generator[None, None, None]: - """ - Discard incoming data. - - This coroutine replaces :meth:`parse`: - - - after receiving a close frame, during a normal closure (1.4); - - after sending a close frame, during an abnormal closure (7.1.7). - - """ - # The server close the TCP connection in the same circumstances where - # discard() replaces parse(). The client closes the connection later, - # after the server closes the connection or a timeout elapses. - # (The latter case cannot be handled in this Sans-I/O layer.) - assert (self.side is SERVER) == (self.eof_sent) - while not (yield from self.reader.at_eof()): - self.reader.discard() - if self.debug: - self.logger.debug("< EOF") - # A server closes the TCP connection immediately, while a client - # waits for the server to close the TCP connection. - if self.side is CLIENT: - self.send_eof() - self.state = CLOSED - # If discard() completes normally, execution ends here. - yield - # Once the reader reaches EOF, its feed_data/eof() methods raise an - # error, so our receive_data/eof() methods don't step the generator. - raise AssertionError("discard() shouldn't step after EOF") - - def recv_frame(self, frame: Frame) -> None: - """ - Process an incoming frame. - - """ - if frame.opcode is OP_TEXT or frame.opcode is OP_BINARY: - if self.cur_size is not None: - raise ProtocolError("expected a continuation frame") - if frame.fin: - self.cur_size = None - else: - self.cur_size = len(frame.data) - - elif frame.opcode is OP_CONT: - if self.cur_size is None: - raise ProtocolError("unexpected continuation frame") - if frame.fin: - self.cur_size = None - else: - self.cur_size += len(frame.data) - - elif frame.opcode is OP_PING: - # 5.5.2. Ping: "Upon receipt of a Ping frame, an endpoint MUST - # send a Pong frame in response" - pong_frame = Frame(OP_PONG, frame.data) - self.send_frame(pong_frame) - - elif frame.opcode is OP_PONG: - # 5.5.3 Pong: "A response to an unsolicited Pong frame is not - # expected." - pass - - elif frame.opcode is OP_CLOSE: - # 7.1.5. The WebSocket Connection Close Code - # 7.1.6. The WebSocket Connection Close Reason - self.close_rcvd = Close.parse(frame.data) - if self.state is CLOSING: - assert self.close_sent is not None - self.close_rcvd_then_sent = False - - if self.cur_size is not None: - raise ProtocolError("incomplete fragmented message") - - # 5.5.1 Close: "If an endpoint receives a Close frame and did - # not previously send a Close frame, the endpoint MUST send a - # Close frame in response. (When sending a Close frame in - # response, the endpoint typically echos the status code it - # received.)" - - if self.state is OPEN: - # Echo the original data instead of re-serializing it with - # Close.serialize() because that fails when the close frame - # is empty and Close.parse() synthetizes a 1005 close code. - # The rest is identical to send_close(). - self.send_frame(Frame(OP_CLOSE, frame.data)) - self.close_sent = self.close_rcvd - self.close_rcvd_then_sent = True - self.state = CLOSING - - # 7.1.2. Start the WebSocket Closing Handshake: "Once an - # endpoint has both sent and received a Close control frame, - # that endpoint SHOULD _Close the WebSocket Connection_" - - # A server closes the TCP connection immediately, while a client - # waits for the server to close the TCP connection. - if self.side is SERVER: - self.send_eof() - - # 1.4. Closing Handshake: "after receiving a control frame - # indicating the connection should be closed, a peer discards - # any further data received." - self.parser = self.discard() - next(self.parser) # start coroutine - - else: - # This can't happen because Frame.parse() validates opcodes. - raise AssertionError(f"unexpected opcode: {frame.opcode:02x}") - - self.events.append(frame) - - # Private methods for sending events. - - def send_frame(self, frame: Frame) -> None: - if self.state is not OPEN: - raise InvalidState( - f"cannot write to a WebSocket in the {self.state.name} state" - ) - - if self.debug: - self.logger.debug("> %s", frame) - self.writes.append( - frame.serialize(mask=self.side is CLIENT, extensions=self.extensions) - ) - - def send_eof(self) -> None: - assert not self.eof_sent - self.eof_sent = True - if self.debug: - self.logger.debug("> EOF") - self.writes.append(SEND_EOF) +warnings.warn( + "websockets.connection was renamed to websockets.protocol " + "and Connection was renamed to Protocol", + DeprecationWarning, +) diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 68249192c..ec4e3b8b7 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -68,7 +68,7 @@ class Request: def exception(self) -> Optional[Exception]: # pragma: no cover warnings.warn( "Request.exception is deprecated; " - "use ServerConnection.handshake_exc instead", + "use ServerProtocol.handshake_exc instead", DeprecationWarning, ) return self._exception @@ -172,7 +172,7 @@ class Response: def exception(self) -> Optional[Exception]: # pragma: no cover warnings.warn( "Response.exception is deprecated; " - "use ClientConnection.handshake_exc instead", + "use ClientProtocol.handshake_exc instead", DeprecationWarning, ) return self._exception diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 9f2bda1ab..7881b947d 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -27,7 +27,6 @@ cast, ) -from ..connection import State from ..datastructures import Headers from ..exceptions import ( ConnectionClosed, @@ -51,6 +50,7 @@ prepare_ctrl, prepare_data, ) +from ..protocol import State from ..typing import Data, LoggerLike, Subprotocol from .compatibility import loop_if_py_lt_38 from .framing import Frame diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index f37a16c20..eabeb8e96 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -25,7 +25,6 @@ cast, ) -from ..connection import State from ..datastructures import Headers, HeadersLike, MultipleValuesError from ..exceptions import ( AbortHandshake, @@ -45,6 +44,7 @@ validate_subprotocols, ) from ..http import USER_AGENT +from ..protocol import State from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol from .compatibility import loop_if_py_lt_38 from .handshake import build_response, check_request diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py new file mode 100644 index 000000000..29d7e1596 --- /dev/null +++ b/src/websockets/protocol.py @@ -0,0 +1,702 @@ +from __future__ import annotations + +import enum +import logging +import uuid +from typing import Generator, List, Optional, Type, Union + +from .exceptions import ( + ConnectionClosed, + ConnectionClosedError, + ConnectionClosedOK, + InvalidState, + PayloadTooBig, + ProtocolError, +) +from .extensions import Extension +from .frames import ( + OK_CLOSE_CODES, + OP_BINARY, + OP_CLOSE, + OP_CONT, + OP_PING, + OP_PONG, + OP_TEXT, + Close, + Frame, +) +from .http11 import Request, Response +from .streams import StreamReader +from .typing import LoggerLike, Origin, Subprotocol + + +__all__ = [ + "Protocol", + "Side", + "State", + "SEND_EOF", +] + +Event = Union[Request, Response, Frame] +"""Events that :meth:`~Protocol.events_received` may return.""" + + +class Side(enum.IntEnum): + """A WebSocket connection is either a server or a client.""" + + SERVER, CLIENT = range(2) + + +SERVER = Side.SERVER +CLIENT = Side.CLIENT + + +class State(enum.IntEnum): + """A WebSocket connection is in one of these four states.""" + + CONNECTING, OPEN, CLOSING, CLOSED = range(4) + + +CONNECTING = State.CONNECTING +OPEN = State.OPEN +CLOSING = State.CLOSING +CLOSED = State.CLOSED + + +SEND_EOF = b"" +"""Sentinel signaling that the TCP connection must be half-closed.""" + + +class Protocol: + """ + Sans-I/O implementation of a WebSocket connection. + + Args: + side: :attr:`~Side.CLIENT` or :attr:`~Side.SERVER`. + state: initial state of the WebSocket connection. + max_size: maximum size of incoming messages in bytes; + :obj:`None` to disable the limit. + logger: logger for this connection; depending on ``side``, + defaults to ``logging.getLogger("websockets.client")`` + or ``logging.getLogger("websockets.server")``; + see the :doc:`logging guide <../topics/logging>` for details. + + """ + + def __init__( + self, + side: Side, + state: State = OPEN, + max_size: Optional[int] = 2**20, + logger: Optional[LoggerLike] = None, + ) -> None: + # Unique identifier. For logs. + self.id: uuid.UUID = uuid.uuid4() + """Unique identifier of the connection. Useful in logs.""" + + # Logger or LoggerAdapter for this connection. + if logger is None: + logger = logging.getLogger(f"websockets.{side.name.lower()}") + self.logger: LoggerLike = logger + """Logger for this connection.""" + + # Track if DEBUG is enabled. Shortcut logging calls if it isn't. + self.debug = logger.isEnabledFor(logging.DEBUG) + + # Connection side. CLIENT or SERVER. + self.side = side + + # Connection state. Initially OPEN because subclasses handle CONNECTING. + self.state = state + + # Maximum size of incoming messages in bytes. + self.max_size = max_size + + # Current size of incoming message in bytes. Only set while reading a + # fragmented message i.e. a data frames with the FIN bit not set. + self.cur_size: Optional[int] = None + + # True while sending a fragmented message i.e. a data frames with the + # FIN bit not set. + self.expect_continuation_frame = False + + # WebSocket protocol parameters. + self.origin: Optional[Origin] = None + self.extensions: List[Extension] = [] + self.subprotocol: Optional[Subprotocol] = None + + # Close code and reason, set when a close frame is sent or received. + self.close_rcvd: Optional[Close] = None + self.close_sent: Optional[Close] = None + self.close_rcvd_then_sent: Optional[bool] = None + + # Track if an exception happened during the handshake. + self.handshake_exc: Optional[Exception] = None + """ + Exception to raise if the opening handshake failed. + + :obj:`None` if the opening handshake succeeded. + + """ + + # Track if send_eof() was called. + self.eof_sent = False + + # Parser state. + self.reader = StreamReader() + self.events: List[Event] = [] + self.writes: List[bytes] = [] + self.parser = self.parse() + next(self.parser) # start coroutine + self.parser_exc: Optional[Exception] = None + + @property + def state(self) -> State: + """ + WebSocket connection state. + + Defined in 4.1, 4.2, 7.1.3, and 7.1.4 of :rfc:`6455`. + + """ + return self._state + + @state.setter + def state(self, state: State) -> None: + if self.debug: + self.logger.debug("= connection is %s", state.name) + self._state = state + + @property + def close_code(self) -> Optional[int]: + """ + `WebSocket close code`_. + + .. _WebSocket close code: + https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.5 + + :obj:`None` if the connection isn't closed yet. + + """ + if self.state is not CLOSED: + return None + elif self.close_rcvd is None: + return 1006 + else: + return self.close_rcvd.code + + @property + def close_reason(self) -> Optional[str]: + """ + `WebSocket close reason`_. + + .. _WebSocket close reason: + https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.6 + + :obj:`None` if the connection isn't closed yet. + + """ + if self.state is not CLOSED: + return None + elif self.close_rcvd is None: + return "" + else: + return self.close_rcvd.reason + + @property + def close_exc(self) -> ConnectionClosed: + """ + Exception to raise when trying to interact with a closed connection. + + Don't raise this exception while the connection :attr:`state` + is :attr:`~websockets.protocol.State.CLOSING`; wait until + it's :attr:`~websockets.protocol.State.CLOSED`. + + Indeed, the exception includes the close code and reason, which are + known only once the connection is closed. + + Raises: + AssertionError: if the connection isn't closed yet. + + """ + assert self.state is CLOSED, "connection isn't closed yet" + exc_type: Type[ConnectionClosed] + if ( + self.close_rcvd is not None + and self.close_sent is not None + and self.close_rcvd.code in OK_CLOSE_CODES + and self.close_sent.code in OK_CLOSE_CODES + ): + exc_type = ConnectionClosedOK + else: + exc_type = ConnectionClosedError + exc: ConnectionClosed = exc_type( + self.close_rcvd, + self.close_sent, + self.close_rcvd_then_sent, + ) + # Chain to the exception raised in the parser, if any. + exc.__cause__ = self.parser_exc + return exc + + # Public methods for receiving data. + + def receive_data(self, data: bytes) -> None: + """ + Receive data from the network. + + After calling this method: + + - You must call :meth:`data_to_send` and send this data to the network. + - You should call :meth:`events_received` and process resulting events. + + Raises: + EOFError: if :meth:`receive_eof` was called earlier. + + """ + self.reader.feed_data(data) + next(self.parser) + + def receive_eof(self) -> None: + """ + Receive the end of the data stream from the network. + + After calling this method: + + - You must call :meth:`data_to_send` and send this data to the network. + - You aren't expected to call :meth:`events_received`; it won't return + any new events. + + Raises: + EOFError: if :meth:`receive_eof` was called earlier. + + """ + self.reader.feed_eof() + next(self.parser) + + # Public methods for sending events. + + def send_continuation(self, data: bytes, fin: bool) -> None: + """ + Send a `Continuation frame`_. + + .. _Continuation frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + Parameters: + data: payload containing the same kind of data + as the initial frame. + fin: FIN bit; set it to :obj:`True` if this is the last frame + of a fragmented message and to :obj:`False` otherwise. + + Raises: + ProtocolError: if a fragmented message isn't in progress. + + """ + if not self.expect_continuation_frame: + raise ProtocolError("unexpected continuation frame") + self.expect_continuation_frame = not fin + self.send_frame(Frame(OP_CONT, data, fin)) + + def send_text(self, data: bytes, fin: bool = True) -> None: + """ + Send a `Text frame`_. + + .. _Text frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + Parameters: + data: payload containing text encoded with UTF-8. + fin: FIN bit; set it to :obj:`False` if this is the first frame of + a fragmented message. + + Raises: + ProtocolError: if a fragmented message is in progress. + + """ + if self.expect_continuation_frame: + raise ProtocolError("expected a continuation frame") + self.expect_continuation_frame = not fin + self.send_frame(Frame(OP_TEXT, data, fin)) + + def send_binary(self, data: bytes, fin: bool = True) -> None: + """ + Send a `Binary frame`_. + + .. _Binary frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + Parameters: + data: payload containing arbitrary binary data. + fin: FIN bit; set it to :obj:`False` if this is the first frame of + a fragmented message. + + Raises: + ProtocolError: if a fragmented message is in progress. + + """ + if self.expect_continuation_frame: + raise ProtocolError("expected a continuation frame") + self.expect_continuation_frame = not fin + self.send_frame(Frame(OP_BINARY, data, fin)) + + def send_close(self, code: Optional[int] = None, reason: str = "") -> None: + """ + Send a `Close frame`_. + + .. _Close frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.1 + + Parameters: + code: close code. + reason: close reason. + + Raises: + ProtocolError: if a fragmented message is being sent, if the code + isn't valid, or if a reason is provided without a code + + """ + if self.expect_continuation_frame: + raise ProtocolError("expected a continuation frame") + if code is None: + if reason != "": + raise ProtocolError("cannot send a reason without a code") + close = Close(1005, "") + data = b"" + else: + close = Close(code, reason) + data = close.serialize() + # send_frame() guarantees that self.state is OPEN at this point. + # 7.1.3. The WebSocket Closing Handshake is Started + self.send_frame(Frame(OP_CLOSE, data)) + self.close_sent = close + self.state = CLOSING + + def send_ping(self, data: bytes) -> None: + """ + Send a `Ping frame`_. + + .. _Ping frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 + + Parameters: + data: payload containing arbitrary binary data. + + """ + self.send_frame(Frame(OP_PING, data)) + + def send_pong(self, data: bytes) -> None: + """ + Send a `Pong frame`_. + + .. _Pong frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 + + Parameters: + data: payload containing arbitrary binary data. + + """ + self.send_frame(Frame(OP_PONG, data)) + + def fail(self, code: int, reason: str = "") -> None: + """ + `Fail the WebSocket connection`_. + + .. _Fail the WebSocket connection: + https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.7 + + Parameters: + code: close code + reason: close reason + + Raises: + ProtocolError: if the code isn't valid. + """ + # 7.1.7. Fail the WebSocket Connection + + # Send a close frame when the state is OPEN (a close frame was already + # sent if it's CLOSING), except when failing the connection because + # of an error reading from or writing to the network. + if self.state is OPEN: + if code != 1006: + close = Close(code, reason) + data = close.serialize() + self.send_frame(Frame(OP_CLOSE, data)) + self.close_sent = close + self.state = CLOSING + + # When failing the connection, a server closes the TCP connection + # without waiting for the client to complete the handshake, while a + # client waits for the server to close the TCP connection, possibly + # after sending a close frame that the client will ignore. + if self.side is SERVER and not self.eof_sent: + self.send_eof() + + # 7.1.7. Fail the WebSocket Connection "An endpoint MUST NOT continue + # to attempt to process data(including a responding Close frame) from + # the remote endpoint after being instructed to _Fail the WebSocket + # Connection_." + self.parser = self.discard() + next(self.parser) # start coroutine + + # Public method for getting incoming events after receiving data. + + def events_received(self) -> List[Event]: + """ + Fetch events generated from data received from the network. + + Call this method immediately after any of the ``receive_*()`` methods. + + Process resulting events, likely by passing them to the application. + + Returns: + List[Event]: Events read from the connection. + """ + events, self.events = self.events, [] + return events + + # Public method for getting outgoing data after receiving data or sending events. + + def data_to_send(self) -> List[bytes]: + """ + Obtain data to send to the network. + + Call this method immediately after any of the ``receive_*()``, + ``send_*()``, or :meth:`fail` methods. + + Write resulting data to the connection. + + The empty bytestring :data:`~websockets.protocol.SEND_EOF` signals + the end of the data stream. When you receive it, half-close the TCP + connection. + + Returns: + List[bytes]: Data to write to the connection. + + """ + writes, self.writes = self.writes, [] + return writes + + def close_expected(self) -> bool: + """ + Tell if the TCP connection is expected to close soon. + + Call this method immediately after any of the ``receive_*()`` or + :meth:`fail` methods. + + If it returns :obj:`True`, schedule closing the TCP connection after a + short timeout if the other side hasn't already closed it. + + Returns: + bool: Whether the TCP connection is expected to close soon. + + """ + # We expect a TCP close if and only if we sent a close frame: + # * Normal closure: once we send a close frame, we expect a TCP close: + # server waits for client to complete the TCP closing handshake; + # client waits for server to initiate the TCP closing handshake. + # * Abnormal closure: we always send a close frame and the same logic + # applies, except on EOFError where we don't send a close frame + # because we already received the TCP close, so we don't expect it. + # We already got a TCP Close if and only if the state is CLOSED. + return self.state is CLOSING or self.handshake_exc is not None + + # Private methods for receiving data. + + def parse(self) -> Generator[None, None, None]: + """ + Parse incoming data into frames. + + :meth:`receive_data` and :meth:`receive_eof` run this generator + coroutine until it needs more data or reaches EOF. + + """ + try: + while True: + if (yield from self.reader.at_eof()): + if self.debug: + self.logger.debug("< EOF") + # If the WebSocket connection is closed cleanly, with a + # closing handhshake, recv_frame() substitutes parse() + # with discard(). This branch is reached only when the + # connection isn't closed cleanly. + raise EOFError("unexpected end of stream") + + if self.max_size is None: + max_size = None + elif self.cur_size is None: + max_size = self.max_size + else: + max_size = self.max_size - self.cur_size + + # During a normal closure, execution ends here on the next + # iteration of the loop after receiving a close frame. At + # this point, recv_frame() replaced parse() by discard(). + frame = yield from Frame.parse( + self.reader.read_exact, + mask=self.side is SERVER, + max_size=max_size, + extensions=self.extensions, + ) + + if self.debug: + self.logger.debug("< %s", frame) + + self.recv_frame(frame) + + except ProtocolError as exc: + self.fail(1002, str(exc)) + self.parser_exc = exc + + except EOFError as exc: + self.fail(1006, str(exc)) + self.parser_exc = exc + + except UnicodeDecodeError as exc: + self.fail(1007, f"{exc.reason} at position {exc.start}") + self.parser_exc = exc + + except PayloadTooBig as exc: + self.fail(1009, str(exc)) + self.parser_exc = exc + + except Exception as exc: + self.logger.error("parser failed", exc_info=True) + # Don't include exception details, which may be security-sensitive. + self.fail(1011) + self.parser_exc = exc + + # During an abnormal closure, execution ends here after catching an + # exception. At this point, fail() replaced parse() by discard(). + yield + raise AssertionError("parse() shouldn't step after error") + + def discard(self) -> Generator[None, None, None]: + """ + Discard incoming data. + + This coroutine replaces :meth:`parse`: + + - after receiving a close frame, during a normal closure (1.4); + - after sending a close frame, during an abnormal closure (7.1.7). + + """ + # The server close the TCP connection in the same circumstances where + # discard() replaces parse(). The client closes the connection later, + # after the server closes the connection or a timeout elapses. + # (The latter case cannot be handled in this Sans-I/O layer.) + assert (self.side is SERVER) == (self.eof_sent) + while not (yield from self.reader.at_eof()): + self.reader.discard() + if self.debug: + self.logger.debug("< EOF") + # A server closes the TCP connection immediately, while a client + # waits for the server to close the TCP connection. + if self.side is CLIENT: + self.send_eof() + self.state = CLOSED + # If discard() completes normally, execution ends here. + yield + # Once the reader reaches EOF, its feed_data/eof() methods raise an + # error, so our receive_data/eof() methods don't step the generator. + raise AssertionError("discard() shouldn't step after EOF") + + def recv_frame(self, frame: Frame) -> None: + """ + Process an incoming frame. + + """ + if frame.opcode is OP_TEXT or frame.opcode is OP_BINARY: + if self.cur_size is not None: + raise ProtocolError("expected a continuation frame") + if frame.fin: + self.cur_size = None + else: + self.cur_size = len(frame.data) + + elif frame.opcode is OP_CONT: + if self.cur_size is None: + raise ProtocolError("unexpected continuation frame") + if frame.fin: + self.cur_size = None + else: + self.cur_size += len(frame.data) + + elif frame.opcode is OP_PING: + # 5.5.2. Ping: "Upon receipt of a Ping frame, an endpoint MUST + # send a Pong frame in response" + pong_frame = Frame(OP_PONG, frame.data) + self.send_frame(pong_frame) + + elif frame.opcode is OP_PONG: + # 5.5.3 Pong: "A response to an unsolicited Pong frame is not + # expected." + pass + + elif frame.opcode is OP_CLOSE: + # 7.1.5. The WebSocket Connection Close Code + # 7.1.6. The WebSocket Connection Close Reason + self.close_rcvd = Close.parse(frame.data) + if self.state is CLOSING: + assert self.close_sent is not None + self.close_rcvd_then_sent = False + + if self.cur_size is not None: + raise ProtocolError("incomplete fragmented message") + + # 5.5.1 Close: "If an endpoint receives a Close frame and did + # not previously send a Close frame, the endpoint MUST send a + # Close frame in response. (When sending a Close frame in + # response, the endpoint typically echos the status code it + # received.)" + + if self.state is OPEN: + # Echo the original data instead of re-serializing it with + # Close.serialize() because that fails when the close frame + # is empty and Close.parse() synthetizes a 1005 close code. + # The rest is identical to send_close(). + self.send_frame(Frame(OP_CLOSE, frame.data)) + self.close_sent = self.close_rcvd + self.close_rcvd_then_sent = True + self.state = CLOSING + + # 7.1.2. Start the WebSocket Closing Handshake: "Once an + # endpoint has both sent and received a Close control frame, + # that endpoint SHOULD _Close the WebSocket Connection_" + + # A server closes the TCP connection immediately, while a client + # waits for the server to close the TCP connection. + if self.side is SERVER: + self.send_eof() + + # 1.4. Closing Handshake: "after receiving a control frame + # indicating the connection should be closed, a peer discards + # any further data received." + self.parser = self.discard() + next(self.parser) # start coroutine + + else: + # This can't happen because Frame.parse() validates opcodes. + raise AssertionError(f"unexpected opcode: {frame.opcode:02x}") + + self.events.append(frame) + + # Private methods for sending events. + + def send_frame(self, frame: Frame) -> None: + if self.state is not OPEN: + raise InvalidState( + f"cannot write to a WebSocket in the {self.state.name} state" + ) + + if self.debug: + self.logger.debug("> %s", frame) + self.writes.append( + frame.serialize(mask=self.side is CLIENT, extensions=self.extensions) + ) + + def send_eof(self) -> None: + assert not self.eof_sent + self.eof_sent = True + if self.debug: + self.logger.debug("> EOF") + self.writes.append(SEND_EOF) diff --git a/src/websockets/server.py b/src/websockets/server.py index edd1764c3..548048c92 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -4,9 +4,9 @@ import binascii import email.utils import http -from typing import Generator, List, Optional, Sequence, Tuple, cast +import warnings +from typing import Any, Generator, List, Optional, Sequence, Tuple, cast -from .connection import CONNECTING, OPEN, SERVER, Connection, State from .datastructures import Headers, MultipleValuesError from .exceptions import ( InvalidHandshake, @@ -26,6 +26,7 @@ parse_upgrade, ) from .http11 import Request, Response +from .protocol import CONNECTING, OPEN, SERVER, Protocol, State from .typing import ( ConnectionOption, ExtensionHeader, @@ -41,10 +42,10 @@ from .legacy.server import * # isort:skip # noqa -__all__ = ["ServerConnection"] +__all__ = ["ServerProtocol"] -class ServerConnection(Connection): +class ServerProtocol(Protocol): """ Sans-I/O implementation of a WebSocket server connection. @@ -515,3 +516,12 @@ def parse(self) -> Generator[None, None, None]: self.events.append(request) yield from super().parse() + + +class ServerConnection(ServerProtocol): + def __init__(self, *args: Any, **kwargs: Any) -> None: + warnings.warn( + "ServerConnection was renamed to ServerProtocol", + DeprecationWarning, + ) + super().__init__(*args, **kwargs) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index f8a79027e..4a4510536 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -16,7 +16,6 @@ import urllib.request import warnings -from websockets.connection import State from websockets.datastructures import Headers from websockets.exceptions import ( ConnectionClosed, @@ -35,6 +34,7 @@ from websockets.legacy.handshake import build_response from websockets.legacy.http import read_response from websockets.legacy.server import * +from websockets.protocol import State from websockets.uri import parse_uri from ..extensions.utils import ( diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 1f830ebee..e85402a39 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -4,7 +4,6 @@ import unittest.mock import warnings -from websockets.connection import State from websockets.exceptions import ConnectionClosed, InvalidState from websockets.frames import ( OP_BINARY, @@ -18,6 +17,7 @@ from websockets.legacy.compatibility import loop_if_py_lt_38 from websockets.legacy.framing import Frame from websockets.legacy.protocol import WebSocketCommonProtocol, broadcast +from websockets.protocol import State from .utils import MS, AsyncioTestCase diff --git a/tests/test_client.py b/tests/test_client.py index 9ed36c1d4..718219b9d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -3,11 +3,11 @@ import unittest.mock from websockets.client import * -from websockets.connection import CONNECTING, OPEN from websockets.datastructures import Headers from websockets.exceptions import InvalidHandshake, InvalidHeader from websockets.frames import OP_TEXT, Frame from websockets.http11 import Request, Response +from websockets.protocol import CONNECTING, OPEN from websockets.uri import parse_uri from websockets.utils import accept_key @@ -18,13 +18,13 @@ Rsv2Extension, ) from .test_utils import ACCEPT, KEY -from .utils import DATE +from .utils import DATE, DeprecationTestCase class ConnectTests(unittest.TestCase): def test_send_connect(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientConnection(parse_uri("wss://example.com/test")) + client = ClientProtocol(parse_uri("wss://example.com/test")) request = client.connect() self.assertIsInstance(request, Request) client.send_request(request) @@ -44,7 +44,7 @@ def test_send_connect(self): def test_connect_request(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientConnection(parse_uri("wss://example.com/test")) + client = ClientProtocol(parse_uri("wss://example.com/test")) request = client.connect() self.assertEqual(request.path, "/test") self.assertEqual( @@ -61,7 +61,7 @@ def test_connect_request(self): ) def test_path(self): - client = ClientConnection(parse_uri("wss://example.com/endpoint?test=1")) + client = ClientProtocol(parse_uri("wss://example.com/endpoint?test=1")) request = client.connect() self.assertEqual(request.path, "/endpoint?test=1") @@ -76,19 +76,19 @@ def test_port(self): ("wss://example.com:8443/", "example.com:8443"), ]: with self.subTest(uri=uri): - client = ClientConnection(parse_uri(uri)) + client = ClientProtocol(parse_uri(uri)) request = client.connect() self.assertEqual(request.headers["Host"], host) def test_user_info(self): - client = ClientConnection(parse_uri("wss://hello:iloveyou@example.com/")) + client = ClientProtocol(parse_uri("wss://hello:iloveyou@example.com/")) request = client.connect() self.assertEqual(request.headers["Authorization"], "Basic aGVsbG86aWxvdmV5b3U=") def test_origin(self): - client = ClientConnection( + client = ClientProtocol( parse_uri("wss://example.com/"), origin="https://example.com", ) @@ -97,7 +97,7 @@ def test_origin(self): self.assertEqual(request.headers["Origin"], "https://example.com") def test_extensions(self): - client = ClientConnection( + client = ClientProtocol( parse_uri("wss://example.com/"), extensions=[ClientOpExtensionFactory()], ) @@ -106,7 +106,7 @@ def test_extensions(self): self.assertEqual(request.headers["Sec-WebSocket-Extensions"], "x-op; op") def test_subprotocols(self): - client = ClientConnection( + client = ClientProtocol( parse_uri("wss://example.com/"), subprotocols=["chat"], ) @@ -118,7 +118,7 @@ def test_subprotocols(self): class AcceptRejectTests(unittest.TestCase): def test_receive_accept(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientConnection(parse_uri("ws://example.com/test")) + client = ClientProtocol(parse_uri("ws://example.com/test")) client.connect() client.receive_data( ( @@ -138,7 +138,7 @@ def test_receive_accept(self): def test_receive_reject(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientConnection(parse_uri("ws://example.com/test")) + client = ClientProtocol(parse_uri("ws://example.com/test")) client.connect() client.receive_data( ( @@ -159,7 +159,7 @@ def test_receive_reject(self): def test_accept_response(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientConnection(parse_uri("ws://example.com/test")) + client = ClientProtocol(parse_uri("ws://example.com/test")) client.connect() client.receive_data( ( @@ -189,7 +189,7 @@ def test_accept_response(self): def test_reject_response(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientConnection(parse_uri("ws://example.com/test")) + client = ClientProtocol(parse_uri("ws://example.com/test")) client.connect() client.receive_data( ( @@ -235,7 +235,7 @@ def make_accept_response(self, client): ) def test_basic(self): - client = ClientConnection(parse_uri("wss://example.com/")) + client = ClientProtocol(parse_uri("wss://example.com/")) response = self.make_accept_response(client) client.receive_data(response.serialize()) [response] = client.events_received() @@ -243,7 +243,7 @@ def test_basic(self): self.assertEqual(client.state, OPEN) def test_missing_connection(self): - client = ClientConnection(parse_uri("wss://example.com/")) + client = ClientProtocol(parse_uri("wss://example.com/")) response = self.make_accept_response(client) del response.headers["Connection"] client.receive_data(response.serialize()) @@ -255,7 +255,7 @@ def test_missing_connection(self): self.assertEqual(str(raised.exception), "missing Connection header") def test_invalid_connection(self): - client = ClientConnection(parse_uri("wss://example.com/")) + client = ClientProtocol(parse_uri("wss://example.com/")) response = self.make_accept_response(client) del response.headers["Connection"] response.headers["Connection"] = "close" @@ -268,7 +268,7 @@ def test_invalid_connection(self): self.assertEqual(str(raised.exception), "invalid Connection header: close") def test_missing_upgrade(self): - client = ClientConnection(parse_uri("wss://example.com/")) + client = ClientProtocol(parse_uri("wss://example.com/")) response = self.make_accept_response(client) del response.headers["Upgrade"] client.receive_data(response.serialize()) @@ -280,7 +280,7 @@ def test_missing_upgrade(self): self.assertEqual(str(raised.exception), "missing Upgrade header") def test_invalid_upgrade(self): - client = ClientConnection(parse_uri("wss://example.com/")) + client = ClientProtocol(parse_uri("wss://example.com/")) response = self.make_accept_response(client) del response.headers["Upgrade"] response.headers["Upgrade"] = "h2c" @@ -293,7 +293,7 @@ def test_invalid_upgrade(self): self.assertEqual(str(raised.exception), "invalid Upgrade header: h2c") def test_missing_accept(self): - client = ClientConnection(parse_uri("wss://example.com/")) + client = ClientProtocol(parse_uri("wss://example.com/")) response = self.make_accept_response(client) del response.headers["Sec-WebSocket-Accept"] client.receive_data(response.serialize()) @@ -305,7 +305,7 @@ def test_missing_accept(self): self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Accept header") def test_multiple_accept(self): - client = ClientConnection(parse_uri("wss://example.com/")) + client = ClientProtocol(parse_uri("wss://example.com/")) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Accept"] = ACCEPT client.receive_data(response.serialize()) @@ -321,7 +321,7 @@ def test_multiple_accept(self): ) def test_invalid_accept(self): - client = ClientConnection(parse_uri("wss://example.com/")) + client = ClientProtocol(parse_uri("wss://example.com/")) response = self.make_accept_response(client) del response.headers["Sec-WebSocket-Accept"] response.headers["Sec-WebSocket-Accept"] = ACCEPT @@ -336,7 +336,7 @@ def test_invalid_accept(self): ) def test_no_extensions(self): - client = ClientConnection(parse_uri("wss://example.com/")) + client = ClientProtocol(parse_uri("wss://example.com/")) response = self.make_accept_response(client) client.receive_data(response.serialize()) [response] = client.events_received() @@ -345,7 +345,7 @@ def test_no_extensions(self): self.assertEqual(client.extensions, []) def test_no_extension(self): - client = ClientConnection( + client = ClientProtocol( parse_uri("wss://example.com/"), extensions=[ClientOpExtensionFactory()], ) @@ -358,7 +358,7 @@ def test_no_extension(self): self.assertEqual(client.extensions, [OpExtension()]) def test_extension(self): - client = ClientConnection( + client = ClientProtocol( parse_uri("wss://example.com/"), extensions=[ClientRsv2ExtensionFactory()], ) @@ -371,7 +371,7 @@ def test_extension(self): self.assertEqual(client.extensions, [Rsv2Extension()]) def test_unexpected_extension(self): - client = ClientConnection(parse_uri("wss://example.com/")) + client = ClientProtocol(parse_uri("wss://example.com/")) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-op; op" client.receive_data(response.serialize()) @@ -383,7 +383,7 @@ def test_unexpected_extension(self): self.assertEqual(str(raised.exception), "no extensions supported") def test_unsupported_extension(self): - client = ClientConnection( + client = ClientProtocol( parse_uri("wss://example.com/"), extensions=[ClientRsv2ExtensionFactory()], ) @@ -401,7 +401,7 @@ def test_unsupported_extension(self): ) def test_supported_extension_parameters(self): - client = ClientConnection( + client = ClientProtocol( parse_uri("wss://example.com/"), extensions=[ClientOpExtensionFactory("this")], ) @@ -414,7 +414,7 @@ def test_supported_extension_parameters(self): self.assertEqual(client.extensions, [OpExtension("this")]) def test_unsupported_extension_parameters(self): - client = ClientConnection( + client = ClientProtocol( parse_uri("wss://example.com/"), extensions=[ClientOpExtensionFactory("this")], ) @@ -432,7 +432,7 @@ def test_unsupported_extension_parameters(self): ) def test_multiple_supported_extension_parameters(self): - client = ClientConnection( + client = ClientProtocol( parse_uri("wss://example.com/"), extensions=[ ClientOpExtensionFactory("this"), @@ -448,7 +448,7 @@ def test_multiple_supported_extension_parameters(self): self.assertEqual(client.extensions, [OpExtension("that")]) def test_multiple_extensions(self): - client = ClientConnection( + client = ClientProtocol( parse_uri("wss://example.com/"), extensions=[ClientOpExtensionFactory(), ClientRsv2ExtensionFactory()], ) @@ -462,7 +462,7 @@ def test_multiple_extensions(self): self.assertEqual(client.extensions, [OpExtension(), Rsv2Extension()]) def test_multiple_extensions_order(self): - client = ClientConnection( + client = ClientProtocol( parse_uri("wss://example.com/"), extensions=[ClientOpExtensionFactory(), ClientRsv2ExtensionFactory()], ) @@ -476,7 +476,7 @@ def test_multiple_extensions_order(self): self.assertEqual(client.extensions, [Rsv2Extension(), OpExtension()]) def test_no_subprotocols(self): - client = ClientConnection(parse_uri("wss://example.com/")) + client = ClientProtocol(parse_uri("wss://example.com/")) response = self.make_accept_response(client) client.receive_data(response.serialize()) [response] = client.events_received() @@ -485,9 +485,7 @@ def test_no_subprotocols(self): self.assertIsNone(client.subprotocol) def test_no_subprotocol(self): - client = ClientConnection( - parse_uri("wss://example.com/"), subprotocols=["chat"] - ) + client = ClientProtocol(parse_uri("wss://example.com/"), subprotocols=["chat"]) response = self.make_accept_response(client) client.receive_data(response.serialize()) [response] = client.events_received() @@ -496,9 +494,7 @@ def test_no_subprotocol(self): self.assertIsNone(client.subprotocol) def test_subprotocol(self): - client = ClientConnection( - parse_uri("wss://example.com/"), subprotocols=["chat"] - ) + client = ClientProtocol(parse_uri("wss://example.com/"), subprotocols=["chat"]) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Protocol"] = "chat" client.receive_data(response.serialize()) @@ -508,7 +504,7 @@ def test_subprotocol(self): self.assertEqual(client.subprotocol, "chat") def test_unexpected_subprotocol(self): - client = ClientConnection(parse_uri("wss://example.com/")) + client = ClientProtocol(parse_uri("wss://example.com/")) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Protocol"] = "chat" client.receive_data(response.serialize()) @@ -520,7 +516,7 @@ def test_unexpected_subprotocol(self): self.assertEqual(str(raised.exception), "no subprotocols supported") def test_multiple_subprotocols(self): - client = ClientConnection( + client = ClientProtocol( parse_uri("wss://example.com/"), subprotocols=["superchat", "chat"], ) @@ -538,7 +534,7 @@ def test_multiple_subprotocols(self): ) def test_supported_subprotocol(self): - client = ClientConnection( + client = ClientProtocol( parse_uri("wss://example.com/"), subprotocols=["superchat", "chat"], ) @@ -551,7 +547,7 @@ def test_supported_subprotocol(self): self.assertEqual(client.subprotocol, "chat") def test_unsupported_subprotocol(self): - client = ClientConnection( + client = ClientProtocol( parse_uri("wss://example.com/"), subprotocols=["superchat", "chat"], ) @@ -568,7 +564,7 @@ def test_unsupported_subprotocol(self): class MiscTests(unittest.TestCase): def test_bypass_handshake(self): - client = ClientConnection(parse_uri("ws://example.com/test"), state=OPEN) + client = ClientProtocol(parse_uri("ws://example.com/test"), state=OPEN) client.receive_data(b"\x81\x06Hello!") [frame] = client.events_received() self.assertEqual(frame, Frame(OP_TEXT, b"Hello!")) @@ -576,5 +572,17 @@ def test_bypass_handshake(self): def test_custom_logger(self): logger = logging.getLogger("test") with self.assertLogs("test", logging.DEBUG) as logs: - ClientConnection(parse_uri("wss://example.com/test"), logger=logger) + ClientProtocol(parse_uri("wss://example.com/test"), logger=logger) self.assertEqual(len(logs.records), 1) + + +class BackwardsCompatibilityTests(DeprecationTestCase): + def test_client_connection_class(self): + with self.assertDeprecationWarning( + "ClientConnection was renamed to ClientProtocol" + ): + from websockets.client import ClientConnection + + client = ClientConnection("ws://localhost/") + + self.assertIsInstance(client, ClientProtocol) diff --git a/tests/test_connection.py b/tests/test_connection.py index 3858d2521..6592d67d0 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,1737 +1,14 @@ -import logging -import unittest.mock +from websockets.protocol import Protocol -from websockets.connection import * -from websockets.connection import CLIENT, CLOSED, CLOSING, SERVER -from websockets.exceptions import ( - ConnectionClosedError, - ConnectionClosedOK, - InvalidState, - PayloadTooBig, - ProtocolError, -) -from websockets.frames import ( - OP_BINARY, - OP_CLOSE, - OP_CONT, - OP_PING, - OP_PONG, - OP_TEXT, - Close, - Frame, -) +from .utils import DeprecationTestCase -from .extensions.utils import Rsv2Extension -from .test_frames import FramesTestCase +class BackwardsCompatibilityTests(DeprecationTestCase): + def test_connection_class(self): + with self.assertDeprecationWarning( + "websockets.connection was renamed to websockets.protocol " + "and Connection was renamed to Protocol" + ): + from websockets.connection import Connection -class ConnectionTestCase(FramesTestCase): - def assertFrameSent(self, connection, frame, eof=False): - """ - Outgoing data for ``connection`` contains the given frame. - - ``frame`` may be ``None`` if no frame is expected. - - When ``eof`` is ``True``, the end of the stream is also expected. - - """ - frames_sent = [ - None - if write is SEND_EOF - else self.parse( - write, - mask=connection.side is CLIENT, - extensions=connection.extensions, - ) - for write in connection.data_to_send() - ] - frames_expected = [] if frame is None else [frame] - if eof: - frames_expected += [None] - self.assertEqual(frames_sent, frames_expected) - - def assertFrameReceived(self, connection, frame): - """ - Incoming data for ``connection`` contains the given frame. - - ``frame`` may be ``None`` if no frame is expected. - - """ - frames_received = connection.events_received() - frames_expected = [] if frame is None else [frame] - self.assertEqual(frames_received, frames_expected) - - def assertConnectionClosing(self, connection, code=None, reason=""): - """ - Incoming data caused the "Start the WebSocket Closing Handshake" process. - - """ - close_frame = Frame( - OP_CLOSE, - b"" if code is None else Close(code, reason).serialize(), - ) - # A close frame was received. - self.assertFrameReceived(connection, close_frame) - # A close frame and possibly the end of stream were sent. - self.assertFrameSent(connection, close_frame, eof=connection.side is SERVER) - - def assertConnectionFailing(self, connection, code=None, reason=""): - """ - Incoming data caused the "Fail the WebSocket Connection" process. - - """ - close_frame = Frame( - OP_CLOSE, - b"" if code is None else Close(code, reason).serialize(), - ) - # No frame was received. - self.assertFrameReceived(connection, None) - # A close frame and possibly the end of stream were sent. - self.assertFrameSent(connection, close_frame, eof=connection.side is SERVER) - - -class MaskingTests(ConnectionTestCase): - """ - Test frame masking. - - 5.1. Overview - - """ - - unmasked_text_frame_date = b"\x81\x04Spam" - masked_text_frame_data = b"\x81\x84\x00\xff\x00\xff\x53\x8f\x61\x92" - - def test_client_sends_masked_frame(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x00\xff\x00\xff"): - client.send_text(b"Spam", True) - self.assertEqual(client.data_to_send(), [self.masked_text_frame_data]) - - def test_server_sends_unmasked_frame(self): - server = Connection(SERVER) - server.send_text(b"Spam", True) - self.assertEqual(server.data_to_send(), [self.unmasked_text_frame_date]) - - def test_client_receives_unmasked_frame(self): - client = Connection(CLIENT) - client.receive_data(self.unmasked_text_frame_date) - self.assertFrameReceived( - client, - Frame(OP_TEXT, b"Spam"), - ) - - def test_server_receives_masked_frame(self): - server = Connection(SERVER) - server.receive_data(self.masked_text_frame_data) - self.assertFrameReceived( - server, - Frame(OP_TEXT, b"Spam"), - ) - - def test_client_receives_masked_frame(self): - client = Connection(CLIENT) - client.receive_data(self.masked_text_frame_data) - self.assertIsInstance(client.parser_exc, ProtocolError) - self.assertEqual(str(client.parser_exc), "incorrect masking") - self.assertConnectionFailing(client, 1002, "incorrect masking") - - def test_server_receives_unmasked_frame(self): - server = Connection(SERVER) - server.receive_data(self.unmasked_text_frame_date) - self.assertIsInstance(server.parser_exc, ProtocolError) - self.assertEqual(str(server.parser_exc), "incorrect masking") - self.assertConnectionFailing(server, 1002, "incorrect masking") - - -class ContinuationTests(ConnectionTestCase): - """ - Test continuation frames without text or binary frames. - - """ - - def test_client_sends_unexpected_continuation(self): - client = Connection(CLIENT) - with self.assertRaises(ProtocolError) as raised: - client.send_continuation(b"", fin=False) - self.assertEqual(str(raised.exception), "unexpected continuation frame") - - def test_server_sends_unexpected_continuation(self): - server = Connection(SERVER) - with self.assertRaises(ProtocolError) as raised: - server.send_continuation(b"", fin=False) - self.assertEqual(str(raised.exception), "unexpected continuation frame") - - def test_client_receives_unexpected_continuation(self): - client = Connection(CLIENT) - client.receive_data(b"\x00\x00") - self.assertIsInstance(client.parser_exc, ProtocolError) - self.assertEqual(str(client.parser_exc), "unexpected continuation frame") - self.assertConnectionFailing(client, 1002, "unexpected continuation frame") - - def test_server_receives_unexpected_continuation(self): - server = Connection(SERVER) - server.receive_data(b"\x00\x80\x00\x00\x00\x00") - self.assertIsInstance(server.parser_exc, ProtocolError) - self.assertEqual(str(server.parser_exc), "unexpected continuation frame") - self.assertConnectionFailing(server, 1002, "unexpected continuation frame") - - def test_client_sends_continuation_after_sending_close(self): - client = Connection(CLIENT) - # Since it isn't possible to send a close frame in a fragmented - # message (see test_client_send_close_in_fragmented_message), in fact, - # this is the same test as test_client_sends_unexpected_continuation. - with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_close(1001) - self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - with self.assertRaises(ProtocolError) as raised: - client.send_continuation(b"", fin=False) - self.assertEqual(str(raised.exception), "unexpected continuation frame") - - def test_server_sends_continuation_after_sending_close(self): - # Since it isn't possible to send a close frame in a fragmented - # message (see test_server_send_close_in_fragmented_message), in fact, - # this is the same test as test_server_sends_unexpected_continuation. - server = Connection(SERVER) - server.send_close(1000) - self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) - with self.assertRaises(ProtocolError) as raised: - server.send_continuation(b"", fin=False) - self.assertEqual(str(raised.exception), "unexpected continuation frame") - - def test_client_receives_continuation_after_receiving_close(self): - client = Connection(CLIENT) - client.receive_data(b"\x88\x02\x03\xe8") - self.assertConnectionClosing(client, 1000) - client.receive_data(b"\x00\x00") - self.assertFrameReceived(client, None) - self.assertFrameSent(client, None) - - def test_server_receives_continuation_after_receiving_close(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") - self.assertConnectionClosing(server, 1001) - server.receive_data(b"\x00\x80\x00\xff\x00\xff") - self.assertFrameReceived(server, None) - self.assertFrameSent(server, None) - - -class TextTests(ConnectionTestCase): - """ - Test text frames and continuation frames. - - """ - - def test_client_sends_text(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_text("😀".encode()) - self.assertEqual( - client.data_to_send(), [b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80"] - ) - - def test_server_sends_text(self): - server = Connection(SERVER) - server.send_text("😀".encode()) - self.assertEqual(server.data_to_send(), [b"\x81\x04\xf0\x9f\x98\x80"]) - - def test_client_receives_text(self): - client = Connection(CLIENT) - client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") - self.assertFrameReceived( - client, - Frame(OP_TEXT, "😀".encode()), - ) - - def test_server_receives_text(self): - server = Connection(SERVER) - server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") - self.assertFrameReceived( - server, - Frame(OP_TEXT, "😀".encode()), - ) - - def test_client_receives_text_over_size_limit(self): - client = Connection(CLIENT, max_size=3) - client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") - self.assertIsInstance(client.parser_exc, PayloadTooBig) - self.assertEqual(str(client.parser_exc), "over size limit (4 > 3 bytes)") - self.assertConnectionFailing(client, 1009, "over size limit (4 > 3 bytes)") - - def test_server_receives_text_over_size_limit(self): - server = Connection(SERVER, max_size=3) - server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") - self.assertIsInstance(server.parser_exc, PayloadTooBig) - self.assertEqual(str(server.parser_exc), "over size limit (4 > 3 bytes)") - self.assertConnectionFailing(server, 1009, "over size limit (4 > 3 bytes)") - - def test_client_receives_text_without_size_limit(self): - client = Connection(CLIENT, max_size=None) - client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") - self.assertFrameReceived( - client, - Frame(OP_TEXT, "😀".encode()), - ) - - def test_server_receives_text_without_size_limit(self): - server = Connection(SERVER, max_size=None) - server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") - self.assertFrameReceived( - server, - Frame(OP_TEXT, "😀".encode()), - ) - - def test_client_sends_fragmented_text(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_text("😀".encode()[:2], fin=False) - self.assertEqual(client.data_to_send(), [b"\x01\x82\x00\x00\x00\x00\xf0\x9f"]) - with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_continuation("😀😀".encode()[2:6], fin=False) - self.assertEqual( - client.data_to_send(), [b"\x00\x84\x00\x00\x00\x00\x98\x80\xf0\x9f"] - ) - with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_continuation("😀".encode()[2:], fin=True) - self.assertEqual(client.data_to_send(), [b"\x80\x82\x00\x00\x00\x00\x98\x80"]) - - def test_server_sends_fragmented_text(self): - server = Connection(SERVER) - server.send_text("😀".encode()[:2], fin=False) - self.assertEqual(server.data_to_send(), [b"\x01\x02\xf0\x9f"]) - server.send_continuation("😀😀".encode()[2:6], fin=False) - self.assertEqual(server.data_to_send(), [b"\x00\x04\x98\x80\xf0\x9f"]) - server.send_continuation("😀".encode()[2:], fin=True) - self.assertEqual(server.data_to_send(), [b"\x80\x02\x98\x80"]) - - def test_client_receives_fragmented_text(self): - client = Connection(CLIENT) - client.receive_data(b"\x01\x02\xf0\x9f") - self.assertFrameReceived( - client, - Frame(OP_TEXT, "😀".encode()[:2], fin=False), - ) - client.receive_data(b"\x00\x04\x98\x80\xf0\x9f") - self.assertFrameReceived( - client, - Frame(OP_CONT, "😀😀".encode()[2:6], fin=False), - ) - client.receive_data(b"\x80\x02\x98\x80") - self.assertFrameReceived( - client, - Frame(OP_CONT, "😀".encode()[2:]), - ) - - def test_server_receives_fragmented_text(self): - server = Connection(SERVER) - server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") - self.assertFrameReceived( - server, - Frame(OP_TEXT, "😀".encode()[:2], fin=False), - ) - server.receive_data(b"\x00\x84\x00\x00\x00\x00\x98\x80\xf0\x9f") - self.assertFrameReceived( - server, - Frame(OP_CONT, "😀😀".encode()[2:6], fin=False), - ) - server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") - self.assertFrameReceived( - server, - Frame(OP_CONT, "😀".encode()[2:]), - ) - - def test_client_receives_fragmented_text_over_size_limit(self): - client = Connection(CLIENT, max_size=3) - client.receive_data(b"\x01\x02\xf0\x9f") - self.assertFrameReceived( - client, - Frame(OP_TEXT, "😀".encode()[:2], fin=False), - ) - client.receive_data(b"\x80\x02\x98\x80") - self.assertIsInstance(client.parser_exc, PayloadTooBig) - self.assertEqual(str(client.parser_exc), "over size limit (2 > 1 bytes)") - self.assertConnectionFailing(client, 1009, "over size limit (2 > 1 bytes)") - - def test_server_receives_fragmented_text_over_size_limit(self): - server = Connection(SERVER, max_size=3) - server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") - self.assertFrameReceived( - server, - Frame(OP_TEXT, "😀".encode()[:2], fin=False), - ) - server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") - self.assertIsInstance(server.parser_exc, PayloadTooBig) - self.assertEqual(str(server.parser_exc), "over size limit (2 > 1 bytes)") - self.assertConnectionFailing(server, 1009, "over size limit (2 > 1 bytes)") - - def test_client_receives_fragmented_text_without_size_limit(self): - client = Connection(CLIENT, max_size=None) - client.receive_data(b"\x01\x02\xf0\x9f") - self.assertFrameReceived( - client, - Frame(OP_TEXT, "😀".encode()[:2], fin=False), - ) - client.receive_data(b"\x00\x04\x98\x80\xf0\x9f") - self.assertFrameReceived( - client, - Frame(OP_CONT, "😀😀".encode()[2:6], fin=False), - ) - client.receive_data(b"\x80\x02\x98\x80") - self.assertFrameReceived( - client, - Frame(OP_CONT, "😀".encode()[2:]), - ) - - def test_server_receives_fragmented_text_without_size_limit(self): - server = Connection(SERVER, max_size=None) - server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") - self.assertFrameReceived( - server, - Frame(OP_TEXT, "😀".encode()[:2], fin=False), - ) - server.receive_data(b"\x00\x84\x00\x00\x00\x00\x98\x80\xf0\x9f") - self.assertFrameReceived( - server, - Frame(OP_CONT, "😀😀".encode()[2:6], fin=False), - ) - server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") - self.assertFrameReceived( - server, - Frame(OP_CONT, "😀".encode()[2:]), - ) - - def test_client_sends_unexpected_text(self): - client = Connection(CLIENT) - client.send_text(b"", fin=False) - with self.assertRaises(ProtocolError) as raised: - client.send_text(b"", fin=False) - self.assertEqual(str(raised.exception), "expected a continuation frame") - - def test_server_sends_unexpected_text(self): - server = Connection(SERVER) - server.send_text(b"", fin=False) - with self.assertRaises(ProtocolError) as raised: - server.send_text(b"", fin=False) - self.assertEqual(str(raised.exception), "expected a continuation frame") - - def test_client_receives_unexpected_text(self): - client = Connection(CLIENT) - client.receive_data(b"\x01\x00") - self.assertFrameReceived( - client, - Frame(OP_TEXT, b"", fin=False), - ) - client.receive_data(b"\x01\x00") - self.assertIsInstance(client.parser_exc, ProtocolError) - self.assertEqual(str(client.parser_exc), "expected a continuation frame") - self.assertConnectionFailing(client, 1002, "expected a continuation frame") - - def test_server_receives_unexpected_text(self): - server = Connection(SERVER) - server.receive_data(b"\x01\x80\x00\x00\x00\x00") - self.assertFrameReceived( - server, - Frame(OP_TEXT, b"", fin=False), - ) - server.receive_data(b"\x01\x80\x00\x00\x00\x00") - self.assertIsInstance(server.parser_exc, ProtocolError) - self.assertEqual(str(server.parser_exc), "expected a continuation frame") - self.assertConnectionFailing(server, 1002, "expected a continuation frame") - - def test_client_sends_text_after_sending_close(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_close(1001) - self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - with self.assertRaises(InvalidState): - client.send_text(b"") - - def test_server_sends_text_after_sending_close(self): - server = Connection(SERVER) - server.send_close(1000) - self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) - with self.assertRaises(InvalidState): - server.send_text(b"") - - def test_client_receives_text_after_receiving_close(self): - client = Connection(CLIENT) - client.receive_data(b"\x88\x02\x03\xe8") - self.assertConnectionClosing(client, 1000) - client.receive_data(b"\x81\x00") - self.assertFrameReceived(client, None) - self.assertFrameSent(client, None) - - def test_server_receives_text_after_receiving_close(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") - self.assertConnectionClosing(server, 1001) - server.receive_data(b"\x81\x80\x00\xff\x00\xff") - self.assertFrameReceived(server, None) - self.assertFrameSent(server, None) - - -class BinaryTests(ConnectionTestCase): - """ - Test binary frames and continuation frames. - - """ - - def test_client_sends_binary(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_binary(b"\x01\x02\xfe\xff") - self.assertEqual( - client.data_to_send(), [b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff"] - ) - - def test_server_sends_binary(self): - server = Connection(SERVER) - server.send_binary(b"\x01\x02\xfe\xff") - self.assertEqual(server.data_to_send(), [b"\x82\x04\x01\x02\xfe\xff"]) - - def test_client_receives_binary(self): - client = Connection(CLIENT) - client.receive_data(b"\x82\x04\x01\x02\xfe\xff") - self.assertFrameReceived( - client, - Frame(OP_BINARY, b"\x01\x02\xfe\xff"), - ) - - def test_server_receives_binary(self): - server = Connection(SERVER) - server.receive_data(b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff") - self.assertFrameReceived( - server, - Frame(OP_BINARY, b"\x01\x02\xfe\xff"), - ) - - def test_client_receives_binary_over_size_limit(self): - client = Connection(CLIENT, max_size=3) - client.receive_data(b"\x82\x04\x01\x02\xfe\xff") - self.assertIsInstance(client.parser_exc, PayloadTooBig) - self.assertEqual(str(client.parser_exc), "over size limit (4 > 3 bytes)") - self.assertConnectionFailing(client, 1009, "over size limit (4 > 3 bytes)") - - def test_server_receives_binary_over_size_limit(self): - server = Connection(SERVER, max_size=3) - server.receive_data(b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff") - self.assertIsInstance(server.parser_exc, PayloadTooBig) - self.assertEqual(str(server.parser_exc), "over size limit (4 > 3 bytes)") - self.assertConnectionFailing(server, 1009, "over size limit (4 > 3 bytes)") - - def test_client_sends_fragmented_binary(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_binary(b"\x01\x02", fin=False) - self.assertEqual(client.data_to_send(), [b"\x02\x82\x00\x00\x00\x00\x01\x02"]) - with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_continuation(b"\xee\xff\x01\x02", fin=False) - self.assertEqual( - client.data_to_send(), [b"\x00\x84\x00\x00\x00\x00\xee\xff\x01\x02"] - ) - with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_continuation(b"\xee\xff", fin=True) - self.assertEqual(client.data_to_send(), [b"\x80\x82\x00\x00\x00\x00\xee\xff"]) - - def test_server_sends_fragmented_binary(self): - server = Connection(SERVER) - server.send_binary(b"\x01\x02", fin=False) - self.assertEqual(server.data_to_send(), [b"\x02\x02\x01\x02"]) - server.send_continuation(b"\xee\xff\x01\x02", fin=False) - self.assertEqual(server.data_to_send(), [b"\x00\x04\xee\xff\x01\x02"]) - server.send_continuation(b"\xee\xff", fin=True) - self.assertEqual(server.data_to_send(), [b"\x80\x02\xee\xff"]) - - def test_client_receives_fragmented_binary(self): - client = Connection(CLIENT) - client.receive_data(b"\x02\x02\x01\x02") - self.assertFrameReceived( - client, - Frame(OP_BINARY, b"\x01\x02", fin=False), - ) - client.receive_data(b"\x00\x04\xfe\xff\x01\x02") - self.assertFrameReceived( - client, - Frame(OP_CONT, b"\xfe\xff\x01\x02", fin=False), - ) - client.receive_data(b"\x80\x02\xfe\xff") - self.assertFrameReceived( - client, - Frame(OP_CONT, b"\xfe\xff"), - ) - - def test_server_receives_fragmented_binary(self): - server = Connection(SERVER) - server.receive_data(b"\x02\x82\x00\x00\x00\x00\x01\x02") - self.assertFrameReceived( - server, - Frame(OP_BINARY, b"\x01\x02", fin=False), - ) - server.receive_data(b"\x00\x84\x00\x00\x00\x00\xee\xff\x01\x02") - self.assertFrameReceived( - server, - Frame(OP_CONT, b"\xee\xff\x01\x02", fin=False), - ) - server.receive_data(b"\x80\x82\x00\x00\x00\x00\xfe\xff") - self.assertFrameReceived( - server, - Frame(OP_CONT, b"\xfe\xff"), - ) - - def test_client_receives_fragmented_binary_over_size_limit(self): - client = Connection(CLIENT, max_size=3) - client.receive_data(b"\x02\x02\x01\x02") - self.assertFrameReceived( - client, - Frame(OP_BINARY, b"\x01\x02", fin=False), - ) - client.receive_data(b"\x80\x02\xfe\xff") - self.assertIsInstance(client.parser_exc, PayloadTooBig) - self.assertEqual(str(client.parser_exc), "over size limit (2 > 1 bytes)") - self.assertConnectionFailing(client, 1009, "over size limit (2 > 1 bytes)") - - def test_server_receives_fragmented_binary_over_size_limit(self): - server = Connection(SERVER, max_size=3) - server.receive_data(b"\x02\x82\x00\x00\x00\x00\x01\x02") - self.assertFrameReceived( - server, - Frame(OP_BINARY, b"\x01\x02", fin=False), - ) - server.receive_data(b"\x80\x82\x00\x00\x00\x00\xfe\xff") - self.assertIsInstance(server.parser_exc, PayloadTooBig) - self.assertEqual(str(server.parser_exc), "over size limit (2 > 1 bytes)") - self.assertConnectionFailing(server, 1009, "over size limit (2 > 1 bytes)") - - def test_client_sends_unexpected_binary(self): - client = Connection(CLIENT) - client.send_binary(b"", fin=False) - with self.assertRaises(ProtocolError) as raised: - client.send_binary(b"", fin=False) - self.assertEqual(str(raised.exception), "expected a continuation frame") - - def test_server_sends_unexpected_binary(self): - server = Connection(SERVER) - server.send_binary(b"", fin=False) - with self.assertRaises(ProtocolError) as raised: - server.send_binary(b"", fin=False) - self.assertEqual(str(raised.exception), "expected a continuation frame") - - def test_client_receives_unexpected_binary(self): - client = Connection(CLIENT) - client.receive_data(b"\x02\x00") - self.assertFrameReceived( - client, - Frame(OP_BINARY, b"", fin=False), - ) - client.receive_data(b"\x02\x00") - self.assertIsInstance(client.parser_exc, ProtocolError) - self.assertEqual(str(client.parser_exc), "expected a continuation frame") - self.assertConnectionFailing(client, 1002, "expected a continuation frame") - - def test_server_receives_unexpected_binary(self): - server = Connection(SERVER) - server.receive_data(b"\x02\x80\x00\x00\x00\x00") - self.assertFrameReceived( - server, - Frame(OP_BINARY, b"", fin=False), - ) - server.receive_data(b"\x02\x80\x00\x00\x00\x00") - self.assertIsInstance(server.parser_exc, ProtocolError) - self.assertEqual(str(server.parser_exc), "expected a continuation frame") - self.assertConnectionFailing(server, 1002, "expected a continuation frame") - - def test_client_sends_binary_after_sending_close(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_close(1001) - self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - with self.assertRaises(InvalidState): - client.send_binary(b"") - - def test_server_sends_binary_after_sending_close(self): - server = Connection(SERVER) - server.send_close(1000) - self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) - with self.assertRaises(InvalidState): - server.send_binary(b"") - - def test_client_receives_binary_after_receiving_close(self): - client = Connection(CLIENT) - client.receive_data(b"\x88\x02\x03\xe8") - self.assertConnectionClosing(client, 1000) - client.receive_data(b"\x82\x00") - self.assertFrameReceived(client, None) - self.assertFrameSent(client, None) - - def test_server_receives_binary_after_receiving_close(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") - self.assertConnectionClosing(server, 1001) - server.receive_data(b"\x82\x80\x00\xff\x00\xff") - self.assertFrameReceived(server, None) - self.assertFrameSent(server, None) - - -class CloseTests(ConnectionTestCase): - """ - Test close frames. - - See RFC 6544: - - 5.5.1. Close - 7.1.6. The WebSocket Connection Close Reason - 7.1.7. Fail the WebSocket Connection - - """ - - def test_close_code(self): - client = Connection(CLIENT) - client.receive_data(b"\x88\x04\x03\xe8OK") - client.receive_eof() - self.assertEqual(client.close_code, 1000) - - def test_close_reason(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x84\x00\x00\x00\x00\x03\xe8OK") - server.receive_eof() - self.assertEqual(server.close_reason, "OK") - - def test_close_code_not_provided(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x80\x00\x00\x00\x00") - server.receive_eof() - self.assertEqual(server.close_code, 1005) - - def test_close_reason_not_provided(self): - client = Connection(CLIENT) - client.receive_data(b"\x88\x00") - client.receive_eof() - self.assertEqual(client.close_reason, "") - - def test_close_code_not_available(self): - client = Connection(CLIENT) - client.receive_eof() - self.assertEqual(client.close_code, 1006) - - def test_close_reason_not_available(self): - server = Connection(SERVER) - server.receive_eof() - self.assertEqual(server.close_reason, "") - - def test_close_code_not_available_yet(self): - server = Connection(SERVER) - self.assertIsNone(server.close_code) - - def test_close_reason_not_available_yet(self): - client = Connection(CLIENT) - self.assertIsNone(client.close_reason) - - def test_client_sends_close(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x3c\x3c\x3c\x3c"): - client.send_close() - self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) - self.assertIs(client.state, CLOSING) - - def test_server_sends_close(self): - server = Connection(SERVER) - server.send_close() - self.assertEqual(server.data_to_send(), [b"\x88\x00"]) - self.assertIs(server.state, CLOSING) - - def test_client_receives_close(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x3c\x3c\x3c\x3c"): - client.receive_data(b"\x88\x00") - self.assertEqual(client.events_received(), [Frame(OP_CLOSE, b"")]) - self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) - self.assertIs(client.state, CLOSING) - - def test_server_receives_close(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") - self.assertEqual(server.events_received(), [Frame(OP_CLOSE, b"")]) - self.assertEqual(server.data_to_send(), [b"\x88\x00", b""]) - self.assertIs(server.state, CLOSING) - - def test_client_sends_close_then_receives_close(self): - # Client-initiated close handshake on the client side. - client = Connection(CLIENT) - - client.send_close() - self.assertFrameReceived(client, None) - self.assertFrameSent(client, Frame(OP_CLOSE, b"")) - - client.receive_data(b"\x88\x00") - self.assertFrameReceived(client, Frame(OP_CLOSE, b"")) - self.assertFrameSent(client, None) - - client.receive_eof() - self.assertFrameReceived(client, None) - self.assertFrameSent(client, None, eof=True) - - def test_server_sends_close_then_receives_close(self): - # Server-initiated close handshake on the server side. - server = Connection(SERVER) - - server.send_close() - self.assertFrameReceived(server, None) - self.assertFrameSent(server, Frame(OP_CLOSE, b"")) - - server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") - self.assertFrameReceived(server, Frame(OP_CLOSE, b"")) - self.assertFrameSent(server, None, eof=True) - - server.receive_eof() - self.assertFrameReceived(server, None) - self.assertFrameSent(server, None) - - def test_client_receives_close_then_sends_close(self): - # Server-initiated close handshake on the client side. - client = Connection(CLIENT) - - client.receive_data(b"\x88\x00") - self.assertFrameReceived(client, Frame(OP_CLOSE, b"")) - self.assertFrameSent(client, Frame(OP_CLOSE, b"")) - - client.receive_eof() - self.assertFrameReceived(client, None) - self.assertFrameSent(client, None, eof=True) - - def test_server_receives_close_then_sends_close(self): - # Client-initiated close handshake on the server side. - server = Connection(SERVER) - - server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") - self.assertFrameReceived(server, Frame(OP_CLOSE, b"")) - self.assertFrameSent(server, Frame(OP_CLOSE, b""), eof=True) - - server.receive_eof() - self.assertFrameReceived(server, None) - self.assertFrameSent(server, None) - - def test_client_sends_close_with_code(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_close(1001) - self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - self.assertIs(client.state, CLOSING) - - def test_server_sends_close_with_code(self): - server = Connection(SERVER) - server.send_close(1000) - self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) - self.assertIs(server.state, CLOSING) - - def test_client_receives_close_with_code(self): - client = Connection(CLIENT) - client.receive_data(b"\x88\x02\x03\xe8") - self.assertConnectionClosing(client, 1000, "") - self.assertIs(client.state, CLOSING) - - def test_server_receives_close_with_code(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") - self.assertConnectionClosing(server, 1001, "") - self.assertIs(server.state, CLOSING) - - def test_client_sends_close_with_code_and_reason(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_close(1001, "going away") - self.assertEqual( - client.data_to_send(), [b"\x88\x8c\x00\x00\x00\x00\x03\xe9going away"] - ) - self.assertIs(client.state, CLOSING) - - def test_server_sends_close_with_code_and_reason(self): - server = Connection(SERVER) - server.send_close(1000, "OK") - self.assertEqual(server.data_to_send(), [b"\x88\x04\x03\xe8OK"]) - self.assertIs(server.state, CLOSING) - - def test_client_receives_close_with_code_and_reason(self): - client = Connection(CLIENT) - client.receive_data(b"\x88\x04\x03\xe8OK") - self.assertConnectionClosing(client, 1000, "OK") - self.assertIs(client.state, CLOSING) - - def test_server_receives_close_with_code_and_reason(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x8c\x00\x00\x00\x00\x03\xe9going away") - self.assertConnectionClosing(server, 1001, "going away") - self.assertIs(server.state, CLOSING) - - def test_client_sends_close_with_reason_only(self): - client = Connection(CLIENT) - with self.assertRaises(ProtocolError) as raised: - client.send_close(reason="going away") - self.assertEqual(str(raised.exception), "cannot send a reason without a code") - - def test_server_sends_close_with_reason_only(self): - server = Connection(SERVER) - with self.assertRaises(ProtocolError) as raised: - server.send_close(reason="OK") - self.assertEqual(str(raised.exception), "cannot send a reason without a code") - - def test_client_receives_close_with_truncated_code(self): - client = Connection(CLIENT) - client.receive_data(b"\x88\x01\x03") - self.assertIsInstance(client.parser_exc, ProtocolError) - self.assertEqual(str(client.parser_exc), "close frame too short") - self.assertConnectionFailing(client, 1002, "close frame too short") - self.assertIs(client.state, CLOSING) - - def test_server_receives_close_with_truncated_code(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x81\x00\x00\x00\x00\x03") - self.assertIsInstance(server.parser_exc, ProtocolError) - self.assertEqual(str(server.parser_exc), "close frame too short") - self.assertConnectionFailing(server, 1002, "close frame too short") - self.assertIs(server.state, CLOSING) - - def test_client_receives_close_with_non_utf8_reason(self): - client = Connection(CLIENT) - - client.receive_data(b"\x88\x04\x03\xe8\xff\xff") - self.assertIsInstance(client.parser_exc, UnicodeDecodeError) - self.assertEqual( - str(client.parser_exc), - "'utf-8' codec can't decode byte 0xff in position 0: invalid start byte", - ) - self.assertConnectionFailing(client, 1007, "invalid start byte at position 0") - self.assertIs(client.state, CLOSING) - - def test_server_receives_close_with_non_utf8_reason(self): - server = Connection(SERVER) - - server.receive_data(b"\x88\x84\x00\x00\x00\x00\x03\xe9\xff\xff") - self.assertIsInstance(server.parser_exc, UnicodeDecodeError) - self.assertEqual( - str(server.parser_exc), - "'utf-8' codec can't decode byte 0xff in position 0: invalid start byte", - ) - self.assertConnectionFailing(server, 1007, "invalid start byte at position 0") - self.assertIs(server.state, CLOSING) - - -class PingTests(ConnectionTestCase): - """ - Test ping. See 5.5.2. Ping in RFC 6544. - - """ - - def test_client_sends_ping(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x00\x44\x88\xcc"): - client.send_ping(b"") - self.assertEqual(client.data_to_send(), [b"\x89\x80\x00\x44\x88\xcc"]) - - def test_server_sends_ping(self): - server = Connection(SERVER) - server.send_ping(b"") - self.assertEqual(server.data_to_send(), [b"\x89\x00"]) - - def test_client_receives_ping(self): - client = Connection(CLIENT) - client.receive_data(b"\x89\x00") - self.assertFrameReceived( - client, - Frame(OP_PING, b""), - ) - self.assertFrameSent( - client, - Frame(OP_PONG, b""), - ) - - def test_server_receives_ping(self): - server = Connection(SERVER) - server.receive_data(b"\x89\x80\x00\x44\x88\xcc") - self.assertFrameReceived( - server, - Frame(OP_PING, b""), - ) - self.assertFrameSent( - server, - Frame(OP_PONG, b""), - ) - - def test_client_sends_ping_with_data(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x00\x44\x88\xcc"): - client.send_ping(b"\x22\x66\xaa\xee") - self.assertEqual( - client.data_to_send(), [b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22"] - ) - - def test_server_sends_ping_with_data(self): - server = Connection(SERVER) - server.send_ping(b"\x22\x66\xaa\xee") - self.assertEqual(server.data_to_send(), [b"\x89\x04\x22\x66\xaa\xee"]) - - def test_client_receives_ping_with_data(self): - client = Connection(CLIENT) - client.receive_data(b"\x89\x04\x22\x66\xaa\xee") - self.assertFrameReceived( - client, - Frame(OP_PING, b"\x22\x66\xaa\xee"), - ) - self.assertFrameSent( - client, - Frame(OP_PONG, b"\x22\x66\xaa\xee"), - ) - - def test_server_receives_ping_with_data(self): - server = Connection(SERVER) - server.receive_data(b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22") - self.assertFrameReceived( - server, - Frame(OP_PING, b"\x22\x66\xaa\xee"), - ) - self.assertFrameSent( - server, - Frame(OP_PONG, b"\x22\x66\xaa\xee"), - ) - - def test_client_sends_fragmented_ping_frame(self): - client = Connection(CLIENT) - # This is only possible through a private API. - with self.assertRaises(ProtocolError) as raised: - client.send_frame(Frame(OP_PING, b"", fin=False)) - self.assertEqual(str(raised.exception), "fragmented control frame") - - def test_server_sends_fragmented_ping_frame(self): - server = Connection(SERVER) - # This is only possible through a private API. - with self.assertRaises(ProtocolError) as raised: - server.send_frame(Frame(OP_PING, b"", fin=False)) - self.assertEqual(str(raised.exception), "fragmented control frame") - - def test_client_receives_fragmented_ping_frame(self): - client = Connection(CLIENT) - client.receive_data(b"\x09\x00") - self.assertIsInstance(client.parser_exc, ProtocolError) - self.assertEqual(str(client.parser_exc), "fragmented control frame") - self.assertConnectionFailing(client, 1002, "fragmented control frame") - - def test_server_receives_fragmented_ping_frame(self): - server = Connection(SERVER) - server.receive_data(b"\x09\x80\x3c\x3c\x3c\x3c") - self.assertIsInstance(server.parser_exc, ProtocolError) - self.assertEqual(str(server.parser_exc), "fragmented control frame") - self.assertConnectionFailing(server, 1002, "fragmented control frame") - - def test_client_sends_ping_after_sending_close(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_close(1001) - self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - # The spec says: "An endpoint MAY send a Ping frame any time (...) - # before the connection is closed" but websockets doesn't support - # sending a Ping frame after a Close frame. - with self.assertRaises(InvalidState) as raised: - client.send_ping(b"") - self.assertEqual( - str(raised.exception), - "cannot write to a WebSocket in the CLOSING state", - ) - - def test_server_sends_ping_after_sending_close(self): - server = Connection(SERVER) - server.send_close(1000) - self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) - # The spec says: "An endpoint MAY send a Ping frame any time (...) - # before the connection is closed" but websockets doesn't support - # sending a Ping frame after a Close frame. - with self.assertRaises(InvalidState) as raised: - server.send_ping(b"") - self.assertEqual( - str(raised.exception), - "cannot write to a WebSocket in the CLOSING state", - ) - - def test_client_receives_ping_after_receiving_close(self): - client = Connection(CLIENT) - client.receive_data(b"\x88\x02\x03\xe8") - self.assertConnectionClosing(client, 1000) - client.receive_data(b"\x89\x04\x22\x66\xaa\xee") - self.assertFrameReceived(client, None) - self.assertFrameSent(client, None) - - def test_server_receives_ping_after_receiving_close(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") - self.assertConnectionClosing(server, 1001) - server.receive_data(b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22") - self.assertFrameReceived(server, None) - self.assertFrameSent(server, None) - - -class PongTests(ConnectionTestCase): - """ - Test pong frames. See 5.5.3. Pong in RFC 6544. - - """ - - def test_client_sends_pong(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x00\x44\x88\xcc"): - client.send_pong(b"") - self.assertEqual(client.data_to_send(), [b"\x8a\x80\x00\x44\x88\xcc"]) - - def test_server_sends_pong(self): - server = Connection(SERVER) - server.send_pong(b"") - self.assertEqual(server.data_to_send(), [b"\x8a\x00"]) - - def test_client_receives_pong(self): - client = Connection(CLIENT) - client.receive_data(b"\x8a\x00") - self.assertFrameReceived( - client, - Frame(OP_PONG, b""), - ) - - def test_server_receives_pong(self): - server = Connection(SERVER) - server.receive_data(b"\x8a\x80\x00\x44\x88\xcc") - self.assertFrameReceived( - server, - Frame(OP_PONG, b""), - ) - - def test_client_sends_pong_with_data(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x00\x44\x88\xcc"): - client.send_pong(b"\x22\x66\xaa\xee") - self.assertEqual( - client.data_to_send(), [b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22"] - ) - - def test_server_sends_pong_with_data(self): - server = Connection(SERVER) - server.send_pong(b"\x22\x66\xaa\xee") - self.assertEqual(server.data_to_send(), [b"\x8a\x04\x22\x66\xaa\xee"]) - - def test_client_receives_pong_with_data(self): - client = Connection(CLIENT) - client.receive_data(b"\x8a\x04\x22\x66\xaa\xee") - self.assertFrameReceived( - client, - Frame(OP_PONG, b"\x22\x66\xaa\xee"), - ) - - def test_server_receives_pong_with_data(self): - server = Connection(SERVER) - server.receive_data(b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22") - self.assertFrameReceived( - server, - Frame(OP_PONG, b"\x22\x66\xaa\xee"), - ) - - def test_client_sends_fragmented_pong_frame(self): - client = Connection(CLIENT) - # This is only possible through a private API. - with self.assertRaises(ProtocolError) as raised: - client.send_frame(Frame(OP_PONG, b"", fin=False)) - self.assertEqual(str(raised.exception), "fragmented control frame") - - def test_server_sends_fragmented_pong_frame(self): - server = Connection(SERVER) - # This is only possible through a private API. - with self.assertRaises(ProtocolError) as raised: - server.send_frame(Frame(OP_PONG, b"", fin=False)) - self.assertEqual(str(raised.exception), "fragmented control frame") - - def test_client_receives_fragmented_pong_frame(self): - client = Connection(CLIENT) - client.receive_data(b"\x0a\x00") - self.assertIsInstance(client.parser_exc, ProtocolError) - self.assertEqual(str(client.parser_exc), "fragmented control frame") - self.assertConnectionFailing(client, 1002, "fragmented control frame") - - def test_server_receives_fragmented_pong_frame(self): - server = Connection(SERVER) - server.receive_data(b"\x0a\x80\x3c\x3c\x3c\x3c") - self.assertIsInstance(server.parser_exc, ProtocolError) - self.assertEqual(str(server.parser_exc), "fragmented control frame") - self.assertConnectionFailing(server, 1002, "fragmented control frame") - - def test_client_sends_pong_after_sending_close(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_close(1001) - self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - # websockets doesn't support sending a Pong frame after a Close frame. - with self.assertRaises(InvalidState): - client.send_pong(b"") - - def test_server_sends_pong_after_sending_close(self): - server = Connection(SERVER) - server.send_close(1000) - self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) - # websockets doesn't support sending a Pong frame after a Close frame. - with self.assertRaises(InvalidState): - server.send_pong(b"") - - def test_client_receives_pong_after_receiving_close(self): - client = Connection(CLIENT) - client.receive_data(b"\x88\x02\x03\xe8") - self.assertConnectionClosing(client, 1000) - client.receive_data(b"\x8a\x04\x22\x66\xaa\xee") - self.assertFrameReceived(client, None) - self.assertFrameSent(client, None) - - def test_server_receives_pong_after_receiving_close(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") - self.assertConnectionClosing(server, 1001) - server.receive_data(b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22") - self.assertFrameReceived(server, None) - self.assertFrameSent(server, None) - - -class FailTests(ConnectionTestCase): - """ - Test failing the connection. - - See 7.1.7. Fail the WebSocket Connection in RFC 6544. - - """ - - def test_client_stops_processing_frames_after_fail(self): - client = Connection(CLIENT) - client.fail(1002) - self.assertConnectionFailing(client, 1002) - client.receive_data(b"\x88\x02\x03\xea") - self.assertFrameReceived(client, None) - - def test_server_stops_processing_frames_after_fail(self): - server = Connection(SERVER) - server.fail(1002) - self.assertConnectionFailing(server, 1002) - server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xea") - self.assertFrameReceived(server, None) - - -class FragmentationTests(ConnectionTestCase): - """ - Test message fragmentation. - - See 5.4. Fragmentation in RFC 6544. - - """ - - def test_client_send_ping_pong_in_fragmented_message(self): - client = Connection(CLIENT) - client.send_text(b"Spam", fin=False) - self.assertFrameSent(client, Frame(OP_TEXT, b"Spam", fin=False)) - client.send_ping(b"Ping") - self.assertFrameSent(client, Frame(OP_PING, b"Ping")) - client.send_continuation(b"Ham", fin=False) - self.assertFrameSent(client, Frame(OP_CONT, b"Ham", fin=False)) - client.send_pong(b"Pong") - self.assertFrameSent(client, Frame(OP_PONG, b"Pong")) - client.send_continuation(b"Eggs", fin=True) - self.assertFrameSent(client, Frame(OP_CONT, b"Eggs")) - - def test_server_send_ping_pong_in_fragmented_message(self): - server = Connection(SERVER) - server.send_text(b"Spam", fin=False) - self.assertFrameSent(server, Frame(OP_TEXT, b"Spam", fin=False)) - server.send_ping(b"Ping") - self.assertFrameSent(server, Frame(OP_PING, b"Ping")) - server.send_continuation(b"Ham", fin=False) - self.assertFrameSent(server, Frame(OP_CONT, b"Ham", fin=False)) - server.send_pong(b"Pong") - self.assertFrameSent(server, Frame(OP_PONG, b"Pong")) - server.send_continuation(b"Eggs", fin=True) - self.assertFrameSent(server, Frame(OP_CONT, b"Eggs")) - - def test_client_receive_ping_pong_in_fragmented_message(self): - client = Connection(CLIENT) - client.receive_data(b"\x01\x04Spam") - self.assertFrameReceived( - client, - Frame(OP_TEXT, b"Spam", fin=False), - ) - client.receive_data(b"\x89\x04Ping") - self.assertFrameReceived( - client, - Frame(OP_PING, b"Ping"), - ) - self.assertFrameSent( - client, - Frame(OP_PONG, b"Ping"), - ) - client.receive_data(b"\x00\x03Ham") - self.assertFrameReceived( - client, - Frame(OP_CONT, b"Ham", fin=False), - ) - client.receive_data(b"\x8a\x04Pong") - self.assertFrameReceived( - client, - Frame(OP_PONG, b"Pong"), - ) - client.receive_data(b"\x80\x04Eggs") - self.assertFrameReceived( - client, - Frame(OP_CONT, b"Eggs"), - ) - - def test_server_receive_ping_pong_in_fragmented_message(self): - server = Connection(SERVER) - server.receive_data(b"\x01\x84\x00\x00\x00\x00Spam") - self.assertFrameReceived( - server, - Frame(OP_TEXT, b"Spam", fin=False), - ) - server.receive_data(b"\x89\x84\x00\x00\x00\x00Ping") - self.assertFrameReceived( - server, - Frame(OP_PING, b"Ping"), - ) - self.assertFrameSent( - server, - Frame(OP_PONG, b"Ping"), - ) - server.receive_data(b"\x00\x83\x00\x00\x00\x00Ham") - self.assertFrameReceived( - server, - Frame(OP_CONT, b"Ham", fin=False), - ) - server.receive_data(b"\x8a\x84\x00\x00\x00\x00Pong") - self.assertFrameReceived( - server, - Frame(OP_PONG, b"Pong"), - ) - server.receive_data(b"\x80\x84\x00\x00\x00\x00Eggs") - self.assertFrameReceived( - server, - Frame(OP_CONT, b"Eggs"), - ) - - def test_client_send_close_in_fragmented_message(self): - client = Connection(CLIENT) - client.send_text(b"Spam", fin=False) - self.assertFrameSent(client, Frame(OP_TEXT, b"Spam", fin=False)) - # The spec says: "An endpoint MUST be capable of handling control - # frames in the middle of a fragmented message." However, since the - # endpoint must not send a data frame after a close frame, a close - # frame can't be "in the middle" of a fragmented message. - with self.assertRaises(ProtocolError) as raised: - client.send_close(1001) - self.assertEqual(str(raised.exception), "expected a continuation frame") - client.send_continuation(b"Eggs", fin=True) - - def test_server_send_close_in_fragmented_message(self): - server = Connection(CLIENT) - server.send_text(b"Spam", fin=False) - self.assertFrameSent(server, Frame(OP_TEXT, b"Spam", fin=False)) - # The spec says: "An endpoint MUST be capable of handling control - # frames in the middle of a fragmented message." However, since the - # endpoint must not send a data frame after a close frame, a close - # frame can't be "in the middle" of a fragmented message. - with self.assertRaises(ProtocolError) as raised: - server.send_close(1000) - self.assertEqual(str(raised.exception), "expected a continuation frame") - - def test_client_receive_close_in_fragmented_message(self): - client = Connection(CLIENT) - client.receive_data(b"\x01\x04Spam") - self.assertFrameReceived( - client, - Frame(OP_TEXT, b"Spam", fin=False), - ) - # The spec says: "An endpoint MUST be capable of handling control - # frames in the middle of a fragmented message." However, since the - # endpoint must not send a data frame after a close frame, a close - # frame can't be "in the middle" of a fragmented message. - client.receive_data(b"\x88\x02\x03\xe8") - self.assertIsInstance(client.parser_exc, ProtocolError) - self.assertEqual(str(client.parser_exc), "incomplete fragmented message") - self.assertConnectionFailing(client, 1002, "incomplete fragmented message") - - def test_server_receive_close_in_fragmented_message(self): - server = Connection(SERVER) - server.receive_data(b"\x01\x84\x00\x00\x00\x00Spam") - self.assertFrameReceived( - server, - Frame(OP_TEXT, b"Spam", fin=False), - ) - # The spec says: "An endpoint MUST be capable of handling control - # frames in the middle of a fragmented message." However, since the - # endpoint must not send a data frame after a close frame, a close - # frame can't be "in the middle" of a fragmented message. - server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") - self.assertIsInstance(server.parser_exc, ProtocolError) - self.assertEqual(str(server.parser_exc), "incomplete fragmented message") - self.assertConnectionFailing(server, 1002, "incomplete fragmented message") - - -class EOFTests(ConnectionTestCase): - """ - Test half-closes on connection termination. - - """ - - def test_client_receives_eof(self): - client = Connection(CLIENT) - client.receive_data(b"\x88\x00") - self.assertConnectionClosing(client) - client.receive_eof() - self.assertIs(client.state, CLOSED) - - def test_server_receives_eof(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") - self.assertConnectionClosing(server) - server.receive_eof() - self.assertIs(server.state, CLOSED) - - def test_client_receives_eof_between_frames(self): - client = Connection(CLIENT) - client.receive_eof() - self.assertIsInstance(client.parser_exc, EOFError) - self.assertEqual(str(client.parser_exc), "unexpected end of stream") - self.assertIs(client.state, CLOSED) - - def test_server_receives_eof_between_frames(self): - server = Connection(SERVER) - server.receive_eof() - self.assertIsInstance(server.parser_exc, EOFError) - self.assertEqual(str(server.parser_exc), "unexpected end of stream") - self.assertIs(server.state, CLOSED) - - def test_client_receives_eof_inside_frame(self): - client = Connection(CLIENT) - client.receive_data(b"\x81") - client.receive_eof() - self.assertIsInstance(client.parser_exc, EOFError) - self.assertEqual( - str(client.parser_exc), - "stream ends after 1 bytes, expected 2 bytes", - ) - self.assertIs(client.state, CLOSED) - - def test_server_receives_eof_inside_frame(self): - server = Connection(SERVER) - server.receive_data(b"\x81") - server.receive_eof() - self.assertIsInstance(server.parser_exc, EOFError) - self.assertEqual( - str(server.parser_exc), - "stream ends after 1 bytes, expected 2 bytes", - ) - self.assertIs(server.state, CLOSED) - - def test_client_receives_data_after_exception(self): - client = Connection(CLIENT) - client.receive_data(b"\xff\xff") - self.assertConnectionFailing(client, 1002, "invalid opcode") - client.receive_data(b"\x00\x00") - self.assertFrameSent(client, None) - - def test_server_receives_data_after_exception(self): - server = Connection(SERVER) - server.receive_data(b"\xff\xff") - self.assertConnectionFailing(server, 1002, "invalid opcode") - server.receive_data(b"\x00\x00") - self.assertFrameSent(server, None) - - def test_client_receives_eof_after_exception(self): - client = Connection(CLIENT) - client.receive_data(b"\xff\xff") - self.assertConnectionFailing(client, 1002, "invalid opcode") - client.receive_eof() - self.assertFrameSent(client, None, eof=True) - - def test_server_receives_eof_after_exception(self): - server = Connection(SERVER) - server.receive_data(b"\xff\xff") - self.assertConnectionFailing(server, 1002, "invalid opcode") - server.receive_eof() - self.assertFrameSent(server, None) - - def test_client_receives_data_and_eof_after_exception(self): - client = Connection(CLIENT) - client.receive_data(b"\xff\xff") - self.assertConnectionFailing(client, 1002, "invalid opcode") - client.receive_data(b"\x00\x00") - client.receive_eof() - self.assertFrameSent(client, None, eof=True) - - def test_server_receives_data_and_eof_after_exception(self): - server = Connection(SERVER) - server.receive_data(b"\xff\xff") - self.assertConnectionFailing(server, 1002, "invalid opcode") - server.receive_data(b"\x00\x00") - server.receive_eof() - self.assertFrameSent(server, None) - - def test_client_receives_data_after_eof(self): - client = Connection(CLIENT) - client.receive_data(b"\x88\x00") - self.assertConnectionClosing(client) - client.receive_eof() - with self.assertRaises(EOFError) as raised: - client.receive_data(b"\x88\x00") - self.assertEqual(str(raised.exception), "stream ended") - - def test_server_receives_data_after_eof(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") - self.assertConnectionClosing(server) - server.receive_eof() - with self.assertRaises(EOFError) as raised: - server.receive_data(b"\x88\x80\x00\x00\x00\x00") - self.assertEqual(str(raised.exception), "stream ended") - - def test_client_receives_eof_after_eof(self): - client = Connection(CLIENT) - client.receive_data(b"\x88\x00") - self.assertConnectionClosing(client) - client.receive_eof() - with self.assertRaises(EOFError) as raised: - client.receive_eof() - self.assertEqual(str(raised.exception), "stream ended") - - def test_server_receives_eof_after_eof(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") - self.assertConnectionClosing(server) - server.receive_eof() - with self.assertRaises(EOFError) as raised: - server.receive_eof() - self.assertEqual(str(raised.exception), "stream ended") - - -class TCPCloseTests(ConnectionTestCase): - """ - Test expectation of TCP close on connection termination. - - """ - - def test_client_default(self): - client = Connection(CLIENT) - self.assertFalse(client.close_expected()) - - def test_server_default(self): - server = Connection(SERVER) - self.assertFalse(server.close_expected()) - - def test_client_sends_close(self): - client = Connection(CLIENT) - client.send_close() - self.assertTrue(client.close_expected()) - - def test_server_sends_close(self): - server = Connection(SERVER) - server.send_close() - self.assertTrue(server.close_expected()) - - def test_client_receives_close(self): - client = Connection(CLIENT) - client.receive_data(b"\x88\x00") - self.assertTrue(client.close_expected()) - - def test_client_receives_close_then_eof(self): - client = Connection(CLIENT) - client.receive_data(b"\x88\x00") - client.receive_eof() - self.assertFalse(client.close_expected()) - - def test_server_receives_close_then_eof(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") - server.receive_eof() - self.assertFalse(server.close_expected()) - - def test_server_receives_close(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") - self.assertTrue(server.close_expected()) - - def test_client_fails_connection(self): - client = Connection(CLIENT) - client.fail(1002) - self.assertTrue(client.close_expected()) - - def test_server_fails_connection(self): - server = Connection(SERVER) - server.fail(1002) - self.assertTrue(server.close_expected()) - - -class ConnectionClosedTests(ConnectionTestCase): - """ - Test connection closed exception. - - """ - - def test_client_sends_close_then_receives_close(self): - # Client-initiated close handshake on the client side complete. - client = Connection(CLIENT) - client.send_close(1000, "") - client.receive_data(b"\x88\x02\x03\xe8") - client.receive_eof() - exc = client.close_exc - self.assertIsInstance(exc, ConnectionClosedOK) - self.assertEqual(exc.rcvd, Close(1000, "")) - self.assertEqual(exc.sent, Close(1000, "")) - self.assertFalse(exc.rcvd_then_sent) - - def test_server_sends_close_then_receives_close(self): - # Server-initiated close handshake on the server side complete. - server = Connection(SERVER) - server.send_close(1000, "") - server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe8") - server.receive_eof() - exc = server.close_exc - self.assertIsInstance(exc, ConnectionClosedOK) - self.assertEqual(exc.rcvd, Close(1000, "")) - self.assertEqual(exc.sent, Close(1000, "")) - self.assertFalse(exc.rcvd_then_sent) - - def test_client_receives_close_then_sends_close(self): - # Server-initiated close handshake on the client side complete. - client = Connection(CLIENT) - client.receive_data(b"\x88\x02\x03\xe8") - client.receive_eof() - exc = client.close_exc - self.assertIsInstance(exc, ConnectionClosedOK) - self.assertEqual(exc.rcvd, Close(1000, "")) - self.assertEqual(exc.sent, Close(1000, "")) - self.assertTrue(exc.rcvd_then_sent) - - def test_server_receives_close_then_sends_close(self): - # Client-initiated close handshake on the server side complete. - server = Connection(SERVER) - server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe8") - server.receive_eof() - exc = server.close_exc - self.assertIsInstance(exc, ConnectionClosedOK) - self.assertEqual(exc.rcvd, Close(1000, "")) - self.assertEqual(exc.sent, Close(1000, "")) - self.assertTrue(exc.rcvd_then_sent) - - def test_client_sends_close_then_receives_eof(self): - # Client-initiated close handshake on the client side times out. - client = Connection(CLIENT) - client.send_close(1000, "") - client.receive_eof() - exc = client.close_exc - self.assertIsInstance(exc, ConnectionClosedError) - self.assertIsNone(exc.rcvd) - self.assertEqual(exc.sent, Close(1000, "")) - self.assertIsNone(exc.rcvd_then_sent) - - def test_server_sends_close_then_receives_eof(self): - # Server-initiated close handshake on the server side times out. - server = Connection(SERVER) - server.send_close(1000, "") - server.receive_eof() - exc = server.close_exc - self.assertIsInstance(exc, ConnectionClosedError) - self.assertIsNone(exc.rcvd) - self.assertEqual(exc.sent, Close(1000, "")) - self.assertIsNone(exc.rcvd_then_sent) - - def test_client_receives_eof(self): - # Server-initiated close handshake on the client side times out. - client = Connection(CLIENT) - client.receive_eof() - exc = client.close_exc - self.assertIsInstance(exc, ConnectionClosedError) - self.assertIsNone(exc.rcvd) - self.assertIsNone(exc.sent) - self.assertIsNone(exc.rcvd_then_sent) - - def test_server_receives_eof(self): - # Client-initiated close handshake on the server side times out. - server = Connection(SERVER) - server.receive_eof() - exc = server.close_exc - self.assertIsInstance(exc, ConnectionClosedError) - self.assertIsNone(exc.rcvd) - self.assertIsNone(exc.sent) - self.assertIsNone(exc.rcvd_then_sent) - - -class ErrorTests(ConnectionTestCase): - """ - Test other error cases. - - """ - - def test_client_hits_internal_error_reading_frame(self): - client = Connection(CLIENT) - # This isn't supposed to happen, so we're simulating it. - with unittest.mock.patch("struct.unpack", side_effect=RuntimeError("BOOM")): - client.receive_data(b"\x81\x00") - self.assertIsInstance(client.parser_exc, RuntimeError) - self.assertEqual(str(client.parser_exc), "BOOM") - self.assertConnectionFailing(client, 1011, "") - - def test_server_hits_internal_error_reading_frame(self): - server = Connection(SERVER) - # This isn't supposed to happen, so we're simulating it. - with unittest.mock.patch("struct.unpack", side_effect=RuntimeError("BOOM")): - server.receive_data(b"\x81\x80\x00\x00\x00\x00") - self.assertIsInstance(server.parser_exc, RuntimeError) - self.assertEqual(str(server.parser_exc), "BOOM") - self.assertConnectionFailing(server, 1011, "") - - -class ExtensionsTests(ConnectionTestCase): - """ - Test how extensions affect frames. - - """ - - def test_client_extension_encodes_frame(self): - client = Connection(CLIENT) - client.extensions = [Rsv2Extension()] - with self.enforce_mask(b"\x00\x44\x88\xcc"): - client.send_ping(b"") - self.assertEqual(client.data_to_send(), [b"\xa9\x80\x00\x44\x88\xcc"]) - - def test_server_extension_encodes_frame(self): - server = Connection(SERVER) - server.extensions = [Rsv2Extension()] - server.send_ping(b"") - self.assertEqual(server.data_to_send(), [b"\xa9\x00"]) - - def test_client_extension_decodes_frame(self): - client = Connection(CLIENT) - client.extensions = [Rsv2Extension()] - client.receive_data(b"\xaa\x00") - self.assertEqual(client.events_received(), [Frame(OP_PONG, b"")]) - - def test_server_extension_decodes_frame(self): - server = Connection(SERVER) - server.extensions = [Rsv2Extension()] - server.receive_data(b"\xaa\x80\x00\x44\x88\xcc") - self.assertEqual(server.events_received(), [Frame(OP_PONG, b"")]) - - -class MiscTests(unittest.TestCase): - def test_client_default_logger(self): - client = Connection(CLIENT) - logger = logging.getLogger("websockets.client") - self.assertIs(client.logger, logger) - - def test_server_default_logger(self): - server = Connection(SERVER) - logger = logging.getLogger("websockets.server") - self.assertIs(server.logger, logger) - - def test_client_custom_logger(self): - logger = logging.getLogger("test") - client = Connection(CLIENT, logger=logger) - self.assertIs(client.logger, logger) - - def test_server_custom_logger(self): - logger = logging.getLogger("test") - server = Connection(SERVER, logger=logger) - self.assertIs(server.logger, logger) + self.assertIs(Connection, Protocol) diff --git a/tests/test_protocol.py b/tests/test_protocol.py new file mode 100644 index 000000000..7321d2594 --- /dev/null +++ b/tests/test_protocol.py @@ -0,0 +1,1737 @@ +import logging +import unittest.mock + +from websockets.exceptions import ( + ConnectionClosedError, + ConnectionClosedOK, + InvalidState, + PayloadTooBig, + ProtocolError, +) +from websockets.frames import ( + OP_BINARY, + OP_CLOSE, + OP_CONT, + OP_PING, + OP_PONG, + OP_TEXT, + Close, + Frame, +) +from websockets.protocol import * +from websockets.protocol import CLIENT, CLOSED, CLOSING, SERVER + +from .extensions.utils import Rsv2Extension +from .test_frames import FramesTestCase + + +class ProtocolTestCase(FramesTestCase): + def assertFrameSent(self, connection, frame, eof=False): + """ + Outgoing data for ``connection`` contains the given frame. + + ``frame`` may be ``None`` if no frame is expected. + + When ``eof`` is ``True``, the end of the stream is also expected. + + """ + frames_sent = [ + None + if write is SEND_EOF + else self.parse( + write, + mask=connection.side is CLIENT, + extensions=connection.extensions, + ) + for write in connection.data_to_send() + ] + frames_expected = [] if frame is None else [frame] + if eof: + frames_expected += [None] + self.assertEqual(frames_sent, frames_expected) + + def assertFrameReceived(self, connection, frame): + """ + Incoming data for ``connection`` contains the given frame. + + ``frame`` may be ``None`` if no frame is expected. + + """ + frames_received = connection.events_received() + frames_expected = [] if frame is None else [frame] + self.assertEqual(frames_received, frames_expected) + + def assertConnectionClosing(self, connection, code=None, reason=""): + """ + Incoming data caused the "Start the WebSocket Closing Handshake" process. + + """ + close_frame = Frame( + OP_CLOSE, + b"" if code is None else Close(code, reason).serialize(), + ) + # A close frame was received. + self.assertFrameReceived(connection, close_frame) + # A close frame and possibly the end of stream were sent. + self.assertFrameSent(connection, close_frame, eof=connection.side is SERVER) + + def assertConnectionFailing(self, connection, code=None, reason=""): + """ + Incoming data caused the "Fail the WebSocket Connection" process. + + """ + close_frame = Frame( + OP_CLOSE, + b"" if code is None else Close(code, reason).serialize(), + ) + # No frame was received. + self.assertFrameReceived(connection, None) + # A close frame and possibly the end of stream were sent. + self.assertFrameSent(connection, close_frame, eof=connection.side is SERVER) + + +class MaskingTests(ProtocolTestCase): + """ + Test frame masking. + + 5.1. Overview + + """ + + unmasked_text_frame_date = b"\x81\x04Spam" + masked_text_frame_data = b"\x81\x84\x00\xff\x00\xff\x53\x8f\x61\x92" + + def test_client_sends_masked_frame(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\xff\x00\xff"): + client.send_text(b"Spam", True) + self.assertEqual(client.data_to_send(), [self.masked_text_frame_data]) + + def test_server_sends_unmasked_frame(self): + server = Protocol(SERVER) + server.send_text(b"Spam", True) + self.assertEqual(server.data_to_send(), [self.unmasked_text_frame_date]) + + def test_client_receives_unmasked_frame(self): + client = Protocol(CLIENT) + client.receive_data(self.unmasked_text_frame_date) + self.assertFrameReceived( + client, + Frame(OP_TEXT, b"Spam"), + ) + + def test_server_receives_masked_frame(self): + server = Protocol(SERVER) + server.receive_data(self.masked_text_frame_data) + self.assertFrameReceived( + server, + Frame(OP_TEXT, b"Spam"), + ) + + def test_client_receives_masked_frame(self): + client = Protocol(CLIENT) + client.receive_data(self.masked_text_frame_data) + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "incorrect masking") + self.assertConnectionFailing(client, 1002, "incorrect masking") + + def test_server_receives_unmasked_frame(self): + server = Protocol(SERVER) + server.receive_data(self.unmasked_text_frame_date) + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "incorrect masking") + self.assertConnectionFailing(server, 1002, "incorrect masking") + + +class ContinuationTests(ProtocolTestCase): + """ + Test continuation frames without text or binary frames. + + """ + + def test_client_sends_unexpected_continuation(self): + client = Protocol(CLIENT) + with self.assertRaises(ProtocolError) as raised: + client.send_continuation(b"", fin=False) + self.assertEqual(str(raised.exception), "unexpected continuation frame") + + def test_server_sends_unexpected_continuation(self): + server = Protocol(SERVER) + with self.assertRaises(ProtocolError) as raised: + server.send_continuation(b"", fin=False) + self.assertEqual(str(raised.exception), "unexpected continuation frame") + + def test_client_receives_unexpected_continuation(self): + client = Protocol(CLIENT) + client.receive_data(b"\x00\x00") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "unexpected continuation frame") + self.assertConnectionFailing(client, 1002, "unexpected continuation frame") + + def test_server_receives_unexpected_continuation(self): + server = Protocol(SERVER) + server.receive_data(b"\x00\x80\x00\x00\x00\x00") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "unexpected continuation frame") + self.assertConnectionFailing(server, 1002, "unexpected continuation frame") + + def test_client_sends_continuation_after_sending_close(self): + client = Protocol(CLIENT) + # Since it isn't possible to send a close frame in a fragmented + # message (see test_client_send_close_in_fragmented_message), in fact, + # this is the same test as test_client_sends_unexpected_continuation. + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_close(1001) + self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) + with self.assertRaises(ProtocolError) as raised: + client.send_continuation(b"", fin=False) + self.assertEqual(str(raised.exception), "unexpected continuation frame") + + def test_server_sends_continuation_after_sending_close(self): + # Since it isn't possible to send a close frame in a fragmented + # message (see test_server_send_close_in_fragmented_message), in fact, + # this is the same test as test_server_sends_unexpected_continuation. + server = Protocol(SERVER) + server.send_close(1000) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) + with self.assertRaises(ProtocolError) as raised: + server.send_continuation(b"", fin=False) + self.assertEqual(str(raised.exception), "unexpected continuation frame") + + def test_client_receives_continuation_after_receiving_close(self): + client = Protocol(CLIENT) + client.receive_data(b"\x88\x02\x03\xe8") + self.assertConnectionClosing(client, 1000) + client.receive_data(b"\x00\x00") + self.assertFrameReceived(client, None) + self.assertFrameSent(client, None) + + def test_server_receives_continuation_after_receiving_close(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") + self.assertConnectionClosing(server, 1001) + server.receive_data(b"\x00\x80\x00\xff\x00\xff") + self.assertFrameReceived(server, None) + self.assertFrameSent(server, None) + + +class TextTests(ProtocolTestCase): + """ + Test text frames and continuation frames. + + """ + + def test_client_sends_text(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_text("😀".encode()) + self.assertEqual( + client.data_to_send(), [b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80"] + ) + + def test_server_sends_text(self): + server = Protocol(SERVER) + server.send_text("😀".encode()) + self.assertEqual(server.data_to_send(), [b"\x81\x04\xf0\x9f\x98\x80"]) + + def test_client_receives_text(self): + client = Protocol(CLIENT) + client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") + self.assertFrameReceived( + client, + Frame(OP_TEXT, "😀".encode()), + ) + + def test_server_receives_text(self): + server = Protocol(SERVER) + server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") + self.assertFrameReceived( + server, + Frame(OP_TEXT, "😀".encode()), + ) + + def test_client_receives_text_over_size_limit(self): + client = Protocol(CLIENT, max_size=3) + client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") + self.assertIsInstance(client.parser_exc, PayloadTooBig) + self.assertEqual(str(client.parser_exc), "over size limit (4 > 3 bytes)") + self.assertConnectionFailing(client, 1009, "over size limit (4 > 3 bytes)") + + def test_server_receives_text_over_size_limit(self): + server = Protocol(SERVER, max_size=3) + server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") + self.assertIsInstance(server.parser_exc, PayloadTooBig) + self.assertEqual(str(server.parser_exc), "over size limit (4 > 3 bytes)") + self.assertConnectionFailing(server, 1009, "over size limit (4 > 3 bytes)") + + def test_client_receives_text_without_size_limit(self): + client = Protocol(CLIENT, max_size=None) + client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") + self.assertFrameReceived( + client, + Frame(OP_TEXT, "😀".encode()), + ) + + def test_server_receives_text_without_size_limit(self): + server = Protocol(SERVER, max_size=None) + server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") + self.assertFrameReceived( + server, + Frame(OP_TEXT, "😀".encode()), + ) + + def test_client_sends_fragmented_text(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_text("😀".encode()[:2], fin=False) + self.assertEqual(client.data_to_send(), [b"\x01\x82\x00\x00\x00\x00\xf0\x9f"]) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_continuation("😀😀".encode()[2:6], fin=False) + self.assertEqual( + client.data_to_send(), [b"\x00\x84\x00\x00\x00\x00\x98\x80\xf0\x9f"] + ) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_continuation("😀".encode()[2:], fin=True) + self.assertEqual(client.data_to_send(), [b"\x80\x82\x00\x00\x00\x00\x98\x80"]) + + def test_server_sends_fragmented_text(self): + server = Protocol(SERVER) + server.send_text("😀".encode()[:2], fin=False) + self.assertEqual(server.data_to_send(), [b"\x01\x02\xf0\x9f"]) + server.send_continuation("😀😀".encode()[2:6], fin=False) + self.assertEqual(server.data_to_send(), [b"\x00\x04\x98\x80\xf0\x9f"]) + server.send_continuation("😀".encode()[2:], fin=True) + self.assertEqual(server.data_to_send(), [b"\x80\x02\x98\x80"]) + + def test_client_receives_fragmented_text(self): + client = Protocol(CLIENT) + client.receive_data(b"\x01\x02\xf0\x9f") + self.assertFrameReceived( + client, + Frame(OP_TEXT, "😀".encode()[:2], fin=False), + ) + client.receive_data(b"\x00\x04\x98\x80\xf0\x9f") + self.assertFrameReceived( + client, + Frame(OP_CONT, "😀😀".encode()[2:6], fin=False), + ) + client.receive_data(b"\x80\x02\x98\x80") + self.assertFrameReceived( + client, + Frame(OP_CONT, "😀".encode()[2:]), + ) + + def test_server_receives_fragmented_text(self): + server = Protocol(SERVER) + server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") + self.assertFrameReceived( + server, + Frame(OP_TEXT, "😀".encode()[:2], fin=False), + ) + server.receive_data(b"\x00\x84\x00\x00\x00\x00\x98\x80\xf0\x9f") + self.assertFrameReceived( + server, + Frame(OP_CONT, "😀😀".encode()[2:6], fin=False), + ) + server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") + self.assertFrameReceived( + server, + Frame(OP_CONT, "😀".encode()[2:]), + ) + + def test_client_receives_fragmented_text_over_size_limit(self): + client = Protocol(CLIENT, max_size=3) + client.receive_data(b"\x01\x02\xf0\x9f") + self.assertFrameReceived( + client, + Frame(OP_TEXT, "😀".encode()[:2], fin=False), + ) + client.receive_data(b"\x80\x02\x98\x80") + self.assertIsInstance(client.parser_exc, PayloadTooBig) + self.assertEqual(str(client.parser_exc), "over size limit (2 > 1 bytes)") + self.assertConnectionFailing(client, 1009, "over size limit (2 > 1 bytes)") + + def test_server_receives_fragmented_text_over_size_limit(self): + server = Protocol(SERVER, max_size=3) + server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") + self.assertFrameReceived( + server, + Frame(OP_TEXT, "😀".encode()[:2], fin=False), + ) + server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") + self.assertIsInstance(server.parser_exc, PayloadTooBig) + self.assertEqual(str(server.parser_exc), "over size limit (2 > 1 bytes)") + self.assertConnectionFailing(server, 1009, "over size limit (2 > 1 bytes)") + + def test_client_receives_fragmented_text_without_size_limit(self): + client = Protocol(CLIENT, max_size=None) + client.receive_data(b"\x01\x02\xf0\x9f") + self.assertFrameReceived( + client, + Frame(OP_TEXT, "😀".encode()[:2], fin=False), + ) + client.receive_data(b"\x00\x04\x98\x80\xf0\x9f") + self.assertFrameReceived( + client, + Frame(OP_CONT, "😀😀".encode()[2:6], fin=False), + ) + client.receive_data(b"\x80\x02\x98\x80") + self.assertFrameReceived( + client, + Frame(OP_CONT, "😀".encode()[2:]), + ) + + def test_server_receives_fragmented_text_without_size_limit(self): + server = Protocol(SERVER, max_size=None) + server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") + self.assertFrameReceived( + server, + Frame(OP_TEXT, "😀".encode()[:2], fin=False), + ) + server.receive_data(b"\x00\x84\x00\x00\x00\x00\x98\x80\xf0\x9f") + self.assertFrameReceived( + server, + Frame(OP_CONT, "😀😀".encode()[2:6], fin=False), + ) + server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") + self.assertFrameReceived( + server, + Frame(OP_CONT, "😀".encode()[2:]), + ) + + def test_client_sends_unexpected_text(self): + client = Protocol(CLIENT) + client.send_text(b"", fin=False) + with self.assertRaises(ProtocolError) as raised: + client.send_text(b"", fin=False) + self.assertEqual(str(raised.exception), "expected a continuation frame") + + def test_server_sends_unexpected_text(self): + server = Protocol(SERVER) + server.send_text(b"", fin=False) + with self.assertRaises(ProtocolError) as raised: + server.send_text(b"", fin=False) + self.assertEqual(str(raised.exception), "expected a continuation frame") + + def test_client_receives_unexpected_text(self): + client = Protocol(CLIENT) + client.receive_data(b"\x01\x00") + self.assertFrameReceived( + client, + Frame(OP_TEXT, b"", fin=False), + ) + client.receive_data(b"\x01\x00") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "expected a continuation frame") + self.assertConnectionFailing(client, 1002, "expected a continuation frame") + + def test_server_receives_unexpected_text(self): + server = Protocol(SERVER) + server.receive_data(b"\x01\x80\x00\x00\x00\x00") + self.assertFrameReceived( + server, + Frame(OP_TEXT, b"", fin=False), + ) + server.receive_data(b"\x01\x80\x00\x00\x00\x00") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "expected a continuation frame") + self.assertConnectionFailing(server, 1002, "expected a continuation frame") + + def test_client_sends_text_after_sending_close(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_close(1001) + self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) + with self.assertRaises(InvalidState): + client.send_text(b"") + + def test_server_sends_text_after_sending_close(self): + server = Protocol(SERVER) + server.send_close(1000) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) + with self.assertRaises(InvalidState): + server.send_text(b"") + + def test_client_receives_text_after_receiving_close(self): + client = Protocol(CLIENT) + client.receive_data(b"\x88\x02\x03\xe8") + self.assertConnectionClosing(client, 1000) + client.receive_data(b"\x81\x00") + self.assertFrameReceived(client, None) + self.assertFrameSent(client, None) + + def test_server_receives_text_after_receiving_close(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") + self.assertConnectionClosing(server, 1001) + server.receive_data(b"\x81\x80\x00\xff\x00\xff") + self.assertFrameReceived(server, None) + self.assertFrameSent(server, None) + + +class BinaryTests(ProtocolTestCase): + """ + Test binary frames and continuation frames. + + """ + + def test_client_sends_binary(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_binary(b"\x01\x02\xfe\xff") + self.assertEqual( + client.data_to_send(), [b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff"] + ) + + def test_server_sends_binary(self): + server = Protocol(SERVER) + server.send_binary(b"\x01\x02\xfe\xff") + self.assertEqual(server.data_to_send(), [b"\x82\x04\x01\x02\xfe\xff"]) + + def test_client_receives_binary(self): + client = Protocol(CLIENT) + client.receive_data(b"\x82\x04\x01\x02\xfe\xff") + self.assertFrameReceived( + client, + Frame(OP_BINARY, b"\x01\x02\xfe\xff"), + ) + + def test_server_receives_binary(self): + server = Protocol(SERVER) + server.receive_data(b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff") + self.assertFrameReceived( + server, + Frame(OP_BINARY, b"\x01\x02\xfe\xff"), + ) + + def test_client_receives_binary_over_size_limit(self): + client = Protocol(CLIENT, max_size=3) + client.receive_data(b"\x82\x04\x01\x02\xfe\xff") + self.assertIsInstance(client.parser_exc, PayloadTooBig) + self.assertEqual(str(client.parser_exc), "over size limit (4 > 3 bytes)") + self.assertConnectionFailing(client, 1009, "over size limit (4 > 3 bytes)") + + def test_server_receives_binary_over_size_limit(self): + server = Protocol(SERVER, max_size=3) + server.receive_data(b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff") + self.assertIsInstance(server.parser_exc, PayloadTooBig) + self.assertEqual(str(server.parser_exc), "over size limit (4 > 3 bytes)") + self.assertConnectionFailing(server, 1009, "over size limit (4 > 3 bytes)") + + def test_client_sends_fragmented_binary(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_binary(b"\x01\x02", fin=False) + self.assertEqual(client.data_to_send(), [b"\x02\x82\x00\x00\x00\x00\x01\x02"]) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_continuation(b"\xee\xff\x01\x02", fin=False) + self.assertEqual( + client.data_to_send(), [b"\x00\x84\x00\x00\x00\x00\xee\xff\x01\x02"] + ) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_continuation(b"\xee\xff", fin=True) + self.assertEqual(client.data_to_send(), [b"\x80\x82\x00\x00\x00\x00\xee\xff"]) + + def test_server_sends_fragmented_binary(self): + server = Protocol(SERVER) + server.send_binary(b"\x01\x02", fin=False) + self.assertEqual(server.data_to_send(), [b"\x02\x02\x01\x02"]) + server.send_continuation(b"\xee\xff\x01\x02", fin=False) + self.assertEqual(server.data_to_send(), [b"\x00\x04\xee\xff\x01\x02"]) + server.send_continuation(b"\xee\xff", fin=True) + self.assertEqual(server.data_to_send(), [b"\x80\x02\xee\xff"]) + + def test_client_receives_fragmented_binary(self): + client = Protocol(CLIENT) + client.receive_data(b"\x02\x02\x01\x02") + self.assertFrameReceived( + client, + Frame(OP_BINARY, b"\x01\x02", fin=False), + ) + client.receive_data(b"\x00\x04\xfe\xff\x01\x02") + self.assertFrameReceived( + client, + Frame(OP_CONT, b"\xfe\xff\x01\x02", fin=False), + ) + client.receive_data(b"\x80\x02\xfe\xff") + self.assertFrameReceived( + client, + Frame(OP_CONT, b"\xfe\xff"), + ) + + def test_server_receives_fragmented_binary(self): + server = Protocol(SERVER) + server.receive_data(b"\x02\x82\x00\x00\x00\x00\x01\x02") + self.assertFrameReceived( + server, + Frame(OP_BINARY, b"\x01\x02", fin=False), + ) + server.receive_data(b"\x00\x84\x00\x00\x00\x00\xee\xff\x01\x02") + self.assertFrameReceived( + server, + Frame(OP_CONT, b"\xee\xff\x01\x02", fin=False), + ) + server.receive_data(b"\x80\x82\x00\x00\x00\x00\xfe\xff") + self.assertFrameReceived( + server, + Frame(OP_CONT, b"\xfe\xff"), + ) + + def test_client_receives_fragmented_binary_over_size_limit(self): + client = Protocol(CLIENT, max_size=3) + client.receive_data(b"\x02\x02\x01\x02") + self.assertFrameReceived( + client, + Frame(OP_BINARY, b"\x01\x02", fin=False), + ) + client.receive_data(b"\x80\x02\xfe\xff") + self.assertIsInstance(client.parser_exc, PayloadTooBig) + self.assertEqual(str(client.parser_exc), "over size limit (2 > 1 bytes)") + self.assertConnectionFailing(client, 1009, "over size limit (2 > 1 bytes)") + + def test_server_receives_fragmented_binary_over_size_limit(self): + server = Protocol(SERVER, max_size=3) + server.receive_data(b"\x02\x82\x00\x00\x00\x00\x01\x02") + self.assertFrameReceived( + server, + Frame(OP_BINARY, b"\x01\x02", fin=False), + ) + server.receive_data(b"\x80\x82\x00\x00\x00\x00\xfe\xff") + self.assertIsInstance(server.parser_exc, PayloadTooBig) + self.assertEqual(str(server.parser_exc), "over size limit (2 > 1 bytes)") + self.assertConnectionFailing(server, 1009, "over size limit (2 > 1 bytes)") + + def test_client_sends_unexpected_binary(self): + client = Protocol(CLIENT) + client.send_binary(b"", fin=False) + with self.assertRaises(ProtocolError) as raised: + client.send_binary(b"", fin=False) + self.assertEqual(str(raised.exception), "expected a continuation frame") + + def test_server_sends_unexpected_binary(self): + server = Protocol(SERVER) + server.send_binary(b"", fin=False) + with self.assertRaises(ProtocolError) as raised: + server.send_binary(b"", fin=False) + self.assertEqual(str(raised.exception), "expected a continuation frame") + + def test_client_receives_unexpected_binary(self): + client = Protocol(CLIENT) + client.receive_data(b"\x02\x00") + self.assertFrameReceived( + client, + Frame(OP_BINARY, b"", fin=False), + ) + client.receive_data(b"\x02\x00") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "expected a continuation frame") + self.assertConnectionFailing(client, 1002, "expected a continuation frame") + + def test_server_receives_unexpected_binary(self): + server = Protocol(SERVER) + server.receive_data(b"\x02\x80\x00\x00\x00\x00") + self.assertFrameReceived( + server, + Frame(OP_BINARY, b"", fin=False), + ) + server.receive_data(b"\x02\x80\x00\x00\x00\x00") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "expected a continuation frame") + self.assertConnectionFailing(server, 1002, "expected a continuation frame") + + def test_client_sends_binary_after_sending_close(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_close(1001) + self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) + with self.assertRaises(InvalidState): + client.send_binary(b"") + + def test_server_sends_binary_after_sending_close(self): + server = Protocol(SERVER) + server.send_close(1000) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) + with self.assertRaises(InvalidState): + server.send_binary(b"") + + def test_client_receives_binary_after_receiving_close(self): + client = Protocol(CLIENT) + client.receive_data(b"\x88\x02\x03\xe8") + self.assertConnectionClosing(client, 1000) + client.receive_data(b"\x82\x00") + self.assertFrameReceived(client, None) + self.assertFrameSent(client, None) + + def test_server_receives_binary_after_receiving_close(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") + self.assertConnectionClosing(server, 1001) + server.receive_data(b"\x82\x80\x00\xff\x00\xff") + self.assertFrameReceived(server, None) + self.assertFrameSent(server, None) + + +class CloseTests(ProtocolTestCase): + """ + Test close frames. + + See RFC 6544: + + 5.5.1. Close + 7.1.6. The WebSocket Connection Close Reason + 7.1.7. Fail the WebSocket Connection + + """ + + def test_close_code(self): + client = Protocol(CLIENT) + client.receive_data(b"\x88\x04\x03\xe8OK") + client.receive_eof() + self.assertEqual(client.close_code, 1000) + + def test_close_reason(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x84\x00\x00\x00\x00\x03\xe8OK") + server.receive_eof() + self.assertEqual(server.close_reason, "OK") + + def test_close_code_not_provided(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x80\x00\x00\x00\x00") + server.receive_eof() + self.assertEqual(server.close_code, 1005) + + def test_close_reason_not_provided(self): + client = Protocol(CLIENT) + client.receive_data(b"\x88\x00") + client.receive_eof() + self.assertEqual(client.close_reason, "") + + def test_close_code_not_available(self): + client = Protocol(CLIENT) + client.receive_eof() + self.assertEqual(client.close_code, 1006) + + def test_close_reason_not_available(self): + server = Protocol(SERVER) + server.receive_eof() + self.assertEqual(server.close_reason, "") + + def test_close_code_not_available_yet(self): + server = Protocol(SERVER) + self.assertIsNone(server.close_code) + + def test_close_reason_not_available_yet(self): + client = Protocol(CLIENT) + self.assertIsNone(client.close_reason) + + def test_client_sends_close(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x3c\x3c\x3c\x3c"): + client.send_close() + self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) + self.assertIs(client.state, CLOSING) + + def test_server_sends_close(self): + server = Protocol(SERVER) + server.send_close() + self.assertEqual(server.data_to_send(), [b"\x88\x00"]) + self.assertIs(server.state, CLOSING) + + def test_client_receives_close(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x3c\x3c\x3c\x3c"): + client.receive_data(b"\x88\x00") + self.assertEqual(client.events_received(), [Frame(OP_CLOSE, b"")]) + self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) + self.assertIs(client.state, CLOSING) + + def test_server_receives_close(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") + self.assertEqual(server.events_received(), [Frame(OP_CLOSE, b"")]) + self.assertEqual(server.data_to_send(), [b"\x88\x00", b""]) + self.assertIs(server.state, CLOSING) + + def test_client_sends_close_then_receives_close(self): + # Client-initiated close handshake on the client side. + client = Protocol(CLIENT) + + client.send_close() + self.assertFrameReceived(client, None) + self.assertFrameSent(client, Frame(OP_CLOSE, b"")) + + client.receive_data(b"\x88\x00") + self.assertFrameReceived(client, Frame(OP_CLOSE, b"")) + self.assertFrameSent(client, None) + + client.receive_eof() + self.assertFrameReceived(client, None) + self.assertFrameSent(client, None, eof=True) + + def test_server_sends_close_then_receives_close(self): + # Server-initiated close handshake on the server side. + server = Protocol(SERVER) + + server.send_close() + self.assertFrameReceived(server, None) + self.assertFrameSent(server, Frame(OP_CLOSE, b"")) + + server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") + self.assertFrameReceived(server, Frame(OP_CLOSE, b"")) + self.assertFrameSent(server, None, eof=True) + + server.receive_eof() + self.assertFrameReceived(server, None) + self.assertFrameSent(server, None) + + def test_client_receives_close_then_sends_close(self): + # Server-initiated close handshake on the client side. + client = Protocol(CLIENT) + + client.receive_data(b"\x88\x00") + self.assertFrameReceived(client, Frame(OP_CLOSE, b"")) + self.assertFrameSent(client, Frame(OP_CLOSE, b"")) + + client.receive_eof() + self.assertFrameReceived(client, None) + self.assertFrameSent(client, None, eof=True) + + def test_server_receives_close_then_sends_close(self): + # Client-initiated close handshake on the server side. + server = Protocol(SERVER) + + server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") + self.assertFrameReceived(server, Frame(OP_CLOSE, b"")) + self.assertFrameSent(server, Frame(OP_CLOSE, b""), eof=True) + + server.receive_eof() + self.assertFrameReceived(server, None) + self.assertFrameSent(server, None) + + def test_client_sends_close_with_code(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_close(1001) + self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) + self.assertIs(client.state, CLOSING) + + def test_server_sends_close_with_code(self): + server = Protocol(SERVER) + server.send_close(1000) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) + self.assertIs(server.state, CLOSING) + + def test_client_receives_close_with_code(self): + client = Protocol(CLIENT) + client.receive_data(b"\x88\x02\x03\xe8") + self.assertConnectionClosing(client, 1000, "") + self.assertIs(client.state, CLOSING) + + def test_server_receives_close_with_code(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") + self.assertConnectionClosing(server, 1001, "") + self.assertIs(server.state, CLOSING) + + def test_client_sends_close_with_code_and_reason(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_close(1001, "going away") + self.assertEqual( + client.data_to_send(), [b"\x88\x8c\x00\x00\x00\x00\x03\xe9going away"] + ) + self.assertIs(client.state, CLOSING) + + def test_server_sends_close_with_code_and_reason(self): + server = Protocol(SERVER) + server.send_close(1000, "OK") + self.assertEqual(server.data_to_send(), [b"\x88\x04\x03\xe8OK"]) + self.assertIs(server.state, CLOSING) + + def test_client_receives_close_with_code_and_reason(self): + client = Protocol(CLIENT) + client.receive_data(b"\x88\x04\x03\xe8OK") + self.assertConnectionClosing(client, 1000, "OK") + self.assertIs(client.state, CLOSING) + + def test_server_receives_close_with_code_and_reason(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x8c\x00\x00\x00\x00\x03\xe9going away") + self.assertConnectionClosing(server, 1001, "going away") + self.assertIs(server.state, CLOSING) + + def test_client_sends_close_with_reason_only(self): + client = Protocol(CLIENT) + with self.assertRaises(ProtocolError) as raised: + client.send_close(reason="going away") + self.assertEqual(str(raised.exception), "cannot send a reason without a code") + + def test_server_sends_close_with_reason_only(self): + server = Protocol(SERVER) + with self.assertRaises(ProtocolError) as raised: + server.send_close(reason="OK") + self.assertEqual(str(raised.exception), "cannot send a reason without a code") + + def test_client_receives_close_with_truncated_code(self): + client = Protocol(CLIENT) + client.receive_data(b"\x88\x01\x03") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "close frame too short") + self.assertConnectionFailing(client, 1002, "close frame too short") + self.assertIs(client.state, CLOSING) + + def test_server_receives_close_with_truncated_code(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x81\x00\x00\x00\x00\x03") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "close frame too short") + self.assertConnectionFailing(server, 1002, "close frame too short") + self.assertIs(server.state, CLOSING) + + def test_client_receives_close_with_non_utf8_reason(self): + client = Protocol(CLIENT) + + client.receive_data(b"\x88\x04\x03\xe8\xff\xff") + self.assertIsInstance(client.parser_exc, UnicodeDecodeError) + self.assertEqual( + str(client.parser_exc), + "'utf-8' codec can't decode byte 0xff in position 0: invalid start byte", + ) + self.assertConnectionFailing(client, 1007, "invalid start byte at position 0") + self.assertIs(client.state, CLOSING) + + def test_server_receives_close_with_non_utf8_reason(self): + server = Protocol(SERVER) + + server.receive_data(b"\x88\x84\x00\x00\x00\x00\x03\xe9\xff\xff") + self.assertIsInstance(server.parser_exc, UnicodeDecodeError) + self.assertEqual( + str(server.parser_exc), + "'utf-8' codec can't decode byte 0xff in position 0: invalid start byte", + ) + self.assertConnectionFailing(server, 1007, "invalid start byte at position 0") + self.assertIs(server.state, CLOSING) + + +class PingTests(ProtocolTestCase): + """ + Test ping. See 5.5.2. Ping in RFC 6544. + + """ + + def test_client_sends_ping(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x44\x88\xcc"): + client.send_ping(b"") + self.assertEqual(client.data_to_send(), [b"\x89\x80\x00\x44\x88\xcc"]) + + def test_server_sends_ping(self): + server = Protocol(SERVER) + server.send_ping(b"") + self.assertEqual(server.data_to_send(), [b"\x89\x00"]) + + def test_client_receives_ping(self): + client = Protocol(CLIENT) + client.receive_data(b"\x89\x00") + self.assertFrameReceived( + client, + Frame(OP_PING, b""), + ) + self.assertFrameSent( + client, + Frame(OP_PONG, b""), + ) + + def test_server_receives_ping(self): + server = Protocol(SERVER) + server.receive_data(b"\x89\x80\x00\x44\x88\xcc") + self.assertFrameReceived( + server, + Frame(OP_PING, b""), + ) + self.assertFrameSent( + server, + Frame(OP_PONG, b""), + ) + + def test_client_sends_ping_with_data(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x44\x88\xcc"): + client.send_ping(b"\x22\x66\xaa\xee") + self.assertEqual( + client.data_to_send(), [b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22"] + ) + + def test_server_sends_ping_with_data(self): + server = Protocol(SERVER) + server.send_ping(b"\x22\x66\xaa\xee") + self.assertEqual(server.data_to_send(), [b"\x89\x04\x22\x66\xaa\xee"]) + + def test_client_receives_ping_with_data(self): + client = Protocol(CLIENT) + client.receive_data(b"\x89\x04\x22\x66\xaa\xee") + self.assertFrameReceived( + client, + Frame(OP_PING, b"\x22\x66\xaa\xee"), + ) + self.assertFrameSent( + client, + Frame(OP_PONG, b"\x22\x66\xaa\xee"), + ) + + def test_server_receives_ping_with_data(self): + server = Protocol(SERVER) + server.receive_data(b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22") + self.assertFrameReceived( + server, + Frame(OP_PING, b"\x22\x66\xaa\xee"), + ) + self.assertFrameSent( + server, + Frame(OP_PONG, b"\x22\x66\xaa\xee"), + ) + + def test_client_sends_fragmented_ping_frame(self): + client = Protocol(CLIENT) + # This is only possible through a private API. + with self.assertRaises(ProtocolError) as raised: + client.send_frame(Frame(OP_PING, b"", fin=False)) + self.assertEqual(str(raised.exception), "fragmented control frame") + + def test_server_sends_fragmented_ping_frame(self): + server = Protocol(SERVER) + # This is only possible through a private API. + with self.assertRaises(ProtocolError) as raised: + server.send_frame(Frame(OP_PING, b"", fin=False)) + self.assertEqual(str(raised.exception), "fragmented control frame") + + def test_client_receives_fragmented_ping_frame(self): + client = Protocol(CLIENT) + client.receive_data(b"\x09\x00") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "fragmented control frame") + self.assertConnectionFailing(client, 1002, "fragmented control frame") + + def test_server_receives_fragmented_ping_frame(self): + server = Protocol(SERVER) + server.receive_data(b"\x09\x80\x3c\x3c\x3c\x3c") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "fragmented control frame") + self.assertConnectionFailing(server, 1002, "fragmented control frame") + + def test_client_sends_ping_after_sending_close(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_close(1001) + self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) + # The spec says: "An endpoint MAY send a Ping frame any time (...) + # before the connection is closed" but websockets doesn't support + # sending a Ping frame after a Close frame. + with self.assertRaises(InvalidState) as raised: + client.send_ping(b"") + self.assertEqual( + str(raised.exception), + "cannot write to a WebSocket in the CLOSING state", + ) + + def test_server_sends_ping_after_sending_close(self): + server = Protocol(SERVER) + server.send_close(1000) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) + # The spec says: "An endpoint MAY send a Ping frame any time (...) + # before the connection is closed" but websockets doesn't support + # sending a Ping frame after a Close frame. + with self.assertRaises(InvalidState) as raised: + server.send_ping(b"") + self.assertEqual( + str(raised.exception), + "cannot write to a WebSocket in the CLOSING state", + ) + + def test_client_receives_ping_after_receiving_close(self): + client = Protocol(CLIENT) + client.receive_data(b"\x88\x02\x03\xe8") + self.assertConnectionClosing(client, 1000) + client.receive_data(b"\x89\x04\x22\x66\xaa\xee") + self.assertFrameReceived(client, None) + self.assertFrameSent(client, None) + + def test_server_receives_ping_after_receiving_close(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") + self.assertConnectionClosing(server, 1001) + server.receive_data(b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22") + self.assertFrameReceived(server, None) + self.assertFrameSent(server, None) + + +class PongTests(ProtocolTestCase): + """ + Test pong frames. See 5.5.3. Pong in RFC 6544. + + """ + + def test_client_sends_pong(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x44\x88\xcc"): + client.send_pong(b"") + self.assertEqual(client.data_to_send(), [b"\x8a\x80\x00\x44\x88\xcc"]) + + def test_server_sends_pong(self): + server = Protocol(SERVER) + server.send_pong(b"") + self.assertEqual(server.data_to_send(), [b"\x8a\x00"]) + + def test_client_receives_pong(self): + client = Protocol(CLIENT) + client.receive_data(b"\x8a\x00") + self.assertFrameReceived( + client, + Frame(OP_PONG, b""), + ) + + def test_server_receives_pong(self): + server = Protocol(SERVER) + server.receive_data(b"\x8a\x80\x00\x44\x88\xcc") + self.assertFrameReceived( + server, + Frame(OP_PONG, b""), + ) + + def test_client_sends_pong_with_data(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x44\x88\xcc"): + client.send_pong(b"\x22\x66\xaa\xee") + self.assertEqual( + client.data_to_send(), [b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22"] + ) + + def test_server_sends_pong_with_data(self): + server = Protocol(SERVER) + server.send_pong(b"\x22\x66\xaa\xee") + self.assertEqual(server.data_to_send(), [b"\x8a\x04\x22\x66\xaa\xee"]) + + def test_client_receives_pong_with_data(self): + client = Protocol(CLIENT) + client.receive_data(b"\x8a\x04\x22\x66\xaa\xee") + self.assertFrameReceived( + client, + Frame(OP_PONG, b"\x22\x66\xaa\xee"), + ) + + def test_server_receives_pong_with_data(self): + server = Protocol(SERVER) + server.receive_data(b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22") + self.assertFrameReceived( + server, + Frame(OP_PONG, b"\x22\x66\xaa\xee"), + ) + + def test_client_sends_fragmented_pong_frame(self): + client = Protocol(CLIENT) + # This is only possible through a private API. + with self.assertRaises(ProtocolError) as raised: + client.send_frame(Frame(OP_PONG, b"", fin=False)) + self.assertEqual(str(raised.exception), "fragmented control frame") + + def test_server_sends_fragmented_pong_frame(self): + server = Protocol(SERVER) + # This is only possible through a private API. + with self.assertRaises(ProtocolError) as raised: + server.send_frame(Frame(OP_PONG, b"", fin=False)) + self.assertEqual(str(raised.exception), "fragmented control frame") + + def test_client_receives_fragmented_pong_frame(self): + client = Protocol(CLIENT) + client.receive_data(b"\x0a\x00") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "fragmented control frame") + self.assertConnectionFailing(client, 1002, "fragmented control frame") + + def test_server_receives_fragmented_pong_frame(self): + server = Protocol(SERVER) + server.receive_data(b"\x0a\x80\x3c\x3c\x3c\x3c") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "fragmented control frame") + self.assertConnectionFailing(server, 1002, "fragmented control frame") + + def test_client_sends_pong_after_sending_close(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_close(1001) + self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) + # websockets doesn't support sending a Pong frame after a Close frame. + with self.assertRaises(InvalidState): + client.send_pong(b"") + + def test_server_sends_pong_after_sending_close(self): + server = Protocol(SERVER) + server.send_close(1000) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) + # websockets doesn't support sending a Pong frame after a Close frame. + with self.assertRaises(InvalidState): + server.send_pong(b"") + + def test_client_receives_pong_after_receiving_close(self): + client = Protocol(CLIENT) + client.receive_data(b"\x88\x02\x03\xe8") + self.assertConnectionClosing(client, 1000) + client.receive_data(b"\x8a\x04\x22\x66\xaa\xee") + self.assertFrameReceived(client, None) + self.assertFrameSent(client, None) + + def test_server_receives_pong_after_receiving_close(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") + self.assertConnectionClosing(server, 1001) + server.receive_data(b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22") + self.assertFrameReceived(server, None) + self.assertFrameSent(server, None) + + +class FailTests(ProtocolTestCase): + """ + Test failing the connection. + + See 7.1.7. Fail the WebSocket Connection in RFC 6544. + + """ + + def test_client_stops_processing_frames_after_fail(self): + client = Protocol(CLIENT) + client.fail(1002) + self.assertConnectionFailing(client, 1002) + client.receive_data(b"\x88\x02\x03\xea") + self.assertFrameReceived(client, None) + + def test_server_stops_processing_frames_after_fail(self): + server = Protocol(SERVER) + server.fail(1002) + self.assertConnectionFailing(server, 1002) + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xea") + self.assertFrameReceived(server, None) + + +class FragmentationTests(ProtocolTestCase): + """ + Test message fragmentation. + + See 5.4. Fragmentation in RFC 6544. + + """ + + def test_client_send_ping_pong_in_fragmented_message(self): + client = Protocol(CLIENT) + client.send_text(b"Spam", fin=False) + self.assertFrameSent(client, Frame(OP_TEXT, b"Spam", fin=False)) + client.send_ping(b"Ping") + self.assertFrameSent(client, Frame(OP_PING, b"Ping")) + client.send_continuation(b"Ham", fin=False) + self.assertFrameSent(client, Frame(OP_CONT, b"Ham", fin=False)) + client.send_pong(b"Pong") + self.assertFrameSent(client, Frame(OP_PONG, b"Pong")) + client.send_continuation(b"Eggs", fin=True) + self.assertFrameSent(client, Frame(OP_CONT, b"Eggs")) + + def test_server_send_ping_pong_in_fragmented_message(self): + server = Protocol(SERVER) + server.send_text(b"Spam", fin=False) + self.assertFrameSent(server, Frame(OP_TEXT, b"Spam", fin=False)) + server.send_ping(b"Ping") + self.assertFrameSent(server, Frame(OP_PING, b"Ping")) + server.send_continuation(b"Ham", fin=False) + self.assertFrameSent(server, Frame(OP_CONT, b"Ham", fin=False)) + server.send_pong(b"Pong") + self.assertFrameSent(server, Frame(OP_PONG, b"Pong")) + server.send_continuation(b"Eggs", fin=True) + self.assertFrameSent(server, Frame(OP_CONT, b"Eggs")) + + def test_client_receive_ping_pong_in_fragmented_message(self): + client = Protocol(CLIENT) + client.receive_data(b"\x01\x04Spam") + self.assertFrameReceived( + client, + Frame(OP_TEXT, b"Spam", fin=False), + ) + client.receive_data(b"\x89\x04Ping") + self.assertFrameReceived( + client, + Frame(OP_PING, b"Ping"), + ) + self.assertFrameSent( + client, + Frame(OP_PONG, b"Ping"), + ) + client.receive_data(b"\x00\x03Ham") + self.assertFrameReceived( + client, + Frame(OP_CONT, b"Ham", fin=False), + ) + client.receive_data(b"\x8a\x04Pong") + self.assertFrameReceived( + client, + Frame(OP_PONG, b"Pong"), + ) + client.receive_data(b"\x80\x04Eggs") + self.assertFrameReceived( + client, + Frame(OP_CONT, b"Eggs"), + ) + + def test_server_receive_ping_pong_in_fragmented_message(self): + server = Protocol(SERVER) + server.receive_data(b"\x01\x84\x00\x00\x00\x00Spam") + self.assertFrameReceived( + server, + Frame(OP_TEXT, b"Spam", fin=False), + ) + server.receive_data(b"\x89\x84\x00\x00\x00\x00Ping") + self.assertFrameReceived( + server, + Frame(OP_PING, b"Ping"), + ) + self.assertFrameSent( + server, + Frame(OP_PONG, b"Ping"), + ) + server.receive_data(b"\x00\x83\x00\x00\x00\x00Ham") + self.assertFrameReceived( + server, + Frame(OP_CONT, b"Ham", fin=False), + ) + server.receive_data(b"\x8a\x84\x00\x00\x00\x00Pong") + self.assertFrameReceived( + server, + Frame(OP_PONG, b"Pong"), + ) + server.receive_data(b"\x80\x84\x00\x00\x00\x00Eggs") + self.assertFrameReceived( + server, + Frame(OP_CONT, b"Eggs"), + ) + + def test_client_send_close_in_fragmented_message(self): + client = Protocol(CLIENT) + client.send_text(b"Spam", fin=False) + self.assertFrameSent(client, Frame(OP_TEXT, b"Spam", fin=False)) + # The spec says: "An endpoint MUST be capable of handling control + # frames in the middle of a fragmented message." However, since the + # endpoint must not send a data frame after a close frame, a close + # frame can't be "in the middle" of a fragmented message. + with self.assertRaises(ProtocolError) as raised: + client.send_close(1001) + self.assertEqual(str(raised.exception), "expected a continuation frame") + client.send_continuation(b"Eggs", fin=True) + + def test_server_send_close_in_fragmented_message(self): + server = Protocol(CLIENT) + server.send_text(b"Spam", fin=False) + self.assertFrameSent(server, Frame(OP_TEXT, b"Spam", fin=False)) + # The spec says: "An endpoint MUST be capable of handling control + # frames in the middle of a fragmented message." However, since the + # endpoint must not send a data frame after a close frame, a close + # frame can't be "in the middle" of a fragmented message. + with self.assertRaises(ProtocolError) as raised: + server.send_close(1000) + self.assertEqual(str(raised.exception), "expected a continuation frame") + + def test_client_receive_close_in_fragmented_message(self): + client = Protocol(CLIENT) + client.receive_data(b"\x01\x04Spam") + self.assertFrameReceived( + client, + Frame(OP_TEXT, b"Spam", fin=False), + ) + # The spec says: "An endpoint MUST be capable of handling control + # frames in the middle of a fragmented message." However, since the + # endpoint must not send a data frame after a close frame, a close + # frame can't be "in the middle" of a fragmented message. + client.receive_data(b"\x88\x02\x03\xe8") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "incomplete fragmented message") + self.assertConnectionFailing(client, 1002, "incomplete fragmented message") + + def test_server_receive_close_in_fragmented_message(self): + server = Protocol(SERVER) + server.receive_data(b"\x01\x84\x00\x00\x00\x00Spam") + self.assertFrameReceived( + server, + Frame(OP_TEXT, b"Spam", fin=False), + ) + # The spec says: "An endpoint MUST be capable of handling control + # frames in the middle of a fragmented message." However, since the + # endpoint must not send a data frame after a close frame, a close + # frame can't be "in the middle" of a fragmented message. + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "incomplete fragmented message") + self.assertConnectionFailing(server, 1002, "incomplete fragmented message") + + +class EOFTests(ProtocolTestCase): + """ + Test half-closes on connection termination. + + """ + + def test_client_receives_eof(self): + client = Protocol(CLIENT) + client.receive_data(b"\x88\x00") + self.assertConnectionClosing(client) + client.receive_eof() + self.assertIs(client.state, CLOSED) + + def test_server_receives_eof(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") + self.assertConnectionClosing(server) + server.receive_eof() + self.assertIs(server.state, CLOSED) + + def test_client_receives_eof_between_frames(self): + client = Protocol(CLIENT) + client.receive_eof() + self.assertIsInstance(client.parser_exc, EOFError) + self.assertEqual(str(client.parser_exc), "unexpected end of stream") + self.assertIs(client.state, CLOSED) + + def test_server_receives_eof_between_frames(self): + server = Protocol(SERVER) + server.receive_eof() + self.assertIsInstance(server.parser_exc, EOFError) + self.assertEqual(str(server.parser_exc), "unexpected end of stream") + self.assertIs(server.state, CLOSED) + + def test_client_receives_eof_inside_frame(self): + client = Protocol(CLIENT) + client.receive_data(b"\x81") + client.receive_eof() + self.assertIsInstance(client.parser_exc, EOFError) + self.assertEqual( + str(client.parser_exc), + "stream ends after 1 bytes, expected 2 bytes", + ) + self.assertIs(client.state, CLOSED) + + def test_server_receives_eof_inside_frame(self): + server = Protocol(SERVER) + server.receive_data(b"\x81") + server.receive_eof() + self.assertIsInstance(server.parser_exc, EOFError) + self.assertEqual( + str(server.parser_exc), + "stream ends after 1 bytes, expected 2 bytes", + ) + self.assertIs(server.state, CLOSED) + + def test_client_receives_data_after_exception(self): + client = Protocol(CLIENT) + client.receive_data(b"\xff\xff") + self.assertConnectionFailing(client, 1002, "invalid opcode") + client.receive_data(b"\x00\x00") + self.assertFrameSent(client, None) + + def test_server_receives_data_after_exception(self): + server = Protocol(SERVER) + server.receive_data(b"\xff\xff") + self.assertConnectionFailing(server, 1002, "invalid opcode") + server.receive_data(b"\x00\x00") + self.assertFrameSent(server, None) + + def test_client_receives_eof_after_exception(self): + client = Protocol(CLIENT) + client.receive_data(b"\xff\xff") + self.assertConnectionFailing(client, 1002, "invalid opcode") + client.receive_eof() + self.assertFrameSent(client, None, eof=True) + + def test_server_receives_eof_after_exception(self): + server = Protocol(SERVER) + server.receive_data(b"\xff\xff") + self.assertConnectionFailing(server, 1002, "invalid opcode") + server.receive_eof() + self.assertFrameSent(server, None) + + def test_client_receives_data_and_eof_after_exception(self): + client = Protocol(CLIENT) + client.receive_data(b"\xff\xff") + self.assertConnectionFailing(client, 1002, "invalid opcode") + client.receive_data(b"\x00\x00") + client.receive_eof() + self.assertFrameSent(client, None, eof=True) + + def test_server_receives_data_and_eof_after_exception(self): + server = Protocol(SERVER) + server.receive_data(b"\xff\xff") + self.assertConnectionFailing(server, 1002, "invalid opcode") + server.receive_data(b"\x00\x00") + server.receive_eof() + self.assertFrameSent(server, None) + + def test_client_receives_data_after_eof(self): + client = Protocol(CLIENT) + client.receive_data(b"\x88\x00") + self.assertConnectionClosing(client) + client.receive_eof() + with self.assertRaises(EOFError) as raised: + client.receive_data(b"\x88\x00") + self.assertEqual(str(raised.exception), "stream ended") + + def test_server_receives_data_after_eof(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") + self.assertConnectionClosing(server) + server.receive_eof() + with self.assertRaises(EOFError) as raised: + server.receive_data(b"\x88\x80\x00\x00\x00\x00") + self.assertEqual(str(raised.exception), "stream ended") + + def test_client_receives_eof_after_eof(self): + client = Protocol(CLIENT) + client.receive_data(b"\x88\x00") + self.assertConnectionClosing(client) + client.receive_eof() + with self.assertRaises(EOFError) as raised: + client.receive_eof() + self.assertEqual(str(raised.exception), "stream ended") + + def test_server_receives_eof_after_eof(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") + self.assertConnectionClosing(server) + server.receive_eof() + with self.assertRaises(EOFError) as raised: + server.receive_eof() + self.assertEqual(str(raised.exception), "stream ended") + + +class TCPCloseTests(ProtocolTestCase): + """ + Test expectation of TCP close on connection termination. + + """ + + def test_client_default(self): + client = Protocol(CLIENT) + self.assertFalse(client.close_expected()) + + def test_server_default(self): + server = Protocol(SERVER) + self.assertFalse(server.close_expected()) + + def test_client_sends_close(self): + client = Protocol(CLIENT) + client.send_close() + self.assertTrue(client.close_expected()) + + def test_server_sends_close(self): + server = Protocol(SERVER) + server.send_close() + self.assertTrue(server.close_expected()) + + def test_client_receives_close(self): + client = Protocol(CLIENT) + client.receive_data(b"\x88\x00") + self.assertTrue(client.close_expected()) + + def test_client_receives_close_then_eof(self): + client = Protocol(CLIENT) + client.receive_data(b"\x88\x00") + client.receive_eof() + self.assertFalse(client.close_expected()) + + def test_server_receives_close_then_eof(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") + server.receive_eof() + self.assertFalse(server.close_expected()) + + def test_server_receives_close(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") + self.assertTrue(server.close_expected()) + + def test_client_fails_connection(self): + client = Protocol(CLIENT) + client.fail(1002) + self.assertTrue(client.close_expected()) + + def test_server_fails_connection(self): + server = Protocol(SERVER) + server.fail(1002) + self.assertTrue(server.close_expected()) + + +class ConnectionClosedTests(ProtocolTestCase): + """ + Test connection closed exception. + + """ + + def test_client_sends_close_then_receives_close(self): + # Client-initiated close handshake on the client side complete. + client = Protocol(CLIENT) + client.send_close(1000, "") + client.receive_data(b"\x88\x02\x03\xe8") + client.receive_eof() + exc = client.close_exc + self.assertIsInstance(exc, ConnectionClosedOK) + self.assertEqual(exc.rcvd, Close(1000, "")) + self.assertEqual(exc.sent, Close(1000, "")) + self.assertFalse(exc.rcvd_then_sent) + + def test_server_sends_close_then_receives_close(self): + # Server-initiated close handshake on the server side complete. + server = Protocol(SERVER) + server.send_close(1000, "") + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe8") + server.receive_eof() + exc = server.close_exc + self.assertIsInstance(exc, ConnectionClosedOK) + self.assertEqual(exc.rcvd, Close(1000, "")) + self.assertEqual(exc.sent, Close(1000, "")) + self.assertFalse(exc.rcvd_then_sent) + + def test_client_receives_close_then_sends_close(self): + # Server-initiated close handshake on the client side complete. + client = Protocol(CLIENT) + client.receive_data(b"\x88\x02\x03\xe8") + client.receive_eof() + exc = client.close_exc + self.assertIsInstance(exc, ConnectionClosedOK) + self.assertEqual(exc.rcvd, Close(1000, "")) + self.assertEqual(exc.sent, Close(1000, "")) + self.assertTrue(exc.rcvd_then_sent) + + def test_server_receives_close_then_sends_close(self): + # Client-initiated close handshake on the server side complete. + server = Protocol(SERVER) + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe8") + server.receive_eof() + exc = server.close_exc + self.assertIsInstance(exc, ConnectionClosedOK) + self.assertEqual(exc.rcvd, Close(1000, "")) + self.assertEqual(exc.sent, Close(1000, "")) + self.assertTrue(exc.rcvd_then_sent) + + def test_client_sends_close_then_receives_eof(self): + # Client-initiated close handshake on the client side times out. + client = Protocol(CLIENT) + client.send_close(1000, "") + client.receive_eof() + exc = client.close_exc + self.assertIsInstance(exc, ConnectionClosedError) + self.assertIsNone(exc.rcvd) + self.assertEqual(exc.sent, Close(1000, "")) + self.assertIsNone(exc.rcvd_then_sent) + + def test_server_sends_close_then_receives_eof(self): + # Server-initiated close handshake on the server side times out. + server = Protocol(SERVER) + server.send_close(1000, "") + server.receive_eof() + exc = server.close_exc + self.assertIsInstance(exc, ConnectionClosedError) + self.assertIsNone(exc.rcvd) + self.assertEqual(exc.sent, Close(1000, "")) + self.assertIsNone(exc.rcvd_then_sent) + + def test_client_receives_eof(self): + # Server-initiated close handshake on the client side times out. + client = Protocol(CLIENT) + client.receive_eof() + exc = client.close_exc + self.assertIsInstance(exc, ConnectionClosedError) + self.assertIsNone(exc.rcvd) + self.assertIsNone(exc.sent) + self.assertIsNone(exc.rcvd_then_sent) + + def test_server_receives_eof(self): + # Client-initiated close handshake on the server side times out. + server = Protocol(SERVER) + server.receive_eof() + exc = server.close_exc + self.assertIsInstance(exc, ConnectionClosedError) + self.assertIsNone(exc.rcvd) + self.assertIsNone(exc.sent) + self.assertIsNone(exc.rcvd_then_sent) + + +class ErrorTests(ProtocolTestCase): + """ + Test other error cases. + + """ + + def test_client_hits_internal_error_reading_frame(self): + client = Protocol(CLIENT) + # This isn't supposed to happen, so we're simulating it. + with unittest.mock.patch("struct.unpack", side_effect=RuntimeError("BOOM")): + client.receive_data(b"\x81\x00") + self.assertIsInstance(client.parser_exc, RuntimeError) + self.assertEqual(str(client.parser_exc), "BOOM") + self.assertConnectionFailing(client, 1011, "") + + def test_server_hits_internal_error_reading_frame(self): + server = Protocol(SERVER) + # This isn't supposed to happen, so we're simulating it. + with unittest.mock.patch("struct.unpack", side_effect=RuntimeError("BOOM")): + server.receive_data(b"\x81\x80\x00\x00\x00\x00") + self.assertIsInstance(server.parser_exc, RuntimeError) + self.assertEqual(str(server.parser_exc), "BOOM") + self.assertConnectionFailing(server, 1011, "") + + +class ExtensionsTests(ProtocolTestCase): + """ + Test how extensions affect frames. + + """ + + def test_client_extension_encodes_frame(self): + client = Protocol(CLIENT) + client.extensions = [Rsv2Extension()] + with self.enforce_mask(b"\x00\x44\x88\xcc"): + client.send_ping(b"") + self.assertEqual(client.data_to_send(), [b"\xa9\x80\x00\x44\x88\xcc"]) + + def test_server_extension_encodes_frame(self): + server = Protocol(SERVER) + server.extensions = [Rsv2Extension()] + server.send_ping(b"") + self.assertEqual(server.data_to_send(), [b"\xa9\x00"]) + + def test_client_extension_decodes_frame(self): + client = Protocol(CLIENT) + client.extensions = [Rsv2Extension()] + client.receive_data(b"\xaa\x00") + self.assertEqual(client.events_received(), [Frame(OP_PONG, b"")]) + + def test_server_extension_decodes_frame(self): + server = Protocol(SERVER) + server.extensions = [Rsv2Extension()] + server.receive_data(b"\xaa\x80\x00\x44\x88\xcc") + self.assertEqual(server.events_received(), [Frame(OP_PONG, b"")]) + + +class MiscTests(unittest.TestCase): + def test_client_default_logger(self): + client = Protocol(CLIENT) + logger = logging.getLogger("websockets.client") + self.assertIs(client.logger, logger) + + def test_server_default_logger(self): + server = Protocol(SERVER) + logger = logging.getLogger("websockets.server") + self.assertIs(server.logger, logger) + + def test_client_custom_logger(self): + logger = logging.getLogger("test") + client = Protocol(CLIENT, logger=logger) + self.assertIs(client.logger, logger) + + def test_server_custom_logger(self): + logger = logging.getLogger("test") + server = Protocol(SERVER, logger=logger) + self.assertIs(server.logger, logger) diff --git a/tests/test_server.py b/tests/test_server.py index f1404499b..ba7e05df2 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -3,11 +3,11 @@ import unittest import unittest.mock -from websockets.connection import CONNECTING, OPEN from websockets.datastructures import Headers from websockets.exceptions import InvalidHeader, InvalidOrigin, InvalidUpgrade from websockets.frames import OP_TEXT, Frame from websockets.http11 import Request, Response +from websockets.protocol import CONNECTING, OPEN from websockets.server import * from .extensions.utils import ( @@ -17,12 +17,12 @@ ServerRsv2ExtensionFactory, ) from .test_utils import ACCEPT, KEY -from .utils import DATE +from .utils import DATE, DeprecationTestCase class ConnectTests(unittest.TestCase): def test_receive_connect(self): - server = ServerConnection() + server = ServerProtocol() server.receive_data( ( f"GET /test HTTP/1.1\r\n" @@ -40,7 +40,7 @@ def test_receive_connect(self): self.assertFalse(server.close_expected()) def test_connect_request(self): - server = ServerConnection() + server = ServerProtocol() server.receive_data( ( f"GET /test HTTP/1.1\r\n" @@ -84,7 +84,7 @@ def make_request(self): ) def test_send_accept(self): - server = ServerConnection() + server = ServerProtocol() with unittest.mock.patch("email.utils.formatdate", return_value=DATE): response = server.accept(self.make_request()) self.assertIsInstance(response, Response) @@ -104,7 +104,7 @@ def test_send_accept(self): self.assertEqual(server.state, OPEN) def test_send_reject(self): - server = ServerConnection() + server = ServerProtocol() with unittest.mock.patch("email.utils.formatdate", return_value=DATE): response = server.reject(http.HTTPStatus.NOT_FOUND, "Sorry folks.\n") self.assertIsInstance(response, Response) @@ -126,7 +126,7 @@ def test_send_reject(self): self.assertEqual(server.state, CONNECTING) def test_accept_response(self): - server = ServerConnection() + server = ServerProtocol() with unittest.mock.patch("email.utils.formatdate", return_value=DATE): response = server.accept(self.make_request()) self.assertIsInstance(response, Response) @@ -146,7 +146,7 @@ def test_accept_response(self): self.assertIsNone(response.body) def test_reject_response(self): - server = ServerConnection() + server = ServerProtocol() with unittest.mock.patch("email.utils.formatdate", return_value=DATE): response = server.reject(http.HTTPStatus.NOT_FOUND, "Sorry folks.\n") self.assertIsInstance(response, Response) @@ -166,17 +166,17 @@ def test_reject_response(self): self.assertEqual(response.body, b"Sorry folks.\n") def test_basic(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() response = server.accept(request) self.assertEqual(response.status_code, 101) def test_unexpected_exception(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() with unittest.mock.patch( - "websockets.server.ServerConnection.process_request", + "websockets.server.ServerProtocol.process_request", side_effect=Exception("BOOM"), ): response = server.accept(request) @@ -187,7 +187,7 @@ def test_unexpected_exception(self): self.assertEqual(str(raised.exception), "BOOM") def test_missing_connection(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() del request.headers["Connection"] response = server.accept(request) @@ -199,7 +199,7 @@ def test_missing_connection(self): self.assertEqual(str(raised.exception), "missing Connection header") def test_invalid_connection(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() del request.headers["Connection"] request.headers["Connection"] = "close" @@ -212,7 +212,7 @@ def test_invalid_connection(self): self.assertEqual(str(raised.exception), "invalid Connection header: close") def test_missing_upgrade(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() del request.headers["Upgrade"] response = server.accept(request) @@ -224,7 +224,7 @@ def test_missing_upgrade(self): self.assertEqual(str(raised.exception), "missing Upgrade header") def test_invalid_upgrade(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() del request.headers["Upgrade"] request.headers["Upgrade"] = "h2c" @@ -237,7 +237,7 @@ def test_invalid_upgrade(self): self.assertEqual(str(raised.exception), "invalid Upgrade header: h2c") def test_missing_key(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() del request.headers["Sec-WebSocket-Key"] response = server.accept(request) @@ -248,7 +248,7 @@ def test_missing_key(self): self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Key header") def test_multiple_key(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() request.headers["Sec-WebSocket-Key"] = KEY response = server.accept(request) @@ -263,7 +263,7 @@ def test_multiple_key(self): ) def test_invalid_key(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() del request.headers["Sec-WebSocket-Key"] request.headers["Sec-WebSocket-Key"] = "not Base64 data!" @@ -277,7 +277,7 @@ def test_invalid_key(self): ) def test_truncated_key(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() del request.headers["Sec-WebSocket-Key"] request.headers["Sec-WebSocket-Key"] = KEY[ @@ -293,7 +293,7 @@ def test_truncated_key(self): ) def test_missing_version(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() del request.headers["Sec-WebSocket-Version"] response = server.accept(request) @@ -304,7 +304,7 @@ def test_missing_version(self): self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Version header") def test_multiple_version(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() request.headers["Sec-WebSocket-Version"] = "11" response = server.accept(request) @@ -319,7 +319,7 @@ def test_multiple_version(self): ) def test_invalid_version(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() del request.headers["Sec-WebSocket-Version"] request.headers["Sec-WebSocket-Version"] = "11" @@ -333,7 +333,7 @@ def test_invalid_version(self): ) def test_no_origin(self): - server = ServerConnection(origins=["https://example.com"]) + server = ServerProtocol(origins=["https://example.com"]) request = self.make_request() response = server.accept(request) @@ -343,7 +343,7 @@ def test_no_origin(self): self.assertEqual(str(raised.exception), "missing Origin header") def test_origin(self): - server = ServerConnection(origins=["https://example.com"]) + server = ServerProtocol(origins=["https://example.com"]) request = self.make_request() request.headers["Origin"] = "https://example.com" response = server.accept(request) @@ -352,7 +352,7 @@ def test_origin(self): self.assertEqual(server.origin, "https://example.com") def test_unexpected_origin(self): - server = ServerConnection(origins=["https://example.com"]) + server = ServerProtocol(origins=["https://example.com"]) request = self.make_request() request.headers["Origin"] = "https://other.example.com" response = server.accept(request) @@ -365,7 +365,7 @@ def test_unexpected_origin(self): ) def test_multiple_origin(self): - server = ServerConnection( + server = ServerProtocol( origins=["https://example.com", "https://other.example.com"] ) request = self.make_request() @@ -384,7 +384,7 @@ def test_multiple_origin(self): ) def test_supported_origin(self): - server = ServerConnection( + server = ServerProtocol( origins=["https://example.com", "https://other.example.com"] ) request = self.make_request() @@ -395,7 +395,7 @@ def test_supported_origin(self): self.assertEqual(server.origin, "https://other.example.com") def test_unsupported_origin(self): - server = ServerConnection( + server = ServerProtocol( origins=["https://example.com", "https://other.example.com"] ) request = self.make_request() @@ -410,7 +410,7 @@ def test_unsupported_origin(self): ) def test_no_origin_accepted(self): - server = ServerConnection(origins=[None]) + server = ServerProtocol(origins=[None]) request = self.make_request() response = server.accept(request) @@ -418,7 +418,7 @@ def test_no_origin_accepted(self): self.assertIsNone(server.origin) def test_no_extensions(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() response = server.accept(request) @@ -427,7 +427,7 @@ def test_no_extensions(self): self.assertEqual(server.extensions, []) def test_no_extension(self): - server = ServerConnection(extensions=[ServerOpExtensionFactory()]) + server = ServerProtocol(extensions=[ServerOpExtensionFactory()]) request = self.make_request() response = server.accept(request) @@ -436,7 +436,7 @@ def test_no_extension(self): self.assertEqual(server.extensions, []) def test_extension(self): - server = ServerConnection(extensions=[ServerOpExtensionFactory()]) + server = ServerProtocol(extensions=[ServerOpExtensionFactory()]) request = self.make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op" response = server.accept(request) @@ -446,7 +446,7 @@ def test_extension(self): self.assertEqual(server.extensions, [OpExtension()]) def test_unexpected_extension(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op" response = server.accept(request) @@ -456,7 +456,7 @@ def test_unexpected_extension(self): self.assertEqual(server.extensions, []) def test_unsupported_extension(self): - server = ServerConnection(extensions=[ServerRsv2ExtensionFactory()]) + server = ServerProtocol(extensions=[ServerRsv2ExtensionFactory()]) request = self.make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op" response = server.accept(request) @@ -466,7 +466,7 @@ def test_unsupported_extension(self): self.assertEqual(server.extensions, []) def test_supported_extension_parameters(self): - server = ServerConnection(extensions=[ServerOpExtensionFactory("this")]) + server = ServerProtocol(extensions=[ServerOpExtensionFactory("this")]) request = self.make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op=this" response = server.accept(request) @@ -476,7 +476,7 @@ def test_supported_extension_parameters(self): self.assertEqual(server.extensions, [OpExtension("this")]) def test_unsupported_extension_parameters(self): - server = ServerConnection(extensions=[ServerOpExtensionFactory("this")]) + server = ServerProtocol(extensions=[ServerOpExtensionFactory("this")]) request = self.make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" response = server.accept(request) @@ -486,7 +486,7 @@ def test_unsupported_extension_parameters(self): self.assertEqual(server.extensions, []) def test_multiple_supported_extension_parameters(self): - server = ServerConnection( + server = ServerProtocol( extensions=[ ServerOpExtensionFactory("this"), ServerOpExtensionFactory("that"), @@ -501,7 +501,7 @@ def test_multiple_supported_extension_parameters(self): self.assertEqual(server.extensions, [OpExtension("that")]) def test_multiple_extensions(self): - server = ServerConnection( + server = ServerProtocol( extensions=[ServerOpExtensionFactory(), ServerRsv2ExtensionFactory()] ) request = self.make_request() @@ -516,7 +516,7 @@ def test_multiple_extensions(self): self.assertEqual(server.extensions, [OpExtension(), Rsv2Extension()]) def test_multiple_extensions_order(self): - server = ServerConnection( + server = ServerProtocol( extensions=[ServerOpExtensionFactory(), ServerRsv2ExtensionFactory()] ) request = self.make_request() @@ -531,7 +531,7 @@ def test_multiple_extensions_order(self): self.assertEqual(server.extensions, [Rsv2Extension(), OpExtension()]) def test_no_subprotocols(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() response = server.accept(request) @@ -540,7 +540,7 @@ def test_no_subprotocols(self): self.assertIsNone(server.subprotocol) def test_no_subprotocol(self): - server = ServerConnection(subprotocols=["chat"]) + server = ServerProtocol(subprotocols=["chat"]) request = self.make_request() response = server.accept(request) @@ -549,7 +549,7 @@ def test_no_subprotocol(self): self.assertIsNone(server.subprotocol) def test_subprotocol(self): - server = ServerConnection(subprotocols=["chat"]) + server = ServerProtocol(subprotocols=["chat"]) request = self.make_request() request.headers["Sec-WebSocket-Protocol"] = "chat" response = server.accept(request) @@ -559,7 +559,7 @@ def test_subprotocol(self): self.assertEqual(server.subprotocol, "chat") def test_unexpected_subprotocol(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() request.headers["Sec-WebSocket-Protocol"] = "chat" response = server.accept(request) @@ -569,7 +569,7 @@ def test_unexpected_subprotocol(self): self.assertIsNone(server.subprotocol) def test_multiple_subprotocols(self): - server = ServerConnection(subprotocols=["superchat", "chat"]) + server = ServerProtocol(subprotocols=["superchat", "chat"]) request = self.make_request() request.headers["Sec-WebSocket-Protocol"] = "superchat" request.headers["Sec-WebSocket-Protocol"] = "chat" @@ -580,7 +580,7 @@ def test_multiple_subprotocols(self): self.assertEqual(server.subprotocol, "superchat") def test_supported_subprotocol(self): - server = ServerConnection(subprotocols=["superchat", "chat"]) + server = ServerProtocol(subprotocols=["superchat", "chat"]) request = self.make_request() request.headers["Sec-WebSocket-Protocol"] = "chat" response = server.accept(request) @@ -590,7 +590,7 @@ def test_supported_subprotocol(self): self.assertEqual(server.subprotocol, "chat") def test_unsupported_subprotocol(self): - server = ServerConnection(subprotocols=["superchat", "chat"]) + server = ServerProtocol(subprotocols=["superchat", "chat"]) request = self.make_request() request.headers["Sec-WebSocket-Protocol"] = "otherchat" response = server.accept(request) @@ -602,7 +602,7 @@ def test_unsupported_subprotocol(self): class MiscTests(unittest.TestCase): def test_bypass_handshake(self): - server = ServerConnection(state=OPEN) + server = ServerProtocol(state=OPEN) server.receive_data(b"\x81\x86\x00\x00\x00\x00Hello!") [frame] = server.events_received() self.assertEqual(frame, Frame(OP_TEXT, b"Hello!")) @@ -610,5 +610,17 @@ def test_bypass_handshake(self): def test_custom_logger(self): logger = logging.getLogger("test") with self.assertLogs("test", logging.DEBUG) as logs: - ServerConnection(logger=logger) + ServerProtocol(logger=logger) self.assertEqual(len(logs.records), 1) + + +class BackwardsCompatibilityTests(DeprecationTestCase): + def test_server_connection_class(self): + with self.assertDeprecationWarning( + "ServerConnection was renamed to ServerProtocol" + ): + from websockets.server import ServerConnection + + server = ServerConnection("ws://localhost/") + + self.assertIsInstance(server, ServerProtocol) diff --git a/tests/utils.py b/tests/utils.py index ac891a0fd..92c754810 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,11 +1,18 @@ +import contextlib import email.utils import unittest +import warnings DATE = email.utils.formatdate(usegmt=True) class GeneratorTestCase(unittest.TestCase): + """ + Base class for testing generator-based coroutines. + + """ + def assertGeneratorRunning(self, gen): """ Check that a generator-based coroutine hasn't completed yet. @@ -21,3 +28,25 @@ def assertGeneratorReturns(self, gen): with self.assertRaises(StopIteration) as raised: next(gen) return raised.exception.value + + +class DeprecationTestCase(unittest.TestCase): + """ + Base class for testing deprecations. + + """ + + @contextlib.contextmanager + def assertDeprecationWarning(self, message): + """ + Check that a deprecation warning was raised with the given message. + + """ + with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter("always") + yield + + self.assertEqual(len(recorded_warnings), 1) + warning = recorded_warnings[0] + self.assertEqual(warning.category, DeprecationWarning) + self.assertEqual(str(warning.message), message) From 35731196de91b91fa79573ad6235a27858027398 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 Oct 2022 11:53:16 +0200 Subject: [PATCH 1143/1539] Change Sans-I/O constructors to keyword-only. --- docs/project/changelog.rst | 7 +++++++ src/websockets/client.py | 1 + src/websockets/protocol.py | 1 + src/websockets/server.py | 1 + tests/test_server.py | 2 +- 5 files changed, 11 insertions(+), 1 deletion(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 532eeffb2..aab9afdb1 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -45,6 +45,13 @@ Backwards-incompatible changes ``client.ClientConnection`` classes were renamed to ``protocol.Protocol``, ``server.ServerProtocol``, and ``client.ClientProtocol``. +.. admonition:: Sans-I/O protocol constructors now use keyword-only arguments. + :class: caution + + If you instantiate :class:`~server.ServerProtocol` or + :class:`~client.ClientProtocol` directly, make sure you are using keyword + arguments. + .. admonition:: Closing a connection without an empty close frame is OK. :class: note diff --git a/src/websockets/client.py b/src/websockets/client.py index a439ab846..ad4c1a15d 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -70,6 +70,7 @@ class ClientProtocol(Protocol): def __init__( self, wsuri: WebSocketURI, + *, origin: Optional[Origin] = None, extensions: Optional[Sequence[ClientExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 29d7e1596..1f1af4a84 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -86,6 +86,7 @@ class Protocol: def __init__( self, side: Side, + *, state: State = OPEN, max_size: Optional[int] = 2**20, logger: Optional[LoggerLike] = None, diff --git a/src/websockets/server.py b/src/websockets/server.py index 548048c92..cc5ba798a 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -69,6 +69,7 @@ class ServerProtocol(Protocol): def __init__( self, + *, origins: Optional[Sequence[Optional[Origin]]] = None, extensions: Optional[Sequence[ServerExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, diff --git a/tests/test_server.py b/tests/test_server.py index ba7e05df2..2645dc843 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -621,6 +621,6 @@ def test_server_connection_class(self): ): from websockets.server import ServerConnection - server = ServerConnection("ws://localhost/") + server = ServerConnection() self.assertIsInstance(server, ServerProtocol) From 1363dd5be5b55fb3a16d58ca346c9f09e7519393 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 28 Nov 2022 07:28:09 +0100 Subject: [PATCH 1144/1539] Revert "Fix example of shutting down a client." This reverts commit 9e960b50. The example was correct. --- example/faq/shutdown_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/faq/shutdown_client.py b/example/faq/shutdown_client.py index 65cfca41b..539dd0304 100755 --- a/example/faq/shutdown_client.py +++ b/example/faq/shutdown_client.py @@ -10,7 +10,7 @@ async def client(): # Close the connection when receiving SIGTERM. loop = asyncio.get_running_loop() loop.add_signal_handler( - signal.SIGTERM, loop.create_task, websocket.close) + signal.SIGTERM, loop.create_task, websocket.close()) # Process messages received on the connection. async for message in websocket: From 23a2d3f6dcd9056c94406b7354d0def89e46a720 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 Oct 2022 13:37:22 +0200 Subject: [PATCH 1145/1539] Add API to customize subprotocol selection logic. Also changed the default logic: (1) to reject client connections that don't offer a subprotocol when the server is configured with subprotocols. This is the expected behavior for what I believe to be the default use case: require one particular subprotocol. This change of behavior isn't documented because I don't know anyone embedding the Sans-I/O layer and supporting subprotocol selection at this time. (2) to rely only on the order of preference of the server. Hey, trying to cater to the preferences of clients was nice, but the behavior was so needlessly complex that documentation apologized... This keeps things simple and still falls within the documented behavior. --- docs/project/changelog.rst | 5 +- docs/reference/server.rst | 2 + src/websockets/legacy/server.py | 24 +++---- src/websockets/server.py | 116 +++++++++++++++++++++----------- tests/test_server.py | 46 +++++++++++-- 5 files changed, 136 insertions(+), 57 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index aab9afdb1..34e1f909c 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -37,7 +37,7 @@ Backwards-incompatible changes :class: caution Aliases provide compatibility for all previously public APIs according to - the `backwards-compatibility policy`_ + the `backwards-compatibility policy`_. * The ``connection`` module was renamed to ``protocol``. @@ -67,6 +67,9 @@ New features * Made it possible to close a server without closing existing connections. +* Added :attr:`~protocol.ServerProtocol.select_subprotocol` to customize + negotiation of subprotocols in the Sans-I/O layer. + 10.4 ---- diff --git a/docs/reference/server.rst b/docs/reference/server.rst index 50ef4ee3c..08bfe2f57 100644 --- a/docs/reference/server.rst +++ b/docs/reference/server.rst @@ -123,6 +123,8 @@ Sans-I/O .. automethod:: accept + .. automethod:: select_subprotocol + .. automethod:: reject .. automethod:: send_response diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index eabeb8e96..5ec5ed1a3 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -520,31 +520,29 @@ def select_subprotocol( server_subprotocols: Sequence[Subprotocol], ) -> Optional[Subprotocol]: """ - Pick a subprotocol among those offered by the client. + Pick a subprotocol among those supported by the client and the server. - If several subprotocols are supported by the client and the server, - the default implementation selects the preferred subprotocol by - giving equal value to the priorities of the client and the server. - If no subprotocol is supported by the client and the server, it - proceeds without a subprotocol. + If several subprotocols are available, select the preferred subprotocol + by giving equal weight to the preferences of the client and the server. - This is unlikely to be the most useful implementation in practice. - Many servers providing a subprotocol will require that the client - uses that subprotocol. Such rules can be implemented in a subclass. + If no subprotocol is available, proceed without a subprotocol. - You may also override this method with the ``select_subprotocol`` - argument of :func:`serve` and :class:`WebSocketServerProtocol`. + You may provide a ``select_subprotocol`` argument to :func:`serve` or + :class:`WebSocketServerProtocol` to override this logic. For example, + you could reject the handshake if the client doesn't support a + particular subprotocol, rather than accept the handshake without that + subprotocol. Args: client_subprotocols: list of subprotocols offered by the client. server_subprotocols: list of subprotocols available on the server. Returns: - Optional[Subprotocol]: Selected subprotocol. + Optional[Subprotocol]: Selected subprotocol, if a common subprotocol + was found. :obj:`None` to continue without a subprotocol. - """ if self._select_subprotocol is not None: return self._select_subprotocol(client_subprotocols, server_subprotocols) diff --git a/src/websockets/server.py b/src/websockets/server.py index cc5ba798a..547c516b5 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -5,7 +5,7 @@ import email.utils import http import warnings -from typing import Any, Generator, List, Optional, Sequence, Tuple, cast +from typing import Any, Callable, Generator, List, Optional, Sequence, Tuple, cast from .datastructures import Headers, MultipleValuesError from .exceptions import ( @@ -58,6 +58,10 @@ class ServerProtocol(Protocol): should be tried. subprotocols: list of supported subprotocols, in order of decreasing preference. + select_subprotocol: callback for selecting a subprotocol among + those supported by the client and the server. It has the same + signature as the :meth:`select_subprotocol` method, including a + :class:`ServerProtocol` instance as first argument. state: initial state of the WebSocket connection. max_size: maximum size of incoming messages in bytes; :obj:`None` to disable the limit. @@ -73,6 +77,7 @@ def __init__( origins: Optional[Sequence[Optional[Origin]]] = None, extensions: Optional[Sequence[ServerExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, + select_subprotocol: Optional[SelectSubprotocol] = None, state: State = CONNECTING, max_size: Optional[int] = 2**20, logger: Optional[LoggerLike] = None, @@ -86,6 +91,14 @@ def __init__( self.origins = origins self.available_extensions = extensions self.available_subprotocols = subprotocols + if select_subprotocol is not None: + # Bind select_subprotocol then shadow self.select_subprotocol. + # Use setattr to work around https://github.com/python/mypy/issues/2427. + setattr( + self, + "select_subprotocol", + select_subprotocol.__get__(self, self.__class__), + ) def accept(self, request: Request) -> Response: """ @@ -96,7 +109,7 @@ def accept(self, request: Request) -> Response: You must send the handshake response with :meth:`send_response`. - You can modify it before sending it, for example to add HTTP headers. + You may modify it before sending it, for example to add HTTP headers. Args: request: WebSocket handshake request event received from the client. @@ -175,7 +188,8 @@ def accept(self, request: Request) -> Response: return Response(101, "Switching Protocols", headers) def process_request( - self, request: Request + self, + request: Request, ) -> Tuple[str, Optional[str], Optional[str]]: """ Check a handshake request and negotiate extensions and subprotocol. @@ -273,6 +287,7 @@ def process_origin(self, headers: Headers) -> Optional[Origin]: Optional[Origin]: origin, if it is acceptable. Raises: + InvalidHandshake: if the Origin header is invalid. InvalidOrigin: if the origin isn't acceptable. """ @@ -323,7 +338,7 @@ def process_extensions( HTTP response header and list of accepted extensions. Raises: - InvalidHandshake: to abort the handshake with an HTTP 400 error. + InvalidHandshake: if the Sec-WebSocket-Extensions header is invalid. """ response_header_value: Optional[str] = None @@ -383,60 +398,79 @@ def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: also the value of the ``Sec-WebSocket-Protocol`` response header. Raises: - InvalidHandshake: to abort the handshake with an HTTP 400 error. + InvalidHandshake: if the Sec-WebSocket-Subprotocol header is invalid. """ - subprotocol: Optional[Subprotocol] = None - - header_values = headers.get_all("Sec-WebSocket-Protocol") - - if header_values and self.available_subprotocols: - - parsed_header_values: List[Subprotocol] = sum( - [parse_subprotocol(header_value) for header_value in header_values], [] - ) - - subprotocol = self.select_subprotocol( - parsed_header_values, self.available_subprotocols - ) + subprotocols: Sequence[Subprotocol] = sum( + [ + parse_subprotocol(header_value) + for header_value in headers.get_all("Sec-WebSocket-Protocol") + ], + [], + ) - return subprotocol + return self.select_subprotocol(subprotocols) def select_subprotocol( self, - client_subprotocols: Sequence[Subprotocol], - server_subprotocols: Sequence[Subprotocol], + subprotocols: Sequence[Subprotocol], ) -> Optional[Subprotocol]: """ Pick a subprotocol among those offered by the client. - If several subprotocols are supported by the client and the server, - the default implementation selects the preferred subprotocols by - giving equal value to the priorities of the client and the server. + If several subprotocols are supported by both the client and the server, + pick the first one in the list declared the server. + + If the server doesn't support any subprotocols, continue without a + subprotocol, regardless of what the client offers. - If no common subprotocol is supported by the client and the server, it - proceeds without a subprotocol. + If the server supports at least one subprotocol and the client doesn't + offer any, abort the handshake with an HTTP 400 error. - This is unlikely to be the most useful implementation in practice, as - many servers providing a subprotocol will require that the client uses - that subprotocol. + You provide a ``select_subprotocol`` argument to :class:`ServerProtocol` + to override this logic. For example, you could accept the connection + even if client doesn't offer a subprotocol, rather than reject it. + + Here's how to negotiate the ``chat`` subprotocol if the client supports + it and continue without a subprotocol otherwise:: + + def select_subprotocol(protocol, subprotocols): + if "chat" in subprotocols: + return "chat" Args: - client_subprotocols: list of subprotocols offered by the client. - server_subprotocols: list of subprotocols available on the server. + subprotocols: list of subprotocols offered by the client. Returns: - Optional[Subprotocol]: Subprotocol, if a common subprotocol was - found. + Optional[Subprotocol]: Selected subprotocol, if a common subprotocol + was found. + + :obj:`None` to continue without a subprotocol. + + Raises: + NegotiationError: custom implementations may raise this exception + to abort the handshake with an HTTP 400 error. """ - subprotocols = set(client_subprotocols) & set(server_subprotocols) - if not subprotocols: + # Server doesn't offer any subprotocols. + if not self.available_subprotocols: # None or empty list return None - priority = lambda p: ( - client_subprotocols.index(p) + server_subprotocols.index(p) + + # Server offers at least one subprotocol but client doesn't offer any. + if not subprotocols: + raise NegotiationError("missing subprotocol") + + # Server and client both offer subprotocols. Look for a shared one. + proposed_subprotocols = set(subprotocols) + for subprotocol in self.available_subprotocols: + if subprotocol in proposed_subprotocols: + return subprotocol + + # No common subprotocol was found. + raise NegotiationError( + "invalid subprotocol; expected one of " + + ", ".join(self.available_subprotocols) ) - return sorted(subprotocols, key=priority)[0] def reject( self, @@ -519,6 +553,12 @@ def parse(self) -> Generator[None, None, None]: yield from super().parse() +SelectSubprotocol = Callable[ + [ServerProtocol, Sequence[Subprotocol]], + Optional[Subprotocol], +] + + class ServerConnection(ServerProtocol): def __init__(self, *args: Any, **kwargs: Any) -> None: warnings.warn( diff --git a/tests/test_server.py b/tests/test_server.py index 2645dc843..1c0bbb292 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -4,7 +4,12 @@ import unittest.mock from websockets.datastructures import Headers -from websockets.exceptions import InvalidHeader, InvalidOrigin, InvalidUpgrade +from websockets.exceptions import ( + InvalidHeader, + InvalidOrigin, + InvalidUpgrade, + NegotiationError, +) from websockets.frames import OP_TEXT, Frame from websockets.http11 import Request, Response from websockets.protocol import CONNECTING, OPEN @@ -544,9 +549,12 @@ def test_no_subprotocol(self): request = self.make_request() response = server.accept(request) - self.assertEqual(response.status_code, 101) - self.assertNotIn("Sec-WebSocket-Protocol", response.headers) - self.assertIsNone(server.subprotocol) + self.assertEqual(response.status_code, 400) + with self.assertRaisesRegex( + NegotiationError, + r"missing subprotocol", + ): + raise server.handshake_exc def test_subprotocol(self): server = ServerProtocol(subprotocols=["chat"]) @@ -571,8 +579,8 @@ def test_unexpected_subprotocol(self): def test_multiple_subprotocols(self): server = ServerProtocol(subprotocols=["superchat", "chat"]) request = self.make_request() - request.headers["Sec-WebSocket-Protocol"] = "superchat" request.headers["Sec-WebSocket-Protocol"] = "chat" + request.headers["Sec-WebSocket-Protocol"] = "superchat" response = server.accept(request) self.assertEqual(response.status_code, 101) @@ -595,6 +603,34 @@ def test_unsupported_subprotocol(self): request.headers["Sec-WebSocket-Protocol"] = "otherchat" response = server.accept(request) + self.assertEqual(response.status_code, 400) + with self.assertRaisesRegex( + NegotiationError, + r"invalid subprotocol; expected one of superchat, chat", + ): + raise server.handshake_exc + + @staticmethod + def optional_chat(protocol, subprotocols): + if "chat" in subprotocols: + return "chat" + + def test_select_subprotocol(self): + server = ServerProtocol(select_subprotocol=self.optional_chat) + request = self.make_request() + request.headers["Sec-WebSocket-Protocol"] = "chat" + response = server.accept(request) + + self.assertEqual(response.status_code, 101) + self.assertEqual(response.headers["Sec-WebSocket-Protocol"], "chat") + self.assertEqual(server.subprotocol, "chat") + + def test_select_no_subprotocol(self): + server = ServerProtocol(select_subprotocol=self.optional_chat) + request = self.make_request() + request.headers["Sec-WebSocket-Protocol"] = "otherchat" + response = server.accept(request) + self.assertEqual(response.status_code, 101) self.assertNotIn("Sec-WebSocket-Protocol", response.headers) self.assertIsNone(server.subprotocol) From e4fcab16a70344e356268e50dc0c0cf541920c5d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 Oct 2022 21:51:39 +0200 Subject: [PATCH 1146/1539] Handle exceptions when parsing opening handshake. --- src/websockets/client.py | 27 +++++++++++++++++---------- src/websockets/server.py | 11 ++++++++++- tests/test_client.py | 26 ++++++++++++++++++++++++++ tests/test_server.py | 18 ++++++++++++++++++ 4 files changed, 71 insertions(+), 11 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index ad4c1a15d..bfc8080a7 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -316,11 +316,17 @@ def send_request(self, request: Request) -> None: def parse(self) -> Generator[None, None, None]: if self.state is CONNECTING: - response = yield from Response.parse( - self.reader.read_line, - self.reader.read_exact, - self.reader.read_to_eof, - ) + try: + response = yield from Response.parse( + self.reader.read_line, + self.reader.read_exact, + self.reader.read_to_eof, + ) + except Exception as exc: + self.handshake_exc = exc + self.parser = self.discard() + next(self.parser) # start coroutine + yield if self.debug: code, phrase = response.status_code, response.reason_phrase @@ -334,14 +340,15 @@ def parse(self) -> Generator[None, None, None]: self.process_response(response) except InvalidHandshake as exc: response._exception = exc + self.events.append(response) self.handshake_exc = exc self.parser = self.discard() next(self.parser) # start coroutine - else: - assert self.state is CONNECTING - self.state = OPEN - finally: - self.events.append(response) + yield + + assert self.state is CONNECTING + self.state = OPEN + self.events.append(response) yield from super().parse() diff --git a/src/websockets/server.py b/src/websockets/server.py index 547c516b5..214b38bfa 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -541,7 +541,16 @@ def send_response(self, response: Response) -> None: def parse(self) -> Generator[None, None, None]: if self.state is CONNECTING: - request = yield from Request.parse(self.reader.read_line) + try: + request = yield from Request.parse( + self.reader.read_line, + ) + except Exception as exc: + self.handshake_exc = exc + self.send_eof() + self.parser = self.discard() + next(self.parser) # start coroutine + yield if self.debug: self.logger.debug("< GET %s HTTP/1.1", request.path) diff --git a/tests/test_client.py b/tests/test_client.py index 718219b9d..c83c87038 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -218,6 +218,32 @@ def test_reject_response(self): ) self.assertEqual(response.body, b"Sorry folks.\n") + def test_no_response(self): + with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): + client = ClientProtocol(parse_uri("ws://example.com/test")) + client.connect() + client.receive_eof() + self.assertEqual(client.events_received(), []) + + def test_partial_response(self): + with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): + client = ClientProtocol(parse_uri("ws://example.com/test")) + client.connect() + client.receive_data(b"HTTP/1.1 101 Switching Protocols\r\n") + client.receive_eof() + self.assertEqual(client.events_received(), []) + + def test_random_response(self): + with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): + client = ClientProtocol(parse_uri("ws://example.com/test")) + client.connect() + client.receive_data(b"220 smtp.invalid\r\n") + client.receive_data(b"250 Hello relay.invalid\r\n") + client.receive_data(b"250 Ok\r\n") + client.receive_data(b"250 Ok\r\n") + client.receive_eof() + self.assertEqual(client.events_received(), []) + def make_accept_response(self, client): request = client.connect() return Response( diff --git a/tests/test_server.py b/tests/test_server.py index 1c0bbb292..ecf3d4cbe 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -72,6 +72,24 @@ def test_connect_request(self): ), ) + def test_no_request(self): + server = ServerProtocol() + server.receive_eof() + self.assertEqual(server.events_received(), []) + + def test_partial_request(self): + server = ServerProtocol() + server.receive_data(b"GET /test HTTP/1.1\r\n") + server.receive_eof() + self.assertEqual(server.events_received(), []) + + def test_random_request(self): + server = ServerProtocol() + server.receive_data(b"HELO relay.invalid\r\n") + server.receive_data(b"MAIL FROM: \r\n") + server.receive_data(b"RCPT TO: \r\n") + self.assertEqual(server.events_received(), []) + class AcceptRejectTests(unittest.TestCase): def make_request(self): From 3015447f5afbe5e6c913bf0a353777ce3dc45f80 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 8 Jan 2023 08:47:56 +0100 Subject: [PATCH 1147/1539] Attempt to get a 10 on the OpenSSF check. --- SECURITY.md | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/SECURITY.md b/SECURITY.md index 82024b485..175b20c58 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,5 +1,12 @@ -# Security policy +# Security + +## Policy Only the latest version receives security updates. -Please report vulnerabilities [via Tidelift](https://tidelift.com/security). +## Contact information + +Please report security vulnerabilities to the +[Tidelift security team](https://tidelift.com/security). + +Tidelift will coordinate the fix and disclosure. From 38b08fb72fb3c7e8358b3bba7cbe467c2a355aa9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 8 Jan 2023 08:51:24 +0100 Subject: [PATCH 1148/1539] Increase max header length in legacy module. Fix #1243. Ref #1239. --- src/websockets/legacy/http.py | 4 ++-- tests/legacy/test_http.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/websockets/legacy/http.py b/src/websockets/legacy/http.py index d9e44cc28..7cc3db844 100644 --- a/src/websockets/legacy/http.py +++ b/src/websockets/legacy/http.py @@ -10,8 +10,8 @@ __all__ = ["read_request", "read_response"] -MAX_HEADERS = 256 -MAX_LINE = 4110 +MAX_HEADERS = 128 +MAX_LINE = 8192 def d(value: bytes) -> str: diff --git a/tests/legacy/test_http.py b/tests/legacy/test_http.py index 5c9adc97f..15d53e08d 100644 --- a/tests/legacy/test_http.py +++ b/tests/legacy/test_http.py @@ -119,13 +119,13 @@ async def test_header_value(self): await read_headers(self.stream) async def test_headers_limit(self): - self.stream.feed_data(b"foo: bar\r\n" * 257 + b"\r\n") + self.stream.feed_data(b"foo: bar\r\n" * 129 + b"\r\n") with self.assertRaises(SecurityError): await read_headers(self.stream) async def test_line_limit(self): - # Header line contains 5 + 4104 + 2 = 4111 bytes. - self.stream.feed_data(b"foo: " + b"a" * 4104 + b"\r\n\r\n") + # Header line contains 5 + 8186 + 2 = 8193 bytes. + self.stream.feed_data(b"foo: " + b"a" * 8186 + b"\r\n\r\n") with self.assertRaises(SecurityError): await read_headers(self.stream) From 75bb1cb07a476e899b689c5f50872f90f98a38e5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 8 Jan 2023 09:14:09 +0100 Subject: [PATCH 1149/1539] Fix incorrect partial binding pattern in docs. Fix #1275. Supersede #1276. --- docs/faq/common.rst | 4 ++-- docs/faq/server.rst | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/faq/common.rst b/docs/faq/common.rst index dff64f67c..66e273056 100644 --- a/docs/faq/common.rst +++ b/docs/faq/common.rst @@ -125,11 +125,11 @@ You can bind additional arguments to the protocol factory with import websockets class MyServerProtocol(websockets.WebSocketServerProtocol): - def __init__(self, extra_argument, *args, **kwargs): + def __init__(self, *args, extra_argument=None, **kwargs): super().__init__(*args, **kwargs) # do something with extra_argument - create_protocol = functools.partial(MyServerProtocol, extra_argument='spam') + create_protocol = functools.partial(MyServerProtocol, extra_argument=42) start_server = websockets.serve(..., create_protocol=create_protocol) This example was for a server. The same pattern applies on a client. diff --git a/docs/faq/server.rst b/docs/faq/server.rst index 4e4622dce..02a248cb8 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -185,7 +185,7 @@ You can bind additional arguments to the connection handler with async def handler(websocket, extra_argument): ... - bound_handler = functools.partial(handler, extra_argument='spam') + bound_handler = functools.partial(handler, extra_argument=42) start_server = websockets.serve(bound_handler, ...) Another way to achieve this result is to define the ``handler`` coroutine in From f42fd7ba34d40b9d9a800916fe0d27bf21c13656 Mon Sep 17 00:00:00 2001 From: ooliver1 Date: Sat, 10 Dec 2022 20:13:15 +0000 Subject: [PATCH 1150/1539] Refactor create_protocol to actually allow subclasses --- src/websockets/legacy/auth.py | 2 +- src/websockets/legacy/client.py | 2 +- src/websockets/legacy/server.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index ac24c179e..def36a39c 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -118,7 +118,7 @@ def basic_auth_protocol_factory( realm: Optional[str] = None, credentials: Optional[Union[Credentials, Iterable[Credentials]]] = None, check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None, - create_protocol: Optional[Callable[[Any], BasicAuthWebSocketServerProtocol]] = None, + create_protocol: Optional[Callable[..., BasicAuthWebSocketServerProtocol]] = None, ) -> Callable[[Any], BasicAuthWebSocketServerProtocol]: """ Protocol factory that enforces HTTP Basic Auth. diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 9b953df8f..4981ac9bd 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -438,7 +438,7 @@ def __init__( self, uri: str, *, - create_protocol: Optional[Callable[[Any], WebSocketClientProtocol]] = None, + create_protocol: Optional[Callable[..., WebSocketClientProtocol]] = None, logger: Optional[LoggerLike] = None, compression: Optional[str] = "deflate", origin: Optional[Origin] = None, diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 5ec5ed1a3..6f8833a88 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -989,7 +989,7 @@ def __init__( host: Optional[Union[str, Sequence[str]]] = None, port: Optional[int] = None, *, - create_protocol: Optional[Callable[[Any], WebSocketServerProtocol]] = None, + create_protocol: Optional[Callable[..., WebSocketServerProtocol]] = None, logger: Optional[LoggerLike] = None, compression: Optional[str] = "deflate", origins: Optional[Sequence[Optional[Origin]]] = None, From 716245215fab4a6937a9514eb0c0e1939dfc4fcd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 18 Jan 2023 07:24:46 +0100 Subject: [PATCH 1151/1539] Follow-up on f42fd7ba. --- src/websockets/legacy/auth.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index def36a39c..3511469e6 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -119,7 +119,7 @@ def basic_auth_protocol_factory( credentials: Optional[Union[Credentials, Iterable[Credentials]]] = None, check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None, create_protocol: Optional[Callable[..., BasicAuthWebSocketServerProtocol]] = None, -) -> Callable[[Any], BasicAuthWebSocketServerProtocol]: +) -> Callable[..., BasicAuthWebSocketServerProtocol]: """ Protocol factory that enforces HTTP Basic Auth. @@ -175,11 +175,7 @@ async def check_credentials(username: str, password: str) -> bool: return hmac.compare_digest(expected_password, password) if create_protocol is None: - # Not sure why mypy cannot figure this out. - create_protocol = cast( - Callable[[Any], BasicAuthWebSocketServerProtocol], - BasicAuthWebSocketServerProtocol, - ) + create_protocol = BasicAuthWebSocketServerProtocol return functools.partial( create_protocol, From f2176ebc682742ec6a00646663c108bc400451bb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 18 Jan 2023 07:36:10 +0100 Subject: [PATCH 1152/1539] Fix errors in the documentation of 23a2d3f6. --- docs/project/changelog.rst | 2 +- src/websockets/server.py | 13 ++++++------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 34e1f909c..cc3ebc091 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -67,7 +67,7 @@ New features * Made it possible to close a server without closing existing connections. -* Added :attr:`~protocol.ServerProtocol.select_subprotocol` to customize +* Added :attr:`~server.ServerProtocol.select_subprotocol` to customize negotiation of subprotocols in the Sans-I/O layer. 10.4 diff --git a/src/websockets/server.py b/src/websockets/server.py index 214b38bfa..148b4ec11 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -77,7 +77,12 @@ def __init__( origins: Optional[Sequence[Optional[Origin]]] = None, extensions: Optional[Sequence[ServerExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, - select_subprotocol: Optional[SelectSubprotocol] = None, + select_subprotocol: Optional[ + Callable[ + [ServerProtocol, Sequence[Subprotocol]], + Optional[Subprotocol], + ] + ] = None, state: State = CONNECTING, max_size: Optional[int] = 2**20, logger: Optional[LoggerLike] = None, @@ -562,12 +567,6 @@ def parse(self) -> Generator[None, None, None]: yield from super().parse() -SelectSubprotocol = Callable[ - [ServerProtocol, Sequence[Subprotocol]], - Optional[Subprotocol], -] - - class ServerConnection(ServerProtocol): def __init__(self, *args: Any, **kwargs: Any) -> None: warnings.warn( From a525950c84a60151f261f6282fa80ff310954718 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 13:35:29 +0100 Subject: [PATCH 1153/1539] Fix typos in comments. Fix #1284. Thank you @cclauss! --- src/websockets/frames.py | 4 ++-- src/websockets/legacy/protocol.py | 2 +- src/websockets/protocol.py | 2 +- src/websockets/server.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 45d006e3f..52d81746d 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -127,7 +127,7 @@ class Frame: def __str__(self) -> str: """ - Return a human-readable represention of a frame. + Return a human-readable representation of a frame. """ coding = None @@ -389,7 +389,7 @@ class Close: def __str__(self) -> str: """ - Return a human-readable represention of a close code and reason. + Return a human-readable representation of a close code and reason. """ if 3000 <= self.code < 4000: diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 7881b947d..f6c419c3a 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1120,7 +1120,7 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: try: # Echo the original data instead of re-serializing it with # Close.serialize() because that fails when the close frame - # is empty and Close.parse() synthetizes a 1005 close code. + # is empty and Close.parse() synthesizes a 1005 close code. await self.write_close_frame(self.close_rcvd, frame.data) except ConnectionClosed: # Connection closed before we could echo the close frame. diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 1f1af4a84..e5e8826f6 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -653,7 +653,7 @@ def recv_frame(self, frame: Frame) -> None: if self.state is OPEN: # Echo the original data instead of re-serializing it with # Close.serialize() because that fails when the close frame - # is empty and Close.parse() synthetizes a 1005 close code. + # is empty and Close.parse() synthesizes a 1005 close code. # The rest is identical to send_close(). self.send_frame(Frame(OP_CLOSE, frame.data)) self.close_sent = self.close_rcvd diff --git a/src/websockets/server.py b/src/websockets/server.py index 148b4ec11..5c73d7e07 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -512,7 +512,7 @@ def reject( response = Response(status.value, status.phrase, headers, body) # When reject() is called from accept(), handshake_exc is already set. # If a user calls reject(), set handshake_exc to guarantee invariant: - # "handshake_exc is None if and only if opening handshake succeded." + # "handshake_exc is None if and only if opening handshake succeeded." if self.handshake_exc is None: self.handshake_exc = InvalidStatus(response) self.logger.info("connection failed (%d %s)", status.value, status.phrase) From 87657de0edefa9c1d5b305d48c9cb0d28bdfc9d1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 15:09:44 +0100 Subject: [PATCH 1154/1539] Small FAQ updates. --- docs/faq/asyncio.rst | 2 +- docs/faq/common.rst | 6 ++---- docs/faq/misc.rst | 4 ++-- docs/faq/server.rst | 4 ++-- 4 files changed, 7 insertions(+), 9 deletions(-) diff --git a/docs/faq/asyncio.rst b/docs/faq/asyncio.rst index 7db43e1be..d00cf3f47 100644 --- a/docs/faq/asyncio.rst +++ b/docs/faq/asyncio.rst @@ -1,4 +1,4 @@ -asyncio usage +Using asyncio ============= .. currentmodule:: websockets diff --git a/docs/faq/common.rst b/docs/faq/common.rst index 66e273056..2a512ea90 100644 --- a/docs/faq/common.rst +++ b/docs/faq/common.rst @@ -109,10 +109,8 @@ Use :func:`~asyncio.wait_for`:: await asyncio.wait_for(websocket.recv(), timeout=10) -This technique works for most APIs, except for asynchronous context managers. -See `issue 574`_. - -.. _issue 574: https://github.com/aaugustin/websockets/issues/574 +This technique works for most APIs. When it doesn't, for example with +asynchronous context managers, websockets provides an ``open_timeout`` argument. How can I pass arguments to a custom protocol subclass? ------------------------------------------------------- diff --git a/docs/faq/misc.rst b/docs/faq/misc.rst index 681c5e45a..15e520fdd 100644 --- a/docs/faq/misc.rst +++ b/docs/faq/misc.rst @@ -33,8 +33,8 @@ Instead, use the real import paths e.g.:: websockets.client.connect(...) websockets.server.serve(...) -Why is websockets slower than another Python library in my benchmark? -..................................................................... +Why is websockets slower than another library in my benchmark? +.............................................................. Not all libraries are as feature-complete as websockets. For a fair benchmark, you should disable features that the other library doesn't provide. Typically, diff --git a/docs/faq/server.rst b/docs/faq/server.rst index 02a248cb8..d1388c701 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -139,7 +139,7 @@ Add error handling according to the behavior you want if the user disconnected before the message could be sent. This example supports only one connection per user. To support concurrent -connects by the same user, you can change ``CONNECTIONS`` to store a set of +connections by the same user, you can change ``CONNECTIONS`` to store a set of connections for each user. If you're running multiple server processes, call ``message_user`` in each @@ -213,7 +213,7 @@ client, as shown in the :doc:`tutorial <../intro/tutorial2>`. When you want to authenticate the connection before routing it, this is usually more convenient. Generally speaking, there is far less emphasis on the request path in WebSocket -servers than in HTTP servers. When a WebSockt server provides a single endpoint, +servers than in HTTP servers. When a WebSocket server provides a single endpoint, it may ignore the request path entirely. How do I access HTTP headers? From bc8b3a85fe3d83378257d5ef2630e9c03714da44 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 16:45:29 +0100 Subject: [PATCH 1155/1539] Explain the legacy submodule. Fix #1297. --- docs/faq/misc.rst | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docs/faq/misc.rst b/docs/faq/misc.rst index 15e520fdd..72e7d6d56 100644 --- a/docs/faq/misc.rst +++ b/docs/faq/misc.rst @@ -33,6 +33,19 @@ Instead, use the real import paths e.g.:: websockets.client.connect(...) websockets.server.serve(...) +Why is the default implementation located in ``websockets.legacy``? +................................................................... + +This is an artifact of websockets' history. For its first eight years, only the +:mod:`asyncio`-based implementation existed. Then, the Sans-I/O implementation +was added. Moving the code in a ``legacy`` submodule eased this refactoring and +optimized maintainability. + +All public APIs were kept at their original locations. ``websockets.legacy`` +isn't a public API. It's only visible in the source code and in stack traces. +There is no intent to deprecate this implementation — at least until a superior +alternative exists. + Why is websockets slower than another library in my benchmark? .............................................................. From 8f0a33c5fb962036a2eb389a14c5bfaf58dfffa6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 16:45:59 +0100 Subject: [PATCH 1156/1539] Add missing words to spellchecker. --- docs/spelling_wordlist.txt | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 1d342515b..9cc2182e1 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -18,6 +18,7 @@ coroutine coroutines cryptocurrencies cryptocurrency +css ctrl deserialize django @@ -27,11 +28,13 @@ formatter fractalideas gunicorn healthz +html hypercorn iframe IPv istio iterable +js keepalive KiB kubernetes @@ -40,12 +43,15 @@ linkerd liveness lookups MiB +mutex mypy nginx Paketo permessage pid +procfile proxying +py pythonic reconnection redis @@ -56,6 +62,7 @@ scalable stateful subclasses subclassing +submodule subpackages subprotocol subprotocols @@ -63,6 +70,7 @@ supervisord tidelift tls tox +txt unregister uple uvicorn From 206d7ef5eea27791684f815f65a01fa23d89d3b2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 16:53:49 +0100 Subject: [PATCH 1157/1539] Restructure API reference. --- docs/faq/misc.rst | 2 + docs/project/changelog.rst | 5 +- docs/reference/{ => asyncio}/client.rst | 58 +--------- docs/reference/{ => asyncio}/common.rst | 68 +----------- docs/reference/{ => asyncio}/server.rst | 69 ++---------- .../{utilities.rst => datastructures.rst} | 9 +- docs/reference/index.rst | 101 +++++++++++------- docs/reference/sansio/client.rst | 58 ++++++++++ docs/reference/sansio/common.rst | 62 +++++++++++ docs/reference/sansio/server.rst | 62 +++++++++++ src/websockets/client.py | 2 +- src/websockets/legacy/client.py | 4 +- src/websockets/legacy/protocol.py | 6 +- src/websockets/legacy/server.py | 6 +- src/websockets/protocol.py | 2 +- src/websockets/server.py | 2 +- 16 files changed, 278 insertions(+), 238 deletions(-) rename docs/reference/{ => asyncio}/client.rst (69%) rename docs/reference/{ => asyncio}/common.rst (52%) rename docs/reference/{ => asyncio}/server.rst (73%) rename docs/reference/{utilities.rst => datastructures.rst} (90%) create mode 100644 docs/reference/sansio/client.rst create mode 100644 docs/reference/sansio/common.rst create mode 100644 docs/reference/sansio/server.rst diff --git a/docs/faq/misc.rst b/docs/faq/misc.rst index 72e7d6d56..e320cb808 100644 --- a/docs/faq/misc.rst +++ b/docs/faq/misc.rst @@ -10,6 +10,8 @@ Often, this is because you created a script called ``websockets.py`` in your current working directory. Then ``import websockets`` imports this module instead of the websockets library. +.. _real-import-paths: + Why does my IDE fail to show documentation for websockets APIs? ............................................................... diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index cc3ebc091..85a4ac92d 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -22,8 +22,9 @@ When a release contains backwards-incompatible API changes, the major version is increased, else the minor version is increased. Patch versions are only for fixing regressions shortly after a release. -Only documented APIs are public. Undocumented APIs are considered private. -They may change at any time. +Only documented API are public. Undocumented, private API may change without +notice. + 11.0 ---- diff --git a/docs/reference/client.rst b/docs/reference/asyncio/client.rst similarity index 69% rename from docs/reference/client.rst rename to docs/reference/asyncio/client.rst index 44f053b1e..5086015b7 100644 --- a/docs/reference/client.rst +++ b/docs/reference/asyncio/client.rst @@ -1,13 +1,10 @@ -Client -====== +Client (:mod:`asyncio`) +======================= .. automodule:: websockets.client -asyncio -------- - Opening a connection -.................... +-------------------- .. autofunction:: connect(uri, *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) :async: @@ -16,7 +13,7 @@ Opening a connection :async: Using a connection -.................. +------------------ .. autoclass:: WebSocketClientProtocol(*, logger=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) @@ -65,50 +62,3 @@ Using a connection .. autoproperty:: close_code .. autoproperty:: close_reason - -Sans-I/O --------- - -.. autoclass:: ClientProtocol(wsuri, origin=None, extensions=None, subprotocols=None, state=State.CONNECTING, max_size=2 ** 20, logger=None) - - .. automethod:: receive_data - - .. automethod:: receive_eof - - .. automethod:: connect - - .. automethod:: send_request - - .. automethod:: send_continuation - - .. automethod:: send_text - - .. automethod:: send_binary - - .. automethod:: send_close - - .. automethod:: send_ping - - .. automethod:: send_pong - - .. automethod:: fail - - .. automethod:: events_received - - .. automethod:: data_to_send - - .. automethod:: close_expected - - .. autoattribute:: id - - .. autoattribute:: logger - - .. autoproperty:: state - - .. autoattribute:: handshake_exc - - .. autoproperty:: close_code - - .. autoproperty:: close_reason - - .. autoproperty:: close_exc diff --git a/docs/reference/common.rst b/docs/reference/asyncio/common.rst similarity index 52% rename from docs/reference/common.rst rename to docs/reference/asyncio/common.rst index b42f5ea3e..ee8dc54ac 100644 --- a/docs/reference/common.rst +++ b/docs/reference/asyncio/common.rst @@ -1,8 +1,5 @@ -Both sides -========== - -asyncio -------- +Both sides (:mod:`asyncio`) +=========================== .. automodule:: websockets.legacy.protocol @@ -53,64 +50,3 @@ asyncio .. autoproperty:: close_code .. autoproperty:: close_reason - -Sans-I/O --------- - -.. automodule:: websockets.protocol - -.. autoclass:: Protocol(side, state=State.OPEN, max_size=2 ** 20, logger=None) - - .. automethod:: receive_data - - .. automethod:: receive_eof - - .. automethod:: send_continuation - - .. automethod:: send_text - - .. automethod:: send_binary - - .. automethod:: send_close - - .. automethod:: send_ping - - .. automethod:: send_pong - - .. automethod:: fail - - .. automethod:: events_received - - .. automethod:: data_to_send - - .. automethod:: close_expected - - .. autoattribute:: id - - .. autoattribute:: logger - - .. autoproperty:: state - - .. autoproperty:: close_code - - .. autoproperty:: close_reason - - .. autoproperty:: close_exc - -.. autoclass:: Side - - .. autoattribute:: SERVER - - .. autoattribute:: CLIENT - -.. autoclass:: State - - .. autoattribute:: CONNECTING - - .. autoattribute:: OPEN - - .. autoattribute:: CLOSING - - .. autoattribute:: CLOSED - -.. autodata:: SEND_EOF diff --git a/docs/reference/server.rst b/docs/reference/asyncio/server.rst similarity index 73% rename from docs/reference/server.rst rename to docs/reference/asyncio/server.rst index 08bfe2f57..106317916 100644 --- a/docs/reference/server.rst +++ b/docs/reference/asyncio/server.rst @@ -1,13 +1,10 @@ -Server -====== +Server (:mod:`asyncio`) +======================= .. automodule:: websockets.server -asyncio -------- - Starting a server -................. +----------------- .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) :async: @@ -16,7 +13,7 @@ Starting a server :async: Stopping a server -................. +----------------- .. autoclass:: WebSocketServer @@ -35,7 +32,7 @@ Stopping a server .. autoattribute:: sockets Using a connection -.................. +------------------ .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, logger=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) @@ -93,7 +90,7 @@ Using a connection Basic authentication -.................... +-------------------- .. automodule:: websockets.auth @@ -110,55 +107,7 @@ websockets supports HTTP Basic Authentication according to .. automethod:: check_credentials -.. currentmodule:: websockets.server - -Sans-I/O --------- - -.. autoclass:: ServerProtocol(origins=None, extensions=None, subprotocols=None, state=State.CONNECTING, max_size=2 ** 20, logger=None) - - .. automethod:: receive_data - - .. automethod:: receive_eof - - .. automethod:: accept - - .. automethod:: select_subprotocol - - .. automethod:: reject - - .. automethod:: send_response - - .. automethod:: send_continuation - - .. automethod:: send_text - - .. automethod:: send_binary - - .. automethod:: send_close - - .. automethod:: send_ping - - .. automethod:: send_pong - - .. automethod:: fail - - .. automethod:: events_received - - .. automethod:: data_to_send - - .. automethod:: close_expected - - .. autoattribute:: id - - .. autoattribute:: logger - - .. autoproperty:: state - - .. autoattribute:: handshake_exc - - .. autoproperty:: close_code - - .. autoproperty:: close_reason +Broadcast +--------- - .. autoproperty:: close_exc +.. autofunction:: websockets.broadcast diff --git a/docs/reference/utilities.rst b/docs/reference/datastructures.rst similarity index 90% rename from docs/reference/utilities.rst rename to docs/reference/datastructures.rst index 6b5d402fc..8217052d1 100644 --- a/docs/reference/utilities.rst +++ b/docs/reference/datastructures.rst @@ -1,10 +1,5 @@ -Utilities -========= - -Broadcast ---------- - -.. autofunction:: websockets.broadcast +Data structures +=============== WebSocket events ---------------- diff --git a/docs/reference/index.rst b/docs/reference/index.rst index f164dde91..3b708ef91 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -3,66 +3,91 @@ API reference .. currentmodule:: websockets -websockets provides client and server implementations, as shown in -the :doc:`getting started guide <../intro/index>`. +:mod:`asyncio` +-------------- -The process for opening and closing a WebSocket connection depends on which -side you're implementing. +This is the default implementation. It's ideal for servers that handle many +clients concurrently. -* On the client side, connecting to a server with :func:`~client.connect` - yields a connection object that provides methods for interacting with the - connection. Your code can open a connection, then send or receive messages. +.. toctree:: + :titlesonly: + + asyncio/server + asyncio/client + asyncio/common + +`Sans-I/O`_ +----------- + +This layer is designed for integrating in third-party libraries, typically +application servers. + +.. _Sans-I/O: https://sans-io.readthedocs.io/ + +.. toctree:: + :titlesonly: - If you use :func:`~client.connect` as an asynchronous context manager, - then websockets closes the connection on exit. If not, then your code is - responsible for closing the connection. + sansio/server + sansio/client + sansio/common -* On the server side, :func:`~server.serve` starts listening for client - connections and yields an server object that you can use to shut down - the server. +Extensions +---------- - Then, when a client connects, the server initializes a connection object and - passes it to a handler coroutine, which is where your code can send or - receive messages. This pattern is called `inversion of control`_. It's - common in frameworks implementing servers. +The Per-Message Deflate extension is built in. You may also define custom +extensions. - When the handler coroutine terminates, websockets closes the connection. You - may also close it in the handler coroutine if you'd like. +.. toctree:: + :titlesonly: -.. _inversion of control: https://en.wikipedia.org/wiki/Inversion_of_control + extensions + +Shared +------ -Once the connection is open, the WebSocket protocol is symmetrical, except for -low-level details that websockets manages under the hood. The same methods -are available on client connections created with :func:`~client.connect` and -on server connections received in argument by the connection handler -of :func:`~server.serve`. +These low-level API are shared by all implementations. .. toctree:: :titlesonly: - server - client - common - utilities + datastructures exceptions types - extensions - limitations -Public API documented in the API reference are subject to the +API stability +------------- + +Public API documented in this API reference are subject to the :ref:`backwards-compatibility policy `. Anything that isn't listed in the API reference is a private API. There's no guarantees of behavior or backwards-compatibility for private APIs. +Convenience imports +------------------- + +For convenience, many public APIs can be imported directly from the +``websockets`` package. + + .. admonition:: Convenience imports are incompatible with some development tools. :class: caution - For convenience, most public APIs can be imported from the ``websockets`` - package. However, this is incompatible with static code analysis. - - It may break auto-completion and contextual documentation in IDEs, type - checking with mypy_, etc. If you're using such tools, stick to the full - import paths. + Specifically, static code analysis tools don't understand them. This breaks + auto-completion and contextual documentation in IDEs, type checking with + mypy_, etc. .. _mypy: https://github.com/python/mypy + + If you're using such tools, stick to the full import paths, as explained in + this FAQ: :ref:`real-import-paths` + +Limitations +----------- + +There are a few known limitations in the current API. + +.. toctree:: + :titlesonly: + + limitations diff --git a/docs/reference/sansio/client.rst b/docs/reference/sansio/client.rst new file mode 100644 index 000000000..09bafc745 --- /dev/null +++ b/docs/reference/sansio/client.rst @@ -0,0 +1,58 @@ +Client (`Sans-I/O`_) +==================== + +.. _Sans-I/O: https://sans-io.readthedocs.io/ + +.. currentmodule:: websockets.client + +.. autoclass:: ClientProtocol(wsuri, origin=None, extensions=None, subprotocols=None, state=State.CONNECTING, max_size=2 ** 20, logger=None) + + .. automethod:: receive_data + + .. automethod:: receive_eof + + .. automethod:: connect + + .. automethod:: send_request + + .. automethod:: send_continuation + + .. automethod:: send_text + + .. automethod:: send_binary + + .. automethod:: send_close + + .. automethod:: send_ping + + .. automethod:: send_pong + + .. automethod:: fail + + .. automethod:: events_received + + .. automethod:: data_to_send + + .. automethod:: close_expected + + WebSocket protocol objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: state + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: handshake_exc + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason + + .. autoproperty:: close_exc diff --git a/docs/reference/sansio/common.rst b/docs/reference/sansio/common.rst new file mode 100644 index 000000000..2678c1361 --- /dev/null +++ b/docs/reference/sansio/common.rst @@ -0,0 +1,62 @@ +Both sides (`Sans-I/O`_) +========================= + +.. _Sans-I/O: https://sans-io.readthedocs.io/ + +.. automodule:: websockets.protocol + +.. autoclass:: Protocol(side, state=State.OPEN, max_size=2 ** 20, logger=None) + + .. automethod:: receive_data + + .. automethod:: receive_eof + + .. automethod:: send_continuation + + .. automethod:: send_text + + .. automethod:: send_binary + + .. automethod:: send_close + + .. automethod:: send_ping + + .. automethod:: send_pong + + .. automethod:: fail + + .. automethod:: events_received + + .. automethod:: data_to_send + + .. automethod:: close_expected + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: state + + .. autoproperty:: close_code + + .. autoproperty:: close_reason + + .. autoproperty:: close_exc + +.. autoclass:: Side + + .. autoattribute:: SERVER + + .. autoattribute:: CLIENT + +.. autoclass:: State + + .. autoattribute:: CONNECTING + + .. autoattribute:: OPEN + + .. autoattribute:: CLOSING + + .. autoattribute:: CLOSED + +.. autodata:: SEND_EOF diff --git a/docs/reference/sansio/server.rst b/docs/reference/sansio/server.rst new file mode 100644 index 000000000..d70df6277 --- /dev/null +++ b/docs/reference/sansio/server.rst @@ -0,0 +1,62 @@ +Server (`Sans-I/O`_) +==================== + +.. _Sans-I/O: https://sans-io.readthedocs.io/ + +.. currentmodule:: websockets.server + +.. autoclass:: ServerProtocol(origins=None, extensions=None, subprotocols=None, state=State.CONNECTING, max_size=2 ** 20, logger=None) + + .. automethod:: receive_data + + .. automethod:: receive_eof + + .. automethod:: accept + + .. automethod:: select_subprotocol + + .. automethod:: reject + + .. automethod:: send_response + + .. automethod:: send_continuation + + .. automethod:: send_text + + .. automethod:: send_binary + + .. automethod:: send_close + + .. automethod:: send_ping + + .. automethod:: send_pong + + .. automethod:: fail + + .. automethod:: events_received + + .. automethod:: data_to_send + + .. automethod:: close_expected + + WebSocket protocol objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: state + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: handshake_exc + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason + + .. autoproperty:: close_exc diff --git a/src/websockets/client.py b/src/websockets/client.py index bfc8080a7..df4f067d3 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -63,7 +63,7 @@ class ClientProtocol(Protocol): :obj:`None` to disable the limit. logger: logger for this connection; defaults to ``logging.getLogger("websockets.client")``; - see the :doc:`logging guide <../topics/logging>` for details. + see the :doc:`logging guide <../../topics/logging>` for details. """ diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 4981ac9bd..1e59a56c9 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -385,10 +385,10 @@ class Connect: be set to a wrapper or a subclass to customize connection handling. logger: logger for this connection; defaults to ``logging.getLogger("websockets.client")``; - see the :doc:`logging guide <../topics/logging>` for details. + see the :doc:`logging guide <../../topics/logging>` for details. compression: shortcut that enables the "permessage-deflate" extension by default; may be set to :obj:`None` to disable compression; - see the :doc:`compression guide <../topics/compression>` for details. + see the :doc:`compression guide <../../topics/compression>` for details. origin: value of the ``Origin`` header. This is useful when connecting to a server that validates the ``Origin`` header to defend against Cross-Site WebSocket Hijacking attacks. diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index f6c419c3a..374aa476d 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -98,7 +98,7 @@ class WebSocketCommonProtocol(asyncio.Protocol): reasons, :meth:`close` completes in at most ``5 * close_timeout`` seconds for clients and ``4 * close_timeout`` for servers. - See the discussion of :doc:`timeouts <../topics/timeouts>` for details. + See the discussion of :doc:`timeouts <../../topics/timeouts>` for details. ``close_timeout`` needs to be a parameter of the protocol because websockets usually calls :meth:`close` implicitly upon exit: @@ -141,12 +141,12 @@ class WebSocketCommonProtocol(asyncio.Protocol): The default value is 64 KiB, equal to asyncio's default (based on the current implementation of ``FlowControlMixin``). - See the discussion of :doc:`memory usage <../topics/memory>` for details. + See the discussion of :doc:`memory usage <../../topics/memory>` for details. Args: logger: logger for this connection; defaults to ``logging.getLogger("websockets.protocol")``; - see the :doc:`logging guide <../topics/logging>` for details. + see the :doc:`logging guide <../../topics/logging>` for details. ping_interval: delay between keepalive pings in seconds; :obj:`None` to disable keepalive pings. ping_timeout: timeout for keepalive pings in seconds; diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 6f8833a88..72f0a8203 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -664,7 +664,7 @@ class WebSocketServer: Args: logger: logger for this server; defaults to ``logging.getLogger("websockets.server")``; - see the :doc:`logging guide <../topics/logging>` for details. + see the :doc:`logging guide <../../topics/logging>` for details. """ @@ -934,10 +934,10 @@ class Serve: be set to a wrapper or a subclass to customize connection handling. logger: logger for this server; defaults to ``logging.getLogger("websockets.server")``; - see the :doc:`logging guide <../topics/logging>` for details. + see the :doc:`logging guide <../../topics/logging>` for details. compression: shortcut that enables the "permessage-deflate" extension by default; may be set to :obj:`None` to disable compression; - see the :doc:`compression guide <../topics/compression>` for details. + see the :doc:`compression guide <../../topics/compression>` for details. origins: acceptable values of the ``Origin`` header; include :obj:`None` in the list if the lack of an origin is acceptable. This is useful for defending against Cross-Site WebSocket diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index e5e8826f6..7bfa96f8b 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -79,7 +79,7 @@ class Protocol: logger: logger for this connection; depending on ``side``, defaults to ``logging.getLogger("websockets.client")`` or ``logging.getLogger("websockets.server")``; - see the :doc:`logging guide <../topics/logging>` for details. + see the :doc:`logging guide <../../topics/logging>` for details. """ diff --git a/src/websockets/server.py b/src/websockets/server.py index 5c73d7e07..9415a99a8 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -67,7 +67,7 @@ class ServerProtocol(Protocol): :obj:`None` to disable the limit. logger: logger for this connection; defaults to ``logging.getLogger("websockets.client")``; - see the :doc:`logging guide <../topics/logging>` for details. + see the :doc:`logging guide <../../topics/logging>` for details. """ From 4aa91dc90ccd743209b02e308e498ce0f57616f6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 17:19:50 +0100 Subject: [PATCH 1158/1539] Upgrade to black 2023 style. Changes stem from https://github.com/psf/black/pull/3035. --- src/websockets/client.py | 4 ---- src/websockets/extensions/permessage_deflate.py | 1 - src/websockets/legacy/client.py | 4 ---- src/websockets/legacy/framing.py | 1 - src/websockets/legacy/protocol.py | 2 -- src/websockets/legacy/server.py | 5 ----- src/websockets/server.py | 3 --- tests/legacy/test_auth.py | 1 - tests/legacy/test_client_server.py | 4 ---- 9 files changed, 25 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index df4f067d3..b5f871571 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -222,7 +222,6 @@ def process_extensions(self, headers: Headers) -> List[Extension]: extensions = headers.get_all("Sec-WebSocket-Extensions") if extensions: - if self.available_extensions is None: raise InvalidHandshake("no extensions supported") @@ -231,9 +230,7 @@ def process_extensions(self, headers: Headers) -> List[Extension]: ) for name, response_params in parsed_extensions: - for extension_factory in self.available_extensions: - # Skip non-matching extensions based on their name. if extension_factory.name != name: continue @@ -280,7 +277,6 @@ def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: subprotocols = headers.get_all("Sec-WebSocket-Protocol") if subprotocols: - if self.available_subprotocols is None: raise InvalidHandshake("no subprotocols supported") diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index e0de5e8f8..b391837c6 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -211,7 +211,6 @@ def _extract_parameters( client_max_window_bits: Optional[Union[int, bool]] = None for name, value in params: - if name == "server_no_context_takeover": if server_no_context_takeover: raise exceptions.DuplicateParameter(name) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 1e59a56c9..f8876f59d 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -188,7 +188,6 @@ def process_extensions( header_values = headers.get_all("Sec-WebSocket-Extensions") if header_values: - if available_extensions is None: raise InvalidHandshake("no extensions supported") @@ -197,9 +196,7 @@ def process_extensions( ) for name, response_params in parsed_header_values: - for extension_factory in available_extensions: - # Skip non-matching extensions based on their name. if extension_factory.name != name: continue @@ -245,7 +242,6 @@ def process_subprotocol( header_values = headers.get_all("Sec-WebSocket-Protocol") if header_values: - if available_subprotocols is None: raise InvalidHandshake("no subprotocols supported") diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index 04cddc0e0..29864b136 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -14,7 +14,6 @@ class Frame(NamedTuple): - fin: bool opcode: frames.Opcode data: bytes diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 374aa476d..d31ec19a8 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -652,7 +652,6 @@ async def send( # Fragmented message -- regular iterator. elif isinstance(message, Iterable): - # Work around https://github.com/python/mypy/issues/6227 message = cast(Iterable[Data], message) @@ -1519,7 +1518,6 @@ def connection_lost(self, exc: Optional[Exception]) -> None: self.connection_lost_waiter.set_result(None) if True: # pragma: no cover - # Copied from asyncio.StreamReaderProtocol if self.reader is not None: if exc is None: diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 72f0a8203..048a270b5 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -160,7 +160,6 @@ async def handler(self) -> None: """ try: - try: await self.handshake( origins=self.origins, @@ -443,15 +442,12 @@ def process_extensions( header_values = headers.get_all("Sec-WebSocket-Extensions") if header_values and available_extensions: - parsed_header_values: List[ExtensionHeader] = sum( [parse_extension(header_value) for header_value in header_values], [] ) for name, request_params in parsed_header_values: - for ext_factory in available_extensions: - # Skip non-matching extensions based on their name. if ext_factory.name != name: continue @@ -503,7 +499,6 @@ def process_subprotocol( header_values = headers.get_all("Sec-WebSocket-Protocol") if header_values and available_subprotocols: - parsed_header_values: List[Subprotocol] = sum( [parse_subprotocol(header_value) for header_value in header_values], [] ) diff --git a/src/websockets/server.py b/src/websockets/server.py index 9415a99a8..16979e7e7 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -354,15 +354,12 @@ def process_extensions( header_values = headers.get_all("Sec-WebSocket-Extensions") if header_values and self.available_extensions: - parsed_header_values: List[ExtensionHeader] = sum( [parse_extension(header_value) for header_value in header_values], [] ) for name, request_params in parsed_header_values: - for ext_factory in self.available_extensions: - # Skip non-matching extensions based on their name. if ext_factory.name != name: continue diff --git a/tests/legacy/test_auth.py b/tests/legacy/test_auth.py index 2b670c31f..3754bcf3a 100644 --- a/tests/legacy/test_auth.py +++ b/tests/legacy/test_auth.py @@ -32,7 +32,6 @@ async def check_credentials(self, username, password): class AuthClientServerTests(ClientServerTestsMixin, AsyncioTestCase): - create_protocol = basic_auth_protocol_factory( realm="auth-tests", credentials=("hello", "iloveyou") ) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 4a4510536..b05d40721 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -212,7 +212,6 @@ class BarClientProtocol(WebSocketClientProtocol): class ClientServerTestsMixin: - secure = False def setUp(self): @@ -309,7 +308,6 @@ def make_http_request(self, path="/", headers=None): class SecureClientServerTestsMixin(ClientServerTestsMixin): - secure = True @property @@ -1299,7 +1297,6 @@ class ClientServerTests( class SecureClientServerTests( CommonClientServerTests, SecureClientServerTestsMixin, AsyncioTestCase ): - # The implementation of this test makes it hard to run it over TLS. test_client_connect_canceled_during_handshake = None @@ -1462,7 +1459,6 @@ async def run_server(path): class AsyncIteratorTests(ClientServerTestsMixin, AsyncioTestCase): - # This is a protocol-level feature, but since it's a high-level API, it is # much easier to exercise at the client or server level. From 1418917e54f56e557b3975a171aada842db8f2e5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 17:59:24 +0100 Subject: [PATCH 1159/1539] Environment variables are always strings. --- tests/legacy/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/legacy/utils.py b/tests/legacy/utils.py index 302fd70be..6195849bd 100644 --- a/tests/legacy/utils.py +++ b/tests/legacy/utils.py @@ -88,7 +88,7 @@ def assertDeprecationWarnings(self, recorded_warnings, expected_warnings): # Unit for timeouts. May be increased on slow machines by setting the # WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. -MS = 0.001 * int(os.environ.get("WEBSOCKETS_TESTS_TIMEOUT_FACTOR", 1)) +MS = 0.001 * int(os.environ.get("WEBSOCKETS_TESTS_TIMEOUT_FACTOR", "1")) # asyncio's debug mode has a 10x performance penalty for this test suite. if os.environ.get("PYTHONASYNCIODEBUG"): # pragma: no cover From 8e1628a14e0dd2ca98871c7500484b5d42d16b67 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 20:12:35 +0100 Subject: [PATCH 1160/1539] Ignore files from direnv. --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 8f9e7dc51..324e77069 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,8 @@ *.pyc *.so .coverage +.direnv +.envrc .idea/ .mypy_cache .tox From ba1ed7a65cc876ff4e0fcd4dd4711402836475e2 Mon Sep 17 00:00:00 2001 From: Sasja Date: Tue, 28 Feb 2023 12:47:07 +0800 Subject: [PATCH 1161/1539] fix small docs typo --- docs/topics/broadcast.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/topics/broadcast.rst b/docs/topics/broadcast.rst index 6c7ced8b0..9a25cbf7d 100644 --- a/docs/topics/broadcast.rst +++ b/docs/topics/broadcast.rst @@ -344,5 +344,5 @@ All other patterns discussed above yield control to the event loop once per client because messages are sent by different tasks. This makes them slower than the built-in :func:`broadcast` function. -There is no major difference between the performance of per-message queues and +There is no major difference between the performance of per-client queues and publish–subscribe. From 17ffb5c9777f21d35ada85155db8c70be490cc3b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 28 Mar 2023 08:12:27 +0200 Subject: [PATCH 1162/1539] Disable social cards in docs. --- docs/conf.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/conf.py b/docs/conf.py index fe6282b5a..58f0c2d55 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -133,6 +133,11 @@ def linkcode_resolve(domain, info): return f"{code_url}/{file}#L{start}-L{end}" +# Configure opengraph extension + +# Social cards don't support the SVG logo. Also, the text preview looks bad. +ogp_social_cards = {"enable": False} + # -- Options for HTML output ------------------------------------------------- From 3c96411902d34f87bb8e6515387df7fbafc86869 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 18:04:00 +0100 Subject: [PATCH 1163/1539] Move MS constant out of legacy directory. --- tests/legacy/test_client_server.py | 3 ++- tests/legacy/test_protocol.py | 3 ++- tests/legacy/utils.py | 19 ------------------- tests/utils.py | 19 +++++++++++++++++++ 4 files changed, 23 insertions(+), 21 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index b05d40721..752f94270 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -42,7 +42,8 @@ NoOpExtension, ServerNoOpExtensionFactory, ) -from .utils import MS, AsyncioTestCase +from ..utils import MS +from .utils import AsyncioTestCase # Generate TLS certificate with: diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index e85402a39..328bc80a2 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -19,7 +19,8 @@ from websockets.legacy.protocol import WebSocketCommonProtocol, broadcast from websockets.protocol import State -from .utils import MS, AsyncioTestCase +from ..utils import MS +from .utils import AsyncioTestCase async def async_iterable(iterable): diff --git a/tests/legacy/utils.py b/tests/legacy/utils.py index 6195849bd..bb4eebb52 100644 --- a/tests/legacy/utils.py +++ b/tests/legacy/utils.py @@ -2,9 +2,6 @@ import contextlib import functools import logging -import os -import platform -import time import unittest @@ -84,19 +81,3 @@ def assertDeprecationWarnings(self, recorded_warnings, expected_warnings): set(str(recorded.message) for recorded in recorded_warnings), set(expected_warnings), ) - - -# Unit for timeouts. May be increased on slow machines by setting the -# WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. -MS = 0.001 * int(os.environ.get("WEBSOCKETS_TESTS_TIMEOUT_FACTOR", "1")) - -# asyncio's debug mode has a 10x performance penalty for this test suite. -if os.environ.get("PYTHONASYNCIODEBUG"): # pragma: no cover - MS *= 10 - -# PyPy has a performance penalty for this test suite. -if platform.python_implementation() == "PyPy": # pragma: no cover - MS *= 5 - -# Ensure that timeouts are larger than the clock's resolution (for Windows). -MS = max(MS, 2.5 * time.get_clock_info("monotonic").resolution) diff --git a/tests/utils.py b/tests/utils.py index 92c754810..5331746f8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,5 +1,8 @@ import contextlib import email.utils +import os +import platform +import time import unittest import warnings @@ -7,6 +10,22 @@ DATE = email.utils.formatdate(usegmt=True) +# Unit for timeouts. May be increased on slow machines by setting the +# WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. +MS = 0.001 * int(os.environ.get("WEBSOCKETS_TESTS_TIMEOUT_FACTOR", "1")) + +# PyPy has a performance penalty for this test suite. +if platform.python_implementation() == "PyPy": # pragma: no cover + MS *= 5 + +# asyncio's debug mode has a 10x performance penalty for this test suite. +if os.environ.get("PYTHONASYNCIODEBUG"): # pragma: no cover + MS *= 10 + +# Ensure that timeouts are larger than the clock's resolution (for Windows). +MS = max(MS, 2.5 * time.get_clock_info("monotonic").resolution) + + class GeneratorTestCase(unittest.TestCase): """ Base class for testing generator-based coroutines. From 445cdf9b05d7f1e2ac9ff9dd5e254cbc57ab6ced Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 17:54:23 +0100 Subject: [PATCH 1164/1539] Move test certificate out of legacy directory. Also add a second domain name for tests. --- tests/legacy/test_client_server.py | 15 ++---- tests/test_localhost.cnf | 5 +- tests/test_localhost.pem | 84 +++++++++++++++--------------- tests/utils.py | 10 ++++ 4 files changed, 58 insertions(+), 56 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 752f94270..72f0b021b 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -42,19 +42,10 @@ NoOpExtension, ServerNoOpExtensionFactory, ) -from ..utils import MS +from ..utils import CERTIFICATE, MS from .utils import AsyncioTestCase -# Generate TLS certificate with: -# $ openssl req -x509 -config test_localhost.cnf -days 15340 -newkey rsa:2048 \ -# -out test_localhost.crt -keyout test_localhost.key -# $ cat test_localhost.key test_localhost.crt > test_localhost.pem -# $ rm test_localhost.key test_localhost.crt - -testcert = bytes(pathlib.Path(__file__).parent.with_name("test_localhost.pem")) - - async def default_handler(ws): if ws.path == "/deprecated_attributes": await ws.recv() # delay that allows catching warnings @@ -314,13 +305,13 @@ class SecureClientServerTestsMixin(ClientServerTestsMixin): @property def server_context(self): ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - ssl_context.load_cert_chain(testcert) + ssl_context.load_cert_chain(CERTIFICATE) return ssl_context @property def client_context(self): ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - ssl_context.load_verify_locations(testcert) + ssl_context.load_verify_locations(CERTIFICATE) return ssl_context def start_server(self, **kwargs): diff --git a/tests/test_localhost.cnf b/tests/test_localhost.cnf index 6dc331ac6..4069e3967 100644 --- a/tests/test_localhost.cnf +++ b/tests/test_localhost.cnf @@ -22,5 +22,6 @@ subjectAltName = @san [ san ] DNS.1 = localhost -IP.2 = 127.0.0.1 -IP.3 = ::1 +DNS.2 = overridden +IP.3 = 127.0.0.1 +IP.4 = ::1 diff --git a/tests/test_localhost.pem b/tests/test_localhost.pem index b8a9ea9ab..8df63ec8f 100644 --- a/tests/test_localhost.pem +++ b/tests/test_localhost.pem @@ -1,48 +1,48 @@ -----BEGIN PRIVATE KEY----- -MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCUgrQVkNbAWRlo -zZUj14Ufz7YEp2MXmvmhdlfOGLwjy+xPO98aJRv5/nYF2eWM3llcmLe8FbBSK+QF -To4su7ZVnc6qITOHqcSDUw06WarQUMs94bhHUvQp1u8+b2hNiMeGw6+QiBI6OJRO -iGpLRbkN6Uj3AKwi8SYVoLyMiztuwbNyGf8fF3DDpHZtBitGtMSBCMsQsfB465pl -2UoyBrWa2lsbLt3VvBZZvHqfEuPjpjjKN5USIXnaf0NizaR6ps3EyfftWy4i7zIQ -N5uTExvaPDyPn9nH3q/dkT99mSMSU1AvTTpX8PN7DlqE6wZMbQsBPRGW7GElQ+Ox -IKdKOLk5AgMBAAECggEAd3kqzQqnaTiEs4ZoC9yPUUc1pErQ8iWP27Ar9TZ67MVa -B2ggFJV0C0sFwbFI9WnPNCn77gj4vzJmD0riH+SnS/tXThDFtscBu7BtvNp0C4Bj -8RWMvXxjxuENuQnBPFbkRWtZ6wk8uK/Zx9AAyyt9M07Qjz1wPfAIdm/IH7zHBFMA -gsqjnkLh1r0FvjNEbLiuGqYU/GVxaZYd+xy+JU52IxjHUUL9yD0BPWb+Szar6AM2 -gUpmTX6+BcCZwwZ//DzCoWYZ9JbP8akn6edBeZyuMPqYgLzZkPyQ+hRW46VPPw89 -yg4LR9nzgQiBHlac0laB4NrWa+d9QRRLitl1O3gVAQKBgQDDkptxXu7w9Lpc+HeE -N/pJfpCzUuF7ZC4vatdoDzvfB5Ky6W88Poq+I7bB9m7StXdFAbDyUBxvisjTBMVA -OtYqpAk/rhX8MjSAtjoFe2nH+eEiQriuZmtA5CdKEXS4hNbc/HhEPWhk7Zh8OV5v -y7l4r6l4UHqaN9QyE0vlFdmcmQKBgQDCZZR/trJ2/g2OquaS+Zd2h/3NXw0NBq4z -4OBEWqNa/R35jdK6WlWJH7+tKOacr+xtswLpPeZHGwMdk64/erbYWBuJWAjpH72J -DM9+1H5fFHANWpWTNn94enQxwfzZRvdkxq4IWzGhesptYnHIzoAmaqC3lbn/e3u0 -Flng32hFoQKBgQCF3D4K3hib0lYQtnxPgmUMktWF+A+fflViXTWs4uhu4mcVkFNz -n7clJ5q6reryzAQjtmGfqRedfRex340HRn46V2aBMK2Znd9zzcZu5CbmGnFvGs3/ -iNiWZNNDjike9sV+IkxLIODoW/vH4xhxWrbLFSjg0ezoy5ew4qZK2abF2QKBgQC5 -M5efeQpbjTyTUERtf/aKCZOGZmkDoPq0GCjxVjzNQdqd1z0NJ2TYR/QP36idXIlu -FZ7PYZaS5aw5MGpQtfOe94n8dm++0et7t0WzunRO1yTNxCA+aSxWNquegAcJZa/q -RdKlyWPmSRqzzZdDzWCPuQQ3AyF5wkYfUy/7qjwoIQKBgB2v96BV7+lICviIKzzb -1o3A3VzAX5MGd98uLGjlK4qsBC+s7mk2eQztiNZgbA0W6fhQ5Dz3HcXJ5ppy8Okc -jeAktrNRzz15hvi/XkWdO+VMqiHW4l+sWYukjhCyod1oO1KGHq0LYYvv076syxGw -vRKLq7IJ4WIp1VtfaBlrIogq +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDYOOQyq8yYtn5x +K3yRborFxTFse16JIVb4x/ZhZgGm49eARCi09fmczQxJdQpHz81Ij6z0xi7AUYH7 +9wS8T0Lh3uGFDDS1GzITUVPIqSUi0xim2T6XPzXFVQYI1D/OjUxlHm+3/up+WwbL +sBgBO/lDmzoa3ZN7kt9HQoGc/14oQz1Qsv1QTDQs69r+o7mmBJr/hf/g7S0Csyy3 +iC6aaq+yCUyzDbjXceTI7WJqbTGNnK0/DjdFD/SJS/uSDNEg0AH53eqcCSjm+Ei/ +UF8qR5Pu4sSsNwToOW2MVgjtHFazc+kG3rzD6+3Dp+t6x6uI/npyuudOMCmOtd6z +kX0UPQaNAgMBAAECggEAS4eMBztGC+5rusKTEAZKSY15l0h9HG/d/qdzJFDKsO6T +/8VPZu8pk6F48kwFHFK1hexSYWq9OAcA3fBK4jDZzybZJm2+F6l5U5AsMUMMqt6M +lPP8Tj8RXG433muuIkvvbL82DVLpvNu1Qv+vUvcNOpWFtY7DDv6eKjlMJ3h4/pzh +89MNt26VMCYOlq1NSjuZBzFohL2u9nsFehlOpcVsqNfNfcYCq9+5yoH8fWJP90Op +hqhvqUoGLN7DRKV1f+AWHSA4nmGgvVviV5PQgMhtk5exlN7kG+rDc3LbzhefS1Sp +Tat1qIgm8fK2n+Q/obQPjHOGOGuvE5cIF7E275ZKgQKBgQDt87BqALKWnbkbQnb7 +GS1h6LRcKyZhFbxnO2qbviBWSo15LEF8jPGV33Dj+T56hqufa/rUkbZiUbIR9yOX +dnOwpAVTo+ObAwZfGfHvrnufiIbHFqJBumaYLqjRZ7AC0QtS3G+kjS9dbllrr7ok +fO4JdfKRXzBJKrkQdCn8hR22rQKBgQDon0b49Dxs1EfdSDbDode2TSwE83fI3vmR +SKUkNY8ma6CRbomVRWijhBM458wJeuhpjPZOvjNMsnDzGwrtdAp2VfFlMIDnA8ZC +fEWIAAH2QYKXKGmkoXOcWB2QbvbI154zCm6zFGtzvRKOCGmTXuhFajO8VPwOyJVt +aSJA3bLrYQKBgQDJM2/tAfAAKRdW9GlUwqI8Ep9G+/l0yANJqtTnIemH7XwYhJJO +9YJlPszfB2aMBgliQNSUHy1/jyKpzDYdITyLlPUoFwEilnkxuud2yiuf5rpH51yF +hU6wyWtXvXv3tbkEdH42PmdZcjBMPQeBSN2hxEi6ISncBDL9tau26PwJ9QKBgQCs +cNYl2reoXTzgtpWSNDk6NL769JjJWTFcF6QD0YhKjOI8rNpkw00sWc3+EybXqDr9 +c7dq6+gPZQAB1vwkxi6zRkZqIqiLl+qygnjwtkC+EhYCg7y8g8q2DUPtO7TJcb0e +TQ9+xRZad8B3dZj93A8G1hF//OfU9bB/qL3xo+bsQQKBgC/9YJvgLIWA/UziLcB2 +29Ai0nbPkN5df7z4PifUHHSlbQJHKak8UKbMP+8S064Ul0F7g8UCjZMk2LzSbaNY +XU5+2j0sIOnGUFoSlvcpdowzYrD2LN5PkKBot7AOq/v7HlcOoR8J8RGWAMpCrHsI +a/u/dlZs+/K16RcavQwx8rag -----END PRIVATE KEY----- -----BEGIN CERTIFICATE----- -MIIDTTCCAjWgAwIBAgIJAJ6VG2cQlsepMA0GCSqGSIb3DQEBCwUAMEwxCzAJBgNV +MIIDWTCCAkGgAwIBAgIJAOL9UKiOOxupMA0GCSqGSIb3DQEBCwUAMEwxCzAJBgNV BAYTAkZSMQ4wDAYDVQQHDAVQYXJpczEZMBcGA1UECgwQQXltZXJpYyBBdWd1c3Rp -bjESMBAGA1UEAwwJbG9jYWxob3N0MCAXDTE4MDUwNTE2NTc1NloYDzIwNjAwNTA0 -MTY1NzU2WjBMMQswCQYDVQQGEwJGUjEOMAwGA1UEBwwFUGFyaXMxGTAXBgNVBAoM +bjESMBAGA1UEAwwJbG9jYWxob3N0MCAXDTIyMTAxNTE5Mjg0MVoYDzIwNjQxMDE0 +MTkyODQxWjBMMQswCQYDVQQGEwJGUjEOMAwGA1UEBwwFUGFyaXMxGTAXBgNVBAoM EEF5bWVyaWMgQXVndXN0aW4xEjAQBgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZI -hvcNAQEBBQADggEPADCCAQoCggEBAJSCtBWQ1sBZGWjNlSPXhR/PtgSnYxea+aF2 -V84YvCPL7E873xolG/n+dgXZ5YzeWVyYt7wVsFIr5AVOjiy7tlWdzqohM4epxINT -DTpZqtBQyz3huEdS9CnW7z5vaE2Ix4bDr5CIEjo4lE6IaktFuQ3pSPcArCLxJhWg -vIyLO27Bs3IZ/x8XcMOkdm0GK0a0xIEIyxCx8HjrmmXZSjIGtZraWxsu3dW8Flm8 -ep8S4+OmOMo3lRIhedp/Q2LNpHqmzcTJ9+1bLiLvMhA3m5MTG9o8PI+f2cfer92R -P32ZIxJTUC9NOlfw83sOWoTrBkxtCwE9EZbsYSVD47Egp0o4uTkCAwEAAaMwMC4w -LAYDVR0RBCUwI4IJbG9jYWxob3N0hwR/AAABhxAAAAAAAAAAAAAAAAAAAAABMA0G -CSqGSIb3DQEBCwUAA4IBAQA0imKp/rflfbDCCx78NdsR5rt0jKem2t3YPGT6tbeU -+FQz62SEdeD2OHWxpvfPf+6h3iTXJbkakr2R4lP3z7GHUe61lt3So9VHAvgbtPTH -aB1gOdThA83o0fzQtnIv67jCvE9gwPQInViZLEcm2iQEZLj6AuSvBKmluTR7vNRj -8/f2R4LsDfCWGrzk2W+deGRvSow7irS88NQ8BW8S8otgMiBx4D2UlOmQwqr6X+/r -jYIDuMb6GDKRXtBUGDokfE94hjj9u2mrNRwt8y4tqu8ZNa//yLEQ0Ow2kP3QJPLY -941VZpwRi2v/+JvI7OBYlvbOTFwM8nAk79k+Dgviygd9 +hvcNAQEBBQADggEPADCCAQoCggEBANg45DKrzJi2fnErfJFuisXFMWx7XokhVvjH +9mFmAabj14BEKLT1+ZzNDEl1CkfPzUiPrPTGLsBRgfv3BLxPQuHe4YUMNLUbMhNR +U8ipJSLTGKbZPpc/NcVVBgjUP86NTGUeb7f+6n5bBsuwGAE7+UObOhrdk3uS30dC +gZz/XihDPVCy/VBMNCzr2v6juaYEmv+F/+DtLQKzLLeILppqr7IJTLMNuNdx5Mjt +YmptMY2crT8ON0UP9IlL+5IM0SDQAfnd6pwJKOb4SL9QXypHk+7ixKw3BOg5bYxW +CO0cVrNz6QbevMPr7cOn63rHq4j+enK6504wKY613rORfRQ9Bo0CAwEAAaM8MDow +OAYDVR0RBDEwL4IJbG9jYWxob3N0ggpvdmVycmlkZGVuhwR/AAABhxAAAAAAAAAA +AAAAAAAAAAABMA0GCSqGSIb3DQEBCwUAA4IBAQBPNDGDdl4wsCRlDuyCHBC8o+vW +Vb14thUw9Z6UrlsQRXLONxHOXbNAj1sYQACNwIWuNz36HXu5m8Xw/ID/bOhnIg+b +Y6l/JU/kZQYB7SV1aR3ZdbCK0gjfkE0POBHuKOjUFIOPBCtJ4tIBUX94zlgJrR9v +2rqJC3TIYrR7pVQumHZsI5GZEMpM5NxfreWwxcgltgxmGdm7elcizHfz7k5+szwh +4eZ/rxK9bw1q8BIvVBWelRvUR55mIrCjzfZp5ZObSYQTZlW7PzXBe5Jk+1w31YHM +RSBA2EpPhYlGNqPidi7bg7rnQcsc6+hE0OqzTL/hWxPm9Vbp9dj3HFTik1wa -----END CERTIFICATE----- diff --git a/tests/utils.py b/tests/utils.py index 5331746f8..afc1a460a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,12 +1,22 @@ import contextlib import email.utils import os +import pathlib import platform import time import unittest import warnings +# Generate TLS certificate with: +# $ openssl req -x509 -config test_localhost.cnf -days 15340 -newkey rsa:2048 \ +# -out test_localhost.crt -keyout test_localhost.key +# $ cat test_localhost.key test_localhost.crt > test_localhost.pem +# $ rm test_localhost.key test_localhost.crt + +CERTIFICATE = bytes(pathlib.Path(__file__).with_name("test_localhost.pem")) + + DATE = email.utils.formatdate(usegmt=True) From 7b92fa02d88b6d6e807a653329b85127ce79d5e7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 18:24:42 +0100 Subject: [PATCH 1165/1539] Add temp_unix_socket_path context manager for tests. --- tests/legacy/test_client_server.py | 11 +++-------- tests/utils.py | 7 +++++++ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 72f0b021b..d92338585 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -3,13 +3,11 @@ import functools import http import logging -import pathlib import platform import random import socket import ssl import sys -import tempfile import unittest import unittest.mock import urllib.error @@ -42,7 +40,7 @@ NoOpExtension, ServerNoOpExtensionFactory, ) -from ..utils import CERTIFICATE, MS +from ..utils import CERTIFICATE, MS, temp_unix_socket_path from .utils import AsyncioTestCase @@ -447,9 +445,7 @@ def send(self, *args, **kwargs): @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") def test_unix_socket(self): - with tempfile.TemporaryDirectory() as temp_dir: - path = bytes(pathlib.Path(temp_dir) / "websockets") - + with temp_unix_socket_path() as path: # Like self.start_server() but with unix_serve(). async def start_server(): return await unix_serve(default_handler, path) @@ -1445,8 +1441,7 @@ async def run_server(path): # Check that exiting the context manager closed the server. self.assertFalse(server.sockets) - with tempfile.TemporaryDirectory() as temp_dir: - path = bytes(pathlib.Path(temp_dir) / "websockets") + with temp_unix_socket_path() as path: self.loop.run_until_complete(run_server(path)) diff --git a/tests/utils.py b/tests/utils.py index afc1a460a..4e9ac9f0e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,6 +3,7 @@ import os import pathlib import platform +import tempfile import time import unittest import warnings @@ -79,3 +80,9 @@ def assertDeprecationWarning(self, message): warning = recorded_warnings[0] self.assertEqual(warning.category, DeprecationWarning) self.assertEqual(str(warning.message), message) + + +@contextlib.contextmanager +def temp_unix_socket_path(): + with tempfile.TemporaryDirectory() as temp_dir: + yield str(pathlib.Path(temp_dir) / "websockets") From d9c8694a5cd41287a33f804419c75937420e251b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 18:06:32 +0100 Subject: [PATCH 1166/1539] Add websockets.sync package. --- setup.py | 7 ++++++- src/websockets/sync/__init__.py | 0 tests/sync/__init__.py | 0 3 files changed, 6 insertions(+), 1 deletion(-) create mode 100644 src/websockets/sync/__init__.py create mode 100644 tests/sync/__init__.py diff --git a/setup.py b/setup.py index 564ada85c..5ed472503 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,12 @@ exec((root_dir / "src" / "websockets" / "version.py").read_text(encoding="utf-8")) -packages = ["websockets", "websockets/legacy", "websockets/extensions"] +packages = [ + "websockets", + "websockets/extensions", + "websockets/legacy", + "websockets/sync", +] ext_modules = [ setuptools.Extension( diff --git a/src/websockets/sync/__init__.py b/src/websockets/sync/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/sync/__init__.py b/tests/sync/__init__.py new file mode 100644 index 000000000..e69de29bb From 2d624fa36ae66fc818e075726f0a4aff683a4ef3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 18:06:50 +0100 Subject: [PATCH 1167/1539] Add thread-safe message reassembler. --- src/websockets/sync/messages.py | 281 +++++++++++++++++++ tests/sync/test_messages.py | 479 ++++++++++++++++++++++++++++++++ tests/sync/utils.py | 26 ++ 3 files changed, 786 insertions(+) create mode 100644 src/websockets/sync/messages.py create mode 100644 tests/sync/test_messages.py create mode 100644 tests/sync/utils.py diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py new file mode 100644 index 000000000..9265e4e9f --- /dev/null +++ b/src/websockets/sync/messages.py @@ -0,0 +1,281 @@ +from __future__ import annotations + +import codecs +import queue +import threading +from typing import Iterator, List, Optional, cast + +from ..frames import Frame, Opcode +from ..typing import Data + + +__all__ = ["Assembler"] + +UTF8Decoder = codecs.getincrementaldecoder("utf-8") + + +class Assembler: + """ + Assemble messages from frames. + + """ + + def __init__(self) -> None: + # Serialize reads and writes -- except for reads via synchronization + # primitives provided by the threading and queue modules. + self.mutex = threading.Lock() + + # We create a latch with two events to ensure proper interleaving of + # writing and reading messages. + # put() sets this event to tell get() that a message can be fetched. + self.message_complete = threading.Event() + # get() sets this event to let put() that the message was fetched. + self.message_fetched = threading.Event() + + # This flag prevents concurrent calls to get() by user code. + self.get_in_progress = False + # This flag prevents concurrent calls to put() by library code. + self.put_in_progress = False + + # Decoder for text frames, None for binary frames. + self.decoder: Optional[codecs.IncrementalDecoder] = None + + # Buffer of frames belonging to the same message. + self.chunks: List[Data] = [] + + # When switching from "buffering" to "streaming", we use a thread-safe + # queue for transferring frames from the writing thread (library code) + # to the reading thread (user code). We're buffering when chunks_queue + # is None and streaming when it's a SimpleQueue. None is a sentinel + # value marking the end of the stream, superseding message_complete. + + # Stream data from frames belonging to the same message. + # Remove quotes around type when dropping Python < 3.9. + self.chunks_queue: Optional["queue.SimpleQueue[Optional[Data]]"] = None + + # This flag marks the end of the stream. + self.closed = False + + def get(self, timeout: Optional[float] = None) -> Data: + """ + Read the next message. + + :meth:`get` returns a single :class:`str` or :class:`bytes`. + + If the message is fragmented, :meth:`get` waits until the last frame is + received, then it reassembles the message and returns it. To receive + messages frame by frame, use :meth:`get_iter` instead. + + Args: + timeout: if a timeout is provided and elapses before a complete + message is received, :meth:`get` raises :exc:`TimeoutError`. + + Raises: + EOFError: if the stream of frames has ended. + RuntimeError: if two threads run :meth:`get` or :meth:``get_iter` + concurrently. + + """ + with self.mutex: + if self.closed: + raise EOFError("stream of frames ended") + + if self.get_in_progress: + raise RuntimeError("get or get_iter is already running") + + self.get_in_progress = True + + # If the message_complete event isn't set yet, release the lock to + # allow put() to run and eventually set it. + # Locking with get_in_progress ensures only one thread can get here. + completed = self.message_complete.wait(timeout) + + with self.mutex: + self.get_in_progress = False + + # Waiting for a complete message timed out. + if not completed: + raise TimeoutError(f"timed out in {timeout:.1f}s") + + # get() was unblocked by close() rather than put(). + if self.closed: + raise EOFError("stream of frames ended") + + assert self.message_complete.is_set() + self.message_complete.clear() + + joiner: Data = b"" if self.decoder is None else "" + # mypy cannot figure out that chunks have the proper type. + message: Data = joiner.join(self.chunks) # type: ignore + + assert not self.message_fetched.is_set() + self.message_fetched.set() + + self.chunks = [] + assert self.chunks_queue is None + + return message + + def get_iter(self) -> Iterator[Data]: + """ + Stream the next message. + + Iterating the return value of :meth:`get_iter` yields a :class:`str` or + :class:`bytes` for each frame in the message. + + The iterator must be fully consumed before calling :meth:`get_iter` or + :meth:`get` again. Else, :exc:`RuntimeError` is raised. + + This method only makes sense for fragmented messages. If messages aren't + fragmented, use :meth:`get` instead. + + Raises: + EOFError: if the stream of frames has ended. + RuntimeError: if two threads run :meth:`get` or :meth:``get_iter` + concurrently. + + """ + with self.mutex: + if self.closed: + raise EOFError("stream of frames ended") + + if self.get_in_progress: + raise RuntimeError("get or get_iter is already running") + + chunks = self.chunks + self.chunks = [] + self.chunks_queue = cast( + # Remove quotes around type when dropping Python < 3.9. + "queue.SimpleQueue[Optional[Data]]", + queue.SimpleQueue(), + ) + + # Sending None in chunk_queue supersedes setting message_complete + # when switching to "streaming". If message is already complete + # when the switch happens, put() didn't send None, so we have to. + if self.message_complete.is_set(): + self.chunks_queue.put(None) + + self.get_in_progress = True + + # Locking with get_in_progress ensures only one thread can get here. + yield from chunks + while True: + chunk = self.chunks_queue.get() + if chunk is None: + break + yield chunk + + with self.mutex: + self.get_in_progress = False + + assert self.message_complete.is_set() + self.message_complete.clear() + + # get_iter() was unblocked by close() rather than put(). + if self.closed: + raise EOFError("stream of frames ended") + + assert not self.message_fetched.is_set() + self.message_fetched.set() + + assert self.chunks == [] + self.chunks_queue = None + + def put(self, frame: Frame) -> None: + """ + Add ``frame`` to the next message. + + When ``frame`` is the final frame in a message, :meth:`put` waits until + the message is fetched, either by calling :meth:`get` or by fully + consuming the return value of :meth:`get_iter`. + + :meth:`put` assumes that the stream of frames respects the protocol. If + it doesn't, the behavior is undefined. + + Raises: + EOFError: if the stream of frames has ended. + RuntimeError: if two threads run :meth:`put` concurrently. + + """ + with self.mutex: + if self.closed: + raise EOFError("stream of frames ended") + + if self.put_in_progress: + raise RuntimeError("put is already running") + + if frame.opcode is Opcode.TEXT: + self.decoder = UTF8Decoder(errors="strict") + elif frame.opcode is Opcode.BINARY: + self.decoder = None + elif frame.opcode is Opcode.CONT: + pass + else: + # Ignore control frames. + return + + data: Data + if self.decoder is not None: + data = self.decoder.decode(frame.data, frame.fin) + else: + data = frame.data + + if self.chunks_queue is None: + self.chunks.append(data) + else: + self.chunks_queue.put(data) + + if not frame.fin: + return + + # Message is complete. Wait until it's fetched to return. + + assert not self.message_complete.is_set() + self.message_complete.set() + + if self.chunks_queue is not None: + self.chunks_queue.put(None) + + assert not self.message_fetched.is_set() + + self.put_in_progress = True + + # Release the lock to allow get() to run and eventually set the event. + self.message_fetched.wait() + + with self.mutex: + self.put_in_progress = False + + assert self.message_fetched.is_set() + self.message_fetched.clear() + + # put() was unblocked by close() rather than get() or get_iter(). + if self.closed: + raise EOFError("stream of frames ended") + + self.decoder = None + + def close(self) -> None: + """ + End the stream of frames. + + Callling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, + or :meth:`put` is safe. They will raise :exc:`EOFError`. + + """ + with self.mutex: + if self.closed: + return + + self.closed = True + + # Unblock get or get_iter. + if self.get_in_progress: + self.message_complete.set() + if self.chunks_queue is not None: + self.chunks_queue.put(None) + + # Unblock put(). + if self.put_in_progress: + self.message_fetched.set() diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py new file mode 100644 index 000000000..069da784b --- /dev/null +++ b/tests/sync/test_messages.py @@ -0,0 +1,479 @@ +import time + +from websockets.frames import OP_BINARY, OP_CONT, OP_PING, OP_PONG, OP_TEXT, Frame +from websockets.sync.messages import * + +from ..utils import MS +from .utils import ThreadTestCase + + +class AssemblerTests(ThreadTestCase): + """ + Tests in this class interact a lot with hidden synchronization mechanisms: + + - get() / get_iter() and put() must run in separate threads when a final + frame is set because put() waits for get() / get_iter() to fetch the + message before returning. + + - run_in_thread() lets its target run before yielding back control on entry, + which guarantees the intended execution order of test cases. + + - run_in_thread() waits for its target to finish running before yielding + back control on exit, which allows making assertions immediately. + + - When the main thread performs actions that let another thread progress, it + must wait before making assertions, to avoid depending on scheduling. + + """ + + def setUp(self): + self.assembler = Assembler() + + def tearDown(self): + """ + Ensure the assembler goes back to its default state after each test. + + This removes the need for testing various sequences. + + """ + self.assertFalse(self.assembler.mutex.locked()) + self.assertFalse(self.assembler.get_in_progress) + self.assertFalse(self.assembler.put_in_progress) + if not self.assembler.closed: + self.assertFalse(self.assembler.message_complete.is_set()) + self.assertFalse(self.assembler.message_fetched.is_set()) + self.assertIsNone(self.assembler.decoder) + self.assertEqual(self.assembler.chunks, []) + self.assertIsNone(self.assembler.chunks_queue) + + # Test get + + def test_get_text_message_already_received(self): + """get returns a text message that is already received.""" + + def putter(): + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + with self.run_in_thread(putter): + message = self.assembler.get() + + self.assertEqual(message, "café") + + def test_get_binary_message_already_received(self): + """get returns a binary message that is already received.""" + + def putter(): + self.assembler.put(Frame(OP_BINARY, b"tea")) + + with self.run_in_thread(putter): + message = self.assembler.get() + + self.assertEqual(message, b"tea") + + def test_get_text_message_not_received_yet(self): + """get returns a text message when it is received.""" + message = None + + def getter(): + nonlocal message + message = self.assembler.get() + + with self.run_in_thread(getter): + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + self.assertEqual(message, "café") + + def test_get_binary_message_not_received_yet(self): + """get returns a binary message when it is received.""" + message = None + + def getter(): + nonlocal message + message = self.assembler.get() + + with self.run_in_thread(getter): + self.assembler.put(Frame(OP_BINARY, b"tea")) + + self.assertEqual(message, b"tea") + + def test_get_fragmented_text_message_already_received(self): + """get reassembles a fragmented a text message that is already received.""" + + def putter(): + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + + with self.run_in_thread(putter): + message = self.assembler.get() + + self.assertEqual(message, "café") + + def test_get_fragmented_binary_message_already_received(self): + """get reassembles a fragmented binary message that is already received.""" + + def putter(): + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + with self.run_in_thread(putter): + message = self.assembler.get() + + self.assertEqual(message, b"tea") + + def test_get_fragmented_text_message_being_received(self): + """get reassembles a fragmented text message that is partially received.""" + message = None + + def getter(): + nonlocal message + message = self.assembler.get() + + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + with self.run_in_thread(getter): + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + + self.assertEqual(message, "café") + + def test_get_fragmented_binary_message_being_received(self): + """get reassembles a fragmented binary message that is partially received.""" + message = None + + def getter(): + nonlocal message + message = self.assembler.get() + + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + with self.run_in_thread(getter): + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + self.assertEqual(message, b"tea") + + def test_get_fragmented_text_message_not_received_yet(self): + """get reassembles a fragmented text message when it is received.""" + message = None + + def getter(): + nonlocal message + message = self.assembler.get() + + with self.run_in_thread(getter): + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + + self.assertEqual(message, "café") + + def test_get_fragmented_binary_message_not_received_yet(self): + """get reassembles a fragmented binary message when it is received.""" + message = None + + def getter(): + nonlocal message + message = self.assembler.get() + + with self.run_in_thread(getter): + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + self.assertEqual(message, b"tea") + + # Test get_iter + + def test_get_iter_text_message_already_received(self): + """get_iter yields a text message that is already received.""" + + def putter(): + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + with self.run_in_thread(putter): + fragments = list(self.assembler.get_iter()) + + self.assertEqual(fragments, ["café"]) + + def test_get_iter_binary_message_already_received(self): + """get_iter yields a binary message that is already received.""" + + def putter(): + self.assembler.put(Frame(OP_BINARY, b"tea")) + + with self.run_in_thread(putter): + fragments = list(self.assembler.get_iter()) + + self.assertEqual(fragments, [b"tea"]) + + def test_get_iter_text_message_not_received_yet(self): + """get_iter yields a text message when it is received.""" + fragments = [] + + def getter(): + for fragment in self.assembler.get_iter(): + fragments.append(fragment) + + with self.run_in_thread(getter): + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + self.assertEqual(fragments, ["café"]) + + def test_get_iter_binary_message_not_received_yet(self): + """get_iter yields a binary message when it is received.""" + fragments = [] + + def getter(): + for fragment in self.assembler.get_iter(): + fragments.append(fragment) + + with self.run_in_thread(getter): + self.assembler.put(Frame(OP_BINARY, b"tea")) + + self.assertEqual(fragments, [b"tea"]) + + def test_get_iter_fragmented_text_message_already_received(self): + """get_iter yields a fragmented text message that is already received.""" + + def putter(): + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + + with self.run_in_thread(putter): + fragments = list(self.assembler.get_iter()) + + self.assertEqual(fragments, ["ca", "f", "é"]) + + def test_get_iter_fragmented_binary_message_already_received(self): + """get_iter yields a fragmented binary message that is already received.""" + + def putter(): + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + with self.run_in_thread(putter): + fragments = list(self.assembler.get_iter()) + + self.assertEqual(fragments, [b"t", b"e", b"a"]) + + def test_get_iter_fragmented_text_message_being_received(self): + """get_iter yields a fragmented text message that is partially received.""" + fragments = [] + + def getter(): + for fragment in self.assembler.get_iter(): + fragments.append(fragment) + + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + with self.run_in_thread(getter): + self.assertEqual(fragments, ["ca"]) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + time.sleep(MS) + self.assertEqual(fragments, ["ca", "f"]) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + + self.assertEqual(fragments, ["ca", "f", "é"]) + + def test_get_iter_fragmented_binary_message_being_received(self): + """get_iter yields a fragmented binary message that is partially received.""" + fragments = [] + + def getter(): + for fragment in self.assembler.get_iter(): + fragments.append(fragment) + + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + with self.run_in_thread(getter): + self.assertEqual(fragments, [b"t"]) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + time.sleep(MS) + self.assertEqual(fragments, [b"t", b"e"]) + self.assembler.put(Frame(OP_CONT, b"a")) + + self.assertEqual(fragments, [b"t", b"e", b"a"]) + + def test_get_iter_fragmented_text_message_not_received_yet(self): + """get_iter yields a fragmented text message when it is received.""" + fragments = [] + + def getter(): + for fragment in self.assembler.get_iter(): + fragments.append(fragment) + + with self.run_in_thread(getter): + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + time.sleep(MS) + self.assertEqual(fragments, ["ca"]) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + time.sleep(MS) + self.assertEqual(fragments, ["ca", "f"]) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + + self.assertEqual(fragments, ["ca", "f", "é"]) + + def test_get_iter_fragmented_binary_message_not_received_yet(self): + """get_iter yields a fragmented binary message when it is received.""" + fragments = [] + + def getter(): + for fragment in self.assembler.get_iter(): + fragments.append(fragment) + + with self.run_in_thread(getter): + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + time.sleep(MS) + self.assertEqual(fragments, [b"t"]) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + time.sleep(MS) + self.assertEqual(fragments, [b"t", b"e"]) + self.assembler.put(Frame(OP_CONT, b"a")) + + self.assertEqual(fragments, [b"t", b"e", b"a"]) + + # Test timeouts + + def test_get_with_timeout_completes(self): + """get returns a message when it is received before the timeout.""" + + def putter(): + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + with self.run_in_thread(putter): + message = self.assembler.get(MS) + + self.assertEqual(message, "café") + + def test_get_with_timeout_times_out(self): + """get raises TimeoutError when no message is received before the timeout.""" + with self.assertRaises(TimeoutError): + self.assembler.get(MS) + + # Test control frames + + def test_control_frame_before_message_is_ignored(self): + """get ignores control frames between messages.""" + + def putter(): + self.assembler.put(Frame(OP_PING, b"")) + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + with self.run_in_thread(putter): + message = self.assembler.get() + + self.assertEqual(message, "café") + + def test_control_frame_in_fragmented_message_is_ignored(self): + """get ignores control frames within fragmented messages.""" + + def putter(): + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_PING, b"")) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_PONG, b"")) + self.assembler.put(Frame(OP_CONT, b"a")) + + with self.run_in_thread(putter): + message = self.assembler.get() + + self.assertEqual(message, b"tea") + + # Test concurrency + + def test_get_fails_when_get_is_running(self): + """get cannot be called concurrently with itself.""" + with self.run_in_thread(self.assembler.get): + with self.assertRaises(RuntimeError): + self.assembler.get() + self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread + + def test_get_fails_when_get_iter_is_running(self): + """get cannot be called concurrently with get_iter.""" + with self.run_in_thread(lambda: list(self.assembler.get_iter())): + with self.assertRaises(RuntimeError): + self.assembler.get() + self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread + + def test_get_iter_fails_when_get_is_running(self): + """get_iter cannot be called concurrently with get.""" + with self.run_in_thread(self.assembler.get): + with self.assertRaises(RuntimeError): + list(self.assembler.get_iter()) + self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread + + def test_get_iter_fails_when_get_iter_is_running(self): + """get_iter cannot be called concurrently with itself.""" + with self.run_in_thread(lambda: list(self.assembler.get_iter())): + with self.assertRaises(RuntimeError): + list(self.assembler.get_iter()) + self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread + + def test_put_fails_when_put_is_running(self): + """put cannot be called concurrently with itself.""" + + def putter(): + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + with self.run_in_thread(putter): + with self.assertRaises(RuntimeError): + self.assembler.put(Frame(OP_BINARY, b"tea")) + self.assembler.get() # unblock other thread + + # Test termination + + def test_get_fails_when_interrupted_by_close(self): + """get raises EOFError when close is called.""" + + def closer(): + time.sleep(2 * MS) + self.assembler.close() + + with self.run_in_thread(closer): + with self.assertRaises(EOFError): + self.assembler.get() + + def test_get_iter_fails_when_interrupted_by_close(self): + """get_iter raises EOFError when close is called.""" + + def closer(): + time.sleep(2 * MS) + self.assembler.close() + + with self.run_in_thread(closer): + with self.assertRaises(EOFError): + list(self.assembler.get_iter()) + + def test_put_fails_when_interrupted_by_close(self): + """put raises EOFError when close is called.""" + + def closer(): + time.sleep(2 * MS) + self.assembler.close() + + with self.run_in_thread(closer): + with self.assertRaises(EOFError): + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + def test_get_fails_after_close(self): + """get raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + self.assembler.get() + + def test_get_iter_fails_after_close(self): + """get_iter raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + list(self.assembler.get_iter()) + + def test_put_fails_after_close(self): + """put raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + def test_close_is_idempotent(self): + """close can be called multiple times safely.""" + self.assembler.close() + self.assembler.close() diff --git a/tests/sync/utils.py b/tests/sync/utils.py new file mode 100644 index 000000000..8903cd349 --- /dev/null +++ b/tests/sync/utils.py @@ -0,0 +1,26 @@ +import contextlib +import threading +import time +import unittest + +from ..utils import MS + + +class ThreadTestCase(unittest.TestCase): + @contextlib.contextmanager + def run_in_thread(self, target): + """ + Run ``target`` function without arguments in a thread. + + In order to facilitate writing tests, this helper lets the thread run + for 1ms on entry and joins the thread with a 1ms timeout on exit. + + """ + thread = threading.Thread(target=target) + thread.start() + time.sleep(MS) + try: + yield + finally: + thread.join(MS) + self.assertFalse(thread.is_alive()) From 4616405fa50959eedfbe8de6df423f9305f266f8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 24 Apr 2022 08:16:02 +0200 Subject: [PATCH 1168/1539] Add deadline for managing timeouts. --- src/websockets/sync/utils.py | 47 ++++++++++++++++++++++++++++++++++++ tests/sync/test_utils.py | 33 +++++++++++++++++++++++++ 2 files changed, 80 insertions(+) create mode 100644 src/websockets/sync/utils.py create mode 100644 tests/sync/test_utils.py diff --git a/src/websockets/sync/utils.py b/src/websockets/sync/utils.py new file mode 100644 index 000000000..8aab6c0d9 --- /dev/null +++ b/src/websockets/sync/utils.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import time +from typing import Optional + + +__all__ = ["Deadline"] + + +class Deadline: + """ + Manage timeouts across multiple steps. + + Args: + timeout: time available in seconds; :obj:`None` if there is no limit. + + """ + + def __init__(self, timeout: Optional[float]) -> None: + self.deadline: Optional[float] + if timeout is None: + self.deadline = None + else: + self.deadline = time.monotonic() + timeout + + def timeout(self, *, raise_if_elapsed: bool = True) -> Optional[float]: + """ + Calculate a timeout from a deadline. + + Args: + raise_if_elapsed (bool): whether to raise :exc:`TimeoutError` + if the deadline lapsed. + + Raises: + TimeoutError: if the deadline lapsed. + + Returns: + Optional[float]: Time left in seconds; + :obj:`None` if there is no limit. + + """ + if self.deadline is None: + return None + timeout = self.deadline - time.monotonic() + if raise_if_elapsed and timeout <= 0: + raise TimeoutError("timed out") + return timeout diff --git a/tests/sync/test_utils.py b/tests/sync/test_utils.py new file mode 100644 index 000000000..2980a97b4 --- /dev/null +++ b/tests/sync/test_utils.py @@ -0,0 +1,33 @@ +import unittest + +from websockets.sync.utils import * + +from ..utils import MS + + +class DeadlineTests(unittest.TestCase): + def test_timeout_pending(self): + """timeout returns remaining time if deadline is in the future.""" + deadline = Deadline(MS) + timeout = deadline.timeout() + self.assertGreater(timeout, 0) + self.assertLess(timeout, MS) + + def test_timeout_elapsed_exception(self): + """timeout raises TimeoutError if deadline is in the past.""" + deadline = Deadline(-MS) + with self.assertRaises(TimeoutError): + deadline.timeout() + + def test_timeout_elapsed_no_exception(self): + """timeout doesn't raise TimeoutError when raise_if_elapsed is disabled.""" + deadline = Deadline(-MS) + timeout = deadline.timeout(raise_if_elapsed=False) + self.assertGreater(timeout, -2 * MS) + self.assertLess(timeout, -MS) + + def test_no_timeout(self): + """timeout returns None when no deadline is set.""" + deadline = Deadline(None) + timeout = deadline.timeout() + self.assertIsNone(timeout, None) From 125ffe28d0fd1c1a80d6021a7dbe7ee69451bb33 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 17:55:30 +0100 Subject: [PATCH 1169/1539] Add compatibility shim for socket.create_server. --- src/websockets/sync/compatibility.py | 21 +++++++++++++++++++++ tests/maxi_cov.py | 4 +++- 2 files changed, 24 insertions(+), 1 deletion(-) create mode 100644 src/websockets/sync/compatibility.py diff --git a/src/websockets/sync/compatibility.py b/src/websockets/sync/compatibility.py new file mode 100644 index 000000000..3064263e9 --- /dev/null +++ b/src/websockets/sync/compatibility.py @@ -0,0 +1,21 @@ +from __future__ import annotations + + +try: + from socket import create_server as socket_create_server +except ImportError: # pragma: no cover + import socket + + def socket_create_server(address, family=socket.AF_INET): # type: ignore + """Simplified backport of socket.create_server from Python 3.8.""" + sock = socket.socket(family, socket.SOCK_STREAM) + try: + sock.bind(address) + sock.listen() + return sock + except socket.error: + sock.close() + raise + + +__all__ = ["socket_create_server"] diff --git a/tests/maxi_cov.py b/tests/maxi_cov.py index b7c07b698..2568dcf18 100755 --- a/tests/maxi_cov.py +++ b/tests/maxi_cov.py @@ -47,6 +47,7 @@ def get_mapping(src_dir="src"): if "legacy" not in os.path.dirname(src_file) if os.path.basename(src_file) != "__init__.py" and os.path.basename(src_file) != "__main__.py" + and os.path.basename(src_file) != "compatibility.py" ] test_files = [ test_file @@ -89,8 +90,9 @@ def get_ignored_files(src_dir="src"): return [ # */websockets matches src/websockets and .tox/**/site-packages/websockets. - # There are no tests for the __main__ module. + # There are no tests for the __main__ module and for compatibility modules. "*/websockets/__main__.py", + "*/websockets/*/compatibility.py", # This approach isn't applicable to the test suite of the legacy # implementation, due to the huge test_client_server test module. "*/websockets/legacy/*", From 59fa2e3de50a9c46f4287ac300b7a2c42c7ae0a4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 17:57:42 +0100 Subject: [PATCH 1170/1539] Add thread-based implementation. --- .github/workflows/tests.yml | 3 + setup.cfg | 1 + src/websockets/server.py | 2 + src/websockets/sync/client.py | 327 +++++++++++++ src/websockets/sync/connection.py | 757 ++++++++++++++++++++++++++++++ src/websockets/sync/messages.py | 14 +- src/websockets/sync/server.py | 525 +++++++++++++++++++++ src/websockets/sync/utils.py | 9 +- tests/protocol.py | 29 ++ tests/sync/client.py | 55 +++ tests/sync/connection.py | 109 +++++ tests/sync/server.py | 67 +++ tests/sync/test_client.py | 271 +++++++++++ tests/sync/test_connection.py | 704 +++++++++++++++++++++++++++ tests/sync/test_messages.py | 2 +- tests/sync/test_server.py | 389 +++++++++++++++ 16 files changed, 3251 insertions(+), 13 deletions(-) create mode 100644 src/websockets/sync/client.py create mode 100644 src/websockets/sync/connection.py create mode 100644 src/websockets/sync/server.py create mode 100644 tests/protocol.py create mode 100644 tests/sync/client.py create mode 100644 tests/sync/connection.py create mode 100644 tests/sync/server.py create mode 100644 tests/sync/test_client.py create mode 100644 tests/sync/test_connection.py create mode 100644 tests/sync/test_server.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4d5cc3cd0..34f8d8c5c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,6 +8,9 @@ on: branches: - main +env: + WEBSOCKETS_TESTS_TIMEOUT_FACTOR: 10 + jobs: coverage: name: Run test coverage checks diff --git a/setup.cfg b/setup.cfg index 48703df87..3a8321c50 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,3 +37,4 @@ exclude_lines = raise AssertionError raise NotImplementedError self.fail\(".*"\) + @unittest.skip diff --git a/src/websockets/server.py b/src/websockets/server.py index 16979e7e7..0dd579052 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -164,6 +164,8 @@ def accept(self, request: Request) -> Response: f"Failed to open a WebSocket connection: {exc}.\n", ) except Exception as exc: + # Handle exceptions raised by user-provided select_subprotocol and + # unexpected errors. request._exception = exc self.handshake_exc = exc self.logger.error("opening handshake failed", exc_info=True) diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py new file mode 100644 index 000000000..e5922582d --- /dev/null +++ b/src/websockets/sync/client.py @@ -0,0 +1,327 @@ +from __future__ import annotations + +import socket +import ssl +import threading +from typing import Any, Optional, Sequence, Type + +from ..client import ClientProtocol +from ..datastructures import HeadersLike +from ..extensions.base import ClientExtensionFactory +from ..extensions.permessage_deflate import enable_client_permessage_deflate +from ..headers import validate_subprotocols +from ..http import USER_AGENT +from ..http11 import Response +from ..protocol import CONNECTING, OPEN, Event +from ..typing import LoggerLike, Origin, Subprotocol +from ..uri import parse_uri +from .connection import Connection +from .utils import Deadline + + +__all__ = ["connect", "unix_connect", "ClientConnection"] + + +class ClientConnection(Connection): + """ + Threaded implementation of a WebSocket client connection. + + :class:`ClientConnection` provides :meth:`recv` and :meth:`send` methods for + receiving and sending messages. + + It supports iteration to receive messages:: + + for message in websocket: + process(message) + + The iterator exits normally when the connection is closed with close code + 1000 (OK) or 1001 (going away) or without a close code. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is + closed with any other code. + + Args: + socket: Socket connected to a WebSocket server. + protocol: Sans-I/O connection. + close_timeout: Timeout for closing the connection in seconds. + + """ + + def __init__( + self, + socket: socket.socket, + protocol: ClientProtocol, + *, + close_timeout: Optional[float] = 10, + ) -> None: + self.protocol: ClientProtocol + self.response_rcvd = threading.Event() + super().__init__( + socket, + protocol, + close_timeout=close_timeout, + ) + + def handshake( + self, + additional_headers: Optional[HeadersLike] = None, + user_agent_header: Optional[str] = USER_AGENT, + timeout: Optional[float] = None, + ) -> None: + """ + Perform the opening handshake. + + """ + with self.send_context(expected_state=CONNECTING): + self.request = self.protocol.connect() + if additional_headers is not None: + self.request.headers.update(additional_headers) + if user_agent_header is not None: + self.request.headers["User-Agent"] = user_agent_header + self.protocol.send_request(self.request) + + if not self.response_rcvd.wait(timeout): + self.close_socket() + self.recv_events_thread.join() + raise TimeoutError("timed out during handshake") + + if self.response is None: + self.close_socket() + self.recv_events_thread.join() + raise ConnectionError("connection closed during handshake") + + if self.protocol.state is not OPEN: + self.recv_events_thread.join(self.close_timeout) + self.close_socket() + self.recv_events_thread.join() + + if self.protocol.handshake_exc is not None: + raise self.protocol.handshake_exc + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + """ + # First event - handshake response. + if self.response is None: + assert isinstance(event, Response) + self.response = event + self.response_rcvd.set() + # Later events - frames. + else: + super().process_event(event) + + def recv_events(self) -> None: + """ + Read incoming data from the socket and process events. + + """ + try: + super().recv_events() + finally: + # If the connection is closed during the handshake, unblock it. + self.response_rcvd.set() + + +def connect( + uri: str, + *, + # TCP/TLS — unix and path are only for unix_connect() + sock: Optional[socket.socket] = None, + ssl_context: Optional[ssl.SSLContext] = None, + server_hostname: Optional[str] = None, + unix: bool = False, + path: Optional[str] = None, + # WebSocket + origin: Optional[Origin] = None, + extensions: Optional[Sequence[ClientExtensionFactory]] = None, + subprotocols: Optional[Sequence[Subprotocol]] = None, + additional_headers: Optional[HeadersLike] = None, + user_agent_header: Optional[str] = USER_AGENT, + compression: Optional[str] = "deflate", + # Timeouts + open_timeout: Optional[float] = 10, + close_timeout: Optional[float] = 10, + # Limits + max_size: Optional[int] = 2**20, + # Logging + logger: Optional[LoggerLike] = None, + # Escape hatch for advanced customization + create_connection: Optional[Type[ClientConnection]] = None, +) -> ClientConnection: + """ + Connect to the WebSocket server at ``uri``. + + This function returns a :class:`ClientConnection` instance, which you can + use to send and receive messages. + + :func:`connect` may be used as a context manager:: + + async with websockets.sync.client.connect(...) as websocket: + ... + + The connection is closed automatically when exiting the context. + + Args: + uri: URI of the WebSocket server. + sock: Preexisting TCP socket. ``sock`` overrides the host and port + from ``uri``. You may call :func:`socket.create_connection` to + create a suitable TCP socket. + ssl_context: Configuration for enabling TLS on the connection. + server_hostname: Hostname for the TLS handshake. ``server_hostname`` + overrides the hostname from ``uri``. + origin: Value of the ``Origin`` header, for servers that require it. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + additional_headers (HeadersLike | None): Arbitrary HTTP headers to add + to the handshake request. + user_agent_header: Value of the ``User-Agent`` request header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. + Setting it to :obj:`None` removes the header. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + open_timeout: Timeout for opening the connection in seconds. + :obj:`None` disables the timeout. + close_timeout: Timeout for closing the connection in seconds. + :obj:`None` disables the timeout. + max_size: Maximum size of incoming messages in bytes. + :obj:`None` disables the limit. + logger: Logger for this client. + It defaults to ``logging.getLogger("websockets.client")``. + See the :doc:`logging guide <../../topics/logging>` for details. + create_connection: Factory for the :class:`ClientConnection` managing + the connection. Set it to a wrapper or a subclass to customize + connection handling. + + Raises: + InvalidURI: If ``uri`` isn't a valid WebSocket URI. + InvalidHandshake: If the opening handshake fails. + TimeoutError: If the opening handshake times out. + + """ + + # Process parameters + + wsuri = parse_uri(uri) + if not wsuri.secure and ssl_context is not None: + raise TypeError("ssl_context argument is incompatible with a ws:// URI") + + if unix: + if path is None and sock is None: + raise TypeError("missing path argument") + elif path is not None and sock is not None: + raise TypeError("path and sock arguments are incompatible") + else: + assert path is None # private argument, only set by unix_connect() + + if subprotocols is not None: + validate_subprotocols(subprotocols) + + if compression == "deflate": + extensions = enable_client_permessage_deflate(extensions) + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + # Calculate timeouts on the TCP, TLS, and WebSocket handshakes. + # The TCP and TLS timeouts must be set on the socket, then removed + # to avoid conflicting with the WebSocket timeout in handshake(). + deadline = Deadline(open_timeout) + + if create_connection is None: + create_connection = ClientConnection + + try: + # Connect socket + + if sock is None: + if unix: + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.settimeout(deadline.timeout()) + assert path is not None # validated above -- this is for mpypy + sock.connect(path) + else: + sock = socket.create_connection( + (wsuri.host, wsuri.port), + deadline.timeout(), + ) + sock.settimeout(None) + + # Disable Nagle algorithm + + if not unix: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True) + + # Initialize TLS wrapper and perform TLS handshake + + if wsuri.secure: + if ssl_context is None: + ssl_context = ssl.create_default_context() + if server_hostname is None: + server_hostname = wsuri.host + sock.settimeout(deadline.timeout()) + sock = ssl_context.wrap_socket(sock, server_hostname=server_hostname) + sock.settimeout(None) + + # Initialize WebSocket connection + + protocol = ClientProtocol( + wsuri, + origin=origin, + extensions=extensions, + subprotocols=subprotocols, + state=CONNECTING, + max_size=max_size, + logger=logger, + ) + + # Initialize WebSocket protocol + + connection = create_connection( + sock, + protocol, + close_timeout=close_timeout, + ) + # On failure, handshake() closes the socket and raises an exception. + connection.handshake( + additional_headers, + user_agent_header, + deadline.timeout(), + ) + + except Exception: + if sock is not None: + sock.close() + raise + + return connection + + +def unix_connect( + path: Optional[str] = None, + uri: Optional[str] = None, + **kwargs: Any, +) -> ClientConnection: + """ + Connect to a WebSocket server listening on a Unix socket. + + This function is identical to :func:`connect`, except for the additional + ``path`` argument. It's only available on Unix. + + It's mainly useful for debugging servers listening on Unix sockets. + + Args: + path: File system path to the Unix socket. + uri: URI of the WebSocket server. ``uri`` defaults to + ``ws://localhost/`` or, when a ``ssl_context`` is provided, to + ``wss://localhost/``. + + """ + if uri is None: + if kwargs.get("ssl_context") is None: + uri = "ws://localhost/" + else: + uri = "wss://localhost/" + return connect(uri=uri, unix=True, path=path, **kwargs) diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py new file mode 100644 index 000000000..64e5c8b44 --- /dev/null +++ b/src/websockets/sync/connection.py @@ -0,0 +1,757 @@ +from __future__ import annotations + +import contextlib +import logging +import random +import socket +import struct +import threading +import uuid +from types import TracebackType +from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Type, Union + +from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError +from ..frames import DATA_OPCODES, BytesLike, Frame, Opcode, prepare_ctrl +from ..http11 import Request, Response +from ..protocol import CLOSED, OPEN, Event, Protocol, State +from ..typing import Data, LoggerLike, Subprotocol +from .messages import Assembler +from .utils import Deadline + + +__all__ = ["Connection"] + +logger = logging.getLogger(__name__) + +BUFSIZE = 65536 + + +class Connection: + """ + Threaded implementation of a WebSocket connection. + + :class:`Connection` provides APIs shared between WebSocket servers and + clients. + + You shouldn't use it directly. Instead, use + :class:`~websockets.sync.client.ClientConnection` or + :class:`~websockets.sync.server.ServerConnection`. + + """ + + def __init__( + self, + socket: socket.socket, + protocol: Protocol, + *, + close_timeout: Optional[float] = 10, + ) -> None: + self.socket = socket + self.protocol = protocol + self.close_timeout = close_timeout + + # Inject reference to this instance in the protocol's logger. + self.protocol.logger = logging.LoggerAdapter( + self.protocol.logger, + {"websocket": self}, + ) + + # Copy attributes from the protocol for convenience. + self.id: uuid.UUID = self.protocol.id + """Unique identifier of the connection. Useful in logs.""" + self.logger: LoggerLike = self.protocol.logger + """Logger for this connection.""" + self.debug = self.protocol.debug + + # HTTP handshake request and response. + self.request: Optional[Request] = None + """Opening handshake request.""" + self.response: Optional[Response] = None + """Opening handshake response.""" + + # Mutex serializing interactions with the protocol. + self.protocol_mutex = threading.Lock() + + # Assembler turning frames into messages and serializing reads. + self.recv_messages = Assembler() + + # Whether we are busy sending a fragmented message. + self.send_in_progress = False + + # Deadline for the closing handshake. + self.close_deadline: Optional[Deadline] = None + + # Mapping of ping IDs to pong waiters, in chronological order. + self.pings: Dict[bytes, threading.Event] = {} + + # Receiving events from the socket. + self.recv_events_thread = threading.Thread(target=self.recv_events) + self.recv_events_thread.start() + + # Exception raised in recv_events, to be chained to ConnectionClosed + # in the user thread in order to show why the TCP connection dropped. + self.recv_events_exc: Optional[BaseException] = None + + # Public attributes + + @property + def local_address(self) -> Any: + """ + Local address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family. + See :meth:`~socket.socket.getsockname`. + + """ + return self.socket.getsockname() + + @property + def remote_address(self) -> Any: + """ + Remote address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family. + See :meth:`~socket.socket.getpeername`. + + """ + return self.socket.getpeername() + + @property + def subprotocol(self) -> Optional[Subprotocol]: + """ + Subprotocol negotiated during the opening handshake. + + :obj:`None` if no subprotocol was negotiated. + + """ + return self.protocol.subprotocol + + # Public methods + + def __enter__(self) -> Connection: + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + self.close(1000 if exc_type is None else 1011) + + def __iter__(self) -> Iterator[Data]: + """ + Iterate on incoming messages. + + The iterator calls :meth:`recv` and yields messages in an infinite loop. + + It exits when the connection is closed normally. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` exception after a + protocol error or a network failure. + + """ + try: + while True: + yield self.recv() + except ConnectionClosedOK: + return + + def recv(self, timeout: Optional[float] = None) -> Data: + """ + Receive the next message. + + When the connection is closed, :meth:`recv` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it raises + :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal closure + and :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. This is how you detect the end of the + message stream. + + If ``timeout`` is :obj:`None`, block until a message is received. If + ``timeout`` is set and no message is received within ``timeout`` + seconds, raise :exc:`TimeoutError`. Set ``timeout`` to ``0`` to check if + a message was already received. + + If the message is fragmented, wait until all fragments are received, + reassemble them, and return the whole message. + + Returns: + A string (:class:`str`) for a Text_ frame or a bytestring + (:class:`bytes`) for a Binary_ frame. + + .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If two threads call :meth:`recv` or + :meth:`recv_streaming` concurrently. + + """ + try: + return self.recv_messages.get(timeout) + except EOFError: + raise self.protocol.close_exc from self.recv_events_exc + except RuntimeError: + raise RuntimeError( + "cannot call recv while another thread " + "is already running recv or recv_streaming" + ) from None + + def recv_streaming(self) -> Iterator[Data]: + """ + Receive the next message frame by frame. + + If the message is fragmented, yield each fragment as it is received. + The iterator must be fully consumed, or else the connection will become + unusable. + + :meth:`recv_streaming` raises the same exceptions as :meth:`recv`. + + Returns: + An iterator of strings (:class:`str`) for a Text_ frame or + bytestrings (:class:`bytes`) for a Binary_ frame. + + .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If two threads call :meth:`recv` or + :meth:`recv_streaming` concurrently. + + """ + try: + yield from self.recv_messages.get_iter() + except EOFError: + raise self.protocol.close_exc from self.recv_events_exc + except RuntimeError: + raise RuntimeError( + "cannot call recv_streaming while another thread " + "is already running recv or recv_streaming" + ) from None + + def send(self, message: Union[Data, Iterable[Data]]) -> None: + """ + Send a message. + + A string (:class:`str`) is sent as a Text_ frame. A bytestring or + bytes-like object (:class:`bytes`, :class:`bytearray`, or + :class:`memoryview`) is sent as a Binary_ frame. + + .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + + :meth:`send` also accepts an iterable of strings, bytestrings, or + bytes-like objects to enable fragmentation_. Each item is treated as a + message fragment and sent in its own frame. All items must be of the + same type, or else :meth:`send` will raise a :exc:`TypeError` and the + connection will be closed. + + .. _fragmentation: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.4 + + :meth:`send` rejects dict-like objects because this is often an error. + (If you really want to send the keys of a dict-like object as fragments, + call its :meth:`~dict.keys` method and pass the result to :meth:`send`.) + + When the connection is closed, :meth:`send` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it + raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal + connection closure and + :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. + + Args: + message: Message to send. + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If a connection is busy sending a fragmented message. + TypeError: If ``message`` doesn't have a supported type. + + """ + # Unfragmented message -- this case must be handled first because + # strings and bytes-like objects are iterable. + + if isinstance(message, str): + with self.send_context(): + if self.send_in_progress: + raise RuntimeError( + "cannot call send while another thread " + "is already running send" + ) + self.protocol.send_text(message.encode("utf-8")) + + elif isinstance(message, BytesLike): + with self.send_context(): + if self.send_in_progress: + raise RuntimeError( + "cannot call send while another thread " + "is already running send" + ) + self.protocol.send_binary(message) + + # Catch a common mistake -- passing a dict to send(). + + elif isinstance(message, Mapping): + raise TypeError("data is a dict-like object") + + # Fragmented message -- regular iterator. + + elif isinstance(message, Iterable): + chunks = iter(message) + try: + chunk = next(chunks) + except StopIteration: + return + + try: + # First fragment. + if isinstance(chunk, str): + text = True + with self.send_context(): + if self.send_in_progress: + raise RuntimeError( + "cannot call send while another thread " + "is already running send" + ) + self.send_in_progress = True + self.protocol.send_text( + chunk.encode("utf-8"), + fin=False, + ) + elif isinstance(chunk, BytesLike): + text = False + with self.send_context(): + if self.send_in_progress: + raise RuntimeError( + "cannot call send while another thread " + "is already running send" + ) + self.send_in_progress = True + self.protocol.send_binary( + chunk, + fin=False, + ) + else: + raise TypeError("data iterable must contain bytes or str") + + # Other fragments + for chunk in chunks: + if isinstance(chunk, str) and text: + with self.send_context(): + assert self.send_in_progress + self.protocol.send_continuation( + chunk.encode("utf-8"), + fin=False, + ) + elif isinstance(chunk, BytesLike) and not text: + with self.send_context(): + assert self.send_in_progress + self.protocol.send_continuation( + chunk, + fin=False, + ) + else: + raise TypeError("data iterable must contain uniform types") + + # Final fragment. + with self.send_context(): + self.protocol.send_continuation(b"", fin=True) + self.send_in_progress = False + + except RuntimeError: + # We didn't start sending a fragmented message. + raise + + except Exception: + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + with self.send_context(): + self.protocol.fail(1011, "error in fragmented message") + raise + + else: + raise TypeError("data must be bytes, str, or iterable") + + def close(self, code: int = 1000, reason: str = "") -> None: + """ + Perform the closing handshake. + + :meth:`close` waits for the other end to complete the handshake and + for the TCP connection to terminate. + + :meth:`close` is idempotent: it doesn't do anything once the + connection is closed. + + Args: + code: WebSocket close code. + reason: WebSocket close reason. + + """ + try: + # The context manager takes care of waiting for the TCP connection + # to terminate after calling a method that sends a close frame. + with self.send_context(): + if self.send_in_progress: + self.protocol.fail(1011, "close during fragmented message") + else: + self.protocol.send_close(code, reason) + except ConnectionClosed: + # Ignore ConnectionClosed exceptions raised from send_context(). + # They mean that the connection is closed, which was the goal. + pass + + def ping(self, data: Optional[Data] = None) -> threading.Event: + """ + Send a Ping_. + + .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 + + A ping may serve as a keepalive or as a check that the remote endpoint + received all messages up to this point + + Args: + data: Payload of the ping. A :class:`str` will be encoded to UTF-8. + If ``data`` is :obj:`None`, the payload is four random bytes. + + Returns: + An event that will be set when the corresponding pong is received. + You can ignore it if you don't intend to wait. + + :: + + pong_event = ws.ping() + pong_event.wait() # only if you want to wait for the pong + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If another ping was sent with the same data and + the corresponding pong wasn't received yet. + + """ + if data is not None: + data = prepare_ctrl(data) + + with self.send_context(): + # Protect against duplicates if a payload is explicitly set. + if data in self.pings: + raise RuntimeError("already waiting for a pong with the same data") + + # Generate a unique random payload otherwise. + while data is None or data in self.pings: + data = struct.pack("!I", random.getrandbits(32)) + + pong_waiter = threading.Event() + self.pings[data] = pong_waiter + self.protocol.send_ping(data) + return pong_waiter + + def pong(self, data: Data = b"") -> None: + """ + Send a Pong_. + + .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 + + An unsolicited pong may serve as a unidirectional heartbeat. + + Args: + data: Payload of the pong. A :class:`str` will be encoded to UTF-8. + + Raises: + ConnectionClosed: When the connection is closed. + + """ + data = prepare_ctrl(data) + + with self.send_context(): + self.protocol.send_pong(data) + + # Private methods + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + This method is overridden in subclasses to handle the handshake. + + """ + assert isinstance(event, Frame) + if event.opcode in DATA_OPCODES: + self.recv_messages.put(event) + + if event.opcode is Opcode.PONG: + self.acknowledge_pings(bytes(event.data)) + + def acknowledge_pings(self, data: bytes) -> None: + """ + Acknowledge pings when receiving a pong. + + """ + with self.protocol_mutex: + # Ignore unsolicited pong. + if data not in self.pings: + return + # Sending a pong for only the most recent ping is legal. + # Acknowledge all previous pings too in that case. + ping_id = None + ping_ids = [] + for ping_id, ping in self.pings.items(): + ping_ids.append(ping_id) + ping.set() + if ping_id == data: + break + else: + raise AssertionError("solicited pong not found in pings") + # Remove acknowledged pings from self.pings. + for ping_id in ping_ids: + del self.pings[ping_id] + + def recv_events(self) -> None: + """ + Read incoming data from the socket and process events. + + Run this method in a thread as long as the connection is alive. + + ``recv_events()`` exits immediately when the ``self.socket`` is closed. + + """ + try: + while True: + try: + if self.close_deadline is not None: + self.socket.settimeout(self.close_deadline.timeout()) + data = self.socket.recv(BUFSIZE) + except Exception as exc: + if self.debug: + self.logger.debug("error while receiving data", exc_info=True) + # When the closing handshake is initiated by our side, + # recv() may block until send_context() closes the socket. + # In that case, send_context() already set recv_events_exc. + # Calling set_recv_events_exc() avoids overwriting it. + with self.protocol_mutex: + self.set_recv_events_exc(exc) + break + + if data == b"": + break + + # Acquire the connection lock. + with self.protocol_mutex: + # Feed incoming data to the connection. + self.protocol.receive_data(data) + + # This isn't expected to raise an exception. + events = self.protocol.events_received() + + # Write outgoing data to the socket. + try: + self.send_data() + except Exception as exc: + if self.debug: + self.logger.debug("error while sending data", exc_info=True) + # Similarly to the above, avoid overriding an exception + # set by send_context(), in case of a race condition + # i.e. send_context() closes the socket after recv() + # returns above but before send_data() calls send(). + self.set_recv_events_exc(exc) + break + + if self.protocol.close_expected(): + # If the connection is expected to close soon, set the + # close deadline based on the close timeout. + if self.close_deadline is None: + self.close_deadline = Deadline(self.close_timeout) + + # Unlock conn_mutex before processing events. Else, the + # application can't send messages in response to events. + + # If self.send_data raised an exception, then events are lost. + # Given that automatic responses write small amounts of data, + # this should be uncommon, so we don't handle the edge case. + + for event in events: + # This isn't expected to raise an exception. + self.process_event(event) + + # Breaking out of the while True: ... loop means that we believe + # that the socket doesn't work anymore. + with self.protocol_mutex: + # Feed the end of the data stream to the connection. + self.protocol.receive_eof() + + # This isn't expected to generate events. + assert not self.protocol.events_received() + + # There is no error handling because send_data() can only write + # the end of the data stream here and it handles errors itself. + self.send_data() + + except Exception as exc: + # This branch should never run. It's a safety net in case of bugs. + self.logger.error("unexpected internal error", exc_info=True) + with self.protocol_mutex: + self.set_recv_events_exc(exc) + # We don't know where we crashed. Force protocol state to CLOSED. + self.protocol.state = CLOSED + finally: + # This isn't expected to raise an exception. + self.recv_messages.close() + self.close_socket() + + @contextlib.contextmanager + def send_context( + self, + *, + expected_state: State = OPEN, # CONNECTING during the opening handshake + ) -> Iterator[None]: + """ + Create a context for writing to the connection from user code. + + On entry, :meth:`send_context` acquires the connection lock and checks + that the connection is open; on exit, it writes outgoing data to the + socket:: + + with self.send_context(): + self.protocol.send_text(message.encode("utf-8")) + + When the connection isn't open on entry, when the connection is expected + to close on exit, or when an unexpected error happens, terminating the + connection, :meth:`send_context` waits until the connection is closed + then raises :exc:`~websockets.exceptions.ConnectionClosed`. + + """ + # Should we wait until the connection is closed? + wait_for_close = False + # Should we close the socket and raise ConnectionClosed? + raise_close_exc = False + # What exception should we chain ConnectionClosed to? + original_exc: Optional[BaseException] = None + + # Acquire the protocol lock. + with self.protocol_mutex: + if self.protocol.state is expected_state: + # Let the caller interact with the protocol. + try: + yield + except (ProtocolError, RuntimeError): + # The protocol state wasn't changed. Exit immediately. + raise + except Exception as exc: + self.logger.error("unexpected internal error", exc_info=True) + # This branch should never run. It's a safety net in case of + # bugs. Since we don't know what happened, we will close the + # connection and raise the exception to the caller. + wait_for_close = False + raise_close_exc = True + original_exc = exc + else: + # Check if the connection is expected to close soon. + if self.protocol.close_expected(): + wait_for_close = True + # If the connection is expected to close soon, set the + # close deadline based on the close timeout. + + # Since we tested earlier that protocol.state was OPEN + # (or CONNECTING) and we didn't release protocol_mutex, + # it is certain that self.close_deadline is still None. + assert self.close_deadline is None + self.close_deadline = Deadline(self.close_timeout) + # Write outgoing data to the socket. + try: + self.send_data() + except Exception as exc: + if self.debug: + self.logger.debug("error while sending data", exc_info=True) + # While the only expected exception here is OSError, + # other exceptions would be treated identically. + wait_for_close = False + raise_close_exc = True + original_exc = exc + + else: # self.protocol.state is not expected_state + # Minor layering violation: we assume that the connection + # will be closing soon if it isn't in the expected state. + wait_for_close = True + raise_close_exc = True + + # To avoid a deadlock, release the connection lock by exiting the + # context manager before waiting for recv_events() to terminate. + + # If the connection is expected to close soon and the close timeout + # elapses, close the socket to terminate the connection. + if wait_for_close: + if self.close_deadline is None: + timeout = self.close_timeout + else: + # Thread.join() returns immediately if timeout is negative. + timeout = self.close_deadline.timeout(raise_if_elapsed=False) + self.recv_events_thread.join(timeout) + + if self.recv_events_thread.is_alive(): + # There's no risk to overwrite another error because + # original_exc is never set when wait_for_close is True. + assert original_exc is None + original_exc = TimeoutError("timed out while closing connection") + # Set recv_events_exc before closing the socket in order to get + # proper exception reporting. + raise_close_exc = True + with self.protocol_mutex: + self.set_recv_events_exc(original_exc) + + # If an error occurred, close the socket to terminate the connection and + # raise an exception. + if raise_close_exc: + self.close_socket() + self.recv_events_thread.join() + raise self.protocol.close_exc from original_exc + + def send_data(self) -> None: + """ + Send outgoing data. + + This method requires holding protocol_mutex. + + Raises: + OSError: When a socket operations fails. + + """ + assert self.protocol_mutex.locked() + for data in self.protocol.data_to_send(): + if data: + if self.close_deadline is not None: + self.socket.settimeout(self.close_deadline.timeout()) + self.socket.sendall(data) + else: + try: + self.socket.shutdown(socket.SHUT_WR) + except OSError: # socket already closed + pass + + def set_recv_events_exc(self, exc: Optional[BaseException]) -> None: + """ + Set recv_events_exc, if not set yet. + + This method requires holding protocol_mutex. + + """ + assert self.protocol_mutex.locked() + if self.recv_events_exc is None: + self.recv_events_exc = exc + + def close_socket(self) -> None: + """ + Shutdown and close socket. + + shutdown() is required to interrupt recv() on Linux. + + """ + try: + self.socket.shutdown(socket.SHUT_RDWR) + except OSError: + pass # socket is already closed + self.socket.close() diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 9265e4e9f..67a22313c 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -67,12 +67,12 @@ def get(self, timeout: Optional[float] = None) -> Data: messages frame by frame, use :meth:`get_iter` instead. Args: - timeout: if a timeout is provided and elapses before a complete + timeout: If a timeout is provided and elapses before a complete message is received, :meth:`get` raises :exc:`TimeoutError`. Raises: - EOFError: if the stream of frames has ended. - RuntimeError: if two threads run :meth:`get` or :meth:``get_iter` + EOFError: If the stream of frames has ended. + RuntimeError: If two threads run :meth:`get` or :meth:``get_iter` concurrently. """ @@ -130,8 +130,8 @@ def get_iter(self) -> Iterator[Data]: fragmented, use :meth:`get` instead. Raises: - EOFError: if the stream of frames has ended. - RuntimeError: if two threads run :meth:`get` or :meth:``get_iter` + EOFError: If the stream of frames has ended. + RuntimeError: If two threads run :meth:`get` or :meth:``get_iter` concurrently. """ @@ -194,8 +194,8 @@ def put(self, frame: Frame) -> None: it doesn't, the behavior is undefined. Raises: - EOFError: if the stream of frames has ended. - RuntimeError: if two threads run :meth:`put` concurrently. + EOFError: If the stream of frames has ended. + RuntimeError: If two threads run :meth:`put` concurrently. """ with self.mutex: diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py new file mode 100644 index 000000000..a53ae2b25 --- /dev/null +++ b/src/websockets/sync/server.py @@ -0,0 +1,525 @@ +from __future__ import annotations + +import http +import logging +import os +import select +import socket +import ssl +import threading +from types import TracebackType +from typing import Any, Callable, Optional, Sequence, Type + +from ..extensions.base import ServerExtensionFactory +from ..extensions.permessage_deflate import enable_server_permessage_deflate +from ..headers import validate_subprotocols +from ..http import USER_AGENT +from ..http11 import Request, Response +from ..protocol import CONNECTING, OPEN, Event +from ..server import ServerProtocol +from ..typing import LoggerLike, Origin, Subprotocol +from .compatibility import socket_create_server +from .connection import Connection +from .utils import Deadline + + +__all__ = ["serve", "unix_serve", "ServerConnection", "WebSocketServer"] + + +class ServerConnection(Connection): + """ + Threaded implementation of a WebSocket server connection. + + :class:`ServerConnection` provides :meth:`recv` and :meth:`send` methods for + receiving and sending messages. + + It supports iteration to receive messages:: + + for message in websocket: + process(message) + + The iterator exits normally when the connection is closed with close code + 1000 (OK) or 1001 (going away) or without a close code. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is + closed with any other code. + + Args: + socket: Socket connected to a WebSocket client. + protocol: Sans-I/O connection. + close_timeout: Timeout for closing the connection in seconds. + + """ + + def __init__( + self, + socket: socket.socket, + protocol: ServerProtocol, + *, + close_timeout: Optional[float] = 10, + ) -> None: + self.protocol: ServerProtocol + self.request_rcvd = threading.Event() + super().__init__( + socket, + protocol, + close_timeout=close_timeout, + ) + + def handshake( + self, + process_request: Optional[ + Callable[ + [ServerConnection, Request], + Optional[Response], + ] + ] = None, + process_response: Optional[ + Callable[ + [ServerConnection, Request, Response], + Optional[Response], + ] + ] = None, + server_header: Optional[str] = USER_AGENT, + timeout: Optional[float] = None, + ) -> None: + """ + Perform the opening handshake. + + """ + if not self.request_rcvd.wait(timeout): + self.close_socket() + self.recv_events_thread.join() + raise TimeoutError("timed out during handshake") + + if self.request is None: + self.close_socket() + self.recv_events_thread.join() + raise ConnectionError("connection closed during handshake") + + with self.send_context(expected_state=CONNECTING): + self.response = None + + if process_request is not None: + try: + self.response = process_request(self, self.request) + except Exception as exc: + self.protocol.handshake_exc = exc + self.logger.error("opening handshake failed", exc_info=True) + self.response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + + if self.response is None: + self.response = self.protocol.accept(self.request) + + if server_header is not None: + self.response.headers["Server"] = server_header + + if process_response is not None: + try: + response = process_response(self, self.request, self.response) + except Exception as exc: + self.protocol.handshake_exc = exc + self.logger.error("opening handshake failed", exc_info=True) + self.response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + else: + if response is not None: + self.response = response + + self.protocol.send_response(self.response) + + if self.protocol.state is not OPEN: + self.recv_events_thread.join(self.close_timeout) + self.close_socket() + self.recv_events_thread.join() + + if self.protocol.handshake_exc is not None: + raise self.protocol.handshake_exc + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + """ + # First event - handshake request. + if self.request is None: + assert isinstance(event, Request) + self.request = event + self.request_rcvd.set() + # Later events - frames. + else: + super().process_event(event) + + def recv_events(self) -> None: + """ + Read incoming data from the socket and process events. + + """ + try: + super().recv_events() + finally: + # If the connection is closed during the handshake, unblock it. + self.request_rcvd.set() + + +class WebSocketServer: + """ + WebSocket server returned by :func:`serve`. + + This class mirrors the API of :class:`~socketserver.BaseServer`, notably the + :meth:`~socketserver.BaseServer.serve_forever` and + :meth:`~socketserver.BaseServer.shutdown` methods, as well as the context + manager protocol. + + Args: + socket: Server socket listening for new connections. + handler: Handler for one connection. Receives the socket and address + returned by :meth:`~socket.socket.accept`. + logger: Logger for this server. + + """ + + def __init__( + self, + socket: socket.socket, + handler: Callable[[socket.socket, Any], None], + logger: Optional[LoggerLike] = None, + ): + self.socket = socket + self.handler = handler + if logger is None: + logger = logging.getLogger("websockets.server") + self.logger = logger + self.shutdown_watcher, self.shutdown_notifier = os.pipe() + + def serve_forever(self) -> None: + """ + See :meth:`socketserver.BaseServer.serve_forever`. + + This method doesn't return. Calling :meth:`shutdown` from another thread + stops the server. + + Typical use:: + + with serve(...) as server: + server.serve_forever() + + """ + poller = select.poll() + poller.register(self.socket) + poller.register(self.shutdown_watcher) + + while True: + poller.poll() + try: + # If the socket is closed, this will raise an exception and exit + # the loop. So we don't need to check the return value of poll(). + sock, addr = self.socket.accept() + except OSError: + break + thread = threading.Thread(target=self.handler, args=(sock, addr)) + thread.start() + + def shutdown(self) -> None: + """ + See :meth:`socketserver.BaseServer.shutdown`. + + """ + self.socket.close() + os.write(self.shutdown_notifier, b"x") + + def fileno(self) -> int: + """ + See :meth:`socketserver.BaseServer.fileno`. + + """ + return self.socket.fileno() + + def __enter__(self) -> WebSocketServer: + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + self.shutdown() + + +def serve( + handler: Callable[[ServerConnection], None], + host: Optional[str] = None, + port: Optional[int] = None, + *, + # TCP/TLS — unix and path are only for unix_serve() + sock: Optional[socket.socket] = None, + ssl_context: Optional[ssl.SSLContext] = None, + unix: bool = False, + path: Optional[str] = None, + # WebSocket + origins: Optional[Sequence[Optional[Origin]]] = None, + extensions: Optional[Sequence[ServerExtensionFactory]] = None, + subprotocols: Optional[Sequence[Subprotocol]] = None, + select_subprotocol: Optional[ + Callable[ + [ServerConnection, Sequence[Subprotocol]], + Optional[Subprotocol], + ] + ] = None, + process_request: Optional[ + Callable[ + [ServerConnection, Request], + Optional[Response], + ] + ] = None, + process_response: Optional[ + Callable[ + [ServerConnection, Request, Response], + Optional[Response], + ] + ] = None, + server_header: Optional[str] = USER_AGENT, + compression: Optional[str] = "deflate", + # Timeouts + open_timeout: Optional[float] = 10, + close_timeout: Optional[float] = 10, + # Limits + max_size: Optional[int] = 2**20, + # Logging + logger: Optional[LoggerLike] = None, + # Escape hatch for advanced customization + create_connection: Optional[Type[ServerConnection]] = None, +) -> WebSocketServer: + """ + Create a WebSocket server listening on ``host`` and ``port``. + + Whenever a client connects, the server creates a :class:`ServerConnection`, + performs the opening handshake, and delegates to the ``handler``. + + The handler receives a :class:`ServerConnection` instance, which you can use + to send and receive messages. + + Once the handler completes, either normally or with an exception, the server + performs the closing handshake and closes the connection. + + :class:`WebSocketServer` mirrors the API of + :class:`~socketserver.BaseServer`. Treat it as a context manager to ensure + that it will be closed and call the :meth:`~WebSocketServer.serve_forever` + method to serve requests:: + + def handler(websocket): + ... + + with websockets.sync.server.serve(handler, ...) as server: + server.serve_forever() + + Args: + handler: Connection handler. It receives the WebSocket connection, + which is a :class:`ServerConnection`, in argument. + host: Network interfaces the server binds to. + See :func:`~socket.create_server` for details. + port: TCP port the server listens on. + See :func:`~socket.create_server` for details. + sock: Preexisting TCP socket. ``sock`` replaces ``host`` and ``port``. + You may call :func:`socket.create_server` to create a suitable TCP + socket. + ssl_context: Configuration for enabling TLS on the connection. + origins: Acceptable values of the ``Origin`` header, for defending + against Cross-Site WebSocket Hijacking attacks. Include :obj:`None` + in the list if the lack of an origin is acceptable. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + select_subprotocol: Callback for selecting a subprotocol among + those supported by the client and the server. It receives a + :class:`ServerConnection` (not a + :class:`~websockets.server.ServerProtocol`!) instance and a list of + subprotocols offered by the client. Other than the first argument, + it has the same behavior as the + :meth:`ServerProtocol.select_subprotocol + ` method. + process_request: Intercept the request during the opening handshake. + Return an HTTP response to force the response or :obj:`None` to + continue normally. When you force a HTTP 101 Continue response, + the handshake is successful. Else, the connection is aborted. + process_response: Intercept the response during the opening handshake. + Return an HTTP response to force the response or :obj:`None` to + continue normally. When you force a HTTP 101 Continue response, + the handshake is successful. Else, the connection is aborted. + server_header: Value of the ``Server`` response header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to + :obj:`None` removes the header. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + open_timeout: Timeout for opening connections in seconds. + :obj:`None` disables the timeout. + close_timeout: Timeout for closing connections in seconds. + :obj:`None` disables the timeout. + max_size: Maximum size of incoming messages in bytes. + :obj:`None` disables the limit. + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. See the + :doc:`logging guide <../../topics/logging>` for details. + create_connection: Factory for the :class:`ServerConnection` managing + the connection. Set it to a wrapper or a subclass to customize + connection handling. + """ + + # Process parameters + + if subprotocols is not None: + validate_subprotocols(subprotocols) + + if compression == "deflate": + extensions = enable_server_permessage_deflate(extensions) + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + if create_connection is None: + create_connection = ServerConnection + + # Bind socket and listen + + if sock is None: + if unix: + if path is None: + raise TypeError("missing path argument") + sock = socket_create_server(path, family=socket.AF_UNIX) + else: + sock = socket_create_server((host, port)) + else: + if path is not None: + raise TypeError("path and sock arguments are incompatible") + + # Initialize TLS wrapper + + if ssl_context is not None: + sock = ssl_context.wrap_socket( + sock, + server_side=True, + # Delay TLS handshake until after we set a timeout on the socket. + do_handshake_on_connect=False, + ) + + # Define request handler + + def conn_handler(sock: socket.socket, addr: Any) -> None: + # Calculate timeouts on the TLS and WebSocket handshakes. + # The TLS timeout must be set on the socket, then removed + # to avoid conflicting with the WebSocket timeout in handshake(). + deadline = Deadline(open_timeout) + + try: + # Disable Nagle algorithm + + if not unix: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True) + + # Perform TLS handshake + + if ssl_context is not None: + sock.settimeout(deadline.timeout()) + assert isinstance(sock, ssl.SSLSocket) # mypy cannot figure this out + sock.do_handshake() + sock.settimeout(None) + + # Create a closure so that select_subprotocol has access to self. + + protocol_select_subprotocol: Optional[ + Callable[ + [ServerProtocol, Sequence[Subprotocol]], + Optional[Subprotocol], + ] + ] = None + + if select_subprotocol is not None: + + def protocol_select_subprotocol( + protocol: ServerProtocol, + subprotocols: Sequence[Subprotocol], + ) -> Optional[Subprotocol]: + # mypy doesn't know that select_subprotocol is immutable. + assert select_subprotocol is not None + # Ensure this function is only used in the intended context. + assert protocol is connection.protocol + return select_subprotocol(connection, subprotocols) + + # Initialize WebSocket connection + + protocol = ServerProtocol( + origins=origins, + extensions=extensions, + subprotocols=subprotocols, + select_subprotocol=protocol_select_subprotocol, + state=CONNECTING, + max_size=max_size, + logger=logger, + ) + + # Initialize WebSocket protocol + + assert create_connection is not None # help mypy + connection = create_connection( + sock, + protocol, + close_timeout=close_timeout, + ) + # On failure, handshake() closes the socket, raises an exception, and + # logs it. + connection.handshake( + process_request, + process_response, + server_header, + deadline.timeout(), + ) + + except Exception: + sock.close() + return + + try: + handler(connection) + except Exception: + protocol.logger.error("connection handler failed", exc_info=True) + connection.close(1011) + else: + connection.close() + + # Initialize server + + return WebSocketServer(sock, conn_handler, logger) + + +def unix_serve( + handler: Callable[[ServerConnection], Any], + path: Optional[str] = None, + **kwargs: Any, +) -> WebSocketServer: + """ + Create a WebSocket server listening on a Unix socket. + + This function is identical to :func:`serve`, except the ``host`` and + ``port`` arguments are replaced by ``path``. It's only available on Unix. + + It's useful for deploying a server behind a reverse proxy such as nginx. + + Args: + handler: Connection handler. It receives the WebSocket connection, + which is a :class:`ServerConnection`, in argument. + path: File system path to the Unix socket. + + """ + return serve(handler, path=path, unix=True, **kwargs) diff --git a/src/websockets/sync/utils.py b/src/websockets/sync/utils.py index 8aab6c0d9..471f32e19 100644 --- a/src/websockets/sync/utils.py +++ b/src/websockets/sync/utils.py @@ -12,7 +12,7 @@ class Deadline: Manage timeouts across multiple steps. Args: - timeout: time available in seconds; :obj:`None` if there is no limit. + timeout: Time available in seconds or :obj:`None` if there is no limit. """ @@ -28,15 +28,14 @@ def timeout(self, *, raise_if_elapsed: bool = True) -> Optional[float]: Calculate a timeout from a deadline. Args: - raise_if_elapsed (bool): whether to raise :exc:`TimeoutError` + raise_if_elapsed (bool): Whether to raise :exc:`TimeoutError` if the deadline lapsed. Raises: - TimeoutError: if the deadline lapsed. + TimeoutError: If the deadline lapsed. Returns: - Optional[float]: Time left in seconds; - :obj:`None` if there is no limit. + Time left in seconds or :obj:`None` if there is no limit. """ if self.deadline is None: diff --git a/tests/protocol.py b/tests/protocol.py new file mode 100644 index 000000000..4e843daab --- /dev/null +++ b/tests/protocol.py @@ -0,0 +1,29 @@ +from websockets.protocol import Protocol + + +class RecordingProtocol(Protocol): + """ + Protocol subclass that records incoming frames. + + By interfacing with this protocol, you can check easily what the component + being testing sends during a test. + + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.frames_rcvd = [] + + def get_frames_rcvd(self): + """ + Get incoming frames received up to this point. + + Calling this method clears the list. Each frame is returned only once. + + """ + frames_rcvd, self.frames_rcvd = self.frames_rcvd, [] + return frames_rcvd + + def recv_frame(self, frame): + self.frames_rcvd.append(frame) + super().recv_frame(frame) diff --git a/tests/sync/client.py b/tests/sync/client.py new file mode 100644 index 000000000..51bbd4388 --- /dev/null +++ b/tests/sync/client.py @@ -0,0 +1,55 @@ +import contextlib +import ssl +import sys +import warnings + +from websockets.sync.client import * +from websockets.sync.server import WebSocketServer + +from ..utils import CERTIFICATE + + +__all__ = [ + "CLIENT_CONTEXT", + "run_client", + "run_unix_client", +] + + +CLIENT_CONTEXT = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) +CLIENT_CONTEXT.load_verify_locations(CERTIFICATE) + +# Work around https://github.com/openssl/openssl/issues/7967 + +# This bug causes connect() to hang in tests for the client. Including this +# workaround acknowledges that the issue could happen outside of the test suite. + +# It shouldn't happen too often, or else OpenSSL 1.1.1 would be unusable. If it +# happens, we can look for a library-level fix, but it won't be easy. + +if sys.version_info[:2] < (3, 8): # pragma: no cover + # ssl.OP_NO_TLSv1_3 was introduced and deprecated on Python 3.7. + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + CLIENT_CONTEXT.options |= ssl.OP_NO_TLSv1_3 + + +@contextlib.contextmanager +def run_client(wsuri_or_server, secure=None, resource_name="/", **kwargs): + if isinstance(wsuri_or_server, str): + wsuri = wsuri_or_server + else: + assert isinstance(wsuri_or_server, WebSocketServer) + if secure is None: + secure = "ssl_context" in kwargs + protocol = "wss" if secure else "ws" + host, port = wsuri_or_server.socket.getsockname() + wsuri = f"{protocol}://{host}:{port}{resource_name}" + with connect(wsuri, **kwargs) as client: + yield client + + +@contextlib.contextmanager +def run_unix_client(path, **kwargs): + with unix_connect(path, **kwargs) as client: + yield client diff --git a/tests/sync/connection.py b/tests/sync/connection.py new file mode 100644 index 000000000..89d4909ee --- /dev/null +++ b/tests/sync/connection.py @@ -0,0 +1,109 @@ +import contextlib +import time + +from websockets.sync.connection import Connection + + +class InterceptingConnection(Connection): + """ + Connection subclass that can intercept outgoing packets. + + By interfacing with this connection, you can simulate network conditions + affecting what the component being tested receives during a test. + + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.socket = InterceptingSocket(self.socket) + + @contextlib.contextmanager + def delay_frames_sent(self, delay): + """ + Add a delay before sending frames. + + Delays cumulate: they're added before every frame or before EOF. + + """ + assert self.socket.delay_sendall is None + self.socket.delay_sendall = delay + try: + yield + finally: + self.socket.delay_sendall = None + + @contextlib.contextmanager + def delay_eof_sent(self, delay): + """ + Add a delay before sending EOF. + + Delays cumulate: they're added before every frame or before EOF. + + """ + assert self.socket.delay_shutdown is None + self.socket.delay_shutdown = delay + try: + yield + finally: + self.socket.delay_shutdown = None + + @contextlib.contextmanager + def drop_frames_sent(self): + """ + Prevent frames from being sent. + + Since TCP is reliable, sending frames or EOF afterwards is unrealistic. + + """ + assert not self.socket.drop_sendall + self.socket.drop_sendall = True + try: + yield + finally: + self.socket.drop_sendall = False + + @contextlib.contextmanager + def drop_eof_sent(self): + """ + Prevent EOF from being sent. + + Since TCP is reliable, sending frames or EOF afterwards is unrealistic. + + """ + assert not self.socket.drop_shutdown + self.socket.drop_shutdown = True + try: + yield + finally: + self.socket.drop_shutdown = False + + +class InterceptingSocket: + """ + Socket wrapper that intercepts calls to sendall and shutdown. + + This is coupled to the implementation, which relies on these two methods. + + """ + + def __init__(self, socket): + self.socket = socket + self.delay_sendall = None + self.delay_shutdown = None + self.drop_sendall = False + self.drop_shutdown = False + + def __getattr__(self, name): + return getattr(self.socket, name) + + def sendall(self, bytes, flags=0): + if self.delay_sendall is not None: + time.sleep(self.delay_sendall) + if not self.drop_sendall: + self.socket.sendall(bytes, flags) + + def shutdown(self, how): + if self.delay_shutdown is not None: + time.sleep(self.delay_shutdown) + if not self.drop_shutdown: + self.socket.shutdown(how) diff --git a/tests/sync/server.py b/tests/sync/server.py new file mode 100644 index 000000000..5f0cd3b07 --- /dev/null +++ b/tests/sync/server.py @@ -0,0 +1,67 @@ +import contextlib +import ssl +import sys +import threading + +from websockets.sync.server import * + +from ..utils import CERTIFICATE + + +SERVER_CONTEXT = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) +SERVER_CONTEXT.load_cert_chain(CERTIFICATE) + +# Work around https://github.com/openssl/openssl/issues/7967 + +# This bug causes connect() to hang in tests for the client. Including this +# workaround acknowledges that the issue could happen outside of the test suite. + +# It shouldn't happen too often, or else OpenSSL 1.1.1 would be unusable. If it +# happens, we can look for a library-level fix, but it won't be easy. + +if sys.version_info[:2] >= (3, 8): # pragma: no cover + SERVER_CONTEXT.num_tickets = 0 + + +def crash(ws): + raise RuntimeError + + +def do_nothing(ws): + pass + + +def eval_shell(ws): + for expr in ws: + value = eval(expr) + ws.send(str(value)) + + +class EvalShellMixin: + def assertEval(self, client, expr, value): + client.send(expr) + self.assertEqual(client.recv(), value) + + +@contextlib.contextmanager +def run_server(ws_handler=eval_shell, host="localhost", port=0, **kwargs): + with serve(ws_handler, host, port, **kwargs) as server: + thread = threading.Thread(target=server.serve_forever) + thread.start() + try: + yield server + finally: + server.shutdown() + thread.join() + + +@contextlib.contextmanager +def run_unix_server(path, ws_handler=eval_shell, **kwargs): + with unix_serve(ws_handler, path, **kwargs) as server: + thread = threading.Thread(target=server.serve_forever) + thread.start() + try: + yield server + finally: + server.shutdown() + thread.join() diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py new file mode 100644 index 000000000..8824ed894 --- /dev/null +++ b/tests/sync/test_client.py @@ -0,0 +1,271 @@ +import socket +import ssl +import threading +import unittest + +from websockets.exceptions import InvalidHandshake +from websockets.extensions.permessage_deflate import PerMessageDeflate +from websockets.sync.client import * + +from ..utils import MS, temp_unix_socket_path +from .client import CLIENT_CONTEXT, run_client, run_unix_client +from .server import SERVER_CONTEXT, do_nothing, run_server, run_unix_server + + +class ClientTests(unittest.TestCase): + def test_connection(self): + """Client connects to server and the handshake succeeds.""" + with run_server() as server: + with run_client(server) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + def test_connection_fails(self): + """Client connects to server but the handshake fails.""" + + def remove_accept_header(self, request, response): + del response.headers["Sec-WebSocket-Accept"] + + # The connection will be open for the server but failed for the client. + # Use a connection handler that exits immediately to avoid an exception. + with run_server(do_nothing, process_response=remove_accept_header) as server: + with self.assertRaisesRegex( + InvalidHandshake, + "missing Sec-WebSocket-Accept header", + ): + with run_client(server, close_timeout=MS): + self.fail("did not raise") + + def test_tcp_connection_fails(self): + """Client fails to connect to server.""" + with self.assertRaises(OSError): + with run_client("ws://localhost:54321"): # invalid port + self.fail("did not raise") + + def test_existing_socket(self): + """Client connects using a pre-existing socket.""" + with run_server() as server: + with socket.create_connection(server.socket.getsockname()) as sock: + # Use a non-existing domain to ensure we connect to the right socket. + with run_client("ws://invalid/", sock=sock) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + def test_additional_headers(self): + """Client can set additional headers with additional_headers.""" + with run_server() as server: + with run_client( + server, additional_headers={"Authorization": "Bearer ..."} + ) as client: + self.assertEqual(client.request.headers["Authorization"], "Bearer ...") + + def test_override_user_agent(self): + """Client can override User-Agent header with user_agent_header.""" + with run_server() as server: + with run_client(server, user_agent_header="Smith") as client: + self.assertEqual(client.request.headers["User-Agent"], "Smith") + + def test_remove_user_agent(self): + """Client can remove User-Agent header with user_agent_header.""" + with run_server() as server: + with run_client(server, user_agent_header=None) as client: + self.assertNotIn("User-Agent", client.request.headers) + + def test_compression_is_enabled(self): + """Client enables compression by default.""" + with run_server() as server: + with run_client(server) as client: + self.assertEqual( + [type(ext) for ext in client.protocol.extensions], + [PerMessageDeflate], + ) + + def test_disable_compression(self): + """Client disables compression.""" + with run_server() as server: + with run_client(server, compression=None) as client: + self.assertEqual(client.protocol.extensions, []) + + def test_custom_connection_factory(self): + """Client runs ClientConnection factory provided in create_connection.""" + + def create_connection(*args, **kwargs): + client = ClientConnection(*args, **kwargs) + client.create_connection_ran = True + return client + + with run_server() as server: + with run_client(server, create_connection=create_connection) as client: + self.assertTrue(client.create_connection_ran) + + def test_timeout_during_handshake(self): + """Client times out before receiving handshake response from server.""" + gate = threading.Event() + + def stall_connection(self, request): + gate.wait() + + # The connection will be open for the server but failed for the client. + # Use a connection handler that exits immediately to avoid an exception. + with run_server(do_nothing, process_request=stall_connection) as server: + try: + with self.assertRaisesRegex( + TimeoutError, + "timed out during handshake", + ): + with run_client(server, open_timeout=3 * MS): + self.fail("did not raise") + finally: + gate.set() + + def test_connection_closed_during_handshake(self): + """Client reads EOF before receiving handshake response from server.""" + + def close_connection(self, request): + self.close_socket() + + with run_server(process_request=close_connection) as server: + with self.assertRaisesRegex( + ConnectionError, + "connection closed during handshake", + ): + with run_client(server): + self.fail("did not raise") + + +class SecureClientTests(unittest.TestCase): + def test_connection(self): + """Client connects to server securely.""" + with run_server(ssl_context=SERVER_CONTEXT) as server: + with run_client(server, ssl_context=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(client.socket.version()[:3], "TLS") + + def test_set_server_hostname_implicitly(self): + """Client sets server_hostname to the host in the WebSocket URI.""" + with temp_unix_socket_path() as path: + with run_unix_server(path, ssl_context=SERVER_CONTEXT): + with run_unix_client( + path, + ssl_context=CLIENT_CONTEXT, + uri="wss://overridden/", + ) as client: + self.assertEqual(client.socket.server_hostname, "overridden") + + def test_set_server_hostname_explicitly(self): + """Client sets server_hostname to the value provided in argument.""" + with temp_unix_socket_path() as path: + with run_unix_server(path, ssl_context=SERVER_CONTEXT): + with run_unix_client( + path, + ssl_context=CLIENT_CONTEXT, + server_hostname="overridden", + ) as client: + self.assertEqual(client.socket.server_hostname, "overridden") + + def test_reject_invalid_server_certificate(self): + """Client rejects certificate where server certificate isn't trusted.""" + with run_server(ssl_context=SERVER_CONTEXT) as server: + with self.assertRaisesRegex( + ssl.SSLCertVerificationError, + r"certificate verify failed: self[ -]signed certificate", + ): + # The test certificate isn't trusted system-wide. + with run_client(server, secure=True): + self.fail("did not raise") + + def test_reject_invalid_server_hostname(self): + """Client rejects certificate where server hostname doesn't match.""" + with run_server(ssl_context=SERVER_CONTEXT) as server: + with self.assertRaisesRegex( + ssl.SSLCertVerificationError, + r"certificate verify failed: Hostname mismatch", + ): + # This hostname isn't included in the test certificate. + with run_client( + server, ssl_context=CLIENT_CONTEXT, server_hostname="invalid" + ): + self.fail("did not raise") + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class UnixClientTests(unittest.TestCase): + def test_connection(self): + """Client connects to server over a Unix socket.""" + with temp_unix_socket_path() as path: + with run_unix_server(path): + with run_unix_client(path) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + def test_set_host_header(self): + """Client sets the Host header to the host in the WebSocket URI.""" + # This is part of the documented behavior of unix_connect(). + with temp_unix_socket_path() as path: + with run_unix_server(path): + with run_unix_client(path, uri="ws://overridden/") as client: + self.assertEqual(client.request.headers["Host"], "overridden") + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class SecureUnixClientTests(unittest.TestCase): + def test_connection(self): + """Client connects to server securely over a Unix socket.""" + with temp_unix_socket_path() as path: + with run_unix_server(path, ssl_context=SERVER_CONTEXT): + with run_unix_client(path, ssl_context=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(client.socket.version()[:3], "TLS") + + def test_set_server_hostname(self): + """Client sets server_hostname to the host in the WebSocket URI.""" + # This is part of the documented behavior of unix_connect(). + with temp_unix_socket_path() as path: + with run_unix_server(path, ssl_context=SERVER_CONTEXT): + with run_unix_client( + path, + ssl_context=CLIENT_CONTEXT, + uri="wss://overridden/", + ) as client: + self.assertEqual(client.socket.server_hostname, "overridden") + + +class ClientUsageErrorsTests(unittest.TestCase): + def test_ssl_context_without_secure_uri(self): + """Client rejects ssl_context when URI isn't secure.""" + with self.assertRaisesRegex( + TypeError, + "ssl_context argument is incompatible with a ws:// URI", + ): + connect("ws://localhost/", ssl_context=CLIENT_CONTEXT) + + def test_unix_without_path_or_sock(self): + """Unix client requires path when sock isn't provided.""" + with self.assertRaisesRegex( + TypeError, + "missing path argument", + ): + unix_connect() + + def test_unix_with_path_and_sock(self): + """Unix client rejects path when sock is provided.""" + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.addCleanup(sock.close) + with self.assertRaisesRegex( + TypeError, + "path and sock arguments are incompatible", + ): + unix_connect(path="/", sock=sock) + + def test_invalid_subprotocol(self): + """Client rejects single value of subprotocols.""" + with self.assertRaisesRegex( + TypeError, + "subprotocols must be a list", + ): + connect("ws://localhost/", subprotocols="chat") + + def test_unsupported_compression(self): + """Client rejects incorrect value of compression.""" + with self.assertRaisesRegex( + ValueError, + "unsupported compression: False", + ): + connect("ws://localhost/", compression=False) diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py new file mode 100644 index 000000000..94850affe --- /dev/null +++ b/tests/sync/test_connection.py @@ -0,0 +1,704 @@ +import contextlib +import logging +import socket +import sys +import threading +import time +import unittest +import uuid +from unittest.mock import patch + +from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK +from websockets.frames import Frame, Opcode +from websockets.protocol import CLIENT, SERVER, Protocol +from websockets.sync.connection import * + +from ..protocol import RecordingProtocol +from ..utils import MS +from .connection import InterceptingConnection + + +# Connection implements symmetrical behavior between clients and servers. +# All tests run on the client side and the server side to validate this. + + +class ClientConnectionTests(unittest.TestCase): + LOCAL = CLIENT + REMOTE = SERVER + + def setUp(self): + socket_, remote_socket = socket.socketpair() + protocol = Protocol(self.LOCAL) + remote_protocol = RecordingProtocol(self.REMOTE) + self.connection = Connection(socket_, protocol, close_timeout=2 * MS) + self.remote_connection = InterceptingConnection(remote_socket, remote_protocol) + + def tearDown(self): + self.remote_connection.close() + self.connection.close() + + # Test helpers built upon RecordingProtocol and InterceptingConnection. + + def assertFrameSent(self, frame): + """Check that a single frame was sent.""" + time.sleep(MS) # let the remote side process messages + self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), [frame]) + + def assertNoFrameSent(self): + """Check that no frame was sent.""" + time.sleep(MS) # let the remote side process messages + self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), []) + + @contextlib.contextmanager + def delay_frames_rcvd(self, delay): + """Delay frames before they're received by the connection.""" + with self.remote_connection.delay_frames_sent(delay): + yield + time.sleep(MS) # let the remote side process messages + + @contextlib.contextmanager + def delay_eof_rcvd(self, delay): + """Delay EOF before it's received by the connection.""" + with self.remote_connection.delay_eof_sent(delay): + yield + time.sleep(MS) # let the remote side process messages + + @contextlib.contextmanager + def drop_frames_rcvd(self): + """Drop frames before they're received by the connection.""" + with self.remote_connection.drop_frames_sent(): + yield + time.sleep(MS) # let the remote side process messages + + @contextlib.contextmanager + def drop_eof_rcvd(self): + """Drop EOF before it's received by the connection.""" + with self.remote_connection.drop_eof_sent(): + yield + time.sleep(MS) # let the remote side process messages + + # Test __enter__ and __exit__. + + def test_enter(self): + """__enter__ returns the connection itself.""" + with self.connection as connection: + self.assertIs(connection, self.connection) + + def test_exit(self): + """__exit__ closes the connection with code 1000.""" + with self.connection: + self.assertNoFrameSent() + self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + def test_exit_with_exception(self): + """__exit__ with an exception closes the connection with code 1011.""" + with self.assertRaises(RuntimeError): + with self.connection: + raise RuntimeError + self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xf3")) + + # Test __iter__. + + def test_iter_text(self): + """__iter__ yields text messages.""" + iterator = iter(self.connection) + self.remote_connection.send("😀") + self.assertEqual(next(iterator), "😀") + self.remote_connection.send("😀") + self.assertEqual(next(iterator), "😀") + + def test_iter_binary(self): + """__iter__ yields binary messages.""" + iterator = iter(self.connection) + self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(next(iterator), b"\x01\x02\xfe\xff") + self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(next(iterator), b"\x01\x02\xfe\xff") + + def test_iter_mixed(self): + """__iter__ yields a mix of text and binary messages.""" + iterator = iter(self.connection) + self.remote_connection.send("😀") + self.assertEqual(next(iterator), "😀") + self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(next(iterator), b"\x01\x02\xfe\xff") + + def test_iter_connection_closed_ok(self): + """__iter__ terminates after a normal closure.""" + iterator = iter(self.connection) + self.remote_connection.close() + with self.assertRaises(StopIteration): + next(iterator) + + def test_iter_connection_closed_error(self): + """__iter__ raises ConnnectionClosedError after an error.""" + iterator = iter(self.connection) + self.remote_connection.close(code=1011) + with self.assertRaises(ConnectionClosedError): + next(iterator) + + # Test recv. + + def test_recv_text(self): + """recv receives a text message.""" + self.remote_connection.send("😀") + self.assertEqual(self.connection.recv(), "😀") + + def test_recv_binary(self): + """recv receives a binary message.""" + self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(self.connection.recv(), b"\x01\x02\xfe\xff") + + def test_recv_fragmented_text(self): + """recv receives a fragmented text message.""" + self.remote_connection.send(["😀", "😀"]) + self.assertEqual(self.connection.recv(), "😀😀") + + def test_recv_fragmented_binary(self): + """recv receives a fragmented binary message.""" + self.remote_connection.send([b"\x01\x02", b"\xfe\xff"]) + self.assertEqual(self.connection.recv(), b"\x01\x02\xfe\xff") + + def test_recv_connection_closed_ok(self): + """recv raises ConnectionClosedOK after a normal closure.""" + self.remote_connection.close() + with self.assertRaises(ConnectionClosedOK): + self.connection.recv() + + def test_recv_connection_closed_error(self): + """recv raises ConnectionClosedError after an error.""" + self.remote_connection.close(code=1011) + with self.assertRaises(ConnectionClosedError): + self.connection.recv() + + def test_recv_during_recv(self): + """recv raises RuntimeError when called concurrently with itself.""" + recv_thread = threading.Thread(target=self.connection.recv) + recv_thread.start() + + with self.assertRaisesRegex( + RuntimeError, + "cannot call recv while another thread " + "is already running recv or recv_streaming", + ): + self.connection.recv() + + self.remote_connection.send("") + recv_thread.join() + + def test_recv_during_recv_streaming(self): + """recv raises RuntimeError when called concurrently with recv_streaming.""" + recv_streaming_thread = threading.Thread( + target=lambda: list(self.connection.recv_streaming()) + ) + recv_streaming_thread.start() + + with self.assertRaisesRegex( + RuntimeError, + "cannot call recv while another thread " + "is already running recv or recv_streaming", + ): + self.connection.recv() + + self.remote_connection.send("") + recv_streaming_thread.join() + + # Test recv_streaming. + + def test_recv_streaming_text(self): + """recv_streaming receives a text message.""" + self.remote_connection.send("😀") + self.assertEqual( + list(self.connection.recv_streaming()), + ["😀"], + ) + + def test_recv_streaming_binary(self): + """recv_streaming receives a binary message.""" + self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual( + list(self.connection.recv_streaming()), + [b"\x01\x02\xfe\xff"], + ) + + def test_recv_streaming_fragmented_text(self): + """recv_streaming receives a fragmented text message.""" + self.remote_connection.send(["😀", "😀"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + list(self.connection.recv_streaming()), + ["😀", "😀", ""], + ) + + def test_recv_streaming_fragmented_binary(self): + """recv_streaming receives a fragmented binary message.""" + self.remote_connection.send([b"\x01\x02", b"\xfe\xff"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + list(self.connection.recv_streaming()), + [b"\x01\x02", b"\xfe\xff", b""], + ) + + def test_recv_streaming_connection_closed_ok(self): + """recv_streaming raises ConnectionClosedOK after a normal closure.""" + self.remote_connection.close() + with self.assertRaises(ConnectionClosedOK): + list(self.connection.recv_streaming()) + + def test_recv_streaming_connection_closed_error(self): + """recv_streaming raises ConnectionClosedError after an error.""" + self.remote_connection.close(code=1011) + with self.assertRaises(ConnectionClosedError): + list(self.connection.recv_streaming()) + + def test_recv_streaming_during_recv(self): + """recv_streaming raises RuntimeError when called concurrently with recv.""" + recv_thread = threading.Thread(target=self.connection.recv) + recv_thread.start() + + with self.assertRaisesRegex( + RuntimeError, + "cannot call recv_streaming while another thread " + "is already running recv or recv_streaming", + ): + list(self.connection.recv_streaming()) + + self.remote_connection.send("") + recv_thread.join() + + def test_recv_streaming_during_recv_streaming(self): + """recv_streaming raises RuntimeError when called concurrently with itself.""" + recv_streaming_thread = threading.Thread( + target=lambda: list(self.connection.recv_streaming()) + ) + recv_streaming_thread.start() + + with self.assertRaisesRegex( + RuntimeError, + r"cannot call recv_streaming while another thread " + r"is already running recv or recv_streaming", + ): + list(self.connection.recv_streaming()) + + self.remote_connection.send("") + recv_streaming_thread.join() + + # Test send. + + def test_send_text(self): + """send sends a text message.""" + self.connection.send("😀") + self.assertEqual(self.remote_connection.recv(), "😀") + + def test_send_binary(self): + """send sends a binary message.""" + self.connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(self.remote_connection.recv(), b"\x01\x02\xfe\xff") + + def test_send_fragmented_text(self): + """send sends a fragmented text message.""" + self.connection.send(["😀", "😀"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + list(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + + def test_send_fragmented_binary(self): + """send sends a fragmented binary message.""" + self.connection.send([b"\x01\x02", b"\xfe\xff"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + list(self.remote_connection.recv_streaming()), + [b"\x01\x02", b"\xfe\xff", b""], + ) + + def test_send_connection_closed_ok(self): + """send raises ConnectionClosedOK after a normal closure.""" + self.remote_connection.close() + with self.assertRaises(ConnectionClosedOK): + self.connection.send("😀") + + def test_send_connection_closed_error(self): + """send raises ConnectionClosedError after an error.""" + self.remote_connection.close(code=1011) + with self.assertRaises(ConnectionClosedError): + self.connection.send("😀") + + def test_send_during_send(self): + """send raises RuntimeError when called concurrently with itself.""" + recv_thread = threading.Thread(target=self.remote_connection.recv) + recv_thread.start() + + send_gate = threading.Event() + exit_gate = threading.Event() + + def fragments(): + yield "😀" + send_gate.set() + exit_gate.wait() + yield "😀" + + send_thread = threading.Thread( + target=self.connection.send, + args=(fragments(),), + ) + send_thread.start() + + send_gate.wait() + # The check happens in four code paths, depending on the argument. + for message in [ + "😀", + b"\x01\x02\xfe\xff", + ["😀", "😀"], + [b"\x01\x02", b"\xfe\xff"], + ]: + with self.subTest(message=message): + with self.assertRaisesRegex( + RuntimeError, + "cannot call send while another thread is already running send", + ): + self.connection.send(message) + + exit_gate.set() + send_thread.join() + recv_thread.join() + + def test_send_empty_iterable(self): + """send does nothing when called with an empty iterable.""" + self.connection.send([]) + self.connection.close() + self.assertEqual(list(iter(self.remote_connection)), []) + + def test_send_mixed_iterable(self): + """send raises TypeError when called with an iterable of inconsistent types.""" + with self.assertRaises(TypeError): + self.connection.send(["😀", b"\xfe\xff"]) + + def test_send_unsupported_iterable(self): + """send raises TypeError when called with an iterable of unsupported type.""" + with self.assertRaises(TypeError): + self.connection.send([None]) + + def test_send_dict(self): + """send raises TypeError when called with a dict.""" + with self.assertRaises(TypeError): + self.connection.send({"type": "object"}) + + def test_send_unsupported_type(self): + """send raises TypeError when called with an unsupported type.""" + with self.assertRaises(TypeError): + self.connection.send(None) + + # Test close. + + def test_close(self): + """close sends a close frame.""" + self.connection.close() + self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + def test_close_explicit_code_reason(self): + """close sends a close frame with a given code and reason.""" + self.connection.close(1001, "bye!") + self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe9bye!")) + + def test_close_waits_for_close_frame(self): + """close waits for a close frame (then EOF) before returning.""" + with self.delay_frames_rcvd(MS): + self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + def test_close_waits_for_connection_closed(self): + """close waits for EOF before returning.""" + if self.LOCAL is SERVER: + self.skipTest("only relevant on the client-side") + + with self.delay_eof_rcvd(MS): + self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + def test_close_timeout_waiting_for_close_frame(self): + """close times out if no close frame is received.""" + with self.drop_eof_rcvd(), self.drop_frames_rcvd(): + self.connection.close() + + with self.assertRaises(ConnectionClosedError) as raised: + self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); no close frame received") + self.assertIsInstance(exc.__cause__, TimeoutError) + + def test_close_timeout_waiting_for_connection_closed(self): + """close times out if EOF isn't received.""" + if self.LOCAL is SERVER: + self.skipTest("only relevant on the client-side") + + with self.drop_eof_rcvd(): + self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + # Remove socket.timeout when dropping Python < 3.10. + self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) + + def test_close_idempotency(self): + """close does nothing if the connection is already closed.""" + self.connection.close() + self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + self.connection.close() + self.assertNoFrameSent() + + def test_close_idempotency_race_condition(self): + """close waits if the connection is already closing.""" + + self.connection.close_timeout = 5 * MS + + def closer(): + with self.delay_frames_rcvd(3 * MS): + self.connection.close() + + close_thread = threading.Thread(target=closer) + close_thread.start() + + # Let closer() initiate the closing handshake and send a close frame. + time.sleep(MS) + self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + # Connection isn't closed yet. + with self.assertRaises(TimeoutError): + self.connection.recv(timeout=0) + + self.connection.close() + self.assertNoFrameSent() + + # Connection is closed now. + with self.assertRaises(ConnectionClosedOK): + self.connection.recv(timeout=0) + + close_thread.join() + + def test_close_during_send(self): + """close fails the connection when called concurrently with send.""" + close_gate = threading.Event() + exit_gate = threading.Event() + + def closer(): + close_gate.wait() + self.connection.close() + exit_gate.set() + + def fragments(): + yield "😀" + close_gate.set() + exit_gate.wait() + yield "😀" + + close_thread = threading.Thread(target=closer) + close_thread.start() + + with self.assertRaises(ConnectionClosedError) as raised: + self.connection.send(fragments()) + + exc = raised.exception + self.assertEqual( + str(exc), + "sent 1011 (unexpected error) close during fragmented message; " + "no close frame received", + ) + self.assertIsNone(exc.__cause__) + + close_thread.join() + + # Test ping. + + @patch("random.getrandbits") + def test_ping(self, getrandbits): + """ping sends a ping frame with a random payload.""" + getrandbits.return_value = 1918987876 + self.connection.ping() + getrandbits.assert_called_once_with(32) + self.assertFrameSent(Frame(Opcode.PING, b"rand")) + + def test_ping_explicit_text(self): + """ping sends a ping frame with a payload provided as text.""" + self.connection.ping("ping") + self.assertFrameSent(Frame(Opcode.PING, b"ping")) + + def test_ping_explicit_binary(self): + """ping sends a ping frame with a payload provided as binary.""" + self.connection.ping(b"ping") + self.assertFrameSent(Frame(Opcode.PING, b"ping")) + + def test_ping_duplicate_payload(self): + """ping rejects the same payload until receiving the pong.""" + with self.remote_connection.protocol_mutex: # block response to ping + pong_waiter = self.connection.ping("idem") + with self.assertRaisesRegex( + RuntimeError, + "already waiting for a pong with the same data", + ): + self.connection.ping("idem") + self.assertTrue(pong_waiter.wait(MS)) + self.connection.ping("idem") # doesn't raise an exception + + def test_acknowledge_ping(self): + """ping is acknowledged by a pong with the same payload.""" + with self.drop_frames_rcvd(): + pong_waiter = self.connection.ping("this") + self.assertFalse(pong_waiter.wait(MS)) + self.remote_connection.pong("this") + self.assertTrue(pong_waiter.wait(MS)) + + def test_acknowledge_ping_non_matching_pong(self): + """ping isn't acknowledged by a pong with a different payload.""" + with self.drop_frames_rcvd(): + pong_waiter = self.connection.ping("this") + self.remote_connection.pong("that") + self.assertFalse(pong_waiter.wait(MS)) + + def test_acknowledge_previous_ping(self): + """ping is acknowledged by a pong with the same payload as a later ping.""" + with self.drop_frames_rcvd(): + pong_waiter = self.connection.ping("this") + self.connection.ping("that") + self.remote_connection.pong("that") + self.assertTrue(pong_waiter.wait(MS)) + + # Test pong. + + def test_pong(self): + """pong sends a pong frame.""" + self.connection.pong() + self.assertFrameSent(Frame(Opcode.PONG, b"")) + + def test_pong_explicit_text(self): + """pong sends a pong frame with a payload provided as text.""" + self.connection.pong("pong") + self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + + def test_pong_explicit_binary(self): + """pong sends a pong frame with a payload provided as binary.""" + self.connection.pong(b"pong") + self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + + # Test attributes. + + def test_id(self): + """Connection has an id attribute.""" + self.assertIsInstance(self.connection.id, uuid.UUID) + + def test_logger(self): + """Connection has a logger attribute.""" + self.assertIsInstance(self.connection.logger, logging.LoggerAdapter) + + def test_local_address(self): + """Connection has a local_address attribute.""" + self.assertIsNotNone(self.connection.local_address) + + def test_remote_address(self): + """Connection has a remote_address attribute.""" + self.assertIsNotNone(self.connection.remote_address) + + def test_request(self): + """Connection has a request attribute.""" + self.assertIsNone(self.connection.request) + + def test_response(self): + """Connection has a response attribute.""" + self.assertIsNone(self.connection.response) + + def test_subprotocol(self): + """Connection has a subprotocol attribute.""" + self.assertIsNone(self.connection.subprotocol) + + # Test reporting of network errors. + + @unittest.skipUnless(sys.platform == "darwin", "works only on BSD") + def test_reading_in_recv_events_fails(self): + """Error when reading incoming frames is correctly reported.""" + # Inject a fault by closing the socket. This works only on BSD. + # I cannot find a way to achieve the same effect on Linux. + self.connection.socket.close() + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + self.connection.recv() + self.assertIsInstance(raised.exception.__cause__, IOError) + + def test_writing_in_recv_events_fails(self): + """Error when responding to incoming frames is correctly reported.""" + # Inject a fault by shutting down the socket for writing — but not by + # closing it because that would terminate the connection. + self.connection.socket.shutdown(socket.SHUT_WR) + # Receive a ping. Responding with a pong will fail. + self.remote_connection.ping() + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + self.connection.recv() + self.assertIsInstance(raised.exception.__cause__, BrokenPipeError) + + def test_writing_in_send_context_fails(self): + """Error when sending outgoing frame is correctly reported.""" + # Inject a fault by shutting down the socket for writing — but not by + # closing it because that would terminate the connection. + self.connection.socket.shutdown(socket.SHUT_WR) + # Sending a pong will fail. + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + self.connection.pong() + self.assertIsInstance(raised.exception.__cause__, BrokenPipeError) + + # Test safety nets — catching all exceptions in case of bugs. + + @patch("websockets.protocol.Protocol.events_received") + def test_unexpected_failure_in_recv_events(self, events_received): + """Unexpected internal error in recv_events() is correctly reported.""" + # Inject a fault in a random call in recv_events(). + # This test is tightly coupled to the implementation. + events_received.side_effect = AssertionError + # Receive a message to trigger the fault. + self.remote_connection.send("😀") + + with self.assertRaises(ConnectionClosedError) as raised: + self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "no close frame received or sent") + self.assertIsInstance(exc.__cause__, AssertionError) + + @patch("websockets.protocol.Protocol.send_text") + def test_unexpected_failure_in_send_context(self, send_text): + """Unexpected internal error in send_context() is correctly reported.""" + # Inject a fault in a random call in send_context(). + # This test is tightly coupled to the implementation. + send_text.side_effect = AssertionError + + # Send a message to trigger the fault. + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + self.connection.send("😀") + + exc = raised.exception + self.assertEqual(str(exc), "no close frame received or sent") + self.assertIsInstance(exc.__cause__, AssertionError) + + +class ServerConnectionTests(ClientConnectionTests): + LOCAL = SERVER + REMOTE = CLIENT diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py index 069da784b..825eb8797 100644 --- a/tests/sync/test_messages.py +++ b/tests/sync/test_messages.py @@ -31,7 +31,7 @@ def setUp(self): def tearDown(self): """ - Ensure the assembler goes back to its default state after each test. + Check that the assembler goes back to its default state after each test. This removes the need for testing various sequences. diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py new file mode 100644 index 000000000..536858149 --- /dev/null +++ b/tests/sync/test_server.py @@ -0,0 +1,389 @@ +import dataclasses +import http +import logging +import socket +import threading +import unittest + +from websockets.exceptions import ( + ConnectionClosedError, + ConnectionClosedOK, + InvalidStatus, + NegotiationError, +) +from websockets.http11 import Request, Response +from websockets.sync.compatibility import socket_create_server +from websockets.sync.server import * + +from ..utils import MS, temp_unix_socket_path +from .client import CLIENT_CONTEXT, run_client, run_unix_client +from .server import ( + SERVER_CONTEXT, + EvalShellMixin, + crash, + do_nothing, + eval_shell, + run_server, + run_unix_server, +) + + +class ServerTests(EvalShellMixin, unittest.TestCase): + def test_connection(self): + """Server receives connection from client and the handshake succeeds.""" + with run_server() as server: + with run_client(server) as client: + self.assertEval(client, "ws.protocol.state.name", "OPEN") + + def test_connection_fails(self): + """Server receives connection from client but the handshake fails.""" + + def remove_key_header(self, request): + del request.headers["Sec-WebSocket-Key"] + + with run_server(process_request=remove_key_header) as server: + with self.assertRaisesRegex( + InvalidStatus, + "server rejected WebSocket connection: HTTP 400", + ): + with run_client(server): + self.fail("did not raise") + + def test_connection_handler_returns(self): + """Connection handler returns.""" + with run_server(do_nothing) as server: + with run_client(server) as client: + with self.assertRaisesRegex( + ConnectionClosedOK, + r"received 1000 \(OK\); then sent 1000 \(OK\)", + ): + client.recv() + + def test_connection_handler_raises_exception(self): + """Connection handler raises an exception.""" + with run_server(crash) as server: + with run_client(server) as client: + with self.assertRaisesRegex( + ConnectionClosedError, + r"received 1011 \(unexpected error\); " + r"then sent 1011 \(unexpected error\)", + ): + client.recv() + + def test_existing_socket(self): + """Server receives connection using a pre-existing socket.""" + with socket_create_server(("localhost", 0)) as sock: + with run_server(sock=sock): + # Build WebSocket URI to ensure we connect to the right socket. + with run_client("ws://{}:{}/".format(*sock.getsockname())) as client: + self.assertEval(client, "ws.protocol.state.name", "OPEN") + + def test_select_subprotocol(self): + """Server selects a subprotocol with the select_subprotocol callable.""" + + def select_subprotocol(ws, subprotocols): + ws.select_subprotocol_ran = True + assert "chat" in subprotocols + return "chat" + + with run_server( + subprotocols=["chat"], + select_subprotocol=select_subprotocol, + ) as server: + with run_client(server, subprotocols=["chat"]) as client: + self.assertEval(client, "ws.select_subprotocol_ran", "True") + self.assertEval(client, "ws.subprotocol", "chat") + + def test_select_subprotocol_rejects_handshake(self): + """Server rejects handshake if select_subprotocol raises NegotiationError.""" + + def select_subprotocol(ws, subprotocols): + raise NegotiationError + + with run_server(select_subprotocol=select_subprotocol) as server: + with self.assertRaisesRegex( + InvalidStatus, + "server rejected WebSocket connection: HTTP 400", + ): + with run_client(server): + self.fail("did not raise") + + def test_select_subprotocol_raises_exception(self): + """Server returns an error if select_subprotocol raises an exception.""" + + def select_subprotocol(ws, subprotocols): + raise RuntimeError + + with run_server(select_subprotocol=select_subprotocol) as server: + with self.assertRaisesRegex( + InvalidStatus, + "server rejected WebSocket connection: HTTP 500", + ): + with run_client(server): + self.fail("did not raise") + + def test_process_request(self): + """Server runs process_request before processing the handshake.""" + + def process_request(ws, request): + self.assertIsInstance(request, Request) + ws.process_request_ran = True + + with run_server(process_request=process_request) as server: + with run_client(server) as client: + self.assertEval(client, "ws.process_request_ran", "True") + + def test_process_request_abort_handshake(self): + """Server aborts handshake if process_request returns a response.""" + + def process_request(ws, request): + return ws.protocol.reject(http.HTTPStatus.FORBIDDEN, "Forbidden") + + with run_server(process_request=process_request) as server: + with self.assertRaisesRegex( + InvalidStatus, + "server rejected WebSocket connection: HTTP 403", + ): + with run_client(server): + self.fail("did not raise") + + def test_process_request_raises_exception(self): + """Server returns an error if process_request raises an exception.""" + + def process_request(ws, request): + raise RuntimeError + + with run_server(process_request=process_request) as server: + with self.assertRaisesRegex( + InvalidStatus, + "server rejected WebSocket connection: HTTP 500", + ): + with run_client(server): + self.fail("did not raise") + + def test_process_response(self): + """Server runs process_response after processing the handshake.""" + + def process_response(ws, request, response): + self.assertIsInstance(request, Request) + self.assertIsInstance(response, Response) + ws.process_response_ran = True + + with run_server(process_response=process_response) as server: + with run_client(server) as client: + self.assertEval(client, "ws.process_response_ran", "True") + + def test_process_response_override_response(self): + """Server runs process_response after processing the handshake.""" + + def process_response(ws, request, response): + headers = response.headers.copy() + headers["X-ProcessResponse-Ran"] = "true" + return dataclasses.replace(response, headers=headers) + + with run_server(process_response=process_response) as server: + with run_client(server) as client: + self.assertEqual( + client.response.headers["X-ProcessResponse-Ran"], "true" + ) + + def test_process_response_raises_exception(self): + """Server returns an error if process_response raises an exception.""" + + def process_response(ws, request, response): + raise RuntimeError + + with run_server(process_response=process_response) as server: + with self.assertRaisesRegex( + InvalidStatus, + "server rejected WebSocket connection: HTTP 500", + ): + with run_client(server): + self.fail("did not raise") + + def test_override_server(self): + """Server can override Server header with server_header.""" + with run_server(server_header="Neo") as server: + with run_client(server) as client: + self.assertEval(client, "ws.response.headers['Server']", "Neo") + + def test_remove_server(self): + """Server can remove Server header with server_header.""" + with run_server(server_header=None) as server: + with run_client(server) as client: + self.assertEval(client, "'Server' in ws.response.headers", "False") + + def test_compression_is_enabled(self): + """Server enables compression by default.""" + with run_server() as server: + with run_client(server) as client: + self.assertEval( + client, + "[type(ext).__name__ for ext in ws.protocol.extensions]", + "['PerMessageDeflate']", + ) + + def test_disable_compression(self): + """Server disables compression.""" + with run_server(compression=None) as server: + with run_client(server) as client: + self.assertEval(client, "ws.protocol.extensions", "[]") + + def test_custom_connection_factory(self): + """Server runs ServerConnection factory provided in create_connection.""" + + def create_connection(*args, **kwargs): + server = ServerConnection(*args, **kwargs) + server.create_connection_ran = True + return server + + with run_server(create_connection=create_connection) as server: + with run_client(server) as client: + self.assertEval(client, "ws.create_connection_ran", "True") + + def test_timeout_during_handshake(self): + """Server times out before receiving handshake request from client.""" + with run_server(open_timeout=MS) as server: + with socket.create_connection(server.socket.getsockname()) as sock: + self.assertEqual(sock.recv(4096), b"") + + def test_connection_closed_during_handshake(self): + """Server reads EOF before receiving handshake request from client.""" + with run_server() as server: + # Patch handler to record a reference to the thread running it. + server_thread = None + conn_received = threading.Event() + original_handler = server.handler + + def handler(sock, addr): + nonlocal server_thread + server_thread = threading.current_thread() + nonlocal conn_received + conn_received.set() + original_handler(sock, addr) + + server.handler = handler + + with socket.create_connection(server.socket.getsockname()): + # Wait for the server to receive the connection, then close it. + conn_received.wait() + + # Wait for the server thread to terminate. + server_thread.join() + + +class SecureServerTests(EvalShellMixin, unittest.TestCase): + def test_connection(self): + """Server receives secure connection from client.""" + with run_server(ssl_context=SERVER_CONTEXT) as server: + with run_client(server, ssl_context=CLIENT_CONTEXT) as client: + self.assertEval(client, "ws.protocol.state.name", "OPEN") + self.assertEval(client, "ws.socket.version()[:3]", "TLS") + + def test_timeout_during_tls_handshake(self): + """Server times out before receiving TLS handshake request from client.""" + with run_server(ssl_context=SERVER_CONTEXT, open_timeout=MS) as server: + with socket.create_connection(server.socket.getsockname()) as sock: + self.assertEqual(sock.recv(4096), b"") + + def test_connection_closed_during_tls_handshake(self): + """Server reads EOF before receiving TLS handshake request from client.""" + with run_server(ssl_context=SERVER_CONTEXT) as server: + # Patch handler to record a reference to the thread running it. + server_thread = None + conn_received = threading.Event() + original_handler = server.handler + + def handler(sock, addr): + nonlocal server_thread + server_thread = threading.current_thread() + nonlocal conn_received + conn_received.set() + original_handler(sock, addr) + + server.handler = handler + + with socket.create_connection(server.socket.getsockname()): + # Wait for the server to receive the connection, then close it. + conn_received.wait() + + # Wait for the server thread to terminate. + server_thread.join() + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class UnixServerTests(EvalShellMixin, unittest.TestCase): + def test_connection(self): + """Server receives connection from client over a Unix socket.""" + with temp_unix_socket_path() as path: + with run_unix_server(path): + with run_unix_client(path) as client: + self.assertEval(client, "ws.protocol.state.name", "OPEN") + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class SecureUnixServerTests(EvalShellMixin, unittest.TestCase): + def test_connection(self): + """Server receives secure connection from client over a Unix socket.""" + with temp_unix_socket_path() as path: + with run_unix_server(path, ssl_context=SERVER_CONTEXT): + with run_unix_client(path, ssl_context=CLIENT_CONTEXT) as client: + self.assertEval(client, "ws.protocol.state.name", "OPEN") + self.assertEval(client, "ws.socket.version()[:3]", "TLS") + + +class ServerUsageErrorsTests(unittest.TestCase): + def test_unix_without_path_or_sock(self): + """Unix server requires path when sock isn't provided.""" + with self.assertRaisesRegex( + TypeError, + "missing path argument", + ): + unix_serve(eval_shell) + + def test_unix_with_path_and_sock(self): + """Unix server rejects path when sock is provided.""" + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.addCleanup(sock.close) + with self.assertRaisesRegex( + TypeError, + "path and sock arguments are incompatible", + ): + unix_serve(eval_shell, path="/", sock=sock) + + def test_invalid_subprotocol(self): + """Server rejects single value of subprotocols.""" + with self.assertRaisesRegex( + TypeError, + "subprotocols must be a list", + ): + serve(eval_shell, subprotocols="chat") + + def test_unsupported_compression(self): + """Server rejects incorrect value of compression.""" + with self.assertRaisesRegex( + ValueError, + "unsupported compression: False", + ): + serve(eval_shell, compression=False) + + +class WebSocketServerTests(unittest.TestCase): + def test_logger(self): + """WebSocketServer accepts a logger argument.""" + logger = logging.getLogger("test") + with run_server(logger=logger) as server: + self.assertIs(server.logger, logger) + + def test_fileno(self): + """WebSocketServer provides a fileno attribute.""" + with run_server() as server: + self.assertIsInstance(server.fileno(), int) + + def test_shutdown(self): + """WebSocketServer provides a shutdown method.""" + with run_server() as server: + server.shutdown() + # Check that the server socket is closed. + with self.assertRaises(OSError): + server.socket.accept() From 6df903cee5b91e1c7aa6c3f9716537a55d0b11cd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 17:27:25 +0100 Subject: [PATCH 1171/1539] Document thread-based implementation. --- README.rst | 48 +++++++++------- docs/faq/asyncio.rst | 12 ++-- docs/faq/client.rst | 5 ++ docs/faq/misc.rst | 9 +-- docs/index.rst | 30 ++++++---- docs/project/changelog.rst | 11 ++++ docs/reference/index.rst | 12 ++++ docs/reference/sync/client.rst | 49 ++++++++++++++++ docs/reference/sync/common.rst | 39 +++++++++++++ docs/reference/sync/server.rst | 60 +++++++++++++++++++ docs/spelling_wordlist.txt | 1 + example/echo.py | 4 +- example/hello.py | 13 +++-- src/websockets/client.py | 2 +- src/websockets/legacy/auth.py | 16 +++--- src/websockets/legacy/client.py | 67 ++++++++++------------ src/websockets/legacy/framing.py | 28 ++++----- src/websockets/legacy/handshake.py | 18 +++--- src/websockets/legacy/http.py | 16 +++--- src/websockets/legacy/protocol.py | 92 +++++++++++++++--------------- src/websockets/legacy/server.py | 74 ++++++++++++------------ src/websockets/protocol.py | 12 ++-- src/websockets/server.py | 6 +- src/websockets/sync/client.py | 4 +- 24 files changed, 405 insertions(+), 223 deletions(-) create mode 100644 docs/reference/sync/client.rst create mode 100644 docs/reference/sync/common.rst create mode 100644 docs/reference/sync/server.rst diff --git a/README.rst b/README.rst index fa1d91061..5ba523e8f 100644 --- a/README.rst +++ b/README.rst @@ -30,47 +30,52 @@ with a focus on correctness, simplicity, robustness, and performance. .. _WebSocket: https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API -Built on top of ``asyncio``, Python's standard asynchronous I/O framework, it -provides an elegant coroutine-based API. +Built on top of ``asyncio``, Python's standard asynchronous I/O framework, the +default implementation provides an elegant coroutine-based API. -`Documentation is available on Read the Docs. `_ +An implementation on top of ``threading`` and a Sans-I/O implementation are also +available. -Here's how a client sends and receives messages: +`Documentation is available on Read the Docs. `_ .. copy-pasted because GitHub doesn't support the include directive +Here's an echo server with the ``asyncio`` API: + .. code:: python #!/usr/bin/env python import asyncio - from websockets import connect + from websockets.server import serve - async def hello(uri): - async with connect(uri) as websocket: - await websocket.send("Hello world!") - await websocket.recv() + async def echo(websocket): + async for message in websocket: + await websocket.send(message) + + async def main(): + async with serve(echo, "localhost", 8765): + await asyncio.Future() # run forever - asyncio.run(hello("ws://localhost:8765")) + asyncio.run(main()) -And here's an echo server: +Here's how a client sends and receives messages with the ``threading`` API: .. code:: python #!/usr/bin/env python import asyncio - from websockets import serve + from websockets.sync.client import connect - async def echo(websocket): - async for message in websocket: - await websocket.send(message) + def hello(): + with connect("ws://localhost:8765") as websocket: + websocket.send("Hello world!") + message = websocket.recv() + print(f"Received: {message}") - async def main(): - async with serve(echo, "localhost", 8765): - await asyncio.Future() # run forever + hello() - asyncio.run(main()) Does that look good? @@ -91,9 +96,8 @@ Why should I use ``websockets``? The development of ``websockets`` is shaped by four principles: -1. **Correctness**: ``websockets`` is heavily tested for compliance - with :rfc:`6455`. Continuous integration fails under 100% branch - coverage. +1. **Correctness**: ``websockets`` is heavily tested for compliance with + :rfc:`6455`. Continuous integration fails under 100% branch coverage. 2. **Simplicity**: all you need to understand is ``msg = await ws.recv()`` and ``await ws.send(msg)``. ``websockets`` takes care of managing connections diff --git a/docs/faq/asyncio.rst b/docs/faq/asyncio.rst index d00cf3f47..e56a42d36 100644 --- a/docs/faq/asyncio.rst +++ b/docs/faq/asyncio.rst @@ -46,17 +46,13 @@ See `issue 867`_. Why am I having problems with threads? -------------------------------------- -You shouldn't use threads. Use tasks instead. - -Indeed, when you chose websockets, you chose :mod:`asyncio` as the primary -framework to handle concurrency. This choice is mutually exclusive with -:mod:`threading`. +If you choose websockets' default implementation based on :mod:`asyncio`, then +you shouldn't use threads. Indeed, choosing :mod:`asyncio` to handle concurrency +is mutually exclusive with :mod:`threading`. If you believe that you need to run websockets in a thread and some logic in another thread, you should run that logic in a :class:`~asyncio.Task` instead. - -If you believe that you cannot run that logic in the same event loop because it -will block websockets, :meth:`~asyncio.loop.run_in_executor` may help. +If it blocks the event loop, :meth:`~asyncio.loop.run_in_executor` will help. This question is really about :mod:`asyncio`. Please review the advice about :ref:`asyncio-multithreading` in the Python documentation. diff --git a/docs/faq/client.rst b/docs/faq/client.rst index 73825e480..c4f5a35b9 100644 --- a/docs/faq/client.rst +++ b/docs/faq/client.rst @@ -43,6 +43,11 @@ To set other HTTP headers, for example the ``Authorization`` header, use the async with connect(..., extra_headers={"Authorization": ...}) as websocket: ... +In the :mod:`threading` API, this argument is named ``additional_headers``:: + + with connect(..., additional_headers={"Authorization": ...}) as websocket: + ... + How do I close a connection? ---------------------------- diff --git a/docs/faq/misc.rst b/docs/faq/misc.rst index e320cb808..4fc271322 100644 --- a/docs/faq/misc.rst +++ b/docs/faq/misc.rst @@ -39,8 +39,8 @@ Why is the default implementation located in ``websockets.legacy``? ................................................................... This is an artifact of websockets' history. For its first eight years, only the -:mod:`asyncio`-based implementation existed. Then, the Sans-I/O implementation -was added. Moving the code in a ``legacy`` submodule eased this refactoring and +:mod:`asyncio` implementation existed. Then, the Sans-I/O implementation was +added. Moving the code in a ``legacy`` submodule eased this refactoring and optimized maintainability. All public APIs were kept at their original locations. ``websockets.legacy`` @@ -61,11 +61,6 @@ you may need to disable: If websockets is still slower than another Python library, please file a bug. -Can I use websockets without ``async`` and ``await``? -..................................................... - -No, there is no convenient way to do this. You should use another library. - Are there ``onopen``, ``onmessage``, ``onerror``, and ``onclose`` callbacks? ............................................................................ diff --git a/docs/index.rst b/docs/index.rst index d64d80ac3..be6a0da05 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -21,26 +21,36 @@ websockets .. |openssf| image:: https://bestpractices.coreinfrastructure.org/projects/6475/badge :target: https://bestpractices.coreinfrastructure.org/projects/6475 -websockets is a library for building WebSocket_ servers and -clients in Python with a focus on correctness, simplicity, robustness, and -performance. +websockets is a library for building WebSocket_ servers and clients in Python +with a focus on correctness, simplicity, robustness, and performance. .. _WebSocket: https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API -Built on top of :mod:`asyncio`, Python's standard asynchronous I/O framework, -it provides an elegant coroutine-based API. +It supports several network I/O and control flow paradigms: -Here's how a client sends and receives messages: +1. The default implementation builds upon :mod:`asyncio`, Python's standard + asynchronous I/O framework. It provides an elegant coroutine-based API. It's + ideal for servers that handle many clients concurrently. +2. The :mod:`threading` implementation is a good alternative for clients, + especially if you aren't familiar with :mod:`asyncio`. It may also be used + for servers that don't need to serve many clients. +3. The `Sans-I/O`_ implementation is designed for integrating in third-party + libraries, typically application servers, in addition being used internally + by websockets. -.. literalinclude:: ../example/hello.py +.. _Sans-I/O: https://sans-io.readthedocs.io/ -And here's an echo server: +Here's an echo server with the :mod:`asyncio` API: .. literalinclude:: ../example/echo.py +Here's how a client sends and receives messages with the :mod:`threading` API: + +.. literalinclude:: ../example/hello.py + Don't worry about the opening and closing handshakes, pings and pongs, or any -other behavior described in the specification. websockets takes care of this -under the hood so you can focus on your application! +other behavior described in the WebSocket specification. websockets takes care +of this under the hood so you can focus on your application! Also, websockets provides an interactive client: diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 85a4ac92d..95193e780 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -66,6 +66,17 @@ Backwards-incompatible changes New features ............ +.. admonition:: websockets 10.0 introduces a implementation on top of :mod:`threading`. + :class: important + + It may be more convenient if you don't need to manage many connections and + you're more comfortable with :mod:`threading` than :mod:`asyncio`. + + It is particularly suited to client applications that establish only one + connection. It may be used for servers handling few connections. + + See :func:`~sync.client.connect` and :func:`~sync.server.serve` for details. + * Made it possible to close a server without closing existing connections. * Added :attr:`~server.ServerProtocol.select_subprotocol` to customize diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 3b708ef91..fa8047c3d 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -16,6 +16,18 @@ clients concurrently. asyncio/client asyncio/common +:mod:`threading` +---------------- + +This alternative implementation can be a good choice for clients. + +.. toctree:: + :titlesonly: + + sync/server + sync/client + sync/common + `Sans-I/O`_ ----------- diff --git a/docs/reference/sync/client.rst b/docs/reference/sync/client.rst new file mode 100644 index 000000000..6cccd6ec4 --- /dev/null +++ b/docs/reference/sync/client.rst @@ -0,0 +1,49 @@ +Client (:mod:`threading`) +========================= + +.. automodule:: websockets.sync.client + +Opening a connection +-------------------- + +.. autofunction:: connect(uri, *, sock=None, ssl_context=None, server_hostname=None, origin=None, extensions=None, subprotocols=None, additional_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", compression="deflate", open_timeout=10, close_timeout=10, max_size=2 ** 20, logger=None, create_connection=None) + +.. autofunction:: unix_connect(path, uri=None, *, sock=None, ssl_context=None, server_hostname=None, origin=None, extensions=None, subprotocols=None, additional_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", compression="deflate", open_timeout=10, close_timeout=10, max_size=2 ** 20, logger=None, create_connection=None) + +Using a connection +------------------ + +.. autoclass:: ClientConnection + + .. automethod:: __iter__ + + .. automethod:: recv + + .. automethod:: recv_streaming + + .. automethod:: send + + .. automethod:: close + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: request + + .. autoattribute:: response + + .. autoproperty:: subprotocol diff --git a/docs/reference/sync/common.rst b/docs/reference/sync/common.rst new file mode 100644 index 000000000..8d97ab3c1 --- /dev/null +++ b/docs/reference/sync/common.rst @@ -0,0 +1,39 @@ +Both sides (:mod:`threading`) +============================= + +.. automodule:: websockets.sync.connection + +.. autoclass:: Connection + + .. automethod:: __iter__ + + .. automethod:: recv + + .. automethod:: recv_streaming + + .. automethod:: send + + .. automethod:: close + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: request + + .. autoattribute:: response + + .. autoproperty:: subprotocol diff --git a/docs/reference/sync/server.rst b/docs/reference/sync/server.rst new file mode 100644 index 000000000..35c112046 --- /dev/null +++ b/docs/reference/sync/server.rst @@ -0,0 +1,60 @@ +Server (:mod:`threading`) +========================= + +.. automodule:: websockets.sync.server + +Creating a server +----------------- + +.. autofunction:: serve(handler, host=None, port=None, *, sock=None, ssl_context=None, origins=None, extensions=None, subprotocols=None, select_subprotocol=None, process_request=None, process_response=None, server_header="Python/x.y.z websockets/X.Y", compression="deflate", open_timeout=10, close_timeout=10, max_size=2 ** 20, logger=None, create_connection=None) + +.. autofunction:: unix_serve(handler, path=None, *, sock=None, ssl_context=None, origins=None, extensions=None, subprotocols=None, select_subprotocol=None, process_request=None, process_response=None, server_header="Python/x.y.z websockets/X.Y", compression="deflate", open_timeout=10, close_timeout=10, max_size=2 ** 20, logger=None, create_connection=None) + +Running a server +---------------- + +.. autoclass:: WebSocketServer + + .. automethod:: serve_forever + + .. automethod:: shutdown + + .. automethod:: fileno + +Using a connection +------------------ + +.. autoclass:: ServerConnection + + .. automethod:: __iter__ + + .. automethod:: recv + + .. automethod:: recv_streaming + + .. automethod:: send + + .. automethod:: close + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: request + + .. autoattribute:: response + + .. autoproperty:: subprotocol diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 9cc2182e1..dfa7065e7 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -22,6 +22,7 @@ css ctrl deserialize django +dev Dockerfile dyno formatter diff --git a/example/echo.py b/example/echo.py index 4b673cb17..2e47e52d9 100755 --- a/example/echo.py +++ b/example/echo.py @@ -1,14 +1,14 @@ #!/usr/bin/env python import asyncio -import websockets +from websockets.server import serve async def echo(websocket): async for message in websocket: await websocket.send(message) async def main(): - async with websockets.serve(echo, "localhost", 8765): + async with serve(echo, "localhost", 8765): await asyncio.Future() # run forever asyncio.run(main()) diff --git a/example/hello.py b/example/hello.py index 84f55dc52..a3ce0699e 100755 --- a/example/hello.py +++ b/example/hello.py @@ -1,11 +1,12 @@ #!/usr/bin/env python import asyncio -import websockets +from websockets.sync.client import connect -async def hello(): - async with websockets.connect("ws://localhost:8765") as websocket: - await websocket.send("Hello world!") - await websocket.recv() +def hello(): + with connect("ws://localhost:8765") as websocket: + websocket.send("Hello world!") + message = websocket.recv() + print(f"Received: {message}") -asyncio.run(hello()) +hello() diff --git a/src/websockets/client.py b/src/websockets/client.py index b5f871571..a0d077fc2 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -60,7 +60,7 @@ class ClientProtocol(Protocol): preference. state: initial state of the WebSocket connection. max_size: maximum size of incoming messages in bytes; - :obj:`None` to disable the limit. + :obj:`None` disables the limit. logger: logger for this connection; defaults to ``logging.getLogger("websockets.client")``; see the :doc:`logging guide <../../topics/logging>` for details. diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index 3511469e6..d3425836e 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -135,20 +135,20 @@ def basic_auth_protocol_factory( ) Args: - realm: indicates the scope of protection. It should contain only ASCII - characters because the encoding of non-ASCII characters is - undefined. Refer to section 2.2 of :rfc:`7235` for details. - credentials: defines hard coded authorized credentials. It can be a + realm: Scope of protection. It should contain only ASCII characters + because the encoding of non-ASCII characters is undefined. + Refer to section 2.2 of :rfc:`7235` for details. + credentials: Hard coded authorized credentials. It can be a ``(username, password)`` pair or a list of such pairs. - check_credentials: defines a coroutine that verifies credentials. - This coroutine receives ``username`` and ``password`` arguments + check_credentials: Coroutine that verifies credentials. + It receives ``username`` and ``password`` arguments and returns a :class:`bool`. One of ``credentials`` or ``check_credentials`` must be provided but not both. - create_protocol: factory that creates the protocol. By default, this + create_protocol: Factory that creates the protocol. By default, this is :class:`BasicAuthWebSocketServerProtocol`. It can be replaced by a subclass. Raises: - TypeError: if the ``credentials`` or ``check_credentials`` argument is + TypeError: If the ``credentials`` or ``check_credentials`` argument is wrong. """ diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index f8876f59d..aa71ddb6e 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -130,7 +130,7 @@ async def read_http_response(self) -> Tuple[int, Headers]: after this coroutine returns. Raises: - InvalidMessage: if the HTTP message is malformed or isn't an + InvalidMessage: If the HTTP message is malformed or isn't an HTTP/1.1 GET response. """ @@ -273,15 +273,15 @@ async def handshake( Args: wsuri: URI of the WebSocket server. - origin: value of the ``Origin`` header. - available_extensions: list of supported extensions, in order in - which they should be tried. - available_subprotocols: list of supported subprotocols, in order - of decreasing preference. - extra_headers: arbitrary HTTP headers to add to the request. + origin: Value of the ``Origin`` header. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + extra_headers: Arbitrary HTTP headers to add to the handshake request. Raises: - InvalidHandshake: if the handshake fails. + InvalidHandshake: If the handshake fails. """ request_headers = Headers() @@ -376,28 +376,26 @@ class Connect: Args: uri: URI of the WebSocket server. - create_protocol: factory for the :class:`asyncio.Protocol` managing - the connection; defaults to :class:`WebSocketClientProtocol`; may - be set to a wrapper or a subclass to customize connection handling. - logger: logger for this connection; - defaults to ``logging.getLogger("websockets.client")``; - see the :doc:`logging guide <../../topics/logging>` for details. - compression: shortcut that enables the "permessage-deflate" extension - by default; may be set to :obj:`None` to disable compression; - see the :doc:`compression guide <../../topics/compression>` for details. - origin: value of the ``Origin`` header. This is useful when connecting - to a server that validates the ``Origin`` header to defend against - Cross-Site WebSocket Hijacking attacks. - extensions: list of supported extensions, in order in which they - should be tried. - subprotocols: list of supported subprotocols, in order of decreasing + create_protocol: Factory for the :class:`asyncio.Protocol` managing + the connection. It defaults to :class:`WebSocketClientProtocol`. + Set it to a wrapper or a subclass to customize connection handling. + logger: Logger for this client. + It defaults to ``logging.getLogger("websockets.client")``. + See the :doc:`logging guide <../../topics/logging>` for details. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + origin: Value of the ``Origin`` header, for servers that require it. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing preference. - extra_headers: arbitrary HTTP headers to add to the request. - user_agent_header: value of the ``User-Agent`` request header; - defaults to ``"Python/x.y.z websockets/X.Y"``; - :obj:`None` removes the header. - open_timeout: timeout for opening the connection in seconds; - :obj:`None` to disable the timeout + extra_headers: Arbitrary HTTP headers to add to the handshake request. + user_agent_header: Value of the ``User-Agent`` request header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. + Setting it to :obj:`None` removes the header. + open_timeout: Timeout for opening the connection in seconds. + :obj:`None` disables the timeout. See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, @@ -418,13 +416,10 @@ class Connect: the TCP connection. The host name from ``uri`` is still used in the TLS handshake for secure connections and in the ``Host`` header. - Returns: - WebSocketClientProtocol: WebSocket connection. - Raises: - InvalidURI: if ``uri`` isn't a valid WebSocket URI. - InvalidHandshake: if the opening handshake fails. - ~asyncio.TimeoutError: if the opening handshake times out. + InvalidURI: If ``uri`` isn't a valid WebSocket URI. + InvalidHandshake: If the opening handshake fails. + ~asyncio.TimeoutError: If the opening handshake times out. """ @@ -705,7 +700,7 @@ def unix_connect( It's mainly useful for debugging servers listening on Unix sockets. Args: - path: file system path to the Unix socket. + path: File system path to the Unix socket. uri: URI of the WebSocket server; the host is used in the TLS handshake for secure connections and in the ``Host`` header. diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index 29864b136..4836eb284 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -51,16 +51,16 @@ async def read( Read a WebSocket frame. Args: - reader: coroutine that reads exactly the requested number of + reader: Coroutine that reads exactly the requested number of bytes, unless the end of file is reached. - mask: whether the frame should be masked i.e. whether the read + mask: Whether the frame should be masked i.e. whether the read happens on the server side. - max_size: maximum payload size in bytes. - extensions: list of extensions, applied in reverse order. + max_size: Maximum payload size in bytes. + extensions: List of extensions, applied in reverse order. Raises: - PayloadTooBig: if the frame exceeds ``max_size``. - ProtocolError: if the frame contains incorrect values. + PayloadTooBig: If the frame exceeds ``max_size``. + ProtocolError: If the frame contains incorrect values. """ @@ -128,14 +128,14 @@ def write( Write a WebSocket frame. Args: - frame: frame to write. - write: function that writes bytes. - mask: whether the frame should be masked i.e. whether the write + frame: Frame to write. + write: Function that writes bytes. + mask: Whether the frame should be masked i.e. whether the write happens on the client side. - extensions: list of extensions, applied in order. + extensions: List of extensions, applied in order. Raises: - ProtocolError: if the frame contains incorrect values. + ProtocolError: If the frame contains incorrect values. """ # The frame is written in a single call to write in order to prevent @@ -154,11 +154,11 @@ def parse_close(data: bytes) -> Tuple[int, str]: Parse the payload from a close frame. Returns: - Tuple[int, str]: close code and reason. + Close code and reason. Raises: - ProtocolError: if data is ill-formed. - UnicodeDecodeError: if the reason isn't valid UTF-8. + ProtocolError: If data is ill-formed. + UnicodeDecodeError: If the reason isn't valid UTF-8. """ close = Close.parse(data) diff --git a/src/websockets/legacy/handshake.py b/src/websockets/legacy/handshake.py index 569937bb9..ad8faf040 100644 --- a/src/websockets/legacy/handshake.py +++ b/src/websockets/legacy/handshake.py @@ -21,7 +21,7 @@ def build_request(headers: Headers) -> str: Update request headers passed in argument. Args: - headers: handshake request headers. + headers: Handshake request headers. Returns: str: ``key`` that must be passed to :func:`check_response`. @@ -45,14 +45,14 @@ def check_request(headers: Headers) -> str: the responsibility of the caller. Args: - headers: handshake request headers. + headers: Handshake request headers. Returns: str: ``key`` that must be passed to :func:`build_response`. Raises: - InvalidHandshake: if the handshake request is invalid; - then the server must return 400 Bad Request error. + InvalidHandshake: If the handshake request is invalid. + Then, the server must return a 400 Bad Request error. """ connection: List[ConnectionOption] = sum( @@ -110,8 +110,8 @@ def build_response(headers: Headers, key: str) -> None: Update response headers passed in argument. Args: - headers: handshake response headers. - key: returned by :func:`check_request`. + headers: Handshake response headers. + key: Returned by :func:`check_request`. """ headers["Upgrade"] = "websocket" @@ -128,11 +128,11 @@ def check_response(headers: Headers, key: str) -> None: the caller. Args: - headers: handshake response headers. - key: returned by :func:`build_request`. + headers: Handshake response headers. + key: Returned by :func:`build_request`. Raises: - InvalidHandshake: if the handshake response is invalid. + InvalidHandshake: If the handshake response is invalid. """ connection: List[ConnectionOption] = sum( diff --git a/src/websockets/legacy/http.py b/src/websockets/legacy/http.py index 7cc3db844..2ac7f7092 100644 --- a/src/websockets/legacy/http.py +++ b/src/websockets/legacy/http.py @@ -56,12 +56,12 @@ async def read_request(stream: asyncio.StreamReader) -> Tuple[str, Headers]: body, it may be read from ``stream`` after this coroutine returns. Args: - stream: input to read the request from + stream: Input to read the request from. Raises: - EOFError: if the connection is closed without a full HTTP request - SecurityError: if the request exceeds a security limit - ValueError: if the request isn't well formatted + EOFError: If the connection is closed without a full HTTP request. + SecurityError: If the request exceeds a security limit. + ValueError: If the request isn't well formatted. """ # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.1 @@ -103,12 +103,12 @@ async def read_response(stream: asyncio.StreamReader) -> Tuple[int, str, Headers body, it may be read from ``stream`` after this coroutine returns. Args: - stream: input to read the response from + stream: Input to read the response from. Raises: - EOFError: if the connection is closed without a full HTTP response - SecurityError: if the response exceeds a security limit - ValueError: if the response isn't well formatted + EOFError: If the connection is closed without a full HTTP response. + SecurityError: If the response exceeds a security limit. + ValueError: If the response isn't well formatted. """ # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.2 diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index d31ec19a8..d1979cd12 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -79,33 +79,32 @@ class WebSocketCommonProtocol(asyncio.Protocol): simplicity. Once the connection is open, a Ping_ frame is sent every ``ping_interval`` - seconds. This serves as a keepalive. It helps keeping the connection - open, especially in the presence of proxies with short timeouts on - inactive connections. Set ``ping_interval`` to :obj:`None` to disable - this behavior. + seconds. This serves as a keepalive. It helps keeping the connection open, + especially in the presence of proxies with short timeouts on inactive + connections. Set ``ping_interval`` to :obj:`None` to disable this behavior. .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 If the corresponding Pong_ frame isn't received within ``ping_timeout`` - seconds, the connection is considered unusable and is closed with code - 1011. This ensures that the remote endpoint remains responsive. Set + seconds, the connection is considered unusable and is closed with code 1011. + This ensures that the remote endpoint remains responsive. Set ``ping_timeout`` to :obj:`None` to disable this behavior. .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 + See the discussion of :doc:`timeouts <../../topics/timeouts>` for details. + The ``close_timeout`` parameter defines a maximum wait time for completing the closing handshake and terminating the TCP connection. For legacy reasons, :meth:`close` completes in at most ``5 * close_timeout`` seconds for clients and ``4 * close_timeout`` for servers. - See the discussion of :doc:`timeouts <../../topics/timeouts>` for details. - - ``close_timeout`` needs to be a parameter of the protocol because - websockets usually calls :meth:`close` implicitly upon exit: + ``close_timeout`` is a parameter of the protocol because websockets usually + calls :meth:`close` implicitly upon exit: - * on the client side, when :func:`~websockets.client.connect` is used as a + * on the client side, when using :func:`~websockets.client.connect` as a context manager; - * on the server side, when the connection handler terminates; + * on the server side, when the connection handler terminates. To apply a timeout to any other API, wrap it in :func:`~asyncio.wait_for`. @@ -144,21 +143,21 @@ class WebSocketCommonProtocol(asyncio.Protocol): See the discussion of :doc:`memory usage <../../topics/memory>` for details. Args: - logger: logger for this connection; - defaults to ``logging.getLogger("websockets.protocol")``; - see the :doc:`logging guide <../../topics/logging>` for details. - ping_interval: delay between keepalive pings in seconds; - :obj:`None` to disable keepalive pings. - ping_timeout: timeout for keepalive pings in seconds; - :obj:`None` to disable timeouts. - close_timeout: timeout for closing the connection in seconds; - for legacy reasons, the actual timeout is 4 or 5 times larger. - max_size: maximum size of incoming messages in bytes; - :obj:`None` to disable the limit. - max_queue: maximum number of incoming messages in receive buffer; - :obj:`None` to disable the limit. - read_limit: high-water mark of read buffer in bytes. - write_limit: high-water mark of write buffer in bytes. + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.protocol")``. + See the :doc:`logging guide <../../topics/logging>` for details. + ping_interval: Delay between keepalive pings in seconds. + :obj:`None` disables keepalive pings. + ping_timeout: Timeout for keepalive pings in seconds. + :obj:`None` disables timeouts. + close_timeout: Timeout for closing the connection in seconds. + For legacy reasons, the actual timeout is 4 or 5 times larger. + max_size: Maximum size of incoming messages in bytes. + :obj:`None` disables the limit. + max_queue: Maximum number of incoming messages in receive buffer. + :obj:`None` disables the limit. + read_limit: High-water mark of read buffer in bytes. + write_limit: High-water mark of write buffer in bytes. """ @@ -484,10 +483,11 @@ async def __aiter__(self) -> AsyncIterator[Data]: """ Iterate on incoming messages. - The iterator exits normally when the connection is closed with the - close code 1000 (OK) or 1001(going away) or without a close code. It - raises a :exc:`~websockets.exceptions.ConnectionClosedError` exception - when the connection is closed with any other code. + The iterator exits normally when the connection is closed with the close + code 1000 (OK) or 1001 (going away) or without a close code. + + It raises a :exc:`~websockets.exceptions.ConnectionClosedError` + exception when the connection is closed with any other code. """ try: @@ -501,8 +501,8 @@ async def recv(self) -> Data: Receive the next message. When the connection is closed, :meth:`recv` raises - :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it - raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it raises + :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal connection closure and :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol error or a network failure. This is how you detect the end of the @@ -511,8 +511,8 @@ async def recv(self) -> Data: Canceling :meth:`recv` is safe. There's no risk of losing the next message. The next invocation of :meth:`recv` will return it. - This makes it possible to enforce a timeout by wrapping :meth:`recv` - in :func:`~asyncio.wait_for`. + This makes it possible to enforce a timeout by wrapping :meth:`recv` in + :func:`~asyncio.wait_for`. Returns: Data: A string (:class:`str`) for a Text_ frame. A bytestring @@ -522,8 +522,8 @@ async def recv(self) -> Data: .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 Raises: - ConnectionClosed: when the connection is closed. - RuntimeError: if two coroutines call :meth:`recv` concurrently. + ConnectionClosed: When the connection is closed. + RuntimeError: If two coroutines call :meth:`recv` concurrently. """ if self._pop_message_waiter is not None: @@ -626,8 +626,8 @@ async def send( to send. Raises: - ConnectionClosed: when the connection is closed. - TypeError: if ``message`` doesn't have a supported type. + ConnectionClosed: When the connection is closed. + TypeError: If ``message`` doesn't have a supported type. """ await self.ensure_open() @@ -841,8 +841,8 @@ async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: latency = await pong_waiter Raises: - ConnectionClosed: when the connection is closed. - RuntimeError: if another ping was sent with the same data and + ConnectionClosed: When the connection is closed. + RuntimeError: If another ping was sent with the same data and the corresponding pong wasn't received yet. """ @@ -881,11 +881,11 @@ async def pong(self, data: Data = b"") -> None: wait, you should close the connection. Args: - data (Data): payload of the pong; a string will be encoded to + data (Data): Payload of the pong. A string will be encoded to UTF-8. Raises: - ConnectionClosed: when the connection is closed. + ConnectionClosed: When the connection is closed. """ await self.ensure_open() @@ -1604,11 +1604,11 @@ def broadcast(websockets: Iterable[WebSocketCommonProtocol], message: Data) -> N Args: websockets (Iterable[WebSocketCommonProtocol]): WebSocket connections to which the message will be sent. - message (Data): message to send. + message (Data): Message to send. Raises: - RuntimeError: if a connection is busy sending a fragmented message. - TypeError: if ``message`` doesn't have a supported type. + RuntimeError: If a connection is busy sending a fragmented message. + TypeError: If ``message`` doesn't have a supported type. """ if not isinstance(message, (str, bytes, bytearray, memoryview)): diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 048a270b5..399df85d3 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -657,9 +657,9 @@ class WebSocketServer: when shutting down. Args: - logger: logger for this server; - defaults to ``logging.getLogger("websockets.server")``; - see the :doc:`logging guide <../../topics/logging>` for details. + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. + See the :doc:`logging guide <../../topics/logging>` for details. """ @@ -918,43 +918,42 @@ class Serve: await stop Args: - ws_handler: connection handler. It receives the WebSocket connection, + ws_handler: Connection handler. It receives the WebSocket connection, which is a :class:`WebSocketServerProtocol`, in argument. - host: network interfaces the server is bound to; - see :meth:`~asyncio.loop.create_server` for details. - port: TCP port the server listens on; - see :meth:`~asyncio.loop.create_server` for details. - create_protocol: factory for the :class:`asyncio.Protocol` managing - the connection; defaults to :class:`WebSocketServerProtocol`; may - be set to a wrapper or a subclass to customize connection handling. - logger: logger for this server; - defaults to ``logging.getLogger("websockets.server")``; - see the :doc:`logging guide <../../topics/logging>` for details. - compression: shortcut that enables the "permessage-deflate" extension - by default; may be set to :obj:`None` to disable compression; - see the :doc:`compression guide <../../topics/compression>` for details. - origins: acceptable values of the ``Origin`` header; include - :obj:`None` in the list if the lack of an origin is acceptable. - This is useful for defending against Cross-Site WebSocket - Hijacking attacks. - extensions: list of supported extensions, in order in which they - should be tried. - subprotocols: list of supported subprotocols, in order of decreasing + host: Network interfaces the server binds to. + See :meth:`~asyncio.loop.create_server` for details. + port: TCP port the server listens on. + See :meth:`~asyncio.loop.create_server` for details. + create_protocol: Factory for the :class:`asyncio.Protocol` managing + the connection. It defaults to :class:`WebSocketServerProtocol`. + Set it to a wrapper or a subclass to customize connection handling. + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. + See the :doc:`logging guide <../../topics/logging>` for details. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + origins: Acceptable values of the ``Origin`` header, for defending + against Cross-Site WebSocket Hijacking attacks. Include :obj:`None` + in the list if the lack of an origin is acceptable. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing preference. extra_headers (Union[HeadersLike, Callable[[str, Headers], HeadersLike]]): - arbitrary HTTP headers to add to the request; this can be + Arbitrary HTTP headers to add to the response. This can be a :data:`~websockets.datastructures.HeadersLike` or a callable taking the request path and headers in arguments and returning a :data:`~websockets.datastructures.HeadersLike`. - server_header: value of the ``Server`` response header; - defaults to ``"Python/x.y.z websockets/X.Y"``; - :obj:`None` removes the header. + server_header: Value of the ``Server`` response header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. + Setting it to :obj:`None` removes the header. process_request (Optional[Callable[[str, Headers], \ Awaitable[Optional[Tuple[http.HTTPStatus, HeadersLike, bytes]]]]]): - intercept HTTP request before the opening handshake; - see :meth:`~WebSocketServerProtocol.process_request` for details. - select_subprotocol: select a subprotocol supported by the client; - see :meth:`~WebSocketServerProtocol.select_subprotocol` for details. + Intercept HTTP request before the opening handshake. + See :meth:`~WebSocketServerProtocol.process_request` for details. + select_subprotocol: Select a subprotocol supported by the client. + See :meth:`~WebSocketServerProtocol.select_subprotocol` for details. See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, @@ -1137,17 +1136,18 @@ def unix_serve( **kwargs: Any, ) -> Serve: """ - Similar to :func:`serve`, but for listening on Unix sockets. + Start a WebSocket server listening on a Unix socket. - This function builds upon the event - loop's :meth:`~asyncio.loop.create_unix_server` method. + This function is identical to :func:`serve`, except the ``host`` and + ``port`` arguments are replaced by ``path``. It is only available on Unix. - It is only available on Unix. + Unrecognized keyword arguments are passed the event loop's + :meth:`~asyncio.loop.create_unix_server` method. It's useful for deploying a server behind a reverse proxy such as nginx. Args: - path: file system path to the Unix socket. + path: File system path to the Unix socket. """ return serve(ws_handler, path=path, unix=True, **kwargs) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 7bfa96f8b..3fdd3881c 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -75,7 +75,7 @@ class Protocol: side: :attr:`~Side.CLIENT` or :attr:`~Side.SERVER`. state: initial state of the WebSocket connection. max_size: maximum size of incoming messages in bytes; - :obj:`None` to disable the limit. + :obj:`None` disables the limit. logger: logger for this connection; depending on ``side``, defaults to ``logging.getLogger("websockets.client")`` or ``logging.getLogger("websockets.server")``; @@ -263,7 +263,8 @@ def receive_eof(self) -> None: After calling this method: - - You must call :meth:`data_to_send` and send this data to the network. + - You must call :meth:`data_to_send` and send this data to the network; + it will return ``[b""]``, signaling the end of the stream, or ``[]``. - You aren't expected to call :meth:`events_received`; it won't return any new events. @@ -481,8 +482,8 @@ def close_expected(self) -> bool: """ Tell if the TCP connection is expected to close soon. - Call this method immediately after any of the ``receive_*()`` or - :meth:`fail` methods. + Call this method immediately after any of the ``receive_*()``, + ``send_close()``, or :meth:`fail` methods. If it returns :obj:`True`, schedule closing the TCP connection after a short timeout if the other side hasn't already closed it. @@ -510,6 +511,9 @@ def parse(self) -> Generator[None, None, None]: :meth:`receive_data` and :meth:`receive_eof` run this generator coroutine until it needs more data or reaches EOF. + :meth:`parse` never raises an exception. Instead, it sets the + :attr:`parser_exc` and yields control. + """ try: while True: diff --git a/src/websockets/server.py b/src/websockets/server.py index 0dd579052..e49f30213 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -58,13 +58,13 @@ class ServerProtocol(Protocol): should be tried. subprotocols: list of supported subprotocols, in order of decreasing preference. - select_subprotocol: callback for selecting a subprotocol among + select_subprotocol: Callback for selecting a subprotocol among those supported by the client and the server. It has the same signature as the :meth:`select_subprotocol` method, including a :class:`ServerProtocol` instance as first argument. state: initial state of the WebSocket connection. max_size: maximum size of incoming messages in bytes; - :obj:`None` to disable the limit. + :obj:`None` disables the limit. logger: logger for this connection; defaults to ``logging.getLogger("websockets.client")``; see the :doc:`logging guide <../../topics/logging>` for details. @@ -120,7 +120,7 @@ def accept(self, request: Request) -> Response: request: WebSocket handshake request event received from the client. Returns: - Response: WebSocket handshake response event to send to the client. + WebSocket handshake response event to send to the client. """ try: diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index e5922582d..ec1167115 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -168,8 +168,8 @@ def connect( from ``uri``. You may call :func:`socket.create_connection` to create a suitable TCP socket. ssl_context: Configuration for enabling TLS on the connection. - server_hostname: Hostname for the TLS handshake. ``server_hostname`` - overrides the hostname from ``uri``. + server_hostname: Host name for the TLS handshake. ``server_hostname`` + overrides the host name from ``uri``. origin: Value of the ``Origin`` header, for servers that require it. extensions: List of supported extensions, in order in which they should be negotiated and run. From 399db68dfa31768569b5257e862cd08987286be1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 29 Mar 2023 07:18:02 +0200 Subject: [PATCH 1172/1539] Ignore compatibility module in coverage measurement. --- setup.cfg | 1 + src/websockets/legacy/compatibility.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 3a8321c50..c2ca2d5dd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -23,6 +23,7 @@ branch = True omit = # */websockets matches src/websockets and .tox/**/site-packages/websockets */websockets/__main__.py + */websockets/legacy/compatibility.py tests/maxi_cov.py [coverage:paths] diff --git a/src/websockets/legacy/compatibility.py b/src/websockets/legacy/compatibility.py index 296e9c584..303e203b4 100644 --- a/src/websockets/legacy/compatibility.py +++ b/src/websockets/legacy/compatibility.py @@ -14,7 +14,7 @@ def loop_if_py_lt_38(loop: asyncio.AbstractEventLoop) -> Dict[str, Any]: """ return {} -else: # pragma: no cover +else: def loop_if_py_lt_38(loop: asyncio.AbstractEventLoop) -> Dict[str, Any]: """ From 45f3b6047a3b09f7685fa1dc585ee34097480e58 Mon Sep 17 00:00:00 2001 From: shafemtol Date: Fri, 24 Feb 2023 18:32:26 +0100 Subject: [PATCH 1173/1539] Set server_hostname automatically when needed --- src/websockets/legacy/client.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index aa71ddb6e..f2d2c72b7 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -530,6 +530,8 @@ def __init__( else: # If sock is given, host and port shouldn't be specified. host, port = None, None + if kwargs.get("ssl"): + kwargs.setdefault("server_hostname", wsuri.host) # If host and port are given, override values from the URI. host = kwargs.pop("host", host) port = kwargs.pop("port", port) From b0876be3b7770854f1dd933e190205ee884405b1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 29 Mar 2023 07:51:17 +0200 Subject: [PATCH 1174/1539] Simplify tests thanks to previous commit --- tests/legacy/test_client_server.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index d92338585..2da203029 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -428,11 +428,7 @@ def send(self, *args, **kwargs): self.assertFalse(client_socket.used_for_read) self.assertFalse(client_socket.used_for_write) - with self.temp_client( - sock=client_socket, - # "You must set server_hostname when using ssl without a host" - server_hostname="localhost" if self.secure else None, - ): + with self.temp_client(sock=client_socket): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") From 132bf7fa44bc9e0ed3ff29aa6bfd7cb9303e49a6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 29 Mar 2023 07:52:41 +0200 Subject: [PATCH 1175/1539] Add changelog entry for previous commits --- docs/project/changelog.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 95193e780..5a3afcc43 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -82,6 +82,12 @@ New features * Added :attr:`~server.ServerProtocol.select_subprotocol` to customize negotiation of subprotocols in the Sans-I/O layer. +Improvements +............ + +* Set ``server_hostname`` automatically on TLS connections when providing a + ``sock`` argument to :func:`~sync.client.connect`. + 10.4 ---- From f0e547965a53582b45df45c5b6202dc3a1284240 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 29 Mar 2023 05:47:41 +0000 Subject: [PATCH 1176/1539] Bump pypa/cibuildwheel from 2.11.1 to 2.12.1 Bumps [pypa/cibuildwheel](https://github.com/pypa/cibuildwheel) from 2.11.1 to 2.12.1. - [Release notes](https://github.com/pypa/cibuildwheel/releases) - [Changelog](https://github.com/pypa/cibuildwheel/blob/main/docs/changelog.md) - [Commits](https://github.com/pypa/cibuildwheel/compare/v2.11.1...v2.12.1) --- updated-dependencies: - dependency-name: pypa/cibuildwheel dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/wheels.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 0013ad103..90a075843 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -48,7 +48,7 @@ jobs: with: platforms: all - name: Build wheels - uses: pypa/cibuildwheel@v2.11.1 + uses: pypa/cibuildwheel@v2.12.1 env: CIBW_ARCHS_MACOS: "x86_64 universal2 arm64" CIBW_ARCHS_LINUX: "auto aarch64" From 09875a619f53fa7fb566d640943e7ea0c28dffae Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 29 Mar 2023 08:02:43 +0200 Subject: [PATCH 1177/1539] Attempt to fix #1317. --- tests/sync/test_client.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 8824ed894..c900f3b0f 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -111,7 +111,10 @@ def stall_connection(self, request): TimeoutError, "timed out during handshake", ): - with run_client(server, open_timeout=3 * MS): + # While it shouldn't take 50ms to open a connection, this + # test becomes flaky in CI when setting a smaller timeout, + # even after increasing WEBSOCKETS_TESTS_TIMEOUT_FACTOR. + with run_client(server, open_timeout=5 * MS): self.fail("did not raise") finally: gate.set() From 25a5252c385d867add96b9bc5df2d537b49d636a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 29 Mar 2023 08:13:20 +0200 Subject: [PATCH 1178/1539] Skip test that fails randomly on PyPy. Fix #1314. --- tests/sync/test_connection.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 94850affe..f26ec3f95 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -1,5 +1,6 @@ import contextlib import logging +import platform import socket import sys import threading @@ -465,6 +466,10 @@ def test_close_idempotency(self): self.connection.close() self.assertNoFrameSent() + @unittest.skipIf( + platform.python_implementation() == "PyPy", + "this test fails randomly due to a bug in PyPy", # see #1314 for details + ) def test_close_idempotency_race_condition(self): """close waits if the connection is already closing.""" From ce06dd6e15159dbfb83e33c16ec9afcb05ff1ec7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Apr 2023 08:49:19 +0200 Subject: [PATCH 1179/1539] Rewrite interactive client with synchronous API. Fix #1312. --- src/websockets/__main__.py | 139 +++++++++---------------------------- 1 file changed, 34 insertions(+), 105 deletions(-) diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index c562d21b5..a7dd1aaef 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -1,16 +1,18 @@ from __future__ import annotations import argparse -import asyncio import os import signal import sys import threading -from typing import Any, Set -from .exceptions import ConnectionClosed -from .frames import Close -from .legacy.client import connect + +try: + import readline # noqa +except ImportError: # Windows has no `readline` normally + pass + +from .sync.client import ClientConnection, connect from .version import version as websockets_version @@ -46,21 +48,6 @@ def win_enable_vt100() -> None: raise RuntimeError("unable to set console mode") -def exit_from_event_loop_thread( - loop: asyncio.AbstractEventLoop, - stop: asyncio.Future[None], -) -> None: - loop.stop() - if not stop.done(): - # When exiting the thread that runs the event loop, raise - # KeyboardInterrupt in the main thread to exit the program. - if sys.platform == "win32": - ctrl_c = signal.CTRL_C_EVENT - else: - ctrl_c = signal.SIGINT - os.kill(os.getpid(), ctrl_c) - - def print_during_input(string: str) -> None: sys.stdout.write( # Save cursor position @@ -93,63 +80,20 @@ def print_over_input(string: str) -> None: sys.stdout.flush() -async def run_client( - uri: str, - loop: asyncio.AbstractEventLoop, - inputs: asyncio.Queue[str], - stop: asyncio.Future[None], -) -> None: - try: - websocket = await connect(uri) - except Exception as exc: - print_over_input(f"Failed to connect to {uri}: {exc}.") - exit_from_event_loop_thread(loop, stop) - return - else: - print_during_input(f"Connected to {uri}.") - - try: - while True: - incoming: asyncio.Future[Any] = asyncio.create_task(websocket.recv()) - outgoing: asyncio.Future[Any] = asyncio.create_task(inputs.get()) - done: Set[asyncio.Future[Any]] - pending: Set[asyncio.Future[Any]] - done, pending = await asyncio.wait( - [incoming, outgoing, stop], return_when=asyncio.FIRST_COMPLETED - ) - - # Cancel pending tasks to avoid leaking them. - if incoming in pending: - incoming.cancel() - if outgoing in pending: - outgoing.cancel() - - if incoming in done: - try: - message = incoming.result() - except ConnectionClosed: - break - else: - if isinstance(message, str): - print_during_input("< " + message) - else: - print_during_input("< (binary) " + message.hex()) - - if outgoing in done: - message = outgoing.result() - await websocket.send(message) - - if stop in done: - break - - finally: - await websocket.close() - assert websocket.close_code is not None and websocket.close_reason is not None - close_status = Close(websocket.close_code, websocket.close_reason) - - print_over_input(f"Connection closed: {close_status}.") - - exit_from_event_loop_thread(loop, stop) +def print_incoming_messages(websocket: ClientConnection, stop: threading.Event) -> None: + for message in websocket: + if isinstance(message, str): + print_during_input("< " + message) + else: + print_during_input("< (binary) " + message.hex()) + if not stop.is_set(): + # When the server closes the connection, raise KeyboardInterrupt + # in the main thread to exit the program. + if sys.platform == "win32": + ctrl_c = signal.CTRL_C_EVENT + else: + ctrl_c = signal.SIGINT + os.kill(os.getpid(), ctrl_c) def main() -> None: @@ -184,29 +128,17 @@ def main() -> None: sys.stderr.flush() try: - import readline # noqa - except ImportError: # Windows has no `readline` normally - pass - - # Create an event loop that will run in a background thread. - loop = asyncio.new_event_loop() - - # Due to zealous removal of the loop parameter in the Queue constructor, - # we need a factory coroutine to run in the freshly created event loop. - async def queue_factory() -> asyncio.Queue[str]: - return asyncio.Queue() - - # Create a queue of user inputs. There's no need to limit its size. - inputs: asyncio.Queue[str] = loop.run_until_complete(queue_factory()) - - # Create a stop condition when receiving SIGINT or SIGTERM. - stop: asyncio.Future[None] = loop.create_future() + websocket = connect(args.uri) + except Exception as exc: + print(f"Failed to connect to {args.uri}: {exc}.") + sys.exit(1) + else: + print(f"Connected to {args.uri}.") - # Schedule the task that will manage the connection. - loop.create_task(run_client(args.uri, loop, inputs, stop)) + stop = threading.Event() - # Start the event loop in a background thread. - thread = threading.Thread(target=loop.run_forever) + # Start the thread that reads messages from the connection. + thread = threading.Thread(target=print_incoming_messages, args=(websocket, stop)) thread.start() # Read from stdin in the main thread in order to receive signals. @@ -214,17 +146,14 @@ async def queue_factory() -> asyncio.Queue[str]: while True: # Since there's no size limit, put_nowait is identical to put. message = input("> ") - loop.call_soon_threadsafe(inputs.put_nowait, message) + websocket.send(message) except (KeyboardInterrupt, EOFError): # ^C, ^D - loop.call_soon_threadsafe(stop.set_result, None) + stop.set() + websocket.close() + print_over_input("Connection closed.") - # Wait for the event loop to terminate. thread.join() - # For reasons unclear, even though the loop is closed in the thread, - # it still thinks it's running here. - loop.close() - if __name__ == "__main__": main() From 2fcc4837faca2a05f395805455070dc0347e9ab1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Apr 2023 08:29:51 +0200 Subject: [PATCH 1180/1539] Improve error handling in broadcast(). Fix #1319. --- docs/project/changelog.rst | 2 + docs/topics/logging.rst | 5 ++ src/websockets/legacy/protocol.py | 74 ++++++++++++++++++++++-------- tests/legacy/test_client_server.py | 4 +- tests/legacy/test_protocol.py | 60 ++++++++++++++++++++++-- 5 files changed, 119 insertions(+), 26 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 5a3afcc43..5a0ea423d 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -85,6 +85,8 @@ New features Improvements ............ +* Improved error handling in :func:`~websockets.broadcast`. + * Set ``server_hostname`` automatically on TLS connections when providing a ``sock`` argument to :func:`~sync.client.connect`. diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index 294a6cda8..e7abd96ce 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -213,6 +213,11 @@ Here's what websockets logs at each level. * Exceptions raised by connection handler coroutines in servers * Exceptions resulting from bugs in websockets +``WARNING`` +........... + +* Failures in :func:`~websockets.broadcast` + ``INFO`` ........ diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index d1979cd12..67bca0ef4 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -7,6 +7,7 @@ import random import ssl import struct +import sys import time import uuid import warnings @@ -1573,13 +1574,17 @@ def eof_received(self) -> None: self.reader.feed_eof() -def broadcast(websockets: Iterable[WebSocketCommonProtocol], message: Data) -> None: +def broadcast( + websockets: Iterable[WebSocketCommonProtocol], + message: Data, + raise_exceptions: bool = False, +) -> None: """ Broadcast a message to several WebSocket connections. - A string (:class:`str`) is sent as a Text_ frame. A bytestring or - bytes-like object (:class:`bytes`, :class:`bytearray`, or - :class:`memoryview`) is sent as a Binary_ frame. + A string (:class:`str`) is sent as a Text_ frame. A bytestring or bytes-like + object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent + as a Binary_ frame. .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 @@ -1587,33 +1592,42 @@ def broadcast(websockets: Iterable[WebSocketCommonProtocol], message: Data) -> N :func:`broadcast` pushes the message synchronously to all connections even if their write buffers are overflowing. There's no backpressure. - :func:`broadcast` skips silently connections that aren't open in order to - avoid errors on connections where the closing handshake is in progress. - - If you broadcast messages faster than a connection can handle them, - messages will pile up in its write buffer until the connection times out. - Keep low values for ``ping_interval`` and ``ping_timeout`` to prevent - excessive memory usage by slow connections when you use :func:`broadcast`. + If you broadcast messages faster than a connection can handle them, messages + will pile up in its write buffer until the connection times out. Keep + ``ping_interval`` and ``ping_timeout`` low to prevent excessive memory usage + from slow connections. Unlike :meth:`~websockets.server.WebSocketServerProtocol.send`, :func:`broadcast` doesn't support sending fragmented messages. Indeed, - fragmentation is useful for sending large messages without buffering - them in memory, while :func:`broadcast` buffers one copy per connection - as fast as possible. + fragmentation is useful for sending large messages without buffering them in + memory, while :func:`broadcast` buffers one copy per connection as fast as + possible. + + :func:`broadcast` skips connections that aren't open in order to avoid + errors on connections where the closing handshake is in progress. + + :func:`broadcast` ignores failures to write the message on some connections. + It continues writing to other connections. On Python 3.11 and above, you + may set ``raise_exceptions`` to :obj:`True` to record failures and raise all + exceptions in a :pep:`654` :exc:`ExceptionGroup`. Args: - websockets (Iterable[WebSocketCommonProtocol]): WebSocket connections - to which the message will be sent. - message (Data): Message to send. + websockets: WebSocket connections to which the message will be sent. + message: Message to send. + raise_exceptions: Whether to raise an exception in case of failures. Raises: - RuntimeError: If a connection is busy sending a fragmented message. TypeError: If ``message`` doesn't have a supported type. """ if not isinstance(message, (str, bytes, bytearray, memoryview)): raise TypeError("data must be str or bytes-like") + if raise_exceptions: + if sys.version_info[:2] < (3, 11): # pragma: no cover + raise ValueError("raise_exceptions requires at least Python 3.11") + exceptions = [] + opcode, data = prepare_data(message) for websocket in websockets: @@ -1621,6 +1635,26 @@ def broadcast(websockets: Iterable[WebSocketCommonProtocol], message: Data) -> N continue if websocket._fragmented_message_waiter is not None: - raise RuntimeError("busy sending a fragmented message") + if raise_exceptions: + exception = RuntimeError("sending a fragmented message") + exceptions.append(exception) + else: + websocket.logger.warning( + "skipped broadcast: sending a fragmented message", + ) + + try: + websocket.write_frame_sync(True, opcode, data) + except Exception as write_exception: + if raise_exceptions: + exception = RuntimeError("failed to write message") + exception.__cause__ = write_exception + exceptions.append(exception) + else: + websocket.logger.warning( + "skipped broadcast: failed to write message", + exc_info=True, + ) - websocket.write_frame_sync(True, opcode, data) + if raise_exceptions: + raise ExceptionGroup("skipped broadcast", exceptions) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 2da203029..b8fd259c9 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1553,11 +1553,11 @@ async def run_client(): await ws.recv() else: # Exit block with an exception. - raise Exception("BOOM!") + raise Exception("BOOM") pass # work around bug in coverage with self.assertLogs("websockets", logging.INFO) as logs: - with self.assertRaisesRegex(Exception, "BOOM!"): + with self.assertRaisesRegex(Exception, "BOOM"): self.loop.run_until_complete(run_client()) # Iteration 1 diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 328bc80a2..ab8155b9b 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1,5 +1,7 @@ import asyncio import contextlib +import logging +import sys import unittest import unittest.mock import warnings @@ -1468,26 +1470,76 @@ def test_broadcast_two_clients(self): def test_broadcast_skips_closed_connection(self): self.close_connection() - broadcast([self.protocol], "café") + with self.assertNoLogs(): + broadcast([self.protocol], "café") self.assertNoFrameSent() def test_broadcast_skips_closing_connection(self): close_task = self.half_close_connection_local() - broadcast([self.protocol], "café") + with self.assertNoLogs(): + broadcast([self.protocol], "café") self.assertNoFrameSent() self.loop.run_until_complete(close_task) # cleanup - def test_broadcast_within_fragmented_text(self): + def test_broadcast_skips_connection_sending_fragmented_text(self): self.make_drain_slow() self.loop.create_task(self.protocol.send(["ca", "fé"])) self.run_loop_once() self.assertOneFrameSent(False, OP_TEXT, "ca".encode("utf-8")) - with self.assertRaises(RuntimeError): + with self.assertLogs("websockets", logging.WARNING) as logs: + broadcast([self.protocol], "café") + + self.assertEqual( + [record.getMessage() for record in logs.records][:2], + ["skipped broadcast: sending a fragmented message"], + ) + + @unittest.skipIf( + sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+" + ) + def test_broadcast_reports_connection_sending_fragmented_text(self): + self.make_drain_slow() + self.loop.create_task(self.protocol.send(["ca", "fé"])) + self.run_loop_once() + self.assertOneFrameSent(False, OP_TEXT, "ca".encode("utf-8")) + + with self.assertRaises(ExceptionGroup) as raised: + broadcast([self.protocol], "café", raise_exceptions=True) + + self.assertEqual(str(raised.exception), "skipped broadcast (1 sub-exception)") + self.assertEqual( + str(raised.exception.exceptions[0]), "sending a fragmented message" + ) + + def test_broadcast_skips_connection_failing_to_send(self): + # Configure mock to raise an exception when writing to the network. + self.protocol.transport.write.side_effect = RuntimeError + + with self.assertLogs("websockets", logging.WARNING) as logs: broadcast([self.protocol], "café") + self.assertEqual( + [record.getMessage() for record in logs.records][:2], + ["skipped broadcast: failed to write message"], + ) + + @unittest.skipIf( + sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+" + ) + def test_broadcast_reports_connection_failing_to_send(self): + # Configure mock to raise an exception when writing to the network. + self.protocol.transport.write.side_effect = RuntimeError("BOOM") + + with self.assertRaises(ExceptionGroup) as raised: + broadcast([self.protocol], "café", raise_exceptions=True) + + self.assertEqual(str(raised.exception), "skipped broadcast (1 sub-exception)") + self.assertEqual(str(raised.exception.exceptions[0]), "failed to write message") + self.assertEqual(str(raised.exception.exceptions[0].__cause__), "BOOM") + class ServerTests(CommonTests, AsyncioTestCase): def setUp(self): From f269e1e9704ff2776dc5ed54d11e120aa06d62e3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Apr 2023 10:11:26 +0200 Subject: [PATCH 1181/1539] Document that connect() can raise OSError. Fix #1265. --- src/websockets/legacy/client.py | 1 + src/websockets/sync/client.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index f2d2c72b7..b79bbab8c 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -418,6 +418,7 @@ class Connect: Raises: InvalidURI: If ``uri`` isn't a valid WebSocket URI. + OSError: If the TCP connection fails. InvalidHandshake: If the opening handshake fails. ~asyncio.TimeoutError: If the opening handshake times out. diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index ec1167115..087ff5f56 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -198,6 +198,7 @@ def connect( Raises: InvalidURI: If ``uri`` isn't a valid WebSocket URI. + OSError: If the TCP connection fails. InvalidHandshake: If the opening handshake fails. TimeoutError: If the opening handshake times out. From 0a58739a54c9de69b23f57e1183e9864e02d5513 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Apr 2023 14:19:02 +0200 Subject: [PATCH 1182/1539] It's an HTTP, but a URI. --- src/websockets/exceptions.py | 2 +- src/websockets/sync/server.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 1f4b9265c..22a3b583f 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -374,7 +374,7 @@ class InvalidState(WebSocketException, AssertionError): class InvalidURI(WebSocketException): """ - Raised when connecting to an URI that isn't a valid WebSocket URI. + Raised when connecting to a URI that isn't a valid WebSocket URI. """ diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index a53ae2b25..9284c6188 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -352,11 +352,11 @@ def handler(websocket): ` method. process_request: Intercept the request during the opening handshake. Return an HTTP response to force the response or :obj:`None` to - continue normally. When you force a HTTP 101 Continue response, + continue normally. When you force an HTTP 101 Continue response, the handshake is successful. Else, the connection is aborted. process_response: Intercept the response during the opening handshake. Return an HTTP response to force the response or :obj:`None` to - continue normally. When you force a HTTP 101 Continue response, + continue normally. When you force an HTTP 101 Continue response, the handshake is successful. Else, the connection is aborted. server_header: Value of the ``Server`` response header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to From 8a6665892051b26d65344a573c1aea89f9c0e64c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Apr 2023 14:41:58 +0200 Subject: [PATCH 1183/1539] Pluralize API consistently. --- docs/project/changelog.rst | 2 +- docs/reference/index.rst | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 5a0ea423d..fc723a626 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -22,7 +22,7 @@ When a release contains backwards-incompatible API changes, the major version is increased, else the minor version is increased. Patch versions are only for fixing regressions shortly after a release. -Only documented API are public. Undocumented, private API may change without +Only documented APIs are public. Undocumented, private APIs may change without notice. diff --git a/docs/reference/index.rst b/docs/reference/index.rst index fa8047c3d..cc4658dab 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -57,7 +57,7 @@ extensions. Shared ------ -These low-level API are shared by all implementations. +These low-level APIs are shared by all implementations. .. toctree:: :titlesonly: @@ -69,7 +69,7 @@ These low-level API are shared by all implementations. API stability ------------- -Public API documented in this API reference are subject to the +Public APIs documented in this API reference are subject to the :ref:`backwards-compatibility policy `. Anything that isn't listed in the API reference is a private API. There's no From d0a292c5a25fcf4c27deb031e1b7e814494aac0e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Apr 2023 17:58:29 +0200 Subject: [PATCH 1184/1539] Fix typo in docs. --- src/websockets/server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/websockets/server.py b/src/websockets/server.py index e49f30213..dcdf3b71e 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -319,8 +319,8 @@ def process_extensions( Accept or reject each extension proposed in the client request. Negotiate parameters for accepted extensions. - :rfc:`6455` leaves the rules up to the specification of each - :extension. + Per :rfc:`6455`, negotiation rules are defined by the specification of + each extension. To provide this level of flexibility, for each extension proposed by the client, we check for a match with each extension available in the From 1c9bdceccd9023e88e1cab5a1044f42c341930ee Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Apr 2023 18:05:34 +0200 Subject: [PATCH 1185/1539] Add open_timeout to serve(). --- docs/project/changelog.rst | 8 ++++++++ src/websockets/legacy/server.py | 21 ++++++++++++++++----- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index fc723a626..68fb2709f 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -63,6 +63,12 @@ Backwards-incompatible changes As a consequence, calling ``WebSocket.close()`` without arguments in a browser isn't reported as an error anymore. +.. admonition:: :func:`~server.serve` times out on the opening handshake after 10 seconds by default. + :class: note + + You can adjust the timeout with the ``open_timeout`` parameter. Set it to + :obj:`None` to disable the timeout entirely. + New features ............ @@ -77,6 +83,8 @@ New features See :func:`~sync.client.connect` and :func:`~sync.server.serve` for details. +* Added ``open_timeout`` to :func:`~server.serve`. + * Made it possible to close a server without closing existing connections. * Added :attr:`~server.ServerProtocol.select_subprotocol` to customize diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 399df85d3..3be86e45c 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -115,6 +115,7 @@ def __init__( select_subprotocol: Optional[ Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] ] = None, + open_timeout: Optional[float] = 10, **kwargs: Any, ) -> None: if logger is None: @@ -136,6 +137,7 @@ def __init__( self.server_header = server_header self._process_request = process_request self._select_subprotocol = select_subprotocol + self.open_timeout = open_timeout def connection_made(self, transport: asyncio.BaseTransport) -> None: """ @@ -161,16 +163,21 @@ async def handler(self) -> None: """ try: try: - await self.handshake( - origins=self.origins, - available_extensions=self.available_extensions, - available_subprotocols=self.available_subprotocols, - extra_headers=self.extra_headers, + await asyncio.wait_for( + self.handshake( + origins=self.origins, + available_extensions=self.available_extensions, + available_subprotocols=self.available_subprotocols, + extra_headers=self.extra_headers, + ), + self.open_timeout, ) # Remove this branch when dropping support for Python < 3.8 # because CancelledError no longer inherits Exception. except asyncio.CancelledError: # pragma: no cover raise + except asyncio.TimeoutError: # pragma: no cover + raise except ConnectionError: raise except Exception as exc: @@ -954,6 +961,8 @@ class Serve: See :meth:`~WebSocketServerProtocol.process_request` for details. select_subprotocol: Select a subprotocol supported by the client. See :meth:`~WebSocketServerProtocol.select_subprotocol` for details. + open_timeout: Timeout for opening connections in seconds. + :obj:`None` disables the timeout. See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, @@ -997,6 +1006,7 @@ def __init__( select_subprotocol: Optional[ Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] ] = None, + open_timeout: Optional[float] = 10, ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, close_timeout: Optional[float] = None, @@ -1059,6 +1069,7 @@ def __init__( host=host, port=port, secure=secure, + open_timeout=open_timeout, ping_interval=ping_interval, ping_timeout=ping_timeout, close_timeout=close_timeout, From db6b1a892ce8508cf362d541b45430966ab2bc02 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Apr 2023 18:30:51 +0200 Subject: [PATCH 1186/1539] Add feature support matrices. Fix #1313. --- docs/reference/features.rst | 182 +++++++++++++++++++++++++++++++++ docs/reference/index.rst | 21 ++-- docs/reference/limitations.rst | 39 ------- 3 files changed, 193 insertions(+), 49 deletions(-) create mode 100644 docs/reference/features.rst delete mode 100644 docs/reference/limitations.rst diff --git a/docs/reference/features.rst b/docs/reference/features.rst new file mode 100644 index 000000000..f4b592a4b --- /dev/null +++ b/docs/reference/features.rst @@ -0,0 +1,182 @@ +Features +======== + +.. currentmodule:: websockets + +Feature support matrices summarize which implementations support which features. + +.. raw:: html + + + +.. |aio| replace:: :mod:`asyncio` +.. |sync| replace:: :mod:`threading` +.. |sans| replace:: `Sans-I/O`_ +.. _Sans-I/O: https://sans-io.readthedocs.io/ + +Both sides +---------- + +.. table:: + :class: support-matrix-table + + +------------------------------------+--------+--------+--------+ + | | |aio| | |sync| | |sans| | + +====================================+========+========+========+ + | Perfom the opening handshake | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Send a message | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Receive a message | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Iterate over received messages | ✅ | ✅ | ❌ | + +------------------------------------+--------+--------+--------+ + | Send a fragmented message | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Receive a fragmented message after | ✅ | ✅ | ❌ | + | reassembly | | | | + +------------------------------------+--------+--------+--------+ + | Receive a fragmented message frame | ❌ | ✅ | ✅ | + | by frame (`#479`_) | | | | + +------------------------------------+--------+--------+--------+ + | Send a ping | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Respond to pings automatically | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Send a pong | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Perfom the closing handshake | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Report close codes and reasons | ❌ | ✅ | ✅ | + | from both sides | | | | + +------------------------------------+--------+--------+--------+ + | Compress messages (:rfc:`7692`) | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Tune memory usage for compression | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Negotiate extensions | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Implement custom extensions | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Negotiate a subprotocol | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Enforce security limits | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Log events | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Enforce opening timeout | ✅ | ✅ | — | + +------------------------------------+--------+--------+--------+ + | Enforce closing timeout | ✅ | ✅ | — | + +------------------------------------+--------+--------+--------+ + | Keepalive | ✅ | ❌ | — | + +------------------------------------+--------+--------+--------+ + | Heartbeat | ✅ | ❌ | — | + +------------------------------------+--------+--------+--------+ + +.. _#479: https://github.com/aaugustin/websockets/issues/479 + +Server +------ + +.. table:: + :class: support-matrix-table + + +------------------------------------+--------+--------+--------+ + | | |aio| | |sync| | |sans| | + +====================================+========+========+========+ + | Listen on a TCP socket | ✅ | ✅ | — | + +------------------------------------+--------+--------+--------+ + | Listen on a Unix socket | ✅ | ✅ | — | + +------------------------------------+--------+--------+--------+ + | Listen using a preexisting socket | ✅ | ✅ | — | + +------------------------------------+--------+--------+--------+ + | Encrypt connection with TLS | ✅ | ✅ | — | + +------------------------------------+--------+--------+--------+ + | Close server on context exit | ✅ | ✅ | — | + +------------------------------------+--------+--------+--------+ + | Close connection on handler exit | ✅ | ✅ | — | + +------------------------------------+--------+--------+--------+ + | Shut down server gracefully | ✅ | ✅ | — | + +------------------------------------+--------+--------+--------+ + | Check ``Origin`` header | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Customize subprotocol selection | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Configure ``Server`` header | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Alter opening handshake request | ❌ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Alter opening handshake response | ❌ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Perform HTTP Basic Authentication | ✅ | ❌ | ❌ | + +------------------------------------+--------+--------+--------+ + | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | + +------------------------------------+--------+--------+--------+ + | Force HTTP response | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + +Client +------ + +.. table:: + :class: support-matrix-table + + +------------------------------------+--------+--------+--------+ + | | |aio| | |sync| | |sans| | + +====================================+========+========+========+ + | Connect to a TCP socket | ✅ | ✅ | — | + +------------------------------------+--------+--------+--------+ + | Connect to a Unix socket | ✅ | ✅ | — | + +------------------------------------+--------+--------+--------+ + | Connect using a preexisting socket | ✅ | ✅ | — | + +------------------------------------+--------+--------+--------+ + | Encrypt connection with TLS | ✅ | ✅ | — | + +------------------------------------+--------+--------+--------+ + | Close connection on context exit | ✅ | ✅ | — | + +------------------------------------+--------+--------+--------+ + | Reconnect automatically | ✅ | ❌ | — | + +------------------------------------+--------+--------+--------+ + | Configure ``Origin`` header | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Configure ``User-Agent`` header | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Alter opening handshake request | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Connect to non-ASCII IRIs | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Perform HTTP Basic Authentication | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | + | (`#784`_) | | | | + +------------------------------------+--------+--------+--------+ + | Follow HTTP redirects | ✅ | ❌ | — | + +------------------------------------+--------+--------+--------+ + | Connect via a HTTP proxy (`#364`_) | ❌ | ❌ | — | + +------------------------------------+--------+--------+--------+ + | Connect via a SOCKS5 proxy | ❌ | ❌ | — | + | (`#475`_) | | | | + +------------------------------------+--------+--------+--------+ + +.. _#364: https://github.com/aaugustin/websockets/issues/364 +.. _#475: https://github.com/aaugustin/websockets/issues/475 +.. _#784: https://github.com/aaugustin/websockets/issues/784 + +Known limitations +----------------- + +There is no way to control compression of outgoing frames on a per-frame basis +(`#538`_). If compression is enabled, all frames are compressed. + +.. _#538: https://github.com/aaugustin/websockets/issues/538 + +The client API doesn't attempt to guarantee that there is no more than one +connection to a given IP address in a CONNECTING state. This behavior is +`mandated by RFC 6455`_. However, :func:`~client.connect()` isn't the right +layer for enforcing this constraint. It's the caller's responsibility. + +.. _mandated by RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-4.1 diff --git a/docs/reference/index.rst b/docs/reference/index.rst index cc4658dab..9364bc887 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -3,6 +3,17 @@ API reference .. currentmodule:: websockets +Features +-------- + +Check which implementations support which features and known limitations. + +.. toctree:: + :titlesonly: + + features + + :mod:`asyncio` -------------- @@ -93,13 +104,3 @@ For convenience, many public APIs can be imported directly from the If you're using such tools, stick to the full import paths, as explained in this FAQ: :ref:`real-import-paths` - -Limitations ------------ - -There are a few known limitations in the current API. - -.. toctree:: - :titlesonly: - - limitations diff --git a/docs/reference/limitations.rst b/docs/reference/limitations.rst deleted file mode 100644 index 696aa38fd..000000000 --- a/docs/reference/limitations.rst +++ /dev/null @@ -1,39 +0,0 @@ -Limitations -=========== - -.. currentmodule:: websockets - -Client ------- - -The client doesn't attempt to guarantee that there is no more than one -connection to a given IP address in a CONNECTING state. This behavior is -`mandated by RFC 6455`_. However, :func:`~client.connect()` isn't the -right layer for enforcing this constraint. It's the caller's responsibility. - -.. _mandated by RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-4.1 - -The client doesn't support connecting through an HTTP proxy (`issue 364`_) or a -SOCKS proxy (`issue 475`_). - -.. _issue 364: https://github.com/aaugustin/websockets/issues/364 -.. _issue 475: https://github.com/aaugustin/websockets/issues/475 - -Server ------- - -At this time, there are no known limitations affecting only the server. - -Both sides ----------- - -There is no way to control compression of outgoing frames on a per-frame basis -(`issue 538`_). If compression is enabled, all frames are compressed. - -.. _issue 538: https://github.com/aaugustin/websockets/issues/538 - -There is no way to receive each fragment of a fragmented messages as it -arrives (`issue 479`_). websockets always reassembles fragmented messages -before returning them. - -.. _issue 479: https://github.com/aaugustin/websockets/issues/479 From 6dcac04c10c61dde9aa5be0374965c31f16e77f4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Apr 2023 18:32:46 +0200 Subject: [PATCH 1187/1539] Hide "Both sides" API references from navigation. The only reason for their existence is the need to hyperlink to functions that may be used on both sides. --- docs/reference/asyncio/common.rst | 2 ++ docs/reference/index.rst | 3 --- docs/reference/sansio/common.rst | 2 ++ docs/reference/sync/common.rst | 2 ++ 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/docs/reference/asyncio/common.rst b/docs/reference/asyncio/common.rst index ee8dc54ac..dc7a54ee1 100644 --- a/docs/reference/asyncio/common.rst +++ b/docs/reference/asyncio/common.rst @@ -1,3 +1,5 @@ +:orphan: + Both sides (:mod:`asyncio`) =========================== diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 9364bc887..2a9556dd9 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -25,7 +25,6 @@ clients concurrently. asyncio/server asyncio/client - asyncio/common :mod:`threading` ---------------- @@ -37,7 +36,6 @@ This alternative implementation can be a good choice for clients. sync/server sync/client - sync/common `Sans-I/O`_ ----------- @@ -52,7 +50,6 @@ application servers. sansio/server sansio/client - sansio/common Extensions ---------- diff --git a/docs/reference/sansio/common.rst b/docs/reference/sansio/common.rst index 2678c1361..cd1ef3c63 100644 --- a/docs/reference/sansio/common.rst +++ b/docs/reference/sansio/common.rst @@ -1,3 +1,5 @@ +:orphan: + Both sides (`Sans-I/O`_) ========================= diff --git a/docs/reference/sync/common.rst b/docs/reference/sync/common.rst index 8d97ab3c1..3dc6d4a50 100644 --- a/docs/reference/sync/common.rst +++ b/docs/reference/sync/common.rst @@ -1,3 +1,5 @@ +:orphan: + Both sides (:mod:`threading`) ============================= From d3d4cf4a2baf7362b004e9bffe5344287dbb9a51 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Apr 2023 18:43:44 +0200 Subject: [PATCH 1188/1539] Build architecture independent wheels. Fix #1300. --- .github/workflows/wheels.yml | 18 ++++++++++++++---- setup.py | 7 +++++-- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 90a075843..68bfbdef4 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -8,8 +8,10 @@ on: jobs: sdist: - name: Build source distribution + name: Build source distribution and architecture-independent wheel runs-on: ubuntu-latest + env: + BUILD_EXTENSION: no steps: - name: Check out repository uses: actions/checkout@v3 @@ -23,10 +25,20 @@ jobs: uses: actions/upload-artifact@v3 with: path: dist/*.tar.gz + - name: Install wheel + run: pip install wheel + - name: Build wheel + run: python setup.py bdist_wheel + - name: Save wheel + uses: actions/upload-artifact@v3 + with: + path: dist/*.whl wheels: - name: Build wheels on ${{ matrix.os }} + name: Build architecture-specific wheels on ${{ matrix.os }} runs-on: ${{ matrix.os }} + env: + BUILD_EXTENSION: yes strategy: matrix: os: @@ -36,8 +48,6 @@ jobs: steps: - name: Check out repository uses: actions/checkout@v3 - - name: Make extension build mandatory - run: touch .cibuildwheel - name: Install Python 3.x uses: actions/setup-python@v4 with: diff --git a/setup.py b/setup.py index 5ed472503..c8e01f24b 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,4 @@ +import os import pathlib import re @@ -27,11 +28,13 @@ "websockets/sync", ] -ext_modules = [ +# Set BUILD_EXTENSION to yes or no to force building or not building the +# speedups extension. If unset, the extension is built only if possible. +ext_modules = [] if os.environ.get("BUILD_EXTENSION") == "no" else [ setuptools.Extension( "websockets.speedups", sources=["src/websockets/speedups.c"], - optional=not (root_dir / ".cibuildwheel").exists(), + optional=os.environ.get("BUILD_EXTENSION") != "yes", ) ] From d81a4cebe033e0c8933fff12581444234e8bb4db Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Apr 2023 21:48:56 +0200 Subject: [PATCH 1189/1539] Move build configuration to pyproject.toml. Keep only dynamic configuration in setup.py. Remove configuration that setuptools no longer requires. --- pyproject.toml | 35 +++++++++++++++++++++++++++ setup.cfg | 11 --------- setup.py | 64 ++++++++++++-------------------------------------- 3 files changed, 50 insertions(+), 60 deletions(-) create mode 100644 pyproject.toml diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..87538c5e0 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,35 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[project] +name = "websockets" +description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" +requires-python = ">=3.7" +license = { text = "BSD-3-Clause" } +authors = [ + { name = "Aymeric Augustin", email = "aymeric.augustin@m4x.org" }, +] +keywords = ["WebSocket"] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", +] +dynamic = ["version", "readme"] + +[project.urls] +homepage = "https://github.com/aaugustin/websockets" +changelog = "https://websockets.readthedocs.io/en/stable/project/changelog.html" +documentation = "https://websockets.readthedocs.io/" +funding = "https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=readme" +tracker = "https://github.com/aaugustin/websockets/issues" diff --git a/setup.cfg b/setup.cfg index c2ca2d5dd..28ea12c12 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,14 +1,3 @@ -[bdist_wheel] -python-tag = py37.py38.py39.py310.py311 - -[metadata] -license_files = LICENSE -project_urls = - Changelog = https://websockets.readthedocs.io/en/stable/project/changelog.html - Documentation = https://websockets.readthedocs.io/ - Funding = https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=readme - Tracker = https://github.com/aaugustin/websockets/issues - [flake8] ignore = E203,E731,F403,F405,W503 max-line-length = 88 diff --git a/setup.py b/setup.py index c8e01f24b..ae0aaa65d 100644 --- a/setup.py +++ b/setup.py @@ -7,66 +7,32 @@ root_dir = pathlib.Path(__file__).parent -description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" - -long_description = (root_dir / "README.rst").read_text(encoding="utf-8") +exec((root_dir / "src" / "websockets" / "version.py").read_text(encoding="utf-8")) -# PyPI disables the "raw" directive. +# PyPI disables the "raw" directive. Remove this section of the README. long_description = re.sub( r"^\.\. raw:: html.*?^(?=\w)", "", - long_description, + (root_dir / "README.rst").read_text(encoding="utf-8"), flags=re.DOTALL | re.MULTILINE, ) -exec((root_dir / "src" / "websockets" / "version.py").read_text(encoding="utf-8")) - -packages = [ - "websockets", - "websockets/extensions", - "websockets/legacy", - "websockets/sync", -] - # Set BUILD_EXTENSION to yes or no to force building or not building the # speedups extension. If unset, the extension is built only if possible. -ext_modules = [] if os.environ.get("BUILD_EXTENSION") == "no" else [ - setuptools.Extension( - "websockets.speedups", - sources=["src/websockets/speedups.c"], - optional=os.environ.get("BUILD_EXTENSION") != "yes", - ) -] - +if os.environ.get("BUILD_EXTENSION") == "no": + ext_modules = [] +else: + ext_modules = [ + setuptools.Extension( + "websockets.speedups", + sources=["src/websockets/speedups.c"], + optional=os.environ.get("BUILD_EXTENSION") != "yes", + ) + ] + +# Static values are declared in pyproject.toml. setuptools.setup( - name="websockets", version=version, - description=description, long_description=long_description, - url="https://github.com/aaugustin/websockets", - author="Aymeric Augustin", - author_email="aymeric.augustin@m4x.org", - license="BSD-3-Clause", - classifiers=[ - "Development Status :: 5 - Production/Stable", - "Environment :: Web Environment", - "Intended Audience :: Developers", - "License :: OSI Approved :: BSD License", - "Operating System :: OS Independent", - "Programming Language :: Python", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - ], - package_dir={"": "src"}, - package_data={"websockets": ["py.typed"]}, - packages=packages, ext_modules=ext_modules, - include_package_data=True, - zip_safe=False, - python_requires=">=3.7", - test_loader="unittest:TestLoader", ) From 0924e9cdd9e74065aa23d160414187f202771d61 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Apr 2023 22:14:37 +0200 Subject: [PATCH 1190/1539] Move coverage configuration to pyproject.toml. --- pyproject.toml | 25 +++++++++++++++++++++++++ setup.cfg | 22 ---------------------- 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 87538c5e0..fe0602b98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,3 +33,28 @@ changelog = "https://websockets.readthedocs.io/en/stable/project/changelog.html" documentation = "https://websockets.readthedocs.io/" funding = "https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=readme" tracker = "https://github.com/aaugustin/websockets/issues" + +[tool.coverage.run] +branch = true +omit = [ + # */websockets matches src/websockets and .tox/**/site-packages/websockets + "*/websockets/__main__.py", + "*/websockets/legacy/compatibility.py", + "tests/maxi_cov.py", +] + +[tool.coverage.paths] +source = [ + "src/websockets", + ".tox/*/lib/python*/site-packages/websockets", +] + +[tool.coverage.report] +exclude_lines = [ + "if self.debug:", + "pragma: no cover", + "raise AssertionError", + "raise NotImplementedError", + "self.fail\\(\".*\"\\)", + "@unittest.skip", +] diff --git a/setup.cfg b/setup.cfg index 28ea12c12..ca211fc63 100644 --- a/setup.cfg +++ b/setup.cfg @@ -6,25 +6,3 @@ max-line-length = 88 profile = black combine_as_imports = True lines_after_imports = 2 - -[coverage:run] -branch = True -omit = - # */websockets matches src/websockets and .tox/**/site-packages/websockets - */websockets/__main__.py - */websockets/legacy/compatibility.py - tests/maxi_cov.py - -[coverage:paths] -source = - src/websockets - .tox/*/lib/python*/site-packages/websockets - -[coverage:report] -exclude_lines = - if self.debug: - pragma: no cover - raise AssertionError - raise NotImplementedError - self.fail\(".*"\) - @unittest.skip From af9173790a620c5f3687ea61f17e58c7466fc020 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Apr 2023 22:33:30 +0200 Subject: [PATCH 1191/1539] Replace flake8 and isort by ruff. Configure in pyproject.toml instead of setup.cfg. Fix #1306. --- .github/workflows/tests.yml | 4 +--- Makefile | 11 ++++++----- pyproject.toml | 16 ++++++++++++++++ setup.cfg | 8 -------- src/websockets/legacy/server.py | 8 ++++---- tox.ini | 13 ++++--------- 6 files changed, 31 insertions(+), 29 deletions(-) delete mode 100644 setup.cfg diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 34f8d8c5c..163d1bbae 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -44,9 +44,7 @@ jobs: - name: Check code formatting run: tox -e black - name: Check code style - run: tox -e flake8 - - name: Check imports ordering - run: tox -e isort + run: tox -e ruff - name: Check types statically run: tox -e mypy diff --git a/Makefile b/Makefile index ac5d6a4aa..cf3b53393 100644 --- a/Makefile +++ b/Makefile @@ -1,18 +1,19 @@ -.PHONY: default style test coverage maxi_cov build clean +.PHONY: default style types tests coverage maxi_cov build clean export PYTHONASYNCIODEBUG=1 export PYTHONPATH=src export PYTHONWARNINGS=default -default: coverage style +default: style types tests style: - isort --project websockets src tests black src tests - flake8 src tests + ruff --fix src tests + +types: mypy --strict src -test: +tests: python -m unittest coverage: diff --git a/pyproject.toml b/pyproject.toml index fe0602b98..99d0fa08c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,3 +58,19 @@ exclude_lines = [ "self.fail\\(\".*\"\\)", "@unittest.skip", ] + +[tool.ruff] +select = [ + "E", # pycodestyle + "F", # Pyflakes + "W", # pycodestyle + "I", # isort +] +ignore = [ + "F403", + "F405", +] + +[tool.ruff.isort] +combine-as-imports = true +lines-after-imports = 2 diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index ca211fc63..000000000 --- a/setup.cfg +++ /dev/null @@ -1,8 +0,0 @@ -[flake8] -ignore = E203,E731,F403,F405,W503 -max-line-length = 88 - -[isort] -profile = black -combine_as_imports = True -lines_after_imports = 2 diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 3be86e45c..92252b136 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -552,10 +552,10 @@ def select_subprotocol( subprotocols = set(client_subprotocols) & set(server_subprotocols) if not subprotocols: return None - priority = lambda p: ( - client_subprotocols.index(p) + server_subprotocols.index(p) - ) - return sorted(subprotocols, key=priority)[0] + return sorted( + subprotocols, + key=lambda p: client_subprotocols.index(p) + server_subprotocols.index(p), + )[0] async def handshake( self, diff --git a/tox.ini b/tox.ini index 0fcab4d79..939d8c0cd 100644 --- a/tox.ini +++ b/tox.ini @@ -7,8 +7,7 @@ envlist = py311 coverage black - flake8 - isort + ruff mypy [testenv] @@ -31,13 +30,9 @@ deps = coverage commands = black --check src tests deps = black -[testenv:flake8] -commands = flake8 src tests -deps = flake8 - -[testenv:isort] -commands = isort --check-only src tests -deps = isort +[testenv:ruff] +commands = ruff src tests +deps = ruff [testenv:mypy] commands = mypy --strict src From 5b376aa4fc97424a10ad0b095c6cdbc5af81fd24 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Apr 2023 08:58:55 +0200 Subject: [PATCH 1192/1539] Reduce reliance on # pragma: no cover. --- pyproject.toml | 1 + src/websockets/frames.py | 2 +- src/websockets/legacy/framing.py | 2 +- src/websockets/legacy/protocol.py | 20 ++++++-------------- src/websockets/legacy/server.py | 18 +++++++++--------- src/websockets/sync/compatibility.py | 2 +- tests/legacy/test_client_server.py | 4 ++-- tests/test_utils.py | 2 +- 8 files changed, 22 insertions(+), 29 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 99d0fa08c..989b6b5e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ source = [ [tool.coverage.report] exclude_lines = [ + "except ImportError:", "if self.debug:", "pragma: no cover", "raise AssertionError", diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 52d81746d..8e0e6d873 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -13,7 +13,7 @@ try: from .speedups import apply_mask -except ImportError: # pragma: no cover +except ImportError: from .utils import apply_mask diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index 4836eb284..dab501d2a 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -9,7 +9,7 @@ try: from ..speedups import apply_mask -except ImportError: # pragma: no cover +except ImportError: from ..utils import apply_mask diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 67bca0ef4..7f9ab2bd8 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -714,9 +714,7 @@ async def send( await self.write_frame(False, opcode, data) # Other fragments. - # coverage reports this code as not covered, but it is - # exercised by tests - changing it breaks the tests! - async for fragment in aiter_message: # pragma: no cover + async for fragment in aiter_message: confirm_opcode, data = prepare_data(fragment) if confirm_opcode != opcode: raise TypeError("data contains inconsistent types") @@ -1150,8 +1148,8 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: if ping_id == frame.data: self.latency = pong_timestamp - ping_timestamp break - else: # pragma: no cover - assert False, "ping_id is in self.pings" + else: + raise AssertionError("solicited pong not found in pings") # Remove acknowledged pings from self.pings. for ping_id in ping_ids: del self.pings[ping_id] @@ -1320,9 +1318,7 @@ async def close_connection(self) -> None: # A client should wait for a TCP close from the server. if self.is_client and hasattr(self, "transfer_data_task"): if await self.wait_for_connection_lost(): - # Coverage marks this line as a partially executed branch. - # I suspect a bug in coverage. Ignore it for now. - return # pragma: no cover + return if self.debug: self.logger.debug("! timed out waiting for TCP close") @@ -1340,9 +1336,7 @@ async def close_connection(self) -> None: pass if await self.wait_for_connection_lost(): - # Coverage marks this line as a partially executed branch. - # I suspect a bug in coverage. Ignore it for now. - return # pragma: no cover + return if self.debug: self.logger.debug("! timed out waiting for TCP close") @@ -1378,9 +1372,7 @@ async def close_transport(self) -> None: self.transport.abort() # connection_lost() is called quickly after aborting. - # Coverage marks this line as a partially executed branch. - # I suspect a bug in coverage. Ignore it for now. - await self.wait_for_connection_lost() # pragma: no cover + await self.wait_for_connection_lost() async def wait_for_connection_lost(self) -> bool: """ diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 92252b136..3506276b9 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -840,7 +840,7 @@ def is_serving(self) -> bool: """ return self.server.is_serving() - async def start_serving(self) -> None: + async def start_serving(self) -> None: # pragma: no cover """ See :meth:`asyncio.Server.start_serving`. @@ -852,9 +852,9 @@ async def start_serving(self) -> None: await server.start_serving() """ - await self.server.start_serving() # pragma: no cover + await self.server.start_serving() - async def serve_forever(self) -> None: + async def serve_forever(self) -> None: # pragma: no cover """ See :meth:`asyncio.Server.serve_forever`. @@ -870,7 +870,7 @@ async def serve_forever(self) -> None: instead of exiting a :func:`serve` context. """ - await self.server.serve_forever() # pragma: no cover + await self.server.serve_forever() @property def sockets(self) -> Iterable[socket.socket]: @@ -880,17 +880,17 @@ def sockets(self) -> Iterable[socket.socket]: """ return self.server.sockets - async def __aenter__(self) -> WebSocketServer: - return self # pragma: no cover + async def __aenter__(self) -> WebSocketServer: # pragma: no cover + return self async def __aexit__( self, exc_type: Optional[Type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType], - ) -> None: - self.close() # pragma: no cover - await self.wait_closed() # pragma: no cover + ) -> None: # pragma: no cover + self.close() + await self.wait_closed() class Serve: diff --git a/src/websockets/sync/compatibility.py b/src/websockets/sync/compatibility.py index 3064263e9..38d2ab668 100644 --- a/src/websockets/sync/compatibility.py +++ b/src/websockets/sync/compatibility.py @@ -3,7 +3,7 @@ try: from socket import create_server as socket_create_server -except ImportError: # pragma: no cover +except ImportError: import socket def socket_create_server(address, family=socket.AF_INET): # type: ignore diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index b8fd259c9..3dd06d01e 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -406,11 +406,11 @@ def __init__(self, *args, **kwargs): self.used_for_write = False super().__init__(*args, **kwargs) - def recv(self, *args, **kwargs): # pragma: no cover + def recv(self, *args, **kwargs): self.used_for_read = True return super().recv(*args, **kwargs) - def recv_into(self, *args, **kwargs): # pragma: no cover + def recv_into(self, *args, **kwargs): self.used_for_read = True return super().recv_into(*args, **kwargs) diff --git a/tests/test_utils.py b/tests/test_utils.py index acd60edfc..678fcfe79 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -82,7 +82,7 @@ def test_apply_mask_check_mask_length(self): try: from websockets.speedups import apply_mask as c_apply_mask -except ImportError: # pragma: no cover +except ImportError: pass else: From 5c7a44266ff63a486a2ab6dd8fa994136fbbbaff Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Apr 2023 09:01:46 +0200 Subject: [PATCH 1193/1539] Reduce usage of # noqa. Specify which warnings are ignored. --- src/websockets/__init__.py | 4 ++-- src/websockets/__main__.py | 2 +- src/websockets/auth.py | 2 +- src/websockets/client.py | 4 ++-- src/websockets/connection.py | 2 +- src/websockets/legacy/framing.py | 7 +++++-- src/websockets/server.py | 2 +- tests/extensions/test_base.py | 2 +- tests/test_auth.py | 2 +- tests/test_http.py | 2 +- tests/test_typing.py | 2 +- 11 files changed, 17 insertions(+), 14 deletions(-) diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 826decc48..dcf3d8150 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -1,10 +1,10 @@ from __future__ import annotations from .imports import lazy_import -from .version import version as __version__ # noqa +from .version import version as __version__ # noqa: F401 -__all__ = [ # noqa +__all__ = [ "AbortHandshake", "basic_auth_protocol_factory", "BasicAuthWebSocketServerProtocol", diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index a7dd1aaef..f2ea5cf4e 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -8,7 +8,7 @@ try: - import readline # noqa + import readline # noqa: F401 except ImportError: # Windows has no `readline` normally pass diff --git a/src/websockets/auth.py b/src/websockets/auth.py index afcb38cff..5292e4f7f 100644 --- a/src/websockets/auth.py +++ b/src/websockets/auth.py @@ -1,4 +1,4 @@ from __future__ import annotations # See #940 for why lazy_import isn't used here for backwards compatibility. -from .legacy.auth import * # noqa +from .legacy.auth import * diff --git a/src/websockets/client.py b/src/websockets/client.py index a0d077fc2..bf8427c37 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -38,7 +38,7 @@ # See #940 for why lazy_import isn't used here for backwards compatibility. -from .legacy.client import * # isort:skip # noqa +from .legacy.client import * # isort:skip # noqa: I001 __all__ = ["ClientProtocol"] @@ -90,7 +90,7 @@ def __init__( self.available_subprotocols = subprotocols self.key = generate_key() - def connect(self) -> Request: # noqa: F811 + def connect(self) -> Request: """ Create a handshake request to open a connection. diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 5ce4d6a3b..88bcda1aa 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -3,7 +3,7 @@ import warnings # lazy_import doesn't support this use case. -from .protocol import SEND_EOF, Protocol as Connection, Side, State # noqa +from .protocol import SEND_EOF, Protocol as Connection, Side, State # noqa: F401 warnings.warn( diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index dab501d2a..b77b869e3 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -145,8 +145,11 @@ def write( # Backwards compatibility with previously documented public APIs - -from ..frames import Close, prepare_ctrl as encode_data, prepare_data # noqa +from ..frames import ( # noqa: E402, F401, I001 + Close, + prepare_ctrl as encode_data, + prepare_data, +) def parse_close(data: bytes) -> Tuple[int, str]: diff --git a/src/websockets/server.py b/src/websockets/server.py index dcdf3b71e..ecb0f74a6 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -39,7 +39,7 @@ # See #940 for why lazy_import isn't used here for backwards compatibility. -from .legacy.server import * # isort:skip # noqa +from .legacy.server import * # isort:skip # noqa: I001 __all__ = ["ServerProtocol"] diff --git a/tests/extensions/test_base.py b/tests/extensions/test_base.py index ba8657b65..b18ffb6fb 100644 --- a/tests/extensions/test_base.py +++ b/tests/extensions/test_base.py @@ -1,4 +1,4 @@ -from websockets.extensions.base import * # noqa +from websockets.extensions.base import * # Abstract classes don't provide any behavior to test. diff --git a/tests/test_auth.py b/tests/test_auth.py index d5a8bd9ad..28db93155 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1 +1 @@ -from websockets.auth import * # noqa +from websockets.auth import * diff --git a/tests/test_http.py b/tests/test_http.py index 16bec9468..036bc1410 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -1 +1 @@ -from websockets.http import * # noqa +from websockets.http import * diff --git a/tests/test_typing.py b/tests/test_typing.py index 6eb1fe6c5..202de840f 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -1 +1 @@ -from websockets.typing import * # noqa +from websockets.typing import * From 5113cd3afe7cbf5f740d80db8449670c27c90101 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 7 Mar 2023 11:56:19 -1000 Subject: [PATCH 1194/1539] Use asyncio.timeout instead of asyncio.wait_for. asyncio.wait_for creates a task whereas asyncio.timeout doesn't. Fallback to a vendored version of async_timeout on Python < 3.11. async.timeout will become the underlying implementation for async.wait_for in Python 3.12: https://github.com/python/cpython/pull/98518 --- pyproject.toml | 1 + src/websockets/legacy/async_timeout.py | 225 +++++++++++++++++++++++++ src/websockets/legacy/compatibility.py | 9 + src/websockets/legacy/protocol.py | 34 ++-- 4 files changed, 246 insertions(+), 23 deletions(-) create mode 100644 src/websockets/legacy/async_timeout.py diff --git a/pyproject.toml b/pyproject.toml index 989b6b5e0..0707c6442 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ branch = true omit = [ # */websockets matches src/websockets and .tox/**/site-packages/websockets "*/websockets/__main__.py", + "*/websockets/legacy/async_timeout.py", "*/websockets/legacy/compatibility.py", "tests/maxi_cov.py", ] diff --git a/src/websockets/legacy/async_timeout.py b/src/websockets/legacy/async_timeout.py new file mode 100644 index 000000000..0a2208927 --- /dev/null +++ b/src/websockets/legacy/async_timeout.py @@ -0,0 +1,225 @@ +# From https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py +# Licensed under the Apache License, Version 2.0. + +import asyncio +import enum +import sys +import warnings +from types import TracebackType +from typing import Optional, Type + + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final + + +__version__ = "4.0.2" + + +__all__ = ("timeout", "timeout_at", "Timeout") + + +def timeout(delay: Optional[float]) -> "Timeout": + """timeout context manager. + + Useful in cases when you want to apply timeout logic around block + of code or in cases when asyncio.wait_for is not suitable. For example: + + >>> async with timeout(0.001): + ... async with aiohttp.get('https://github.com') as r: + ... await r.text() + + + delay - value in seconds or None to disable timeout logic + """ + loop = asyncio.get_running_loop() + if delay is not None: + deadline = loop.time() + delay # type: Optional[float] + else: + deadline = None + return Timeout(deadline, loop) + + +def timeout_at(deadline: Optional[float]) -> "Timeout": + """Schedule the timeout at absolute time. + + deadline argument points on the time in the same clock system + as loop.time(). + + Please note: it is not POSIX time but a time with + undefined starting base, e.g. the time of the system power on. + + >>> async with timeout_at(loop.time() + 10): + ... async with aiohttp.get('https://github.com') as r: + ... await r.text() + + + """ + loop = asyncio.get_running_loop() + return Timeout(deadline, loop) + + +class _State(enum.Enum): + INIT = "INIT" + ENTER = "ENTER" + TIMEOUT = "TIMEOUT" + EXIT = "EXIT" + + +@final +class Timeout: + # Internal class, please don't instantiate it directly + # Use timeout() and timeout_at() public factories instead. + # + # Implementation note: `async with timeout()` is preferred + # over `with timeout()`. + # While technically the Timeout class implementation + # doesn't need to be async at all, + # the `async with` statement explicitly points that + # the context manager should be used from async function context. + # + # This design allows to avoid many silly misusages. + # + # TimeoutError is raised immediately when scheduled + # if the deadline is passed. + # The purpose is to time out as soon as possible + # without waiting for the next await expression. + + __slots__ = ("_deadline", "_loop", "_state", "_timeout_handler") + + def __init__( + self, deadline: Optional[float], loop: asyncio.AbstractEventLoop + ) -> None: + self._loop = loop + self._state = _State.INIT + + self._timeout_handler = None # type: Optional[asyncio.Handle] + if deadline is None: + self._deadline = None # type: Optional[float] + else: + self.update(deadline) + + def __enter__(self) -> "Timeout": + warnings.warn( + "with timeout() is deprecated, use async with timeout() instead", + DeprecationWarning, + stacklevel=2, + ) + self._do_enter() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + self._do_exit(exc_type) + return None + + async def __aenter__(self) -> "Timeout": + self._do_enter() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + self._do_exit(exc_type) + return None + + @property + def expired(self) -> bool: + """Is timeout expired during execution?""" + return self._state == _State.TIMEOUT + + @property + def deadline(self) -> Optional[float]: + return self._deadline + + def reject(self) -> None: + """Reject scheduled timeout if any.""" + # cancel is maybe better name but + # task.cancel() raises CancelledError in asyncio world. + if self._state not in (_State.INIT, _State.ENTER): + raise RuntimeError(f"invalid state {self._state.value}") + self._reject() + + def _reject(self) -> None: + if self._timeout_handler is not None: + self._timeout_handler.cancel() + self._timeout_handler = None + + def shift(self, delay: float) -> None: + """Advance timeout on delay seconds. + + The delay can be negative. + + Raise RuntimeError if shift is called when deadline is not scheduled + """ + deadline = self._deadline + if deadline is None: + raise RuntimeError("cannot shift timeout if deadline is not scheduled") + self.update(deadline + delay) + + def update(self, deadline: float) -> None: + """Set deadline to absolute value. + + deadline argument points on the time in the same clock system + as loop.time(). + + If new deadline is in the past the timeout is raised immediately. + + Please note: it is not POSIX time but a time with + undefined starting base, e.g. the time of the system power on. + """ + if self._state == _State.EXIT: + raise RuntimeError("cannot reschedule after exit from context manager") + if self._state == _State.TIMEOUT: + raise RuntimeError("cannot reschedule expired timeout") + if self._timeout_handler is not None: + self._timeout_handler.cancel() + self._deadline = deadline + if self._state != _State.INIT: + self._reschedule() + + def _reschedule(self) -> None: + assert self._state == _State.ENTER + deadline = self._deadline + if deadline is None: + return + + now = self._loop.time() + if self._timeout_handler is not None: + self._timeout_handler.cancel() + + task = asyncio.current_task() + if deadline <= now: + self._timeout_handler = self._loop.call_soon(self._on_timeout, task) + else: + self._timeout_handler = self._loop.call_at(deadline, self._on_timeout, task) + + def _do_enter(self) -> None: + if self._state != _State.INIT: + raise RuntimeError(f"invalid state {self._state.value}") + self._state = _State.ENTER + self._reschedule() + + def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None: + if exc_type is asyncio.CancelledError and self._state == _State.TIMEOUT: + self._timeout_handler = None + raise asyncio.TimeoutError + # timeout has not expired + self._state = _State.EXIT + self._reject() + return None + + def _on_timeout(self, task: "asyncio.Task[None]") -> None: + task.cancel() + self._state = _State.TIMEOUT + # drop the reference early + self._timeout_handler = None diff --git a/src/websockets/legacy/compatibility.py b/src/websockets/legacy/compatibility.py index 303e203b4..cb9b02c86 100644 --- a/src/websockets/legacy/compatibility.py +++ b/src/websockets/legacy/compatibility.py @@ -5,6 +5,9 @@ from typing import Any, Dict +__all__ = ["asyncio_timeout", "loop_if_py_lt_38"] + + if sys.version_info[:2] >= (3, 8): def loop_if_py_lt_38(loop: asyncio.AbstractEventLoop) -> Dict[str, Any]: @@ -22,3 +25,9 @@ def loop_if_py_lt_38(loop: asyncio.AbstractEventLoop) -> Dict[str, Any]: """ return {"loop": loop} + + +if sys.version_info[:2] >= (3, 11): + from asyncio import timeout as asyncio_timeout # noqa: F401 +else: + from .async_timeout import timeout as asyncio_timeout # noqa: F401 diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 7f9ab2bd8..78b59ee81 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -53,7 +53,7 @@ ) from ..protocol import State from ..typing import Data, LoggerLike, Subprotocol -from .compatibility import loop_if_py_lt_38 +from .compatibility import asyncio_timeout, loop_if_py_lt_38 from .framing import Frame @@ -761,19 +761,16 @@ async def close(self, code: int = 1000, reason: str = "") -> None: """ try: - await asyncio.wait_for( - self.write_close_frame(Close(code, reason)), - self.close_timeout, - **loop_if_py_lt_38(self.loop), - ) + async with asyncio_timeout(self.close_timeout): + await self.write_close_frame(Close(code, reason)) except asyncio.TimeoutError: # If the close frame cannot be sent because the send buffers # are full, the closing handshake won't complete anyway. # Fail the connection to shut down faster. self.fail_connection() - # If no close frame is received within the timeout, wait_for() cancels - # the data transfer task and raises TimeoutError. + # If no close frame is received within the timeout, asyncio_timeout() + # cancels the data transfer task and raises TimeoutError. # If close() is called multiple times concurrently and one of these # calls hits the timeout, the data transfer task will be canceled. @@ -782,11 +779,8 @@ async def close(self, code: int = 1000, reason: str = "") -> None: try: # If close() is canceled during the wait, self.transfer_data_task # is canceled before the timeout elapses. - await asyncio.wait_for( - self.transfer_data_task, - self.close_timeout, - **loop_if_py_lt_38(self.loop), - ) + async with asyncio_timeout(self.close_timeout): + await self.transfer_data_task except (asyncio.TimeoutError, asyncio.CancelledError): pass @@ -1268,11 +1262,8 @@ async def keepalive_ping(self) -> None: if self.ping_timeout is not None: try: - await asyncio.wait_for( - pong_waiter, - self.ping_timeout, - **loop_if_py_lt_38(self.loop), - ) + async with asyncio_timeout(self.ping_timeout): + await pong_waiter self.logger.debug("% received keepalive pong") except asyncio.TimeoutError: if self.debug: @@ -1384,11 +1375,8 @@ async def wait_for_connection_lost(self) -> bool: """ if not self.connection_lost_waiter.done(): try: - await asyncio.wait_for( - asyncio.shield(self.connection_lost_waiter), - self.close_timeout, - **loop_if_py_lt_38(self.loop), - ) + async with asyncio_timeout(self.close_timeout): + await asyncio.shield(self.connection_lost_waiter) except asyncio.TimeoutError: pass # Re-check self.connection_lost_waiter.done() synchronously because From 5a6f74e2248181ccd25a638304e8959d3ea90f91 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Apr 2023 07:44:51 +0200 Subject: [PATCH 1195/1539] Backport typing.final for compatibility with Python 3.7. --- src/websockets/legacy/async_timeout.py | 46 ++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/src/websockets/legacy/async_timeout.py b/src/websockets/legacy/async_timeout.py index 0a2208927..8264094f5 100644 --- a/src/websockets/legacy/async_timeout.py +++ b/src/websockets/legacy/async_timeout.py @@ -1,5 +1,5 @@ # From https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py -# Licensed under the Apache License, Version 2.0. +# Licensed under the Apache License (Apache-2.0) import asyncio import enum @@ -9,11 +9,48 @@ from typing import Optional, Type -if sys.version_info >= (3, 8): +# From https://github.com/python/typing_extensions/blob/main/src/typing_extensions.py +# Licensed under the Python Software Foundation License (PSF-2.0) + +if sys.version_info >= (3, 11): from typing import final else: - from typing_extensions import final + # @final exists in 3.8+, but we backport it for all versions + # before 3.11 to keep support for the __final__ attribute. + # See https://bugs.python.org/issue46342 + def final(f): + """This decorator can be used to indicate to type checkers that + the decorated method cannot be overridden, and decorated class + cannot be subclassed. For example: + + class Base: + @final + def done(self) -> None: + ... + class Sub(Base): + def done(self) -> None: # Error reported by type checker + ... + @final + class Leaf: + ... + class Other(Leaf): # Error reported by type checker + ... + + There is no runtime checking of these properties. The decorator + sets the ``__final__`` attribute to ``True`` on the decorated object + to allow runtime introspection. + """ + try: + f.__final__ = True + except (AttributeError, TypeError): + # Skip the attribute silently if it is not writable. + # AttributeError happens if the object has __slots__ or a + # read-only property, TypeError if it's a builtin class. + pass + return f + +# End https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py __version__ = "4.0.2" @@ -223,3 +260,6 @@ def _on_timeout(self, task: "asyncio.Task[None]") -> None: self._state = _State.TIMEOUT # drop the reference early self._timeout_handler = None + + +# End https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py From 808d8540af05e6bbfd74be8ae501b200ce7c966e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Apr 2023 08:14:27 +0200 Subject: [PATCH 1196/1539] Replace asyncio.wait_for with asyncio.timeout. Some instances were missed in a previous commit. Also update documentation. --- docs/faq/common.rst | 9 +++++++-- docs/topics/design.rst | 2 +- src/websockets/legacy/client.py | 4 +++- src/websockets/legacy/protocol.py | 5 +++-- src/websockets/legacy/server.py | 10 ++++------ tests/legacy/test_client_server.py | 4 +++- 6 files changed, 21 insertions(+), 13 deletions(-) diff --git a/docs/faq/common.rst b/docs/faq/common.rst index 2a512ea90..505149a64 100644 --- a/docs/faq/common.rst +++ b/docs/faq/common.rst @@ -105,9 +105,14 @@ you can adjust it with the ``ping_timeout`` argument. How do I set a timeout on :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`? -------------------------------------------------------------------------------- -Use :func:`~asyncio.wait_for`:: +On Python ≥ 3.11, use :func:`asyncio.timeout`:: - await asyncio.wait_for(websocket.recv(), timeout=10) + async with asyncio.timeout(timeout=10): + message = await websocket.recv() + +On older versions of Python, use :func:`asyncio.wait_for`:: + + message = await asyncio.wait_for(websocket.recv(), timeout=10) This technique works for most APIs. When it doesn't, for example with asynchronous context managers, websockets provides an ``open_timeout`` argument. diff --git a/docs/topics/design.rst b/docs/topics/design.rst index 33dd187b9..f164d2990 100644 --- a/docs/topics/design.rst +++ b/docs/topics/design.rst @@ -437,7 +437,7 @@ propagate cancellation to them. prevent cancellation. :meth:`~legacy.protocol.WebSocketCommonProtocol.close` waits for the data transfer -task to terminate with :func:`~asyncio.wait_for`. If it's canceled or if the +task to terminate with :func:`~asyncio.timeout`. If it's canceled or if the timeout elapses, :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` is canceled, which is correct at this point. :meth:`~legacy.protocol.WebSocketCommonProtocol.close` then waits for diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index b79bbab8c..c5e9d0d52 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -44,6 +44,7 @@ from ..http import USER_AGENT from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol from ..uri import WebSocketURI, parse_uri +from .compatibility import asyncio_timeout from .handshake import build_request, check_response from .http import read_response from .protocol import WebSocketCommonProtocol @@ -650,7 +651,8 @@ def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]: return self.__await_impl_timeout__().__await__() async def __await_impl_timeout__(self) -> WebSocketClientProtocol: - return await asyncio.wait_for(self.__await_impl__(), self.open_timeout) + async with asyncio_timeout(self.open_timeout): + return await self.__await_impl__() async def __await_impl__(self) -> WebSocketClientProtocol: for redirects in range(self.MAX_REDIRECTS_ALLOWED): diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 78b59ee81..8b921e6fe 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -107,7 +107,8 @@ class WebSocketCommonProtocol(asyncio.Protocol): context manager; * on the server side, when the connection handler terminates. - To apply a timeout to any other API, wrap it in :func:`~asyncio.wait_for`. + To apply a timeout to any other API, wrap it in :func:`~asyncio.timeout` or + :func:`~asyncio.wait_for`. The ``max_size`` parameter enforces the maximum size for incoming messages in bytes. The default value is 1 MiB. If a larger message is received, @@ -513,7 +514,7 @@ async def recv(self) -> Data: message. The next invocation of :meth:`recv` will return it. This makes it possible to enforce a timeout by wrapping :meth:`recv` in - :func:`~asyncio.wait_for`. + :func:`~asyncio.timeout` or :func:`~asyncio.wait_for`. Returns: Data: A string (:class:`str`) for a Text_ frame. A bytestring diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 3506276b9..25d5a7144 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -46,7 +46,7 @@ from ..http import USER_AGENT from ..protocol import State from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol -from .compatibility import loop_if_py_lt_38 +from .compatibility import asyncio_timeout, loop_if_py_lt_38 from .handshake import build_response, check_request from .http import read_request from .protocol import WebSocketCommonProtocol @@ -163,15 +163,13 @@ async def handler(self) -> None: """ try: try: - await asyncio.wait_for( - self.handshake( + async with asyncio_timeout(self.open_timeout): + await self.handshake( origins=self.origins, available_extensions=self.available_extensions, available_subprotocols=self.available_subprotocols, extra_headers=self.extra_headers, - ), - self.open_timeout, - ) + ) # Remove this branch when dropping support for Python < 3.8 # because CancelledError no longer inherits Exception. except asyncio.CancelledError: # pragma: no cover diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 3dd06d01e..133af0536 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -29,6 +29,7 @@ ) from websockets.http import USER_AGENT from websockets.legacy.client import * +from websockets.legacy.compatibility import asyncio_timeout from websockets.legacy.handshake import build_response from websockets.legacy.http import read_response from websockets.legacy.server import * @@ -1129,7 +1130,8 @@ def test_client_connect_canceled_during_handshake(self): async def cancelled_client(): start_client = connect(get_server_uri(self.server), sock=sock) - await asyncio.wait_for(start_client, 5 * MS) + async with asyncio_timeout(5 * MS): + await start_client with self.assertRaises(asyncio.TimeoutError): self.loop.run_until_complete(cancelled_client()) From f075aac67e15cdf4bc06078e23b82eac5fb2d758 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Apr 2023 08:38:39 +0200 Subject: [PATCH 1197/1539] Restore semantics of tests. They relied (accidentally) on wait_for() creating a task, causing the event loop to run once when calling close(). --- tests/legacy/test_protocol.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index ab8155b9b..409aef901 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -190,7 +190,6 @@ def half_close_connection_local(self, code=1000, reason="close"): close_frame_data = Close(code, reason).serialize() # Trigger the closing handshake from the local endpoint. close_task = self.loop.create_task(self.protocol.close(code, reason)) - self.run_loop_once() # wait_for executes self.run_loop_once() # write_frame executes # Empty the outgoing data stream so we can make assertions later on. self.assertOneFrameSent(True, OP_CLOSE, close_frame_data) @@ -919,6 +918,7 @@ def test_answer_ping_does_not_crash_if_connection_closing(self): close_task = self.half_close_connection_local() self.receive_frame(Frame(True, OP_PING, b"test")) + self.run_loop_once() with self.assertNoLogs(): self.loop.run_until_complete(self.protocol.close()) @@ -931,6 +931,7 @@ def test_answer_ping_does_not_crash_if_connection_closed(self): # which prevents responding with a pong frame properly. self.receive_frame(Frame(True, OP_PING, b"test")) self.receive_eof() + self.run_loop_once() with self.assertNoLogs(): self.loop.run_until_complete(self.protocol.close()) @@ -1362,6 +1363,7 @@ def test_remote_close_and_connection_lost(self): # which prevents echoing the close frame properly. self.receive_frame(self.close_frame) self.receive_eof() + self.run_loop_once() with self.assertNoLogs(): self.loop.run_until_complete(self.protocol.close(reason="oh noes!")) @@ -1375,6 +1377,7 @@ def test_simultaneous_close(self): # https://github.com/aaugustin/websockets/issues/339 self.loop.call_soon(self.receive_frame, self.remote_close) self.loop.call_soon(self.receive_eof_if_client) + self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="local")) @@ -1386,6 +1389,7 @@ def test_simultaneous_close(self): def test_close_preserves_incoming_frames(self): self.receive_frame(Frame(True, OP_TEXT, b"hello")) + self.run_loop_once() self.loop.call_later(MS, self.receive_frame, self.close_frame) self.loop.call_later(MS, self.receive_eof_if_client) @@ -1573,6 +1577,7 @@ def test_local_close_connection_lost_timeout_after_write_eof(self): # HACK: disable write_eof => other end drops connection emulation. self.transport._eof = True self.receive_frame(self.close_frame) + self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) self.assertConnectionClosed(1000, "close") @@ -1589,6 +1594,7 @@ def test_local_close_connection_lost_timeout_after_close(self): # HACK: disable close => other end drops connection emulation. self.transport._closing = True self.receive_frame(self.close_frame) + self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) self.assertConnectionClosed(1000, "close") @@ -1631,6 +1637,7 @@ def test_local_close_connection_lost_timeout_after_write_eof(self): # HACK: disable write_eof => other end drops connection emulation. self.transport._eof = True self.receive_frame(self.close_frame) + self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) self.assertConnectionClosed(1000, "close") @@ -1650,5 +1657,6 @@ def test_local_close_connection_lost_timeout_after_close(self): # HACK: disable close => other end drops connection emulation. self.transport._closing = True self.receive_frame(self.close_frame) + self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) self.assertConnectionClosed(1000, "close") From 901e434fac7bf60018c950bdaf85b9946cc4309d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Apr 2023 09:30:00 +0200 Subject: [PATCH 1198/1539] Work around bug in coverage. --- src/websockets/legacy/protocol.py | 3 ++- tests/legacy/test_protocol.py | 18 ++++++++++++------ 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 8b921e6fe..733abb3b9 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1361,7 +1361,8 @@ async def close_transport(self) -> None: # Abort the TCP connection. Buffers are discarded. if self.debug: self.logger.debug("x aborting TCP connection") - self.transport.abort() + # Due to a bug in coverage, this is erroneously reported as not covered. + self.transport.abort() # pragma: no cover # connection_lost() is called quickly after aborting. await self.wait_for_connection_lost() diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 409aef901..a05dcc6f6 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1579,7 +1579,8 @@ def test_local_close_connection_lost_timeout_after_write_eof(self): self.receive_frame(self.close_frame) self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionClosed(1000, "close") + # Due to a bug in coverage, this is erroneously reported as not covered. + self.assertConnectionClosed(1000, "close") # pragma: no cover def test_local_close_connection_lost_timeout_after_close(self): self.protocol.close_timeout = 10 * MS @@ -1596,7 +1597,8 @@ def test_local_close_connection_lost_timeout_after_close(self): self.receive_frame(self.close_frame) self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionClosed(1000, "close") + # Due to a bug in coverage, this is erroneously reported as not covered. + self.assertConnectionClosed(1000, "close") # pragma: no cover class ClientTests(CommonTests, AsyncioTestCase): @@ -1614,7 +1616,8 @@ def test_local_close_send_close_frame_timeout(self): # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(19 * MS, 29 * MS): self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionClosed(1006, "") + # Due to a bug in coverage, this is erroneously reported as not covered. + self.assertConnectionClosed(1006, "") # pragma: no cover def test_local_close_receive_close_frame_timeout(self): self.protocol.close_timeout = 10 * MS @@ -1624,7 +1627,8 @@ def test_local_close_receive_close_frame_timeout(self): # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(19 * MS, 29 * MS): self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionClosed(1006, "") + # Due to a bug in coverage, this is erroneously reported as not covered. + self.assertConnectionClosed(1006, "") # pragma: no cover def test_local_close_connection_lost_timeout_after_write_eof(self): self.protocol.close_timeout = 10 * MS @@ -1639,7 +1643,8 @@ def test_local_close_connection_lost_timeout_after_write_eof(self): self.receive_frame(self.close_frame) self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionClosed(1000, "close") + # Due to a bug in coverage, this is erroneously reported as not covered. + self.assertConnectionClosed(1000, "close") # pragma: no cover def test_local_close_connection_lost_timeout_after_close(self): self.protocol.close_timeout = 10 * MS @@ -1659,4 +1664,5 @@ def test_local_close_connection_lost_timeout_after_close(self): self.receive_frame(self.close_frame) self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionClosed(1000, "close") + # Due to a bug in coverage, this is erroneously reported as not covered. + self.assertConnectionClosed(1000, "close") # pragma: no cover From 00835ccf2bc9bb483b6bf3a69dd487d3745fbb27 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Apr 2023 10:07:35 +0200 Subject: [PATCH 1199/1539] Fix typo. --- docs/reference/features.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/reference/features.rst b/docs/reference/features.rst index f4b592a4b..7e6e262dc 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -28,7 +28,7 @@ Both sides +------------------------------------+--------+--------+--------+ | | |aio| | |sync| | |sans| | +====================================+========+========+========+ - | Perfom the opening handshake | ✅ | ✅ | ✅ | + | Perform the opening handshake | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+ | Send a message | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+ @@ -50,7 +50,7 @@ Both sides +------------------------------------+--------+--------+--------+ | Send a pong | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+ - | Perfom the closing handshake | ✅ | ✅ | ✅ | + | Perform the closing handshake | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+ | Report close codes and reasons | ❌ | ✅ | ✅ | | from both sides | | | | From 7dd4ede471f95b59b2d15c669b927773a2371183 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Apr 2023 10:06:06 +0200 Subject: [PATCH 1200/1539] Add changelog for d3d4cf4a. --- docs/project/changelog.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 68fb2709f..8e9be789e 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -93,6 +93,8 @@ New features Improvements ............ +* Added platform-independent wheels. + * Improved error handling in :func:`~websockets.broadcast`. * Set ``server_hostname`` automatically on TLS connections when providing a From f516cf51e166ec0cc797fa6bb68b559e4c1fed8b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Apr 2023 10:00:26 +0200 Subject: [PATCH 1201/1539] Release version 11.0 --- docs/project/changelog.rst | 3 +-- src/websockets/version.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 8e9be789e..14caf2df9 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,11 +25,10 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. - 11.0 ---- -*In development* +*April 2, 2023* Backwards-incompatible changes .............................. diff --git a/src/websockets/version.py b/src/websockets/version.py index 1a3d884d8..375112512 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -16,7 +16,7 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = False +released = True tag = version = commit = "11.0" From fe1879fc22ced71f3b9b23f9fd7ac5a727ebb6bd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Apr 2023 10:01:02 +0200 Subject: [PATCH 1202/1539] Start version 11.1 --- docs/project/changelog.rst | 5 +++++ src/websockets/version.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 14caf2df9..bdb7d7f7d 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,11 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. +11.1 +---- + +*In development* + 11.0 ---- diff --git a/src/websockets/version.py b/src/websockets/version.py index 375112512..802dba546 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -16,9 +16,9 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = True +released = False -tag = version = commit = "11.0" +tag = version = commit = "11.1" if not released: # pragma: no cover From 1bf73423044dedc325435d261a39473337f5ddcf Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Apr 2023 10:48:19 +0200 Subject: [PATCH 1203/1539] Drop support for Python 3.7. --- .github/workflows/tests.yml | 4 ---- docs/intro/index.rst | 2 +- docs/project/changelog.rst | 11 ++++++++++- pyproject.toml | 3 +-- src/websockets/datastructures.py | 8 +------- src/websockets/legacy/client.py | 8 -------- src/websockets/legacy/compatibility.py | 23 +---------------------- src/websockets/legacy/protocol.py | 17 ++++------------- src/websockets/legacy/server.py | 16 ++++------------ src/websockets/sync/compatibility.py | 21 --------------------- src/websockets/sync/server.py | 5 ++--- src/websockets/version.py | 6 +++--- tests/legacy/test_protocol.py | 3 +-- tests/legacy/utils.py | 3 ++- tests/sync/client.py | 16 ---------------- tests/sync/server.py | 4 +--- tests/sync/test_server.py | 3 +-- 17 files changed, 32 insertions(+), 121 deletions(-) delete mode 100644 src/websockets/sync/compatibility.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 163d1bbae..603426412 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -57,19 +57,15 @@ jobs: strategy: matrix: python: - - "3.7" - "3.8" - "3.9" - "3.10" - "3.11" - - "pypy-3.7" - "pypy-3.8" - "pypy-3.9" is_main: - ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} exclude: - - python: "pypy-3.7" - is_main: false - python: "pypy-3.8" is_main: false - python: "pypy-3.9" diff --git a/docs/intro/index.rst b/docs/intro/index.rst index fe4e704d6..095262a20 100644 --- a/docs/intro/index.rst +++ b/docs/intro/index.rst @@ -6,7 +6,7 @@ Getting started Requirements ------------ -websockets requires Python ≥ 3.7. +websockets requires Python ≥ 3.8. .. admonition:: Use the most recent Python release :class: tip diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index bdb7d7f7d..7b6972fd0 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,11 +25,20 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. -11.1 +12.0 ---- *In development* +Backwards-incompatible changes +.............................. + +.. admonition:: websockets 12.0 requires Python ≥ 3.8. + :class: tip + + websockets 11.0 is the last version supporting Python 3.7. + + 11.0 ---- diff --git a/pyproject.toml b/pyproject.toml index 0707c6442..a4622ae2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [project] name = "websockets" description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" -requires-python = ">=3.7" +requires-python = ">=3.8" license = { text = "BSD-3-Clause" } authors = [ { name = "Aymeric Augustin", email = "aymeric.augustin@m4x.org" }, @@ -19,7 +19,6 @@ classifiers = [ "Operating System :: OS Independent", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index 36a2cbaf9..a0a648463 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -1,6 +1,5 @@ from __future__ import annotations -import sys from typing import ( Any, Dict, @@ -9,17 +8,12 @@ List, Mapping, MutableMapping, + Protocol, Tuple, Union, ) -if sys.version_info[:2] >= (3, 8): - from typing import Protocol -else: # pragma: no cover - Protocol = object # mypy will report errors on Python 3.7. - - __all__ = ["Headers", "HeadersLike", "MultipleValuesError"] diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index c5e9d0d52..48622523e 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -137,10 +137,6 @@ async def read_http_response(self) -> Tuple[int, Headers]: """ try: status_code, reason, headers = await read_response(self.reader) - # Remove this branch when dropping support for Python < 3.8 - # because CancelledError no longer inherits Exception. - except asyncio.CancelledError: # pragma: no cover - raise except Exception as exc: raise InvalidMessage("did not receive a valid HTTP response") from exc @@ -601,10 +597,6 @@ async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]: try: async with self as protocol: yield protocol - # Remove this branch when dropping support for Python < 3.8 - # because CancelledError no longer inherits Exception. - except asyncio.CancelledError: # pragma: no cover - raise except Exception: # Add a random initial delay between 0 and 5 seconds. # See 7.2.3. Recovering from Abnormal Closure in RFC 6544. diff --git a/src/websockets/legacy/compatibility.py b/src/websockets/legacy/compatibility.py index cb9b02c86..6bd01e70d 100644 --- a/src/websockets/legacy/compatibility.py +++ b/src/websockets/legacy/compatibility.py @@ -1,30 +1,9 @@ from __future__ import annotations -import asyncio import sys -from typing import Any, Dict -__all__ = ["asyncio_timeout", "loop_if_py_lt_38"] - - -if sys.version_info[:2] >= (3, 8): - - def loop_if_py_lt_38(loop: asyncio.AbstractEventLoop) -> Dict[str, Any]: - """ - Helper for the removal of the loop argument in Python 3.10. - - """ - return {} - -else: - - def loop_if_py_lt_38(loop: asyncio.AbstractEventLoop) -> Dict[str, Any]: - """ - Helper for the removal of the loop argument in Python 3.10. - - """ - return {"loop": loop} +__all__ = ["asyncio_timeout"] if sys.version_info[:2] >= (3, 11): diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 733abb3b9..0422c10da 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -53,7 +53,7 @@ ) from ..protocol import State from ..typing import Data, LoggerLike, Subprotocol -from .compatibility import asyncio_timeout, loop_if_py_lt_38 +from .compatibility import asyncio_timeout from .framing import Frame @@ -244,7 +244,7 @@ def __init__( self._paused = False self._drain_waiter: Optional[asyncio.Future[None]] = None - self._drain_lock = asyncio.Lock(**loop_if_py_lt_38(loop)) + self._drain_lock = asyncio.Lock() # This class implements the data transfer and closing handshake, which # are shared between the client-side and the server-side. @@ -339,7 +339,7 @@ async def _drain(self) -> None: # pragma: no cover # write(...); yield from drain() # in a loop would never call connection_lost(), so it # would not see an error when the socket is closed. - await asyncio.sleep(0, **loop_if_py_lt_38(self.loop)) + await asyncio.sleep(0) await self._drain_helper() def connection_open(self) -> None: @@ -551,7 +551,6 @@ async def recv(self) -> Data: await asyncio.wait( [pop_message_waiter, self.transfer_data_task], return_when=asyncio.FIRST_COMPLETED, - **loop_if_py_lt_38(self.loop), ) finally: self._pop_message_waiter = None @@ -1247,10 +1246,7 @@ async def keepalive_ping(self) -> None: try: while True: - await asyncio.sleep( - self.ping_interval, - **loop_if_py_lt_38(self.loop), - ) + await asyncio.sleep(self.ping_interval) # ping() raises CancelledError if the connection is closed, # when close_connection() cancels self.keepalive_ping_task. @@ -1272,11 +1268,6 @@ async def keepalive_ping(self) -> None: self.fail_connection(1011, "keepalive ping timeout") break - # Remove this branch when dropping support for Python < 3.8 - # because CancelledError no longer inherits Exception. - except asyncio.CancelledError: - raise - except ConnectionClosed: pass diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 25d5a7144..a17c52328 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -46,7 +46,7 @@ from ..http import USER_AGENT from ..protocol import State from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol -from .compatibility import asyncio_timeout, loop_if_py_lt_38 +from .compatibility import asyncio_timeout from .handshake import build_response, check_request from .http import read_request from .protocol import WebSocketCommonProtocol @@ -170,10 +170,6 @@ async def handler(self) -> None: available_subprotocols=self.available_subprotocols, extra_headers=self.extra_headers, ) - # Remove this branch when dropping support for Python < 3.8 - # because CancelledError no longer inherits Exception. - except asyncio.CancelledError: # pragma: no cover - raise except asyncio.TimeoutError: # pragma: no cover raise except ConnectionError: @@ -770,7 +766,7 @@ async def _close(self, close_connections: bool) -> None: # Wait until all accepted connections reach connection_made() and call # register(). See https://bugs.python.org/issue34852 for details. - await asyncio.sleep(0, **loop_if_py_lt_38(self.get_loop())) + await asyncio.sleep(0) if close_connections: # Close OPEN connections with status code 1001. Since the server was @@ -784,18 +780,14 @@ async def _close(self, close_connections: bool) -> None: ] # asyncio.wait doesn't accept an empty first argument. if close_tasks: - await asyncio.wait( - close_tasks, - **loop_if_py_lt_38(self.get_loop()), - ) + await asyncio.wait(close_tasks) # Wait until all connection handlers are complete. # asyncio.wait doesn't accept an empty first argument. if self.websockets: await asyncio.wait( - [websocket.handler_task for websocket in self.websockets], - **loop_if_py_lt_38(self.get_loop()), + [websocket.handler_task for websocket in self.websockets] ) # Tell wait_closed() to return. diff --git a/src/websockets/sync/compatibility.py b/src/websockets/sync/compatibility.py deleted file mode 100644 index 38d2ab668..000000000 --- a/src/websockets/sync/compatibility.py +++ /dev/null @@ -1,21 +0,0 @@ -from __future__ import annotations - - -try: - from socket import create_server as socket_create_server -except ImportError: - import socket - - def socket_create_server(address, family=socket.AF_INET): # type: ignore - """Simplified backport of socket.create_server from Python 3.8.""" - sock = socket.socket(family, socket.SOCK_STREAM) - try: - sock.bind(address) - sock.listen() - return sock - except socket.error: - sock.close() - raise - - -__all__ = ["socket_create_server"] diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 9284c6188..e25646a5c 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -18,7 +18,6 @@ from ..protocol import CONNECTING, OPEN, Event from ..server import ServerProtocol from ..typing import LoggerLike, Origin, Subprotocol -from .compatibility import socket_create_server from .connection import Connection from .utils import Deadline @@ -397,9 +396,9 @@ def handler(websocket): if unix: if path is None: raise TypeError("missing path argument") - sock = socket_create_server(path, family=socket.AF_UNIX) + sock = socket.create_server(path, family=socket.AF_UNIX) else: - sock = socket_create_server((host, port)) + sock = socket.create_server((host, port)) else: if path is not None: raise TypeError("path and sock arguments are incompatible") diff --git a/src/websockets/version.py b/src/websockets/version.py index 802dba546..f45f069ab 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -1,5 +1,7 @@ from __future__ import annotations +import importlib.metadata + __all__ = ["tag", "version", "commit"] @@ -18,7 +20,7 @@ released = False -tag = version = commit = "11.1" +tag = version = commit = "12.0" if not released: # pragma: no cover @@ -56,8 +58,6 @@ def get_version(tag: str) -> str: # Read version from package metadata if it is installed. try: - import importlib.metadata # move up when dropping Python 3.7 - return importlib.metadata.version("websockets") except ImportError: pass diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index a05dcc6f6..9338e15bd 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -16,7 +16,6 @@ OP_TEXT, Close, ) -from websockets.legacy.compatibility import loop_if_py_lt_38 from websockets.legacy.framing import Frame from websockets.legacy.protocol import WebSocketCommonProtocol, broadcast from websockets.protocol import State @@ -117,7 +116,7 @@ def make_drain_slow(self, delay=MS): original_drain = self.protocol._drain async def delayed_drain(): - await asyncio.sleep(delay, **loop_if_py_lt_38(self.loop)) + await asyncio.sleep(delay) await original_drain() self.protocol._drain = delayed_drain diff --git a/tests/legacy/utils.py b/tests/legacy/utils.py index bb4eebb52..4a21dcaeb 100644 --- a/tests/legacy/utils.py +++ b/tests/legacy/utils.py @@ -9,7 +9,8 @@ class AsyncioTestCase(unittest.TestCase): """ Base class for tests that sets up an isolated event loop for each test. - Replace with IsolatedAsyncioTestCase when dropping Python < 3.8. + IsolatedAsyncioTestCase was introduced in Python 3.8 for similar purposes + but isn't a drop-in replacement. """ diff --git a/tests/sync/client.py b/tests/sync/client.py index 51bbd4388..683893e88 100644 --- a/tests/sync/client.py +++ b/tests/sync/client.py @@ -1,7 +1,5 @@ import contextlib import ssl -import sys -import warnings from websockets.sync.client import * from websockets.sync.server import WebSocketServer @@ -19,20 +17,6 @@ CLIENT_CONTEXT = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) CLIENT_CONTEXT.load_verify_locations(CERTIFICATE) -# Work around https://github.com/openssl/openssl/issues/7967 - -# This bug causes connect() to hang in tests for the client. Including this -# workaround acknowledges that the issue could happen outside of the test suite. - -# It shouldn't happen too often, or else OpenSSL 1.1.1 would be unusable. If it -# happens, we can look for a library-level fix, but it won't be easy. - -if sys.version_info[:2] < (3, 8): # pragma: no cover - # ssl.OP_NO_TLSv1_3 was introduced and deprecated on Python 3.7. - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - CLIENT_CONTEXT.options |= ssl.OP_NO_TLSv1_3 - @contextlib.contextmanager def run_client(wsuri_or_server, secure=None, resource_name="/", **kwargs): diff --git a/tests/sync/server.py b/tests/sync/server.py index 5f0cd3b07..a9a77438c 100644 --- a/tests/sync/server.py +++ b/tests/sync/server.py @@ -1,6 +1,5 @@ import contextlib import ssl -import sys import threading from websockets.sync.server import * @@ -19,8 +18,7 @@ # It shouldn't happen too often, or else OpenSSL 1.1.1 would be unusable. If it # happens, we can look for a library-level fix, but it won't be easy. -if sys.version_info[:2] >= (3, 8): # pragma: no cover - SERVER_CONTEXT.num_tickets = 0 +SERVER_CONTEXT.num_tickets = 0 def crash(ws): diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 536858149..60e70c0a5 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -12,7 +12,6 @@ NegotiationError, ) from websockets.http11 import Request, Response -from websockets.sync.compatibility import socket_create_server from websockets.sync.server import * from ..utils import MS, temp_unix_socket_path @@ -72,7 +71,7 @@ def test_connection_handler_raises_exception(self): def test_existing_socket(self): """Server receives connection using a pre-existing socket.""" - with socket_create_server(("localhost", 0)) as sock: + with socket.create_server(("localhost", 0)) as sock: with run_server(sock=sock): # Build WebSocket URI to ensure we connect to the right socket. with run_client("ws://{}:{}/".format(*sock.getsockname())) as client: From fff80c316e25af645ab5cc9b7217b2064d429d73 Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Mon, 3 Apr 2023 00:03:43 +0300 Subject: [PATCH 1204/1539] Fix FAQ link in issue template --- .github/ISSUE_TEMPLATE/issue.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ISSUE_TEMPLATE/issue.md b/.github/ISSUE_TEMPLATE/issue.md index f2704152c..6c54e7fd5 100644 --- a/.github/ISSUE_TEMPLATE/issue.md +++ b/.github/ISSUE_TEMPLATE/issue.md @@ -12,7 +12,7 @@ assignees: '' Thanks for taking the time to report an issue! Did you check the FAQ? Perhaps you'll find the answer you need: -https://websockets.readthedocs.io/en/stable/howto/faq.html +https://websockets.readthedocs.io/en/stable/faq/index.html Is your question really about asyncio? Perhaps the dev guide will help: https://docs.python.org/3/library/asyncio-dev.html From 19be15b387bf92697ff71e6f3e6d5038fe1515d8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 6 Apr 2023 07:54:30 +0200 Subject: [PATCH 1205/1539] Restore speedups.c in source distribution. --- .github/workflows/wheels.yml | 7 +++---- docs/project/changelog.rst | 9 +++++++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 68bfbdef4..00bd0ccc5 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -10,8 +10,6 @@ jobs: sdist: name: Build source distribution and architecture-independent wheel runs-on: ubuntu-latest - env: - BUILD_EXTENSION: no steps: - name: Check out repository uses: actions/checkout@v3 @@ -28,6 +26,8 @@ jobs: - name: Install wheel run: pip install wheel - name: Build wheel + env: + BUILD_EXTENSION: no run: python setup.py bdist_wheel - name: Save wheel uses: actions/upload-artifact@v3 @@ -37,8 +37,6 @@ jobs: wheels: name: Build architecture-specific wheels on ${{ matrix.os }} runs-on: ${{ matrix.os }} - env: - BUILD_EXTENSION: yes strategy: matrix: os: @@ -60,6 +58,7 @@ jobs: - name: Build wheels uses: pypa/cibuildwheel@v2.12.1 env: + BUILD_EXTENSION: yes CIBW_ARCHS_MACOS: "x86_64 universal2 arm64" CIBW_ARCHS_LINUX: "auto aarch64" CIBW_SKIP: cp36-* diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 7b6972fd0..f35caa19e 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -38,6 +38,15 @@ Backwards-incompatible changes websockets 11.0 is the last version supporting Python 3.7. +11.0.1 +------ + +*April 6, 2023* + +Bug fixes +......... + +* Restored the C extension in the source distribution. 11.0 ---- From 659e562d83187a419592e1a472b0d7af657d6642 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 6 Apr 2023 07:37:02 +0200 Subject: [PATCH 1206/1539] Move cibuildwheel configuration to pyproject.toml. cibuildwheel can read requires-python. --- .github/workflows/wheels.yml | 3 --- pyproject.toml | 8 ++++++++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 00bd0ccc5..d9c9ea833 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -59,9 +59,6 @@ jobs: uses: pypa/cibuildwheel@v2.12.1 env: BUILD_EXTENSION: yes - CIBW_ARCHS_MACOS: "x86_64 universal2 arm64" - CIBW_ARCHS_LINUX: "auto aarch64" - CIBW_SKIP: cp36-* - name: Save wheels uses: actions/upload-artifact@v3 with: diff --git a/pyproject.toml b/pyproject.toml index a4622ae2d..9065aab45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,14 @@ documentation = "https://websockets.readthedocs.io/" funding = "https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=readme" tracker = "https://github.com/aaugustin/websockets/issues" +# On a macOS runner, build Intel, Universal, and Apple Silicon wheels. +[tool.cibuildwheel.macos] +archs = ["x86_64", "universal2", "arm64"] + +# On an Linux Intel runner with QEMU installed, build Intel and ARM wheels. +[tool.cibuildwheel.linux] +archs = ["auto", "aarch64"] + [tool.coverage.run] branch = true omit = [ From 1bc18193b32b210cc804e7e174d15c0b27db21f0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 7 Apr 2023 08:11:30 +0200 Subject: [PATCH 1207/1539] Move repository to python-websockets organization. --- .github/ISSUE_TEMPLATE/issue.md | 4 ++-- README.rst | 14 +++++++------- docs/conf.py | 2 +- docs/faq/asyncio.rst | 2 +- docs/index.rst | 4 ++-- docs/intro/tutorial3.rst | 8 ++++---- docs/project/changelog.rst | 2 +- docs/project/contributing.rst | 6 +++--- docs/reference/features.rst | 10 +++++----- docs/topics/authentication.rst | 2 +- docs/topics/broadcast.rst | 2 +- docs/topics/compression.rst | 4 ++-- example/tutorial/step3/main.js | 2 +- experiments/compression/benchmark.py | 2 +- pyproject.toml | 4 ++-- tests/legacy/test_protocol.py | 2 +- 16 files changed, 35 insertions(+), 35 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/issue.md b/.github/ISSUE_TEMPLATE/issue.md index 6c54e7fd5..3cf4e3b77 100644 --- a/.github/ISSUE_TEMPLATE/issue.md +++ b/.github/ISSUE_TEMPLATE/issue.md @@ -18,12 +18,12 @@ Is your question really about asyncio? Perhaps the dev guide will help: https://docs.python.org/3/library/asyncio-dev.html Did you look for similar issues? Please keep the discussion in one place :-) -https://github.com/aaugustin/websockets/issues?q=is%3Aissue +https://github.com/python-websockets/websockets/issues?q=is%3Aissue Is your issue related to cryptocurrency in any way? Please don't file it. https://websockets.readthedocs.io/en/stable/project/contributing.html#cryptocurrency-users For bugs, providing a reproduction helps a lot. Take an existing example and tweak it! -https://github.com/aaugustin/websockets/tree/main/example +https://github.com/python-websockets/websockets/tree/main/example --> diff --git a/README.rst b/README.rst index 5ba523e8f..f53d3d0fc 100644 --- a/README.rst +++ b/README.rst @@ -13,8 +13,8 @@ .. |pyversions| image:: https://img.shields.io/pypi/pyversions/websockets.svg :target: https://pypi.python.org/pypi/websockets -.. |tests| image:: https://img.shields.io/github/checks-status/aaugustin/websockets/main?label=tests - :target: https://github.com/aaugustin/websockets/actions/workflows/tests.yml +.. |tests| image:: https://img.shields.io/github/checks-status/python-websockets/websockets/main?label=tests + :target: https://github.com/python-websockets/websockets/actions/workflows/tests.yml .. |docs| image:: https://img.shields.io/readthedocs/websockets.svg :target: https://websockets.readthedocs.io/ @@ -84,7 +84,7 @@ Does that look good? .. raw:: html
- +

websockets for enterprise

Available as part of the Tidelift Subscription

The maintainers of websockets and thousands of other packages are working with Tidelift to deliver commercial support and maintenance for the open source dependencies you use to build your applications. Save time, reduce risk, and improve code health, while paying the maintainers of the exact dependencies you use. Learn more.

@@ -147,13 +147,13 @@ contact`_. Tidelift will coordinate the fix and disclosure. For anything else, please open an issue_ or send a `pull request`_. -.. _issue: https://github.com/aaugustin/websockets/issues/new -.. _pull request: https://github.com/aaugustin/websockets/compare/ +.. _issue: https://github.com/python-websockets/websockets/issues/new +.. _pull request: https://github.com/python-websockets/websockets/compare/ Participants must uphold the `Contributor Covenant code of conduct`_. -.. _Contributor Covenant code of conduct: https://github.com/aaugustin/websockets/blob/main/CODE_OF_CONDUCT.md +.. _Contributor Covenant code of conduct: https://github.com/python-websockets/websockets/blob/main/CODE_OF_CONDUCT.md ``websockets`` is released under the `BSD license`_. -.. _BSD license: https://github.com/aaugustin/websockets/blob/main/LICENSE +.. _BSD license: https://github.com/python-websockets/websockets/blob/main/LICENSE diff --git a/docs/conf.py b/docs/conf.py index 58f0c2d55..9d61dc717 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -97,7 +97,7 @@ # Configure viewcode extension. from websockets.version import commit -code_url = f"https://github.com/aaugustin/websockets/blob/{commit}" +code_url = f"https://github.com/python-websockets/websockets/blob/{commit}" def linkcode_resolve(domain, info): # Non-linkable objects from the starter kit in the tutorial. diff --git a/docs/faq/asyncio.rst b/docs/faq/asyncio.rst index e56a42d36..e77f50add 100644 --- a/docs/faq/asyncio.rst +++ b/docs/faq/asyncio.rst @@ -41,7 +41,7 @@ every coroutine. See `issue 867`_. -.. _issue 867: https://github.com/aaugustin/websockets/issues/867 +.. _issue 867: https://github.com/python-websockets/websockets/issues/867 Why am I having problems with threads? -------------------------------------- diff --git a/docs/index.rst b/docs/index.rst index be6a0da05..d9737db12 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -12,8 +12,8 @@ websockets .. |pyversions| image:: https://img.shields.io/pypi/pyversions/websockets.svg :target: https://pypi.python.org/pypi/websockets -.. |tests| image:: https://img.shields.io/github/checks-status/aaugustin/websockets/main?label=tests - :target: https://github.com/aaugustin/websockets/actions/workflows/tests.yml +.. |tests| image:: https://img.shields.io/github/checks-status/python-websockets/websockets/main?label=tests + :target: https://github.com/python-websockets/websockets/actions/workflows/tests.yml .. |docs| image:: https://img.shields.io/readthedocs/websockets.svg :target: https://websockets.readthedocs.io/ diff --git a/docs/intro/tutorial3.rst b/docs/intro/tutorial3.rst index 4d42447b7..6fdec113b 100644 --- a/docs/intro/tutorial3.rst +++ b/docs/intro/tutorial3.rst @@ -198,7 +198,7 @@ in ``main.js``: You can take this strategy one step further by checking the address of the HTTP server and determining the address of the WebSocket server accordingly. -Add this function to ``main.js``; replace ``aaugustin`` by your GitHub +Add this function to ``main.js``; replace ``python-websockets`` by your GitHub username and ``websockets-tutorial`` by the name of your app on Heroku: .. literalinclude:: ../../example/tutorial/step3/main.js @@ -226,12 +226,12 @@ Deploy the web application Go to GitHub and create a new repository called ``websockets-tutorial``. -Push your code to this repository. You must replace ``aaugustin`` by your -GitHub username in the following command: +Push your code to this repository. You must replace ``python-websockets`` by +your GitHub username in the following command: .. code-block:: console - $ git remote add origin git@github.com:aaugustin/websockets-tutorial.git + $ git remote add origin git@github.com:python-websockets/websockets-tutorial.git $ git push -u origin main Enumerating objects: 11, done. Counting objects: 100% (11/11), done. diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index f35caa19e..127ea3142 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -225,7 +225,7 @@ Improvements * Reverted optimization of default compression settings for clients, mainly to avoid triggering bugs in poorly implemented servers like `AWS API Gateway`_. - .. _AWS API Gateway: https://github.com/aaugustin/websockets/issues/1065 + .. _AWS API Gateway: https://github.com/python-websockets/websockets/issues/1065 * Mirrored the entire :class:`~asyncio.Server` API in :class:`~server.WebSocketServer`. diff --git a/docs/project/contributing.rst b/docs/project/contributing.rst index 43fd58dc8..020ed7ad8 100644 --- a/docs/project/contributing.rst +++ b/docs/project/contributing.rst @@ -10,7 +10,7 @@ This project and everyone participating in it is governed by the `Code of Conduct`_. By participating, you are expected to uphold this code. Please report inappropriate behavior to aymeric DOT augustin AT fractalideas DOT com. -.. _Code of Conduct: https://github.com/aaugustin/websockets/blob/main/CODE_OF_CONDUCT.md +.. _Code of Conduct: https://github.com/python-websockets/websockets/blob/main/CODE_OF_CONDUCT.md *(If I'm the person with the inappropriate behavior, please accept my apologies. I know I can mess up. I can't expect you to tell me, but if you @@ -31,8 +31,8 @@ If you're wondering why things are done in a certain way, the :doc:`design document <../topics/design>` provides lots of details about the internals of websockets. -.. _issue: https://github.com/aaugustin/websockets/issues/new -.. _pull request: https://github.com/aaugustin/websockets/compare/ +.. _issue: https://github.com/python-websockets/websockets/issues/new +.. _pull request: https://github.com/python-websockets/websockets/compare/ Questions --------- diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 7e6e262dc..3cc52ec10 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -78,7 +78,7 @@ Both sides | Heartbeat | ✅ | ❌ | — | +------------------------------------+--------+--------+--------+ -.. _#479: https://github.com/aaugustin/websockets/issues/479 +.. _#479: https://github.com/python-websockets/websockets/issues/479 Server ------ @@ -162,9 +162,9 @@ Client | (`#475`_) | | | | +------------------------------------+--------+--------+--------+ -.. _#364: https://github.com/aaugustin/websockets/issues/364 -.. _#475: https://github.com/aaugustin/websockets/issues/475 -.. _#784: https://github.com/aaugustin/websockets/issues/784 +.. _#364: https://github.com/python-websockets/websockets/issues/364 +.. _#475: https://github.com/python-websockets/websockets/issues/475 +.. _#784: https://github.com/python-websockets/websockets/issues/784 Known limitations ----------------- @@ -172,7 +172,7 @@ Known limitations There is no way to control compression of outgoing frames on a per-frame basis (`#538`_). If compression is enabled, all frames are compressed. -.. _#538: https://github.com/aaugustin/websockets/issues/538 +.. _#538: https://github.com/python-websockets/websockets/issues/538 The client API doesn't attempt to guarantee that there is no more than one connection to a given IP address in a CONNECTING state. This behavior is diff --git a/docs/topics/authentication.rst b/docs/topics/authentication.rst index 1849d635a..60dd38766 100644 --- a/docs/topics/authentication.rst +++ b/docs/topics/authentication.rst @@ -152,7 +152,7 @@ The `experiments/authentication`_ directory demonstrates these techniques. Run the experiment in an environment where websockets is installed: -.. _experiments/authentication: https://github.com/aaugustin/websockets/tree/main/experiments/authentication +.. _experiments/authentication: https://github.com/python-websockets/websockets/tree/main/experiments/authentication .. code-block:: console diff --git a/docs/topics/broadcast.rst b/docs/topics/broadcast.rst index 9a25cbf7d..1acb372d4 100644 --- a/docs/topics/broadcast.rst +++ b/docs/topics/broadcast.rst @@ -51,7 +51,7 @@ message, or else it will never let the server run. That's why it includes A complete example is available in the `experiments/broadcast`_ directory. -.. _experiments/broadcast: https://github.com/aaugustin/websockets/tree/main/experiments/broadcast +.. _experiments/broadcast: https://github.com/python-websockets/websockets/tree/main/experiments/broadcast The naive way ------------- diff --git a/docs/topics/compression.rst b/docs/topics/compression.rst index f0b7ce898..eaf99070d 100644 --- a/docs/topics/compression.rst +++ b/docs/topics/compression.rst @@ -158,7 +158,7 @@ corpus. Defaults must be safe for all applications, hence a more conservative choice. -.. _compression/benchmark.py: https://github.com/aaugustin/websockets/blob/main/experiments/compression/benchmark.py +.. _compression/benchmark.py: https://github.com/python-websockets/websockets/blob/main/experiments/compression/benchmark.py The benchmark focuses on compression because it's more expensive than decompression. Indeed, leaving aside small allocations, theoretical memory @@ -199,7 +199,7 @@ like the server side: 3. On a more pragmatic note, some servers misbehave badly when a client configures compression settings. `AWS API Gateway`_ is the worst offender. - .. _AWS API Gateway: https://github.com/aaugustin/websockets/issues/1065 + .. _AWS API Gateway: https://github.com/python-websockets/websockets/issues/1065 Unfortunately, even though websockets is right and AWS is wrong, many users jump to the conclusion that websockets doesn't work. diff --git a/example/tutorial/step3/main.js b/example/tutorial/step3/main.js index 15afd4163..3000fa2f7 100644 --- a/example/tutorial/step3/main.js +++ b/example/tutorial/step3/main.js @@ -1,7 +1,7 @@ import { createBoard, playMove } from "./connect4.js"; function getWebSocketServer() { - if (window.location.host === "aaugustin.github.io") { + if (window.location.host === "python-websockets.github.io") { return "wss://websockets-tutorial.herokuapp.com/"; } else if (window.location.host === "localhost:8000") { return "ws://localhost:8001/"; diff --git a/experiments/compression/benchmark.py b/experiments/compression/benchmark.py index bdcdd8e95..c5b13c8fa 100644 --- a/experiments/compression/benchmark.py +++ b/experiments/compression/benchmark.py @@ -20,7 +20,7 @@ def _corpus(): OAUTH_TOKEN = getpass.getpass("OAuth Token? ") COMMIT_API = ( f'curl -H "Authorization: token {OAUTH_TOKEN}" ' - f"https://api.github.com/repos/aaugustin/websockets/git/commits/:sha" + f"https://api.github.com/repos/python-websockets/websockets/git/commits/:sha" ) commits = [] diff --git a/pyproject.toml b/pyproject.toml index 9065aab45..941b01c08 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,11 +27,11 @@ classifiers = [ dynamic = ["version", "readme"] [project.urls] -homepage = "https://github.com/aaugustin/websockets" +homepage = "https://github.com/python-websockets/websockets" changelog = "https://websockets.readthedocs.io/en/stable/project/changelog.html" documentation = "https://websockets.readthedocs.io/" funding = "https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=readme" -tracker = "https://github.com/aaugustin/websockets/issues" +tracker = "https://github.com/python-websockets/websockets/issues" # On a macOS runner, build Intel, Universal, and Apple Silicon wheels. [tool.cibuildwheel.macos] diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 9338e15bd..514d91ef8 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1373,7 +1373,7 @@ def test_remote_close_and_connection_lost(self): def test_simultaneous_close(self): # Receive the incoming close frame right after self.protocol.close() # starts executing. This reproduces the error described in: - # https://github.com/aaugustin/websockets/issues/339 + # https://github.com/python-websockets/websockets/issues/339 self.loop.call_soon(self.receive_frame, self.remote_close) self.loop.call_soon(self.receive_eof_if_client) self.run_loop_once() From 8c6f7267e6ebb619904df7a4c8c005cc30307de6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 7 Apr 2023 08:13:02 +0200 Subject: [PATCH 1208/1539] Capitalize link titles for nicer display on PyPI. --- pyproject.toml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 941b01c08..c26e3aabc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,11 +27,11 @@ classifiers = [ dynamic = ["version", "readme"] [project.urls] -homepage = "https://github.com/python-websockets/websockets" -changelog = "https://websockets.readthedocs.io/en/stable/project/changelog.html" -documentation = "https://websockets.readthedocs.io/" -funding = "https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=readme" -tracker = "https://github.com/python-websockets/websockets/issues" +Homepage = "https://github.com/python-websockets/websockets" +Changelog = "https://websockets.readthedocs.io/en/stable/project/changelog.html" +Documentation = "https://websockets.readthedocs.io/" +Funding = "https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=readme" +Tracker = "https://github.com/python-websockets/websockets/issues" # On a macOS runner, build Intel, Universal, and Apple Silicon wheels. [tool.cibuildwheel.macos] From a0daea5b870535996f61231541a38aa43b797b9a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Apr 2023 18:00:55 +0200 Subject: [PATCH 1209/1539] Make recv buffer size a class attribute. This supports overriding it by subclassing, rather than monkey-patching. Also use a more realistic value in the docs. --- docs/howto/sansio.rst | 2 +- src/websockets/sync/connection.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/howto/sansio.rst b/docs/howto/sansio.rst index 08b09f7ce..d41519ff0 100644 --- a/docs/howto/sansio.rst +++ b/docs/howto/sansio.rst @@ -133,7 +133,7 @@ When reaching the end of the data stream, call the protocol's For example, if ``sock`` is a :obj:`~socket.socket`:: try: - data = sock.recv(4096) + data = sock.recv(65536) except OSError: # socket closed data = b"" if data: diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 64e5c8b44..59c0bc071 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -23,8 +23,6 @@ logger = logging.getLogger(__name__) -BUFSIZE = 65536 - class Connection: """ @@ -39,6 +37,8 @@ class Connection: """ + recv_bufsize = 65536 + def __init__( self, socket: socket.socket, @@ -525,7 +525,7 @@ def recv_events(self) -> None: try: if self.close_deadline is not None: self.socket.settimeout(self.close_deadline.timeout()) - data = self.socket.recv(BUFSIZE) + data = self.socket.recv(self.recv_bufsize) except Exception as exc: if self.debug: self.logger.debug("error while receiving data", exc_info=True) From a6e14978190ddf79e03f34df39a27a99b5619c2a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Apr 2023 18:30:22 +0200 Subject: [PATCH 1210/1539] Archive benchmark script for stream readers. --- experiments/optimization/streams.py | 301 ++++++++++++++++++++++++++++ 1 file changed, 301 insertions(+) create mode 100644 experiments/optimization/streams.py diff --git a/experiments/optimization/streams.py b/experiments/optimization/streams.py new file mode 100644 index 000000000..ca24a5983 --- /dev/null +++ b/experiments/optimization/streams.py @@ -0,0 +1,301 @@ +""" +Benchmark two possible implementations of a stream reader. + +The difference lies in the data structure that buffers incoming data: + +* ``ByteArrayStreamReader`` uses a ``bytearray``; +* ``BytesDequeStreamReader`` uses a ``deque[bytes]``. + +``ByteArrayStreamReader`` is faster for streaming small frames, which is the +standard use case of websockets, likely due to its simple implementation and +to ``bytearray`` being fast at appending data and removing data at the front +(https://hg.python.org/cpython/rev/499a96611baa). + +``BytesDequeStreamReader`` is faster for large frames and for bursts, likely +because it copies payloads only once, while ``ByteArrayStreamReader`` copies +them twice. + +""" + + +import collections +import os +import timeit + + +# Implementations + + +class ByteArrayStreamReader: + def __init__(self): + self.buffer = bytearray() + self.eof = False + + def readline(self): + n = 0 # number of bytes to read + p = 0 # number of bytes without a newline + while True: + n = self.buffer.find(b"\n", p) + 1 + if n > 0: + break + p = len(self.buffer) + yield + r = self.buffer[:n] + del self.buffer[:n] + return r + + def readexactly(self, n): + assert n >= 0 + while len(self.buffer) < n: + yield + r = self.buffer[:n] + del self.buffer[:n] + return r + + def feed_data(self, data): + self.buffer += data + + def feed_eof(self): + self.eof = True + + def at_eof(self): + return self.eof and not self.buffer + + +class BytesDequeStreamReader: + def __init__(self): + self.buffer = collections.deque() + self.eof = False + + def readline(self): + b = [] + while True: + # Read next chunk + while True: + try: + c = self.buffer.popleft() + except IndexError: + yield + else: + break + # Handle chunk + n = c.find(b"\n") + 1 + if n == len(c): + # Read exactly enough data + b.append(c) + break + elif n > 0: + # Read too much data + b.append(c[:n]) + self.buffer.appendleft(c[n:]) + break + else: # n == 0 + # Need to read more data + b.append(c) + return b"".join(b) + + def readexactly(self, n): + if n == 0: + return b"" + b = [] + while True: + # Read next chunk + while True: + try: + c = self.buffer.popleft() + except IndexError: + yield + else: + break + # Handle chunk + n -= len(c) + if n == 0: + # Read exactly enough data + b.append(c) + break + elif n < 0: + # Read too much data + b.append(c[:n]) + self.buffer.appendleft(c[n:]) + break + else: # n >= 0 + # Need to read more data + b.append(c) + return b"".join(b) + + def feed_data(self, data): + self.buffer.append(data) + + def feed_eof(self): + self.eof = True + + def at_eof(self): + return self.eof and not self.buffer + + +# Tests + + +class Protocol: + def __init__(self, StreamReader): + self.reader = StreamReader() + self.events = [] + # Start parser coroutine + self.parser = self.run_parser() + next(self.parser) + + def run_parser(self): + while True: + frame = yield from self.reader.readexactly(2) + self.events.append(frame) + frame = yield from self.reader.readline() + self.events.append(frame) + + def data_received(self, data): + self.reader.feed_data(data) + next(self.parser) # run parser until more data is needed + events, self.events = self.events, [] + return events + + +def run_test(StreamReader): + proto = Protocol(StreamReader) + + actual = proto.data_received(b"a") + expected = [] + assert actual == expected, f"{actual} != {expected}" + + actual = proto.data_received(b"b") + expected = [b"ab"] + assert actual == expected, f"{actual} != {expected}" + + actual = proto.data_received(b"c") + expected = [] + assert actual == expected, f"{actual} != {expected}" + + actual = proto.data_received(b"\n") + expected = [b"c\n"] + assert actual == expected, f"{actual} != {expected}" + + actual = proto.data_received(b"efghi\njklmn") + expected = [b"ef", b"ghi\n", b"jk"] + assert actual == expected, f"{actual} != {expected}" + + +# Benchmarks + + +def get_frame_packets(size, packet_size=None): + if size < 126: + frame = bytes([138, size]) + elif size < 65536: + frame = bytes([138, 126]) + bytes(divmod(size, 256)) + else: + size1, size2 = divmod(size, 65536) + frame = ( + bytes([138, 127]) + bytes(divmod(size1, 256)) + bytes(divmod(size2, 256)) + ) + frame += os.urandom(size) + if packet_size is None: + return [frame] + else: + packets = [] + while frame: + packets.append(frame[:packet_size]) + frame = frame[packet_size:] + return packets + + +def benchmark_stream(StreamReader, packets, size, count): + reader = StreamReader() + for _ in range(count): + for packet in packets: + reader.feed_data(packet) + yield from reader.readexactly(2) + if size >= 65536: + yield from reader.readexactly(4) + elif size >= 126: + yield from reader.readexactly(2) + yield from reader.readexactly(size) + reader.feed_eof() + assert reader.at_eof() + + +def benchmark_burst(StreamReader, packets, size, count): + reader = StreamReader() + for _ in range(count): + for packet in packets: + reader.feed_data(packet) + reader.feed_eof() + for _ in range(count): + yield from reader.readexactly(2) + if size >= 65536: + yield from reader.readexactly(4) + elif size >= 126: + yield from reader.readexactly(2) + yield from reader.readexactly(size) + assert reader.at_eof() + + +def run_benchmark(size, count, packet_size=None, number=1000): + stmt = f"list(benchmark(StreamReader, packets, {size}, {count}))" + setup = f"packets = get_frame_packets({size}, {packet_size})" + context = globals() + + context["StreamReader"] = context["ByteArrayStreamReader"] + context["benchmark"] = context["benchmark_stream"] + bas = min(timeit.repeat(stmt, setup, number=number, globals=context)) + context["benchmark"] = context["benchmark_burst"] + bab = min(timeit.repeat(stmt, setup, number=number, globals=context)) + + context["StreamReader"] = context["BytesDequeStreamReader"] + context["benchmark"] = context["benchmark_stream"] + bds = min(timeit.repeat(stmt, setup, number=number, globals=context)) + context["benchmark"] = context["benchmark_burst"] + bdb = min(timeit.repeat(stmt, setup, number=number, globals=context)) + + print( + f"Frame size = {size} bytes, " + f"frame count = {count}, " + f"packet size = {packet_size}" + ) + print(f"* ByteArrayStreamReader (stream): {bas / number * 1_000_000:.1f}µs") + print( + f"* BytesDequeStreamReader (stream): " + f"{bds / number * 1_000_000:.1f}µs ({(bds / bas - 1) * 100:+.1f}%)" + ) + print(f"* ByteArrayStreamReader (burst): {bab / number * 1_000_000:.1f}µs") + print( + f"* BytesDequeStreamReader (burst): " + f"{bdb / number * 1_000_000:.1f}µs ({(bdb / bab - 1) * 100:+.1f}%)" + ) + print() + + +if __name__ == "__main__": + run_test(ByteArrayStreamReader) + run_test(BytesDequeStreamReader) + + run_benchmark(size=8, count=1000) + run_benchmark(size=60, count=1000) + run_benchmark(size=500, count=500) + run_benchmark(size=4_000, count=200) + run_benchmark(size=30_000, count=100) + run_benchmark(size=250_000, count=50) + run_benchmark(size=2_000_000, count=20) + + run_benchmark(size=4_000, count=200, packet_size=1024) + run_benchmark(size=30_000, count=100, packet_size=1024) + run_benchmark(size=250_000, count=50, packet_size=1024) + run_benchmark(size=2_000_000, count=20, packet_size=1024) + + run_benchmark(size=30_000, count=100, packet_size=4096) + run_benchmark(size=250_000, count=50, packet_size=4096) + run_benchmark(size=2_000_000, count=20, packet_size=4096) + + run_benchmark(size=30_000, count=100, packet_size=16384) + run_benchmark(size=250_000, count=50, packet_size=16384) + run_benchmark(size=2_000_000, count=20, packet_size=16384) + + run_benchmark(size=250_000, count=50, packet_size=65536) + run_benchmark(size=2_000_000, count=20, packet_size=65536) From bb17be2fb0b99c8638788648fdd83f9049c1c344 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 11 Apr 2023 21:50:53 +0200 Subject: [PATCH 1211/1539] Add scripts to benchmark parsers. --- experiments/optimization/parse_frames.py | 101 +++++++++++++++++++ experiments/optimization/parse_handshake.py | 102 ++++++++++++++++++++ 2 files changed, 203 insertions(+) create mode 100644 experiments/optimization/parse_frames.py create mode 100644 experiments/optimization/parse_handshake.py diff --git a/experiments/optimization/parse_frames.py b/experiments/optimization/parse_frames.py new file mode 100644 index 000000000..e3acbe3c2 --- /dev/null +++ b/experiments/optimization/parse_frames.py @@ -0,0 +1,101 @@ +"""Benchark parsing WebSocket frames.""" + +import subprocess +import sys +import timeit + +from websockets.extensions.permessage_deflate import PerMessageDeflate +from websockets.frames import Frame, Opcode +from websockets.streams import StreamReader + + +# 256kB of text, compressible by about 70%. +text = subprocess.check_output(["git", "log", "8dd8e410"], text=True) + + +def get_frame(size): + repeat, remainder = divmod(size, 256 * 1024) + payload = repeat * text + text[:remainder] + return Frame(Opcode.TEXT, payload.encode(), True) + + +def parse_frame(data, count, mask, extensions): + reader = StreamReader() + for _ in range(count): + reader.feed_data(data) + parser = Frame.parse( + reader.read_exact, + mask=mask, + extensions=extensions, + ) + try: + next(parser) + except StopIteration: + pass + else: + assert False, "parser should return frame" + reader.feed_eof() + assert reader.at_eof(), "parser should consume all data" + + +def run_benchmark(size, count, compression=False, number=100): + if compression: + extensions = [PerMessageDeflate(True, True, 12, 12, {"memLevel": 5})] + else: + extensions = [] + globals = { + "get_frame": get_frame, + "parse_frame": parse_frame, + "extensions": extensions, + } + sppf = ( + min( + timeit.repeat( + f"parse_frame(data, {count}, mask=True, extensions=extensions)", + f"data = get_frame({size})" + f".serialize(mask=True, extensions=extensions)", + number=number, + globals=globals, + ) + ) + / number + / count + * 1_000_000 + ) + cppf = ( + min( + timeit.repeat( + f"parse_frame(data, {count}, mask=False, extensions=extensions)", + f"data = get_frame({size})" + f".serialize(mask=False, extensions=extensions)", + number=number, + globals=globals, + ) + ) + / number + / count + * 1_000_000 + ) + print(f"{size}\t{compression}\t{sppf:.2f}\t{cppf:.2f}") + + +if __name__ == "__main__": + print("Sizes are in bytes. Times are in µs per frame.", file=sys.stderr) + print("Run `tabs -16` for clean output. Pipe stdout to TSV for saving.") + print(file=sys.stderr) + + print("size\tcompression\tserver\tclient") + run_benchmark(size=8, count=1000, compression=False) + run_benchmark(size=60, count=1000, compression=False) + run_benchmark(size=500, count=1000, compression=False) + run_benchmark(size=4_000, count=1000, compression=False) + run_benchmark(size=30_000, count=200, compression=False) + run_benchmark(size=250_000, count=100, compression=False) + run_benchmark(size=2_000_000, count=20, compression=False) + + run_benchmark(size=8, count=1000, compression=True) + run_benchmark(size=60, count=1000, compression=True) + run_benchmark(size=500, count=200, compression=True) + run_benchmark(size=4_000, count=100, compression=True) + run_benchmark(size=30_000, count=20, compression=True) + run_benchmark(size=250_000, count=10, compression=True) diff --git a/experiments/optimization/parse_handshake.py b/experiments/optimization/parse_handshake.py new file mode 100644 index 000000000..af5a4ecae --- /dev/null +++ b/experiments/optimization/parse_handshake.py @@ -0,0 +1,102 @@ +"""Benchark parsing WebSocket handshake requests.""" + +# The parser for responses is designed similarly and should perform similarly. + +import sys +import timeit + +from websockets.http11 import Request +from websockets.streams import StreamReader + + +CHROME_HANDSHAKE = ( + b"GET / HTTP/1.1\r\n" + b"Host: localhost:5678\r\n" + b"Connection: Upgrade\r\n" + b"Pragma: no-cache\r\n" + b"Cache-Control: no-cache\r\n" + b"User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) " + b"AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36\r\n" + b"Upgrade: websocket\r\n" + b"Origin: null\r\n" + b"Sec-WebSocket-Version: 13\r\n" + b"Accept-Encoding: gzip, deflate, br\r\n" + b"Accept-Language: en-GB,en;q=0.9,en-US;q=0.8,fr;q=0.7\r\n" + b"Sec-WebSocket-Key: ebkySAl+8+e6l5pRKTMkyQ==\r\n" + b"Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\n" + b"\r\n" +) + +FIREFOX_HANDSHAKE = ( + b"GET / HTTP/1.1\r\n" + b"Host: localhost:5678\r\n" + b"User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:109.0) " + b"Gecko/20100101 Firefox/111.0\r\n" + b"Accept: */*\r\n" + b"Accept-Language: en-US,en;q=0.7,fr-FR;q=0.3\r\n" + b"Accept-Encoding: gzip, deflate, br\r\n" + b"Sec-WebSocket-Version: 13\r\n" + b"Origin: null\r\n" + b"Sec-WebSocket-Extensions: permessage-deflate\r\n" + b"Sec-WebSocket-Key: 1PuS+hnb+0AXsL7z2hNAhw==\r\n" + b"Connection: keep-alive, Upgrade\r\n" + b"Sec-Fetch-Dest: websocket\r\n" + b"Sec-Fetch-Mode: websocket\r\n" + b"Sec-Fetch-Site: cross-site\r\n" + b"Pragma: no-cache\r\n" + b"Cache-Control: no-cache\r\n" + b"Upgrade: websocket\r\n" + b"\r\n" +) + +WEBSOCKETS_HANDSHAKE = ( + b"GET / HTTP/1.1\r\n" + b"Host: localhost:8765\r\n" + b"Upgrade: websocket\r\n" + b"Connection: Upgrade\r\n" + b"Sec-WebSocket-Key: 9c55e0/siQ6tJPCs/QR8ZA==\r\n" + b"Sec-WebSocket-Version: 13\r\n" + b"Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\n" + b"User-Agent: Python/3.11 websockets/11.0\r\n" + b"\r\n" +) + + +def parse_handshake(handshake): + reader = StreamReader() + reader.feed_data(handshake) + parser = Request.parse(reader.read_line) + try: + next(parser) + except StopIteration: + pass + else: + assert False, "parser should return request" + reader.feed_eof() + assert reader.at_eof(), "parser should consume all data" + + +def run_benchmark(name, handshake, number=10000): + ph = ( + min( + timeit.repeat( + "parse_handshake(handshake)", + number=number, + globals={"parse_handshake": parse_handshake, "handshake": handshake}, + ) + ) + / number + * 1_000_000 + ) + print(f"{name}\t{len(handshake)}\t{ph:.1f}") + + +if __name__ == "__main__": + print("Sizes are in bytes. Times are in µs per frame.", file=sys.stderr) + print("Run `tabs -16` for clean output. Pipe stdout to TSV for saving.") + print(file=sys.stderr) + + print("client\tsize\ttime") + run_benchmark("Chrome", CHROME_HANDSHAKE) + run_benchmark("Firefox", FIREFOX_HANDSHAKE) + run_benchmark("websockets", WEBSOCKETS_HANDSHAKE) From c38507fde13b9d2ac9bf62b8f68dba8d93f1fcd3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 Apr 2023 08:14:33 +0200 Subject: [PATCH 1212/1539] Avoid deadlock when closing sync connection with unread messages. Fix #1336. --- docs/project/changelog.rst | 11 ++++++++ src/websockets/sync/connection.py | 23 +++++++++++------ tests/sync/test_connection.py | 43 +++++++++++++++++++++++++++++++ 3 files changed, 69 insertions(+), 8 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 127ea3142..f54880b06 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -38,6 +38,17 @@ Backwards-incompatible changes websockets 11.0 is the last version supporting Python 3.7. +11.0.2 +------ + +*April 18, 2023* + +Bug fixes +......... + +* Fixed a deadlock in the :mod:`threading` implementation when closing a + connection without reading all messages. + 11.0.1 ------ diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 59c0bc071..afebf5ea9 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -382,8 +382,9 @@ def close(self, code: int = 1000, reason: str = "") -> None: """ Perform the closing handshake. - :meth:`close` waits for the other end to complete the handshake and - for the TCP connection to terminate. + :meth:`close` waits for the other end to complete the handshake, for the + TCP connection to terminate, and for all incoming messages to be read + with :meth:`recv`. :meth:`close` is idempotent: it doesn't do anything once the connection is closed. @@ -574,9 +575,13 @@ def recv_events(self) -> None: # Given that automatic responses write small amounts of data, # this should be uncommon, so we don't handle the edge case. - for event in events: - # This isn't expected to raise an exception. - self.process_event(event) + try: + for event in events: + # This may raise EOFError if the closing handshake + # times out while a message is waiting to be read. + self.process_event(event) + except EOFError: + break # Breaking out of the while True: ... loop means that we believe # that the socket doesn't work anymore. @@ -600,7 +605,6 @@ def recv_events(self) -> None: self.protocol.state = CLOSED finally: # This isn't expected to raise an exception. - self.recv_messages.close() self.close_socket() @contextlib.contextmanager @@ -745,13 +749,16 @@ def set_recv_events_exc(self, exc: Optional[BaseException]) -> None: def close_socket(self) -> None: """ - Shutdown and close socket. + Shutdown and close socket. Close message assembler. - shutdown() is required to interrupt recv() on Linux. + Calling close_socket() guarantees that recv_events() terminates. Indeed, + recv_events() may block only on socket.recv() or on recv_messages.put(). """ + # shutdown() is required to interrupt recv() on Linux. try: self.socket.shutdown(socket.SHUT_RDWR) except OSError: pass # socket is already closed self.socket.close() + self.recv_messages.close() diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index f26ec3f95..0e7cff948 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -458,6 +458,49 @@ def test_close_timeout_waiting_for_connection_closed(self): # Remove socket.timeout when dropping Python < 3.10. self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) + def test_close_waits_for_recv(self): + self.remote_connection.send("😀") + + close_thread = threading.Thread(target=self.connection.close) + close_thread.start() + + # Let close() initiate the closing handshake and send a close frame. + time.sleep(MS) + self.assertTrue(close_thread.is_alive()) + + # Connection isn't closed yet. + self.connection.recv() + + # Let close() receive a close frame and finish the closing handshake. + time.sleep(MS) + self.assertFalse(close_thread.is_alive()) + + # Connection is closed now. + with self.assertRaises(ConnectionClosedOK) as raised: + self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + def test_close_timeout_waiting_for_recv(self): + self.remote_connection.send("😀") + + close_thread = threading.Thread(target=self.connection.close) + close_thread.start() + + # Let close() time out during the closing handshake. + time.sleep(3 * MS) + self.assertFalse(close_thread.is_alive()) + + # Connection is closed now. + with self.assertRaises(ConnectionClosedError) as raised: + self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); no close frame received") + self.assertIsInstance(exc.__cause__, TimeoutError) + def test_close_idempotency(self): """close does nothing if the connection is already closed.""" self.connection.close() From fe629dede6eb083013e0d2373d5c3120c0078db3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 20 Apr 2023 16:45:55 +0200 Subject: [PATCH 1213/1539] Fix typo in changelog. Fix #1340. --- docs/project/changelog.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index f54880b06..c5574f442 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -105,7 +105,7 @@ Backwards-incompatible changes New features ............ -.. admonition:: websockets 10.0 introduces a implementation on top of :mod:`threading`. +.. admonition:: websockets 11.0 introduces a implementation on top of :mod:`threading`. :class: important It may be more convenient if you don't need to manage many connections and From e152ceda9ec5d9e5052b3ae395ade2fae83d73fb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 4 May 2023 15:30:42 +0200 Subject: [PATCH 1214/1539] Add Open Collective and GitHub Sponsors. --- .github/FUNDING.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml index 7ae223b3d..c6c5426a5 100644 --- a/.github/FUNDING.yml +++ b/.github/FUNDING.yml @@ -1 +1,3 @@ -tidelift: "pypi/websockets" +github: python-websockets +open_collective: websockets +tidelift: pypi/websockets From 5aafc9e88fb1abf03012d71bdace770e3771c376 Mon Sep 17 00:00:00 2001 From: Carl Harris Date: Sun, 7 May 2023 09:41:16 -0400 Subject: [PATCH 1215/1539] Use selectors instead of select.poll in sync.WebSocket Server for multi-platform support (#1349) * use multiplatform selector instead of poll * don't use os.pipe with the I/O multiplexing selector on win32 On the Win32 platform, only sockets can be used with I/O multiplexing (such as that performed by selectors.DefaultSelector); the pipe cannot be added to the selector. However, on the win32 platform, simply closing the listener socket is enough to cause the call to select to return -- the additional pipe is redundant. On Mac OS X (and possibly other BSD derivatives), closing the listener socket isn't enough. In the interest of maximum compatibility, we simply disable the use of os.pipe on the Win32 platform. * exclude platform checks for win32 from coverage testing --- pyproject.toml | 1 + src/websockets/sync/server.py | 20 ++++++++++++-------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c26e3aabc..530052ddd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ source = [ exclude_lines = [ "except ImportError:", "if self.debug:", + "if sys.platform != \"win32\":", "pragma: no cover", "raise AssertionError", "raise NotImplementedError", diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index e25646a5c..072fce717 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -3,9 +3,10 @@ import http import logging import os -import select +import selectors import socket import ssl +import sys import threading from types import TracebackType from typing import Any, Callable, Optional, Sequence, Type @@ -199,7 +200,8 @@ def __init__( if logger is None: logger = logging.getLogger("websockets.server") self.logger = logger - self.shutdown_watcher, self.shutdown_notifier = os.pipe() + if sys.platform != "win32": + self.shutdown_watcher, self.shutdown_notifier = os.pipe() def serve_forever(self) -> None: """ @@ -214,15 +216,16 @@ def serve_forever(self) -> None: server.serve_forever() """ - poller = select.poll() - poller.register(self.socket) - poller.register(self.shutdown_watcher) + poller = selectors.DefaultSelector() + poller.register(self.socket, selectors.EVENT_READ) + if sys.platform != "win32": + poller.register(self.shutdown_watcher, selectors.EVENT_READ) while True: - poller.poll() + poller.select() try: # If the socket is closed, this will raise an exception and exit - # the loop. So we don't need to check the return value of poll(). + # the loop. So we don't need to check the return value of select(). sock, addr = self.socket.accept() except OSError: break @@ -235,7 +238,8 @@ def shutdown(self) -> None: """ self.socket.close() - os.write(self.shutdown_notifier, b"x") + if sys.platform != "win32": + os.write(self.shutdown_notifier, b"x") def fileno(self) -> int: """ From bf51a57f3d9c1f9ed7c04709a3b09ec6f337c4c5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 7 May 2023 15:44:46 +0200 Subject: [PATCH 1216/1539] Add changelog for previous commit. --- docs/project/changelog.rst | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index c5574f442..36bcfc58b 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -38,6 +38,16 @@ Backwards-incompatible changes websockets 11.0 is the last version supporting Python 3.7. +11.0.3 +------ + +*May 7, 2023* + +Bug fixes +......... + +* Fixed the :mod:`threading` implementation of servers on Windows. + 11.0.2 ------ From 9c578a1f5a5e20b8901ac75f12a4dfec4877f6d2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 7 May 2023 16:20:55 +0200 Subject: [PATCH 1217/1539] Create GitHub release when pushing tag. Fix #1347. --- .github/workflows/wheels.yml | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index d9c9ea833..8aa5c0b7b 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -64,18 +64,25 @@ jobs: with: path: wheelhouse/*.whl - upload_pypi: - name: Upload to PyPI + release: + name: Release needs: - sdist - wheels runs-on: ubuntu-latest + # Don't release when running the workflow manually from GitHub's UI. if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') steps: - - uses: actions/download-artifact@v3 + - name: Download artifacts + uses: actions/download-artifact@v3 with: name: artifact path: dist - - uses: pypa/gh-action-pypi-publish@release/v1 + - name: Upload to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 with: password: ${{ secrets.PYPI_API_TOKEN }} + - name: Create GitHub release + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: gh release create ${{ github.ref_name }} --notes "See https://websockets.readthedocs.io/en/stable/project/changelog.html for details." From 73394df61d3b4c62f4868b75d16e1a81b2329f0c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 18 May 2023 09:23:47 +0200 Subject: [PATCH 1218/1539] Document how to override DNS resolution. Ref #1343. --- docs/faq/client.rst | 10 ++++++++++ docs/faq/server.rst | 8 +++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/docs/faq/client.rst b/docs/faq/client.rst index c4f5a35b9..c590ac107 100644 --- a/docs/faq/client.rst +++ b/docs/faq/client.rst @@ -48,6 +48,16 @@ In the :mod:`threading` API, this argument is named ``additional_headers``:: with connect(..., additional_headers={"Authorization": ...}) as websocket: ... +How do I force the IP address that the client connects to? +---------------------------------------------------------- + +Use the ``host`` argument of :meth:`~asyncio.loop.create_connection`:: + + await websockets.connect("ws://example.com", host="192.168.0.1") + +:func:`~client.connect` accepts the same arguments as +:meth:`~asyncio.loop.create_connection`. + How do I close a connection? ---------------------------- diff --git a/docs/faq/server.rst b/docs/faq/server.rst index d1388c701..08b412d30 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -252,10 +252,12 @@ It's available in :attr:`~legacy.protocol.WebSocketCommonProtocol.remote_address async def handler(websocket): remote_ip = websocket.remote_address[0] -How do I set the IP addresses my server listens on? ---------------------------------------------------- +How do I set the IP addresses that my server listens on? +-------------------------------------------------------- -Look at the ``host`` argument of :meth:`~asyncio.loop.create_server`. +Use the ``host`` argument of :meth:`~asyncio.loop.create_server`:: + + await websockets.serve(handler, host="192.168.0.1", port=8080) :func:`~server.serve` accepts the same arguments as :meth:`~asyncio.loop.create_server`. From 03d62c97fcafffa5cdbe4bb55b2a8d17a62eca33 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 18 May 2023 16:52:39 +0200 Subject: [PATCH 1219/1539] Fix server shutdown on Python 3.12. Ref https://github.com/python/cpython/issues/79033. Fix #1356. --- src/websockets/legacy/server.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index a17c52328..77e0fdab7 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -761,18 +761,13 @@ async def _close(self, close_connections: bool) -> None: # Stop accepting new connections. self.server.close() - # Wait until self.server.close() completes. - await self.server.wait_closed() - # Wait until all accepted connections reach connection_made() and call # register(). See https://bugs.python.org/issue34852 for details. await asyncio.sleep(0) if close_connections: - # Close OPEN connections with status code 1001. Since the server was - # closed, handshake() closes OPENING connections with an HTTP 503 - # error. Wait until all connections are closed. - + # Close OPEN connections with close code 1001. After server.close(), + # handshake() closes OPENING connections with an HTTP 503 error. close_tasks = [ asyncio.create_task(websocket.close(1001)) for websocket in self.websockets @@ -782,8 +777,10 @@ async def _close(self, close_connections: bool) -> None: if close_tasks: await asyncio.wait(close_tasks) - # Wait until all connection handlers are complete. + # Wait until all TCP connections are closed. + await self.server.wait_closed() + # Wait until all connection handlers terminate. # asyncio.wait doesn't accept an empty first argument. if self.websockets: await asyncio.wait( From 2845257be549b41e8279b2ae19f38258f72fb587 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 18 May 2023 18:45:25 +0200 Subject: [PATCH 1220/1539] Mention application-level heartbeats. Fix #1330. --- docs/faq/common.rst | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/faq/common.rst b/docs/faq/common.rst index 505149a64..0f3cd5910 100644 --- a/docs/faq/common.rst +++ b/docs/faq/common.rst @@ -151,4 +151,11 @@ See :doc:`../topics/timeouts` for details. How do I respond to pings? -------------------------- -Don't bother; websockets takes care of responding to pings with pongs. +If you are referring to Ping_ and Pong_ frames defined in the WebSocket +protocol, don't bother, because websockets handles them for you. + +.. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 +.. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 + +If you are connecting to a server that defines its own heartbeat at the +application level, then you need to build that logic into your application. From 89fc4086b9bd8bb4c070284276c2eb5973f8b27c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 18 May 2023 18:49:14 +0200 Subject: [PATCH 1221/1539] Don't crash if git is super slow. Fix #1334. --- src/websockets/version.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/websockets/version.py b/src/websockets/version.py index f45f069ab..3f171b391 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -46,7 +46,11 @@ def get_version(tag: str) -> str: text=True, ).stdout.strip() # subprocess.run raises FileNotFoundError if git isn't on $PATH. - except (FileNotFoundError, subprocess.CalledProcessError): + except ( + FileNotFoundError, + subprocess.CalledProcessError, + subprocess.TimeoutExpired, + ): pass else: description_re = r"[0-9.]+-([0-9]+)-(g[0-9a-f]{7,}(?:-dirty)?)" From 2b627b26b3cd15f222881450cf0e8c19f7e826fd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 18 May 2023 10:21:41 +0200 Subject: [PATCH 1222/1539] Provide an enum for close codes. Fix #1335. --- docs/faq/common.rst | 8 +- docs/howto/django.rst | 2 +- docs/howto/fly.rst | 3 +- docs/howto/heroku.rst | 3 +- docs/howto/kubernetes.rst | 2 +- docs/howto/render.rst | 3 +- docs/project/changelog.rst | 5 + docs/reference/datastructures.rst | 23 +- docs/topics/authentication.rst | 2 +- example/django/authentication.py | 3 +- example/django/notifications.py | 3 +- experiments/authentication/app.py | 3 +- src/websockets/exceptions.py | 2 +- src/websockets/frames.py | 89 +++++--- src/websockets/legacy/protocol.py | 36 ++-- src/websockets/protocol.py | 17 +- src/websockets/sync/connection.py | 19 +- src/websockets/sync/server.py | 4 +- tests/extensions/test_permessage_deflate.py | 3 +- tests/legacy/test_client_server.py | 27 +-- tests/legacy/test_framing.py | 8 +- tests/legacy/test_protocol.py | 131 +++++++---- tests/sync/test_connection.py | 14 +- tests/sync/test_server.py | 4 +- tests/test_exceptions.py | 67 ++++-- tests/test_frames.py | 41 +++- tests/test_protocol.py | 227 ++++++++++++-------- 27 files changed, 480 insertions(+), 269 deletions(-) diff --git a/docs/faq/common.rst b/docs/faq/common.rst index 0f3cd5910..2c63c4f36 100644 --- a/docs/faq/common.rst +++ b/docs/faq/common.rst @@ -54,8 +54,8 @@ There are several reasons why long-lived connections may be lost: If you're facing a reproducible issue, :ref:`enable debug logs ` to see when and how connections are closed. -What does ``ConnectionClosedError: sent 1011 (unexpected error) keepalive ping timeout; no close frame received`` mean? ------------------------------------------------------------------------------------------------------------------------ +What does ``ConnectionClosedError: sent 1011 (internal error) keepalive ping timeout; no close frame received`` mean? +--------------------------------------------------------------------------------------------------------------------- If you're seeing this traceback in the logs of a server: @@ -70,7 +70,7 @@ If you're seeing this traceback in the logs of a server: Traceback (most recent call last): ... - websockets.exceptions.ConnectionClosedError: sent 1011 (unexpected error) keepalive ping timeout; no close frame received + websockets.exceptions.ConnectionClosedError: sent 1011 (internal error) keepalive ping timeout; no close frame received or if a client crashes with this traceback: @@ -84,7 +84,7 @@ or if a client crashes with this traceback: Traceback (most recent call last): ... - websockets.exceptions.ConnectionClosedError: sent 1011 (unexpected error) keepalive ping timeout; no close frame received + websockets.exceptions.ConnectionClosedError: sent 1011 (internal error) keepalive ping timeout; no close frame received it means that the WebSocket connection suffered from excessive latency and was closed after reaching the timeout of websockets' keepalive mechanism. diff --git a/docs/howto/django.rst b/docs/howto/django.rst index c955a5ec1..e3da0a878 100644 --- a/docs/howto/django.rst +++ b/docs/howto/django.rst @@ -158,7 +158,7 @@ closes the connection: $ python -m websockets ws://localhost:8888/ Connected to ws://localhost:8888. > not a token - Connection closed: 1011 (unexpected error) authentication failed. + Connection closed: 1011 (internal error) authentication failed. You can also test from a browser by generating a new token and running the following code in the JavaScript console of the browser: diff --git a/docs/howto/fly.rst b/docs/howto/fly.rst index 7e404de61..ed001a2ae 100644 --- a/docs/howto/fly.rst +++ b/docs/howto/fly.rst @@ -174,5 +174,4 @@ away). Connection closed: 1001 (going away). If graceful shutdown wasn't working, the server wouldn't perform a closing -handshake and the connection would be closed with code 1006 (connection closed -abnormally). +handshake and the connection would be closed with code 1006 (abnormal closure). diff --git a/docs/howto/heroku.rst b/docs/howto/heroku.rst index 2b3a44819..a97d2e7ce 100644 --- a/docs/howto/heroku.rst +++ b/docs/howto/heroku.rst @@ -180,5 +180,4 @@ away). Connection closed: 1001 (going away). If graceful shutdown wasn't working, the server wouldn't perform a closing -handshake and the connection would be closed with code 1006 (connection closed -abnormally). +handshake and the connection would be closed with code 1006 (abnormal closure). diff --git a/docs/howto/kubernetes.rst b/docs/howto/kubernetes.rst index c217e5946..064a6ac4d 100644 --- a/docs/howto/kubernetes.rst +++ b/docs/howto/kubernetes.rst @@ -77,7 +77,7 @@ shut down gracefully: < Hey there! Connection closed: 1001 (going away). -If it didn't, you'd get code 1006 (connection closed abnormally). +If it didn't, you'd get code 1006 (abnormal closure). Deploy application ------------------ diff --git a/docs/howto/render.rst b/docs/howto/render.rst index b8c417b66..70bf8c376 100644 --- a/docs/howto/render.rst +++ b/docs/howto/render.rst @@ -167,7 +167,6 @@ deployment completes, the connection is closed with code 1001 (going away). Connection closed: 1001 (going away). If graceful shutdown wasn't working, the server wouldn't perform a closing -handshake and the connection would be closed with code 1006 (connection closed -abnormally). +handshake and the connection would be closed with code 1006 (abnormal closure). Remember to downgrade to a free plan if you upgraded just for testing this feature. diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 36bcfc58b..94fc5ebd9 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -38,6 +38,11 @@ Backwards-incompatible changes websockets 11.0 is the last version supporting Python 3.7. +Improvements +............ + +* Added :class:`~frames.CloseCode`. + 11.0.3 ------ diff --git a/docs/reference/datastructures.rst b/docs/reference/datastructures.rst index 8217052d1..7c037da4c 100644 --- a/docs/reference/datastructures.rst +++ b/docs/reference/datastructures.rst @@ -11,19 +11,32 @@ WebSocket events .. autoclass:: Opcode .. autoattribute:: CONT - .. autoattribute:: TEXT - .. autoattribute:: BINARY - .. autoattribute:: CLOSE - .. autoattribute:: PING - .. autoattribute:: PONG .. autoclass:: Close + .. autoclass:: CloseCode + + .. autoattribute:: OK + .. autoattribute:: GOING_AWAY + .. autoattribute:: PROTOCOL_ERROR + .. autoattribute:: UNSUPPORTED_DATA + .. autoattribute:: NO_STATUS_RCVD + .. autoattribute:: CONNECTION_CLOSED_ABNORMALLY + .. autoattribute:: INVALID_DATA + .. autoattribute:: POLICY_VIOLATION + .. autoattribute:: MESSAGE_TOO_BIG + .. autoattribute:: MANDATORY_EXTENSION + .. autoattribute:: INTERNAL_ERROR + .. autoattribute:: SERVICE_RESTART + .. autoattribute:: TRY_AGAIN_LATER + .. autoattribute:: BAD_GATEWAY + .. autoattribute:: TLS_FAILURE + HTTP events ----------- diff --git a/docs/topics/authentication.rst b/docs/topics/authentication.rst index 60dd38766..31bc8e6da 100644 --- a/docs/topics/authentication.rst +++ b/docs/topics/authentication.rst @@ -189,7 +189,7 @@ connection: token = await websocket.recv() user = get_user(token) if user is None: - await websocket.close(1011, "authentication failed") + await websocket.close(CloseCode.INTERNAL_ERROR, "authentication failed") return ... diff --git a/example/django/authentication.py b/example/django/authentication.py index 7f60f8275..f6dad0f55 100644 --- a/example/django/authentication.py +++ b/example/django/authentication.py @@ -8,13 +8,14 @@ django.setup() from sesame.utils import get_user +from websockets.frames import CloseCode async def handler(websocket): sesame = await websocket.recv() user = await asyncio.to_thread(get_user, sesame) if user is None: - await websocket.close(1011, "authentication failed") + await websocket.close(CloseCode.INTERNAL_ERROR, "authentication failed") return await websocket.send(f"Hello {user}!") diff --git a/example/django/notifications.py b/example/django/notifications.py index 7275a1ef7..3a9ed10cf 100644 --- a/example/django/notifications.py +++ b/example/django/notifications.py @@ -11,6 +11,7 @@ from django.contrib.contenttypes.models import ContentType from sesame.utils import get_user +from websockets.frames import CloseCode CONNECTIONS = {} @@ -33,7 +34,7 @@ async def handler(websocket): sesame = await websocket.recv() user = await asyncio.to_thread(get_user, sesame) if user is None: - await websocket.close(1011, "authentication failed") + await websocket.close(CloseCode.INTERNAL_ERROR, "authentication failed") return ct_ids = await asyncio.to_thread(get_content_types, user) diff --git a/experiments/authentication/app.py b/experiments/authentication/app.py index 6b3b2ae3f..039e21174 100644 --- a/experiments/authentication/app.py +++ b/experiments/authentication/app.py @@ -9,6 +9,7 @@ import uuid import websockets +from websockets.frames import CloseCode # User accounts database @@ -95,7 +96,7 @@ async def first_message_handler(websocket): token = await websocket.recv() user = get_user(token) if user is None: - await websocket.close(1011, "authentication failed") + await websocket.close(CloseCode.INTERNAL_ERROR, "authentication failed") return await websocket.send(f"Hello {user}!") diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 22a3b583f..9d8476648 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -121,7 +121,7 @@ def __str__(self) -> str: @property def code(self) -> int: if self.rcvd is None: - return 1006 + return frames.CloseCode.ABNORMAL_CLOSURE return self.rcvd.code @property diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 8e0e6d873..6b1befb2e 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -52,48 +52,69 @@ class Opcode(enum.IntEnum): CTRL_OPCODES = OP_CLOSE, OP_PING, OP_PONG -# See https://www.iana.org/assignments/websocket/websocket.xhtml -CLOSE_CODES = { - 1000: "OK", - 1001: "going away", - 1002: "protocol error", - 1003: "unsupported type", +class CloseCode(enum.IntEnum): + """Close code values for WebSocket close frames.""" + + NORMAL_CLOSURE = 1000 + GOING_AWAY = 1001 + PROTOCOL_ERROR = 1002 + UNSUPPORTED_DATA = 1003 # 1004 is reserved - 1005: "no status code [internal]", - 1006: "connection closed abnormally [internal]", - 1007: "invalid data", - 1008: "policy violation", - 1009: "message too big", - 1010: "extension required", - 1011: "unexpected error", - 1012: "service restart", - 1013: "try again later", - 1014: "bad gateway", - 1015: "TLS failure [internal]", + NO_STATUS_RCVD = 1005 + ABNORMAL_CLOSURE = 1006 + INVALID_DATA = 1007 + POLICY_VIOLATION = 1008 + MESSAGE_TOO_BIG = 1009 + MANDATORY_EXTENSION = 1010 + INTERNAL_ERROR = 1011 + SERVICE_RESTART = 1012 + TRY_AGAIN_LATER = 1013 + BAD_GATEWAY = 1014 + TLS_HANDSHAKE = 1015 + + +# See https://www.iana.org/assignments/websocket/websocket.xhtml +CLOSE_CODE_EXPLANATIONS: dict[int, str] = { + CloseCode.NORMAL_CLOSURE: "OK", + CloseCode.GOING_AWAY: "going away", + CloseCode.PROTOCOL_ERROR: "protocol error", + CloseCode.UNSUPPORTED_DATA: "unsupported data", + CloseCode.NO_STATUS_RCVD: "no status received [internal]", + CloseCode.ABNORMAL_CLOSURE: "abnormal closure [internal]", + CloseCode.INVALID_DATA: "invalid frame payload data", + CloseCode.POLICY_VIOLATION: "policy violation", + CloseCode.MESSAGE_TOO_BIG: "message too big", + CloseCode.MANDATORY_EXTENSION: "mandatory extension", + CloseCode.INTERNAL_ERROR: "internal error", + CloseCode.SERVICE_RESTART: "service restart", + CloseCode.TRY_AGAIN_LATER: "try again later", + CloseCode.BAD_GATEWAY: "bad gateway", + CloseCode.TLS_HANDSHAKE: "TLS handshake failure [internal]", } # Close code that are allowed in a close frame. # Using a set optimizes `code in EXTERNAL_CLOSE_CODES`. EXTERNAL_CLOSE_CODES = { - 1000, - 1001, - 1002, - 1003, - 1007, - 1008, - 1009, - 1010, - 1011, - 1012, - 1013, - 1014, + CloseCode.NORMAL_CLOSURE, + CloseCode.GOING_AWAY, + CloseCode.PROTOCOL_ERROR, + CloseCode.UNSUPPORTED_DATA, + CloseCode.INVALID_DATA, + CloseCode.POLICY_VIOLATION, + CloseCode.MESSAGE_TOO_BIG, + CloseCode.MANDATORY_EXTENSION, + CloseCode.INTERNAL_ERROR, + CloseCode.SERVICE_RESTART, + CloseCode.TRY_AGAIN_LATER, + CloseCode.BAD_GATEWAY, } + OK_CLOSE_CODES = { - 1000, - 1001, - 1005, + CloseCode.NORMAL_CLOSURE, + CloseCode.GOING_AWAY, + CloseCode.NO_STATUS_RCVD, } @@ -397,7 +418,7 @@ def __str__(self) -> str: elif 4000 <= self.code < 5000: explanation = "private use" else: - explanation = CLOSE_CODES.get(self.code, "unknown") + explanation = CLOSE_CODE_EXPLANATIONS.get(self.code, "unknown") result = f"{self.code} ({explanation})" if self.reason: @@ -425,7 +446,7 @@ def parse(cls, data: bytes) -> Close: close.check() return close elif len(data) == 0: - return cls(1005, "") + return cls(CloseCode.NO_STATUS_RCVD, "") else: raise exceptions.ProtocolError("close frame too short") diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 0422c10da..19cee0e65 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -47,6 +47,7 @@ OP_PONG, OP_TEXT, Close, + CloseCode, Opcode, prepare_ctrl, prepare_data, @@ -459,7 +460,7 @@ def close_code(self) -> Optional[int]: if self.state is not State.CLOSED: return None elif self.close_rcvd is None: - return 1006 + return CloseCode.ABNORMAL_CLOSURE else: return self.close_rcvd.code @@ -681,7 +682,7 @@ async def send( except (Exception, asyncio.CancelledError): # We're half-way through a fragmented message and we can't # complete it. This makes the connection unusable. - self.fail_connection(1011) + self.fail_connection(CloseCode.INTERNAL_ERROR) raise finally: @@ -726,7 +727,7 @@ async def send( except (Exception, asyncio.CancelledError): # We're half-way through a fragmented message and we can't # complete it. This makes the connection unusable. - self.fail_connection(1011) + self.fail_connection(CloseCode.INTERNAL_ERROR) raise finally: @@ -736,7 +737,11 @@ async def send( else: raise TypeError("data must be str, bytes-like, or iterable") - async def close(self, code: int = 1000, reason: str = "") -> None: + async def close( + self, + code: int = CloseCode.NORMAL_CLOSURE, + reason: str = "", + ) -> None: """ Perform the closing handshake. @@ -986,7 +991,7 @@ async def transfer_data(self) -> None: except ProtocolError as exc: self.transfer_data_exc = exc - self.fail_connection(1002) + self.fail_connection(CloseCode.PROTOCOL_ERROR) except (ConnectionError, TimeoutError, EOFError, ssl.SSLError) as exc: # Reading data with self.reader.readexactly may raise: @@ -997,15 +1002,15 @@ async def transfer_data(self) -> None: # bytes are available than requested; # - ssl.SSLError if the other side infringes the TLS protocol. self.transfer_data_exc = exc - self.fail_connection(1006) + self.fail_connection(CloseCode.ABNORMAL_CLOSURE) except UnicodeDecodeError as exc: self.transfer_data_exc = exc - self.fail_connection(1007) + self.fail_connection(CloseCode.INVALID_DATA) except PayloadTooBig as exc: self.transfer_data_exc = exc - self.fail_connection(1009) + self.fail_connection(CloseCode.MESSAGE_TOO_BIG) except Exception as exc: # This shouldn't happen often because exceptions expected under @@ -1014,7 +1019,7 @@ async def transfer_data(self) -> None: self.logger.error("data transfer failed", exc_info=True) self.transfer_data_exc = exc - self.fail_connection(1011) + self.fail_connection(CloseCode.INTERNAL_ERROR) async def read_message(self) -> Optional[Data]: """ @@ -1265,7 +1270,10 @@ async def keepalive_ping(self) -> None: except asyncio.TimeoutError: if self.debug: self.logger.debug("! timed out waiting for keepalive pong") - self.fail_connection(1011, "keepalive ping timeout") + self.fail_connection( + CloseCode.INTERNAL_ERROR, + "keepalive ping timeout", + ) break except ConnectionClosed: @@ -1377,7 +1385,11 @@ async def wait_for_connection_lost(self) -> bool: # and the moment this coroutine resumes running. return self.connection_lost_waiter.done() - def fail_connection(self, code: int = 1006, reason: str = "") -> None: + def fail_connection( + self, + code: int = CloseCode.ABNORMAL_CLOSURE, + reason: str = "", + ) -> None: """ 7.1.7. Fail the WebSocket Connection @@ -1408,7 +1420,7 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> None: # sent if it's CLOSING), except when failing the connection because of # an error reading from or writing to the network. # Don't send a close frame if the connection is broken. - if code != 1006 and self.state is State.OPEN: + if code != CloseCode.ABNORMAL_CLOSURE and self.state is State.OPEN: close = Close(code, reason) # Write the close frame without draining the write buffer. diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 3fdd3881c..765e6b9bb 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -23,6 +23,7 @@ OP_PONG, OP_TEXT, Close, + CloseCode, Frame, ) from .http11 import Request, Response @@ -181,7 +182,7 @@ def close_code(self) -> Optional[int]: if self.state is not CLOSED: return None elif self.close_rcvd is None: - return 1006 + return CloseCode.ABNORMAL_CLOSURE else: return self.close_rcvd.code @@ -362,7 +363,7 @@ def send_close(self, code: Optional[int] = None, reason: str = "") -> None: if code is None: if reason != "": raise ProtocolError("cannot send a reason without a code") - close = Close(1005, "") + close = Close(CloseCode.NO_STATUS_RCVD, "") data = b"" else: close = Close(code, reason) @@ -419,7 +420,7 @@ def fail(self, code: int, reason: str = "") -> None: # sent if it's CLOSING), except when failing the connection because # of an error reading from or writing to the network. if self.state is OPEN: - if code != 1006: + if code != CloseCode.ABNORMAL_CLOSURE: close = Close(code, reason) data = close.serialize() self.send_frame(Frame(OP_CLOSE, data)) @@ -549,25 +550,25 @@ def parse(self) -> Generator[None, None, None]: self.recv_frame(frame) except ProtocolError as exc: - self.fail(1002, str(exc)) + self.fail(CloseCode.PROTOCOL_ERROR, str(exc)) self.parser_exc = exc except EOFError as exc: - self.fail(1006, str(exc)) + self.fail(CloseCode.ABNORMAL_CLOSURE, str(exc)) self.parser_exc = exc except UnicodeDecodeError as exc: - self.fail(1007, f"{exc.reason} at position {exc.start}") + self.fail(CloseCode.INVALID_DATA, f"{exc.reason} at position {exc.start}") self.parser_exc = exc except PayloadTooBig as exc: - self.fail(1009, str(exc)) + self.fail(CloseCode.MESSAGE_TOO_BIG, str(exc)) self.parser_exc = exc except Exception as exc: self.logger.error("parser failed", exc_info=True) # Don't include exception details, which may be security-sensitive. - self.fail(1011) + self.fail(CloseCode.INTERNAL_ERROR) self.parser_exc = exc # During an abnormal closure, execution ends here after catching an diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index afebf5ea9..4a8879e37 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -11,7 +11,7 @@ from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Type, Union from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError -from ..frames import DATA_OPCODES, BytesLike, Frame, Opcode, prepare_ctrl +from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode, prepare_ctrl from ..http11 import Request, Response from ..protocol import CLOSED, OPEN, Event, Protocol, State from ..typing import Data, LoggerLike, Subprotocol @@ -141,7 +141,10 @@ def __exit__( exc_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> None: - self.close(1000 if exc_type is None else 1011) + if exc_type is None: + self.close() + else: + self.close(CloseCode.INTERNAL_ERROR) def __iter__(self) -> Iterator[Data]: """ @@ -372,13 +375,16 @@ def send(self, message: Union[Data, Iterable[Data]]) -> None: # We're half-way through a fragmented message and we can't # complete it. This makes the connection unusable. with self.send_context(): - self.protocol.fail(1011, "error in fragmented message") + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "error in fragmented message", + ) raise else: raise TypeError("data must be bytes, str, or iterable") - def close(self, code: int = 1000, reason: str = "") -> None: + def close(self, code: int = CloseCode.NORMAL_CLOSURE, reason: str = "") -> None: """ Perform the closing handshake. @@ -399,7 +405,10 @@ def close(self, code: int = 1000, reason: str = "") -> None: # to terminate after calling a method that sends a close frame. with self.send_context(): if self.send_in_progress: - self.protocol.fail(1011, "close during fragmented message") + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "close during fragmented message", + ) else: self.protocol.send_close(code, reason) except ConnectionClosed: diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 072fce717..14767968c 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -11,6 +11,8 @@ from types import TracebackType from typing import Any, Callable, Optional, Sequence, Type +from websockets.frames import CloseCode + from ..extensions.base import ServerExtensionFactory from ..extensions.permessage_deflate import enable_server_permessage_deflate from ..headers import validate_subprotocols @@ -497,7 +499,7 @@ def protocol_select_subprotocol( handler(connection) except Exception: protocol.logger.error("connection handler failed", exc_info=True) - connection.close(1011) + connection.close(CloseCode.INTERNAL_ERROR) else: connection.close() diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index c341fdb32..0e698566f 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -18,6 +18,7 @@ OP_PONG, OP_TEXT, Close, + CloseCode, Frame, ) @@ -74,7 +75,7 @@ def test_no_encode_decode_pong_frame(self): self.assertEqual(self.extension.decode(frame), frame) def test_no_encode_decode_close_frame(self): - frame = Frame(OP_CLOSE, Close(1000, "").serialize()) + frame = Frame(OP_CLOSE, Close(CloseCode.NORMAL_CLOSURE, "").serialize()) self.assertEqual(self.extension.encode(frame), frame) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 133af0536..cba24ad32 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -27,6 +27,7 @@ PerMessageDeflate, ServerPerMessageDeflateFactory, ) +from websockets.frames import CloseCode from websockets.http import USER_AGENT from websockets.legacy.client import * from websockets.legacy.compatibility import asyncio_timeout @@ -1150,7 +1151,7 @@ def test_server_handler_crashes(self, send): self.loop.run_until_complete(self.client.recv()) # Connection ends with an unexpected error. - self.assertEqual(self.client.close_code, 1011) + self.assertEqual(self.client.close_code, CloseCode.INTERNAL_ERROR) @with_server() @unittest.mock.patch("websockets.legacy.server.WebSocketServerProtocol.close") @@ -1163,7 +1164,7 @@ def test_server_close_crashes(self, close): self.assertEqual(reply, "Hello!") # Connection ends with an abnormal closure. - self.assertEqual(self.client.close_code, 1006) + self.assertEqual(self.client.close_code, CloseCode.ABNORMAL_CLOSURE) @with_server() @with_client() @@ -1196,8 +1197,8 @@ def test_server_shuts_down_during_connection_handling(self): self.loop.run_until_complete(self.client.recv()) # Server closed the connection with 1001 Going Away. - self.assertEqual(self.client.close_code, 1001) - self.assertEqual(server_ws.close_code, 1001) + self.assertEqual(self.client.close_code, CloseCode.GOING_AWAY) + self.assertEqual(server_ws.close_code, CloseCode.GOING_AWAY) @with_server() def test_server_shuts_down_gracefully_during_connection_handling(self): @@ -1208,8 +1209,8 @@ def test_server_shuts_down_gracefully_during_connection_handling(self): self.loop.run_until_complete(self.client.recv()) # Client closed the connection with 1000 OK. - self.assertEqual(self.client.close_code, 1000) - self.assertEqual(server_ws.close_code, 1000) + self.assertEqual(self.client.close_code, CloseCode.NORMAL_CLOSURE) + self.assertEqual(server_ws.close_code, CloseCode.NORMAL_CLOSURE) @with_server() def test_server_shuts_down_and_waits_until_handlers_terminate(self): @@ -1271,7 +1272,7 @@ def test_connection_error_during_closing_handshake(self, close): self.assertEqual(reply, "Hello!") # Connection ends with an abnormal closure. - self.assertEqual(self.client.close_code, 1006) + self.assertEqual(self.client.close_code, CloseCode.ABNORMAL_CLOSURE) class ClientServerTests( @@ -1467,12 +1468,12 @@ async def run_client(): self.assertEqual(messages, self.MESSAGES) - async def echo_handler_1001(ws): + async def echo_handler_going_away(ws): for message in AsyncIteratorTests.MESSAGES: await ws.send(message) - await ws.close(1001) + await ws.close(CloseCode.GOING_AWAY) - @with_server(handler=echo_handler_1001) + @with_server(handler=echo_handler_going_away) def test_iterate_on_messages_going_away_exit_ok(self): messages = [] @@ -1486,12 +1487,12 @@ async def run_client(): self.assertEqual(messages, self.MESSAGES) - async def echo_handler_1011(ws): + async def echo_handler_internal_error(ws): for message in AsyncIteratorTests.MESSAGES: await ws.send(message) - await ws.close(1011) + await ws.close(CloseCode.INTERNAL_ERROR) - @with_server(handler=echo_handler_1011) + @with_server(handler=echo_handler_internal_error) def test_iterate_on_messages_internal_error_exit_not_ok(self): messages = [] diff --git a/tests/legacy/test_framing.py b/tests/legacy/test_framing.py index 035f0f03c..e1e4c891b 100644 --- a/tests/legacy/test_framing.py +++ b/tests/legacy/test_framing.py @@ -6,7 +6,7 @@ import warnings from websockets.exceptions import PayloadTooBig, ProtocolError -from websockets.frames import OP_BINARY, OP_CLOSE, OP_PING, OP_PONG, OP_TEXT +from websockets.frames import OP_BINARY, OP_CLOSE, OP_PING, OP_PONG, OP_TEXT, CloseCode from websockets.legacy.framing import * from .utils import AsyncioTestCase @@ -187,11 +187,11 @@ def assertCloseData(self, code, reason, data): self.assertEqual(parsed, (code, reason)) def test_parse_close_and_serialize_close(self): - self.assertCloseData(1000, "", b"\x03\xe8") - self.assertCloseData(1000, "OK", b"\x03\xe8OK") + self.assertCloseData(CloseCode.NORMAL_CLOSURE, "", b"\x03\xe8") + self.assertCloseData(CloseCode.NORMAL_CLOSURE, "OK", b"\x03\xe8OK") def test_parse_close_empty(self): - self.assertEqual(parse_close(b""), (1005, "")) + self.assertEqual(parse_close(b""), (CloseCode.NO_STATUS_RCVD, "")) def test_parse_close_errors(self): with self.assertRaises(ProtocolError): diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 514d91ef8..f2eb0fea0 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -15,6 +15,7 @@ OP_PONG, OP_TEXT, Close, + CloseCode, ) from websockets.legacy.framing import Frame from websockets.legacy.protocol import WebSocketCommonProtocol, broadcast @@ -121,9 +122,21 @@ async def delayed_drain(): self.protocol._drain = delayed_drain - close_frame = Frame(True, OP_CLOSE, Close(1000, "close").serialize()) - local_close = Frame(True, OP_CLOSE, Close(1000, "local").serialize()) - remote_close = Frame(True, OP_CLOSE, Close(1000, "remote").serialize()) + close_frame = Frame( + True, + OP_CLOSE, + Close(CloseCode.NORMAL_CLOSURE, "close").serialize(), + ) + local_close = Frame( + True, + OP_CLOSE, + Close(CloseCode.NORMAL_CLOSURE, "local").serialize(), + ) + remote_close = Frame( + True, + OP_CLOSE, + Close(CloseCode.NORMAL_CLOSURE, "remote").serialize(), + ) def receive_frame(self, frame): """ @@ -157,7 +170,7 @@ def receive_eof_if_client(self): if self.protocol.is_client: self.receive_eof() - def close_connection(self, code=1000, reason="close"): + def close_connection(self, code=CloseCode.NORMAL_CLOSURE, reason="close"): """ Execute a closing handshake. @@ -175,7 +188,11 @@ def close_connection(self, code=1000, reason="close"): assert self.protocol.state is State.CLOSED - def half_close_connection_local(self, code=1000, reason="close"): + def half_close_connection_local( + self, + code=CloseCode.NORMAL_CLOSURE, + reason="close", + ): """ Start a closing handshake but do not complete it. @@ -205,7 +222,11 @@ def half_close_connection_local(self, code=1000, reason="close"): # This task must be awaited or canceled by the caller. return close_task - def half_close_connection_remote(self, code=1000, reason="close"): + def half_close_connection_remote( + self, + code=CloseCode.NORMAL_CLOSURE, + reason="close", + ): """ Receive a closing handshake but do not complete it. @@ -299,10 +320,10 @@ def assertConnectionFailed(self, code, message): # The following line guarantees that connection_lost was called. self.assertEqual(self.protocol.state, State.CLOSED) # No close frame was received. - self.assertEqual(self.protocol.close_code, 1006) + self.assertEqual(self.protocol.close_code, CloseCode.ABNORMAL_CLOSURE) self.assertEqual(self.protocol.close_reason, "") # A close frame was sent -- unless the connection was already lost. - if code == 1006: + if code == CloseCode.ABNORMAL_CLOSURE: self.assertNoFrameSent() else: self.assertOneFrameSent(True, OP_CLOSE, Close(code, message).serialize()) @@ -391,11 +412,11 @@ def test_wait_closed(self): self.assertTrue(wait_closed.done()) def test_close_code(self): - self.close_connection(1001, "Bye!") - self.assertEqual(self.protocol.close_code, 1001) + self.close_connection(CloseCode.GOING_AWAY, "Bye!") + self.assertEqual(self.protocol.close_code, CloseCode.GOING_AWAY) def test_close_reason(self): - self.close_connection(1001, "Bye!") + self.close_connection(CloseCode.GOING_AWAY, "Bye!") self.assertEqual(self.protocol.close_reason, "Bye!") def test_close_code_not_set(self): @@ -439,24 +460,24 @@ def test_recv_on_closed_connection(self): def test_recv_protocol_error(self): self.receive_frame(Frame(True, OP_CONT, "café".encode("utf-8"))) self.process_invalid_frames() - self.assertConnectionFailed(1002, "") + self.assertConnectionFailed(CloseCode.PROTOCOL_ERROR, "") def test_recv_unicode_error(self): self.receive_frame(Frame(True, OP_TEXT, "café".encode("latin-1"))) self.process_invalid_frames() - self.assertConnectionFailed(1007, "") + self.assertConnectionFailed(CloseCode.INVALID_DATA, "") def test_recv_text_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8") * 205)) self.process_invalid_frames() - self.assertConnectionFailed(1009, "") + self.assertConnectionFailed(CloseCode.MESSAGE_TOO_BIG, "") def test_recv_binary_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(True, OP_BINARY, b"tea" * 342)) self.process_invalid_frames() - self.assertConnectionFailed(1009, "") + self.assertConnectionFailed(CloseCode.MESSAGE_TOO_BIG, "") def test_recv_text_no_max_size(self): self.protocol.max_size = None # for test coverage @@ -531,7 +552,7 @@ async def read_message(): self.protocol.read_message = read_message self.process_invalid_frames() - self.assertConnectionFailed(1011, "") + self.assertConnectionFailed(CloseCode.INTERNAL_ERROR, "") def test_recv_canceled(self): recv = self.loop.create_task(self.protocol.recv()) @@ -667,7 +688,7 @@ def test_send_iterable_mixed_type_error(self): self.loop.run_until_complete(self.protocol.send(["café", b"tea"])) self.assertFramesSent( (False, OP_TEXT, "café".encode("utf-8")), - (True, OP_CLOSE, Close(1011, "").serialize()), + (True, OP_CLOSE, Close(CloseCode.INTERNAL_ERROR, "").serialize()), ) def test_send_iterable_prevents_concurrent_send(self): @@ -741,7 +762,7 @@ def test_send_async_iterable_mixed_type_error(self): ) self.assertFramesSent( (False, OP_TEXT, "café".encode("utf-8")), - (True, OP_CLOSE, Close(1011, "").serialize()), + (True, OP_CLOSE, Close(CloseCode.INTERNAL_ERROR, "").serialize()), ) def test_send_async_iterable_prevents_concurrent_send(self): @@ -1068,14 +1089,14 @@ def test_fragmented_text_payload_too_big(self): self.receive_frame(Frame(False, OP_TEXT, "café".encode("utf-8") * 100)) self.receive_frame(Frame(True, OP_CONT, "café".encode("utf-8") * 105)) self.process_invalid_frames() - self.assertConnectionFailed(1009, "") + self.assertConnectionFailed(CloseCode.MESSAGE_TOO_BIG, "") def test_fragmented_binary_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(False, OP_BINARY, b"tea" * 171)) self.receive_frame(Frame(True, OP_CONT, b"tea" * 171)) self.process_invalid_frames() - self.assertConnectionFailed(1009, "") + self.assertConnectionFailed(CloseCode.MESSAGE_TOO_BIG, "") def test_fragmented_text_no_max_size(self): self.protocol.max_size = None # for test coverage @@ -1104,7 +1125,7 @@ def test_unterminated_fragmented_text(self): # Missing the second part of the fragmented frame. self.receive_frame(Frame(True, OP_BINARY, b"tea")) self.process_invalid_frames() - self.assertConnectionFailed(1002, "") + self.assertConnectionFailed(CloseCode.PROTOCOL_ERROR, "") def test_close_handshake_in_fragmented_text(self): self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) @@ -1114,12 +1135,12 @@ def test_close_handshake_in_fragmented_text(self): # can be interjected in the middle of a fragmented message and that a # close frame must be echoed. Even though there's an unterminated # message, technically, the closing handshake was successful. - self.assertConnectionClosed(1005, "") + self.assertConnectionClosed(CloseCode.NO_STATUS_RCVD, "") def test_connection_close_in_fragmented_text(self): self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) self.process_invalid_frames() - self.assertConnectionFailed(1006, "") + self.assertConnectionFailed(CloseCode.ABNORMAL_CLOSURE, "") # Test miscellaneous code paths to ensure full coverage. @@ -1127,7 +1148,7 @@ def test_connection_lost(self): # Test calling connection_lost without going through close_connection. self.protocol.connection_lost(None) - self.assertConnectionFailed(1006, "") + self.assertConnectionFailed(CloseCode.ABNORMAL_CLOSURE, "") def test_ensure_open_before_opening_handshake(self): # Simulate a bug by forcibly reverting the protocol state. @@ -1168,7 +1189,7 @@ def test_connection_closed_attributes(self): self.loop.run_until_complete(self.protocol.recv()) connection_closed_exc = context.exception - self.assertEqual(connection_closed_exc.code, 1000) + self.assertEqual(connection_closed_exc.code, CloseCode.NORMAL_CLOSURE) self.assertEqual(connection_closed_exc.reason, "close") # Test the protocol logic for sending keepalive pings. @@ -1228,7 +1249,9 @@ def test_keepalive_ping_not_acknowledged_closes_connection(self): # Connection is closed at 6ms. self.loop.run_until_complete(asyncio.sleep(4 * MS)) self.assertOneFrameSent( - True, OP_CLOSE, Close(1011, "keepalive ping timeout").serialize() + True, + OP_CLOSE, + Close(CloseCode.INTERNAL_ERROR, "keepalive ping timeout").serialize(), ) # The keepalive ping task is complete. @@ -1328,13 +1351,13 @@ def test_local_close(self): # Run the closing handshake. self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionClosed(1000, "close") + self.assertConnectionClosed(CloseCode.NORMAL_CLOSURE, "close") self.assertOneFrameSent(*self.close_frame) # Closing the connection again is a no-op. self.loop.run_until_complete(self.protocol.close(reason="oh noes!")) - self.assertConnectionClosed(1000, "close") + self.assertConnectionClosed(CloseCode.NORMAL_CLOSURE, "close") self.assertNoFrameSent() def test_remote_close(self): @@ -1347,13 +1370,13 @@ def test_remote_close(self): with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.recv()) - self.assertConnectionClosed(1000, "close") + self.assertConnectionClosed(CloseCode.NORMAL_CLOSURE, "close") self.assertOneFrameSent(*self.close_frame) # Closing the connection again is a no-op. self.loop.run_until_complete(self.protocol.close(reason="oh noes!")) - self.assertConnectionClosed(1000, "close") + self.assertConnectionClosed(CloseCode.NORMAL_CLOSURE, "close") self.assertNoFrameSent() def test_remote_close_and_connection_lost(self): @@ -1367,7 +1390,7 @@ def test_remote_close_and_connection_lost(self): with self.assertNoLogs(): self.loop.run_until_complete(self.protocol.close(reason="oh noes!")) - self.assertConnectionClosed(1000, "close") + self.assertConnectionClosed(CloseCode.NORMAL_CLOSURE, "close") self.assertOneFrameSent(*self.close_frame) def test_simultaneous_close(self): @@ -1380,7 +1403,7 @@ def test_simultaneous_close(self): self.loop.run_until_complete(self.protocol.close(reason="local")) - self.assertConnectionClosed(1000, "remote") + self.assertConnectionClosed(CloseCode.NORMAL_CLOSURE, "remote") # The current implementation sends a close frame in response to the # close frame received from the remote end. It skips the close frame # that should be sent as a result of calling close(). @@ -1394,7 +1417,7 @@ def test_close_preserves_incoming_frames(self): self.loop.call_later(MS, self.receive_eof_if_client) self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionClosed(1000, "close") + self.assertConnectionClosed(CloseCode.NORMAL_CLOSURE, "close") self.assertOneFrameSent(*self.close_frame) next_message = self.loop.run_until_complete(self.protocol.recv()) @@ -1407,14 +1430,14 @@ def test_close_protocol_error(self): self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionFailed(1002, "") + self.assertConnectionFailed(CloseCode.PROTOCOL_ERROR, "") def test_close_connection_lost(self): self.receive_eof() self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionFailed(1006, "") + self.assertConnectionFailed(CloseCode.ABNORMAL_CLOSURE, "") def test_local_close_during_recv(self): recv = self.loop.create_task(self.protocol.recv()) @@ -1427,7 +1450,7 @@ def test_local_close_during_recv(self): with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(recv) - self.assertConnectionClosed(1000, "close") + self.assertConnectionClosed(CloseCode.NORMAL_CLOSURE, "close") # There is no test_remote_close_during_recv because it would be identical # to test_remote_close. @@ -1442,7 +1465,7 @@ def test_remote_close_during_send(self): with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(send) - self.assertConnectionClosed(1000, "close") + self.assertConnectionClosed(CloseCode.NORMAL_CLOSURE, "close") # There is no test_local_close_during_send because this cannot really # happen, considering that writes are serialized. @@ -1557,7 +1580,7 @@ def test_local_close_send_close_frame_timeout(self): # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(9 * MS, 19 * MS): self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionClosed(1006, "") + self.assertConnectionClosed(CloseCode.ABNORMAL_CLOSURE, "") def test_local_close_receive_close_frame_timeout(self): self.protocol.close_timeout = 10 * MS @@ -1565,7 +1588,7 @@ def test_local_close_receive_close_frame_timeout(self): # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(9 * MS, 19 * MS): self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionClosed(1006, "") + self.assertConnectionClosed(CloseCode.ABNORMAL_CLOSURE, "") def test_local_close_connection_lost_timeout_after_write_eof(self): self.protocol.close_timeout = 10 * MS @@ -1579,7 +1602,10 @@ def test_local_close_connection_lost_timeout_after_write_eof(self): self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) # Due to a bug in coverage, this is erroneously reported as not covered. - self.assertConnectionClosed(1000, "close") # pragma: no cover + self.assertConnectionClosed( # pragma: no cover + CloseCode.NORMAL_CLOSURE, + "close", + ) def test_local_close_connection_lost_timeout_after_close(self): self.protocol.close_timeout = 10 * MS @@ -1597,7 +1623,10 @@ def test_local_close_connection_lost_timeout_after_close(self): self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) # Due to a bug in coverage, this is erroneously reported as not covered. - self.assertConnectionClosed(1000, "close") # pragma: no cover + self.assertConnectionClosed( # pragma: no cover + CloseCode.NORMAL_CLOSURE, + "close", + ) class ClientTests(CommonTests, AsyncioTestCase): @@ -1616,7 +1645,10 @@ def test_local_close_send_close_frame_timeout(self): with self.assertCompletesWithin(19 * MS, 29 * MS): self.loop.run_until_complete(self.protocol.close(reason="close")) # Due to a bug in coverage, this is erroneously reported as not covered. - self.assertConnectionClosed(1006, "") # pragma: no cover + self.assertConnectionClosed( # pragma: no cover + CloseCode.ABNORMAL_CLOSURE, + "", + ) def test_local_close_receive_close_frame_timeout(self): self.protocol.close_timeout = 10 * MS @@ -1627,7 +1659,10 @@ def test_local_close_receive_close_frame_timeout(self): with self.assertCompletesWithin(19 * MS, 29 * MS): self.loop.run_until_complete(self.protocol.close(reason="close")) # Due to a bug in coverage, this is erroneously reported as not covered. - self.assertConnectionClosed(1006, "") # pragma: no cover + self.assertConnectionClosed( # pragma: no cover + CloseCode.ABNORMAL_CLOSURE, + "", + ) def test_local_close_connection_lost_timeout_after_write_eof(self): self.protocol.close_timeout = 10 * MS @@ -1643,7 +1678,10 @@ def test_local_close_connection_lost_timeout_after_write_eof(self): self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) # Due to a bug in coverage, this is erroneously reported as not covered. - self.assertConnectionClosed(1000, "close") # pragma: no cover + self.assertConnectionClosed( # pragma: no cover + CloseCode.NORMAL_CLOSURE, + "close", + ) def test_local_close_connection_lost_timeout_after_close(self): self.protocol.close_timeout = 10 * MS @@ -1664,4 +1702,7 @@ def test_local_close_connection_lost_timeout_after_close(self): self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) # Due to a bug in coverage, this is erroneously reported as not covered. - self.assertConnectionClosed(1000, "close") # pragma: no cover + self.assertConnectionClosed( # pragma: no cover + CloseCode.NORMAL_CLOSURE, + "close", + ) diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 0e7cff948..63544d4ad 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -10,7 +10,7 @@ from unittest.mock import patch from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK -from websockets.frames import Frame, Opcode +from websockets.frames import CloseCode, Frame, Opcode from websockets.protocol import CLIENT, SERVER, Protocol from websockets.sync.connection import * @@ -134,7 +134,7 @@ def test_iter_connection_closed_ok(self): def test_iter_connection_closed_error(self): """__iter__ raises ConnnectionClosedError after an error.""" iterator = iter(self.connection) - self.remote_connection.close(code=1011) + self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) with self.assertRaises(ConnectionClosedError): next(iterator) @@ -168,7 +168,7 @@ def test_recv_connection_closed_ok(self): def test_recv_connection_closed_error(self): """recv raises ConnectionClosedError after an error.""" - self.remote_connection.close(code=1011) + self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) with self.assertRaises(ConnectionClosedError): self.connection.recv() @@ -248,7 +248,7 @@ def test_recv_streaming_connection_closed_ok(self): def test_recv_streaming_connection_closed_error(self): """recv_streaming raises ConnectionClosedError after an error.""" - self.remote_connection.close(code=1011) + self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) with self.assertRaises(ConnectionClosedError): list(self.connection.recv_streaming()) @@ -322,7 +322,7 @@ def test_send_connection_closed_ok(self): def test_send_connection_closed_error(self): """send raises ConnectionClosedError after an error.""" - self.remote_connection.close(code=1011) + self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) with self.assertRaises(ConnectionClosedError): self.connection.send("😀") @@ -400,7 +400,7 @@ def test_close(self): def test_close_explicit_code_reason(self): """close sends a close frame with a given code and reason.""" - self.connection.close(1001, "bye!") + self.connection.close(CloseCode.GOING_AWAY, "bye!") self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe9bye!")) def test_close_waits_for_close_frame(self): @@ -567,7 +567,7 @@ def fragments(): exc = raised.exception self.assertEqual( str(exc), - "sent 1011 (unexpected error) close during fragmented message; " + "sent 1011 (internal error) close during fragmented message; " "no close frame received", ) self.assertIsNone(exc.__cause__) diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 60e70c0a5..f9db84246 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -64,8 +64,8 @@ def test_connection_handler_raises_exception(self): with run_client(server) as client: with self.assertRaisesRegex( ConnectionClosedError, - r"received 1011 \(unexpected error\); " - r"then sent 1011 \(unexpected error\)", + r"received 1011 \(internal error\); " + r"then sent 1011 \(internal error\)", ): client.recv() diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index a0f9dfcd2..1e6f58fad 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -2,49 +2,80 @@ from websockets.datastructures import Headers from websockets.exceptions import * -from websockets.frames import Close +from websockets.frames import Close, CloseCode from websockets.http11 import Response class ExceptionsTests(unittest.TestCase): def test_str(self): for exception, exception_str in [ - # fmt: off ( WebSocketException("something went wrong"), "something went wrong", ), ( - ConnectionClosed(Close(1000, ""), Close(1000, ""), True), + ConnectionClosed( + Close(CloseCode.NORMAL_CLOSURE, ""), + Close(CloseCode.NORMAL_CLOSURE, ""), + True, + ), "received 1000 (OK); then sent 1000 (OK)", ), ( - ConnectionClosed(Close(1001, "Bye!"), Close(1001, "Bye!"), False), + ConnectionClosed( + Close(CloseCode.GOING_AWAY, "Bye!"), + Close(CloseCode.GOING_AWAY, "Bye!"), + False, + ), "sent 1001 (going away) Bye!; then received 1001 (going away) Bye!", ), ( - ConnectionClosed(Close(1000, "race"), Close(1000, "cond"), True), + ConnectionClosed( + Close(CloseCode.NORMAL_CLOSURE, "race"), + Close(CloseCode.NORMAL_CLOSURE, "cond"), + True, + ), "received 1000 (OK) race; then sent 1000 (OK) cond", ), ( - ConnectionClosed(Close(1000, "cond"), Close(1000, "race"), False), + ConnectionClosed( + Close(CloseCode.NORMAL_CLOSURE, "cond"), + Close(CloseCode.NORMAL_CLOSURE, "race"), + False, + ), "sent 1000 (OK) race; then received 1000 (OK) cond", ), ( - ConnectionClosed(None, Close(1009, ""), None), + ConnectionClosed( + None, + Close(CloseCode.MESSAGE_TOO_BIG, ""), + None, + ), "sent 1009 (message too big); no close frame received", ), ( - ConnectionClosed(Close(1002, ""), None, None), + ConnectionClosed( + Close(CloseCode.PROTOCOL_ERROR, ""), + None, + None, + ), "received 1002 (protocol error); no close frame sent", ), ( - ConnectionClosedOK(Close(1000, ""), Close(1000, ""), True), + ConnectionClosedOK( + Close(CloseCode.NORMAL_CLOSURE, ""), + Close(CloseCode.NORMAL_CLOSURE, ""), + True, + ), "received 1000 (OK); then sent 1000 (OK)", ), ( - ConnectionClosedError(None, None, None), - "no close frame received or sent" + ConnectionClosedError( + None, + None, + None, + ), + "no close frame received or sent", ), ( InvalidHandshake("invalid request"), @@ -75,11 +106,8 @@ def test_str(self): "invalid Name header: Value", ), ( - InvalidHeaderFormat( - "Sec-WebSocket-Protocol", "expected token", "a=|", 3 - ), - "invalid Sec-WebSocket-Protocol header: " - "expected token at 3 in a=|", + InvalidHeaderFormat("Sec-WebSocket-Protocol", "exp. token", "a=|", 3), + "invalid Sec-WebSocket-Protocol header: exp. token at 3 in a=|", ), ( InvalidHeaderValue("Sec-WebSocket-Version", "42"), @@ -153,17 +181,16 @@ def test_str(self): ProtocolError("invalid opcode: 7"), "invalid opcode: 7", ), - # fmt: on ]: with self.subTest(exception=exception): self.assertEqual(str(exception), exception_str) def test_connection_closed_attributes_backwards_compatibility(self): - exception = ConnectionClosed(Close(1000, "OK"), None, None) - self.assertEqual(exception.code, 1000) + exception = ConnectionClosed(Close(CloseCode.NORMAL_CLOSURE, "OK"), None, None) + self.assertEqual(exception.code, CloseCode.NORMAL_CLOSURE) self.assertEqual(exception.reason, "OK") def test_connection_closed_attributes_backwards_compatibility_defaults(self): exception = ConnectionClosed(None, None, None) - self.assertEqual(exception.code, 1006) + self.assertEqual(exception.code, CloseCode.ABNORMAL_CLOSURE) self.assertEqual(exception.reason, "") diff --git a/tests/test_frames.py b/tests/test_frames.py index e7c48b930..e323b3b57 100644 --- a/tests/test_frames.py +++ b/tests/test_frames.py @@ -5,6 +5,7 @@ from websockets.exceptions import PayloadTooBig, ProtocolError from websockets.frames import * +from websockets.frames import CloseCode from websockets.streams import StreamReader from .utils import GeneratorTestCase @@ -444,18 +445,42 @@ def assertCloseData(self, close, data): self.assertEqual(parsed, close) def test_str(self): - self.assertEqual(str(Close(1000, "")), "1000 (OK)") - self.assertEqual(str(Close(1001, "Bye!")), "1001 (going away) Bye!") - self.assertEqual(str(Close(3000, "")), "3000 (registered)") - self.assertEqual(str(Close(4000, "")), "4000 (private use)") - self.assertEqual(str(Close(5000, "")), "5000 (unknown)") + self.assertEqual( + str(Close(CloseCode.NORMAL_CLOSURE, "")), + "1000 (OK)", + ) + self.assertEqual( + str(Close(CloseCode.GOING_AWAY, "Bye!")), + "1001 (going away) Bye!", + ) + self.assertEqual( + str(Close(3000, "")), + "3000 (registered)", + ) + self.assertEqual( + str(Close(4000, "")), + "4000 (private use)", + ) + self.assertEqual( + str(Close(5000, "")), + "5000 (unknown)", + ) def test_parse_and_serialize(self): - self.assertCloseData(Close(1001, ""), b"\x03\xe9") - self.assertCloseData(Close(1000, "OK"), b"\x03\xe8OK") + self.assertCloseData( + Close(CloseCode.NORMAL_CLOSURE, "OK"), + b"\x03\xe8OK", + ) + self.assertCloseData( + Close(CloseCode.GOING_AWAY, ""), + b"\x03\xe9", + ) def test_parse_empty(self): - self.assertEqual(Close.parse(b""), Close(1005, "")) + self.assertEqual( + Close.parse(b""), + Close(CloseCode.NO_STATUS_RCVD, ""), + ) def test_parse_errors(self): with self.assertRaises(ProtocolError): diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 7321d2594..a64172b53 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -16,6 +16,7 @@ OP_PONG, OP_TEXT, Close, + CloseCode, Frame, ) from websockets.protocol import * @@ -133,14 +134,18 @@ def test_client_receives_masked_frame(self): client.receive_data(self.masked_text_frame_data) self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "incorrect masking") - self.assertConnectionFailing(client, 1002, "incorrect masking") + self.assertConnectionFailing( + client, CloseCode.PROTOCOL_ERROR, "incorrect masking" + ) def test_server_receives_unmasked_frame(self): server = Protocol(SERVER) server.receive_data(self.unmasked_text_frame_date) self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "incorrect masking") - self.assertConnectionFailing(server, 1002, "incorrect masking") + self.assertConnectionFailing( + server, CloseCode.PROTOCOL_ERROR, "incorrect masking" + ) class ContinuationTests(ProtocolTestCase): @@ -166,14 +171,18 @@ def test_client_receives_unexpected_continuation(self): client.receive_data(b"\x00\x00") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "unexpected continuation frame") - self.assertConnectionFailing(client, 1002, "unexpected continuation frame") + self.assertConnectionFailing( + client, CloseCode.PROTOCOL_ERROR, "unexpected continuation frame" + ) def test_server_receives_unexpected_continuation(self): server = Protocol(SERVER) server.receive_data(b"\x00\x80\x00\x00\x00\x00") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "unexpected continuation frame") - self.assertConnectionFailing(server, 1002, "unexpected continuation frame") + self.assertConnectionFailing( + server, CloseCode.PROTOCOL_ERROR, "unexpected continuation frame" + ) def test_client_sends_continuation_after_sending_close(self): client = Protocol(CLIENT) @@ -181,7 +190,7 @@ def test_client_sends_continuation_after_sending_close(self): # message (see test_client_send_close_in_fragmented_message), in fact, # this is the same test as test_client_sends_unexpected_continuation. with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_close(1001) + client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) with self.assertRaises(ProtocolError) as raised: client.send_continuation(b"", fin=False) @@ -192,7 +201,7 @@ def test_server_sends_continuation_after_sending_close(self): # message (see test_server_send_close_in_fragmented_message), in fact, # this is the same test as test_server_sends_unexpected_continuation. server = Protocol(SERVER) - server.send_close(1000) + server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) with self.assertRaises(ProtocolError) as raised: server.send_continuation(b"", fin=False) @@ -201,7 +210,7 @@ def test_server_sends_continuation_after_sending_close(self): def test_client_receives_continuation_after_receiving_close(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") - self.assertConnectionClosing(client, 1000) + self.assertConnectionClosing(client, CloseCode.NORMAL_CLOSURE) client.receive_data(b"\x00\x00") self.assertFrameReceived(client, None) self.assertFrameSent(client, None) @@ -209,7 +218,7 @@ def test_client_receives_continuation_after_receiving_close(self): def test_server_receives_continuation_after_receiving_close(self): server = Protocol(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") - self.assertConnectionClosing(server, 1001) + self.assertConnectionClosing(server, CloseCode.GOING_AWAY) server.receive_data(b"\x00\x80\x00\xff\x00\xff") self.assertFrameReceived(server, None) self.assertFrameSent(server, None) @@ -255,14 +264,18 @@ def test_client_receives_text_over_size_limit(self): client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") self.assertIsInstance(client.parser_exc, PayloadTooBig) self.assertEqual(str(client.parser_exc), "over size limit (4 > 3 bytes)") - self.assertConnectionFailing(client, 1009, "over size limit (4 > 3 bytes)") + self.assertConnectionFailing( + client, CloseCode.MESSAGE_TOO_BIG, "over size limit (4 > 3 bytes)" + ) def test_server_receives_text_over_size_limit(self): server = Protocol(SERVER, max_size=3) server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") self.assertIsInstance(server.parser_exc, PayloadTooBig) self.assertEqual(str(server.parser_exc), "over size limit (4 > 3 bytes)") - self.assertConnectionFailing(server, 1009, "over size limit (4 > 3 bytes)") + self.assertConnectionFailing( + server, CloseCode.MESSAGE_TOO_BIG, "over size limit (4 > 3 bytes)" + ) def test_client_receives_text_without_size_limit(self): client = Protocol(CLIENT, max_size=None) @@ -349,7 +362,9 @@ def test_client_receives_fragmented_text_over_size_limit(self): client.receive_data(b"\x80\x02\x98\x80") self.assertIsInstance(client.parser_exc, PayloadTooBig) self.assertEqual(str(client.parser_exc), "over size limit (2 > 1 bytes)") - self.assertConnectionFailing(client, 1009, "over size limit (2 > 1 bytes)") + self.assertConnectionFailing( + client, CloseCode.MESSAGE_TOO_BIG, "over size limit (2 > 1 bytes)" + ) def test_server_receives_fragmented_text_over_size_limit(self): server = Protocol(SERVER, max_size=3) @@ -361,7 +376,9 @@ def test_server_receives_fragmented_text_over_size_limit(self): server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") self.assertIsInstance(server.parser_exc, PayloadTooBig) self.assertEqual(str(server.parser_exc), "over size limit (2 > 1 bytes)") - self.assertConnectionFailing(server, 1009, "over size limit (2 > 1 bytes)") + self.assertConnectionFailing( + server, CloseCode.MESSAGE_TOO_BIG, "over size limit (2 > 1 bytes)" + ) def test_client_receives_fragmented_text_without_size_limit(self): client = Protocol(CLIENT, max_size=None) @@ -423,7 +440,9 @@ def test_client_receives_unexpected_text(self): client.receive_data(b"\x01\x00") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "expected a continuation frame") - self.assertConnectionFailing(client, 1002, "expected a continuation frame") + self.assertConnectionFailing( + client, CloseCode.PROTOCOL_ERROR, "expected a continuation frame" + ) def test_server_receives_unexpected_text(self): server = Protocol(SERVER) @@ -435,19 +454,21 @@ def test_server_receives_unexpected_text(self): server.receive_data(b"\x01\x80\x00\x00\x00\x00") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "expected a continuation frame") - self.assertConnectionFailing(server, 1002, "expected a continuation frame") + self.assertConnectionFailing( + server, CloseCode.PROTOCOL_ERROR, "expected a continuation frame" + ) def test_client_sends_text_after_sending_close(self): client = Protocol(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_close(1001) + client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) with self.assertRaises(InvalidState): client.send_text(b"") def test_server_sends_text_after_sending_close(self): server = Protocol(SERVER) - server.send_close(1000) + server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) with self.assertRaises(InvalidState): server.send_text(b"") @@ -455,7 +476,7 @@ def test_server_sends_text_after_sending_close(self): def test_client_receives_text_after_receiving_close(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") - self.assertConnectionClosing(client, 1000) + self.assertConnectionClosing(client, CloseCode.NORMAL_CLOSURE) client.receive_data(b"\x81\x00") self.assertFrameReceived(client, None) self.assertFrameSent(client, None) @@ -463,7 +484,7 @@ def test_client_receives_text_after_receiving_close(self): def test_server_receives_text_after_receiving_close(self): server = Protocol(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") - self.assertConnectionClosing(server, 1001) + self.assertConnectionClosing(server, CloseCode.GOING_AWAY) server.receive_data(b"\x81\x80\x00\xff\x00\xff") self.assertFrameReceived(server, None) self.assertFrameSent(server, None) @@ -509,14 +530,18 @@ def test_client_receives_binary_over_size_limit(self): client.receive_data(b"\x82\x04\x01\x02\xfe\xff") self.assertIsInstance(client.parser_exc, PayloadTooBig) self.assertEqual(str(client.parser_exc), "over size limit (4 > 3 bytes)") - self.assertConnectionFailing(client, 1009, "over size limit (4 > 3 bytes)") + self.assertConnectionFailing( + client, CloseCode.MESSAGE_TOO_BIG, "over size limit (4 > 3 bytes)" + ) def test_server_receives_binary_over_size_limit(self): server = Protocol(SERVER, max_size=3) server.receive_data(b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff") self.assertIsInstance(server.parser_exc, PayloadTooBig) self.assertEqual(str(server.parser_exc), "over size limit (4 > 3 bytes)") - self.assertConnectionFailing(server, 1009, "over size limit (4 > 3 bytes)") + self.assertConnectionFailing( + server, CloseCode.MESSAGE_TOO_BIG, "over size limit (4 > 3 bytes)" + ) def test_client_sends_fragmented_binary(self): client = Protocol(CLIENT) @@ -587,7 +612,9 @@ def test_client_receives_fragmented_binary_over_size_limit(self): client.receive_data(b"\x80\x02\xfe\xff") self.assertIsInstance(client.parser_exc, PayloadTooBig) self.assertEqual(str(client.parser_exc), "over size limit (2 > 1 bytes)") - self.assertConnectionFailing(client, 1009, "over size limit (2 > 1 bytes)") + self.assertConnectionFailing( + client, CloseCode.MESSAGE_TOO_BIG, "over size limit (2 > 1 bytes)" + ) def test_server_receives_fragmented_binary_over_size_limit(self): server = Protocol(SERVER, max_size=3) @@ -599,7 +626,9 @@ def test_server_receives_fragmented_binary_over_size_limit(self): server.receive_data(b"\x80\x82\x00\x00\x00\x00\xfe\xff") self.assertIsInstance(server.parser_exc, PayloadTooBig) self.assertEqual(str(server.parser_exc), "over size limit (2 > 1 bytes)") - self.assertConnectionFailing(server, 1009, "over size limit (2 > 1 bytes)") + self.assertConnectionFailing( + server, CloseCode.MESSAGE_TOO_BIG, "over size limit (2 > 1 bytes)" + ) def test_client_sends_unexpected_binary(self): client = Protocol(CLIENT) @@ -625,7 +654,9 @@ def test_client_receives_unexpected_binary(self): client.receive_data(b"\x02\x00") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "expected a continuation frame") - self.assertConnectionFailing(client, 1002, "expected a continuation frame") + self.assertConnectionFailing( + client, CloseCode.PROTOCOL_ERROR, "expected a continuation frame" + ) def test_server_receives_unexpected_binary(self): server = Protocol(SERVER) @@ -637,19 +668,21 @@ def test_server_receives_unexpected_binary(self): server.receive_data(b"\x02\x80\x00\x00\x00\x00") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "expected a continuation frame") - self.assertConnectionFailing(server, 1002, "expected a continuation frame") + self.assertConnectionFailing( + server, CloseCode.PROTOCOL_ERROR, "expected a continuation frame" + ) def test_client_sends_binary_after_sending_close(self): client = Protocol(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_close(1001) + client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) with self.assertRaises(InvalidState): client.send_binary(b"") def test_server_sends_binary_after_sending_close(self): server = Protocol(SERVER) - server.send_close(1000) + server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) with self.assertRaises(InvalidState): server.send_binary(b"") @@ -657,7 +690,7 @@ def test_server_sends_binary_after_sending_close(self): def test_client_receives_binary_after_receiving_close(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") - self.assertConnectionClosing(client, 1000) + self.assertConnectionClosing(client, CloseCode.NORMAL_CLOSURE) client.receive_data(b"\x82\x00") self.assertFrameReceived(client, None) self.assertFrameSent(client, None) @@ -665,7 +698,7 @@ def test_client_receives_binary_after_receiving_close(self): def test_server_receives_binary_after_receiving_close(self): server = Protocol(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") - self.assertConnectionClosing(server, 1001) + self.assertConnectionClosing(server, CloseCode.GOING_AWAY) server.receive_data(b"\x82\x80\x00\xff\x00\xff") self.assertFrameReceived(server, None) self.assertFrameSent(server, None) @@ -687,7 +720,7 @@ def test_close_code(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x04\x03\xe8OK") client.receive_eof() - self.assertEqual(client.close_code, 1000) + self.assertEqual(client.close_code, CloseCode.NORMAL_CLOSURE) def test_close_reason(self): server = Protocol(SERVER) @@ -699,7 +732,7 @@ def test_close_code_not_provided(self): server = Protocol(SERVER) server.receive_data(b"\x88\x80\x00\x00\x00\x00") server.receive_eof() - self.assertEqual(server.close_code, 1005) + self.assertEqual(server.close_code, CloseCode.NO_STATUS_RCVD) def test_close_reason_not_provided(self): client = Protocol(CLIENT) @@ -710,7 +743,7 @@ def test_close_reason_not_provided(self): def test_close_code_not_available(self): client = Protocol(CLIENT) client.receive_eof() - self.assertEqual(client.close_code, 1006) + self.assertEqual(client.close_code, CloseCode.ABNORMAL_CLOSURE) def test_close_reason_not_available(self): server = Protocol(SERVER) @@ -812,32 +845,32 @@ def test_server_receives_close_then_sends_close(self): def test_client_sends_close_with_code(self): client = Protocol(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_close(1001) + client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) self.assertIs(client.state, CLOSING) def test_server_sends_close_with_code(self): server = Protocol(SERVER) - server.send_close(1000) + server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) self.assertIs(server.state, CLOSING) def test_client_receives_close_with_code(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") - self.assertConnectionClosing(client, 1000, "") + self.assertConnectionClosing(client, CloseCode.NORMAL_CLOSURE, "") self.assertIs(client.state, CLOSING) def test_server_receives_close_with_code(self): server = Protocol(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") - self.assertConnectionClosing(server, 1001, "") + self.assertConnectionClosing(server, CloseCode.GOING_AWAY, "") self.assertIs(server.state, CLOSING) def test_client_sends_close_with_code_and_reason(self): client = Protocol(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_close(1001, "going away") + client.send_close(CloseCode.GOING_AWAY, "going away") self.assertEqual( client.data_to_send(), [b"\x88\x8c\x00\x00\x00\x00\x03\xe9going away"] ) @@ -845,20 +878,20 @@ def test_client_sends_close_with_code_and_reason(self): def test_server_sends_close_with_code_and_reason(self): server = Protocol(SERVER) - server.send_close(1000, "OK") + server.send_close(CloseCode.NORMAL_CLOSURE, "OK") self.assertEqual(server.data_to_send(), [b"\x88\x04\x03\xe8OK"]) self.assertIs(server.state, CLOSING) def test_client_receives_close_with_code_and_reason(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x04\x03\xe8OK") - self.assertConnectionClosing(client, 1000, "OK") + self.assertConnectionClosing(client, CloseCode.NORMAL_CLOSURE, "OK") self.assertIs(client.state, CLOSING) def test_server_receives_close_with_code_and_reason(self): server = Protocol(SERVER) server.receive_data(b"\x88\x8c\x00\x00\x00\x00\x03\xe9going away") - self.assertConnectionClosing(server, 1001, "going away") + self.assertConnectionClosing(server, CloseCode.GOING_AWAY, "going away") self.assertIs(server.state, CLOSING) def test_client_sends_close_with_reason_only(self): @@ -878,7 +911,9 @@ def test_client_receives_close_with_truncated_code(self): client.receive_data(b"\x88\x01\x03") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "close frame too short") - self.assertConnectionFailing(client, 1002, "close frame too short") + self.assertConnectionFailing( + client, CloseCode.PROTOCOL_ERROR, "close frame too short" + ) self.assertIs(client.state, CLOSING) def test_server_receives_close_with_truncated_code(self): @@ -886,7 +921,9 @@ def test_server_receives_close_with_truncated_code(self): server.receive_data(b"\x88\x81\x00\x00\x00\x00\x03") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "close frame too short") - self.assertConnectionFailing(server, 1002, "close frame too short") + self.assertConnectionFailing( + server, CloseCode.PROTOCOL_ERROR, "close frame too short" + ) self.assertIs(server.state, CLOSING) def test_client_receives_close_with_non_utf8_reason(self): @@ -898,7 +935,9 @@ def test_client_receives_close_with_non_utf8_reason(self): str(client.parser_exc), "'utf-8' codec can't decode byte 0xff in position 0: invalid start byte", ) - self.assertConnectionFailing(client, 1007, "invalid start byte at position 0") + self.assertConnectionFailing( + client, CloseCode.INVALID_DATA, "invalid start byte at position 0" + ) self.assertIs(client.state, CLOSING) def test_server_receives_close_with_non_utf8_reason(self): @@ -910,7 +949,9 @@ def test_server_receives_close_with_non_utf8_reason(self): str(server.parser_exc), "'utf-8' codec can't decode byte 0xff in position 0: invalid start byte", ) - self.assertConnectionFailing(server, 1007, "invalid start byte at position 0") + self.assertConnectionFailing( + server, CloseCode.INVALID_DATA, "invalid start byte at position 0" + ) self.assertIs(server.state, CLOSING) @@ -1011,19 +1052,23 @@ def test_client_receives_fragmented_ping_frame(self): client.receive_data(b"\x09\x00") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "fragmented control frame") - self.assertConnectionFailing(client, 1002, "fragmented control frame") + self.assertConnectionFailing( + client, CloseCode.PROTOCOL_ERROR, "fragmented control frame" + ) def test_server_receives_fragmented_ping_frame(self): server = Protocol(SERVER) server.receive_data(b"\x09\x80\x3c\x3c\x3c\x3c") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "fragmented control frame") - self.assertConnectionFailing(server, 1002, "fragmented control frame") + self.assertConnectionFailing( + server, CloseCode.PROTOCOL_ERROR, "fragmented control frame" + ) def test_client_sends_ping_after_sending_close(self): client = Protocol(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_close(1001) + client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) # The spec says: "An endpoint MAY send a Ping frame any time (...) # before the connection is closed" but websockets doesn't support @@ -1037,7 +1082,7 @@ def test_client_sends_ping_after_sending_close(self): def test_server_sends_ping_after_sending_close(self): server = Protocol(SERVER) - server.send_close(1000) + server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) # The spec says: "An endpoint MAY send a Ping frame any time (...) # before the connection is closed" but websockets doesn't support @@ -1052,7 +1097,7 @@ def test_server_sends_ping_after_sending_close(self): def test_client_receives_ping_after_receiving_close(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") - self.assertConnectionClosing(client, 1000) + self.assertConnectionClosing(client, CloseCode.NORMAL_CLOSURE) client.receive_data(b"\x89\x04\x22\x66\xaa\xee") self.assertFrameReceived(client, None) self.assertFrameSent(client, None) @@ -1060,7 +1105,7 @@ def test_client_receives_ping_after_receiving_close(self): def test_server_receives_ping_after_receiving_close(self): server = Protocol(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") - self.assertConnectionClosing(server, 1001) + self.assertConnectionClosing(server, CloseCode.GOING_AWAY) server.receive_data(b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22") self.assertFrameReceived(server, None) self.assertFrameSent(server, None) @@ -1147,19 +1192,23 @@ def test_client_receives_fragmented_pong_frame(self): client.receive_data(b"\x0a\x00") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "fragmented control frame") - self.assertConnectionFailing(client, 1002, "fragmented control frame") + self.assertConnectionFailing( + client, CloseCode.PROTOCOL_ERROR, "fragmented control frame" + ) def test_server_receives_fragmented_pong_frame(self): server = Protocol(SERVER) server.receive_data(b"\x0a\x80\x3c\x3c\x3c\x3c") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "fragmented control frame") - self.assertConnectionFailing(server, 1002, "fragmented control frame") + self.assertConnectionFailing( + server, CloseCode.PROTOCOL_ERROR, "fragmented control frame" + ) def test_client_sends_pong_after_sending_close(self): client = Protocol(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_close(1001) + client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) # websockets doesn't support sending a Pong frame after a Close frame. with self.assertRaises(InvalidState): @@ -1167,7 +1216,7 @@ def test_client_sends_pong_after_sending_close(self): def test_server_sends_pong_after_sending_close(self): server = Protocol(SERVER) - server.send_close(1000) + server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) # websockets doesn't support sending a Pong frame after a Close frame. with self.assertRaises(InvalidState): @@ -1176,7 +1225,7 @@ def test_server_sends_pong_after_sending_close(self): def test_client_receives_pong_after_receiving_close(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") - self.assertConnectionClosing(client, 1000) + self.assertConnectionClosing(client, CloseCode.NORMAL_CLOSURE) client.receive_data(b"\x8a\x04\x22\x66\xaa\xee") self.assertFrameReceived(client, None) self.assertFrameSent(client, None) @@ -1184,7 +1233,7 @@ def test_client_receives_pong_after_receiving_close(self): def test_server_receives_pong_after_receiving_close(self): server = Protocol(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") - self.assertConnectionClosing(server, 1001) + self.assertConnectionClosing(server, CloseCode.GOING_AWAY) server.receive_data(b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22") self.assertFrameReceived(server, None) self.assertFrameSent(server, None) @@ -1200,15 +1249,15 @@ class FailTests(ProtocolTestCase): def test_client_stops_processing_frames_after_fail(self): client = Protocol(CLIENT) - client.fail(1002) - self.assertConnectionFailing(client, 1002) + client.fail(CloseCode.PROTOCOL_ERROR) + self.assertConnectionFailing(client, CloseCode.PROTOCOL_ERROR) client.receive_data(b"\x88\x02\x03\xea") self.assertFrameReceived(client, None) def test_server_stops_processing_frames_after_fail(self): server = Protocol(SERVER) - server.fail(1002) - self.assertConnectionFailing(server, 1002) + server.fail(CloseCode.PROTOCOL_ERROR) + self.assertConnectionFailing(server, CloseCode.PROTOCOL_ERROR) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xea") self.assertFrameReceived(server, None) @@ -1320,7 +1369,7 @@ def test_client_send_close_in_fragmented_message(self): # endpoint must not send a data frame after a close frame, a close # frame can't be "in the middle" of a fragmented message. with self.assertRaises(ProtocolError) as raised: - client.send_close(1001) + client.send_close(CloseCode.GOING_AWAY) self.assertEqual(str(raised.exception), "expected a continuation frame") client.send_continuation(b"Eggs", fin=True) @@ -1333,7 +1382,7 @@ def test_server_send_close_in_fragmented_message(self): # endpoint must not send a data frame after a close frame, a close # frame can't be "in the middle" of a fragmented message. with self.assertRaises(ProtocolError) as raised: - server.send_close(1000) + server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(str(raised.exception), "expected a continuation frame") def test_client_receive_close_in_fragmented_message(self): @@ -1350,7 +1399,9 @@ def test_client_receive_close_in_fragmented_message(self): client.receive_data(b"\x88\x02\x03\xe8") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "incomplete fragmented message") - self.assertConnectionFailing(client, 1002, "incomplete fragmented message") + self.assertConnectionFailing( + client, CloseCode.PROTOCOL_ERROR, "incomplete fragmented message" + ) def test_server_receive_close_in_fragmented_message(self): server = Protocol(SERVER) @@ -1366,7 +1417,9 @@ def test_server_receive_close_in_fragmented_message(self): server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "incomplete fragmented message") - self.assertConnectionFailing(server, 1002, "incomplete fragmented message") + self.assertConnectionFailing( + server, CloseCode.PROTOCOL_ERROR, "incomplete fragmented message" + ) class EOFTests(ProtocolTestCase): @@ -1428,35 +1481,35 @@ def test_server_receives_eof_inside_frame(self): def test_client_receives_data_after_exception(self): client = Protocol(CLIENT) client.receive_data(b"\xff\xff") - self.assertConnectionFailing(client, 1002, "invalid opcode") + self.assertConnectionFailing(client, CloseCode.PROTOCOL_ERROR, "invalid opcode") client.receive_data(b"\x00\x00") self.assertFrameSent(client, None) def test_server_receives_data_after_exception(self): server = Protocol(SERVER) server.receive_data(b"\xff\xff") - self.assertConnectionFailing(server, 1002, "invalid opcode") + self.assertConnectionFailing(server, CloseCode.PROTOCOL_ERROR, "invalid opcode") server.receive_data(b"\x00\x00") self.assertFrameSent(server, None) def test_client_receives_eof_after_exception(self): client = Protocol(CLIENT) client.receive_data(b"\xff\xff") - self.assertConnectionFailing(client, 1002, "invalid opcode") + self.assertConnectionFailing(client, CloseCode.PROTOCOL_ERROR, "invalid opcode") client.receive_eof() self.assertFrameSent(client, None, eof=True) def test_server_receives_eof_after_exception(self): server = Protocol(SERVER) server.receive_data(b"\xff\xff") - self.assertConnectionFailing(server, 1002, "invalid opcode") + self.assertConnectionFailing(server, CloseCode.PROTOCOL_ERROR, "invalid opcode") server.receive_eof() self.assertFrameSent(server, None) def test_client_receives_data_and_eof_after_exception(self): client = Protocol(CLIENT) client.receive_data(b"\xff\xff") - self.assertConnectionFailing(client, 1002, "invalid opcode") + self.assertConnectionFailing(client, CloseCode.PROTOCOL_ERROR, "invalid opcode") client.receive_data(b"\x00\x00") client.receive_eof() self.assertFrameSent(client, None, eof=True) @@ -1464,7 +1517,7 @@ def test_client_receives_data_and_eof_after_exception(self): def test_server_receives_data_and_eof_after_exception(self): server = Protocol(SERVER) server.receive_data(b"\xff\xff") - self.assertConnectionFailing(server, 1002, "invalid opcode") + self.assertConnectionFailing(server, CloseCode.PROTOCOL_ERROR, "invalid opcode") server.receive_data(b"\x00\x00") server.receive_eof() self.assertFrameSent(server, None) @@ -1554,12 +1607,12 @@ def test_server_receives_close(self): def test_client_fails_connection(self): client = Protocol(CLIENT) - client.fail(1002) + client.fail(CloseCode.PROTOCOL_ERROR) self.assertTrue(client.close_expected()) def test_server_fails_connection(self): server = Protocol(SERVER) - server.fail(1002) + server.fail(CloseCode.PROTOCOL_ERROR) self.assertTrue(server.close_expected()) @@ -1572,25 +1625,25 @@ class ConnectionClosedTests(ProtocolTestCase): def test_client_sends_close_then_receives_close(self): # Client-initiated close handshake on the client side complete. client = Protocol(CLIENT) - client.send_close(1000, "") + client.send_close(CloseCode.NORMAL_CLOSURE, "") client.receive_data(b"\x88\x02\x03\xe8") client.receive_eof() exc = client.close_exc self.assertIsInstance(exc, ConnectionClosedOK) - self.assertEqual(exc.rcvd, Close(1000, "")) - self.assertEqual(exc.sent, Close(1000, "")) + self.assertEqual(exc.rcvd, Close(CloseCode.NORMAL_CLOSURE, "")) + self.assertEqual(exc.sent, Close(CloseCode.NORMAL_CLOSURE, "")) self.assertFalse(exc.rcvd_then_sent) def test_server_sends_close_then_receives_close(self): # Server-initiated close handshake on the server side complete. server = Protocol(SERVER) - server.send_close(1000, "") + server.send_close(CloseCode.NORMAL_CLOSURE, "") server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe8") server.receive_eof() exc = server.close_exc self.assertIsInstance(exc, ConnectionClosedOK) - self.assertEqual(exc.rcvd, Close(1000, "")) - self.assertEqual(exc.sent, Close(1000, "")) + self.assertEqual(exc.rcvd, Close(CloseCode.NORMAL_CLOSURE, "")) + self.assertEqual(exc.sent, Close(CloseCode.NORMAL_CLOSURE, "")) self.assertFalse(exc.rcvd_then_sent) def test_client_receives_close_then_sends_close(self): @@ -1600,8 +1653,8 @@ def test_client_receives_close_then_sends_close(self): client.receive_eof() exc = client.close_exc self.assertIsInstance(exc, ConnectionClosedOK) - self.assertEqual(exc.rcvd, Close(1000, "")) - self.assertEqual(exc.sent, Close(1000, "")) + self.assertEqual(exc.rcvd, Close(CloseCode.NORMAL_CLOSURE, "")) + self.assertEqual(exc.sent, Close(CloseCode.NORMAL_CLOSURE, "")) self.assertTrue(exc.rcvd_then_sent) def test_server_receives_close_then_sends_close(self): @@ -1611,30 +1664,30 @@ def test_server_receives_close_then_sends_close(self): server.receive_eof() exc = server.close_exc self.assertIsInstance(exc, ConnectionClosedOK) - self.assertEqual(exc.rcvd, Close(1000, "")) - self.assertEqual(exc.sent, Close(1000, "")) + self.assertEqual(exc.rcvd, Close(CloseCode.NORMAL_CLOSURE, "")) + self.assertEqual(exc.sent, Close(CloseCode.NORMAL_CLOSURE, "")) self.assertTrue(exc.rcvd_then_sent) def test_client_sends_close_then_receives_eof(self): # Client-initiated close handshake on the client side times out. client = Protocol(CLIENT) - client.send_close(1000, "") + client.send_close(CloseCode.NORMAL_CLOSURE, "") client.receive_eof() exc = client.close_exc self.assertIsInstance(exc, ConnectionClosedError) self.assertIsNone(exc.rcvd) - self.assertEqual(exc.sent, Close(1000, "")) + self.assertEqual(exc.sent, Close(CloseCode.NORMAL_CLOSURE, "")) self.assertIsNone(exc.rcvd_then_sent) def test_server_sends_close_then_receives_eof(self): # Server-initiated close handshake on the server side times out. server = Protocol(SERVER) - server.send_close(1000, "") + server.send_close(CloseCode.NORMAL_CLOSURE, "") server.receive_eof() exc = server.close_exc self.assertIsInstance(exc, ConnectionClosedError) self.assertIsNone(exc.rcvd) - self.assertEqual(exc.sent, Close(1000, "")) + self.assertEqual(exc.sent, Close(CloseCode.NORMAL_CLOSURE, "")) self.assertIsNone(exc.rcvd_then_sent) def test_client_receives_eof(self): @@ -1671,7 +1724,7 @@ def test_client_hits_internal_error_reading_frame(self): client.receive_data(b"\x81\x00") self.assertIsInstance(client.parser_exc, RuntimeError) self.assertEqual(str(client.parser_exc), "BOOM") - self.assertConnectionFailing(client, 1011, "") + self.assertConnectionFailing(client, CloseCode.INTERNAL_ERROR, "") def test_server_hits_internal_error_reading_frame(self): server = Protocol(SERVER) @@ -1680,7 +1733,7 @@ def test_server_hits_internal_error_reading_frame(self): server.receive_data(b"\x81\x80\x00\x00\x00\x00") self.assertIsInstance(server.parser_exc, RuntimeError) self.assertEqual(str(server.parser_exc), "BOOM") - self.assertConnectionFailing(server, 1011, "") + self.assertConnectionFailing(server, CloseCode.INTERNAL_ERROR, "") class ExtensionsTests(ProtocolTestCase): From e3abb88ad6c8fc3e57f42aed4a00b9644c4e65df Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 18 May 2023 18:59:55 +0200 Subject: [PATCH 1223/1539] Document that the Host header isn't validated. --- docs/reference/features.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 3cc52ec10..98b3c0dda 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -174,6 +174,11 @@ There is no way to control compression of outgoing frames on a per-frame basis .. _#538: https://github.com/python-websockets/websockets/issues/538 +The server doesn't check the Host header and respond with a HTTP 400 Bad Request +if it is missing or invalid (`#1246`). + +.. _#1246: https://github.com/python-websockets/websockets/issues/1246 + The client API doesn't attempt to guarantee that there is no more than one connection to a given IP address in a CONNECTING state. This behavior is `mandated by RFC 6455`_. However, :func:`~client.connect()` isn't the right From 1bf9d1d766c80da4887240737266926f173fbcef Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 18 May 2023 19:16:31 +0200 Subject: [PATCH 1224/1539] Accept int status in process_request and reject. Fix #1309. --- src/websockets/exceptions.py | 3 ++- src/websockets/server.py | 2 ++ tests/legacy/test_client_server.py | 14 ++++++++++++++ tests/test_server.py | 6 ++++++ 4 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 9d8476648..0f0686872 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -334,7 +334,8 @@ def __init__( headers: datastructures.HeadersLike, body: bytes = b"", ) -> None: - self.status = status + # If a user passes an int instead of a HTTPStatus, fix it automatically. + self.status = http.HTTPStatus(status) self.headers = datastructures.Headers(headers) self.body = body diff --git a/src/websockets/server.py b/src/websockets/server.py index ecb0f74a6..b9646ea81 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -499,6 +499,8 @@ def reject( Response: WebSocket handshake response event to send to the client. """ + # If a user passes an int instead of a HTTPStatus, fix it automatically. + status = http.HTTPStatus(status) body = text.encode() headers = Headers( [ diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index cba24ad32..02dbe9e3f 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -190,6 +190,12 @@ async def process_request(self, path, request_headers): return http.HTTPStatus.OK, [("X-Access", "OK")], b"status = green\n" +class ProcessRequestReturningIntProtocol(WebSocketServerProtocol): + async def process_request(self, path, request_headers): + if path == "/__health__/": + return 200, [], b"OK\n" + + class SlowOpeningHandshakeProtocol(WebSocketServerProtocol): async def process_request(self, path, request_headers): await asyncio.sleep(10 * MS) @@ -757,6 +763,14 @@ def test_http_request_custom_server_header(self): with contextlib.closing(response): self.assertEqual(response.headers["Server"], "websockets") + @with_server(create_protocol=ProcessRequestReturningIntProtocol) + def test_process_request_returns_int_status(self): + response = self.loop.run_until_complete(self.make_http_request("/__health__/")) + + with contextlib.closing(response): + self.assertEqual(response.code, 200) + self.assertEqual(response.read(), b"OK\n") + def assert_client_raises_code(self, status_code): with self.assertRaises(InvalidStatusCode) as raised: self.start_client() diff --git a/tests/test_server.py b/tests/test_server.py index ecf3d4cbe..b6f5e3568 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -188,6 +188,12 @@ def test_reject_response(self): ) self.assertEqual(response.body, b"Sorry folks.\n") + def test_reject_response_supports_int_status(self): + server = ServerProtocol() + response = server.reject(404, "Sorry folks.\n") + self.assertEqual(response.status_code, 404) + self.assertEqual(response.reason_phrase, "Not Found") + def test_basic(self): server = ServerProtocol() request = self.make_request() From 7b9a53b74c09b27cf0899a28f38f6835c5d141ed Mon Sep 17 00:00:00 2001 From: Benjamin Loison <12752145+Benjamin-Loison@users.noreply.github.com> Date: Tue, 27 Jun 2023 18:17:43 +0200 Subject: [PATCH 1225/1539] Remove the `asyncio` import in `README.rst` for the `threading` example --- README.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/README.rst b/README.rst index f53d3d0fc..870b208ba 100644 --- a/README.rst +++ b/README.rst @@ -65,7 +65,6 @@ Here's how a client sends and receives messages with the ``threading`` API: #!/usr/bin/env python - import asyncio from websockets.sync.client import connect def hello(): From b870c46b2778444ae2d8960a0ae0bf4be53dc81c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 Jul 2023 10:14:45 +0200 Subject: [PATCH 1226/1539] Fix inconsistency in 2b627b26. --- docs/reference/datastructures.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/reference/datastructures.rst b/docs/reference/datastructures.rst index 7c037da4c..ec02d4210 100644 --- a/docs/reference/datastructures.rst +++ b/docs/reference/datastructures.rst @@ -21,12 +21,12 @@ WebSocket events .. autoclass:: CloseCode - .. autoattribute:: OK + .. autoattribute:: NORMAL_CLOSURE .. autoattribute:: GOING_AWAY .. autoattribute:: PROTOCOL_ERROR .. autoattribute:: UNSUPPORTED_DATA .. autoattribute:: NO_STATUS_RCVD - .. autoattribute:: CONNECTION_CLOSED_ABNORMALLY + .. autoattribute:: ABNORMAL_CLOSURE .. autoattribute:: INVALID_DATA .. autoattribute:: POLICY_VIOLATION .. autoattribute:: MESSAGE_TOO_BIG @@ -35,7 +35,7 @@ WebSocket events .. autoattribute:: SERVICE_RESTART .. autoattribute:: TRY_AGAIN_LATER .. autoattribute:: BAD_GATEWAY - .. autoattribute:: TLS_FAILURE + .. autoattribute:: TLS_HANDSHAKE HTTP events ----------- From 8ed5424a5bdcbe99565d05239e7ea6e5a15a3cdd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 Jul 2023 10:24:40 +0200 Subject: [PATCH 1227/1539] Clarify that the TLS example is simplistic. Fix #1381. --- docs/howto/quickstart.rst | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/docs/howto/quickstart.rst b/docs/howto/quickstart.rst index da3c3999e..ab870952c 100644 --- a/docs/howto/quickstart.rst +++ b/docs/howto/quickstart.rst @@ -57,9 +57,6 @@ Here's how to adapt the server to encrypt connections. You must download :download:`localhost.pem <../../example/quickstart/localhost.pem>` and save it in the same directory as ``server_secure.py``. -See the documentation of the :mod:`ssl` module for details on configuring the -TLS context securely. - .. literalinclude:: ../../example/quickstart/server_secure.py :caption: server_secure.py :language: python @@ -79,6 +76,15 @@ When connecting to a secure WebSocket server with a valid certificate — any certificate signed by a CA that your Python installation trusts — you can simply pass ``ssl=True`` to :func:`~client.connect`. +.. admonition:: Configure the TLS context securely + :class: attention + + This example demonstrates the ``ssl`` argument with a TLS certificate shared + between the client and the server. This is a simplistic setup. + + Please review the advice and security considerations in the documentation of + the :mod:`ssl` module to configure the TLS context securely. + Connect from a browser ---------------------- From adfb8d69a7a1f6f4c8381c9e7182619d202c3cf2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 Jul 2023 10:47:27 +0200 Subject: [PATCH 1228/1539] Fix test that fails on IPv6. --- tests/legacy/test_client_server.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 02dbe9e3f..056ff193f 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -397,9 +397,9 @@ def test_explicit_host_port(self): wsuri = parse_uri(uri) # Change host and port to invalid values. - changed_uri = uri.replace(wsuri.host, "example.com").replace( - str(wsuri.port), str(65535 - wsuri.port) - ) + scheme = "wss" if wsuri.secure else "ws" + port = 65535 - wsuri.port + changed_uri = f"{scheme}://example.com:{port}/" with self.temp_client(uri=changed_uri, host=wsuri.host, port=wsuri.port): self.loop.run_until_complete(self.client.send("Hello!")) From c41ce814bb173e2b174c60ce935b8598da626863 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 1 Oct 2023 15:55:46 +0200 Subject: [PATCH 1229/1539] Rejecting a connection isn't always a failure. Fix #1402. --- src/websockets/legacy/server.py | 2 +- src/websockets/server.py | 2 +- tests/legacy/test_client_server.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 77e0fdab7..c9b32c417 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -227,7 +227,7 @@ async def handler(self) -> None: self.write_http_response(status, headers, body) self.logger.info( - "connection failed (%d %s)", status.value, status.phrase + "connection rejected (%d %s)", status.value, status.phrase ) await self.close_transport() return diff --git a/src/websockets/server.py b/src/websockets/server.py index b9646ea81..34069faf0 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -516,7 +516,7 @@ def reject( # "handshake_exc is None if and only if opening handshake succeeded." if self.handshake_exc is None: self.handshake_exc = InvalidStatus(response) - self.logger.info("connection failed (%d %s)", status.value, status.phrase) + self.logger.info("connection rejected (%d %s)", status.value, status.phrase) return response def send_response(self, response: Response) -> None: diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 056ff193f..f7e02d101 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1597,12 +1597,12 @@ async def run_client(): self.assertEqual( [record.getMessage() for record in logs.records][4:-1], [ - "connection failed (503 Service Unavailable)", + "connection rejected (503 Service Unavailable)", "connection closed", "! connect failed; reconnecting in 0.0 seconds", ] + [ - "connection failed (503 Service Unavailable)", + "connection rejected (503 Service Unavailable)", "connection closed", "! connect failed again; retrying in 0 seconds", ] From 1b10ca16dbb4e6c84c8d27cd42e8e35cfb85f5e8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 1 Oct 2023 16:39:03 +0200 Subject: [PATCH 1230/1539] Add compatibility imports from legacy package in __all__. Supersedes #1400. --- src/websockets/auth.py | 2 ++ src/websockets/client.py | 4 +++- src/websockets/server.py | 4 +++- tests/test_exports.py | 10 +++------- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/websockets/auth.py b/src/websockets/auth.py index 5292e4f7f..b792e02f5 100644 --- a/src/websockets/auth.py +++ b/src/websockets/auth.py @@ -1,4 +1,6 @@ from __future__ import annotations # See #940 for why lazy_import isn't used here for backwards compatibility. +# See #1400 for why listing compatibility imports in __all__ helps PyCharm. from .legacy.auth import * +from .legacy.auth import __all__ # noqa: F401 diff --git a/src/websockets/client.py b/src/websockets/client.py index bf8427c37..b2f622042 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -38,10 +38,12 @@ # See #940 for why lazy_import isn't used here for backwards compatibility. +# See #1400 for why listing compatibility imports in __all__ helps PyCharm. from .legacy.client import * # isort:skip # noqa: I001 +from .legacy.client import __all__ as legacy__all__ -__all__ = ["ClientProtocol"] +__all__ = ["ClientProtocol"] + legacy__all__ class ClientProtocol(Protocol): diff --git a/src/websockets/server.py b/src/websockets/server.py index 34069faf0..872e4e5af 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -39,10 +39,12 @@ # See #940 for why lazy_import isn't used here for backwards compatibility. +# See #1400 for why listing compatibility imports in __all__ helps PyCharm. from .legacy.server import * # isort:skip # noqa: I001 +from .legacy.server import __all__ as legacy__all__ -__all__ = ["ServerProtocol"] +__all__ = ["ServerProtocol"] + legacy__all__ class ServerProtocol(Protocol): diff --git a/tests/test_exports.py b/tests/test_exports.py index 978b1d0e7..d63cb590c 100644 --- a/tests/test_exports.py +++ b/tests/test_exports.py @@ -1,24 +1,20 @@ import unittest import websockets +import websockets.auth import websockets.client import websockets.exceptions -import websockets.legacy.auth -import websockets.legacy.client import websockets.legacy.protocol -import websockets.legacy.server import websockets.server import websockets.typing import websockets.uri combined_exports = ( - websockets.legacy.auth.__all__ - + websockets.legacy.client.__all__ - + websockets.legacy.protocol.__all__ - + websockets.legacy.server.__all__ + websockets.auth.__all__ + websockets.client.__all__ + websockets.exceptions.__all__ + + websockets.legacy.protocol.__all__ + websockets.server.__all__ + websockets.typing.__all__ + websockets.uri.__all__ From c9eefae63d2e456296e32da177749b392dd37b35 Mon Sep 17 00:00:00 2001 From: kxxt Date: Wed, 14 Jun 2023 15:26:56 +0800 Subject: [PATCH 1231/1539] test: allow WEBSOCKETS_TESTS_TIMEOUT_FACTOR to be float --- tests/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils.py b/tests/utils.py index 4e9ac9f0e..2937a2f15 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -23,7 +23,7 @@ # Unit for timeouts. May be increased on slow machines by setting the # WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. -MS = 0.001 * int(os.environ.get("WEBSOCKETS_TESTS_TIMEOUT_FACTOR", "1")) +MS = 0.001 * float(os.environ.get("WEBSOCKETS_TESTS_TIMEOUT_FACTOR", "1")) # PyPy has a performance penalty for this test suite. if platform.python_implementation() == "PyPy": # pragma: no cover From ca5926ed1532fcd1dd264ef4f5a7a357fef66cfc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 1 Oct 2023 16:59:33 +0200 Subject: [PATCH 1232/1539] Go back to 100% coverage. Broken in 1bf9d1d7. --- tests/legacy/test_client_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index f7e02d101..c49d91b70 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -192,8 +192,8 @@ async def process_request(self, path, request_headers): class ProcessRequestReturningIntProtocol(WebSocketServerProtocol): async def process_request(self, path, request_headers): - if path == "/__health__/": - return 200, [], b"OK\n" + assert path == "/__health__/" + return 200, [], b"OK\n" class SlowOpeningHandshakeProtocol(WebSocketServerProtocol): From 439dafa656029f27c26956181d27a781ec04b105 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 1 Oct 2023 17:15:37 +0200 Subject: [PATCH 1233/1539] Import eagerly when running under a type checker. Fix #1292 (and many others). Also fix some inconsistencies in the lists. The new rule is: - when re-exporting a module, re-export it entirely; - don't re-export deprecated modules. --- docs/faq/misc.rst | 23 ---- docs/project/changelog.rst | 5 + docs/reference/index.rst | 13 --- src/websockets/__init__.py | 229 ++++++++++++++++++++++++------------- src/websockets/http.py | 29 +++-- tests/test_exports.py | 3 +- 6 files changed, 175 insertions(+), 127 deletions(-) diff --git a/docs/faq/misc.rst b/docs/faq/misc.rst index 4fc271322..ee5ad2372 100644 --- a/docs/faq/misc.rst +++ b/docs/faq/misc.rst @@ -12,29 +12,6 @@ instead of the websockets library. .. _real-import-paths: -Why does my IDE fail to show documentation for websockets APIs? -............................................................... - -You are probably using the convenience imports e.g.:: - - import websockets - - websockets.connect(...) - websockets.serve(...) - -This is incompatible with static code analysis. It may break auto-completion and -contextual documentation in IDEs, type checking with mypy_, etc. - -.. _mypy: https://github.com/python/mypy - -Instead, use the real import paths e.g.:: - - import websockets.client - import websockets.server - - websockets.client.connect(...) - websockets.server.serve(...) - Why is the default implementation located in ``websockets.legacy``? ................................................................... diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 94fc5ebd9..ad9ab5908 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -41,6 +41,11 @@ Backwards-incompatible changes Improvements ............ +* Made convenience imports from ``websockets`` compatible with static code + analysis tools such as auto-completion in an IDE or type checking with mypy_. + + .. _mypy: https://github.com/python/mypy + * Added :class:`~frames.CloseCode`. 11.0.3 diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 2a9556dd9..0b80f087a 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -88,16 +88,3 @@ Convenience imports For convenience, many public APIs can be imported directly from the ``websockets`` package. - - -.. admonition:: Convenience imports are incompatible with some development tools. - :class: caution - - Specifically, static code analysis tools don't understand them. This breaks - auto-completion and contextual documentation in IDEs, type checking with - mypy_, etc. - - .. _mypy: https://github.com/python/mypy - - If you're using such tools, stick to the full import paths, as explained in - this FAQ: :ref:`real-import-paths` diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index dcf3d8150..2b9c3bc54 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -1,23 +1,24 @@ from __future__ import annotations +import typing + from .imports import lazy_import from .version import version as __version__ # noqa: F401 __all__ = [ - "AbortHandshake", - "basic_auth_protocol_factory", - "BasicAuthWebSocketServerProtocol", - "broadcast", + # .client "ClientProtocol", - "connect", + # .datastructures + "Headers", + "HeadersLike", + "MultipleValuesError", + # .exceptions + "AbortHandshake", "ConnectionClosed", "ConnectionClosedError", "ConnectionClosedOK", - "Data", "DuplicateParameter", - "ExtensionName", - "ExtensionParameter", "InvalidHandshake", "InvalidHeader", "InvalidHeaderFormat", @@ -31,84 +32,156 @@ "InvalidStatusCode", "InvalidUpgrade", "InvalidURI", - "LoggerLike", "NegotiationError", - "Origin", - "parse_uri", "PayloadTooBig", "ProtocolError", "RedirectHandshake", "SecurityError", - "serve", - "ServerProtocol", - "Subprotocol", - "unix_connect", - "unix_serve", - "WebSocketClientProtocol", - "WebSocketCommonProtocol", "WebSocketException", "WebSocketProtocolError", + # .legacy.auth + "BasicAuthWebSocketServerProtocol", + "basic_auth_protocol_factory", + # .legacy.client + "WebSocketClientProtocol", + "connect", + "unix_connect", + # .legacy.protocol + "WebSocketCommonProtocol", + "broadcast", + # .legacy.server "WebSocketServer", "WebSocketServerProtocol", - "WebSocketURI", + "serve", + "unix_serve", + # .server + "ServerProtocol", + # .typing + "Data", + "ExtensionName", + "ExtensionParameter", + "LoggerLike", + "Origin", + "Subprotocol", ] -lazy_import( - globals(), - aliases={ - "auth": ".legacy", - "basic_auth_protocol_factory": ".legacy.auth", - "BasicAuthWebSocketServerProtocol": ".legacy.auth", - "broadcast": ".legacy.protocol", - "ClientProtocol": ".client", - "connect": ".legacy.client", - "unix_connect": ".legacy.client", - "WebSocketClientProtocol": ".legacy.client", - "Headers": ".datastructures", - "MultipleValuesError": ".datastructures", - "WebSocketException": ".exceptions", - "ConnectionClosed": ".exceptions", - "ConnectionClosedError": ".exceptions", - "ConnectionClosedOK": ".exceptions", - "InvalidHandshake": ".exceptions", - "SecurityError": ".exceptions", - "InvalidMessage": ".exceptions", - "InvalidHeader": ".exceptions", - "InvalidHeaderFormat": ".exceptions", - "InvalidHeaderValue": ".exceptions", - "InvalidOrigin": ".exceptions", - "InvalidUpgrade": ".exceptions", - "InvalidStatus": ".exceptions", - "InvalidStatusCode": ".exceptions", - "NegotiationError": ".exceptions", - "DuplicateParameter": ".exceptions", - "InvalidParameterName": ".exceptions", - "InvalidParameterValue": ".exceptions", - "AbortHandshake": ".exceptions", - "RedirectHandshake": ".exceptions", - "InvalidState": ".exceptions", - "InvalidURI": ".exceptions", - "PayloadTooBig": ".exceptions", - "ProtocolError": ".exceptions", - "WebSocketProtocolError": ".exceptions", - "protocol": ".legacy", - "WebSocketCommonProtocol": ".legacy.protocol", - "ServerProtocol": ".server", - "serve": ".legacy.server", - "unix_serve": ".legacy.server", - "WebSocketServerProtocol": ".legacy.server", - "WebSocketServer": ".legacy.server", - "Data": ".typing", - "LoggerLike": ".typing", - "Origin": ".typing", - "ExtensionHeader": ".typing", - "ExtensionParameter": ".typing", - "Subprotocol": ".typing", - }, - deprecated_aliases={ - "framing": ".legacy", - "handshake": ".legacy", - "parse_uri": ".uri", - "WebSocketURI": ".uri", - }, -) +# When type checking, import non-deprecated aliases eagerly. Else, import on demand. +if typing.TYPE_CHECKING: + from .client import ClientProtocol + from .datastructures import Headers, HeadersLike, MultipleValuesError + from .exceptions import ( + AbortHandshake, + ConnectionClosed, + ConnectionClosedError, + ConnectionClosedOK, + DuplicateParameter, + InvalidHandshake, + InvalidHeader, + InvalidHeaderFormat, + InvalidHeaderValue, + InvalidMessage, + InvalidOrigin, + InvalidParameterName, + InvalidParameterValue, + InvalidState, + InvalidStatus, + InvalidStatusCode, + InvalidUpgrade, + InvalidURI, + NegotiationError, + PayloadTooBig, + ProtocolError, + RedirectHandshake, + SecurityError, + WebSocketException, + WebSocketProtocolError, + ) + from .legacy.auth import ( + BasicAuthWebSocketServerProtocol, + basic_auth_protocol_factory, + ) + from .legacy.client import WebSocketClientProtocol, connect, unix_connect + from .legacy.protocol import WebSocketCommonProtocol, broadcast + from .legacy.server import ( + WebSocketServer, + WebSocketServerProtocol, + serve, + unix_serve, + ) + from .server import ServerProtocol + from .typing import ( + Data, + ExtensionName, + ExtensionParameter, + LoggerLike, + Origin, + Subprotocol, + ) +else: + lazy_import( + globals(), + aliases={ + # .client + "ClientProtocol": ".client", + # .datastructures + "Headers": ".datastructures", + "HeadersLike": ".datastructures", + "MultipleValuesError": ".datastructures", + # .exceptions + "AbortHandshake": ".exceptions", + "ConnectionClosed": ".exceptions", + "ConnectionClosedError": ".exceptions", + "ConnectionClosedOK": ".exceptions", + "DuplicateParameter": ".exceptions", + "InvalidHandshake": ".exceptions", + "InvalidHeader": ".exceptions", + "InvalidHeaderFormat": ".exceptions", + "InvalidHeaderValue": ".exceptions", + "InvalidMessage": ".exceptions", + "InvalidOrigin": ".exceptions", + "InvalidParameterName": ".exceptions", + "InvalidParameterValue": ".exceptions", + "InvalidState": ".exceptions", + "InvalidStatus": ".exceptions", + "InvalidStatusCode": ".exceptions", + "InvalidUpgrade": ".exceptions", + "InvalidURI": ".exceptions", + "NegotiationError": ".exceptions", + "PayloadTooBig": ".exceptions", + "ProtocolError": ".exceptions", + "RedirectHandshake": ".exceptions", + "SecurityError": ".exceptions", + "WebSocketException": ".exceptions", + "WebSocketProtocolError": ".exceptions", + # .legacy.auth + "BasicAuthWebSocketServerProtocol": ".legacy.auth", + "basic_auth_protocol_factory": ".legacy.auth", + # .legacy.client + "WebSocketClientProtocol": ".legacy.client", + "connect": ".legacy.client", + "unix_connect": ".legacy.client", + # .legacy.protocol + "WebSocketCommonProtocol": ".legacy.protocol", + "broadcast": ".legacy.protocol", + # .legacy.server + "WebSocketServer": ".legacy.server", + "WebSocketServerProtocol": ".legacy.server", + "serve": ".legacy.server", + "unix_serve": ".legacy.server", + # .server + "ServerProtocol": ".server", + # .typing + "Data": ".typing", + "ExtensionName": ".typing", + "ExtensionParameter": ".typing", + "LoggerLike": ".typing", + "Origin": ".typing", + "Subprotocol": ".typing", + }, + deprecated_aliases={ + "framing": ".legacy", + "handshake": ".legacy", + "parse_uri": ".uri", + "WebSocketURI": ".uri", + }, + ) diff --git a/src/websockets/http.py b/src/websockets/http.py index b14fa94bd..9f86f6a1f 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -1,6 +1,7 @@ from __future__ import annotations import sys +import typing from .imports import lazy_import from .version import version as websockets_version @@ -9,18 +10,22 @@ # For backwards compatibility: -lazy_import( - globals(), - # Headers and MultipleValuesError used to be defined in this module. - aliases={ - "Headers": ".datastructures", - "MultipleValuesError": ".datastructures", - }, - deprecated_aliases={ - "read_request": ".legacy.http", - "read_response": ".legacy.http", - }, -) +# When type checking, import non-deprecated aliases eagerly. Else, import on demand. +if typing.TYPE_CHECKING: + from .datastructures import Headers, MultipleValuesError # noqa: F401 +else: + lazy_import( + globals(), + # Headers and MultipleValuesError used to be defined in this module. + aliases={ + "Headers": ".datastructures", + "MultipleValuesError": ".datastructures", + }, + deprecated_aliases={ + "read_request": ".legacy.http", + "read_response": ".legacy.http", + }, + ) __all__ = ["USER_AGENT"] diff --git a/tests/test_exports.py b/tests/test_exports.py index d63cb590c..67a1a6f99 100644 --- a/tests/test_exports.py +++ b/tests/test_exports.py @@ -3,6 +3,7 @@ import websockets import websockets.auth import websockets.client +import websockets.datastructures import websockets.exceptions import websockets.legacy.protocol import websockets.server @@ -13,11 +14,11 @@ combined_exports = ( websockets.auth.__all__ + websockets.client.__all__ + + websockets.datastructures.__all__ + websockets.exceptions.__all__ + websockets.legacy.protocol.__all__ + websockets.server.__all__ + websockets.typing.__all__ - + websockets.uri.__all__ ) From 678b4380a258abd2a074475f3a817823241a4362 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 1 Oct 2023 17:43:00 +0200 Subject: [PATCH 1234/1539] Fix coverage. --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 530052ddd..f24616dd7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ exclude_lines = [ "except ImportError:", "if self.debug:", "if sys.platform != \"win32\":", + "if typing.TYPE_CHECKING:", "pragma: no cover", "raise AssertionError", "raise NotImplementedError", From be551424ecf17b92ce5cc24f5eb73f2796eb1daa Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 7 Oct 2023 05:42:16 +0000 Subject: [PATCH 1235/1539] Bump actions/checkout from 3 to 4 Bumps [actions/checkout](https://github.com/actions/checkout) from 3 to 4. - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v3...v4) --- updated-dependencies: - dependency-name: actions/checkout dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/tests.yml | 6 +++--- .github/workflows/wheels.yml | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 603426412..470f5bc96 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -17,7 +17,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install Python 3.x uses: actions/setup-python@v4 with: @@ -34,7 +34,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install Python 3.x uses: actions/setup-python@v4 with: @@ -72,7 +72,7 @@ jobs: is_main: false steps: - name: Check out repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install Python ${{ matrix.python }} uses: actions/setup-python@v4 with: diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 8aa5c0b7b..e846c54d9 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install Python 3.x uses: actions/setup-python@v4 with: @@ -45,7 +45,7 @@ jobs: - macOS-latest steps: - name: Check out repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install Python 3.x uses: actions/setup-python@v4 with: From f62d44ac1a0a606fe5753cf0017ce9e1dbf3640e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 7 Oct 2023 05:42:20 +0000 Subject: [PATCH 1236/1539] Bump docker/setup-qemu-action from 2 to 3 Bumps [docker/setup-qemu-action](https://github.com/docker/setup-qemu-action) from 2 to 3. - [Release notes](https://github.com/docker/setup-qemu-action/releases) - [Commits](https://github.com/docker/setup-qemu-action/compare/v2...v3) --- updated-dependencies: - dependency-name: docker/setup-qemu-action dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/wheels.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index e846c54d9..0e182cf78 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -52,7 +52,7 @@ jobs: python-version: 3.x - name: Set up QEMU if: runner.os == 'Linux' - uses: docker/setup-qemu-action@v2 + uses: docker/setup-qemu-action@v3 with: platforms: all - name: Build wheels From ed6cb1b8d0b3a6121ecef06fcec8e3f4de417a5b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Oct 2023 15:26:22 +0200 Subject: [PATCH 1237/1539] Make it explicit that status codes can be int. The code and tests already support it but the types didn't reflect it. Refs #1406. --- docs/reference/types.rst | 2 ++ src/websockets/__init__.py | 3 +++ src/websockets/exceptions.py | 3 ++- src/websockets/legacy/server.py | 8 ++++---- src/websockets/server.py | 3 ++- src/websockets/typing.py | 7 +++++++ 6 files changed, 20 insertions(+), 6 deletions(-) diff --git a/docs/reference/types.rst b/docs/reference/types.rst index 88550d08d..9d3aa8310 100644 --- a/docs/reference/types.rst +++ b/docs/reference/types.rst @@ -7,6 +7,8 @@ Types .. autodata:: LoggerLike + .. autodata:: StatusLike + .. autodata:: Origin .. autodata:: Subprotocol diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 2b9c3bc54..fdb028f4c 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -61,6 +61,7 @@ "ExtensionName", "ExtensionParameter", "LoggerLike", + "StatusLike", "Origin", "Subprotocol", ] @@ -115,6 +116,7 @@ ExtensionParameter, LoggerLike, Origin, + StatusLike, Subprotocol, ) else: @@ -176,6 +178,7 @@ "ExtensionParameter": ".typing", "LoggerLike": ".typing", "Origin": ".typing", + "StatusLike": "typing", "Subprotocol": ".typing", }, deprecated_aliases={ diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 0f0686872..f7169e3b1 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -34,6 +34,7 @@ from typing import Optional from . import datastructures, frames, http11 +from .typing import StatusLike __all__ = [ @@ -330,7 +331,7 @@ class AbortHandshake(InvalidHandshake): def __init__( self, - status: http.HTTPStatus, + status: StatusLike, headers: datastructures.HeadersLike, body: bytes = b"", ) -> None: diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index c9b32c417..7c24dd74a 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -45,7 +45,7 @@ ) from ..http import USER_AGENT from ..protocol import State -from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol +from ..typing import ExtensionHeader, LoggerLike, Origin, StatusLike, Subprotocol from .compatibility import asyncio_timeout from .handshake import build_response, check_request from .http import read_request @@ -57,7 +57,7 @@ HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]] -HTTPResponse = Tuple[http.HTTPStatus, HeadersLike, bytes] +HTTPResponse = Tuple[StatusLike, HeadersLike, bytes] class WebSocketServerProtocol(WebSocketCommonProtocol): @@ -349,7 +349,7 @@ async def process_request( request_headers: request headers. Returns: - Optional[Tuple[http.HTTPStatus, HeadersLike, bytes]]: :obj:`None` + Optional[Tuple[StatusLike, HeadersLike, bytes]]: :obj:`None` to continue the WebSocket handshake normally. An HTTP response, represented by a 3-uple of the response status, @@ -943,7 +943,7 @@ class Serve: It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. process_request (Optional[Callable[[str, Headers], \ - Awaitable[Optional[Tuple[http.HTTPStatus, HeadersLike, bytes]]]]]): + Awaitable[Optional[Tuple[StatusLike, HeadersLike, bytes]]]]]): Intercept HTTP request before the opening handshake. See :meth:`~WebSocketServerProtocol.process_request` for details. select_subprotocol: Select a subprotocol supported by the client. diff --git a/src/websockets/server.py b/src/websockets/server.py index 872e4e5af..191660553 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -32,6 +32,7 @@ ExtensionHeader, LoggerLike, Origin, + StatusLike, Subprotocol, UpgradeProtocol, ) @@ -480,7 +481,7 @@ def select_subprotocol(protocol, subprotocols): def reject( self, - status: http.HTTPStatus, + status: StatusLike, text: str, ) -> Response: """ diff --git a/src/websockets/typing.py b/src/websockets/typing.py index e672ba006..cc3e3ec0d 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -1,5 +1,6 @@ from __future__ import annotations +import http import logging from typing import List, NewType, Optional, Tuple, Union @@ -7,6 +8,7 @@ __all__ = [ "Data", "LoggerLike", + "StatusLike", "Origin", "Subprotocol", "ExtensionName", @@ -30,6 +32,11 @@ """Types accepted where a :class:`~logging.Logger` is expected.""" +StatusLike = Union[http.HTTPStatus, int] +""" +Types accepted where an :class:`~http.HTTPStatus` is expected.""" + + Origin = NewType("Origin", str) """Value of a ``Origin`` header.""" From 1077920fec5ebc274df16efc86bbb0a8056f3601 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Oct 2023 15:39:16 +0200 Subject: [PATCH 1238/1539] Add changelog for previous commit. --- docs/project/changelog.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index ad9ab5908..d87f4545c 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -46,6 +46,8 @@ Improvements .. _mypy: https://github.com/python/mypy +* Accepted a plain :class:`int` where an :class:`~http.HTTPStatus` is expected. + * Added :class:`~frames.CloseCode`. 11.0.3 From 20413649b09aa98f136b1354a1f3ee801b663f36 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 21 Oct 2023 13:50:17 +0000 Subject: [PATCH 1239/1539] Bump pypa/cibuildwheel from 2.12.1 to 2.16.2 Bumps [pypa/cibuildwheel](https://github.com/pypa/cibuildwheel) from 2.12.1 to 2.16.2. - [Release notes](https://github.com/pypa/cibuildwheel/releases) - [Changelog](https://github.com/pypa/cibuildwheel/blob/main/docs/changelog.md) - [Commits](https://github.com/pypa/cibuildwheel/compare/v2.12.1...v2.16.2) --- updated-dependencies: - dependency-name: pypa/cibuildwheel dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/wheels.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 0e182cf78..707ef2c60 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -56,7 +56,7 @@ jobs: with: platforms: all - name: Build wheels - uses: pypa/cibuildwheel@v2.12.1 + uses: pypa/cibuildwheel@v2.16.2 env: BUILD_EXTENSION: yes - name: Save wheels From 01195322d2620a44039b716cb93c108c2ca9b6b9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Oct 2023 15:41:38 +0200 Subject: [PATCH 1240/1539] Release version 12.0 --- docs/project/changelog.rst | 2 +- src/websockets/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index d87f4545c..264e6e42d 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -28,7 +28,7 @@ notice. 12.0 ---- -*In development* +*October 21, 2023* Backwards-incompatible changes .............................. diff --git a/src/websockets/version.py b/src/websockets/version.py index 3f171b391..d1c99458e 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,7 +18,7 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = False +released = True tag = version = commit = "12.0" From 310c29512955b37fffee685120108795a8436b6c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Oct 2023 16:30:14 +0200 Subject: [PATCH 1241/1539] Rename workflow for making a release. --- .github/workflows/{wheels.yml => release.yml} | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) rename .github/workflows/{wheels.yml => release.yml} (97%) diff --git a/.github/workflows/wheels.yml b/.github/workflows/release.yml similarity index 97% rename from .github/workflows/wheels.yml rename to .github/workflows/release.yml index 707ef2c60..90f24b6f1 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/release.yml @@ -1,4 +1,4 @@ -name: Build wheels +name: Make release on: push: @@ -64,8 +64,8 @@ jobs: with: path: wheelhouse/*.whl - release: - name: Release + upload: + name: Upload needs: - sdist - wheels From 88e702ddaf214b46fcf6b3ceca25961f79ca9d00 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Oct 2023 16:39:20 +0200 Subject: [PATCH 1242/1539] Upgrade to Trusted Publishing. --- .github/workflows/release.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 90f24b6f1..8fad13529 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -72,6 +72,8 @@ jobs: runs-on: ubuntu-latest # Don't release when running the workflow manually from GitHub's UI. if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') + permissions: + id-token: write steps: - name: Download artifacts uses: actions/download-artifact@v3 @@ -80,8 +82,6 @@ jobs: path: dist - name: Upload to PyPI uses: pypa/gh-action-pypi-publish@release/v1 - with: - password: ${{ secrets.PYPI_API_TOKEN }} - name: Create GitHub release env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} From 5121bd15f988cc446db95b15a0bcac8dc64b68ab Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Oct 2023 16:54:27 +0200 Subject: [PATCH 1243/1539] Blind fix for automatic release creation. --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 8fad13529..6e895e64e 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -85,4 +85,4 @@ jobs: - name: Create GitHub release env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: gh release create ${{ github.ref_name }} --notes "See https://websockets.readthedocs.io/en/stable/project/changelog.html for details." + run: gh release -R python-websockets/websockets create ${{ github.ref_name }} --notes "See https://websockets.readthedocs.io/en/stable/project/changelog.html for details." From 2431e09eebc75578e310627f0eab38cd81df2f6b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 3 Nov 2023 07:55:23 +0100 Subject: [PATCH 1244/1539] Fix import style (likely autogenerated). --- src/websockets/sync/server.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 14767968c..d12da0c65 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -11,10 +11,9 @@ from types import TracebackType from typing import Any, Callable, Optional, Sequence, Type -from websockets.frames import CloseCode - from ..extensions.base import ServerExtensionFactory from ..extensions.permessage_deflate import enable_server_permessage_deflate +from ..frames import CloseCode from ..headers import validate_subprotocols from ..http import USER_AGENT from ..http11 import Request, Response From ec3bd2ab06278602c1d6018b476699e090036373 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 3 Nov 2023 08:22:33 +0100 Subject: [PATCH 1245/1539] Make sync reassembler more readable. No logic changes. --- src/websockets/sync/messages.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 67a22313c..d98ff855b 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -47,13 +47,13 @@ def __init__(self) -> None: # queue for transferring frames from the writing thread (library code) # to the reading thread (user code). We're buffering when chunks_queue # is None and streaming when it's a SimpleQueue. None is a sentinel - # value marking the end of the stream, superseding message_complete. + # value marking the end of the message, superseding message_complete. # Stream data from frames belonging to the same message. # Remove quotes around type when dropping Python < 3.9. self.chunks_queue: Optional["queue.SimpleQueue[Optional[Data]]"] = None - # This flag marks the end of the stream. + # This flag marks the end of the connection. self.closed = False def get(self, timeout: Optional[float] = None) -> Data: @@ -108,12 +108,12 @@ def get(self, timeout: Optional[float] = None) -> Data: # mypy cannot figure out that chunks have the proper type. message: Data = joiner.join(self.chunks) # type: ignore - assert not self.message_fetched.is_set() - self.message_fetched.set() - self.chunks = [] assert self.chunks_queue is None + assert not self.message_fetched.is_set() + self.message_fetched.set() + return message def get_iter(self) -> Iterator[Data]: @@ -169,26 +169,26 @@ def get_iter(self) -> Iterator[Data]: with self.mutex: self.get_in_progress = False - assert self.message_complete.is_set() - self.message_complete.clear() - # get_iter() was unblocked by close() rather than put(). if self.closed: raise EOFError("stream of frames ended") - assert not self.message_fetched.is_set() - self.message_fetched.set() + assert self.message_complete.is_set() + self.message_complete.clear() assert self.chunks == [] self.chunks_queue = None + assert not self.message_fetched.is_set() + self.message_fetched.set() + def put(self, frame: Frame) -> None: """ Add ``frame`` to the next message. When ``frame`` is the final frame in a message, :meth:`put` waits until - the message is fetched, either by calling :meth:`get` or by fully - consuming the return value of :meth:`get_iter`. + the message is fetched, which can be achieved by calling :meth:`get` or + by fully consuming the return value of :meth:`get_iter`. :meth:`put` assumes that the stream of frames respects the protocol. If it doesn't, the behavior is undefined. @@ -247,13 +247,13 @@ def put(self, frame: Frame) -> None: with self.mutex: self.put_in_progress = False - assert self.message_fetched.is_set() - self.message_fetched.clear() - # put() was unblocked by close() rather than get() or get_iter(). if self.closed: raise EOFError("stream of frames ended") + assert self.message_fetched.is_set() + self.message_fetched.clear() + self.decoder = None def close(self) -> None: From 5737b474ad7d4a3a5e04d68299f4e5ec34bd62ac Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 25 Nov 2023 14:46:44 +0100 Subject: [PATCH 1246/1539] Start version 12.1. This commit should have been made right after releasing 12.0. --- docs/project/changelog.rst | 5 +++++ src/websockets/version.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 264e6e42d..200ca7ef3 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,11 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. +12.1 +---- + +*In development* + 12.0 ---- diff --git a/src/websockets/version.py b/src/websockets/version.py index d1c99458e..f1de3cbf4 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,9 +18,9 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = True +released = False -tag = version = commit = "12.0" +tag = version = commit = "12.1" if not released: # pragma: no cover From 94dd203f63bb52b1a30faa228e63ada2f0f2e874 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 9 Dec 2023 06:25:11 +0000 Subject: [PATCH 1247/1539] Bump actions/setup-python from 4 to 5 Bumps [actions/setup-python](https://github.com/actions/setup-python) from 4 to 5. - [Release notes](https://github.com/actions/setup-python/releases) - [Commits](https://github.com/actions/setup-python/compare/v4...v5) --- updated-dependencies: - dependency-name: actions/setup-python dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/release.yml | 4 ++-- .github/workflows/tests.yml | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 6e895e64e..7d56b9aa5 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -14,7 +14,7 @@ jobs: - name: Check out repository uses: actions/checkout@v4 - name: Install Python 3.x - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: 3.x - name: Build sdist @@ -47,7 +47,7 @@ jobs: - name: Check out repository uses: actions/checkout@v4 - name: Install Python 3.x - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: 3.x - name: Set up QEMU diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 470f5bc96..b128defb5 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -19,7 +19,7 @@ jobs: - name: Check out repository uses: actions/checkout@v4 - name: Install Python 3.x - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.x" - name: Install tox @@ -36,7 +36,7 @@ jobs: - name: Check out repository uses: actions/checkout@v4 - name: Install Python 3.x - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.x" - name: Install tox @@ -74,7 +74,7 @@ jobs: - name: Check out repository uses: actions/checkout@v4 - name: Install Python ${{ matrix.python }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} - name: Install tox From fe1833fb103f4d63baee525c5b62dedd24b9884e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 25 Nov 2023 14:48:59 +0100 Subject: [PATCH 1248/1539] Confirm support for Python 3.12. Fix #1417. --- .github/workflows/tests.yml | 1 + docs/project/changelog.rst | 5 +++++ pyproject.toml | 1 + tox.ini | 1 + 4 files changed, 8 insertions(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b128defb5..8161f1cbb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -61,6 +61,7 @@ jobs: - "3.9" - "3.10" - "3.11" + - "3.12" - "pypy-3.8" - "pypy-3.9" is_main: diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 200ca7ef3..963353d0e 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -30,6 +30,11 @@ notice. *In development* +New features +............ + +* Validated compatibility with Python 3.12. + 12.0 ---- diff --git a/pyproject.toml b/pyproject.toml index f24616dd7..a7b4a6a9e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ classifiers = [ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ] dynamic = ["version", "readme"] diff --git a/tox.ini b/tox.ini index 939d8c0cd..538b638d9 100644 --- a/tox.ini +++ b/tox.ini @@ -5,6 +5,7 @@ envlist = py39 py310 py311 + py312 coverage black ruff From beeb9387dedb574c8d1a6c2a6e7312c17788c858 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 16 Dec 2023 06:24:38 +0000 Subject: [PATCH 1249/1539] Bump actions/download-artifact from 3 to 4 Bumps [actions/download-artifact](https://github.com/actions/download-artifact) from 3 to 4. - [Release notes](https://github.com/actions/download-artifact/releases) - [Commits](https://github.com/actions/download-artifact/compare/v3...v4) --- updated-dependencies: - dependency-name: actions/download-artifact dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7d56b9aa5..c1b750c80 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -76,7 +76,7 @@ jobs: id-token: write steps: - name: Download artifacts - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: artifact path: dist From 33b20e11e86f8490770185c78ed39adab8db4560 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 16 Dec 2023 06:24:43 +0000 Subject: [PATCH 1250/1539] Bump actions/upload-artifact from 3 to 4 Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 3 to 4. - [Release notes](https://github.com/actions/upload-artifact/releases) - [Commits](https://github.com/actions/upload-artifact/compare/v3...v4) --- updated-dependencies: - dependency-name: actions/upload-artifact dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/release.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index c1b750c80..4a00bf8fc 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -20,7 +20,7 @@ jobs: - name: Build sdist run: python setup.py sdist - name: Save sdist - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: path: dist/*.tar.gz - name: Install wheel @@ -30,7 +30,7 @@ jobs: BUILD_EXTENSION: no run: python setup.py bdist_wheel - name: Save wheel - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: path: dist/*.whl @@ -60,7 +60,7 @@ jobs: env: BUILD_EXTENSION: yes - name: Save wheels - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: path: wheelhouse/*.whl From b3c51958849c80209b4d68fca081ef3fffc5e2bd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 1 Jan 2024 15:02:51 +0100 Subject: [PATCH 1251/1539] Make test_local/remote_address more robust. Fix #1427. --- tests/sync/test_connection.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 63544d4ad..e128425d8 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -656,13 +656,17 @@ def test_logger(self): """Connection has a logger attribute.""" self.assertIsInstance(self.connection.logger, logging.LoggerAdapter) - def test_local_address(self): - """Connection has a local_address attribute.""" - self.assertIsNotNone(self.connection.local_address) - - def test_remote_address(self): - """Connection has a remote_address attribute.""" - self.assertIsNotNone(self.connection.remote_address) + @unittest.mock.patch("socket.socket.getsockname", return_value=("sock", 1234)) + def test_local_address(self, getsockname): + """Connection provides a local_address attribute.""" + self.assertEqual(self.connection.local_address, ("sock", 1234)) + getsockname.assert_called_with() + + @unittest.mock.patch("socket.socket.getpeername", return_value=("peer", 1234)) + def test_remote_address(self, getpeername): + """Connection provides a remote_address attribute.""" + self.assertEqual(self.connection.remote_address, ("peer", 1234)) + getpeername.assert_called_with() def test_request(self): """Connection has a request attribute.""" From 9038a62e7261af21109977407907038a1a0efc65 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 1 Jan 2024 15:38:53 +0100 Subject: [PATCH 1252/1539] Make mypy 1.8.0 happy. --- src/websockets/legacy/auth.py | 2 +- src/websockets/typing.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index d3425836e..e8d6b75d5 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -159,7 +159,7 @@ def basic_auth_protocol_factory( if is_credentials(credentials): credentials_list = [cast(Credentials, credentials)] elif isinstance(credentials, Iterable): - credentials_list = list(credentials) + credentials_list = list(cast(Iterable[Credentials], credentials)) if not all(is_credentials(item) for item in credentials_list): raise TypeError(f"invalid credentials argument: {credentials}") else: diff --git a/src/websockets/typing.py b/src/websockets/typing.py index cc3e3ec0d..e073e650d 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -2,7 +2,7 @@ import http import logging -from typing import List, NewType, Optional, Tuple, Union +from typing import Any, List, NewType, Optional, Tuple, Union __all__ = [ @@ -28,7 +28,7 @@ """ -LoggerLike = Union[logging.Logger, logging.LoggerAdapter] +LoggerLike = Union[logging.Logger, logging.LoggerAdapter[Any]] """Types accepted where a :class:`~logging.Logger` is expected.""" From 230d5052a33c0d940d926a1fc88909d39f57efd8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 1 Jan 2024 15:41:54 +0100 Subject: [PATCH 1253/1539] Add tests for abstract classes. This prevents Python 3.12 to complain that no test cases were run and to exit with code 5 (which breaks maxi_cov). --- pyproject.toml | 1 - tests/extensions/test_base.py | 28 +++++++++++++++++++++++++++- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a7b4a6a9e..c4c5412c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,7 +66,6 @@ exclude_lines = [ "if typing.TYPE_CHECKING:", "pragma: no cover", "raise AssertionError", - "raise NotImplementedError", "self.fail\\(\".*\"\\)", "@unittest.skip", ] diff --git a/tests/extensions/test_base.py b/tests/extensions/test_base.py index b18ffb6fb..62250b07f 100644 --- a/tests/extensions/test_base.py +++ b/tests/extensions/test_base.py @@ -1,4 +1,30 @@ +import unittest + from websockets.extensions.base import * +from websockets.frames import Frame, Opcode + + +class ExtensionTests(unittest.TestCase): + def test_encode(self): + with self.assertRaises(NotImplementedError): + Extension().encode(Frame(Opcode.TEXT, b"")) + + def test_decode(self): + with self.assertRaises(NotImplementedError): + Extension().decode(Frame(Opcode.TEXT, b"")) + + +class ClientExtensionFactoryTests(unittest.TestCase): + def test_get_request_params(self): + with self.assertRaises(NotImplementedError): + ClientExtensionFactory().get_request_params() + + def test_process_response_params(self): + with self.assertRaises(NotImplementedError): + ClientExtensionFactory().process_response_params([], []) -# Abstract classes don't provide any behavior to test. +class ServerExtensionFactoryTests(unittest.TestCase): + def test_process_request_params(self): + with self.assertRaises(NotImplementedError): + ServerExtensionFactory().process_request_params([], []) From 5209b2a1cba00b28b8f62502157d5dbb98625a49 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 1 Jan 2024 16:11:16 +0100 Subject: [PATCH 1254/1539] Remove empty test modules. This prevents Python 3.12 to complain that no test cases were run and to exit with code 5 (which breaks maxi_cov). --- tests/maxi_cov.py | 33 ++++++++++++++++++++++----------- tests/test_auth.py | 1 - tests/test_http.py | 7 +++++++ tests/test_typing.py | 1 - 4 files changed, 29 insertions(+), 13 deletions(-) delete mode 100644 tests/test_auth.py delete mode 100644 tests/test_typing.py diff --git a/tests/maxi_cov.py b/tests/maxi_cov.py index 2568dcf18..bc4a44e8c 100755 --- a/tests/maxi_cov.py +++ b/tests/maxi_cov.py @@ -8,8 +8,15 @@ import sys -UNMAPPED_SRC_FILES = ["websockets/version.py"] -UNMAPPED_TEST_FILES = ["tests/test_exports.py"] +UNMAPPED_SRC_FILES = [ + "websockets/auth.py", + "websockets/typing.py", + "websockets/version.py", +] + +UNMAPPED_TEST_FILES = [ + "tests/test_exports.py", +] def check_environment(): @@ -60,7 +67,7 @@ def get_mapping(src_dir="src"): # Map source files to test files. mapping = {} - unmapped_test_files = [] + unmapped_test_files = set() for test_file in test_files: dir_name, file_name = os.path.split(test_file) @@ -73,26 +80,30 @@ def get_mapping(src_dir="src"): if src_file in src_files: mapping[src_file] = test_file else: - unmapped_test_files.append(test_file) + unmapped_test_files.add(test_file) - unmapped_src_files = list(set(src_files) - set(mapping)) + unmapped_src_files = set(src_files) - set(mapping) # Ensure that all files are mapped. - assert unmapped_src_files == UNMAPPED_SRC_FILES - assert unmapped_test_files == UNMAPPED_TEST_FILES + assert unmapped_src_files == set(UNMAPPED_SRC_FILES) + assert unmapped_test_files == set(UNMAPPED_TEST_FILES) return mapping def get_ignored_files(src_dir="src"): """Return the list of files to exclude from coverage measurement.""" - + # */websockets matches src/websockets and .tox/**/site-packages/websockets. return [ - # */websockets matches src/websockets and .tox/**/site-packages/websockets. - # There are no tests for the __main__ module and for compatibility modules. + # There are no tests for the __main__ module. "*/websockets/__main__.py", + # There is nothing to test on type declarations. + "*/websockets/typing.py", + # We don't test compatibility modules with previous versions of Python + # or websockets (import locations). "*/websockets/*/compatibility.py", + "*/websockets/auth.py", # This approach isn't applicable to the test suite of the legacy # implementation, due to the huge test_client_server test module. "*/websockets/legacy/*", @@ -125,7 +136,7 @@ def run_coverage(mapping, src_dir="src"): "-m", "unittest", ] - + UNMAPPED_TEST_FILES, + + list(UNMAPPED_TEST_FILES), check=True, ) # Append coverage of each source module by the corresponding test module. diff --git a/tests/test_auth.py b/tests/test_auth.py deleted file mode 100644 index 28db93155..000000000 --- a/tests/test_auth.py +++ /dev/null @@ -1 +0,0 @@ -from websockets.auth import * diff --git a/tests/test_http.py b/tests/test_http.py index 036bc1410..baaa7d416 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -1 +1,8 @@ +import unittest + from websockets.http import * + + +class HTTPTests(unittest.TestCase): + def test_user_agent(self): + USER_AGENT # exists diff --git a/tests/test_typing.py b/tests/test_typing.py deleted file mode 100644 index 202de840f..000000000 --- a/tests/test_typing.py +++ /dev/null @@ -1 +0,0 @@ -from websockets.typing import * From 3c6b1aab96adde1a4b0d3e8f1a93b7f2c7310af0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 13 Jan 2024 20:46:54 +0100 Subject: [PATCH 1255/1539] Restore compatibility with Python < 3.11. Broken in 9038a62e. --- src/websockets/typing.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/websockets/typing.py b/src/websockets/typing.py index e073e650d..5dfecf66f 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -2,6 +2,7 @@ import http import logging +import typing from typing import Any, List, NewType, Optional, Tuple, Union @@ -28,8 +29,12 @@ """ -LoggerLike = Union[logging.Logger, logging.LoggerAdapter[Any]] -"""Types accepted where a :class:`~logging.Logger` is expected.""" +if typing.TYPE_CHECKING: + LoggerLike = Union[logging.Logger, logging.LoggerAdapter[Any]] + """Types accepted where a :class:`~logging.Logger` is expected.""" +else: # remove this branch when dropping support for Python < 3.11 + LoggerLike = Union[logging.Logger, logging.LoggerAdapter] + """Types accepted where a :class:`~logging.Logger` is expected.""" StatusLike = Union[http.HTTPStatus, int] From 7b522ec0df8f4e26abe09046a0ae7861714f5a2a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 13 Jan 2024 21:54:08 +0100 Subject: [PATCH 1256/1539] Simplify code. It had to be written in that way with asyncio.wait_for but that isn't necessary anymore with asyncio.timeout. --- src/websockets/legacy/client.py | 55 ++++++++++++++++----------------- 1 file changed, 26 insertions(+), 29 deletions(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 48622523e..b85d22867 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -640,38 +640,35 @@ async def __aexit__( def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]: # Create a suitable iterator by calling __await__ on a coroutine. - return self.__await_impl_timeout__().__await__() - - async def __await_impl_timeout__(self) -> WebSocketClientProtocol: - async with asyncio_timeout(self.open_timeout): - return await self.__await_impl__() + return self.__await_impl__().__await__() async def __await_impl__(self) -> WebSocketClientProtocol: - for redirects in range(self.MAX_REDIRECTS_ALLOWED): - _transport, _protocol = await self._create_connection() - protocol = cast(WebSocketClientProtocol, _protocol) - try: - await protocol.handshake( - self._wsuri, - origin=protocol.origin, - available_extensions=protocol.available_extensions, - available_subprotocols=protocol.available_subprotocols, - extra_headers=protocol.extra_headers, - ) - except RedirectHandshake as exc: - protocol.fail_connection() - await protocol.wait_closed() - self.handle_redirect(exc.uri) - # Avoid leaking a connected socket when the handshake fails. - except (Exception, asyncio.CancelledError): - protocol.fail_connection() - await protocol.wait_closed() - raise + async with asyncio_timeout(self.open_timeout): + for _redirects in range(self.MAX_REDIRECTS_ALLOWED): + _transport, _protocol = await self._create_connection() + protocol = cast(WebSocketClientProtocol, _protocol) + try: + await protocol.handshake( + self._wsuri, + origin=protocol.origin, + available_extensions=protocol.available_extensions, + available_subprotocols=protocol.available_subprotocols, + extra_headers=protocol.extra_headers, + ) + except RedirectHandshake as exc: + protocol.fail_connection() + await protocol.wait_closed() + self.handle_redirect(exc.uri) + # Avoid leaking a connected socket when the handshake fails. + except (Exception, asyncio.CancelledError): + protocol.fail_connection() + await protocol.wait_closed() + raise + else: + self.protocol = protocol + return protocol else: - self.protocol = protocol - return protocol - else: - raise SecurityError("too many redirects") + raise SecurityError("too many redirects") # ... = yield from connect(...) - remove when dropping Python < 3.10 From 35bc7dd8288445289134c335aae8af859862ccd1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 14 Jan 2024 21:20:06 +0100 Subject: [PATCH 1257/1539] Create futures with create_future. This is the preferred way to create Futures in asyncio. --- README.rst | 2 +- compliance/test_server.py | 2 +- docs/intro/tutorial1.rst | 2 +- docs/topics/broadcast.rst | 5 +++-- example/django/authentication.py | 2 +- example/echo.py | 2 +- example/faq/health_check_server.py | 2 +- example/legacy/basic_auth_server.py | 2 +- example/legacy/unix_server.py | 2 +- example/quickstart/counter.py | 2 +- example/quickstart/server.py | 2 +- example/quickstart/server_secure.py | 2 +- example/quickstart/show_time.py | 2 +- example/tutorial/step1/app.py | 2 +- example/tutorial/step2/app.py | 2 +- experiments/broadcast/server.py | 5 +++-- src/websockets/legacy/protocol.py | 4 ++-- src/websockets/legacy/server.py | 6 ++++-- 18 files changed, 26 insertions(+), 22 deletions(-) diff --git a/README.rst b/README.rst index 870b208ba..94cd79ab9 100644 --- a/README.rst +++ b/README.rst @@ -55,7 +55,7 @@ Here's an echo server with the ``asyncio`` API: async def main(): async with serve(echo, "localhost", 8765): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever asyncio.run(main()) diff --git a/compliance/test_server.py b/compliance/test_server.py index 92f895d92..5701e4485 100644 --- a/compliance/test_server.py +++ b/compliance/test_server.py @@ -21,7 +21,7 @@ async def echo(ws): async def main(): with websockets.serve(echo, HOST, PORT, max_size=2 ** 25, max_queue=1): try: - await asyncio.Future() + await asyncio.get_running_loop().create_future() # run forever except KeyboardInterrupt: pass diff --git a/docs/intro/tutorial1.rst b/docs/intro/tutorial1.rst index ff85003b5..6b32d47f6 100644 --- a/docs/intro/tutorial1.rst +++ b/docs/intro/tutorial1.rst @@ -195,7 +195,7 @@ Create an ``app.py`` file next to ``connect4.py`` with this content: async def main(): async with websockets.serve(handler, "", 8001): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": diff --git a/docs/topics/broadcast.rst b/docs/topics/broadcast.rst index 1acb372d4..b6ddda734 100644 --- a/docs/topics/broadcast.rst +++ b/docs/topics/broadcast.rst @@ -273,10 +273,11 @@ Here's a message stream that supports multiple consumers:: class PubSub: def __init__(self): - self.waiter = asyncio.Future() + self.waiter = asyncio.get_running_loop().create_future() def publish(self, value): - waiter, self.waiter = self.waiter, asyncio.Future() + waiter = self.waiter + self.waiter = asyncio.get_running_loop().create_future() waiter.set_result((value, self.waiter)) async def subscribe(self): diff --git a/example/django/authentication.py b/example/django/authentication.py index f6dad0f55..83e128f07 100644 --- a/example/django/authentication.py +++ b/example/django/authentication.py @@ -23,7 +23,7 @@ async def handler(websocket): async def main(): async with websockets.serve(handler, "localhost", 8888): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": diff --git a/example/echo.py b/example/echo.py index 2e47e52d9..d11b33527 100755 --- a/example/echo.py +++ b/example/echo.py @@ -9,6 +9,6 @@ async def echo(websocket): async def main(): async with serve(echo, "localhost", 8765): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever asyncio.run(main()) diff --git a/example/faq/health_check_server.py b/example/faq/health_check_server.py index 7b8bded77..6c7681e8a 100755 --- a/example/faq/health_check_server.py +++ b/example/faq/health_check_server.py @@ -17,6 +17,6 @@ async def main(): echo, "localhost", 8765, process_request=health_check, ): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever asyncio.run(main()) diff --git a/example/legacy/basic_auth_server.py b/example/legacy/basic_auth_server.py index d2efeb7e5..6f6020253 100755 --- a/example/legacy/basic_auth_server.py +++ b/example/legacy/basic_auth_server.py @@ -16,6 +16,6 @@ async def main(): realm="example", credentials=("mary", "p@ssw0rd") ), ): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever asyncio.run(main()) diff --git a/example/legacy/unix_server.py b/example/legacy/unix_server.py index 335039c35..5bfb66072 100755 --- a/example/legacy/unix_server.py +++ b/example/legacy/unix_server.py @@ -18,6 +18,6 @@ async def hello(websocket): async def main(): socket_path = os.path.join(os.path.dirname(__file__), "socket") async with websockets.unix_serve(hello, socket_path): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever asyncio.run(main()) diff --git a/example/quickstart/counter.py b/example/quickstart/counter.py index 566e12965..414919e04 100755 --- a/example/quickstart/counter.py +++ b/example/quickstart/counter.py @@ -43,7 +43,7 @@ async def counter(websocket): async def main(): async with websockets.serve(counter, "localhost", 6789): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": asyncio.run(main()) diff --git a/example/quickstart/server.py b/example/quickstart/server.py index 31b182972..64d7adeb6 100755 --- a/example/quickstart/server.py +++ b/example/quickstart/server.py @@ -14,7 +14,7 @@ async def hello(websocket): async def main(): async with websockets.serve(hello, "localhost", 8765): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": asyncio.run(main()) diff --git a/example/quickstart/server_secure.py b/example/quickstart/server_secure.py index de41d30dc..11db5fb3a 100755 --- a/example/quickstart/server_secure.py +++ b/example/quickstart/server_secure.py @@ -20,7 +20,7 @@ async def hello(websocket): async def main(): async with websockets.serve(hello, "localhost", 8765, ssl=ssl_context): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": asyncio.run(main()) diff --git a/example/quickstart/show_time.py b/example/quickstart/show_time.py index a83078e8a..add226869 100755 --- a/example/quickstart/show_time.py +++ b/example/quickstart/show_time.py @@ -13,7 +13,7 @@ async def show_time(websocket): async def main(): async with websockets.serve(show_time, "localhost", 5678): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": asyncio.run(main()) diff --git a/example/tutorial/step1/app.py b/example/tutorial/step1/app.py index 3b0fbd786..6ec1c60b8 100644 --- a/example/tutorial/step1/app.py +++ b/example/tutorial/step1/app.py @@ -58,7 +58,7 @@ async def handler(websocket): async def main(): async with websockets.serve(handler, "", 8001): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": diff --git a/example/tutorial/step2/app.py b/example/tutorial/step2/app.py index 2693d4304..db3e36374 100644 --- a/example/tutorial/step2/app.py +++ b/example/tutorial/step2/app.py @@ -183,7 +183,7 @@ async def handler(websocket): async def main(): async with websockets.serve(handler, "", 8001): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": diff --git a/experiments/broadcast/server.py b/experiments/broadcast/server.py index 9c9907b7f..b0407ba34 100644 --- a/experiments/broadcast/server.py +++ b/experiments/broadcast/server.py @@ -27,10 +27,11 @@ async def relay(queue, websocket): class PubSub: def __init__(self): - self.waiter = asyncio.Future() + self.waiter = asyncio.get_running_loop().create_future() def publish(self, value): - waiter, self.waiter = self.waiter, asyncio.Future() + waiter = self.waiter + self.waiter = asyncio.get_running_loop().create_future() waiter.set_result((value, self.waiter)) async def subscribe(self): diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 19cee0e65..47f948b7a 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -664,7 +664,7 @@ async def send( return opcode, data = prepare_data(fragment) - self._fragmented_message_waiter = asyncio.Future() + self._fragmented_message_waiter = self.loop.create_future() try: # First fragment. await self.write_frame(False, opcode, data) @@ -709,7 +709,7 @@ async def send( return opcode, data = prepare_data(fragment) - self._fragmented_message_waiter = asyncio.Future() + self._fragmented_message_waiter = self.loop.create_future() try: # First fragment. await self.write_frame(False, opcode, data) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 7c24dd74a..d95bec4f6 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -897,7 +897,8 @@ class Serve: Awaiting :func:`serve` yields a :class:`WebSocketServer`. This object provides a :meth:`~WebSocketServer.close` method to shut down the server:: - stop = asyncio.Future() # set this future to exit the server + # set this future to exit the server + stop = asyncio.get_running_loop().create_future() server = await serve(...) await stop @@ -906,7 +907,8 @@ class Serve: :func:`serve` can be used as an asynchronous context manager. Then, the server is shut down automatically when exiting the context:: - stop = asyncio.Future() # set this future to exit the server + # set this future to exit the server + stop = asyncio.get_running_loop().create_future() async with serve(...): await stop From cba4c242614734a722891992e8bc005bc848c0c1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 20 Jan 2024 18:56:07 +0100 Subject: [PATCH 1258/1539] Avoid duplicating signature in API docs. --- docs/reference/sansio/client.rst | 2 +- docs/reference/sansio/common.rst | 2 +- docs/reference/sansio/server.rst | 2 +- docs/reference/sync/client.rst | 4 ++-- docs/reference/sync/server.rst | 4 ++-- src/websockets/sync/client.py | 18 ++++++++++-------- src/websockets/sync/server.py | 14 +++++++++----- 7 files changed, 26 insertions(+), 20 deletions(-) diff --git a/docs/reference/sansio/client.rst b/docs/reference/sansio/client.rst index 09bafc745..12f88b8ed 100644 --- a/docs/reference/sansio/client.rst +++ b/docs/reference/sansio/client.rst @@ -5,7 +5,7 @@ Client (`Sans-I/O`_) .. currentmodule:: websockets.client -.. autoclass:: ClientProtocol(wsuri, origin=None, extensions=None, subprotocols=None, state=State.CONNECTING, max_size=2 ** 20, logger=None) +.. autoclass:: ClientProtocol .. automethod:: receive_data diff --git a/docs/reference/sansio/common.rst b/docs/reference/sansio/common.rst index cd1ef3c63..7d5447ac9 100644 --- a/docs/reference/sansio/common.rst +++ b/docs/reference/sansio/common.rst @@ -7,7 +7,7 @@ Both sides (`Sans-I/O`_) .. automodule:: websockets.protocol -.. autoclass:: Protocol(side, state=State.OPEN, max_size=2 ** 20, logger=None) +.. autoclass:: Protocol .. automethod:: receive_data diff --git a/docs/reference/sansio/server.rst b/docs/reference/sansio/server.rst index d70df6277..3152f174e 100644 --- a/docs/reference/sansio/server.rst +++ b/docs/reference/sansio/server.rst @@ -5,7 +5,7 @@ Server (`Sans-I/O`_) .. currentmodule:: websockets.server -.. autoclass:: ServerProtocol(origins=None, extensions=None, subprotocols=None, state=State.CONNECTING, max_size=2 ** 20, logger=None) +.. autoclass:: ServerProtocol .. automethod:: receive_data diff --git a/docs/reference/sync/client.rst b/docs/reference/sync/client.rst index 6cccd6ec4..af1132412 100644 --- a/docs/reference/sync/client.rst +++ b/docs/reference/sync/client.rst @@ -6,9 +6,9 @@ Client (:mod:`threading`) Opening a connection -------------------- -.. autofunction:: connect(uri, *, sock=None, ssl_context=None, server_hostname=None, origin=None, extensions=None, subprotocols=None, additional_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", compression="deflate", open_timeout=10, close_timeout=10, max_size=2 ** 20, logger=None, create_connection=None) +.. autofunction:: connect -.. autofunction:: unix_connect(path, uri=None, *, sock=None, ssl_context=None, server_hostname=None, origin=None, extensions=None, subprotocols=None, additional_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", compression="deflate", open_timeout=10, close_timeout=10, max_size=2 ** 20, logger=None, create_connection=None) +.. autofunction:: unix_connect Using a connection ------------------ diff --git a/docs/reference/sync/server.rst b/docs/reference/sync/server.rst index 35c112046..7ed744df2 100644 --- a/docs/reference/sync/server.rst +++ b/docs/reference/sync/server.rst @@ -6,9 +6,9 @@ Server (:mod:`threading`) Creating a server ----------------- -.. autofunction:: serve(handler, host=None, port=None, *, sock=None, ssl_context=None, origins=None, extensions=None, subprotocols=None, select_subprotocol=None, process_request=None, process_response=None, server_header="Python/x.y.z websockets/X.Y", compression="deflate", open_timeout=10, close_timeout=10, max_size=2 ** 20, logger=None, create_connection=None) +.. autofunction:: serve -.. autofunction:: unix_serve(handler, path=None, *, sock=None, ssl_context=None, origins=None, extensions=None, subprotocols=None, select_subprotocol=None, process_request=None, process_response=None, server_header="Python/x.y.z websockets/X.Y", compression="deflate", open_timeout=10, close_timeout=10, max_size=2 ** 20, logger=None, create_connection=None) +.. autofunction:: unix_serve Running a server ---------------- diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 087ff5f56..78a9a3c86 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -126,12 +126,10 @@ def recv_events(self) -> None: def connect( uri: str, *, - # TCP/TLS — unix and path are only for unix_connect() + # TCP/TLS sock: Optional[socket.socket] = None, ssl_context: Optional[ssl.SSLContext] = None, server_hostname: Optional[str] = None, - unix: bool = False, - path: Optional[str] = None, # WebSocket origin: Optional[Origin] = None, extensions: Optional[Sequence[ClientExtensionFactory]] = None, @@ -148,6 +146,7 @@ def connect( logger: Optional[LoggerLike] = None, # Escape hatch for advanced customization create_connection: Optional[Type[ClientConnection]] = None, + **kwargs: Any, ) -> ClientConnection: """ Connect to the WebSocket server at ``uri``. @@ -210,13 +209,15 @@ def connect( if not wsuri.secure and ssl_context is not None: raise TypeError("ssl_context argument is incompatible with a ws:// URI") + # Private APIs for unix_connect() + unix: bool = kwargs.pop("unix", False) + path: Optional[str] = kwargs.pop("path", None) + if unix: if path is None and sock is None: raise TypeError("missing path argument") elif path is not None and sock is not None: raise TypeError("path and sock arguments are incompatible") - else: - assert path is None # private argument, only set by unix_connect() if subprotocols is not None: validate_subprotocols(subprotocols) @@ -241,7 +242,7 @@ def connect( if unix: sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sock.settimeout(deadline.timeout()) - assert path is not None # validated above -- this is for mpypy + assert path is not None # mypy cannot figure this out sock.connect(path) else: sock = socket.create_connection( @@ -308,8 +309,9 @@ def unix_connect( """ Connect to a WebSocket server listening on a Unix socket. - This function is identical to :func:`connect`, except for the additional - ``path`` argument. It's only available on Unix. + This function accepts the same keyword arguments as :func:`connect`. + + It's only available on Unix. It's mainly useful for debugging servers listening on Unix sockets. diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index d12da0c65..7faab0a3d 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -266,11 +266,9 @@ def serve( host: Optional[str] = None, port: Optional[int] = None, *, - # TCP/TLS — unix and path are only for unix_serve() + # TCP/TLS sock: Optional[socket.socket] = None, ssl_context: Optional[ssl.SSLContext] = None, - unix: bool = False, - path: Optional[str] = None, # WebSocket origins: Optional[Sequence[Optional[Origin]]] = None, extensions: Optional[Sequence[ServerExtensionFactory]] = None, @@ -304,6 +302,7 @@ def serve( logger: Optional[LoggerLike] = None, # Escape hatch for advanced customization create_connection: Optional[Type[ServerConnection]] = None, + **kwargs: Any, ) -> WebSocketServer: """ Create a WebSocket server listening on ``host`` and ``port``. @@ -397,6 +396,10 @@ def handler(websocket): # Bind socket and listen + # Private APIs for unix_connect() + unix: bool = kwargs.pop("unix", False) + path: Optional[str] = kwargs.pop("path", None) + if sock is None: if unix: if path is None: @@ -515,8 +518,9 @@ def unix_serve( """ Create a WebSocket server listening on a Unix socket. - This function is identical to :func:`serve`, except the ``host`` and - ``port`` arguments are replaced by ``path``. It's only available on Unix. + This function accepts the same keyword arguments as :func:`serve`. + + It's only available on Unix. It's useful for deploying a server behind a reverse proxy such as nginx. From cd4bc7960658db6d51f60f528b3b53c718426591 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 20 Jan 2024 19:08:49 +0100 Subject: [PATCH 1259/1539] Pass arguments to create_server/connection. --- src/websockets/sync/client.py | 8 ++++---- src/websockets/sync/server.py | 8 ++++++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 78a9a3c86..79af0132f 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -195,6 +195,8 @@ def connect( the connection. Set it to a wrapper or a subclass to customize connection handling. + Any other keyword arguments are passed to :func:`~socket.create_connection`. + Raises: InvalidURI: If ``uri`` isn't a valid WebSocket URI. OSError: If the TCP connection fails. @@ -245,10 +247,8 @@ def connect( assert path is not None # mypy cannot figure this out sock.connect(path) else: - sock = socket.create_connection( - (wsuri.host, wsuri.port), - deadline.timeout(), - ) + kwargs.setdefault("timeout", deadline.timeout()) + sock = socket.create_connection((wsuri.host, wsuri.port), **kwargs) sock.settimeout(None) # Disable Nagle algorithm diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 7faab0a3d..c19992849 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -379,6 +379,9 @@ def handler(websocket): create_connection: Factory for the :class:`ServerConnection` managing the connection. Set it to a wrapper or a subclass to customize connection handling. + + Any other keyword arguments are passed to :func:`~socket.create_server`. + """ # Process parameters @@ -404,9 +407,10 @@ def handler(websocket): if unix: if path is None: raise TypeError("missing path argument") - sock = socket.create_server(path, family=socket.AF_UNIX) + kwargs.setdefault("family", socket.AF_UNIX) + sock = socket.create_server(path, **kwargs) else: - sock = socket.create_server((host, port)) + sock = socket.create_server((host, port), **kwargs) else: if path is not None: raise TypeError("path and sock arguments are incompatible") From 03ecfa5611f0c87ea9cfa7497f78e0c85408060e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 20 Jan 2024 19:14:28 +0100 Subject: [PATCH 1260/1539] Standardize on .encode(). We had a mix of .encode() and .encode("utf-8") -- which is the default. --- experiments/compression/benchmark.py | 2 +- src/websockets/frames.py | 6 +- src/websockets/sync/connection.py | 8 +- tests/extensions/test_permessage_deflate.py | 26 +++--- tests/legacy/test_framing.py | 6 +- tests/legacy/test_protocol.py | 90 ++++++++++----------- tests/test_frames.py | 4 +- 7 files changed, 69 insertions(+), 73 deletions(-) diff --git a/experiments/compression/benchmark.py b/experiments/compression/benchmark.py index c5b13c8fa..4fbdf6220 100644 --- a/experiments/compression/benchmark.py +++ b/experiments/compression/benchmark.py @@ -66,7 +66,7 @@ def _run(data): for _ in range(REPEAT): for item in data: if isinstance(item, str): - item = item.encode("utf-8") + item = item.encode() # Taken from PerMessageDeflate.encode item = encoder.compress(item) + encoder.flush(zlib.Z_SYNC_FLUSH) if item.endswith(b"\x00\x00\xff\xff"): diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 6b1befb2e..63c35ed4d 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -364,7 +364,7 @@ def prepare_data(data: Data) -> Tuple[int, bytes]: """ if isinstance(data, str): - return OP_TEXT, data.encode("utf-8") + return OP_TEXT, data.encode() elif isinstance(data, BytesLike): return OP_BINARY, data else: @@ -387,7 +387,7 @@ def prepare_ctrl(data: Data) -> bytes: """ if isinstance(data, str): - return data.encode("utf-8") + return data.encode() elif isinstance(data, BytesLike): return bytes(data) else: @@ -456,7 +456,7 @@ def serialize(self) -> bytes: """ self.check() - return struct.pack("!H", self.code) + self.reason.encode("utf-8") + return struct.pack("!H", self.code) + self.reason.encode() def check(self) -> None: """ diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 4a8879e37..62aa17ffd 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -287,7 +287,7 @@ def send(self, message: Union[Data, Iterable[Data]]) -> None: "cannot call send while another thread " "is already running send" ) - self.protocol.send_text(message.encode("utf-8")) + self.protocol.send_text(message.encode()) elif isinstance(message, BytesLike): with self.send_context(): @@ -324,7 +324,7 @@ def send(self, message: Union[Data, Iterable[Data]]) -> None: ) self.send_in_progress = True self.protocol.send_text( - chunk.encode("utf-8"), + chunk.encode(), fin=False, ) elif isinstance(chunk, BytesLike): @@ -349,7 +349,7 @@ def send(self, message: Union[Data, Iterable[Data]]) -> None: with self.send_context(): assert self.send_in_progress self.protocol.send_continuation( - chunk.encode("utf-8"), + chunk.encode(), fin=False, ) elif isinstance(chunk, BytesLike) and not text: @@ -630,7 +630,7 @@ def send_context( socket:: with self.send_context(): - self.protocol.send_text(message.encode("utf-8")) + self.protocol.send_text(message.encode()) When the connection isn't open on entry, when the connection is expected to close on exit, or when an unexpected error happens, terminating the diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index 0e698566f..ee09813c4 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -84,7 +84,7 @@ def test_no_encode_decode_close_frame(self): # Data frames are encoded and decoded. def test_encode_decode_text_frame(self): - frame = Frame(OP_TEXT, "café".encode("utf-8")) + frame = Frame(OP_TEXT, "café".encode()) enc_frame = self.extension.encode(frame) @@ -112,9 +112,9 @@ def test_encode_decode_binary_frame(self): self.assertEqual(dec_frame, frame) def test_encode_decode_fragmented_text_frame(self): - frame1 = Frame(OP_TEXT, "café".encode("utf-8"), fin=False) - frame2 = Frame(OP_CONT, " & ".encode("utf-8"), fin=False) - frame3 = Frame(OP_CONT, "croissants".encode("utf-8")) + frame1 = Frame(OP_TEXT, "café".encode(), fin=False) + frame2 = Frame(OP_CONT, " & ".encode(), fin=False) + frame3 = Frame(OP_CONT, "croissants".encode()) enc_frame1 = self.extension.encode(frame1) enc_frame2 = self.extension.encode(frame2) @@ -168,7 +168,7 @@ def test_encode_decode_fragmented_binary_frame(self): self.assertEqual(dec_frame2, frame2) def test_no_decode_text_frame(self): - frame = Frame(OP_TEXT, "café".encode("utf-8")) + frame = Frame(OP_TEXT, "café".encode()) # Try decoding a frame that wasn't encoded. self.assertEqual(self.extension.decode(frame), frame) @@ -180,9 +180,9 @@ def test_no_decode_binary_frame(self): self.assertEqual(self.extension.decode(frame), frame) def test_no_decode_fragmented_text_frame(self): - frame1 = Frame(OP_TEXT, "café".encode("utf-8"), fin=False) - frame2 = Frame(OP_CONT, " & ".encode("utf-8"), fin=False) - frame3 = Frame(OP_CONT, "croissants".encode("utf-8")) + frame1 = Frame(OP_TEXT, "café".encode(), fin=False) + frame2 = Frame(OP_CONT, " & ".encode(), fin=False) + frame3 = Frame(OP_CONT, "croissants".encode()) dec_frame1 = self.extension.decode(frame1) dec_frame2 = self.extension.decode(frame2) @@ -203,7 +203,7 @@ def test_no_decode_fragmented_binary_frame(self): self.assertEqual(dec_frame2, frame2) def test_context_takeover(self): - frame = Frame(OP_TEXT, "café".encode("utf-8")) + frame = Frame(OP_TEXT, "café".encode()) enc_frame1 = self.extension.encode(frame) enc_frame2 = self.extension.encode(frame) @@ -215,7 +215,7 @@ def test_remote_no_context_takeover(self): # No context takeover when decoding messages. self.extension = PerMessageDeflate(True, False, 15, 15) - frame = Frame(OP_TEXT, "café".encode("utf-8")) + frame = Frame(OP_TEXT, "café".encode()) enc_frame1 = self.extension.encode(frame) enc_frame2 = self.extension.encode(frame) @@ -233,7 +233,7 @@ def test_local_no_context_takeover(self): # No context takeover when encoding and decoding messages. self.extension = PerMessageDeflate(True, True, 15, 15) - frame = Frame(OP_TEXT, "café".encode("utf-8")) + frame = Frame(OP_TEXT, "café".encode()) enc_frame1 = self.extension.encode(frame) enc_frame2 = self.extension.encode(frame) @@ -253,7 +253,7 @@ def test_compress_settings(self): # Configure an extension so that no compression actually occurs. extension = PerMessageDeflate(False, False, 15, 15, {"level": 0}) - frame = Frame(OP_TEXT, "café".encode("utf-8")) + frame = Frame(OP_TEXT, "café".encode()) enc_frame = extension.encode(frame) @@ -269,7 +269,7 @@ def test_compress_settings(self): # Frames aren't decoded beyond max_size. def test_decompress_max_size(self): - frame = Frame(OP_TEXT, ("a" * 20).encode("utf-8")) + frame = Frame(OP_TEXT, ("a" * 20).encode()) enc_frame = self.extension.encode(frame) diff --git a/tests/legacy/test_framing.py b/tests/legacy/test_framing.py index e1e4c891b..6f811bd5e 100644 --- a/tests/legacy/test_framing.py +++ b/tests/legacy/test_framing.py @@ -76,14 +76,12 @@ def test_binary_masked(self): ) def test_non_ascii_text(self): - self.round_trip( - b"\x81\x05caf\xc3\xa9", Frame(True, OP_TEXT, "café".encode("utf-8")) - ) + self.round_trip(b"\x81\x05caf\xc3\xa9", Frame(True, OP_TEXT, "café".encode())) def test_non_ascii_text_masked(self): self.round_trip( b"\x81\x85\x64\xbe\xee\x7e\x07\xdf\x88\xbd\xcd", - Frame(True, OP_TEXT, "café".encode("utf-8")), + Frame(True, OP_TEXT, "café".encode()), mask=True, ) diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index f2eb0fea0..f3dcd9ac7 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -428,7 +428,7 @@ def test_close_reason_not_set(self): # Test the recv coroutine. def test_recv_text(self): - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) + self.receive_frame(Frame(True, OP_TEXT, "café".encode())) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, "café") @@ -458,7 +458,7 @@ def test_recv_on_closed_connection(self): self.loop.run_until_complete(self.protocol.recv()) def test_recv_protocol_error(self): - self.receive_frame(Frame(True, OP_CONT, "café".encode("utf-8"))) + self.receive_frame(Frame(True, OP_CONT, "café".encode())) self.process_invalid_frames() self.assertConnectionFailed(CloseCode.PROTOCOL_ERROR, "") @@ -469,7 +469,7 @@ def test_recv_unicode_error(self): def test_recv_text_payload_too_big(self): self.protocol.max_size = 1024 - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8") * 205)) + self.receive_frame(Frame(True, OP_TEXT, "café".encode() * 205)) self.process_invalid_frames() self.assertConnectionFailed(CloseCode.MESSAGE_TOO_BIG, "") @@ -481,7 +481,7 @@ def test_recv_binary_payload_too_big(self): def test_recv_text_no_max_size(self): self.protocol.max_size = None # for test coverage - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8") * 205)) + self.receive_frame(Frame(True, OP_TEXT, "café".encode() * 205)) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, "café" * 205) @@ -498,7 +498,7 @@ def test_recv_queue_empty(self): asyncio.wait_for(asyncio.shield(recv), timeout=MS) ) - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) + self.receive_frame(Frame(True, OP_TEXT, "café".encode())) data = self.loop.run_until_complete(recv) self.assertEqual(data, "café") @@ -507,7 +507,7 @@ def test_recv_queue_full(self): # Test internals because it's hard to verify buffers from the outside. self.assertEqual(list(self.protocol.messages), []) - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) + self.receive_frame(Frame(True, OP_TEXT, "café".encode())) self.run_loop_once() self.assertEqual(list(self.protocol.messages), ["café"]) @@ -535,7 +535,7 @@ def test_recv_queue_no_limit(self): self.protocol.max_queue = None for _ in range(100): - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) + self.receive_frame(Frame(True, OP_TEXT, "café".encode())) self.run_loop_once() # Incoming message queue can contain at least 100 messages. @@ -562,7 +562,7 @@ def test_recv_canceled(self): self.loop.run_until_complete(recv) # The next frame doesn't disappear in a vacuum (it used to). - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) + self.receive_frame(Frame(True, OP_TEXT, "café".encode())) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, "café") @@ -570,15 +570,13 @@ def test_recv_canceled_race_condition(self): recv = self.loop.create_task( asyncio.wait_for(self.protocol.recv(), timeout=0.000_001) ) - self.loop.call_soon( - self.receive_frame, Frame(True, OP_TEXT, "café".encode("utf-8")) - ) + self.loop.call_soon(self.receive_frame, Frame(True, OP_TEXT, "café".encode())) with self.assertRaises(asyncio.TimeoutError): self.loop.run_until_complete(recv) # The previous frame doesn't disappear in a vacuum (it used to). - self.receive_frame(Frame(True, OP_TEXT, "tea".encode("utf-8"))) + self.receive_frame(Frame(True, OP_TEXT, "tea".encode())) data = self.loop.run_until_complete(self.protocol.recv()) # If we're getting "tea" there, it means "café" was swallowed (ha, ha). self.assertEqual(data, "café") @@ -586,7 +584,7 @@ def test_recv_canceled_race_condition(self): def test_recv_when_transfer_data_cancelled(self): # Clog incoming queue. self.protocol.max_queue = 1 - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) + self.receive_frame(Frame(True, OP_TEXT, "café".encode())) self.receive_frame(Frame(True, OP_BINARY, b"tea")) self.run_loop_once() @@ -620,7 +618,7 @@ def test_recv_prevents_concurrent_calls(self): def test_send_text(self): self.loop.run_until_complete(self.protocol.send("café")) - self.assertOneFrameSent(True, OP_TEXT, "café".encode("utf-8")) + self.assertOneFrameSent(True, OP_TEXT, "café".encode()) def test_send_binary(self): self.loop.run_until_complete(self.protocol.send(b"tea")) @@ -647,9 +645,9 @@ def test_send_type_error(self): def test_send_iterable_text(self): self.loop.run_until_complete(self.protocol.send(["ca", "fé"])) self.assertFramesSent( - (False, OP_TEXT, "ca".encode("utf-8")), - (False, OP_CONT, "fé".encode("utf-8")), - (True, OP_CONT, "".encode("utf-8")), + (False, OP_TEXT, "ca".encode()), + (False, OP_CONT, "fé".encode()), + (True, OP_CONT, "".encode()), ) def test_send_iterable_binary(self): @@ -687,7 +685,7 @@ def test_send_iterable_mixed_type_error(self): with self.assertRaises(TypeError): self.loop.run_until_complete(self.protocol.send(["café", b"tea"])) self.assertFramesSent( - (False, OP_TEXT, "café".encode("utf-8")), + (False, OP_TEXT, "café".encode()), (True, OP_CLOSE, Close(CloseCode.INTERNAL_ERROR, "").serialize()), ) @@ -710,18 +708,18 @@ async def run_concurrently(): self.loop.run_until_complete(run_concurrently()) self.assertFramesSent( - (False, OP_TEXT, "ca".encode("utf-8")), - (False, OP_CONT, "fé".encode("utf-8")), - (True, OP_CONT, "".encode("utf-8")), + (False, OP_TEXT, "ca".encode()), + (False, OP_CONT, "fé".encode()), + (True, OP_CONT, "".encode()), (True, OP_BINARY, b"tea"), ) def test_send_async_iterable_text(self): self.loop.run_until_complete(self.protocol.send(async_iterable(["ca", "fé"]))) self.assertFramesSent( - (False, OP_TEXT, "ca".encode("utf-8")), - (False, OP_CONT, "fé".encode("utf-8")), - (True, OP_CONT, "".encode("utf-8")), + (False, OP_TEXT, "ca".encode()), + (False, OP_CONT, "fé".encode()), + (True, OP_CONT, "".encode()), ) def test_send_async_iterable_binary(self): @@ -761,7 +759,7 @@ def test_send_async_iterable_mixed_type_error(self): self.protocol.send(async_iterable(["café", b"tea"])) ) self.assertFramesSent( - (False, OP_TEXT, "café".encode("utf-8")), + (False, OP_TEXT, "café".encode()), (True, OP_CLOSE, Close(CloseCode.INTERNAL_ERROR, "").serialize()), ) @@ -784,9 +782,9 @@ async def run_concurrently(): self.loop.run_until_complete(run_concurrently()) self.assertFramesSent( - (False, OP_TEXT, "ca".encode("utf-8")), - (False, OP_CONT, "fé".encode("utf-8")), - (True, OP_CONT, "".encode("utf-8")), + (False, OP_TEXT, "ca".encode()), + (False, OP_CONT, "fé".encode()), + (True, OP_CONT, "".encode()), (True, OP_BINARY, b"tea"), ) @@ -829,7 +827,7 @@ def test_ping_default(self): def test_ping_text(self): self.loop.run_until_complete(self.protocol.ping("café")) - self.assertOneFrameSent(True, OP_PING, "café".encode("utf-8")) + self.assertOneFrameSent(True, OP_PING, "café".encode()) def test_ping_binary(self): self.loop.run_until_complete(self.protocol.ping(b"tea")) @@ -882,7 +880,7 @@ def test_pong_default(self): def test_pong_text(self): self.loop.run_until_complete(self.protocol.pong("café")) - self.assertOneFrameSent(True, OP_PONG, "café".encode("utf-8")) + self.assertOneFrameSent(True, OP_PONG, "café".encode()) def test_pong_binary(self): self.loop.run_until_complete(self.protocol.pong(b"tea")) @@ -1072,8 +1070,8 @@ def test_return_latency_on_pong(self): # Test the protocol's logic for rebuilding fragmented messages. def test_fragmented_text(self): - self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) - self.receive_frame(Frame(True, OP_CONT, "fé".encode("utf-8"))) + self.receive_frame(Frame(False, OP_TEXT, "ca".encode())) + self.receive_frame(Frame(True, OP_CONT, "fé".encode())) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, "café") @@ -1086,8 +1084,8 @@ def test_fragmented_binary(self): def test_fragmented_text_payload_too_big(self): self.protocol.max_size = 1024 - self.receive_frame(Frame(False, OP_TEXT, "café".encode("utf-8") * 100)) - self.receive_frame(Frame(True, OP_CONT, "café".encode("utf-8") * 105)) + self.receive_frame(Frame(False, OP_TEXT, "café".encode() * 100)) + self.receive_frame(Frame(True, OP_CONT, "café".encode() * 105)) self.process_invalid_frames() self.assertConnectionFailed(CloseCode.MESSAGE_TOO_BIG, "") @@ -1100,8 +1098,8 @@ def test_fragmented_binary_payload_too_big(self): def test_fragmented_text_no_max_size(self): self.protocol.max_size = None # for test coverage - self.receive_frame(Frame(False, OP_TEXT, "café".encode("utf-8") * 100)) - self.receive_frame(Frame(True, OP_CONT, "café".encode("utf-8") * 105)) + self.receive_frame(Frame(False, OP_TEXT, "café".encode() * 100)) + self.receive_frame(Frame(True, OP_CONT, "café".encode() * 105)) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, "café" * 205) @@ -1113,22 +1111,22 @@ def test_fragmented_binary_no_max_size(self): self.assertEqual(data, b"tea" * 342) def test_control_frame_within_fragmented_text(self): - self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) + self.receive_frame(Frame(False, OP_TEXT, "ca".encode())) self.receive_frame(Frame(True, OP_PING, b"")) - self.receive_frame(Frame(True, OP_CONT, "fé".encode("utf-8"))) + self.receive_frame(Frame(True, OP_CONT, "fé".encode())) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, "café") self.assertOneFrameSent(True, OP_PONG, b"") def test_unterminated_fragmented_text(self): - self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) + self.receive_frame(Frame(False, OP_TEXT, "ca".encode())) # Missing the second part of the fragmented frame. self.receive_frame(Frame(True, OP_BINARY, b"tea")) self.process_invalid_frames() self.assertConnectionFailed(CloseCode.PROTOCOL_ERROR, "") def test_close_handshake_in_fragmented_text(self): - self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) + self.receive_frame(Frame(False, OP_TEXT, "ca".encode())) self.receive_frame(Frame(True, OP_CLOSE, b"")) self.process_invalid_frames() # The RFC may have overlooked this case: it says that control frames @@ -1138,7 +1136,7 @@ def test_close_handshake_in_fragmented_text(self): self.assertConnectionClosed(CloseCode.NO_STATUS_RCVD, "") def test_connection_close_in_fragmented_text(self): - self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) + self.receive_frame(Frame(False, OP_TEXT, "ca".encode())) self.process_invalid_frames() self.assertConnectionFailed(CloseCode.ABNORMAL_CLOSURE, "") @@ -1472,7 +1470,7 @@ def test_remote_close_during_send(self): def test_broadcast_text(self): broadcast([self.protocol], "café") - self.assertOneFrameSent(True, OP_TEXT, "café".encode("utf-8")) + self.assertOneFrameSent(True, OP_TEXT, "café".encode()) def test_broadcast_binary(self): broadcast([self.protocol], b"tea") @@ -1489,8 +1487,8 @@ def test_broadcast_no_clients(self): def test_broadcast_two_clients(self): broadcast([self.protocol, self.protocol], "café") self.assertFramesSent( - (True, OP_TEXT, "café".encode("utf-8")), - (True, OP_TEXT, "café".encode("utf-8")), + (True, OP_TEXT, "café".encode()), + (True, OP_TEXT, "café".encode()), ) def test_broadcast_skips_closed_connection(self): @@ -1513,7 +1511,7 @@ def test_broadcast_skips_connection_sending_fragmented_text(self): self.make_drain_slow() self.loop.create_task(self.protocol.send(["ca", "fé"])) self.run_loop_once() - self.assertOneFrameSent(False, OP_TEXT, "ca".encode("utf-8")) + self.assertOneFrameSent(False, OP_TEXT, "ca".encode()) with self.assertLogs("websockets", logging.WARNING) as logs: broadcast([self.protocol], "café") @@ -1530,7 +1528,7 @@ def test_broadcast_reports_connection_sending_fragmented_text(self): self.make_drain_slow() self.loop.create_task(self.protocol.send(["ca", "fé"])) self.run_loop_once() - self.assertOneFrameSent(False, OP_TEXT, "ca".encode("utf-8")) + self.assertOneFrameSent(False, OP_TEXT, "ca".encode()) with self.assertRaises(ExceptionGroup) as raised: broadcast([self.protocol], "café", raise_exceptions=True) diff --git a/tests/test_frames.py b/tests/test_frames.py index e323b3b57..3e9f5d6f8 100644 --- a/tests/test_frames.py +++ b/tests/test_frames.py @@ -77,14 +77,14 @@ def test_binary_masked(self): def test_non_ascii_text_unmasked(self): self.assertFrameData( - Frame(OP_TEXT, "café".encode("utf-8")), + Frame(OP_TEXT, "café".encode()), b"\x81\x05caf\xc3\xa9", mask=False, ) def test_non_ascii_text_masked(self): self.assertFrameData( - Frame(OP_TEXT, "café".encode("utf-8")), + Frame(OP_TEXT, "café".encode()), b"\x81\x85\x64\xbe\xee\x7e\x07\xdf\x88\xbd\xcd", mask=True, ) From ebc9890d4f2a7b1675d50d4fea167b9107082e9a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jan 2024 09:27:40 +0100 Subject: [PATCH 1261/1539] Remove redundant return types from docs. sphinx picks them from function signatures. This changes slightly the output e.g. ExtensionParameter becomes Tuple[str, str | None]. While this can be a bit less readable, it looks like an improvement because the information is available without needing to navigate to the definition of ExtensionParameter. --- src/websockets/client.py | 6 +++--- src/websockets/extensions/base.py | 13 ++++++------- src/websockets/legacy/auth.py | 2 +- src/websockets/legacy/handshake.py | 4 ++-- src/websockets/legacy/protocol.py | 11 +++++------ src/websockets/legacy/server.py | 11 +++++------ src/websockets/protocol.py | 6 +++--- src/websockets/server.py | 16 +++++++--------- src/websockets/uri.py | 2 +- 9 files changed, 33 insertions(+), 38 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index b2f622042..85bc81b47 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -101,7 +101,7 @@ def connect(self) -> Request: You can modify it before sending it, for example to add HTTP headers. Returns: - Request: WebSocket handshake request event to send to the server. + WebSocket handshake request event to send to the server. """ headers = Headers() @@ -213,7 +213,7 @@ def process_extensions(self, headers: Headers) -> List[Extension]: headers: WebSocket handshake response headers. Returns: - List[Extension]: List of accepted extensions. + List of accepted extensions. Raises: InvalidHandshake: to abort the handshake. @@ -271,7 +271,7 @@ def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: headers: WebSocket handshake response headers. Returns: - Optional[Subprotocol]: Subprotocol, if one was selected. + Subprotocol, if one was selected. """ subprotocol: Optional[Subprotocol] = None diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index 6c481a46c..9eba6c9e7 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -32,7 +32,7 @@ def decode( max_size: maximum payload size in bytes. Returns: - Frame: Decoded frame. + Decoded frame. Raises: PayloadTooBig: if decoding the payload exceeds ``max_size``. @@ -48,7 +48,7 @@ def encode(self, frame: frames.Frame) -> frames.Frame: frame (Frame): outgoing frame. Returns: - Frame: Encoded frame. + Encoded frame. """ raise NotImplementedError @@ -68,7 +68,7 @@ def get_request_params(self) -> List[ExtensionParameter]: Build parameters to send to the server for this extension. Returns: - List[ExtensionParameter]: Parameters to send to the server. + Parameters to send to the server. """ raise NotImplementedError @@ -88,7 +88,7 @@ def process_response_params( accepted extensions. Returns: - Extension: An extension instance. + An extension instance. Raises: NegotiationError: if parameters aren't acceptable. @@ -121,9 +121,8 @@ def process_request_params( accepted extensions. Returns: - Tuple[List[ExtensionParameter], Extension]: To accept the offer, - parameters to send to the client for this extension and an - extension instance. + To accept the offer, parameters to send to the client for this + extension and an extension instance. Raises: NegotiationError: to reject the offer, if parameters received from diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index e8d6b75d5..8217afedd 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -66,7 +66,7 @@ async def check_credentials(self, username: str, password: str) -> bool: password: HTTP Basic Auth password. Returns: - bool: :obj:`True` if the handshake should continue; + :obj:`True` if the handshake should continue; :obj:`False` if it should fail with an HTTP 401 error. """ diff --git a/src/websockets/legacy/handshake.py b/src/websockets/legacy/handshake.py index ad8faf040..5853c31db 100644 --- a/src/websockets/legacy/handshake.py +++ b/src/websockets/legacy/handshake.py @@ -24,7 +24,7 @@ def build_request(headers: Headers) -> str: headers: Handshake request headers. Returns: - str: ``key`` that must be passed to :func:`check_response`. + ``key`` that must be passed to :func:`check_response`. """ key = generate_key() @@ -48,7 +48,7 @@ def check_request(headers: Headers) -> str: headers: Handshake request headers. Returns: - str: ``key`` that must be passed to :func:`build_response`. + ``key`` that must be passed to :func:`build_response`. Raises: InvalidHandshake: If the handshake request is invalid. diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 47f948b7a..a9fbd5a7a 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -518,7 +518,7 @@ async def recv(self) -> Data: :func:`~asyncio.timeout` or :func:`~asyncio.wait_for`. Returns: - Data: A string (:class:`str`) for a Text_ frame. A bytestring + A string (:class:`str`) for a Text_ frame. A bytestring (:class:`bytes`) for a Binary_ frame. .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 @@ -805,7 +805,7 @@ async def wait_closed(self) -> None: """ await asyncio.shield(self.connection_lost_waiter) - async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: + async def ping(self, data: Optional[Data] = None) -> Awaitable[float]: """ Send a Ping_. @@ -827,10 +827,9 @@ async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: containing four random bytes. Returns: - ~asyncio.Future[float]: A future that will be completed when the - corresponding pong is received. You can ignore it if you don't - intend to wait. The result of the future is the latency of the - connection in seconds. + A future that will be completed when the corresponding pong is + received. You can ignore it if you don't intend to wait. The result + of the future is the latency of the connection in seconds. :: diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index d95bec4f6..297613591 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -349,8 +349,8 @@ async def process_request( request_headers: request headers. Returns: - Optional[Tuple[StatusLike, HeadersLike, bytes]]: :obj:`None` - to continue the WebSocket handshake normally. + Tuple[StatusLike, HeadersLike, bytes] | None: :obj:`None` to + continue the WebSocket handshake normally. An HTTP response, represented by a 3-uple of the response status, headers, and body, to abort the WebSocket handshake and return @@ -534,8 +534,7 @@ def select_subprotocol( server_subprotocols: list of subprotocols available on the server. Returns: - Optional[Subprotocol]: Selected subprotocol, if a common subprotocol - was found. + Selected subprotocol, if a common subprotocol was found. :obj:`None` to continue without a subprotocol. @@ -572,7 +571,7 @@ async def handshake( the handshake succeeds. Returns: - str: path of the URI of the request. + path of the URI of the request. Raises: InvalidHandshake: if the handshake fails. @@ -968,7 +967,7 @@ class Serve: outside of websockets. Returns: - WebSocketServer: WebSocket server. + WebSocket server. """ diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 765e6b9bb..342aba413 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -452,7 +452,7 @@ def events_received(self) -> List[Event]: Process resulting events, likely by passing them to the application. Returns: - List[Event]: Events read from the connection. + Events read from the connection. """ events, self.events = self.events, [] return events @@ -473,7 +473,7 @@ def data_to_send(self) -> List[bytes]: connection. Returns: - List[bytes]: Data to write to the connection. + Data to write to the connection. """ writes, self.writes = self.writes, [] @@ -490,7 +490,7 @@ def close_expected(self) -> bool: short timeout if the other side hasn't already closed it. Returns: - bool: Whether the TCP connection is expected to close soon. + Whether the TCP connection is expected to close soon. """ # We expect a TCP close if and only if we sent a close frame: diff --git a/src/websockets/server.py b/src/websockets/server.py index 191660553..58391d3cf 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -213,7 +213,6 @@ def process_request( request: WebSocket handshake request received from the client. Returns: - Tuple[str, Optional[str], Optional[str]]: ``Sec-WebSocket-Accept``, ``Sec-WebSocket-Extensions``, and ``Sec-WebSocket-Protocol`` headers for the handshake response. @@ -294,7 +293,7 @@ def process_origin(self, headers: Headers) -> Optional[Origin]: headers: WebSocket handshake request headers. Returns: - Optional[Origin]: origin, if it is acceptable. + origin, if it is acceptable. Raises: InvalidHandshake: if the Origin header is invalid. @@ -344,8 +343,8 @@ def process_extensions( headers: WebSocket handshake request headers. Returns: - Tuple[Optional[str], List[Extension]]: ``Sec-WebSocket-Extensions`` - HTTP response header and list of accepted extensions. + ``Sec-WebSocket-Extensions`` HTTP response header and list of + accepted extensions. Raises: InvalidHandshake: if the Sec-WebSocket-Extensions header is invalid. @@ -401,8 +400,8 @@ def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: headers: WebSocket handshake request headers. Returns: - Optional[Subprotocol]: Subprotocol, if one was selected; this is - also the value of the ``Sec-WebSocket-Protocol`` response header. + Subprotocol, if one was selected; this is also the value of the + ``Sec-WebSocket-Protocol`` response header. Raises: InvalidHandshake: if the Sec-WebSocket-Subprotocol header is invalid. @@ -449,8 +448,7 @@ def select_subprotocol(protocol, subprotocols): subprotocols: list of subprotocols offered by the client. Returns: - Optional[Subprotocol]: Selected subprotocol, if a common subprotocol - was found. + Selected subprotocol, if a common subprotocol was found. :obj:`None` to continue without a subprotocol. @@ -499,7 +497,7 @@ def reject( text: HTTP response body; will be encoded to UTF-8. Returns: - Response: WebSocket handshake response event to send to the client. + WebSocket handshake response event to send to the client. """ # If a user passes an int instead of a HTTPStatus, fix it automatically. diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 385090f66..970020e26 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -66,7 +66,7 @@ def parse_uri(uri: str) -> WebSocketURI: uri: WebSocket URI. Returns: - WebSocketURI: Parsed WebSocket URI. + Parsed WebSocket URI. Raises: InvalidURI: if ``uri`` isn't a valid WebSocket URI. From c53fc3b7eed17c12c1b4db5d456b7921c2cde98f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jan 2024 10:01:53 +0100 Subject: [PATCH 1262/1539] Start argument descriptions with uppercase letter. This change was automated with regex replaces: ^( Args:\n(?: .*\n(?: .*\n)*)*?(?: \w+(?: \(.*\))?): )([a-z]) $1\U$2 ^( Args:\n(?: .*\n(?: .*\n)*)*?(?: \w+(?: \(.*\))?): )([a-z]) $1\U$2 Also remove redundant type annotations. --- src/websockets/client.py | 12 ++++---- src/websockets/datastructures.py | 2 +- src/websockets/extensions/base.py | 18 +++++------- .../extensions/permessage_deflate.py | 22 +++++++-------- src/websockets/frames.py | 14 +++++----- src/websockets/headers.py | 6 ++-- src/websockets/http11.py | 12 ++++---- src/websockets/legacy/protocol.py | 11 +++----- src/websockets/legacy/server.py | 28 +++++++++---------- src/websockets/protocol.py | 6 ++-- src/websockets/server.py | 14 +++++----- src/websockets/streams.py | 8 +++--- src/websockets/sync/utils.py | 2 +- src/websockets/utils.py | 4 +-- 14 files changed, 76 insertions(+), 83 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 85bc81b47..028e7ce47 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -53,17 +53,17 @@ class ClientProtocol(Protocol): Args: wsuri: URI of the WebSocket server, parsed with :func:`~websockets.uri.parse_uri`. - origin: value of the ``Origin`` header. This is useful when connecting + origin: Value of the ``Origin`` header. This is useful when connecting to a server that validates the ``Origin`` header to defend against Cross-Site WebSocket Hijacking attacks. - extensions: list of supported extensions, in order in which they + extensions: List of supported extensions, in order in which they should be tried. - subprotocols: list of supported subprotocols, in order of decreasing + subprotocols: List of supported subprotocols, in order of decreasing preference. - state: initial state of the WebSocket connection. - max_size: maximum size of incoming messages in bytes; + state: Initial state of the WebSocket connection. + max_size: Maximum size of incoming messages in bytes; :obj:`None` disables the limit. - logger: logger for this connection; + logger: Logger for this connection; defaults to ``logging.getLogger("websockets.client")``; see the :doc:`logging guide <../../topics/logging>` for details. diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index a0a648463..c2a5acfee 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -152,7 +152,7 @@ def get_all(self, key: str) -> List[str]: Return the (possibly empty) list of all values for a header. Args: - key: header name. + key: Header name. """ return self._dict.get(key.lower(), []) diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index 9eba6c9e7..cca3fe513 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -28,8 +28,8 @@ def decode( Decode an incoming frame. Args: - frame (Frame): incoming frame. - max_size: maximum payload size in bytes. + frame: Incoming frame. + max_size: Maximum payload size in bytes. Returns: Decoded frame. @@ -45,7 +45,7 @@ def encode(self, frame: frames.Frame) -> frames.Frame: Encode an outgoing frame. Args: - frame (Frame): outgoing frame. + frame: Outgoing frame. Returns: Encoded frame. @@ -82,10 +82,8 @@ def process_response_params( Process parameters received from the server. Args: - params (Sequence[ExtensionParameter]): parameters received from - the server for this extension. - accepted_extensions (Sequence[Extension]): list of previously - accepted extensions. + params: Parameters received from the server for this extension. + accepted_extensions: List of previously accepted extensions. Returns: An extension instance. @@ -115,10 +113,8 @@ def process_request_params( Process parameters received from the client. Args: - params (Sequence[ExtensionParameter]): parameters received from - the client for this extension. - accepted_extensions (Sequence[Extension]): list of previously - accepted extensions. + params: Parameters received from the client for this extension. + accepted_extensions: List of previously accepted extensions. Returns: To accept the offer, parameters to send to the client for this diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index b391837c6..edccac3ca 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -268,14 +268,14 @@ class ClientPerMessageDeflateFactory(ClientExtensionFactory): value or to an integer value to include them with this value. Args: - server_no_context_takeover: prevent server from using context takeover. - client_no_context_takeover: prevent client from using context takeover. - server_max_window_bits: maximum size of the server's LZ77 sliding window + server_no_context_takeover: Prevent server from using context takeover. + client_no_context_takeover: Prevent client from using context takeover. + server_max_window_bits: Maximum size of the server's LZ77 sliding window in bits, between 8 and 15. - client_max_window_bits: maximum size of the client's LZ77 sliding window + client_max_window_bits: Maximum size of the client's LZ77 sliding window in bits, between 8 and 15, or :obj:`True` to indicate support without setting a limit. - compress_settings: additional keyword arguments for :func:`zlib.compressobj`, + compress_settings: Additional keyword arguments for :func:`zlib.compressobj`, excluding ``wbits``. """ @@ -468,15 +468,15 @@ class ServerPerMessageDeflateFactory(ServerExtensionFactory): value or to an integer value to include them with this value. Args: - server_no_context_takeover: prevent server from using context takeover. - client_no_context_takeover: prevent client from using context takeover. - server_max_window_bits: maximum size of the server's LZ77 sliding window + server_no_context_takeover: Prevent server from using context takeover. + client_no_context_takeover: Prevent client from using context takeover. + server_max_window_bits: Maximum size of the server's LZ77 sliding window in bits, between 8 and 15. - client_max_window_bits: maximum size of the client's LZ77 sliding window + client_max_window_bits: Maximum size of the client's LZ77 sliding window in bits, between 8 and 15. - compress_settings: additional keyword arguments for :func:`zlib.compressobj`, + compress_settings: Additional keyword arguments for :func:`zlib.compressobj`, excluding ``wbits``. - require_client_max_window_bits: do not enable compression at all if + require_client_max_window_bits: Do not enable compression at all if client doesn't advertise support for ``client_max_window_bits``; the default behavior is to enable compression without enforcing ``client_max_window_bits``. diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 63c35ed4d..e5e2af8b4 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -208,12 +208,12 @@ def parse( This is a generator-based coroutine. Args: - read_exact: generator-based coroutine that reads the requested + read_exact: Generator-based coroutine that reads the requested bytes or raises an exception if there isn't enough data. - mask: whether the frame should be masked i.e. whether the read + mask: Whether the frame should be masked i.e. whether the read happens on the server side. - max_size: maximum payload size in bytes. - extensions: list of extensions, applied in reverse order. + max_size: Maximum payload size in bytes. + extensions: List of extensions, applied in reverse order. Raises: EOFError: if the connection is closed without a full WebSocket frame. @@ -280,9 +280,9 @@ def serialize( Serialize a WebSocket frame. Args: - mask: whether the frame should be masked i.e. whether the write + mask: Whether the frame should be masked i.e. whether the write happens on the client side. - extensions: list of extensions, applied in order. + extensions: List of extensions, applied in order. Raises: ProtocolError: if the frame contains incorrect values. @@ -432,7 +432,7 @@ def parse(cls, data: bytes) -> Close: Parse the payload of a close frame. Args: - data: payload of the close frame. + data: Payload of the close frame. Raises: ProtocolError: if data is ill-formed. diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 9ae3035a5..8391ad26c 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -289,7 +289,7 @@ def parse_upgrade(header: str) -> List[UpgradeProtocol]: Return a list of HTTP protocols. Args: - header: value of the ``Upgrade`` header. + header: Value of the ``Upgrade`` header. Raises: InvalidHeaderFormat: on invalid inputs. @@ -486,7 +486,7 @@ def build_www_authenticate_basic(realm: str) -> str: Build a ``WWW-Authenticate`` header for HTTP Basic Auth. Args: - realm: identifier of the protection space. + realm: Identifier of the protection space. """ # https://www.rfc-editor.org/rfc/rfc7617.html#section-2 @@ -532,7 +532,7 @@ def parse_authorization_basic(header: str) -> Tuple[str, str]: Return a ``(username, password)`` tuple. Args: - header: value of the ``Authorization`` header. + header: Value of the ``Authorization`` header. Raises: InvalidHeaderFormat: on invalid inputs. diff --git a/src/websockets/http11.py b/src/websockets/http11.py index ec4e3b8b7..c0a96f878 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -93,7 +93,7 @@ def parse( body, it may be read from the data stream after :meth:`parse` returns. Args: - read_line: generator-based coroutine that reads a LF-terminated + read_line: Generator-based coroutine that reads a LF-terminated line or raises an exception if there isn't enough data Raises: @@ -193,11 +193,11 @@ def parse( characters. Other characters are represented with surrogate escapes. Args: - read_line: generator-based coroutine that reads a LF-terminated + read_line: Generator-based coroutine that reads a LF-terminated line or raises an exception if there isn't enough data. - read_exact: generator-based coroutine that reads the requested + read_exact: Generator-based coroutine that reads the requested bytes or raises an exception if there isn't enough data. - read_to_eof: generator-based coroutine that reads until the end + read_to_eof: Generator-based coroutine that reads until the end of the stream. Raises: @@ -295,7 +295,7 @@ def parse_headers( Non-ASCII characters are represented with surrogate escapes. Args: - read_line: generator-based coroutine that reads a LF-terminated line + read_line: Generator-based coroutine that reads a LF-terminated line or raises an exception if there isn't enough data. Raises: @@ -346,7 +346,7 @@ def parse_line( CRLF is stripped from the return value. Args: - read_line: generator-based coroutine that reads a LF-terminated line + read_line: Generator-based coroutine that reads a LF-terminated line or raises an exception if there isn't enough data. Raises: diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index a9fbd5a7a..26d50a2cc 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -624,8 +624,7 @@ async def send( error or a network failure. Args: - message (Union[Data, Iterable[Data], AsyncIterable[Data]): message - to send. + message: Message to send. Raises: ConnectionClosed: When the connection is closed. @@ -822,9 +821,8 @@ async def ping(self, data: Optional[Data] = None) -> Awaitable[float]: effect. Args: - data (Optional[Data]): payload of the ping; a string will be - encoded to UTF-8; or :obj:`None` to generate a payload - containing four random bytes. + data: Payload of the ping. A string will be encoded to UTF-8. + If ``data`` is :obj:`None`, the payload is four random bytes. Returns: A future that will be completed when the corresponding pong is @@ -878,8 +876,7 @@ async def pong(self, data: Data = b"") -> None: wait, you should close the connection. Args: - data (Data): Payload of the pong. A string will be encoded to - UTF-8. + data: Payload of the pong. A string will be encoded to UTF-8. Raises: ConnectionClosed: When the connection is closed. diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 297613591..4af7ed109 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -345,8 +345,8 @@ async def process_request( from shutting down. Args: - path: request path, including optional query string. - request_headers: request headers. + path: Request path, including optional query string. + request_headers: Request headers. Returns: Tuple[StatusLike, HeadersLike, bytes] | None: :obj:`None` to @@ -377,8 +377,8 @@ def process_origin( Handle the Origin HTTP request header. Args: - headers: request headers. - origins: optional list of acceptable origins. + headers: Request headers. + origins: Optional list of acceptable origins. Raises: InvalidOrigin: if the origin isn't acceptable. @@ -428,8 +428,8 @@ def process_extensions( order of extensions, may be implemented by overriding this method. Args: - headers: request headers. - extensions: optional list of supported extensions. + headers: Request headers. + extensions: Optional list of supported extensions. Raises: InvalidHandshake: to abort the handshake with an HTTP 400 error. @@ -488,8 +488,8 @@ def process_subprotocol( as the selected subprotocol. Args: - headers: request headers. - available_subprotocols: optional list of supported subprotocols. + headers: Request headers. + available_subprotocols: Optional list of supported subprotocols. Raises: InvalidHandshake: to abort the handshake with an HTTP 400 error. @@ -530,8 +530,8 @@ def select_subprotocol( subprotocol. Args: - client_subprotocols: list of subprotocols offered by the client. - server_subprotocols: list of subprotocols available on the server. + client_subprotocols: List of subprotocols offered by the client. + server_subprotocols: List of subprotocols available on the server. Returns: Selected subprotocol, if a common subprotocol was found. @@ -561,13 +561,13 @@ async def handshake( Perform the server side of the opening handshake. Args: - origins: list of acceptable values of the Origin HTTP header; + origins: List of acceptable values of the Origin HTTP header; include :obj:`None` if the lack of an origin is acceptable. - extensions: list of supported extensions, in order in which they + extensions: List of supported extensions, in order in which they should be tried. - subprotocols: list of supported subprotocols, in order of + subprotocols: List of supported subprotocols, in order of decreasing preference. - extra_headers: arbitrary HTTP headers to add to the response when + extra_headers: Arbitrary HTTP headers to add to the response when the handshake succeeds. Returns: diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 342aba413..99c9ee1a8 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -74,10 +74,10 @@ class Protocol: Args: side: :attr:`~Side.CLIENT` or :attr:`~Side.SERVER`. - state: initial state of the WebSocket connection. - max_size: maximum size of incoming messages in bytes; + state: Initial state of the WebSocket connection. + max_size: Maximum size of incoming messages in bytes; :obj:`None` disables the limit. - logger: logger for this connection; depending on ``side``, + logger: Logger for this connection; depending on ``side``, defaults to ``logging.getLogger("websockets.client")`` or ``logging.getLogger("websockets.server")``; see the :doc:`logging guide <../../topics/logging>` for details. diff --git a/src/websockets/server.py b/src/websockets/server.py index 58391d3cf..6711a0bba 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -53,22 +53,22 @@ class ServerProtocol(Protocol): Sans-I/O implementation of a WebSocket server connection. Args: - origins: acceptable values of the ``Origin`` header; include + origins: Acceptable values of the ``Origin`` header; include :obj:`None` in the list if the lack of an origin is acceptable. This is useful for defending against Cross-Site WebSocket Hijacking attacks. - extensions: list of supported extensions, in order in which they + extensions: List of supported extensions, in order in which they should be tried. - subprotocols: list of supported subprotocols, in order of decreasing + subprotocols: List of supported subprotocols, in order of decreasing preference. select_subprotocol: Callback for selecting a subprotocol among those supported by the client and the server. It has the same signature as the :meth:`select_subprotocol` method, including a :class:`ServerProtocol` instance as first argument. - state: initial state of the WebSocket connection. - max_size: maximum size of incoming messages in bytes; + state: Initial state of the WebSocket connection. + max_size: Maximum size of incoming messages in bytes; :obj:`None` disables the limit. - logger: logger for this connection; + logger: Logger for this connection; defaults to ``logging.getLogger("websockets.client")``; see the :doc:`logging guide <../../topics/logging>` for details. @@ -445,7 +445,7 @@ def select_subprotocol(protocol, subprotocols): return "chat" Args: - subprotocols: list of subprotocols offered by the client. + subprotocols: List of subprotocols offered by the client. Returns: Selected subprotocol, if a common subprotocol was found. diff --git a/src/websockets/streams.py b/src/websockets/streams.py index f861d4bd2..d288cf0cc 100644 --- a/src/websockets/streams.py +++ b/src/websockets/streams.py @@ -26,7 +26,7 @@ def read_line(self, m: int) -> Generator[None, None, bytes]: The return value includes the LF character. Args: - m: maximum number bytes to read; this is a security limit. + m: Maximum number bytes to read; this is a security limit. Raises: EOFError: if the stream ends without a LF. @@ -58,7 +58,7 @@ def read_exact(self, n: int) -> Generator[None, None, bytes]: This is a generator-based coroutine. Args: - n: how many bytes to read. + n: How many bytes to read. Raises: EOFError: if the stream ends in less than ``n`` bytes. @@ -81,7 +81,7 @@ def read_to_eof(self, m: int) -> Generator[None, None, bytes]: This is a generator-based coroutine. Args: - m: maximum number bytes to read; this is a security limit. + m: Maximum number bytes to read; this is a security limit. Raises: RuntimeError: if the stream ends in more than ``m`` bytes. @@ -119,7 +119,7 @@ def feed_data(self, data: bytes) -> None: :meth:`feed_data` cannot be called after :meth:`feed_eof`. Args: - data: data to write. + data: Data to write. Raises: EOFError: if the stream has ended. diff --git a/src/websockets/sync/utils.py b/src/websockets/sync/utils.py index 471f32e19..3364bdc2d 100644 --- a/src/websockets/sync/utils.py +++ b/src/websockets/sync/utils.py @@ -28,7 +28,7 @@ def timeout(self, *, raise_if_elapsed: bool = True) -> Optional[float]: Calculate a timeout from a deadline. Args: - raise_if_elapsed (bool): Whether to raise :exc:`TimeoutError` + raise_if_elapsed: Whether to raise :exc:`TimeoutError` if the deadline lapsed. Raises: diff --git a/src/websockets/utils.py b/src/websockets/utils.py index c40404906..62d2dc177 100644 --- a/src/websockets/utils.py +++ b/src/websockets/utils.py @@ -26,7 +26,7 @@ def accept_key(key: str) -> str: Compute the value of the Sec-WebSocket-Accept header. Args: - key: value of the Sec-WebSocket-Key header. + key: Value of the Sec-WebSocket-Key header. """ sha1 = hashlib.sha1((key + GUID).encode()).digest() @@ -38,7 +38,7 @@ def apply_mask(data: bytes, mask: bytes) -> bytes: Apply masking to the data of a WebSocket message. Args: - data: data to mask. + data: Data to mask. mask: 4-bytes mask. """ From 2865bdcc8b93f78d019aa0c605c86535dd66d026 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jan 2024 10:14:25 +0100 Subject: [PATCH 1263/1539] Start exception descriptions with uppercase letter. This change was automated with regex replaces: ^( Raises:\n(?: .*\n(?: .*\n)*)*?(?: \w+): )([a-z]) $1\U$2 ^( Raises:\n(?: .*\n(?: .*\n)*)*?(?: \w+): )([a-z]) $1\U$2 --- src/websockets/client.py | 4 ++-- src/websockets/extensions/base.py | 6 +++--- src/websockets/frames.py | 22 +++++++++++----------- src/websockets/headers.py | 30 +++++++++++++++--------------- src/websockets/http11.py | 24 ++++++++++++------------ src/websockets/legacy/server.py | 10 +++++----- src/websockets/protocol.py | 16 ++++++++-------- src/websockets/server.py | 12 ++++++------ src/websockets/streams.py | 12 ++++++------ src/websockets/uri.py | 2 +- 10 files changed, 69 insertions(+), 69 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 028e7ce47..633b1960b 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -144,7 +144,7 @@ def process_response(self, response: Response) -> None: request: WebSocket handshake response received from the server. Raises: - InvalidHandshake: if the handshake response is invalid. + InvalidHandshake: If the handshake response is invalid. """ @@ -216,7 +216,7 @@ def process_extensions(self, headers: Headers) -> List[Extension]: List of accepted extensions. Raises: - InvalidHandshake: to abort the handshake. + InvalidHandshake: To abort the handshake. """ accepted_extensions: List[Extension] = [] diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index cca3fe513..7446c990c 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -35,7 +35,7 @@ def decode( Decoded frame. Raises: - PayloadTooBig: if decoding the payload exceeds ``max_size``. + PayloadTooBig: If decoding the payload exceeds ``max_size``. """ raise NotImplementedError @@ -89,7 +89,7 @@ def process_response_params( An extension instance. Raises: - NegotiationError: if parameters aren't acceptable. + NegotiationError: If parameters aren't acceptable. """ raise NotImplementedError @@ -121,7 +121,7 @@ def process_request_params( extension and an extension instance. Raises: - NegotiationError: to reject the offer, if parameters received from + NegotiationError: To reject the offer, if parameters received from the client aren't acceptable. """ diff --git a/src/websockets/frames.py b/src/websockets/frames.py index e5e2af8b4..201bc5068 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -216,10 +216,10 @@ def parse( extensions: List of extensions, applied in reverse order. Raises: - EOFError: if the connection is closed without a full WebSocket frame. - UnicodeDecodeError: if the frame contains invalid UTF-8. - PayloadTooBig: if the frame's payload size exceeds ``max_size``. - ProtocolError: if the frame contains incorrect values. + EOFError: If the connection is closed without a full WebSocket frame. + UnicodeDecodeError: If the frame contains invalid UTF-8. + PayloadTooBig: If the frame's payload size exceeds ``max_size``. + ProtocolError: If the frame contains incorrect values. """ # Read the header. @@ -285,7 +285,7 @@ def serialize( extensions: List of extensions, applied in order. Raises: - ProtocolError: if the frame contains incorrect values. + ProtocolError: If the frame contains incorrect values. """ self.check() @@ -334,7 +334,7 @@ def check(self) -> None: Check that reserved bits and opcode have acceptable values. Raises: - ProtocolError: if a reserved bit or the opcode is invalid. + ProtocolError: If a reserved bit or the opcode is invalid. """ if self.rsv1 or self.rsv2 or self.rsv3: @@ -360,7 +360,7 @@ def prepare_data(data: Data) -> Tuple[int, bytes]: object. Raises: - TypeError: if ``data`` doesn't have a supported type. + TypeError: If ``data`` doesn't have a supported type. """ if isinstance(data, str): @@ -383,7 +383,7 @@ def prepare_ctrl(data: Data) -> bytes: If ``data`` is a bytes-like object, return a :class:`bytes` object. Raises: - TypeError: if ``data`` doesn't have a supported type. + TypeError: If ``data`` doesn't have a supported type. """ if isinstance(data, str): @@ -435,8 +435,8 @@ def parse(cls, data: bytes) -> Close: data: Payload of the close frame. Raises: - ProtocolError: if data is ill-formed. - UnicodeDecodeError: if the reason isn't valid UTF-8. + ProtocolError: If data is ill-formed. + UnicodeDecodeError: If the reason isn't valid UTF-8. """ if len(data) >= 2: @@ -463,7 +463,7 @@ def check(self) -> None: Check that the close code has a valid value for a close frame. Raises: - ProtocolError: if the close code is invalid. + ProtocolError: If the close code is invalid. """ if not (self.code in EXTERNAL_CLOSE_CODES or 3000 <= self.code < 5000): diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 8391ad26c..463df3061 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -103,7 +103,7 @@ def parse_token(header: str, pos: int, header_name: str) -> Tuple[str, int]: Return the token value and the new position. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ match = _token_re.match(header, pos) @@ -127,7 +127,7 @@ def parse_quoted_string(header: str, pos: int, header_name: str) -> Tuple[str, i Return the unquoted value and the new position. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ match = _quoted_string_re.match(header, pos) @@ -180,7 +180,7 @@ def parse_list( Return a list of items. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ # Per https://www.rfc-editor.org/rfc/rfc7230.html#section-7, "a recipient @@ -234,7 +234,7 @@ def parse_connection_option( Return the protocol value and the new position. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ item, pos = parse_token(header, pos, header_name) @@ -251,7 +251,7 @@ def parse_connection(header: str) -> List[ConnectionOption]: header: value of the ``Connection`` header. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ return parse_list(parse_connection_option, header, 0, "Connection") @@ -271,7 +271,7 @@ def parse_upgrade_protocol( Return the protocol value and the new position. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ match = _protocol_re.match(header, pos) @@ -292,7 +292,7 @@ def parse_upgrade(header: str) -> List[UpgradeProtocol]: header: Value of the ``Upgrade`` header. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ return parse_list(parse_upgrade_protocol, header, 0, "Upgrade") @@ -307,7 +307,7 @@ def parse_extension_item_param( Return a ``(name, value)`` pair and the new position. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ # Extract parameter name. @@ -344,7 +344,7 @@ def parse_extension_item( list of ``(name, value)`` pairs, and the new position. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ # Extract extension name. @@ -379,7 +379,7 @@ def parse_extension(header: str) -> List[ExtensionHeader]: Parameter values are :obj:`None` when no value is provided. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ return parse_list(parse_extension_item, header, 0, "Sec-WebSocket-Extensions") @@ -431,7 +431,7 @@ def parse_subprotocol_item( Return the subprotocol value and the new position. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ item, pos = parse_token(header, pos, header_name) @@ -445,7 +445,7 @@ def parse_subprotocol(header: str) -> List[Subprotocol]: Return a list of WebSocket subprotocols. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ return parse_list(parse_subprotocol_item, header, 0, "Sec-WebSocket-Protocol") @@ -505,7 +505,7 @@ def parse_token68(header: str, pos: int, header_name: str) -> Tuple[str, int]: Return the token value and the new position. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ match = _token68_re.match(header, pos) @@ -535,8 +535,8 @@ def parse_authorization_basic(header: str) -> Tuple[str, str]: header: Value of the ``Authorization`` header. Raises: - InvalidHeaderFormat: on invalid inputs. - InvalidHeaderValue: on unsupported inputs. + InvalidHeaderFormat: On invalid inputs. + InvalidHeaderValue: On unsupported inputs. """ # https://www.rfc-editor.org/rfc/rfc7235.html#section-2.1 diff --git a/src/websockets/http11.py b/src/websockets/http11.py index c0a96f878..6fe775eec 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -97,9 +97,9 @@ def parse( line or raises an exception if there isn't enough data Raises: - EOFError: if the connection is closed without a full HTTP request. - SecurityError: if the request exceeds a security limit. - ValueError: if the request isn't well formatted. + EOFError: If the connection is closed without a full HTTP request. + SecurityError: If the request exceeds a security limit. + ValueError: If the request isn't well formatted. """ # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.1 @@ -201,10 +201,10 @@ def parse( of the stream. Raises: - EOFError: if the connection is closed without a full HTTP response. - SecurityError: if the response exceeds a security limit. - LookupError: if the response isn't well formatted. - ValueError: if the response isn't well formatted. + EOFError: If the connection is closed without a full HTTP response. + SecurityError: If the response exceeds a security limit. + LookupError: If the response isn't well formatted. + ValueError: If the response isn't well formatted. """ # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.2 @@ -299,9 +299,9 @@ def parse_headers( or raises an exception if there isn't enough data. Raises: - EOFError: if the connection is closed without complete headers. - SecurityError: if the request exceeds a security limit. - ValueError: if the request isn't well formatted. + EOFError: If the connection is closed without complete headers. + SecurityError: If the request exceeds a security limit. + ValueError: If the request isn't well formatted. """ # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.2 @@ -350,8 +350,8 @@ def parse_line( or raises an exception if there isn't enough data. Raises: - EOFError: if the connection is closed without a CRLF. - SecurityError: if the response exceeds a security limit. + EOFError: If the connection is closed without a CRLF. + SecurityError: If the response exceeds a security limit. """ try: diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 4af7ed109..e8cf8220f 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -271,7 +271,7 @@ async def read_http_request(self) -> Tuple[str, Headers]: after this coroutine returns. Raises: - InvalidMessage: if the HTTP message is malformed or isn't an + InvalidMessage: If the HTTP message is malformed or isn't an HTTP/1.1 GET request. """ @@ -381,7 +381,7 @@ def process_origin( origins: Optional list of acceptable origins. Raises: - InvalidOrigin: if the origin isn't acceptable. + InvalidOrigin: If the origin isn't acceptable. """ # "The user agent MUST NOT include more than one Origin header field" @@ -432,7 +432,7 @@ def process_extensions( extensions: Optional list of supported extensions. Raises: - InvalidHandshake: to abort the handshake with an HTTP 400 error. + InvalidHandshake: To abort the handshake with an HTTP 400 error. """ response_header_value: Optional[str] = None @@ -492,7 +492,7 @@ def process_subprotocol( available_subprotocols: Optional list of supported subprotocols. Raises: - InvalidHandshake: to abort the handshake with an HTTP 400 error. + InvalidHandshake: To abort the handshake with an HTTP 400 error. """ subprotocol: Optional[Subprotocol] = None @@ -574,7 +574,7 @@ async def handshake( path of the URI of the request. Raises: - InvalidHandshake: if the handshake fails. + InvalidHandshake: If the handshake fails. """ path, request_headers = await self.read_http_request() diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 99c9ee1a8..6851f3b1f 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -217,7 +217,7 @@ def close_exc(self) -> ConnectionClosed: known only once the connection is closed. Raises: - AssertionError: if the connection isn't closed yet. + AssertionError: If the connection isn't closed yet. """ assert self.state is CLOSED, "connection isn't closed yet" @@ -252,7 +252,7 @@ def receive_data(self, data: bytes) -> None: - You should call :meth:`events_received` and process resulting events. Raises: - EOFError: if :meth:`receive_eof` was called earlier. + EOFError: If :meth:`receive_eof` was called earlier. """ self.reader.feed_data(data) @@ -270,7 +270,7 @@ def receive_eof(self) -> None: any new events. Raises: - EOFError: if :meth:`receive_eof` was called earlier. + EOFError: If :meth:`receive_eof` was called earlier. """ self.reader.feed_eof() @@ -292,7 +292,7 @@ def send_continuation(self, data: bytes, fin: bool) -> None: of a fragmented message and to :obj:`False` otherwise. Raises: - ProtocolError: if a fragmented message isn't in progress. + ProtocolError: If a fragmented message isn't in progress. """ if not self.expect_continuation_frame: @@ -313,7 +313,7 @@ def send_text(self, data: bytes, fin: bool = True) -> None: a fragmented message. Raises: - ProtocolError: if a fragmented message is in progress. + ProtocolError: If a fragmented message is in progress. """ if self.expect_continuation_frame: @@ -334,7 +334,7 @@ def send_binary(self, data: bytes, fin: bool = True) -> None: a fragmented message. Raises: - ProtocolError: if a fragmented message is in progress. + ProtocolError: If a fragmented message is in progress. """ if self.expect_continuation_frame: @@ -354,7 +354,7 @@ def send_close(self, code: Optional[int] = None, reason: str = "") -> None: reason: close reason. Raises: - ProtocolError: if a fragmented message is being sent, if the code + ProtocolError: If a fragmented message is being sent, if the code isn't valid, or if a reason is provided without a code """ @@ -412,7 +412,7 @@ def fail(self, code: int, reason: str = "") -> None: reason: close reason Raises: - ProtocolError: if the code isn't valid. + ProtocolError: If the code isn't valid. """ # 7.1.7. Fail the WebSocket Connection diff --git a/src/websockets/server.py b/src/websockets/server.py index 6711a0bba..330e54f37 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -217,7 +217,7 @@ def process_request( ``Sec-WebSocket-Protocol`` headers for the handshake response. Raises: - InvalidHandshake: if the handshake request is invalid; + InvalidHandshake: If the handshake request is invalid; then the server must return 400 Bad Request error. """ @@ -296,8 +296,8 @@ def process_origin(self, headers: Headers) -> Optional[Origin]: origin, if it is acceptable. Raises: - InvalidHandshake: if the Origin header is invalid. - InvalidOrigin: if the origin isn't acceptable. + InvalidHandshake: If the Origin header is invalid. + InvalidOrigin: If the origin isn't acceptable. """ # "The user agent MUST NOT include more than one Origin header field" @@ -347,7 +347,7 @@ def process_extensions( accepted extensions. Raises: - InvalidHandshake: if the Sec-WebSocket-Extensions header is invalid. + InvalidHandshake: If the Sec-WebSocket-Extensions header is invalid. """ response_header_value: Optional[str] = None @@ -404,7 +404,7 @@ def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: ``Sec-WebSocket-Protocol`` response header. Raises: - InvalidHandshake: if the Sec-WebSocket-Subprotocol header is invalid. + InvalidHandshake: If the Sec-WebSocket-Subprotocol header is invalid. """ subprotocols: Sequence[Subprotocol] = sum( @@ -453,7 +453,7 @@ def select_subprotocol(protocol, subprotocols): :obj:`None` to continue without a subprotocol. Raises: - NegotiationError: custom implementations may raise this exception + NegotiationError: Custom implementations may raise this exception to abort the handshake with an HTTP 400 error. """ diff --git a/src/websockets/streams.py b/src/websockets/streams.py index d288cf0cc..956f139d4 100644 --- a/src/websockets/streams.py +++ b/src/websockets/streams.py @@ -29,8 +29,8 @@ def read_line(self, m: int) -> Generator[None, None, bytes]: m: Maximum number bytes to read; this is a security limit. Raises: - EOFError: if the stream ends without a LF. - RuntimeError: if the stream ends in more than ``m`` bytes. + EOFError: If the stream ends without a LF. + RuntimeError: If the stream ends in more than ``m`` bytes. """ n = 0 # number of bytes to read @@ -61,7 +61,7 @@ def read_exact(self, n: int) -> Generator[None, None, bytes]: n: How many bytes to read. Raises: - EOFError: if the stream ends in less than ``n`` bytes. + EOFError: If the stream ends in less than ``n`` bytes. """ assert n >= 0 @@ -84,7 +84,7 @@ def read_to_eof(self, m: int) -> Generator[None, None, bytes]: m: Maximum number bytes to read; this is a security limit. Raises: - RuntimeError: if the stream ends in more than ``m`` bytes. + RuntimeError: If the stream ends in more than ``m`` bytes. """ while not self.eof: @@ -122,7 +122,7 @@ def feed_data(self, data: bytes) -> None: data: Data to write. Raises: - EOFError: if the stream has ended. + EOFError: If the stream has ended. """ if self.eof: @@ -136,7 +136,7 @@ def feed_eof(self) -> None: :meth:`feed_eof` cannot be called more than once. Raises: - EOFError: if the stream has ended. + EOFError: If the stream has ended. """ if self.eof: diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 970020e26..8cf581743 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -69,7 +69,7 @@ def parse_uri(uri: str) -> WebSocketURI: Parsed WebSocket URI. Raises: - InvalidURI: if ``uri`` isn't a valid WebSocket URI. + InvalidURI: If ``uri`` isn't a valid WebSocket URI. """ parsed = urllib.parse.urlparse(uri) From 908c7ba23168da52d0006d67bc068e315e90daae Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 6 Jan 2024 10:02:56 +0100 Subject: [PATCH 1264/1539] Clean up sync message assembler. Remove support for control frames, which isn't actually used. --- src/websockets/sync/messages.py | 34 +++++----- tests/sync/test_messages.py | 113 ++++++++++++-------------------- 2 files changed, 60 insertions(+), 87 deletions(-) diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index d98ff855b..dcba183d9 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -5,7 +5,7 @@ import threading from typing import Iterator, List, Optional, cast -from ..frames import Frame, Opcode +from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame from ..typing import Data @@ -25,8 +25,11 @@ def __init__(self) -> None: # primitives provided by the threading and queue modules. self.mutex = threading.Lock() - # We create a latch with two events to ensure proper interleaving of - # writing and reading messages. + # We create a latch with two events to synchronize the production of + # frames and the consumption of messages (or frames) without a buffer. + # This design requires a switch between the library thread and the user + # thread for each message; that shouldn't be a performance bottleneck. + # put() sets this event to tell get() that a message can be fetched. self.message_complete = threading.Event() # get() sets this event to let put() that the message was fetched. @@ -72,8 +75,10 @@ def get(self, timeout: Optional[float] = None) -> Data: Raises: EOFError: If the stream of frames has ended. - RuntimeError: If two threads run :meth:`get` or :meth:``get_iter` + RuntimeError: If two threads run :meth:`get` or :meth:`get_iter` concurrently. + TimeoutError: If a timeout is provided and elapses before a + complete message is received. """ with self.mutex: @@ -131,7 +136,7 @@ def get_iter(self) -> Iterator[Data]: Raises: EOFError: If the stream of frames has ended. - RuntimeError: If two threads run :meth:`get` or :meth:``get_iter` + RuntimeError: If two threads run :meth:`get` or :meth:`get_iter` concurrently. """ @@ -159,11 +164,10 @@ def get_iter(self) -> Iterator[Data]: self.get_in_progress = True # Locking with get_in_progress ensures only one thread can get here. - yield from chunks - while True: - chunk = self.chunks_queue.get() - if chunk is None: - break + chunk: Optional[Data] + for chunk in chunks: + yield chunk + while (chunk := self.chunks_queue.get()) is not None: yield chunk with self.mutex: @@ -205,15 +209,12 @@ def put(self, frame: Frame) -> None: if self.put_in_progress: raise RuntimeError("put is already running") - if frame.opcode is Opcode.TEXT: + if frame.opcode is OP_TEXT: self.decoder = UTF8Decoder(errors="strict") - elif frame.opcode is Opcode.BINARY: + elif frame.opcode is OP_BINARY: self.decoder = None - elif frame.opcode is Opcode.CONT: - pass else: - # Ignore control frames. - return + assert frame.opcode is OP_CONT data: Data if self.decoder is not None: @@ -242,6 +243,7 @@ def put(self, frame: Frame) -> None: self.put_in_progress = True # Release the lock to allow get() to run and eventually set the event. + # Locking with put_in_progress ensures only one coroutine can get here. self.message_fetched.wait() with self.mutex: diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py index 825eb8797..c134b8304 100644 --- a/tests/sync/test_messages.py +++ b/tests/sync/test_messages.py @@ -1,6 +1,6 @@ import time -from websockets.frames import OP_BINARY, OP_CONT, OP_PING, OP_PONG, OP_TEXT, Frame +from websockets.frames import OP_BINARY, OP_CONT, OP_TEXT, Frame from websockets.sync.messages import * from ..utils import MS @@ -350,76 +350,6 @@ def test_get_with_timeout_times_out(self): with self.assertRaises(TimeoutError): self.assembler.get(MS) - # Test control frames - - def test_control_frame_before_message_is_ignored(self): - """get ignores control frames between messages.""" - - def putter(): - self.assembler.put(Frame(OP_PING, b"")) - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - - with self.run_in_thread(putter): - message = self.assembler.get() - - self.assertEqual(message, "café") - - def test_control_frame_in_fragmented_message_is_ignored(self): - """get ignores control frames within fragmented messages.""" - - def putter(): - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) - self.assembler.put(Frame(OP_PING, b"")) - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - self.assembler.put(Frame(OP_PONG, b"")) - self.assembler.put(Frame(OP_CONT, b"a")) - - with self.run_in_thread(putter): - message = self.assembler.get() - - self.assertEqual(message, b"tea") - - # Test concurrency - - def test_get_fails_when_get_is_running(self): - """get cannot be called concurrently with itself.""" - with self.run_in_thread(self.assembler.get): - with self.assertRaises(RuntimeError): - self.assembler.get() - self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread - - def test_get_fails_when_get_iter_is_running(self): - """get cannot be called concurrently with get_iter.""" - with self.run_in_thread(lambda: list(self.assembler.get_iter())): - with self.assertRaises(RuntimeError): - self.assembler.get() - self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread - - def test_get_iter_fails_when_get_is_running(self): - """get_iter cannot be called concurrently with get.""" - with self.run_in_thread(self.assembler.get): - with self.assertRaises(RuntimeError): - list(self.assembler.get_iter()) - self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread - - def test_get_iter_fails_when_get_iter_is_running(self): - """get_iter cannot be called concurrently with itself.""" - with self.run_in_thread(lambda: list(self.assembler.get_iter())): - with self.assertRaises(RuntimeError): - list(self.assembler.get_iter()) - self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread - - def test_put_fails_when_put_is_running(self): - """put cannot be called concurrently with itself.""" - - def putter(): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - - with self.run_in_thread(putter): - with self.assertRaises(RuntimeError): - self.assembler.put(Frame(OP_BINARY, b"tea")) - self.assembler.get() # unblock other thread - # Test termination def test_get_fails_when_interrupted_by_close(self): @@ -477,3 +407,44 @@ def test_close_is_idempotent(self): """close can be called multiple times safely.""" self.assembler.close() self.assembler.close() + + # Test (non-)concurrency + + def test_get_fails_when_get_is_running(self): + """get cannot be called concurrently with itself.""" + with self.run_in_thread(self.assembler.get): + with self.assertRaises(RuntimeError): + self.assembler.get() + self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread + + def test_get_fails_when_get_iter_is_running(self): + """get cannot be called concurrently with get_iter.""" + with self.run_in_thread(lambda: list(self.assembler.get_iter())): + with self.assertRaises(RuntimeError): + self.assembler.get() + self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread + + def test_get_iter_fails_when_get_is_running(self): + """get_iter cannot be called concurrently with get.""" + with self.run_in_thread(self.assembler.get): + with self.assertRaises(RuntimeError): + list(self.assembler.get_iter()) + self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread + + def test_get_iter_fails_when_get_iter_is_running(self): + """get_iter cannot be called concurrently with itself.""" + with self.run_in_thread(lambda: list(self.assembler.get_iter())): + with self.assertRaises(RuntimeError): + list(self.assembler.get_iter()) + self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread + + def test_put_fails_when_put_is_running(self): + """put cannot be called concurrently with itself.""" + + def putter(): + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + with self.run_in_thread(putter): + with self.assertRaises(RuntimeError): + self.assembler.put(Frame(OP_BINARY, b"tea")) + self.assembler.get() # unblock other thread From e21811e751f3f4fef18ad13b1b6f7064be004af6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jan 2024 11:12:27 +0100 Subject: [PATCH 1265/1539] Rename ssl_context to ssl in sync implementation. --- docs/project/changelog.rst | 14 ++++++++++- src/websockets/sync/client.py | 27 ++++++++++++-------- src/websockets/sync/server.py | 21 ++++++++++------ tests/sync/client.py | 3 ++- tests/sync/test_client.py | 47 ++++++++++++++++++++--------------- tests/sync/test_server.py | 23 +++++++++++------ 6 files changed, 89 insertions(+), 46 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 963353d0e..e288831be 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,11 +25,23 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. -12.1 +13.0 ---- *In development* +Backwards-incompatible changes +.............................. + +.. admonition:: The ``ssl_context`` argument of :func:`~sync.client.connect` + and :func:`~sync.server.serve` is renamed to ``ssl``. + :class: note + + This aligns the API of the :mod:`threading` implementation with the + :mod:`asyncio` implementation. + + For backwards compatibility, ``ssl_context`` is still supported. + New features ............ diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 79af0132f..6faca7789 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -1,8 +1,9 @@ from __future__ import annotations import socket -import ssl +import ssl as ssl_module import threading +import warnings from typing import Any, Optional, Sequence, Type from ..client import ClientProtocol @@ -128,7 +129,7 @@ def connect( *, # TCP/TLS sock: Optional[socket.socket] = None, - ssl_context: Optional[ssl.SSLContext] = None, + ssl: Optional[ssl_module.SSLContext] = None, server_hostname: Optional[str] = None, # WebSocket origin: Optional[Origin] = None, @@ -166,7 +167,7 @@ def connect( sock: Preexisting TCP socket. ``sock`` overrides the host and port from ``uri``. You may call :func:`socket.create_connection` to create a suitable TCP socket. - ssl_context: Configuration for enabling TLS on the connection. + ssl: Configuration for enabling TLS on the connection. server_hostname: Host name for the TLS handshake. ``server_hostname`` overrides the host name from ``uri``. origin: Value of the ``Origin`` header, for servers that require it. @@ -207,9 +208,14 @@ def connect( # Process parameters + # Backwards compatibility: ssl used to be called ssl_context. + if ssl is None and "ssl_context" in kwargs: + ssl = kwargs.pop("ssl_context") + warnings.warn("ssl_context was renamed to ssl", DeprecationWarning) + wsuri = parse_uri(uri) - if not wsuri.secure and ssl_context is not None: - raise TypeError("ssl_context argument is incompatible with a ws:// URI") + if not wsuri.secure and ssl is not None: + raise TypeError("ssl argument is incompatible with a ws:// URI") # Private APIs for unix_connect() unix: bool = kwargs.pop("unix", False) @@ -259,12 +265,12 @@ def connect( # Initialize TLS wrapper and perform TLS handshake if wsuri.secure: - if ssl_context is None: - ssl_context = ssl.create_default_context() + if ssl is None: + ssl = ssl_module.create_default_context() if server_hostname is None: server_hostname = wsuri.host sock.settimeout(deadline.timeout()) - sock = ssl_context.wrap_socket(sock, server_hostname=server_hostname) + sock = ssl.wrap_socket(sock, server_hostname=server_hostname) sock.settimeout(None) # Initialize WebSocket connection @@ -318,12 +324,13 @@ def unix_connect( Args: path: File system path to the Unix socket. uri: URI of the WebSocket server. ``uri`` defaults to - ``ws://localhost/`` or, when a ``ssl_context`` is provided, to + ``ws://localhost/`` or, when a ``ssl`` is provided, to ``wss://localhost/``. """ if uri is None: - if kwargs.get("ssl_context") is None: + # Backwards compatibility: ssl used to be called ssl_context. + if kwargs.get("ssl") is None and kwargs.get("ssl_context") is None: uri = "ws://localhost/" else: uri = "wss://localhost/" diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index c19992849..fa6087d54 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -5,9 +5,10 @@ import os import selectors import socket -import ssl +import ssl as ssl_module import sys import threading +import warnings from types import TracebackType from typing import Any, Callable, Optional, Sequence, Type @@ -268,7 +269,7 @@ def serve( *, # TCP/TLS sock: Optional[socket.socket] = None, - ssl_context: Optional[ssl.SSLContext] = None, + ssl: Optional[ssl_module.SSLContext] = None, # WebSocket origins: Optional[Sequence[Optional[Origin]]] = None, extensions: Optional[Sequence[ServerExtensionFactory]] = None, @@ -337,7 +338,7 @@ def handler(websocket): sock: Preexisting TCP socket. ``sock`` replaces ``host`` and ``port``. You may call :func:`socket.create_server` to create a suitable TCP socket. - ssl_context: Configuration for enabling TLS on the connection. + ssl: Configuration for enabling TLS on the connection. origins: Acceptable values of the ``Origin`` header, for defending against Cross-Site WebSocket Hijacking attacks. Include :obj:`None` in the list if the lack of an origin is acceptable. @@ -386,6 +387,11 @@ def handler(websocket): # Process parameters + # Backwards compatibility: ssl used to be called ssl_context. + if ssl is None and "ssl_context" in kwargs: + ssl = kwargs.pop("ssl_context") + warnings.warn("ssl_context was renamed to ssl", DeprecationWarning) + if subprotocols is not None: validate_subprotocols(subprotocols) @@ -417,8 +423,8 @@ def handler(websocket): # Initialize TLS wrapper - if ssl_context is not None: - sock = ssl_context.wrap_socket( + if ssl is not None: + sock = ssl.wrap_socket( sock, server_side=True, # Delay TLS handshake until after we set a timeout on the socket. @@ -441,9 +447,10 @@ def conn_handler(sock: socket.socket, addr: Any) -> None: # Perform TLS handshake - if ssl_context is not None: + if ssl is not None: sock.settimeout(deadline.timeout()) - assert isinstance(sock, ssl.SSLSocket) # mypy cannot figure this out + # mypy cannot figure this out + assert isinstance(sock, ssl_module.SSLSocket) sock.do_handshake() sock.settimeout(None) diff --git a/tests/sync/client.py b/tests/sync/client.py index 683893e88..bb4855c7f 100644 --- a/tests/sync/client.py +++ b/tests/sync/client.py @@ -25,7 +25,8 @@ def run_client(wsuri_or_server, secure=None, resource_name="/", **kwargs): else: assert isinstance(wsuri_or_server, WebSocketServer) if secure is None: - secure = "ssl_context" in kwargs + # Backwards compatibility: ssl used to be called ssl_context. + secure = "ssl" in kwargs or "ssl_context" in kwargs protocol = "wss" if secure else "ws" host, port = wsuri_or_server.socket.getsockname() wsuri = f"{protocol}://{host}:{port}{resource_name}" diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index c900f3b0f..fa363debf 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -7,7 +7,7 @@ from websockets.extensions.permessage_deflate import PerMessageDeflate from websockets.sync.client import * -from ..utils import MS, temp_unix_socket_path +from ..utils import MS, DeprecationTestCase, temp_unix_socket_path from .client import CLIENT_CONTEXT, run_client, run_unix_client from .server import SERVER_CONTEXT, do_nothing, run_server, run_unix_server @@ -137,18 +137,18 @@ def close_connection(self, request): class SecureClientTests(unittest.TestCase): def test_connection(self): """Client connects to server securely.""" - with run_server(ssl_context=SERVER_CONTEXT) as server: - with run_client(server, ssl_context=CLIENT_CONTEXT) as client: + with run_server(ssl=SERVER_CONTEXT) as server: + with run_client(server, ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertEqual(client.socket.version()[:3], "TLS") def test_set_server_hostname_implicitly(self): """Client sets server_hostname to the host in the WebSocket URI.""" with temp_unix_socket_path() as path: - with run_unix_server(path, ssl_context=SERVER_CONTEXT): + with run_unix_server(path, ssl=SERVER_CONTEXT): with run_unix_client( path, - ssl_context=CLIENT_CONTEXT, + ssl=CLIENT_CONTEXT, uri="wss://overridden/", ) as client: self.assertEqual(client.socket.server_hostname, "overridden") @@ -156,17 +156,17 @@ def test_set_server_hostname_implicitly(self): def test_set_server_hostname_explicitly(self): """Client sets server_hostname to the value provided in argument.""" with temp_unix_socket_path() as path: - with run_unix_server(path, ssl_context=SERVER_CONTEXT): + with run_unix_server(path, ssl=SERVER_CONTEXT): with run_unix_client( path, - ssl_context=CLIENT_CONTEXT, + ssl=CLIENT_CONTEXT, server_hostname="overridden", ) as client: self.assertEqual(client.socket.server_hostname, "overridden") def test_reject_invalid_server_certificate(self): """Client rejects certificate where server certificate isn't trusted.""" - with run_server(ssl_context=SERVER_CONTEXT) as server: + with run_server(ssl=SERVER_CONTEXT) as server: with self.assertRaisesRegex( ssl.SSLCertVerificationError, r"certificate verify failed: self[ -]signed certificate", @@ -177,15 +177,13 @@ def test_reject_invalid_server_certificate(self): def test_reject_invalid_server_hostname(self): """Client rejects certificate where server hostname doesn't match.""" - with run_server(ssl_context=SERVER_CONTEXT) as server: + with run_server(ssl=SERVER_CONTEXT) as server: with self.assertRaisesRegex( ssl.SSLCertVerificationError, r"certificate verify failed: Hostname mismatch", ): # This hostname isn't included in the test certificate. - with run_client( - server, ssl_context=CLIENT_CONTEXT, server_hostname="invalid" - ): + with run_client(server, ssl=CLIENT_CONTEXT, server_hostname="invalid"): self.fail("did not raise") @@ -212,8 +210,8 @@ class SecureUnixClientTests(unittest.TestCase): def test_connection(self): """Client connects to server securely over a Unix socket.""" with temp_unix_socket_path() as path: - with run_unix_server(path, ssl_context=SERVER_CONTEXT): - with run_unix_client(path, ssl_context=CLIENT_CONTEXT) as client: + with run_unix_server(path, ssl=SERVER_CONTEXT): + with run_unix_client(path, ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertEqual(client.socket.version()[:3], "TLS") @@ -221,23 +219,23 @@ def test_set_server_hostname(self): """Client sets server_hostname to the host in the WebSocket URI.""" # This is part of the documented behavior of unix_connect(). with temp_unix_socket_path() as path: - with run_unix_server(path, ssl_context=SERVER_CONTEXT): + with run_unix_server(path, ssl=SERVER_CONTEXT): with run_unix_client( path, - ssl_context=CLIENT_CONTEXT, + ssl=CLIENT_CONTEXT, uri="wss://overridden/", ) as client: self.assertEqual(client.socket.server_hostname, "overridden") class ClientUsageErrorsTests(unittest.TestCase): - def test_ssl_context_without_secure_uri(self): - """Client rejects ssl_context when URI isn't secure.""" + def test_ssl_without_secure_uri(self): + """Client rejects ssl when URI isn't secure.""" with self.assertRaisesRegex( TypeError, - "ssl_context argument is incompatible with a ws:// URI", + "ssl argument is incompatible with a ws:// URI", ): - connect("ws://localhost/", ssl_context=CLIENT_CONTEXT) + connect("ws://localhost/", ssl=CLIENT_CONTEXT) def test_unix_without_path_or_sock(self): """Unix client requires path when sock isn't provided.""" @@ -272,3 +270,12 @@ def test_unsupported_compression(self): "unsupported compression: False", ): connect("ws://localhost/", compression=False) + + +class BackwardsCompatibilityTests(DeprecationTestCase): + def test_ssl_context_argument(self): + """Client supports the deprecated ssl_context argument.""" + with run_server(ssl=SERVER_CONTEXT) as server: + with self.assertDeprecationWarning("ssl_context was renamed to ssl"): + with run_client(server, ssl_context=CLIENT_CONTEXT): + pass diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index f9db84246..5e7e79c52 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -14,7 +14,7 @@ from websockets.http11 import Request, Response from websockets.sync.server import * -from ..utils import MS, temp_unix_socket_path +from ..utils import MS, DeprecationTestCase, temp_unix_socket_path from .client import CLIENT_CONTEXT, run_client, run_unix_client from .server import ( SERVER_CONTEXT, @@ -274,20 +274,20 @@ def handler(sock, addr): class SecureServerTests(EvalShellMixin, unittest.TestCase): def test_connection(self): """Server receives secure connection from client.""" - with run_server(ssl_context=SERVER_CONTEXT) as server: - with run_client(server, ssl_context=CLIENT_CONTEXT) as client: + with run_server(ssl=SERVER_CONTEXT) as server: + with run_client(server, ssl=CLIENT_CONTEXT) as client: self.assertEval(client, "ws.protocol.state.name", "OPEN") self.assertEval(client, "ws.socket.version()[:3]", "TLS") def test_timeout_during_tls_handshake(self): """Server times out before receiving TLS handshake request from client.""" - with run_server(ssl_context=SERVER_CONTEXT, open_timeout=MS) as server: + with run_server(ssl=SERVER_CONTEXT, open_timeout=MS) as server: with socket.create_connection(server.socket.getsockname()) as sock: self.assertEqual(sock.recv(4096), b"") def test_connection_closed_during_tls_handshake(self): """Server reads EOF before receiving TLS handshake request from client.""" - with run_server(ssl_context=SERVER_CONTEXT) as server: + with run_server(ssl=SERVER_CONTEXT) as server: # Patch handler to record a reference to the thread running it. server_thread = None conn_received = threading.Event() @@ -325,8 +325,8 @@ class SecureUnixServerTests(EvalShellMixin, unittest.TestCase): def test_connection(self): """Server receives secure connection from client over a Unix socket.""" with temp_unix_socket_path() as path: - with run_unix_server(path, ssl_context=SERVER_CONTEXT): - with run_unix_client(path, ssl_context=CLIENT_CONTEXT) as client: + with run_unix_server(path, ssl=SERVER_CONTEXT): + with run_unix_client(path, ssl=CLIENT_CONTEXT) as client: self.assertEval(client, "ws.protocol.state.name", "OPEN") self.assertEval(client, "ws.socket.version()[:3]", "TLS") @@ -386,3 +386,12 @@ def test_shutdown(self): # Check that the server socket is closed. with self.assertRaises(OSError): server.socket.accept() + + +class BackwardsCompatibilityTests(DeprecationTestCase): + def test_ssl_context_argument(self): + """Client supports the deprecated ssl_context argument.""" + with self.assertDeprecationWarning("ssl_context was renamed to ssl"): + with run_server(ssl_context=SERVER_CONTEXT) as server: + with run_client(server, ssl=CLIENT_CONTEXT): + pass From 45d8de7495ea33724bf93d753d65cad932472aac Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jan 2024 21:06:38 +0100 Subject: [PATCH 1266/1539] Standardize style for testing exceptions. --- tests/legacy/test_client_server.py | 19 +++-- tests/legacy/test_http.py | 76 +++++++++++++++----- tests/sync/test_client.py | 90 ++++++++++++----------- tests/sync/test_connection.py | 54 +++++++------- tests/sync/test_server.py | 110 ++++++++++++++++------------- tests/test_server.py | 73 +++++++++++++------ 6 files changed, 265 insertions(+), 157 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index c49d91b70..4a21f7cea 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1331,20 +1331,24 @@ def test_checking_origin_succeeds(self): @with_server(origins=["http://localhost"]) def test_checking_origin_fails(self): - with self.assertRaisesRegex( - InvalidHandshake, "server rejected WebSocket connection: HTTP 403" - ): + with self.assertRaises(InvalidHandshake) as raised: self.start_client(origin="http://otherhost") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) @with_server(origins=["http://localhost"]) def test_checking_origins_fails_with_multiple_headers(self): - with self.assertRaisesRegex( - InvalidHandshake, "server rejected WebSocket connection: HTTP 400" - ): + with self.assertRaises(InvalidHandshake) as raised: self.start_client( origin="http://localhost", extra_headers=[("Origin", "http://otherhost")], ) + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 400", + ) @with_server(origins=[None]) @with_client() @@ -1574,8 +1578,9 @@ async def run_client(): pass # work around bug in coverage with self.assertLogs("websockets", logging.INFO) as logs: - with self.assertRaisesRegex(Exception, "BOOM"): + with self.assertRaises(Exception) as raised: self.loop.run_until_complete(run_client()) + self.assertEqual(str(raised.exception), "BOOM") # Iteration 1 self.assertEqual( diff --git a/tests/legacy/test_http.py b/tests/legacy/test_http.py index 15d53e08d..76af61122 100644 --- a/tests/legacy/test_http.py +++ b/tests/legacy/test_http.py @@ -31,30 +31,48 @@ async def test_read_request(self): async def test_read_request_empty(self): self.stream.feed_eof() - with self.assertRaisesRegex( - EOFError, "connection closed while reading HTTP request line" - ): + with self.assertRaises(EOFError) as raised: await read_request(self.stream) + self.assertEqual( + str(raised.exception), + "connection closed while reading HTTP request line", + ) async def test_read_request_invalid_request_line(self): self.stream.feed_data(b"GET /\r\n\r\n") - with self.assertRaisesRegex(ValueError, "invalid HTTP request line: GET /"): + with self.assertRaises(ValueError) as raised: await read_request(self.stream) + self.assertEqual( + str(raised.exception), + "invalid HTTP request line: GET /", + ) async def test_read_request_unsupported_method(self): self.stream.feed_data(b"OPTIONS * HTTP/1.1\r\n\r\n") - with self.assertRaisesRegex(ValueError, "unsupported HTTP method: OPTIONS"): + with self.assertRaises(ValueError) as raised: await read_request(self.stream) + self.assertEqual( + str(raised.exception), + "unsupported HTTP method: OPTIONS", + ) async def test_read_request_unsupported_version(self): self.stream.feed_data(b"GET /chat HTTP/1.0\r\n\r\n") - with self.assertRaisesRegex(ValueError, "unsupported HTTP version: HTTP/1.0"): + with self.assertRaises(ValueError) as raised: await read_request(self.stream) + self.assertEqual( + str(raised.exception), + "unsupported HTTP version: HTTP/1.0", + ) async def test_read_request_invalid_header(self): self.stream.feed_data(b"GET /chat HTTP/1.1\r\nOops\r\n") - with self.assertRaisesRegex(ValueError, "invalid HTTP header line: Oops"): + with self.assertRaises(ValueError) as raised: await read_request(self.stream) + self.assertEqual( + str(raised.exception), + "invalid HTTP header line: Oops", + ) async def test_read_response(self): # Example from the protocol overview in RFC 6455 @@ -73,40 +91,66 @@ async def test_read_response(self): async def test_read_response_empty(self): self.stream.feed_eof() - with self.assertRaisesRegex( - EOFError, "connection closed while reading HTTP status line" - ): + with self.assertRaises(EOFError) as raised: await read_response(self.stream) + self.assertEqual( + str(raised.exception), + "connection closed while reading HTTP status line", + ) async def test_read_request_invalid_status_line(self): self.stream.feed_data(b"Hello!\r\n") - with self.assertRaisesRegex(ValueError, "invalid HTTP status line: Hello!"): + with self.assertRaises(ValueError) as raised: await read_response(self.stream) + self.assertEqual( + str(raised.exception), + "invalid HTTP status line: Hello!", + ) async def test_read_response_unsupported_version(self): self.stream.feed_data(b"HTTP/1.0 400 Bad Request\r\n\r\n") - with self.assertRaisesRegex(ValueError, "unsupported HTTP version: HTTP/1.0"): + with self.assertRaises(ValueError) as raised: await read_response(self.stream) + self.assertEqual( + str(raised.exception), + "unsupported HTTP version: HTTP/1.0", + ) async def test_read_response_invalid_status(self): self.stream.feed_data(b"HTTP/1.1 OMG WTF\r\n\r\n") - with self.assertRaisesRegex(ValueError, "invalid HTTP status code: OMG"): + with self.assertRaises(ValueError) as raised: await read_response(self.stream) + self.assertEqual( + str(raised.exception), + "invalid HTTP status code: OMG", + ) async def test_read_response_unsupported_status(self): self.stream.feed_data(b"HTTP/1.1 007 My name is Bond\r\n\r\n") - with self.assertRaisesRegex(ValueError, "unsupported HTTP status code: 007"): + with self.assertRaises(ValueError) as raised: await read_response(self.stream) + self.assertEqual( + str(raised.exception), + "unsupported HTTP status code: 007", + ) async def test_read_response_invalid_reason(self): self.stream.feed_data(b"HTTP/1.1 200 \x7f\r\n\r\n") - with self.assertRaisesRegex(ValueError, "invalid HTTP reason phrase: \\x7f"): + with self.assertRaises(ValueError) as raised: await read_response(self.stream) + self.assertEqual( + str(raised.exception), + "invalid HTTP reason phrase: \x7f", + ) async def test_read_response_invalid_header(self): self.stream.feed_data(b"HTTP/1.1 500 Internal Server Error\r\nOops\r\n") - with self.assertRaisesRegex(ValueError, "invalid HTTP header line: Oops"): + with self.assertRaises(ValueError) as raised: await read_response(self.stream) + self.assertEqual( + str(raised.exception), + "invalid HTTP header line: Oops", + ) async def test_header_name(self): self.stream.feed_data(b"foo bar: baz qux\r\n\r\n") diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index fa363debf..c403b9632 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -28,12 +28,13 @@ def remove_accept_header(self, request, response): # The connection will be open for the server but failed for the client. # Use a connection handler that exits immediately to avoid an exception. with run_server(do_nothing, process_response=remove_accept_header) as server: - with self.assertRaisesRegex( - InvalidHandshake, - "missing Sec-WebSocket-Accept header", - ): + with self.assertRaises(InvalidHandshake) as raised: with run_client(server, close_timeout=MS): self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "missing Sec-WebSocket-Accept header", + ) def test_tcp_connection_fails(self): """Client fails to connect to server.""" @@ -107,15 +108,16 @@ def stall_connection(self, request): # Use a connection handler that exits immediately to avoid an exception. with run_server(do_nothing, process_request=stall_connection) as server: try: - with self.assertRaisesRegex( - TimeoutError, - "timed out during handshake", - ): + with self.assertRaises(TimeoutError) as raised: # While it shouldn't take 50ms to open a connection, this # test becomes flaky in CI when setting a smaller timeout, # even after increasing WEBSOCKETS_TESTS_TIMEOUT_FACTOR. with run_client(server, open_timeout=5 * MS): self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "timed out during handshake", + ) finally: gate.set() @@ -126,12 +128,13 @@ def close_connection(self, request): self.close_socket() with run_server(process_request=close_connection) as server: - with self.assertRaisesRegex( - ConnectionError, - "connection closed during handshake", - ): + with self.assertRaises(ConnectionError) as raised: with run_client(server): self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "connection closed during handshake", + ) class SecureClientTests(unittest.TestCase): @@ -167,24 +170,26 @@ def test_set_server_hostname_explicitly(self): def test_reject_invalid_server_certificate(self): """Client rejects certificate where server certificate isn't trusted.""" with run_server(ssl=SERVER_CONTEXT) as server: - with self.assertRaisesRegex( - ssl.SSLCertVerificationError, - r"certificate verify failed: self[ -]signed certificate", - ): + with self.assertRaises(ssl.SSLCertVerificationError) as raised: # The test certificate isn't trusted system-wide. with run_client(server, secure=True): self.fail("did not raise") + self.assertIn( + "certificate verify failed: self signed certificate", + str(raised.exception).replace("-", " "), + ) def test_reject_invalid_server_hostname(self): """Client rejects certificate where server hostname doesn't match.""" with run_server(ssl=SERVER_CONTEXT) as server: - with self.assertRaisesRegex( - ssl.SSLCertVerificationError, - r"certificate verify failed: Hostname mismatch", - ): + with self.assertRaises(ssl.SSLCertVerificationError) as raised: # This hostname isn't included in the test certificate. with run_client(server, ssl=CLIENT_CONTEXT, server_hostname="invalid"): self.fail("did not raise") + self.assertIn( + "certificate verify failed: Hostname mismatch", + str(raised.exception), + ) @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") @@ -231,45 +236,50 @@ def test_set_server_hostname(self): class ClientUsageErrorsTests(unittest.TestCase): def test_ssl_without_secure_uri(self): """Client rejects ssl when URI isn't secure.""" - with self.assertRaisesRegex( - TypeError, - "ssl argument is incompatible with a ws:// URI", - ): + with self.assertRaises(TypeError) as raised: connect("ws://localhost/", ssl=CLIENT_CONTEXT) + self.assertEqual( + str(raised.exception), + "ssl argument is incompatible with a ws:// URI", + ) def test_unix_without_path_or_sock(self): """Unix client requires path when sock isn't provided.""" - with self.assertRaisesRegex( - TypeError, - "missing path argument", - ): + with self.assertRaises(TypeError) as raised: unix_connect() + self.assertEqual( + str(raised.exception), + "missing path argument", + ) def test_unix_with_path_and_sock(self): """Unix client rejects path when sock is provided.""" sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.addCleanup(sock.close) - with self.assertRaisesRegex( - TypeError, - "path and sock arguments are incompatible", - ): + with self.assertRaises(TypeError) as raised: unix_connect(path="/", sock=sock) + self.assertEqual( + str(raised.exception), + "path and sock arguments are incompatible", + ) def test_invalid_subprotocol(self): """Client rejects single value of subprotocols.""" - with self.assertRaisesRegex( - TypeError, - "subprotocols must be a list", - ): + with self.assertRaises(TypeError) as raised: connect("ws://localhost/", subprotocols="chat") + self.assertEqual( + str(raised.exception), + "subprotocols must be a list, not a str", + ) def test_unsupported_compression(self): """Client rejects incorrect value of compression.""" - with self.assertRaisesRegex( - ValueError, - "unsupported compression: False", - ): + with self.assertRaises(ValueError) as raised: connect("ws://localhost/", compression=False) + self.assertEqual( + str(raised.exception), + "unsupported compression: False", + ) class BackwardsCompatibilityTests(DeprecationTestCase): diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index e128425d8..953c8c253 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -177,12 +177,13 @@ def test_recv_during_recv(self): recv_thread = threading.Thread(target=self.connection.recv) recv_thread.start() - with self.assertRaisesRegex( - RuntimeError, + with self.assertRaises(RuntimeError) as raised: + self.connection.recv() + self.assertEqual( + str(raised.exception), "cannot call recv while another thread " "is already running recv or recv_streaming", - ): - self.connection.recv() + ) self.remote_connection.send("") recv_thread.join() @@ -194,12 +195,13 @@ def test_recv_during_recv_streaming(self): ) recv_streaming_thread.start() - with self.assertRaisesRegex( - RuntimeError, + with self.assertRaises(RuntimeError) as raised: + self.connection.recv() + self.assertEqual( + str(raised.exception), "cannot call recv while another thread " "is already running recv or recv_streaming", - ): - self.connection.recv() + ) self.remote_connection.send("") recv_streaming_thread.join() @@ -257,12 +259,13 @@ def test_recv_streaming_during_recv(self): recv_thread = threading.Thread(target=self.connection.recv) recv_thread.start() - with self.assertRaisesRegex( - RuntimeError, + with self.assertRaises(RuntimeError) as raised: + list(self.connection.recv_streaming()) + self.assertEqual( + str(raised.exception), "cannot call recv_streaming while another thread " "is already running recv or recv_streaming", - ): - list(self.connection.recv_streaming()) + ) self.remote_connection.send("") recv_thread.join() @@ -274,12 +277,13 @@ def test_recv_streaming_during_recv_streaming(self): ) recv_streaming_thread.start() - with self.assertRaisesRegex( - RuntimeError, + with self.assertRaises(RuntimeError) as raised: + list(self.connection.recv_streaming()) + self.assertEqual( + str(raised.exception), r"cannot call recv_streaming while another thread " r"is already running recv or recv_streaming", - ): - list(self.connection.recv_streaming()) + ) self.remote_connection.send("") recv_streaming_thread.join() @@ -355,11 +359,12 @@ def fragments(): [b"\x01\x02", b"\xfe\xff"], ]: with self.subTest(message=message): - with self.assertRaisesRegex( - RuntimeError, - "cannot call send while another thread is already running send", - ): + with self.assertRaises(RuntimeError) as raised: self.connection.send(message) + self.assertEqual( + str(raised.exception), + "cannot call send while another thread is already running send", + ) exit_gate.set() send_thread.join() @@ -598,11 +603,12 @@ def test_ping_duplicate_payload(self): """ping rejects the same payload until receiving the pong.""" with self.remote_connection.protocol_mutex: # block response to ping pong_waiter = self.connection.ping("idem") - with self.assertRaisesRegex( - RuntimeError, - "already waiting for a pong with the same data", - ): + with self.assertRaises(RuntimeError) as raised: self.connection.ping("idem") + self.assertEqual( + str(raised.exception), + "already waiting for a pong with the same data", + ) self.assertTrue(pong_waiter.wait(MS)) self.connection.ping("idem") # doesn't raise an exception diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 5e7e79c52..f9f30baf1 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -41,33 +41,36 @@ def remove_key_header(self, request): del request.headers["Sec-WebSocket-Key"] with run_server(process_request=remove_key_header) as server: - with self.assertRaisesRegex( - InvalidStatus, - "server rejected WebSocket connection: HTTP 400", - ): + with self.assertRaises(InvalidStatus) as raised: with run_client(server): self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 400", + ) def test_connection_handler_returns(self): """Connection handler returns.""" with run_server(do_nothing) as server: with run_client(server) as client: - with self.assertRaisesRegex( - ConnectionClosedOK, - r"received 1000 \(OK\); then sent 1000 \(OK\)", - ): + with self.assertRaises(ConnectionClosedOK) as raised: client.recv() + self.assertEqual( + str(raised.exception), + "received 1000 (OK); then sent 1000 (OK)", + ) def test_connection_handler_raises_exception(self): """Connection handler raises an exception.""" with run_server(crash) as server: with run_client(server) as client: - with self.assertRaisesRegex( - ConnectionClosedError, - r"received 1011 \(internal error\); " - r"then sent 1011 \(internal error\)", - ): + with self.assertRaises(ConnectionClosedError) as raised: client.recv() + self.assertEqual( + str(raised.exception), + "received 1011 (internal error); " + "then sent 1011 (internal error)", + ) def test_existing_socket(self): """Server receives connection using a pre-existing socket.""" @@ -100,12 +103,13 @@ def select_subprotocol(ws, subprotocols): raise NegotiationError with run_server(select_subprotocol=select_subprotocol) as server: - with self.assertRaisesRegex( - InvalidStatus, - "server rejected WebSocket connection: HTTP 400", - ): + with self.assertRaises(InvalidStatus) as raised: with run_client(server): self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 400", + ) def test_select_subprotocol_raises_exception(self): """Server returns an error if select_subprotocol raises an exception.""" @@ -114,12 +118,13 @@ def select_subprotocol(ws, subprotocols): raise RuntimeError with run_server(select_subprotocol=select_subprotocol) as server: - with self.assertRaisesRegex( - InvalidStatus, - "server rejected WebSocket connection: HTTP 500", - ): + with self.assertRaises(InvalidStatus) as raised: with run_client(server): self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) def test_process_request(self): """Server runs process_request before processing the handshake.""" @@ -139,12 +144,13 @@ def process_request(ws, request): return ws.protocol.reject(http.HTTPStatus.FORBIDDEN, "Forbidden") with run_server(process_request=process_request) as server: - with self.assertRaisesRegex( - InvalidStatus, - "server rejected WebSocket connection: HTTP 403", - ): + with self.assertRaises(InvalidStatus) as raised: with run_client(server): self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) def test_process_request_raises_exception(self): """Server returns an error if process_request raises an exception.""" @@ -153,12 +159,13 @@ def process_request(ws, request): raise RuntimeError with run_server(process_request=process_request) as server: - with self.assertRaisesRegex( - InvalidStatus, - "server rejected WebSocket connection: HTTP 500", - ): + with self.assertRaises(InvalidStatus) as raised: with run_client(server): self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) def test_process_response(self): """Server runs process_response after processing the handshake.""" @@ -193,12 +200,13 @@ def process_response(ws, request, response): raise RuntimeError with run_server(process_response=process_response) as server: - with self.assertRaisesRegex( - InvalidStatus, - "server rejected WebSocket connection: HTTP 500", - ): + with self.assertRaises(InvalidStatus) as raised: with run_client(server): self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) def test_override_server(self): """Server can override Server header with server_header.""" @@ -334,37 +342,41 @@ def test_connection(self): class ServerUsageErrorsTests(unittest.TestCase): def test_unix_without_path_or_sock(self): """Unix server requires path when sock isn't provided.""" - with self.assertRaisesRegex( - TypeError, - "missing path argument", - ): + with self.assertRaises(TypeError) as raised: unix_serve(eval_shell) + self.assertEqual( + str(raised.exception), + "missing path argument", + ) def test_unix_with_path_and_sock(self): """Unix server rejects path when sock is provided.""" sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.addCleanup(sock.close) - with self.assertRaisesRegex( - TypeError, - "path and sock arguments are incompatible", - ): + with self.assertRaises(TypeError) as raised: unix_serve(eval_shell, path="/", sock=sock) + self.assertEqual( + str(raised.exception), + "path and sock arguments are incompatible", + ) def test_invalid_subprotocol(self): """Server rejects single value of subprotocols.""" - with self.assertRaisesRegex( - TypeError, - "subprotocols must be a list", - ): + with self.assertRaises(TypeError) as raised: serve(eval_shell, subprotocols="chat") + self.assertEqual( + str(raised.exception), + "subprotocols must be a list, not a str", + ) def test_unsupported_compression(self): """Server rejects incorrect value of compression.""" - with self.assertRaisesRegex( - ValueError, - "unsupported compression: False", - ): + with self.assertRaises(ValueError) as raised: serve(eval_shell, compression=False) + self.assertEqual( + str(raised.exception), + "unsupported compression: False", + ) class WebSocketServerTests(unittest.TestCase): diff --git a/tests/test_server.py b/tests/test_server.py index b6f5e3568..e4460dcba 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -213,7 +213,10 @@ def test_unexpected_exception(self): self.assertEqual(response.status_code, 500) with self.assertRaises(Exception) as raised: raise server.handshake_exc - self.assertEqual(str(raised.exception), "BOOM") + self.assertEqual( + str(raised.exception), + "BOOM", + ) def test_missing_connection(self): server = ServerProtocol() @@ -225,7 +228,10 @@ def test_missing_connection(self): self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: raise server.handshake_exc - self.assertEqual(str(raised.exception), "missing Connection header") + self.assertEqual( + str(raised.exception), + "missing Connection header", + ) def test_invalid_connection(self): server = ServerProtocol() @@ -238,7 +244,10 @@ def test_invalid_connection(self): self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: raise server.handshake_exc - self.assertEqual(str(raised.exception), "invalid Connection header: close") + self.assertEqual( + str(raised.exception), + "invalid Connection header: close", + ) def test_missing_upgrade(self): server = ServerProtocol() @@ -250,7 +259,10 @@ def test_missing_upgrade(self): self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: raise server.handshake_exc - self.assertEqual(str(raised.exception), "missing Upgrade header") + self.assertEqual( + str(raised.exception), + "missing Upgrade header", + ) def test_invalid_upgrade(self): server = ServerProtocol() @@ -263,7 +275,10 @@ def test_invalid_upgrade(self): self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: raise server.handshake_exc - self.assertEqual(str(raised.exception), "invalid Upgrade header: h2c") + self.assertEqual( + str(raised.exception), + "invalid Upgrade header: h2c", + ) def test_missing_key(self): server = ServerProtocol() @@ -274,7 +289,10 @@ def test_missing_key(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: raise server.handshake_exc - self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Key header") + self.assertEqual( + str(raised.exception), + "missing Sec-WebSocket-Key header", + ) def test_multiple_key(self): server = ServerProtocol() @@ -302,7 +320,8 @@ def test_invalid_key(self): with self.assertRaises(InvalidHeader) as raised: raise server.handshake_exc self.assertEqual( - str(raised.exception), "invalid Sec-WebSocket-Key header: not Base64 data!" + str(raised.exception), + "invalid Sec-WebSocket-Key header: not Base64 data!", ) def test_truncated_key(self): @@ -318,7 +337,8 @@ def test_truncated_key(self): with self.assertRaises(InvalidHeader) as raised: raise server.handshake_exc self.assertEqual( - str(raised.exception), f"invalid Sec-WebSocket-Key header: {KEY[:16]}" + str(raised.exception), + f"invalid Sec-WebSocket-Key header: {KEY[:16]}", ) def test_missing_version(self): @@ -330,7 +350,10 @@ def test_missing_version(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: raise server.handshake_exc - self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Version header") + self.assertEqual( + str(raised.exception), + "missing Sec-WebSocket-Version header", + ) def test_multiple_version(self): server = ServerProtocol() @@ -358,7 +381,8 @@ def test_invalid_version(self): with self.assertRaises(InvalidHeader) as raised: raise server.handshake_exc self.assertEqual( - str(raised.exception), "invalid Sec-WebSocket-Version header: 11" + str(raised.exception), + "invalid Sec-WebSocket-Version header: 11", ) def test_no_origin(self): @@ -369,7 +393,10 @@ def test_no_origin(self): self.assertEqual(response.status_code, 403) with self.assertRaises(InvalidOrigin) as raised: raise server.handshake_exc - self.assertEqual(str(raised.exception), "missing Origin header") + self.assertEqual( + str(raised.exception), + "missing Origin header", + ) def test_origin(self): server = ServerProtocol(origins=["https://example.com"]) @@ -390,7 +417,8 @@ def test_unexpected_origin(self): with self.assertRaises(InvalidOrigin) as raised: raise server.handshake_exc self.assertEqual( - str(raised.exception), "invalid Origin header: https://other.example.com" + str(raised.exception), + "invalid Origin header: https://other.example.com", ) def test_multiple_origin(self): @@ -435,7 +463,8 @@ def test_unsupported_origin(self): with self.assertRaises(InvalidOrigin) as raised: raise server.handshake_exc self.assertEqual( - str(raised.exception), "invalid Origin header: https://original.example.com" + str(raised.exception), + "invalid Origin header: https://original.example.com", ) def test_no_origin_accepted(self): @@ -574,11 +603,12 @@ def test_no_subprotocol(self): response = server.accept(request) self.assertEqual(response.status_code, 400) - with self.assertRaisesRegex( - NegotiationError, - r"missing subprotocol", - ): + with self.assertRaises(NegotiationError) as raised: raise server.handshake_exc + self.assertEqual( + str(raised.exception), + "missing subprotocol", + ) def test_subprotocol(self): server = ServerProtocol(subprotocols=["chat"]) @@ -628,11 +658,12 @@ def test_unsupported_subprotocol(self): response = server.accept(request) self.assertEqual(response.status_code, 400) - with self.assertRaisesRegex( - NegotiationError, - r"invalid subprotocol; expected one of superchat, chat", - ): + with self.assertRaises(NegotiationError) as raised: raise server.handshake_exc + self.assertEqual( + str(raised.exception), + "invalid subprotocol; expected one of superchat, chat", + ) @staticmethod def optional_chat(protocol, subprotocols): From c06e44d214ca3650b12fbbcaa1a0266dae9432d0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 27 Jan 2024 15:20:39 +0100 Subject: [PATCH 1267/1539] Support closing while sending a fragmented message. On one hand, it will close the connection with an unfinished fragmented message, which is less than ideal. On the other hand, RFC 6455 implies that it should be legal and it's probably best to let users close the connection if they want to close the connection (rather than force them to call fail() instead). --- src/websockets/protocol.py | 6 ++---- tests/test_protocol.py | 36 ++++++++++++------------------------ 2 files changed, 14 insertions(+), 28 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 6851f3b1f..4650cf16d 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -354,12 +354,10 @@ def send_close(self, code: Optional[int] = None, reason: str = "") -> None: reason: close reason. Raises: - ProtocolError: If a fragmented message is being sent, if the code - isn't valid, or if a reason is provided without a code + ProtocolError: If the code isn't valid or if a reason is provided + without a code. """ - if self.expect_continuation_frame: - raise ProtocolError("expected a continuation frame") if code is None: if reason != "": raise ProtocolError("cannot send a reason without a code") diff --git a/tests/test_protocol.py b/tests/test_protocol.py index a64172b53..a1661231f 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -1364,26 +1364,22 @@ def test_client_send_close_in_fragmented_message(self): client = Protocol(CLIENT) client.send_text(b"Spam", fin=False) self.assertFrameSent(client, Frame(OP_TEXT, b"Spam", fin=False)) - # The spec says: "An endpoint MUST be capable of handling control - # frames in the middle of a fragmented message." However, since the - # endpoint must not send a data frame after a close frame, a close - # frame can't be "in the middle" of a fragmented message. - with self.assertRaises(ProtocolError) as raised: - client.send_close(CloseCode.GOING_AWAY) - self.assertEqual(str(raised.exception), "expected a continuation frame") - client.send_continuation(b"Eggs", fin=True) + with self.enforce_mask(b"\x3c\x3c\x3c\x3c"): + client.send_close() + self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) + self.assertIs(client.state, CLOSING) + with self.assertRaises(InvalidState): + client.send_continuation(b"Eggs", fin=True) def test_server_send_close_in_fragmented_message(self): - server = Protocol(CLIENT) + server = Protocol(SERVER) server.send_text(b"Spam", fin=False) self.assertFrameSent(server, Frame(OP_TEXT, b"Spam", fin=False)) - # The spec says: "An endpoint MUST be capable of handling control - # frames in the middle of a fragmented message." However, since the - # endpoint must not send a data frame after a close frame, a close - # frame can't be "in the middle" of a fragmented message. - with self.assertRaises(ProtocolError) as raised: - server.send_close(CloseCode.NORMAL_CLOSURE) - self.assertEqual(str(raised.exception), "expected a continuation frame") + server.send_close() + self.assertEqual(server.data_to_send(), [b"\x88\x00"]) + self.assertIs(server.state, CLOSING) + with self.assertRaises(InvalidState): + server.send_continuation(b"Eggs", fin=True) def test_client_receive_close_in_fragmented_message(self): client = Protocol(CLIENT) @@ -1392,10 +1388,6 @@ def test_client_receive_close_in_fragmented_message(self): client, Frame(OP_TEXT, b"Spam", fin=False), ) - # The spec says: "An endpoint MUST be capable of handling control - # frames in the middle of a fragmented message." However, since the - # endpoint must not send a data frame after a close frame, a close - # frame can't be "in the middle" of a fragmented message. client.receive_data(b"\x88\x02\x03\xe8") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "incomplete fragmented message") @@ -1410,10 +1402,6 @@ def test_server_receive_close_in_fragmented_message(self): server, Frame(OP_TEXT, b"Spam", fin=False), ) - # The spec says: "An endpoint MUST be capable of handling control - # frames in the middle of a fragmented message." However, since the - # endpoint must not send a data frame after a close frame, a close - # frame can't be "in the middle" of a fragmented message. server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "incomplete fragmented message") From d28b71dd297da99aad9d644a2f4721707e464707 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 27 Jan 2024 15:41:42 +0100 Subject: [PATCH 1268/1539] Upgrade to the latest version of black. --- src/websockets/datastructures.py | 6 ++---- tests/test_protocol.py | 14 ++++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index c2a5acfee..aef11bf23 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -172,11 +172,9 @@ class SupportsKeysAndGetItem(Protocol): # pragma: no cover """ - def keys(self) -> Iterable[str]: - ... + def keys(self) -> Iterable[str]: ... - def __getitem__(self, key: str) -> str: - ... + def __getitem__(self, key: str) -> str: ... HeadersLike = Union[ diff --git a/tests/test_protocol.py b/tests/test_protocol.py index a1661231f..b53c8a1ec 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -37,12 +37,14 @@ def assertFrameSent(self, connection, frame, eof=False): """ frames_sent = [ - None - if write is SEND_EOF - else self.parse( - write, - mask=connection.side is CLIENT, - extensions=connection.extensions, + ( + None + if write is SEND_EOF + else self.parse( + write, + mask=connection.side is CLIENT, + extensions=connection.extensions, + ) ) for write in connection.data_to_send() ] From 705dc85e87bb1184d926ab95a591097780c4b855 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 27 Jan 2024 14:55:44 +0100 Subject: [PATCH 1269/1539] Allow sending ping and pong after close. Fix #1429. --- src/websockets/protocol.py | 28 ++++++++-- tests/test_protocol.py | 111 +++++++++++++++++++++++++++---------- 2 files changed, 105 insertions(+), 34 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 4650cf16d..0b36202e5 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -297,6 +297,8 @@ def send_continuation(self, data: bytes, fin: bool) -> None: """ if not self.expect_continuation_frame: raise ProtocolError("unexpected continuation frame") + if self._state is not OPEN: + raise InvalidState(f"connection is {self.state.name.lower()}") self.expect_continuation_frame = not fin self.send_frame(Frame(OP_CONT, data, fin)) @@ -318,6 +320,8 @@ def send_text(self, data: bytes, fin: bool = True) -> None: """ if self.expect_continuation_frame: raise ProtocolError("expected a continuation frame") + if self._state is not OPEN: + raise InvalidState(f"connection is {self.state.name.lower()}") self.expect_continuation_frame = not fin self.send_frame(Frame(OP_TEXT, data, fin)) @@ -339,6 +343,8 @@ def send_binary(self, data: bytes, fin: bool = True) -> None: """ if self.expect_continuation_frame: raise ProtocolError("expected a continuation frame") + if self._state is not OPEN: + raise InvalidState(f"connection is {self.state.name.lower()}") self.expect_continuation_frame = not fin self.send_frame(Frame(OP_BINARY, data, fin)) @@ -358,6 +364,10 @@ def send_close(self, code: Optional[int] = None, reason: str = "") -> None: without a code. """ + # While RFC 6455 doesn't rule out sending more than one close Frame, + # websockets is conservative in what it sends and doesn't allow that. + if self._state is not OPEN: + raise InvalidState(f"connection is {self.state.name.lower()}") if code is None: if reason != "": raise ProtocolError("cannot send a reason without a code") @@ -383,6 +393,9 @@ def send_ping(self, data: bytes) -> None: data: payload containing arbitrary binary data. """ + # RFC 6455 allows control frames after starting the closing handshake. + if self._state is not OPEN and self._state is not CLOSING: + raise InvalidState(f"connection is {self.state.name.lower()}") self.send_frame(Frame(OP_PING, data)) def send_pong(self, data: bytes) -> None: @@ -396,6 +409,9 @@ def send_pong(self, data: bytes) -> None: data: payload containing arbitrary binary data. """ + # RFC 6455 allows control frames after starting the closing handshake. + if self._state is not OPEN and self._state is not CLOSING: + raise InvalidState(f"connection is {self.state.name.lower()}") self.send_frame(Frame(OP_PONG, data)) def fail(self, code: int, reason: str = "") -> None: @@ -675,6 +691,8 @@ def recv_frame(self, frame: Frame) -> None: # 1.4. Closing Handshake: "after receiving a control frame # indicating the connection should be closed, a peer discards # any further data received." + # RFC 6455 allows reading Ping and Pong frames after a Close frame. + # However, that doesn't seem useful; websockets doesn't support it. self.parser = self.discard() next(self.parser) # start coroutine @@ -687,15 +705,13 @@ def recv_frame(self, frame: Frame) -> None: # Private methods for sending events. def send_frame(self, frame: Frame) -> None: - if self.state is not OPEN: - raise InvalidState( - f"cannot write to a WebSocket in the {self.state.name} state" - ) - if self.debug: self.logger.debug("> %s", frame) self.writes.append( - frame.serialize(mask=self.side is CLIENT, extensions=self.extensions) + frame.serialize( + mask=self.side is CLIENT, + extensions=self.extensions, + ) ) def send_eof(self) -> None: diff --git a/tests/test_protocol.py b/tests/test_protocol.py index b53c8a1ec..1d5dab7a0 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -465,15 +465,17 @@ def test_client_sends_text_after_sending_close(self): with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - with self.assertRaises(InvalidState): + with self.assertRaises(InvalidState) as raised: client.send_text(b"") + self.assertEqual(str(raised.exception), "connection is closing") def test_server_sends_text_after_sending_close(self): server = Protocol(SERVER) server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) - with self.assertRaises(InvalidState): + with self.assertRaises(InvalidState) as raised: server.send_text(b"") + self.assertEqual(str(raised.exception), "connection is closing") def test_client_receives_text_after_receiving_close(self): client = Protocol(CLIENT) @@ -679,15 +681,17 @@ def test_client_sends_binary_after_sending_close(self): with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - with self.assertRaises(InvalidState): + with self.assertRaises(InvalidState) as raised: client.send_binary(b"") + self.assertEqual(str(raised.exception), "connection is closing") def test_server_sends_binary_after_sending_close(self): server = Protocol(SERVER) server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) - with self.assertRaises(InvalidState): + with self.assertRaises(InvalidState) as raised: server.send_binary(b"") + self.assertEqual(str(raised.exception), "connection is closing") def test_client_receives_binary_after_receiving_close(self): client = Protocol(CLIENT) @@ -956,6 +960,37 @@ def test_server_receives_close_with_non_utf8_reason(self): ) self.assertIs(server.state, CLOSING) + def test_client_sends_close_twice(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_close(CloseCode.GOING_AWAY) + self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) + with self.assertRaises(InvalidState) as raised: + client.send_close(CloseCode.GOING_AWAY) + self.assertEqual(str(raised.exception), "connection is closing") + + def test_server_sends_close_twice(self): + server = Protocol(SERVER) + server.send_close(CloseCode.NORMAL_CLOSURE) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) + with self.assertRaises(InvalidState) as raised: + server.send_close(CloseCode.NORMAL_CLOSURE) + self.assertEqual(str(raised.exception), "connection is closing") + + def test_client_sends_close_after_connection_is_closed(self): + client = Protocol(CLIENT) + client.receive_eof() + with self.assertRaises(InvalidState) as raised: + client.send_close(CloseCode.GOING_AWAY) + self.assertEqual(str(raised.exception), "connection is closed") + + def test_server_sends_close_after_connection_is_closed(self): + server = Protocol(SERVER) + server.receive_eof() + with self.assertRaises(InvalidState) as raised: + server.send_close(CloseCode.NORMAL_CLOSURE) + self.assertEqual(str(raised.exception), "connection is closed") + class PingTests(ProtocolTestCase): """ @@ -1072,35 +1107,23 @@ def test_client_sends_ping_after_sending_close(self): with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - # The spec says: "An endpoint MAY send a Ping frame any time (...) - # before the connection is closed" but websockets doesn't support - # sending a Ping frame after a Close frame. - with self.assertRaises(InvalidState) as raised: + with self.enforce_mask(b"\x00\x44\x88\xcc"): client.send_ping(b"") - self.assertEqual( - str(raised.exception), - "cannot write to a WebSocket in the CLOSING state", - ) + self.assertEqual(client.data_to_send(), [b"\x89\x80\x00\x44\x88\xcc"]) def test_server_sends_ping_after_sending_close(self): server = Protocol(SERVER) server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) - # The spec says: "An endpoint MAY send a Ping frame any time (...) - # before the connection is closed" but websockets doesn't support - # sending a Ping frame after a Close frame. - with self.assertRaises(InvalidState) as raised: - server.send_ping(b"") - self.assertEqual( - str(raised.exception), - "cannot write to a WebSocket in the CLOSING state", - ) + server.send_ping(b"") + self.assertEqual(server.data_to_send(), [b"\x89\x00"]) def test_client_receives_ping_after_receiving_close(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, CloseCode.NORMAL_CLOSURE) client.receive_data(b"\x89\x04\x22\x66\xaa\xee") + # websockets ignores control frames after a close frame. self.assertFrameReceived(client, None) self.assertFrameSent(client, None) @@ -1109,9 +1132,24 @@ def test_server_receives_ping_after_receiving_close(self): server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, CloseCode.GOING_AWAY) server.receive_data(b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22") + # websockets ignores control frames after a close frame. self.assertFrameReceived(server, None) self.assertFrameSent(server, None) + def test_client_sends_ping_after_connection_is_closed(self): + client = Protocol(CLIENT) + client.receive_eof() + with self.assertRaises(InvalidState) as raised: + client.send_ping(b"") + self.assertEqual(str(raised.exception), "connection is closed") + + def test_server_sends_ping_after_connection_is_closed(self): + server = Protocol(SERVER) + server.receive_eof() + with self.assertRaises(InvalidState) as raised: + server.send_ping(b"") + self.assertEqual(str(raised.exception), "connection is closed") + class PongTests(ProtocolTestCase): """ @@ -1212,23 +1250,23 @@ def test_client_sends_pong_after_sending_close(self): with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - # websockets doesn't support sending a Pong frame after a Close frame. - with self.assertRaises(InvalidState): + with self.enforce_mask(b"\x00\x44\x88\xcc"): client.send_pong(b"") + self.assertEqual(client.data_to_send(), [b"\x8a\x80\x00\x44\x88\xcc"]) def test_server_sends_pong_after_sending_close(self): server = Protocol(SERVER) server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) - # websockets doesn't support sending a Pong frame after a Close frame. - with self.assertRaises(InvalidState): - server.send_pong(b"") + server.send_pong(b"") + self.assertEqual(server.data_to_send(), [b"\x8a\x00"]) def test_client_receives_pong_after_receiving_close(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, CloseCode.NORMAL_CLOSURE) client.receive_data(b"\x8a\x04\x22\x66\xaa\xee") + # websockets ignores control frames after a close frame. self.assertFrameReceived(client, None) self.assertFrameSent(client, None) @@ -1237,9 +1275,24 @@ def test_server_receives_pong_after_receiving_close(self): server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, CloseCode.GOING_AWAY) server.receive_data(b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22") + # websockets ignores control frames after a close frame. self.assertFrameReceived(server, None) self.assertFrameSent(server, None) + def test_client_sends_pong_after_connection_is_closed(self): + client = Protocol(CLIENT) + client.receive_eof() + with self.assertRaises(InvalidState) as raised: + client.send_pong(b"") + self.assertEqual(str(raised.exception), "connection is closed") + + def test_server_sends_pong_after_connection_is_closed(self): + server = Protocol(SERVER) + server.receive_eof() + with self.assertRaises(InvalidState) as raised: + server.send_pong(b"") + self.assertEqual(str(raised.exception), "connection is closed") + class FailTests(ProtocolTestCase): """ @@ -1370,8 +1423,9 @@ def test_client_send_close_in_fragmented_message(self): client.send_close() self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) self.assertIs(client.state, CLOSING) - with self.assertRaises(InvalidState): + with self.assertRaises(InvalidState) as raised: client.send_continuation(b"Eggs", fin=True) + self.assertEqual(str(raised.exception), "connection is closing") def test_server_send_close_in_fragmented_message(self): server = Protocol(SERVER) @@ -1380,8 +1434,9 @@ def test_server_send_close_in_fragmented_message(self): server.send_close() self.assertEqual(server.data_to_send(), [b"\x88\x00"]) self.assertIs(server.state, CLOSING) - with self.assertRaises(InvalidState): + with self.assertRaises(InvalidState) as raised: server.send_continuation(b"Eggs", fin=True) + self.assertEqual(str(raised.exception), "connection is closing") def test_client_receive_close_in_fragmented_message(self): client = Protocol(CLIENT) From 96fddaf49b5a5af1f3215076bf2a73dfb4b72ca1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 28 Jan 2024 16:47:51 +0100 Subject: [PATCH 1270/1539] Wording and line wrapping fixes in changelog. --- docs/project/changelog.rst | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index e288831be..dc84a5ae2 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -34,7 +34,8 @@ Backwards-incompatible changes .............................. .. admonition:: The ``ssl_context`` argument of :func:`~sync.client.connect` - and :func:`~sync.server.serve` is renamed to ``ssl``. + and :func:`~sync.server.serve` in the :mod:`threading` implementation is + renamed to ``ssl``. :class: note This aligns the API of the :mod:`threading` implementation with the @@ -140,7 +141,8 @@ Backwards-incompatible changes As a consequence, calling ``WebSocket.close()`` without arguments in a browser isn't reported as an error anymore. -.. admonition:: :func:`~server.serve` times out on the opening handshake after 10 seconds by default. +.. admonition:: :func:`~server.serve` times out on the opening handshake after + 10 seconds by default. :class: note You can adjust the timeout with the ``open_timeout`` parameter. Set it to @@ -149,7 +151,7 @@ Backwards-incompatible changes New features ............ -.. admonition:: websockets 11.0 introduces a implementation on top of :mod:`threading`. +.. admonition:: websockets 11.0 introduces a :mod:`threading` implementation. :class: important It may be more convenient if you don't need to manage many connections and @@ -211,7 +213,8 @@ Improvements Backwards-incompatible changes .............................. -.. admonition:: The ``exception`` attribute of :class:`~http11.Request` and :class:`~http11.Response` is deprecated. +.. admonition:: The ``exception`` attribute of :class:`~http11.Request` and + :class:`~http11.Response` is deprecated. :class: note Use the ``handshake_exc`` attribute of :class:`~server.ServerProtocol` and @@ -565,11 +568,11 @@ Backwards-incompatible changes .. admonition:: ``process_request`` is now expected to be a coroutine. :class: note - If you're passing a ``process_request`` argument to - :func:`~server.serve` or :class:`~server.WebSocketServerProtocol`, or if - you're overriding + If you're passing a ``process_request`` argument to :func:`~server.serve` + or :class:`~server.WebSocketServerProtocol`, or if you're overriding :meth:`~server.WebSocketServerProtocol.process_request` in a subclass, - define it with ``async def`` instead of ``def``. Previously, both were supported. + define it with ``async def`` instead of ``def``. Previously, both were + supported. For backwards compatibility, functions are still accepted, but mixing functions and coroutines won't work in some inheritance scenarios. From 3b7fa7673bf6a96a5e9debd7dcfa65e04f85efbb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 28 Jan 2024 16:49:13 +0100 Subject: [PATCH 1271/1539] Enable deprecation for second argument of handlers. --- docs/project/changelog.rst | 26 ++++++++++++++++++++------ src/websockets/legacy/server.py | 4 +--- tests/legacy/test_client_server.py | 6 ++---- 3 files changed, 23 insertions(+), 13 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index dc84a5ae2..fd186a5fc 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -43,6 +43,21 @@ Backwards-incompatible changes For backwards compatibility, ``ssl_context`` is still supported. +.. admonition:: Receiving the request path in the second parameter of connection + handlers is deprecated. + :class: note + + If you implemented the connection handler of a server as:: + + async def handler(request, path): + ... + + You should switch to the recommended pattern since 10.1:: + + async def handler(request): + path = request.path # only if handler() uses the path argument + ... + New features ............ @@ -257,20 +272,19 @@ New features * Added a tutorial. -* Made the second parameter of connection handlers optional. It will be - deprecated in the next major release. The request path is available in - the :attr:`~legacy.protocol.WebSocketCommonProtocol.path` attribute of - the first argument. +* Made the second parameter of connection handlers optional. The request path is + available in the :attr:`~legacy.protocol.WebSocketCommonProtocol.path` + attribute of the first argument. If you implemented the connection handler of a server as:: async def handler(request, path): ... - You should replace it by:: + You should replace it with:: async def handler(request): - path = request.path # if handler() uses the path argument + path = request.path # only if handler() uses the path argument ... * Added ``python -m websockets --version``. diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index e8cf8220f..4659ed9a6 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -1168,9 +1168,7 @@ def remove_path_argument( pass else: # ws_handler accepts two arguments; activate backwards compatibility. - - # Enable deprecation warning and announce deprecation in 11.0. - # warnings.warn("remove second argument of ws_handler", DeprecationWarning) + warnings.warn("remove second argument of ws_handler", DeprecationWarning) async def _ws_handler(websocket: WebSocketServerProtocol) -> Any: return await cast( diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 4a21f7cea..51a74734b 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -480,8 +480,7 @@ async def handler_with_path(ws, path): with self.temp_server( handler=handler_with_path, - # Enable deprecation warning and announce deprecation in 11.0. - # deprecation_warnings=["remove second argument of ws_handler"], + deprecation_warnings=["remove second argument of ws_handler"], ): with self.temp_client("/path"): self.assertEqual( @@ -497,8 +496,7 @@ async def handler_with_path(ws, path, extra): with self.temp_server( handler=bound_handler_with_path, - # Enable deprecation warning and announce deprecation in 11.0. - # deprecation_warnings=["remove second argument of ws_handler"], + deprecation_warnings=["remove second argument of ws_handler"], ): with self.temp_client("/path"): self.assertEqual( From aa33161cd9498bfca39d64fc36319bc1fbce68f2 Mon Sep 17 00:00:00 2001 From: MtkN1 <51289448+MtkN1@users.noreply.github.com> Date: Wed, 7 Feb 2024 13:14:22 +0900 Subject: [PATCH 1272/1539] Fix wrong RFC number --- src/websockets/legacy/client.py | 2 +- tests/test_protocol.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index b85d22867..255696580 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -599,7 +599,7 @@ async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]: yield protocol except Exception: # Add a random initial delay between 0 and 5 seconds. - # See 7.2.3. Recovering from Abnormal Closure in RFC 6544. + # See 7.2.3. Recovering from Abnormal Closure in RFC 6455. if backoff_delay == self.BACKOFF_MIN: initial_delay = random.random() * self.BACKOFF_INITIAL self.logger.info( diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 1d5dab7a0..e1527525b 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -714,7 +714,7 @@ class CloseTests(ProtocolTestCase): """ Test close frames. - See RFC 6544: + See RFC 6455: 5.5.1. Close 7.1.6. The WebSocket Connection Close Reason @@ -994,7 +994,7 @@ def test_server_sends_close_after_connection_is_closed(self): class PingTests(ProtocolTestCase): """ - Test ping. See 5.5.2. Ping in RFC 6544. + Test ping. See 5.5.2. Ping in RFC 6455. """ @@ -1153,7 +1153,7 @@ def test_server_sends_ping_after_connection_is_closed(self): class PongTests(ProtocolTestCase): """ - Test pong frames. See 5.5.3. Pong in RFC 6544. + Test pong frames. See 5.5.3. Pong in RFC 6455. """ @@ -1298,7 +1298,7 @@ class FailTests(ProtocolTestCase): """ Test failing the connection. - See 7.1.7. Fail the WebSocket Connection in RFC 6544. + See 7.1.7. Fail the WebSocket Connection in RFC 6455. """ @@ -1321,7 +1321,7 @@ class FragmentationTests(ProtocolTestCase): """ Test message fragmentation. - See 5.4. Fragmentation in RFC 6544. + See 5.4. Fragmentation in RFC 6455. """ From 87f58c7190025521e5dc380945b0cc536169bd0c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 11 Feb 2024 17:25:00 +0100 Subject: [PATCH 1273/1539] Fix make clean. --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index cf3b53393..bf8c8dc58 100644 --- a/Makefile +++ b/Makefile @@ -30,6 +30,6 @@ build: python setup.py build_ext --inplace clean: - find . -name '*.pyc' -o -name '*.so' -delete + find . -name '*.pyc' -delete -o -name '*.so' -delete find . -name __pycache__ -delete rm -rf .coverage .mypy_cache build compliance/reports dist docs/_build htmlcov MANIFEST src/websockets.egg-info From 9b5273c68323dd63598dfcba97339f03f61d3d0f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 11 Feb 2024 17:34:57 +0100 Subject: [PATCH 1274/1539] Move CLIENT/SERVER_CONTEXT to utils. Then we can reuse them for testing other implementations. --- tests/sync/client.py | 8 -------- tests/sync/server.py | 17 ----------------- tests/sync/test_client.py | 12 +++++++++--- tests/sync/test_server.py | 11 ++++++++--- tests/utils.py | 18 ++++++++++++++++++ 5 files changed, 35 insertions(+), 31 deletions(-) diff --git a/tests/sync/client.py b/tests/sync/client.py index bb4855c7f..72eb5b8d2 100644 --- a/tests/sync/client.py +++ b/tests/sync/client.py @@ -1,23 +1,15 @@ import contextlib -import ssl from websockets.sync.client import * from websockets.sync.server import WebSocketServer -from ..utils import CERTIFICATE - __all__ = [ - "CLIENT_CONTEXT", "run_client", "run_unix_client", ] -CLIENT_CONTEXT = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) -CLIENT_CONTEXT.load_verify_locations(CERTIFICATE) - - @contextlib.contextmanager def run_client(wsuri_or_server, secure=None, resource_name="/", **kwargs): if isinstance(wsuri_or_server, str): diff --git a/tests/sync/server.py b/tests/sync/server.py index a9a77438c..10ab789c2 100644 --- a/tests/sync/server.py +++ b/tests/sync/server.py @@ -1,25 +1,8 @@ import contextlib -import ssl import threading from websockets.sync.server import * -from ..utils import CERTIFICATE - - -SERVER_CONTEXT = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) -SERVER_CONTEXT.load_cert_chain(CERTIFICATE) - -# Work around https://github.com/openssl/openssl/issues/7967 - -# This bug causes connect() to hang in tests for the client. Including this -# workaround acknowledges that the issue could happen outside of the test suite. - -# It shouldn't happen too often, or else OpenSSL 1.1.1 would be unusable. If it -# happens, we can look for a library-level fix, but it won't be easy. - -SERVER_CONTEXT.num_tickets = 0 - def crash(ws): raise RuntimeError diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index c403b9632..bebf68aa5 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -7,9 +7,15 @@ from websockets.extensions.permessage_deflate import PerMessageDeflate from websockets.sync.client import * -from ..utils import MS, DeprecationTestCase, temp_unix_socket_path -from .client import CLIENT_CONTEXT, run_client, run_unix_client -from .server import SERVER_CONTEXT, do_nothing, run_server, run_unix_server +from ..utils import ( + CLIENT_CONTEXT, + MS, + SERVER_CONTEXT, + DeprecationTestCase, + temp_unix_socket_path, +) +from .client import run_client, run_unix_client +from .server import do_nothing, run_server, run_unix_server class ClientTests(unittest.TestCase): diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index f9f30baf1..490a3f63e 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -14,10 +14,15 @@ from websockets.http11 import Request, Response from websockets.sync.server import * -from ..utils import MS, DeprecationTestCase, temp_unix_socket_path -from .client import CLIENT_CONTEXT, run_client, run_unix_client -from .server import ( +from ..utils import ( + CLIENT_CONTEXT, + MS, SERVER_CONTEXT, + DeprecationTestCase, + temp_unix_socket_path, +) +from .client import run_client, run_unix_client +from .server import ( EvalShellMixin, crash, do_nothing, diff --git a/tests/utils.py b/tests/utils.py index 2937a2f15..bd3b61d7b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,6 +3,7 @@ import os import pathlib import platform +import ssl import tempfile import time import unittest @@ -17,6 +18,23 @@ CERTIFICATE = bytes(pathlib.Path(__file__).with_name("test_localhost.pem")) +CLIENT_CONTEXT = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) +CLIENT_CONTEXT.load_verify_locations(CERTIFICATE) + + +SERVER_CONTEXT = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) +SERVER_CONTEXT.load_cert_chain(CERTIFICATE) + +# Work around https://github.com/openssl/openssl/issues/7967 + +# This bug causes connect() to hang in tests for the client. Including this +# workaround acknowledges that the issue could happen outside of the test suite. + +# It shouldn't happen too often, or else OpenSSL 1.1.1 would be unusable. If it +# happens, we can look for a library-level fix, but it won't be easy. + +SERVER_CONTEXT.num_tickets = 0 + DATE = email.utils.formatdate(usegmt=True) From de768cf65e7e2b1a3b67854fb9e08816a5ff7050 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 11 Feb 2024 21:08:45 +0100 Subject: [PATCH 1275/1539] Improve tests for sync implementation. --- tests/sync/server.py | 8 +++--- tests/sync/test_client.py | 59 ++++++++++++++++++++------------------- tests/sync/test_server.py | 36 ++++++++++++------------ 3 files changed, 53 insertions(+), 50 deletions(-) diff --git a/tests/sync/server.py b/tests/sync/server.py index 10ab789c2..d5295ccd8 100644 --- a/tests/sync/server.py +++ b/tests/sync/server.py @@ -25,8 +25,8 @@ def assertEval(self, client, expr, value): @contextlib.contextmanager -def run_server(ws_handler=eval_shell, host="localhost", port=0, **kwargs): - with serve(ws_handler, host, port, **kwargs) as server: +def run_server(handler=eval_shell, host="localhost", port=0, **kwargs): + with serve(handler, host, port, **kwargs) as server: thread = threading.Thread(target=server.serve_forever) thread.start() try: @@ -37,8 +37,8 @@ def run_server(ws_handler=eval_shell, host="localhost", port=0, **kwargs): @contextlib.contextmanager -def run_unix_server(path, ws_handler=eval_shell, **kwargs): - with unix_serve(ws_handler, path, **kwargs) as server: +def run_unix_server(path, handler=eval_shell, **kwargs): + with unix_serve(handler, path, **kwargs) as server: thread = threading.Thread(target=server.serve_forever) thread.start() try: diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index bebf68aa5..03f4e972f 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -3,7 +3,7 @@ import threading import unittest -from websockets.exceptions import InvalidHandshake +from websockets.exceptions import InvalidHandshake, InvalidURI from websockets.extensions.permessage_deflate import PerMessageDeflate from websockets.sync.client import * @@ -25,29 +25,6 @@ def test_connection(self): with run_client(server) as client: self.assertEqual(client.protocol.state.name, "OPEN") - def test_connection_fails(self): - """Client connects to server but the handshake fails.""" - - def remove_accept_header(self, request, response): - del response.headers["Sec-WebSocket-Accept"] - - # The connection will be open for the server but failed for the client. - # Use a connection handler that exits immediately to avoid an exception. - with run_server(do_nothing, process_response=remove_accept_header) as server: - with self.assertRaises(InvalidHandshake) as raised: - with run_client(server, close_timeout=MS): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "missing Sec-WebSocket-Accept header", - ) - - def test_tcp_connection_fails(self): - """Client fails to connect to server.""" - with self.assertRaises(OSError): - with run_client("ws://localhost:54321"): # invalid port - self.fail("did not raise") - def test_existing_socket(self): """Client connects using a pre-existing socket.""" with run_server() as server: @@ -103,6 +80,35 @@ def create_connection(*args, **kwargs): with run_client(server, create_connection=create_connection) as client: self.assertTrue(client.create_connection_ran) + def test_invalid_uri(self): + """Client receives an invalid URI.""" + with self.assertRaises(InvalidURI): + with run_client("http://localhost"): # invalid scheme + self.fail("did not raise") + + def test_tcp_connection_fails(self): + """Client fails to connect to server.""" + with self.assertRaises(OSError): + with run_client("ws://localhost:54321"): # invalid port + self.fail("did not raise") + + def test_handshake_fails(self): + """Client connects to server but the handshake fails.""" + + def remove_accept_header(self, request, response): + del response.headers["Sec-WebSocket-Accept"] + + # The connection will be open for the server but failed for the client. + # Use a connection handler that exits immediately to avoid an exception. + with run_server(do_nothing, process_response=remove_accept_header) as server: + with self.assertRaises(InvalidHandshake) as raised: + with run_client(server, close_timeout=MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "missing Sec-WebSocket-Accept header", + ) + def test_timeout_during_handshake(self): """Client times out before receiving handshake response from server.""" gate = threading.Event() @@ -115,10 +121,7 @@ def stall_connection(self, request): with run_server(do_nothing, process_request=stall_connection) as server: try: with self.assertRaises(TimeoutError) as raised: - # While it shouldn't take 50ms to open a connection, this - # test becomes flaky in CI when setting a smaller timeout, - # even after increasing WEBSOCKETS_TESTS_TIMEOUT_FACTOR. - with run_client(server, open_timeout=5 * MS): + with run_client(server, open_timeout=2 * MS): self.fail("did not raise") self.assertEqual( str(raised.exception), diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 490a3f63e..9d509a5c4 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -39,21 +39,6 @@ def test_connection(self): with run_client(server) as client: self.assertEval(client, "ws.protocol.state.name", "OPEN") - def test_connection_fails(self): - """Server receives connection from client but the handshake fails.""" - - def remove_key_header(self, request): - del request.headers["Sec-WebSocket-Key"] - - with run_server(process_request=remove_key_header) as server: - with self.assertRaises(InvalidStatus) as raised: - with run_client(server): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 400", - ) - def test_connection_handler_returns(self): """Connection handler returns.""" with run_server(do_nothing) as server: @@ -81,8 +66,8 @@ def test_existing_socket(self): """Server receives connection using a pre-existing socket.""" with socket.create_server(("localhost", 0)) as sock: with run_server(sock=sock): - # Build WebSocket URI to ensure we connect to the right socket. - with run_client("ws://{}:{}/".format(*sock.getsockname())) as client: + uri = "ws://{}:{}/".format(*sock.getsockname()) + with run_client(uri) as client: self.assertEval(client, "ws.protocol.state.name", "OPEN") def test_select_subprotocol(self): @@ -185,7 +170,7 @@ def process_response(ws, request, response): self.assertEval(client, "ws.process_response_ran", "True") def test_process_response_override_response(self): - """Server runs process_response after processing the handshake.""" + """Server runs process_response and overrides the handshake response.""" def process_response(ws, request, response): headers = response.headers.copy() @@ -253,6 +238,21 @@ def create_connection(*args, **kwargs): with run_client(server) as client: self.assertEval(client, "ws.create_connection_ran", "True") + def test_handshake_fails(self): + """Server receives connection from client but the handshake fails.""" + + def remove_key_header(self, request): + del request.headers["Sec-WebSocket-Key"] + + with run_server(process_request=remove_key_header) as server: + with self.assertRaises(InvalidStatus) as raised: + with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 400", + ) + def test_timeout_during_handshake(self): """Server times out before receiving handshake request from client.""" with run_server(open_timeout=MS) as server: From 50b6d20d7a652d39cffc7aea9f8c0abc88fb8f37 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 11 Feb 2024 21:19:24 +0100 Subject: [PATCH 1276/1539] Various cleanups in sync implementation. --- src/websockets/sync/client.py | 9 +++-- src/websockets/sync/connection.py | 58 +++++++++++++++---------------- src/websockets/sync/server.py | 27 +++++++------- 3 files changed, 45 insertions(+), 49 deletions(-) diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 6faca7789..0bb7a76fd 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -25,7 +25,7 @@ class ClientConnection(Connection): """ - Threaded implementation of a WebSocket client connection. + :mod:`threading` implementation of a WebSocket client connection. :class:`ClientConnection` provides :meth:`recv` and :meth:`send` methods for receiving and sending messages. @@ -157,7 +157,7 @@ def connect( :func:`connect` may be used as a context manager:: - async with websockets.sync.client.connect(...) as websocket: + with websockets.sync.client.connect(...) as websocket: ... The connection is closed automatically when exiting the context. @@ -273,19 +273,18 @@ def connect( sock = ssl.wrap_socket(sock, server_hostname=server_hostname) sock.settimeout(None) - # Initialize WebSocket connection + # Initialize WebSocket protocol protocol = ClientProtocol( wsuri, origin=origin, extensions=extensions, subprotocols=subprotocols, - state=CONNECTING, max_size=max_size, logger=logger, ) - # Initialize WebSocket protocol + # Initialize WebSocket connection connection = create_connection( sock, diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 62aa17ffd..6ac40cd7c 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -21,12 +21,10 @@ __all__ = ["Connection"] -logger = logging.getLogger(__name__) - class Connection: """ - Threaded implementation of a WebSocket connection. + :mod:`threading` implementation of a WebSocket connection. :class:`Connection` provides APIs shared between WebSocket servers and clients. @@ -82,7 +80,7 @@ def __init__( self.close_deadline: Optional[Deadline] = None # Mapping of ping IDs to pong waiters, in chronological order. - self.pings: Dict[bytes, threading.Event] = {} + self.ping_waiters: Dict[bytes, threading.Event] = {} # Receiving events from the socket. self.recv_events_thread = threading.Thread(target=self.recv_events) @@ -90,7 +88,7 @@ def __init__( # Exception raised in recv_events, to be chained to ConnectionClosed # in the user thread in order to show why the TCP connection dropped. - self.recv_events_exc: Optional[BaseException] = None + self.recv_exc: Optional[BaseException] = None # Public attributes @@ -198,7 +196,7 @@ def recv(self, timeout: Optional[float] = None) -> Data: try: return self.recv_messages.get(timeout) except EOFError: - raise self.protocol.close_exc from self.recv_events_exc + raise self.protocol.close_exc from self.recv_exc except RuntimeError: raise RuntimeError( "cannot call recv while another thread " @@ -229,9 +227,10 @@ def recv_streaming(self) -> Iterator[Data]: """ try: - yield from self.recv_messages.get_iter() + for frame in self.recv_messages.get_iter(): + yield frame except EOFError: - raise self.protocol.close_exc from self.recv_events_exc + raise self.protocol.close_exc from self.recv_exc except RuntimeError: raise RuntimeError( "cannot call recv_streaming while another thread " @@ -273,7 +272,7 @@ def send(self, message: Union[Data, Iterable[Data]]) -> None: Raises: ConnectionClosed: When the connection is closed. - RuntimeError: If a connection is busy sending a fragmented message. + RuntimeError: If the connection is sending a fragmented message. TypeError: If ``message`` doesn't have a supported type. """ @@ -449,15 +448,15 @@ def ping(self, data: Optional[Data] = None) -> threading.Event: with self.send_context(): # Protect against duplicates if a payload is explicitly set. - if data in self.pings: + if data in self.ping_waiters: raise RuntimeError("already waiting for a pong with the same data") # Generate a unique random payload otherwise. - while data is None or data in self.pings: + while data is None or data in self.ping_waiters: data = struct.pack("!I", random.getrandbits(32)) pong_waiter = threading.Event() - self.pings[data] = pong_waiter + self.ping_waiters[data] = pong_waiter self.protocol.send_ping(data) return pong_waiter @@ -504,22 +503,22 @@ def acknowledge_pings(self, data: bytes) -> None: """ with self.protocol_mutex: # Ignore unsolicited pong. - if data not in self.pings: + if data not in self.ping_waiters: return # Sending a pong for only the most recent ping is legal. # Acknowledge all previous pings too in that case. ping_id = None ping_ids = [] - for ping_id, ping in self.pings.items(): + for ping_id, ping in self.ping_waiters.items(): ping_ids.append(ping_id) ping.set() if ping_id == data: break else: raise AssertionError("solicited pong not found in pings") - # Remove acknowledged pings from self.pings. + # Remove acknowledged pings from self.ping_waiters. for ping_id in ping_ids: - del self.pings[ping_id] + del self.ping_waiters[ping_id] def recv_events(self) -> None: """ @@ -541,10 +540,10 @@ def recv_events(self) -> None: self.logger.debug("error while receiving data", exc_info=True) # When the closing handshake is initiated by our side, # recv() may block until send_context() closes the socket. - # In that case, send_context() already set recv_events_exc. - # Calling set_recv_events_exc() avoids overwriting it. + # In that case, send_context() already set recv_exc. + # Calling set_recv_exc() avoids overwriting it. with self.protocol_mutex: - self.set_recv_events_exc(exc) + self.set_recv_exc(exc) break if data == b"": @@ -552,7 +551,7 @@ def recv_events(self) -> None: # Acquire the connection lock. with self.protocol_mutex: - # Feed incoming data to the connection. + # Feed incoming data to the protocol. self.protocol.receive_data(data) # This isn't expected to raise an exception. @@ -568,7 +567,7 @@ def recv_events(self) -> None: # set by send_context(), in case of a race condition # i.e. send_context() closes the socket after recv() # returns above but before send_data() calls send(). - self.set_recv_events_exc(exc) + self.set_recv_exc(exc) break if self.protocol.close_expected(): @@ -595,7 +594,7 @@ def recv_events(self) -> None: # Breaking out of the while True: ... loop means that we believe # that the socket doesn't work anymore. with self.protocol_mutex: - # Feed the end of the data stream to the connection. + # Feed the end of the data stream to the protocol. self.protocol.receive_eof() # This isn't expected to generate events. @@ -609,7 +608,7 @@ def recv_events(self) -> None: # This branch should never run. It's a safety net in case of bugs. self.logger.error("unexpected internal error", exc_info=True) with self.protocol_mutex: - self.set_recv_events_exc(exc) + self.set_recv_exc(exc) # We don't know where we crashed. Force protocol state to CLOSED. self.protocol.state = CLOSED finally: @@ -668,7 +667,6 @@ def send_context( wait_for_close = True # If the connection is expected to close soon, set the # close deadline based on the close timeout. - # Since we tested earlier that protocol.state was OPEN # (or CONNECTING) and we didn't release protocol_mutex, # it is certain that self.close_deadline is still None. @@ -710,11 +708,11 @@ def send_context( # original_exc is never set when wait_for_close is True. assert original_exc is None original_exc = TimeoutError("timed out while closing connection") - # Set recv_events_exc before closing the socket in order to get + # Set recv_exc before closing the socket in order to get # proper exception reporting. raise_close_exc = True with self.protocol_mutex: - self.set_recv_events_exc(original_exc) + self.set_recv_exc(original_exc) # If an error occurred, close the socket to terminate the connection and # raise an exception. @@ -745,16 +743,16 @@ def send_data(self) -> None: except OSError: # socket already closed pass - def set_recv_events_exc(self, exc: Optional[BaseException]) -> None: + def set_recv_exc(self, exc: Optional[BaseException]) -> None: """ - Set recv_events_exc, if not set yet. + Set recv_exc, if not set yet. This method requires holding protocol_mutex. """ assert self.protocol_mutex.locked() - if self.recv_events_exc is None: - self.recv_events_exc = exc + if self.recv_exc is None: + self.recv_exc = exc def close_socket(self) -> None: """ diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index fa6087d54..a070edf18 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -30,7 +30,7 @@ class ServerConnection(Connection): """ - Threaded implementation of a WebSocket server connection. + :mod:`threading` implementation of a WebSocket server connection. :class:`ServerConnection` provides :meth:`recv` and :meth:`send` methods for receiving and sending messages. @@ -188,6 +188,8 @@ class WebSocketServer: handler: Handler for one connection. Receives the socket and address returned by :meth:`~socket.socket.accept`. logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. + See the :doc:`logging guide <../../topics/logging>` for details. """ @@ -311,16 +313,16 @@ def serve( Whenever a client connects, the server creates a :class:`ServerConnection`, performs the opening handshake, and delegates to the ``handler``. - The handler receives a :class:`ServerConnection` instance, which you can use - to send and receive messages. + The handler receives the :class:`ServerConnection` instance, which you can + use to send and receive messages. Once the handler completes, either normally or with an exception, the server performs the closing handshake and closes the connection. - :class:`WebSocketServer` mirrors the API of + This function returns a :class:`WebSocketServer` whose API mirrors :class:`~socketserver.BaseServer`. Treat it as a context manager to ensure - that it will be closed and call the :meth:`~WebSocketServer.serve_forever` - method to serve requests:: + that it will be closed and call :meth:`~WebSocketServer.serve_forever` to + serve requests:: def handler(websocket): ... @@ -454,15 +456,13 @@ def conn_handler(sock: socket.socket, addr: Any) -> None: sock.do_handshake() sock.settimeout(None) - # Create a closure so that select_subprotocol has access to self. - + # Create a closure to give select_subprotocol access to connection. protocol_select_subprotocol: Optional[ Callable[ [ServerProtocol, Sequence[Subprotocol]], Optional[Subprotocol], ] ] = None - if select_subprotocol is not None: def protocol_select_subprotocol( @@ -475,19 +475,18 @@ def protocol_select_subprotocol( assert protocol is connection.protocol return select_subprotocol(connection, subprotocols) - # Initialize WebSocket connection + # Initialize WebSocket protocol protocol = ServerProtocol( origins=origins, extensions=extensions, subprotocols=subprotocols, select_subprotocol=protocol_select_subprotocol, - state=CONNECTING, max_size=max_size, logger=logger, ) - # Initialize WebSocket protocol + # Initialize WebSocket connection assert create_connection is not None # help mypy connection = create_connection( @@ -522,7 +521,7 @@ def protocol_select_subprotocol( def unix_serve( - handler: Callable[[ServerConnection], Any], + handler: Callable[[ServerConnection], None], path: Optional[str] = None, **kwargs: Any, ) -> WebSocketServer: @@ -541,4 +540,4 @@ def unix_serve( path: File system path to the Unix socket. """ - return serve(handler, path=path, unix=True, **kwargs) + return serve(handler, unix=True, path=path, **kwargs) From e217458ef8b692e45ca6f66c5aeb7fad0aee97ee Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 11 Feb 2024 21:20:48 +0100 Subject: [PATCH 1277/1539] Small cleanups in legacy implementation. --- src/websockets/legacy/client.py | 2 +- src/websockets/legacy/server.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 255696580..e5da8b13a 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -60,7 +60,7 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): :class:`WebSocketClientProtocol` provides :meth:`recv` and :meth:`send` coroutines for receiving and sending messages. - It supports asynchronous iteration to receive incoming messages:: + It supports asynchronous iteration to receive messages:: async for message in websocket: await process(message) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 4659ed9a6..0f3c1c150 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -649,9 +649,7 @@ class WebSocketServer: """ WebSocket server returned by :func:`serve`. - This class provides the same interface as :class:`~asyncio.Server`, - notably the :meth:`~asyncio.Server.close` - and :meth:`~asyncio.Server.wait_closed` methods. + This class mirrors the API of :class:`~asyncio.Server`. It keeps track of WebSocket connections in order to close them properly when shutting down. From 5f24866bfeefbe561fa76f7e5a494996d95a2757 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 16 Apr 2024 08:48:42 +0200 Subject: [PATCH 1278/1539] Always mark background threads as daemon. Fix #1455. --- src/websockets/sync/connection.py | 9 +++++++-- src/websockets/sync/server.py | 3 +++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 6ac40cd7c..b41202dc9 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -82,8 +82,13 @@ def __init__( # Mapping of ping IDs to pong waiters, in chronological order. self.ping_waiters: Dict[bytes, threading.Event] = {} - # Receiving events from the socket. - self.recv_events_thread = threading.Thread(target=self.recv_events) + # Receiving events from the socket. This thread explicitly is marked as + # to support creating a connection in a non-daemon thread then using it + # in a daemon thread; this shouldn't block the intpreter from exiting. + self.recv_events_thread = threading.Thread( + target=self.recv_events, + daemon=True, + ) self.recv_events_thread.start() # Exception raised in recv_events, to be chained to ConnectionClosed diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index a070edf18..fd4f5d3bd 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -233,6 +233,9 @@ def serve_forever(self) -> None: sock, addr = self.socket.accept() except OSError: break + # Since there isn't a mechanism for tracking connections and waiting + # for them to terminate, we cannot use daemon threads, or else all + # connections would be terminate brutally when closing the server. thread = threading.Thread(target=self.handler, args=(sock, addr)) thread.start() From 2774fabc13f09311dec345cc8513aa7b93200b92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexis=20M=C3=A9taireau?= Date: Tue, 16 Apr 2024 16:52:01 +0200 Subject: [PATCH 1279/1539] docs(nginx): Fix a typo --- docs/howto/nginx.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/howto/nginx.rst b/docs/howto/nginx.rst index 30545fbc7..ff42c3c2b 100644 --- a/docs/howto/nginx.rst +++ b/docs/howto/nginx.rst @@ -17,9 +17,9 @@ Save this app to ``app.py``: .. literalinclude:: ../../example/deployment/nginx/app.py :emphasize-lines: 21,23 -We'd like to nginx to connect to websockets servers via Unix sockets in order -to avoid the overhead of TCP for communicating between processes running in -the same OS. +We'd like nginx to connect to websockets servers via Unix sockets in order to +avoid the overhead of TCP for communicating between processes running in the +same OS. We start the app with :func:`~websockets.server.unix_serve`. Each server process listens on a different socket thanks to an environment variable set From 0fdc694a980ede0e91286ea5ea1d4f9c62bb42fb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 16 Apr 2024 08:55:36 +0200 Subject: [PATCH 1280/1539] Make it easy to monkey-patch length of frames repr. Fix #1451. --- src/websockets/frames.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 201bc5068..862eef3aa 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -146,6 +146,9 @@ class Frame: rsv2: bool = False rsv3: bool = False + # Monkey-patch if you want to see more in logs. Should be a multiple of 3. + MAX_LOG = 75 + def __str__(self) -> str: """ Return a human-readable representation of a frame. @@ -163,8 +166,9 @@ def __str__(self) -> str: # We'll show at most the first 16 bytes and the last 8 bytes. # Encode just what we need, plus two dummy bytes to elide later. binary = self.data - if len(binary) > 25: - binary = b"".join([binary[:16], b"\x00\x00", binary[-8:]]) + if len(binary) > self.MAX_LOG // 3: + cut = (self.MAX_LOG // 3 - 1) // 3 # by default cut = 8 + binary = b"".join([binary[: 2 * cut], b"\x00\x00", binary[-cut:]]) data = " ".join(f"{byte:02x}" for byte in binary) elif self.opcode is OP_CLOSE: data = str(Close.parse(self.data)) @@ -179,15 +183,17 @@ def __str__(self) -> str: coding = "text" except (UnicodeDecodeError, AttributeError): binary = self.data - if len(binary) > 25: - binary = b"".join([binary[:16], b"\x00\x00", binary[-8:]]) + if len(binary) > self.MAX_LOG // 3: + cut = (self.MAX_LOG // 3 - 1) // 3 # by default cut = 8 + binary = b"".join([binary[: 2 * cut], b"\x00\x00", binary[-cut:]]) data = " ".join(f"{byte:02x}" for byte in binary) coding = "binary" else: data = "''" - if len(data) > 75: - data = data[:48] + "..." + data[-24:] + if len(data) > self.MAX_LOG: + cut = self.MAX_LOG // 3 - 1 # by default cut = 24 + data = data[: 2 * cut] + "..." + data[-cut:] metadata = ", ".join(filter(None, [coding, length, non_final])) From f0398141d2efd28f64d8e1d6d9adc179a9e5e334 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 16 Apr 2024 19:41:14 +0200 Subject: [PATCH 1281/1539] Bump asyncio_timeout to 4.0.3. This makes type checking pass again. --- src/websockets/legacy/async_timeout.py | 39 ++++++++++++++++++-------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/src/websockets/legacy/async_timeout.py b/src/websockets/legacy/async_timeout.py index 8264094f5..6ffa89969 100644 --- a/src/websockets/legacy/async_timeout.py +++ b/src/websockets/legacy/async_timeout.py @@ -9,12 +9,12 @@ from typing import Optional, Type -# From https://github.com/python/typing_extensions/blob/main/src/typing_extensions.py -# Licensed under the Python Software Foundation License (PSF-2.0) - if sys.version_info >= (3, 11): from typing import final else: + # From https://github.com/python/typing_extensions/blob/main/src/typing_extensions.py + # Licensed under the Python Software Foundation License (PSF-2.0) + # @final exists in 3.8+, but we backport it for all versions # before 3.11 to keep support for the __final__ attribute. # See https://bugs.python.org/issue46342 @@ -49,10 +49,21 @@ class Other(Leaf): # Error reported by type checker pass return f + # End https://github.com/python/typing_extensions/blob/main/src/typing_extensions.py + + +if sys.version_info >= (3, 11): + + def _uncancel_task(task: "asyncio.Task[object]") -> None: + task.uncancel() + +else: + + def _uncancel_task(task: "asyncio.Task[object]") -> None: + pass -# End https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py -__version__ = "4.0.2" +__version__ = "4.0.3" __all__ = ("timeout", "timeout_at", "Timeout") @@ -124,7 +135,7 @@ class Timeout: # The purpose is to time out as soon as possible # without waiting for the next await expression. - __slots__ = ("_deadline", "_loop", "_state", "_timeout_handler") + __slots__ = ("_deadline", "_loop", "_state", "_timeout_handler", "_task") def __init__( self, deadline: Optional[float], loop: asyncio.AbstractEventLoop @@ -132,6 +143,7 @@ def __init__( self._loop = loop self._state = _State.INIT + self._task: Optional["asyncio.Task[object]"] = None self._timeout_handler = None # type: Optional[asyncio.Handle] if deadline is None: self._deadline = None # type: Optional[float] @@ -187,6 +199,7 @@ def reject(self) -> None: self._reject() def _reject(self) -> None: + self._task = None if self._timeout_handler is not None: self._timeout_handler.cancel() self._timeout_handler = None @@ -234,11 +247,11 @@ def _reschedule(self) -> None: if self._timeout_handler is not None: self._timeout_handler.cancel() - task = asyncio.current_task() + self._task = asyncio.current_task() if deadline <= now: - self._timeout_handler = self._loop.call_soon(self._on_timeout, task) + self._timeout_handler = self._loop.call_soon(self._on_timeout) else: - self._timeout_handler = self._loop.call_at(deadline, self._on_timeout, task) + self._timeout_handler = self._loop.call_at(deadline, self._on_timeout) def _do_enter(self) -> None: if self._state != _State.INIT: @@ -248,15 +261,19 @@ def _do_enter(self) -> None: def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None: if exc_type is asyncio.CancelledError and self._state == _State.TIMEOUT: + assert self._task is not None + _uncancel_task(self._task) self._timeout_handler = None + self._task = None raise asyncio.TimeoutError # timeout has not expired self._state = _State.EXIT self._reject() return None - def _on_timeout(self, task: "asyncio.Task[None]") -> None: - task.cancel() + def _on_timeout(self) -> None: + assert self._task is not None + self._task.cancel() self._state = _State.TIMEOUT # drop the reference early self._timeout_handler = None From 33997631a04320a5f8d57fac0f2645dc2d654c29 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 12 Jul 2024 08:35:58 +0200 Subject: [PATCH 1282/1539] Update ruff. --- Makefile | 2 +- pyproject.toml | 4 ++-- tox.ini | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Makefile b/Makefile index bf8c8dc58..dacfe2a0b 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ default: style types tests style: black src tests - ruff --fix src tests + ruff check --fix src tests types: mypy --strict src diff --git a/pyproject.toml b/pyproject.toml index c4c5412c5..2367849ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,7 +70,7 @@ exclude_lines = [ "@unittest.skip", ] -[tool.ruff] +[tool.ruff.lint] select = [ "E", # pycodestyle "F", # Pyflakes @@ -82,6 +82,6 @@ ignore = [ "F405", ] -[tool.ruff.isort] +[tool.ruff.lint.isort] combine-as-imports = true lines-after-imports = 2 diff --git a/tox.ini b/tox.ini index 538b638d9..b0e4a5931 100644 --- a/tox.ini +++ b/tox.ini @@ -32,7 +32,7 @@ commands = black --check src tests deps = black [testenv:ruff] -commands = ruff src tests +commands = ruff check src tests deps = ruff [testenv:mypy] From 2d195baaa632efd9fb87f09813d01af28464eb8c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 12 Jul 2024 09:01:24 +0200 Subject: [PATCH 1283/1539] Don't run tests on Python 3.7. Forgotten in 1bf73423. --- tox.ini | 1 - 1 file changed, 1 deletion(-) diff --git a/tox.ini b/tox.ini index b0e4a5931..06003c85b 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,5 @@ [tox] envlist = - py37 py38 py39 py310 From 7f402303fe1703767d9236494aacc3f197fbc708 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 12 Jul 2024 09:02:07 +0200 Subject: [PATCH 1284/1539] Switch from typing.Optional to | None. --- src/websockets/client.py | 16 +-- src/websockets/exceptions.py | 19 ++- src/websockets/extensions/base.py | 4 +- .../extensions/permessage_deflate.py | 32 ++--- src/websockets/frames.py | 8 +- src/websockets/headers.py | 6 +- src/websockets/http11.py | 14 +- src/websockets/imports.py | 6 +- src/websockets/legacy/auth.py | 18 +-- src/websockets/legacy/client.py | 77 +++++----- src/websockets/legacy/framing.py | 8 +- src/websockets/legacy/protocol.py | 65 ++++----- src/websockets/legacy/server.py | 135 +++++++++--------- src/websockets/protocol.py | 28 ++-- src/websockets/server.py | 35 ++--- src/websockets/sync/client.py | 44 +++--- src/websockets/sync/connection.py | 28 ++-- src/websockets/sync/messages.py | 12 +- src/websockets/sync/server.py | 92 ++++++------ src/websockets/sync/utils.py | 7 +- src/websockets/typing.py | 2 +- src/websockets/uri.py | 8 +- 22 files changed, 334 insertions(+), 330 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 633b1960b..cfb441fd9 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import Any, Generator, List, Optional, Sequence +from typing import Any, Generator, List, Sequence from .datastructures import Headers, MultipleValuesError from .exceptions import ( @@ -73,12 +73,12 @@ def __init__( self, wsuri: WebSocketURI, *, - origin: Optional[Origin] = None, - extensions: Optional[Sequence[ClientExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, + origin: Origin | None = None, + extensions: Sequence[ClientExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, state: State = CONNECTING, - max_size: Optional[int] = 2**20, - logger: Optional[LoggerLike] = None, + max_size: int | None = 2**20, + logger: LoggerLike | None = None, ): super().__init__( side=CLIENT, @@ -261,7 +261,7 @@ def process_extensions(self, headers: Headers) -> List[Extension]: return accepted_extensions - def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: + def process_subprotocol(self, headers: Headers) -> Subprotocol | None: """ Handle the Sec-WebSocket-Protocol HTTP response header. @@ -274,7 +274,7 @@ def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: Subprotocol, if one was selected. """ - subprotocol: Optional[Subprotocol] = None + subprotocol: Subprotocol | None = None subprotocols = headers.get_all("Sec-WebSocket-Protocol") diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index f7169e3b1..adb66e262 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -31,7 +31,6 @@ from __future__ import annotations import http -from typing import Optional from . import datastructures, frames, http11 from .typing import StatusLike @@ -78,11 +77,11 @@ class ConnectionClosed(WebSocketException): Raised when trying to interact with a closed connection. Attributes: - rcvd (Optional[Close]): if a close frame was received, its code and + rcvd (Close | None): if a close frame was received, its code and reason are available in ``rcvd.code`` and ``rcvd.reason``. - sent (Optional[Close]): if a close frame was sent, its code and reason + sent (Close | None): if a close frame was sent, its code and reason are available in ``sent.code`` and ``sent.reason``. - rcvd_then_sent (Optional[bool]): if close frames were received and + rcvd_then_sent (bool | None): if close frames were received and sent, this attribute tells in which order this happened, from the perspective of this side of the connection. @@ -90,9 +89,9 @@ class ConnectionClosed(WebSocketException): def __init__( self, - rcvd: Optional[frames.Close], - sent: Optional[frames.Close], - rcvd_then_sent: Optional[bool] = None, + rcvd: frames.Close | None, + sent: frames.Close | None, + rcvd_then_sent: bool | None = None, ) -> None: self.rcvd = rcvd self.sent = sent @@ -181,7 +180,7 @@ class InvalidHeader(InvalidHandshake): """ - def __init__(self, name: str, value: Optional[str] = None) -> None: + def __init__(self, name: str, value: str | None = None) -> None: self.name = name self.value = value @@ -221,7 +220,7 @@ class InvalidOrigin(InvalidHeader): """ - def __init__(self, origin: Optional[str]) -> None: + def __init__(self, origin: str | None) -> None: super().__init__("Origin", origin) @@ -301,7 +300,7 @@ class InvalidParameterValue(NegotiationError): """ - def __init__(self, name: str, value: Optional[str]) -> None: + def __init__(self, name: str, value: str | None) -> None: self.name = name self.value = value diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index 7446c990c..5b5528a09 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Optional, Sequence, Tuple +from typing import List, Sequence, Tuple from .. import frames from ..typing import ExtensionName, ExtensionParameter @@ -22,7 +22,7 @@ def decode( self, frame: frames.Frame, *, - max_size: Optional[int] = None, + max_size: int | None = None, ) -> frames.Frame: """ Decode an incoming frame. diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index edccac3ca..e95b1064b 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -2,7 +2,7 @@ import dataclasses import zlib -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Sequence, Tuple, Union from .. import exceptions, frames from ..typing import ExtensionName, ExtensionParameter @@ -36,7 +36,7 @@ def __init__( local_no_context_takeover: bool, remote_max_window_bits: int, local_max_window_bits: int, - compress_settings: Optional[Dict[Any, Any]] = None, + compress_settings: Dict[Any, Any] | None = None, ) -> None: """ Configure the Per-Message Deflate extension. @@ -84,7 +84,7 @@ def decode( self, frame: frames.Frame, *, - max_size: Optional[int] = None, + max_size: int | None = None, ) -> frames.Frame: """ Decode an incoming frame. @@ -174,8 +174,8 @@ def encode(self, frame: frames.Frame) -> frames.Frame: def _build_parameters( server_no_context_takeover: bool, client_no_context_takeover: bool, - server_max_window_bits: Optional[int], - client_max_window_bits: Optional[Union[int, bool]], + server_max_window_bits: int | None, + client_max_window_bits: Union[int, bool] | None, ) -> List[ExtensionParameter]: """ Build a list of ``(name, value)`` pairs for some compression parameters. @@ -197,7 +197,7 @@ def _build_parameters( def _extract_parameters( params: Sequence[ExtensionParameter], *, is_server: bool -) -> Tuple[bool, bool, Optional[int], Optional[Union[int, bool]]]: +) -> Tuple[bool, bool, int | None, Union[int, bool] | None]: """ Extract compression parameters from a list of ``(name, value)`` pairs. @@ -207,8 +207,8 @@ def _extract_parameters( """ server_no_context_takeover: bool = False client_no_context_takeover: bool = False - server_max_window_bits: Optional[int] = None - client_max_window_bits: Optional[Union[int, bool]] = None + server_max_window_bits: int | None = None + client_max_window_bits: Union[int, bool] | None = None for name, value in params: if name == "server_no_context_takeover": @@ -286,9 +286,9 @@ def __init__( self, server_no_context_takeover: bool = False, client_no_context_takeover: bool = False, - server_max_window_bits: Optional[int] = None, - client_max_window_bits: Optional[Union[int, bool]] = True, - compress_settings: Optional[Dict[str, Any]] = None, + server_max_window_bits: int | None = None, + client_max_window_bits: Union[int, bool] | None = True, + compress_settings: Dict[str, Any] | None = None, ) -> None: """ Configure the Per-Message Deflate extension factory. @@ -433,7 +433,7 @@ def process_response_params( def enable_client_permessage_deflate( - extensions: Optional[Sequence[ClientExtensionFactory]], + extensions: Sequence[ClientExtensionFactory] | None, ) -> Sequence[ClientExtensionFactory]: """ Enable Per-Message Deflate with default settings in client extensions. @@ -489,9 +489,9 @@ def __init__( self, server_no_context_takeover: bool = False, client_no_context_takeover: bool = False, - server_max_window_bits: Optional[int] = None, - client_max_window_bits: Optional[int] = None, - compress_settings: Optional[Dict[str, Any]] = None, + server_max_window_bits: int | None = None, + client_max_window_bits: int | None = None, + compress_settings: Dict[str, Any] | None = None, require_client_max_window_bits: bool = False, ) -> None: """ @@ -635,7 +635,7 @@ def process_request_params( def enable_server_permessage_deflate( - extensions: Optional[Sequence[ServerExtensionFactory]], + extensions: Sequence[ServerExtensionFactory] | None, ) -> Sequence[ServerExtensionFactory]: """ Enable Per-Message Deflate with default settings in server extensions. diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 862eef3aa..5a304d6a7 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -5,7 +5,7 @@ import io import secrets import struct -from typing import Callable, Generator, Optional, Sequence, Tuple +from typing import Callable, Generator, Sequence, Tuple from . import exceptions, extensions from .typing import Data @@ -205,8 +205,8 @@ def parse( read_exact: Callable[[int], Generator[None, None, bytes]], *, mask: bool, - max_size: Optional[int] = None, - extensions: Optional[Sequence[extensions.Extension]] = None, + max_size: int | None = None, + extensions: Sequence[extensions.Extension] | None = None, ) -> Generator[None, None, Frame]: """ Parse a WebSocket frame. @@ -280,7 +280,7 @@ def serialize( self, *, mask: bool, - extensions: Optional[Sequence[extensions.Extension]] = None, + extensions: Sequence[extensions.Extension] | None = None, ) -> bytes: """ Serialize a WebSocket frame. diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 463df3061..3b316e0bf 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -4,7 +4,7 @@ import binascii import ipaddress import re -from typing import Callable, List, Optional, Sequence, Tuple, TypeVar, cast +from typing import Callable, List, Sequence, Tuple, TypeVar, cast from . import exceptions from .typing import ( @@ -63,7 +63,7 @@ def build_host(host: str, port: int, secure: bool) -> str: # https://www.rfc-editor.org/rfc/rfc7230.html#appendix-B. -def peek_ahead(header: str, pos: int) -> Optional[str]: +def peek_ahead(header: str, pos: int) -> str | None: """ Return the next character from ``header`` at the given position. @@ -314,7 +314,7 @@ def parse_extension_item_param( name, pos = parse_token(header, pos, header_name) pos = parse_OWS(header, pos) # Extract parameter value, if there is one. - value: Optional[str] = None + value: str | None = None if peek_ahead(header, pos) == "=": pos = parse_OWS(header, pos + 1) if peek_ahead(header, pos) == '"': diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 6fe775eec..a7e9ae682 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -3,7 +3,7 @@ import dataclasses import re import warnings -from typing import Callable, Generator, Optional +from typing import Callable, Generator from . import datastructures, exceptions @@ -62,10 +62,10 @@ class Request: headers: datastructures.Headers # body isn't useful is the context of this library. - _exception: Optional[Exception] = None + _exception: Exception | None = None @property - def exception(self) -> Optional[Exception]: # pragma: no cover + def exception(self) -> Exception | None: # pragma: no cover warnings.warn( "Request.exception is deprecated; " "use ServerProtocol.handshake_exc instead", @@ -164,12 +164,12 @@ class Response: status_code: int reason_phrase: str headers: datastructures.Headers - body: Optional[bytes] = None + body: bytes | None = None - _exception: Optional[Exception] = None + _exception: Exception | None = None @property - def exception(self) -> Optional[Exception]: # pragma: no cover + def exception(self) -> Exception | None: # pragma: no cover warnings.warn( "Response.exception is deprecated; " "use ClientProtocol.handshake_exc instead", @@ -245,7 +245,7 @@ def parse( if 100 <= status_code < 200 or status_code == 204 or status_code == 304: body = None else: - content_length: Optional[int] + content_length: int | None try: # MultipleValuesError is sufficiently unlikely that we don't # attempt to handle it. Instead we document that its parent diff --git a/src/websockets/imports.py b/src/websockets/imports.py index a6a59d4c2..9c05234f5 100644 --- a/src/websockets/imports.py +++ b/src/websockets/imports.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import Any, Dict, Iterable, Optional +from typing import Any, Dict, Iterable __all__ = ["lazy_import"] @@ -30,8 +30,8 @@ def import_name(name: str, source: str, namespace: Dict[str, Any]) -> Any: def lazy_import( namespace: Dict[str, Any], - aliases: Optional[Dict[str, str]] = None, - deprecated_aliases: Optional[Dict[str, str]] = None, + aliases: Dict[str, str] | None = None, + deprecated_aliases: Dict[str, str] | None = None, ) -> None: """ Provide lazy, module-level imports. diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index 8217afedd..067f9c78c 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -3,7 +3,7 @@ import functools import hmac import http -from typing import Any, Awaitable, Callable, Iterable, Optional, Tuple, Union, cast +from typing import Any, Awaitable, Callable, Iterable, Tuple, Union, cast from ..datastructures import Headers from ..exceptions import InvalidHeader @@ -39,14 +39,14 @@ class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol): encoding of non-ASCII characters is undefined. """ - username: Optional[str] = None + username: str | None = None """Username of the authenticated user.""" def __init__( self, *args: Any, - realm: Optional[str] = None, - check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None, + realm: str | None = None, + check_credentials: Callable[[str, str], Awaitable[bool]] | None = None, **kwargs: Any, ) -> None: if realm is not None: @@ -79,7 +79,7 @@ async def process_request( self, path: str, request_headers: Headers, - ) -> Optional[HTTPResponse]: + ) -> HTTPResponse | None: """ Check HTTP Basic Auth and return an HTTP 401 response if needed. @@ -115,10 +115,10 @@ async def process_request( def basic_auth_protocol_factory( - realm: Optional[str] = None, - credentials: Optional[Union[Credentials, Iterable[Credentials]]] = None, - check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None, - create_protocol: Optional[Callable[..., BasicAuthWebSocketServerProtocol]] = None, + realm: str | None = None, + credentials: Union[Credentials, Iterable[Credentials]] | None = None, + check_credentials: Callable[[str, str], Awaitable[bool]] | None = None, + create_protocol: Callable[..., BasicAuthWebSocketServerProtocol] | None = None, ) -> Callable[..., BasicAuthWebSocketServerProtocol]: """ Protocol factory that enforces HTTP Basic Auth. diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index e5da8b13a..f7464368f 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -13,7 +13,6 @@ Callable, Generator, List, - Optional, Sequence, Tuple, Type, @@ -86,12 +85,12 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): def __init__( self, *, - logger: Optional[LoggerLike] = None, - origin: Optional[Origin] = None, - extensions: Optional[Sequence[ClientExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLike] = None, - user_agent_header: Optional[str] = USER_AGENT, + logger: LoggerLike | None = None, + origin: Origin | None = None, + extensions: Sequence[ClientExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + extra_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, **kwargs: Any, ) -> None: if logger is None: @@ -152,7 +151,7 @@ async def read_http_response(self) -> Tuple[int, Headers]: @staticmethod def process_extensions( headers: Headers, - available_extensions: Optional[Sequence[ClientExtensionFactory]], + available_extensions: Sequence[ClientExtensionFactory] | None, ) -> List[Extension]: """ Handle the Sec-WebSocket-Extensions HTTP response header. @@ -224,8 +223,8 @@ def process_extensions( @staticmethod def process_subprotocol( - headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]] - ) -> Optional[Subprotocol]: + headers: Headers, available_subprotocols: Sequence[Subprotocol] | None + ) -> Subprotocol | None: """ Handle the Sec-WebSocket-Protocol HTTP response header. @@ -234,7 +233,7 @@ def process_subprotocol( Return the selected subprotocol. """ - subprotocol: Optional[Subprotocol] = None + subprotocol: Subprotocol | None = None header_values = headers.get_all("Sec-WebSocket-Protocol") @@ -260,10 +259,10 @@ def process_subprotocol( async def handshake( self, wsuri: WebSocketURI, - origin: Optional[Origin] = None, - available_extensions: Optional[Sequence[ClientExtensionFactory]] = None, - available_subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLike] = None, + origin: Origin | None = None, + available_extensions: Sequence[ClientExtensionFactory] | None = None, + available_subprotocols: Sequence[Subprotocol] | None = None, + extra_headers: HeadersLike | None = None, ) -> None: """ Perform the client side of the opening handshake. @@ -427,26 +426,26 @@ def __init__( self, uri: str, *, - create_protocol: Optional[Callable[..., WebSocketClientProtocol]] = None, - logger: Optional[LoggerLike] = None, - compression: Optional[str] = "deflate", - origin: Optional[Origin] = None, - extensions: Optional[Sequence[ClientExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLike] = None, - user_agent_header: Optional[str] = USER_AGENT, - open_timeout: Optional[float] = 10, - ping_interval: Optional[float] = 20, - ping_timeout: Optional[float] = 20, - close_timeout: Optional[float] = None, - max_size: Optional[int] = 2**20, - max_queue: Optional[int] = 2**5, + create_protocol: Callable[..., WebSocketClientProtocol] | None = None, + logger: LoggerLike | None = None, + compression: str | None = "deflate", + origin: Origin | None = None, + extensions: Sequence[ClientExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + extra_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = None, + max_size: int | None = 2**20, + max_queue: int | None = 2**5, read_limit: int = 2**16, write_limit: int = 2**16, **kwargs: Any, ) -> None: # Backwards compatibility: close_timeout used to be called timeout. - timeout: Optional[float] = kwargs.pop("timeout", None) + timeout: float | None = kwargs.pop("timeout", None) if timeout is None: timeout = 10 else: @@ -456,7 +455,7 @@ def __init__( close_timeout = timeout # Backwards compatibility: create_protocol used to be called klass. - klass: Optional[Type[WebSocketClientProtocol]] = kwargs.pop("klass", None) + klass: Type[WebSocketClientProtocol] | None = kwargs.pop("klass", None) if klass is None: klass = WebSocketClientProtocol else: @@ -469,7 +468,7 @@ def __init__( legacy_recv: bool = kwargs.pop("legacy_recv", False) # Backwards compatibility: the loop parameter used to be supported. - _loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None) + _loop: asyncio.AbstractEventLoop | None = kwargs.pop("loop", None) if _loop is None: loop = asyncio.get_event_loop() else: @@ -516,13 +515,13 @@ def __init__( ) if kwargs.pop("unix", False): - path: Optional[str] = kwargs.pop("path", None) + path: str | None = kwargs.pop("path", None) create_connection = functools.partial( loop.create_unix_connection, factory, path, **kwargs ) else: - host: Optional[str] - port: Optional[int] + host: str | None + port: int | None if kwargs.get("sock") is None: host, port = wsuri.host, wsuri.port else: @@ -630,9 +629,9 @@ async def __aenter__(self) -> WebSocketClientProtocol: async def __aexit__( self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: await self.protocol.close() @@ -679,7 +678,7 @@ async def __await_impl__(self) -> WebSocketClientProtocol: def unix_connect( - path: Optional[str] = None, + path: str | None = None, uri: str = "ws://localhost/", **kwargs: Any, ) -> Connect: diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index b77b869e3..8a13fa446 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -1,7 +1,7 @@ from __future__ import annotations import struct -from typing import Any, Awaitable, Callable, NamedTuple, Optional, Sequence, Tuple +from typing import Any, Awaitable, Callable, NamedTuple, Sequence, Tuple from .. import extensions, frames from ..exceptions import PayloadTooBig, ProtocolError @@ -44,8 +44,8 @@ async def read( reader: Callable[[int], Awaitable[bytes]], *, mask: bool, - max_size: Optional[int] = None, - extensions: Optional[Sequence[extensions.Extension]] = None, + max_size: int | None = None, + extensions: Sequence[extensions.Extension] | None = None, ) -> Frame: """ Read a WebSocket frame. @@ -122,7 +122,7 @@ def write( write: Callable[[bytes], Any], *, mask: bool, - extensions: Optional[Sequence[extensions.Extension]] = None, + extensions: Sequence[extensions.Extension] | None = None, ) -> None: """ Write a WebSocket frame. diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 26d50a2cc..94d42cfdb 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -22,7 +22,6 @@ Iterable, List, Mapping, - Optional, Tuple, Union, cast, @@ -173,21 +172,21 @@ class WebSocketCommonProtocol(asyncio.Protocol): def __init__( self, *, - logger: Optional[LoggerLike] = None, - ping_interval: Optional[float] = 20, - ping_timeout: Optional[float] = 20, - close_timeout: Optional[float] = None, - max_size: Optional[int] = 2**20, - max_queue: Optional[int] = 2**5, + logger: LoggerLike | None = None, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = None, + max_size: int | None = 2**20, + max_queue: int | None = 2**5, read_limit: int = 2**16, write_limit: int = 2**16, # The following arguments are kept only for backwards compatibility. - host: Optional[str] = None, - port: Optional[int] = None, - secure: Optional[bool] = None, + host: str | None = None, + port: int | None = None, + secure: bool | None = None, legacy_recv: bool = False, - loop: Optional[asyncio.AbstractEventLoop] = None, - timeout: Optional[float] = None, + loop: asyncio.AbstractEventLoop | None = None, + timeout: float | None = None, ) -> None: if legacy_recv: # pragma: no cover warnings.warn("legacy_recv is deprecated", DeprecationWarning) @@ -243,7 +242,7 @@ def __init__( # Copied from asyncio.FlowControlMixin self._paused = False - self._drain_waiter: Optional[asyncio.Future[None]] = None + self._drain_waiter: asyncio.Future[None] | None = None self._drain_lock = asyncio.Lock() @@ -265,13 +264,13 @@ def __init__( # WebSocket protocol parameters. self.extensions: List[Extension] = [] - self.subprotocol: Optional[Subprotocol] = None + self.subprotocol: Subprotocol | None = None """Subprotocol, if one was negotiated.""" # Close code and reason, set when a close frame is sent or received. - self.close_rcvd: Optional[Close] = None - self.close_sent: Optional[Close] = None - self.close_rcvd_then_sent: Optional[bool] = None + self.close_rcvd: Close | None = None + self.close_sent: Close | None = None + self.close_rcvd_then_sent: bool | None = None # Completed when the connection state becomes CLOSED. Translates the # :meth:`connection_lost` callback to a :class:`~asyncio.Future` @@ -281,11 +280,11 @@ def __init__( # Queue of received messages. self.messages: Deque[Data] = collections.deque() - self._pop_message_waiter: Optional[asyncio.Future[None]] = None - self._put_message_waiter: Optional[asyncio.Future[None]] = None + self._pop_message_waiter: asyncio.Future[None] | None = None + self._put_message_waiter: asyncio.Future[None] | None = None # Protect sending fragmented messages. - self._fragmented_message_waiter: Optional[asyncio.Future[None]] = None + self._fragmented_message_waiter: asyncio.Future[None] | None = None # Mapping of ping IDs to pong waiters, in chronological order. self.pings: Dict[bytes, Tuple[asyncio.Future[float], float]] = {} @@ -306,7 +305,7 @@ def __init__( self.transfer_data_task: asyncio.Task[None] # Exception that occurred during data transfer, if any. - self.transfer_data_exc: Optional[BaseException] = None + self.transfer_data_exc: BaseException | None = None # Task sending keepalive pings. self.keepalive_ping_task: asyncio.Task[None] @@ -363,19 +362,19 @@ def connection_open(self) -> None: self.close_connection_task = self.loop.create_task(self.close_connection()) @property - def host(self) -> Optional[str]: + def host(self) -> str | None: alternative = "remote_address" if self.is_client else "local_address" warnings.warn(f"use {alternative}[0] instead of host", DeprecationWarning) return self._host @property - def port(self) -> Optional[int]: + def port(self) -> int | None: alternative = "remote_address" if self.is_client else "local_address" warnings.warn(f"use {alternative}[1] instead of port", DeprecationWarning) return self._port @property - def secure(self) -> Optional[bool]: + def secure(self) -> bool | None: warnings.warn("don't use secure", DeprecationWarning) return self._secure @@ -447,7 +446,7 @@ def closed(self) -> bool: return self.state is State.CLOSED @property - def close_code(self) -> Optional[int]: + def close_code(self) -> int | None: """ WebSocket close code, defined in `section 7.1.5 of RFC 6455`_. @@ -465,7 +464,7 @@ def close_code(self) -> Optional[int]: return self.close_rcvd.code @property - def close_reason(self) -> Optional[str]: + def close_reason(self) -> str | None: """ WebSocket close reason, defined in `section 7.1.6 of RFC 6455`_. @@ -804,7 +803,7 @@ async def wait_closed(self) -> None: """ await asyncio.shield(self.connection_lost_waiter) - async def ping(self, data: Optional[Data] = None) -> Awaitable[float]: + async def ping(self, data: Data | None = None) -> Awaitable[float]: """ Send a Ping_. @@ -1017,7 +1016,7 @@ async def transfer_data(self) -> None: self.transfer_data_exc = exc self.fail_connection(CloseCode.INTERNAL_ERROR) - async def read_message(self) -> Optional[Data]: + async def read_message(self) -> Data | None: """ Read a single message from the connection. @@ -1090,7 +1089,7 @@ def append(frame: Frame) -> None: return ("" if text else b"").join(fragments) - async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: + async def read_data_frame(self, max_size: int | None) -> Frame | None: """ Read a single data frame from the connection. @@ -1153,7 +1152,7 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: else: return frame - async def read_frame(self, max_size: Optional[int]) -> Frame: + async def read_frame(self, max_size: int | None) -> Frame: """ Read a single frame from the connection. @@ -1204,9 +1203,7 @@ async def write_frame( self.write_frame_sync(fin, opcode, data) await self.drain() - async def write_close_frame( - self, close: Close, data: Optional[bytes] = None - ) -> None: + async def write_close_frame(self, close: Close, data: bytes | None = None) -> None: """ Write a close frame if and only if the connection state is OPEN. @@ -1484,7 +1481,7 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: # Copied from asyncio.StreamReaderProtocol self.reader.set_transport(transport) - def connection_lost(self, exc: Optional[Exception]) -> None: + def connection_lost(self, exc: Exception | None) -> None: """ 7.1.4. The WebSocket Connection is Closed. diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 0f3c1c150..551115174 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -16,7 +16,6 @@ Generator, Iterable, List, - Optional, Sequence, Set, Tuple, @@ -103,19 +102,19 @@ def __init__( ], ws_server: WebSocketServer, *, - logger: Optional[LoggerLike] = None, - origins: Optional[Sequence[Optional[Origin]]] = None, - extensions: Optional[Sequence[ServerExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLikeOrCallable] = None, - server_header: Optional[str] = USER_AGENT, - process_request: Optional[ - Callable[[str, Headers], Awaitable[Optional[HTTPResponse]]] - ] = None, - select_subprotocol: Optional[ - Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] - ] = None, - open_timeout: Optional[float] = 10, + logger: LoggerLike | None = None, + origins: Sequence[Origin | None] | None = None, + extensions: Sequence[ServerExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + extra_headers: HeadersLikeOrCallable | None = None, + server_header: str | None = USER_AGENT, + process_request: ( + Callable[[str, Headers], Awaitable[HTTPResponse | None]] | None + ) = None, + select_subprotocol: ( + Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] | None + ) = None, + open_timeout: float | None = 10, **kwargs: Any, ) -> None: if logger is None: @@ -293,7 +292,7 @@ async def read_http_request(self) -> Tuple[str, Headers]: return path, headers def write_http_response( - self, status: http.HTTPStatus, headers: Headers, body: Optional[bytes] = None + self, status: http.HTTPStatus, headers: Headers, body: bytes | None = None ) -> None: """ Write status line and headers to the HTTP response. @@ -322,7 +321,7 @@ def write_http_response( async def process_request( self, path: str, request_headers: Headers - ) -> Optional[HTTPResponse]: + ) -> HTTPResponse | None: """ Intercept the HTTP request and return an HTTP response if appropriate. @@ -371,8 +370,8 @@ async def process_request( @staticmethod def process_origin( - headers: Headers, origins: Optional[Sequence[Optional[Origin]]] = None - ) -> Optional[Origin]: + headers: Headers, origins: Sequence[Origin | None] | None = None + ) -> Origin | None: """ Handle the Origin HTTP request header. @@ -387,9 +386,11 @@ def process_origin( # "The user agent MUST NOT include more than one Origin header field" # per https://www.rfc-editor.org/rfc/rfc6454.html#section-7.3. try: - origin = cast(Optional[Origin], headers.get("Origin")) + origin = headers.get("Origin") except MultipleValuesError as exc: raise InvalidHeader("Origin", "more than one Origin header found") from exc + if origin is not None: + origin = cast(Origin, origin) if origins is not None: if origin not in origins: raise InvalidOrigin(origin) @@ -398,8 +399,8 @@ def process_origin( @staticmethod def process_extensions( headers: Headers, - available_extensions: Optional[Sequence[ServerExtensionFactory]], - ) -> Tuple[Optional[str], List[Extension]]: + available_extensions: Sequence[ServerExtensionFactory] | None, + ) -> Tuple[str | None, List[Extension]]: """ Handle the Sec-WebSocket-Extensions HTTP request header. @@ -435,7 +436,7 @@ def process_extensions( InvalidHandshake: To abort the handshake with an HTTP 400 error. """ - response_header_value: Optional[str] = None + response_header_value: str | None = None extension_headers: List[ExtensionHeader] = [] accepted_extensions: List[Extension] = [] @@ -479,8 +480,8 @@ def process_extensions( # Not @staticmethod because it calls self.select_subprotocol() def process_subprotocol( - self, headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]] - ) -> Optional[Subprotocol]: + self, headers: Headers, available_subprotocols: Sequence[Subprotocol] | None + ) -> Subprotocol | None: """ Handle the Sec-WebSocket-Protocol HTTP request header. @@ -495,7 +496,7 @@ def process_subprotocol( InvalidHandshake: To abort the handshake with an HTTP 400 error. """ - subprotocol: Optional[Subprotocol] = None + subprotocol: Subprotocol | None = None header_values = headers.get_all("Sec-WebSocket-Protocol") @@ -514,7 +515,7 @@ def select_subprotocol( self, client_subprotocols: Sequence[Subprotocol], server_subprotocols: Sequence[Subprotocol], - ) -> Optional[Subprotocol]: + ) -> Subprotocol | None: """ Pick a subprotocol among those supported by the client and the server. @@ -552,10 +553,10 @@ def select_subprotocol( async def handshake( self, - origins: Optional[Sequence[Optional[Origin]]] = None, - available_extensions: Optional[Sequence[ServerExtensionFactory]] = None, - available_subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLikeOrCallable] = None, + origins: Sequence[Origin | None] | None = None, + available_extensions: Sequence[ServerExtensionFactory] | None = None, + available_subprotocols: Sequence[Subprotocol] | None = None, + extra_headers: HeadersLikeOrCallable | None = None, ) -> str: """ Perform the server side of the opening handshake. @@ -661,7 +662,7 @@ class WebSocketServer: """ - def __init__(self, logger: Optional[LoggerLike] = None): + def __init__(self, logger: LoggerLike | None = None): if logger is None: logger = logging.getLogger("websockets.server") self.logger = logger @@ -670,7 +671,7 @@ def __init__(self, logger: Optional[LoggerLike] = None): self.websockets: Set[WebSocketServerProtocol] = set() # Task responsible for closing the server and terminating connections. - self.close_task: Optional[asyncio.Task[None]] = None + self.close_task: asyncio.Task[None] | None = None # Completed when the server is closed and connections are terminated. self.closed_waiter: asyncio.Future[None] @@ -869,9 +870,9 @@ async def __aenter__(self) -> WebSocketServer: # pragma: no cover async def __aexit__( self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: # pragma: no cover self.close() await self.wait_closed() @@ -941,8 +942,8 @@ class Serve: server_header: Value of the ``Server`` response header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. - process_request (Optional[Callable[[str, Headers], \ - Awaitable[Optional[Tuple[StatusLike, HeadersLike, bytes]]]]]): + process_request (Callable[[str, Headers], \ + Awaitable[Tuple[StatusLike, HeadersLike, bytes] | None]] | None): Intercept HTTP request before the opening handshake. See :meth:`~WebSocketServerProtocol.process_request` for details. select_subprotocol: Select a subprotocol supported by the client. @@ -975,35 +976,35 @@ def __init__( Callable[[WebSocketServerProtocol], Awaitable[Any]], Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated ], - host: Optional[Union[str, Sequence[str]]] = None, - port: Optional[int] = None, + host: Union[str, Sequence[str]] | None = None, + port: int | None = None, *, - create_protocol: Optional[Callable[..., WebSocketServerProtocol]] = None, - logger: Optional[LoggerLike] = None, - compression: Optional[str] = "deflate", - origins: Optional[Sequence[Optional[Origin]]] = None, - extensions: Optional[Sequence[ServerExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLikeOrCallable] = None, - server_header: Optional[str] = USER_AGENT, - process_request: Optional[ - Callable[[str, Headers], Awaitable[Optional[HTTPResponse]]] - ] = None, - select_subprotocol: Optional[ - Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] - ] = None, - open_timeout: Optional[float] = 10, - ping_interval: Optional[float] = 20, - ping_timeout: Optional[float] = 20, - close_timeout: Optional[float] = None, - max_size: Optional[int] = 2**20, - max_queue: Optional[int] = 2**5, + create_protocol: Callable[..., WebSocketServerProtocol] | None = None, + logger: LoggerLike | None = None, + compression: str | None = "deflate", + origins: Sequence[Origin | None] | None = None, + extensions: Sequence[ServerExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + extra_headers: HeadersLikeOrCallable | None = None, + server_header: str | None = USER_AGENT, + process_request: ( + Callable[[str, Headers], Awaitable[HTTPResponse | None]] | None + ) = None, + select_subprotocol: ( + Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] | None + ) = None, + open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = None, + max_size: int | None = 2**20, + max_queue: int | None = 2**5, read_limit: int = 2**16, write_limit: int = 2**16, **kwargs: Any, ) -> None: # Backwards compatibility: close_timeout used to be called timeout. - timeout: Optional[float] = kwargs.pop("timeout", None) + timeout: float | None = kwargs.pop("timeout", None) if timeout is None: timeout = 10 else: @@ -1013,7 +1014,7 @@ def __init__( close_timeout = timeout # Backwards compatibility: create_protocol used to be called klass. - klass: Optional[Type[WebSocketServerProtocol]] = kwargs.pop("klass", None) + klass: Type[WebSocketServerProtocol] | None = kwargs.pop("klass", None) if klass is None: klass = WebSocketServerProtocol else: @@ -1026,7 +1027,7 @@ def __init__( legacy_recv: bool = kwargs.pop("legacy_recv", False) # Backwards compatibility: the loop parameter used to be supported. - _loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None) + _loop: asyncio.AbstractEventLoop | None = kwargs.pop("loop", None) if _loop is None: loop = asyncio.get_event_loop() else: @@ -1076,7 +1077,7 @@ def __init__( ) if kwargs.pop("unix", False): - path: Optional[str] = kwargs.pop("path", None) + path: str | None = kwargs.pop("path", None) # unix_serve(path) must not specify host and port parameters. assert host is None and port is None create_server = functools.partial( @@ -1098,9 +1099,9 @@ async def __aenter__(self) -> WebSocketServer: async def __aexit__( self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: self.ws_server.close() await self.ws_server.wait_closed() @@ -1129,7 +1130,7 @@ def unix_serve( Callable[[WebSocketServerProtocol], Awaitable[Any]], Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated ], - path: Optional[str] = None, + path: str | None = None, **kwargs: Any, ) -> Serve: """ diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 0b36202e5..8aa222eeb 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -3,7 +3,7 @@ import enum import logging import uuid -from typing import Generator, List, Optional, Type, Union +from typing import Generator, List, Type, Union from .exceptions import ( ConnectionClosed, @@ -89,8 +89,8 @@ def __init__( side: Side, *, state: State = OPEN, - max_size: Optional[int] = 2**20, - logger: Optional[LoggerLike] = None, + max_size: int | None = 2**20, + logger: LoggerLike | None = None, ) -> None: # Unique identifier. For logs. self.id: uuid.UUID = uuid.uuid4() @@ -116,24 +116,24 @@ def __init__( # Current size of incoming message in bytes. Only set while reading a # fragmented message i.e. a data frames with the FIN bit not set. - self.cur_size: Optional[int] = None + self.cur_size: int | None = None # True while sending a fragmented message i.e. a data frames with the # FIN bit not set. self.expect_continuation_frame = False # WebSocket protocol parameters. - self.origin: Optional[Origin] = None + self.origin: Origin | None = None self.extensions: List[Extension] = [] - self.subprotocol: Optional[Subprotocol] = None + self.subprotocol: Subprotocol | None = None # Close code and reason, set when a close frame is sent or received. - self.close_rcvd: Optional[Close] = None - self.close_sent: Optional[Close] = None - self.close_rcvd_then_sent: Optional[bool] = None + self.close_rcvd: Close | None = None + self.close_sent: Close | None = None + self.close_rcvd_then_sent: bool | None = None # Track if an exception happened during the handshake. - self.handshake_exc: Optional[Exception] = None + self.handshake_exc: Exception | None = None """ Exception to raise if the opening handshake failed. @@ -150,7 +150,7 @@ def __init__( self.writes: List[bytes] = [] self.parser = self.parse() next(self.parser) # start coroutine - self.parser_exc: Optional[Exception] = None + self.parser_exc: Exception | None = None @property def state(self) -> State: @@ -169,7 +169,7 @@ def state(self, state: State) -> None: self._state = state @property - def close_code(self) -> Optional[int]: + def close_code(self) -> int | None: """ `WebSocket close code`_. @@ -187,7 +187,7 @@ def close_code(self) -> Optional[int]: return self.close_rcvd.code @property - def close_reason(self) -> Optional[str]: + def close_reason(self) -> str | None: """ `WebSocket close reason`_. @@ -348,7 +348,7 @@ def send_binary(self, data: bytes, fin: bool = True) -> None: self.expect_continuation_frame = not fin self.send_frame(Frame(OP_BINARY, data, fin)) - def send_close(self, code: Optional[int] = None, reason: str = "") -> None: + def send_close(self, code: int | None = None, reason: str = "") -> None: """ Send a `Close frame`_. diff --git a/src/websockets/server.py b/src/websockets/server.py index 330e54f37..a92541085 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -5,7 +5,7 @@ import email.utils import http import warnings -from typing import Any, Callable, Generator, List, Optional, Sequence, Tuple, cast +from typing import Any, Callable, Generator, List, Sequence, Tuple, cast from .datastructures import Headers, MultipleValuesError from .exceptions import ( @@ -77,18 +77,19 @@ class ServerProtocol(Protocol): def __init__( self, *, - origins: Optional[Sequence[Optional[Origin]]] = None, - extensions: Optional[Sequence[ServerExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - select_subprotocol: Optional[ + origins: Sequence[Origin | None] | None = None, + extensions: Sequence[ServerExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + select_subprotocol: ( Callable[ [ServerProtocol, Sequence[Subprotocol]], - Optional[Subprotocol], + Subprotocol | None, ] - ] = None, + | None + ) = None, state: State = CONNECTING, - max_size: Optional[int] = 2**20, - logger: Optional[LoggerLike] = None, + max_size: int | None = 2**20, + logger: LoggerLike | None = None, ): super().__init__( side=SERVER, @@ -200,7 +201,7 @@ def accept(self, request: Request) -> Response: def process_request( self, request: Request, - ) -> Tuple[str, Optional[str], Optional[str]]: + ) -> Tuple[str, str | None, str | None]: """ Check a handshake request and negotiate extensions and subprotocol. @@ -285,7 +286,7 @@ def process_request( protocol_header, ) - def process_origin(self, headers: Headers) -> Optional[Origin]: + def process_origin(self, headers: Headers) -> Origin | None: """ Handle the Origin HTTP request header. @@ -303,9 +304,11 @@ def process_origin(self, headers: Headers) -> Optional[Origin]: # "The user agent MUST NOT include more than one Origin header field" # per https://www.rfc-editor.org/rfc/rfc6454.html#section-7.3. try: - origin = cast(Optional[Origin], headers.get("Origin")) + origin = headers.get("Origin") except MultipleValuesError as exc: raise InvalidHeader("Origin", "more than one Origin header found") from exc + if origin is not None: + origin = cast(Origin, origin) if self.origins is not None: if origin not in self.origins: raise InvalidOrigin(origin) @@ -314,7 +317,7 @@ def process_origin(self, headers: Headers) -> Optional[Origin]: def process_extensions( self, headers: Headers, - ) -> Tuple[Optional[str], List[Extension]]: + ) -> Tuple[str | None, List[Extension]]: """ Handle the Sec-WebSocket-Extensions HTTP request header. @@ -350,7 +353,7 @@ def process_extensions( InvalidHandshake: If the Sec-WebSocket-Extensions header is invalid. """ - response_header_value: Optional[str] = None + response_header_value: str | None = None extension_headers: List[ExtensionHeader] = [] accepted_extensions: List[Extension] = [] @@ -392,7 +395,7 @@ def process_extensions( return response_header_value, accepted_extensions - def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: + def process_subprotocol(self, headers: Headers) -> Subprotocol | None: """ Handle the Sec-WebSocket-Protocol HTTP request header. @@ -420,7 +423,7 @@ def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: def select_subprotocol( self, subprotocols: Sequence[Subprotocol], - ) -> Optional[Subprotocol]: + ) -> Subprotocol | None: """ Pick a subprotocol among those offered by the client. diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 0bb7a76fd..60b49ebc3 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -4,7 +4,7 @@ import ssl as ssl_module import threading import warnings -from typing import Any, Optional, Sequence, Type +from typing import Any, Sequence, Type from ..client import ClientProtocol from ..datastructures import HeadersLike @@ -52,7 +52,7 @@ def __init__( socket: socket.socket, protocol: ClientProtocol, *, - close_timeout: Optional[float] = 10, + close_timeout: float | None = 10, ) -> None: self.protocol: ClientProtocol self.response_rcvd = threading.Event() @@ -64,9 +64,9 @@ def __init__( def handshake( self, - additional_headers: Optional[HeadersLike] = None, - user_agent_header: Optional[str] = USER_AGENT, - timeout: Optional[float] = None, + additional_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + timeout: float | None = None, ) -> None: """ Perform the opening handshake. @@ -128,25 +128,25 @@ def connect( uri: str, *, # TCP/TLS - sock: Optional[socket.socket] = None, - ssl: Optional[ssl_module.SSLContext] = None, - server_hostname: Optional[str] = None, + sock: socket.socket | None = None, + ssl: ssl_module.SSLContext | None = None, + server_hostname: str | None = None, # WebSocket - origin: Optional[Origin] = None, - extensions: Optional[Sequence[ClientExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - additional_headers: Optional[HeadersLike] = None, - user_agent_header: Optional[str] = USER_AGENT, - compression: Optional[str] = "deflate", + origin: Origin | None = None, + extensions: Sequence[ClientExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + additional_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + compression: str | None = "deflate", # Timeouts - open_timeout: Optional[float] = 10, - close_timeout: Optional[float] = 10, + open_timeout: float | None = 10, + close_timeout: float | None = 10, # Limits - max_size: Optional[int] = 2**20, + max_size: int | None = 2**20, # Logging - logger: Optional[LoggerLike] = None, + logger: LoggerLike | None = None, # Escape hatch for advanced customization - create_connection: Optional[Type[ClientConnection]] = None, + create_connection: Type[ClientConnection] | None = None, **kwargs: Any, ) -> ClientConnection: """ @@ -219,7 +219,7 @@ def connect( # Private APIs for unix_connect() unix: bool = kwargs.pop("unix", False) - path: Optional[str] = kwargs.pop("path", None) + path: str | None = kwargs.pop("path", None) if unix: if path is None and sock is None: @@ -307,8 +307,8 @@ def connect( def unix_connect( - path: Optional[str] = None, - uri: Optional[str] = None, + path: str | None = None, + uri: str | None = None, **kwargs: Any, ) -> ClientConnection: """ diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index b41202dc9..bb9743181 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -8,7 +8,7 @@ import threading import uuid from types import TracebackType -from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Type, Union +from typing import Any, Dict, Iterable, Iterator, Mapping, Type, Union from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode, prepare_ctrl @@ -42,7 +42,7 @@ def __init__( socket: socket.socket, protocol: Protocol, *, - close_timeout: Optional[float] = 10, + close_timeout: float | None = 10, ) -> None: self.socket = socket self.protocol = protocol @@ -62,9 +62,9 @@ def __init__( self.debug = self.protocol.debug # HTTP handshake request and response. - self.request: Optional[Request] = None + self.request: Request | None = None """Opening handshake request.""" - self.response: Optional[Response] = None + self.response: Response | None = None """Opening handshake response.""" # Mutex serializing interactions with the protocol. @@ -77,7 +77,7 @@ def __init__( self.send_in_progress = False # Deadline for the closing handshake. - self.close_deadline: Optional[Deadline] = None + self.close_deadline: Deadline | None = None # Mapping of ping IDs to pong waiters, in chronological order. self.ping_waiters: Dict[bytes, threading.Event] = {} @@ -93,7 +93,7 @@ def __init__( # Exception raised in recv_events, to be chained to ConnectionClosed # in the user thread in order to show why the TCP connection dropped. - self.recv_exc: Optional[BaseException] = None + self.recv_exc: BaseException | None = None # Public attributes @@ -124,7 +124,7 @@ def remote_address(self) -> Any: return self.socket.getpeername() @property - def subprotocol(self) -> Optional[Subprotocol]: + def subprotocol(self) -> Subprotocol | None: """ Subprotocol negotiated during the opening handshake. @@ -140,9 +140,9 @@ def __enter__(self) -> Connection: def __exit__( self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: if exc_type is None: self.close() @@ -166,7 +166,7 @@ def __iter__(self) -> Iterator[Data]: except ConnectionClosedOK: return - def recv(self, timeout: Optional[float] = None) -> Data: + def recv(self, timeout: float | None = None) -> Data: """ Receive the next message. @@ -420,7 +420,7 @@ def close(self, code: int = CloseCode.NORMAL_CLOSURE, reason: str = "") -> None: # They mean that the connection is closed, which was the goal. pass - def ping(self, data: Optional[Data] = None) -> threading.Event: + def ping(self, data: Data | None = None) -> threading.Event: """ Send a Ping_. @@ -647,7 +647,7 @@ def send_context( # Should we close the socket and raise ConnectionClosed? raise_close_exc = False # What exception should we chain ConnectionClosed to? - original_exc: Optional[BaseException] = None + original_exc: BaseException | None = None # Acquire the protocol lock. with self.protocol_mutex: @@ -748,7 +748,7 @@ def send_data(self) -> None: except OSError: # socket already closed pass - def set_recv_exc(self, exc: Optional[BaseException]) -> None: + def set_recv_exc(self, exc: BaseException | None) -> None: """ Set recv_exc, if not set yet. diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index dcba183d9..2c604ba09 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -3,7 +3,7 @@ import codecs import queue import threading -from typing import Iterator, List, Optional, cast +from typing import Iterator, List, cast from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame from ..typing import Data @@ -41,7 +41,7 @@ def __init__(self) -> None: self.put_in_progress = False # Decoder for text frames, None for binary frames. - self.decoder: Optional[codecs.IncrementalDecoder] = None + self.decoder: codecs.IncrementalDecoder | None = None # Buffer of frames belonging to the same message. self.chunks: List[Data] = [] @@ -54,12 +54,12 @@ def __init__(self) -> None: # Stream data from frames belonging to the same message. # Remove quotes around type when dropping Python < 3.9. - self.chunks_queue: Optional["queue.SimpleQueue[Optional[Data]]"] = None + self.chunks_queue: "queue.SimpleQueue[Data | None] | None" = None # This flag marks the end of the connection. self.closed = False - def get(self, timeout: Optional[float] = None) -> Data: + def get(self, timeout: float | None = None) -> Data: """ Read the next message. @@ -151,7 +151,7 @@ def get_iter(self) -> Iterator[Data]: self.chunks = [] self.chunks_queue = cast( # Remove quotes around type when dropping Python < 3.9. - "queue.SimpleQueue[Optional[Data]]", + "queue.SimpleQueue[Data | None]", queue.SimpleQueue(), ) @@ -164,7 +164,7 @@ def get_iter(self) -> Iterator[Data]: self.get_in_progress = True # Locking with get_in_progress ensures only one thread can get here. - chunk: Optional[Data] + chunk: Data | None for chunk in chunks: yield chunk while (chunk := self.chunks_queue.get()) is not None: diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index fd4f5d3bd..b801510b4 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -10,7 +10,7 @@ import threading import warnings from types import TracebackType -from typing import Any, Callable, Optional, Sequence, Type +from typing import Any, Callable, Sequence, Type from ..extensions.base import ServerExtensionFactory from ..extensions.permessage_deflate import enable_server_permessage_deflate @@ -57,7 +57,7 @@ def __init__( socket: socket.socket, protocol: ServerProtocol, *, - close_timeout: Optional[float] = 10, + close_timeout: float | None = 10, ) -> None: self.protocol: ServerProtocol self.request_rcvd = threading.Event() @@ -69,20 +69,22 @@ def __init__( def handshake( self, - process_request: Optional[ + process_request: ( Callable[ [ServerConnection, Request], - Optional[Response], + Response | None, ] - ] = None, - process_response: Optional[ + | None + ) = None, + process_response: ( Callable[ [ServerConnection, Request, Response], - Optional[Response], + Response | None, ] - ] = None, - server_header: Optional[str] = USER_AGENT, - timeout: Optional[float] = None, + | None + ) = None, + server_header: str | None = USER_AGENT, + timeout: float | None = None, ) -> None: """ Perform the opening handshake. @@ -197,7 +199,7 @@ def __init__( self, socket: socket.socket, handler: Callable[[socket.socket, Any], None], - logger: Optional[LoggerLike] = None, + logger: LoggerLike | None = None, ): self.socket = socket self.handler = handler @@ -260,54 +262,57 @@ def __enter__(self) -> WebSocketServer: def __exit__( self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: self.shutdown() def serve( handler: Callable[[ServerConnection], None], - host: Optional[str] = None, - port: Optional[int] = None, + host: str | None = None, + port: int | None = None, *, # TCP/TLS - sock: Optional[socket.socket] = None, - ssl: Optional[ssl_module.SSLContext] = None, + sock: socket.socket | None = None, + ssl: ssl_module.SSLContext | None = None, # WebSocket - origins: Optional[Sequence[Optional[Origin]]] = None, - extensions: Optional[Sequence[ServerExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - select_subprotocol: Optional[ + origins: Sequence[Origin | None] | None = None, + extensions: Sequence[ServerExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + select_subprotocol: ( Callable[ [ServerConnection, Sequence[Subprotocol]], - Optional[Subprotocol], + Subprotocol | None, ] - ] = None, - process_request: Optional[ + | None + ) = None, + process_request: ( Callable[ [ServerConnection, Request], - Optional[Response], + Response | None, ] - ] = None, - process_response: Optional[ + | None + ) = None, + process_response: ( Callable[ [ServerConnection, Request, Response], - Optional[Response], + Response | None, ] - ] = None, - server_header: Optional[str] = USER_AGENT, - compression: Optional[str] = "deflate", + | None + ) = None, + server_header: str | None = USER_AGENT, + compression: str | None = "deflate", # Timeouts - open_timeout: Optional[float] = 10, - close_timeout: Optional[float] = 10, + open_timeout: float | None = 10, + close_timeout: float | None = 10, # Limits - max_size: Optional[int] = 2**20, + max_size: int | None = 2**20, # Logging - logger: Optional[LoggerLike] = None, + logger: LoggerLike | None = None, # Escape hatch for advanced customization - create_connection: Optional[Type[ServerConnection]] = None, + create_connection: Type[ServerConnection] | None = None, **kwargs: Any, ) -> WebSocketServer: """ @@ -412,7 +417,7 @@ def handler(websocket): # Private APIs for unix_connect() unix: bool = kwargs.pop("unix", False) - path: Optional[str] = kwargs.pop("path", None) + path: str | None = kwargs.pop("path", None) if sock is None: if unix: @@ -460,18 +465,19 @@ def conn_handler(sock: socket.socket, addr: Any) -> None: sock.settimeout(None) # Create a closure to give select_subprotocol access to connection. - protocol_select_subprotocol: Optional[ + protocol_select_subprotocol: ( Callable[ [ServerProtocol, Sequence[Subprotocol]], - Optional[Subprotocol], + Subprotocol | None, ] - ] = None + | None + ) = None if select_subprotocol is not None: def protocol_select_subprotocol( protocol: ServerProtocol, subprotocols: Sequence[Subprotocol], - ) -> Optional[Subprotocol]: + ) -> Subprotocol | None: # mypy doesn't know that select_subprotocol is immutable. assert select_subprotocol is not None # Ensure this function is only used in the intended context. @@ -525,7 +531,7 @@ def protocol_select_subprotocol( def unix_serve( handler: Callable[[ServerConnection], None], - path: Optional[str] = None, + path: str | None = None, **kwargs: Any, ) -> WebSocketServer: """ diff --git a/src/websockets/sync/utils.py b/src/websockets/sync/utils.py index 3364bdc2d..00bce2cc6 100644 --- a/src/websockets/sync/utils.py +++ b/src/websockets/sync/utils.py @@ -1,7 +1,6 @@ from __future__ import annotations import time -from typing import Optional __all__ = ["Deadline"] @@ -16,14 +15,14 @@ class Deadline: """ - def __init__(self, timeout: Optional[float]) -> None: - self.deadline: Optional[float] + def __init__(self, timeout: float | None) -> None: + self.deadline: float | None if timeout is None: self.deadline = None else: self.deadline = time.monotonic() + timeout - def timeout(self, *, raise_if_elapsed: bool = True) -> Optional[float]: + def timeout(self, *, raise_if_elapsed: bool = True) -> float | None: """ Calculate a timeout from a deadline. diff --git a/src/websockets/typing.py b/src/websockets/typing.py index 5dfecf66f..7c5b3664d 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -53,7 +53,7 @@ ExtensionName = NewType("ExtensionName", str) """Name of a WebSocket extension.""" - +# Change to str | None when dropping Python < 3.10. ExtensionParameter = Tuple[str, Optional[str]] """Parameter of a WebSocket extension.""" diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 8cf581743..902716066 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -2,7 +2,7 @@ import dataclasses import urllib.parse -from typing import Optional, Tuple +from typing import Tuple from . import exceptions @@ -33,8 +33,8 @@ class WebSocketURI: port: int path: str query: str - username: Optional[str] = None - password: Optional[str] = None + username: str | None = None + password: str | None = None @property def resource_name(self) -> str: @@ -47,7 +47,7 @@ def resource_name(self) -> str: return resource_name @property - def user_info(self) -> Optional[Tuple[str, str]]: + def user_info(self) -> Tuple[str, str] | None: if self.username is None: return None assert self.password is not None From 63a2d8eff62fe487c42a8f3176528730b7eed727 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 12 Jul 2024 09:16:17 +0200 Subject: [PATCH 1285/1539] Switch from typing.Union to |. --- src/websockets/datastructures.py | 1 + .../extensions/permessage_deflate.py | 10 ++--- src/websockets/legacy/auth.py | 4 +- src/websockets/legacy/protocol.py | 3 +- src/websockets/legacy/server.py | 37 ++++++++++--------- src/websockets/protocol.py | 1 + src/websockets/sync/connection.py | 4 +- src/websockets/typing.py | 3 ++ 8 files changed, 34 insertions(+), 29 deletions(-) diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index aef11bf23..5605772d8 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -177,6 +177,7 @@ def keys(self) -> Iterable[str]: ... def __getitem__(self, key: str) -> str: ... +# Change to Headers | Mapping[str, str] | ... when dropping Python < 3.10. HeadersLike = Union[ Headers, Mapping[str, str], diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index e95b1064b..48a6a0833 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -2,7 +2,7 @@ import dataclasses import zlib -from typing import Any, Dict, List, Sequence, Tuple, Union +from typing import Any, Dict, List, Sequence, Tuple from .. import exceptions, frames from ..typing import ExtensionName, ExtensionParameter @@ -175,7 +175,7 @@ def _build_parameters( server_no_context_takeover: bool, client_no_context_takeover: bool, server_max_window_bits: int | None, - client_max_window_bits: Union[int, bool] | None, + client_max_window_bits: int | bool | None, ) -> List[ExtensionParameter]: """ Build a list of ``(name, value)`` pairs for some compression parameters. @@ -197,7 +197,7 @@ def _build_parameters( def _extract_parameters( params: Sequence[ExtensionParameter], *, is_server: bool -) -> Tuple[bool, bool, int | None, Union[int, bool] | None]: +) -> Tuple[bool, bool, int | None, int | bool | None]: """ Extract compression parameters from a list of ``(name, value)`` pairs. @@ -208,7 +208,7 @@ def _extract_parameters( server_no_context_takeover: bool = False client_no_context_takeover: bool = False server_max_window_bits: int | None = None - client_max_window_bits: Union[int, bool] | None = None + client_max_window_bits: int | bool | None = None for name, value in params: if name == "server_no_context_takeover": @@ -287,7 +287,7 @@ def __init__( server_no_context_takeover: bool = False, client_no_context_takeover: bool = False, server_max_window_bits: int | None = None, - client_max_window_bits: Union[int, bool] | None = True, + client_max_window_bits: int | bool | None = True, compress_settings: Dict[str, Any] | None = None, ) -> None: """ diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index 067f9c78c..9d685d9f4 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -3,7 +3,7 @@ import functools import hmac import http -from typing import Any, Awaitable, Callable, Iterable, Tuple, Union, cast +from typing import Any, Awaitable, Callable, Iterable, Tuple, cast from ..datastructures import Headers from ..exceptions import InvalidHeader @@ -116,7 +116,7 @@ async def process_request( def basic_auth_protocol_factory( realm: str | None = None, - credentials: Union[Credentials, Iterable[Credentials]] | None = None, + credentials: Credentials | Iterable[Credentials] | None = None, check_credentials: Callable[[str, str], Awaitable[bool]] | None = None, create_protocol: Callable[..., BasicAuthWebSocketServerProtocol] | None = None, ) -> Callable[..., BasicAuthWebSocketServerProtocol]: diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 94d42cfdb..f4c5901dc 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -23,7 +23,6 @@ List, Mapping, Tuple, - Union, cast, ) @@ -578,7 +577,7 @@ async def recv(self) -> Data: async def send( self, - message: Union[Data, Iterable[Data], AsyncIterable[Data]], + message: Data | Iterable[Data] | AsyncIterable[Data], ) -> None: """ Send a message. diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 551115174..13a6f5591 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -54,6 +54,7 @@ __all__ = ["serve", "unix_serve", "WebSocketServerProtocol", "WebSocketServer"] +# Change to HeadersLike | ... when dropping Python < 3.10. HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]] HTTPResponse = Tuple[StatusLike, HeadersLike, bytes] @@ -96,10 +97,10 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): def __init__( self, - ws_handler: Union[ - Callable[[WebSocketServerProtocol], Awaitable[Any]], - Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated - ], + ws_handler: ( + Callable[[WebSocketServerProtocol], Awaitable[Any]] + | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] # deprecated + ), ws_server: WebSocketServer, *, logger: LoggerLike | None = None, @@ -934,7 +935,7 @@ class Serve: should be negotiated and run. subprotocols: List of supported subprotocols, in order of decreasing preference. - extra_headers (Union[HeadersLike, Callable[[str, Headers], HeadersLike]]): + extra_headers (HeadersLike | Callable[[str, Headers] | HeadersLike]): Arbitrary HTTP headers to add to the response. This can be a :data:`~websockets.datastructures.HeadersLike` or a callable taking the request path and headers in arguments and returning @@ -972,11 +973,11 @@ class Serve: def __init__( self, - ws_handler: Union[ - Callable[[WebSocketServerProtocol], Awaitable[Any]], - Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated - ], - host: Union[str, Sequence[str]] | None = None, + ws_handler: ( + Callable[[WebSocketServerProtocol], Awaitable[Any]] + | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] # deprecated + ), + host: str | Sequence[str] | None = None, port: int | None = None, *, create_protocol: Callable[..., WebSocketServerProtocol] | None = None, @@ -1126,10 +1127,10 @@ async def __await_impl__(self) -> WebSocketServer: def unix_serve( - ws_handler: Union[ - Callable[[WebSocketServerProtocol], Awaitable[Any]], - Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated - ], + ws_handler: ( + Callable[[WebSocketServerProtocol], Awaitable[Any]] + | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] # deprecated + ), path: str | None = None, **kwargs: Any, ) -> Serve: @@ -1152,10 +1153,10 @@ def unix_serve( def remove_path_argument( - ws_handler: Union[ - Callable[[WebSocketServerProtocol], Awaitable[Any]], - Callable[[WebSocketServerProtocol, str], Awaitable[Any]], - ] + ws_handler: ( + Callable[[WebSocketServerProtocol], Awaitable[Any]] + | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] + ) ) -> Callable[[WebSocketServerProtocol], Awaitable[Any]]: try: inspect.signature(ws_handler).bind(None) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 8aa222eeb..f288a2733 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -38,6 +38,7 @@ "SEND_EOF", ] +# Change to Request | Response | Frame when dropping Python < 3.10. Event = Union[Request, Response, Frame] """Events that :meth:`~Protocol.events_received` may return.""" diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index bb9743181..7a750331d 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -8,7 +8,7 @@ import threading import uuid from types import TracebackType -from typing import Any, Dict, Iterable, Iterator, Mapping, Type, Union +from typing import Any, Dict, Iterable, Iterator, Mapping, Type from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode, prepare_ctrl @@ -242,7 +242,7 @@ def recv_streaming(self) -> Iterator[Data]: "is already running recv or recv_streaming" ) from None - def send(self, message: Union[Data, Iterable[Data]]) -> None: + def send(self, message: Data | Iterable[Data]) -> None: """ Send a message. diff --git a/src/websockets/typing.py b/src/websockets/typing.py index 7c5b3664d..73d4a4754 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -19,6 +19,7 @@ # Public types used in the signature of public APIs +# Change to str | bytes when dropping Python < 3.10. Data = Union[str, bytes] """Types supported in a WebSocket message: :class:`str` for a Text_ frame, :class:`bytes` for a Binary_. @@ -29,6 +30,7 @@ """ +# Change to logging.Logger | ... when dropping Python < 3.10. if typing.TYPE_CHECKING: LoggerLike = Union[logging.Logger, logging.LoggerAdapter[Any]] """Types accepted where a :class:`~logging.Logger` is expected.""" @@ -37,6 +39,7 @@ """Types accepted where a :class:`~logging.Logger` is expected.""" +# Change to http.HTTPStatus | int when dropping Python < 3.10. StatusLike = Union[http.HTTPStatus, int] """ Types accepted where an :class:`~http.HTTPStatus` is expected.""" From cd059d5633eed129775572054a7458d8e3f07166 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 12 Jul 2024 09:28:36 +0200 Subject: [PATCH 1286/1539] Switch from typing.Dict/List/Tuple/Type/Set to native types. --- src/websockets/client.py | 12 +++---- src/websockets/datastructures.py | 11 +++--- src/websockets/extensions/base.py | 6 ++-- .../extensions/permessage_deflate.py | 18 +++++----- src/websockets/frames.py | 4 +-- src/websockets/headers.py | 34 +++++++++---------- src/websockets/imports.py | 10 +++--- src/websockets/legacy/auth.py | 1 + src/websockets/legacy/client.py | 15 ++++---- src/websockets/legacy/framing.py | 4 +-- src/websockets/legacy/handshake.py | 9 +++-- src/websockets/legacy/http.py | 5 ++- src/websockets/legacy/protocol.py | 9 ++--- src/websockets/legacy/server.py | 28 +++++++-------- src/websockets/protocol.py | 14 ++++---- src/websockets/server.py | 16 ++++----- src/websockets/sync/client.py | 4 +-- src/websockets/sync/connection.py | 6 ++-- src/websockets/sync/messages.py | 4 +-- src/websockets/sync/server.py | 6 ++-- src/websockets/typing.py | 4 ++- src/websockets/uri.py | 3 +- 22 files changed, 107 insertions(+), 116 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index cfb441fd9..8f78ac320 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import Any, Generator, List, Sequence +from typing import Any, Generator, Sequence from .datastructures import Headers, MultipleValuesError from .exceptions import ( @@ -153,7 +153,7 @@ def process_response(self, response: Response) -> None: headers = response.headers - connection: List[ConnectionOption] = sum( + connection: list[ConnectionOption] = sum( [parse_connection(value) for value in headers.get_all("Connection")], [] ) @@ -162,7 +162,7 @@ def process_response(self, response: Response) -> None: "Connection", ", ".join(connection) if connection else None ) - upgrade: List[UpgradeProtocol] = sum( + upgrade: list[UpgradeProtocol] = sum( [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] ) @@ -188,7 +188,7 @@ def process_response(self, response: Response) -> None: self.subprotocol = self.process_subprotocol(headers) - def process_extensions(self, headers: Headers) -> List[Extension]: + def process_extensions(self, headers: Headers) -> list[Extension]: """ Handle the Sec-WebSocket-Extensions HTTP response header. @@ -219,7 +219,7 @@ def process_extensions(self, headers: Headers) -> List[Extension]: InvalidHandshake: To abort the handshake. """ - accepted_extensions: List[Extension] = [] + accepted_extensions: list[Extension] = [] extensions = headers.get_all("Sec-WebSocket-Extensions") @@ -227,7 +227,7 @@ def process_extensions(self, headers: Headers) -> List[Extension]: if self.available_extensions is None: raise InvalidHandshake("no extensions supported") - parsed_extensions: List[ExtensionHeader] = sum( + parsed_extensions: list[ExtensionHeader] = sum( [parse_extension(header_value) for header_value in extensions], [] ) diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index 5605772d8..3d64d951e 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -2,10 +2,8 @@ from typing import ( Any, - Dict, Iterable, Iterator, - List, Mapping, MutableMapping, Protocol, @@ -72,8 +70,8 @@ class Headers(MutableMapping[str, str]): # Like dict, Headers accepts an optional "mapping or iterable" argument. def __init__(self, *args: HeadersLike, **kwargs: str) -> None: - self._dict: Dict[str, List[str]] = {} - self._list: List[Tuple[str, str]] = [] + self._dict: dict[str, list[str]] = {} + self._list: list[tuple[str, str]] = [] self.update(*args, **kwargs) def __str__(self) -> str: @@ -147,7 +145,7 @@ def update(self, *args: HeadersLike, **kwargs: str) -> None: # Methods for handling multiple values - def get_all(self, key: str) -> List[str]: + def get_all(self, key: str) -> list[str]: """ Return the (possibly empty) list of all values for a header. @@ -157,7 +155,7 @@ def get_all(self, key: str) -> List[str]: """ return self._dict.get(key.lower(), []) - def raw_items(self) -> Iterator[Tuple[str, str]]: + def raw_items(self) -> Iterator[tuple[str, str]]: """ Return an iterator of all values as ``(name, value)`` pairs. @@ -181,6 +179,7 @@ def __getitem__(self, key: str) -> str: ... HeadersLike = Union[ Headers, Mapping[str, str], + # Change to tuple[str, str] when dropping Python < 3.9. Iterable[Tuple[str, str]], SupportsKeysAndGetItem, ] diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index 5b5528a09..a6c76c3d4 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Sequence, Tuple +from typing import Sequence from .. import frames from ..typing import ExtensionName, ExtensionParameter @@ -63,7 +63,7 @@ class ClientExtensionFactory: name: ExtensionName """Extension identifier.""" - def get_request_params(self) -> List[ExtensionParameter]: + def get_request_params(self) -> list[ExtensionParameter]: """ Build parameters to send to the server for this extension. @@ -108,7 +108,7 @@ def process_request_params( self, params: Sequence[ExtensionParameter], accepted_extensions: Sequence[Extension], - ) -> Tuple[List[ExtensionParameter], Extension]: + ) -> tuple[list[ExtensionParameter], Extension]: """ Process parameters received from the client. diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 48a6a0833..579262f02 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -2,7 +2,7 @@ import dataclasses import zlib -from typing import Any, Dict, List, Sequence, Tuple +from typing import Any, Sequence from .. import exceptions, frames from ..typing import ExtensionName, ExtensionParameter @@ -36,7 +36,7 @@ def __init__( local_no_context_takeover: bool, remote_max_window_bits: int, local_max_window_bits: int, - compress_settings: Dict[Any, Any] | None = None, + compress_settings: dict[Any, Any] | None = None, ) -> None: """ Configure the Per-Message Deflate extension. @@ -176,12 +176,12 @@ def _build_parameters( client_no_context_takeover: bool, server_max_window_bits: int | None, client_max_window_bits: int | bool | None, -) -> List[ExtensionParameter]: +) -> list[ExtensionParameter]: """ Build a list of ``(name, value)`` pairs for some compression parameters. """ - params: List[ExtensionParameter] = [] + params: list[ExtensionParameter] = [] if server_no_context_takeover: params.append(("server_no_context_takeover", None)) if client_no_context_takeover: @@ -197,7 +197,7 @@ def _build_parameters( def _extract_parameters( params: Sequence[ExtensionParameter], *, is_server: bool -) -> Tuple[bool, bool, int | None, int | bool | None]: +) -> tuple[bool, bool, int | None, int | bool | None]: """ Extract compression parameters from a list of ``(name, value)`` pairs. @@ -288,7 +288,7 @@ def __init__( client_no_context_takeover: bool = False, server_max_window_bits: int | None = None, client_max_window_bits: int | bool | None = True, - compress_settings: Dict[str, Any] | None = None, + compress_settings: dict[str, Any] | None = None, ) -> None: """ Configure the Per-Message Deflate extension factory. @@ -314,7 +314,7 @@ def __init__( self.client_max_window_bits = client_max_window_bits self.compress_settings = compress_settings - def get_request_params(self) -> List[ExtensionParameter]: + def get_request_params(self) -> list[ExtensionParameter]: """ Build request parameters. @@ -491,7 +491,7 @@ def __init__( client_no_context_takeover: bool = False, server_max_window_bits: int | None = None, client_max_window_bits: int | None = None, - compress_settings: Dict[str, Any] | None = None, + compress_settings: dict[str, Any] | None = None, require_client_max_window_bits: bool = False, ) -> None: """ @@ -524,7 +524,7 @@ def process_request_params( self, params: Sequence[ExtensionParameter], accepted_extensions: Sequence[Extension], - ) -> Tuple[List[ExtensionParameter], PerMessageDeflate]: + ) -> tuple[list[ExtensionParameter], PerMessageDeflate]: """ Process request parameters. diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 5a304d6a7..0da676432 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -5,7 +5,7 @@ import io import secrets import struct -from typing import Callable, Generator, Sequence, Tuple +from typing import Callable, Generator, Sequence from . import exceptions, extensions from .typing import Data @@ -353,7 +353,7 @@ def check(self) -> None: raise exceptions.ProtocolError("fragmented control frame") -def prepare_data(data: Data) -> Tuple[int, bytes]: +def prepare_data(data: Data) -> tuple[int, bytes]: """ Convert a string or byte-like object to an opcode and a bytes-like object. diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 3b316e0bf..bc42e0b72 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -4,7 +4,7 @@ import binascii import ipaddress import re -from typing import Callable, List, Sequence, Tuple, TypeVar, cast +from typing import Callable, Sequence, TypeVar, cast from . import exceptions from .typing import ( @@ -96,7 +96,7 @@ def parse_OWS(header: str, pos: int) -> int: _token_re = re.compile(r"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+") -def parse_token(header: str, pos: int, header_name: str) -> Tuple[str, int]: +def parse_token(header: str, pos: int, header_name: str) -> tuple[str, int]: """ Parse a token from ``header`` at the given position. @@ -120,7 +120,7 @@ def parse_token(header: str, pos: int, header_name: str) -> Tuple[str, int]: _unquote_re = re.compile(r"\\([\x09\x20-\x7e\x80-\xff])") -def parse_quoted_string(header: str, pos: int, header_name: str) -> Tuple[str, int]: +def parse_quoted_string(header: str, pos: int, header_name: str) -> tuple[str, int]: """ Parse a quoted string from ``header`` at the given position. @@ -158,11 +158,11 @@ def build_quoted_string(value: str) -> str: def parse_list( - parse_item: Callable[[str, int, str], Tuple[T, int]], + parse_item: Callable[[str, int, str], tuple[T, int]], header: str, pos: int, header_name: str, -) -> List[T]: +) -> list[T]: """ Parse a comma-separated list from ``header`` at the given position. @@ -227,7 +227,7 @@ def parse_list( def parse_connection_option( header: str, pos: int, header_name: str -) -> Tuple[ConnectionOption, int]: +) -> tuple[ConnectionOption, int]: """ Parse a Connection option from ``header`` at the given position. @@ -241,7 +241,7 @@ def parse_connection_option( return cast(ConnectionOption, item), pos -def parse_connection(header: str) -> List[ConnectionOption]: +def parse_connection(header: str) -> list[ConnectionOption]: """ Parse a ``Connection`` header. @@ -264,7 +264,7 @@ def parse_connection(header: str) -> List[ConnectionOption]: def parse_upgrade_protocol( header: str, pos: int, header_name: str -) -> Tuple[UpgradeProtocol, int]: +) -> tuple[UpgradeProtocol, int]: """ Parse an Upgrade protocol from ``header`` at the given position. @@ -282,7 +282,7 @@ def parse_upgrade_protocol( return cast(UpgradeProtocol, match.group()), match.end() -def parse_upgrade(header: str) -> List[UpgradeProtocol]: +def parse_upgrade(header: str) -> list[UpgradeProtocol]: """ Parse an ``Upgrade`` header. @@ -300,7 +300,7 @@ def parse_upgrade(header: str) -> List[UpgradeProtocol]: def parse_extension_item_param( header: str, pos: int, header_name: str -) -> Tuple[ExtensionParameter, int]: +) -> tuple[ExtensionParameter, int]: """ Parse a single extension parameter from ``header`` at the given position. @@ -336,7 +336,7 @@ def parse_extension_item_param( def parse_extension_item( header: str, pos: int, header_name: str -) -> Tuple[ExtensionHeader, int]: +) -> tuple[ExtensionHeader, int]: """ Parse an extension definition from ``header`` at the given position. @@ -359,7 +359,7 @@ def parse_extension_item( return (cast(ExtensionName, name), parameters), pos -def parse_extension(header: str) -> List[ExtensionHeader]: +def parse_extension(header: str) -> list[ExtensionHeader]: """ Parse a ``Sec-WebSocket-Extensions`` header. @@ -389,7 +389,7 @@ def parse_extension(header: str) -> List[ExtensionHeader]: def build_extension_item( - name: ExtensionName, parameters: List[ExtensionParameter] + name: ExtensionName, parameters: list[ExtensionParameter] ) -> str: """ Build an extension definition. @@ -424,7 +424,7 @@ def build_extension(extensions: Sequence[ExtensionHeader]) -> str: def parse_subprotocol_item( header: str, pos: int, header_name: str -) -> Tuple[Subprotocol, int]: +) -> tuple[Subprotocol, int]: """ Parse a subprotocol from ``header`` at the given position. @@ -438,7 +438,7 @@ def parse_subprotocol_item( return cast(Subprotocol, item), pos -def parse_subprotocol(header: str) -> List[Subprotocol]: +def parse_subprotocol(header: str) -> list[Subprotocol]: """ Parse a ``Sec-WebSocket-Protocol`` header. @@ -498,7 +498,7 @@ def build_www_authenticate_basic(realm: str) -> str: _token68_re = re.compile(r"[A-Za-z0-9-._~+/]+=*") -def parse_token68(header: str, pos: int, header_name: str) -> Tuple[str, int]: +def parse_token68(header: str, pos: int, header_name: str) -> tuple[str, int]: """ Parse a token68 from ``header`` at the given position. @@ -525,7 +525,7 @@ def parse_end(header: str, pos: int, header_name: str) -> None: raise exceptions.InvalidHeaderFormat(header_name, "trailing data", header, pos) -def parse_authorization_basic(header: str) -> Tuple[str, str]: +def parse_authorization_basic(header: str) -> tuple[str, str]: """ Parse an ``Authorization`` header for HTTP Basic Auth. diff --git a/src/websockets/imports.py b/src/websockets/imports.py index 9c05234f5..bb80e4eac 100644 --- a/src/websockets/imports.py +++ b/src/websockets/imports.py @@ -1,13 +1,13 @@ from __future__ import annotations import warnings -from typing import Any, Dict, Iterable +from typing import Any, Iterable __all__ = ["lazy_import"] -def import_name(name: str, source: str, namespace: Dict[str, Any]) -> Any: +def import_name(name: str, source: str, namespace: dict[str, Any]) -> Any: """ Import ``name`` from ``source`` in ``namespace``. @@ -29,9 +29,9 @@ def import_name(name: str, source: str, namespace: Dict[str, Any]) -> Any: def lazy_import( - namespace: Dict[str, Any], - aliases: Dict[str, str] | None = None, - deprecated_aliases: Dict[str, str] | None = None, + namespace: dict[str, Any], + aliases: dict[str, str] | None = None, + deprecated_aliases: dict[str, str] | None = None, ) -> None: """ Provide lazy, module-level imports. diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index 9d685d9f4..c2d30e4b4 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -13,6 +13,7 @@ __all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"] +# Change to tuple[str, str] when dropping Python < 3.9. Credentials = Tuple[str, str] diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index f7464368f..d9d69fdaa 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -12,10 +12,7 @@ AsyncIterator, Callable, Generator, - List, Sequence, - Tuple, - Type, cast, ) @@ -122,7 +119,7 @@ def write_http_request(self, path: str, headers: Headers) -> None: self.transport.write(request.encode()) - async def read_http_response(self) -> Tuple[int, Headers]: + async def read_http_response(self) -> tuple[int, Headers]: """ Read status line and headers from the HTTP response. @@ -152,7 +149,7 @@ async def read_http_response(self) -> Tuple[int, Headers]: def process_extensions( headers: Headers, available_extensions: Sequence[ClientExtensionFactory] | None, - ) -> List[Extension]: + ) -> list[Extension]: """ Handle the Sec-WebSocket-Extensions HTTP response header. @@ -179,7 +176,7 @@ def process_extensions( order of extensions, may be implemented by overriding this method. """ - accepted_extensions: List[Extension] = [] + accepted_extensions: list[Extension] = [] header_values = headers.get_all("Sec-WebSocket-Extensions") @@ -187,7 +184,7 @@ def process_extensions( if available_extensions is None: raise InvalidHandshake("no extensions supported") - parsed_header_values: List[ExtensionHeader] = sum( + parsed_header_values: list[ExtensionHeader] = sum( [parse_extension(header_value) for header_value in header_values], [] ) @@ -455,7 +452,7 @@ def __init__( close_timeout = timeout # Backwards compatibility: create_protocol used to be called klass. - klass: Type[WebSocketClientProtocol] | None = kwargs.pop("klass", None) + klass: type[WebSocketClientProtocol] | None = kwargs.pop("klass", None) if klass is None: klass = WebSocketClientProtocol else: @@ -629,7 +626,7 @@ async def __aenter__(self) -> WebSocketClientProtocol: async def __aexit__( self, - exc_type: Type[BaseException] | None, + exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index 8a13fa446..1aaca5cc6 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -1,7 +1,7 @@ from __future__ import annotations import struct -from typing import Any, Awaitable, Callable, NamedTuple, Sequence, Tuple +from typing import Any, Awaitable, Callable, NamedTuple, Sequence from .. import extensions, frames from ..exceptions import PayloadTooBig, ProtocolError @@ -152,7 +152,7 @@ def write( ) -def parse_close(data: bytes) -> Tuple[int, str]: +def parse_close(data: bytes) -> tuple[int, str]: """ Parse the payload from a close frame. diff --git a/src/websockets/legacy/handshake.py b/src/websockets/legacy/handshake.py index 5853c31db..2a39c1b03 100644 --- a/src/websockets/legacy/handshake.py +++ b/src/websockets/legacy/handshake.py @@ -2,7 +2,6 @@ import base64 import binascii -from typing import List from ..datastructures import Headers, MultipleValuesError from ..exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade @@ -55,14 +54,14 @@ def check_request(headers: Headers) -> str: Then, the server must return a 400 Bad Request error. """ - connection: List[ConnectionOption] = sum( + connection: list[ConnectionOption] = sum( [parse_connection(value) for value in headers.get_all("Connection")], [] ) if not any(value.lower() == "upgrade" for value in connection): raise InvalidUpgrade("Connection", ", ".join(connection)) - upgrade: List[UpgradeProtocol] = sum( + upgrade: list[UpgradeProtocol] = sum( [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] ) @@ -135,14 +134,14 @@ def check_response(headers: Headers, key: str) -> None: InvalidHandshake: If the handshake response is invalid. """ - connection: List[ConnectionOption] = sum( + connection: list[ConnectionOption] = sum( [parse_connection(value) for value in headers.get_all("Connection")], [] ) if not any(value.lower() == "upgrade" for value in connection): raise InvalidUpgrade("Connection", " ".join(connection)) - upgrade: List[UpgradeProtocol] = sum( + upgrade: list[UpgradeProtocol] = sum( [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] ) diff --git a/src/websockets/legacy/http.py b/src/websockets/legacy/http.py index 2ac7f7092..9a553e175 100644 --- a/src/websockets/legacy/http.py +++ b/src/websockets/legacy/http.py @@ -2,7 +2,6 @@ import asyncio import re -from typing import Tuple from ..datastructures import Headers from ..exceptions import SecurityError @@ -42,7 +41,7 @@ def d(value: bytes) -> str: _value_re = re.compile(rb"[\x09\x20-\x7e\x80-\xff]*") -async def read_request(stream: asyncio.StreamReader) -> Tuple[str, Headers]: +async def read_request(stream: asyncio.StreamReader) -> tuple[str, Headers]: """ Read an HTTP/1.1 GET request and return ``(path, headers)``. @@ -91,7 +90,7 @@ async def read_request(stream: asyncio.StreamReader) -> Tuple[str, Headers]: return path, headers -async def read_response(stream: asyncio.StreamReader) -> Tuple[int, str, Headers]: +async def read_response(stream: asyncio.StreamReader) -> tuple[int, str, Headers]: """ Read an HTTP/1.1 response and return ``(status_code, reason, headers)``. diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index f4c5901dc..67161019f 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -18,11 +18,8 @@ Awaitable, Callable, Deque, - Dict, Iterable, - List, Mapping, - Tuple, cast, ) @@ -262,7 +259,7 @@ def __init__( """Opening handshake response headers.""" # WebSocket protocol parameters. - self.extensions: List[Extension] = [] + self.extensions: list[Extension] = [] self.subprotocol: Subprotocol | None = None """Subprotocol, if one was negotiated.""" @@ -286,7 +283,7 @@ def __init__( self._fragmented_message_waiter: asyncio.Future[None] | None = None # Mapping of ping IDs to pong waiters, in chronological order. - self.pings: Dict[bytes, Tuple[asyncio.Future[float], float]] = {} + self.pings: dict[bytes, tuple[asyncio.Future[float], float]] = {} self.latency: float = 0 """ @@ -1042,7 +1039,7 @@ async def read_message(self) -> Data | None: return frame.data.decode("utf-8") if text else frame.data # 5.4. Fragmentation - fragments: List[Data] = [] + fragments: list[Data] = [] max_size = self.max_size if text: decoder_factory = codecs.getincrementaldecoder("utf-8") diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 13a6f5591..c0ea6a764 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -15,11 +15,8 @@ Callable, Generator, Iterable, - List, Sequence, - Set, Tuple, - Type, Union, cast, ) @@ -57,6 +54,7 @@ # Change to HeadersLike | ... when dropping Python < 3.10. HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]] +# Change to tuple[...] when dropping Python < 3.9. HTTPResponse = Tuple[StatusLike, HeadersLike, bytes] @@ -263,7 +261,7 @@ async def handler(self) -> None: self.ws_server.unregister(self) self.logger.info("connection closed") - async def read_http_request(self) -> Tuple[str, Headers]: + async def read_http_request(self) -> tuple[str, Headers]: """ Read request line and headers from the HTTP request. @@ -349,7 +347,7 @@ async def process_request( request_headers: Request headers. Returns: - Tuple[StatusLike, HeadersLike, bytes] | None: :obj:`None` to + tuple[StatusLike, HeadersLike, bytes] | None: :obj:`None` to continue the WebSocket handshake normally. An HTTP response, represented by a 3-uple of the response status, @@ -401,7 +399,7 @@ def process_origin( def process_extensions( headers: Headers, available_extensions: Sequence[ServerExtensionFactory] | None, - ) -> Tuple[str | None, List[Extension]]: + ) -> tuple[str | None, list[Extension]]: """ Handle the Sec-WebSocket-Extensions HTTP request header. @@ -439,13 +437,13 @@ def process_extensions( """ response_header_value: str | None = None - extension_headers: List[ExtensionHeader] = [] - accepted_extensions: List[Extension] = [] + extension_headers: list[ExtensionHeader] = [] + accepted_extensions: list[Extension] = [] header_values = headers.get_all("Sec-WebSocket-Extensions") if header_values and available_extensions: - parsed_header_values: List[ExtensionHeader] = sum( + parsed_header_values: list[ExtensionHeader] = sum( [parse_extension(header_value) for header_value in header_values], [] ) @@ -502,7 +500,7 @@ def process_subprotocol( header_values = headers.get_all("Sec-WebSocket-Protocol") if header_values and available_subprotocols: - parsed_header_values: List[Subprotocol] = sum( + parsed_header_values: list[Subprotocol] = sum( [parse_subprotocol(header_value) for header_value in header_values], [] ) @@ -669,7 +667,7 @@ def __init__(self, logger: LoggerLike | None = None): self.logger = logger # Keep track of active connections. - self.websockets: Set[WebSocketServerProtocol] = set() + self.websockets: set[WebSocketServerProtocol] = set() # Task responsible for closing the server and terminating connections. self.close_task: asyncio.Task[None] | None = None @@ -871,7 +869,7 @@ async def __aenter__(self) -> WebSocketServer: # pragma: no cover async def __aexit__( self, - exc_type: Type[BaseException] | None, + exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: # pragma: no cover @@ -944,7 +942,7 @@ class Serve: It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. process_request (Callable[[str, Headers], \ - Awaitable[Tuple[StatusLike, HeadersLike, bytes] | None]] | None): + Awaitable[tuple[StatusLike, HeadersLike, bytes] | None]] | None): Intercept HTTP request before the opening handshake. See :meth:`~WebSocketServerProtocol.process_request` for details. select_subprotocol: Select a subprotocol supported by the client. @@ -1015,7 +1013,7 @@ def __init__( close_timeout = timeout # Backwards compatibility: create_protocol used to be called klass. - klass: Type[WebSocketServerProtocol] | None = kwargs.pop("klass", None) + klass: type[WebSocketServerProtocol] | None = kwargs.pop("klass", None) if klass is None: klass = WebSocketServerProtocol else: @@ -1100,7 +1098,7 @@ async def __aenter__(self) -> WebSocketServer: async def __aexit__( self, - exc_type: Type[BaseException] | None, + exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index f288a2733..2f5542f6e 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -3,7 +3,7 @@ import enum import logging import uuid -from typing import Generator, List, Type, Union +from typing import Generator, Union from .exceptions import ( ConnectionClosed, @@ -125,7 +125,7 @@ def __init__( # WebSocket protocol parameters. self.origin: Origin | None = None - self.extensions: List[Extension] = [] + self.extensions: list[Extension] = [] self.subprotocol: Subprotocol | None = None # Close code and reason, set when a close frame is sent or received. @@ -147,8 +147,8 @@ def __init__( # Parser state. self.reader = StreamReader() - self.events: List[Event] = [] - self.writes: List[bytes] = [] + self.events: list[Event] = [] + self.writes: list[bytes] = [] self.parser = self.parse() next(self.parser) # start coroutine self.parser_exc: Exception | None = None @@ -222,7 +222,7 @@ def close_exc(self) -> ConnectionClosed: """ assert self.state is CLOSED, "connection isn't closed yet" - exc_type: Type[ConnectionClosed] + exc_type: type[ConnectionClosed] if ( self.close_rcvd is not None and self.close_sent is not None @@ -458,7 +458,7 @@ def fail(self, code: int, reason: str = "") -> None: # Public method for getting incoming events after receiving data. - def events_received(self) -> List[Event]: + def events_received(self) -> list[Event]: """ Fetch events generated from data received from the network. @@ -474,7 +474,7 @@ def events_received(self) -> List[Event]: # Public method for getting outgoing data after receiving data or sending events. - def data_to_send(self) -> List[bytes]: + def data_to_send(self) -> list[bytes]: """ Obtain data to send to the network. diff --git a/src/websockets/server.py b/src/websockets/server.py index a92541085..f976ebad7 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -5,7 +5,7 @@ import email.utils import http import warnings -from typing import Any, Callable, Generator, List, Sequence, Tuple, cast +from typing import Any, Callable, Generator, Sequence, cast from .datastructures import Headers, MultipleValuesError from .exceptions import ( @@ -201,7 +201,7 @@ def accept(self, request: Request) -> Response: def process_request( self, request: Request, - ) -> Tuple[str, str | None, str | None]: + ) -> tuple[str, str | None, str | None]: """ Check a handshake request and negotiate extensions and subprotocol. @@ -224,7 +224,7 @@ def process_request( """ headers = request.headers - connection: List[ConnectionOption] = sum( + connection: list[ConnectionOption] = sum( [parse_connection(value) for value in headers.get_all("Connection")], [] ) @@ -233,7 +233,7 @@ def process_request( "Connection", ", ".join(connection) if connection else None ) - upgrade: List[UpgradeProtocol] = sum( + upgrade: list[UpgradeProtocol] = sum( [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] ) @@ -317,7 +317,7 @@ def process_origin(self, headers: Headers) -> Origin | None: def process_extensions( self, headers: Headers, - ) -> Tuple[str | None, List[Extension]]: + ) -> tuple[str | None, list[Extension]]: """ Handle the Sec-WebSocket-Extensions HTTP request header. @@ -355,13 +355,13 @@ def process_extensions( """ response_header_value: str | None = None - extension_headers: List[ExtensionHeader] = [] - accepted_extensions: List[Extension] = [] + extension_headers: list[ExtensionHeader] = [] + accepted_extensions: list[Extension] = [] header_values = headers.get_all("Sec-WebSocket-Extensions") if header_values and self.available_extensions: - parsed_header_values: List[ExtensionHeader] = sum( + parsed_header_values: list[ExtensionHeader] = sum( [parse_extension(header_value) for header_value in header_values], [] ) diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 60b49ebc3..c97a09402 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -4,7 +4,7 @@ import ssl as ssl_module import threading import warnings -from typing import Any, Sequence, Type +from typing import Any, Sequence from ..client import ClientProtocol from ..datastructures import HeadersLike @@ -146,7 +146,7 @@ def connect( # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization - create_connection: Type[ClientConnection] | None = None, + create_connection: type[ClientConnection] | None = None, **kwargs: Any, ) -> ClientConnection: """ diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 7a750331d..33d8299e2 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -8,7 +8,7 @@ import threading import uuid from types import TracebackType -from typing import Any, Dict, Iterable, Iterator, Mapping, Type +from typing import Any, Iterable, Iterator, Mapping from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode, prepare_ctrl @@ -80,7 +80,7 @@ def __init__( self.close_deadline: Deadline | None = None # Mapping of ping IDs to pong waiters, in chronological order. - self.ping_waiters: Dict[bytes, threading.Event] = {} + self.ping_waiters: dict[bytes, threading.Event] = {} # Receiving events from the socket. This thread explicitly is marked as # to support creating a connection in a non-daemon thread then using it @@ -140,7 +140,7 @@ def __enter__(self) -> Connection: def __exit__( self, - exc_type: Type[BaseException] | None, + exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 2c604ba09..a6e78e7fd 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -3,7 +3,7 @@ import codecs import queue import threading -from typing import Iterator, List, cast +from typing import Iterator, cast from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame from ..typing import Data @@ -44,7 +44,7 @@ def __init__(self) -> None: self.decoder: codecs.IncrementalDecoder | None = None # Buffer of frames belonging to the same message. - self.chunks: List[Data] = [] + self.chunks: list[Data] = [] # When switching from "buffering" to "streaming", we use a thread-safe # queue for transferring frames from the writing thread (library code) diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index b801510b4..4f088b63a 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -10,7 +10,7 @@ import threading import warnings from types import TracebackType -from typing import Any, Callable, Sequence, Type +from typing import Any, Callable, Sequence from ..extensions.base import ServerExtensionFactory from ..extensions.permessage_deflate import enable_server_permessage_deflate @@ -262,7 +262,7 @@ def __enter__(self) -> WebSocketServer: def __exit__( self, - exc_type: Type[BaseException] | None, + exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: @@ -312,7 +312,7 @@ def serve( # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization - create_connection: Type[ServerConnection] | None = None, + create_connection: type[ServerConnection] | None = None, **kwargs: Any, ) -> WebSocketServer: """ diff --git a/src/websockets/typing.py b/src/websockets/typing.py index 73d4a4754..6360c7a0a 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -56,13 +56,15 @@ ExtensionName = NewType("ExtensionName", str) """Name of a WebSocket extension.""" -# Change to str | None when dropping Python < 3.10. +# Change to tuple[str, Optional[str]] when dropping Python < 3.9. +# Change to tuple[str, str | None] when dropping Python < 3.10. ExtensionParameter = Tuple[str, Optional[str]] """Parameter of a WebSocket extension.""" # Private types +# Change to tuple[.., list[...]] when dropping Python < 3.9. ExtensionHeader = Tuple[ExtensionName, List[ExtensionParameter]] """Extension in a ``Sec-WebSocket-Extensions`` header.""" diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 902716066..5cb38a9cc 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -2,7 +2,6 @@ import dataclasses import urllib.parse -from typing import Tuple from . import exceptions @@ -47,7 +46,7 @@ def resource_name(self) -> str: return resource_name @property - def user_info(self) -> Tuple[str, str] | None: + def user_info(self) -> tuple[str, str] | None: if self.username is None: return None assert self.password is not None From f45286b3b2d54f8b79087b060858042b2488688b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 12 Jul 2024 09:37:09 +0200 Subject: [PATCH 1287/1539] Pick changes suggested by `pyupgrade --py38-plus`. Other changes were ignored, on purpose. --- src/websockets/sync/messages.py | 3 +-- tests/legacy/test_client_server.py | 2 +- tests/legacy/utils.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index a6e78e7fd..6cbff2595 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -53,8 +53,7 @@ def __init__(self) -> None: # value marking the end of the message, superseding message_complete. # Stream data from frames belonging to the same message. - # Remove quotes around type when dropping Python < 3.9. - self.chunks_queue: "queue.SimpleQueue[Data | None] | None" = None + self.chunks_queue: queue.SimpleQueue[Data | None] | None = None # This flag marks the end of the connection. self.closed = False diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 51a74734b..c38086572 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -65,7 +65,7 @@ async def default_handler(ws): await ws.wait_closed() await asyncio.sleep(2 * MS) else: - await ws.send((await ws.recv())) + await ws.send(await ws.recv()) async def redirect_request(path, headers, test, status): diff --git a/tests/legacy/utils.py b/tests/legacy/utils.py index 4a21dcaeb..28bc90df3 100644 --- a/tests/legacy/utils.py +++ b/tests/legacy/utils.py @@ -79,6 +79,6 @@ def assertDeprecationWarnings(self, recorded_warnings, expected_warnings): for recorded in recorded_warnings: self.assertEqual(type(recorded.message), DeprecationWarning) self.assertEqual( - set(str(recorded.message) for recorded in recorded_warnings), + {str(recorded.message) for recorded in recorded_warnings}, set(expected_warnings), ) From 650d08caf1c5f84c77a8bf8780a1b407a1432357 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jul 2024 08:19:32 +0200 Subject: [PATCH 1288/1539] Upgrade to mypy 1.11. --- src/websockets/legacy/auth.py | 5 +++++ src/websockets/legacy/client.py | 6 ++++-- src/websockets/legacy/server.py | 3 +++ 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index c2d30e4b4..8526bad6b 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -178,6 +178,11 @@ async def check_credentials(username: str, password: str) -> bool: if create_protocol is None: create_protocol = BasicAuthWebSocketServerProtocol + # Help mypy and avoid this error: "type[BasicAuthWebSocketServerProtocol] | + # Callable[..., BasicAuthWebSocketServerProtocol]" not callable [misc] + create_protocol = cast( + Callable[..., BasicAuthWebSocketServerProtocol], create_protocol + ) return functools.partial( create_protocol, realm=realm, diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index d9d69fdaa..b15eddf75 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -489,6 +489,9 @@ def __init__( if subprotocols is not None: validate_subprotocols(subprotocols) + # Help mypy and avoid this error: "type[WebSocketClientProtocol] | + # Callable[..., WebSocketClientProtocol]" not callable [misc] + create_protocol = cast(Callable[..., WebSocketClientProtocol], create_protocol) factory = functools.partial( create_protocol, logger=logger, @@ -641,8 +644,7 @@ def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]: async def __await_impl__(self) -> WebSocketClientProtocol: async with asyncio_timeout(self.open_timeout): for _redirects in range(self.MAX_REDIRECTS_ALLOWED): - _transport, _protocol = await self._create_connection() - protocol = cast(WebSocketClientProtocol, _protocol) + _transport, protocol = await self._create_connection() try: await protocol.handshake( self._wsuri, diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index c0ea6a764..08c82df25 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -1045,6 +1045,9 @@ def __init__( if subprotocols is not None: validate_subprotocols(subprotocols) + # Help mypy and avoid this error: "type[WebSocketServerProtocol] | + # Callable[..., WebSocketServerProtocol]" not callable [misc] + create_protocol = cast(Callable[..., WebSocketServerProtocol], create_protocol) factory = functools.partial( create_protocol, # For backwards compatibility with 10.0 or earlier. Done here in From e05f6dc83434dae7d91fc0db822ab15aa1e4c00b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jul 2024 08:05:24 +0200 Subject: [PATCH 1289/1539] Support ws:// to wss:// redirects. Fix #1454. --- src/websockets/legacy/client.py | 14 ++++++++++---- tests/legacy/test_client_server.py | 13 ++++++++++++- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index b15eddf75..b0e15b543 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -558,21 +558,27 @@ def handle_redirect(self, uri: str) -> None: raise SecurityError("redirect from WSS to WS") same_origin = ( - old_wsuri.host == new_wsuri.host and old_wsuri.port == new_wsuri.port + old_wsuri.secure == new_wsuri.secure + and old_wsuri.host == new_wsuri.host + and old_wsuri.port == new_wsuri.port ) - # Rewrite the host and port arguments for cross-origin redirects. + # Rewrite secure, host, and port for cross-origin redirects. # This preserves connection overrides with the host and port # arguments if the redirect points to the same host and port. if not same_origin: - # Replace the host and port argument passed to the protocol factory. factory = self._create_connection.args[0] + # Support TLS upgrade. + if not old_wsuri.secure and new_wsuri.secure: + factory.keywords["secure"] = True + self._create_connection.keywords.setdefault("ssl", True) + # Replace secure, host, and port arguments of the protocol factory. factory = functools.partial( factory.func, *factory.args, **dict(factory.keywords, host=new_wsuri.host, port=new_wsuri.port), ) - # Replace the host and port argument passed to create_connection. + # Replace secure, host, and port arguments of create_connection. self._create_connection = functools.partial( self._create_connection.func, *(factory, new_wsuri.host, new_wsuri.port), diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index c38086572..09b3b361a 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -75,6 +75,8 @@ async def redirect_request(path, headers, test, status): location = "/" elif path == "/infinite": location = get_server_uri(test.server, test.secure, "/infinite") + elif path == "/force_secure": + location = get_server_uri(test.server, True, "/") elif path == "/force_insecure": location = get_server_uri(test.server, False, "/") elif path == "/missing_location": @@ -1290,7 +1292,16 @@ def test_connection_error_during_closing_handshake(self, close): class ClientServerTests( CommonClientServerTests, ClientServerTestsMixin, AsyncioTestCase ): - pass + + def test_redirect_secure(self): + with temp_test_redirecting_server(self): + # websockets doesn't support serving non-TLS and TLS connections + # from the same server and this test suite makes it difficult to + # run two servers. Therefore, we expect the redirect to create a + # TLS client connection to a non-TLS server, which will fail. + with self.assertRaises(ssl.SSLError): + with self.temp_client("/force_secure"): + self.fail("did not raise") class SecureClientServerTests( From 61b69db60cceff6c46ff308d2c10f7f81480788c Mon Sep 17 00:00:00 2001 From: Antonio Curado Date: Mon, 20 May 2024 18:29:55 +0200 Subject: [PATCH 1290/1539] Correct handle exceptions in `legacy/broadcast` --- src/websockets/legacy/protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 67161019f..3d09440e1 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1630,5 +1630,5 @@ def broadcast( exc_info=True, ) - if raise_exceptions: + if raise_exceptions and exceptions: raise ExceptionGroup("skipped broadcast", exceptions) From 96d3adf6617fd53fc7a7adcc5a560eeeb8493473 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jul 2024 08:30:01 +0200 Subject: [PATCH 1291/1539] Add tests for previous commit. --- tests/legacy/test_protocol.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index f3dcd9ac7..05d2f3795 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1472,10 +1472,24 @@ def test_broadcast_text(self): broadcast([self.protocol], "café") self.assertOneFrameSent(True, OP_TEXT, "café".encode()) + @unittest.skipIf( + sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+" + ) + def test_broadcast_text_reports_no_errors(self): + broadcast([self.protocol], "café", raise_exceptions=True) + self.assertOneFrameSent(True, OP_TEXT, "café".encode()) + def test_broadcast_binary(self): broadcast([self.protocol], b"tea") self.assertOneFrameSent(True, OP_BINARY, b"tea") + @unittest.skipIf( + sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+" + ) + def test_broadcast_binary_reports_no_errors(self): + broadcast([self.protocol], b"tea", raise_exceptions=True) + self.assertOneFrameSent(True, OP_BINARY, b"tea") + def test_broadcast_type_error(self): with self.assertRaises(TypeError): broadcast([self.protocol], ["ca", "fé"]) From ee997c157d3214f758c2422fc44c2a582153f58a Mon Sep 17 00:00:00 2001 From: xuanzhi33 <37460139+xuanzhi33@users.noreply.github.com> Date: Wed, 28 Feb 2024 17:36:11 +0800 Subject: [PATCH 1292/1539] docs: Correct the example for "Starting a server" in the API reference --- src/websockets/legacy/server.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 08c82df25..fb91265d8 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -899,7 +899,8 @@ class Serve: server = await serve(...) await stop - await server.close() + server.close() + await server.wait_closed() :func:`serve` can be used as an asynchronous context manager. Then, the server is shut down automatically when exiting the context:: From 1210ee81e470bd8df7700d459ba101263fb7413c Mon Sep 17 00:00:00 2001 From: xuanzhi33 <37460139+xuanzhi33@users.noreply.github.com> Date: Wed, 28 Feb 2024 18:30:01 +0800 Subject: [PATCH 1293/1539] Update server.rst --- docs/faq/server.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/faq/server.rst b/docs/faq/server.rst index 08b412d30..cba1cd35f 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -300,7 +300,8 @@ Here's how to adapt the example just above:: server = await websockets.serve(echo, "localhost", 8765) await stop - await server.close(close_connections=False) + server.close(close_connections=False) + await server.wait_closed() How do I implement a health check? ---------------------------------- From 41c42b8681dc1245a65e2db8491a573bba1827dc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jul 2024 13:39:14 +0200 Subject: [PATCH 1294/1539] Make it easier to debug version numbers. --- src/websockets/version.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/websockets/version.py b/src/websockets/version.py index f1de3cbf4..145c7a9ed 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -55,7 +55,8 @@ def get_version(tag: str) -> str: else: description_re = r"[0-9.]+-([0-9]+)-(g[0-9a-f]{7,}(?:-dirty)?)" match = re.fullmatch(description_re, description) - assert match is not None + if match is None: + raise ValueError(f"Unexpected git description: {description}") distance, remainder = match.groups() remainder = remainder.replace("-", ".") # required by PEP 440 return f"{tag}.dev{distance}+{remainder}" @@ -75,7 +76,8 @@ def get_commit(tag: str, version: str) -> str: # Extract commit from version, falling back to tag if not available. version_re = r"[0-9.]+\.dev[0-9]+\+g([0-9a-f]{7,}|unknown)(?:\.dirty)?" match = re.fullmatch(version_re, version) - assert match is not None + if match is None: + raise ValueError(f"Unexpected version: {version}") (commit,) = match.groups() return tag if commit == "unknown" else commit From eaa64c07676a9c28d17c9538242fab4638754584 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jul 2024 14:41:49 +0200 Subject: [PATCH 1295/1539] Avoid reading the wrong version. This was causing builds to fail on Read the Docs since sphinx-autobuild added websockets as a dependency. --- src/websockets/version.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/websockets/version.py b/src/websockets/version.py index 145c7a9ed..46ae34a47 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -34,8 +34,20 @@ def get_version(tag: str) -> str: file_path = pathlib.Path(__file__) root_dir = file_path.parents[0 if file_path.name == "setup.py" else 2] - # Read version from git if available. This prevents reading stale - # information from src/websockets.egg-info after building a sdist. + # Read version from package metadata if it is installed. + try: + version = importlib.metadata.version("websockets") + except ImportError: + pass + else: + # Check that this file belongs to the installed package. + files = importlib.metadata.files("websockets") + if files: + version_file = [f for f in files if f.name == file_path.name][0] + if version_file.locate() == file_path: + return version + + # Read version from git if available. try: description = subprocess.run( ["git", "describe", "--dirty", "--tags", "--long"], @@ -61,12 +73,6 @@ def get_version(tag: str) -> str: remainder = remainder.replace("-", ".") # required by PEP 440 return f"{tag}.dev{distance}+{remainder}" - # Read version from package metadata if it is installed. - try: - return importlib.metadata.version("websockets") - except ImportError: - pass - # Avoid crashing if the development version cannot be determined. return f"{tag}.dev0+gunknown" From e10eebaec368b04f28102df513e26e933ed5a6fd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jul 2024 15:24:25 +0200 Subject: [PATCH 1296/1539] Unshallow git clone on RtD. This is required for get_version to find the last tag. --- .readthedocs.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.readthedocs.yml b/.readthedocs.yml index 0369e0656..28c990c5c 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -4,6 +4,9 @@ build: os: ubuntu-20.04 tools: python: "3.10" + jobs: + post_checkout: + - git fetch --unshallow sphinx: configuration: docs/conf.py From c8c0a9bfee962540eb3c9c228e36d4ef7bd7ed42 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jul 2024 15:56:42 +0200 Subject: [PATCH 1297/1539] Improve error reporting when header is too long. Refs #1471. --- src/websockets/legacy/server.py | 6 +++++- src/websockets/server.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index fb91265d8..c0c138767 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -199,10 +199,14 @@ async def handler(self) -> None: elif isinstance(exc, InvalidHandshake): if self.debug: self.logger.debug("! invalid handshake", exc_info=True) + exc_str = f"{exc}" + while exc.__cause__ is not None: + exc = exc.__cause__ + exc_str += f"; {exc}" status, headers, body = ( http.HTTPStatus.BAD_REQUEST, Headers(), - f"Failed to open a WebSocket connection: {exc}.\n".encode(), + f"Failed to open a WebSocket connection: {exc_str}.\n".encode(), ) else: self.logger.error("opening handshake failed", exc_info=True) diff --git a/src/websockets/server.py b/src/websockets/server.py index f976ebad7..7f5631230 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -163,9 +163,13 @@ def accept(self, request: Request) -> Response: self.handshake_exc = exc if self.debug: self.logger.debug("! invalid handshake", exc_info=True) + exc_str = f"{exc}" + while exc.__cause__ is not None: + exc = exc.__cause__ + exc_str += f"; {exc}" return self.reject( http.HTTPStatus.BAD_REQUEST, - f"Failed to open a WebSocket connection: {exc}.\n", + f"Failed to open a WebSocket connection: {exc_str}.\n", ) except Exception as exc: # Handle exceptions raised by user-provided select_subprotocol and From d26bac47eac98e6f3b77358b8c836ed02e493fc6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 30 Jul 2024 09:04:22 +0200 Subject: [PATCH 1298/1539] Make eaa64c07 more robust. This avoids crashing on ossfuzz, which uses a custom loader. --- src/websockets/version.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/websockets/version.py b/src/websockets/version.py index 46ae34a47..44709a91b 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -43,9 +43,11 @@ def get_version(tag: str) -> str: # Check that this file belongs to the installed package. files = importlib.metadata.files("websockets") if files: - version_file = [f for f in files if f.name == file_path.name][0] - if version_file.locate() == file_path: - return version + version_files = [f for f in files if f.name == file_path.name] + if version_files: + version_file = version_files[0] + if version_file.locate() == file_path: + return version # Read version from git if available. try: From d2710227cd2464162861b584f5dd83c472208929 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 3 Aug 2024 10:04:43 +0200 Subject: [PATCH 1299/1539] Make mypy happy. --- src/websockets/legacy/server.py | 9 +++++---- src/websockets/server.py | 9 +++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index c0c138767..93698e1cb 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -199,10 +199,11 @@ async def handler(self) -> None: elif isinstance(exc, InvalidHandshake): if self.debug: self.logger.debug("! invalid handshake", exc_info=True) - exc_str = f"{exc}" - while exc.__cause__ is not None: - exc = exc.__cause__ - exc_str += f"; {exc}" + exc_chain = cast(BaseException, exc) + exc_str = f"{exc_chain}" + while exc_chain.__cause__ is not None: + exc_chain = exc_chain.__cause__ + exc_str += f"; {exc_chain}" status, headers, body = ( http.HTTPStatus.BAD_REQUEST, Headers(), diff --git a/src/websockets/server.py b/src/websockets/server.py index 7f5631230..baab400d4 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -163,10 +163,11 @@ def accept(self, request: Request) -> Response: self.handshake_exc = exc if self.debug: self.logger.debug("! invalid handshake", exc_info=True) - exc_str = f"{exc}" - while exc.__cause__ is not None: - exc = exc.__cause__ - exc_str += f"; {exc}" + exc_chain = cast(BaseException, exc) + exc_str = f"{exc_chain}" + while exc_chain.__cause__ is not None: + exc_chain = exc_chain.__cause__ + exc_str += f"; {exc_chain}" return self.reject( http.HTTPStatus.BAD_REQUEST, f"Failed to open a WebSocket connection: {exc_str}.\n", From 309e62fa89311de51083e1a62adf06d4450fc5f2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 3 Aug 2024 10:09:24 +0200 Subject: [PATCH 1300/1539] Test against current PyPy versions. --- .github/workflows/tests.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8161f1cbb..15a45bdfb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -62,15 +62,15 @@ jobs: - "3.10" - "3.11" - "3.12" - - "pypy-3.8" - "pypy-3.9" + - "pypy-3.10" is_main: - ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} exclude: - - python: "pypy-3.8" - is_main: false - python: "pypy-3.9" is_main: false + - python: "pypy-3.10" + is_main: false steps: - name: Check out repository uses: actions/checkout@v4 From fab77d60b660585bcdd996a10ec904f79c901085 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 3 Aug 2024 10:38:17 +0200 Subject: [PATCH 1301/1539] Annotate __init__ methods consistently. --- src/websockets/client.py | 2 +- src/websockets/legacy/server.py | 2 +- src/websockets/server.py | 2 +- src/websockets/sync/server.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 8f78ac320..07d1d34ed 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -79,7 +79,7 @@ def __init__( state: State = CONNECTING, max_size: int | None = 2**20, logger: LoggerLike | None = None, - ): + ) -> None: super().__init__( side=CLIENT, state=state, diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 93698e1cb..39464be6c 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -666,7 +666,7 @@ class WebSocketServer: """ - def __init__(self, logger: LoggerLike | None = None): + def __init__(self, logger: LoggerLike | None = None) -> None: if logger is None: logger = logging.getLogger("websockets.server") self.logger = logger diff --git a/src/websockets/server.py b/src/websockets/server.py index baab400d4..7211d3cbf 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -90,7 +90,7 @@ def __init__( state: State = CONNECTING, max_size: int | None = 2**20, logger: LoggerLike | None = None, - ): + ) -> None: super().__init__( side=SERVER, state=state, diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 4f088b63a..7fb46f5aa 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -200,7 +200,7 @@ def __init__( socket: socket.socket, handler: Callable[[socket.socket, Any], None], logger: LoggerLike | None = None, - ): + ) -> None: self.socket = socket self.handler = handler if logger is None: From 14cca7699971c19c30a38cad260aeb5f26e0c3ca Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 3 Aug 2024 10:46:04 +0200 Subject: [PATCH 1302/1539] Bugs in coverage were fixed \o/ --- src/websockets/legacy/protocol.py | 3 +-- tests/legacy/test_client_server.py | 1 - tests/legacy/test_protocol.py | 18 ++++++------------ 3 files changed, 7 insertions(+), 15 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 3d09440e1..de9ea59b6 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1349,8 +1349,7 @@ async def close_transport(self) -> None: # Abort the TCP connection. Buffers are discarded. if self.debug: self.logger.debug("x aborting TCP connection") - # Due to a bug in coverage, this is erroneously reported as not covered. - self.transport.abort() # pragma: no cover + self.transport.abort() # connection_lost() is called quickly after aborting. await self.wait_for_connection_lost() diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 09b3b361a..0c3f22156 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1584,7 +1584,6 @@ async def run_client(): else: # Exit block with an exception. raise Exception("BOOM") - pass # work around bug in coverage with self.assertLogs("websockets", logging.INFO) as logs: with self.assertRaises(Exception) as raised: diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 05d2f3795..d6303dcc7 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1613,8 +1613,7 @@ def test_local_close_connection_lost_timeout_after_write_eof(self): self.receive_frame(self.close_frame) self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) - # Due to a bug in coverage, this is erroneously reported as not covered. - self.assertConnectionClosed( # pragma: no cover + self.assertConnectionClosed( CloseCode.NORMAL_CLOSURE, "close", ) @@ -1634,8 +1633,7 @@ def test_local_close_connection_lost_timeout_after_close(self): self.receive_frame(self.close_frame) self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) - # Due to a bug in coverage, this is erroneously reported as not covered. - self.assertConnectionClosed( # pragma: no cover + self.assertConnectionClosed( CloseCode.NORMAL_CLOSURE, "close", ) @@ -1656,8 +1654,7 @@ def test_local_close_send_close_frame_timeout(self): # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(19 * MS, 29 * MS): self.loop.run_until_complete(self.protocol.close(reason="close")) - # Due to a bug in coverage, this is erroneously reported as not covered. - self.assertConnectionClosed( # pragma: no cover + self.assertConnectionClosed( CloseCode.ABNORMAL_CLOSURE, "", ) @@ -1670,8 +1667,7 @@ def test_local_close_receive_close_frame_timeout(self): # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(19 * MS, 29 * MS): self.loop.run_until_complete(self.protocol.close(reason="close")) - # Due to a bug in coverage, this is erroneously reported as not covered. - self.assertConnectionClosed( # pragma: no cover + self.assertConnectionClosed( CloseCode.ABNORMAL_CLOSURE, "", ) @@ -1689,8 +1685,7 @@ def test_local_close_connection_lost_timeout_after_write_eof(self): self.receive_frame(self.close_frame) self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) - # Due to a bug in coverage, this is erroneously reported as not covered. - self.assertConnectionClosed( # pragma: no cover + self.assertConnectionClosed( CloseCode.NORMAL_CLOSURE, "close", ) @@ -1713,8 +1708,7 @@ def test_local_close_connection_lost_timeout_after_close(self): self.receive_frame(self.close_frame) self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) - # Due to a bug in coverage, this is erroneously reported as not covered. - self.assertConnectionClosed( # pragma: no cover + self.assertConnectionClosed( CloseCode.NORMAL_CLOSURE, "close", ) From 7bb18a6ea84d2651b68ad45f5e9464a47d314b6b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 5 Aug 2024 07:53:06 +0200 Subject: [PATCH 1303/1539] Update references to Python's bug tracker. --- src/websockets/__main__.py | 2 +- src/websockets/legacy/protocol.py | 6 +++--- src/websockets/legacy/server.py | 3 ++- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index f2ea5cf4e..8647481d0 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -22,7 +22,7 @@ def win_enable_vt100() -> None: """ Enable VT-100 for console output on Windows. - See also https://bugs.python.org/issue29059. + See also https://github.com/python/cpython/issues/73245. """ import ctypes diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index de9ea59b6..57cb4e770 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1175,9 +1175,9 @@ def write_frame_sync(self, fin: bool, opcode: int, data: bytes) -> None: async def drain(self) -> None: try: - # drain() cannot be called concurrently by multiple coroutines: - # http://bugs.python.org/issue29930. Remove this lock when no - # version of Python where this bugs exists is supported anymore. + # drain() cannot be called concurrently by multiple coroutines. + # See https://github.com/python/cpython/issues/74116 for details. + # This workaround can be removed when dropping Python < 3.10. async with self._drain_lock: # Handle flow control automatically. await self._drain() diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 39464be6c..f4442fecc 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -764,7 +764,8 @@ async def _close(self, close_connections: bool) -> None: self.server.close() # Wait until all accepted connections reach connection_made() and call - # register(). See https://bugs.python.org/issue34852 for details. + # register(). See https://github.com/python/cpython/issues/79033 for + # details. This workaround can be removed when dropping Python < 3.11. await asyncio.sleep(0) if close_connections: From 273db5bcc4113061bd7d8f0a4edbf6c4d76c4d84 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 4 Aug 2024 15:18:07 +0200 Subject: [PATCH 1304/1539] Make it easier to enable logs while running tests. --- tests/__init__.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/__init__.py b/tests/__init__.py index dd78609f5..bb1866f2d 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,5 +1,14 @@ import logging +import os -# Avoid displaying stack traces at the ERROR logging level. -logging.basicConfig(level=logging.CRITICAL) +format = "%(asctime)s %(levelname)s %(name)s %(message)s" + +if bool(os.environ.get("WEBSOCKETS_DEBUG")): # pragma: no cover + # Display every frame sent or received in debug mode. + level = logging.DEBUG +else: + # Hide stack traces of exceptions. + level = logging.CRITICAL + +logging.basicConfig(format=format, level=level) From cbcb7fd715be0f1efb98102302739e9d9f3ca08c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 5 Aug 2024 15:59:16 +0200 Subject: [PATCH 1305/1539] Pass WEBSOCKETS_TESTS_TIMEOUT_FACTOR to tox. Previously, despite being declared in .github/workflows/tests.yml, it had no effect because tox insulates test runs from the environment. --- tox.ini | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 06003c85b..b00833e73 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = +env_list = py38 py39 py310 @@ -12,6 +12,7 @@ envlist = [testenv] commands = python -W error::DeprecationWarning -W error::PendingDeprecationWarning -m unittest {posargs} +pass_env = WEBSOCKETS_* [testenv:coverage] commands = From 8c4fd9c24a701b7050681f786b2918d205b91338 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 5 Aug 2024 16:04:10 +0200 Subject: [PATCH 1306/1539] Remove superfluous `coverage erase`. `coverage run` starts clean unless `--append` is specified. --- tox.ini | 1 - 1 file changed, 1 deletion(-) diff --git a/tox.ini b/tox.ini index b00833e73..1edcfe261 100644 --- a/tox.ini +++ b/tox.ini @@ -16,7 +16,6 @@ pass_env = WEBSOCKETS_* [testenv:coverage] commands = - python -m coverage erase python -m coverage run --source {envsitepackagesdir}/websockets,tests -m unittest {posargs} python -m coverage report --show-missing --fail-under=100 deps = coverage From 02b333829e385d4fef42fa0565996adcfea653b3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 5 Aug 2024 08:59:51 +0200 Subject: [PATCH 1307/1539] Make Protocol.receive_eof idempotent. This removes the need for keeping track of whether you called it or not, especially in an asyncio context where it may be called in eof_received or in connection_lost. --- src/websockets/protocol.py | 5 +++-- tests/test_protocol.py | 8 ++------ 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 2f5542f6e..7f2b45c74 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -270,10 +270,11 @@ def receive_eof(self) -> None: - You aren't expected to call :meth:`events_received`; it won't return any new events. - Raises: - EOFError: If :meth:`receive_eof` was called earlier. + :meth:`receive_eof` is idempotent. """ + if self.reader.eof: + return self.reader.feed_eof() next(self.parser) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index e1527525b..7f1276bb2 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -1590,18 +1590,14 @@ def test_client_receives_eof_after_eof(self): client.receive_data(b"\x88\x00") self.assertConnectionClosing(client) client.receive_eof() - with self.assertRaises(EOFError) as raised: - client.receive_eof() - self.assertEqual(str(raised.exception), "stream ended") + client.receive_eof() # this is idempotent def test_server_receives_eof_after_eof(self): server = Protocol(SERVER) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") self.assertConnectionClosing(server) server.receive_eof() - with self.assertRaises(EOFError) as raised: - server.receive_eof() - self.assertEqual(str(raised.exception), "stream ended") + server.receive_eof() # this is idempotent class TCPCloseTests(ProtocolTestCase): From 3ad92b50515e5c83344f0771a34e8d9e7cd8ff4e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 7 Aug 2024 15:28:18 +0200 Subject: [PATCH 1308/1539] Don't specify the encoding when it's utf-8. --- src/websockets/frames.py | 2 +- src/websockets/legacy/protocol.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 0da676432..af56d3f8f 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -447,7 +447,7 @@ def parse(cls, data: bytes) -> Close: """ if len(data) >= 2: (code,) = struct.unpack("!H", data[:2]) - reason = data[2:].decode("utf-8") + reason = data[2:].decode() close = cls(code, reason) close.check() return close diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 57cb4e770..c28bdcf48 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1036,7 +1036,7 @@ async def read_message(self) -> Data | None: # Shortcut for the common case - no fragmentation if frame.fin: - return frame.data.decode("utf-8") if text else frame.data + return frame.data.decode() if text else frame.data # 5.4. Fragmentation fragments: list[Data] = [] From d0fd9cf61432a8dcf1cd639139ebb57d4b522c01 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 4 Aug 2024 08:08:35 +0200 Subject: [PATCH 1309/1539] Improve tests for sync implementation slightly. --- src/websockets/sync/connection.py | 3 +- tests/sync/connection.py | 4 +- tests/sync/test_connection.py | 76 +++++++++++++++++++++---------- 3 files changed, 57 insertions(+), 26 deletions(-) diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 33d8299e2..2bcb3aa0e 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -373,6 +373,7 @@ def send(self, message: Data | Iterable[Data]) -> None: except RuntimeError: # We didn't start sending a fragmented message. + # The connection is still usable. raise except Exception: @@ -756,7 +757,7 @@ def set_recv_exc(self, exc: BaseException | None) -> None: """ assert self.protocol_mutex.locked() - if self.recv_exc is None: + if self.recv_exc is None: # pragma: no branch self.recv_exc = exc def close_socket(self) -> None: diff --git a/tests/sync/connection.py b/tests/sync/connection.py index 89d4909ee..9c8bacea0 100644 --- a/tests/sync/connection.py +++ b/tests/sync/connection.py @@ -8,7 +8,7 @@ class InterceptingConnection(Connection): """ Connection subclass that can intercept outgoing packets. - By interfacing with this connection, you can simulate network conditions + By interfacing with this connection, we simulate network conditions affecting what the component being tested receives during a test. """ @@ -80,7 +80,7 @@ def drop_eof_sent(self): class InterceptingSocket: """ - Socket wrapper that intercepts calls to sendall and shutdown. + Socket wrapper that intercepts calls to ``sendall()`` and ``shutdown()``. This is coupled to the implementation, which relies on these two methods. diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 953c8c253..88cbcd669 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -246,13 +246,15 @@ def test_recv_streaming_connection_closed_ok(self): """recv_streaming raises ConnectionClosedOK after a normal closure.""" self.remote_connection.close() with self.assertRaises(ConnectionClosedOK): - list(self.connection.recv_streaming()) + for _ in self.connection.recv_streaming(): + self.fail("did not raise") def test_recv_streaming_connection_closed_error(self): """recv_streaming raises ConnectionClosedError after an error.""" self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) with self.assertRaises(ConnectionClosedError): - list(self.connection.recv_streaming()) + for _ in self.connection.recv_streaming(): + self.fail("did not raise") def test_recv_streaming_during_recv(self): """recv_streaming raises RuntimeError when called concurrently with recv.""" @@ -260,7 +262,8 @@ def test_recv_streaming_during_recv(self): recv_thread.start() with self.assertRaises(RuntimeError) as raised: - list(self.connection.recv_streaming()) + for _ in self.connection.recv_streaming(): + self.fail("did not raise") self.assertEqual( str(raised.exception), "cannot call recv_streaming while another thread " @@ -278,7 +281,8 @@ def test_recv_streaming_during_recv_streaming(self): recv_streaming_thread.start() with self.assertRaises(RuntimeError) as raised: - list(self.connection.recv_streaming()) + for _ in self.connection.recv_streaming(): + self.fail("did not raise") self.assertEqual( str(raised.exception), r"cannot call recv_streaming while another thread " @@ -374,7 +378,7 @@ def test_send_empty_iterable(self): """send does nothing when called with an empty iterable.""" self.connection.send([]) self.connection.close() - self.assertEqual(list(iter(self.remote_connection)), []) + self.assertEqual(list(self.remote_connection), []) def test_send_mixed_iterable(self): """send raises TypeError when called with an iterable of inconsistent types.""" @@ -437,7 +441,7 @@ def test_close_waits_for_connection_closed(self): def test_close_timeout_waiting_for_close_frame(self): """close times out if no close frame is received.""" - with self.drop_eof_rcvd(), self.drop_frames_rcvd(): + with self.drop_frames_rcvd(), self.drop_eof_rcvd(): self.connection.close() with self.assertRaises(ConnectionClosedError) as raised: @@ -464,6 +468,10 @@ def test_close_timeout_waiting_for_connection_closed(self): self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) def test_close_waits_for_recv(self): + # The sync implementation doesn't have a buffer for incoming messsages. + # It requires reading incoming frames until the close frame is reached. + # This behavior — close() blocks until recv() is called — is less than + # ideal and inconsistent with the asyncio implementation. self.remote_connection.send("😀") close_thread = threading.Thread(target=self.connection.close) @@ -547,6 +555,25 @@ def closer(): close_thread.join() + def test_close_during_recv(self): + """close aborts recv when called concurrently with recv.""" + + def closer(): + time.sleep(MS) + self.connection.close() + + close_thread = threading.Thread(target=closer) + close_thread.start() + + with self.assertRaises(ConnectionClosedOK) as raised: + self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + close_thread.join() + def test_close_during_send(self): """close fails the connection when called concurrently with send.""" close_gate = threading.Event() @@ -599,42 +626,45 @@ def test_ping_explicit_binary(self): self.connection.ping(b"ping") self.assertFrameSent(Frame(Opcode.PING, b"ping")) - def test_ping_duplicate_payload(self): - """ping rejects the same payload until receiving the pong.""" - with self.remote_connection.protocol_mutex: # block response to ping - pong_waiter = self.connection.ping("idem") - with self.assertRaises(RuntimeError) as raised: - self.connection.ping("idem") - self.assertEqual( - str(raised.exception), - "already waiting for a pong with the same data", - ) - self.assertTrue(pong_waiter.wait(MS)) - self.connection.ping("idem") # doesn't raise an exception - def test_acknowledge_ping(self): """ping is acknowledged by a pong with the same payload.""" - with self.drop_frames_rcvd(): + with self.drop_frames_rcvd(): # drop automatic response to ping pong_waiter = self.connection.ping("this") - self.assertFalse(pong_waiter.wait(MS)) self.remote_connection.pong("this") self.assertTrue(pong_waiter.wait(MS)) def test_acknowledge_ping_non_matching_pong(self): """ping isn't acknowledged by a pong with a different payload.""" - with self.drop_frames_rcvd(): + with self.drop_frames_rcvd(): # drop automatic response to ping pong_waiter = self.connection.ping("this") self.remote_connection.pong("that") self.assertFalse(pong_waiter.wait(MS)) def test_acknowledge_previous_ping(self): """ping is acknowledged by a pong with the same payload as a later ping.""" - with self.drop_frames_rcvd(): + with self.drop_frames_rcvd(): # drop automatic response to ping pong_waiter = self.connection.ping("this") self.connection.ping("that") self.remote_connection.pong("that") self.assertTrue(pong_waiter.wait(MS)) + def test_ping_duplicate_payload(self): + """ping rejects the same payload until receiving the pong.""" + with self.drop_frames_rcvd(): # drop automatic response to ping + pong_waiter = self.connection.ping("idem") + + with self.assertRaises(RuntimeError) as raised: + self.connection.ping("idem") + self.assertEqual( + str(raised.exception), + "already waiting for a pong with the same data", + ) + + self.remote_connection.pong("idem") + self.assertTrue(pong_waiter.wait(MS)) + + self.connection.ping("idem") # doesn't raise an exception + # Test pong. def test_pong(self): From c92fba02db87a88af54a47e7f5bae050587490dd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 7 Apr 2023 23:11:11 +0200 Subject: [PATCH 1310/1539] Move asyncio compatibility to a new package. --- pyproject.toml | 4 ++-- src/websockets/asyncio/__init__.py | 0 src/websockets/{legacy => asyncio}/async_timeout.py | 0 src/websockets/{legacy => asyncio}/compatibility.py | 0 src/websockets/legacy/client.py | 2 +- src/websockets/legacy/protocol.py | 2 +- src/websockets/legacy/server.py | 2 +- tests/legacy/test_client_server.py | 2 +- tests/maxi_cov.py | 6 ++++-- 9 files changed, 10 insertions(+), 8 deletions(-) create mode 100644 src/websockets/asyncio/__init__.py rename src/websockets/{legacy => asyncio}/async_timeout.py (100%) rename src/websockets/{legacy => asyncio}/compatibility.py (100%) diff --git a/pyproject.toml b/pyproject.toml index 2367849ca..de8acd6a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,8 +47,8 @@ branch = true omit = [ # */websockets matches src/websockets and .tox/**/site-packages/websockets "*/websockets/__main__.py", - "*/websockets/legacy/async_timeout.py", - "*/websockets/legacy/compatibility.py", + "*/websockets/asyncio/async_timeout.py", + "*/websockets/asyncio/compatibility.py", "tests/maxi_cov.py", ] diff --git a/src/websockets/asyncio/__init__.py b/src/websockets/asyncio/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/websockets/legacy/async_timeout.py b/src/websockets/asyncio/async_timeout.py similarity index 100% rename from src/websockets/legacy/async_timeout.py rename to src/websockets/asyncio/async_timeout.py diff --git a/src/websockets/legacy/compatibility.py b/src/websockets/asyncio/compatibility.py similarity index 100% rename from src/websockets/legacy/compatibility.py rename to src/websockets/asyncio/compatibility.py diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index b0e15b543..d1d8d5608 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -16,6 +16,7 @@ cast, ) +from ..asyncio.compatibility import asyncio_timeout from ..datastructures import Headers, HeadersLike from ..exceptions import ( InvalidHandshake, @@ -40,7 +41,6 @@ from ..http import USER_AGENT from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol from ..uri import WebSocketURI, parse_uri -from .compatibility import asyncio_timeout from .handshake import build_request, check_response from .http import read_response from .protocol import WebSocketCommonProtocol diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index c28bdcf48..120ff8e73 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -23,6 +23,7 @@ cast, ) +from ..asyncio.compatibility import asyncio_timeout from ..datastructures import Headers from ..exceptions import ( ConnectionClosed, @@ -49,7 +50,6 @@ ) from ..protocol import State from ..typing import Data, LoggerLike, Subprotocol -from .compatibility import asyncio_timeout from .framing import Frame diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index f4442fecc..208ffa780 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -21,6 +21,7 @@ cast, ) +from ..asyncio.compatibility import asyncio_timeout from ..datastructures import Headers, HeadersLike, MultipleValuesError from ..exceptions import ( AbortHandshake, @@ -42,7 +43,6 @@ from ..http import USER_AGENT from ..protocol import State from ..typing import ExtensionHeader, LoggerLike, Origin, StatusLike, Subprotocol -from .compatibility import asyncio_timeout from .handshake import build_response, check_request from .http import read_request from .protocol import WebSocketCommonProtocol diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 0c3f22156..b5c5d726a 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -14,6 +14,7 @@ import urllib.request import warnings +from websockets.asyncio.compatibility import asyncio_timeout from websockets.datastructures import Headers from websockets.exceptions import ( ConnectionClosed, @@ -30,7 +31,6 @@ from websockets.frames import CloseCode from websockets.http import USER_AGENT from websockets.legacy.client import * -from websockets.legacy.compatibility import asyncio_timeout from websockets.legacy.handshake import build_response from websockets.legacy.http import read_response from websockets.legacy.server import * diff --git a/tests/maxi_cov.py b/tests/maxi_cov.py index bc4a44e8c..83686c3d3 100755 --- a/tests/maxi_cov.py +++ b/tests/maxi_cov.py @@ -52,8 +52,9 @@ def get_mapping(src_dir="src"): os.path.relpath(src_file, src_dir) for src_file in sorted(src_files) if "legacy" not in os.path.dirname(src_file) - if os.path.basename(src_file) != "__init__.py" + and os.path.basename(src_file) != "__init__.py" and os.path.basename(src_file) != "__main__.py" + and os.path.basename(src_file) != "async_timeout.py" and os.path.basename(src_file) != "compatibility.py" ] test_files = [ @@ -102,7 +103,8 @@ def get_ignored_files(src_dir="src"): "*/websockets/typing.py", # We don't test compatibility modules with previous versions of Python # or websockets (import locations). - "*/websockets/*/compatibility.py", + "*/websockets/asyncio/async_timeout.py", + "*/websockets/asyncio/compatibility.py", "*/websockets/auth.py", # This approach isn't applicable to the test suite of the legacy # implementation, due to the huge test_client_server test module. From 9f8f2f27218e4dc7ad4126109e6ffe012946b71b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 7 Apr 2023 23:12:28 +0200 Subject: [PATCH 1311/1539] Add asyncio message assembler. --- src/websockets/asyncio/compatibility.py | 20 +- src/websockets/asyncio/messages.py | 283 ++++++++++++++ src/websockets/sync/messages.py | 4 +- tests/asyncio/__init__.py | 0 tests/asyncio/test_messages.py | 471 ++++++++++++++++++++++++ tests/asyncio/utils.py | 5 + 6 files changed, 777 insertions(+), 6 deletions(-) create mode 100644 src/websockets/asyncio/messages.py create mode 100644 tests/asyncio/__init__.py create mode 100644 tests/asyncio/test_messages.py create mode 100644 tests/asyncio/utils.py diff --git a/src/websockets/asyncio/compatibility.py b/src/websockets/asyncio/compatibility.py index 6bd01e70d..390f00ac7 100644 --- a/src/websockets/asyncio/compatibility.py +++ b/src/websockets/asyncio/compatibility.py @@ -3,10 +3,22 @@ import sys -__all__ = ["asyncio_timeout"] +__all__ = ["TimeoutError", "aiter", "anext", "asyncio_timeout"] if sys.version_info[:2] >= (3, 11): - from asyncio import timeout as asyncio_timeout # noqa: F401 -else: - from .async_timeout import timeout as asyncio_timeout # noqa: F401 + TimeoutError = TimeoutError + aiter = aiter + anext = anext + from asyncio import timeout as asyncio_timeout + +else: # Python < 3.11 + from asyncio import TimeoutError + + def aiter(async_iterable): + return type(async_iterable).__aiter__(async_iterable) + + async def anext(async_iterator): + return await type(async_iterator).__anext__(async_iterator) + + from .async_timeout import timeout as asyncio_timeout diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py new file mode 100644 index 000000000..2a9c4d37d --- /dev/null +++ b/src/websockets/asyncio/messages.py @@ -0,0 +1,283 @@ +from __future__ import annotations + +import asyncio +import codecs +import collections +from typing import ( + Any, + AsyncIterator, + Callable, + Generic, + Iterable, + TypeVar, +) + +from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame +from ..typing import Data + + +__all__ = ["Assembler"] + +UTF8Decoder = codecs.getincrementaldecoder("utf-8") + +T = TypeVar("T") + + +class SimpleQueue(Generic[T]): + """ + Simplified version of :class:`asyncio.Queue`. + + Provides only the subset of functionality needed by :class:`Assembler`. + + """ + + def __init__(self) -> None: + self.loop = asyncio.get_running_loop() + self.get_waiter: asyncio.Future[None] | None = None + self.queue: collections.deque[T] = collections.deque() + + def __len__(self) -> int: + return len(self.queue) + + def put(self, item: T) -> None: + """Put an item into the queue without waiting.""" + self.queue.append(item) + if self.get_waiter is not None and not self.get_waiter.done(): + self.get_waiter.set_result(None) + + async def get(self) -> T: + """Remove and return an item from the queue, waiting if necessary.""" + if not self.queue: + if self.get_waiter is not None: + raise RuntimeError("get is already running") + self.get_waiter = self.loop.create_future() + try: + await self.get_waiter + finally: + self.get_waiter.cancel() + self.get_waiter = None + return self.queue.popleft() + + def reset(self, items: Iterable[T]) -> None: + """Put back items into an empty, idle queue.""" + assert self.get_waiter is None, "cannot reset() while get() is running" + assert not self.queue, "cannot reset() while queue isn't empty" + self.queue.extend(items) + + def abort(self) -> None: + if self.get_waiter is not None and not self.get_waiter.done(): + self.get_waiter.set_exception(EOFError("stream of frames ended")) + # Clear the queue to avoid storing unnecessary data in memory. + self.queue.clear() + + +class Assembler: + """ + Assemble messages from frames. + + :class:`Assembler` expects only data frames. The stream of frames must + respect the protocol; if it doesn't, the behavior is undefined. + + Args: + pause: Called when the buffer of frames goes above the high water mark; + should pause reading from the network. + resume: Called when the buffer of frames goes below the low water mark; + should resume reading from the network. + + """ + + # coverage reports incorrectly: "line NN didn't jump to the function exit" + def __init__( # pragma: no cover + self, + pause: Callable[[], Any] = lambda: None, + resume: Callable[[], Any] = lambda: None, + ) -> None: + # Queue of incoming messages. Each item is a queue of frames. + self.frames: SimpleQueue[Frame] = SimpleQueue() + + # We cannot put a hard limit on the size of the queue because a single + # call to Protocol.data_received() could produce thousands of frames, + # which must be buffered. Instead, we pause reading when the buffer goes + # above the high limit and we resume when it goes under the low limit. + self.high = 16 + self.low = 4 + self.paused = False + self.pause = pause + self.resume = resume + + # This flag prevents concurrent calls to get() by user code. + self.get_in_progress = False + + # This flag marks the end of the connection. + self.closed = False + + async def get(self, decode: bool | None = None) -> Data: + """ + Read the next message. + + :meth:`get` returns a single :class:`str` or :class:`bytes`. + + If the message is fragmented, :meth:`get` waits until the last frame is + received, then it reassembles the message and returns it. To receive + messages frame by frame, use :meth:`get_iter` instead. + + Raises: + EOFError: If the stream of frames has ended. + RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter` + concurrently. + + """ + if self.closed: + raise EOFError("stream of frames ended") + + if self.get_in_progress: + raise RuntimeError("get() or get_iter() is already running") + + # Locking with get_in_progress ensures only one coroutine can get here. + self.get_in_progress = True + + # First frame + try: + frame = await self.frames.get() + except asyncio.CancelledError: + self.get_in_progress = False + raise + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + frames = [frame] + + # Following frames, for fragmented messages + while not frame.fin: + try: + frame = await self.frames.get() + except asyncio.CancelledError: + # Put frames already received back into the queue + # so that future calls to get() can return them. + self.frames.reset(frames) + self.get_in_progress = False + raise + self.maybe_resume() + assert frame.opcode is OP_CONT + frames.append(frame) + + self.get_in_progress = False + + data = b"".join(frame.data for frame in frames) + if decode: + return data.decode() + else: + return data + + async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: + """ + Stream the next message. + + Iterating the return value of :meth:`get_iter` asynchronously yields a + :class:`str` or :class:`bytes` for each frame in the message. + + The iterator must be fully consumed before calling :meth:`get_iter` or + :meth:`get` again. Else, :exc:`RuntimeError` is raised. + + This method only makes sense for fragmented messages. If messages aren't + fragmented, use :meth:`get` instead. + + Raises: + EOFError: If the stream of frames has ended. + RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter` + concurrently. + + """ + if self.closed: + raise EOFError("stream of frames ended") + + if self.get_in_progress: + raise RuntimeError("get() or get_iter() is already running") + + # Locking with get_in_progress ensures only one coroutine can get here. + self.get_in_progress = True + + # First frame + try: + frame = await self.frames.get() + except asyncio.CancelledError: + self.get_in_progress = False + raise + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + if decode: + decoder = UTF8Decoder() + yield decoder.decode(frame.data, frame.fin) + else: + yield frame.data + + # Following frames, for fragmented messages + while not frame.fin: + # We cannot handle asyncio.CancelledError because we don't buffer + # previous fragments — we're streaming them. Canceling get_iter() + # here will leave the assembler in a stuck state. Future calls to + # get() or get_iter() will raise RuntimeError. + frame = await self.frames.get() + self.maybe_resume() + assert frame.opcode is OP_CONT + if decode: + yield decoder.decode(frame.data, frame.fin) + else: + yield frame.data + + self.get_in_progress = False + + def put(self, frame: Frame) -> None: + """ + Add ``frame`` to the next message. + + Raises: + EOFError: If the stream of frames has ended. + + """ + if self.closed: + raise EOFError("stream of frames ended") + + self.frames.put(frame) + self.maybe_pause() + + def get_limits(self) -> tuple[int, int]: + """Return low and high water marks for flow control.""" + return self.low, self.high + + def set_limits(self, low: int = 4, high: int = 16) -> None: + """Configure low and high water marks for flow control.""" + self.low, self.high = low, high + + def maybe_pause(self) -> None: + """Pause the writer if queue is above the high water mark.""" + # Check for "> high" to support high = 0 + if len(self.frames) > self.high and not self.paused: + self.paused = True + self.pause() + + def maybe_resume(self) -> None: + """Resume the writer if queue is below the low water mark.""" + # Check for "<= low" to support low = 0 + if len(self.frames) <= self.low and self.paused: + self.paused = False + self.resume() + + def close(self) -> None: + """ + End the stream of frames. + + Callling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, + or :meth:`put` is safe. They will raise :exc:`EOFError`. + + """ + if self.closed: + return + + self.closed = True + + # Unblock get() or get_iter(). + self.frames.abort() diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 6cbff2595..ff90345ac 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -85,7 +85,7 @@ def get(self, timeout: float | None = None) -> Data: raise EOFError("stream of frames ended") if self.get_in_progress: - raise RuntimeError("get or get_iter is already running") + raise RuntimeError("get() or get_iter() is already running") self.get_in_progress = True @@ -144,7 +144,7 @@ def get_iter(self) -> Iterator[Data]: raise EOFError("stream of frames ended") if self.get_in_progress: - raise RuntimeError("get or get_iter is already running") + raise RuntimeError("get() or get_iter() is already running") chunks = self.chunks self.chunks = [] diff --git a/tests/asyncio/__init__.py b/tests/asyncio/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py new file mode 100644 index 000000000..c8a2d7cd5 --- /dev/null +++ b/tests/asyncio/test_messages.py @@ -0,0 +1,471 @@ +import asyncio +import unittest +import unittest.mock + +from websockets.asyncio.compatibility import aiter, anext +from websockets.asyncio.messages import * +from websockets.asyncio.messages import SimpleQueue +from websockets.frames import OP_BINARY, OP_CONT, OP_TEXT, Frame + +from .utils import alist + + +class SimpleQueueTests(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.queue = SimpleQueue() + + async def test_len(self): + """__len__ returns queue length.""" + self.assertEqual(len(self.queue), 0) + self.queue.put(42) + self.assertEqual(len(self.queue), 1) + await self.queue.get() + self.assertEqual(len(self.queue), 0) + + async def test_put_then_get(self): + """get returns an item that is already put.""" + self.queue.put(42) + item = await self.queue.get() + self.assertEqual(item, 42) + + async def test_get_then_put(self): + """get returns an item when it is put.""" + getter_task = asyncio.create_task(self.queue.get()) + await asyncio.sleep(0) # let the task start + self.queue.put(42) + item = await getter_task + self.assertEqual(item, 42) + + async def test_get_concurrently(self): + """get cannot be called concurrently with itself.""" + getter_task = asyncio.create_task(self.queue.get()) + await asyncio.sleep(0) # let the task start + with self.assertRaises(RuntimeError): + await self.queue.get() + getter_task.cancel() + + async def test_reset(self): + """reset sets the content of the queue.""" + self.queue.reset([42]) + item = await self.queue.get() + self.assertEqual(item, 42) + + async def test_abort(self): + """abort throws an exception in get.""" + getter_task = asyncio.create_task(self.queue.get()) + await asyncio.sleep(0) # let the task start + self.queue.abort() + with self.assertRaises(EOFError): + await getter_task + + async def test_abort_clears_queue(self): + """abort clears buffered data from the queue.""" + self.queue.put(42) + self.assertEqual(len(self.queue), 1) + self.queue.abort() + self.assertEqual(len(self.queue), 0) + + +class AssemblerTests(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.pause = unittest.mock.Mock() + self.resume = unittest.mock.Mock() + self.assembler = Assembler(pause=self.pause, resume=self.resume) + self.assembler.set_limits(low=1, high=2) + + # Test get + + async def test_get_text_message_already_received(self): + """get returns a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_get_binary_message_already_received(self): + """get returns a binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = await self.assembler.get() + self.assertEqual(message, b"tea") + + async def test_get_text_message_not_received_yet(self): + """get returns a text message when it is received.""" + getter_task = asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) # let the event loop start getter_task + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = await getter_task + self.assertEqual(message, "café") + + async def test_get_binary_message_not_received_yet(self): + """get returns a binary message when it is received.""" + getter_task = asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) # let the event loop start getter_task + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = await getter_task + self.assertEqual(message, b"tea") + + async def test_get_fragmented_text_message_already_received(self): + """get reassembles a fragmented a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_get_fragmented_binary_message_already_received(self): + """get reassembles a fragmented binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + message = await self.assembler.get() + self.assertEqual(message, b"tea") + + async def test_get_fragmented_text_message_not_received_yet(self): + """get reassembles a fragmented text message when it is received.""" + getter_task = asyncio.create_task(self.assembler.get()) + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + message = await getter_task + self.assertEqual(message, "café") + + async def test_get_fragmented_binary_message_not_received_yet(self): + """get reassembles a fragmented binary message when it is received.""" + getter_task = asyncio.create_task(self.assembler.get()) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + message = await getter_task + self.assertEqual(message, b"tea") + + async def test_get_fragmented_text_message_being_received(self): + """get reassembles a fragmented text message that is partially received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + getter_task = asyncio.create_task(self.assembler.get()) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + message = await getter_task + self.assertEqual(message, "café") + + async def test_get_fragmented_binary_message_being_received(self): + """get reassembles a fragmented binary message that is partially received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + getter_task = asyncio.create_task(self.assembler.get()) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + message = await getter_task + self.assertEqual(message, b"tea") + + async def test_get_encoded_text_message(self): + """get returns a text message without UTF-8 decoding.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = await self.assembler.get(decode=False) + self.assertEqual(message, b"caf\xc3\xa9") + + async def test_get_decoded_binary_message(self): + """get returns a binary message with UTF-8 decoding.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = await self.assembler.get(decode=True) + self.assertEqual(message, "tea") + + async def test_get_resumes_reading(self): + """get resumes reading when queue goes below the high-water mark.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"water")) + + # queue is above the low-water mark + await self.assembler.get() + self.resume.assert_not_called() + + # queue is at the low-water mark + await self.assembler.get() + self.resume.assert_called_once_with() + + # queue is below the low-water mark + await self.assembler.get() + self.resume.assert_called_once_with() + + async def test_cancel_get_before_first_frame(self): + """get can be canceled safely before reading the first frame.""" + getter_task = asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) # let the event loop start getter_task + getter_task.cancel() + with self.assertRaises(asyncio.CancelledError): + await getter_task + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_cancel_get_after_first_frame(self): + """get can be canceled safely after reading the first frame.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + + getter_task = asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) # let the event loop start getter_task + getter_task.cancel() + with self.assertRaises(asyncio.CancelledError): + await getter_task + + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + + message = await self.assembler.get() + self.assertEqual(message, "café") + + # Test get_iter + + async def test_get_iter_text_message_already_received(self): + """get_iter yields a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + async def test_get_iter_binary_message_already_received(self): + """get_iter yields a binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, [b"tea"]) + + async def test_get_iter_text_message_not_received_yet(self): + """get_iter yields a text message when it is received.""" + getter_task = asyncio.create_task(alist(self.assembler.get_iter())) + await asyncio.sleep(0) # let the event loop start getter_task + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + fragments = await getter_task + self.assertEqual(fragments, ["café"]) + + async def test_get_iter_binary_message_not_received_yet(self): + """get_iter yields a binary message when it is received.""" + getter_task = asyncio.create_task(alist(self.assembler.get_iter())) + await asyncio.sleep(0) # let the event loop start getter_task + self.assembler.put(Frame(OP_BINARY, b"tea")) + fragments = await getter_task + self.assertEqual(fragments, [b"tea"]) + + async def test_get_iter_fragmented_text_message_already_received(self): + """get_iter yields a fragmented text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["ca", "f", "é"]) + + async def test_get_iter_fragmented_binary_message_already_received(self): + """get_iter yields a fragmented binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, [b"t", b"e", b"a"]) + + async def test_get_iter_fragmented_text_message_not_received_yet(self): + """get_iter yields a fragmented text message when it is received.""" + iterator = aiter(self.assembler.get_iter()) + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assertEqual(await anext(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(await anext(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(await anext(iterator), "é") + + async def test_get_iter_fragmented_binary_message_not_received_yet(self): + """get_iter yields a fragmented binary message when it is received.""" + iterator = aiter(self.assembler.get_iter()) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assertEqual(await anext(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(await anext(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(await anext(iterator), b"a") + + async def test_get_iter_fragmented_text_message_being_received(self): + """get_iter yields a fragmented text message that is partially received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + iterator = aiter(self.assembler.get_iter()) + self.assertEqual(await anext(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(await anext(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(await anext(iterator), "é") + + async def test_get_iter_fragmented_binary_message_being_received(self): + """get_iter yields a fragmented binary message that is partially received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + iterator = aiter(self.assembler.get_iter()) + self.assertEqual(await anext(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(await anext(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(await anext(iterator), b"a") + + async def test_get_iter_encoded_text_message(self): + """get_iter yields a text message without UTF-8 decoding.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + fragments = await alist(self.assembler.get_iter(decode=False)) + self.assertEqual(fragments, [b"ca", b"f\xc3", b"\xa9"]) + + async def test_get_iter_decoded_binary_message(self): + """get_iter yields a binary message with UTF-8 decoding.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + fragments = await alist(self.assembler.get_iter(decode=True)) + self.assertEqual(fragments, ["t", "e", "a"]) + + async def test_get_iter_resumes_reading(self): + """get_iter resumes reading when queue goes below the high-water mark.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + iterator = aiter(self.assembler.get_iter()) + + # queue is above the low-water mark + await anext(iterator) + self.resume.assert_not_called() + + # queue is at the low-water mark + await anext(iterator) + self.resume.assert_called_once_with() + + # queue is below the low-water mark + await anext(iterator) + self.resume.assert_called_once_with() + + async def test_cancel_get_iter_before_first_frame(self): + """get_iter can be canceled safely before reading the first frame.""" + getter_task = asyncio.create_task(alist(self.assembler.get_iter())) + await asyncio.sleep(0) # let the event loop start getter_task + getter_task.cancel() + with self.assertRaises(asyncio.CancelledError): + await getter_task + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + async def test_cancel_get_iter_after_first_frame(self): + """get cannot be canceled after reading the first frame.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + + getter_task = asyncio.create_task(alist(self.assembler.get_iter())) + await asyncio.sleep(0) # let the event loop start getter_task + getter_task.cancel() + with self.assertRaises(asyncio.CancelledError): + await getter_task + + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + + with self.assertRaises(RuntimeError): + await alist(self.assembler.get_iter()) + + # Test put + + async def test_put_pauses_reading(self): + """put pauses reading when queue goes above the high-water mark.""" + # queue is below the high-water mark + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.pause.assert_not_called() + + # queue is at the high-water mark + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.pause.assert_called_once_with() + + # queue is above the high-water mark + self.assembler.put(Frame(OP_CONT, b"a")) + self.pause.assert_called_once_with() + + # Test termination + + async def test_get_fails_when_interrupted_by_close(self): + """get raises EOFError when close is called.""" + asyncio.get_running_loop().call_soon(self.assembler.close) + with self.assertRaises(EOFError): + await self.assembler.get() + + async def test_get_iter_fails_when_interrupted_by_close(self): + """get_iter raises EOFError when close is called.""" + asyncio.get_running_loop().call_soon(self.assembler.close) + with self.assertRaises(EOFError): + async for _ in self.assembler.get_iter(): + self.fail("no fragment expected") + + async def test_get_fails_after_close(self): + """get raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + await self.assembler.get() + + async def test_get_iter_fails_after_close(self): + """get_iter raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + async for _ in self.assembler.get_iter(): + self.fail("no fragment expected") + + async def test_put_fails_after_close(self): + """put raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + async def test_close_is_idempotent(self): + """close can be called multiple times safely.""" + self.assembler.close() + self.assembler.close() + + # Test (non-)concurrency + + async def test_get_fails_when_get_is_running(self): + """get cannot be called concurrently with itself.""" + asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) + with self.assertRaises(RuntimeError): + await self.assembler.get() + self.assembler.close() # let task terminate + + async def test_get_fails_when_get_iter_is_running(self): + """get cannot be called concurrently with get_iter.""" + asyncio.create_task(alist(self.assembler.get_iter())) + await asyncio.sleep(0) + with self.assertRaises(RuntimeError): + await self.assembler.get() + self.assembler.close() # let task terminate + + async def test_get_iter_fails_when_get_is_running(self): + """get_iter cannot be called concurrently with get.""" + asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) + with self.assertRaises(RuntimeError): + await alist(self.assembler.get_iter()) + self.assembler.close() # let task terminate + + async def test_get_iter_fails_when_get_iter_is_running(self): + """get_iter cannot be called concurrently with itself.""" + asyncio.create_task(alist(self.assembler.get_iter())) + await asyncio.sleep(0) + with self.assertRaises(RuntimeError): + await alist(self.assembler.get_iter()) + self.assembler.close() # let task terminate + + # Test getting and setting limits + + async def test_get_limits(self): + """get_limits returns low and high water marks.""" + low, high = self.assembler.get_limits() + self.assertEqual(low, 1) + self.assertEqual(high, 2) + + async def test_set_limits(self): + """set_limits changes low and high water marks.""" + self.assembler.set_limits(low=2, high=4) + low, high = self.assembler.get_limits() + self.assertEqual(low, 2) + self.assertEqual(high, 4) diff --git a/tests/asyncio/utils.py b/tests/asyncio/utils.py new file mode 100644 index 000000000..a611bfc4b --- /dev/null +++ b/tests/asyncio/utils.py @@ -0,0 +1,5 @@ +async def alist(async_iterable): + items = [] + async for item in async_iterable: + items.append(item) + return items From 4a981688198f91385281b8c8e1cdfc0197d43bf5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 8 Apr 2023 08:37:40 +0200 Subject: [PATCH 1312/1539] Add new asyncio-based implementation. --- docs/project/changelog.rst | 15 +- docs/reference/index.rst | 12 + docs/reference/new-asyncio/client.rst | 53 ++ docs/reference/new-asyncio/common.rst | 43 ++ docs/reference/new-asyncio/server.rst | 72 ++ src/websockets/asyncio/client.py | 331 +++++++++ src/websockets/asyncio/compatibility.py | 12 +- src/websockets/asyncio/connection.py | 883 ++++++++++++++++++++++ src/websockets/asyncio/server.py | 772 +++++++++++++++++++ tests/asyncio/client.py | 33 + tests/asyncio/connection.py | 115 +++ tests/asyncio/server.py | 50 ++ tests/asyncio/test_client.py | 306 ++++++++ tests/asyncio/test_connection.py | 948 ++++++++++++++++++++++++ tests/asyncio/test_server.py | 525 +++++++++++++ 15 files changed, 4165 insertions(+), 5 deletions(-) create mode 100644 docs/reference/new-asyncio/client.rst create mode 100644 docs/reference/new-asyncio/common.rst create mode 100644 docs/reference/new-asyncio/server.rst create mode 100644 src/websockets/asyncio/client.py create mode 100644 src/websockets/asyncio/connection.py create mode 100644 src/websockets/asyncio/server.py create mode 100644 tests/asyncio/client.py create mode 100644 tests/asyncio/connection.py create mode 100644 tests/asyncio/server.py create mode 100644 tests/asyncio/test_client.py create mode 100644 tests/asyncio/test_connection.py create mode 100644 tests/asyncio/test_server.py diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index fd186a5fc..108b7c9c0 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -52,7 +52,7 @@ Backwards-incompatible changes async def handler(request, path): ... - You should switch to the recommended pattern since 10.1:: + You should switch to the pattern recommended since version 10.1:: async def handler(request): path = request.path # only if handler() uses the path argument @@ -61,6 +61,16 @@ Backwards-incompatible changes New features ............ +.. admonition:: websockets 11.0 introduces a new :mod:`asyncio` implementation. + :class: important + + This new implementation is intended to be a drop-in replacement for the + current implementation. It will become the default in a future release. + Please try it and report any issue that you encounter! + + See :func:`websockets.asyncio.client.connect` and + :func:`websockets.asyncio.server.serve` for details. + * Validated compatibility with Python 3.12. 12.0 @@ -175,7 +185,8 @@ New features It is particularly suited to client applications that establish only one connection. It may be used for servers handling few connections. - See :func:`~sync.client.connect` and :func:`~sync.server.serve` for details. + See :func:`websockets.sync.client.connect` and + :func:`websockets.sync.server.serve` for details. * Added ``open_timeout`` to :func:`~server.serve`. diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 0b80f087a..2486ac564 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -26,6 +26,18 @@ clients concurrently. asyncio/server asyncio/client +:mod:`asyncio` (new) +-------------------- + +This is a rewrite of the :mod:`asyncio` implementation. It will become the +default in the future. + +.. toctree:: + :titlesonly: + + new-asyncio/server + new-asyncio/client + :mod:`threading` ---------------- diff --git a/docs/reference/new-asyncio/client.rst b/docs/reference/new-asyncio/client.rst new file mode 100644 index 000000000..552d83b2f --- /dev/null +++ b/docs/reference/new-asyncio/client.rst @@ -0,0 +1,53 @@ +Client (:mod:`asyncio` - new) +============================= + +.. automodule:: websockets.asyncio.client + +Opening a connection +-------------------- + +.. autofunction:: connect + :async: + +.. autofunction:: unix_connect + :async: + +Using a connection +------------------ + +.. autoclass:: ClientConnection + + .. automethod:: __aiter__ + + .. automethod:: recv + + .. automethod:: recv_streaming + + .. automethod:: send + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: request + + .. autoattribute:: response + + .. autoproperty:: subprotocol diff --git a/docs/reference/new-asyncio/common.rst b/docs/reference/new-asyncio/common.rst new file mode 100644 index 000000000..ba23552dc --- /dev/null +++ b/docs/reference/new-asyncio/common.rst @@ -0,0 +1,43 @@ +:orphan: + +Both sides (:mod:`asyncio` - new) +================================= + +.. automodule:: websockets.asyncio.connection + +.. autoclass:: Connection + + .. automethod:: __aiter__ + + .. automethod:: recv + + .. automethod:: recv_streaming + + .. automethod:: send + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: request + + .. autoattribute:: response + + .. autoproperty:: subprotocol diff --git a/docs/reference/new-asyncio/server.rst b/docs/reference/new-asyncio/server.rst new file mode 100644 index 000000000..f3446fb80 --- /dev/null +++ b/docs/reference/new-asyncio/server.rst @@ -0,0 +1,72 @@ +Server (:mod:`asyncio` - new) +============================= + +.. automodule:: websockets.asyncio.server + +Creating a server +----------------- + +.. autofunction:: serve + :async: + +.. autofunction:: unix_serve + :async: + +Running a server +---------------- + +.. autoclass:: WebSocketServer + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: get_loop + + .. automethod:: is_serving + + .. automethod:: start_serving + + .. automethod:: serve_forever + + .. autoattribute:: sockets + +Using a connection +------------------ + +.. autoclass:: ServerConnection + + .. automethod:: __aiter__ + + .. automethod:: recv + + .. automethod:: recv_streaming + + .. automethod:: send + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: request + + .. autoattribute:: response + + .. autoproperty:: subprotocol diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py new file mode 100644 index 000000000..040d68ece --- /dev/null +++ b/src/websockets/asyncio/client.py @@ -0,0 +1,331 @@ +from __future__ import annotations + +import asyncio +from types import TracebackType +from typing import Any, Generator, Sequence + +from ..client import ClientProtocol +from ..datastructures import HeadersLike +from ..extensions.base import ClientExtensionFactory +from ..extensions.permessage_deflate import enable_client_permessage_deflate +from ..headers import validate_subprotocols +from ..http import USER_AGENT +from ..http11 import Response +from ..protocol import CONNECTING, Event +from ..typing import LoggerLike, Origin, Subprotocol +from ..uri import parse_uri +from .compatibility import TimeoutError, asyncio_timeout +from .connection import Connection + + +__all__ = ["connect", "unix_connect", "ClientConnection"] + + +class ClientConnection(Connection): + """ + :mod:`asyncio` implementation of a WebSocket client connection. + + :class:`ClientConnection` provides :meth:`recv` and :meth:`send` coroutines + for receiving and sending messages. + + It supports asynchronous iteration to receive messages:: + + async for message in websocket: + await process(message) + + The iterator exits normally when the connection is closed with close code + 1000 (OK) or 1001 (going away) or without a close code. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is + closed with any other code. + + Args: + protocol: Sans-I/O connection. + close_timeout: Timeout for closing the connection in seconds. + :obj:`None` disables the timeout. + + """ + + def __init__( + self, + protocol: ClientProtocol, + *, + close_timeout: float | None = 10, + ) -> None: + self.protocol: ClientProtocol + super().__init__( + protocol, + close_timeout=close_timeout, + ) + self.response_rcvd: asyncio.Future[None] = self.loop.create_future() + + async def handshake( + self, + additional_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + ) -> None: + """ + Perform the opening handshake. + + """ + async with self.send_context(expected_state=CONNECTING): + self.request = self.protocol.connect() + if additional_headers is not None: + self.request.headers.update(additional_headers) + if user_agent_header is not None: + self.request.headers["User-Agent"] = user_agent_header + self.protocol.send_request(self.request) + + # May raise CancelledError if open_timeout is exceeded. + await self.response_rcvd + + if self.response is None: + raise ConnectionError("connection closed during handshake") + + if self.protocol.handshake_exc is not None: + try: + async with asyncio_timeout(self.close_timeout): + await self.connection_lost_waiter + finally: + raise self.protocol.handshake_exc + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + """ + # First event - handshake response. + if self.response is None: + assert isinstance(event, Response) + self.response = event + self.response_rcvd.set_result(None) + # Later events - frames. + else: + super().process_event(event) + + def connection_lost(self, exc: Exception | None) -> None: + try: + super().connection_lost(exc) + finally: + # If the connection is closed during the handshake, unblock it. + if not self.response_rcvd.done(): + self.response_rcvd.set_result(None) + + +# This is spelled in lower case because it's exposed as a callable in the API. +class connect: + """ + Connect to the WebSocket server at ``uri``. + + This coroutine returns a :class:`ClientConnection` instance, which you can + use to send and receive messages. + + :func:`connect` may be used as an asynchronous context manager:: + + async with websockets.asyncio.client.connect(...) as websocket: + ... + + The connection is closed automatically when exiting the context. + + Args: + uri: URI of the WebSocket server. + origin: Value of the ``Origin`` header, for servers that require it. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + additional_headers (HeadersLike | None): Arbitrary HTTP headers to add + to the handshake request. + user_agent_header: Value of the ``User-Agent`` request header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. + Setting it to :obj:`None` removes the header. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + open_timeout: Timeout for opening the connection in seconds. + :obj:`None` disables the timeout. + close_timeout: Timeout for closing the connection in seconds. + :obj:`None` disables the timeout. + max_size: Maximum size of incoming messages in bytes. + :obj:`None` disables the limit. + logger: Logger for this client. + It defaults to ``logging.getLogger("websockets.client")``. + See the :doc:`logging guide <../../topics/logging>` for details. + create_connection: Factory for the :class:`ClientConnection` managing + the connection. Set it to a wrapper or a subclass to customize + connection handling. + + Any other keyword arguments are passed to the event loop's + :meth:`~asyncio.loop.create_connection` method. + + For example: + + * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enforce TLS settings. + When connecting to a ``wss://`` URI, if ``ssl`` isn't provided, a TLS + context is created with :func:`~ssl.create_default_context`. + + * You can set ``server_hostname`` to override the host name from ``uri`` in + the TLS handshake. + + * You can set ``host`` and ``port`` to connect to a different host and port + from those found in ``uri``. This only changes the destination of the TCP + connection. The host name from ``uri`` is still used in the TLS handshake + for secure connections and in the ``Host`` header. + + * You can set ``sock`` to provide a preexisting TCP socket. You may call + :func:`socket.create_connection` (not to be confused with the event loop's + :meth:`~asyncio.loop.create_connection` method) to create a suitable + client socket and customize it. + + Raises: + InvalidURI: If ``uri`` isn't a valid WebSocket URI. + OSError: If the TCP connection fails. + InvalidHandshake: If the opening handshake fails. + TimeoutError: If the opening handshake times out. + + """ + + def __init__( + self, + uri: str, + *, + # WebSocket + origin: Origin | None = None, + extensions: Sequence[ClientExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + additional_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + compression: str | None = "deflate", + # Timeouts + open_timeout: float | None = 10, + close_timeout: float | None = 10, + # Limits + max_size: int | None = 2**20, + # Logging + logger: LoggerLike | None = None, + # Escape hatch for advanced customization + create_connection: type[ClientConnection] | None = None, + # Other keyword arguments are passed to loop.create_connection + **kwargs: Any, + ) -> None: + + wsuri = parse_uri(uri) + + if wsuri.secure: + kwargs.setdefault("ssl", True) + kwargs.setdefault("server_hostname", wsuri.host) + if kwargs.get("ssl") is None: + raise TypeError("ssl=None is incompatible with a wss:// URI") + else: + if kwargs.get("ssl") is not None: + raise TypeError("ssl argument is incompatible with a ws:// URI") + + if subprotocols is not None: + validate_subprotocols(subprotocols) + + if compression == "deflate": + extensions = enable_client_permessage_deflate(extensions) + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + if create_connection is None: + create_connection = ClientConnection + + def factory() -> ClientConnection: + # This is a protocol in the Sans-I/O implementation of websockets. + protocol = ClientProtocol( + wsuri, + origin=origin, + extensions=extensions, + subprotocols=subprotocols, + max_size=max_size, + logger=logger, + ) + # This is a connection in websockets and a protocol in asyncio. + connection = create_connection( + protocol, + close_timeout=close_timeout, + ) + return connection + + loop = asyncio.get_running_loop() + if kwargs.pop("unix", False): + self._create_connection = loop.create_unix_connection(factory, **kwargs) + else: + if kwargs.get("sock") is None: + kwargs.setdefault("host", wsuri.host) + kwargs.setdefault("port", wsuri.port) + self._create_connection = loop.create_connection(factory, **kwargs) + + self._handshake_args = ( + additional_headers, + user_agent_header, + ) + + self._open_timeout = open_timeout + + # async with connect(...) as ...: ... + + async def __aenter__(self) -> ClientConnection: + return await self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + await self.connection.close() + + # ... = await connect(...) + + def __await__(self) -> Generator[Any, None, ClientConnection]: + # Create a suitable iterator by calling __await__ on a coroutine. + return self.__await_impl__().__await__() + + async def __await_impl__(self) -> ClientConnection: + try: + async with asyncio_timeout(self._open_timeout): + _transport, self.connection = await self._create_connection + try: + await self.connection.handshake(*self._handshake_args) + except (Exception, asyncio.CancelledError): + self.connection.transport.close() + raise + else: + return self.connection + except TimeoutError: + # Re-raise exception with an informative error message. + raise TimeoutError("timed out during handshake") from None + + # ... = yield from connect(...) - remove when dropping Python < 3.10 + + __iter__ = __await__ + + +def unix_connect( + path: str | None = None, + uri: str | None = None, + **kwargs: Any, +) -> connect: + """ + Connect to a WebSocket server listening on a Unix socket. + + This function accepts the same keyword arguments as :func:`connect`. + + It's only available on Unix. + + It's mainly useful for debugging servers listening on Unix sockets. + + Args: + path: File system path to the Unix socket. + uri: URI of the WebSocket server. ``uri`` defaults to + ``ws://localhost/`` or, when a ``ssl`` argument is provided, to + ``wss://localhost/``. + + """ + if uri is None: + if kwargs.get("ssl") is None: + uri = "ws://localhost/" + else: + uri = "wss://localhost/" + return connect(uri=uri, unix=True, path=path, **kwargs) diff --git a/src/websockets/asyncio/compatibility.py b/src/websockets/asyncio/compatibility.py index 390f00ac7..e17000069 100644 --- a/src/websockets/asyncio/compatibility.py +++ b/src/websockets/asyncio/compatibility.py @@ -3,14 +3,17 @@ import sys -__all__ = ["TimeoutError", "aiter", "anext", "asyncio_timeout"] +__all__ = ["TimeoutError", "aiter", "anext", "asyncio_timeout", "asyncio_timeout_at"] if sys.version_info[:2] >= (3, 11): TimeoutError = TimeoutError aiter = aiter anext = anext - from asyncio import timeout as asyncio_timeout + from asyncio import ( + timeout as asyncio_timeout, # noqa: F401 + timeout_at as asyncio_timeout_at, # noqa: F401 + ) else: # Python < 3.11 from asyncio import TimeoutError @@ -21,4 +24,7 @@ def aiter(async_iterable): async def anext(async_iterator): return await type(async_iterator).__anext__(async_iterator) - from .async_timeout import timeout as asyncio_timeout + from .async_timeout import ( + timeout as asyncio_timeout, # noqa: F401 + timeout_at as asyncio_timeout_at, # noqa: F401 + ) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py new file mode 100644 index 000000000..550c0ac97 --- /dev/null +++ b/src/websockets/asyncio/connection.py @@ -0,0 +1,883 @@ +from __future__ import annotations + +import asyncio +import collections +import contextlib +import logging +import random +import struct +import uuid +from types import TracebackType +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Awaitable, + Iterable, + Mapping, + cast, +) + +from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError +from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode, prepare_ctrl +from ..http11 import Request, Response +from ..protocol import CLOSED, OPEN, Event, Protocol, State +from ..typing import Data, LoggerLike, Subprotocol +from .compatibility import TimeoutError, aiter, anext, asyncio_timeout_at +from .messages import Assembler + + +__all__ = ["Connection"] + + +class Connection(asyncio.Protocol): + """ + :mod:`asyncio` implementation of a WebSocket connection. + + :class:`Connection` provides APIs shared between WebSocket servers and + clients. + + You shouldn't use it directly. Instead, use + :class:`~websockets.asyncio.client.ClientConnection` or + :class:`~websockets.asyncio.server.ServerConnection`. + + """ + + def __init__( + self, + protocol: Protocol, + *, + close_timeout: float | None = 10, + ) -> None: + self.protocol = protocol + self.close_timeout = close_timeout + + # Inject reference to this instance in the protocol's logger. + self.protocol.logger = logging.LoggerAdapter( + self.protocol.logger, + {"websocket": self}, + ) + + # Copy attributes from the protocol for convenience. + self.id: uuid.UUID = self.protocol.id + """Unique identifier of the connection. Useful in logs.""" + self.logger: LoggerLike = self.protocol.logger + """Logger for this connection.""" + self.debug = self.protocol.debug + + # HTTP handshake request and response. + self.request: Request | None = None + """Opening handshake request.""" + self.response: Response | None = None + """Opening handshake response.""" + + # Event loop running this connection. + self.loop = asyncio.get_running_loop() + + # Assembler turning frames into messages and serializing reads. + self.recv_messages: Assembler # initialized in connection_made + + # Deadline for the closing handshake. + self.close_deadline: float | None = None + + # Protect sending fragmented messages. + self.fragmented_send_waiter: asyncio.Future[None] | None = None + + # Mapping of ping IDs to pong waiters, in chronological order. + self.pong_waiters: dict[bytes, tuple[asyncio.Future[float], float]] = {} + + # Exception raised while reading from the connection, to be chained to + # ConnectionClosed in order to show why the TCP connection dropped. + self.recv_exc: BaseException | None = None + + # Completed when the TCP connection is closed and the WebSocket + # connection state becomes CLOSED. + self.connection_lost_waiter: asyncio.Future[None] = self.loop.create_future() + + # Adapted from asyncio.FlowControlMixin + self.paused: bool = False + self.drain_waiters: collections.deque[asyncio.Future[None]] = ( + collections.deque() + ) + + # Public attributes + + @property + def local_address(self) -> Any: + """ + Local address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family. + See :meth:`~socket.socket.getsockname`. + + """ + return self.transport.get_extra_info("sockname") + + @property + def remote_address(self) -> Any: + """ + Remote address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family. + See :meth:`~socket.socket.getpeername`. + + """ + return self.transport.get_extra_info("peername") + + @property + def subprotocol(self) -> Subprotocol | None: + """ + Subprotocol negotiated during the opening handshake. + + :obj:`None` if no subprotocol was negotiated. + + """ + return self.protocol.subprotocol + + # Public methods + + async def __aenter__(self) -> Connection: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + if exc_type is None: + await self.close() + else: + await self.close(CloseCode.INTERNAL_ERROR) + + async def __aiter__(self) -> AsyncIterator[Data]: + """ + Iterate on incoming messages. + + The iterator calls :meth:`recv` and yields messages asynchronously in an + infinite loop. + + It exits when the connection is closed normally. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` exception after a + protocol error or a network failure. + + """ + try: + while True: + yield await self.recv() + except ConnectionClosedOK: + return + + async def recv(self) -> Data: + """ + Receive the next message. + + When the connection is closed, :meth:`recv` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it raises + :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal closure + and :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. This is how you detect the end of the + message stream. + + Canceling :meth:`recv` is safe. There's no risk of losing data. The next + invocation of :meth:`recv` will return the next message. + + This makes it possible to enforce a timeout by wrapping :meth:`recv` in + :func:`~asyncio.timeout` or :func:`~asyncio.wait_for`. + + When the message is fragmented, :meth:`recv` waits until all fragments + are received, reassembles them, and returns the whole message. + + Returns: + A string (:class:`str`) for a Text_ frame or a bytestring + (:class:`bytes`) for a Binary_ frame. + + .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If two coroutines call :meth:`recv` or + :meth:`recv_streaming` concurrently. + + """ + try: + return await self.recv_messages.get() + except EOFError: + raise self.protocol.close_exc from self.recv_exc + except RuntimeError: + raise RuntimeError( + "cannot call recv while another coroutine " + "is already running recv or recv_streaming" + ) from None + + async def recv_streaming(self) -> AsyncIterator[Data]: + """ + Receive the next message frame by frame. + + This method is designed for receiving fragmented messages. It returns an + asynchronous iterator that yields each fragment as it is received. This + iterator must be fully consumed. Else, future calls to :meth:`recv` or + :meth:`recv_streaming` will raise :exc:`RuntimeError`, making the + connection unusable. + + :meth:`recv_streaming` raises the same exceptions as :meth:`recv`. + + Canceling :meth:`recv_streaming` before receiving the first frame is + safe. Canceling it after receiving one or more frames leaves the + iterator in a partially consumed state, making the connection unusable. + Instead, you should close the connection with :meth:`close`. + + Returns: + An iterator of strings (:class:`str`) for a Text_ frame or + bytestrings (:class:`bytes`) for a Binary_ frame. + + .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If two coroutines call :meth:`recv` or + :meth:`recv_streaming` concurrently. + + """ + try: + async for frame in self.recv_messages.get_iter(): + yield frame + except EOFError: + raise self.protocol.close_exc from self.recv_exc + except RuntimeError: + raise RuntimeError( + "cannot call recv_streaming while another coroutine " + "is already running recv or recv_streaming" + ) from None + + async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> None: + """ + Send a message. + + A string (:class:`str`) is sent as a Text_ frame. A bytestring or + bytes-like object (:class:`bytes`, :class:`bytearray`, or + :class:`memoryview`) is sent as a Binary_ frame. + + .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + + :meth:`send` also accepts an iterable or an asynchronous iterable of + strings, bytestrings, or bytes-like objects to enable fragmentation_. + Each item is treated as a message fragment and sent in its own frame. + All items must be of the same type, or else :meth:`send` will raise a + :exc:`TypeError` and the connection will be closed. + + .. _fragmentation: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.4 + + :meth:`send` rejects dict-like objects because this is often an error. + (If you really want to send the keys of a dict-like object as fragments, + call its :meth:`~dict.keys` method and pass the result to :meth:`send`.) + + Canceling :meth:`send` is discouraged. Instead, you should close the + connection with :meth:`close`. Indeed, there are only two situations + where :meth:`send` may yield control to the event loop and then get + canceled; in both cases, :meth:`close` has the same effect and is + more clear: + + 1. The write buffer is full. If you don't want to wait until enough + data is sent, your only alternative is to close the connection. + :meth:`close` will likely time out then abort the TCP connection. + 2. ``message`` is an asynchronous iterator that yields control. + Stopping in the middle of a fragmented message will cause a + protocol error and the connection will be closed. + + When the connection is closed, :meth:`send` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it + raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal + connection closure and + :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. + + Args: + message: Message to send. + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If the connection busy sending a fragmented message. + TypeError: If ``message`` doesn't have a supported type. + + """ + # While sending a fragmented message, prevent sending other messages + # until all fragments are sent. + while self.fragmented_send_waiter is not None: + await asyncio.shield(self.fragmented_send_waiter) + + # Unfragmented message -- this case must be handled first because + # strings and bytes-like objects are iterable. + + if isinstance(message, str): + async with self.send_context(): + self.protocol.send_text(message.encode()) + + elif isinstance(message, BytesLike): + async with self.send_context(): + self.protocol.send_binary(message) + + # Catch a common mistake -- passing a dict to send(). + + elif isinstance(message, Mapping): + raise TypeError("data is a dict-like object") + + # Fragmented message -- regular iterator. + + elif isinstance(message, Iterable): + chunks = iter(message) + try: + chunk = next(chunks) + except StopIteration: + return + + assert self.fragmented_send_waiter is None + self.fragmented_send_waiter = self.loop.create_future() + try: + # First fragment. + if isinstance(chunk, str): + text = True + async with self.send_context(): + self.protocol.send_text( + chunk.encode(), + fin=False, + ) + elif isinstance(chunk, BytesLike): + text = False + async with self.send_context(): + self.protocol.send_binary( + chunk, + fin=False, + ) + else: + raise TypeError("iterable must contain bytes or str") + + # Other fragments + for chunk in chunks: + if isinstance(chunk, str) and text: + async with self.send_context(): + self.protocol.send_continuation( + chunk.encode(), + fin=False, + ) + elif isinstance(chunk, BytesLike) and not text: + async with self.send_context(): + self.protocol.send_continuation( + chunk, + fin=False, + ) + else: + raise TypeError("iterable must contain uniform types") + + # Final fragment. + async with self.send_context(): + self.protocol.send_continuation(b"", fin=True) + + except Exception: + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + async with self.send_context(): + self.protocol.fail(1011, "error in fragmented message") + raise + + finally: + self.fragmented_send_waiter.set_result(None) + self.fragmented_send_waiter = None + + # Fragmented message -- async iterator. + + elif isinstance(message, AsyncIterable): + achunks = aiter(message) + try: + chunk = await anext(achunks) + except StopAsyncIteration: + return + + assert self.fragmented_send_waiter is None + self.fragmented_send_waiter = self.loop.create_future() + try: + # First fragment. + if isinstance(chunk, str): + text = True + async with self.send_context(): + self.protocol.send_text( + chunk.encode(), + fin=False, + ) + elif isinstance(chunk, BytesLike): + text = False + async with self.send_context(): + self.protocol.send_binary( + chunk, + fin=False, + ) + else: + raise TypeError("async iterable must contain bytes or str") + + # Other fragments + async for chunk in achunks: + if isinstance(chunk, str) and text: + async with self.send_context(): + self.protocol.send_continuation( + chunk.encode(), + fin=False, + ) + elif isinstance(chunk, BytesLike) and not text: + async with self.send_context(): + self.protocol.send_continuation( + chunk, + fin=False, + ) + else: + raise TypeError("async iterable must contain uniform types") + + # Final fragment. + async with self.send_context(): + self.protocol.send_continuation(b"", fin=True) + + except Exception: + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + async with self.send_context(): + self.protocol.fail(1011, "error in fragmented message") + raise + + finally: + self.fragmented_send_waiter.set_result(None) + self.fragmented_send_waiter = None + + else: + raise TypeError("data must be bytes, str, iterable, or async iterable") + + async def close(self, code: int = 1000, reason: str = "") -> None: + """ + Perform the closing handshake. + + :meth:`close` waits for the other end to complete the handshake and + for the TCP connection to terminate. + + :meth:`close` is idempotent: it doesn't do anything once the + connection is closed. + + Args: + code: WebSocket close code. + reason: WebSocket close reason. + + """ + try: + # The context manager takes care of waiting for the TCP connection + # to terminate after calling a method that sends a close frame. + async with self.send_context(): + if self.fragmented_send_waiter is not None: + self.protocol.fail(1011, "close during fragmented message") + else: + self.protocol.send_close(code, reason) + except ConnectionClosed: + # Ignore ConnectionClosed exceptions raised from send_context(). + # They mean that the connection is closed, which was the goal. + pass + + async def wait_closed(self) -> None: + """ + Wait until the connection is closed. + + :meth:`wait_closed` waits for the closing handshake to complete and for + the TCP connection to terminate. + + """ + await asyncio.shield(self.connection_lost_waiter) + + async def ping(self, data: Data | None = None) -> Awaitable[None]: + """ + Send a Ping_. + + .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 + + A ping may serve as a keepalive or as a check that the remote endpoint + received all messages up to this point + + Args: + data: Payload of the ping. A :class:`str` will be encoded to UTF-8. + If ``data`` is :obj:`None`, the payload is four random bytes. + + Returns: + A future that will be completed when the corresponding pong is + received. You can ignore it if you don't intend to wait. The result + of the future is the latency of the connection in seconds. + + :: + + pong_waiter = await ws.ping() + # only if you want to wait for the corresponding pong + latency = await pong_waiter + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If another ping was sent with the same data and + the corresponding pong wasn't received yet. + + """ + if data is not None: + data = prepare_ctrl(data) + + async with self.send_context(): + # Protect against duplicates if a payload is explicitly set. + if data in self.pong_waiters: + raise RuntimeError("already waiting for a pong with the same data") + + # Generate a unique random payload otherwise. + while data is None or data in self.pong_waiters: + data = struct.pack("!I", random.getrandbits(32)) + + pong_waiter = self.loop.create_future() + # The event loop's default clock is time.monotonic(). Its resolution + # is a bit low on Windows (~16ms). We cannot use time.perf_counter() + # because it doesn't count time elapsed while the process sleeps. + ping_timestamp = self.loop.time() + self.pong_waiters[data] = (pong_waiter, ping_timestamp) + self.protocol.send_ping(data) + return pong_waiter + + async def pong(self, data: Data = b"") -> None: + """ + Send a Pong_. + + .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 + + An unsolicited pong may serve as a unidirectional heartbeat. + + Args: + data: Payload of the pong. A :class:`str` will be encoded to UTF-8. + + Raises: + ConnectionClosed: When the connection is closed. + + """ + data = prepare_ctrl(data) + + async with self.send_context(): + self.protocol.send_pong(data) + + # Private methods + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + This method is overridden in subclasses to handle the handshake. + + """ + assert isinstance(event, Frame) + if event.opcode in DATA_OPCODES: + self.recv_messages.put(event) + + if event.opcode is Opcode.PONG: + self.acknowledge_pings(bytes(event.data)) + + def acknowledge_pings(self, data: bytes) -> None: + """ + Acknowledge pings when receiving a pong. + + """ + # Ignore unsolicited pong. + if data not in self.pong_waiters: + return + + pong_timestamp = self.loop.time() + + # Sending a pong for only the most recent ping is legal. + # Acknowledge all previous pings too in that case. + ping_id = None + ping_ids = [] + for ping_id, (pong_waiter, ping_timestamp) in self.pong_waiters.items(): + ping_ids.append(ping_id) + pong_waiter.set_result(pong_timestamp - ping_timestamp) + if ping_id == data: + break + else: + raise AssertionError("solicited pong not found in pings") + + # Remove acknowledged pings from self.pong_waiters. + for ping_id in ping_ids: + del self.pong_waiters[ping_id] + + def abort_pings(self) -> None: + """ + Raise ConnectionClosed in pending pings. + + They'll never receive a pong once the connection is closed. + + """ + assert self.protocol.state is CLOSED + exc = self.protocol.close_exc + + for pong_waiter, _ping_timestamp in self.pong_waiters.values(): + pong_waiter.set_exception(exc) + # If the exception is never retrieved, it will be logged when ping + # is garbage-collected. This is confusing for users. + # Given that ping is done (with an exception), canceling it does + # nothing, but it prevents logging the exception. + pong_waiter.cancel() + + self.pong_waiters.clear() + + @contextlib.asynccontextmanager + async def send_context( + self, + *, + expected_state: State = OPEN, # CONNECTING during the opening handshake + ) -> AsyncIterator[None]: + """ + Create a context for writing to the connection from user code. + + On entry, :meth:`send_context` checks that the connection is open; on + exit, it writes outgoing data to the socket:: + + async async with self.send_context(): + self.protocol.send_text(message.encode()) + + When the connection isn't open on entry, when the connection is expected + to close on exit, or when an unexpected error happens, terminating the + connection, :meth:`send_context` waits until the connection is closed + then raises :exc:`~websockets.exceptions.ConnectionClosed`. + + """ + # Should we wait until the connection is closed? + wait_for_close = False + # Should we close the transport and raise ConnectionClosed? + raise_close_exc = False + # What exception should we chain ConnectionClosed to? + original_exc: BaseException | None = None + + if self.protocol.state is expected_state: + # Let the caller interact with the protocol. + try: + yield + except (ProtocolError, RuntimeError): + # The protocol state wasn't changed. Exit immediately. + raise + except Exception as exc: + self.logger.error("unexpected internal error", exc_info=True) + # This branch should never run. It's a safety net in case of + # bugs. Since we don't know what happened, we will close the + # connection and raise the exception to the caller. + wait_for_close = False + raise_close_exc = True + original_exc = exc + else: + # Check if the connection is expected to close soon. + if self.protocol.close_expected(): + wait_for_close = True + # If the connection is expected to close soon, set the + # close deadline based on the close timeout. + # Since we tested earlier that protocol.state was OPEN + # (or CONNECTING), self.close_deadline is still None. + if self.close_timeout is not None: + assert self.close_deadline is None + self.close_deadline = self.loop.time() + self.close_timeout + # Write outgoing data to the socket and enforce flow control. + try: + self.send_data() + await self.drain() + except Exception as exc: + if self.debug: + self.logger.debug("error while sending data", exc_info=True) + # While the only expected exception here is OSError, + # other exceptions would be treated identically. + wait_for_close = False + raise_close_exc = True + original_exc = exc + + else: # self.protocol.state is not expected_state + # Minor layering violation: we assume that the connection + # will be closing soon if it isn't in the expected state. + wait_for_close = True + # Calculate close_deadline if it wasn't set yet. + if self.close_timeout is not None: + if self.close_deadline is None: + self.close_deadline = self.loop.time() + self.close_timeout + raise_close_exc = True + + # If the connection is expected to close soon and the close timeout + # elapses, close the socket to terminate the connection. + if wait_for_close: + try: + async with asyncio_timeout_at(self.close_deadline): + await asyncio.shield(self.connection_lost_waiter) + except TimeoutError: + # There's no risk to overwrite another error because + # original_exc is never set when wait_for_close is True. + assert original_exc is None + original_exc = TimeoutError("timed out while closing connection") + # Set recv_exc before closing the transport in order to get + # proper exception reporting. + raise_close_exc = True + self.set_recv_exc(original_exc) + + # If an error occurred, close the transport to terminate the connection and + # raise an exception. + if raise_close_exc: + self.close_transport() + await asyncio.shield(self.connection_lost_waiter) + raise self.protocol.close_exc from original_exc + + def send_data(self) -> None: + """ + Send outgoing data. + + Raises: + OSError: When a socket operations fails. + + """ + for data in self.protocol.data_to_send(): + if data: + self.transport.write(data) + else: + # Half-close the TCP connection when possible i.e. no TLS. + if self.transport.can_write_eof(): + if self.debug: + self.logger.debug("x half-closing TCP connection") + # write_eof() doesn't document which exceptions it raises. + # OSError is plausible. uvloop can raise RuntimeError here. + try: + self.transport.write_eof() + except (OSError, RuntimeError): # pragma: no cover + pass + # Else, close the TCP connection. + else: # pragma: no cover + if self.debug: + self.logger.debug("x closing TCP connection") + self.transport.close() + + def set_recv_exc(self, exc: BaseException | None) -> None: + """ + Set recv_exc, if not set yet. + + """ + if self.recv_exc is None: + self.recv_exc = exc + + def close_transport(self) -> None: + """ + Close transport and message assembler. + + """ + self.transport.close() + self.recv_messages.close() + + # asyncio.Protocol methods + + # Connection callbacks + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + transport = cast(asyncio.Transport, transport) + self.transport = transport + self.recv_messages = Assembler( + pause=self.transport.pause_reading, + resume=self.transport.resume_reading, + ) + + def connection_lost(self, exc: Exception | None) -> None: + self.protocol.receive_eof() # receive_eof is idempotent + self.recv_messages.close() + self.set_recv_exc(exc) + # If self.connection_lost_waiter isn't pending, that's a bug, because: + # - it's set only here in connection_lost() which is called only once; + # - it must never be canceled. + self.connection_lost_waiter.set_result(None) + self.abort_pings() + + # Adapted from asyncio.streams.FlowControlMixin + if self.paused: # pragma: no cover + self.paused = False + for waiter in self.drain_waiters: + if not waiter.done(): + if exc is None: + waiter.set_result(None) + else: + waiter.set_exception(exc) + + # Flow control callbacks + + def pause_writing(self) -> None: # pragma: no cover + # Adapted from asyncio.streams.FlowControlMixin + assert not self.paused + self.paused = True + + def resume_writing(self) -> None: # pragma: no cover + # Adapted from asyncio.streams.FlowControlMixin + assert self.paused + self.paused = False + for waiter in self.drain_waiters: + if not waiter.done(): + waiter.set_result(None) + + async def drain(self) -> None: # pragma: no cover + # We don't check if the connection is closed because we call drain() + # immediately after write() and write() would fail in that case. + + # Adapted from asyncio.streams.StreamWriter + # Yield to the event loop so that connection_lost() may be called. + if self.transport.is_closing(): + await asyncio.sleep(0) + + # Adapted from asyncio.streams.FlowControlMixin + if self.paused: + waiter = self.loop.create_future() + self.drain_waiters.append(waiter) + try: + await waiter + finally: + self.drain_waiters.remove(waiter) + + # Streaming protocol callbacks + + def data_received(self, data: bytes) -> None: + # Feed incoming data to the protocol. + self.protocol.receive_data(data) + + # This isn't expected to raise an exception. + events = self.protocol.events_received() + + # Write outgoing data to the transport. + try: + self.send_data() + except Exception as exc: + if self.debug: + self.logger.debug("error while sending data", exc_info=True) + self.set_recv_exc(exc) + + if self.protocol.close_expected(): + # If the connection is expected to close soon, set the + # close deadline based on the close timeout. + if self.close_timeout is not None: + if self.close_deadline is None: + self.close_deadline = self.loop.time() + self.close_timeout + + for event in events: + # This isn't expected to raise an exception. + self.process_event(event) + + def eof_received(self) -> None: + # Feed the end of the data stream to the connection. + self.protocol.receive_eof() + + # This isn't expected to generate events. + assert not self.protocol.events_received() + + # There is no error handling because send_data() can only write + # the end of the data stream here and it shouldn't raise errors. + self.send_data() + + # The WebSocket protocol has its own closing handshake: endpoints close + # the TCP or TLS connection after sending and receiving a close frame. + # As a consequence, they never need to write after receiving EOF, so + # there's no reason to keep the transport open by returning True. + # Besides, that doesn't work on TLS connections. diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py new file mode 100644 index 000000000..aa175f775 --- /dev/null +++ b/src/websockets/asyncio/server.py @@ -0,0 +1,772 @@ +from __future__ import annotations + +import asyncio +import http +import logging +import socket +import sys +from types import TracebackType +from typing import ( + Any, + Awaitable, + Callable, + Generator, + Iterable, + Sequence, +) + +from websockets.frames import CloseCode + +from ..extensions.base import ServerExtensionFactory +from ..extensions.permessage_deflate import enable_server_permessage_deflate +from ..headers import validate_subprotocols +from ..http import USER_AGENT +from ..http11 import Request, Response +from ..protocol import CONNECTING, Event +from ..server import ServerProtocol +from ..typing import LoggerLike, Origin, Subprotocol +from .compatibility import asyncio_timeout +from .connection import Connection + + +__all__ = ["serve", "unix_serve", "ServerConnection", "WebSocketServer"] + + +class ServerConnection(Connection): + """ + :mod:`asyncio` implementation of a WebSocket server connection. + + :class:`ServerConnection` provides :meth:`recv` and :meth:`send` methods for + receiving and sending messages. + + It supports asynchronous iteration to receive messages:: + + async for message in websocket: + await process(message) + + The iterator exits normally when the connection is closed with close code + 1000 (OK) or 1001 (going away) or without a close code. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is + closed with any other code. + + Args: + protocol: Sans-I/O connection. + server: Server that manages this connection. + close_timeout: Timeout for closing connections in seconds. + :obj:`None` disables the timeout. + + """ + + def __init__( + self, + protocol: ServerProtocol, + server: WebSocketServer, + *, + close_timeout: float | None = 10, + ) -> None: + self.protocol: ServerProtocol + super().__init__( + protocol, + close_timeout=close_timeout, + ) + self.server = server + self.request_rcvd: asyncio.Future[None] = self.loop.create_future() + + async def handshake( + self, + process_request: ( + Callable[ + [ServerConnection, Request], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + process_response: ( + Callable[ + [ServerConnection, Request, Response], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + server_header: str | None = USER_AGENT, + ) -> None: + """ + Perform the opening handshake. + + """ + # May raise CancelledError if open_timeout is exceeded. + await self.request_rcvd + + if self.request is None: + raise ConnectionError("connection closed during handshake") + + async with self.send_context(expected_state=CONNECTING): + response = None + + if process_request is not None: + try: + response = process_request(self, self.request) + if isinstance(response, Awaitable): + response = await response + except Exception as exc: + self.protocol.handshake_exc = exc + self.logger.error("opening handshake failed", exc_info=True) + response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + + if response is None: + if self.server.is_serving(): + self.response = self.protocol.accept(self.request) + else: + self.response = self.protocol.reject( + http.HTTPStatus.SERVICE_UNAVAILABLE, + "Server is shutting down.\n", + ) + else: + assert isinstance(response, Response) # help mypy + self.response = response + + if server_header is not None: + self.response.headers["Server"] = server_header + + response = None + + if process_response is not None: + try: + response = process_response(self, self.request, self.response) + if isinstance(response, Awaitable): + response = await response + except Exception as exc: + self.protocol.handshake_exc = exc + self.logger.error("opening handshake failed", exc_info=True) + response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + + if response is not None: + assert isinstance(response, Response) # help mypy + self.response = response + + self.protocol.send_response(self.response) + + if self.protocol.handshake_exc is not None: + try: + async with asyncio_timeout(self.close_timeout): + await self.connection_lost_waiter + finally: + raise self.protocol.handshake_exc + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + """ + # First event - handshake request. + if self.request is None: + assert isinstance(event, Request) + self.request = event + self.request_rcvd.set_result(None) + # Later events - frames. + else: + super().process_event(event) + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + super().connection_made(transport) + self.server.start_connection_handler(self) + + def connection_lost(self, exc: Exception | None) -> None: + try: + super().connection_lost(exc) + finally: + # If the connection is closed during the handshake, unblock it. + if not self.request_rcvd.done(): + self.request_rcvd.set_result(None) + + +class WebSocketServer: + """ + WebSocket server returned by :func:`serve`. + + This class mirrors the API of :class:`~asyncio.Server`. + + It keeps track of WebSocket connections in order to close them properly + when shutting down. + + Args: + handler: Connection handler. It receives the WebSocket connection, + which is a :class:`ServerConnection`, in argument. + process_request: Intercept the request during the opening handshake. + Return an HTTP response to force the response or :obj:`None` to + continue normally. When you force an HTTP 101 Continue response, the + handshake is successful. Else, the connection is aborted. + ``process_request`` may be a function or a coroutine. + process_response: Intercept the response during the opening handshake. + Return an HTTP response to force the response or :obj:`None` to + continue normally. When you force an HTTP 101 Continue response, the + handshake is successful. Else, the connection is aborted. + ``process_response`` may be a function or a coroutine. + server_header: Value of the ``Server`` response header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to + :obj:`None` removes the header. + open_timeout: Timeout for opening connections in seconds. + :obj:`None` disables the timeout. + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. + See the :doc:`logging guide <../../topics/logging>` for details. + + """ + + def __init__( + self, + handler: Callable[[ServerConnection], Awaitable[None]], + *, + process_request: ( + Callable[ + [ServerConnection, Request], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + process_response: ( + Callable[ + [ServerConnection, Request, Response], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + server_header: str | None = USER_AGENT, + open_timeout: float | None = 10, + logger: LoggerLike | None = None, + ) -> None: + self.loop = asyncio.get_running_loop() + self.handler = handler + self.process_request = process_request + self.process_response = process_response + self.server_header = server_header + self.open_timeout = open_timeout + if logger is None: + logger = logging.getLogger("websockets.server") + self.logger = logger + + # Keep track of active connections. + self.handlers: dict[ServerConnection, asyncio.Task[None]] = {} + + # Task responsible for closing the server and terminating connections. + self.close_task: asyncio.Task[None] | None = None + + # Completed when the server is closed and connections are terminated. + self.closed_waiter: asyncio.Future[None] = self.loop.create_future() + + def wrap(self, server: asyncio.Server) -> None: + """ + Attach to a given :class:`~asyncio.Server`. + + Since :meth:`~asyncio.loop.create_server` doesn't support injecting a + custom ``Server`` class, the easiest solution that doesn't rely on + private :mod:`asyncio` APIs is to: + + - instantiate a :class:`WebSocketServer` + - give the protocol factory a reference to that instance + - call :meth:`~asyncio.loop.create_server` with the factory + - attach the resulting :class:`~asyncio.Server` with this method + + """ + self.server = server + for sock in server.sockets: + if sock.family == socket.AF_INET: + name = "%s:%d" % sock.getsockname() + elif sock.family == socket.AF_INET6: + name = "[%s]:%d" % sock.getsockname()[:2] + elif sock.family == socket.AF_UNIX: + name = sock.getsockname() + # In the unlikely event that someone runs websockets over a + # protocol other than IP or Unix sockets, avoid crashing. + else: # pragma: no cover + name = str(sock.getsockname()) + self.logger.info("server listening on %s", name) + + async def conn_handler(self, connection: ServerConnection) -> None: + """ + Handle the lifecycle of a WebSocket connection. + + Since this method doesn't have a caller that can handle exceptions, + it attempts to log relevant ones. + + It guarantees that the TCP connection is closed before exiting. + + """ + try: + # On failure, handshake() closes the transport, raises an + # exception, and logs it. + async with asyncio_timeout(self.open_timeout): + await connection.handshake( + self.process_request, + self.process_response, + self.server_header, + ) + + try: + await self.handler(connection) + except Exception: + self.logger.error("connection handler failed", exc_info=True) + await connection.close(CloseCode.INTERNAL_ERROR) + else: + await connection.close() + + except Exception: + # Don't leak connections on errors. + connection.transport.abort() + + finally: + # Registration is tied to the lifecycle of conn_handler() because + # the server waits for connection handlers to terminate, even if + # all connections are already closed. + del self.handlers[connection] + + def start_connection_handler(self, connection: ServerConnection) -> None: + """ + Register a connection with this server. + + """ + # The connection must be registered in self.handlers immediately. + # If it was registered in conn_handler(), a race condition could + # happen when closing the server after scheduling conn_handler() + # but before it starts executing. + self.handlers[connection] = self.loop.create_task(self.conn_handler(connection)) + + def close(self, close_connections: bool = True) -> None: + """ + Close the server. + + * Close the underlying :class:`~asyncio.Server`. + * When ``close_connections`` is :obj:`True`, which is the default, + close existing connections. Specifically: + + * Reject opening WebSocket connections with an HTTP 503 (service + unavailable) error. This happens when the server accepted the TCP + connection but didn't complete the opening handshake before closing. + * Close open WebSocket connections with close code 1001 (going away). + + * Wait until all connection handlers terminate. + + :meth:`close` is idempotent. + + """ + if self.close_task is None: + self.close_task = self.get_loop().create_task( + self._close(close_connections) + ) + + async def _close(self, close_connections: bool) -> None: + """ + Implementation of :meth:`close`. + + This calls :meth:`~asyncio.Server.close` on the underlying + :class:`~asyncio.Server` object to stop accepting new connections and + then closes open connections with close code 1001. + + """ + self.logger.info("server closing") + + # Stop accepting new connections. + self.server.close() + + # Wait until all accepted connections reach connection_made() and call + # register(). See https://github.com/python/cpython/issues/79033 for + # details. This workaround can be removed when dropping Python < 3.11. + await asyncio.sleep(0) + + if close_connections: + # Close OPEN connections with close code 1001. After server.close(), + # handshake() closes OPENING connections with an HTTP 503 error. + close_tasks = [ + asyncio.create_task(connection.close(1001)) + for connection in self.handlers + if connection.protocol.state is not CONNECTING + ] + # asyncio.wait doesn't accept an empty first argument. + if close_tasks: + await asyncio.wait(close_tasks) + + # Wait until all TCP connections are closed. + await self.server.wait_closed() + + # Wait until all connection handlers terminate. + # asyncio.wait doesn't accept an empty first argument. + if self.handlers: + await asyncio.wait(self.handlers.values()) + + # Tell wait_closed() to return. + self.closed_waiter.set_result(None) + + self.logger.info("server closed") + + async def wait_closed(self) -> None: + """ + Wait until the server is closed. + + When :meth:`wait_closed` returns, all TCP connections are closed and + all connection handlers have returned. + + To ensure a fast shutdown, a connection handler should always be + awaiting at least one of: + + * :meth:`~ServerConnection.recv`: when the connection is closed, + it raises :exc:`~websockets.exceptions.ConnectionClosedOK`; + * :meth:`~ServerConnection.wait_closed`: when the connection is + closed, it returns. + + Then the connection handler is immediately notified of the shutdown; + it can clean up and exit. + + """ + await asyncio.shield(self.closed_waiter) + + def get_loop(self) -> asyncio.AbstractEventLoop: + """ + See :meth:`asyncio.Server.get_loop`. + + """ + return self.server.get_loop() + + def is_serving(self) -> bool: # pragma: no cover + """ + See :meth:`asyncio.Server.is_serving`. + + """ + return self.server.is_serving() + + async def start_serving(self) -> None: # pragma: no cover + """ + See :meth:`asyncio.Server.start_serving`. + + Typical use:: + + server = await serve(..., start_serving=False) + # perform additional setup here... + # ... then start the server + await server.start_serving() + + """ + await self.server.start_serving() + + async def serve_forever(self) -> None: # pragma: no cover + """ + See :meth:`asyncio.Server.serve_forever`. + + Typical use:: + + server = await serve(...) + # this coroutine doesn't return + # canceling it stops the server + await server.serve_forever() + + This is an alternative to using :func:`serve` as an asynchronous context + manager. Shutdown is triggered by canceling :meth:`serve_forever` + instead of exiting a :func:`serve` context. + + """ + await self.server.serve_forever() + + @property + def sockets(self) -> Iterable[socket.socket]: + """ + See :attr:`asyncio.Server.sockets`. + + """ + return self.server.sockets + + async def __aenter__(self) -> WebSocketServer: # pragma: no cover + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: # pragma: no cover + self.close() + await self.wait_closed() + + +# This is spelled in lower case because it's exposed as a callable in the API. +class serve: + """ + Create a WebSocket server listening on ``host`` and ``port``. + + Whenever a client connects, the server creates a :class:`ServerConnection`, + performs the opening handshake, and delegates to the ``handler`` coroutine. + + The handler receives the :class:`ServerConnection` instance, which you can + use to send and receive messages. + + Once the handler completes, either normally or with an exception, the server + performs the closing handshake and closes the connection. + + This coroutine returns a :class:`WebSocketServer` whose API mirrors + :class:`~asyncio.Server`. Treat it as an asynchronous context manager to + ensure that the server will be closed:: + + def handler(websocket): + ... + + # set this future to exit the server + stop = asyncio.get_running_loop().create_future() + + async with websockets.asyncio.server.serve(handler, host, port): + await stop + + Alternatively, call :meth:`~WebSocketServer.serve_forever` to serve requests + and cancel it to stop the server:: + + server = await websockets.asyncio.server.serve(handler, host, port) + await server.serve_forever() + + Args: + handler: Connection handler. It receives the WebSocket connection, + which is a :class:`ServerConnection`, in argument. + host: Network interfaces the server binds to. + See :meth:`~asyncio.loop.create_server` for details. + port: TCP port the server listens on. + See :meth:`~asyncio.loop.create_server` for details. + origins: Acceptable values of the ``Origin`` header, for defending + against Cross-Site WebSocket Hijacking attacks. Include :obj:`None` + in the list if the lack of an origin is acceptable. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + select_subprotocol: Callback for selecting a subprotocol among + those supported by the client and the server. It receives a + :class:`ServerConnection` (not a + :class:`~websockets.server.ServerProtocol`!) instance and a list of + subprotocols offered by the client. Other than the first argument, + it has the same behavior as the + :meth:`ServerProtocol.select_subprotocol + ` method. + process_request: Intercept the request during the opening handshake. + Return an HTTP response to force the response or :obj:`None` to + continue normally. When you force an HTTP 101 Continue response, the + handshake is successful. Else, the connection is aborted. + ``process_request`` may be a function or a coroutine. + process_response: Intercept the response during the opening handshake. + Return an HTTP response to force the response or :obj:`None` to + continue normally. When you force an HTTP 101 Continue response, the + handshake is successful. Else, the connection is aborted. + ``process_response`` may be a function or a coroutine. + server_header: Value of the ``Server`` response header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to + :obj:`None` removes the header. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + open_timeout: Timeout for opening connections in seconds. + :obj:`None` disables the timeout. + close_timeout: Timeout for closing connections in seconds. + :obj:`None` disables the timeout. + max_size: Maximum size of incoming messages in bytes. + :obj:`None` disables the limit. + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. See the + :doc:`logging guide <../../topics/logging>` for details. + create_connection: Factory for the :class:`ServerConnection` managing + the connection. Set it to a wrapper or a subclass to customize + connection handling. + + Any other keyword arguments are passed to the event loop's + :meth:`~asyncio.loop.create_server` method. + + For example: + + * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enable TLS. + + * You can set ``sock`` to provide a preexisting TCP socket. You may call + :func:`socket.create_server` (not to be confused with the event loop's + :meth:`~asyncio.loop.create_server` method) to create a suitable server + socket and customize it. + + * You can set ``start_serving`` to ``False`` to start accepting connections + only after you call :meth:`~WebSocketServer.start_serving()` or + :meth:`~WebSocketServer.serve_forever()`. + + """ + + def __init__( + self, + handler: Callable[[ServerConnection], Awaitable[None]], + host: str | None = None, + port: int | None = None, + *, + # WebSocket + origins: Sequence[Origin | None] | None = None, + extensions: Sequence[ServerExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + select_subprotocol: ( + Callable[ + [ServerConnection, Sequence[Subprotocol]], + Subprotocol | None, + ] + | None + ) = None, + process_request: ( + Callable[ + [ServerConnection, Request], + Response | None, + ] + | None + ) = None, + process_response: ( + Callable[ + [ServerConnection, Request, Response], + Response | None, + ] + | None + ) = None, + server_header: str | None = USER_AGENT, + compression: str | None = "deflate", + # Timeouts + open_timeout: float | None = 10, + close_timeout: float | None = 10, + # Limits + max_size: int | None = 2**20, + # Logging + logger: LoggerLike | None = None, + # Escape hatch for advanced customization + create_connection: type[ServerConnection] | None = None, + # Other keyword arguments are passed to loop.create_server + **kwargs: Any, + ) -> None: + + if subprotocols is not None: + validate_subprotocols(subprotocols) + + if compression == "deflate": + extensions = enable_server_permessage_deflate(extensions) + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + if create_connection is None: + create_connection = ServerConnection + + self.server = WebSocketServer( + handler, + process_request=process_request, + process_response=process_response, + server_header=server_header, + open_timeout=open_timeout, + logger=logger, + ) + + if kwargs.get("ssl") is not None: + kwargs.setdefault("ssl_handshake_timeout", open_timeout) + if sys.version_info[:2] >= (3, 11): # pragma: no branch + kwargs.setdefault("ssl_shutdown_timeout", close_timeout) + + def factory() -> ServerConnection: + """ + Create an asyncio protocol for managing a WebSocket connection. + + """ + # Create a closure to give select_subprotocol access to connection. + protocol_select_subprotocol: ( + Callable[ + [ServerProtocol, Sequence[Subprotocol]], + Subprotocol | None, + ] + | None + ) = None + if select_subprotocol is not None: + + def protocol_select_subprotocol( + protocol: ServerProtocol, + subprotocols: Sequence[Subprotocol], + ) -> Subprotocol | None: + # mypy doesn't know that select_subprotocol is immutable. + assert select_subprotocol is not None + # Ensure this function is only used in the intended context. + assert protocol is connection.protocol + return select_subprotocol(connection, subprotocols) + + # This is a protocol in the Sans-I/O implementation of websockets. + protocol = ServerProtocol( + origins=origins, + extensions=extensions, + subprotocols=subprotocols, + select_subprotocol=protocol_select_subprotocol, + max_size=max_size, + logger=logger, + ) + # This is a connection in websockets and a protocol in asyncio. + connection = create_connection( + protocol, + self.server, + close_timeout=close_timeout, + ) + return connection + + loop = asyncio.get_running_loop() + if kwargs.pop("unix", False): + self._create_server = loop.create_unix_server(factory, **kwargs) + else: + # mypy cannot tell that kwargs must provide sock when port is None. + self._create_server = loop.create_server(factory, host, port, **kwargs) # type: ignore[arg-type] + + # async with serve(...) as ...: ... + + async def __aenter__(self) -> WebSocketServer: + return await self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self.server.close() + await self.server.wait_closed() + + # ... = await serve(...) + + def __await__(self) -> Generator[Any, None, WebSocketServer]: + # Create a suitable iterator by calling __await__ on a coroutine. + return self.__await_impl__().__await__() + + async def __await_impl__(self) -> WebSocketServer: + server = await self._create_server + self.server.wrap(server) + return self.server + + # ... = yield from serve(...) - remove when dropping Python < 3.10 + + __iter__ = __await__ + + +def unix_serve( + handler: Callable[[ServerConnection], Awaitable[None]], + path: str | None = None, + **kwargs: Any, +) -> Awaitable[WebSocketServer]: + """ + Create a WebSocket server listening on a Unix socket. + + This function is identical to :func:`serve`, except the ``host`` and + ``port`` arguments are replaced by ``path``. It's only available on Unix. + + It's useful for deploying a server behind a reverse proxy such as nginx. + + Args: + handler: Connection handler. It receives the WebSocket connection, + which is a :class:`ServerConnection`, in argument. + path: File system path to the Unix socket. + + """ + return serve(handler, unix=True, path=path, **kwargs) diff --git a/tests/asyncio/client.py b/tests/asyncio/client.py new file mode 100644 index 000000000..e5826add7 --- /dev/null +++ b/tests/asyncio/client.py @@ -0,0 +1,33 @@ +import contextlib + +from websockets.asyncio.client import * +from websockets.asyncio.server import WebSocketServer + +from .server import get_server_host_port + + +__all__ = [ + "run_client", + "run_unix_client", +] + + +@contextlib.asynccontextmanager +async def run_client(wsuri_or_server, secure=None, resource_name="/", **kwargs): + if isinstance(wsuri_or_server, str): + wsuri = wsuri_or_server + else: + assert isinstance(wsuri_or_server, WebSocketServer) + if secure is None: + secure = "ssl" in kwargs + protocol = "wss" if secure else "ws" + host, port = get_server_host_port(wsuri_or_server) + wsuri = f"{protocol}://{host}:{port}{resource_name}" + async with connect(wsuri, **kwargs) as client: + yield client + + +@contextlib.asynccontextmanager +async def run_unix_client(path, **kwargs): + async with unix_connect(path, **kwargs) as client: + yield client diff --git a/tests/asyncio/connection.py b/tests/asyncio/connection.py new file mode 100644 index 000000000..ad1c121bf --- /dev/null +++ b/tests/asyncio/connection.py @@ -0,0 +1,115 @@ +import asyncio +import contextlib + +from websockets.asyncio.connection import Connection + + +class InterceptingConnection(Connection): + """ + Connection subclass that can intercept outgoing packets. + + By interfacing with this connection, we simulate network conditions + affecting what the component being tested receives during a test. + + """ + + def connection_made(self, transport): + super().connection_made(InterceptingTransport(transport)) + + @contextlib.contextmanager + def delay_frames_sent(self, delay): + """ + Add a delay before sending frames. + + This can result in out-of-order writes, which is unrealistic. + + """ + assert self.transport.delay_write is None + self.transport.delay_write = delay + try: + yield + finally: + self.transport.delay_write = None + + @contextlib.contextmanager + def delay_eof_sent(self, delay): + """ + Add a delay before sending EOF. + + This can result in out-of-order writes, which is unrealistic. + + """ + assert self.transport.delay_write_eof is None + self.transport.delay_write_eof = delay + try: + yield + finally: + self.transport.delay_write_eof = None + + @contextlib.contextmanager + def drop_frames_sent(self): + """ + Prevent frames from being sent. + + Since TCP is reliable, sending frames or EOF afterwards is unrealistic. + + """ + assert not self.transport.drop_write + self.transport.drop_write = True + try: + yield + finally: + self.transport.drop_write = False + + @contextlib.contextmanager + def drop_eof_sent(self): + """ + Prevent EOF from being sent. + + Since TCP is reliable, sending frames or EOF afterwards is unrealistic. + + """ + assert not self.transport.drop_write_eof + self.transport.drop_write_eof = True + try: + yield + finally: + self.transport.drop_write_eof = False + + +class InterceptingTransport: + """ + Transport wrapper that intercepts calls to ``write()`` and ``write_eof()``. + + This is coupled to the implementation, which relies on these two methods. + + Since ``write()`` and ``write_eof()`` are not coroutines, this effect is + achieved by scheduling writes at a later time, after the methods return. + This can easily result in out-of-order writes, which is unrealistic. + + """ + + def __init__(self, transport): + self.loop = asyncio.get_running_loop() + self.transport = transport + self.delay_write = None + self.delay_write_eof = None + self.drop_write = False + self.drop_write_eof = False + + def __getattr__(self, name): + return getattr(self.transport, name) + + def write(self, data): + if not self.drop_write: + if self.delay_write is not None: + self.loop.call_later(self.delay_write, self.transport.write, data) + else: + self.transport.write(data) + + def write_eof(self): + if not self.drop_write_eof: + if self.delay_write_eof is not None: + self.loop.call_later(self.delay_write_eof, self.transport.write_eof) + else: + self.transport.write_eof() diff --git a/tests/asyncio/server.py b/tests/asyncio/server.py new file mode 100644 index 000000000..0fe20dc65 --- /dev/null +++ b/tests/asyncio/server.py @@ -0,0 +1,50 @@ +import asyncio +import contextlib +import socket + +from websockets.asyncio.server import * + + +def get_server_host_port(server): + for sock in server.sockets: + if sock.family == socket.AF_INET: # pragma: no branch + return sock.getsockname() + raise AssertionError("expected at least one IPv4 socket") + + +async def eval_shell(ws): + async for expr in ws: + value = eval(expr) + await ws.send(str(value)) + + +class EvalShellMixin: + async def assertEval(self, client, expr, value): + await client.send(expr) + self.assertEqual(await client.recv(), value) + + +async def crash(ws): + raise RuntimeError + + +async def do_nothing(ws): + pass + + +async def keep_running(ws): + delay = float(await ws.recv()) + await ws.close() + await asyncio.sleep(delay) + + +@contextlib.asynccontextmanager +async def run_server(handler=eval_shell, host="localhost", port=0, **kwargs): + async with serve(handler, host, port, **kwargs) as server: + yield server + + +@contextlib.asynccontextmanager +async def run_unix_server(path, handler=eval_shell, **kwargs): + async with unix_serve(handler, path, **kwargs) as server: + yield server diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py new file mode 100644 index 000000000..aab65cd2e --- /dev/null +++ b/tests/asyncio/test_client.py @@ -0,0 +1,306 @@ +import asyncio +import socket +import ssl +import unittest + +from websockets.asyncio.client import * +from websockets.asyncio.compatibility import TimeoutError +from websockets.exceptions import InvalidHandshake, InvalidURI +from websockets.extensions.permessage_deflate import PerMessageDeflate + +from ..utils import CLIENT_CONTEXT, MS, SERVER_CONTEXT, temp_unix_socket_path +from .client import run_client, run_unix_client +from .server import do_nothing, get_server_host_port, run_server, run_unix_server + + +class ClientTests(unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Client connects to server and the handshake succeeds.""" + async with run_server() as server: + async with run_client(server) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + async def test_existing_socket(self): + """Client connects using a pre-existing socket.""" + async with run_server() as server: + with socket.create_connection(get_server_host_port(server)) as sock: + # Use a non-existing domain to ensure we connect to the right socket. + async with run_client("ws://invalid/", sock=sock) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + async def test_additional_headers(self): + """Client can set additional headers with additional_headers.""" + async with run_server() as server: + async with run_client( + server, additional_headers={"Authorization": "Bearer ..."} + ) as client: + self.assertEqual(client.request.headers["Authorization"], "Bearer ...") + + async def test_override_user_agent(self): + """Client can override User-Agent header with user_agent_header.""" + async with run_server() as server: + async with run_client(server, user_agent_header="Smith") as client: + self.assertEqual(client.request.headers["User-Agent"], "Smith") + + async def test_remove_user_agent(self): + """Client can remove User-Agent header with user_agent_header.""" + async with run_server() as server: + async with run_client(server, user_agent_header=None) as client: + self.assertNotIn("User-Agent", client.request.headers) + + async def test_compression_is_enabled(self): + """Client enables compression by default.""" + async with run_server() as server: + async with run_client(server) as client: + self.assertEqual( + [type(ext) for ext in client.protocol.extensions], + [PerMessageDeflate], + ) + + async def test_disable_compression(self): + """Client disables compression.""" + async with run_server() as server: + async with run_client(server, compression=None) as client: + self.assertEqual(client.protocol.extensions, []) + + async def test_custom_connection_factory(self): + """Client runs ClientConnection factory provided in create_connection.""" + + def create_connection(*args, **kwargs): + client = ClientConnection(*args, **kwargs) + client.create_connection_ran = True + return client + + async with run_server() as server: + async with run_client( + server, create_connection=create_connection + ) as client: + self.assertTrue(client.create_connection_ran) + + async def test_invalid_uri(self): + """Client receives an invalid URI.""" + with self.assertRaises(InvalidURI): + async with run_client("http://localhost"): # invalid scheme + self.fail("did not raise") + + async def test_tcp_connection_fails(self): + """Client fails to connect to server.""" + with self.assertRaises(OSError): + async with run_client("ws://localhost:54321"): # invalid port + self.fail("did not raise") + + async def test_handshake_fails(self): + """Client connects to server but the handshake fails.""" + + def remove_accept_header(self, request, response): + del response.headers["Sec-WebSocket-Accept"] + + # The connection will be open for the server but failed for the client. + # Use a connection handler that exits immediately to avoid an exception. + async with run_server( + do_nothing, process_response=remove_accept_header + ) as server: + with self.assertRaises(InvalidHandshake) as raised: + async with run_client(server, close_timeout=MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "missing Sec-WebSocket-Accept header", + ) + + async def test_timeout_during_handshake(self): + """Client times out before receiving handshake response from server.""" + gate = asyncio.get_running_loop().create_future() + + async def stall_connection(self, request): + await gate + + # The connection will be open for the server but failed for the client. + # Use a connection handler that exits immediately to avoid an exception. + async with run_server(do_nothing, process_request=stall_connection) as server: + try: + with self.assertRaises(TimeoutError) as raised: + async with run_client(server, open_timeout=2 * MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "timed out during handshake", + ) + finally: + gate.set_result(None) + + async def test_connection_closed_during_handshake(self): + """Client reads EOF before receiving handshake response from server.""" + + def close_connection(self, request): + self.close_transport() + + async with run_server(process_request=close_connection) as server: + with self.assertRaises(ConnectionError) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "connection closed during handshake", + ) + + +class SecureClientTests(unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Client connects to server securely.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + async with run_client(server, ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.version()[:3], "TLS") + + async def test_set_server_hostname_implicitly(self): + """Client sets server_hostname to the host in the WebSocket URI.""" + with temp_unix_socket_path() as path: + async with run_unix_server(path, ssl=SERVER_CONTEXT): + async with run_unix_client( + path, + ssl=CLIENT_CONTEXT, + uri="wss://overridden/", + ) as client: + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.server_hostname, "overridden") + + async def test_set_server_hostname_explicitly(self): + """Client sets server_hostname to the value provided in argument.""" + with temp_unix_socket_path() as path: + async with run_unix_server(path, ssl=SERVER_CONTEXT): + async with run_unix_client( + path, + ssl=CLIENT_CONTEXT, + server_hostname="overridden", + ) as client: + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.server_hostname, "overridden") + + async def test_reject_invalid_server_certificate(self): + """Client rejects certificate where server certificate isn't trusted.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The test certificate isn't trusted system-wide. + async with run_client(server, secure=True): + self.fail("did not raise") + self.assertIn( + "certificate verify failed: self signed certificate", + str(raised.exception).replace("-", " "), + ) + + async def test_reject_invalid_server_hostname(self): + """Client rejects certificate where server hostname doesn't match.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # This hostname isn't included in the test certificate. + async with run_client( + server, ssl=CLIENT_CONTEXT, server_hostname="invalid" + ): + self.fail("did not raise") + self.assertIn( + "certificate verify failed: Hostname mismatch", + str(raised.exception), + ) + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class UnixClientTests(unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Client connects to server over a Unix socket.""" + with temp_unix_socket_path() as path: + async with run_unix_server(path): + async with run_unix_client(path) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + async def test_set_host_header(self): + """Client sets the Host header to the host in the WebSocket URI.""" + # This is part of the documented behavior of unix_connect(). + with temp_unix_socket_path() as path: + async with run_unix_server(path): + async with run_unix_client(path, uri="ws://overridden/") as client: + self.assertEqual(client.request.headers["Host"], "overridden") + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class SecureUnixClientTests(unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Client connects to server securely over a Unix socket.""" + with temp_unix_socket_path() as path: + async with run_unix_server(path, ssl=SERVER_CONTEXT): + async with run_unix_client(path, ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.version()[:3], "TLS") + + async def test_set_server_hostname(self): + """Client sets server_hostname to the host in the WebSocket URI.""" + # This is part of the documented behavior of unix_connect(). + with temp_unix_socket_path() as path: + async with run_unix_server(path, ssl=SERVER_CONTEXT): + async with run_unix_client( + path, + ssl=CLIENT_CONTEXT, + uri="wss://overridden/", + ) as client: + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.server_hostname, "overridden") + + +class ClientUsageErrorsTests(unittest.IsolatedAsyncioTestCase): + async def test_ssl_without_secure_uri(self): + """Client rejects ssl when URI isn't secure.""" + with self.assertRaises(TypeError) as raised: + await connect("ws://localhost/", ssl=CLIENT_CONTEXT) + self.assertEqual( + str(raised.exception), + "ssl argument is incompatible with a ws:// URI", + ) + + async def test_secure_uri_without_ssl(self): + """Client rejects no ssl when URI is secure.""" + with self.assertRaises(TypeError) as raised: + await connect("wss://localhost/", ssl=None) + self.assertEqual( + str(raised.exception), + "ssl=None is incompatible with a wss:// URI", + ) + + async def test_unix_without_path_or_sock(self): + """Unix client requires path when sock isn't provided.""" + with self.assertRaises(ValueError) as raised: + await unix_connect() + self.assertEqual( + str(raised.exception), + "no path and sock were specified", + ) + + async def test_unix_with_path_and_sock(self): + """Unix client rejects path when sock is provided.""" + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.addCleanup(sock.close) + with self.assertRaises(ValueError) as raised: + await unix_connect(path="/", sock=sock) + self.assertEqual( + str(raised.exception), + "path and sock can not be specified at the same time", + ) + + async def test_invalid_subprotocol(self): + """Client rejects single value of subprotocols.""" + with self.assertRaises(TypeError) as raised: + await connect("ws://localhost/", subprotocols="chat") + self.assertEqual( + str(raised.exception), + "subprotocols must be a list, not a str", + ) + + async def test_unsupported_compression(self): + """Client rejects incorrect value of compression.""" + with self.assertRaises(ValueError) as raised: + await connect("ws://localhost/", compression=False) + self.assertEqual( + str(raised.exception), + "unsupported compression: False", + ) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py new file mode 100644 index 000000000..a8b3980b4 --- /dev/null +++ b/tests/asyncio/test_connection.py @@ -0,0 +1,948 @@ +import asyncio +import contextlib +import logging +import socket +import unittest +import uuid +from unittest.mock import patch + +from websockets.asyncio.compatibility import TimeoutError, aiter, anext, asyncio_timeout +from websockets.asyncio.connection import * +from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK +from websockets.frames import CloseCode, Frame, Opcode +from websockets.protocol import CLIENT, SERVER, Protocol + +from ..protocol import RecordingProtocol +from ..utils import MS +from .connection import InterceptingConnection +from .utils import alist + + +# Connection implements symmetrical behavior between clients and servers. +# All tests run on the client side and the server side to validate this. + + +class ClientConnectionTests(unittest.IsolatedAsyncioTestCase): + LOCAL = CLIENT + REMOTE = SERVER + + async def asyncSetUp(self): + loop = asyncio.get_running_loop() + socket_, remote_socket = socket.socketpair() + self.transport, self.connection = await loop.create_connection( + lambda: Connection(Protocol(self.LOCAL), close_timeout=2 * MS), + sock=socket_, + ) + self.remote_transport, self.remote_connection = await loop.create_connection( + lambda: InterceptingConnection(RecordingProtocol(self.REMOTE)), + sock=remote_socket, + ) + + async def asyncTearDown(self): + await self.remote_connection.close() + await self.connection.close() + + # Test helpers built upon RecordingProtocol and InterceptingConnection. + + async def assertFrameSent(self, frame): + """Check that a single frame was sent.""" + # Let the remote side process messages. + # Two runs of the event loop are required for answering pings. + await asyncio.sleep(0) + await asyncio.sleep(0) + self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), [frame]) + + async def assertFramesSent(self, frames): + """Check that several frames were sent.""" + # Let the remote side process messages. + # Two runs of the event loop are required for answering pings. + await asyncio.sleep(0) + await asyncio.sleep(0) + self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), frames) + + async def assertNoFrameSent(self): + """Check that no frame was sent.""" + # Run the event loop twice for consistency with assertFrameSent. + await asyncio.sleep(0) + await asyncio.sleep(0) + self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), []) + + @contextlib.asynccontextmanager + async def delay_frames_rcvd(self, delay): + """Delay frames before they're received by the connection.""" + with self.remote_connection.delay_frames_sent(delay): + yield + await asyncio.sleep(MS) # let the remote side process messages + + @contextlib.asynccontextmanager + async def delay_eof_rcvd(self, delay): + """Delay EOF before it's received by the connection.""" + with self.remote_connection.delay_eof_sent(delay): + yield + await asyncio.sleep(MS) # let the remote side process messages + + @contextlib.asynccontextmanager + async def drop_frames_rcvd(self): + """Drop frames before they're received by the connection.""" + with self.remote_connection.drop_frames_sent(): + yield + await asyncio.sleep(MS) # let the remote side process messages + + @contextlib.asynccontextmanager + async def drop_eof_rcvd(self): + """Drop EOF before it's received by the connection.""" + with self.remote_connection.drop_eof_sent(): + yield + await asyncio.sleep(MS) # let the remote side process messages + + # Test __aenter__ and __aexit__. + + async def test_aenter(self): + """__aenter__ returns the connection itself.""" + async with self.connection as connection: + self.assertIs(connection, self.connection) + + async def test_aexit(self): + """__aexit__ closes the connection with code 1000.""" + async with self.connection: + await self.assertNoFrameSent() + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + async def test_exit_with_exception(self): + """__exit__ with an exception closes the connection with code 1011.""" + with self.assertRaises(RuntimeError): + async with self.connection: + raise RuntimeError + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xf3")) + + # Test __aiter__. + + async def test_aiter_text(self): + """__aiter__ yields text messages.""" + aiterator = aiter(self.connection) + await self.remote_connection.send("😀") + self.assertEqual(await anext(aiterator), "😀") + await self.remote_connection.send("😀") + self.assertEqual(await anext(aiterator), "😀") + + async def test_aiter_binary(self): + """__aiter__ yields binary messages.""" + aiterator = aiter(self.connection) + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(aiterator), b"\x01\x02\xfe\xff") + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(aiterator), b"\x01\x02\xfe\xff") + + async def test_aiter_mixed(self): + """__aiter__ yields a mix of text and binary messages.""" + aiterator = aiter(self.connection) + await self.remote_connection.send("😀") + self.assertEqual(await anext(aiterator), "😀") + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(aiterator), b"\x01\x02\xfe\xff") + + async def test_aiter_connection_closed_ok(self): + """__aiter__ terminates after a normal closure.""" + aiterator = aiter(self.connection) + await self.remote_connection.close() + with self.assertRaises(StopAsyncIteration): + await anext(aiterator) + + async def test_aiter_connection_closed_error(self): + """__aiter__ raises ConnnectionClosedError after an error.""" + aiterator = aiter(self.connection) + await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + await anext(aiterator) + + # Test recv. + + async def test_recv_text(self): + """recv receives a text message.""" + await self.remote_connection.send("😀") + self.assertEqual(await self.connection.recv(), "😀") + + async def test_recv_binary(self): + """recv receives a binary message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await self.connection.recv(), b"\x01\x02\xfe\xff") + + async def test_recv_fragmented_text(self): + """recv receives a fragmented text message.""" + await self.remote_connection.send(["😀", "😀"]) + self.assertEqual(await self.connection.recv(), "😀😀") + + async def test_recv_fragmented_binary(self): + """recv receives a fragmented binary message.""" + await self.remote_connection.send([b"\x01\x02", b"\xfe\xff"]) + self.assertEqual(await self.connection.recv(), b"\x01\x02\xfe\xff") + + async def test_recv_connection_closed_ok(self): + """recv raises ConnectionClosedOK after a normal closure.""" + await self.remote_connection.close() + with self.assertRaises(ConnectionClosedOK): + await self.connection.recv() + + async def test_recv_connection_closed_error(self): + """recv raises ConnectionClosedError after an error.""" + await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + await self.connection.recv() + + async def test_recv_during_recv(self): + """recv raises RuntimeError when called concurrently with itself.""" + recv_task = asyncio.create_task(self.connection.recv()) + await asyncio.sleep(0) # let the event loop start recv_task + self.addCleanup(recv_task.cancel) + + with self.assertRaises(RuntimeError) as raised: + await self.connection.recv() + self.assertEqual( + str(raised.exception), + "cannot call recv while another coroutine " + "is already running recv or recv_streaming", + ) + + async def test_recv_during_recv_streaming(self): + """recv raises RuntimeError when called concurrently with recv_streaming.""" + recv_streaming_task = asyncio.create_task( + alist(self.connection.recv_streaming()) + ) + await asyncio.sleep(0) # let the event loop start recv_streaming_task + self.addCleanup(recv_streaming_task.cancel) + + with self.assertRaises(RuntimeError) as raised: + await self.connection.recv() + self.assertEqual( + str(raised.exception), + "cannot call recv while another coroutine " + "is already running recv or recv_streaming", + ) + + async def test_recv_cancellation_before_receiving(self): + """recv can be cancelled before receiving a frame.""" + recv_task = asyncio.create_task(self.connection.recv()) + await asyncio.sleep(0) # let the event loop start recv_task + + recv_task.cancel() + await asyncio.sleep(0) # let the event loop cancel recv_task + + # Running recv again receives the next message. + await self.remote_connection.send("😀") + self.assertEqual(await self.connection.recv(), "😀") + + async def test_recv_cancellation_while_receiving(self): + """recv cannot be cancelled after receiving a frame.""" + recv_task = asyncio.create_task(self.connection.recv()) + await asyncio.sleep(0) # let the event loop start recv_task + + gate = asyncio.get_running_loop().create_future() + + async def fragments(): + yield "⏳" + await gate + yield "⌛️" + + asyncio.create_task(self.remote_connection.send(fragments())) + await asyncio.sleep(MS) + + recv_task.cancel() + await asyncio.sleep(0) # let the event loop cancel recv_task + + # Running recv again receives the complete message. + gate.set_result(None) + self.assertEqual(await self.connection.recv(), "⏳⌛️") + + # Test recv_streaming. + + async def test_recv_streaming_text(self): + """recv_streaming receives a text message.""" + await self.remote_connection.send("😀") + self.assertEqual( + await alist(self.connection.recv_streaming()), + ["😀"], + ) + + async def test_recv_streaming_binary(self): + """recv_streaming receives a binary message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual( + await alist(self.connection.recv_streaming()), + [b"\x01\x02\xfe\xff"], + ) + + async def test_recv_streaming_fragmented_text(self): + """recv_streaming receives a fragmented text message.""" + await self.remote_connection.send(["😀", "😀"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_recv_streaming_fragmented_binary(self): + """recv_streaming receives a fragmented binary message.""" + await self.remote_connection.send([b"\x01\x02", b"\xfe\xff"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.connection.recv_streaming()), + [b"\x01\x02", b"\xfe\xff", b""], + ) + + async def test_recv_streaming_connection_closed_ok(self): + """recv_streaming raises ConnectionClosedOK after a normal closure.""" + await self.remote_connection.close() + with self.assertRaises(ConnectionClosedOK): + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + + async def test_recv_streaming_connection_closed_error(self): + """recv_streaming raises ConnectionClosedError after an error.""" + await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + + async def test_recv_streaming_during_recv(self): + """recv_streaming raises RuntimeError when called concurrently with recv.""" + recv_task = asyncio.create_task(self.connection.recv()) + await asyncio.sleep(0) # let the event loop start recv_task + self.addCleanup(recv_task.cancel) + + with self.assertRaises(RuntimeError) as raised: + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "cannot call recv_streaming while another coroutine " + "is already running recv or recv_streaming", + ) + + async def test_recv_streaming_during_recv_streaming(self): + """recv_streaming raises RuntimeError when called concurrently with itself.""" + recv_streaming_task = asyncio.create_task( + alist(self.connection.recv_streaming()) + ) + await asyncio.sleep(0) # let the event loop start recv_streaming_task + self.addCleanup(recv_streaming_task.cancel) + + with self.assertRaises(RuntimeError) as raised: + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + r"cannot call recv_streaming while another coroutine " + r"is already running recv or recv_streaming", + ) + + async def test_recv_streaming_cancellation_before_receiving(self): + """recv_streaming can be cancelled before receiving a frame.""" + recv_streaming_task = asyncio.create_task( + alist(self.connection.recv_streaming()) + ) + await asyncio.sleep(0) # let the event loop start recv_streaming_task + + recv_streaming_task.cancel() + await asyncio.sleep(0) # let the event loop cancel recv_streaming_task + + # Running recv_streaming again receives the next message. + await self.remote_connection.send(["😀", "😀"]) + self.assertEqual( + await alist(self.connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_recv_streaming_cancellation_while_receiving(self): + """recv_streaming cannot be cancelled after receiving a frame.""" + recv_streaming_task = asyncio.create_task( + alist(self.connection.recv_streaming()) + ) + await asyncio.sleep(0) # let the event loop start recv_streaming_task + + gate = asyncio.get_running_loop().create_future() + + async def fragments(): + yield "⏳" + await gate + yield "⌛️" + + asyncio.create_task(self.remote_connection.send(fragments())) + await asyncio.sleep(MS) + + recv_streaming_task.cancel() + await asyncio.sleep(0) # let the event loop cancel recv_streaming_task + + gate.set_result(None) + # Running recv_streaming again fails. + with self.assertRaises(RuntimeError): + await alist(self.connection.recv_streaming()) + + # Test send. + + async def test_send_text(self): + """send sends a text message.""" + await self.connection.send("😀") + self.assertEqual(await self.remote_connection.recv(), "😀") + + async def test_send_binary(self): + """send sends a binary message.""" + await self.connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await self.remote_connection.recv(), b"\x01\x02\xfe\xff") + + async def test_send_fragmented_text(self): + """send sends a fragmented text message.""" + await self.connection.send(["😀", "😀"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_send_fragmented_binary(self): + """send sends a fragmented binary message.""" + await self.connection.send([b"\x01\x02", b"\xfe\xff"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + [b"\x01\x02", b"\xfe\xff", b""], + ) + + async def test_send_async_fragmented_text(self): + """send sends a fragmented text message asynchronously.""" + + async def fragments(): + yield "😀" + yield "😀" + + await self.connection.send(fragments()) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_send_async_fragmented_binary(self): + """send sends a fragmented binary message asynchronously.""" + + async def fragments(): + yield b"\x01\x02" + yield b"\xfe\xff" + + await self.connection.send(fragments()) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + [b"\x01\x02", b"\xfe\xff", b""], + ) + + async def test_send_connection_closed_ok(self): + """send raises ConnectionClosedOK after a normal closure.""" + await self.remote_connection.close() + with self.assertRaises(ConnectionClosedOK): + await self.connection.send("😀") + + async def test_send_connection_closed_error(self): + """send raises ConnectionClosedError after an error.""" + await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + await self.connection.send("😀") + + async def test_send_while_send_blocked(self): + """send waits for a previous call to send to complete.""" + # This test fails if the guard with fragmented_send_waiter is removed + # from send() in the case when message is an Iterable. + self.connection.pause_writing() + asyncio.create_task(self.connection.send(["⏳", "⌛️"])) + await asyncio.sleep(MS) + await self.assertFrameSent( + Frame(Opcode.TEXT, "⏳".encode(), fin=False), + ) + + asyncio.create_task(self.connection.send("✅")) + await asyncio.sleep(MS) + await self.assertNoFrameSent() + + self.connection.resume_writing() + await asyncio.sleep(MS) + await self.assertFramesSent( + [ + Frame(Opcode.CONT, "⌛️".encode(), fin=False), + Frame(Opcode.CONT, b"", fin=True), + Frame(Opcode.TEXT, "✅".encode()), + ] + ) + + async def test_send_while_send_async_blocked(self): + """send waits for a previous call to send to complete.""" + # This test fails if the guard with fragmented_send_waiter is removed + # from send() in the case when message is an AsyncIterable. + self.connection.pause_writing() + + async def fragments(): + yield "⏳" + yield "⌛️" + + asyncio.create_task(self.connection.send(fragments())) + await asyncio.sleep(MS) + await self.assertFrameSent( + Frame(Opcode.TEXT, "⏳".encode(), fin=False), + ) + + asyncio.create_task(self.connection.send("✅")) + await asyncio.sleep(MS) + await self.assertNoFrameSent() + + self.connection.resume_writing() + await asyncio.sleep(MS) + await self.assertFramesSent( + [ + Frame(Opcode.CONT, "⌛️".encode(), fin=False), + Frame(Opcode.CONT, b"", fin=True), + Frame(Opcode.TEXT, "✅".encode()), + ] + ) + + async def test_send_during_send_async(self): + """send waits for a previous call to send to complete.""" + # This test fails if the guard with fragmented_send_waiter is removed + # from send() in the case when message is an AsyncIterable. + gate = asyncio.get_running_loop().create_future() + + async def fragments(): + yield "⏳" + await gate + yield "⌛️" + + asyncio.create_task(self.connection.send(fragments())) + await asyncio.sleep(MS) + await self.assertFrameSent( + Frame(Opcode.TEXT, "⏳".encode(), fin=False), + ) + + asyncio.create_task(self.connection.send("✅")) + await asyncio.sleep(MS) + await self.assertNoFrameSent() + + gate.set_result(None) + await asyncio.sleep(MS) + await self.assertFramesSent( + [ + Frame(Opcode.CONT, "⌛️".encode(), fin=False), + Frame(Opcode.CONT, b"", fin=True), + Frame(Opcode.TEXT, "✅".encode()), + ] + ) + + async def test_send_empty_iterable(self): + """send does nothing when called with an empty iterable.""" + await self.connection.send([]) + await self.connection.close() + self.assertEqual(await alist(self.remote_connection), []) + + async def test_send_mixed_iterable(self): + """send raises TypeError when called with an iterable of inconsistent types.""" + with self.assertRaises(TypeError): + await self.connection.send(["😀", b"\xfe\xff"]) + + async def test_send_unsupported_iterable(self): + """send raises TypeError when called with an iterable of unsupported type.""" + with self.assertRaises(TypeError): + await self.connection.send([None]) + + async def test_send_empty_async_iterable(self): + """send does nothing when called with an empty async iterable.""" + + async def fragments(): + return + yield # pragma: no cover + + await self.connection.send(fragments()) + await self.connection.close() + self.assertEqual(await alist(self.remote_connection), []) + + async def test_send_mixed_async_iterable(self): + """send raises TypeError when called with an iterable of inconsistent types.""" + + async def fragments(): + yield "😀" + yield b"\xfe\xff" + + with self.assertRaises(TypeError): + await self.connection.send(fragments()) + + async def test_send_unsupported_async_iterable(self): + """send raises TypeError when called with an iterable of unsupported type.""" + + async def fragments(): + yield None + + with self.assertRaises(TypeError): + await self.connection.send(fragments()) + + async def test_send_dict(self): + """send raises TypeError when called with a dict.""" + with self.assertRaises(TypeError): + await self.connection.send({"type": "object"}) + + async def test_send_unsupported_type(self): + """send raises TypeError when called with an unsupported type.""" + with self.assertRaises(TypeError): + await self.connection.send(None) + + # Test close. + + async def test_close(self): + """close sends a close frame.""" + await self.connection.close() + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + async def test_close_explicit_code_reason(self): + """close sends a close frame with a given code and reason.""" + await self.connection.close(CloseCode.GOING_AWAY, "bye!") + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe9bye!")) + + async def test_close_waits_for_close_frame(self): + """close waits for a close frame (then EOF) before returning.""" + async with self.delay_frames_rcvd(MS), self.delay_eof_rcvd(MS): + await self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_waits_for_connection_closed(self): + """close waits for EOF before returning.""" + if self.LOCAL is SERVER: + self.skipTest("only relevant on the client-side") + + async with self.delay_eof_rcvd(MS): + await self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_no_timeout_waits_for_close_frame(self): + """close without timeout waits for a close frame (then EOF) before returning.""" + self.connection.close_timeout = None + + async with self.delay_frames_rcvd(MS), self.delay_eof_rcvd(MS): + await self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_no_timeout_waits_for_connection_closed(self): + """close without timeout waits for EOF before returning.""" + if self.LOCAL is SERVER: + self.skipTest("only relevant on the client-side") + + self.connection.close_timeout = None + + async with self.delay_eof_rcvd(MS): + await self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_timeout_waiting_for_close_frame(self): + """close times out if no close frame is received.""" + async with self.drop_eof_rcvd(), self.drop_frames_rcvd(): + await self.connection.close() + + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); no close frame received") + self.assertIsInstance(exc.__cause__, TimeoutError) + + async def test_close_timeout_waiting_for_connection_closed(self): + """close times out if EOF isn't received.""" + if self.LOCAL is SERVER: + self.skipTest("only relevant on the client-side") + + async with self.drop_eof_rcvd(): + await self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + # Remove socket.timeout when dropping Python < 3.10. + self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) + + async def test_close_does_not_wait_for_recv(self): + # The asyncio implementation has a buffer for incoming messages. Closing + # the connection discards buffered messages. This is allowed by the RFC: + # > However, there is no guarantee that the endpoint that has already + # > sent a Close frame will continue to process data. + await self.remote_connection.send("😀") + await self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_idempotency(self): + """close does nothing if the connection is already closed.""" + await self.connection.close() + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + await self.connection.close() + await self.assertNoFrameSent() + + async def test_close_during_recv(self): + """close aborts recv when called concurrently with recv.""" + recv_task = asyncio.create_task(self.connection.recv()) + await asyncio.sleep(MS) + await self.connection.close() + with self.assertRaises(ConnectionClosedOK) as raised: + await recv_task + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_during_send(self): + """close fails the connection when called concurrently with send.""" + gate = asyncio.get_running_loop().create_future() + + async def fragments(): + yield "⏳" + await gate + yield "⌛️" + + send_task = asyncio.create_task(self.connection.send(fragments())) + await asyncio.sleep(MS) + + asyncio.create_task(self.connection.close()) + await asyncio.sleep(MS) + + gate.set_result(None) + + with self.assertRaises(ConnectionClosedError) as raised: + await send_task + + exc = raised.exception + self.assertEqual( + str(exc), + "sent 1011 (internal error) close during fragmented message; " + "no close frame received", + ) + self.assertIsNone(exc.__cause__) + + # Test wait_closed. + + async def test_wait_closed(self): + """wait_closed waits for the connection to close.""" + wait_closed_task = asyncio.create_task(self.connection.wait_closed()) + await asyncio.sleep(0) # let the event loop start wait_closed_task + self.assertFalse(wait_closed_task.done()) + await self.connection.close() + self.assertTrue(wait_closed_task.done()) + + # Test ping. + + @patch("random.getrandbits") + async def test_ping(self, getrandbits): + """ping sends a ping frame with a random payload.""" + getrandbits.return_value = 1918987876 + await self.connection.ping() + getrandbits.assert_called_once_with(32) + await self.assertFrameSent(Frame(Opcode.PING, b"rand")) + + async def test_ping_explicit_text(self): + """ping sends a ping frame with a payload provided as text.""" + await self.connection.ping("ping") + await self.assertFrameSent(Frame(Opcode.PING, b"ping")) + + async def test_ping_explicit_binary(self): + """ping sends a ping frame with a payload provided as binary.""" + await self.connection.ping(b"ping") + await self.assertFrameSent(Frame(Opcode.PING, b"ping")) + + async def test_acknowledge_ping(self): + """ping is acknowledged by a pong with the same payload.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_waiter = await self.connection.ping("this") + await self.remote_connection.pong("this") + async with asyncio_timeout(MS): + await pong_waiter + + async def test_acknowledge_ping_non_matching_pong(self): + """ping isn't acknowledged by a pong with a different payload.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_waiter = await self.connection.ping("this") + await self.remote_connection.pong("that") + with self.assertRaises(TimeoutError): + async with asyncio_timeout(MS): + await pong_waiter + + async def test_acknowledge_previous_ping(self): + """ping is acknowledged by a pong with the same payload as a later ping.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_waiter = await self.connection.ping("this") + await self.connection.ping("that") + await self.remote_connection.pong("that") + async with asyncio_timeout(MS): + await pong_waiter + + async def test_ping_duplicate_payload(self): + """ping rejects the same payload until receiving the pong.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_waiter = await self.connection.ping("idem") + + with self.assertRaises(RuntimeError) as raised: + await self.connection.ping("idem") + self.assertEqual( + str(raised.exception), + "already waiting for a pong with the same data", + ) + + await self.remote_connection.pong("idem") + async with asyncio_timeout(MS): + await pong_waiter + + await self.connection.ping("idem") # doesn't raise an exception + + # Test pong. + + async def test_pong(self): + """pong sends a pong frame.""" + await self.connection.pong() + await self.assertFrameSent(Frame(Opcode.PONG, b"")) + + async def test_pong_explicit_text(self): + """pong sends a pong frame with a payload provided as text.""" + await self.connection.pong("pong") + await self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + + async def test_pong_explicit_binary(self): + """pong sends a pong frame with a payload provided as binary.""" + await self.connection.pong(b"pong") + await self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + + # Test attributes. + + async def test_id(self): + """Connection has an id attribute.""" + self.assertIsInstance(self.connection.id, uuid.UUID) + + async def test_logger(self): + """Connection has a logger attribute.""" + self.assertIsInstance(self.connection.logger, logging.LoggerAdapter) + + @unittest.mock.patch( + "asyncio.BaseTransport.get_extra_info", return_value=("sock", 1234) + ) + async def test_local_address(self, get_extra_info): + """Connection provides a local_address attribute.""" + self.assertEqual(self.connection.local_address, ("sock", 1234)) + get_extra_info.assert_called_with("sockname") + + @unittest.mock.patch( + "asyncio.BaseTransport.get_extra_info", return_value=("peer", 1234) + ) + async def test_remote_address(self, get_extra_info): + """Connection provides a remote_address attribute.""" + self.assertEqual(self.connection.remote_address, ("peer", 1234)) + get_extra_info.assert_called_with("peername") + + async def test_request(self): + """Connection has a request attribute.""" + self.assertIsNone(self.connection.request) + + async def test_response(self): + """Connection has a response attribute.""" + self.assertIsNone(self.connection.response) + + async def test_subprotocol(self): + """Connection has a subprotocol attribute.""" + self.assertIsNone(self.connection.subprotocol) + + # Test reporting of network errors. + + async def test_writing_in_data_received_fails(self): + """Error when responding to incoming frames is correctly reported.""" + # Inject a fault by shutting down the transport for writing — but not by + # closing it because that would terminate the connection. + self.transport.write_eof() + # Receive a ping. Responding with a pong will fail. + await self.remote_connection.ping() + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.recv() + cause = raised.exception.__cause__ + self.assertEqual(str(cause), "Cannot call write() after write_eof()") + self.assertIsInstance(cause, RuntimeError) + + async def test_writing_in_send_context_fails(self): + """Error when sending outgoing frame is correctly reported.""" + # Inject a fault by shutting down the transport for writing — but not by + # closing it because that would terminate the connection. + self.transport.write_eof() + # Sending a pong will fail. + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.pong() + cause = raised.exception.__cause__ + self.assertEqual(str(cause), "Cannot call write() after write_eof()") + self.assertIsInstance(cause, RuntimeError) + + # Test safety nets — catching all exceptions in case of bugs. + + @patch("websockets.protocol.Protocol.events_received") + async def test_unexpected_failure_in_data_received(self, events_received): + """Unexpected internal error in data_received() is correctly reported.""" + # Inject a fault in a random call in data_received(). + # This test is tightly coupled to the implementation. + events_received.side_effect = AssertionError + # Receive a message to trigger the fault. + await self.remote_connection.send("😀") + + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "no close frame received or sent") + self.assertIsInstance(exc.__cause__, AssertionError) + + @patch("websockets.protocol.Protocol.send_text") + async def test_unexpected_failure_in_send_context(self, send_text): + """Unexpected internal error in send_context() is correctly reported.""" + # Inject a fault in a random call in send_context(). + # This test is tightly coupled to the implementation. + send_text.side_effect = AssertionError + + # Send a message to trigger the fault. + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.send("😀") + + exc = raised.exception + self.assertEqual(str(exc), "no close frame received or sent") + self.assertIsInstance(exc.__cause__, AssertionError) + + +class ServerConnectionTests(ClientConnectionTests): + LOCAL = SERVER + REMOTE = CLIENT diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py new file mode 100644 index 000000000..2e59f49b1 --- /dev/null +++ b/tests/asyncio/test_server.py @@ -0,0 +1,525 @@ +import asyncio +import dataclasses +import http +import logging +import socket +import unittest + +from websockets.asyncio.compatibility import TimeoutError, asyncio_timeout +from websockets.asyncio.server import * +from websockets.exceptions import ( + ConnectionClosedError, + ConnectionClosedOK, + InvalidStatus, + NegotiationError, +) +from websockets.http11 import Request, Response + +from ..utils import ( + CLIENT_CONTEXT, + MS, + SERVER_CONTEXT, + temp_unix_socket_path, +) +from .client import run_client, run_unix_client +from .server import ( + EvalShellMixin, + crash, + do_nothing, + eval_shell, + get_server_host_port, + keep_running, + run_server, + run_unix_server, +) + + +class ServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Server receives connection from client and the handshake succeeds.""" + async with run_server() as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + + async def test_connection_handler_returns(self): + """Connection handler returns.""" + async with run_server(do_nothing) as server: + async with run_client(server) as client: + with self.assertRaises(ConnectionClosedOK) as raised: + await client.recv() + self.assertEqual( + str(raised.exception), + "received 1000 (OK); then sent 1000 (OK)", + ) + + async def test_connection_handler_raises_exception(self): + """Connection handler raises an exception.""" + async with run_server(crash) as server: + async with run_client(server) as client: + with self.assertRaises(ConnectionClosedError) as raised: + await client.recv() + self.assertEqual( + str(raised.exception), + "received 1011 (internal error); " + "then sent 1011 (internal error)", + ) + + async def test_existing_socket(self): + """Server receives connection using a pre-existing socket.""" + with socket.create_server(("localhost", 0)) as sock: + async with run_server(sock=sock, host=None, port=None): + uri = "ws://{}:{}/".format(*sock.getsockname()) + async with run_client(uri) as client: + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + + async def test_select_subprotocol(self): + """Server selects a subprotocol with the select_subprotocol callable.""" + + def select_subprotocol(ws, subprotocols): + ws.select_subprotocol_ran = True + assert "chat" in subprotocols + return "chat" + + async with run_server( + subprotocols=["chat"], + select_subprotocol=select_subprotocol, + ) as server: + async with run_client(server, subprotocols=["chat"]) as client: + await self.assertEval(client, "ws.select_subprotocol_ran", "True") + await self.assertEval(client, "ws.subprotocol", "chat") + + async def test_select_subprotocol_rejects_handshake(self): + """Server rejects handshake if select_subprotocol raises NegotiationError.""" + + def select_subprotocol(ws, subprotocols): + raise NegotiationError + + async with run_server(select_subprotocol=select_subprotocol) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 400", + ) + + async def test_select_subprotocol_raises_exception(self): + """Server returns an error if select_subprotocol raises an exception.""" + + def select_subprotocol(ws, subprotocols): + raise RuntimeError + + async with run_server(select_subprotocol=select_subprotocol) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + + async def test_process_request(self): + """Server runs process_request before processing the handshake.""" + + def process_request(ws, request): + self.assertIsInstance(request, Request) + ws.process_request_ran = True + + async with run_server(process_request=process_request) as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.process_request_ran", "True") + + async def test_async_process_request(self): + """Server runs async process_request before processing the handshake.""" + + async def process_request(ws, request): + self.assertIsInstance(request, Request) + ws.process_request_ran = True + + async with run_server(process_request=process_request) as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.process_request_ran", "True") + + async def test_process_request_abort_handshake(self): + """Server aborts handshake if process_request returns a response.""" + + def process_request(ws, request): + return ws.protocol.reject(http.HTTPStatus.FORBIDDEN, "Forbidden") + + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) + + async def test_async_process_request_abort_handshake(self): + """Server aborts handshake if async process_request returns a response.""" + + async def process_request(ws, request): + return ws.protocol.reject(http.HTTPStatus.FORBIDDEN, "Forbidden") + + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) + + async def test_process_request_raises_exception(self): + """Server returns an error if process_request raises an exception.""" + + def process_request(ws, request): + raise RuntimeError + + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + + async def test_async_process_request_raises_exception(self): + """Server returns an error if async process_request raises an exception.""" + + async def process_request(ws, request): + raise RuntimeError + + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + + async def test_process_response(self): + """Server runs process_response after processing the handshake.""" + + def process_response(ws, request, response): + self.assertIsInstance(request, Request) + self.assertIsInstance(response, Response) + ws.process_response_ran = True + + async with run_server(process_response=process_response) as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.process_response_ran", "True") + + async def test_async_process_response(self): + """Server runs async process_response after processing the handshake.""" + + async def process_response(ws, request, response): + self.assertIsInstance(request, Request) + self.assertIsInstance(response, Response) + ws.process_response_ran = True + + async with run_server(process_response=process_response) as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.process_response_ran", "True") + + async def test_process_response_override_response(self): + """Server runs process_response and overrides the handshake response.""" + + def process_response(ws, request, response): + headers = response.headers.copy() + headers["X-ProcessResponse-Ran"] = "true" + return dataclasses.replace(response, headers=headers) + + async with run_server(process_response=process_response) as server: + async with run_client(server) as client: + self.assertEqual( + client.response.headers["X-ProcessResponse-Ran"], "true" + ) + + async def test_async_process_response_override_response(self): + """Server runs async process_response and overrides the handshake response.""" + + async def process_response(ws, request, response): + headers = response.headers.copy() + headers["X-ProcessResponse-Ran"] = "true" + return dataclasses.replace(response, headers=headers) + + async with run_server(process_response=process_response) as server: + async with run_client(server) as client: + self.assertEqual( + client.response.headers["X-ProcessResponse-Ran"], "true" + ) + + async def test_process_response_raises_exception(self): + """Server returns an error if process_response raises an exception.""" + + def process_response(ws, request, response): + raise RuntimeError + + async with run_server(process_response=process_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + + async def test_async_process_response_raises_exception(self): + """Server returns an error if async process_response raises an exception.""" + + async def process_response(ws, request, response): + raise RuntimeError + + async with run_server(process_response=process_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + + async def test_override_server(self): + """Server can override Server header with server_header.""" + async with run_server(server_header="Neo") as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.response.headers['Server']", "Neo") + + async def test_remove_server(self): + """Server can remove Server header with server_header.""" + async with run_server(server_header=None) as server: + async with run_client(server) as client: + await self.assertEval( + client, "'Server' in ws.response.headers", "False" + ) + + async def test_compression_is_enabled(self): + """Server enables compression by default.""" + async with run_server() as server: + async with run_client(server) as client: + await self.assertEval( + client, + "[type(ext).__name__ for ext in ws.protocol.extensions]", + "['PerMessageDeflate']", + ) + + async def test_disable_compression(self): + """Server disables compression.""" + async with run_server(compression=None) as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.protocol.extensions", "[]") + + async def test_custom_connection_factory(self): + """Server runs ServerConnection factory provided in create_connection.""" + + def create_connection(*args, **kwargs): + server = ServerConnection(*args, **kwargs) + server.create_connection_ran = True + return server + + async with run_server(create_connection=create_connection) as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.create_connection_ran", "True") + + async def test_handshake_fails(self): + """Server receives connection from client but the handshake fails.""" + + def remove_key_header(self, request): + del request.headers["Sec-WebSocket-Key"] + + async with run_server(process_request=remove_key_header) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 400", + ) + + async def test_timeout_during_handshake(self): + """Server times out before receiving handshake request from client.""" + async with run_server(open_timeout=MS) as server: + reader, writer = await asyncio.open_connection( + *get_server_host_port(server) + ) + try: + self.assertEqual(await reader.read(4096), b"") + finally: + writer.close() + + async def test_connection_closed_during_handshake(self): + """Server reads EOF before receiving handshake request from client.""" + async with run_server() as server: + _reader, writer = await asyncio.open_connection( + *get_server_host_port(server) + ) + writer.close() + + async def test_close_server_rejects_connecting_connections(self): + """Server rejects connecting connections with HTTP 503 when closing.""" + + async def process_request(ws, _request): + while ws.server.is_serving(): + await asyncio.sleep(0) + + async with run_server(process_request=process_request) as server: + asyncio.get_running_loop().call_later(MS, server.close) + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 503", + ) + + async def test_close_server_closes_open_connections(self): + """Server closes open connections with close code 1001 when closing.""" + async with run_server() as server: + async with run_client(server) as client: + server.close() + with self.assertRaises(ConnectionClosedOK) as raised: + await client.recv() + self.assertEqual( + str(raised.exception), + "received 1001 (going away); then sent 1001 (going away)", + ) + + async def test_close_server_keeps_connections_open(self): + """Server waits for client to close open connections when closing.""" + async with run_server() as server: + async with run_client(server) as client: + server.close(close_connections=False) + + # Server cannot receive new connections. + await asyncio.sleep(0) + self.assertFalse(server.sockets) + + # The server waits for the client to close the connection. + with self.assertRaises(TimeoutError): + async with asyncio_timeout(MS): + await server.wait_closed() + + # Once the client closes the connection, the server terminates. + await client.close() + async with asyncio_timeout(MS): + await server.wait_closed() + + async def test_close_server_keeps_handlers_running(self): + """Server waits for connection handlers to terminate.""" + async with run_server(keep_running) as server: + async with run_client(server) as client: + # Delay termination of connection handler. + await client.send(str(2 * MS)) + + server.close() + + # The server waits for the connection handler to terminate. + with self.assertRaises(TimeoutError): + async with asyncio_timeout(MS): + await server.wait_closed() + + async with asyncio_timeout(2 * MS): + await server.wait_closed() + + +SSL_OBJECT = "ws.transport.get_extra_info('ssl_object')" + + +class SecureServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Server receives secure connection from client.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + async with run_client(server, ssl=CLIENT_CONTEXT) as client: + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + await self.assertEval(client, SSL_OBJECT + ".version()[:3]", "TLS") + + async def test_timeout_during_tls_handshake(self): + """Server times out before receiving TLS handshake request from client.""" + async with run_server(ssl=SERVER_CONTEXT, open_timeout=MS) as server: + reader, writer = await asyncio.open_connection( + *get_server_host_port(server) + ) + try: + self.assertEqual(await reader.read(4096), b"") + finally: + writer.close() + + async def test_connection_closed_during_tls_handshake(self): + """Server reads EOF before receiving TLS handshake request from client.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + _reader, writer = await asyncio.open_connection( + *get_server_host_port(server) + ) + writer.close() + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class UnixServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Server receives connection from client over a Unix socket.""" + with temp_unix_socket_path() as path: + async with run_unix_server(path): + async with run_unix_client(path) as client: + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class SecureUnixServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Server receives secure connection from client over a Unix socket.""" + with temp_unix_socket_path() as path: + async with run_unix_server(path, ssl=SERVER_CONTEXT): + async with run_unix_client(path, ssl=CLIENT_CONTEXT) as client: + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + await self.assertEval(client, SSL_OBJECT + ".version()[:3]", "TLS") + + +class ServerUsageErrorsTests(unittest.IsolatedAsyncioTestCase): + async def test_unix_without_path_or_sock(self): + """Unix server requires path when sock isn't provided.""" + with self.assertRaises(ValueError) as raised: + await unix_serve(eval_shell) + self.assertEqual( + str(raised.exception), + "path was not specified, and no sock specified", + ) + + async def test_unix_with_path_and_sock(self): + """Unix server rejects path when sock is provided.""" + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.addCleanup(sock.close) + with self.assertRaises(ValueError) as raised: + await unix_serve(eval_shell, path="/", sock=sock) + self.assertEqual( + str(raised.exception), + "path and sock can not be specified at the same time", + ) + + async def test_invalid_subprotocol(self): + """Server rejects single value of subprotocols.""" + with self.assertRaises(TypeError) as raised: + await serve(eval_shell, subprotocols="chat") + self.assertEqual( + str(raised.exception), + "subprotocols must be a list, not a str", + ) + + async def test_unsupported_compression(self): + """Server rejects incorrect value of compression.""" + with self.assertRaises(ValueError) as raised: + await serve(eval_shell, compression=False) + self.assertEqual( + str(raised.exception), + "unsupported compression: False", + ) + + +class WebSocketServerTests(unittest.IsolatedAsyncioTestCase): + async def test_logger(self): + """WebSocketServer accepts a logger argument.""" + logger = logging.getLogger("test") + async with run_server(logger=logger) as server: + self.assertIs(server.logger, logger) From d2120de4708d09d3cade465541f202d5a8bab722 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 7 Aug 2024 08:31:06 +0200 Subject: [PATCH 1313/1539] Add an option to disable decoding of text frames. Also support decoding binary frames. Fix #1376. --- src/websockets/asyncio/connection.py | 34 ++++++++++++++++++++++++---- src/websockets/asyncio/messages.py | 10 ++++++++ tests/asyncio/test_connection.py | 26 +++++++++++++++++++++ 3 files changed, 66 insertions(+), 4 deletions(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 550c0ac97..152c6789e 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -172,7 +172,7 @@ async def __aiter__(self) -> AsyncIterator[Data]: except ConnectionClosedOK: return - async def recv(self) -> Data: + async def recv(self, decode: bool | None = None) -> Data: """ Receive the next message. @@ -192,6 +192,10 @@ async def recv(self) -> Data: When the message is fragmented, :meth:`recv` waits until all fragments are received, reassembles them, and returns the whole message. + Args: + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + Returns: A string (:class:`str`) for a Text_ frame or a bytestring (:class:`bytes`) for a Binary_ frame. @@ -199,6 +203,15 @@ async def recv(self) -> Data: .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames + and return a bytestring (:class:`bytes`). This may be useful to + optimize performance when decoding isn't needed. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames + and return a string (:class:`str`). This is useful for servers + that send binary frames instead of text frames. + Raises: ConnectionClosed: When the connection is closed. RuntimeError: If two coroutines call :meth:`recv` or @@ -206,7 +219,7 @@ async def recv(self) -> Data: """ try: - return await self.recv_messages.get() + return await self.recv_messages.get(decode) except EOFError: raise self.protocol.close_exc from self.recv_exc except RuntimeError: @@ -215,7 +228,7 @@ async def recv(self) -> Data: "is already running recv or recv_streaming" ) from None - async def recv_streaming(self) -> AsyncIterator[Data]: + async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]: """ Receive the next message frame by frame. @@ -232,6 +245,10 @@ async def recv_streaming(self) -> AsyncIterator[Data]: iterator in a partially consumed state, making the connection unusable. Instead, you should close the connection with :meth:`close`. + Args: + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + Returns: An iterator of strings (:class:`str`) for a Text_ frame or bytestrings (:class:`bytes`) for a Binary_ frame. @@ -239,6 +256,15 @@ async def recv_streaming(self) -> AsyncIterator[Data]: .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames + and return bytestrings (:class:`bytes`). This may be useful to + optimize performance when decoding isn't needed. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames + and return strings (:class:`str`). This is useful for servers + that send binary frames instead of text frames. + Raises: ConnectionClosed: When the connection is closed. RuntimeError: If two coroutines call :meth:`recv` or @@ -246,7 +272,7 @@ async def recv_streaming(self) -> AsyncIterator[Data]: """ try: - async for frame in self.recv_messages.get_iter(): + async for frame in self.recv_messages.get_iter(decode): yield frame except EOFError: raise self.protocol.close_exc from self.recv_exc diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 2a9c4d37d..bc33df8d7 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -121,6 +121,11 @@ async def get(self, decode: bool | None = None) -> Data: received, then it reassembles the message and returns it. To receive messages frame by frame, use :meth:`get_iter` instead. + Args: + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. + Raises: EOFError: If the stream of frames has ended. RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter` @@ -183,6 +188,11 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: This method only makes sense for fragmented messages. If messages aren't fragmented, use :meth:`get` instead. + Args: + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. + Raises: EOFError: If the stream of frames has ended. RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter` diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index a8b3980b4..2efd4e96d 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -167,6 +167,16 @@ async def test_recv_binary(self): await self.remote_connection.send(b"\x01\x02\xfe\xff") self.assertEqual(await self.connection.recv(), b"\x01\x02\xfe\xff") + async def test_recv_encoded_text(self): + """recv receives an UTF-8 encoded text message.""" + await self.remote_connection.send("😀") + self.assertEqual(await self.connection.recv(decode=False), "😀".encode()) + + async def test_recv_decoded_binary(self): + """recv receives an UTF-8 decoded binary message.""" + await self.remote_connection.send("😀".encode()) + self.assertEqual(await self.connection.recv(decode=True), "😀") + async def test_recv_fragmented_text(self): """recv receives a fragmented text message.""" await self.remote_connection.send(["😀", "😀"]) @@ -271,6 +281,22 @@ async def test_recv_streaming_binary(self): [b"\x01\x02\xfe\xff"], ) + async def test_recv_streaming_encoded_text(self): + """recv_streaming receives an UTF-8 encoded text message.""" + await self.remote_connection.send("😀") + self.assertEqual( + await alist(self.connection.recv_streaming(decode=False)), + ["😀".encode()], + ) + + async def test_recv_streaming_decoded_binary(self): + """recv_streaming receives a UTF-8 decoded binary message.""" + await self.remote_connection.send("😀".encode()) + self.assertEqual( + await alist(self.connection.recv_streaming(decode=True)), + ["😀"], + ) + async def test_recv_streaming_fragmented_text(self): """recv_streaming receives a fragmented text message.""" await self.remote_connection.send(["😀", "😀"]) From e35c15a2a70c347f6f7a3e503ff1181ac35e1298 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 7 Aug 2024 18:45:02 +0200 Subject: [PATCH 1314/1539] Reduce MS for situations with performance penalties. Nowadays it's tuned with WEBSOCKETS_TESTS_TIMEOUT_FACTOR. --- tests/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index bd3b61d7b..1793f3e8b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -45,11 +45,11 @@ # PyPy has a performance penalty for this test suite. if platform.python_implementation() == "PyPy": # pragma: no cover - MS *= 5 + MS *= 2 -# asyncio's debug mode has a 10x performance penalty for this test suite. +# asyncio's debug mode has a performance penalty for this test suite. if os.environ.get("PYTHONASYNCIODEBUG"): # pragma: no cover - MS *= 10 + MS *= 2 # Ensure that timeouts are larger than the clock's resolution (for Windows). MS = max(MS, 2.5 * time.get_clock_info("monotonic").resolution) From bbb316155a5aeb719f262873a5b29a98c19b25d9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 8 Aug 2024 10:07:44 +0200 Subject: [PATCH 1315/1539] Add env vars for configuring constants. --- docs/project/changelog.rst | 6 ++++ docs/reference/index.rst | 1 + docs/reference/variables.rst | 48 ++++++++++++++++++++++++++++++ docs/topics/logging.rst | 4 +++ docs/topics/security.rst | 38 +++++++++++++++++------ src/websockets/asyncio/client.py | 5 ++-- src/websockets/asyncio/server.py | 11 ++++--- src/websockets/frames.py | 17 ++++++----- src/websockets/http.py | 9 ------ src/websockets/http11.py | 36 +++++++++++++++++----- src/websockets/legacy/client.py | 4 +-- src/websockets/legacy/http.py | 9 +++--- src/websockets/legacy/server.py | 8 ++--- src/websockets/sync/client.py | 3 +- src/websockets/sync/server.py | 9 +++--- tests/legacy/test_client_server.py | 2 +- tests/test_http.py | 8 ----- 17 files changed, 149 insertions(+), 69 deletions(-) create mode 100644 docs/reference/variables.rst delete mode 100644 tests/test_http.py diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 108b7c9c0..8143e3483 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -73,6 +73,12 @@ New features * Validated compatibility with Python 3.12. +* Added :doc:`environment variables <../reference/variables>` to configure debug + logs, the ``Server`` and ``User-Agent`` headers, as well as security limits. + + If you were monkey-patching constants, be aware that they were renamed, which + will break your configuration. You must switch to the environment variables. + 12.0 ---- diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 2486ac564..d3a0e935c 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -85,6 +85,7 @@ These low-level APIs are shared by all implementations. datastructures exceptions types + variables API stability ------------- diff --git a/docs/reference/variables.rst b/docs/reference/variables.rst new file mode 100644 index 000000000..4bca112da --- /dev/null +++ b/docs/reference/variables.rst @@ -0,0 +1,48 @@ +Environment variables +===================== + +Logging +------- + +.. envvar:: WEBSOCKETS_MAX_LOG_SIZE + + How much of each frame to show in debug logs. + + The default value is ``75``. + +See the :doc:`logging guide <../topics/logging>` for details. + +Security +........ + +.. envvar:: WEBSOCKETS_SERVER + + Server header sent by websockets. + + The default value uses the format ``"Python/x.y.z websockets/X.Y"``. + +.. envvar:: WEBSOCKETS_USER_AGENT + + User-Agent header sent by websockets. + + The default value uses the format ``"Python/x.y.z websockets/X.Y"``. + +.. envvar:: WEBSOCKETS_MAX_LINE_LENGTH + + Maximum length of the request or status line in the opening handshake. + + The default value is ``8192``. + +.. envvar:: WEBSOCKETS_MAX_NUM_HEADERS + + Maximum number of HTTP headers in the opening handshake. + + The default value is ``128``. + +.. envvar:: WEBSOCKETS_MAX_BODY_SIZE + + Maximum size of the body of an HTTP response in the opening handshake. + + The default value is ``1_048_576`` (1 MiB). + +See the :doc:`security guide <../topics/security>` for details. diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index e7abd96ce..873c852c2 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -76,6 +76,10 @@ Here's how to enable debug logs for development:: level=logging.DEBUG, ) +By default, websockets elides the content of messages to improve readability. +If you want to see more, you can increase the :envvar:`WEBSOCKETS_MAX_LOG_SIZE` +environment variable. The default value is 75. + Furthermore, websockets adds a ``websocket`` attribute to log records, so you can include additional information about the current connection in logs. diff --git a/docs/topics/security.rst b/docs/topics/security.rst index d3dec21bd..83d79e35b 100644 --- a/docs/topics/security.rst +++ b/docs/topics/security.rst @@ -1,6 +1,8 @@ Security ======== +.. currentmodule:: websockets + Encryption ---------- @@ -27,15 +29,33 @@ an amplification factor of 1000 between network traffic and memory usage. Configuring a server to :doc:`optimize memory usage ` will improve security in addition to improving performance. -Other limits ------------- +HTTP limits +----------- + +In the opening handshake, websockets applies limits to the amount of data that +it accepts in order to minimize exposure to denial of service attacks. + +The request or status line is limited to 8192 bytes. Each header line, including +the name and value, is limited to 8192 bytes too. No more than 128 HTTP headers +are allowed. When the HTTP response includes a body, it is limited to 1 MiB. + +You may change these limits by setting the :envvar:`WEBSOCKETS_MAX_LINE_LENGTH`, +:envvar:`WEBSOCKETS_MAX_NUM_HEADERS`, and :envvar:`WEBSOCKETS_MAX_BODY_SIZE` +environment variables respectively. + +Identification +-------------- + +By default, websockets identifies itself with a ``Server`` or ``User-Agent`` +header in the format ``"Python/x.y.z websockets/X.Y"``. -websockets implements additional limits on the amount of data it accepts in -order to minimize exposure to security vulnerabilities. +You can set the ``server_header`` argument of :func:`~server.serve` or the +``user_agent_header`` argument of :func:`~client.connect` to configure another +value. Setting them to :obj:`None` removes the header. -In the opening handshake, websockets limits the number of HTTP headers to 256 -and the size of an individual header to 4096 bytes. These limits are 10 to 20 -times larger than what's expected in standard use cases. They're hard-coded. +Alternatively, you can set the :envvar:`WEBSOCKETS_SERVER` and +:envvar:`WEBSOCKETS_USER_AGENT` environment variables respectively. Setting them +to an empty string removes the header. -If you need to change these limits, you can monkey-patch the constants in -``websockets.http11``. +If both the argument and the environment variable are set, the argument takes +precedence. diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 040d68ece..ac8ded8ca 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -9,8 +9,7 @@ from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate from ..headers import validate_subprotocols -from ..http import USER_AGENT -from ..http11 import Response +from ..http11 import USER_AGENT, Response from ..protocol import CONNECTING, Event from ..typing import LoggerLike, Origin, Subprotocol from ..uri import parse_uri @@ -71,7 +70,7 @@ async def handshake( self.request = self.protocol.connect() if additional_headers is not None: self.request.headers.update(additional_headers) - if user_agent_header is not None: + if user_agent_header: self.request.headers["User-Agent"] = user_agent_header self.protocol.send_request(self.request) diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index aa175f775..0c8b8780b 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -20,8 +20,7 @@ from ..extensions.base import ServerExtensionFactory from ..extensions.permessage_deflate import enable_server_permessage_deflate from ..headers import validate_subprotocols -from ..http import USER_AGENT -from ..http11 import Request, Response +from ..http11 import SERVER, Request, Response from ..protocol import CONNECTING, Event from ..server import ServerProtocol from ..typing import LoggerLike, Origin, Subprotocol @@ -88,7 +87,7 @@ async def handshake( ] | None ) = None, - server_header: str | None = USER_AGENT, + server_header: str | None = SERVER, ) -> None: """ Perform the opening handshake. @@ -131,7 +130,7 @@ async def handshake( assert isinstance(response, Response) # help mypy self.response = response - if server_header is not None: + if server_header: self.response.headers["Server"] = server_header response = None @@ -243,7 +242,7 @@ def __init__( ] | None ) = None, - server_header: str | None = USER_AGENT, + server_header: str | None = SERVER, open_timeout: float | None = 10, logger: LoggerLike | None = None, ) -> None: @@ -631,7 +630,7 @@ def __init__( ] | None ) = None, - server_header: str | None = USER_AGENT, + server_header: str | None = SERVER, compression: str | None = "deflate", # Timeouts open_timeout: float | None = 10, diff --git a/src/websockets/frames.py b/src/websockets/frames.py index af56d3f8f..819fdd742 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -3,6 +3,7 @@ import dataclasses import enum import io +import os import secrets import struct from typing import Callable, Generator, Sequence @@ -146,8 +147,8 @@ class Frame: rsv2: bool = False rsv3: bool = False - # Monkey-patch if you want to see more in logs. Should be a multiple of 3. - MAX_LOG = 75 + # Configure if you want to see more in logs. Should be a multiple of 3. + MAX_LOG_SIZE = int(os.environ.get("WEBSOCKETS_MAX_LOG_SIZE", "75")) def __str__(self) -> str: """ @@ -166,8 +167,8 @@ def __str__(self) -> str: # We'll show at most the first 16 bytes and the last 8 bytes. # Encode just what we need, plus two dummy bytes to elide later. binary = self.data - if len(binary) > self.MAX_LOG // 3: - cut = (self.MAX_LOG // 3 - 1) // 3 # by default cut = 8 + if len(binary) > self.MAX_LOG_SIZE // 3: + cut = (self.MAX_LOG_SIZE // 3 - 1) // 3 # by default cut = 8 binary = b"".join([binary[: 2 * cut], b"\x00\x00", binary[-cut:]]) data = " ".join(f"{byte:02x}" for byte in binary) elif self.opcode is OP_CLOSE: @@ -183,16 +184,16 @@ def __str__(self) -> str: coding = "text" except (UnicodeDecodeError, AttributeError): binary = self.data - if len(binary) > self.MAX_LOG // 3: - cut = (self.MAX_LOG // 3 - 1) // 3 # by default cut = 8 + if len(binary) > self.MAX_LOG_SIZE // 3: + cut = (self.MAX_LOG_SIZE // 3 - 1) // 3 # by default cut = 8 binary = b"".join([binary[: 2 * cut], b"\x00\x00", binary[-cut:]]) data = " ".join(f"{byte:02x}" for byte in binary) coding = "binary" else: data = "''" - if len(data) > self.MAX_LOG: - cut = self.MAX_LOG // 3 - 1 # by default cut = 24 + if len(data) > self.MAX_LOG_SIZE: + cut = self.MAX_LOG_SIZE // 3 - 1 # by default cut = 24 data = data[: 2 * cut] + "..." + data[-cut:] metadata = ", ".join(filter(None, [coding, length, non_final])) diff --git a/src/websockets/http.py b/src/websockets/http.py index 9f86f6a1f..a24102307 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -1,10 +1,8 @@ from __future__ import annotations -import sys import typing from .imports import lazy_import -from .version import version as websockets_version # For backwards compatibility: @@ -26,10 +24,3 @@ "read_response": ".legacy.http", }, ) - - -__all__ = ["USER_AGENT"] - - -PYTHON_VERSION = "{}.{}".format(*sys.version_info) -USER_AGENT = f"Python/{PYTHON_VERSION} websockets/{websockets_version}" diff --git a/src/websockets/http11.py b/src/websockets/http11.py index a7e9ae682..ed49fcbf9 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -1,23 +1,43 @@ from __future__ import annotations import dataclasses +import os import re +import sys import warnings from typing import Callable, Generator from . import datastructures, exceptions +from .version import version as websockets_version +__all__ = ["SERVER", "USER_AGENT", "Request", "Response"] + + +PYTHON_VERSION = "{}.{}".format(*sys.version_info) + +# User-Agent header for HTTP requests. +USER_AGENT = os.environ.get( + "WEBSOCKETS_USER_AGENT", + f"Python/{PYTHON_VERSION} websockets/{websockets_version}", +) + +# Server header for HTTP responses. +SERVER = os.environ.get( + "WEBSOCKETS_SERVER", + f"Python/{PYTHON_VERSION} websockets/{websockets_version}", +) + # Maximum total size of headers is around 128 * 8 KiB = 1 MiB. -MAX_HEADERS = 128 +MAX_NUM_HEADERS = int(os.environ.get("WEBSOCKETS_MAX_NUM_HEADERS", "128")) # Limit request line and header lines. 8KiB is the most common default # configuration of popular HTTP servers. -MAX_LINE = 8192 +MAX_LINE_LENGTH = int(os.environ.get("WEBSOCKETS_MAX_LINE_LENGTH", "8192")) # Support for HTTP response bodies is intended to read an error message # returned by a server. It isn't designed to perform large file transfers. -MAX_BODY = 2**20 # 1 MiB +MAX_BODY_SIZE = int(os.environ.get("WEBSOCKETS_MAX_BODY_SIZE", "1_048_576")) # 1 MiB def d(value: bytes) -> str: @@ -258,12 +278,12 @@ def parse( if content_length is None: try: - body = yield from read_to_eof(MAX_BODY) + body = yield from read_to_eof(MAX_BODY_SIZE) except RuntimeError: raise exceptions.SecurityError( - f"body too large: over {MAX_BODY} bytes" + f"body too large: over {MAX_BODY_SIZE} bytes" ) - elif content_length > MAX_BODY: + elif content_length > MAX_BODY_SIZE: raise exceptions.SecurityError( f"body too large: {content_length} bytes" ) @@ -309,7 +329,7 @@ def parse_headers( # We don't attempt to support obsolete line folding. headers = datastructures.Headers() - for _ in range(MAX_HEADERS + 1): + for _ in range(MAX_NUM_HEADERS + 1): try: line = yield from parse_line(read_line) except EOFError as exc: @@ -355,7 +375,7 @@ def parse_line( """ try: - line = yield from read_line(MAX_LINE) + line = yield from read_line(MAX_LINE_LENGTH) except RuntimeError: raise exceptions.SecurityError("line too long") # Not mandatory but safe - https://www.rfc-editor.org/rfc/rfc7230.html#section-3.5 diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index d1d8d5608..b61126c81 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -38,7 +38,7 @@ parse_subprotocol, validate_subprotocols, ) -from ..http import USER_AGENT +from ..http11 import USER_AGENT from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol from ..uri import WebSocketURI, parse_uri from .handshake import build_request, check_response @@ -307,7 +307,7 @@ async def handshake( if self.extra_headers is not None: request_headers.update(self.extra_headers) - if self.user_agent_header is not None: + if self.user_agent_header: request_headers.setdefault("User-Agent", self.user_agent_header) self.write_http_request(wsuri.resource_name, request_headers) diff --git a/src/websockets/legacy/http.py b/src/websockets/legacy/http.py index 9a553e175..b5df7e4c4 100644 --- a/src/websockets/legacy/http.py +++ b/src/websockets/legacy/http.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import os import re from ..datastructures import Headers @@ -9,8 +10,8 @@ __all__ = ["read_request", "read_response"] -MAX_HEADERS = 128 -MAX_LINE = 8192 +MAX_NUM_HEADERS = int(os.environ.get("WEBSOCKETS_MAX_NUM_HEADERS", "128")) +MAX_LINE_LENGTH = int(os.environ.get("WEBSOCKETS_MAX_LINE_LENGTH", "8192")) def d(value: bytes) -> str: @@ -154,7 +155,7 @@ async def read_headers(stream: asyncio.StreamReader) -> Headers: # We don't attempt to support obsolete line folding. headers = Headers() - for _ in range(MAX_HEADERS + 1): + for _ in range(MAX_NUM_HEADERS + 1): try: line = await read_line(stream) except EOFError as exc: @@ -192,7 +193,7 @@ async def read_line(stream: asyncio.StreamReader) -> bytes: # Security: this is bounded by the StreamReader's limit (default = 32 KiB). line = await stream.readline() # Security: this guarantees header values are small (hard-coded = 8 KiB) - if len(line) > MAX_LINE: + if len(line) > MAX_LINE_LENGTH: raise SecurityError("line too long") # Not mandatory but safe - https://www.rfc-editor.org/rfc/rfc7230.html#section-3.5 if not line.endswith(b"\r\n"): diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 208ffa780..cd7980e00 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -40,7 +40,7 @@ parse_subprotocol, validate_subprotocols, ) -from ..http import USER_AGENT +from ..http11 import SERVER from ..protocol import State from ..typing import ExtensionHeader, LoggerLike, Origin, StatusLike, Subprotocol from .handshake import build_response, check_request @@ -106,7 +106,7 @@ def __init__( extensions: Sequence[ServerExtensionFactory] | None = None, subprotocols: Sequence[Subprotocol] | None = None, extra_headers: HeadersLikeOrCallable | None = None, - server_header: str | None = USER_AGENT, + server_header: str | None = SERVER, process_request: ( Callable[[str, Headers], Awaitable[HTTPResponse | None]] | None ) = None, @@ -221,7 +221,7 @@ async def handler(self) -> None: ) headers.setdefault("Date", email.utils.formatdate(usegmt=True)) - if self.server_header is not None: + if self.server_header: headers.setdefault("Server", self.server_header) headers.setdefault("Content-Length", str(len(body))) @@ -992,7 +992,7 @@ def __init__( extensions: Sequence[ServerExtensionFactory] | None = None, subprotocols: Sequence[Subprotocol] | None = None, extra_headers: HeadersLikeOrCallable | None = None, - server_header: str | None = USER_AGENT, + server_header: str | None = SERVER, process_request: ( Callable[[str, Headers], Awaitable[HTTPResponse | None]] | None ) = None, diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index c97a09402..e33d53f62 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -11,8 +11,7 @@ from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate from ..headers import validate_subprotocols -from ..http import USER_AGENT -from ..http11 import Response +from ..http11 import USER_AGENT, Response from ..protocol import CONNECTING, OPEN, Event from ..typing import LoggerLike, Origin, Subprotocol from ..uri import parse_uri diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 7fb46f5aa..ebbbd0312 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -16,8 +16,7 @@ from ..extensions.permessage_deflate import enable_server_permessage_deflate from ..frames import CloseCode from ..headers import validate_subprotocols -from ..http import USER_AGENT -from ..http11 import Request, Response +from ..http11 import SERVER, Request, Response from ..protocol import CONNECTING, OPEN, Event from ..server import ServerProtocol from ..typing import LoggerLike, Origin, Subprotocol @@ -83,7 +82,7 @@ def handshake( ] | None ) = None, - server_header: str | None = USER_AGENT, + server_header: str | None = SERVER, timeout: float | None = None, ) -> None: """ @@ -120,7 +119,7 @@ def handshake( if self.response is None: self.response = self.protocol.accept(self.request) - if server_header is not None: + if server_header: self.response.headers["Server"] = server_header if process_response is not None: @@ -302,7 +301,7 @@ def serve( ] | None ) = None, - server_header: str | None = USER_AGENT, + server_header: str | None = SERVER, compression: str | None = "deflate", # Timeouts open_timeout: float | None = 10, diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index b5c5d726a..329f59286 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -29,7 +29,7 @@ ServerPerMessageDeflateFactory, ) from websockets.frames import CloseCode -from websockets.http import USER_AGENT +from websockets.http11 import USER_AGENT from websockets.legacy.client import * from websockets.legacy.handshake import build_response from websockets.legacy.http import read_response diff --git a/tests/test_http.py b/tests/test_http.py deleted file mode 100644 index baaa7d416..000000000 --- a/tests/test_http.py +++ /dev/null @@ -1,8 +0,0 @@ -import unittest - -from websockets.http import * - - -class HTTPTests(unittest.TestCase): - def test_user_agent(self): - USER_AGENT # exists From a7a5042bed89b96fa2e391f3be0e255a59bffb0a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 8 Aug 2024 10:21:23 +0200 Subject: [PATCH 1316/1539] Deprecate fully websockets.http. All public API within this module are deprecated since version 9.0 so there's nothing to document. --- src/websockets/connection.py | 1 - src/websockets/http.py | 31 ++++++++++--------------------- tests/test_http.py | 16 ++++++++++++++++ 3 files changed, 26 insertions(+), 22 deletions(-) create mode 100644 tests/test_http.py diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 88bcda1aa..7942c1a28 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -2,7 +2,6 @@ import warnings -# lazy_import doesn't support this use case. from .protocol import SEND_EOF, Protocol as Connection, Side, State # noqa: F401 diff --git a/src/websockets/http.py b/src/websockets/http.py index a24102307..3dc560062 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -1,26 +1,15 @@ from __future__ import annotations -import typing +import warnings -from .imports import lazy_import +from .datastructures import Headers, MultipleValuesError # noqa: F401 +from .legacy.http import read_request, read_response # noqa: F401 -# For backwards compatibility: - - -# When type checking, import non-deprecated aliases eagerly. Else, import on demand. -if typing.TYPE_CHECKING: - from .datastructures import Headers, MultipleValuesError # noqa: F401 -else: - lazy_import( - globals(), - # Headers and MultipleValuesError used to be defined in this module. - aliases={ - "Headers": ".datastructures", - "MultipleValuesError": ".datastructures", - }, - deprecated_aliases={ - "read_request": ".legacy.http", - "read_response": ".legacy.http", - }, - ) +warnings.warn( + "Headers and MultipleValuesError were moved " + "from websockets.http to websockets.datastructures" + "and read_request and read_response were moved " + "from websockets.http to websockets.legacy.http", + DeprecationWarning, +) diff --git a/tests/test_http.py b/tests/test_http.py new file mode 100644 index 000000000..6e81199fc --- /dev/null +++ b/tests/test_http.py @@ -0,0 +1,16 @@ +from websockets.datastructures import Headers + +from .utils import DeprecationTestCase + + +class BackwardsCompatibilityTests(DeprecationTestCase): + def test_headers_class(self): + with self.assertDeprecationWarning( + "Headers and MultipleValuesError were moved " + "from websockets.http to websockets.datastructures" + "and read_request and read_response were moved " + "from websockets.http to websockets.legacy.http", + ): + from websockets.http import Headers as OldHeaders + + self.assertIs(OldHeaders, Headers) From 5835da4967e130cd631a7601e75ea5228ab27537 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 9 Aug 2024 09:25:19 +0200 Subject: [PATCH 1317/1539] Adjust timings to avoid spurious failures. --- tests/asyncio/test_server.py | 6 +++--- tests/utils.py | 7 ++++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 2e59f49b1..535083cbc 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -412,16 +412,16 @@ async def test_close_server_keeps_handlers_running(self): async with run_server(keep_running) as server: async with run_client(server) as client: # Delay termination of connection handler. - await client.send(str(2 * MS)) + await client.send(str(3 * MS)) server.close() # The server waits for the connection handler to terminate. with self.assertRaises(TimeoutError): - async with asyncio_timeout(MS): + async with asyncio_timeout(2 * MS): await server.wait_closed() - async with asyncio_timeout(2 * MS): + async with asyncio_timeout(3 * MS): await server.wait_closed() diff --git a/tests/utils.py b/tests/utils.py index 1793f3e8b..960439135 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -43,13 +43,14 @@ # WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. MS = 0.001 * float(os.environ.get("WEBSOCKETS_TESTS_TIMEOUT_FACTOR", "1")) -# PyPy has a performance penalty for this test suite. +# PyPy, asyncio's debug mode, and coverage penalize performance of this +# test suite. Increase timeouts to reduce the risk of spurious failures. if platform.python_implementation() == "PyPy": # pragma: no cover MS *= 2 - -# asyncio's debug mode has a performance penalty for this test suite. if os.environ.get("PYTHONASYNCIODEBUG"): # pragma: no cover MS *= 2 +if os.environ.get("COVERAGE_RUN"): # pragma: no branch + MS *= 2 # Ensure that timeouts are larger than the clock's resolution (for Windows). MS = max(MS, 2.5 * time.get_clock_info("monotonic").resolution) From 906592908bb5850a4f78a5d4877fbc2412d611b7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 9 Aug 2024 09:25:42 +0200 Subject: [PATCH 1318/1539] Avoid spurious coverage failures due to timing effects. --- tests/asyncio/test_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 535083cbc..4a8a76a21 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -363,7 +363,7 @@ async def test_close_server_rejects_connecting_connections(self): async def process_request(ws, _request): while ws.server.is_serving(): - await asyncio.sleep(0) + await asyncio.sleep(0) # pragma: no cover async with run_server(process_request=process_request) as server: asyncio.get_running_loop().call_later(MS, server.close) From 84e8bd879b8dfc528b4e57517f2e1f8b7ad0a378 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 9 Aug 2024 09:49:55 +0200 Subject: [PATCH 1319/1539] Fix spurious exception while running tests. Due to a race condition between serve_forever and shutdown, test run logs randomly contained this exception: Exception in thread Thread-NNN (serve_forever): Traceback (most recent call last): ... ValueError: Invalid file descriptor: -1 --- src/websockets/sync/server.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index ebbbd0312..10fbe4859 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -222,7 +222,13 @@ def serve_forever(self) -> None: """ poller = selectors.DefaultSelector() - poller.register(self.socket, selectors.EVENT_READ) + try: + poller.register(self.socket, selectors.EVENT_READ) + except ValueError: # pragma: no cover + # If shutdown() is called before poller.register(), + # the socket is closed and poller.register() raises + # ValueError: Invalid file descriptor: -1 + return if sys.platform != "win32": poller.register(self.shutdown_watcher, selectors.EVENT_READ) From a3ed1604b0f331fe91df52641cfab2ae5349eb46 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 9 Aug 2024 10:24:24 +0200 Subject: [PATCH 1320/1539] Make test_reconnect robust to slower runs. This avoids failures with higher WEBSOCKETS_TESTS_TIMEOUT_FACTOR, notably on PyPy. Refs #1483. --- tests/legacy/test_client_server.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 329f59286..0c5d66c92 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -5,6 +5,7 @@ import logging import platform import random +import re import socket import ssl import sys @@ -1608,16 +1609,19 @@ async def run_client(): ) # Iteration 3 self.assertEqual( - [record.getMessage() for record in logs.records][4:-1], + [ + re.sub(r"[0-9\.]+ seconds", "X seconds", record.getMessage()) + for record in logs.records + ][4:-1], [ "connection rejected (503 Service Unavailable)", "connection closed", - "! connect failed; reconnecting in 0.0 seconds", + "! connect failed; reconnecting in X seconds", ] + [ "connection rejected (503 Service Unavailable)", "connection closed", - "! connect failed again; retrying in 0 seconds", + "! connect failed again; retrying in X seconds", ] * ((len(logs.records) - 8) // 3) + [ From 58787cc6a58a1f1baf4be3f78d868594108afebd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 11 Aug 2024 14:35:38 +0200 Subject: [PATCH 1321/1539] Confirm support for Python 3.13. --- .github/workflows/release.yml | 2 +- .github/workflows/tests.yml | 2 ++ docs/faq/misc.rst | 3 +++ docs/project/changelog.rst | 2 +- pyproject.toml | 1 + src/websockets/asyncio/connection.py | 3 +-- tox.ini | 1 + 7 files changed, 10 insertions(+), 4 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 4a00bf8fc..ed52ddd80 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -56,7 +56,7 @@ jobs: with: platforms: all - name: Build wheels - uses: pypa/cibuildwheel@v2.16.2 + uses: pypa/cibuildwheel@v2.20.0 env: BUILD_EXTENSION: yes - name: Save wheels diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 15a45bdfb..b9172b7fb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -62,6 +62,7 @@ jobs: - "3.10" - "3.11" - "3.12" + - "3.13" - "pypy-3.9" - "pypy-3.10" is_main: @@ -78,6 +79,7 @@ jobs: uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} + allow-prereleases: true - name: Install tox run: pip install tox - name: Run tests diff --git a/docs/faq/misc.rst b/docs/faq/misc.rst index ee5ad2372..0e74a784f 100644 --- a/docs/faq/misc.rst +++ b/docs/faq/misc.rst @@ -3,6 +3,9 @@ Miscellaneous .. currentmodule:: websockets +.. Remove this question when dropping Python < 3.13, which provides natively +.. a good error message in this case. + Why do I get the error: ``module 'websockets' has no attribute '...'``? ....................................................................... diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 8143e3483..00b055dd1 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -71,7 +71,7 @@ New features See :func:`websockets.asyncio.client.connect` and :func:`websockets.asyncio.server.serve` for details. -* Validated compatibility with Python 3.12. +* Validated compatibility with Python 3.12 and 3.13. * Added :doc:`environment variables <../reference/variables>` to configure debug logs, the ``Server`` and ``User-Agent`` headers, as well as security limits. diff --git a/pyproject.toml b/pyproject.toml index de8acd6a3..c1d34c90b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ] dynamic = ["version", "readme"] diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 152c6789e..4f44d798c 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -564,8 +564,7 @@ async def ping(self, data: Data | None = None) -> Awaitable[None]: pong_waiter = self.loop.create_future() # The event loop's default clock is time.monotonic(). Its resolution - # is a bit low on Windows (~16ms). We cannot use time.perf_counter() - # because it doesn't count time elapsed while the process sleeps. + # is a bit low on Windows (~16ms). This is improved in Python 3.13. ping_timestamp = self.loop.time() self.pong_waiters[data] = (pong_waiter, ping_timestamp) self.protocol.send_ping(data) diff --git a/tox.ini b/tox.ini index 1edcfe261..16d9c9f16 100644 --- a/tox.ini +++ b/tox.ini @@ -5,6 +5,7 @@ env_list = py310 py311 py312 + py313 coverage black ruff From 9ec785d6f12cb1a3a3bc43f543f4a831a635472b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 11 Aug 2024 15:07:40 +0200 Subject: [PATCH 1322/1539] Fix copy-paste error in tests. --- tests/asyncio/test_client.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index aab65cd2e..b74617ef0 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -156,27 +156,27 @@ async def test_connection(self): async def test_set_server_hostname_implicitly(self): """Client sets server_hostname to the host in the WebSocket URI.""" - with temp_unix_socket_path() as path: - async with run_unix_server(path, ssl=SERVER_CONTEXT): - async with run_unix_client( - path, - ssl=CLIENT_CONTEXT, - uri="wss://overridden/", - ) as client: - ssl_object = client.transport.get_extra_info("ssl_object") - self.assertEqual(ssl_object.server_hostname, "overridden") + async with run_server(ssl=SERVER_CONTEXT) as server: + host, port = get_server_host_port(server) + async with run_client( + "wss://overridden/", + host=host, + port=port, + ssl=CLIENT_CONTEXT, + ) as client: + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.server_hostname, "overridden") async def test_set_server_hostname_explicitly(self): """Client sets server_hostname to the value provided in argument.""" - with temp_unix_socket_path() as path: - async with run_unix_server(path, ssl=SERVER_CONTEXT): - async with run_unix_client( - path, - ssl=CLIENT_CONTEXT, - server_hostname="overridden", - ) as client: - ssl_object = client.transport.get_extra_info("ssl_object") - self.assertEqual(ssl_object.server_hostname, "overridden") + async with run_server(ssl=SERVER_CONTEXT) as server: + async with run_client( + server, + ssl=CLIENT_CONTEXT, + server_hostname="overridden", + ) as client: + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.server_hostname, "overridden") async def test_reject_invalid_server_certificate(self): """Client rejects certificate where server certificate isn't trusted.""" From 1853a9b2d0247573633e2749fe1169f764abe03c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 11 Aug 2024 15:59:59 +0200 Subject: [PATCH 1323/1539] Ignore ResourceWarning in test. This is expected to prevent a spurious test failure under PyPy. Refs #1483. --- tests/legacy/utils.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/legacy/utils.py b/tests/legacy/utils.py index 28bc90df3..5bb56b26f 100644 --- a/tests/legacy/utils.py +++ b/tests/legacy/utils.py @@ -2,6 +2,7 @@ import contextlib import functools import logging +import sys import unittest @@ -76,8 +77,19 @@ def assertDeprecationWarnings(self, recorded_warnings, expected_warnings): Check recorded deprecation warnings match a list of expected messages. """ + # Work around https://github.com/python/cpython/issues/90476. + if sys.version_info[:2] < (3, 11): # pragma: no cover + recorded_warnings = [ + recorded + for recorded in recorded_warnings + if not ( + type(recorded.message) is ResourceWarning + and str(recorded.message).startswith("unclosed transport") + ) + ] + for recorded in recorded_warnings: - self.assertEqual(type(recorded.message), DeprecationWarning) + self.assertIs(type(recorded.message), DeprecationWarning) self.assertEqual( {str(recorded.message) for recorded in recorded_warnings}, set(expected_warnings), From 00b63afe7d921d17fd48abee2e25389050a2410c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 12 Aug 2024 09:44:02 +0200 Subject: [PATCH 1324/1539] Add new asyncio implementation to feature matrices. --- docs/reference/features.rst | 257 ++++++++++++++++++------------------ 1 file changed, 128 insertions(+), 129 deletions(-) diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 98b3c0dda..946770fe3 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -14,9 +14,10 @@ Feature support matrices summarize which implementations support which features. .support-matrix-table td:not(:first-child) { text-align: center; } -.. |aio| replace:: :mod:`asyncio` +.. |aio| replace:: :mod:`asyncio` (new) .. |sync| replace:: :mod:`threading` .. |sans| replace:: `Sans-I/O`_ +.. |leg| replace:: :mod:`asyncio` (legacy) .. _Sans-I/O: https://sans-io.readthedocs.io/ Both sides @@ -25,60 +26,58 @@ Both sides .. table:: :class: support-matrix-table - +------------------------------------+--------+--------+--------+ - | | |aio| | |sync| | |sans| | - +====================================+========+========+========+ - | Perform the opening handshake | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Send a message | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Receive a message | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Iterate over received messages | ✅ | ✅ | ❌ | - +------------------------------------+--------+--------+--------+ - | Send a fragmented message | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Receive a fragmented message after | ✅ | ✅ | ❌ | - | reassembly | | | | - +------------------------------------+--------+--------+--------+ - | Receive a fragmented message frame | ❌ | ✅ | ✅ | - | by frame (`#479`_) | | | | - +------------------------------------+--------+--------+--------+ - | Send a ping | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Respond to pings automatically | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Send a pong | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Perform the closing handshake | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Report close codes and reasons | ❌ | ✅ | ✅ | - | from both sides | | | | - +------------------------------------+--------+--------+--------+ - | Compress messages (:rfc:`7692`) | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Tune memory usage for compression | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Negotiate extensions | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Implement custom extensions | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Negotiate a subprotocol | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Enforce security limits | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Log events | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Enforce opening timeout | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Enforce closing timeout | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Keepalive | ✅ | ❌ | — | - +------------------------------------+--------+--------+--------+ - | Heartbeat | ✅ | ❌ | — | - +------------------------------------+--------+--------+--------+ - -.. _#479: https://github.com/python-websockets/websockets/issues/479 + +------------------------------------+--------+--------+--------+--------+ + | | |aio| | |sync| | |sans| | |leg| | + +====================================+========+========+========+========+ + | Perform the opening handshake | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Enforce opening timeout | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Send a message | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Receive a message | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Iterate over received messages | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Send a fragmented message | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Receive a fragmented message frame | ✅ | ✅ | ✅ | ❌ | + | by frame | | | | | + +------------------------------------+--------+--------+--------+--------+ + | Receive a fragmented message after | ✅ | ✅ | — | ✅ | + | reassembly | | | | | + +------------------------------------+--------+--------+--------+--------+ + | Send a ping | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Respond to pings automatically | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Send a pong | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Keepalive | ❌ | ❌ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Heartbeat | ❌ | ❌ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Perform the closing handshake | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Enforce closing timeout | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Report close codes and reasons | ✅ | ✅ | ✅ | ❌ | + | from both sides | | | | | + +------------------------------------+--------+--------+--------+--------+ + | Compress messages (:rfc:`7692`) | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Tune memory usage for compression | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Negotiate extensions | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Implement custom extensions | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Negotiate a subprotocol | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Enforce security limits | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Log events | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ Server ------ @@ -86,39 +85,39 @@ Server .. table:: :class: support-matrix-table - +------------------------------------+--------+--------+--------+ - | | |aio| | |sync| | |sans| | - +====================================+========+========+========+ - | Listen on a TCP socket | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Listen on a Unix socket | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Listen using a preexisting socket | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Encrypt connection with TLS | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Close server on context exit | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Close connection on handler exit | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Shut down server gracefully | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Check ``Origin`` header | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Customize subprotocol selection | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Configure ``Server`` header | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Alter opening handshake request | ❌ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Alter opening handshake response | ❌ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Perform HTTP Basic Authentication | ✅ | ❌ | ❌ | - +------------------------------------+--------+--------+--------+ - | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | - +------------------------------------+--------+--------+--------+ - | Force HTTP response | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ + +------------------------------------+--------+--------+--------+--------+ + | | |aio| | |sync| | |sans| | |leg| | + +====================================+========+========+========+========+ + | Listen on a TCP socket | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Listen on a Unix socket | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Listen using a preexisting socket | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Encrypt connection with TLS | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Close server on context exit | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Close connection on handler exit | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Shut down server gracefully | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Check ``Origin`` header | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Customize subprotocol selection | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Configure ``Server`` header | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Alter opening handshake request | ✅ | ✅ | ✅ | ❌ | + +------------------------------------+--------+--------+--------+--------+ + | Alter opening handshake response | ✅ | ✅ | ✅ | ❌ | + +------------------------------------+--------+--------+--------+--------+ + | Force an HTTP response | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Perform HTTP Basic Authentication | ❌ | ❌ | ❌ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | ❌ | + +------------------------------------+--------+--------+--------+--------+ Client ------ @@ -126,41 +125,43 @@ Client .. table:: :class: support-matrix-table - +------------------------------------+--------+--------+--------+ - | | |aio| | |sync| | |sans| | - +====================================+========+========+========+ - | Connect to a TCP socket | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Connect to a Unix socket | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Connect using a preexisting socket | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Encrypt connection with TLS | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Close connection on context exit | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Reconnect automatically | ✅ | ❌ | — | - +------------------------------------+--------+--------+--------+ - | Configure ``Origin`` header | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Configure ``User-Agent`` header | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Alter opening handshake request | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Connect to non-ASCII IRIs | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Perform HTTP Basic Authentication | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | - | (`#784`_) | | | | - +------------------------------------+--------+--------+--------+ - | Follow HTTP redirects | ✅ | ❌ | — | - +------------------------------------+--------+--------+--------+ - | Connect via a HTTP proxy (`#364`_) | ❌ | ❌ | — | - +------------------------------------+--------+--------+--------+ - | Connect via a SOCKS5 proxy | ❌ | ❌ | — | - | (`#475`_) | | | | - +------------------------------------+--------+--------+--------+ + +------------------------------------+--------+--------+--------+--------+ + | | |aio| | |sync| | |sans| | |leg| | + +====================================+========+========+========+========+ + | Connect to a TCP socket | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Connect to a Unix socket | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Connect using a preexisting socket | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Encrypt connection with TLS | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Close connection on context exit | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Reconnect automatically | ❌ | ❌ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Configure ``Origin`` header | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Configure ``User-Agent`` header | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Modify opening handshake request | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Modify opening handshake response | ✅ | ✅ | ✅ | ❌ | + +------------------------------------+--------+--------+--------+--------+ + | Connect to non-ASCII IRIs | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Follow HTTP redirects | ❌ | ❌ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Perform HTTP Basic Authentication | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | ❌ | + | (`#784`_) | | | | | + +------------------------------------+--------+--------+--------+--------+ + | Connect via a HTTP proxy (`#364`_) | ❌ | ❌ | — | ❌ | + +------------------------------------+--------+--------+--------+--------+ + | Connect via a SOCKS5 proxy | ❌ | ❌ | — | ❌ | + | (`#475`_) | | | | | + +------------------------------------+--------+--------+--------+--------+ .. _#364: https://github.com/python-websockets/websockets/issues/364 .. _#475: https://github.com/python-websockets/websockets/issues/475 @@ -174,14 +175,12 @@ There is no way to control compression of outgoing frames on a per-frame basis .. _#538: https://github.com/python-websockets/websockets/issues/538 -The server doesn't check the Host header and respond with a HTTP 400 Bad Request -if it is missing or invalid (`#1246`). +The server doesn't check the Host header and doesn't respond with a HTTP 400 Bad +Request if it is missing or invalid (`#1246`). .. _#1246: https://github.com/python-websockets/websockets/issues/1246 The client API doesn't attempt to guarantee that there is no more than one connection to a given IP address in a CONNECTING state. This behavior is -`mandated by RFC 6455`_. However, :func:`~client.connect()` isn't the right -layer for enforcing this constraint. It's the caller's responsibility. - -.. _mandated by RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-4.1 +mandated by :rfc:`6455`, section 4.1. However, :func:`~client.connect()` isn't +the right layer for enforcing this constraint. It's the caller's responsibility. From 7345b31edc82abc200ebb58dc0fbe856e65d447b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 12 Aug 2024 09:45:39 +0200 Subject: [PATCH 1325/1539] Cool URI don't change... except when they do. --- src/websockets/asyncio/connection.py | 18 ++++++------- .../extensions/permessage_deflate.py | 4 +-- src/websockets/headers.py | 18 ++++++------- src/websockets/http11.py | 14 +++++----- src/websockets/legacy/http.py | 10 +++---- src/websockets/legacy/protocol.py | 26 +++++++++---------- src/websockets/legacy/server.py | 2 +- src/websockets/protocol.py | 4 +-- src/websockets/server.py | 2 +- src/websockets/sync/connection.py | 18 ++++++------- src/websockets/typing.py | 4 +-- src/websockets/uri.py | 2 +- 12 files changed, 61 insertions(+), 61 deletions(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 4f44d798c..0a3ddb9aa 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -200,8 +200,8 @@ async def recv(self, decode: bool | None = None) -> Data: A string (:class:`str`) for a Text_ frame or a bytestring (:class:`bytes`) for a Binary_ frame. - .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 - .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 You may override this behavior with the ``decode`` argument: @@ -253,8 +253,8 @@ async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data An iterator of strings (:class:`str`) for a Text_ frame or bytestrings (:class:`bytes`) for a Binary_ frame. - .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 - .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 You may override this behavior with the ``decode`` argument: @@ -290,8 +290,8 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No bytes-like object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent as a Binary_ frame. - .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 - .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 :meth:`send` also accepts an iterable or an asynchronous iterable of strings, bytestrings, or bytes-like objects to enable fragmentation_. @@ -299,7 +299,7 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No All items must be of the same type, or else :meth:`send` will raise a :exc:`TypeError` and the connection will be closed. - .. _fragmentation: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.4 + .. _fragmentation: https://datatracker.ietf.org/doc/html/rfc6455#section-5.4 :meth:`send` rejects dict-like objects because this is often an error. (If you really want to send the keys of a dict-like object as fragments, @@ -524,7 +524,7 @@ async def ping(self, data: Data | None = None) -> Awaitable[None]: """ Send a Ping_. - .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 + .. _Ping: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 A ping may serve as a keepalive or as a check that the remote endpoint received all messages up to this point @@ -574,7 +574,7 @@ async def pong(self, data: Data = b"") -> None: """ Send a Pong_. - .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 + .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 An unsolicited pong may serve as a unidirectional heartbeat. diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 579262f02..fea14131e 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -262,7 +262,7 @@ class ClientPerMessageDeflateFactory(ClientExtensionFactory): Parameters behave as described in `section 7.1 of RFC 7692`_. - .. _section 7.1 of RFC 7692: https://www.rfc-editor.org/rfc/rfc7692.html#section-7.1 + .. _section 7.1 of RFC 7692: https://datatracker.ietf.org/doc/html/rfc7692#section-7.1 Set them to :obj:`True` to include them in the negotiation offer without a value or to an integer value to include them with this value. @@ -462,7 +462,7 @@ class ServerPerMessageDeflateFactory(ServerExtensionFactory): Parameters behave as described in `section 7.1 of RFC 7692`_. - .. _section 7.1 of RFC 7692: https://www.rfc-editor.org/rfc/rfc7692.html#section-7.1 + .. _section 7.1 of RFC 7692: https://datatracker.ietf.org/doc/html/rfc7692#section-7.1 Set them to :obj:`True` to include them in the negotiation offer without a value or to an integer value to include them with this value. diff --git a/src/websockets/headers.py b/src/websockets/headers.py index bc42e0b72..0ffd65233 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -40,7 +40,7 @@ def build_host(host: str, port: int, secure: bool) -> str: Build a ``Host`` header. """ - # https://www.rfc-editor.org/rfc/rfc3986.html#section-3.2.2 + # https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.2 # IPv6 addresses must be enclosed in brackets. try: address = ipaddress.ip_address(host) @@ -59,8 +59,8 @@ def build_host(host: str, port: int, secure: bool) -> str: # To avoid a dependency on a parsing library, we implement manually the ABNF -# described in https://www.rfc-editor.org/rfc/rfc6455.html#section-9.1 and -# https://www.rfc-editor.org/rfc/rfc7230.html#appendix-B. +# described in https://datatracker.ietf.org/doc/html/rfc6455#section-9.1 and +# https://datatracker.ietf.org/doc/html/rfc7230#appendix-B. def peek_ahead(header: str, pos: int) -> str | None: @@ -183,7 +183,7 @@ def parse_list( InvalidHeaderFormat: On invalid inputs. """ - # Per https://www.rfc-editor.org/rfc/rfc7230.html#section-7, "a recipient + # Per https://datatracker.ietf.org/doc/html/rfc7230#section-7, "a recipient # MUST parse and ignore a reasonable number of empty list elements"; # hence while loops that remove extra delimiters. @@ -320,7 +320,7 @@ def parse_extension_item_param( if peek_ahead(header, pos) == '"': pos_before = pos # for proper error reporting below value, pos = parse_quoted_string(header, pos, header_name) - # https://www.rfc-editor.org/rfc/rfc6455.html#section-9.1 says: + # https://datatracker.ietf.org/doc/html/rfc6455#section-9.1 says: # the value after quoted-string unescaping MUST conform to # the 'token' ABNF. if _token_re.fullmatch(value) is None: @@ -489,7 +489,7 @@ def build_www_authenticate_basic(realm: str) -> str: realm: Identifier of the protection space. """ - # https://www.rfc-editor.org/rfc/rfc7617.html#section-2 + # https://datatracker.ietf.org/doc/html/rfc7617#section-2 realm = build_quoted_string(realm) charset = build_quoted_string("UTF-8") return f"Basic realm={realm}, charset={charset}" @@ -539,8 +539,8 @@ def parse_authorization_basic(header: str) -> tuple[str, str]: InvalidHeaderValue: On unsupported inputs. """ - # https://www.rfc-editor.org/rfc/rfc7235.html#section-2.1 - # https://www.rfc-editor.org/rfc/rfc7617.html#section-2 + # https://datatracker.ietf.org/doc/html/rfc7235#section-2.1 + # https://datatracker.ietf.org/doc/html/rfc7617#section-2 scheme, pos = parse_token(header, 0, "Authorization") if scheme.lower() != "basic": raise exceptions.InvalidHeaderValue( @@ -580,7 +580,7 @@ def build_authorization_basic(username: str, password: str) -> str: This is the reverse of :func:`parse_authorization_basic`. """ - # https://www.rfc-editor.org/rfc/rfc7617.html#section-2 + # https://datatracker.ietf.org/doc/html/rfc7617#section-2 assert ":" not in username user_pass = f"{username}:{password}" basic_credentials = base64.b64encode(user_pass.encode()).decode() diff --git a/src/websockets/http11.py b/src/websockets/http11.py index ed49fcbf9..b86c6ca4a 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -48,7 +48,7 @@ def d(value: bytes) -> str: return value.decode(errors="backslashreplace") -# See https://www.rfc-editor.org/rfc/rfc7230.html#appendix-B. +# See https://datatracker.ietf.org/doc/html/rfc7230#appendix-B. # Regex for validating header names. @@ -122,7 +122,7 @@ def parse( ValueError: If the request isn't well formatted. """ - # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.1 + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.1 # Parsing is simple because fixed values are expected for method and # version and because path isn't checked. Since WebSocket software tends @@ -146,7 +146,7 @@ def parse( headers = yield from parse_headers(read_line) - # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.3.3 + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.3 if "Transfer-Encoding" in headers: raise NotImplementedError("transfer codings aren't supported") @@ -227,7 +227,7 @@ def parse( ValueError: If the response isn't well formatted. """ - # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.2 + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.2 try: status_line = yield from parse_line(read_line) @@ -255,7 +255,7 @@ def parse( headers = yield from parse_headers(read_line) - # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.3.3 + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.3 if "Transfer-Encoding" in headers: raise NotImplementedError("transfer codings aren't supported") @@ -324,7 +324,7 @@ def parse_headers( ValueError: If the request isn't well formatted. """ - # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.2 + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.2 # We don't attempt to support obsolete line folding. @@ -378,7 +378,7 @@ def parse_line( line = yield from read_line(MAX_LINE_LENGTH) except RuntimeError: raise exceptions.SecurityError("line too long") - # Not mandatory but safe - https://www.rfc-editor.org/rfc/rfc7230.html#section-3.5 + # Not mandatory but safe - https://datatracker.ietf.org/doc/html/rfc7230#section-3.5 if not line.endswith(b"\r\n"): raise EOFError("line without CRLF") return line[:-2] diff --git a/src/websockets/legacy/http.py b/src/websockets/legacy/http.py index b5df7e4c4..a7c8a927e 100644 --- a/src/websockets/legacy/http.py +++ b/src/websockets/legacy/http.py @@ -22,7 +22,7 @@ def d(value: bytes) -> str: return value.decode(errors="backslashreplace") -# See https://www.rfc-editor.org/rfc/rfc7230.html#appendix-B. +# See https://datatracker.ietf.org/doc/html/rfc7230#appendix-B. # Regex for validating header names. @@ -64,7 +64,7 @@ async def read_request(stream: asyncio.StreamReader) -> tuple[str, Headers]: ValueError: If the request isn't well formatted. """ - # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.1 + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.1 # Parsing is simple because fixed values are expected for method and # version and because path isn't checked. Since WebSocket software tends @@ -111,7 +111,7 @@ async def read_response(stream: asyncio.StreamReader) -> tuple[int, str, Headers ValueError: If the response isn't well formatted. """ - # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.2 + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.2 # As in read_request, parsing is simple because a fixed value is expected # for version, status_code is a 3-digit number, and reason can be ignored. @@ -150,7 +150,7 @@ async def read_headers(stream: asyncio.StreamReader) -> Headers: Non-ASCII characters are represented with surrogate escapes. """ - # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.2 + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.2 # We don't attempt to support obsolete line folding. @@ -195,7 +195,7 @@ async def read_line(stream: asyncio.StreamReader) -> bytes: # Security: this guarantees header values are small (hard-coded = 8 KiB) if len(line) > MAX_LINE_LENGTH: raise SecurityError("line too long") - # Not mandatory but safe - https://www.rfc-editor.org/rfc/rfc7230.html#section-3.5 + # Not mandatory but safe - https://datatracker.ietf.org/doc/html/rfc7230#section-3.5 if not line.endswith(b"\r\n"): raise EOFError("line without CRLF") return line[:-2] diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 120ff8e73..6f8916576 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -80,14 +80,14 @@ class WebSocketCommonProtocol(asyncio.Protocol): especially in the presence of proxies with short timeouts on inactive connections. Set ``ping_interval`` to :obj:`None` to disable this behavior. - .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 + .. _Ping: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 If the corresponding Pong_ frame isn't received within ``ping_timeout`` seconds, the connection is considered unusable and is closed with code 1011. This ensures that the remote endpoint remains responsive. Set ``ping_timeout`` to :obj:`None` to disable this behavior. - .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 + .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 See the discussion of :doc:`timeouts <../../topics/timeouts>` for details. @@ -447,7 +447,7 @@ def close_code(self) -> int | None: WebSocket close code, defined in `section 7.1.5 of RFC 6455`_. .. _section 7.1.5 of RFC 6455: - https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.5 + https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.5 :obj:`None` if the connection isn't closed yet. @@ -465,7 +465,7 @@ def close_reason(self) -> str | None: WebSocket close reason, defined in `section 7.1.6 of RFC 6455`_. .. _section 7.1.6 of RFC 6455: - https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.6 + https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.6 :obj:`None` if the connection isn't closed yet. @@ -516,8 +516,8 @@ async def recv(self) -> Data: A string (:class:`str`) for a Text_ frame. A bytestring (:class:`bytes`) for a Binary_ frame. - .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 - .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 Raises: ConnectionClosed: When the connection is closed. @@ -583,8 +583,8 @@ async def send( bytes-like object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent as a Binary_ frame. - .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 - .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 :meth:`send` also accepts an iterable or an asynchronous iterable of strings, bytestrings, or bytes-like objects to enable fragmentation_. @@ -592,7 +592,7 @@ async def send( All items must be of the same type, or else :meth:`send` will raise a :exc:`TypeError` and the connection will be closed. - .. _fragmentation: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.4 + .. _fragmentation: https://datatracker.ietf.org/doc/html/rfc6455#section-5.4 :meth:`send` rejects dict-like objects because this is often an error. (If you want to send the keys of a dict-like object as fragments, call @@ -803,7 +803,7 @@ async def ping(self, data: Data | None = None) -> Awaitable[float]: """ Send a Ping_. - .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 + .. _Ping: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 A ping may serve as a keepalive, as a check that the remote endpoint received all messages up to this point, or to measure :attr:`latency`. @@ -862,7 +862,7 @@ async def pong(self, data: Data = b"") -> None: """ Send a Pong_. - .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 + .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 An unsolicited pong may serve as a unidirectional heartbeat. @@ -1559,8 +1559,8 @@ def broadcast( object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent as a Binary_ frame. - .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 - .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 :func:`broadcast` pushes the message synchronously to all connections even if their write buffers are overflowing. There's no backpressure. diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index cd7980e00..d230f009e 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -388,7 +388,7 @@ def process_origin( """ # "The user agent MUST NOT include more than one Origin header field" - # per https://www.rfc-editor.org/rfc/rfc6454.html#section-7.3. + # per https://datatracker.ietf.org/doc/html/rfc6454#section-7.3. try: origin = headers.get("Origin") except MultipleValuesError as exc: diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 7f2b45c74..917c19163 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -175,7 +175,7 @@ def close_code(self) -> int | None: `WebSocket close code`_. .. _WebSocket close code: - https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.5 + https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.5 :obj:`None` if the connection isn't closed yet. @@ -193,7 +193,7 @@ def close_reason(self) -> str | None: `WebSocket close reason`_. .. _WebSocket close reason: - https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.6 + https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.6 :obj:`None` if the connection isn't closed yet. diff --git a/src/websockets/server.py b/src/websockets/server.py index 7211d3cbf..1b4c3bf29 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -307,7 +307,7 @@ def process_origin(self, headers: Headers) -> Origin | None: """ # "The user agent MUST NOT include more than one Origin header field" - # per https://www.rfc-editor.org/rfc/rfc6454.html#section-7.3. + # per https://datatracker.ietf.org/doc/html/rfc6454#section-7.3. try: origin = headers.get("Origin") except MultipleValuesError as exc: diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 2bcb3aa0e..a4826c785 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -189,8 +189,8 @@ def recv(self, timeout: float | None = None) -> Data: A string (:class:`str`) for a Text_ frame or a bytestring (:class:`bytes`) for a Binary_ frame. - .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 - .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 Raises: ConnectionClosed: When the connection is closed. @@ -222,8 +222,8 @@ def recv_streaming(self) -> Iterator[Data]: An iterator of strings (:class:`str`) for a Text_ frame or bytestrings (:class:`bytes`) for a Binary_ frame. - .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 - .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 Raises: ConnectionClosed: When the connection is closed. @@ -250,8 +250,8 @@ def send(self, message: Data | Iterable[Data]) -> None: bytes-like object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent as a Binary_ frame. - .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 - .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 :meth:`send` also accepts an iterable of strings, bytestrings, or bytes-like objects to enable fragmentation_. Each item is treated as a @@ -259,7 +259,7 @@ def send(self, message: Data | Iterable[Data]) -> None: same type, or else :meth:`send` will raise a :exc:`TypeError` and the connection will be closed. - .. _fragmentation: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.4 + .. _fragmentation: https://datatracker.ietf.org/doc/html/rfc6455#section-5.4 :meth:`send` rejects dict-like objects because this is often an error. (If you really want to send the keys of a dict-like object as fragments, @@ -425,7 +425,7 @@ def ping(self, data: Data | None = None) -> threading.Event: """ Send a Ping_. - .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 + .. _Ping: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 A ping may serve as a keepalive or as a check that the remote endpoint received all messages up to this point @@ -470,7 +470,7 @@ def pong(self, data: Data = b"") -> None: """ Send a Pong_. - .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 + .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 An unsolicited pong may serve as a unidirectional heartbeat. diff --git a/src/websockets/typing.py b/src/websockets/typing.py index 6360c7a0a..447fe79da 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -24,8 +24,8 @@ """Types supported in a WebSocket message: :class:`str` for a Text_ frame, :class:`bytes` for a Binary_. -.. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 -.. _Binary : https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 +.. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 +.. _Binary : https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 """ diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 5cb38a9cc..82b35f92a 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -23,7 +23,7 @@ class WebSocketURI: username: Available when the URI contains `User Information`_. password: Available when the URI contains `User Information`_. - .. _User Information: https://www.rfc-editor.org/rfc/rfc3986.html#section-3.2.1 + .. _User Information: https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.1 """ From e2f0385119992317c0f49b32775ef50b2fa40218 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 13 Aug 2024 23:15:55 +0200 Subject: [PATCH 1326/1539] Add guide for upgrading to the new asyncio implementation. --- docs/howto/index.rst | 8 + docs/howto/upgrade.rst | 357 ++++++++++++++++++++++++++ docs/project/changelog.rst | 66 ++++- docs/reference/asyncio/client.rst | 4 +- docs/reference/asyncio/common.rst | 4 +- docs/reference/asyncio/server.rst | 10 +- docs/reference/new-asyncio/client.rst | 4 +- docs/reference/new-asyncio/common.rst | 4 +- docs/reference/new-asyncio/server.rst | 4 +- docs/spelling_wordlist.txt | 5 +- 10 files changed, 446 insertions(+), 20 deletions(-) create mode 100644 docs/howto/upgrade.rst diff --git a/docs/howto/index.rst b/docs/howto/index.rst index ddbe67d3a..863c1c63c 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -8,6 +8,14 @@ In a hurry? Check out these examples. quickstart +Upgrading from the legacy :mod:`asyncio` implementation to the new one? +Read this. + +.. toctree:: + :titlesonly: + + upgrade + If you're stuck, perhaps you'll find the answer here. .. toctree:: diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst new file mode 100644 index 000000000..bb4c59bc4 --- /dev/null +++ b/docs/howto/upgrade.rst @@ -0,0 +1,357 @@ +Upgrade to the new :mod:`asyncio` implementation +================================================ + +.. currentmodule:: websockets + +The new :mod:`asyncio` implementation is a rewrite of the original +implementation of websockets. + +It provides a very similar API. However, there are a few differences. + +The recommended upgrade process is: + +1. Make sure that your application doesn't use any `deprecated APIs`_. If it + doesn't raise any warnings, you can skip this step. +2. Check if your application depends on `missing features`_. If it does, you + should stick to the original implementation until they're added. +3. `Update import paths`_. For straightforward usage of websockets, this could + be the only step you need to take. Upgrading could be transparent. +4. `Review API changes`_ and adapt your application to preserve its current + functionality or take advantage of improvements in the new implementation. + +In the interest of brevity, only :func:`~asyncio.client.connect` and +:func:`~asyncio.server.serve` are discussed below but everything also applies +to :func:`~asyncio.client.unix_connect` and :func:`~asyncio.server.unix_serve` +respectively. + +.. admonition:: What will happen to the original implementation? + :class: hint + + The original implementation is now considered legacy. + + The next steps are: + + 1. Deprecating it once the new implementation reaches feature parity. + 2. Maintaining it for five years per the :ref:`backwards-compatibility + policy `. + 3. Removing it. This is expected to happen around 2030. + +.. _deprecated APIs: + +Deprecated APIs +--------------- + +Here's the list of deprecated behaviors that the original implementation still +supports and that the new implementation doesn't reproduce. + +If you're seeing a :class:`DeprecationWarning`, follow upgrade instructions from +the release notes of the version in which the feature was deprecated. + +* The ``path`` argument of connection handlers — unnecessary since :ref:`10.1` + and deprecated in :ref:`13.0`. +* The ``loop`` and ``legacy_recv`` arguments of :func:`~client.connect` and + :func:`~server.serve`, which were removed — deprecated in :ref:`10.0`. +* The ``timeout`` and ``klass`` arguments of :func:`~client.connect` and + :func:`~server.serve`, which were renamed to ``close_timeout`` and + ``create_protocol`` — deprecated in :ref:`7.0` and :ref:`3.4` respectively. +* An empty string in the ``origins`` argument of :func:`~server.serve` — + deprecated in :ref:`7.0`. +* The ``host``, ``port``, and ``secure`` attributes of connections — deprecated + in :ref:`8.0`. + +.. _missing features: + +Missing features +---------------- + +.. admonition:: All features listed below will be provided in a future release. + :class: tip + + If your application relies on one of them, you should stick to the original + implementation until the new implementation supports it in a future release. + +Broadcast +......... + +The new implementation doesn't support :doc:`broadcasting messages +<../topics/broadcast>` yet. + +Keepalive +......... + +The new implementation doesn't provide a :ref:`keepalive mechanism ` +yet. + +As a consequence, :func:`~asyncio.client.connect` and +:func:`~asyncio.server.serve` don't accept the ``ping_interval`` and +``ping_timeout`` arguments and the +:attr:`~legacy.protocol.WebSocketCommonProtocol.latency` property doesn't exist. + +HTTP Basic Authentication +......................... + +On the server side, :func:`~asyncio.server.serve` doesn't provide HTTP Basic +Authentication yet. + +For the avoidance of doubt, on the client side, :func:`~asyncio.client.connect` +performs HTTP Basic Authentication. + +Following redirects +................... + +The new implementation of :func:`~asyncio.client.connect` doesn't follow HTTP +redirects yet. + +Automatic reconnection +...................... + +The new implementation of :func:`~asyncio.client.connect` doesn't provide +automatic reconnection yet. + +In other words, the following pattern isn't supported:: + + from websockets.asyncio.client import connect + + async for websocket in connect(...): # this doesn't work yet + ... + +Configuring buffers +................... + +The new implementation doesn't provide a way to configure read and write buffers +yet. + +In practice, :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` +don't accept the ``max_queue``, ``read_limit``, and ``write_limit`` arguments. + +Here's the most likely outcome: + +* ``max_queue`` will be implemented but its semantics will change from "maximum + number of messages" to "maximum number of frames", which makes a difference + when messages are fragmented. +* ``read_limit`` won't be implemented because the buffer that it configured was + removed from the new implementation. The queue that ``max_queue`` configures + is the only read buffer now. +* ``write_limit`` will be implemented as in the original implementation. + Alternatively, the same functionality could be exposed with a different API. + +.. _Update import paths: + +Import paths +------------ + +For context, the ``websockets`` package is structured as follows: + +* The new implementation is found in the ``websockets.asyncio`` package. +* The original implementation was moved to the ``websockets.legacy`` package. +* The ``websockets`` package provides aliases for convenience. +* The ``websockets.client`` and ``websockets.server`` packages provide aliases + for backwards-compatibility with earlier versions of websockets. +* Currently, all aliases point to the original implementation. In the future, + they will point to the new implementation or they will be deprecated. + +To upgrade to the new :mod:`asyncio` implementation, change import paths as +shown in the tables below. + +.. |br| raw:: html + +
+ +Client APIs +........... + ++-------------------------------------------------------------------+-----------------------------------------------------+ +| Legacy :mod:`asyncio` implementation | New :mod:`asyncio` implementation | ++===================================================================+=====================================================+ +| ``websockets.connect()`` |br| | :func:`websockets.asyncio.client.connect` | +| :func:`websockets.client.connect` |br| | | +| ``websockets.legacy.client.connect()`` | | ++-------------------------------------------------------------------+-----------------------------------------------------+ +| ``websockets.unix_connect()`` |br| | :func:`websockets.asyncio.client.unix_connect` | +| :func:`websockets.client.unix_connect` |br| | | +| ``websockets.legacy.client.unix_connect()`` | | ++-------------------------------------------------------------------+-----------------------------------------------------+ +| ``websockets.WebSocketClientProtocol`` |br| | :class:`websockets.asyncio.client.ClientConnection` | +| :class:`websockets.client.WebSocketClientProtocol` |br| | | +| ``websockets.legacy.client.WebSocketClientProtocol`` | | ++-------------------------------------------------------------------+-----------------------------------------------------+ + +Server APIs +........... + ++-------------------------------------------------------------------+-----------------------------------------------------+ +| Legacy :mod:`asyncio` implementation | New :mod:`asyncio` implementation | ++===================================================================+=====================================================+ +| ``websockets.serve()`` |br| | :func:`websockets.asyncio.server.serve` | +| :func:`websockets.server.serve` |br| | | +| ``websockets.legacy.server.serve()`` | | ++-------------------------------------------------------------------+-----------------------------------------------------+ +| ``websockets.unix_serve()`` |br| | :func:`websockets.asyncio.server.unix_serve` | +| :func:`websockets.server.unix_serve` |br| | | +| ``websockets.legacy.server.unix_serve()`` | | ++-------------------------------------------------------------------+-----------------------------------------------------+ +| ``websockets.WebSocketServer`` |br| | :class:`websockets.asyncio.server.WebSocketServer` | +| :class:`websockets.server.WebSocketServer` |br| | | +| ``websockets.legacy.server.WebSocketServer`` | | ++-------------------------------------------------------------------+-----------------------------------------------------+ +| ``websockets.WebSocketServerProtocol`` |br| | :class:`websockets.asyncio.server.ServerConnection` | +| :class:`websockets.server.WebSocketServerProtocol` |br| | | +| ``websockets.legacy.server.WebSocketServerProtocol`` | | ++-------------------------------------------------------------------+-----------------------------------------------------+ +| :func:`websockets.broadcast` |br| | *not available yet* | +| ``websockets.legacy.protocol.broadcast()`` | | ++-------------------------------------------------------------------+-----------------------------------------------------+ +| ``websockets.BasicAuthWebSocketServerProtocol`` |br| | *not available yet* | +| :class:`websockets.auth.BasicAuthWebSocketServerProtocol` |br| | | +| ``websockets.legacy.auth.BasicAuthWebSocketServerProtocol`` | | ++-------------------------------------------------------------------+-----------------------------------------------------+ +| ``websockets.basic_auth_protocol_factory()`` |br| | *not available yet* | +| :func:`websockets.auth.basic_auth_protocol_factory` |br| | | +| ``websockets.legacy.auth.basic_auth_protocol_factory()`` | | ++-------------------------------------------------------------------+-----------------------------------------------------+ + +.. _Review API changes: + +API changes +----------- + +Controlling UTF-8 decoding +.......................... + +The new implementation of the :meth:`~asyncio.connection.Connection.recv` method +provides the ``decode`` argument to control UTF-8 decoding of messages. This +didn't exist in the original implementation. + +If you're calling :meth:`~str.encode` on a :class:`str` object returned by +:meth:`~asyncio.connection.Connection.recv`, using ``decode=False`` and removing +:meth:`~str.encode` saves a round-trip of UTF-8 decoding and encoding for text +messages. + +You can also force UTF-8 decoding of binary messages with ``decode=True``. This +is rarely useful and has no performance benefits over decoding a :class:`bytes` +object returned by :meth:`~asyncio.connection.Connection.recv`. + +Receiving fragmented messages +............................. + +The new implementation provides the +:meth:`~asyncio.connection.Connection.recv_streaming` method for receiving a +fragmented message frame by frame. There was no way to do this in the original +implementation. + +Depending on your use case, adopting this method may improve performance when +streaming large messages. Specifically, it could reduce memory usage. + +Customizing the opening handshake +................................. + +On the client side, if you're adding headers to the handshake request sent by +:func:`~client.connect` with the ``extra_headers`` argument, you must rename it +to ``additional_headers``. + +On the server side, if you're customizing how :func:`~server.serve` processes +the opening handshake with the ``process_request``, ``extra_headers``, or +``select_subprotocol``, you must update your code. ``process_response`` and +``select_subprotocol`` have new signatures; ``process_response`` replaces +``extra_headers`` and provides more flexibility. + +``process_request`` +~~~~~~~~~~~~~~~~~~~ + +The signature of ``process_request`` changed. This is easiest to illustrate with +an example:: + + import http + + # Original implementation + + def process_request(path, request_headers): + return http.HTTPStatus.OK, [], b"OK\n" + + serve(..., process_request=process_request, ...) + + # New implementation + + def process_request(connection, request): + return connection.protocol.reject(http.HTTPStatus.OK, "OK\n") + + serve(..., process_request=process_request, ...) + +``connection`` is always available in ``process_request``. In the original +implementation, you had to write a subclass of +:class:`~server.WebSocketServerProtocol` and pass it in the ``create_protocol`` +argument to make the connection object available in a ``process_request`` +method. This pattern isn't useful anymore; you can replace it with a +``process_request`` function or coroutine. + +``path`` and ``headers`` are available as attributes of the ``request`` object. + +``process_response`` +~~~~~~~~~~~~~~~~~~~~ + +``process_request`` replaces ``extra_headers`` and provides more flexibility. +In the most basic case, you would adapt your code as follows:: + + # Original implementation + + serve(..., extra_headers=HEADERS, ...) + + # New implementation + + def process_response(connection, request, response): + response.headers.update(HEADERS) + return response + + serve(..., process_response=process_response, ...) + +``connection`` is always available in ``process_response``, similar to +``process_request``. In the original implementation, there was no way to make +the connection object available. + +In addition, the ``request`` and ``response`` objects are available, which +enables a broader range of use cases (e.g., logging) and makes +``process_response`` more useful than ``extra_headers``. + +``select_subprotocol`` +~~~~~~~~~~~~~~~~~~~~~~ + +The signature of ``select_subprotocol`` changed. Here's an example:: + + # Original implementation + + def select_subprotocol(client_subprotocols, server_subprotocols): + if "chat" in client_subprotocols: + return "chat" + + # New implementation + + def select_subprotocol(connection, subprotocols): + if "chat" in subprotocols + return "chat" + + serve(..., select_subprotocol=select_subprotocol, ...) + +``connection`` is always available in ``select_subprotocol``. This brings the +same benefits as in ``process_request``. It may remove the need to subclass of +:class:`~server.WebSocketServerProtocol`. + +The ``subprotocols`` argument contains the list of subprotocols offered by the +client. The list of subprotocols supported by the server was removed because +``select_subprotocols`` already knows which subprotocols it may select and under +which conditions. + +Miscellaneous changes +..................... + +The first argument of :func:`~asyncio.server.serve` is called ``handler`` instead +of ``ws_handler``. It's usually passed as a positional argument, making this +change transparent. If you're passing it as a keyword argument, you must update +its name. + +The keyword argument of :func:`~asyncio.server.serve` for customizing the +creation of the connection object is called ``create_connection`` instead of +``create_protocol``. It must return a :class:`~asyncio.server.ServerConnection` +instead of a :class:`~server.WebSocketServerProtocol`. If you were customizing +connection objects, you should check the new implementation and possibly redo +your customization. Keep in mind that the changes to ``process_request`` and +``select_subprotocol`` remove most use cases for ``create_connection``. diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 00b055dd1..f033f5632 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,8 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. +.. _13.0: + 13.0 ---- @@ -66,10 +68,10 @@ New features This new implementation is intended to be a drop-in replacement for the current implementation. It will become the default in a future release. - Please try it and report any issue that you encounter! - See :func:`websockets.asyncio.client.connect` and - :func:`websockets.asyncio.server.serve` for details. + Please try it and report any issue that you encounter! The :doc:`upgrade + guide <../howto/upgrade>` explains everything you need to know about the + upgrade process. * Validated compatibility with Python 3.12 and 3.13. @@ -79,6 +81,8 @@ New features If you were monkey-patching constants, be aware that they were renamed, which will break your configuration. You must switch to the environment variables. +.. _12.0: + 12.0 ---- @@ -135,6 +139,8 @@ Bug fixes * Restored the C extension in the source distribution. +.. _11.0: + 11.0 ---- @@ -211,6 +217,8 @@ Improvements * Set ``server_hostname`` automatically on TLS connections when providing a ``sock`` argument to :func:`~sync.client.connect`. +.. _10.4: + 10.4 ---- @@ -237,6 +245,8 @@ Improvements * Improved FAQ. +.. _10.3: + 10.3 ---- @@ -259,6 +269,8 @@ Improvements * Reduced noise in logs when :mod:`ssl` or :mod:`zlib` raise exceptions. +.. _10.2: + 10.2 ---- @@ -279,6 +291,8 @@ Bug fixes * Avoided leaking open sockets when :func:`~client.connect` is canceled. +.. _10.1: + 10.1 ---- @@ -328,6 +342,8 @@ Bug fixes * Avoided half-closing TCP connections that are already closed. +.. _10.0: + 10.0 ---- @@ -434,6 +450,8 @@ Bug fixes * Avoided a crash when receiving a ping while the connection is closing. +.. _9.1: + 9.1 --- @@ -472,6 +490,8 @@ Bug fixes * Fixed issues with the packaging of the 9.0 release. +.. _9.0: + 9.0 --- @@ -549,6 +569,8 @@ Bug fixes * Ensured cancellation always propagates, even on Python versions where :exc:`~asyncio.CancelledError` inherits :exc:`Exception`. +.. _8.1: + 8.1 --- @@ -583,6 +605,8 @@ Bug fixes * Restored the ability to import ``WebSocketProtocolError`` from ``websockets``. +.. _8.0: + 8.0 --- @@ -692,6 +716,8 @@ Bug fixes * Avoided a crash when a ``extra_headers`` callable returns :obj:`None`. +.. _7.0: + 7.0 --- @@ -786,6 +812,8 @@ Bug fixes :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`: canceling it at the wrong time could result in messages being dropped. +.. _6.0: + 6.0 --- @@ -840,6 +868,8 @@ Bug fixes * Fixed a regression in 5.0 that broke some invocations of :func:`~server.serve` and :func:`~client.connect`. +.. _5.0: + 5.0 --- @@ -925,6 +955,8 @@ Bug fixes * Fixed issues with the packaging of the 4.0 release. +.. _4.0: + 4.0 --- @@ -984,6 +1016,8 @@ Bug fixes * Stopped leaking pending tasks when :meth:`~asyncio.Task.cancel` is called on a connection while it's being closed. +.. _3.4: + 3.4 --- @@ -1027,6 +1061,8 @@ Bug fixes * Providing a ``sock`` argument to :func:`~client.connect` no longer crashes. +.. _3.3: + 3.3 --- @@ -1047,6 +1083,8 @@ Bug fixes * Avoided crashing on concurrent writes on slow connections. +.. _3.2: + 3.2 --- @@ -1063,6 +1101,8 @@ Improvements * Made server shutdown more robust. +.. _3.1: + 3.1 --- @@ -1078,6 +1118,8 @@ Bug fixes * Avoided a warning when closing a connection before the opening handshake. +.. _3.0: + 3.0 --- @@ -1135,6 +1177,8 @@ Improvements * Improved documentation. +.. _2.7: + 2.7 --- @@ -1150,6 +1194,8 @@ Improvements * Refreshed documentation. +.. _2.6: + 2.6 --- @@ -1167,6 +1213,8 @@ Bug fixes * Avoided TCP fragmentation of small frames. +.. _2.5: + 2.5 --- @@ -1200,6 +1248,8 @@ Bug fixes * Canceling :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` no longer drops the next message. +.. _2.4: + 2.4 --- @@ -1213,6 +1263,8 @@ New features * Added ``loop`` argument to :func:`~client.connect` and :func:`~server.serve`. +.. _2.3: + 2.3 --- @@ -1223,6 +1275,8 @@ Improvements * Improved compliance of close codes. +.. _2.2: + 2.2 --- @@ -1233,6 +1287,8 @@ New features * Added support for limiting message size. +.. _2.1: + 2.1 --- @@ -1247,6 +1303,8 @@ New features .. _Origin: https://www.rfc-editor.org/rfc/rfc6455.html#section-10.2 +.. _2.0: + 2.0 --- @@ -1275,6 +1333,8 @@ New features * Added flow control for outgoing data. +.. _1.0: + 1.0 --- diff --git a/docs/reference/asyncio/client.rst b/docs/reference/asyncio/client.rst index 5086015b7..f9ce2f2d8 100644 --- a/docs/reference/asyncio/client.rst +++ b/docs/reference/asyncio/client.rst @@ -1,5 +1,5 @@ -Client (:mod:`asyncio`) -======================= +Client (legacy :mod:`asyncio`) +============================== .. automodule:: websockets.client diff --git a/docs/reference/asyncio/common.rst b/docs/reference/asyncio/common.rst index dc7a54ee1..aee774479 100644 --- a/docs/reference/asyncio/common.rst +++ b/docs/reference/asyncio/common.rst @@ -1,7 +1,7 @@ :orphan: -Both sides (:mod:`asyncio`) -=========================== +Both sides (legacy :mod:`asyncio`) +================================== .. automodule:: websockets.legacy.protocol diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index 106317916..4bd52b40b 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -1,15 +1,15 @@ -Server (:mod:`asyncio`) -======================= +Server (legacy :mod:`asyncio`) +============================== .. automodule:: websockets.server Starting a server ----------------- -.. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) +.. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) :async: -.. autofunction:: unix_serve(ws_handler, path=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) +.. autofunction:: unix_serve(ws_handler, path=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) :async: Stopping a server @@ -34,7 +34,7 @@ Stopping a server Using a connection ------------------ -.. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, logger=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) +.. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, logger=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) .. automethod:: recv diff --git a/docs/reference/new-asyncio/client.rst b/docs/reference/new-asyncio/client.rst index 552d83b2f..196bda2b7 100644 --- a/docs/reference/new-asyncio/client.rst +++ b/docs/reference/new-asyncio/client.rst @@ -1,5 +1,5 @@ -Client (:mod:`asyncio` - new) -============================= +Client (new :mod:`asyncio`) +=========================== .. automodule:: websockets.asyncio.client diff --git a/docs/reference/new-asyncio/common.rst b/docs/reference/new-asyncio/common.rst index ba23552dc..4fa97dcf2 100644 --- a/docs/reference/new-asyncio/common.rst +++ b/docs/reference/new-asyncio/common.rst @@ -1,7 +1,7 @@ :orphan: -Both sides (:mod:`asyncio` - new) -================================= +Both sides (new :mod:`asyncio`) +=============================== .. automodule:: websockets.asyncio.connection diff --git a/docs/reference/new-asyncio/server.rst b/docs/reference/new-asyncio/server.rst index f3446fb80..c43673d33 100644 --- a/docs/reference/new-asyncio/server.rst +++ b/docs/reference/new-asyncio/server.rst @@ -1,5 +1,5 @@ -Server (:mod:`asyncio` - new) -============================= +Server (new :mod:`asyncio`) +=========================== .. automodule:: websockets.asyncio.server diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index dfa7065e7..a1ba59a37 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -21,8 +21,8 @@ cryptocurrency css ctrl deserialize -django dev +django Dockerfile dyno formatter @@ -44,6 +44,7 @@ linkerd liveness lookups MiB +middleware mutex mypy nginx @@ -77,8 +78,8 @@ uple uvicorn uvloop virtualenv -WebSocket websocket +WebSocket websockets ws wsgi From 8385cf02fccd5e171e1ee5b8949df11773c0f954 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 14 Aug 2024 08:40:46 +0200 Subject: [PATCH 1327/1539] Add write_limit parameter to the new asyncio API. --- docs/howto/upgrade.rst | 84 ++++++++++++++++++---------- src/websockets/asyncio/client.py | 16 ++++++ src/websockets/asyncio/connection.py | 16 +++++- src/websockets/asyncio/messages.py | 21 ++++--- src/websockets/asyncio/server.py | 16 ++++++ tests/asyncio/test_connection.py | 38 ++++++++++++- tests/asyncio/test_messages.py | 35 +++++++----- 7 files changed, 166 insertions(+), 60 deletions(-) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index bb4c59bc4..10e8967d8 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -115,26 +115,6 @@ In other words, the following pattern isn't supported:: async for websocket in connect(...): # this doesn't work yet ... -Configuring buffers -................... - -The new implementation doesn't provide a way to configure read and write buffers -yet. - -In practice, :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` -don't accept the ``max_queue``, ``read_limit``, and ``write_limit`` arguments. - -Here's the most likely outcome: - -* ``max_queue`` will be implemented but its semantics will change from "maximum - number of messages" to "maximum number of frames", which makes a difference - when messages are fragmented. -* ``read_limit`` won't be implemented because the buffer that it configured was - removed from the new implementation. The queue that ``max_queue`` configures - is the only read buffer now. -* ``write_limit`` will be implemented as in the original implementation. - Alternatively, the same functionality could be exposed with a different API. - .. _Update import paths: Import paths @@ -340,18 +320,60 @@ client. The list of subprotocols supported by the server was removed because ``select_subprotocols`` already knows which subprotocols it may select and under which conditions. -Miscellaneous changes -..................... +Arguments of :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` +.............................................................................. + +``ws_handler`` → ``handler`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -The first argument of :func:`~asyncio.server.serve` is called ``handler`` instead -of ``ws_handler``. It's usually passed as a positional argument, making this -change transparent. If you're passing it as a keyword argument, you must update -its name. +The first argument of :func:`~asyncio.server.serve` is now called ``handler`` +instead of ``ws_handler``. It's usually passed as a positional argument, making +this change transparent. If you're passing it as a keyword argument, you must +update its name. + +``create_protocol`` → ``create_connection`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The keyword argument of :func:`~asyncio.server.serve` for customizing the -creation of the connection object is called ``create_connection`` instead of +creation of the connection object is now called ``create_connection`` instead of ``create_protocol``. It must return a :class:`~asyncio.server.ServerConnection` -instead of a :class:`~server.WebSocketServerProtocol`. If you were customizing -connection objects, you should check the new implementation and possibly redo -your customization. Keep in mind that the changes to ``process_request`` and -``select_subprotocol`` remove most use cases for ``create_connection``. +instead of a :class:`~server.WebSocketServerProtocol`. + +If you were customizing connection objects, you should check the new +implementation and possibly redo your customization. Keep in mind that the +changes to ``process_request`` and ``select_subprotocol`` remove most use cases +for ``create_connection``. + +``max_queue`` +~~~~~~~~~~~~~ + +The ``max_queue`` argument of :func:`~asyncio.client.connect` and +:func:`~asyncio.server.serve` has a new meaning but achieves a similar effect. + +It is now the high-water mark of a buffer of incoming frames. It defaults to 16 +frames. It used to be the size of a buffer of incoming messages that refilled as +soon as a message was read. It used to default to 32 messages. + +This can make a difference when messages are fragmented in several frames. In +that case, you may want to increase ``max_queue``. If you're writing a high +performance server and you know that you're receiving fragmented messages, +probably you should adopt :meth:`~asyncio.connection.Connection.recv_streaming` +and optimize the performance of reads again. In all other cases, given how +uncommon fragmentation is, you shouldn't worry about this change. + +``read_limit`` +~~~~~~~~~~~~~~ + +The ``read_limit`` argument doesn't exist in the new implementation because it +doesn't buffer data received from the network in a +:class:`~asyncio.StreamReader`. With a better design, this buffer could be +removed. + +The buffer of incoming frames configured by ``max_queue`` is the only read +buffer now. + +``write_limit`` +~~~~~~~~~~~~~~~ + +The ``write_limit`` argument of :func:`~asyncio.client.connect` and +:func:`~asyncio.server.serve` defaults to 32 KiB instead of 64 KiB. diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index ac8ded8ca..b2eaf9a65 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -49,11 +49,15 @@ def __init__( protocol: ClientProtocol, *, close_timeout: float | None = 10, + max_queue: int | tuple[int, int | None] = 16, + write_limit: int | tuple[int, int | None] = 2**15, ) -> None: self.protocol: ClientProtocol super().__init__( protocol, close_timeout=close_timeout, + max_queue=max_queue, + write_limit=write_limit, ) self.response_rcvd: asyncio.Future[None] = self.loop.create_future() @@ -146,6 +150,14 @@ class connect: :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. :obj:`None` disables the limit. + max_queue: High-water mark of the buffer where frames are received. + It defaults to 16 frames. The low-water mark defaults to ``max_queue + // 4``. You may pass a ``(high, low)`` tuple to set the high-water + and low-water marks. + write_limit: High-water mark of write buffer in bytes. It is passed to + :meth:`~asyncio.WriteTransport.set_write_buffer_limits`. It defaults + to 32 KiB. You may pass a ``(high, low)`` tuple to set the + high-water and low-water marks. logger: Logger for this client. It defaults to ``logging.getLogger("websockets.client")``. See the :doc:`logging guide <../../topics/logging>` for details. @@ -199,6 +211,8 @@ def __init__( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, + max_queue: int | tuple[int, int | None] = 16, + write_limit: int | tuple[int, int | None] = 2**15, # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization @@ -243,6 +257,8 @@ def factory() -> ClientConnection: connection = create_connection( protocol, close_timeout=close_timeout, + max_queue=max_queue, + write_limit=write_limit, ) return connection diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 0a3ddb9aa..1c4424f0d 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -48,9 +48,17 @@ def __init__( protocol: Protocol, *, close_timeout: float | None = 10, + max_queue: int | tuple[int, int | None] = 16, + write_limit: int | tuple[int, int | None] = 2**15, ) -> None: self.protocol = protocol self.close_timeout = close_timeout + if isinstance(max_queue, int): + max_queue = (max_queue, None) + self.max_queue = max_queue + if isinstance(write_limit, int): + write_limit = (write_limit, None) + self.write_limit = write_limit # Inject reference to this instance in the protocol's logger. self.protocol.logger = logging.LoggerAdapter( @@ -803,11 +811,13 @@ def close_transport(self) -> None: def connection_made(self, transport: asyncio.BaseTransport) -> None: transport = cast(asyncio.Transport, transport) - self.transport = transport self.recv_messages = Assembler( - pause=self.transport.pause_reading, - resume=self.transport.resume_reading, + *self.max_queue, + pause=transport.pause_reading, + resume=transport.resume_reading, ) + transport.set_write_buffer_limits(*self.write_limit) + self.transport = transport def connection_lost(self, exc: Exception | None) -> None: self.protocol.receive_eof() # receive_eof is idempotent diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index bc33df8d7..33ab6a5e9 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -89,6 +89,8 @@ class Assembler: # coverage reports incorrectly: "line NN didn't jump to the function exit" def __init__( # pragma: no cover self, + high: int = 16, + low: int | None = None, pause: Callable[[], Any] = lambda: None, resume: Callable[[], Any] = lambda: None, ) -> None: @@ -99,11 +101,16 @@ def __init__( # pragma: no cover # call to Protocol.data_received() could produce thousands of frames, # which must be buffered. Instead, we pause reading when the buffer goes # above the high limit and we resume when it goes under the low limit. - self.high = 16 - self.low = 4 - self.paused = False + if low is None: + low = high // 4 + if low < 0: + raise ValueError("low must be positive or equal to zero") + if high < low: + raise ValueError("high must be greater than or equal to low") + self.high, self.low = high, low self.pause = pause self.resume = resume + self.paused = False # This flag prevents concurrent calls to get() by user code. self.get_in_progress = False @@ -254,14 +261,6 @@ def put(self, frame: Frame) -> None: self.frames.put(frame) self.maybe_pause() - def get_limits(self) -> tuple[int, int]: - """Return low and high water marks for flow control.""" - return self.low, self.high - - def set_limits(self, low: int = 4, high: int = 16) -> None: - """Configure low and high water marks for flow control.""" - self.low, self.high = low, high - def maybe_pause(self) -> None: """Pause the writer if queue is above the high water mark.""" # Check for "> high" to support high = 0 diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 0c8b8780b..4feea13c4 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -62,11 +62,15 @@ def __init__( server: WebSocketServer, *, close_timeout: float | None = 10, + max_queue: int | tuple[int, int | None] = 16, + write_limit: int | tuple[int, int | None] = 2**15, ) -> None: self.protocol: ServerProtocol super().__init__( protocol, close_timeout=close_timeout, + max_queue=max_queue, + write_limit=write_limit, ) self.server = server self.request_rcvd: asyncio.Future[None] = self.loop.create_future() @@ -574,6 +578,14 @@ def handler(websocket): :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. :obj:`None` disables the limit. + max_queue: High-water mark of the buffer where frames are received. + It defaults to 16 frames. The low-water mark defaults to ``max_queue + // 4``. You may pass a ``(high, low)`` tuple to set the high-water + and low-water marks. + write_limit: High-water mark of write buffer in bytes. It is passed to + :meth:`~asyncio.WriteTransport.set_write_buffer_limits`. It defaults + to 32 KiB. You may pass a ``(high, low)`` tuple to set the + high-water and low-water marks. logger: Logger for this server. It defaults to ``logging.getLogger("websockets.server")``. See the :doc:`logging guide <../../topics/logging>` for details. @@ -637,6 +649,8 @@ def __init__( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, + max_queue: int | tuple[int, int | None] = 16, + write_limit: int | tuple[int, int | None] = 2**15, # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization @@ -709,6 +723,8 @@ def protocol_select_subprotocol( protocol, self.server, close_timeout=close_timeout, + max_queue=max_queue, + write_limit=write_limit, ) return connection diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 2efd4e96d..02029b754 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -4,7 +4,7 @@ import socket import unittest import uuid -from unittest.mock import patch +from unittest.mock import Mock, patch from websockets.asyncio.compatibility import TimeoutError, aiter, anext, asyncio_timeout from websockets.asyncio.connection import * @@ -867,6 +867,42 @@ async def test_pong_explicit_binary(self): await self.connection.pong(b"pong") await self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + # Test parameters. + + async def test_close_timeout(self): + """close_timeout parameter configures close timeout.""" + connection = Connection(Protocol(self.LOCAL), close_timeout=42 * MS) + self.assertEqual(connection.close_timeout, 42 * MS) + + async def test_max_queue(self): + """max_queue parameter configures high-water mark of frames buffer.""" + connection = Connection(Protocol(self.LOCAL), max_queue=4) + transport = Mock() + connection.connection_made(transport) + self.assertEqual(connection.recv_messages.high, 4) + + async def test_max_queue_tuple(self): + """max_queue parameter configures high-water mark of frames buffer.""" + connection = Connection(Protocol(self.LOCAL), max_queue=(4, 2)) + transport = Mock() + connection.connection_made(transport) + self.assertEqual(connection.recv_messages.high, 4) + self.assertEqual(connection.recv_messages.low, 2) + + async def test_write_limit(self): + """write_limit parameter configures high-water mark of write buffer.""" + connection = Connection(Protocol(self.LOCAL), write_limit=4096) + transport = Mock() + connection.connection_made(transport) + transport.set_write_buffer_limits.assert_called_once_with(4096, None) + + async def test_write_limits(self): + """write_limit parameter configures high and low-water marks of write buffer.""" + connection = Connection(Protocol(self.LOCAL), write_limit=(4096, 2048)) + transport = Mock() + connection.connection_made(transport) + transport.set_write_buffer_limits.assert_called_once_with(4096, 2048) + # Test attributes. async def test_id(self): diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index c8a2d7cd5..615b1f3a8 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -70,8 +70,7 @@ class AssemblerTests(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): self.pause = unittest.mock.Mock() self.resume = unittest.mock.Mock() - self.assembler = Assembler(pause=self.pause, resume=self.resume) - self.assembler.set_limits(low=1, high=2) + self.assembler = Assembler(high=2, low=1, pause=self.pause, resume=self.resume) # Test get @@ -455,17 +454,25 @@ async def test_get_iter_fails_when_get_iter_is_running(self): await alist(self.assembler.get_iter()) self.assembler.close() # let task terminate - # Test getting and setting limits + # Test setting limits - async def test_get_limits(self): - """get_limits returns low and high water marks.""" - low, high = self.assembler.get_limits() - self.assertEqual(low, 1) - self.assertEqual(high, 2) + async def test_set_high_water_mark(self): + """high sets the high-water mark.""" + assembler = Assembler(high=10) + self.assertEqual(assembler.high, 10) - async def test_set_limits(self): - """set_limits changes low and high water marks.""" - self.assembler.set_limits(low=2, high=4) - low, high = self.assembler.get_limits() - self.assertEqual(low, 2) - self.assertEqual(high, 4) + async def test_set_high_and_low_water_mark(self): + """high sets the high-water mark.""" + assembler = Assembler(high=10, low=5) + self.assertEqual(assembler.high, 10) + self.assertEqual(assembler.low, 5) + + async def test_set_invalid_high_water_mark(self): + """high must be a non-negative integer.""" + with self.assertRaises(ValueError): + Assembler(high=-1) + + async def test_set_invalid_low_water_mark(self): + """low must be higher than high.""" + with self.assertRaises(ValueError): + Assembler(low=10, high=5) From 5eafbe466b909f21dc7e74b1350583b4d5ae0606 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 15 Aug 2024 16:25:51 +0200 Subject: [PATCH 1328/1539] Rewrite documentation of buffers. Describe all implementations. Also update documentation of compression. --- .gitignore | 2 +- docs/topics/compression.rst | 173 ++++++++++-------- docs/topics/design.rst | 49 ----- docs/topics/memory.rst | 156 +++++++++++++--- experiments/compression/benchmark.py | 74 ++------ experiments/compression/client.py | 18 +- experiments/compression/corpus.py | 52 ++++++ experiments/compression/server.py | 10 +- .../extensions/permessage_deflate.py | 6 +- 9 files changed, 316 insertions(+), 224 deletions(-) create mode 100644 experiments/compression/corpus.py diff --git a/.gitignore b/.gitignore index 324e77069..d8e6697a8 100644 --- a/.gitignore +++ b/.gitignore @@ -8,7 +8,7 @@ .tox build/ compliance/reports/ -experiments/compression/corpus.pkl +experiments/compression/corpus/ dist/ docs/_build/ htmlcov/ diff --git a/docs/topics/compression.rst b/docs/topics/compression.rst index eaf99070d..be263e56f 100644 --- a/docs/topics/compression.rst +++ b/docs/topics/compression.rst @@ -7,37 +7,36 @@ Most WebSocket servers exchange JSON messages because they're convenient to parse and serialize in a browser. These messages contain text data and tend to be repetitive. -This makes the stream of messages highly compressible. Enabling compression +This makes the stream of messages highly compressible. Compressing messages can reduce network traffic by more than 80%. -There's a standard for compressing messages. :rfc:`7692` defines WebSocket -Per-Message Deflate, a compression extension based on the Deflate_ algorithm. +websockets implements WebSocket Per-Message Deflate, a compression extension +based on the Deflate_ algorithm specified in :rfc:`7692`. .. _Deflate: https://en.wikipedia.org/wiki/Deflate -Configuring compression ------------------------ +:func:`~websockets.asyncio.client.connect` and +:func:`~websockets.asyncio.server.serve` enable compression by default because +the reduction in network bandwidth is usually worth the additional memory and +CPU cost. -:func:`~websockets.client.connect` and :func:`~websockets.server.serve` enable -compression by default because the reduction in network bandwidth is usually -worth the additional memory and CPU cost. -If you want to disable compression, set ``compression=None``:: +Configuring compression +----------------------- - import websockets +To disable compression, set ``compression=None``:: - websockets.connect(..., compression=None) + connect(..., compression=None, ...) - websockets.serve(..., compression=None) + serve(..., compression=None, ...) -If you want to customize compression settings, you can enable the Per-Message -Deflate extension explicitly with :class:`ClientPerMessageDeflateFactory` or +To customize compression settings, enable the Per-Message Deflate extension +explicitly with :class:`ClientPerMessageDeflateFactory` or :class:`ServerPerMessageDeflateFactory`:: - import websockets from websockets.extensions import permessage_deflate - websockets.connect( + connect( ..., extensions=[ permessage_deflate.ClientPerMessageDeflateFactory( @@ -46,9 +45,10 @@ Deflate extension explicitly with :class:`ClientPerMessageDeflateFactory` or compress_settings={"memLevel": 4}, ), ], + ..., ) - websockets.serve( + serve( ..., extensions=[ permessage_deflate.ServerPerMessageDeflateFactory( @@ -57,13 +57,14 @@ Deflate extension explicitly with :class:`ClientPerMessageDeflateFactory` or compress_settings={"memLevel": 4}, ), ], + ..., ) The Window Bits and Memory Level values in these examples reduce memory usage at the expense of compression rate. -Compression settings --------------------- +Compression parameters +---------------------- When a client and a server enable the Per-Message Deflate extension, they negotiate two parameters to guarantee compatibility between compression and @@ -81,9 +82,9 @@ and memory usage for both sides. This requires retaining the compression context and state between messages, which increases the memory footprint of a connection. -* **Window Bits** controls the size of the compression context. It must be - an integer between 9 (lowest memory usage) and 15 (best compression). - Setting it to 8 is possible but rejected by some versions of zlib. +* **Window Bits** controls the size of the compression context. It must be an + integer between 9 (lowest memory usage) and 15 (best compression). Setting it + to 8 is possible but rejected by some versions of zlib and not very useful. On the server side, websockets defaults to 12. Specifically, the compression window size (server to client) is always 12 while the decompression window @@ -94,9 +95,8 @@ and memory usage for both sides. has the same effect as defaulting to 15. :mod:`zlib` offers additional parameters for tuning compression. They control -the trade-off between compression rate, memory usage, and CPU usage only for -compressing. They're transparent for decompressing. Unless mentioned -otherwise, websockets inherits defaults of :func:`~zlib.compressobj`. +the trade-off between compression rate, memory usage, and CPU usage for +compressing. They're transparent for decompressing. * **Memory Level** controls the size of the compression state. It must be an integer between 1 (lowest memory usage) and 9 (best compression). @@ -108,87 +108,82 @@ otherwise, websockets inherits defaults of :func:`~zlib.compressobj`. * **Compression Level** controls the effort to optimize compression. It must be an integer between 1 (lowest CPU usage) and 9 (best compression). + websockets relies on the default value chosen by :func:`~zlib.compressobj`, + ``Z_DEFAULT_COMPRESSION``. + * **Strategy** selects the compression strategy. The best choice depends on the type of data being compressed. + websockets relies on the default value chosen by :func:`~zlib.compressobj`, + ``Z_DEFAULT_STRATEGY``. -Tuning compression ------------------- +To customize these parameters, add keyword arguments for +:func:`~zlib.compressobj` in ``compress_settings``. -For servers -........... +Default settings for servers +---------------------------- By default, websockets enables compression with conservative settings that optimize memory usage at the cost of a slightly worse compression rate: -Window Bits = 12 and Memory Level = 5. This strikes a good balance for small +Window Bits = 12 and Memory Level = 5. This strikes a good balance for small messages that are typical of WebSocket servers. -Here's how various compression settings affect memory usage of a single -connection on a 64-bit system, as well a benchmark of compressed size and -compression time for a corpus of small JSON documents. +Here's an example of how compression settings affect memory usage per +connection, compressed size, and compression time for a corpus of JSON +documents. =========== ============ ============ ================ ================ Window Bits Memory Level Memory usage Size vs. default Time vs. default =========== ============ ============ ================ ================ -15 8 322 KiB -4.0% +15% -14 7 178 KiB -2.6% +10% -13 6 106 KiB -1.4% +5% -**12** **5** **70 KiB** **=** **=** -11 4 52 KiB +3.7% -5% -10 3 43 KiB +90% +50% -9 2 39 KiB +160% +100% -— — 19 KiB +452% — +15 8 316 KiB -10% +10% +14 7 172 KiB -7% +5% +13 6 100 KiB -3% +2% +**12** **5** **64 KiB** **=** **=** +11 4 46 KiB +10% +4% +10 3 37 KiB +70% +40% +9 2 33 KiB +130% +90% +— — 14 KiB +350% — =========== ============ ============ ================ ================ Window Bits and Memory Level don't have to move in lockstep. However, other combinations don't yield significantly better results than those shown above. -Compressed size and compression time depend heavily on the kind of messages -exchanged by the application so this example may not apply to your use case. - -You can adapt `compression/benchmark.py`_ by creating a list of typical -messages and passing it to the ``_run`` function. - -Window Bits = 11 and Memory Level = 4 looks like the sweet spot in this table. - -websockets defaults to Window Bits = 12 and Memory Level = 5 to stay away from -Window Bits = 10 or Memory Level = 3 where performance craters, raising doubts -on what could happen at Window Bits = 11 and Memory Level = 4 on a different +websockets defaults to Window Bits = 12 and Memory Level = 5 to stay away from +Window Bits = 10 or Memory Level = 3 where performance craters, raising doubts +on what could happen at Window Bits = 11 and Memory Level = 4 on a different corpus. Defaults must be safe for all applications, hence a more conservative choice. -.. _compression/benchmark.py: https://github.com/python-websockets/websockets/blob/main/experiments/compression/benchmark.py +Optimizing settings +------------------- -The benchmark focuses on compression because it's more expensive than -decompression. Indeed, leaving aside small allocations, theoretical memory -usage is: +Compressed size and compression time depend on the structure of messages +exchanged by your application. As a consequence, default settings may not be +optimal for your use case. -* ``(1 << (windowBits + 2)) + (1 << (memLevel + 9))`` for compression; -* ``1 << windowBits`` for decompression. +To compare how various compression settings perform for your use case: -CPU usage is also higher for compression than decompression. +1. Create a corpus of typical messages in a directory, one message per file. +2. Run the `compression/benchmark.py`_ script, passing the directory in + argument. -While it's always possible for a server to use a smaller window size for -compressing outgoing messages, using a smaller window size for decompressing -incoming messages requires collaboration from clients. +The script measures compressed size and compression time for all combinations of +Window Bits and Memory Level. It outputs two tables with absolute values and two +tables with values relative to websockets' default settings. -When a client doesn't support configuring the size of its compression window, -websockets enables compression with the largest possible decompression window. -In most use cases, this is more efficient than disabling compression both ways. +Pick your favorite settings in these tables and configure them as shown above. -If you are very sensitive to memory usage, you can reverse this behavior by -setting the ``require_client_max_window_bits`` parameter of -:class:`ServerPerMessageDeflateFactory` to ``True``. +.. _compression/benchmark.py: https://github.com/python-websockets/websockets/blob/main/experiments/compression/benchmark.py -For clients -........... +Default settings for clients +---------------------------- -By default, websockets enables compression with Memory Level = 5 but leaves +By default, websockets enables compression with Memory Level = 5 but leaves the Window Bits setting up to the server. -There's two good reasons and one bad reason for not optimizing the client side -like the server side: +There's two good reasons and one bad reason for not optimizing Window Bits on +the client side as on the server side: 1. If the maintainers of a server configured some optimized settings, we don't want to override them with more restrictive settings. @@ -196,8 +191,9 @@ like the server side: 2. Optimizing memory usage doesn't matter very much for clients because it's uncommon to open thousands of client connections in a program. -3. On a more pragmatic note, some servers misbehave badly when a client - configures compression settings. `AWS API Gateway`_ is the worst offender. +3. On a more pragmatic and annoying note, some servers misbehave badly when a + client configures compression settings. `AWS API Gateway`_ is the worst + offender. .. _AWS API Gateway: https://github.com/python-websockets/websockets/issues/1065 @@ -207,6 +203,29 @@ like the server side: Until the ecosystem levels up, interoperability with buggy servers seems more valuable than optimizing memory usage. +Decompression +------------- + +The discussion above focuses on compression because it's more expensive than +decompression. Indeed, leaving aside small allocations, theoretical memory +usage is: + +* ``(1 << (windowBits + 2)) + (1 << (memLevel + 9))`` for compression; +* ``1 << windowBits`` for decompression. + +CPU usage is also higher for compression than decompression. + +While it's always possible for a server to use a smaller window size for +compressing outgoing messages, using a smaller window size for decompressing +incoming messages requires collaboration from clients. + +When a client doesn't support configuring the size of its compression window, +websockets enables compression with the largest possible decompression window. +In most use cases, this is more efficient than disabling compression both ways. + +If you are very sensitive to memory usage, you can reverse this behavior by +setting the ``require_client_max_window_bits`` parameter of +:class:`ServerPerMessageDeflateFactory` to ``True``. Further reading --------------- @@ -216,7 +235,7 @@ settings affect memory usage and how to optimize them. .. _blog post by Ilya Grigorik: https://www.igvita.com/2013/11/27/configuring-and-optimizing-websocket-compression/ -This `experiment by Peter Thorson`_ recommends Window Bits = 11 and Memory -Level = 4 for optimizing memory usage. +This `experiment by Peter Thorson`_ recommends Window Bits = 11 and Memory +Level = 4 for optimizing memory usage. .. _experiment by Peter Thorson: https://mailarchive.ietf.org/arch/msg/hybi/F9t4uPufVEy8KBLuL36cZjCmM_Y/ diff --git a/docs/topics/design.rst b/docs/topics/design.rst index f164d2990..cc65e6a70 100644 --- a/docs/topics/design.rst +++ b/docs/topics/design.rst @@ -488,55 +488,6 @@ they're drained. That's why all APIs that write frames are asynchronous. Of course, it's still possible for an application to create its own unbounded buffers and break the backpressure. Be careful with queues. - -.. _buffers: - -Buffers -------- - -.. note:: - - This section discusses buffers from the perspective of a server but it - applies to clients as well. - -An asynchronous systems works best when its buffers are almost always empty. - -For example, if a client sends data too fast for a server, the queue of -incoming messages will be constantly full. The server will always be 32 -messages (by default) behind the client. This consumes memory and increases -latency for no good reason. The problem is called bufferbloat. - -If buffers are almost always full and that problem cannot be solved by adding -capacity — typically because the system is bottlenecked by the output and -constantly regulated by backpressure — reducing the size of buffers minimizes -negative consequences. - -By default websockets has rather high limits. You can decrease them according -to your application's characteristics. - -Bufferbloat can happen at every level in the stack where there is a buffer. -For each connection, the receiving side contains these buffers: - -- OS buffers: tuning them is an advanced optimization. -- :class:`~asyncio.StreamReader` bytes buffer: the default limit is 64 KiB. - You can set another limit by passing a ``read_limit`` keyword argument to - :func:`~client.connect()` or :func:`~server.serve`. -- Incoming messages :class:`~collections.deque`: its size depends both on - the size and the number of messages it contains. By default the maximum - UTF-8 encoded size is 1 MiB and the maximum number is 32. In the worst case, - after UTF-8 decoding, a single message could take up to 4 MiB of memory and - the overall memory consumption could reach 128 MiB. You should adjust these - limits by setting the ``max_size`` and ``max_queue`` keyword arguments of - :func:`~client.connect()` or :func:`~server.serve` according to your - application's requirements. - -For each connection, the sending side contains these buffers: - -- :class:`~asyncio.StreamWriter` bytes buffer: the default size is 64 KiB. - You can set another limit by passing a ``write_limit`` keyword argument to - :func:`~client.connect()` or :func:`~server.serve`. -- OS buffers: tuning them is an advanced optimization. - Concurrency ----------- diff --git a/docs/topics/memory.rst b/docs/topics/memory.rst index e44247a77..efbcbb83f 100644 --- a/docs/topics/memory.rst +++ b/docs/topics/memory.rst @@ -1,5 +1,5 @@ -Memory usage -============ +Memory and buffers +================== .. currentmodule:: websockets @@ -9,40 +9,148 @@ memory usage can become a bottleneck. Memory usage of a single connection is the sum of: -1. the baseline amount of memory websockets requires for each connection, -2. the amount of data held in buffers before the application processes it, -3. any additional memory allocated by the application itself. +1. the baseline amount of memory that websockets uses for each connection; +2. the amount of memory needed by your application code; +3. the amount of data held in buffers. -Baseline --------- +Connection +---------- -Compression settings are the main factor affecting the baseline amount of -memory used by each connection. +Compression settings are the primary factor affecting how much memory each +connection uses. -With websockets' defaults, on the server side, a single connections uses -70 KiB of memory. +The :mod:`asyncio` implementation with default settings uses 64 KiB of memory +for each connection. + +You can reduce memory usage to 14 KiB per connection if you disable compression +entirely. Refer to the :doc:`topic guide on compression <../topics/compression>` to learn more about tuning compression settings. +Application +----------- + +Your application will allocate memory for its data structures. Memory usage +depends on your use case and your implementation. + +Make sure that you don't keep references to data that you don't need anymore +because this prevents garbage collection. + Buffers ------- -Under normal circumstances, buffers are almost always empty. +Typical WebSocket applications exchange small messages at a rate that doesn't +saturate the CPU or the network. Buffers are almost always empty. This is the +optimal situation. Buffers absorb bursts of incoming or outgoing messages +without having to pause reading or writing. + +If the application receives messages faster than it can process them, receive +buffers will fill up when. If the application sends messages faster than the +network can transmit them, send buffers will fill up. + +When buffers are almost always full, not only does the additional memory usage +fail to bring any benefit, but latency degrades as well. This problem is called +bufferbloat_. If it cannot be resolved by adding capacity, typically because the +system is bottlenecked by its output and constantly regulated by +:ref:`backpressure `, then buffers should be kept small to ensure +that backpressure kicks in quickly. + +.. _bufferbloat: https://en.wikipedia.org/wiki/Bufferbloat + +To sum up, buffers should be sized to absorb bursts of messages. Making them +larger than necessary often causes more harm than good. + +There are three levels of buffering in an application built with websockets. + +TCP buffers +........... + +The operating system allocates buffers for each TCP connection. The receive +buffer stores data received from the network until the application reads it. +The send buffer stores data written by the application until it's sent to +the network and acknowledged by the recipient. + +Modern operating systems adjust the size of TCP buffers automatically to match +network conditions. Overall, you shouldn't worry about TCP buffers. Just be +aware that they exist. + +In very high throughput scenarios, TCP buffers may grow to several megabytes +to store the data in flight. Then, they can make up the bulk of the memory +usage of a connection. + +I/O library buffers +................... + +I/O libraries like :mod:`asyncio` may provide read and write buffers to reduce +the frequency of system calls or the need to pause reading or writing. + +You should keep these buffers small. Increasing them can help with spiky +workloads but it can also backfire because it delays backpressure. + +* In the new :mod:`asyncio` implementation, there is no library-level read + buffer. + + There is a write buffer. The ``write_limit`` argument of + :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` controls its + size. When the write buffer grows above the high-water mark, + :meth:`~asyncio.connection.Connection.send` waits until it drains under the + low-water mark to return. This creates backpressure on coroutines that send + messages. + +* In the legacy :mod:`asyncio` implementation, there is a library-level read + buffer. The ``read_limit`` argument of :func:`~client.connect` and + :func:`~server.serve` controls its size. When the read buffer grows above the + high-water mark, the connection stops reading from the network until it drains + under the low-water mark. This creates backpressure on the TCP connection. + + There is a write buffer. It as controlled by ``write_limit``. It behaves like + the new :mod:`asyncio` implementation described above. + +* In the :mod:`threading` implementation, there are no library-level buffers. + All I/O operations are performed directly on the :class:`~socket.socket`. + +websockets' buffers +................... + +Incoming messages are queued in a buffer after they have been received from the +network and parsed. A larger buffer may help a slow applications handle bursts +of messages while remaining responsive to control frames. + +The memory footprint of this buffer is bounded by the product of ``max_size``, +which controls the size of items in the queue, and ``max_queue``, which controls +the number of items. + +The ``max_size`` argument of :func:`~asyncio.client.connect` and +:func:`~asyncio.server.serve` defaults to 1 MiB. Most applications never receive +such large messages. Configuring a smaller value puts a tighter boundary on +memory usage. This can make your application more resilient to denial of service +attacks. + +The behavior of the ``max_queue`` argument of :func:`~asyncio.client.connect` +and :func:`~asyncio.server.serve` varies across implementations. -Under high load, if a server receives more messages than it can process, -bufferbloat can result in excessive memory usage. +* In the new :mod:`asyncio` implementation, ``max_queue`` is the high-water mark + of a queue of incoming frames. It defaults to 16 frames. If the queue grows + larger, the connection stops reading from the network until the application + consumes messages and the queue goes below the low-water mark. This creates + backpressure on the TCP connection. -By default websockets has generous limits. It is strongly recommended to adapt -them to your application. When you call :func:`~server.serve`: + Each item in the queue is a frame. A frame can be a message or a message + fragment. Either way, it must be smaller than ``max_size``, the maximum size + of a message. The queue may use up to ``max_size * max_queue`` bytes of + memory. By default, this is 16 MiB. -- Set ``max_size`` (default: 1 MiB, UTF-8 encoded) to the maximum size of - messages your application generates. -- Set ``max_queue`` (default: 32) to the maximum number of messages your - application expects to receive faster than it can process them. The queue - provides burst tolerance without slowing down the TCP connection. +* In the legacy :mod:`asyncio` implementation, ``max_queue`` is the maximum + size of a queue of incoming messages. It defaults to 32 messages. If the queue + fills up, the connection stops reading from the library-level read buffer + described above. If that buffer fills up as well, it will create backpressure + on the TCP connection. -Furthermore, you can lower ``read_limit`` and ``write_limit`` (default: -64 KiB) to reduce the size of buffers for incoming and outgoing data. + Text messages are decoded before they're added to the queue. Since Python can + use up to 4 bytes of memory per character, the queue may use up to ``4 * + max_size * max_queue`` bytes of memory. By default, this is 128 MiB. -The design document provides :ref:`more details about buffers `. +* In the :mod:`threading` implementation, there is no queue of incoming + messages. The ``max_queue`` argument doesn't exist. The connection keeps at + most one message in memory at a time. diff --git a/experiments/compression/benchmark.py b/experiments/compression/benchmark.py index 4fbdf6220..86ebece31 100644 --- a/experiments/compression/benchmark.py +++ b/experiments/compression/benchmark.py @@ -1,72 +1,32 @@ #!/usr/bin/env python -import getpass -import json -import pickle -import subprocess +import collections +import pathlib import sys import time import zlib -CORPUS_FILE = "corpus.pkl" - REPEAT = 10 WB, ML = 12, 5 # defaults used as a reference -def _corpus(): - OAUTH_TOKEN = getpass.getpass("OAuth Token? ") - COMMIT_API = ( - f'curl -H "Authorization: token {OAUTH_TOKEN}" ' - f"https://api.github.com/repos/python-websockets/websockets/git/commits/:sha" - ) - - commits = [] - - head = subprocess.check_output("git rev-parse HEAD", shell=True).decode().strip() - todo = [head] - seen = set() - - while todo: - sha = todo.pop(0) - commit = subprocess.check_output(COMMIT_API.replace(":sha", sha), shell=True) - commits.append(commit) - seen.add(sha) - for parent in json.loads(commit)["parents"]: - sha = parent["sha"] - if sha not in seen and sha not in todo: - todo.append(sha) - time.sleep(1) # rate throttling - - return commits - - -def corpus(): - data = _corpus() - with open(CORPUS_FILE, "wb") as handle: - pickle.dump(data, handle) - - -def _run(data): - size = {} - duration = {} +def benchmark(data): + size = collections.defaultdict(dict) + duration = collections.defaultdict(dict) for wbits in range(9, 16): - size[wbits] = {} - duration[wbits] = {} - for memLevel in range(1, 10): encoder = zlib.compressobj(wbits=-wbits, memLevel=memLevel) encoded = [] + print(f"Compressing {REPEAT} times with {wbits=} and {memLevel=}") + t0 = time.perf_counter() for _ in range(REPEAT): for item in data: - if isinstance(item, str): - item = item.encode() # Taken from PerMessageDeflate.encode item = encoder.compress(item) + encoder.flush(zlib.Z_SYNC_FLUSH) if item.endswith(b"\x00\x00\xff\xff"): @@ -75,7 +35,7 @@ def _run(data): t1 = time.perf_counter() - size[wbits][memLevel] = sum(len(item) for item in encoded) + size[wbits][memLevel] = sum(len(item) for item in encoded) / REPEAT duration[wbits][memLevel] = (t1 - t0) / REPEAT raw_size = sum(len(item) for item in data) @@ -149,15 +109,13 @@ def _run(data): print() -def run(): - with open(CORPUS_FILE, "rb") as handle: - data = pickle.load(handle) - _run(data) +def main(corpus): + data = [file.read_bytes() for file in corpus.iterdir()] + benchmark(data) -try: - run = globals()[sys.argv[1]] -except (KeyError, IndexError): - print(f"Usage: {sys.argv[0]} [corpus|run]") -else: - run() +if __name__ == "__main__": + if len(sys.argv) < 2: + print(f"Usage: {sys.argv[0]} [directory]") + sys.exit(2) + main(pathlib.Path(sys.argv[1])) diff --git a/experiments/compression/client.py b/experiments/compression/client.py index 3ee19ddc5..69bfd5e7c 100644 --- a/experiments/compression/client.py +++ b/experiments/compression/client.py @@ -4,8 +4,8 @@ import statistics import tracemalloc -import websockets -from websockets.extensions import permessage_deflate +from websockets.asyncio.client import connect +from websockets.extensions.permessage_deflate import ClientPerMessageDeflateFactory CLIENTS = 20 @@ -16,16 +16,16 @@ MEM_SIZE = [] -async def client(client): +async def client(num): # Space out connections to make them sequential. - await asyncio.sleep(client * INTERVAL) + await asyncio.sleep(num * INTERVAL) tracemalloc.start() - async with websockets.connect( + async with connect( "ws://localhost:8765", extensions=[ - permessage_deflate.ClientPerMessageDeflateFactory( + ClientPerMessageDeflateFactory( server_max_window_bits=WB, client_max_window_bits=WB, compress_settings={"memLevel": ML}, @@ -42,11 +42,13 @@ async def client(client): tracemalloc.stop() # Hold connection open until the end of the test. - await asyncio.sleep(CLIENTS * INTERVAL) + await asyncio.sleep((CLIENTS + 1 - num) * INTERVAL) async def clients(): - await asyncio.gather(*[client(client) for client in range(CLIENTS + 1)]) + # Start one more client than necessary because we will ignore + # non-representative results from the first connection. + await asyncio.gather(*[client(num) for num in range(CLIENTS + 1)]) asyncio.run(clients()) diff --git a/experiments/compression/corpus.py b/experiments/compression/corpus.py new file mode 100644 index 000000000..da5661dfa --- /dev/null +++ b/experiments/compression/corpus.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python + +import getpass +import json +import pathlib +import subprocess +import sys +import time + + +def github_commits(): + OAUTH_TOKEN = getpass.getpass("OAuth Token? ") + COMMIT_API = ( + f'curl -H "Authorization: token {OAUTH_TOKEN}" ' + f"https://api.github.com/repos/python-websockets/websockets/git/commits/:sha" + ) + + commits = [] + + head = subprocess.check_output( + "git rev-parse origin/main", + shell=True, + text=True, + ).strip() + todo = [head] + seen = set() + + while todo: + sha = todo.pop(0) + commit = subprocess.check_output(COMMIT_API.replace(":sha", sha), shell=True) + commits.append(commit) + seen.add(sha) + for parent in json.loads(commit)["parents"]: + sha = parent["sha"] + if sha not in seen and sha not in todo: + todo.append(sha) + time.sleep(1) # rate throttling + + return commits + + +def main(corpus): + data = github_commits() + for num, content in enumerate(reversed(data)): + (corpus / f"{num:04d}.json").write_bytes(content) + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print(f"Usage: {sys.argv[0]} [directory]") + sys.exit(2) + main(pathlib.Path(sys.argv[1])) diff --git a/experiments/compression/server.py b/experiments/compression/server.py index 8d1ee3cd7..1c28f7355 100644 --- a/experiments/compression/server.py +++ b/experiments/compression/server.py @@ -6,8 +6,8 @@ import statistics import tracemalloc -import websockets -from websockets.extensions import permessage_deflate +from websockets.asyncio.server import serve +from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory CLIENTS = 20 @@ -44,12 +44,12 @@ async def server(): print() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with websockets.serve( + async with serve( handler, "localhost", 8765, extensions=[ - permessage_deflate.ServerPerMessageDeflateFactory( + ServerPerMessageDeflateFactory( server_max_window_bits=WB, client_max_window_bits=WB, compress_settings={"memLevel": ML}, @@ -63,7 +63,7 @@ async def server(): asyncio.run(server()) -# First connection may incur non-representative setup costs. +# First connection incurs non-representative setup costs. del MEM_SIZE[0] print(f"µ = {statistics.mean(MEM_SIZE) / 1024:.1f} KiB") diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index fea14131e..5b907b79f 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -62,7 +62,8 @@ def __init__( if not self.local_no_context_takeover: self.encoder = zlib.compressobj( - wbits=-self.local_max_window_bits, **self.compress_settings + wbits=-self.local_max_window_bits, + **self.compress_settings, ) # To handle continuation frames properly, we must keep track of @@ -156,7 +157,8 @@ def encode(self, frame: frames.Frame) -> frames.Frame: # Re-initialize per-message decoder. if self.local_no_context_takeover: self.encoder = zlib.compressobj( - wbits=-self.local_max_window_bits, **self.compress_settings + wbits=-self.local_max_window_bits, + **self.compress_settings, ) # Compress data. From c3b162d05c3788b9367eb3cce8c5001c37a3e6fa Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 17 Aug 2024 08:53:45 +0200 Subject: [PATCH 1329/1539] Add broadcast to the new asyncio implementation. --- docs/faq/server.rst | 2 +- docs/howto/upgrade.rst | 10 +- docs/intro/tutorial2.rst | 16 +-- docs/project/changelog.rst | 4 +- docs/reference/asyncio/server.rst | 2 +- docs/reference/new-asyncio/server.rst | 5 + docs/topics/broadcast.rst | 69 ++++++------ docs/topics/logging.rst | 2 +- docs/topics/performance.rst | 6 +- experiments/broadcast/server.py | 21 ++-- src/websockets/asyncio/connection.py | 100 ++++++++++++++++- src/websockets/legacy/protocol.py | 14 +-- src/websockets/sync/connection.py | 2 +- tests/asyncio/test_connection.py | 156 ++++++++++++++++++++++++++ tests/legacy/test_protocol.py | 12 +- 15 files changed, 341 insertions(+), 80 deletions(-) diff --git a/docs/faq/server.rst b/docs/faq/server.rst index cba1cd35f..53e34632f 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -102,7 +102,7 @@ Record all connections in a global variable:: finally: CONNECTIONS.remove(websocket) -Then, call :func:`~websockets.broadcast`:: +Then, call :func:`~asyncio.connection.broadcast`:: import websockets diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 10e8967d8..6efaf0f56 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -70,12 +70,6 @@ Missing features If your application relies on one of them, you should stick to the original implementation until the new implementation supports it in a future release. -Broadcast -......... - -The new implementation doesn't support :doc:`broadcasting messages -<../topics/broadcast>` yet. - Keepalive ......... @@ -178,8 +172,8 @@ Server APIs | :class:`websockets.server.WebSocketServerProtocol` |br| | | | ``websockets.legacy.server.WebSocketServerProtocol`` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| :func:`websockets.broadcast` |br| | *not available yet* | -| ``websockets.legacy.protocol.broadcast()`` | | +| ``websockets.broadcast`` |br| | :func:`websockets.asyncio.connection.broadcast` | +| :func:`websockets.legacy.protocol.broadcast()` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.BasicAuthWebSocketServerProtocol`` |br| | *not available yet* | | :class:`websockets.auth.BasicAuthWebSocketServerProtocol` |br| | | diff --git a/docs/intro/tutorial2.rst b/docs/intro/tutorial2.rst index 5ac4ae9dd..b8e35f292 100644 --- a/docs/intro/tutorial2.rst +++ b/docs/intro/tutorial2.rst @@ -482,7 +482,7 @@ you're using this pattern: ... Since this is a very common pattern in WebSocket servers, websockets provides -the :func:`broadcast` helper for this purpose: +the :func:`~legacy.protocol.broadcast` helper for this purpose: .. code-block:: python @@ -494,13 +494,14 @@ the :func:`broadcast` helper for this purpose: ... -Calling :func:`broadcast` once is more efficient than +Calling :func:`legacy.protocol.broadcast` once is more efficient than calling :meth:`~legacy.protocol.WebSocketCommonProtocol.send` in a loop. -However, there's a subtle difference in behavior. Did you notice that there's -no ``await`` in the second version? Indeed, :func:`broadcast` is a function, -not a coroutine like :meth:`~legacy.protocol.WebSocketCommonProtocol.send` -or :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`. +However, there's a subtle difference in behavior. Did you notice that there's no +``await`` in the second version? Indeed, :func:`legacy.protocol.broadcast` is a +function, not a coroutine like +:meth:`~legacy.protocol.WebSocketCommonProtocol.send` or +:meth:`~legacy.protocol.WebSocketCommonProtocol.recv`. It's quite obvious why :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` is a coroutine. When you want to receive the next message, you have to wait @@ -521,7 +522,8 @@ That said, when you're sending the same messages to many clients in a loop, applying backpressure in this way can become counterproductive. When you're broadcasting, you don't want to slow down everyone to the pace of the slowest clients; you want to drop clients that cannot keep up with the data stream. -That's why :func:`broadcast` doesn't wait until write buffers drain. +That's why :func:`legacy.protocol.broadcast` doesn't wait until write buffers +drain. For our Connect Four game, there's no difference in practice: the total amount of data sent on a connection for a game of Connect Four is less than 64 KB, diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index f033f5632..eaabb2e9f 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -212,7 +212,7 @@ Improvements * Added platform-independent wheels. -* Improved error handling in :func:`~websockets.broadcast`. +* Improved error handling in :func:`~legacy.protocol.broadcast`. * Set ``server_hostname`` automatically on TLS connections when providing a ``sock`` argument to :func:`~sync.client.connect`. @@ -402,7 +402,7 @@ New features * Added compatibility with Python 3.10. -* Added :func:`~websockets.broadcast` to send a message to many clients. +* Added :func:`~legacy.protocol.broadcast` to send a message to many clients. * Added support for reconnecting automatically by using :func:`~client.connect` as an asynchronous iterator. diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index 4bd52b40b..3636f0b33 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -110,4 +110,4 @@ websockets supports HTTP Basic Authentication according to Broadcast --------- -.. autofunction:: websockets.broadcast +.. autofunction:: websockets.legacy.protocol.broadcast diff --git a/docs/reference/new-asyncio/server.rst b/docs/reference/new-asyncio/server.rst index c43673d33..7f9de6148 100644 --- a/docs/reference/new-asyncio/server.rst +++ b/docs/reference/new-asyncio/server.rst @@ -70,3 +70,8 @@ Using a connection .. autoattribute:: response .. autoproperty:: subprotocol + +Broadcast +--------- + +.. autofunction:: websockets.asyncio.connection.broadcast diff --git a/docs/topics/broadcast.rst b/docs/topics/broadcast.rst index b6ddda734..671319136 100644 --- a/docs/topics/broadcast.rst +++ b/docs/topics/broadcast.rst @@ -1,21 +1,22 @@ -Broadcasting messages -===================== +Broadcasting +============ .. currentmodule:: websockets - -.. admonition:: If you just want to send a message to all connected clients, - use :func:`broadcast`. +.. admonition:: If you want to send a message to all connected clients, + use :func:`~asyncio.connection.broadcast`. :class: tip - If you want to learn about its design in depth, continue reading this - document. + If you want to learn about its design, continue reading this document. + + For the legacy :mod:`asyncio` implementation, use + :func:`~legacy.protocol.broadcast`. WebSocket servers often send the same message to all connected clients or to a subset of clients for which the message is relevant. -Let's explore options for broadcasting a message, explain the design -of :func:`broadcast`, and discuss alternatives. +Let's explore options for broadcasting a message, explain the design of +:func:`~asyncio.connection.broadcast`, and discuss alternatives. For each option, we'll provide a connection handler called ``handler()`` and a function or coroutine called ``broadcast()`` that sends a message to all @@ -24,7 +25,7 @@ connected clients. Integrating them is left as an exercise for the reader. You could start with:: import asyncio - import websockets + from websockets.asyncio.server import serve async def handler(websocket): ... @@ -39,7 +40,7 @@ Integrating them is left as an exercise for the reader. You could start with:: await broadcast(message) async def main(): - async with websockets.serve(handler, "localhost", 8765): + async with serve(handler, "localhost", 8765): await broadcast_messages() # runs forever if __name__ == "__main__": @@ -82,11 +83,13 @@ to:: Here's a coroutine that broadcasts a message to all clients:: + from websockets import ConnectionClosed + async def broadcast(message): for websocket in CLIENTS.copy(): try: await websocket.send(message) - except websockets.ConnectionClosed: + except ConnectionClosed: pass There are two tricks in this version of ``broadcast()``. @@ -117,11 +120,11 @@ which is usually outside of the control of the server. If you know for sure that you will never write more than ``write_limit`` bytes within ``ping_interval + ping_timeout``, then websockets will terminate slow -connections before the write buffer has time to fill up. +connections before the write buffer can fill up. -Don't set extreme ``write_limit``, ``ping_interval``, and ``ping_timeout`` -values to ensure that this condition holds. Set reasonable values and use the -built-in :func:`broadcast` function instead. +Don't set extreme values of ``write_limit``, ``ping_interval``, or +``ping_timeout`` to ensure that this condition holds! Instead, set reasonable +values and use the built-in :func:`~asyncio.connection.broadcast` function. The concurrent way ------------------ @@ -134,7 +137,7 @@ Let's modify ``broadcast()`` to send messages concurrently:: async def send(websocket, message): try: await websocket.send(message) - except websockets.ConnectionClosed: + except ConnectionClosed: pass def broadcast(message): @@ -179,20 +182,20 @@ doesn't work well when broadcasting a message to thousands of clients. When you're sending messages to a single client, you don't want to send them faster than the network can transfer them and the client accept them. This is -why :meth:`~server.WebSocketServerProtocol.send` checks if the write buffer -is full and, if it is, waits until it drain, giving the network and the -client time to catch up. This provides backpressure. +why :meth:`~asyncio.server.ServerConnection.send` checks if the write buffer is +above the high-water mark and, if it is, waits until it drains, giving the +network and the client time to catch up. This provides backpressure. Without backpressure, you could pile up data in the write buffer until the server process runs out of memory and the operating system kills it. -The :meth:`~server.WebSocketServerProtocol.send` API is designed to enforce +The :meth:`~asyncio.server.ServerConnection.send` API is designed to enforce backpressure by default. This helps users of websockets write robust programs even if they never heard about backpressure. For comparison, :class:`asyncio.StreamWriter` requires users to understand -backpressure and to await :meth:`~asyncio.StreamWriter.drain` explicitly -after each :meth:`~asyncio.StreamWriter.write`. +backpressure and to await :meth:`~asyncio.StreamWriter.drain` after each +:meth:`~asyncio.StreamWriter.write` — or at least sufficiently frequently. When broadcasting messages, backpressure consists in slowing down all clients in an attempt to let the slowest client catch up. With thousands of clients, @@ -203,14 +206,14 @@ How do we avoid running out of memory when slow clients can't keep up with the broadcast rate, then? The most straightforward option is to disconnect them. If a client gets too far behind, eventually it reaches the limit defined by -``ping_timeout`` and websockets terminates the connection. You can read the -discussion of :doc:`keepalive and timeouts <./timeouts>` for details. +``ping_timeout`` and websockets terminates the connection. You can refer to +the discussion of :doc:`keepalive and timeouts ` for details. -How :func:`broadcast` works ---------------------------- +How :func:`~asyncio.connection.broadcast` works +----------------------------------------------- -The built-in :func:`broadcast` function is similar to the naive way. The main -difference is that it doesn't apply backpressure. +The built-in :func:`~asyncio.connection.broadcast` function is similar to the +naive way. The main difference is that it doesn't apply backpressure. This provides the best performance by avoiding the overhead of scheduling and running one task per client. @@ -321,9 +324,9 @@ the asynchronous iterator returned by ``subscribe()``. Performance considerations -------------------------- -The built-in :func:`broadcast` function sends all messages without yielding -control to the event loop. So does the naive way when the network and clients -are fast and reliable. +The built-in :func:`~asyncio.connection.broadcast` function sends all messages +without yielding control to the event loop. So does the naive way when the +network and clients are fast and reliable. For each client, a WebSocket frame is prepared and sent to the network. This is the minimum amount of work required to broadcast a message. @@ -343,7 +346,7 @@ However, this isn't possible in general for two reasons: All other patterns discussed above yield control to the event loop once per client because messages are sent by different tasks. This makes them slower -than the built-in :func:`broadcast` function. +than the built-in :func:`~asyncio.connection.broadcast` function. There is no major difference between the performance of per-client queues and publish–subscribe. diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index 873c852c2..765278360 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -220,7 +220,7 @@ Here's what websockets logs at each level. ``WARNING`` ........... -* Failures in :func:`~websockets.broadcast` +* Failures in :func:`~asyncio.connection.broadcast` ``INFO`` ........ diff --git a/docs/topics/performance.rst b/docs/topics/performance.rst index 45e23b239..b226cec43 100644 --- a/docs/topics/performance.rst +++ b/docs/topics/performance.rst @@ -1,6 +1,8 @@ Performance =========== +.. currentmodule:: websockets + Here are tips to optimize performance. uvloop @@ -16,5 +18,5 @@ application.) broadcast --------- -:func:`~websockets.broadcast` is the most efficient way to send a message to -many clients. +:func:`~asyncio.connection.broadcast` is the most efficient way to send a +message to many clients. diff --git a/experiments/broadcast/server.py b/experiments/broadcast/server.py index b0407ba34..0a5c82b3c 100644 --- a/experiments/broadcast/server.py +++ b/experiments/broadcast/server.py @@ -6,7 +6,9 @@ import sys import time -import websockets +from websockets import ConnectionClosed +from websockets.asyncio.server import serve +from websockets.asyncio.connection import broadcast CLIENTS = set() @@ -15,7 +17,7 @@ async def send(websocket, message): try: await websocket.send(message) - except websockets.ConnectionClosed: + except ConnectionClosed: pass @@ -43,9 +45,6 @@ async def subscribe(self): __aiter__ = subscribe -PUBSUB = PubSub() - - async def handler(websocket, method=None): if method in ["default", "naive", "task", "wait"]: CLIENTS.add(websocket) @@ -63,14 +62,18 @@ async def handler(websocket, method=None): CLIENTS.remove(queue) relay_task.cancel() elif method == "pubsub": + global PUBSUB async for message in PUBSUB: await websocket.send(message) else: raise NotImplementedError(f"unsupported method: {method}") -async def broadcast(method, size, delay): +async def broadcast_messages(method, size, delay): """Broadcast messages at regular intervals.""" + if method == "pubsub": + global PUBSUB + PUBSUB = PubSub() load_average = 0 time_average = 0 pc1, pt1 = time.perf_counter_ns(), time.process_time_ns() @@ -90,7 +93,7 @@ async def broadcast(method, size, delay): message = str(time.time_ns()).encode() + b" " + os.urandom(size - 20) if method == "default": - websockets.broadcast(CLIENTS, message) + broadcast(CLIENTS, message) elif method == "naive": # Since the loop can yield control, make a copy of CLIENTS # to avoid: RuntimeError: Set changed size during iteration @@ -128,14 +131,14 @@ async def broadcast(method, size, delay): async def main(method, size, delay): - async with websockets.serve( + async with serve( functools.partial(handler, method=method), "localhost", 8765, compression=None, ping_timeout=None, ): - await broadcast(method, size, delay) + await broadcast_messages(method, size, delay) if __name__ == "__main__": diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 1c4424f0d..9d2f087da 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -6,6 +6,7 @@ import logging import random import struct +import sys import uuid from types import TracebackType from typing import ( @@ -27,7 +28,7 @@ from .messages import Assembler -__all__ = ["Connection"] +__all__ = ["Connection", "broadcast"] class Connection(asyncio.Protocol): @@ -338,7 +339,6 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No Raises: ConnectionClosed: When the connection is closed. - RuntimeError: If the connection busy sending a fragmented message. TypeError: If ``message`` doesn't have a supported type. """ @@ -488,7 +488,7 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No self.fragmented_send_waiter = None else: - raise TypeError("data must be bytes, str, iterable, or async iterable") + raise TypeError("data must be str, bytes, iterable, or async iterable") async def close(self, code: int = 1000, reason: str = "") -> None: """ @@ -673,7 +673,7 @@ async def send_context( On entry, :meth:`send_context` checks that the connection is open; on exit, it writes outgoing data to the socket:: - async async with self.send_context(): + async with self.send_context(): self.protocol.send_text(message.encode()) When the connection isn't open on entry, when the connection is expected @@ -916,3 +916,95 @@ def eof_received(self) -> None: # As a consequence, they never need to write after receiving EOF, so # there's no reason to keep the transport open by returning True. # Besides, that doesn't work on TLS connections. + + +def broadcast( + connections: Iterable[Connection], + message: Data, + raise_exceptions: bool = False, +) -> None: + """ + Broadcast a message to several WebSocket connections. + + A string (:class:`str`) is sent as a Text_ frame. A bytestring or bytes-like + object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent + as a Binary_ frame. + + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + :func:`broadcast` pushes the message synchronously to all connections even + if their write buffers are overflowing. There's no backpressure. + + If you broadcast messages faster than a connection can handle them, messages + will pile up in its write buffer until the connection times out. Keep + ``ping_interval`` and ``ping_timeout`` low to prevent excessive memory usage + from slow connections. + + Unlike :meth:`~Connection.send`, :func:`broadcast` doesn't support sending + fragmented messages. Indeed, fragmentation is useful for sending large + messages without buffering them in memory, while :func:`broadcast` buffers + one copy per connection as fast as possible. + + :func:`broadcast` skips connections that aren't open in order to avoid + errors on connections where the closing handshake is in progress. + + :func:`broadcast` ignores failures to write the message on some connections. + It continues writing to other connections. On Python 3.11 and above, you may + set ``raise_exceptions`` to :obj:`True` to record failures and raise all + exceptions in a :pep:`654` :exc:`ExceptionGroup`. + + Args: + websockets: WebSocket connections to which the message will be sent. + message: Message to send. + raise_exceptions: Whether to raise an exception in case of failures. + + Raises: + TypeError: If ``message`` doesn't have a supported type. + + """ + if isinstance(message, str): + send_method = "send_text" + message = message.encode() + elif isinstance(message, BytesLike): + send_method = "send_binary" + else: + raise TypeError("data must be str or bytes") + + if raise_exceptions: + if sys.version_info[:2] < (3, 11): # pragma: no cover + raise ValueError("raise_exceptions requires at least Python 3.11") + exceptions = [] + + for connection in connections: + if connection.protocol.state is not OPEN: + continue + + if connection.fragmented_send_waiter is not None: + if raise_exceptions: + exception = RuntimeError("sending a fragmented message") + exceptions.append(exception) + else: + connection.logger.warning( + "skipped broadcast: sending a fragmented message", + ) + continue + + try: + # Call connection.protocol.send_text or send_binary. + # Either way, message is already converted to bytes. + getattr(connection.protocol, send_method)(message) + connection.send_data() + except Exception as write_exception: + if raise_exceptions: + exception = RuntimeError("failed to write message") + exception.__cause__ = write_exception + exceptions.append(exception) + else: + connection.logger.warning( + "skipped broadcast: failed to write message", + exc_info=True, + ) + + if raise_exceptions and exceptions: + raise ExceptionGroup("skipped broadcast", exceptions) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 6f8916576..b948257e0 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1570,18 +1570,17 @@ def broadcast( ``ping_interval`` and ``ping_timeout`` low to prevent excessive memory usage from slow connections. - Unlike :meth:`~websockets.server.WebSocketServerProtocol.send`, - :func:`broadcast` doesn't support sending fragmented messages. Indeed, - fragmentation is useful for sending large messages without buffering them in - memory, while :func:`broadcast` buffers one copy per connection as fast as - possible. + Unlike :meth:`~WebSocketCommonProtocol.send`, :func:`broadcast` doesn't + support sending fragmented messages. Indeed, fragmentation is useful for + sending large messages without buffering them in memory, while + :func:`broadcast` buffers one copy per connection as fast as possible. :func:`broadcast` skips connections that aren't open in order to avoid errors on connections where the closing handshake is in progress. :func:`broadcast` ignores failures to write the message on some connections. - It continues writing to other connections. On Python 3.11 and above, you - may set ``raise_exceptions`` to :obj:`True` to record failures and raise all + It continues writing to other connections. On Python 3.11 and above, you may + set ``raise_exceptions`` to :obj:`True` to record failures and raise all exceptions in a :pep:`654` :exc:`ExceptionGroup`. Args: @@ -1615,6 +1614,7 @@ def broadcast( websocket.logger.warning( "skipped broadcast: sending a fragmented message", ) + continue try: websocket.write_frame_sync(True, opcode, data) diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index a4826c785..88d6aee1f 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -387,7 +387,7 @@ def send(self, message: Data | Iterable[Data]) -> None: raise else: - raise TypeError("data must be bytes, str, or iterable") + raise TypeError("data must be str, bytes, or iterable") def close(self, code: int = CloseCode.NORMAL_CLOSURE, reason: str = "") -> None: """ diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 02029b754..1cf382a01 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -2,6 +2,7 @@ import contextlib import logging import socket +import sys import unittest import uuid from unittest.mock import Mock, patch @@ -1004,6 +1005,161 @@ async def test_unexpected_failure_in_send_context(self, send_text): self.assertEqual(str(exc), "no close frame received or sent") self.assertIsInstance(exc.__cause__, AssertionError) + # Test broadcast. + + async def test_broadcast_text(self): + """broadcast broadcasts a text message.""" + broadcast([self.connection], "😀") + await self.assertFrameSent(Frame(Opcode.TEXT, "😀".encode())) + + @unittest.skipIf( + sys.version_info[:2] < (3, 11), + "raise_exceptions requires Python 3.11+", + ) + async def test_broadcast_text_reports_no_errors(self): + """broadcast broadcasts a text message without raising exceptions.""" + broadcast([self.connection], "😀", raise_exceptions=True) + await self.assertFrameSent(Frame(Opcode.TEXT, "😀".encode())) + + async def test_broadcast_binary(self): + """broadcast broadcasts a binary message.""" + broadcast([self.connection], b"\x01\x02\xfe\xff") + await self.assertFrameSent(Frame(Opcode.BINARY, b"\x01\x02\xfe\xff")) + + @unittest.skipIf( + sys.version_info[:2] < (3, 11), + "raise_exceptions requires Python 3.11+", + ) + async def test_broadcast_binary_reports_no_errors(self): + """broadcast broadcasts a binary message without raising exceptions.""" + broadcast([self.connection], b"\x01\x02\xfe\xff", raise_exceptions=True) + await self.assertFrameSent(Frame(Opcode.BINARY, b"\x01\x02\xfe\xff")) + + async def test_broadcast_no_clients(self): + """broadcast does nothing when called with an empty list of clients.""" + broadcast([], "😀") + await self.assertNoFrameSent() + + async def test_broadcast_two_clients(self): + """broadcast broadcasts a message to several clients.""" + broadcast([self.connection, self.connection], "😀") + await self.assertFramesSent( + [ + Frame(Opcode.TEXT, "😀".encode()), + Frame(Opcode.TEXT, "😀".encode()), + ] + ) + + async def test_broadcast_skips_closed_connection(self): + """broadcast ignores closed connections.""" + await self.connection.close() + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + with self.assertNoLogs(): + broadcast([self.connection], "😀") + await self.assertNoFrameSent() + + async def test_broadcast_skips_closing_connection(self): + """broadcast ignores closing connections.""" + async with self.delay_frames_rcvd(MS): + close_task = asyncio.create_task(self.connection.close()) + await asyncio.sleep(0) + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + with self.assertNoLogs(): + broadcast([self.connection], "😀") + await self.assertNoFrameSent() + + await close_task + + async def test_broadcast_skips_connection_with_send_blocked(self): + """broadcast logs a warning when a connection is blocked in send.""" + gate = asyncio.get_running_loop().create_future() + + async def fragments(): + yield "⏳" + await gate + + send_task = asyncio.create_task(self.connection.send(fragments())) + await asyncio.sleep(MS) + await self.assertFrameSent(Frame(Opcode.TEXT, "⏳".encode(), fin=False)) + + with self.assertLogs("websockets", logging.WARNING) as logs: + broadcast([self.connection], "😀") + + self.assertEqual( + [record.getMessage() for record in logs.records][:2], + ["skipped broadcast: sending a fragmented message"], + ) + + gate.set_result(None) + await send_task + + @unittest.skipIf( + sys.version_info[:2] < (3, 11), + "raise_exceptions requires Python 3.11+", + ) + async def test_broadcast_reports_connection_with_send_blocked(self): + """broadcast raises exceptions for connections blocked in send.""" + gate = asyncio.get_running_loop().create_future() + + async def fragments(): + yield "⏳" + await gate + + send_task = asyncio.create_task(self.connection.send(fragments())) + await asyncio.sleep(MS) + await self.assertFrameSent(Frame(Opcode.TEXT, "⏳".encode(), fin=False)) + + with self.assertRaises(ExceptionGroup) as raised: + broadcast([self.connection], "😀", raise_exceptions=True) + + self.assertEqual(str(raised.exception), "skipped broadcast (1 sub-exception)") + exc = raised.exception.exceptions[0] + self.assertEqual(str(exc), "sending a fragmented message") + self.assertIsInstance(exc, RuntimeError) + + gate.set_result(None) + await send_task + + async def test_broadcast_skips_connection_failing_to_send(self): + """broadcast logs a warning when a connection fails to send.""" + # Inject a fault by shutting down the transport for writing. + self.transport.write_eof() + + with self.assertLogs("websockets", logging.WARNING) as logs: + broadcast([self.connection], "😀") + + self.assertEqual( + [record.getMessage() for record in logs.records][:2], + ["skipped broadcast: failed to write message"], + ) + + @unittest.skipIf( + sys.version_info[:2] < (3, 11), + "raise_exceptions requires Python 3.11+", + ) + async def test_broadcast_reports_connection_failing_to_send(self): + """broadcast raises exceptions for connections failing to send.""" + # Inject a fault by shutting down the transport for writing. + self.transport.write_eof() + + with self.assertRaises(ExceptionGroup) as raised: + broadcast([self.connection], "😀", raise_exceptions=True) + + self.assertEqual(str(raised.exception), "skipped broadcast (1 sub-exception)") + exc = raised.exception.exceptions[0] + self.assertEqual(str(exc), "failed to write message") + self.assertIsInstance(exc, RuntimeError) + cause = exc.__cause__ + self.assertEqual(str(cause), "Cannot call write() after write_eof()") + self.assertIsInstance(cause, RuntimeError) + + async def test_broadcast_type_error(self): + """broadcast raises TypeError when called with an unsupported type.""" + with self.assertRaises(TypeError): + broadcast([self.connection], ["⏳", "⌛️"]) + class ServerConnectionTests(ClientConnectionTests): LOCAL = SERVER diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index d6303dcc7..ccea34719 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1473,7 +1473,8 @@ def test_broadcast_text(self): self.assertOneFrameSent(True, OP_TEXT, "café".encode()) @unittest.skipIf( - sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+" + sys.version_info[:2] < (3, 11), + "raise_exceptions requires Python 3.11+", ) def test_broadcast_text_reports_no_errors(self): broadcast([self.protocol], "café", raise_exceptions=True) @@ -1484,7 +1485,8 @@ def test_broadcast_binary(self): self.assertOneFrameSent(True, OP_BINARY, b"tea") @unittest.skipIf( - sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+" + sys.version_info[:2] < (3, 11), + "raise_exceptions requires Python 3.11+", ) def test_broadcast_binary_reports_no_errors(self): broadcast([self.protocol], b"tea", raise_exceptions=True) @@ -1536,7 +1538,8 @@ def test_broadcast_skips_connection_sending_fragmented_text(self): ) @unittest.skipIf( - sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+" + sys.version_info[:2] < (3, 11), + "raise_exceptions requires Python 3.11+", ) def test_broadcast_reports_connection_sending_fragmented_text(self): self.make_drain_slow() @@ -1565,7 +1568,8 @@ def test_broadcast_skips_connection_failing_to_send(self): ) @unittest.skipIf( - sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+" + sys.version_info[:2] < (3, 11), + "raise_exceptions requires Python 3.11+", ) def test_broadcast_reports_connection_failing_to_send(self): # Configure mock to raise an exception when writing to the network. From 8d9f9a1cc791df01d7995693551cd9cf83e154c2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 17 Aug 2024 11:15:51 +0200 Subject: [PATCH 1330/1539] Expose connection state in the new asyncio implementation. --- docs/reference/new-asyncio/client.rst | 2 ++ docs/reference/new-asyncio/common.rst | 2 ++ docs/reference/new-asyncio/server.rst | 2 ++ src/websockets/asyncio/connection.py | 12 ++++++++++++ src/websockets/protocol.py | 2 +- tests/asyncio/test_connection.py | 6 +++++- 6 files changed, 24 insertions(+), 2 deletions(-) diff --git a/docs/reference/new-asyncio/client.rst b/docs/reference/new-asyncio/client.rst index 196bda2b7..efd143f14 100644 --- a/docs/reference/new-asyncio/client.rst +++ b/docs/reference/new-asyncio/client.rst @@ -43,6 +43,8 @@ Using a connection .. autoproperty:: remote_address + .. autoproperty:: state + The following attributes are available after the opening handshake, once the WebSocket connection is open: diff --git a/docs/reference/new-asyncio/common.rst b/docs/reference/new-asyncio/common.rst index 4fa97dcf2..60ea6bb37 100644 --- a/docs/reference/new-asyncio/common.rst +++ b/docs/reference/new-asyncio/common.rst @@ -33,6 +33,8 @@ Both sides (new :mod:`asyncio`) .. autoproperty:: remote_address + .. autoproperty:: state + The following attributes are available after the opening handshake, once the WebSocket connection is open: diff --git a/docs/reference/new-asyncio/server.rst b/docs/reference/new-asyncio/server.rst index 7f9de6148..b163e0fcd 100644 --- a/docs/reference/new-asyncio/server.rst +++ b/docs/reference/new-asyncio/server.rst @@ -62,6 +62,8 @@ Using a connection .. autoproperty:: remote_address + .. autoproperty:: state + The following attributes are available after the opening handshake, once the WebSocket connection is open: diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 9d2f087da..a323376ca 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -137,6 +137,18 @@ def remote_address(self) -> Any: """ return self.transport.get_extra_info("peername") + @property + def state(self) -> State: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should call :meth:`~recv` or + :meth:`send` and handle :exc:`~exceptions.ConnectionClosed` exceptions. + + """ + return self.protocol.state + @property def subprotocol(self) -> Subprotocol | None: """ diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 917c19163..de065c544 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -156,7 +156,7 @@ def __init__( @property def state(self) -> State: """ - WebSocket connection state. + State of the WebSocket connection. Defined in 4.1, 4.2, 7.1.3, and 7.1.4 of :rfc:`6455`. diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 1cf382a01..239b5312e 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -11,7 +11,7 @@ from websockets.asyncio.connection import * from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK from websockets.frames import CloseCode, Frame, Opcode -from websockets.protocol import CLIENT, SERVER, Protocol +from websockets.protocol import CLIENT, SERVER, Protocol, State from ..protocol import RecordingProtocol from ..utils import MS @@ -930,6 +930,10 @@ async def test_remote_address(self, get_extra_info): self.assertEqual(self.connection.remote_address, ("peer", 1234)) get_extra_info.assert_called_with("peername") + async def test_state(self): + """Connection has a state attribute.""" + self.assertEqual(self.connection.state, State.OPEN) + async def test_request(self): """Connection has a request attribute.""" self.assertIsNone(self.connection.request) From 7c8e0b9d6246cd7bdd304f630f719fc55620f89a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 17 Aug 2024 10:20:37 +0200 Subject: [PATCH 1331/1539] Document removal of open and closed properties. They won't be added to the new asyncio implementation. --- docs/howto/upgrade.rst | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 6efaf0f56..8ff18c594 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -371,3 +371,29 @@ buffer now. The ``write_limit`` argument of :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` defaults to 32 KiB instead of 64 KiB. + +Attributes of connections +......................... + +``open`` and ``closed`` +~~~~~~~~~~~~~~~~~~~~~~~ + +The :attr:`~legacy.protocol.WebSocketCommonProtocol.open` and +:attr:`~legacy.protocol.WebSocketCommonProtocol.closed` properties are removed. +Using them was discouraged. + +Instead, you should call :meth:`~asyncio.connection.Connection.recv` or +:meth:`~asyncio.connection.Connection.send` and handle +:exc:`~exceptions.ConnectionClosed` exceptions. + +If your code relies on them, you can replace:: + + connection.open + connection.closed + +with:: + + from websockets.protocol import State + + connection.state is State.OPEN + connection.state is State.CLOSED From 7c1d1d9b97fa698034d2b3651eb5a757e42b3dfb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 17 Aug 2024 11:53:37 +0200 Subject: [PATCH 1332/1539] Add a respond method to server connections. It's an alias of the reject method of the underlying server protocol. It makes it easier to write process_request It's called respond because the semantics are "consider the request as an HTTP request and create an HTTP response". There isn't a similar alias for accept because process_request should just return and websockets will call accept. --- docs/howto/upgrade.rst | 2 +- docs/reference/new-asyncio/server.rst | 2 ++ docs/reference/sync/server.rst | 2 ++ src/websockets/asyncio/server.py | 23 ++++++++++++++++++++++- src/websockets/server.py | 27 ++++++++++++++------------- src/websockets/sync/server.py | 23 ++++++++++++++++++++++- tests/asyncio/test_server.py | 4 ++-- tests/sync/test_server.py | 2 +- 8 files changed, 66 insertions(+), 19 deletions(-) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 8ff18c594..fe95a6517 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -247,7 +247,7 @@ an example:: # New implementation def process_request(connection, request): - return connection.protocol.reject(http.HTTPStatus.OK, "OK\n") + return connection.respond(http.HTTPStatus.OK, "OK\n") serve(..., process_request=process_request, ...) diff --git a/docs/reference/new-asyncio/server.rst b/docs/reference/new-asyncio/server.rst index b163e0fcd..5ffcff843 100644 --- a/docs/reference/new-asyncio/server.rst +++ b/docs/reference/new-asyncio/server.rst @@ -52,6 +52,8 @@ Using a connection .. automethod:: pong + .. automethod:: respond + WebSocket connection objects also provide these attributes: .. autoattribute:: id diff --git a/docs/reference/sync/server.rst b/docs/reference/sync/server.rst index 7ed744df2..26ab872c8 100644 --- a/docs/reference/sync/server.rst +++ b/docs/reference/sync/server.rst @@ -40,6 +40,8 @@ Using a connection .. automethod:: pong + .. automethod:: respond + WebSocket connection objects also provide these attributes: .. autoattribute:: id diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 4feea13c4..cc2f46216 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -23,7 +23,7 @@ from ..http11 import SERVER, Request, Response from ..protocol import CONNECTING, Event from ..server import ServerProtocol -from ..typing import LoggerLike, Origin, Subprotocol +from ..typing import LoggerLike, Origin, StatusLike, Subprotocol from .compatibility import asyncio_timeout from .connection import Connection @@ -75,6 +75,27 @@ def __init__( self.server = server self.request_rcvd: asyncio.Future[None] = self.loop.create_future() + def respond(self, status: StatusLike, text: str) -> Response: + """ + Create a plain text HTTP response. + + ``process_request`` and ``process_response`` may call this method to + return an HTTP response instead of performing the WebSocket opening + handshake. + + You can modify the response before returning it, for example by changing + HTTP headers. + + Args: + status: HTTP status code. + text: HTTP response body; it will be encoded to UTF-8. + + Returns: + HTTP response to send to the client. + + """ + return self.protocol.reject(status, text) + async def handshake( self, process_request: ( diff --git a/src/websockets/server.py b/src/websockets/server.py index 1b4c3bf29..2ab9102f7 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -113,18 +113,22 @@ def accept(self, request: Request) -> Response: """ Create a handshake response to accept the connection. - If the connection cannot be established, the handshake response - actually rejects the handshake. + If the handshake request is valid and the handshake successful, + :meth:`accept` returns an HTTP response with status code 101. + + Else, it returns an HTTP response with another status code. This rejects + the connection, like :meth:`reject` would. You must send the handshake response with :meth:`send_response`. - You may modify it before sending it, for example to add HTTP headers. + You may modify the response before sending it, typically by adding HTTP + headers. Args: - request: WebSocket handshake request event received from the client. + request: WebSocket handshake request received from the client. Returns: - WebSocket handshake response event to send to the client. + WebSocket handshake response or HTTP response to send to the client. """ try: @@ -485,11 +489,7 @@ def select_subprotocol(protocol, subprotocols): + ", ".join(self.available_subprotocols) ) - def reject( - self, - status: StatusLike, - text: str, - ) -> Response: + def reject(self, status: StatusLike, text: str) -> Response: """ Create a handshake response to reject the connection. @@ -498,14 +498,15 @@ def reject( You must send the handshake response with :meth:`send_response`. - You can modify it before sending it, for example to alter HTTP headers. + You may modify the response before sending it, for example by changing + HTTP headers. Args: status: HTTP status code. - text: HTTP response body; will be encoded to UTF-8. + text: HTTP response body; it will be encoded to UTF-8. Returns: - WebSocket handshake response event to send to the client. + HTTP response to send to the client. """ # If a user passes an int instead of a HTTPStatus, fix it automatically. diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 10fbe4859..b381908ca 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -19,7 +19,7 @@ from ..http11 import SERVER, Request, Response from ..protocol import CONNECTING, OPEN, Event from ..server import ServerProtocol -from ..typing import LoggerLike, Origin, Subprotocol +from ..typing import LoggerLike, Origin, StatusLike, Subprotocol from .connection import Connection from .utils import Deadline @@ -66,6 +66,27 @@ def __init__( close_timeout=close_timeout, ) + def respond(self, status: StatusLike, text: str) -> Response: + """ + Create a plain text HTTP response. + + ``process_request`` and ``process_response`` may call this method to + return an HTTP response instead of performing the WebSocket opening + handshake. + + You can modify the response before returning it, for example by changing + HTTP headers. + + Args: + status: HTTP status code. + text: HTTP response body; it will be encoded to UTF-8. + + Returns: + HTTP response to send to the client. + + """ + return self.protocol.reject(status, text) + def handshake( self, process_request: ( diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 4a8a76a21..fa590210f 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -144,7 +144,7 @@ async def test_process_request_abort_handshake(self): """Server aborts handshake if process_request returns a response.""" def process_request(ws, request): - return ws.protocol.reject(http.HTTPStatus.FORBIDDEN, "Forbidden") + return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") async with run_server(process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: @@ -159,7 +159,7 @@ async def test_async_process_request_abort_handshake(self): """Server aborts handshake if async process_request returns a response.""" async def process_request(ws, request): - return ws.protocol.reject(http.HTTPStatus.FORBIDDEN, "Forbidden") + return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") async with run_server(process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 9d509a5c4..4e04a39d5 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -131,7 +131,7 @@ def test_process_request_abort_handshake(self): """Server aborts handshake if process_request returns a response.""" def process_request(ws, request): - return ws.protocol.reject(http.HTTPStatus.FORBIDDEN, "Forbidden") + return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") with run_server(process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: From 7b19e790ce766dadb0b90b040be68694074a7e0d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 17 Aug 2024 23:12:36 +0200 Subject: [PATCH 1333/1539] Add keepalive to the new asyncio implementation. --- docs/faq/common.rst | 4 +- docs/howto/upgrade.rst | 11 --- docs/reference/features.rst | 4 +- docs/reference/new-asyncio/client.rst | 2 + docs/reference/new-asyncio/common.rst | 2 + docs/reference/new-asyncio/server.rst | 2 + docs/topics/broadcast.rst | 4 +- docs/topics/index.rst | 2 +- docs/topics/{timeouts.rst => keepalive.rst} | 16 ++-- src/websockets/asyncio/client.py | 21 ++++- src/websockets/asyncio/connection.py | 89 ++++++++++++++++++-- src/websockets/asyncio/server.py | 21 ++++- src/websockets/legacy/protocol.py | 16 ++-- tests/asyncio/test_client.py | 15 ++++ tests/asyncio/test_connection.py | 91 ++++++++++++++++++++- tests/asyncio/test_server.py | 21 +++++ tests/legacy/test_protocol.py | 4 +- 17 files changed, 274 insertions(+), 51 deletions(-) rename docs/topics/{timeouts.rst => keepalive.rst} (90%) diff --git a/docs/faq/common.rst b/docs/faq/common.rst index 2c63c4f36..84256fdfe 100644 --- a/docs/faq/common.rst +++ b/docs/faq/common.rst @@ -97,7 +97,7 @@ There are two main reasons why latency may increase: * Poor network connectivity. * More traffic than the recipient can handle. -See the discussion of :doc:`timeouts <../topics/timeouts>` for details. +See the discussion of :doc:`keepalive <../topics/keepalive>` for details. If websockets' default timeout of 20 seconds is too short for your use case, you can adjust it with the ``ping_timeout`` argument. @@ -146,7 +146,7 @@ It closes the connection if it doesn't get a pong within 20 seconds. You can adjust this behavior with ``ping_interval`` and ``ping_timeout``. -See :doc:`../topics/timeouts` for details. +See :doc:`../topics/keepalive` for details. How do I respond to pings? -------------------------- diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index fe95a6517..16b010aca 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -70,17 +70,6 @@ Missing features If your application relies on one of them, you should stick to the original implementation until the new implementation supports it in a future release. -Keepalive -......... - -The new implementation doesn't provide a :ref:`keepalive mechanism ` -yet. - -As a consequence, :func:`~asyncio.client.connect` and -:func:`~asyncio.server.serve` don't accept the ``ping_interval`` and -``ping_timeout`` arguments and the -:attr:`~legacy.protocol.WebSocketCommonProtocol.latency` property doesn't exist. - HTTP Basic Authentication ......................... diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 946770fe3..45fa79c48 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -53,9 +53,9 @@ Both sides +------------------------------------+--------+--------+--------+--------+ | Send a pong | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ - | Keepalive | ❌ | ❌ | — | ✅ | + | Keepalive | ✅ | ❌ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ - | Heartbeat | ❌ | ❌ | — | ✅ | + | Heartbeat | ✅ | ❌ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Perform the closing handshake | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ diff --git a/docs/reference/new-asyncio/client.rst b/docs/reference/new-asyncio/client.rst index efd143f14..77a3c5d53 100644 --- a/docs/reference/new-asyncio/client.rst +++ b/docs/reference/new-asyncio/client.rst @@ -43,6 +43,8 @@ Using a connection .. autoproperty:: remote_address + .. autoattribute:: latency + .. autoproperty:: state The following attributes are available after the opening handshake, diff --git a/docs/reference/new-asyncio/common.rst b/docs/reference/new-asyncio/common.rst index 60ea6bb37..a58325fb9 100644 --- a/docs/reference/new-asyncio/common.rst +++ b/docs/reference/new-asyncio/common.rst @@ -33,6 +33,8 @@ Both sides (new :mod:`asyncio`) .. autoproperty:: remote_address + .. autoattribute:: latency + .. autoproperty:: state The following attributes are available after the opening handshake, diff --git a/docs/reference/new-asyncio/server.rst b/docs/reference/new-asyncio/server.rst index 5ffcff843..7bceca5a0 100644 --- a/docs/reference/new-asyncio/server.rst +++ b/docs/reference/new-asyncio/server.rst @@ -64,6 +64,8 @@ Using a connection .. autoproperty:: remote_address + .. autoattribute:: latency + .. autoproperty:: state The following attributes are available after the opening handshake, diff --git a/docs/topics/broadcast.rst b/docs/topics/broadcast.rst index 671319136..ec358bbd2 100644 --- a/docs/topics/broadcast.rst +++ b/docs/topics/broadcast.rst @@ -206,8 +206,8 @@ How do we avoid running out of memory when slow clients can't keep up with the broadcast rate, then? The most straightforward option is to disconnect them. If a client gets too far behind, eventually it reaches the limit defined by -``ping_timeout`` and websockets terminates the connection. You can refer to -the discussion of :doc:`keepalive and timeouts ` for details. +``ping_timeout`` and websockets terminates the connection. You can refer to the +discussion of :doc:`keepalive ` for details. How :func:`~asyncio.connection.broadcast` works ----------------------------------------------- diff --git a/docs/topics/index.rst b/docs/topics/index.rst index 120a3dd32..a2b8ca879 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -11,7 +11,7 @@ Get a deeper understanding of how websockets is built and why. authentication broadcast compression - timeouts + keepalive design memory security diff --git a/docs/topics/timeouts.rst b/docs/topics/keepalive.rst similarity index 90% rename from docs/topics/timeouts.rst rename to docs/topics/keepalive.rst index 633fc1ab4..1c7a43264 100644 --- a/docs/topics/timeouts.rst +++ b/docs/topics/keepalive.rst @@ -1,5 +1,5 @@ -Timeouts -======== +Keepalive and latency +===================== .. currentmodule:: websockets @@ -49,9 +49,9 @@ This mechanism serves two purposes: application gets a :exc:`~exceptions.ConnectionClosed` exception. Timings are configurable with the ``ping_interval`` and ``ping_timeout`` -arguments of :func:`~client.connect` and :func:`~server.serve`. Shorter values -will detect connection drops faster but they will increase network traffic and -they will be more sensitive to latency. +arguments of :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve`. +Shorter values will detect connection drops faster but they will increase +network traffic and they will be more sensitive to latency. Setting ``ping_interval`` to :obj:`None` disables the whole keepalive and heartbeat mechanism. @@ -111,6 +111,6 @@ Latency between a client and a server may increase for two reasons: than the client can accept. The latency measured during the last exchange of Ping and Pong frames is -available in the :attr:`~legacy.protocol.WebSocketCommonProtocol.latency` -attribute. Alternatively, you can measure the latency at any time with the -:attr:`~legacy.protocol.WebSocketCommonProtocol.ping` method. +available in the :attr:`~asyncio.connection.Connection.latency` attribute. +Alternatively, you can measure the latency at any time with the +:attr:`~asyncio.connection.Connection.ping` method. diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index b2eaf9a65..632d3ac2b 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -37,10 +37,11 @@ class ClientConnection(Connection): :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. + The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``, + and ``write_limit`` arguments the same meaning as in :func:`connect`. + Args: protocol: Sans-I/O connection. - close_timeout: Timeout for closing the connection in seconds. - :obj:`None` disables the timeout. """ @@ -48,6 +49,8 @@ def __init__( self, protocol: ClientProtocol, *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, close_timeout: float | None = 10, max_queue: int | tuple[int, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, @@ -55,6 +58,8 @@ def __init__( self.protocol: ClientProtocol super().__init__( protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, write_limit=write_limit, @@ -84,7 +89,9 @@ async def handshake( if self.response is None: raise ConnectionError("connection closed during handshake") - if self.protocol.handshake_exc is not None: + if self.protocol.handshake_exc is None: + self.start_keepalive() + else: try: async with asyncio_timeout(self.close_timeout): await self.connection_lost_waiter @@ -146,6 +153,10 @@ class connect: :doc:`compression guide <../../topics/compression>` for details. open_timeout: Timeout for opening the connection in seconds. :obj:`None` disables the timeout. + ping_interval: Interval between keepalive pings in seconds. + :obj:`None` disables keepalive. + ping_timeout: Timeout for keepalive pings in seconds. + :obj:`None` disables timeouts. close_timeout: Timeout for closing the connection in seconds. :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. @@ -208,6 +219,8 @@ def __init__( compression: str | None = "deflate", # Timeouts open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, @@ -256,6 +269,8 @@ def factory() -> ClientConnection: # This is a connection in websockets and a protocol in asyncio. connection = create_connection( protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, write_limit=write_limit, diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index a323376ca..b232b7956 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -24,7 +24,13 @@ from ..http11 import Request, Response from ..protocol import CLOSED, OPEN, Event, Protocol, State from ..typing import Data, LoggerLike, Subprotocol -from .compatibility import TimeoutError, aiter, anext, asyncio_timeout_at +from .compatibility import ( + TimeoutError, + aiter, + anext, + asyncio_timeout, + asyncio_timeout_at, +) from .messages import Assembler @@ -48,11 +54,15 @@ def __init__( self, protocol: Protocol, *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, close_timeout: float | None = 10, max_queue: int | tuple[int, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, ) -> None: self.protocol = protocol + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout self.close_timeout = close_timeout if isinstance(max_queue, int): max_queue = (max_queue, None) @@ -95,6 +105,21 @@ def __init__( # Mapping of ping IDs to pong waiters, in chronological order. self.pong_waiters: dict[bytes, tuple[asyncio.Future[float], float]] = {} + self.latency: float = 0 + """ + Latency of the connection, in seconds. + + This value is updated after sending a ping frame and receiving a + matching pong frame. Before the first ping, :attr:`latency` is ``0``. + + By default, websockets enables a :ref:`keepalive ` mechanism + that sends ping frames automatically at regular intervals. You can also + send ping frames and measure latency with :meth:`ping`. + """ + + # Task that sends keepalive pings. None when ping_interval is None. + self.keepalive_task: asyncio.Task[None] | None = None + # Exception raised while reading from the connection, to be chained to # ConnectionClosed in order to show why the TCP connection dropped. self.recv_exc: BaseException | None = None @@ -144,7 +169,8 @@ def state(self) -> State: This attribute is provided for completeness. Typical applications shouldn't check its value. Instead, they should call :meth:`~recv` or - :meth:`send` and handle :exc:`~exceptions.ConnectionClosed` exceptions. + :meth:`send` and handle :exc:`~websockets.exceptions.ConnectionClosed` + exceptions. """ return self.protocol.state @@ -540,7 +566,7 @@ async def wait_closed(self) -> None: """ await asyncio.shield(self.connection_lost_waiter) - async def ping(self, data: Data | None = None) -> Awaitable[None]: + async def ping(self, data: Data | None = None) -> Awaitable[float]: """ Send a Ping_. @@ -643,8 +669,10 @@ def acknowledge_pings(self, data: bytes) -> None: ping_ids = [] for ping_id, (pong_waiter, ping_timestamp) in self.pong_waiters.items(): ping_ids.append(ping_id) - pong_waiter.set_result(pong_timestamp - ping_timestamp) + latency = pong_timestamp - ping_timestamp + pong_waiter.set_result(latency) if ping_id == data: + self.latency = latency break else: raise AssertionError("solicited pong not found in pings") @@ -664,7 +692,8 @@ def abort_pings(self) -> None: exc = self.protocol.close_exc for pong_waiter, _ping_timestamp in self.pong_waiters.values(): - pong_waiter.set_exception(exc) + if not pong_waiter.done(): + pong_waiter.set_exception(exc) # If the exception is never retrieved, it will be logged when ping # is garbage-collected. This is confusing for users. # Given that ping is done (with an exception), canceling it does @@ -673,6 +702,50 @@ def abort_pings(self) -> None: self.pong_waiters.clear() + async def keepalive(self) -> None: + """ + Send a Ping frame and wait for a Pong frame at regular intervals. + + """ + assert self.ping_interval is not None + latency = 0.0 + try: + while True: + # If self.ping_timeout > latency > self.ping_interval, pings + # will be sent immediately after receiving pongs. The period + # will be longer than self.ping_interval. + await asyncio.sleep(self.ping_interval - latency) + + self.logger.debug("% sending keepalive ping") + pong_waiter = await self.ping() + + if self.ping_timeout is not None: + try: + async with asyncio_timeout(self.ping_timeout): + latency = await pong_waiter + self.logger.debug("% received keepalive pong") + except asyncio.TimeoutError: + if self.debug: + self.logger.debug("! timed out waiting for keepalive pong") + async with self.send_context(): + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "keepalive ping timeout", + ) + break + except ConnectionClosed: + pass + except Exception: + self.logger.error("keepalive ping failed", exc_info=True) + + def start_keepalive(self) -> None: + """ + Run :meth:`keepalive` in a task, unless keepalive is disabled. + + """ + if self.ping_interval is not None: + self.keepalive_task = self.loop.create_task(self.keepalive()) + @contextlib.asynccontextmanager async def send_context( self, @@ -835,11 +908,15 @@ def connection_lost(self, exc: Exception | None) -> None: self.protocol.receive_eof() # receive_eof is idempotent self.recv_messages.close() self.set_recv_exc(exc) + self.abort_pings() + # If keepalive() was waiting for a pong, abort_pings() terminated it. + # If it was sleeping until the next ping, we need to cancel it now + if self.keepalive_task is not None: + self.keepalive_task.cancel() # If self.connection_lost_waiter isn't pending, that's a bug, because: # - it's set only here in connection_lost() which is called only once; # - it must never be canceled. self.connection_lost_waiter.set_result(None) - self.abort_pings() # Adapted from asyncio.streams.FlowControlMixin if self.paused: # pragma: no cover diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index cc2f46216..1f55502bb 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -48,11 +48,12 @@ class ServerConnection(Connection): :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. + The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``, + and ``write_limit`` arguments the same meaning as in :func:`serve`. + Args: protocol: Sans-I/O connection. server: Server that manages this connection. - close_timeout: Timeout for closing connections in seconds. - :obj:`None` disables the timeout. """ @@ -61,6 +62,8 @@ def __init__( protocol: ServerProtocol, server: WebSocketServer, *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, close_timeout: float | None = 10, max_queue: int | tuple[int, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, @@ -68,6 +71,8 @@ def __init__( self.protocol: ServerProtocol super().__init__( protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, write_limit=write_limit, @@ -182,7 +187,9 @@ async def handshake( self.protocol.send_response(self.response) - if self.protocol.handshake_exc is not None: + if self.protocol.handshake_exc is None: + self.start_keepalive() + else: try: async with asyncio_timeout(self.close_timeout): await self.connection_lost_waiter @@ -595,6 +602,10 @@ def handler(websocket): :doc:`compression guide <../../topics/compression>` for details. open_timeout: Timeout for opening connections in seconds. :obj:`None` disables the timeout. + ping_interval: Interval between keepalive pings in seconds. + :obj:`None` disables keepalive. + ping_timeout: Timeout for keepalive pings in seconds. + :obj:`None` disables timeouts. close_timeout: Timeout for closing connections in seconds. :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. @@ -667,6 +678,8 @@ def __init__( compression: str | None = "deflate", # Timeouts open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, @@ -743,6 +756,8 @@ def protocol_select_subprotocol( connection = create_connection( protocol, self.server, + ping_interval=ping_interval, + ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, write_limit=write_limit, diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index b948257e0..191350de3 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -89,7 +89,7 @@ class WebSocketCommonProtocol(asyncio.Protocol): .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 - See the discussion of :doc:`timeouts <../../topics/timeouts>` for details. + See the discussion of :doc:`timeouts <../../topics/keepalive>` for details. The ``close_timeout`` parameter defines a maximum wait time for completing the closing handshake and terminating the TCP connection. For legacy @@ -144,8 +144,8 @@ class WebSocketCommonProtocol(asyncio.Protocol): logger: Logger for this server. It defaults to ``logging.getLogger("websockets.protocol")``. See the :doc:`logging guide <../../topics/logging>` for details. - ping_interval: Delay between keepalive pings in seconds. - :obj:`None` disables keepalive pings. + ping_interval: Interval between keepalive pings in seconds. + :obj:`None` disables keepalive. ping_timeout: Timeout for keepalive pings in seconds. :obj:`None` disables timeouts. close_timeout: Timeout for closing the connection in seconds. @@ -1242,18 +1242,16 @@ async def keepalive_ping(self) -> None: while True: await asyncio.sleep(self.ping_interval) - # ping() raises CancelledError if the connection is closed, - # when close_connection() cancels self.keepalive_ping_task. - - # ping() raises ConnectionClosed if the connection is lost, - # when connection_lost() calls abort_pings(). - self.logger.debug("% sending keepalive ping") pong_waiter = await self.ping() if self.ping_timeout is not None: try: async with asyncio_timeout(self.ping_timeout): + # Raises CancelledError if the connection is closed, + # when close_connection() cancels keepalive_ping(). + # Raises ConnectionClosed if the connection is lost, + # when connection_lost() calls abort_pings(). await pong_waiter self.logger.debug("% received keepalive pong") except asyncio.TimeoutError: diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index b74617ef0..0bd2af4f1 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -63,6 +63,21 @@ async def test_disable_compression(self): async with run_client(server, compression=None) as client: self.assertEqual(client.protocol.extensions, []) + async def test_keepalive_is_enabled(self): + """Client enables keepalive and measures latency by default.""" + async with run_server() as server: + async with run_client(server, ping_interval=MS) as client: + self.assertEqual(client.latency, 0) + await asyncio.sleep(2 * MS) + self.assertGreater(client.latency, 0) + + async def test_disable_keepalive(self): + """Client disables keepalive.""" + async with run_server() as server: + async with run_client(server, ping_interval=None) as client: + await asyncio.sleep(2 * MS) + self.assertEqual(client.latency, 0) + async def test_custom_connection_factory(self): """Client runs ClientConnection factory provided in create_connection.""" diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 239b5312e..9b84a6b81 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -868,6 +868,93 @@ async def test_pong_explicit_binary(self): await self.connection.pong(b"pong") await self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + # Test keepalive. + + @patch("random.getrandbits") + async def test_keepalive(self, getrandbits): + """keepalive sends pings.""" + self.connection.ping_interval = 2 * MS + getrandbits.return_value = 1918987876 + self.connection.start_keepalive() + await asyncio.sleep(3 * MS) + await self.assertFrameSent(Frame(Opcode.PING, b"rand")) + + @patch("random.getrandbits") + async def test_keepalive_times_out(self, getrandbits): + """keepalive closes the connection if ping_timeout elapses.""" + self.connection.ping_interval = 4 * MS + self.connection.ping_timeout = 2 * MS + async with self.drop_frames_rcvd(): + getrandbits.return_value = 1918987876 + self.connection.start_keepalive() + await asyncio.sleep(4 * MS) + # Exiting the context manager sleeps for MS. + await self.assertFrameSent(Frame(Opcode.PING, b"rand")) + await asyncio.sleep(MS) + await self.assertFrameSent( + Frame(Opcode.CLOSE, b"\x03\xf3keepalive ping timeout") + ) + + @patch("random.getrandbits") + async def test_keepalive_ignores_timeout(self, getrandbits): + """keepalive ignores timeouts if ping_timeout isn't set.""" + self.connection.ping_interval = 4 * MS + self.connection.ping_timeout = None + async with self.drop_frames_rcvd(): + getrandbits.return_value = 1918987876 + self.connection.start_keepalive() + await asyncio.sleep(4 * MS) + # Exiting the context manager sleeps for MS. + await self.assertFrameSent(Frame(Opcode.PING, b"rand")) + await asyncio.sleep(MS) + await self.assertNoFrameSent() + + async def test_disable_keepalive(self): + """keepalive is disabled when ping_interval is None.""" + self.connection.ping_interval = None + self.connection.start_keepalive() + await asyncio.sleep(3 * MS) + await self.assertNoFrameSent() + + async def test_keepalive_terminates_while_sleeping(self): + """keepalive task terminates while waiting to send a ping.""" + self.connection.ping_interval = 2 * MS + self.connection.start_keepalive() + await asyncio.sleep(MS) + await self.connection.close() + self.assertTrue(self.connection.keepalive_task.done()) + + async def test_keepalive_terminates_while_waiting_for_pong(self): + """keepalive task terminates while waiting to receive a pong.""" + self.connection.ping_interval = 2 * MS + async with self.drop_frames_rcvd(): + self.connection.start_keepalive() + await asyncio.sleep(2 * MS) + # Exiting the context manager sleeps for MS. + await self.connection.close() + self.assertTrue(self.connection.keepalive_task.done()) + + async def test_keepalive_reports_errors(self): + """keepalive reports unexpected errors in logs.""" + self.connection.ping_interval = 2 * MS + # Inject a fault by raising an exception in a pending pong waiter. + async with self.drop_frames_rcvd(): + self.connection.start_keepalive() + await asyncio.sleep(2 * MS) + # Exiting the context manager sleeps for MS. + pong_waiter = next(iter(self.connection.pong_waiters.values()))[0] + with self.assertLogs("websockets", logging.ERROR) as logs: + pong_waiter.set_exception(Exception("BOOM")) + await asyncio.sleep(0) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["keepalive ping failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) + # Test parameters. async def test_close_timeout(self): @@ -1092,7 +1179,7 @@ async def fragments(): broadcast([self.connection], "😀") self.assertEqual( - [record.getMessage() for record in logs.records][:2], + [record.getMessage() for record in logs.records], ["skipped broadcast: sending a fragmented message"], ) @@ -1135,7 +1222,7 @@ async def test_broadcast_skips_connection_failing_to_send(self): broadcast([self.connection], "😀") self.assertEqual( - [record.getMessage() for record in logs.records][:2], + [record.getMessage() for record in logs.records], ["skipped broadcast: failed to write message"], ) diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index fa590210f..b3023434b 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -312,6 +312,27 @@ async def test_disable_compression(self): async with run_client(server) as client: await self.assertEval(client, "ws.protocol.extensions", "[]") + async def test_keepalive_is_enabled(self): + """Server enables keepalive and measures latency.""" + async with run_server(ping_interval=MS) as server: + async with run_client(server) as client: + await client.send("ws.latency") + latency = eval(await client.recv()) + self.assertEqual(latency, 0) + await asyncio.sleep(2 * MS) + await client.send("ws.latency") + latency = eval(await client.recv()) + self.assertGreater(latency, 0) + + async def test_disable_keepalive(self): + """Client disables keepalive.""" + async with run_server(ping_interval=None) as server: + async with run_client(server) as client: + await asyncio.sleep(2 * MS) + await client.send("ws.latency") + latency = eval(await client.recv()) + self.assertEqual(latency, 0) + async def test_custom_connection_factory(self): """Server runs ServerConnection factory provided in create_connection.""" diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index ccea34719..8751b9ac6 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1533,7 +1533,7 @@ def test_broadcast_skips_connection_sending_fragmented_text(self): broadcast([self.protocol], "café") self.assertEqual( - [record.getMessage() for record in logs.records][:2], + [record.getMessage() for record in logs.records], ["skipped broadcast: sending a fragmented message"], ) @@ -1563,7 +1563,7 @@ def test_broadcast_skips_connection_failing_to_send(self): broadcast([self.protocol], "café") self.assertEqual( - [record.getMessage() for record in logs.records][:2], + [record.getMessage() for record in logs.records], ["skipped broadcast: failed to write message"], ) From 60381d2566b55126f0f89f0c8380cf44ddc51aa1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 19 Aug 2024 07:58:31 +0200 Subject: [PATCH 1334/1539] Fix exception chaining for ConnectionClosed. --- src/websockets/asyncio/connection.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index b232b7956..284fe2124 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -906,13 +906,17 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: def connection_lost(self, exc: Exception | None) -> None: self.protocol.receive_eof() # receive_eof is idempotent - self.recv_messages.close() + + # Abort recv() and pending pings with a ConnectionClosed exception. + # Set recv_exc first to get proper exception reporting. self.set_recv_exc(exc) + self.recv_messages.close() self.abort_pings() # If keepalive() was waiting for a pong, abort_pings() terminated it. # If it was sleeping until the next ping, we need to cancel it now if self.keepalive_task is not None: self.keepalive_task.cancel() + # If self.connection_lost_waiter isn't pending, that's a bug, because: # - it's set only here in connection_lost() which is called only once; # - it must never be canceled. From a78b5546074ed9e89a265eec1b54292a628d9b25 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 19 Aug 2024 08:00:16 +0200 Subject: [PATCH 1335/1539] Document legacy implementation in websockets.legacy. Until now it was documented directly in the websockets package. Also update most examples to use the new asyncio implementation. Some drive-by documentation improvements too. --- compliance/test_client.py | 10 +- compliance/test_server.py | 6 +- docs/conf.py | 28 +- docs/faq/asyncio.rst | 44 ++-- docs/faq/client.rst | 62 +++-- docs/faq/common.rst | 53 ++-- docs/faq/misc.rst | 17 +- docs/faq/server.rst | 111 ++++---- docs/howto/cheatsheet.rst | 57 ++-- docs/howto/django.rst | 2 +- docs/howto/heroku.rst | 2 +- docs/howto/nginx.rst | 6 +- docs/howto/patterns.rst | 10 +- docs/howto/quickstart.rst | 14 +- docs/howto/upgrade.rst | 91 ++++--- docs/intro/tutorial1.rst | 34 +-- docs/intro/tutorial3.rst | 6 +- docs/project/changelog.rst | 117 ++++----- docs/reference/asyncio/client.rst | 37 ++- docs/reference/asyncio/common.rst | 33 +-- docs/reference/asyncio/server.rst | 70 ++--- docs/reference/features.rst | 5 +- docs/reference/index.rst | 19 +- docs/reference/legacy/client.rst | 64 +++++ docs/reference/legacy/common.rst | 54 ++++ docs/reference/legacy/server.rst | 113 ++++++++ docs/reference/new-asyncio/client.rst | 57 ---- docs/reference/new-asyncio/common.rst | 47 ---- docs/reference/new-asyncio/server.rst | 83 ------ docs/topics/authentication.rst | 22 +- docs/topics/deployment.rst | 13 +- docs/topics/design.rst | 286 ++++++++++----------- docs/topics/logging.rst | 11 +- docs/topics/memory.rst | 9 +- docs/topics/security.rst | 6 +- example/deployment/fly/app.py | 10 +- example/deployment/haproxy/app.py | 4 +- example/deployment/heroku/app.py | 4 +- example/deployment/kubernetes/app.py | 18 +- example/deployment/kubernetes/benchmark.py | 5 +- example/deployment/nginx/app.py | 4 +- example/deployment/render/app.py | 10 +- example/deployment/supervisor/app.py | 4 +- example/django/authentication.py | 4 +- example/django/notifications.py | 7 +- example/echo.py | 2 +- example/faq/health_check_server.py | 15 +- example/faq/shutdown_client.py | 8 +- example/faq/shutdown_server.py | 5 +- example/legacy/basic_auth_client.py | 5 +- example/legacy/basic_auth_server.py | 8 +- example/legacy/unix_client.py | 5 +- example/legacy/unix_server.py | 5 +- example/logging/json_log_formatter.py | 2 +- example/quickstart/client.py | 5 +- example/quickstart/client_secure.py | 5 +- example/quickstart/counter.py | 13 +- example/quickstart/server.py | 5 +- example/quickstart/server_secure.py | 5 +- example/quickstart/show_time.py | 5 +- example/quickstart/show_time_2.py | 8 +- example/tutorial/step1/app.py | 4 +- example/tutorial/step2/app.py | 4 +- example/tutorial/step3/app.py | 4 +- experiments/authentication/app.py | 19 +- experiments/broadcast/clients.py | 4 +- src/websockets/exceptions.py | 4 +- src/websockets/legacy/auth.py | 6 +- src/websockets/legacy/client.py | 4 +- src/websockets/legacy/protocol.py | 16 +- tests/asyncio/test_connection.py | 2 +- tests/asyncio/test_server.py | 9 +- tests/sync/test_connection.py | 2 +- tests/sync/test_server.py | 5 +- 74 files changed, 951 insertions(+), 902 deletions(-) create mode 100644 docs/reference/legacy/client.rst create mode 100644 docs/reference/legacy/common.rst create mode 100644 docs/reference/legacy/server.rst delete mode 100644 docs/reference/new-asyncio/client.rst delete mode 100644 docs/reference/new-asyncio/common.rst delete mode 100644 docs/reference/new-asyncio/server.rst diff --git a/compliance/test_client.py b/compliance/test_client.py index 1ed4d711e..8e22569fd 100644 --- a/compliance/test_client.py +++ b/compliance/test_client.py @@ -1,9 +1,9 @@ +import asyncio import json import logging import urllib.parse -import asyncio -import websockets +from websockets.asyncio.client import connect logging.basicConfig(level=logging.WARNING) @@ -18,21 +18,21 @@ async def get_case_count(server): uri = f"{server}/getCaseCount" - async with websockets.connect(uri) as ws: + async with connect(uri) as ws: msg = ws.recv() return json.loads(msg) async def run_case(server, case, agent): uri = f"{server}/runCase?case={case}&agent={agent}" - async with websockets.connect(uri, max_size=2 ** 25, max_queue=1) as ws: + async with connect(uri, max_size=2 ** 25, max_queue=1) as ws: async for msg in ws: await ws.send(msg) async def update_reports(server, agent): uri = f"{server}/updateReports?agent={agent}" - async with websockets.connect(uri): + async with connect(uri): pass diff --git a/compliance/test_server.py b/compliance/test_server.py index 5701e4485..39176e902 100644 --- a/compliance/test_server.py +++ b/compliance/test_server.py @@ -1,7 +1,7 @@ +import asyncio import logging -import asyncio -import websockets +from websockets.asyncio.server import serve logging.basicConfig(level=logging.WARNING) @@ -19,7 +19,7 @@ async def echo(ws): async def main(): - with websockets.serve(echo, HOST, PORT, max_size=2 ** 25, max_queue=1): + with serve(echo, HOST, PORT, max_size=2 ** 25, max_queue=1): try: await asyncio.get_running_loop().create_future() # run forever except KeyboardInterrupt: diff --git a/docs/conf.py b/docs/conf.py index 9d61dc717..2c621bf41 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -36,20 +36,20 @@ # topics/design.rst discusses undocumented APIs ("py:meth", "client.WebSocketClientProtocol.handshake"), ("py:meth", "server.WebSocketServerProtocol.handshake"), - ("py:attr", "legacy.protocol.WebSocketCommonProtocol.is_client"), - ("py:attr", "legacy.protocol.WebSocketCommonProtocol.messages"), - ("py:meth", "legacy.protocol.WebSocketCommonProtocol.close_connection"), - ("py:attr", "legacy.protocol.WebSocketCommonProtocol.close_connection_task"), - ("py:meth", "legacy.protocol.WebSocketCommonProtocol.keepalive_ping"), - ("py:attr", "legacy.protocol.WebSocketCommonProtocol.keepalive_ping_task"), - ("py:meth", "legacy.protocol.WebSocketCommonProtocol.transfer_data"), - ("py:attr", "legacy.protocol.WebSocketCommonProtocol.transfer_data_task"), - ("py:meth", "legacy.protocol.WebSocketCommonProtocol.connection_open"), - ("py:meth", "legacy.protocol.WebSocketCommonProtocol.ensure_open"), - ("py:meth", "legacy.protocol.WebSocketCommonProtocol.fail_connection"), - ("py:meth", "legacy.protocol.WebSocketCommonProtocol.connection_lost"), - ("py:meth", "legacy.protocol.WebSocketCommonProtocol.read_message"), - ("py:meth", "legacy.protocol.WebSocketCommonProtocol.write_frame"), + ("py:attr", "protocol.WebSocketCommonProtocol.is_client"), + ("py:attr", "protocol.WebSocketCommonProtocol.messages"), + ("py:meth", "protocol.WebSocketCommonProtocol.close_connection"), + ("py:attr", "protocol.WebSocketCommonProtocol.close_connection_task"), + ("py:meth", "protocol.WebSocketCommonProtocol.keepalive_ping"), + ("py:attr", "protocol.WebSocketCommonProtocol.keepalive_ping_task"), + ("py:meth", "protocol.WebSocketCommonProtocol.transfer_data"), + ("py:attr", "protocol.WebSocketCommonProtocol.transfer_data_task"), + ("py:meth", "protocol.WebSocketCommonProtocol.connection_open"), + ("py:meth", "protocol.WebSocketCommonProtocol.ensure_open"), + ("py:meth", "protocol.WebSocketCommonProtocol.fail_connection"), + ("py:meth", "protocol.WebSocketCommonProtocol.connection_lost"), + ("py:meth", "protocol.WebSocketCommonProtocol.read_message"), + ("py:meth", "protocol.WebSocketCommonProtocol.write_frame"), ] # Add any Sphinx extension module names here, as strings. They can be diff --git a/docs/faq/asyncio.rst b/docs/faq/asyncio.rst index e77f50add..3bc381cfd 100644 --- a/docs/faq/asyncio.rst +++ b/docs/faq/asyncio.rst @@ -1,7 +1,12 @@ Using asyncio ============= -.. currentmodule:: websockets +.. currentmodule:: websockets.asyncio.connection + +.. admonition:: This FAQ is written for the new :mod:`asyncio` implementation. + :class: hint + + Answers are also valid for the legacy :mod:`asyncio` implementation. How do I run two coroutines in parallel? ---------------------------------------- @@ -9,8 +14,8 @@ How do I run two coroutines in parallel? You must start two tasks, which the event loop will run concurrently. You can achieve this with :func:`asyncio.gather` or :func:`asyncio.create_task`. -Keep track of the tasks and make sure they terminate or you cancel them when -the connection terminates. +Keep track of the tasks and make sure that they terminate or that you cancel +them when the connection terminates. Why does my program never receive any messages? ----------------------------------------------- @@ -22,13 +27,12 @@ Putting an ``await`` statement in a ``for`` or a ``while`` loop isn't enough to yield control. Awaiting a coroutine may yield control, but there's no guarantee that it will. -For example, :meth:`~legacy.protocol.WebSocketCommonProtocol.send` only yields -control when send buffers are full, which never happens in most practical -cases. +For example, :meth:`~Connection.send` only yields control when send buffers are +full, which never happens in most practical cases. -If you run a loop that contains only synchronous operations and -a :meth:`~legacy.protocol.WebSocketCommonProtocol.send` call, you must yield -control explicitly with :func:`asyncio.sleep`:: +If you run a loop that contains only synchronous operations and a +:meth:`~Connection.send` call, you must yield control explicitly with +:func:`asyncio.sleep`:: async def producer(websocket): message = generate_next_message() @@ -46,16 +50,19 @@ See `issue 867`_. Why am I having problems with threads? -------------------------------------- -If you choose websockets' default implementation based on :mod:`asyncio`, then -you shouldn't use threads. Indeed, choosing :mod:`asyncio` to handle concurrency -is mutually exclusive with :mod:`threading`. +If you choose websockets' :mod:`asyncio` implementation, then you shouldn't use +threads. Indeed, choosing :mod:`asyncio` to handle concurrency is mutually +exclusive with :mod:`threading`. If you believe that you need to run websockets in a thread and some logic in another thread, you should run that logic in a :class:`~asyncio.Task` instead. -If it blocks the event loop, :meth:`~asyncio.loop.run_in_executor` will help. -This question is really about :mod:`asyncio`. Please review the advice about -:ref:`asyncio-multithreading` in the Python documentation. +If it has to run in another thread because it would block the event loop, +:func:`~asyncio.to_thread` or :meth:`~asyncio.loop.run_in_executor` is the way +to go. + +Please review the advice about :ref:`asyncio-multithreading` in the Python +documentation. Why does my simple program misbehave mysteriously? -------------------------------------------------- @@ -63,7 +70,6 @@ Why does my simple program misbehave mysteriously? You are using :func:`time.sleep` instead of :func:`asyncio.sleep`, which blocks the event loop and prevents asyncio from operating normally. -This may lead to messages getting send but not received, to connection -timeouts, and to unexpected results of shotgun debugging e.g. adding an -unnecessary call to :meth:`~legacy.protocol.WebSocketCommonProtocol.send` -makes the program functional. +This may lead to messages getting send but not received, to connection timeouts, +and to unexpected results of shotgun debugging e.g. adding an unnecessary call +to a coroutine makes the program functional. diff --git a/docs/faq/client.rst b/docs/faq/client.rst index c590ac107..0dfc84253 100644 --- a/docs/faq/client.rst +++ b/docs/faq/client.rst @@ -1,7 +1,16 @@ Client ====== -.. currentmodule:: websockets +.. currentmodule:: websockets.asyncio.client + +.. admonition:: This FAQ is written for the new :mod:`asyncio` implementation. + :class: hint + + Answers are also valid for the legacy :mod:`asyncio` implementation. + + They translate to the :mod:`threading` implementation by removing ``await`` + and ``async`` keywords and by using a :class:`~threading.Thread` instead of + a :class:`~asyncio.Task` for concurrent execution. Why does the client close the connection prematurely? ----------------------------------------------------- @@ -22,46 +31,47 @@ change it to:: How do I access HTTP headers? ----------------------------- -Once the connection is established, HTTP headers are available in -:attr:`~client.WebSocketClientProtocol.request_headers` and -:attr:`~client.WebSocketClientProtocol.response_headers`. +Once the connection is established, HTTP headers are available in the +:attr:`~ClientConnection.request` and :attr:`~ClientConnection.response` +objects:: + + async with connect(...) as websocket: + websocket.request.headers + websocket.response.headers How do I set HTTP headers? -------------------------- To set the ``Origin``, ``Sec-WebSocket-Extensions``, or ``Sec-WebSocket-Protocol`` headers in the WebSocket handshake request, use the -``origin``, ``extensions``, or ``subprotocols`` arguments of -:func:`~client.connect`. +``origin``, ``extensions``, or ``subprotocols`` arguments of :func:`~connect`. To override the ``User-Agent`` header, use the ``user_agent_header`` argument. Set it to :obj:`None` to remove the header. To set other HTTP headers, for example the ``Authorization`` header, use the -``extra_headers`` argument:: +``additional_headers`` argument:: - async with connect(..., extra_headers={"Authorization": ...}) as websocket: + async with connect(..., additional_headers={"Authorization": ...}) as websocket: ... -In the :mod:`threading` API, this argument is named ``additional_headers``:: - - with connect(..., additional_headers={"Authorization": ...}) as websocket: - ... +In the legacy :mod:`asyncio` API, this argument is named ``extra_headers``. How do I force the IP address that the client connects to? ---------------------------------------------------------- -Use the ``host`` argument of :meth:`~asyncio.loop.create_connection`:: +Use the ``host`` argument :func:`~connect`:: - await websockets.connect("ws://example.com", host="192.168.0.1") + async with connect(..., host="192.168.0.1") as websocket: + ... -:func:`~client.connect` accepts the same arguments as -:meth:`~asyncio.loop.create_connection`. +:func:`~connect` accepts the same arguments as +:meth:`~asyncio.loop.create_connection` and passes them through. How do I close a connection? ---------------------------- -The easiest is to use :func:`~client.connect` as a context manager:: +The easiest is to use :func:`~connect` as a context manager:: async with connect(...) as websocket: ... @@ -71,9 +81,17 @@ The connection is closed when exiting the context manager. How do I reconnect when the connection drops? --------------------------------------------- -Use :func:`~client.connect` as an asynchronous iterator:: +.. admonition:: This feature is only supported by the legacy :mod:`asyncio` + implementation. + :class: warning + + It will be added to the new :mod:`asyncio` implementation soon. + +Use :func:`~websockets.legacy.client.connect` as an asynchronous iterator:: + + from websockets.legacy.client import connect - async for websocket in websockets.connect(...): + async for websocket in connect(...): try: ... except websockets.ConnectionClosed: @@ -90,12 +108,12 @@ You can close the connection. Here's an example that terminates cleanly when it receives SIGTERM on Unix: .. literalinclude:: ../../example/faq/shutdown_client.py - :emphasize-lines: 10-13 + :emphasize-lines: 11-13 How do I disable TLS/SSL certificate verification? -------------------------------------------------- Look at the ``ssl`` argument of :meth:`~asyncio.loop.create_connection`. -:func:`~client.connect` accepts the same arguments as -:meth:`~asyncio.loop.create_connection`. +:func:`~connect` accepts the same arguments as +:meth:`~asyncio.loop.create_connection` and passes them through. diff --git a/docs/faq/common.rst b/docs/faq/common.rst index 84256fdfe..0dc4a3aeb 100644 --- a/docs/faq/common.rst +++ b/docs/faq/common.rst @@ -1,7 +1,7 @@ Both sides ========== -.. currentmodule:: websockets +.. currentmodule:: websockets.asyncio.connection What does ``ConnectionClosedError: no close frame received or sent`` mean? -------------------------------------------------------------------------- @@ -11,12 +11,6 @@ If you're seeing this traceback in the logs of a server: .. code-block:: pytb connection handler failed - Traceback (most recent call last): - ... - asyncio.exceptions.IncompleteReadError: 0 bytes read on a total of 2 expected bytes - - The above exception was the direct cause of the following exception: - Traceback (most recent call last): ... websockets.exceptions.ConnectionClosedError: no close frame received or sent @@ -25,12 +19,6 @@ or if a client crashes with this traceback: .. code-block:: pytb - Traceback (most recent call last): - ... - ConnectionResetError: [Errno 54] Connection reset by peer - - The above exception was the direct cause of the following exception: - Traceback (most recent call last): ... websockets.exceptions.ConnectionClosedError: no close frame received or sent @@ -39,8 +27,8 @@ it means that the TCP connection was lost. As a consequence, the WebSocket connection was closed without receiving and sending a close frame, which is abnormal. -You can catch and handle :exc:`~exceptions.ConnectionClosed` to prevent it -from being logged. +You can catch and handle :exc:`~websockets.exceptions.ConnectionClosed` to +prevent it from being logged. There are several reasons why long-lived connections may be lost: @@ -62,12 +50,6 @@ If you're seeing this traceback in the logs of a server: .. code-block:: pytb connection handler failed - Traceback (most recent call last): - ... - asyncio.exceptions.CancelledError - - The above exception was the direct cause of the following exception: - Traceback (most recent call last): ... websockets.exceptions.ConnectionClosedError: sent 1011 (internal error) keepalive ping timeout; no close frame received @@ -76,12 +58,6 @@ or if a client crashes with this traceback: .. code-block:: pytb - Traceback (most recent call last): - ... - asyncio.exceptions.CancelledError - - The above exception was the direct cause of the following exception: - Traceback (most recent call last): ... websockets.exceptions.ConnectionClosedError: sent 1011 (internal error) keepalive ping timeout; no close frame received @@ -89,8 +65,8 @@ or if a client crashes with this traceback: it means that the WebSocket connection suffered from excessive latency and was closed after reaching the timeout of websockets' keepalive mechanism. -You can catch and handle :exc:`~exceptions.ConnectionClosed` to prevent it -from being logged. +You can catch and handle :exc:`~websockets.exceptions.ConnectionClosed` to +prevent it from being logged. There are two main reasons why latency may increase: @@ -102,8 +78,8 @@ See the discussion of :doc:`keepalive <../topics/keepalive>` for details. If websockets' default timeout of 20 seconds is too short for your use case, you can adjust it with the ``ping_timeout`` argument. -How do I set a timeout on :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`? --------------------------------------------------------------------------------- +How do I set a timeout on :meth:`~Connection.recv`? +--------------------------------------------------- On Python ≥ 3.11, use :func:`asyncio.timeout`:: @@ -117,23 +93,24 @@ On older versions of Python, use :func:`asyncio.wait_for`:: This technique works for most APIs. When it doesn't, for example with asynchronous context managers, websockets provides an ``open_timeout`` argument. -How can I pass arguments to a custom protocol subclass? -------------------------------------------------------- +How can I pass arguments to a custom connection subclass? +--------------------------------------------------------- -You can bind additional arguments to the protocol factory with +You can bind additional arguments to the connection factory with :func:`functools.partial`:: import asyncio import functools - import websockets + from websockets.asyncio.server import ServerConnection, serve - class MyServerProtocol(websockets.WebSocketServerProtocol): + class MyServerConnection(ServerConnection): def __init__(self, *args, extra_argument=None, **kwargs): super().__init__(*args, **kwargs) # do something with extra_argument - create_protocol = functools.partial(MyServerProtocol, extra_argument=42) - start_server = websockets.serve(..., create_protocol=create_protocol) + create_connection = functools.partial(ServerConnection, extra_argument=42) + async with serve(..., create_connection=create_connection): + ... This example was for a server. The same pattern applies on a client. diff --git a/docs/faq/misc.rst b/docs/faq/misc.rst index 0e74a784f..4936aa6f3 100644 --- a/docs/faq/misc.rst +++ b/docs/faq/misc.rst @@ -13,27 +13,12 @@ Often, this is because you created a script called ``websockets.py`` in your current working directory. Then ``import websockets`` imports this module instead of the websockets library. -.. _real-import-paths: - -Why is the default implementation located in ``websockets.legacy``? -................................................................... - -This is an artifact of websockets' history. For its first eight years, only the -:mod:`asyncio` implementation existed. Then, the Sans-I/O implementation was -added. Moving the code in a ``legacy`` submodule eased this refactoring and -optimized maintainability. - -All public APIs were kept at their original locations. ``websockets.legacy`` -isn't a public API. It's only visible in the source code and in stack traces. -There is no intent to deprecate this implementation — at least until a superior -alternative exists. - Why is websockets slower than another library in my benchmark? .............................................................. Not all libraries are as feature-complete as websockets. For a fair benchmark, you should disable features that the other library doesn't provide. Typically, -you may need to disable: +you must disable: * Compression: set ``compression=None`` * Keepalive: set ``ping_interval=None`` diff --git a/docs/faq/server.rst b/docs/faq/server.rst index 53e34632f..e6b068316 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -1,7 +1,16 @@ Server ====== -.. currentmodule:: websockets +.. currentmodule:: websockets.asyncio.server + +.. admonition:: This FAQ is written for the new :mod:`asyncio` implementation. + :class: hint + + Answers are also valid for the legacy :mod:`asyncio` implementation. + + They translate to the :mod:`threading` implementation by removing ``await`` + and ``async`` keywords and by using a :class:`~threading.Thread` instead of + a :class:`~asyncio.Task` for concurrent execution. Why does the server close the connection prematurely? ----------------------------------------------------- @@ -36,8 +45,13 @@ change it like this:: async for message in websocket: print(message) -*Don't feel bad if this happens to you — it's the most common question in -websockets' issue tracker :-)* +If you have prior experience with an API that relies on callbacks, you may +assume that ``handler()`` is executed every time a message is received. The API +of websockets relies on coroutines instead. + +The handler coroutine is started when a new connection is established. Then, it +is responsible for receiving or sending messages throughout the lifetime of that +connection. Why can only one client connect at a time? ------------------------------------------ @@ -69,9 +83,9 @@ continuously:: while True: await websocket.send("firehose!") -:meth:`~legacy.protocol.WebSocketCommonProtocol.send` completes synchronously as -long as there's space in send buffers. The event loop never runs. (This pattern -is uncommon in real-world applications. It occurs mostly in toy programs.) +:meth:`~ServerConnection.send` completes synchronously as long as there's space +in send buffers. The event loop never runs. (This pattern is uncommon in +real-world applications. It occurs mostly in toy programs.) You can avoid the issue by yielding control to the event loop explicitly:: @@ -102,12 +116,12 @@ Record all connections in a global variable:: finally: CONNECTIONS.remove(websocket) -Then, call :func:`~asyncio.connection.broadcast`:: +Then, call :func:`~websockets.asyncio.connection.broadcast`:: - import websockets + from websockets.asyncio.connection import broadcast def message_all(message): - websockets.broadcast(CONNECTIONS, message) + broadcast(CONNECTIONS, message) If you're running multiple server processes, make sure you call ``message_all`` in each process. @@ -129,7 +143,7 @@ Record connections in a global variable, keyed by user identifier:: finally: del CONNECTIONS[user_id] -Then, call :meth:`~legacy.protocol.WebSocketCommonProtocol.send`:: +Then, call :meth:`~ServerConnection.send`:: async def message_user(user_id, message): websocket = CONNECTIONS[user_id] # raises KeyError if user disconnected @@ -178,15 +192,12 @@ How do I pass arguments to the connection handler? You can bind additional arguments to the connection handler with :func:`functools.partial`:: - import asyncio import functools - import websockets async def handler(websocket, extra_argument): ... bound_handler = functools.partial(handler, extra_argument=42) - start_server = websockets.serve(bound_handler, ...) Another way to achieve this result is to define the ``handler`` coroutine in a scope where the ``extra_argument`` variable exists instead of injecting it @@ -195,14 +206,14 @@ through an argument. How do I access the request path? --------------------------------- -It is available in the :attr:`~server.WebSocketServerProtocol.path` attribute. +It is available in the :attr:`~ServerConnection.request` object. You may route a connection to different handlers depending on the request path:: async def handler(websocket): - if websocket.path == "/blue": + if websocket.request.path == "/blue": await blue_handler(websocket) - elif websocket.path == "/green": + elif websocket.request.path == "/green": await green_handler(websocket) else: # No handler for this path; close the connection. @@ -219,35 +230,46 @@ it may ignore the request path entirely. How do I access HTTP headers? ----------------------------- -To access HTTP headers during the WebSocket handshake, you can override -:attr:`~server.WebSocketServerProtocol.process_request`:: +You can access HTTP headers during the WebSocket handshake by providing a +``process_request`` callable or coroutine:: - async def process_request(self, path, request_headers): - authorization = request_headers["Authorization"] + def process_request(connection, request): + authorization = request.headers["Authorization"] + ... + + async with serve(handler, process_request=process_request): + ... -Once the connection is established, HTTP headers are available in -:attr:`~server.WebSocketServerProtocol.request_headers` and -:attr:`~server.WebSocketServerProtocol.response_headers`:: +Once the connection is established, HTTP headers are available in the +:attr:`~ServerConnection.request` and :attr:`~ServerConnection.response` +objects:: async def handler(websocket): - authorization = websocket.request_headers["Authorization"] + authorization = websocket.request.headers["Authorization"] How do I set HTTP headers? -------------------------- To set the ``Sec-WebSocket-Extensions`` or ``Sec-WebSocket-Protocol`` headers in the WebSocket handshake response, use the ``extensions`` or ``subprotocols`` -arguments of :func:`~server.serve`. +arguments of :func:`~serve`. To override the ``Server`` header, use the ``server_header`` argument. Set it to :obj:`None` to remove the header. -To set other HTTP headers, use the ``extra_headers`` argument. +To set other HTTP headers, provide a ``process_response`` callable or +coroutine:: + + def process_response(connection, request, response): + response.headers["X-Blessing"] = "May the network be with you" + + async with serve(handler, process_response=process_response): + ... How do I get the IP address of the client? ------------------------------------------ -It's available in :attr:`~legacy.protocol.WebSocketCommonProtocol.remote_address`:: +It's available in :attr:`~ServerConnection.remote_address`:: async def handler(websocket): remote_ip = websocket.remote_address[0] @@ -255,18 +277,19 @@ It's available in :attr:`~legacy.protocol.WebSocketCommonProtocol.remote_address How do I set the IP addresses that my server listens on? -------------------------------------------------------- -Use the ``host`` argument of :meth:`~asyncio.loop.create_server`:: +Use the ``host`` argument of :meth:`~serve`:: - await websockets.serve(handler, host="192.168.0.1", port=8080) + async with serve(handler, host="192.168.0.1", port=8080): + ... -:func:`~server.serve` accepts the same arguments as -:meth:`~asyncio.loop.create_server`. +:func:`~serve` accepts the same arguments as +:meth:`~asyncio.loop.create_server` and passes them through. What does ``OSError: [Errno 99] error while attempting to bind on address ('::1', 80, 0, 0): address not available`` mean? -------------------------------------------------------------------------------------------------------------------------- -You are calling :func:`~server.serve` without a ``host`` argument in a context -where IPv6 isn't available. +You are calling :func:`~serve` without a ``host`` argument in a context where +IPv6 isn't available. To listen only on IPv4, specify ``host="0.0.0.0"`` or ``family=socket.AF_INET``. @@ -280,17 +303,17 @@ websockets takes care of closing the connection when the handler exits. How do I stop a server? ----------------------- -Exit the :func:`~server.serve` context manager. +Exit the :func:`~serve` context manager. Here's an example that terminates cleanly when it receives SIGTERM on Unix: .. literalinclude:: ../../example/faq/shutdown_server.py - :emphasize-lines: 12-15,18 + :emphasize-lines: 13-16,19 How do I stop a server while keeping existing connections open? --------------------------------------------------------------- -Call the server's :meth:`~server.WebSocketServer.close` method with +Call the server's :meth:`~WebSocketServer.close` method with ``close_connections=False``. Here's how to adapt the example just above:: @@ -298,7 +321,7 @@ Here's how to adapt the example just above:: async def server(): ... - server = await websockets.serve(echo, "localhost", 8765) + server = await serve(echo, "localhost", 8765) await stop server.close(close_connections=False) await server.wait_closed() @@ -306,14 +329,14 @@ Here's how to adapt the example just above:: How do I implement a health check? ---------------------------------- -Intercept WebSocket handshake requests with the -:meth:`~server.WebSocketServerProtocol.process_request` hook. - -When a request is sent to the health check endpoint, treat is as an HTTP request -and return a ``(status, headers, body)`` tuple, as in this example: +Intercept requests with the ``process_request`` hook. When a request is sent to +the health check endpoint, treat is as an HTTP request and return a response: .. literalinclude:: ../../example/faq/health_check_server.py - :emphasize-lines: 7-9,18 + :emphasize-lines: 7-9,16 + +:meth:`~ServerConnection.respond` makes it easy to send a plain text response. +You can also construct a :class:`~websockets.http11.Response` object directly. How do I run HTTP and WebSocket servers on the same port? --------------------------------------------------------- @@ -327,7 +350,7 @@ Providing an HTTP server is out of scope for websockets. It only aims at providing a WebSocket server. There's limited support for returning HTTP responses with the -:attr:`~server.WebSocketServerProtocol.process_request` hook. +``process_request`` hook. If you need more, pick an HTTP server and run it separately. diff --git a/docs/howto/cheatsheet.rst b/docs/howto/cheatsheet.rst index 95b551f67..8df2f234b 100644 --- a/docs/howto/cheatsheet.rst +++ b/docs/howto/cheatsheet.rst @@ -9,24 +9,24 @@ Server * Write a coroutine that handles a single connection. It receives a WebSocket protocol instance and the URI path in argument. - * Call :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` and - :meth:`~legacy.protocol.WebSocketCommonProtocol.send` to receive and send - messages at any time. + * Call :meth:`~asyncio.connection.Connection.recv` and + :meth:`~asyncio.connection.Connection.send` to receive and send messages at + any time. - * When :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` or - :meth:`~legacy.protocol.WebSocketCommonProtocol.send` raises - :exc:`~exceptions.ConnectionClosed`, clean up and exit. If you started - other :class:`asyncio.Task`, terminate them before exiting. + * When :meth:`~asyncio.connection.Connection.recv` or + :meth:`~asyncio.connection.Connection.send` raises + :exc:`~exceptions.ConnectionClosed`, clean up and exit. If you started other + :class:`asyncio.Task`, terminate them before exiting. - * If you aren't awaiting :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`, - consider awaiting :meth:`~legacy.protocol.WebSocketCommonProtocol.wait_closed` - to detect quickly when the connection is closed. + * If you aren't awaiting :meth:`~asyncio.connection.Connection.recv`, consider + awaiting :meth:`~asyncio.connection.Connection.wait_closed` to detect + quickly when the connection is closed. - * You may :meth:`~legacy.protocol.WebSocketCommonProtocol.ping` or - :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` if you wish but it isn't - needed in general. + * You may :meth:`~asyncio.connection.Connection.ping` or + :meth:`~asyncio.connection.Connection.pong` if you wish but it isn't needed + in general. -* Create a server with :func:`~server.serve` which is similar to asyncio's +* Create a server with :func:`~asyncio.server.serve` which is similar to asyncio's :meth:`~asyncio.loop.create_server`. You can also use it as an asynchronous context manager. @@ -35,30 +35,30 @@ Server handler exits normally or with an exception. * For advanced customization, you may subclass - :class:`~server.WebSocketServerProtocol` and pass either this subclass or - a factory function as the ``create_protocol`` argument. + :class:`~asyncio.server.ServerConnection` and pass either this subclass or a + factory function as the ``create_connection`` argument. Client ------ -* Create a client with :func:`~client.connect` which is similar to asyncio's - :meth:`~asyncio.loop.create_connection`. You can also use it as an +* Create a client with :func:`~asyncio.client.connect` which is similar to + asyncio's :meth:`~asyncio.loop.create_connection`. You can also use it as an asynchronous context manager. * For advanced customization, you may subclass - :class:`~client.WebSocketClientProtocol` and pass either this subclass or - a factory function as the ``create_protocol`` argument. + :class:`~asyncio.client.ClientConnection` and pass either this subclass or + a factory function as the ``create_connection`` argument. -* Call :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` and - :meth:`~legacy.protocol.WebSocketCommonProtocol.send` to receive and send messages - at any time. +* Call :meth:`~asyncio.connection.Connection.recv` and + :meth:`~asyncio.connection.Connection.send` to receive and send messages at + any time. -* You may :meth:`~legacy.protocol.WebSocketCommonProtocol.ping` or - :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` if you wish but it isn't - needed in general. +* You may :meth:`~asyncio.connection.Connection.ping` or + :meth:`~asyncio.connection.Connection.pong` if you wish but it isn't needed in + general. -* If you aren't using :func:`~client.connect` as a context manager, call - :meth:`~legacy.protocol.WebSocketCommonProtocol.close` to terminate the connection. +* If you aren't using :func:`~asyncio.client.connect` as a context manager, call + :meth:`~asyncio.connection.Connection.close` to terminate the connection. .. _debugging: @@ -84,4 +84,3 @@ particular. Fortunately Python's official documentation provides advice to `develop with asyncio`_. Check it out: it's invaluable! .. _develop with asyncio: https://docs.python.org/3/library/asyncio-dev.html - diff --git a/docs/howto/django.rst b/docs/howto/django.rst index e3da0a878..dada9c5e4 100644 --- a/docs/howto/django.rst +++ b/docs/howto/django.rst @@ -124,7 +124,7 @@ support asynchronous I/O. It would block the event loop if it didn't run in a separate thread. :func:`~asyncio.to_thread` is available since Python 3.9. In earlier versions, use :meth:`~asyncio.loop.run_in_executor` instead. -Finally, we start a server with :func:`~websockets.server.serve`. +Finally, we start a server with :func:`~websockets.asyncio.server.serve`. We're ready to test! diff --git a/docs/howto/heroku.rst b/docs/howto/heroku.rst index a97d2e7ce..b335e14c5 100644 --- a/docs/howto/heroku.rst +++ b/docs/howto/heroku.rst @@ -42,7 +42,7 @@ Here's the implementation of the app, an echo server. Save it in a file called Heroku expects the server to `listen on a specific port`_, which is provided in the ``$PORT`` environment variable. The app reads it and passes it to -:func:`~websockets.server.serve`. +:func:`~websockets.asyncio.server.serve`. .. _listen on a specific port: https://devcenter.heroku.com/articles/preparing-a-codebase-for-heroku-deployment#4-listen-on-the-correct-port diff --git a/docs/howto/nginx.rst b/docs/howto/nginx.rst index ff42c3c2b..872353cad 100644 --- a/docs/howto/nginx.rst +++ b/docs/howto/nginx.rst @@ -21,9 +21,9 @@ We'd like nginx to connect to websockets servers via Unix sockets in order to avoid the overhead of TCP for communicating between processes running in the same OS. -We start the app with :func:`~websockets.server.unix_serve`. Each server -process listens on a different socket thanks to an environment variable set -by Supervisor to a different value. +We start the app with :func:`~websockets.asyncio.server.unix_serve`. Each server +process listens on a different socket thanks to an environment variable set by +Supervisor to a different value. Save this configuration to ``supervisord.conf``: diff --git a/docs/howto/patterns.rst b/docs/howto/patterns.rst index c6f325d21..60bc8ab42 100644 --- a/docs/howto/patterns.rst +++ b/docs/howto/patterns.rst @@ -8,7 +8,7 @@ client. You will certainly implement some of them in your application. This page gives examples of connection handlers for a server. However, they're also applicable to a client, simply by assuming that ``websocket`` is a -connection created with :func:`~client.connect`. +connection created with :func:`~asyncio.client.connect`. WebSocket connections are long-lived. You will usually write a loop to process several messages during the lifetime of a connection. @@ -42,10 +42,10 @@ In this example, ``producer()`` is a coroutine implementing your business logic for generating the next message to send on the WebSocket connection. Each message must be :class:`str` or :class:`bytes`. -Iteration terminates when the client disconnects -because :meth:`~server.WebSocketServerProtocol.send` raises a -:exc:`~exceptions.ConnectionClosed` exception, -which breaks out of the ``while True`` loop. +Iteration terminates when the client disconnects because +:meth:`~asyncio.server.ServerConnection.send` raises a +:exc:`~exceptions.ConnectionClosed` exception, which breaks out of the ``while +True`` loop. Consumer and producer --------------------- diff --git a/docs/howto/quickstart.rst b/docs/howto/quickstart.rst index ab870952c..e6bd362a4 100644 --- a/docs/howto/quickstart.rst +++ b/docs/howto/quickstart.rst @@ -17,9 +17,9 @@ It receives a name from the client, sends a greeting, and closes the connection. :language: python :linenos: -:func:`~server.serve` executes the connection handler coroutine ``hello()`` -once for each WebSocket connection. It closes the WebSocket connection when -the handler returns. +:func:`~asyncio.server.serve` executes the connection handler coroutine +``hello()`` once for each WebSocket connection. It closes the WebSocket +connection when the handler returns. Here's a corresponding WebSocket client. @@ -30,8 +30,8 @@ It sends a name to the server, receives a greeting, and closes the connection. :language: python :linenos: -Using :func:`~client.connect` as an asynchronous context manager ensures the -WebSocket connection is closed. +Using :func:`~asyncio.client.connect` as an asynchronous context manager ensures +the WebSocket connection is closed. .. _secure-server-example: @@ -73,8 +73,8 @@ In this example, the client needs a TLS context because the server uses a self-signed certificate. When connecting to a secure WebSocket server with a valid certificate — any -certificate signed by a CA that your Python installation trusts — you can -simply pass ``ssl=True`` to :func:`~client.connect`. +certificate signed by a CA that your Python installation trusts — you can simply +pass ``ssl=True`` to :func:`~asyncio.client.connect`. .. admonition:: Configure the TLS context securely :class: attention diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 16b010aca..40c8c5ec9 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -49,12 +49,13 @@ the release notes of the version in which the feature was deprecated. * The ``path`` argument of connection handlers — unnecessary since :ref:`10.1` and deprecated in :ref:`13.0`. -* The ``loop`` and ``legacy_recv`` arguments of :func:`~client.connect` and - :func:`~server.serve`, which were removed — deprecated in :ref:`10.0`. -* The ``timeout`` and ``klass`` arguments of :func:`~client.connect` and - :func:`~server.serve`, which were renamed to ``close_timeout`` and +* The ``loop`` and ``legacy_recv`` arguments of :func:`~legacy.client.connect` + and :func:`~legacy.server.serve`, which were removed — deprecated in + :ref:`10.0`. +* The ``timeout`` and ``klass`` arguments of :func:`~legacy.client.connect` and + :func:`~legacy.server.serve`, which were renamed to ``close_timeout`` and ``create_protocol`` — deprecated in :ref:`7.0` and :ref:`3.4` respectively. -* An empty string in the ``origins`` argument of :func:`~server.serve` — +* An empty string in the ``origins`` argument of :func:`~legacy.server.serve` — deprecated in :ref:`7.0`. * The ``host``, ``port``, and ``secure`` attributes of connections — deprecated in :ref:`8.0`. @@ -127,16 +128,16 @@ Client APIs | Legacy :mod:`asyncio` implementation | New :mod:`asyncio` implementation | +===================================================================+=====================================================+ | ``websockets.connect()`` |br| | :func:`websockets.asyncio.client.connect` | -| :func:`websockets.client.connect` |br| | | -| ``websockets.legacy.client.connect()`` | | +| ``websockets.client.connect()`` |br| | | +| :func:`websockets.legacy.client.connect` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.unix_connect()`` |br| | :func:`websockets.asyncio.client.unix_connect` | -| :func:`websockets.client.unix_connect` |br| | | -| ``websockets.legacy.client.unix_connect()`` | | +| ``websockets.client.unix_connect()`` |br| | | +| :func:`websockets.legacy.client.unix_connect` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.WebSocketClientProtocol`` |br| | :class:`websockets.asyncio.client.ClientConnection` | -| :class:`websockets.client.WebSocketClientProtocol` |br| | | -| ``websockets.legacy.client.WebSocketClientProtocol`` | | +| ``websockets.client.WebSocketClientProtocol`` |br| | | +| :class:`websockets.legacy.client.WebSocketClientProtocol` | | +-------------------------------------------------------------------+-----------------------------------------------------+ Server APIs @@ -146,31 +147,31 @@ Server APIs | Legacy :mod:`asyncio` implementation | New :mod:`asyncio` implementation | +===================================================================+=====================================================+ | ``websockets.serve()`` |br| | :func:`websockets.asyncio.server.serve` | -| :func:`websockets.server.serve` |br| | | -| ``websockets.legacy.server.serve()`` | | +| ``websockets.server.serve()`` |br| | | +| :func:`websockets.legacy.server.serve` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.unix_serve()`` |br| | :func:`websockets.asyncio.server.unix_serve` | -| :func:`websockets.server.unix_serve` |br| | | -| ``websockets.legacy.server.unix_serve()`` | | +| ``websockets.server.unix_serve()`` |br| | | +| :func:`websockets.legacy.server.unix_serve` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.WebSocketServer`` |br| | :class:`websockets.asyncio.server.WebSocketServer` | -| :class:`websockets.server.WebSocketServer` |br| | | -| ``websockets.legacy.server.WebSocketServer`` | | +| ``websockets.server.WebSocketServer`` |br| | | +| :class:`websockets.legacy.server.WebSocketServer` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.WebSocketServerProtocol`` |br| | :class:`websockets.asyncio.server.ServerConnection` | -| :class:`websockets.server.WebSocketServerProtocol` |br| | | -| ``websockets.legacy.server.WebSocketServerProtocol`` | | +| ``websockets.server.WebSocketServerProtocol`` |br| | | +| :class:`websockets.legacy.server.WebSocketServerProtocol` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.broadcast`` |br| | :func:`websockets.asyncio.connection.broadcast` | | :func:`websockets.legacy.protocol.broadcast()` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.BasicAuthWebSocketServerProtocol`` |br| | *not available yet* | -| :class:`websockets.auth.BasicAuthWebSocketServerProtocol` |br| | | -| ``websockets.legacy.auth.BasicAuthWebSocketServerProtocol`` | | +| ``websockets.auth.BasicAuthWebSocketServerProtocol`` |br| | | +| :class:`websockets.legacy.auth.BasicAuthWebSocketServerProtocol` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.basic_auth_protocol_factory()`` |br| | *not available yet* | -| :func:`websockets.auth.basic_auth_protocol_factory` |br| | | -| ``websockets.legacy.auth.basic_auth_protocol_factory()`` | | +| ``websockets.auth.basic_auth_protocol_factory()`` |br| | | +| :func:`websockets.legacy.auth.basic_auth_protocol_factory` | | +-------------------------------------------------------------------+-----------------------------------------------------+ .. _Review API changes: @@ -209,12 +210,12 @@ Customizing the opening handshake ................................. On the client side, if you're adding headers to the handshake request sent by -:func:`~client.connect` with the ``extra_headers`` argument, you must rename it -to ``additional_headers``. +:func:`~legacy.client.connect` with the ``extra_headers`` argument, you must +rename it to ``additional_headers``. -On the server side, if you're customizing how :func:`~server.serve` processes -the opening handshake with the ``process_request``, ``extra_headers``, or -``select_subprotocol``, you must update your code. ``process_response`` and +On the server side, if you're customizing how :func:`~legacy.server.serve` +processes the opening handshake with the ``process_request``, ``extra_headers``, +or ``select_subprotocol``, you must update your code. ``process_response`` and ``select_subprotocol`` have new signatures; ``process_response`` replaces ``extra_headers`` and provides more flexibility. @@ -242,10 +243,10 @@ an example:: ``connection`` is always available in ``process_request``. In the original implementation, you had to write a subclass of -:class:`~server.WebSocketServerProtocol` and pass it in the ``create_protocol`` -argument to make the connection object available in a ``process_request`` -method. This pattern isn't useful anymore; you can replace it with a -``process_request`` function or coroutine. +:class:`~legacy.server.WebSocketServerProtocol` and pass it in the +``create_protocol`` argument to make the connection object available in a +``process_request`` method. This pattern isn't useful anymore; you can replace +it with a ``process_request`` function or coroutine. ``path`` and ``headers`` are available as attributes of the ``request`` object. @@ -296,7 +297,7 @@ The signature of ``select_subprotocol`` changed. Here's an example:: ``connection`` is always available in ``select_subprotocol``. This brings the same benefits as in ``process_request``. It may remove the need to subclass of -:class:`~server.WebSocketServerProtocol`. +:class:`~legacy.server.WebSocketServerProtocol`. The ``subprotocols`` argument contains the list of subprotocols offered by the client. The list of subprotocols supported by the server was removed because @@ -320,7 +321,7 @@ update its name. The keyword argument of :func:`~asyncio.server.serve` for customizing the creation of the connection object is now called ``create_connection`` instead of ``create_protocol``. It must return a :class:`~asyncio.server.ServerConnection` -instead of a :class:`~server.WebSocketServerProtocol`. +instead of a :class:`~legacy.server.WebSocketServerProtocol`. If you were customizing connection objects, you should check the new implementation and possibly redo your customization. Keep in mind that the @@ -364,6 +365,28 @@ The ``write_limit`` argument of :func:`~asyncio.client.connect` and Attributes of connections ......................... +``path``, ``request_headers`` and ``response_headers`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The :attr:`~legacy.protocol.WebSocketCommonProtocol.path`, +:attr:`~legacy.protocol.WebSocketCommonProtocol.request_headers` and +:attr:`~legacy.protocol.WebSocketCommonProtocol.response_headers` properties are +replaced by :attr:`~asyncio.connection.Connection.request` and +:attr:`~asyncio.connection.Connection.response`, which provide a ``headers`` +attribute. + +If your code relies on them, you can replace:: + + connection.path + connection.request_headers + connection.response_headers + +with:: + + connection.request.path + connection.request.headers + connection.response.headers + ``open`` and ``closed`` ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/intro/tutorial1.rst b/docs/intro/tutorial1.rst index 6b32d47f6..74f5f79a3 100644 --- a/docs/intro/tutorial1.rst +++ b/docs/intro/tutorial1.rst @@ -184,7 +184,7 @@ Create an ``app.py`` file next to ``connect4.py`` with this content: import asyncio - import websockets + from websockets.asyncio.server import serve async def handler(websocket): @@ -194,7 +194,7 @@ Create an ``app.py`` file next to ``connect4.py`` with this content: async def main(): - async with websockets.serve(handler, "", 8001): + async with serve(handler, "", 8001): await asyncio.get_running_loop().create_future() # run forever @@ -204,8 +204,9 @@ Create an ``app.py`` file next to ``connect4.py`` with this content: The entry point of this program is ``asyncio.run(main())``. It creates an asyncio event loop, runs the ``main()`` coroutine, and shuts down the loop. -The ``main()`` coroutine calls :func:`~server.serve` to start a websockets -server. :func:`~server.serve` takes three positional arguments: +The ``main()`` coroutine calls :func:`~asyncio.server.serve` to start a +websockets server. :func:`~asyncio.server.serve` takes three positional +arguments: * ``handler`` is a coroutine that manages a connection. When a client connects, websockets calls ``handler`` with the connection in argument. @@ -215,7 +216,7 @@ server. :func:`~server.serve` takes three positional arguments: on the same local network can connect. * The third argument is the port on which the server listens. -Invoking :func:`~server.serve` as an asynchronous context manager, in an +Invoking :func:`~asyncio.server.serve` as an asynchronous context manager, in an ``async with`` block, ensures that the server shuts down properly when terminating the program. @@ -258,11 +259,11 @@ stack trace of an exception: ... websockets.exceptions.ConnectionClosedOK: received 1000 (OK); then sent 1000 (OK) -Indeed, the server was waiting for the next message -with :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` when the client -disconnected. When this happens, websockets raises -a :exc:`~exceptions.ConnectionClosedOK` exception to let you know that you -won't receive another message on this connection. +Indeed, the server was waiting for the next message with +:meth:`~asyncio.server.ServerConnection.recv` when the client disconnected. +When this happens, websockets raises a :exc:`~exceptions.ConnectionClosedOK` +exception to let you know that you won't receive another message on this +connection. This exception creates noise in the server logs, making it more difficult to spot real errors when you add functionality to the server. Catch it in the @@ -551,13 +552,12 @@ Summary In this first part of the tutorial, you learned how to: -* build and run a WebSocket server in Python with :func:`~server.serve`; -* receive a message in a connection handler - with :meth:`~server.WebSocketServerProtocol.recv`; -* send a message in a connection handler - with :meth:`~server.WebSocketServerProtocol.send`; -* iterate over incoming messages with ``async for - message in websocket: ...``; +* build and run a WebSocket server in Python with :func:`~asyncio.server.serve`; +* receive a message in a connection handler with + :meth:`~asyncio.server.ServerConnection.recv`; +* send a message in a connection handler with + :meth:`~asyncio.server.ServerConnection.send`; +* iterate over incoming messages with ``async for message in websocket: ...``; * open a WebSocket connection in JavaScript with the ``WebSocket`` API; * send messages in a browser with ``WebSocket.send()``; * receive messages in a browser by listening to ``message`` events; diff --git a/docs/intro/tutorial3.rst b/docs/intro/tutorial3.rst index 6fdec113b..21d51371b 100644 --- a/docs/intro/tutorial3.rst +++ b/docs/intro/tutorial3.rst @@ -93,9 +93,9 @@ called ``stop`` and registers a signal handler that sets the result of this future. The value of the future doesn't matter; it's only for waiting for ``SIGTERM``. -Then, by using :func:`~server.serve` as a context manager and exiting the -context when ``stop`` has a result, ``main()`` ensures that the server closes -connections cleanly and exits on ``SIGTERM``. +Then, by using :func:`~asyncio.server.serve` as a context manager and exiting +the context when ``stop`` has a result, ``main()`` ensures that the server +closes connections cleanly and exits on ``SIGTERM``. The app is now fully compatible with Heroku. diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index eaabb2e9f..df5af54f4 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -178,8 +178,8 @@ Backwards-incompatible changes As a consequence, calling ``WebSocket.close()`` without arguments in a browser isn't reported as an error anymore. -.. admonition:: :func:`~server.serve` times out on the opening handshake after - 10 seconds by default. +.. admonition:: :func:`~legacy.server.serve` times out on the opening handshake + after 10 seconds by default. :class: note You can adjust the timeout with the ``open_timeout`` parameter. Set it to @@ -200,7 +200,7 @@ New features See :func:`websockets.sync.client.connect` and :func:`websockets.sync.server.serve` for details. -* Added ``open_timeout`` to :func:`~server.serve`. +* Added ``open_timeout`` to :func:`~legacy.server.serve`. * Made it possible to close a server without closing existing connections. @@ -289,7 +289,7 @@ Bug fixes * Fixed backwards-incompatibility in 10.1 for connection handlers created with :func:`functools.partial`. -* Avoided leaking open sockets when :func:`~client.connect` is canceled. +* Avoided leaking open sockets when :func:`~legacy.client.connect` is canceled. .. _10.1: @@ -330,8 +330,8 @@ Improvements .. _AWS API Gateway: https://github.com/python-websockets/websockets/issues/1065 -* Mirrored the entire :class:`~asyncio.Server` API - in :class:`~server.WebSocketServer`. +* Mirrored the entire :class:`~asyncio.Server` API in + :class:`~legacy.server.WebSocketServer`. * Improved performance for large messages on ARM processors. @@ -364,9 +364,9 @@ Backwards-incompatible changes Python 3.10 for details. The ``loop`` parameter is also removed - from :class:`~server.WebSocketServer`. This should be transparent. + from :class:`~legacy.server.WebSocketServer`. This should be transparent. -.. admonition:: :func:`~client.connect` times out after 10 seconds by default. +.. admonition:: :func:`~legacy.client.connect` times out after 10 seconds by default. :class: note You can adjust the timeout with the ``open_timeout`` parameter. Set it to @@ -405,9 +405,9 @@ New features * Added :func:`~legacy.protocol.broadcast` to send a message to many clients. * Added support for reconnecting automatically by using - :func:`~client.connect` as an asynchronous iterator. + :func:`~legacy.client.connect` as an asynchronous iterator. -* Added ``open_timeout`` to :func:`~client.connect`. +* Added ``open_timeout`` to :func:`~legacy.client.connect`. * Documented how to integrate with `Django `_. @@ -427,12 +427,12 @@ Improvements * Optimized processing of client-to-server messages when the C extension isn't available. -* Supported relative redirects in :func:`~client.connect`. +* Supported relative redirects in :func:`~legacy.client.connect`. * Handled TCP connection drops during the opening handshake. * Made it easier to customize authentication with - :meth:`~auth.BasicAuthWebSocketServerProtocol.check_credentials`. + :meth:`~legacy.auth.BasicAuthWebSocketServerProtocol.check_credentials`. * Provided additional information in :exc:`~exceptions.ConnectionClosed` exceptions. @@ -590,7 +590,7 @@ Bug fixes ......... * Restored the ability to pass a socket with the ``sock`` parameter of - :func:`~server.serve`. + :func:`~legacy.server.serve`. * Removed an incorrect assertion when a connection drops. @@ -623,11 +623,12 @@ Backwards-incompatible changes .. admonition:: ``process_request`` is now expected to be a coroutine. :class: note - If you're passing a ``process_request`` argument to :func:`~server.serve` - or :class:`~server.WebSocketServerProtocol`, or if you're overriding - :meth:`~server.WebSocketServerProtocol.process_request` in a subclass, - define it with ``async def`` instead of ``def``. Previously, both were - supported. + If you're passing a ``process_request`` argument to + :func:`~legacy.server.serve` or + :class:`~legacy.server.WebSocketServerProtocol`, or if you're overriding + :meth:`~legacy.server.WebSocketServerProtocol.process_request` in a + subclass, define it with ``async def`` instead of ``def``. Previously, both + were supported. For backwards compatibility, functions are still accepted, but mixing functions and coroutines won't work in some inheritance scenarios. @@ -661,15 +662,15 @@ Backwards-incompatible changes New features ............ -* Added :func:`~auth.basic_auth_protocol_factory` to enforce HTTP - Basic Auth on the server side. +* Added :func:`~legacy.auth.basic_auth_protocol_factory` to enforce HTTP Basic + Auth on the server side. -* :func:`~client.connect` handles redirects from the server during the +* :func:`~legacy.client.connect` handles redirects from the server during the handshake. -* :func:`~client.connect` supports overriding ``host`` and ``port``. +* :func:`~legacy.client.connect` supports overriding ``host`` and ``port``. -* Added :func:`~client.unix_connect` for connecting to Unix sockets. +* Added :func:`~legacy.client.unix_connect` for connecting to Unix sockets. * Added support for asynchronous generators in :meth:`~legacy.protocol.WebSocketCommonProtocol.send` @@ -699,9 +700,8 @@ Improvements :exc:`~exceptions.ConnectionClosed` to tell apart normal connection termination from errors. -* Changed :meth:`WebSocketServer.close() - ` to perform a proper closing handshake - instead of failing the connection. +* Changed :meth:`WebSocketServer.close() ` + to perform a proper closing handshake instead of failing the connection. * Improved error messages when HTTP parsing fails. @@ -734,7 +734,7 @@ Backwards-incompatible changes See :class:`~legacy.protocol.WebSocketCommonProtocol` for details. .. admonition:: Termination of connections by :meth:`WebSocketServer.close() - ` changes. + ` changes. :class: caution Previously, connections handlers were canceled. Now, connections are @@ -758,15 +758,16 @@ Backwards-incompatible changes Concurrent calls lead to non-deterministic behavior because there are no guarantees about which coroutine will receive which message. -.. admonition:: The ``timeout`` argument of :func:`~server.serve` - and :func:`~client.connect` is renamed to ``close_timeout`` . +.. admonition:: The ``timeout`` argument of :func:`~legacy.server.serve` + and :func:`~legacy.client.connect` is renamed to ``close_timeout`` . :class: note This prevents confusion with ``ping_timeout``. For backwards compatibility, ``timeout`` is still supported. -.. admonition:: The ``origins`` argument of :func:`~server.serve` changes. +.. admonition:: The ``origins`` argument of :func:`~legacy.server.serve` + changes. :class: note Include :obj:`None` in the list rather than ``''`` to allow requests that @@ -786,10 +787,10 @@ New features ............ * Added ``process_request`` and ``select_subprotocol`` arguments to - :func:`~server.serve` and - :class:`~server.WebSocketServerProtocol` to facilitate customization of - :meth:`~server.WebSocketServerProtocol.process_request` and - :meth:`~server.WebSocketServerProtocol.select_subprotocol`. + :func:`~legacy.server.serve` and + :class:`~legacy.server.WebSocketServerProtocol` to facilitate customization of + :meth:`~legacy.server.WebSocketServerProtocol.process_request` and + :meth:`~legacy.server.WebSocketServerProtocol.select_subprotocol`. * Added support for sending fragmented messages. @@ -826,10 +827,10 @@ Backwards-incompatible changes several APIs are updated to use it. :class: caution - * The ``request_headers`` argument - of :meth:`~server.WebSocketServerProtocol.process_request` is now - a :class:`~datastructures.Headers` instead of - an ``http.client.HTTPMessage``. + * The ``request_headers`` argument of + :meth:`~legacy.server.WebSocketServerProtocol.process_request` is now a + :class:`~datastructures.Headers` instead of an + ``http.client.HTTPMessage``. * The ``request_headers`` and ``response_headers`` attributes of :class:`~legacy.protocol.WebSocketCommonProtocol` are now @@ -866,7 +867,7 @@ Bug fixes ......... * Fixed a regression in 5.0 that broke some invocations of - :func:`~server.serve` and :func:`~client.connect`. + :func:`~legacy.server.serve` and :func:`~legacy.client.connect`. .. _5.0: @@ -900,10 +901,10 @@ Backwards-incompatible changes New features ............ -* :func:`~client.connect` performs HTTP Basic Auth when the URI contains +* :func:`~legacy.client.connect` performs HTTP Basic Auth when the URI contains credentials. -* :func:`~server.unix_serve` can be used as an asynchronous context +* :func:`~legacy.server.unix_serve` can be used as an asynchronous context manager on Python ≥ 3.5.1. * Added the :attr:`~legacy.protocol.WebSocketCommonProtocol.closed` property @@ -979,7 +980,7 @@ Backwards-incompatible changes Compression should improve performance but it increases RAM and CPU use. If you want to disable compression, add ``compression=None`` when calling - :func:`~server.serve` or :func:`~client.connect`. + :func:`~legacy.server.serve` or :func:`~legacy.client.connect`. .. admonition:: The ``state_name`` attribute of protocols is deprecated. :class: note @@ -992,10 +993,10 @@ New features * :class:`~legacy.protocol.WebSocketCommonProtocol` instances can be used as asynchronous iterators on Python ≥ 3.6. They yield incoming messages. -* Added :func:`~server.unix_serve` for listening on Unix sockets. +* Added :func:`~legacy.server.unix_serve` for listening on Unix sockets. -* Added the :attr:`~server.WebSocketServer.sockets` attribute to the - return value of :func:`~server.serve`. +* Added the :attr:`~legacy.server.WebSocketServer.sockets` attribute to the + return value of :func:`~legacy.server.serve`. * Allowed ``extra_headers`` to override ``Server`` and ``User-Agent`` headers. @@ -1030,24 +1031,24 @@ Backwards-incompatible changes by :class:`~exceptions.InvalidStatusCode`. :class: note - This exception is raised when :func:`~client.connect` receives an invalid + This exception is raised when :func:`~legacy.client.connect` receives an invalid response status code from the server. New features ............ -* :func:`~server.serve` can be used as an asynchronous context manager +* :func:`~legacy.server.serve` can be used as an asynchronous context manager on Python ≥ 3.5.1. * Added support for customizing handling of incoming connections with - :meth:`~server.WebSocketServerProtocol.process_request`. + :meth:`~legacy.server.WebSocketServerProtocol.process_request`. * Made read and write buffer sizes configurable. Improvements ............ -* Renamed :func:`~server.serve` and :func:`~client.connect`'s +* Renamed :func:`~legacy.server.serve` and :func:`~legacy.client.connect`'s ``klass`` argument to ``create_protocol`` to reflect that it can also be a callable. For backwards compatibility, ``klass`` is still supported. @@ -1058,7 +1059,7 @@ Improvements Bug fixes ......... -* Providing a ``sock`` argument to :func:`~client.connect` no longer +* Providing a ``sock`` argument to :func:`~legacy.client.connect` no longer crashes. .. _3.3: @@ -1094,7 +1095,7 @@ New features ............ * Added ``timeout``, ``max_size``, and ``max_queue`` arguments to - :func:`~client.connect` and :func:`~server.serve`. + :func:`~legacy.client.connect` and :func:`~legacy.server.serve`. Improvements ............ @@ -1151,15 +1152,15 @@ Backwards-incompatible changes In order to avoid stranding projects built upon an earlier version, the previous behavior can be restored by passing ``legacy_recv=True`` to - :func:`~server.serve`, :func:`~client.connect`, - :class:`~server.WebSocketServerProtocol`, or - :class:`~client.WebSocketClientProtocol`. + :func:`~legacy.server.serve`, :func:`~legacy.client.connect`, + :class:`~legacy.server.WebSocketServerProtocol`, or + :class:`~legacy.client.WebSocketClientProtocol`. New features ............ -* :func:`~client.connect` can be used as an asynchronous context - manager on Python ≥ 3.5.1. +* :func:`~legacy.client.connect` can be used as an asynchronous context manager + on Python ≥ 3.5.1. * :meth:`~legacy.protocol.WebSocketCommonProtocol.ping` and :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` support data passed as @@ -1260,8 +1261,8 @@ New features * Added support for subprotocols. -* Added ``loop`` argument to :func:`~client.connect` and - :func:`~server.serve`. +* Added ``loop`` argument to :func:`~legacy.client.connect` and + :func:`~legacy.server.serve`. .. _2.3: diff --git a/docs/reference/asyncio/client.rst b/docs/reference/asyncio/client.rst index f9ce2f2d8..77a3c5d53 100644 --- a/docs/reference/asyncio/client.rst +++ b/docs/reference/asyncio/client.rst @@ -1,24 +1,28 @@ -Client (legacy :mod:`asyncio`) -============================== +Client (new :mod:`asyncio`) +=========================== -.. automodule:: websockets.client +.. automodule:: websockets.asyncio.client Opening a connection -------------------- -.. autofunction:: connect(uri, *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) +.. autofunction:: connect :async: -.. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) +.. autofunction:: unix_connect :async: Using a connection ------------------ -.. autoclass:: WebSocketClientProtocol(*, logger=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) +.. autoclass:: ClientConnection + + .. automethod:: __aiter__ .. automethod:: recv + .. automethod:: recv_streaming + .. automethod:: send .. automethod:: close @@ -39,26 +43,15 @@ Using a connection .. autoproperty:: remote_address - .. autoproperty:: open - - .. autoproperty:: closed - .. autoattribute:: latency + .. autoproperty:: state + The following attributes are available after the opening handshake, once the WebSocket connection is open: - .. autoattribute:: path - - .. autoattribute:: request_headers - - .. autoattribute:: response_headers - - .. autoattribute:: subprotocol - - The following attributes are available after the closing handshake, - once the WebSocket connection is closed: + .. autoattribute:: request - .. autoproperty:: close_code + .. autoattribute:: response - .. autoproperty:: close_reason + .. autoproperty:: subprotocol diff --git a/docs/reference/asyncio/common.rst b/docs/reference/asyncio/common.rst index aee774479..a58325fb9 100644 --- a/docs/reference/asyncio/common.rst +++ b/docs/reference/asyncio/common.rst @@ -1,14 +1,18 @@ :orphan: -Both sides (legacy :mod:`asyncio`) -================================== +Both sides (new :mod:`asyncio`) +=============================== -.. automodule:: websockets.legacy.protocol +.. automodule:: websockets.asyncio.connection -.. autoclass:: WebSocketCommonProtocol(*, logger=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) +.. autoclass:: Connection + + .. automethod:: __aiter__ .. automethod:: recv + .. automethod:: recv_streaming + .. automethod:: send .. automethod:: close @@ -29,26 +33,15 @@ Both sides (legacy :mod:`asyncio`) .. autoproperty:: remote_address - .. autoproperty:: open - - .. autoproperty:: closed - .. autoattribute:: latency + .. autoproperty:: state + The following attributes are available after the opening handshake, once the WebSocket connection is open: - .. autoattribute:: path - - .. autoattribute:: request_headers - - .. autoattribute:: response_headers - - .. autoattribute:: subprotocol - - The following attributes are available after the closing handshake, - once the WebSocket connection is closed: + .. autoattribute:: request - .. autoproperty:: close_code + .. autoattribute:: response - .. autoproperty:: close_reason + .. autoproperty:: subprotocol diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index 3636f0b33..7bceca5a0 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -1,19 +1,19 @@ -Server (legacy :mod:`asyncio`) -============================== +Server (new :mod:`asyncio`) +=========================== -.. automodule:: websockets.server +.. automodule:: websockets.asyncio.server -Starting a server +Creating a server ----------------- -.. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) +.. autofunction:: serve :async: -.. autofunction:: unix_serve(ws_handler, path=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) +.. autofunction:: unix_serve :async: -Stopping a server ------------------ +Running a server +---------------- .. autoclass:: WebSocketServer @@ -34,10 +34,14 @@ Stopping a server Using a connection ------------------ -.. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, logger=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) +.. autoclass:: ServerConnection + + .. automethod:: __aiter__ .. automethod:: recv + .. automethod:: recv_streaming + .. automethod:: send .. automethod:: close @@ -48,11 +52,7 @@ Using a connection .. automethod:: pong - You can customize the opening handshake in a subclass by overriding these methods: - - .. automethod:: process_request - - .. automethod:: select_subprotocol + .. automethod:: respond WebSocket connection objects also provide these attributes: @@ -64,50 +64,20 @@ Using a connection .. autoproperty:: remote_address - .. autoproperty:: open - - .. autoproperty:: closed - .. autoattribute:: latency + .. autoproperty:: state + The following attributes are available after the opening handshake, once the WebSocket connection is open: - .. autoattribute:: path - - .. autoattribute:: request_headers - - .. autoattribute:: response_headers - - .. autoattribute:: subprotocol - - The following attributes are available after the closing handshake, - once the WebSocket connection is closed: - - .. autoproperty:: close_code - - .. autoproperty:: close_reason - - -Basic authentication --------------------- - -.. automodule:: websockets.auth - -websockets supports HTTP Basic Authentication according to -:rfc:`7235` and :rfc:`7617`. - -.. autofunction:: basic_auth_protocol_factory - -.. autoclass:: BasicAuthWebSocketServerProtocol - - .. autoattribute:: realm + .. autoattribute:: request - .. autoattribute:: username + .. autoattribute:: response - .. automethod:: check_credentials + .. autoproperty:: subprotocol Broadcast --------- -.. autofunction:: websockets.legacy.protocol.broadcast +.. autofunction:: websockets.asyncio.connection.broadcast diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 45fa79c48..6840fe15b 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -182,5 +182,6 @@ Request if it is missing or invalid (`#1246`). The client API doesn't attempt to guarantee that there is no more than one connection to a given IP address in a CONNECTING state. This behavior is -mandated by :rfc:`6455`, section 4.1. However, :func:`~client.connect()` isn't -the right layer for enforcing this constraint. It's the caller's responsibility. +mandated by :rfc:`6455`, section 4.1. However, :func:`~asyncio.client.connect()` +isn't the right layer for enforcing this constraint. It's the caller's +responsibility. diff --git a/docs/reference/index.rst b/docs/reference/index.rst index d3a0e935c..77b538b78 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -13,12 +13,12 @@ Check which implementations support which features and known limitations. features +:mod:`asyncio` (new) +-------------------- -:mod:`asyncio` --------------- +It's ideal for servers that handle many clients concurrently. -This is the default implementation. It's ideal for servers that handle many -clients concurrently. +It's a rewrite of the legacy :mod:`asyncio` implementation. .. toctree:: :titlesonly: @@ -26,17 +26,16 @@ clients concurrently. asyncio/server asyncio/client -:mod:`asyncio` (new) --------------------- +:mod:`asyncio` (legacy) +----------------------- -This is a rewrite of the :mod:`asyncio` implementation. It will become the -default in the future. +This is the historical implementation. .. toctree:: :titlesonly: - new-asyncio/server - new-asyncio/client + legacy/server + legacy/client :mod:`threading` ---------------- diff --git a/docs/reference/legacy/client.rst b/docs/reference/legacy/client.rst new file mode 100644 index 000000000..fca45d218 --- /dev/null +++ b/docs/reference/legacy/client.rst @@ -0,0 +1,64 @@ +Client (legacy :mod:`asyncio`) +============================== + +.. automodule:: websockets.legacy.client + +Opening a connection +-------------------- + +.. autofunction:: connect(uri, *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) + :async: + +.. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) + :async: + +Using a connection +------------------ + +.. autoclass:: WebSocketClientProtocol(*, logger=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) + + .. automethod:: recv + + .. automethod:: send + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + .. autoproperty:: open + + .. autoproperty:: closed + + .. autoattribute:: latency + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: path + + .. autoattribute:: request_headers + + .. autoattribute:: response_headers + + .. autoattribute:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason diff --git a/docs/reference/legacy/common.rst b/docs/reference/legacy/common.rst new file mode 100644 index 000000000..aee774479 --- /dev/null +++ b/docs/reference/legacy/common.rst @@ -0,0 +1,54 @@ +:orphan: + +Both sides (legacy :mod:`asyncio`) +================================== + +.. automodule:: websockets.legacy.protocol + +.. autoclass:: WebSocketCommonProtocol(*, logger=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) + + .. automethod:: recv + + .. automethod:: send + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + .. autoproperty:: open + + .. autoproperty:: closed + + .. autoattribute:: latency + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: path + + .. autoattribute:: request_headers + + .. autoattribute:: response_headers + + .. autoattribute:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason diff --git a/docs/reference/legacy/server.rst b/docs/reference/legacy/server.rst new file mode 100644 index 000000000..c2758f5a2 --- /dev/null +++ b/docs/reference/legacy/server.rst @@ -0,0 +1,113 @@ +Server (legacy :mod:`asyncio`) +============================== + +.. automodule:: websockets.legacy.server + +Starting a server +----------------- + +.. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) + :async: + +.. autofunction:: unix_serve(ws_handler, path=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) + :async: + +Stopping a server +----------------- + +.. autoclass:: WebSocketServer + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: get_loop + + .. automethod:: is_serving + + .. automethod:: start_serving + + .. automethod:: serve_forever + + .. autoattribute:: sockets + +Using a connection +------------------ + +.. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, logger=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) + + .. automethod:: recv + + .. automethod:: send + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: ping + + .. automethod:: pong + + You can customize the opening handshake in a subclass by overriding these methods: + + .. automethod:: process_request + + .. automethod:: select_subprotocol + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + .. autoproperty:: open + + .. autoproperty:: closed + + .. autoattribute:: latency + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: path + + .. autoattribute:: request_headers + + .. autoattribute:: response_headers + + .. autoattribute:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason + + +Basic authentication +-------------------- + +.. automodule:: websockets.legacy.auth + +websockets supports HTTP Basic Authentication according to +:rfc:`7235` and :rfc:`7617`. + +.. autofunction:: basic_auth_protocol_factory + +.. autoclass:: BasicAuthWebSocketServerProtocol + + .. autoattribute:: realm + + .. autoattribute:: username + + .. automethod:: check_credentials + +Broadcast +--------- + +.. autofunction:: websockets.legacy.protocol.broadcast diff --git a/docs/reference/new-asyncio/client.rst b/docs/reference/new-asyncio/client.rst deleted file mode 100644 index 77a3c5d53..000000000 --- a/docs/reference/new-asyncio/client.rst +++ /dev/null @@ -1,57 +0,0 @@ -Client (new :mod:`asyncio`) -=========================== - -.. automodule:: websockets.asyncio.client - -Opening a connection --------------------- - -.. autofunction:: connect - :async: - -.. autofunction:: unix_connect - :async: - -Using a connection ------------------- - -.. autoclass:: ClientConnection - - .. automethod:: __aiter__ - - .. automethod:: recv - - .. automethod:: recv_streaming - - .. automethod:: send - - .. automethod:: close - - .. automethod:: wait_closed - - .. automethod:: ping - - .. automethod:: pong - - WebSocket connection objects also provide these attributes: - - .. autoattribute:: id - - .. autoattribute:: logger - - .. autoproperty:: local_address - - .. autoproperty:: remote_address - - .. autoattribute:: latency - - .. autoproperty:: state - - The following attributes are available after the opening handshake, - once the WebSocket connection is open: - - .. autoattribute:: request - - .. autoattribute:: response - - .. autoproperty:: subprotocol diff --git a/docs/reference/new-asyncio/common.rst b/docs/reference/new-asyncio/common.rst deleted file mode 100644 index a58325fb9..000000000 --- a/docs/reference/new-asyncio/common.rst +++ /dev/null @@ -1,47 +0,0 @@ -:orphan: - -Both sides (new :mod:`asyncio`) -=============================== - -.. automodule:: websockets.asyncio.connection - -.. autoclass:: Connection - - .. automethod:: __aiter__ - - .. automethod:: recv - - .. automethod:: recv_streaming - - .. automethod:: send - - .. automethod:: close - - .. automethod:: wait_closed - - .. automethod:: ping - - .. automethod:: pong - - WebSocket connection objects also provide these attributes: - - .. autoattribute:: id - - .. autoattribute:: logger - - .. autoproperty:: local_address - - .. autoproperty:: remote_address - - .. autoattribute:: latency - - .. autoproperty:: state - - The following attributes are available after the opening handshake, - once the WebSocket connection is open: - - .. autoattribute:: request - - .. autoattribute:: response - - .. autoproperty:: subprotocol diff --git a/docs/reference/new-asyncio/server.rst b/docs/reference/new-asyncio/server.rst deleted file mode 100644 index 7bceca5a0..000000000 --- a/docs/reference/new-asyncio/server.rst +++ /dev/null @@ -1,83 +0,0 @@ -Server (new :mod:`asyncio`) -=========================== - -.. automodule:: websockets.asyncio.server - -Creating a server ------------------ - -.. autofunction:: serve - :async: - -.. autofunction:: unix_serve - :async: - -Running a server ----------------- - -.. autoclass:: WebSocketServer - - .. automethod:: close - - .. automethod:: wait_closed - - .. automethod:: get_loop - - .. automethod:: is_serving - - .. automethod:: start_serving - - .. automethod:: serve_forever - - .. autoattribute:: sockets - -Using a connection ------------------- - -.. autoclass:: ServerConnection - - .. automethod:: __aiter__ - - .. automethod:: recv - - .. automethod:: recv_streaming - - .. automethod:: send - - .. automethod:: close - - .. automethod:: wait_closed - - .. automethod:: ping - - .. automethod:: pong - - .. automethod:: respond - - WebSocket connection objects also provide these attributes: - - .. autoattribute:: id - - .. autoattribute:: logger - - .. autoproperty:: local_address - - .. autoproperty:: remote_address - - .. autoattribute:: latency - - .. autoproperty:: state - - The following attributes are available after the opening handshake, - once the WebSocket connection is open: - - .. autoattribute:: request - - .. autoattribute:: response - - .. autoproperty:: subprotocol - -Broadcast ---------- - -.. autofunction:: websockets.asyncio.connection.broadcast diff --git a/docs/topics/authentication.rst b/docs/topics/authentication.rst index 31bc8e6da..86d2e2587 100644 --- a/docs/topics/authentication.rst +++ b/docs/topics/authentication.rst @@ -212,7 +212,9 @@ the user. If authentication fails, it returns an HTTP 401: .. code-block:: python - class QueryParamProtocol(websockets.WebSocketServerProtocol): + from websockets.legacy.server import WebSocketServerProtocol + + class QueryParamProtocol(WebSocketServerProtocol): async def process_request(self, path, headers): token = get_query_parameter(path, "token") if token is None: @@ -258,7 +260,9 @@ the user. If authentication fails, it returns an HTTP 401: .. code-block:: python - class CookieProtocol(websockets.WebSocketServerProtocol): + from websockets.legacy.server import WebSocketServerProtocol + + class CookieProtocol(WebSocketServerProtocol): async def process_request(self, path, headers): # Serve iframe on non-WebSocket requests ... @@ -299,7 +303,9 @@ the user. If authentication fails, it returns an HTTP 401: .. code-block:: python - class UserInfoProtocol(websockets.BasicAuthWebSocketServerProtocol): + from websockets.legacy.auth import BasicAuthWebSocketServerProtocol + + class UserInfoProtocol(BasicAuthWebSocketServerProtocol): async def check_credentials(self, username, password): if username != "token": return False @@ -328,8 +334,10 @@ To authenticate a websockets client with HTTP Basic Authentication .. code-block:: python - async with websockets.connect( - f"wss://{username}:{password}@example.com", + from websockets.legacy.client import connect + + async with connect( + f"wss://{username}:{password}@example.com" ) as websocket: ... @@ -341,7 +349,9 @@ To authenticate a websockets client with HTTP Bearer Authentication .. code-block:: python - async with websockets.connect( + from websockets.legacy.client import connect + + async with connect( "wss://example.com", extra_headers={"Authorization": f"Bearer {token}"} ) as websocket: diff --git a/docs/topics/deployment.rst b/docs/topics/deployment.rst index 2a1fe9a78..48ef72b56 100644 --- a/docs/topics/deployment.rst +++ b/docs/topics/deployment.rst @@ -78,7 +78,7 @@ Option 2 almost always combines with option 3. How do I start a process? ......................... -Run a Python program that invokes :func:`~server.serve`. That's it. +Run a Python program that invokes :func:`~asyncio.server.serve`. That's it. Don't run an ASGI server such as Uvicorn, Hypercorn, or Daphne. They're alternatives to websockets, not complements. @@ -98,18 +98,19 @@ signal and exit the server to ensure a graceful shutdown. Here's an example: .. literalinclude:: ../../example/faq/shutdown_server.py - :emphasize-lines: 12-15,18 + :emphasize-lines: 13-16,19 -When exiting the context manager, :func:`~server.serve` closes all connections +When exiting the context manager, :func:`~asyncio.server.serve` closes all +connections with code 1001 (going away). As a consequence: * If the connection handler is awaiting - :meth:`~server.WebSocketServerProtocol.recv`, it receives a + :meth:`~asyncio.server.ServerConnection.recv`, it receives a :exc:`~exceptions.ConnectionClosedOK` exception. It can catch the exception and clean up before exiting. * Otherwise, it should be waiting on - :meth:`~server.WebSocketServerProtocol.wait_closed`, so it can receive the + :meth:`~asyncio.server.ServerConnection.wait_closed`, so it can receive the :exc:`~exceptions.ConnectionClosedOK` exception and exit. This example is easily adapted to handle other signals. @@ -173,7 +174,7 @@ Load balancers need a way to check whether server processes are up and running to avoid routing connections to a non-functional backend. websockets provide minimal support for responding to HTTP requests with the -:meth:`~server.WebSocketServerProtocol.process_request` hook. +``process_request`` hook. Here's an example: diff --git a/docs/topics/design.rst b/docs/topics/design.rst index cc65e6a70..d2fd18d0c 100644 --- a/docs/topics/design.rst +++ b/docs/topics/design.rst @@ -1,10 +1,11 @@ -Design -====== +Design (legacy :mod:`asyncio`) +============================== -.. currentmodule:: websockets +.. currentmodule:: websockets.legacy -This document describes the design of websockets. It assumes familiarity with -the specification of the WebSocket protocol in :rfc:`6455`. +This document describes the design of the legacy implementation of websockets. +It assumes familiarity with the specification of the WebSocket protocol in +:rfc:`6455`. It's primarily intended at maintainers. It may also be useful for users who wish to understand what happens under the hood. @@ -32,21 +33,19 @@ WebSocket connections go through a trivial state machine: Transitions happen in the following places: - ``CONNECTING -> OPEN``: in - :meth:`~legacy.protocol.WebSocketCommonProtocol.connection_open` which runs when - the :ref:`opening handshake ` completes and the WebSocket + :meth:`~protocol.WebSocketCommonProtocol.connection_open` which runs when the + :ref:`opening handshake ` completes and the WebSocket connection is established — not to be confused with - :meth:`~asyncio.BaseProtocol.connection_made` which runs when the TCP connection - is established; -- ``OPEN -> CLOSING``: in - :meth:`~legacy.protocol.WebSocketCommonProtocol.write_frame` immediately before - sending a close frame; since receiving a close frame triggers sending a - close frame, this does the right thing regardless of which side started the - :ref:`closing handshake `; also in - :meth:`~legacy.protocol.WebSocketCommonProtocol.fail_connection` which duplicates - a few lines of code from ``write_close_frame()`` and ``write_frame()``; -- ``* -> CLOSED``: in - :meth:`~legacy.protocol.WebSocketCommonProtocol.connection_lost` which is always - called exactly once when the TCP connection is closed. + :meth:`~asyncio.BaseProtocol.connection_made` which runs when the TCP + connection is established; +- ``OPEN -> CLOSING``: in :meth:`~protocol.WebSocketCommonProtocol.write_frame` + immediately before sending a close frame; since receiving a close frame + triggers sending a close frame, this does the right thing regardless of which + side started the :ref:`closing handshake `; also in + :meth:`~protocol.WebSocketCommonProtocol.fail_connection` which duplicates a + few lines of code from ``write_close_frame()`` and ``write_frame()``; +- ``* -> CLOSED``: in :meth:`~protocol.WebSocketCommonProtocol.connection_lost` + which is always called exactly once when the TCP connection is closed. Coroutines .......... @@ -57,38 +56,38 @@ connection lifecycle on the client side. .. image:: lifecycle.svg :target: _images/lifecycle.svg -The lifecycle is identical on the server side, except inversion of control -makes the equivalent of :meth:`~client.connect` implicit. +The lifecycle is identical on the server side, except inversion of control makes +the equivalent of :meth:`~client.connect` implicit. Coroutines shown in green are called by the application. Multiple coroutines may interact with the WebSocket connection concurrently. Coroutines shown in gray manage the connection. When the opening handshake -succeeds, :meth:`~legacy.protocol.WebSocketCommonProtocol.connection_open` starts -two tasks: - -- :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` runs - :meth:`~legacy.protocol.WebSocketCommonProtocol.transfer_data` which handles - incoming data and lets :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` - consume it. It may be canceled to terminate the connection. It never exits - with an exception other than :exc:`~asyncio.CancelledError`. See :ref:`data - transfer ` below. - -- :attr:`~legacy.protocol.WebSocketCommonProtocol.keepalive_ping_task` runs - :meth:`~legacy.protocol.WebSocketCommonProtocol.keepalive_ping` which sends Ping +succeeds, :meth:`~protocol.WebSocketCommonProtocol.connection_open` starts two +tasks: + +- :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` runs + :meth:`~protocol.WebSocketCommonProtocol.transfer_data` which handles incoming + data and lets :meth:`~protocol.WebSocketCommonProtocol.recv` consume it. It + may be canceled to terminate the connection. It never exits with an exception + other than :exc:`~asyncio.CancelledError`. See :ref:`data transfer + ` below. + +- :attr:`~protocol.WebSocketCommonProtocol.keepalive_ping_task` runs + :meth:`~protocol.WebSocketCommonProtocol.keepalive_ping` which sends Ping frames at regular intervals and ensures that corresponding Pong frames are - received. It is canceled when the connection terminates. It never exits - with an exception other than :exc:`~asyncio.CancelledError`. + received. It is canceled when the connection terminates. It never exits with + an exception other than :exc:`~asyncio.CancelledError`. -- :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` runs - :meth:`~legacy.protocol.WebSocketCommonProtocol.close_connection` which waits for - the data transfer to terminate, then takes care of closing the TCP - connection. It must not be canceled. It never exits with an exception. See - :ref:`connection termination ` below. +- :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` runs + :meth:`~protocol.WebSocketCommonProtocol.close_connection` which waits for the + data transfer to terminate, then takes care of closing the TCP connection. It + must not be canceled. It never exits with an exception. See :ref:`connection + termination ` below. -Besides, :meth:`~legacy.protocol.WebSocketCommonProtocol.fail_connection` starts -the same :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` when -the opening handshake fails, in order to close the TCP connection. +Besides, :meth:`~protocol.WebSocketCommonProtocol.fail_connection` starts the +same :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` when the +opening handshake fails, in order to close the TCP connection. Splitting the responsibilities between two tasks makes it easier to guarantee that websockets can terminate connections: @@ -99,11 +98,11 @@ that websockets can terminate connections: regardless of whether the connection terminates normally or abnormally. -:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` completes when no +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` completes when no more data will be received on the connection. Under normal circumstances, it exits after exchanging close frames. -:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` completes when +:attr:`~protocol.WebSocketCommonProtocol.close_connection_task` completes when the TCP connection is closed. @@ -113,10 +112,9 @@ Opening handshake ----------------- websockets performs the opening handshake when establishing a WebSocket -connection. On the client side, :meth:`~client.connect` executes it -before returning the protocol to the caller. On the server side, it's executed -before passing the protocol to the ``ws_handler`` coroutine handling the -connection. +connection. On the client side, :meth:`~client.connect` executes it before +returning the protocol to the caller. On the server side, it's executed before +passing the protocol to the ``ws_handler`` coroutine handling the connection. While the opening handshake is asymmetrical — the client sends an HTTP Upgrade request and the server replies with an HTTP Switching Protocols response — @@ -136,9 +134,9 @@ On the client side, :meth:`~client.WebSocketClientProtocol.handshake`: On the server side, :meth:`~server.WebSocketServerProtocol.handshake`: - reads an HTTP request from the network; -- calls :meth:`~server.WebSocketServerProtocol.process_request` which may - abort the WebSocket handshake and return an HTTP response instead; this - hook only makes sense on the server side; +- calls :meth:`~server.WebSocketServerProtocol.process_request` which may abort + the WebSocket handshake and return an HTTP response instead; this hook only + makes sense on the server side; - checks the HTTP request, negotiates ``extensions`` and ``subprotocol``, and configures the protocol accordingly; - builds an HTTP response based on the above and parameters passed to @@ -178,13 +176,13 @@ differences between a server and a client: These differences are so minor that all the logic for `data framing`_, for `sending and receiving data`_ and for `closing the connection`_ is implemented -in the same class, :class:`~legacy.protocol.WebSocketCommonProtocol`. +in the same class, :class:`~protocol.WebSocketCommonProtocol`. .. _data framing: https://www.rfc-editor.org/rfc/rfc6455.html#section-5 .. _sending and receiving data: https://www.rfc-editor.org/rfc/rfc6455.html#section-6 .. _closing the connection: https://www.rfc-editor.org/rfc/rfc6455.html#section-7 -The :attr:`~legacy.protocol.WebSocketCommonProtocol.is_client` attribute tells which +The :attr:`~protocol.WebSocketCommonProtocol.is_client` attribute tells which side a protocol instance is managing. This attribute is defined on the :attr:`~server.WebSocketServerProtocol` and :attr:`~client.WebSocketClientProtocol` classes. @@ -211,11 +209,11 @@ The left side of the diagram shows how websockets receives data. Incoming data is written to a :class:`~asyncio.StreamReader` in order to implement flow control and provide backpressure on the TCP connection. -:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task`, which is started +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`, which is started when the WebSocket connection is established, processes this data. When it receives data frames, it reassembles fragments and puts the resulting -messages in the :attr:`~legacy.protocol.WebSocketCommonProtocol.messages` queue. +messages in the :attr:`~protocol.WebSocketCommonProtocol.messages` queue. When it encounters a control frame: @@ -227,11 +225,11 @@ When it encounters a control frame: Running this process in a task guarantees that control frames are processed promptly. Without such a task, websockets would depend on the application to drive the connection by having exactly one coroutine awaiting -:meth:`~legacy.protocol.WebSocketCommonProtocol.recv` at any time. While this -happens naturally in many use cases, it cannot be relied upon. +:meth:`~protocol.WebSocketCommonProtocol.recv` at any time. While this happens +naturally in many use cases, it cannot be relied upon. -Then :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` fetches the next message -from the :attr:`~legacy.protocol.WebSocketCommonProtocol.messages` queue, with some +Then :meth:`~protocol.WebSocketCommonProtocol.recv` fetches the next message +from the :attr:`~protocol.WebSocketCommonProtocol.messages` queue, with some complexity added for handling backpressure and termination correctly. Sending data @@ -239,19 +237,19 @@ Sending data The right side of the diagram shows how websockets sends data. -:meth:`~legacy.protocol.WebSocketCommonProtocol.send` writes one or several data -frames containing the message. While sending a fragmented message, concurrent -calls to :meth:`~legacy.protocol.WebSocketCommonProtocol.send` are put on hold until -all fragments are sent. This makes concurrent calls safe. +:meth:`~protocol.WebSocketCommonProtocol.send` writes one or several data frames +containing the message. While sending a fragmented message, concurrent calls to +:meth:`~protocol.WebSocketCommonProtocol.send` are put on hold until all +fragments are sent. This makes concurrent calls safe. -:meth:`~legacy.protocol.WebSocketCommonProtocol.ping` writes a ping frame and -yields a :class:`~asyncio.Future` which will be completed when a matching pong -frame is received. +:meth:`~protocol.WebSocketCommonProtocol.ping` writes a ping frame and yields a +:class:`~asyncio.Future` which will be completed when a matching pong frame is +received. -:meth:`~legacy.protocol.WebSocketCommonProtocol.pong` writes a pong frame. +:meth:`~protocol.WebSocketCommonProtocol.pong` writes a pong frame. -:meth:`~legacy.protocol.WebSocketCommonProtocol.close` writes a close frame and -waits for the TCP connection to terminate. +:meth:`~protocol.WebSocketCommonProtocol.close` writes a close frame and waits +for the TCP connection to terminate. Outgoing data is written to a :class:`~asyncio.StreamWriter` in order to implement flow control and provide backpressure from the TCP connection. @@ -262,17 +260,17 @@ Closing handshake ................. When the other side of the connection initiates the closing handshake, -:meth:`~legacy.protocol.WebSocketCommonProtocol.read_message` receives a close -frame while in the ``OPEN`` state. It moves to the ``CLOSING`` state, sends a -close frame, and returns :obj:`None`, causing -:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` to terminate. +:meth:`~protocol.WebSocketCommonProtocol.read_message` receives a close frame +while in the ``OPEN`` state. It moves to the ``CLOSING`` state, sends a close +frame, and returns :obj:`None`, causing +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to terminate. When this side of the connection initiates the closing handshake with -:meth:`~legacy.protocol.WebSocketCommonProtocol.close`, it moves to the ``CLOSING`` +:meth:`~protocol.WebSocketCommonProtocol.close`, it moves to the ``CLOSING`` state and sends a close frame. When the other side sends a close frame, -:meth:`~legacy.protocol.WebSocketCommonProtocol.read_message` receives it in the +:meth:`~protocol.WebSocketCommonProtocol.read_message` receives it in the ``CLOSING`` state and returns :obj:`None`, also causing -:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` to terminate. +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to terminate. If the other side doesn't send a close frame within the connection's close timeout, websockets :ref:`fails the connection `. @@ -289,33 +287,33 @@ Then websockets terminates the TCP connection. Connection termination ---------------------- -:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task`, which is +:attr:`~protocol.WebSocketCommonProtocol.close_connection_task`, which is started when the WebSocket connection is established, is responsible for eventually closing the TCP connection. -First :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` waits -for :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` to terminate, -which may happen as a result of: +First :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` waits for +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to terminate, which +may happen as a result of: - a successful closing handshake: as explained above, this exits the infinite - loop in :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task`; + loop in :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`; - a timeout while waiting for the closing handshake to complete: this cancels - :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task`; + :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`; - a protocol error, including connection errors: depending on the exception, - :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` :ref:`fails the + :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` :ref:`fails the connection ` with a suitable code and exits. -:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` is separate -from :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` to make it -easier to implement the timeout on the closing handshake. Canceling -:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` creates no risk -of canceling :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` -and failing to close the TCP connection, thus leaking resources. +:attr:`~protocol.WebSocketCommonProtocol.close_connection_task` is separate from +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to make it easier +to implement the timeout on the closing handshake. Canceling +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` creates no risk of +canceling :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` and +failing to close the TCP connection, thus leaking resources. -Then :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` cancels -:meth:`~legacy.protocol.WebSocketCommonProtocol.keepalive_ping`. This task has no -protocol compliance responsibilities. Terminating it to avoid leaking it is -the only concern. +Then :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` cancels +:meth:`~protocol.WebSocketCommonProtocol.keepalive_ping`. This task has no +protocol compliance responsibilities. Terminating it to avoid leaking it is the +only concern. Terminating the TCP connection can take up to ``2 * close_timeout`` on the server side and ``3 * close_timeout`` on the client side. Clients start by @@ -335,11 +333,11 @@ If the opening handshake doesn't complete successfully, websockets fails the connection by closing the TCP connection. Once the opening handshake has completed, websockets fails the connection by -canceling :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` -and sending a close frame if appropriate. +canceling :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` and +sending a close frame if appropriate. -:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` exits, unblocking -:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task`, which closes +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` exits, unblocking +:attr:`~protocol.WebSocketCommonProtocol.close_connection_task`, which closes the TCP connection. @@ -348,7 +346,7 @@ the TCP connection. Server shutdown --------------- -:class:`~websockets.server.WebSocketServer` closes asynchronously like +:class:`~server.WebSocketServer` closes asynchronously like :class:`asyncio.Server`. The shutdown happen in two steps: 1. Stop listening and accepting new connections; @@ -356,10 +354,10 @@ Server shutdown the opening handshake is still in progress, with HTTP status code 503 (Service Unavailable). -The first call to :class:`~websockets.server.WebSocketServer.close` starts a -task that performs this sequence. Further calls are ignored. This is the -easiest way to make :class:`~websockets.server.WebSocketServer.close` and -:class:`~websockets.server.WebSocketServer.wait_closed` idempotent. +The first call to :class:`~server.WebSocketServer.close` starts a task that +performs this sequence. Further calls are ignored. This is the easiest way to +make :class:`~server.WebSocketServer.close` and +:class:`~server.WebSocketServer.wait_closed` idempotent. .. _cancellation: @@ -415,45 +413,45 @@ happen on the client side. On the server side, the opening handshake is managed by websockets and nothing results in a cancellation. Once the WebSocket connection is established, internal tasks -:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` and -:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` mustn't get +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` and +:attr:`~protocol.WebSocketCommonProtocol.close_connection_task` mustn't get accidentally canceled if a coroutine that awaits them is canceled. In other words, they must be shielded from cancellation. -:meth:`~legacy.protocol.WebSocketCommonProtocol.recv` waits for the next message in -the queue or for :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` -to terminate, whichever comes first. It relies on :func:`~asyncio.wait` for -waiting on two futures in parallel. As a consequence, even though it's waiting -on a :class:`~asyncio.Future` signaling the next message and on -:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task`, it doesn't +:meth:`~protocol.WebSocketCommonProtocol.recv` waits for the next message in the +queue or for :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to +terminate, whichever comes first. It relies on :func:`~asyncio.wait` for waiting +on two futures in parallel. As a consequence, even though it's waiting on a +:class:`~asyncio.Future` signaling the next message and on +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`, it doesn't propagate cancellation to them. -:meth:`~legacy.protocol.WebSocketCommonProtocol.ensure_open` is called by -:meth:`~legacy.protocol.WebSocketCommonProtocol.send`, -:meth:`~legacy.protocol.WebSocketCommonProtocol.ping`, and -:meth:`~legacy.protocol.WebSocketCommonProtocol.pong`. When the connection state is +:meth:`~protocol.WebSocketCommonProtocol.ensure_open` is called by +:meth:`~protocol.WebSocketCommonProtocol.send`, +:meth:`~protocol.WebSocketCommonProtocol.ping`, and +:meth:`~protocol.WebSocketCommonProtocol.pong`. When the connection state is ``CLOSING``, it waits for -:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` but shields it to +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` but shields it to prevent cancellation. -:meth:`~legacy.protocol.WebSocketCommonProtocol.close` waits for the data transfer -task to terminate with :func:`~asyncio.timeout`. If it's canceled or if the -timeout elapses, :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` -is canceled, which is correct at this point. -:meth:`~legacy.protocol.WebSocketCommonProtocol.close` then waits for -:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` but shields it +:meth:`~protocol.WebSocketCommonProtocol.close` waits for the data transfer task +to terminate with :func:`~asyncio.timeout`. If it's canceled or if the timeout +elapses, :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` is +canceled, which is correct at this point. +:meth:`~protocol.WebSocketCommonProtocol.close` then waits for +:attr:`~protocol.WebSocketCommonProtocol.close_connection_task` but shields it to prevent cancellation. -:meth:`~legacy.protocol.WebSocketCommonProtocol.close` and -:meth:`~legacy.protocol.WebSocketCommonProtocol.fail_connection` are the only -places where :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` may -be canceled. +:meth:`~protocol.WebSocketCommonProtocol.close` and +:meth:`~protocol.WebSocketCommonProtocol.fail_connection` are the only places +where :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` may be +canceled. -:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` starts by -waiting for :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task`. It +:attr:`~protocol.WebSocketCommonProtocol.close_connection_task` starts by +waiting for :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`. It catches :exc:`~asyncio.CancelledError` to prevent a cancellation of -:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` from propagating -to :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task`. +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` from propagating to +:attr:`~protocol.WebSocketCommonProtocol.close_connection_task`. .. _backpressure: @@ -491,28 +489,28 @@ buffers and break the backpressure. Be careful with queues. Concurrency ----------- -Awaiting any combination of :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`, -:meth:`~legacy.protocol.WebSocketCommonProtocol.send`, -:meth:`~legacy.protocol.WebSocketCommonProtocol.close` -:meth:`~legacy.protocol.WebSocketCommonProtocol.ping`, or -:meth:`~legacy.protocol.WebSocketCommonProtocol.pong` concurrently is safe, including +Awaiting any combination of :meth:`~protocol.WebSocketCommonProtocol.recv`, +:meth:`~protocol.WebSocketCommonProtocol.send`, +:meth:`~protocol.WebSocketCommonProtocol.close` +:meth:`~protocol.WebSocketCommonProtocol.ping`, or +:meth:`~protocol.WebSocketCommonProtocol.pong` concurrently is safe, including multiple calls to the same method, with one exception and one limitation. -* **Only one coroutine can receive messages at a time.** This constraint - avoids non-deterministic behavior (and simplifies the implementation). If a - coroutine is awaiting :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`, - awaiting it again in another coroutine raises :exc:`RuntimeError`. +* **Only one coroutine can receive messages at a time.** This constraint avoids + non-deterministic behavior (and simplifies the implementation). If a coroutine + is awaiting :meth:`~protocol.WebSocketCommonProtocol.recv`, awaiting it again + in another coroutine raises :exc:`RuntimeError`. * **Sending a fragmented message forces serialization.** Indeed, the WebSocket protocol doesn't support multiplexing messages. If a coroutine is awaiting - :meth:`~legacy.protocol.WebSocketCommonProtocol.send` to send a fragmented message, + :meth:`~protocol.WebSocketCommonProtocol.send` to send a fragmented message, awaiting it again in another coroutine waits until the first call completes. - This will be transparent in many cases. It may be a concern if the - fragmented message is generated slowly by an asynchronous iterator. + This will be transparent in many cases. It may be a concern if the fragmented + message is generated slowly by an asynchronous iterator. Receiving frames is independent from sending frames. This isolates -:meth:`~legacy.protocol.WebSocketCommonProtocol.recv`, which receives frames, from -the other methods, which send frames. +:meth:`~protocol.WebSocketCommonProtocol.recv`, which receives frames, from the +other methods, which send frames. While the connection is open, each frame is sent with a single write. Combined with the concurrency model of :mod:`asyncio`, this enforces serialization. The diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index 765278360..cad49ba55 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -101,9 +101,10 @@ However, this technique runs into two problems: * Even with :meth:`str.format` style, you're restricted to attribute and index lookups, which isn't enough to implement some fairly simple requirements. -There's a better way. :func:`~client.connect` and :func:`~server.serve` accept -a ``logger`` argument to override the default :class:`~logging.Logger`. You -can set ``logger`` to a :class:`~logging.LoggerAdapter` that enriches logs. +There's a better way. :func:`~asyncio.client.connect` and +:func:`~asyncio.server.serve` accept a ``logger`` argument to override the +default :class:`~logging.Logger`. You can set ``logger`` to a +:class:`~logging.LoggerAdapter` that enriches logs. For example, if the server is behind a reverse proxy, :attr:`~legacy.protocol.WebSocketCommonProtocol.remote_address` gives @@ -128,7 +129,7 @@ Here's how to include them in logs, assuming they're in the xff = websocket.request_headers.get("X-Forwarded-For") return f"{websocket.id} {xff} {msg}", kwargs - async with websockets.serve( + async with serve( ..., # Python < 3.10 requires passing None as the second argument. logger=LoggerAdapter(logging.getLogger("websockets.server"), None), @@ -170,7 +171,7 @@ a :class:`~logging.LoggerAdapter`:: } return msg, kwargs - async with websockets.serve( + async with serve( ..., # Python < 3.10 requires passing None as the second argument. logger=LoggerAdapter(logging.getLogger("websockets.server"), None), diff --git a/docs/topics/memory.rst b/docs/topics/memory.rst index efbcbb83f..61b1113e2 100644 --- a/docs/topics/memory.rst +++ b/docs/topics/memory.rst @@ -99,10 +99,11 @@ workloads but it can also backfire because it delays backpressure. messages. * In the legacy :mod:`asyncio` implementation, there is a library-level read - buffer. The ``read_limit`` argument of :func:`~client.connect` and - :func:`~server.serve` controls its size. When the read buffer grows above the - high-water mark, the connection stops reading from the network until it drains - under the low-water mark. This creates backpressure on the TCP connection. + buffer. The ``read_limit`` argument of :func:`~legacy.client.connect` and + :func:`~legacy.server.serve` controls its size. When the read buffer grows + above the high-water mark, the connection stops reading from the network until + it drains under the low-water mark. This creates backpressure on the TCP + connection. There is a write buffer. It as controlled by ``write_limit``. It behaves like the new :mod:`asyncio` implementation described above. diff --git a/docs/topics/security.rst b/docs/topics/security.rst index 83d79e35b..a22b752c7 100644 --- a/docs/topics/security.rst +++ b/docs/topics/security.rst @@ -49,9 +49,9 @@ Identification By default, websockets identifies itself with a ``Server`` or ``User-Agent`` header in the format ``"Python/x.y.z websockets/X.Y"``. -You can set the ``server_header`` argument of :func:`~server.serve` or the -``user_agent_header`` argument of :func:`~client.connect` to configure another -value. Setting them to :obj:`None` removes the header. +You can set the ``server_header`` argument of :func:`~asyncio.server.serve` or +the ``user_agent_header`` argument of :func:`~asyncio.client.connect` to +configure another value. Setting them to :obj:`None` removes the header. Alternatively, you can set the :envvar:`WEBSOCKETS_SERVER` and :envvar:`WEBSOCKETS_USER_AGENT` environment variables respectively. Setting them diff --git a/example/deployment/fly/app.py b/example/deployment/fly/app.py index 4ca34d23b..c8e6af4f9 100644 --- a/example/deployment/fly/app.py +++ b/example/deployment/fly/app.py @@ -4,7 +4,7 @@ import http import signal -import websockets +from websockets.asyncio.server import serve async def echo(websocket): @@ -12,9 +12,9 @@ async def echo(websocket): await websocket.send(message) -async def health_check(path, request_headers): - if path == "/healthz": - return http.HTTPStatus.OK, [], b"OK\n" +def health_check(connection, request): + if request.path == "/healthz": + return connection.respond(http.HTTPStatus.OK, "OK\n") async def main(): @@ -23,7 +23,7 @@ async def main(): stop = loop.create_future() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with websockets.serve( + async with serve( echo, host="", port=8080, diff --git a/example/deployment/haproxy/app.py b/example/deployment/haproxy/app.py index 360479b8e..ef6d9c42d 100644 --- a/example/deployment/haproxy/app.py +++ b/example/deployment/haproxy/app.py @@ -4,7 +4,7 @@ import os import signal -import websockets +from websockets.asyncio.server import serve async def echo(websocket): @@ -18,7 +18,7 @@ async def main(): stop = loop.create_future() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with websockets.serve( + async with serve( echo, host="localhost", port=8000 + int(os.environ["SUPERVISOR_PROCESS_NAME"][-2:]), diff --git a/example/deployment/heroku/app.py b/example/deployment/heroku/app.py index d4ba3edb5..17ad09d26 100644 --- a/example/deployment/heroku/app.py +++ b/example/deployment/heroku/app.py @@ -4,7 +4,7 @@ import signal import os -import websockets +from websockets.asyncio.server import serve async def echo(websocket): @@ -18,7 +18,7 @@ async def main(): stop = loop.create_future() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with websockets.serve( + async with serve( echo, host="", port=int(os.environ["PORT"]), diff --git a/example/deployment/kubernetes/app.py b/example/deployment/kubernetes/app.py index a8bcef688..387f0ade1 100755 --- a/example/deployment/kubernetes/app.py +++ b/example/deployment/kubernetes/app.py @@ -6,7 +6,7 @@ import sys import time -import websockets +from websockets.asyncio.server import serve async def slow_echo(websocket): @@ -17,17 +17,17 @@ async def slow_echo(websocket): await websocket.send(message) -async def health_check(path, request_headers): - if path == "/healthz": - return http.HTTPStatus.OK, [], b"OK\n" - if path == "/inemuri": +def health_check(connection, request): + if request.path == "/healthz": + return connection.respond(http.HTTPStatus.OK, "OK\n") + if request.path == "/inemuri": loop = asyncio.get_running_loop() loop.call_later(1, time.sleep, 10) - return http.HTTPStatus.OK, [], b"Sleeping for 10s\n" - if path == "/seppuku": + return connection.respond(http.HTTPStatus.OK, "Sleeping for 10s\n") + if request.path == "/seppuku": loop = asyncio.get_running_loop() loop.call_later(1, sys.exit, 69) - return http.HTTPStatus.OK, [], b"Terminating\n" + return connection.respond(http.HTTPStatus.OK, "Terminating\n") async def main(): @@ -36,7 +36,7 @@ async def main(): stop = loop.create_future() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with websockets.serve( + async with serve( slow_echo, host="", port=80, diff --git a/example/deployment/kubernetes/benchmark.py b/example/deployment/kubernetes/benchmark.py index 22ee4c5bd..11a452d55 100755 --- a/example/deployment/kubernetes/benchmark.py +++ b/example/deployment/kubernetes/benchmark.py @@ -2,14 +2,15 @@ import asyncio import sys -import websockets + +from websockets.asyncio.client import connect URI = "ws://localhost:32080" async def run(client_id, messages): - async with websockets.connect(URI) as websocket: + async with connect(URI) as websocket: for message_id in range(messages): await websocket.send(f"{client_id}:{message_id}") await websocket.recv() diff --git a/example/deployment/nginx/app.py b/example/deployment/nginx/app.py index 24e608975..134070f61 100644 --- a/example/deployment/nginx/app.py +++ b/example/deployment/nginx/app.py @@ -4,7 +4,7 @@ import os import signal -import websockets +from websockets.asyncio.server import unix_serve async def echo(websocket): @@ -18,7 +18,7 @@ async def main(): stop = loop.create_future() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with websockets.unix_serve( + async with unix_serve( echo, path=f"{os.environ['SUPERVISOR_PROCESS_NAME']}.sock", ): diff --git a/example/deployment/render/app.py b/example/deployment/render/app.py index 4ca34d23b..c8e6af4f9 100644 --- a/example/deployment/render/app.py +++ b/example/deployment/render/app.py @@ -4,7 +4,7 @@ import http import signal -import websockets +from websockets.asyncio.server import serve async def echo(websocket): @@ -12,9 +12,9 @@ async def echo(websocket): await websocket.send(message) -async def health_check(path, request_headers): - if path == "/healthz": - return http.HTTPStatus.OK, [], b"OK\n" +def health_check(connection, request): + if request.path == "/healthz": + return connection.respond(http.HTTPStatus.OK, "OK\n") async def main(): @@ -23,7 +23,7 @@ async def main(): stop = loop.create_future() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with websockets.serve( + async with serve( echo, host="", port=8080, diff --git a/example/deployment/supervisor/app.py b/example/deployment/supervisor/app.py index bf61983ef..5e69f16a6 100644 --- a/example/deployment/supervisor/app.py +++ b/example/deployment/supervisor/app.py @@ -3,7 +3,7 @@ import asyncio import signal -import websockets +from websockets.asyncio.server import serve async def echo(websocket): @@ -17,7 +17,7 @@ async def main(): stop = loop.create_future() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with websockets.serve( + async with serve( echo, host="", port=8080, diff --git a/example/django/authentication.py b/example/django/authentication.py index 83e128f07..c4f12a3f8 100644 --- a/example/django/authentication.py +++ b/example/django/authentication.py @@ -3,11 +3,11 @@ import asyncio import django -import websockets django.setup() from sesame.utils import get_user +from websockets.asyncio.server import serve from websockets.frames import CloseCode @@ -22,7 +22,7 @@ async def handler(websocket): async def main(): - async with websockets.serve(handler, "localhost", 8888): + async with serve(handler, "localhost", 8888): await asyncio.get_running_loop().create_future() # run forever diff --git a/example/django/notifications.py b/example/django/notifications.py index 3a9ed10cf..445438d2d 100644 --- a/example/django/notifications.py +++ b/example/django/notifications.py @@ -5,12 +5,13 @@ import aioredis import django -import websockets django.setup() from django.contrib.contenttypes.models import ContentType from sesame.utils import get_user +from websockets.asyncio.connection import broadcast +from websockets.asyncio.server import serve from websockets.frames import CloseCode @@ -61,11 +62,11 @@ async def process_events(): for websocket, connection in CONNECTIONS.items() if event["content_type_id"] in connection["content_type_ids"] ) - websockets.broadcast(recipients, payload) + broadcast(recipients, payload) async def main(): - async with websockets.serve(handler, "localhost", 8888): + async with serve(handler, "localhost", 8888): await process_events() # runs forever diff --git a/example/echo.py b/example/echo.py index d11b33527..b952a5cfb 100755 --- a/example/echo.py +++ b/example/echo.py @@ -1,7 +1,7 @@ #!/usr/bin/env python import asyncio -from websockets.server import serve +from websockets.asyncio.server import serve async def echo(websocket): async for message in websocket: diff --git a/example/faq/health_check_server.py b/example/faq/health_check_server.py index 6c7681e8a..c0fa4327f 100755 --- a/example/faq/health_check_server.py +++ b/example/faq/health_check_server.py @@ -1,22 +1,19 @@ #!/usr/bin/env python import asyncio -import http -import websockets +from http import HTTPStatus +from websockets.asyncio.server import serve -async def health_check(path, request_headers): - if path == "/healthz": - return http.HTTPStatus.OK, [], b"OK\n" +def health_check(connection, request): + if request.path == "/healthz": + return connection.respond(HTTPStatus.OK, b"OK\n") async def echo(websocket): async for message in websocket: await websocket.send(message) async def main(): - async with websockets.serve( - echo, "localhost", 8765, - process_request=health_check, - ): + async with serve(echo, "localhost", 8765, process_request=health_check): await asyncio.get_running_loop().create_future() # run forever asyncio.run(main()) diff --git a/example/faq/shutdown_client.py b/example/faq/shutdown_client.py index 539dd0304..5c8bd8cbe 100755 --- a/example/faq/shutdown_client.py +++ b/example/faq/shutdown_client.py @@ -2,15 +2,15 @@ import asyncio import signal -import websockets + +from websockets.asyncio.client import connect async def client(): uri = "ws://localhost:8765" - async with websockets.connect(uri) as websocket: + async with connect(uri) as websocket: # Close the connection when receiving SIGTERM. loop = asyncio.get_running_loop() - loop.add_signal_handler( - signal.SIGTERM, loop.create_task, websocket.close()) + loop.add_signal_handler(signal.SIGTERM, loop.create_task, websocket.close()) # Process messages received on the connection. async for message in websocket: diff --git a/example/faq/shutdown_server.py b/example/faq/shutdown_server.py index 1bcc9c90b..3f7bc5732 100755 --- a/example/faq/shutdown_server.py +++ b/example/faq/shutdown_server.py @@ -2,7 +2,8 @@ import asyncio import signal -import websockets + +from websockets.asyncio.server import serve async def echo(websocket): async for message in websocket: @@ -14,7 +15,7 @@ async def server(): stop = loop.create_future() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with websockets.serve(echo, "localhost", 8765): + async with serve(echo, "localhost", 8765): await stop asyncio.run(server()) diff --git a/example/legacy/basic_auth_client.py b/example/legacy/basic_auth_client.py index 164732152..0252894b7 100755 --- a/example/legacy/basic_auth_client.py +++ b/example/legacy/basic_auth_client.py @@ -3,11 +3,12 @@ # WS client example with HTTP Basic Authentication import asyncio -import websockets + +from websockets.legacy.client import connect async def hello(): uri = "ws://mary:p@ssw0rd@localhost:8765" - async with websockets.connect(uri) as websocket: + async with connect(uri) as websocket: greeting = await websocket.recv() print(greeting) diff --git a/example/legacy/basic_auth_server.py b/example/legacy/basic_auth_server.py index 6f6020253..fc45a0270 100755 --- a/example/legacy/basic_auth_server.py +++ b/example/legacy/basic_auth_server.py @@ -3,16 +3,18 @@ # Server example with HTTP Basic Authentication over TLS import asyncio -import websockets + +from websockets.legacy.auth import basic_auth_protocol_factory +from websockets.legacy.server import serve async def hello(websocket): greeting = f"Hello {websocket.username}!" await websocket.send(greeting) async def main(): - async with websockets.serve( + async with serve( hello, "localhost", 8765, - create_protocol=websockets.basic_auth_protocol_factory( + create_protocol=basic_auth_protocol_factory( realm="example", credentials=("mary", "p@ssw0rd") ), ): diff --git a/example/legacy/unix_client.py b/example/legacy/unix_client.py index 926156730..87201c9e4 100755 --- a/example/legacy/unix_client.py +++ b/example/legacy/unix_client.py @@ -4,11 +4,12 @@ import asyncio import os.path -import websockets + +from websockets.legacy.client import unix_connect async def hello(): socket_path = os.path.join(os.path.dirname(__file__), "socket") - async with websockets.unix_connect(socket_path) as websocket: + async with unix_connect(socket_path) as websocket: name = input("What's your name? ") await websocket.send(name) print(f">>> {name}") diff --git a/example/legacy/unix_server.py b/example/legacy/unix_server.py index 5bfb66072..8a4981f5f 100755 --- a/example/legacy/unix_server.py +++ b/example/legacy/unix_server.py @@ -4,7 +4,8 @@ import asyncio import os.path -import websockets + +from websockets.legacy.server import unix_serve async def hello(websocket): name = await websocket.recv() @@ -17,7 +18,7 @@ async def hello(websocket): async def main(): socket_path = os.path.join(os.path.dirname(__file__), "socket") - async with websockets.unix_serve(hello, socket_path): + async with unix_serve(hello, socket_path): await asyncio.get_running_loop().create_future() # run forever asyncio.run(main()) diff --git a/example/logging/json_log_formatter.py b/example/logging/json_log_formatter.py index b8fc8d6dc..ff7fce8b5 100644 --- a/example/logging/json_log_formatter.py +++ b/example/logging/json_log_formatter.py @@ -1,6 +1,6 @@ +import datetime import json import logging -import datetime class JSONFormatter(logging.Formatter): """ diff --git a/example/quickstart/client.py b/example/quickstart/client.py index 8d588c2b0..934af69e3 100755 --- a/example/quickstart/client.py +++ b/example/quickstart/client.py @@ -1,11 +1,12 @@ #!/usr/bin/env python import asyncio -import websockets + +from websockets.asyncio.client import connect async def hello(): uri = "ws://localhost:8765" - async with websockets.connect(uri) as websocket: + async with connect(uri) as websocket: name = input("What's your name? ") await websocket.send(name) diff --git a/example/quickstart/client_secure.py b/example/quickstart/client_secure.py index f4b39f2b8..a1449587a 100755 --- a/example/quickstart/client_secure.py +++ b/example/quickstart/client_secure.py @@ -3,7 +3,8 @@ import asyncio import pathlib import ssl -import websockets + +from websockets.asyncio.client import connect ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) localhost_pem = pathlib.Path(__file__).with_name("localhost.pem") @@ -11,7 +12,7 @@ async def hello(): uri = "wss://localhost:8765" - async with websockets.connect(uri, ssl=ssl_context) as websocket: + async with connect(uri, ssl=ssl_context) as websocket: name = input("What's your name? ") await websocket.send(name) diff --git a/example/quickstart/counter.py b/example/quickstart/counter.py index 414919e04..d42069e64 100755 --- a/example/quickstart/counter.py +++ b/example/quickstart/counter.py @@ -3,7 +3,8 @@ import asyncio import json import logging -import websockets +from websockets.asyncio.connection import broadcast +from websockets.asyncio.server import serve logging.basicConfig() @@ -22,7 +23,7 @@ async def counter(websocket): try: # Register user USERS.add(websocket) - websockets.broadcast(USERS, users_event()) + broadcast(USERS, users_event()) # Send current state to user await websocket.send(value_event()) # Manage state changes @@ -30,19 +31,19 @@ async def counter(websocket): event = json.loads(message) if event["action"] == "minus": VALUE -= 1 - websockets.broadcast(USERS, value_event()) + broadcast(USERS, value_event()) elif event["action"] == "plus": VALUE += 1 - websockets.broadcast(USERS, value_event()) + broadcast(USERS, value_event()) else: logging.error("unsupported event: %s", event) finally: # Unregister user USERS.remove(websocket) - websockets.broadcast(USERS, users_event()) + broadcast(USERS, users_event()) async def main(): - async with websockets.serve(counter, "localhost", 6789): + async with serve(counter, "localhost", 6789): await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": diff --git a/example/quickstart/server.py b/example/quickstart/server.py index 64d7adeb6..bde5e6126 100755 --- a/example/quickstart/server.py +++ b/example/quickstart/server.py @@ -1,7 +1,8 @@ #!/usr/bin/env python import asyncio -import websockets + +from websockets.asyncio.server import serve async def hello(websocket): name = await websocket.recv() @@ -13,7 +14,7 @@ async def hello(websocket): print(f">>> {greeting}") async def main(): - async with websockets.serve(hello, "localhost", 8765): + async with serve(hello, "localhost", 8765): await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": diff --git a/example/quickstart/server_secure.py b/example/quickstart/server_secure.py index 11db5fb3a..8b456ed6e 100755 --- a/example/quickstart/server_secure.py +++ b/example/quickstart/server_secure.py @@ -3,7 +3,8 @@ import asyncio import pathlib import ssl -import websockets + +from websockets.asyncio.server import serve async def hello(websocket): name = await websocket.recv() @@ -19,7 +20,7 @@ async def hello(websocket): ssl_context.load_cert_chain(localhost_pem) async def main(): - async with websockets.serve(hello, "localhost", 8765, ssl=ssl_context): + async with serve(hello, "localhost", 8765, ssl=ssl_context): await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": diff --git a/example/quickstart/show_time.py b/example/quickstart/show_time.py index add226869..8aeb811db 100755 --- a/example/quickstart/show_time.py +++ b/example/quickstart/show_time.py @@ -3,7 +3,8 @@ import asyncio import datetime import random -import websockets + +from websockets.asyncio.server import serve async def show_time(websocket): while True: @@ -12,7 +13,7 @@ async def show_time(websocket): await asyncio.sleep(random.random() * 2 + 1) async def main(): - async with websockets.serve(show_time, "localhost", 5678): + async with serve(show_time, "localhost", 5678): await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": diff --git a/example/quickstart/show_time_2.py b/example/quickstart/show_time_2.py index 08e87f593..4fa244a23 100755 --- a/example/quickstart/show_time_2.py +++ b/example/quickstart/show_time_2.py @@ -3,7 +3,9 @@ import asyncio import datetime import random -import websockets + +from websockets.asyncio.connection import broadcast +from websockets.asyncio.server import serve CONNECTIONS = set() @@ -17,11 +19,11 @@ async def register(websocket): async def show_time(): while True: message = datetime.datetime.utcnow().isoformat() + "Z" - websockets.broadcast(CONNECTIONS, message) + broadcast(CONNECTIONS, message) await asyncio.sleep(random.random() * 2 + 1) async def main(): - async with websockets.serve(register, "localhost", 5678): + async with serve(register, "localhost", 5678): await show_time() if __name__ == "__main__": diff --git a/example/tutorial/step1/app.py b/example/tutorial/step1/app.py index 6ec1c60b8..db69070a1 100644 --- a/example/tutorial/step1/app.py +++ b/example/tutorial/step1/app.py @@ -4,7 +4,7 @@ import itertools import json -import websockets +from websockets.asyncio.server import serve from connect4 import PLAYER1, PLAYER2, Connect4 @@ -57,7 +57,7 @@ async def handler(websocket): async def main(): - async with websockets.serve(handler, "", 8001): + async with serve(handler, "", 8001): await asyncio.get_running_loop().create_future() # run forever diff --git a/example/tutorial/step2/app.py b/example/tutorial/step2/app.py index db3e36374..feaf223a0 100644 --- a/example/tutorial/step2/app.py +++ b/example/tutorial/step2/app.py @@ -4,7 +4,7 @@ import json import secrets -import websockets +from websockets.asyncio.server import serve from connect4 import PLAYER1, PLAYER2, Connect4 @@ -182,7 +182,7 @@ async def handler(websocket): async def main(): - async with websockets.serve(handler, "", 8001): + async with serve(handler, "", 8001): await asyncio.get_running_loop().create_future() # run forever diff --git a/example/tutorial/step3/app.py b/example/tutorial/step3/app.py index c2ee020d2..a428e29e7 100644 --- a/example/tutorial/step3/app.py +++ b/example/tutorial/step3/app.py @@ -6,7 +6,7 @@ import secrets import signal -import websockets +from websockets.asyncio.server import serve from connect4 import PLAYER1, PLAYER2, Connect4 @@ -190,7 +190,7 @@ async def main(): loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) port = int(os.environ.get("PORT", "8001")) - async with websockets.serve(handler, "", port): + async with serve(handler, "", port): await stop diff --git a/experiments/authentication/app.py b/experiments/authentication/app.py index 039e21174..e3b2cf1f6 100644 --- a/experiments/authentication/app.py +++ b/experiments/authentication/app.py @@ -8,8 +8,9 @@ import urllib.parse import uuid -import websockets from websockets.frames import CloseCode +from websockets.legacy.auth import BasicAuthWebSocketServerProtocol +from websockets.legacy.server import WebSocketServerProtocol, serve # User accounts database @@ -107,7 +108,7 @@ async def first_message_handler(websocket): # Add credentials to the WebSocket URI in a query parameter -class QueryParamProtocol(websockets.WebSocketServerProtocol): +class QueryParamProtocol(WebSocketServerProtocol): async def process_request(self, path, headers): token = get_query_param(path, "token") if token is None: @@ -131,7 +132,7 @@ async def query_param_handler(websocket): # Set a cookie on the domain of the WebSocket URI -class CookieProtocol(websockets.WebSocketServerProtocol): +class CookieProtocol(WebSocketServerProtocol): async def process_request(self, path, headers): if "Upgrade" not in headers: template = pathlib.Path(__file__).with_name(path[1:]) @@ -161,7 +162,7 @@ async def cookie_handler(websocket): # Adding credentials to the WebSocket URI in user information -class UserInfoProtocol(websockets.BasicAuthWebSocketServerProtocol): +class UserInfoProtocol(BasicAuthWebSocketServerProtocol): async def check_credentials(self, username, password): if username != "token": return False @@ -192,26 +193,26 @@ async def main(): loop.add_signal_handler(signal.SIGINT, stop.set_result, None) loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with websockets.serve( + async with serve( noop_handler, host="", port=8000, process_request=serve_html, - ), websockets.serve( + ), serve( first_message_handler, host="", port=8001, - ), websockets.serve( + ), serve( query_param_handler, host="", port=8002, create_protocol=QueryParamProtocol, - ), websockets.serve( + ), serve( cookie_handler, host="", port=8003, create_protocol=CookieProtocol, - ), websockets.serve( + ), serve( user_info_handler, host="", port=8004, diff --git a/experiments/broadcast/clients.py b/experiments/broadcast/clients.py index fe39dfe05..64334f20f 100644 --- a/experiments/broadcast/clients.py +++ b/experiments/broadcast/clients.py @@ -5,7 +5,7 @@ import sys import time -import websockets +from websockets.asyncio.client import connect LATENCIES = {} @@ -26,7 +26,7 @@ async def log_latency(interval): async def client(): try: - async with websockets.connect( + async with connect( "ws://localhost:8765", ping_timeout=None, ) as websocket: diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index adb66e262..52cc48898 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -319,8 +319,8 @@ class AbortHandshake(InvalidHandshake): This exception is an implementation detail. - The public API - is :meth:`~websockets.server.WebSocketServerProtocol.process_request`. + The public API is + :meth:`~websockets.legacy.server.WebSocketServerProtocol.process_request`. Attributes: status (~http.HTTPStatus): HTTP status code. diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index 8526bad6b..4d030e5e2 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -125,11 +125,11 @@ def basic_auth_protocol_factory( Protocol factory that enforces HTTP Basic Auth. :func:`basic_auth_protocol_factory` is designed to integrate with - :func:`~websockets.server.serve` like this:: + :func:`~websockets.legacy.server.serve` like this:: - websockets.serve( + serve( ..., - create_protocol=websockets.basic_auth_protocol_factory( + create_protocol=basic_auth_protocol_factory( realm="my dev server", credentials=("hello", "iloveyou"), ) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index b61126c81..256bee14c 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -342,7 +342,7 @@ class Connect: :func:`connect` can be used as a asynchronous context manager:: - async with websockets.connect(...) as websocket: + async with connect(...) as websocket: ... The connection is closed automatically when exiting the context. @@ -350,7 +350,7 @@ class Connect: :func:`connect` can be used as an infinite asynchronous iterator to reconnect automatically on errors:: - async for websocket in websockets.connect(...): + async for websocket in connect(...): try: ... except websockets.ConnectionClosed: diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 191350de3..66eb94199 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -67,13 +67,13 @@ class WebSocketCommonProtocol(asyncio.Protocol): :class:`WebSocketCommonProtocol` provides APIs shared between WebSocket servers and clients. You shouldn't use it directly. Instead, use - :class:`~websockets.client.WebSocketClientProtocol` or - :class:`~websockets.server.WebSocketServerProtocol`. + :class:`~websockets.legacy.client.WebSocketClientProtocol` or + :class:`~websockets.legacy.server.WebSocketServerProtocol`. This documentation focuses on low-level details that aren't covered in the - documentation of :class:`~websockets.client.WebSocketClientProtocol` and - :class:`~websockets.server.WebSocketServerProtocol` for the sake of - simplicity. + documentation of :class:`~websockets.legacy.client.WebSocketClientProtocol` + and :class:`~websockets.legacy.server.WebSocketServerProtocol` for the sake + of simplicity. Once the connection is open, a Ping_ frame is sent every ``ping_interval`` seconds. This serves as a keepalive. It helps keeping the connection open, @@ -89,7 +89,7 @@ class WebSocketCommonProtocol(asyncio.Protocol): .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 - See the discussion of :doc:`timeouts <../../topics/keepalive>` for details. + See the discussion of :doc:`keepalive <../../topics/keepalive>` for details. The ``close_timeout`` parameter defines a maximum wait time for completing the closing handshake and terminating the TCP connection. For legacy @@ -99,8 +99,8 @@ class WebSocketCommonProtocol(asyncio.Protocol): ``close_timeout`` is a parameter of the protocol because websockets usually calls :meth:`close` implicitly upon exit: - * on the client side, when using :func:`~websockets.client.connect` as a - context manager; + * on the client side, when using :func:`~websockets.legacy.client.connect` + as a context manager; * on the server side, when the connection handler terminates. To apply a timeout to any other API, wrap it in :func:`~asyncio.timeout` or diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 9b84a6b81..3cdfeb21b 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -150,7 +150,7 @@ async def test_aiter_connection_closed_ok(self): await anext(aiterator) async def test_aiter_connection_closed_error(self): - """__aiter__ raises ConnnectionClosedError after an error.""" + """__aiter__ raises ConnectionClosedError after an error.""" aiterator = aiter(self.connection) await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) with self.assertRaises(ConnectionClosedError): diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index b3023434b..5d4f0e2f8 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -1,5 +1,4 @@ import asyncio -import dataclasses import http import logging import socket @@ -228,9 +227,7 @@ async def test_process_response_override_response(self): """Server runs process_response and overrides the handshake response.""" def process_response(ws, request, response): - headers = response.headers.copy() - headers["X-ProcessResponse-Ran"] = "true" - return dataclasses.replace(response, headers=headers) + response.headers["X-ProcessResponse-Ran"] = "true" async with run_server(process_response=process_response) as server: async with run_client(server) as client: @@ -242,9 +239,7 @@ async def test_async_process_response_override_response(self): """Server runs async process_response and overrides the handshake response.""" async def process_response(ws, request, response): - headers = response.headers.copy() - headers["X-ProcessResponse-Ran"] = "true" - return dataclasses.replace(response, headers=headers) + response.headers["X-ProcessResponse-Ran"] = "true" async with run_server(process_response=process_response) as server: async with run_client(server) as client: diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 88cbcd669..877adc4bf 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -132,7 +132,7 @@ def test_iter_connection_closed_ok(self): next(iterator) def test_iter_connection_closed_error(self): - """__iter__ raises ConnnectionClosedError after an error.""" + """__iter__ raises ConnectionClosedError after an error.""" iterator = iter(self.connection) self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) with self.assertRaises(ConnectionClosedError): diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 4e04a39d5..c0a5f01e6 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -1,4 +1,3 @@ -import dataclasses import http import logging import socket @@ -173,9 +172,7 @@ def test_process_response_override_response(self): """Server runs process_response and overrides the handshake response.""" def process_response(ws, request, response): - headers = response.headers.copy() - headers["X-ProcessResponse-Ran"] = "true" - return dataclasses.replace(response, headers=headers) + response.headers["X-ProcessResponse-Ran"] = "true" with run_server(process_response=process_response) as server: with run_client(server) as client: From 472f9517b0f8d1f190ae5961fe10a064ef016972 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 19 Aug 2024 19:34:38 +0200 Subject: [PATCH 1336/1539] Explain new asyncio implementation in docs index page. --- docs/index.rst | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index d9737db12..218a489a3 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -28,23 +28,42 @@ with a focus on correctness, simplicity, robustness, and performance. It supports several network I/O and control flow paradigms: -1. The default implementation builds upon :mod:`asyncio`, Python's standard +1. The primary implementation builds upon :mod:`asyncio`, Python's standard asynchronous I/O framework. It provides an elegant coroutine-based API. It's ideal for servers that handle many clients concurrently. + + .. admonition:: As of version :ref:`13.0`, there is a new :mod:`asyncio` + implementation. + :class: important + + The historical implementation in ``websockets.legacy`` traces its roots to + early versions of websockets. Although it's stable and robust, it is now + considered legacy. + + The new implementation in ``websockets.asyncio`` is a rewrite on top of + the Sans-I/O implementation. It adds a few features that were impossible + to implement within the original design. + + The new implementation will become the default as soon as it reaches + feature parity. If you're using the historical implementation, you should + :doc:`ugrade to the new implementation `. It's usually + straightforward. + 2. The :mod:`threading` implementation is a good alternative for clients, especially if you aren't familiar with :mod:`asyncio`. It may also be used for servers that don't need to serve many clients. + 3. The `Sans-I/O`_ implementation is designed for integrating in third-party libraries, typically application servers, in addition being used internally by websockets. .. _Sans-I/O: https://sans-io.readthedocs.io/ -Here's an echo server with the :mod:`asyncio` API: +Here's an echo server using the :mod:`asyncio` API: .. literalinclude:: ../example/echo.py -Here's how a client sends and receives messages with the :mod:`threading` API: +Here's a client using the :mod:`threading` API: .. literalinclude:: ../example/hello.py From 14ca557f53cf19084eb64aef2e4563e5630c211b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 19 Aug 2024 19:35:02 +0200 Subject: [PATCH 1337/1539] Proof-read tutorial. Switch to the new asyncio implementation. Change exception to ValueError when arguments have incorrect values. --- docs/intro/tutorial1.rst | 4 ++-- docs/intro/tutorial2.rst | 38 ++++++++++++++++-------------- example/tutorial/start/connect4.py | 6 ++--- example/tutorial/step1/app.py | 2 +- example/tutorial/step2/app.py | 7 +++--- example/tutorial/step3/app.py | 7 +++--- 6 files changed, 34 insertions(+), 30 deletions(-) diff --git a/docs/intro/tutorial1.rst b/docs/intro/tutorial1.rst index 74f5f79a3..6e91867c8 100644 --- a/docs/intro/tutorial1.rst +++ b/docs/intro/tutorial1.rst @@ -123,7 +123,7 @@ wins. Here's its API. :param player: :data:`~connect4.PLAYER1` or :data:`~connect4.PLAYER2`. :param column: between ``0`` and ``6``. :returns: Row where the checker lands, between ``0`` and ``5``. - :raises RuntimeError: if the move is illegal. + :raises ValueError: if the move is illegal. .. attribute:: moves @@ -520,7 +520,7 @@ Then, you're going to iterate over incoming messages and take these steps: interface sends; * play the move in the board with the :meth:`~connect4.Connect4.play` method, alternating between the two players; -* if :meth:`~connect4.Connect4.play` raises :exc:`RuntimeError` because the +* if :meth:`~connect4.Connect4.play` raises :exc:`ValueError` because the move is illegal, send an event of type ``"error"``; * else, send an event of type ``"play"`` to tell the user interface where the checker lands; diff --git a/docs/intro/tutorial2.rst b/docs/intro/tutorial2.rst index b8e35f292..b5d3a3dc8 100644 --- a/docs/intro/tutorial2.rst +++ b/docs/intro/tutorial2.rst @@ -84,7 +84,7 @@ When the second player joins the game, look it up: async def handler(websocket): ... - join_key = ... # TODO + join_key = ... # Find the Connect Four game. game, connected = JOIN[join_key] @@ -434,7 +434,7 @@ Once the initialization sequence is done, watching a game is as simple as registering the WebSocket connection in the ``connected`` set in order to receive game events and doing nothing until the spectator disconnects. You can wait for a connection to terminate with -:meth:`~legacy.protocol.WebSocketCommonProtocol.wait_closed`: +:meth:`~asyncio.server.ServerConnection.wait_closed`: .. code-block:: python @@ -482,38 +482,40 @@ you're using this pattern: ... Since this is a very common pattern in WebSocket servers, websockets provides -the :func:`~legacy.protocol.broadcast` helper for this purpose: +the :func:`~asyncio.connection.broadcast` helper for this purpose: .. code-block:: python + from websockets.asyncio.connection import broadcast + async def handler(websocket): ... - websockets.broadcast(connected, json.dumps(event)) + broadcast(connected, json.dumps(event)) ... -Calling :func:`legacy.protocol.broadcast` once is more efficient than -calling :meth:`~legacy.protocol.WebSocketCommonProtocol.send` in a loop. +Calling :func:`~asyncio.connection.broadcast` once is more efficient than +calling :meth:`~asyncio.server.ServerConnection.send` in a loop. However, there's a subtle difference in behavior. Did you notice that there's no -``await`` in the second version? Indeed, :func:`legacy.protocol.broadcast` is a -function, not a coroutine like -:meth:`~legacy.protocol.WebSocketCommonProtocol.send` or -:meth:`~legacy.protocol.WebSocketCommonProtocol.recv`. +``await`` in the second version? Indeed, :func:`~asyncio.connection.broadcast` +is a function, not a coroutine like +:meth:`~asyncio.server.ServerConnection.send` or +:meth:`~asyncio.server.ServerConnection.recv`. -It's quite obvious why :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` +It's quite obvious why :meth:`~asyncio.server.ServerConnection.recv` is a coroutine. When you want to receive the next message, you have to wait until the client sends it and the network transmits it. -It's less obvious why :meth:`~legacy.protocol.WebSocketCommonProtocol.send` is +It's less obvious why :meth:`~asyncio.server.ServerConnection.send` is a coroutine. If you send many messages or large messages, you could write data faster than the network can transmit it or the client can read it. Then, outgoing data will pile up in buffers, which will consume memory and may crash your application. -To avoid this problem, :meth:`~legacy.protocol.WebSocketCommonProtocol.send` +To avoid this problem, :meth:`~asyncio.server.ServerConnection.send` waits until the write buffer drains. By slowing down the application as necessary, this ensures that the server doesn't send data too quickly. This is called backpressure and it's useful for building robust systems. @@ -522,12 +524,12 @@ That said, when you're sending the same messages to many clients in a loop, applying backpressure in this way can become counterproductive. When you're broadcasting, you don't want to slow down everyone to the pace of the slowest clients; you want to drop clients that cannot keep up with the data stream. -That's why :func:`legacy.protocol.broadcast` doesn't wait until write buffers -drain. +That's why :func:`~asyncio.connection.broadcast` doesn't wait until write +buffers drain and therefore doesn't need to be a coroutine. -For our Connect Four game, there's no difference in practice: the total amount -of data sent on a connection for a game of Connect Four is less than 64 KB, -so the write buffer never fills up and backpressure never kicks in anyway. +For our Connect Four game, there's no difference in practice. The total amount +of data sent on a connection for a game of Connect Four is so small that the +write buffer cannot fill up. As a consequence, backpressure never kicks in. Summary ------- diff --git a/example/tutorial/start/connect4.py b/example/tutorial/start/connect4.py index 0a61e7c7e..104476962 100644 --- a/example/tutorial/start/connect4.py +++ b/example/tutorial/start/connect4.py @@ -43,15 +43,15 @@ def play(self, player, column): Returns the row where the checker lands. - Raises :exc:`RuntimeError` if the move is illegal. + Raises :exc:`ValueError` if the move is illegal. """ if player == self.last_player: - raise RuntimeError("It isn't your turn.") + raise ValueError("It isn't your turn.") row = self.top[column] if row == 6: - raise RuntimeError("This slot is full.") + raise ValueError("This slot is full.") self.moves.append((player, column, row)) self.top[column] += 1 diff --git a/example/tutorial/step1/app.py b/example/tutorial/step1/app.py index db69070a1..595a10dc7 100644 --- a/example/tutorial/step1/app.py +++ b/example/tutorial/step1/app.py @@ -26,7 +26,7 @@ async def handler(websocket): try: # Play the move. row = game.play(player, column) - except RuntimeError as exc: + except ValueError as exc: # Send an "error" event if the move was illegal. event = { "type": "error", diff --git a/example/tutorial/step2/app.py b/example/tutorial/step2/app.py index feaf223a0..86b2c88c3 100644 --- a/example/tutorial/step2/app.py +++ b/example/tutorial/step2/app.py @@ -4,6 +4,7 @@ import json import secrets +from websockets.asyncio.connection import broadcast from websockets.asyncio.server import serve from connect4 import PLAYER1, PLAYER2, Connect4 @@ -59,7 +60,7 @@ async def play(websocket, game, player, connected): try: # Play the move. row = game.play(player, column) - except RuntimeError as exc: + except ValueError as exc: # Send an "error" event if the move was illegal. await error(websocket, str(exc)) continue @@ -71,7 +72,7 @@ async def play(websocket, game, player, connected): "column": column, "row": row, } - websockets.broadcast(connected, json.dumps(event)) + broadcast(connected, json.dumps(event)) # If move is winning, send a "win" event. if game.winner is not None: @@ -79,7 +80,7 @@ async def play(websocket, game, player, connected): "type": "win", "player": game.winner, } - websockets.broadcast(connected, json.dumps(event)) + broadcast(connected, json.dumps(event)) async def start(websocket): diff --git a/example/tutorial/step3/app.py b/example/tutorial/step3/app.py index a428e29e7..34024d087 100644 --- a/example/tutorial/step3/app.py +++ b/example/tutorial/step3/app.py @@ -6,6 +6,7 @@ import secrets import signal +from websockets.asyncio.connection import broadcast from websockets.asyncio.server import serve from connect4 import PLAYER1, PLAYER2, Connect4 @@ -61,7 +62,7 @@ async def play(websocket, game, player, connected): try: # Play the move. row = game.play(player, column) - except RuntimeError as exc: + except ValueError as exc: # Send an "error" event if the move was illegal. await error(websocket, str(exc)) continue @@ -73,7 +74,7 @@ async def play(websocket, game, player, connected): "column": column, "row": row, } - websockets.broadcast(connected, json.dumps(event)) + broadcast(connected, json.dumps(event)) # If move is winning, send a "win" event. if game.winner is not None: @@ -81,7 +82,7 @@ async def play(websocket, game, player, connected): "type": "win", "player": game.winner, } - websockets.broadcast(connected, json.dumps(event)) + broadcast(connected, json.dumps(event)) async def start(websocket): From 2a17e1dac6a4514f3663c4e15546ebe33bf90e4b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 19 Aug 2024 21:49:46 +0200 Subject: [PATCH 1338/1539] Move broadcast() to the server module. Rationale: * It is only useful for servers. (Maybe there's a use case for client but I couldn't picture it.) * It was already documented in the page covering the server module. * Within this page, it was the only API from the connection or protocol module. The implementation remains in the connection or protocol module because moving it would require refactoring tests. I'd rather keep them simple. (And I'm lazy.) This change doesn't require a backwards compatibility shim because the documentated location of the legacy implementation of broadcast was websockets.broadcast, it's changing with the introduction of the new asyncio API, and the changes are already documented. --- docs/faq/server.rst | 4 ++-- docs/howto/patterns.rst | 2 +- docs/howto/upgrade.rst | 4 ++-- docs/intro/tutorial2.rst | 15 +++++++-------- docs/project/changelog.rst | 4 ++-- docs/reference/asyncio/server.rst | 2 +- docs/reference/features.rst | 2 ++ docs/reference/legacy/server.rst | 10 +++++----- docs/topics/broadcast.rst | 20 ++++++++++---------- docs/topics/logging.rst | 2 +- docs/topics/performance.rst | 4 ++-- example/django/notifications.py | 3 +-- example/quickstart/counter.py | 3 +-- example/quickstart/show_time_2.py | 3 +-- example/tutorial/step2/app.py | 3 +-- example/tutorial/step3/app.py | 3 +-- experiments/broadcast/server.py | 3 +-- src/websockets/__init__.py | 7 ++++--- src/websockets/asyncio/connection.py | 25 ++++++++++++++++++++----- src/websockets/asyncio/server.py | 4 ++-- src/websockets/legacy/protocol.py | 25 ++++++++++++++++++++----- src/websockets/legacy/server.py | 10 ++++++++-- tests/asyncio/test_connection.py | 1 + 23 files changed, 96 insertions(+), 63 deletions(-) diff --git a/docs/faq/server.rst b/docs/faq/server.rst index e6b068316..66e81edfe 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -116,9 +116,9 @@ Record all connections in a global variable:: finally: CONNECTIONS.remove(websocket) -Then, call :func:`~websockets.asyncio.connection.broadcast`:: +Then, call :func:`~websockets.asyncio.server.broadcast`:: - from websockets.asyncio.connection import broadcast + from websockets.asyncio.server import broadcast def message_all(message): broadcast(CONNECTIONS, message) diff --git a/docs/howto/patterns.rst b/docs/howto/patterns.rst index 60bc8ab42..bfb78b6ca 100644 --- a/docs/howto/patterns.rst +++ b/docs/howto/patterns.rst @@ -90,7 +90,7 @@ connect and unregister them when they disconnect:: connected.add(websocket) try: # Broadcast a message to all connected clients. - websockets.broadcast(connected, "Hello!") + broadcast(connected, "Hello!") await asyncio.sleep(10) finally: # Unregister. diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 40c8c5ec9..c5320155e 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -162,8 +162,8 @@ Server APIs | ``websockets.server.WebSocketServerProtocol`` |br| | | | :class:`websockets.legacy.server.WebSocketServerProtocol` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.broadcast`` |br| | :func:`websockets.asyncio.connection.broadcast` | -| :func:`websockets.legacy.protocol.broadcast()` | | +| ``websockets.broadcast`` |br| | :func:`websockets.asyncio.server.broadcast` | +| :func:`websockets.legacy.server.broadcast()` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.BasicAuthWebSocketServerProtocol`` |br| | *not available yet* | | ``websockets.auth.BasicAuthWebSocketServerProtocol`` |br| | | diff --git a/docs/intro/tutorial2.rst b/docs/intro/tutorial2.rst index b5d3a3dc8..0211615d1 100644 --- a/docs/intro/tutorial2.rst +++ b/docs/intro/tutorial2.rst @@ -482,11 +482,11 @@ you're using this pattern: ... Since this is a very common pattern in WebSocket servers, websockets provides -the :func:`~asyncio.connection.broadcast` helper for this purpose: +the :func:`~asyncio.server.broadcast` helper for this purpose: .. code-block:: python - from websockets.asyncio.connection import broadcast + from websockets.asyncio.server import broadcast async def handler(websocket): @@ -496,13 +496,12 @@ the :func:`~asyncio.connection.broadcast` helper for this purpose: ... -Calling :func:`~asyncio.connection.broadcast` once is more efficient than +Calling :func:`~asyncio.server.broadcast` once is more efficient than calling :meth:`~asyncio.server.ServerConnection.send` in a loop. However, there's a subtle difference in behavior. Did you notice that there's no -``await`` in the second version? Indeed, :func:`~asyncio.connection.broadcast` -is a function, not a coroutine like -:meth:`~asyncio.server.ServerConnection.send` or +``await`` in the second version? Indeed, :func:`~asyncio.server.broadcast` is a +function, not a coroutine like :meth:`~asyncio.server.ServerConnection.send` or :meth:`~asyncio.server.ServerConnection.recv`. It's quite obvious why :meth:`~asyncio.server.ServerConnection.recv` @@ -524,8 +523,8 @@ That said, when you're sending the same messages to many clients in a loop, applying backpressure in this way can become counterproductive. When you're broadcasting, you don't want to slow down everyone to the pace of the slowest clients; you want to drop clients that cannot keep up with the data stream. -That's why :func:`~asyncio.connection.broadcast` doesn't wait until write -buffers drain and therefore doesn't need to be a coroutine. +That's why :func:`~asyncio.server.broadcast` doesn't wait until write buffers +drain and therefore doesn't need to be a coroutine. For our Connect Four game, there's no difference in practice. The total amount of data sent on a connection for a game of Connect Four is so small that the diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index df5af54f4..e85c3a395 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -212,7 +212,7 @@ Improvements * Added platform-independent wheels. -* Improved error handling in :func:`~legacy.protocol.broadcast`. +* Improved error handling in :func:`~legacy.server.broadcast`. * Set ``server_hostname`` automatically on TLS connections when providing a ``sock`` argument to :func:`~sync.client.connect`. @@ -402,7 +402,7 @@ New features * Added compatibility with Python 3.10. -* Added :func:`~legacy.protocol.broadcast` to send a message to many clients. +* Added :func:`~legacy.server.broadcast` to send a message to many clients. * Added support for reconnecting automatically by using :func:`~legacy.client.connect` as an asynchronous iterator. diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index 7bceca5a0..541c9952c 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -80,4 +80,4 @@ Using a connection Broadcast --------- -.. autofunction:: websockets.asyncio.connection.broadcast +.. autofunction:: websockets.asyncio.server.broadcast diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 6840fe15b..cb0e564f9 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -35,6 +35,8 @@ Both sides +------------------------------------+--------+--------+--------+--------+ | Send a message | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ + | Broadcast a message | ✅ | ❌ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ | Receive a message | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Iterate over received messages | ✅ | ✅ | — | ✅ | diff --git a/docs/reference/legacy/server.rst b/docs/reference/legacy/server.rst index c2758f5a2..b6c383ce7 100644 --- a/docs/reference/legacy/server.rst +++ b/docs/reference/legacy/server.rst @@ -89,6 +89,11 @@ Using a connection .. autoproperty:: close_reason +Broadcast +--------- + +.. autofunction:: websockets.legacy.server.broadcast + Basic authentication -------------------- @@ -106,8 +111,3 @@ websockets supports HTTP Basic Authentication according to .. autoattribute:: username .. automethod:: check_credentials - -Broadcast ---------- - -.. autofunction:: websockets.legacy.protocol.broadcast diff --git a/docs/topics/broadcast.rst b/docs/topics/broadcast.rst index ec358bbd2..c9699feb2 100644 --- a/docs/topics/broadcast.rst +++ b/docs/topics/broadcast.rst @@ -4,19 +4,19 @@ Broadcasting .. currentmodule:: websockets .. admonition:: If you want to send a message to all connected clients, - use :func:`~asyncio.connection.broadcast`. + use :func:`~asyncio.server.broadcast`. :class: tip If you want to learn about its design, continue reading this document. For the legacy :mod:`asyncio` implementation, use - :func:`~legacy.protocol.broadcast`. + :func:`~legacy.server.broadcast`. WebSocket servers often send the same message to all connected clients or to a subset of clients for which the message is relevant. Let's explore options for broadcasting a message, explain the design of -:func:`~asyncio.connection.broadcast`, and discuss alternatives. +:func:`~asyncio.server.broadcast`, and discuss alternatives. For each option, we'll provide a connection handler called ``handler()`` and a function or coroutine called ``broadcast()`` that sends a message to all @@ -124,7 +124,7 @@ connections before the write buffer can fill up. Don't set extreme values of ``write_limit``, ``ping_interval``, or ``ping_timeout`` to ensure that this condition holds! Instead, set reasonable -values and use the built-in :func:`~asyncio.connection.broadcast` function. +values and use the built-in :func:`~asyncio.server.broadcast` function. The concurrent way ------------------ @@ -209,11 +209,11 @@ If a client gets too far behind, eventually it reaches the limit defined by ``ping_timeout`` and websockets terminates the connection. You can refer to the discussion of :doc:`keepalive ` for details. -How :func:`~asyncio.connection.broadcast` works ------------------------------------------------ +How :func:`~asyncio.server.broadcast` works +------------------------------------------- -The built-in :func:`~asyncio.connection.broadcast` function is similar to the -naive way. The main difference is that it doesn't apply backpressure. +The built-in :func:`~asyncio.server.broadcast` function is similar to the naive +way. The main difference is that it doesn't apply backpressure. This provides the best performance by avoiding the overhead of scheduling and running one task per client. @@ -324,7 +324,7 @@ the asynchronous iterator returned by ``subscribe()``. Performance considerations -------------------------- -The built-in :func:`~asyncio.connection.broadcast` function sends all messages +The built-in :func:`~asyncio.server.broadcast` function sends all messages without yielding control to the event loop. So does the naive way when the network and clients are fast and reliable. @@ -346,7 +346,7 @@ However, this isn't possible in general for two reasons: All other patterns discussed above yield control to the event loop once per client because messages are sent by different tasks. This makes them slower -than the built-in :func:`~asyncio.connection.broadcast` function. +than the built-in :func:`~asyncio.server.broadcast` function. There is no major difference between the performance of per-client queues and publish–subscribe. diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index cad49ba55..9580b4c50 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -221,7 +221,7 @@ Here's what websockets logs at each level. ``WARNING`` ........... -* Failures in :func:`~asyncio.connection.broadcast` +* Failures in :func:`~asyncio.server.broadcast` ``INFO`` ........ diff --git a/docs/topics/performance.rst b/docs/topics/performance.rst index b226cec43..b0828fe0d 100644 --- a/docs/topics/performance.rst +++ b/docs/topics/performance.rst @@ -18,5 +18,5 @@ application.) broadcast --------- -:func:`~asyncio.connection.broadcast` is the most efficient way to send a -message to many clients. +:func:`~asyncio.server.broadcast` is the most efficient way to send a message to +many clients. diff --git a/example/django/notifications.py b/example/django/notifications.py index 445438d2d..76ce9c2d7 100644 --- a/example/django/notifications.py +++ b/example/django/notifications.py @@ -10,8 +10,7 @@ from django.contrib.contenttypes.models import ContentType from sesame.utils import get_user -from websockets.asyncio.connection import broadcast -from websockets.asyncio.server import serve +from websockets.asyncio.server import broadcast, serve from websockets.frames import CloseCode diff --git a/example/quickstart/counter.py b/example/quickstart/counter.py index d42069e64..91eedc56a 100755 --- a/example/quickstart/counter.py +++ b/example/quickstart/counter.py @@ -3,8 +3,7 @@ import asyncio import json import logging -from websockets.asyncio.connection import broadcast -from websockets.asyncio.server import serve +from websockets.asyncio.server import broadcast, serve logging.basicConfig() diff --git a/example/quickstart/show_time_2.py b/example/quickstart/show_time_2.py index 4fa244a23..9c9659d14 100755 --- a/example/quickstart/show_time_2.py +++ b/example/quickstart/show_time_2.py @@ -4,8 +4,7 @@ import datetime import random -from websockets.asyncio.connection import broadcast -from websockets.asyncio.server import serve +from websockets.asyncio.server import broadcast, serve CONNECTIONS = set() diff --git a/example/tutorial/step2/app.py b/example/tutorial/step2/app.py index 86b2c88c3..ef3dd9483 100644 --- a/example/tutorial/step2/app.py +++ b/example/tutorial/step2/app.py @@ -4,8 +4,7 @@ import json import secrets -from websockets.asyncio.connection import broadcast -from websockets.asyncio.server import serve +from websockets.asyncio.server import broadcast, serve from connect4 import PLAYER1, PLAYER2, Connect4 diff --git a/example/tutorial/step3/app.py b/example/tutorial/step3/app.py index 34024d087..261057f9a 100644 --- a/example/tutorial/step3/app.py +++ b/example/tutorial/step3/app.py @@ -6,8 +6,7 @@ import secrets import signal -from websockets.asyncio.connection import broadcast -from websockets.asyncio.server import serve +from websockets.asyncio.server import broadcast, serve from connect4 import PLAYER1, PLAYER2, Connect4 diff --git a/experiments/broadcast/server.py b/experiments/broadcast/server.py index 0a5c82b3c..d5b50bd71 100644 --- a/experiments/broadcast/server.py +++ b/experiments/broadcast/server.py @@ -7,8 +7,7 @@ import time from websockets import ConnectionClosed -from websockets.asyncio.server import serve -from websockets.asyncio.connection import broadcast +from websockets.asyncio.server import broadcast, serve CLIENTS = set() diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index fdb028f4c..b618a6dff 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -48,10 +48,10 @@ "unix_connect", # .legacy.protocol "WebSocketCommonProtocol", - "broadcast", # .legacy.server "WebSocketServer", "WebSocketServerProtocol", + "broadcast", "serve", "unix_serve", # .server @@ -102,10 +102,11 @@ basic_auth_protocol_factory, ) from .legacy.client import WebSocketClientProtocol, connect, unix_connect - from .legacy.protocol import WebSocketCommonProtocol, broadcast + from .legacy.protocol import WebSocketCommonProtocol from .legacy.server import ( WebSocketServer, WebSocketServerProtocol, + broadcast, serve, unix_serve, ) @@ -164,10 +165,10 @@ "unix_connect": ".legacy.client", # .legacy.protocol "WebSocketCommonProtocol": ".legacy.protocol", - "broadcast": ".legacy.protocol", # .legacy.server "WebSocketServer": ".legacy.server", "WebSocketServerProtocol": ".legacy.server", + "broadcast": ".legacy.server", "serve": ".legacy.server", "unix_serve": ".legacy.server", # .server diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 284fe2124..a6b909c72 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -34,7 +34,7 @@ from .messages import Assembler -__all__ = ["Connection", "broadcast"] +__all__ = ["Connection"] class Connection(asyncio.Protocol): @@ -1011,6 +1011,12 @@ def eof_received(self) -> None: # Besides, that doesn't work on TLS connections. +# broadcast() is defined in the connection module even though it's primarily +# used by servers and documented in the server module because it works with +# client connections too and because it's easier to test together with the +# Connection class. + + def broadcast( connections: Iterable[Connection], message: Data, @@ -1034,10 +1040,11 @@ def broadcast( ``ping_interval`` and ``ping_timeout`` low to prevent excessive memory usage from slow connections. - Unlike :meth:`~Connection.send`, :func:`broadcast` doesn't support sending - fragmented messages. Indeed, fragmentation is useful for sending large - messages without buffering them in memory, while :func:`broadcast` buffers - one copy per connection as fast as possible. + Unlike :meth:`~websockets.asyncio.connection.Connection.send`, + :func:`broadcast` doesn't support sending fragmented messages. Indeed, + fragmentation is useful for sending large messages without buffering them in + memory, while :func:`broadcast` buffers one copy per connection as fast as + possible. :func:`broadcast` skips connections that aren't open in order to avoid errors on connections where the closing handshake is in progress. @@ -1047,6 +1054,10 @@ def broadcast( set ``raise_exceptions`` to :obj:`True` to record failures and raise all exceptions in a :pep:`654` :exc:`ExceptionGroup`. + While :func:`broadcast` makes more sense for servers, it works identically + with clients, if you have a use case for opening connections to many servers + and broadcasting a message to them. + Args: websockets: WebSocket connections to which the message will be sent. message: Message to send. @@ -1101,3 +1112,7 @@ def broadcast( if raise_exceptions and exceptions: raise ExceptionGroup("skipped broadcast", exceptions) + + +# Pretend that broadcast is actually defined in the server module. +broadcast.__module__ = "websockets.asyncio.server" diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 1f55502bb..35637a18f 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -25,10 +25,10 @@ from ..server import ServerProtocol from ..typing import LoggerLike, Origin, StatusLike, Subprotocol from .compatibility import asyncio_timeout -from .connection import Connection +from .connection import Connection, broadcast -__all__ = ["serve", "unix_serve", "ServerConnection", "WebSocketServer"] +__all__ = ["broadcast", "serve", "unix_serve", "ServerConnection", "WebSocketServer"] class ServerConnection(Connection): diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 66eb94199..e83e146f9 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -53,7 +53,7 @@ from .framing import Frame -__all__ = ["WebSocketCommonProtocol", "broadcast"] +__all__ = ["WebSocketCommonProtocol"] # In order to ensure consistency, the code always checks the current value of @@ -1545,6 +1545,12 @@ def eof_received(self) -> None: self.reader.feed_eof() +# broadcast() is defined in the protocol module even though it's primarily +# used by servers and documented in the server module because it works with +# client connections too and because it's easier to test together with the +# WebSocketCommonProtocol class. + + def broadcast( websockets: Iterable[WebSocketCommonProtocol], message: Data, @@ -1568,10 +1574,11 @@ def broadcast( ``ping_interval`` and ``ping_timeout`` low to prevent excessive memory usage from slow connections. - Unlike :meth:`~WebSocketCommonProtocol.send`, :func:`broadcast` doesn't - support sending fragmented messages. Indeed, fragmentation is useful for - sending large messages without buffering them in memory, while - :func:`broadcast` buffers one copy per connection as fast as possible. + Unlike :meth:`~websockets.legacy.protocol.WebSocketCommonProtocol.send`, + :func:`broadcast` doesn't support sending fragmented messages. Indeed, + fragmentation is useful for sending large messages without buffering them in + memory, while :func:`broadcast` buffers one copy per connection as fast as + possible. :func:`broadcast` skips connections that aren't open in order to avoid errors on connections where the closing handshake is in progress. @@ -1581,6 +1588,10 @@ def broadcast( set ``raise_exceptions`` to :obj:`True` to record failures and raise all exceptions in a :pep:`654` :exc:`ExceptionGroup`. + While :func:`broadcast` makes more sense for servers, it works identically + with clients, if you have a use case for opening connections to many servers + and broadcasting a message to them. + Args: websockets: WebSocket connections to which the message will be sent. message: Message to send. @@ -1629,3 +1640,7 @@ def broadcast( if raise_exceptions and exceptions: raise ExceptionGroup("skipped broadcast", exceptions) + + +# Pretend that broadcast is actually defined in the server module. +broadcast.__module__ = "websockets.legacy.server" diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index d230f009e..43136db3e 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -45,10 +45,16 @@ from ..typing import ExtensionHeader, LoggerLike, Origin, StatusLike, Subprotocol from .handshake import build_response, check_request from .http import read_request -from .protocol import WebSocketCommonProtocol +from .protocol import WebSocketCommonProtocol, broadcast -__all__ = ["serve", "unix_serve", "WebSocketServerProtocol", "WebSocketServer"] +__all__ = [ + "broadcast", + "serve", + "unix_serve", + "WebSocketServerProtocol", + "WebSocketServer", +] # Change to HeadersLike | ... when dropping Python < 3.10. diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 3cdfeb21b..52e4fc5c8 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -9,6 +9,7 @@ from websockets.asyncio.compatibility import TimeoutError, aiter, anext, asyncio_timeout from websockets.asyncio.connection import * +from websockets.asyncio.connection import broadcast from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK from websockets.frames import CloseCode, Frame, Opcode from websockets.protocol import CLIENT, SERVER, Protocol, State From b05fa2cceefcc5cfaba4e0e06e40d588505c8334 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 19 Aug 2024 22:17:14 +0200 Subject: [PATCH 1339/1539] Rename WebSocketServer to Server. The shorter name is better. Representation remains unambiguous: The change isn't applied to the legacy implementation because it has longer names for other API too. --- docs/faq/server.rst | 3 +-- docs/howto/upgrade.rst | 2 +- docs/project/changelog.rst | 30 ++++++++++++++-------- docs/reference/asyncio/server.rst | 2 +- docs/reference/sync/server.rst | 2 +- src/websockets/asyncio/server.py | 42 +++++++++++++++---------------- src/websockets/sync/server.py | 28 ++++++++++++++------- tests/asyncio/client.py | 4 +-- tests/asyncio/test_server.py | 2 +- tests/sync/client.py | 4 +-- tests/sync/test_server.py | 11 +++++--- 11 files changed, 77 insertions(+), 53 deletions(-) diff --git a/docs/faq/server.rst b/docs/faq/server.rst index 66e81edfe..63eb5ffc6 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -313,8 +313,7 @@ Here's an example that terminates cleanly when it receives SIGTERM on Unix: How do I stop a server while keeping existing connections open? --------------------------------------------------------------- -Call the server's :meth:`~WebSocketServer.close` method with -``close_connections=False``. +Call the server's :meth:`~Server.close` method with ``close_connections=False``. Here's how to adapt the example just above:: diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index c5320155e..8d0895638 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -154,7 +154,7 @@ Server APIs | ``websockets.server.unix_serve()`` |br| | | | :func:`websockets.legacy.server.unix_serve` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.WebSocketServer`` |br| | :class:`websockets.asyncio.server.WebSocketServer` | +| ``websockets.WebSocketServer`` |br| | :class:`websockets.asyncio.server.Server` | | ``websockets.server.WebSocketServer`` |br| | | | :class:`websockets.legacy.server.WebSocketServer` | | +-------------------------------------------------------------------+-----------------------------------------------------+ diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index e85c3a395..f4ae76702 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -35,16 +35,6 @@ notice. Backwards-incompatible changes .............................. -.. admonition:: The ``ssl_context`` argument of :func:`~sync.client.connect` - and :func:`~sync.server.serve` in the :mod:`threading` implementation is - renamed to ``ssl``. - :class: note - - This aligns the API of the :mod:`threading` implementation with the - :mod:`asyncio` implementation. - - For backwards compatibility, ``ssl_context`` is still supported. - .. admonition:: Receiving the request path in the second parameter of connection handlers is deprecated. :class: note @@ -60,6 +50,26 @@ Backwards-incompatible changes path = request.path # only if handler() uses the path argument ... +.. admonition:: The ``ssl_context`` argument of :func:`~sync.client.connect` + and :func:`~sync.server.serve` in the :mod:`threading` implementation is + renamed to ``ssl``. + :class: note + + This aligns the API of the :mod:`threading` implementation with the + :mod:`asyncio` implementation. + + For backwards compatibility, ``ssl_context`` is still supported. + +.. admonition:: The ``WebSocketServer`` class in the :mod:`threading` + implementation is renamed to :class:`~sync.server.Server`. + :class: note + + This class isn't designed to be imported or instantiated directly. + :func:`~sync.server.serve` returns an instance. For this reason, + the change should be transparent. + + Regardless, an alias provides backwards compatibility. + New features ............ diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index 541c9952c..bd5a34b19 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -15,7 +15,7 @@ Creating a server Running a server ---------------- -.. autoclass:: WebSocketServer +.. autoclass:: Server .. automethod:: close diff --git a/docs/reference/sync/server.rst b/docs/reference/sync/server.rst index 26ab872c8..23cb04097 100644 --- a/docs/reference/sync/server.rst +++ b/docs/reference/sync/server.rst @@ -13,7 +13,7 @@ Creating a server Running a server ---------------- -.. autoclass:: WebSocketServer +.. autoclass:: Server .. automethod:: serve_forever diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 35637a18f..8ebbddb67 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -28,7 +28,7 @@ from .connection import Connection, broadcast -__all__ = ["broadcast", "serve", "unix_serve", "ServerConnection", "WebSocketServer"] +__all__ = ["broadcast", "serve", "unix_serve", "ServerConnection", "Server"] class ServerConnection(Connection): @@ -60,7 +60,7 @@ class ServerConnection(Connection): def __init__( self, protocol: ServerProtocol, - server: WebSocketServer, + server: Server, *, ping_interval: float | None = 20, ping_timeout: float | None = 20, @@ -223,11 +223,11 @@ def connection_lost(self, exc: Exception | None) -> None: self.request_rcvd.set_result(None) -class WebSocketServer: +class Server: """ WebSocket server returned by :func:`serve`. - This class mirrors the API of :class:`~asyncio.Server`. + This class mirrors the API of :class:`asyncio.Server`. It keeps track of WebSocket connections in order to close them properly when shutting down. @@ -299,16 +299,16 @@ def __init__( def wrap(self, server: asyncio.Server) -> None: """ - Attach to a given :class:`~asyncio.Server`. + Attach to a given :class:`asyncio.Server`. Since :meth:`~asyncio.loop.create_server` doesn't support injecting a custom ``Server`` class, the easiest solution that doesn't rely on private :mod:`asyncio` APIs is to: - - instantiate a :class:`WebSocketServer` + - instantiate a :class:`Server` - give the protocol factory a reference to that instance - call :meth:`~asyncio.loop.create_server` with the factory - - attach the resulting :class:`~asyncio.Server` with this method + - attach the resulting :class:`asyncio.Server` with this method """ self.server = server @@ -378,7 +378,7 @@ def close(self, close_connections: bool = True) -> None: """ Close the server. - * Close the underlying :class:`~asyncio.Server`. + * Close the underlying :class:`asyncio.Server`. * When ``close_connections`` is :obj:`True`, which is the default, close existing connections. Specifically: @@ -402,7 +402,7 @@ async def _close(self, close_connections: bool) -> None: Implementation of :meth:`close`. This calls :meth:`~asyncio.Server.close` on the underlying - :class:`~asyncio.Server` object to stop accepting new connections and + :class:`asyncio.Server` object to stop accepting new connections and then closes open connections with close code 1001. """ @@ -516,7 +516,7 @@ def sockets(self) -> Iterable[socket.socket]: """ return self.server.sockets - async def __aenter__(self) -> WebSocketServer: # pragma: no cover + async def __aenter__(self) -> Server: # pragma: no cover return self async def __aexit__( @@ -543,8 +543,8 @@ class serve: Once the handler completes, either normally or with an exception, the server performs the closing handshake and closes the connection. - This coroutine returns a :class:`WebSocketServer` whose API mirrors - :class:`~asyncio.Server`. Treat it as an asynchronous context manager to + This coroutine returns a :class:`Server` whose API mirrors + :class:`asyncio.Server`. Treat it as an asynchronous context manager to ensure that the server will be closed:: def handler(websocket): @@ -556,8 +556,8 @@ def handler(websocket): async with websockets.asyncio.server.serve(handler, host, port): await stop - Alternatively, call :meth:`~WebSocketServer.serve_forever` to serve requests - and cancel it to stop the server:: + Alternatively, call :meth:`~Server.serve_forever` to serve requests and + cancel it to stop the server:: server = await websockets.asyncio.server.serve(handler, host, port) await server.serve_forever() @@ -638,8 +638,8 @@ def handler(websocket): socket and customize it. * You can set ``start_serving`` to ``False`` to start accepting connections - only after you call :meth:`~WebSocketServer.start_serving()` or - :meth:`~WebSocketServer.serve_forever()`. + only after you call :meth:`~Server.start_serving()` or + :meth:`~Server.serve_forever()`. """ @@ -704,7 +704,7 @@ def __init__( if create_connection is None: create_connection = ServerConnection - self.server = WebSocketServer( + self.server = Server( handler, process_request=process_request, process_response=process_response, @@ -773,7 +773,7 @@ def protocol_select_subprotocol( # async with serve(...) as ...: ... - async def __aenter__(self) -> WebSocketServer: + async def __aenter__(self) -> Server: return await self async def __aexit__( @@ -787,11 +787,11 @@ async def __aexit__( # ... = await serve(...) - def __await__(self) -> Generator[Any, None, WebSocketServer]: + def __await__(self) -> Generator[Any, None, Server]: # Create a suitable iterator by calling __await__ on a coroutine. return self.__await_impl__().__await__() - async def __await_impl__(self) -> WebSocketServer: + async def __await_impl__(self) -> Server: server = await self._create_server self.server.wrap(server) return self.server @@ -805,7 +805,7 @@ def unix_serve( handler: Callable[[ServerConnection], Awaitable[None]], path: str | None = None, **kwargs: Any, -) -> Awaitable[WebSocketServer]: +) -> Awaitable[Server]: """ Create a WebSocket server listening on a Unix socket. diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index b381908ca..85a7e9907 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -24,7 +24,7 @@ from .utils import Deadline -__all__ = ["serve", "unix_serve", "ServerConnection", "WebSocketServer"] +__all__ = ["serve", "unix_serve", "ServerConnection", "Server"] class ServerConnection(Connection): @@ -196,7 +196,7 @@ def recv_events(self) -> None: self.request_rcvd.set() -class WebSocketServer: +class Server: """ WebSocket server returned by :func:`serve`. @@ -283,7 +283,7 @@ def fileno(self) -> int: """ return self.socket.fileno() - def __enter__(self) -> WebSocketServer: + def __enter__(self) -> Server: return self def __exit__( @@ -295,6 +295,16 @@ def __exit__( self.shutdown() +def __getattr__(name: str) -> Any: + if name == "WebSocketServer": + warnings.warn( + "WebSocketServer was renamed to Server", + DeprecationWarning, + ) + return Server + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + def serve( handler: Callable[[ServerConnection], None], host: str | None = None, @@ -340,7 +350,7 @@ def serve( # Escape hatch for advanced customization create_connection: type[ServerConnection] | None = None, **kwargs: Any, -) -> WebSocketServer: +) -> Server: """ Create a WebSocket server listening on ``host`` and ``port``. @@ -353,10 +363,10 @@ def serve( Once the handler completes, either normally or with an exception, the server performs the closing handshake and closes the connection. - This function returns a :class:`WebSocketServer` whose API mirrors + This function returns a :class:`Server` whose API mirrors :class:`~socketserver.BaseServer`. Treat it as a context manager to ensure - that it will be closed and call :meth:`~WebSocketServer.serve_forever` to - serve requests:: + that it will be closed and call :meth:`~Server.serve_forever` to serve + requests:: def handler(websocket): ... @@ -552,14 +562,14 @@ def protocol_select_subprotocol( # Initialize server - return WebSocketServer(sock, conn_handler, logger) + return Server(sock, conn_handler, logger) def unix_serve( handler: Callable[[ServerConnection], None], path: str | None = None, **kwargs: Any, -) -> WebSocketServer: +) -> Server: """ Create a WebSocket server listening on a Unix socket. diff --git a/tests/asyncio/client.py b/tests/asyncio/client.py index e5826add7..a73079c6e 100644 --- a/tests/asyncio/client.py +++ b/tests/asyncio/client.py @@ -1,7 +1,7 @@ import contextlib from websockets.asyncio.client import * -from websockets.asyncio.server import WebSocketServer +from websockets.asyncio.server import Server from .server import get_server_host_port @@ -17,7 +17,7 @@ async def run_client(wsuri_or_server, secure=None, resource_name="/", **kwargs): if isinstance(wsuri_or_server, str): wsuri = wsuri_or_server else: - assert isinstance(wsuri_or_server, WebSocketServer) + assert isinstance(wsuri_or_server, Server) if secure is None: secure = "ssl" in kwargs protocol = "wss" if secure else "ws" diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 5d4f0e2f8..4b637f3af 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -535,7 +535,7 @@ async def test_unsupported_compression(self): class WebSocketServerTests(unittest.IsolatedAsyncioTestCase): async def test_logger(self): - """WebSocketServer accepts a logger argument.""" + """Server accepts a logger argument.""" logger = logging.getLogger("test") async with run_server(logger=logger) as server: self.assertIs(server.logger, logger) diff --git a/tests/sync/client.py b/tests/sync/client.py index 72eb5b8d2..acbf97fa7 100644 --- a/tests/sync/client.py +++ b/tests/sync/client.py @@ -1,7 +1,7 @@ import contextlib from websockets.sync.client import * -from websockets.sync.server import WebSocketServer +from websockets.sync.server import Server __all__ = [ @@ -15,7 +15,7 @@ def run_client(wsuri_or_server, secure=None, resource_name="/", **kwargs): if isinstance(wsuri_or_server, str): wsuri = wsuri_or_server else: - assert isinstance(wsuri_or_server, WebSocketServer) + assert isinstance(wsuri_or_server, Server) if secure is None: # Backwards compatibility: ssl used to be called ssl_context. secure = "ssl" in kwargs or "ssl_context" in kwargs diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index c0a5f01e6..315601eca 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -383,18 +383,18 @@ def test_unsupported_compression(self): class WebSocketServerTests(unittest.TestCase): def test_logger(self): - """WebSocketServer accepts a logger argument.""" + """Server accepts a logger argument.""" logger = logging.getLogger("test") with run_server(logger=logger) as server: self.assertIs(server.logger, logger) def test_fileno(self): - """WebSocketServer provides a fileno attribute.""" + """Server provides a fileno attribute.""" with run_server() as server: self.assertIsInstance(server.fileno(), int) def test_shutdown(self): - """WebSocketServer provides a shutdown method.""" + """Server provides a shutdown method.""" with run_server() as server: server.shutdown() # Check that the server socket is closed. @@ -409,3 +409,8 @@ def test_ssl_context_argument(self): with run_server(ssl_context=SERVER_CONTEXT) as server: with run_client(server, ssl=CLIENT_CONTEXT): pass + + def test_web_socket_server_class(self): + with self.assertDeprecationWarning("WebSocketServer was renamed to Server"): + from websockets.sync.server import WebSocketServer + self.assertIs(WebSocketServer, Server) From 09b1d8d4d585ed6d4e2c0db6e200e48f176215b1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 19 Aug 2024 22:26:56 +0200 Subject: [PATCH 1340/1539] Fix tests on Python < 3.10. --- tests/asyncio/test_connection.py | 17 +++++++++++++++++ tests/legacy/utils.py | 31 ++++++++++++++++--------------- 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 52e4fc5c8..29bb00418 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -44,6 +44,23 @@ async def asyncTearDown(self): await self.remote_connection.close() await self.connection.close() + if sys.version_info[:2] < (3, 10): # pragma: no cover + + @contextlib.contextmanager + def assertNoLogs(self, logger="websockets", level=logging.ERROR): + """ + No message is logged on the given logger with at least the given level. + + """ + with self.assertLogs(logger, level) as logs: + # We want to test that no log message is emitted + # but assertLogs expects at least one log message. + logging.getLogger(logger).log(level, "dummy") + yield + + level_name = logging.getLevelName(level) + self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) + # Test helpers built upon RecordingProtocol and InterceptingConnection. async def assertFrameSent(self, frame): diff --git a/tests/legacy/utils.py b/tests/legacy/utils.py index 5bb56b26f..5b56050d5 100644 --- a/tests/legacy/utils.py +++ b/tests/legacy/utils.py @@ -56,21 +56,22 @@ def run_loop_once(self): self.loop.call_soon(self.loop.stop) self.loop.run_forever() - # Remove when dropping Python < 3.10 - @contextlib.contextmanager - def assertNoLogs(self, logger="websockets", level=logging.ERROR): - """ - No message is logged on the given logger with at least the given level. - - """ - with self.assertLogs(logger, level) as logs: - # We want to test that no log message is emitted - # but assertLogs expects at least one log message. - logging.getLogger(logger).log(level, "dummy") - yield - - level_name = logging.getLevelName(level) - self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) + if sys.version_info[:2] < (3, 10): # pragma: no cover + + @contextlib.contextmanager + def assertNoLogs(self, logger="websockets", level=logging.ERROR): + """ + No message is logged on the given logger with at least the given level. + + """ + with self.assertLogs(logger, level) as logs: + # We want to test that no log message is emitted + # but assertLogs expects at least one log message. + logging.getLogger(logger).log(level, "dummy") + yield + + level_name = logging.getLevelName(level) + self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) def assertDeprecationWarnings(self, recorded_warnings, expected_warnings): """ From 8eaa5a26b667fccb1b9d75034a77b2d6906e2b2e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Aug 2024 07:49:47 +0200 Subject: [PATCH 1341/1539] Document & test process_response modifying the response. a78b5546 inadvertently changed the test from "returning a new response" to "modifying the existing response". Both are supported.. --- src/websockets/asyncio/server.py | 11 +++--- src/websockets/sync/server.py | 13 ++++--- tests/asyncio/test_server.py | 65 +++++++++++++++++++++----------- tests/sync/test_server.py | 33 ++++++++++------ 4 files changed, 78 insertions(+), 44 deletions(-) diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 8ebbddb67..8f04ec318 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -236,15 +236,16 @@ class Server: handler: Connection handler. It receives the WebSocket connection, which is a :class:`ServerConnection`, in argument. process_request: Intercept the request during the opening handshake. - Return an HTTP response to force the response or :obj:`None` to + Return an HTTP response to force the response. Return :obj:`None` to continue normally. When you force an HTTP 101 Continue response, the handshake is successful. Else, the connection is aborted. ``process_request`` may be a function or a coroutine. process_response: Intercept the response during the opening handshake. - Return an HTTP response to force the response or :obj:`None` to - continue normally. When you force an HTTP 101 Continue response, the - handshake is successful. Else, the connection is aborted. - ``process_response`` may be a function or a coroutine. + Modify the response or return a new HTTP response to force the + response. Return :obj:`None` to continue normally. When you force an + HTTP 101 Continue response, the handshake is successful. Else, the + connection is aborted. ``process_response`` may be a function or a + coroutine. server_header: Value of the ``Server`` response header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 85a7e9907..86c162af3 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -401,13 +401,14 @@ def handler(websocket): :meth:`ServerProtocol.select_subprotocol ` method. process_request: Intercept the request during the opening handshake. - Return an HTTP response to force the response or :obj:`None` to - continue normally. When you force an HTTP 101 Continue response, - the handshake is successful. Else, the connection is aborted. + Return an HTTP response to force the response. Return :obj:`None` to + continue normally. When you force an HTTP 101 Continue response, the + handshake is successful. Else, the connection is aborted. process_response: Intercept the response during the opening handshake. - Return an HTTP response to force the response or :obj:`None` to - continue normally. When you force an HTTP 101 Continue response, - the handshake is successful. Else, the connection is aborted. + Modify the response or return a new HTTP response to force the + response. Return :obj:`None` to continue normally. When you force an + HTTP 101 Continue response, the handshake is successful. Else, the + connection is aborted. server_header: Value of the ``Server`` response header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 4b637f3af..b899998f4 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -1,4 +1,5 @@ import asyncio +import dataclasses import http import logging import socket @@ -117,8 +118,8 @@ def select_subprotocol(ws, subprotocols): "server rejected WebSocket connection: HTTP 500", ) - async def test_process_request(self): - """Server runs process_request before processing the handshake.""" + async def test_process_request_returns_none(self): + """Server runs process_request and continues the handshake.""" def process_request(ws, request): self.assertIsInstance(request, Request) @@ -128,8 +129,8 @@ def process_request(ws, request): async with run_client(server) as client: await self.assertEval(client, "ws.process_request_ran", "True") - async def test_async_process_request(self): - """Server runs async process_request before processing the handshake.""" + async def test_async_process_request_returns_none(self): + """Server runs async process_request and continues the handshake.""" async def process_request(ws, request): self.assertIsInstance(request, Request) @@ -139,7 +140,7 @@ async def process_request(ws, request): async with run_client(server) as client: await self.assertEval(client, "ws.process_request_ran", "True") - async def test_process_request_abort_handshake(self): + async def test_process_request_returns_response(self): """Server aborts handshake if process_request returns a response.""" def process_request(ws, request): @@ -154,7 +155,7 @@ def process_request(ws, request): "server rejected WebSocket connection: HTTP 403", ) - async def test_async_process_request_abort_handshake(self): + async def test_async_process_request_returns_response(self): """Server aborts handshake if async process_request returns a response.""" async def process_request(ws, request): @@ -199,8 +200,8 @@ async def process_request(ws, request): "server rejected WebSocket connection: HTTP 500", ) - async def test_process_response(self): - """Server runs process_response after processing the handshake.""" + async def test_process_response_returns_none(self): + """Server runs process_response but keeps the handshake response.""" def process_response(ws, request, response): self.assertIsInstance(request, Request) @@ -211,8 +212,8 @@ def process_response(ws, request, response): async with run_client(server) as client: await self.assertEval(client, "ws.process_response_ran", "True") - async def test_async_process_response(self): - """Server runs async process_response after processing the handshake.""" + async def test_async_process_response_returns_none(self): + """Server runs async process_response but keeps the handshake response.""" async def process_response(ws, request, response): self.assertIsInstance(request, Request) @@ -223,29 +224,49 @@ async def process_response(ws, request, response): async with run_client(server) as client: await self.assertEval(client, "ws.process_response_ran", "True") - async def test_process_response_override_response(self): - """Server runs process_response and overrides the handshake response.""" + async def test_process_response_modifies_response(self): + """Server runs process_response and modifies the handshake response.""" def process_response(ws, request, response): - response.headers["X-ProcessResponse-Ran"] = "true" + response.headers["X-ProcessResponse"] = "OK" async with run_server(process_response=process_response) as server: async with run_client(server) as client: - self.assertEqual( - client.response.headers["X-ProcessResponse-Ran"], "true" - ) + self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") - async def test_async_process_response_override_response(self): - """Server runs async process_response and overrides the handshake response.""" + async def test_async_process_response_modifies_response(self): + """Server runs async process_response and modifies the handshake response.""" async def process_response(ws, request, response): - response.headers["X-ProcessResponse-Ran"] = "true" + response.headers["X-ProcessResponse"] = "OK" async with run_server(process_response=process_response) as server: async with run_client(server) as client: - self.assertEqual( - client.response.headers["X-ProcessResponse-Ran"], "true" - ) + self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") + + async def test_process_response_replaces_response(self): + """Server runs process_response and replaces the handshake response.""" + + def process_response(ws, request, response): + headers = response.headers.copy() + headers["X-ProcessResponse"] = "OK" + return dataclasses.replace(response, headers=headers) + + async with run_server(process_response=process_response) as server: + async with run_client(server) as client: + self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") + + async def test_async_process_response_replaces_response(self): + """Server runs async process_response and replaces the handshake response.""" + + async def process_response(ws, request, response): + headers = response.headers.copy() + headers["X-ProcessResponse"] = "OK" + return dataclasses.replace(response, headers=headers) + + async with run_server(process_response=process_response) as server: + async with run_client(server) as client: + self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") async def test_process_response_raises_exception(self): """Server returns an error if process_response raises an exception.""" diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 315601eca..e3dfeb271 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -1,3 +1,4 @@ +import dataclasses import http import logging import socket @@ -115,8 +116,8 @@ def select_subprotocol(ws, subprotocols): "server rejected WebSocket connection: HTTP 500", ) - def test_process_request(self): - """Server runs process_request before processing the handshake.""" + def test_process_request_returns_none(self): + """Server runs process_request and continues the handshake.""" def process_request(ws, request): self.assertIsInstance(request, Request) @@ -126,7 +127,7 @@ def process_request(ws, request): with run_client(server) as client: self.assertEval(client, "ws.process_request_ran", "True") - def test_process_request_abort_handshake(self): + def test_process_request_returns_response(self): """Server aborts handshake if process_request returns a response.""" def process_request(ws, request): @@ -156,8 +157,8 @@ def process_request(ws, request): "server rejected WebSocket connection: HTTP 500", ) - def test_process_response(self): - """Server runs process_response after processing the handshake.""" + def test_process_response_returns_none(self): + """Server runs process_response but keeps the handshake response.""" def process_response(ws, request, response): self.assertIsInstance(request, Request) @@ -168,17 +169,27 @@ def process_response(ws, request, response): with run_client(server) as client: self.assertEval(client, "ws.process_response_ran", "True") - def test_process_response_override_response(self): - """Server runs process_response and overrides the handshake response.""" + def test_process_response_modifies_response(self): + """Server runs process_response and modifies the handshake response.""" def process_response(ws, request, response): - response.headers["X-ProcessResponse-Ran"] = "true" + response.headers["X-ProcessResponse"] = "OK" with run_server(process_response=process_response) as server: with run_client(server) as client: - self.assertEqual( - client.response.headers["X-ProcessResponse-Ran"], "true" - ) + self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") + + def test_process_response_replaces_response(self): + """Server runs process_response and replaces the handshake response.""" + + def process_response(ws, request, response): + headers = response.headers.copy() + headers["X-ProcessResponse"] = "OK" + return dataclasses.replace(response, headers=headers) + + with run_server(process_response=process_response) as server: + with run_client(server) as client: + self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") def test_process_response_raises_exception(self): """Server returns an error if process_response raises an exception.""" From 9e5b91bf8f9039de0af85e597e7fd643cfd1a139 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Aug 2024 08:26:44 +0200 Subject: [PATCH 1342/1539] Improve documentation of latency. Also fix #1414. --- docs/reference/features.rst | 2 ++ docs/topics/keepalive.rst | 28 ++++++++++++++++++++-------- src/websockets/asyncio/connection.py | 9 +++++---- src/websockets/legacy/protocol.py | 9 +++++---- 4 files changed, 32 insertions(+), 16 deletions(-) diff --git a/docs/reference/features.rst b/docs/reference/features.rst index cb0e564f9..a380f4555 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -59,6 +59,8 @@ Both sides +------------------------------------+--------+--------+--------+--------+ | Heartbeat | ✅ | ❌ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ + | Measure latency | ✅ | ❌ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ | Perform the closing handshake | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Enforce closing timeout | ✅ | ✅ | — | ✅ | diff --git a/docs/topics/keepalive.rst b/docs/topics/keepalive.rst index 1c7a43264..91f11fb11 100644 --- a/docs/topics/keepalive.rst +++ b/docs/topics/keepalive.rst @@ -40,13 +40,16 @@ It loops through these steps: If the Pong frame isn't received, websockets considers the connection broken and closes it. -This mechanism serves two purposes: +This mechanism serves three purposes: 1. It creates a trickle of traffic so that the TCP connection isn't idle and network infrastructure along the path keeps it open ("keepalive"). 2. It detects if the connection drops or becomes so slow that it's unusable in practice ("heartbeat"). In that case, it terminates the connection and your application gets a :exc:`~exceptions.ConnectionClosed` exception. +3. It measures the :attr:`~asyncio.connection.Connection.latency` of the + connection. The time between sending a Ping frame and receiving a matching + Pong frame approximates the round-trip time. Timings are configurable with the ``ping_interval`` and ``ping_timeout`` arguments of :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve`. @@ -54,7 +57,7 @@ Shorter values will detect connection drops faster but they will increase network traffic and they will be more sensitive to latency. Setting ``ping_interval`` to :obj:`None` disables the whole keepalive and -heartbeat mechanism. +heartbeat mechanism, including measurement of latency. Setting ``ping_timeout`` to :obj:`None` disables only timeouts. This enables keepalive, to keep idle connections open, and disables heartbeat, to support large @@ -85,9 +88,23 @@ Unfortunately, the WebSocket API in browsers doesn't expose the native Ping and Pong functionality in the WebSocket protocol. You have to roll your own in the application layer. +Read this `blog post `_ for +a complete walk-through of this issue. + Latency issues -------------- +The :attr:`~asyncio.connection.Connection.latency` attribute stores latency +measured during the last exchange of Ping and Pong frames:: + + latency = websocket.latency + +Alternatively, you can measure the latency at any time by calling +:attr:`~asyncio.connection.Connection.ping` and awaiting its result:: + + pong_waiter = await websocket.ping() + latency = await pong_waiter + Latency between a client and a server may increase for two reasons: * Network connectivity is poor. When network packets are lost, TCP attempts to @@ -97,7 +114,7 @@ Latency between a client and a server may increase for two reasons: * Traffic is high. For example, if a client sends messages on the connection faster than a server can process them, this manifests as latency as well, - because data is waiting in flight, mostly in OS buffers. + because data is waiting in :doc:`buffers `. If the server is more than 20 seconds behind, it doesn't see the Pong before the default timeout elapses. As a consequence, it closes the connection. @@ -109,8 +126,3 @@ Latency between a client and a server may increase for two reasons: The same reasoning applies to situations where the server sends more traffic than the client can accept. - -The latency measured during the last exchange of Ping and Pong frames is -available in the :attr:`~asyncio.connection.Connection.latency` attribute. -Alternatively, you can measure the latency at any time with the -:attr:`~asyncio.connection.Connection.ping` method. diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index a6b909c72..9e7ea3d8c 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -109,12 +109,13 @@ def __init__( """ Latency of the connection, in seconds. - This value is updated after sending a ping frame and receiving a - matching pong frame. Before the first ping, :attr:`latency` is ``0``. + Latency is defined as the round-trip time of the connection. It is + measured by sending a Ping frame and waiting for a matching Pong frame. + Before the first measurement, :attr:`latency` is ``0``. By default, websockets enables a :ref:`keepalive ` mechanism - that sends ping frames automatically at regular intervals. You can also - send ping frames and measure latency with :meth:`ping`. + that sends Ping frames automatically at regular intervals. You can also + send Ping frames and measure latency with :meth:`ping`. """ # Task that sends keepalive pings. None when ping_interval is None. diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index e83e146f9..3b9a8c4aa 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -289,12 +289,13 @@ def __init__( """ Latency of the connection, in seconds. - This value is updated after sending a ping frame and receiving a - matching pong frame. Before the first ping, :attr:`latency` is ``0``. + Latency is defined as the round-trip time of the connection. It is + measured by sending a Ping frame and waiting for a matching Pong frame. + Before the first measurement, :attr:`latency` is ``0``. By default, websockets enables a :ref:`keepalive ` mechanism - that sends ping frames automatically at regular intervals. You can also - send ping frames and measure latency with :meth:`ping`. + that sends Ping frames automatically at regular intervals. You can also + send Ping frames and measure latency with :meth:`ping`. """ # Task running the data transfer. From 453e55ac2a20a50bfd30c0b4c011c50d01e7bb0a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Aug 2024 09:52:31 +0200 Subject: [PATCH 1343/1539] Standardize on raise AssertionError(...). --- experiments/optimization/parse_frames.py | 2 +- experiments/optimization/parse_handshake.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/experiments/optimization/parse_frames.py b/experiments/optimization/parse_frames.py index e3acbe3c2..9ea71c58e 100644 --- a/experiments/optimization/parse_frames.py +++ b/experiments/optimization/parse_frames.py @@ -33,7 +33,7 @@ def parse_frame(data, count, mask, extensions): except StopIteration: pass else: - assert False, "parser should return frame" + raise AssertionError("parser should return frame") reader.feed_eof() assert reader.at_eof(), "parser should consume all data" diff --git a/experiments/optimization/parse_handshake.py b/experiments/optimization/parse_handshake.py index af5a4ecae..393e0215c 100644 --- a/experiments/optimization/parse_handshake.py +++ b/experiments/optimization/parse_handshake.py @@ -71,7 +71,7 @@ def parse_handshake(handshake): except StopIteration: pass else: - assert False, "parser should return request" + raise AssertionError("parser should return request") reader.feed_eof() assert reader.at_eof(), "parser should consume all data" From 9d355bfeb784e41b2a879645113c73c4560c9a91 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Aug 2024 09:53:05 +0200 Subject: [PATCH 1344/1539] Remove unnecessary code paths in keepalive(). Also add comments in tests to clarify the intended sequence. --- src/websockets/asyncio/connection.py | 14 +++++--- tests/asyncio/test_connection.py | 50 ++++++++++++++++++---------- 2 files changed, 42 insertions(+), 22 deletions(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 9e7ea3d8c..005e9b4bb 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -723,6 +723,10 @@ async def keepalive(self) -> None: if self.ping_timeout is not None: try: async with asyncio_timeout(self.ping_timeout): + # connection_lost cancels keepalive immediately + # after setting a ConnectionClosed exception on + # pong_waiter. A CancelledError is raised here, + # not a ConnectionClosed exception. latency = await pong_waiter self.logger.debug("% received keepalive pong") except asyncio.TimeoutError: @@ -733,9 +737,10 @@ async def keepalive(self) -> None: CloseCode.INTERNAL_ERROR, "keepalive ping timeout", ) - break - except ConnectionClosed: - pass + raise AssertionError( + "send_context() should wait for connection_lost(), " + "which cancels keepalive()" + ) except Exception: self.logger.error("keepalive ping failed", exc_info=True) @@ -913,8 +918,7 @@ def connection_lost(self, exc: Exception | None) -> None: self.set_recv_exc(exc) self.recv_messages.close() self.abort_pings() - # If keepalive() was waiting for a pong, abort_pings() terminated it. - # If it was sleeping until the next ping, we need to cancel it now + if self.keepalive_task is not None: self.keepalive_task.cancel() diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 29bb00418..59218de4b 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -890,12 +890,25 @@ async def test_pong_explicit_binary(self): @patch("random.getrandbits") async def test_keepalive(self, getrandbits): - """keepalive sends pings.""" + """keepalive sends pings at ping_interval and measures latency.""" self.connection.ping_interval = 2 * MS getrandbits.return_value = 1918987876 self.connection.start_keepalive() + self.assertEqual(self.connection.latency, 0) + # 2 ms: keepalive() sends a ping frame. + # 2.x ms: a pong frame is received. await asyncio.sleep(3 * MS) + # 3 ms: check that the ping frame was sent. await self.assertFrameSent(Frame(Opcode.PING, b"rand")) + self.assertGreater(self.connection.latency, 0) + self.assertLess(self.connection.latency, MS) + + async def test_disable_keepalive(self): + """keepalive is disabled when ping_interval is None.""" + self.connection.ping_interval = None + self.connection.start_keepalive() + await asyncio.sleep(3 * MS) + await self.assertNoFrameSent() @patch("random.getrandbits") async def test_keepalive_times_out(self, getrandbits): @@ -905,13 +918,14 @@ async def test_keepalive_times_out(self, getrandbits): async with self.drop_frames_rcvd(): getrandbits.return_value = 1918987876 self.connection.start_keepalive() + # 4 ms: keepalive() sends a ping frame. await asyncio.sleep(4 * MS) # Exiting the context manager sleeps for MS. - await self.assertFrameSent(Frame(Opcode.PING, b"rand")) - await asyncio.sleep(MS) - await self.assertFrameSent( - Frame(Opcode.CLOSE, b"\x03\xf3keepalive ping timeout") - ) + # 4.x ms: a pong frame is dropped. + # 6 ms: no pong frame is received; the connection is closed. + await asyncio.sleep(2 * MS) + # 7 ms: check that the connection is closed. + self.assertEqual(self.connection.state, State.CLOSED) @patch("random.getrandbits") async def test_keepalive_ignores_timeout(self, getrandbits): @@ -921,18 +935,14 @@ async def test_keepalive_ignores_timeout(self, getrandbits): async with self.drop_frames_rcvd(): getrandbits.return_value = 1918987876 self.connection.start_keepalive() + # 4 ms: keepalive() sends a ping frame. await asyncio.sleep(4 * MS) # Exiting the context manager sleeps for MS. - await self.assertFrameSent(Frame(Opcode.PING, b"rand")) - await asyncio.sleep(MS) - await self.assertNoFrameSent() - - async def test_disable_keepalive(self): - """keepalive is disabled when ping_interval is None.""" - self.connection.ping_interval = None - self.connection.start_keepalive() - await asyncio.sleep(3 * MS) - await self.assertNoFrameSent() + # 4.x ms: a pong frame is dropped. + # 6 ms: no pong frame is received; the connection remains open. + await asyncio.sleep(2 * MS) + # 7 ms: check that the connection is still open. + self.assertEqual(self.connection.state, State.OPEN) async def test_keepalive_terminates_while_sleeping(self): """keepalive task terminates while waiting to send a ping.""" @@ -945,21 +955,27 @@ async def test_keepalive_terminates_while_sleeping(self): async def test_keepalive_terminates_while_waiting_for_pong(self): """keepalive task terminates while waiting to receive a pong.""" self.connection.ping_interval = 2 * MS + self.connection.ping_timeout = 2 * MS async with self.drop_frames_rcvd(): self.connection.start_keepalive() + # 2 ms: keepalive() sends a ping frame. await asyncio.sleep(2 * MS) # Exiting the context manager sleeps for MS. + # 2.x ms: a pong frame is dropped. + # 3 ms: close the connection before ping_timeout elapses. await self.connection.close() self.assertTrue(self.connection.keepalive_task.done()) async def test_keepalive_reports_errors(self): """keepalive reports unexpected errors in logs.""" self.connection.ping_interval = 2 * MS - # Inject a fault by raising an exception in a pending pong waiter. async with self.drop_frames_rcvd(): self.connection.start_keepalive() + # 2 ms: keepalive() sends a ping frame. await asyncio.sleep(2 * MS) # Exiting the context manager sleeps for MS. + # 2.x ms: a pong frame is dropped. + # 3 ms: inject a fault: raise an exception in the pending pong waiter. pong_waiter = next(iter(self.connection.pong_waiters.values()))[0] with self.assertLogs("websockets", logging.ERROR) as logs: pong_waiter.set_exception(Exception("BOOM")) From 12fa8bc8fcc03a120ceb05700905b6e4698df563 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Aug 2024 14:24:25 +0200 Subject: [PATCH 1345/1539] Complete changelog with changes since 12.0. --- docs/project/changelog.rst | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index f4ae76702..06d8a7774 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -35,6 +35,11 @@ notice. Backwards-incompatible changes .............................. +.. admonition:: websockets 13.0 requires Python ≥ 3.8. + :class: tip + + websockets 12.0 is the last version supporting Python 3.7. + .. admonition:: Receiving the request path in the second parameter of connection handlers is deprecated. :class: note @@ -64,9 +69,8 @@ Backwards-incompatible changes implementation is renamed to :class:`~sync.server.Server`. :class: note - This class isn't designed to be imported or instantiated directly. - :func:`~sync.server.serve` returns an instance. For this reason, - the change should be transparent. + This change should be transparent because this class shouldn't be + instantiated directly; :func:`~sync.server.serve` returns an instance. Regardless, an alias provides backwards compatibility. @@ -91,6 +95,22 @@ New features If you were monkey-patching constants, be aware that they were renamed, which will break your configuration. You must switch to the environment variables. +Improvements +............ + +* The error message in server logs when a header is too long is more explicit. + +Bug fixes +......... + +* Fixed a bug in the :mod:`threading` implementation that could prevent the + program from exiting when a connection wasn't closed properly. + +* Redirecting from a ``ws://`` URI to a ``wss://`` URI now works. + +* ``broadcast(raise_exceptions=True)`` no longer crashes when there isn't any + exception. + .. _12.0: 12.0 From 0019943e551d285ec27c29315a23dcc959a2ec29 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Aug 2024 14:39:43 +0200 Subject: [PATCH 1346/1539] Release version 13.0. --- docs/project/changelog.rst | 2 +- src/websockets/version.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 06d8a7774..7c5998288 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -30,7 +30,7 @@ notice. 13.0 ---- -*In development* +*August 20, 2024* Backwards-incompatible changes .............................. diff --git a/src/websockets/version.py b/src/websockets/version.py index 44709a91b..56c321940 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,9 +18,9 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = False +released = True -tag = version = commit = "12.1" +tag = version = commit = "13.0" if not released: # pragma: no cover From 4d0e0e10c6ebb779780a9e590667661381df78dc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Aug 2024 14:55:32 +0200 Subject: [PATCH 1347/1539] Build sdist and arch-independent wheel with build. This removes the dependency on setuptools, which isn't installed by default anymore, causing the build to fail. --- .github/workflows/release.yml | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index ed52ddd80..184444e56 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -17,22 +17,18 @@ jobs: uses: actions/setup-python@v5 with: python-version: 3.x - - name: Build sdist - run: python setup.py sdist - - name: Save sdist - uses: actions/upload-artifact@v4 - with: - path: dist/*.tar.gz - - name: Install wheel - run: pip install wheel - - name: Build wheel + - name: Install build + run: pip install build + - name: Build sdist & wheel + run: python -m build env: BUILD_EXTENSION: no - run: python setup.py bdist_wheel - - name: Save wheel + - name: Save sdist & wheel uses: actions/upload-artifact@v4 with: - path: dist/*.whl + path: | + dist/*.tar.gz + dist/*.whl wheels: name: Build architecture-specific wheels on ${{ matrix.os }} From f9c20d0e4c9a25b66d4643879bc4594137036793 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Aug 2024 15:06:07 +0200 Subject: [PATCH 1348/1539] Avoid deleting .so files in .direnv or equivalent. --- Makefile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index dacfe2a0b..a69248b6e 100644 --- a/Makefile +++ b/Makefile @@ -30,6 +30,7 @@ build: python setup.py build_ext --inplace clean: - find . -name '*.pyc' -delete -o -name '*.so' -delete + find src -name '*.so' -delete + find . -name '*.pyc' -delete find . -name __pycache__ -delete rm -rf .coverage .mypy_cache build compliance/reports dist docs/_build htmlcov MANIFEST src/websockets.egg-info From 323adef1f3000cf07617d7ee649c27c0801126e6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Aug 2024 15:19:28 +0200 Subject: [PATCH 1349/1539] Migrate to actions/upload-artifact@v4. The version number was increased without accounting for backwards-incompatible changes. --- .github/workflows/release.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 184444e56..4d2b5b75e 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -26,6 +26,7 @@ jobs: - name: Save sdist & wheel uses: actions/upload-artifact@v4 with: + name: dist-architecture-independent path: | dist/*.tar.gz dist/*.whl @@ -58,6 +59,7 @@ jobs: - name: Save wheels uses: actions/upload-artifact@v4 with: + name: dist-${{ matrix.os }} path: wheelhouse/*.whl upload: @@ -74,7 +76,8 @@ jobs: - name: Download artifacts uses: actions/download-artifact@v4 with: - name: artifact + pattern: dist-* + merge-multiple: true path: dist - name: Upload to PyPI uses: pypa/gh-action-pypi-publish@release/v1 From ed2f21e3ce5bb494e0c9a51833b7da6692b5b9fc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Aug 2024 16:12:31 +0200 Subject: [PATCH 1350/1539] Attempt to fix automatic creation of GitHub release. --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 4d2b5b75e..a68714ed1 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -71,7 +71,7 @@ jobs: # Don't release when running the workflow manually from GitHub's UI. if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') permissions: - id-token: write + contents: write steps: - name: Download artifacts uses: actions/download-artifact@v4 From 3944595a1f1de2271b2f682f65b534f27628d3b6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Aug 2024 16:13:40 +0200 Subject: [PATCH 1351/1539] Start version 13.1. --- docs/project/changelog.rst | 7 +++++++ src/websockets/version.py | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 7c5998288..955136ac0 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,13 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. +.. _13.1: + +13.1 +---- + +*In development* + .. _13.0: 13.0 diff --git a/src/websockets/version.py b/src/websockets/version.py index 56c321940..bbda56d6b 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,9 +18,9 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = True +released = False -tag = version = commit = "13.0" +tag = version = commit = "13.1" if not released: # pragma: no cover From 6b1cc94caa2aa8f0395b8d8d2e6e9e211c927e8c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 21 Aug 2024 08:37:37 +0200 Subject: [PATCH 1352/1539] Move prepare_data/ctrl back to the legacy framing module. They are deprecated and they must go away with the legacy implementation. They were only documented in the framing module, not in the new frames module, so this doesn't require more backwards-compatibility shims. --- src/websockets/asyncio/connection.py | 17 ++++++-- src/websockets/frames.py | 50 ------------------------ src/websockets/legacy/framing.py | 58 +++++++++++++++++++++++++--- src/websockets/legacy/protocol.py | 4 +- src/websockets/speedups.c | 1 - src/websockets/sync/connection.py | 17 ++++++-- tests/asyncio/test_connection.py | 10 +++++ tests/legacy/test_framing.py | 56 +++++++++++++++++++++++++++ tests/sync/test_connection.py | 10 +++++ tests/test_frames.py | 56 --------------------------- 10 files changed, 156 insertions(+), 123 deletions(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 005e9b4bb..069b3e1d2 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -20,7 +20,7 @@ ) from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError -from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode, prepare_ctrl +from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode from ..http11 import Request, Response from ..protocol import CLOSED, OPEN, Event, Protocol, State from ..typing import Data, LoggerLike, Subprotocol @@ -597,8 +597,12 @@ async def ping(self, data: Data | None = None) -> Awaitable[float]: the corresponding pong wasn't received yet. """ - if data is not None: - data = prepare_ctrl(data) + if isinstance(data, BytesLike): + data = bytes(data) + elif isinstance(data, str): + data = data.encode() + elif data is not None: + raise TypeError("data must be str or bytes-like") async with self.send_context(): # Protect against duplicates if a payload is explicitly set. @@ -632,7 +636,12 @@ async def pong(self, data: Data = b"") -> None: ConnectionClosed: When the connection is closed. """ - data = prepare_ctrl(data) + if isinstance(data, BytesLike): + data = bytes(data) + elif isinstance(data, str): + data = data.encode() + else: + raise TypeError("data must be str or bytes-like") async with self.send_context(): self.protocol.send_pong(data) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 819fdd742..8e44dd3a2 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -9,7 +9,6 @@ from typing import Callable, Generator, Sequence from . import exceptions, extensions -from .typing import Data try: @@ -29,8 +28,6 @@ "DATA_OPCODES", "CTRL_OPCODES", "Frame", - "prepare_data", - "prepare_ctrl", "Close", ] @@ -354,53 +351,6 @@ def check(self) -> None: raise exceptions.ProtocolError("fragmented control frame") -def prepare_data(data: Data) -> tuple[int, bytes]: - """ - Convert a string or byte-like object to an opcode and a bytes-like object. - - This function is designed for data frames. - - If ``data`` is a :class:`str`, return ``OP_TEXT`` and a :class:`bytes` - object encoding ``data`` in UTF-8. - - If ``data`` is a bytes-like object, return ``OP_BINARY`` and a bytes-like - object. - - Raises: - TypeError: If ``data`` doesn't have a supported type. - - """ - if isinstance(data, str): - return OP_TEXT, data.encode() - elif isinstance(data, BytesLike): - return OP_BINARY, data - else: - raise TypeError("data must be str or bytes-like") - - -def prepare_ctrl(data: Data) -> bytes: - """ - Convert a string or byte-like object to bytes. - - This function is designed for ping and pong frames. - - If ``data`` is a :class:`str`, return a :class:`bytes` object encoding - ``data`` in UTF-8. - - If ``data`` is a bytes-like object, return a :class:`bytes` object. - - Raises: - TypeError: If ``data`` doesn't have a supported type. - - """ - if isinstance(data, str): - return data.encode() - elif isinstance(data, BytesLike): - return bytes(data) - else: - raise TypeError("data must be str or bytes-like") - - @dataclasses.dataclass class Close: """ diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index 1aaca5cc6..4c2f8c23f 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -5,6 +5,8 @@ from .. import extensions, frames from ..exceptions import PayloadTooBig, ProtocolError +from ..frames import BytesLike +from ..typing import Data try: @@ -144,12 +146,58 @@ def write( write(self.new_frame.serialize(mask=mask, extensions=extensions)) +def prepare_data(data: Data) -> tuple[int, bytes]: + """ + Convert a string or byte-like object to an opcode and a bytes-like object. + + This function is designed for data frames. + + If ``data`` is a :class:`str`, return ``OP_TEXT`` and a :class:`bytes` + object encoding ``data`` in UTF-8. + + If ``data`` is a bytes-like object, return ``OP_BINARY`` and a bytes-like + object. + + Raises: + TypeError: If ``data`` doesn't have a supported type. + + """ + if isinstance(data, str): + return frames.Opcode.TEXT, data.encode() + elif isinstance(data, BytesLike): + return frames.Opcode.BINARY, data + else: + raise TypeError("data must be str or bytes-like") + + +def prepare_ctrl(data: Data) -> bytes: + """ + Convert a string or byte-like object to bytes. + + This function is designed for ping and pong frames. + + If ``data`` is a :class:`str`, return a :class:`bytes` object encoding + ``data`` in UTF-8. + + If ``data`` is a bytes-like object, return a :class:`bytes` object. + + Raises: + TypeError: If ``data`` doesn't have a supported type. + + """ + if isinstance(data, str): + return data.encode() + elif isinstance(data, BytesLike): + return bytes(data) + else: + raise TypeError("data must be str or bytes-like") + + +# Backwards compatibility with previously documented public APIs +encode_data = prepare_ctrl + # Backwards compatibility with previously documented public APIs -from ..frames import ( # noqa: E402, F401, I001 - Close, - prepare_ctrl as encode_data, - prepare_data, -) +from ..frames import Close # noqa: E402 F401, I001 def parse_close(data: bytes) -> tuple[int, str]: diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 3b9a8c4aa..998e390d4 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -45,12 +45,10 @@ Close, CloseCode, Opcode, - prepare_ctrl, - prepare_data, ) from ..protocol import State from ..typing import Data, LoggerLike, Subprotocol -from .framing import Frame +from .framing import Frame, prepare_ctrl, prepare_data __all__ = ["WebSocketCommonProtocol"] diff --git a/src/websockets/speedups.c b/src/websockets/speedups.c index a19590419..cb10dedb8 100644 --- a/src/websockets/speedups.c +++ b/src/websockets/speedups.c @@ -19,7 +19,6 @@ _PyBytesLike_AsStringAndSize(PyObject *obj, PyObject **tmp, char **buffer, Py_ss { // This supports bytes, bytearrays, and memoryview objects, // which are common data structures for handling byte streams. - // websockets.framing.prepare_data() returns only these types. // If *tmp isn't NULL, the caller gets a new reference. if (PyBytes_Check(obj)) { diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 88d6aee1f..16e51abda 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -11,7 +11,7 @@ from typing import Any, Iterable, Iterator, Mapping from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError -from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode, prepare_ctrl +from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode from ..http11 import Request, Response from ..protocol import CLOSED, OPEN, Event, Protocol, State from ..typing import Data, LoggerLike, Subprotocol @@ -449,8 +449,12 @@ def ping(self, data: Data | None = None) -> threading.Event: the corresponding pong wasn't received yet. """ - if data is not None: - data = prepare_ctrl(data) + if isinstance(data, BytesLike): + data = bytes(data) + elif isinstance(data, str): + data = data.encode() + elif data is not None: + raise TypeError("data must be str or bytes-like") with self.send_context(): # Protect against duplicates if a payload is explicitly set. @@ -481,7 +485,12 @@ def pong(self, data: Data = b"") -> None: ConnectionClosed: When the connection is closed. """ - data = prepare_ctrl(data) + if isinstance(data, BytesLike): + data = bytes(data) + elif isinstance(data, str): + data = data.encode() + else: + raise TypeError("data must be str or bytes-like") with self.send_context(): self.protocol.send_pong(data) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 59218de4b..78f3adf68 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -869,6 +869,11 @@ async def test_ping_duplicate_payload(self): await self.connection.ping("idem") # doesn't raise an exception + async def test_ping_unsupported_type(self): + """ping raises TypeError when called with an unsupported type.""" + with self.assertRaises(TypeError): + await self.connection.ping([]) + # Test pong. async def test_pong(self): @@ -886,6 +891,11 @@ async def test_pong_explicit_binary(self): await self.connection.pong(b"pong") await self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + async def test_pong_unsupported_type(self): + """pong raises TypeError when called with an unsupported type.""" + with self.assertRaises(TypeError): + await self.connection.pong([]) + # Test keepalive. @patch("random.getrandbits") diff --git a/tests/legacy/test_framing.py b/tests/legacy/test_framing.py index 6f811bd5e..e816b91e0 100644 --- a/tests/legacy/test_framing.py +++ b/tests/legacy/test_framing.py @@ -173,6 +173,62 @@ def decode(frame, *, max_size=None): ) +class PrepareDataTests(unittest.TestCase): + def test_prepare_data_str(self): + self.assertEqual( + prepare_data("café"), + (OP_TEXT, b"caf\xc3\xa9"), + ) + + def test_prepare_data_bytes(self): + self.assertEqual( + prepare_data(b"tea"), + (OP_BINARY, b"tea"), + ) + + def test_prepare_data_bytearray(self): + self.assertEqual( + prepare_data(bytearray(b"tea")), + (OP_BINARY, bytearray(b"tea")), + ) + + def test_prepare_data_memoryview(self): + self.assertEqual( + prepare_data(memoryview(b"tea")), + (OP_BINARY, memoryview(b"tea")), + ) + + def test_prepare_data_list(self): + with self.assertRaises(TypeError): + prepare_data([]) + + def test_prepare_data_none(self): + with self.assertRaises(TypeError): + prepare_data(None) + + +class PrepareCtrlTests(unittest.TestCase): + def test_prepare_ctrl_str(self): + self.assertEqual(prepare_ctrl("café"), b"caf\xc3\xa9") + + def test_prepare_ctrl_bytes(self): + self.assertEqual(prepare_ctrl(b"tea"), b"tea") + + def test_prepare_ctrl_bytearray(self): + self.assertEqual(prepare_ctrl(bytearray(b"tea")), b"tea") + + def test_prepare_ctrl_memoryview(self): + self.assertEqual(prepare_ctrl(memoryview(b"tea")), b"tea") + + def test_prepare_ctrl_list(self): + with self.assertRaises(TypeError): + prepare_ctrl([]) + + def test_prepare_ctrl_none(self): + with self.assertRaises(TypeError): + prepare_ctrl(None) + + class ParseAndSerializeCloseTests(unittest.TestCase): def assertCloseData(self, code, reason, data): """ diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 877adc4bf..d9fb2093b 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -665,6 +665,11 @@ def test_ping_duplicate_payload(self): self.connection.ping("idem") # doesn't raise an exception + def test_ping_unsupported_type(self): + """ping raises TypeError when called with an unsupported type.""" + with self.assertRaises(TypeError): + self.connection.ping([]) + # Test pong. def test_pong(self): @@ -682,6 +687,11 @@ def test_pong_explicit_binary(self): self.connection.pong(b"pong") self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + def test_pong_unsupported_type(self): + """pong raises TypeError when called with an unsupported type.""" + with self.assertRaises(TypeError): + self.connection.pong([]) + # Test attributes. def test_id(self): diff --git a/tests/test_frames.py b/tests/test_frames.py index 3e9f5d6f8..857b41fe4 100644 --- a/tests/test_frames.py +++ b/tests/test_frames.py @@ -377,62 +377,6 @@ def test_pong_binary(self): ) -class PrepareDataTests(unittest.TestCase): - def test_prepare_data_str(self): - self.assertEqual( - prepare_data("café"), - (OP_TEXT, b"caf\xc3\xa9"), - ) - - def test_prepare_data_bytes(self): - self.assertEqual( - prepare_data(b"tea"), - (OP_BINARY, b"tea"), - ) - - def test_prepare_data_bytearray(self): - self.assertEqual( - prepare_data(bytearray(b"tea")), - (OP_BINARY, bytearray(b"tea")), - ) - - def test_prepare_data_memoryview(self): - self.assertEqual( - prepare_data(memoryview(b"tea")), - (OP_BINARY, memoryview(b"tea")), - ) - - def test_prepare_data_list(self): - with self.assertRaises(TypeError): - prepare_data([]) - - def test_prepare_data_none(self): - with self.assertRaises(TypeError): - prepare_data(None) - - -class PrepareCtrlTests(unittest.TestCase): - def test_prepare_ctrl_str(self): - self.assertEqual(prepare_ctrl("café"), b"caf\xc3\xa9") - - def test_prepare_ctrl_bytes(self): - self.assertEqual(prepare_ctrl(b"tea"), b"tea") - - def test_prepare_ctrl_bytearray(self): - self.assertEqual(prepare_ctrl(bytearray(b"tea")), b"tea") - - def test_prepare_ctrl_memoryview(self): - self.assertEqual(prepare_ctrl(memoryview(b"tea")), b"tea") - - def test_prepare_ctrl_list(self): - with self.assertRaises(TypeError): - prepare_ctrl([]) - - def test_prepare_ctrl_none(self): - with self.assertRaises(TypeError): - prepare_ctrl(None) - - class CloseTests(unittest.TestCase): def assertCloseData(self, close, data): """ From d3dcfd150385970a0d1df7187a989680cc426f55 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 21 Aug 2024 23:09:46 +0200 Subject: [PATCH 1353/1539] Add HTTP Basic Auth to asyncio & threading servers. --- docs/howto/upgrade.rst | 86 +++++++++++++--- docs/project/changelog.rst | 6 ++ docs/reference/asyncio/server.rst | 8 ++ docs/reference/features.rst | 2 +- docs/reference/sync/server.rst | 8 ++ docs/topics/compression.rst | 2 - src/websockets/asyncio/client.py | 4 +- src/websockets/asyncio/server.py | 149 ++++++++++++++++++++++++++-- src/websockets/sync/client.py | 4 +- src/websockets/sync/server.py | 133 ++++++++++++++++++++++++- tests/asyncio/test_server.py | 160 ++++++++++++++++++++++++++++++ tests/sync/test_server.py | 145 +++++++++++++++++++++++++++ 12 files changed, 679 insertions(+), 28 deletions(-) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 8d0895638..d68f5fe99 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -71,15 +71,6 @@ Missing features If your application relies on one of them, you should stick to the original implementation until the new implementation supports it in a future release. -HTTP Basic Authentication -......................... - -On the server side, :func:`~asyncio.server.serve` doesn't provide HTTP Basic -Authentication yet. - -For the avoidance of doubt, on the client side, :func:`~asyncio.client.connect` -performs HTTP Basic Authentication. - Following redirects ................... @@ -165,12 +156,12 @@ Server APIs | ``websockets.broadcast`` |br| | :func:`websockets.asyncio.server.broadcast` | | :func:`websockets.legacy.server.broadcast()` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.BasicAuthWebSocketServerProtocol`` |br| | *not available yet* | -| ``websockets.auth.BasicAuthWebSocketServerProtocol`` |br| | | +| ``websockets.BasicAuthWebSocketServerProtocol`` |br| | See below :ref:`how to migrate ` to | +| ``websockets.auth.BasicAuthWebSocketServerProtocol`` |br| | :func:`websockets.asyncio.server.basic_auth`. | | :class:`websockets.legacy.auth.BasicAuthWebSocketServerProtocol` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.basic_auth_protocol_factory()`` |br| | *not available yet* | -| ``websockets.auth.basic_auth_protocol_factory()`` |br| | | +| ``websockets.basic_auth_protocol_factory()`` |br| | See below :ref:`how to migrate ` to | +| ``websockets.auth.basic_auth_protocol_factory()`` |br| | :func:`websockets.asyncio.server.basic_auth`. | | :func:`websockets.legacy.auth.basic_auth_protocol_factory` | | +-------------------------------------------------------------------+-----------------------------------------------------+ @@ -206,6 +197,75 @@ implementation. Depending on your use case, adopting this method may improve performance when streaming large messages. Specifically, it could reduce memory usage. +.. _basic-auth: + +Performing HTTP Basic Authentication +.................................... + +.. admonition:: This section applies only to servers. + :class: tip + + On the client side, :func:`~asyncio.client.connect` performs HTTP Basic + Authentication automatically when the URI contains credentials. + +In the original implementation, the recommended way to add HTTP Basic +Authentication to a server was to set the ``create_protocol`` argument of +:func:`~legacy.server.serve` to a factory function generated by +:func:`~legacy.auth.basic_auth_protocol_factory`:: + + from websockets.legacy.auth import basic_auth_protocol_factory + from websockets.legacy.server import serve + + async with serve(..., create_protocol=basic_auth_protocol_factory(...)): + ... + +In the new implementation, the :func:`~asyncio.server.basic_auth` function +generates a ``process_request`` coroutine that performs HTTP Basic +Authentication:: + + from websockets.asyncio.server import basic_auth, serve + + async with serve(..., process_request=basic_auth(...)): + ... + +:func:`~asyncio.server.basic_auth` accepts either hard coded ``credentials`` or +a ``check_credentials`` coroutine as well as an optional ``realm`` just like +:func:`~legacy.auth.basic_auth_protocol_factory`. Furthermore, +``check_credentials`` may be a function instead of a coroutine. + +This new API has more obvious semantics. That makes it easier to understand and +also easier to extend. + +In the original implementation, overriding ``create_protocol`` changed the type +of connection objects to :class:`~legacy.auth.BasicAuthWebSocketServerProtocol`, +a subclass of :class:`~legacy.server.WebSocketServerProtocol` that performs HTTP +Basic Authentication in its ``process_request`` method. If you wanted to +customize ``process_request`` further, you had: + +* an ill-defined option: add a ``process_request`` argument to + :func:`~legacy.server.serve`; to tell which one would run first, you had to + experiment or read the code; +* a cumbersome option: subclass + :class:`~legacy.auth.BasicAuthWebSocketServerProtocol`, then pass that + subclass in the ``create_protocol`` argument of + :func:`~legacy.auth.basic_auth_protocol_factory`. + +In the new implementation, you just write a ``process_request`` coroutine:: + + from websockets.asyncio.server import basic_auth, serve + + process_basic_auth = basic_auth(...) + + async def process_request(connection, request): + ... # some logic here + response = await process_basic_auth(connection, request) + if response is not None: + return response + ... # more logic here + + async with serve(..., process_request=process_request): + ... + Customizing the opening handshake ................................. diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 955136ac0..d940b1ea2 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -32,6 +32,12 @@ notice. *In development* +New features +............ + +* The new :mod:`asyncio` and :mod:`threading` implementations provide an API for + enforcing HTTP Basic Auth on the server side. + .. _13.0: 13.0 diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index bd5a34b19..d4d20aeb6 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -81,3 +81,11 @@ Broadcast --------- .. autofunction:: websockets.asyncio.server.broadcast + +HTTP Basic Authentication +------------------------- + +websockets supports HTTP Basic Authentication according to +:rfc:`7235` and :rfc:`7617`. + +.. autofunction:: websockets.asyncio.server.basic_auth diff --git a/docs/reference/features.rst b/docs/reference/features.rst index a380f4555..d9941e408 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -118,7 +118,7 @@ Server +------------------------------------+--------+--------+--------+--------+ | Force an HTTP response | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ - | Perform HTTP Basic Authentication | ❌ | ❌ | ❌ | ✅ | + | Perform HTTP Basic Authentication | ✅ | ✅ | ❌ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | ❌ | +------------------------------------+--------+--------+--------+--------+ diff --git a/docs/reference/sync/server.rst b/docs/reference/sync/server.rst index 23cb04097..80e9c17bb 100644 --- a/docs/reference/sync/server.rst +++ b/docs/reference/sync/server.rst @@ -60,3 +60,11 @@ Using a connection .. autoattribute:: response .. autoproperty:: subprotocol + +HTTP Basic Authentication +------------------------- + +websockets supports HTTP Basic Authentication according to +:rfc:`7235` and :rfc:`7617`. + +.. autofunction:: websockets.sync.server.basic_auth diff --git a/docs/topics/compression.rst b/docs/topics/compression.rst index be263e56f..5f09bbf73 100644 --- a/docs/topics/compression.rst +++ b/docs/topics/compression.rst @@ -45,7 +45,6 @@ explicitly with :class:`ClientPerMessageDeflateFactory` or compress_settings={"memLevel": 4}, ), ], - ..., ) serve( @@ -57,7 +56,6 @@ explicitly with :class:`ClientPerMessageDeflateFactory` or compress_settings={"memLevel": 4}, ), ], - ..., ) The Window Bits and Memory Level values in these examples reduce memory usage diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 632d3ac2b..033887e87 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -131,7 +131,9 @@ class connect: :func:`connect` may be used as an asynchronous context manager:: - async with websockets.asyncio.client.connect(...) as websocket: + from websockets.asyncio.client import connect + + async with connect(...) as websocket: ... The connection is closed automatically when exiting the context. diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 8f04ec318..f6cd9a1b5 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import hmac import http import logging import socket @@ -13,13 +14,19 @@ Generator, Iterable, Sequence, + Tuple, + cast, ) -from websockets.frames import CloseCode - +from ..exceptions import InvalidHeader from ..extensions.base import ServerExtensionFactory from ..extensions.permessage_deflate import enable_server_permessage_deflate -from ..headers import validate_subprotocols +from ..frames import CloseCode +from ..headers import ( + build_www_authenticate_basic, + parse_authorization_basic, + validate_subprotocols, +) from ..http11 import SERVER, Request, Response from ..protocol import CONNECTING, Event from ..server import ServerProtocol @@ -28,7 +35,14 @@ from .connection import Connection, broadcast -__all__ = ["broadcast", "serve", "unix_serve", "ServerConnection", "Server"] +__all__ = [ + "broadcast", + "serve", + "unix_serve", + "ServerConnection", + "Server", + "basic_auth", +] class ServerConnection(Connection): @@ -79,6 +93,7 @@ def __init__( ) self.server = server self.request_rcvd: asyncio.Future[None] = self.loop.create_future() + self.username: str # see basic_auth() def respond(self, status: StatusLike, text: str) -> Response: """ @@ -548,19 +563,21 @@ class serve: :class:`asyncio.Server`. Treat it as an asynchronous context manager to ensure that the server will be closed:: + from websockets.asyncio.server import serve + def handler(websocket): ... # set this future to exit the server stop = asyncio.get_running_loop().create_future() - async with websockets.asyncio.server.serve(handler, host, port): + async with serve(handler, host, port): await stop Alternatively, call :meth:`~Server.serve_forever` to serve requests and cancel it to stop the server:: - server = await websockets.asyncio.server.serve(handler, host, port) + server = await serve(handler, host, port) await server.serve_forever() Args: @@ -822,3 +839,123 @@ def unix_serve( """ return serve(handler, unix=True, path=path, **kwargs) + + +def is_credentials(credentials: Any) -> bool: + try: + username, password = credentials + except (TypeError, ValueError): + return False + else: + return isinstance(username, str) and isinstance(password, str) + + +def basic_auth( + realm: str = "", + credentials: tuple[str, str] | Iterable[tuple[str, str]] | None = None, + check_credentials: Callable[[str, str], Awaitable[bool] | bool] | None = None, +) -> Callable[[ServerConnection, Request], Awaitable[Response | None]]: + """ + Factory for ``process_request`` to enforce HTTP Basic Authentication. + + :func:`basic_auth` is designed to integrate with :func:`serve` as follows:: + + from websockets.asyncio.server import basic_auth, serve + + async with serve( + ..., + process_request=basic_auth( + realm="my dev server", + credentials=("hello", "iloveyou"), + ), + ): + + If authentication succeeds, the connection's ``username`` attribute is set. + If it fails, the server responds with an HTTP 401 Unauthorized status. + + One of ``credentials`` or ``check_credentials`` must be provided; not both. + + Args: + realm: Scope of protection. It should contain only ASCII characters + because the encoding of non-ASCII characters is undefined. Refer to + section 2.2 of :rfc:`7235` for details. + credentials: Hard coded authorized credentials. It can be a + ``(username, password)`` pair or a list of such pairs. + check_credentials: Function or coroutine that verifies credentials. + It receives ``username`` and ``password`` arguments and returns + whether they're valid. + Raises: + TypeError: If ``credentials`` or ``check_credentials`` is wrong. + + """ + if (credentials is None) == (check_credentials is None): + raise TypeError("provide either credentials or check_credentials") + + if credentials is not None: + if is_credentials(credentials): + credentials_list = [cast(Tuple[str, str], credentials)] + elif isinstance(credentials, Iterable): + credentials_list = list(cast(Iterable[Tuple[str, str]], credentials)) + if not all(is_credentials(item) for item in credentials_list): + raise TypeError(f"invalid credentials argument: {credentials}") + else: + raise TypeError(f"invalid credentials argument: {credentials}") + + credentials_dict = dict(credentials_list) + + def check_credentials(username: str, password: str) -> bool: + try: + expected_password = credentials_dict[username] + except KeyError: + return False + return hmac.compare_digest(expected_password, password) + + assert check_credentials is not None # help mypy + + async def process_request( + connection: ServerConnection, + request: Request, + ) -> Response | None: + """ + Perform HTTP Basic Authentication. + + If it succeeds, set the connection's ``username`` attribute and return + :obj:`None`. If it fails, return an HTTP 401 Unauthorized responss. + + """ + try: + authorization = request.headers["Authorization"] + except KeyError: + response = connection.respond( + http.HTTPStatus.UNAUTHORIZED, + "Missing credentials\n", + ) + response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm) + return response + + try: + username, password = parse_authorization_basic(authorization) + except InvalidHeader: + response = connection.respond( + http.HTTPStatus.UNAUTHORIZED, + "Unsupported credentials\n", + ) + response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm) + return response + + valid_credentials = check_credentials(username, password) + if isinstance(valid_credentials, Awaitable): + valid_credentials = await valid_credentials + + if not valid_credentials: + response = connection.respond( + http.HTTPStatus.UNAUTHORIZED, + "Invalid credentials\n", + ) + response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm) + return response + + connection.username = username + return None + + return process_request diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index e33d53f62..3c700a377 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -156,7 +156,9 @@ def connect( :func:`connect` may be used as a context manager:: - with websockets.sync.client.connect(...) as websocket: + from websockets.sync.client import connect + + with connect(...) as websocket: ... The connection is closed automatically when exiting the context. diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 86c162af3..5e22e112e 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -1,5 +1,6 @@ from __future__ import annotations +import hmac import http import logging import os @@ -10,12 +11,17 @@ import threading import warnings from types import TracebackType -from typing import Any, Callable, Sequence +from typing import Any, Callable, Iterable, Sequence, Tuple, cast +from ..exceptions import InvalidHeader from ..extensions.base import ServerExtensionFactory from ..extensions.permessage_deflate import enable_server_permessage_deflate from ..frames import CloseCode -from ..headers import validate_subprotocols +from ..headers import ( + build_www_authenticate_basic, + parse_authorization_basic, + validate_subprotocols, +) from ..http11 import SERVER, Request, Response from ..protocol import CONNECTING, OPEN, Event from ..server import ServerProtocol @@ -24,7 +30,7 @@ from .utils import Deadline -__all__ = ["serve", "unix_serve", "ServerConnection", "Server"] +__all__ = ["serve", "unix_serve", "ServerConnection", "Server", "basic_auth"] class ServerConnection(Connection): @@ -65,6 +71,7 @@ def __init__( protocol, close_timeout=close_timeout, ) + self.username: str # see basic_auth() def respond(self, status: StatusLike, text: str) -> Response: """ @@ -368,10 +375,12 @@ def serve( that it will be closed and call :meth:`~Server.serve_forever` to serve requests:: + from websockets.sync.server import serve + def handler(websocket): ... - with websockets.sync.server.serve(handler, ...) as server: + with serve(handler, ...) as server: server.serve_forever() Args: @@ -587,3 +596,119 @@ def unix_serve( """ return serve(handler, unix=True, path=path, **kwargs) + + +def is_credentials(credentials: Any) -> bool: + try: + username, password = credentials + except (TypeError, ValueError): + return False + else: + return isinstance(username, str) and isinstance(password, str) + + +def basic_auth( + realm: str = "", + credentials: tuple[str, str] | Iterable[tuple[str, str]] | None = None, + check_credentials: Callable[[str, str], bool] | None = None, +) -> Callable[[ServerConnection, Request], Response | None]: + """ + Factory for ``process_request`` to enforce HTTP Basic Authentication. + + :func:`basic_auth` is designed to integrate with :func:`serve` as follows:: + + from websockets.sync.server import basic_auth, serve + + with serve( + ..., + process_request=basic_auth( + realm="my dev server", + credentials=("hello", "iloveyou"), + ), + ): + + If authentication succeeds, the connection's ``username`` attribute is set. + If it fails, the server responds with an HTTP 401 Unauthorized status. + + One of ``credentials`` or ``check_credentials`` must be provided; not both. + + Args: + realm: Scope of protection. It should contain only ASCII characters + because the encoding of non-ASCII characters is undefined. Refer to + section 2.2 of :rfc:`7235` for details. + credentials: Hard coded authorized credentials. It can be a + ``(username, password)`` pair or a list of such pairs. + check_credentials: Function that verifies credentials. + It receives ``username`` and ``password`` arguments and returns + whether they're valid. + Raises: + TypeError: If ``credentials`` or ``check_credentials`` is wrong. + + """ + if (credentials is None) == (check_credentials is None): + raise TypeError("provide either credentials or check_credentials") + + if credentials is not None: + if is_credentials(credentials): + credentials_list = [cast(Tuple[str, str], credentials)] + elif isinstance(credentials, Iterable): + credentials_list = list(cast(Iterable[Tuple[str, str]], credentials)) + if not all(is_credentials(item) for item in credentials_list): + raise TypeError(f"invalid credentials argument: {credentials}") + else: + raise TypeError(f"invalid credentials argument: {credentials}") + + credentials_dict = dict(credentials_list) + + def check_credentials(username: str, password: str) -> bool: + try: + expected_password = credentials_dict[username] + except KeyError: + return False + return hmac.compare_digest(expected_password, password) + + assert check_credentials is not None # help mypy + + def process_request( + connection: ServerConnection, + request: Request, + ) -> Response | None: + """ + Perform HTTP Basic Authentication. + + If it succeeds, set the connection's ``username`` attribute and return + :obj:`None`. If it fails, return an HTTP 401 Unauthorized responss. + + """ + try: + authorization = request.headers["Authorization"] + except KeyError: + response = connection.respond( + http.HTTPStatus.UNAUTHORIZED, + "Missing credentials\n", + ) + response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm) + return response + + try: + username, password = parse_authorization_basic(authorization) + except InvalidHeader: + response = connection.respond( + http.HTTPStatus.UNAUTHORIZED, + "Unsupported credentials\n", + ) + response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm) + return response + + if not check_credentials(username, password): + response = connection.respond( + http.HTTPStatus.UNAUTHORIZED, + "Invalid credentials\n", + ) + response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm) + return response + + connection.username = username + return None + + return process_request diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index b899998f4..f05b9f1e4 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -1,5 +1,6 @@ import asyncio import dataclasses +import hmac import http import logging import socket @@ -560,3 +561,162 @@ async def test_logger(self): logger = logging.getLogger("test") async with run_server(logger=logger) as server: self.assertIs(server.logger, logger) + + +class BasicAuthTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + async def test_valid_authorization(self): + """basic_auth authenticates client with HTTP Basic Authentication.""" + async with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + async with run_client( + server, + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ) as client: + await self.assertEval(client, "ws.username", "hello") + + async def test_missing_authorization(self): + """basic_auth rejects client without credentials.""" + async with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + async def test_unsupported_authorization(self): + """basic_auth rejects client with unsupported credentials.""" + async with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client( + server, + additional_headers={"Authorization": "Negotiate ..."}, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + async def test_authorization_with_unknown_username(self): + """basic_auth rejects client with unknown username.""" + async with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client( + server, + additional_headers={"Authorization": "Basic YnllOnlvdWxvdmVtZQ=="}, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + async def test_authorization_with_incorrect_password(self): + """basic_auth rejects client with incorrect password.""" + async with run_server( + process_request=basic_auth(credentials=("hello", "changeme")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client( + server, + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + async def test_list_of_credentials(self): + """basic_auth accepts a list of hard coded credentials.""" + async with run_server( + process_request=basic_auth( + credentials=[ + ("hello", "iloveyou"), + ("bye", "youloveme"), + ] + ), + ) as server: + async with run_client( + server, + additional_headers={"Authorization": "Basic YnllOnlvdWxvdmVtZQ=="}, + ) as client: + await self.assertEval(client, "ws.username", "bye") + + async def test_check_credentials_function(self): + """basic_auth accepts a check_credentials function.""" + + def check_credentials(username, password): + return hmac.compare_digest(password, "iloveyou") + + async with run_server( + process_request=basic_auth(check_credentials=check_credentials), + ) as server: + async with run_client( + server, + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ) as client: + await self.assertEval(client, "ws.username", "hello") + + async def test_check_credentials_coroutine(self): + """basic_auth accepts a check_credentials coroutine.""" + + async def check_credentials(username, password): + return hmac.compare_digest(password, "iloveyou") + + async with run_server( + process_request=basic_auth(check_credentials=check_credentials), + ) as server: + async with run_client( + server, + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ) as client: + await self.assertEval(client, "ws.username", "hello") + + async def test_without_credentials_or_check_credentials(self): + """basic_auth requires either credentials or check_credentials.""" + with self.assertRaises(TypeError) as raised: + basic_auth() + self.assertEqual( + str(raised.exception), + "provide either credentials or check_credentials", + ) + + async def test_with_credentials_and_check_credentials(self): + """basic_auth requires only one of credentials and check_credentials.""" + with self.assertRaises(TypeError) as raised: + basic_auth( + credentials=("hello", "iloveyou"), + check_credentials=lambda: False, # pragma: no cover + ) + self.assertEqual( + str(raised.exception), + "provide either credentials or check_credentials", + ) + + async def test_bad_credentials(self): + """basic_auth receives an unsupported credentials argument.""" + with self.assertRaises(TypeError) as raised: + basic_auth(credentials=42) + self.assertEqual( + str(raised.exception), + "invalid credentials argument: 42", + ) + + async def test_bad_list_of_credentials(self): + """basic_auth receives an unsupported credentials argument.""" + with self.assertRaises(TypeError) as raised: + basic_auth(credentials=[42]) + self.assertEqual( + str(raised.exception), + "invalid credentials argument: [42]", + ) diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index e3dfeb271..39d7501bc 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -1,4 +1,5 @@ import dataclasses +import hmac import http import logging import socket @@ -413,6 +414,150 @@ def test_shutdown(self): server.socket.accept() +class BasicAuthTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + def test_valid_authorization(self): + """basic_auth authenticates client with HTTP Basic Authentication.""" + with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + with run_client( + server, + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ) as client: + self.assertEval(client, "ws.username", "hello") + + def test_missing_authorization(self): + """basic_auth rejects client without credentials.""" + with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + def test_unsupported_authorization(self): + """basic_auth rejects client with unsupported credentials.""" + with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + with run_client( + server, + additional_headers={"Authorization": "Negotiate ..."}, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + def test_authorization_with_unknown_username(self): + """basic_auth rejects client with unknown username.""" + with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + with run_client( + server, + additional_headers={"Authorization": "Basic YnllOnlvdWxvdmVtZQ=="}, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + def test_authorization_with_incorrect_password(self): + """basic_auth rejects client with incorrect password.""" + with run_server( + process_request=basic_auth(credentials=("hello", "changeme")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + with run_client( + server, + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + def test_list_of_credentials(self): + """basic_auth accepts a list of hard coded credentials.""" + with run_server( + process_request=basic_auth( + credentials=[ + ("hello", "iloveyou"), + ("bye", "youloveme"), + ] + ), + ) as server: + with run_client( + server, + additional_headers={"Authorization": "Basic YnllOnlvdWxvdmVtZQ=="}, + ) as client: + self.assertEval(client, "ws.username", "bye") + + def test_check_credentials(self): + """basic_auth accepts a check_credentials function.""" + + def check_credentials(username, password): + return hmac.compare_digest(password, "iloveyou") + + with run_server( + process_request=basic_auth(check_credentials=check_credentials), + ) as server: + with run_client( + server, + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ) as client: + self.assertEval(client, "ws.username", "hello") + + def test_without_credentials_or_check_credentials(self): + """basic_auth requires either credentials or check_credentials.""" + with self.assertRaises(TypeError) as raised: + basic_auth() + self.assertEqual( + str(raised.exception), + "provide either credentials or check_credentials", + ) + + def test_with_credentials_and_check_credentials(self): + """basic_auth requires only one of credentials and check_credentials.""" + with self.assertRaises(TypeError) as raised: + basic_auth( + credentials=("hello", "iloveyou"), + check_credentials=lambda: False, # pragma: no cover + ) + self.assertEqual( + str(raised.exception), + "provide either credentials or check_credentials", + ) + + def test_bad_credentials(self): + """basic_auth receives an unsupported credentials argument.""" + with self.assertRaises(TypeError) as raised: + basic_auth(credentials=42) + self.assertEqual( + str(raised.exception), + "invalid credentials argument: 42", + ) + + def test_bad_list_of_credentials(self): + """basic_auth receives an unsupported credentials argument.""" + with self.assertRaises(TypeError) as raised: + basic_auth(credentials=[42]) + self.assertEqual( + str(raised.exception), + "invalid credentials argument: [42]", + ) + + class BackwardsCompatibilityTests(DeprecationTestCase): def test_ssl_context_argument(self): """Client supports the deprecated ssl_context argument.""" From 4920a585beef9c4f8a1aa7e332dc20d75f252423 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 22 Aug 2024 09:12:16 +0200 Subject: [PATCH 1354/1539] Make logging examples more robust. Prevent them from crashing when the opening handshake didn't complete successfully. Also migrate them to the new asyncio API. Fix #1428. --- docs/topics/logging.rst | 15 +++++++++------ src/websockets/server.py | 2 +- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index 9580b4c50..be5678455 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -124,9 +124,11 @@ Here's how to include them in logs, assuming they're in the def process(self, msg, kwargs): try: websocket = kwargs["extra"]["websocket"] - except KeyError: + except KeyError: # log entry not coming from a connection + return msg, kwargs + if websocket.request is None: # opening handshake not complete return msg, kwargs - xff = websocket.request_headers.get("X-Forwarded-For") + xff = headers.get("X-Forwarded-For") return f"{websocket.id} {xff} {msg}", kwargs async with serve( @@ -165,10 +167,11 @@ a :class:`~logging.LoggerAdapter`:: websocket = kwargs["extra"]["websocket"] except KeyError: return msg, kwargs - kwargs["extra"]["event_data"] = { - "connection_id": str(websocket.id), - "remote_addr": websocket.request_headers.get("X-Forwarded-For"), - } + event_data = {"connection_id": str(websocket.id)} + if websocket.request is not None: # opening handshake complete + headers = websocket.request.headers + event_data["remote_addr"] = headers.get("X-Forwarded-For") + kwargs["extra"]["event_data"] = event_data return msg, kwargs async with serve( diff --git a/src/websockets/server.py b/src/websockets/server.py index 2ab9102f7..11ba8b425 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -69,7 +69,7 @@ class ServerProtocol(Protocol): max_size: Maximum size of incoming messages in bytes; :obj:`None` disables the limit. logger: Logger for this connection; - defaults to ``logging.getLogger("websockets.client")``; + defaults to ``logging.getLogger("websockets.server")``; see the :doc:`logging guide <../../topics/logging>` for details. """ From d341bbabd81facbf084e7c5b321e7c8c2cd977f5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 22 Aug 2024 21:55:55 +0200 Subject: [PATCH 1355/1539] Add API for the set of active connections. Fix #1486. --- docs/howto/upgrade.rst | 10 ++++++++ docs/project/changelog.rst | 7 ++++-- docs/reference/asyncio/server.rst | 2 ++ src/websockets/asyncio/server.py | 14 ++++++++++- tests/asyncio/test_server.py | 24 ++++++++++++------- tests/sync/test_server.py | 40 +++++++++++++++---------------- 6 files changed, 65 insertions(+), 32 deletions(-) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index d68f5fe99..602d8a4e6 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -197,6 +197,16 @@ implementation. Depending on your use case, adopting this method may improve performance when streaming large messages. Specifically, it could reduce memory usage. +Tracking open connections +......................... + +The new implementation of :class:`~asyncio.server.Server` provides a +:attr:`~asyncio.server.Server.connections` property, which is a set of all open +connections. This didn't exist in the original implementation. + +If you were keeping track of open connections, you may be able to simplify your +code by using this property. + .. _basic-auth: Performing HTTP Basic Authentication diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index d940b1ea2..1c87882f0 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -35,8 +35,11 @@ notice. New features ............ -* The new :mod:`asyncio` and :mod:`threading` implementations provide an API for - enforcing HTTP Basic Auth on the server side. +* Made the set of active connections available in the :attr:`Server.connections + ` property. + +* Added HTTP Basic Auth to the new :mod:`asyncio` and :mod:`threading` + implementations of servers. .. _13.0: diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index d4d20aeb6..2fcaeb414 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -17,6 +17,8 @@ Running a server .. autoclass:: Server + .. autoattribute:: connections + .. automethod:: close .. automethod:: wait_closed diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index f6cd9a1b5..29860e565 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -28,7 +28,7 @@ validate_subprotocols, ) from ..http11 import SERVER, Request, Response -from ..protocol import CONNECTING, Event +from ..protocol import CONNECTING, OPEN, Event from ..server import ServerProtocol from ..typing import LoggerLike, Origin, StatusLike, Subprotocol from .compatibility import asyncio_timeout @@ -313,6 +313,18 @@ def __init__( # Completed when the server is closed and connections are terminated. self.closed_waiter: asyncio.Future[None] = self.loop.create_future() + @property + def connections(self) -> set[ServerConnection]: + """ + Set of active connections. + + This property contains all connections that completed the opening + handshake successfully and didn't start the closing handshake yet. + It can be useful in combination with :func:`~broadcast`. + + """ + return {connection for connection in self.handlers if connection.state is OPEN} + def wrap(self, server: asyncio.Server) -> None: """ Attach to a given :class:`asyncio.Server`. diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index f05b9f1e4..38f226903 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -350,6 +350,12 @@ async def test_disable_keepalive(self): latency = eval(await client.recv()) self.assertEqual(latency, 0) + async def test_logger(self): + """Server accepts a logger argument.""" + logger = logging.getLogger("test") + async with run_server(logger=logger) as server: + self.assertIs(server.logger, logger) + async def test_custom_connection_factory(self): """Server runs ServerConnection factory provided in create_connection.""" @@ -362,6 +368,16 @@ def create_connection(*args, **kwargs): async with run_client(server) as client: await self.assertEval(client, "ws.create_connection_ran", "True") + async def test_connections(self): + """Server provides a connections property.""" + async with run_server() as server: + self.assertEqual(server.connections, set()) + async with run_client(server) as client: + self.assertEqual(len(server.connections), 1) + ws_id = str(next(iter(server.connections)).id) + await self.assertEval(client, "ws.id", ws_id) + self.assertEqual(server.connections, set()) + async def test_handshake_fails(self): """Server receives connection from client but the handshake fails.""" @@ -555,14 +571,6 @@ async def test_unsupported_compression(self): ) -class WebSocketServerTests(unittest.IsolatedAsyncioTestCase): - async def test_logger(self): - """Server accepts a logger argument.""" - logger = logging.getLogger("test") - async with run_server(logger=logger) as server: - self.assertIs(server.logger, logger) - - class BasicAuthTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): async def test_valid_authorization(self): """basic_auth authenticates client with HTTP Basic Authentication.""" diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 39d7501bc..a17634716 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -235,6 +235,12 @@ def test_disable_compression(self): with run_client(server) as client: self.assertEval(client, "ws.protocol.extensions", "[]") + def test_logger(self): + """Server accepts a logger argument.""" + logger = logging.getLogger("test") + with run_server(logger=logger) as server: + self.assertIs(server.logger, logger) + def test_custom_connection_factory(self): """Server runs ServerConnection factory provided in create_connection.""" @@ -247,6 +253,19 @@ def create_connection(*args, **kwargs): with run_client(server) as client: self.assertEval(client, "ws.create_connection_ran", "True") + def test_fileno(self): + """Server provides a fileno attribute.""" + with run_server() as server: + self.assertIsInstance(server.fileno(), int) + + def test_shutdown(self): + """Server provides a shutdown method.""" + with run_server() as server: + server.shutdown() + # Check that the server socket is closed. + with self.assertRaises(OSError): + server.socket.accept() + def test_handshake_fails(self): """Server receives connection from client but the handshake fails.""" @@ -393,27 +412,6 @@ def test_unsupported_compression(self): ) -class WebSocketServerTests(unittest.TestCase): - def test_logger(self): - """Server accepts a logger argument.""" - logger = logging.getLogger("test") - with run_server(logger=logger) as server: - self.assertIs(server.logger, logger) - - def test_fileno(self): - """Server provides a fileno attribute.""" - with run_server() as server: - self.assertIsInstance(server.fileno(), int) - - def test_shutdown(self): - """Server provides a shutdown method.""" - with run_server() as server: - server.shutdown() - # Check that the server socket is closed. - with self.assertRaises(OSError): - server.socket.accept() - - class BasicAuthTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): def test_valid_authorization(self): """basic_auth authenticates client with HTTP Basic Authentication.""" From 2306835d7bf936d3b7c09caabd12dd072c7d7a39 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Aug 2024 08:41:35 +0200 Subject: [PATCH 1356/1539] Set Protocol.close_rcvd_then_sent in an edge case. When receiving a close frame in the middle of a fragmented message -- cf. test_(client|server)_receive_close_in_fragmented_message) -- recv_frame() was raising a ProtocolError, fail() was sending a Close frame, and close_rcvd_then_sent was never set. These tests don't proceed to the CLOSED state, making it difficult to test close_exc. --- src/websockets/protocol.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index de065c544..3b3e80cf5 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -378,9 +378,11 @@ def send_close(self, code: int | None = None, reason: str = "") -> None: else: close = Close(code, reason) data = close.serialize() - # send_frame() guarantees that self.state is OPEN at this point. # 7.1.3. The WebSocket Closing Handshake is Started self.send_frame(Frame(OP_CLOSE, data)) + # Since the state is OPEN, no close frame was received yet. + # As a consequence, self.close_rcvd_then_sent remains None. + assert self.close_rcvd is None self.close_sent = close self.state = CLOSING @@ -441,6 +443,12 @@ def fail(self, code: int, reason: str = "") -> None: data = close.serialize() self.send_frame(Frame(OP_CLOSE, data)) self.close_sent = close + # If recv_messages() raised an exception upon receiving a close + # frame but before echoing it, then close_rcvd is not None even + # though the state is OPEN. This happens when the connection is + # closed while receiving a fragmented message. + if self.close_rcvd is not None: + self.close_rcvd_then_sent = True self.state = CLOSING # When failing the connection, a server closes the TCP connection From 61f42acdbe47c51ad51283890d814091d2c201e6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Aug 2024 07:20:15 +0200 Subject: [PATCH 1357/1539] Deprecate private attributes of ConnectionClosed. While I don't think they were ever documented, they were tested and ConnectionClosed is such a common API that they may have been used. --- src/websockets/exceptions.py | 11 +++++++++++ tests/legacy/test_protocol.py | 10 ---------- tests/test_exceptions.py | 32 ++++++++++++++++++++++++++------ 3 files changed, 37 insertions(+), 16 deletions(-) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 52cc48898..64b7d3102 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -31,6 +31,7 @@ from __future__ import annotations import http +import warnings from . import datastructures, frames, http11 from .typing import StatusLike @@ -120,12 +121,22 @@ def __str__(self) -> str: @property def code(self) -> int: + warnings.warn( # deprecated in 13.1 + "ConnectionClosed.code is deprecated; " + "use Protocol.close_code or ConnectionClosed.rcvd.code", + DeprecationWarning, + ) if self.rcvd is None: return frames.CloseCode.ABNORMAL_CLOSURE return self.rcvd.code @property def reason(self) -> str: + warnings.warn( # deprecated in 13.1 + "ConnectionClosed.reason is deprecated; " + "use Protocol.close_reason or ConnectionClosed.rcvd.reason", + DeprecationWarning, + ) if self.rcvd is None: return "" return self.rcvd.reason diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 8751b9ac6..de2a320b5 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1180,16 +1180,6 @@ def test_legacy_recv(self): # Now recv() returns None instead of raising ConnectionClosed. self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) - def test_connection_closed_attributes(self): - self.close_connection() - - with self.assertRaises(ConnectionClosed) as context: - self.loop.run_until_complete(self.protocol.recv()) - - connection_closed_exc = context.exception - self.assertEqual(connection_closed_exc.code, CloseCode.NORMAL_CLOSURE) - self.assertEqual(connection_closed_exc.reason, "close") - # Test the protocol logic for sending keepalive pings. def restart_protocol_with_keepalive_ping( diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 1e6f58fad..92ba7dda8 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -5,6 +5,8 @@ from websockets.frames import Close, CloseCode from websockets.http11 import Response +from .utils import DeprecationTestCase + class ExceptionsTests(unittest.TestCase): def test_str(self): @@ -185,12 +187,30 @@ def test_str(self): with self.subTest(exception=exception): self.assertEqual(str(exception), exception_str) - def test_connection_closed_attributes_backwards_compatibility(self): + +class DeprecationTests(DeprecationTestCase): + def test_connection_closed_attributes_deprecation(self): exception = ConnectionClosed(Close(CloseCode.NORMAL_CLOSURE, "OK"), None, None) - self.assertEqual(exception.code, CloseCode.NORMAL_CLOSURE) - self.assertEqual(exception.reason, "OK") + with self.assertDeprecationWarning( + "ConnectionClosed.code is deprecated; " + "use Protocol.close_code or ConnectionClosed.rcvd.code" + ): + self.assertEqual(exception.code, CloseCode.NORMAL_CLOSURE) + with self.assertDeprecationWarning( + "ConnectionClosed.reason is deprecated; " + "use Protocol.close_reason or ConnectionClosed.rcvd.reason" + ): + self.assertEqual(exception.reason, "OK") - def test_connection_closed_attributes_backwards_compatibility_defaults(self): + def test_connection_closed_attributes_deprecation_defaults(self): exception = ConnectionClosed(None, None, None) - self.assertEqual(exception.code, CloseCode.ABNORMAL_CLOSURE) - self.assertEqual(exception.reason, "") + with self.assertDeprecationWarning( + "ConnectionClosed.code is deprecated; " + "use Protocol.close_code or ConnectionClosed.rcvd.code" + ): + self.assertEqual(exception.code, CloseCode.ABNORMAL_CLOSURE) + with self.assertDeprecationWarning( + "ConnectionClosed.reason is deprecated; " + "use Protocol.close_reason or ConnectionClosed.rcvd.reason" + ): + self.assertEqual(exception.reason, "") From 5638611169aa763ec84cdc8d2a967b0533401129 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Aug 2024 07:25:06 +0200 Subject: [PATCH 1358/1539] Clean up implementation of ConnectionClosed. --- src/websockets/exceptions.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 64b7d3102..8df319dae 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -78,13 +78,13 @@ class ConnectionClosed(WebSocketException): Raised when trying to interact with a closed connection. Attributes: - rcvd (Close | None): if a close frame was received, its code and - reason are available in ``rcvd.code`` and ``rcvd.reason``. - sent (Close | None): if a close frame was sent, its code and reason - are available in ``sent.code`` and ``sent.reason``. - rcvd_then_sent (bool | None): if close frames were received and - sent, this attribute tells in which order this happened, from the - perspective of this side of the connection. + rcvd: If a close frame was received, its code and reason are available + in ``rcvd.code`` and ``rcvd.reason``. + sent: If a close frame was sent, its code and reason are available + in ``sent.code`` and ``sent.reason``. + rcvd_then_sent: If close frames were received and sent, this attribute + tells in which order this happened, from the perspective of this + side of the connection. """ @@ -97,21 +97,18 @@ def __init__( self.rcvd = rcvd self.sent = sent self.rcvd_then_sent = rcvd_then_sent + assert (self.rcvd_then_sent is None) == (self.rcvd is None or self.sent is None) def __str__(self) -> str: if self.rcvd is None: if self.sent is None: - assert self.rcvd_then_sent is None return "no close frame received or sent" else: - assert self.rcvd_then_sent is None return f"sent {self.sent}; no close frame received" else: if self.sent is None: - assert self.rcvd_then_sent is None return f"received {self.rcvd}; no close frame sent" else: - assert self.rcvd_then_sent is not None if self.rcvd_then_sent: return f"received {self.rcvd}; then sent {self.sent}" else: From 567f9859158fe9849b8a1ae7536471119ce4c61f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Aug 2024 07:32:28 +0200 Subject: [PATCH 1359/1539] Move InvalidMessage to the legacy package. It is only used by the legacy implementation. --- src/websockets/__init__.py | 8 +++++--- src/websockets/exceptions.py | 23 +++++++++++++++-------- src/websockets/legacy/client.py | 2 +- src/websockets/legacy/exceptions.py | 8 ++++++++ src/websockets/legacy/server.py | 2 +- tests/legacy/test_exceptions.py | 15 +++++++++++++++ tests/test_exceptions.py | 4 ---- 7 files changed, 45 insertions(+), 17 deletions(-) create mode 100644 src/websockets/legacy/exceptions.py create mode 100644 tests/legacy/test_exceptions.py diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index b618a6dff..63b0a260b 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -23,7 +23,6 @@ "InvalidHeader", "InvalidHeaderFormat", "InvalidHeaderValue", - "InvalidMessage", "InvalidOrigin", "InvalidParameterName", "InvalidParameterValue", @@ -46,6 +45,8 @@ "WebSocketClientProtocol", "connect", "unix_connect", + # .legacy.exceptions + "InvalidMessage", # .legacy.protocol "WebSocketCommonProtocol", # .legacy.server @@ -80,7 +81,6 @@ InvalidHeader, InvalidHeaderFormat, InvalidHeaderValue, - InvalidMessage, InvalidOrigin, InvalidParameterName, InvalidParameterValue, @@ -102,6 +102,7 @@ basic_auth_protocol_factory, ) from .legacy.client import WebSocketClientProtocol, connect, unix_connect + from .legacy.exceptions import InvalidMessage from .legacy.protocol import WebSocketCommonProtocol from .legacy.server import ( WebSocketServer, @@ -140,7 +141,6 @@ "InvalidHeader": ".exceptions", "InvalidHeaderFormat": ".exceptions", "InvalidHeaderValue": ".exceptions", - "InvalidMessage": ".exceptions", "InvalidOrigin": ".exceptions", "InvalidParameterName": ".exceptions", "InvalidParameterValue": ".exceptions", @@ -163,6 +163,8 @@ "WebSocketClientProtocol": ".legacy.client", "connect": ".legacy.client", "unix_connect": ".legacy.client", + # .legacy.exceptions + "InvalidMessage": ".legacy.exceptions", # .legacy.protocol "WebSocketCommonProtocol": ".legacy.protocol", # .legacy.server diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 8df319dae..b9b100150 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -7,7 +7,7 @@ * :exc:`ConnectionClosedOK` * :exc:`InvalidHandshake` * :exc:`SecurityError` - * :exc:`InvalidMessage` + * :exc:`InvalidMessage` (legacy) * :exc:`InvalidHeader` * :exc:`InvalidHeaderFormat` * :exc:`InvalidHeaderValue` @@ -31,9 +31,11 @@ from __future__ import annotations import http +import typing import warnings from . import datastructures, frames, http11 +from .imports import lazy_import from .typing import StatusLike @@ -175,13 +177,6 @@ class SecurityError(InvalidHandshake): """ -class InvalidMessage(InvalidHandshake): - """ - Raised when a handshake request or response is malformed. - - """ - - class InvalidHeader(InvalidHandshake): """ Raised when an HTTP header doesn't have a valid format or value. @@ -410,3 +405,15 @@ class ProtocolError(WebSocketException): WebSocketProtocolError = ProtocolError # for backwards compatibility + + +# When type checking, import non-deprecated aliases eagerly. Else, import on demand. +if typing.TYPE_CHECKING: + from .legacy.exceptions import InvalidMessage +else: + lazy_import( + globals(), + aliases={ + "InvalidMessage": ".legacy.exceptions", + }, + ) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 256bee14c..bd56c5d16 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -21,7 +21,6 @@ from ..exceptions import ( InvalidHandshake, InvalidHeader, - InvalidMessage, InvalidStatusCode, NegotiationError, RedirectHandshake, @@ -41,6 +40,7 @@ from ..http11 import USER_AGENT from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol from ..uri import WebSocketURI, parse_uri +from .exceptions import InvalidMessage from .handshake import build_request, check_response from .http import read_response from .protocol import WebSocketCommonProtocol diff --git a/src/websockets/legacy/exceptions.py b/src/websockets/legacy/exceptions.py new file mode 100644 index 000000000..9ea173e1a --- /dev/null +++ b/src/websockets/legacy/exceptions.py @@ -0,0 +1,8 @@ +from ..exceptions import InvalidHandshake + + +class InvalidMessage(InvalidHandshake): + """ + Raised when a handshake request or response is malformed. + + """ diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 43136db3e..1bf359f32 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -27,7 +27,6 @@ AbortHandshake, InvalidHandshake, InvalidHeader, - InvalidMessage, InvalidOrigin, InvalidUpgrade, NegotiationError, @@ -43,6 +42,7 @@ from ..http11 import SERVER from ..protocol import State from ..typing import ExtensionHeader, LoggerLike, Origin, StatusLike, Subprotocol +from .exceptions import InvalidMessage from .handshake import build_response, check_request from .http import read_request from .protocol import WebSocketCommonProtocol, broadcast diff --git a/tests/legacy/test_exceptions.py b/tests/legacy/test_exceptions.py new file mode 100644 index 000000000..1850b3bf2 --- /dev/null +++ b/tests/legacy/test_exceptions.py @@ -0,0 +1,15 @@ +import unittest + +from websockets.legacy.exceptions import * + + +class ExceptionsTests(unittest.TestCase): + def test_str(self): + for exception, exception_str in [ + ( + InvalidMessage("malformed HTTP message"), + "malformed HTTP message", + ), + ]: + with self.subTest(exception=exception): + self.assertEqual(str(exception), exception_str) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 92ba7dda8..36a4d1724 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -87,10 +87,6 @@ def test_str(self): SecurityError("redirect from WSS to WS"), "redirect from WSS to WS", ), - ( - InvalidMessage("malformed HTTP message"), - "malformed HTTP message", - ), ( InvalidHeader("Name"), "missing Name header", From c02648e3ce9e96deec83f966dbb74fca1d9b2e0e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Aug 2024 07:56:36 +0200 Subject: [PATCH 1360/1539] Move WebSocketProtocolError to the legacy package. This ensures that it will be deprecated and removed. --- src/websockets/exceptions.py | 6 +++--- src/websockets/legacy/exceptions.py | 5 ++++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index b9b100150..1188ec432 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -404,16 +404,16 @@ class ProtocolError(WebSocketException): """ -WebSocketProtocolError = ProtocolError # for backwards compatibility - - # When type checking, import non-deprecated aliases eagerly. Else, import on demand. if typing.TYPE_CHECKING: from .legacy.exceptions import InvalidMessage + + WebSocketProtocolError = ProtocolError else: lazy_import( globals(), aliases={ "InvalidMessage": ".legacy.exceptions", + "WebSocketProtocolError": ".legacy.exceptions", }, ) diff --git a/src/websockets/legacy/exceptions.py b/src/websockets/legacy/exceptions.py index 9ea173e1a..79bfa7f4d 100644 --- a/src/websockets/legacy/exceptions.py +++ b/src/websockets/legacy/exceptions.py @@ -1,4 +1,7 @@ -from ..exceptions import InvalidHandshake +from ..exceptions import ( + InvalidHandshake, + ProtocolError as WebSocketProtocolError, # noqa: F401 +) class InvalidMessage(InvalidHandshake): From 938f1075ebcada7e46e9fb4caacfb32839c51605 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Aug 2024 09:02:11 +0200 Subject: [PATCH 1361/1539] Move InvalidStatusCode to the legacy package. It is only used by the legacy implementation. --- src/websockets/__init__.py | 7 +++---- src/websockets/exceptions.py | 20 +++++--------------- src/websockets/legacy/client.py | 3 +-- src/websockets/legacy/exceptions.py | 15 +++++++++++++++ tests/legacy/test_exceptions.py | 5 +++++ tests/test_exceptions.py | 4 ---- 6 files changed, 29 insertions(+), 25 deletions(-) diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 63b0a260b..09d7aac88 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -28,7 +28,6 @@ "InvalidParameterValue", "InvalidState", "InvalidStatus", - "InvalidStatusCode", "InvalidUpgrade", "InvalidURI", "NegotiationError", @@ -47,6 +46,7 @@ "unix_connect", # .legacy.exceptions "InvalidMessage", + "InvalidStatusCode", # .legacy.protocol "WebSocketCommonProtocol", # .legacy.server @@ -86,7 +86,6 @@ InvalidParameterValue, InvalidState, InvalidStatus, - InvalidStatusCode, InvalidUpgrade, InvalidURI, NegotiationError, @@ -102,7 +101,7 @@ basic_auth_protocol_factory, ) from .legacy.client import WebSocketClientProtocol, connect, unix_connect - from .legacy.exceptions import InvalidMessage + from .legacy.exceptions import InvalidMessage, InvalidStatusCode from .legacy.protocol import WebSocketCommonProtocol from .legacy.server import ( WebSocketServer, @@ -146,7 +145,6 @@ "InvalidParameterValue": ".exceptions", "InvalidState": ".exceptions", "InvalidStatus": ".exceptions", - "InvalidStatusCode": ".exceptions", "InvalidUpgrade": ".exceptions", "InvalidURI": ".exceptions", "NegotiationError": ".exceptions", @@ -165,6 +163,7 @@ "unix_connect": ".legacy.client", # .legacy.exceptions "InvalidMessage": ".legacy.exceptions", + "InvalidStatusCode": ".legacy.exceptions", # .legacy.protocol "WebSocketCommonProtocol": ".legacy.protocol", # .legacy.server diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 1188ec432..e35bf430e 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -250,20 +250,6 @@ def __str__(self) -> str: ) -class InvalidStatusCode(InvalidHandshake): - """ - Raised when a handshake response status code is invalid. - - """ - - def __init__(self, status_code: int, headers: datastructures.Headers) -> None: - self.status_code = status_code - self.headers = headers - - def __str__(self) -> str: - return f"server rejected WebSocket connection: HTTP {self.status_code}" - - class NegotiationError(InvalidHandshake): """ Raised when negotiating an extension fails. @@ -406,7 +392,10 @@ class ProtocolError(WebSocketException): # When type checking, import non-deprecated aliases eagerly. Else, import on demand. if typing.TYPE_CHECKING: - from .legacy.exceptions import InvalidMessage + from .legacy.exceptions import ( + InvalidMessage, + InvalidStatusCode, + ) WebSocketProtocolError = ProtocolError else: @@ -414,6 +403,7 @@ class ProtocolError(WebSocketException): globals(), aliases={ "InvalidMessage": ".legacy.exceptions", + "InvalidStatusCode": ".legacy.exceptions", "WebSocketProtocolError": ".legacy.exceptions", }, ) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index bd56c5d16..fd0803e7d 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -21,7 +21,6 @@ from ..exceptions import ( InvalidHandshake, InvalidHeader, - InvalidStatusCode, NegotiationError, RedirectHandshake, SecurityError, @@ -40,7 +39,7 @@ from ..http11 import USER_AGENT from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol from ..uri import WebSocketURI, parse_uri -from .exceptions import InvalidMessage +from .exceptions import InvalidMessage, InvalidStatusCode from .handshake import build_request, check_response from .http import read_response from .protocol import WebSocketCommonProtocol diff --git a/src/websockets/legacy/exceptions.py b/src/websockets/legacy/exceptions.py index 79bfa7f4d..d82676d5e 100644 --- a/src/websockets/legacy/exceptions.py +++ b/src/websockets/legacy/exceptions.py @@ -1,3 +1,4 @@ +from .. import datastructures from ..exceptions import ( InvalidHandshake, ProtocolError as WebSocketProtocolError, # noqa: F401 @@ -9,3 +10,17 @@ class InvalidMessage(InvalidHandshake): Raised when a handshake request or response is malformed. """ + + +class InvalidStatusCode(InvalidHandshake): + """ + Raised when a handshake response status code is invalid. + + """ + + def __init__(self, status_code: int, headers: datastructures.Headers) -> None: + self.status_code = status_code + self.headers = headers + + def __str__(self) -> str: + return f"server rejected WebSocket connection: HTTP {self.status_code}" diff --git a/tests/legacy/test_exceptions.py b/tests/legacy/test_exceptions.py index 1850b3bf2..8b9a616e9 100644 --- a/tests/legacy/test_exceptions.py +++ b/tests/legacy/test_exceptions.py @@ -1,5 +1,6 @@ import unittest +from websockets.datastructures import Headers from websockets.legacy.exceptions import * @@ -10,6 +11,10 @@ def test_str(self): InvalidMessage("malformed HTTP message"), "malformed HTTP message", ), + ( + InvalidStatusCode(403, Headers()), + "server rejected WebSocket connection: HTTP 403", + ), ]: with self.subTest(exception=exception): self.assertEqual(str(exception), exception_str) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 36a4d1724..b79489e44 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -127,10 +127,6 @@ def test_str(self): InvalidStatus(Response(401, "Unauthorized", Headers())), "server rejected WebSocket connection: HTTP 401", ), - ( - InvalidStatusCode(403, Headers()), - "server rejected WebSocket connection: HTTP 403", - ), ( NegotiationError("unsupported subprotocol: spam"), "unsupported subprotocol: spam", From 30cbbc3adaf949c0801a2fd37f82d0bd1eaa0ffc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Aug 2024 09:09:41 +0200 Subject: [PATCH 1362/1539] Move AbortHandshake to the legacy package. It is only used by the legacy implementation. --- src/websockets/__init__.py | 11 +++++--- src/websockets/exceptions.py | 42 +++-------------------------- src/websockets/legacy/exceptions.py | 37 +++++++++++++++++++++++++ src/websockets/legacy/server.py | 3 +-- tests/legacy/test_exceptions.py | 4 +++ tests/test_exceptions.py | 4 --- 6 files changed, 53 insertions(+), 48 deletions(-) diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 09d7aac88..a623d32c4 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -14,7 +14,6 @@ "HeadersLike", "MultipleValuesError", # .exceptions - "AbortHandshake", "ConnectionClosed", "ConnectionClosedError", "ConnectionClosedOK", @@ -45,6 +44,7 @@ "connect", "unix_connect", # .legacy.exceptions + "AbortHandshake", "InvalidMessage", "InvalidStatusCode", # .legacy.protocol @@ -72,7 +72,6 @@ from .client import ClientProtocol from .datastructures import Headers, HeadersLike, MultipleValuesError from .exceptions import ( - AbortHandshake, ConnectionClosed, ConnectionClosedError, ConnectionClosedOK, @@ -101,7 +100,11 @@ basic_auth_protocol_factory, ) from .legacy.client import WebSocketClientProtocol, connect, unix_connect - from .legacy.exceptions import InvalidMessage, InvalidStatusCode + from .legacy.exceptions import ( + AbortHandshake, + InvalidMessage, + InvalidStatusCode, + ) from .legacy.protocol import WebSocketCommonProtocol from .legacy.server import ( WebSocketServer, @@ -131,7 +134,6 @@ "HeadersLike": ".datastructures", "MultipleValuesError": ".datastructures", # .exceptions - "AbortHandshake": ".exceptions", "ConnectionClosed": ".exceptions", "ConnectionClosedError": ".exceptions", "ConnectionClosedOK": ".exceptions", @@ -162,6 +164,7 @@ "connect": ".legacy.client", "unix_connect": ".legacy.client", # .legacy.exceptions + "AbortHandshake": ".legacy.exceptions", "InvalidMessage": ".legacy.exceptions", "InvalidStatusCode": ".legacy.exceptions", # .legacy.protocol diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index e35bf430e..803e3fa53 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -19,7 +19,7 @@ * :exc:`DuplicateParameter` * :exc:`InvalidParameterName` * :exc:`InvalidParameterValue` - * :exc:`AbortHandshake` + * :exc:`AbortHandshake` (legacy) * :exc:`RedirectHandshake` * :exc:`InvalidState` * :exc:`InvalidURI` @@ -30,13 +30,11 @@ from __future__ import annotations -import http import typing import warnings -from . import datastructures, frames, http11 +from . import frames, http11 from .imports import lazy_import -from .typing import StatusLike __all__ = [ @@ -302,40 +300,6 @@ def __str__(self) -> str: return f"invalid value for parameter {self.name}: {self.value}" -class AbortHandshake(InvalidHandshake): - """ - Raised to abort the handshake on purpose and return an HTTP response. - - This exception is an implementation detail. - - The public API is - :meth:`~websockets.legacy.server.WebSocketServerProtocol.process_request`. - - Attributes: - status (~http.HTTPStatus): HTTP status code. - headers (Headers): HTTP response headers. - body (bytes): HTTP response body. - """ - - def __init__( - self, - status: StatusLike, - headers: datastructures.HeadersLike, - body: bytes = b"", - ) -> None: - # If a user passes an int instead of a HTTPStatus, fix it automatically. - self.status = http.HTTPStatus(status) - self.headers = datastructures.Headers(headers) - self.body = body - - def __str__(self) -> str: - return ( - f"HTTP {self.status:d}, " - f"{len(self.headers)} headers, " - f"{len(self.body)} bytes" - ) - - class RedirectHandshake(InvalidHandshake): """ Raised when a handshake gets redirected. @@ -393,6 +357,7 @@ class ProtocolError(WebSocketException): # When type checking, import non-deprecated aliases eagerly. Else, import on demand. if typing.TYPE_CHECKING: from .legacy.exceptions import ( + AbortHandshake, InvalidMessage, InvalidStatusCode, ) @@ -402,6 +367,7 @@ class ProtocolError(WebSocketException): lazy_import( globals(), aliases={ + "AbortHandshake": ".legacy.exceptions", "InvalidMessage": ".legacy.exceptions", "InvalidStatusCode": ".legacy.exceptions", "WebSocketProtocolError": ".legacy.exceptions", diff --git a/src/websockets/legacy/exceptions.py b/src/websockets/legacy/exceptions.py index d82676d5e..d02a1f933 100644 --- a/src/websockets/legacy/exceptions.py +++ b/src/websockets/legacy/exceptions.py @@ -1,8 +1,11 @@ +import http + from .. import datastructures from ..exceptions import ( InvalidHandshake, ProtocolError as WebSocketProtocolError, # noqa: F401 ) +from ..typing import StatusLike class InvalidMessage(InvalidHandshake): @@ -24,3 +27,37 @@ def __init__(self, status_code: int, headers: datastructures.Headers) -> None: def __str__(self) -> str: return f"server rejected WebSocket connection: HTTP {self.status_code}" + + +class AbortHandshake(InvalidHandshake): + """ + Raised to abort the handshake on purpose and return an HTTP response. + + This exception is an implementation detail. + + The public API is + :meth:`~websockets.legacy.server.WebSocketServerProtocol.process_request`. + + Attributes: + status (~http.HTTPStatus): HTTP status code. + headers (Headers): HTTP response headers. + body (bytes): HTTP response body. + """ + + def __init__( + self, + status: StatusLike, + headers: datastructures.HeadersLike, + body: bytes = b"", + ) -> None: + # If a user passes an int instead of a HTTPStatus, fix it automatically. + self.status = http.HTTPStatus(status) + self.headers = datastructures.Headers(headers) + self.body = body + + def __str__(self) -> str: + return ( + f"HTTP {self.status:d}, " + f"{len(self.headers)} headers, " + f"{len(self.body)} bytes" + ) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 1bf359f32..a71996e45 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -24,7 +24,6 @@ from ..asyncio.compatibility import asyncio_timeout from ..datastructures import Headers, HeadersLike, MultipleValuesError from ..exceptions import ( - AbortHandshake, InvalidHandshake, InvalidHeader, InvalidOrigin, @@ -42,7 +41,7 @@ from ..http11 import SERVER from ..protocol import State from ..typing import ExtensionHeader, LoggerLike, Origin, StatusLike, Subprotocol -from .exceptions import InvalidMessage +from .exceptions import AbortHandshake, InvalidMessage from .handshake import build_response, check_request from .http import read_request from .protocol import WebSocketCommonProtocol, broadcast diff --git a/tests/legacy/test_exceptions.py b/tests/legacy/test_exceptions.py index 8b9a616e9..1bab24c4f 100644 --- a/tests/legacy/test_exceptions.py +++ b/tests/legacy/test_exceptions.py @@ -15,6 +15,10 @@ def test_str(self): InvalidStatusCode(403, Headers()), "server rejected WebSocket connection: HTTP 403", ), + ( + AbortHandshake(200, Headers(), b"OK\n"), + "HTTP 200, 0 headers, 3 bytes", + ), ]: with self.subTest(exception=exception): self.assertEqual(str(exception), exception_str) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index b79489e44..b45642840 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -151,10 +151,6 @@ def test_str(self): InvalidParameterValue("a", "|"), "invalid value for parameter a: |", ), - ( - AbortHandshake(200, Headers(), b"OK\n"), - "HTTP 200, 0 headers, 3 bytes", - ), ( RedirectHandshake("wss://example.com"), "redirect to wss://example.com", From 5f31bcb5a52e4ec8f2b4fc52e9e9b1ffef356771 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Aug 2024 09:12:20 +0200 Subject: [PATCH 1363/1539] Move RedirectHandshake to the legacy package. It is only used by the legacy implementation. --- src/websockets/__init__.py | 6 +++--- src/websockets/exceptions.py | 19 +++---------------- src/websockets/legacy/client.py | 4 ++-- src/websockets/legacy/exceptions.py | 15 +++++++++++++++ tests/legacy/test_exceptions.py | 4 ++++ tests/test_exceptions.py | 4 ---- 6 files changed, 27 insertions(+), 25 deletions(-) diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index a623d32c4..12141adb0 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -32,7 +32,6 @@ "NegotiationError", "PayloadTooBig", "ProtocolError", - "RedirectHandshake", "SecurityError", "WebSocketException", "WebSocketProtocolError", @@ -47,6 +46,7 @@ "AbortHandshake", "InvalidMessage", "InvalidStatusCode", + "RedirectHandshake", # .legacy.protocol "WebSocketCommonProtocol", # .legacy.server @@ -90,7 +90,6 @@ NegotiationError, PayloadTooBig, ProtocolError, - RedirectHandshake, SecurityError, WebSocketException, WebSocketProtocolError, @@ -104,6 +103,7 @@ AbortHandshake, InvalidMessage, InvalidStatusCode, + RedirectHandshake, ) from .legacy.protocol import WebSocketCommonProtocol from .legacy.server import ( @@ -152,7 +152,6 @@ "NegotiationError": ".exceptions", "PayloadTooBig": ".exceptions", "ProtocolError": ".exceptions", - "RedirectHandshake": ".exceptions", "SecurityError": ".exceptions", "WebSocketException": ".exceptions", "WebSocketProtocolError": ".exceptions", @@ -167,6 +166,7 @@ "AbortHandshake": ".legacy.exceptions", "InvalidMessage": ".legacy.exceptions", "InvalidStatusCode": ".legacy.exceptions", + "RedirectHandshake": ".legacy.exceptions", # .legacy.protocol "WebSocketCommonProtocol": ".legacy.protocol", # .legacy.server diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 803e3fa53..b5aabf75d 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -20,7 +20,7 @@ * :exc:`InvalidParameterName` * :exc:`InvalidParameterValue` * :exc:`AbortHandshake` (legacy) - * :exc:`RedirectHandshake` + * :exc:`RedirectHandshake` (legacy) * :exc:`InvalidState` * :exc:`InvalidURI` * :exc:`PayloadTooBig` @@ -300,21 +300,6 @@ def __str__(self) -> str: return f"invalid value for parameter {self.name}: {self.value}" -class RedirectHandshake(InvalidHandshake): - """ - Raised when a handshake gets redirected. - - This exception is an implementation detail. - - """ - - def __init__(self, uri: str) -> None: - self.uri = uri - - def __str__(self) -> str: - return f"redirect to {self.uri}" - - class InvalidState(WebSocketException, AssertionError): """ Raised when an operation is forbidden in the current state. @@ -360,6 +345,7 @@ class ProtocolError(WebSocketException): AbortHandshake, InvalidMessage, InvalidStatusCode, + RedirectHandshake, ) WebSocketProtocolError = ProtocolError @@ -370,6 +356,7 @@ class ProtocolError(WebSocketException): "AbortHandshake": ".legacy.exceptions", "InvalidMessage": ".legacy.exceptions", "InvalidStatusCode": ".legacy.exceptions", + "RedirectHandshake": ".legacy.exceptions", "WebSocketProtocolError": ".legacy.exceptions", }, ) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index fd0803e7d..057acb656 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -22,7 +22,7 @@ InvalidHandshake, InvalidHeader, NegotiationError, - RedirectHandshake, + SecurityError, ) from ..extensions import ClientExtensionFactory, Extension @@ -39,7 +39,7 @@ from ..http11 import USER_AGENT from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol from ..uri import WebSocketURI, parse_uri -from .exceptions import InvalidMessage, InvalidStatusCode +from .exceptions import InvalidMessage, InvalidStatusCode, RedirectHandshake from .handshake import build_request, check_response from .http import read_response from .protocol import WebSocketCommonProtocol diff --git a/src/websockets/legacy/exceptions.py b/src/websockets/legacy/exceptions.py index d02a1f933..9ca9b7aff 100644 --- a/src/websockets/legacy/exceptions.py +++ b/src/websockets/legacy/exceptions.py @@ -61,3 +61,18 @@ def __str__(self) -> str: f"{len(self.headers)} headers, " f"{len(self.body)} bytes" ) + + +class RedirectHandshake(InvalidHandshake): + """ + Raised when a handshake gets redirected. + + This exception is an implementation detail. + + """ + + def __init__(self, uri: str) -> None: + self.uri = uri + + def __str__(self) -> str: + return f"redirect to {self.uri}" diff --git a/tests/legacy/test_exceptions.py b/tests/legacy/test_exceptions.py index 1bab24c4f..e5d22a917 100644 --- a/tests/legacy/test_exceptions.py +++ b/tests/legacy/test_exceptions.py @@ -19,6 +19,10 @@ def test_str(self): AbortHandshake(200, Headers(), b"OK\n"), "HTTP 200, 0 headers, 3 bytes", ), + ( + RedirectHandshake("wss://example.com"), + "redirect to wss://example.com", + ), ]: with self.subTest(exception=exception): self.assertEqual(str(exception), exception_str) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index b45642840..b54903275 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -151,10 +151,6 @@ def test_str(self): InvalidParameterValue("a", "|"), "invalid value for parameter a: |", ), - ( - RedirectHandshake("wss://example.com"), - "redirect to wss://example.com", - ), ( InvalidState("WebSocket connection isn't established yet"), "WebSocket connection isn't established yet", From 58d72bc995b3e254b5804b36860d948a9c2cf38f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Aug 2024 09:21:39 +0200 Subject: [PATCH 1364/1539] Document legacy exceptions in their own section. This change has the side effect of dropping the WebSocketProtocolError exception from the list. It was documented as an alias of ProtocolError as a side effect of using the automodule directive. Eventually it will get deprecated with the rest of the legacy module so potential users will be aware. --- docs/reference/exceptions.rst | 53 ++++++++++++++++++++++++++++++++++- 1 file changed, 52 insertions(+), 1 deletion(-) diff --git a/docs/reference/exceptions.rst b/docs/reference/exceptions.rst index 907a650d2..062e69df1 100644 --- a/docs/reference/exceptions.rst +++ b/docs/reference/exceptions.rst @@ -2,5 +2,56 @@ Exceptions ========== .. automodule:: websockets.exceptions - :members: +.. autoexception:: WebSocketException + +.. autoexception:: ConnectionClosed + +.. autoexception:: ConnectionClosedError + +.. autoexception:: ConnectionClosedOK + +.. autoexception:: InvalidHandshake + +.. autoexception:: SecurityError + +.. autoexception:: InvalidHeader + +.. autoexception:: InvalidHeaderFormat + +.. autoexception:: InvalidHeaderValue + +.. autoexception:: InvalidOrigin + +.. autoexception:: InvalidUpgrade + +.. autoexception:: InvalidStatus + +.. autoexception:: NegotiationError + +.. autoexception:: DuplicateParameter + +.. autoexception:: InvalidParameterName + +.. autoexception:: InvalidParameterValue + +.. autoexception:: InvalidState + +.. autoexception:: InvalidURI + +.. autoexception:: PayloadTooBig + +.. autoexception:: ProtocolError + +Legacy exceptions +----------------- + +These exceptions are only used by the legacy :mod:`asyncio` implementation. + +.. autoexception:: InvalidMessage + +.. autoexception:: InvalidStatusCode + +.. autoexception:: AbortHandshake + +.. autoexception:: RedirectHandshake From 60fc23788f8e7d20e97780a73573e68f5d29db68 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Aug 2024 14:22:35 +0200 Subject: [PATCH 1365/1539] Always raise subclasses of InvalidHandshake. Also shorten messages for InvalidHeader exceptions. --- src/websockets/client.py | 15 +++++++-------- src/websockets/datastructures.py | 2 +- src/websockets/legacy/client.py | 13 +++++++------ src/websockets/legacy/handshake.py | 12 +++--------- src/websockets/legacy/server.py | 2 +- src/websockets/server.py | 11 +++-------- tests/test_client.py | 7 ++++--- tests/test_server.py | 8 +++----- 8 files changed, 29 insertions(+), 41 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 07d1d34ed..c408c1447 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -176,10 +176,7 @@ def process_response(self, response: Response) -> None: except KeyError as exc: raise InvalidHeader("Sec-WebSocket-Accept") from exc except MultipleValuesError as exc: - raise InvalidHeader( - "Sec-WebSocket-Accept", - "more than one Sec-WebSocket-Accept header found", - ) from exc + raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from exc if s_w_accept != accept_key(self.key): raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept) @@ -225,7 +222,7 @@ def process_extensions(self, headers: Headers) -> list[Extension]: if extensions: if self.available_extensions is None: - raise InvalidHandshake("no extensions supported") + raise NegotiationError("no extensions supported") parsed_extensions: list[ExtensionHeader] = sum( [parse_extension(header_value) for header_value in extensions], [] @@ -280,15 +277,17 @@ def process_subprotocol(self, headers: Headers) -> Subprotocol | None: if subprotocols: if self.available_subprotocols is None: - raise InvalidHandshake("no subprotocols supported") + raise NegotiationError("no subprotocols supported") parsed_subprotocols: Sequence[Subprotocol] = sum( [parse_subprotocol(header_value) for header_value in subprotocols], [] ) if len(parsed_subprotocols) > 1: - subprotocols_display = ", ".join(parsed_subprotocols) - raise InvalidHandshake(f"multiple subprotocols: {subprotocols_display}") + raise InvalidHeader( + "Sec-WebSocket-Protocol", + f"multiple values: {', '.join(parsed_subprotocols)}", + ) subprotocol = parsed_subprotocols[0] diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index 3d64d951e..106d6f393 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -17,7 +17,7 @@ class MultipleValuesError(LookupError): """ - Exception raised when :class:`Headers` has more than one value for a key. + Exception raised when :class:`Headers` has multiple values for a key. """ diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 057acb656..25142ea25 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -19,10 +19,9 @@ from ..asyncio.compatibility import asyncio_timeout from ..datastructures import Headers, HeadersLike from ..exceptions import ( - InvalidHandshake, InvalidHeader, + InvalidHeaderValue, NegotiationError, - SecurityError, ) from ..extensions import ClientExtensionFactory, Extension @@ -181,7 +180,7 @@ def process_extensions( if header_values: if available_extensions is None: - raise InvalidHandshake("no extensions supported") + raise NegotiationError("no extensions supported") parsed_header_values: list[ExtensionHeader] = sum( [parse_extension(header_value) for header_value in header_values], [] @@ -235,15 +234,17 @@ def process_subprotocol( if header_values: if available_subprotocols is None: - raise InvalidHandshake("no subprotocols supported") + raise NegotiationError("no subprotocols supported") parsed_header_values: Sequence[Subprotocol] = sum( [parse_subprotocol(header_value) for header_value in header_values], [] ) if len(parsed_header_values) > 1: - subprotocols = ", ".join(parsed_header_values) - raise InvalidHandshake(f"multiple subprotocols: {subprotocols}") + raise InvalidHeaderValue( + "Sec-WebSocket-Protocol", + f"multiple values: {', '.join(parsed_header_values)}", + ) subprotocol = parsed_header_values[0] diff --git a/src/websockets/legacy/handshake.py b/src/websockets/legacy/handshake.py index 2a39c1b03..6a7157c01 100644 --- a/src/websockets/legacy/handshake.py +++ b/src/websockets/legacy/handshake.py @@ -76,9 +76,7 @@ def check_request(headers: Headers) -> str: except KeyError as exc: raise InvalidHeader("Sec-WebSocket-Key") from exc except MultipleValuesError as exc: - raise InvalidHeader( - "Sec-WebSocket-Key", "more than one Sec-WebSocket-Key header found" - ) from exc + raise InvalidHeader("Sec-WebSocket-Key", "multiple values") from exc try: raw_key = base64.b64decode(s_w_key.encode(), validate=True) @@ -92,9 +90,7 @@ def check_request(headers: Headers) -> str: except KeyError as exc: raise InvalidHeader("Sec-WebSocket-Version") from exc except MultipleValuesError as exc: - raise InvalidHeader( - "Sec-WebSocket-Version", "more than one Sec-WebSocket-Version header found" - ) from exc + raise InvalidHeader("Sec-WebSocket-Version", "multiple values") from exc if s_w_version != "13": raise InvalidHeaderValue("Sec-WebSocket-Version", s_w_version) @@ -156,9 +152,7 @@ def check_response(headers: Headers, key: str) -> None: except KeyError as exc: raise InvalidHeader("Sec-WebSocket-Accept") from exc except MultipleValuesError as exc: - raise InvalidHeader( - "Sec-WebSocket-Accept", "more than one Sec-WebSocket-Accept header found" - ) from exc + raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from exc if s_w_accept != accept(key): raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index a71996e45..b31cc25b8 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -397,7 +397,7 @@ def process_origin( try: origin = headers.get("Origin") except MultipleValuesError as exc: - raise InvalidHeader("Origin", "more than one Origin header found") from exc + raise InvalidHeader("Origin", "multiple values") from exc if origin is not None: origin = cast(Origin, origin) if origins is not None: diff --git a/src/websockets/server.py b/src/websockets/server.py index 11ba8b425..3a03378b2 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -257,9 +257,7 @@ def process_request( except KeyError as exc: raise InvalidHeader("Sec-WebSocket-Key") from exc except MultipleValuesError as exc: - raise InvalidHeader( - "Sec-WebSocket-Key", "more than one Sec-WebSocket-Key header found" - ) from exc + raise InvalidHeader("Sec-WebSocket-Key", "multiple values") from exc try: raw_key = base64.b64decode(key.encode(), validate=True) @@ -273,10 +271,7 @@ def process_request( except KeyError as exc: raise InvalidHeader("Sec-WebSocket-Version") from exc except MultipleValuesError as exc: - raise InvalidHeader( - "Sec-WebSocket-Version", - "more than one Sec-WebSocket-Version header found", - ) from exc + raise InvalidHeader("Sec-WebSocket-Version", "multiple values") from exc if version != "13": raise InvalidHeaderValue("Sec-WebSocket-Version", version) @@ -315,7 +310,7 @@ def process_origin(self, headers: Headers) -> Origin | None: try: origin = headers.get("Origin") except MultipleValuesError as exc: - raise InvalidHeader("Origin", "more than one Origin header found") from exc + raise InvalidHeader("Origin", "multiple values") from exc if origin is not None: origin = cast(Origin, origin) if self.origins is not None: diff --git a/tests/test_client.py b/tests/test_client.py index c83c87038..d798a66f9 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -342,8 +342,7 @@ def test_multiple_accept(self): raise client.handshake_exc self.assertEqual( str(raised.exception), - "invalid Sec-WebSocket-Accept header: " - "more than one Sec-WebSocket-Accept header found", + "invalid Sec-WebSocket-Accept header: multiple values", ) def test_invalid_accept(self): @@ -556,7 +555,9 @@ def test_multiple_subprotocols(self): with self.assertRaises(InvalidHandshake) as raised: raise client.handshake_exc self.assertEqual( - str(raised.exception), "multiple subprotocols: superchat, chat" + str(raised.exception), + "invalid Sec-WebSocket-Protocol header: " + "multiple values: superchat, chat", ) def test_supported_subprotocol(self): diff --git a/tests/test_server.py b/tests/test_server.py index e4460dcba..e7d249f49 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -305,8 +305,7 @@ def test_multiple_key(self): raise server.handshake_exc self.assertEqual( str(raised.exception), - "invalid Sec-WebSocket-Key header: " - "more than one Sec-WebSocket-Key header found", + "invalid Sec-WebSocket-Key header: multiple values", ) def test_invalid_key(self): @@ -366,8 +365,7 @@ def test_multiple_version(self): raise server.handshake_exc self.assertEqual( str(raised.exception), - "invalid Sec-WebSocket-Version header: " - "more than one Sec-WebSocket-Version header found", + "invalid Sec-WebSocket-Version header: multiple values", ) def test_invalid_version(self): @@ -437,7 +435,7 @@ def test_multiple_origin(self): raise server.handshake_exc self.assertEqual( str(raised.exception), - "invalid Origin header: more than one Origin header found", + "invalid Origin header: multiple values", ) def test_supported_origin(self): From 4e17142e81d6d832061fd69694875e0d962c54ed Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Aug 2024 14:10:18 +0200 Subject: [PATCH 1366/1539] Improve documentation of exceptions. Group and order them. Extend docstrings. --- docs/project/changelog.rst | 2 +- docs/reference/exceptions.rst | 33 +++++++-- src/websockets/exceptions.py | 132 +++++++++++++++++++--------------- tests/test_exceptions.py | 24 +++---- 4 files changed, 114 insertions(+), 77 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 1c87882f0..e82f61753 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -613,7 +613,7 @@ Bug fixes * Aligned maximum cookie size with popular web browsers. * Ensured cancellation always propagates, even on Python versions where - :exc:`~asyncio.CancelledError` inherits :exc:`Exception`. + :exc:`~asyncio.CancelledError` inherits from :exc:`Exception`. .. _8.1: diff --git a/docs/reference/exceptions.rst b/docs/reference/exceptions.rst index 062e69df1..14a8edcd1 100644 --- a/docs/reference/exceptions.rst +++ b/docs/reference/exceptions.rst @@ -5,16 +5,35 @@ Exceptions .. autoexception:: WebSocketException +Connection closed +----------------- + +:meth:`~websockets.asyncio.connection.Connection.recv`, +:meth:`~websockets.asyncio.connection.Connection.send`, and similar methods +raise the exceptions below when the connection is closed. This is the expected +way to detect disconnections. + .. autoexception:: ConnectionClosed +.. autoexception:: ConnectionClosedOK + .. autoexception:: ConnectionClosedError -.. autoexception:: ConnectionClosedOK +Connection failed +----------------- + +These exceptions are raised by :func:`~websockets.asyncio.client.connect` when +the opening handshake fails and the connection cannot be established. They are +also reported by :func:`~websockets.asyncio.server.serve` in logs. + +.. autoexception:: InvalidURI .. autoexception:: InvalidHandshake .. autoexception:: SecurityError +.. autoexception:: InvalidStatus + .. autoexception:: InvalidHeader .. autoexception:: InvalidHeaderFormat @@ -25,8 +44,6 @@ Exceptions .. autoexception:: InvalidUpgrade -.. autoexception:: InvalidStatus - .. autoexception:: NegotiationError .. autoexception:: DuplicateParameter @@ -35,13 +52,17 @@ Exceptions .. autoexception:: InvalidParameterValue -.. autoexception:: InvalidState +Sans-I/O exceptions +------------------- -.. autoexception:: InvalidURI +These exceptions are only raised by the Sans-I/O implementation. They are +translated to :exc:`ConnectionClosedError` in the other implementations. + +.. autoexception:: ProtocolError .. autoexception:: PayloadTooBig -.. autoexception:: ProtocolError +.. autoexception:: InvalidState Legacy exceptions ----------------- diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index b5aabf75d..8d998bfdb 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -1,30 +1,30 @@ """ -:mod:`websockets.exceptions` defines the following exception hierarchy: +:mod:`websockets.exceptions` defines the following hierarchy of exceptions. * :exc:`WebSocketException` * :exc:`ConnectionClosed` - * :exc:`ConnectionClosedError` * :exc:`ConnectionClosedOK` + * :exc:`ConnectionClosedError` + * :exc:`InvalidURI` * :exc:`InvalidHandshake` * :exc:`SecurityError` * :exc:`InvalidMessage` (legacy) + * :exc:`InvalidStatus` + * :exc:`InvalidStatusCode` (legacy) * :exc:`InvalidHeader` * :exc:`InvalidHeaderFormat` * :exc:`InvalidHeaderValue` * :exc:`InvalidOrigin` * :exc:`InvalidUpgrade` - * :exc:`InvalidStatus` - * :exc:`InvalidStatusCode` (legacy) * :exc:`NegotiationError` * :exc:`DuplicateParameter` * :exc:`InvalidParameterName` * :exc:`InvalidParameterValue` * :exc:`AbortHandshake` (legacy) * :exc:`RedirectHandshake` (legacy) - * :exc:`InvalidState` - * :exc:`InvalidURI` - * :exc:`PayloadTooBig` - * :exc:`ProtocolError` + * :exc:`ProtocolError` (Sans-I/O) + * :exc:`PayloadTooBig` (Sans-I/O) + * :exc:`InvalidState` (Sans-I/O) """ @@ -40,29 +40,29 @@ __all__ = [ "WebSocketException", "ConnectionClosed", - "ConnectionClosedError", "ConnectionClosedOK", + "ConnectionClosedError", + "InvalidURI", "InvalidHandshake", "SecurityError", "InvalidMessage", + "InvalidStatus", + "InvalidStatusCode", "InvalidHeader", "InvalidHeaderFormat", "InvalidHeaderValue", "InvalidOrigin", "InvalidUpgrade", - "InvalidStatus", - "InvalidStatusCode", "NegotiationError", "DuplicateParameter", "InvalidParameterName", "InvalidParameterValue", "AbortHandshake", "RedirectHandshake", - "InvalidState", - "InvalidURI", - "PayloadTooBig", "ProtocolError", "WebSocketProtocolError", + "PayloadTooBig", + "InvalidState", ] @@ -139,6 +139,16 @@ def reason(self) -> str: return self.rcvd.reason +class ConnectionClosedOK(ConnectionClosed): + """ + Like :exc:`ConnectionClosed`, when the connection terminated properly. + + A close code with code 1000 (OK) or 1001 (going away) or without a code was + received and sent. + + """ + + class ConnectionClosedError(ConnectionClosed): """ Like :exc:`ConnectionClosed`, when the connection terminated with an error. @@ -149,19 +159,23 @@ class ConnectionClosedError(ConnectionClosed): """ -class ConnectionClosedOK(ConnectionClosed): +class InvalidURI(WebSocketException): """ - Like :exc:`ConnectionClosed`, when the connection terminated properly. - - A close code with code 1000 (OK) or 1001 (going away) or without a code was - received and sent. + Raised when connecting to a URI that isn't a valid WebSocket URI. """ + def __init__(self, uri: str, msg: str) -> None: + self.uri = uri + self.msg = msg + + def __str__(self) -> str: + return f"{self.uri} isn't a valid URI: {self.msg}" + class InvalidHandshake(WebSocketException): """ - Raised during the handshake when the WebSocket connection fails. + Base class for exceptions raised when the opening handshake fails. """ @@ -170,10 +184,27 @@ class SecurityError(InvalidHandshake): """ Raised when a handshake request or response breaks a security rule. - Security limits are hard coded. + Security limits can be configured with :doc:`environment variables + <../reference/variables>`. + + """ + + +class InvalidStatus(InvalidHandshake): + """ + Raised when a handshake response rejects the WebSocket upgrade. """ + def __init__(self, response: http11.Response) -> None: + self.response = response + + def __str__(self) -> str: + return ( + "server rejected WebSocket connection: " + f"HTTP {self.response.status_code:d}" + ) + class InvalidHeader(InvalidHandshake): """ @@ -210,7 +241,7 @@ class InvalidHeaderValue(InvalidHeader): """ Raised when an HTTP header has a wrong value. - The format of the header is correct but a value isn't acceptable. + The format of the header is correct but the value isn't acceptable. """ @@ -232,25 +263,9 @@ class InvalidUpgrade(InvalidHeader): """ -class InvalidStatus(InvalidHandshake): - """ - Raised when a handshake response rejects the WebSocket upgrade. - - """ - - def __init__(self, response: http11.Response) -> None: - self.response = response - - def __str__(self) -> str: - return ( - "server rejected WebSocket connection: " - f"HTTP {self.response.status_code:d}" - ) - - class NegotiationError(InvalidHandshake): """ - Raised when negotiating an extension fails. + Raised when negotiating an extension or a subprotocol fails. """ @@ -300,41 +315,42 @@ def __str__(self) -> str: return f"invalid value for parameter {self.name}: {self.value}" -class InvalidState(WebSocketException, AssertionError): +class ProtocolError(WebSocketException): """ - Raised when an operation is forbidden in the current state. + Raised when receiving or sending a frame that breaks the protocol. - This exception is an implementation detail. + The Sans-I/O implementation raises this exception when: - It should never be raised in normal circumstances. + * receiving or sending a frame that contains invalid data; + * receiving or sending an invalid sequence of frames. """ -class InvalidURI(WebSocketException): +class PayloadTooBig(WebSocketException): """ - Raised when connecting to a URI that isn't a valid WebSocket URI. + Raised when parsing a frame with a payload that exceeds the maximum size. - """ - - def __init__(self, uri: str, msg: str) -> None: - self.uri = uri - self.msg = msg - - def __str__(self) -> str: - return f"{self.uri} isn't a valid URI: {self.msg}" + The Sans-I/O layer uses this exception internally. It doesn't bubble up to + the I/O layer. + The :meth:`~websockets.extensions.Extension.decode` method of extensions + must raise :exc:`PayloadTooBig` if decoding a frame would exceed the limit. -class PayloadTooBig(WebSocketException): """ - Raised when receiving a frame with a payload exceeding the maximum size. + +class InvalidState(WebSocketException, AssertionError): """ + Raised when sending a frame is forbidden in the current state. + Specifically, the Sans-I/O layer raises this exception when: -class ProtocolError(WebSocketException): - """ - Raised when a frame breaks the protocol. + * sending a data frame to a connection in a state other + :attr:`~websockets.protocol.State.OPEN`; + * sending a control frame to a connection in a state other than + :attr:`~websockets.protocol.State.OPEN` or + :attr:`~websockets.protocol.State.CLOSING`. """ diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index b54903275..5620b8a53 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -79,6 +79,10 @@ def test_str(self): ), "no close frame received or sent", ), + ( + InvalidURI("|", "not at all!"), + "| isn't a valid URI: not at all!", + ), ( InvalidHandshake("invalid request"), "invalid request", @@ -87,6 +91,10 @@ def test_str(self): SecurityError("redirect from WSS to WS"), "redirect from WSS to WS", ), + ( + InvalidStatus(Response(401, "Unauthorized", Headers())), + "server rejected WebSocket connection: HTTP 401", + ), ( InvalidHeader("Name"), "missing Name header", @@ -123,10 +131,6 @@ def test_str(self): InvalidUpgrade("Connection", "websocket"), "invalid Connection header: websocket", ), - ( - InvalidStatus(Response(401, "Unauthorized", Headers())), - "server rejected WebSocket connection: HTTP 401", - ), ( NegotiationError("unsupported subprotocol: spam"), "unsupported subprotocol: spam", @@ -152,20 +156,16 @@ def test_str(self): "invalid value for parameter a: |", ), ( - InvalidState("WebSocket connection isn't established yet"), - "WebSocket connection isn't established yet", - ), - ( - InvalidURI("|", "not at all!"), - "| isn't a valid URI: not at all!", + ProtocolError("invalid opcode: 7"), + "invalid opcode: 7", ), ( PayloadTooBig("payload length exceeds limit: 2 > 1 bytes"), "payload length exceeds limit: 2 > 1 bytes", ), ( - ProtocolError("invalid opcode: 7"), - "invalid opcode: 7", + InvalidState("WebSocket connection isn't established yet"), + "WebSocket connection isn't established yet", ), ]: with self.subTest(exception=exception): From 2a9dfb593eb7dd17424242cf344945af0510aa3f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Aug 2024 14:40:58 +0200 Subject: [PATCH 1367/1539] Simplify untangling of import cycles Three imports at the bottom, related to type annotations, is a lesser evil compared to dozens of module-prefixed identifiers, a departure from the coding style of this library. Refs #989. --- src/websockets/exceptions.py | 4 +- src/websockets/extensions/base.py | 11 ++--- .../extensions/permessage_deflate.py | 48 +++++++++++-------- src/websockets/frames.py | 24 +++++----- src/websockets/headers.py | 32 +++++-------- src/websockets/http11.py | 23 ++++----- src/websockets/uri.py | 10 ++-- 7 files changed, 74 insertions(+), 78 deletions(-) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 8d998bfdb..b2b679e6b 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -33,7 +33,6 @@ import typing import warnings -from . import frames, http11 from .imports import lazy_import @@ -376,3 +375,6 @@ class InvalidState(WebSocketException, AssertionError): "WebSocketProtocolError": ".legacy.exceptions", }, ) + +# At the bottom to break import cycles created by type annotations. +from . import frames, http11 # noqa: E402 diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index a6c76c3d4..75bae6b77 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -2,7 +2,7 @@ from typing import Sequence -from .. import frames +from ..frames import Frame from ..typing import ExtensionName, ExtensionParameter @@ -18,12 +18,7 @@ class Extension: name: ExtensionName """Extension identifier.""" - def decode( - self, - frame: frames.Frame, - *, - max_size: int | None = None, - ) -> frames.Frame: + def decode(self, frame: Frame, *, max_size: int | None = None) -> Frame: """ Decode an incoming frame. @@ -40,7 +35,7 @@ def decode( """ raise NotImplementedError - def encode(self, frame: frames.Frame) -> frames.Frame: + def encode(self, frame: Frame) -> Frame: """ Encode an outgoing frame. diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 5b907b79f..25d2c1c45 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -4,7 +4,15 @@ import zlib from typing import Any, Sequence -from .. import exceptions, frames +from .. import frames +from ..exceptions import ( + DuplicateParameter, + InvalidParameterName, + InvalidParameterValue, + NegotiationError, + PayloadTooBig, + ProtocolError, +) from ..typing import ExtensionName, ExtensionParameter from .base import ClientExtensionFactory, Extension, ServerExtensionFactory @@ -129,9 +137,9 @@ def decode( try: data = self.decoder.decompress(data, max_length) except zlib.error as exc: - raise exceptions.ProtocolError("decompression failed") from exc + raise ProtocolError("decompression failed") from exc if self.decoder.unconsumed_tail: - raise exceptions.PayloadTooBig(f"over size limit (? > {max_size} bytes)") + raise PayloadTooBig(f"over size limit (? > {max_size} bytes)") # Allow garbage collection of the decoder if it won't be reused. if frame.fin and self.remote_no_context_takeover: @@ -215,40 +223,40 @@ def _extract_parameters( for name, value in params: if name == "server_no_context_takeover": if server_no_context_takeover: - raise exceptions.DuplicateParameter(name) + raise DuplicateParameter(name) if value is None: server_no_context_takeover = True else: - raise exceptions.InvalidParameterValue(name, value) + raise InvalidParameterValue(name, value) elif name == "client_no_context_takeover": if client_no_context_takeover: - raise exceptions.DuplicateParameter(name) + raise DuplicateParameter(name) if value is None: client_no_context_takeover = True else: - raise exceptions.InvalidParameterValue(name, value) + raise InvalidParameterValue(name, value) elif name == "server_max_window_bits": if server_max_window_bits is not None: - raise exceptions.DuplicateParameter(name) + raise DuplicateParameter(name) if value in _MAX_WINDOW_BITS_VALUES: server_max_window_bits = int(value) else: - raise exceptions.InvalidParameterValue(name, value) + raise InvalidParameterValue(name, value) elif name == "client_max_window_bits": if client_max_window_bits is not None: - raise exceptions.DuplicateParameter(name) + raise DuplicateParameter(name) if is_server and value is None: # only in handshake requests client_max_window_bits = True elif value in _MAX_WINDOW_BITS_VALUES: client_max_window_bits = int(value) else: - raise exceptions.InvalidParameterValue(name, value) + raise InvalidParameterValue(name, value) else: - raise exceptions.InvalidParameterName(name) + raise InvalidParameterName(name) return ( server_no_context_takeover, @@ -340,7 +348,7 @@ def process_response_params( """ if any(other.name == self.name for other in accepted_extensions): - raise exceptions.NegotiationError(f"received duplicate {self.name}") + raise NegotiationError(f"received duplicate {self.name}") # Request parameters are available in instance variables. @@ -366,7 +374,7 @@ def process_response_params( if self.server_no_context_takeover: if not server_no_context_takeover: - raise exceptions.NegotiationError("expected server_no_context_takeover") + raise NegotiationError("expected server_no_context_takeover") # client_no_context_takeover # @@ -396,9 +404,9 @@ def process_response_params( else: if server_max_window_bits is None: - raise exceptions.NegotiationError("expected server_max_window_bits") + raise NegotiationError("expected server_max_window_bits") elif server_max_window_bits > self.server_max_window_bits: - raise exceptions.NegotiationError("unsupported server_max_window_bits") + raise NegotiationError("unsupported server_max_window_bits") # client_max_window_bits @@ -414,7 +422,7 @@ def process_response_params( if self.client_max_window_bits is None: if client_max_window_bits is not None: - raise exceptions.NegotiationError("unexpected client_max_window_bits") + raise NegotiationError("unexpected client_max_window_bits") elif self.client_max_window_bits is True: pass @@ -423,7 +431,7 @@ def process_response_params( if client_max_window_bits is None: client_max_window_bits = self.client_max_window_bits elif client_max_window_bits > self.client_max_window_bits: - raise exceptions.NegotiationError("unsupported client_max_window_bits") + raise NegotiationError("unsupported client_max_window_bits") return PerMessageDeflate( server_no_context_takeover, # remote_no_context_takeover @@ -534,7 +542,7 @@ def process_request_params( """ if any(other.name == self.name for other in accepted_extensions): - raise exceptions.NegotiationError(f"skipped duplicate {self.name}") + raise NegotiationError(f"skipped duplicate {self.name}") # Load request parameters in local variables. ( @@ -613,7 +621,7 @@ def process_request_params( else: if client_max_window_bits is None: if self.require_client_max_window_bits: - raise exceptions.NegotiationError("required client_max_window_bits") + raise NegotiationError("required client_max_window_bits") elif client_max_window_bits is True: client_max_window_bits = self.client_max_window_bits elif self.client_max_window_bits < client_max_window_bits: diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 8e44dd3a2..a63bdc3b6 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -8,7 +8,7 @@ import struct from typing import Callable, Generator, Sequence -from . import exceptions, extensions +from .exceptions import PayloadTooBig, ProtocolError try: @@ -239,10 +239,10 @@ def parse( try: opcode = Opcode(head1 & 0b00001111) except ValueError as exc: - raise exceptions.ProtocolError("invalid opcode") from exc + raise ProtocolError("invalid opcode") from exc if (True if head2 & 0b10000000 else False) != mask: - raise exceptions.ProtocolError("incorrect masking") + raise ProtocolError("incorrect masking") length = head2 & 0b01111111 if length == 126: @@ -252,9 +252,7 @@ def parse( data = yield from read_exact(8) (length,) = struct.unpack("!Q", data) if max_size is not None and length > max_size: - raise exceptions.PayloadTooBig( - f"over size limit ({length} > {max_size} bytes)" - ) + raise PayloadTooBig(f"over size limit ({length} > {max_size} bytes)") if mask: mask_bytes = yield from read_exact(4) @@ -342,13 +340,13 @@ def check(self) -> None: """ if self.rsv1 or self.rsv2 or self.rsv3: - raise exceptions.ProtocolError("reserved bits must be 0") + raise ProtocolError("reserved bits must be 0") if self.opcode in CTRL_OPCODES: if len(self.data) > 125: - raise exceptions.ProtocolError("control frame too long") + raise ProtocolError("control frame too long") if not self.fin: - raise exceptions.ProtocolError("fragmented control frame") + raise ProtocolError("fragmented control frame") @dataclasses.dataclass @@ -405,7 +403,7 @@ def parse(cls, data: bytes) -> Close: elif len(data) == 0: return cls(CloseCode.NO_STATUS_RCVD, "") else: - raise exceptions.ProtocolError("close frame too short") + raise ProtocolError("close frame too short") def serialize(self) -> bytes: """ @@ -424,4 +422,8 @@ def check(self) -> None: """ if not (self.code in EXTERNAL_CLOSE_CODES or 3000 <= self.code < 5000): - raise exceptions.ProtocolError("invalid status code") + raise ProtocolError("invalid status code") + + +# At the bottom to break import cycles created by type annotations. +from . import extensions # noqa: E402 diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 0ffd65233..9103018a0 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -6,7 +6,7 @@ import re from typing import Callable, Sequence, TypeVar, cast -from . import exceptions +from .exceptions import InvalidHeaderFormat, InvalidHeaderValue from .typing import ( ConnectionOption, ExtensionHeader, @@ -108,7 +108,7 @@ def parse_token(header: str, pos: int, header_name: str) -> tuple[str, int]: """ match = _token_re.match(header, pos) if match is None: - raise exceptions.InvalidHeaderFormat(header_name, "expected token", header, pos) + raise InvalidHeaderFormat(header_name, "expected token", header, pos) return match.group(), match.end() @@ -132,9 +132,7 @@ def parse_quoted_string(header: str, pos: int, header_name: str) -> tuple[str, i """ match = _quoted_string_re.match(header, pos) if match is None: - raise exceptions.InvalidHeaderFormat( - header_name, "expected quoted string", header, pos - ) + raise InvalidHeaderFormat(header_name, "expected quoted string", header, pos) return _unquote_re.sub(r"\1", match.group()[1:-1]), match.end() @@ -206,9 +204,7 @@ def parse_list( if peek_ahead(header, pos) == ",": pos = parse_OWS(header, pos + 1) else: - raise exceptions.InvalidHeaderFormat( - header_name, "expected comma", header, pos - ) + raise InvalidHeaderFormat(header_name, "expected comma", header, pos) # Remove extra delimiters before the next item. while peek_ahead(header, pos) == ",": @@ -276,9 +272,7 @@ def parse_upgrade_protocol( """ match = _protocol_re.match(header, pos) if match is None: - raise exceptions.InvalidHeaderFormat( - header_name, "expected protocol", header, pos - ) + raise InvalidHeaderFormat(header_name, "expected protocol", header, pos) return cast(UpgradeProtocol, match.group()), match.end() @@ -324,7 +318,7 @@ def parse_extension_item_param( # the value after quoted-string unescaping MUST conform to # the 'token' ABNF. if _token_re.fullmatch(value) is None: - raise exceptions.InvalidHeaderFormat( + raise InvalidHeaderFormat( header_name, "invalid quoted header content", header, pos_before ) else: @@ -510,9 +504,7 @@ def parse_token68(header: str, pos: int, header_name: str) -> tuple[str, int]: """ match = _token68_re.match(header, pos) if match is None: - raise exceptions.InvalidHeaderFormat( - header_name, "expected token68", header, pos - ) + raise InvalidHeaderFormat(header_name, "expected token68", header, pos) return match.group(), match.end() @@ -522,7 +514,7 @@ def parse_end(header: str, pos: int, header_name: str) -> None: """ if pos < len(header): - raise exceptions.InvalidHeaderFormat(header_name, "trailing data", header, pos) + raise InvalidHeaderFormat(header_name, "trailing data", header, pos) def parse_authorization_basic(header: str) -> tuple[str, str]: @@ -543,12 +535,12 @@ def parse_authorization_basic(header: str) -> tuple[str, str]: # https://datatracker.ietf.org/doc/html/rfc7617#section-2 scheme, pos = parse_token(header, 0, "Authorization") if scheme.lower() != "basic": - raise exceptions.InvalidHeaderValue( + raise InvalidHeaderValue( "Authorization", f"unsupported scheme: {scheme}", ) if peek_ahead(header, pos) != " ": - raise exceptions.InvalidHeaderFormat( + raise InvalidHeaderFormat( "Authorization", "expected space after scheme", header, pos ) pos += 1 @@ -558,14 +550,14 @@ def parse_authorization_basic(header: str) -> tuple[str, str]: try: user_pass = base64.b64decode(basic_credentials.encode()).decode() except binascii.Error: - raise exceptions.InvalidHeaderValue( + raise InvalidHeaderValue( "Authorization", "expected base64-encoded credentials", ) from None try: username, password = user_pass.split(":", 1) except ValueError: - raise exceptions.InvalidHeaderValue( + raise InvalidHeaderValue( "Authorization", "expected username:password credentials", ) from None diff --git a/src/websockets/http11.py b/src/websockets/http11.py index b86c6ca4a..562bcb72c 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -7,7 +7,8 @@ import warnings from typing import Callable, Generator -from . import datastructures, exceptions +from .datastructures import Headers +from .exceptions import SecurityError from .version import version as websockets_version @@ -79,7 +80,7 @@ class Request: """ path: str - headers: datastructures.Headers + headers: Headers # body isn't useful is the context of this library. _exception: Exception | None = None @@ -183,7 +184,7 @@ class Response: status_code: int reason_phrase: str - headers: datastructures.Headers + headers: Headers body: bytes | None = None _exception: Exception | None = None @@ -280,13 +281,9 @@ def parse( try: body = yield from read_to_eof(MAX_BODY_SIZE) except RuntimeError: - raise exceptions.SecurityError( - f"body too large: over {MAX_BODY_SIZE} bytes" - ) + raise SecurityError(f"body too large: over {MAX_BODY_SIZE} bytes") elif content_length > MAX_BODY_SIZE: - raise exceptions.SecurityError( - f"body too large: {content_length} bytes" - ) + raise SecurityError(f"body too large: {content_length} bytes") else: body = yield from read_exact(content_length) @@ -308,7 +305,7 @@ def serialize(self) -> bytes: def parse_headers( read_line: Callable[[int], Generator[None, None, bytes]], -) -> Generator[None, None, datastructures.Headers]: +) -> Generator[None, None, Headers]: """ Parse HTTP headers. @@ -328,7 +325,7 @@ def parse_headers( # We don't attempt to support obsolete line folding. - headers = datastructures.Headers() + headers = Headers() for _ in range(MAX_NUM_HEADERS + 1): try: line = yield from parse_line(read_line) @@ -352,7 +349,7 @@ def parse_headers( headers[name] = value else: - raise exceptions.SecurityError("too many HTTP headers") + raise SecurityError("too many HTTP headers") return headers @@ -377,7 +374,7 @@ def parse_line( try: line = yield from read_line(MAX_LINE_LENGTH) except RuntimeError: - raise exceptions.SecurityError("line too long") + raise SecurityError("line too long") # Not mandatory but safe - https://datatracker.ietf.org/doc/html/rfc7230#section-3.5 if not line.endswith(b"\r\n"): raise EOFError("line without CRLF") diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 82b35f92a..16bb3f1c1 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -3,7 +3,7 @@ import dataclasses import urllib.parse -from . import exceptions +from .exceptions import InvalidURI __all__ = ["parse_uri", "WebSocketURI"] @@ -73,11 +73,11 @@ def parse_uri(uri: str) -> WebSocketURI: """ parsed = urllib.parse.urlparse(uri) if parsed.scheme not in ["ws", "wss"]: - raise exceptions.InvalidURI(uri, "scheme isn't ws or wss") + raise InvalidURI(uri, "scheme isn't ws or wss") if parsed.hostname is None: - raise exceptions.InvalidURI(uri, "hostname isn't provided") + raise InvalidURI(uri, "hostname isn't provided") if parsed.fragment != "": - raise exceptions.InvalidURI(uri, "fragment identifier is meaningless") + raise InvalidURI(uri, "fragment identifier is meaningless") secure = parsed.scheme == "wss" host = parsed.hostname @@ -89,7 +89,7 @@ def parse_uri(uri: str) -> WebSocketURI: # urllib.parse.urlparse accepts URLs with a username but without a # password. This doesn't make sense for HTTP Basic Auth credentials. if username is not None and password is None: - raise exceptions.InvalidURI(uri, "username provided without password") + raise InvalidURI(uri, "username provided without password") try: uri.encode("ascii") From ed7d392f58bc08ead4556c874d9d980b11874d2a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Aug 2024 21:56:13 +0200 Subject: [PATCH 1368/1539] Note when each deprecation occurred. The legacy package was ignored because it will be deprecated and removed wholesale. --- docs/project/changelog.rst | 11 +++++++++++ src/websockets/__init__.py | 1 + src/websockets/client.py | 2 +- src/websockets/connection.py | 2 +- src/websockets/http.py | 2 +- src/websockets/http11.py | 4 ++-- src/websockets/server.py | 2 +- src/websockets/sync/client.py | 5 ++++- src/websockets/sync/server.py | 7 +++++-- 9 files changed, 27 insertions(+), 9 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index e82f61753..c239cfe5d 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -32,6 +32,17 @@ notice. *In development* +Backwards-incompatible changes +.............................. + +.. admonition:: The ``code`` and ``reason`` attributes of + :exc:`~exceptions.ConnectionClosed` are deprecated. + :class: note + + They were removed from the documentation in version 10.0, due to their + spec-compliant but counter-intuitive behavior, but they were kept in + the code for backwards compatibility. They're now formally deprecated. + New features ............ diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 12141adb0..ac02a9f7e 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -187,6 +187,7 @@ "Subprotocol": ".typing", }, deprecated_aliases={ + # deprecated in 9.0 - 2021-09-01 "framing": ".legacy", "handshake": ".legacy", "parse_uri": ".uri", diff --git a/src/websockets/client.py b/src/websockets/client.py index c408c1447..ae467993a 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -352,7 +352,7 @@ def parse(self) -> Generator[None, None, None]: class ClientConnection(ClientProtocol): def __init__(self, *args: Any, **kwargs: Any) -> None: - warnings.warn( + warnings.warn( # deprecated in 11.0 - 2023-04-02 "ClientConnection was renamed to ClientProtocol", DeprecationWarning, ) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 7942c1a28..5e78e3447 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -5,7 +5,7 @@ from .protocol import SEND_EOF, Protocol as Connection, Side, State # noqa: F401 -warnings.warn( +warnings.warn( # deprecated in 11.0 - 2023-04-02 "websockets.connection was renamed to websockets.protocol " "and Connection was renamed to Protocol", DeprecationWarning, diff --git a/src/websockets/http.py b/src/websockets/http.py index 3dc560062..0ff5598c7 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -6,7 +6,7 @@ from .legacy.http import read_request, read_response # noqa: F401 -warnings.warn( +warnings.warn( # deprecated in 9.0 - 2021-09-01 "Headers and MultipleValuesError were moved " "from websockets.http to websockets.datastructures" "and read_request and read_response were moved " diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 562bcb72c..61865bb92 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -87,7 +87,7 @@ class Request: @property def exception(self) -> Exception | None: # pragma: no cover - warnings.warn( + warnings.warn( # deprecated in 10.3 - 2022-04-17 "Request.exception is deprecated; " "use ServerProtocol.handshake_exc instead", DeprecationWarning, @@ -191,7 +191,7 @@ class Response: @property def exception(self) -> Exception | None: # pragma: no cover - warnings.warn( + warnings.warn( # deprecated in 10.3 - 2022-04-17 "Response.exception is deprecated; " "use ClientProtocol.handshake_exc instead", DeprecationWarning, diff --git a/src/websockets/server.py b/src/websockets/server.py index 3a03378b2..ac62800d6 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -575,7 +575,7 @@ def parse(self) -> Generator[None, None, None]: class ServerConnection(ServerProtocol): def __init__(self, *args: Any, **kwargs: Any) -> None: - warnings.warn( + warnings.warn( # deprecated in 11.0 - 2023-04-02 "ServerConnection was renamed to ServerProtocol", DeprecationWarning, ) diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 3c700a377..6a04515f0 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -212,7 +212,10 @@ def connect( # Backwards compatibility: ssl used to be called ssl_context. if ssl is None and "ssl_context" in kwargs: ssl = kwargs.pop("ssl_context") - warnings.warn("ssl_context was renamed to ssl", DeprecationWarning) + warnings.warn( # deprecated in 13.0 - 2024-08-20 + "ssl_context was renamed to ssl", + DeprecationWarning, + ) wsuri = parse_uri(uri) if not wsuri.secure and ssl is not None: diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 5e22e112e..15de458b5 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -304,7 +304,7 @@ def __exit__( def __getattr__(name: str) -> Any: if name == "WebSocketServer": - warnings.warn( + warnings.warn( # deprecated in 13.0 - 2024-08-20 "WebSocketServer was renamed to Server", DeprecationWarning, ) @@ -446,7 +446,10 @@ def handler(websocket): # Backwards compatibility: ssl used to be called ssl_context. if ssl is None and "ssl_context" in kwargs: ssl = kwargs.pop("ssl_context") - warnings.warn("ssl_context was renamed to ssl", DeprecationWarning) + warnings.warn( # deprecated in 13.0 - 2024-08-20 + "ssl_context was renamed to ssl", + DeprecationWarning, + ) if subprotocols is not None: validate_subprotocols(subprotocols) From d61dc43614063b1620b30f50df3947e6079bb371 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Aug 2024 22:46:57 +0200 Subject: [PATCH 1369/1539] Restructure upgrade guide. --- docs/howto/upgrade.rst | 346 +++++++++++++++++++++++------------------ 1 file changed, 194 insertions(+), 152 deletions(-) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 602d8a4e6..18c3cc127 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -16,8 +16,10 @@ The recommended upgrade process is: should stick to the original implementation until they're added. 3. `Update import paths`_. For straightforward usage of websockets, this could be the only step you need to take. Upgrading could be transparent. -4. `Review API changes`_ and adapt your application to preserve its current - functionality or take advantage of improvements in the new implementation. +4. Check out `new features and improvements`_ and consider taking advantage of + them to improve your application. +5. Review `API changes`_ and adapt your application to preserve its current + functionality. In the interest of brevity, only :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` are discussed below but everything also applies @@ -99,11 +101,12 @@ For context, the ``websockets`` package is structured as follows: * The new implementation is found in the ``websockets.asyncio`` package. * The original implementation was moved to the ``websockets.legacy`` package. -* The ``websockets`` package provides aliases for convenience. +* The ``websockets`` package provides aliases for convenience. Currently, they + point to the original implementation. They will be updated to point to the new + implementation when it feels mature. * The ``websockets.client`` and ``websockets.server`` packages provide aliases - for backwards-compatibility with earlier versions of websockets. -* Currently, all aliases point to the original implementation. In the future, - they will point to the new implementation or they will be deprecated. + for backwards-compatibility with earlier versions of websockets. They will + be deprecated together with the original implementation. To upgrade to the new :mod:`asyncio` implementation, change import paths as shown in the tables below. @@ -153,7 +156,7 @@ Server APIs | ``websockets.server.WebSocketServerProtocol`` |br| | | | :class:`websockets.legacy.server.WebSocketServerProtocol` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.broadcast`` |br| | :func:`websockets.asyncio.server.broadcast` | +| ``websockets.broadcast()`` |br| | :func:`websockets.asyncio.server.broadcast` | | :func:`websockets.legacy.server.broadcast()` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.BasicAuthWebSocketServerProtocol`` |br| | See below :ref:`how to migrate ` to | @@ -165,10 +168,32 @@ Server APIs | :func:`websockets.legacy.auth.basic_auth_protocol_factory` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -.. _Review API changes: +.. _new features and improvements: -API changes ------------ +New features and improvements +----------------------------- + +Customizing the opening handshake +................................. + +On the server side, if you're customizing how :func:`~legacy.server.serve` +processes the opening handshake with the ``process_request``, ``extra_headers``, +or ``select_subprotocol``, you must update your code and you can probably make +it simpler. + +``process_request`` and ``select_subprotocol`` have new signatures. +``process_response`` replaces ``extra_headers`` and provides more flexibility. +See process_request_, select_subprotocol_, and process_response_ below. + +Tracking open connections +......................... + +The new implementation of :class:`~asyncio.server.Server` provides a +:attr:`~asyncio.server.Server.connections` property, which is a set of all open +connections. This didn't exist in the original implementation. + +If you're keeping track of open connections in order to broadcast messages to +all of them, you can simplify your code by using this property. Controlling UTF-8 decoding .......................... @@ -197,97 +222,76 @@ implementation. Depending on your use case, adopting this method may improve performance when streaming large messages. Specifically, it could reduce memory usage. -Tracking open connections -......................... - -The new implementation of :class:`~asyncio.server.Server` provides a -:attr:`~asyncio.server.Server.connections` property, which is a set of all open -connections. This didn't exist in the original implementation. - -If you were keeping track of open connections, you may be able to simplify your -code by using this property. - -.. _basic-auth: - -Performing HTTP Basic Authentication -.................................... +.. _API changes: -.. admonition:: This section applies only to servers. - :class: tip - - On the client side, :func:`~asyncio.client.connect` performs HTTP Basic - Authentication automatically when the URI contains credentials. +API changes +----------- -In the original implementation, the recommended way to add HTTP Basic -Authentication to a server was to set the ``create_protocol`` argument of -:func:`~legacy.server.serve` to a factory function generated by -:func:`~legacy.auth.basic_auth_protocol_factory`:: +Attributes of connection objects +................................ - from websockets.legacy.auth import basic_auth_protocol_factory - from websockets.legacy.server import serve +``path``, ``request_headers``, and ``response_headers`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - async with serve(..., create_protocol=basic_auth_protocol_factory(...)): - ... +The :attr:`~legacy.protocol.WebSocketCommonProtocol.path`, +:attr:`~legacy.protocol.WebSocketCommonProtocol.request_headers` and +:attr:`~legacy.protocol.WebSocketCommonProtocol.response_headers` properties are +replaced by :attr:`~asyncio.connection.Connection.request` and +:attr:`~asyncio.connection.Connection.response`. -In the new implementation, the :func:`~asyncio.server.basic_auth` function -generates a ``process_request`` coroutine that performs HTTP Basic -Authentication:: +If your code uses them, you can update it as follows. - from websockets.asyncio.server import basic_auth, serve +========================================== ========================================== +Legacy :mod:`asyncio` implementation New :mod:`asyncio` implementation +========================================== ========================================== +``connection.path`` ``connection.request.path`` +``connection.request_headers`` ``connection.request.headers`` +``connection.response_headers`` ``connection.response.headers`` +========================================== ========================================== - async with serve(..., process_request=basic_auth(...)): - ... +``open`` and ``closed`` +~~~~~~~~~~~~~~~~~~~~~~~ -:func:`~asyncio.server.basic_auth` accepts either hard coded ``credentials`` or -a ``check_credentials`` coroutine as well as an optional ``realm`` just like -:func:`~legacy.auth.basic_auth_protocol_factory`. Furthermore, -``check_credentials`` may be a function instead of a coroutine. +The :attr:`~legacy.protocol.WebSocketCommonProtocol.open` and +:attr:`~legacy.protocol.WebSocketCommonProtocol.closed` properties are removed. +Using them was discouraged. -This new API has more obvious semantics. That makes it easier to understand and -also easier to extend. +Instead, you should call :meth:`~asyncio.connection.Connection.recv` or +:meth:`~asyncio.connection.Connection.send` and handle +:exc:`~exceptions.ConnectionClosed` exceptions. -In the original implementation, overriding ``create_protocol`` changed the type -of connection objects to :class:`~legacy.auth.BasicAuthWebSocketServerProtocol`, -a subclass of :class:`~legacy.server.WebSocketServerProtocol` that performs HTTP -Basic Authentication in its ``process_request`` method. If you wanted to -customize ``process_request`` further, you had: +If your code uses them, you can update it as follows. -* an ill-defined option: add a ``process_request`` argument to - :func:`~legacy.server.serve`; to tell which one would run first, you had to - experiment or read the code; -* a cumbersome option: subclass - :class:`~legacy.auth.BasicAuthWebSocketServerProtocol`, then pass that - subclass in the ``create_protocol`` argument of - :func:`~legacy.auth.basic_auth_protocol_factory`. +========================================== ========================================== +Legacy :mod:`asyncio` implementation New :mod:`asyncio` implementation +========================================== ========================================== +.. ``from websockets.protocol import State`` +``connection.open`` ``connection.state is State.OPEN`` +``connection.closed`` ``connection.state is State.CLOSED`` +========================================== ========================================== -In the new implementation, you just write a ``process_request`` coroutine:: +Arguments of :func:`~asyncio.client.connect` +............................................ - from websockets.asyncio.server import basic_auth, serve +``extra_headers`` → ``additional_headers`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - process_basic_auth = basic_auth(...) +If you're adding headers to the handshake request sent by +:func:`~legacy.client.connect` with the ``extra_headers`` argument, you must +rename it to ``additional_headers``. - async def process_request(connection, request): - ... # some logic here - response = await process_basic_auth(connection, request) - if response is not None: - return response - ... # more logic here +Arguments of :func:`~asyncio.server.serve` +.......................................... - async with serve(..., process_request=process_request): - ... +``ws_handler`` → ``handler`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Customizing the opening handshake -................................. +The first argument of :func:`~asyncio.server.serve` is now called ``handler`` +instead of ``ws_handler``. It's usually passed as a positional argument, making +this change transparent. If you're passing it as a keyword argument, you must +update its name. -On the client side, if you're adding headers to the handshake request sent by -:func:`~legacy.client.connect` with the ``extra_headers`` argument, you must -rename it to ``additional_headers``. - -On the server side, if you're customizing how :func:`~legacy.server.serve` -processes the opening handshake with the ``process_request``, ``extra_headers``, -or ``select_subprotocol``, you must update your code. ``process_response`` and -``select_subprotocol`` have new signatures; ``process_response`` replaces -``extra_headers`` and provides more flexibility. +.. _process_request: ``process_request`` ~~~~~~~~~~~~~~~~~~~ @@ -302,8 +306,6 @@ an example:: def process_request(path, request_headers): return http.HTTPStatus.OK, [], b"OK\n" - serve(..., process_request=process_request, ...) - # New implementation def process_request(connection, request): @@ -312,16 +314,22 @@ an example:: serve(..., process_request=process_request, ...) ``connection`` is always available in ``process_request``. In the original -implementation, you had to write a subclass of +implementation, if you wanted to make the connection object available in a +``process_request`` method, you had to write a subclass of :class:`~legacy.server.WebSocketServerProtocol` and pass it in the -``create_protocol`` argument to make the connection object available in a -``process_request`` method. This pattern isn't useful anymore; you can replace -it with a ``process_request`` function or coroutine. +``create_protocol`` argument. This pattern isn't useful anymore; you can +replace it with a ``process_request`` function or coroutine. ``path`` and ``headers`` are available as attributes of the ``request`` object. -``process_response`` -~~~~~~~~~~~~~~~~~~~~ +.. _process_response: + +``extra_headers`` → ``process_response`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you're adding headers to the handshake response sent by +:func:`~legacy.server.serve` with the ``extra_headers`` argument, you must write +a ``process_response`` callable instead. ``process_request`` replaces ``extra_headers`` and provides more flexibility. In the most basic case, you would adapt your code as follows:: @@ -346,10 +354,13 @@ In addition, the ``request`` and ``response`` objects are available, which enables a broader range of use cases (e.g., logging) and makes ``process_response`` more useful than ``extra_headers``. +.. _select_subprotocol: + ``select_subprotocol`` ~~~~~~~~~~~~~~~~~~~~~~ -The signature of ``select_subprotocol`` changed. Here's an example:: +If you're selecting a subprotocol, you must update your code because the +signature of ``select_subprotocol`` changed. Here's an example:: # Original implementation @@ -366,37 +377,30 @@ The signature of ``select_subprotocol`` changed. Here's an example:: serve(..., select_subprotocol=select_subprotocol, ...) ``connection`` is always available in ``select_subprotocol``. This brings the -same benefits as in ``process_request``. It may remove the need to subclass of +same benefits as in ``process_request``. It may remove the need to subclass :class:`~legacy.server.WebSocketServerProtocol`. The ``subprotocols`` argument contains the list of subprotocols offered by the client. The list of subprotocols supported by the server was removed because -``select_subprotocols`` already knows which subprotocols it may select and under +``select_subprotocols`` has to know which subprotocols it may select and under which conditions. -Arguments of :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` -.............................................................................. - -``ws_handler`` → ``handler`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The first argument of :func:`~asyncio.server.serve` is now called ``handler`` -instead of ``ws_handler``. It's usually passed as a positional argument, making -this change transparent. If you're passing it as a keyword argument, you must -update its name. +Furthermore, the default behavior when ``select_subprotocol`` isn't provided +changed in two ways: -``create_protocol`` → ``create_connection`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +1. In the original implementation, a server with a list of subprotocols accepted + to continue without a subprotocol. In the new implementation, a server that + is configured with subprotocols rejects connections that don't support any. +2. In the original implementation, when several subprotocols were available, the + server averaged the client's preferences with its own preferences. In the new + implementation, the server just picks the first subprotocol from its list. -The keyword argument of :func:`~asyncio.server.serve` for customizing the -creation of the connection object is now called ``create_connection`` instead of -``create_protocol``. It must return a :class:`~asyncio.server.ServerConnection` -instead of a :class:`~legacy.server.WebSocketServerProtocol`. +If you had a ``select_subprotocol`` for the sole purpose of rejecting +connections without a subprotocol, you can remove it and keep only the +``subprotocols`` argument. -If you were customizing connection objects, you should check the new -implementation and possibly redo your customization. Keep in mind that the -changes to ``process_request`` and ``select_subprotocol`` remove most use cases -for ``create_connection``. +Arguments of :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` +.............................................................................. ``max_queue`` ~~~~~~~~~~~~~ @@ -409,11 +413,15 @@ frames. It used to be the size of a buffer of incoming messages that refilled as soon as a message was read. It used to default to 32 messages. This can make a difference when messages are fragmented in several frames. In -that case, you may want to increase ``max_queue``. If you're writing a high -performance server and you know that you're receiving fragmented messages, -probably you should adopt :meth:`~asyncio.connection.Connection.recv_streaming` -and optimize the performance of reads again. In all other cases, given how -uncommon fragmentation is, you shouldn't worry about this change. +that case, you may want to increase ``max_queue``. + +If you're writing a high performance server and you know that you're receiving +fragmented messages, probably you should adopt +:meth:`~asyncio.connection.Connection.recv_streaming` and optimize the +performance of reads again. + +In all other cases, given how uncommon fragmentation is, you shouldn't worry +about this change. ``read_limit`` ~~~~~~~~~~~~~~ @@ -432,50 +440,84 @@ buffer now. The ``write_limit`` argument of :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` defaults to 32 KiB instead of 64 KiB. -Attributes of connections -......................... +``create_protocol`` → ``create_connection`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -``path``, ``request_headers`` and ``response_headers`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +The keyword argument of :func:`~asyncio.server.serve` for customizing the +creation of the connection object is now called ``create_connection`` instead of +``create_protocol``. It must return a :class:`~asyncio.server.ServerConnection` +instead of a :class:`~legacy.server.WebSocketServerProtocol`. -The :attr:`~legacy.protocol.WebSocketCommonProtocol.path`, -:attr:`~legacy.protocol.WebSocketCommonProtocol.request_headers` and -:attr:`~legacy.protocol.WebSocketCommonProtocol.response_headers` properties are -replaced by :attr:`~asyncio.connection.Connection.request` and -:attr:`~asyncio.connection.Connection.response`, which provide a ``headers`` -attribute. +If you were customizing connection objects, probably you need to redo your +customization. Consider switching to ``process_request`` and +``select_subprotocol`` as their new design removes most use cases for +``create_connection``. -If your code relies on them, you can replace:: +.. _basic-auth: - connection.path - connection.request_headers - connection.response_headers +Performing HTTP Basic Authentication +.................................... -with:: +.. admonition:: This section applies only to servers. + :class: tip - connection.request.path - connection.request.headers - connection.response.headers + On the client side, :func:`~asyncio.client.connect` performs HTTP Basic + Authentication automatically when the URI contains credentials. -``open`` and ``closed`` -~~~~~~~~~~~~~~~~~~~~~~~ +In the original implementation, the recommended way to add HTTP Basic +Authentication to a server was to set the ``create_protocol`` argument of +:func:`~legacy.server.serve` to a factory function generated by +:func:`~legacy.auth.basic_auth_protocol_factory`:: -The :attr:`~legacy.protocol.WebSocketCommonProtocol.open` and -:attr:`~legacy.protocol.WebSocketCommonProtocol.closed` properties are removed. -Using them was discouraged. + from websockets.legacy.auth import basic_auth_protocol_factory + from websockets.legacy.server import serve -Instead, you should call :meth:`~asyncio.connection.Connection.recv` or -:meth:`~asyncio.connection.Connection.send` and handle -:exc:`~exceptions.ConnectionClosed` exceptions. + async with serve(..., create_protocol=basic_auth_protocol_factory(...)): + ... + +In the new implementation, the :func:`~asyncio.server.basic_auth` function +generates a ``process_request`` coroutine that performs HTTP Basic +Authentication:: -If your code relies on them, you can replace:: + from websockets.asyncio.server import basic_auth, serve - connection.open - connection.closed + async with serve(..., process_request=basic_auth(...)): + ... -with:: +:func:`~asyncio.server.basic_auth` accepts either hard coded ``credentials`` or +a ``check_credentials`` coroutine as well as an optional ``realm`` just like +:func:`~legacy.auth.basic_auth_protocol_factory`. Furthermore, +``check_credentials`` may be a function instead of a coroutine. - from websockets.protocol import State +This new API has more obvious semantics. That makes it easier to understand and +also easier to extend. - connection.state is State.OPEN - connection.state is State.CLOSED +In the original implementation, overriding ``create_protocol`` changed the type +of connection objects to :class:`~legacy.auth.BasicAuthWebSocketServerProtocol`, +a subclass of :class:`~legacy.server.WebSocketServerProtocol` that performs HTTP +Basic Authentication in its ``process_request`` method. If you wanted to +customize ``process_request`` further, you had: + +* an ill-defined option: add a ``process_request`` argument to + :func:`~legacy.server.serve`; to tell which one would run first, you had to + experiment or read the code; +* a cumbersome option: subclass + :class:`~legacy.auth.BasicAuthWebSocketServerProtocol`, then pass that + subclass in the ``create_protocol`` argument of + :func:`~legacy.auth.basic_auth_protocol_factory`. + +In the new implementation, you just write a ``process_request`` coroutine:: + + from websockets.asyncio.server import basic_auth, serve + + process_basic_auth = basic_auth(...) + + async def process_request(connection, request): + ... # some logic here + response = await process_basic_auth(connection, request) + if response is not None: + return response + ... # more logic here + + async with serve(..., process_request=process_request): + ... From 451944aa2d58e6d1c2d42042b1d2eea61a97b578 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 24 Aug 2024 08:35:34 +0200 Subject: [PATCH 1370/1539] process_request/response can be coroutines. --- src/websockets/asyncio/server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 29860e565..f24281252 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -693,14 +693,14 @@ def __init__( process_request: ( Callable[ [ServerConnection, Request], - Response | None, + Awaitable[Response | None] | Response | None, ] | None ) = None, process_response: ( Callable[ [ServerConnection, Request, Response], - Response | None, + Awaitable[Response | None] | Response | None, ] | None ) = None, From 6261ad360db65c95d229a02f40a92892a20183b2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 24 Aug 2024 21:33:04 +0200 Subject: [PATCH 1371/1539] Remove run_(unix_)client abstraction from tests. It wasn't adding any value, as shown in the diff. --- tests/asyncio/client.py | 33 --------- tests/asyncio/server.py | 9 ++- tests/asyncio/test_client.py | 70 +++++++++----------- tests/asyncio/test_server.py | 125 +++++++++++++++++------------------ tests/sync/client.py | 32 --------- tests/sync/server.py | 8 +++ tests/sync/test_client.py | 67 +++++++++---------- tests/sync/test_server.py | 77 ++++++++++----------- 8 files changed, 178 insertions(+), 243 deletions(-) delete mode 100644 tests/asyncio/client.py delete mode 100644 tests/sync/client.py diff --git a/tests/asyncio/client.py b/tests/asyncio/client.py deleted file mode 100644 index a73079c6e..000000000 --- a/tests/asyncio/client.py +++ /dev/null @@ -1,33 +0,0 @@ -import contextlib - -from websockets.asyncio.client import * -from websockets.asyncio.server import Server - -from .server import get_server_host_port - - -__all__ = [ - "run_client", - "run_unix_client", -] - - -@contextlib.asynccontextmanager -async def run_client(wsuri_or_server, secure=None, resource_name="/", **kwargs): - if isinstance(wsuri_or_server, str): - wsuri = wsuri_or_server - else: - assert isinstance(wsuri_or_server, Server) - if secure is None: - secure = "ssl" in kwargs - protocol = "wss" if secure else "ws" - host, port = get_server_host_port(wsuri_or_server) - wsuri = f"{protocol}://{host}:{port}{resource_name}" - async with connect(wsuri, **kwargs) as client: - yield client - - -@contextlib.asynccontextmanager -async def run_unix_client(path, **kwargs): - async with unix_connect(path, **kwargs) as client: - yield client diff --git a/tests/asyncio/server.py b/tests/asyncio/server.py index 0fe20dc65..06fa92dea 100644 --- a/tests/asyncio/server.py +++ b/tests/asyncio/server.py @@ -5,13 +5,20 @@ from websockets.asyncio.server import * -def get_server_host_port(server): +def get_host_port(server): for sock in server.sockets: if sock.family == socket.AF_INET: # pragma: no branch return sock.getsockname() raise AssertionError("expected at least one IPv4 socket") +def get_uri(server): + secure = server.server._ssl_context is not None # hack + protocol = "wss" if secure else "ws" + host, port = get_host_port(server) + return f"{protocol}://{host}:{port}" + + async def eval_shell(ws): async for expr in ws: value = eval(expr) diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 0bd2af4f1..a8ef6ef9d 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -9,49 +9,48 @@ from websockets.extensions.permessage_deflate import PerMessageDeflate from ..utils import CLIENT_CONTEXT, MS, SERVER_CONTEXT, temp_unix_socket_path -from .client import run_client, run_unix_client -from .server import do_nothing, get_server_host_port, run_server, run_unix_server +from .server import do_nothing, get_host_port, get_uri, run_server, run_unix_server class ClientTests(unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Client connects to server and the handshake succeeds.""" async with run_server() as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: self.assertEqual(client.protocol.state.name, "OPEN") async def test_existing_socket(self): """Client connects using a pre-existing socket.""" async with run_server() as server: - with socket.create_connection(get_server_host_port(server)) as sock: + with socket.create_connection(get_host_port(server)) as sock: # Use a non-existing domain to ensure we connect to the right socket. - async with run_client("ws://invalid/", sock=sock) as client: + async with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") async def test_additional_headers(self): """Client can set additional headers with additional_headers.""" async with run_server() as server: - async with run_client( - server, additional_headers={"Authorization": "Bearer ..."} + async with connect( + get_uri(server), additional_headers={"Authorization": "Bearer ..."} ) as client: self.assertEqual(client.request.headers["Authorization"], "Bearer ...") async def test_override_user_agent(self): """Client can override User-Agent header with user_agent_header.""" async with run_server() as server: - async with run_client(server, user_agent_header="Smith") as client: + async with connect(get_uri(server), user_agent_header="Smith") as client: self.assertEqual(client.request.headers["User-Agent"], "Smith") async def test_remove_user_agent(self): """Client can remove User-Agent header with user_agent_header.""" async with run_server() as server: - async with run_client(server, user_agent_header=None) as client: + async with connect(get_uri(server), user_agent_header=None) as client: self.assertNotIn("User-Agent", client.request.headers) async def test_compression_is_enabled(self): """Client enables compression by default.""" async with run_server() as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: self.assertEqual( [type(ext) for ext in client.protocol.extensions], [PerMessageDeflate], @@ -60,13 +59,13 @@ async def test_compression_is_enabled(self): async def test_disable_compression(self): """Client disables compression.""" async with run_server() as server: - async with run_client(server, compression=None) as client: + async with connect(get_uri(server), compression=None) as client: self.assertEqual(client.protocol.extensions, []) async def test_keepalive_is_enabled(self): """Client enables keepalive and measures latency by default.""" async with run_server() as server: - async with run_client(server, ping_interval=MS) as client: + async with connect(get_uri(server), ping_interval=MS) as client: self.assertEqual(client.latency, 0) await asyncio.sleep(2 * MS) self.assertGreater(client.latency, 0) @@ -74,7 +73,7 @@ async def test_keepalive_is_enabled(self): async def test_disable_keepalive(self): """Client disables keepalive.""" async with run_server() as server: - async with run_client(server, ping_interval=None) as client: + async with connect(get_uri(server), ping_interval=None) as client: await asyncio.sleep(2 * MS) self.assertEqual(client.latency, 0) @@ -87,21 +86,21 @@ def create_connection(*args, **kwargs): return client async with run_server() as server: - async with run_client( - server, create_connection=create_connection + async with connect( + get_uri(server), create_connection=create_connection ) as client: self.assertTrue(client.create_connection_ran) async def test_invalid_uri(self): """Client receives an invalid URI.""" with self.assertRaises(InvalidURI): - async with run_client("http://localhost"): # invalid scheme + async with connect("http://localhost"): # invalid scheme self.fail("did not raise") async def test_tcp_connection_fails(self): """Client fails to connect to server.""" with self.assertRaises(OSError): - async with run_client("ws://localhost:54321"): # invalid port + async with connect("ws://localhost:54321"): # invalid port self.fail("did not raise") async def test_handshake_fails(self): @@ -116,7 +115,7 @@ def remove_accept_header(self, request, response): do_nothing, process_response=remove_accept_header ) as server: with self.assertRaises(InvalidHandshake) as raised: - async with run_client(server, close_timeout=MS): + async with connect(get_uri(server), close_timeout=MS): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -135,7 +134,7 @@ async def stall_connection(self, request): async with run_server(do_nothing, process_request=stall_connection) as server: try: with self.assertRaises(TimeoutError) as raised: - async with run_client(server, open_timeout=2 * MS): + async with connect(get_uri(server), open_timeout=2 * MS): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -152,7 +151,7 @@ def close_connection(self, request): async with run_server(process_request=close_connection) as server: with self.assertRaises(ConnectionError) as raised: - async with run_client(server): + async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -164,7 +163,7 @@ class SecureClientTests(unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Client connects to server securely.""" async with run_server(ssl=SERVER_CONTEXT) as server: - async with run_client(server, ssl=CLIENT_CONTEXT) as client: + async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") ssl_object = client.transport.get_extra_info("ssl_object") self.assertEqual(ssl_object.version()[:3], "TLS") @@ -172,12 +171,9 @@ async def test_connection(self): async def test_set_server_hostname_implicitly(self): """Client sets server_hostname to the host in the WebSocket URI.""" async with run_server(ssl=SERVER_CONTEXT) as server: - host, port = get_server_host_port(server) - async with run_client( - "wss://overridden/", - host=host, - port=port, - ssl=CLIENT_CONTEXT, + host, port = get_host_port(server) + async with connect( + "wss://overridden/", host=host, port=port, ssl=CLIENT_CONTEXT ) as client: ssl_object = client.transport.get_extra_info("ssl_object") self.assertEqual(ssl_object.server_hostname, "overridden") @@ -185,10 +181,8 @@ async def test_set_server_hostname_implicitly(self): async def test_set_server_hostname_explicitly(self): """Client sets server_hostname to the value provided in argument.""" async with run_server(ssl=SERVER_CONTEXT) as server: - async with run_client( - server, - ssl=CLIENT_CONTEXT, - server_hostname="overridden", + async with connect( + get_uri(server), ssl=CLIENT_CONTEXT, server_hostname="overridden" ) as client: ssl_object = client.transport.get_extra_info("ssl_object") self.assertEqual(ssl_object.server_hostname, "overridden") @@ -198,7 +192,7 @@ async def test_reject_invalid_server_certificate(self): async with run_server(ssl=SERVER_CONTEXT) as server: with self.assertRaises(ssl.SSLCertVerificationError) as raised: # The test certificate isn't trusted system-wide. - async with run_client(server, secure=True): + async with connect(get_uri(server)): self.fail("did not raise") self.assertIn( "certificate verify failed: self signed certificate", @@ -210,8 +204,8 @@ async def test_reject_invalid_server_hostname(self): async with run_server(ssl=SERVER_CONTEXT) as server: with self.assertRaises(ssl.SSLCertVerificationError) as raised: # This hostname isn't included in the test certificate. - async with run_client( - server, ssl=CLIENT_CONTEXT, server_hostname="invalid" + async with connect( + get_uri(server), ssl=CLIENT_CONTEXT, server_hostname="invalid" ): self.fail("did not raise") self.assertIn( @@ -226,7 +220,7 @@ async def test_connection(self): """Client connects to server over a Unix socket.""" with temp_unix_socket_path() as path: async with run_unix_server(path): - async with run_unix_client(path) as client: + async with unix_connect(path) as client: self.assertEqual(client.protocol.state.name, "OPEN") async def test_set_host_header(self): @@ -234,7 +228,7 @@ async def test_set_host_header(self): # This is part of the documented behavior of unix_connect(). with temp_unix_socket_path() as path: async with run_unix_server(path): - async with run_unix_client(path, uri="ws://overridden/") as client: + async with unix_connect(path, uri="ws://overridden/") as client: self.assertEqual(client.request.headers["Host"], "overridden") @@ -244,7 +238,7 @@ async def test_connection(self): """Client connects to server securely over a Unix socket.""" with temp_unix_socket_path() as path: async with run_unix_server(path, ssl=SERVER_CONTEXT): - async with run_unix_client(path, ssl=CLIENT_CONTEXT) as client: + async with unix_connect(path, ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") ssl_object = client.transport.get_extra_info("ssl_object") self.assertEqual(ssl_object.version()[:3], "TLS") @@ -254,7 +248,7 @@ async def test_set_server_hostname(self): # This is part of the documented behavior of unix_connect(). with temp_unix_socket_path() as path: async with run_unix_server(path, ssl=SERVER_CONTEXT): - async with run_unix_client( + async with unix_connect( path, ssl=CLIENT_CONTEXT, uri="wss://overridden/", diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 38f226903..a16267439 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -6,6 +6,7 @@ import socket import unittest +from websockets.asyncio.client import connect, unix_connect from websockets.asyncio.compatibility import TimeoutError, asyncio_timeout from websockets.asyncio.server import * from websockets.exceptions import ( @@ -22,13 +23,13 @@ SERVER_CONTEXT, temp_unix_socket_path, ) -from .client import run_client, run_unix_client from .server import ( EvalShellMixin, crash, do_nothing, eval_shell, - get_server_host_port, + get_host_port, + get_uri, keep_running, run_server, run_unix_server, @@ -39,13 +40,13 @@ class ServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Server receives connection from client and the handshake succeeds.""" async with run_server() as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") async def test_connection_handler_returns(self): """Connection handler returns.""" async with run_server(do_nothing) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: with self.assertRaises(ConnectionClosedOK) as raised: await client.recv() self.assertEqual( @@ -56,7 +57,7 @@ async def test_connection_handler_returns(self): async def test_connection_handler_raises_exception(self): """Connection handler raises an exception.""" async with run_server(crash) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: with self.assertRaises(ConnectionClosedError) as raised: await client.recv() self.assertEqual( @@ -70,7 +71,7 @@ async def test_existing_socket(self): with socket.create_server(("localhost", 0)) as sock: async with run_server(sock=sock, host=None, port=None): uri = "ws://{}:{}/".format(*sock.getsockname()) - async with run_client(uri) as client: + async with connect(uri) as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") async def test_select_subprotocol(self): @@ -85,7 +86,7 @@ def select_subprotocol(ws, subprotocols): subprotocols=["chat"], select_subprotocol=select_subprotocol, ) as server: - async with run_client(server, subprotocols=["chat"]) as client: + async with connect(get_uri(server), subprotocols=["chat"]) as client: await self.assertEval(client, "ws.select_subprotocol_ran", "True") await self.assertEval(client, "ws.subprotocol", "chat") @@ -97,7 +98,7 @@ def select_subprotocol(ws, subprotocols): async with run_server(select_subprotocol=select_subprotocol) as server: with self.assertRaises(InvalidStatus) as raised: - async with run_client(server): + async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -112,7 +113,7 @@ def select_subprotocol(ws, subprotocols): async with run_server(select_subprotocol=select_subprotocol) as server: with self.assertRaises(InvalidStatus) as raised: - async with run_client(server): + async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -127,7 +128,7 @@ def process_request(ws, request): ws.process_request_ran = True async with run_server(process_request=process_request) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.process_request_ran", "True") async def test_async_process_request_returns_none(self): @@ -138,7 +139,7 @@ async def process_request(ws, request): ws.process_request_ran = True async with run_server(process_request=process_request) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.process_request_ran", "True") async def test_process_request_returns_response(self): @@ -149,7 +150,7 @@ def process_request(ws, request): async with run_server(process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: - async with run_client(server): + async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -164,7 +165,7 @@ async def process_request(ws, request): async with run_server(process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: - async with run_client(server): + async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -179,7 +180,7 @@ def process_request(ws, request): async with run_server(process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: - async with run_client(server): + async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -194,7 +195,7 @@ async def process_request(ws, request): async with run_server(process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: - async with run_client(server): + async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -210,7 +211,7 @@ def process_response(ws, request, response): ws.process_response_ran = True async with run_server(process_response=process_response) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.process_response_ran", "True") async def test_async_process_response_returns_none(self): @@ -222,7 +223,7 @@ async def process_response(ws, request, response): ws.process_response_ran = True async with run_server(process_response=process_response) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.process_response_ran", "True") async def test_process_response_modifies_response(self): @@ -232,7 +233,7 @@ def process_response(ws, request, response): response.headers["X-ProcessResponse"] = "OK" async with run_server(process_response=process_response) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") async def test_async_process_response_modifies_response(self): @@ -242,7 +243,7 @@ async def process_response(ws, request, response): response.headers["X-ProcessResponse"] = "OK" async with run_server(process_response=process_response) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") async def test_process_response_replaces_response(self): @@ -254,7 +255,7 @@ def process_response(ws, request, response): return dataclasses.replace(response, headers=headers) async with run_server(process_response=process_response) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") async def test_async_process_response_replaces_response(self): @@ -266,7 +267,7 @@ async def process_response(ws, request, response): return dataclasses.replace(response, headers=headers) async with run_server(process_response=process_response) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") async def test_process_response_raises_exception(self): @@ -277,7 +278,7 @@ def process_response(ws, request, response): async with run_server(process_response=process_response) as server: with self.assertRaises(InvalidStatus) as raised: - async with run_client(server): + async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -292,7 +293,7 @@ async def process_response(ws, request, response): async with run_server(process_response=process_response) as server: with self.assertRaises(InvalidStatus) as raised: - async with run_client(server): + async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -302,13 +303,13 @@ async def process_response(ws, request, response): async def test_override_server(self): """Server can override Server header with server_header.""" async with run_server(server_header="Neo") as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.response.headers['Server']", "Neo") async def test_remove_server(self): """Server can remove Server header with server_header.""" async with run_server(server_header=None) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: await self.assertEval( client, "'Server' in ws.response.headers", "False" ) @@ -316,7 +317,7 @@ async def test_remove_server(self): async def test_compression_is_enabled(self): """Server enables compression by default.""" async with run_server() as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: await self.assertEval( client, "[type(ext).__name__ for ext in ws.protocol.extensions]", @@ -326,13 +327,13 @@ async def test_compression_is_enabled(self): async def test_disable_compression(self): """Server disables compression.""" async with run_server(compression=None) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.protocol.extensions", "[]") async def test_keepalive_is_enabled(self): """Server enables keepalive and measures latency.""" async with run_server(ping_interval=MS) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: await client.send("ws.latency") latency = eval(await client.recv()) self.assertEqual(latency, 0) @@ -344,7 +345,7 @@ async def test_keepalive_is_enabled(self): async def test_disable_keepalive(self): """Client disables keepalive.""" async with run_server(ping_interval=None) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: await asyncio.sleep(2 * MS) await client.send("ws.latency") latency = eval(await client.recv()) @@ -365,14 +366,14 @@ def create_connection(*args, **kwargs): return server async with run_server(create_connection=create_connection) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.create_connection_ran", "True") async def test_connections(self): """Server provides a connections property.""" async with run_server() as server: self.assertEqual(server.connections, set()) - async with run_client(server) as client: + async with connect(get_uri(server)) as client: self.assertEqual(len(server.connections), 1) ws_id = str(next(iter(server.connections)).id) await self.assertEval(client, "ws.id", ws_id) @@ -386,7 +387,7 @@ def remove_key_header(self, request): async with run_server(process_request=remove_key_header) as server: with self.assertRaises(InvalidStatus) as raised: - async with run_client(server): + async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -396,9 +397,7 @@ def remove_key_header(self, request): async def test_timeout_during_handshake(self): """Server times out before receiving handshake request from client.""" async with run_server(open_timeout=MS) as server: - reader, writer = await asyncio.open_connection( - *get_server_host_port(server) - ) + reader, writer = await asyncio.open_connection(*get_host_port(server)) try: self.assertEqual(await reader.read(4096), b"") finally: @@ -407,9 +406,7 @@ async def test_timeout_during_handshake(self): async def test_connection_closed_during_handshake(self): """Server reads EOF before receiving handshake request from client.""" async with run_server() as server: - _reader, writer = await asyncio.open_connection( - *get_server_host_port(server) - ) + _reader, writer = await asyncio.open_connection(*get_host_port(server)) writer.close() async def test_close_server_rejects_connecting_connections(self): @@ -422,7 +419,7 @@ async def process_request(ws, _request): async with run_server(process_request=process_request) as server: asyncio.get_running_loop().call_later(MS, server.close) with self.assertRaises(InvalidStatus) as raised: - async with run_client(server): + async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -432,7 +429,7 @@ async def process_request(ws, _request): async def test_close_server_closes_open_connections(self): """Server closes open connections with close code 1001 when closing.""" async with run_server() as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: server.close() with self.assertRaises(ConnectionClosedOK) as raised: await client.recv() @@ -444,7 +441,7 @@ async def test_close_server_closes_open_connections(self): async def test_close_server_keeps_connections_open(self): """Server waits for client to close open connections when closing.""" async with run_server() as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: server.close(close_connections=False) # Server cannot receive new connections. @@ -464,7 +461,7 @@ async def test_close_server_keeps_connections_open(self): async def test_close_server_keeps_handlers_running(self): """Server waits for connection handlers to terminate.""" async with run_server(keep_running) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: # Delay termination of connection handler. await client.send(str(3 * MS)) @@ -486,16 +483,14 @@ class SecureServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Server receives secure connection from client.""" async with run_server(ssl=SERVER_CONTEXT) as server: - async with run_client(server, ssl=CLIENT_CONTEXT) as client: + async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") await self.assertEval(client, SSL_OBJECT + ".version()[:3]", "TLS") async def test_timeout_during_tls_handshake(self): """Server times out before receiving TLS handshake request from client.""" async with run_server(ssl=SERVER_CONTEXT, open_timeout=MS) as server: - reader, writer = await asyncio.open_connection( - *get_server_host_port(server) - ) + reader, writer = await asyncio.open_connection(*get_host_port(server)) try: self.assertEqual(await reader.read(4096), b"") finally: @@ -504,9 +499,7 @@ async def test_timeout_during_tls_handshake(self): async def test_connection_closed_during_tls_handshake(self): """Server reads EOF before receiving TLS handshake request from client.""" async with run_server(ssl=SERVER_CONTEXT) as server: - _reader, writer = await asyncio.open_connection( - *get_server_host_port(server) - ) + _reader, writer = await asyncio.open_connection(*get_host_port(server)) writer.close() @@ -516,7 +509,7 @@ async def test_connection(self): """Server receives connection from client over a Unix socket.""" with temp_unix_socket_path() as path: async with run_unix_server(path): - async with run_unix_client(path) as client: + async with unix_connect(path) as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") @@ -526,7 +519,7 @@ async def test_connection(self): """Server receives secure connection from client over a Unix socket.""" with temp_unix_socket_path() as path: async with run_unix_server(path, ssl=SERVER_CONTEXT): - async with run_unix_client(path, ssl=CLIENT_CONTEXT) as client: + async with unix_connect(path, ssl=CLIENT_CONTEXT) as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") await self.assertEval(client, SSL_OBJECT + ".version()[:3]", "TLS") @@ -577,8 +570,8 @@ async def test_valid_authorization(self): async with run_server( process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: - async with run_client( - server, + async with connect( + get_uri(server), additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, ) as client: await self.assertEval(client, "ws.username", "hello") @@ -589,7 +582,7 @@ async def test_missing_authorization(self): process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: with self.assertRaises(InvalidStatus) as raised: - async with run_client(server): + async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -602,8 +595,8 @@ async def test_unsupported_authorization(self): process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: with self.assertRaises(InvalidStatus) as raised: - async with run_client( - server, + async with connect( + get_uri(server), additional_headers={"Authorization": "Negotiate ..."}, ): self.fail("did not raise") @@ -618,8 +611,8 @@ async def test_authorization_with_unknown_username(self): process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: with self.assertRaises(InvalidStatus) as raised: - async with run_client( - server, + async with connect( + get_uri(server), additional_headers={"Authorization": "Basic YnllOnlvdWxvdmVtZQ=="}, ): self.fail("did not raise") @@ -634,8 +627,8 @@ async def test_authorization_with_incorrect_password(self): process_request=basic_auth(credentials=("hello", "changeme")), ) as server: with self.assertRaises(InvalidStatus) as raised: - async with run_client( - server, + async with connect( + get_uri(server), additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, ): self.fail("did not raise") @@ -654,8 +647,8 @@ async def test_list_of_credentials(self): ] ), ) as server: - async with run_client( - server, + async with connect( + get_uri(server), additional_headers={"Authorization": "Basic YnllOnlvdWxvdmVtZQ=="}, ) as client: await self.assertEval(client, "ws.username", "bye") @@ -669,8 +662,8 @@ def check_credentials(username, password): async with run_server( process_request=basic_auth(check_credentials=check_credentials), ) as server: - async with run_client( - server, + async with connect( + get_uri(server), additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, ) as client: await self.assertEval(client, "ws.username", "hello") @@ -684,8 +677,8 @@ async def check_credentials(username, password): async with run_server( process_request=basic_auth(check_credentials=check_credentials), ) as server: - async with run_client( - server, + async with connect( + get_uri(server), additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, ) as client: await self.assertEval(client, "ws.username", "hello") diff --git a/tests/sync/client.py b/tests/sync/client.py deleted file mode 100644 index acbf97fa7..000000000 --- a/tests/sync/client.py +++ /dev/null @@ -1,32 +0,0 @@ -import contextlib - -from websockets.sync.client import * -from websockets.sync.server import Server - - -__all__ = [ - "run_client", - "run_unix_client", -] - - -@contextlib.contextmanager -def run_client(wsuri_or_server, secure=None, resource_name="/", **kwargs): - if isinstance(wsuri_or_server, str): - wsuri = wsuri_or_server - else: - assert isinstance(wsuri_or_server, Server) - if secure is None: - # Backwards compatibility: ssl used to be called ssl_context. - secure = "ssl" in kwargs or "ssl_context" in kwargs - protocol = "wss" if secure else "ws" - host, port = wsuri_or_server.socket.getsockname() - wsuri = f"{protocol}://{host}:{port}{resource_name}" - with connect(wsuri, **kwargs) as client: - yield client - - -@contextlib.contextmanager -def run_unix_client(path, **kwargs): - with unix_connect(path, **kwargs) as client: - yield client diff --git a/tests/sync/server.py b/tests/sync/server.py index d5295ccd8..a86cf88ce 100644 --- a/tests/sync/server.py +++ b/tests/sync/server.py @@ -1,9 +1,17 @@ import contextlib +import ssl import threading from websockets.sync.server import * +def get_uri(server): + secure = isinstance(server.socket, ssl.SSLSocket) # hack + protocol = "wss" if secure else "ws" + host, port = server.socket.getsockname() + return f"{protocol}://{host}:{port}" + + def crash(ws): raise RuntimeError diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 03f4e972f..44cbdd6c4 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -14,15 +14,14 @@ DeprecationTestCase, temp_unix_socket_path, ) -from .client import run_client, run_unix_client -from .server import do_nothing, run_server, run_unix_server +from .server import do_nothing, get_uri, run_server, run_unix_server class ClientTests(unittest.TestCase): def test_connection(self): """Client connects to server and the handshake succeeds.""" with run_server() as server: - with run_client(server) as client: + with connect(get_uri(server)) as client: self.assertEqual(client.protocol.state.name, "OPEN") def test_existing_socket(self): @@ -30,33 +29,33 @@ def test_existing_socket(self): with run_server() as server: with socket.create_connection(server.socket.getsockname()) as sock: # Use a non-existing domain to ensure we connect to the right socket. - with run_client("ws://invalid/", sock=sock) as client: + with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") def test_additional_headers(self): """Client can set additional headers with additional_headers.""" with run_server() as server: - with run_client( - server, additional_headers={"Authorization": "Bearer ..."} + with connect( + get_uri(server), additional_headers={"Authorization": "Bearer ..."} ) as client: self.assertEqual(client.request.headers["Authorization"], "Bearer ...") def test_override_user_agent(self): """Client can override User-Agent header with user_agent_header.""" with run_server() as server: - with run_client(server, user_agent_header="Smith") as client: + with connect(get_uri(server), user_agent_header="Smith") as client: self.assertEqual(client.request.headers["User-Agent"], "Smith") def test_remove_user_agent(self): """Client can remove User-Agent header with user_agent_header.""" with run_server() as server: - with run_client(server, user_agent_header=None) as client: + with connect(get_uri(server), user_agent_header=None) as client: self.assertNotIn("User-Agent", client.request.headers) def test_compression_is_enabled(self): """Client enables compression by default.""" with run_server() as server: - with run_client(server) as client: + with connect(get_uri(server)) as client: self.assertEqual( [type(ext) for ext in client.protocol.extensions], [PerMessageDeflate], @@ -65,7 +64,7 @@ def test_compression_is_enabled(self): def test_disable_compression(self): """Client disables compression.""" with run_server() as server: - with run_client(server, compression=None) as client: + with connect(get_uri(server), compression=None) as client: self.assertEqual(client.protocol.extensions, []) def test_custom_connection_factory(self): @@ -77,19 +76,21 @@ def create_connection(*args, **kwargs): return client with run_server() as server: - with run_client(server, create_connection=create_connection) as client: + with connect( + get_uri(server), create_connection=create_connection + ) as client: self.assertTrue(client.create_connection_ran) def test_invalid_uri(self): """Client receives an invalid URI.""" with self.assertRaises(InvalidURI): - with run_client("http://localhost"): # invalid scheme + with connect("http://localhost"): # invalid scheme self.fail("did not raise") def test_tcp_connection_fails(self): """Client fails to connect to server.""" with self.assertRaises(OSError): - with run_client("ws://localhost:54321"): # invalid port + with connect("ws://localhost:54321"): # invalid port self.fail("did not raise") def test_handshake_fails(self): @@ -102,7 +103,7 @@ def remove_accept_header(self, request, response): # Use a connection handler that exits immediately to avoid an exception. with run_server(do_nothing, process_response=remove_accept_header) as server: with self.assertRaises(InvalidHandshake) as raised: - with run_client(server, close_timeout=MS): + with connect(get_uri(server), close_timeout=MS): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -121,7 +122,7 @@ def stall_connection(self, request): with run_server(do_nothing, process_request=stall_connection) as server: try: with self.assertRaises(TimeoutError) as raised: - with run_client(server, open_timeout=2 * MS): + with connect(get_uri(server), open_timeout=2 * MS): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -138,7 +139,7 @@ def close_connection(self, request): with run_server(process_request=close_connection) as server: with self.assertRaises(ConnectionError) as raised: - with run_client(server): + with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -150,7 +151,7 @@ class SecureClientTests(unittest.TestCase): def test_connection(self): """Client connects to server securely.""" with run_server(ssl=SERVER_CONTEXT) as server: - with run_client(server, ssl=CLIENT_CONTEXT) as client: + with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertEqual(client.socket.version()[:3], "TLS") @@ -158,10 +159,8 @@ def test_set_server_hostname_implicitly(self): """Client sets server_hostname to the host in the WebSocket URI.""" with temp_unix_socket_path() as path: with run_unix_server(path, ssl=SERVER_CONTEXT): - with run_unix_client( - path, - ssl=CLIENT_CONTEXT, - uri="wss://overridden/", + with unix_connect( + path, ssl=CLIENT_CONTEXT, uri="wss://overridden/" ) as client: self.assertEqual(client.socket.server_hostname, "overridden") @@ -169,10 +168,8 @@ def test_set_server_hostname_explicitly(self): """Client sets server_hostname to the value provided in argument.""" with temp_unix_socket_path() as path: with run_unix_server(path, ssl=SERVER_CONTEXT): - with run_unix_client( - path, - ssl=CLIENT_CONTEXT, - server_hostname="overridden", + with unix_connect( + path, ssl=CLIENT_CONTEXT, server_hostname="overridden" ) as client: self.assertEqual(client.socket.server_hostname, "overridden") @@ -181,7 +178,7 @@ def test_reject_invalid_server_certificate(self): with run_server(ssl=SERVER_CONTEXT) as server: with self.assertRaises(ssl.SSLCertVerificationError) as raised: # The test certificate isn't trusted system-wide. - with run_client(server, secure=True): + with connect(get_uri(server)): self.fail("did not raise") self.assertIn( "certificate verify failed: self signed certificate", @@ -193,7 +190,9 @@ def test_reject_invalid_server_hostname(self): with run_server(ssl=SERVER_CONTEXT) as server: with self.assertRaises(ssl.SSLCertVerificationError) as raised: # This hostname isn't included in the test certificate. - with run_client(server, ssl=CLIENT_CONTEXT, server_hostname="invalid"): + with connect( + get_uri(server), ssl=CLIENT_CONTEXT, server_hostname="invalid" + ): self.fail("did not raise") self.assertIn( "certificate verify failed: Hostname mismatch", @@ -207,7 +206,7 @@ def test_connection(self): """Client connects to server over a Unix socket.""" with temp_unix_socket_path() as path: with run_unix_server(path): - with run_unix_client(path) as client: + with unix_connect(path) as client: self.assertEqual(client.protocol.state.name, "OPEN") def test_set_host_header(self): @@ -215,7 +214,7 @@ def test_set_host_header(self): # This is part of the documented behavior of unix_connect(). with temp_unix_socket_path() as path: with run_unix_server(path): - with run_unix_client(path, uri="ws://overridden/") as client: + with unix_connect(path, uri="ws://overridden/") as client: self.assertEqual(client.request.headers["Host"], "overridden") @@ -225,7 +224,7 @@ def test_connection(self): """Client connects to server securely over a Unix socket.""" with temp_unix_socket_path() as path: with run_unix_server(path, ssl=SERVER_CONTEXT): - with run_unix_client(path, ssl=CLIENT_CONTEXT) as client: + with unix_connect(path, ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertEqual(client.socket.version()[:3], "TLS") @@ -234,10 +233,8 @@ def test_set_server_hostname(self): # This is part of the documented behavior of unix_connect(). with temp_unix_socket_path() as path: with run_unix_server(path, ssl=SERVER_CONTEXT): - with run_unix_client( - path, - ssl=CLIENT_CONTEXT, - uri="wss://overridden/", + with unix_connect( + path, ssl=CLIENT_CONTEXT, uri="wss://overridden/" ) as client: self.assertEqual(client.socket.server_hostname, "overridden") @@ -296,5 +293,5 @@ def test_ssl_context_argument(self): """Client supports the deprecated ssl_context argument.""" with run_server(ssl=SERVER_CONTEXT) as server: with self.assertDeprecationWarning("ssl_context was renamed to ssl"): - with run_client(server, ssl_context=CLIENT_CONTEXT): + with connect(get_uri(server), ssl_context=CLIENT_CONTEXT): pass diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index a17634716..e6a17d02f 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -13,6 +13,7 @@ NegotiationError, ) from websockets.http11 import Request, Response +from websockets.sync.client import connect, unix_connect from websockets.sync.server import * from ..utils import ( @@ -22,12 +23,12 @@ DeprecationTestCase, temp_unix_socket_path, ) -from .client import run_client, run_unix_client from .server import ( EvalShellMixin, crash, do_nothing, eval_shell, + get_uri, run_server, run_unix_server, ) @@ -37,13 +38,13 @@ class ServerTests(EvalShellMixin, unittest.TestCase): def test_connection(self): """Server receives connection from client and the handshake succeeds.""" with run_server() as server: - with run_client(server) as client: + with connect(get_uri(server)) as client: self.assertEval(client, "ws.protocol.state.name", "OPEN") def test_connection_handler_returns(self): """Connection handler returns.""" with run_server(do_nothing) as server: - with run_client(server) as client: + with connect(get_uri(server)) as client: with self.assertRaises(ConnectionClosedOK) as raised: client.recv() self.assertEqual( @@ -54,7 +55,7 @@ def test_connection_handler_returns(self): def test_connection_handler_raises_exception(self): """Connection handler raises an exception.""" with run_server(crash) as server: - with run_client(server) as client: + with connect(get_uri(server)) as client: with self.assertRaises(ConnectionClosedError) as raised: client.recv() self.assertEqual( @@ -68,7 +69,7 @@ def test_existing_socket(self): with socket.create_server(("localhost", 0)) as sock: with run_server(sock=sock): uri = "ws://{}:{}/".format(*sock.getsockname()) - with run_client(uri) as client: + with connect(uri) as client: self.assertEval(client, "ws.protocol.state.name", "OPEN") def test_select_subprotocol(self): @@ -83,7 +84,7 @@ def select_subprotocol(ws, subprotocols): subprotocols=["chat"], select_subprotocol=select_subprotocol, ) as server: - with run_client(server, subprotocols=["chat"]) as client: + with connect(get_uri(server), subprotocols=["chat"]) as client: self.assertEval(client, "ws.select_subprotocol_ran", "True") self.assertEval(client, "ws.subprotocol", "chat") @@ -95,7 +96,7 @@ def select_subprotocol(ws, subprotocols): with run_server(select_subprotocol=select_subprotocol) as server: with self.assertRaises(InvalidStatus) as raised: - with run_client(server): + with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -110,7 +111,7 @@ def select_subprotocol(ws, subprotocols): with run_server(select_subprotocol=select_subprotocol) as server: with self.assertRaises(InvalidStatus) as raised: - with run_client(server): + with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -125,7 +126,7 @@ def process_request(ws, request): ws.process_request_ran = True with run_server(process_request=process_request) as server: - with run_client(server) as client: + with connect(get_uri(server)) as client: self.assertEval(client, "ws.process_request_ran", "True") def test_process_request_returns_response(self): @@ -136,7 +137,7 @@ def process_request(ws, request): with run_server(process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: - with run_client(server): + with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -151,7 +152,7 @@ def process_request(ws, request): with run_server(process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: - with run_client(server): + with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -167,7 +168,7 @@ def process_response(ws, request, response): ws.process_response_ran = True with run_server(process_response=process_response) as server: - with run_client(server) as client: + with connect(get_uri(server)) as client: self.assertEval(client, "ws.process_response_ran", "True") def test_process_response_modifies_response(self): @@ -177,7 +178,7 @@ def process_response(ws, request, response): response.headers["X-ProcessResponse"] = "OK" with run_server(process_response=process_response) as server: - with run_client(server) as client: + with connect(get_uri(server)) as client: self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") def test_process_response_replaces_response(self): @@ -189,7 +190,7 @@ def process_response(ws, request, response): return dataclasses.replace(response, headers=headers) with run_server(process_response=process_response) as server: - with run_client(server) as client: + with connect(get_uri(server)) as client: self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") def test_process_response_raises_exception(self): @@ -200,7 +201,7 @@ def process_response(ws, request, response): with run_server(process_response=process_response) as server: with self.assertRaises(InvalidStatus) as raised: - with run_client(server): + with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -210,19 +211,19 @@ def process_response(ws, request, response): def test_override_server(self): """Server can override Server header with server_header.""" with run_server(server_header="Neo") as server: - with run_client(server) as client: + with connect(get_uri(server)) as client: self.assertEval(client, "ws.response.headers['Server']", "Neo") def test_remove_server(self): """Server can remove Server header with server_header.""" with run_server(server_header=None) as server: - with run_client(server) as client: + with connect(get_uri(server)) as client: self.assertEval(client, "'Server' in ws.response.headers", "False") def test_compression_is_enabled(self): """Server enables compression by default.""" with run_server() as server: - with run_client(server) as client: + with connect(get_uri(server)) as client: self.assertEval( client, "[type(ext).__name__ for ext in ws.protocol.extensions]", @@ -232,7 +233,7 @@ def test_compression_is_enabled(self): def test_disable_compression(self): """Server disables compression.""" with run_server(compression=None) as server: - with run_client(server) as client: + with connect(get_uri(server)) as client: self.assertEval(client, "ws.protocol.extensions", "[]") def test_logger(self): @@ -250,7 +251,7 @@ def create_connection(*args, **kwargs): return server with run_server(create_connection=create_connection) as server: - with run_client(server) as client: + with connect(get_uri(server)) as client: self.assertEval(client, "ws.create_connection_ran", "True") def test_fileno(self): @@ -274,7 +275,7 @@ def remove_key_header(self, request): with run_server(process_request=remove_key_header) as server: with self.assertRaises(InvalidStatus) as raised: - with run_client(server): + with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -316,7 +317,7 @@ class SecureServerTests(EvalShellMixin, unittest.TestCase): def test_connection(self): """Server receives secure connection from client.""" with run_server(ssl=SERVER_CONTEXT) as server: - with run_client(server, ssl=CLIENT_CONTEXT) as client: + with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: self.assertEval(client, "ws.protocol.state.name", "OPEN") self.assertEval(client, "ws.socket.version()[:3]", "TLS") @@ -357,7 +358,7 @@ def test_connection(self): """Server receives connection from client over a Unix socket.""" with temp_unix_socket_path() as path: with run_unix_server(path): - with run_unix_client(path) as client: + with unix_connect(path) as client: self.assertEval(client, "ws.protocol.state.name", "OPEN") @@ -367,7 +368,7 @@ def test_connection(self): """Server receives secure connection from client over a Unix socket.""" with temp_unix_socket_path() as path: with run_unix_server(path, ssl=SERVER_CONTEXT): - with run_unix_client(path, ssl=CLIENT_CONTEXT) as client: + with unix_connect(path, ssl=CLIENT_CONTEXT) as client: self.assertEval(client, "ws.protocol.state.name", "OPEN") self.assertEval(client, "ws.socket.version()[:3]", "TLS") @@ -418,8 +419,8 @@ def test_valid_authorization(self): with run_server( process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: - with run_client( - server, + with connect( + get_uri(server), additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, ) as client: self.assertEval(client, "ws.username", "hello") @@ -430,7 +431,7 @@ def test_missing_authorization(self): process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: with self.assertRaises(InvalidStatus) as raised: - with run_client(server): + with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -443,8 +444,8 @@ def test_unsupported_authorization(self): process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: with self.assertRaises(InvalidStatus) as raised: - with run_client( - server, + with connect( + get_uri(server), additional_headers={"Authorization": "Negotiate ..."}, ): self.fail("did not raise") @@ -459,8 +460,8 @@ def test_authorization_with_unknown_username(self): process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: with self.assertRaises(InvalidStatus) as raised: - with run_client( - server, + with connect( + get_uri(server), additional_headers={"Authorization": "Basic YnllOnlvdWxvdmVtZQ=="}, ): self.fail("did not raise") @@ -475,8 +476,8 @@ def test_authorization_with_incorrect_password(self): process_request=basic_auth(credentials=("hello", "changeme")), ) as server: with self.assertRaises(InvalidStatus) as raised: - with run_client( - server, + with connect( + get_uri(server), additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, ): self.fail("did not raise") @@ -495,8 +496,8 @@ def test_list_of_credentials(self): ] ), ) as server: - with run_client( - server, + with connect( + get_uri(server), additional_headers={"Authorization": "Basic YnllOnlvdWxvdmVtZQ=="}, ) as client: self.assertEval(client, "ws.username", "bye") @@ -510,8 +511,8 @@ def check_credentials(username, password): with run_server( process_request=basic_auth(check_credentials=check_credentials), ) as server: - with run_client( - server, + with connect( + get_uri(server), additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, ) as client: self.assertEval(client, "ws.username", "hello") @@ -561,7 +562,7 @@ def test_ssl_context_argument(self): """Client supports the deprecated ssl_context argument.""" with self.assertDeprecationWarning("ssl_context was renamed to ssl"): with run_server(ssl_context=SERVER_CONTEXT) as server: - with run_client(server, ssl=CLIENT_CONTEXT): + with connect(get_uri(server), ssl=CLIENT_CONTEXT): pass def test_web_socket_server_class(self): From f0d1ebb238466ec48cffc657341c7928ef08f43e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 24 Aug 2024 21:48:31 +0200 Subject: [PATCH 1372/1539] Use the same connection handler for all tests. --- tests/asyncio/server.py | 39 ++++++++++++++++++------------------ tests/asyncio/test_client.py | 12 +++++------ tests/asyncio/test_server.py | 25 ++++++++++------------- tests/sync/server.py | 29 ++++++++++++++------------- tests/sync/test_client.py | 10 ++++----- tests/sync/test_server.py | 20 +++++++++--------- 6 files changed, 64 insertions(+), 71 deletions(-) diff --git a/tests/asyncio/server.py b/tests/asyncio/server.py index 06fa92dea..097950971 100644 --- a/tests/asyncio/server.py +++ b/tests/asyncio/server.py @@ -19,10 +19,23 @@ def get_uri(server): return f"{protocol}://{host}:{port}" -async def eval_shell(ws): - async for expr in ws: - value = eval(expr) - await ws.send(str(value)) +async def handler(ws): + path = ws.request.path + if path == "/": + # The default path is an eval shell. + async for expr in ws: + value = eval(expr) + await ws.send(str(value)) + elif path == "/crash": + raise RuntimeError + elif path == "/no-op": + pass + elif path == "/delay": + delay = float(await ws.recv()) + await ws.close() + await asyncio.sleep(delay) + else: + raise AssertionError(f"unexpected path: {path}") class EvalShellMixin: @@ -31,27 +44,13 @@ async def assertEval(self, client, expr, value): self.assertEqual(await client.recv(), value) -async def crash(ws): - raise RuntimeError - - -async def do_nothing(ws): - pass - - -async def keep_running(ws): - delay = float(await ws.recv()) - await ws.close() - await asyncio.sleep(delay) - - @contextlib.asynccontextmanager -async def run_server(handler=eval_shell, host="localhost", port=0, **kwargs): +async def run_server(handler=handler, host="localhost", port=0, **kwargs): async with serve(handler, host, port, **kwargs) as server: yield server @contextlib.asynccontextmanager -async def run_unix_server(path, handler=eval_shell, **kwargs): +async def run_unix_server(path, handler=handler, **kwargs): async with unix_serve(handler, path, **kwargs) as server: yield server diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index a8ef6ef9d..77261129a 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -9,7 +9,7 @@ from websockets.extensions.permessage_deflate import PerMessageDeflate from ..utils import CLIENT_CONTEXT, MS, SERVER_CONTEXT, temp_unix_socket_path -from .server import do_nothing, get_host_port, get_uri, run_server, run_unix_server +from .server import get_host_port, get_uri, run_server, run_unix_server class ClientTests(unittest.IsolatedAsyncioTestCase): @@ -111,11 +111,9 @@ def remove_accept_header(self, request, response): # The connection will be open for the server but failed for the client. # Use a connection handler that exits immediately to avoid an exception. - async with run_server( - do_nothing, process_response=remove_accept_header - ) as server: + async with run_server(process_response=remove_accept_header) as server: with self.assertRaises(InvalidHandshake) as raised: - async with connect(get_uri(server), close_timeout=MS): + async with connect(get_uri(server) + "/no-op", close_timeout=MS): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -131,10 +129,10 @@ async def stall_connection(self, request): # The connection will be open for the server but failed for the client. # Use a connection handler that exits immediately to avoid an exception. - async with run_server(do_nothing, process_request=stall_connection) as server: + async with run_server(process_request=stall_connection) as server: try: with self.assertRaises(TimeoutError) as raised: - async with connect(get_uri(server), open_timeout=2 * MS): + async with connect(get_uri(server) + "/no-op", open_timeout=2 * MS): self.fail("did not raise") self.assertEqual( str(raised.exception), diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index a16267439..fd9c69c40 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -25,12 +25,9 @@ ) from .server import ( EvalShellMixin, - crash, - do_nothing, - eval_shell, get_host_port, get_uri, - keep_running, + handler, run_server, run_unix_server, ) @@ -45,8 +42,8 @@ async def test_connection(self): async def test_connection_handler_returns(self): """Connection handler returns.""" - async with run_server(do_nothing) as server: - async with connect(get_uri(server)) as client: + async with run_server() as server: + async with connect(get_uri(server) + "/no-op") as client: with self.assertRaises(ConnectionClosedOK) as raised: await client.recv() self.assertEqual( @@ -56,8 +53,8 @@ async def test_connection_handler_returns(self): async def test_connection_handler_raises_exception(self): """Connection handler raises an exception.""" - async with run_server(crash) as server: - async with connect(get_uri(server)) as client: + async with run_server() as server: + async with connect(get_uri(server) + "/crash") as client: with self.assertRaises(ConnectionClosedError) as raised: await client.recv() self.assertEqual( @@ -460,8 +457,8 @@ async def test_close_server_keeps_connections_open(self): async def test_close_server_keeps_handlers_running(self): """Server waits for connection handlers to terminate.""" - async with run_server(keep_running) as server: - async with connect(get_uri(server)) as client: + async with run_server() as server: + async with connect(get_uri(server) + "/delay") as client: # Delay termination of connection handler. await client.send(str(3 * MS)) @@ -528,7 +525,7 @@ class ServerUsageErrorsTests(unittest.IsolatedAsyncioTestCase): async def test_unix_without_path_or_sock(self): """Unix server requires path when sock isn't provided.""" with self.assertRaises(ValueError) as raised: - await unix_serve(eval_shell) + await unix_serve(handler) self.assertEqual( str(raised.exception), "path was not specified, and no sock specified", @@ -539,7 +536,7 @@ async def test_unix_with_path_and_sock(self): sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.addCleanup(sock.close) with self.assertRaises(ValueError) as raised: - await unix_serve(eval_shell, path="/", sock=sock) + await unix_serve(handler, path="/", sock=sock) self.assertEqual( str(raised.exception), "path and sock can not be specified at the same time", @@ -548,7 +545,7 @@ async def test_unix_with_path_and_sock(self): async def test_invalid_subprotocol(self): """Server rejects single value of subprotocols.""" with self.assertRaises(TypeError) as raised: - await serve(eval_shell, subprotocols="chat") + await serve(handler, subprotocols="chat") self.assertEqual( str(raised.exception), "subprotocols must be a list, not a str", @@ -557,7 +554,7 @@ async def test_invalid_subprotocol(self): async def test_unsupported_compression(self): """Server rejects incorrect value of compression.""" with self.assertRaises(ValueError) as raised: - await serve(eval_shell, compression=False) + await serve(handler, compression=False) self.assertEqual( str(raised.exception), "unsupported compression: False", diff --git a/tests/sync/server.py b/tests/sync/server.py index a86cf88ce..114c1545b 100644 --- a/tests/sync/server.py +++ b/tests/sync/server.py @@ -12,18 +12,19 @@ def get_uri(server): return f"{protocol}://{host}:{port}" -def crash(ws): - raise RuntimeError - - -def do_nothing(ws): - pass - - -def eval_shell(ws): - for expr in ws: - value = eval(expr) - ws.send(str(value)) +def handler(ws): + path = ws.request.path + if path == "/": + # The default path is an eval shell. + for expr in ws: + value = eval(expr) + ws.send(str(value)) + elif path == "/crash": + raise RuntimeError + elif path == "/no-op": + pass + else: + raise AssertionError(f"unexpected path: {path}") class EvalShellMixin: @@ -33,7 +34,7 @@ def assertEval(self, client, expr, value): @contextlib.contextmanager -def run_server(handler=eval_shell, host="localhost", port=0, **kwargs): +def run_server(handler=handler, host="localhost", port=0, **kwargs): with serve(handler, host, port, **kwargs) as server: thread = threading.Thread(target=server.serve_forever) thread.start() @@ -45,7 +46,7 @@ def run_server(handler=eval_shell, host="localhost", port=0, **kwargs): @contextlib.contextmanager -def run_unix_server(path, handler=eval_shell, **kwargs): +def run_unix_server(path, handler=handler, **kwargs): with unix_serve(handler, path, **kwargs) as server: thread = threading.Thread(target=server.serve_forever) thread.start() diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 44cbdd6c4..0d5273d12 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -14,7 +14,7 @@ DeprecationTestCase, temp_unix_socket_path, ) -from .server import do_nothing, get_uri, run_server, run_unix_server +from .server import get_uri, run_server, run_unix_server class ClientTests(unittest.TestCase): @@ -101,9 +101,9 @@ def remove_accept_header(self, request, response): # The connection will be open for the server but failed for the client. # Use a connection handler that exits immediately to avoid an exception. - with run_server(do_nothing, process_response=remove_accept_header) as server: + with run_server(process_response=remove_accept_header) as server: with self.assertRaises(InvalidHandshake) as raised: - with connect(get_uri(server), close_timeout=MS): + with connect(get_uri(server) + "/no-op", close_timeout=MS): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -119,10 +119,10 @@ def stall_connection(self, request): # The connection will be open for the server but failed for the client. # Use a connection handler that exits immediately to avoid an exception. - with run_server(do_nothing, process_request=stall_connection) as server: + with run_server(process_request=stall_connection) as server: try: with self.assertRaises(TimeoutError) as raised: - with connect(get_uri(server), open_timeout=2 * MS): + with connect(get_uri(server) + "/no-op", open_timeout=2 * MS): self.fail("did not raise") self.assertEqual( str(raised.exception), diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index e6a17d02f..a4b537c66 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -25,10 +25,8 @@ ) from .server import ( EvalShellMixin, - crash, - do_nothing, - eval_shell, get_uri, + handler, run_server, run_unix_server, ) @@ -43,8 +41,8 @@ def test_connection(self): def test_connection_handler_returns(self): """Connection handler returns.""" - with run_server(do_nothing) as server: - with connect(get_uri(server)) as client: + with run_server() as server: + with connect(get_uri(server) + "/no-op") as client: with self.assertRaises(ConnectionClosedOK) as raised: client.recv() self.assertEqual( @@ -54,8 +52,8 @@ def test_connection_handler_returns(self): def test_connection_handler_raises_exception(self): """Connection handler raises an exception.""" - with run_server(crash) as server: - with connect(get_uri(server)) as client: + with run_server() as server: + with connect(get_uri(server) + "/crash") as client: with self.assertRaises(ConnectionClosedError) as raised: client.recv() self.assertEqual( @@ -377,7 +375,7 @@ class ServerUsageErrorsTests(unittest.TestCase): def test_unix_without_path_or_sock(self): """Unix server requires path when sock isn't provided.""" with self.assertRaises(TypeError) as raised: - unix_serve(eval_shell) + unix_serve(handler) self.assertEqual( str(raised.exception), "missing path argument", @@ -388,7 +386,7 @@ def test_unix_with_path_and_sock(self): sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.addCleanup(sock.close) with self.assertRaises(TypeError) as raised: - unix_serve(eval_shell, path="/", sock=sock) + unix_serve(handler, path="/", sock=sock) self.assertEqual( str(raised.exception), "path and sock arguments are incompatible", @@ -397,7 +395,7 @@ def test_unix_with_path_and_sock(self): def test_invalid_subprotocol(self): """Server rejects single value of subprotocols.""" with self.assertRaises(TypeError) as raised: - serve(eval_shell, subprotocols="chat") + serve(handler, subprotocols="chat") self.assertEqual( str(raised.exception), "subprotocols must be a list, not a str", @@ -406,7 +404,7 @@ def test_invalid_subprotocol(self): def test_unsupported_compression(self): """Server rejects incorrect value of compression.""" with self.assertRaises(ValueError) as raised: - serve(eval_shell, compression=False) + serve(handler, compression=False) self.assertEqual( str(raised.exception), "unsupported compression: False", From 15eb223e52aafad92849cc784837691f47c474cb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 24 Aug 2024 22:06:59 +0200 Subject: [PATCH 1373/1539] Remove run_(unix_)server abstraction from asyncio tests. It wasn't adding much value. Calling serve directly is more obvious. --- tests/asyncio/server.py | 19 ++---- tests/asyncio/test_client.py | 47 +++++++------- tests/asyncio/test_server.py | 116 +++++++++++++++++++---------------- 3 files changed, 90 insertions(+), 92 deletions(-) diff --git a/tests/asyncio/server.py b/tests/asyncio/server.py index 097950971..acf6500c6 100644 --- a/tests/asyncio/server.py +++ b/tests/asyncio/server.py @@ -1,9 +1,6 @@ import asyncio -import contextlib import socket -from websockets.asyncio.server import * - def get_host_port(server): for sock in server.sockets: @@ -38,19 +35,11 @@ async def handler(ws): raise AssertionError(f"unexpected path: {path}") +# This shortcut avoids repeating serve(handler, "localhost", 0) for every test. +args = handler, "localhost", 0 + + class EvalShellMixin: async def assertEval(self, client, expr, value): await client.send(expr) self.assertEqual(await client.recv(), value) - - -@contextlib.asynccontextmanager -async def run_server(handler=handler, host="localhost", port=0, **kwargs): - async with serve(handler, host, port, **kwargs) as server: - yield server - - -@contextlib.asynccontextmanager -async def run_unix_server(path, handler=handler, **kwargs): - async with unix_serve(handler, path, **kwargs) as server: - yield server diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 77261129a..53a6eaaf1 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -5,23 +5,24 @@ from websockets.asyncio.client import * from websockets.asyncio.compatibility import TimeoutError +from websockets.asyncio.server import serve, unix_serve from websockets.exceptions import InvalidHandshake, InvalidURI from websockets.extensions.permessage_deflate import PerMessageDeflate from ..utils import CLIENT_CONTEXT, MS, SERVER_CONTEXT, temp_unix_socket_path -from .server import get_host_port, get_uri, run_server, run_unix_server +from .server import args, get_host_port, get_uri, handler class ClientTests(unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Client connects to server and the handshake succeeds.""" - async with run_server() as server: + async with serve(*args) as server: async with connect(get_uri(server)) as client: self.assertEqual(client.protocol.state.name, "OPEN") async def test_existing_socket(self): """Client connects using a pre-existing socket.""" - async with run_server() as server: + async with serve(*args) as server: with socket.create_connection(get_host_port(server)) as sock: # Use a non-existing domain to ensure we connect to the right socket. async with connect("ws://invalid/", sock=sock) as client: @@ -29,7 +30,7 @@ async def test_existing_socket(self): async def test_additional_headers(self): """Client can set additional headers with additional_headers.""" - async with run_server() as server: + async with serve(*args) as server: async with connect( get_uri(server), additional_headers={"Authorization": "Bearer ..."} ) as client: @@ -37,19 +38,19 @@ async def test_additional_headers(self): async def test_override_user_agent(self): """Client can override User-Agent header with user_agent_header.""" - async with run_server() as server: + async with serve(*args) as server: async with connect(get_uri(server), user_agent_header="Smith") as client: self.assertEqual(client.request.headers["User-Agent"], "Smith") async def test_remove_user_agent(self): """Client can remove User-Agent header with user_agent_header.""" - async with run_server() as server: + async with serve(*args) as server: async with connect(get_uri(server), user_agent_header=None) as client: self.assertNotIn("User-Agent", client.request.headers) async def test_compression_is_enabled(self): """Client enables compression by default.""" - async with run_server() as server: + async with serve(*args) as server: async with connect(get_uri(server)) as client: self.assertEqual( [type(ext) for ext in client.protocol.extensions], @@ -58,13 +59,13 @@ async def test_compression_is_enabled(self): async def test_disable_compression(self): """Client disables compression.""" - async with run_server() as server: + async with serve(*args) as server: async with connect(get_uri(server), compression=None) as client: self.assertEqual(client.protocol.extensions, []) async def test_keepalive_is_enabled(self): """Client enables keepalive and measures latency by default.""" - async with run_server() as server: + async with serve(*args) as server: async with connect(get_uri(server), ping_interval=MS) as client: self.assertEqual(client.latency, 0) await asyncio.sleep(2 * MS) @@ -72,7 +73,7 @@ async def test_keepalive_is_enabled(self): async def test_disable_keepalive(self): """Client disables keepalive.""" - async with run_server() as server: + async with serve(*args) as server: async with connect(get_uri(server), ping_interval=None) as client: await asyncio.sleep(2 * MS) self.assertEqual(client.latency, 0) @@ -85,7 +86,7 @@ def create_connection(*args, **kwargs): client.create_connection_ran = True return client - async with run_server() as server: + async with serve(*args) as server: async with connect( get_uri(server), create_connection=create_connection ) as client: @@ -111,7 +112,7 @@ def remove_accept_header(self, request, response): # The connection will be open for the server but failed for the client. # Use a connection handler that exits immediately to avoid an exception. - async with run_server(process_response=remove_accept_header) as server: + async with serve(*args, process_response=remove_accept_header) as server: with self.assertRaises(InvalidHandshake) as raised: async with connect(get_uri(server) + "/no-op", close_timeout=MS): self.fail("did not raise") @@ -129,7 +130,7 @@ async def stall_connection(self, request): # The connection will be open for the server but failed for the client. # Use a connection handler that exits immediately to avoid an exception. - async with run_server(process_request=stall_connection) as server: + async with serve(*args, process_request=stall_connection) as server: try: with self.assertRaises(TimeoutError) as raised: async with connect(get_uri(server) + "/no-op", open_timeout=2 * MS): @@ -147,7 +148,7 @@ async def test_connection_closed_during_handshake(self): def close_connection(self, request): self.close_transport() - async with run_server(process_request=close_connection) as server: + async with serve(*args, process_request=close_connection) as server: with self.assertRaises(ConnectionError) as raised: async with connect(get_uri(server)): self.fail("did not raise") @@ -160,7 +161,7 @@ def close_connection(self, request): class SecureClientTests(unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Client connects to server securely.""" - async with run_server(ssl=SERVER_CONTEXT) as server: + async with serve(*args, ssl=SERVER_CONTEXT) as server: async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") ssl_object = client.transport.get_extra_info("ssl_object") @@ -168,7 +169,7 @@ async def test_connection(self): async def test_set_server_hostname_implicitly(self): """Client sets server_hostname to the host in the WebSocket URI.""" - async with run_server(ssl=SERVER_CONTEXT) as server: + async with serve(*args, ssl=SERVER_CONTEXT) as server: host, port = get_host_port(server) async with connect( "wss://overridden/", host=host, port=port, ssl=CLIENT_CONTEXT @@ -178,7 +179,7 @@ async def test_set_server_hostname_implicitly(self): async def test_set_server_hostname_explicitly(self): """Client sets server_hostname to the value provided in argument.""" - async with run_server(ssl=SERVER_CONTEXT) as server: + async with serve(*args, ssl=SERVER_CONTEXT) as server: async with connect( get_uri(server), ssl=CLIENT_CONTEXT, server_hostname="overridden" ) as client: @@ -187,7 +188,7 @@ async def test_set_server_hostname_explicitly(self): async def test_reject_invalid_server_certificate(self): """Client rejects certificate where server certificate isn't trusted.""" - async with run_server(ssl=SERVER_CONTEXT) as server: + async with serve(*args, ssl=SERVER_CONTEXT) as server: with self.assertRaises(ssl.SSLCertVerificationError) as raised: # The test certificate isn't trusted system-wide. async with connect(get_uri(server)): @@ -199,7 +200,7 @@ async def test_reject_invalid_server_certificate(self): async def test_reject_invalid_server_hostname(self): """Client rejects certificate where server hostname doesn't match.""" - async with run_server(ssl=SERVER_CONTEXT) as server: + async with serve(*args, ssl=SERVER_CONTEXT) as server: with self.assertRaises(ssl.SSLCertVerificationError) as raised: # This hostname isn't included in the test certificate. async with connect( @@ -217,7 +218,7 @@ class UnixClientTests(unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Client connects to server over a Unix socket.""" with temp_unix_socket_path() as path: - async with run_unix_server(path): + async with unix_serve(handler, path): async with unix_connect(path) as client: self.assertEqual(client.protocol.state.name, "OPEN") @@ -225,7 +226,7 @@ async def test_set_host_header(self): """Client sets the Host header to the host in the WebSocket URI.""" # This is part of the documented behavior of unix_connect(). with temp_unix_socket_path() as path: - async with run_unix_server(path): + async with unix_serve(handler, path): async with unix_connect(path, uri="ws://overridden/") as client: self.assertEqual(client.request.headers["Host"], "overridden") @@ -235,7 +236,7 @@ class SecureUnixClientTests(unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Client connects to server securely over a Unix socket.""" with temp_unix_socket_path() as path: - async with run_unix_server(path, ssl=SERVER_CONTEXT): + async with unix_serve(handler, path, ssl=SERVER_CONTEXT): async with unix_connect(path, ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") ssl_object = client.transport.get_extra_info("ssl_object") @@ -245,7 +246,7 @@ async def test_set_server_hostname(self): """Client sets server_hostname to the host in the WebSocket URI.""" # This is part of the documented behavior of unix_connect(). with temp_unix_socket_path() as path: - async with run_unix_server(path, ssl=SERVER_CONTEXT): + async with unix_serve(handler, path, ssl=SERVER_CONTEXT): async with unix_connect( path, ssl=CLIENT_CONTEXT, diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index fd9c69c40..fe0cafe13 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -25,24 +25,23 @@ ) from .server import ( EvalShellMixin, + args, get_host_port, get_uri, handler, - run_server, - run_unix_server, ) class ServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Server receives connection from client and the handshake succeeds.""" - async with run_server() as server: + async with serve(*args) as server: async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") async def test_connection_handler_returns(self): """Connection handler returns.""" - async with run_server() as server: + async with serve(*args) as server: async with connect(get_uri(server) + "/no-op") as client: with self.assertRaises(ConnectionClosedOK) as raised: await client.recv() @@ -53,7 +52,7 @@ async def test_connection_handler_returns(self): async def test_connection_handler_raises_exception(self): """Connection handler raises an exception.""" - async with run_server() as server: + async with serve(*args) as server: async with connect(get_uri(server) + "/crash") as client: with self.assertRaises(ConnectionClosedError) as raised: await client.recv() @@ -66,7 +65,7 @@ async def test_connection_handler_raises_exception(self): async def test_existing_socket(self): """Server receives connection using a pre-existing socket.""" with socket.create_server(("localhost", 0)) as sock: - async with run_server(sock=sock, host=None, port=None): + async with serve(handler, sock=sock, host=None, port=None): uri = "ws://{}:{}/".format(*sock.getsockname()) async with connect(uri) as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") @@ -79,7 +78,8 @@ def select_subprotocol(ws, subprotocols): assert "chat" in subprotocols return "chat" - async with run_server( + async with serve( + *args, subprotocols=["chat"], select_subprotocol=select_subprotocol, ) as server: @@ -93,7 +93,7 @@ async def test_select_subprotocol_rejects_handshake(self): def select_subprotocol(ws, subprotocols): raise NegotiationError - async with run_server(select_subprotocol=select_subprotocol) as server: + async with serve(*args, select_subprotocol=select_subprotocol) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") @@ -108,7 +108,7 @@ async def test_select_subprotocol_raises_exception(self): def select_subprotocol(ws, subprotocols): raise RuntimeError - async with run_server(select_subprotocol=select_subprotocol) as server: + async with serve(*args, select_subprotocol=select_subprotocol) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") @@ -124,7 +124,7 @@ def process_request(ws, request): self.assertIsInstance(request, Request) ws.process_request_ran = True - async with run_server(process_request=process_request) as server: + async with serve(*args, process_request=process_request) as server: async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.process_request_ran", "True") @@ -135,7 +135,7 @@ async def process_request(ws, request): self.assertIsInstance(request, Request) ws.process_request_ran = True - async with run_server(process_request=process_request) as server: + async with serve(*args, process_request=process_request) as server: async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.process_request_ran", "True") @@ -145,7 +145,7 @@ async def test_process_request_returns_response(self): def process_request(ws, request): return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") - async with run_server(process_request=process_request) as server: + async with serve(*args, process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") @@ -160,7 +160,7 @@ async def test_async_process_request_returns_response(self): async def process_request(ws, request): return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") - async with run_server(process_request=process_request) as server: + async with serve(*args, process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") @@ -175,7 +175,7 @@ async def test_process_request_raises_exception(self): def process_request(ws, request): raise RuntimeError - async with run_server(process_request=process_request) as server: + async with serve(*args, process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") @@ -190,7 +190,7 @@ async def test_async_process_request_raises_exception(self): async def process_request(ws, request): raise RuntimeError - async with run_server(process_request=process_request) as server: + async with serve(*args, process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") @@ -207,7 +207,7 @@ def process_response(ws, request, response): self.assertIsInstance(response, Response) ws.process_response_ran = True - async with run_server(process_response=process_response) as server: + async with serve(*args, process_response=process_response) as server: async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.process_response_ran", "True") @@ -219,7 +219,7 @@ async def process_response(ws, request, response): self.assertIsInstance(response, Response) ws.process_response_ran = True - async with run_server(process_response=process_response) as server: + async with serve(*args, process_response=process_response) as server: async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.process_response_ran", "True") @@ -229,7 +229,7 @@ async def test_process_response_modifies_response(self): def process_response(ws, request, response): response.headers["X-ProcessResponse"] = "OK" - async with run_server(process_response=process_response) as server: + async with serve(*args, process_response=process_response) as server: async with connect(get_uri(server)) as client: self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") @@ -239,7 +239,7 @@ async def test_async_process_response_modifies_response(self): async def process_response(ws, request, response): response.headers["X-ProcessResponse"] = "OK" - async with run_server(process_response=process_response) as server: + async with serve(*args, process_response=process_response) as server: async with connect(get_uri(server)) as client: self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") @@ -251,7 +251,7 @@ def process_response(ws, request, response): headers["X-ProcessResponse"] = "OK" return dataclasses.replace(response, headers=headers) - async with run_server(process_response=process_response) as server: + async with serve(*args, process_response=process_response) as server: async with connect(get_uri(server)) as client: self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") @@ -263,7 +263,7 @@ async def process_response(ws, request, response): headers["X-ProcessResponse"] = "OK" return dataclasses.replace(response, headers=headers) - async with run_server(process_response=process_response) as server: + async with serve(*args, process_response=process_response) as server: async with connect(get_uri(server)) as client: self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") @@ -273,7 +273,7 @@ async def test_process_response_raises_exception(self): def process_response(ws, request, response): raise RuntimeError - async with run_server(process_response=process_response) as server: + async with serve(*args, process_response=process_response) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") @@ -288,7 +288,7 @@ async def test_async_process_response_raises_exception(self): async def process_response(ws, request, response): raise RuntimeError - async with run_server(process_response=process_response) as server: + async with serve(*args, process_response=process_response) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") @@ -299,13 +299,13 @@ async def process_response(ws, request, response): async def test_override_server(self): """Server can override Server header with server_header.""" - async with run_server(server_header="Neo") as server: + async with serve(*args, server_header="Neo") as server: async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.response.headers['Server']", "Neo") async def test_remove_server(self): """Server can remove Server header with server_header.""" - async with run_server(server_header=None) as server: + async with serve(*args, server_header=None) as server: async with connect(get_uri(server)) as client: await self.assertEval( client, "'Server' in ws.response.headers", "False" @@ -313,7 +313,7 @@ async def test_remove_server(self): async def test_compression_is_enabled(self): """Server enables compression by default.""" - async with run_server() as server: + async with serve(*args) as server: async with connect(get_uri(server)) as client: await self.assertEval( client, @@ -323,13 +323,13 @@ async def test_compression_is_enabled(self): async def test_disable_compression(self): """Server disables compression.""" - async with run_server(compression=None) as server: + async with serve(*args, compression=None) as server: async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.protocol.extensions", "[]") async def test_keepalive_is_enabled(self): """Server enables keepalive and measures latency.""" - async with run_server(ping_interval=MS) as server: + async with serve(*args, ping_interval=MS) as server: async with connect(get_uri(server)) as client: await client.send("ws.latency") latency = eval(await client.recv()) @@ -341,7 +341,7 @@ async def test_keepalive_is_enabled(self): async def test_disable_keepalive(self): """Client disables keepalive.""" - async with run_server(ping_interval=None) as server: + async with serve(*args, ping_interval=None) as server: async with connect(get_uri(server)) as client: await asyncio.sleep(2 * MS) await client.send("ws.latency") @@ -351,7 +351,7 @@ async def test_disable_keepalive(self): async def test_logger(self): """Server accepts a logger argument.""" logger = logging.getLogger("test") - async with run_server(logger=logger) as server: + async with serve(*args, logger=logger) as server: self.assertIs(server.logger, logger) async def test_custom_connection_factory(self): @@ -362,13 +362,13 @@ def create_connection(*args, **kwargs): server.create_connection_ran = True return server - async with run_server(create_connection=create_connection) as server: + async with serve(*args, create_connection=create_connection) as server: async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.create_connection_ran", "True") async def test_connections(self): """Server provides a connections property.""" - async with run_server() as server: + async with serve(*args) as server: self.assertEqual(server.connections, set()) async with connect(get_uri(server)) as client: self.assertEqual(len(server.connections), 1) @@ -382,7 +382,7 @@ async def test_handshake_fails(self): def remove_key_header(self, request): del request.headers["Sec-WebSocket-Key"] - async with run_server(process_request=remove_key_header) as server: + async with serve(*args, process_request=remove_key_header) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") @@ -393,7 +393,7 @@ def remove_key_header(self, request): async def test_timeout_during_handshake(self): """Server times out before receiving handshake request from client.""" - async with run_server(open_timeout=MS) as server: + async with serve(*args, open_timeout=MS) as server: reader, writer = await asyncio.open_connection(*get_host_port(server)) try: self.assertEqual(await reader.read(4096), b"") @@ -402,7 +402,7 @@ async def test_timeout_during_handshake(self): async def test_connection_closed_during_handshake(self): """Server reads EOF before receiving handshake request from client.""" - async with run_server() as server: + async with serve(*args) as server: _reader, writer = await asyncio.open_connection(*get_host_port(server)) writer.close() @@ -413,7 +413,7 @@ async def process_request(ws, _request): while ws.server.is_serving(): await asyncio.sleep(0) # pragma: no cover - async with run_server(process_request=process_request) as server: + async with serve(*args, process_request=process_request) as server: asyncio.get_running_loop().call_later(MS, server.close) with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): @@ -425,7 +425,7 @@ async def process_request(ws, _request): async def test_close_server_closes_open_connections(self): """Server closes open connections with close code 1001 when closing.""" - async with run_server() as server: + async with serve(*args) as server: async with connect(get_uri(server)) as client: server.close() with self.assertRaises(ConnectionClosedOK) as raised: @@ -437,7 +437,7 @@ async def test_close_server_closes_open_connections(self): async def test_close_server_keeps_connections_open(self): """Server waits for client to close open connections when closing.""" - async with run_server() as server: + async with serve(*args) as server: async with connect(get_uri(server)) as client: server.close(close_connections=False) @@ -457,7 +457,7 @@ async def test_close_server_keeps_connections_open(self): async def test_close_server_keeps_handlers_running(self): """Server waits for connection handlers to terminate.""" - async with run_server() as server: + async with serve(*args) as server: async with connect(get_uri(server) + "/delay") as client: # Delay termination of connection handler. await client.send(str(3 * MS)) @@ -479,14 +479,14 @@ async def test_close_server_keeps_handlers_running(self): class SecureServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Server receives secure connection from client.""" - async with run_server(ssl=SERVER_CONTEXT) as server: + async with serve(*args, ssl=SERVER_CONTEXT) as server: async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") await self.assertEval(client, SSL_OBJECT + ".version()[:3]", "TLS") async def test_timeout_during_tls_handshake(self): """Server times out before receiving TLS handshake request from client.""" - async with run_server(ssl=SERVER_CONTEXT, open_timeout=MS) as server: + async with serve(*args, ssl=SERVER_CONTEXT, open_timeout=MS) as server: reader, writer = await asyncio.open_connection(*get_host_port(server)) try: self.assertEqual(await reader.read(4096), b"") @@ -495,7 +495,7 @@ async def test_timeout_during_tls_handshake(self): async def test_connection_closed_during_tls_handshake(self): """Server reads EOF before receiving TLS handshake request from client.""" - async with run_server(ssl=SERVER_CONTEXT) as server: + async with serve(*args, ssl=SERVER_CONTEXT) as server: _reader, writer = await asyncio.open_connection(*get_host_port(server)) writer.close() @@ -505,7 +505,7 @@ class UnixServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Server receives connection from client over a Unix socket.""" with temp_unix_socket_path() as path: - async with run_unix_server(path): + async with unix_serve(handler, path): async with unix_connect(path) as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") @@ -515,7 +515,7 @@ class SecureUnixServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Server receives secure connection from client over a Unix socket.""" with temp_unix_socket_path() as path: - async with run_unix_server(path, ssl=SERVER_CONTEXT): + async with unix_serve(handler, path, ssl=SERVER_CONTEXT): async with unix_connect(path, ssl=CLIENT_CONTEXT) as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") await self.assertEval(client, SSL_OBJECT + ".version()[:3]", "TLS") @@ -545,7 +545,7 @@ async def test_unix_with_path_and_sock(self): async def test_invalid_subprotocol(self): """Server rejects single value of subprotocols.""" with self.assertRaises(TypeError) as raised: - await serve(handler, subprotocols="chat") + await serve(*args, subprotocols="chat") self.assertEqual( str(raised.exception), "subprotocols must be a list, not a str", @@ -554,7 +554,7 @@ async def test_invalid_subprotocol(self): async def test_unsupported_compression(self): """Server rejects incorrect value of compression.""" with self.assertRaises(ValueError) as raised: - await serve(handler, compression=False) + await serve(*args, compression=False) self.assertEqual( str(raised.exception), "unsupported compression: False", @@ -564,7 +564,8 @@ async def test_unsupported_compression(self): class BasicAuthTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): async def test_valid_authorization(self): """basic_auth authenticates client with HTTP Basic Authentication.""" - async with run_server( + async with serve( + *args, process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: async with connect( @@ -575,7 +576,8 @@ async def test_valid_authorization(self): async def test_missing_authorization(self): """basic_auth rejects client without credentials.""" - async with run_server( + async with serve( + *args, process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: with self.assertRaises(InvalidStatus) as raised: @@ -588,7 +590,8 @@ async def test_missing_authorization(self): async def test_unsupported_authorization(self): """basic_auth rejects client with unsupported credentials.""" - async with run_server( + async with serve( + *args, process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: with self.assertRaises(InvalidStatus) as raised: @@ -604,7 +607,8 @@ async def test_unsupported_authorization(self): async def test_authorization_with_unknown_username(self): """basic_auth rejects client with unknown username.""" - async with run_server( + async with serve( + *args, process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: with self.assertRaises(InvalidStatus) as raised: @@ -620,7 +624,8 @@ async def test_authorization_with_unknown_username(self): async def test_authorization_with_incorrect_password(self): """basic_auth rejects client with incorrect password.""" - async with run_server( + async with serve( + *args, process_request=basic_auth(credentials=("hello", "changeme")), ) as server: with self.assertRaises(InvalidStatus) as raised: @@ -636,7 +641,8 @@ async def test_authorization_with_incorrect_password(self): async def test_list_of_credentials(self): """basic_auth accepts a list of hard coded credentials.""" - async with run_server( + async with serve( + *args, process_request=basic_auth( credentials=[ ("hello", "iloveyou"), @@ -656,7 +662,8 @@ async def test_check_credentials_function(self): def check_credentials(username, password): return hmac.compare_digest(password, "iloveyou") - async with run_server( + async with serve( + *args, process_request=basic_auth(check_credentials=check_credentials), ) as server: async with connect( @@ -671,7 +678,8 @@ async def test_check_credentials_coroutine(self): async def check_credentials(username, password): return hmac.compare_digest(password, "iloveyou") - async with run_server( + async with serve( + *args, process_request=basic_auth(check_credentials=check_credentials), ) as server: async with connect( From 317123119434e22e11c273bf422917bfa0f2a626 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 28 Aug 2024 21:40:33 +0200 Subject: [PATCH 1374/1539] Add tests for the logger argument of clients. --- tests/asyncio/test_client.py | 8 ++++++++ tests/asyncio/test_server.py | 2 +- tests/sync/test_client.py | 8 ++++++++ tests/sync/test_server.py | 2 +- 4 files changed, 18 insertions(+), 2 deletions(-) diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 53a6eaaf1..15178f8b8 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -1,4 +1,5 @@ import asyncio +import logging import socket import ssl import unittest @@ -78,6 +79,13 @@ async def test_disable_keepalive(self): await asyncio.sleep(2 * MS) self.assertEqual(client.latency, 0) + async def test_logger(self): + """Client accepts a logger argument.""" + logger = logging.getLogger("test") + async with serve(*args) as server: + async with connect(get_uri(server), logger=logger) as client: + self.assertEqual(client.logger.name, logger.name) + async def test_custom_connection_factory(self): """Client runs ClientConnection factory provided in create_connection.""" diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index fe0cafe13..ceb0417a7 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -352,7 +352,7 @@ async def test_logger(self): """Server accepts a logger argument.""" logger = logging.getLogger("test") async with serve(*args, logger=logger) as server: - self.assertIs(server.logger, logger) + self.assertEqual(server.logger.name, logger.name) async def test_custom_connection_factory(self): """Server runs ServerConnection factory provided in create_connection.""" diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 0d5273d12..812412203 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -1,3 +1,4 @@ +import logging import socket import ssl import threading @@ -67,6 +68,13 @@ def test_disable_compression(self): with connect(get_uri(server), compression=None) as client: self.assertEqual(client.protocol.extensions, []) + def test_logger(self): + """Client accepts a logger argument.""" + logger = logging.getLogger("test") + with run_server() as server: + with connect(get_uri(server), logger=logger) as client: + self.assertEqual(client.logger.name, logger.name) + def test_custom_connection_factory(self): """Client runs ClientConnection factory provided in create_connection.""" diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index a4b537c66..541a14602 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -238,7 +238,7 @@ def test_logger(self): """Server accepts a logger argument.""" logger = logging.getLogger("test") with run_server(logger=logger) as server: - self.assertIs(server.logger, logger) + self.assertEqual(server.logger.name, logger.name) def test_custom_connection_factory(self): """Server runs ServerConnection factory provided in create_connection.""" From 04ac475805f20cf73d79bb3fd003e412c8526396 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 28 Aug 2024 21:40:47 +0200 Subject: [PATCH 1375/1539] Fix copy-paste errors in test docstrings. --- tests/asyncio/test_server.py | 2 +- tests/sync/test_server.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index ceb0417a7..bc1d0444c 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -340,7 +340,7 @@ async def test_keepalive_is_enabled(self): self.assertGreater(latency, 0) async def test_disable_keepalive(self): - """Client disables keepalive.""" + """Server disables keepalive.""" async with serve(*args, ping_interval=None) as server: async with connect(get_uri(server)) as client: await asyncio.sleep(2 * MS) diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 541a14602..7bcf144a2 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -557,7 +557,7 @@ def test_bad_list_of_credentials(self): class BackwardsCompatibilityTests(DeprecationTestCase): def test_ssl_context_argument(self): - """Client supports the deprecated ssl_context argument.""" + """Server supports the deprecated ssl_context argument.""" with self.assertDeprecationWarning("ssl_context was renamed to ssl"): with run_server(ssl_context=SERVER_CONTEXT) as server: with connect(get_uri(server), ssl=CLIENT_CONTEXT): From 033bf40c692029e3acaf2bb3adc946e7f8cf15ac Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 24 Aug 2024 08:06:30 +0200 Subject: [PATCH 1376/1539] Add a generator of backoff delays. --- docs/reference/variables.rst | 39 +++++++++++++++++++++++++++++---- src/websockets/client.py | 32 +++++++++++++++++++++++++++ src/websockets/legacy/client.py | 11 +++++----- tests/test_client.py | 13 +++++++++++ 4 files changed, 86 insertions(+), 9 deletions(-) diff --git a/docs/reference/variables.rst b/docs/reference/variables.rst index 4bca112da..498132251 100644 --- a/docs/reference/variables.rst +++ b/docs/reference/variables.rst @@ -13,7 +13,7 @@ Logging See the :doc:`logging guide <../topics/logging>` for details. Security -........ +-------- .. envvar:: WEBSOCKETS_SERVER @@ -31,18 +31,49 @@ Security Maximum length of the request or status line in the opening handshake. - The default value is ``8192``. + The default value is ``8192`` bytes. .. envvar:: WEBSOCKETS_MAX_NUM_HEADERS Maximum number of HTTP headers in the opening handshake. - The default value is ``128``. + The default value is ``128`` bytes. .. envvar:: WEBSOCKETS_MAX_BODY_SIZE Maximum size of the body of an HTTP response in the opening handshake. - The default value is ``1_048_576`` (1 MiB). + The default value is ``1_048_576`` bytes (1 MiB). See the :doc:`security guide <../topics/security>` for details. + +Reconnection +------------ + +Reconnection attempts are spaced out with truncated exponential backoff. + +.. envvar:: BACKOFF_INITIAL_DELAY + + The first attempt is delayed by a random amount of time between ``0`` and + ``BACKOFF_INITIAL_DELAY`` seconds. + + The default value is ``5.0`` seconds. + +.. envvar:: BACKOFF_MIN_DELAY + + The second attempt is delayed by ``BACKOFF_MIN_DELAY`` seconds. + + The default value is ``3.1`` seconds. + +.. envvar:: BACKOFF_FACTOR + + After the second attempt, the delay is multiplied by ``BACKOFF_FACTOR`` + between each attempt. + + The default value is ``1.618``. + +.. envvar:: BACKOFF_MAX_DELAY + + The delay between attempts is capped at ``BACKOFF_MAX_DELAY`` seconds. + + The default value is ``90.0`` seconds. diff --git a/src/websockets/client.py b/src/websockets/client.py index ae467993a..95de99dc5 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -1,5 +1,7 @@ from __future__ import annotations +import os +import random import warnings from typing import Any, Generator, Sequence @@ -357,3 +359,33 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: DeprecationWarning, ) super().__init__(*args, **kwargs) + + +BACKOFF_INITIAL_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_INITIAL_DELAY", "5")) +BACKOFF_MIN_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_MIN_DELAY", "3.1")) +BACKOFF_MAX_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_MAX_DELAY", "90.0")) +BACKOFF_FACTOR = float(os.environ.get("WEBSOCKETS_BACKOFF_FACTOR", "1.618")) + + +def backoff( + initial_delay: float = BACKOFF_INITIAL_DELAY, + min_delay: float = BACKOFF_MIN_DELAY, + max_delay: float = BACKOFF_MAX_DELAY, + factor: float = BACKOFF_FACTOR, +) -> Generator[float, None, None]: + """ + Generate a series of backoff delays between reconnection attempts. + + Yields: + How many seconds to wait before retrying to connect. + + """ + # Add a random initial delay between 0 and 5 seconds. + # See 7.2.3. Recovering from Abnormal Closure in RFC 6455. + yield random.random() * initial_delay + delay = min_delay + while delay < max_delay: + yield delay + delay *= factor + while True: + yield max_delay diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 25142ea25..a1bc5cbae 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -3,6 +3,7 @@ import asyncio import functools import logging +import os import random import urllib.parse import warnings @@ -591,13 +592,13 @@ def handle_redirect(self, uri: str) -> None: # async for ... in connect(...): - BACKOFF_MIN = 1.92 - BACKOFF_MAX = 60.0 - BACKOFF_FACTOR = 1.618 - BACKOFF_INITIAL = 5 + BACKOFF_INITIAL = float(os.environ.get("WEBSOCKETS_BACKOFF_INITIAL_DELAY", "5")) + BACKOFF_MIN = float(os.environ.get("WEBSOCKETS_BACKOFF_MIN_DELAY", "3.1")) + BACKOFF_MAX = float(os.environ.get("WEBSOCKETS_BACKOFF_MAX_DELAY", "90.0")) + BACKOFF_FACTOR = float(os.environ.get("WEBSOCKETS_BACKOFF_FACTOR", "1.618")) async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]: - backoff_delay = self.BACKOFF_MIN + backoff_delay = self.BACKOFF_MIN / self.BACKOFF_FACTOR while True: try: async with self as protocol: diff --git a/tests/test_client.py b/tests/test_client.py index d798a66f9..8b3bf4232 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -3,6 +3,7 @@ import unittest.mock from websockets.client import * +from websockets.client import backoff from websockets.datastructures import Headers from websockets.exceptions import InvalidHandshake, InvalidHeader from websockets.frames import OP_TEXT, Frame @@ -613,3 +614,15 @@ def test_client_connection_class(self): client = ClientConnection("ws://localhost/") self.assertIsInstance(client, ClientProtocol) + + +class BackoffTests(unittest.TestCase): + def test_backoff(self): + backoff_gen = backoff() + + initial_delay = next(backoff_gen) + self.assertGreaterEqual(initial_delay, 0) + self.assertLess(initial_delay, 5) + + following_delays = [int(next(backoff_gen)) for _ in range(9)] + self.assertEqual(following_delays, [3, 5, 8, 13, 21, 34, 55, 89, 90]) From 5ada5f776be2d1693f56a13de781d4b1b13f4ecf Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 24 Aug 2024 08:22:58 +0200 Subject: [PATCH 1377/1539] Reorder functions logically in ClientConnection. Define callers after callees. --- src/websockets/asyncio/client.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 033887e87..170ae5912 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -295,19 +295,6 @@ def factory() -> ClientConnection: self._open_timeout = open_timeout - # async with connect(...) as ...: ... - - async def __aenter__(self) -> ClientConnection: - return await self - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_value: BaseException | None, - traceback: TracebackType | None, - ) -> None: - await self.connection.close() - # ... = await connect(...) def __await__(self) -> Generator[Any, None, ClientConnection]: @@ -333,6 +320,19 @@ async def __await_impl__(self) -> ClientConnection: __iter__ = __await__ + # async with connect(...) as ...: ... + + async def __aenter__(self) -> ClientConnection: + return await self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + await self.connection.close() + def unix_connect( path: str | None = None, From 6ffb6b057caeb1d1d35b60a1a82afa7629694f82 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 24 Aug 2024 08:28:53 +0200 Subject: [PATCH 1378/1539] Remove _-prefixed attributes. Overall websockets doesn't use this convention for private attributes. --- src/websockets/asyncio/client.py | 14 +++++++------- src/websockets/asyncio/server.py | 6 +++--- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 170ae5912..f669359a5 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -281,19 +281,19 @@ def factory() -> ClientConnection: loop = asyncio.get_running_loop() if kwargs.pop("unix", False): - self._create_connection = loop.create_unix_connection(factory, **kwargs) + self.create_connection = loop.create_unix_connection(factory, **kwargs) else: if kwargs.get("sock") is None: kwargs.setdefault("host", wsuri.host) kwargs.setdefault("port", wsuri.port) - self._create_connection = loop.create_connection(factory, **kwargs) + self.create_connection = loop.create_connection(factory, **kwargs) - self._handshake_args = ( + self.handshake_args = ( additional_headers, user_agent_header, ) - self._open_timeout = open_timeout + self.open_timeout = open_timeout # ... = await connect(...) @@ -303,10 +303,10 @@ def __await__(self) -> Generator[Any, None, ClientConnection]: async def __await_impl__(self) -> ClientConnection: try: - async with asyncio_timeout(self._open_timeout): - _transport, self.connection = await self._create_connection + async with asyncio_timeout(self.open_timeout): + _transport, self.connection = await self.create_connection try: - await self.connection.handshake(*self._handshake_args) + await self.connection.handshake(*self.handshake_args) except (Exception, asyncio.CancelledError): self.connection.transport.close() raise diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index f24281252..5e71a892b 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -796,10 +796,10 @@ def protocol_select_subprotocol( loop = asyncio.get_running_loop() if kwargs.pop("unix", False): - self._create_server = loop.create_unix_server(factory, **kwargs) + self.create_server = loop.create_unix_server(factory, **kwargs) else: # mypy cannot tell that kwargs must provide sock when port is None. - self._create_server = loop.create_server(factory, host, port, **kwargs) # type: ignore[arg-type] + self.create_server = loop.create_server(factory, host, port, **kwargs) # type: ignore[arg-type] # async with serve(...) as ...: ... @@ -822,7 +822,7 @@ def __await__(self) -> Generator[Any, None, Server]: return self.__await_impl__().__await__() async def __await_impl__(self) -> Server: - server = await self._create_server + server = await self.create_server self.server.wrap(server) return self.server From 16456e2bec3ad324d545c8eb5ea9d1eed583dea4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 28 Aug 2024 22:17:39 +0200 Subject: [PATCH 1379/1539] Restore id-token permission. Removing it in ed2f21e3 was a mistake. --- .github/workflows/release.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index a68714ed1..0a30c049b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -71,6 +71,7 @@ jobs: # Don't release when running the workflow manually from GitHub's UI. if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') permissions: + id-token: write contents: write steps: - name: Download artifacts From 62d70f4a8fb5e30d8744b1378659cc999deb0715 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 28 Aug 2024 22:06:53 +0200 Subject: [PATCH 1380/1539] Restore speedups.c in source distribution. Fix #1494. --- MANIFEST.in | 1 + docs/project/changelog.rst | 10 ++++++++++ src/websockets/version.py | 2 +- 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/MANIFEST.in b/MANIFEST.in index 1c660b95b..d4598bda0 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1,3 @@ include LICENSE include src/websockets/py.typed +include src/websockets/speedups.c # required when BUILD_EXTENSION=no diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 7c5998288..2292c5aa1 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,16 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. +13.0.1 +------ + +*August 28, 2024* + +Bug fixes +......... + +* Restored the C extension in the source distribution. + .. _13.0: 13.0 diff --git a/src/websockets/version.py b/src/websockets/version.py index 56c321940..e82b285db 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -20,7 +20,7 @@ released = True -tag = version = commit = "13.0" +tag = version = commit = "13.0.1" if not released: # pragma: no cover From 157f7908c33cb540f7cf76568dde8aa6cf400504 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 28 Aug 2024 22:22:02 +0200 Subject: [PATCH 1381/1539] Add provenance attestations. --- .github/workflows/release.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 0a30c049b..863d88aa9 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -72,6 +72,7 @@ jobs: if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') permissions: id-token: write + attestations: write contents: write steps: - name: Download artifacts @@ -80,6 +81,10 @@ jobs: pattern: dist-* merge-multiple: true path: dist + - name: Attest provenance + uses: actions/attest-build-provenance@v1 + with: + subject-path: dist/* - name: Upload to PyPI uses: pypa/gh-action-pypi-publish@release/v1 - name: Create GitHub release From d8ab09b2566cdfcb7a0a95dc992454bb9f4b1647 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 24 Aug 2024 14:59:28 +0200 Subject: [PATCH 1382/1539] Add automatic reconnection to the new asyncio implementation. Missing tests for now. Fix #1480. --- docs/howto/upgrade.rst | 33 ++++---- docs/project/changelog.rst | 8 +- docs/reference/asyncio/client.rst | 2 + src/websockets/asyncio/client.py | 123 ++++++++++++++++++++++++++++-- tests/asyncio/test_client.py | 120 ++++++++++++++++++++++++++++- 5 files changed, 263 insertions(+), 23 deletions(-) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 18c3cc127..42edb978a 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -79,19 +79,6 @@ Following redirects The new implementation of :func:`~asyncio.client.connect` doesn't follow HTTP redirects yet. -Automatic reconnection -...................... - -The new implementation of :func:`~asyncio.client.connect` doesn't provide -automatic reconnection yet. - -In other words, the following pattern isn't supported:: - - from websockets.asyncio.client import connect - - async for websocket in connect(...): # this doesn't work yet - ... - .. _Update import paths: Import paths @@ -185,6 +172,26 @@ it simpler. ``process_response`` replaces ``extra_headers`` and provides more flexibility. See process_request_, select_subprotocol_, and process_response_ below. +Customizing automatic reconnection +.................................. + +On the client side, if you're reconnecting automatically with ``async for ... in +connect(...)``, the behavior when a connection attempt fails was enhanced and +made configurable. + +The original implementation retried on any error. The new implementation uses an +heuristic to determine whether an error is retryable or fatal. By default, only +network errors and server errors (HTTP 500, 502, 503, or 504) are considered +retryable. You can customize this behavior with the ``process_exception`` +argument of :func:`~asyncio.client.connect`. + +See :func:`~asyncio.client.process_exception` for more information. + +Here's how to revert to the behavior of the original implementation:: + + async for ... in connect(..., process_exception=lambda exc: exc): + ... + Tracking open connections ......................... diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index c239cfe5d..78a547d0b 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -46,12 +46,16 @@ Backwards-incompatible changes New features ............ -* Made the set of active connections available in the :attr:`Server.connections - ` property. +* Added support for reconnecting automatically by using + :func:`~asyncio.client.connect` as an asynchronous iterator to the new + :mod:`asyncio` implementation. * Added HTTP Basic Auth to the new :mod:`asyncio` and :mod:`threading` implementations of servers. +* Made the set of active connections available in the :attr:`Server.connections + ` property. + .. _13.0: 13.0 diff --git a/docs/reference/asyncio/client.rst b/docs/reference/asyncio/client.rst index 77a3c5d53..e2b0ff550 100644 --- a/docs/reference/asyncio/client.rst +++ b/docs/reference/asyncio/client.rst @@ -12,6 +12,8 @@ Opening a connection .. autofunction:: unix_connect :async: +.. autofunction:: process_exception + Using a connection ------------------ diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index f669359a5..860db3238 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -1,11 +1,14 @@ from __future__ import annotations import asyncio +import functools +import logging from types import TracebackType -from typing import Any, Generator, Sequence +from typing import Any, AsyncIterator, Callable, Generator, Sequence -from ..client import ClientProtocol +from ..client import ClientProtocol, backoff from ..datastructures import HeadersLike +from ..exceptions import InvalidStatus from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate from ..headers import validate_subprotocols @@ -121,6 +124,46 @@ def connection_lost(self, exc: Exception | None) -> None: self.response_rcvd.set_result(None) +def process_exception(exc: Exception) -> Exception | None: + """ + Determine whether an error is retryable or fatal. + + When reconnecting automatically with ``async for ... in connect(...)``, if a + connection attempt fails, :func:`process_exception` is called to determine + whether to retry connecting or to raise the exception. + + This function defines the default behavior, which is to retry on: + + * :exc:`OSError` and :exc:`asyncio.TimeoutError`: network errors; + * :exc:`~websockets.exceptions.InvalidStatus` when the status code is 500, + 502, 503, or 504: server or proxy errors. + + All other exceptions are considered fatal. + + You can change this behavior with the ``process_exception`` argument of + :func:`connect`. + + Return :obj:`None` if the exception is retryable i.e. when the error could + be transient and trying to reconnect with the same parameters could succeed. + The exception will be logged at the ``INFO`` level. + + Return an exception, either ``exc`` or a new exception, if the exception is + fatal i.e. when trying to reconnect will most likely produce the same error. + That exception will be raised, breaking out of the retry loop. + + """ + if isinstance(exc, (OSError, asyncio.TimeoutError)): + return None + if isinstance(exc, InvalidStatus) and exc.response.status_code in [ + 500, # Internal Server Error + 502, # Bad Gateway + 503, # Service Unavailable + 504, # Gateway Timeout + ]: + return None + return exc + + # This is spelled in lower case because it's exposed as a callable in the API. class connect: """ @@ -138,6 +181,21 @@ class connect: The connection is closed automatically when exiting the context. + :func:`connect` can be used as an infinite asynchronous iterator to + reconnect automatically on errors:: + + async for websocket in connect(...): + try: + ... + except websockets.ConnectionClosed: + continue + + If the connection fails with a transient error, it is retried with + exponential backoff. If it fails with a fatal error, the exception is + raised, breaking out of the loop. + + The connection is closed automatically after each iteration of the loop. + Args: uri: URI of the WebSocket server. origin: Value of the ``Origin`` header, for servers that require it. @@ -153,6 +211,9 @@ class connect: compression: The "permessage-deflate" extension is enabled by default. Set ``compression`` to :obj:`None` to disable it. See the :doc:`compression guide <../../topics/compression>` for details. + process_exception: When reconnecting automatically, tell whether an + error is transient or fatal. The default behavior is defined by + :func:`process_exception`. Refer to its documentation for details. open_timeout: Timeout for opening the connection in seconds. :obj:`None` disables the timeout. ping_interval: Interval between keepalive pings in seconds. @@ -219,6 +280,7 @@ def __init__( additional_headers: HeadersLike | None = None, user_agent_header: str | None = USER_AGENT, compression: str | None = "deflate", + process_exception: Callable[[Exception], Exception | None] = process_exception, # Timeouts open_timeout: float | None = 10, ping_interval: float | None = 20, @@ -281,19 +343,26 @@ def factory() -> ClientConnection: loop = asyncio.get_running_loop() if kwargs.pop("unix", False): - self.create_connection = loop.create_unix_connection(factory, **kwargs) + self.create_connection = functools.partial( + loop.create_unix_connection, factory, **kwargs + ) else: if kwargs.get("sock") is None: kwargs.setdefault("host", wsuri.host) kwargs.setdefault("port", wsuri.port) - self.create_connection = loop.create_connection(factory, **kwargs) + self.create_connection = functools.partial( + loop.create_connection, factory, **kwargs + ) self.handshake_args = ( additional_headers, user_agent_header, ) - + self.process_exception = process_exception self.open_timeout = open_timeout + if logger is None: + logger = logging.getLogger("websockets.client") + self.logger = logger # ... = await connect(...) @@ -304,7 +373,7 @@ def __await__(self) -> Generator[Any, None, ClientConnection]: async def __await_impl__(self) -> ClientConnection: try: async with asyncio_timeout(self.open_timeout): - _transport, self.connection = await self.create_connection + _transport, self.connection = await self.create_connection() try: await self.connection.handshake(*self.handshake_args) except (Exception, asyncio.CancelledError): @@ -333,6 +402,48 @@ async def __aexit__( ) -> None: await self.connection.close() + # async for ... in connect(...): + + async def __aiter__(self) -> AsyncIterator[ClientConnection]: + delays: Generator[float, None, None] | None = None + while True: + try: + async with self as protocol: + yield protocol + except Exception as exc: + # Determine whether the exception is retryable or fatal. + # The API of process_exception is "return an exception or None"; + # "raise an exception" is also supported because it's a frequent + # mistake. It isn't documented in order to keep the API simple. + try: + new_exc = self.process_exception(exc) + except Exception as raised_exc: + new_exc = raised_exc + + # The connection failed with a fatal error. + # Raise the exception and exit the loop. + if new_exc is exc: + raise + if new_exc is not None: + raise new_exc from exc + + # The connection failed with a retryable error. + # Start or continue backoff and reconnect. + if delays is None: + delays = backoff() + delay = next(delays) + self.logger.info( + "! connect failed; reconnecting in %.1f seconds", + delay, + exc_info=True, + ) + await asyncio.sleep(delay) + continue + + else: + # The connection succeeded. Reset backoff. + delays = None + def unix_connect( path: str | None = None, diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 15178f8b8..7467d215a 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -1,4 +1,6 @@ import asyncio +import contextlib +import http import logging import socket import ssl @@ -7,20 +9,134 @@ from websockets.asyncio.client import * from websockets.asyncio.compatibility import TimeoutError from websockets.asyncio.server import serve, unix_serve -from websockets.exceptions import InvalidHandshake, InvalidURI +from websockets.client import backoff +from websockets.exceptions import InvalidHandshake, InvalidStatus, InvalidURI from websockets.extensions.permessage_deflate import PerMessageDeflate from ..utils import CLIENT_CONTEXT, MS, SERVER_CONTEXT, temp_unix_socket_path from .server import args, get_host_port, get_uri, handler +# Decorate tests that need it with @short_backoff_delay() instead of using it as +# a context manager when dropping support for Python < 3.10. +@contextlib.asynccontextmanager +async def short_backoff_delay(): + defaults = backoff.__defaults__ + backoff.__defaults__ = ( + defaults[0] * MS, + defaults[1] * MS, + defaults[2] * MS, + defaults[3], + ) + try: + yield + finally: + backoff.__defaults__ = defaults + + class ClientTests(unittest.IsolatedAsyncioTestCase): async def test_connection(self): - """Client connects to server and the handshake succeeds.""" + """Client connects to server.""" async with serve(*args) as server: async with connect(get_uri(server)) as client: self.assertEqual(client.protocol.state.name, "OPEN") + async def test_reconnection(self): + """Client reconnects to server.""" + iterations = 0 + successful = 0 + + def process_request(connection, request): + nonlocal iterations + iterations += 1 + # Retriable errors + if iterations == 1: + connection.transport.close() + elif iterations == 2: + return connection.respond(http.HTTPStatus.SERVICE_UNAVAILABLE, "🚒") + # Fatal error + elif iterations == 5: + return connection.respond(http.HTTPStatus.PAYMENT_REQUIRED, "💸") + + async with serve(*args, process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with short_backoff_delay(): + async for client in connect(get_uri(server)): + self.assertEqual(client.protocol.state.name, "OPEN") + successful += 1 + + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 402", + ) + self.assertEqual(iterations, 5) + self.assertEqual(successful, 2) + + @unittest.skipUnless( + hasattr(http.HTTPStatus, "IM_A_TEAPOT"), + "test requires Python 3.9", + ) + async def test_reconnection_with_custom_process_exception(self): + """Client runs process_exception to tell if errors are retryable or fatal.""" + iteration = 0 + + def process_request(connection, request): + nonlocal iteration + iteration += 1 + if iteration == 1: + return connection.respond(http.HTTPStatus.SERVICE_UNAVAILABLE, "🚒") + return connection.respond(http.HTTPStatus.IM_A_TEAPOT, "🫖") + + def process_exception(exc): + if isinstance(exc, InvalidStatus): + if 500 <= exc.response.status_code < 600: + return None + if exc.response.status_code == 418: + return Exception("🫖 💔 ☕️") + self.fail("unexpected exception") + + async with serve(*args, process_request=process_request) as server: + with self.assertRaises(Exception) as raised: + async with short_backoff_delay(): + async for _ in connect( + get_uri(server), process_exception=process_exception + ): + self.fail("did not raise") + + self.assertEqual(iteration, 2) + self.assertEqual( + str(raised.exception), + "🫖 💔 ☕️", + ) + + @unittest.skipUnless( + hasattr(http.HTTPStatus, "IM_A_TEAPOT"), + "test requires Python 3.9", + ) + async def test_reconnection_with_custom_process_exception_raising_exception(self): + """Client supports raising an exception in process_exception.""" + + def process_request(connection, request): + return connection.respond(http.HTTPStatus.IM_A_TEAPOT, "🫖") + + def process_exception(exc): + if isinstance(exc, InvalidStatus) and exc.response.status_code == 418: + raise Exception("🫖 💔 ☕️") + self.fail("unexpected exception") + + async with serve(*args, process_request=process_request) as server: + with self.assertRaises(Exception) as raised: + async with short_backoff_delay(): + async for _ in connect( + get_uri(server), process_exception=process_exception + ): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + "🫖 💔 ☕️", + ) + async def test_existing_socket(self): """Client connects using a pre-existing socket.""" async with serve(*args) as server: From 032463e4b9a27cb349c67e21d10ad8a34f22bb59 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 28 Aug 2024 22:56:43 +0200 Subject: [PATCH 1383/1539] Prevent a warning in twine upload. --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index ae0aaa65d..7ea3a5e5f 100644 --- a/setup.py +++ b/setup.py @@ -34,5 +34,6 @@ setuptools.setup( version=version, long_description=long_description, + long_description_content_type="text/x-rst", ext_modules=ext_modules, ) From 5be975e440e00dafc32c303975b12aaaf723c5e8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 28 Aug 2024 23:12:02 +0200 Subject: [PATCH 1384/1539] Make make build the C extension by default. --- Makefile | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/Makefile b/Makefile index a69248b6e..3eb3ae6b2 100644 --- a/Makefile +++ b/Makefile @@ -4,7 +4,8 @@ export PYTHONASYNCIODEBUG=1 export PYTHONPATH=src export PYTHONWARNINGS=default -default: style types tests +build: + python setup.py build_ext --inplace style: black src tests @@ -26,9 +27,6 @@ maxi_cov: coverage html coverage report --show-missing --fail-under=100 -build: - python setup.py build_ext --inplace - clean: find src -name '*.so' -delete find . -name '*.pyc' -delete From 2b990e88b97d385d72d91c8354e8c9ed4a1f04c4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 28 Aug 2024 23:12:37 +0200 Subject: [PATCH 1385/1539] Update FAQ after implementing reconnection. --- docs/faq/client.rst | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/docs/faq/client.rst b/docs/faq/client.rst index 0dfc84253..0a7aab6e2 100644 --- a/docs/faq/client.rst +++ b/docs/faq/client.rst @@ -81,15 +81,9 @@ The connection is closed when exiting the context manager. How do I reconnect when the connection drops? --------------------------------------------- -.. admonition:: This feature is only supported by the legacy :mod:`asyncio` - implementation. - :class: warning +Use :func:`~websockets.asyncio.client.connect` as an asynchronous iterator:: - It will be added to the new :mod:`asyncio` implementation soon. - -Use :func:`~websockets.legacy.client.connect` as an asynchronous iterator:: - - from websockets.legacy.client import connect + from websockets.asyncio.client import connect async for websocket in connect(...): try: From f842a13a234c61cc0c75743db54e698474e58c05 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 7 Sep 2024 21:52:19 +0200 Subject: [PATCH 1386/1539] Prevent false positives with latest ruff. --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index c1d34c90b..fde9c3226 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,9 @@ exclude_lines = [ "@unittest.skip", ] +[tool.ruff] +target-version = "py312" + [tool.ruff.lint] select = [ "E", # pycodestyle From 1f89db78ca6363a7e4d9830b6b280e311e3934f3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 7 Sep 2024 21:59:44 +0200 Subject: [PATCH 1387/1539] Switch from black to ruff for code formatting. It's faster and achieves pretty much the same result. --- .github/workflows/tests.yml | 4 +--- Makefile | 2 +- src/websockets/asyncio/client.py | 1 - src/websockets/asyncio/server.py | 1 - src/websockets/legacy/server.py | 11 +++++++---- tests/legacy/test_client_server.py | 1 - tox.ini | 15 +++++++-------- 7 files changed, 16 insertions(+), 19 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b9172b7fb..43193ea50 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -41,9 +41,7 @@ jobs: python-version: "3.x" - name: Install tox run: pip install tox - - name: Check code formatting - run: tox -e black - - name: Check code style + - name: Check code formatting & style run: tox -e ruff - name: Check types statically run: tox -e mypy diff --git a/Makefile b/Makefile index 3eb3ae6b2..fd36d0367 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ build: python setup.py build_ext --inplace style: - black src tests + ruff format src tests ruff check --fix src tests types: diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 860db3238..5f7a37198 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -297,7 +297,6 @@ def __init__( # Other keyword arguments are passed to loop.create_connection **kwargs: Any, ) -> None: - wsuri = parse_uri(uri) if wsuri.secure: diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 5e71a892b..1aa47af88 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -722,7 +722,6 @@ def __init__( # Other keyword arguments are passed to loop.create_server **kwargs: Any, ) -> None: - if subprotocols is not None: validate_subprotocols(subprotocols) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index b31cc25b8..2cb9b1abb 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -100,9 +100,10 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): def __init__( self, + # The version that accepts the path in the second argument is deprecated. ws_handler: ( Callable[[WebSocketServerProtocol], Awaitable[Any]] - | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] # deprecated + | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] ), ws_server: WebSocketServer, *, @@ -983,9 +984,10 @@ class Serve: def __init__( self, + # The version that accepts the path in the second argument is deprecated. ws_handler: ( Callable[[WebSocketServerProtocol], Awaitable[Any]] - | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] # deprecated + | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] ), host: str | Sequence[str] | None = None, port: int | None = None, @@ -1140,9 +1142,10 @@ async def __await_impl__(self) -> WebSocketServer: def unix_serve( + # The version that accepts the path in the second argument is deprecated. ws_handler: ( Callable[[WebSocketServerProtocol], Awaitable[Any]] - | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] # deprecated + | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] ), path: str | None = None, **kwargs: Any, @@ -1169,7 +1172,7 @@ def remove_path_argument( ws_handler: ( Callable[[WebSocketServerProtocol], Awaitable[Any]] | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] - ) + ), ) -> Callable[[WebSocketServerProtocol], Awaitable[Any]]: try: inspect.signature(ws_handler).bind(None) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 0c5d66c92..2f3ba9b77 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1293,7 +1293,6 @@ def test_connection_error_during_closing_handshake(self, close): class ClientServerTests( CommonClientServerTests, ClientServerTestsMixin, AsyncioTestCase ): - def test_redirect_secure(self): with temp_test_redirecting_server(self): # websockets doesn't support serving non-TLS and TLS connections diff --git a/tox.ini b/tox.ini index 16d9c9f16..cba9b290b 100644 --- a/tox.ini +++ b/tox.ini @@ -7,12 +7,12 @@ env_list = py312 py313 coverage - black ruff mypy [testenv] -commands = python -W error::DeprecationWarning -W error::PendingDeprecationWarning -m unittest {posargs} +commands = + python -W error::DeprecationWarning -W error::PendingDeprecationWarning -m unittest {posargs} pass_env = WEBSOCKETS_* [testenv:coverage] @@ -27,14 +27,13 @@ commands = python -m coverage report --show-missing --fail-under=100 deps = coverage -[testenv:black] -commands = black --check src tests -deps = black - [testenv:ruff] -commands = ruff check src tests +commands = + ruff format --check src tests + ruff check src tests deps = ruff [testenv:mypy] -commands = mypy --strict src +commands = + mypy --strict src deps = mypy From 6b2f06083573be750ac70c894ae13d87c03ef624 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 7 Sep 2024 21:40:07 +0200 Subject: [PATCH 1388/1539] Follow redirects in the new asyncio implementation. Fix #631. --- docs/howto/upgrade.rst | 27 +-- docs/project/changelog.rst | 3 + docs/reference/features.rst | 2 +- docs/reference/variables.rst | 11 ++ src/websockets/asyncio/client.py | 169 ++++++++++++---- src/websockets/legacy/client.py | 2 +- tests/asyncio/test_client.py | 319 ++++++++++++++++++++++++------- 7 files changed, 403 insertions(+), 130 deletions(-) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 42edb978a..120509c9f 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -10,15 +10,13 @@ It provides a very similar API. However, there are a few differences. The recommended upgrade process is: -1. Make sure that your application doesn't use any `deprecated APIs`_. If it +#. Make sure that your application doesn't use any `deprecated APIs`_. If it doesn't raise any warnings, you can skip this step. -2. Check if your application depends on `missing features`_. If it does, you - should stick to the original implementation until they're added. -3. `Update import paths`_. For straightforward usage of websockets, this could +#. `Update import paths`_. For straightforward usage of websockets, this could be the only step you need to take. Upgrading could be transparent. -4. Check out `new features and improvements`_ and consider taking advantage of +#. Check out `new features and improvements`_ and consider taking advantage of them to improve your application. -5. Review `API changes`_ and adapt your application to preserve its current +#. Review `API changes`_ and adapt your application to preserve its current functionality. In the interest of brevity, only :func:`~asyncio.client.connect` and @@ -62,23 +60,6 @@ the release notes of the version in which the feature was deprecated. * The ``host``, ``port``, and ``secure`` attributes of connections — deprecated in :ref:`8.0`. -.. _missing features: - -Missing features ----------------- - -.. admonition:: All features listed below will be provided in a future release. - :class: tip - - If your application relies on one of them, you should stick to the original - implementation until the new implementation supports it in a future release. - -Following redirects -................... - -The new implementation of :func:`~asyncio.client.connect` doesn't follow HTTP -redirects yet. - .. _Update import paths: Import paths diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 61113fb81..c77876a73 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -50,6 +50,9 @@ New features :func:`~asyncio.client.connect` as an asynchronous iterator to the new :mod:`asyncio` implementation. +* :func:`~asyncio.client.connect` now follows redirects in the new + :mod:`asyncio` implementation. + * Added HTTP Basic Auth to the new :mod:`asyncio` and :mod:`threading` implementations of servers. diff --git a/docs/reference/features.rst b/docs/reference/features.rst index d9941e408..32fc05baf 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -154,7 +154,7 @@ Client +------------------------------------+--------+--------+--------+--------+ | Connect to non-ASCII IRIs | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ - | Follow HTTP redirects | ❌ | ❌ | — | ✅ | + | Follow HTTP redirects | ✅ | ❌ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Perform HTTP Basic Authentication | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ diff --git a/docs/reference/variables.rst b/docs/reference/variables.rst index 498132251..b766e02a1 100644 --- a/docs/reference/variables.rst +++ b/docs/reference/variables.rst @@ -1,6 +1,8 @@ Environment variables ===================== +.. currentmodule:: websockets + Logging ------- @@ -77,3 +79,12 @@ Reconnection attempts are spaced out with truncated exponential backoff. The delay between attempts is capped at ``BACKOFF_MAX_DELAY`` seconds. The default value is ``90.0`` seconds. + +Redirections +------------ + +.. envvar:: WEBSOCKETS_MAX_REDIRECTS + + Maximum number of redirects that :func:`~asyncio.client.connect` follows. + + The default value is ``10``. diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 5f7a37198..50f67b95f 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -1,27 +1,30 @@ from __future__ import annotations import asyncio -import functools import logging +import os +import urllib.parse from types import TracebackType from typing import Any, AsyncIterator, Callable, Generator, Sequence from ..client import ClientProtocol, backoff from ..datastructures import HeadersLike -from ..exceptions import InvalidStatus +from ..exceptions import InvalidStatus, SecurityError from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate from ..headers import validate_subprotocols from ..http11 import USER_AGENT, Response from ..protocol import CONNECTING, Event from ..typing import LoggerLike, Origin, Subprotocol -from ..uri import parse_uri +from ..uri import WebSocketURI, parse_uri from .compatibility import TimeoutError, asyncio_timeout from .connection import Connection __all__ = ["connect", "unix_connect", "ClientConnection"] +MAX_REDIRECTS = int(os.environ.get("WEBSOCKETS_MAX_REDIRECTS", "10")) + class ClientConnection(Connection): """ @@ -126,7 +129,7 @@ def connection_lost(self, exc: Exception | None) -> None: def process_exception(exc: Exception) -> Exception | None: """ - Determine whether an error is retryable or fatal. + Determine whether a connection error is retryable or fatal. When reconnecting automatically with ``async for ... in connect(...)``, if a connection attempt fails, :func:`process_exception` is called to determine @@ -297,16 +300,7 @@ def __init__( # Other keyword arguments are passed to loop.create_connection **kwargs: Any, ) -> None: - wsuri = parse_uri(uri) - - if wsuri.secure: - kwargs.setdefault("ssl", True) - kwargs.setdefault("server_hostname", wsuri.host) - if kwargs.get("ssl") is None: - raise TypeError("ssl=None is incompatible with a wss:// URI") - else: - if kwargs.get("ssl") is not None: - raise TypeError("ssl argument is incompatible with a ws:// URI") + self.uri = uri if subprotocols is not None: validate_subprotocols(subprotocols) @@ -316,10 +310,13 @@ def __init__( elif compression is not None: raise ValueError(f"unsupported compression: {compression}") + if logger is None: + logger = logging.getLogger("websockets.client") + if create_connection is None: create_connection = ClientConnection - def factory() -> ClientConnection: + def protocol_factory(wsuri: WebSocketURI) -> ClientConnection: # This is a protocol in the Sans-I/O implementation of websockets. protocol = ClientProtocol( wsuri, @@ -340,28 +337,104 @@ def factory() -> ClientConnection: ) return connection + self.protocol_factory = protocol_factory + self.handshake_args = ( + additional_headers, + user_agent_header, + ) + self.process_exception = process_exception + self.open_timeout = open_timeout + self.logger = logger + self.connection_kwargs = kwargs + + async def create_connection(self) -> ClientConnection: + """Create TCP or Unix connection.""" loop = asyncio.get_running_loop() + + wsuri = parse_uri(self.uri) + kwargs = self.connection_kwargs.copy() + + def factory() -> ClientConnection: + return self.protocol_factory(wsuri) + + if wsuri.secure: + kwargs.setdefault("ssl", True) + kwargs.setdefault("server_hostname", wsuri.host) + if kwargs.get("ssl") is None: + raise TypeError("ssl=None is incompatible with a wss:// URI") + else: + if kwargs.get("ssl") is not None: + raise TypeError("ssl argument is incompatible with a ws:// URI") + if kwargs.pop("unix", False): - self.create_connection = functools.partial( - loop.create_unix_connection, factory, **kwargs - ) + _, connection = await loop.create_unix_connection(factory, **kwargs) else: if kwargs.get("sock") is None: kwargs.setdefault("host", wsuri.host) kwargs.setdefault("port", wsuri.port) - self.create_connection = functools.partial( - loop.create_connection, factory, **kwargs + _, connection = await loop.create_connection(factory, **kwargs) + return connection + + def process_redirect(self, exc: Exception) -> Exception | str: + """ + Determine whether a connection error is a redirect that can be followed. + + Return the new URI if it's a valid redirect. Else, return an exception. + + """ + if not ( + isinstance(exc, InvalidStatus) + and exc.response.status_code + in [ + 300, # Multiple Choices + 301, # Moved Permanently + 302, # Found + 303, # See Other + 307, # Temporary Redirect + 308, # Permanent Redirect + ] + and "Location" in exc.response.headers + ): + return exc + + old_wsuri = parse_uri(self.uri) + new_uri = urllib.parse.urljoin(self.uri, exc.response.headers["Location"]) + new_wsuri = parse_uri(new_uri) + + # If connect() received a socket, it is closed and cannot be reused. + if self.connection_kwargs.get("sock") is not None: + return ValueError( + f"cannot follow redirect to {new_uri} with a preexisting socket" ) - self.handshake_args = ( - additional_headers, - user_agent_header, - ) - self.process_exception = process_exception - self.open_timeout = open_timeout - if logger is None: - logger = logging.getLogger("websockets.client") - self.logger = logger + # TLS downgrade is forbidden. + if old_wsuri.secure and not new_wsuri.secure: + return SecurityError(f"cannot follow redirect to non-secure URI {new_uri}") + + # Apply restrictions to cross-origin redirects. + if ( + old_wsuri.secure != new_wsuri.secure + or old_wsuri.host != new_wsuri.host + or old_wsuri.port != new_wsuri.port + ): + # Cross-origin redirects on Unix sockets don't quite make sense. + if self.connection_kwargs.get("unix", False): + return ValueError( + f"cannot follow cross-origin redirect to {new_uri} " + f"with a Unix socket" + ) + + # Cross-origin redirects when host and port are overridden are ill-defined. + if ( + self.connection_kwargs.get("host") is not None + or self.connection_kwargs.get("port") is not None + ): + return ValueError( + f"cannot follow cross-origin redirect to {new_uri} " + f"with an explicit host or port" + ) + + return new_uri # ... = await connect(...) @@ -372,14 +445,38 @@ def __await__(self) -> Generator[Any, None, ClientConnection]: async def __await_impl__(self) -> ClientConnection: try: async with asyncio_timeout(self.open_timeout): - _transport, self.connection = await self.create_connection() - try: - await self.connection.handshake(*self.handshake_args) - except (Exception, asyncio.CancelledError): - self.connection.transport.close() - raise + for _ in range(MAX_REDIRECTS): + self.connection = await self.create_connection() + try: + await self.connection.handshake(*self.handshake_args) + except asyncio.CancelledError: + self.connection.transport.close() + raise + except Exception as exc: + # Always close the connection even though keep-alive is + # the default in HTTP/1.1 because create_connection ties + # opening the network connection with initializing the + # protocol. In the current design of connect(), there is + # no easy way to reuse the network connection that works + # in every case nor to reinitialize the protocol. + self.connection.transport.close() + + uri_or_exc = self.process_redirect(exc) + # Response is a valid redirect; follow it. + if isinstance(uri_or_exc, str): + self.uri = uri_or_exc + continue + # Response isn't a valid redirect; raise the exception. + if uri_or_exc is exc: + raise + else: + raise uri_or_exc from exc + + else: + return self.connection else: - return self.connection + raise SecurityError(f"more than {MAX_REDIRECTS} redirects") + except TimeoutError: # Re-raise exception with an informative error message. raise TimeoutError("timed out during handshake") from None diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index a1bc5cbae..ec4c2ff64 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -418,7 +418,7 @@ class Connect: """ - MAX_REDIRECTS_ALLOWED = 10 + MAX_REDIRECTS_ALLOWED = int(os.environ.get("WEBSOCKETS_MAX_REDIRECTS", "10")) def __init__( self, diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 7467d215a..b0487552e 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -10,7 +10,12 @@ from websockets.asyncio.compatibility import TimeoutError from websockets.asyncio.server import serve, unix_serve from websockets.client import backoff -from websockets.exceptions import InvalidHandshake, InvalidStatus, InvalidURI +from websockets.exceptions import ( + InvalidHandshake, + InvalidStatus, + InvalidURI, + SecurityError, +) from websockets.extensions.permessage_deflate import PerMessageDeflate from ..utils import CLIENT_CONTEXT, MS, SERVER_CONTEXT, temp_unix_socket_path @@ -34,6 +39,20 @@ async def short_backoff_delay(): backoff.__defaults__ = defaults +# Decorate tests that need it with @few_redirects() instead of using it as a +# context manager when dropping support for Python < 3.10. +@contextlib.asynccontextmanager +async def few_redirects(): + from websockets.asyncio import client + + max_redirects = client.MAX_REDIRECTS + client.MAX_REDIRECTS = 2 + try: + yield + finally: + client.MAX_REDIRECTS = max_redirects + + class ClientTests(unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Client connects to server.""" @@ -41,7 +60,93 @@ async def test_connection(self): async with connect(get_uri(server)) as client: self.assertEqual(client.protocol.state.name, "OPEN") - async def test_reconnection(self): + async def test_explicit_host_port(self): + """Client connects using an explicit host / port.""" + async with serve(*args) as server: + host, port = get_host_port(server) + async with connect("ws://overridden/", host=host, port=port) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + async def test_existing_socket(self): + """Client connects using a pre-existing socket.""" + async with serve(*args) as server: + with socket.create_connection(get_host_port(server)) as sock: + # Use a non-existing domain to ensure we connect to sock. + async with connect("ws://invalid/", sock=sock) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + async def test_additional_headers(self): + """Client can set additional headers with additional_headers.""" + async with serve(*args) as server: + async with connect( + get_uri(server), additional_headers={"Authorization": "Bearer ..."} + ) as client: + self.assertEqual(client.request.headers["Authorization"], "Bearer ...") + + async def test_override_user_agent(self): + """Client can override User-Agent header with user_agent_header.""" + async with serve(*args) as server: + async with connect(get_uri(server), user_agent_header="Smith") as client: + self.assertEqual(client.request.headers["User-Agent"], "Smith") + + async def test_remove_user_agent(self): + """Client can remove User-Agent header with user_agent_header.""" + async with serve(*args) as server: + async with connect(get_uri(server), user_agent_header=None) as client: + self.assertNotIn("User-Agent", client.request.headers) + + async def test_compression_is_enabled(self): + """Client enables compression by default.""" + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual( + [type(ext) for ext in client.protocol.extensions], + [PerMessageDeflate], + ) + + async def test_disable_compression(self): + """Client disables compression.""" + async with serve(*args) as server: + async with connect(get_uri(server), compression=None) as client: + self.assertEqual(client.protocol.extensions, []) + + async def test_keepalive_is_enabled(self): + """Client enables keepalive and measures latency by default.""" + async with serve(*args) as server: + async with connect(get_uri(server), ping_interval=MS) as client: + self.assertEqual(client.latency, 0) + await asyncio.sleep(2 * MS) + self.assertGreater(client.latency, 0) + + async def test_disable_keepalive(self): + """Client disables keepalive.""" + async with serve(*args) as server: + async with connect(get_uri(server), ping_interval=None) as client: + await asyncio.sleep(2 * MS) + self.assertEqual(client.latency, 0) + + async def test_logger(self): + """Client accepts a logger argument.""" + logger = logging.getLogger("test") + async with serve(*args) as server: + async with connect(get_uri(server), logger=logger) as client: + self.assertEqual(client.logger.name, logger.name) + + async def test_custom_connection_factory(self): + """Client runs ClientConnection factory provided in create_connection.""" + + def create_connection(*args, **kwargs): + client = ClientConnection(*args, **kwargs) + client.create_connection_ran = True + return client + + async with serve(*args) as server: + async with connect( + get_uri(server), create_connection=create_connection + ) as client: + self.assertTrue(client.create_connection_ran) + + async def test_reconnect(self): """Client reconnects to server.""" iterations = 0 successful = 0 @@ -76,7 +181,7 @@ def process_request(connection, request): hasattr(http.HTTPStatus, "IM_A_TEAPOT"), "test requires Python 3.9", ) - async def test_reconnection_with_custom_process_exception(self): + async def test_reconnect_with_custom_process_exception(self): """Client runs process_exception to tell if errors are retryable or fatal.""" iteration = 0 @@ -113,7 +218,7 @@ def process_exception(exc): hasattr(http.HTTPStatus, "IM_A_TEAPOT"), "test requires Python 3.9", ) - async def test_reconnection_with_custom_process_exception_raising_exception(self): + async def test_reconnect_with_custom_process_exception_raising_exception(self): """Client supports raising an exception in process_exception.""" def process_request(connection, request): @@ -137,84 +242,107 @@ def process_exception(exc): "🫖 💔 ☕️", ) - async def test_existing_socket(self): - """Client connects using a pre-existing socket.""" - async with serve(*args) as server: - with socket.create_connection(get_host_port(server)) as sock: - # Use a non-existing domain to ensure we connect to the right socket. - async with connect("ws://invalid/", sock=sock) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + async def test_redirect(self): + """Client follows redirect.""" - async def test_additional_headers(self): - """Client can set additional headers with additional_headers.""" - async with serve(*args) as server: - async with connect( - get_uri(server), additional_headers={"Authorization": "Bearer ..."} - ) as client: - self.assertEqual(client.request.headers["Authorization"], "Bearer ...") + def redirect(connection, request): + if request.path == "/redirect": + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = "/" + return response - async def test_override_user_agent(self): - """Client can override User-Agent header with user_agent_header.""" - async with serve(*args) as server: - async with connect(get_uri(server), user_agent_header="Smith") as client: - self.assertEqual(client.request.headers["User-Agent"], "Smith") + async with serve(*args, process_request=redirect) as server: + async with connect(get_uri(server) + "/redirect") as client: + self.assertEqual(client.protocol.wsuri.path, "/") - async def test_remove_user_agent(self): - """Client can remove User-Agent header with user_agent_header.""" - async with serve(*args) as server: - async with connect(get_uri(server), user_agent_header=None) as client: - self.assertNotIn("User-Agent", client.request.headers) + async def test_cross_origin_redirect(self): + """Client follows redirect to a secure URI on a different origin.""" - async def test_compression_is_enabled(self): - """Client enables compression by default.""" - async with serve(*args) as server: - async with connect(get_uri(server)) as client: - self.assertEqual( - [type(ext) for ext in client.protocol.extensions], - [PerMessageDeflate], - ) + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = get_uri(other_server) + return response - async def test_disable_compression(self): - """Client disables compression.""" - async with serve(*args) as server: - async with connect(get_uri(server), compression=None) as client: - self.assertEqual(client.protocol.extensions, []) + async with serve(*args, process_request=redirect) as server: + async with serve(*args) as other_server: + async with connect(get_uri(server)): + self.assertFalse(server.connections) + self.assertTrue(other_server.connections) - async def test_keepalive_is_enabled(self): - """Client enables keepalive and measures latency by default.""" - async with serve(*args) as server: - async with connect(get_uri(server), ping_interval=MS) as client: - self.assertEqual(client.latency, 0) - await asyncio.sleep(2 * MS) - self.assertGreater(client.latency, 0) + async def test_redirect_limit(self): + """Client stops following redirects after limit is reached.""" - async def test_disable_keepalive(self): - """Client disables keepalive.""" - async with serve(*args) as server: - async with connect(get_uri(server), ping_interval=None) as client: - await asyncio.sleep(2 * MS) - self.assertEqual(client.latency, 0) + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = request.path + return response - async def test_logger(self): - """Client accepts a logger argument.""" - logger = logging.getLogger("test") - async with serve(*args) as server: - async with connect(get_uri(server), logger=logger) as client: - self.assertEqual(client.logger.name, logger.name) + async with serve(*args, process_request=redirect) as server: + async with few_redirects(): + with self.assertRaises(SecurityError) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") - async def test_custom_connection_factory(self): - """Client runs ClientConnection factory provided in create_connection.""" + self.assertEqual( + str(raised.exception), + "more than 2 redirects", + ) - def create_connection(*args, **kwargs): - client = ClientConnection(*args, **kwargs) - client.create_connection_ran = True - return client + async def test_redirect_with_explicit_host_port(self): + """Client follows redirect with an explicit host / port.""" - async with serve(*args) as server: + def redirect(connection, request): + if request.path == "/redirect": + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = "/" + return response + + async with serve(*args, process_request=redirect) as server: + host, port = get_host_port(server) async with connect( - get_uri(server), create_connection=create_connection + "ws://overridden/redirect", host=host, port=port ) as client: - self.assertTrue(client.create_connection_ran) + self.assertEqual(client.protocol.wsuri.path, "/") + + async def test_cross_origin_redirect_with_explicit_host_port(self): + """Client doesn't follow cross-origin redirect with an explicit host / port.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = "ws://other/" + return response + + async with serve(*args, process_request=redirect) as server: + host, port = get_host_port(server) + with self.assertRaises(ValueError) as raised: + async with connect("ws://overridden/", host=host, port=port): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + "cannot follow cross-origin redirect to ws://other/ " + "with an explicit host or port", + ) + + async def test_redirect_with_existing_socket(self): + """Client doesn't follow redirect when using a pre-existing socket.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = "/" + return response + + async with serve(*args, process_request=redirect) as server: + with socket.create_connection(get_host_port(server)) as sock: + with self.assertRaises(ValueError) as raised: + # Use a non-existing domain to ensure we connect to sock. + async with connect("ws://invalid/redirect", sock=sock): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + "cannot follow redirect to ws://invalid/ with a preexisting socket", + ) async def test_invalid_uri(self): """Client receives an invalid URI.""" @@ -336,6 +464,40 @@ async def test_reject_invalid_server_hostname(self): str(raised.exception), ) + async def test_cross_origin_redirect(self): + """Client follows redirect to a secure URI on a different origin.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = get_uri(other_server) + return response + + async with serve(*args, ssl=SERVER_CONTEXT, process_request=redirect) as server: + async with serve(*args, ssl=SERVER_CONTEXT) as other_server: + async with connect(get_uri(server), ssl=CLIENT_CONTEXT): + self.assertFalse(server.connections) + self.assertTrue(other_server.connections) + + async def test_redirect_to_insecure_uri(self): + """Client doesn't follow redirect from secure URI to non-secure URI.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = insecure_uri + return response + + async with serve(*args, ssl=SERVER_CONTEXT, process_request=redirect) as server: + with self.assertRaises(SecurityError) as raised: + secure_uri = get_uri(server) + insecure_uri = secure_uri.replace("wss://", "ws://") + async with connect(secure_uri, ssl=CLIENT_CONTEXT): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + f"cannot follow redirect to non-secure URI {insecure_uri}", + ) + @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") class UnixClientTests(unittest.IsolatedAsyncioTestCase): @@ -354,6 +516,25 @@ async def test_set_host_header(self): async with unix_connect(path, uri="ws://overridden/") as client: self.assertEqual(client.request.headers["Host"], "overridden") + async def test_cross_origin_redirect(self): + """Client doesn't follows redirect to a URI on a different origin.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = "ws://other/" + return response + + with temp_unix_socket_path() as path: + async with unix_serve(handler, path, process_request=redirect): + with self.assertRaises(ValueError) as raised: + async with unix_connect(path): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + "cannot follow cross-origin redirect to ws://other/ with a Unix socket", + ) + @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") class SecureUnixClientTests(unittest.IsolatedAsyncioTestCase): From 566ab1d4165f9dfaa469cd19858bc5b451b182cd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 7 Sep 2024 22:15:28 +0200 Subject: [PATCH 1389/1539] The new asyncio implementation has reached parity. --- docs/howto/upgrade.rst | 3 ++- docs/index.rst | 10 +++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 120509c9f..70254d93e 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -31,7 +31,8 @@ respectively. The next steps are: - 1. Deprecating it once the new implementation reaches feature parity. + 1. Deprecating it once the new implementation is considered sufficiently + robust. 2. Maintaining it for five years per the :ref:`backwards-compatibility policy `. 3. Removing it. This is expected to happen around 2030. diff --git a/docs/index.rst b/docs/index.rst index 218a489a3..b8cd300e3 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -26,7 +26,7 @@ with a focus on correctness, simplicity, robustness, and performance. .. _WebSocket: https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API -It supports several network I/O and control flow paradigms: +It supports several network I/O and control flow paradigms. 1. The primary implementation builds upon :mod:`asyncio`, Python's standard asynchronous I/O framework. It provides an elegant coroutine-based API. It's @@ -44,10 +44,10 @@ It supports several network I/O and control flow paradigms: the Sans-I/O implementation. It adds a few features that were impossible to implement within the original design. - The new implementation will become the default as soon as it reaches - feature parity. If you're using the historical implementation, you should - :doc:`ugrade to the new implementation `. It's usually - straightforward. + The new implementation provides all features of the historical + implementation, and a few more. If you're using the historical + implementation, you should :doc:`ugrade to the new implementation + `. It's usually straightforward. 2. The :mod:`threading` implementation is a good alternative for clients, especially if you aren't familiar with :mod:`asyncio`. It may also be used From ef7b1e32d26b2104d8bc5e2ac3a001e0504aba11 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 7 Sep 2024 22:24:21 +0200 Subject: [PATCH 1390/1539] Proof-read upgrade guide. --- docs/howto/upgrade.rst | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 70254d93e..f3e42591e 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -10,14 +10,14 @@ It provides a very similar API. However, there are a few differences. The recommended upgrade process is: -#. Make sure that your application doesn't use any `deprecated APIs`_. If it - doesn't raise any warnings, you can skip this step. -#. `Update import paths`_. For straightforward usage of websockets, this could - be the only step you need to take. Upgrading could be transparent. -#. Check out `new features and improvements`_ and consider taking advantage of - them to improve your application. -#. Review `API changes`_ and adapt your application to preserve its current - functionality. +#. Make sure that your code doesn't use any `deprecated APIs`_. If it doesn't + raise warnings, you're fine. +#. `Update import paths`_. For straightforward use cases, this could be the only + step you need to take. +#. Check out `new features and improvements`_. Consider taking advantage of them + in your code. +#. Review `API changes`_. If needed, update your application to preserve its + current behavior. In the interest of brevity, only :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` are discussed below but everything also applies @@ -146,9 +146,8 @@ Customizing the opening handshake ................................. On the server side, if you're customizing how :func:`~legacy.server.serve` -processes the opening handshake with the ``process_request``, ``extra_headers``, -or ``select_subprotocol``, you must update your code and you can probably make -it simpler. +processes the opening handshake with ``process_request``, ``extra_headers``, or +``select_subprotocol``, you must update your code. Probably you can simplify it! ``process_request`` and ``select_subprotocol`` have new signatures. ``process_response`` replaces ``extra_headers`` and provides more flexibility. @@ -481,16 +480,17 @@ a ``check_credentials`` coroutine as well as an optional ``realm`` just like This new API has more obvious semantics. That makes it easier to understand and also easier to extend. -In the original implementation, overriding ``create_protocol`` changed the type +In the original implementation, overriding ``create_protocol`` changes the type of connection objects to :class:`~legacy.auth.BasicAuthWebSocketServerProtocol`, a subclass of :class:`~legacy.server.WebSocketServerProtocol` that performs HTTP -Basic Authentication in its ``process_request`` method. If you wanted to -customize ``process_request`` further, you had: +Basic Authentication in its ``process_request`` method. -* an ill-defined option: add a ``process_request`` argument to +To customize ``process_request`` further, you had only bad options: + +* the ill-defined option: add a ``process_request`` argument to :func:`~legacy.server.serve`; to tell which one would run first, you had to experiment or read the code; -* a cumbersome option: subclass +* the cumbersome option: subclass :class:`~legacy.auth.BasicAuthWebSocketServerProtocol`, then pass that subclass in the ``create_protocol`` argument of :func:`~legacy.auth.basic_auth_protocol_factory`. From 0158a246683437de2f85a27fecef06f5e5393e99 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 7 Sep 2024 22:39:43 +0200 Subject: [PATCH 1391/1539] Forgotten in d8ab09b2. --- docs/reference/features.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 32fc05baf..eeade1462 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -142,7 +142,7 @@ Client +------------------------------------+--------+--------+--------+--------+ | Close connection on context exit | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ - | Reconnect automatically | ❌ | ❌ | — | ✅ | + | Reconnect automatically | ✅ | ❌ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Configure ``Origin`` header | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ From 4a9dae23b3661210331463968cf4a5eeb54c41dd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 8 Sep 2024 16:58:50 +0200 Subject: [PATCH 1392/1539] Simplify handling of connection close during handshake. --- src/websockets/asyncio/client.py | 13 ++++--------- src/websockets/asyncio/server.py | 13 ++++--------- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 50f67b95f..2b8fbfd3a 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -90,7 +90,10 @@ async def handshake( self.protocol.send_request(self.request) # May raise CancelledError if open_timeout is exceeded. - await self.response_rcvd + await asyncio.wait( + [self.response_rcvd, self.connection_lost_waiter], + return_when=asyncio.FIRST_COMPLETED, + ) if self.response is None: raise ConnectionError("connection closed during handshake") @@ -118,14 +121,6 @@ def process_event(self, event: Event) -> None: else: super().process_event(event) - def connection_lost(self, exc: Exception | None) -> None: - try: - super().connection_lost(exc) - finally: - # If the connection is closed during the handshake, unblock it. - if not self.response_rcvd.done(): - self.response_rcvd.set_result(None) - def process_exception(exc: Exception) -> Exception | None: """ diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 1aa47af88..c04c5202f 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -139,7 +139,10 @@ async def handshake( """ # May raise CancelledError if open_timeout is exceeded. - await self.request_rcvd + await asyncio.wait( + [self.request_rcvd, self.connection_lost_waiter], + return_when=asyncio.FIRST_COMPLETED, + ) if self.request is None: raise ConnectionError("connection closed during handshake") @@ -229,14 +232,6 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: super().connection_made(transport) self.server.start_connection_handler(self) - def connection_lost(self, exc: Exception | None) -> None: - try: - super().connection_lost(exc) - finally: - # If the connection is closed during the handshake, unblock it. - if not self.request_rcvd.done(): - self.request_rcvd.set_result(None) - class Server: """ From 96055073d3cbeb6e52bb419205b2753d3898d612 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 8 Sep 2024 18:53:57 +0200 Subject: [PATCH 1393/1539] Close connection when client receives bad response. This avoids stalling until the opening handshake timeouts. --- src/websockets/asyncio/client.py | 21 +++++++++----------- src/websockets/client.py | 2 ++ src/websockets/protocol.py | 12 ++++++------ src/websockets/sync/client.py | 28 ++++++++++++--------------- tests/asyncio/test_client.py | 24 +++++++++++++++++++++-- tests/sync/test_client.py | 33 ++++++++++++++++++++++++++++++-- tests/test_client.py | 2 +- 7 files changed, 83 insertions(+), 39 deletions(-) diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 2b8fbfd3a..3985bfb6a 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -89,23 +89,19 @@ async def handshake( self.request.headers["User-Agent"] = user_agent_header self.protocol.send_request(self.request) - # May raise CancelledError if open_timeout is exceeded. await asyncio.wait( [self.response_rcvd, self.connection_lost_waiter], return_when=asyncio.FIRST_COMPLETED, ) - if self.response is None: - raise ConnectionError("connection closed during handshake") + # self.protocol.handshake_exc is always set when the connection is lost + # before receiving a response, when the response cannot be parsed, or + # when the response fails the handshake. if self.protocol.handshake_exc is None: self.start_keepalive() else: - try: - async with asyncio_timeout(self.close_timeout): - await self.connection_lost_waiter - finally: - raise self.protocol.handshake_exc + raise self.protocol.handshake_exc def process_event(self, event: Event) -> None: """ @@ -132,7 +128,8 @@ def process_exception(exc: Exception) -> Exception | None: This function defines the default behavior, which is to retry on: - * :exc:`OSError` and :exc:`asyncio.TimeoutError`: network errors; + * :exc:`EOFError`, :exc:`OSError`, :exc:`asyncio.TimeoutError`: network + errors; * :exc:`~websockets.exceptions.InvalidStatus` when the status code is 500, 502, 503, or 504: server or proxy errors. @@ -150,7 +147,7 @@ def process_exception(exc: Exception) -> Exception | None: That exception will be raised, breaking out of the retry loop. """ - if isinstance(exc, (OSError, asyncio.TimeoutError)): + if isinstance(exc, (EOFError, OSError, asyncio.TimeoutError)): return None if isinstance(exc, InvalidStatus) and exc.response.status_code in [ 500, # Internal Server Error @@ -445,7 +442,7 @@ async def __await_impl__(self) -> ClientConnection: try: await self.connection.handshake(*self.handshake_args) except asyncio.CancelledError: - self.connection.transport.close() + self.connection.close_transport() raise except Exception as exc: # Always close the connection even though keep-alive is @@ -454,7 +451,7 @@ async def __await_impl__(self) -> ClientConnection: # protocol. In the current design of connect(), there is # no easy way to reuse the network connection that works # in every case nor to reinitialize the protocol. - self.connection.transport.close() + self.connection.close_transport() uri_or_exc = self.process_redirect(exc) # Response is a valid redirect; follow it. diff --git a/src/websockets/client.py b/src/websockets/client.py index 95de99dc5..0e36fd028 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -323,6 +323,7 @@ def parse(self) -> Generator[None, None, None]: ) except Exception as exc: self.handshake_exc = exc + self.send_eof() self.parser = self.discard() next(self.parser) # start coroutine yield @@ -341,6 +342,7 @@ def parse(self) -> Generator[None, None, None]: response._exception = exc self.events.append(response) self.handshake_exc = exc + self.send_eof() self.parser = self.discard() next(self.parser) # start coroutine yield diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 3b3e80cf5..8751ebdb4 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -610,18 +610,18 @@ def discard(self) -> Generator[None, None, None]: - after sending a close frame, during an abnormal closure (7.1.7). """ - # The server close the TCP connection in the same circumstances where - # discard() replaces parse(). The client closes the connection later, - # after the server closes the connection or a timeout elapses. - # (The latter case cannot be handled in this Sans-I/O layer.) - assert (self.side is SERVER) == (self.eof_sent) + # After the opening handshake completes, the server closes the TCP + # connection in the same circumstances where discard() replaces parse(). + # The client closes it when it receives EOF from the server or times + # out. (The latter case cannot be handled in this Sans-I/O layer.) + assert (self.state == CONNECTING or self.side is SERVER) == (self.eof_sent) while not (yield from self.reader.at_eof()): self.reader.discard() if self.debug: self.logger.debug("< EOF") # A server closes the TCP connection immediately, while a client # waits for the server to close the TCP connection. - if self.side is CLIENT: + if self.state != CONNECTING and self.side is CLIENT: self.send_eof() self.state = CLOSED # If discard() completes normally, execution ends here. diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 6a04515f0..d1e20a757 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -12,7 +12,7 @@ from ..extensions.permessage_deflate import enable_client_permessage_deflate from ..headers import validate_subprotocols from ..http11 import USER_AGENT, Response -from ..protocol import CONNECTING, OPEN, Event +from ..protocol import CONNECTING, Event from ..typing import LoggerLike, Origin, Subprotocol from ..uri import parse_uri from .connection import Connection @@ -80,19 +80,11 @@ def handshake( self.protocol.send_request(self.request) if not self.response_rcvd.wait(timeout): - self.close_socket() - self.recv_events_thread.join() raise TimeoutError("timed out during handshake") - if self.response is None: - self.close_socket() - self.recv_events_thread.join() - raise ConnectionError("connection closed during handshake") - - if self.protocol.state is not OPEN: - self.recv_events_thread.join(self.close_timeout) - self.close_socket() - self.recv_events_thread.join() + # self.protocol.handshake_exc is always set when the connection is lost + # before receiving a response, when the response cannot be parsed, or + # when the response fails the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc @@ -295,16 +287,20 @@ def connect( protocol, close_timeout=close_timeout, ) - # On failure, handshake() closes the socket and raises an exception. + except Exception: + if sock is not None: + sock.close() + raise + + try: connection.handshake( additional_headers, user_agent_header, deadline.timeout(), ) - except Exception: - if sock is not None: - sock.close() + connection.close_socket() + connection.recv_events_thread.join() raise return connection diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index b0487552e..725bac92b 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -401,12 +401,32 @@ def close_connection(self, request): self.close_transport() async with serve(*args, process_request=close_connection) as server: - with self.assertRaises(ConnectionError) as raised: + with self.assertRaises(EOFError) as raised: async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), - "connection closed during handshake", + "connection closed while reading HTTP status line", + ) + + async def test_junk_handshake(self): + """Client closes the connection when receiving non-HTTP response from server.""" + + async def junk(reader, writer): + await asyncio.sleep(MS) # wait for the client to send the handshake request + writer.write(b"220 smtp.invalid ESMTP Postfix\r\n") + await reader.read(4096) # wait for the client to close the connection + writer.close() + + server = await asyncio.start_server(junk, "localhost", 0) + host, port = get_host_port(server) + async with server: + with self.assertRaises(ValueError) as raised: + async with connect(f"ws://{host}:{port}"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "unsupported HTTP version: 220", ) diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 812412203..96f7f0c90 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -1,7 +1,9 @@ import logging import socket +import socketserver import ssl import threading +import time import unittest from websockets.exceptions import InvalidHandshake, InvalidURI @@ -146,14 +148,41 @@ def close_connection(self, request): self.close_socket() with run_server(process_request=close_connection) as server: - with self.assertRaises(ConnectionError) as raised: + with self.assertRaises(EOFError) as raised: with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), - "connection closed during handshake", + "connection closed while reading HTTP status line", ) + def test_junk_handshake(self): + """Client closes the connection when receiving non-HTTP response from server.""" + + class JunkHandler(socketserver.BaseRequestHandler): + def handle(self): + time.sleep(MS) # wait for the client to send the handshake request + self.request.send(b"220 smtp.invalid ESMTP Postfix\r\n") + self.request.recv(4096) # wait for the client to close the connection + self.request.close() + + server = socketserver.TCPServer(("localhost", 0), JunkHandler) + host, port = server.server_address + with server: + thread = threading.Thread(target=server.serve_forever, args=(MS,)) + thread.start() + try: + with self.assertRaises(ValueError) as raised: + with connect(f"ws://{host}:{port}"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "unsupported HTTP version: 220", + ) + finally: + server.shutdown() + thread.join() + class SecureClientTests(unittest.TestCase): def test_connection(self): diff --git a/tests/test_client.py b/tests/test_client.py index 8b3bf4232..47558c1c0 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -154,7 +154,7 @@ def test_receive_reject(self): ) [response] = client.events_received() self.assertIsInstance(response, Response) - self.assertEqual(client.data_to_send(), []) + self.assertEqual(client.data_to_send(), [b""]) self.assertTrue(client.close_expected()) self.assertEqual(client.state, CONNECTING) From 560f6eed24a516bf9c8c4c7eaf83f84d2ce606f0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 8 Sep 2024 21:26:32 +0200 Subject: [PATCH 1394/1539] Log error when server receives bad request. --- src/websockets/asyncio/server.py | 155 ++++++++++++++++--------------- src/websockets/sync/server.py | 136 ++++++++++++++------------- tests/asyncio/test_server.py | 21 +++++ tests/sync/test_server.py | 32 +++++++ tests/test_server.py | 2 +- 5 files changed, 206 insertions(+), 140 deletions(-) diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index c04c5202f..228b20012 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -138,81 +138,76 @@ async def handshake( Perform the opening handshake. """ - # May raise CancelledError if open_timeout is exceeded. await asyncio.wait( [self.request_rcvd, self.connection_lost_waiter], return_when=asyncio.FIRST_COMPLETED, ) - if self.request is None: - raise ConnectionError("connection closed during handshake") - - async with self.send_context(expected_state=CONNECTING): - response = None - - if process_request is not None: - try: - response = process_request(self, self.request) - if isinstance(response, Awaitable): - response = await response - except Exception as exc: - self.protocol.handshake_exc = exc - self.logger.error("opening handshake failed", exc_info=True) - response = self.protocol.reject( - http.HTTPStatus.INTERNAL_SERVER_ERROR, - ( - "Failed to open a WebSocket connection.\n" - "See server log for more information.\n" - ), - ) - - if response is None: - if self.server.is_serving(): - self.response = self.protocol.accept(self.request) + if self.request is not None: + async with self.send_context(expected_state=CONNECTING): + response = None + + if process_request is not None: + try: + response = process_request(self, self.request) + if isinstance(response, Awaitable): + response = await response + except Exception as exc: + self.protocol.handshake_exc = exc + response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + + if response is None: + if self.server.is_serving(): + self.response = self.protocol.accept(self.request) + else: + self.response = self.protocol.reject( + http.HTTPStatus.SERVICE_UNAVAILABLE, + "Server is shutting down.\n", + ) else: - self.response = self.protocol.reject( - http.HTTPStatus.SERVICE_UNAVAILABLE, - "Server is shutting down.\n", - ) - else: - assert isinstance(response, Response) # help mypy - self.response = response - - if server_header: - self.response.headers["Server"] = server_header - - response = None - - if process_response is not None: - try: - response = process_response(self, self.request, self.response) - if isinstance(response, Awaitable): - response = await response - except Exception as exc: - self.protocol.handshake_exc = exc - self.logger.error("opening handshake failed", exc_info=True) - response = self.protocol.reject( - http.HTTPStatus.INTERNAL_SERVER_ERROR, - ( - "Failed to open a WebSocket connection.\n" - "See server log for more information.\n" - ), - ) - - if response is not None: - assert isinstance(response, Response) # help mypy - self.response = response - - self.protocol.send_response(self.response) + assert isinstance(response, Response) # help mypy + self.response = response + + if server_header: + self.response.headers["Server"] = server_header + + response = None + + if process_response is not None: + try: + response = process_response(self, self.request, self.response) + if isinstance(response, Awaitable): + response = await response + except Exception as exc: + self.protocol.handshake_exc = exc + response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + + if response is not None: + assert isinstance(response, Response) # help mypy + self.response = response + + self.protocol.send_response(self.response) + + # self.protocol.handshake_exc is always set when the connection is lost + # before receiving a request, when the request cannot be parsed, or when + # the response fails the handshake. if self.protocol.handshake_exc is None: self.start_keepalive() else: - try: - async with asyncio_timeout(self.close_timeout): - await self.connection_lost_waiter - finally: - raise self.protocol.handshake_exc + raise self.protocol.handshake_exc def process_event(self, event: Event) -> None: """ @@ -359,25 +354,35 @@ async def conn_handler(self, connection: ServerConnection) -> None: """ try: - # On failure, handshake() closes the transport, raises an - # exception, and logs it. async with asyncio_timeout(self.open_timeout): - await connection.handshake( - self.process_request, - self.process_response, - self.server_header, - ) + try: + await connection.handshake( + self.process_request, + self.process_response, + self.server_header, + ) + except asyncio.CancelledError: + connection.close_transport() + raise + except Exception: + connection.logger.error("opening handshake failed", exc_info=True) + connection.close_transport() + return try: await self.handler(connection) except Exception: - self.logger.error("connection handler failed", exc_info=True) + connection.logger.error("connection handler failed", exc_info=True) await connection.close(CloseCode.INTERNAL_ERROR) else: await connection.close() - except Exception: - # Don't leak connections on errors. + except TimeoutError: + # When the opening handshake times out, there's nothing to log. + pass + + except Exception: # pragma: no cover + # Don't leak connections on unexpected errors. connection.transport.abort() finally: diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 15de458b5..eb0536013 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -23,7 +23,7 @@ validate_subprotocols, ) from ..http11 import SERVER, Request, Response -from ..protocol import CONNECTING, OPEN, Event +from ..protocol import CONNECTING, Event from ..server import ServerProtocol from ..typing import LoggerLike, Origin, StatusLike, Subprotocol from .connection import Connection @@ -118,61 +118,56 @@ def handshake( """ if not self.request_rcvd.wait(timeout): - self.close_socket() - self.recv_events_thread.join() raise TimeoutError("timed out during handshake") - if self.request is None: - self.close_socket() - self.recv_events_thread.join() - raise ConnectionError("connection closed during handshake") - - with self.send_context(expected_state=CONNECTING): - self.response = None - - if process_request is not None: - try: - self.response = process_request(self, self.request) - except Exception as exc: - self.protocol.handshake_exc = exc - self.logger.error("opening handshake failed", exc_info=True) - self.response = self.protocol.reject( - http.HTTPStatus.INTERNAL_SERVER_ERROR, - ( - "Failed to open a WebSocket connection.\n" - "See server log for more information.\n" - ), - ) - - if self.response is None: - self.response = self.protocol.accept(self.request) - - if server_header: - self.response.headers["Server"] = server_header - - if process_response is not None: - try: - response = process_response(self, self.request, self.response) - except Exception as exc: - self.protocol.handshake_exc = exc - self.logger.error("opening handshake failed", exc_info=True) - self.response = self.protocol.reject( - http.HTTPStatus.INTERNAL_SERVER_ERROR, - ( - "Failed to open a WebSocket connection.\n" - "See server log for more information.\n" - ), - ) + if self.request is not None: + with self.send_context(expected_state=CONNECTING): + response = None + + if process_request is not None: + try: + response = process_request(self, self.request) + except Exception as exc: + self.protocol.handshake_exc = exc + response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + + if response is None: + self.response = self.protocol.accept(self.request) else: + self.response = response + + if server_header: + self.response.headers["Server"] = server_header + + response = None + + if process_response is not None: + try: + response = process_response(self, self.request, self.response) + except Exception as exc: + self.protocol.handshake_exc = exc + response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + if response is not None: self.response = response - self.protocol.send_response(self.response) + self.protocol.send_response(self.response) - if self.protocol.state is not OPEN: - self.recv_events_thread.join(self.close_timeout) - self.close_socket() - self.recv_events_thread.join() + # self.protocol.handshake_exc is always set when the connection is lost + # before receiving a request, when the request cannot be parsed, or when + # the response fails the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc @@ -552,26 +547,39 @@ def protocol_select_subprotocol( protocol, close_timeout=close_timeout, ) - # On failure, handshake() closes the socket, raises an exception, and - # logs it. - connection.handshake( - process_request, - process_response, - server_header, - deadline.timeout(), - ) - except Exception: sock.close() return try: - handler(connection) - except Exception: - protocol.logger.error("connection handler failed", exc_info=True) - connection.close(CloseCode.INTERNAL_ERROR) - else: - connection.close() + try: + connection.handshake( + process_request, + process_response, + server_header, + deadline.timeout(), + ) + except TimeoutError: + connection.close_socket() + connection.recv_events_thread.join() + return + except Exception: + connection.logger.error("opening handshake failed", exc_info=True) + connection.close_socket() + connection.recv_events_thread.join() + return + + try: + handler(connection) + except Exception: + connection.logger.error("connection handler failed", exc_info=True) + connection.close(CloseCode.INTERNAL_ERROR) + else: + connection.close() + + except Exception: # pragma: no cover + # Don't leak sockets on unexpected errors. + sock.close() # Initialize server diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index bc1d0444c..fdcbf9780 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -406,6 +406,27 @@ async def test_connection_closed_during_handshake(self): _reader, writer = await asyncio.open_connection(*get_host_port(server)) writer.close() + async def test_junk_handshake(self): + """Server closes the connection when receiving non-HTTP request from client.""" + with self.assertLogs("websockets", logging.ERROR) as logs: + async with serve(*args) as server: + reader, writer = await asyncio.open_connection(*get_host_port(server)) + writer.write(b"HELO relay.invalid\r\n") + try: + # Wait for the server to close the connection. + self.assertEqual(await reader.read(4096), b"") + finally: + writer.close() + + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["invalid HTTP request line: HELO relay.invalid"], + ) + async def test_close_server_rejects_connecting_connections(self): """Server rejects connecting connections with HTTP 503 when closing.""" diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 7bcf144a2..56565dab7 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -310,6 +310,38 @@ def handler(sock, addr): # Wait for the server thread to terminate. server_thread.join() + def test_junk_handshake(self): + """Server closes the connection when receiving non-HTTP request from client.""" + with self.assertLogs("websockets.server", logging.ERROR) as logs: + with run_server() as server: + # Patch handler to record a reference to the thread running it. + server_thread = None + original_handler = server.handler + + def handler(sock, addr): + nonlocal server_thread + server_thread = threading.current_thread() + original_handler(sock, addr) + + server.handler = handler + + with socket.create_connection(server.socket.getsockname()) as sock: + sock.send(b"HELO relay.invalid\r\n") + # Wait for the server to close the connection. + self.assertEqual(sock.recv(4096), b"") + + # Wait for the server thread to terminate. + server_thread.join() + + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["invalid HTTP request line: HELO relay.invalid"], + ) + class SecureServerTests(EvalShellMixin, unittest.TestCase): def test_connection(self): diff --git a/tests/test_server.py b/tests/test_server.py index e7d249f49..d34c8e83d 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -83,7 +83,7 @@ def test_partial_request(self): server.receive_eof() self.assertEqual(server.events_received(), []) - def test_random_request(self): + def test_junk_request(self): server = ServerProtocol() server.receive_data(b"HELO relay.invalid\r\n") server.receive_data(b"MAIL FROM: \r\n") From cb42484bb9e79b12f5d7a19bc96e43a022bc511a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 8 Sep 2024 22:15:33 +0200 Subject: [PATCH 1395/1539] Improve error messages on HTTP parsing errors. --- src/websockets/http11.py | 30 +++++++++++++++++------------- tests/asyncio/test_client.py | 3 ++- tests/sync/test_client.py | 3 ++- tests/test_http11.py | 25 ++++++++++++------------- 4 files changed, 33 insertions(+), 28 deletions(-) diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 61865bb92..47cef7a9b 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -135,14 +135,15 @@ def parse( raise EOFError("connection closed while reading HTTP request line") from exc try: - method, raw_path, version = request_line.split(b" ", 2) + method, raw_path, protocol = request_line.split(b" ", 2) except ValueError: # not enough values to unpack (expected 3, got 1-2) raise ValueError(f"invalid HTTP request line: {d(request_line)}") from None - + if protocol != b"HTTP/1.1": + raise ValueError( + f"unsupported protocol; expected HTTP/1.1: {d(request_line)}" + ) if method != b"GET": - raise ValueError(f"unsupported HTTP method: {d(method)}") - if version != b"HTTP/1.1": - raise ValueError(f"unsupported HTTP version: {d(version)}") + raise ValueError(f"unsupported HTTP method; expected GET; got {d(method)}") path = raw_path.decode("ascii", "surrogateescape") headers = yield from parse_headers(read_line) @@ -236,23 +237,26 @@ def parse( raise EOFError("connection closed while reading HTTP status line") from exc try: - version, raw_status_code, raw_reason = status_line.split(b" ", 2) + protocol, raw_status_code, raw_reason = status_line.split(b" ", 2) except ValueError: # not enough values to unpack (expected 3, got 1-2) raise ValueError(f"invalid HTTP status line: {d(status_line)}") from None - - if version != b"HTTP/1.1": - raise ValueError(f"unsupported HTTP version: {d(version)}") + if protocol != b"HTTP/1.1": + raise ValueError( + f"unsupported protocol; expected HTTP/1.1: {d(status_line)}" + ) try: status_code = int(raw_status_code) except ValueError: # invalid literal for int() with base 10 raise ValueError( - f"invalid HTTP status code: {d(raw_status_code)}" + f"invalid status code; expected integer; got {d(raw_status_code)}" ) from None - if not 100 <= status_code < 1000: - raise ValueError(f"unsupported HTTP status code: {d(raw_status_code)}") + if not 100 <= status_code < 600: + raise ValueError( + f"invalid status code; expected 100–599; got {d(raw_status_code)}" + ) if not _value_re.fullmatch(raw_reason): raise ValueError(f"invalid HTTP reason phrase: {d(raw_reason)}") - reason = raw_reason.decode() + reason = raw_reason.decode("ascii", "surrogateescape") headers = yield from parse_headers(read_line) diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 725bac92b..999ef1b71 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -426,7 +426,8 @@ async def junk(reader, writer): self.fail("did not raise") self.assertEqual( str(raised.exception), - "unsupported HTTP version: 220", + "unsupported protocol; expected HTTP/1.1: " + "220 smtp.invalid ESMTP Postfix", ) diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 96f7f0c90..e63d774b7 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -177,7 +177,8 @@ def handle(self): self.fail("did not raise") self.assertEqual( str(raised.exception), - "unsupported HTTP version: 220", + "unsupported protocol; expected HTTP/1.1: " + "220 smtp.invalid ESMTP Postfix", ) finally: server.shutdown() diff --git a/tests/test_http11.py b/tests/test_http11.py index d2e5e0462..1fbcb3ba4 100644 --- a/tests/test_http11.py +++ b/tests/test_http11.py @@ -50,22 +50,22 @@ def test_parse_invalid_request_line(self): "invalid HTTP request line: GET /", ) - def test_parse_unsupported_method(self): - self.reader.feed_data(b"OPTIONS * HTTP/1.1\r\n\r\n") + def test_parse_unsupported_protocol(self): + self.reader.feed_data(b"GET /chat HTTP/1.0\r\n\r\n") with self.assertRaises(ValueError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), - "unsupported HTTP method: OPTIONS", + "unsupported protocol; expected HTTP/1.1: GET /chat HTTP/1.0", ) - def test_parse_unsupported_version(self): - self.reader.feed_data(b"GET /chat HTTP/1.0\r\n\r\n") + def test_parse_unsupported_method(self): + self.reader.feed_data(b"OPTIONS * HTTP/1.1\r\n\r\n") with self.assertRaises(ValueError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), - "unsupported HTTP version: HTTP/1.0", + "unsupported HTTP method; expected GET; got OPTIONS", ) def test_parse_invalid_header(self): @@ -171,31 +171,30 @@ def test_parse_invalid_status_line(self): "invalid HTTP status line: Hello!", ) - def test_parse_unsupported_version(self): + def test_parse_unsupported_protocol(self): self.reader.feed_data(b"HTTP/1.0 400 Bad Request\r\n\r\n") with self.assertRaises(ValueError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), - "unsupported HTTP version: HTTP/1.0", + "unsupported protocol; expected HTTP/1.1: HTTP/1.0 400 Bad Request", ) - def test_parse_invalid_status(self): + def test_parse_non_integer_status(self): self.reader.feed_data(b"HTTP/1.1 OMG WTF\r\n\r\n") with self.assertRaises(ValueError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), - "invalid HTTP status code: OMG", + "invalid status code; expected integer; got OMG", ) - def test_parse_unsupported_status(self): + def test_parse_non_three_digit_status(self): self.reader.feed_data(b"HTTP/1.1 007 My name is Bond\r\n\r\n") with self.assertRaises(ValueError) as raised: next(self.parse()) self.assertEqual( - str(raised.exception), - "unsupported HTTP status code: 007", + str(raised.exception), "invalid status code; expected 100–599; got 007" ) def test_parse_invalid_reason(self): From 90270d8b1bc19ee0c4ac94c6a200c95614ab8771 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 8 Sep 2024 22:20:24 +0200 Subject: [PATCH 1396/1539] Add changelog for previous commits. --- docs/project/changelog.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index c77876a73..5716c6f25 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -59,6 +59,11 @@ New features * Made the set of active connections available in the :attr:`Server.connections ` property. +Improvements +............ + +* Improved reporting of errors during the opening handshake. + 13.0.1 ------ From 14d9d40acd23ecc66a8ac8d0535af9d8a1b0fa07 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 8 Sep 2024 22:52:04 +0200 Subject: [PATCH 1397/1539] Fix typo in convenience imports. Fix #1496. --- src/websockets/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index ac02a9f7e..7bd35cfa7 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -183,7 +183,7 @@ "ExtensionParameter": ".typing", "LoggerLike": ".typing", "Origin": ".typing", - "StatusLike": "typing", + "StatusLike": ".typing", "Subprotocol": ".typing", }, deprecated_aliases={ From f9cea9cca568dc92704de3744639eb4248278a8f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 9 Sep 2024 21:45:49 +0200 Subject: [PATCH 1398/1539] Improve isolation of tests of sync implementation. Before this change, threads handling requests could continue running after the end of the test. This caused spurious failures. Specifically, a test expecting an error log could get an error log from a previous tests. This happened sporadically on PyPy. --- tests/sync/server.py | 18 +++++++++++++ tests/sync/test_server.py | 54 +++------------------------------------ 2 files changed, 21 insertions(+), 51 deletions(-) diff --git a/tests/sync/server.py b/tests/sync/server.py index 114c1545b..fd7a03d82 100644 --- a/tests/sync/server.py +++ b/tests/sync/server.py @@ -38,12 +38,30 @@ def run_server(handler=handler, host="localhost", port=0, **kwargs): with serve(handler, host, port, **kwargs) as server: thread = threading.Thread(target=server.serve_forever) thread.start() + + # HACK: since the sync server doesn't track connections (yet), we record + # a reference to the thread handling the most recent connection, then we + # can wait for that thread to terminate when exiting the context. + handler_thread = None + original_handler = server.handler + + def handler(sock, addr): + nonlocal handler_thread + handler_thread = threading.current_thread() + original_handler(sock, addr) + + server.handler = handler + try: yield server finally: server.shutdown() thread.join() + # HACK: wait for the thread handling the most recent connection. + if handler_thread is not None: + handler_thread.join() + @contextlib.contextmanager def run_unix_server(path, handler=handler, **kwargs): diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 56565dab7..d0d2c0955 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -3,7 +3,7 @@ import http import logging import socket -import threading +import time import unittest from websockets.exceptions import ( @@ -289,50 +289,19 @@ def test_timeout_during_handshake(self): def test_connection_closed_during_handshake(self): """Server reads EOF before receiving handshake request from client.""" with run_server() as server: - # Patch handler to record a reference to the thread running it. - server_thread = None - conn_received = threading.Event() - original_handler = server.handler - - def handler(sock, addr): - nonlocal server_thread - server_thread = threading.current_thread() - nonlocal conn_received - conn_received.set() - original_handler(sock, addr) - - server.handler = handler - with socket.create_connection(server.socket.getsockname()): # Wait for the server to receive the connection, then close it. - conn_received.wait() - - # Wait for the server thread to terminate. - server_thread.join() + time.sleep(MS) def test_junk_handshake(self): """Server closes the connection when receiving non-HTTP request from client.""" with self.assertLogs("websockets.server", logging.ERROR) as logs: with run_server() as server: - # Patch handler to record a reference to the thread running it. - server_thread = None - original_handler = server.handler - - def handler(sock, addr): - nonlocal server_thread - server_thread = threading.current_thread() - original_handler(sock, addr) - - server.handler = handler - with socket.create_connection(server.socket.getsockname()) as sock: sock.send(b"HELO relay.invalid\r\n") # Wait for the server to close the connection. self.assertEqual(sock.recv(4096), b"") - # Wait for the server thread to terminate. - server_thread.join() - self.assertEqual( [record.getMessage() for record in logs.records], ["opening handshake failed"], @@ -360,26 +329,9 @@ def test_timeout_during_tls_handshake(self): def test_connection_closed_during_tls_handshake(self): """Server reads EOF before receiving TLS handshake request from client.""" with run_server(ssl=SERVER_CONTEXT) as server: - # Patch handler to record a reference to the thread running it. - server_thread = None - conn_received = threading.Event() - original_handler = server.handler - - def handler(sock, addr): - nonlocal server_thread - server_thread = threading.current_thread() - nonlocal conn_received - conn_received.set() - original_handler(sock, addr) - - server.handler = handler - with socket.create_connection(server.socket.getsockname()): # Wait for the server to receive the connection, then close it. - conn_received.wait() - - # Wait for the server thread to terminate. - server_thread.join() + time.sleep(MS) @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") From 070ff1a3e536e497a9ee11ea3b0649f82f3974c9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 9 Sep 2024 22:36:05 +0200 Subject: [PATCH 1399/1539] Add dedicated ConcurrencyError exception. Previously, a generic RuntimeError was used. Fix #1499. --- docs/project/changelog.rst | 4 +++ docs/reference/exceptions.rst | 5 ++++ src/websockets/__init__.py | 3 +++ src/websockets/asyncio/connection.py | 36 ++++++++++++++++----------- src/websockets/asyncio/messages.py | 19 +++++++------- src/websockets/exceptions.py | 12 +++++++++ src/websockets/sync/connection.py | 37 ++++++++++++++++------------ src/websockets/sync/messages.py | 15 +++++------ tests/asyncio/test_connection.py | 28 ++++++++++++--------- tests/asyncio/test_messages.py | 19 +++++++------- tests/sync/test_connection.py | 28 ++++++++++++--------- tests/sync/test_messages.py | 17 +++++++------ tests/test_exceptions.py | 4 +++ 13 files changed, 140 insertions(+), 87 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 5716c6f25..f92ca68b6 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -64,6 +64,10 @@ Improvements * Improved reporting of errors during the opening handshake. +* Raised :exc:`~exceptions.ConcurrencyError` on unsupported concurrent calls. + Previously, :exc:`RuntimeError` was raised. For backwards compatibility, + :exc:`~exceptions.ConcurrencyError` is a subclass of :exc:`RuntimeError`. + 13.0.1 ------ diff --git a/docs/reference/exceptions.rst b/docs/reference/exceptions.rst index 14a8edcd1..75934ef99 100644 --- a/docs/reference/exceptions.rst +++ b/docs/reference/exceptions.rst @@ -64,6 +64,11 @@ translated to :exc:`ConnectionClosedError` in the other implementations. .. autoexception:: InvalidState +Miscellaneous exceptions +------------------------ + +.. autoexception:: ConcurrencyError + Legacy exceptions ----------------- diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 7bd35cfa7..54591e9fd 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -14,6 +14,7 @@ "HeadersLike", "MultipleValuesError", # .exceptions + "ConcurrencyError", "ConnectionClosed", "ConnectionClosedError", "ConnectionClosedOK", @@ -72,6 +73,7 @@ from .client import ClientProtocol from .datastructures import Headers, HeadersLike, MultipleValuesError from .exceptions import ( + ConcurrencyError, ConnectionClosed, ConnectionClosedError, ConnectionClosedOK, @@ -134,6 +136,7 @@ "HeadersLike": ".datastructures", "MultipleValuesError": ".datastructures", # .exceptions + "ConcurrencyError": ".exceptions", "ConnectionClosed": ".exceptions", "ConnectionClosedError": ".exceptions", "ConnectionClosedOK": ".exceptions", diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 069b3e1d2..1b24f9af0 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -19,7 +19,12 @@ cast, ) -from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError +from ..exceptions import ( + ConcurrencyError, + ConnectionClosed, + ConnectionClosedOK, + ProtocolError, +) from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode from ..http11 import Request, Response from ..protocol import CLOSED, OPEN, Event, Protocol, State @@ -262,7 +267,7 @@ async def recv(self, decode: bool | None = None) -> Data: Raises: ConnectionClosed: When the connection is closed. - RuntimeError: If two coroutines call :meth:`recv` or + ConcurrencyError: If two coroutines call :meth:`recv` or :meth:`recv_streaming` concurrently. """ @@ -270,8 +275,8 @@ async def recv(self, decode: bool | None = None) -> Data: return await self.recv_messages.get(decode) except EOFError: raise self.protocol.close_exc from self.recv_exc - except RuntimeError: - raise RuntimeError( + except ConcurrencyError: + raise ConcurrencyError( "cannot call recv while another coroutine " "is already running recv or recv_streaming" ) from None @@ -283,8 +288,9 @@ async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data This method is designed for receiving fragmented messages. It returns an asynchronous iterator that yields each fragment as it is received. This iterator must be fully consumed. Else, future calls to :meth:`recv` or - :meth:`recv_streaming` will raise :exc:`RuntimeError`, making the - connection unusable. + :meth:`recv_streaming` will raise + :exc:`~websockets.exceptions.ConcurrencyError`, making the connection + unusable. :meth:`recv_streaming` raises the same exceptions as :meth:`recv`. @@ -315,7 +321,7 @@ async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data Raises: ConnectionClosed: When the connection is closed. - RuntimeError: If two coroutines call :meth:`recv` or + ConcurrencyError: If two coroutines call :meth:`recv` or :meth:`recv_streaming` concurrently. """ @@ -324,8 +330,8 @@ async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data yield frame except EOFError: raise self.protocol.close_exc from self.recv_exc - except RuntimeError: - raise RuntimeError( + except ConcurrencyError: + raise ConcurrencyError( "cannot call recv_streaming while another coroutine " "is already running recv or recv_streaming" ) from None @@ -593,7 +599,7 @@ async def ping(self, data: Data | None = None) -> Awaitable[float]: Raises: ConnectionClosed: When the connection is closed. - RuntimeError: If another ping was sent with the same data and + ConcurrencyError: If another ping was sent with the same data and the corresponding pong wasn't received yet. """ @@ -607,7 +613,7 @@ async def ping(self, data: Data | None = None) -> Awaitable[float]: async with self.send_context(): # Protect against duplicates if a payload is explicitly set. if data in self.pong_waiters: - raise RuntimeError("already waiting for a pong with the same data") + raise ConcurrencyError("already waiting for a pong with the same data") # Generate a unique random payload otherwise. while data is None or data in self.pong_waiters: @@ -793,7 +799,7 @@ async def send_context( # Let the caller interact with the protocol. try: yield - except (ProtocolError, RuntimeError): + except (ProtocolError, ConcurrencyError): # The protocol state wasn't changed. Exit immediately. raise except Exception as exc: @@ -1092,15 +1098,17 @@ def broadcast( if raise_exceptions: if sys.version_info[:2] < (3, 11): # pragma: no cover raise ValueError("raise_exceptions requires at least Python 3.11") - exceptions = [] + exceptions: list[Exception] = [] for connection in connections: + exception: Exception + if connection.protocol.state is not OPEN: continue if connection.fragmented_send_waiter is not None: if raise_exceptions: - exception = RuntimeError("sending a fragmented message") + exception = ConcurrencyError("sending a fragmented message") exceptions.append(exception) else: connection.logger.warning( diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 33ab6a5e9..c2b4afd67 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -12,6 +12,7 @@ TypeVar, ) +from ..exceptions import ConcurrencyError from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame from ..typing import Data @@ -49,7 +50,7 @@ async def get(self) -> T: """Remove and return an item from the queue, waiting if necessary.""" if not self.queue: if self.get_waiter is not None: - raise RuntimeError("get is already running") + raise ConcurrencyError("get is already running") self.get_waiter = self.loop.create_future() try: await self.get_waiter @@ -135,15 +136,15 @@ async def get(self, decode: bool | None = None) -> Data: Raises: EOFError: If the stream of frames has ended. - RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter` - concurrently. + ConcurrencyError: If two coroutines run :meth:`get` or + :meth:`get_iter` concurrently. """ if self.closed: raise EOFError("stream of frames ended") if self.get_in_progress: - raise RuntimeError("get() or get_iter() is already running") + raise ConcurrencyError("get() or get_iter() is already running") # Locking with get_in_progress ensures only one coroutine can get here. self.get_in_progress = True @@ -190,7 +191,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: :class:`str` or :class:`bytes` for each frame in the message. The iterator must be fully consumed before calling :meth:`get_iter` or - :meth:`get` again. Else, :exc:`RuntimeError` is raised. + :meth:`get` again. Else, :exc:`ConcurrencyError` is raised. This method only makes sense for fragmented messages. If messages aren't fragmented, use :meth:`get` instead. @@ -202,15 +203,15 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: Raises: EOFError: If the stream of frames has ended. - RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter` - concurrently. + ConcurrencyError: If two coroutines run :meth:`get` or + :meth:`get_iter` concurrently. """ if self.closed: raise EOFError("stream of frames ended") if self.get_in_progress: - raise RuntimeError("get() or get_iter() is already running") + raise ConcurrencyError("get() or get_iter() is already running") # Locking with get_in_progress ensures only one coroutine can get here. self.get_in_progress = True @@ -236,7 +237,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: # We cannot handle asyncio.CancelledError because we don't buffer # previous fragments — we're streaming them. Canceling get_iter() # here will leave the assembler in a stuck state. Future calls to - # get() or get_iter() will raise RuntimeError. + # get() or get_iter() will raise ConcurrencyError. frame = await self.frames.get() self.maybe_resume() assert frame.opcode is OP_CONT diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index b2b679e6b..d723f2fec 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -25,6 +25,7 @@ * :exc:`ProtocolError` (Sans-I/O) * :exc:`PayloadTooBig` (Sans-I/O) * :exc:`InvalidState` (Sans-I/O) + * :exc:`ConcurrencyError` """ @@ -62,6 +63,7 @@ "WebSocketProtocolError", "PayloadTooBig", "InvalidState", + "ConcurrencyError", ] @@ -354,6 +356,16 @@ class InvalidState(WebSocketException, AssertionError): """ +class ConcurrencyError(WebSocketException, RuntimeError): + """ + Raised when receiving or sending messages concurrently. + + WebSocket is a connection-oriented protocol. Reads must be serialized; so + must be writes. However, reading and writing concurrently is possible. + + """ + + # When type checking, import non-deprecated aliases eagerly. Else, import on demand. if typing.TYPE_CHECKING: from .legacy.exceptions import ( diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 16e51abda..65a7b63ed 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -10,7 +10,12 @@ from types import TracebackType from typing import Any, Iterable, Iterator, Mapping -from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError +from ..exceptions import ( + ConcurrencyError, + ConnectionClosed, + ConnectionClosedOK, + ProtocolError, +) from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode from ..http11 import Request, Response from ..protocol import CLOSED, OPEN, Event, Protocol, State @@ -194,7 +199,7 @@ def recv(self, timeout: float | None = None) -> Data: Raises: ConnectionClosed: When the connection is closed. - RuntimeError: If two threads call :meth:`recv` or + ConcurrencyError: If two threads call :meth:`recv` or :meth:`recv_streaming` concurrently. """ @@ -202,8 +207,8 @@ def recv(self, timeout: float | None = None) -> Data: return self.recv_messages.get(timeout) except EOFError: raise self.protocol.close_exc from self.recv_exc - except RuntimeError: - raise RuntimeError( + except ConcurrencyError: + raise ConcurrencyError( "cannot call recv while another thread " "is already running recv or recv_streaming" ) from None @@ -227,7 +232,7 @@ def recv_streaming(self) -> Iterator[Data]: Raises: ConnectionClosed: When the connection is closed. - RuntimeError: If two threads call :meth:`recv` or + ConcurrencyError: If two threads call :meth:`recv` or :meth:`recv_streaming` concurrently. """ @@ -236,8 +241,8 @@ def recv_streaming(self) -> Iterator[Data]: yield frame except EOFError: raise self.protocol.close_exc from self.recv_exc - except RuntimeError: - raise RuntimeError( + except ConcurrencyError: + raise ConcurrencyError( "cannot call recv_streaming while another thread " "is already running recv or recv_streaming" ) from None @@ -277,7 +282,7 @@ def send(self, message: Data | Iterable[Data]) -> None: Raises: ConnectionClosed: When the connection is closed. - RuntimeError: If the connection is sending a fragmented message. + ConcurrencyError: If the connection is sending a fragmented message. TypeError: If ``message`` doesn't have a supported type. """ @@ -287,7 +292,7 @@ def send(self, message: Data | Iterable[Data]) -> None: if isinstance(message, str): with self.send_context(): if self.send_in_progress: - raise RuntimeError( + raise ConcurrencyError( "cannot call send while another thread " "is already running send" ) @@ -296,7 +301,7 @@ def send(self, message: Data | Iterable[Data]) -> None: elif isinstance(message, BytesLike): with self.send_context(): if self.send_in_progress: - raise RuntimeError( + raise ConcurrencyError( "cannot call send while another thread " "is already running send" ) @@ -322,7 +327,7 @@ def send(self, message: Data | Iterable[Data]) -> None: text = True with self.send_context(): if self.send_in_progress: - raise RuntimeError( + raise ConcurrencyError( "cannot call send while another thread " "is already running send" ) @@ -335,7 +340,7 @@ def send(self, message: Data | Iterable[Data]) -> None: text = False with self.send_context(): if self.send_in_progress: - raise RuntimeError( + raise ConcurrencyError( "cannot call send while another thread " "is already running send" ) @@ -371,7 +376,7 @@ def send(self, message: Data | Iterable[Data]) -> None: self.protocol.send_continuation(b"", fin=True) self.send_in_progress = False - except RuntimeError: + except ConcurrencyError: # We didn't start sending a fragmented message. # The connection is still usable. raise @@ -445,7 +450,7 @@ def ping(self, data: Data | None = None) -> threading.Event: Raises: ConnectionClosed: When the connection is closed. - RuntimeError: If another ping was sent with the same data and + ConcurrencyError: If another ping was sent with the same data and the corresponding pong wasn't received yet. """ @@ -459,7 +464,7 @@ def ping(self, data: Data | None = None) -> threading.Event: with self.send_context(): # Protect against duplicates if a payload is explicitly set. if data in self.ping_waiters: - raise RuntimeError("already waiting for a pong with the same data") + raise ConcurrencyError("already waiting for a pong with the same data") # Generate a unique random payload otherwise. while data is None or data in self.ping_waiters: @@ -665,7 +670,7 @@ def send_context( # Let the caller interact with the protocol. try: yield - except (ProtocolError, RuntimeError): + except (ProtocolError, ConcurrencyError): # The protocol state wasn't changed. Exit immediately. raise except Exception as exc: diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index ff90345ac..8d090538f 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -5,6 +5,7 @@ import threading from typing import Iterator, cast +from ..exceptions import ConcurrencyError from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame from ..typing import Data @@ -74,7 +75,7 @@ def get(self, timeout: float | None = None) -> Data: Raises: EOFError: If the stream of frames has ended. - RuntimeError: If two threads run :meth:`get` or :meth:`get_iter` + ConcurrencyError: If two threads run :meth:`get` or :meth:`get_iter` concurrently. TimeoutError: If a timeout is provided and elapses before a complete message is received. @@ -85,7 +86,7 @@ def get(self, timeout: float | None = None) -> Data: raise EOFError("stream of frames ended") if self.get_in_progress: - raise RuntimeError("get() or get_iter() is already running") + raise ConcurrencyError("get() or get_iter() is already running") self.get_in_progress = True @@ -128,14 +129,14 @@ def get_iter(self) -> Iterator[Data]: :class:`bytes` for each frame in the message. The iterator must be fully consumed before calling :meth:`get_iter` or - :meth:`get` again. Else, :exc:`RuntimeError` is raised. + :meth:`get` again. Else, :exc:`ConcurrencyError` is raised. This method only makes sense for fragmented messages. If messages aren't fragmented, use :meth:`get` instead. Raises: EOFError: If the stream of frames has ended. - RuntimeError: If two threads run :meth:`get` or :meth:`get_iter` + ConcurrencyError: If two threads run :meth:`get` or :meth:`get_iter` concurrently. """ @@ -144,7 +145,7 @@ def get_iter(self) -> Iterator[Data]: raise EOFError("stream of frames ended") if self.get_in_progress: - raise RuntimeError("get() or get_iter() is already running") + raise ConcurrencyError("get() or get_iter() is already running") chunks = self.chunks self.chunks = [] @@ -198,7 +199,7 @@ def put(self, frame: Frame) -> None: Raises: EOFError: If the stream of frames has ended. - RuntimeError: If two threads run :meth:`put` concurrently. + ConcurrencyError: If two threads run :meth:`put` concurrently. """ with self.mutex: @@ -206,7 +207,7 @@ def put(self, frame: Frame) -> None: raise EOFError("stream of frames ended") if self.put_in_progress: - raise RuntimeError("put is already running") + raise ConcurrencyError("put is already running") if frame.opcode is OP_TEXT: self.decoder = UTF8Decoder(errors="strict") diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 78f3adf68..70d9dad63 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -10,7 +10,11 @@ from websockets.asyncio.compatibility import TimeoutError, aiter, anext, asyncio_timeout from websockets.asyncio.connection import * from websockets.asyncio.connection import broadcast -from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK +from websockets.exceptions import ( + ConcurrencyError, + ConnectionClosedError, + ConnectionClosedOK, +) from websockets.frames import CloseCode, Frame, Opcode from websockets.protocol import CLIENT, SERVER, Protocol, State @@ -219,12 +223,12 @@ async def test_recv_connection_closed_error(self): await self.connection.recv() async def test_recv_during_recv(self): - """recv raises RuntimeError when called concurrently with itself.""" + """recv raises ConcurrencyError when called concurrently.""" recv_task = asyncio.create_task(self.connection.recv()) await asyncio.sleep(0) # let the event loop start recv_task self.addCleanup(recv_task.cancel) - with self.assertRaises(RuntimeError) as raised: + with self.assertRaises(ConcurrencyError) as raised: await self.connection.recv() self.assertEqual( str(raised.exception), @@ -233,14 +237,14 @@ async def test_recv_during_recv(self): ) async def test_recv_during_recv_streaming(self): - """recv raises RuntimeError when called concurrently with recv_streaming.""" + """recv raises ConcurrencyError when called concurrently with recv_streaming.""" recv_streaming_task = asyncio.create_task( alist(self.connection.recv_streaming()) ) await asyncio.sleep(0) # let the event loop start recv_streaming_task self.addCleanup(recv_streaming_task.cancel) - with self.assertRaises(RuntimeError) as raised: + with self.assertRaises(ConcurrencyError) as raised: await self.connection.recv() self.assertEqual( str(raised.exception), @@ -349,12 +353,12 @@ async def test_recv_streaming_connection_closed_error(self): self.fail("did not raise") async def test_recv_streaming_during_recv(self): - """recv_streaming raises RuntimeError when called concurrently with recv.""" + """recv_streaming raises ConcurrencyError when called concurrently with recv.""" recv_task = asyncio.create_task(self.connection.recv()) await asyncio.sleep(0) # let the event loop start recv_task self.addCleanup(recv_task.cancel) - with self.assertRaises(RuntimeError) as raised: + with self.assertRaises(ConcurrencyError) as raised: async for _ in self.connection.recv_streaming(): self.fail("did not raise") self.assertEqual( @@ -364,14 +368,14 @@ async def test_recv_streaming_during_recv(self): ) async def test_recv_streaming_during_recv_streaming(self): - """recv_streaming raises RuntimeError when called concurrently with itself.""" + """recv_streaming raises ConcurrencyError when called concurrently.""" recv_streaming_task = asyncio.create_task( alist(self.connection.recv_streaming()) ) await asyncio.sleep(0) # let the event loop start recv_streaming_task self.addCleanup(recv_streaming_task.cancel) - with self.assertRaises(RuntimeError) as raised: + with self.assertRaises(ConcurrencyError) as raised: async for _ in self.connection.recv_streaming(): self.fail("did not raise") self.assertEqual( @@ -419,7 +423,7 @@ async def fragments(): gate.set_result(None) # Running recv_streaming again fails. - with self.assertRaises(RuntimeError): + with self.assertRaises(ConcurrencyError): await alist(self.connection.recv_streaming()) # Test send. @@ -856,7 +860,7 @@ async def test_ping_duplicate_payload(self): async with self.drop_frames_rcvd(): # drop automatic response to ping pong_waiter = await self.connection.ping("idem") - with self.assertRaises(RuntimeError) as raised: + with self.assertRaises(ConcurrencyError) as raised: await self.connection.ping("idem") self.assertEqual( str(raised.exception), @@ -1252,7 +1256,7 @@ async def fragments(): self.assertEqual(str(raised.exception), "skipped broadcast (1 sub-exception)") exc = raised.exception.exceptions[0] self.assertEqual(str(exc), "sending a fragmented message") - self.assertIsInstance(exc, RuntimeError) + self.assertIsInstance(exc, ConcurrencyError) gate.set_result(None) await send_task diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index 615b1f3a8..d2cf25c9c 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -5,6 +5,7 @@ from websockets.asyncio.compatibility import aiter, anext from websockets.asyncio.messages import * from websockets.asyncio.messages import SimpleQueue +from websockets.exceptions import ConcurrencyError from websockets.frames import OP_BINARY, OP_CONT, OP_TEXT, Frame from .utils import alist @@ -37,10 +38,10 @@ async def test_get_then_put(self): self.assertEqual(item, 42) async def test_get_concurrently(self): - """get cannot be called concurrently with itself.""" + """get cannot be called concurrently.""" getter_task = asyncio.create_task(self.queue.get()) await asyncio.sleep(0) # let the task start - with self.assertRaises(RuntimeError): + with self.assertRaises(ConcurrencyError): await self.queue.get() getter_task.cancel() @@ -361,7 +362,7 @@ async def test_cancel_get_iter_after_first_frame(self): self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) - with self.assertRaises(RuntimeError): + with self.assertRaises(ConcurrencyError): await alist(self.assembler.get_iter()) # Test put @@ -423,10 +424,10 @@ async def test_close_is_idempotent(self): # Test (non-)concurrency async def test_get_fails_when_get_is_running(self): - """get cannot be called concurrently with itself.""" + """get cannot be called concurrently.""" asyncio.create_task(self.assembler.get()) await asyncio.sleep(0) - with self.assertRaises(RuntimeError): + with self.assertRaises(ConcurrencyError): await self.assembler.get() self.assembler.close() # let task terminate @@ -434,7 +435,7 @@ async def test_get_fails_when_get_iter_is_running(self): """get cannot be called concurrently with get_iter.""" asyncio.create_task(alist(self.assembler.get_iter())) await asyncio.sleep(0) - with self.assertRaises(RuntimeError): + with self.assertRaises(ConcurrencyError): await self.assembler.get() self.assembler.close() # let task terminate @@ -442,15 +443,15 @@ async def test_get_iter_fails_when_get_is_running(self): """get_iter cannot be called concurrently with get.""" asyncio.create_task(self.assembler.get()) await asyncio.sleep(0) - with self.assertRaises(RuntimeError): + with self.assertRaises(ConcurrencyError): await alist(self.assembler.get_iter()) self.assembler.close() # let task terminate async def test_get_iter_fails_when_get_iter_is_running(self): - """get_iter cannot be called concurrently with itself.""" + """get_iter cannot be called concurrently.""" asyncio.create_task(alist(self.assembler.get_iter())) await asyncio.sleep(0) - with self.assertRaises(RuntimeError): + with self.assertRaises(ConcurrencyError): await alist(self.assembler.get_iter()) self.assembler.close() # let task terminate diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index d9fb2093b..16f92e164 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -9,7 +9,11 @@ import uuid from unittest.mock import patch -from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK +from websockets.exceptions import ( + ConcurrencyError, + ConnectionClosedError, + ConnectionClosedOK, +) from websockets.frames import CloseCode, Frame, Opcode from websockets.protocol import CLIENT, SERVER, Protocol from websockets.sync.connection import * @@ -173,11 +177,11 @@ def test_recv_connection_closed_error(self): self.connection.recv() def test_recv_during_recv(self): - """recv raises RuntimeError when called concurrently with itself.""" + """recv raises ConcurrencyError when called concurrently.""" recv_thread = threading.Thread(target=self.connection.recv) recv_thread.start() - with self.assertRaises(RuntimeError) as raised: + with self.assertRaises(ConcurrencyError) as raised: self.connection.recv() self.assertEqual( str(raised.exception), @@ -189,13 +193,13 @@ def test_recv_during_recv(self): recv_thread.join() def test_recv_during_recv_streaming(self): - """recv raises RuntimeError when called concurrently with recv_streaming.""" + """recv raises ConcurrencyError when called concurrently with recv_streaming.""" recv_streaming_thread = threading.Thread( target=lambda: list(self.connection.recv_streaming()) ) recv_streaming_thread.start() - with self.assertRaises(RuntimeError) as raised: + with self.assertRaises(ConcurrencyError) as raised: self.connection.recv() self.assertEqual( str(raised.exception), @@ -257,11 +261,11 @@ def test_recv_streaming_connection_closed_error(self): self.fail("did not raise") def test_recv_streaming_during_recv(self): - """recv_streaming raises RuntimeError when called concurrently with recv.""" + """recv_streaming raises ConcurrencyError when called concurrently with recv.""" recv_thread = threading.Thread(target=self.connection.recv) recv_thread.start() - with self.assertRaises(RuntimeError) as raised: + with self.assertRaises(ConcurrencyError) as raised: for _ in self.connection.recv_streaming(): self.fail("did not raise") self.assertEqual( @@ -274,13 +278,13 @@ def test_recv_streaming_during_recv(self): recv_thread.join() def test_recv_streaming_during_recv_streaming(self): - """recv_streaming raises RuntimeError when called concurrently with itself.""" + """recv_streaming raises ConcurrencyError when called concurrently.""" recv_streaming_thread = threading.Thread( target=lambda: list(self.connection.recv_streaming()) ) recv_streaming_thread.start() - with self.assertRaises(RuntimeError) as raised: + with self.assertRaises(ConcurrencyError) as raised: for _ in self.connection.recv_streaming(): self.fail("did not raise") self.assertEqual( @@ -335,7 +339,7 @@ def test_send_connection_closed_error(self): self.connection.send("😀") def test_send_during_send(self): - """send raises RuntimeError when called concurrently with itself.""" + """send raises ConcurrencyError when called concurrently.""" recv_thread = threading.Thread(target=self.remote_connection.recv) recv_thread.start() @@ -363,7 +367,7 @@ def fragments(): [b"\x01\x02", b"\xfe\xff"], ]: with self.subTest(message=message): - with self.assertRaises(RuntimeError) as raised: + with self.assertRaises(ConcurrencyError) as raised: self.connection.send(message) self.assertEqual( str(raised.exception), @@ -653,7 +657,7 @@ def test_ping_duplicate_payload(self): with self.drop_frames_rcvd(): # drop automatic response to ping pong_waiter = self.connection.ping("idem") - with self.assertRaises(RuntimeError) as raised: + with self.assertRaises(ConcurrencyError) as raised: self.connection.ping("idem") self.assertEqual( str(raised.exception), diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py index c134b8304..d44b39b88 100644 --- a/tests/sync/test_messages.py +++ b/tests/sync/test_messages.py @@ -1,5 +1,6 @@ import time +from websockets.exceptions import ConcurrencyError from websockets.frames import OP_BINARY, OP_CONT, OP_TEXT, Frame from websockets.sync.messages import * @@ -411,40 +412,40 @@ def test_close_is_idempotent(self): # Test (non-)concurrency def test_get_fails_when_get_is_running(self): - """get cannot be called concurrently with itself.""" + """get cannot be called concurrently.""" with self.run_in_thread(self.assembler.get): - with self.assertRaises(RuntimeError): + with self.assertRaises(ConcurrencyError): self.assembler.get() self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread def test_get_fails_when_get_iter_is_running(self): """get cannot be called concurrently with get_iter.""" with self.run_in_thread(lambda: list(self.assembler.get_iter())): - with self.assertRaises(RuntimeError): + with self.assertRaises(ConcurrencyError): self.assembler.get() self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread def test_get_iter_fails_when_get_is_running(self): """get_iter cannot be called concurrently with get.""" with self.run_in_thread(self.assembler.get): - with self.assertRaises(RuntimeError): + with self.assertRaises(ConcurrencyError): list(self.assembler.get_iter()) self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread def test_get_iter_fails_when_get_iter_is_running(self): - """get_iter cannot be called concurrently with itself.""" + """get_iter cannot be called concurrently.""" with self.run_in_thread(lambda: list(self.assembler.get_iter())): - with self.assertRaises(RuntimeError): + with self.assertRaises(ConcurrencyError): list(self.assembler.get_iter()) self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread def test_put_fails_when_put_is_running(self): - """put cannot be called concurrently with itself.""" + """put cannot be called concurrently.""" def putter(): self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) with self.run_in_thread(putter): - with self.assertRaises(RuntimeError): + with self.assertRaises(ConcurrencyError): self.assembler.put(Frame(OP_BINARY, b"tea")) self.assembler.get() # unblock other thread diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 5620b8a53..8d41bf915 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -167,6 +167,10 @@ def test_str(self): InvalidState("WebSocket connection isn't established yet"), "WebSocket connection isn't established yet", ), + ( + ConcurrencyError("get() or get_iter() is already running"), + "get() or get_iter() is already running", + ), ]: with self.subTest(exception=exception): self.assertEqual(str(exception), exception_str) From d19ed267b5e04ded752f78da6c85ad2905cf89e5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 11 Sep 2024 07:01:12 +0200 Subject: [PATCH 1400/1539] Run spellcheck. --- docs/reference/variables.rst | 4 ++-- docs/spelling_wordlist.txt | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/reference/variables.rst b/docs/reference/variables.rst index b766e02a1..a24773074 100644 --- a/docs/reference/variables.rst +++ b/docs/reference/variables.rst @@ -80,8 +80,8 @@ Reconnection attempts are spaced out with truncated exponential backoff. The default value is ``90.0`` seconds. -Redirections ------------- +Redirects +--------- .. envvar:: WEBSOCKETS_MAX_REDIRECTS diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index a1ba59a37..11b13250a 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -59,6 +59,7 @@ reconnection redis redistributions retransmit +retryable runtime scalable stateful From 98f236f89529d317628fe8ee3d4d0564e675f68d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 10 Sep 2024 08:01:49 +0200 Subject: [PATCH 1401/1539] Run handler only when opening handshake succeeds. When process_request() or process_response() returned a HTTP response without calling accept() or reject() and with a status code other than 101, the connection handler used to start, which was incorrect. Fix #1419. Also move start_keepalive() outside of handshake() and bring it together with starting the connection handler, which is more logical. --- docs/project/changelog.rst | 7 +++++ src/websockets/asyncio/client.py | 5 ++-- src/websockets/asyncio/server.py | 11 ++++---- src/websockets/server.py | 23 +++++++++------- src/websockets/sync/server.py | 8 +++--- tests/asyncio/test_server.py | 10 +++++-- tests/sync/test_server.py | 5 +++- tests/test_server.py | 45 +++++++++++++++++++++++++++++--- 8 files changed, 88 insertions(+), 26 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index f92ca68b6..69051d287 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -68,6 +68,13 @@ Improvements Previously, :exc:`RuntimeError` was raised. For backwards compatibility, :exc:`~exceptions.ConcurrencyError` is a subclass of :exc:`RuntimeError`. +Bug fixes +......... + +* The new :mod:`asyncio` and :mod:`threading` implementations of servers don't + start the connection handler anymore when ``process_request`` or + ``process_response`` returns a HTTP response. + 13.0.1 ------ diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 3985bfb6a..b1beb3e00 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -98,9 +98,7 @@ async def handshake( # before receiving a response, when the response cannot be parsed, or # when the response fails the handshake. - if self.protocol.handshake_exc is None: - self.start_keepalive() - else: + if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc def process_event(self, event: Event) -> None: @@ -465,6 +463,7 @@ async def __await_impl__(self) -> ClientConnection: raise uri_or_exc from exc else: + self.connection.start_keepalive() return self.connection else: raise SecurityError(f"more than {MAX_REDIRECTS} redirects") diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 228b20012..78ee760d2 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -201,12 +201,11 @@ async def handshake( self.protocol.send_response(self.response) # self.protocol.handshake_exc is always set when the connection is lost - # before receiving a request, when the request cannot be parsed, or when - # the response fails the handshake. + # before receiving a request, when the request cannot be parsed, when + # the handshake encounters an error, or when process_request or + # process_response sends a HTTP response that rejects the handshake. - if self.protocol.handshake_exc is None: - self.start_keepalive() - else: + if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc def process_event(self, event: Event) -> None: @@ -369,7 +368,9 @@ async def conn_handler(self, connection: ServerConnection) -> None: connection.close_transport() return + assert connection.protocol.state is OPEN try: + connection.start_keepalive() await self.handler(connection) except Exception: connection.logger.error("connection handler failed", exc_info=True) diff --git a/src/websockets/server.py b/src/websockets/server.py index ac62800d6..b2671f402 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -204,7 +204,6 @@ def accept(self, request: Request) -> Response: if protocol_header is not None: headers["Sec-WebSocket-Protocol"] = protocol_header - self.logger.info("connection open") return Response(101, "Switching Protocols", headers) def process_request( @@ -515,14 +514,7 @@ def reject(self, status: StatusLike, text: str) -> Response: ("Content-Type", "text/plain; charset=utf-8"), ] ) - response = Response(status.value, status.phrase, headers, body) - # When reject() is called from accept(), handshake_exc is already set. - # If a user calls reject(), set handshake_exc to guarantee invariant: - # "handshake_exc is None if and only if opening handshake succeeded." - if self.handshake_exc is None: - self.handshake_exc = InvalidStatus(response) - self.logger.info("connection rejected (%d %s)", status.value, status.phrase) - return response + return Response(status.value, status.phrase, headers, body) def send_response(self, response: Response) -> None: """ @@ -545,7 +537,20 @@ def send_response(self, response: Response) -> None: if response.status_code == 101: assert self.state is CONNECTING self.state = OPEN + self.logger.info("connection open") + else: + # handshake_exc may be already set if accept() encountered an error. + # If the connection isn't open, set handshake_exc to guarantee that + # handshake_exc is None if and only if opening handshake succeeded. + if self.handshake_exc is None: + self.handshake_exc = InvalidStatus(response) + self.logger.info( + "connection rejected (%d %s)", + response.status_code, + response.reason_phrase, + ) + self.send_eof() self.parser = self.discard() next(self.parser) # start coroutine diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index eb0536013..0b19201a9 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -23,7 +23,7 @@ validate_subprotocols, ) from ..http11 import SERVER, Request, Response -from ..protocol import CONNECTING, Event +from ..protocol import CONNECTING, OPEN, Event from ..server import ServerProtocol from ..typing import LoggerLike, Origin, StatusLike, Subprotocol from .connection import Connection @@ -166,8 +166,9 @@ def handshake( self.protocol.send_response(self.response) # self.protocol.handshake_exc is always set when the connection is lost - # before receiving a request, when the request cannot be parsed, or when - # the response fails the handshake. + # before receiving a request, when the request cannot be parsed, when + # the handshake encounters an error, or when process_request or + # process_response sends a HTTP response that rejects the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc @@ -569,6 +570,7 @@ def protocol_select_subprotocol( connection.recv_events_thread.join() return + assert connection.protocol.state is OPEN try: handler(connection) except Exception: diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index fdcbf9780..47e0148a6 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -145,7 +145,10 @@ async def test_process_request_returns_response(self): def process_request(ws, request): return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") - async with serve(*args, process_request=process_request) as server: + async def handler(ws): + self.fail("handler must not run") + + async with serve(handler, *args[1:], process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") @@ -160,7 +163,10 @@ async def test_async_process_request_returns_response(self): async def process_request(ws, request): return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") - async with serve(*args, process_request=process_request) as server: + async def handler(ws): + self.fail("handler must not run") + + async with serve(handler, *args[1:], process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index d0d2c0955..3bc6f76cd 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -133,7 +133,10 @@ def test_process_request_returns_response(self): def process_request(ws, request): return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") - with run_server(process_request=process_request) as server: + def handler(ws): + self.fail("handler must not run") + + with run_server(handler, process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: with connect(get_uri(server)): self.fail("did not raise") diff --git a/tests/test_server.py b/tests/test_server.py index d34c8e83d..52c8a2b99 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -106,10 +106,11 @@ def make_request(self): ), ) - def test_send_accept(self): + def test_send_response_after_successful_accept(self): server = ServerProtocol() + request = self.make_request() with unittest.mock.patch("email.utils.formatdate", return_value=DATE): - response = server.accept(self.make_request()) + response = server.accept(request) self.assertIsInstance(response, Response) server.send_response(response) self.assertEqual( @@ -126,7 +127,32 @@ def test_send_accept(self): self.assertFalse(server.close_expected()) self.assertEqual(server.state, OPEN) - def test_send_reject(self): + def test_send_response_after_failed_accept(self): + server = ServerProtocol() + request = self.make_request() + del request.headers["Sec-WebSocket-Key"] + with unittest.mock.patch("email.utils.formatdate", return_value=DATE): + response = server.accept(request) + self.assertIsInstance(response, Response) + server.send_response(response) + self.assertEqual( + server.data_to_send(), + [ + f"HTTP/1.1 400 Bad Request\r\n" + f"Date: {DATE}\r\n" + f"Connection: close\r\n" + f"Content-Length: 94\r\n" + f"Content-Type: text/plain; charset=utf-8\r\n" + f"\r\n" + f"Failed to open a WebSocket connection: " + f"missing Sec-WebSocket-Key header; 'sec-websocket-key'.\n".encode(), + b"", + ], + ) + self.assertTrue(server.close_expected()) + self.assertEqual(server.state, CONNECTING) + + def test_send_response_after_reject(self): server = ServerProtocol() with unittest.mock.patch("email.utils.formatdate", return_value=DATE): response = server.reject(http.HTTPStatus.NOT_FOUND, "Sorry folks.\n") @@ -148,6 +174,19 @@ def test_send_reject(self): self.assertTrue(server.close_expected()) self.assertEqual(server.state, CONNECTING) + def test_send_response_without_accept_or_reject(self): + server = ServerProtocol() + server.send_response(Response(410, "Gone", Headers(), b"AWOL.\n")) + self.assertEqual( + server.data_to_send(), + [ + "HTTP/1.1 410 Gone\r\n\r\nAWOL.\n".encode(), + b"", + ], + ) + self.assertTrue(server.close_expected()) + self.assertEqual(server.state, CONNECTING) + def test_accept_response(self): server = ServerProtocol() with unittest.mock.patch("email.utils.formatdate", return_value=DATE): From 206624a6a700b8ee572b1730df5072adc47a17e5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Sep 2024 14:01:59 +0200 Subject: [PATCH 1402/1539] Standard spelling on "an HTTP". --- docs/project/changelog.rst | 2 +- docs/reference/features.rst | 7 +++---- src/websockets/asyncio/server.py | 2 +- src/websockets/sync/server.py | 2 +- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 69051d287..456c15dac 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -73,7 +73,7 @@ Bug fixes * The new :mod:`asyncio` and :mod:`threading` implementations of servers don't start the connection handler anymore when ``process_request`` or - ``process_response`` returns a HTTP response. + ``process_response`` returns an HTTP response. 13.0.1 ------ diff --git a/docs/reference/features.rst b/docs/reference/features.rst index eeade1462..8b04034eb 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -161,10 +161,9 @@ Client | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | ❌ | | (`#784`_) | | | | | +------------------------------------+--------+--------+--------+--------+ - | Connect via a HTTP proxy (`#364`_) | ❌ | ❌ | — | ❌ | + | Connect via HTTP proxy (`#364`_) | ❌ | ❌ | — | ❌ | +------------------------------------+--------+--------+--------+--------+ - | Connect via a SOCKS5 proxy | ❌ | ❌ | — | ❌ | - | (`#475`_) | | | | | + | Connect via SOCKS5 proxy (`#475`_) | ❌ | ❌ | — | ❌ | +------------------------------------+--------+--------+--------+--------+ .. _#364: https://github.com/python-websockets/websockets/issues/364 @@ -179,7 +178,7 @@ There is no way to control compression of outgoing frames on a per-frame basis .. _#538: https://github.com/python-websockets/websockets/issues/538 -The server doesn't check the Host header and doesn't respond with a HTTP 400 Bad +The server doesn't check the Host header and doesn't respond with HTTP 400 Bad Request if it is missing or invalid (`#1246`). .. _#1246: https://github.com/python-websockets/websockets/issues/1246 diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 78ee760d2..19dae44b7 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -203,7 +203,7 @@ async def handshake( # self.protocol.handshake_exc is always set when the connection is lost # before receiving a request, when the request cannot be parsed, when # the handshake encounters an error, or when process_request or - # process_response sends a HTTP response that rejects the handshake. + # process_response sends an HTTP response that rejects the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 0b19201a9..1b7cbb4b4 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -168,7 +168,7 @@ def handshake( # self.protocol.handshake_exc is always set when the connection is lost # before receiving a request, when the request cannot be parsed, when # the handshake encounters an error, or when process_request or - # process_response sends a HTTP response that rejects the handshake. + # process_response sends an HTTP response that rejects the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc From 20739e03ec6ccb010391b5179315368a2dd3a594 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Sep 2024 13:58:33 +0200 Subject: [PATCH 1403/1539] Improve exception handling during handshake. Also refactor tests for Sans-I/O client and server. --- src/websockets/client.py | 8 +- src/websockets/server.py | 22 +- tests/test_client.py | 726 ++++++++++++++++++++------------------- tests/test_connection.py | 1 + tests/test_server.py | 658 +++++++++++++++++++++-------------- 5 files changed, 799 insertions(+), 616 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 0e36fd028..e5f294986 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -175,10 +175,10 @@ def process_response(self, response: Response) -> None: try: s_w_accept = headers["Sec-WebSocket-Accept"] - except KeyError as exc: - raise InvalidHeader("Sec-WebSocket-Accept") from exc - except MultipleValuesError as exc: - raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from exc + except KeyError: + raise InvalidHeader("Sec-WebSocket-Accept") from None + except MultipleValuesError: + raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from None if s_w_accept != accept_key(self.key): raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept) diff --git a/src/websockets/server.py b/src/websockets/server.py index b2671f402..006d5bdd5 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -253,10 +253,10 @@ def process_request( try: key = headers["Sec-WebSocket-Key"] - except KeyError as exc: - raise InvalidHeader("Sec-WebSocket-Key") from exc - except MultipleValuesError as exc: - raise InvalidHeader("Sec-WebSocket-Key", "multiple values") from exc + except KeyError: + raise InvalidHeader("Sec-WebSocket-Key") from None + except MultipleValuesError: + raise InvalidHeader("Sec-WebSocket-Key", "multiple values") from None try: raw_key = base64.b64decode(key.encode(), validate=True) @@ -267,10 +267,10 @@ def process_request( try: version = headers["Sec-WebSocket-Version"] - except KeyError as exc: - raise InvalidHeader("Sec-WebSocket-Version") from exc - except MultipleValuesError as exc: - raise InvalidHeader("Sec-WebSocket-Version", "multiple values") from exc + except KeyError: + raise InvalidHeader("Sec-WebSocket-Version") from None + except MultipleValuesError: + raise InvalidHeader("Sec-WebSocket-Version", "multiple values") from None if version != "13": raise InvalidHeaderValue("Sec-WebSocket-Version", version) @@ -308,8 +308,8 @@ def process_origin(self, headers: Headers) -> Origin | None: # per https://datatracker.ietf.org/doc/html/rfc6454#section-7.3. try: origin = headers.get("Origin") - except MultipleValuesError as exc: - raise InvalidHeader("Origin", "multiple values") from exc + except MultipleValuesError: + raise InvalidHeader("Origin", "multiple values") from None if origin is not None: origin = cast(Origin, origin) if self.origins is not None: @@ -503,7 +503,7 @@ def reject(self, status: StatusLike, text: str) -> Response: HTTP response to send to the client. """ - # If a user passes an int instead of a HTTPStatus, fix it automatically. + # If status is an int instead of an HTTPStatus, fix it automatically. status = http.HTTPStatus(status) body = text.encode() headers = Headers( diff --git a/tests/test_client.py b/tests/test_client.py index 47558c1c0..2468be85e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,11 +1,14 @@ +import contextlib +import dataclasses import logging +import types import unittest -import unittest.mock +from unittest.mock import patch from websockets.client import * from websockets.client import backoff from websockets.datastructures import Headers -from websockets.exceptions import InvalidHandshake, InvalidHeader +from websockets.exceptions import InvalidHandshake, InvalidHeader, InvalidStatus from websockets.frames import OP_TEXT, Frame from websockets.http11 import Request, Response from websockets.protocol import CONNECTING, OPEN @@ -22,13 +25,19 @@ from .utils import DATE, DeprecationTestCase -class ConnectTests(unittest.TestCase): - def test_send_connect(self): - with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientProtocol(parse_uri("wss://example.com/test")) +URI = parse_uri("wss://example.com/test") # for tests where the URI doesn't matter + + +@patch("websockets.client.generate_key", return_value=KEY) +class BasicTests(unittest.TestCase): + """Test basic opening handshake scenarios.""" + + def test_send_request(self, _generate_key): + """Client sends a handshake request.""" + client = ClientProtocol(URI) request = client.connect() - self.assertIsInstance(request, Request) client.send_request(request) + self.assertEqual( client.data_to_send(), [ @@ -42,11 +51,56 @@ def test_send_connect(self): ], ) self.assertFalse(client.close_expected()) + self.assertEqual(client.state, CONNECTING) + + def test_receive_successful_response(self, _generate_key): + """Client receives a successful handshake response.""" + client = ClientProtocol(URI) + client.receive_data( + ( + f"HTTP/1.1 101 Switching Protocols\r\n" + f"Upgrade: websocket\r\n" + f"Connection: Upgrade\r\n" + f"Sec-WebSocket-Accept: {ACCEPT}\r\n" + f"Date: {DATE}\r\n" + f"\r\n" + ).encode(), + ) + + self.assertEqual(client.data_to_send(), []) + self.assertFalse(client.close_expected()) + self.assertEqual(client.state, OPEN) + + def test_receive_failed_response(self, _generate_key): + """Client receives a failed handshake response.""" + client = ClientProtocol(URI) + client.receive_data( + ( + f"HTTP/1.1 404 Not Found\r\n" + f"Date: {DATE}\r\n" + f"Content-Length: 13\r\n" + f"Content-Type: text/plain; charset=utf-8\r\n" + f"Connection: close\r\n" + f"\r\n" + f"Sorry folks.\n" + ).encode(), + ) + + self.assertEqual(client.data_to_send(), [b""]) + self.assertTrue(client.close_expected()) + self.assertEqual(client.state, CONNECTING) + + +class RequestTests(unittest.TestCase): + """Test generating opening handshake requests.""" - def test_connect_request(self): - with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientProtocol(parse_uri("wss://example.com/test")) + @patch("websockets.client.generate_key", return_value=KEY) + def test_connect(self, _generate_key): + """connect() creates an opening handshake request.""" + client = ClientProtocol(URI) request = client.connect() + + self.assertIsInstance(request, Request) self.assertEqual(request.path, "/test") self.assertEqual( request.headers, @@ -62,12 +116,14 @@ def test_connect_request(self): ) def test_path(self): + """connect() uses the path from the URI.""" client = ClientProtocol(parse_uri("wss://example.com/endpoint?test=1")) request = client.connect() self.assertEqual(request.path, "/endpoint?test=1") def test_port(self): + """connect() uses the port from the URI or the default port.""" for uri, host in [ ("ws://example.com/", "example.com"), ("ws://example.com:80/", "example.com"), @@ -83,85 +139,41 @@ def test_port(self): self.assertEqual(request.headers["Host"], host) def test_user_info(self): + """connect() perfoms HTTP Basic Authentication with user info from the URI.""" client = ClientProtocol(parse_uri("wss://hello:iloveyou@example.com/")) request = client.connect() self.assertEqual(request.headers["Authorization"], "Basic aGVsbG86aWxvdmV5b3U=") def test_origin(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - origin="https://example.com", - ) + """connect(origin=...) generates an Origin header.""" + client = ClientProtocol(URI, origin="https://example.com") request = client.connect() self.assertEqual(request.headers["Origin"], "https://example.com") def test_extensions(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - extensions=[ClientOpExtensionFactory()], - ) + """connect(extensions=...) generates a Sec-WebSocket-Extensions header.""" + client = ClientProtocol(URI, extensions=[ClientOpExtensionFactory()]) request = client.connect() self.assertEqual(request.headers["Sec-WebSocket-Extensions"], "x-op; op") def test_subprotocols(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - subprotocols=["chat"], - ) + """connect(subprotocols=...) generates a Sec-WebSocket-Protocol header.""" + client = ClientProtocol(URI, subprotocols=["chat"]) request = client.connect() self.assertEqual(request.headers["Sec-WebSocket-Protocol"], "chat") -class AcceptRejectTests(unittest.TestCase): - def test_receive_accept(self): - with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientProtocol(parse_uri("ws://example.com/test")) - client.connect() - client.receive_data( - ( - f"HTTP/1.1 101 Switching Protocols\r\n" - f"Upgrade: websocket\r\n" - f"Connection: Upgrade\r\n" - f"Sec-WebSocket-Accept: {ACCEPT}\r\n" - f"Date: {DATE}\r\n" - f"\r\n" - ).encode(), - ) - [response] = client.events_received() - self.assertIsInstance(response, Response) - self.assertEqual(client.data_to_send(), []) - self.assertFalse(client.close_expected()) - self.assertEqual(client.state, OPEN) - - def test_receive_reject(self): - with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientProtocol(parse_uri("ws://example.com/test")) - client.connect() - client.receive_data( - ( - f"HTTP/1.1 404 Not Found\r\n" - f"Date: {DATE}\r\n" - f"Content-Length: 13\r\n" - f"Content-Type: text/plain; charset=utf-8\r\n" - f"Connection: close\r\n" - f"\r\n" - f"Sorry folks.\n" - ).encode(), - ) - [response] = client.events_received() - self.assertIsInstance(response, Response) - self.assertEqual(client.data_to_send(), [b""]) - self.assertTrue(client.close_expected()) - self.assertEqual(client.state, CONNECTING) +@patch("websockets.client.generate_key", return_value=KEY) +class ResponseTests(unittest.TestCase): + """Test receiving opening handshake responses.""" - def test_accept_response(self): - with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientProtocol(parse_uri("ws://example.com/test")) - client.connect() + def test_receive_successful_response(self, _generate_key): + """Client receives a successful handshake response.""" + client = ClientProtocol(URI) client.receive_data( ( f"HTTP/1.1 101 Switching Protocols\r\n" @@ -173,6 +185,7 @@ def test_accept_response(self): ).encode(), ) [response] = client.events_received() + self.assertEqual(response.status_code, 101) self.assertEqual(response.reason_phrase, "Switching Protocols") self.assertEqual( @@ -187,11 +200,11 @@ def test_accept_response(self): ), ) self.assertIsNone(response.body) + self.assertIsNone(client.handshake_exc) - def test_reject_response(self): - with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientProtocol(parse_uri("ws://example.com/test")) - client.connect() + def test_receive_failed_response(self, _generate_key): + """Client receives a failed handshake response.""" + client = ClientProtocol(URI) client.receive_data( ( f"HTTP/1.1 404 Not Found\r\n" @@ -204,6 +217,7 @@ def test_reject_response(self): ).encode(), ) [response] = client.events_received() + self.assertEqual(response.status_code, 404) self.assertEqual(response.reason_phrase, "Not Found") self.assertEqual( @@ -218,394 +232,416 @@ def test_reject_response(self): ), ) self.assertEqual(response.body, b"Sorry folks.\n") + self.assertIsInstance(client.handshake_exc, InvalidStatus) + self.assertEqual( + str(client.handshake_exc), + "server rejected WebSocket connection: HTTP 404", + ) - def test_no_response(self): - with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientProtocol(parse_uri("ws://example.com/test")) - client.connect() + def test_receive_no_response(self, _generate_key): + """Client receives no handshake response.""" + client = ClientProtocol(URI) client.receive_eof() + self.assertEqual(client.events_received(), []) + self.assertIsInstance(client.handshake_exc, EOFError) + self.assertEqual( + str(client.handshake_exc), + "connection closed while reading HTTP status line", + ) - def test_partial_response(self): - with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientProtocol(parse_uri("ws://example.com/test")) - client.connect() + def test_receive_truncated_response(self, _generate_key): + """Client receives a truncated handshake response.""" + client = ClientProtocol(URI) client.receive_data(b"HTTP/1.1 101 Switching Protocols\r\n") client.receive_eof() + self.assertEqual(client.events_received(), []) + self.assertIsInstance(client.handshake_exc, EOFError) + self.assertEqual( + str(client.handshake_exc), + "connection closed while reading HTTP headers", + ) - def test_random_response(self): - with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientProtocol(parse_uri("ws://example.com/test")) - client.connect() + def test_receive_random_response(self, _generate_key): + """Client receives a junk handshake response.""" + client = ClientProtocol(URI) client.receive_data(b"220 smtp.invalid\r\n") client.receive_data(b"250 Hello relay.invalid\r\n") client.receive_data(b"250 Ok\r\n") client.receive_data(b"250 Ok\r\n") - client.receive_eof() - self.assertEqual(client.events_received(), []) - def make_accept_response(self, client): - request = client.connect() - return Response( - status_code=101, - reason_phrase="Switching Protocols", - headers=Headers( - { - "Upgrade": "websocket", - "Connection": "Upgrade", - "Sec-WebSocket-Accept": accept_key( - request.headers["Sec-WebSocket-Key"] - ), - } - ), + self.assertEqual(client.events_received(), []) + self.assertIsInstance(client.handshake_exc, ValueError) + self.assertEqual( + str(client.handshake_exc), + "invalid HTTP status line: 220 smtp.invalid", ) - def test_basic(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - client.receive_data(response.serialize()) - [response] = client.events_received() +@contextlib.contextmanager +def alter_and_receive_response(client): + """Generate a handshake response that can be altered for testing.""" + # We could start by sending a handshake request, i.e.: + # request = client.connect() + # client.send_request(request) + # However, in the current implementation, these calls have no effect on the + # state of the client. Therefore, they're unnecessary and can be skipped. + response = Response( + status_code=101, + reason_phrase="Switching Protocols", + headers=Headers( + { + "Upgrade": "websocket", + "Connection": "Upgrade", + "Sec-WebSocket-Accept": accept_key(client.key), + } + ), + ) + yield response + client.receive_data(response.serialize()) + [parsed_response] = client.events_received() + assert response == dataclasses.replace(parsed_response, _exception=None) + + +class HandshakeTests(unittest.TestCase): + """Test processing of handshake responses to configure the connection.""" + + def assertHandshakeSuccess(self, client): + """Assert that the opening handshake succeeded.""" self.assertEqual(client.state, OPEN) + self.assertIsNone(client.handshake_exc) - def test_missing_connection(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - del response.headers["Connection"] - client.receive_data(response.serialize()) - [response] = client.events_received() - + def assertHandshakeError(self, client, exc_type, msg): + """Assert that the opening handshake failed with the given exception.""" self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHeader) as raised: - raise client.handshake_exc - self.assertEqual(str(raised.exception), "missing Connection header") + self.assertIsInstance(client.handshake_exc, exc_type) + # Exception chaining isn't used is client handshake implementation. + assert client.handshake_exc.__cause__ is None + self.assertEqual(str(client.handshake_exc), msg) - def test_invalid_connection(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - del response.headers["Connection"] - response.headers["Connection"] = "close" - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_basic(self): + """Handshake succeeds.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client): + pass - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHeader) as raised: - raise client.handshake_exc - self.assertEqual(str(raised.exception), "invalid Connection header: close") + self.assertHandshakeSuccess(client) - def test_missing_upgrade(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - del response.headers["Upgrade"] - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_missing_connection(self): + """Handshake fails when the Connection header is missing.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client) as response: + del response.headers["Connection"] + + self.assertHandshakeError( + client, + InvalidHeader, + "missing Connection header", + ) - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHeader) as raised: - raise client.handshake_exc - self.assertEqual(str(raised.exception), "missing Upgrade header") + def test_invalid_connection(self): + """Handshake fails when the Connection header is invalid.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client) as response: + del response.headers["Connection"] + response.headers["Connection"] = "close" + + self.assertHandshakeError( + client, + InvalidHeader, + "invalid Connection header: close", + ) - def test_invalid_upgrade(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - del response.headers["Upgrade"] - response.headers["Upgrade"] = "h2c" - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_missing_upgrade(self): + """Handshake fails when the Upgrade header is missing.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client) as response: + del response.headers["Upgrade"] + + self.assertHandshakeError( + client, + InvalidHeader, + "missing Upgrade header", + ) - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHeader) as raised: - raise client.handshake_exc - self.assertEqual(str(raised.exception), "invalid Upgrade header: h2c") + def test_invalid_upgrade(self): + """Handshake fails when the Upgrade header is invalid.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client) as response: + del response.headers["Upgrade"] + response.headers["Upgrade"] = "h2c" + + self.assertHandshakeError( + client, + InvalidHeader, + "invalid Upgrade header: h2c", + ) def test_missing_accept(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - del response.headers["Sec-WebSocket-Accept"] - client.receive_data(response.serialize()) - [response] = client.events_received() - - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHeader) as raised: - raise client.handshake_exc - self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Accept header") + """Handshake fails when the Sec-WebSocket-Accept header is missing.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client) as response: + del response.headers["Sec-WebSocket-Accept"] + + self.assertHandshakeError( + client, + InvalidHeader, + "missing Sec-WebSocket-Accept header", + ) def test_multiple_accept(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Accept"] = ACCEPT - client.receive_data(response.serialize()) - [response] = client.events_received() - - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHeader) as raised: - raise client.handshake_exc - self.assertEqual( - str(raised.exception), + """Handshake fails when the Sec-WebSocket-Accept header is repeated.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Accept"] = ACCEPT + + self.assertHandshakeError( + client, + InvalidHeader, "invalid Sec-WebSocket-Accept header: multiple values", ) def test_invalid_accept(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - del response.headers["Sec-WebSocket-Accept"] - response.headers["Sec-WebSocket-Accept"] = ACCEPT - client.receive_data(response.serialize()) - [response] = client.events_received() - - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHeader) as raised: - raise client.handshake_exc - self.assertEqual( - str(raised.exception), f"invalid Sec-WebSocket-Accept header: {ACCEPT}" + """Handshake fails when the Sec-WebSocket-Accept header is invalid.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client) as response: + del response.headers["Sec-WebSocket-Accept"] + response.headers["Sec-WebSocket-Accept"] = ACCEPT + + self.assertHandshakeError( + client, + InvalidHeader, + f"invalid Sec-WebSocket-Accept header: {ACCEPT}", ) def test_no_extensions(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - client.receive_data(response.serialize()) - [response] = client.events_received() + """Handshake succeeds without extensions.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client): + pass - self.assertEqual(client.state, OPEN) + self.assertHandshakeSuccess(client) self.assertEqual(client.extensions, []) - def test_no_extension(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - extensions=[ClientOpExtensionFactory()], - ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Extensions"] = "x-op; op" - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_offer_extension(self): + """Client offers an extension.""" + client = ClientProtocol(URI, extensions=[ClientRsv2ExtensionFactory()]) + request = client.connect() - self.assertEqual(client.state, OPEN) - self.assertEqual(client.extensions, [OpExtension()]) + self.assertEqual(request.headers["Sec-WebSocket-Extensions"], "x-rsv2") - def test_extension(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - extensions=[ClientRsv2ExtensionFactory()], - ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_enable_extension(self): + """Client offers an extension and the server enables it.""" + client = ClientProtocol(URI, extensions=[ClientRsv2ExtensionFactory()]) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" - self.assertEqual(client.state, OPEN) + self.assertHandshakeSuccess(client) self.assertEqual(client.extensions, [Rsv2Extension()]) - def test_unexpected_extension(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Extensions"] = "x-op; op" - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_extension_not_enabled(self): + """Client offers an extension, but the server doesn't enable it.""" + client = ClientProtocol(URI, extensions=[ClientRsv2ExtensionFactory()]) + with alter_and_receive_response(client): + pass - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHandshake) as raised: - raise client.handshake_exc - self.assertEqual(str(raised.exception), "no extensions supported") + self.assertHandshakeSuccess(client) + self.assertEqual(client.extensions, []) - def test_unsupported_extension(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - extensions=[ClientRsv2ExtensionFactory()], + def test_no_extensions_offered(self): + """Server enables an extension when the client didn't offer any.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" + + self.assertHandshakeError( + client, + InvalidHandshake, + "no extensions supported", ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Extensions"] = "x-op; op" - client.receive_data(response.serialize()) - [response] = client.events_received() - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHandshake) as raised: - raise client.handshake_exc - self.assertEqual( - str(raised.exception), + def test_extension_not_offered(self): + """Server enables an extension that the client didn't offer.""" + client = ClientProtocol(URI, extensions=[ClientRsv2ExtensionFactory()]) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Extensions"] = "x-op; op" + + self.assertHandshakeError( + client, + InvalidHandshake, "Unsupported extension: name = x-op, params = [('op', None)]", ) def test_supported_extension_parameters(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - extensions=[ClientOpExtensionFactory("this")], - ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Extensions"] = "x-op; op=this" - client.receive_data(response.serialize()) - [response] = client.events_received() + """Server enables an extension with parameters supported by the client.""" + client = ClientProtocol(URI, extensions=[ClientOpExtensionFactory("this")]) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Extensions"] = "x-op; op=this" - self.assertEqual(client.state, OPEN) + self.assertHandshakeSuccess(client) self.assertEqual(client.extensions, [OpExtension("this")]) def test_unsupported_extension_parameters(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - extensions=[ClientOpExtensionFactory("this")], - ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" - client.receive_data(response.serialize()) - [response] = client.events_received() - - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHandshake) as raised: - raise client.handshake_exc - self.assertEqual( - str(raised.exception), + """Server enables an extension with parameters unsupported by the client.""" + client = ClientProtocol(URI, extensions=[ClientOpExtensionFactory("this")]) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" + + self.assertHandshakeError( + client, + InvalidHandshake, "Unsupported extension: name = x-op, params = [('op', 'that')]", ) def test_multiple_supported_extension_parameters(self): + """Client offers the same extension with several parameters.""" client = ClientProtocol( - parse_uri("wss://example.com/"), + URI, extensions=[ ClientOpExtensionFactory("this"), ClientOpExtensionFactory("that"), ], ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" - client.receive_data(response.serialize()) - [response] = client.events_received() + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" - self.assertEqual(client.state, OPEN) + self.assertHandshakeSuccess(client) self.assertEqual(client.extensions, [OpExtension("that")]) def test_multiple_extensions(self): + """Client offers several extensions and the server enables them.""" client = ClientProtocol( - parse_uri("wss://example.com/"), - extensions=[ClientOpExtensionFactory(), ClientRsv2ExtensionFactory()], + URI, + extensions=[ + ClientOpExtensionFactory(), + ClientRsv2ExtensionFactory(), + ], ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Extensions"] = "x-op; op" - response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" - client.receive_data(response.serialize()) - [response] = client.events_received() + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Extensions"] = "x-op; op" + response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" - self.assertEqual(client.state, OPEN) + self.assertHandshakeSuccess(client) self.assertEqual(client.extensions, [OpExtension(), Rsv2Extension()]) def test_multiple_extensions_order(self): + """Client respects the order of extensions chosen by the server.""" client = ClientProtocol( - parse_uri("wss://example.com/"), - extensions=[ClientOpExtensionFactory(), ClientRsv2ExtensionFactory()], + URI, + extensions=[ + ClientOpExtensionFactory(), + ClientRsv2ExtensionFactory(), + ], ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" - response.headers["Sec-WebSocket-Extensions"] = "x-op; op" - client.receive_data(response.serialize()) - [response] = client.events_received() + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" + response.headers["Sec-WebSocket-Extensions"] = "x-op; op" - self.assertEqual(client.state, OPEN) + self.assertHandshakeSuccess(client) self.assertEqual(client.extensions, [Rsv2Extension(), OpExtension()]) def test_no_subprotocols(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - client.receive_data(response.serialize()) - [response] = client.events_received() + """Handshake succeeds without subprotocols.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client): + pass - self.assertEqual(client.state, OPEN) + self.assertHandshakeSuccess(client) self.assertIsNone(client.subprotocol) - def test_no_subprotocol(self): - client = ClientProtocol(parse_uri("wss://example.com/"), subprotocols=["chat"]) - response = self.make_accept_response(client) - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_no_subprotocol_requested(self): + """Client doesn't offer a subprotocol, but the server enables one.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Protocol"] = "chat" - self.assertEqual(client.state, OPEN) - self.assertIsNone(client.subprotocol) + self.assertHandshakeError( + client, + InvalidHandshake, + "no subprotocols supported", + ) - def test_subprotocol(self): - client = ClientProtocol(parse_uri("wss://example.com/"), subprotocols=["chat"]) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Protocol"] = "chat" - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_offer_subprotocol(self): + """Client offers a subprotocol.""" + client = ClientProtocol(URI, subprotocols=["chat"]) + request = client.connect() - self.assertEqual(client.state, OPEN) - self.assertEqual(client.subprotocol, "chat") + self.assertEqual(request.headers["Sec-WebSocket-Protocol"], "chat") - def test_unexpected_subprotocol(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Protocol"] = "chat" - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_enable_subprotocol(self): + """Client offers a subprotocol and the server enables it.""" + client = ClientProtocol(URI, subprotocols=["chat"]) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Protocol"] = "chat" - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHandshake) as raised: - raise client.handshake_exc - self.assertEqual(str(raised.exception), "no subprotocols supported") + self.assertHandshakeSuccess(client) + self.assertEqual(client.subprotocol, "chat") - def test_multiple_subprotocols(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - subprotocols=["superchat", "chat"], - ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Protocol"] = "superchat" - response.headers["Sec-WebSocket-Protocol"] = "chat" - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_no_subprotocol_accepted(self): + """Client offers a subprotocol, but the server doesn't enable it.""" + client = ClientProtocol(URI, subprotocols=["chat"]) + with alter_and_receive_response(client): + pass - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHandshake) as raised: - raise client.handshake_exc - self.assertEqual( - str(raised.exception), - "invalid Sec-WebSocket-Protocol header: " - "multiple values: superchat, chat", - ) + self.assertHandshakeSuccess(client) + self.assertIsNone(client.subprotocol) - def test_supported_subprotocol(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - subprotocols=["superchat", "chat"], - ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Protocol"] = "chat" - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_multiple_subprotocols(self): + """Client offers several subprotocols and the server enables one.""" + client = ClientProtocol(URI, subprotocols=["superchat", "chat"]) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Protocol"] = "chat" - self.assertEqual(client.state, OPEN) + self.assertHandshakeSuccess(client) self.assertEqual(client.subprotocol, "chat") def test_unsupported_subprotocol(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - subprotocols=["superchat", "chat"], + """Client offers subprotocols but the server enables another one.""" + client = ClientProtocol(URI, subprotocols=["superchat", "chat"]) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Protocol"] = "otherchat" + + self.assertHandshakeError( + client, + InvalidHandshake, + "unsupported subprotocol: otherchat", ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Protocol"] = "otherchat" - client.receive_data(response.serialize()) - [response] = client.events_received() - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHandshake) as raised: - raise client.handshake_exc - self.assertEqual(str(raised.exception), "unsupported subprotocol: otherchat") + def test_multiple_subprotocols_accepted(self): + """Server attempts to enable multiple subprotocols.""" + client = ClientProtocol(URI, subprotocols=["superchat", "chat"]) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Protocol"] = "superchat" + response.headers["Sec-WebSocket-Protocol"] = "chat" + + self.assertHandshakeError( + client, + InvalidHandshake, + "invalid Sec-WebSocket-Protocol header: " + "multiple values: superchat, chat", + ) class MiscTests(unittest.TestCase): def test_bypass_handshake(self): - client = ClientProtocol(parse_uri("ws://example.com/test"), state=OPEN) + """ClientProtocol bypasses the opening handshake.""" + client = ClientProtocol(URI, state=OPEN) client.receive_data(b"\x81\x06Hello!") [frame] = client.events_received() self.assertEqual(frame, Frame(OP_TEXT, b"Hello!")) def test_custom_logger(self): + """ClientProtocol accepts a logger argument.""" logger = logging.getLogger("test") with self.assertLogs("test", logging.DEBUG) as logs: - ClientProtocol(parse_uri("wss://example.com/test"), logger=logger) + ClientProtocol(URI, logger=logger) self.assertEqual(len(logs.records), 1) class BackwardsCompatibilityTests(DeprecationTestCase): def test_client_connection_class(self): + """ClientConnection is a deprecated alias for ClientProtocol.""" with self.assertDeprecationWarning( "ClientConnection was renamed to ClientProtocol" ): @@ -618,7 +654,9 @@ def test_client_connection_class(self): class BackoffTests(unittest.TestCase): def test_backoff(self): + """backoff() yields a random delay, then exponentially increasing delays.""" backoff_gen = backoff() + self.assertIsInstance(backoff_gen, types.GeneratorType) initial_delay = next(backoff_gen) self.assertGreaterEqual(initial_delay, 0) diff --git a/tests/test_connection.py b/tests/test_connection.py index 6592d67d0..9ad2ebea4 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -5,6 +5,7 @@ class BackwardsCompatibilityTests(DeprecationTestCase): def test_connection_class(self): + """Connection is a deprecated alias for Protocol.""" with self.assertDeprecationWarning( "websockets.connection was renamed to websockets.protocol " "and Connection was renamed to Protocol" diff --git a/tests/test_server.py b/tests/test_server.py index 52c8a2b99..844ba64ec 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,7 +1,8 @@ import http import logging +import sys import unittest -import unittest.mock +from unittest.mock import patch from websockets.datastructures import Headers from websockets.exceptions import ( @@ -25,8 +26,28 @@ from .utils import DATE, DeprecationTestCase -class ConnectTests(unittest.TestCase): - def test_receive_connect(self): +def make_request(): + """Generate a handshake request that can be altered for testing.""" + return Request( + path="/test", + headers=Headers( + { + "Host": "example.com", + "Upgrade": "websocket", + "Connection": "Upgrade", + "Sec-WebSocket-Key": KEY, + "Sec-WebSocket-Version": "13", + } + ), + ) + + +@patch("email.utils.formatdate", return_value=DATE) +class BasicTests(unittest.TestCase): + """Test basic opening handshake scenarios.""" + + def test_receive_request(self, _formatdate): + """Server receives a handshake request.""" server = ServerProtocol() server.receive_data( ( @@ -39,80 +60,18 @@ def test_receive_connect(self): f"\r\n" ).encode(), ) - [request] = server.events_received() - self.assertIsInstance(request, Request) + self.assertEqual(server.data_to_send(), []) self.assertFalse(server.close_expected()) + self.assertEqual(server.state, CONNECTING) - def test_connect_request(self): - server = ServerProtocol() - server.receive_data( - ( - f"GET /test HTTP/1.1\r\n" - f"Host: example.com\r\n" - f"Upgrade: websocket\r\n" - f"Connection: Upgrade\r\n" - f"Sec-WebSocket-Key: {KEY}\r\n" - f"Sec-WebSocket-Version: 13\r\n" - f"\r\n" - ).encode(), - ) - [request] = server.events_received() - self.assertEqual(request.path, "/test") - self.assertEqual( - request.headers, - Headers( - { - "Host": "example.com", - "Upgrade": "websocket", - "Connection": "Upgrade", - "Sec-WebSocket-Key": KEY, - "Sec-WebSocket-Version": "13", - } - ), - ) - - def test_no_request(self): - server = ServerProtocol() - server.receive_eof() - self.assertEqual(server.events_received(), []) - - def test_partial_request(self): - server = ServerProtocol() - server.receive_data(b"GET /test HTTP/1.1\r\n") - server.receive_eof() - self.assertEqual(server.events_received(), []) - - def test_junk_request(self): - server = ServerProtocol() - server.receive_data(b"HELO relay.invalid\r\n") - server.receive_data(b"MAIL FROM: \r\n") - server.receive_data(b"RCPT TO: \r\n") - self.assertEqual(server.events_received(), []) - - -class AcceptRejectTests(unittest.TestCase): - def make_request(self): - return Request( - path="/test", - headers=Headers( - { - "Host": "example.com", - "Upgrade": "websocket", - "Connection": "Upgrade", - "Sec-WebSocket-Key": KEY, - "Sec-WebSocket-Version": "13", - } - ), - ) - - def test_send_response_after_successful_accept(self): + def test_accept_and_send_successful_response(self, _formatdate): + """Server accepts a handshake request and sends a successful response.""" server = ServerProtocol() - request = self.make_request() - with unittest.mock.patch("email.utils.formatdate", return_value=DATE): - response = server.accept(request) - self.assertIsInstance(response, Response) + request = make_request() + response = server.accept(request) server.send_response(response) + self.assertEqual( server.data_to_send(), [ @@ -127,37 +86,37 @@ def test_send_response_after_successful_accept(self): self.assertFalse(server.close_expected()) self.assertEqual(server.state, OPEN) - def test_send_response_after_failed_accept(self): + def test_send_response_after_failed_accept(self, _formatdate): + """Server accepts a handshake request but sends a failed response.""" server = ServerProtocol() - request = self.make_request() + request = make_request() del request.headers["Sec-WebSocket-Key"] - with unittest.mock.patch("email.utils.formatdate", return_value=DATE): - response = server.accept(request) - self.assertIsInstance(response, Response) + response = server.accept(request) server.send_response(response) + self.assertEqual( server.data_to_send(), [ f"HTTP/1.1 400 Bad Request\r\n" f"Date: {DATE}\r\n" f"Connection: close\r\n" - f"Content-Length: 94\r\n" + f"Content-Length: 73\r\n" f"Content-Type: text/plain; charset=utf-8\r\n" f"\r\n" f"Failed to open a WebSocket connection: " - f"missing Sec-WebSocket-Key header; 'sec-websocket-key'.\n".encode(), + f"missing Sec-WebSocket-Key header.\n".encode(), b"", ], ) self.assertTrue(server.close_expected()) self.assertEqual(server.state, CONNECTING) - def test_send_response_after_reject(self): + def test_send_response_after_reject(self, _formatdate): + """Server rejects a handshake request and sends a failed response.""" server = ServerProtocol() - with unittest.mock.patch("email.utils.formatdate", return_value=DATE): - response = server.reject(http.HTTPStatus.NOT_FOUND, "Sorry folks.\n") - self.assertIsInstance(response, Response) + response = server.reject(http.HTTPStatus.NOT_FOUND, "Sorry folks.\n") server.send_response(response) + self.assertEqual( server.data_to_send(), [ @@ -174,23 +133,124 @@ def test_send_response_after_reject(self): self.assertTrue(server.close_expected()) self.assertEqual(server.state, CONNECTING) - def test_send_response_without_accept_or_reject(self): + def test_send_response_without_accept_or_reject(self, _formatdate): + """Server doesn't accept or reject and sends a failed response.""" server = ServerProtocol() - server.send_response(Response(410, "Gone", Headers(), b"AWOL.\n")) + server.send_response( + Response( + 410, + "Gone", + Headers( + { + "Connection": "close", + "Content-Length": 6, + "Content-Type": "text/plain", + } + ), + b"AWOL.\n", + ) + ) self.assertEqual( server.data_to_send(), [ - "HTTP/1.1 410 Gone\r\n\r\nAWOL.\n".encode(), + "HTTP/1.1 410 Gone\r\n" + "Connection: close\r\n" + "Content-Length: 6\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "AWOL.\n".encode(), b"", ], ) self.assertTrue(server.close_expected()) self.assertEqual(server.state, CONNECTING) - def test_accept_response(self): + +class RequestTests(unittest.TestCase): + """Test receiving opening handshake requests.""" + + def test_receive_request(self): + """Server receives a handshake request.""" server = ServerProtocol() - with unittest.mock.patch("email.utils.formatdate", return_value=DATE): - response = server.accept(self.make_request()) + server.receive_data( + ( + f"GET /test HTTP/1.1\r\n" + f"Host: example.com\r\n" + f"Upgrade: websocket\r\n" + f"Connection: Upgrade\r\n" + f"Sec-WebSocket-Key: {KEY}\r\n" + f"Sec-WebSocket-Version: 13\r\n" + f"\r\n" + ).encode(), + ) + [request] = server.events_received() + + self.assertIsInstance(request, Request) + self.assertEqual(request.path, "/test") + self.assertEqual( + request.headers, + Headers( + { + "Host": "example.com", + "Upgrade": "websocket", + "Connection": "Upgrade", + "Sec-WebSocket-Key": KEY, + "Sec-WebSocket-Version": "13", + } + ), + ) + self.assertIsNone(server.handshake_exc) + + def test_receive_no_request(self): + """Server receives no handshake request.""" + server = ServerProtocol() + server.receive_eof() + + self.assertEqual(server.events_received(), []) + self.assertIsInstance(server.handshake_exc, EOFError) + self.assertEqual( + str(server.handshake_exc), + "connection closed while reading HTTP request line", + ) + + def test_receive_truncated_request(self): + """Server receives a truncated handshake request.""" + server = ServerProtocol() + server.receive_data(b"GET /test HTTP/1.1\r\n") + server.receive_eof() + + self.assertEqual(server.events_received(), []) + self.assertIsInstance(server.handshake_exc, EOFError) + self.assertEqual( + str(server.handshake_exc), + "connection closed while reading HTTP headers", + ) + + def test_receive_junk_request(self): + """Server receives a junk handshake request.""" + server = ServerProtocol() + server.receive_data(b"HELO relay.invalid\r\n") + server.receive_data(b"MAIL FROM: \r\n") + server.receive_data(b"RCPT TO: \r\n") + + self.assertEqual(server.events_received(), []) + self.assertIsInstance(server.handshake_exc, ValueError) + self.assertEqual( + str(server.handshake_exc), + "invalid HTTP request line: HELO relay.invalid", + ) + + +class ResponseTests(unittest.TestCase): + """Test generating opening handshake responses.""" + + @patch("email.utils.formatdate", return_value=DATE) + def test_accept_response(self, _formatdate): + """accept() creates a successful opening handshake response.""" + server = ServerProtocol() + request = make_request() + response = server.accept(request) + self.assertIsInstance(response, Response) self.assertEqual(response.status_code, 101) self.assertEqual(response.reason_phrase, "Switching Protocols") @@ -207,10 +267,12 @@ def test_accept_response(self): ) self.assertIsNone(response.body) - def test_reject_response(self): + @patch("email.utils.formatdate", return_value=DATE) + def test_reject_response(self, _formatdate): + """reject() creates a failed opening handshake response.""" server = ServerProtocol() - with unittest.mock.patch("email.utils.formatdate", return_value=DATE): - response = server.reject(http.HTTPStatus.NOT_FOUND, "Sorry folks.\n") + response = server.reject(http.HTTPStatus.NOT_FOUND, "Sorry folks.\n") + self.assertIsInstance(response, Response) self.assertEqual(response.status_code, 404) self.assertEqual(response.reason_phrase, "Not Found") @@ -228,477 +290,552 @@ def test_reject_response(self): self.assertEqual(response.body, b"Sorry folks.\n") def test_reject_response_supports_int_status(self): + """reject() accepts an integer status code instead of an HTTPStatus.""" server = ServerProtocol() response = server.reject(404, "Sorry folks.\n") + self.assertEqual(response.status_code, 404) self.assertEqual(response.reason_phrase, "Not Found") - def test_basic(self): + @patch("websockets.server.ServerProtocol.process_request") + def test_unexpected_error(self, process_request): + """accept() handles unexpected errors and returns an error response.""" server = ServerProtocol() - request = self.make_request() + request = make_request() + process_request.side_effect = (Exception("BOOM"),) response = server.accept(request) - self.assertEqual(response.status_code, 101) + self.assertEqual(response.status_code, 500) + self.assertIsInstance(server.handshake_exc, Exception) + self.assertEqual(str(server.handshake_exc), "BOOM") - def test_unexpected_exception(self): + +class HandshakeTests(unittest.TestCase): + """Test processing of handshake responses to configure the connection.""" + + def assertHandshakeSuccess(self, server): + """Assert that the opening handshake succeeded.""" + self.assertEqual(server.state, OPEN) + self.assertIsNone(server.handshake_exc) + + def assertHandshakeError(self, server, exc_type, msg): + """Assert that the opening handshake failed with the given exception.""" + self.assertEqual(server.state, CONNECTING) + self.assertIsInstance(server.handshake_exc, exc_type) + exc = server.handshake_exc + exc_str = str(exc) + while exc.__cause__ is not None: + exc = exc.__cause__ + exc_str += "; " + str(exc) + self.assertEqual(exc_str, msg) + + def test_basic(self): + """Handshake succeeds.""" server = ServerProtocol() - request = self.make_request() - with unittest.mock.patch( - "websockets.server.ServerProtocol.process_request", - side_effect=Exception("BOOM"), - ): - response = server.accept(request) + request = make_request() + response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 500) - with self.assertRaises(Exception) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), - "BOOM", - ) + self.assertHandshakeSuccess(server) def test_missing_connection(self): + """Handshake fails when the Connection header is missing.""" server = ServerProtocol() - request = self.make_request() + request = make_request() del request.headers["Connection"] response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") - with self.assertRaises(InvalidUpgrade) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidUpgrade, "missing Connection header", ) def test_invalid_connection(self): + """Handshake fails when the Connection header is invalid.""" server = ServerProtocol() - request = self.make_request() + request = make_request() del request.headers["Connection"] request.headers["Connection"] = "close" response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") - with self.assertRaises(InvalidUpgrade) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidUpgrade, "invalid Connection header: close", ) def test_missing_upgrade(self): + """Handshake fails when the Upgrade header is missing.""" server = ServerProtocol() - request = self.make_request() + request = make_request() del request.headers["Upgrade"] response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") - with self.assertRaises(InvalidUpgrade) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidUpgrade, "missing Upgrade header", ) def test_invalid_upgrade(self): + """Handshake fails when the Upgrade header is invalid.""" server = ServerProtocol() - request = self.make_request() + request = make_request() del request.headers["Upgrade"] request.headers["Upgrade"] = "h2c" response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") - with self.assertRaises(InvalidUpgrade) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidUpgrade, "invalid Upgrade header: h2c", ) def test_missing_key(self): + """Handshake fails when the Sec-WebSocket-Key header is missing.""" server = ServerProtocol() - request = self.make_request() + request = make_request() del request.headers["Sec-WebSocket-Key"] response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 400) - with self.assertRaises(InvalidHeader) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidHeader, "missing Sec-WebSocket-Key header", ) def test_multiple_key(self): + """Handshake fails when the Sec-WebSocket-Key header is repeated.""" server = ServerProtocol() - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Key"] = KEY response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 400) - with self.assertRaises(InvalidHeader) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidHeader, "invalid Sec-WebSocket-Key header: multiple values", ) def test_invalid_key(self): + """Handshake fails when the Sec-WebSocket-Key header is invalid.""" server = ServerProtocol() - request = self.make_request() + request = make_request() del request.headers["Sec-WebSocket-Key"] - request.headers["Sec-WebSocket-Key"] = "not Base64 data!" + request.headers["Sec-WebSocket-Key"] = "" response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 400) - with self.assertRaises(InvalidHeader) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), - "invalid Sec-WebSocket-Key header: not Base64 data!", + if sys.version_info[:2] >= (3, 11): + b64_exc = "Only base64 data is allowed" + else: # pragma: no cover + b64_exc = "Non-base64 digit found" + self.assertHandshakeError( + server, + InvalidHeader, + f"invalid Sec-WebSocket-Key header: ; {b64_exc}", ) def test_truncated_key(self): + """Handshake fails when the Sec-WebSocket-Key header is truncated.""" server = ServerProtocol() - request = self.make_request() + request = make_request() del request.headers["Sec-WebSocket-Key"] - request.headers["Sec-WebSocket-Key"] = KEY[ - :16 - ] # 12 bytes instead of 16, Base64-encoded + # 12 bytes instead of 16, Base64-encoded + request.headers["Sec-WebSocket-Key"] = KEY[:16] response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 400) - with self.assertRaises(InvalidHeader) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidHeader, f"invalid Sec-WebSocket-Key header: {KEY[:16]}", ) def test_missing_version(self): + """Handshake fails when the Sec-WebSocket-Version header is missing.""" server = ServerProtocol() - request = self.make_request() + request = make_request() del request.headers["Sec-WebSocket-Version"] response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 400) - with self.assertRaises(InvalidHeader) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidHeader, "missing Sec-WebSocket-Version header", ) def test_multiple_version(self): + """Handshake fails when the Sec-WebSocket-Version header is repeated.""" server = ServerProtocol() - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Version"] = "11" response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 400) - with self.assertRaises(InvalidHeader) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidHeader, "invalid Sec-WebSocket-Version header: multiple values", ) def test_invalid_version(self): + """Handshake fails when the Sec-WebSocket-Version header is invalid.""" server = ServerProtocol() - request = self.make_request() + request = make_request() del request.headers["Sec-WebSocket-Version"] request.headers["Sec-WebSocket-Version"] = "11" response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 400) - with self.assertRaises(InvalidHeader) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidHeader, "invalid Sec-WebSocket-Version header: 11", ) - def test_no_origin(self): + def test_origin(self): + """Handshake succeeds when checking origin.""" server = ServerProtocol(origins=["https://example.com"]) - request = self.make_request() + request = make_request() + request.headers["Origin"] = "https://example.com" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 403) - with self.assertRaises(InvalidOrigin) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), - "missing Origin header", - ) + self.assertHandshakeSuccess(server) + self.assertEqual(server.origin, "https://example.com") - def test_origin(self): + def test_no_origin(self): + """Handshake fails when checking origin and the Origin header is missing.""" server = ServerProtocol(origins=["https://example.com"]) - request = self.make_request() - request.headers["Origin"] = "https://example.com" + request = make_request() response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) - self.assertEqual(server.origin, "https://example.com") + self.assertEqual(response.status_code, 403) + self.assertHandshakeError( + server, + InvalidOrigin, + "missing Origin header", + ) def test_unexpected_origin(self): + """Handshake fails when checking origin and the Origin header is unexpected.""" server = ServerProtocol(origins=["https://example.com"]) - request = self.make_request() + request = make_request() request.headers["Origin"] = "https://other.example.com" response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 403) - with self.assertRaises(InvalidOrigin) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidOrigin, "invalid Origin header: https://other.example.com", ) def test_multiple_origin(self): + """Handshake fails when checking origins and the Origin header is repeated.""" server = ServerProtocol( origins=["https://example.com", "https://other.example.com"] ) - request = self.make_request() + request = make_request() request.headers["Origin"] = "https://example.com" request.headers["Origin"] = "https://other.example.com" response = server.accept(request) + server.send_response(response) # This is prohibited by the HTTP specification, so the return code is # 400 Bad Request rather than 403 Forbidden. self.assertEqual(response.status_code, 400) - with self.assertRaises(InvalidHeader) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidHeader, "invalid Origin header: multiple values", ) def test_supported_origin(self): + """Handshake succeeds when checking origins and the origin is supported.""" server = ServerProtocol( origins=["https://example.com", "https://other.example.com"] ) - request = self.make_request() + request = make_request() request.headers["Origin"] = "https://other.example.com" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertEqual(server.origin, "https://other.example.com") def test_unsupported_origin(self): + """Handshake succeeds when checking origins and the origin is unsupported.""" server = ServerProtocol( origins=["https://example.com", "https://other.example.com"] ) - request = self.make_request() + request = make_request() request.headers["Origin"] = "https://original.example.com" response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 403) - with self.assertRaises(InvalidOrigin) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidOrigin, "invalid Origin header: https://original.example.com", ) def test_no_origin_accepted(self): + """Handshake succeeds when the lack of an origin is accepted.""" server = ServerProtocol(origins=[None]) - request = self.make_request() + request = make_request() response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertIsNone(server.origin) def test_no_extensions(self): + """Handshake succeeds without extensions.""" server = ServerProtocol() - request = self.make_request() - response = server.accept(request) - - self.assertEqual(response.status_code, 101) - self.assertNotIn("Sec-WebSocket-Extensions", response.headers) - self.assertEqual(server.extensions, []) - - def test_no_extension(self): - server = ServerProtocol(extensions=[ServerOpExtensionFactory()]) - request = self.make_request() + request = make_request() response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertNotIn("Sec-WebSocket-Extensions", response.headers) self.assertEqual(server.extensions, []) def test_extension(self): + """Server enables an extension when the client offers it.""" server = ServerProtocol(extensions=[ServerOpExtensionFactory()]) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertEqual(response.headers["Sec-WebSocket-Extensions"], "x-op; op") self.assertEqual(server.extensions, [OpExtension()]) - def test_unexpected_extension(self): + def test_extension_not_enabled(self): + """Server doesn't enable an extension when the client doesn't offer it.""" + server = ServerProtocol(extensions=[ServerOpExtensionFactory()]) + request = make_request() + response = server.accept(request) + server.send_response(response) + + self.assertHandshakeSuccess(server) + self.assertNotIn("Sec-WebSocket-Extensions", response.headers) + self.assertEqual(server.extensions, []) + + def test_no_extensions_supported(self): + """Client offers an extension, but the server doesn't support any.""" server = ServerProtocol() - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertNotIn("Sec-WebSocket-Extensions", response.headers) self.assertEqual(server.extensions, []) - def test_unsupported_extension(self): + def test_extension_not_supported(self): + """Client offers an extension, but the server doesn't support it.""" server = ServerProtocol(extensions=[ServerRsv2ExtensionFactory()]) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertNotIn("Sec-WebSocket-Extensions", response.headers) self.assertEqual(server.extensions, []) def test_supported_extension_parameters(self): + """Client offers an extension with parameters supported by the server.""" server = ServerProtocol(extensions=[ServerOpExtensionFactory("this")]) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op=this" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertEqual(response.headers["Sec-WebSocket-Extensions"], "x-op; op=this") self.assertEqual(server.extensions, [OpExtension("this")]) def test_unsupported_extension_parameters(self): + """Client offers an extension with parameters unsupported by the server.""" server = ServerProtocol(extensions=[ServerOpExtensionFactory("this")]) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertNotIn("Sec-WebSocket-Extensions", response.headers) self.assertEqual(server.extensions, []) def test_multiple_supported_extension_parameters(self): + """Server supports the same extension with several parameters.""" server = ServerProtocol( extensions=[ ServerOpExtensionFactory("this"), ServerOpExtensionFactory("that"), ] ) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertEqual(response.headers["Sec-WebSocket-Extensions"], "x-op; op=that") self.assertEqual(server.extensions, [OpExtension("that")]) def test_multiple_extensions(self): + """Server enables several extensions when the client offers them.""" server = ServerProtocol( extensions=[ServerOpExtensionFactory(), ServerRsv2ExtensionFactory()] ) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op" request.headers["Sec-WebSocket-Extensions"] = "x-rsv2" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertEqual( response.headers["Sec-WebSocket-Extensions"], "x-op; op, x-rsv2" ) self.assertEqual(server.extensions, [OpExtension(), Rsv2Extension()]) def test_multiple_extensions_order(self): + """Server respects the order of extensions set in its configuration.""" server = ServerProtocol( extensions=[ServerOpExtensionFactory(), ServerRsv2ExtensionFactory()] ) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-rsv2" request.headers["Sec-WebSocket-Extensions"] = "x-op; op" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertEqual( response.headers["Sec-WebSocket-Extensions"], "x-rsv2, x-op; op" ) self.assertEqual(server.extensions, [Rsv2Extension(), OpExtension()]) def test_no_subprotocols(self): + """Handshake succeeds without subprotocols.""" server = ServerProtocol() - request = self.make_request() + request = make_request() response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertNotIn("Sec-WebSocket-Protocol", response.headers) self.assertIsNone(server.subprotocol) - def test_no_subprotocol(self): + def test_no_subprotocol_requested(self): + """Server expects a subprotocol, but the client doesn't offer it.""" server = ServerProtocol(subprotocols=["chat"]) - request = self.make_request() + request = make_request() response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 400) - with self.assertRaises(NegotiationError) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + NegotiationError, "missing subprotocol", ) def test_subprotocol(self): + """Server enables a subprotocol when the client offers it.""" server = ServerProtocol(subprotocols=["chat"]) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Protocol"] = "chat" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertEqual(response.headers["Sec-WebSocket-Protocol"], "chat") self.assertEqual(server.subprotocol, "chat") - def test_unexpected_subprotocol(self): + def test_no_subprotocols_supported(self): + """Client offers a subprotocol, but the server doesn't support any.""" server = ServerProtocol() - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Protocol"] = "chat" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertNotIn("Sec-WebSocket-Protocol", response.headers) self.assertIsNone(server.subprotocol) def test_multiple_subprotocols(self): + """Server enables all of the subprotocols when the client offers them.""" server = ServerProtocol(subprotocols=["superchat", "chat"]) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Protocol"] = "chat" request.headers["Sec-WebSocket-Protocol"] = "superchat" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertEqual(response.headers["Sec-WebSocket-Protocol"], "superchat") self.assertEqual(server.subprotocol, "superchat") def test_supported_subprotocol(self): + """Server enables one of the subprotocols when the client offers it.""" server = ServerProtocol(subprotocols=["superchat", "chat"]) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Protocol"] = "chat" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertEqual(response.headers["Sec-WebSocket-Protocol"], "chat") self.assertEqual(server.subprotocol, "chat") def test_unsupported_subprotocol(self): + """Server expects one of the subprotocols, but the client doesn't offer any.""" server = ServerProtocol(subprotocols=["superchat", "chat"]) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Protocol"] = "otherchat" response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 400) - with self.assertRaises(NegotiationError) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + NegotiationError, "invalid subprotocol; expected one of superchat, chat", ) @@ -708,34 +845,40 @@ def optional_chat(protocol, subprotocols): return "chat" def test_select_subprotocol(self): + """Server enables a subprotocol with select_subprotocol.""" server = ServerProtocol(select_subprotocol=self.optional_chat) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Protocol"] = "chat" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertEqual(response.headers["Sec-WebSocket-Protocol"], "chat") self.assertEqual(server.subprotocol, "chat") def test_select_no_subprotocol(self): + """Server doesn't enable any subprotocol with select_subprotocol.""" server = ServerProtocol(select_subprotocol=self.optional_chat) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Protocol"] = "otherchat" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertNotIn("Sec-WebSocket-Protocol", response.headers) self.assertIsNone(server.subprotocol) class MiscTests(unittest.TestCase): def test_bypass_handshake(self): + """ServerProtocol bypasses the opening handshake.""" server = ServerProtocol(state=OPEN) server.receive_data(b"\x81\x86\x00\x00\x00\x00Hello!") [frame] = server.events_received() self.assertEqual(frame, Frame(OP_TEXT, b"Hello!")) def test_custom_logger(self): + """ServerProtocol accepts a logger argument.""" logger = logging.getLogger("test") with self.assertLogs("test", logging.DEBUG) as logs: ServerProtocol(logger=logger) @@ -744,6 +887,7 @@ def test_custom_logger(self): class BackwardsCompatibilityTests(DeprecationTestCase): def test_server_connection_class(self): + """ServerConnection is a deprecated alias for ServerProtocol.""" with self.assertDeprecationWarning( "ServerConnection was renamed to ServerProtocol" ): From 36409237c9377fb8fc2fbb90f9f74192777b518c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Sep 2024 19:04:08 +0200 Subject: [PATCH 1404/1539] Wait until state is CLOSED to acces close_exc. Fix #1449. --- docs/project/changelog.rst | 4 ++++ src/websockets/asyncio/connection.py | 14 +++++++++++--- src/websockets/sync/connection.py | 14 ++++++++++++-- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 456c15dac..615d3ab71 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -75,6 +75,10 @@ Bug fixes start the connection handler anymore when ``process_request`` or ``process_response`` returns an HTTP response. +* Fixed a bug in the :mod:`threading` implementation that could lead to + incorrect error reporting when closing a connection while + :meth:`~sync.connection.Connection.recv` is running. + 13.0.1 ------ diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 1b24f9af0..6af61a4a9 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -274,6 +274,8 @@ async def recv(self, decode: bool | None = None) -> Data: try: return await self.recv_messages.get(decode) except EOFError: + # Wait for the protocol state to be CLOSED before accessing close_exc. + await asyncio.shield(self.connection_lost_waiter) raise self.protocol.close_exc from self.recv_exc except ConcurrencyError: raise ConcurrencyError( @@ -329,6 +331,8 @@ async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data async for frame in self.recv_messages.get_iter(decode): yield frame except EOFError: + # Wait for the protocol state to be CLOSED before accessing close_exc. + await asyncio.shield(self.connection_lost_waiter) raise self.protocol.close_exc from self.recv_exc except ConcurrencyError: raise ConcurrencyError( @@ -864,6 +868,7 @@ async def send_context( # raise an exception. if raise_close_exc: self.close_transport() + # Wait for the protocol state to be CLOSED before accessing close_exc. await asyncio.shield(self.connection_lost_waiter) raise self.protocol.close_exc from original_exc @@ -926,11 +931,14 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: self.transport = transport def connection_lost(self, exc: Exception | None) -> None: - self.protocol.receive_eof() # receive_eof is idempotent + # Calling protocol.receive_eof() is safe because it's idempotent. + # This guarantees that the protocol state becomes CLOSED. + self.protocol.receive_eof() + assert self.protocol.state is CLOSED - # Abort recv() and pending pings with a ConnectionClosed exception. - # Set recv_exc first to get proper exception reporting. self.set_recv_exc(exc) + + # Abort recv() and pending pings with a ConnectionClosed exception. self.recv_messages.close() self.abort_pings() diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 65a7b63ed..77b488093 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -206,6 +206,8 @@ def recv(self, timeout: float | None = None) -> Data: try: return self.recv_messages.get(timeout) except EOFError: + # Wait for the protocol state to be CLOSED before accessing close_exc. + self.recv_events_thread.join() raise self.protocol.close_exc from self.recv_exc except ConcurrencyError: raise ConcurrencyError( @@ -240,6 +242,8 @@ def recv_streaming(self) -> Iterator[Data]: for frame in self.recv_messages.get_iter(): yield frame except EOFError: + # Wait for the protocol state to be CLOSED before accessing close_exc. + self.recv_events_thread.join() raise self.protocol.close_exc from self.recv_exc except ConcurrencyError: raise ConcurrencyError( @@ -629,8 +633,6 @@ def recv_events(self) -> None: self.logger.error("unexpected internal error", exc_info=True) with self.protocol_mutex: self.set_recv_exc(exc) - # We don't know where we crashed. Force protocol state to CLOSED. - self.protocol.state = CLOSED finally: # This isn't expected to raise an exception. self.close_socket() @@ -738,6 +740,7 @@ def send_context( # raise an exception. if raise_close_exc: self.close_socket() + # Wait for the protocol state to be CLOSED before accessing close_exc. self.recv_events_thread.join() raise self.protocol.close_exc from original_exc @@ -788,4 +791,11 @@ def close_socket(self) -> None: except OSError: pass # socket is already closed self.socket.close() + + # Calling protocol.receive_eof() is safe because it's idempotent. + # This guarantees that the protocol state becomes CLOSED. + self.protocol.receive_eof() + assert self.protocol.state is CLOSED + + # Abort recv() with a ConnectionClosed exception. self.recv_messages.close() From 0afccc956639684af5082513da8ba8f5105448f5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Sep 2024 19:12:07 +0200 Subject: [PATCH 1405/1539] Clarify comment. --- src/websockets/sync/connection.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 77b488093..97588870e 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -87,9 +87,9 @@ def __init__( # Mapping of ping IDs to pong waiters, in chronological order. self.ping_waiters: dict[bytes, threading.Event] = {} - # Receiving events from the socket. This thread explicitly is marked as - # to support creating a connection in a non-daemon thread then using it - # in a daemon thread; this shouldn't block the intpreter from exiting. + # Receiving events from the socket. This thread is marked as daemon to + # allow creating a connection in a non-daemon thread and using it in a + # daemon thread. This mustn't prevent the interpreter from exiting. self.recv_events_thread = threading.Thread( target=self.recv_events, daemon=True, From 4d229bf9f583d593aa103287aee0a77c9fbc3a79 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Sep 2024 19:13:12 +0200 Subject: [PATCH 1406/1539] Release version 13.1. --- docs/project/changelog.rst | 2 +- src/websockets/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 615d3ab71..7e4bce9c6 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -30,7 +30,7 @@ notice. 13.1 ---- -*In development* +*September 21, 2024* Backwards-incompatible changes .............................. diff --git a/src/websockets/version.py b/src/websockets/version.py index bbda56d6b..00b0a985e 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,7 +18,7 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = False +released = True tag = version = commit = "13.1" From 37c7f6529a0877c87f191bc566d2e14f7c96e192 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Sep 2024 19:16:44 +0200 Subject: [PATCH 1407/1539] Start version 14.0. --- docs/project/changelog.rst | 7 +++++++ src/websockets/version.py | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 7e4bce9c6..65e26008f 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,13 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. +.. _14.0: + +14.0 +---- + +*In development* + .. _13.1: 13.1 diff --git a/src/websockets/version.py b/src/websockets/version.py index 00b0a985e..34fc2eaef 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,9 +18,9 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = True +released = False -tag = version = commit = "13.1" +tag = version = commit = "14.0" if not released: # pragma: no cover From 44ccee17c519ea1397ee28a8ac3a7d7685cd0b89 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Sep 2024 19:30:56 +0200 Subject: [PATCH 1408/1539] Drop Python 3.8. It is EOL at the end of October. --- .github/workflows/tests.yml | 1 - docs/howto/django.rst | 3 +-- docs/intro/index.rst | 2 +- docs/project/changelog.rst | 14 +++++++++----- pyproject.toml | 3 +-- src/websockets/asyncio/client.py | 5 +++-- src/websockets/asyncio/connection.py | 11 ++--------- src/websockets/asyncio/messages.py | 10 ++-------- src/websockets/asyncio/server.py | 16 ++++------------ src/websockets/client.py | 7 ++++--- src/websockets/datastructures.py | 15 +++------------ src/websockets/extensions/base.py | 2 +- src/websockets/extensions/permessage_deflate.py | 3 ++- src/websockets/frames.py | 3 ++- src/websockets/headers.py | 3 ++- src/websockets/http11.py | 3 ++- src/websockets/imports.py | 3 ++- src/websockets/legacy/auth.py | 6 +++--- src/websockets/legacy/client.py | 10 ++-------- src/websockets/legacy/framing.py | 3 ++- src/websockets/legacy/protocol.py | 13 ++----------- src/websockets/legacy/server.py | 16 +++------------- src/websockets/protocol.py | 7 ++++--- src/websockets/server.py | 5 +++-- src/websockets/streams.py | 2 +- src/websockets/sync/client.py | 3 ++- src/websockets/sync/connection.py | 6 +++--- src/websockets/sync/messages.py | 6 +++--- src/websockets/sync/server.py | 7 ++++--- src/websockets/typing.py | 8 +++----- tests/asyncio/test_client.py | 8 -------- tox.ini | 1 - 32 files changed, 76 insertions(+), 129 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 43193ea50..beaf9d12b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -55,7 +55,6 @@ jobs: strategy: matrix: python: - - "3.8" - "3.9" - "3.10" - "3.11" diff --git a/docs/howto/django.rst b/docs/howto/django.rst index dada9c5e4..4fe2311cb 100644 --- a/docs/howto/django.rst +++ b/docs/howto/django.rst @@ -121,8 +121,7 @@ authentication fails, it closes the connection and exits. When we call an API that makes a database query such as ``get_user()``, we wrap the call in :func:`~asyncio.to_thread`. Indeed, the Django ORM doesn't support asynchronous I/O. It would block the event loop if it didn't run in a -separate thread. :func:`~asyncio.to_thread` is available since Python 3.9. In -earlier versions, use :meth:`~asyncio.loop.run_in_executor` instead. +separate thread. Finally, we start a server with :func:`~websockets.asyncio.server.serve`. diff --git a/docs/intro/index.rst b/docs/intro/index.rst index 095262a20..642e50094 100644 --- a/docs/intro/index.rst +++ b/docs/intro/index.rst @@ -6,7 +6,7 @@ Getting started Requirements ------------ -websockets requires Python ≥ 3.8. +websockets requires Python ≥ 3.9. .. admonition:: Use the most recent Python release :class: tip diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 65e26008f..5f07fc09f 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -32,6 +32,15 @@ notice. *In development* +Backwards-incompatible changes +.............................. + +.. admonition:: websockets 14.0 requires Python ≥ 3.9. + :class: tip + + websockets 13.1 is the last version supporting Python 3.8. + + .. _13.1: 13.1 @@ -106,11 +115,6 @@ Bug fixes Backwards-incompatible changes .............................. -.. admonition:: websockets 13.0 requires Python ≥ 3.8. - :class: tip - - websockets 12.0 is the last version supporting Python 3.7. - .. admonition:: Receiving the request path in the second parameter of connection handlers is deprecated. :class: note diff --git a/pyproject.toml b/pyproject.toml index fde9c3226..6a0ab8d7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [project] name = "websockets" description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" -requires-python = ">=3.8" +requires-python = ">=3.9" license = { text = "BSD-3-Clause" } authors = [ { name = "Aymeric Augustin", email = "aymeric.augustin@m4x.org" }, @@ -19,7 +19,6 @@ classifiers = [ "Operating System :: OS Independent", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index b1beb3e00..23b1a348a 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -4,8 +4,9 @@ import logging import os import urllib.parse +from collections.abc import AsyncIterator, Generator, Sequence from types import TracebackType -from typing import Any, AsyncIterator, Callable, Generator, Sequence +from typing import Any, Callable from ..client import ClientProtocol, backoff from ..datastructures import HeadersLike @@ -492,7 +493,7 @@ async def __aexit__( # async for ... in connect(...): async def __aiter__(self) -> AsyncIterator[ClientConnection]: - delays: Generator[float, None, None] | None = None + delays: Generator[float] | None = None while True: try: async with self as protocol: diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 6af61a4a9..702e69995 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -8,16 +8,9 @@ import struct import sys import uuid +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterable, Mapping from types import TracebackType -from typing import ( - Any, - AsyncIterable, - AsyncIterator, - Awaitable, - Iterable, - Mapping, - cast, -) +from typing import Any, cast from ..exceptions import ( ConcurrencyError, diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index c2b4afd67..e3ec5062f 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -3,14 +3,8 @@ import asyncio import codecs import collections -from typing import ( - Any, - AsyncIterator, - Callable, - Generic, - Iterable, - TypeVar, -) +from collections.abc import AsyncIterator, Iterable +from typing import Any, Callable, Generic, TypeVar from ..exceptions import ConcurrencyError from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 19dae44b7..e11dd91f1 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -6,17 +6,9 @@ import logging import socket import sys +from collections.abc import Awaitable, Generator, Iterable, Sequence from types import TracebackType -from typing import ( - Any, - Awaitable, - Callable, - Generator, - Iterable, - Sequence, - Tuple, - cast, -) +from typing import Any, Callable, cast from ..exceptions import InvalidHeader from ..extensions.base import ServerExtensionFactory @@ -905,9 +897,9 @@ def basic_auth( if credentials is not None: if is_credentials(credentials): - credentials_list = [cast(Tuple[str, str], credentials)] + credentials_list = [cast(tuple[str, str], credentials)] elif isinstance(credentials, Iterable): - credentials_list = list(cast(Iterable[Tuple[str, str]], credentials)) + credentials_list = list(cast(Iterable[tuple[str, str]], credentials)) if not all(is_credentials(item) for item in credentials_list): raise TypeError(f"invalid credentials argument: {credentials}") else: diff --git a/src/websockets/client.py b/src/websockets/client.py index e5f294986..bce82d66b 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -3,7 +3,8 @@ import os import random import warnings -from typing import Any, Generator, Sequence +from collections.abc import Generator, Sequence +from typing import Any from .datastructures import Headers, MultipleValuesError from .exceptions import ( @@ -313,7 +314,7 @@ def send_request(self, request: Request) -> None: self.writes.append(request.serialize()) - def parse(self) -> Generator[None, None, None]: + def parse(self) -> Generator[None]: if self.state is CONNECTING: try: response = yield from Response.parse( @@ -374,7 +375,7 @@ def backoff( min_delay: float = BACKOFF_MIN_DELAY, max_delay: float = BACKOFF_MAX_DELAY, factor: float = BACKOFF_FACTOR, -) -> Generator[float, None, None]: +) -> Generator[float]: """ Generate a series of backoff delays between reconnection attempts. diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index 106d6f393..77b6f86fa 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -1,15 +1,7 @@ from __future__ import annotations -from typing import ( - Any, - Iterable, - Iterator, - Mapping, - MutableMapping, - Protocol, - Tuple, - Union, -) +from collections.abc import Iterable, Iterator, Mapping, MutableMapping +from typing import Any, Protocol, Union __all__ = ["Headers", "HeadersLike", "MultipleValuesError"] @@ -179,8 +171,7 @@ def __getitem__(self, key: str) -> str: ... HeadersLike = Union[ Headers, Mapping[str, str], - # Change to tuple[str, str] when dropping Python < 3.9. - Iterable[Tuple[str, str]], + Iterable[tuple[str, str]], SupportsKeysAndGetItem, ] """ diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index 75bae6b77..42dd6c5fa 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence +from collections.abc import Sequence from ..frames import Frame from ..typing import ExtensionName, ExtensionParameter diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 25d2c1c45..f962b65fb 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -2,7 +2,8 @@ import dataclasses import zlib -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any from .. import frames from ..exceptions import ( diff --git a/src/websockets/frames.py b/src/websockets/frames.py index a63bdc3b6..dace2c902 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -6,7 +6,8 @@ import os import secrets import struct -from typing import Callable, Generator, Sequence +from collections.abc import Generator, Sequence +from typing import Callable from .exceptions import PayloadTooBig, ProtocolError diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 9103018a0..e05948a1f 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -4,7 +4,8 @@ import binascii import ipaddress import re -from typing import Callable, Sequence, TypeVar, cast +from collections.abc import Sequence +from typing import Callable, TypeVar, cast from .exceptions import InvalidHeaderFormat, InvalidHeaderValue from .typing import ( diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 47cef7a9b..af542c77b 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -5,7 +5,8 @@ import re import sys import warnings -from typing import Callable, Generator +from collections.abc import Generator +from typing import Callable from .datastructures import Headers from .exceptions import SecurityError diff --git a/src/websockets/imports.py b/src/websockets/imports.py index bb80e4eac..c63fb212e 100644 --- a/src/websockets/imports.py +++ b/src/websockets/imports.py @@ -1,7 +1,8 @@ from __future__ import annotations import warnings -from typing import Any, Iterable +from collections.abc import Iterable +from typing import Any __all__ = ["lazy_import"] diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index 4d030e5e2..a262fcd79 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -3,7 +3,8 @@ import functools import hmac import http -from typing import Any, Awaitable, Callable, Iterable, Tuple, cast +from collections.abc import Awaitable, Iterable +from typing import Any, Callable, cast from ..datastructures import Headers from ..exceptions import InvalidHeader @@ -13,8 +14,7 @@ __all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"] -# Change to tuple[str, str] when dropping Python < 3.9. -Credentials = Tuple[str, str] +Credentials = tuple[str, str] def is_credentials(value: Any) -> bool: diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index ec4c2ff64..116445e25 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -7,15 +7,9 @@ import random import urllib.parse import warnings +from collections.abc import AsyncIterator, Generator, Sequence from types import TracebackType -from typing import ( - Any, - AsyncIterator, - Callable, - Generator, - Sequence, - cast, -) +from typing import Any, Callable, cast from ..asyncio.compatibility import asyncio_timeout from ..datastructures import Headers, HeadersLike diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index 4c2f8c23f..4ec194ed7 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -1,7 +1,8 @@ from __future__ import annotations import struct -from typing import Any, Awaitable, Callable, NamedTuple, Sequence +from collections.abc import Awaitable, Sequence +from typing import Any, Callable, NamedTuple from .. import extensions, frames from ..exceptions import PayloadTooBig, ProtocolError diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 998e390d4..cedde6200 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -11,17 +11,8 @@ import time import uuid import warnings -from typing import ( - Any, - AsyncIterable, - AsyncIterator, - Awaitable, - Callable, - Deque, - Iterable, - Mapping, - cast, -) +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterable, Mapping +from typing import Any, Callable, Deque, cast from ..asyncio.compatibility import asyncio_timeout from ..datastructures import Headers diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 2cb9b1abb..9326b6100 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -8,18 +8,9 @@ import logging import socket import warnings +from collections.abc import Awaitable, Generator, Iterable, Sequence from types import TracebackType -from typing import ( - Any, - Awaitable, - Callable, - Generator, - Iterable, - Sequence, - Tuple, - Union, - cast, -) +from typing import Any, Callable, Union, cast from ..asyncio.compatibility import asyncio_timeout from ..datastructures import Headers, HeadersLike, MultipleValuesError @@ -59,8 +50,7 @@ # Change to HeadersLike | ... when dropping Python < 3.10. HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]] -# Change to tuple[...] when dropping Python < 3.9. -HTTPResponse = Tuple[StatusLike, HeadersLike, bytes] +HTTPResponse = tuple[StatusLike, HeadersLike, bytes] class WebSocketServerProtocol(WebSocketCommonProtocol): diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 8751ebdb4..091b4a23a 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -3,7 +3,8 @@ import enum import logging import uuid -from typing import Generator, Union +from collections.abc import Generator +from typing import Union from .exceptions import ( ConnectionClosed, @@ -529,7 +530,7 @@ def close_expected(self) -> bool: # Private methods for receiving data. - def parse(self) -> Generator[None, None, None]: + def parse(self) -> Generator[None]: """ Parse incoming data into frames. @@ -600,7 +601,7 @@ def parse(self) -> Generator[None, None, None]: yield raise AssertionError("parse() shouldn't step after error") - def discard(self) -> Generator[None, None, None]: + def discard(self) -> Generator[None]: """ Discard incoming data. diff --git a/src/websockets/server.py b/src/websockets/server.py index 006d5bdd5..9fe970619 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -5,7 +5,8 @@ import email.utils import http import warnings -from typing import Any, Callable, Generator, Sequence, cast +from collections.abc import Generator, Sequence +from typing import Any, Callable, cast from .datastructures import Headers, MultipleValuesError from .exceptions import ( @@ -555,7 +556,7 @@ def send_response(self, response: Response) -> None: self.parser = self.discard() next(self.parser) # start coroutine - def parse(self) -> Generator[None, None, None]: + def parse(self) -> Generator[None]: if self.state is CONNECTING: try: request = yield from Request.parse( diff --git a/src/websockets/streams.py b/src/websockets/streams.py index 956f139d4..f52e6193a 100644 --- a/src/websockets/streams.py +++ b/src/websockets/streams.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Generator +from collections.abc import Generator class StreamReader: diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index d1e20a757..5e1ba6d84 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -4,7 +4,8 @@ import ssl as ssl_module import threading import warnings -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any from ..client import ClientProtocol from ..datastructures import HeadersLike diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 97588870e..8c5df9592 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -7,8 +7,9 @@ import struct import threading import uuid +from collections.abc import Iterable, Iterator, Mapping from types import TracebackType -from typing import Any, Iterable, Iterator, Mapping +from typing import Any from ..exceptions import ( ConcurrencyError, @@ -239,8 +240,7 @@ def recv_streaming(self) -> Iterator[Data]: """ try: - for frame in self.recv_messages.get_iter(): - yield frame + yield from self.recv_messages.get_iter() except EOFError: # Wait for the protocol state to be CLOSED before accessing close_exc. self.recv_events_thread.join() diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 8d090538f..b96cd6880 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -3,7 +3,8 @@ import codecs import queue import threading -from typing import Iterator, cast +from collections.abc import Iterator +from typing import cast from ..exceptions import ConcurrencyError from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame @@ -150,8 +151,7 @@ def get_iter(self) -> Iterator[Data]: chunks = self.chunks self.chunks = [] self.chunks_queue = cast( - # Remove quotes around type when dropping Python < 3.9. - "queue.SimpleQueue[Data | None]", + queue.SimpleQueue[Data | None], queue.SimpleQueue(), ) diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 1b7cbb4b4..464c4a173 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -10,8 +10,9 @@ import sys import threading import warnings +from collections.abc import Iterable, Sequence from types import TracebackType -from typing import Any, Callable, Iterable, Sequence, Tuple, cast +from typing import Any, Callable, cast from ..exceptions import InvalidHeader from ..extensions.base import ServerExtensionFactory @@ -663,9 +664,9 @@ def basic_auth( if credentials is not None: if is_credentials(credentials): - credentials_list = [cast(Tuple[str, str], credentials)] + credentials_list = [cast(tuple[str, str], credentials)] elif isinstance(credentials, Iterable): - credentials_list = list(cast(Iterable[Tuple[str, str]], credentials)) + credentials_list = list(cast(Iterable[tuple[str, str]], credentials)) if not all(is_credentials(item) for item in credentials_list): raise TypeError(f"invalid credentials argument: {credentials}") else: diff --git a/src/websockets/typing.py b/src/websockets/typing.py index 447fe79da..0a37141c6 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -3,7 +3,7 @@ import http import logging import typing -from typing import Any, List, NewType, Optional, Tuple, Union +from typing import Any, NewType, Optional, Union __all__ = [ @@ -56,16 +56,14 @@ ExtensionName = NewType("ExtensionName", str) """Name of a WebSocket extension.""" -# Change to tuple[str, Optional[str]] when dropping Python < 3.9. # Change to tuple[str, str | None] when dropping Python < 3.10. -ExtensionParameter = Tuple[str, Optional[str]] +ExtensionParameter = tuple[str, Optional[str]] """Parameter of a WebSocket extension.""" # Private types -# Change to tuple[.., list[...]] when dropping Python < 3.9. -ExtensionHeader = Tuple[ExtensionName, List[ExtensionParameter]] +ExtensionHeader = tuple[ExtensionName, list[ExtensionParameter]] """Extension in a ``Sec-WebSocket-Extensions`` header.""" diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 999ef1b71..9354a6e0a 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -177,10 +177,6 @@ def process_request(connection, request): self.assertEqual(iterations, 5) self.assertEqual(successful, 2) - @unittest.skipUnless( - hasattr(http.HTTPStatus, "IM_A_TEAPOT"), - "test requires Python 3.9", - ) async def test_reconnect_with_custom_process_exception(self): """Client runs process_exception to tell if errors are retryable or fatal.""" iteration = 0 @@ -214,10 +210,6 @@ def process_exception(exc): "🫖 💔 ☕️", ) - @unittest.skipUnless( - hasattr(http.HTTPStatus, "IM_A_TEAPOT"), - "test requires Python 3.9", - ) async def test_reconnect_with_custom_process_exception_raising_exception(self): """Client supports raising an exception in process_exception.""" diff --git a/tox.ini b/tox.ini index cba9b290b..0bcec5ded 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,5 @@ [tox] env_list = - py38 py39 py310 py311 From e44a1eadf3c287335fd10c381ba5ccb8bd1ff4ce Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 22 Sep 2024 12:00:24 +0200 Subject: [PATCH 1409/1539] Document why packagers mustn't run the test suite. Refs #1509, #1496, #1427, #1426, #1081, #1026, perhaps others. --- docs/project/contributing.rst | 26 ++++++++++++++++++++++++-- tests/utils.py | 18 +++++++++++++++--- 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/docs/project/contributing.rst b/docs/project/contributing.rst index 020ed7ad8..3988c028a 100644 --- a/docs/project/contributing.rst +++ b/docs/project/contributing.rst @@ -17,8 +17,8 @@ apologies. I know I can mess up. I can't expect you to tell me, but if you choose to do so, I'll do my best to handle criticism constructively. -- Aymeric)* -Contributions -------------- +Contributing +------------ Bug reports, patches and suggestions are welcome! @@ -34,6 +34,28 @@ websockets. .. _issue: https://github.com/python-websockets/websockets/issues/new .. _pull request: https://github.com/python-websockets/websockets/compare/ +Packaging +--------- + +Some distributions package websockets so that it can be installed with the +system package manager rather than with pip, possibly in a virtualenv. + +If you're packaging websockets for a distribution, you must use `releases +published on PyPI`_ as input. You may check `SLSA attestations on GitHub`_. + +.. _releases published on PyPI: https://pypi.org/project/websockets/#files +.. _SLSA attestations on GitHub: https://github.com/python-websockets/websockets/attestations + +You mustn't rely on the git repository as input. Specifically, you mustn't +attempt to run the main test suite. It isn't treated as a deliverable of the +project. It doesn't do what you think it does. It's designed for the needs of +developers, not packagers. + +On a typical build farm for a distribution, tests that exercise timeouts will +fail randomly. Indeed, the test suite is optimized for running very fast, with a +tolerable level of flakiness, on a high-end laptop without noisy neighbors. This +isn't your context. + Questions --------- diff --git a/tests/utils.py b/tests/utils.py index 960439135..639fb7fe5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -9,6 +9,8 @@ import unittest import warnings +from websockets.version import released + # Generate TLS certificate with: # $ openssl req -x509 -config test_localhost.cnf -days 15340 -newkey rsa:2048 \ @@ -39,9 +41,19 @@ DATE = email.utils.formatdate(usegmt=True) -# Unit for timeouts. May be increased on slow machines by setting the -# WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. -MS = 0.001 * float(os.environ.get("WEBSOCKETS_TESTS_TIMEOUT_FACTOR", "1")) +# Unit for timeouts. May be increased in slow or noisy environments by setting +# the WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. + +# Downstream distributors insist on running the test suite despites my pleas to +# the contrary. They do it on build farms with unstable performance, leading to +# flakiness, and then they file bugs. Make tests 100x slower to avoid flakiness. + +MS = 0.001 * float( + os.environ.get( + "WEBSOCKETS_TESTS_TIMEOUT_FACTOR", + "100" if released else "1", + ) +) # PyPy, asyncio's debug mode, and coverage penalize performance of this # test suite. Increase timeouts to reduce the risk of spurious failures. From 4f4e64442e84763ca634f67ba0062c3fb8c985c6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 22 Sep 2024 12:46:55 +0200 Subject: [PATCH 1410/1539] Restore compatibility with Python 3.9. It was broken in 44ccee17. --- src/websockets/sync/messages.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index b96cd6880..997fa98df 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -151,7 +151,8 @@ def get_iter(self) -> Iterator[Data]: chunks = self.chunks self.chunks = [] self.chunks_queue = cast( - queue.SimpleQueue[Data | None], + # Remove quotes around type when dropping Python < 3.10. + "queue.SimpleQueue[Data | None]", queue.SimpleQueue(), ) From ddafc6682a3f6d0bed9d88dfbecd15dd246973bf Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 22 Sep 2024 20:25:10 +0200 Subject: [PATCH 1411/1539] Split "getting support" from "contributing". Also add a page about financial contributions. --- docs/project/contributing.rst | 31 ---------------------- docs/project/index.rst | 4 ++- docs/project/sponsoring.rst | 11 ++++++++ docs/project/support.rst | 49 +++++++++++++++++++++++++++++++++++ 4 files changed, 63 insertions(+), 32 deletions(-) create mode 100644 docs/project/sponsoring.rst create mode 100644 docs/project/support.rst diff --git a/docs/project/contributing.rst b/docs/project/contributing.rst index 3988c028a..6ecd175f8 100644 --- a/docs/project/contributing.rst +++ b/docs/project/contributing.rst @@ -55,34 +55,3 @@ On a typical build farm for a distribution, tests that exercise timeouts will fail randomly. Indeed, the test suite is optimized for running very fast, with a tolerable level of flakiness, on a high-end laptop without noisy neighbors. This isn't your context. - -Questions ---------- - -GitHub issues aren't a good medium for handling questions. There are better -places to ask questions, for example Stack Overflow. - -If you want to ask a question anyway, please make sure that: - -- it's a question about websockets and not about :mod:`asyncio`; -- it isn't answered in the documentation; -- it wasn't asked already. - -A good question can be written as a suggestion to improve the documentation. - -Cryptocurrency users --------------------- - -websockets appears to be quite popular for interfacing with Bitcoin or other -cryptocurrency trackers. I'm strongly opposed to Bitcoin's carbon footprint. - -I'm aware of efforts to build proof-of-stake models. I'll care once the total -energy consumption of all cryptocurrencies drops to a non-bullshit level. - -You already negated all of humanity's efforts to develop renewable energy. -Please stop heating the planet where my children will have to live. - -Since websockets is released under an open-source license, you can use it for -any purpose you like. However, I won't spend any of my time to help you. - -I will summarily close issues related to Bitcoin or cryptocurrency in any way. diff --git a/docs/project/index.rst b/docs/project/index.rst index 459146345..56c98196a 100644 --- a/docs/project/index.rst +++ b/docs/project/index.rst @@ -8,5 +8,7 @@ This is about websockets-the-project rather than websockets-the-software. changelog contributing - license + sponsoring For enterprise + support + license diff --git a/docs/project/sponsoring.rst b/docs/project/sponsoring.rst new file mode 100644 index 000000000..77a4fd1d8 --- /dev/null +++ b/docs/project/sponsoring.rst @@ -0,0 +1,11 @@ +Sponsoring +========== + +You may sponsor the development of websockets through: + +* `GitHub Sponsors`_ +* `Open Collective`_ +* :doc:`Tidelift ` + +.. _GitHub Sponsors: https://github.com/sponsors/python-websockets +.. _Open Collective: https://opencollective.com/websockets diff --git a/docs/project/support.rst b/docs/project/support.rst new file mode 100644 index 000000000..21aad6e02 --- /dev/null +++ b/docs/project/support.rst @@ -0,0 +1,49 @@ +Getting support +=============== + +.. admonition:: There are no free support channels. + :class: tip + + websockets is an open-source project. It's primarily maintained by one + person as a hobby. + + For this reason, the focus is on flawless code and self-service + documentation, not support. + +Enterprise +---------- + +websockets is maintained with high standards, making it suitable for enterprise +use cases. Additional guarantees are available via :doc:`Tidelift `. +If you're using it in a professional setting, consider subscribing. + +Questions +--------- + +GitHub issues aren't a good medium for handling questions. There are better +places to ask questions, for example Stack Overflow. + +If you want to ask a question anyway, please make sure that: + +- it's a question about websockets and not about :mod:`asyncio`; +- it isn't answered in the documentation; +- it wasn't asked already. + +A good question can be written as a suggestion to improve the documentation. + +Cryptocurrency users +-------------------- + +websockets appears to be quite popular for interfacing with Bitcoin or other +cryptocurrency trackers. I'm strongly opposed to Bitcoin's carbon footprint. + +I'm aware of efforts to build proof-of-stake models. I'll care once the total +energy consumption of all cryptocurrencies drops to a non-bullshit level. + +You already negated all of humanity's efforts to develop renewable energy. +Please stop heating the planet where my children will have to live. + +Since websockets is released under an open-source license, you can use it for +any purpose you like. However, I won't spend any of my time to help you. + +I will summarily close issues related to cryptocurrency in any way. From a942fcc13f47d9a5bdcd93b6f84da21e1d185e63 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Sep 2024 22:05:54 +0200 Subject: [PATCH 1412/1539] Switch convenience aliases to the new implementation. --- docs/howto/upgrade.rst | 29 ++++++++++++++--------------- docs/project/changelog.rst | 11 +++++++++++ docs/reference/index.rst | 2 +- src/websockets/__init__.py | 38 ++++++++++++++++++++------------------ 4 files changed, 46 insertions(+), 34 deletions(-) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index f3e42591e..5b1b8e4a2 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -3,8 +3,8 @@ Upgrade to the new :mod:`asyncio` implementation .. currentmodule:: websockets -The new :mod:`asyncio` implementation is a rewrite of the original -implementation of websockets. +The new :mod:`asyncio` implementation, which is now the default, is a rewrite of +the original implementation of websockets. It provides a very similar API. However, there are a few differences. @@ -70,9 +70,8 @@ For context, the ``websockets`` package is structured as follows: * The new implementation is found in the ``websockets.asyncio`` package. * The original implementation was moved to the ``websockets.legacy`` package. -* The ``websockets`` package provides aliases for convenience. Currently, they - point to the original implementation. They will be updated to point to the new - implementation when it feels mature. +* The ``websockets`` package provides aliases for convenience. They were + switched to the new implementation in version 14.0. * The ``websockets.client`` and ``websockets.server`` packages provide aliases for backwards-compatibility with earlier versions of websockets. They will be deprecated together with the original implementation. @@ -90,12 +89,12 @@ Client APIs +-------------------------------------------------------------------+-----------------------------------------------------+ | Legacy :mod:`asyncio` implementation | New :mod:`asyncio` implementation | +===================================================================+=====================================================+ -| ``websockets.connect()`` |br| | :func:`websockets.asyncio.client.connect` | -| ``websockets.client.connect()`` |br| | | +| ``websockets.connect()`` *(before 14.0)* |br| | ``websockets.connect()`` *(since 14.0)* |br| | +| ``websockets.client.connect()`` |br| | :func:`websockets.asyncio.client.connect` | | :func:`websockets.legacy.client.connect` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.unix_connect()`` |br| | :func:`websockets.asyncio.client.unix_connect` | -| ``websockets.client.unix_connect()`` |br| | | +| ``websockets.unix_connect()`` *(before 14.0)* |br| | ``websockets.unix_connect()`` *(since 14.0)* |br| | +| ``websockets.client.unix_connect()`` |br| | :func:`websockets.asyncio.client.unix_connect` | | :func:`websockets.legacy.client.unix_connect` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.WebSocketClientProtocol`` |br| | :class:`websockets.asyncio.client.ClientConnection` | @@ -109,12 +108,12 @@ Server APIs +-------------------------------------------------------------------+-----------------------------------------------------+ | Legacy :mod:`asyncio` implementation | New :mod:`asyncio` implementation | +===================================================================+=====================================================+ -| ``websockets.serve()`` |br| | :func:`websockets.asyncio.server.serve` | -| ``websockets.server.serve()`` |br| | | +| ``websockets.serve()`` *(before 14.0)* |br| | ``websockets.serve()`` *(since 14.0)* |br| | +| ``websockets.server.serve()`` |br| | :func:`websockets.asyncio.server.serve` | | :func:`websockets.legacy.server.serve` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.unix_serve()`` |br| | :func:`websockets.asyncio.server.unix_serve` | -| ``websockets.server.unix_serve()`` |br| | | +| ``websockets.unix_serve()`` *(before 14.0)* |br| | ``websockets.unix_serve()`` *(since 14.0)* |br| | +| ``websockets.server.unix_serve()`` |br| | :func:`websockets.asyncio.server.unix_serve` | | :func:`websockets.legacy.server.unix_serve` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.WebSocketServer`` |br| | :class:`websockets.asyncio.server.Server` | @@ -125,8 +124,8 @@ Server APIs | ``websockets.server.WebSocketServerProtocol`` |br| | | | :class:`websockets.legacy.server.WebSocketServerProtocol` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.broadcast()`` |br| | :func:`websockets.asyncio.server.broadcast` | -| :func:`websockets.legacy.server.broadcast()` | | +| ``websockets.broadcast()`` *(before 14.0)* |br| | ``websockets.broadcast()`` *(since 14.0)* |br| | +| :func:`websockets.legacy.server.broadcast()` | :func:`websockets.asyncio.server.broadcast` | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.BasicAuthWebSocketServerProtocol`` |br| | See below :ref:`how to migrate ` to | | ``websockets.auth.BasicAuthWebSocketServerProtocol`` |br| | :func:`websockets.asyncio.server.basic_auth`. | diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 5f07fc09f..45e4ef0cc 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -40,6 +40,17 @@ Backwards-incompatible changes websockets 13.1 is the last version supporting Python 3.8. +.. admonition:: The new :mod:`asyncio` implementation is now the default. + :class: caution + + The following aliases in the ``websockets`` package were switched to the new + :mod:`asyncio` implementation:: + + from websockets import connect, unix_connext + from websockets import broadcast, serve, unix_serve + + If you're using any of them, then you must follow the :doc:`upgrade guide + <../howto/upgrade>` immediately. .. _13.1: diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 77b538b78..ed2341cc6 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -98,5 +98,5 @@ guarantees of behavior or backwards-compatibility for private APIs. Convenience imports ------------------- -For convenience, many public APIs can be imported directly from the +For convenience, some public APIs can be imported directly from the ``websockets`` package. diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 54591e9fd..036e71c23 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -7,6 +7,14 @@ __all__ = [ + # .asyncio.client + "connect", + "unix_connect", + # .asyncio.server + "basic_auth", + "broadcast", + "serve", + "unix_serve", # .client "ClientProtocol", # .datastructures @@ -41,8 +49,6 @@ "basic_auth_protocol_factory", # .legacy.client "WebSocketClientProtocol", - "connect", - "unix_connect", # .legacy.exceptions "AbortHandshake", "InvalidMessage", @@ -53,9 +59,6 @@ # .legacy.server "WebSocketServer", "WebSocketServerProtocol", - "broadcast", - "serve", - "unix_serve", # .server "ServerProtocol", # .typing @@ -70,6 +73,8 @@ # When type checking, import non-deprecated aliases eagerly. Else, import on demand. if typing.TYPE_CHECKING: + from .asyncio.client import connect, unix_connect + from .asyncio.server import basic_auth, broadcast, serve, unix_serve from .client import ClientProtocol from .datastructures import Headers, HeadersLike, MultipleValuesError from .exceptions import ( @@ -100,7 +105,7 @@ BasicAuthWebSocketServerProtocol, basic_auth_protocol_factory, ) - from .legacy.client import WebSocketClientProtocol, connect, unix_connect + from .legacy.client import WebSocketClientProtocol from .legacy.exceptions import ( AbortHandshake, InvalidMessage, @@ -108,13 +113,7 @@ RedirectHandshake, ) from .legacy.protocol import WebSocketCommonProtocol - from .legacy.server import ( - WebSocketServer, - WebSocketServerProtocol, - broadcast, - serve, - unix_serve, - ) + from .legacy.server import WebSocketServer, WebSocketServerProtocol from .server import ServerProtocol from .typing import ( Data, @@ -129,6 +128,14 @@ lazy_import( globals(), aliases={ + # .asyncio.client + "connect": ".asyncio.client", + "unix_connect": ".asyncio.client", + # .asyncio.server + "basic_auth": ".asyncio.server", + "broadcast": ".asyncio.server", + "serve": ".asyncio.server", + "unix_serve": ".asyncio.server", # .client "ClientProtocol": ".client", # .datastructures @@ -163,8 +170,6 @@ "basic_auth_protocol_factory": ".legacy.auth", # .legacy.client "WebSocketClientProtocol": ".legacy.client", - "connect": ".legacy.client", - "unix_connect": ".legacy.client", # .legacy.exceptions "AbortHandshake": ".legacy.exceptions", "InvalidMessage": ".legacy.exceptions", @@ -175,9 +180,6 @@ # .legacy.server "WebSocketServer": ".legacy.server", "WebSocketServerProtocol": ".legacy.server", - "broadcast": ".legacy.server", - "serve": ".legacy.server", - "unix_serve": ".legacy.server", # .server "ServerProtocol": ".server", # .typing From 8d055ebf383520f46c1f0b6d40742e3ef8c3d723 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Sep 2024 22:40:09 +0200 Subject: [PATCH 1413/1539] Deprecate aliases pointing to the legacy implementation. --- docs/howto/upgrade.rst | 3 +- docs/project/changelog.rst | 6 +++ src/websockets/__init__.py | 63 ++++++++---------------------- src/websockets/auth.py | 10 ++++- src/websockets/client.py | 20 ++++++---- src/websockets/exceptions.py | 41 ++++++------------- src/websockets/server.py | 22 +++++++---- tests/legacy/test_auth.py | 2 +- tests/legacy/test_client_server.py | 2 +- tests/test_exports.py | 30 ++++++++++---- 10 files changed, 99 insertions(+), 100 deletions(-) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 5b1b8e4a2..bdaefd768 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -71,7 +71,8 @@ For context, the ``websockets`` package is structured as follows: * The new implementation is found in the ``websockets.asyncio`` package. * The original implementation was moved to the ``websockets.legacy`` package. * The ``websockets`` package provides aliases for convenience. They were - switched to the new implementation in version 14.0. + switched to the new implementation in version 14.0 or deprecated when there + isn't an equivalent API. * The ``websockets.client`` and ``websockets.server`` packages provide aliases for backwards-compatibility with earlier versions of websockets. They will be deprecated together with the original implementation. diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 45e4ef0cc..30e89a69c 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -52,6 +52,12 @@ Backwards-incompatible changes If you're using any of them, then you must follow the :doc:`upgrade guide <../howto/upgrade>` immediately. +.. admonition:: The legacy :mod:`asyncio` implementation is now deprecated. + :class: caution + + Aliases for deprecated API were removed from ``__all__``. As a consequence, + they cannot be imported e.g. with ``from websockets import *`` anymore. + .. _13.1: 13.1 diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 036e71c23..531ce49f7 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -43,22 +43,6 @@ "ProtocolError", "SecurityError", "WebSocketException", - "WebSocketProtocolError", - # .legacy.auth - "BasicAuthWebSocketServerProtocol", - "basic_auth_protocol_factory", - # .legacy.client - "WebSocketClientProtocol", - # .legacy.exceptions - "AbortHandshake", - "InvalidMessage", - "InvalidStatusCode", - "RedirectHandshake", - # .legacy.protocol - "WebSocketCommonProtocol", - # .legacy.server - "WebSocketServer", - "WebSocketServerProtocol", # .server "ServerProtocol", # .typing @@ -99,21 +83,7 @@ ProtocolError, SecurityError, WebSocketException, - WebSocketProtocolError, ) - from .legacy.auth import ( - BasicAuthWebSocketServerProtocol, - basic_auth_protocol_factory, - ) - from .legacy.client import WebSocketClientProtocol - from .legacy.exceptions import ( - AbortHandshake, - InvalidMessage, - InvalidStatusCode, - RedirectHandshake, - ) - from .legacy.protocol import WebSocketCommonProtocol - from .legacy.server import WebSocketServer, WebSocketServerProtocol from .server import ServerProtocol from .typing import ( Data, @@ -164,22 +134,6 @@ "ProtocolError": ".exceptions", "SecurityError": ".exceptions", "WebSocketException": ".exceptions", - "WebSocketProtocolError": ".exceptions", - # .legacy.auth - "BasicAuthWebSocketServerProtocol": ".legacy.auth", - "basic_auth_protocol_factory": ".legacy.auth", - # .legacy.client - "WebSocketClientProtocol": ".legacy.client", - # .legacy.exceptions - "AbortHandshake": ".legacy.exceptions", - "InvalidMessage": ".legacy.exceptions", - "InvalidStatusCode": ".legacy.exceptions", - "RedirectHandshake": ".legacy.exceptions", - # .legacy.protocol - "WebSocketCommonProtocol": ".legacy.protocol", - # .legacy.server - "WebSocketServer": ".legacy.server", - "WebSocketServerProtocol": ".legacy.server", # .server "ServerProtocol": ".server", # .typing @@ -197,5 +151,22 @@ "handshake": ".legacy", "parse_uri": ".uri", "WebSocketURI": ".uri", + # deprecated in 14.0 + # .legacy.auth + "BasicAuthWebSocketServerProtocol": ".legacy.auth", + "basic_auth_protocol_factory": ".legacy.auth", + # .legacy.client + "WebSocketClientProtocol": ".legacy.client", + # .legacy.exceptions + "AbortHandshake": ".legacy.exceptions", + "InvalidMessage": ".legacy.exceptions", + "InvalidStatusCode": ".legacy.exceptions", + "RedirectHandshake": ".legacy.exceptions", + "WebSocketProtocolError": ".legacy.exceptions", + # .legacy.protocol + "WebSocketCommonProtocol": ".legacy.protocol", + # .legacy.server + "WebSocketServer": ".legacy.server", + "WebSocketServerProtocol": ".legacy.server", }, ) diff --git a/src/websockets/auth.py b/src/websockets/auth.py index b792e02f5..1e0002cee 100644 --- a/src/websockets/auth.py +++ b/src/websockets/auth.py @@ -1,6 +1,12 @@ from __future__ import annotations -# See #940 for why lazy_import isn't used here for backwards compatibility. -# See #1400 for why listing compatibility imports in __all__ helps PyCharm. +import warnings + from .legacy.auth import * from .legacy.auth import __all__ # noqa: F401 + + +warnings.warn( # deprecated in 14.0 + "websockets.auth is deprecated", + DeprecationWarning, +) diff --git a/src/websockets/client.py b/src/websockets/client.py index bce82d66b..8b66900a8 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -27,6 +27,7 @@ parse_upgrade, ) from .http11 import Request, Response +from .imports import lazy_import from .protocol import CLIENT, CONNECTING, OPEN, Protocol, State from .typing import ( ConnectionOption, @@ -40,13 +41,7 @@ from .utils import accept_key, generate_key -# See #940 for why lazy_import isn't used here for backwards compatibility. -# See #1400 for why listing compatibility imports in __all__ helps PyCharm. -from .legacy.client import * # isort:skip # noqa: I001 -from .legacy.client import __all__ as legacy__all__ - - -__all__ = ["ClientProtocol"] + legacy__all__ +__all__ = ["ClientProtocol"] class ClientProtocol(Protocol): @@ -392,3 +387,14 @@ def backoff( delay *= factor while True: yield max_delay + + +lazy_import( + globals(), + deprecated_aliases={ + # deprecated in 14.0 + "WebSocketClientProtocol": ".legacy.client", + "connect": ".legacy.client", + "unix_connect": ".legacy.client", + }, +) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index d723f2fec..7681736a4 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -31,7 +31,6 @@ from __future__ import annotations -import typing import warnings from .imports import lazy_import @@ -45,9 +44,7 @@ "InvalidURI", "InvalidHandshake", "SecurityError", - "InvalidMessage", "InvalidStatus", - "InvalidStatusCode", "InvalidHeader", "InvalidHeaderFormat", "InvalidHeaderValue", @@ -57,10 +54,7 @@ "DuplicateParameter", "InvalidParameterName", "InvalidParameterValue", - "AbortHandshake", - "RedirectHandshake", "ProtocolError", - "WebSocketProtocolError", "PayloadTooBig", "InvalidState", "ConcurrencyError", @@ -366,27 +360,18 @@ class ConcurrencyError(WebSocketException, RuntimeError): """ -# When type checking, import non-deprecated aliases eagerly. Else, import on demand. -if typing.TYPE_CHECKING: - from .legacy.exceptions import ( - AbortHandshake, - InvalidMessage, - InvalidStatusCode, - RedirectHandshake, - ) - - WebSocketProtocolError = ProtocolError -else: - lazy_import( - globals(), - aliases={ - "AbortHandshake": ".legacy.exceptions", - "InvalidMessage": ".legacy.exceptions", - "InvalidStatusCode": ".legacy.exceptions", - "RedirectHandshake": ".legacy.exceptions", - "WebSocketProtocolError": ".legacy.exceptions", - }, - ) - # At the bottom to break import cycles created by type annotations. from . import frames, http11 # noqa: E402 + + +lazy_import( + globals(), + deprecated_aliases={ + # deprecated in 14.0 + "AbortHandshake": ".legacy.exceptions", + "InvalidMessage": ".legacy.exceptions", + "InvalidStatusCode": ".legacy.exceptions", + "RedirectHandshake": ".legacy.exceptions", + "WebSocketProtocolError": ".legacy.exceptions", + }, +) diff --git a/src/websockets/server.py b/src/websockets/server.py index 9fe970619..527db8990 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -27,6 +27,7 @@ parse_upgrade, ) from .http11 import Request, Response +from .imports import lazy_import from .protocol import CONNECTING, OPEN, SERVER, Protocol, State from .typing import ( ConnectionOption, @@ -40,13 +41,7 @@ from .utils import accept_key -# See #940 for why lazy_import isn't used here for backwards compatibility. -# See #1400 for why listing compatibility imports in __all__ helps PyCharm. -from .legacy.server import * # isort:skip # noqa: I001 -from .legacy.server import __all__ as legacy__all__ - - -__all__ = ["ServerProtocol"] + legacy__all__ +__all__ = ["ServerProtocol"] class ServerProtocol(Protocol): @@ -586,3 +581,16 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: DeprecationWarning, ) super().__init__(*args, **kwargs) + + +lazy_import( + globals(), + deprecated_aliases={ + # deprecated in 14.0 + "WebSocketServer": ".legacy.server", + "WebSocketServerProtocol": ".legacy.server", + "broadcast": ".legacy.server", + "serve": ".legacy.server", + "unix_serve": ".legacy.server", + }, +) diff --git a/tests/legacy/test_auth.py b/tests/legacy/test_auth.py index 3754bcf3a..dabd4212a 100644 --- a/tests/legacy/test_auth.py +++ b/tests/legacy/test_auth.py @@ -2,10 +2,10 @@ import unittest import urllib.error -from websockets.exceptions import InvalidStatusCode from websockets.headers import build_authorization_basic from websockets.legacy.auth import * from websockets.legacy.auth import is_credentials +from websockets.legacy.exceptions import InvalidStatusCode from .test_client_server import ClientServerTestsMixin, with_client, with_server from .utils import AsyncioTestCase diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 2f3ba9b77..502ab68e7 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -21,7 +21,6 @@ ConnectionClosed, InvalidHandshake, InvalidHeader, - InvalidStatusCode, NegotiationError, ) from websockets.extensions.permessage_deflate import ( @@ -32,6 +31,7 @@ from websockets.frames import CloseCode from websockets.http11 import USER_AGENT from websockets.legacy.client import * +from websockets.legacy.exceptions import InvalidStatusCode from websockets.legacy.handshake import build_response from websockets.legacy.http import read_response from websockets.legacy.server import * diff --git a/tests/test_exports.py b/tests/test_exports.py index 67a1a6f99..93b0684f7 100644 --- a/tests/test_exports.py +++ b/tests/test_exports.py @@ -1,30 +1,46 @@ import unittest import websockets -import websockets.auth +import websockets.asyncio.client +import websockets.asyncio.server import websockets.client import websockets.datastructures import websockets.exceptions -import websockets.legacy.protocol import websockets.server import websockets.typing import websockets.uri combined_exports = ( - websockets.auth.__all__ + [] + + websockets.asyncio.client.__all__ + + websockets.asyncio.server.__all__ + websockets.client.__all__ + websockets.datastructures.__all__ + websockets.exceptions.__all__ - + websockets.legacy.protocol.__all__ + websockets.server.__all__ + websockets.typing.__all__ ) +# These API are intentionally not re-exported by the top-level module. +missing_reexports = [ + # websockets.asyncio.client + "ClientConnection", + # websockets.asyncio.server + "ServerConnection", + "Server", +] + class ExportsTests(unittest.TestCase): - def test_top_level_module_reexports_all_submodule_exports(self): - self.assertEqual(set(combined_exports), set(websockets.__all__)) + def test_top_level_module_reexports_submodule_exports(self): + self.assertEqual( + set(combined_exports), + set(websockets.__all__ + missing_reexports), + ) def test_submodule_exports_are_globally_unique(self): - self.assertEqual(len(set(combined_exports)), len(combined_exports)) + self.assertEqual( + len(set(combined_exports)), + len(combined_exports), + ) From d62e423744b3e20ef6405019f96a63faa21612a2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 22 Sep 2024 12:10:15 +0200 Subject: [PATCH 1414/1539] Deprecate the legacy asyncio implementation. --- docs/howto/upgrade.rst | 21 ++++++--------- docs/index.rst | 43 ++++++++++++++++--------------- docs/project/changelog.rst | 3 +++ docs/reference/index.rst | 30 +++++++++++---------- docs/reference/legacy/client.rst | 6 +++++ docs/reference/legacy/common.rst | 6 +++++ docs/reference/legacy/server.rst | 6 +++++ docs/topics/design.rst | 2 ++ docs/topics/index.rst | 1 - src/websockets/auth.py | 12 ++++++--- src/websockets/http.py | 7 ++++- src/websockets/legacy/__init__.py | 11 ++++++++ tests/legacy/__init__.py | 9 +++++++ tests/maxi_cov.py | 2 -- tests/test_auth.py | 14 ++++++++++ 15 files changed, 118 insertions(+), 55 deletions(-) create mode 100644 tests/test_auth.py diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index bdaefd768..02d4c6f01 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -27,15 +27,9 @@ respectively. .. admonition:: What will happen to the original implementation? :class: hint - The original implementation is now considered legacy. - - The next steps are: - - 1. Deprecating it once the new implementation is considered sufficiently - robust. - 2. Maintaining it for five years per the :ref:`backwards-compatibility - policy `. - 3. Removing it. This is expected to happen around 2030. + The original implementation is deprecated. It will be maintained for five + years after deprecation according to the :ref:`backwards-compatibility + policy `. Then, by 2030, it will be removed. .. _deprecated APIs: @@ -69,13 +63,14 @@ Import paths For context, the ``websockets`` package is structured as follows: * The new implementation is found in the ``websockets.asyncio`` package. -* The original implementation was moved to the ``websockets.legacy`` package. +* The original implementation was moved to the ``websockets.legacy`` package + and deprecated. * The ``websockets`` package provides aliases for convenience. They were switched to the new implementation in version 14.0 or deprecated when there - isn't an equivalent API. + wasn't an equivalent API. * The ``websockets.client`` and ``websockets.server`` packages provide aliases - for backwards-compatibility with earlier versions of websockets. They will - be deprecated together with the original implementation. + for backwards-compatibility with earlier versions of websockets. They were + deprecated. To upgrade to the new :mod:`asyncio` implementation, change import paths as shown in the tables below. diff --git a/docs/index.rst b/docs/index.rst index b8cd300e3..f9576f2dc 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -28,30 +28,13 @@ with a focus on correctness, simplicity, robustness, and performance. It supports several network I/O and control flow paradigms. -1. The primary implementation builds upon :mod:`asyncio`, Python's standard - asynchronous I/O framework. It provides an elegant coroutine-based API. It's - ideal for servers that handle many clients concurrently. - - .. admonition:: As of version :ref:`13.0`, there is a new :mod:`asyncio` - implementation. - :class: important - - The historical implementation in ``websockets.legacy`` traces its roots to - early versions of websockets. Although it's stable and robust, it is now - considered legacy. - - The new implementation in ``websockets.asyncio`` is a rewrite on top of - the Sans-I/O implementation. It adds a few features that were impossible - to implement within the original design. - - The new implementation provides all features of the historical - implementation, and a few more. If you're using the historical - implementation, you should :doc:`ugrade to the new implementation - `. It's usually straightforward. +1. The default implementation builds upon :mod:`asyncio`, Python's built-in + asynchronous I/O library. It provides an elegant coroutine-based API. It's + ideal for servers that handle many client connections. 2. The :mod:`threading` implementation is a good alternative for clients, especially if you aren't familiar with :mod:`asyncio`. It may also be used - for servers that don't need to serve many clients. + for servers that handle few client connections. 3. The `Sans-I/O`_ implementation is designed for integrating in third-party libraries, typically application servers, in addition being used internally @@ -59,6 +42,24 @@ It supports several network I/O and control flow paradigms. .. _Sans-I/O: https://sans-io.readthedocs.io/ +Refer to the :doc:`feature support matrices ` for the full +list of features provided by each implementation. + +.. admonition:: The :mod:`asyncio` implementation was rewritten. + :class: tip + + The new implementation in ``websockets.asyncio`` builds upon the Sans-I/O + implementation. It adds features that were impossible to provide in the + original design. It was introduced in version 13.0. + + The historical implementation in ``websockets.legacy`` traces its roots to + early versions of websockets. While it's stable and robust, it was deprecated + in version 14.0 and it will be removed by 2030. + + The new implementation provides the same features as the historical + implementation, and then some. If you're using the historical implementation, + you should :doc:`ugrade to the new implementation `. + Here's an echo server using the :mod:`asyncio` API: .. literalinclude:: ../example/echo.py diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 30e89a69c..c8d854ba4 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -55,6 +55,9 @@ Backwards-incompatible changes .. admonition:: The legacy :mod:`asyncio` implementation is now deprecated. :class: caution + The :doc:`upgrade guide <../howto/upgrade>` provides complete instructions + to migrate your application. + Aliases for deprecated API were removed from ``__all__``. As a consequence, they cannot be imported e.g. with ``from websockets import *`` anymore. diff --git a/docs/reference/index.rst b/docs/reference/index.rst index ed2341cc6..c78a3c095 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -13,12 +13,12 @@ Check which implementations support which features and known limitations. features -:mod:`asyncio` (new) --------------------- +:mod:`asyncio` +-------------- It's ideal for servers that handle many clients concurrently. -It's a rewrite of the legacy :mod:`asyncio` implementation. +This is the default implementation. .. toctree:: :titlesonly: @@ -26,17 +26,6 @@ It's a rewrite of the legacy :mod:`asyncio` implementation. asyncio/server asyncio/client -:mod:`asyncio` (legacy) ------------------------ - -This is the historical implementation. - -.. toctree:: - :titlesonly: - - legacy/server - legacy/client - :mod:`threading` ---------------- @@ -62,6 +51,19 @@ application servers. sansio/server sansio/client +:mod:`asyncio` (legacy) +----------------------- + +This is the historical implementation. + +It is deprecated and will be removed. + +.. toctree:: + :titlesonly: + + legacy/server + legacy/client + Extensions ---------- diff --git a/docs/reference/legacy/client.rst b/docs/reference/legacy/client.rst index fca45d218..a798409f0 100644 --- a/docs/reference/legacy/client.rst +++ b/docs/reference/legacy/client.rst @@ -1,6 +1,12 @@ Client (legacy :mod:`asyncio`) ============================== +.. admonition:: The legacy :mod:`asyncio` implementation is deprecated. + :class: caution + + The :doc:`upgrade guide <../../howto/upgrade>` provides complete instructions + to migrate your application. + .. automodule:: websockets.legacy.client Opening a connection diff --git a/docs/reference/legacy/common.rst b/docs/reference/legacy/common.rst index aee774479..45c56fccd 100644 --- a/docs/reference/legacy/common.rst +++ b/docs/reference/legacy/common.rst @@ -3,6 +3,12 @@ Both sides (legacy :mod:`asyncio`) ================================== +.. admonition:: The legacy :mod:`asyncio` implementation is deprecated. + :class: caution + + The :doc:`upgrade guide <../../howto/upgrade>` provides complete instructions + to migrate your application. + .. automodule:: websockets.legacy.protocol .. autoclass:: WebSocketCommonProtocol(*, logger=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) diff --git a/docs/reference/legacy/server.rst b/docs/reference/legacy/server.rst index b6c383ce7..3c1d19fc6 100644 --- a/docs/reference/legacy/server.rst +++ b/docs/reference/legacy/server.rst @@ -1,6 +1,12 @@ Server (legacy :mod:`asyncio`) ============================== +.. admonition:: The legacy :mod:`asyncio` implementation is deprecated. + :class: caution + + The :doc:`upgrade guide <../../howto/upgrade>` provides complete instructions + to migrate your application. + .. automodule:: websockets.legacy.server Starting a server diff --git a/docs/topics/design.rst b/docs/topics/design.rst index d2fd18d0c..b73ace517 100644 --- a/docs/topics/design.rst +++ b/docs/topics/design.rst @@ -1,3 +1,5 @@ +:orphan: + Design (legacy :mod:`asyncio`) ============================== diff --git a/docs/topics/index.rst b/docs/topics/index.rst index a2b8ca879..616753c6c 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -12,7 +12,6 @@ Get a deeper understanding of how websockets is built and why. broadcast compression keepalive - design memory security performance diff --git a/src/websockets/auth.py b/src/websockets/auth.py index 1e0002cee..98e62af3c 100644 --- a/src/websockets/auth.py +++ b/src/websockets/auth.py @@ -2,11 +2,17 @@ import warnings -from .legacy.auth import * -from .legacy.auth import __all__ # noqa: F401 + +with warnings.catch_warnings(): + # Suppress redundant DeprecationWarning raised by websockets.legacy. + warnings.filterwarnings("ignore", category=DeprecationWarning) + from .legacy.auth import * + from .legacy.auth import __all__ # noqa: F401 warnings.warn( # deprecated in 14.0 - "websockets.auth is deprecated", + "websockets.auth, an alias for websockets.legacy.auth, is deprecated; " + "see https://websockets.readthedocs.io/en/stable/howto/upgrade.html " + "for upgrade instructions", DeprecationWarning, ) diff --git a/src/websockets/http.py b/src/websockets/http.py index 0ff5598c7..0d860e537 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -3,7 +3,12 @@ import warnings from .datastructures import Headers, MultipleValuesError # noqa: F401 -from .legacy.http import read_request, read_response # noqa: F401 + + +with warnings.catch_warnings(): + # Suppress redundant DeprecationWarning raised by websockets.legacy. + warnings.filterwarnings("ignore", category=DeprecationWarning) + from .legacy.http import read_request, read_response # noqa: F401 warnings.warn( # deprecated in 9.0 - 2021-09-01 diff --git a/src/websockets/legacy/__init__.py b/src/websockets/legacy/__init__.py index e69de29bb..84f870f3a 100644 --- a/src/websockets/legacy/__init__.py +++ b/src/websockets/legacy/__init__.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +import warnings + + +warnings.warn( # deprecated in 14.0 + "websockets.legacy is deprecated; " + "see https://websockets.readthedocs.io/en/stable/howto/upgrade.html " + "for upgrade instructions", + DeprecationWarning, +) diff --git a/tests/legacy/__init__.py b/tests/legacy/__init__.py index e69de29bb..035834a89 100644 --- a/tests/legacy/__init__.py +++ b/tests/legacy/__init__.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +import warnings + + +with warnings.catch_warnings(): + # Suppress DeprecationWarning raised by websockets.legacy. + warnings.filterwarnings("ignore", category=DeprecationWarning) + import websockets.legacy # noqa: F401 diff --git a/tests/maxi_cov.py b/tests/maxi_cov.py index 83686c3d3..8ccef7d39 100755 --- a/tests/maxi_cov.py +++ b/tests/maxi_cov.py @@ -9,7 +9,6 @@ UNMAPPED_SRC_FILES = [ - "websockets/auth.py", "websockets/typing.py", "websockets/version.py", ] @@ -105,7 +104,6 @@ def get_ignored_files(src_dir="src"): # or websockets (import locations). "*/websockets/asyncio/async_timeout.py", "*/websockets/asyncio/compatibility.py", - "*/websockets/auth.py", # This approach isn't applicable to the test suite of the legacy # implementation, due to the huge test_client_server test module. "*/websockets/legacy/*", diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 000000000..16c00c1b9 --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,14 @@ +from .utils import DeprecationTestCase + + +class BackwardsCompatibilityTests(DeprecationTestCase): + def test_headers_class(self): + with self.assertDeprecationWarning( + "websockets.auth, an alias for websockets.legacy.auth, is deprecated; " + "see https://websockets.readthedocs.io/en/stable/howto/upgrade.html " + "for upgrade instructions", + ): + from websockets.auth import ( + BasicAuthWebSocketServerProtocol, # noqa: F401 + basic_auth_protocol_factory, # noqa: F401 + ) From a0b20f081d7ae48409c6b79ed16bbc261d5109f9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 22 Sep 2024 21:16:26 +0200 Subject: [PATCH 1415/1539] Document that only asyncio supports keepalive. Fix #1508. --- docs/topics/keepalive.rst | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/docs/topics/keepalive.rst b/docs/topics/keepalive.rst index 91f11fb11..458fa3d05 100644 --- a/docs/topics/keepalive.rst +++ b/docs/topics/keepalive.rst @@ -1,6 +1,11 @@ Keepalive and latency ===================== +.. admonition:: This guide applies only to the :mod:`asyncio` implementation. + :class: tip + + The :mod:`threading` implementation doesn't provide keepalive yet. + .. currentmodule:: websockets Long-lived connections @@ -31,14 +36,8 @@ based on WebSocket Ping_ and Pong_ frames, which are designed for this purpose. .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 -It loops through these steps: - -1. Wait 20 seconds. -2. Send a Ping frame. -3. Receive a corresponding Pong frame within 20 seconds. - -If the Pong frame isn't received, websockets considers the connection broken and -closes it. +It sends a Ping frame every 20 seconds. It expects a Pong frame in return within +20 seconds. Else, it considers the connection broken and terminates it. This mechanism serves three purposes: From baadc33364131ff7236249c9077ae10b561395b6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 23 Sep 2024 23:52:25 +0200 Subject: [PATCH 1416/1539] Profile and optimize the permessage-deflate extension. dataclasses.replace is surprisingly expensive. zlib functions make up the bulk of the cost now. --- experiments/compression/corpus.py | 2 +- experiments/profiling/compression.py | 45 +++++++++++++++++++ .../extensions/permessage_deflate.py | 29 +++++++++--- 3 files changed, 68 insertions(+), 8 deletions(-) create mode 100644 experiments/profiling/compression.py diff --git a/experiments/compression/corpus.py b/experiments/compression/corpus.py index da5661dfa..56e262114 100644 --- a/experiments/compression/corpus.py +++ b/experiments/compression/corpus.py @@ -47,6 +47,6 @@ def main(corpus): if __name__ == "__main__": if len(sys.argv) < 2: - print(f"Usage: {sys.argv[0]} [directory]") + print(f"Usage: {sys.argv[0]} ") sys.exit(2) main(pathlib.Path(sys.argv[1])) diff --git a/experiments/profiling/compression.py b/experiments/profiling/compression.py new file mode 100644 index 000000000..1ece1f10e --- /dev/null +++ b/experiments/profiling/compression.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python + +""" +Profile the permessage-deflate extension. + +Usage:: + $ pip install line_profiler + $ python experiments/compression/corpus.py experiments/compression/corpus + $ PYTHONPATH=src python -m kernprof \ + --line-by-line \ + --prof-mod src/websockets/extensions/permessage_deflate.py \ + --view \ + experiments/profiling/compression.py experiments/compression/corpus 12 5 6 + +""" + +import pathlib +import sys + +from websockets.extensions.permessage_deflate import PerMessageDeflate +from websockets.frames import OP_TEXT, Frame + + +def compress_and_decompress(corpus, max_window_bits, memory_level, level): + extension = PerMessageDeflate( + remote_no_context_takeover=False, + local_no_context_takeover=False, + remote_max_window_bits=max_window_bits, + local_max_window_bits=max_window_bits, + compress_settings={"memLevel": memory_level, "level": level}, + ) + for data in corpus: + frame = Frame(OP_TEXT, data) + frame = extension.encode(frame) + frame = extension.decode(frame) + + +if __name__ == "__main__": + if len(sys.argv) < 2 or not pathlib.Path(sys.argv[1]).is_dir(): + print(f"Usage: {sys.argv[0]} [] []") + corpus = [file.read_bytes() for file in pathlib.Path(sys.argv[1]).iterdir()] + max_window_bits = int(sys.argv[2]) if len(sys.argv) > 2 else 12 + memory_level = int(sys.argv[3]) if len(sys.argv) > 3 else 5 + level = int(sys.argv[4]) if len(sys.argv) > 4 else 6 + compress_and_decompress(corpus, max_window_bits, memory_level, level) diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index f962b65fb..21df804fd 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -1,6 +1,5 @@ from __future__ import annotations -import dataclasses import zlib from collections.abc import Sequence from typing import Any @@ -120,7 +119,6 @@ def decode( else: if not frame.rsv1: return frame - frame = dataclasses.replace(frame, rsv1=False) if not frame.fin: self.decode_cont_data = True @@ -146,7 +144,15 @@ def decode( if frame.fin and self.remote_no_context_takeover: del self.decoder - return dataclasses.replace(frame, data=data) + return frames.Frame( + frame.opcode, + data, + frame.fin, + # Unset the rsv1 flag on the first frame of a compressed message. + False, + frame.rsv2, + frame.rsv3, + ) def encode(self, frame: frames.Frame) -> frames.Frame: """ @@ -161,8 +167,6 @@ def encode(self, frame: frames.Frame) -> frames.Frame: # data" flag similar to "decode continuation data" at this time. if frame.opcode is not frames.OP_CONT: - # Set the rsv1 flag on the first frame of a compressed message. - frame = dataclasses.replace(frame, rsv1=True) # Re-initialize per-message decoder. if self.local_no_context_takeover: self.encoder = zlib.compressobj( @@ -172,14 +176,25 @@ def encode(self, frame: frames.Frame) -> frames.Frame: # Compress data. data = self.encoder.compress(frame.data) + self.encoder.flush(zlib.Z_SYNC_FLUSH) - if frame.fin and data.endswith(_EMPTY_UNCOMPRESSED_BLOCK): + if frame.fin and data[-4:] == _EMPTY_UNCOMPRESSED_BLOCK: + # Making a copy is faster than memoryview(a)[:-4] until about 2kB. + # On larger messages, it's slower but profiling shows that it's + # marginal compared to compress() and flush(). Keep it simple. data = data[:-4] # Allow garbage collection of the encoder if it won't be reused. if frame.fin and self.local_no_context_takeover: del self.encoder - return dataclasses.replace(frame, data=data) + return frames.Frame( + frame.opcode, + data, + frame.fin, + # Set the rsv1 flag on the first frame of a compressed message. + frame.opcode is not frames.OP_CONT, + frame.rsv2, + frame.rsv3, + ) def _build_parameters( From 524dd4afa8dc2a0d17ab08f20f592af03c165db1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 26 Sep 2024 23:11:28 +0200 Subject: [PATCH 1417/1539] Avoid making a copy of large frames. This isn't very significant compared to the cost of compression. It can make a real difference for decompression. --- docs/project/changelog.rst | 14 +++++++++ .../extensions/permessage_deflate.py | 30 ++++++++++++------- src/websockets/frames.py | 8 ++--- tests/extensions/test_permessage_deflate.py | 11 +++++++ 4 files changed, 49 insertions(+), 14 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index c8d854ba4..f5b4812bd 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -61,6 +61,20 @@ Backwards-incompatible changes Aliases for deprecated API were removed from ``__all__``. As a consequence, they cannot be imported e.g. with ``from websockets import *`` anymore. +.. admonition:: :attr:`Frame.data ` is now a bytes-like object. + :class: note + + In addition to :class:`bytes`, it may be a :class:`bytearray` or a + :class:`memoryview`. + + If you wrote an :class:`extension ` that relies on + methods not provided by these new types, you may need to update your code. + +Improvements +............ + +* Sending or receiving large compressed frames is now faster. + .. _13.1: 13.1 diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 21df804fd..ed16937d8 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -129,16 +129,22 @@ def decode( # Uncompress data. Protect against zip bombs by preventing zlib from # decompressing more than max_length bytes (except when the limit is # disabled with max_size = None). - data = frame.data - if frame.fin: - data += _EMPTY_UNCOMPRESSED_BLOCK + if frame.fin and len(frame.data) < 2044: + # Profiling shows that appending four bytes, which makes a copy, is + # faster than calling decompress() again when data is less than 2kB. + data = bytes(frame.data) + _EMPTY_UNCOMPRESSED_BLOCK + else: + data = frame.data max_length = 0 if max_size is None else max_size try: data = self.decoder.decompress(data, max_length) + if self.decoder.unconsumed_tail: + raise PayloadTooBig(f"over size limit (? > {max_size} bytes)") + if frame.fin and len(frame.data) >= 2044: + # This cannot generate additional data. + self.decoder.decompress(_EMPTY_UNCOMPRESSED_BLOCK) except zlib.error as exc: raise ProtocolError("decompression failed") from exc - if self.decoder.unconsumed_tail: - raise PayloadTooBig(f"over size limit (? > {max_size} bytes)") # Allow garbage collection of the decoder if it won't be reused. if frame.fin and self.remote_no_context_takeover: @@ -176,11 +182,15 @@ def encode(self, frame: frames.Frame) -> frames.Frame: # Compress data. data = self.encoder.compress(frame.data) + self.encoder.flush(zlib.Z_SYNC_FLUSH) - if frame.fin and data[-4:] == _EMPTY_UNCOMPRESSED_BLOCK: - # Making a copy is faster than memoryview(a)[:-4] until about 2kB. - # On larger messages, it's slower but profiling shows that it's - # marginal compared to compress() and flush(). Keep it simple. - data = data[:-4] + if frame.fin: + # Sync flush generates between 5 or 6 bytes, ending with the bytes + # 0x00 0x00 0xff 0xff, which must be removed. + assert data[-4:] == _EMPTY_UNCOMPRESSED_BLOCK + # Making a copy is faster than memoryview(a)[:-4] until 2kB. + if len(data) < 2048: + data = data[:-4] + else: + data = memoryview(data)[:-4] # Allow garbage collection of the encoder if it won't be reused. if frame.fin and self.local_no_context_takeover: diff --git a/src/websockets/frames.py b/src/websockets/frames.py index dace2c902..5fadf3c2d 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -7,7 +7,7 @@ import secrets import struct from collections.abc import Generator, Sequence -from typing import Callable +from typing import Callable, Union from .exceptions import PayloadTooBig, ProtocolError @@ -139,7 +139,7 @@ class Frame: """ opcode: Opcode - data: bytes + data: Union[bytes, bytearray, memoryview] fin: bool = True rsv1: bool = False rsv2: bool = False @@ -160,7 +160,7 @@ def __str__(self) -> str: if self.opcode is OP_TEXT: # Decoding only the beginning and the end is needlessly hard. # Decode the entire payload then elide later if necessary. - data = repr(self.data.decode()) + data = repr(bytes(self.data).decode()) elif self.opcode is OP_BINARY: # We'll show at most the first 16 bytes and the last 8 bytes. # Encode just what we need, plus two dummy bytes to elide later. @@ -178,7 +178,7 @@ def __str__(self) -> str: # binary. If self.data is a memoryview, it has no decode() method, # which raises AttributeError. try: - data = repr(self.data.decode()) + data = repr(bytes(self.data).decode()) coding = "text" except (UnicodeDecodeError, AttributeError): binary = self.data diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index ee09813c4..76cd48623 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -1,4 +1,5 @@ import dataclasses +import os import unittest from websockets.exceptions import ( @@ -167,6 +168,16 @@ def test_encode_decode_fragmented_binary_frame(self): self.assertEqual(dec_frame1, frame1) self.assertEqual(dec_frame2, frame2) + def test_encode_decode_large_frame(self): + # There is a separate code path that avoids copying data + # when frames are larger than 2kB. Test it for coverage. + frame = Frame(OP_BINARY, os.urandom(4096)) + + enc_frame = self.extension.encode(frame) + dec_frame = self.extension.decode(enc_frame) + + self.assertEqual(dec_frame, frame) + def test_no_decode_text_frame(self): frame = Frame(OP_TEXT, "café".encode()) From 07dc56443acd8aaac4d79db6a30e450d0073137d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 27 Sep 2024 21:44:36 +0200 Subject: [PATCH 1418/1539] Add asyncio & threading examples to the homepage. Fix #1437. --- docs/index.rst | 19 +++++++++++++++---- example/{ => asyncio}/echo.py | 12 +++++++++--- example/asyncio/hello.py | 17 +++++++++++++++++ example/ruff.toml | 2 ++ example/sync/echo.py | 19 +++++++++++++++++++ example/{ => sync}/hello.py | 10 +++++++--- 6 files changed, 69 insertions(+), 10 deletions(-) rename example/{ => asyncio}/echo.py (51%) create mode 100755 example/asyncio/hello.py create mode 100644 example/ruff.toml create mode 100755 example/sync/echo.py rename example/{ => sync}/hello.py (66%) diff --git a/docs/index.rst b/docs/index.rst index f9576f2dc..de14fa2d0 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -60,13 +60,24 @@ list of features provided by each implementation. implementation, and then some. If you're using the historical implementation, you should :doc:`ugrade to the new implementation `. -Here's an echo server using the :mod:`asyncio` API: +Here's an echo server and corresponding client. -.. literalinclude:: ../example/echo.py +.. tab:: asyncio -Here's a client using the :mod:`threading` API: + .. literalinclude:: ../example/asyncio/echo.py -.. literalinclude:: ../example/hello.py +.. tab:: threading + + .. literalinclude:: ../example/sync/echo.py + +.. tab:: asyncio + :new-set: + + .. literalinclude:: ../example/asyncio/hello.py + +.. tab:: threading + + .. literalinclude:: ../example/sync/hello.py Don't worry about the opening and closing handshakes, pings and pongs, or any other behavior described in the WebSocket specification. websockets takes care diff --git a/example/echo.py b/example/asyncio/echo.py similarity index 51% rename from example/echo.py rename to example/asyncio/echo.py index b952a5cfb..28d877be7 100755 --- a/example/echo.py +++ b/example/asyncio/echo.py @@ -1,14 +1,20 @@ #!/usr/bin/env python +"""Echo server using the asyncio API.""" + import asyncio from websockets.asyncio.server import serve + async def echo(websocket): async for message in websocket: await websocket.send(message) + async def main(): - async with serve(echo, "localhost", 8765): - await asyncio.get_running_loop().create_future() # run forever + async with serve(echo, "localhost", 8765) as server: + await server.serve_forever() + -asyncio.run(main()) +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/asyncio/hello.py b/example/asyncio/hello.py new file mode 100755 index 000000000..6e4518497 --- /dev/null +++ b/example/asyncio/hello.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python + +"""Client using the asyncio API.""" + +import asyncio +from websockets.asyncio.client import connect + + +async def hello(): + async with connect("ws://localhost:8765") as websocket: + await websocket.send("Hello world!") + message = await websocket.recv() + print(message) + + +if __name__ == "__main__": + asyncio.run(hello()) diff --git a/example/ruff.toml b/example/ruff.toml new file mode 100644 index 000000000..13ae36c08 --- /dev/null +++ b/example/ruff.toml @@ -0,0 +1,2 @@ +[lint.isort] +no-sections = true diff --git a/example/sync/echo.py b/example/sync/echo.py new file mode 100755 index 000000000..4b47db1ba --- /dev/null +++ b/example/sync/echo.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python + +"""Echo server using the threading API.""" + +from websockets.sync.server import serve + + +def echo(websocket): + for message in websocket: + websocket.send(message) + + +def main(): + with serve(echo, "localhost", 8765) as server: + server.serve_forever() + + +if __name__ == "__main__": + main() diff --git a/example/hello.py b/example/sync/hello.py similarity index 66% rename from example/hello.py rename to example/sync/hello.py index a3ce0699e..bb4cd3ffd 100755 --- a/example/hello.py +++ b/example/sync/hello.py @@ -1,12 +1,16 @@ #!/usr/bin/env python -import asyncio +"""Client using the threading API.""" + from websockets.sync.client import connect + def hello(): with connect("ws://localhost:8765") as websocket: websocket.send("Hello world!") message = websocket.recv() - print(f"Received: {message}") + print(message) + -hello() +if __name__ == "__main__": + hello() From a5c8943a99a8625836b150ca3559923ecf79bcd0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 27 Sep 2024 21:53:49 +0200 Subject: [PATCH 1419/1539] Add asyncio & threading client & server examples. They're convenient to tweak to reproduce issues. --- example/asyncio/client.py | 22 ++++++++++++++++++++++ example/asyncio/server.py | 25 +++++++++++++++++++++++++ example/sync/client.py | 20 ++++++++++++++++++++ example/sync/server.py | 24 ++++++++++++++++++++++++ 4 files changed, 91 insertions(+) create mode 100644 example/asyncio/client.py create mode 100644 example/asyncio/server.py create mode 100644 example/sync/client.py create mode 100644 example/sync/server.py diff --git a/example/asyncio/client.py b/example/asyncio/client.py new file mode 100644 index 000000000..e3562642d --- /dev/null +++ b/example/asyncio/client.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python + +"""Client example using the asyncio API.""" + +import asyncio + +from websockets.asyncio.client import connect + + +async def hello(): + async with connect("ws://localhost:8765") as websocket: + name = input("What's your name? ") + + await websocket.send(name) + print(f">>> {name}") + + greeting = await websocket.recv() + print(f"<<< {greeting}") + + +if __name__ == "__main__": + asyncio.run(hello()) diff --git a/example/asyncio/server.py b/example/asyncio/server.py new file mode 100644 index 000000000..574e053bf --- /dev/null +++ b/example/asyncio/server.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python + +"""Server example using the asyncio API.""" + +import asyncio +from websockets.asyncio.server import serve + + +async def hello(websocket): + name = await websocket.recv() + print(f"<<< {name}") + + greeting = f"Hello {name}!" + + await websocket.send(greeting) + print(f">>> {greeting}") + + +async def main(): + async with serve(hello, "localhost", 8765) as server: + await server.serve_forever() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/sync/client.py b/example/sync/client.py new file mode 100644 index 000000000..c0d633c7b --- /dev/null +++ b/example/sync/client.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python + +"""Client example using the threading API.""" + +from websockets.sync.client import connect + + +def hello(): + with connect("ws://localhost:8765") as websocket: + name = input("What's your name? ") + + websocket.send(name) + print(f">>> {name}") + + greeting = websocket.recv() + print(f"<<< {greeting}") + + +if __name__ == "__main__": + hello() diff --git a/example/sync/server.py b/example/sync/server.py new file mode 100644 index 000000000..030049f81 --- /dev/null +++ b/example/sync/server.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python + +"""Server example using the threading API.""" + +from websockets.sync.server import serve + + +def hello(websocket): + name = websocket.recv() + print(f"<<< {name}") + + greeting = f"Hello {name}!" + + websocket.send(greeting) + print(f">>> {greeting}") + + +def main(): + with serve(hello, "localhost", 8765) as server: + server.serve_forever() + + +if __name__ == "__main__": + main() From 21987f96ad93f8c8bbf0b8ea99f3a18a52335730 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 27 Sep 2024 23:06:00 +0200 Subject: [PATCH 1420/1539] Migrate authentication experiment to new asyncio. --- docs/topics/authentication.rst | 118 ++++++-------- experiments/authentication/app.py | 194 ++++++++++-------------- experiments/authentication/script.js | 5 +- experiments/authentication/test.js | 2 - experiments/authentication/user_info.js | 2 +- 5 files changed, 127 insertions(+), 194 deletions(-) diff --git a/docs/topics/authentication.rst b/docs/topics/authentication.rst index 86d2e2587..e2de4332e 100644 --- a/docs/topics/authentication.rst +++ b/docs/topics/authentication.rst @@ -1,13 +1,13 @@ Authentication ============== -The WebSocket protocol was designed for creating web applications that need -bidirectional communication between clients running in browsers and servers. +The WebSocket protocol is designed for creating web applications that require +bidirectional communication between browsers and servers. In most practical use cases, WebSocket servers need to authenticate clients in order to route communications appropriately and securely. -:rfc:`6455` stays elusive when it comes to authentication: +:rfc:`6455` remains elusive when it comes to authentication: This protocol doesn't prescribe any particular way that servers can authenticate clients during the WebSocket handshake. The WebSocket @@ -26,8 +26,8 @@ System design Consider a setup where the WebSocket server is separate from the HTTP server. -Most servers built with websockets to complement a web application adopt this -design because websockets doesn't aim at supporting HTTP. +Most servers built with websockets adopt this design because they're a component +in a web application and websockets doesn't aim at supporting HTTP. The following diagram illustrates the authentication flow. @@ -82,8 +82,8 @@ WebSocket server. credentials would be a session identifier or a serialized, signed session. Unfortunately, when the WebSocket server runs on a different domain from - the web application, this idea bumps into the `Same-Origin Policy`_. For - security reasons, setting a cookie on a different origin is impossible. + the web application, this idea hits the wall of the `Same-Origin Policy`_. + For security reasons, setting a cookie on a different origin is impossible. The proper workaround consists in: @@ -108,13 +108,11 @@ WebSocket server. Letting the browser perform HTTP Basic Auth is a nice idea in theory. - In practice it doesn't work due to poor support in browsers. + In practice it doesn't work due to browser support limitations: - As of May 2021: + * Chrome behaves as expected. - * Chrome 90 behaves as expected. - - * Firefox 88 caches credentials too aggressively. + * Firefox caches credentials too aggressively. When connecting again to the same server with new credentials, it reuses the old credentials, which may be expired, resulting in an HTTP 401. Then @@ -123,7 +121,7 @@ WebSocket server. When tokens are short-lived or single-use, this bug produces an interesting effect: every other WebSocket connection fails. - * Safari 14 ignores credentials entirely. + * Safari behaves as expected. Two other options are off the table: @@ -142,8 +140,10 @@ Two other options are off the table: While this is suggested by the RFC, installing a TLS certificate is too far from the mainstream experience of browser users. This could make sense in - high security contexts. I hope developers working on such projects don't - take security advice from the documentation of random open source projects. + high security contexts. + + I hope that developers working on projects in this category don't take + security advice from the documentation of random open source projects :-) Let's experiment! ----------------- @@ -185,6 +185,8 @@ connection: .. code-block:: python + from websockets.frames import CloseCode + async def first_message_handler(websocket): token = await websocket.recv() user = get_user(token) @@ -212,24 +214,16 @@ the user. If authentication fails, it returns an HTTP 401: .. code-block:: python - from websockets.legacy.server import WebSocketServerProtocol - - class QueryParamProtocol(WebSocketServerProtocol): - async def process_request(self, path, headers): - token = get_query_parameter(path, "token") - if token is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Missing token\n" - - user = get_user(token) - if user is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Invalid token\n" + async def query_param_auth(connection, request): + token = get_query_param(request.path, "token") + if token is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Missing token\n") - self.user = user + user = get_user(token) + if user is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Invalid token\n") - async def query_param_handler(websocket): - user = websocket.user - - ... + connection.username = user Cookie ...... @@ -260,27 +254,19 @@ the user. If authentication fails, it returns an HTTP 401: .. code-block:: python - from websockets.legacy.server import WebSocketServerProtocol - - class CookieProtocol(WebSocketServerProtocol): - async def process_request(self, path, headers): - # Serve iframe on non-WebSocket requests - ... - - token = get_cookie(headers.get("Cookie", ""), "token") - if token is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Missing token\n" - - user = get_user(token) - if user is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Invalid token\n" + async def cookie_auth(connection, request): + # Serve iframe on non-WebSocket requests + ... - self.user = user + token = get_cookie(request.headers.get("Cookie", ""), "token") + if token is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Missing token\n") - async def cookie_handler(websocket): - user = websocket.user + user = get_user(token) + if user is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Invalid token\n") - ... + connection.username = user User information ................ @@ -303,24 +289,12 @@ the user. If authentication fails, it returns an HTTP 401: .. code-block:: python - from websockets.legacy.auth import BasicAuthWebSocketServerProtocol - - class UserInfoProtocol(BasicAuthWebSocketServerProtocol): - async def check_credentials(self, username, password): - if username != "token": - return False - - user = get_user(password) - if user is None: - return False + from websockets.asyncio.server import basic_auth as websockets_basic_auth - self.user = user - return True + def check_credentials(username, password): + return username == get_user(password) - async def user_info_handler(websocket): - user = websocket.user - - ... + basic_auth = websockets_basic_auth(check_credentials=check_credentials) Machine-to-machine authentication --------------------------------- @@ -334,11 +308,9 @@ To authenticate a websockets client with HTTP Basic Authentication .. code-block:: python - from websockets.legacy.client import connect + from websockets.asyncio.client import connect - async with connect( - f"wss://{username}:{password}@example.com" - ) as websocket: + async with connect(f"wss://{username}:{password}@.../") as websocket: ... (You must :func:`~urllib.parse.quote` ``username`` and ``password`` if they @@ -349,10 +321,8 @@ To authenticate a websockets client with HTTP Bearer Authentication .. code-block:: python - from websockets.legacy.client import connect + from websockets.asyncio.client import connect - async with connect( - "wss://example.com", - extra_headers={"Authorization": f"Bearer {token}"} - ) as websocket: + headers = {"Authorization": f"Bearer {token}"} + async with connect("wss://.../", additional_headers=headers) as websocket: ... diff --git a/experiments/authentication/app.py b/experiments/authentication/app.py index e3b2cf1f6..0bdd7fd2f 100644 --- a/experiments/authentication/app.py +++ b/experiments/authentication/app.py @@ -1,6 +1,7 @@ #!/usr/bin/env python import asyncio +import email.utils import http import http.cookies import pathlib @@ -8,9 +9,10 @@ import urllib.parse import uuid +from websockets.asyncio.server import basic_auth as websockets_basic_auth, serve +from websockets.datastructures import Headers from websockets.frames import CloseCode -from websockets.legacy.auth import BasicAuthWebSocketServerProtocol -from websockets.legacy.server import WebSocketServerProtocol, serve +from websockets.http11 import Response # User accounts database @@ -49,7 +51,19 @@ def get_query_param(path, key): return values[0] -# Main HTTP server +# WebSocket handler + + +async def handler(websocket): + try: + user = websocket.username + except AttributeError: + return + + await websocket.send(f"Hello {user}!") + message = await websocket.recv() + assert message == f"Goodbye {user}." + CONTENT_TYPES = { ".css": "text/css", @@ -59,9 +73,10 @@ def get_query_param(path, key): } -async def serve_html(path, request_headers): - user = get_query_param(path, "user") - path = urllib.parse.urlparse(path).path +async def serve_html(connection, request): + """Basic HTTP server implemented as a process_request hook.""" + user = get_query_param(request.path, "user") + path = urllib.parse.urlparse(request.path).path if path == "/": if user is None: page = "index.html" @@ -76,147 +91,96 @@ async def serve_html(path, request_headers): pass else: if template.is_file(): - headers = {"Content-Type": CONTENT_TYPES[template.suffix]} body = template.read_bytes() if user is not None: token = create_token(user) body = body.replace(b"TOKEN", token.encode()) - return http.HTTPStatus.OK, headers, body - - return http.HTTPStatus.NOT_FOUND, {}, b"Not found\n" - + headers = Headers( + { + "Date": email.utils.formatdate(usegmt=True), + "Connection": "close", + "Content-Length": str(len(body)), + "Content-Type": CONTENT_TYPES[template.suffix], + } + ) + return Response(200, "OK", headers, body) -async def noop_handler(websocket): - pass - - -# Send credentials as the first message in the WebSocket connection + return connection.respond(http.HTTPStatus.NOT_FOUND, "Not found\n") async def first_message_handler(websocket): + """Handler that sends credentials in the first WebSocket message.""" token = await websocket.recv() user = get_user(token) if user is None: await websocket.close(CloseCode.INTERNAL_ERROR, "authentication failed") return - await websocket.send(f"Hello {user}!") - message = await websocket.recv() - assert message == f"Goodbye {user}." - - -# Add credentials to the WebSocket URI in a query parameter - - -class QueryParamProtocol(WebSocketServerProtocol): - async def process_request(self, path, headers): - token = get_query_param(path, "token") - if token is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Missing token\n" - - user = get_user(token) - if user is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Invalid token\n" - - self.user = user - - -async def query_param_handler(websocket): - user = websocket.user - - await websocket.send(f"Hello {user}!") - message = await websocket.recv() - assert message == f"Goodbye {user}." - - -# Set a cookie on the domain of the WebSocket URI - - -class CookieProtocol(WebSocketServerProtocol): - async def process_request(self, path, headers): - if "Upgrade" not in headers: - template = pathlib.Path(__file__).with_name(path[1:]) - headers = {"Content-Type": CONTENT_TYPES[template.suffix]} - body = template.read_bytes() - return http.HTTPStatus.OK, headers, body - - token = get_cookie(headers.get("Cookie", ""), "token") - if token is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Missing token\n" - - user = get_user(token) - if user is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Invalid token\n" - - self.user = user + websocket.username = user + await handler(websocket) -async def cookie_handler(websocket): - user = websocket.user +async def query_param_auth(connection, request): + """Authenticate user from token in query parameter.""" + token = get_query_param(request.path, "token") + if token is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Missing token\n") - await websocket.send(f"Hello {user}!") - message = await websocket.recv() - assert message == f"Goodbye {user}." - - -# Adding credentials to the WebSocket URI in user information - - -class UserInfoProtocol(BasicAuthWebSocketServerProtocol): - async def check_credentials(self, username, password): - if username != "token": - return False - - user = get_user(password) - if user is None: - return False + user = get_user(token) + if user is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Invalid token\n") + + connection.username = user + + +async def cookie_auth(connection, request): + """Authenticate user from token in cookie.""" + if "Upgrade" not in request.headers: + template = pathlib.Path(__file__).with_name(request.path[1:]) + body = template.read_bytes() + headers = Headers( + { + "Date": email.utils.formatdate(usegmt=True), + "Connection": "close", + "Content-Length": str(len(body)), + "Content-Type": CONTENT_TYPES[template.suffix], + } + ) + return Response(200, "OK", headers, body) + + token = get_cookie(request.headers.get("Cookie", ""), "token") + if token is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Missing token\n") - self.user = user - return True + user = get_user(token) + if user is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Invalid token\n") + connection.username = user -async def user_info_handler(websocket): - user = websocket.user - await websocket.send(f"Hello {user}!") - message = await websocket.recv() - assert message == f"Goodbye {user}." +def check_credentials(username, password): + """Authenticate user with HTTP Basic Auth.""" + return username == get_user(password) -# Start all five servers +basic_auth = websockets_basic_auth(check_credentials=check_credentials) async def main(): + """Start one HTTP server and four WebSocket servers.""" # Set the stop condition when receiving SIGINT or SIGTERM. loop = asyncio.get_running_loop() stop = loop.create_future() loop.add_signal_handler(signal.SIGINT, stop.set_result, None) loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with serve( - noop_handler, - host="", - port=8000, - process_request=serve_html, - ), serve( - first_message_handler, - host="", - port=8001, - ), serve( - query_param_handler, - host="", - port=8002, - create_protocol=QueryParamProtocol, - ), serve( - cookie_handler, - host="", - port=8003, - create_protocol=CookieProtocol, - ), serve( - user_info_handler, - host="", - port=8004, - create_protocol=UserInfoProtocol, + async with ( + serve(handler, host="", port=8000, process_request=serve_html), + serve(first_message_handler, host="", port=8001), + serve(handler, host="", port=8002, process_request=query_param_auth), + serve(handler, host="", port=8003, process_request=cookie_auth), + serve(handler, host="", port=8004, process_request=basic_auth), ): print("Running on http://localhost:8000/") await stop diff --git a/experiments/authentication/script.js b/experiments/authentication/script.js index ec4e5e670..01dd5b168 100644 --- a/experiments/authentication/script.js +++ b/experiments/authentication/script.js @@ -1,4 +1,5 @@ -var token = window.parent.token; +var token = window.parent.token, + user = window.parent.user; function getExpectedEvents() { return [ @@ -7,7 +8,7 @@ function getExpectedEvents() { }, { type: "message", - data: `Hello ${window.parent.user}!`, + data: `Hello ${user}!`, }, { type: "close", diff --git a/experiments/authentication/test.js b/experiments/authentication/test.js index 428830ff3..e05ca697e 100644 --- a/experiments/authentication/test.js +++ b/experiments/authentication/test.js @@ -1,6 +1,4 @@ -// for connecting to WebSocket servers var token = document.body.dataset.token; -// for test assertions only const params = new URLSearchParams(window.location.search); var user = params.get("user"); diff --git a/experiments/authentication/user_info.js b/experiments/authentication/user_info.js index 1dab2ce4c..bc9a3f148 100644 --- a/experiments/authentication/user_info.js +++ b/experiments/authentication/user_info.js @@ -1,5 +1,5 @@ window.addEventListener("DOMContentLoaded", () => { - const uri = `ws://token:${token}@localhost:8004/`; + const uri = `ws://${user}:${token}@localhost:8004/`; const websocket = new WebSocket(uri); websocket.onmessage = ({ data }) => { From bc4b8f2776cd4da9aee2e67f66764bf26a5ad09e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 30 Sep 2024 21:07:28 +0200 Subject: [PATCH 1421/1539] Add option to force sending text or binary frames. Fix #1515. --- src/websockets/asyncio/connection.py | 124 +++++++++++++++------------ tests/asyncio/test_connection.py | 72 ++++++++++++++-- 2 files changed, 134 insertions(+), 62 deletions(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 702e69995..12871e4b3 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -251,12 +251,13 @@ async def recv(self, decode: bool | None = None) -> Data: You may override this behavior with the ``decode`` argument: - * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames - and return a bytestring (:class:`bytes`). This may be useful to - optimize performance when decoding isn't needed. + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and + return a bytestring (:class:`bytes`). This improves performance + when decoding isn't needed, for example if the message contains + JSON and you're using a JSON library that expects a bytestring. * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames - and return a string (:class:`str`). This is useful for servers - that send binary frames instead of text frames. + and return a string (:class:`str`). This may be useful for + servers that send binary frames instead of text frames. Raises: ConnectionClosed: When the connection is closed. @@ -333,7 +334,11 @@ async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data "is already running recv or recv_streaming" ) from None - async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> None: + async def send( + self, + message: Data | Iterable[Data] | AsyncIterable[Data], + text: bool | None = None, + ) -> None: """ Send a message. @@ -344,6 +349,17 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + You may override this behavior with the ``text`` argument: + + * Set ``text=True`` to send a bytestring or bytes-like object + (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) as a + Text_ frame. This improves performance when the message is already + UTF-8 encoded, for example if the message contains JSON and you're + using a JSON library that produces a bytestring. + * Set ``text=False`` to send a string (:class:`str`) in a Binary_ + frame. This may be useful for servers that expect binary frames + instead of text frames. + :meth:`send` also accepts an iterable or an asynchronous iterable of strings, bytestrings, or bytes-like objects to enable fragmentation_. Each item is treated as a message fragment and sent in its own frame. @@ -393,12 +409,20 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No # strings and bytes-like objects are iterable. if isinstance(message, str): - async with self.send_context(): - self.protocol.send_text(message.encode()) + if text is False: + async with self.send_context(): + self.protocol.send_binary(message.encode()) + else: + async with self.send_context(): + self.protocol.send_text(message.encode()) elif isinstance(message, BytesLike): - async with self.send_context(): - self.protocol.send_binary(message) + if text is True: + async with self.send_context(): + self.protocol.send_text(message) + else: + async with self.send_context(): + self.protocol.send_binary(message) # Catch a common mistake -- passing a dict to send(). @@ -419,36 +443,32 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No try: # First fragment. if isinstance(chunk, str): - text = True - async with self.send_context(): - self.protocol.send_text( - chunk.encode(), - fin=False, - ) + if text is False: + async with self.send_context(): + self.protocol.send_binary(chunk.encode(), fin=False) + else: + async with self.send_context(): + self.protocol.send_text(chunk.encode(), fin=False) + encode = True elif isinstance(chunk, BytesLike): - text = False - async with self.send_context(): - self.protocol.send_binary( - chunk, - fin=False, - ) + if text is True: + async with self.send_context(): + self.protocol.send_text(chunk, fin=False) + else: + async with self.send_context(): + self.protocol.send_binary(chunk, fin=False) + encode = False else: raise TypeError("iterable must contain bytes or str") # Other fragments for chunk in chunks: - if isinstance(chunk, str) and text: + if isinstance(chunk, str) and encode: async with self.send_context(): - self.protocol.send_continuation( - chunk.encode(), - fin=False, - ) - elif isinstance(chunk, BytesLike) and not text: + self.protocol.send_continuation(chunk.encode(), fin=False) + elif isinstance(chunk, BytesLike) and not encode: async with self.send_context(): - self.protocol.send_continuation( - chunk, - fin=False, - ) + self.protocol.send_continuation(chunk, fin=False) else: raise TypeError("iterable must contain uniform types") @@ -481,36 +501,32 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No try: # First fragment. if isinstance(chunk, str): - text = True - async with self.send_context(): - self.protocol.send_text( - chunk.encode(), - fin=False, - ) + if text is False: + async with self.send_context(): + self.protocol.send_binary(chunk.encode(), fin=False) + else: + async with self.send_context(): + self.protocol.send_text(chunk.encode(), fin=False) + encode = True elif isinstance(chunk, BytesLike): - text = False - async with self.send_context(): - self.protocol.send_binary( - chunk, - fin=False, - ) + if text is True: + async with self.send_context(): + self.protocol.send_text(chunk, fin=False) + else: + async with self.send_context(): + self.protocol.send_binary(chunk, fin=False) + encode = False else: raise TypeError("async iterable must contain bytes or str") # Other fragments async for chunk in achunks: - if isinstance(chunk, str) and text: + if isinstance(chunk, str) and encode: async with self.send_context(): - self.protocol.send_continuation( - chunk.encode(), - fin=False, - ) - elif isinstance(chunk, BytesLike) and not text: + self.protocol.send_continuation(chunk.encode(), fin=False) + elif isinstance(chunk, BytesLike) and not encode: async with self.send_context(): - self.protocol.send_continuation( - chunk, - fin=False, - ) + self.protocol.send_continuation(chunk, fin=False) else: raise TypeError("async iterable must contain uniform types") diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 70d9dad63..563cf2b17 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -190,13 +190,13 @@ async def test_recv_binary(self): await self.remote_connection.send(b"\x01\x02\xfe\xff") self.assertEqual(await self.connection.recv(), b"\x01\x02\xfe\xff") - async def test_recv_encoded_text(self): - """recv receives an UTF-8 encoded text message.""" + async def test_recv_text_as_bytes(self): + """recv receives a text message as bytes.""" await self.remote_connection.send("😀") self.assertEqual(await self.connection.recv(decode=False), "😀".encode()) - async def test_recv_decoded_binary(self): - """recv receives an UTF-8 decoded binary message.""" + async def test_recv_binary_as_text(self): + """recv receives a binary message as a str.""" await self.remote_connection.send("😀".encode()) self.assertEqual(await self.connection.recv(decode=True), "😀") @@ -304,16 +304,16 @@ async def test_recv_streaming_binary(self): [b"\x01\x02\xfe\xff"], ) - async def test_recv_streaming_encoded_text(self): - """recv_streaming receives an UTF-8 encoded text message.""" + async def test_recv_streaming_text_as_bytes(self): + """recv_streaming receives a text message as bytes.""" await self.remote_connection.send("😀") self.assertEqual( await alist(self.connection.recv_streaming(decode=False)), ["😀".encode()], ) - async def test_recv_streaming_decoded_binary(self): - """recv_streaming receives a UTF-8 decoded binary message.""" + async def test_recv_streaming_binary_as_str(self): + """recv_streaming receives a binary message as a str.""" await self.remote_connection.send("😀".encode()) self.assertEqual( await alist(self.connection.recv_streaming(decode=True)), @@ -438,6 +438,16 @@ async def test_send_binary(self): await self.connection.send(b"\x01\x02\xfe\xff") self.assertEqual(await self.remote_connection.recv(), b"\x01\x02\xfe\xff") + async def test_send_binary_from_str(self): + """send sends a binary message from a str.""" + await self.connection.send("😀", text=False) + self.assertEqual(await self.remote_connection.recv(), "😀".encode()) + + async def test_send_text_from_bytes(self): + """send sends a text message from bytes.""" + await self.connection.send("😀".encode(), text=True) + self.assertEqual(await self.remote_connection.recv(), "😀") + async def test_send_fragmented_text(self): """send sends a fragmented text message.""" await self.connection.send(["😀", "😀"]) @@ -456,6 +466,24 @@ async def test_send_fragmented_binary(self): [b"\x01\x02", b"\xfe\xff", b""], ) + async def test_send_fragmented_binary_from_str(self): + """send sends a fragmented binary message from a str.""" + await self.connection.send(["😀", "😀"], text=False) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀".encode(), "😀".encode(), b""], + ) + + async def test_send_fragmented_text_from_bytes(self): + """send sends a fragmented text message from bytes.""" + await self.connection.send(["😀".encode(), "😀".encode()], text=True) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + async def test_send_async_fragmented_text(self): """send sends a fragmented text message asynchronously.""" @@ -484,6 +512,34 @@ async def fragments(): [b"\x01\x02", b"\xfe\xff", b""], ) + async def test_send_async_fragmented_binary_from_str(self): + """send sends a fragmented binary message from a str asynchronously.""" + + async def fragments(): + yield "😀" + yield "😀" + + await self.connection.send(fragments(), text=False) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀".encode(), "😀".encode(), b""], + ) + + async def test_send_async_fragmented_text_from_bytes(self): + """send sends a fragmented text message from bytes asynchronously.""" + + async def fragments(): + yield "😀".encode() + yield "😀".encode() + + await self.connection.send(fragments(), text=True) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + async def test_send_connection_closed_ok(self): """send raises ConnectionClosedOK after a normal closure.""" await self.remote_connection.close() From 7fdd932c6b29a9ef4db46b2141371f42207b4f00 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2024 14:21:00 +0200 Subject: [PATCH 1422/1539] Review & update gitignore. --- .gitignore | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index d8e6697a8..291bf1fb6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,16 +1,16 @@ *.pyc *.so .coverage -.direnv +.direnv/ .envrc .idea/ -.mypy_cache -.tox +.mypy_cache/ +.tox/ +.vscode/ build/ compliance/reports/ -experiments/compression/corpus/ dist/ docs/_build/ +experiments/compression/corpus/ htmlcov/ -MANIFEST -websockets.egg-info/ +src/websockets.egg-info/ From c5985d5c4192390b2da58ae97015e3ab1ba41cd2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Oct 2024 16:21:16 +0200 Subject: [PATCH 1423/1539] Update compliance test suite. --- Makefile | 4 +- compliance/README.rst | 74 ++++++++++++++++++---------- compliance/asyncio/client.py | 52 +++++++++++++++++++ compliance/asyncio/server.py | 32 ++++++++++++ compliance/config/fuzzingclient.json | 11 +++++ compliance/config/fuzzingserver.json | 7 +++ compliance/fuzzingclient.json | 11 ----- compliance/fuzzingserver.json | 12 ----- compliance/sync/client.py | 51 +++++++++++++++++++ compliance/sync/server.py | 31 ++++++++++++ compliance/test_client.py | 48 ------------------ compliance/test_server.py | 29 ----------- 12 files changed, 234 insertions(+), 128 deletions(-) create mode 100644 compliance/asyncio/client.py create mode 100644 compliance/asyncio/server.py create mode 100644 compliance/config/fuzzingclient.json create mode 100644 compliance/config/fuzzingserver.json delete mode 100644 compliance/fuzzingclient.json delete mode 100644 compliance/fuzzingserver.json create mode 100644 compliance/sync/client.py create mode 100644 compliance/sync/server.py delete mode 100644 compliance/test_client.py delete mode 100644 compliance/test_server.py diff --git a/Makefile b/Makefile index fd36d0367..06bfe9edc 100644 --- a/Makefile +++ b/Makefile @@ -8,8 +8,8 @@ build: python setup.py build_ext --inplace style: - ruff format src tests - ruff check --fix src tests + ruff format compliance src tests + ruff check --fix compliance src tests types: mypy --strict src diff --git a/compliance/README.rst b/compliance/README.rst index 8570f9176..c7c7c93b4 100644 --- a/compliance/README.rst +++ b/compliance/README.rst @@ -4,47 +4,69 @@ Autobahn Testsuite General information and installation instructions are available at https://github.com/crossbario/autobahn-testsuite. -To improve performance, you should compile the C extension first:: +Running the test suite +---------------------- + +All commands below must be run from the root directory of the repository. + +To get acceptable performance, compile the C extension first: + +.. code-block:: console $ python setup.py build_ext --inplace -Running the test suite ----------------------- +Run each command in a different shell. Testing takes several minutes to complete +— wstest is the bottleneck. When clients finish, stop servers with Ctrl-C. + +You can exclude slow tests by modifying the configuration files as follows:: + + "exclude-cases": ["9.*", "12.*", "13.*"] -All commands below must be run from the directory containing this file. +The test server and client applications shouldn't display any exceptions. -To test the server:: +To test the servers: - $ PYTHONPATH=.. python test_server.py - $ wstest -m fuzzingclient +.. code-block:: console -To test the client:: + $ PYTHONPATH=src python compliance/asyncio/server.py + $ PYTHONPATH=src python compliance/sync/server.py - $ wstest -m fuzzingserver - $ PYTHONPATH=.. python test_client.py + $ docker run --interactive --tty --rm \ + --volume "${PWD}/compliance/config:/config" \ + --volume "${PWD}/compliance/reports:/reports" \ + --name fuzzingclient \ + crossbario/autobahn-testsuite \ + wstest --mode fuzzingclient --spec /config/fuzzingclient.json -Run the first command in a shell. Run the second command in another shell. -It should take about ten minutes to complete — wstest is the bottleneck. -Then kill the first one with Ctrl-C. + $ open reports/servers/index.html -The test client or server shouldn't display any exceptions. The results are -stored in reports/clients/index.html. +To test the clients: -Note that the Autobahn software only supports Python 2, while ``websockets`` -only supports Python 3; you need two different environments. +.. code-block:: console + $ docker run --interactive --tty --rm \ + --volume "${PWD}/compliance/config:/config" \ + --volume "${PWD}/compliance/reports:/reports" \ + --publish 9001:9001 \ + --name fuzzingserver \ + crossbario/autobahn-testsuite \ + wstest --mode fuzzingserver --spec /config/fuzzingserver.json + + $ PYTHONPATH=src python compliance/asyncio/client.py + $ PYTHONPATH=src python compliance/sync/client.py + + $ open reports/clients/index.html Conformance notes ----------------- Some test cases are more strict than the RFC. Given the implementation of the -library and the test echo client or server, ``websockets`` gets a "Non-Strict" -in these cases. - -In 3.2, 3.3, 4.1.3, 4.1.4, 4.2.3, 4.2.4, and 5.15 ``websockets`` notices the -protocol error and closes the connection before it has had a chance to echo -the previous frame. +library and the test client and server applications, websockets passes with a +"Non-Strict" result in these cases. -In 6.4.3 and 6.4.4, even though it uses an incremental decoder, ``websockets`` -doesn't notice the invalid utf-8 fast enough to get a "Strict" pass. These -tests are more strict than the RFC. +In 3.2, 3.3, 4.1.3, 4.1.4, 4.2.3, 4.2.4, and 5.15 websockets notices the +protocol error and closes the connection at the library level before the +application gets a chance to echo the previous frame. +In 6.4.3 and 6.4.4, even though it uses an incremental decoder, websockets +doesn't notice the invalid utf-8 fast enough to get a "Strict" pass. These tests +are more strict than the RFC. diff --git a/compliance/asyncio/client.py b/compliance/asyncio/client.py new file mode 100644 index 000000000..5b0bfb3ae --- /dev/null +++ b/compliance/asyncio/client.py @@ -0,0 +1,52 @@ +import asyncio +import json +import logging + +from websockets.asyncio.client import connect +from websockets.exceptions import WebSocketException + + +logging.basicConfig(level=logging.WARNING) + +SERVER = "ws://127.0.0.1:9001" + + +async def get_case_count(): + async with connect(f"{SERVER}/getCaseCount") as ws: + return json.loads(await ws.recv()) + + +async def run_case(case): + async with connect( + f"{SERVER}/runCase?case={case}", + user_agent_header="websockets.asyncio", + max_size=2**25, + ) as ws: + async for msg in ws: + await ws.send(msg) + + +async def update_reports(): + async with connect(f"{SERVER}/updateReports", open_timeout=60): + pass + + +async def main(): + cases = await get_case_count() + for case in range(1, cases + 1): + print(f"Running test case {case:03d} / {cases}... ", end="\t") + try: + await run_case(case) + except WebSocketException as exc: + print(f"ERROR: {type(exc).__name__}: {exc}") + except Exception as exc: + print(f"FAIL: {type(exc).__name__}: {exc}") + else: + print("OK") + print("Ran {cases} test cases") + await update_reports() + print("Updated reports") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/compliance/asyncio/server.py b/compliance/asyncio/server.py new file mode 100644 index 000000000..cff2728c9 --- /dev/null +++ b/compliance/asyncio/server.py @@ -0,0 +1,32 @@ +import asyncio +import logging + +from websockets.asyncio.server import serve + + +logging.basicConfig(level=logging.WARNING) + +HOST, PORT = "0.0.0.0", 9002 + + +async def echo(ws): + async for msg in ws: + await ws.send(msg) + + +async def main(): + async with serve( + echo, + HOST, + PORT, + server_header="websockets.sync", + max_size=2**25, + ) as server: + try: + await server.serve_forever() + except KeyboardInterrupt: + pass + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/compliance/config/fuzzingclient.json b/compliance/config/fuzzingclient.json new file mode 100644 index 000000000..37f8f3a35 --- /dev/null +++ b/compliance/config/fuzzingclient.json @@ -0,0 +1,11 @@ + +{ + "servers": [{ + "url": "ws://host.docker.internal:9002" + }, { + "url": "ws://host.docker.internal:9003" + }], + "outdir": "./reports/servers", + "cases": ["*"], + "exclude-cases": [] +} diff --git a/compliance/config/fuzzingserver.json b/compliance/config/fuzzingserver.json new file mode 100644 index 000000000..1bcb5959d --- /dev/null +++ b/compliance/config/fuzzingserver.json @@ -0,0 +1,7 @@ + +{ + "url": "ws://localhost:9001", + "outdir": "./reports/clients", + "cases": ["*"], + "exclude-cases": [] +} diff --git a/compliance/fuzzingclient.json b/compliance/fuzzingclient.json deleted file mode 100644 index 202ff49a0..000000000 --- a/compliance/fuzzingclient.json +++ /dev/null @@ -1,11 +0,0 @@ - -{ - "options": {"failByDrop": false}, - "outdir": "./reports/servers", - - "servers": [{"agent": "websockets", "url": "ws://localhost:8642", "options": {"version": 18}}], - - "cases": ["*"], - "exclude-cases": [], - "exclude-agent-cases": {} -} diff --git a/compliance/fuzzingserver.json b/compliance/fuzzingserver.json deleted file mode 100644 index 1bdb42723..000000000 --- a/compliance/fuzzingserver.json +++ /dev/null @@ -1,12 +0,0 @@ - -{ - "url": "ws://localhost:8642", - - "options": {"failByDrop": false}, - "outdir": "./reports/clients", - "webport": 8080, - - "cases": ["*"], - "exclude-cases": [], - "exclude-agent-cases": {} -} diff --git a/compliance/sync/client.py b/compliance/sync/client.py new file mode 100644 index 000000000..e585496f3 --- /dev/null +++ b/compliance/sync/client.py @@ -0,0 +1,51 @@ +import json +import logging + +from websockets.exceptions import WebSocketException +from websockets.sync.client import connect + + +logging.basicConfig(level=logging.WARNING) + +SERVER = "ws://127.0.0.1:9001" + + +def get_case_count(): + with connect(f"{SERVER}/getCaseCount") as ws: + return json.loads(ws.recv()) + + +def run_case(case): + with connect( + f"{SERVER}/runCase?case={case}", + user_agent_header="websockets.sync", + max_size=2**25, + ) as ws: + for msg in ws: + ws.send(msg) + + +def update_reports(): + with connect(f"{SERVER}/updateReports", open_timeout=60): + pass + + +def main(): + cases = get_case_count() + for case in range(1, cases + 1): + print(f"Running test case {case:03d} / {cases}... ", end="\t") + try: + run_case(case) + except WebSocketException as exc: + print(f"ERROR: {type(exc).__name__}: {exc}") + except Exception as exc: + print(f"FAIL: {type(exc).__name__}: {exc}") + else: + print("OK") + print("Ran {cases} test cases") + update_reports() + print("Updated reports") + + +if __name__ == "__main__": + main() diff --git a/compliance/sync/server.py b/compliance/sync/server.py new file mode 100644 index 000000000..c3cb4d989 --- /dev/null +++ b/compliance/sync/server.py @@ -0,0 +1,31 @@ +import logging + +from websockets.sync.server import serve + + +logging.basicConfig(level=logging.WARNING) + +HOST, PORT = "0.0.0.0", 9003 + + +def echo(ws): + for msg in ws: + ws.send(msg) + + +def main(): + with serve( + echo, + HOST, + PORT, + server_header="websockets.asyncio", + max_size=2**25, + ) as server: + try: + server.serve_forever() + except KeyboardInterrupt: + pass + + +if __name__ == "__main__": + main() diff --git a/compliance/test_client.py b/compliance/test_client.py deleted file mode 100644 index 8e22569fd..000000000 --- a/compliance/test_client.py +++ /dev/null @@ -1,48 +0,0 @@ -import asyncio -import json -import logging -import urllib.parse - -from websockets.asyncio.client import connect - - -logging.basicConfig(level=logging.WARNING) - -# Uncomment this line to make only websockets more verbose. -# logging.getLogger('websockets').setLevel(logging.DEBUG) - - -SERVER = "ws://127.0.0.1:8642" -AGENT = "websockets" - - -async def get_case_count(server): - uri = f"{server}/getCaseCount" - async with connect(uri) as ws: - msg = ws.recv() - return json.loads(msg) - - -async def run_case(server, case, agent): - uri = f"{server}/runCase?case={case}&agent={agent}" - async with connect(uri, max_size=2 ** 25, max_queue=1) as ws: - async for msg in ws: - await ws.send(msg) - - -async def update_reports(server, agent): - uri = f"{server}/updateReports?agent={agent}" - async with connect(uri): - pass - - -async def run_tests(server, agent): - cases = await get_case_count(server) - for case in range(1, cases + 1): - print(f"Running test case {case} out of {cases}", end="\r") - await run_case(server, case, agent) - print(f"Ran {cases} test cases ") - await update_reports(server, agent) - - -asyncio.run(run_tests(SERVER, urllib.parse.quote(AGENT))) diff --git a/compliance/test_server.py b/compliance/test_server.py deleted file mode 100644 index 39176e902..000000000 --- a/compliance/test_server.py +++ /dev/null @@ -1,29 +0,0 @@ -import asyncio -import logging - -from websockets.asyncio.server import serve - - -logging.basicConfig(level=logging.WARNING) - -# Uncomment this line to make only websockets more verbose. -# logging.getLogger('websockets').setLevel(logging.DEBUG) - - -HOST, PORT = "127.0.0.1", 8642 - - -async def echo(ws): - async for msg in ws: - await ws.send(msg) - - -async def main(): - with serve(echo, HOST, PORT, max_size=2 ** 25, max_queue=1): - try: - await asyncio.get_running_loop().create_future() # run forever - except KeyboardInterrupt: - pass - - -asyncio.run(main()) From b2f0a7647f1402c84a8dabb391c3ca7371975eb3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Oct 2024 21:32:30 +0200 Subject: [PATCH 1424/1539] Add option to force sending text or binary frames. This adds the same functionality to the threading implemetation as bc4b8f2 did to the asyncio implementation. Refs #1515. --- src/websockets/asyncio/connection.py | 43 +++++++++++--------- src/websockets/sync/connection.py | 61 +++++++++++++++++----------- tests/sync/test_connection.py | 28 +++++++++++++ 3 files changed, 90 insertions(+), 42 deletions(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 12871e4b3..3b81e386b 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -409,19 +409,17 @@ async def send( # strings and bytes-like objects are iterable. if isinstance(message, str): - if text is False: - async with self.send_context(): + async with self.send_context(): + if text is False: self.protocol.send_binary(message.encode()) - else: - async with self.send_context(): + else: self.protocol.send_text(message.encode()) elif isinstance(message, BytesLike): - if text is True: - async with self.send_context(): + async with self.send_context(): + if text is True: self.protocol.send_text(message) - else: - async with self.send_context(): + else: self.protocol.send_binary(message) # Catch a common mistake -- passing a dict to send(). @@ -443,19 +441,17 @@ async def send( try: # First fragment. if isinstance(chunk, str): - if text is False: - async with self.send_context(): + async with self.send_context(): + if text is False: self.protocol.send_binary(chunk.encode(), fin=False) - else: - async with self.send_context(): + else: self.protocol.send_text(chunk.encode(), fin=False) encode = True elif isinstance(chunk, BytesLike): - if text is True: - async with self.send_context(): + async with self.send_context(): + if text is True: self.protocol.send_text(chunk, fin=False) - else: - async with self.send_context(): + else: self.protocol.send_binary(chunk, fin=False) encode = False else: @@ -480,7 +476,10 @@ async def send( # We're half-way through a fragmented message and we can't # complete it. This makes the connection unusable. async with self.send_context(): - self.protocol.fail(1011, "error in fragmented message") + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "error in fragmented message", + ) raise finally: @@ -538,7 +537,10 @@ async def send( # We're half-way through a fragmented message and we can't # complete it. This makes the connection unusable. async with self.send_context(): - self.protocol.fail(1011, "error in fragmented message") + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "error in fragmented message", + ) raise finally: @@ -568,7 +570,10 @@ async def close(self, code: int = 1000, reason: str = "") -> None: # to terminate after calling a method that sends a close frame. async with self.send_context(): if self.fragmented_send_waiter is not None: - self.protocol.fail(1011, "close during fragmented message") + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "close during fragmented message", + ) else: self.protocol.send_close(code, reason) except ConnectionClosed: diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 8c5df9592..3f4cac09f 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -251,7 +251,11 @@ def recv_streaming(self) -> Iterator[Data]: "is already running recv or recv_streaming" ) from None - def send(self, message: Data | Iterable[Data]) -> None: + def send( + self, + message: Data | Iterable[Data], + text: bool | None = None, + ) -> None: """ Send a message. @@ -262,6 +266,17 @@ def send(self, message: Data | Iterable[Data]) -> None: .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + You may override this behavior with the ``text`` argument: + + * Set ``text=True`` to send a bytestring or bytes-like object + (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) as a + Text_ frame. This improves performance when the message is already + UTF-8 encoded, for example if the message contains JSON and you're + using a JSON library that produces a bytestring. + * Set ``text=False`` to send a string (:class:`str`) in a Binary_ + frame. This may be useful for servers that expect binary frames + instead of text frames. + :meth:`send` also accepts an iterable of strings, bytestrings, or bytes-like objects to enable fragmentation_. Each item is treated as a message fragment and sent in its own frame. All items must be of the @@ -300,7 +315,10 @@ def send(self, message: Data | Iterable[Data]) -> None: "cannot call send while another thread " "is already running send" ) - self.protocol.send_text(message.encode()) + if text is False: + self.protocol.send_binary(message.encode()) + else: + self.protocol.send_text(message.encode()) elif isinstance(message, BytesLike): with self.send_context(): @@ -309,7 +327,10 @@ def send(self, message: Data | Iterable[Data]) -> None: "cannot call send while another thread " "is already running send" ) - self.protocol.send_binary(message) + if text is True: + self.protocol.send_text(message) + else: + self.protocol.send_binary(message) # Catch a common mistake -- passing a dict to send(). @@ -328,7 +349,6 @@ def send(self, message: Data | Iterable[Data]) -> None: try: # First fragment. if isinstance(chunk, str): - text = True with self.send_context(): if self.send_in_progress: raise ConcurrencyError( @@ -336,12 +356,12 @@ def send(self, message: Data | Iterable[Data]) -> None: "is already running send" ) self.send_in_progress = True - self.protocol.send_text( - chunk.encode(), - fin=False, - ) + if text is False: + self.protocol.send_binary(chunk.encode(), fin=False) + else: + self.protocol.send_text(chunk.encode(), fin=False) + encode = True elif isinstance(chunk, BytesLike): - text = False with self.send_context(): if self.send_in_progress: raise ConcurrencyError( @@ -349,29 +369,24 @@ def send(self, message: Data | Iterable[Data]) -> None: "is already running send" ) self.send_in_progress = True - self.protocol.send_binary( - chunk, - fin=False, - ) + if text is True: + self.protocol.send_text(chunk, fin=False) + else: + self.protocol.send_binary(chunk, fin=False) + encode = False else: raise TypeError("data iterable must contain bytes or str") # Other fragments for chunk in chunks: - if isinstance(chunk, str) and text: + if isinstance(chunk, str) and encode: with self.send_context(): assert self.send_in_progress - self.protocol.send_continuation( - chunk.encode(), - fin=False, - ) - elif isinstance(chunk, BytesLike) and not text: + self.protocol.send_continuation(chunk.encode(), fin=False) + elif isinstance(chunk, BytesLike) and not encode: with self.send_context(): assert self.send_in_progress - self.protocol.send_continuation( - chunk, - fin=False, - ) + self.protocol.send_continuation(chunk, fin=False) else: raise TypeError("data iterable must contain uniform types") diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 16f92e164..87333fd35 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -308,6 +308,16 @@ def test_send_binary(self): self.connection.send(b"\x01\x02\xfe\xff") self.assertEqual(self.remote_connection.recv(), b"\x01\x02\xfe\xff") + def test_send_binary_from_str(self): + """send sends a binary message from a str.""" + self.connection.send("😀", text=False) + self.assertEqual(self.remote_connection.recv(), "😀".encode()) + + def test_send_text_from_bytes(self): + """send sends a text message from bytes.""" + self.connection.send("😀".encode(), text=True) + self.assertEqual(self.remote_connection.recv(), "😀") + def test_send_fragmented_text(self): """send sends a fragmented text message.""" self.connection.send(["😀", "😀"]) @@ -326,6 +336,24 @@ def test_send_fragmented_binary(self): [b"\x01\x02", b"\xfe\xff", b""], ) + def test_send_fragmented_binary_from_str(self): + """send sends a fragmented binary message from a str.""" + self.connection.send(["😀", "😀"], text=False) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + list(self.remote_connection.recv_streaming()), + ["😀".encode(), "😀".encode(), b""], + ) + + def test_send_fragmented_text_from_bytes(self): + """send sends a fragmented text message from bytes.""" + self.connection.send(["😀".encode(), "😀".encode()], text=True) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + list(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + def test_send_connection_closed_ok(self): """send raises ConnectionClosedOK after a normal closure.""" self.remote_connection.close() From e5182c95a3332535a034c409d59463afbd760f0c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Oct 2024 21:45:17 +0200 Subject: [PATCH 1425/1539] Blind fix for coverage failing in GitHub Actions. It doesn't fail locally. --- pyproject.toml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6a0ab8d7c..4e26c757e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,15 +60,19 @@ source = [ [tool.coverage.report] exclude_lines = [ + "pragma: no cover", "except ImportError:", "if self.debug:", "if sys.platform != \"win32\":", "if typing.TYPE_CHECKING:", - "pragma: no cover", "raise AssertionError", "self.fail\\(\".*\"\\)", "@unittest.skip", ] +partial_branches = [ + "pragma: no branch", + "with self.assertRaises\\(.*\\)", +] [tool.ruff] target-version = "py312" From 1387c976833956ea4d44ed0bd541fe648a065ed7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 25 Oct 2024 13:26:52 +0200 Subject: [PATCH 1426/1539] Rewrite sync Assembler to improve performance. Previously, a latch was used to synchronize the user thread reading messages and the background thread reading from the network. This required two thread switches per message. Now, the background thread writes messages to queue, from which the user thread reads. This allows passing several frames at each thread switch, reducing the overhead. With this server code: async def test(websocket): for i in range(int(await websocket.recv())): await websocket.send(f"{{\"iteration\": {i}}}") async with serve(test, "localhost", 8765) as server: await server.serve_forever() and this client code: with connect("ws://localhost:8765", compression=None) as websocket: websocket.send("1_000_000") for message in websocket: pass an unscientific benchmark (running it on my laptop) shows a 2.5x speedup, going from 11 seconds to 4.4 seconds. Setting a very large recv_bufsize and max_size doesn't yield significant further improvement. Flow control was tested by inserting debug logs in maybe_pause/resume() and by measuring the wait for the recv_flow_control lock. It showed the expected behavior of pausing and unpausing coupled with some wait time. The new implementation mirrors the asyncio implementation and gains the option to prevent or force decoding of frames. Fix #1376 for the threading implementation. --- docs/project/changelog.rst | 16 +- src/websockets/asyncio/client.py | 2 +- src/websockets/asyncio/messages.py | 59 +++-- src/websockets/asyncio/server.py | 2 +- src/websockets/sync/client.py | 12 +- src/websockets/sync/connection.py | 79 ++++-- src/websockets/sync/messages.py | 333 ++++++++++++------------ src/websockets/sync/server.py | 12 +- tests/asyncio/test_connection.py | 19 +- tests/asyncio/test_messages.py | 12 +- tests/sync/test_connection.py | 106 +++++--- tests/sync/test_messages.py | 400 ++++++++++++++--------------- 12 files changed, 585 insertions(+), 467 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index f5b4812bd..410671239 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -70,10 +70,21 @@ Backwards-incompatible changes If you wrote an :class:`extension ` that relies on methods not provided by these new types, you may need to update your code. +New features +............ + +* Added an option to receive text frames as :class:`bytes`, without decoding, + in the :mod:`threading` implementation; also binary frames as :class:`str`. + +* Added an option to send :class:`bytes` as a text frame in the :mod:`asyncio` + and :mod:`threading` implementations, as well as :class:`str` a binary frame. + Improvements ............ -* Sending or receiving large compressed frames is now faster. +* The :mod:`threading` implementation receives messages faster. + +* Sending or receiving large compressed messages is now faster. .. _13.1: @@ -198,6 +209,9 @@ New features * Validated compatibility with Python 3.12 and 3.13. +* Added an option to receive text frames as :class:`bytes`, without decoding, + in the :mod:`asyncio` implementation; also binary frames as :class:`str`. + * Added :doc:`environment variables <../reference/variables>` to configure debug logs, the ``Server`` and ``User-Agent`` headers, as well as security limits. diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 23b1a348a..0c8bedc5d 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -45,7 +45,7 @@ class ClientConnection(Connection): closed with any other code. The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``, - and ``write_limit`` arguments the same meaning as in :func:`connect`. + and ``write_limit`` arguments have the same meaning as in :func:`connect`. Args: protocol: Sans-I/O connection. diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index e3ec5062f..09be22ba2 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -60,6 +60,7 @@ def reset(self, items: Iterable[T]) -> None: self.queue.extend(items) def abort(self) -> None: + """Close the queue, raising EOFError in get() if necessary.""" if self.get_waiter is not None and not self.get_waiter.done(): self.get_waiter.set_exception(EOFError("stream of frames ended")) # Clear the queue to avoid storing unnecessary data in memory. @@ -89,7 +90,7 @@ def __init__( # pragma: no cover pause: Callable[[], Any] = lambda: None, resume: Callable[[], Any] = lambda: None, ) -> None: - # Queue of incoming messages. Each item is a queue of frames. + # Queue of incoming frames. self.frames: SimpleQueue[Frame] = SimpleQueue() # We cannot put a hard limit on the size of the queue because a single @@ -140,36 +141,35 @@ async def get(self, decode: bool | None = None) -> Data: if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - # Locking with get_in_progress ensures only one coroutine can get here. self.get_in_progress = True - # First frame + # Locking with get_in_progress prevents concurrent execution until + # get() fetches a complete message or is cancelled. + try: + # First frame frame = await self.frames.get() - except asyncio.CancelledError: - self.get_in_progress = False - raise - self.maybe_resume() - assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY - if decode is None: - decode = frame.opcode is OP_TEXT - frames = [frame] - - # Following frames, for fragmented messages - while not frame.fin: - try: - frame = await self.frames.get() - except asyncio.CancelledError: - # Put frames already received back into the queue - # so that future calls to get() can return them. - self.frames.reset(frames) - self.get_in_progress = False - raise self.maybe_resume() - assert frame.opcode is OP_CONT - frames.append(frame) - - self.get_in_progress = False + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + frames = [frame] + + # Following frames, for fragmented messages + while not frame.fin: + try: + frame = await self.frames.get() + except asyncio.CancelledError: + # Put frames already received back into the queue + # so that future calls to get() can return them. + self.frames.reset(frames) + raise + self.maybe_resume() + assert frame.opcode is OP_CONT + frames.append(frame) + + finally: + self.get_in_progress = False data = b"".join(frame.data for frame in frames) if decode: @@ -207,9 +207,14 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - # Locking with get_in_progress ensures only one coroutine can get here. self.get_in_progress = True + # Locking with get_in_progress prevents concurrent execution until + # get_iter() fetches a complete message or is cancelled. + + # If get_iter() raises an exception e.g. in decoder.decode(), + # get_in_progress remains set and the connection becomes unusable. + # First frame try: frame = await self.frames.get() diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index e11dd91f1..a6ae5996d 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -55,7 +55,7 @@ class ServerConnection(Connection): closed with any other code. The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``, - and ``write_limit`` arguments the same meaning as in :func:`serve`. + and ``write_limit`` arguments have the same meaning as in :func:`serve`. Args: protocol: Sans-I/O connection. diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 5e1ba6d84..42daa32ea 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -40,10 +40,12 @@ class ClientConnection(Connection): :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. + The ``close_timeout`` and ``max_queue`` arguments have the same meaning as + in :func:`connect`. + Args: socket: Socket connected to a WebSocket server. protocol: Sans-I/O connection. - close_timeout: Timeout for closing the connection in seconds. """ @@ -53,6 +55,7 @@ def __init__( protocol: ClientProtocol, *, close_timeout: float | None = 10, + max_queue: int | tuple[int, int | None] = 16, ) -> None: self.protocol: ClientProtocol self.response_rcvd = threading.Event() @@ -60,6 +63,7 @@ def __init__( socket, protocol, close_timeout=close_timeout, + max_queue=max_queue, ) def handshake( @@ -135,6 +139,7 @@ def connect( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, + max_queue: int | tuple[int, int | None] = 16, # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization @@ -183,6 +188,10 @@ def connect( :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. :obj:`None` disables the limit. + max_queue: High-water mark of the buffer where frames are received. + It defaults to 16 frames. The low-water mark defaults to ``max_queue + // 4``. You may pass a ``(high, low)`` tuple to set the high-water + and low-water marks. logger: Logger for this client. It defaults to ``logging.getLogger("websockets.client")``. See the :doc:`logging guide <../../topics/logging>` for details. @@ -287,6 +296,7 @@ def connect( sock, protocol, close_timeout=close_timeout, + max_queue=max_queue, ) except Exception: if sock is not None: diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 3f4cac09f..3ab9f4937 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -49,10 +49,14 @@ def __init__( protocol: Protocol, *, close_timeout: float | None = 10, + max_queue: int | tuple[int, int | None] = 16, ) -> None: self.socket = socket self.protocol = protocol self.close_timeout = close_timeout + if isinstance(max_queue, int): + max_queue = (max_queue, None) + self.max_queue = max_queue # Inject reference to this instance in the protocol's logger. self.protocol.logger = logging.LoggerAdapter( @@ -76,8 +80,15 @@ def __init__( # Mutex serializing interactions with the protocol. self.protocol_mutex = threading.Lock() + # Lock stopping reads when the assembler buffer is full. + self.recv_flow_control = threading.Lock() + # Assembler turning frames into messages and serializing reads. - self.recv_messages = Assembler() + self.recv_messages = Assembler( + *self.max_queue, + pause=self.recv_flow_control.acquire, + resume=self.recv_flow_control.release, + ) # Whether we are busy sending a fragmented message. self.send_in_progress = False @@ -88,6 +99,10 @@ def __init__( # Mapping of ping IDs to pong waiters, in chronological order. self.ping_waiters: dict[bytes, threading.Event] = {} + # Exception raised in recv_events, to be chained to ConnectionClosed + # in the user thread in order to show why the TCP connection dropped. + self.recv_exc: BaseException | None = None + # Receiving events from the socket. This thread is marked as daemon to # allow creating a connection in a non-daemon thread and using it in a # daemon thread. This mustn't prevent the interpreter from exiting. @@ -97,10 +112,6 @@ def __init__( ) self.recv_events_thread.start() - # Exception raised in recv_events, to be chained to ConnectionClosed - # in the user thread in order to show why the TCP connection dropped. - self.recv_exc: BaseException | None = None - # Public attributes @property @@ -172,7 +183,7 @@ def __iter__(self) -> Iterator[Data]: except ConnectionClosedOK: return - def recv(self, timeout: float | None = None) -> Data: + def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data: """ Receive the next message. @@ -191,6 +202,11 @@ def recv(self, timeout: float | None = None) -> Data: If the message is fragmented, wait until all fragments are received, reassemble them, and return the whole message. + Args: + timeout: Timeout for receiving a message in seconds. + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + Returns: A string (:class:`str`) for a Text_ frame or a bytestring (:class:`bytes`) for a Binary_ frame. @@ -198,6 +214,16 @@ def recv(self, timeout: float | None = None) -> Data: .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and + return a bytestring (:class:`bytes`). This improves performance + when decoding isn't needed, for example if the message contains + JSON and you're using a JSON library that expects a bytestring. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames + and return a string (:class:`str`). This may be useful for + servers that send binary frames instead of text frames. + Raises: ConnectionClosed: When the connection is closed. ConcurrencyError: If two threads call :meth:`recv` or @@ -205,7 +231,7 @@ def recv(self, timeout: float | None = None) -> Data: """ try: - return self.recv_messages.get(timeout) + return self.recv_messages.get(timeout, decode) except EOFError: # Wait for the protocol state to be CLOSED before accessing close_exc. self.recv_events_thread.join() @@ -216,16 +242,23 @@ def recv(self, timeout: float | None = None) -> Data: "is already running recv or recv_streaming" ) from None - def recv_streaming(self) -> Iterator[Data]: + def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]: """ Receive the next message frame by frame. - If the message is fragmented, yield each fragment as it is received. - The iterator must be fully consumed, or else the connection will become + This method is designed for receiving fragmented messages. It returns an + iterator that yields each fragment as it is received. This iterator must + be fully consumed. Else, future calls to :meth:`recv` or + :meth:`recv_streaming` will raise + :exc:`~websockets.exceptions.ConcurrencyError`, making the connection unusable. :meth:`recv_streaming` raises the same exceptions as :meth:`recv`. + Args: + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + Returns: An iterator of strings (:class:`str`) for a Text_ frame or bytestrings (:class:`bytes`) for a Binary_ frame. @@ -233,6 +266,15 @@ def recv_streaming(self) -> Iterator[Data]: .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames + and return bytestrings (:class:`bytes`). This may be useful to + optimize performance when decoding isn't needed. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames + and return strings (:class:`str`). This is useful for servers + that send binary frames instead of text frames. + Raises: ConnectionClosed: When the connection is closed. ConcurrencyError: If two threads call :meth:`recv` or @@ -240,7 +282,7 @@ def recv_streaming(self) -> Iterator[Data]: """ try: - yield from self.recv_messages.get_iter() + yield from self.recv_messages.get_iter(decode) except EOFError: # Wait for the protocol state to be CLOSED before accessing close_exc. self.recv_events_thread.join() @@ -571,8 +613,9 @@ def recv_events(self) -> None: try: while True: try: - if self.close_deadline is not None: - self.socket.settimeout(self.close_deadline.timeout()) + with self.recv_flow_control: + if self.close_deadline is not None: + self.socket.settimeout(self.close_deadline.timeout()) data = self.socket.recv(self.recv_bufsize) except Exception as exc: if self.debug: @@ -622,13 +665,9 @@ def recv_events(self) -> None: # Given that automatic responses write small amounts of data, # this should be uncommon, so we don't handle the edge case. - try: - for event in events: - # This may raise EOFError if the closing handshake - # times out while a message is waiting to be read. - self.process_event(event) - except EOFError: - break + for event in events: + # This isn't expected to raise an exception. + self.process_event(event) # Breaking out of the while True: ... loop means that we believe # that the socket doesn't work anymore. diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 997fa98df..983b114dc 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -3,12 +3,12 @@ import codecs import queue import threading -from collections.abc import Iterator -from typing import cast +from typing import Any, Callable, Iterable, Iterator from ..exceptions import ConcurrencyError from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame from ..typing import Data +from .utils import Deadline __all__ = ["Assembler"] @@ -20,47 +20,83 @@ class Assembler: """ Assemble messages from frames. + :class:`Assembler` expects only data frames. The stream of frames must + respect the protocol; if it doesn't, the behavior is undefined. + + Args: + pause: Called when the buffer of frames goes above the high water mark; + should pause reading from the network. + resume: Called when the buffer of frames goes below the low water mark; + should resume reading from the network. + """ - def __init__(self) -> None: + def __init__( + self, + high: int = 16, + low: int | None = None, + pause: Callable[[], Any] = lambda: None, + resume: Callable[[], Any] = lambda: None, + ) -> None: # Serialize reads and writes -- except for reads via synchronization # primitives provided by the threading and queue modules. self.mutex = threading.Lock() - # We create a latch with two events to synchronize the production of - # frames and the consumption of messages (or frames) without a buffer. - # This design requires a switch between the library thread and the user - # thread for each message; that shouldn't be a performance bottleneck. - - # put() sets this event to tell get() that a message can be fetched. - self.message_complete = threading.Event() - # get() sets this event to let put() that the message was fetched. - self.message_fetched = threading.Event() + # Queue of incoming frames. + self.frames: queue.SimpleQueue[Frame | None] = queue.SimpleQueue() + + # We cannot put a hard limit on the size of the queue because a single + # call to Protocol.data_received() could produce thousands of frames, + # which must be buffered. Instead, we pause reading when the buffer goes + # above the high limit and we resume when it goes under the low limit. + if low is None: + low = high // 4 + if low < 0: + raise ValueError("low must be positive or equal to zero") + if high < low: + raise ValueError("high must be greater than or equal to low") + self.high, self.low = high, low + self.pause = pause + self.resume = resume + self.paused = False # This flag prevents concurrent calls to get() by user code. self.get_in_progress = False - # This flag prevents concurrent calls to put() by library code. - self.put_in_progress = False - - # Decoder for text frames, None for binary frames. - self.decoder: codecs.IncrementalDecoder | None = None - - # Buffer of frames belonging to the same message. - self.chunks: list[Data] = [] - - # When switching from "buffering" to "streaming", we use a thread-safe - # queue for transferring frames from the writing thread (library code) - # to the reading thread (user code). We're buffering when chunks_queue - # is None and streaming when it's a SimpleQueue. None is a sentinel - # value marking the end of the message, superseding message_complete. - - # Stream data from frames belonging to the same message. - self.chunks_queue: queue.SimpleQueue[Data | None] | None = None # This flag marks the end of the connection. self.closed = False - def get(self, timeout: float | None = None) -> Data: + def get_next_frame(self, timeout: float | None = None) -> Frame: + # Helper to factor out the logic for getting the next frame from the + # queue, while handling timeouts and reaching the end of the stream. + try: + frame = self.frames.get(timeout=timeout) + except queue.Empty: + raise TimeoutError(f"timed out in {timeout:.1f}s") from None + if frame is None: + raise EOFError("stream of frames ended") + return frame + + def reset_queue(self, frames: Iterable[Frame]) -> None: + # Helper to put frames back into the queue after they were fetched. + # This happens only when the queue is empty. However, by the time + # we acquire self.mutex, put() may have added items in the queue. + # Therefore, we must handle the case where the queue is not empty. + frame: Frame | None + with self.mutex: + queued = [] + try: + while True: + queued.append(self.frames.get_nowait()) + except queue.Empty: + pass + for frame in frames: + self.frames.put(frame) + # This loop runs only when a race condition occurs. + for frame in queued: # pragma: no cover + self.frames.put(frame) + + def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: """ Read the next message. @@ -73,11 +109,14 @@ def get(self, timeout: float | None = None) -> Data: Args: timeout: If a timeout is provided and elapses before a complete message is received, :meth:`get` raises :exc:`TimeoutError`. + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. Raises: EOFError: If the stream of frames has ended. - ConcurrencyError: If two threads run :meth:`get` or :meth:`get_iter` - concurrently. + ConcurrencyError: If two coroutines run :meth:`get` or + :meth:`get_iter` concurrently. TimeoutError: If a timeout is provided and elapses before a complete message is received. @@ -89,40 +128,45 @@ def get(self, timeout: float | None = None) -> Data: if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") + # Locking with get_in_progress ensures only one thread can get here. self.get_in_progress = True - # If the message_complete event isn't set yet, release the lock to - # allow put() to run and eventually set it. - # Locking with get_in_progress ensures only one thread can get here. - completed = self.message_complete.wait(timeout) + try: + deadline = Deadline(timeout) + + # First frame + frame = self.get_next_frame(deadline.timeout()) + with self.mutex: + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + frames = [frame] + + # Following frames, for fragmented messages + while not frame.fin: + try: + frame = self.get_next_frame(deadline.timeout()) + except TimeoutError: + # Put frames already received back into the queue + # so that future calls to get() can return them. + self.reset_queue(frames) + raise + with self.mutex: + self.maybe_resume() + assert frame.opcode is OP_CONT + frames.append(frame) - with self.mutex: + finally: self.get_in_progress = False - # Waiting for a complete message timed out. - if not completed: - raise TimeoutError(f"timed out in {timeout:.1f}s") - - # get() was unblocked by close() rather than put(). - if self.closed: - raise EOFError("stream of frames ended") - - assert self.message_complete.is_set() - self.message_complete.clear() - - joiner: Data = b"" if self.decoder is None else "" - # mypy cannot figure out that chunks have the proper type. - message: Data = joiner.join(self.chunks) # type: ignore + data = b"".join(frame.data for frame in frames) + if decode: + return data.decode() + else: + return data - self.chunks = [] - assert self.chunks_queue is None - - assert not self.message_fetched.is_set() - self.message_fetched.set() - - return message - - def get_iter(self) -> Iterator[Data]: + def get_iter(self, decode: bool | None = None) -> Iterator[Data]: """ Stream the next message. @@ -135,10 +179,15 @@ def get_iter(self) -> Iterator[Data]: This method only makes sense for fragmented messages. If messages aren't fragmented, use :meth:`get` instead. + Args: + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. + Raises: EOFError: If the stream of frames has ended. - ConcurrencyError: If two threads run :meth:`get` or :meth:`get_iter` - concurrently. + ConcurrencyError: If two coroutines run :meth:`get` or + :meth:`get_iter` concurrently. """ with self.mutex: @@ -148,116 +197,81 @@ def get_iter(self) -> Iterator[Data]: if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - chunks = self.chunks - self.chunks = [] - self.chunks_queue = cast( - # Remove quotes around type when dropping Python < 3.10. - "queue.SimpleQueue[Data | None]", - queue.SimpleQueue(), - ) - - # Sending None in chunk_queue supersedes setting message_complete - # when switching to "streaming". If message is already complete - # when the switch happens, put() didn't send None, so we have to. - if self.message_complete.is_set(): - self.chunks_queue.put(None) - + # Locking with get_in_progress ensures only one coroutine can get here. self.get_in_progress = True - # Locking with get_in_progress ensures only one thread can get here. - chunk: Data | None - for chunk in chunks: - yield chunk - while (chunk := self.chunks_queue.get()) is not None: - yield chunk + # Locking with get_in_progress prevents concurrent execution until + # get_iter() fetches a complete message or is cancelled. - with self.mutex: - self.get_in_progress = False + # If get_iter() raises an exception e.g. in decoder.decode(), + # get_in_progress remains set and the connection becomes unusable. - # get_iter() was unblocked by close() rather than put(). - if self.closed: - raise EOFError("stream of frames ended") - - assert self.message_complete.is_set() - self.message_complete.clear() - - assert self.chunks == [] - self.chunks_queue = None + # First frame + frame = self.get_next_frame() + with self.mutex: + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + if decode: + decoder = UTF8Decoder() + yield decoder.decode(frame.data, frame.fin) + else: + yield frame.data + + # Following frames, for fragmented messages + while not frame.fin: + frame = self.get_next_frame() + with self.mutex: + self.maybe_resume() + assert frame.opcode is OP_CONT + if decode: + yield decoder.decode(frame.data, frame.fin) + else: + yield frame.data - assert not self.message_fetched.is_set() - self.message_fetched.set() + self.get_in_progress = False def put(self, frame: Frame) -> None: """ Add ``frame`` to the next message. - When ``frame`` is the final frame in a message, :meth:`put` waits until - the message is fetched, which can be achieved by calling :meth:`get` or - by fully consuming the return value of :meth:`get_iter`. - - :meth:`put` assumes that the stream of frames respects the protocol. If - it doesn't, the behavior is undefined. - Raises: EOFError: If the stream of frames has ended. - ConcurrencyError: If two threads run :meth:`put` concurrently. """ with self.mutex: if self.closed: raise EOFError("stream of frames ended") - if self.put_in_progress: - raise ConcurrencyError("put is already running") - - if frame.opcode is OP_TEXT: - self.decoder = UTF8Decoder(errors="strict") - elif frame.opcode is OP_BINARY: - self.decoder = None - else: - assert frame.opcode is OP_CONT - - data: Data - if self.decoder is not None: - data = self.decoder.decode(frame.data, frame.fin) - else: - data = frame.data - - if self.chunks_queue is None: - self.chunks.append(data) - else: - self.chunks_queue.put(data) - - if not frame.fin: - return - - # Message is complete. Wait until it's fetched to return. - - assert not self.message_complete.is_set() - self.message_complete.set() - - if self.chunks_queue is not None: - self.chunks_queue.put(None) - - assert not self.message_fetched.is_set() - - self.put_in_progress = True - - # Release the lock to allow get() to run and eventually set the event. - # Locking with put_in_progress ensures only one coroutine can get here. - self.message_fetched.wait() - - with self.mutex: - self.put_in_progress = False - - # put() was unblocked by close() rather than get() or get_iter(). - if self.closed: - raise EOFError("stream of frames ended") - - assert self.message_fetched.is_set() - self.message_fetched.clear() - - self.decoder = None + self.frames.put(frame) + self.maybe_pause() + + # put() and get/get_iter() call maybe_pause() and maybe_resume() while + # holding self.mutex. This guarantees that the calls interleave properly. + # Specifically, it prevents a race condition where maybe_resume() would + # run before maybe_pause(), leaving the connection incorrectly paused. + + # A race condition is possible when get/get_iter() call self.frames.get() + # without holding self.mutex. However, it's harmless — and even beneficial! + # It can only result in popping an item from the queue before maybe_resume() + # runs and skipping a pause() - resume() cycle that would otherwise occur. + + def maybe_pause(self) -> None: + """Pause the writer if queue is above the high water mark.""" + assert self.mutex.locked() + # Check for "> high" to support high = 0 + if self.frames.qsize() > self.high and not self.paused: + self.paused = True + self.pause() + + def maybe_resume(self) -> None: + """Resume the writer if queue is below the low water mark.""" + assert self.mutex.locked() + # Check for "<= low" to support low = 0 + if self.frames.qsize() <= self.low and self.paused: + self.paused = False + self.resume() def close(self) -> None: """ @@ -273,12 +287,5 @@ def close(self) -> None: self.closed = True - # Unblock get or get_iter. - if self.get_in_progress: - self.message_complete.set() - if self.chunks_queue is not None: - self.chunks_queue.put(None) - - # Unblock put(). - if self.put_in_progress: - self.message_fetched.set() + # Unblock get() or get_iter(). + self.frames.put(None) diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 464c4a173..94f76b658 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -51,10 +51,12 @@ class ServerConnection(Connection): :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. + The ``close_timeout`` and ``max_queue`` arguments have the same meaning as + in :func:`serve`. + Args: socket: Socket connected to a WebSocket client. protocol: Sans-I/O connection. - close_timeout: Timeout for closing the connection in seconds. """ @@ -64,6 +66,7 @@ def __init__( protocol: ServerProtocol, *, close_timeout: float | None = 10, + max_queue: int | tuple[int, int | None] = 16, ) -> None: self.protocol: ServerProtocol self.request_rcvd = threading.Event() @@ -71,6 +74,7 @@ def __init__( socket, protocol, close_timeout=close_timeout, + max_queue=max_queue, ) self.username: str # see basic_auth() @@ -349,6 +353,7 @@ def serve( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, + max_queue: int | tuple[int, int | None] = 16, # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization @@ -427,6 +432,10 @@ def handler(websocket): :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. :obj:`None` disables the limit. + max_queue: High-water mark of the buffer where frames are received. + It defaults to 16 frames. The low-water mark defaults to ``max_queue + // 4``. You may pass a ``(high, low)`` tuple to set the high-water + and low-water marks. logger: Logger for this server. It defaults to ``logging.getLogger("websockets.server")``. See the :doc:`logging guide <../../topics/logging>` for details. @@ -548,6 +557,7 @@ def protocol_select_subprotocol( sock, protocol, close_timeout=close_timeout, + max_queue=max_queue, ) except Exception: sock.close() diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 563cf2b17..12e2bd5fa 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -793,8 +793,8 @@ async def test_close_timeout_waiting_for_connection_closed(self): self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) async def test_close_does_not_wait_for_recv(self): - # The asyncio implementation has a buffer for incoming messages. Closing - # the connection discards buffered messages. This is allowed by the RFC: + # Closing the connection discards messages buffered in the assembler. + # This is allowed by the RFC: # > However, there is no guarantee that the endpoint that has already # > sent a Close frame will continue to process data. await self.remote_connection.send("😀") @@ -1075,7 +1075,10 @@ async def test_max_queue(self): async def test_max_queue_tuple(self): """max_queue parameter configures high-water mark of frames buffer.""" - connection = Connection(Protocol(self.LOCAL), max_queue=(4, 2)) + connection = Connection( + Protocol(self.LOCAL), + max_queue=(4, 2), + ) transport = Mock() connection.connection_made(transport) self.assertEqual(connection.recv_messages.high, 4) @@ -1083,14 +1086,20 @@ async def test_max_queue_tuple(self): async def test_write_limit(self): """write_limit parameter configures high-water mark of write buffer.""" - connection = Connection(Protocol(self.LOCAL), write_limit=4096) + connection = Connection( + Protocol(self.LOCAL), + write_limit=4096, + ) transport = Mock() connection.connection_made(transport) transport.set_write_buffer_limits.assert_called_once_with(4096, None) async def test_write_limits(self): """write_limit parameter configures high and low-water marks of write buffer.""" - connection = Connection(Protocol(self.LOCAL), write_limit=(4096, 2048)) + connection = Connection( + Protocol(self.LOCAL), + write_limit=(4096, 2048), + ) transport = Mock() connection.connection_made(transport) transport.set_write_buffer_limits.assert_called_once_with(4096, 2048) diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index d2cf25c9c..2ff929d3a 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -350,7 +350,7 @@ async def test_cancel_get_iter_before_first_frame(self): self.assertEqual(fragments, ["café"]) async def test_cancel_get_iter_after_first_frame(self): - """get cannot be canceled after reading the first frame.""" + """get_iter cannot be canceled after reading the first frame.""" self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) getter_task = asyncio.create_task(alist(self.assembler.get_iter())) @@ -429,7 +429,7 @@ async def test_get_fails_when_get_is_running(self): await asyncio.sleep(0) with self.assertRaises(ConcurrencyError): await self.assembler.get() - self.assembler.close() # let task terminate + self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate async def test_get_fails_when_get_iter_is_running(self): """get cannot be called concurrently with get_iter.""" @@ -437,7 +437,7 @@ async def test_get_fails_when_get_iter_is_running(self): await asyncio.sleep(0) with self.assertRaises(ConcurrencyError): await self.assembler.get() - self.assembler.close() # let task terminate + self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate async def test_get_iter_fails_when_get_is_running(self): """get_iter cannot be called concurrently with get.""" @@ -445,7 +445,7 @@ async def test_get_iter_fails_when_get_is_running(self): await asyncio.sleep(0) with self.assertRaises(ConcurrencyError): await alist(self.assembler.get_iter()) - self.assembler.close() # let task terminate + self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate async def test_get_iter_fails_when_get_iter_is_running(self): """get_iter cannot be called concurrently.""" @@ -453,7 +453,7 @@ async def test_get_iter_fails_when_get_iter_is_running(self): await asyncio.sleep(0) with self.assertRaises(ConcurrencyError): await alist(self.assembler.get_iter()) - self.assembler.close() # let task terminate + self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate # Test setting limits @@ -463,7 +463,7 @@ async def test_set_high_water_mark(self): self.assertEqual(assembler.high, 10) async def test_set_high_and_low_water_mark(self): - """high sets the high-water mark.""" + """high sets the high-water mark and low-water mark.""" assembler = Assembler(high=10, low=5) self.assertEqual(assembler.high, 10) self.assertEqual(assembler.low, 5) diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 87333fd35..db1cc8e93 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -154,6 +154,16 @@ def test_recv_binary(self): self.remote_connection.send(b"\x01\x02\xfe\xff") self.assertEqual(self.connection.recv(), b"\x01\x02\xfe\xff") + def test_recv_text_as_bytes(self): + """recv receives a text message as bytes.""" + self.remote_connection.send("😀") + self.assertEqual(self.connection.recv(decode=False), "😀".encode()) + + def test_recv_binary_as_text(self): + """recv receives a binary message as a str.""" + self.remote_connection.send("😀".encode()) + self.assertEqual(self.connection.recv(decode=True), "😀") + def test_recv_fragmented_text(self): """recv receives a fragmented text message.""" self.remote_connection.send(["😀", "😀"]) @@ -228,6 +238,22 @@ def test_recv_streaming_binary(self): [b"\x01\x02\xfe\xff"], ) + def test_recv_streaming_text_as_bytes(self): + """recv_streaming receives a text message as bytes.""" + self.remote_connection.send("😀") + self.assertEqual( + list(self.connection.recv_streaming(decode=False)), + ["😀".encode()], + ) + + def test_recv_streaming_binary_as_str(self): + """recv_streaming receives a binary message as a str.""" + self.remote_connection.send("😀".encode()) + self.assertEqual( + list(self.connection.recv_streaming(decode=True)), + ["😀"], + ) + def test_recv_streaming_fragmented_text(self): """recv_streaming receives a fragmented text message.""" self.remote_connection.send(["😀", "😀"]) @@ -499,28 +525,17 @@ def test_close_timeout_waiting_for_connection_closed(self): # Remove socket.timeout when dropping Python < 3.10. self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) - def test_close_waits_for_recv(self): - # The sync implementation doesn't have a buffer for incoming messsages. - # It requires reading incoming frames until the close frame is reached. - # This behavior — close() blocks until recv() is called — is less than - # ideal and inconsistent with the asyncio implementation. + def test_close_does_not_wait_for_recv(self): + # Closing the connection discards messages buffered in the assembler. + # This is allowed by the RFC: + # > However, there is no guarantee that the endpoint that has already + # > sent a Close frame will continue to process data. self.remote_connection.send("😀") + self.connection.close() close_thread = threading.Thread(target=self.connection.close) close_thread.start() - # Let close() initiate the closing handshake and send a close frame. - time.sleep(MS) - self.assertTrue(close_thread.is_alive()) - - # Connection isn't closed yet. - self.connection.recv() - - # Let close() receive a close frame and finish the closing handshake. - time.sleep(MS) - self.assertFalse(close_thread.is_alive()) - - # Connection is closed now. with self.assertRaises(ConnectionClosedOK) as raised: self.connection.recv() @@ -528,24 +543,6 @@ def test_close_waits_for_recv(self): self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") self.assertIsNone(exc.__cause__) - def test_close_timeout_waiting_for_recv(self): - self.remote_connection.send("😀") - - close_thread = threading.Thread(target=self.connection.close) - close_thread.start() - - # Let close() time out during the closing handshake. - time.sleep(3 * MS) - self.assertFalse(close_thread.is_alive()) - - # Connection is closed now. - with self.assertRaises(ConnectionClosedError) as raised: - self.connection.recv() - - exc = raised.exception - self.assertEqual(str(exc), "sent 1000 (OK); no close frame received") - self.assertIsInstance(exc.__cause__, TimeoutError) - def test_close_idempotency(self): """close does nothing if the connection is already closed.""" self.connection.close() @@ -724,6 +721,45 @@ def test_pong_unsupported_type(self): with self.assertRaises(TypeError): self.connection.pong([]) + # Test parameters. + + def test_close_timeout(self): + """close_timeout parameter configures close timeout.""" + socket_, remote_socket = socket.socketpair() + self.addCleanup(socket_.close) + self.addCleanup(remote_socket.close) + connection = Connection( + socket_, + Protocol(self.LOCAL), + close_timeout=42 * MS, + ) + self.assertEqual(connection.close_timeout, 42 * MS) + + def test_max_queue(self): + """max_queue parameter configures high-water mark of frames buffer.""" + socket_, remote_socket = socket.socketpair() + self.addCleanup(socket_.close) + self.addCleanup(remote_socket.close) + connection = Connection( + socket_, + Protocol(self.LOCAL), + max_queue=4, + ) + self.assertEqual(connection.recv_messages.high, 4) + + def test_max_queue_tuple(self): + """max_queue parameter configures high-water mark of frames buffer.""" + socket_, remote_socket = socket.socketpair() + self.addCleanup(socket_.close) + self.addCleanup(remote_socket.close) + connection = Connection( + socket_, + Protocol(self.LOCAL), + max_queue=(4, 2), + ) + self.assertEqual(connection.recv_messages.high, 4) + self.assertEqual(connection.recv_messages.low, 2) + # Test attributes. def test_id(self): diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py index d44b39b88..02513894a 100644 --- a/tests/sync/test_messages.py +++ b/tests/sync/test_messages.py @@ -1,4 +1,6 @@ import time +import unittest +import unittest.mock from websockets.exceptions import ConcurrencyError from websockets.frames import OP_BINARY, OP_CONT, OP_TEXT, Frame @@ -9,66 +11,23 @@ class AssemblerTests(ThreadTestCase): - """ - Tests in this class interact a lot with hidden synchronization mechanisms: - - - get() / get_iter() and put() must run in separate threads when a final - frame is set because put() waits for get() / get_iter() to fetch the - message before returning. - - - run_in_thread() lets its target run before yielding back control on entry, - which guarantees the intended execution order of test cases. - - - run_in_thread() waits for its target to finish running before yielding - back control on exit, which allows making assertions immediately. - - - When the main thread performs actions that let another thread progress, it - must wait before making assertions, to avoid depending on scheduling. - - """ - def setUp(self): - self.assembler = Assembler() - - def tearDown(self): - """ - Check that the assembler goes back to its default state after each test. - - This removes the need for testing various sequences. - - """ - self.assertFalse(self.assembler.mutex.locked()) - self.assertFalse(self.assembler.get_in_progress) - self.assertFalse(self.assembler.put_in_progress) - if not self.assembler.closed: - self.assertFalse(self.assembler.message_complete.is_set()) - self.assertFalse(self.assembler.message_fetched.is_set()) - self.assertIsNone(self.assembler.decoder) - self.assertEqual(self.assembler.chunks, []) - self.assertIsNone(self.assembler.chunks_queue) + self.pause = unittest.mock.Mock() + self.resume = unittest.mock.Mock() + self.assembler = Assembler(high=2, low=1, pause=self.pause, resume=self.resume) # Test get def test_get_text_message_already_received(self): """get returns a text message that is already received.""" - - def putter(): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - - with self.run_in_thread(putter): - message = self.assembler.get() - + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = self.assembler.get() self.assertEqual(message, "café") def test_get_binary_message_already_received(self): """get returns a binary message that is already received.""" - - def putter(): - self.assembler.put(Frame(OP_BINARY, b"tea")) - - with self.run_in_thread(putter): - message = self.assembler.get() - + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = self.assembler.get() self.assertEqual(message, b"tea") def test_get_text_message_not_received_yet(self): @@ -99,112 +58,145 @@ def getter(): def test_get_fragmented_text_message_already_received(self): """get reassembles a fragmented a text message that is already received.""" - - def putter(): - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - self.assembler.put(Frame(OP_CONT, b"\xa9")) - - with self.run_in_thread(putter): - message = self.assembler.get() - + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + message = self.assembler.get() self.assertEqual(message, "café") def test_get_fragmented_binary_message_already_received(self): """get reassembles a fragmented binary message that is already received.""" - - def putter(): - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - self.assembler.put(Frame(OP_CONT, b"a")) - - with self.run_in_thread(putter): - message = self.assembler.get() - + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + message = self.assembler.get() self.assertEqual(message, b"tea") - def test_get_fragmented_text_message_being_received(self): - """get reassembles a fragmented text message that is partially received.""" + def test_get_fragmented_text_message_not_received_yet(self): + """get reassembles a fragmented text message when it is received.""" message = None def getter(): nonlocal message message = self.assembler.get() - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) with self.run_in_thread(getter): + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) self.assertEqual(message, "café") - def test_get_fragmented_binary_message_being_received(self): - """get reassembles a fragmented binary message that is partially received.""" + def test_get_fragmented_binary_message_not_received_yet(self): + """get reassembles a fragmented binary message when it is received.""" message = None def getter(): nonlocal message message = self.assembler.get() - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) with self.run_in_thread(getter): + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) self.assertEqual(message, b"tea") - def test_get_fragmented_text_message_not_received_yet(self): - """get reassembles a fragmented text message when it is received.""" + def test_get_fragmented_text_message_being_received(self): + """get reassembles a fragmented text message that is partially received.""" message = None def getter(): nonlocal message message = self.assembler.get() + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) with self.run_in_thread(getter): - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) self.assertEqual(message, "café") - def test_get_fragmented_binary_message_not_received_yet(self): - """get reassembles a fragmented binary message when it is received.""" + def test_get_fragmented_binary_message_being_received(self): + """get reassembles a fragmented binary message that is partially received.""" message = None def getter(): nonlocal message message = self.assembler.get() + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) with self.run_in_thread(getter): - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) self.assertEqual(message, b"tea") - # Test get_iter + def test_get_encoded_text_message(self): + """get returns a text message without UTF-8 decoding.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = self.assembler.get(decode=False) + self.assertEqual(message, b"caf\xc3\xa9") + + def test_get_decoded_binary_message(self): + """get returns a binary message with UTF-8 decoding.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = self.assembler.get(decode=True) + self.assertEqual(message, "tea") + + def test_get_resumes_reading(self): + """get resumes reading when queue goes below the high-water mark.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"water")) + + # queue is above the low-water mark + self.assembler.get() + self.resume.assert_not_called() + + # queue is at the low-water mark + self.assembler.get() + self.resume.assert_called_once_with() + + # queue is below the low-water mark + self.assembler.get() + self.resume.assert_called_once_with() + + def test_get_timeout_before_first_frame(self): + """get times out before reading the first frame.""" + with self.assertRaises(TimeoutError): + self.assembler.get(timeout=MS) - def test_get_iter_text_message_already_received(self): - """get_iter yields a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - def putter(): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = self.assembler.get() + self.assertEqual(message, "café") - with self.run_in_thread(putter): - fragments = list(self.assembler.get_iter()) + def test_get_timeout_after_first_frame(self): + """get times out after reading the first frame.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - self.assertEqual(fragments, ["café"]) + with self.assertRaises(TimeoutError): + self.assembler.get(timeout=MS) - def test_get_iter_binary_message_already_received(self): - """get_iter yields a binary message that is already received.""" + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) - def putter(): - self.assembler.put(Frame(OP_BINARY, b"tea")) + message = self.assembler.get() + self.assertEqual(message, "café") - with self.run_in_thread(putter): - fragments = list(self.assembler.get_iter()) + # Test get_iter + def test_get_iter_text_message_already_received(self): + """get_iter yields a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + fragments = list(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + def test_get_iter_binary_message_already_received(self): + """get_iter yields a binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + fragments = list(self.assembler.get_iter()) self.assertEqual(fragments, [b"tea"]) def test_get_iter_text_message_not_received_yet(self): @@ -212,6 +204,7 @@ def test_get_iter_text_message_not_received_yet(self): fragments = [] def getter(): + nonlocal fragments for fragment in self.assembler.get_iter(): fragments.append(fragment) @@ -225,6 +218,7 @@ def test_get_iter_binary_message_not_received_yet(self): fragments = [] def getter(): + nonlocal fragments for fragment in self.assembler.get_iter(): fragments.append(fragment) @@ -235,121 +229,112 @@ def getter(): def test_get_iter_fragmented_text_message_already_received(self): """get_iter yields a fragmented text message that is already received.""" - - def putter(): - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - self.assembler.put(Frame(OP_CONT, b"\xa9")) - - with self.run_in_thread(putter): - fragments = list(self.assembler.get_iter()) - + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + fragments = list(self.assembler.get_iter()) self.assertEqual(fragments, ["ca", "f", "é"]) def test_get_iter_fragmented_binary_message_already_received(self): """get_iter yields a fragmented binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + fragments = list(self.assembler.get_iter()) + self.assertEqual(fragments, [b"t", b"e", b"a"]) - def putter(): - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - self.assembler.put(Frame(OP_CONT, b"a")) - - with self.run_in_thread(putter): - fragments = list(self.assembler.get_iter()) + def test_get_iter_fragmented_text_message_not_received_yet(self): + """get_iter yields a fragmented text message when it is received.""" + iterator = self.assembler.get_iter() + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assertEqual(next(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(next(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(next(iterator), "é") - self.assertEqual(fragments, [b"t", b"e", b"a"]) + def test_get_iter_fragmented_binary_message_not_received_yet(self): + """get_iter yields a fragmented binary message when it is received.""" + iterator = self.assembler.get_iter() + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assertEqual(next(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(next(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(next(iterator), b"a") def test_get_iter_fragmented_text_message_being_received(self): """get_iter yields a fragmented text message that is partially received.""" - fragments = [] - - def getter(): - for fragment in self.assembler.get_iter(): - fragments.append(fragment) - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - with self.run_in_thread(getter): - self.assertEqual(fragments, ["ca"]) - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, ["ca", "f"]) - self.assembler.put(Frame(OP_CONT, b"\xa9")) - - self.assertEqual(fragments, ["ca", "f", "é"]) + iterator = self.assembler.get_iter() + self.assertEqual(next(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(next(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(next(iterator), "é") def test_get_iter_fragmented_binary_message_being_received(self): """get_iter yields a fragmented binary message that is partially received.""" - fragments = [] - - def getter(): - for fragment in self.assembler.get_iter(): - fragments.append(fragment) - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) - with self.run_in_thread(getter): - self.assertEqual(fragments, [b"t"]) - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, [b"t", b"e"]) - self.assembler.put(Frame(OP_CONT, b"a")) - - self.assertEqual(fragments, [b"t", b"e", b"a"]) - - def test_get_iter_fragmented_text_message_not_received_yet(self): - """get_iter yields a fragmented text message when it is received.""" - fragments = [] - - def getter(): - for fragment in self.assembler.get_iter(): - fragments.append(fragment) + iterator = self.assembler.get_iter() + self.assertEqual(next(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(next(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(next(iterator), b"a") + + def test_get_iter_encoded_text_message(self): + """get_iter yields a text message without UTF-8 decoding.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + fragments = list(self.assembler.get_iter(decode=False)) + self.assertEqual(fragments, [b"ca", b"f\xc3", b"\xa9"]) - with self.run_in_thread(getter): - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, ["ca"]) - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, ["ca", "f"]) - self.assembler.put(Frame(OP_CONT, b"\xa9")) + def test_get_iter_decoded_binary_message(self): + """get_iter yields a binary message with UTF-8 decoding.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + fragments = list(self.assembler.get_iter(decode=True)) + self.assertEqual(fragments, ["t", "e", "a"]) - self.assertEqual(fragments, ["ca", "f", "é"]) + def test_get_iter_resumes_reading(self): + """get_iter resumes reading when queue goes below the high-water mark.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) - def test_get_iter_fragmented_binary_message_not_received_yet(self): - """get_iter yields a fragmented binary message when it is received.""" - fragments = [] + iterator = self.assembler.get_iter() - def getter(): - for fragment in self.assembler.get_iter(): - fragments.append(fragment) + # queue is above the low-water mark + next(iterator) + self.resume.assert_not_called() - with self.run_in_thread(getter): - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, [b"t"]) - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, [b"t", b"e"]) - self.assembler.put(Frame(OP_CONT, b"a")) + # queue is at the low-water mark + next(iterator) + self.resume.assert_called_once_with() - self.assertEqual(fragments, [b"t", b"e", b"a"]) - - # Test timeouts + # queue is below the low-water mark + next(iterator) + self.resume.assert_called_once_with() - def test_get_with_timeout_completes(self): - """get returns a message when it is received before the timeout.""" + # Test put - def putter(): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - - with self.run_in_thread(putter): - message = self.assembler.get(MS) + def test_put_pauses_reading(self): + """put pauses reading when queue goes above the high-water mark.""" + # queue is below the high-water mark + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.pause.assert_not_called() - self.assertEqual(message, "café") + # queue is at the high-water mark + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.pause.assert_called_once_with() - def test_get_with_timeout_times_out(self): - """get raises TimeoutError when no message is received before the timeout.""" - with self.assertRaises(TimeoutError): - self.assembler.get(MS) + # queue is above the high-water mark + self.assembler.put(Frame(OP_CONT, b"a")) + self.pause.assert_called_once_with() # Test termination @@ -373,18 +358,8 @@ def closer(): with self.run_in_thread(closer): with self.assertRaises(EOFError): - list(self.assembler.get_iter()) - - def test_put_fails_when_interrupted_by_close(self): - """put raises EOFError when close is called.""" - - def closer(): - time.sleep(2 * MS) - self.assembler.close() - - with self.run_in_thread(closer): - with self.assertRaises(EOFError): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + for _ in self.assembler.get_iter(): + self.fail("no fragment expected") def test_get_fails_after_close(self): """get raises EOFError after close is called.""" @@ -396,7 +371,8 @@ def test_get_iter_fails_after_close(self): """get_iter raises EOFError after close is called.""" self.assembler.close() with self.assertRaises(EOFError): - list(self.assembler.get_iter()) + for _ in self.assembler.get_iter(): + self.fail("no fragment expected") def test_put_fails_after_close(self): """put raises EOFError after close is called.""" @@ -439,13 +415,25 @@ def test_get_iter_fails_when_get_iter_is_running(self): list(self.assembler.get_iter()) self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread - def test_put_fails_when_put_is_running(self): - """put cannot be called concurrently.""" + # Test setting limits - def putter(): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + def test_set_high_water_mark(self): + """high sets the high-water mark.""" + assembler = Assembler(high=10) + self.assertEqual(assembler.high, 10) - with self.run_in_thread(putter): - with self.assertRaises(ConcurrencyError): - self.assembler.put(Frame(OP_BINARY, b"tea")) - self.assembler.get() # unblock other thread + def test_set_high_and_low_water_mark(self): + """high sets the high-water mark and low-water mark.""" + assembler = Assembler(high=10, low=5) + self.assertEqual(assembler.high, 10) + self.assertEqual(assembler.low, 5) + + def test_set_invalid_high_water_mark(self): + """high must be a non-negative integer.""" + with self.assertRaises(ValueError): + Assembler(high=-1) + + def test_set_invalid_low_water_mark(self): + """low must be higher than high.""" + with self.assertRaises(ValueError): + Assembler(low=10, high=5) From 6c9e3f48ff48dd4ad025f307e68d1be3c0687b4e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 25 Oct 2024 14:16:06 +0200 Subject: [PATCH 1427/1539] Update feature list for encode/decode params. --- docs/reference/features.rst | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 8b04034eb..576ea1025 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -43,12 +43,18 @@ Both sides +------------------------------------+--------+--------+--------+--------+ | Send a fragmented message | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ - | Receive a fragmented message frame | ✅ | ✅ | ✅ | ❌ | + | Receive a fragmented message frame | ✅ | ✅ | — | ❌ | | by frame | | | | | +------------------------------------+--------+--------+--------+--------+ | Receive a fragmented message after | ✅ | ✅ | — | ✅ | | reassembly | | | | | +------------------------------------+--------+--------+--------+--------+ + | Force sending a message as Text or | ✅ | ✅ | — | ❌ | + | Binary | | | | | + +------------------------------------+--------+--------+--------+--------+ + | Force receiving a message as | ✅ | ✅ | — | ❌ | + | :class:`bytes` or :class:`str` | | | | | + +------------------------------------+--------+--------+--------+--------+ | Send a ping | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Respond to pings automatically | ✅ | ✅ | ✅ | ✅ | From 8315d3cbd356fb2fbbe3fd03e189e3750a4d0399 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 25 Oct 2024 16:18:30 +0200 Subject: [PATCH 1428/1539] Close the connection with code 1007 on invalid UTF-8. Fix #1523. --- docs/reference/features.rst | 5 +++++ src/websockets/asyncio/connection.py | 33 +++++++++++++++++++++++----- src/websockets/asyncio/messages.py | 2 ++ src/websockets/frames.py | 1 - src/websockets/sync/connection.py | 33 +++++++++++++++++++++++----- src/websockets/sync/messages.py | 2 ++ tests/asyncio/test_connection.py | 18 +++++++++++++++ tests/sync/test_connection.py | 18 +++++++++++++++ 8 files changed, 99 insertions(+), 13 deletions(-) diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 576ea1025..9187fa505 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -194,3 +194,8 @@ connection to a given IP address in a CONNECTING state. This behavior is mandated by :rfc:`6455`, section 4.1. However, :func:`~asyncio.client.connect()` isn't the right layer for enforcing this constraint. It's the caller's responsibility. + +It is possible to send or receive a text message containing invalid UTF-8 with +``send(not_utf8_bytes, text=True)`` and ``not_utf8_bytes = recv(decode=False)`` +respectively. As a side effect of disabling UTF-8 encoding and decoding, these +options also disable UTF-8 validation. diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 3b81e386b..2568249c7 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -268,14 +268,24 @@ async def recv(self, decode: bool | None = None) -> Data: try: return await self.recv_messages.get(decode) except EOFError: - # Wait for the protocol state to be CLOSED before accessing close_exc. - await asyncio.shield(self.connection_lost_waiter) - raise self.protocol.close_exc from self.recv_exc + pass + # fallthrough except ConcurrencyError: raise ConcurrencyError( "cannot call recv while another coroutine " "is already running recv or recv_streaming" ) from None + except UnicodeDecodeError as exc: + async with self.send_context(): + self.protocol.fail( + CloseCode.INVALID_DATA, + f"{exc.reason} at position {exc.start}", + ) + # fallthrough + + # Wait for the protocol state to be CLOSED before accessing close_exc. + await asyncio.shield(self.connection_lost_waiter) + raise self.protocol.close_exc from self.recv_exc async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]: """ @@ -324,15 +334,26 @@ async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data try: async for frame in self.recv_messages.get_iter(decode): yield frame + return except EOFError: - # Wait for the protocol state to be CLOSED before accessing close_exc. - await asyncio.shield(self.connection_lost_waiter) - raise self.protocol.close_exc from self.recv_exc + pass + # fallthrough except ConcurrencyError: raise ConcurrencyError( "cannot call recv_streaming while another coroutine " "is already running recv or recv_streaming" ) from None + except UnicodeDecodeError as exc: + async with self.send_context(): + self.protocol.fail( + CloseCode.INVALID_DATA, + f"{exc.reason} at position {exc.start}", + ) + # fallthrough + + # Wait for the protocol state to be CLOSED before accessing close_exc. + await asyncio.shield(self.connection_lost_waiter) + raise self.protocol.close_exc from self.recv_exc async def send( self, diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 09be22ba2..b57c0ca4e 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -131,6 +131,7 @@ async def get(self, decode: bool | None = None) -> Data: Raises: EOFError: If the stream of frames has ended. + UnicodeDecodeError: If a text frame contains invalid UTF-8. ConcurrencyError: If two coroutines run :meth:`get` or :meth:`get_iter` concurrently. @@ -197,6 +198,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: Raises: EOFError: If the stream of frames has ended. + UnicodeDecodeError: If a text frame contains invalid UTF-8. ConcurrencyError: If two coroutines run :meth:`get` or :meth:`get_iter` concurrently. diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 5fadf3c2d..0ff9f4d71 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -222,7 +222,6 @@ def parse( Raises: EOFError: If the connection is closed without a full WebSocket frame. - UnicodeDecodeError: If the frame contains invalid UTF-8. PayloadTooBig: If the frame's payload size exceeds ``max_size``. ProtocolError: If the frame contains incorrect values. diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 3ab9f4937..823b44f74 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -233,14 +233,24 @@ def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data try: return self.recv_messages.get(timeout, decode) except EOFError: - # Wait for the protocol state to be CLOSED before accessing close_exc. - self.recv_events_thread.join() - raise self.protocol.close_exc from self.recv_exc + pass + # fallthrough except ConcurrencyError: raise ConcurrencyError( "cannot call recv while another thread " "is already running recv or recv_streaming" ) from None + except UnicodeDecodeError as exc: + with self.send_context(): + self.protocol.fail( + CloseCode.INVALID_DATA, + f"{exc.reason} at position {exc.start}", + ) + # fallthrough + + # Wait for the protocol state to be CLOSED before accessing close_exc. + self.recv_events_thread.join() + raise self.protocol.close_exc from self.recv_exc def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]: """ @@ -283,15 +293,26 @@ def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]: """ try: yield from self.recv_messages.get_iter(decode) + return except EOFError: - # Wait for the protocol state to be CLOSED before accessing close_exc. - self.recv_events_thread.join() - raise self.protocol.close_exc from self.recv_exc + pass + # fallthrough except ConcurrencyError: raise ConcurrencyError( "cannot call recv_streaming while another thread " "is already running recv or recv_streaming" ) from None + except UnicodeDecodeError as exc: + with self.send_context(): + self.protocol.fail( + CloseCode.INVALID_DATA, + f"{exc.reason} at position {exc.start}", + ) + # fallthrough + + # Wait for the protocol state to be CLOSED before accessing close_exc. + self.recv_events_thread.join() + raise self.protocol.close_exc from self.recv_exc def send( self, diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 983b114dc..17f8dce7e 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -115,6 +115,7 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: Raises: EOFError: If the stream of frames has ended. + UnicodeDecodeError: If a text frame contains invalid UTF-8. ConcurrencyError: If two coroutines run :meth:`get` or :meth:`get_iter` concurrently. TimeoutError: If a timeout is provided and elapses before a @@ -186,6 +187,7 @@ def get_iter(self, decode: bool | None = None) -> Iterator[Data]: Raises: EOFError: If the stream of frames has ended. + UnicodeDecodeError: If a text frame contains invalid UTF-8. ConcurrencyError: If two coroutines run :meth:`get` or :meth:`get_iter` concurrently. diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 12e2bd5fa..a3b65e956 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -222,6 +222,15 @@ async def test_recv_connection_closed_error(self): with self.assertRaises(ConnectionClosedError): await self.connection.recv() + async def test_recv_non_utf8_text(self): + """recv receives a non-UTF-8 text message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) + with self.assertRaises(ConnectionClosedError): + await self.connection.recv() + await self.assertFrameSent( + Frame(Opcode.CLOSE, b"\x03\xefinvalid start byte at position 2") + ) + async def test_recv_during_recv(self): """recv raises ConcurrencyError when called concurrently.""" recv_task = asyncio.create_task(self.connection.recv()) @@ -352,6 +361,15 @@ async def test_recv_streaming_connection_closed_error(self): async for _ in self.connection.recv_streaming(): self.fail("did not raise") + async def test_recv_streaming_non_utf8_text(self): + """recv_streaming receives a non-UTF-8 text message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) + with self.assertRaises(ConnectionClosedError): + await alist(self.connection.recv_streaming()) + await self.assertFrameSent( + Frame(Opcode.CLOSE, b"\x03\xefinvalid start byte at position 2") + ) + async def test_recv_streaming_during_recv(self): """recv_streaming raises ConcurrencyError when called concurrently with recv.""" recv_task = asyncio.create_task(self.connection.recv()) diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index db1cc8e93..abdfd3f78 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -186,6 +186,15 @@ def test_recv_connection_closed_error(self): with self.assertRaises(ConnectionClosedError): self.connection.recv() + def test_recv_non_utf8_text(self): + """recv receives a non-UTF-8 text message.""" + self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) + with self.assertRaises(ConnectionClosedError): + self.connection.recv() + self.assertFrameSent( + Frame(Opcode.CLOSE, b"\x03\xefinvalid start byte at position 2") + ) + def test_recv_during_recv(self): """recv raises ConcurrencyError when called concurrently.""" recv_thread = threading.Thread(target=self.connection.recv) @@ -286,6 +295,15 @@ def test_recv_streaming_connection_closed_error(self): for _ in self.connection.recv_streaming(): self.fail("did not raise") + def test_recv_streaming_non_utf8_text(self): + """recv_streaming receives a non-UTF-8 text message.""" + self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) + with self.assertRaises(ConnectionClosedError): + list(self.connection.recv_streaming()) + self.assertFrameSent( + Frame(Opcode.CLOSE, b"\x03\xefinvalid start byte at position 2") + ) + def test_recv_streaming_during_recv(self): """recv_streaming raises ConcurrencyError when called concurrently with recv.""" recv_thread = threading.Thread(target=self.connection.recv) From 0d2e246f4ee44ece6f74066cfe608e0df906e312 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 25 Oct 2024 16:36:51 +0200 Subject: [PATCH 1429/1539] Various fixes for the compliance test suite. --- compliance/README.rst | 17 ++++++++++++----- compliance/asyncio/client.py | 21 ++++++++++++++------- compliance/asyncio/server.py | 8 ++++++-- compliance/config/fuzzingclient.json | 2 +- compliance/config/fuzzingserver.json | 2 +- compliance/sync/client.py | 21 ++++++++++++++------- compliance/sync/server.py | 8 ++++++-- 7 files changed, 54 insertions(+), 25 deletions(-) diff --git a/compliance/README.rst b/compliance/README.rst index c7c7c93b4..ee491310f 100644 --- a/compliance/README.rst +++ b/compliance/README.rst @@ -38,7 +38,7 @@ To test the servers: crossbario/autobahn-testsuite \ wstest --mode fuzzingclient --spec /config/fuzzingclient.json - $ open reports/servers/index.html + $ open compliance/reports/servers/index.html To test the clients: @@ -54,7 +54,7 @@ To test the clients: $ PYTHONPATH=src python compliance/asyncio/client.py $ PYTHONPATH=src python compliance/sync/client.py - $ open reports/clients/index.html + $ open compliance/reports/clients/index.html Conformance notes ----------------- @@ -67,6 +67,13 @@ In 3.2, 3.3, 4.1.3, 4.1.4, 4.2.3, 4.2.4, and 5.15 websockets notices the protocol error and closes the connection at the library level before the application gets a chance to echo the previous frame. -In 6.4.3 and 6.4.4, even though it uses an incremental decoder, websockets -doesn't notice the invalid utf-8 fast enough to get a "Strict" pass. These tests -are more strict than the RFC. +In 6.4.1, 6.4.2, 6.4.3, and 6.4.4, even though it uses an incremental decoder, +websockets doesn't notice the invalid utf-8 fast enough to get a "Strict" pass. +These tests are more strict than the RFC. + +Test case 7.1.5 fails because websockets treats closing the connection in the +middle of a fragmented message as a protocol error. As a consequence, it sends +a close frame with code 1002. The test suite expects a close frame with code +1000, echoing the close code that it sent. This isn't required. RFC 6455 states +that "the endpoint typically echos the status code it received", which leaves +the possibility to send a close frame with a different status code. diff --git a/compliance/asyncio/client.py b/compliance/asyncio/client.py index 5b0bfb3ae..044ed6043 100644 --- a/compliance/asyncio/client.py +++ b/compliance/asyncio/client.py @@ -8,7 +8,9 @@ logging.basicConfig(level=logging.WARNING) -SERVER = "ws://127.0.0.1:9001" +SERVER = "ws://localhost:9001" + +AGENT = "websockets.asyncio" async def get_case_count(): @@ -18,16 +20,21 @@ async def get_case_count(): async def run_case(case): async with connect( - f"{SERVER}/runCase?case={case}", - user_agent_header="websockets.asyncio", + f"{SERVER}/runCase?case={case}&agent={AGENT}", max_size=2**25, ) as ws: - async for msg in ws: - await ws.send(msg) + try: + async for msg in ws: + await ws.send(msg) + except WebSocketException: + pass async def update_reports(): - async with connect(f"{SERVER}/updateReports", open_timeout=60): + async with connect( + f"{SERVER}/updateReports?agent={AGENT}", + open_timeout=60, + ): pass @@ -43,7 +50,7 @@ async def main(): print(f"FAIL: {type(exc).__name__}: {exc}") else: print("OK") - print("Ran {cases} test cases") + print(f"Ran {cases} test cases") await update_reports() print("Updated reports") diff --git a/compliance/asyncio/server.py b/compliance/asyncio/server.py index cff2728c9..84deb9727 100644 --- a/compliance/asyncio/server.py +++ b/compliance/asyncio/server.py @@ -2,6 +2,7 @@ import logging from websockets.asyncio.server import serve +from websockets.exceptions import WebSocketException logging.basicConfig(level=logging.WARNING) @@ -10,8 +11,11 @@ async def echo(ws): - async for msg in ws: - await ws.send(msg) + try: + async for msg in ws: + await ws.send(msg) + except WebSocketException: + pass async def main(): diff --git a/compliance/config/fuzzingclient.json b/compliance/config/fuzzingclient.json index 37f8f3a35..756ad03b6 100644 --- a/compliance/config/fuzzingclient.json +++ b/compliance/config/fuzzingclient.json @@ -5,7 +5,7 @@ }, { "url": "ws://host.docker.internal:9003" }], - "outdir": "./reports/servers", + "outdir": "/reports/servers", "cases": ["*"], "exclude-cases": [] } diff --git a/compliance/config/fuzzingserver.json b/compliance/config/fuzzingserver.json index 1bcb5959d..384caf0a2 100644 --- a/compliance/config/fuzzingserver.json +++ b/compliance/config/fuzzingserver.json @@ -1,7 +1,7 @@ { "url": "ws://localhost:9001", - "outdir": "./reports/clients", + "outdir": "/reports/clients", "cases": ["*"], "exclude-cases": [] } diff --git a/compliance/sync/client.py b/compliance/sync/client.py index e585496f3..c810e1beb 100644 --- a/compliance/sync/client.py +++ b/compliance/sync/client.py @@ -7,7 +7,9 @@ logging.basicConfig(level=logging.WARNING) -SERVER = "ws://127.0.0.1:9001" +SERVER = "ws://localhost:9001" + +AGENT = "websockets.sync" def get_case_count(): @@ -17,16 +19,21 @@ def get_case_count(): def run_case(case): with connect( - f"{SERVER}/runCase?case={case}", - user_agent_header="websockets.sync", + f"{SERVER}/runCase?case={case}&agent={AGENT}", max_size=2**25, ) as ws: - for msg in ws: - ws.send(msg) + try: + for msg in ws: + ws.send(msg) + except WebSocketException: + pass def update_reports(): - with connect(f"{SERVER}/updateReports", open_timeout=60): + with connect( + f"{SERVER}/updateReports?agent={AGENT}", + open_timeout=60, + ): pass @@ -42,7 +49,7 @@ def main(): print(f"FAIL: {type(exc).__name__}: {exc}") else: print("OK") - print("Ran {cases} test cases") + print(f"Ran {cases} test cases") update_reports() print("Updated reports") diff --git a/compliance/sync/server.py b/compliance/sync/server.py index c3cb4d989..494f56a44 100644 --- a/compliance/sync/server.py +++ b/compliance/sync/server.py @@ -1,5 +1,6 @@ import logging +from websockets.exceptions import WebSocketException from websockets.sync.server import serve @@ -9,8 +10,11 @@ def echo(ws): - for msg in ws: - ws.send(msg) + try: + for msg in ws: + ws.send(msg) + except WebSocketException: + pass def main(): From 6cea05e51d50455d66e90a1888aba9be8e8809db Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 26 Oct 2024 08:06:21 +0200 Subject: [PATCH 1430/1539] Support HTTP response without Content-Length. Fix #1531. --- src/websockets/asyncio/connection.py | 13 +++++++++-- src/websockets/legacy/exceptions.py | 2 +- src/websockets/sync/connection.py | 13 +++++++++-- tests/asyncio/test_client.py | 30 +++++++++++++++++++++++++ tests/sync/test_client.py | 33 +++++++++++++++++++++++++++- 5 files changed, 85 insertions(+), 6 deletions(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 2568249c7..5545632d6 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -1060,13 +1060,22 @@ def eof_received(self) -> None: # Feed the end of the data stream to the connection. self.protocol.receive_eof() - # This isn't expected to generate events. - assert not self.protocol.events_received() + # This isn't expected to raise an exception. + events = self.protocol.events_received() # There is no error handling because send_data() can only write # the end of the data stream here and it shouldn't raise errors. self.send_data() + # This code path is triggered when receiving an HTTP response + # without a Content-Length header. This is the only case where + # reading until EOF generates an event; all other events have + # a known length. Ignore for coverage measurement because tests + # are in test_client.py rather than test_connection.py. + for event in events: # pragma: no cover + # This isn't expected to raise an exception. + self.process_event(event) + # The WebSocket protocol has its own closing handshake: endpoints close # the TCP or TLS connection after sending and receiving a close frame. # As a consequence, they never need to write after receiving EOF, so diff --git a/src/websockets/legacy/exceptions.py b/src/websockets/legacy/exceptions.py index 9ca9b7aff..e2279c825 100644 --- a/src/websockets/legacy/exceptions.py +++ b/src/websockets/legacy/exceptions.py @@ -50,7 +50,7 @@ def __init__( headers: datastructures.HeadersLike, body: bytes = b"", ) -> None: - # If a user passes an int instead of a HTTPStatus, fix it automatically. + # If a user passes an int instead of an HTTPStatus, fix it automatically. self.status = http.HTTPStatus(status) self.headers = datastructures.Headers(headers) self.body = body diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 823b44f74..8d1dbcf58 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -696,13 +696,22 @@ def recv_events(self) -> None: # Feed the end of the data stream to the protocol. self.protocol.receive_eof() - # This isn't expected to generate events. - assert not self.protocol.events_received() + # This isn't expected to raise an exception. + events = self.protocol.events_received() # There is no error handling because send_data() can only write # the end of the data stream here and it handles errors itself. self.send_data() + # This code path is triggered when receiving an HTTP response + # without a Content-Length header. This is the only case where + # reading until EOF generates an event; all other events have + # a known length. Ignore for coverage measurement because tests + # are in test_client.py rather than test_connection.py. + for event in events: # pragma: no cover + # This isn't expected to raise an exception. + self.process_event(event) + except Exception as exc: # This branch should never run. It's a safety net in case of bugs. self.logger.error("unexpected internal error", exc_info=True) diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 9354a6e0a..1b89977ea 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -401,6 +401,36 @@ def close_connection(self, request): "connection closed while reading HTTP status line", ) + async def test_http_response(self): + """Client reads HTTP response.""" + + def http_response(connection, request): + return connection.respond(http.HTTPStatus.OK, "👌") + + async with serve(*args, process_request=http_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + + self.assertEqual(raised.exception.response.status_code, 200) + self.assertEqual(raised.exception.response.body.decode(), "👌") + + async def test_http_response_without_content_length(self): + """Client reads HTTP response without a Content-Length header.""" + + def http_response(connection, request): + response = connection.respond(http.HTTPStatus.OK, "👌") + del response.headers["Content-Length"] + return response + + async with serve(*args, process_request=http_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + + self.assertEqual(raised.exception.response.status_code, 200) + self.assertEqual(raised.exception.response.body.decode(), "👌") + async def test_junk_handshake(self): """Client closes the connection when receiving non-HTTP response from server.""" diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index e63d774b7..e9b0f63ad 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -1,3 +1,4 @@ +import http import logging import socket import socketserver @@ -6,7 +7,7 @@ import time import unittest -from websockets.exceptions import InvalidHandshake, InvalidURI +from websockets.exceptions import InvalidHandshake, InvalidStatus, InvalidURI from websockets.extensions.permessage_deflate import PerMessageDeflate from websockets.sync.client import * @@ -156,6 +157,36 @@ def close_connection(self, request): "connection closed while reading HTTP status line", ) + def test_http_response(self): + """Client reads HTTP response.""" + + def http_response(connection, request): + return connection.respond(http.HTTPStatus.OK, "👌") + + with run_server(process_request=http_response) as server: + with self.assertRaises(InvalidStatus) as raised: + with connect(get_uri(server)): + self.fail("did not raise") + + self.assertEqual(raised.exception.response.status_code, 200) + self.assertEqual(raised.exception.response.body.decode(), "👌") + + def test_http_response_without_content_length(self): + """Client reads HTTP response without a Content-Length header.""" + + def http_response(connection, request): + response = connection.respond(http.HTTPStatus.OK, "👌") + del response.headers["Content-Length"] + return response + + with run_server(process_request=http_response) as server: + with self.assertRaises(InvalidStatus) as raised: + with connect(get_uri(server)): + self.fail("did not raise") + + self.assertEqual(raised.exception.response.status_code, 200) + self.assertEqual(raised.exception.response.body.decode(), "👌") + def test_junk_handshake(self): """Client closes the connection when receiving non-HTTP response from server.""" From c75b1df159d83e46a2ae29069ae92789690ed22f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 26 Oct 2024 09:22:17 +0200 Subject: [PATCH 1431/1539] Mention the option to keep the legacy implementation. --- docs/project/changelog.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 410671239..1b3b0073c 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -52,6 +52,12 @@ Backwards-incompatible changes If you're using any of them, then you must follow the :doc:`upgrade guide <../howto/upgrade>` immediately. + Alternatively, you may stick to the legacy :mod:`asyncio` implementation for + now by importing it explicitly:: + + from websockets.legacy.client import connect, unix_connect + from websockets.legacy.server import broadcast, serve, unix_serve + .. admonition:: The legacy :mod:`asyncio` implementation is now deprecated. :class: caution From d248160b6b509c5701fd749cd2ef103d244c7631 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 26 Oct 2024 09:22:52 +0200 Subject: [PATCH 1432/1539] Raise ValueError for required or unacceptable arguments. This improves consistency with asyncio and within websockets. --- docs/project/changelog.rst | 13 +++++++++++++ src/websockets/asyncio/client.py | 4 ++-- src/websockets/asyncio/server.py | 4 +++- src/websockets/sync/client.py | 6 +++--- src/websockets/sync/server.py | 8 +++++--- tests/asyncio/test_client.py | 4 ++-- tests/asyncio/test_server.py | 4 ++-- tests/sync/test_client.py | 6 +++--- tests/sync/test_server.py | 8 ++++---- 9 files changed, 37 insertions(+), 20 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 1b3b0073c..576b7252e 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -76,6 +76,19 @@ Backwards-incompatible changes If you wrote an :class:`extension ` that relies on methods not provided by these new types, you may need to update your code. +.. admonition:: Several API raise :exc:`ValueError` instead of :exc:`TypeError` + on invalid arguments. + :class: note + + :func:`~asyncio.client.connect`, :func:`~asyncio.client.unix_connect`, and + :func:`~asyncio.server.basic_auth` in the :mod:`asyncio` implementation as + well as :func:`~sync.client.connect`, :func:`~sync.client.unix_connect`, + :func:`~sync.server.serve`, :func:`~sync.server.unix_serve`, and + :func:`~sync.server.basic_auth` in the :mod:`threading` implementation now + raise :exc:`ValueError` when a required argument isn't provided or an + argument that is incompatible with others is provided. + + New features ............ diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 0c8bedc5d..ff7916d39 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -352,10 +352,10 @@ def factory() -> ClientConnection: kwargs.setdefault("ssl", True) kwargs.setdefault("server_hostname", wsuri.host) if kwargs.get("ssl") is None: - raise TypeError("ssl=None is incompatible with a wss:// URI") + raise ValueError("ssl=None is incompatible with a wss:// URI") else: if kwargs.get("ssl") is not None: - raise TypeError("ssl argument is incompatible with a ws:// URI") + raise ValueError("ssl argument is incompatible with a ws:// URI") if kwargs.pop("unix", False): _, connection = await loop.create_unix_connection(factory, **kwargs) diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index a6ae5996d..180d3a5a9 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -890,10 +890,12 @@ def basic_auth( whether they're valid. Raises: TypeError: If ``credentials`` or ``check_credentials`` is wrong. + ValueError: If ``credentials`` and ``check_credentials`` are both + provided or both not provided. """ if (credentials is None) == (check_credentials is None): - raise TypeError("provide either credentials or check_credentials") + raise ValueError("provide either credentials or check_credentials") if credentials is not None: if is_credentials(credentials): diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 42daa32ea..0aada658e 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -221,7 +221,7 @@ def connect( wsuri = parse_uri(uri) if not wsuri.secure and ssl is not None: - raise TypeError("ssl argument is incompatible with a ws:// URI") + raise ValueError("ssl argument is incompatible with a ws:// URI") # Private APIs for unix_connect() unix: bool = kwargs.pop("unix", False) @@ -229,9 +229,9 @@ def connect( if unix: if path is None and sock is None: - raise TypeError("missing path argument") + raise ValueError("missing path argument") elif path is not None and sock is not None: - raise TypeError("path and sock arguments are incompatible") + raise ValueError("path and sock arguments are incompatible") if subprotocols is not None: validate_subprotocols(subprotocols) diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 94f76b658..44dbd7290 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -477,14 +477,14 @@ def handler(websocket): if sock is None: if unix: if path is None: - raise TypeError("missing path argument") + raise ValueError("missing path argument") kwargs.setdefault("family", socket.AF_UNIX) sock = socket.create_server(path, **kwargs) else: sock = socket.create_server((host, port), **kwargs) else: if path is not None: - raise TypeError("path and sock arguments are incompatible") + raise ValueError("path and sock arguments are incompatible") # Initialize TLS wrapper @@ -667,10 +667,12 @@ def basic_auth( whether they're valid. Raises: TypeError: If ``credentials`` or ``check_credentials`` is wrong. + ValueError: If ``credentials`` and ``check_credentials`` are both + provided or both not provided. """ if (credentials is None) == (check_credentials is None): - raise TypeError("provide either credentials or check_credentials") + raise ValueError("provide either credentials or check_credentials") if credentials is not None: if is_credentials(credentials): diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 1b89977ea..231d6b8ca 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -607,7 +607,7 @@ async def test_set_server_hostname(self): class ClientUsageErrorsTests(unittest.IsolatedAsyncioTestCase): async def test_ssl_without_secure_uri(self): """Client rejects ssl when URI isn't secure.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: await connect("ws://localhost/", ssl=CLIENT_CONTEXT) self.assertEqual( str(raised.exception), @@ -616,7 +616,7 @@ async def test_ssl_without_secure_uri(self): async def test_secure_uri_without_ssl(self): """Client rejects no ssl when URI is secure.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: await connect("wss://localhost/", ssl=None) self.assertEqual( str(raised.exception), diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 47e0148a6..1dcb8c7b7 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -717,7 +717,7 @@ async def check_credentials(username, password): async def test_without_credentials_or_check_credentials(self): """basic_auth requires either credentials or check_credentials.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: basic_auth() self.assertEqual( str(raised.exception), @@ -726,7 +726,7 @@ async def test_without_credentials_or_check_credentials(self): async def test_with_credentials_and_check_credentials(self): """basic_auth requires only one of credentials and check_credentials.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: basic_auth( credentials=("hello", "iloveyou"), check_credentials=lambda: False, # pragma: no cover diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index e9b0f63ad..9d457a912 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -311,7 +311,7 @@ def test_set_server_hostname(self): class ClientUsageErrorsTests(unittest.TestCase): def test_ssl_without_secure_uri(self): """Client rejects ssl when URI isn't secure.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: connect("ws://localhost/", ssl=CLIENT_CONTEXT) self.assertEqual( str(raised.exception), @@ -320,7 +320,7 @@ def test_ssl_without_secure_uri(self): def test_unix_without_path_or_sock(self): """Unix client requires path when sock isn't provided.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: unix_connect() self.assertEqual( str(raised.exception), @@ -331,7 +331,7 @@ def test_unix_with_path_and_sock(self): """Unix client rejects path when sock is provided.""" sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.addCleanup(sock.close) - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: unix_connect(path="/", sock=sock) self.assertEqual( str(raised.exception), diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 3bc6f76cd..54e49bf16 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -361,7 +361,7 @@ def test_connection(self): class ServerUsageErrorsTests(unittest.TestCase): def test_unix_without_path_or_sock(self): """Unix server requires path when sock isn't provided.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: unix_serve(handler) self.assertEqual( str(raised.exception), @@ -372,7 +372,7 @@ def test_unix_with_path_and_sock(self): """Unix server rejects path when sock is provided.""" sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.addCleanup(sock.close) - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: unix_serve(handler, path="/", sock=sock) self.assertEqual( str(raised.exception), @@ -504,7 +504,7 @@ def check_credentials(username, password): def test_without_credentials_or_check_credentials(self): """basic_auth requires either credentials or check_credentials.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: basic_auth() self.assertEqual( str(raised.exception), @@ -513,7 +513,7 @@ def test_without_credentials_or_check_credentials(self): def test_with_credentials_and_check_credentials(self): """basic_auth requires only one of credentials and check_credentials.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: basic_auth( credentials=("hello", "iloveyou"), check_credentials=lambda: False, # pragma: no cover From 3b2c5223f7213b130d1469535fccbf9c3d08c4a9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 27 Oct 2024 07:59:16 +0100 Subject: [PATCH 1433/1539] Rewrite tips for Sans-I/O integration. --- docs/howto/sansio.rst | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/docs/howto/sansio.rst b/docs/howto/sansio.rst index d41519ff0..ca530e6a1 100644 --- a/docs/howto/sansio.rst +++ b/docs/howto/sansio.rst @@ -302,21 +302,24 @@ Tips Serialize operations .................... -The Sans-I/O layer expects to run sequentially. If your interact with it from -multiple threads or coroutines, you must ensure correct serialization. This -should happen automatically in a cooperative multitasking environment. +The Sans-I/O layer is designed to run sequentially. If you interact with it from +multiple threads or coroutines, you must ensure correct serialization. -However, you still have to make sure you don't break this property by -accident. For example, serialize writes to the network -when :meth:`~protocol.Protocol.data_to_send` returns multiple values to -prevent concurrent writes from interleaving incorrectly. +Usually, this comes for free in a cooperative multitasking environment. In a +preemptive multitasking environment, it requires mutual exclusion. -Avoid buffers -............. +Furthermore, you must serialize writes to the network. When +:meth:`~protocol.Protocol.data_to_send` returns several values, you must write +them all before starting the next write. -The Sans-I/O layer doesn't do any buffering. It makes events available in +Minimize buffers +................ + +The Sans-I/O layer doesn't perform any buffering. It makes events available in :meth:`~protocol.Protocol.events_received` as soon as they're received. -You should make incoming messages available to the application immediately and -stop further processing until the application fetches them. This will usually -result in the best performance. +You should make incoming messages available to the application immediately. + +A small buffer of incoming messages will usually result in the best performance. +It will reduce context switching between the library and the application while +ensuring that backpressure is propagated. From 018d2e5cf56ff03690bcbf271d188e76d59c62f3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 27 Oct 2024 07:59:35 +0100 Subject: [PATCH 1434/1539] Explain application-level keepalives. Fix #1514. --- docs/topics/keepalive.rst | 40 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/docs/topics/keepalive.rst b/docs/topics/keepalive.rst index 458fa3d05..003087fad 100644 --- a/docs/topics/keepalive.rst +++ b/docs/topics/keepalive.rst @@ -90,6 +90,46 @@ application layer. Read this `blog post `_ for a complete walk-through of this issue. +Application-level keepalive +--------------------------- + +Some servers require clients to send a keepalive message with a specific content +at regular intervals. Usually they expect Text_ frames rather than Ping_ frames, +meaning that you must send them with :attr:`~asyncio.connection.Connection.send` +rather than :attr:`~asyncio.connection.Connection.ping`. + +.. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + +In websockets, such keepalive mechanisms are considered as application-level +because they rely on data frames. That's unlike the protocol-level keepalive +based on control frames. Therefore, it's your responsibility to implement the +required behavior. + +You can run a task in the background to send keepalive messages: + +.. code-block:: python + + import itertools + import json + + from websockets import ConnectionClosed + + async def keepalive(websocket, ping_interval=30): + for ping in itertools.count(): + await asyncio.sleep(ping_interval) + try: + await websocket.send(json.dumps({"ping": ping})) + except ConnectionClosed: + break + + async def main(): + async with connect(...) as websocket: + keepalive_task = asyncio.create_task(keepalive(websocket)) + try: + ... # your application logic goes here + finally: + keepalive_task.cancel() + Latency issues -------------- From e44d79559d16661df4709c1b54150d735f85ae54 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 1 Nov 2024 10:52:16 +0000 Subject: [PATCH 1435/1539] Bump pypa/cibuildwheel from 2.20.0 to 2.21.3 Bumps [pypa/cibuildwheel](https://github.com/pypa/cibuildwheel) from 2.20.0 to 2.21.3. - [Release notes](https://github.com/pypa/cibuildwheel/releases) - [Changelog](https://github.com/pypa/cibuildwheel/blob/main/docs/changelog.md) - [Commits](https://github.com/pypa/cibuildwheel/compare/v2.20.0...v2.21.3) --- updated-dependencies: - dependency-name: pypa/cibuildwheel dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 863d88aa9..cc26502ca 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -53,7 +53,7 @@ jobs: with: platforms: all - name: Build wheels - uses: pypa/cibuildwheel@v2.20.0 + uses: pypa/cibuildwheel@v2.21.3 env: BUILD_EXTENSION: yes - name: Save wheels From 810bdeb7943e4cc0dcb662ef27ea764e31740e05 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 1 Nov 2024 12:35:38 +0100 Subject: [PATCH 1436/1539] Report size correctly in PayloadTooBig. Previously, it was reported incorrectly for fragmented messages. Fix #1522. --- docs/project/changelog.rst | 31 +++++--- src/websockets/exceptions.py | 41 +++++++++++ .../extensions/permessage_deflate.py | 3 +- src/websockets/frames.py | 2 +- src/websockets/legacy/framing.py | 2 +- src/websockets/protocol.py | 5 +- tests/test_exceptions.py | 20 +++++- tests/test_protocol.py | 72 ++++++++++++++----- 8 files changed, 143 insertions(+), 33 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 576b7252e..71b2a6960 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -67,15 +67,6 @@ Backwards-incompatible changes Aliases for deprecated API were removed from ``__all__``. As a consequence, they cannot be imported e.g. with ``from websockets import *`` anymore. -.. admonition:: :attr:`Frame.data ` is now a bytes-like object. - :class: note - - In addition to :class:`bytes`, it may be a :class:`bytearray` or a - :class:`memoryview`. - - If you wrote an :class:`extension ` that relies on - methods not provided by these new types, you may need to update your code. - .. admonition:: Several API raise :exc:`ValueError` instead of :exc:`TypeError` on invalid arguments. :class: note @@ -88,6 +79,26 @@ Backwards-incompatible changes raise :exc:`ValueError` when a required argument isn't provided or an argument that is incompatible with others is provided. +.. admonition:: :attr:`Frame.data ` is now a bytes-like object. + :class: note + + In addition to :class:`bytes`, it may be a :class:`bytearray` or a + :class:`memoryview`. + + If you wrote an :class:`extension ` that relies on + methods not provided by these new types, you may need to update your code. + +.. admonition:: The signature of :exc:`~exceptions.PayloadTooBig` changed. + :class: note + + If you wrote an extension that raises :exc:`~exceptions.PayloadTooBig` in + :meth:`~extensions.Extension.decode`, for example, you must replace:: + + PayloadTooBig(f"over size limit ({size} > {max_size} bytes)") + + with:: + + PayloadTooBig(size, max_size) New features ............ @@ -105,6 +116,8 @@ Improvements * Sending or receiving large compressed messages is now faster. +* Errors when a fragmented message is too large are clearer. + .. _13.1: 13.1 diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 7681736a4..be3d1ca5f 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -334,6 +334,47 @@ class PayloadTooBig(WebSocketException): """ + def __init__( + self, + size_or_message: int | None | str, + max_size: int | None = None, + cur_size: int | None = None, + ) -> None: + if isinstance(size_or_message, str): + assert max_size is None + assert cur_size is None + warnings.warn( # deprecated in 14.0 + "PayloadTooBig(message) is deprecated; " + "change to PayloadTooBig(size, max_size)", + DeprecationWarning, + ) + self.message: str | None = size_or_message + else: + self.message = None + self.size: int | None = size_or_message + assert max_size is not None + self.max_size: int = max_size + self.cur_size: int | None = None + self.set_current_size(cur_size) + + def __str__(self) -> str: + if self.message is not None: + return self.message + else: + message = "frame " + if self.size is not None: + message += f"with {self.size} bytes " + if self.cur_size is not None: + message += f"after reading {self.cur_size} bytes " + message += f"exceeds limit of {self.max_size} bytes" + return message + + def set_current_size(self, cur_size: int | None) -> None: + assert self.cur_size is None + if cur_size is not None: + self.max_size += cur_size + self.cur_size = cur_size + class InvalidState(WebSocketException, AssertionError): """ diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index ed16937d8..cefad4f56 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -139,7 +139,8 @@ def decode( try: data = self.decoder.decompress(data, max_length) if self.decoder.unconsumed_tail: - raise PayloadTooBig(f"over size limit (? > {max_size} bytes)") + assert max_size is not None # help mypy + raise PayloadTooBig(None, max_size) if frame.fin and len(frame.data) >= 2044: # This cannot generate additional data. self.decoder.decompress(_EMPTY_UNCOMPRESSED_BLOCK) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 0ff9f4d71..7898c8a5d 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -252,7 +252,7 @@ def parse( data = yield from read_exact(8) (length,) = struct.unpack("!Q", data) if max_size is not None and length > max_size: - raise PayloadTooBig(f"over size limit ({length} > {max_size} bytes)") + raise PayloadTooBig(length, max_size) if mask: mask_bytes = yield from read_exact(4) diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index 4ec194ed7..add0c6e0e 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -93,7 +93,7 @@ async def read( data = await reader(8) (length,) = struct.unpack("!Q", data) if max_size is not None and length > max_size: - raise PayloadTooBig(f"over size limit ({length} > {max_size} bytes)") + raise PayloadTooBig(length, max_size) if mask: mask_bits = await reader(4) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 091b4a23a..19b813526 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -587,6 +587,7 @@ def parse(self) -> Generator[None]: self.parser_exc = exc except PayloadTooBig as exc: + exc.set_current_size(self.cur_size) self.fail(CloseCode.MESSAGE_TOO_BIG, str(exc)) self.parser_exc = exc @@ -639,9 +640,7 @@ def recv_frame(self, frame: Frame) -> None: if frame.opcode is OP_TEXT or frame.opcode is OP_BINARY: if self.cur_size is not None: raise ProtocolError("expected a continuation frame") - if frame.fin: - self.cur_size = None - else: + if not frame.fin: self.cur_size = len(frame.data) elif frame.opcode is OP_CONT: diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 8d41bf915..fef41d136 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -160,8 +160,16 @@ def test_str(self): "invalid opcode: 7", ), ( - PayloadTooBig("payload length exceeds limit: 2 > 1 bytes"), - "payload length exceeds limit: 2 > 1 bytes", + PayloadTooBig(None, 4), + "frame exceeds limit of 4 bytes", + ), + ( + PayloadTooBig(8, 4), + "frame with 8 bytes exceeds limit of 4 bytes", + ), + ( + PayloadTooBig(8, 4, 12), + "frame with 8 bytes after reading 12 bytes exceeds limit of 16 bytes", ), ( InvalidState("WebSocket connection isn't established yet"), @@ -202,3 +210,11 @@ def test_connection_closed_attributes_deprecation_defaults(self): "use Protocol.close_reason or ConnectionClosed.rcvd.reason" ): self.assertEqual(exception.reason, "") + + def test_payload_too_big_with_message(self): + with self.assertDeprecationWarning( + "PayloadTooBig(message) is deprecated; " + "change to PayloadTooBig(size, max_size)", + ): + exc = PayloadTooBig("payload length exceeds limit: 2 > 1 bytes") + self.assertEqual(str(exc), "payload length exceeds limit: 2 > 1 bytes") diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 7f1276bb2..0ae804bb3 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -265,18 +265,28 @@ def test_client_receives_text_over_size_limit(self): client = Protocol(CLIENT, max_size=3) client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") self.assertIsInstance(client.parser_exc, PayloadTooBig) - self.assertEqual(str(client.parser_exc), "over size limit (4 > 3 bytes)") + self.assertEqual( + str(client.parser_exc), + "frame with 4 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - client, CloseCode.MESSAGE_TOO_BIG, "over size limit (4 > 3 bytes)" + client, + CloseCode.MESSAGE_TOO_BIG, + "frame with 4 bytes exceeds limit of 3 bytes", ) def test_server_receives_text_over_size_limit(self): server = Protocol(SERVER, max_size=3) server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") self.assertIsInstance(server.parser_exc, PayloadTooBig) - self.assertEqual(str(server.parser_exc), "over size limit (4 > 3 bytes)") + self.assertEqual( + str(server.parser_exc), + "frame with 4 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - server, CloseCode.MESSAGE_TOO_BIG, "over size limit (4 > 3 bytes)" + server, + CloseCode.MESSAGE_TOO_BIG, + "frame with 4 bytes exceeds limit of 3 bytes", ) def test_client_receives_text_without_size_limit(self): @@ -363,9 +373,14 @@ def test_client_receives_fragmented_text_over_size_limit(self): ) client.receive_data(b"\x80\x02\x98\x80") self.assertIsInstance(client.parser_exc, PayloadTooBig) - self.assertEqual(str(client.parser_exc), "over size limit (2 > 1 bytes)") + self.assertEqual( + str(client.parser_exc), + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - client, CloseCode.MESSAGE_TOO_BIG, "over size limit (2 > 1 bytes)" + client, + CloseCode.MESSAGE_TOO_BIG, + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", ) def test_server_receives_fragmented_text_over_size_limit(self): @@ -377,9 +392,14 @@ def test_server_receives_fragmented_text_over_size_limit(self): ) server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") self.assertIsInstance(server.parser_exc, PayloadTooBig) - self.assertEqual(str(server.parser_exc), "over size limit (2 > 1 bytes)") + self.assertEqual( + str(server.parser_exc), + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - server, CloseCode.MESSAGE_TOO_BIG, "over size limit (2 > 1 bytes)" + server, + CloseCode.MESSAGE_TOO_BIG, + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", ) def test_client_receives_fragmented_text_without_size_limit(self): @@ -533,18 +553,28 @@ def test_client_receives_binary_over_size_limit(self): client = Protocol(CLIENT, max_size=3) client.receive_data(b"\x82\x04\x01\x02\xfe\xff") self.assertIsInstance(client.parser_exc, PayloadTooBig) - self.assertEqual(str(client.parser_exc), "over size limit (4 > 3 bytes)") + self.assertEqual( + str(client.parser_exc), + "frame with 4 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - client, CloseCode.MESSAGE_TOO_BIG, "over size limit (4 > 3 bytes)" + client, + CloseCode.MESSAGE_TOO_BIG, + "frame with 4 bytes exceeds limit of 3 bytes", ) def test_server_receives_binary_over_size_limit(self): server = Protocol(SERVER, max_size=3) server.receive_data(b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff") self.assertIsInstance(server.parser_exc, PayloadTooBig) - self.assertEqual(str(server.parser_exc), "over size limit (4 > 3 bytes)") + self.assertEqual( + str(server.parser_exc), + "frame with 4 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - server, CloseCode.MESSAGE_TOO_BIG, "over size limit (4 > 3 bytes)" + server, + CloseCode.MESSAGE_TOO_BIG, + "frame with 4 bytes exceeds limit of 3 bytes", ) def test_client_sends_fragmented_binary(self): @@ -615,9 +645,14 @@ def test_client_receives_fragmented_binary_over_size_limit(self): ) client.receive_data(b"\x80\x02\xfe\xff") self.assertIsInstance(client.parser_exc, PayloadTooBig) - self.assertEqual(str(client.parser_exc), "over size limit (2 > 1 bytes)") + self.assertEqual( + str(client.parser_exc), + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - client, CloseCode.MESSAGE_TOO_BIG, "over size limit (2 > 1 bytes)" + client, + CloseCode.MESSAGE_TOO_BIG, + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", ) def test_server_receives_fragmented_binary_over_size_limit(self): @@ -629,9 +664,14 @@ def test_server_receives_fragmented_binary_over_size_limit(self): ) server.receive_data(b"\x80\x82\x00\x00\x00\x00\xfe\xff") self.assertIsInstance(server.parser_exc, PayloadTooBig) - self.assertEqual(str(server.parser_exc), "over size limit (2 > 1 bytes)") + self.assertEqual( + str(server.parser_exc), + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - server, CloseCode.MESSAGE_TOO_BIG, "over size limit (2 > 1 bytes)" + server, + CloseCode.MESSAGE_TOO_BIG, + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", ) def test_client_sends_unexpected_binary(self): From cdeb882865145399ee0fb7d0e7623418916d6b78 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 3 Nov 2024 08:48:53 +0100 Subject: [PATCH 1437/1539] Don't log an error when process_request returns a response. Fix #1513. --- src/websockets/asyncio/client.py | 6 +- src/websockets/asyncio/server.py | 17 ++-- src/websockets/protocol.py | 29 ++++-- src/websockets/server.py | 6 -- src/websockets/sync/client.py | 6 +- src/websockets/sync/server.py | 11 ++- tests/asyncio/test_connection.py | 2 +- tests/asyncio/test_server.py | 146 ++++++++++++++++++++----------- tests/test_protocol.py | 20 ++++- 9 files changed, 163 insertions(+), 80 deletions(-) diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index ff7916d39..d276ac171 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -95,9 +95,9 @@ async def handshake( return_when=asyncio.FIRST_COMPLETED, ) - # self.protocol.handshake_exc is always set when the connection is lost - # before receiving a response, when the response cannot be parsed, or - # when the response fails the handshake. + # self.protocol.handshake_exc is set when the connection is lost before + # receiving a response, when the response cannot be parsed, or when the + # response fails the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 180d3a5a9..15c9ba13e 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -192,10 +192,13 @@ async def handshake( self.protocol.send_response(self.response) - # self.protocol.handshake_exc is always set when the connection is lost - # before receiving a request, when the request cannot be parsed, when - # the handshake encounters an error, or when process_request or - # process_response sends an HTTP response that rejects the handshake. + # self.protocol.handshake_exc is set when the connection is lost before + # receiving a request, when the request cannot be parsed, or when the + # handshake fails, including when process_request or process_response + # raises an exception. + + # It isn't set when process_request or process_response sends an HTTP + # response that rejects the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc @@ -360,7 +363,11 @@ async def conn_handler(self, connection: ServerConnection) -> None: connection.close_transport() return - assert connection.protocol.state is OPEN + if connection.protocol.state is not OPEN: + # process_request or process_response rejected the handshake. + connection.close_transport() + return + try: connection.start_keepalive() await self.handler(connection) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 19b813526..0f6fea250 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -518,15 +518,34 @@ def close_expected(self) -> bool: Whether the TCP connection is expected to close soon. """ - # We expect a TCP close if and only if we sent a close frame: + # During the opening handshake, when our state is CONNECTING, we expect + # a TCP close if and only if the hansdake fails. When it does, we start + # the TCP closing handshake by sending EOF with send_eof(). + + # Once the opening handshake completes successfully, we expect a TCP + # close if and only if we sent a close frame, meaning that our state + # progressed to CLOSING: + # * Normal closure: once we send a close frame, we expect a TCP close: # server waits for client to complete the TCP closing handshake; # client waits for server to initiate the TCP closing handshake. + # * Abnormal closure: we always send a close frame and the same logic # applies, except on EOFError where we don't send a close frame # because we already received the TCP close, so we don't expect it. - # We already got a TCP Close if and only if the state is CLOSED. - return self.state is CLOSING or self.handshake_exc is not None + + # If our state is CLOSED, we already received a TCP close so we don't + # expect it anymore. + + # Micro-optimization: put the most common case first + if self.state is OPEN: + return False + if self.state is CLOSING: + return True + if self.state is CLOSED: + return False + assert self.state is CONNECTING + return self.eof_sent # Private methods for receiving data. @@ -616,14 +635,14 @@ def discard(self) -> Generator[None]: # connection in the same circumstances where discard() replaces parse(). # The client closes it when it receives EOF from the server or times # out. (The latter case cannot be handled in this Sans-I/O layer.) - assert (self.state == CONNECTING or self.side is SERVER) == (self.eof_sent) + assert (self.side is SERVER or self.state is CONNECTING) == (self.eof_sent) while not (yield from self.reader.at_eof()): self.reader.discard() if self.debug: self.logger.debug("< EOF") # A server closes the TCP connection immediately, while a client # waits for the server to close the TCP connection. - if self.state != CONNECTING and self.side is CLIENT: + if self.side is CLIENT and self.state is not CONNECTING: self.send_eof() self.state = CLOSED # If discard() completes normally, execution ends here. diff --git a/src/websockets/server.py b/src/websockets/server.py index 527db8990..e3fdcc646 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -14,7 +14,6 @@ InvalidHeader, InvalidHeaderValue, InvalidOrigin, - InvalidStatus, InvalidUpgrade, NegotiationError, ) @@ -536,11 +535,6 @@ def send_response(self, response: Response) -> None: self.logger.info("connection open") else: - # handshake_exc may be already set if accept() encountered an error. - # If the connection isn't open, set handshake_exc to guarantee that - # handshake_exc is None if and only if opening handshake succeeded. - if self.handshake_exc is None: - self.handshake_exc = InvalidStatus(response) self.logger.info( "connection rejected (%d %s)", response.status_code, diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 0aada658e..54d0aef68 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -87,9 +87,9 @@ def handshake( if not self.response_rcvd.wait(timeout): raise TimeoutError("timed out during handshake") - # self.protocol.handshake_exc is always set when the connection is lost - # before receiving a response, when the response cannot be parsed, or - # when the response fails the handshake. + # self.protocol.handshake_exc is set when the connection is lost before + # receiving a response, when the response cannot be parsed, or when the + # response fails the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 44dbd7290..8601ccef9 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -170,10 +170,13 @@ def handshake( self.protocol.send_response(self.response) - # self.protocol.handshake_exc is always set when the connection is lost - # before receiving a request, when the request cannot be parsed, when - # the handshake encounters an error, or when process_request or - # process_response sends an HTTP response that rejects the handshake. + # self.protocol.handshake_exc is set when the connection is lost before + # receiving a request, when the request cannot be parsed, or when the + # handshake fails, including when process_request or process_response + # raises an exception. + + # It isn't set when process_request or process_response sends an HTTP + # response that rejects the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index a3b65e956..c98765d80 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -51,7 +51,7 @@ async def asyncTearDown(self): if sys.version_info[:2] < (3, 10): # pragma: no cover @contextlib.contextmanager - def assertNoLogs(self, logger="websockets", level=logging.ERROR): + def assertNoLogs(self, logger=None, level=None): """ No message is logged on the given logger with at least the given level. diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 1dcb8c7b7..c817f5ef6 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -148,14 +148,17 @@ def process_request(ws, request): async def handler(ws): self.fail("handler must not run") - async with serve(handler, *args[1:], process_request=process_request) as server: - with self.assertRaises(InvalidStatus) as raised: - async with connect(get_uri(server)): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 403", - ) + with self.assertNoLogs("websockets", logging.ERROR): + async with serve( + handler, *args[1:], process_request=process_request + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) async def test_async_process_request_returns_response(self): """Server aborts handshake if async process_request returns a response.""" @@ -166,44 +169,65 @@ async def process_request(ws, request): async def handler(ws): self.fail("handler must not run") - async with serve(handler, *args[1:], process_request=process_request) as server: - with self.assertRaises(InvalidStatus) as raised: - async with connect(get_uri(server)): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 403", - ) + with self.assertNoLogs("websockets", logging.ERROR): + async with serve( + handler, *args[1:], process_request=process_request + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) async def test_process_request_raises_exception(self): """Server returns an error if process_request raises an exception.""" def process_request(ws, request): - raise RuntimeError + raise RuntimeError("BOOM") - async with serve(*args, process_request=process_request) as server: - with self.assertRaises(InvalidStatus) as raised: - async with connect(get_uri(server)): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 500", - ) + with self.assertLogs("websockets", logging.ERROR) as logs: + async with serve(*args, process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) async def test_async_process_request_raises_exception(self): """Server returns an error if async process_request raises an exception.""" async def process_request(ws, request): - raise RuntimeError + raise RuntimeError("BOOM") - async with serve(*args, process_request=process_request) as server: - with self.assertRaises(InvalidStatus) as raised: - async with connect(get_uri(server)): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 500", - ) + with self.assertLogs("websockets", logging.ERROR) as logs: + async with serve(*args, process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) async def test_process_response_returns_none(self): """Server runs process_response but keeps the handshake response.""" @@ -277,31 +301,49 @@ async def test_process_response_raises_exception(self): """Server returns an error if process_response raises an exception.""" def process_response(ws, request, response): - raise RuntimeError + raise RuntimeError("BOOM") - async with serve(*args, process_response=process_response) as server: - with self.assertRaises(InvalidStatus) as raised: - async with connect(get_uri(server)): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 500", - ) + with self.assertLogs("websockets", logging.ERROR) as logs: + async with serve(*args, process_response=process_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) async def test_async_process_response_raises_exception(self): """Server returns an error if async process_response raises an exception.""" async def process_response(ws, request, response): - raise RuntimeError + raise RuntimeError("BOOM") - async with serve(*args, process_response=process_response) as server: - with self.assertRaises(InvalidStatus) as raised: - async with connect(get_uri(server)): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 500", - ) + with self.assertLogs("websockets", logging.ERROR) as logs: + async with serve(*args, process_response=process_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) async def test_override_server(self): """Server can override Server header with server_header.""" diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 0ae804bb3..1c092459d 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -20,7 +20,7 @@ Frame, ) from websockets.protocol import * -from websockets.protocol import CLIENT, CLOSED, CLOSING, SERVER +from websockets.protocol import CLIENT, CLOSED, CLOSING, CONNECTING, SERVER from .extensions.utils import Rsv2Extension from .test_frames import FramesTestCase @@ -1696,6 +1696,24 @@ def test_server_fails_connection(self): server.fail(CloseCode.PROTOCOL_ERROR) self.assertTrue(server.close_expected()) + def test_client_is_connecting(self): + client = Protocol(CLIENT, state=CONNECTING) + self.assertFalse(client.close_expected()) + + def test_server_is_connecting(self): + server = Protocol(SERVER, state=CONNECTING) + self.assertFalse(server.close_expected()) + + def test_client_failed_connecting(self): + client = Protocol(CLIENT, state=CONNECTING) + client.send_eof() + self.assertTrue(client.close_expected()) + + def test_server_failed_connecting(self): + server = Protocol(SERVER, state=CONNECTING) + server.send_eof() + self.assertTrue(server.close_expected()) + class ConnectionClosedTests(ProtocolTestCase): """ From 76f6f573e2ecb279230c2bf56c07bf4d4f717147 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 3 Nov 2024 09:11:51 +0100 Subject: [PATCH 1438/1539] Factor out backport of assertNoLogs. Fix previous commit on Python 3.9. --- tests/asyncio/test_connection.py | 25 ++++--------------------- tests/asyncio/test_server.py | 3 ++- tests/legacy/test_protocol.py | 12 ++++++------ tests/legacy/utils.py | 23 +++-------------------- tests/utils.py | 26 ++++++++++++++++++++++++++ 5 files changed, 41 insertions(+), 48 deletions(-) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index c98765d80..d61798afb 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -19,7 +19,7 @@ from websockets.protocol import CLIENT, SERVER, Protocol, State from ..protocol import RecordingProtocol -from ..utils import MS +from ..utils import MS, AssertNoLogsMixin from .connection import InterceptingConnection from .utils import alist @@ -28,7 +28,7 @@ # All tests run on the client side and the server side to validate this. -class ClientConnectionTests(unittest.IsolatedAsyncioTestCase): +class ClientConnectionTests(AssertNoLogsMixin, unittest.IsolatedAsyncioTestCase): LOCAL = CLIENT REMOTE = SERVER @@ -48,23 +48,6 @@ async def asyncTearDown(self): await self.remote_connection.close() await self.connection.close() - if sys.version_info[:2] < (3, 10): # pragma: no cover - - @contextlib.contextmanager - def assertNoLogs(self, logger=None, level=None): - """ - No message is logged on the given logger with at least the given level. - - """ - with self.assertLogs(logger, level) as logs: - # We want to test that no log message is emitted - # but assertLogs expects at least one log message. - logging.getLogger(logger).log(level, "dummy") - yield - - level_name = logging.getLevelName(level) - self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) - # Test helpers built upon RecordingProtocol and InterceptingConnection. async def assertFrameSent(self, frame): @@ -1277,7 +1260,7 @@ async def test_broadcast_skips_closed_connection(self): await self.connection.close() await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.WARNING): broadcast([self.connection], "😀") await self.assertNoFrameSent() @@ -1288,7 +1271,7 @@ async def test_broadcast_skips_closing_connection(self): await asyncio.sleep(0) await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.WARNING): broadcast([self.connection], "😀") await self.assertNoFrameSent() diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index c817f5ef6..3e289e592 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -21,6 +21,7 @@ CLIENT_CONTEXT, MS, SERVER_CONTEXT, + AssertNoLogsMixin, temp_unix_socket_path, ) from .server import ( @@ -32,7 +33,7 @@ ) -class ServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): +class ServerTests(EvalShellMixin, AssertNoLogsMixin, unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Server receives connection from client and the handshake succeeds.""" async with serve(*args) as server: diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index de2a320b5..be2910a8f 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -938,7 +938,7 @@ def test_answer_ping_does_not_crash_if_connection_closing(self): self.receive_frame(Frame(True, OP_PING, b"test")) self.run_loop_once() - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.ERROR): self.loop.run_until_complete(self.protocol.close()) self.loop.run_until_complete(close_task) # cleanup @@ -951,7 +951,7 @@ def test_answer_ping_does_not_crash_if_connection_closed(self): self.receive_eof() self.run_loop_once() - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.ERROR): self.loop.run_until_complete(self.protocol.close()) def test_ignore_pong(self): @@ -1028,7 +1028,7 @@ def test_acknowledge_aborted_ping(self): pong_waiter.result() # transfer_data doesn't crash, which would be logged. - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.ERROR): # Unclog incoming queue. self.loop.run_until_complete(self.protocol.recv()) self.loop.run_until_complete(self.protocol.recv()) @@ -1375,7 +1375,7 @@ def test_remote_close_and_connection_lost(self): self.receive_eof() self.run_loop_once() - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.ERROR): self.loop.run_until_complete(self.protocol.close(reason="oh noes!")) self.assertConnectionClosed(CloseCode.NORMAL_CLOSURE, "close") @@ -1500,14 +1500,14 @@ def test_broadcast_two_clients(self): def test_broadcast_skips_closed_connection(self): self.close_connection() - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.ERROR): broadcast([self.protocol], "café") self.assertNoFrameSent() def test_broadcast_skips_closing_connection(self): close_task = self.half_close_connection_local() - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.ERROR): broadcast([self.protocol], "café") self.assertNoFrameSent() diff --git a/tests/legacy/utils.py b/tests/legacy/utils.py index 5b56050d5..1f79bb600 100644 --- a/tests/legacy/utils.py +++ b/tests/legacy/utils.py @@ -1,12 +1,12 @@ import asyncio -import contextlib import functools -import logging import sys import unittest +from ..utils import AssertNoLogsMixin -class AsyncioTestCase(unittest.TestCase): + +class AsyncioTestCase(AssertNoLogsMixin, unittest.TestCase): """ Base class for tests that sets up an isolated event loop for each test. @@ -56,23 +56,6 @@ def run_loop_once(self): self.loop.call_soon(self.loop.stop) self.loop.run_forever() - if sys.version_info[:2] < (3, 10): # pragma: no cover - - @contextlib.contextmanager - def assertNoLogs(self, logger="websockets", level=logging.ERROR): - """ - No message is logged on the given logger with at least the given level. - - """ - with self.assertLogs(logger, level) as logs: - # We want to test that no log message is emitted - # but assertLogs expects at least one log message. - logging.getLogger(logger).log(level, "dummy") - yield - - level_name = logging.getLevelName(level) - self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) - def assertDeprecationWarnings(self, recorded_warnings, expected_warnings): """ Check recorded deprecation warnings match a list of expected messages. diff --git a/tests/utils.py b/tests/utils.py index 639fb7fe5..77d020726 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,9 +1,11 @@ import contextlib import email.utils +import logging import os import pathlib import platform import ssl +import sys import tempfile import time import unittest @@ -113,6 +115,30 @@ def assertDeprecationWarning(self, message): self.assertEqual(str(warning.message), message) +class AssertNoLogsMixin: + """ + Backport of assertNoLogs for Python 3.9. + + """ + + if sys.version_info[:2] < (3, 10): # pragma: no cover + + @contextlib.contextmanager + def assertNoLogs(self, logger=None, level=None): + """ + No message is logged on the given logger with at least the given level. + + """ + with self.assertLogs(logger, level) as logs: + # We want to test that no log message is emitted + # but assertLogs expects at least one log message. + logging.getLogger(logger).log(level, "dummy") + yield + + level_name = logging.getLevelName(level) + self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) + + @contextlib.contextmanager def temp_unix_socket_path(): with tempfile.TemporaryDirectory() as temp_dir: From 0a5a79c224c9be97f79909f7218fd1da7b2acabb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 3 Nov 2024 09:47:51 +0100 Subject: [PATCH 1439/1539] Clean up prefixes of debug log messages. All debug messages and only debug messages should have them. --- docs/topics/logging.rst | 8 ++++++-- src/websockets/asyncio/client.py | 2 +- src/websockets/asyncio/connection.py | 6 +++--- src/websockets/legacy/client.py | 4 ++-- src/websockets/legacy/protocol.py | 8 ++++---- src/websockets/sync/connection.py | 15 ++++++++++++--- tests/legacy/test_client_server.py | 4 ++-- 7 files changed, 30 insertions(+), 17 deletions(-) diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index be5678455..ae71be265 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -218,7 +218,10 @@ Here's what websockets logs at each level. ``ERROR`` ......... -* Exceptions raised by connection handler coroutines in servers +* Exceptions raised by your code in servers + * connection handler coroutines + * ``select_subprotocol`` callbacks + * ``process_request`` and ``process_response`` callbacks * Exceptions resulting from bugs in websockets ``WARNING`` @@ -250,4 +253,5 @@ Debug messages have cute prefixes that make logs easier to scan: * ``=`` - set connection state * ``x`` - shut down connection * ``%`` - manage pings and pongs -* ``!`` - handle errors and timeouts +* ``-`` - timeout +* ``!`` - error, with a traceback diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index d276ac171..302d0b94d 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -521,7 +521,7 @@ async def __aiter__(self) -> AsyncIterator[ClientConnection]: delays = backoff() delay = next(delays) self.logger.info( - "! connect failed; reconnecting in %.1f seconds", + "connect failed; reconnecting in %.1f seconds", delay, exc_info=True, ) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 5545632d6..c4961884c 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -785,7 +785,7 @@ async def keepalive(self) -> None: self.logger.debug("% received keepalive pong") except asyncio.TimeoutError: if self.debug: - self.logger.debug("! timed out waiting for keepalive pong") + self.logger.debug("- timed out waiting for keepalive pong") async with self.send_context(): self.protocol.fail( CloseCode.INTERNAL_ERROR, @@ -866,7 +866,7 @@ async def send_context( await self.drain() except Exception as exc: if self.debug: - self.logger.debug("error while sending data", exc_info=True) + self.logger.debug("! error while sending data", exc_info=True) # While the only expected exception here is OSError, # other exceptions would be treated identically. wait_for_close = False @@ -1042,7 +1042,7 @@ def data_received(self, data: bytes) -> None: self.send_data() except Exception as exc: if self.debug: - self.logger.debug("error while sending data", exc_info=True) + self.logger.debug("! error while sending data", exc_info=True) self.set_recv_exc(exc) if self.protocol.close_expected(): diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 116445e25..a2dc0250f 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -603,14 +603,14 @@ async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]: if backoff_delay == self.BACKOFF_MIN: initial_delay = random.random() * self.BACKOFF_INITIAL self.logger.info( - "! connect failed; reconnecting in %.1f seconds", + "connect failed; reconnecting in %.1f seconds", initial_delay, exc_info=True, ) await asyncio.sleep(initial_delay) else: self.logger.info( - "! connect failed again; retrying in %d seconds", + "connect failed again; retrying in %d seconds", int(backoff_delay), exc_info=True, ) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index cedde6200..bd998dfd1 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1246,7 +1246,7 @@ async def keepalive_ping(self) -> None: self.logger.debug("% received keepalive pong") except asyncio.TimeoutError: if self.debug: - self.logger.debug("! timed out waiting for keepalive pong") + self.logger.debug("- timed out waiting for keepalive pong") self.fail_connection( CloseCode.INTERNAL_ERROR, "keepalive ping timeout", @@ -1288,7 +1288,7 @@ async def close_connection(self) -> None: if await self.wait_for_connection_lost(): return if self.debug: - self.logger.debug("! timed out waiting for TCP close") + self.logger.debug("- timed out waiting for TCP close") # Half-close the TCP connection if possible (when there's no TLS). if self.transport.can_write_eof(): @@ -1306,7 +1306,7 @@ async def close_connection(self) -> None: if await self.wait_for_connection_lost(): return if self.debug: - self.logger.debug("! timed out waiting for TCP close") + self.logger.debug("- timed out waiting for TCP close") finally: # The try/finally ensures that the transport never remains open, @@ -1332,7 +1332,7 @@ async def close_transport(self) -> None: if await self.wait_for_connection_lost(): return if self.debug: - self.logger.debug("! timed out waiting for TCP close") + self.logger.debug("- timed out waiting for TCP close") # Abort the TCP connection. Buffers are discarded. if self.debug: diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 8d1dbcf58..77f803c9b 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -640,7 +640,10 @@ def recv_events(self) -> None: data = self.socket.recv(self.recv_bufsize) except Exception as exc: if self.debug: - self.logger.debug("error while receiving data", exc_info=True) + self.logger.debug( + "! error while receiving data", + exc_info=True, + ) # When the closing handshake is initiated by our side, # recv() may block until send_context() closes the socket. # In that case, send_context() already set recv_exc. @@ -665,7 +668,10 @@ def recv_events(self) -> None: self.send_data() except Exception as exc: if self.debug: - self.logger.debug("error while sending data", exc_info=True) + self.logger.debug( + "! error while sending data", + exc_info=True, + ) # Similarly to the above, avoid overriding an exception # set by send_context(), in case of a race condition # i.e. send_context() closes the socket after recv() @@ -783,7 +789,10 @@ def send_context( self.send_data() except Exception as exc: if self.debug: - self.logger.debug("error while sending data", exc_info=True) + self.logger.debug( + "! error while sending data", + exc_info=True, + ) # While the only expected exception here is OSError, # other exceptions would be treated identically. wait_for_close = False diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 502ab68e7..375d47e29 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1615,12 +1615,12 @@ async def run_client(): [ "connection rejected (503 Service Unavailable)", "connection closed", - "! connect failed; reconnecting in X seconds", + "connect failed; reconnecting in X seconds", ] + [ "connection rejected (503 Service Unavailable)", "connection closed", - "! connect failed again; retrying in X seconds", + "connect failed again; retrying in X seconds", ] * ((len(logs.records) - 8) // 3) + [ From 9b3595d8d4d0573e00209f9e920f6a6fab981fa9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 4 Nov 2024 22:58:18 +0100 Subject: [PATCH 1440/1539] Remove stack traces from INFO and WARNING logs. Fix #1501. --- src/websockets/asyncio/client.py | 6 ++++-- src/websockets/asyncio/connection.py | 9 +++++++-- src/websockets/legacy/client.py | 13 ++++++++----- src/websockets/legacy/protocol.py | 9 +++++++-- tests/asyncio/test_connection.py | 5 ++++- tests/legacy/test_client_server.py | 8 ++++++-- tests/legacy/test_protocol.py | 4 ++-- 7 files changed, 38 insertions(+), 16 deletions(-) diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 302d0b94d..b3d50c12e 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -3,6 +3,7 @@ import asyncio import logging import os +import traceback import urllib.parse from collections.abc import AsyncIterator, Generator, Sequence from types import TracebackType @@ -521,9 +522,10 @@ async def __aiter__(self) -> AsyncIterator[ClientConnection]: delays = backoff() delay = next(delays) self.logger.info( - "connect failed; reconnecting in %.1f seconds", + "connect failed; reconnecting in %.1f seconds: %s", delay, - exc_info=True, + # Remove first argument when dropping Python 3.9. + traceback.format_exception_only(type(exc), exc)[0].strip(), ) await asyncio.sleep(delay) continue diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index c4961884c..186846ef3 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -7,6 +7,7 @@ import random import struct import sys +import traceback import uuid from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterable, Mapping from types import TracebackType @@ -1180,8 +1181,12 @@ def broadcast( exceptions.append(exception) else: connection.logger.warning( - "skipped broadcast: failed to write message", - exc_info=True, + "skipped broadcast: failed to write message: %s", + traceback.format_exception_only( + # Remove first argument when dropping Python 3.9. + type(write_exception), + write_exception, + )[0].strip(), ) if raise_exceptions and exceptions: diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index a2dc0250f..555069e8c 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -5,6 +5,7 @@ import logging import os import random +import traceback import urllib.parse import warnings from collections.abc import AsyncIterator, Generator, Sequence @@ -597,22 +598,24 @@ async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]: try: async with self as protocol: yield protocol - except Exception: + except Exception as exc: # Add a random initial delay between 0 and 5 seconds. # See 7.2.3. Recovering from Abnormal Closure in RFC 6455. if backoff_delay == self.BACKOFF_MIN: initial_delay = random.random() * self.BACKOFF_INITIAL self.logger.info( - "connect failed; reconnecting in %.1f seconds", + "connect failed; reconnecting in %.1f seconds: %s", initial_delay, - exc_info=True, + # Remove first argument when dropping Python 3.9. + traceback.format_exception_only(type(exc), exc)[0].strip(), ) await asyncio.sleep(initial_delay) else: self.logger.info( - "connect failed again; retrying in %d seconds", + "connect failed again; retrying in %d seconds: %s", int(backoff_delay), - exc_info=True, + # Remove first argument when dropping Python 3.9. + traceback.format_exception_only(type(exc), exc)[0].strip(), ) await asyncio.sleep(int(backoff_delay)) # Increase delay with truncated exponential backoff. diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index bd998dfd1..db126c01e 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -9,6 +9,7 @@ import struct import sys import time +import traceback import uuid import warnings from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterable, Mapping @@ -1624,8 +1625,12 @@ def broadcast( exceptions.append(exception) else: websocket.logger.warning( - "skipped broadcast: failed to write message", - exc_info=True, + "skipped broadcast: failed to write message: %s", + traceback.format_exception_only( + # Remove first argument when dropping Python 3.9. + type(write_exception), + write_exception, + )[0].strip(), ) if raise_exceptions and exceptions: diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index d61798afb..902b3b847 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -1337,7 +1337,10 @@ async def test_broadcast_skips_connection_failing_to_send(self): self.assertEqual( [record.getMessage() for record in logs.records], - ["skipped broadcast: failed to write message"], + [ + "skipped broadcast: failed to write message: " + "RuntimeError: Cannot call write() after write_eof()" + ], ) @unittest.skipIf( diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 375d47e29..c13c6c92e 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1607,6 +1607,10 @@ async def run_client(): ], ) # Iteration 3 + exc = ( + "websockets.legacy.exceptions.InvalidStatusCode: " + "server rejected WebSocket connection: HTTP 503" + ) self.assertEqual( [ re.sub(r"[0-9\.]+ seconds", "X seconds", record.getMessage()) @@ -1615,12 +1619,12 @@ async def run_client(): [ "connection rejected (503 Service Unavailable)", "connection closed", - "connect failed; reconnecting in X seconds", + f"connect failed; reconnecting in X seconds: {exc}", ] + [ "connection rejected (503 Service Unavailable)", "connection closed", - "connect failed again; retrying in X seconds", + f"connect failed again; retrying in X seconds: {exc}", ] * ((len(logs.records) - 8) // 3) + [ diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index be2910a8f..d30198934 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1547,14 +1547,14 @@ def test_broadcast_reports_connection_sending_fragmented_text(self): def test_broadcast_skips_connection_failing_to_send(self): # Configure mock to raise an exception when writing to the network. - self.protocol.transport.write.side_effect = RuntimeError + self.protocol.transport.write.side_effect = RuntimeError("BOOM") with self.assertLogs("websockets", logging.WARNING) as logs: broadcast([self.protocol], "café") self.assertEqual( [record.getMessage() for record in logs.records], - ["skipped broadcast: failed to write message"], + ["skipped broadcast: failed to write message: RuntimeError: BOOM"], ) @unittest.skipIf( From 5f34e2741e94ac21ef99e3dc212aa152d77d1a37 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 9 Nov 2024 08:55:26 +0100 Subject: [PATCH 1441/1539] Fix remaining instances of shortcut imports. --- docs/topics/broadcast.rst | 2 +- docs/topics/keepalive.rst | 2 +- experiments/broadcast/server.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/topics/broadcast.rst b/docs/topics/broadcast.rst index c9699feb2..66b0819b2 100644 --- a/docs/topics/broadcast.rst +++ b/docs/topics/broadcast.rst @@ -83,7 +83,7 @@ to:: Here's a coroutine that broadcasts a message to all clients:: - from websockets import ConnectionClosed + from websockets.exceptions import ConnectionClosed async def broadcast(message): for websocket in CLIENTS.copy(): diff --git a/docs/topics/keepalive.rst b/docs/topics/keepalive.rst index 003087fad..4897de2ba 100644 --- a/docs/topics/keepalive.rst +++ b/docs/topics/keepalive.rst @@ -112,7 +112,7 @@ You can run a task in the background to send keepalive messages: import itertools import json - from websockets import ConnectionClosed + from websockets.exceptions import ConnectionClosed async def keepalive(websocket, ping_interval=30): for ping in itertools.count(): diff --git a/experiments/broadcast/server.py b/experiments/broadcast/server.py index d5b50bd71..eca55357e 100644 --- a/experiments/broadcast/server.py +++ b/experiments/broadcast/server.py @@ -6,8 +6,8 @@ import sys import time -from websockets import ConnectionClosed from websockets.asyncio.server import broadcast, serve +from websockets.exceptions import ConnectionClosed CLIENTS = set() From c57bcb743fe1128f6d28e2eaebbdd202eb3ee2eb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 9 Nov 2024 09:36:42 +0100 Subject: [PATCH 1442/1539] Standardize links to RFC. There were 70 links to https://datatracker.ietf.org/doc/html/ vs. 15 links https://www.rfc-editor.org/rfc/. Also :rfc:`....` links to https://datatracker.ietf.org/doc/html/ by default. While https://www.ietf.org/process/rfcs/#introduction says: > The RFC Editor website is the authoritative site for RFCs. the IETF Datatracker looks a bit better and has more information. --- docs/faq/common.rst | 4 ++-- docs/howto/extensions.rst | 2 +- docs/project/changelog.rst | 2 +- docs/reference/extensions.rst | 2 +- docs/topics/design.rst | 10 +++++----- docs/topics/keepalive.rst | 6 +++--- docs/topics/logging.rst | 4 ++-- 7 files changed, 15 insertions(+), 15 deletions(-) diff --git a/docs/faq/common.rst b/docs/faq/common.rst index 0dc4a3aeb..ba7a95932 100644 --- a/docs/faq/common.rst +++ b/docs/faq/common.rst @@ -131,8 +131,8 @@ How do I respond to pings? If you are referring to Ping_ and Pong_ frames defined in the WebSocket protocol, don't bother, because websockets handles them for you. -.. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 -.. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 +.. _Ping: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.5.2 +.. _Pong: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.5.3 If you are connecting to a server that defines its own heartbeat at the application level, then you need to build that logic into your application. diff --git a/docs/howto/extensions.rst b/docs/howto/extensions.rst index 3c8a7d72a..c4e9da626 100644 --- a/docs/howto/extensions.rst +++ b/docs/howto/extensions.rst @@ -7,7 +7,7 @@ During the opening handshake, WebSocket clients and servers negotiate which extensions_ will be used with which parameters. Then each frame is processed by extensions before being sent or after being received. -.. _extensions: https://www.rfc-editor.org/rfc/rfc6455.html#section-9 +.. _extensions: https://datatracker.ietf.org/doc/html/rfc6455.html#section-9 As a consequence, writing an extension requires implementing several classes: diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 71b2a6960..1056bc980 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -1487,7 +1487,7 @@ New features * Added support for providing and checking Origin_. -.. _Origin: https://www.rfc-editor.org/rfc/rfc6455.html#section-10.2 +.. _Origin: https://datatracker.ietf.org/doc/html/rfc6455.html#section-10.2 .. _2.0: diff --git a/docs/reference/extensions.rst b/docs/reference/extensions.rst index a70f1b1e5..f3da464a5 100644 --- a/docs/reference/extensions.rst +++ b/docs/reference/extensions.rst @@ -8,7 +8,7 @@ The WebSocket protocol supports extensions_. At the time of writing, there's only one `registered extension`_ with a public specification, WebSocket Per-Message Deflate. -.. _extensions: https://www.rfc-editor.org/rfc/rfc6455.html#section-9 +.. _extensions: https://datatracker.ietf.org/doc/html/rfc6455.html#section-9 .. _registered extension: https://www.iana.org/assignments/websocket/websocket.xhtml#extension-name Per-Message Deflate diff --git a/docs/topics/design.rst b/docs/topics/design.rst index b73ace517..bc14bd332 100644 --- a/docs/topics/design.rst +++ b/docs/topics/design.rst @@ -173,16 +173,16 @@ differences between a server and a client: - `closing the TCP connection`_: the server closes the connection immediately; the client waits for the server to do it. -.. _client-to-server masking: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.3 -.. _closing the TCP connection: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.1 +.. _client-to-server masking: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.3 +.. _closing the TCP connection: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.5.1 These differences are so minor that all the logic for `data framing`_, for `sending and receiving data`_ and for `closing the connection`_ is implemented in the same class, :class:`~protocol.WebSocketCommonProtocol`. -.. _data framing: https://www.rfc-editor.org/rfc/rfc6455.html#section-5 -.. _sending and receiving data: https://www.rfc-editor.org/rfc/rfc6455.html#section-6 -.. _closing the connection: https://www.rfc-editor.org/rfc/rfc6455.html#section-7 +.. _data framing: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5 +.. _sending and receiving data: https://datatracker.ietf.org/doc/html/rfc6455.html#section-6 +.. _closing the connection: https://datatracker.ietf.org/doc/html/rfc6455.html#section-7 The :attr:`~protocol.WebSocketCommonProtocol.is_client` attribute tells which side a protocol instance is managing. This attribute is defined on the diff --git a/docs/topics/keepalive.rst b/docs/topics/keepalive.rst index 4897de2ba..a0467ced2 100644 --- a/docs/topics/keepalive.rst +++ b/docs/topics/keepalive.rst @@ -33,8 +33,8 @@ Keepalive in websockets To avoid these problems, websockets runs a keepalive and heartbeat mechanism based on WebSocket Ping_ and Pong_ frames, which are designed for this purpose. -.. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 -.. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 +.. _Ping: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.5.2 +.. _Pong: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.5.3 It sends a Ping frame every 20 seconds. It expects a Pong frame in return within 20 seconds. Else, it considers the connection broken and terminates it. @@ -98,7 +98,7 @@ at regular intervals. Usually they expect Text_ frames rather than Ping_ frames, meaning that you must send them with :attr:`~asyncio.connection.Connection.send` rather than :attr:`~asyncio.connection.Connection.ping`. -.. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 +.. _Text: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.6 In websockets, such keepalive mechanisms are considered as application-level because they rely on data frames. That's unlike the protocol-level keepalive diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index ae71be265..fff33a024 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -35,8 +35,8 @@ Instead, when running as a server, websockets logs one event when a `connection is established`_ and another event when a `connection is closed`_. -.. _connection is established: https://www.rfc-editor.org/rfc/rfc6455.html#section-4 -.. _connection is closed: https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.4 +.. _connection is established: https://datatracker.ietf.org/doc/html/rfc6455.html#section-4 +.. _connection is closed: https://datatracker.ietf.org/doc/html/rfc6455.html#section-7.1.4 By default, websockets doesn't log an event for every message. That would be excessive for many applications exchanging small messages at a fast rate. If From 178c88447c2754d2d2b0c02472868cbaef7cec52 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 9 Nov 2024 13:56:49 +0100 Subject: [PATCH 1443/1539] Complete and polish changelog for 14.0. --- docs/project/changelog.rst | 41 +++++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 1056bc980..161f71c7f 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -41,7 +41,7 @@ Backwards-incompatible changes websockets 13.1 is the last version supporting Python 3.8. .. admonition:: The new :mod:`asyncio` implementation is now the default. - :class: caution + :class: danger The following aliases in the ``websockets`` package were switched to the new :mod:`asyncio` implementation:: @@ -64,8 +64,8 @@ Backwards-incompatible changes The :doc:`upgrade guide <../howto/upgrade>` provides complete instructions to migrate your application. - Aliases for deprecated API were removed from ``__all__``. As a consequence, - they cannot be imported e.g. with ``from websockets import *`` anymore. + Aliases for deprecated API were removed from ``websockets.__all__``, meaning + that they cannot be imported with ``from websockets import *`` anymore. .. admonition:: Several API raise :exc:`ValueError` instead of :exc:`TypeError` on invalid arguments. @@ -83,22 +83,16 @@ Backwards-incompatible changes :class: note In addition to :class:`bytes`, it may be a :class:`bytearray` or a - :class:`memoryview`. - - If you wrote an :class:`extension ` that relies on - methods not provided by these new types, you may need to update your code. + :class:`memoryview`. If you wrote an :class:`~extensions.Extension` that + relies on methods not provided by these types, you must update your code. .. admonition:: The signature of :exc:`~exceptions.PayloadTooBig` changed. :class: note If you wrote an extension that raises :exc:`~exceptions.PayloadTooBig` in - :meth:`~extensions.Extension.decode`, for example, you must replace:: - - PayloadTooBig(f"over size limit ({size} > {max_size} bytes)") - - with:: - - PayloadTooBig(size, max_size) + :meth:`~extensions.Extension.decode`, for example, you must replace + ``PayloadTooBig(f"over size limit ({size} > {max_size} bytes)")`` with + ``PayloadTooBig(size, max_size)``. New features ............ @@ -106,8 +100,8 @@ New features * Added an option to receive text frames as :class:`bytes`, without decoding, in the :mod:`threading` implementation; also binary frames as :class:`str`. -* Added an option to send :class:`bytes` as a text frame in the :mod:`asyncio` - and :mod:`threading` implementations, as well as :class:`str` a binary frame. +* Added an option to send :class:`bytes` in a text frame in the :mod:`asyncio` + and :mod:`threading` implementations; also :class:`str` in a binary frame. Improvements ............ @@ -118,6 +112,21 @@ Improvements * Errors when a fragmented message is too large are clearer. +* Log messages at the :data:`~logging.WARNING` and :data:`~logging.INFO` levels + no longer include stack traces. + +Bug fixes +......... + +* Clients no longer crash when the server rejects the opening handshake and the + HTTP response doesn't Include a ``Content-Length`` header. + +* Returning an HTTP response in ``process_request`` or ``process_response`` + doesn't generate a log message at the :data:`~logging.ERROR` level anymore. + +* Connections are closed with code 1007 (invalid data) when receiving invalid + UTF-8 in a text frame. + .. _13.1: 13.1 From f0d20aafab027e9b99460b193dcb709872b219a5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 9 Nov 2024 13:58:10 +0100 Subject: [PATCH 1444/1539] Release version 14.0. --- docs/project/changelog.rst | 2 +- src/websockets/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 161f71c7f..5aa58a09b 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -30,7 +30,7 @@ notice. 14.0 ---- -*In development* +*November 9, 2024* Backwards-incompatible changes .............................. diff --git a/src/websockets/version.py b/src/websockets/version.py index 34fc2eaef..7c64f566a 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,7 +18,7 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = False +released = True tag = version = commit = "14.0" From b9d74504eaeedd35cb7dc2651a18420f10e3828d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 9 Nov 2024 14:01:07 +0100 Subject: [PATCH 1445/1539] Start version 14.1. --- docs/project/changelog.rst | 7 +++++++ src/websockets/version.py | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 5aa58a09b..9e6a9d113 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,13 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. +.. _14.1: + +14.1 +---- + +*In development* + .. _14.0: 14.0 diff --git a/src/websockets/version.py b/src/websockets/version.py index 7c64f566a..48d2edaea 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,9 +18,9 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = True +released = False -tag = version = commit = "14.0" +tag = version = commit = "14.1" if not released: # pragma: no cover From e9fc77da927793d05072163d61e137dd35f97e4d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 9 Nov 2024 14:01:22 +0100 Subject: [PATCH 1446/1539] Add dates for deprecations. --- src/websockets/__init__.py | 2 +- src/websockets/auth.py | 2 +- src/websockets/client.py | 2 +- src/websockets/exceptions.py | 8 ++++---- src/websockets/legacy/__init__.py | 2 +- src/websockets/server.py | 2 +- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 531ce49f7..0c7e9b4c6 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -151,7 +151,7 @@ "handshake": ".legacy", "parse_uri": ".uri", "WebSocketURI": ".uri", - # deprecated in 14.0 + # deprecated in 14.0 - 2024-11-09 # .legacy.auth "BasicAuthWebSocketServerProtocol": ".legacy.auth", "basic_auth_protocol_factory": ".legacy.auth", diff --git a/src/websockets/auth.py b/src/websockets/auth.py index 98e62af3c..15b70a372 100644 --- a/src/websockets/auth.py +++ b/src/websockets/auth.py @@ -10,7 +10,7 @@ from .legacy.auth import __all__ # noqa: F401 -warnings.warn( # deprecated in 14.0 +warnings.warn( # deprecated in 14.0 - 2024-11-09 "websockets.auth, an alias for websockets.legacy.auth, is deprecated; " "see https://websockets.readthedocs.io/en/stable/howto/upgrade.html " "for upgrade instructions", diff --git a/src/websockets/client.py b/src/websockets/client.py index 8b66900a8..f6cbc9f65 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -392,7 +392,7 @@ def backoff( lazy_import( globals(), deprecated_aliases={ - # deprecated in 14.0 + # deprecated in 14.0 - 2024-11-09 "WebSocketClientProtocol": ".legacy.client", "connect": ".legacy.client", "unix_connect": ".legacy.client", diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index be3d1ca5f..f3e751971 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -113,7 +113,7 @@ def __str__(self) -> str: @property def code(self) -> int: - warnings.warn( # deprecated in 13.1 + warnings.warn( # deprecated in 13.1 - 2024-09-21 "ConnectionClosed.code is deprecated; " "use Protocol.close_code or ConnectionClosed.rcvd.code", DeprecationWarning, @@ -124,7 +124,7 @@ def code(self) -> int: @property def reason(self) -> str: - warnings.warn( # deprecated in 13.1 + warnings.warn( # deprecated in 13.1 - 2024-09-21 "ConnectionClosed.reason is deprecated; " "use Protocol.close_reason or ConnectionClosed.rcvd.reason", DeprecationWarning, @@ -343,7 +343,7 @@ def __init__( if isinstance(size_or_message, str): assert max_size is None assert cur_size is None - warnings.warn( # deprecated in 14.0 + warnings.warn( # deprecated in 14.0 - 2024-11-09 "PayloadTooBig(message) is deprecated; " "change to PayloadTooBig(size, max_size)", DeprecationWarning, @@ -408,7 +408,7 @@ class ConcurrencyError(WebSocketException, RuntimeError): lazy_import( globals(), deprecated_aliases={ - # deprecated in 14.0 + # deprecated in 14.0 - 2024-11-09 "AbortHandshake": ".legacy.exceptions", "InvalidMessage": ".legacy.exceptions", "InvalidStatusCode": ".legacy.exceptions", diff --git a/src/websockets/legacy/__init__.py b/src/websockets/legacy/__init__.py index 84f870f3a..ad9aa2506 100644 --- a/src/websockets/legacy/__init__.py +++ b/src/websockets/legacy/__init__.py @@ -3,7 +3,7 @@ import warnings -warnings.warn( # deprecated in 14.0 +warnings.warn( # deprecated in 14.0 - 2024-11-09 "websockets.legacy is deprecated; " "see https://websockets.readthedocs.io/en/stable/howto/upgrade.html " "for upgrade instructions", diff --git a/src/websockets/server.py b/src/websockets/server.py index e3fdcc646..607cc306e 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -580,7 +580,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: lazy_import( globals(), deprecated_aliases={ - # deprecated in 14.0 + # deprecated in 14.0 - 2024-11-09 "WebSocketServer": ".legacy.server", "WebSocketServerProtocol": ".legacy.server", "broadcast": ".legacy.server", From 083bcacd485a12dc1f9c6b98123bb92fafa4a9cf Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 11 Nov 2024 18:50:43 +0100 Subject: [PATCH 1447/1539] Import ConnectionClosed from websockets.exceptions. Fix #1539. --- docs/faq/client.rst | 3 ++- docs/faq/server.rst | 2 +- docs/intro/tutorial1.rst | 4 +++- src/websockets/asyncio/client.py | 2 +- src/websockets/legacy/client.py | 2 +- 5 files changed, 8 insertions(+), 5 deletions(-) diff --git a/docs/faq/client.rst b/docs/faq/client.rst index 0a7aab6e2..cc9856a8b 100644 --- a/docs/faq/client.rst +++ b/docs/faq/client.rst @@ -84,11 +84,12 @@ How do I reconnect when the connection drops? Use :func:`~websockets.asyncio.client.connect` as an asynchronous iterator:: from websockets.asyncio.client import connect + from websockets.exceptions import ConnectionClosed async for websocket in connect(...): try: ... - except websockets.ConnectionClosed: + except ConnectionClosed: continue Make sure you handle exceptions in the ``async for`` loop. Uncaught exceptions diff --git a/docs/faq/server.rst b/docs/faq/server.rst index 63eb5ffc6..ce7e1962d 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -147,7 +147,7 @@ Then, call :meth:`~ServerConnection.send`:: async def message_user(user_id, message): websocket = CONNECTIONS[user_id] # raises KeyError if user disconnected - await websocket.send(message) # may raise websockets.ConnectionClosed + await websocket.send(message) # may raise websockets.exceptions.ConnectionClosed Add error handling according to the behavior you want if the user disconnected before the message could be sent. diff --git a/docs/intro/tutorial1.rst b/docs/intro/tutorial1.rst index 6e91867c8..87074caee 100644 --- a/docs/intro/tutorial1.rst +++ b/docs/intro/tutorial1.rst @@ -271,11 +271,13 @@ spot real errors when you add functionality to the server. Catch it in the .. code-block:: python + from websockets.exceptions import ConnectionClosedOK + async def handler(websocket): while True: try: message = await websocket.recv() - except websockets.ConnectionClosedOK: + except ConnectionClosedOK: break print(message) diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index b3d50c12e..74ae70f0d 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -182,7 +182,7 @@ class connect: async for websocket in connect(...): try: ... - except websockets.ConnectionClosed: + except websockets.exceptions.ConnectionClosed: continue If the connection fails with a transient error, it is retried with diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 555069e8c..a3856b470 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -349,7 +349,7 @@ class Connect: async for websocket in connect(...): try: ... - except websockets.ConnectionClosed: + except websockets.exceptions.ConnectionClosed: continue The connection is closed automatically after each iteration of the loop. From 86bf0c5afc4c88429231ac7ba5f857fd32a02d0e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 11 Nov 2024 17:40:55 +0100 Subject: [PATCH 1448/1539] Remove unnecessary branch. --- src/websockets/asyncio/messages.py | 3 +-- tests/asyncio/test_messages.py | 8 -------- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index b57c0ca4e..69636e3d0 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -43,8 +43,7 @@ def put(self, item: T) -> None: async def get(self) -> T: """Remove and return an item from the queue, waiting if necessary.""" if not self.queue: - if self.get_waiter is not None: - raise ConcurrencyError("get is already running") + assert self.get_waiter is None, "cannot call get() concurrently" self.get_waiter = self.loop.create_future() try: await self.get_waiter diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index 2ff929d3a..5c9ac9445 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -37,14 +37,6 @@ async def test_get_then_put(self): item = await getter_task self.assertEqual(item, 42) - async def test_get_concurrently(self): - """get cannot be called concurrently.""" - getter_task = asyncio.create_task(self.queue.get()) - await asyncio.sleep(0) # let the task start - with self.assertRaises(ConcurrencyError): - await self.queue.get() - getter_task.cancel() - async def test_reset(self): """reset sets the content of the queue.""" self.queue.reset([42]) From f17c11ab6b46cfe6817cbab5d92ba2626d3e87d5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 11 Nov 2024 17:42:16 +0100 Subject: [PATCH 1449/1539] Keep queued messages after abort. --- src/websockets/asyncio/messages.py | 2 -- tests/asyncio/test_messages.py | 7 ------- 2 files changed, 9 deletions(-) diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 69636e3d0..678a3c14e 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -62,8 +62,6 @@ def abort(self) -> None: """Close the queue, raising EOFError in get() if necessary.""" if self.get_waiter is not None and not self.get_waiter.done(): self.get_waiter.set_exception(EOFError("stream of frames ended")) - # Clear the queue to avoid storing unnecessary data in memory. - self.queue.clear() class Assembler: diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index 5c9ac9445..181ffd376 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -51,13 +51,6 @@ async def test_abort(self): with self.assertRaises(EOFError): await getter_task - async def test_abort_clears_queue(self): - """abort clears buffered data from the queue.""" - self.queue.put(42) - self.assertEqual(len(self.queue), 1) - self.queue.abort() - self.assertEqual(len(self.queue), 0) - class AssemblerTests(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): From bdfc8cf90301b528eebfbd0ef8a31e5a2fc7d5f7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 11 Nov 2024 18:11:02 +0100 Subject: [PATCH 1450/1539] Uniformize comments. --- src/websockets/asyncio/messages.py | 8 ++++---- src/websockets/sync/messages.py | 11 +++++------ 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 678a3c14e..814a3c03c 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -141,8 +141,8 @@ async def get(self, decode: bool | None = None) -> Data: self.get_in_progress = True - # Locking with get_in_progress prevents concurrent execution until - # get() fetches a complete message or is cancelled. + # Locking with get_in_progress prevents concurrent execution + # until get() fetches a complete message or is cancelled. try: # First frame @@ -208,8 +208,8 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: self.get_in_progress = True - # Locking with get_in_progress prevents concurrent execution until - # get_iter() fetches a complete message or is cancelled. + # Locking with get_in_progress prevents concurrent execution + # until get_iter() fetches a complete message or is cancelled. # If get_iter() raises an exception e.g. in decoder.decode(), # get_in_progress remains set and the connection becomes unusable. diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 17f8dce7e..ce08172b2 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -128,10 +128,11 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - - # Locking with get_in_progress ensures only one thread can get here. self.get_in_progress = True + # Locking with get_in_progress prevents concurrent execution + # until get() fetches a complete message or times out. + try: deadline = Deadline(timeout) @@ -198,12 +199,10 @@ def get_iter(self, decode: bool | None = None) -> Iterator[Data]: if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - - # Locking with get_in_progress ensures only one coroutine can get here. self.get_in_progress = True - # Locking with get_in_progress prevents concurrent execution until - # get_iter() fetches a complete message or is cancelled. + # Locking with get_in_progress prevents concurrent execution + # until get_iter() fetches a complete message or times out. # If get_iter() raises an exception e.g. in decoder.decode(), # get_in_progress remains set and the connection becomes unusable. From 303483412dc5b420d09c1421792b3f8b99c323e6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 11 Nov 2024 18:30:36 +0100 Subject: [PATCH 1451/1539] Support recv() after the connection is closed. Fix #1538. --- docs/project/changelog.rst | 7 ++++ src/websockets/asyncio/messages.py | 20 ++++-------- src/websockets/sync/messages.py | 27 ++++++++-------- tests/asyncio/test_connection.py | 8 ++--- tests/asyncio/test_messages.py | 52 ++++++++++++++++++++++++++++++ tests/sync/test_connection.py | 19 ++++------- tests/sync/test_messages.py | 52 ++++++++++++++++++++++++++++++ 7 files changed, 142 insertions(+), 43 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 9e6a9d113..074a81c85 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -34,6 +34,13 @@ notice. .. _14.0: +Bug fixes +......... + +* Once the connection is closed, messages previously received and buffered can + be read in the :mod:`asyncio` and :mod:`threading` implementations, just like + in the legacy implementation. + 14.0 ---- diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 814a3c03c..14ea7bf90 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -40,9 +40,11 @@ def put(self, item: T) -> None: if self.get_waiter is not None and not self.get_waiter.done(): self.get_waiter.set_result(None) - async def get(self) -> T: + async def get(self, block: bool = True) -> T: """Remove and return an item from the queue, waiting if necessary.""" if not self.queue: + if not block: + raise EOFError("stream of frames ended") assert self.get_waiter is None, "cannot call get() concurrently" self.get_waiter = self.loop.create_future() try: @@ -133,12 +135,8 @@ async def get(self, decode: bool | None = None) -> Data: :meth:`get_iter` concurrently. """ - if self.closed: - raise EOFError("stream of frames ended") - if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - self.get_in_progress = True # Locking with get_in_progress prevents concurrent execution @@ -146,7 +144,7 @@ async def get(self, decode: bool | None = None) -> Data: try: # First frame - frame = await self.frames.get() + frame = await self.frames.get(not self.closed) self.maybe_resume() assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY if decode is None: @@ -156,7 +154,7 @@ async def get(self, decode: bool | None = None) -> Data: # Following frames, for fragmented messages while not frame.fin: try: - frame = await self.frames.get() + frame = await self.frames.get(not self.closed) except asyncio.CancelledError: # Put frames already received back into the queue # so that future calls to get() can return them. @@ -200,12 +198,8 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: :meth:`get_iter` concurrently. """ - if self.closed: - raise EOFError("stream of frames ended") - if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - self.get_in_progress = True # Locking with get_in_progress prevents concurrent execution @@ -216,7 +210,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: # First frame try: - frame = await self.frames.get() + frame = await self.frames.get(not self.closed) except asyncio.CancelledError: self.get_in_progress = False raise @@ -236,7 +230,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: # previous fragments — we're streaming them. Canceling get_iter() # here will leave the assembler in a stuck state. Future calls to # get() or get_iter() will raise ConcurrencyError. - frame = await self.frames.get() + frame = await self.frames.get(not self.closed) self.maybe_resume() assert frame.opcode is OP_CONT if decode: diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index ce08172b2..af8635f16 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -69,10 +69,16 @@ def __init__( def get_next_frame(self, timeout: float | None = None) -> Frame: # Helper to factor out the logic for getting the next frame from the # queue, while handling timeouts and reaching the end of the stream. - try: - frame = self.frames.get(timeout=timeout) - except queue.Empty: - raise TimeoutError(f"timed out in {timeout:.1f}s") from None + if self.closed: + try: + frame = self.frames.get(block=False) + except queue.Empty: + raise EOFError("stream of frames ended") from None + else: + try: + frame = self.frames.get(block=True, timeout=timeout) + except queue.Empty: + raise TimeoutError(f"timed out in {timeout:.1f}s") from None if frame is None: raise EOFError("stream of frames ended") return frame @@ -87,7 +93,7 @@ def reset_queue(self, frames: Iterable[Frame]) -> None: queued = [] try: while True: - queued.append(self.frames.get_nowait()) + queued.append(self.frames.get(block=False)) except queue.Empty: pass for frame in frames: @@ -123,9 +129,6 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: """ with self.mutex: - if self.closed: - raise EOFError("stream of frames ended") - if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") self.get_in_progress = True @@ -194,9 +197,6 @@ def get_iter(self, decode: bool | None = None) -> Iterator[Data]: """ with self.mutex: - if self.closed: - raise EOFError("stream of frames ended") - if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") self.get_in_progress = True @@ -288,5 +288,6 @@ def close(self) -> None: self.closed = True - # Unblock get() or get_iter(). - self.frames.put(None) + if self.get_in_progress: + # Unblock get() or get_iter(). + self.frames.put(None) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 902b3b847..b1c57c8ca 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -793,14 +793,12 @@ async def test_close_timeout_waiting_for_connection_closed(self): # Remove socket.timeout when dropping Python < 3.10. self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) - async def test_close_does_not_wait_for_recv(self): - # Closing the connection discards messages buffered in the assembler. - # This is allowed by the RFC: - # > However, there is no guarantee that the endpoint that has already - # > sent a Close frame will continue to process data. + async def test_close_preserves_queued_messages(self): + """close preserves messages buffered in the assembler.""" await self.remote_connection.send("😀") await self.connection.close() + self.assertEqual(await self.connection.recv(), "😀") with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index 181ffd376..566f71cea 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -395,6 +395,58 @@ async def test_get_iter_fails_after_close(self): async for _ in self.assembler.get_iter(): self.fail("no fragment expected") + async def test_get_queued_message_after_close(self): + """get returns a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_get_iter_queued_message_after_close(self): + """get_iter yields a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + async def test_get_queued_fragmented_message_after_close(self): + """get reassembles a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + self.assembler.close() + message = await self.assembler.get() + self.assertEqual(message, b"tea") + + async def test_get_iter_queued_fragmented_message_after_close(self): + """get_iter yields a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, [b"t", b"e", b"a"]) + + async def test_get_partially_queued_fragmented_message_after_close(self): + """get raises EOF on a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + with self.assertRaises(EOFError): + await self.assembler.get() + + async def test_get_iter_partially_queued_fragmented_message_after_close(self): + """get_iter yields a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + fragments = [] + with self.assertRaises(EOFError): + async for fragment in self.assembler.get_iter(): + fragments.append(fragment) + self.assertEqual(fragments, [b"t", b"e"]) + async def test_put_fails_after_close(self): """put raises EOFError after close is called.""" self.assembler.close() diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index abdfd3f78..408b9697a 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -543,17 +543,12 @@ def test_close_timeout_waiting_for_connection_closed(self): # Remove socket.timeout when dropping Python < 3.10. self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) - def test_close_does_not_wait_for_recv(self): - # Closing the connection discards messages buffered in the assembler. - # This is allowed by the RFC: - # > However, there is no guarantee that the endpoint that has already - # > sent a Close frame will continue to process data. + def test_close_preserves_queued_messages(self): + """close preserves messages buffered in the assembler.""" self.remote_connection.send("😀") self.connection.close() - close_thread = threading.Thread(target=self.connection.close) - close_thread.start() - + self.assertEqual(self.connection.recv(), "😀") with self.assertRaises(ConnectionClosedOK) as raised: self.connection.recv() @@ -576,10 +571,10 @@ def test_close_idempotency(self): def test_close_idempotency_race_condition(self): """close waits if the connection is already closing.""" - self.connection.close_timeout = 5 * MS + self.connection.close_timeout = 6 * MS def closer(): - with self.delay_frames_rcvd(3 * MS): + with self.delay_frames_rcvd(4 * MS): self.connection.close() close_thread = threading.Thread(target=closer) @@ -591,14 +586,14 @@ def closer(): # Connection isn't closed yet. with self.assertRaises(TimeoutError): - self.connection.recv(timeout=0) + self.connection.recv(timeout=MS) self.connection.close() self.assertNoFrameSent() # Connection is closed now. with self.assertRaises(ConnectionClosedOK): - self.connection.recv(timeout=0) + self.connection.recv(timeout=MS) close_thread.join() diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py index 02513894a..9ebe45088 100644 --- a/tests/sync/test_messages.py +++ b/tests/sync/test_messages.py @@ -374,6 +374,58 @@ def test_get_iter_fails_after_close(self): for _ in self.assembler.get_iter(): self.fail("no fragment expected") + def test_get_queued_message_after_close(self): + """get returns a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + message = self.assembler.get() + self.assertEqual(message, "café") + + def test_get_iter_queued_message_after_close(self): + """get_iter yields a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + fragments = list(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + def test_get_queued_fragmented_message_after_close(self): + """get reassembles a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + self.assembler.close() + message = self.assembler.get() + self.assertEqual(message, b"tea") + + def test_get_iter_queued_fragmented_message_after_close(self): + """get_iter yields a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + fragments = list(self.assembler.get_iter()) + self.assertEqual(fragments, [b"t", b"e", b"a"]) + + def test_get_partially_queued_fragmented_message_after_close(self): + """get raises EOF on a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + with self.assertRaises(EOFError): + self.assembler.get() + + def test_get_iter_partially_queued_fragmented_message_after_close(self): + """get_iter yields a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + fragments = [] + with self.assertRaises(EOFError): + for fragment in self.assembler.get_iter(): + fragments.append(fragment) + self.assertEqual(fragments, [b"t", b"e"]) + def test_put_fails_after_close(self): """put raises EOFError after close is called.""" self.assembler.close() From 9a2f39fc66fd2427f904e551b1cc5f3995b02217 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 11 Nov 2024 21:13:39 +0100 Subject: [PATCH 1452/1539] Support max_queue=None like the legacy implementation. Fix #1540. --- docs/project/changelog.rst | 7 ++++ src/websockets/asyncio/client.py | 7 ++-- src/websockets/asyncio/connection.py | 4 +- src/websockets/asyncio/messages.py | 23 ++++++++--- src/websockets/asyncio/server.py | 7 ++-- src/websockets/sync/client.py | 7 ++-- src/websockets/sync/connection.py | 4 +- src/websockets/sync/messages.py | 25 +++++++++--- src/websockets/sync/server.py | 7 ++-- tests/asyncio/test_connection.py | 12 +++++- tests/asyncio/test_messages.py | 61 +++++++++++++++++++++++++--- tests/sync/test_connection.py | 17 +++++++- tests/sync/test_messages.py | 61 +++++++++++++++++++++++++--- 13 files changed, 200 insertions(+), 42 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 074a81c85..8e1ad81f0 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -34,6 +34,13 @@ notice. .. _14.0: +Improvements +............ + +* Supported ``max_queue=None`` in the :mod:`asyncio` and :mod:`threading` + implementations for consistency with the legacy implementation, even though + this is never a good idea. + Bug fixes ......... diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 74ae70f0d..cdd9bfac6 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -60,7 +60,7 @@ def __init__( ping_interval: float | None = 20, ping_timeout: float | None = 20, close_timeout: float | None = 10, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, ) -> None: self.protocol: ClientProtocol @@ -222,7 +222,8 @@ class connect: max_queue: High-water mark of the buffer where frames are received. It defaults to 16 frames. The low-water mark defaults to ``max_queue // 4``. You may pass a ``(high, low)`` tuple to set the high-water - and low-water marks. + and low-water marks. If you want to disable flow control entirely, + you may set it to ``None``, although that's a bad idea. write_limit: High-water mark of write buffer in bytes. It is passed to :meth:`~asyncio.WriteTransport.set_write_buffer_limits`. It defaults to 32 KiB. You may pass a ``(high, low)`` tuple to set the @@ -283,7 +284,7 @@ def __init__( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, # Logging logger: LoggerLike | None = None, diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 186846ef3..f1dcbada6 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -56,14 +56,14 @@ def __init__( ping_interval: float | None = 20, ping_timeout: float | None = 20, close_timeout: float | None = 10, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, ) -> None: self.protocol = protocol self.ping_interval = ping_interval self.ping_timeout = ping_timeout self.close_timeout = close_timeout - if isinstance(max_queue, int): + if isinstance(max_queue, int) or max_queue is None: max_queue = (max_queue, None) self.max_queue = max_queue if isinstance(write_limit, int): diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 14ea7bf90..e6d1d31cc 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -84,7 +84,7 @@ class Assembler: # coverage reports incorrectly: "line NN didn't jump to the function exit" def __init__( # pragma: no cover self, - high: int = 16, + high: int | None = None, low: int | None = None, pause: Callable[[], Any] = lambda: None, resume: Callable[[], Any] = lambda: None, @@ -96,12 +96,15 @@ def __init__( # pragma: no cover # call to Protocol.data_received() could produce thousands of frames, # which must be buffered. Instead, we pause reading when the buffer goes # above the high limit and we resume when it goes under the low limit. - if low is None: + if high is not None and low is None: low = high // 4 - if low < 0: - raise ValueError("low must be positive or equal to zero") - if high < low: - raise ValueError("high must be greater than or equal to low") + if high is None and low is not None: + high = low * 4 + if high is not None and low is not None: + if low < 0: + raise ValueError("low must be positive or equal to zero") + if high < low: + raise ValueError("high must be greater than or equal to low") self.high, self.low = high, low self.pause = pause self.resume = resume @@ -256,6 +259,10 @@ def put(self, frame: Frame) -> None: def maybe_pause(self) -> None: """Pause the writer if queue is above the high water mark.""" + # Skip if flow control is disabled + if self.high is None: + return + # Check for "> high" to support high = 0 if len(self.frames) > self.high and not self.paused: self.paused = True @@ -263,6 +270,10 @@ def maybe_pause(self) -> None: def maybe_resume(self) -> None: """Resume the writer if queue is below the low water mark.""" + # Skip if flow control is disabled + if self.low is None: + return + # Check for "<= low" to support low = 0 if len(self.frames) <= self.low and self.paused: self.paused = False diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 15c9ba13e..fdb928004 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -71,7 +71,7 @@ def __init__( ping_interval: float | None = 20, ping_timeout: float | None = 20, close_timeout: float | None = 10, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, ) -> None: self.protocol: ServerProtocol @@ -643,7 +643,8 @@ def handler(websocket): max_queue: High-water mark of the buffer where frames are received. It defaults to 16 frames. The low-water mark defaults to ``max_queue // 4``. You may pass a ``(high, low)`` tuple to set the high-water - and low-water marks. + and low-water marks. If you want to disable flow control entirely, + you may set it to ``None``, although that's a bad idea. write_limit: High-water mark of write buffer in bytes. It is passed to :meth:`~asyncio.WriteTransport.set_write_buffer_limits`. It defaults to 32 KiB. You may pass a ``(high, low)`` tuple to set the @@ -713,7 +714,7 @@ def __init__( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, # Logging logger: LoggerLike | None = None, diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 54d0aef68..9e6da7caf 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -55,7 +55,7 @@ def __init__( protocol: ClientProtocol, *, close_timeout: float | None = 10, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, ) -> None: self.protocol: ClientProtocol self.response_rcvd = threading.Event() @@ -139,7 +139,7 @@ def connect( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization @@ -191,7 +191,8 @@ def connect( max_queue: High-water mark of the buffer where frames are received. It defaults to 16 frames. The low-water mark defaults to ``max_queue // 4``. You may pass a ``(high, low)`` tuple to set the high-water - and low-water marks. + and low-water marks. If you want to disable flow control entirely, + you may set it to ``None``, although that's a bad idea. logger: Logger for this client. It defaults to ``logging.getLogger("websockets.client")``. See the :doc:`logging guide <../../topics/logging>` for details. diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 77f803c9b..be3381c8a 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -49,12 +49,12 @@ def __init__( protocol: Protocol, *, close_timeout: float | None = 10, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, ) -> None: self.socket = socket self.protocol = protocol self.close_timeout = close_timeout - if isinstance(max_queue, int): + if isinstance(max_queue, int) or max_queue is None: max_queue = (max_queue, None) self.max_queue = max_queue diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index af8635f16..98490797f 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -33,7 +33,7 @@ class Assembler: def __init__( self, - high: int = 16, + high: int | None = None, low: int | None = None, pause: Callable[[], Any] = lambda: None, resume: Callable[[], Any] = lambda: None, @@ -49,12 +49,15 @@ def __init__( # call to Protocol.data_received() could produce thousands of frames, # which must be buffered. Instead, we pause reading when the buffer goes # above the high limit and we resume when it goes under the low limit. - if low is None: + if high is not None and low is None: low = high // 4 - if low < 0: - raise ValueError("low must be positive or equal to zero") - if high < low: - raise ValueError("high must be greater than or equal to low") + if high is None and low is not None: + high = low * 4 + if high is not None and low is not None: + if low < 0: + raise ValueError("low must be positive or equal to zero") + if high < low: + raise ValueError("high must be greater than or equal to low") self.high, self.low = high, low self.pause = pause self.resume = resume @@ -260,7 +263,12 @@ def put(self, frame: Frame) -> None: def maybe_pause(self) -> None: """Pause the writer if queue is above the high water mark.""" + # Skip if flow control is disabled + if self.high is None: + return + assert self.mutex.locked() + # Check for "> high" to support high = 0 if self.frames.qsize() > self.high and not self.paused: self.paused = True @@ -268,7 +276,12 @@ def maybe_pause(self) -> None: def maybe_resume(self) -> None: """Resume the writer if queue is below the low water mark.""" + # Skip if flow control is disabled + if self.low is None: + return + assert self.mutex.locked() + # Check for "<= low" to support low = 0 if self.frames.qsize() <= self.low and self.paused: self.paused = False diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 8601ccef9..9506d6830 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -66,7 +66,7 @@ def __init__( protocol: ServerProtocol, *, close_timeout: float | None = 10, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, ) -> None: self.protocol: ServerProtocol self.request_rcvd = threading.Event() @@ -356,7 +356,7 @@ def serve( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization @@ -438,7 +438,8 @@ def handler(websocket): max_queue: High-water mark of the buffer where frames are received. It defaults to 16 frames. The low-water mark defaults to ``max_queue // 4``. You may pass a ``(high, low)`` tuple to set the high-water - and low-water marks. + and low-water marks. If you want to disable flow control entirely, + you may set it to ``None``, although that's a bad idea. logger: Logger for this server. It defaults to ``logging.getLogger("websockets.server")``. See the :doc:`logging guide <../../topics/logging>` for details. diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index b1c57c8ca..8dd0a0335 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -1066,14 +1066,22 @@ async def test_close_timeout(self): self.assertEqual(connection.close_timeout, 42 * MS) async def test_max_queue(self): - """max_queue parameter configures high-water mark of frames buffer.""" + """max_queue configures high-water mark of frames buffer.""" connection = Connection(Protocol(self.LOCAL), max_queue=4) transport = Mock() connection.connection_made(transport) self.assertEqual(connection.recv_messages.high, 4) + async def test_max_queue_none(self): + """max_queue disables high-water mark of frames buffer.""" + connection = Connection(Protocol(self.LOCAL), max_queue=None) + transport = Mock() + connection.connection_made(transport) + self.assertEqual(connection.recv_messages.high, None) + self.assertEqual(connection.recv_messages.low, None) + async def test_max_queue_tuple(self): - """max_queue parameter configures high-water mark of frames buffer.""" + """max_queue configures high-water and low-water marks of frames buffer.""" connection = Connection( Protocol(self.LOCAL), max_queue=(4, 2), diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index 566f71cea..a90788d02 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -153,7 +153,7 @@ async def test_get_decoded_binary_message(self): self.assertEqual(message, "tea") async def test_get_resumes_reading(self): - """get resumes reading when queue goes below the high-water mark.""" + """get resumes reading when queue goes below the low-water mark.""" self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"water")) @@ -170,6 +170,19 @@ async def test_get_resumes_reading(self): await self.assembler.get() self.resume.assert_called_once_with() + async def test_get_does_not_resume_reading(self): + """get does not resume reading when the low-water mark is unset.""" + self.assembler.low = None + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"water")) + await self.assembler.get() + await self.assembler.get() + await self.assembler.get() + + self.resume.assert_not_called() + async def test_cancel_get_before_first_frame(self): """get can be canceled safely before reading the first frame.""" getter_task = asyncio.create_task(self.assembler.get()) @@ -302,7 +315,7 @@ async def test_get_iter_decoded_binary_message(self): self.assertEqual(fragments, ["t", "e", "a"]) async def test_get_iter_resumes_reading(self): - """get_iter resumes reading when queue goes below the high-water mark.""" + """get_iter resumes reading when queue goes below the low-water mark.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) @@ -321,6 +334,20 @@ async def test_get_iter_resumes_reading(self): await anext(iterator) self.resume.assert_called_once_with() + async def test_get_iter_does_not_resume_reading(self): + """get_iter does not resume reading when the low-water mark is unset.""" + self.assembler.low = None + + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + iterator = aiter(self.assembler.get_iter()) + await anext(iterator) + await anext(iterator) + await anext(iterator) + + self.resume.assert_not_called() + async def test_cancel_get_iter_before_first_frame(self): """get_iter can be canceled safely before reading the first frame.""" getter_task = asyncio.create_task(alist(self.assembler.get_iter())) @@ -367,6 +394,17 @@ async def test_put_pauses_reading(self): self.assembler.put(Frame(OP_CONT, b"a")) self.pause.assert_called_once_with() + async def test_put_does_not_pause_reading(self): + """put does not pause reading when the high-water mark is unset.""" + self.assembler.high = None + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + self.pause.assert_not_called() + # Test termination async def test_get_fails_when_interrupted_by_close(self): @@ -495,16 +533,29 @@ async def test_get_iter_fails_when_get_iter_is_running(self): # Test setting limits async def test_set_high_water_mark(self): - """high sets the high-water mark.""" + """high sets the high-water and low-water marks.""" assembler = Assembler(high=10) self.assertEqual(assembler.high, 10) + self.assertEqual(assembler.low, 2) + + async def test_set_low_water_mark(self): + """low sets the low-water and high-water marks.""" + assembler = Assembler(low=5) + self.assertEqual(assembler.low, 5) + self.assertEqual(assembler.high, 20) - async def test_set_high_and_low_water_mark(self): - """high sets the high-water mark and low-water mark.""" + async def test_set_high_and_low_water_marks(self): + """high and low set the high-water and low-water marks.""" assembler = Assembler(high=10, low=5) self.assertEqual(assembler.high, 10) self.assertEqual(assembler.low, 5) + async def test_unset_high_and_low_water_marks(self): + """High-water and low-water marks are unset.""" + assembler = Assembler() + self.assertEqual(assembler.high, None) + self.assertEqual(assembler.low, None) + async def test_set_invalid_high_water_mark(self): """high must be a non-negative integer.""" with self.assertRaises(ValueError): diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 408b9697a..6be490a5d 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -749,7 +749,7 @@ def test_close_timeout(self): self.assertEqual(connection.close_timeout, 42 * MS) def test_max_queue(self): - """max_queue parameter configures high-water mark of frames buffer.""" + """max_queue configures high-water mark of frames buffer.""" socket_, remote_socket = socket.socketpair() self.addCleanup(socket_.close) self.addCleanup(remote_socket.close) @@ -760,8 +760,21 @@ def test_max_queue(self): ) self.assertEqual(connection.recv_messages.high, 4) + def test_max_queue_none(self): + """max_queue disables high-water mark of frames buffer.""" + socket_, remote_socket = socket.socketpair() + self.addCleanup(socket_.close) + self.addCleanup(remote_socket.close) + connection = Connection( + socket_, + Protocol(self.LOCAL), + max_queue=None, + ) + self.assertEqual(connection.recv_messages.high, None) + self.assertEqual(connection.recv_messages.high, None) + def test_max_queue_tuple(self): - """max_queue parameter configures high-water mark of frames buffer.""" + """max_queue configures high-water and low-water marks of frames buffer.""" socket_, remote_socket = socket.socketpair() self.addCleanup(socket_.close) self.addCleanup(remote_socket.close) diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py index 9ebe45088..d22693102 100644 --- a/tests/sync/test_messages.py +++ b/tests/sync/test_messages.py @@ -145,7 +145,7 @@ def test_get_decoded_binary_message(self): self.assertEqual(message, "tea") def test_get_resumes_reading(self): - """get resumes reading when queue goes below the high-water mark.""" + """get resumes reading when queue goes below the low-water mark.""" self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"water")) @@ -162,6 +162,19 @@ def test_get_resumes_reading(self): self.assembler.get() self.resume.assert_called_once_with() + def test_get_does_not_resume_reading(self): + """get does not resume reading when the low-water mark is unset.""" + self.assembler.low = None + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"water")) + self.assembler.get() + self.assembler.get() + self.assembler.get() + + self.resume.assert_not_called() + def test_get_timeout_before_first_frame(self): """get times out before reading the first frame.""" with self.assertRaises(TimeoutError): @@ -300,7 +313,7 @@ def test_get_iter_decoded_binary_message(self): self.assertEqual(fragments, ["t", "e", "a"]) def test_get_iter_resumes_reading(self): - """get_iter resumes reading when queue goes below the high-water mark.""" + """get_iter resumes reading when queue goes below the low-water mark.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) @@ -319,6 +332,20 @@ def test_get_iter_resumes_reading(self): next(iterator) self.resume.assert_called_once_with() + def test_get_iter_does_not_resume_reading(self): + """get_iter does not resume reading when the low-water mark is unset.""" + self.assembler.low = None + + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + iterator = self.assembler.get_iter() + next(iterator) + next(iterator) + next(iterator) + + self.resume.assert_not_called() + # Test put def test_put_pauses_reading(self): @@ -336,6 +363,17 @@ def test_put_pauses_reading(self): self.assembler.put(Frame(OP_CONT, b"a")) self.pause.assert_called_once_with() + def test_put_does_not_pause_reading(self): + """put does not pause reading when the high-water mark is unset.""" + self.assembler.high = None + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + self.pause.assert_not_called() + # Test termination def test_get_fails_when_interrupted_by_close(self): @@ -470,16 +508,29 @@ def test_get_iter_fails_when_get_iter_is_running(self): # Test setting limits def test_set_high_water_mark(self): - """high sets the high-water mark.""" + """high sets the high-water and low-water marks.""" assembler = Assembler(high=10) self.assertEqual(assembler.high, 10) + self.assertEqual(assembler.low, 2) + + def test_set_low_water_mark(self): + """low sets the low-water and high-water marks.""" + assembler = Assembler(low=5) + self.assertEqual(assembler.low, 5) + self.assertEqual(assembler.high, 20) - def test_set_high_and_low_water_mark(self): - """high sets the high-water mark and low-water mark.""" + def test_set_high_and_low_water_marks(self): + """high and low set the high-water and low-water marks.""" assembler = Assembler(high=10, low=5) self.assertEqual(assembler.high, 10) self.assertEqual(assembler.low, 5) + def test_unset_high_and_low_water_marks(self): + """High-water and low-water marks are unset.""" + assembler = Assembler() + self.assertEqual(assembler.high, None) + self.assertEqual(assembler.low, None) + def test_set_invalid_high_water_mark(self): """high must be a non-negative integer.""" with self.assertRaises(ValueError): From de2e7fb8b7eaca56d633ae7d2ffffdbb212048a1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 11 Nov 2024 21:40:26 +0100 Subject: [PATCH 1453/1539] Add close_code and close_reason to new implementations. Also add state to threading implementation. Fix #1537. --- docs/reference/asyncio/client.rst | 7 ++++++ docs/reference/asyncio/common.rst | 7 ++++++ docs/reference/asyncio/server.rst | 7 ++++++ docs/reference/sync/client.rst | 9 +++++++ docs/reference/sync/common.rst | 9 +++++++ docs/reference/sync/server.rst | 9 +++++++ src/websockets/asyncio/connection.py | 24 ++++++++++++++++++ src/websockets/protocol.py | 21 ++++++++++------ src/websockets/sync/connection.py | 37 ++++++++++++++++++++++++++++ tests/asyncio/test_connection.py | 10 +++++++- tests/sync/test_connection.py | 14 ++++++++++- 11 files changed, 145 insertions(+), 9 deletions(-) diff --git a/docs/reference/asyncio/client.rst b/docs/reference/asyncio/client.rst index e2b0ff550..ea7b21506 100644 --- a/docs/reference/asyncio/client.rst +++ b/docs/reference/asyncio/client.rst @@ -57,3 +57,10 @@ Using a connection .. autoattribute:: response .. autoproperty:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason diff --git a/docs/reference/asyncio/common.rst b/docs/reference/asyncio/common.rst index a58325fb9..325f20450 100644 --- a/docs/reference/asyncio/common.rst +++ b/docs/reference/asyncio/common.rst @@ -45,3 +45,10 @@ Both sides (new :mod:`asyncio`) .. autoattribute:: response .. autoproperty:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index 2fcaeb414..49bd6f072 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -79,6 +79,13 @@ Using a connection .. autoproperty:: subprotocol + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason + Broadcast --------- diff --git a/docs/reference/sync/client.rst b/docs/reference/sync/client.rst index af1132412..2aa491f6a 100644 --- a/docs/reference/sync/client.rst +++ b/docs/reference/sync/client.rst @@ -39,6 +39,8 @@ Using a connection .. autoproperty:: remote_address + .. autoproperty:: state + The following attributes are available after the opening handshake, once the WebSocket connection is open: @@ -47,3 +49,10 @@ Using a connection .. autoattribute:: response .. autoproperty:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason diff --git a/docs/reference/sync/common.rst b/docs/reference/sync/common.rst index 3dc6d4a50..3c03b25b6 100644 --- a/docs/reference/sync/common.rst +++ b/docs/reference/sync/common.rst @@ -31,6 +31,8 @@ Both sides (:mod:`threading`) .. autoproperty:: remote_address + .. autoproperty:: state + The following attributes are available after the opening handshake, once the WebSocket connection is open: @@ -39,3 +41,10 @@ Both sides (:mod:`threading`) .. autoattribute:: response .. autoproperty:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason diff --git a/docs/reference/sync/server.rst b/docs/reference/sync/server.rst index 80e9c17bb..1d80450f9 100644 --- a/docs/reference/sync/server.rst +++ b/docs/reference/sync/server.rst @@ -52,6 +52,8 @@ Using a connection .. autoproperty:: remote_address + .. autoproperty:: state + The following attributes are available after the opening handshake, once the WebSocket connection is open: @@ -61,6 +63,13 @@ Using a connection .. autoproperty:: subprotocol + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason + HTTP Basic Authentication ------------------------- diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index f1dcbada6..e5c350fe2 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -185,6 +185,30 @@ def subprotocol(self) -> Subprotocol | None: """ return self.protocol.subprotocol + @property + def close_code(self) -> int | None: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should inspect attributes + of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. + + """ + return self.protocol.close_code + + @property + def close_reason(self) -> str | None: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should inspect attributes + of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. + + """ + return self.protocol.close_reason + # Public methods async def __aenter__(self) -> Connection: diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 0f6fea250..bc64a216a 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -159,7 +159,12 @@ def state(self) -> State: """ State of the WebSocket connection. - Defined in 4.1, 4.2, 7.1.3, and 7.1.4 of :rfc:`6455`. + Defined in 4.1_, 4.2_, 7.1.3_, and 7.1.4_ of :rfc:`6455`. + + .. _4.1: https://datatracker.ietf.org/doc/html/rfc6455#section-4.1 + .. _4.2: https://datatracker.ietf.org/doc/html/rfc6455#section-4.2 + .. _7.1.3: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.3 + .. _7.1.4: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.4 """ return self._state @@ -173,10 +178,11 @@ def state(self, state: State) -> None: @property def close_code(self) -> int | None: """ - `WebSocket close code`_. + WebSocket close code received from the remote endpoint. + + Defined in 7.1.5_ of :rfc:`6455`. - .. _WebSocket close code: - https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.5 + .. _7.1.5: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.5 :obj:`None` if the connection isn't closed yet. @@ -191,10 +197,11 @@ def close_code(self) -> int | None: @property def close_reason(self) -> str | None: """ - `WebSocket close reason`_. + WebSocket close reason received from the remote endpoint. + + Defined in 7.1.6_ of :rfc:`6455`. - .. _WebSocket close reason: - https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.6 + .. _7.1.6: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.6 :obj:`None` if the connection isn't closed yet. diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index be3381c8a..d8dbf140e 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -140,6 +140,19 @@ def remote_address(self) -> Any: """ return self.socket.getpeername() + @property + def state(self) -> State: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should call :meth:`~recv` or + :meth:`send` and handle :exc:`~websockets.exceptions.ConnectionClosed` + exceptions. + + """ + return self.protocol.state + @property def subprotocol(self) -> Subprotocol | None: """ @@ -150,6 +163,30 @@ def subprotocol(self) -> Subprotocol | None: """ return self.protocol.subprotocol + @property + def close_code(self) -> int | None: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should inspect attributes + of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. + + """ + return self.protocol.close_code + + @property + def close_reason(self) -> str | None: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should inspect attributes + of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. + + """ + return self.protocol.close_reason + # Public methods def __enter__(self) -> Connection: diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 8dd0a0335..5a0b61bf7 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -1139,7 +1139,7 @@ async def test_remote_address(self, get_extra_info): async def test_state(self): """Connection has a state attribute.""" - self.assertEqual(self.connection.state, State.OPEN) + self.assertIs(self.connection.state, State.OPEN) async def test_request(self): """Connection has a request attribute.""" @@ -1153,6 +1153,14 @@ async def test_subprotocol(self): """Connection has a subprotocol attribute.""" self.assertIsNone(self.connection.subprotocol) + async def test_close_code(self): + """Connection has a close_code attribute.""" + self.assertIsNone(self.connection.close_code) + + async def test_close_reason(self): + """Connection has a close_reason attribute.""" + self.assertIsNone(self.connection.close_reason) + # Test reporting of network errors. async def test_writing_in_data_received_fails(self): diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 6be490a5d..4884bf13f 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -15,7 +15,7 @@ ConnectionClosedOK, ) from websockets.frames import CloseCode, Frame, Opcode -from websockets.protocol import CLIENT, SERVER, Protocol +from websockets.protocol import CLIENT, SERVER, Protocol, State from websockets.sync.connection import * from ..protocol import RecordingProtocol @@ -808,6 +808,10 @@ def test_remote_address(self, getpeername): self.assertEqual(self.connection.remote_address, ("peer", 1234)) getpeername.assert_called_with() + def test_state(self): + """Connection has a state attribute.""" + self.assertIs(self.connection.state, State.OPEN) + def test_request(self): """Connection has a request attribute.""" self.assertIsNone(self.connection.request) @@ -820,6 +824,14 @@ def test_subprotocol(self): """Connection has a subprotocol attribute.""" self.assertIsNone(self.connection.subprotocol) + def test_close_code(self): + """Connection has a close_code attribute.""" + self.assertIsNone(self.connection.close_code) + + def test_close_reason(self): + """Connection has a close_reason attribute.""" + self.assertIsNone(self.connection.close_reason) + # Test reporting of network errors. @unittest.skipUnless(sys.platform == "darwin", "works only on BSD") From 1f19487b0f55aac3f67a5f1b35209e0e3b294063 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 11 Nov 2024 22:01:58 +0100 Subject: [PATCH 1454/1539] Add changelog for previous commit. --- docs/project/changelog.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 8e1ad81f0..792f8ede4 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -41,6 +41,10 @@ Improvements implementations for consistency with the legacy implementation, even though this is never a good idea. +* Added ``close_code`` and ``close_reason`` attributes in the :mod:`asyncio` and + :mod:`threading` implementations for consistency with the legacy + implementation. + Bug fixes ......... From 7438b8ebee3f6ef59c15b86d627d32f98f49df8f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 12 Nov 2024 08:40:24 +0100 Subject: [PATCH 1455/1539] Stop testing on PyPy 3.9. PyPy v7.3.17 no longer provides PyPy 3.9. Also the test suite was flaky under PyPy 3.9. --- .github/workflows/tests.yml | 3 --- tests/sync/test_connection.py | 5 ----- 2 files changed, 8 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index beaf9d12b..5ab9c4c72 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -60,13 +60,10 @@ jobs: - "3.11" - "3.12" - "3.13" - - "pypy-3.9" - "pypy-3.10" is_main: - ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} exclude: - - python: "pypy-3.9" - is_main: false - python: "pypy-3.10" is_main: false steps: diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 4884bf13f..e21e310a2 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -1,6 +1,5 @@ import contextlib import logging -import platform import socket import sys import threading @@ -564,10 +563,6 @@ def test_close_idempotency(self): self.connection.close() self.assertNoFrameSent() - @unittest.skipIf( - platform.python_implementation() == "PyPy", - "this test fails randomly due to a bug in PyPy", # see #1314 for details - ) def test_close_idempotency_race_condition(self): """close waits if the connection is already closing.""" From d0015c93f49511eb8dd073caa6a3a338979f741b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 13 Nov 2024 07:51:11 +0100 Subject: [PATCH 1456/1539] Fix refactoring error in a78b5546. Fix #1546. --- example/faq/health_check_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/faq/health_check_server.py b/example/faq/health_check_server.py index c0fa4327f..30623a4bb 100755 --- a/example/faq/health_check_server.py +++ b/example/faq/health_check_server.py @@ -6,7 +6,7 @@ def health_check(connection, request): if request.path == "/healthz": - return connection.respond(HTTPStatus.OK, b"OK\n") + return connection.respond(HTTPStatus.OK, "OK\n") async def echo(websocket): async for message in websocket: From 0403823185b272ae389da1c9182773932b4df950 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 13 Nov 2024 07:53:45 +0100 Subject: [PATCH 1457/1539] Release version 14.1. --- docs/project/changelog.rst | 6 +++--- src/websockets/version.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 792f8ede4..ca6769199 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -30,9 +30,7 @@ notice. 14.1 ---- -*In development* - -.. _14.0: +*November 13, 2024* Improvements ............ @@ -52,6 +50,8 @@ Bug fixes be read in the :mod:`asyncio` and :mod:`threading` implementations, just like in the legacy implementation. +.. _14.0: + 14.0 ---- diff --git a/src/websockets/version.py b/src/websockets/version.py index 48d2edaea..f2defeff0 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,7 +18,7 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = False +released = True tag = version = commit = "14.1" From d8891a101d7cfc2cf7e01de078ada7c93a264ebd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 13 Nov 2024 07:55:13 +0100 Subject: [PATCH 1458/1539] Start version 14.2. --- docs/project/changelog.rst | 7 +++++++ src/websockets/version.py | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index ca6769199..9c594b653 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,13 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. +.. _14.2: + +14.2 +---- + +*In development* + .. _14.1: 14.1 diff --git a/src/websockets/version.py b/src/websockets/version.py index f2defeff0..b0df10f0a 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,9 +18,9 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = True +released = False -tag = version = commit = "14.1" +tag = version = commit = "14.2" if not released: # pragma: no cover From 59d4dcf779fe7d2b0302083b072d8b03adce2f61 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 13 Nov 2024 23:00:22 +0100 Subject: [PATCH 1459/1539] Reintroduce InvalidMessage. This improves compatibility with the legacy implementation and clarifies error reporting. Fix #1548. --- docs/project/changelog.rst | 8 ++++++++ docs/reference/exceptions.rst | 4 ++-- src/websockets/__init__.py | 4 +++- src/websockets/asyncio/client.py | 6 ++++-- src/websockets/client.py | 6 +++++- src/websockets/exceptions.py | 11 +++++++++-- src/websockets/legacy/client.py | 3 ++- src/websockets/legacy/exceptions.py | 9 ++------- src/websockets/legacy/server.py | 3 ++- src/websockets/server.py | 6 +++++- tests/asyncio/test_client.py | 27 ++++++++++++++++++++------- tests/asyncio/test_server.py | 4 ++++ tests/legacy/test_exceptions.py | 4 ---- tests/sync/test_client.py | 21 ++++++++++++++++++--- tests/sync/test_server.py | 4 ++++ tests/test_client.py | 28 ++++++++++++++++++++++++---- tests/test_exceptions.py | 4 ++++ tests/test_server.py | 24 ++++++++++++++++++++---- 18 files changed, 136 insertions(+), 40 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 9c594b653..b7f4f62f9 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -32,6 +32,14 @@ notice. *In development* +Bug fixes +......... + +* Wrapped errors when reading the opening handshake request or response in + :exc:`~exceptions.InvalidMessage` so that :func:`~asyncio.client.connect` + raises :exc:`~exceptions.InvalidHandshake` or a subclass when the opening + handshake fails. + .. _14.1: 14.1 diff --git a/docs/reference/exceptions.rst b/docs/reference/exceptions.rst index 75934ef99..d6b7f0f57 100644 --- a/docs/reference/exceptions.rst +++ b/docs/reference/exceptions.rst @@ -30,6 +30,8 @@ also reported by :func:`~websockets.asyncio.server.serve` in logs. .. autoexception:: InvalidHandshake +.. autoexception:: InvalidMessage + .. autoexception:: SecurityError .. autoexception:: InvalidStatus @@ -74,8 +76,6 @@ Legacy exceptions These exceptions are only used by the legacy :mod:`asyncio` implementation. -.. autoexception:: InvalidMessage - .. autoexception:: InvalidStatusCode .. autoexception:: AbortHandshake diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 0c7e9b4c6..c278b21d4 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -31,6 +31,7 @@ "InvalidHeader", "InvalidHeaderFormat", "InvalidHeaderValue", + "InvalidMessage", "InvalidOrigin", "InvalidParameterName", "InvalidParameterValue", @@ -71,6 +72,7 @@ InvalidHeader, InvalidHeaderFormat, InvalidHeaderValue, + InvalidMessage, InvalidOrigin, InvalidParameterName, InvalidParameterValue, @@ -122,6 +124,7 @@ "InvalidHeader": ".exceptions", "InvalidHeaderFormat": ".exceptions", "InvalidHeaderValue": ".exceptions", + "InvalidMessage": ".exceptions", "InvalidOrigin": ".exceptions", "InvalidParameterName": ".exceptions", "InvalidParameterValue": ".exceptions", @@ -159,7 +162,6 @@ "WebSocketClientProtocol": ".legacy.client", # .legacy.exceptions "AbortHandshake": ".legacy.exceptions", - "InvalidMessage": ".legacy.exceptions", "InvalidStatusCode": ".legacy.exceptions", "RedirectHandshake": ".legacy.exceptions", "WebSocketProtocolError": ".legacy.exceptions", diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index cdd9bfac6..8581c0551 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -11,7 +11,7 @@ from ..client import ClientProtocol, backoff from ..datastructures import HeadersLike -from ..exceptions import InvalidStatus, SecurityError +from ..exceptions import InvalidMessage, InvalidStatus, SecurityError from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate from ..headers import validate_subprotocols @@ -147,7 +147,9 @@ def process_exception(exc: Exception) -> Exception | None: That exception will be raised, breaking out of the retry loop. """ - if isinstance(exc, (EOFError, OSError, asyncio.TimeoutError)): + if isinstance(exc, (OSError, asyncio.TimeoutError)): + return None + if isinstance(exc, InvalidMessage) and isinstance(exc.__cause__, EOFError): return None if isinstance(exc, InvalidStatus) and exc.response.status_code in [ 500, # Internal Server Error diff --git a/src/websockets/client.py b/src/websockets/client.py index f6cbc9f65..5ced05c2a 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -11,6 +11,7 @@ InvalidHandshake, InvalidHeader, InvalidHeaderValue, + InvalidMessage, InvalidStatus, InvalidUpgrade, NegotiationError, @@ -318,7 +319,10 @@ def parse(self) -> Generator[None]: self.reader.read_to_eof, ) except Exception as exc: - self.handshake_exc = exc + self.handshake_exc = InvalidMessage( + "did not receive a valid HTTP response" + ) + self.handshake_exc.__cause__ = exc self.send_eof() self.parser = self.discard() next(self.parser) # start coroutine diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index f3e751971..81fbb1efd 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -8,7 +8,7 @@ * :exc:`InvalidURI` * :exc:`InvalidHandshake` * :exc:`SecurityError` - * :exc:`InvalidMessage` (legacy) + * :exc:`InvalidMessage` * :exc:`InvalidStatus` * :exc:`InvalidStatusCode` (legacy) * :exc:`InvalidHeader` @@ -48,6 +48,7 @@ "InvalidHeader", "InvalidHeaderFormat", "InvalidHeaderValue", + "InvalidMessage", "InvalidOrigin", "InvalidUpgrade", "NegotiationError", @@ -185,6 +186,13 @@ class SecurityError(InvalidHandshake): """ +class InvalidMessage(InvalidHandshake): + """ + Raised when a handshake request or response is malformed. + + """ + + class InvalidStatus(InvalidHandshake): """ Raised when a handshake response rejects the WebSocket upgrade. @@ -410,7 +418,6 @@ class ConcurrencyError(WebSocketException, RuntimeError): deprecated_aliases={ # deprecated in 14.0 - 2024-11-09 "AbortHandshake": ".legacy.exceptions", - "InvalidMessage": ".legacy.exceptions", "InvalidStatusCode": ".legacy.exceptions", "RedirectHandshake": ".legacy.exceptions", "WebSocketProtocolError": ".legacy.exceptions", diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index a3856b470..29141f39a 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -17,6 +17,7 @@ from ..exceptions import ( InvalidHeader, InvalidHeaderValue, + InvalidMessage, NegotiationError, SecurityError, ) @@ -34,7 +35,7 @@ from ..http11 import USER_AGENT from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol from ..uri import WebSocketURI, parse_uri -from .exceptions import InvalidMessage, InvalidStatusCode, RedirectHandshake +from .exceptions import InvalidStatusCode, RedirectHandshake from .handshake import build_request, check_response from .http import read_response from .protocol import WebSocketCommonProtocol diff --git a/src/websockets/legacy/exceptions.py b/src/websockets/legacy/exceptions.py index e2279c825..78fb696fa 100644 --- a/src/websockets/legacy/exceptions.py +++ b/src/websockets/legacy/exceptions.py @@ -3,18 +3,13 @@ from .. import datastructures from ..exceptions import ( InvalidHandshake, + # InvalidMessage was incorrectly moved here in versions 14.0 and 14.1. + InvalidMessage, # noqa: F401 ProtocolError as WebSocketProtocolError, # noqa: F401 ) from ..typing import StatusLike -class InvalidMessage(InvalidHandshake): - """ - Raised when a handshake request or response is malformed. - - """ - - class InvalidStatusCode(InvalidHandshake): """ Raised when a handshake response status code is invalid. diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 9326b6100..f9d57cb99 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -17,6 +17,7 @@ from ..exceptions import ( InvalidHandshake, InvalidHeader, + InvalidMessage, InvalidOrigin, InvalidUpgrade, NegotiationError, @@ -32,7 +33,7 @@ from ..http11 import SERVER from ..protocol import State from ..typing import ExtensionHeader, LoggerLike, Origin, StatusLike, Subprotocol -from .exceptions import AbortHandshake, InvalidMessage +from .exceptions import AbortHandshake from .handshake import build_response, check_request from .http import read_request from .protocol import WebSocketCommonProtocol, broadcast diff --git a/src/websockets/server.py b/src/websockets/server.py index 607cc306e..1b663a137 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -13,6 +13,7 @@ InvalidHandshake, InvalidHeader, InvalidHeaderValue, + InvalidMessage, InvalidOrigin, InvalidUpgrade, NegotiationError, @@ -552,7 +553,10 @@ def parse(self) -> Generator[None]: self.reader.read_line, ) except Exception as exc: - self.handshake_exc = exc + self.handshake_exc = InvalidMessage( + "did not receive a valid HTTP request" + ) + self.handshake_exc.__cause__ = exc self.send_eof() self.parser = self.discard() next(self.parser) # start coroutine diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 231d6b8ca..1773c08bd 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -12,6 +12,7 @@ from websockets.client import backoff from websockets.exceptions import ( InvalidHandshake, + InvalidMessage, InvalidStatus, InvalidURI, SecurityError, @@ -151,22 +152,24 @@ async def test_reconnect(self): iterations = 0 successful = 0 - def process_request(connection, request): + async def process_request(connection, request): nonlocal iterations iterations += 1 # Retriable errors if iterations == 1: - connection.transport.close() + await asyncio.sleep(3 * MS) elif iterations == 2: + connection.transport.close() + elif iterations == 3: return connection.respond(http.HTTPStatus.SERVICE_UNAVAILABLE, "🚒") # Fatal error - elif iterations == 5: + elif iterations == 6: return connection.respond(http.HTTPStatus.PAYMENT_REQUIRED, "💸") async with serve(*args, process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: async with short_backoff_delay(): - async for client in connect(get_uri(server)): + async for client in connect(get_uri(server), open_timeout=3 * MS): self.assertEqual(client.protocol.state.name, "OPEN") successful += 1 @@ -174,7 +177,7 @@ def process_request(connection, request): str(raised.exception), "server rejected WebSocket connection: HTTP 402", ) - self.assertEqual(iterations, 5) + self.assertEqual(iterations, 6) self.assertEqual(successful, 2) async def test_reconnect_with_custom_process_exception(self): @@ -393,11 +396,16 @@ def close_connection(self, request): self.close_transport() async with serve(*args, process_request=close_connection) as server: - with self.assertRaises(EOFError) as raised: + with self.assertRaises(InvalidMessage) as raised: async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), + "did not receive a valid HTTP response", + ) + self.assertIsInstance(raised.exception.__cause__, EOFError) + self.assertEqual( + str(raised.exception.__cause__), "connection closed while reading HTTP status line", ) @@ -443,11 +451,16 @@ async def junk(reader, writer): server = await asyncio.start_server(junk, "localhost", 0) host, port = get_host_port(server) async with server: - with self.assertRaises(ValueError) as raised: + with self.assertRaises(InvalidMessage) as raised: async with connect(f"ws://{host}:{port}"): self.fail("did not raise") self.assertEqual( str(raised.exception), + "did not receive a valid HTTP response", + ) + self.assertIsInstance(raised.exception.__cause__, ValueError) + self.assertEqual( + str(raised.exception.__cause__), "unsupported protocol; expected HTTP/1.1: " "220 smtp.invalid ESMTP Postfix", ) diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 3e289e592..83885fab5 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -473,6 +473,10 @@ async def test_junk_handshake(self): ) self.assertEqual( [str(record.exc_info[1]) for record in logs.records], + ["did not receive a valid HTTP request"], + ) + self.assertEqual( + [str(record.exc_info[1].__cause__) for record in logs.records], ["invalid HTTP request line: HELO relay.invalid"], ) diff --git a/tests/legacy/test_exceptions.py b/tests/legacy/test_exceptions.py index e5d22a917..4e6ff952b 100644 --- a/tests/legacy/test_exceptions.py +++ b/tests/legacy/test_exceptions.py @@ -7,10 +7,6 @@ class ExceptionsTests(unittest.TestCase): def test_str(self): for exception, exception_str in [ - ( - InvalidMessage("malformed HTTP message"), - "malformed HTTP message", - ), ( InvalidStatusCode(403, Headers()), "server rejected WebSocket connection: HTTP 403", diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 9d457a912..7d8170519 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -7,7 +7,12 @@ import time import unittest -from websockets.exceptions import InvalidHandshake, InvalidStatus, InvalidURI +from websockets.exceptions import ( + InvalidHandshake, + InvalidMessage, + InvalidStatus, + InvalidURI, +) from websockets.extensions.permessage_deflate import PerMessageDeflate from websockets.sync.client import * @@ -149,11 +154,16 @@ def close_connection(self, request): self.close_socket() with run_server(process_request=close_connection) as server: - with self.assertRaises(EOFError) as raised: + with self.assertRaises(InvalidMessage) as raised: with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), + "did not receive a valid HTTP response", + ) + self.assertIsInstance(raised.exception.__cause__, EOFError) + self.assertEqual( + str(raised.exception.__cause__), "connection closed while reading HTTP status line", ) @@ -203,11 +213,16 @@ def handle(self): thread = threading.Thread(target=server.serve_forever, args=(MS,)) thread.start() try: - with self.assertRaises(ValueError) as raised: + with self.assertRaises(InvalidMessage) as raised: with connect(f"ws://{host}:{port}"): self.fail("did not raise") self.assertEqual( str(raised.exception), + "did not receive a valid HTTP response", + ) + self.assertIsInstance(raised.exception.__cause__, ValueError) + self.assertEqual( + str(raised.exception.__cause__), "unsupported protocol; expected HTTP/1.1: " "220 smtp.invalid ESMTP Postfix", ) diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 54e49bf16..9a2676437 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -311,6 +311,10 @@ def test_junk_handshake(self): ) self.assertEqual( [str(record.exc_info[1]) for record in logs.records], + ["did not receive a valid HTTP request"], + ) + self.assertEqual( + [str(record.exc_info[1].__cause__) for record in logs.records], ["invalid HTTP request line: HELO relay.invalid"], ) diff --git a/tests/test_client.py b/tests/test_client.py index 2468be85e..1edbae57d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -8,7 +8,12 @@ from websockets.client import * from websockets.client import backoff from websockets.datastructures import Headers -from websockets.exceptions import InvalidHandshake, InvalidHeader, InvalidStatus +from websockets.exceptions import ( + InvalidHandshake, + InvalidHeader, + InvalidMessage, + InvalidStatus, +) from websockets.frames import OP_TEXT, Frame from websockets.http11 import Request, Response from websockets.protocol import CONNECTING, OPEN @@ -244,9 +249,14 @@ def test_receive_no_response(self, _generate_key): client.receive_eof() self.assertEqual(client.events_received(), []) - self.assertIsInstance(client.handshake_exc, EOFError) + self.assertIsInstance(client.handshake_exc, InvalidMessage) self.assertEqual( str(client.handshake_exc), + "did not receive a valid HTTP response", + ) + self.assertIsInstance(client.handshake_exc.__cause__, EOFError) + self.assertEqual( + str(client.handshake_exc.__cause__), "connection closed while reading HTTP status line", ) @@ -257,9 +267,14 @@ def test_receive_truncated_response(self, _generate_key): client.receive_eof() self.assertEqual(client.events_received(), []) - self.assertIsInstance(client.handshake_exc, EOFError) + self.assertIsInstance(client.handshake_exc, InvalidMessage) self.assertEqual( str(client.handshake_exc), + "did not receive a valid HTTP response", + ) + self.assertIsInstance(client.handshake_exc.__cause__, EOFError) + self.assertEqual( + str(client.handshake_exc.__cause__), "connection closed while reading HTTP headers", ) @@ -272,9 +287,14 @@ def test_receive_random_response(self, _generate_key): client.receive_data(b"250 Ok\r\n") self.assertEqual(client.events_received(), []) - self.assertIsInstance(client.handshake_exc, ValueError) + self.assertIsInstance(client.handshake_exc, InvalidMessage) self.assertEqual( str(client.handshake_exc), + "did not receive a valid HTTP response", + ) + self.assertIsInstance(client.handshake_exc.__cause__, ValueError) + self.assertEqual( + str(client.handshake_exc.__cause__), "invalid HTTP status line: 220 smtp.invalid", ) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index fef41d136..e0518b0e0 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -91,6 +91,10 @@ def test_str(self): SecurityError("redirect from WSS to WS"), "redirect from WSS to WS", ), + ( + InvalidMessage("malformed HTTP message"), + "malformed HTTP message", + ), ( InvalidStatus(Response(401, "Unauthorized", Headers())), "server rejected WebSocket connection: HTTP 401", diff --git a/tests/test_server.py b/tests/test_server.py index 844ba64ec..5efeca2d0 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -7,6 +7,7 @@ from websockets.datastructures import Headers from websockets.exceptions import ( InvalidHeader, + InvalidMessage, InvalidOrigin, InvalidUpgrade, NegotiationError, @@ -207,9 +208,15 @@ def test_receive_no_request(self): server.receive_eof() self.assertEqual(server.events_received(), []) - self.assertIsInstance(server.handshake_exc, EOFError) + self.assertEqual(server.events_received(), []) + self.assertIsInstance(server.handshake_exc, InvalidMessage) self.assertEqual( str(server.handshake_exc), + "did not receive a valid HTTP request", + ) + self.assertIsInstance(server.handshake_exc.__cause__, EOFError) + self.assertEqual( + str(server.handshake_exc.__cause__), "connection closed while reading HTTP request line", ) @@ -220,9 +227,14 @@ def test_receive_truncated_request(self): server.receive_eof() self.assertEqual(server.events_received(), []) - self.assertIsInstance(server.handshake_exc, EOFError) + self.assertIsInstance(server.handshake_exc, InvalidMessage) self.assertEqual( str(server.handshake_exc), + "did not receive a valid HTTP request", + ) + self.assertIsInstance(server.handshake_exc.__cause__, EOFError) + self.assertEqual( + str(server.handshake_exc.__cause__), "connection closed while reading HTTP headers", ) @@ -233,10 +245,14 @@ def test_receive_junk_request(self): server.receive_data(b"MAIL FROM: \r\n") server.receive_data(b"RCPT TO: \r\n") - self.assertEqual(server.events_received(), []) - self.assertIsInstance(server.handshake_exc, ValueError) + self.assertIsInstance(server.handshake_exc, InvalidMessage) self.assertEqual( str(server.handshake_exc), + "did not receive a valid HTTP request", + ) + self.assertIsInstance(server.handshake_exc.__cause__, ValueError) + self.assertEqual( + str(server.handshake_exc.__cause__), "invalid HTTP request line: HELO relay.invalid", ) From d852df7dd6324eaee17fc848f029ada371678cbe Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 7 Dec 2024 06:23:21 +0000 Subject: [PATCH 1460/1539] Bump actions/attest-build-provenance from 1 to 2 Bumps [actions/attest-build-provenance](https://github.com/actions/attest-build-provenance) from 1 to 2. - [Release notes](https://github.com/actions/attest-build-provenance/releases) - [Changelog](https://github.com/actions/attest-build-provenance/blob/main/RELEASE.md) - [Commits](https://github.com/actions/attest-build-provenance/compare/v1...v2) --- updated-dependencies: - dependency-name: actions/attest-build-provenance dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index cc26502ca..aa1e2e7a9 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -82,7 +82,7 @@ jobs: merge-multiple: true path: dist - name: Attest provenance - uses: actions/attest-build-provenance@v1 + uses: actions/attest-build-provenance@v2 with: subject-path: dist/* - name: Upload to PyPI From 197b3ec2c7acf3a3804a94c0c02ed8e27051b0f0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 3 Jan 2025 22:50:31 +0100 Subject: [PATCH 1461/1539] Don't crash when acknowledging a cancelled ping. Fix #1566. --- src/websockets/asyncio/connection.py | 3 ++- tests/asyncio/test_connection.py | 23 ++++++++++++++++++++++- tests/sync/test_connection.py | 2 +- 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index e5c350fe2..c34d19d58 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -750,7 +750,8 @@ def acknowledge_pings(self, data: bytes) -> None: for ping_id, (pong_waiter, ping_timestamp) in self.pong_waiters.items(): ping_ids.append(ping_id) latency = pong_timestamp - ping_timestamp - pong_waiter.set_result(latency) + if not pong_waiter.done(): + pong_waiter.set_result(latency) if ping_id == data: self.latency = latency break diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 5a0b61bf7..788a457ed 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -892,6 +892,15 @@ async def test_acknowledge_ping(self): async with asyncio_timeout(MS): await pong_waiter + async def test_acknowledge_canceled_ping(self): + """ping is acknowledged by a pong with the same payload after being canceled.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_waiter = await self.connection.ping("this") + pong_waiter.cancel() + await self.remote_connection.pong("this") + with self.assertRaises(asyncio.CancelledError): + await pong_waiter + async def test_acknowledge_ping_non_matching_pong(self): """ping isn't acknowledged by a pong with a different payload.""" async with self.drop_frames_rcvd(): # drop automatic response to ping @@ -902,7 +911,7 @@ async def test_acknowledge_ping_non_matching_pong(self): await pong_waiter async def test_acknowledge_previous_ping(self): - """ping is acknowledged by a pong with the same payload as a later ping.""" + """ping is acknowledged by a pong for a later ping.""" async with self.drop_frames_rcvd(): # drop automatic response to ping pong_waiter = await self.connection.ping("this") await self.connection.ping("that") @@ -910,6 +919,18 @@ async def test_acknowledge_previous_ping(self): async with asyncio_timeout(MS): await pong_waiter + async def test_acknowledge_previous_canceled_ping(self): + """ping is acknowledged by a pong for a later ping after being canceled.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_waiter = await self.connection.ping("this") + pong_waiter_2 = await self.connection.ping("that") + pong_waiter.cancel() + await self.remote_connection.pong("that") + async with asyncio_timeout(MS): + await pong_waiter_2 + with self.assertRaises(asyncio.CancelledError): + await pong_waiter + async def test_ping_duplicate_payload(self): """ping rejects the same payload until receiving the pong.""" async with self.drop_frames_rcvd(): # drop automatic response to ping diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index e21e310a2..aa445498c 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -678,7 +678,7 @@ def test_acknowledge_ping_non_matching_pong(self): self.assertFalse(pong_waiter.wait(MS)) def test_acknowledge_previous_ping(self): - """ping is acknowledged by a pong with the same payload as a later ping.""" + """ping is acknowledged by a pong for as a later ping.""" with self.drop_frames_rcvd(): # drop automatic response to ping pong_waiter = self.connection.ping("this") self.connection.ping("that") From 2abf77fb1ce67b3c5995fb63eeea6b829d96c0fb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 4 Jan 2025 22:38:44 +0100 Subject: [PATCH 1462/1539] Improve consistency of convenience imports. APIs available as convenience imports should have their types also available as convenienc imports. Fix #1560. --- docs/howto/upgrade.rst | 12 ++++----- src/websockets/__init__.py | 44 ++++++++++++++++++++++++++++++-- src/websockets/datastructures.py | 6 ++++- src/websockets/frames.py | 1 + src/websockets/http11.py | 7 ++++- tests/test_exports.py | 36 ++++++++++++-------------- 6 files changed, 77 insertions(+), 29 deletions(-) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 02d4c6f01..db6bf11f1 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -93,8 +93,8 @@ Client APIs | ``websockets.client.unix_connect()`` |br| | :func:`websockets.asyncio.client.unix_connect` | | :func:`websockets.legacy.client.unix_connect` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.WebSocketClientProtocol`` |br| | :class:`websockets.asyncio.client.ClientConnection` | -| ``websockets.client.WebSocketClientProtocol`` |br| | | +| ``websockets.WebSocketClientProtocol`` |br| | ``websockets.ClientConnection`` *(since 14.2)* |br| | +| ``websockets.client.WebSocketClientProtocol`` |br| | :class:`websockets.asyncio.client.ClientConnection` | | :class:`websockets.legacy.client.WebSocketClientProtocol` | | +-------------------------------------------------------------------+-----------------------------------------------------+ @@ -112,12 +112,12 @@ Server APIs | ``websockets.server.unix_serve()`` |br| | :func:`websockets.asyncio.server.unix_serve` | | :func:`websockets.legacy.server.unix_serve` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.WebSocketServer`` |br| | :class:`websockets.asyncio.server.Server` | -| ``websockets.server.WebSocketServer`` |br| | | +| ``websockets.WebSocketServer`` |br| | ``websockets.Server`` *(since 14.2)* |br| | +| ``websockets.server.WebSocketServer`` |br| | :class:`websockets.asyncio.server.Server` | | :class:`websockets.legacy.server.WebSocketServer` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.WebSocketServerProtocol`` |br| | :class:`websockets.asyncio.server.ServerConnection` | -| ``websockets.server.WebSocketServerProtocol`` |br| | | +| ``websockets.WebSocketServerProtocol`` |br| | ``websockets.ServerConnection`` *(since 14.2)* |br| | +| ``websockets.server.WebSocketServerProtocol`` |br| | :class:`websockets.asyncio.server.ServerConnection` | | :class:`websockets.legacy.server.WebSocketServerProtocol` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.broadcast()`` *(before 14.0)* |br| | ``websockets.broadcast()`` *(since 14.0)* |br| | diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index c278b21d4..c8df54e0b 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -10,11 +10,14 @@ # .asyncio.client "connect", "unix_connect", + "ClientConnection", # .asyncio.server "basic_auth", "broadcast", "serve", "unix_serve", + "ServerConnection", + "Server", # .client "ClientProtocol", # .datastructures @@ -44,6 +47,18 @@ "ProtocolError", "SecurityError", "WebSocketException", + # .frames + "Close", + "CloseCode", + "Frame", + "Opcode", + # .http11 + "Request", + "Response", + # .protocol + "Protocol", + "Side", + "State", # .server "ServerProtocol", # .typing @@ -58,8 +73,15 @@ # When type checking, import non-deprecated aliases eagerly. Else, import on demand. if typing.TYPE_CHECKING: - from .asyncio.client import connect, unix_connect - from .asyncio.server import basic_auth, broadcast, serve, unix_serve + from .asyncio.client import ClientConnection, connect, unix_connect + from .asyncio.server import ( + Server, + ServerConnection, + basic_auth, + broadcast, + serve, + unix_serve, + ) from .client import ClientProtocol from .datastructures import Headers, HeadersLike, MultipleValuesError from .exceptions import ( @@ -86,6 +108,9 @@ SecurityError, WebSocketException, ) + from .frames import Close, CloseCode, Frame, Opcode + from .http11 import Request, Response + from .protocol import Protocol, Side, State from .server import ServerProtocol from .typing import ( Data, @@ -103,11 +128,14 @@ # .asyncio.client "connect": ".asyncio.client", "unix_connect": ".asyncio.client", + "ClientConnection": ".asyncio.client", # .asyncio.server "basic_auth": ".asyncio.server", "broadcast": ".asyncio.server", "serve": ".asyncio.server", "unix_serve": ".asyncio.server", + "ServerConnection": ".asyncio.server", + "Server": ".asyncio.server", # .client "ClientProtocol": ".client", # .datastructures @@ -137,6 +165,18 @@ "ProtocolError": ".exceptions", "SecurityError": ".exceptions", "WebSocketException": ".exceptions", + # .frames + "Close": ".frames", + "CloseCode": ".frames", + "Frame": ".frames", + "Opcode": ".frames", + # .http11 + "Request": ".http11", + "Response": ".http11", + # .protocol + "Protocol": ".protocol", + "Side": ".protocol", + "State": ".protocol", # .server "ServerProtocol": ".server", # .typing diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index 77b6f86fa..3c5dcbe9a 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -4,7 +4,11 @@ from typing import Any, Protocol, Union -__all__ = ["Headers", "HeadersLike", "MultipleValuesError"] +__all__ = [ + "Headers", + "HeadersLike", + "MultipleValuesError", +] class MultipleValuesError(LookupError): diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 7898c8a5d..ab0869d01 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -28,6 +28,7 @@ "OP_PONG", "DATA_OPCODES", "CTRL_OPCODES", + "CloseCode", "Frame", "Close", ] diff --git a/src/websockets/http11.py b/src/websockets/http11.py index af542c77b..949a320f7 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -13,7 +13,12 @@ from .version import version as websockets_version -__all__ = ["SERVER", "USER_AGENT", "Request", "Response"] +__all__ = [ + "SERVER", + "USER_AGENT", + "Request", + "Response", +] PYTHON_VERSION = "{}.{}".format(*sys.version_info) diff --git a/tests/test_exports.py b/tests/test_exports.py index 93b0684f7..88e27e69d 100644 --- a/tests/test_exports.py +++ b/tests/test_exports.py @@ -11,24 +11,22 @@ import websockets.uri -combined_exports = ( - [] - + websockets.asyncio.client.__all__ - + websockets.asyncio.server.__all__ - + websockets.client.__all__ - + websockets.datastructures.__all__ - + websockets.exceptions.__all__ - + websockets.server.__all__ - + websockets.typing.__all__ -) - -# These API are intentionally not re-exported by the top-level module. -missing_reexports = [ - # websockets.asyncio.client - "ClientConnection", - # websockets.asyncio.server - "ServerConnection", - "Server", +combined_exports = [ + name + for name in ( + [] + + websockets.asyncio.client.__all__ + + websockets.asyncio.server.__all__ + + websockets.client.__all__ + + websockets.datastructures.__all__ + + websockets.exceptions.__all__ + + websockets.frames.__all__ + + websockets.http11.__all__ + + websockets.protocol.__all__ + + websockets.server.__all__ + + websockets.typing.__all__ + ) + if not name.isupper() # filter out constants ] @@ -36,7 +34,7 @@ class ExportsTests(unittest.TestCase): def test_top_level_module_reexports_submodule_exports(self): self.assertEqual( set(combined_exports), - set(websockets.__all__ + missing_reexports), + set(websockets.__all__), ) def test_submodule_exports_are_globally_unique(self): From e6d0ea1d6b13a979924329d02fb82f79d82c7236 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 4 Jan 2025 10:18:38 +0100 Subject: [PATCH 1463/1539] Read chunked HTTP responses. Fix #1550. --- src/websockets/client.py | 2 +- src/websockets/http11.py | 145 ++++++++++++++++++++++++--------------- src/websockets/server.py | 2 +- tests/test_client.py | 2 +- tests/test_http11.py | 93 ++++++++++++++++++++----- tests/test_server.py | 2 +- 6 files changed, 167 insertions(+), 79 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 5ced05c2a..37e2a8b3a 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -333,7 +333,7 @@ def parse(self) -> Generator[None]: self.logger.debug("< HTTP/1.1 %d %s", code, phrase) for key, value in response.headers.raw_items(): self.logger.debug("< %s: %s", key, value) - if response.body is not None: + if response.body: self.logger.debug("< [body] (%d bytes)", len(response.body)) try: diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 949a320f7..396a43f07 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -185,14 +185,14 @@ class Response: status_code: Response code. reason_phrase: Response reason. headers: Response headers. - body: Response body, if any. + body: Response body. """ status_code: int reason_phrase: str headers: Headers - body: bytes | None = None + body: bytes = b"" _exception: Exception | None = None @@ -266,36 +266,9 @@ def parse( headers = yield from parse_headers(read_line) - # https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.3 - - if "Transfer-Encoding" in headers: - raise NotImplementedError("transfer codings aren't supported") - - # Since websockets only does GET requests (no HEAD, no CONNECT), all - # responses except 1xx, 204, and 304 include a message body. - if 100 <= status_code < 200 or status_code == 204 or status_code == 304: - body = None - else: - content_length: int | None - try: - # MultipleValuesError is sufficiently unlikely that we don't - # attempt to handle it. Instead we document that its parent - # class, LookupError, may be raised. - raw_content_length = headers["Content-Length"] - except KeyError: - content_length = None - else: - content_length = int(raw_content_length) - - if content_length is None: - try: - body = yield from read_to_eof(MAX_BODY_SIZE) - except RuntimeError: - raise SecurityError(f"body too large: over {MAX_BODY_SIZE} bytes") - elif content_length > MAX_BODY_SIZE: - raise SecurityError(f"body too large: {content_length} bytes") - else: - body = yield from read_exact(content_length) + body = yield from read_body( + status_code, headers, read_line, read_exact, read_to_eof + ) return cls(status_code, reason, headers, body) @@ -308,11 +281,37 @@ def serialize(self) -> bytes: # we can keep this simple. response = f"HTTP/1.1 {self.status_code} {self.reason_phrase}\r\n".encode() response += self.headers.serialize() - if self.body is not None: - response += self.body + response += self.body return response +def parse_line( + read_line: Callable[[int], Generator[None, None, bytes]], +) -> Generator[None, None, bytes]: + """ + Parse a single line. + + CRLF is stripped from the return value. + + Args: + read_line: Generator-based coroutine that reads a LF-terminated line + or raises an exception if there isn't enough data. + + Raises: + EOFError: If the connection is closed without a CRLF. + SecurityError: If the response exceeds a security limit. + + """ + try: + line = yield from read_line(MAX_LINE_LENGTH) + except RuntimeError: + raise SecurityError("line too long") + # Not mandatory but safe - https://datatracker.ietf.org/doc/html/rfc7230#section-3.5 + if not line.endswith(b"\r\n"): + raise EOFError("line without CRLF") + return line[:-2] + + def parse_headers( read_line: Callable[[int], Generator[None, None, bytes]], ) -> Generator[None, None, Headers]: @@ -364,28 +363,62 @@ def parse_headers( return headers -def parse_line( +def read_body( + status_code: int, + headers: Headers, read_line: Callable[[int], Generator[None, None, bytes]], + read_exact: Callable[[int], Generator[None, None, bytes]], + read_to_eof: Callable[[int], Generator[None, None, bytes]], ) -> Generator[None, None, bytes]: - """ - Parse a single line. - - CRLF is stripped from the return value. - - Args: - read_line: Generator-based coroutine that reads a LF-terminated line - or raises an exception if there isn't enough data. - - Raises: - EOFError: If the connection is closed without a CRLF. - SecurityError: If the response exceeds a security limit. + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.3 + + # Since websockets only does GET requests (no HEAD, no CONNECT), all + # responses except 1xx, 204, and 304 include a message body. + if 100 <= status_code < 200 or status_code == 204 or status_code == 304: + return b"" + + # MultipleValuesError is sufficiently unlikely that we don't attempt to + # handle it when accessing headers. Instead we document that its parent + # class, LookupError, may be raised. + # Conversions from str to int are protected by sys.set_int_max_str_digits.. + + elif (coding := headers.get("Transfer-Encoding")) is not None: + if coding != "chunked": + raise NotImplementedError(f"transfer coding {coding} isn't supported") + + body = b"" + while True: + chunk_size_line = yield from parse_line(read_line) + raw_chunk_size = chunk_size_line.split(b";", 1)[0] + # Set a lower limit than default_max_str_digits; 1 EB is plenty. + if len(raw_chunk_size) > 15: + str_chunk_size = raw_chunk_size.decode(errors="backslashreplace") + raise SecurityError(f"chunk too large: 0x{str_chunk_size} bytes") + chunk_size = int(raw_chunk_size, 16) + if chunk_size == 0: + break + if len(body) + chunk_size > MAX_BODY_SIZE: + raise SecurityError( + f"chunk too large: {chunk_size} bytes after {len(body)} bytes" + ) + body += yield from read_exact(chunk_size) + if (yield from read_exact(2)) != b"\r\n": + raise ValueError("chunk without CRLF") + # Read the trailer. + yield from parse_headers(read_line) + return body + + elif (raw_content_length := headers.get("Content-Length")) is not None: + # Set a lower limit than default_max_str_digits; 1 EiB is plenty. + if len(raw_content_length) > 18: + raise SecurityError(f"body too large: {raw_content_length} bytes") + content_length = int(raw_content_length) + if content_length > MAX_BODY_SIZE: + raise SecurityError(f"body too large: {content_length} bytes") + return (yield from read_exact(content_length)) - """ - try: - line = yield from read_line(MAX_LINE_LENGTH) - except RuntimeError: - raise SecurityError("line too long") - # Not mandatory but safe - https://datatracker.ietf.org/doc/html/rfc7230#section-3.5 - if not line.endswith(b"\r\n"): - raise EOFError("line without CRLF") - return line[:-2] + else: + try: + return (yield from read_to_eof(MAX_BODY_SIZE)) + except RuntimeError: + raise SecurityError(f"body too large: over {MAX_BODY_SIZE} bytes") diff --git a/src/websockets/server.py b/src/websockets/server.py index 1b663a137..fe3c65a7d 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -525,7 +525,7 @@ def send_response(self, response: Response) -> None: self.logger.debug("> HTTP/1.1 %d %s", code, phrase) for key, value in response.headers.raw_items(): self.logger.debug("> %s: %s", key, value) - if response.body is not None: + if response.body: self.logger.debug("> [body] (%d bytes)", len(response.body)) self.writes.append(response.serialize()) diff --git a/tests/test_client.py b/tests/test_client.py index 1edbae57d..9f3ab09b2 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -204,7 +204,7 @@ def test_receive_successful_response(self, _generate_key): } ), ) - self.assertIsNone(response.body) + self.assertEqual(response.body, b"") self.assertIsNone(client.handshake_exc) def test_receive_failed_response(self, _generate_key): diff --git a/tests/test_http11.py b/tests/test_http11.py index 1fbcb3ba4..bb0d27b95 100644 --- a/tests/test_http11.py +++ b/tests/test_http11.py @@ -87,7 +87,7 @@ def test_parse_body(self): ) def test_parse_body_with_transfer_encoding(self): - self.reader.feed_data(b"GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n") + self.reader.feed_data(b"GET / HTTP/1.1\r\nTransfer-Encoding: compress\r\n\r\n") with self.assertRaises(NotImplementedError) as raised: next(self.parse()) self.assertEqual( @@ -151,7 +151,7 @@ def test_parse(self): self.assertEqual(response.status_code, 101) self.assertEqual(response.reason_phrase, "Switching Protocols") self.assertEqual(response.headers["Upgrade"], "websocket") - self.assertIsNone(response.body) + self.assertEqual(response.body, b"") def test_parse_empty(self): self.reader.feed_eof() @@ -215,14 +215,7 @@ def test_parse_invalid_header(self): "invalid HTTP header line: Oops", ) - def test_parse_body_with_content_length(self): - self.reader.feed_data( - b"HTTP/1.1 200 OK\r\nContent-Length: 13\r\n\r\nHello world!\n" - ) - response = self.assertGeneratorReturns(self.parse()) - self.assertEqual(response.body, b"Hello world!\n") - - def test_parse_body_without_content_length(self): + def test_parse_body(self): self.reader.feed_data(b"HTTP/1.1 200 OK\r\n\r\nHello world!\n") gen = self.parse() self.assertGeneratorRunning(gen) @@ -230,7 +223,23 @@ def test_parse_body_without_content_length(self): response = self.assertGeneratorReturns(gen) self.assertEqual(response.body, b"Hello world!\n") - def test_parse_body_with_content_length_too_long(self): + def test_parse_body_too_large(self): + self.reader.feed_data(b"HTTP/1.1 200 OK\r\n\r\n" + b"a" * 1048577) + with self.assertRaises(SecurityError) as raised: + next(self.parse()) + self.assertEqual( + str(raised.exception), + "body too large: over 1048576 bytes", + ) + + def test_parse_body_with_content_length(self): + self.reader.feed_data( + b"HTTP/1.1 200 OK\r\nContent-Length: 13\r\n\r\nHello world!\n" + ) + response = self.assertGeneratorReturns(self.parse()) + self.assertEqual(response.body, b"Hello world!\n") + + def test_parse_body_with_content_length_and_body_too_large(self): self.reader.feed_data(b"HTTP/1.1 200 OK\r\nContent-Length: 1048577\r\n\r\n") with self.assertRaises(SecurityError) as raised: next(self.parse()) @@ -239,33 +248,79 @@ def test_parse_body_with_content_length_too_long(self): "body too large: 1048577 bytes", ) - def test_parse_body_without_content_length_too_long(self): - self.reader.feed_data(b"HTTP/1.1 200 OK\r\n\r\n" + b"a" * 1048577) + def test_parse_body_with_content_length_and_body_way_too_large(self): + self.reader.feed_data( + b"HTTP/1.1 200 OK\r\nContent-Length: 1234567890123456789\r\n\r\n" + ) with self.assertRaises(SecurityError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), - "body too large: over 1048576 bytes", + "body too large: 1234567890123456789 bytes", ) - def test_parse_body_with_transfer_encoding(self): - self.reader.feed_data(b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n") + def test_parse_body_with_chunked_transfer_encoding(self): + self.reader.feed_data( + b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n" + b"6\r\nHello \r\n7\r\nworld!\n\r\n0\r\n\r\n" + ) + response = self.assertGeneratorReturns(self.parse()) + self.assertEqual(response.body, b"Hello world!\n") + + def test_parse_body_with_chunked_transfer_encoding_and_chunk_without_crlf(self): + self.reader.feed_data( + b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n" + b"6\r\nHello 7\r\nworld!\n0\r\n" + ) + with self.assertRaises(ValueError) as raised: + next(self.parse()) + self.assertEqual( + str(raised.exception), + "chunk without CRLF", + ) + + def test_parse_body_with_chunked_transfer_encoding_and_chunk_too_large(self): + self.reader.feed_data( + b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n" + b"100000\r\n" + b"a" * 1048576 + b"\r\n1\r\na\r\n0\r\n\r\n" + ) + with self.assertRaises(SecurityError) as raised: + next(self.parse()) + self.assertEqual( + str(raised.exception), + "chunk too large: 1 bytes after 1048576 bytes", + ) + + def test_parse_body_with_chunked_transfer_encoding_and_chunk_way_too_large(self): + self.reader.feed_data( + b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n" + b"1234567890ABCDEF\r\n\r\n" + ) + with self.assertRaises(SecurityError) as raised: + next(self.parse()) + self.assertEqual( + str(raised.exception), + "chunk too large: 0x1234567890ABCDEF bytes", + ) + + def test_parse_body_with_unsupported_transfer_encoding(self): + self.reader.feed_data(b"HTTP/1.1 200 OK\r\nTransfer-Encoding: compress\r\n\r\n") with self.assertRaises(NotImplementedError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), - "transfer codings aren't supported", + "transfer coding compress isn't supported", ) def test_parse_body_no_content(self): self.reader.feed_data(b"HTTP/1.1 204 No Content\r\n\r\n") response = self.assertGeneratorReturns(self.parse()) - self.assertIsNone(response.body) + self.assertEqual(response.body, b"") def test_parse_body_not_modified(self): self.reader.feed_data(b"HTTP/1.1 304 Not Modified\r\n\r\n") response = self.assertGeneratorReturns(self.parse()) - self.assertIsNone(response.body) + self.assertEqual(response.body, b"") def test_serialize(self): # Example from the protocol overview in RFC 6455 diff --git a/tests/test_server.py b/tests/test_server.py index 5efeca2d0..69f555689 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -281,7 +281,7 @@ def test_accept_response(self, _formatdate): } ), ) - self.assertIsNone(response.body) + self.assertEqual(response.body, b"") @patch("email.utils.formatdate", return_value=DATE) def test_reject_response(self, _formatdate): From 916f841815070a0374434b13b9c91829a7e7a522 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 10 Jan 2025 22:32:44 +0100 Subject: [PATCH 1464/1539] Upgrade ruff. --- src/websockets/exceptions.py | 3 +-- src/websockets/http11.py | 3 +-- src/websockets/legacy/exceptions.py | 4 +--- src/websockets/sync/connection.py | 6 ++---- tests/asyncio/test_server.py | 3 +-- tests/sync/test_server.py | 3 +-- tests/test_client.py | 3 +-- 7 files changed, 8 insertions(+), 17 deletions(-) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 81fbb1efd..73b24debf 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -204,8 +204,7 @@ def __init__(self, response: http11.Response) -> None: def __str__(self) -> str: return ( - "server rejected WebSocket connection: " - f"HTTP {self.response.status_code:d}" + f"server rejected WebSocket connection: HTTP {self.response.status_code:d}" ) diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 396a43f07..49d7b9a41 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -94,8 +94,7 @@ class Request: @property def exception(self) -> Exception | None: # pragma: no cover warnings.warn( # deprecated in 10.3 - 2022-04-17 - "Request.exception is deprecated; " - "use ServerProtocol.handshake_exc instead", + "Request.exception is deprecated; use ServerProtocol.handshake_exc instead", DeprecationWarning, ) return self._exception diff --git a/src/websockets/legacy/exceptions.py b/src/websockets/legacy/exceptions.py index 78fb696fa..29a2525b4 100644 --- a/src/websockets/legacy/exceptions.py +++ b/src/websockets/legacy/exceptions.py @@ -52,9 +52,7 @@ def __init__( def __str__(self) -> str: return ( - f"HTTP {self.status:d}, " - f"{len(self.headers)} headers, " - f"{len(self.body)} bytes" + f"HTTP {self.status:d}, {len(self.headers)} headers, {len(self.body)} bytes" ) diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index d8dbf140e..60be245b4 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -412,8 +412,7 @@ def send( with self.send_context(): if self.send_in_progress: raise ConcurrencyError( - "cannot call send while another thread " - "is already running send" + "cannot call send while another thread is already running send" ) if text is False: self.protocol.send_binary(message.encode()) @@ -424,8 +423,7 @@ def send( with self.send_context(): if self.send_in_progress: raise ConcurrencyError( - "cannot call send while another thread " - "is already running send" + "cannot call send while another thread is already running send" ) if text is True: self.protocol.send_text(message) diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 83885fab5..edad52da5 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -59,8 +59,7 @@ async def test_connection_handler_raises_exception(self): await client.recv() self.assertEqual( str(raised.exception), - "received 1011 (internal error); " - "then sent 1011 (internal error)", + "received 1011 (internal error); then sent 1011 (internal error)", ) async def test_existing_socket(self): diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 9a2676437..bb2ebae14 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -58,8 +58,7 @@ def test_connection_handler_raises_exception(self): client.recv() self.assertEqual( str(raised.exception), - "received 1011 (internal error); " - "then sent 1011 (internal error)", + "received 1011 (internal error); then sent 1011 (internal error)", ) def test_existing_socket(self): diff --git a/tests/test_client.py b/tests/test_client.py index 9f3ab09b2..fc9f2ec9a 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -638,8 +638,7 @@ def test_multiple_subprotocols_accepted(self): self.assertHandshakeError( client, InvalidHandshake, - "invalid Sec-WebSocket-Protocol header: " - "multiple values: superchat, chat", + "invalid Sec-WebSocket-Protocol header: multiple values: superchat, chat", ) From 668e56cf5d405d5f418d97d183e76fc36d21f857 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 10 Jan 2025 22:26:15 +0100 Subject: [PATCH 1465/1539] Fix Connection.recv(timeout=0) in the sync implementation. Fix #1552. --- docs/project/changelog.rst | 4 ++++ src/websockets/sync/messages.py | 13 ++++++++++--- tests/sync/test_messages.py | 6 ++++++ 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index b7f4f62f9..de6b0be21 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -35,6 +35,10 @@ notice. Bug fixes ......... +* Fixed ``connection.recv(timeout=0)`` in the :mod:`threading` implementation. + If a message is already received, it is returned. Previously, + :exc:`TimeoutError` was raised incorrectly. + * Wrapped errors when reading the opening handshake request or response in :exc:`~exceptions.InvalidMessage` so that :func:`~asyncio.client.connect` raises :exc:`~exceptions.InvalidHandshake` or a subclass when the opening diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 98490797f..12e8b1623 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -79,7 +79,12 @@ def get_next_frame(self, timeout: float | None = None) -> Frame: raise EOFError("stream of frames ended") from None else: try: - frame = self.frames.get(block=True, timeout=timeout) + # Check for a frame that's already received if timeout <= 0. + # SimpleQueue.get() doesn't support negative timeout values. + if timeout is not None and timeout <= 0: + frame = self.frames.get(block=False) + else: + frame = self.frames.get(block=True, timeout=timeout) except queue.Empty: raise TimeoutError(f"timed out in {timeout:.1f}s") from None if frame is None: @@ -143,7 +148,7 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: deadline = Deadline(timeout) # First frame - frame = self.get_next_frame(deadline.timeout()) + frame = self.get_next_frame(deadline.timeout(raise_if_elapsed=False)) with self.mutex: self.maybe_resume() assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY @@ -154,7 +159,9 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: # Following frames, for fragmented messages while not frame.fin: try: - frame = self.get_next_frame(deadline.timeout()) + frame = self.get_next_frame( + deadline.timeout(raise_if_elapsed=False) + ) except TimeoutError: # Put frames already received back into the queue # so that future calls to get() can return them. diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py index d22693102..0a94b4f85 100644 --- a/tests/sync/test_messages.py +++ b/tests/sync/test_messages.py @@ -198,6 +198,12 @@ def test_get_timeout_after_first_frame(self): message = self.assembler.get() self.assertEqual(message, "café") + def test_get_if_received(self): + """get returns a text message if it's already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = self.assembler.get(timeout=0) + self.assertEqual(message, "café") + # Test get_iter def test_get_iter_text_message_already_received(self): From b1e88fcb77f6b74b2eab6a11a4ee53e4f04937ac Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 30 Nov 2024 06:42:21 +0000 Subject: [PATCH 1466/1539] Bump pypa/cibuildwheel from 2.21.3 to 2.22.0 Bumps [pypa/cibuildwheel](https://github.com/pypa/cibuildwheel) from 2.21.3 to 2.22.0. - [Release notes](https://github.com/pypa/cibuildwheel/releases) - [Changelog](https://github.com/pypa/cibuildwheel/blob/main/docs/changelog.md) - [Commits](https://github.com/pypa/cibuildwheel/compare/v2.21.3...v2.22.0) --- updated-dependencies: - dependency-name: pypa/cibuildwheel dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index aa1e2e7a9..0ff07c9df 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -53,7 +53,7 @@ jobs: with: platforms: all - name: Build wheels - uses: pypa/cibuildwheel@v2.21.3 + uses: pypa/cibuildwheel@v2.22.0 env: BUILD_EXTENSION: yes - name: Save wheels From 6317c00cc5af245116781ddfde518ec004de672e Mon Sep 17 00:00:00 2001 From: = Date: Fri, 10 Jan 2025 14:47:36 -0700 Subject: [PATCH 1467/1539] Clarify behavior of `recv(timeout=0)` behavior. Refs #1552. --- src/websockets/sync/connection.py | 7 ++++--- tests/sync/test_messages.py | 24 ++++++++++++++++++++++-- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 60be245b4..e78073242 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -232,9 +232,10 @@ def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data message stream. If ``timeout`` is :obj:`None`, block until a message is received. If - ``timeout`` is set and no message is received within ``timeout`` - seconds, raise :exc:`TimeoutError`. Set ``timeout`` to ``0`` to check if - a message was already received. + ``timeout`` is set, wait up to ``timeout`` seconds for a message to be + received and return it, else raise :exc:`TimeoutError`. If ``timeout`` + is ``0`` or negative, check if a message has been received already and + return it, else raise :exc:`TimeoutError`. If the message is fragmented, wait until all fragments are received, reassemble them, and return the whole message. diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py index 0a94b4f85..e5510af3e 100644 --- a/tests/sync/test_messages.py +++ b/tests/sync/test_messages.py @@ -198,12 +198,32 @@ def test_get_timeout_after_first_frame(self): message = self.assembler.get() self.assertEqual(message, "café") - def test_get_if_received(self): - """get returns a text message if it's already received.""" + def test_get_timeout_0_message_already_received(self): + """get(timeout=0) returns a message that is already received.""" self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) message = self.assembler.get(timeout=0) self.assertEqual(message, "café") + def test_get_timeout_0_message_not_received_yet(self): + """get(timeout=0) times out when no message is already received.""" + with self.assertRaises(TimeoutError): + self.assembler.get(timeout=0) + + def test_get_timeout_0_fragmented_message_already_received(self): + """get(timeout=0) returns a fragmented message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + message = self.assembler.get(timeout=0) + self.assertEqual(message, "café") + + def test_get_timeout_0_fragmented_message_partially_received(self): + """get(timeout=0) times out when a fragmented message is partially received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + with self.assertRaises(TimeoutError): + self.assembler.get(timeout=0) + # Test get_iter def test_get_iter_text_message_already_received(self): From 031ec31b70adc527836c5565a7809724fb888c9c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 11 Jan 2025 10:11:59 +0100 Subject: [PATCH 1468/1539] Prevent close() from blocking when reading is paused. Fix #1555. --- docs/project/changelog.rst | 11 +++++++---- src/websockets/asyncio/messages.py | 2 +- src/websockets/sync/messages.py | 7 ++++++- tests/sync/test_messages.py | 12 ++++++++++++ 4 files changed, 26 insertions(+), 6 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index de6b0be21..83a8e6962 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -35,15 +35,18 @@ notice. Bug fixes ......... -* Fixed ``connection.recv(timeout=0)`` in the :mod:`threading` implementation. - If a message is already received, it is returned. Previously, - :exc:`TimeoutError` was raised incorrectly. - * Wrapped errors when reading the opening handshake request or response in :exc:`~exceptions.InvalidMessage` so that :func:`~asyncio.client.connect` raises :exc:`~exceptions.InvalidHandshake` or a subclass when the opening handshake fails. +* Fixed :meth:`~sync.connection.Connection.recv` with ``timeout=0`` in the + :mod:`threading` implementation. If a message is already received, it is + returned. Previously, :exc:`TimeoutError` was raised incorrectly. + +* Prevented :meth:`~sync.connection.Connection.close` from blocking when + receive buffers are saturated in the :mod:`threading` implementation. + .. _14.1: 14.1 diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index e6d1d31cc..c10072467 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -283,7 +283,7 @@ def close(self) -> None: """ End the stream of frames. - Callling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, + Calling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, or :meth:`put` is safe. They will raise :exc:`EOFError`. """ diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 12e8b1623..dfabedd65 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -298,7 +298,7 @@ def close(self) -> None: """ End the stream of frames. - Callling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, + Calling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, or :meth:`put` is safe. They will raise :exc:`EOFError`. """ @@ -311,3 +311,8 @@ def close(self) -> None: if self.get_in_progress: # Unblock get() or get_iter(). self.frames.put(None) + + if self.paused: + # Unblock recv_events(). + self.paused = False + self.resume() diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py index e5510af3e..e42784094 100644 --- a/tests/sync/test_messages.py +++ b/tests/sync/test_messages.py @@ -496,6 +496,18 @@ def test_put_fails_after_close(self): with self.assertRaises(EOFError): self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + def test_close_resumes_reading(self): + """close unblocks reading when queue is above the high-water mark.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"water")) + + # queue is at the high-water mark + assert self.assembler.paused + + self.assembler.close() + self.resume.assert_called_once_with() + def test_close_is_idempotent(self): """close can be called multiple times safely.""" self.assembler.close() From e7a098e1a0d5dffcac5f0600703c4ec2de0be48a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 12 Jan 2025 09:00:51 +0100 Subject: [PATCH 1469/1539] Prevent AssertionError in the recv_events thread. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit close_socket() was interacting with the protocol, namely calling protocol.receive_of(), without locking the mutex. This created the possibility of a race condition. If two threads called receive_eof() concurrently, the second one could return before the first one finished running it. This led to receive_eof() returning (in the second thread) before the connection state was CLOSED, breaking an invariant. This race condition could be triggered reliably by shutting down the network (e.g., turning wifi off), closing the connection, and waiting for the timeout. Then, close() calls close_socket() — this happens in the `raise_close_exc` branch of send_context(). This unblocks the read in recv_events() which calls close_socket() in the `finally:` branch. Fix #1558. --- src/websockets/sync/connection.py | 5 +++-- tests/sync/test_client.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index e78073242..06ea00efc 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -923,8 +923,9 @@ def close_socket(self) -> None: # Calling protocol.receive_eof() is safe because it's idempotent. # This guarantees that the protocol state becomes CLOSED. - self.protocol.receive_eof() - assert self.protocol.state is CLOSED + with self.protocol_mutex: + self.protocol.receive_eof() + assert self.protocol.state is CLOSED # Abort recv() with a ConnectionClosed exception. self.recv_messages.close() diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 7d8170519..7ab8f4dd4 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -151,7 +151,8 @@ def test_connection_closed_during_handshake(self): """Client reads EOF before receiving handshake response from server.""" def close_connection(self, request): - self.close_socket() + self.socket.shutdown(socket.SHUT_RDWR) + self.socket.close() with run_server(process_request=close_connection) as server: with self.assertRaises(InvalidMessage) as raised: From 613f3f0ef83a0c80ae49a42766fb634295216c5c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Jan 2025 18:58:20 +0100 Subject: [PATCH 1470/1539] Prevent close() from blocking when reading is paused. Closing the transport normally is achieved with transport.write_eof(). Closing it in abnormal situations relied on transport.close(). However, that didn't lead to connection_lost() when reading is paused. Replacing it with transport.abort() ensures that buffers are dropped (which is what we want in abnormal situations) and connection_lost() called quickly. Fix #1555 (for real!) --- docs/project/changelog.rst | 5 +++-- src/websockets/asyncio/client.py | 4 ++-- src/websockets/asyncio/connection.py | 10 +--------- src/websockets/asyncio/server.py | 6 +++--- tests/asyncio/test_client.py | 2 +- 5 files changed, 10 insertions(+), 17 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 83a8e6962..5a2e1bc24 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -44,8 +44,9 @@ Bug fixes :mod:`threading` implementation. If a message is already received, it is returned. Previously, :exc:`TimeoutError` was raised incorrectly. -* Prevented :meth:`~sync.connection.Connection.close` from blocking when - receive buffers are saturated in the :mod:`threading` implementation. +* Prevented :meth:`~asyncio.connection.Connection.close` from blocking when + receive buffers are saturated in the :mod:`asyncio` and :mod:`threading` + implementations. .. _14.1: diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 8581c0551..f05f546d3 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -445,7 +445,7 @@ async def __await_impl__(self) -> ClientConnection: try: await self.connection.handshake(*self.handshake_args) except asyncio.CancelledError: - self.connection.close_transport() + self.connection.transport.abort() raise except Exception as exc: # Always close the connection even though keep-alive is @@ -454,7 +454,7 @@ async def __await_impl__(self) -> ClientConnection: # protocol. In the current design of connect(), there is # no easy way to reuse the network connection that works # in every case nor to reinitialize the protocol. - self.connection.close_transport() + self.connection.transport.abort() uri_or_exc = self.process_redirect(exc) # Response is a valid redirect; follow it. diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index c34d19d58..e2e587e7c 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -928,7 +928,7 @@ async def send_context( # If an error occurred, close the transport to terminate the connection and # raise an exception. if raise_close_exc: - self.close_transport() + self.transport.abort() # Wait for the protocol state to be CLOSED before accessing close_exc. await asyncio.shield(self.connection_lost_waiter) raise self.protocol.close_exc from original_exc @@ -969,14 +969,6 @@ def set_recv_exc(self, exc: BaseException | None) -> None: if self.recv_exc is None: self.recv_exc = exc - def close_transport(self) -> None: - """ - Close transport and message assembler. - - """ - self.transport.close() - self.recv_messages.close() - # asyncio.Protocol methods # Connection callbacks diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index fdb928004..49d6f8910 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -356,16 +356,16 @@ async def conn_handler(self, connection: ServerConnection) -> None: self.server_header, ) except asyncio.CancelledError: - connection.close_transport() + connection.transport.abort() raise except Exception: connection.logger.error("opening handshake failed", exc_info=True) - connection.close_transport() + connection.transport.abort() return if connection.protocol.state is not OPEN: # process_request or process_response rejected the handshake. - connection.close_transport() + connection.transport.abort() return try: diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 1773c08bd..4db4c038c 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -393,7 +393,7 @@ async def test_connection_closed_during_handshake(self): """Client reads EOF before receiving handshake response from server.""" def close_connection(self, request): - self.close_transport() + self.transport.close() async with serve(*args, process_request=close_connection) as server: with self.assertRaises(InvalidMessage) as raised: From 7e617b2a57177885926b3a4a8a092621ab719a00 Mon Sep 17 00:00:00 2001 From: dan005 Date: Mon, 13 Jan 2025 23:22:20 +0500 Subject: [PATCH 1471/1539] Add regex support in `ServerProtocol(origins=...)`. --- src/websockets/asyncio/server.py | 10 +++++---- src/websockets/server.py | 21 +++++++++++++----- src/websockets/sync/server.py | 10 +++++---- tests/test_server.py | 37 ++++++++++++++++++++++++++++++++ 4 files changed, 65 insertions(+), 13 deletions(-) diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 49d6f8910..080ea3f16 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -4,6 +4,7 @@ import hmac import http import logging +import re import socket import sys from collections.abc import Awaitable, Generator, Iterable, Sequence @@ -599,9 +600,10 @@ def handler(websocket): See :meth:`~asyncio.loop.create_server` for details. port: TCP port the server listens on. See :meth:`~asyncio.loop.create_server` for details. - origins: Acceptable values of the ``Origin`` header, for defending - against Cross-Site WebSocket Hijacking attacks. Include :obj:`None` - in the list if the lack of an origin is acceptable. + origins: Acceptable values of the ``Origin`` header, including regular + expressions, for defending against Cross-Site WebSocket Hijacking + attacks. Include :obj:`None` in the list if the lack of an origin + is acceptable. extensions: List of supported extensions, in order in which they should be negotiated and run. subprotocols: List of supported subprotocols, in order of decreasing @@ -681,7 +683,7 @@ def __init__( port: int | None = None, *, # WebSocket - origins: Sequence[Origin | None] | None = None, + origins: Sequence[Origin | re.Pattern[str] | None] | None = None, extensions: Sequence[ServerExtensionFactory] | None = None, subprotocols: Sequence[Subprotocol] | None = None, select_subprotocol: ( diff --git a/src/websockets/server.py b/src/websockets/server.py index fe3c65a7d..67082ed72 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -4,6 +4,7 @@ import binascii import email.utils import http +import re import warnings from collections.abc import Generator, Sequence from typing import Any, Callable, cast @@ -49,9 +50,9 @@ class ServerProtocol(Protocol): Sans-I/O implementation of a WebSocket server connection. Args: - origins: Acceptable values of the ``Origin`` header; include - :obj:`None` in the list if the lack of an origin is acceptable. - This is useful for defending against Cross-Site WebSocket + origins: Acceptable values of the ``Origin`` header, including regular + expressions; include :obj:`None` in the list if the lack of an origin + is acceptable. This is useful for defending against Cross-Site WebSocket Hijacking attacks. extensions: List of supported extensions, in order in which they should be tried. @@ -73,7 +74,7 @@ class ServerProtocol(Protocol): def __init__( self, *, - origins: Sequence[Origin | None] | None = None, + origins: Sequence[Origin | re.Pattern[str] | None] | None = None, extensions: Sequence[ServerExtensionFactory] | None = None, subprotocols: Sequence[Subprotocol] | None = None, select_subprotocol: ( @@ -309,7 +310,17 @@ def process_origin(self, headers: Headers) -> Origin | None: if origin is not None: origin = cast(Origin, origin) if self.origins is not None: - if origin not in self.origins: + valid = False + for acceptable_origin_or_regex in self.origins: + if isinstance(acceptable_origin_or_regex, re.Pattern): + # `str(origin)` is needed for compatibility + # between `Pattern.match(string=...)` and `origin`. + valid = acceptable_origin_or_regex.match(str(origin)) is not None + else: + valid = acceptable_origin_or_regex == origin + if valid: + break + if not valid: raise InvalidOrigin(origin) return origin diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 9506d6830..c14e558ac 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -4,6 +4,7 @@ import http import logging import os +import re import selectors import socket import ssl as ssl_module @@ -325,7 +326,7 @@ def serve( sock: socket.socket | None = None, ssl: ssl_module.SSLContext | None = None, # WebSocket - origins: Sequence[Origin | None] | None = None, + origins: Sequence[Origin | re.Pattern[str] | None] | None = None, extensions: Sequence[ServerExtensionFactory] | None = None, subprotocols: Sequence[Subprotocol] | None = None, select_subprotocol: ( @@ -399,9 +400,10 @@ def handler(websocket): You may call :func:`socket.create_server` to create a suitable TCP socket. ssl: Configuration for enabling TLS on the connection. - origins: Acceptable values of the ``Origin`` header, for defending - against Cross-Site WebSocket Hijacking attacks. Include :obj:`None` - in the list if the lack of an origin is acceptable. + origins: Acceptable values of the ``Origin`` header, including regular + expressions, for defending against Cross-Site WebSocket Hijacking + attacks. Include :obj:`None` in the list if the lack of an origin + is acceptable. extensions: List of supported extensions, in order in which they should be negotiated and run. subprotocols: List of supported subprotocols, in order of decreasing diff --git a/tests/test_server.py b/tests/test_server.py index 69f555689..dd5e0d09a 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,5 +1,6 @@ import http import logging +import re import sys import unittest from unittest.mock import patch @@ -623,6 +624,42 @@ def test_unsupported_origin(self): "invalid Origin header: https://original.example.com", ) + def test_supported_origin_by_regex(self): + """ + Handshake succeeds when checking origins and the origin is supported + by a regular expression. + """ + server = ServerProtocol( + origins=["https://example.com", re.compile(r"https://other.*")] + ) + request = make_request() + request.headers["Origin"] = "https://other.example.com" + response = server.accept(request) + server.send_response(response) + + self.assertHandshakeSuccess(server) + self.assertEqual(server.origin, "https://other.example.com") + + def test_unsupported_origin_by_regex(self): + """ + Handshake succeeds when checking origins and the origin is unsupported + by a regular expression. + """ + server = ServerProtocol( + origins=["https://example.com", re.compile(r"https://other.*")] + ) + request = make_request() + request.headers["Origin"] = "https://original.example.com" + response = server.accept(request) + server.send_response(response) + + self.assertEqual(response.status_code, 403) + self.assertHandshakeError( + server, + InvalidOrigin, + "invalid Origin header: https://original.example.com", + ) + def test_no_origin_accepted(self): """Handshake succeeds when the lack of an origin is accepted.""" server = ServerProtocol(origins=[None]) From 7de24bd087e6e51901b6fd5d28bdb6a32404de69 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Jan 2025 21:19:10 +0100 Subject: [PATCH 1472/1539] Improve previous commit. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Require fullmatch instead of match — this avoids a vulnerability. * Shorten code and tweak to match my preferred style. * Add changelog. --- docs/project/changelog.rst | 6 ++++++ src/websockets/asyncio/server.py | 9 ++++---- src/websockets/server.py | 25 +++++++++++---------- src/websockets/sync/server.py | 9 ++++---- tests/test_server.py | 37 +++++++++++++++++++++----------- 5 files changed, 52 insertions(+), 34 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 5a2e1bc24..74fac904f 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -32,6 +32,12 @@ notice. *In development* +New features +............ + +* Added support for regular expressions in the ``origins`` argument of + :func:`~asyncio.server.serve`. + Bug fixes ......... diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 080ea3f16..ebe45c2a9 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -600,10 +600,11 @@ def handler(websocket): See :meth:`~asyncio.loop.create_server` for details. port: TCP port the server listens on. See :meth:`~asyncio.loop.create_server` for details. - origins: Acceptable values of the ``Origin`` header, including regular - expressions, for defending against Cross-Site WebSocket Hijacking - attacks. Include :obj:`None` in the list if the lack of an origin - is acceptable. + origins: Acceptable values of the ``Origin`` header, for defending + against Cross-Site WebSocket Hijacking attacks. Values can be + :class:`str` to test for an exact match or regular expressions + compiled by :func:`re.compile` to test against a pattern. Include + :obj:`None` in the list if the lack of an origin is acceptable. extensions: List of supported extensions, in order in which they should be negotiated and run. subprotocols: List of supported subprotocols, in order of decreasing diff --git a/src/websockets/server.py b/src/websockets/server.py index 67082ed72..90e6c9921 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -50,9 +50,11 @@ class ServerProtocol(Protocol): Sans-I/O implementation of a WebSocket server connection. Args: - origins: Acceptable values of the ``Origin`` header, including regular - expressions; include :obj:`None` in the list if the lack of an origin - is acceptable. This is useful for defending against Cross-Site WebSocket + origins: Acceptable values of the ``Origin`` header. Values can be + :class:`str` to test for an exact match or regular expressions + compiled by :func:`re.compile` to test against a pattern. Include + :obj:`None` in the list if the lack of an origin is acceptable. + This is useful for defending against Cross-Site WebSocket Hijacking attacks. extensions: List of supported extensions, in order in which they should be tried. @@ -310,17 +312,14 @@ def process_origin(self, headers: Headers) -> Origin | None: if origin is not None: origin = cast(Origin, origin) if self.origins is not None: - valid = False - for acceptable_origin_or_regex in self.origins: - if isinstance(acceptable_origin_or_regex, re.Pattern): - # `str(origin)` is needed for compatibility - # between `Pattern.match(string=...)` and `origin`. - valid = acceptable_origin_or_regex.match(str(origin)) is not None - else: - valid = acceptable_origin_or_regex == origin - if valid: + for origin_or_regex in self.origins: + if origin_or_regex == origin or ( + isinstance(origin_or_regex, re.Pattern) + and origin is not None + and origin_or_regex.fullmatch(origin) is not None + ): break - if not valid: + else: raise InvalidOrigin(origin) return origin diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index c14e558ac..50a2f3c06 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -400,10 +400,11 @@ def handler(websocket): You may call :func:`socket.create_server` to create a suitable TCP socket. ssl: Configuration for enabling TLS on the connection. - origins: Acceptable values of the ``Origin`` header, including regular - expressions, for defending against Cross-Site WebSocket Hijacking - attacks. Include :obj:`None` in the list if the lack of an origin - is acceptable. + origins: Acceptable values of the ``Origin`` header, for defending + against Cross-Site WebSocket Hijacking attacks. Values can be + :class:`str` to test for an exact match or regular expressions + compiled by :func:`re.compile` to test against a pattern. Include + :obj:`None` in the list if the lack of an origin is acceptable. extensions: List of supported extensions, in order in which they should be negotiated and run. subprotocols: List of supported subprotocols, in order of decreasing diff --git a/tests/test_server.py b/tests/test_server.py index dd5e0d09a..9f328ded5 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -608,7 +608,7 @@ def test_supported_origin(self): self.assertEqual(server.origin, "https://other.example.com") def test_unsupported_origin(self): - """Handshake succeeds when checking origins and the origin is unsupported.""" + """Handshake fails when checking origins and the origin is unsupported.""" server = ServerProtocol( origins=["https://example.com", "https://other.example.com"] ) @@ -624,13 +624,10 @@ def test_unsupported_origin(self): "invalid Origin header: https://original.example.com", ) - def test_supported_origin_by_regex(self): - """ - Handshake succeeds when checking origins and the origin is supported - by a regular expression. - """ + def test_supported_origin_regex(self): + """Handshake succeeds when checking origins and the origin is supported.""" server = ServerProtocol( - origins=["https://example.com", re.compile(r"https://other.*")] + origins=[re.compile(r"https://(?!original)[a-z]+\.example\.com")] ) request = make_request() request.headers["Origin"] = "https://other.example.com" @@ -640,13 +637,10 @@ def test_supported_origin_by_regex(self): self.assertHandshakeSuccess(server) self.assertEqual(server.origin, "https://other.example.com") - def test_unsupported_origin_by_regex(self): - """ - Handshake succeeds when checking origins and the origin is unsupported - by a regular expression. - """ + def test_unsupported_origin_regex(self): + """Handshake fails when checking origins and the origin is unsupported.""" server = ServerProtocol( - origins=["https://example.com", re.compile(r"https://other.*")] + origins=[re.compile(r"https://(?!original)[a-z]+\.example\.com")] ) request = make_request() request.headers["Origin"] = "https://original.example.com" @@ -660,6 +654,23 @@ def test_unsupported_origin_by_regex(self): "invalid Origin header: https://original.example.com", ) + def test_partial_match_origin_regex(self): + """Handshake fails when checking origins and the origin a partial match.""" + server = ServerProtocol( + origins=[re.compile(r"https://(?!original)[a-z]+\.example\.com")] + ) + request = make_request() + request.headers["Origin"] = "https://other.example.com.hacked" + response = server.accept(request) + server.send_response(response) + + self.assertEqual(response.status_code, 403) + self.assertHandshakeError( + server, + InvalidOrigin, + "invalid Origin header: https://other.example.com.hacked", + ) + def test_no_origin_accepted(self): """Handshake succeeds when the lack of an origin is accepted.""" server = ServerProtocol(origins=[None]) From 17e309a830cabfe5e335bf96cb3795c45f053fbc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Jan 2025 21:27:28 +0100 Subject: [PATCH 1473/1539] Mention another symptom in the changelog. Refs #1527. --- docs/project/changelog.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 74fac904f..5f616d0e8 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -51,8 +51,8 @@ Bug fixes returned. Previously, :exc:`TimeoutError` was raised incorrectly. * Prevented :meth:`~asyncio.connection.Connection.close` from blocking when - receive buffers are saturated in the :mod:`asyncio` and :mod:`threading` - implementations. + the network becomes unavailable or when receive buffers are saturated in + the :mod:`asyncio` and :mod:`threading` implementations. .. _14.1: From c8242bbb3ab7e8054a2cbae2bb88cb649bf730f4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Jan 2025 21:38:01 +0100 Subject: [PATCH 1474/1539] Add changelog for #1566. --- docs/project/changelog.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 5f616d0e8..73b28b7b4 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -50,6 +50,9 @@ Bug fixes :mod:`threading` implementation. If a message is already received, it is returned. Previously, :exc:`TimeoutError` was raised incorrectly. +* Fixed a crash in the :mod:`asyncio` implementation when cancelling a ping + then receiving the corresponding pong. + * Prevented :meth:`~asyncio.connection.Connection.close` from blocking when the network becomes unavailable or when receive buffers are saturated in the :mod:`asyncio` and :mod:`threading` implementations. From 624a36cc9c1ea1369971387d7bb23533b2797350 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Jan 2025 21:42:16 +0100 Subject: [PATCH 1475/1539] Release version 14.2. --- docs/project/changelog.rst | 2 +- src/websockets/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 73b28b7b4..bfb356c62 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -30,7 +30,7 @@ notice. 14.2 ---- -*In development* +*January 19, 2025* New features ............ diff --git a/src/websockets/version.py b/src/websockets/version.py index b0df10f0a..4b11b6fe9 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,7 +18,7 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = False +released = True tag = version = commit = "14.2" From 4af270d035bccd0bd282727c0abfdb7ed2cc0205 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Jan 2025 21:57:27 +0100 Subject: [PATCH 1476/1539] Start version 14.3. --- docs/project/changelog.rst | 7 +++++++ src/websockets/version.py | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index bfb356c62..4ad8f5532 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,13 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. +.. _14.3: + +14.3 +---- + +*In development* + .. _14.2: 14.2 diff --git a/src/websockets/version.py b/src/websockets/version.py index 4b11b6fe9..ca9a9115b 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,9 +18,9 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = True +released = False -tag = version = commit = "14.2" +tag = version = commit = "14.3" if not released: # pragma: no cover From 328a20cbb049fe738f77d13f274c29821646b1de Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Jan 2025 21:57:36 +0100 Subject: [PATCH 1477/1539] Prepare upgrade to cibuildwheel 3. --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 4e26c757e..d3128f8ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,9 @@ Documentation = "https://websockets.readthedocs.io/" Funding = "https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=readme" Tracker = "https://github.com/python-websockets/websockets/issues" +[tool.cibuildwheel] +enable = ["pypy"] + # On a macOS runner, build Intel, Universal, and Apple Silicon wheels. [tool.cibuildwheel.macos] archs = ["x86_64", "universal2", "arm64"] From 7a8b757c56777509bf183bf973407de943d0d973 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 20 Jan 2025 19:50:41 +0100 Subject: [PATCH 1478/1539] Rename and reorder for consistency. --- src/websockets/asyncio/connection.py | 16 ++++++++-------- src/websockets/sync/connection.py | 24 ++++++++++++------------ 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index e2e587e7c..91bc0dda5 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -101,6 +101,14 @@ def __init__( # Protect sending fragmented messages. self.fragmented_send_waiter: asyncio.Future[None] | None = None + # Exception raised while reading from the connection, to be chained to + # ConnectionClosed in order to show why the TCP connection dropped. + self.recv_exc: BaseException | None = None + + # Completed when the TCP connection is closed and the WebSocket + # connection state becomes CLOSED. + self.connection_lost_waiter: asyncio.Future[None] = self.loop.create_future() + # Mapping of ping IDs to pong waiters, in chronological order. self.pong_waiters: dict[bytes, tuple[asyncio.Future[float], float]] = {} @@ -120,14 +128,6 @@ def __init__( # Task that sends keepalive pings. None when ping_interval is None. self.keepalive_task: asyncio.Task[None] | None = None - # Exception raised while reading from the connection, to be chained to - # ConnectionClosed in order to show why the TCP connection dropped. - self.recv_exc: BaseException | None = None - - # Completed when the TCP connection is closed and the WebSocket - # connection state becomes CLOSED. - self.connection_lost_waiter: asyncio.Future[None] = self.loop.create_future() - # Adapted from asyncio.FlowControlMixin self.paused: bool = False self.drain_waiters: collections.deque[asyncio.Future[None]] = ( diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 06ea00efc..653310c35 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -90,14 +90,11 @@ def __init__( resume=self.recv_flow_control.release, ) - # Whether we are busy sending a fragmented message. - self.send_in_progress = False - # Deadline for the closing handshake. self.close_deadline: Deadline | None = None - # Mapping of ping IDs to pong waiters, in chronological order. - self.ping_waiters: dict[bytes, threading.Event] = {} + # Whether we are busy sending a fragmented message. + self.send_in_progress = False # Exception raised in recv_events, to be chained to ConnectionClosed # in the user thread in order to show why the TCP connection dropped. @@ -112,6 +109,9 @@ def __init__( ) self.recv_events_thread.start() + # Mapping of ping IDs to pong waiters, in chronological order. + self.pong_waiters: dict[bytes, threading.Event] = {} + # Public attributes @property @@ -581,15 +581,15 @@ def ping(self, data: Data | None = None) -> threading.Event: with self.send_context(): # Protect against duplicates if a payload is explicitly set. - if data in self.ping_waiters: + if data in self.pong_waiters: raise ConcurrencyError("already waiting for a pong with the same data") # Generate a unique random payload otherwise. - while data is None or data in self.ping_waiters: + while data is None or data in self.pong_waiters: data = struct.pack("!I", random.getrandbits(32)) pong_waiter = threading.Event() - self.ping_waiters[data] = pong_waiter + self.pong_waiters[data] = pong_waiter self.protocol.send_ping(data) return pong_waiter @@ -641,22 +641,22 @@ def acknowledge_pings(self, data: bytes) -> None: """ with self.protocol_mutex: # Ignore unsolicited pong. - if data not in self.ping_waiters: + if data not in self.pong_waiters: return # Sending a pong for only the most recent ping is legal. # Acknowledge all previous pings too in that case. ping_id = None ping_ids = [] - for ping_id, ping in self.ping_waiters.items(): + for ping_id, ping in self.pong_waiters.items(): ping_ids.append(ping_id) ping.set() if ping_id == data: break else: raise AssertionError("solicited pong not found in pings") - # Remove acknowledged pings from self.ping_waiters. + # Remove acknowledged pings from self.pong_waiters. for ping_id in ping_ids: - del self.ping_waiters[ping_id] + del self.pong_waiters[ping_id] def recv_events(self) -> None: """ From 0fac3829353905c2079e05c36958a294b2b49cf1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 21 Jan 2025 22:35:08 +0100 Subject: [PATCH 1479/1539] Add latency measurement to the threading implementation. --- docs/project/changelog.rst | 5 +++++ docs/reference/features.rst | 3 +-- docs/reference/sync/client.rst | 2 ++ docs/reference/sync/common.rst | 2 ++ docs/reference/sync/server.rst | 2 ++ src/websockets/sync/connection.py | 23 +++++++++++++++++++---- 6 files changed, 31 insertions(+), 6 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 4ad8f5532..867231241 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -32,6 +32,11 @@ notice. *In development* +New features +............ + +* Added latency measurement to the :mod:`threading` implementation. + .. _14.2: 14.2 diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 9187fa505..1135bf829 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -65,10 +65,9 @@ Both sides +------------------------------------+--------+--------+--------+--------+ | Heartbeat | ✅ | ❌ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ - | Measure latency | ✅ | ❌ | — | ✅ | + | Measure latency | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Perform the closing handshake | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ | Enforce closing timeout | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Report close codes and reasons | ✅ | ✅ | ✅ | ❌ | diff --git a/docs/reference/sync/client.rst b/docs/reference/sync/client.rst index 2aa491f6a..89316c997 100644 --- a/docs/reference/sync/client.rst +++ b/docs/reference/sync/client.rst @@ -39,6 +39,8 @@ Using a connection .. autoproperty:: remote_address + .. autoproperty:: latency + .. autoproperty:: state The following attributes are available after the opening handshake, diff --git a/docs/reference/sync/common.rst b/docs/reference/sync/common.rst index 3c03b25b6..d44ff55b6 100644 --- a/docs/reference/sync/common.rst +++ b/docs/reference/sync/common.rst @@ -31,6 +31,8 @@ Both sides (:mod:`threading`) .. autoproperty:: remote_address + .. autoattribute:: latency + .. autoproperty:: state The following attributes are available after the opening handshake, diff --git a/docs/reference/sync/server.rst b/docs/reference/sync/server.rst index 1d80450f9..c3d0e8f25 100644 --- a/docs/reference/sync/server.rst +++ b/docs/reference/sync/server.rst @@ -52,6 +52,8 @@ Using a connection .. autoproperty:: remote_address + .. autoproperty:: latency + .. autoproperty:: state The following attributes are available after the opening handshake, diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 653310c35..b0fbf45b6 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -6,6 +6,7 @@ import socket import struct import threading +import time import uuid from collections.abc import Iterable, Iterator, Mapping from types import TracebackType @@ -110,7 +111,16 @@ def __init__( self.recv_events_thread.start() # Mapping of ping IDs to pong waiters, in chronological order. - self.pong_waiters: dict[bytes, threading.Event] = {} + self.pong_waiters: dict[bytes, tuple[threading.Event, float]] = {} + + self.latency: float = 0 + """ + Latency of the connection, in seconds. + + Latency is defined as the round-trip time of the connection. It is + measured by sending a Ping frame and waiting for a matching Pong frame. + Before the first measurement, :attr:`latency` is ``0``. + """ # Public attributes @@ -589,7 +599,7 @@ def ping(self, data: Data | None = None) -> threading.Event: data = struct.pack("!I", random.getrandbits(32)) pong_waiter = threading.Event() - self.pong_waiters[data] = pong_waiter + self.pong_waiters[data] = (pong_waiter, time.monotonic()) self.protocol.send_ping(data) return pong_waiter @@ -643,17 +653,22 @@ def acknowledge_pings(self, data: bytes) -> None: # Ignore unsolicited pong. if data not in self.pong_waiters: return + + pong_timestamp = time.monotonic() + # Sending a pong for only the most recent ping is legal. # Acknowledge all previous pings too in that case. ping_id = None ping_ids = [] - for ping_id, ping in self.pong_waiters.items(): + for ping_id, (pong_waiter, ping_timestamp) in self.pong_waiters.items(): ping_ids.append(ping_id) - ping.set() + pong_waiter.set() if ping_id == data: + self.latency = pong_timestamp - ping_timestamp break else: raise AssertionError("solicited pong not found in pings") + # Remove acknowledged pings from self.pong_waiters. for ping_id in ping_ids: del self.pong_waiters[ping_id] From fc7b151fdfbc092a8d2062ef522d374074153cfe Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 21 Jan 2025 22:32:30 +0100 Subject: [PATCH 1480/1539] Add option to set pong waiters on connection close. --- src/websockets/sync/connection.py | 38 +++++++++++++++++++++++++++---- tests/sync/test_connection.py | 9 ++++++++ 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index b0fbf45b6..5270c1fab 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -111,7 +111,7 @@ def __init__( self.recv_events_thread.start() # Mapping of ping IDs to pong waiters, in chronological order. - self.pong_waiters: dict[bytes, tuple[threading.Event, float]] = {} + self.pong_waiters: dict[bytes, tuple[threading.Event, float, bool]] = {} self.latency: float = 0 """ @@ -554,7 +554,11 @@ def close(self, code: int = CloseCode.NORMAL_CLOSURE, reason: str = "") -> None: # They mean that the connection is closed, which was the goal. pass - def ping(self, data: Data | None = None) -> threading.Event: + def ping( + self, + data: Data | None = None, + ack_on_close: bool = False, + ) -> threading.Event: """ Send a Ping_. @@ -566,6 +570,12 @@ def ping(self, data: Data | None = None) -> threading.Event: Args: data: Payload of the ping. A :class:`str` will be encoded to UTF-8. If ``data`` is :obj:`None`, the payload is four random bytes. + ack_on_close: when this option is :obj:`True`, the event will also + be set when the connection is closed. While this avoids getting + stuck waiting for a pong that will never arrive, it requires + checking that the state of the connection is still ``OPEN`` to + confirm that a pong was received, rather than the connection + being closed. Returns: An event that will be set when the corresponding pong is received. @@ -599,7 +609,7 @@ def ping(self, data: Data | None = None) -> threading.Event: data = struct.pack("!I", random.getrandbits(32)) pong_waiter = threading.Event() - self.pong_waiters[data] = (pong_waiter, time.monotonic()) + self.pong_waiters[data] = (pong_waiter, time.monotonic(), ack_on_close) self.protocol.send_ping(data) return pong_waiter @@ -660,7 +670,11 @@ def acknowledge_pings(self, data: bytes) -> None: # Acknowledge all previous pings too in that case. ping_id = None ping_ids = [] - for ping_id, (pong_waiter, ping_timestamp) in self.pong_waiters.items(): + for ping_id, ( + pong_waiter, + ping_timestamp, + _ack_on_close, + ) in self.pong_waiters.items(): ping_ids.append(ping_id) pong_waiter.set() if ping_id == data: @@ -673,6 +687,19 @@ def acknowledge_pings(self, data: bytes) -> None: for ping_id in ping_ids: del self.pong_waiters[ping_id] + def acknowledge_pending_pings(self) -> None: + """ + Acknowledge pending pings when the connection is closed. + + """ + assert self.protocol.state is CLOSED + + for pong_waiter, _ping_timestamp, ack_on_close in self.pong_waiters.values(): + if ack_on_close: + pong_waiter.set() + + self.pong_waiters.clear() + def recv_events(self) -> None: """ Read incoming data from the socket and process events. @@ -944,3 +971,6 @@ def close_socket(self) -> None: # Abort recv() with a ConnectionClosed exception. self.recv_messages.close() + + # Acknowledge pings sent with the ack_on_close option. + self.acknowledge_pending_pings() diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index aa445498c..ee4aec5f6 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -685,6 +685,15 @@ def test_acknowledge_previous_ping(self): self.remote_connection.pong("that") self.assertTrue(pong_waiter.wait(MS)) + def test_acknowledge_ping_on_close(self): + """ping with ack_on_close is acknowledged when the connection is closed.""" + with self.drop_frames_rcvd(): # drop automatic response to ping + pong_waiter_ack_on_close = self.connection.ping("this", ack_on_close=True) + pong_waiter = self.connection.ping("that") + self.connection.close() + self.assertTrue(pong_waiter_ack_on_close.wait(MS)) + self.assertFalse(pong_waiter.wait(MS)) + def test_ping_duplicate_payload(self): """ping rejects the same payload until receiving the pong.""" with self.drop_frames_rcvd(): # drop automatic response to ping From 8f12d8fd16d8945cbb7ad2d37962c47c16cb804f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 21 Jan 2025 22:21:31 +0100 Subject: [PATCH 1481/1539] Add keepalive to the threading implementation. --- docs/project/changelog.rst | 3 +- docs/reference/features.rst | 4 +- docs/topics/keepalive.rst | 5 -- src/websockets/asyncio/connection.py | 17 +++-- src/websockets/sync/client.py | 17 ++++- src/websockets/sync/connection.py | 63 +++++++++++++++ src/websockets/sync/server.py | 17 ++++- tests/asyncio/test_connection.py | 26 +++---- tests/sync/test_client.py | 15 ++++ tests/sync/test_connection.py | 110 +++++++++++++++++++++++++++ tests/sync/test_server.py | 21 +++++ 11 files changed, 267 insertions(+), 31 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 867231241..67c16ba9e 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -35,7 +35,8 @@ notice. New features ............ -* Added latency measurement to the :mod:`threading` implementation. +* Added :doc:`keepalive and latency measurement <../topics/keepalive>` to the + :mod:`threading` implementation. .. _14.2: diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 1135bf829..6ba42f66b 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -61,9 +61,9 @@ Both sides +------------------------------------+--------+--------+--------+--------+ | Send a pong | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ - | Keepalive | ✅ | ❌ | — | ✅ | + | Keepalive | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ - | Heartbeat | ✅ | ❌ | — | ✅ | + | Heartbeat | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Measure latency | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ diff --git a/docs/topics/keepalive.rst b/docs/topics/keepalive.rst index a0467ced2..e63c2f8f5 100644 --- a/docs/topics/keepalive.rst +++ b/docs/topics/keepalive.rst @@ -1,11 +1,6 @@ Keepalive and latency ===================== -.. admonition:: This guide applies only to the :mod:`asyncio` implementation. - :class: tip - - The :mod:`threading` implementation doesn't provide keepalive yet. - .. currentmodule:: websockets Long-lived connections diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 91bc0dda5..75c43fa8a 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -686,8 +686,7 @@ async def ping(self, data: Data | None = None) -> Awaitable[float]: pong_waiter = self.loop.create_future() # The event loop's default clock is time.monotonic(). Its resolution # is a bit low on Windows (~16ms). This is improved in Python 3.13. - ping_timestamp = self.loop.time() - self.pong_waiters[data] = (pong_waiter, ping_timestamp) + self.pong_waiters[data] = (pong_waiter, self.loop.time()) self.protocol.send_ping(data) return pong_waiter @@ -792,13 +791,19 @@ async def keepalive(self) -> None: latency = 0.0 try: while True: - # If self.ping_timeout > latency > self.ping_interval, pings - # will be sent immediately after receiving pongs. The period - # will be longer than self.ping_interval. + # If self.ping_timeout > latency > self.ping_interval, + # pings will be sent immediately after receiving pongs. + # The period will be longer than self.ping_interval. await asyncio.sleep(self.ping_interval - latency) - self.logger.debug("% sending keepalive ping") + # This cannot raise ConnectionClosed when the connection is + # closing because ping(), via send_context(), waits for the + # connection to be closed before raising ConnectionClosed. + # However, connection_lost() cancels keepalive_task before + # it gets a chance to resume excuting. pong_waiter = await self.ping() + if self.debug: + self.logger.debug("% sent keepalive ping") if self.ping_timeout is not None: try: diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 9e6da7caf..8325721b7 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -40,8 +40,8 @@ class ClientConnection(Connection): :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. - The ``close_timeout`` and ``max_queue`` arguments have the same meaning as - in :func:`connect`. + The ``ping_interval``, ``ping_timeout``, ``close_timeout``, and + ``max_queue`` arguments have the same meaning as in :func:`connect`. Args: socket: Socket connected to a WebSocket server. @@ -54,6 +54,8 @@ def __init__( socket: socket.socket, protocol: ClientProtocol, *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, close_timeout: float | None = 10, max_queue: int | None | tuple[int | None, int | None] = 16, ) -> None: @@ -62,6 +64,8 @@ def __init__( super().__init__( socket, protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, ) @@ -136,6 +140,8 @@ def connect( compression: str | None = "deflate", # Timeouts open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, @@ -184,6 +190,10 @@ def connect( :doc:`compression guide <../../topics/compression>` for details. open_timeout: Timeout for opening the connection in seconds. :obj:`None` disables the timeout. + ping_interval: Interval between keepalive pings in seconds. + :obj:`None` disables keepalive. + ping_timeout: Timeout for keepalive pings in seconds. + :obj:`None` disables timeouts. close_timeout: Timeout for closing the connection in seconds. :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. @@ -296,6 +306,8 @@ def connect( connection = create_connection( sock, protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, ) @@ -315,6 +327,7 @@ def connect( connection.recv_events_thread.join() raise + connection.start_keepalive() return connection diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 5270c1fab..07f0543e4 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -49,11 +49,15 @@ def __init__( socket: socket.socket, protocol: Protocol, *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, close_timeout: float | None = 10, max_queue: int | None | tuple[int | None, int | None] = 16, ) -> None: self.socket = socket self.protocol = protocol + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout self.close_timeout = close_timeout if isinstance(max_queue, int) or max_queue is None: max_queue = (max_queue, None) @@ -120,8 +124,15 @@ def __init__( Latency is defined as the round-trip time of the connection. It is measured by sending a Ping frame and waiting for a matching Pong frame. Before the first measurement, :attr:`latency` is ``0``. + + By default, websockets enables a :ref:`keepalive ` mechanism + that sends Ping frames automatically at regular intervals. You can also + send Ping frames and measure latency with :meth:`ping`. """ + # Thread that sends keepalive pings. None when ping_interval is None. + self.keepalive_thread: threading.Thread | None = None + # Public attributes @property @@ -700,6 +711,58 @@ def acknowledge_pending_pings(self) -> None: self.pong_waiters.clear() + def keepalive(self) -> None: + """ + Send a Ping frame and wait for a Pong frame at regular intervals. + + """ + assert self.ping_interval is not None + try: + while True: + # If self.ping_timeout > self.latency > self.ping_interval, + # pings will be sent immediately after receiving pongs. + # The period will be longer than self.ping_interval. + self.recv_events_thread.join(self.ping_interval - self.latency) + if not self.recv_events_thread.is_alive(): + break + + try: + pong_waiter = self.ping(ack_on_close=True) + except ConnectionClosed: + break + if self.debug: + self.logger.debug("% sent keepalive ping") + + if self.ping_timeout is not None: + # + if pong_waiter.wait(self.ping_timeout): + if self.debug: + self.logger.debug("% received keepalive pong") + else: + if self.debug: + self.logger.debug("- timed out waiting for keepalive pong") + with self.send_context(): + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "keepalive ping timeout", + ) + break + except Exception: + self.logger.error("keepalive ping failed", exc_info=True) + + def start_keepalive(self) -> None: + """ + Run :meth:`keepalive` in a thread, unless keepalive is disabled. + + """ + if self.ping_interval is not None: + # This thread is marked as daemon like self.recv_events_thread. + self.keepalive_thread = threading.Thread( + target=self.keepalive, + daemon=True, + ) + self.keepalive_thread.start() + def recv_events(self) -> None: """ Read incoming data from the socket and process events. diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 50a2f3c06..643f9b44b 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -52,8 +52,8 @@ class ServerConnection(Connection): :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. - The ``close_timeout`` and ``max_queue`` arguments have the same meaning as - in :func:`serve`. + The ``ping_interval``, ``ping_timeout``, ``close_timeout``, and + ``max_queue`` arguments have the same meaning as in :func:`serve`. Args: socket: Socket connected to a WebSocket client. @@ -66,6 +66,8 @@ def __init__( socket: socket.socket, protocol: ServerProtocol, *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, close_timeout: float | None = 10, max_queue: int | None | tuple[int | None, int | None] = 16, ) -> None: @@ -74,6 +76,8 @@ def __init__( super().__init__( socket, protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, ) @@ -354,6 +358,8 @@ def serve( compression: str | None = "deflate", # Timeouts open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, @@ -434,6 +440,10 @@ def handler(websocket): :doc:`compression guide <../../topics/compression>` for details. open_timeout: Timeout for opening connections in seconds. :obj:`None` disables the timeout. + ping_interval: Interval between keepalive pings in seconds. + :obj:`None` disables keepalive. + ping_timeout: Timeout for keepalive pings in seconds. + :obj:`None` disables timeouts. close_timeout: Timeout for closing connections in seconds. :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. @@ -563,6 +573,8 @@ def protocol_select_subprotocol( connection = create_connection( sock, protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, ) @@ -590,6 +602,7 @@ def protocol_select_subprotocol( assert connection.protocol.state is OPEN try: + connection.start_keepalive() handler(connection) except Exception: connection.logger.error("connection handler failed", exc_info=True) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 788a457ed..b53c97030 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -1010,7 +1010,7 @@ async def test_keepalive_times_out(self, getrandbits): self.connection.start_keepalive() # 4 ms: keepalive() sends a ping frame. await asyncio.sleep(4 * MS) - # Exiting the context manager sleeps for MS. + # Exiting the context manager sleeps for 1 ms. # 4.x ms: a pong frame is dropped. # 6 ms: no pong frame is received; the connection is closed. await asyncio.sleep(2 * MS) @@ -1026,9 +1026,9 @@ async def test_keepalive_ignores_timeout(self, getrandbits): getrandbits.return_value = 1918987876 self.connection.start_keepalive() # 4 ms: keepalive() sends a ping frame. - await asyncio.sleep(4 * MS) - # Exiting the context manager sleeps for MS. # 4.x ms: a pong frame is dropped. + await asyncio.sleep(4 * MS) + # Exiting the context manager sleeps for 1 ms. # 6 ms: no pong frame is received; the connection remains open. await asyncio.sleep(2 * MS) # 7 ms: check that the connection is still open. @@ -1036,7 +1036,7 @@ async def test_keepalive_ignores_timeout(self, getrandbits): async def test_keepalive_terminates_while_sleeping(self): """keepalive task terminates while waiting to send a ping.""" - self.connection.ping_interval = 2 * MS + self.connection.ping_interval = 3 * MS self.connection.start_keepalive() await asyncio.sleep(MS) await self.connection.close() @@ -1044,15 +1044,15 @@ async def test_keepalive_terminates_while_sleeping(self): async def test_keepalive_terminates_while_waiting_for_pong(self): """keepalive task terminates while waiting to receive a pong.""" - self.connection.ping_interval = 2 * MS - self.connection.ping_timeout = 2 * MS + self.connection.ping_interval = MS + self.connection.ping_timeout = 3 * MS async with self.drop_frames_rcvd(): self.connection.start_keepalive() - # 2 ms: keepalive() sends a ping frame. - await asyncio.sleep(2 * MS) - # Exiting the context manager sleeps for MS. - # 2.x ms: a pong frame is dropped. - # 3 ms: close the connection before ping_timeout elapses. + # 1 ms: keepalive() sends a ping frame. + # 1.x ms: a pong frame is dropped. + await asyncio.sleep(MS) + # Exiting the context manager sleeps for 1 ms. + # 2 ms: close the connection before ping_timeout elapses. await self.connection.close() self.assertTrue(self.connection.keepalive_task.done()) @@ -1062,9 +1062,9 @@ async def test_keepalive_reports_errors(self): async with self.drop_frames_rcvd(): self.connection.start_keepalive() # 2 ms: keepalive() sends a ping frame. - await asyncio.sleep(2 * MS) - # Exiting the context manager sleeps for MS. # 2.x ms: a pong frame is dropped. + await asyncio.sleep(2 * MS) + # Exiting the context manager sleeps for 1 ms. # 3 ms: inject a fault: raise an exception in the pending pong waiter. pong_waiter = next(iter(self.connection.pong_waiters.values()))[0] with self.assertLogs("websockets", logging.ERROR) as logs: diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 7ab8f4dd4..1669a0e84 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -76,6 +76,21 @@ def test_disable_compression(self): with connect(get_uri(server), compression=None) as client: self.assertEqual(client.protocol.extensions, []) + def test_keepalive_is_enabled(self): + """Client enables keepalive and measures latency by default.""" + with run_server() as server: + with connect(get_uri(server), ping_interval=MS) as client: + self.assertEqual(client.latency, 0) + time.sleep(2 * MS) + self.assertGreater(client.latency, 0) + + def test_disable_keepalive(self): + """Client disables keepalive.""" + with run_server() as server: + with connect(get_uri(server), ping_interval=None) as client: + time.sleep(2 * MS) + self.assertEqual(client.latency, 0) + def test_logger(self): """Client accepts a logger argument.""" logger = logging.getLogger("test") diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index ee4aec5f6..f191d8944 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -738,6 +738,116 @@ def test_pong_unsupported_type(self): with self.assertRaises(TypeError): self.connection.pong([]) + # Test keepalive. + + @patch("random.getrandbits") + def test_keepalive(self, getrandbits): + """keepalive sends pings at ping_interval and measures latency.""" + self.connection.ping_interval = 2 * MS + getrandbits.return_value = 1918987876 + self.connection.start_keepalive() + self.assertEqual(self.connection.latency, 0) + # 2 ms: keepalive() sends a ping frame. + # 2.x ms: a pong frame is received. + time.sleep(3 * MS) + # 3 ms: check that the ping frame was sent. + self.assertFrameSent(Frame(Opcode.PING, b"rand")) + self.assertGreater(self.connection.latency, 0) + self.assertLess(self.connection.latency, MS) + + def test_disable_keepalive(self): + """keepalive is disabled when ping_interval is None.""" + self.connection.ping_interval = None + self.connection.start_keepalive() + time.sleep(3 * MS) + self.assertNoFrameSent() + + @patch("random.getrandbits") + def test_keepalive_times_out(self, getrandbits): + """keepalive closes the connection if ping_timeout elapses.""" + self.connection.ping_interval = 4 * MS + self.connection.ping_timeout = 2 * MS + with self.drop_frames_rcvd(): + getrandbits.return_value = 1918987876 + self.connection.start_keepalive() + # 4 ms: keepalive() sends a ping frame. + time.sleep(4 * MS) + # Exiting the context manager sleeps for 1 ms. + # 4.x ms: a pong frame is dropped. + # 6 ms: no pong frame is received; the connection is closed. + time.sleep(2 * MS) + # 7 ms: check that the connection is closed. + self.assertEqual(self.connection.state, State.CLOSED) + + @patch("random.getrandbits") + def test_keepalive_ignores_timeout(self, getrandbits): + """keepalive ignores timeouts if ping_timeout isn't set.""" + self.connection.ping_interval = 4 * MS + self.connection.ping_timeout = None + with self.drop_frames_rcvd(): + getrandbits.return_value = 1918987876 + self.connection.start_keepalive() + # 4 ms: keepalive() sends a ping frame. + time.sleep(4 * MS) + # Exiting the context manager sleeps for 1 ms. + # 4.x ms: a pong frame is dropped. + # 6 ms: no pong frame is received; the connection remains open. + time.sleep(2 * MS) + # 7 ms: check that the connection is still open. + self.assertEqual(self.connection.state, State.OPEN) + + def test_keepalive_terminates_while_sleeping(self): + """keepalive task terminates while waiting to send a ping.""" + self.connection.ping_interval = 3 * MS + self.connection.start_keepalive() + time.sleep(MS) + self.connection.close() + self.connection.keepalive_thread.join(MS) + self.assertFalse(self.connection.keepalive_thread.is_alive()) + + def test_keepalive_terminates_when_sending_ping_fails(self): + """keepalive task terminates when sending a ping fails.""" + self.connection.ping_interval = 1 * MS + self.connection.start_keepalive() + with self.drop_eof_rcvd(), self.drop_frames_rcvd(): + self.connection.close() + self.assertFalse(self.connection.keepalive_thread.is_alive()) + + def test_keepalive_terminates_while_waiting_for_pong(self): + """keepalive task terminates while waiting to receive a pong.""" + self.connection.ping_interval = MS + self.connection.ping_timeout = 4 * MS + with self.drop_frames_rcvd(): + self.connection.start_keepalive() + # 1 ms: keepalive() sends a ping frame. + # 1.x ms: a pong frame is dropped. + time.sleep(MS) + # Exiting the context manager sleeps for 1 ms. + # 2 ms: close the connection before ping_timeout elapses. + self.connection.close() + self.connection.keepalive_thread.join(MS) + self.assertFalse(self.connection.keepalive_thread.is_alive()) + + def test_keepalive_reports_errors(self): + """keepalive reports unexpected errors in logs.""" + self.connection.ping_interval = 2 * MS + with self.drop_frames_rcvd(): + self.connection.start_keepalive() + # 2 ms: keepalive() sends a ping frame. + # 2.x ms: a pong frame is dropped. + with self.assertLogs("websockets", logging.ERROR) as logs: + with patch("threading.Event.wait", side_effect=Exception("BOOM")): + time.sleep(3 * MS) + # Exiting the context manager sleeps for 1 ms. + self.assertEqual( + [record.getMessage() for record in logs.records], + ["keepalive ping failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) + # Test parameters. def test_close_timeout(self): diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index bb2ebae14..8e83f2a81 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -236,6 +236,27 @@ def test_disable_compression(self): with connect(get_uri(server)) as client: self.assertEval(client, "ws.protocol.extensions", "[]") + def test_keepalive_is_enabled(self): + """Server enables keepalive and measures latency.""" + with run_server(ping_interval=MS) as server: + with connect(get_uri(server)) as client: + client.send("ws.latency") + latency = eval(client.recv()) + self.assertEqual(latency, 0) + time.sleep(2 * MS) + client.send("ws.latency") + latency = eval(client.recv()) + self.assertGreater(latency, 0) + + def test_disable_keepalive(self): + """Server disables keepalive.""" + with run_server(ping_interval=None) as server: + with connect(get_uri(server)) as client: + time.sleep(2 * MS) + client.send("ws.latency") + latency = eval(client.recv()) + self.assertEqual(latency, 0) + def test_logger(self): """Server accepts a logger argument.""" logger = logging.getLogger("test") From ee0d7441940131cc30b0f5a9a2afc905bb490564 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 24 Jan 2025 21:15:08 +0100 Subject: [PATCH 1482/1539] Avoid shadowing an import with another import. --- pyproject.toml | 2 +- src/websockets/__init__.py | 5 +++-- src/websockets/typing.py | 5 ++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d3128f8ec..4044de0f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,7 +67,7 @@ exclude_lines = [ "except ImportError:", "if self.debug:", "if sys.platform != \"win32\":", - "if typing.TYPE_CHECKING:", + "if TYPE_CHECKING:", "raise AssertionError", "self.fail\\(\".*\"\\)", "@unittest.skip", diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index c8df54e0b..1d0abe5cd 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -1,6 +1,7 @@ from __future__ import annotations -import typing +# Importing the typing module would conflict with websockets.typing. +from typing import TYPE_CHECKING from .imports import lazy_import from .version import version as __version__ # noqa: F401 @@ -72,7 +73,7 @@ ] # When type checking, import non-deprecated aliases eagerly. Else, import on demand. -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from .asyncio.client import ClientConnection, connect, unix_connect from .asyncio.server import ( Server, diff --git a/src/websockets/typing.py b/src/websockets/typing.py index 0a37141c6..f10481b8b 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -2,8 +2,7 @@ import http import logging -import typing -from typing import Any, NewType, Optional, Union +from typing import TYPE_CHECKING, Any, NewType, Optional, Union __all__ = [ @@ -31,7 +30,7 @@ # Change to logging.Logger | ... when dropping Python < 3.10. -if typing.TYPE_CHECKING: +if TYPE_CHECKING: LoggerLike = Union[logging.Logger, logging.LoggerAdapter[Any]] """Types accepted where a :class:`~logging.Logger` is expected.""" else: # remove this branch when dropping support for Python < 3.11 From bba423e510cc422e27dfd77a95771d208e3e766a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 23 Jan 2025 21:22:14 +0100 Subject: [PATCH 1483/1539] Add type overloads for recv and recv_streaming. Fix #1578. --- docs/project/changelog.rst | 6 +++++ pyproject.toml | 1 + src/websockets/asyncio/connection.py | 20 ++++++++++++++++- src/websockets/asyncio/messages.py | 20 ++++++++++++++++- src/websockets/sync/connection.py | 33 +++++++++++++++++++++++++++- src/websockets/sync/messages.py | 29 +++++++++++++++++++++++- 6 files changed, 105 insertions(+), 4 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 67c16ba9e..7f341d942 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -38,6 +38,12 @@ New features * Added :doc:`keepalive and latency measurement <../topics/keepalive>` to the :mod:`threading` implementation. +Improvements +............ + +* Added type overloads for the ``decode`` argument of + :meth:`~asyncio.connection.Connection.recv`. This may simplify static typing. + .. _14.2: 14.2 diff --git a/pyproject.toml b/pyproject.toml index 4044de0f2..c0d9fcfd9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,7 @@ exclude_lines = [ "if TYPE_CHECKING:", "raise AssertionError", "self.fail\\(\".*\"\\)", + "@overload", "@unittest.skip", ] partial_branches = [ diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 75c43fa8a..79429923e 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -11,7 +11,7 @@ import uuid from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterable, Mapping from types import TracebackType -from typing import Any, cast +from typing import Any, Literal, cast, overload from ..exceptions import ( ConcurrencyError, @@ -243,6 +243,15 @@ async def __aiter__(self) -> AsyncIterator[Data]: except ConnectionClosedOK: return + @overload + async def recv(self, decode: Literal[True]) -> str: ... + + @overload + async def recv(self, decode: Literal[False]) -> bytes: ... + + @overload + async def recv(self, decode: bool | None = None) -> Data: ... + async def recv(self, decode: bool | None = None) -> Data: """ Receive the next message. @@ -312,6 +321,15 @@ async def recv(self, decode: bool | None = None) -> Data: await asyncio.shield(self.connection_lost_waiter) raise self.protocol.close_exc from self.recv_exc + @overload + def recv_streaming(self, decode: Literal[True]) -> AsyncIterator[str]: ... + + @overload + def recv_streaming(self, decode: Literal[False]) -> AsyncIterator[bytes]: ... + + @overload + def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]: ... + async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]: """ Receive the next message frame by frame. diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index c10072467..581870037 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -4,7 +4,7 @@ import codecs import collections from collections.abc import AsyncIterator, Iterable -from typing import Any, Callable, Generic, TypeVar +from typing import Any, Callable, Generic, Literal, TypeVar, overload from ..exceptions import ConcurrencyError from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame @@ -116,6 +116,15 @@ def __init__( # pragma: no cover # This flag marks the end of the connection. self.closed = False + @overload + async def get(self, decode: Literal[True]) -> str: ... + + @overload + async def get(self, decode: Literal[False]) -> bytes: ... + + @overload + async def get(self, decode: bool | None = None) -> Data: ... + async def get(self, decode: bool | None = None) -> Data: """ Read the next message. @@ -176,6 +185,15 @@ async def get(self, decode: bool | None = None) -> Data: else: return data + @overload + def get_iter(self, decode: Literal[True]) -> AsyncIterator[str]: ... + + @overload + def get_iter(self, decode: Literal[False]) -> AsyncIterator[bytes]: ... + + @overload + def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: ... + async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: """ Stream the next message. diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 07f0543e4..0c517cc64 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -10,7 +10,7 @@ import uuid from collections.abc import Iterable, Iterator, Mapping from types import TracebackType -from typing import Any +from typing import Any, Literal, overload from ..exceptions import ( ConcurrencyError, @@ -241,6 +241,28 @@ def __iter__(self) -> Iterator[Data]: except ConnectionClosedOK: return + # This overload structure is required to avoid the error: + # "parameter without a default follows parameter with a default" + + @overload + def recv(self, timeout: float | None, decode: Literal[True]) -> str: ... + + @overload + def recv(self, timeout: float | None, decode: Literal[False]) -> bytes: ... + + @overload + def recv(self, timeout: float | None = None, *, decode: Literal[True]) -> str: ... + + @overload + def recv( + self, timeout: float | None = None, *, decode: Literal[False] + ) -> bytes: ... + + @overload + def recv( + self, timeout: float | None = None, decode: bool | None = None + ) -> Data: ... + def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data: """ Receive the next message. @@ -311,6 +333,15 @@ def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data self.recv_events_thread.join() raise self.protocol.close_exc from self.recv_exc + @overload + def recv_streaming(self, decode: Literal[True]) -> Iterator[str]: ... + + @overload + def recv_streaming(self, decode: Literal[False]) -> Iterator[bytes]: ... + + @overload + def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]: ... + def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]: """ Receive the next message frame by frame. diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index dfabedd65..c619e78a1 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -3,7 +3,7 @@ import codecs import queue import threading -from typing import Any, Callable, Iterable, Iterator +from typing import Any, Callable, Iterable, Iterator, Literal, overload from ..exceptions import ConcurrencyError from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame @@ -110,6 +110,24 @@ def reset_queue(self, frames: Iterable[Frame]) -> None: for frame in queued: # pragma: no cover self.frames.put(frame) + # This overload structure is required to avoid the error: + # "parameter without a default follows parameter with a default" + + @overload + def get(self, timeout: float | None, decode: Literal[True]) -> str: ... + + @overload + def get(self, timeout: float | None, decode: Literal[False]) -> bytes: ... + + @overload + def get(self, timeout: float | None = None, *, decode: Literal[True]) -> str: ... + + @overload + def get(self, timeout: float | None = None, *, decode: Literal[False]) -> bytes: ... + + @overload + def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: ... + def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: """ Read the next message. @@ -181,6 +199,15 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: else: return data + @overload + def get_iter(self, decode: Literal[True]) -> Iterator[str]: ... + + @overload + def get_iter(self, decode: Literal[False]) -> Iterator[bytes]: ... + + @overload + def get_iter(self, decode: bool | None = None) -> Iterator[Data]: ... + def get_iter(self, decode: bool | None = None) -> Iterator[Data]: """ Stream the next message. From c42672beb30e4c65a9f22ea8e0e68ee86b483a7c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 25 Jan 2025 09:26:50 +0100 Subject: [PATCH 1484/1539] Narrow down type of client_max_window_bits. --- src/websockets/extensions/permessage_deflate.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index cefad4f56..8e74cb282 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -2,7 +2,7 @@ import zlib from collections.abc import Sequence -from typing import Any +from typing import Any, Literal from .. import frames from ..exceptions import ( @@ -212,7 +212,7 @@ def _build_parameters( server_no_context_takeover: bool, client_no_context_takeover: bool, server_max_window_bits: int | None, - client_max_window_bits: int | bool | None, + client_max_window_bits: int | Literal[True] | None, ) -> list[ExtensionParameter]: """ Build a list of ``(name, value)`` pairs for some compression parameters. @@ -234,7 +234,7 @@ def _build_parameters( def _extract_parameters( params: Sequence[ExtensionParameter], *, is_server: bool -) -> tuple[bool, bool, int | None, int | bool | None]: +) -> tuple[bool, bool, int | None, int | Literal[True] | None]: """ Extract compression parameters from a list of ``(name, value)`` pairs. @@ -245,7 +245,7 @@ def _extract_parameters( server_no_context_takeover: bool = False client_no_context_takeover: bool = False server_max_window_bits: int | None = None - client_max_window_bits: int | bool | None = None + client_max_window_bits: int | Literal[True] | None = None for name, value in params: if name == "server_no_context_takeover": @@ -324,7 +324,7 @@ def __init__( server_no_context_takeover: bool = False, client_no_context_takeover: bool = False, server_max_window_bits: int | None = None, - client_max_window_bits: int | bool | None = True, + client_max_window_bits: int | Literal[True] | None = True, compress_settings: dict[str, Any] | None = None, ) -> None: """ From a1bf0008fdcc39499cf1a270c5491e1ca1d4335c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 25 Jan 2025 09:35:17 +0100 Subject: [PATCH 1485/1539] Rename the wsuri variable. The name looked confusing. --- docs/howto/sansio.rst | 8 ++++---- src/websockets/asyncio/client.py | 28 ++++++++++++++-------------- src/websockets/client.py | 16 +++++++--------- src/websockets/sync/client.py | 12 ++++++------ tests/asyncio/test_client.py | 4 ++-- 5 files changed, 33 insertions(+), 35 deletions(-) diff --git a/docs/howto/sansio.rst b/docs/howto/sansio.rst index ca530e6a1..27abcdabd 100644 --- a/docs/howto/sansio.rst +++ b/docs/howto/sansio.rst @@ -28,16 +28,16 @@ If you're building a client, parse the URI you'd like to connect to:: from websockets.uri import parse_uri - wsuri = parse_uri("ws://example.com/") + uri = parse_uri("ws://example.com/") -Open a TCP connection to ``(wsuri.host, wsuri.port)`` and perform a TLS -handshake if ``wsuri.secure`` is :obj:`True`. +Open a TCP connection to ``(uri.host, uri.port)`` and perform a TLS handshake +if ``uri.secure`` is :obj:`True`. Initialize a :class:`~client.ClientProtocol`:: from websockets.client import ClientProtocol - protocol = ClientProtocol(wsuri) + protocol = ClientProtocol(uri) Create a WebSocket handshake request with :meth:`~client.ClientProtocol.connect` and send it diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index f05f546d3..237bb273a 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -311,10 +311,10 @@ def __init__( if create_connection is None: create_connection = ClientConnection - def protocol_factory(wsuri: WebSocketURI) -> ClientConnection: + def protocol_factory(uri: WebSocketURI) -> ClientConnection: # This is a protocol in the Sans-I/O implementation of websockets. protocol = ClientProtocol( - wsuri, + uri, origin=origin, extensions=extensions, subprotocols=subprotocols, @@ -346,15 +346,15 @@ async def create_connection(self) -> ClientConnection: """Create TCP or Unix connection.""" loop = asyncio.get_running_loop() - wsuri = parse_uri(self.uri) + ws_uri = parse_uri(self.uri) kwargs = self.connection_kwargs.copy() def factory() -> ClientConnection: - return self.protocol_factory(wsuri) + return self.protocol_factory(ws_uri) - if wsuri.secure: + if ws_uri.secure: kwargs.setdefault("ssl", True) - kwargs.setdefault("server_hostname", wsuri.host) + kwargs.setdefault("server_hostname", ws_uri.host) if kwargs.get("ssl") is None: raise ValueError("ssl=None is incompatible with a wss:// URI") else: @@ -365,8 +365,8 @@ def factory() -> ClientConnection: _, connection = await loop.create_unix_connection(factory, **kwargs) else: if kwargs.get("sock") is None: - kwargs.setdefault("host", wsuri.host) - kwargs.setdefault("port", wsuri.port) + kwargs.setdefault("host", ws_uri.host) + kwargs.setdefault("port", ws_uri.port) _, connection = await loop.create_connection(factory, **kwargs) return connection @@ -392,9 +392,9 @@ def process_redirect(self, exc: Exception) -> Exception | str: ): return exc - old_wsuri = parse_uri(self.uri) + old_ws_uri = parse_uri(self.uri) new_uri = urllib.parse.urljoin(self.uri, exc.response.headers["Location"]) - new_wsuri = parse_uri(new_uri) + new_ws_uri = parse_uri(new_uri) # If connect() received a socket, it is closed and cannot be reused. if self.connection_kwargs.get("sock") is not None: @@ -403,14 +403,14 @@ def process_redirect(self, exc: Exception) -> Exception | str: ) # TLS downgrade is forbidden. - if old_wsuri.secure and not new_wsuri.secure: + if old_ws_uri.secure and not new_ws_uri.secure: return SecurityError(f"cannot follow redirect to non-secure URI {new_uri}") # Apply restrictions to cross-origin redirects. if ( - old_wsuri.secure != new_wsuri.secure - or old_wsuri.host != new_wsuri.host - or old_wsuri.port != new_wsuri.port + old_ws_uri.secure != new_ws_uri.secure + or old_ws_uri.host != new_ws_uri.host + or old_ws_uri.port != new_ws_uri.port ): # Cross-origin redirects on Unix sockets don't quite make sense. if self.connection_kwargs.get("unix", False): diff --git a/src/websockets/client.py b/src/websockets/client.py index 37e2a8b3a..4c22dfc12 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -50,7 +50,7 @@ class ClientProtocol(Protocol): Sans-I/O implementation of a WebSocket client connection. Args: - wsuri: URI of the WebSocket server, parsed + uri: URI of the WebSocket server, parsed with :func:`~websockets.uri.parse_uri`. origin: Value of the ``Origin`` header. This is useful when connecting to a server that validates the ``Origin`` header to defend against @@ -70,7 +70,7 @@ class ClientProtocol(Protocol): def __init__( self, - wsuri: WebSocketURI, + uri: WebSocketURI, *, origin: Origin | None = None, extensions: Sequence[ClientExtensionFactory] | None = None, @@ -85,7 +85,7 @@ def __init__( max_size=max_size, logger=logger, ) - self.wsuri = wsuri + self.uri = uri self.origin = origin self.available_extensions = extensions self.available_subprotocols = subprotocols @@ -105,12 +105,10 @@ def connect(self) -> Request: """ headers = Headers() - headers["Host"] = build_host( - self.wsuri.host, self.wsuri.port, self.wsuri.secure - ) + headers["Host"] = build_host(self.uri.host, self.uri.port, self.uri.secure) - if self.wsuri.user_info: - headers["Authorization"] = build_authorization_basic(*self.wsuri.user_info) + if self.uri.user_info: + headers["Authorization"] = build_authorization_basic(*self.uri.user_info) if self.origin is not None: headers["Origin"] = self.origin @@ -133,7 +131,7 @@ def connect(self) -> Request: protocol_header = build_subprotocol(self.available_subprotocols) headers["Sec-WebSocket-Protocol"] = protocol_header - return Request(self.wsuri.resource_name, headers) + return Request(self.uri.resource_name, headers) def process_response(self, response: Response) -> None: """ diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 8325721b7..e95036595 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -230,8 +230,8 @@ def connect( DeprecationWarning, ) - wsuri = parse_uri(uri) - if not wsuri.secure and ssl is not None: + ws_uri = parse_uri(uri) + if not ws_uri.secure and ssl is not None: raise ValueError("ssl argument is incompatible with a ws:// URI") # Private APIs for unix_connect() @@ -271,7 +271,7 @@ def connect( sock.connect(path) else: kwargs.setdefault("timeout", deadline.timeout()) - sock = socket.create_connection((wsuri.host, wsuri.port), **kwargs) + sock = socket.create_connection((ws_uri.host, ws_uri.port), **kwargs) sock.settimeout(None) # Disable Nagle algorithm @@ -281,11 +281,11 @@ def connect( # Initialize TLS wrapper and perform TLS handshake - if wsuri.secure: + if ws_uri.secure: if ssl is None: ssl = ssl_module.create_default_context() if server_hostname is None: - server_hostname = wsuri.host + server_hostname = ws_uri.host sock.settimeout(deadline.timeout()) sock = ssl.wrap_socket(sock, server_hostname=server_hostname) sock.settimeout(None) @@ -293,7 +293,7 @@ def connect( # Initialize WebSocket protocol protocol = ClientProtocol( - wsuri, + ws_uri, origin=origin, extensions=extensions, subprotocols=subprotocols, diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 4db4c038c..dd60805c1 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -248,7 +248,7 @@ def redirect(connection, request): async with serve(*args, process_request=redirect) as server: async with connect(get_uri(server) + "/redirect") as client: - self.assertEqual(client.protocol.wsuri.path, "/") + self.assertEqual(client.protocol.uri.path, "/") async def test_cross_origin_redirect(self): """Client follows redirect to a secure URI on a different origin.""" @@ -297,7 +297,7 @@ def redirect(connection, request): async with connect( "ws://overridden/redirect", host=host, port=port ) as client: - self.assertEqual(client.protocol.wsuri.path, "/") + self.assertEqual(client.protocol.uri.path, "/") async def test_cross_origin_redirect_with_explicit_host_port(self): """Client doesn't follow cross-origin redirect with an explicit host / port.""" From bd92cf70a2fc83a85ce1f20e9766cd207a11afad Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 25 Jan 2025 09:41:23 +0100 Subject: [PATCH 1486/1539] Group parameters more logically. --- src/websockets/asyncio/client.py | 9 +++++---- src/websockets/asyncio/server.py | 9 +++++---- src/websockets/sync/client.py | 9 +++++---- src/websockets/sync/server.py | 9 +++++---- tests/asyncio/test_client.py | 30 +++++++++++++++--------------- tests/asyncio/test_server.py | 32 ++++++++++++++++---------------- tests/sync/test_client.py | 30 +++++++++++++++--------------- tests/sync/test_server.py | 32 ++++++++++++++++---------------- 8 files changed, 82 insertions(+), 78 deletions(-) diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 237bb273a..bde0beeea 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -200,14 +200,14 @@ class connect: should be negotiated and run. subprotocols: List of supported subprotocols, in order of decreasing preference. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. additional_headers (HeadersLike | None): Arbitrary HTTP headers to add to the handshake request. user_agent_header: Value of the ``User-Agent`` request header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. - compression: The "permessage-deflate" extension is enabled by default. - Set ``compression`` to :obj:`None` to disable it. See the - :doc:`compression guide <../../topics/compression>` for details. process_exception: When reconnecting automatically, tell whether an error is transient or fatal. The default behavior is defined by :func:`process_exception`. Refer to its documentation for details. @@ -275,9 +275,10 @@ def __init__( origin: Origin | None = None, extensions: Sequence[ClientExtensionFactory] | None = None, subprotocols: Sequence[Subprotocol] | None = None, + compression: str | None = "deflate", + # HTTP additional_headers: HeadersLike | None = None, user_agent_header: str | None = USER_AGENT, - compression: str | None = "deflate", process_exception: Callable[[Exception], Exception | None] = process_exception, # Timeouts open_timeout: float | None = 10, diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index ebe45c2a9..2e2b78782 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -617,6 +617,9 @@ def handler(websocket): it has the same behavior as the :meth:`ServerProtocol.select_subprotocol ` method. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. process_request: Intercept the request during the opening handshake. Return an HTTP response to force the response or :obj:`None` to continue normally. When you force an HTTP 101 Continue response, the @@ -630,9 +633,6 @@ def handler(websocket): server_header: Value of the ``Server`` response header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. - compression: The "permessage-deflate" extension is enabled by default. - Set ``compression`` to :obj:`None` to disable it. See the - :doc:`compression guide <../../topics/compression>` for details. open_timeout: Timeout for opening connections in seconds. :obj:`None` disables the timeout. ping_interval: Interval between keepalive pings in seconds. @@ -694,6 +694,8 @@ def __init__( ] | None ) = None, + compression: str | None = "deflate", + # HTTP process_request: ( Callable[ [ServerConnection, Request], @@ -709,7 +711,6 @@ def __init__( | None ) = None, server_header: str | None = SERVER, - compression: str | None = "deflate", # Timeouts open_timeout: float | None = 10, ping_interval: float | None = 20, diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index e95036595..da2b88591 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -135,9 +135,10 @@ def connect( origin: Origin | None = None, extensions: Sequence[ClientExtensionFactory] | None = None, subprotocols: Sequence[Subprotocol] | None = None, + compression: str | None = "deflate", + # HTTP additional_headers: HeadersLike | None = None, user_agent_header: str | None = USER_AGENT, - compression: str | None = "deflate", # Timeouts open_timeout: float | None = 10, ping_interval: float | None = 20, @@ -180,14 +181,14 @@ def connect( should be negotiated and run. subprotocols: List of supported subprotocols, in order of decreasing preference. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. additional_headers (HeadersLike | None): Arbitrary HTTP headers to add to the handshake request. user_agent_header: Value of the ``User-Agent`` request header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. - compression: The "permessage-deflate" extension is enabled by default. - Set ``compression`` to :obj:`None` to disable it. See the - :doc:`compression guide <../../topics/compression>` for details. open_timeout: Timeout for opening the connection in seconds. :obj:`None` disables the timeout. ping_interval: Interval between keepalive pings in seconds. diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 643f9b44b..2b753b2c5 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -340,6 +340,8 @@ def serve( ] | None ) = None, + compression: str | None = "deflate", + # HTTP process_request: ( Callable[ [ServerConnection, Request], @@ -355,7 +357,6 @@ def serve( | None ) = None, server_header: str | None = SERVER, - compression: str | None = "deflate", # Timeouts open_timeout: float | None = 10, ping_interval: float | None = 20, @@ -423,6 +424,9 @@ def handler(websocket): it has the same behavior as the :meth:`ServerProtocol.select_subprotocol ` method. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. process_request: Intercept the request during the opening handshake. Return an HTTP response to force the response. Return :obj:`None` to continue normally. When you force an HTTP 101 Continue response, the @@ -435,9 +439,6 @@ def handler(websocket): server_header: Value of the ``Server`` response header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. - compression: The "permessage-deflate" extension is enabled by default. - Set ``compression`` to :obj:`None` to disable it. See the - :doc:`compression guide <../../topics/compression>` for details. open_timeout: Timeout for opening connections in seconds. :obj:`None` disables the timeout. ping_interval: Interval between keepalive pings in seconds. diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index dd60805c1..f05bfc699 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -76,6 +76,21 @@ async def test_existing_socket(self): async with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") + async def test_compression_is_enabled(self): + """Client enables compression by default.""" + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual( + [type(ext) for ext in client.protocol.extensions], + [PerMessageDeflate], + ) + + async def test_disable_compression(self): + """Client disables compression.""" + async with serve(*args) as server: + async with connect(get_uri(server), compression=None) as client: + self.assertEqual(client.protocol.extensions, []) + async def test_additional_headers(self): """Client can set additional headers with additional_headers.""" async with serve(*args) as server: @@ -96,21 +111,6 @@ async def test_remove_user_agent(self): async with connect(get_uri(server), user_agent_header=None) as client: self.assertNotIn("User-Agent", client.request.headers) - async def test_compression_is_enabled(self): - """Client enables compression by default.""" - async with serve(*args) as server: - async with connect(get_uri(server)) as client: - self.assertEqual( - [type(ext) for ext in client.protocol.extensions], - [PerMessageDeflate], - ) - - async def test_disable_compression(self): - """Client disables compression.""" - async with serve(*args) as server: - async with connect(get_uri(server), compression=None) as client: - self.assertEqual(client.protocol.extensions, []) - async def test_keepalive_is_enabled(self): """Client enables keepalive and measures latency by default.""" async with serve(*args) as server: diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index edad52da5..5c66dc727 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -117,6 +117,22 @@ def select_subprotocol(ws, subprotocols): "server rejected WebSocket connection: HTTP 500", ) + async def test_compression_is_enabled(self): + """Server enables compression by default.""" + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + await self.assertEval( + client, + "[type(ext).__name__ for ext in ws.protocol.extensions]", + "['PerMessageDeflate']", + ) + + async def test_disable_compression(self): + """Server disables compression.""" + async with serve(*args, compression=None) as server: + async with connect(get_uri(server)) as client: + await self.assertEval(client, "ws.protocol.extensions", "[]") + async def test_process_request_returns_none(self): """Server runs process_request and continues the handshake.""" @@ -359,22 +375,6 @@ async def test_remove_server(self): client, "'Server' in ws.response.headers", "False" ) - async def test_compression_is_enabled(self): - """Server enables compression by default.""" - async with serve(*args) as server: - async with connect(get_uri(server)) as client: - await self.assertEval( - client, - "[type(ext).__name__ for ext in ws.protocol.extensions]", - "['PerMessageDeflate']", - ) - - async def test_disable_compression(self): - """Server disables compression.""" - async with serve(*args, compression=None) as server: - async with connect(get_uri(server)) as client: - await self.assertEval(client, "ws.protocol.extensions", "[]") - async def test_keepalive_is_enabled(self): """Server enables keepalive and measures latency.""" async with serve(*args, ping_interval=MS) as server: diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 1669a0e84..736a84c98 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -41,6 +41,21 @@ def test_existing_socket(self): with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") + def test_compression_is_enabled(self): + """Client enables compression by default.""" + with run_server() as server: + with connect(get_uri(server)) as client: + self.assertEqual( + [type(ext) for ext in client.protocol.extensions], + [PerMessageDeflate], + ) + + def test_disable_compression(self): + """Client disables compression.""" + with run_server() as server: + with connect(get_uri(server), compression=None) as client: + self.assertEqual(client.protocol.extensions, []) + def test_additional_headers(self): """Client can set additional headers with additional_headers.""" with run_server() as server: @@ -61,21 +76,6 @@ def test_remove_user_agent(self): with connect(get_uri(server), user_agent_header=None) as client: self.assertNotIn("User-Agent", client.request.headers) - def test_compression_is_enabled(self): - """Client enables compression by default.""" - with run_server() as server: - with connect(get_uri(server)) as client: - self.assertEqual( - [type(ext) for ext in client.protocol.extensions], - [PerMessageDeflate], - ) - - def test_disable_compression(self): - """Client disables compression.""" - with run_server() as server: - with connect(get_uri(server), compression=None) as client: - self.assertEqual(client.protocol.extensions, []) - def test_keepalive_is_enabled(self): """Client enables keepalive and measures latency by default.""" with run_server() as server: diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 8e83f2a81..f59671efd 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -115,6 +115,22 @@ def select_subprotocol(ws, subprotocols): "server rejected WebSocket connection: HTTP 500", ) + def test_compression_is_enabled(self): + """Server enables compression by default.""" + with run_server() as server: + with connect(get_uri(server)) as client: + self.assertEval( + client, + "[type(ext).__name__ for ext in ws.protocol.extensions]", + "['PerMessageDeflate']", + ) + + def test_disable_compression(self): + """Server disables compression.""" + with run_server(compression=None) as server: + with connect(get_uri(server)) as client: + self.assertEval(client, "ws.protocol.extensions", "[]") + def test_process_request_returns_none(self): """Server runs process_request and continues the handshake.""" @@ -220,22 +236,6 @@ def test_remove_server(self): with connect(get_uri(server)) as client: self.assertEval(client, "'Server' in ws.response.headers", "False") - def test_compression_is_enabled(self): - """Server enables compression by default.""" - with run_server() as server: - with connect(get_uri(server)) as client: - self.assertEval( - client, - "[type(ext).__name__ for ext in ws.protocol.extensions]", - "['PerMessageDeflate']", - ) - - def test_disable_compression(self): - """Server disables compression.""" - with run_server(compression=None) as server: - with connect(get_uri(server)) as client: - self.assertEval(client, "ws.protocol.extensions", "[]") - def test_keepalive_is_enabled(self): """Server enables keepalive and measures latency.""" with run_server(ping_interval=MS) as server: From 919cd92149edc3395314282e8c5799b52d7c4bc2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 25 Jan 2025 09:48:53 +0100 Subject: [PATCH 1487/1539] Remove some excess vertical space. --- src/websockets/client.py | 21 ++++----------------- src/websockets/server.py | 21 ++------------------- 2 files changed, 6 insertions(+), 36 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 4c22dfc12..9ea21c39c 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -104,33 +104,26 @@ def connect(self) -> Request: """ headers = Headers() - headers["Host"] = build_host(self.uri.host, self.uri.port, self.uri.secure) - if self.uri.user_info: headers["Authorization"] = build_authorization_basic(*self.uri.user_info) - if self.origin is not None: headers["Origin"] = self.origin - headers["Upgrade"] = "websocket" headers["Connection"] = "Upgrade" headers["Sec-WebSocket-Key"] = self.key headers["Sec-WebSocket-Version"] = "13" - if self.available_extensions is not None: - extensions_header = build_extension( + headers["Sec-WebSocket-Extensions"] = build_extension( [ (extension_factory.name, extension_factory.get_request_params()) for extension_factory in self.available_extensions ] ) - headers["Sec-WebSocket-Extensions"] = extensions_header - if self.available_subprotocols is not None: - protocol_header = build_subprotocol(self.available_subprotocols) - headers["Sec-WebSocket-Protocol"] = protocol_header - + headers["Sec-WebSocket-Protocol"] = build_subprotocol( + self.available_subprotocols + ) return Request(self.uri.resource_name, headers) def process_response(self, response: Response) -> None: @@ -153,7 +146,6 @@ def process_response(self, response: Response) -> None: connection: list[ConnectionOption] = sum( [parse_connection(value) for value in headers.get_all("Connection")], [] ) - if not any(value.lower() == "upgrade" for value in connection): raise InvalidUpgrade( "Connection", ", ".join(connection) if connection else None @@ -162,7 +154,6 @@ def process_response(self, response: Response) -> None: upgrade: list[UpgradeProtocol] = sum( [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] ) - # For compatibility with non-strict implementations, ignore case when # checking the Upgrade header. It's supposed to be 'WebSocket'. if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"): @@ -174,12 +165,10 @@ def process_response(self, response: Response) -> None: raise InvalidHeader("Sec-WebSocket-Accept") from None except MultipleValuesError: raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from None - if s_w_accept != accept_key(self.key): raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept) self.extensions = self.process_extensions(headers) - self.subprotocol = self.process_subprotocol(headers) def process_extensions(self, headers: Headers) -> list[Extension]: @@ -279,7 +268,6 @@ def process_subprotocol(self, headers: Headers) -> Subprotocol | None: parsed_subprotocols: Sequence[Subprotocol] = sum( [parse_subprotocol(header_value) for header_value in subprotocols], [] ) - if len(parsed_subprotocols) > 1: raise InvalidHeader( "Sec-WebSocket-Protocol", @@ -287,7 +275,6 @@ def process_subprotocol(self, headers: Headers) -> Subprotocol | None: ) subprotocol = parsed_subprotocols[0] - if subprotocol not in self.available_subprotocols: raise NegotiationError(f"unsupported subprotocol: {subprotocol}") diff --git a/src/websockets/server.py b/src/websockets/server.py index 90e6c9921..174441203 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -190,19 +190,14 @@ def accept(self, request: Request) -> Response: ) headers = Headers() - headers["Date"] = email.utils.formatdate(usegmt=True) - headers["Upgrade"] = "websocket" headers["Connection"] = "Upgrade" headers["Sec-WebSocket-Accept"] = accept_header - if extensions_header is not None: headers["Sec-WebSocket-Extensions"] = extensions_header - if protocol_header is not None: headers["Sec-WebSocket-Protocol"] = protocol_header - return Response(101, "Switching Protocols", headers) def process_request( @@ -234,7 +229,6 @@ def process_request( connection: list[ConnectionOption] = sum( [parse_connection(value) for value in headers.get_all("Connection")], [] ) - if not any(value.lower() == "upgrade" for value in connection): raise InvalidUpgrade( "Connection", ", ".join(connection) if connection else None @@ -243,7 +237,6 @@ def process_request( upgrade: list[UpgradeProtocol] = sum( [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] ) - # For compatibility with non-strict implementations, ignore case when # checking the Upgrade header. The RFC always uses "websocket", except # in section 11.2. (IANA registration) where it uses "WebSocket". @@ -256,13 +249,13 @@ def process_request( raise InvalidHeader("Sec-WebSocket-Key") from None except MultipleValuesError: raise InvalidHeader("Sec-WebSocket-Key", "multiple values") from None - try: raw_key = base64.b64decode(key.encode(), validate=True) except binascii.Error as exc: raise InvalidHeaderValue("Sec-WebSocket-Key", key) from exc if len(raw_key) != 16: raise InvalidHeaderValue("Sec-WebSocket-Key", key) + accept_header = accept_key(key) try: version = headers["Sec-WebSocket-Version"] @@ -270,23 +263,14 @@ def process_request( raise InvalidHeader("Sec-WebSocket-Version") from None except MultipleValuesError: raise InvalidHeader("Sec-WebSocket-Version", "multiple values") from None - if version != "13": raise InvalidHeaderValue("Sec-WebSocket-Version", version) - accept_header = accept_key(key) - self.origin = self.process_origin(headers) - extensions_header, self.extensions = self.process_extensions(headers) - protocol_header = self.subprotocol = self.process_subprotocol(headers) - return ( - accept_header, - extensions_header, - protocol_header, - ) + return (accept_header, extensions_header, protocol_header) def process_origin(self, headers: Headers) -> Origin | None: """ @@ -426,7 +410,6 @@ def process_subprotocol(self, headers: Headers) -> Subprotocol | None: ], [], ) - return self.select_subprotocol(subprotocols) def select_subprotocol( From 1e62b3c384a9fc345c661ebdd66a2a1a0fbbff52 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 25 Jan 2025 13:59:46 +0100 Subject: [PATCH 1488/1539] Disable PyPy in CI. Refs #1581. --- .github/workflows/tests.yml | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5ab9c4c72..ca73cd499 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -60,12 +60,13 @@ jobs: - "3.11" - "3.12" - "3.13" - - "pypy-3.10" +# Disable PyPy per https://github.com/python-websockets/websockets/issues/1581 +# - "pypy-3.10" is_main: - ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} - exclude: - - python: "pypy-3.10" - is_main: false +# exclude: +# - python: "pypy-3.10" +# is_main: false steps: - name: Check out repository uses: actions/checkout@v4 From bd1796625de9a087229bb805acae86b414a4449c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 25 Jan 2025 14:19:29 +0100 Subject: [PATCH 1489/1539] Fix names of environment variables in docs. --- docs/reference/variables.rst | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/docs/reference/variables.rst b/docs/reference/variables.rst index a24773074..a55057a0d 100644 --- a/docs/reference/variables.rst +++ b/docs/reference/variables.rst @@ -54,29 +54,30 @@ Reconnection Reconnection attempts are spaced out with truncated exponential backoff. -.. envvar:: BACKOFF_INITIAL_DELAY +.. envvar:: WEBSOCKETS_BACKOFF_INITIAL_DELAY The first attempt is delayed by a random amount of time between ``0`` and - ``BACKOFF_INITIAL_DELAY`` seconds. + ``WEBSOCKETS_BACKOFF_INITIAL_DELAY`` seconds. The default value is ``5.0`` seconds. -.. envvar:: BACKOFF_MIN_DELAY +.. envvar:: WEBSOCKETS_BACKOFF_MIN_DELAY - The second attempt is delayed by ``BACKOFF_MIN_DELAY`` seconds. + The second attempt is delayed by ``WEBSOCKETS_BACKOFF_MIN_DELAY`` seconds. The default value is ``3.1`` seconds. -.. envvar:: BACKOFF_FACTOR +.. envvar:: WEBSOCKETS_BACKOFF_FACTOR - After the second attempt, the delay is multiplied by ``BACKOFF_FACTOR`` - between each attempt. + After the second attempt, the delay is multiplied by + ``WEBSOCKETS_BACKOFF_FACTOR`` between each attempt. The default value is ``1.618``. -.. envvar:: BACKOFF_MAX_DELAY +.. envvar:: WEBSOCKETS_BACKOFF_MAX_DELAY - The delay between attempts is capped at ``BACKOFF_MAX_DELAY`` seconds. + The delay between attempts is capped at ``WEBSOCKETS_BACKOFF_MAX_DELAY`` + seconds. The default value is ``90.0`` seconds. From 387197f98c3bd605338ed898c4d05a15d4ee1bfc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 26 Jan 2025 15:08:47 +0100 Subject: [PATCH 1490/1539] Mitigate test flakiness. --- tests/asyncio/test_server.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 5c66dc727..38c0315a1 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -542,7 +542,8 @@ async def test_close_server_keeps_handlers_running(self): async with asyncio_timeout(2 * MS): await server.wait_closed() - async with asyncio_timeout(3 * MS): + # Set a large timeout here, else the test becomes flaky. + async with asyncio_timeout(5 * MS): await server.wait_closed() From dfea9b311283e10511c897b9e0253571943a9123 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 26 Jan 2025 18:30:45 +0100 Subject: [PATCH 1491/1539] Standardize mocking style to @patch in tests. --- tests/asyncio/test_connection.py | 35 ++++++----------- tests/legacy/test_client_server.py | 42 +++++++++----------- tests/sync/test_connection.py | 31 ++++++--------- tests/test_frames.py | 7 +--- tests/test_protocol.py | 61 +++++++++++++++--------------- tests/test_server.py | 6 ++- 6 files changed, 79 insertions(+), 103 deletions(-) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index b53c97030..dc4539948 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -866,10 +866,9 @@ async def test_wait_closed(self): # Test ping. - @patch("random.getrandbits") + @patch("random.getrandbits", return_value=1918987876) async def test_ping(self, getrandbits): """ping sends a ping frame with a random payload.""" - getrandbits.return_value = 1918987876 await self.connection.ping() getrandbits.assert_called_once_with(32) await self.assertFrameSent(Frame(Opcode.PING, b"rand")) @@ -978,11 +977,10 @@ async def test_pong_unsupported_type(self): # Test keepalive. - @patch("random.getrandbits") + @patch("random.getrandbits", return_value=1918987876) async def test_keepalive(self, getrandbits): """keepalive sends pings at ping_interval and measures latency.""" self.connection.ping_interval = 2 * MS - getrandbits.return_value = 1918987876 self.connection.start_keepalive() self.assertEqual(self.connection.latency, 0) # 2 ms: keepalive() sends a ping frame. @@ -1000,13 +998,12 @@ async def test_disable_keepalive(self): await asyncio.sleep(3 * MS) await self.assertNoFrameSent() - @patch("random.getrandbits") + @patch("random.getrandbits", return_value=1918987876) async def test_keepalive_times_out(self, getrandbits): """keepalive closes the connection if ping_timeout elapses.""" self.connection.ping_interval = 4 * MS self.connection.ping_timeout = 2 * MS async with self.drop_frames_rcvd(): - getrandbits.return_value = 1918987876 self.connection.start_keepalive() # 4 ms: keepalive() sends a ping frame. await asyncio.sleep(4 * MS) @@ -1017,13 +1014,12 @@ async def test_keepalive_times_out(self, getrandbits): # 7 ms: check that the connection is closed. self.assertEqual(self.connection.state, State.CLOSED) - @patch("random.getrandbits") + @patch("random.getrandbits", return_value=1918987876) async def test_keepalive_ignores_timeout(self, getrandbits): """keepalive ignores timeouts if ping_timeout isn't set.""" self.connection.ping_interval = 4 * MS self.connection.ping_timeout = None async with self.drop_frames_rcvd(): - getrandbits.return_value = 1918987876 self.connection.start_keepalive() # 4 ms: keepalive() sends a ping frame. # 4.x ms: a pong frame is dropped. @@ -1142,17 +1138,13 @@ async def test_logger(self): """Connection has a logger attribute.""" self.assertIsInstance(self.connection.logger, logging.LoggerAdapter) - @unittest.mock.patch( - "asyncio.BaseTransport.get_extra_info", return_value=("sock", 1234) - ) + @patch("asyncio.BaseTransport.get_extra_info", return_value=("sock", 1234)) async def test_local_address(self, get_extra_info): """Connection provides a local_address attribute.""" self.assertEqual(self.connection.local_address, ("sock", 1234)) get_extra_info.assert_called_with("sockname") - @unittest.mock.patch( - "asyncio.BaseTransport.get_extra_info", return_value=("peer", 1234) - ) + @patch("asyncio.BaseTransport.get_extra_info", return_value=("peer", 1234)) async def test_remote_address(self, get_extra_info): """Connection provides a remote_address attribute.""" self.assertEqual(self.connection.remote_address, ("peer", 1234)) @@ -1213,12 +1205,11 @@ async def test_writing_in_send_context_fails(self): # Test safety nets — catching all exceptions in case of bugs. - @patch("websockets.protocol.Protocol.events_received") + # Inject a fault in a random call in data_received(). + # This test is tightly coupled to the implementation. + @patch("websockets.protocol.Protocol.events_received", side_effect=AssertionError) async def test_unexpected_failure_in_data_received(self, events_received): """Unexpected internal error in data_received() is correctly reported.""" - # Inject a fault in a random call in data_received(). - # This test is tightly coupled to the implementation. - events_received.side_effect = AssertionError # Receive a message to trigger the fault. await self.remote_connection.send("😀") @@ -1229,13 +1220,11 @@ async def test_unexpected_failure_in_data_received(self, events_received): self.assertEqual(str(exc), "no close frame received or sent") self.assertIsInstance(exc.__cause__, AssertionError) - @patch("websockets.protocol.Protocol.send_text") + # Inject a fault in a random call in send_context(). + # This test is tightly coupled to the implementation. + @patch("websockets.protocol.Protocol.send_text", side_effect=AssertionError) async def test_unexpected_failure_in_send_context(self, send_text): """Unexpected internal error in send_context() is correctly reported.""" - # Inject a fault in a random call in send_context(). - # This test is tightly coupled to the implementation. - send_text.side_effect = AssertionError - # Send a message to trigger the fault. # The connection closed exception reports the injected fault. with self.assertRaises(ConnectionClosedError) as raised: diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index c13c6c92e..2354db022 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -10,10 +10,10 @@ import ssl import sys import unittest -import unittest.mock import urllib.error import urllib.request import warnings +from unittest.mock import patch from websockets.asyncio.compatibility import asyncio_timeout from websockets.datastructures import Headers @@ -968,7 +968,7 @@ def test_extension_order(self): ) @with_server(extensions=[ServerNoOpExtensionFactory()]) - @unittest.mock.patch.object(WebSocketServerProtocol, "process_extensions") + @patch.object(WebSocketServerProtocol, "process_extensions") def test_extensions_error(self, _process_extensions): _process_extensions.return_value = "x-no-op", [NoOpExtension()] @@ -978,7 +978,7 @@ def test_extensions_error(self, _process_extensions): ) @with_server(extensions=[ServerNoOpExtensionFactory()]) - @unittest.mock.patch.object(WebSocketServerProtocol, "process_extensions") + @patch.object(WebSocketServerProtocol, "process_extensions") def test_extensions_error_no_extensions(self, _process_extensions): _process_extensions.return_value = "x-no-op", [NoOpExtension()] @@ -1051,7 +1051,7 @@ def test_subprotocol_not_requested(self): self.assertEqual(self.client.subprotocol, None) @with_server(subprotocols=["superchat"]) - @unittest.mock.patch.object(WebSocketServerProtocol, "process_subprotocol") + @patch.object(WebSocketServerProtocol, "process_subprotocol") def test_subprotocol_error(self, _process_subprotocol): _process_subprotocol.return_value = "superchat" @@ -1060,7 +1060,7 @@ def test_subprotocol_error(self, _process_subprotocol): self.run_loop_once() @with_server(subprotocols=["superchat"]) - @unittest.mock.patch.object(WebSocketServerProtocol, "process_subprotocol") + @patch.object(WebSocketServerProtocol, "process_subprotocol") def test_subprotocol_error_no_subprotocols(self, _process_subprotocol): _process_subprotocol.return_value = "superchat" @@ -1069,7 +1069,7 @@ def test_subprotocol_error_no_subprotocols(self, _process_subprotocol): self.run_loop_once() @with_server(subprotocols=["superchat", "chat"]) - @unittest.mock.patch.object(WebSocketServerProtocol, "process_subprotocol") + @patch.object(WebSocketServerProtocol, "process_subprotocol") def test_subprotocol_error_two_subprotocols(self, _process_subprotocol): _process_subprotocol.return_value = "superchat, chat" @@ -1078,7 +1078,7 @@ def test_subprotocol_error_two_subprotocols(self, _process_subprotocol): self.run_loop_once() @with_server() - @unittest.mock.patch("websockets.legacy.server.read_request") + @patch("websockets.legacy.server.read_request") def test_server_receives_malformed_request(self, _read_request): _read_request.side_effect = ValueError("read_request failed") @@ -1086,7 +1086,7 @@ def test_server_receives_malformed_request(self, _read_request): self.start_client() @with_server() - @unittest.mock.patch("websockets.legacy.client.read_response") + @patch("websockets.legacy.client.read_response") def test_client_receives_malformed_response(self, _read_response): _read_response.side_effect = ValueError("read_response failed") @@ -1095,7 +1095,7 @@ def test_client_receives_malformed_response(self, _read_response): self.run_loop_once() @with_server() - @unittest.mock.patch("websockets.legacy.client.build_request") + @patch("websockets.legacy.client.build_request") def test_client_sends_invalid_handshake_request(self, _build_request): def wrong_build_request(headers): return "42" @@ -1106,7 +1106,7 @@ def wrong_build_request(headers): self.start_client() @with_server() - @unittest.mock.patch("websockets.legacy.server.build_response") + @patch("websockets.legacy.server.build_response") def test_server_sends_invalid_handshake_response(self, _build_response): def wrong_build_response(headers, key): return build_response(headers, "42") @@ -1117,7 +1117,7 @@ def wrong_build_response(headers, key): self.start_client() @with_server() - @unittest.mock.patch("websockets.legacy.client.read_response") + @patch("websockets.legacy.client.read_response") def test_server_does_not_switch_protocols(self, _read_response): async def wrong_read_response(stream): status_code, reason, headers = await read_response(stream) @@ -1130,9 +1130,7 @@ async def wrong_read_response(stream): self.run_loop_once() @with_server() - @unittest.mock.patch( - "websockets.legacy.server.WebSocketServerProtocol.process_request" - ) + @patch("websockets.legacy.server.WebSocketServerProtocol.process_request") def test_server_error_in_handshake(self, _process_request): _process_request.side_effect = Exception("process_request crashed") @@ -1156,7 +1154,7 @@ async def cancelled_client(): sock.send(b"") # socket is closed @with_server() - @unittest.mock.patch("websockets.legacy.server.WebSocketServerProtocol.send") + @patch("websockets.legacy.server.WebSocketServerProtocol.send") def test_server_handler_crashes(self, send): send.side_effect = ValueError("send failed") @@ -1169,7 +1167,7 @@ def test_server_handler_crashes(self, send): self.assertEqual(self.client.close_code, CloseCode.INTERNAL_ERROR) @with_server() - @unittest.mock.patch("websockets.legacy.server.WebSocketServerProtocol.close") + @patch("websockets.legacy.server.WebSocketServerProtocol.close") def test_server_close_crashes(self, close): close.side_effect = ValueError("close failed") @@ -1183,7 +1181,7 @@ def test_server_close_crashes(self, close): @with_server() @with_client() - @unittest.mock.patch.object(WebSocketClientProtocol, "handshake") + @patch.object(WebSocketClientProtocol, "handshake") def test_client_closes_connection_before_handshake(self, handshake): # We have mocked the handshake() method to prevent the client from # performing the opening handshake. Force it to close the connection. @@ -1254,12 +1252,8 @@ def test_invalid_status_error_during_client_connect(self): self.assertEqual(exception.status_code, 403) @with_server() - @unittest.mock.patch( - "websockets.legacy.server.WebSocketServerProtocol.write_http_response" - ) - @unittest.mock.patch( - "websockets.legacy.server.WebSocketServerProtocol.read_http_request" - ) + @patch("websockets.legacy.server.WebSocketServerProtocol.write_http_response") + @patch("websockets.legacy.server.WebSocketServerProtocol.read_http_request") def test_connection_error_during_opening_handshake( self, _read_http_request, _write_http_response ): @@ -1277,7 +1271,7 @@ def test_connection_error_during_opening_handshake( _write_http_response.assert_not_called() @with_server() - @unittest.mock.patch("websockets.legacy.server.WebSocketServerProtocol.close") + @patch("websockets.legacy.server.WebSocketServerProtocol.close") def test_connection_error_during_closing_handshake(self, close): close.side_effect = ConnectionError diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index f191d8944..be7ff36f4 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -645,10 +645,9 @@ def fragments(): # Test ping. - @patch("random.getrandbits") + @patch("random.getrandbits", return_value=1918987876) def test_ping(self, getrandbits): """ping sends a ping frame with a random payload.""" - getrandbits.return_value = 1918987876 self.connection.ping() getrandbits.assert_called_once_with(32) self.assertFrameSent(Frame(Opcode.PING, b"rand")) @@ -740,11 +739,10 @@ def test_pong_unsupported_type(self): # Test keepalive. - @patch("random.getrandbits") + @patch("random.getrandbits", return_value=1918987876) def test_keepalive(self, getrandbits): """keepalive sends pings at ping_interval and measures latency.""" self.connection.ping_interval = 2 * MS - getrandbits.return_value = 1918987876 self.connection.start_keepalive() self.assertEqual(self.connection.latency, 0) # 2 ms: keepalive() sends a ping frame. @@ -762,13 +760,12 @@ def test_disable_keepalive(self): time.sleep(3 * MS) self.assertNoFrameSent() - @patch("random.getrandbits") + @patch("random.getrandbits", return_value=1918987876) def test_keepalive_times_out(self, getrandbits): """keepalive closes the connection if ping_timeout elapses.""" self.connection.ping_interval = 4 * MS self.connection.ping_timeout = 2 * MS with self.drop_frames_rcvd(): - getrandbits.return_value = 1918987876 self.connection.start_keepalive() # 4 ms: keepalive() sends a ping frame. time.sleep(4 * MS) @@ -779,13 +776,12 @@ def test_keepalive_times_out(self, getrandbits): # 7 ms: check that the connection is closed. self.assertEqual(self.connection.state, State.CLOSED) - @patch("random.getrandbits") + @patch("random.getrandbits", return_value=1918987876) def test_keepalive_ignores_timeout(self, getrandbits): """keepalive ignores timeouts if ping_timeout isn't set.""" self.connection.ping_interval = 4 * MS self.connection.ping_timeout = None with self.drop_frames_rcvd(): - getrandbits.return_value = 1918987876 self.connection.start_keepalive() # 4 ms: keepalive() sends a ping frame. time.sleep(4 * MS) @@ -910,13 +906,13 @@ def test_logger(self): """Connection has a logger attribute.""" self.assertIsInstance(self.connection.logger, logging.LoggerAdapter) - @unittest.mock.patch("socket.socket.getsockname", return_value=("sock", 1234)) + @patch("socket.socket.getsockname", return_value=("sock", 1234)) def test_local_address(self, getsockname): """Connection provides a local_address attribute.""" self.assertEqual(self.connection.local_address, ("sock", 1234)) getsockname.assert_called_with() - @unittest.mock.patch("socket.socket.getpeername", return_value=("peer", 1234)) + @patch("socket.socket.getpeername", return_value=("peer", 1234)) def test_remote_address(self, getpeername): """Connection provides a remote_address attribute.""" self.assertEqual(self.connection.remote_address, ("peer", 1234)) @@ -984,12 +980,11 @@ def test_writing_in_send_context_fails(self): # Test safety nets — catching all exceptions in case of bugs. - @patch("websockets.protocol.Protocol.events_received") + # Inject a fault in a random call in recv_events(). + # This test is tightly coupled to the implementation. + @patch("websockets.protocol.Protocol.events_received", side_effect=AssertionError) def test_unexpected_failure_in_recv_events(self, events_received): """Unexpected internal error in recv_events() is correctly reported.""" - # Inject a fault in a random call in recv_events(). - # This test is tightly coupled to the implementation. - events_received.side_effect = AssertionError # Receive a message to trigger the fault. self.remote_connection.send("😀") @@ -1000,13 +995,11 @@ def test_unexpected_failure_in_recv_events(self, events_received): self.assertEqual(str(exc), "no close frame received or sent") self.assertIsInstance(exc.__cause__, AssertionError) - @patch("websockets.protocol.Protocol.send_text") + # Inject a fault in a random call in send_context(). + # This test is tightly coupled to the implementation. + @patch("websockets.protocol.Protocol.send_text", side_effect=AssertionError) def test_unexpected_failure_in_send_context(self, send_text): """Unexpected internal error in send_context() is correctly reported.""" - # Inject a fault in a random call in send_context(). - # This test is tightly coupled to the implementation. - send_text.side_effect = AssertionError - # Send a message to trigger the fault. # The connection closed exception reports the injected fault. with self.assertRaises(ConnectionClosedError) as raised: diff --git a/tests/test_frames.py b/tests/test_frames.py index 857b41fe4..1c372b5de 100644 --- a/tests/test_frames.py +++ b/tests/test_frames.py @@ -1,7 +1,7 @@ import codecs import dataclasses import unittest -import unittest.mock +from unittest.mock import patch from websockets.exceptions import PayloadTooBig, ProtocolError from websockets.frames import * @@ -12,9 +12,6 @@ class FramesTestCase(GeneratorTestCase): - def enforce_mask(self, mask): - return unittest.mock.patch("secrets.token_bytes", return_value=mask) - def parse(self, data, mask, max_size=None, extensions=None): """ Parse a frame from a bytestring. @@ -41,7 +38,7 @@ def assertFrameData(self, frame, data, mask, extensions=None): # Make masking deterministic by reusing the same "random" mask. # This has an effect only when mask is True. mask_bytes = data[2:6] if mask else b"" - with self.enforce_mask(mask_bytes): + with patch("secrets.token_bytes", return_value=mask_bytes): serialized = frame.serialize(mask=mask, extensions=extensions) self.assertEqual(serialized, data) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 1c092459d..9e2d65041 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -1,5 +1,6 @@ import logging -import unittest.mock +import unittest +from unittest.mock import patch from websockets.exceptions import ( ConnectionClosedError, @@ -106,7 +107,7 @@ class MaskingTests(ProtocolTestCase): def test_client_sends_masked_frame(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\xff\x00\xff"): + with patch("secrets.token_bytes", return_value=b"\x00\xff\x00\xff"): client.send_text(b"Spam", True) self.assertEqual(client.data_to_send(), [self.masked_text_frame_data]) @@ -191,7 +192,7 @@ def test_client_sends_continuation_after_sending_close(self): # Since it isn't possible to send a close frame in a fragmented # message (see test_client_send_close_in_fragmented_message), in fact, # this is the same test as test_client_sends_unexpected_continuation. - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) with self.assertRaises(ProtocolError) as raised: @@ -234,7 +235,7 @@ class TextTests(ProtocolTestCase): def test_client_sends_text(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_text("😀".encode()) self.assertEqual( client.data_to_send(), [b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80"] @@ -307,15 +308,15 @@ def test_server_receives_text_without_size_limit(self): def test_client_sends_fragmented_text(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_text("😀".encode()[:2], fin=False) self.assertEqual(client.data_to_send(), [b"\x01\x82\x00\x00\x00\x00\xf0\x9f"]) - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_continuation("😀😀".encode()[2:6], fin=False) self.assertEqual( client.data_to_send(), [b"\x00\x84\x00\x00\x00\x00\x98\x80\xf0\x9f"] ) - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_continuation("😀".encode()[2:], fin=True) self.assertEqual(client.data_to_send(), [b"\x80\x82\x00\x00\x00\x00\x98\x80"]) @@ -482,7 +483,7 @@ def test_server_receives_unexpected_text(self): def test_client_sends_text_after_sending_close(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) with self.assertRaises(InvalidState) as raised: @@ -522,7 +523,7 @@ class BinaryTests(ProtocolTestCase): def test_client_sends_binary(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_binary(b"\x01\x02\xfe\xff") self.assertEqual( client.data_to_send(), [b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff"] @@ -579,15 +580,15 @@ def test_server_receives_binary_over_size_limit(self): def test_client_sends_fragmented_binary(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_binary(b"\x01\x02", fin=False) self.assertEqual(client.data_to_send(), [b"\x02\x82\x00\x00\x00\x00\x01\x02"]) - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_continuation(b"\xee\xff\x01\x02", fin=False) self.assertEqual( client.data_to_send(), [b"\x00\x84\x00\x00\x00\x00\xee\xff\x01\x02"] ) - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_continuation(b"\xee\xff", fin=True) self.assertEqual(client.data_to_send(), [b"\x80\x82\x00\x00\x00\x00\xee\xff"]) @@ -718,7 +719,7 @@ def test_server_receives_unexpected_binary(self): def test_client_sends_binary_after_sending_close(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) with self.assertRaises(InvalidState) as raised: @@ -806,7 +807,7 @@ def test_close_reason_not_available_yet(self): def test_client_sends_close(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x3c\x3c\x3c\x3c"): + with patch("secrets.token_bytes", return_value=b"\x3c\x3c\x3c\x3c"): client.send_close() self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) self.assertIs(client.state, CLOSING) @@ -819,7 +820,7 @@ def test_server_sends_close(self): def test_client_receives_close(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x3c\x3c\x3c\x3c"): + with patch("secrets.token_bytes", return_value=b"\x3c\x3c\x3c\x3c"): client.receive_data(b"\x88\x00") self.assertEqual(client.events_received(), [Frame(OP_CLOSE, b"")]) self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) @@ -890,7 +891,7 @@ def test_server_receives_close_then_sends_close(self): def test_client_sends_close_with_code(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) self.assertIs(client.state, CLOSING) @@ -915,7 +916,7 @@ def test_server_receives_close_with_code(self): def test_client_sends_close_with_code_and_reason(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY, "going away") self.assertEqual( client.data_to_send(), [b"\x88\x8c\x00\x00\x00\x00\x03\xe9going away"] @@ -1002,7 +1003,7 @@ def test_server_receives_close_with_non_utf8_reason(self): def test_client_sends_close_twice(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) with self.assertRaises(InvalidState) as raised: @@ -1040,7 +1041,7 @@ class PingTests(ProtocolTestCase): def test_client_sends_ping(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\x44\x88\xcc"): + with patch("secrets.token_bytes", return_value=b"\x00\x44\x88\xcc"): client.send_ping(b"") self.assertEqual(client.data_to_send(), [b"\x89\x80\x00\x44\x88\xcc"]) @@ -1075,7 +1076,7 @@ def test_server_receives_ping(self): def test_client_sends_ping_with_data(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\x44\x88\xcc"): + with patch("secrets.token_bytes", return_value=b"\x00\x44\x88\xcc"): client.send_ping(b"\x22\x66\xaa\xee") self.assertEqual( client.data_to_send(), [b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22"] @@ -1144,10 +1145,10 @@ def test_server_receives_fragmented_ping_frame(self): def test_client_sends_ping_after_sending_close(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - with self.enforce_mask(b"\x00\x44\x88\xcc"): + with patch("secrets.token_bytes", return_value=b"\x00\x44\x88\xcc"): client.send_ping(b"") self.assertEqual(client.data_to_send(), [b"\x89\x80\x00\x44\x88\xcc"]) @@ -1199,7 +1200,7 @@ class PongTests(ProtocolTestCase): def test_client_sends_pong(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\x44\x88\xcc"): + with patch("secrets.token_bytes", return_value=b"\x00\x44\x88\xcc"): client.send_pong(b"") self.assertEqual(client.data_to_send(), [b"\x8a\x80\x00\x44\x88\xcc"]) @@ -1226,7 +1227,7 @@ def test_server_receives_pong(self): def test_client_sends_pong_with_data(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\x44\x88\xcc"): + with patch("secrets.token_bytes", return_value=b"\x00\x44\x88\xcc"): client.send_pong(b"\x22\x66\xaa\xee") self.assertEqual( client.data_to_send(), [b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22"] @@ -1287,10 +1288,10 @@ def test_server_receives_fragmented_pong_frame(self): def test_client_sends_pong_after_sending_close(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - with self.enforce_mask(b"\x00\x44\x88\xcc"): + with patch("secrets.token_bytes", return_value=b"\x00\x44\x88\xcc"): client.send_pong(b"") self.assertEqual(client.data_to_send(), [b"\x8a\x80\x00\x44\x88\xcc"]) @@ -1459,7 +1460,7 @@ def test_client_send_close_in_fragmented_message(self): client = Protocol(CLIENT) client.send_text(b"Spam", fin=False) self.assertFrameSent(client, Frame(OP_TEXT, b"Spam", fin=False)) - with self.enforce_mask(b"\x3c\x3c\x3c\x3c"): + with patch("secrets.token_bytes", return_value=b"\x3c\x3c\x3c\x3c"): client.send_close() self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) self.assertIs(client.state, CLOSING) @@ -1819,7 +1820,7 @@ class ErrorTests(ProtocolTestCase): def test_client_hits_internal_error_reading_frame(self): client = Protocol(CLIENT) # This isn't supposed to happen, so we're simulating it. - with unittest.mock.patch("struct.unpack", side_effect=RuntimeError("BOOM")): + with patch("struct.unpack", side_effect=RuntimeError("BOOM")): client.receive_data(b"\x81\x00") self.assertIsInstance(client.parser_exc, RuntimeError) self.assertEqual(str(client.parser_exc), "BOOM") @@ -1828,7 +1829,7 @@ def test_client_hits_internal_error_reading_frame(self): def test_server_hits_internal_error_reading_frame(self): server = Protocol(SERVER) # This isn't supposed to happen, so we're simulating it. - with unittest.mock.patch("struct.unpack", side_effect=RuntimeError("BOOM")): + with patch("struct.unpack", side_effect=RuntimeError("BOOM")): server.receive_data(b"\x81\x80\x00\x00\x00\x00") self.assertIsInstance(server.parser_exc, RuntimeError) self.assertEqual(str(server.parser_exc), "BOOM") @@ -1844,7 +1845,7 @@ class ExtensionsTests(ProtocolTestCase): def test_client_extension_encodes_frame(self): client = Protocol(CLIENT) client.extensions = [Rsv2Extension()] - with self.enforce_mask(b"\x00\x44\x88\xcc"): + with patch("secrets.token_bytes", return_value=b"\x00\x44\x88\xcc"): client.send_ping(b"") self.assertEqual(client.data_to_send(), [b"\xa9\x80\x00\x44\x88\xcc"]) diff --git a/tests/test_server.py b/tests/test_server.py index 9f328ded5..43970a7cd 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -314,12 +314,14 @@ def test_reject_response_supports_int_status(self): self.assertEqual(response.status_code, 404) self.assertEqual(response.reason_phrase, "Not Found") - @patch("websockets.server.ServerProtocol.process_request") + @patch( + "websockets.server.ServerProtocol.process_request", + side_effect=Exception("BOOM"), + ) def test_unexpected_error(self, process_request): """accept() handles unexpected errors and returns an error response.""" server = ServerProtocol() request = make_request() - process_request.side_effect = (Exception("BOOM"),) response = server.accept(request) self.assertEqual(response.status_code, 500) From 2d515c84acef39e9642358f442575bdc2d04c8de Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 26 Jan 2025 18:35:10 +0100 Subject: [PATCH 1492/1539] Increase test timing to reduce flakiness. --- tests/asyncio/test_connection.py | 14 +++++++------- tests/sync/test_connection.py | 14 +++++++------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index dc4539948..5230eca89 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -980,13 +980,14 @@ async def test_pong_unsupported_type(self): @patch("random.getrandbits", return_value=1918987876) async def test_keepalive(self, getrandbits): """keepalive sends pings at ping_interval and measures latency.""" - self.connection.ping_interval = 2 * MS + self.connection.ping_interval = 3 * MS self.connection.start_keepalive() + self.assertIsNotNone(self.connection.keepalive_task) self.assertEqual(self.connection.latency, 0) - # 2 ms: keepalive() sends a ping frame. - # 2.x ms: a pong frame is received. - await asyncio.sleep(3 * MS) - # 3 ms: check that the ping frame was sent. + # 3 ms: keepalive() sends a ping frame. + # 3.x ms: a pong frame is received. + await asyncio.sleep(4 * MS) + # 4 ms: check that the ping frame was sent. await self.assertFrameSent(Frame(Opcode.PING, b"rand")) self.assertGreater(self.connection.latency, 0) self.assertLess(self.connection.latency, MS) @@ -995,8 +996,7 @@ async def test_disable_keepalive(self): """keepalive is disabled when ping_interval is None.""" self.connection.ping_interval = None self.connection.start_keepalive() - await asyncio.sleep(3 * MS) - await self.assertNoFrameSent() + self.assertIsNone(self.connection.keepalive_task) @patch("random.getrandbits", return_value=1918987876) async def test_keepalive_times_out(self, getrandbits): diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index be7ff36f4..a5aee35bb 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -742,13 +742,14 @@ def test_pong_unsupported_type(self): @patch("random.getrandbits", return_value=1918987876) def test_keepalive(self, getrandbits): """keepalive sends pings at ping_interval and measures latency.""" - self.connection.ping_interval = 2 * MS + self.connection.ping_interval = 4 * MS self.connection.start_keepalive() + self.assertIsNotNone(self.connection.keepalive_thread) self.assertEqual(self.connection.latency, 0) - # 2 ms: keepalive() sends a ping frame. - # 2.x ms: a pong frame is received. - time.sleep(3 * MS) - # 3 ms: check that the ping frame was sent. + # 3 ms: keepalive() sends a ping frame. + # 3.x ms: a pong frame is received. + time.sleep(4 * MS) + # 4 ms: check that the ping frame was sent. self.assertFrameSent(Frame(Opcode.PING, b"rand")) self.assertGreater(self.connection.latency, 0) self.assertLess(self.connection.latency, MS) @@ -757,8 +758,7 @@ def test_disable_keepalive(self): """keepalive is disabled when ping_interval is None.""" self.connection.ping_interval = None self.connection.start_keepalive() - time.sleep(3 * MS) - self.assertNoFrameSent() + self.assertIsNone(self.connection.keepalive_thread) @patch("random.getrandbits", return_value=1918987876) def test_keepalive_times_out(self, getrandbits): From 4ea521f62dfbda4335f21db6604260c83a0fac67 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 25 Jan 2025 13:20:54 +0100 Subject: [PATCH 1493/1539] Add helpers to locate proxy for client connections. --- docs/reference/exceptions.rst | 10 +- src/websockets/__init__.py | 9 ++ src/websockets/exceptions.py | 42 +++++++- src/websockets/uri.py | 128 +++++++++++++++++++++++- tests/test_exceptions.py | 12 +++ tests/test_uri.py | 182 ++++++++++++++++++++++++++++++++-- tests/utils.py | 16 +++ 7 files changed, 382 insertions(+), 17 deletions(-) diff --git a/docs/reference/exceptions.rst b/docs/reference/exceptions.rst index d6b7f0f57..e0c2efdd1 100644 --- a/docs/reference/exceptions.rst +++ b/docs/reference/exceptions.rst @@ -28,14 +28,20 @@ also reported by :func:`~websockets.asyncio.server.serve` in logs. .. autoexception:: InvalidURI -.. autoexception:: InvalidHandshake +.. autoexception:: InvalidProxy -.. autoexception:: InvalidMessage +.. autoexception:: InvalidHandshake .. autoexception:: SecurityError +.. autoexception:: InvalidMessage + .. autoexception:: InvalidStatus +.. autoexception:: InvalidProxyMessage + +.. autoexception:: InvalidProxyStatus + .. autoexception:: InvalidHeader .. autoexception:: InvalidHeaderFormat diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 1d0abe5cd..8bf282a73 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -39,6 +39,9 @@ "InvalidOrigin", "InvalidParameterName", "InvalidParameterValue", + "InvalidProxy", + "InvalidProxyMessage", + "InvalidProxyStatus", "InvalidState", "InvalidStatus", "InvalidUpgrade", @@ -99,6 +102,9 @@ InvalidOrigin, InvalidParameterName, InvalidParameterValue, + InvalidProxy, + InvalidProxyMessage, + InvalidProxyStatus, InvalidState, InvalidStatus, InvalidUpgrade, @@ -157,6 +163,9 @@ "InvalidOrigin": ".exceptions", "InvalidParameterName": ".exceptions", "InvalidParameterValue": ".exceptions", + "InvalidProxy": ".exceptions", + "InvalidProxyMessage": ".exceptions", + "InvalidProxyStatus": ".exceptions", "InvalidState": ".exceptions", "InvalidStatus": ".exceptions", "InvalidUpgrade": ".exceptions", diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 73b24debf..e70aac92e 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -6,11 +6,14 @@ * :exc:`ConnectionClosedOK` * :exc:`ConnectionClosedError` * :exc:`InvalidURI` + * :exc:`InvalidProxy` * :exc:`InvalidHandshake` * :exc:`SecurityError` * :exc:`InvalidMessage` * :exc:`InvalidStatus` * :exc:`InvalidStatusCode` (legacy) + * :exc:`InvalidProxyMessage` + * :exc:`InvalidProxyStatus` * :exc:`InvalidHeader` * :exc:`InvalidHeaderFormat` * :exc:`InvalidHeaderValue` @@ -42,13 +45,16 @@ "ConnectionClosedOK", "ConnectionClosedError", "InvalidURI", + "InvalidProxy", "InvalidHandshake", "SecurityError", + "InvalidMessage", "InvalidStatus", + "InvalidProxyMessage", + "InvalidProxyStatus", "InvalidHeader", "InvalidHeaderFormat", "InvalidHeaderValue", - "InvalidMessage", "InvalidOrigin", "InvalidUpgrade", "NegotiationError", @@ -169,6 +175,20 @@ def __str__(self) -> str: return f"{self.uri} isn't a valid URI: {self.msg}" +class InvalidProxy(WebSocketException): + """ + Raised when connecting via a proxy that isn't valid. + + """ + + def __init__(self, proxy: str, msg: str) -> None: + self.proxy = proxy + self.msg = msg + + def __str__(self) -> str: + return f"{self.proxy} isn't a valid proxy: {self.msg}" + + class InvalidHandshake(WebSocketException): """ Base class for exceptions raised when the opening handshake fails. @@ -208,6 +228,26 @@ def __str__(self) -> str: ) +class InvalidProxyMessage(InvalidHandshake): + """ + Raised when a proxy response is malformed. + + """ + + +class InvalidProxyStatus(InvalidHandshake): + """ + Raised when a proxy rejects the connection. + + """ + + def __init__(self, response: http11.Response) -> None: + self.response = response + + def __str__(self) -> str: + return f"proxy rejected connection: HTTP {self.response.status_code:d}" + + class InvalidHeader(InvalidHandshake): """ Raised when an HTTP header doesn't have a valid format or value. diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 16bb3f1c1..b925b99b5 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -2,13 +2,18 @@ import dataclasses import urllib.parse +import urllib.request -from .exceptions import InvalidURI +from .exceptions import InvalidProxy, InvalidURI __all__ = ["parse_uri", "WebSocketURI"] +# All characters from the gen-delims and sub-delims sets in RFC 3987. +DELIMS = ":/?#[]@!$&'()*+,;=" + + @dataclasses.dataclass class WebSocketURI: """ @@ -53,10 +58,6 @@ def user_info(self) -> tuple[str, str] | None: return (self.username, self.password) -# All characters from the gen-delims and sub-delims sets in RFC 3987. -DELIMS = ":/?#[]@!$&'()*+,;=" - - def parse_uri(uri: str) -> WebSocketURI: """ Parse and validate a WebSocket URI. @@ -105,3 +106,120 @@ def parse_uri(uri: str) -> WebSocketURI: password = urllib.parse.quote(password, safe=DELIMS) return WebSocketURI(secure, host, port, path, query, username, password) + + +@dataclasses.dataclass +class Proxy: + """ + Proxy. + + Attributes: + scheme: ``"socks5h"``, ``"socks5"``, ``"socks4a"``, ``"socks4"``, + ``"https"``, or ``"http"``. + host: Normalized to lower case. + port: Always set even if it's the default. + username: Available when the proxy address contains `User Information`_. + password: Available when the proxy address contains `User Information`_. + + .. _User Information: https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.1 + + """ + + scheme: str + host: str + port: int + username: str | None = None + password: str | None = None + + @property + def user_info(self) -> tuple[str, str] | None: + if self.username is None: + return None + assert self.password is not None + return (self.username, self.password) + + +def parse_proxy(proxy: str) -> Proxy: + """ + Parse and validate a proxy. + + Args: + proxy: proxy. + + Returns: + Parsed proxy. + + Raises: + InvalidProxy: If ``proxy`` isn't a valid proxy. + + """ + parsed = urllib.parse.urlparse(proxy) + if parsed.scheme not in ["socks5h", "socks5", "socks4a", "socks4", "https", "http"]: + raise InvalidProxy(proxy, f"scheme {parsed.scheme} isn't supported") + if parsed.hostname is None: + raise InvalidProxy(proxy, "hostname isn't provided") + if parsed.path not in ["", "/"]: + raise InvalidProxy(proxy, "path is meaningless") + if parsed.query != "": + raise InvalidProxy(proxy, "query is meaningless") + if parsed.fragment != "": + raise InvalidProxy(proxy, "fragment is meaningless") + + scheme = parsed.scheme + host = parsed.hostname + port = parsed.port or (443 if parsed.scheme == "https" else 80) + username = parsed.username + password = parsed.password + # urllib.parse.urlparse accepts URLs with a username but without a + # password. This doesn't make sense for HTTP Basic Auth credentials. + if username is not None and password is None: + raise InvalidProxy(proxy, "username provided without password") + + try: + proxy.encode("ascii") + except UnicodeEncodeError: + # Input contains non-ASCII characters. + # It must be an IRI. Convert it to a URI. + host = host.encode("idna").decode() + if username is not None: + assert password is not None + username = urllib.parse.quote(username, safe=DELIMS) + password = urllib.parse.quote(password, safe=DELIMS) + + return Proxy(scheme, host, port, username, password) + + +def get_proxy(uri: WebSocketURI) -> str | None: + """ + Return the proxy to use for connecting to the given WebSocket URI, if any. + + """ + if urllib.request.proxy_bypass(f"{uri.host}:{uri.port}"): + return None + + # According to the _Proxy Usage_ section of RFC 6455, use a SOCKS5 proxy if + # available, else favor the proxy for HTTPS connections over the proxy for + # HTTP connections. + + # The priority of a proxy for WebSocket connections is unspecified. We give + # it the highest priority. This makes it easy to configure a specific proxy + # for websockets. + + # getproxies() may return SOCKS proxies as {"socks": "http://host:port"} or + # as {"https": "socks5h://host:port"} depending on whether they're declared + # in the operating system or in environment variables. + + proxies = urllib.request.getproxies() + if uri.secure: + schemes = ["wss", "socks", "https"] + else: + schemes = ["ws", "socks", "https", "http"] + + for scheme in schemes: + proxy = proxies.get(scheme) + if proxy is not None: + if scheme == "socks" and proxy.startswith("http://"): + proxy = "socks5h://" + proxy[7:] + return proxy + else: + return None diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index e0518b0e0..8b437ab5e 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -83,6 +83,10 @@ def test_str(self): InvalidURI("|", "not at all!"), "| isn't a valid URI: not at all!", ), + ( + InvalidProxy("|", "not at all!"), + "| isn't a valid proxy: not at all!", + ), ( InvalidHandshake("invalid request"), "invalid request", @@ -99,6 +103,14 @@ def test_str(self): InvalidStatus(Response(401, "Unauthorized", Headers())), "server rejected WebSocket connection: HTTP 401", ), + ( + InvalidProxyMessage("malformed HTTP message"), + "malformed HTTP message", + ), + ( + InvalidProxyStatus(Response(401, "Unauthorized", Headers())), + "proxy rejected connection: HTTP 401", + ), ( InvalidHeader("Name"), "missing Name header", diff --git a/tests/test_uri.py b/tests/test_uri.py index 8acc01c18..35b51fa58 100644 --- a/tests/test_uri.py +++ b/tests/test_uri.py @@ -1,7 +1,10 @@ import unittest -from websockets.exceptions import InvalidURI +from websockets.exceptions import InvalidProxy, InvalidURI from websockets.uri import * +from websockets.uri import Proxy, get_proxy, parse_proxy + +from .utils import patch_environ VALID_URIS = [ @@ -59,38 +62,199 @@ "ws:///path", ] -RESOURCE_NAMES = [ +URIS_WITH_RESOURCE_NAMES = [ ("ws://localhost/", "/"), ("ws://localhost", "/"), ("ws://localhost/path?query", "/path?query"), ("ws://høst/πass?qùéry", "/%CF%80ass?q%C3%B9%C3%A9ry"), ] -USER_INFOS = [ +URIS_WITH_USER_INFO = [ ("ws://localhost/", None), ("ws://user:pass@localhost/", ("user", "pass")), ("ws://üser:påss@høst/", ("%C3%BCser", "p%C3%A5ss")), ] +VALID_PROXIES = [ + ( + "http://proxy:8080", + Proxy("http", "proxy", 8080, None, None), + ), + ( + "https://proxy:8080", + Proxy("https", "proxy", 8080, None, None), + ), + ( + "http://proxy", + Proxy("http", "proxy", 80, None, None), + ), + ( + "http://proxy:8080/", + Proxy("http", "proxy", 8080, None, None), + ), + ( + "http://PROXY:8080", + Proxy("http", "proxy", 8080, None, None), + ), + ( + "http://user:pass@proxy:8080", + Proxy("http", "proxy", 8080, "user", "pass"), + ), + ( + "http://høst:8080/", + Proxy("http", "xn--hst-0na", 8080, None, None), + ), + ( + "http://üser:påss@høst:8080", + Proxy("http", "xn--hst-0na", 8080, "%C3%BCser", "p%C3%A5ss"), + ), +] + +INVALID_PROXIES = [ + "ws://proxy:8080", + "wss://proxy:8080", + "http://proxy:8080/path", + "http://proxy:8080/?query", + "http://proxy:8080/#fragment", + "http://user@proxy", + "http:///", +] + +PROXIES_WITH_USER_INFO = [ + ("http://proxy", None), + ("http://user:pass@proxy", ("user", "pass")), + ("http://üser:påss@høst", ("%C3%BCser", "p%C3%A5ss")), +] + +PROXY_ENVS = [ + ( + {"ws_proxy": "http://proxy:8080"}, + "ws://example.com/", + "http://proxy:8080", + ), + ( + {"ws_proxy": "http://proxy:8080"}, + "wss://example.com/", + None, + ), + ( + {"wss_proxy": "http://proxy:8080"}, + "ws://example.com/", + None, + ), + ( + {"wss_proxy": "http://proxy:8080"}, + "wss://example.com/", + "http://proxy:8080", + ), + ( + {"http_proxy": "http://proxy:8080"}, + "ws://example.com/", + "http://proxy:8080", + ), + ( + {"http_proxy": "http://proxy:8080"}, + "wss://example.com/", + None, + ), + ( + {"https_proxy": "http://proxy:8080"}, + "ws://example.com/", + "http://proxy:8080", + ), + ( + {"https_proxy": "http://proxy:8080"}, + "wss://example.com/", + "http://proxy:8080", + ), + ( + {"socks_proxy": "http://proxy:1080"}, + "ws://example.com/", + "socks5h://proxy:1080", + ), + ( + {"socks_proxy": "http://proxy:1080"}, + "wss://example.com/", + "socks5h://proxy:1080", + ), + ( + {"ws_proxy": "http://proxy1:8080", "wss_proxy": "http://proxy2:8080"}, + "ws://example.com/", + "http://proxy1:8080", + ), + ( + {"ws_proxy": "http://proxy1:8080", "wss_proxy": "http://proxy2:8080"}, + "wss://example.com/", + "http://proxy2:8080", + ), + ( + {"http_proxy": "http://proxy1:8080", "https_proxy": "http://proxy2:8080"}, + "ws://example.com/", + "http://proxy2:8080", + ), + ( + {"http_proxy": "http://proxy1:8080", "https_proxy": "http://proxy2:8080"}, + "wss://example.com/", + "http://proxy2:8080", + ), + ( + {"https_proxy": "http://proxy:8080", "socks_proxy": "http://proxy:1080"}, + "ws://example.com/", + "socks5h://proxy:1080", + ), + ( + {"https_proxy": "http://proxy:8080", "socks_proxy": "http://proxy:1080"}, + "wss://example.com/", + "socks5h://proxy:1080", + ), + ( + {"socks_proxy": "http://proxy:1080", "no_proxy": ".local"}, + "ws://example.local/", + None, + ), +] + class URITests(unittest.TestCase): - def test_success(self): + def test_parse_valid_uris(self): for uri, parsed in VALID_URIS: with self.subTest(uri=uri): self.assertEqual(parse_uri(uri), parsed) - def test_error(self): + def test_parse_invalid_uris(self): for uri in INVALID_URIS: with self.subTest(uri=uri): with self.assertRaises(InvalidURI): parse_uri(uri) - def test_resource_name(self): - for uri, resource_name in RESOURCE_NAMES: + def test_parse_resource_name(self): + for uri, resource_name in URIS_WITH_RESOURCE_NAMES: with self.subTest(uri=uri): self.assertEqual(parse_uri(uri).resource_name, resource_name) - def test_user_info(self): - for uri, user_info in USER_INFOS: + def test_parse_user_info(self): + for uri, user_info in URIS_WITH_USER_INFO: with self.subTest(uri=uri): self.assertEqual(parse_uri(uri).user_info, user_info) + + def test_parse_valid_proxies(self): + for proxy, parsed in VALID_PROXIES: + with self.subTest(proxy=proxy): + self.assertEqual(parse_proxy(proxy), parsed) + + def test_parse_invalid_proxies(self): + for proxy in INVALID_PROXIES: + with self.subTest(proxy=proxy): + with self.assertRaises(InvalidProxy): + parse_proxy(proxy) + + def test_parse_proxy_user_info(self): + for proxy, user_info in PROXIES_WITH_USER_INFO: + with self.subTest(proxy=proxy): + self.assertEqual(parse_proxy(proxy).user_info, user_info) + + def test_get_proxy(self): + for environ, uri, proxy in PROXY_ENVS: + with patch_environ(environ): + with self.subTest(environ=environ, uri=uri): + self.assertEqual(get_proxy(parse_uri(uri)), proxy) diff --git a/tests/utils.py b/tests/utils.py index 77d020726..f68a447b1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -139,6 +139,22 @@ def assertNoLogs(self, logger=None, level=None): self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) +@contextlib.contextmanager +def patch_environ(environ): + backup = {} + for key, value in environ.items(): + backup[key] = os.environ.get(key) + os.environ[key] = value + try: + yield + finally: + for key, value in backup.items(): + if value is None: + del os.environ[key] + else: # pragma: no cover + os.environ[key] = value + + @contextlib.contextmanager def temp_unix_socket_path(): with tempfile.TemporaryDirectory() as temp_dir: From 4a89e5616ffed1a8662fe195ad14827bb93a9bed Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 25 Jan 2025 17:18:27 +0100 Subject: [PATCH 1494/1539] Add support for SOCKS proxies. Fix #475. --- docs/project/changelog.rst | 19 ++++++- docs/reference/features.rst | 3 +- docs/topics/index.rst | 1 + docs/topics/proxies.rst | 66 +++++++++++++++++++++++ src/websockets/asyncio/client.py | 64 +++++++++++++++++++++-- src/websockets/sync/client.py | 68 ++++++++++++++++++++++-- src/websockets/version.py | 2 +- tests/asyncio/test_client.py | 83 ++++++++++++++++++++++++++++- tests/proxy.py | 89 ++++++++++++++++++++++++++++++++ tests/requirements.txt | 2 + tests/sync/test_client.py | 80 +++++++++++++++++++++++++++- tox.ini | 21 ++++++-- 12 files changed, 479 insertions(+), 19 deletions(-) create mode 100644 docs/topics/proxies.rst create mode 100644 tests/proxy.py create mode 100644 tests/requirements.txt diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 7f341d942..2a429b43e 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,13 +25,28 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. -.. _14.3: +.. _15.0: -14.3 +15.0 ---- *In development* +Backwards-incompatible changes +.............................. + +.. admonition:: Client connections use SOCKS proxies automatically. + :class: important + + If a proxy is configured in the operating system or with an environment + variable, websockets uses it automatically when connecting to a server. + This feature requires installing the third-party library `python-socks`_. + + If you want to disable the proxy, add ``proxy=None`` when calling + :func:`~asyncio.client.connect`. See :doc:`../topics/proxies` for details. + + .. _python-socks: https://github.com/romis2012/python-socks + New features ............ diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 6ba42f66b..eaecd02a9 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -168,11 +168,10 @@ Client +------------------------------------+--------+--------+--------+--------+ | Connect via HTTP proxy (`#364`_) | ❌ | ❌ | — | ❌ | +------------------------------------+--------+--------+--------+--------+ - | Connect via SOCKS5 proxy (`#475`_) | ❌ | ❌ | — | ❌ | + | Connect via SOCKS5 proxy | ✅ | ✅ | — | ❌ | +------------------------------------+--------+--------+--------+--------+ .. _#364: https://github.com/python-websockets/websockets/issues/364 -.. _#475: https://github.com/python-websockets/websockets/issues/475 .. _#784: https://github.com/python-websockets/websockets/issues/784 Known limitations diff --git a/docs/topics/index.rst b/docs/topics/index.rst index 616753c6c..a08d487c9 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -15,3 +15,4 @@ Get a deeper understanding of how websockets is built and why. memory security performance + proxies diff --git a/docs/topics/proxies.rst b/docs/topics/proxies.rst new file mode 100644 index 000000000..fd3ae78b6 --- /dev/null +++ b/docs/topics/proxies.rst @@ -0,0 +1,66 @@ +Proxies +======= + +.. currentmodule:: websockets + +If a proxy is configured in the operating system or with an environment +variable, websockets uses it automatically when connecting to a server. + +Configuration +------------- + +First, if the server is in the proxy bypass list of the operating system or in +the ``no_proxy`` environment variable, websockets connects directly. + +Then, it looks for a proxy in the following locations: + +1. The ``wss_proxy`` or ``ws_proxy`` environment variables for ``wss://`` and + ``ws://`` connections respectively. They allow configuring a specific proxy + for WebSocket connections. +2. A SOCKS proxy configured in the operating system. +3. An HTTP proxy configured in the operating system or in the ``https_proxy`` + environment variable, for both ``wss://`` and ``ws://`` connections. +4. An HTTP proxy configured in the operating system or in the ``http_proxy`` + environment variable, only for ``ws://`` connections. + +Finally, if no proxy is found, websockets connects directly. + +While environment variables are case-insensitive, the lower-case spelling is the +most common, for `historical reasons`_, and recommended. + +.. _historical reasons: https://unix.stackexchange.com/questions/212894/ + +.. admonition:: Any environment variable can configure a SOCKS proxy or an HTTP proxy. + :class: tip + + For example, ``https_proxy=socks5h://proxy:1080/`` configures a SOCKS proxy + for all WebSocket connections. Likewise, ``wss_proxy=http://proxy:8080/`` + configures an HTTP proxy only for ``wss://`` connections. + +.. admonition:: What if websockets doesn't select the right proxy? + :class: hint + + websockets relies on :func:`~urllib.request.getproxies()` to read the proxy + configuration. Check that it returns what you expect. If it doesn't, review + your proxy configuration. + +You can override the default configuration and configure a proxy explicitly with +the ``proxy`` argument of :func:`~asyncio.client.connect`. Set ``proxy=None`` to +disable the proxy. + +SOCKS proxies +------------- + +Connecting through a SOCKS proxy requires installing the third-party library +`python-socks`_:: + + $ pip install python-socks\[asyncio\] + +.. _python-socks: https://github.com/romis2012/python-socks + +python-socks supports SOCKS4, SOCKS4a, SOCKS5, and SOCKS5h. The protocol version +is configured in the address of the proxy e.g. ``socks5h://proxy:1080/``. When a +SOCKS proxy is configured in the operating system, python-socks uses SOCKS5h. + +python-socks supports username/password authentication for SOCKS5 (:rfc:`1929`) +but does not support other authentication methods such as GSSAPI (:rfc:`1961`). diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index bde0beeea..f76095ead 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -7,7 +7,7 @@ import urllib.parse from collections.abc import AsyncIterator, Generator, Sequence from types import TracebackType -from typing import Any, Callable +from typing import Any, Callable, Literal from ..client import ClientProtocol, backoff from ..datastructures import HeadersLike @@ -18,7 +18,7 @@ from ..http11 import USER_AGENT, Response from ..protocol import CONNECTING, Event from ..typing import LoggerLike, Origin, Subprotocol -from ..uri import WebSocketURI, parse_uri +from ..uri import Proxy, WebSocketURI, get_proxy, parse_proxy, parse_uri from .compatibility import TimeoutError, asyncio_timeout from .connection import Connection @@ -208,6 +208,10 @@ class connect: user_agent_header: Value of the ``User-Agent`` request header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. + proxy: If a proxy is configured, it is used by default. Set ``proxy`` + to :obj:`None` to disable the proxy or to the address of a proxy + to override the system configuration. See the :doc:`proxy docs + <../../topics/proxies>` for details. process_exception: When reconnecting automatically, tell whether an error is transient or fatal. The default behavior is defined by :func:`process_exception`. Refer to its documentation for details. @@ -279,6 +283,7 @@ def __init__( # HTTP additional_headers: HeadersLike | None = None, user_agent_header: str | None = USER_AGENT, + proxy: str | Literal[True] | None = True, process_exception: Callable[[Exception], Exception | None] = process_exception, # Timeouts open_timeout: float | None = 10, @@ -333,6 +338,7 @@ def protocol_factory(uri: WebSocketURI) -> ClientConnection: ) return connection + self.proxy = proxy self.protocol_factory = protocol_factory self.handshake_args = ( additional_headers, @@ -346,9 +352,20 @@ def protocol_factory(uri: WebSocketURI) -> ClientConnection: async def create_connection(self) -> ClientConnection: """Create TCP or Unix connection.""" loop = asyncio.get_running_loop() + kwargs = self.connection_kwargs.copy() ws_uri = parse_uri(self.uri) - kwargs = self.connection_kwargs.copy() + + proxy = self.proxy + proxy_uri: Proxy | None = None + if kwargs.get("unix", False): + proxy = None + if kwargs.get("sock") is not None: + proxy = None + if proxy is True: + proxy = get_proxy(ws_uri) + if proxy is not None: + proxy_uri = parse_proxy(proxy) def factory() -> ClientConnection: return self.protocol_factory(ws_uri) @@ -365,6 +382,47 @@ def factory() -> ClientConnection: if kwargs.pop("unix", False): _, connection = await loop.create_unix_connection(factory, **kwargs) else: + if proxy_uri is not None: + if proxy_uri.scheme[:5] == "socks": + try: + from python_socks import ProxyType + from python_socks.async_.asyncio import Proxy + except ImportError: + raise ImportError( + "python-socks is required to use a SOCKS proxy" + ) + if proxy_uri.scheme == "socks5h": + proxy_type = ProxyType.SOCKS5 + rdns = True + elif proxy_uri.scheme == "socks5": + proxy_type = ProxyType.SOCKS5 + rdns = False + # We use mitmproxy for testing and it doesn't support SOCKS4. + elif proxy_uri.scheme == "socks4a": # pragma: no cover + proxy_type = ProxyType.SOCKS4 + rdns = True + elif proxy_uri.scheme == "socks4": # pragma: no cover + proxy_type = ProxyType.SOCKS4 + rdns = False + # Proxy types are enforced in parse_proxy(). + else: + raise AssertionError("unsupported SOCKS proxy") + socks_proxy = Proxy( + proxy_type, + proxy_uri.host, + proxy_uri.port, + proxy_uri.username, + proxy_uri.password, + rdns, + ) + kwargs["sock"] = await socks_proxy.connect( + ws_uri.host, + ws_uri.port, + local_addr=kwargs.pop("local_addr", None), + ) + # Proxy types are enforced in parse_proxy(). + else: + raise AssertionError("unsupported proxy") if kwargs.get("sock") is None: kwargs.setdefault("host", ws_uri.host) kwargs.setdefault("port", ws_uri.port) diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index da2b88591..96f62edab 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -5,7 +5,7 @@ import threading import warnings from collections.abc import Sequence -from typing import Any +from typing import Any, Literal from ..client import ClientProtocol from ..datastructures import HeadersLike @@ -15,7 +15,7 @@ from ..http11 import USER_AGENT, Response from ..protocol import CONNECTING, Event from ..typing import LoggerLike, Origin, Subprotocol -from ..uri import parse_uri +from ..uri import Proxy, get_proxy, parse_proxy, parse_uri from .connection import Connection from .utils import Deadline @@ -139,6 +139,7 @@ def connect( # HTTP additional_headers: HeadersLike | None = None, user_agent_header: str | None = USER_AGENT, + proxy: str | Literal[True] | None = True, # Timeouts open_timeout: float | None = 10, ping_interval: float | None = 20, @@ -189,6 +190,10 @@ def connect( user_agent_header: Value of the ``User-Agent`` request header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. + proxy: If a proxy is configured, it is used by default. Set ``proxy`` + to :obj:`None` to disable the proxy or to the address of a proxy + to override the system configuration. See the :doc:`proxy docs + <../../topics/proxies>` for details. open_timeout: Timeout for opening the connection in seconds. :obj:`None` disables the timeout. ping_interval: Interval between keepalive pings in seconds. @@ -253,6 +258,16 @@ def connect( elif compression is not None: raise ValueError(f"unsupported compression: {compression}") + proxy_uri: Proxy | None = None + if unix: + proxy = None + if sock is not None: + proxy = None + if proxy is True: + proxy = get_proxy(ws_uri) + if proxy is not None: + proxy_uri = parse_proxy(proxy) + # Calculate timeouts on the TCP, TLS, and WebSocket handshakes. # The TCP and TLS timeouts must be set on the socket, then removed # to avoid conflicting with the WebSocket timeout in handshake(). @@ -271,8 +286,53 @@ def connect( assert path is not None # mypy cannot figure this out sock.connect(path) else: - kwargs.setdefault("timeout", deadline.timeout()) - sock = socket.create_connection((ws_uri.host, ws_uri.port), **kwargs) + if proxy_uri is not None: + if proxy_uri.scheme[:5] == "socks": + try: + from python_socks import ProxyType + from python_socks.sync import Proxy + except ImportError: + raise ImportError( + "python-socks is required to use a SOCKS proxy" + ) + if proxy_uri.scheme == "socks5h": + proxy_type = ProxyType.SOCKS5 + rdns = True + elif proxy_uri.scheme == "socks5": + proxy_type = ProxyType.SOCKS5 + rdns = False + # We use mitmproxy for testing and it doesn't support SOCKS4. + elif proxy_uri.scheme == "socks4a": # pragma: no cover + proxy_type = ProxyType.SOCKS4 + rdns = True + elif proxy_uri.scheme == "socks4": # pragma: no cover + proxy_type = ProxyType.SOCKS4 + rdns = False + # Proxy types are enforced in parse_proxy(). + else: + raise AssertionError("unsupported SOCKS proxy") + socks_proxy = Proxy( + proxy_type, + proxy_uri.host, + proxy_uri.port, + proxy_uri.username, + proxy_uri.password, + rdns, + ) + sock = socks_proxy.connect( + ws_uri.host, + ws_uri.port, + timeout=deadline.timeout(), + local_addr=kwargs.pop("local_addr", None), + ) + # Proxy types are enforced in parse_proxy(). + else: + raise AssertionError("unsupported proxy") + else: + kwargs.setdefault("timeout", deadline.timeout()) + sock = socket.create_connection( + (ws_uri.host, ws_uri.port), **kwargs + ) sock.settimeout(None) # Disable Nagle algorithm diff --git a/src/websockets/version.py b/src/websockets/version.py index ca9a9115b..611e7d238 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -20,7 +20,7 @@ released = False -tag = version = commit = "14.3" +tag = version = commit = "15.0" if not released: # pragma: no cover diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index f05bfc699..cb2b8ede6 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -4,6 +4,7 @@ import logging import socket import ssl +import sys import unittest from websockets.asyncio.client import * @@ -13,13 +14,21 @@ from websockets.exceptions import ( InvalidHandshake, InvalidMessage, + InvalidProxy, InvalidStatus, InvalidURI, SecurityError, ) from websockets.extensions.permessage_deflate import PerMessageDeflate -from ..utils import CLIENT_CONTEXT, MS, SERVER_CONTEXT, temp_unix_socket_path +from ..proxy import async_proxy +from ..utils import ( + CLIENT_CONTEXT, + MS, + SERVER_CONTEXT, + patch_environ, + temp_unix_socket_path, +) from .server import args, get_host_port, get_uri, handler @@ -555,6 +564,78 @@ def redirect(connection, request): ) +@unittest.skipUnless("mitmproxy" in sys.modules, "mitmproxy not installed") +class ProxyClientTests(unittest.IsolatedAsyncioTestCase): + @contextlib.asynccontextmanager + async def socks_proxy(self, auth=None): + if auth: + proxyauth = "hello:iloveyou" + proxy_uri = "http://hello:iloveyou@localhost:1080" + else: + proxyauth = None + proxy_uri = "http://localhost:1080" + async with async_proxy(mode=["socks5"], proxyauth=proxyauth) as record_flows: + with patch_environ({"socks_proxy": proxy_uri}): + yield record_flows + + async def test_socks_proxy(self): + """Client connects to server through a SOCKS5 proxy.""" + async with self.socks_proxy() as proxy: + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(len(proxy.get_flows()), 1) + + async def test_secure_socks_proxy(self): + """Client connects to server securely through a SOCKS5 proxy.""" + async with self.socks_proxy() as proxy: + async with serve(*args, ssl=SERVER_CONTEXT) as server: + async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(len(proxy.get_flows()), 1) + + async def test_authenticated_socks_proxy(self): + """Client connects to server through an authenticated SOCKS5 proxy.""" + async with self.socks_proxy(auth=True) as proxy: + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(len(proxy.get_flows()), 1) + + async def test_explicit_proxy(self): + """Client connects to server through a proxy set explicitly.""" + async with async_proxy(mode=["socks5"]) as proxy: + async with serve(*args) as server: + async with connect( + get_uri(server), + # Take this opportunity to test socks5 instead of socks5h. + proxy="socks5://localhost:1080", + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(len(proxy.get_flows()), 1) + + async def test_ignore_proxy_with_existing_socket(self): + """Client connects using a pre-existing socket.""" + async with self.socks_proxy() as proxy: + async with serve(*args) as server: + with socket.create_connection(get_host_port(server)) as sock: + # Use a non-existing domain to ensure we connect to sock. + async with connect("ws://invalid/", sock=sock) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(len(proxy.get_flows()), 0) + + async def test_unsupported_proxy(self): + """Client connects to server through an unsupported proxy.""" + with patch_environ({"ws_proxy": "other://localhost:1080"}): + with self.assertRaises(InvalidProxy) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "other://localhost:1080 isn't a valid proxy: scheme other isn't supported", + ) + + @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") class UnixClientTests(unittest.IsolatedAsyncioTestCase): async def test_connection(self): diff --git a/tests/proxy.py b/tests/proxy.py new file mode 100644 index 000000000..95525a360 --- /dev/null +++ b/tests/proxy.py @@ -0,0 +1,89 @@ +import asyncio +import contextlib +import pathlib +import threading +import warnings + + +warnings.filterwarnings("ignore", category=DeprecationWarning, module="mitmproxy") +warnings.filterwarnings("ignore", category=DeprecationWarning, module="passlib") +warnings.filterwarnings("ignore", category=DeprecationWarning, module="pyasn1") + +try: + from mitmproxy.addons import core, next_layer, proxyauth, proxyserver, tlsconfig + from mitmproxy.master import Master + from mitmproxy.options import Options +except ImportError: + pass + + +class RecordFlows: + def __init__(self): + self.ready = asyncio.get_running_loop().create_future() + self.flows = [] + + def running(self): + self.ready.set_result(None) + + def websocket_start(self, flow): + self.flows.append(flow) + + def get_flows(self): + flows, self.flows[:] = self.flows[:], [] + return flows + + +@contextlib.asynccontextmanager +async def async_proxy(mode, **config): + options = Options(mode=mode) + master = Master(options) + record_flows = RecordFlows() + master.addons.add( + core.Core(), + proxyauth.ProxyAuth(), + proxyserver.Proxyserver(), + next_layer.NextLayer(), + tlsconfig.TlsConfig(), + record_flows, + ) + config.update( + # Use our test certificate for TLS between client and proxy + # and disable TLS verification between proxy and upstream. + certs=[str(pathlib.Path(__file__).with_name("test_localhost.pem"))], + ssl_insecure=True, + ) + options.update(**config) + + asyncio.create_task(master.run()) + try: + await record_flows.ready + yield record_flows + finally: + for server in master.addons.get("proxyserver").servers: + await server.stop() + master.shutdown() + + +@contextlib.contextmanager +def sync_proxy(mode, **config): + loop = None + test_done = None + proxy_ready = threading.Event() + record_flows = None + + async def proxy_coroutine(): + nonlocal loop, test_done, proxy_ready, record_flows + loop = asyncio.get_running_loop() + test_done = loop.create_future() + async with async_proxy(mode, **config) as record_flows: + proxy_ready.set() + await test_done + + proxy_thread = threading.Thread(target=asyncio.run, args=(proxy_coroutine(),)) + proxy_thread.start() + try: + proxy_ready.wait() + yield record_flows + finally: + loop.call_soon_threadsafe(test_done.set_result, None) + proxy_thread.join() diff --git a/tests/requirements.txt b/tests/requirements.txt new file mode 100644 index 000000000..f375e6f69 --- /dev/null +++ b/tests/requirements.txt @@ -0,0 +1,2 @@ +python-socks[asyncio] +mitmproxy diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 736a84c98..2f62dd34d 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -1,8 +1,10 @@ +import contextlib import http import logging import socket import socketserver import ssl +import sys import threading import time import unittest @@ -10,17 +12,20 @@ from websockets.exceptions import ( InvalidHandshake, InvalidMessage, + InvalidProxy, InvalidStatus, InvalidURI, ) from websockets.extensions.permessage_deflate import PerMessageDeflate from websockets.sync.client import * +from ..proxy import sync_proxy from ..utils import ( CLIENT_CONTEXT, MS, SERVER_CONTEXT, DeprecationTestCase, + patch_environ, temp_unix_socket_path, ) from .server import get_uri, run_server, run_unix_server @@ -37,7 +42,7 @@ def test_existing_socket(self): """Client connects using a pre-existing socket.""" with run_server() as server: with socket.create_connection(server.socket.getsockname()) as sock: - # Use a non-existing domain to ensure we connect to the right socket. + # Use a non-existing domain to ensure we connect to sock. with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") @@ -300,6 +305,79 @@ def test_reject_invalid_server_hostname(self): ) +@unittest.skipUnless("mitmproxy" in sys.modules, "mitmproxy not installed") +class ProxyClientTests(unittest.TestCase): + @contextlib.contextmanager + def socks_proxy(self, auth=None): + if auth: + proxyauth = "hello:iloveyou" + proxy_uri = "http://hello:iloveyou@localhost:1080" + else: + proxyauth = None + proxy_uri = "http://localhost:1080" + + with sync_proxy(mode=["socks5"], proxyauth=proxyauth) as record_flows: + with patch_environ({"socks_proxy": proxy_uri}): + yield record_flows + + def test_socks_proxy(self): + """Client connects to server through a SOCKS5 proxy.""" + with self.socks_proxy() as proxy: + with run_server() as server: + with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(len(proxy.get_flows()), 1) + + def test_secure_socks_proxy(self): + """Client connects to server securely through a SOCKS5 proxy.""" + with self.socks_proxy() as proxy: + with run_server(ssl=SERVER_CONTEXT) as server: + with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(len(proxy.get_flows()), 1) + + def test_authenticated_socks_proxy(self): + """Client connects to server through an authenticated SOCKS5 proxy.""" + with self.socks_proxy(auth=True) as proxy: + with run_server() as server: + with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(len(proxy.get_flows()), 1) + + def test_explicit_proxy(self): + """Client connects to server through a proxy set explicitly.""" + with sync_proxy(mode=["socks5"]) as proxy: + with run_server() as server: + with connect( + get_uri(server), + # Take this opportunity to test socks5 instead of socks5h. + proxy="socks5://localhost:1080", + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(len(proxy.get_flows()), 1) + + def test_ignore_proxy_with_existing_socket(self): + """Client connects using a pre-existing socket.""" + with self.socks_proxy() as proxy: + with run_server() as server: + with socket.create_connection(server.socket.getsockname()) as sock: + # Use a non-existing domain to ensure we connect to sock. + with connect("ws://invalid/", sock=sock) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(len(proxy.get_flows()), 0) + + def test_unsupported_proxy(self): + """Client connects to server through an unsupported proxy.""" + with patch_environ({"ws_proxy": "other://localhost:1080"}): + with self.assertRaises(InvalidProxy) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "other://localhost:1080 isn't a valid proxy: scheme other isn't supported", + ) + + @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") class UnixClientTests(unittest.TestCase): def test_connection(self): diff --git a/tox.ini b/tox.ini index 0bcec5ded..f5a2f5d3c 100644 --- a/tox.ini +++ b/tox.ini @@ -12,27 +12,38 @@ env_list = [testenv] commands = python -W error::DeprecationWarning -W error::PendingDeprecationWarning -m unittest {posargs} -pass_env = WEBSOCKETS_* +pass_env = + WEBSOCKETS_* +deps = + mitmproxy + python-socks[asyncio] [testenv:coverage] commands = python -m coverage run --source {envsitepackagesdir}/websockets,tests -m unittest {posargs} python -m coverage report --show-missing --fail-under=100 -deps = coverage +deps = + coverage + {[testenv]deps} [testenv:maxi_cov] commands = python tests/maxi_cov.py {envsitepackagesdir} python -m coverage report --show-missing --fail-under=100 -deps = coverage +deps = + coverage + {[testenv]deps} [testenv:ruff] commands = ruff format --check src tests ruff check src tests -deps = ruff +deps = + ruff [testenv:mypy] commands = mypy --strict src -deps = mypy +deps = + mypy + python-socks From 10175f7a41ea1cf17aff7983d7eace3c2e4da5c0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 27 Jan 2025 22:07:21 +0100 Subject: [PATCH 1495/1539] Refactor SOCKS proxy implementation. --- src/websockets/asyncio/client.py | 109 +++++++++++++++----------- src/websockets/sync/client.py | 127 ++++++++++++++++++------------- 2 files changed, 141 insertions(+), 95 deletions(-) diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index f76095ead..7052ca85a 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -3,6 +3,7 @@ import asyncio import logging import os +import socket import traceback import urllib.parse from collections.abc import AsyncIterator, Generator, Sequence @@ -357,15 +358,12 @@ async def create_connection(self) -> ClientConnection: ws_uri = parse_uri(self.uri) proxy = self.proxy - proxy_uri: Proxy | None = None if kwargs.get("unix", False): proxy = None if kwargs.get("sock") is not None: proxy = None if proxy is True: proxy = get_proxy(ws_uri) - if proxy is not None: - proxy_uri = parse_proxy(proxy) def factory() -> ClientConnection: return self.protocol_factory(ws_uri) @@ -381,48 +379,14 @@ def factory() -> ClientConnection: if kwargs.pop("unix", False): _, connection = await loop.create_unix_connection(factory, **kwargs) + elif proxy is not None: + kwargs["sock"] = await connect_proxy( + parse_proxy(proxy), + ws_uri, + local_addr=kwargs.pop("local_addr", None), + ) + _, connection = await loop.create_connection(factory, **kwargs) else: - if proxy_uri is not None: - if proxy_uri.scheme[:5] == "socks": - try: - from python_socks import ProxyType - from python_socks.async_.asyncio import Proxy - except ImportError: - raise ImportError( - "python-socks is required to use a SOCKS proxy" - ) - if proxy_uri.scheme == "socks5h": - proxy_type = ProxyType.SOCKS5 - rdns = True - elif proxy_uri.scheme == "socks5": - proxy_type = ProxyType.SOCKS5 - rdns = False - # We use mitmproxy for testing and it doesn't support SOCKS4. - elif proxy_uri.scheme == "socks4a": # pragma: no cover - proxy_type = ProxyType.SOCKS4 - rdns = True - elif proxy_uri.scheme == "socks4": # pragma: no cover - proxy_type = ProxyType.SOCKS4 - rdns = False - # Proxy types are enforced in parse_proxy(). - else: - raise AssertionError("unsupported SOCKS proxy") - socks_proxy = Proxy( - proxy_type, - proxy_uri.host, - proxy_uri.port, - proxy_uri.username, - proxy_uri.password, - rdns, - ) - kwargs["sock"] = await socks_proxy.connect( - ws_uri.host, - ws_uri.port, - local_addr=kwargs.pop("local_addr", None), - ) - # Proxy types are enforced in parse_proxy(). - else: - raise AssertionError("unsupported proxy") if kwargs.get("sock") is None: kwargs.setdefault("host", ws_uri.host) kwargs.setdefault("port", ws_uri.port) @@ -624,3 +588,60 @@ def unix_connect( else: uri = "wss://localhost/" return connect(uri=uri, unix=True, path=path, **kwargs) + + +try: + from python_socks import ProxyType + from python_socks.async_.asyncio import Proxy as SocksProxy + + SOCKS_PROXY_TYPES = { + "socks5h": ProxyType.SOCKS5, + "socks5": ProxyType.SOCKS5, + "socks4a": ProxyType.SOCKS4, + "socks4": ProxyType.SOCKS4, + } + + SOCKS_PROXY_RDNS = { + "socks5h": True, + "socks5": False, + "socks4a": True, + "socks4": False, + } + + async def connect_socks_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + **kwargs: Any, + ) -> socket.socket: + """Connect via a SOCKS proxy and return the socket.""" + socks_proxy = SocksProxy( + SOCKS_PROXY_TYPES[proxy.scheme], + proxy.host, + proxy.port, + proxy.username, + proxy.password, + SOCKS_PROXY_RDNS[proxy.scheme], + ) + return await socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs) + +except ImportError: + + async def connect_socks_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + **kwargs: Any, + ) -> socket.socket: + raise ImportError("python-socks is required to use a SOCKS proxy") + + +async def connect_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + **kwargs: Any, +) -> socket.socket: + """Connect via a proxy and return the socket.""" + # parse_proxy() validates proxy.scheme. + if proxy.scheme[:5] == "socks": + return await connect_socks_proxy(proxy, ws_uri, **kwargs) + else: + raise AssertionError("unsupported proxy") diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 96f62edab..5fbec67ad 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -15,7 +15,7 @@ from ..http11 import USER_AGENT, Response from ..protocol import CONNECTING, Event from ..typing import LoggerLike, Origin, Subprotocol -from ..uri import Proxy, get_proxy, parse_proxy, parse_uri +from ..uri import Proxy, WebSocketURI, get_proxy, parse_proxy, parse_uri from .connection import Connection from .utils import Deadline @@ -258,15 +258,12 @@ def connect( elif compression is not None: raise ValueError(f"unsupported compression: {compression}") - proxy_uri: Proxy | None = None if unix: proxy = None if sock is not None: proxy = None if proxy is True: proxy = get_proxy(ws_uri) - if proxy is not None: - proxy_uri = parse_proxy(proxy) # Calculate timeouts on the TCP, TLS, and WebSocket handshakes. # The TCP and TLS timeouts must be set on the socket, then removed @@ -285,54 +282,21 @@ def connect( sock.settimeout(deadline.timeout()) assert path is not None # mypy cannot figure this out sock.connect(path) + elif proxy is not None: + sock = connect_proxy( + parse_proxy(proxy), + ws_uri, + deadline, + # websockets is consistent with the socket module while + # python_socks is consistent across implementations. + local_addr=kwargs.pop("source_address", None), + ) else: - if proxy_uri is not None: - if proxy_uri.scheme[:5] == "socks": - try: - from python_socks import ProxyType - from python_socks.sync import Proxy - except ImportError: - raise ImportError( - "python-socks is required to use a SOCKS proxy" - ) - if proxy_uri.scheme == "socks5h": - proxy_type = ProxyType.SOCKS5 - rdns = True - elif proxy_uri.scheme == "socks5": - proxy_type = ProxyType.SOCKS5 - rdns = False - # We use mitmproxy for testing and it doesn't support SOCKS4. - elif proxy_uri.scheme == "socks4a": # pragma: no cover - proxy_type = ProxyType.SOCKS4 - rdns = True - elif proxy_uri.scheme == "socks4": # pragma: no cover - proxy_type = ProxyType.SOCKS4 - rdns = False - # Proxy types are enforced in parse_proxy(). - else: - raise AssertionError("unsupported SOCKS proxy") - socks_proxy = Proxy( - proxy_type, - proxy_uri.host, - proxy_uri.port, - proxy_uri.username, - proxy_uri.password, - rdns, - ) - sock = socks_proxy.connect( - ws_uri.host, - ws_uri.port, - timeout=deadline.timeout(), - local_addr=kwargs.pop("local_addr", None), - ) - # Proxy types are enforced in parse_proxy(). - else: - raise AssertionError("unsupported proxy") - else: - kwargs.setdefault("timeout", deadline.timeout()) - sock = socket.create_connection( - (ws_uri.host, ws_uri.port), **kwargs - ) + kwargs.setdefault("timeout", deadline.timeout()) + sock = socket.create_connection( + (ws_uri.host, ws_uri.port), + **kwargs, + ) sock.settimeout(None) # Disable Nagle algorithm @@ -420,3 +384,64 @@ def unix_connect( else: uri = "wss://localhost/" return connect(uri=uri, unix=True, path=path, **kwargs) + + +try: + from python_socks import ProxyType + from python_socks.sync import Proxy as SocksProxy + + SOCKS_PROXY_TYPES = { + "socks5h": ProxyType.SOCKS5, + "socks5": ProxyType.SOCKS5, + "socks4a": ProxyType.SOCKS4, + "socks4": ProxyType.SOCKS4, + } + + SOCKS_PROXY_RDNS = { + "socks5h": True, + "socks5": False, + "socks4a": True, + "socks4": False, + } + + def connect_socks_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + deadline: Deadline, + **kwargs: Any, + ) -> socket.socket: + """Connect via a SOCKS proxy and return the socket.""" + socks_proxy = SocksProxy( + SOCKS_PROXY_TYPES[proxy.scheme], + proxy.host, + proxy.port, + proxy.username, + proxy.password, + SOCKS_PROXY_RDNS[proxy.scheme], + ) + kwargs.setdefault("timeout", deadline.timeout()) + return socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs) + +except ImportError: + + def connect_socks_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + deadline: Deadline, + **kwargs: Any, + ) -> socket.socket: + raise ImportError("python-socks is required to use a SOCKS proxy") + + +def connect_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + deadline: Deadline, + **kwargs: Any, +) -> socket.socket: + """Connect via a proxy and return the socket.""" + # parse_proxy() validates proxy.scheme. + if proxy.scheme[:5] == "socks": + return connect_socks_proxy(proxy, ws_uri, deadline, **kwargs) + else: + raise AssertionError("unsupported proxy") From 321be894176262a74a0b020aa1327f2aaca728ca Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 28 Jan 2025 22:42:11 +0100 Subject: [PATCH 1496/1539] Improve error handling for SOCKS proxy. --- docs/reference/exceptions.rst | 8 ++- src/websockets/__init__.py | 3 + src/websockets/asyncio/client.py | 17 +++++- src/websockets/exceptions.py | 41 ++++++++------ src/websockets/sync/client.py | 10 +++- tests/asyncio/test_client.py | 94 +++++++++++++++++++++++--------- tests/asyncio/test_server.py | 6 +- tests/sync/test_client.py | 94 +++++++++++++++++++++++--------- tests/sync/test_server.py | 4 +- tests/test_exceptions.py | 4 ++ 10 files changed, 203 insertions(+), 78 deletions(-) diff --git a/docs/reference/exceptions.rst b/docs/reference/exceptions.rst index e0c2efdd1..6c09a13fa 100644 --- a/docs/reference/exceptions.rst +++ b/docs/reference/exceptions.rst @@ -34,14 +34,16 @@ also reported by :func:`~websockets.asyncio.server.serve` in logs. .. autoexception:: SecurityError -.. autoexception:: InvalidMessage - -.. autoexception:: InvalidStatus +.. autoexception:: ProxyError .. autoexception:: InvalidProxyMessage .. autoexception:: InvalidProxyStatus +.. autoexception:: InvalidMessage + +.. autoexception:: InvalidStatus + .. autoexception:: InvalidHeader .. autoexception:: InvalidHeaderFormat diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 8bf282a73..28a10910b 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -49,6 +49,7 @@ "NegotiationError", "PayloadTooBig", "ProtocolError", + "ProxyError", "SecurityError", "WebSocketException", # .frames @@ -112,6 +113,7 @@ NegotiationError, PayloadTooBig, ProtocolError, + ProxyError, SecurityError, WebSocketException, ) @@ -173,6 +175,7 @@ "NegotiationError": ".exceptions", "PayloadTooBig": ".exceptions", "ProtocolError": ".exceptions", + "ProxyError": ".exceptions", "SecurityError": ".exceptions", "WebSocketException": ".exceptions", # .frames diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 7052ca85a..9582a4bb9 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -12,7 +12,7 @@ from ..client import ClientProtocol, backoff from ..datastructures import HeadersLike -from ..exceptions import InvalidMessage, InvalidStatus, SecurityError +from ..exceptions import InvalidMessage, InvalidStatus, ProxyError, SecurityError from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate from ..headers import validate_subprotocols @@ -148,7 +148,9 @@ def process_exception(exc: Exception) -> Exception | None: That exception will be raised, breaking out of the retry loop. """ - if isinstance(exc, (OSError, asyncio.TimeoutError)): + # This catches python-socks' ProxyConnectionError and ProxyTimeoutError. + # Remove asyncio.TimeoutError when dropping Python < 3.11. + if isinstance(exc, (OSError, TimeoutError, asyncio.TimeoutError)): return None if isinstance(exc, InvalidMessage) and isinstance(exc.__cause__, EOFError): return None @@ -266,6 +268,7 @@ class connect: Raises: InvalidURI: If ``uri`` isn't a valid WebSocket URI. + InvalidProxy: If ``proxy`` isn't a valid proxy. OSError: If the TCP connection fails. InvalidHandshake: If the opening handshake fails. TimeoutError: If the opening handshake times out. @@ -622,7 +625,15 @@ async def connect_socks_proxy( proxy.password, SOCKS_PROXY_RDNS[proxy.scheme], ) - return await socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs) + # connect() is documented to raise OSError. + # socks_proxy.connect() doesn't raise TimeoutError; it gets canceled. + # Wrap other exceptions in ProxyError, a subclass of InvalidHandshake. + try: + return await socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs) + except OSError: + raise + except Exception as exc: + raise ProxyError("failed to connect to SOCKS proxy") from exc except ImportError: diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index e70aac92e..ab1a15ca8 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -9,11 +9,12 @@ * :exc:`InvalidProxy` * :exc:`InvalidHandshake` * :exc:`SecurityError` + * :exc:`ProxyError` + * :exc:`InvalidProxyMessage` + * :exc:`InvalidProxyStatus` * :exc:`InvalidMessage` * :exc:`InvalidStatus` * :exc:`InvalidStatusCode` (legacy) - * :exc:`InvalidProxyMessage` - * :exc:`InvalidProxyStatus` * :exc:`InvalidHeader` * :exc:`InvalidHeaderFormat` * :exc:`InvalidHeaderValue` @@ -48,10 +49,11 @@ "InvalidProxy", "InvalidHandshake", "SecurityError", - "InvalidMessage", - "InvalidStatus", + "ProxyError", "InvalidProxyMessage", "InvalidProxyStatus", + "InvalidMessage", + "InvalidStatus", "InvalidHeader", "InvalidHeaderFormat", "InvalidHeaderValue", @@ -206,16 +208,23 @@ class SecurityError(InvalidHandshake): """ -class InvalidMessage(InvalidHandshake): +class ProxyError(InvalidHandshake): """ - Raised when a handshake request or response is malformed. + Raised when failing to connect to a proxy. """ -class InvalidStatus(InvalidHandshake): +class InvalidProxyMessage(ProxyError): """ - Raised when a handshake response rejects the WebSocket upgrade. + Raised when an HTTP proxy response is malformed. + + """ + + +class InvalidProxyStatus(ProxyError): + """ + Raised when an HTTP proxy rejects the connection. """ @@ -223,21 +232,19 @@ def __init__(self, response: http11.Response) -> None: self.response = response def __str__(self) -> str: - return ( - f"server rejected WebSocket connection: HTTP {self.response.status_code:d}" - ) + return f"proxy rejected connection: HTTP {self.response.status_code:d}" -class InvalidProxyMessage(InvalidHandshake): +class InvalidMessage(InvalidHandshake): """ - Raised when a proxy response is malformed. + Raised when a handshake request or response is malformed. """ -class InvalidProxyStatus(InvalidHandshake): +class InvalidStatus(InvalidHandshake): """ - Raised when a proxy rejects the connection. + Raised when a handshake response rejects the WebSocket upgrade. """ @@ -245,7 +252,9 @@ def __init__(self, response: http11.Response) -> None: self.response = response def __str__(self) -> str: - return f"proxy rejected connection: HTTP {self.response.status_code:d}" + return ( + f"server rejected WebSocket connection: HTTP {self.response.status_code:d}" + ) class InvalidHeader(InvalidHandshake): diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 5fbec67ad..e2a287648 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -9,6 +9,7 @@ from ..client import ClientProtocol from ..datastructures import HeadersLike +from ..exceptions import ProxyError from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate from ..headers import validate_subprotocols @@ -420,7 +421,14 @@ def connect_socks_proxy( SOCKS_PROXY_RDNS[proxy.scheme], ) kwargs.setdefault("timeout", deadline.timeout()) - return socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs) + # connect() is documented to raise OSError and TimeoutError. + # Wrap other exceptions in ProxyError, a subclass of InvalidHandshake. + try: + return socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs) + except (OSError, TimeoutError, socket.timeout): + raise + except Exception as exc: + raise ProxyError("failed to connect to SOCKS proxy") from exc except ImportError: diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index cb2b8ede6..bdd519fb8 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -17,6 +17,7 @@ InvalidProxy, InvalidStatus, InvalidURI, + ProxyError, SecurityError, ) from websockets.extensions.permessage_deflate import PerMessageDeflate @@ -379,24 +380,16 @@ def remove_accept_header(self, request, response): async def test_timeout_during_handshake(self): """Client times out before receiving handshake response from server.""" - gate = asyncio.get_running_loop().create_future() - - async def stall_connection(self, request): - await gate - - # The connection will be open for the server but failed for the client. - # Use a connection handler that exits immediately to avoid an exception. - async with serve(*args, process_request=stall_connection) as server: - try: - with self.assertRaises(TimeoutError) as raised: - async with connect(get_uri(server) + "/no-op", open_timeout=2 * MS): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "timed out during handshake", - ) - finally: - gate.set_result(None) + # Replace the WebSocket server with a TCP server that does't respond. + with socket.create_server(("localhost", 0)) as sock: + host, port = sock.getsockname() + with self.assertRaises(TimeoutError) as raised: + async with connect(f"ws://{host}:{port}", open_timeout=MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "timed out during handshake", + ) async def test_connection_closed_during_handshake(self): """Client reads EOF before receiving handshake response from server.""" @@ -570,11 +563,14 @@ class ProxyClientTests(unittest.IsolatedAsyncioTestCase): async def socks_proxy(self, auth=None): if auth: proxyauth = "hello:iloveyou" - proxy_uri = "http://hello:iloveyou@localhost:1080" + proxy_uri = "http://hello:iloveyou@localhost:51080" else: proxyauth = None - proxy_uri = "http://localhost:1080" - async with async_proxy(mode=["socks5"], proxyauth=proxyauth) as record_flows: + proxy_uri = "http://localhost:51080" + async with async_proxy( + mode=["socks5@51080"], + proxyauth=proxyauth, + ) as record_flows: with patch_environ({"socks_proxy": proxy_uri}): yield record_flows @@ -602,14 +598,62 @@ async def test_authenticated_socks_proxy(self): self.assertEqual(client.protocol.state.name, "OPEN") self.assertEqual(len(proxy.get_flows()), 1) + async def test_socks_proxy_connection_error(self): + """Client receives an error when connecting to the SOCKS5 proxy.""" + from python_socks import ProxyError as SocksProxyError + + async with self.socks_proxy(auth=True) as proxy: + with self.assertRaises(ProxyError) as raised: + async with connect( + "ws://example.com/", + proxy="socks5h://localhost:51080", # remove credentials + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "failed to connect to SOCKS proxy", + ) + self.assertIsInstance(raised.exception.__cause__, SocksProxyError) + self.assertEqual(len(proxy.get_flows()), 0) + + async def test_socks_proxy_connection_fails(self): + """Client fails to connect to the SOCKS5 proxy.""" + from python_socks import ProxyConnectionError as SocksProxyConnectionError + + with self.assertRaises(OSError) as raised: + async with connect( + "ws://example.com/", + proxy="socks5h://localhost:51080", # nothing at this address + ): + self.fail("did not raise") + # Don't test str(raised.exception) because we don't control it. + self.assertIsInstance(raised.exception, SocksProxyConnectionError) + + async def test_socks_proxy_connection_timeout(self): + """Client times out while connecting to the SOCKS5 proxy.""" + # Replace the proxy with a TCP server that does't respond. + with socket.create_server(("localhost", 0)) as sock: + host, port = sock.getsockname() + with self.assertRaises(TimeoutError) as raised: + async with connect( + "ws://example.com/", + proxy=f"socks5h://{host}:{port}/", + open_timeout=MS, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "timed out during handshake", + ) + async def test_explicit_proxy(self): """Client connects to server through a proxy set explicitly.""" - async with async_proxy(mode=["socks5"]) as proxy: + async with async_proxy(mode=["socks5@51080"]) as proxy: async with serve(*args) as server: async with connect( get_uri(server), # Take this opportunity to test socks5 instead of socks5h. - proxy="socks5://localhost:1080", + proxy="socks5://localhost:51080", ) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertEqual(len(proxy.get_flows()), 1) @@ -626,13 +670,13 @@ async def test_ignore_proxy_with_existing_socket(self): async def test_unsupported_proxy(self): """Client connects to server through an unsupported proxy.""" - with patch_environ({"ws_proxy": "other://localhost:1080"}): + with patch_environ({"ws_proxy": "other://localhost:51080"}): with self.assertRaises(InvalidProxy) as raised: async with connect("ws://example.com/"): self.fail("did not raise") self.assertEqual( str(raised.exception), - "other://localhost:1080 isn't a valid proxy: scheme other isn't supported", + "other://localhost:51080 isn't a valid proxy: scheme other isn't supported", ) diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 38c0315a1..6adfff8e9 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -65,9 +65,9 @@ async def test_connection_handler_raises_exception(self): async def test_existing_socket(self): """Server receives connection using a pre-existing socket.""" with socket.create_server(("localhost", 0)) as sock: - async with serve(handler, sock=sock, host=None, port=None): - uri = "ws://{}:{}/".format(*sock.getsockname()) - async with connect(uri) as client: + host, port = sock.getsockname() + async with serve(handler, sock=sock): + async with connect(f"ws://{host}:{port}/") as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") async def test_select_subprotocol(self): diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 2f62dd34d..dbecadcac 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -15,6 +15,7 @@ InvalidProxy, InvalidStatus, InvalidURI, + ProxyError, ) from websockets.extensions.permessage_deflate import PerMessageDeflate from websockets.sync.client import * @@ -148,24 +149,16 @@ def remove_accept_header(self, request, response): def test_timeout_during_handshake(self): """Client times out before receiving handshake response from server.""" - gate = threading.Event() - - def stall_connection(self, request): - gate.wait() - - # The connection will be open for the server but failed for the client. - # Use a connection handler that exits immediately to avoid an exception. - with run_server(process_request=stall_connection) as server: - try: - with self.assertRaises(TimeoutError) as raised: - with connect(get_uri(server) + "/no-op", open_timeout=2 * MS): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "timed out during handshake", - ) - finally: - gate.set() + # Replace the WebSocket server with a TCP server that does't respond. + with socket.create_server(("localhost", 0)) as sock: + host, port = sock.getsockname() + with self.assertRaises(TimeoutError) as raised: + with connect(f"ws://{host}:{port}", open_timeout=MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "timed out during handshake", + ) def test_connection_closed_during_handshake(self): """Client reads EOF before receiving handshake response from server.""" @@ -311,12 +304,15 @@ class ProxyClientTests(unittest.TestCase): def socks_proxy(self, auth=None): if auth: proxyauth = "hello:iloveyou" - proxy_uri = "http://hello:iloveyou@localhost:1080" + proxy_uri = "http://hello:iloveyou@localhost:51080" else: proxyauth = None - proxy_uri = "http://localhost:1080" + proxy_uri = "http://localhost:51080" - with sync_proxy(mode=["socks5"], proxyauth=proxyauth) as record_flows: + with sync_proxy( + mode=["socks5@51080"], + proxyauth=proxyauth, + ) as record_flows: with patch_environ({"socks_proxy": proxy_uri}): yield record_flows @@ -344,14 +340,62 @@ def test_authenticated_socks_proxy(self): self.assertEqual(client.protocol.state.name, "OPEN") self.assertEqual(len(proxy.get_flows()), 1) + def test_socks_proxy_connection_error(self): + """Client receives an error when connecting to the SOCKS5 proxy.""" + from python_socks import ProxyError as SocksProxyError + + with self.socks_proxy(auth=True) as proxy: + with self.assertRaises(ProxyError) as raised: + with connect( + "ws://example.com/", + proxy="socks5h://localhost:51080", # remove credentials + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "failed to connect to SOCKS proxy", + ) + self.assertIsInstance(raised.exception.__cause__, SocksProxyError) + self.assertEqual(len(proxy.get_flows()), 0) + + def test_socks_proxy_connection_fails(self): + """Client fails to connect to the SOCKS5 proxy.""" + from python_socks import ProxyConnectionError as SocksProxyConnectionError + + with self.assertRaises(OSError) as raised: + with connect( + "ws://example.com/", + proxy="socks5h://localhost:51080", # nothing at this address + ): + self.fail("did not raise") + # Don't test str(raised.exception) because we don't control it. + self.assertIsInstance(raised.exception, SocksProxyConnectionError) + + def test_socks_proxy_timeout(self): + """Client times out before connecting to the SOCKS5 proxy.""" + from python_socks import ProxyTimeoutError as SocksProxyTimeoutError + + # Replace the proxy with a TCP server that does't respond. + with socket.create_server(("localhost", 0)) as sock: + host, port = sock.getsockname() + with self.assertRaises(TimeoutError) as raised: + with connect( + "ws://example.com/", + proxy=f"socks5h://{host}:{port}/", + open_timeout=MS, + ): + self.fail("did not raise") + # Don't test str(raised.exception) because we don't control it. + self.assertIsInstance(raised.exception, SocksProxyTimeoutError) + def test_explicit_proxy(self): """Client connects to server through a proxy set explicitly.""" - with sync_proxy(mode=["socks5"]) as proxy: + with sync_proxy(mode=["socks5@51080"]) as proxy: with run_server() as server: with connect( get_uri(server), # Take this opportunity to test socks5 instead of socks5h. - proxy="socks5://localhost:1080", + proxy="socks5://localhost:51080", ) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertEqual(len(proxy.get_flows()), 1) @@ -368,13 +412,13 @@ def test_ignore_proxy_with_existing_socket(self): def test_unsupported_proxy(self): """Client connects to server through an unsupported proxy.""" - with patch_environ({"ws_proxy": "other://localhost:1080"}): + with patch_environ({"ws_proxy": "other://localhost:51080"}): with self.assertRaises(InvalidProxy) as raised: with connect("ws://example.com/"): self.fail("did not raise") self.assertEqual( str(raised.exception), - "other://localhost:1080 isn't a valid proxy: scheme other isn't supported", + "other://localhost:51080 isn't a valid proxy: scheme other isn't supported", ) diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index f59671efd..d04d1859a 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -64,9 +64,9 @@ def test_connection_handler_raises_exception(self): def test_existing_socket(self): """Server receives connection using a pre-existing socket.""" with socket.create_server(("localhost", 0)) as sock: + host, port = sock.getsockname() with run_server(sock=sock): - uri = "ws://{}:{}/".format(*sock.getsockname()) - with connect(uri) as client: + with connect(f"ws://{host}:{port}/") as client: self.assertEval(client, "ws.protocol.state.name", "OPEN") def test_select_subprotocol(self): diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 8b437ab5e..b4e7acee7 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -95,6 +95,10 @@ def test_str(self): SecurityError("redirect from WSS to WS"), "redirect from WSS to WS", ), + ( + ProxyError("failed to connect to SOCKS proxy"), + "failed to connect to SOCKS proxy", + ), ( InvalidMessage("malformed HTTP message"), "malformed HTTP message", From a00c18436a8c1500928f032f7131fa226f745334 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 29 Jan 2025 21:38:33 +0100 Subject: [PATCH 1497/1539] Support the legacy pattern for setting User-Agent. Support setting it with additional_headers for backwards compatibility with the legacy implementation and the API until websoockets 10.4. Before this change, the header was added twice: once with the custom value and once with the default value. Fix #1583. --- docs/howto/upgrade.rst | 14 ++++++++++---- src/websockets/asyncio/client.py | 2 +- src/websockets/sync/client.py | 2 +- tests/asyncio/test_client.py | 8 ++++++++ tests/sync/test_client.py | 8 ++++++++ 5 files changed, 28 insertions(+), 6 deletions(-) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index db6bf11f1..8cfd7b4b5 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -259,9 +259,12 @@ Arguments of :func:`~asyncio.client.connect` ``extra_headers`` → ``additional_headers`` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -If you're adding headers to the handshake request sent by -:func:`~legacy.client.connect` with the ``extra_headers`` argument, you must -rename it to ``additional_headers``. +If you're setting the ``User-Agent`` header with the ``extra_headers`` argument, +you should set it with ``user_agent_header`` instead. + +If you're adding other headers to the handshake request sent by +:func:`~legacy.client.connect` with ``extra_headers``, you must rename it to +``additional_headers``. Arguments of :func:`~asyncio.server.serve` .......................................... @@ -310,7 +313,10 @@ replace it with a ``process_request`` function or coroutine. ``extra_headers`` → ``process_response`` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -If you're adding headers to the handshake response sent by +If you're setting the ``Server`` header with ``extra_headers``, you should set +it with the ``server_header`` argument instead. + +If you're adding other headers to the handshake response sent by :func:`~legacy.server.serve` with the ``extra_headers`` argument, you must write a ``process_response`` callable instead. diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 9582a4bb9..1e560fe0c 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -89,7 +89,7 @@ async def handshake( if additional_headers is not None: self.request.headers.update(additional_headers) if user_agent_header: - self.request.headers["User-Agent"] = user_agent_header + self.request.headers.setdefault("User-Agent", user_agent_header) self.protocol.send_request(self.request) await asyncio.wait( diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index e2a287648..b7ab83664 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -86,7 +86,7 @@ def handshake( if additional_headers is not None: self.request.headers.update(additional_headers) if user_agent_header is not None: - self.request.headers["User-Agent"] = user_agent_header + self.request.headers.setdefault("User-Agent", user_agent_header) self.protocol.send_request(self.request) if not self.response_rcvd.wait(timeout): diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index bdd519fb8..9c7ee46ad 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -121,6 +121,14 @@ async def test_remove_user_agent(self): async with connect(get_uri(server), user_agent_header=None) as client: self.assertNotIn("User-Agent", client.request.headers) + async def test_legacy_user_agent(self): + """Client can override User-Agent header with additional_headers.""" + async with serve(*args) as server: + async with connect( + get_uri(server), additional_headers={"User-Agent": "Smith"} + ) as client: + self.assertEqual(client.request.headers["User-Agent"], "Smith") + async def test_keepalive_is_enabled(self): """Client enables keepalive and measures latency by default.""" async with serve(*args) as server: diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index dbecadcac..4844d3b5e 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -82,6 +82,14 @@ def test_remove_user_agent(self): with connect(get_uri(server), user_agent_header=None) as client: self.assertNotIn("User-Agent", client.request.headers) + def test_legacy_user_agent(self): + """Client can override User-Agent header with additional_headers.""" + with run_server() as server: + with connect( + get_uri(server), additional_headers={"User-Agent": "Smith"} + ) as client: + self.assertEqual(client.request.headers["User-Agent"], "Smith") + def test_keepalive_is_enabled(self): """Client enables keepalive and measures latency by default.""" with run_server() as server: From 7bb226a4aff8479994851a1cc83fc1aaf21a3724 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 29 Jan 2025 21:46:45 +0100 Subject: [PATCH 1498/1539] Standardize spelling of canceling/canceled. --- docs/project/changelog.rst | 2 +- src/websockets/asyncio/messages.py | 4 ++-- tests/asyncio/test_connection.py | 8 ++++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 2a429b43e..bfbfa793f 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -84,7 +84,7 @@ Bug fixes :mod:`threading` implementation. If a message is already received, it is returned. Previously, :exc:`TimeoutError` was raised incorrectly. -* Fixed a crash in the :mod:`asyncio` implementation when cancelling a ping +* Fixed a crash in the :mod:`asyncio` implementation when canceling a ping then receiving the corresponding pong. * Prevented :meth:`~asyncio.connection.Connection.close` from blocking when diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 581870037..1fd41811c 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -152,7 +152,7 @@ async def get(self, decode: bool | None = None) -> Data: self.get_in_progress = True # Locking with get_in_progress prevents concurrent execution - # until get() fetches a complete message or is cancelled. + # until get() fetches a complete message or is canceled. try: # First frame @@ -224,7 +224,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: self.get_in_progress = True # Locking with get_in_progress prevents concurrent execution - # until get_iter() fetches a complete message or is cancelled. + # until get_iter() fetches a complete message or is canceled. # If get_iter() raises an exception e.g. in decoder.decode(), # get_in_progress remains set and the connection becomes unusable. diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 5230eca89..668f55cbd 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -245,7 +245,7 @@ async def test_recv_during_recv_streaming(self): ) async def test_recv_cancellation_before_receiving(self): - """recv can be cancelled before receiving a frame.""" + """recv can be canceled before receiving a frame.""" recv_task = asyncio.create_task(self.connection.recv()) await asyncio.sleep(0) # let the event loop start recv_task @@ -257,7 +257,7 @@ async def test_recv_cancellation_before_receiving(self): self.assertEqual(await self.connection.recv(), "😀") async def test_recv_cancellation_while_receiving(self): - """recv cannot be cancelled after receiving a frame.""" + """recv cannot be canceled after receiving a frame.""" recv_task = asyncio.create_task(self.connection.recv()) await asyncio.sleep(0) # let the event loop start recv_task @@ -386,7 +386,7 @@ async def test_recv_streaming_during_recv_streaming(self): ) async def test_recv_streaming_cancellation_before_receiving(self): - """recv_streaming can be cancelled before receiving a frame.""" + """recv_streaming can be canceled before receiving a frame.""" recv_streaming_task = asyncio.create_task( alist(self.connection.recv_streaming()) ) @@ -403,7 +403,7 @@ async def test_recv_streaming_cancellation_before_receiving(self): ) async def test_recv_streaming_cancellation_while_receiving(self): - """recv_streaming cannot be cancelled after receiving a frame.""" + """recv_streaming cannot be canceled after receiving a frame.""" recv_streaming_task = asyncio.create_task( alist(self.connection.recv_streaming()) ) From 737d5ffb5121531f4c1efba2a25c3fd7db76b868 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 31 Jan 2025 21:37:15 +0100 Subject: [PATCH 1499/1539] Run mitmproxy only once per test case. This reduces the run time of the test suite by 40%, from 6.7s to 4.1s. --- tests/asyncio/test_client.py | 146 +++++++++++++++------------------ tests/proxy.py | 132 ++++++++++++++++-------------- tests/sync/test_client.py | 152 ++++++++++++++++------------------- 3 files changed, 207 insertions(+), 223 deletions(-) diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 9c7ee46ad..be8ef8a42 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -22,7 +22,7 @@ ) from websockets.extensions.permessage_deflate import PerMessageDeflate -from ..proxy import async_proxy +from ..proxy import ProxyMixin from ..utils import ( CLIENT_CONTEXT, MS, @@ -388,7 +388,7 @@ def remove_accept_header(self, request, response): async def test_timeout_during_handshake(self): """Client times out before receiving handshake response from server.""" - # Replace the WebSocket server with a TCP server that does't respond. + # Replace the WebSocket server with a TCP server that doesn't respond. with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() with self.assertRaises(TimeoutError) as raised: @@ -508,7 +508,7 @@ async def test_reject_invalid_server_certificate(self): """Client rejects certificate where server certificate isn't trusted.""" async with serve(*args, ssl=SERVER_CONTEXT) as server: with self.assertRaises(ssl.SSLCertVerificationError) as raised: - # The test certificate isn't trusted system-wide. + # The test certificate is self-signed. async with connect(get_uri(server)): self.fail("did not raise") self.assertIn( @@ -566,126 +566,105 @@ def redirect(connection, request): @unittest.skipUnless("mitmproxy" in sys.modules, "mitmproxy not installed") -class ProxyClientTests(unittest.IsolatedAsyncioTestCase): - @contextlib.asynccontextmanager - async def socks_proxy(self, auth=None): - if auth: - proxyauth = "hello:iloveyou" - proxy_uri = "http://hello:iloveyou@localhost:51080" - else: - proxyauth = None - proxy_uri = "http://localhost:51080" - async with async_proxy( - mode=["socks5@51080"], - proxyauth=proxyauth, - ) as record_flows: - with patch_environ({"socks_proxy": proxy_uri}): - yield record_flows +class SocksProxyClientTests(ProxyMixin, unittest.IsolatedAsyncioTestCase): + proxy_mode = "socks5@51080" async def test_socks_proxy(self): """Client connects to server through a SOCKS5 proxy.""" - async with self.socks_proxy() as proxy: + with patch_environ({"socks_proxy": "http://localhost:51080"}): async with serve(*args) as server: async with connect(get_uri(server)) as client: self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(len(proxy.get_flows()), 1) + self.assertNumFlows(1) async def test_secure_socks_proxy(self): """Client connects to server securely through a SOCKS5 proxy.""" - async with self.socks_proxy() as proxy: + with patch_environ({"socks_proxy": "http://localhost:51080"}): async with serve(*args, ssl=SERVER_CONTEXT) as server: async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(len(proxy.get_flows()), 1) + self.assertNumFlows(1) async def test_authenticated_socks_proxy(self): """Client connects to server through an authenticated SOCKS5 proxy.""" - async with self.socks_proxy(auth=True) as proxy: - async with serve(*args) as server: - async with connect(get_uri(server)) as client: - self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(len(proxy.get_flows()), 1) + try: + self.proxy_options.update(proxyauth="hello:iloveyou") + with patch_environ( + {"socks_proxy": "http://hello:iloveyou@localhost:51080"} + ): + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + finally: + self.proxy_options.update(proxyauth=None) + self.assertNumFlows(1) - async def test_socks_proxy_connection_error(self): - """Client receives an error when connecting to the SOCKS5 proxy.""" + async def test_authenticated_socks_proxy_error(self): + """Client fails to authenticate to the SOCKS5 proxy.""" from python_socks import ProxyError as SocksProxyError - async with self.socks_proxy(auth=True) as proxy: - with self.assertRaises(ProxyError) as raised: - async with connect( - "ws://example.com/", - proxy="socks5h://localhost:51080", # remove credentials - ): - self.fail("did not raise") + try: + self.proxy_options.update(proxyauth="any") + with patch_environ({"socks_proxy": "http://localhost:51080"}): + with self.assertRaises(ProxyError) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(proxyauth=None) self.assertEqual( str(raised.exception), "failed to connect to SOCKS proxy", ) self.assertIsInstance(raised.exception.__cause__, SocksProxyError) - self.assertEqual(len(proxy.get_flows()), 0) + self.assertNumFlows(0) - async def test_socks_proxy_connection_fails(self): + async def test_socks_proxy_connection_failure(self): """Client fails to connect to the SOCKS5 proxy.""" from python_socks import ProxyConnectionError as SocksProxyConnectionError - with self.assertRaises(OSError) as raised: - async with connect( - "ws://example.com/", - proxy="socks5h://localhost:51080", # nothing at this address - ): - self.fail("did not raise") + with patch_environ({"socks_proxy": "http://localhost:61080"}): # bad port + with self.assertRaises(OSError) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") # Don't test str(raised.exception) because we don't control it. self.assertIsInstance(raised.exception, SocksProxyConnectionError) + self.assertNumFlows(0) async def test_socks_proxy_connection_timeout(self): """Client times out while connecting to the SOCKS5 proxy.""" - # Replace the proxy with a TCP server that does't respond. + # Replace the proxy with a TCP server that doesn't respond. with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() - with self.assertRaises(TimeoutError) as raised: - async with connect( - "ws://example.com/", - proxy=f"socks5h://{host}:{port}/", - open_timeout=MS, - ): - self.fail("did not raise") + with patch_environ({"socks_proxy": f"http://{host}:{port}"}): + with self.assertRaises(TimeoutError) as raised: + async with connect("ws://example.com/", open_timeout=MS): + self.fail("did not raise") self.assertEqual( str(raised.exception), "timed out during handshake", ) + self.assertNumFlows(0) - async def test_explicit_proxy(self): - """Client connects to server through a proxy set explicitly.""" - async with async_proxy(mode=["socks5@51080"]) as proxy: - async with serve(*args) as server: - async with connect( - get_uri(server), - # Take this opportunity to test socks5 instead of socks5h. - proxy="socks5://localhost:51080", - ) as client: - self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(len(proxy.get_flows()), 1) + async def test_explicit_socks_proxy(self): + """Client connects to server through a SOCKS5 proxy set explicitly.""" + async with serve(*args) as server: + async with connect( + get_uri(server), + # Take this opportunity to test socks5 instead of socks5h. + proxy="socks5://localhost:51080", + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) async def test_ignore_proxy_with_existing_socket(self): """Client connects using a pre-existing socket.""" - async with self.socks_proxy() as proxy: + with patch_environ({"socks_proxy": "http://localhost:51080"}): async with serve(*args) as server: with socket.create_connection(get_host_port(server)) as sock: # Use a non-existing domain to ensure we connect to sock. async with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(len(proxy.get_flows()), 0) - - async def test_unsupported_proxy(self): - """Client connects to server through an unsupported proxy.""" - with patch_environ({"ws_proxy": "other://localhost:51080"}): - with self.assertRaises(InvalidProxy) as raised: - async with connect("ws://example.com/"): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "other://localhost:51080 isn't a valid proxy: scheme other isn't supported", - ) + self.assertNumFlows(0) @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") @@ -724,10 +703,7 @@ def redirect(connection, request): "cannot follow cross-origin redirect to ws://other/ with a Unix socket", ) - -@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") -class SecureUnixClientTests(unittest.IsolatedAsyncioTestCase): - async def test_connection(self): + async def test_secure_connection(self): """Client connects to server securely over a Unix socket.""" with temp_unix_socket_path() as path: async with unix_serve(handler, path, ssl=SERVER_CONTEXT): @@ -769,6 +745,16 @@ async def test_secure_uri_without_ssl(self): "ssl=None is incompatible with a wss:// URI", ) + async def test_unsupported_proxy(self): + """Client rejects unsupported proxy.""" + with self.assertRaises(InvalidProxy) as raised: + async with connect("ws://example.com/", proxy="other://localhost:51080"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "other://localhost:51080 isn't a valid proxy: scheme other isn't supported", + ) + async def test_unix_without_path_or_sock(self): """Unix client requires path when sock isn't provided.""" with self.assertRaises(ValueError) as raised: diff --git a/tests/proxy.py b/tests/proxy.py index 95525a360..8b62b0eb7 100644 --- a/tests/proxy.py +++ b/tests/proxy.py @@ -1,15 +1,14 @@ import asyncio -import contextlib import pathlib import threading import warnings -warnings.filterwarnings("ignore", category=DeprecationWarning, module="mitmproxy") -warnings.filterwarnings("ignore", category=DeprecationWarning, module="passlib") -warnings.filterwarnings("ignore", category=DeprecationWarning, module="pyasn1") - try: + # Ignore deprecation warnings raised by mitmproxy dependencies at import time. + warnings.filterwarnings("ignore", category=DeprecationWarning, module="passlib") + warnings.filterwarnings("ignore", category=DeprecationWarning, module="pyasn1") + from mitmproxy.addons import core, next_layer, proxyauth, proxyserver, tlsconfig from mitmproxy.master import Master from mitmproxy.options import Options @@ -18,13 +17,10 @@ class RecordFlows: - def __init__(self): - self.ready = asyncio.get_running_loop().create_future() + def __init__(self, on_running): + self.running = on_running self.flows = [] - def running(self): - self.ready.set_result(None) - def websocket_start(self, flow): self.flows.append(flow) @@ -32,58 +28,76 @@ def get_flows(self): flows, self.flows[:] = self.flows[:], [] return flows + def reset_flows(self): + self.flows = [] + + +class ProxyMixin: + """ + Run mitmproxy in a background thread. + + While it's uncommon to run two event loops in two threads, tests for the + asyncio implementation rely on this class too because it starts an event + loop for mitm proxy once, then a new event loop for each test. + """ + + proxy_mode = None + + @classmethod + async def run_proxy(cls): + cls.proxy_loop = loop = asyncio.get_event_loop() + cls.proxy_stop = stop = loop.create_future() + + cls.proxy_options = options = Options(mode=[cls.proxy_mode]) + cls.proxy_master = master = Master(options) + master.addons.add( + core.Core(), + proxyauth.ProxyAuth(), + proxyserver.Proxyserver(), + next_layer.NextLayer(), + tlsconfig.TlsConfig(), + RecordFlows(on_running=cls.proxy_ready.set), + ) + options.update( + # Use test certificate for TLS between client and proxy. + certs=[str(pathlib.Path(__file__).with_name("test_localhost.pem"))], + # Disable TLS verification between proxy and upstream. + ssl_insecure=True, + ) + + task = loop.create_task(cls.proxy_master.run()) + await stop -@contextlib.asynccontextmanager -async def async_proxy(mode, **config): - options = Options(mode=mode) - master = Master(options) - record_flows = RecordFlows() - master.addons.add( - core.Core(), - proxyauth.ProxyAuth(), - proxyserver.Proxyserver(), - next_layer.NextLayer(), - tlsconfig.TlsConfig(), - record_flows, - ) - config.update( - # Use our test certificate for TLS between client and proxy - # and disable TLS verification between proxy and upstream. - certs=[str(pathlib.Path(__file__).with_name("test_localhost.pem"))], - ssl_insecure=True, - ) - options.update(**config) - - asyncio.create_task(master.run()) - try: - await record_flows.ready - yield record_flows - finally: for server in master.addons.get("proxyserver").servers: await server.stop() master.shutdown() + await task + + @classmethod + def setUpClass(cls): + super().setUpClass() + + # Ignore deprecation warnings raised by mitmproxy at run time. + warnings.filterwarnings( + "ignore", category=DeprecationWarning, module="mitmproxy" + ) + + cls.proxy_ready = threading.Event() + cls.proxy_thread = threading.Thread(target=asyncio.run, args=(cls.run_proxy(),)) + cls.proxy_thread.start() + cls.proxy_ready.wait() + + def assertNumFlows(self, num_flows): + record_flows = self.proxy_master.addons.get("recordflows") + self.assertEqual(len(record_flows.get_flows()), num_flows) + def tearDown(self): + record_flows = self.proxy_master.addons.get("recordflows") + record_flows.reset_flows() + super().tearDown() -@contextlib.contextmanager -def sync_proxy(mode, **config): - loop = None - test_done = None - proxy_ready = threading.Event() - record_flows = None - - async def proxy_coroutine(): - nonlocal loop, test_done, proxy_ready, record_flows - loop = asyncio.get_running_loop() - test_done = loop.create_future() - async with async_proxy(mode, **config) as record_flows: - proxy_ready.set() - await test_done - - proxy_thread = threading.Thread(target=asyncio.run, args=(proxy_coroutine(),)) - proxy_thread.start() - try: - proxy_ready.wait() - yield record_flows - finally: - loop.call_soon_threadsafe(test_done.set_result, None) - proxy_thread.join() + @classmethod + def tearDownClass(cls): + cls.proxy_loop.call_soon_threadsafe(cls.proxy_stop.set_result, None) + cls.proxy_thread.join() + super().tearDownClass() diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 4844d3b5e..38cfab7b3 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -1,4 +1,3 @@ -import contextlib import http import logging import socket @@ -20,7 +19,7 @@ from websockets.extensions.permessage_deflate import PerMessageDeflate from websockets.sync.client import * -from ..proxy import sync_proxy +from ..proxy import ProxyMixin from ..utils import ( CLIENT_CONTEXT, MS, @@ -157,7 +156,7 @@ def remove_accept_header(self, request, response): def test_timeout_during_handshake(self): """Client times out before receiving handshake response from server.""" - # Replace the WebSocket server with a TCP server that does't respond. + # Replace the WebSocket server with a TCP server that doesn't respond. with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() with self.assertRaises(TimeoutError) as raised: @@ -283,7 +282,7 @@ def test_reject_invalid_server_certificate(self): """Client rejects certificate where server certificate isn't trusted.""" with run_server(ssl=SERVER_CONTEXT) as server: with self.assertRaises(ssl.SSLCertVerificationError) as raised: - # The test certificate isn't trusted system-wide. + # The test certificate is self-signed. with connect(get_uri(server)): self.fail("did not raise") self.assertIn( @@ -307,127 +306,105 @@ def test_reject_invalid_server_hostname(self): @unittest.skipUnless("mitmproxy" in sys.modules, "mitmproxy not installed") -class ProxyClientTests(unittest.TestCase): - @contextlib.contextmanager - def socks_proxy(self, auth=None): - if auth: - proxyauth = "hello:iloveyou" - proxy_uri = "http://hello:iloveyou@localhost:51080" - else: - proxyauth = None - proxy_uri = "http://localhost:51080" - - with sync_proxy( - mode=["socks5@51080"], - proxyauth=proxyauth, - ) as record_flows: - with patch_environ({"socks_proxy": proxy_uri}): - yield record_flows +class SocksProxyClientTests(ProxyMixin, unittest.TestCase): + proxy_mode = "socks5@51080" def test_socks_proxy(self): """Client connects to server through a SOCKS5 proxy.""" - with self.socks_proxy() as proxy: + with patch_environ({"socks_proxy": "http://localhost:51080"}): with run_server() as server: with connect(get_uri(server)) as client: self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(len(proxy.get_flows()), 1) + self.assertNumFlows(1) def test_secure_socks_proxy(self): """Client connects to server securely through a SOCKS5 proxy.""" - with self.socks_proxy() as proxy: + with patch_environ({"socks_proxy": "http://localhost:51080"}): with run_server(ssl=SERVER_CONTEXT) as server: with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(len(proxy.get_flows()), 1) + self.assertNumFlows(1) def test_authenticated_socks_proxy(self): """Client connects to server through an authenticated SOCKS5 proxy.""" - with self.socks_proxy(auth=True) as proxy: - with run_server() as server: - with connect(get_uri(server)) as client: - self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(len(proxy.get_flows()), 1) + try: + self.proxy_options.update(proxyauth="hello:iloveyou") + with patch_environ( + {"socks_proxy": "http://hello:iloveyou@localhost:51080"} + ): + with run_server() as server: + with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + finally: + self.proxy_options.update(proxyauth=None) + self.assertNumFlows(1) - def test_socks_proxy_connection_error(self): - """Client receives an error when connecting to the SOCKS5 proxy.""" + def test_authenticated_socks_proxy_error(self): + """Client fails to authenticate to the SOCKS5 proxy.""" from python_socks import ProxyError as SocksProxyError - with self.socks_proxy(auth=True) as proxy: - with self.assertRaises(ProxyError) as raised: - with connect( - "ws://example.com/", - proxy="socks5h://localhost:51080", # remove credentials - ): - self.fail("did not raise") + try: + self.proxy_options.update(proxyauth="any") + with patch_environ({"socks_proxy": "http://localhost:51080"}): + with self.assertRaises(ProxyError) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(proxyauth=None) self.assertEqual( str(raised.exception), "failed to connect to SOCKS proxy", ) self.assertIsInstance(raised.exception.__cause__, SocksProxyError) - self.assertEqual(len(proxy.get_flows()), 0) + self.assertNumFlows(0) - def test_socks_proxy_connection_fails(self): + def test_socks_proxy_connection_failure(self): """Client fails to connect to the SOCKS5 proxy.""" from python_socks import ProxyConnectionError as SocksProxyConnectionError - with self.assertRaises(OSError) as raised: - with connect( - "ws://example.com/", - proxy="socks5h://localhost:51080", # nothing at this address - ): - self.fail("did not raise") + with patch_environ({"socks_proxy": "http://localhost:61080"}): # bad port + with self.assertRaises(OSError) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") # Don't test str(raised.exception) because we don't control it. self.assertIsInstance(raised.exception, SocksProxyConnectionError) + self.assertNumFlows(0) - def test_socks_proxy_timeout(self): - """Client times out before connecting to the SOCKS5 proxy.""" + def test_socks_proxy_connection_timeout(self): + """Client times out while connecting to the SOCKS5 proxy.""" from python_socks import ProxyTimeoutError as SocksProxyTimeoutError - # Replace the proxy with a TCP server that does't respond. + # Replace the proxy with a TCP server that doesn't respond. with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() - with self.assertRaises(TimeoutError) as raised: - with connect( - "ws://example.com/", - proxy=f"socks5h://{host}:{port}/", - open_timeout=MS, - ): - self.fail("did not raise") + with patch_environ({"socks_proxy": f"http://{host}:{port}"}): + with self.assertRaises(TimeoutError) as raised: + with connect("ws://example.com/", open_timeout=MS): + self.fail("did not raise") # Don't test str(raised.exception) because we don't control it. self.assertIsInstance(raised.exception, SocksProxyTimeoutError) + self.assertNumFlows(0) - def test_explicit_proxy(self): - """Client connects to server through a proxy set explicitly.""" - with sync_proxy(mode=["socks5@51080"]) as proxy: - with run_server() as server: - with connect( - get_uri(server), - # Take this opportunity to test socks5 instead of socks5h. - proxy="socks5://localhost:51080", - ) as client: - self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(len(proxy.get_flows()), 1) + def test_explicit_socks_proxy(self): + """Client connects to server through a SOCKS5 proxy set explicitly.""" + with run_server() as server: + with connect( + get_uri(server), + # Take this opportunity to test socks5 instead of socks5h. + proxy="socks5://localhost:51080", + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) def test_ignore_proxy_with_existing_socket(self): """Client connects using a pre-existing socket.""" - with self.socks_proxy() as proxy: + with patch_environ({"ws_proxy": "http://localhost:58080"}): with run_server() as server: with socket.create_connection(server.socket.getsockname()) as sock: # Use a non-existing domain to ensure we connect to sock. with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(len(proxy.get_flows()), 0) - - def test_unsupported_proxy(self): - """Client connects to server through an unsupported proxy.""" - with patch_environ({"ws_proxy": "other://localhost:51080"}): - with self.assertRaises(InvalidProxy) as raised: - with connect("ws://example.com/"): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "other://localhost:51080 isn't a valid proxy: scheme other isn't supported", - ) + self.assertNumFlows(0) @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") @@ -447,10 +424,7 @@ def test_set_host_header(self): with unix_connect(path, uri="ws://overridden/") as client: self.assertEqual(client.request.headers["Host"], "overridden") - -@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") -class SecureUnixClientTests(unittest.TestCase): - def test_connection(self): + def test_secure_connection(self): """Client connects to server securely over a Unix socket.""" with temp_unix_socket_path() as path: with run_unix_server(path, ssl=SERVER_CONTEXT): @@ -488,6 +462,16 @@ def test_unix_without_path_or_sock(self): "missing path argument", ) + def test_unsupported_proxy(self): + """Client rejects unsupported proxy.""" + with self.assertRaises(InvalidProxy) as raised: + with connect("ws://example.com/", proxy="other://localhost:58080"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "other://localhost:58080 isn't a valid proxy: scheme other isn't supported", + ) + def test_unix_with_path_and_sock(self): """Unix client rejects path when sock is provided.""" sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) From 0eab9b82959cdcea658ff5ab943f03503caa082b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 31 Jan 2025 21:54:06 +0100 Subject: [PATCH 1500/1539] Simplify SOCKS proxy implementation. --- src/websockets/asyncio/client.py | 37 ++++++++++++++++---------------- src/websockets/sync/client.py | 35 +++++++++++------------------- 2 files changed, 31 insertions(+), 41 deletions(-) diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 1e560fe0c..a3fcab039 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -383,16 +383,28 @@ def factory() -> ClientConnection: if kwargs.pop("unix", False): _, connection = await loop.create_unix_connection(factory, **kwargs) elif proxy is not None: - kwargs["sock"] = await connect_proxy( - parse_proxy(proxy), - ws_uri, - local_addr=kwargs.pop("local_addr", None), - ) - _, connection = await loop.create_connection(factory, **kwargs) + proxy_parsed = parse_proxy(proxy) + if proxy_parsed.scheme[:5] == "socks": + # Connect to the server through the proxy. + sock = await connect_socks_proxy( + proxy_parsed, + ws_uri, + local_addr=kwargs.pop("local_addr", None), + ) + # Initialize WebSocket connection via the proxy. + _, connection = await loop.create_connection( + factory, + sock=sock, + **kwargs, + ) + else: + raise AssertionError("unsupported proxy") else: + # Connect to the server directly. if kwargs.get("sock") is None: kwargs.setdefault("host", ws_uri.host) kwargs.setdefault("port", ws_uri.port) + # Initialize WebSocket connection. _, connection = await loop.create_connection(factory, **kwargs) return connection @@ -643,16 +655,3 @@ async def connect_socks_proxy( **kwargs: Any, ) -> socket.socket: raise ImportError("python-socks is required to use a SOCKS proxy") - - -async def connect_proxy( - proxy: Proxy, - ws_uri: WebSocketURI, - **kwargs: Any, -) -> socket.socket: - """Connect via a proxy and return the socket.""" - # parse_proxy() validates proxy.scheme. - if proxy.scheme[:5] == "socks": - return await connect_socks_proxy(proxy, ws_uri, **kwargs) - else: - raise AssertionError("unsupported proxy") diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index b7ab83664..722def319 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -284,14 +284,19 @@ def connect( assert path is not None # mypy cannot figure this out sock.connect(path) elif proxy is not None: - sock = connect_proxy( - parse_proxy(proxy), - ws_uri, - deadline, - # websockets is consistent with the socket module while - # python_socks is consistent across implementations. - local_addr=kwargs.pop("source_address", None), - ) + proxy_parsed = parse_proxy(proxy) + if proxy_parsed.scheme[:5] == "socks": + # Connect to the server through the proxy. + sock = connect_socks_proxy( + proxy_parsed, + ws_uri, + deadline, + # websockets is consistent with the socket module while + # python_socks is consistent across implementations. + local_addr=kwargs.pop("source_address", None), + ) + else: + raise AssertionError("unsupported proxy") else: kwargs.setdefault("timeout", deadline.timeout()) sock = socket.create_connection( @@ -439,17 +444,3 @@ def connect_socks_proxy( **kwargs: Any, ) -> socket.socket: raise ImportError("python-socks is required to use a SOCKS proxy") - - -def connect_proxy( - proxy: Proxy, - ws_uri: WebSocketURI, - deadline: Deadline, - **kwargs: Any, -) -> socket.socket: - """Connect via a proxy and return the socket.""" - # parse_proxy() validates proxy.scheme. - if proxy.scheme[:5] == "socks": - return connect_socks_proxy(proxy, ws_uri, deadline, **kwargs) - else: - raise AssertionError("unsupported proxy") From 3def92ead354c8bd9dc5e2901e82fbfa43902485 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Feb 2025 12:12:11 +0100 Subject: [PATCH 1501/1539] Don't intercept TLS connections in tests. This avoids mixing TLS termination by mitmproxy and by websockets. --- tests/proxy.py | 17 ++++++++--------- tox.ini | 4 ++-- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/tests/proxy.py b/tests/proxy.py index 8b62b0eb7..6cae7b49e 100644 --- a/tests/proxy.py +++ b/tests/proxy.py @@ -1,5 +1,4 @@ import asyncio -import pathlib import threading import warnings @@ -21,7 +20,7 @@ def __init__(self, on_running): self.running = on_running self.flows = [] - def websocket_start(self, flow): + def tcp_start(self, flow): self.flows.append(flow) def get_flows(self): @@ -48,7 +47,13 @@ async def run_proxy(cls): cls.proxy_loop = loop = asyncio.get_event_loop() cls.proxy_stop = stop = loop.create_future() - cls.proxy_options = options = Options(mode=[cls.proxy_mode]) + cls.proxy_options = options = Options( + mode=[cls.proxy_mode], + # Don't intercept connections, but record them. + ignore_hosts=["^localhost:", "^127.0.0.1:", "^::1:"], + # This option requires mitmproxy 11.0.0, which requires Python 3.11. + show_ignored_hosts=True, + ) cls.proxy_master = master = Master(options) master.addons.add( core.Core(), @@ -58,12 +63,6 @@ async def run_proxy(cls): tlsconfig.TlsConfig(), RecordFlows(on_running=cls.proxy_ready.set), ) - options.update( - # Use test certificate for TLS between client and proxy. - certs=[str(pathlib.Path(__file__).with_name("test_localhost.pem"))], - # Disable TLS verification between proxy and upstream. - ssl_insecure=True, - ) task = loop.create_task(cls.proxy_master.run()) await stop diff --git a/tox.ini b/tox.ini index f5a2f5d3c..918aeaaec 100644 --- a/tox.ini +++ b/tox.ini @@ -15,8 +15,8 @@ commands = pass_env = WEBSOCKETS_* deps = - mitmproxy - python-socks[asyncio] + py311,py312,py313,coverage,maxi_cov: mitmproxy + py311,py312,py313,coverage,maxi_cov: python-socks[asyncio] [testenv:coverage] commands = From ba7104caa7540b29db7c02aaf9f07c95691e42f9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Feb 2025 21:28:58 +0100 Subject: [PATCH 1502/1539] Clarify timeout errors. Specifically, not hiding the __cause__ of TimeoutError makes it visible when it happens while connecting to a proxy. --- src/websockets/asyncio/client.py | 4 ++-- src/websockets/sync/client.py | 2 +- src/websockets/sync/server.py | 2 +- tests/asyncio/test_client.py | 4 ++-- tests/sync/test_client.py | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index a3fcab039..058313388 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -511,9 +511,9 @@ async def __await_impl__(self) -> ClientConnection: else: raise SecurityError(f"more than {MAX_REDIRECTS} redirects") - except TimeoutError: + except TimeoutError as exc: # Re-raise exception with an informative error message. - raise TimeoutError("timed out during handshake") from None + raise TimeoutError("timed out during opening handshake") from exc # ... = yield from connect(...) - remove when dropping Python < 3.10 diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 722def319..8ce9e7d84 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -90,7 +90,7 @@ def handshake( self.protocol.send_request(self.request) if not self.response_rcvd.wait(timeout): - raise TimeoutError("timed out during handshake") + raise TimeoutError("timed out while waiting for handshake response") # self.protocol.handshake_exc is set when the connection is lost before # receiving a response, when the response cannot be parsed, or when the diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 2b753b2c5..10e3b6816 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -128,7 +128,7 @@ def handshake( """ if not self.request_rcvd.wait(timeout): - raise TimeoutError("timed out during handshake") + raise TimeoutError("timed out while waiting for handshake request") if self.request is not None: with self.send_context(expected_state=CONNECTING): diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index be8ef8a42..3728c7734 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -396,7 +396,7 @@ async def test_timeout_during_handshake(self): self.fail("did not raise") self.assertEqual( str(raised.exception), - "timed out during handshake", + "timed out during opening handshake", ) async def test_connection_closed_during_handshake(self): @@ -641,7 +641,7 @@ async def test_socks_proxy_connection_timeout(self): self.fail("did not raise") self.assertEqual( str(raised.exception), - "timed out during handshake", + "timed out during opening handshake", ) self.assertNumFlows(0) diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 38cfab7b3..47886e015 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -164,7 +164,7 @@ def test_timeout_during_handshake(self): self.fail("did not raise") self.assertEqual( str(raised.exception), - "timed out during handshake", + "timed out while waiting for handshake response", ) def test_connection_closed_during_handshake(self): From 6c15b9c8b039d9da610d0a1563a302cc54cf62b0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 25 Jan 2025 13:35:19 +0100 Subject: [PATCH 1503/1539] Add option to always include port in build_host helper. --- src/websockets/headers.py | 10 +++++++-- tests/test_headers.py | 43 +++++++++++++++++++++++---------------- 2 files changed, 33 insertions(+), 20 deletions(-) diff --git a/src/websockets/headers.py b/src/websockets/headers.py index e05948a1f..c42abd976 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -36,7 +36,13 @@ T = TypeVar("T") -def build_host(host: str, port: int, secure: bool) -> str: +def build_host( + host: str, + port: int, + secure: bool, + *, + always_include_port: bool = False, +) -> str: """ Build a ``Host`` header. @@ -53,7 +59,7 @@ def build_host(host: str, port: int, secure: bool) -> str: if address.version == 6: host = f"[{host}]" - if port != (443 if secure else 80): + if always_include_port or port != (443 if secure else 80): host = f"{host}:{port}" return host diff --git a/tests/test_headers.py b/tests/test_headers.py index 4ebd8b90c..816afc541 100644 --- a/tests/test_headers.py +++ b/tests/test_headers.py @@ -6,26 +6,33 @@ class HeadersTests(unittest.TestCase): def test_build_host(self): - for (host, port, secure), result in [ - (("localhost", 80, False), "localhost"), - (("localhost", 8000, False), "localhost:8000"), - (("localhost", 443, True), "localhost"), - (("localhost", 8443, True), "localhost:8443"), - (("example.com", 80, False), "example.com"), - (("example.com", 8000, False), "example.com:8000"), - (("example.com", 443, True), "example.com"), - (("example.com", 8443, True), "example.com:8443"), - (("127.0.0.1", 80, False), "127.0.0.1"), - (("127.0.0.1", 8000, False), "127.0.0.1:8000"), - (("127.0.0.1", 443, True), "127.0.0.1"), - (("127.0.0.1", 8443, True), "127.0.0.1:8443"), - (("::1", 80, False), "[::1]"), - (("::1", 8000, False), "[::1]:8000"), - (("::1", 443, True), "[::1]"), - (("::1", 8443, True), "[::1]:8443"), + for (host, port, secure), (result, result_with_port) in [ + (("localhost", 80, False), ("localhost", "localhost:80")), + (("localhost", 8000, False), ("localhost:8000", "localhost:8000")), + (("localhost", 443, True), ("localhost", "localhost:443")), + (("localhost", 8443, True), ("localhost:8443", "localhost:8443")), + (("example.com", 80, False), ("example.com", "example.com:80")), + (("example.com", 8000, False), ("example.com:8000", "example.com:8000")), + (("example.com", 443, True), ("example.com", "example.com:443")), + (("example.com", 8443, True), ("example.com:8443", "example.com:8443")), + (("127.0.0.1", 80, False), ("127.0.0.1", "127.0.0.1:80")), + (("127.0.0.1", 8000, False), ("127.0.0.1:8000", "127.0.0.1:8000")), + (("127.0.0.1", 443, True), ("127.0.0.1", "127.0.0.1:443")), + (("127.0.0.1", 8443, True), ("127.0.0.1:8443", "127.0.0.1:8443")), + (("::1", 80, False), ("[::1]", "[::1]:80")), + (("::1", 8000, False), ("[::1]:8000", "[::1]:8000")), + (("::1", 443, True), ("[::1]", "[::1]:443")), + (("::1", 8443, True), ("[::1]:8443", "[::1]:8443")), ]: with self.subTest(host=host, port=port, secure=secure): - self.assertEqual(build_host(host, port, secure), result) + self.assertEqual( + build_host(host, port, secure), + result, + ) + self.assertEqual( + build_host(host, port, secure, always_include_port=True), + result_with_port, + ) def test_parse_connection(self): for header, parsed in [ From e5e85d21a07995ac993c325c7b2c38441c2cbf83 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 31 Jan 2025 22:16:21 +0100 Subject: [PATCH 1504/1539] Add option not to read body in Response.parse. This allows keeping the connection open after reading the response to a CONNECT request. --- src/websockets/http11.py | 10 +++++++--- tests/test_http11.py | 8 +++++++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 49d7b9a41..530ac3d09 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -210,6 +210,7 @@ def parse( read_line: Callable[[int], Generator[None, None, bytes]], read_exact: Callable[[int], Generator[None, None, bytes]], read_to_eof: Callable[[int], Generator[None, None, bytes]], + include_body: bool = True, ) -> Generator[None, None, Response]: """ Parse a WebSocket handshake response. @@ -265,9 +266,12 @@ def parse( headers = yield from parse_headers(read_line) - body = yield from read_body( - status_code, headers, read_line, read_exact, read_to_eof - ) + if include_body: + body = yield from read_body( + status_code, headers, read_line, read_exact, read_to_eof + ) + else: + body = b"" return cls(status_code, reason, headers, body) diff --git a/tests/test_http11.py b/tests/test_http11.py index bb0d27b95..3afb6d02c 100644 --- a/tests/test_http11.py +++ b/tests/test_http11.py @@ -130,11 +130,12 @@ def setUp(self): super().setUp() self.reader = StreamReader() - def parse(self): + def parse(self, **kwargs): return Response.parse( self.reader.read_line, self.reader.read_exact, self.reader.read_to_eof, + **kwargs, ) def test_parse(self): @@ -322,6 +323,11 @@ def test_parse_body_not_modified(self): response = self.assertGeneratorReturns(self.parse()) self.assertEqual(response.body, b"") + def test_parse_without_body(self): + self.reader.feed_data(b"HTTP/1.1 200 Connection Established\r\n\r\n") + response = self.assertGeneratorReturns(self.parse(include_body=False)) + self.assertEqual(response.body, b"") + def test_serialize(self): # Example from the protocol overview in RFC 6455 response = Response( From ec706c922deeaf6fcf98174d4b2f16a6819117c7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 26 Jan 2025 21:21:49 +0100 Subject: [PATCH 1505/1539] Add support for HTTP(S) proxies. Fix #364. --- docs/project/changelog.rst | 4 +- docs/reference/features.rst | 3 +- docs/topics/proxies.rst | 19 +++ src/websockets/asyncio/client.py | 154 ++++++++++++++++++++++- src/websockets/sync/client.py | 204 ++++++++++++++++++++++++++++++- tests/asyncio/test_client.py | 198 +++++++++++++++++++++++++++++- tests/proxy.py | 37 +++++- tests/sync/test_client.py | 184 ++++++++++++++++++++++++++++ tests/utils.py | 7 +- 9 files changed, 791 insertions(+), 19 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index bfbfa793f..7bb94b349 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -35,12 +35,12 @@ notice. Backwards-incompatible changes .............................. -.. admonition:: Client connections use SOCKS proxies automatically. +.. admonition:: Client connections use SOCKS and HTTP proxies automatically. :class: important If a proxy is configured in the operating system or with an environment variable, websockets uses it automatically when connecting to a server. - This feature requires installing the third-party library `python-socks`_. + SOCKS proxies require installing the third-party library `python-socks`_. If you want to disable the proxy, add ``proxy=None`` when calling :func:`~asyncio.client.connect`. See :doc:`../topics/proxies` for details. diff --git a/docs/reference/features.rst b/docs/reference/features.rst index eaecd02a9..93b083d20 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -166,12 +166,11 @@ Client | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | ❌ | | (`#784`_) | | | | | +------------------------------------+--------+--------+--------+--------+ - | Connect via HTTP proxy (`#364`_) | ❌ | ❌ | — | ❌ | + | Connect via HTTP proxy | ✅ | ✅ | — | ❌ | +------------------------------------+--------+--------+--------+--------+ | Connect via SOCKS5 proxy | ✅ | ✅ | — | ❌ | +------------------------------------+--------+--------+--------+--------+ -.. _#364: https://github.com/python-websockets/websockets/issues/364 .. _#784: https://github.com/python-websockets/websockets/issues/784 Known limitations diff --git a/docs/topics/proxies.rst b/docs/topics/proxies.rst index fd3ae78b6..14fc68c0c 100644 --- a/docs/topics/proxies.rst +++ b/docs/topics/proxies.rst @@ -30,6 +30,9 @@ most common, for `historical reasons`_, and recommended. .. _historical reasons: https://unix.stackexchange.com/questions/212894/ +websockets authenticates automatically when the address of the proxy includes +credentials e.g. ``http://user:password@proxy:8080/``. + .. admonition:: Any environment variable can configure a SOCKS proxy or an HTTP proxy. :class: tip @@ -64,3 +67,19 @@ SOCKS proxy is configured in the operating system, python-socks uses SOCKS5h. python-socks supports username/password authentication for SOCKS5 (:rfc:`1929`) but does not support other authentication methods such as GSSAPI (:rfc:`1961`). + +HTTP proxies +------------ + +When the address of the proxy starts with ``https://``, websockets secures the +connection to the proxy with TLS. + +When the address of the server starts with ``wss://``, websockets secures the +connection from the proxy to the server with TLS. + +These two options are compatible. TLS-in-TLS is supported. + +The documentation of :func:`~asyncio.client.connect` describes how to configure +TLS from websockets to the proxy and from the proxy to the server. + +websockets supports proxy authentication with Basic Auth. diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 058313388..c19a53f8c 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -4,20 +4,29 @@ import logging import os import socket +import ssl as ssl_module import traceback import urllib.parse from collections.abc import AsyncIterator, Generator, Sequence from types import TracebackType -from typing import Any, Callable, Literal +from typing import Any, Callable, Literal, cast from ..client import ClientProtocol, backoff -from ..datastructures import HeadersLike -from ..exceptions import InvalidMessage, InvalidStatus, ProxyError, SecurityError +from ..datastructures import Headers, HeadersLike +from ..exceptions import ( + InvalidMessage, + InvalidProxyMessage, + InvalidProxyStatus, + InvalidStatus, + ProxyError, + SecurityError, +) from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate -from ..headers import validate_subprotocols +from ..headers import build_authorization_basic, build_host, validate_subprotocols from ..http11 import USER_AGENT, Response from ..protocol import CONNECTING, Event +from ..streams import StreamReader from ..typing import LoggerLike, Origin, Subprotocol from ..uri import Proxy, WebSocketURI, get_proxy, parse_proxy, parse_uri from .compatibility import TimeoutError, asyncio_timeout @@ -266,6 +275,16 @@ class connect: :meth:`~asyncio.loop.create_connection` method) to create a suitable client socket and customize it. + When using a proxy: + + * Prefix keyword arguments with ``proxy_`` for configuring TLS between the + client and an HTTPS proxy: ``proxy_ssl``, ``proxy_server_hostname``, + ``proxy_ssl_handshake_timeout``, and ``proxy_ssl_shutdown_timeout``. + * Use the standard keyword arguments for configuring TLS between the proxy + and the WebSocket server: ``ssl``, ``server_hostname``, + ``ssl_handshake_timeout``, and ``ssl_shutdown_timeout``. + * Other keyword arguments are used only for connecting to the proxy. + Raises: InvalidURI: If ``uri`` isn't a valid WebSocket URI. InvalidProxy: If ``proxy`` isn't a valid proxy. @@ -397,6 +416,47 @@ def factory() -> ClientConnection: sock=sock, **kwargs, ) + elif proxy_parsed.scheme[:4] == "http": + # Split keyword arguments between the proxy and the server. + all_kwargs, proxy_kwargs, kwargs = kwargs, {}, {} + for key, value in all_kwargs.items(): + if key.startswith("ssl") or key == "server_hostname": + kwargs[key] = value + elif key.startswith("proxy_"): + proxy_kwargs[key[6:]] = value + else: + proxy_kwargs[key] = value + # Validate the proxy_ssl argument. + if proxy_parsed.scheme == "https": + proxy_kwargs.setdefault("ssl", True) + if proxy_kwargs.get("ssl") is None: + raise ValueError( + "proxy_ssl=None is incompatible with an https:// proxy" + ) + else: + if proxy_kwargs.get("ssl") is not None: + raise ValueError( + "proxy_ssl argument is incompatible with an http:// proxy" + ) + # Connect to the server through the proxy. + transport = await connect_http_proxy( + proxy_parsed, + ws_uri, + **proxy_kwargs, + ) + # Initialize WebSocket connection via the proxy. + connection = factory() + transport.set_protocol(connection) + ssl = kwargs.pop("ssl", None) + if ssl is True: + ssl = ssl_module.create_default_context() + if ssl is not None: + new_transport = await loop.start_tls( + transport, connection, ssl, **kwargs + ) + assert new_transport is not None # help mypy + transport = new_transport + connection.connection_made(transport) else: raise AssertionError("unsupported proxy") else: @@ -655,3 +715,89 @@ async def connect_socks_proxy( **kwargs: Any, ) -> socket.socket: raise ImportError("python-socks is required to use a SOCKS proxy") + + +def prepare_connect_request(proxy: Proxy, ws_uri: WebSocketURI) -> bytes: + host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True) + headers = Headers() + headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure) + if proxy.username is not None: + assert proxy.password is not None # enforced by parse_proxy() + headers["Proxy-Authorization"] = build_authorization_basic( + proxy.username, proxy.password + ) + # We cannot use the Request class because it supports only GET requests. + return f"CONNECT {host} HTTP/1.1\r\n".encode() + headers.serialize() + + +class HTTPProxyConnection(asyncio.Protocol): + def __init__(self, ws_uri: WebSocketURI, proxy: Proxy): + self.ws_uri = ws_uri + self.proxy = proxy + + self.reader = StreamReader() + self.parser = Response.parse( + self.reader.read_line, + self.reader.read_exact, + self.reader.read_to_eof, + include_body=False, + ) + + loop = asyncio.get_running_loop() + self.response: asyncio.Future[Response] = loop.create_future() + + def run_parser(self) -> None: + try: + next(self.parser) + except StopIteration as exc: + response = exc.value + if 200 <= response.status_code < 300: + self.response.set_result(response) + else: + self.response.set_exception(InvalidProxyStatus(response)) + except Exception as exc: + proxy_exc = InvalidProxyMessage( + "did not receive a valid HTTP response from proxy" + ) + proxy_exc.__cause__ = exc + self.response.set_exception(proxy_exc) + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + transport = cast(asyncio.Transport, transport) + self.transport = transport + self.transport.write(prepare_connect_request(self.proxy, self.ws_uri)) + + def data_received(self, data: bytes) -> None: + self.reader.feed_data(data) + self.run_parser() + + def eof_received(self) -> None: + self.reader.feed_eof() + self.run_parser() + + def connection_lost(self, exc: Exception | None) -> None: + self.reader.feed_eof() + if exc is not None: + self.response.set_exception(exc) + + +async def connect_http_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + **kwargs: Any, +) -> asyncio.Transport: + transport, protocol = await asyncio.get_running_loop().create_connection( + lambda: HTTPProxyConnection(ws_uri, proxy), + proxy.host, + proxy.port, + **kwargs, + ) + + try: + # This raises exceptions if the connection to the proxy fails. + await protocol.response + except Exception: + transport.close() + raise + + return transport diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 8ce9e7d84..c0fe6901a 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -5,16 +5,17 @@ import threading import warnings from collections.abc import Sequence -from typing import Any, Literal +from typing import Any, Callable, Literal, TypeVar, cast from ..client import ClientProtocol -from ..datastructures import HeadersLike -from ..exceptions import ProxyError +from ..datastructures import Headers, HeadersLike +from ..exceptions import InvalidProxyMessage, InvalidProxyStatus, ProxyError from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate -from ..headers import validate_subprotocols +from ..headers import build_authorization_basic, build_host, validate_subprotocols from ..http11 import USER_AGENT, Response from ..protocol import CONNECTING, Event +from ..streams import StreamReader from ..typing import LoggerLike, Origin, Subprotocol from ..uri import Proxy, WebSocketURI, get_proxy, parse_proxy, parse_uri from .connection import Connection @@ -141,6 +142,8 @@ def connect( additional_headers: HeadersLike | None = None, user_agent_header: str | None = USER_AGENT, proxy: str | Literal[True] | None = True, + proxy_ssl: ssl_module.SSLContext | None = None, + proxy_server_hostname: str | None = None, # Timeouts open_timeout: float | None = 10, ping_interval: float | None = 20, @@ -195,6 +198,9 @@ def connect( to :obj:`None` to disable the proxy or to the address of a proxy to override the system configuration. See the :doc:`proxy docs <../../topics/proxies>` for details. + proxy_ssl: Configuration for enabling TLS on the proxy connection. + proxy_server_hostname: Host name for the TLS handshake with the proxy. + ``proxy_server_hostname`` overrides the host name from ``proxy``. open_timeout: Timeout for opening the connection in seconds. :obj:`None` disables the timeout. ping_interval: Interval between keepalive pings in seconds. @@ -295,6 +301,21 @@ def connect( # python_socks is consistent across implementations. local_addr=kwargs.pop("source_address", None), ) + elif proxy_parsed.scheme[:4] == "http": + # Validate the proxy_ssl argument. + if proxy_parsed.scheme != "https" and proxy_ssl is not None: + raise ValueError( + "proxy_ssl argument is incompatible with an http:// proxy" + ) + # Connect to the server through the proxy. + sock = connect_http_proxy( + proxy_parsed, + ws_uri, + deadline, + ssl=proxy_ssl, + server_hostname=proxy_server_hostname, + **kwargs, + ) else: raise AssertionError("unsupported proxy") else: @@ -318,7 +339,12 @@ def connect( if server_hostname is None: server_hostname = ws_uri.host sock.settimeout(deadline.timeout()) - sock = ssl.wrap_socket(sock, server_hostname=server_hostname) + if proxy_ssl is None: + sock = ssl.wrap_socket(sock, server_hostname=server_hostname) + else: + sock_2 = SSLSSLSocket(sock, ssl, server_hostname=server_hostname) + # Let's pretend that sock is a socket, even though it isn't. + sock = cast(socket.socket, sock_2) sock.settimeout(None) # Initialize WebSocket protocol @@ -444,3 +470,171 @@ def connect_socks_proxy( **kwargs: Any, ) -> socket.socket: raise ImportError("python-socks is required to use a SOCKS proxy") + + +def prepare_connect_request(proxy: Proxy, ws_uri: WebSocketURI) -> bytes: + host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True) + headers = Headers() + headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure) + if proxy.username is not None: + assert proxy.password is not None # enforced by parse_proxy() + headers["Proxy-Authorization"] = build_authorization_basic( + proxy.username, proxy.password + ) + # We cannot use the Request class because it supports only GET requests. + return f"CONNECT {host} HTTP/1.1\r\n".encode() + headers.serialize() + + +def read_connect_response(sock: socket.socket, deadline: Deadline) -> Response: + reader = StreamReader() + parser = Response.parse( + reader.read_line, + reader.read_exact, + reader.read_to_eof, + include_body=False, + ) + try: + while True: + sock.settimeout(deadline.timeout()) + data = sock.recv(4096) + if data: + reader.feed_data(data) + else: + reader.feed_eof() + next(parser) + except StopIteration as exc: + assert isinstance(exc.value, Response) # help mypy + response = exc.value + if 200 <= response.status_code < 300: + return response + else: + raise InvalidProxyStatus(response) + except socket.timeout: + raise TimeoutError("timed out while connecting to HTTP proxy") + except Exception as exc: + raise InvalidProxyMessage( + "did not receive a valid HTTP response from proxy" + ) from exc + finally: + sock.settimeout(None) + + +def connect_http_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + deadline: Deadline, + *, + ssl: ssl_module.SSLContext | None = None, + server_hostname: str | None = None, + **kwargs: Any, +) -> socket.socket: + # Connect socket + + kwargs.setdefault("timeout", deadline.timeout()) + sock = socket.create_connection((proxy.host, proxy.port), **kwargs) + + # Initialize TLS wrapper and perform TLS handshake + + if proxy.scheme == "https": + if ssl is None: + ssl = ssl_module.create_default_context() + if server_hostname is None: + server_hostname = proxy.host + sock.settimeout(deadline.timeout()) + sock = ssl.wrap_socket(sock, server_hostname=server_hostname) + sock.settimeout(None) + + # Send CONNECT request to the proxy and read response. + + sock.sendall(prepare_connect_request(proxy, ws_uri)) + try: + read_connect_response(sock, deadline) + except Exception: + sock.close() + raise + + return sock + + +T = TypeVar("T") +F = TypeVar("F", bound=Callable[..., T]) + + +class SSLSSLSocket: + """ + Socket-like object providing TLS-in-TLS. + + Only methods that are used by websockets are implemented. + + """ + + recv_bufsize = 65536 + + def __init__( + self, + sock: socket.socket, + ssl_context: ssl_module.SSLContext, + server_hostname: str | None = None, + ) -> None: + self.incoming = ssl_module.MemoryBIO() + self.outgoing = ssl_module.MemoryBIO() + self.ssl_socket = sock + self.ssl_object = ssl_context.wrap_bio( + self.incoming, + self.outgoing, + server_hostname=server_hostname, + ) + self.run_io(self.ssl_object.do_handshake) + + def run_io(self, func: Callable[..., T], *args: Any) -> T: + while True: + want_read = False + want_write = False + try: + result = func(*args) + except ssl_module.SSLWantReadError: + want_read = True + except ssl_module.SSLWantWriteError: # pragma: no cover + want_write = True + + # Write outgoing data in all cases. + data = self.outgoing.read() + if data: + self.ssl_socket.sendall(data) + + # Read incoming data and retry on SSLWantReadError. + if want_read: + data = self.ssl_socket.recv(self.recv_bufsize) + if data: + self.incoming.write(data) + else: + self.incoming.write_eof() + continue + # Retry after writing outgoing data on SSLWantWriteError. + if want_write: # pragma: no cover + continue + # Return result if no error happened. + return result + + def recv(self, buflen: int) -> bytes: + try: + return self.run_io(self.ssl_object.read, buflen) + except ssl_module.SSLEOFError: + return b"" # always ignore ragged EOFs + + def send(self, data: bytes) -> int: + return self.run_io(self.ssl_object.write, data) + + def sendall(self, data: bytes) -> None: + # adapted from ssl_module.SSLSocket.sendall() + count = 0 + with memoryview(data) as view, view.cast("B") as byte_view: + amount = len(byte_view) + while count < amount: + count += self.send(byte_view[count:]) + + # recv_into(), recvfrom(), recvfrom_into(), sendto(), unwrap(), and the + # flags argument aren't implemented because websockets doesn't need them. + + def __getattr__(self, name: str) -> Any: + return getattr(self.ssl_socket, name) diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 3728c7734..c6ff26ae4 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -15,6 +15,7 @@ InvalidHandshake, InvalidMessage, InvalidProxy, + InvalidProxyMessage, InvalidStatus, InvalidURI, ProxyError, @@ -667,6 +668,181 @@ async def test_ignore_proxy_with_existing_socket(self): self.assertNumFlows(0) +@unittest.skipUnless("mitmproxy" in sys.modules, "mitmproxy not installed") +class HTTPProxyClientTests(ProxyMixin, unittest.IsolatedAsyncioTestCase): + proxy_mode = "regular@58080" + + async def test_http_proxy(self): + """Client connects to server through an HTTP proxy.""" + with patch_environ({"https_proxy": "http://localhost:58080"}): + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + + async def test_secure_http_proxy(self): + """Client connects to server securely through an HTTP proxy.""" + with patch_environ({"https_proxy": "http://localhost:58080"}): + async with serve(*args, ssl=SERVER_CONTEXT) as server: + async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.version()[:3], "TLS") + self.assertNumFlows(1) + + async def test_authenticated_http_proxy(self): + """Client connects to server through an authenticated HTTP proxy.""" + try: + self.proxy_options.update(proxyauth="hello:iloveyou") + with patch_environ( + {"https_proxy": "http://hello:iloveyou@localhost:58080"} + ): + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + finally: + self.proxy_options.update(proxyauth=None) + self.assertNumFlows(1) + + async def test_authenticated_http_proxy_error(self): + """Client fails to authenticate to the HTTP proxy.""" + try: + self.proxy_options.update(proxyauth="any") + with patch_environ({"https_proxy": "http://localhost:58080"}): + with self.assertRaises(ProxyError) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(proxyauth=None) + self.assertEqual( + str(raised.exception), + "proxy rejected connection: HTTP 407", + ) + self.assertNumFlows(0) + + async def test_http_proxy_protocol_error(self): + """Client receives invalid data when connecting to the HTTP proxy.""" + try: + self.proxy_options.update(break_http_connect=True) + with patch_environ({"https_proxy": "http://localhost:58080"}): + with self.assertRaises(InvalidProxyMessage) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(break_http_connect=False) + self.assertEqual( + str(raised.exception), + "did not receive a valid HTTP response from proxy", + ) + self.assertNumFlows(0) + + async def test_http_proxy_connection_error(self): + """Client receives no response when connecting to the HTTP proxy.""" + try: + self.proxy_options.update(close_http_connect=True) + with patch_environ({"https_proxy": "http://localhost:58080"}): + with self.assertRaises(InvalidProxyMessage) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(close_http_connect=False) + self.assertEqual( + str(raised.exception), + "did not receive a valid HTTP response from proxy", + ) + self.assertNumFlows(0) + + async def test_http_proxy_connection_failure(self): + """Client fails to connect to the HTTP proxy.""" + with patch_environ({"https_proxy": "http://localhost:61080"}): # bad port + with self.assertRaises(OSError): + async with connect("ws://example.com/"): + self.fail("did not raise") + # Don't test str(raised.exception) because we don't control it. + self.assertNumFlows(0) + + async def test_http_proxy_connection_timeout(self): + """Client times out while connecting to the HTTP proxy.""" + # Replace the proxy with a TCP server that doesn't respond. + with socket.create_server(("localhost", 0)) as sock: + host, port = sock.getsockname() + with patch_environ({"https_proxy": f"http://{host}:{port}"}): + with self.assertRaises(TimeoutError) as raised: + async with connect("ws://example.com/", open_timeout=MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "timed out during opening handshake", + ) + + async def test_https_proxy(self): + """Client connects to server through an HTTPS proxy.""" + with patch_environ({"https_proxy": "https://localhost:58080"}): + async with serve(*args) as server: + async with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + + async def test_secure_https_proxy(self): + """Client connects to server securely through an HTTPS proxy.""" + with patch_environ({"https_proxy": "https://localhost:58080"}): + async with serve(*args, ssl=SERVER_CONTEXT) as server: + async with connect( + get_uri(server), + ssl=CLIENT_CONTEXT, + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.version()[:3], "TLS") + self.assertNumFlows(1) + + async def test_https_server_hostname(self): + """Client sets server_hostname to the value of proxy_server_hostname.""" + with patch_environ({"https_proxy": "https://localhost:58080"}): + async with serve(*args) as server: + # Pass an argument not prefixed with proxy_ for coverage. + kwargs = {"all_errors": True} if sys.version_info >= (3, 12) else {} + async with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + proxy_server_hostname="overridden", + **kwargs, + ) as client: + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.server_hostname, "overridden") + self.assertNumFlows(1) + + async def test_https_proxy_invalid_proxy_certificate(self): + """Client rejects certificate when proxy certificate isn't trusted.""" + with patch_environ({"https_proxy": "https://localhost:58080"}): + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The proxy certificate isn't trusted. + async with connect("wss://example.com/"): + self.fail("did not raise") + self.assertIn( + "certificate verify failed: unable to get local issuer certificate", + str(raised.exception), + ) + + async def test_https_proxy_invalid_server_certificate(self): + """Client rejects certificate when proxy certificate isn't trusted.""" + with patch_environ({"https_proxy": "https://localhost:58080"}): + async with serve(*args, ssl=SERVER_CONTEXT) as server: + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The test certificate is self-signed. + async with connect(get_uri(server), proxy_ssl=self.proxy_context): + self.fail("did not raise") + self.assertIn( + "certificate verify failed: self signed certificate", + str(raised.exception).replace("-", " "), + ) + self.assertNumFlows(1) + + @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") class UnixClientTests(unittest.IsolatedAsyncioTestCase): async def test_connection(self): @@ -737,7 +913,7 @@ async def test_ssl_without_secure_uri(self): ) async def test_secure_uri_without_ssl(self): - """Client rejects no ssl when URI is secure.""" + """Client rejects ssl=None when URI is secure.""" with self.assertRaises(ValueError) as raised: await connect("wss://localhost/", ssl=None) self.assertEqual( @@ -745,6 +921,26 @@ async def test_secure_uri_without_ssl(self): "ssl=None is incompatible with a wss:// URI", ) + async def test_proxy_ssl_without_https_proxy(self): + """Client rejects proxy_ssl when proxy isn't HTTPS.""" + with patch_environ({"https_proxy": "http://localhost:8080"}): + with self.assertRaises(ValueError) as raised: + await connect("ws://localhost/", proxy_ssl=True) + self.assertEqual( + str(raised.exception), + "proxy_ssl argument is incompatible with an http:// proxy", + ) + + async def test_https_proxy_without_ssl(self): + """Client rejects proxy_ssl=None when proxy is HTTPS.""" + with patch_environ({"https_proxy": "https://localhost:8080"}): + with self.assertRaises(ValueError) as raised: + await connect("ws://localhost/", proxy_ssl=None) + self.assertEqual( + str(raised.exception), + "proxy_ssl=None is incompatible with an https:// proxy", + ) + async def test_unsupported_proxy(self): """Client rejects unsupported proxy.""" with self.assertRaises(InvalidProxy) as raised: diff --git a/tests/proxy.py b/tests/proxy.py index 6cae7b49e..9746e3382 100644 --- a/tests/proxy.py +++ b/tests/proxy.py @@ -1,4 +1,6 @@ import asyncio +import pathlib +import ssl import threading import warnings @@ -8,9 +10,11 @@ warnings.filterwarnings("ignore", category=DeprecationWarning, module="passlib") warnings.filterwarnings("ignore", category=DeprecationWarning, module="pyasn1") + from mitmproxy import ctx from mitmproxy.addons import core, next_layer, proxyauth, proxyserver, tlsconfig + from mitmproxy.http import Response from mitmproxy.master import Master - from mitmproxy.options import Options + from mitmproxy.options import CONF_BASENAME, CONF_DIR, Options except ImportError: pass @@ -31,6 +35,31 @@ def reset_flows(self): self.flows = [] +class AlterRequest: + def load(self, loader): + loader.add_option( + name="break_http_connect", + typespec=bool, + default=False, + help="Respond to HTTP CONNECT requests with a 999 status code.", + ) + loader.add_option( + name="close_http_connect", + typespec=bool, + default=False, + help="Do not respond to HTTP CONNECT requests.", + ) + + def http_connect(self, flow): + if ctx.options.break_http_connect: + # mitmproxy can send a response with a status code not between 100 + # and 599, while websockets treats it as a protocol error. + # This is used for testing HTTP parsing errors. + flow.response = Response.make(999, "not a valid HTTP response") + if ctx.options.close_http_connect: + flow.kill() + + class ProxyMixin: """ Run mitmproxy in a background thread. @@ -62,6 +91,7 @@ async def run_proxy(cls): next_layer.NextLayer(), tlsconfig.TlsConfig(), RecordFlows(on_running=cls.proxy_ready.set), + AlterRequest(), ) task = loop.create_task(cls.proxy_master.run()) @@ -86,6 +116,11 @@ def setUpClass(cls): cls.proxy_thread.start() cls.proxy_ready.wait() + certificate = pathlib.Path(CONF_DIR) / f"{CONF_BASENAME}-ca-cert.pem" + certificate = certificate.expanduser() + cls.proxy_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + cls.proxy_context.load_verify_locations(bytes(certificate)) + def assertNumFlows(self, num_flows): record_flows = self.proxy_master.addons.get("recordflows") self.assertEqual(len(record_flows.get_flows()), num_flows) diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 47886e015..386caf56a 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -12,6 +12,7 @@ InvalidHandshake, InvalidMessage, InvalidProxy, + InvalidProxyMessage, InvalidStatus, InvalidURI, ProxyError, @@ -407,6 +408,179 @@ def test_ignore_proxy_with_existing_socket(self): self.assertNumFlows(0) +@unittest.skipUnless("mitmproxy" in sys.modules, "mitmproxy not installed") +class HTTPProxyClientTests(ProxyMixin, unittest.IsolatedAsyncioTestCase): + proxy_mode = "regular@58080" + + def test_http_proxy(self): + """Client connects to server through an HTTP proxy.""" + with patch_environ({"https_proxy": "http://localhost:58080"}): + with run_server() as server: + with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + + def test_secure_http_proxy(self): + """Client connects to server securely through an HTTP proxy.""" + with patch_environ({"https_proxy": "http://localhost:58080"}): + with run_server(ssl=SERVER_CONTEXT) as server: + with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(client.socket.version()[:3], "TLS") + self.assertNumFlows(1) + + def test_authenticated_http_proxy(self): + """Client connects to server through an authenticated HTTP proxy.""" + try: + self.proxy_options.update(proxyauth="hello:iloveyou") + with patch_environ( + {"https_proxy": "http://hello:iloveyou@localhost:58080"} + ): + with run_server() as server: + with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + finally: + self.proxy_options.update(proxyauth=None) + self.assertNumFlows(1) + + def test_authenticated_http_proxy_error(self): + """Client fails to authenticate to the HTTP proxy.""" + try: + self.proxy_options.update(proxyauth="any") + with patch_environ({"https_proxy": "http://localhost:58080"}): + with self.assertRaises(ProxyError) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(proxyauth=None) + self.assertEqual( + str(raised.exception), + "proxy rejected connection: HTTP 407", + ) + self.assertNumFlows(0) + + def test_http_proxy_protocol_error(self): + """Client receives invalid data when connecting to the HTTP proxy.""" + try: + self.proxy_options.update(break_http_connect=True) + with patch_environ({"https_proxy": "http://localhost:58080"}): + with self.assertRaises(InvalidProxyMessage) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(break_http_connect=False) + self.assertEqual( + str(raised.exception), + "did not receive a valid HTTP response from proxy", + ) + self.assertNumFlows(0) + + def test_http_proxy_connection_error(self): + """Client receives no response when connecting to the HTTP proxy.""" + try: + self.proxy_options.update(close_http_connect=True) + with patch_environ({"https_proxy": "http://localhost:58080"}): + with self.assertRaises(InvalidProxyMessage) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(close_http_connect=False) + self.assertEqual( + str(raised.exception), + "did not receive a valid HTTP response from proxy", + ) + self.assertNumFlows(0) + + def test_http_proxy_connection_failure(self): + """Client fails to connect to the HTTP proxy.""" + with patch_environ({"https_proxy": "http://localhost:61080"}): # bad port + with self.assertRaises(OSError): + with connect("ws://example.com/"): + self.fail("did not raise") + # Don't test str(raised.exception) because we don't control it. + self.assertNumFlows(0) + + def test_http_proxy_connection_timeout(self): + """Client times out while connecting to the HTTP proxy.""" + # Replace the proxy with a TCP server that does't respond. + with socket.create_server(("localhost", 0)) as sock: + host, port = sock.getsockname() + with patch_environ({"https_proxy": f"http://{host}:{port}"}): + with self.assertRaises(TimeoutError) as raised: + with connect("ws://example.com/", open_timeout=MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "timed out while connecting to HTTP proxy", + ) + + def test_https_proxy(self): + """Client connects to server through an HTTPS proxy.""" + with patch_environ({"https_proxy": "https://localhost:58080"}): + with run_server() as server: + with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + + def test_secure_https_proxy(self): + """Client connects to server securely through an HTTPS proxy.""" + with patch_environ({"https_proxy": "https://localhost:58080"}): + with run_server(ssl=SERVER_CONTEXT) as server: + with connect( + get_uri(server), + ssl=CLIENT_CONTEXT, + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(client.socket.version()[:3], "TLS") + self.assertNumFlows(1) + + def test_https_proxy_server_hostname(self): + """Client sets server_hostname to the value of proxy_server_hostname.""" + with patch_environ({"https_proxy": "https://localhost:58080"}): + with run_server() as server: + # Pass an argument not prefixed with proxy_ for coverage. + kwargs = {"all_errors": True} if sys.version_info >= (3, 11) else {} + with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + proxy_server_hostname="overridden", + **kwargs, + ) as client: + self.assertEqual(client.socket.server_hostname, "overridden") + self.assertNumFlows(1) + + def test_https_proxy_invalid_proxy_certificate(self): + """Client rejects certificate when proxy certificate isn't trusted.""" + with patch_environ({"https_proxy": "https://localhost:58080"}): + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The proxy certificate isn't trusted. + with connect("wss://example.com/"): + self.fail("did not raise") + self.assertIn( + "certificate verify failed: unable to get local issuer certificate", + str(raised.exception), + ) + self.assertNumFlows(0) + + def test_https_proxy_invalid_server_certificate(self): + """Client rejects certificate when server certificate isn't trusted.""" + with patch_environ({"https_proxy": "https://localhost:58080"}): + with run_server(ssl=SERVER_CONTEXT) as server: + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The test certificate is self-signed. + with connect(get_uri(server), proxy_ssl=self.proxy_context): + self.fail("did not raise") + self.assertIn( + "certificate verify failed: self signed certificate", + str(raised.exception).replace("-", " "), + ) + self.assertNumFlows(1) + + @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") class UnixClientTests(unittest.TestCase): def test_connection(self): @@ -453,6 +627,16 @@ def test_ssl_without_secure_uri(self): "ssl argument is incompatible with a ws:// URI", ) + def test_proxy_ssl_without_https_proxy(self): + """Client rejects proxy_ssl when proxy isn't HTTPS.""" + with patch_environ({"https_proxy": "http://localhost:8080"}): + with self.assertRaises(ValueError) as raised: + connect("ws://localhost/", proxy_ssl=True) + self.assertEqual( + str(raised.exception), + "proxy_ssl argument is incompatible with an http:// proxy", + ) + def test_unix_without_path_or_sock(self): """Unix client requires path when sock isn't provided.""" with self.assertRaises(ValueError) as raised: diff --git a/tests/utils.py b/tests/utils.py index f68a447b1..389381345 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,14 +20,13 @@ # $ cat test_localhost.key test_localhost.crt > test_localhost.pem # $ rm test_localhost.key test_localhost.crt -CERTIFICATE = bytes(pathlib.Path(__file__).with_name("test_localhost.pem")) +CERTIFICATE = pathlib.Path(__file__).with_name("test_localhost.pem") CLIENT_CONTEXT = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) -CLIENT_CONTEXT.load_verify_locations(CERTIFICATE) - +CLIENT_CONTEXT.load_verify_locations(bytes(CERTIFICATE)) SERVER_CONTEXT = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) -SERVER_CONTEXT.load_cert_chain(CERTIFICATE) +SERVER_CONTEXT.load_cert_chain(bytes(CERTIFICATE)) # Work around https://github.com/openssl/openssl/issues/7967 From 2b9a90a7edc305f5619b229ae5cbda755e173da9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Feb 2025 21:51:10 +0100 Subject: [PATCH 1506/1539] Replace patch_environ with unittest.mock.patch.dict. --- tests/asyncio/test_client.py | 228 +++++++++++++++++------------------ tests/sync/test_client.py | 206 +++++++++++++++---------------- tests/test_uri.py | 6 +- tests/utils.py | 16 --- 4 files changed, 219 insertions(+), 237 deletions(-) diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index c6ff26ae4..c2a96f3ec 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -2,10 +2,12 @@ import contextlib import http import logging +import os import socket import ssl import sys import unittest +from unittest.mock import patch from websockets.asyncio.client import * from websockets.asyncio.compatibility import TimeoutError @@ -24,13 +26,7 @@ from websockets.extensions.permessage_deflate import PerMessageDeflate from ..proxy import ProxyMixin -from ..utils import ( - CLIENT_CONTEXT, - MS, - SERVER_CONTEXT, - patch_environ, - temp_unix_socket_path, -) +from ..utils import CLIENT_CONTEXT, MS, SERVER_CONTEXT, temp_unix_socket_path from .server import args, get_host_port, get_uri, handler @@ -570,46 +566,44 @@ def redirect(connection, request): class SocksProxyClientTests(ProxyMixin, unittest.IsolatedAsyncioTestCase): proxy_mode = "socks5@51080" + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) async def test_socks_proxy(self): """Client connects to server through a SOCKS5 proxy.""" - with patch_environ({"socks_proxy": "http://localhost:51080"}): - async with serve(*args) as server: - async with connect(get_uri(server)) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) async def test_secure_socks_proxy(self): """Client connects to server securely through a SOCKS5 proxy.""" - with patch_environ({"socks_proxy": "http://localhost:51080"}): - async with serve(*args, ssl=SERVER_CONTEXT) as server: - async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + async with serve(*args, ssl=SERVER_CONTEXT) as server: + async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) + @patch.dict(os.environ, {"socks_proxy": "http://hello:iloveyou@localhost:51080"}) async def test_authenticated_socks_proxy(self): """Client connects to server through an authenticated SOCKS5 proxy.""" try: self.proxy_options.update(proxyauth="hello:iloveyou") - with patch_environ( - {"socks_proxy": "http://hello:iloveyou@localhost:51080"} - ): - async with serve(*args) as server: - async with connect(get_uri(server)) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") finally: self.proxy_options.update(proxyauth=None) self.assertNumFlows(1) + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) async def test_authenticated_socks_proxy_error(self): """Client fails to authenticate to the SOCKS5 proxy.""" from python_socks import ProxyError as SocksProxyError try: self.proxy_options.update(proxyauth="any") - with patch_environ({"socks_proxy": "http://localhost:51080"}): - with self.assertRaises(ProxyError) as raised: - async with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(ProxyError) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") finally: self.proxy_options.update(proxyauth=None) self.assertEqual( @@ -619,14 +613,14 @@ async def test_authenticated_socks_proxy_error(self): self.assertIsInstance(raised.exception.__cause__, SocksProxyError) self.assertNumFlows(0) + @patch.dict(os.environ, {"socks_proxy": "http://localhost:61080"}) # bad port async def test_socks_proxy_connection_failure(self): """Client fails to connect to the SOCKS5 proxy.""" from python_socks import ProxyConnectionError as SocksProxyConnectionError - with patch_environ({"socks_proxy": "http://localhost:61080"}): # bad port - with self.assertRaises(OSError) as raised: - async with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(OSError) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") # Don't test str(raised.exception) because we don't control it. self.assertIsInstance(raised.exception, SocksProxyConnectionError) self.assertNumFlows(0) @@ -636,7 +630,7 @@ async def test_socks_proxy_connection_timeout(self): # Replace the proxy with a TCP server that doesn't respond. with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() - with patch_environ({"socks_proxy": f"http://{host}:{port}"}): + with patch.dict(os.environ, {"socks_proxy": f"http://{host}:{port}"}): with self.assertRaises(TimeoutError) as raised: async with connect("ws://example.com/", open_timeout=MS): self.fail("did not raise") @@ -657,14 +651,14 @@ async def test_explicit_socks_proxy(self): self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) async def test_ignore_proxy_with_existing_socket(self): """Client connects using a pre-existing socket.""" - with patch_environ({"socks_proxy": "http://localhost:51080"}): - async with serve(*args) as server: - with socket.create_connection(get_host_port(server)) as sock: - # Use a non-existing domain to ensure we connect to sock. - async with connect("ws://invalid/", sock=sock) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + async with serve(*args) as server: + with socket.create_connection(get_host_port(server)) as sock: + # Use a non-existing domain to ensure we connect to sock. + async with connect("ws://invalid/", sock=sock) as client: + self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(0) @@ -672,46 +666,44 @@ async def test_ignore_proxy_with_existing_socket(self): class HTTPProxyClientTests(ProxyMixin, unittest.IsolatedAsyncioTestCase): proxy_mode = "regular@58080" + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) async def test_http_proxy(self): """Client connects to server through an HTTP proxy.""" - with patch_environ({"https_proxy": "http://localhost:58080"}): - async with serve(*args) as server: - async with connect(get_uri(server)) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) async def test_secure_http_proxy(self): """Client connects to server securely through an HTTP proxy.""" - with patch_environ({"https_proxy": "http://localhost:58080"}): - async with serve(*args, ssl=SERVER_CONTEXT) as server: - async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: - self.assertEqual(client.protocol.state.name, "OPEN") - ssl_object = client.transport.get_extra_info("ssl_object") - self.assertEqual(ssl_object.version()[:3], "TLS") + async with serve(*args, ssl=SERVER_CONTEXT) as server: + async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.version()[:3], "TLS") self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "http://hello:iloveyou@localhost:58080"}) async def test_authenticated_http_proxy(self): """Client connects to server through an authenticated HTTP proxy.""" try: self.proxy_options.update(proxyauth="hello:iloveyou") - with patch_environ( - {"https_proxy": "http://hello:iloveyou@localhost:58080"} - ): - async with serve(*args) as server: - async with connect(get_uri(server)) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") finally: self.proxy_options.update(proxyauth=None) self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) async def test_authenticated_http_proxy_error(self): """Client fails to authenticate to the HTTP proxy.""" try: self.proxy_options.update(proxyauth="any") - with patch_environ({"https_proxy": "http://localhost:58080"}): - with self.assertRaises(ProxyError) as raised: - async with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(ProxyError) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") finally: self.proxy_options.update(proxyauth=None) self.assertEqual( @@ -720,14 +712,14 @@ async def test_authenticated_http_proxy_error(self): ) self.assertNumFlows(0) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) async def test_http_proxy_protocol_error(self): """Client receives invalid data when connecting to the HTTP proxy.""" try: self.proxy_options.update(break_http_connect=True) - with patch_environ({"https_proxy": "http://localhost:58080"}): - with self.assertRaises(InvalidProxyMessage) as raised: - async with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(InvalidProxyMessage) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") finally: self.proxy_options.update(break_http_connect=False) self.assertEqual( @@ -736,14 +728,14 @@ async def test_http_proxy_protocol_error(self): ) self.assertNumFlows(0) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) async def test_http_proxy_connection_error(self): """Client receives no response when connecting to the HTTP proxy.""" try: self.proxy_options.update(close_http_connect=True) - with patch_environ({"https_proxy": "http://localhost:58080"}): - with self.assertRaises(InvalidProxyMessage) as raised: - async with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(InvalidProxyMessage) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") finally: self.proxy_options.update(close_http_connect=False) self.assertEqual( @@ -752,12 +744,12 @@ async def test_http_proxy_connection_error(self): ) self.assertNumFlows(0) + @patch.dict(os.environ, {"https_proxy": "http://localhost:48080"}) # bad port async def test_http_proxy_connection_failure(self): """Client fails to connect to the HTTP proxy.""" - with patch_environ({"https_proxy": "http://localhost:61080"}): # bad port - with self.assertRaises(OSError): - async with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(OSError): + async with connect("ws://example.com/"): + self.fail("did not raise") # Don't test str(raised.exception) because we don't control it. self.assertNumFlows(0) @@ -766,7 +758,7 @@ async def test_http_proxy_connection_timeout(self): # Replace the proxy with a TCP server that doesn't respond. with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() - with patch_environ({"https_proxy": f"http://{host}:{port}"}): + with patch.dict(os.environ, {"https_proxy": f"http://{host}:{port}"}): with self.assertRaises(TimeoutError) as raised: async with connect("ws://example.com/", open_timeout=MS): self.fail("did not raise") @@ -775,67 +767,67 @@ async def test_http_proxy_connection_timeout(self): "timed out during opening handshake", ) + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) async def test_https_proxy(self): """Client connects to server through an HTTPS proxy.""" - with patch_environ({"https_proxy": "https://localhost:58080"}): - async with serve(*args) as server: - async with connect( - get_uri(server), - proxy_ssl=self.proxy_context, - ) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + async with serve(*args) as server: + async with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) async def test_secure_https_proxy(self): """Client connects to server securely through an HTTPS proxy.""" - with patch_environ({"https_proxy": "https://localhost:58080"}): - async with serve(*args, ssl=SERVER_CONTEXT) as server: - async with connect( - get_uri(server), - ssl=CLIENT_CONTEXT, - proxy_ssl=self.proxy_context, - ) as client: - self.assertEqual(client.protocol.state.name, "OPEN") - ssl_object = client.transport.get_extra_info("ssl_object") - self.assertEqual(ssl_object.version()[:3], "TLS") + async with serve(*args, ssl=SERVER_CONTEXT) as server: + async with connect( + get_uri(server), + ssl=CLIENT_CONTEXT, + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.version()[:3], "TLS") self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) async def test_https_server_hostname(self): """Client sets server_hostname to the value of proxy_server_hostname.""" - with patch_environ({"https_proxy": "https://localhost:58080"}): - async with serve(*args) as server: - # Pass an argument not prefixed with proxy_ for coverage. - kwargs = {"all_errors": True} if sys.version_info >= (3, 12) else {} - async with connect( - get_uri(server), - proxy_ssl=self.proxy_context, - proxy_server_hostname="overridden", - **kwargs, - ) as client: - ssl_object = client.transport.get_extra_info("ssl_object") - self.assertEqual(ssl_object.server_hostname, "overridden") + async with serve(*args) as server: + # Pass an argument not prefixed with proxy_ for coverage. + kwargs = {"all_errors": True} if sys.version_info >= (3, 12) else {} + async with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + proxy_server_hostname="overridden", + **kwargs, + ) as client: + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.server_hostname, "overridden") self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) async def test_https_proxy_invalid_proxy_certificate(self): """Client rejects certificate when proxy certificate isn't trusted.""" - with patch_environ({"https_proxy": "https://localhost:58080"}): - with self.assertRaises(ssl.SSLCertVerificationError) as raised: - # The proxy certificate isn't trusted. - async with connect("wss://example.com/"): - self.fail("did not raise") + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The proxy certificate isn't trusted. + async with connect("wss://example.com/"): + self.fail("did not raise") self.assertIn( "certificate verify failed: unable to get local issuer certificate", str(raised.exception), ) + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) async def test_https_proxy_invalid_server_certificate(self): """Client rejects certificate when proxy certificate isn't trusted.""" - with patch_environ({"https_proxy": "https://localhost:58080"}): - async with serve(*args, ssl=SERVER_CONTEXT) as server: - with self.assertRaises(ssl.SSLCertVerificationError) as raised: - # The test certificate is self-signed. - async with connect(get_uri(server), proxy_ssl=self.proxy_context): - self.fail("did not raise") + async with serve(*args, ssl=SERVER_CONTEXT) as server: + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The test certificate is self-signed. + async with connect(get_uri(server), proxy_ssl=self.proxy_context): + self.fail("did not raise") self.assertIn( "certificate verify failed: self signed certificate", str(raised.exception).replace("-", " "), @@ -923,9 +915,12 @@ async def test_secure_uri_without_ssl(self): async def test_proxy_ssl_without_https_proxy(self): """Client rejects proxy_ssl when proxy isn't HTTPS.""" - with patch_environ({"https_proxy": "http://localhost:8080"}): - with self.assertRaises(ValueError) as raised: - await connect("ws://localhost/", proxy_ssl=True) + with self.assertRaises(ValueError) as raised: + await connect( + "ws://localhost/", + proxy="http://localhost:8080", + proxy_ssl=True, + ) self.assertEqual( str(raised.exception), "proxy_ssl argument is incompatible with an http:// proxy", @@ -933,9 +928,12 @@ async def test_proxy_ssl_without_https_proxy(self): async def test_https_proxy_without_ssl(self): """Client rejects proxy_ssl=None when proxy is HTTPS.""" - with patch_environ({"https_proxy": "https://localhost:8080"}): - with self.assertRaises(ValueError) as raised: - await connect("ws://localhost/", proxy_ssl=None) + with self.assertRaises(ValueError) as raised: + await connect( + "ws://localhost/", + proxy="https://localhost:8080", + proxy_ssl=None, + ) self.assertEqual( str(raised.exception), "proxy_ssl=None is incompatible with an https:// proxy", diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 386caf56a..e4927bb32 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -1,5 +1,6 @@ import http import logging +import os import socket import socketserver import ssl @@ -7,6 +8,7 @@ import threading import time import unittest +from unittest.mock import patch from websockets.exceptions import ( InvalidHandshake, @@ -26,7 +28,6 @@ MS, SERVER_CONTEXT, DeprecationTestCase, - patch_environ, temp_unix_socket_path, ) from .server import get_uri, run_server, run_unix_server @@ -310,46 +311,44 @@ def test_reject_invalid_server_hostname(self): class SocksProxyClientTests(ProxyMixin, unittest.TestCase): proxy_mode = "socks5@51080" + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) def test_socks_proxy(self): """Client connects to server through a SOCKS5 proxy.""" - with patch_environ({"socks_proxy": "http://localhost:51080"}): - with run_server() as server: - with connect(get_uri(server)) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + with run_server() as server: + with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) def test_secure_socks_proxy(self): """Client connects to server securely through a SOCKS5 proxy.""" - with patch_environ({"socks_proxy": "http://localhost:51080"}): - with run_server(ssl=SERVER_CONTEXT) as server: - with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + with run_server(ssl=SERVER_CONTEXT) as server: + with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) + @patch.dict(os.environ, {"socks_proxy": "http://hello:iloveyou@localhost:51080"}) def test_authenticated_socks_proxy(self): """Client connects to server through an authenticated SOCKS5 proxy.""" try: self.proxy_options.update(proxyauth="hello:iloveyou") - with patch_environ( - {"socks_proxy": "http://hello:iloveyou@localhost:51080"} - ): - with run_server() as server: - with connect(get_uri(server)) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + with run_server() as server: + with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") finally: self.proxy_options.update(proxyauth=None) self.assertNumFlows(1) + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) def test_authenticated_socks_proxy_error(self): """Client fails to authenticate to the SOCKS5 proxy.""" from python_socks import ProxyError as SocksProxyError try: self.proxy_options.update(proxyauth="any") - with patch_environ({"socks_proxy": "http://localhost:51080"}): - with self.assertRaises(ProxyError) as raised: - with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(ProxyError) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") finally: self.proxy_options.update(proxyauth=None) self.assertEqual( @@ -359,14 +358,14 @@ def test_authenticated_socks_proxy_error(self): self.assertIsInstance(raised.exception.__cause__, SocksProxyError) self.assertNumFlows(0) + @patch.dict(os.environ, {"socks_proxy": "http://localhost:61080"}) # bad port def test_socks_proxy_connection_failure(self): """Client fails to connect to the SOCKS5 proxy.""" from python_socks import ProxyConnectionError as SocksProxyConnectionError - with patch_environ({"socks_proxy": "http://localhost:61080"}): # bad port - with self.assertRaises(OSError) as raised: - with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(OSError) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") # Don't test str(raised.exception) because we don't control it. self.assertIsInstance(raised.exception, SocksProxyConnectionError) self.assertNumFlows(0) @@ -378,7 +377,7 @@ def test_socks_proxy_connection_timeout(self): # Replace the proxy with a TCP server that doesn't respond. with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() - with patch_environ({"socks_proxy": f"http://{host}:{port}"}): + with patch.dict(os.environ, {"socks_proxy": f"http://{host}:{port}"}): with self.assertRaises(TimeoutError) as raised: with connect("ws://example.com/", open_timeout=MS): self.fail("did not raise") @@ -397,14 +396,14 @@ def test_explicit_socks_proxy(self): self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) + @patch.dict(os.environ, {"ws_proxy": "http://localhost:58080"}) def test_ignore_proxy_with_existing_socket(self): """Client connects using a pre-existing socket.""" - with patch_environ({"ws_proxy": "http://localhost:58080"}): - with run_server() as server: - with socket.create_connection(server.socket.getsockname()) as sock: - # Use a non-existing domain to ensure we connect to sock. - with connect("ws://invalid/", sock=sock) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + with run_server() as server: + with socket.create_connection(server.socket.getsockname()) as sock: + # Use a non-existing domain to ensure we connect to sock. + with connect("ws://invalid/", sock=sock) as client: + self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(0) @@ -412,45 +411,43 @@ def test_ignore_proxy_with_existing_socket(self): class HTTPProxyClientTests(ProxyMixin, unittest.IsolatedAsyncioTestCase): proxy_mode = "regular@58080" + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) def test_http_proxy(self): """Client connects to server through an HTTP proxy.""" - with patch_environ({"https_proxy": "http://localhost:58080"}): - with run_server() as server: - with connect(get_uri(server)) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + with run_server() as server: + with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) def test_secure_http_proxy(self): """Client connects to server securely through an HTTP proxy.""" - with patch_environ({"https_proxy": "http://localhost:58080"}): - with run_server(ssl=SERVER_CONTEXT) as server: - with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: - self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(client.socket.version()[:3], "TLS") + with run_server(ssl=SERVER_CONTEXT) as server: + with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(client.socket.version()[:3], "TLS") self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "http://hello:iloveyou@localhost:58080"}) def test_authenticated_http_proxy(self): """Client connects to server through an authenticated HTTP proxy.""" try: self.proxy_options.update(proxyauth="hello:iloveyou") - with patch_environ( - {"https_proxy": "http://hello:iloveyou@localhost:58080"} - ): - with run_server() as server: - with connect(get_uri(server)) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + with run_server() as server: + with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") finally: self.proxy_options.update(proxyauth=None) self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) def test_authenticated_http_proxy_error(self): """Client fails to authenticate to the HTTP proxy.""" try: self.proxy_options.update(proxyauth="any") - with patch_environ({"https_proxy": "http://localhost:58080"}): - with self.assertRaises(ProxyError) as raised: - with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(ProxyError) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") finally: self.proxy_options.update(proxyauth=None) self.assertEqual( @@ -459,14 +456,14 @@ def test_authenticated_http_proxy_error(self): ) self.assertNumFlows(0) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) def test_http_proxy_protocol_error(self): """Client receives invalid data when connecting to the HTTP proxy.""" try: self.proxy_options.update(break_http_connect=True) - with patch_environ({"https_proxy": "http://localhost:58080"}): - with self.assertRaises(InvalidProxyMessage) as raised: - with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(InvalidProxyMessage) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") finally: self.proxy_options.update(break_http_connect=False) self.assertEqual( @@ -475,14 +472,14 @@ def test_http_proxy_protocol_error(self): ) self.assertNumFlows(0) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) def test_http_proxy_connection_error(self): """Client receives no response when connecting to the HTTP proxy.""" try: self.proxy_options.update(close_http_connect=True) - with patch_environ({"https_proxy": "http://localhost:58080"}): - with self.assertRaises(InvalidProxyMessage) as raised: - with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(InvalidProxyMessage) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") finally: self.proxy_options.update(close_http_connect=False) self.assertEqual( @@ -491,12 +488,12 @@ def test_http_proxy_connection_error(self): ) self.assertNumFlows(0) + @patch.dict(os.environ, {"https_proxy": "http://localhost:48080"}) # bad port def test_http_proxy_connection_failure(self): """Client fails to connect to the HTTP proxy.""" - with patch_environ({"https_proxy": "http://localhost:61080"}): # bad port - with self.assertRaises(OSError): - with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(OSError): + with connect("ws://example.com/"): + self.fail("did not raise") # Don't test str(raised.exception) because we don't control it. self.assertNumFlows(0) @@ -505,7 +502,7 @@ def test_http_proxy_connection_timeout(self): # Replace the proxy with a TCP server that does't respond. with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() - with patch_environ({"https_proxy": f"http://{host}:{port}"}): + with patch.dict(os.environ, {"https_proxy": f"http://{host}:{port}"}): with self.assertRaises(TimeoutError) as raised: with connect("ws://example.com/", open_timeout=MS): self.fail("did not raise") @@ -514,66 +511,66 @@ def test_http_proxy_connection_timeout(self): "timed out while connecting to HTTP proxy", ) + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) def test_https_proxy(self): """Client connects to server through an HTTPS proxy.""" - with patch_environ({"https_proxy": "https://localhost:58080"}): - with run_server() as server: - with connect( - get_uri(server), - proxy_ssl=self.proxy_context, - ) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + with run_server() as server: + with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) def test_secure_https_proxy(self): """Client connects to server securely through an HTTPS proxy.""" - with patch_environ({"https_proxy": "https://localhost:58080"}): - with run_server(ssl=SERVER_CONTEXT) as server: - with connect( - get_uri(server), - ssl=CLIENT_CONTEXT, - proxy_ssl=self.proxy_context, - ) as client: - self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(client.socket.version()[:3], "TLS") + with run_server(ssl=SERVER_CONTEXT) as server: + with connect( + get_uri(server), + ssl=CLIENT_CONTEXT, + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(client.socket.version()[:3], "TLS") self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) def test_https_proxy_server_hostname(self): """Client sets server_hostname to the value of proxy_server_hostname.""" - with patch_environ({"https_proxy": "https://localhost:58080"}): - with run_server() as server: - # Pass an argument not prefixed with proxy_ for coverage. - kwargs = {"all_errors": True} if sys.version_info >= (3, 11) else {} - with connect( - get_uri(server), - proxy_ssl=self.proxy_context, - proxy_server_hostname="overridden", - **kwargs, - ) as client: - self.assertEqual(client.socket.server_hostname, "overridden") + with run_server() as server: + # Pass an argument not prefixed with proxy_ for coverage. + kwargs = {"all_errors": True} if sys.version_info >= (3, 11) else {} + with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + proxy_server_hostname="overridden", + **kwargs, + ) as client: + self.assertEqual(client.socket.server_hostname, "overridden") self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) def test_https_proxy_invalid_proxy_certificate(self): """Client rejects certificate when proxy certificate isn't trusted.""" - with patch_environ({"https_proxy": "https://localhost:58080"}): - with self.assertRaises(ssl.SSLCertVerificationError) as raised: - # The proxy certificate isn't trusted. - with connect("wss://example.com/"): - self.fail("did not raise") + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The proxy certificate isn't trusted. + with connect("wss://example.com/"): + self.fail("did not raise") self.assertIn( "certificate verify failed: unable to get local issuer certificate", str(raised.exception), ) self.assertNumFlows(0) + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) def test_https_proxy_invalid_server_certificate(self): """Client rejects certificate when server certificate isn't trusted.""" - with patch_environ({"https_proxy": "https://localhost:58080"}): - with run_server(ssl=SERVER_CONTEXT) as server: - with self.assertRaises(ssl.SSLCertVerificationError) as raised: - # The test certificate is self-signed. - with connect(get_uri(server), proxy_ssl=self.proxy_context): - self.fail("did not raise") + with run_server(ssl=SERVER_CONTEXT) as server: + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The test certificate is self-signed. + with connect(get_uri(server), proxy_ssl=self.proxy_context): + self.fail("did not raise") self.assertIn( "certificate verify failed: self signed certificate", str(raised.exception).replace("-", " "), @@ -629,9 +626,12 @@ def test_ssl_without_secure_uri(self): def test_proxy_ssl_without_https_proxy(self): """Client rejects proxy_ssl when proxy isn't HTTPS.""" - with patch_environ({"https_proxy": "http://localhost:8080"}): - with self.assertRaises(ValueError) as raised: - connect("ws://localhost/", proxy_ssl=True) + with self.assertRaises(ValueError) as raised: + connect( + "ws://localhost/", + proxy="http://localhost:8080", + proxy_ssl=True, + ) self.assertEqual( str(raised.exception), "proxy_ssl argument is incompatible with an http:// proxy", diff --git a/tests/test_uri.py b/tests/test_uri.py index 35b51fa58..3ccf21158 100644 --- a/tests/test_uri.py +++ b/tests/test_uri.py @@ -1,11 +1,11 @@ +import os import unittest +from unittest.mock import patch from websockets.exceptions import InvalidProxy, InvalidURI from websockets.uri import * from websockets.uri import Proxy, get_proxy, parse_proxy -from .utils import patch_environ - VALID_URIS = [ ( @@ -255,6 +255,6 @@ def test_parse_proxy_user_info(self): def test_get_proxy(self): for environ, uri, proxy in PROXY_ENVS: - with patch_environ(environ): + with patch.dict(os.environ, environ): with self.subTest(environ=environ, uri=uri): self.assertEqual(get_proxy(parse_uri(uri)), proxy) diff --git a/tests/utils.py b/tests/utils.py index 389381345..7932aae60 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -138,22 +138,6 @@ def assertNoLogs(self, logger=None, level=None): self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) -@contextlib.contextmanager -def patch_environ(environ): - backup = {} - for key, value in environ.items(): - backup[key] = os.environ.get(key) - os.environ[key] = value - try: - yield - finally: - for key, value in backup.items(): - if value is None: - del os.environ[key] - else: # pragma: no cover - os.environ[key] = value - - @contextlib.contextmanager def temp_unix_socket_path(): with tempfile.TemporaryDirectory() as temp_dir: From 1a90f1e9ebf6a9ea034d89d1e706bb8e3bee1f25 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Feb 2025 23:14:32 +0100 Subject: [PATCH 1507/1539] Set User-Agent header in CONNECT requests. --- src/websockets/asyncio/client.py | 37 ++++++++++++++++++++++--------- src/websockets/sync/client.py | 12 ++++++++-- tests/asyncio/test_client.py | 18 +++++++++++++++ tests/proxy.py | 38 ++++++++++++++++++++++---------- tests/sync/test_client.py | 18 +++++++++++++++ 5 files changed, 99 insertions(+), 24 deletions(-) diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index c19a53f8c..38a56ddda 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -97,7 +97,7 @@ async def handshake( self.request = self.protocol.connect() if additional_headers is not None: self.request.headers.update(additional_headers) - if user_agent_header: + if user_agent_header is not None: self.request.headers.setdefault("User-Agent", user_agent_header) self.protocol.send_request(self.request) @@ -363,10 +363,8 @@ def protocol_factory(uri: WebSocketURI) -> ClientConnection: self.proxy = proxy self.protocol_factory = protocol_factory - self.handshake_args = ( - additional_headers, - user_agent_header, - ) + self.additional_headers = additional_headers + self.user_agent_header = user_agent_header self.process_exception = process_exception self.open_timeout = open_timeout self.logger = logger @@ -442,6 +440,7 @@ def factory() -> ClientConnection: transport = await connect_http_proxy( proxy_parsed, ws_uri, + user_agent_header=self.user_agent_header, **proxy_kwargs, ) # Initialize WebSocket connection via the proxy. @@ -541,7 +540,10 @@ async def __await_impl__(self) -> ClientConnection: for _ in range(MAX_REDIRECTS): self.connection = await self.create_connection() try: - await self.connection.handshake(*self.handshake_args) + await self.connection.handshake( + self.additional_headers, + self.user_agent_header, + ) except asyncio.CancelledError: self.connection.transport.abort() raise @@ -717,10 +719,16 @@ async def connect_socks_proxy( raise ImportError("python-socks is required to use a SOCKS proxy") -def prepare_connect_request(proxy: Proxy, ws_uri: WebSocketURI) -> bytes: +def prepare_connect_request( + proxy: Proxy, + ws_uri: WebSocketURI, + user_agent_header: str | None = None, +) -> bytes: host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True) headers = Headers() headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure) + if user_agent_header is not None: + headers["User-Agent"] = user_agent_header if proxy.username is not None: assert proxy.password is not None # enforced by parse_proxy() headers["Proxy-Authorization"] = build_authorization_basic( @@ -731,9 +739,15 @@ def prepare_connect_request(proxy: Proxy, ws_uri: WebSocketURI) -> bytes: class HTTPProxyConnection(asyncio.Protocol): - def __init__(self, ws_uri: WebSocketURI, proxy: Proxy): + def __init__( + self, + ws_uri: WebSocketURI, + proxy: Proxy, + user_agent_header: str | None = None, + ): self.ws_uri = ws_uri self.proxy = proxy + self.user_agent_header = user_agent_header self.reader = StreamReader() self.parser = Response.parse( @@ -765,7 +779,9 @@ def run_parser(self) -> None: def connection_made(self, transport: asyncio.BaseTransport) -> None: transport = cast(asyncio.Transport, transport) self.transport = transport - self.transport.write(prepare_connect_request(self.proxy, self.ws_uri)) + self.transport.write( + prepare_connect_request(self.proxy, self.ws_uri, self.user_agent_header) + ) def data_received(self, data: bytes) -> None: self.reader.feed_data(data) @@ -784,10 +800,11 @@ def connection_lost(self, exc: Exception | None) -> None: async def connect_http_proxy( proxy: Proxy, ws_uri: WebSocketURI, + user_agent_header: str | None = None, **kwargs: Any, ) -> asyncio.Transport: transport, protocol = await asyncio.get_running_loop().create_connection( - lambda: HTTPProxyConnection(ws_uri, proxy), + lambda: HTTPProxyConnection(ws_uri, proxy, user_agent_header), proxy.host, proxy.port, **kwargs, diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index c0fe6901a..58cb84710 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -312,6 +312,7 @@ def connect( proxy_parsed, ws_uri, deadline, + user_agent_header=user_agent_header, ssl=proxy_ssl, server_hostname=proxy_server_hostname, **kwargs, @@ -472,10 +473,16 @@ def connect_socks_proxy( raise ImportError("python-socks is required to use a SOCKS proxy") -def prepare_connect_request(proxy: Proxy, ws_uri: WebSocketURI) -> bytes: +def prepare_connect_request( + proxy: Proxy, + ws_uri: WebSocketURI, + user_agent_header: str | None = None, +) -> bytes: host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True) headers = Headers() headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure) + if user_agent_header is not None: + headers["User-Agent"] = user_agent_header if proxy.username is not None: assert proxy.password is not None # enforced by parse_proxy() headers["Proxy-Authorization"] = build_authorization_basic( @@ -524,6 +531,7 @@ def connect_http_proxy( ws_uri: WebSocketURI, deadline: Deadline, *, + user_agent_header: str | None = None, ssl: ssl_module.SSLContext | None = None, server_hostname: str | None = None, **kwargs: Any, @@ -546,7 +554,7 @@ def connect_http_proxy( # Send CONNECT request to the proxy and read response. - sock.sendall(prepare_connect_request(proxy, ws_uri)) + sock.sendall(prepare_connect_request(proxy, ws_uri, user_agent_header)) try: read_connect_response(sock, deadline) except Exception: diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index c2a96f3ec..465ea2bdb 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -712,6 +712,24 @@ async def test_authenticated_http_proxy_error(self): ) self.assertNumFlows(0) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + async def test_http_proxy_override_user_agent(self): + """Client can override User-Agent header with user_agent_header.""" + async with serve(*args) as server: + async with connect(get_uri(server), user_agent_header="Smith") as client: + self.assertEqual(client.protocol.state.name, "OPEN") + [http_connect] = self.get_http_connects() + self.assertEqual(http_connect.request.headers[b"User-Agent"], "Smith") + + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + async def test_http_proxy_remove_user_agent(self): + """Client can remove User-Agent header with user_agent_header.""" + async with serve(*args) as server: + async with connect(get_uri(server), user_agent_header=None) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + [http_connect] = self.get_http_connects() + self.assertNotIn(b"User-Agent", http_connect.request.headers) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) async def test_http_proxy_protocol_error(self): """Client receives invalid data when connecting to the HTTP proxy.""" diff --git a/tests/proxy.py b/tests/proxy.py index 9746e3382..236c49337 100644 --- a/tests/proxy.py +++ b/tests/proxy.py @@ -22,17 +22,26 @@ class RecordFlows: def __init__(self, on_running): self.running = on_running - self.flows = [] + self.http_connects = [] + self.tcp_flows = [] + + def http_connect(self, flow): + self.http_connects.append(flow) def tcp_start(self, flow): - self.flows.append(flow) + self.tcp_flows.append(flow) + + def get_http_connects(self): + http_connects, self.http_connects[:] = self.http_connects[:], [] + return http_connects - def get_flows(self): - flows, self.flows[:] = self.flows[:], [] - return flows + def get_tcp_flows(self): + tcp_flows, self.tcp_flows[:] = self.tcp_flows[:], [] + return tcp_flows - def reset_flows(self): - self.flows = [] + def reset(self): + self.http_connects = [] + self.tcp_flows = [] class AlterRequest: @@ -121,13 +130,18 @@ def setUpClass(cls): cls.proxy_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) cls.proxy_context.load_verify_locations(bytes(certificate)) - def assertNumFlows(self, num_flows): - record_flows = self.proxy_master.addons.get("recordflows") - self.assertEqual(len(record_flows.get_flows()), num_flows) + def get_http_connects(self): + return self.proxy_master.addons.get("recordflows").get_http_connects() + + def get_tcp_flows(self): + return self.proxy_master.addons.get("recordflows").get_tcp_flows() + + def assertNumFlows(self, num_tcp_flows): + self.assertEqual(len(self.get_tcp_flows()), num_tcp_flows) def tearDown(self): - record_flows = self.proxy_master.addons.get("recordflows") - record_flows.reset_flows() + record_tcp_flows = self.proxy_master.addons.get("recordflows") + record_tcp_flows.reset() super().tearDown() @classmethod diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index e4927bb32..415343911 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -456,6 +456,24 @@ def test_authenticated_http_proxy_error(self): ) self.assertNumFlows(0) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + def test_http_proxy_override_user_agent(self): + """Client can override User-Agent header with user_agent_header.""" + with run_server() as server: + with connect(get_uri(server), user_agent_header="Smith") as client: + self.assertEqual(client.protocol.state.name, "OPEN") + [http_connect] = self.get_http_connects() + self.assertEqual(http_connect.request.headers[b"User-Agent"], "Smith") + + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + def test_http_proxy_remove_user_agent(self): + """Client can remove User-Agent header with user_agent_header.""" + with run_server() as server: + with connect(get_uri(server), user_agent_header=None) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + [http_connect] = self.get_http_connects() + self.assertNotIn(b"User-Agent", http_connect.request.headers) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) def test_http_proxy_protocol_error(self): """Client receives invalid data when connecting to the HTTP proxy.""" From d60255bfbb21853f764999dad462e06d9ac3ba72 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Feb 2025 23:27:01 +0100 Subject: [PATCH 1508/1539] Highlight potential backwards incompatibility. --- docs/project/changelog.rst | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 7bb94b349..4eaa0a4d1 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -47,11 +47,14 @@ Backwards-incompatible changes .. _python-socks: https://github.com/romis2012/python-socks -New features -............ +.. admonition:: Keepalive is enabled in the :mod:`threading` implementation. + :class: note + + The :mod:`threading` implementation now sends Ping frames at regular + intervals and closes the connection if it doesn't receive a matching Pong + frame just like the :mod:`asyncio` implementation. -* Added :doc:`keepalive and latency measurement <../topics/keepalive>` to the - :mod:`threading` implementation. + See :doc:`keepalive and latency <../topics/keepalive>` for details. Improvements ............ From 7a2f8f40af801de9e09b4411aa4af3cc5a86139f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Feb 2025 23:24:43 +0100 Subject: [PATCH 1509/1539] Start recv_events only after attributes are initialized. Else, a race condition could lead to accessing self.pong_waiters before it is defined. --- src/websockets/asyncio/connection.py | 16 ++++++++-------- src/websockets/sync/connection.py | 28 +++++++++++++++------------- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 79429923e..1b51e4791 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -101,14 +101,6 @@ def __init__( # Protect sending fragmented messages. self.fragmented_send_waiter: asyncio.Future[None] | None = None - # Exception raised while reading from the connection, to be chained to - # ConnectionClosed in order to show why the TCP connection dropped. - self.recv_exc: BaseException | None = None - - # Completed when the TCP connection is closed and the WebSocket - # connection state becomes CLOSED. - self.connection_lost_waiter: asyncio.Future[None] = self.loop.create_future() - # Mapping of ping IDs to pong waiters, in chronological order. self.pong_waiters: dict[bytes, tuple[asyncio.Future[float], float]] = {} @@ -128,6 +120,14 @@ def __init__( # Task that sends keepalive pings. None when ping_interval is None. self.keepalive_task: asyncio.Task[None] | None = None + # Exception raised while reading from the connection, to be chained to + # ConnectionClosed in order to show why the TCP connection dropped. + self.recv_exc: BaseException | None = None + + # Completed when the TCP connection is closed and the WebSocket + # connection state becomes CLOSED. + self.connection_lost_waiter: asyncio.Future[None] = self.loop.create_future() + # Adapted from asyncio.FlowControlMixin self.paused: bool = False self.drain_waiters: collections.deque[asyncio.Future[None]] = ( diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 0c517cc64..8b9e06257 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -101,19 +101,6 @@ def __init__( # Whether we are busy sending a fragmented message. self.send_in_progress = False - # Exception raised in recv_events, to be chained to ConnectionClosed - # in the user thread in order to show why the TCP connection dropped. - self.recv_exc: BaseException | None = None - - # Receiving events from the socket. This thread is marked as daemon to - # allow creating a connection in a non-daemon thread and using it in a - # daemon thread. This mustn't prevent the interpreter from exiting. - self.recv_events_thread = threading.Thread( - target=self.recv_events, - daemon=True, - ) - self.recv_events_thread.start() - # Mapping of ping IDs to pong waiters, in chronological order. self.pong_waiters: dict[bytes, tuple[threading.Event, float, bool]] = {} @@ -133,6 +120,21 @@ def __init__( # Thread that sends keepalive pings. None when ping_interval is None. self.keepalive_thread: threading.Thread | None = None + # Exception raised in recv_events, to be chained to ConnectionClosed + # in the user thread in order to show why the TCP connection dropped. + self.recv_exc: BaseException | None = None + + # Receiving events from the socket. This thread is marked as daemon to + # allow creating a connection in a non-daemon thread and using it in a + # daemon thread. This mustn't prevent the interpreter from exiting. + self.recv_events_thread = threading.Thread( + target=self.recv_events, + daemon=True, + ) + + # Start recv_events only after all attributes are initialized. + self.recv_events_thread.start() + # Public attributes @property From 602d7195c09c3c568bfd0268311a7b07a07ffb3d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 4 Feb 2025 22:09:49 +0100 Subject: [PATCH 1510/1539] Standardized and improved admonition types. --- docs/faq/asyncio.rst | 2 +- docs/faq/client.rst | 2 +- docs/faq/server.rst | 2 +- docs/project/changelog.rst | 8 ++++---- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/faq/asyncio.rst b/docs/faq/asyncio.rst index 3bc381cfd..a1bb663b5 100644 --- a/docs/faq/asyncio.rst +++ b/docs/faq/asyncio.rst @@ -4,7 +4,7 @@ Using asyncio .. currentmodule:: websockets.asyncio.connection .. admonition:: This FAQ is written for the new :mod:`asyncio` implementation. - :class: hint + :class: tip Answers are also valid for the legacy :mod:`asyncio` implementation. diff --git a/docs/faq/client.rst b/docs/faq/client.rst index cc9856a8b..d3f627684 100644 --- a/docs/faq/client.rst +++ b/docs/faq/client.rst @@ -4,7 +4,7 @@ Client .. currentmodule:: websockets.asyncio.client .. admonition:: This FAQ is written for the new :mod:`asyncio` implementation. - :class: hint + :class: tip Answers are also valid for the legacy :mod:`asyncio` implementation. diff --git a/docs/faq/server.rst b/docs/faq/server.rst index ce7e1962d..e6a3abe85 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -4,7 +4,7 @@ Server .. currentmodule:: websockets.asyncio.server .. admonition:: This FAQ is written for the new :mod:`asyncio` implementation. - :class: hint + :class: tip Answers are also valid for the legacy :mod:`asyncio` implementation. diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 4eaa0a4d1..1f02a6cde 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -48,7 +48,7 @@ Backwards-incompatible changes .. _python-socks: https://github.com/romis2012/python-socks .. admonition:: Keepalive is enabled in the :mod:`threading` implementation. - :class: note + :class: important The :mod:`threading` implementation now sends Ping frames at regular intervals and closes the connection if it doesn't receive a matching Pong @@ -135,7 +135,7 @@ Backwards-incompatible changes websockets 13.1 is the last version supporting Python 3.8. .. admonition:: The new :mod:`asyncio` implementation is now the default. - :class: danger + :class: attention The following aliases in the ``websockets`` package were switched to the new :mod:`asyncio` implementation:: @@ -749,7 +749,7 @@ Security fix ............ .. admonition:: websockets 9.1 fixes a security issue introduced in 8.0. - :class: important + :class: danger Version 8.0 was vulnerable to timing attacks on HTTP Basic Auth passwords (`CVE-2021-33880`_). @@ -1168,7 +1168,7 @@ Security fix ............ .. admonition:: websockets 5.0 fixes a security issue introduced in 4.0. - :class: important + :class: danger Version 4.0 was vulnerable to denial of service by memory exhaustion because it didn't enforce ``max_size`` when decompressing compressed From 4b9caad779c3c995845abb099c185e7d6009570f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Feb 2025 23:15:59 +0100 Subject: [PATCH 1511/1539] Add a router based on werkzeug.routing. Fix #311. --- docs/conf.py | 5 +- docs/faq/client.rst | 2 +- docs/faq/server.rst | 4 +- docs/project/changelog.rst | 6 + docs/reference/asyncio/server.rst | 19 ++- docs/reference/features.rst | 2 + docs/reference/sync/server.rst | 27 +++- docs/requirements.txt | 1 + src/websockets/__init__.py | 9 ++ src/websockets/asyncio/router.py | 196 +++++++++++++++++++++++++++++ src/websockets/asyncio/server.py | 4 +- src/websockets/sync/router.py | 190 ++++++++++++++++++++++++++++ src/websockets/sync/server.py | 4 +- tests/asyncio/server.py | 8 +- tests/asyncio/test_router.py | 198 ++++++++++++++++++++++++++++++ tests/sync/server.py | 44 +++++-- tests/sync/test_router.py | 174 ++++++++++++++++++++++++++ tests/test_exports.py | 2 + tox.ini | 2 + 19 files changed, 879 insertions(+), 18 deletions(-) create mode 100644 src/websockets/asyncio/router.py create mode 100644 src/websockets/sync/router.py create mode 100644 tests/asyncio/test_router.py create mode 100644 tests/sync/test_router.py diff --git a/docs/conf.py b/docs/conf.py index 2c621bf41..c6b9ac7d8 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -82,7 +82,10 @@ assert PythonDomain.object_types["data"].roles == ("data", "obj") PythonDomain.object_types["data"].roles = ("data", "class", "obj") -intersphinx_mapping = {"python": ("https://docs.python.org/3", None)} +intersphinx_mapping = { + "python": ("https://docs.python.org/3", None), + "werkzeug": ("https://werkzeug.palletsprojects.com/en/stable/", None), +} spelling_show_suggestions = True diff --git a/docs/faq/client.rst b/docs/faq/client.rst index d3f627684..c39e588ca 100644 --- a/docs/faq/client.rst +++ b/docs/faq/client.rst @@ -81,7 +81,7 @@ The connection is closed when exiting the context manager. How do I reconnect when the connection drops? --------------------------------------------- -Use :func:`~websockets.asyncio.client.connect` as an asynchronous iterator:: +Use :func:`connect` as an asynchronous iterator:: from websockets.asyncio.client import connect from websockets.exceptions import ConnectionClosed diff --git a/docs/faq/server.rst b/docs/faq/server.rst index e6a3abe85..d00dcafba 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -116,7 +116,7 @@ Record all connections in a global variable:: finally: CONNECTIONS.remove(websocket) -Then, call :func:`~websockets.asyncio.server.broadcast`:: +Then, call :func:`broadcast`:: from websockets.asyncio.server import broadcast @@ -219,6 +219,8 @@ You may route a connection to different handlers depending on the request path:: # No handler for this path; close the connection. return +For more complex routing, you may use :func:`~websockets.asyncio.router.route`. + You may also route the connection based on the first message received from the client, as shown in the :doc:`tutorial <../intro/tutorial2>`. When you want to authenticate the connection before routing it, this is usually more convenient. diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 1f02a6cde..d7db6167a 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -56,6 +56,12 @@ Backwards-incompatible changes See :doc:`keepalive and latency <../topics/keepalive>` for details. +New features +............ + +* Added :func:`~asyncio.router.route` and :func:`~asyncio.router.unix_route` to + dispatch connections to different handlers depending on the URL. + Improvements ............ diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index 49bd6f072..8d8b700f3 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -12,6 +12,21 @@ Creating a server .. autofunction:: unix_serve :async: +Routing connections +------------------- + +.. automodule:: websockets.asyncio.router + +.. autofunction:: route + :async: + +.. autofunction:: unix_route + :async: + +.. autoclass:: Router + +.. currentmodule:: websockets.asyncio.server + Running a server ---------------- @@ -89,7 +104,7 @@ Using a connection Broadcast --------- -.. autofunction:: websockets.asyncio.server.broadcast +.. autofunction:: broadcast HTTP Basic Authentication ------------------------- @@ -97,4 +112,4 @@ HTTP Basic Authentication websockets supports HTTP Basic Authentication according to :rfc:`7235` and :rfc:`7617`. -.. autofunction:: websockets.asyncio.server.basic_auth +.. autofunction:: basic_auth diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 93b083d20..0da966ccc 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -127,6 +127,8 @@ Server +------------------------------------+--------+--------+--------+--------+ | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | ❌ | +------------------------------------+--------+--------+--------+--------+ + | Dispatch connections to handlers | ✅ | ✅ | — | ❌ | + +------------------------------------+--------+--------+--------+--------+ Client ------ diff --git a/docs/reference/sync/server.rst b/docs/reference/sync/server.rst index c3d0e8f25..f6a45a659 100644 --- a/docs/reference/sync/server.rst +++ b/docs/reference/sync/server.rst @@ -10,6 +10,31 @@ Creating a server .. autofunction:: unix_serve +Routing connections +------------------- + +.. automodule:: websockets.sync.router + +.. autofunction:: route + +.. autofunction:: unix_route + +.. autoclass:: Router + +.. currentmodule:: websockets.sync.server + +Routing connections +------------------- + +.. autofunction:: route + :async: + +.. autofunction:: unix_route + :async: + +.. autoclass:: Server + + Running a server ---------------- @@ -78,4 +103,4 @@ HTTP Basic Authentication websockets supports HTTP Basic Authentication according to :rfc:`7235` and :rfc:`7617`. -.. autofunction:: websockets.sync.server.basic_auth +.. autofunction:: basic_auth diff --git a/docs/requirements.txt b/docs/requirements.txt index bcd1d7114..77c87f4dc 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -6,3 +6,4 @@ sphinx-inline-tabs sphinxcontrib-spelling sphinxcontrib-trio sphinxext-opengraph +werkzeug diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 28a10910b..f90aff5b9 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -12,6 +12,10 @@ "connect", "unix_connect", "ClientConnection", + # .asyncio.router + "route", + "unix_route", + "Router", # .asyncio.server "basic_auth", "broadcast", @@ -79,6 +83,7 @@ # When type checking, import non-deprecated aliases eagerly. Else, import on demand. if TYPE_CHECKING: from .asyncio.client import ClientConnection, connect, unix_connect + from .asyncio.router import Router, route, unix_route from .asyncio.server import ( Server, ServerConnection, @@ -138,6 +143,10 @@ "connect": ".asyncio.client", "unix_connect": ".asyncio.client", "ClientConnection": ".asyncio.client", + # .asyncio.router + "route": ".asyncio.router", + "unix_route": ".asyncio.router", + "Router": ".asyncio.router", # .asyncio.server "basic_auth": ".asyncio.server", "broadcast": ".asyncio.server", diff --git a/src/websockets/asyncio/router.py b/src/websockets/asyncio/router.py new file mode 100644 index 000000000..cd95022c1 --- /dev/null +++ b/src/websockets/asyncio/router.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +import http +import ssl as ssl_module +import urllib.parse +from typing import Any, Awaitable, Callable, Literal + +from werkzeug.exceptions import NotFound +from werkzeug.routing import Map, RequestRedirect + +from ..http11 import Request, Response +from .server import Server, ServerConnection, serve + + +__all__ = ["route", "unix_route", "Router"] + + +class Router: + """WebSocket router supporting :func:`route`.""" + + def __init__( + self, + url_map: Map, + server_name: str | None = None, + url_scheme: str = "ws", + ) -> None: + self.url_map = url_map + self.server_name = server_name + self.url_scheme = url_scheme + for rule in self.url_map.iter_rules(): + rule.websocket = True + + def get_server_name(self, connection: ServerConnection, request: Request) -> str: + if self.server_name is None: + return request.headers["Host"] + else: + return self.server_name + + def redirect(self, connection: ServerConnection, url: str) -> Response: + response = connection.respond(http.HTTPStatus.FOUND, f"Found at {url}") + response.headers["Location"] = url + return response + + def not_found(self, connection: ServerConnection) -> Response: + return connection.respond(http.HTTPStatus.NOT_FOUND, "Not Found") + + def route_request( + self, connection: ServerConnection, request: Request + ) -> Response | None: + """Route incoming request.""" + url_map_adapter = self.url_map.bind( + server_name=self.get_server_name(connection, request), + url_scheme=self.url_scheme, + ) + try: + parsed = urllib.parse.urlparse(request.path) + handler, kwargs = url_map_adapter.match( + path_info=parsed.path, + query_args=parsed.query, + ) + except RequestRedirect as redirect: + return self.redirect(connection, redirect.new_url) + except NotFound: + return self.not_found(connection) + connection.handler, connection.handler_kwargs = handler, kwargs + return None + + async def handler(self, connection: ServerConnection) -> None: + """Handle a connection.""" + return await connection.handler(connection, **connection.handler_kwargs) + + +def route( + url_map: Map, + *args: Any, + server_name: str | None = None, + ssl: ssl_module.SSLContext | Literal[True] | None = None, + create_router: type[Router] | None = None, + **kwargs: Any, +) -> Awaitable[Server]: + """ + Create a WebSocket server dispatching connections to different handlers. + + This feature requires the third-party library `werkzeug`_:: + + $ pip install werkzeug + + .. _werkzeug: https://werkzeug.palletsprojects.com/ + + :func:`route` accepts the same arguments as + :func:`~websockets.sync.server.serve`, except as described below. + + The first argument is a :class:`werkzeug.routing.Map` that maps URL patterns + to connection handlers. In addition to the connection, handlers receive + parameters captured in the URL as keyword arguments. + + Here's an example:: + + + from websockets.asyncio.router import route + from werkzeug.routing import Map, Rule + + async def channel_handler(websocket, channel_id): + ... + + url_map = Map([ + Rule("/channel/", endpoint=channel_handler), + ... + ]) + + # set this future to exit the server + stop = asyncio.get_running_loop().create_future() + + async with route(url_map, ...) as server: + await stop + + + Refer to the documentation of :mod:`werkzeug.routing` for details. + + If you define redirects with ``Rule(..., redirect_to=...)`` in the URL map, + when the server runs behind a reverse proxy that modifies the ``Host`` + header or terminates TLS, you need additional configuration: + + * Set ``server_name`` to the name of the server as seen by clients. When not + provided, websockets uses the value of the ``Host`` header. + + * Set ``ssl=True`` to generate ``wss://`` URIs without actually enabling + TLS. Under the hood, this bind the URL map with a ``url_scheme`` of + ``wss://`` instead of ``ws://``. + + There is no need to specify ``websocket=True`` in each rule. It is added + automatically. + + Args: + url_map: Mapping of URL patterns to connection handlers. + server_name: Name of the server as seen by clients. If :obj:`None`, + websockets uses the value of the ``Host`` header. + ssl: Configuration for enabling TLS on the connection. Set it to + :obj:`True` if a reverse proxy terminates TLS connections. + create_router: Factory for the :class:`Router` dispatching requests to + handlers. Set it to a wrapper or a subclass to customize routing. + + """ + url_scheme = "ws" if ssl is None else "wss" + if ssl is not True and ssl is not None: + kwargs["ssl"] = ssl + + if create_router is None: + create_router = Router + + router = create_router(url_map, server_name, url_scheme) + + _process_request: ( + Callable[ + [ServerConnection, Request], + Awaitable[Response | None] | Response | None, + ] + | None + ) = kwargs.pop("process_request", None) + if _process_request is None: + process_request: Callable[ + [ServerConnection, Request], + Awaitable[Response | None] | Response | None, + ] = router.route_request + else: + + async def process_request( + connection: ServerConnection, request: Request + ) -> Response | None: + response = _process_request(connection, request) + if isinstance(response, Awaitable): + response = await response + if response is not None: + return response + return router.route_request(connection, request) + + return serve(router.handler, *args, process_request=process_request, **kwargs) + + +def unix_route( + url_map: Map, + path: str | None = None, + **kwargs: Any, +) -> Awaitable[Server]: + """ + Create a WebSocket Unix server dispatching connections to different handlers. + + :func:`unix_route` combines the behaviors of :func:`route` and + :func:`~websockets.asyncio.server.unix_serve`. + + Args: + url_map: Mapping of URL patterns to connection handlers. + path: File system path to the Unix socket. + + """ + return route(url_map, unix=True, path=path, **kwargs) diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 2e2b78782..ec7fc4383 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -9,7 +9,7 @@ import sys from collections.abc import Awaitable, Generator, Iterable, Sequence from types import TracebackType -from typing import Any, Callable, cast +from typing import Any, Callable, Mapping, cast from ..exceptions import InvalidHeader from ..extensions.base import ServerExtensionFactory @@ -87,6 +87,8 @@ def __init__( self.server = server self.request_rcvd: asyncio.Future[None] = self.loop.create_future() self.username: str # see basic_auth() + self.handler: Callable[[ServerConnection], Awaitable[None]] # see route() + self.handler_kwargs: Mapping[str, Any] # see route() def respond(self, status: StatusLike, text: str) -> Response: """ diff --git a/src/websockets/sync/router.py b/src/websockets/sync/router.py new file mode 100644 index 000000000..33105bf32 --- /dev/null +++ b/src/websockets/sync/router.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import http +import ssl as ssl_module +import urllib.parse +from typing import Any, Callable, Literal + +from werkzeug.exceptions import NotFound +from werkzeug.routing import Map, RequestRedirect + +from ..http11 import Request, Response +from .server import Server, ServerConnection, serve + + +__all__ = ["route", "unix_route", "Router"] + + +class Router: + """WebSocket router supporting :func:`route`.""" + + def __init__( + self, + url_map: Map, + server_name: str | None = None, + url_scheme: str = "ws", + ) -> None: + self.url_map = url_map + self.server_name = server_name + self.url_scheme = url_scheme + for rule in self.url_map.iter_rules(): + rule.websocket = True + + def get_server_name(self, connection: ServerConnection, request: Request) -> str: + if self.server_name is None: + return request.headers["Host"] + else: + return self.server_name + + def redirect(self, connection: ServerConnection, url: str) -> Response: + response = connection.respond(http.HTTPStatus.FOUND, f"Found at {url}") + response.headers["Location"] = url + return response + + def not_found(self, connection: ServerConnection) -> Response: + return connection.respond(http.HTTPStatus.NOT_FOUND, "Not Found") + + def route_request( + self, connection: ServerConnection, request: Request + ) -> Response | None: + """Route incoming request.""" + url_map_adapter = self.url_map.bind( + server_name=self.get_server_name(connection, request), + url_scheme=self.url_scheme, + ) + try: + parsed = urllib.parse.urlparse(request.path) + handler, kwargs = url_map_adapter.match( + path_info=parsed.path, + query_args=parsed.query, + ) + except RequestRedirect as redirect: + return self.redirect(connection, redirect.new_url) + except NotFound: + return self.not_found(connection) + connection.handler, connection.handler_kwargs = handler, kwargs + return None + + def handler(self, connection: ServerConnection) -> None: + """Handle a connection.""" + return connection.handler(connection, **connection.handler_kwargs) + + +def route( + url_map: Map, + *args: Any, + server_name: str | None = None, + ssl: ssl_module.SSLContext | Literal[True] | None = None, + create_router: type[Router] | None = None, + **kwargs: Any, +) -> Server: + """ + Create a WebSocket server dispatching connections to different handlers. + + This feature requires the third-party library `werkzeug`_:: + + $ pip install werkzeug + + .. _werkzeug: https://werkzeug.palletsprojects.com/ + + :func:`route` accepts the same arguments as + :func:`~websockets.sync.server.serve`, except as described below. + + The first argument is a :class:`werkzeug.routing.Map` that maps URL patterns + to connection handlers. In addition to the connection, handlers receive + parameters captured in the URL as keyword arguments. + + Here's an example:: + + + from websockets.sync.router import route + from werkzeug.routing import Map, Rule + + def channel_handler(websocket, channel_id): + ... + + url_map = Map([ + Rule("/channel/", endpoint=channel_handler), + ... + ]) + + with route(url_map, ...) as server: + server.serve_forever() + + Refer to the documentation of :mod:`werkzeug.routing` for details. + + If you define redirects with ``Rule(..., redirect_to=...)`` in the URL map, + when the server runs behind a reverse proxy that modifies the ``Host`` + header or terminates TLS, you need additional configuration: + + * Set ``server_name`` to the name of the server as seen by clients. When not + provided, websockets uses the value of the ``Host`` header. + + * Set ``ssl=True`` to generate ``wss://`` URIs without actually enabling + TLS. Under the hood, this bind the URL map with a ``url_scheme`` of + ``wss://`` instead of ``ws://``. + + There is no need to specify ``websocket=True`` in each rule. It is added + automatically. + + Args: + url_map: Mapping of URL patterns to connection handlers. + server_name: Name of the server as seen by clients. If :obj:`None`, + websockets uses the value of the ``Host`` header. + ssl: Configuration for enabling TLS on the connection. Set it to + :obj:`True` if a reverse proxy terminates TLS connections. + create_router: Factory for the :class:`Router` dispatching requests to + handlers. Set it to a wrapper or a subclass to customize routing. + + """ + url_scheme = "ws" if ssl is None else "wss" + if ssl is not True and ssl is not None: + kwargs["ssl"] = ssl + + if create_router is None: + create_router = Router + + router = create_router(url_map, server_name, url_scheme) + + _process_request: ( + Callable[ + [ServerConnection, Request], + Response | None, + ] + | None + ) = kwargs.pop("process_request", None) + if _process_request is None: + process_request: Callable[ + [ServerConnection, Request], + Response | None, + ] = router.route_request + else: + + def process_request( + connection: ServerConnection, request: Request + ) -> Response | None: + response = _process_request(connection, request) + if response is not None: + return response + return router.route_request(connection, request) + + return serve(router.handler, *args, process_request=process_request, **kwargs) + + +def unix_route( + url_map: Map, + path: str | None = None, + **kwargs: Any, +) -> Server: + """ + Create a WebSocket Unix server dispatching connections to different handlers. + + :func:`unix_route` combines the behaviors of :func:`route` and + :func:`~websockets.sync.server.unix_serve`. + + Args: + url_map: Mapping of URL patterns to connection handlers. + path: File system path to the Unix socket. + + """ + return route(url_map, unix=True, path=path, **kwargs) diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 10e3b6816..efb40a7f4 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -13,7 +13,7 @@ import warnings from collections.abc import Iterable, Sequence from types import TracebackType -from typing import Any, Callable, cast +from typing import Any, Callable, Mapping, cast from ..exceptions import InvalidHeader from ..extensions.base import ServerExtensionFactory @@ -82,6 +82,8 @@ def __init__( max_queue=max_queue, ) self.username: str # see basic_auth() + self.handler: Callable[[ServerConnection], None] # see route() + self.handler_kwargs: Mapping[str, Any] # see route() def respond(self, status: StatusLike, text: str) -> Response: """ diff --git a/tests/asyncio/server.py b/tests/asyncio/server.py index acf6500c6..b142bcd7e 100644 --- a/tests/asyncio/server.py +++ b/tests/asyncio/server.py @@ -1,5 +1,6 @@ import asyncio import socket +import urllib.parse def get_host_port(server): @@ -9,15 +10,16 @@ def get_host_port(server): raise AssertionError("expected at least one IPv4 socket") -def get_uri(server): - secure = server.server._ssl_context is not None # hack +def get_uri(server, secure=None): + if secure is None: + secure = server.server._ssl_context is not None # hack protocol = "wss" if secure else "ws" host, port = get_host_port(server) return f"{protocol}://{host}:{port}" async def handler(ws): - path = ws.request.path + path = urllib.parse.urlparse(ws.request.path).path if path == "/": # The default path is an eval shell. async for expr in ws: diff --git a/tests/asyncio/test_router.py b/tests/asyncio/test_router.py new file mode 100644 index 000000000..1426cc9f3 --- /dev/null +++ b/tests/asyncio/test_router.py @@ -0,0 +1,198 @@ +import http +import socket +import sys +import unittest +from unittest.mock import patch + +from websockets.asyncio.client import connect, unix_connect +from websockets.asyncio.router import * +from websockets.exceptions import InvalidStatus + +from ..utils import CLIENT_CONTEXT, SERVER_CONTEXT, temp_unix_socket_path +from .server import EvalShellMixin, get_uri, handler +from .utils import alist + + +try: + from werkzeug.routing import Map, Rule +except ImportError: + pass + + +async def echo(websocket, count): + message = await websocket.recv() + for _ in range(count): + await websocket.send(message) + + +@unittest.skipUnless("werkzeug" in sys.modules, "werkzeug not installed") +class RouterTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + # This is a small realistic example of werkzeug's basic URL routing + # features: path matching, parameter extraction, and default values. + + async def test_router_matches_paths_and_extracts_parameters(self): + """Router matches paths and extracts parameters.""" + url_map = Map( + [ + Rule("/echo", defaults={"count": 1}, endpoint=echo), + Rule("/echo/", endpoint=echo), + ] + ) + async with route(url_map, "localhost", 0) as server: + async with connect(get_uri(server) + "/echo") as client: + await client.send("hello") + messages = await alist(client) + self.assertEqual(messages, ["hello"]) + + async with connect(get_uri(server) + "/echo/3") as client: + await client.send("hello") + messages = await alist(client) + self.assertEqual(messages, ["hello", "hello", "hello"]) + + @property # avoids an import-time dependency on werkzeug + def url_map(self): + return Map( + [ + Rule("/", endpoint=handler), + Rule("/r", redirect_to="/"), + ] + ) + + async def test_route_with_query_string(self): + """Router ignores query strings when matching paths.""" + async with route(self.url_map, "localhost", 0) as server: + async with connect(get_uri(server) + "/?a=b") as client: + await self.assertEval(client, "ws.request.path", "/?a=b") + + async def test_redirect(self): + """Router redirects connections according to redirect_to.""" + async with route(self.url_map, "localhost", 0) as server: + async with connect(get_uri(server) + "/r") as client: + await self.assertEval(client, "ws.request.path", "/") + + async def test_secure_redirect(self): + """Router redirects connections to a wss:// URI when TLS is enabled.""" + async with route(self.url_map, "localhost", 0, ssl=SERVER_CONTEXT) as server: + async with connect(get_uri(server) + "/r", ssl=CLIENT_CONTEXT) as client: + await self.assertEval(client, "ws.request.path", "/") + + @patch("websockets.asyncio.client.connect.process_redirect", lambda _, exc: exc) + async def test_force_secure_redirect(self): + """Router redirects ws:// connections to a wss:// URI when ssl=True.""" + async with route(self.url_map, "localhost", 0, ssl=True) as server: + redirect_uri = get_uri(server, secure=True) + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server) + "/r"): + self.fail("did not raise") + self.assertEqual( + raised.exception.response.headers["Location"], + redirect_uri + "/", + ) + + @patch("websockets.asyncio.client.connect.process_redirect", lambda _, exc: exc) + async def test_force_redirect_server_name(self): + """Router redirects connections to the host declared in server_name.""" + async with route(self.url_map, "localhost", 0, server_name="other") as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server) + "/r"): + self.fail("did not raise") + self.assertEqual( + raised.exception.response.headers["Location"], + "ws://other/", + ) + + async def test_not_found(self): + """Router rejects requests to unknown paths with an HTTP 404 error.""" + async with route(self.url_map, "localhost", 0) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server) + "/n"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 404", + ) + + async def test_process_request_function_returning_none(self): + """Router supports a process_request function returning None.""" + + def process_request(ws, request): + ws.process_request_ran = True + + async with route( + self.url_map, "localhost", 0, process_request=process_request + ) as server: + async with connect(get_uri(server) + "/") as client: + await self.assertEval(client, "ws.process_request_ran", "True") + + async def test_process_request_coroutine_returning_none(self): + """Router supports a process_request coroutine returning None.""" + + async def process_request(ws, request): + ws.process_request_ran = True + + async with route( + self.url_map, "localhost", 0, process_request=process_request + ) as server: + async with connect(get_uri(server) + "/") as client: + await self.assertEval(client, "ws.process_request_ran", "True") + + async def test_process_request_function_returning_response(self): + """Router supports a process_request function returning a response.""" + + def process_request(ws, request): + return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") + + async with route( + self.url_map, "localhost", 0, process_request=process_request + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server) + "/"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) + + async def test_process_request_coroutine_returning_response(self): + """Router supports a process_request coroutine returning a response.""" + + async def process_request(ws, request): + return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") + + async with route( + self.url_map, "localhost", 0, process_request=process_request + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server) + "/"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) + + async def test_custom_router_factory(self): + """Router supports a custom router factory.""" + + class MyRouter(Router): + async def handler(self, connection): + connection.my_router_ran = True + return await super().handler(connection) + + async with route( + self.url_map, "localhost", 0, create_router=MyRouter + ) as server: + async with connect(get_uri(server)) as client: + await self.assertEval(client, "ws.my_router_ran", "True") + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class UnixRouterTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + async def test_router_supports_unix_sockets(self): + """Router supports Unix sockets.""" + url_map = Map([Rule("/echo/", endpoint=echo)]) + with temp_unix_socket_path() as path: + async with unix_route(url_map, path): + async with unix_connect(path, "ws://localhost/echo/3") as client: + await client.send("hello") + messages = await alist(client) + self.assertEqual(messages, ["hello", "hello", "hello"]) diff --git a/tests/sync/server.py b/tests/sync/server.py index fd7a03d82..cadaa267e 100644 --- a/tests/sync/server.py +++ b/tests/sync/server.py @@ -1,19 +1,22 @@ import contextlib import ssl import threading +import urllib.parse +from websockets.sync.router import * from websockets.sync.server import * -def get_uri(server): - secure = isinstance(server.socket, ssl.SSLSocket) # hack +def get_uri(server, secure=None): + if secure is None: + secure = isinstance(server.socket, ssl.SSLSocket) # hack protocol = "wss" if secure else "ws" host, port = server.socket.getsockname() return f"{protocol}://{host}:{port}" def handler(ws): - path = ws.request.path + path = urllib.parse.urlparse(ws.request.path).path if path == "/": # The default path is an eval shell. for expr in ws: @@ -34,8 +37,14 @@ def assertEval(self, client, expr, value): @contextlib.contextmanager -def run_server(handler=handler, host="localhost", port=0, **kwargs): - with serve(handler, host, port, **kwargs) as server: +def run_server_or_router( + serve_or_route, + handler_or_url_map, + host="localhost", + port=0, + **kwargs, +): + with serve_or_route(handler_or_url_map, host, port, **kwargs) as server: thread = threading.Thread(target=server.serve_forever) thread.start() @@ -63,9 +72,22 @@ def handler(sock, addr): handler_thread.join() +def run_server(handler=handler, **kwargs): + return run_server_or_router(serve, handler, **kwargs) + + +def run_router(url_map, **kwargs): + return run_server_or_router(route, url_map, **kwargs) + + @contextlib.contextmanager -def run_unix_server(path, handler=handler, **kwargs): - with unix_serve(handler, path, **kwargs) as server: +def run_unix_server_or_router( + path, + unix_serve_or_route, + handler_or_url_map, + **kwargs, +): + with unix_serve_or_route(handler_or_url_map, path, **kwargs) as server: thread = threading.Thread(target=server.serve_forever) thread.start() try: @@ -73,3 +95,11 @@ def run_unix_server(path, handler=handler, **kwargs): finally: server.shutdown() thread.join() + + +def run_unix_server(path, handler=handler, **kwargs): + return run_unix_server_or_router(path, unix_serve, handler, **kwargs) + + +def run_unix_router(path, url_map, **kwargs): + return run_unix_server_or_router(path, unix_route, url_map, **kwargs) diff --git a/tests/sync/test_router.py b/tests/sync/test_router.py new file mode 100644 index 000000000..07274e625 --- /dev/null +++ b/tests/sync/test_router.py @@ -0,0 +1,174 @@ +import http +import socket +import sys +import unittest +from unittest.mock import patch + +from websockets.exceptions import InvalidStatus +from websockets.sync.client import connect, unix_connect +from websockets.sync.router import * + +from ..utils import CLIENT_CONTEXT, SERVER_CONTEXT, temp_unix_socket_path +from .server import EvalShellMixin, get_uri, handler, run_router, run_unix_router + + +try: + from werkzeug.routing import Map, Rule +except ImportError: + pass + + +def echo(websocket, count): + message = websocket.recv() + for _ in range(count): + websocket.send(message) + + +@unittest.skipUnless("werkzeug" in sys.modules, "werkzeug not installed") +class RouterTests(EvalShellMixin, unittest.TestCase): + # This is a small realistic example of werkzeug's basic URL routing + # features: path matching, parameter extraction, and default values. + + def test_router_matches_paths_and_extracts_parameters(self): + """Router matches paths and extracts parameters.""" + url_map = Map( + [ + Rule("/echo", defaults={"count": 1}, endpoint=echo), + Rule("/echo/", endpoint=echo), + ] + ) + with run_router(url_map) as server: + with connect(get_uri(server) + "/echo") as client: + client.send("hello") + messages = list(client) + self.assertEqual(messages, ["hello"]) + + with connect(get_uri(server) + "/echo/3") as client: + client.send("hello") + messages = list(client) + self.assertEqual(messages, ["hello", "hello", "hello"]) + + @property # avoids an import-time dependency on werkzeug + def url_map(self): + return Map( + [ + Rule("/", endpoint=handler), + Rule("/r", redirect_to="/"), + ] + ) + + def test_route_with_query_string(self): + """Router ignores query strings when matching paths.""" + with run_router(self.url_map) as server: + with connect(get_uri(server) + "/?a=b") as client: + self.assertEval(client, "ws.request.path", "/?a=b") + + def test_redirect(self): + """Router redirects connections according to redirect_to.""" + with run_router(self.url_map, server_name="localhost") as server: + with self.assertRaises(InvalidStatus) as raised: + with connect(get_uri(server) + "/r"): + self.fail("did not raise") + self.assertEqual( + raised.exception.response.headers["Location"], + "ws://localhost/", + ) + + def test_secure_redirect(self): + """Router redirects connections to a wss:// URI when TLS is enabled.""" + with run_router( + self.url_map, server_name="localhost", ssl=SERVER_CONTEXT + ) as server: + with self.assertRaises(InvalidStatus) as raised: + with connect(get_uri(server) + "/r", ssl=CLIENT_CONTEXT): + self.fail("did not raise") + self.assertEqual( + raised.exception.response.headers["Location"], + "wss://localhost/", + ) + + @patch("websockets.asyncio.client.connect.process_redirect", lambda _, exc: exc) + def test_force_secure_redirect(self): + """Router redirects ws:// connections to a wss:// URI when ssl=True.""" + with run_router(self.url_map, ssl=True) as server: + redirect_uri = get_uri(server, secure=True) + with self.assertRaises(InvalidStatus) as raised: + with connect(get_uri(server) + "/r"): + self.fail("did not raise") + self.assertEqual( + raised.exception.response.headers["Location"], + redirect_uri + "/", + ) + + @patch("websockets.asyncio.client.connect.process_redirect", lambda _, exc: exc) + def test_force_redirect_server_name(self): + """Router redirects connections to the host declared in server_name.""" + with run_router(self.url_map, server_name="other") as server: + with self.assertRaises(InvalidStatus) as raised: + with connect(get_uri(server) + "/r"): + self.fail("did not raise") + self.assertEqual( + raised.exception.response.headers["Location"], + "ws://other/", + ) + + def test_not_found(self): + """Router rejects requests to unknown paths with an HTTP 404 error.""" + with run_router(self.url_map) as server: + with self.assertRaises(InvalidStatus) as raised: + with connect(get_uri(server) + "/n"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 404", + ) + + def test_process_request_returning_none(self): + """Router supports a process_request returning None.""" + + def process_request(ws, request): + ws.process_request_ran = True + + with run_router(self.url_map, process_request=process_request) as server: + with connect(get_uri(server) + "/") as client: + self.assertEval(client, "ws.process_request_ran", "True") + + def test_process_request_returning_response(self): + """Router supports a process_request returning a response.""" + + def process_request(ws, request): + return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") + + with run_router(self.url_map, process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + with connect(get_uri(server) + "/"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) + + def test_custom_router_factory(self): + """Router supports a custom router factory.""" + + class MyRouter(Router): + def handler(self, connection): + connection.my_router_ran = True + return super().handler(connection) + + with run_router(self.url_map, create_router=MyRouter) as server: + with connect(get_uri(server)) as client: + self.assertEval(client, "ws.my_router_ran", "True") + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class UnixRouterTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + def test_router_supports_unix_sockets(self): + """Router supports Unix sockets.""" + url_map = Map([Rule("/echo/", endpoint=echo)]) + with temp_unix_socket_path() as path: + with run_unix_router(path, url_map): + with unix_connect(path, "ws://localhost/echo/3") as client: + client.send("hello") + messages = list(client) + self.assertEqual(messages, ["hello", "hello", "hello"]) diff --git a/tests/test_exports.py b/tests/test_exports.py index 88e27e69d..34a470661 100644 --- a/tests/test_exports.py +++ b/tests/test_exports.py @@ -2,6 +2,7 @@ import websockets import websockets.asyncio.client +import websockets.asyncio.router import websockets.asyncio.server import websockets.client import websockets.datastructures @@ -16,6 +17,7 @@ for name in ( [] + websockets.asyncio.client.__all__ + + websockets.asyncio.router.__all__ + websockets.asyncio.server.__all__ + websockets.client.__all__ + websockets.datastructures.__all__ diff --git a/tox.ini b/tox.ini index 918aeaaec..9450e9714 100644 --- a/tox.ini +++ b/tox.ini @@ -17,6 +17,7 @@ pass_env = deps = py311,py312,py313,coverage,maxi_cov: mitmproxy py311,py312,py313,coverage,maxi_cov: python-socks[asyncio] + werkzeug [testenv:coverage] commands = @@ -47,3 +48,4 @@ commands = deps = mypy python-socks + werkzeug From 983e484cb77a1f0c490aaab6a092d26c38fd4f61 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 4 Feb 2025 22:38:42 +0100 Subject: [PATCH 1512/1539] Add example of routing. --- example/routing.py | 154 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 154 insertions(+) create mode 100644 example/routing.py diff --git a/example/routing.py b/example/routing.py new file mode 100644 index 000000000..9f2df4980 --- /dev/null +++ b/example/routing.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python + +import asyncio +import datetime +import time +import zoneinfo + +from websockets.asyncio.router import route +from websockets.exceptions import ConnectionClosed +from werkzeug.routing import BaseConverter, Map, Rule, ValidationError + + +async def clock(websocket, tzinfo): + """Send the current time in the given timezone every second.""" + loop = asyncio.get_running_loop() + loop_offset = (loop.time() - time.time()) % 1 + try: + while True: + # Sleep until the next second according to the wall clock. + await asyncio.sleep(1 - (loop.time() - loop_offset) % 1) + now = datetime.datetime.now(tzinfo).replace(microsecond=0) + await websocket.send(now.isoformat()) + except ConnectionClosed: + return + + +async def alarm(websocket, alarm_at, tzinfo): + """Send the alarm time in the given timezone when it is reached.""" + alarm_at = alarm_at.replace(tzinfo=tzinfo) + now = datetime.datetime.now(tz=datetime.timezone.utc) + + try: + async with asyncio.timeout((alarm_at - now).total_seconds()): + await websocket.wait_closed() + except asyncio.TimeoutError: + try: + await websocket.send(alarm_at.isoformat()) + except ConnectionClosed: + return + + +async def timer(websocket, alarm_after): + """Send the remaining time until the alarm time every second.""" + alarm_at = datetime.datetime.now(tz=datetime.timezone.utc) + alarm_after + loop = asyncio.get_running_loop() + loop_offset = (loop.time() - time.time() + alarm_at.timestamp()) % 1 + + try: + while alarm_after.total_seconds() > 0: + # Sleep until the next second as a delta to the alarm time. + await asyncio.sleep(1 - (loop.time() - loop_offset) % 1) + alarm_after = alarm_at - datetime.datetime.now(tz=datetime.timezone.utc) + # Round up to the next second. + alarm_after += datetime.timedelta( + seconds=1, + microseconds=-alarm_after.microseconds, + ) + await websocket.send(format_timedelta(alarm_after)) + except ConnectionClosed: + return + + +class ZoneInfoConverter(BaseConverter): + regex = r"[A-Za-z0-9_/+-]+" + + def to_python(self, value): + try: + return zoneinfo.ZoneInfo(value) + except zoneinfo.ZoneInfoNotFoundError: + raise ValidationError + + def to_url(self, value): + return value.key + + +class DateTimeConverter(BaseConverter): + regex = r"[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}(?:\.[0-9]{3})?" + + def to_python(self, value): + try: + return datetime.datetime.fromisoformat(value) + except ValueError: + raise ValidationError + + def to_url(self, value): + return value.isoformat() + + +class TimeDeltaConverter(BaseConverter): + regex = r"[0-9]{2}:[0-9]{2}:[0-9]{2}(?:\.[0-9]{3}(?:[0-9]{3})?)?" + + def to_python(self, value): + return datetime.timedelta( + hours=int(value[0:2]), + minutes=int(value[3:5]), + seconds=int(value[6:8]), + milliseconds=int(value[9:12]) if len(value) == 12 else 0, + microseconds=int(value[9:15]) if len(value) == 15 else 0, + ) + + def to_url(self, value): + return format_timedelta(value) + + +def format_timedelta(delta): + assert 0 <= delta.seconds < 86400 + hours = delta.seconds // 3600 + minutes = (delta.seconds % 3600) // 60 + seconds = delta.seconds % 60 + if delta.microseconds: + return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{delta.microseconds:06d}" + else: + return f"{hours:02d}:{minutes:02d}:{seconds:02d}" + + +url_map = Map( + [ + Rule( + "/", + redirect_to="/clock", + ), + Rule( + "/clock", + defaults={"tzinfo": datetime.timezone.utc}, + endpoint=clock, + ), + Rule( + "/clock/", + endpoint=clock, + ), + Rule( + "/alarm//", + endpoint=alarm, + ), + Rule( + "/timer/", + endpoint=timer, + ), + ], + converters={ + "tzinfo": ZoneInfoConverter, + "datetime": DateTimeConverter, + "timedelta": TimeDeltaConverter, + }, +) + + +async def main(): + async with route(url_map, "localhost", 8888): + await asyncio.get_running_loop().create_future() # run forever + + +if __name__ == "__main__": + asyncio.run(main()) From a28bb7fff5d4141a805697a4645fe84c4101c718 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 4 Feb 2025 23:12:37 +0100 Subject: [PATCH 1513/1539] Update references to third-party servers. --- README.rst | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/README.rst b/README.rst index 94cd79ab9..39ae1e8ba 100644 --- a/README.rst +++ b/README.rst @@ -128,11 +128,12 @@ Why shouldn't I use ``websockets``? and :rfc:`7692`: Compression Extensions for WebSocket. Its support for HTTP is minimal — just enough for an HTTP health check. - If you want to do both in the same server, look at HTTP frameworks that - build on top of ``websockets`` to support WebSocket connections, like - Sanic_. + If you want to do both in the same server, look at HTTP + WebSocket servers + that build on top of ``websockets`` to support WebSocket connections, like + uvicorn_ or Sanic_. -.. _Sanic: https://sanicframework.org/en/ +.. _uvicorn: https://www.uvicorn.org/ +.. _Sanic: https://sanic.dev/en/ What else? ---------- From dde3716725332f76d7f62fdedfde172017bd02b8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 4 Feb 2025 23:12:54 +0100 Subject: [PATCH 1514/1539] There are faster libraries today. --- docs/faq/misc.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/faq/misc.rst b/docs/faq/misc.rst index 4936aa6f3..3b5106006 100644 --- a/docs/faq/misc.rst +++ b/docs/faq/misc.rst @@ -24,7 +24,9 @@ you must disable: * Keepalive: set ``ping_interval=None`` * UTF-8 decoding: send ``bytes`` rather than ``str`` -If websockets is still slower than another Python library, please file a bug. +Then, please consider whether websockets is the bottleneck of the performance +of your application. Usually, in real-world applications, CPU time spent in +websockets is negligible compared to time spent in the application logic. Are there ``onopen``, ``onmessage``, ``onerror``, and ``onclose`` callbacks? ............................................................................ From 0bdfbd1dbe72ce3770428325eb4cfa6f3952773c Mon Sep 17 00:00:00 2001 From: Forrest Li Date: Tue, 4 Feb 2025 18:38:27 -0500 Subject: [PATCH 1515/1539] Fix incorrect asyncio import in README --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 39ae1e8ba..dc2530c23 100644 --- a/README.rst +++ b/README.rst @@ -47,7 +47,7 @@ Here's an echo server with the ``asyncio`` API: #!/usr/bin/env python import asyncio - from websockets.server import serve + from websockets.asyncio.server import serve async def echo(websocket): async for message in websocket: From 5b516463c9a2ab9cf76ea087f0687cc0c8d54f0c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 7 Feb 2025 10:17:51 +0100 Subject: [PATCH 1516/1539] Revert "Disable PyPy in CI." This reverts commit 1e62b3c384a9fc345c661ebdd66a2a1a0fbbff52. Fix #1581. --- .github/workflows/tests.yml | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ca73cd499..5ab9c4c72 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -60,13 +60,12 @@ jobs: - "3.11" - "3.12" - "3.13" -# Disable PyPy per https://github.com/python-websockets/websockets/issues/1581 -# - "pypy-3.10" + - "pypy-3.10" is_main: - ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} -# exclude: -# - python: "pypy-3.10" -# is_main: false + exclude: + - python: "pypy-3.10" + is_main: false steps: - name: Check out repository uses: actions/checkout@v4 From 90db2de19feb1bd6147f2335f9facdec9f87e690 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 8 Feb 2025 10:53:51 +0100 Subject: [PATCH 1517/1539] Document HTTP Digest Auth as a known limitation. Ref #784. --- docs/reference/features.rst | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 0da966ccc..321e2e832 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -125,8 +125,6 @@ Server +------------------------------------+--------+--------+--------+--------+ | Perform HTTP Basic Authentication | ✅ | ✅ | ❌ | ✅ | +------------------------------------+--------+--------+--------+--------+ - | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | ❌ | - +------------------------------------+--------+--------+--------+--------+ | Dispatch connections to handlers | ✅ | ✅ | — | ❌ | +------------------------------------+--------+--------+--------+--------+ @@ -165,16 +163,11 @@ Client +------------------------------------+--------+--------+--------+--------+ | Perform HTTP Basic Authentication | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ - | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | ❌ | - | (`#784`_) | | | | | - +------------------------------------+--------+--------+--------+--------+ | Connect via HTTP proxy | ✅ | ✅ | — | ❌ | +------------------------------------+--------+--------+--------+--------+ | Connect via SOCKS5 proxy | ✅ | ✅ | — | ❌ | +------------------------------------+--------+--------+--------+--------+ -.. _#784: https://github.com/python-websockets/websockets/issues/784 - Known limitations ----------------- @@ -188,6 +181,10 @@ Request if it is missing or invalid (`#1246`). .. _#1246: https://github.com/python-websockets/websockets/issues/1246 +The client doesn't support HTTP Digest Authentication (`#784`_). + +.. _#784: https://github.com/python-websockets/websockets/issues/784 + The client API doesn't attempt to guarantee that there is no more than one connection to a given IP address in a CONNECTING state. This behavior is mandated by :rfc:`6455`, section 4.1. However, :func:`~asyncio.client.connect()` From 2dd33c86dfac2861d2adddb6e32eac9d1787a221 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 8 Feb 2025 17:49:02 +0100 Subject: [PATCH 1518/1539] Add guide for deployment to Koyeb. Fix #1369. --- docs/howto/fly.rst | 2 +- docs/howto/heroku.rst | 4 +- docs/howto/index.rst | 1 + docs/howto/koyeb.rst | 165 ++++++++++++++++++++++ docs/spelling_wordlist.txt | 1 + example/deployment/koyeb/Procfile | 1 + example/deployment/koyeb/app.py | 37 +++++ example/deployment/koyeb/requirements.txt | 1 + 8 files changed, 208 insertions(+), 4 deletions(-) create mode 100644 docs/howto/koyeb.rst create mode 100644 example/deployment/koyeb/Procfile create mode 100644 example/deployment/koyeb/app.py create mode 100644 example/deployment/koyeb/requirements.txt diff --git a/docs/howto/fly.rst b/docs/howto/fly.rst index ed001a2ae..4262e8aeb 100644 --- a/docs/howto/fly.rst +++ b/docs/howto/fly.rst @@ -1,5 +1,5 @@ Deploy to Fly -================ +============= This guide describes how to deploy a websockets server to Fly_. diff --git a/docs/howto/heroku.rst b/docs/howto/heroku.rst index b335e14c5..8d16eccfb 100644 --- a/docs/howto/heroku.rst +++ b/docs/howto/heroku.rst @@ -58,12 +58,10 @@ on websockets: .. literalinclude:: ../../example/deployment/heroku/requirements.txt :language: text -Create a ``Procfile``. +Create a ``Procfile`` to tell Heroku how to run the app. .. literalinclude:: ../../example/deployment/heroku/Procfile -This tells Heroku how to run the app. - Confirm that you created the correct files and commit them to git: .. code-block:: console diff --git a/docs/howto/index.rst b/docs/howto/index.rst index 863c1c63c..619b11fa8 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -48,6 +48,7 @@ Once your application is ready, learn how to deploy it on various platforms. :titlesonly: render + koyeb fly heroku kubernetes diff --git a/docs/howto/koyeb.rst b/docs/howto/koyeb.rst new file mode 100644 index 000000000..1ac17daac --- /dev/null +++ b/docs/howto/koyeb.rst @@ -0,0 +1,165 @@ +Deploy to Koyeb +================ + +This guide describes how to deploy a websockets server to Koyeb_. + +.. _Koyeb: https://www.koyeb.com + +.. admonition:: The free tier of Koyeb is sufficient for trying this guide. + :class: tip + + The `free tier`__ include one web service, which this guide uses. + + __ https://www.koyeb.com/pricing + +We’re going to deploy a very simple app. The process would be identical to a +more realistic app. + +Create repository +----------------- + +Koyeb supports multiple deployment methods. Its quick start guides recommend +git-driven deployment as the first option. Let's initialize a git repository: + +.. code-block:: console + + $ mkdir websockets-echo + $ cd websockets-echo + $ git init -b main + Initialized empty Git repository in websockets-echo/.git/ + $ git commit --allow-empty -m "Initial commit." + [main (root-commit) 740f699] Initial commit. + +Render requires the git repository to be hosted at GitHub. + +Sign up or log in to GitHub. Create a new repository named ``websockets-echo``. +Don't enable any of the initialization options offered by GitHub. Then, follow +instructions for pushing an existing repository from the command line. + +After pushing, refresh your repository's homepage on GitHub. You should see an +empty repository with an empty initial commit. + +Create application +------------------ + +Here’s the implementation of the app, an echo server. Save it in a file +called ``app.py``: + +.. literalinclude:: ../../example/deployment/koyeb/app.py + :language: python + +This app implements typical requirements for running on a Platform as a Service: + +* it listens on the port provided in the ``$PORT`` environment variable; +* it provides a health check at ``/healthz``; +* it closes connections and exits cleanly when it receives a ``SIGTERM`` signal; + while not documented, this is how Koyeb terminates apps. + +Create a ``requirements.txt`` file containing this line to declare a dependency +on websockets: + +.. literalinclude:: ../../example/deployment/koyeb/requirements.txt + :language: text + +Create a ``Procfile`` to tell Koyeb how to run the app. + +.. literalinclude:: ../../example/deployment/koyeb/Procfile + +Confirm that you created the correct files and commit them to git: + +.. code-block:: console + + $ ls + Procfile app.py requirements.txt + $ git add . + $ git commit -m "Initial implementation." + [main f634b8b] Initial implementation. +  3 files changed, 39 insertions(+) +  create mode 100644 Procfile +  create mode 100644 app.py +  create mode 100644 requirements.txt + +The app is ready. Let's deploy it! + +Deploy application +------------------ + +Sign up or log in to Koyeb. + +In the Koyeb control panel, create a web service with GitHub as the deployment +method. Install and authorize Koyeb's GitHub app if you haven't done that yet. + +Follow the steps to create a new service: + +1. Select the ``websockets-echo`` repository in the list of your repositories. +2. Confirm that the **Free** instance type is selected. Click **Next**. +3. Configure health checks: change the protocol from TCP to HTTP and set the + path to ``/healthz``. Review other settings; defaults should be correct. + Click **Deploy**. + +Koyeb builds the app, deploys it, verifies that the health checks passes, and +makes the deployment active. + +Validate deployment +------------------- + +Let's confirm that your application is running as expected. + +Since it's a WebSocket server, you need a WebSocket client, such as the +interactive client that comes with websockets. + +If you're currently building a websockets server, perhaps you're already in a +virtualenv where websockets is installed. If not, you can install it in a new +virtualenv as follows: + +.. code-block:: console + + $ python -m venv websockets-client + $ . websockets-client/bin/activate + $ pip install websockets + +Look for the URL of your app in the Koyeb control panel. It looks like +``https://--.koyeb.app/``. Connect the +interactive client — you must replace ``https`` with ``wss`` in the URL: + +.. code-block:: console + + $ python -m websockets wss://--.koyeb.app/ + Connected to wss://--.koyeb.app/. + > + +Great! Your app is running! + +Once you're connected, you can send any message and the server will echo it, +or press Ctrl-D to terminate the connection: + +.. code-block:: console + + > Hello! + < Hello! + Connection closed: 1000 (OK). + + +You can also confirm that your application shuts down gracefully. + +Connect an interactive client again: + +.. code-block:: console + + $ python -m websockets wss://--.koyeb.app/ + Connected to wss://--.koyeb.app/. + > + +In the Koyeb control panel, go to the **Settings** tab, click **Pause**, and +confirm. + +Eventually, the connection gets closed with code 1001 (going away). + +.. code-block:: console + + $ python -m websockets wss://--.koyeb.app/ + Connected to wss://--.koyeb.app/. + Connection closed: 1001 (going away). + +If graceful shutdown wasn't working, the server wouldn't perform a closing +handshake and the connection would be closed with code 1006 (abnormal closure). diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 11b13250a..c841bda2d 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -38,6 +38,7 @@ iterable js keepalive KiB +Koyeb kubernetes lifecycle linkerd diff --git a/example/deployment/koyeb/Procfile b/example/deployment/koyeb/Procfile new file mode 100644 index 000000000..2e35818f6 --- /dev/null +++ b/example/deployment/koyeb/Procfile @@ -0,0 +1 @@ +web: python app.py diff --git a/example/deployment/koyeb/app.py b/example/deployment/koyeb/app.py new file mode 100644 index 000000000..978467a30 --- /dev/null +++ b/example/deployment/koyeb/app.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python + +import asyncio +import http +import os +import signal + +from websockets.asyncio.server import serve + + +async def echo(websocket): + async for message in websocket: + await websocket.send(message) + + +def health_check(connection, request): + if request.path == "/healthz": + return connection.respond(http.HTTPStatus.OK, "OK\n") + + +async def main(): + # Set the stop condition when receiving SIGINT. + loop = asyncio.get_running_loop() + stop = loop.create_future() + loop.add_signal_handler(signal.SIGINT, stop.set_result, None) + + async with serve( + echo, + host="", + port=int(os.environ["PORT"]), + process_request=health_check, + ): + await stop + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/deployment/koyeb/requirements.txt b/example/deployment/koyeb/requirements.txt new file mode 100644 index 000000000..14774b465 --- /dev/null +++ b/example/deployment/koyeb/requirements.txt @@ -0,0 +1 @@ +websockets From 277f0f80e00fca87a9dbda8f5766d8dbf1347eda Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Feb 2025 16:24:39 +0100 Subject: [PATCH 1519/1539] Switch tutorial to Koyeb. Fix #1299. --- docs/howto/koyeb.rst | 14 +-- docs/intro/tutorial3.rst | 167 +++++++++++++++++---------------- docs/spelling_wordlist.txt | 3 +- example/tutorial/step3/app.py | 8 +- example/tutorial/step3/main.js | 2 +- 5 files changed, 101 insertions(+), 93 deletions(-) diff --git a/docs/howto/koyeb.rst b/docs/howto/koyeb.rst index 1ac17daac..0ad126dd8 100644 --- a/docs/howto/koyeb.rst +++ b/docs/howto/koyeb.rst @@ -119,13 +119,13 @@ virtualenv as follows: $ pip install websockets Look for the URL of your app in the Koyeb control panel. It looks like -``https://--.koyeb.app/``. Connect the +``https://--.koyeb.app/``. Connect the interactive client — you must replace ``https`` with ``wss`` in the URL: .. code-block:: console - $ python -m websockets wss://--.koyeb.app/ - Connected to wss://--.koyeb.app/. + $ python -m websockets wss://--.koyeb.app/ + Connected to wss://--.koyeb.app/. > Great! Your app is running! @@ -146,8 +146,8 @@ Connect an interactive client again: .. code-block:: console - $ python -m websockets wss://--.koyeb.app/ - Connected to wss://--.koyeb.app/. + $ python -m websockets wss://--.koyeb.app/ + Connected to wss://--.koyeb.app/. > In the Koyeb control panel, go to the **Settings** tab, click **Pause**, and @@ -157,8 +157,8 @@ Eventually, the connection gets closed with code 1001 (going away). .. code-block:: console - $ python -m websockets wss://--.koyeb.app/ - Connected to wss://--.koyeb.app/. + $ python -m websockets wss://--.koyeb.app/ + Connected to wss://--.koyeb.app/. Connection closed: 1001 (going away). If graceful shutdown wasn't working, the server wouldn't perform a closing diff --git a/docs/intro/tutorial3.rst b/docs/intro/tutorial3.rst index 21d51371b..9356fdbe1 100644 --- a/docs/intro/tutorial3.rst +++ b/docs/intro/tutorial3.rst @@ -28,14 +28,20 @@ and a WebSocket server on ``ws://localhost:8001/`` with: Now you want to deploy these servers on the Internet. There's a vast range of hosting providers to choose from. For the sake of simplicity, we'll rely on: -* GitHub Pages for the HTTP server; -* Heroku for the WebSocket server. +* `GitHub Pages`_ for the HTTP server; +* Koyeb_ for the WebSocket server. + +.. _GitHub Pages: https://pages.github.com/ +.. _Koyeb: https://www.koyeb.com/ + +Koyeb is a modern Platform as a Service provider whose free tier allows you to +run a web application, including a WebSocket server. Commit project to git --------------------- Perhaps you committed your work to git while you were progressing through the -tutorial. If you didn't, now is a good time, because GitHub and Heroku offer +tutorial. If you didn't, now is a good time, because GitHub and Koyeb offer git-based deployment workflows. Initialize a git repository: @@ -45,7 +51,7 @@ Initialize a git repository: $ git init -b main Initialized empty Git repository in websockets-tutorial/.git/ $ git commit --allow-empty -m "Initial commit." - [main (root-commit) ...] Initial commit. + [main (root-commit) 8195c1d] Initial commit. Add all files and commit: @@ -53,7 +59,7 @@ Add all files and commit: $ git add . $ git commit -m "Initial implementation of Connect Four game." - [main ...] Initial implementation of Connect Four game. + [main 7f0b2c4] Initial implementation of Connect Four game. 6 files changed, 500 insertions(+) create mode 100644 app.py create mode 100644 connect4.css @@ -62,32 +68,58 @@ Add all files and commit: create mode 100644 index.html create mode 100644 main.js -Prepare the WebSocket server ----------------------------- +Sign up or log in to GitHub. -Before you deploy the server, you must adapt it to meet requirements of -Heroku's runtime. This involves two small changes: +Create a new repository. Set the repository name to ``websockets-tutorial``, +the visibility to Public, and click **Create repository**. -1. Heroku expects the server to `listen on a specific port`_, provided in the - ``$PORT`` environment variable. +Push your code to this repository. You must replace ``python-websockets`` by +your GitHub username in the following command: -2. Heroku sends a ``SIGTERM`` signal when `shutting down a dyno`_, which - should trigger a clean exit. +.. code-block:: console -.. _listen on a specific port: https://devcenter.heroku.com/articles/preparing-a-codebase-for-heroku-deployment#4-listen-on-the-correct-port + $ git remote add origin git@github.com:python-websockets/websockets-tutorial.git + $ git branch -M main + $ git push -u origin main + ... + To github.com:python-websockets/websockets-tutorial.git + * [new branch] main -> main + Branch 'main' set up to track remote branch 'main' from 'origin'. + +Adapt the WebSocket server +-------------------------- -.. _shutting down a dyno: https://devcenter.heroku.com/articles/dynos#shutdown +Before you deploy the server, you must adapt it for Koyeb's environment. This +involves three small changes: + +1. Koyeb provides the port on which the server should listen in the ``$PORT`` + environment variable. + +2. Koyeb requires a health check to verify that the server is running. We'll add + a HTTP health check. + +3. Koyeb sends a ``SIGTERM`` signal when terminating the server. We'll catch it + and trigger a clean exit. Adapt the ``main()`` coroutine accordingly: .. code-block:: python + import http import os import signal +.. literalinclude:: ../../example/tutorial/step3/app.py + :pyobject: health_check + .. literalinclude:: ../../example/tutorial/step3/app.py :pyobject: main +The ``process_request`` parameter of :func:`~asyncio.server.serve` is a callback +that runs for each request. When it returns an HTTP response, websockets sends +that response instead of opening a WebSocket connection. Here, requests to +``/healthz`` return an HTTP 200 status code. + To catch the ``SIGTERM`` signal, ``main()`` creates a :class:`~asyncio.Future` called ``stop`` and registers a signal handler that sets the result of this future. The value of the future doesn't matter; it's only for waiting for @@ -97,8 +129,6 @@ Then, by using :func:`~asyncio.server.serve` as a context manager and exiting the context when ``stop`` has a result, ``main()`` ensures that the server closes connections cleanly and exits on ``SIGTERM``. -The app is now fully compatible with Heroku. - Deploy the WebSocket server --------------------------- @@ -108,12 +138,12 @@ when building the image: .. literalinclude:: ../../example/tutorial/step3/requirements.txt :language: text -.. admonition:: Heroku treats ``requirements.txt`` as a signal to `detect a Python app`_. +.. admonition:: Koyeb treats ``requirements.txt`` as a signal to `detect a Python app`__. :class: tip That's why you don't need to declare that you need a Python runtime. -.. _detect a Python app: https://devcenter.heroku.com/articles/python-support#recognizing-a-python-app + __ https://www.koyeb.com/docs/build-and-deploy/build-from-git/python#detection Create a ``Procfile`` file with this content to configure the command for running the server: @@ -121,66 +151,52 @@ running the server: .. literalinclude:: ../../example/tutorial/step3/Procfile :language: text -Commit your changes: +Commit and push your changes: .. code-block:: console $ git add . - $ git commit -m "Deploy to Heroku." - [main ...] Deploy to Heroku. - 3 files changed, 12 insertions(+), 2 deletions(-) + $ git commit -m "Deploy to Koyeb." + [main ac96d65] Deploy to Koyeb. + 3 files changed, 18 insertions(+), 2 deletions(-) create mode 100644 Procfile create mode 100644 requirements.txt + $ git push + ... + To github.com:python-websockets/websockets-tutorial.git + + 6bd6032...ac96d65 main -> main -Follow the `set-up instructions`_ to install the Heroku CLI and to log in, if -you haven't done that yet. - -.. _set-up instructions: https://devcenter.heroku.com/articles/getting-started-with-python#set-up - -Create a Heroku app. You must choose a unique name and replace -``websockets-tutorial`` by this name in the following command: +Sign up or log in to Koyeb. -.. code-block:: console - - $ heroku create websockets-tutorial - Creating ⬢ websockets-tutorial... done - https://websockets-tutorial.herokuapp.com/ | https://git.heroku.com/websockets-tutorial.git - -If you reuse a name that someone else already uses, you will receive this -error; if this happens, try another name: +In the Koyeb control panel, create a web service with GitHub as the deployment +method. `Install and authorize Koyeb's GitHub app`__ if you haven't done that yet. -.. code-block:: console - - $ heroku create websockets-tutorial - Creating ⬢ websockets-tutorial... ! - ▸ Name websockets-tutorial is already taken +__ https://www.koyeb.com/docs/build-and-deploy/deploy-with-git#connect-your-github-account-to-koyeb -Deploy by pushing the code to Heroku: +Follow the steps to create a new service: -.. code-block:: console +1. Select the ``websockets-tutorial`` repository in the list of your repositories. +2. Confirm that the **Free** instance type is selected. Click **Next**. +3. Configure health checks: change the protocol from TCP to HTTP and set the + path to ``/healthz``. Review other settings; defaults should be correct. + Click **Deploy**. - $ git push heroku - - ... lots of output... - - remote: Released v1 - remote: https://websockets-tutorial.herokuapp.com/ deployed to Heroku - remote: - remote: Verifying deploy... done. - To https://git.heroku.com/websockets-tutorial.git - * [new branch] main -> main +Koyeb builds the app, deploys it, verifies that the health checks passes, and +makes the deployment active. You can test the WebSocket server with the interactive client exactly like you -did in the first part of the tutorial. Replace ``websockets-tutorial`` by the -name of your app in the following command: +did in the first part of the tutorial. The Koyeb control panel provides the URL +of your app in the format: ``https://--.koyeb.app/``. Replace +``https`` with ``wss`` in the URL and connect the interactive client: .. code-block:: console - $ python -m websockets wss://websockets-tutorial.herokuapp.com/ - Connected to wss://websockets-tutorial.herokuapp.com/. + $ python -m websockets wss://--.koyeb.app/ + Connected to wss://--.koyeb.app/. > {"type": "init"} < {"type": "init", "join": "54ICxFae_Ip7TJE2", "watch": "634w44TblL5Dbd9a"} - Connection closed: 1000 (OK). + +Press Ctrl-D to terminate the connection. It works! @@ -199,7 +215,7 @@ You can take this strategy one step further by checking the address of the HTTP server and determining the address of the WebSocket server accordingly. Add this function to ``main.js``; replace ``python-websockets`` by your GitHub -username and ``websockets-tutorial`` by the name of your app on Heroku: +username and ``websockets-tutorial`` by the name of your app on Koyeb: .. literalinclude:: ../../example/tutorial/step3/main.js :language: js @@ -218,42 +234,27 @@ Commit your changes: $ git add . $ git commit -m "Configure WebSocket server address." - [main ...] Configure WebSocket server address. + [main 0903526] Configure WebSocket server address. 1 file changed, 11 insertions(+), 1 deletion(-) + $ git push + ... + To github.com:python-websockets/websockets-tutorial.git + + ac96d65...0903526 main -> main Deploy the web application -------------------------- -Go to GitHub and create a new repository called ``websockets-tutorial``. - -Push your code to this repository. You must replace ``python-websockets`` by -your GitHub username in the following command: - -.. code-block:: console - - $ git remote add origin git@github.com:python-websockets/websockets-tutorial.git - $ git push -u origin main - Enumerating objects: 11, done. - Counting objects: 100% (11/11), done. - Delta compression using up to 8 threads - Compressing objects: 100% (10/10), done. - Writing objects: 100% (11/11), 5.90 KiB | 2.95 MiB/s, done. - Total 11 (delta 0), reused 0 (delta 0), pack-reused 0 - To github.com:/websockets-tutorial.git - * [new branch] main -> main - Branch 'main' set up to track remote branch 'main' from 'origin'. - Go back to GitHub, open the Settings tab of the repository and select Pages in the menu. Select the main branch as source and click Save. GitHub tells you that your site is published. -Follow the link and start a game! +Open https://.github.io/websockets-tutorial/ and start a game! Summary ------- In this third part of the tutorial, you learned how to deploy a WebSocket -application with Heroku. +application with Koyeb. You can start a Connect Four game, send the JOIN link to a friend, and play over the Internet! diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index c841bda2d..dd32a78c3 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -27,6 +27,7 @@ Dockerfile dyno formatter fractalideas +github gunicorn healthz html @@ -38,7 +39,7 @@ iterable js keepalive KiB -Koyeb +koyeb kubernetes lifecycle linkerd diff --git a/example/tutorial/step3/app.py b/example/tutorial/step3/app.py index 261057f9a..335fd48d3 100644 --- a/example/tutorial/step3/app.py +++ b/example/tutorial/step3/app.py @@ -1,6 +1,7 @@ #!/usr/bin/env python import asyncio +import http import json import os import secrets @@ -183,6 +184,11 @@ async def handler(websocket): await start(websocket) +def health_check(connection, request): + if request.path == "/healthz": + return connection.respond(http.HTTPStatus.OK, "OK\n") + + async def main(): # Set the stop condition when receiving SIGTERM. loop = asyncio.get_running_loop() @@ -190,7 +196,7 @@ async def main(): loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) port = int(os.environ.get("PORT", "8001")) - async with serve(handler, "", port): + async with serve(handler, "", port, process_request=health_check): await stop diff --git a/example/tutorial/step3/main.js b/example/tutorial/step3/main.js index 3000fa2f7..3a7a0db49 100644 --- a/example/tutorial/step3/main.js +++ b/example/tutorial/step3/main.js @@ -2,7 +2,7 @@ import { createBoard, playMove } from "./connect4.js"; function getWebSocketServer() { if (window.location.host === "python-websockets.github.io") { - return "wss://websockets-tutorial.herokuapp.com/"; + return "wss://websockets-tutorial.koyeb.app/"; } else if (window.location.host === "localhost:8000") { return "ws://localhost:8001/"; } else { From 755eac176700548d430ef580c4ad1128a096674a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Feb 2025 17:28:02 +0100 Subject: [PATCH 1520/1539] Simplify pattern for shutting down servers. This avoids having to create manually a future. Also, it's robust to receiving multiple times the signal. Fix #1593. --- docs/faq/client.rst | 2 +- docs/faq/server.rst | 2 +- docs/howto/haproxy.rst | 2 +- docs/howto/nginx.rst | 2 +- docs/intro/tutorial3.rst | 20 ++++++++------------ docs/topics/deployment.rst | 3 +-- example/deployment/fly/app.py | 16 ++++------------ example/deployment/haproxy/app.py | 16 +++++----------- example/deployment/heroku/app.py | 16 +++++----------- example/deployment/koyeb/app.py | 17 +++++------------ example/deployment/kubernetes/app.py | 16 ++++------------ example/deployment/nginx/app.py | 15 +++++---------- example/deployment/render/app.py | 16 ++++------------ example/deployment/supervisor/app.py | 16 ++++------------ example/faq/shutdown_client.py | 3 +-- example/faq/shutdown_server.py | 16 +++++++--------- example/tutorial/step3/app.py | 11 ++++------- experiments/compression/server.py | 19 ++++++++----------- 18 files changed, 69 insertions(+), 139 deletions(-) diff --git a/docs/faq/client.rst b/docs/faq/client.rst index c39e588ca..cf27fcd45 100644 --- a/docs/faq/client.rst +++ b/docs/faq/client.rst @@ -103,7 +103,7 @@ You can close the connection. Here's an example that terminates cleanly when it receives SIGTERM on Unix: .. literalinclude:: ../../example/faq/shutdown_client.py - :emphasize-lines: 11-13 + :emphasize-lines: 10-12 How do I disable TLS/SSL certificate verification? -------------------------------------------------- diff --git a/docs/faq/server.rst b/docs/faq/server.rst index d00dcafba..bb04c5e1c 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -310,7 +310,7 @@ Exit the :func:`~serve` context manager. Here's an example that terminates cleanly when it receives SIGTERM on Unix: .. literalinclude:: ../../example/faq/shutdown_server.py - :emphasize-lines: 13-16,19 + :emphasize-lines: 14-16 How do I stop a server while keeping existing connections open? --------------------------------------------------------------- diff --git a/docs/howto/haproxy.rst b/docs/howto/haproxy.rst index fdaab0401..5ffe21ea3 100644 --- a/docs/howto/haproxy.rst +++ b/docs/howto/haproxy.rst @@ -15,7 +15,7 @@ Run server processes Save this app to ``app.py``: .. literalinclude:: ../../example/deployment/haproxy/app.py - :emphasize-lines: 24 + :language: python Each server process listens on a different port by extracting an incremental index from an environment variable set by Supervisor. diff --git a/docs/howto/nginx.rst b/docs/howto/nginx.rst index 872353cad..5b37c9b36 100644 --- a/docs/howto/nginx.rst +++ b/docs/howto/nginx.rst @@ -15,7 +15,7 @@ Run server processes Save this app to ``app.py``: .. literalinclude:: ../../example/deployment/nginx/app.py - :emphasize-lines: 21,23 + :language: python We'd like nginx to connect to websockets servers via Unix sockets in order to avoid the overhead of TCP for communicating between processes running in the diff --git a/docs/intro/tutorial3.rst b/docs/intro/tutorial3.rst index 9356fdbe1..bdd1bd50b 100644 --- a/docs/intro/tutorial3.rst +++ b/docs/intro/tutorial3.rst @@ -120,14 +120,10 @@ that runs for each request. When it returns an HTTP response, websockets sends that response instead of opening a WebSocket connection. Here, requests to ``/healthz`` return an HTTP 200 status code. -To catch the ``SIGTERM`` signal, ``main()`` creates a :class:`~asyncio.Future` -called ``stop`` and registers a signal handler that sets the result of this -future. The value of the future doesn't matter; it's only for waiting for -``SIGTERM``. - -Then, by using :func:`~asyncio.server.serve` as a context manager and exiting -the context when ``stop`` has a result, ``main()`` ensures that the server -closes connections cleanly and exits on ``SIGTERM``. +``main()`` registers a signal handler that closes the server when receiving the +``SIGTERM`` signal. Then, it waits for the server to be closed. Additionally, +using :func:`~asyncio.server.serve` as a context manager ensures that the server +will always be closed cleanly, even if the program crashes. Deploy the WebSocket server --------------------------- @@ -157,14 +153,14 @@ Commit and push your changes: $ git add . $ git commit -m "Deploy to Koyeb." - [main ac96d65] Deploy to Koyeb. - 3 files changed, 18 insertions(+), 2 deletions(-) + [main 4a4b6e9] Deploy to Koyeb. + 3 files changed, 15 insertions(+), 2 deletions(-) create mode 100644 Procfile create mode 100644 requirements.txt $ git push ... To github.com:python-websockets/websockets-tutorial.git - + 6bd6032...ac96d65 main -> main + + 6bd6032...4a4b6e9 main -> main Sign up or log in to Koyeb. @@ -239,7 +235,7 @@ Commit your changes: $ git push ... To github.com:python-websockets/websockets-tutorial.git - + ac96d65...0903526 main -> main + + 4a4b6e9...968eaaa main -> main Deploy the web application -------------------------- diff --git a/docs/topics/deployment.rst b/docs/topics/deployment.rst index 48ef72b56..00d1f9285 100644 --- a/docs/topics/deployment.rst +++ b/docs/topics/deployment.rst @@ -101,8 +101,7 @@ Here's an example: :emphasize-lines: 13-16,19 When exiting the context manager, :func:`~asyncio.server.serve` closes all -connections -with code 1001 (going away). As a consequence: +connections with code 1001 (going away). As a consequence: * If the connection handler is awaiting :meth:`~asyncio.server.ServerConnection.recv`, it receives a diff --git a/example/deployment/fly/app.py b/example/deployment/fly/app.py index c8e6af4f9..a841831cf 100644 --- a/example/deployment/fly/app.py +++ b/example/deployment/fly/app.py @@ -18,18 +18,10 @@ def health_check(connection, request): async def main(): - # Set the stop condition when receiving SIGTERM. - loop = asyncio.get_running_loop() - stop = loop.create_future() - loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - - async with serve( - echo, - host="", - port=8080, - process_request=health_check, - ): - await stop + async with serve(echo, "", 8080, process_request=health_check) as server: + loop = asyncio.get_running_loop() + loop.add_signal_handler(signal.SIGTERM, server.close) + await server.wait_closed() if __name__ == "__main__": diff --git a/example/deployment/haproxy/app.py b/example/deployment/haproxy/app.py index ef6d9c42d..6596c9f32 100644 --- a/example/deployment/haproxy/app.py +++ b/example/deployment/haproxy/app.py @@ -13,17 +13,11 @@ async def echo(websocket): async def main(): - # Set the stop condition when receiving SIGTERM. - loop = asyncio.get_running_loop() - stop = loop.create_future() - loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - - async with serve( - echo, - host="localhost", - port=8000 + int(os.environ["SUPERVISOR_PROCESS_NAME"][-2:]), - ): - await stop + port = 8000 + int(os.environ["SUPERVISOR_PROCESS_NAME"][-2:]) + async with serve(echo, "localhost", port) as server: + loop = asyncio.get_running_loop() + loop.add_signal_handler(signal.SIGTERM, server.close) + await server.wait_closed() if __name__ == "__main__": diff --git a/example/deployment/heroku/app.py b/example/deployment/heroku/app.py index 17ad09d26..524fb35f8 100644 --- a/example/deployment/heroku/app.py +++ b/example/deployment/heroku/app.py @@ -13,17 +13,11 @@ async def echo(websocket): async def main(): - # Set the stop condition when receiving SIGTERM. - loop = asyncio.get_running_loop() - stop = loop.create_future() - loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - - async with serve( - echo, - host="", - port=int(os.environ["PORT"]), - ): - await stop + port = int(os.environ["PORT"]) + async with serve(echo, "localhost", port) as server: + loop = asyncio.get_running_loop() + loop.add_signal_handler(signal.SIGTERM, server.close) + await server.wait_closed() if __name__ == "__main__": diff --git a/example/deployment/koyeb/app.py b/example/deployment/koyeb/app.py index 978467a30..4bfbee793 100644 --- a/example/deployment/koyeb/app.py +++ b/example/deployment/koyeb/app.py @@ -19,18 +19,11 @@ def health_check(connection, request): async def main(): - # Set the stop condition when receiving SIGINT. - loop = asyncio.get_running_loop() - stop = loop.create_future() - loop.add_signal_handler(signal.SIGINT, stop.set_result, None) - - async with serve( - echo, - host="", - port=int(os.environ["PORT"]), - process_request=health_check, - ): - await stop + port = int(os.environ["PORT"]) + async with serve(echo, "", port, process_request=health_check) as server: + loop = asyncio.get_running_loop() + loop.add_signal_handler(signal.SIGINT, server.close) + await server.wait_closed() if __name__ == "__main__": diff --git a/example/deployment/kubernetes/app.py b/example/deployment/kubernetes/app.py index 387f0ade1..95125773d 100755 --- a/example/deployment/kubernetes/app.py +++ b/example/deployment/kubernetes/app.py @@ -31,18 +31,10 @@ def health_check(connection, request): async def main(): - # Set the stop condition when receiving SIGTERM. - loop = asyncio.get_running_loop() - stop = loop.create_future() - loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - - async with serve( - slow_echo, - host="", - port=80, - process_request=health_check, - ): - await stop + async with serve(slow_echo, "", 80, process_request=health_check) as server: + loop = asyncio.get_running_loop() + loop.add_signal_handler(signal.SIGTERM, server.close) + await server.wait_closed() if __name__ == "__main__": diff --git a/example/deployment/nginx/app.py b/example/deployment/nginx/app.py index 134070f61..4b3ad9b13 100644 --- a/example/deployment/nginx/app.py +++ b/example/deployment/nginx/app.py @@ -13,16 +13,11 @@ async def echo(websocket): async def main(): - # Set the stop condition when receiving SIGTERM. - loop = asyncio.get_running_loop() - stop = loop.create_future() - loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - - async with unix_serve( - echo, - path=f"{os.environ['SUPERVISOR_PROCESS_NAME']}.sock", - ): - await stop + path = f"{os.environ['SUPERVISOR_PROCESS_NAME']}.sock" + async with unix_serve(echo, path) as server: + loop = asyncio.get_running_loop() + loop.add_signal_handler(signal.SIGTERM, server.close) + await server.wait_closed() if __name__ == "__main__": diff --git a/example/deployment/render/app.py b/example/deployment/render/app.py index c8e6af4f9..a841831cf 100644 --- a/example/deployment/render/app.py +++ b/example/deployment/render/app.py @@ -18,18 +18,10 @@ def health_check(connection, request): async def main(): - # Set the stop condition when receiving SIGTERM. - loop = asyncio.get_running_loop() - stop = loop.create_future() - loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - - async with serve( - echo, - host="", - port=8080, - process_request=health_check, - ): - await stop + async with serve(echo, "", 8080, process_request=health_check) as server: + loop = asyncio.get_running_loop() + loop.add_signal_handler(signal.SIGTERM, server.close) + await server.wait_closed() if __name__ == "__main__": diff --git a/example/deployment/supervisor/app.py b/example/deployment/supervisor/app.py index 5e69f16a6..1ca70bdc0 100644 --- a/example/deployment/supervisor/app.py +++ b/example/deployment/supervisor/app.py @@ -12,18 +12,10 @@ async def echo(websocket): async def main(): - # Set the stop condition when receiving SIGTERM. - loop = asyncio.get_running_loop() - stop = loop.create_future() - loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - - async with serve( - echo, - host="", - port=8080, - reuse_port=True, - ): - await stop + async with serve(echo, "", 8080, reuse_port=True) as server: + loop = asyncio.get_running_loop() + loop.add_signal_handler(signal.SIGTERM, server.close) + await server.wait_closed() if __name__ == "__main__": diff --git a/example/faq/shutdown_client.py b/example/faq/shutdown_client.py index 5c8bd8cbe..3280c6f9b 100755 --- a/example/faq/shutdown_client.py +++ b/example/faq/shutdown_client.py @@ -6,8 +6,7 @@ from websockets.asyncio.client import connect async def client(): - uri = "ws://localhost:8765" - async with connect(uri) as websocket: + async with connect("ws://localhost:8765") as websocket: # Close the connection when receiving SIGTERM. loop = asyncio.get_running_loop() loop.add_signal_handler(signal.SIGTERM, loop.create_task, websocket.close()) diff --git a/example/faq/shutdown_server.py b/example/faq/shutdown_server.py index 3f7bc5732..ea00e2520 100755 --- a/example/faq/shutdown_server.py +++ b/example/faq/shutdown_server.py @@ -5,17 +5,15 @@ from websockets.asyncio.server import serve -async def echo(websocket): +async def handler(websocket): async for message in websocket: - await websocket.send(message) + ... async def server(): - # Set the stop condition when receiving SIGTERM. - loop = asyncio.get_running_loop() - stop = loop.create_future() - loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - - async with serve(echo, "localhost", 8765): - await stop + async with serve(handler, "localhost", 8765) as server: + # Close the server when receiving SIGTERM. + loop = asyncio.get_running_loop() + loop.add_signal_handler(signal.SIGTERM, server.close) + await server.wait_closed() asyncio.run(server()) diff --git a/example/tutorial/step3/app.py b/example/tutorial/step3/app.py index 335fd48d3..8a285e92e 100644 --- a/example/tutorial/step3/app.py +++ b/example/tutorial/step3/app.py @@ -190,14 +190,11 @@ def health_check(connection, request): async def main(): - # Set the stop condition when receiving SIGTERM. - loop = asyncio.get_running_loop() - stop = loop.create_future() - loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - port = int(os.environ.get("PORT", "8001")) - async with serve(handler, "", port, process_request=health_check): - await stop + async with serve(handler, "", port, process_request=health_check) as server: + loop = asyncio.get_running_loop() + loop.add_signal_handler(signal.SIGTERM, server.close) + await server.wait_closed() if __name__ == "__main__": diff --git a/experiments/compression/server.py b/experiments/compression/server.py index 1c28f7355..dd399a29f 100644 --- a/experiments/compression/server.py +++ b/experiments/compression/server.py @@ -35,15 +35,6 @@ async def handler(ws): async def server(): - loop = asyncio.get_running_loop() - stop = loop.create_future() - - # Set the stop condition when receiving SIGTERM. - print("Stop the server with:") - print(f"kill -TERM {os.getpid()}") - print() - loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with serve( handler, "localhost", @@ -55,9 +46,15 @@ async def server(): compress_settings={"memLevel": ML}, ) ], - ): + ) as server: + print("Stop the server with:") + print(f"kill -TERM {os.getpid()}") + print() + loop = asyncio.get_running_loop() + loop.add_signal_handler(signal.SIGTERM, server.close) + tracemalloc.start() - await stop + await server.wait_closed() asyncio.run(server()) From 9f01cefde6f0d55de8bd2738a9601ddf661a1db1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Feb 2025 18:44:28 +0100 Subject: [PATCH 1521/1539] Simplify pattern for running servers forever. --- README.rst | 4 ++-- docs/intro/tutorial1.rst | 4 ++-- example/django/authentication.py | 4 ++-- example/faq/health_check_server.py | 4 ++-- example/quickstart/counter.py | 4 ++-- example/quickstart/server.py | 4 ++-- example/quickstart/server_secure.py | 4 ++-- example/quickstart/show_time.py | 4 ++-- example/routing.py | 4 ++-- example/tutorial/step1/app.py | 4 ++-- example/tutorial/step2/app.py | 4 ++-- 11 files changed, 22 insertions(+), 22 deletions(-) diff --git a/README.rst b/README.rst index dc2530c23..cc47b2910 100644 --- a/README.rst +++ b/README.rst @@ -54,8 +54,8 @@ Here's an echo server with the ``asyncio`` API: await websocket.send(message) async def main(): - async with serve(echo, "localhost", 8765): - await asyncio.get_running_loop().create_future() # run forever + async with serve(echo, "localhost", 8765) as server: + await server.serve_forever() asyncio.run(main()) diff --git a/docs/intro/tutorial1.rst b/docs/intro/tutorial1.rst index 87074caee..cc88e6a4a 100644 --- a/docs/intro/tutorial1.rst +++ b/docs/intro/tutorial1.rst @@ -194,8 +194,8 @@ Create an ``app.py`` file next to ``connect4.py`` with this content: async def main(): - async with serve(handler, "", 8001): - await asyncio.get_running_loop().create_future() # run forever + async with serve(handler, "", 8001) as server: + await server.serve_forever() if __name__ == "__main__": diff --git a/example/django/authentication.py b/example/django/authentication.py index c4f12a3f8..e61d70432 100644 --- a/example/django/authentication.py +++ b/example/django/authentication.py @@ -22,8 +22,8 @@ async def handler(websocket): async def main(): - async with serve(handler, "localhost", 8888): - await asyncio.get_running_loop().create_future() # run forever + async with serve(handler, "localhost", 8888) as server: + await server.serve_forever() if __name__ == "__main__": diff --git a/example/faq/health_check_server.py b/example/faq/health_check_server.py index 30623a4bb..3fdffb501 100755 --- a/example/faq/health_check_server.py +++ b/example/faq/health_check_server.py @@ -13,7 +13,7 @@ async def echo(websocket): await websocket.send(message) async def main(): - async with serve(echo, "localhost", 8765, process_request=health_check): - await asyncio.get_running_loop().create_future() # run forever + async with serve(echo, "localhost", 8765, process_request=health_check) as server: + await server.serve_forever() asyncio.run(main()) diff --git a/example/quickstart/counter.py b/example/quickstart/counter.py index 91eedc56a..8f0ff81be 100755 --- a/example/quickstart/counter.py +++ b/example/quickstart/counter.py @@ -42,8 +42,8 @@ async def counter(websocket): broadcast(USERS, users_event()) async def main(): - async with serve(counter, "localhost", 6789): - await asyncio.get_running_loop().create_future() # run forever + async with serve(counter, "localhost", 6789) as server: + await server.serve_forever() if __name__ == "__main__": asyncio.run(main()) diff --git a/example/quickstart/server.py b/example/quickstart/server.py index bde5e6126..a01f91703 100755 --- a/example/quickstart/server.py +++ b/example/quickstart/server.py @@ -14,8 +14,8 @@ async def hello(websocket): print(f">>> {greeting}") async def main(): - async with serve(hello, "localhost", 8765): - await asyncio.get_running_loop().create_future() # run forever + async with serve(hello, "localhost", 8765) as server: + await server.serve_forever() if __name__ == "__main__": asyncio.run(main()) diff --git a/example/quickstart/server_secure.py b/example/quickstart/server_secure.py index 8b456ed6e..92c6629b5 100755 --- a/example/quickstart/server_secure.py +++ b/example/quickstart/server_secure.py @@ -20,8 +20,8 @@ async def hello(websocket): ssl_context.load_cert_chain(localhost_pem) async def main(): - async with serve(hello, "localhost", 8765, ssl=ssl_context): - await asyncio.get_running_loop().create_future() # run forever + async with serve(hello, "localhost", 8765, ssl=ssl_context) as server: + await server.serve_forever() if __name__ == "__main__": asyncio.run(main()) diff --git a/example/quickstart/show_time.py b/example/quickstart/show_time.py index 8aeb811db..ecb908f30 100755 --- a/example/quickstart/show_time.py +++ b/example/quickstart/show_time.py @@ -13,8 +13,8 @@ async def show_time(websocket): await asyncio.sleep(random.random() * 2 + 1) async def main(): - async with serve(show_time, "localhost", 5678): - await asyncio.get_running_loop().create_future() # run forever + async with serve(show_time, "localhost", 5678) as server: + await server.serve_forever() if __name__ == "__main__": asyncio.run(main()) diff --git a/example/routing.py b/example/routing.py index 9f2df4980..7fc4ad4b3 100644 --- a/example/routing.py +++ b/example/routing.py @@ -146,8 +146,8 @@ def format_timedelta(delta): async def main(): - async with route(url_map, "localhost", 8888): - await asyncio.get_running_loop().create_future() # run forever + async with route(url_map, "localhost", 8888) as server: + await server.serve_forever() if __name__ == "__main__": diff --git a/example/tutorial/step1/app.py b/example/tutorial/step1/app.py index 595a10dc7..bc8f02484 100644 --- a/example/tutorial/step1/app.py +++ b/example/tutorial/step1/app.py @@ -57,8 +57,8 @@ async def handler(websocket): async def main(): - async with serve(handler, "", 8001): - await asyncio.get_running_loop().create_future() # run forever + async with serve(handler, "", 8001) as server: + await server.serve_forever() if __name__ == "__main__": diff --git a/example/tutorial/step2/app.py b/example/tutorial/step2/app.py index ef3dd9483..fe50fb3af 100644 --- a/example/tutorial/step2/app.py +++ b/example/tutorial/step2/app.py @@ -182,8 +182,8 @@ async def handler(websocket): async def main(): - async with serve(handler, "", 8001): - await asyncio.get_running_loop().create_future() # run forever + async with serve(handler, "", 8001) as server: + await server.serve_forever() if __name__ == "__main__": From 930defe196547aba051dbac8d409d6008a86f454 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Feb 2025 18:57:13 +0100 Subject: [PATCH 1522/1539] Fix signal name. --- example/deployment/koyeb/app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/deployment/koyeb/app.py b/example/deployment/koyeb/app.py index 4bfbee793..62ba9d843 100644 --- a/example/deployment/koyeb/app.py +++ b/example/deployment/koyeb/app.py @@ -22,7 +22,7 @@ async def main(): port = int(os.environ["PORT"]) async with serve(echo, "", port, process_request=health_check) as server: loop = asyncio.get_running_loop() - loop.add_signal_handler(signal.SIGINT, server.close) + loop.add_signal_handler(signal.SIGTERM, server.close) await server.wait_closed() From 281a5b39b5529e40946f5e54d6c205acb7cfefda Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Feb 2025 20:14:17 +0100 Subject: [PATCH 1523/1539] Remove cheat sheet. The information wasn't presented in a useful format and some of it was out of date. Refs #1209. --- docs/faq/common.rst | 32 ++++++++++++++- docs/faq/index.rst | 10 +++-- docs/howto/cheatsheet.rst | 86 --------------------------------------- docs/howto/index.rst | 1 - docs/topics/logging.rst | 2 +- 5 files changed, 38 insertions(+), 93 deletions(-) delete mode 100644 docs/howto/cheatsheet.rst diff --git a/docs/faq/common.rst b/docs/faq/common.rst index ba7a95932..b31d4504f 100644 --- a/docs/faq/common.rst +++ b/docs/faq/common.rst @@ -3,6 +3,33 @@ Both sides .. currentmodule:: websockets.asyncio.connection +.. _enable-debug-logs: + +How do I enable debug logs? +--------------------------- + +You can enable debug logs to see exactly what websockets is doing. + +If logging isn't configured in your application:: + + import logging + + logging.basicConfig( + format="%(asctime)s %(message)s", + level=logging.DEBUG, + ) + +If logging is already configured:: + + import logging + + logger = logging.getLogger("websockets") + logger.setLevel(logging.DEBUG) + logger.addHandler(logging.StreamHandler()) + +Refer to the :doc:`logging documentation <../topics/logging>` for more details +on logging in websockets. + What does ``ConnectionClosedError: no close frame received or sent`` mean? -------------------------------------------------------------------------- @@ -39,8 +66,9 @@ There are several reasons why long-lived connections may be lost: connections may terminate connections after a short amount of time, usually 30 seconds, despite websockets' keepalive mechanism. -If you're facing a reproducible issue, :ref:`enable debug logs ` to -see when and how connections are closed. +If you're facing a reproducible issue, :ref:`enable debug logs +` to see when and how connections are closed. connections are +closed. What does ``ConnectionClosedError: sent 1011 (internal error) keepalive ping timeout; no close frame received`` mean? --------------------------------------------------------------------------------------------------------------------- diff --git a/docs/faq/index.rst b/docs/faq/index.rst index 9d5b0d538..7488a5397 100644 --- a/docs/faq/index.rst +++ b/docs/faq/index.rst @@ -7,10 +7,14 @@ Frequently asked questions about :mod:`asyncio`. :class: seealso - Python's documentation about `developing with asyncio`_ is a good - complement. + If you're new to ``asyncio``, you will certainly encounter issues that are + related to asynchronous programming in general rather than to websockets in + particular. - .. _developing with asyncio: https://docs.python.org/3/library/asyncio-dev.html + Fortunately, Python's official documentation provides advice to `develop + with asyncio`_. Check it out: it's invaluable! + + .. _develop with asyncio: https://docs.python.org/3/library/asyncio-dev.html .. toctree:: diff --git a/docs/howto/cheatsheet.rst b/docs/howto/cheatsheet.rst deleted file mode 100644 index 8df2f234b..000000000 --- a/docs/howto/cheatsheet.rst +++ /dev/null @@ -1,86 +0,0 @@ -Cheat sheet -=========== - -.. currentmodule:: websockets - -Server ------- - -* Write a coroutine that handles a single connection. It receives a WebSocket - protocol instance and the URI path in argument. - - * Call :meth:`~asyncio.connection.Connection.recv` and - :meth:`~asyncio.connection.Connection.send` to receive and send messages at - any time. - - * When :meth:`~asyncio.connection.Connection.recv` or - :meth:`~asyncio.connection.Connection.send` raises - :exc:`~exceptions.ConnectionClosed`, clean up and exit. If you started other - :class:`asyncio.Task`, terminate them before exiting. - - * If you aren't awaiting :meth:`~asyncio.connection.Connection.recv`, consider - awaiting :meth:`~asyncio.connection.Connection.wait_closed` to detect - quickly when the connection is closed. - - * You may :meth:`~asyncio.connection.Connection.ping` or - :meth:`~asyncio.connection.Connection.pong` if you wish but it isn't needed - in general. - -* Create a server with :func:`~asyncio.server.serve` which is similar to asyncio's - :meth:`~asyncio.loop.create_server`. You can also use it as an asynchronous - context manager. - - * The server takes care of establishing connections, then lets the handler - execute the application logic, and finally closes the connection after the - handler exits normally or with an exception. - - * For advanced customization, you may subclass - :class:`~asyncio.server.ServerConnection` and pass either this subclass or a - factory function as the ``create_connection`` argument. - -Client ------- - -* Create a client with :func:`~asyncio.client.connect` which is similar to - asyncio's :meth:`~asyncio.loop.create_connection`. You can also use it as an - asynchronous context manager. - - * For advanced customization, you may subclass - :class:`~asyncio.client.ClientConnection` and pass either this subclass or - a factory function as the ``create_connection`` argument. - -* Call :meth:`~asyncio.connection.Connection.recv` and - :meth:`~asyncio.connection.Connection.send` to receive and send messages at - any time. - -* You may :meth:`~asyncio.connection.Connection.ping` or - :meth:`~asyncio.connection.Connection.pong` if you wish but it isn't needed in - general. - -* If you aren't using :func:`~asyncio.client.connect` as a context manager, call - :meth:`~asyncio.connection.Connection.close` to terminate the connection. - -.. _debugging: - -Debugging ---------- - -If you don't understand what websockets is doing, enable logging:: - - import logging - logger = logging.getLogger('websockets') - logger.setLevel(logging.DEBUG) - logger.addHandler(logging.StreamHandler()) - -The logs contain: - -* Exceptions in the connection handler at the ``ERROR`` level -* Exceptions in the opening or closing handshake at the ``INFO`` level -* All frames at the ``DEBUG`` level — this can be very verbose - -If you're new to ``asyncio``, you will certainly encounter issues that are -related to asynchronous programming in general rather than to websockets in -particular. Fortunately Python's official documentation provides advice to -`develop with asyncio`_. Check it out: it's invaluable! - -.. _develop with asyncio: https://docs.python.org/3/library/asyncio-dev.html diff --git a/docs/howto/index.rst b/docs/howto/index.rst index 619b11fa8..0b573b229 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -21,7 +21,6 @@ If you're stuck, perhaps you'll find the answer here. .. toctree:: :titlesonly: - cheatsheet patterns autoreload diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index fff33a024..03c5bef5f 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -72,7 +72,7 @@ Here's a basic configuration for a server in production:: Here's how to enable debug logs for development:: logging.basicConfig( - format="%(message)s", + format="%(asctime)s %(message)s", level=logging.DEBUG, ) From 476aaac5bd146f82c8b30658233897acaedbe82e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Feb 2025 20:29:42 +0100 Subject: [PATCH 1524/1539] Deindent module contents in API docs. This the dominant style in the project. --- docs/reference/datastructures.rst | 76 +++++++++++++++---------------- docs/reference/extensions.rst | 41 ++++++++--------- docs/reference/types.rst | 14 +++--- 3 files changed, 65 insertions(+), 66 deletions(-) diff --git a/docs/reference/datastructures.rst b/docs/reference/datastructures.rst index ec02d4210..04a7466fa 100644 --- a/docs/reference/datastructures.rst +++ b/docs/reference/datastructures.rst @@ -6,61 +6,61 @@ WebSocket events .. automodule:: websockets.frames - .. autoclass:: Frame - - .. autoclass:: Opcode - - .. autoattribute:: CONT - .. autoattribute:: TEXT - .. autoattribute:: BINARY - .. autoattribute:: CLOSE - .. autoattribute:: PING - .. autoattribute:: PONG - - .. autoclass:: Close - - .. autoclass:: CloseCode - - .. autoattribute:: NORMAL_CLOSURE - .. autoattribute:: GOING_AWAY - .. autoattribute:: PROTOCOL_ERROR - .. autoattribute:: UNSUPPORTED_DATA - .. autoattribute:: NO_STATUS_RCVD - .. autoattribute:: ABNORMAL_CLOSURE - .. autoattribute:: INVALID_DATA - .. autoattribute:: POLICY_VIOLATION - .. autoattribute:: MESSAGE_TOO_BIG - .. autoattribute:: MANDATORY_EXTENSION - .. autoattribute:: INTERNAL_ERROR - .. autoattribute:: SERVICE_RESTART - .. autoattribute:: TRY_AGAIN_LATER - .. autoattribute:: BAD_GATEWAY - .. autoattribute:: TLS_HANDSHAKE +.. autoclass:: Frame + +.. autoclass:: Opcode + + .. autoattribute:: CONT + .. autoattribute:: TEXT + .. autoattribute:: BINARY + .. autoattribute:: CLOSE + .. autoattribute:: PING + .. autoattribute:: PONG + +.. autoclass:: Close + +.. autoclass:: CloseCode + + .. autoattribute:: NORMAL_CLOSURE + .. autoattribute:: GOING_AWAY + .. autoattribute:: PROTOCOL_ERROR + .. autoattribute:: UNSUPPORTED_DATA + .. autoattribute:: NO_STATUS_RCVD + .. autoattribute:: ABNORMAL_CLOSURE + .. autoattribute:: INVALID_DATA + .. autoattribute:: POLICY_VIOLATION + .. autoattribute:: MESSAGE_TOO_BIG + .. autoattribute:: MANDATORY_EXTENSION + .. autoattribute:: INTERNAL_ERROR + .. autoattribute:: SERVICE_RESTART + .. autoattribute:: TRY_AGAIN_LATER + .. autoattribute:: BAD_GATEWAY + .. autoattribute:: TLS_HANDSHAKE HTTP events ----------- .. automodule:: websockets.http11 - .. autoclass:: Request +.. autoclass:: Request - .. autoclass:: Response +.. autoclass:: Response .. automodule:: websockets.datastructures - .. autoclass:: Headers +.. autoclass:: Headers - .. automethod:: get_all + .. automethod:: get_all - .. automethod:: raw_items + .. automethod:: raw_items - .. autoexception:: MultipleValuesError +.. autoexception:: MultipleValuesError URIs ---- .. automodule:: websockets.uri - .. autofunction:: parse_uri +.. autofunction:: parse_uri - .. autoclass:: WebSocketURI +.. autoclass:: WebSocketURI diff --git a/docs/reference/extensions.rst b/docs/reference/extensions.rst index f3da464a5..880ef4a2a 100644 --- a/docs/reference/extensions.rst +++ b/docs/reference/extensions.rst @@ -16,45 +16,44 @@ Per-Message Deflate .. automodule:: websockets.extensions.permessage_deflate - :mod:`websockets.extensions.permessage_deflate` implements WebSocket - Per-Message Deflate. +:mod:`websockets.extensions.permessage_deflate` implements WebSocket Per-Message +Deflate. - This extension is specified in :rfc:`7692`. +This extension is specified in :rfc:`7692`. - Refer to the :doc:`topic guide on compression <../topics/compression>` to - learn more about tuning compression settings. +Refer to the :doc:`topic guide on compression <../topics/compression>` to learn +more about tuning compression settings. - .. autoclass:: ClientPerMessageDeflateFactory +.. autoclass:: ServerPerMessageDeflateFactory - .. autoclass:: ServerPerMessageDeflateFactory +.. autoclass:: ClientPerMessageDeflateFactory Base classes ------------ .. automodule:: websockets.extensions - :mod:`websockets.extensions` defines base classes for implementing - extensions. +:mod:`websockets.extensions` defines base classes for implementing extensions. - Refer to the :doc:`how-to guide on extensions <../howto/extensions>` to - learn more about writing an extension. +Refer to the :doc:`how-to guide on extensions <../howto/extensions>` to learn +more about writing an extension. - .. autoclass:: Extension +.. autoclass:: Extension - .. autoattribute:: name + .. autoattribute:: name - .. automethod:: decode + .. automethod:: decode - .. automethod:: encode + .. automethod:: encode - .. autoclass:: ClientExtensionFactory +.. autoclass:: ServerExtensionFactory - .. autoattribute:: name + .. automethod:: process_request_params - .. automethod:: get_request_params +.. autoclass:: ClientExtensionFactory - .. automethod:: process_response_params + .. autoattribute:: name - .. autoclass:: ServerExtensionFactory + .. automethod:: get_request_params - .. automethod:: process_request_params + .. automethod:: process_response_params diff --git a/docs/reference/types.rst b/docs/reference/types.rst index 9d3aa8310..d249b9294 100644 --- a/docs/reference/types.rst +++ b/docs/reference/types.rst @@ -3,19 +3,19 @@ Types .. automodule:: websockets.typing - .. autodata:: Data +.. autodata:: Data - .. autodata:: LoggerLike +.. autodata:: LoggerLike - .. autodata:: StatusLike +.. autodata:: StatusLike - .. autodata:: Origin +.. autodata:: Origin - .. autodata:: Subprotocol +.. autodata:: Subprotocol - .. autodata:: ExtensionName +.. autodata:: ExtensionName - .. autodata:: ExtensionParameter +.. autodata:: ExtensionParameter .. autodata:: websockets.protocol.Event From 4a2cfd5eaccb8e48971d21f65c4d04cf766a273f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Feb 2025 20:30:33 +0100 Subject: [PATCH 1525/1539] Assorted fixes to the docs. --- docs/intro/index.rst | 1 + docs/reference/asyncio/client.rst | 4 ++-- docs/reference/asyncio/common.rst | 4 ++-- docs/reference/asyncio/server.rst | 4 ++-- docs/reference/features.rst | 3 ++- docs/reference/index.rst | 11 +++++------ docs/reference/legacy/client.rst | 4 ++-- docs/reference/legacy/common.rst | 4 ++-- docs/reference/legacy/server.rst | 4 ++-- docs/spelling_wordlist.txt | 1 + docs/topics/design.rst | 4 ++-- 11 files changed, 23 insertions(+), 21 deletions(-) diff --git a/docs/intro/index.rst b/docs/intro/index.rst index 642e50094..76994b6a2 100644 --- a/docs/intro/index.rst +++ b/docs/intro/index.rst @@ -35,6 +35,7 @@ Tutorial Learn how to build an real-time web application with websockets. .. toctree:: + :maxdepth: 1 tutorial1 tutorial2 diff --git a/docs/reference/asyncio/client.rst b/docs/reference/asyncio/client.rst index ea7b21506..72c7dce37 100644 --- a/docs/reference/asyncio/client.rst +++ b/docs/reference/asyncio/client.rst @@ -1,5 +1,5 @@ -Client (new :mod:`asyncio`) -=========================== +Client (:mod:`asyncio`) +======================= .. automodule:: websockets.asyncio.client diff --git a/docs/reference/asyncio/common.rst b/docs/reference/asyncio/common.rst index 325f20450..d772adc25 100644 --- a/docs/reference/asyncio/common.rst +++ b/docs/reference/asyncio/common.rst @@ -1,7 +1,7 @@ :orphan: -Both sides (new :mod:`asyncio`) -=============================== +Both sides (:mod:`asyncio`) +=========================== .. automodule:: websockets.asyncio.connection diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index 8d8b700f3..a245929ef 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -1,5 +1,5 @@ -Server (new :mod:`asyncio`) -=========================== +Server (:mod:`asyncio`) +======================= .. automodule:: websockets.asyncio.server diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 321e2e832..e5f6e0de0 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -68,6 +68,7 @@ Both sides | Measure latency | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Perform the closing handshake | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ | Enforce closing timeout | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Report close codes and reasons | ✅ | ✅ | ✅ | ❌ | @@ -177,7 +178,7 @@ There is no way to control compression of outgoing frames on a per-frame basis .. _#538: https://github.com/python-websockets/websockets/issues/538 The server doesn't check the Host header and doesn't respond with HTTP 400 Bad -Request if it is missing or invalid (`#1246`). +Request if it is missing or invalid (`#1246`_). .. _#1246: https://github.com/python-websockets/websockets/issues/1246 diff --git a/docs/reference/index.rst b/docs/reference/index.rst index c78a3c095..cc9542c24 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -51,12 +51,11 @@ application servers. sansio/server sansio/client -:mod:`asyncio` (legacy) ------------------------ - -This is the historical implementation. +Legacy +------ -It is deprecated and will be removed. +This is the historical implementation. It is deprecated. It will be removed by +2030. .. toctree:: :titlesonly: @@ -67,7 +66,7 @@ It is deprecated and will be removed. Extensions ---------- -The Per-Message Deflate extension is built in. You may also define custom +The Per-Message Deflate extension is built-in. You may also define custom extensions. .. toctree:: diff --git a/docs/reference/legacy/client.rst b/docs/reference/legacy/client.rst index a798409f0..ede887f32 100644 --- a/docs/reference/legacy/client.rst +++ b/docs/reference/legacy/client.rst @@ -1,5 +1,5 @@ -Client (legacy :mod:`asyncio`) -============================== +Client (legacy) +=============== .. admonition:: The legacy :mod:`asyncio` implementation is deprecated. :class: caution diff --git a/docs/reference/legacy/common.rst b/docs/reference/legacy/common.rst index 45c56fccd..821576020 100644 --- a/docs/reference/legacy/common.rst +++ b/docs/reference/legacy/common.rst @@ -1,7 +1,7 @@ :orphan: -Both sides (legacy :mod:`asyncio`) -================================== +Both sides (legacy) +=================== .. admonition:: The legacy :mod:`asyncio` implementation is deprecated. :class: caution diff --git a/docs/reference/legacy/server.rst b/docs/reference/legacy/server.rst index 3c1d19fc6..8636034e2 100644 --- a/docs/reference/legacy/server.rst +++ b/docs/reference/legacy/server.rst @@ -1,5 +1,5 @@ -Server (legacy :mod:`asyncio`) -============================== +Server (legacy) +=============== .. admonition:: The legacy :mod:`asyncio` implementation is deprecated. :class: caution diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index dd32a78c3..1980aafa3 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -33,6 +33,7 @@ healthz html hypercorn iframe +io IPv istio iterable diff --git a/docs/topics/design.rst b/docs/topics/design.rst index bc14bd332..c1f55a9dc 100644 --- a/docs/topics/design.rst +++ b/docs/topics/design.rst @@ -1,7 +1,7 @@ :orphan: -Design (legacy :mod:`asyncio`) -============================== +Design (legacy) +=============== .. currentmodule:: websockets.legacy From 2a612686a0c7f0f39d09e2478d23ca91d9cfe5d6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Feb 2025 22:03:58 +0100 Subject: [PATCH 1526/1539] Move deployment guides to their own section. --- docs/{howto => deploy}/fly.rst | 0 docs/{howto => deploy}/haproxy.rst | 0 docs/{howto => deploy}/heroku.rst | 0 docs/deploy/index.rst | 26 ++++++++++++++++++++++++++ docs/{howto => deploy}/koyeb.rst | 0 docs/{howto => deploy}/kubernetes.rst | 0 docs/{howto => deploy}/nginx.rst | 0 docs/{howto => deploy}/render.rst | 0 docs/{howto => deploy}/supervisor.rst | 0 docs/howto/index.rst | 14 -------------- docs/index.rst | 1 + 11 files changed, 27 insertions(+), 14 deletions(-) rename docs/{howto => deploy}/fly.rst (100%) rename docs/{howto => deploy}/haproxy.rst (100%) rename docs/{howto => deploy}/heroku.rst (100%) create mode 100644 docs/deploy/index.rst rename docs/{howto => deploy}/koyeb.rst (100%) rename docs/{howto => deploy}/kubernetes.rst (100%) rename docs/{howto => deploy}/nginx.rst (100%) rename docs/{howto => deploy}/render.rst (100%) rename docs/{howto => deploy}/supervisor.rst (100%) diff --git a/docs/howto/fly.rst b/docs/deploy/fly.rst similarity index 100% rename from docs/howto/fly.rst rename to docs/deploy/fly.rst diff --git a/docs/howto/haproxy.rst b/docs/deploy/haproxy.rst similarity index 100% rename from docs/howto/haproxy.rst rename to docs/deploy/haproxy.rst diff --git a/docs/howto/heroku.rst b/docs/deploy/heroku.rst similarity index 100% rename from docs/howto/heroku.rst rename to docs/deploy/heroku.rst diff --git a/docs/deploy/index.rst b/docs/deploy/index.rst new file mode 100644 index 000000000..a965e649b --- /dev/null +++ b/docs/deploy/index.rst @@ -0,0 +1,26 @@ +Deployment guides +================= + +Discover how to deploy your application on various platforms. + +Platforms-as-a-Service +---------------------- + +.. toctree:: + :titlesonly: + + render + koyeb + fly + heroku + +Self-hosted +----------- + +.. toctree:: + :titlesonly: + + kubernetes + supervisor + nginx + haproxy diff --git a/docs/howto/koyeb.rst b/docs/deploy/koyeb.rst similarity index 100% rename from docs/howto/koyeb.rst rename to docs/deploy/koyeb.rst diff --git a/docs/howto/kubernetes.rst b/docs/deploy/kubernetes.rst similarity index 100% rename from docs/howto/kubernetes.rst rename to docs/deploy/kubernetes.rst diff --git a/docs/howto/nginx.rst b/docs/deploy/nginx.rst similarity index 100% rename from docs/howto/nginx.rst rename to docs/deploy/nginx.rst diff --git a/docs/howto/render.rst b/docs/deploy/render.rst similarity index 100% rename from docs/howto/render.rst rename to docs/deploy/render.rst diff --git a/docs/howto/supervisor.rst b/docs/deploy/supervisor.rst similarity index 100% rename from docs/howto/supervisor.rst rename to docs/deploy/supervisor.rst diff --git a/docs/howto/index.rst b/docs/howto/index.rst index 0b573b229..deae44942 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -41,20 +41,6 @@ features, which websockets supports fully. .. _deployment-howto: -Once your application is ready, learn how to deploy it on various platforms. - -.. toctree:: - :titlesonly: - - render - koyeb - fly - heroku - kubernetes - supervisor - nginx - haproxy - If you're integrating the Sans-I/O layer of websockets into a library, rather than building an application with websockets, follow this guide. diff --git a/docs/index.rst b/docs/index.rst index de14fa2d0..bc3cc9df4 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -100,6 +100,7 @@ Do you like it? :doc:`Let's dive in! ` intro/index howto/index + deploy/index faq/index reference/index topics/index From 8c9f6fc27bb54610f1c389d7769cddc1504ca57e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Feb 2025 22:46:17 +0100 Subject: [PATCH 1527/1539] Merge deployment topic guide into deployment section. --- .../architecture.svg} | 0 docs/deploy/index.rst | 202 +++++++++++++++++- docs/howto/index.rst | 2 - docs/spelling_wordlist.txt | 1 + docs/topics/deployment.rst | 181 ---------------- docs/topics/index.rst | 1 - 6 files changed, 197 insertions(+), 190 deletions(-) rename docs/{topics/deployment.svg => deploy/architecture.svg} (100%) delete mode 100644 docs/topics/deployment.rst diff --git a/docs/topics/deployment.svg b/docs/deploy/architecture.svg similarity index 100% rename from docs/topics/deployment.svg rename to docs/deploy/architecture.svg diff --git a/docs/deploy/index.rst b/docs/deploy/index.rst index a965e649b..a6432d835 100644 --- a/docs/deploy/index.rst +++ b/docs/deploy/index.rst @@ -1,11 +1,39 @@ -Deployment guides -================= +Deployment +========== -Discover how to deploy your application on various platforms. +.. currentmodule:: websockets -Platforms-as-a-Service +Architecture decisions ---------------------- +When you deploy your websockets server to production, at a high level, your +architecture will almost certainly look like the following diagram: + +.. image:: architecture.svg + +The basic unit for scaling a websockets server is "one server process". Each +blue box in the diagram represents one server process. + +There's more variation in routing. While the routing layer is shown as one big +box, it is likely to involve several subsystems. + +As a consequence, when you design a deployment, you must answer two questions: + +1. How will I run the appropriate number of server processes? +2. How will I route incoming connections to these processes? + +These questions are interrelated. There's a wide range of valid answers, +depending on your goals and your constraints. + +Platforms-as-a-Service +...................... + +Platforms-as-a-Service are the easiest option. They provide end-to-end, +integrated solutions and they require little configuration. + +Here's how to deploy on some popular PaaS providers. Since all PaaS use +similar patterns, the concepts translate to other providers. + .. toctree:: :titlesonly: @@ -14,8 +42,13 @@ Platforms-as-a-Service fly heroku -Self-hosted ------------ +Self-hosted infrastructure +.......................... + +If you need more control over your infrastructure, you can deploy on your own +infrastructure. This requires more configuration. + +Here's how to configure some components mentioned in this guide. .. toctree:: :titlesonly: @@ -24,3 +57,160 @@ Self-hosted supervisor nginx haproxy + +Running server processes +------------------------ + +How many processes do I need? +............................. + +Typically, one server process will manage a few hundreds or thousands +connections, depending on the frequency of messages and the amount of work +they require. + +CPU and memory usage increase with the number of connections to the server. + +Often CPU is the limiting factor. If a server process goes to 100% CPU, then +you reached the limit. How much headroom you want to keep is up to you. + +Once you know how many connections a server process can manage and how many +connections you need to handle, you can calculate how many processes to run. + +You can also automate this calculation by configuring an autoscaler to keep +CPU usage or connection count within acceptable limits. + +.. admonition:: Don't scale with threads. Scale only with processes. + :class: tip + + Threads don't make sense for a server built with :mod:`asyncio`. + +How do I run processes? +....................... + +Most solutions for running multiple instances of a server process fall into +one of these three buckets: + +1. Running N processes on a platform: + + * a Kubernetes Deployment + + * its equivalent on a Platform as a Service provider + +2. Running N servers: + + * an AWS Auto Scaling group, a GCP Managed instance group, etc. + + * a fixed set of long-lived servers + +3. Running N processes on a server: + + * preferably via a process manager or supervisor + +Option 1 is easiest if you have access to such a platform. Option 2 usually +combines with option 3. + +How do I start a process? +......................... + +Run a Python program that invokes :func:`~asyncio.server.serve` or +:func:`~asyncio.router.route`. That's it! + +Don't run an ASGI server such as Uvicorn, Hypercorn, or Daphne. They're +alternatives to websockets, not complements. + +Don't run a WSGI server such as Gunicorn, Waitress, or mod_wsgi. They aren't +designed to run WebSocket applications. + +Applications servers handle network connections and expose a Python API. You +don't need one because websockets handles network connections directly. + +How do I stop a process? +........................ + +Process managers send the SIGTERM signal to terminate processes. Catch this +signal and exit the server to ensure a graceful shutdown. + +Here's an example: + +.. literalinclude:: ../../example/faq/shutdown_server.py + :emphasize-lines: 14-16 + +When exiting the context manager, :func:`~asyncio.server.serve` closes all +connections with code 1001 (going away). As a consequence: + +* If the connection handler is awaiting + :meth:`~asyncio.server.ServerConnection.recv`, it receives a + :exc:`~exceptions.ConnectionClosedOK` exception. It can catch the exception + and clean up before exiting. + +* Otherwise, it should be waiting on + :meth:`~asyncio.server.ServerConnection.wait_closed`, so it can receive the + :exc:`~exceptions.ConnectionClosedOK` exception and exit. + +This example is easily adapted to handle other signals. + +If you override the default signal handler for SIGINT, which raises +:exc:`KeyboardInterrupt`, be aware that you won't be able to interrupt a +program with Ctrl-C anymore when it's stuck in a loop. + +Routing connections +------------------- + +What does routing involve? +.......................... + +Since the routing layer is directly exposed to the Internet, it should provide +appropriate protection against threats ranging from Internet background noise +to targeted attacks. + +You should always secure WebSocket connections with TLS. Since the routing +layer carries the public domain name, it should terminate TLS connections. + +Finally, it must route connections to the server processes, balancing new +connections across them. + +How do I route connections? +........................... + +Here are typical solutions for load balancing, matched to ways of running +processes: + +1. If you're running on a platform, it comes with a routing layer: + + * a Kubernetes Ingress and Service + + * a service mesh: Istio, Consul, Linkerd, etc. + + * the routing mesh of a Platform as a Service + +2. If you're running N servers, you may load balance with: + + * a cloud load balancer: AWS Elastic Load Balancing, GCP Cloud Load + Balancing, etc. + + * A software load balancer: HAProxy, NGINX, etc. + +3. If you're running N processes on a server, you may load balance with: + + * A software load balancer: HAProxy, NGINX, etc. + + * The operating system — all processes listen on the same port + +You may trust the load balancer to handle encryption and to provide security. +You may add another layer in front of the load balancer for these purposes. + +There are many possibilities. Don't add layers that you don't need, though. + +How do I implement a health check? +.................................. + +Load balancers need a way to check whether server processes are up and running +to avoid routing connections to a non-functional backend. + +websockets provide minimal support for responding to HTTP requests with the +``process_request`` hook. + +Here's an example: + +.. literalinclude:: ../../example/faq/health_check_server.py + :emphasize-lines: 7-9,16 diff --git a/docs/howto/index.rst b/docs/howto/index.rst index deae44942..c7859f3cd 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -39,8 +39,6 @@ features, which websockets supports fully. extensions -.. _deployment-howto: - If you're integrating the Sans-I/O layer of websockets into a library, rather than building an application with websockets, follow this guide. diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 1980aafa3..4a7dcd5ab 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -51,6 +51,7 @@ middleware mutex mypy nginx +PaaS Paketo permessage pid diff --git a/docs/topics/deployment.rst b/docs/topics/deployment.rst deleted file mode 100644 index 00d1f9285..000000000 --- a/docs/topics/deployment.rst +++ /dev/null @@ -1,181 +0,0 @@ -Deployment -========== - -.. currentmodule:: websockets - -When you deploy your websockets server to production, at a high level, your -architecture will almost certainly look like the following diagram: - -.. image:: deployment.svg - -The basic unit for scaling a websockets server is "one server process". Each -blue box in the diagram represents one server process. - -There's more variation in routing. While the routing layer is shown as one big -box, it is likely to involve several subsystems. - -When you design a deployment, your should consider two questions: - -1. How will I run the appropriate number of server processes? -2. How will I route incoming connections to these processes? - -These questions are strongly related. There's a wide range of acceptable -answers, depending on your goals and your constraints. - -You can find a few concrete examples in the :ref:`deployment how-to guides -`. - -Running server processes ------------------------- - -How many processes do I need? -............................. - -Typically, one server process will manage a few hundreds or thousands -connections, depending on the frequency of messages and the amount of work -they require. - -CPU and memory usage increase with the number of connections to the server. - -Often CPU is the limiting factor. If a server process goes to 100% CPU, then -you reached the limit. How much headroom you want to keep is up to you. - -Once you know how many connections a server process can manage and how many -connections you need to handle, you can calculate how many processes to run. - -You can also automate this calculation by configuring an autoscaler to keep -CPU usage or connection count within acceptable limits. - -Don't scale with threads. Threads doesn't make sense for a server built with -:mod:`asyncio`. - -How do I run processes? -....................... - -Most solutions for running multiple instances of a server process fall into -one of these three buckets: - -1. Running N processes on a platform: - - * a Kubernetes Deployment - - * its equivalent on a Platform as a Service provider - -2. Running N servers: - - * an AWS Auto Scaling group, a GCP Managed instance group, etc. - - * a fixed set of long-lived servers - -3. Running N processes on a server: - - * preferably via a process manager or supervisor - -Option 1 is easiest of you have access to such a platform. - -Option 2 almost always combines with option 3. - -How do I start a process? -......................... - -Run a Python program that invokes :func:`~asyncio.server.serve`. That's it. - -Don't run an ASGI server such as Uvicorn, Hypercorn, or Daphne. They're -alternatives to websockets, not complements. - -Don't run a WSGI server such as Gunicorn, Waitress, or mod_wsgi. They aren't -designed to run WebSocket applications. - -Applications servers handle network connections and expose a Python API. You -don't need one because websockets handles network connections directly. - -How do I stop a process? -........................ - -Process managers send the SIGTERM signal to terminate processes. Catch this -signal and exit the server to ensure a graceful shutdown. - -Here's an example: - -.. literalinclude:: ../../example/faq/shutdown_server.py - :emphasize-lines: 13-16,19 - -When exiting the context manager, :func:`~asyncio.server.serve` closes all -connections with code 1001 (going away). As a consequence: - -* If the connection handler is awaiting - :meth:`~asyncio.server.ServerConnection.recv`, it receives a - :exc:`~exceptions.ConnectionClosedOK` exception. It can catch the exception - and clean up before exiting. - -* Otherwise, it should be waiting on - :meth:`~asyncio.server.ServerConnection.wait_closed`, so it can receive the - :exc:`~exceptions.ConnectionClosedOK` exception and exit. - -This example is easily adapted to handle other signals. - -If you override the default signal handler for SIGINT, which raises -:exc:`KeyboardInterrupt`, be aware that you won't be able to interrupt a -program with Ctrl-C anymore when it's stuck in a loop. - -Routing connections -------------------- - -What does routing involve? -.......................... - -Since the routing layer is directly exposed to the Internet, it should provide -appropriate protection against threats ranging from Internet background noise -to targeted attacks. - -You should always secure WebSocket connections with TLS. Since the routing -layer carries the public domain name, it should terminate TLS connections. - -Finally, it must route connections to the server processes, balancing new -connections across them. - -How do I route connections? -........................... - -Here are typical solutions for load balancing, matched to ways of running -processes: - -1. If you're running on a platform, it comes with a routing layer: - - * a Kubernetes Ingress and Service - - * a service mesh: Istio, Consul, Linkerd, etc. - - * the routing mesh of a Platform as a Service - -2. If you're running N servers, you may load balance with: - - * a cloud load balancer: AWS Elastic Load Balancing, GCP Cloud Load - Balancing, etc. - - * A software load balancer: HAProxy, NGINX, etc. - -3. If you're running N processes on a server, you may load balance with: - - * A software load balancer: HAProxy, NGINX, etc. - - * The operating system — all processes listen on the same port - -You may trust the load balancer to handle encryption and to provide security. -You may add another layer in front of the load balancer for these purposes. - -There are many possibilities. Don't add layers that you don't need, though. - -How do I implement a health check? -.................................. - -Load balancers need a way to check whether server processes are up and running -to avoid routing connections to a non-functional backend. - -websockets provide minimal support for responding to HTTP requests with the -``process_request`` hook. - -Here's an example: - -.. literalinclude:: ../../example/faq/health_check_server.py - :emphasize-lines: 7-9,18 diff --git a/docs/topics/index.rst b/docs/topics/index.rst index a08d487c9..ebffff71f 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -6,7 +6,6 @@ Get a deeper understanding of how websockets is built and why. .. toctree:: :titlesonly: - deployment logging authentication broadcast From 89b037fbf7aca8af23ff68f773615b242ff75ae7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 Feb 2025 10:53:52 +0100 Subject: [PATCH 1528/1539] Remove content added by mistake in 4b9caad7. --- docs/reference/sync/server.rst | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/docs/reference/sync/server.rst b/docs/reference/sync/server.rst index f6a45a659..59dde9b35 100644 --- a/docs/reference/sync/server.rst +++ b/docs/reference/sync/server.rst @@ -23,18 +23,6 @@ Routing connections .. currentmodule:: websockets.sync.server -Routing connections -------------------- - -.. autofunction:: route - :async: - -.. autofunction:: unix_route - :async: - -.. autoclass:: Server - - Running a server ---------------- From 74a7ac20a7009d68cd4a38705d8aac969a17e78a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 Feb 2025 13:27:12 +0100 Subject: [PATCH 1529/1539] Add topic guide on routing. --- docs/deploy/index.rst | 8 ++-- docs/faq/server.rst | 22 +--------- docs/project/changelog.rst | 7 +++- docs/topics/index.rst | 3 +- docs/topics/routing.rst | 83 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 96 insertions(+), 27 deletions(-) create mode 100644 docs/topics/routing.rst diff --git a/docs/deploy/index.rst b/docs/deploy/index.rst index a6432d835..2bdab9464 100644 --- a/docs/deploy/index.rst +++ b/docs/deploy/index.rst @@ -14,8 +14,8 @@ architecture will almost certainly look like the following diagram: The basic unit for scaling a websockets server is "one server process". Each blue box in the diagram represents one server process. -There's more variation in routing. While the routing layer is shown as one big -box, it is likely to involve several subsystems. +There's more variation in routing connections to processes. While the routing +layer is shown as one big box, it is likely to involve several subsystems. As a consequence, when you design a deployment, you must answer two questions: @@ -153,8 +153,8 @@ If you override the default signal handler for SIGINT, which raises :exc:`KeyboardInterrupt`, be aware that you won't be able to interrupt a program with Ctrl-C anymore when it's stuck in a loop. -Routing connections -------------------- +Routing connections to processes +-------------------------------- What does routing involve? .......................... diff --git a/docs/faq/server.rst b/docs/faq/server.rst index bb04c5e1c..10b041095 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -208,26 +208,8 @@ How do I access the request path? It is available in the :attr:`~ServerConnection.request` object. -You may route a connection to different handlers depending on the request path:: - - async def handler(websocket): - if websocket.request.path == "/blue": - await blue_handler(websocket) - elif websocket.request.path == "/green": - await green_handler(websocket) - else: - # No handler for this path; close the connection. - return - -For more complex routing, you may use :func:`~websockets.asyncio.router.route`. - -You may also route the connection based on the first message received from the -client, as shown in the :doc:`tutorial <../intro/tutorial2>`. When you want to -authenticate the connection before routing it, this is usually more convenient. - -Generally speaking, there is far less emphasis on the request path in WebSocket -servers than in HTTP servers. When a WebSocket server provides a single endpoint, -it may ignore the request path entirely. +Refer to the :doc:`routing guide <../topics/routing>` for details on how to +route connections to different handlers depending on the request path. How do I access HTTP headers? ----------------------------- diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index d7db6167a..31a537af5 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -43,7 +43,9 @@ Backwards-incompatible changes SOCKS proxies require installing the third-party library `python-socks`_. If you want to disable the proxy, add ``proxy=None`` when calling - :func:`~asyncio.client.connect`. See :doc:`../topics/proxies` for details. + :func:`~asyncio.client.connect`. + + See :doc:`proxies <../topics/proxies>` for details. .. _python-socks: https://github.com/romis2012/python-socks @@ -60,7 +62,8 @@ New features ............ * Added :func:`~asyncio.router.route` and :func:`~asyncio.router.unix_route` to - dispatch connections to different handlers depending on the URL. + dispatch connections to handlers based on the request path. Read more about + routing in :doc:`routing <../topics/routing>`. Improvements ............ diff --git a/docs/topics/index.rst b/docs/topics/index.rst index ebffff71f..4273599b7 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -6,12 +6,13 @@ Get a deeper understanding of how websockets is built and why. .. toctree:: :titlesonly: - logging authentication broadcast compression keepalive + logging memory security performance proxies + routing diff --git a/docs/topics/routing.rst b/docs/topics/routing.rst new file mode 100644 index 000000000..64c675e74 --- /dev/null +++ b/docs/topics/routing.rst @@ -0,0 +1,83 @@ +Routing +======= + +.. currentmodule:: websockets + +Many WebSocket servers provide just one endpoint. That's why +:func:`~asyncio.server.serve` accepts a single connection handler as its first +argument. + +This may come as a surprise to you if you're used to HTTP servers. In a standard +HTTP application, each request gets dispatched to a handler based on the request +path. Clients know which path to use for which operation. + +In a WebSocket application, clients open a persistent connection then they send +all messages over that unique connection. When different messages correspond to +different operations, they must be dispatched based on the message content. + +Simple routing +-------------- + +If you need different handlers for different clients or different use cases, you +may route each connection to the right handler based on the request path. + +Since WebSocket servers typically provide fewer routes than HTTP servers, you +can keep it simple:: + + async def handler(websocket): + match websocket.request.path: + case "/blue": + await blue_handler(websocket) + case "/green": + await green_handler(websocket) + case _: + # No handler for this path. Close the connection. + return + +You may also route connections based on the first message received from the +client, as demonstrated in the :doc:`tutorial <../intro/tutorial2>`:: + + import json + + async def handler(websocket): + message = await websocket.recv() + settings = json.loads(message) + match settings["color"]: + case "blue": + await blue_handler(websocket) + case "green": + await green_handler(websocket) + case _: + # No handler for this message. Close the connection. + return + +When you need to authenticate the connection before routing it, this pattern is +more convenient. + +Complex routing +--------------- + +If you have outgrow these simple patterns, websockets provides full-fledged +routing based on the request path with :func:`~asyncio.router.route`. + +This feature builds upon Flask_'s router. To use it, you must install the +third-party library `werkzeug`_:: + + $ pip install werkzeug + +.. _Flask: https://flask.palletsprojects.com/ +.. _werkzeug: https://werkzeug.palletsprojects.com/ + +:func:`~asyncio.router.route` expects a :class:`werkzeug.routing.Map` as its +first argument to declare which URL patterns map to which handlers. Review the +documentation of :mod:`werkzeug.routing` to learn about its functionality. + +To give you a sense of what's possible, here's the URL map of the example in +`example/routing.py`_: + +.. _example/routing.py: https://github.com/python-websockets/websockets/blob/main/example/routing.py + +.. literalinclude:: ../../example/routing.py + :language: python + :start-at: url_map = Map( + :end-at: await server.serve_forever() From 7bfb1140e1c19aa3b59ecf2b08bde56a82cfe04a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 Feb 2025 13:48:51 +0100 Subject: [PATCH 1530/1539] Add standalone guide to enable debug logs. Pointing to this information is frequently needed in the issue tracker. --- docs/faq/common.rst | 32 ++------------------------------ docs/howto/autoreload.rst | 7 +++---- docs/howto/debugging.rst | 33 +++++++++++++++++++++++++++++++++ docs/howto/index.rst | 29 +++++++++++++---------------- docs/intro/tutorial1.rst | 5 ++++- 5 files changed, 55 insertions(+), 51 deletions(-) create mode 100644 docs/howto/debugging.rst diff --git a/docs/faq/common.rst b/docs/faq/common.rst index b31d4504f..1ee0062af 100644 --- a/docs/faq/common.rst +++ b/docs/faq/common.rst @@ -3,33 +3,6 @@ Both sides .. currentmodule:: websockets.asyncio.connection -.. _enable-debug-logs: - -How do I enable debug logs? ---------------------------- - -You can enable debug logs to see exactly what websockets is doing. - -If logging isn't configured in your application:: - - import logging - - logging.basicConfig( - format="%(asctime)s %(message)s", - level=logging.DEBUG, - ) - -If logging is already configured:: - - import logging - - logger = logging.getLogger("websockets") - logger.setLevel(logging.DEBUG) - logger.addHandler(logging.StreamHandler()) - -Refer to the :doc:`logging documentation <../topics/logging>` for more details -on logging in websockets. - What does ``ConnectionClosedError: no close frame received or sent`` mean? -------------------------------------------------------------------------- @@ -66,9 +39,8 @@ There are several reasons why long-lived connections may be lost: connections may terminate connections after a short amount of time, usually 30 seconds, despite websockets' keepalive mechanism. -If you're facing a reproducible issue, :ref:`enable debug logs -` to see when and how connections are closed. connections are -closed. +If you're facing a reproducible issue, :doc:`enable debug logs +<../howto/debugging>` to see when and how connections are closed. What does ``ConnectionClosedError: sent 1011 (internal error) keepalive ping timeout; no close frame received`` mean? --------------------------------------------------------------------------------------------------------------------- diff --git a/docs/howto/autoreload.rst b/docs/howto/autoreload.rst index fc736a591..4d1adee8b 100644 --- a/docs/howto/autoreload.rst +++ b/docs/howto/autoreload.rst @@ -7,8 +7,7 @@ stop the server and restart it, which slows down your development process. Web frameworks such as Django or Flask provide a development server that reloads the application automatically when you make code changes. There is no -such functionality in websockets because it's designed for production rather -than development. +such functionality in websockets because it's designed only for production. However, you can achieve the same result easily. @@ -27,5 +26,5 @@ Run your server with ``watchmedo auto-restart``: $ watchmedo auto-restart --pattern "*.py" --recursive --signal SIGTERM \ python app.py -This example assumes that the server is defined in a script called ``app.py``. -Adapt it as necessary. +This example assumes that the server is defined in a script called ``app.py`` +and exits cleanly when receiving the ``SIGTERM`` signal. Adapt it as necessary. diff --git a/docs/howto/debugging.rst b/docs/howto/debugging.rst new file mode 100644 index 000000000..fc5dcba56 --- /dev/null +++ b/docs/howto/debugging.rst @@ -0,0 +1,33 @@ +Enable debug logs +================== + +websockets logs events with the :mod:`logging` module from the standard library. + +It writes to the ``"websockets.server"`` and ``"websockets.client"`` loggers. + +Enable logs at the ``DEBUG`` level to see exactly what websockets is doing. + +If logging isn't configured in your application:: + + import logging + + logging.basicConfig( + format="%(asctime)s %(message)s", + level=logging.DEBUG, + ) + +If logging is already configured:: + + import logging + + logger = logging.getLogger("websockets") + logger.setLevel(logging.DEBUG) + logger.addHandler(logging.StreamHandler()) + +Refer to the :doc:`logging guide <../topics/logging>` for more details on +logging in websockets. + +In addition, you may enable asyncio's `debug mode`_ to see what asyncio is +doing. + +.. _debug mode: https://docs.python.org/3/library/asyncio-dev.html#asyncio-debug-mode diff --git a/docs/howto/index.rst b/docs/howto/index.rst index c7859f3cd..8a9717ed1 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -4,40 +4,30 @@ How-to guides In a hurry? Check out these examples. .. toctree:: - :titlesonly: quickstart - -Upgrading from the legacy :mod:`asyncio` implementation to the new one? -Read this. - -.. toctree:: - :titlesonly: - - upgrade + debugging + autoreload If you're stuck, perhaps you'll find the answer here. .. toctree:: - :titlesonly: patterns - autoreload This guide will help you integrate websockets into a broader system. .. toctree:: - :titlesonly: django -The WebSocket protocol makes provisions for extending or specializing its -features, which websockets supports fully. +Upgrading from the legacy :mod:`asyncio` implementation to the new one? +Read this. .. toctree:: - :titlesonly: + :maxdepth: 2 - extensions + upgrade If you're integrating the Sans-I/O layer of websockets into a library, rather than building an application with websockets, follow this guide. @@ -46,3 +36,10 @@ than building an application with websockets, follow this guide. :maxdepth: 2 sansio + +The WebSocket protocol makes provisions for extending or specializing its +features, which websockets supports fully. + +.. toctree:: + + extensions diff --git a/docs/intro/tutorial1.rst b/docs/intro/tutorial1.rst index cc88e6a4a..39e693aae 100644 --- a/docs/intro/tutorial1.rst +++ b/docs/intro/tutorial1.rst @@ -545,7 +545,10 @@ taking alternate turns. import logging - logging.basicConfig(format="%(message)s", level=logging.DEBUG) + logging.basicConfig( + format="%(asctime)s %(message)s", + level=logging.DEBUG, + ) If you're stuck, a solution is available at the bottom of this document. From e60e04cdc755fe41e8057b23842e87cabacdc694 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 Feb 2025 14:12:00 +0100 Subject: [PATCH 1531/1539] Unify examples for topic guides under experiments. --- docs/topics/logging.rst | 2 +- docs/topics/routing.rst | 6 +++--- {example/logging => experiments}/json_log_formatter.py | 0 {example => experiments}/routing.py | 0 4 files changed, 4 insertions(+), 4 deletions(-) rename {example/logging => experiments}/json_log_formatter.py (100%) rename {example => experiments}/routing.py (100%) diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index 03c5bef5f..2eedd32a4 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -146,7 +146,7 @@ output logs as JSON with a bit of effort. First, we need a :class:`~logging.Formatter` that renders JSON: -.. literalinclude:: ../../example/logging/json_log_formatter.py +.. literalinclude:: ../../experiments/json_log_formatter.py Then, we configure logging to apply this formatter:: diff --git a/docs/topics/routing.rst b/docs/topics/routing.rst index 64c675e74..d1790532a 100644 --- a/docs/topics/routing.rst +++ b/docs/topics/routing.rst @@ -73,11 +73,11 @@ first argument to declare which URL patterns map to which handlers. Review the documentation of :mod:`werkzeug.routing` to learn about its functionality. To give you a sense of what's possible, here's the URL map of the example in -`example/routing.py`_: +`experiments/routing.py`_: -.. _example/routing.py: https://github.com/python-websockets/websockets/blob/main/example/routing.py +.. _experiments/routing.py: https://github.com/python-websockets/websockets/blob/main/experiments/routing.py -.. literalinclude:: ../../example/routing.py +.. literalinclude:: ../../experiments/routing.py :language: python :start-at: url_map = Map( :end-at: await server.serve_forever() diff --git a/example/logging/json_log_formatter.py b/experiments/json_log_formatter.py similarity index 100% rename from example/logging/json_log_formatter.py rename to experiments/json_log_formatter.py diff --git a/example/routing.py b/experiments/routing.py similarity index 100% rename from example/routing.py rename to experiments/routing.py From 82dfd83cee4814f27867c909da785ebe3228cbf8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 Feb 2025 14:16:39 +0100 Subject: [PATCH 1532/1539] Improve index page for topic guides. --- docs/topics/authentication.rst | 4 ++-- docs/topics/index.rst | 18 +++++++++++++----- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/docs/topics/authentication.rst b/docs/topics/authentication.rst index e2de4332e..7c022066f 100644 --- a/docs/topics/authentication.rst +++ b/docs/topics/authentication.rst @@ -313,8 +313,8 @@ To authenticate a websockets client with HTTP Basic Authentication async with connect(f"wss://{username}:{password}@.../") as websocket: ... -(You must :func:`~urllib.parse.quote` ``username`` and ``password`` if they -contain unsafe characters.) +You must :func:`~urllib.parse.quote` ``username`` and ``password`` if they +contain unsafe characters. To authenticate a websockets client with HTTP Bearer Authentication (:rfc:`6750`), add a suitable ``Authorization`` header: diff --git a/docs/topics/index.rst b/docs/topics/index.rst index 4273599b7..ca5d83c97 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -1,18 +1,26 @@ Topic guides ============ -Get a deeper understanding of how websockets is built and why. +These documents discuss how websockets is designed and how to make the best of +its features when building applications. .. toctree:: - :titlesonly: + :maxdepth: 2 authentication broadcast + logging + proxies + routing + +These guides describe how to optimize the configuration of websockets +applications for performance and reliability. + +.. toctree:: + :maxdepth: 2 + compression keepalive - logging memory security performance - proxies - routing From 61a6a7a99534e2522773476fc72ca615f5a6ea97 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 Feb 2025 21:48:27 +0100 Subject: [PATCH 1533/1539] Clean up quick start guides. * Move most examples back to the getting started section. * Add a dedicated howto on encryption. And some drive-by fixes. Refs #1209. --- docs/deploy/koyeb.rst | 1 - docs/howto/encryption.rst | 65 +++++++ docs/howto/index.rst | 12 +- docs/howto/quickstart.rst | 170 ------------------ docs/intro/examples.rst | 112 ++++++++++++ docs/intro/index.rst | 9 +- docs/reference/legacy/server.rst | 1 - docs/topics/compression.rst | 1 - docs/topics/proxies.rst | 4 +- docs/topics/routing.rst | 5 +- docs/topics/security.rst | 9 +- example/quick/client.py | 17 ++ example/{quickstart => quick}/counter.css | 0 example/{quickstart => quick}/counter.html | 0 example/{quickstart => quick}/counter.js | 0 example/{quickstart => quick}/counter.py | 1 + example/{quickstart => quick}/server.py | 0 example/{quickstart => quick}/show_time.html | 0 example/{quickstart => quick}/show_time.js | 0 example/{quickstart => quick}/show_time.py | 2 +- example/quick/sync_time.py | 23 +++ example/quickstart/client.py | 19 -- example/quickstart/show_time_2.py | 29 --- .../client_secure.py => tls/client.py} | 13 +- example/{quickstart => tls}/localhost.pem | 0 .../server_secure.py => tls/server.py} | 0 src/websockets/asyncio/router.py | 4 +- src/websockets/sync/router.py | 4 +- 28 files changed, 258 insertions(+), 243 deletions(-) create mode 100644 docs/howto/encryption.rst delete mode 100644 docs/howto/quickstart.rst create mode 100644 docs/intro/examples.rst create mode 100755 example/quick/client.py rename example/{quickstart => quick}/counter.css (100%) rename example/{quickstart => quick}/counter.html (100%) rename example/{quickstart => quick}/counter.js (100%) rename example/{quickstart => quick}/counter.py (99%) rename example/{quickstart => quick}/server.py (100%) rename example/{quickstart => quick}/show_time.html (100%) rename example/{quickstart => quick}/show_time.js (100%) rename example/{quickstart => quick}/show_time.py (84%) create mode 100755 example/quick/sync_time.py delete mode 100755 example/quickstart/client.py delete mode 100755 example/quickstart/show_time_2.py rename example/{quickstart/client_secure.py => tls/client.py} (61%) rename example/{quickstart => tls}/localhost.pem (100%) rename example/{quickstart/server_secure.py => tls/server.py} (100%) diff --git a/docs/deploy/koyeb.rst b/docs/deploy/koyeb.rst index 0ad126dd8..0b7c96cb9 100644 --- a/docs/deploy/koyeb.rst +++ b/docs/deploy/koyeb.rst @@ -139,7 +139,6 @@ or press Ctrl-D to terminate the connection: < Hello! Connection closed: 1000 (OK). - You can also confirm that your application shuts down gracefully. Connect an interactive client again: diff --git a/docs/howto/encryption.rst b/docs/howto/encryption.rst new file mode 100644 index 000000000..af19fefd0 --- /dev/null +++ b/docs/howto/encryption.rst @@ -0,0 +1,65 @@ +Encrypt connections +==================== + +.. currentmodule:: websockets + +You should always secure WebSocket connections with TLS_ (Transport Layer +Security). + +.. admonition:: TLS vs. SSL + :class: tip + + TLS is sometimes referred to as SSL (Secure Sockets Layer). SSL was an + earlier encryption protocol; the name stuck. + +The ``wss`` protocol is to ``ws`` what ``https`` is to ``http``. + +Secure WebSocket connections require certificates just like HTTPS. + +.. _TLS: https://developer.mozilla.org/en-US/docs/Web/Security/Transport_Layer_Security + +.. admonition:: Configure the TLS context securely + :class: attention + + The examples below demonstrate the ``ssl`` argument with a TLS certificate + shared between the client and the server. This is a simplistic setup. + + Please review the advice and security considerations in the documentation of + the :mod:`ssl` module to configure the TLS context appropriately. + +Servers +------- + +In a typical :doc:`deployment <../deploy/index>`, the server is behind a reverse +proxy that terminates TLS. The client connects to the reverse proxy with TLS and +the reverse proxy connects to the server without TLS. + +In that case, you don't need to configure TLS in websockets. + +If needed in your setup, you can terminate TLS in the server. + +In the example below, :func:`~asyncio.server.serve` is configured to receive +secure connections. Before running this server, download +:download:`localhost.pem <../../example/tls/localhost.pem>` and save it in the +same directory as ``server.py``. + +.. literalinclude:: ../../example/tls/server.py + :caption: server.py + +Receive both plain and TLS connections on the same port isn't supported. + +Clients +------- + +:func:`~asyncio.client.connect` enables TLS automatically when connecting to a +``wss://...`` URI. + +This works out of the box when the TLS certificate of the server is valid, +meaning it's signed by a certificate authority that your Python installation +trusts. + +In the example above, since the server uses a self-signed certificate, the +client needs to be configured to trust the certificate. Here's how to do so. + +.. literalinclude:: ../../example/tls/client.py + :caption: client.py diff --git a/docs/howto/index.rst b/docs/howto/index.rst index 8a9717ed1..ffded9ff0 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -1,13 +1,18 @@ How-to guides ============= -In a hurry? Check out these examples. +Set up your development environment comfortably. .. toctree:: - quickstart - debugging autoreload + debugging + +Configure websockets securely in production. + +.. toctree:: + + encryption If you're stuck, perhaps you'll find the answer here. @@ -25,7 +30,6 @@ Upgrading from the legacy :mod:`asyncio` implementation to the new one? Read this. .. toctree:: - :maxdepth: 2 upgrade diff --git a/docs/howto/quickstart.rst b/docs/howto/quickstart.rst deleted file mode 100644 index e6bd362a4..000000000 --- a/docs/howto/quickstart.rst +++ /dev/null @@ -1,170 +0,0 @@ -Quick start -=========== - -.. currentmodule:: websockets - -Here are a few examples to get you started quickly with websockets. - -Say "Hello world!" ------------------- - -Here's a WebSocket server. - -It receives a name from the client, sends a greeting, and closes the connection. - -.. literalinclude:: ../../example/quickstart/server.py - :caption: server.py - :language: python - :linenos: - -:func:`~asyncio.server.serve` executes the connection handler coroutine -``hello()`` once for each WebSocket connection. It closes the WebSocket -connection when the handler returns. - -Here's a corresponding WebSocket client. - -It sends a name to the server, receives a greeting, and closes the connection. - -.. literalinclude:: ../../example/quickstart/client.py - :caption: client.py - :language: python - :linenos: - -Using :func:`~asyncio.client.connect` as an asynchronous context manager ensures -the WebSocket connection is closed. - -.. _secure-server-example: - -Encrypt connections -------------------- - -Secure WebSocket connections improve confidentiality and also reliability -because they reduce the risk of interference by bad proxies. - -The ``wss`` protocol is to ``ws`` what ``https`` is to ``http``. The -connection is encrypted with TLS_ (Transport Layer Security). ``wss`` -requires certificates like ``https``. - -.. _TLS: https://developer.mozilla.org/en-US/docs/Web/Security/Transport_Layer_Security - -.. admonition:: TLS vs. SSL - :class: tip - - TLS is sometimes referred to as SSL (Secure Sockets Layer). SSL was an - earlier encryption protocol; the name stuck. - -Here's how to adapt the server to encrypt connections. You must download -:download:`localhost.pem <../../example/quickstart/localhost.pem>` and save it -in the same directory as ``server_secure.py``. - -.. literalinclude:: ../../example/quickstart/server_secure.py - :caption: server_secure.py - :language: python - :linenos: - -Here's how to adapt the client similarly. - -.. literalinclude:: ../../example/quickstart/client_secure.py - :caption: client_secure.py - :language: python - :linenos: - -In this example, the client needs a TLS context because the server uses a -self-signed certificate. - -When connecting to a secure WebSocket server with a valid certificate — any -certificate signed by a CA that your Python installation trusts — you can simply -pass ``ssl=True`` to :func:`~asyncio.client.connect`. - -.. admonition:: Configure the TLS context securely - :class: attention - - This example demonstrates the ``ssl`` argument with a TLS certificate shared - between the client and the server. This is a simplistic setup. - - Please review the advice and security considerations in the documentation of - the :mod:`ssl` module to configure the TLS context securely. - -Connect from a browser ----------------------- - -The WebSocket protocol was invented for the web — as the name says! - -Here's how to connect to a WebSocket server from a browser. - -Run this script in a console: - -.. literalinclude:: ../../example/quickstart/show_time.py - :caption: show_time.py - :language: python - :linenos: - -Save this file as ``show_time.html``: - -.. literalinclude:: ../../example/quickstart/show_time.html - :caption: show_time.html - :language: html - :linenos: - -Save this file as ``show_time.js``: - -.. literalinclude:: ../../example/quickstart/show_time.js - :caption: show_time.js - :language: js - :linenos: - -Then, open ``show_time.html`` in several browsers. Clocks tick irregularly. - -Broadcast messages ------------------- - -Let's change the previous example to send the same timestamps to all browsers, -instead of generating independent sequences for each client. - -Stop the previous script if it's still running and run this script in a console: - -.. literalinclude:: ../../example/quickstart/show_time_2.py - :caption: show_time_2.py - :language: python - :linenos: - -Refresh ``show_time.html`` in all browsers. Clocks tick in sync. - -Manage application state ------------------------- - -A WebSocket server can receive events from clients, process them to update the -application state, and broadcast the updated state to all connected clients. - -Here's an example where any client can increment or decrement a counter. The -concurrency model of :mod:`asyncio` guarantees that updates are serialized. - -Run this script in a console: - -.. literalinclude:: ../../example/quickstart/counter.py - :caption: counter.py - :language: python - :linenos: - -Save this file as ``counter.html``: - -.. literalinclude:: ../../example/quickstart/counter.html - :caption: counter.html - :language: html - :linenos: - -Save this file as ``counter.css``: - -.. literalinclude:: ../../example/quickstart/counter.css - :caption: counter.css - :language: css - :linenos: - -Save this file as ``counter.js``: - -.. literalinclude:: ../../example/quickstart/counter.js - :caption: counter.js - :language: js - :linenos: - -Then open ``counter.html`` file in several browsers and play with [+] and [-]. diff --git a/docs/intro/examples.rst b/docs/intro/examples.rst new file mode 100644 index 000000000..341712475 --- /dev/null +++ b/docs/intro/examples.rst @@ -0,0 +1,112 @@ +Quick examples +============== + +.. currentmodule:: websockets + +Start a server +-------------- + +This WebSocket server receives a name from the client, sends a greeting, and +closes the connection. + +.. literalinclude:: ../../example/quick/server.py + :caption: server.py + :language: python + +:func:`~asyncio.server.serve` executes the connection handler coroutine +``hello()`` once for each WebSocket connection. It closes the WebSocket +connection when the handler returns. + +Connect a client +---------------- + +This WebSocket client sends a name to the server, receives a greeting, and +closes the connection. + +.. literalinclude:: ../../example/quick/client.py + :caption: client.py + :language: python + +Using :func:`~sync.client.connect` as a context manager ensures that the +WebSocket connection is closed. + +Connect a browser +----------------- + +The WebSocket protocol was invented for the web — as the name says! + +Here's how to connect a browser to a WebSocket server. + +Run this script in a console: + +.. literalinclude:: ../../example/quick/show_time.py + :caption: show_time.py + :language: python + +Save this file as ``show_time.html``: + +.. literalinclude:: ../../example/quick/show_time.html + :caption: show_time.html + :language: html + +Save this file as ``show_time.js``: + +.. literalinclude:: ../../example/quick/show_time.js + :caption: show_time.js + :language: js + +Then, open ``show_time.html`` in several browsers or tabs. Clocks tick +irregularly. + +Broadcast messages +------------------ + +Let's send the same timestamps to everyone instead of generating independent +sequences for each connection. + +Stop the previous script if it's still running and run this script in a console: + +.. literalinclude:: ../../example/quick/sync_time.py + :caption: sync_time.py + :language: python + +Refresh ``show_time.html`` in all browsers or tabs. Clocks tick in sync. + +Manage application state +------------------------ + +A WebSocket server can receive events from clients, process them to update the +application state, and broadcast the updated state to all connected clients. + +Here's an example where any client can increment or decrement a counter. The +concurrency model of :mod:`asyncio` guarantees that updates are serialized. + +This example keep tracks of connected users explicitly in ``USERS`` instead of +relying on :attr:`server.connections `. The +result is the same. + +Run this script in a console: + +.. literalinclude:: ../../example/quick/counter.py + :caption: counter.py + :language: python + +Save this file as ``counter.html``: + +.. literalinclude:: ../../example/quick/counter.html + :caption: counter.html + :language: html + +Save this file as ``counter.css``: + +.. literalinclude:: ../../example/quick/counter.css + :caption: counter.css + :language: css + +Save this file as ``counter.js``: + +.. literalinclude:: ../../example/quick/counter.js + :caption: counter.js + :language: js + +Then open ``counter.html`` file in several browsers and play with [+] and [-]. diff --git a/docs/intro/index.rst b/docs/intro/index.rst index 76994b6a2..d6f8fb9e0 100644 --- a/docs/intro/index.rst +++ b/docs/intro/index.rst @@ -35,7 +35,7 @@ Tutorial Learn how to build an real-time web application with websockets. .. toctree:: - :maxdepth: 1 + :maxdepth: 2 tutorial1 tutorial2 @@ -44,4 +44,9 @@ Learn how to build an real-time web application with websockets. In a hurry? ----------- -Look at the :doc:`quick start guide <../howto/quickstart>`. +These examples will get you started quickly with websockets. + +.. toctree:: + :maxdepth: 2 + + examples diff --git a/docs/reference/legacy/server.rst b/docs/reference/legacy/server.rst index 8636034e2..0ac84156d 100644 --- a/docs/reference/legacy/server.rst +++ b/docs/reference/legacy/server.rst @@ -94,7 +94,6 @@ Using a connection .. autoproperty:: close_reason - Broadcast --------- diff --git a/docs/topics/compression.rst b/docs/topics/compression.rst index 5f09bbf73..06bba0922 100644 --- a/docs/topics/compression.rst +++ b/docs/topics/compression.rst @@ -20,7 +20,6 @@ based on the Deflate_ algorithm specified in :rfc:`7692`. the reduction in network bandwidth is usually worth the additional memory and CPU cost. - Configuring compression ----------------------- diff --git a/docs/topics/proxies.rst b/docs/topics/proxies.rst index 14fc68c0c..a2536d4c0 100644 --- a/docs/topics/proxies.rst +++ b/docs/topics/proxies.rst @@ -55,7 +55,9 @@ SOCKS proxies ------------- Connecting through a SOCKS proxy requires installing the third-party library -`python-socks`_:: +`python-socks`_: + +.. code-block:: console $ pip install python-socks\[asyncio\] diff --git a/docs/topics/routing.rst b/docs/topics/routing.rst index d1790532a..44d89e00b 100644 --- a/docs/topics/routing.rst +++ b/docs/topics/routing.rst @@ -61,7 +61,9 @@ If you have outgrow these simple patterns, websockets provides full-fledged routing based on the request path with :func:`~asyncio.router.route`. This feature builds upon Flask_'s router. To use it, you must install the -third-party library `werkzeug`_:: +third-party library `werkzeug`_: + +.. code-block:: console $ pip install werkzeug @@ -78,6 +80,5 @@ To give you a sense of what's possible, here's the URL map of the example in .. _experiments/routing.py: https://github.com/python-websockets/websockets/blob/main/experiments/routing.py .. literalinclude:: ../../experiments/routing.py - :language: python :start-at: url_map = Map( :end-at: await server.serve_forever() diff --git a/docs/topics/security.rst b/docs/topics/security.rst index a22b752c7..e91f73b15 100644 --- a/docs/topics/security.rst +++ b/docs/topics/security.rst @@ -6,10 +6,13 @@ Security Encryption ---------- -For production use, a server should require encrypted connections. +In production, you should always secure WebSocket connections with TLS. -See this example of :ref:`encrypting connections with TLS -`. +Secure WebSocket connections provide confidentiality and integrity, as well as +better reliability because they reduce the risk of interference by bad proxies. + +WebSocket servers are usually deployed behind a reverse proxy that terminates +TLS. Else, you can :doc:`configure TLS <../howto/encryption>` for the server. Memory usage ------------ diff --git a/example/quick/client.py b/example/quick/client.py new file mode 100755 index 000000000..4f34c0628 --- /dev/null +++ b/example/quick/client.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python + +from websockets.sync.client import connect + +def hello(): + uri = "ws://localhost:8765" + with connect(uri) as websocket: + name = input("What's your name? ") + + websocket.send(name) + print(f">>> {name}") + + greeting = websocket.recv() + print(f"<<< {greeting}") + +if __name__ == "__main__": + hello() diff --git a/example/quickstart/counter.css b/example/quick/counter.css similarity index 100% rename from example/quickstart/counter.css rename to example/quick/counter.css diff --git a/example/quickstart/counter.html b/example/quick/counter.html similarity index 100% rename from example/quickstart/counter.html rename to example/quick/counter.html diff --git a/example/quickstart/counter.js b/example/quick/counter.js similarity index 100% rename from example/quickstart/counter.js rename to example/quick/counter.js diff --git a/example/quickstart/counter.py b/example/quick/counter.py similarity index 99% rename from example/quickstart/counter.py rename to example/quick/counter.py index 8f0ff81be..b31345ce2 100755 --- a/example/quickstart/counter.py +++ b/example/quick/counter.py @@ -3,6 +3,7 @@ import asyncio import json import logging + from websockets.asyncio.server import broadcast, serve logging.basicConfig() diff --git a/example/quickstart/server.py b/example/quick/server.py similarity index 100% rename from example/quickstart/server.py rename to example/quick/server.py diff --git a/example/quickstart/show_time.html b/example/quick/show_time.html similarity index 100% rename from example/quickstart/show_time.html rename to example/quick/show_time.html diff --git a/example/quickstart/show_time.js b/example/quick/show_time.js similarity index 100% rename from example/quickstart/show_time.js rename to example/quick/show_time.js diff --git a/example/quickstart/show_time.py b/example/quick/show_time.py similarity index 84% rename from example/quickstart/show_time.py rename to example/quick/show_time.py index ecb908f30..b56aada7b 100755 --- a/example/quickstart/show_time.py +++ b/example/quick/show_time.py @@ -8,7 +8,7 @@ async def show_time(websocket): while True: - message = datetime.datetime.utcnow().isoformat() + "Z" + message = datetime.datetime.now(tz=datetime.timezone.utc).isoformat() await websocket.send(message) await asyncio.sleep(random.random() * 2 + 1) diff --git a/example/quick/sync_time.py b/example/quick/sync_time.py new file mode 100755 index 000000000..cdbe731af --- /dev/null +++ b/example/quick/sync_time.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python + +import asyncio +import datetime +import random + +from websockets.asyncio.server import broadcast, serve + +async def noop(websocket): + await websocket.wait_closed() + +async def show_time(server): + while True: + message = datetime.datetime.now(tz=datetime.timezone.utc).isoformat() + broadcast(server.connections, message) + await asyncio.sleep(random.random() * 2 + 1) + +async def main(): + async with serve(noop, "localhost", 5678) as server: + await show_time(server) + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/quickstart/client.py b/example/quickstart/client.py deleted file mode 100755 index 934af69e3..000000000 --- a/example/quickstart/client.py +++ /dev/null @@ -1,19 +0,0 @@ -#!/usr/bin/env python - -import asyncio - -from websockets.asyncio.client import connect - -async def hello(): - uri = "ws://localhost:8765" - async with connect(uri) as websocket: - name = input("What's your name? ") - - await websocket.send(name) - print(f">>> {name}") - - greeting = await websocket.recv() - print(f"<<< {greeting}") - -if __name__ == "__main__": - asyncio.run(hello()) diff --git a/example/quickstart/show_time_2.py b/example/quickstart/show_time_2.py deleted file mode 100755 index 9c9659d14..000000000 --- a/example/quickstart/show_time_2.py +++ /dev/null @@ -1,29 +0,0 @@ -#!/usr/bin/env python - -import asyncio -import datetime -import random - -from websockets.asyncio.server import broadcast, serve - -CONNECTIONS = set() - -async def register(websocket): - CONNECTIONS.add(websocket) - try: - await websocket.wait_closed() - finally: - CONNECTIONS.remove(websocket) - -async def show_time(): - while True: - message = datetime.datetime.utcnow().isoformat() + "Z" - broadcast(CONNECTIONS, message) - await asyncio.sleep(random.random() * 2 + 1) - -async def main(): - async with serve(register, "localhost", 5678): - await show_time() - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/example/quickstart/client_secure.py b/example/tls/client.py similarity index 61% rename from example/quickstart/client_secure.py rename to example/tls/client.py index a1449587a..c97ccf8e4 100755 --- a/example/quickstart/client_secure.py +++ b/example/tls/client.py @@ -1,25 +1,24 @@ #!/usr/bin/env python -import asyncio import pathlib import ssl -from websockets.asyncio.client import connect +from websockets.sync.client import connect ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) localhost_pem = pathlib.Path(__file__).with_name("localhost.pem") ssl_context.load_verify_locations(localhost_pem) -async def hello(): +def hello(): uri = "wss://localhost:8765" - async with connect(uri, ssl=ssl_context) as websocket: + with connect(uri, ssl=ssl_context) as websocket: name = input("What's your name? ") - await websocket.send(name) + websocket.send(name) print(f">>> {name}") - greeting = await websocket.recv() + greeting = websocket.recv() print(f"<<< {greeting}") if __name__ == "__main__": - asyncio.run(hello()) + hello() diff --git a/example/quickstart/localhost.pem b/example/tls/localhost.pem similarity index 100% rename from example/quickstart/localhost.pem rename to example/tls/localhost.pem diff --git a/example/quickstart/server_secure.py b/example/tls/server.py similarity index 100% rename from example/quickstart/server_secure.py rename to example/tls/server.py diff --git a/src/websockets/asyncio/router.py b/src/websockets/asyncio/router.py index cd95022c1..047e7ef1c 100644 --- a/src/websockets/asyncio/router.py +++ b/src/websockets/asyncio/router.py @@ -81,7 +81,9 @@ def route( """ Create a WebSocket server dispatching connections to different handlers. - This feature requires the third-party library `werkzeug`_:: + This feature requires the third-party library `werkzeug`_: + + .. code-block:: console $ pip install werkzeug diff --git a/src/websockets/sync/router.py b/src/websockets/sync/router.py index 33105bf32..5572c4261 100644 --- a/src/websockets/sync/router.py +++ b/src/websockets/sync/router.py @@ -81,7 +81,9 @@ def route( """ Create a WebSocket server dispatching connections to different handlers. - This feature requires the third-party library `werkzeug`_:: + This feature requires the third-party library `werkzeug`_: + + .. code-block:: console $ pip install werkzeug From e934680c21b719bad631d8a6ee12bbc54c783601 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 Feb 2025 22:20:05 +0100 Subject: [PATCH 1534/1539] Broaden type of extension parameters. --- src/websockets/extensions/base.py | 2 +- src/websockets/extensions/permessage_deflate.py | 2 +- src/websockets/headers.py | 2 +- src/websockets/typing.py | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index 42dd6c5fa..2fdc59f0f 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -58,7 +58,7 @@ class ClientExtensionFactory: name: ExtensionName """Extension identifier.""" - def get_request_params(self) -> list[ExtensionParameter]: + def get_request_params(self) -> Sequence[ExtensionParameter]: """ Build parameters to send to the server for this extension. diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 8e74cb282..7e9e7a5dd 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -351,7 +351,7 @@ def __init__( self.client_max_window_bits = client_max_window_bits self.compress_settings = compress_settings - def get_request_params(self) -> list[ExtensionParameter]: + def get_request_params(self) -> Sequence[ExtensionParameter]: """ Build request parameters. diff --git a/src/websockets/headers.py b/src/websockets/headers.py index c42abd976..e05ff5b4c 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -390,7 +390,7 @@ def parse_extension(header: str) -> list[ExtensionHeader]: def build_extension_item( - name: ExtensionName, parameters: list[ExtensionParameter] + name: ExtensionName, parameters: Sequence[ExtensionParameter] ) -> str: """ Build an extension definition. diff --git a/src/websockets/typing.py b/src/websockets/typing.py index f10481b8b..ab7ddd33e 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -2,7 +2,7 @@ import http import logging -from typing import TYPE_CHECKING, Any, NewType, Optional, Union +from typing import TYPE_CHECKING, Any, NewType, Optional, Sequence, Union __all__ = [ @@ -62,7 +62,7 @@ # Private types -ExtensionHeader = tuple[ExtensionName, list[ExtensionParameter]] +ExtensionHeader = tuple[ExtensionName, Sequence[ExtensionParameter]] """Extension in a ``Sec-WebSocket-Extensions`` header.""" From 667e418ae474235c2bc26024df11c8e1d573d27a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 Feb 2025 22:55:58 +0100 Subject: [PATCH 1535/1539] Review how-to guides, notably the patterns guide. Fix #1209. --- docs/conf.py | 1 + docs/howto/autoreload.rst | 17 +++++----- docs/howto/debugging.rst | 11 ++++--- docs/howto/django.rst | 57 ++++++++++++++++----------------- docs/howto/extensions.rst | 39 ++++++++++++++--------- docs/howto/index.rst | 9 ++---- docs/howto/patterns.rst | 64 +++++++++++++++++++++++--------------- docs/project/changelog.rst | 2 ++ 8 files changed, 113 insertions(+), 87 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index c6b9ac7d8..798d595db 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -84,6 +84,7 @@ intersphinx_mapping = { "python": ("https://docs.python.org/3", None), + "sesame": ("https://django-sesame.readthedocs.io/en/stable/", None), "werkzeug": ("https://werkzeug.palletsprojects.com/en/stable/", None), } diff --git a/docs/howto/autoreload.rst b/docs/howto/autoreload.rst index 4d1adee8b..dfa84ada3 100644 --- a/docs/howto/autoreload.rst +++ b/docs/howto/autoreload.rst @@ -1,15 +1,16 @@ Reload on code changes ====================== -When developing a websockets server, you may run it locally to test changes. -Unfortunately, whenever you want to try a new version of the code, you must -stop the server and restart it, which slows down your development process. +When developing a websockets server, you are likely to run it locally to test +changes. Unfortunately, whenever you want to try a new version of the code, you +must stop the server and restart it, which slows down your development process. -Web frameworks such as Django or Flask provide a development server that -reloads the application automatically when you make code changes. There is no -such functionality in websockets because it's designed only for production. +Web frameworks such as Django or Flask provide a development server that reloads +the application automatically when you make code changes. There is no equivalent +functionality in websockets because it's designed only for production. -However, you can achieve the same result easily. +However, you can achieve the same result easily with a third-party library and a +shell command. Install watchdog_ with the ``watchmedo`` shell utility: @@ -27,4 +28,4 @@ Run your server with ``watchmedo auto-restart``: python app.py This example assumes that the server is defined in a script called ``app.py`` -and exits cleanly when receiving the ``SIGTERM`` signal. Adapt it as necessary. +and exits cleanly when receiving the ``SIGTERM`` signal. Adapt as necessary. diff --git a/docs/howto/debugging.rst b/docs/howto/debugging.rst index fc5dcba56..546f70a6f 100644 --- a/docs/howto/debugging.rst +++ b/docs/howto/debugging.rst @@ -3,9 +3,10 @@ Enable debug logs websockets logs events with the :mod:`logging` module from the standard library. -It writes to the ``"websockets.server"`` and ``"websockets.client"`` loggers. +It emits logs in the ``"websockets.server"`` and ``"websockets.client"`` +loggers. -Enable logs at the ``DEBUG`` level to see exactly what websockets is doing. +You can enable logs at the ``DEBUG`` level to see exactly what websockets does. If logging isn't configured in your application:: @@ -24,10 +25,10 @@ If logging is already configured:: logger.setLevel(logging.DEBUG) logger.addHandler(logging.StreamHandler()) -Refer to the :doc:`logging guide <../topics/logging>` for more details on +Refer to the :doc:`logging guide <../topics/logging>` for more information about logging in websockets. -In addition, you may enable asyncio's `debug mode`_ to see what asyncio is -doing. +You may also enable asyncio's `debug mode`_ to get warnings about classic +pitfalls. .. _debug mode: https://docs.python.org/3/library/asyncio-dev.html#asyncio-debug-mode diff --git a/docs/howto/django.rst b/docs/howto/django.rst index 4fe2311cb..2fa0b89fd 100644 --- a/docs/howto/django.rst +++ b/docs/howto/django.rst @@ -101,6 +101,7 @@ call its APIs in the websockets server. Now here's how to implement authentication. .. literalinclude:: ../../example/django/authentication.py + :caption: authentication.py Let's unpack this code. @@ -113,23 +114,25 @@ your settings module. The connection handler reads the first message received from the client, which is expected to contain a django-sesame token. Then it authenticates the user -with ``get_user()``, the API for `authentication outside a view`_. If -authentication fails, it closes the connection and exits. +with :func:`~sesame.utils.get_user`, the API provided by django-sesame for +`authentication outside a view`_. .. _authentication outside a view: https://django-sesame.readthedocs.io/en/stable/howto.html#outside-a-view -When we call an API that makes a database query such as ``get_user()``, we -wrap the call in :func:`~asyncio.to_thread`. Indeed, the Django ORM doesn't -support asynchronous I/O. It would block the event loop if it didn't run in a -separate thread. +If authentication fails, it closes the connection and exits. + +When we call an API that makes a database query such as +:func:`~sesame.utils.get_user`, we wrap the call in :func:`~asyncio.to_thread`. +Indeed, the Django ORM doesn't support asynchronous I/O. It would block the +event loop if it didn't run in a separate thread. Finally, we start a server with :func:`~websockets.asyncio.server.serve`. We're ready to test! -Save this code to a file called ``authentication.py``, make sure the -``DJANGO_SETTINGS_MODULE`` environment variable is set properly, and start the -websockets server: +Download :download:`authentication.py <../../example/django/authentication.py>`, +make sure the ``DJANGO_SETTINGS_MODULE`` environment variable is set properly, +and start the websockets server: .. code-block:: console @@ -169,7 +172,7 @@ following code in the JavaScript console of the browser: websocket.onmessage = (event) => console.log(event.data); If you don't want to import your entire Django project into the websockets -server, you can build a separate Django project with ``django.contrib.auth``, +server, you can create a simpler Django project with ``django.contrib.auth``, ``django-sesame``, a suitable ``User`` model, and a subset of the settings of the main project. @@ -184,11 +187,11 @@ action was made. This may be used for showing notifications to other users. Many use cases for WebSocket with Django follow a similar pattern. -Set up event bus -................ +Set up event stream +................... -We need a event bus to enable communications between Django and websockets. -Both sides connect permanently to the bus. Then Django writes events and +We need an event stream to enable communications between Django and websockets. +Both sides connect permanently to the stream. Then Django writes events and websockets reads them. For the sake of simplicity, we'll rely on `Redis Pub/Sub`_. @@ -219,14 +222,15 @@ change ``get_redis_connection("default")`` in the code below to the same name. Publish events .............. -Now let's write events to the bus. +Now let's write events to the stream. Add the following code to a module that is imported when your Django project -starts. Typically, you would put it in a ``signals.py`` module, which you -would import in the ``AppConfig.ready()`` method of one of your apps: +starts. Typically, you would put it in a :download:`signals.py +<../../example/django/signals.py>` module, which you would import in the +``AppConfig.ready()`` method of one of your apps: .. literalinclude:: ../../example/django/signals.py - + :caption: signals.py This code runs every time the admin saves a ``LogEntry`` object to keep track of a change. It extracts interesting data, serializes it to JSON, and writes an event to Redis. @@ -256,13 +260,13 @@ We need to add several features: * Keep track of connected clients so we can broadcast messages. * Tell which content types the user has permission to view or to change. -* Connect to the message bus and read events. +* Connect to the message stream and read events. * Broadcast these events to users who have corresponding permissions. Here's a complete implementation. .. literalinclude:: ../../example/django/notifications.py - + :caption: notifications.py Since the ``get_content_types()`` function makes a database query, it is wrapped inside :func:`asyncio.to_thread()`. It runs once when each WebSocket connection is open; then its result is cached for the lifetime of the @@ -273,13 +277,10 @@ The connection handler merely registers the connection in a global variable, associated to the list of content types for which events should be sent to that connection, and waits until the client disconnects. -The ``process_events()`` function reads events from Redis and broadcasts them -to all connections that should receive them. We don't care much if a sending a -notification fails — this happens when a connection drops between the moment -we iterate on connections and the moment the corresponding message is sent — -so we start a task with for each message and forget about it. Also, this means -we're immediately ready to process the next event, even if it takes time to -send a message to a slow client. +The ``process_events()`` function reads events from Redis and broadcasts them to +all connections that should receive them. We don't care much if a sending a +notification fails. This happens when a connection drops between the moment we +iterate on connections and the moment the corresponding message is sent. Since Redis can publish a message to multiple subscribers, multiple instances of this server can safely run in parallel. @@ -290,4 +291,4 @@ Does it scale? In theory, given enough servers, this design can scale to a hundred million clients, since Redis can handle ten thousand servers and each server can handle ten thousand clients. In practice, you would need a more scalable -message bus before reaching that scale, due to the volume of messages. +message stream before reaching that scale, due to the volume of messages. diff --git a/docs/howto/extensions.rst b/docs/howto/extensions.rst index c4e9da626..2f73e2f87 100644 --- a/docs/howto/extensions.rst +++ b/docs/howto/extensions.rst @@ -1,30 +1,39 @@ Write an extension ================== -.. currentmodule:: websockets.extensions +.. currentmodule:: websockets During the opening handshake, WebSocket clients and servers negotiate which -extensions_ will be used with which parameters. Then each frame is processed -by extensions before being sent or after being received. +extensions_ will be used and with which parameters. .. _extensions: https://datatracker.ietf.org/doc/html/rfc6455.html#section-9 -As a consequence, writing an extension requires implementing several classes: +Then, each frame is processed before being sent and after being received +according to the extensions that were negotiated. -* Extension Factory: it negotiates parameters and instantiates the extension. +Writing an extension requires implementing at least two classes, an extension +factory and an extension. They inherit from base classes provided by websockets. - Clients and servers require separate extension factories with distinct APIs. +Extension factory +----------------- - Extension factories are the public API of an extension. +An extension factory negotiates parameters and instantiates the extension. -* Extension: it decodes incoming frames and encodes outgoing frames. +Clients and servers require separate extension factories with distinct APIs. +Base classes are :class:`~extensions.ClientExtensionFactory` and +:class:`~extensions.ServerExtensionFactory`. - If the extension is symmetrical, clients and servers can use the same - class. +Extension factories are the public API of an extension. Extensions are enabled +with the ``extensions`` parameter of :func:`~asyncio.client.connect` or +:func:`~asyncio.server.serve`. - Extensions are initialized by extension factories, so they don't need to be - part of the public API of an extension. +Extension +--------- -websockets provides base classes for extension factories and extensions. -See :class:`ClientExtensionFactory`, :class:`ServerExtensionFactory`, -and :class:`Extension` for details. +An extension decodes incoming frames and encodes outgoing frames. + +If the extension is symmetrical, clients and servers can use the same class. The +base class is :class:`~extensions.Extension`. + +Since extensions are initialized by extension factories, they don't need to be +part of the public API of an extension. diff --git a/docs/howto/index.rst b/docs/howto/index.rst index ffded9ff0..12b38ed06 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -14,22 +14,19 @@ Configure websockets securely in production. encryption -If you're stuck, perhaps you'll find the answer here. +These guides will help you design and build your application. .. toctree:: + :maxdepth: 2 patterns - -This guide will help you integrate websockets into a broader system. - -.. toctree:: - django Upgrading from the legacy :mod:`asyncio` implementation to the new one? Read this. .. toctree:: + :maxdepth: 2 upgrade diff --git a/docs/howto/patterns.rst b/docs/howto/patterns.rst index bfb78b6ca..e0602b6f8 100644 --- a/docs/howto/patterns.rst +++ b/docs/howto/patterns.rst @@ -1,46 +1,52 @@ -Patterns -======== +Design a WebSocket application +============================== .. currentmodule:: websockets -Here are typical patterns for processing messages in a WebSocket server or -client. You will certainly implement some of them in your application. +WebSocket server or client applications follow common patterns. This guide +describes patterns that you're likely to implement in your application. -This page gives examples of connection handlers for a server. However, they're -also applicable to a client, simply by assuming that ``websocket`` is a -connection created with :func:`~asyncio.client.connect`. +All examples are connection handlers for a server. However, they would also +apply to a client, assuming that ``websocket`` is a connection created with +:func:`~asyncio.client.connect`. -WebSocket connections are long-lived. You will usually write a loop to process -several messages during the lifetime of a connection. +.. admonition:: WebSocket connections are long-lived. + :class: tip -Consumer --------- + You need a loop to process several messages during the lifetime of a + connection. + +Consumer pattern +---------------- To receive messages from the WebSocket connection:: async def consumer_handler(websocket): async for message in websocket: - await consumer(message) + await consume(message) -In this example, ``consumer()`` is a coroutine implementing your business -logic for processing a message received on the WebSocket connection. Each -message may be :class:`str` or :class:`bytes`. +In this example, ``consume()`` is a coroutine implementing your business logic +for processing a message received on the WebSocket connection. Iteration terminates when the client disconnects. -Producer --------- +Producer pattern +---------------- To send messages to the WebSocket connection:: + from websockets.exceptions import ConnectionClosed + async def producer_handler(websocket): - while True: - message = await producer() - await websocket.send(message) + try: + while True: + message = await produce() + await websocket.send(message) + except ConnectionClosed: + break -In this example, ``producer()`` is a coroutine implementing your business -logic for generating the next message to send on the WebSocket connection. -Each message must be :class:`str` or :class:`bytes`. +In this example, ``produce()`` is a coroutine implementing your business logic +for generating the next message to send on the WebSocket connection. Iteration terminates when the client disconnects because :meth:`~asyncio.server.ServerConnection.send` raises a @@ -51,8 +57,12 @@ Consumer and producer --------------------- You can receive and send messages on the same WebSocket connection by -combining the consumer and producer patterns. This requires running two tasks -in parallel:: +combining the consumer and producer patterns. + +This requires running two tasks in parallel. The simplest option offered by +:mod:`asyncio` is:: + + import asyncio async def handler(websocket): await asyncio.gather( @@ -99,6 +109,10 @@ connect and unregister them when they disconnect:: This example maintains the set of connected clients in memory. This works as long as you run a single process. It doesn't scale to multiple processes. +If you just need the set of connected clients, as in this example, use the +:attr:`~asyncio.server.Server.connections` property of the server. This pattern +is needed only when recording additional information about each client. + Publish–subscribe ----------------- diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 31a537af5..7dc79cf7d 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -68,6 +68,8 @@ New features Improvements ............ +* Refreshed several how-to guides and topic guides. + * Added type overloads for the ``decode`` argument of :meth:`~asyncio.connection.Connection.recv`. This may simplify static typing. From 963f13f0bc9ed72afc38fdc0f373b2549744d570 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 Feb 2025 08:41:55 +0100 Subject: [PATCH 1536/1539] Restore display of close code and reason. Fix #1591. --- src/websockets/__main__.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index 8647481d0..fbc8b0568 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -12,6 +12,7 @@ except ImportError: # Windows has no `readline` normally pass +from .frames import Close from .sync.client import ClientConnection, connect from .version import version as websockets_version @@ -150,7 +151,10 @@ def main() -> None: except (KeyboardInterrupt, EOFError): # ^C, ^D stop.set() websocket.close() - print_over_input("Connection closed.") + + assert websocket.close_code is not None and websocket.close_reason is not None + close_status = Close(websocket.close_code, websocket.close_reason) + print_over_input(f"Connection closed: {close_status}.") thread.join() From bc3fd2946a67276faff6e1a1043e6f1f1d0ad69f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 Feb 2025 09:22:55 +0100 Subject: [PATCH 1537/1539] Simplify enabling VT100 mode on Windows. Arguably, this relies on a bug, but: * The current implementation was vulnerable to platform-specific bugs, as highlighted in discussions of VT100 on Windows in the Python bug tracker, while the new one is clearly not going to cause harm. * If the bug in Python is fixed, hopefully we will gain a proper way to enable VT100. --- src/websockets/__main__.py | 49 ++++---------------------------------- 1 file changed, 5 insertions(+), 44 deletions(-) diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index fbc8b0568..c356510c5 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -9,7 +9,7 @@ try: import readline # noqa: F401 -except ImportError: # Windows has no `readline` normally +except ImportError: # readline isn't available on all platforms pass from .frames import Close @@ -17,38 +17,6 @@ from .version import version as websockets_version -if sys.platform == "win32": - - def win_enable_vt100() -> None: - """ - Enable VT-100 for console output on Windows. - - See also https://github.com/python/cpython/issues/73245. - - """ - import ctypes - - STD_OUTPUT_HANDLE = ctypes.c_uint(-11) - INVALID_HANDLE_VALUE = ctypes.c_uint(-1) - ENABLE_VIRTUAL_TERMINAL_PROCESSING = 0x004 - - handle = ctypes.windll.kernel32.GetStdHandle(STD_OUTPUT_HANDLE) - if handle == INVALID_HANDLE_VALUE: - raise RuntimeError("unable to obtain stdout handle") - - cur_mode = ctypes.c_uint() - if ctypes.windll.kernel32.GetConsoleMode(handle, ctypes.byref(cur_mode)) == 0: - raise RuntimeError("unable to query current console mode") - - # ctypes ints lack support for the required bit-OR operation. - # Temporarily convert to Py int, do the OR and convert back. - py_int_mode = int.from_bytes(cur_mode, sys.byteorder) - new_mode = ctypes.c_uint(py_int_mode | ENABLE_VIRTUAL_TERMINAL_PROCESSING) - - if ctypes.windll.kernel32.SetConsoleMode(handle, new_mode) == 0: - raise RuntimeError("unable to set console mode") - - def print_during_input(string: str) -> None: sys.stdout.write( # Save cursor position @@ -116,17 +84,10 @@ def main() -> None: if args.uri is None: parser.error("the following arguments are required: ") - # If we're on Windows, enable VT100 terminal support. - if sys.platform == "win32": - try: - win_enable_vt100() - except RuntimeError as exc: - sys.stderr.write( - f"Unable to set terminal to VT100 mode. This is only " - f"supported since Win10 anniversary update. Expect " - f"weird symbols on the terminal.\nError: {exc}\n" - ) - sys.stderr.flush() + # Enable VT100 to support ANSI escape codes in Command Prompt on Windows. + # See https://github.com/python/cpython/issues/74261 for why this works. + if sys.platform == "win32": # pragma: no cover + os.system("") try: websocket = connect(args.uri) From a1ba01db142459db0ea6f7659b3a5f4962749fa6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 Feb 2025 11:32:49 +0100 Subject: [PATCH 1538/1539] Rewrite interactive client (again) without threads. Fix #1592. --- src/websockets/__main__.py | 138 ++++++++++++++++++++++++------------- 1 file changed, 91 insertions(+), 47 deletions(-) diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index c356510c5..5043e1e29 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -1,19 +1,16 @@ from __future__ import annotations import argparse +import asyncio import os -import signal import sys -import threading - - -try: - import readline # noqa: F401 -except ImportError: # readline isn't available on all platforms - pass +from typing import Generator +from .asyncio.client import ClientConnection, connect +from .asyncio.messages import SimpleQueue +from .exceptions import ConnectionClosed from .frames import Close -from .sync.client import ClientConnection, connect +from .streams import StreamReader from .version import version as websockets_version @@ -49,24 +46,94 @@ def print_over_input(string: str) -> None: sys.stdout.flush() -def print_incoming_messages(websocket: ClientConnection, stop: threading.Event) -> None: - for message in websocket: +class ReadLines(asyncio.Protocol): + def __init__(self) -> None: + self.reader = StreamReader() + self.messages: SimpleQueue[str] = SimpleQueue() + + def parse(self) -> Generator[None, None, None]: + while True: + sys.stdout.write("> ") + sys.stdout.flush() + line = yield from self.reader.read_line(sys.maxsize) + self.messages.put(line.decode().rstrip("\r\n")) + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + self.parser = self.parse() + next(self.parser) + + def data_received(self, data: bytes) -> None: + self.reader.feed_data(data) + next(self.parser) + + def eof_received(self) -> None: + self.reader.feed_eof() + # next(self.parser) isn't useful and would raise EOFError. + + def connection_lost(self, exc: Exception | None) -> None: + self.reader.discard() + self.messages.abort() + + +async def print_incoming_messages(websocket: ClientConnection) -> None: + async for message in websocket: if isinstance(message, str): print_during_input("< " + message) else: print_during_input("< (binary) " + message.hex()) - if not stop.is_set(): - # When the server closes the connection, raise KeyboardInterrupt - # in the main thread to exit the program. - if sys.platform == "win32": - ctrl_c = signal.CTRL_C_EVENT - else: - ctrl_c = signal.SIGINT - os.kill(os.getpid(), ctrl_c) + + +async def send_outgoing_messages( + websocket: ClientConnection, + messages: SimpleQueue[str], +) -> None: + while True: + try: + message = await messages.get() + except EOFError: + break + try: + await websocket.send(message) + except ConnectionClosed: + break + + +async def interactive_client(uri: str) -> None: + try: + websocket = await connect(uri) + except Exception as exc: + print(f"Failed to connect to {uri}: {exc}.") + sys.exit(1) + else: + print(f"Connected to {uri}.") + + loop = asyncio.get_running_loop() + transport, protocol = await loop.connect_read_pipe(ReadLines, sys.stdin) + incoming = asyncio.create_task( + print_incoming_messages(websocket), + ) + outgoing = asyncio.create_task( + send_outgoing_messages(websocket, protocol.messages), + ) + try: + await asyncio.wait( + [incoming, outgoing], + return_when=asyncio.FIRST_COMPLETED, + ) + except (KeyboardInterrupt, EOFError): # ^C, ^D + pass + finally: + incoming.cancel() + outgoing.cancel() + transport.close() + + await websocket.close() + assert websocket.close_code is not None and websocket.close_reason is not None + close_status = Close(websocket.close_code, websocket.close_reason) + print_over_input(f"Connection closed: {close_status}.") def main() -> None: - # Parse command line arguments. parser = argparse.ArgumentParser( prog="python -m websockets", description="Interactive WebSocket client.", @@ -90,34 +157,11 @@ def main() -> None: os.system("") try: - websocket = connect(args.uri) - except Exception as exc: - print(f"Failed to connect to {args.uri}: {exc}.") - sys.exit(1) - else: - print(f"Connected to {args.uri}.") - - stop = threading.Event() - - # Start the thread that reads messages from the connection. - thread = threading.Thread(target=print_incoming_messages, args=(websocket, stop)) - thread.start() - - # Read from stdin in the main thread in order to receive signals. - try: - while True: - # Since there's no size limit, put_nowait is identical to put. - message = input("> ") - websocket.send(message) - except (KeyboardInterrupt, EOFError): # ^C, ^D - stop.set() - websocket.close() - - assert websocket.close_code is not None and websocket.close_reason is not None - close_status = Close(websocket.close_code, websocket.close_reason) - print_over_input(f"Connection closed: {close_status}.") + import readline # noqa: F401 + except ImportError: # readline isn't available on all platforms + pass - thread.join() + asyncio.run(interactive_client(args.uri)) if __name__ == "__main__": From 7ac73c645329055a3c352077b8055e6ed65fa46c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 Feb 2025 11:47:18 +0100 Subject: [PATCH 1539/1539] Release version 15.0. --- docs/project/changelog.rst | 2 +- src/websockets/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 7dc79cf7d..287d2fe31 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -30,7 +30,7 @@ notice. 15.0 ---- -*In development* +*February 16, 2025* Backwards-incompatible changes .............................. diff --git a/src/websockets/version.py b/src/websockets/version.py index 611e7d238..738f2cac1 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,7 +18,7 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = False +released = True tag = version = commit = "15.0"